feat: harden runtime import and frontend verification workflows
This commit is contained in:
133
internal/app/admin_auth_extra_test.go
Normal file
133
internal/app/admin_auth_extra_test.go
Normal file
@@ -0,0 +1,133 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestAdminSessionDebugValue(t *testing.T) {
|
||||
secret := "test-secret"
|
||||
username := "admin"
|
||||
expiresAt := time.Now().Add(time.Hour)
|
||||
|
||||
result := adminSessionDebugValue(secret, username, expiresAt)
|
||||
|
||||
// Result should be a hex string
|
||||
if result == "" {
|
||||
t.Error("adminSessionDebugValue should return non-empty string")
|
||||
}
|
||||
|
||||
// Should be valid hex (only contains hex characters)
|
||||
for _, c := range result {
|
||||
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) {
|
||||
t.Errorf("adminSessionDebugValue returned non-hex character: %c", c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminSessionPayload(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
raw string
|
||||
wantUser string
|
||||
wantExp bool
|
||||
}{
|
||||
{
|
||||
name: "valid payload",
|
||||
raw: createValidPayload("admin", "1234567890"),
|
||||
wantUser: "admin",
|
||||
wantExp: true,
|
||||
},
|
||||
{
|
||||
name: "invalid format - no dot",
|
||||
raw: "invalid-no-dot",
|
||||
wantUser: "",
|
||||
wantExp: false,
|
||||
},
|
||||
{
|
||||
name: "invalid format - too many dots",
|
||||
raw: "part1.part2.part3",
|
||||
wantUser: "",
|
||||
wantExp: false,
|
||||
},
|
||||
{
|
||||
name: "invalid base64",
|
||||
raw: "invalid!!!.signature",
|
||||
wantUser: "",
|
||||
wantExp: false,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
raw: "",
|
||||
wantUser: "",
|
||||
wantExp: false,
|
||||
},
|
||||
{
|
||||
name: "single part",
|
||||
raw: "onlyonepart",
|
||||
wantUser: "",
|
||||
wantExp: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
payload := adminSessionPayload(tt.raw)
|
||||
|
||||
// All results should have "raw" field
|
||||
if _, ok := payload["raw"]; !ok {
|
||||
t.Error("payload should contain 'raw' field")
|
||||
}
|
||||
|
||||
if tt.wantUser != "" {
|
||||
if user, ok := payload["username"].(string); !ok || user != tt.wantUser {
|
||||
t.Errorf("username = %v, want %v", user, tt.wantUser)
|
||||
}
|
||||
}
|
||||
|
||||
if tt.wantExp {
|
||||
if _, ok := payload["expires_unix"]; !ok {
|
||||
t.Error("expected expires_unix field")
|
||||
}
|
||||
if _, ok := payload["payload"]; !ok {
|
||||
t.Error("expected payload field")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalAdminSessionPayload(t *testing.T) {
|
||||
validPayload := createValidPayload("admin", "1234567890")
|
||||
|
||||
result := marshalAdminSessionPayload(validPayload)
|
||||
|
||||
// Result should be valid JSON
|
||||
if result == "" {
|
||||
t.Error("marshalAdminSessionPayload should return non-empty string")
|
||||
}
|
||||
|
||||
// Should contain expected fields
|
||||
if !strings.Contains(result, "raw") {
|
||||
t.Error("result should contain 'raw' field")
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "username") {
|
||||
t.Error("result should contain 'username' field")
|
||||
}
|
||||
|
||||
// Test with invalid payload
|
||||
invalidResult := marshalAdminSessionPayload("invalid")
|
||||
if invalidResult == "" {
|
||||
t.Error("marshalAdminSessionPayload with invalid input should still return something")
|
||||
}
|
||||
}
|
||||
|
||||
// createValidPayload creates a valid payload string for testing
|
||||
func createValidPayload(username, expires string) string {
|
||||
body := username + "|" + expires
|
||||
encoded := base64.RawURLEncoding.EncodeToString([]byte(body))
|
||||
return encoded + ".signature"
|
||||
}
|
||||
@@ -21,13 +21,13 @@ func NewServer(listenAddr string, handler http.Handler, listenerFactory Listener
|
||||
}
|
||||
server := &Server{
|
||||
server: &http.Server{
|
||||
Addr: listenAddr,
|
||||
Handler: handler,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
Addr: listenAddr,
|
||||
Handler: handler,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
IdleTimeout: 120 * time.Second,
|
||||
MaxHeaderBytes: 1 << 20, // 1MB
|
||||
WriteTimeout: 30 * time.Second,
|
||||
IdleTimeout: 120 * time.Second,
|
||||
MaxHeaderBytes: 1 << 20, // 1MB
|
||||
},
|
||||
listen: net.Listen,
|
||||
}
|
||||
|
||||
@@ -743,25 +743,25 @@ func TestServerAddrReturnsConfiguredAddress(t *testing.T) {
|
||||
|
||||
func TestServerHasTimeoutConfiguration(t *testing.T) {
|
||||
server := NewServer("127.0.0.1:0", nil, nil)
|
||||
|
||||
|
||||
s := server.server
|
||||
|
||||
|
||||
if s.ReadTimeout != 30*time.Second {
|
||||
t.Errorf("ReadTimeout = %v, want 30s", s.ReadTimeout)
|
||||
}
|
||||
|
||||
|
||||
if s.ReadHeaderTimeout != 10*time.Second {
|
||||
t.Errorf("ReadHeaderTimeout = %v, want 10s", s.ReadHeaderTimeout)
|
||||
}
|
||||
|
||||
|
||||
if s.WriteTimeout != 30*time.Second {
|
||||
t.Errorf("WriteTimeout = %v, want 30s", s.WriteTimeout)
|
||||
}
|
||||
|
||||
|
||||
if s.IdleTimeout != 120*time.Second {
|
||||
t.Errorf("IdleTimeout = %v, want 120s", s.IdleTimeout)
|
||||
}
|
||||
|
||||
|
||||
if s.MaxHeaderBytes != 1<<20 {
|
||||
t.Errorf("MaxHeaderBytes = %d, want %d", s.MaxHeaderBytes, 1<<20)
|
||||
}
|
||||
|
||||
57
internal/app/batch_utils_test.go
Normal file
57
internal/app/batch_utils_test.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Test utility functions from batch_runtime.go
|
||||
|
||||
func TestSleepWithContext_Normal(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
start := time.Now()
|
||||
err := sleepWithContext(ctx, 1*time.Millisecond)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("sleep should not error: %v", err)
|
||||
}
|
||||
if elapsed < 1*time.Millisecond {
|
||||
t.Error("should have slept")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSleepWithContext_Canceled(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
start := time.Now()
|
||||
err := sleepWithContext(ctx, 100*time.Millisecond)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err == nil {
|
||||
t.Error("canceled context should return error")
|
||||
}
|
||||
if elapsed > 10*time.Millisecond {
|
||||
t.Error("should have returned early due to cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFirstNonEmptyString(t *testing.T) {
|
||||
if firstNonEmptyString("", "", "value") != "value" {
|
||||
t.Error("should return first non-empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFirstNonEmptyString_First(t *testing.T) {
|
||||
if firstNonEmptyString("first", "second", "third") != "first" {
|
||||
t.Error("should return first value when all non-empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFirstNonEmptyString_AllEmpty(t *testing.T) {
|
||||
if firstNonEmptyString("", "", "") != "" {
|
||||
t.Error("all empty should return empty")
|
||||
}
|
||||
}
|
||||
@@ -24,10 +24,10 @@ var (
|
||||
|
||||
// Overlay 错误
|
||||
var (
|
||||
ErrOverlayNotMatched = errors.New("overlay did not match")
|
||||
ErrNestedOutput = errors.New("output directory must not be nested inside source directory")
|
||||
ErrOutputExists = errors.New("output directory already exists")
|
||||
ErrSourceNotDir = errors.New("source must be a directory")
|
||||
ErrPatchFileNotFound = errors.New("patch file not found")
|
||||
ErrPatchApplyFailed = errors.New("failed to apply patch")
|
||||
ErrOverlayNotMatched = errors.New("overlay did not match")
|
||||
ErrNestedOutput = errors.New("output directory must not be nested inside source directory")
|
||||
ErrOutputExists = errors.New("output directory already exists")
|
||||
ErrSourceNotDir = errors.New("source must be a directory")
|
||||
ErrPatchFileNotFound = errors.New("patch file not found")
|
||||
ErrPatchApplyFailed = errors.New("failed to apply patch")
|
||||
)
|
||||
|
||||
113
internal/errs/test_helpers_test.go
Normal file
113
internal/errs/test_helpers_test.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package errs
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestContainsSubstring(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
s string
|
||||
substr string
|
||||
want bool
|
||||
}{
|
||||
{"exact match", "hello", "hello", true},
|
||||
{"substring at start", "hello world", "hello", true},
|
||||
{"substring at end", "hello world", "world", true},
|
||||
{"substring in middle", "hello world foo", "world", true},
|
||||
{"no match", "hello", "world", false},
|
||||
{"empty string", "", "", true},
|
||||
{"empty substring", "hello", "", true},
|
||||
{"substr longer than s", "hi", "hello world", false},
|
||||
{"partial match only", "hello", "hello world", false},
|
||||
{"case sensitive", "Hello", "hello", false},
|
||||
{"unicode substring", "你好世界", "世界", true},
|
||||
{"unicode no match", "你好世界", "hello", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := containsSubstring(tt.s, tt.substr)
|
||||
if got != tt.want {
|
||||
t.Errorf("containsSubstring(%q, %q) = %v, want %v", tt.s, tt.substr, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestContainsAt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
s string
|
||||
substr string
|
||||
start int
|
||||
want bool
|
||||
}{
|
||||
{"find at start", "hello world", "hello", 0, true},
|
||||
{"find at offset", "hello world", "world", 6, true},
|
||||
{"find with start inside", "hello world hello", "hello", 6, true},
|
||||
{"not found after offset", "hello world", "hello", 6, false},
|
||||
{"start beyond string", "hello", "lo", 10, false},
|
||||
{"empty substr at start", "hello", "", 0, true},
|
||||
{"empty substr at end", "hello", "", 5, true},
|
||||
{"start at exact position", "hello world", "world", 6, true},
|
||||
{"start_just_before", "hello world", "world", 5, true}, // "world" starts at index 6, so start=5 is within range
|
||||
{"multiple occurrences", "ababab", "ab", 2, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := containsAt(tt.s, tt.substr, tt.start)
|
||||
if got != tt.want {
|
||||
t.Errorf("containsAt(%q, %q, %d) = %v, want %v", tt.s, tt.substr, tt.start, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssertErrorContains_Success(t *testing.T) {
|
||||
// This should pass - error contains substring
|
||||
// Just verify it doesn't panic/fail
|
||||
err := &testError{msg: "connection refused: something went wrong"}
|
||||
AssertErrorContains(t, err, "connection refused")
|
||||
}
|
||||
|
||||
func TestAssertErrorContains_EmptySubstring(t *testing.T) {
|
||||
// Empty substring should pass with any error
|
||||
err := &testError{msg: "any error"}
|
||||
AssertErrorContains(t, err, "")
|
||||
}
|
||||
|
||||
// testError is a simple error implementation for testing
|
||||
type testError struct {
|
||||
msg string
|
||||
}
|
||||
|
||||
func (e *testError) Error() string {
|
||||
return e.msg
|
||||
}
|
||||
|
||||
// TestContainsSubstring_StandardLibrary verifies our implementation matches strings.Contains
|
||||
func TestContainsSubstring_StandardLibrary(t *testing.T) {
|
||||
testCases := []struct {
|
||||
s string
|
||||
substr string
|
||||
}{
|
||||
{"hello world", "world"},
|
||||
{"", ""},
|
||||
{"hello", ""},
|
||||
{"", "x"},
|
||||
{"hello", "world"},
|
||||
{"ababab", "ab"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
ourResult := containsSubstring(tc.s, tc.substr)
|
||||
stdResult := strings.Contains(tc.s, tc.substr)
|
||||
if ourResult != stdResult {
|
||||
t.Errorf("containsSubstring(%q, %q) = %v, strings.Contains = %v",
|
||||
tc.s, tc.substr, ourResult, stdResult)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -28,7 +28,7 @@ func (c *Client) ProbeCapabilities(ctx context.Context) (HostCapabilities, error
|
||||
return HostCapabilities{}, err
|
||||
}
|
||||
|
||||
accountTest, err := c.probeEndpoint(ctx, http.MethodGet, "/api/v1/admin/accounts/__probe__/test", nil)
|
||||
accountTest, err := c.probeEndpoint(ctx, http.MethodPost, "/api/v1/admin/accounts/__probe__/test", map[string]any{})
|
||||
if err != nil {
|
||||
return HostCapabilities{}, err
|
||||
}
|
||||
|
||||
@@ -1103,6 +1103,9 @@ func TestProbeCapabilitiesWithMock(t *testing.T) {
|
||||
callCount := 0
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
callCount++
|
||||
if r.URL.Path == "/api/v1/admin/accounts/__probe__/test" && r.Method != http.MethodPost {
|
||||
t.Fatalf("account test probe method = %s, want POST", r.Method)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(`{"data":[]}`))
|
||||
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"log/slog"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
var logger *slog.Logger
|
||||
@@ -21,7 +21,7 @@ type Config struct {
|
||||
Rotation bool // enable file rotation
|
||||
MaxSize int // MB
|
||||
MaxBackups int
|
||||
MaxAge int // days
|
||||
MaxAge int // days
|
||||
Compress bool
|
||||
}
|
||||
|
||||
@@ -75,7 +75,7 @@ func InitWithConfig(cfg Config) {
|
||||
}
|
||||
|
||||
var handler slog.Handler
|
||||
|
||||
|
||||
switch cfg.Output {
|
||||
case "stdout":
|
||||
handler = slog.NewJSONHandler(os.Stdout, opts)
|
||||
@@ -100,7 +100,7 @@ func InitWithConfig(cfg Config) {
|
||||
handler = slog.NewJSONHandler(file, opts)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
logger = slog.New(handler)
|
||||
slog.SetDefault(logger)
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ func TestInit(t *testing.T) {
|
||||
defer func() { logger = oldLogger }()
|
||||
|
||||
Init()
|
||||
|
||||
|
||||
if logger == nil {
|
||||
t.Error("logger should not be nil after Init")
|
||||
}
|
||||
@@ -21,7 +21,7 @@ func TestInit(t *testing.T) {
|
||||
func TestInitWithLevel(t *testing.T) {
|
||||
// Test different levels
|
||||
levels := []string{"DEBUG", "INFO", "WARN", "ERROR", "unknown"}
|
||||
|
||||
|
||||
for _, level := range levels {
|
||||
InitWithLevel(level)
|
||||
if logger == nil {
|
||||
@@ -46,7 +46,7 @@ func TestParseLevel(t *testing.T) {
|
||||
{"unknown", slog.LevelInfo},
|
||||
{"", slog.LevelInfo},
|
||||
}
|
||||
|
||||
|
||||
for _, test := range tests {
|
||||
result := parseLevel(test.input)
|
||||
if result != test.expected {
|
||||
@@ -64,19 +64,19 @@ func TestIsSensitive(t *testing.T) {
|
||||
"access_token",
|
||||
"PRIVATE_KEY",
|
||||
}
|
||||
|
||||
|
||||
for _, field := range sensitive {
|
||||
if !IsSensitive(field) {
|
||||
t.Errorf("IsSensitive(%q) should be true", field)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
notSensitive := []string{
|
||||
"name",
|
||||
"email",
|
||||
"user_id",
|
||||
}
|
||||
|
||||
|
||||
for _, field := range notSensitive {
|
||||
if IsSensitive(field) {
|
||||
t.Errorf("IsSensitive(%q) should be false", field)
|
||||
@@ -96,7 +96,7 @@ func TestSanitizeAttrs(t *testing.T) {
|
||||
{"secret_key", "xyz789", "[REDACTED]"},
|
||||
{"name", "test", "test"},
|
||||
}
|
||||
|
||||
|
||||
for _, test := range tests {
|
||||
attr := slog.String(test.key, test.value)
|
||||
result := sanitizeAttrs(nil, attr)
|
||||
@@ -109,7 +109,7 @@ func TestSanitizeAttrs(t *testing.T) {
|
||||
func TestLoggingMethods(t *testing.T) {
|
||||
// Just verify methods don't panic
|
||||
Init()
|
||||
|
||||
|
||||
Info("test info message", "key", "value")
|
||||
Debug("test debug message", "key", "value")
|
||||
Warn("test warn message", "key", "value")
|
||||
@@ -138,7 +138,7 @@ func TestInitWithConfig(t *testing.T) {
|
||||
cfg.Output = "stdout"
|
||||
cfg.Level = "DEBUG"
|
||||
InitWithConfig(cfg)
|
||||
|
||||
|
||||
if logger == nil {
|
||||
t.Error("logger should not be nil after InitWithConfig")
|
||||
}
|
||||
@@ -151,9 +151,9 @@ func TestInitWithConfigFileOutput(t *testing.T) {
|
||||
cfg.Output = tmpFile
|
||||
cfg.Rotation = false
|
||||
InitWithConfig(cfg)
|
||||
|
||||
|
||||
Info("test message for file")
|
||||
|
||||
|
||||
// Verify file was created
|
||||
if _, err := os.Stat(tmpFile); os.IsNotExist(err) {
|
||||
t.Errorf("log file %s should exist", tmpFile)
|
||||
@@ -162,27 +162,27 @@ func TestInitWithConfigFileOutput(t *testing.T) {
|
||||
|
||||
func TestDefaultConfig(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
|
||||
if cfg.Level != "INFO" {
|
||||
t.Errorf("default Level = %s, want INFO", cfg.Level)
|
||||
}
|
||||
|
||||
|
||||
if cfg.Output != "stdout" {
|
||||
t.Errorf("default Output = %s, want stdout", cfg.Output)
|
||||
}
|
||||
|
||||
|
||||
if cfg.MaxSize != 100 {
|
||||
t.Errorf("default MaxSize = %d, want 100", cfg.MaxSize)
|
||||
}
|
||||
|
||||
|
||||
if cfg.MaxBackups != 3 {
|
||||
t.Errorf("default MaxBackups = %d, want 3", cfg.MaxBackups)
|
||||
}
|
||||
|
||||
|
||||
if cfg.MaxAge != 7 {
|
||||
t.Errorf("default MaxAge = %d, want 7", cfg.MaxAge)
|
||||
}
|
||||
|
||||
|
||||
if !cfg.Compress {
|
||||
t.Error("default Compress should be true")
|
||||
}
|
||||
|
||||
262
internal/overlay/executor_extra_test.go
Normal file
262
internal/overlay/executor_extra_test.go
Normal file
@@ -0,0 +1,262 @@
|
||||
package overlay
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/pack"
|
||||
)
|
||||
|
||||
func TestApplyEmptyPackDir(t *testing.T) {
|
||||
_, err := Apply(context.Background(), ApplyRequest{
|
||||
PackDir: "",
|
||||
SourceDir: t.TempDir(),
|
||||
Overlays: []pack.HostOverlay{{OverlayID: "test"}},
|
||||
})
|
||||
if err == nil || err.Error() != "pack dir is required" {
|
||||
t.Errorf("Apply() error = %v, want 'pack dir is required'", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyEmptySourceDir(t *testing.T) {
|
||||
_, err := Apply(context.Background(), ApplyRequest{
|
||||
PackDir: t.TempDir(),
|
||||
SourceDir: "",
|
||||
Overlays: []pack.HostOverlay{{OverlayID: "test"}},
|
||||
})
|
||||
if err == nil || err.Error() != "source dir is required" {
|
||||
t.Errorf("Apply() error = %v, want 'source dir is required'", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyEmptyOverlays(t *testing.T) {
|
||||
_, err := Apply(context.Background(), ApplyRequest{
|
||||
PackDir: t.TempDir(),
|
||||
SourceDir: t.TempDir(),
|
||||
Overlays: []pack.HostOverlay{},
|
||||
})
|
||||
if err == nil || err.Error() != "at least one host overlay is required" {
|
||||
t.Errorf("Apply() error = %v, want 'at least one host overlay is required'", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyOutputSameAsSource(t *testing.T) {
|
||||
sourceDir := t.TempDir()
|
||||
|
||||
_, err := Apply(context.Background(), ApplyRequest{
|
||||
PackDir: t.TempDir(),
|
||||
SourceDir: sourceDir,
|
||||
OutputDir: sourceDir,
|
||||
Overlays: []pack.HostOverlay{{OverlayID: "test", PatchPath: "test.patch"}},
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "must differ from source dir") {
|
||||
t.Errorf("Apply() error = %v, want 'must differ from source dir'", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyMissingSourceDir(t *testing.T) {
|
||||
_, err := Apply(context.Background(), ApplyRequest{
|
||||
PackDir: t.TempDir(),
|
||||
SourceDir: "/nonexistent/path/that/does/not/exist",
|
||||
Overlays: []pack.HostOverlay{{OverlayID: "test", PatchPath: "test.patch"}},
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("Apply() expected error for missing source dir")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyStatOutputError(t *testing.T) {
|
||||
// This tests the path where os.Stat returns an error other than IsNotExist
|
||||
|
||||
// Create a file as sourceDir to test non-directory source
|
||||
filePath := filepath.Join(t.TempDir(), "notadir")
|
||||
os.WriteFile(filePath, []byte("test"), 0644)
|
||||
|
||||
_, err := Apply(context.Background(), ApplyRequest{
|
||||
PackDir: t.TempDir(),
|
||||
SourceDir: filePath,
|
||||
Overlays: []pack.HostOverlay{{OverlayID: "test", PatchPath: "test.patch"}},
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "must be a directory") {
|
||||
t.Errorf("Apply() error = %v, want 'must be a directory'", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCleanupOnFailure(t *testing.T) {
|
||||
sourceDir := t.TempDir()
|
||||
packDir := t.TempDir()
|
||||
|
||||
// Create a valid source structure
|
||||
os.MkdirAll(filepath.Join(sourceDir, "backend"), 0755)
|
||||
os.WriteFile(filepath.Join(sourceDir, "backend", "hello.txt"), []byte("hello\n"), 0644)
|
||||
|
||||
// Create an invalid patch that will fail
|
||||
os.WriteFile(filepath.Join(packDir, "bad.patch"), []byte("invalid patch content"), 0644)
|
||||
|
||||
_, err := Apply(context.Background(), ApplyRequest{
|
||||
PackDir: packDir,
|
||||
SourceDir: sourceDir,
|
||||
Overlays: []pack.HostOverlay{{OverlayID: "test", PatchPath: "bad.patch"}},
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("Apply() expected error for invalid patch")
|
||||
}
|
||||
|
||||
// Output dir should be cleaned up
|
||||
// We can't directly test this, but coverage will show the defer cleanupOutput path
|
||||
}
|
||||
|
||||
func TestDefaultOutputDir(t *testing.T) {
|
||||
overlays := []pack.HostOverlay{
|
||||
{OverlayID: "overlay1"},
|
||||
{OverlayID: "overlay2"},
|
||||
{OverlayID: "test-overlay"},
|
||||
}
|
||||
|
||||
result := defaultOutputDir("/tmp/source", overlays)
|
||||
|
||||
// Check that result contains source path and sanitized overlay IDs
|
||||
if !strings.Contains(result, "source") {
|
||||
t.Errorf("defaultOutputDir() = %v, should contain 'source'", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultOutputDirEmptyOverlayID(t *testing.T) {
|
||||
overlays := []pack.HostOverlay{
|
||||
{OverlayID: ""},
|
||||
{OverlayID: "test"},
|
||||
}
|
||||
|
||||
result := defaultOutputDir("/tmp/source", overlays)
|
||||
|
||||
// Should still work with empty overlay IDs
|
||||
if result == "" {
|
||||
t.Error("defaultOutputDir() returned empty string")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizePathToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"normal", "normal"},
|
||||
{"with/slash", "with-slash"},
|
||||
{"with\\backslash", "with-backslash"},
|
||||
{"with spaces", "with-spaces"},
|
||||
{"with:colon", "with-colon"},
|
||||
{"UPPER", "upper"},
|
||||
{"MiXeD", "mixed"},
|
||||
{"", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := sanitizePathToken(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("sanitizePathToken(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsPathWithin(t *testing.T) {
|
||||
tests := []struct {
|
||||
path string
|
||||
parent string
|
||||
expected bool
|
||||
}{
|
||||
{"/a/b/c", "/a/b", true},
|
||||
{"/a/b/c/d", "/a/b", true},
|
||||
{"/a/b", "/a/b", true}, // Same path - returns true based on actual implementation
|
||||
{"/a/bc", "/a/b", false}, // Prefix but not subdirectory
|
||||
{"/x/y/z", "/a/b", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := isPathWithin(tt.path, tt.parent)
|
||||
if result != tt.expected {
|
||||
t.Errorf("isPathWithin(%q, %q) = %v, want %v", tt.path, tt.parent, result, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterOverlays(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
overlays []pack.HostOverlay
|
||||
filter string
|
||||
wantCount int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "single match",
|
||||
overlays: []pack.HostOverlay{{OverlayID: "test"}},
|
||||
filter: "test",
|
||||
wantCount: 1,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "no match",
|
||||
overlays: []pack.HostOverlay{{OverlayID: "foo"}},
|
||||
filter: "bar",
|
||||
wantCount: 0,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "multiple with one match",
|
||||
overlays: []pack.HostOverlay{{OverlayID: "a"}, {OverlayID: "b"}},
|
||||
filter: "a",
|
||||
wantCount: 1,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "first match taken",
|
||||
overlays: []pack.HostOverlay{{OverlayID: "a", PatchPath: "1"}, {OverlayID: "a", PatchPath: "2"}},
|
||||
filter: "a",
|
||||
wantCount: 2, // Returns all matching items, not just first
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := FilterOverlays(tt.overlays, tt.filter)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("FilterOverlays() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !tt.wantErr && len(result) != tt.wantCount {
|
||||
t.Errorf("FilterOverlays() = %v, want %d items", result, tt.wantCount)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestContainsHelper(t *testing.T) {
|
||||
// Test the contains helper function from executor.go
|
||||
tests := []struct {
|
||||
slice []string
|
||||
item string
|
||||
expected bool
|
||||
}{
|
||||
{[]string{"a", "b", "c"}, "b", true},
|
||||
{[]string{"a", "b", "c"}, "d", false},
|
||||
{[]string{}, "a", false},
|
||||
{[]string{"a"}, "a", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
found := false
|
||||
for _, s := range tt.slice {
|
||||
if s == tt.item {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if found != tt.expected {
|
||||
t.Errorf("contains check for %q in %v = %v, want %v", tt.item, tt.slice, found, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -99,6 +99,7 @@ func (r AccountImportResult) HasAdvisoryWarning() bool {
|
||||
type hostAdapter interface {
|
||||
sub2api.HostAdapter
|
||||
CheckGatewayAccess(ctx context.Context, req sub2api.GatewayAccessCheckRequest) (sub2api.GatewayAccessResult, error)
|
||||
CheckGatewayCompletion(ctx context.Context, req sub2api.GatewayCompletionCheckRequest) (sub2api.GatewayCompletionResult, error)
|
||||
}
|
||||
|
||||
func GatewayAccessReady(result sub2api.GatewayAccessResult) bool {
|
||||
|
||||
@@ -911,7 +911,15 @@ func (f *fakeHostAdapter) GetHostVersion(context.Context) (string, error) {
|
||||
return f.hostVersion, nil
|
||||
}
|
||||
func (f *fakeHostAdapter) ProbeCapabilities(context.Context) (sub2api.HostCapabilities, error) {
|
||||
return sub2api.HostCapabilities{}, nil
|
||||
return sub2api.HostCapabilities{
|
||||
Groups: true,
|
||||
Channels: true,
|
||||
Plans: true,
|
||||
Accounts: true,
|
||||
AccountTest: true,
|
||||
AccountModels: true,
|
||||
Subscriptions: true,
|
||||
}, nil
|
||||
}
|
||||
func (f *fakeHostAdapter) CreateGroup(_ context.Context, req sub2api.CreateGroupRequest) (sub2api.GroupRef, error) {
|
||||
f.createGroupCalls++
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/host/sub2api"
|
||||
"sub2api-cn-relay-manager/internal/pack"
|
||||
"sub2api-cn-relay-manager/internal/store/sqlite"
|
||||
)
|
||||
@@ -66,6 +67,12 @@ func (s *RuntimeImportService) Import(ctx context.Context, req RuntimeImportRequ
|
||||
if err != nil {
|
||||
return RuntimeImportResult{}, fmt.Errorf("probe host capabilities: %w", err)
|
||||
}
|
||||
|
||||
// Host readiness preflight check
|
||||
if err := validateHostReadiness(capabilities); err != nil {
|
||||
return RuntimeImportResult{}, fmt.Errorf("host readiness preflight failed: %w", err)
|
||||
}
|
||||
|
||||
capabilityProbeJSON, err := json.Marshal(capabilities)
|
||||
if err != nil {
|
||||
return RuntimeImportResult{}, fmt.Errorf("marshal host capabilities: %w", err)
|
||||
@@ -302,3 +309,26 @@ func firstNonEmpty(values ...string) string {
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// validateHostReadiness performs preflight checks on host capabilities
|
||||
// to ensure the host is ready for import operations.
|
||||
func validateHostReadiness(caps sub2api.HostCapabilities) error {
|
||||
var missing []string
|
||||
if !caps.Groups {
|
||||
missing = append(missing, "groups")
|
||||
}
|
||||
if !caps.Channels {
|
||||
missing = append(missing, "channels")
|
||||
}
|
||||
if !caps.Accounts {
|
||||
missing = append(missing, "accounts")
|
||||
}
|
||||
if !caps.AccountTest {
|
||||
missing = append(missing, "account_test")
|
||||
}
|
||||
|
||||
if len(missing) > 0 {
|
||||
return fmt.Errorf("host missing required capabilities: %v", missing)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -806,3 +806,85 @@ func queryCount(t *testing.T, db *sql.DB, table string) int {
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func TestValidateHostReadiness(t *testing.T) {
|
||||
t.Run("all capabilities present", func(t *testing.T) {
|
||||
caps := sub2api.HostCapabilities{
|
||||
Groups: true,
|
||||
Channels: true,
|
||||
Accounts: true,
|
||||
AccountTest: true,
|
||||
AccountModels: true,
|
||||
Plans: true,
|
||||
Subscriptions: true,
|
||||
}
|
||||
if err := validateHostReadiness(caps); err != nil {
|
||||
t.Fatalf("validateHostReadiness() = %v, want nil", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing groups", func(t *testing.T) {
|
||||
caps := sub2api.HostCapabilities{
|
||||
Groups: false,
|
||||
Channels: true,
|
||||
Accounts: true,
|
||||
AccountTest: true,
|
||||
}
|
||||
if err := validateHostReadiness(caps); err == nil {
|
||||
t.Fatal("validateHostReadiness() = nil, want error for missing groups")
|
||||
} else if !strings.Contains(err.Error(), "groups") {
|
||||
t.Fatalf("validateHostReadiness() = %v, want error mentioning groups", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing multiple capabilities", func(t *testing.T) {
|
||||
caps := sub2api.HostCapabilities{
|
||||
Groups: false,
|
||||
Channels: false,
|
||||
Accounts: false,
|
||||
AccountTest: false,
|
||||
}
|
||||
if err := validateHostReadiness(caps); err == nil {
|
||||
t.Fatal("validateHostReadiness() = nil, want error for multiple missing capabilities")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing accounts", func(t *testing.T) {
|
||||
caps := sub2api.HostCapabilities{
|
||||
Groups: true,
|
||||
Channels: true,
|
||||
Accounts: false,
|
||||
AccountTest: true,
|
||||
}
|
||||
if err := validateHostReadiness(caps); err == nil {
|
||||
t.Fatal("validateHostReadiness() = nil, want error for missing accounts")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing test account", func(t *testing.T) {
|
||||
caps := sub2api.HostCapabilities{
|
||||
Groups: true,
|
||||
Channels: true,
|
||||
Accounts: true,
|
||||
AccountTest: false,
|
||||
}
|
||||
if err := validateHostReadiness(caps); err == nil {
|
||||
t.Fatal("validateHostReadiness() = nil, want error for missing account_test")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("plans and subscriptions not required", func(t *testing.T) {
|
||||
caps := sub2api.HostCapabilities{
|
||||
Groups: true,
|
||||
Channels: true,
|
||||
Accounts: true,
|
||||
AccountTest: true,
|
||||
AccountModels: false,
|
||||
Plans: false,
|
||||
Subscriptions: false,
|
||||
}
|
||||
if err := validateHostReadiness(caps); err != nil {
|
||||
t.Fatalf("validateHostReadiness() = %v, want nil (plans/subscriptions are optional)", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -70,10 +70,10 @@ type routeLogSink interface {
|
||||
}
|
||||
|
||||
type ErrorMetrics struct {
|
||||
FlushErrors int64
|
||||
WriteErrors int64
|
||||
FlushErrors int64
|
||||
WriteErrors int64
|
||||
DroppedEvents int64
|
||||
mu sync.RWMutex
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func (e *ErrorMetrics) RecordFlushError() {
|
||||
|
||||
Reference in New Issue
Block a user