package setup import ( "fmt" "os" "path/filepath" "strings" "testing" "github.com/stretchr/testify/assert" ) // ============================================================================= // Test: setup.go — quoteIdentifier SQL Injection Prevention // 验证 PostgreSQL 标识符引用能正确防御 SQL 注入 // ============================================================================= func TestQuoteIdentifier_SQLInjectionDefense(t *testing.T) { t.Parallel() tests := []struct { name string input string expectedQuoted string description string }{ { name: "normal identifier", input: "mydatabase", expectedQuoted: `"mydatabase"`, description: "Normal database name should be quoted as-is", }, { name: "identifier with underscores", input: "my_db_name", expectedQuoted: `"my_db_name"`, description: "Underscores are valid in identifiers", }, { name: "identifier with numbers", input: "db123", expectedQuoted: `"db123"`, description: "Numbers after first char are valid", }, { name: "identifier starting with number", input: "123db", expectedQuoted: `"123db"`, description: "Numbers at start need quoting but are valid", }, { name: "SQL injection via double quote escape", input: `mydb"; DROP TABLE users; --`, expectedQuoted: `"mydb""; DROP TABLE users; --"`, description: "Double quotes must be escaped by doubling to prevent injection", }, { name: "SQL injection single double quote", input: `foo"bar`, expectedQuoted: `"foo""bar"`, description: "Single internal double quote gets doubled", }, { name: "SQL injection multiple double quotes", input: `a"b"c"d"e`, expectedQuoted: `"a""b""c""d""e"`, description: "All double quotes must be escaped", }, { name: "empty string produces empty quoted", input: "", expectedQuoted: `""`, description: "Empty input becomes empty quoted identifier", }, { name: "SQL injection UNION attack", input: `db" UNION SELECT * FROM secrets --`, expectedQuoted: `"db"" UNION SELECT * FROM secrets --"`, description: "UNION injection attempt neutralized by quote escaping", }, { name: "SQL injection with semicolon and comment", input: `test; SELECT 1--`, expectedQuoted: `"test; SELECT 1--"`, description: "Semicolons and comments inside quotes are literal text, not SQL syntax", }, { name: "whitespace is preserved inside quotes", input: `my db name`, expectedQuoted: `"my db name"`, description: "Spaces inside quoted identifiers are preserved", }, { name: "special characters preserved", input: `my-db.name$v2.0`, expectedQuoted: `"my-db.name$v2.0"`, description: "Non-quote special characters pass through (PostgreSQL allows these)", }, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() got := quoteIdentifier(tc.input) assert.Equal(t, tc.expectedQuoted, got, "quoteIdentifier(%q): got %q, want %q — %s", tc.input, got, tc.expectedQuoted, tc.description) }) } } func TestQuoteIdentifier_SafetyInvariant(t *testing.T) { t.Parallel() attackStrings := []string{ `mydb`, `my_db_123`, `; COPY users TO '/etc/passwd'; --`, } for _, attack := range attackStrings { attack := attack safeName := fmt.Sprintf("inv_%d", hashString(attack)) t.Run(safeName, func(t *testing.T) { t.Parallel() quoted := quoteIdentifier(attack) // Invariant 1: Output always starts and ends with exactly one double quote if !strings.HasPrefix(quoted, `"`) { t.Errorf("must start with double quote") } if !strings.HasSuffix(quoted, `"`) { t.Errorf("must end with double quote") } // Invariant 2: All internal double quotes are escaped (doubled) inner := quoted[1 : len(quoted)-1] for i := 0; i < len(inner)-1; i++ { if inner[i] == '"' && inner[i+1] != '"' { t.Errorf("unescaped double quote at position %d in inner content", i) } } // Invariant 3: When used in SQL, the result is a single valid identifier sql := fmt.Sprintf("CREATE DATABASE %s", quoted) if !strings.Contains(sql, quoted) { t.Error("SQL must contain the exact quoted identifier") } }) } } func min(a, b int) int { if a < b { return a } return b } func hashString(s string) int { h := 0 for _, c := range s { h = h*31 + int(c) } if h < 0 { h = -h } return h % 10000 } // ============================================================================= // Test: setup.go — readExistingJWTSecret / JWT Secret Mismatch Detection // ============================================================================= func TestReadExistingJWTSecret(t *testing.T) { t.Run("returns empty when no config file exists", func(t *testing.T) { dir := t.TempDir() t.Setenv("DATA_DIR", dir) secret := readExistingJWTSecret() assert.Empty(t, secret) }) t.Run("reads jwt.secret from config file", func(t *testing.T) { dir := t.TempDir() t.Setenv("DATA_DIR", dir) configPath := filepath.Join(dir, "config.yaml") content := []byte(`jwt: secret: my-test-secret-32-bytes-long-value!! `) assert.NoError(t, os.WriteFile(configPath, content, 0o644)) secret := readExistingJWTSecret() assert.Equal(t, "my-test-secret-32-bytes-long-value!!", secret) }) t.Run("returns empty for missing jwt.secret key", func(t *testing.T) { dir := t.TempDir() t.Setenv("DATA_DIR", dir) configPath := filepath.Join(dir, "config.yaml") content := []byte(`server: port: 8080 `) assert.NoError(t, os.WriteFile(configPath, content, 0o644)) secret := readExistingJWTSecret() assert.Empty(t, secret) }) t.Run("trims whitespace from secret", func(t *testing.T) { dir := t.TempDir() t.Setenv("DATA_DIR", dir) configPath := filepath.Join(dir, "config.yaml") content := []byte("jwt:\n secret: spaced-secret-32b \n") assert.NoError(t, os.WriteFile(configPath, content, 0o644)) secret := readExistingJWTSecret() assert.Equal(t, "spaced-secret-32b", secret) }) t.Run("returns empty on malformed YAML", func(t *testing.T) { dir := t.TempDir() t.Setenv("DATA_DIR", dir) configPath := filepath.Join(dir, "config.yaml") content := []byte(`{invalid yaml [[[`) assert.NoError(t, os.WriteFile(configPath, content, 0o644)) secret := readExistingJWTSecret() assert.Empty(t, secret, "malformed YAML should return empty secret without error") }) } // ============================================================================= // Test: setup.go — AutoSetupFromEnv helpers // ============================================================================= func TestGetEnvOrDefault(t *testing.T) { t.Run("returns env var value", func(t *testing.T) { t.Setenv("TEST_GETENV_KEY", "hello_value") assert.Equal(t, "hello_value", getEnvOrDefault("TEST_GETENV_KEY", "default")) }) t.Run("returns default when not set", func(t *testing.T) { os.Unsetenv("TEST_NONEXISTENT_KEY_XYZ") assert.Equal(t, "fallback", getEnvOrDefault("TEST_NONEXISTENT_KEY_XYZ", "fallback")) }) t.Run("returns default for empty string env", func(t *testing.T) { t.Setenv("TEST_EMPTY_ENV_KEY", "") assert.Equal(t, "fallback", getEnvOrDefault("TEST_EMPTY_ENV_KEY", "fallback")) }) } func TestGetEnvIntOrDefault(t *testing.T) { t.Run("parses valid integer", func(t *testing.T) { t.Setenv("TEST_INT_KEY", "5432") assert.Equal(t, 5432, getEnvIntOrDefault("TEST_INT_KEY", 0)) }) t.Run("returns default for invalid int", func(t *testing.T) { t.Setenv("TEST_BAD_INT", "not_a_number") assert.Equal(t, 9999, getEnvIntOrDefault("TEST_BAD_INT", 9999)) }) t.Run("returns default for empty", func(t *testing.T) { os.Unsetenv("TEST_EMPTY_INT_KEY") assert.Equal(t, 42, getEnvIntOrDefault("TEST_EMPTY_INT_KEY", 42)) }) } func TestAutoSetupEnabled(t *testing.T) { cases := map[string]bool{ "true": true, "1": true, "yes": true, "false": false, "0": false, "no": false, "": false, "TRUE": false, "Yes": false, // case-sensitive } for val, expected := range cases { val, expected := val, expected t.Run(fmt.Sprintf("AUTO_SETUP=%q", val), func(t *testing.T) { t.Setenv("AUTO_SETUP", val) assert.Equal(t, expected, AutoSetupEnabled()) }) } } // ============================================================================= // Test: setup.go — GetDataDir / NeedsSetup // ============================================================================= func TestGetDataDir_Priority(t *testing.T) { t.Run("DATA_DIR env takes priority", func(t *testing.T) { t.Setenv("DATA_DIR", "/custom/data/path") assert.Equal(t, "/custom/data/path", GetDataDir()) }) t.Run("falls back to current directory when no DATA_DIR and no /app/data", func(t *testing.T) { os.Unsetenv("DATA_DIR") // /app/data likely doesn't exist on dev machine dir := GetDataDir() assert.NotEmpty(t, dir) // Should be "." or similar fallback }) } func TestNeedsSetup_WithNoFiles(t *testing.T) { dir := t.TempDir() t.Setenv("DATA_DIR", dir) // No config.yaml or .installed → needs setup assert.True(t, NeedsSetup(), "should need setup when no config/lock files exist") } func TestNeedsSetup_WithConfigFile(t *testing.T) { dir := t.TempDir() t.Setenv("DATA_DIR", dir) configPath := filepath.Join(dir, "config.yaml") assert.NoError(t, os.WriteFile(configPath, []byte("test: data"), 0o644)) assert.False(t, NeedsSetup(), "should NOT need setup when config.yaml exists") } func TestNeedsSetup_WithLockFile(t *testing.T) { dir := t.TempDir() t.Setenv("DATA_DIR", dir) lockPath := filepath.Join(dir, ".installed") assert.NoError(t, os.WriteFile(lockPath, []byte("installed_at=2024"), 0o644)) assert.False(t, NeedsSetup(), "should NOT need setup when .installed lock exists") } // ============================================================================= // Test: setup.go — generateSecret // ============================================================================= func TestGenerateSecret(t *testing.T) { t.Parallel() t.Run("generates hex-encoded string of correct length", func(t *testing.T) { s, err := generateSecret(16) assert.NoError(t, err) assert.Len(t, s, 32) // 16 bytes = 32 hex chars }) t.Run("generates different values each call", func(t *testing.T) { s1, _ := generateSecret(16) s2, _ := generateSecret(16) assert.NotEqual(t, s1, s2) }) t.Run("valid hex characters only", func(t *testing.T) { s, err := generateSecret(32) assert.NoError(t, err) for _, c := range s { assert.True(t, (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f'), "invalid hex char: %c", c) } }) }