- Add quoteIdentifier for SQL injection defense in setup.go - Add setup_security_test.go for security tests - Add admin auth middleware improvements - Add admin auth test coverage
342 lines
11 KiB
Go
342 lines
11 KiB
Go
package setup
|
|
|
|
import (
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
// =============================================================================
|
|
// Test: setup.go — quoteIdentifier SQL Injection Prevention
|
|
// 验证 PostgreSQL 标识符引用能正确防御 SQL 注入
|
|
// =============================================================================
|
|
|
|
func TestQuoteIdentifier_SQLInjectionDefense(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
input string
|
|
expectedQuoted string
|
|
description string
|
|
}{
|
|
{
|
|
name: "normal identifier",
|
|
input: "mydatabase",
|
|
expectedQuoted: `"mydatabase"`,
|
|
description: "Normal database name should be quoted as-is",
|
|
},
|
|
{
|
|
name: "identifier with underscores",
|
|
input: "my_db_name",
|
|
expectedQuoted: `"my_db_name"`,
|
|
description: "Underscores are valid in identifiers",
|
|
},
|
|
{
|
|
name: "identifier with numbers",
|
|
input: "db123",
|
|
expectedQuoted: `"db123"`,
|
|
description: "Numbers after first char are valid",
|
|
},
|
|
{
|
|
name: "identifier starting with number",
|
|
input: "123db",
|
|
expectedQuoted: `"123db"`,
|
|
description: "Numbers at start need quoting but are valid",
|
|
},
|
|
{
|
|
name: "SQL injection via double quote escape",
|
|
input: `mydb"; DROP TABLE users; --`,
|
|
expectedQuoted: `"mydb""; DROP TABLE users; --"`,
|
|
description: "Double quotes must be escaped by doubling to prevent injection",
|
|
},
|
|
{
|
|
name: "SQL injection single double quote",
|
|
input: `foo"bar`,
|
|
expectedQuoted: `"foo""bar"`,
|
|
description: "Single internal double quote gets doubled",
|
|
},
|
|
{
|
|
name: "SQL injection multiple double quotes",
|
|
input: `a"b"c"d"e`,
|
|
expectedQuoted: `"a""b""c""d""e"`,
|
|
description: "All double quotes must be escaped",
|
|
},
|
|
{
|
|
name: "empty string produces empty quoted",
|
|
input: "",
|
|
expectedQuoted: `""`,
|
|
description: "Empty input becomes empty quoted identifier",
|
|
},
|
|
{
|
|
name: "SQL injection UNION attack",
|
|
input: `db" UNION SELECT * FROM secrets --`,
|
|
expectedQuoted: `"db"" UNION SELECT * FROM secrets --"`,
|
|
description: "UNION injection attempt neutralized by quote escaping",
|
|
},
|
|
{
|
|
name: "SQL injection with semicolon and comment",
|
|
input: `test; SELECT 1--`,
|
|
expectedQuoted: `"test; SELECT 1--"`,
|
|
description: "Semicolons and comments inside quotes are literal text, not SQL syntax",
|
|
},
|
|
{
|
|
name: "whitespace is preserved inside quotes",
|
|
input: `my db name`,
|
|
expectedQuoted: `"my db name"`,
|
|
description: "Spaces inside quoted identifiers are preserved",
|
|
},
|
|
{
|
|
name: "special characters preserved",
|
|
input: `my-db.name$v2.0`,
|
|
expectedQuoted: `"my-db.name$v2.0"`,
|
|
description: "Non-quote special characters pass through (PostgreSQL allows these)",
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
got := quoteIdentifier(tc.input)
|
|
assert.Equal(t, tc.expectedQuoted, got,
|
|
"quoteIdentifier(%q): got %q, want %q — %s", tc.input, got, tc.expectedQuoted, tc.description)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestQuoteIdentifier_SafetyInvariant(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
attackStrings := []string{
|
|
`mydb`,
|
|
`my_db_123`,
|
|
`; COPY users TO '/etc/passwd'; --`,
|
|
}
|
|
|
|
for _, attack := range attackStrings {
|
|
attack := attack
|
|
safeName := fmt.Sprintf("inv_%d", hashString(attack))
|
|
t.Run(safeName, func(t *testing.T) {
|
|
t.Parallel()
|
|
quoted := quoteIdentifier(attack)
|
|
|
|
// Invariant 1: Output always starts and ends with exactly one double quote
|
|
if !strings.HasPrefix(quoted, `"`) { t.Errorf("must start with double quote") }
|
|
if !strings.HasSuffix(quoted, `"`) { t.Errorf("must end with double quote") }
|
|
|
|
// Invariant 2: All internal double quotes are escaped (doubled)
|
|
inner := quoted[1 : len(quoted)-1]
|
|
for i := 0; i < len(inner)-1; i++ {
|
|
if inner[i] == '"' && inner[i+1] != '"' {
|
|
t.Errorf("unescaped double quote at position %d in inner content", i)
|
|
}
|
|
}
|
|
|
|
// Invariant 3: When used in SQL, the result is a single valid identifier
|
|
sql := fmt.Sprintf("CREATE DATABASE %s", quoted)
|
|
if !strings.Contains(sql, quoted) { t.Error("SQL must contain the exact quoted identifier") }
|
|
})
|
|
}
|
|
}
|
|
|
|
func min(a, b int) int { if a < b { return a }; return b }
|
|
|
|
func hashString(s string) int {
|
|
h := 0
|
|
for _, c := range s {
|
|
h = h*31 + int(c)
|
|
}
|
|
if h < 0 { h = -h }
|
|
return h % 10000
|
|
}
|
|
|
|
// =============================================================================
|
|
// Test: setup.go — readExistingJWTSecret / JWT Secret Mismatch Detection
|
|
// =============================================================================
|
|
|
|
func TestReadExistingJWTSecret(t *testing.T) {
|
|
t.Run("returns empty when no config file exists", func(t *testing.T) {
|
|
dir := t.TempDir()
|
|
t.Setenv("DATA_DIR", dir)
|
|
secret := readExistingJWTSecret()
|
|
assert.Empty(t, secret)
|
|
})
|
|
|
|
t.Run("reads jwt.secret from config file", func(t *testing.T) {
|
|
dir := t.TempDir()
|
|
t.Setenv("DATA_DIR", dir)
|
|
configPath := filepath.Join(dir, "config.yaml")
|
|
content := []byte(`jwt:
|
|
secret: my-test-secret-32-bytes-long-value!!
|
|
`)
|
|
assert.NoError(t, os.WriteFile(configPath, content, 0o644))
|
|
|
|
secret := readExistingJWTSecret()
|
|
assert.Equal(t, "my-test-secret-32-bytes-long-value!!", secret)
|
|
})
|
|
|
|
t.Run("returns empty for missing jwt.secret key", func(t *testing.T) {
|
|
dir := t.TempDir()
|
|
t.Setenv("DATA_DIR", dir)
|
|
configPath := filepath.Join(dir, "config.yaml")
|
|
content := []byte(`server:
|
|
port: 8080
|
|
`)
|
|
assert.NoError(t, os.WriteFile(configPath, content, 0o644))
|
|
|
|
secret := readExistingJWTSecret()
|
|
assert.Empty(t, secret)
|
|
})
|
|
|
|
t.Run("trims whitespace from secret", func(t *testing.T) {
|
|
dir := t.TempDir()
|
|
t.Setenv("DATA_DIR", dir)
|
|
configPath := filepath.Join(dir, "config.yaml")
|
|
content := []byte("jwt:\n secret: spaced-secret-32b \n")
|
|
assert.NoError(t, os.WriteFile(configPath, content, 0o644))
|
|
|
|
secret := readExistingJWTSecret()
|
|
assert.Equal(t, "spaced-secret-32b", secret)
|
|
})
|
|
|
|
t.Run("returns empty on malformed YAML", func(t *testing.T) {
|
|
dir := t.TempDir()
|
|
t.Setenv("DATA_DIR", dir)
|
|
configPath := filepath.Join(dir, "config.yaml")
|
|
content := []byte(`{invalid yaml [[[`)
|
|
assert.NoError(t, os.WriteFile(configPath, content, 0o644))
|
|
|
|
secret := readExistingJWTSecret()
|
|
assert.Empty(t, secret, "malformed YAML should return empty secret without error")
|
|
})
|
|
}
|
|
|
|
// =============================================================================
|
|
// Test: setup.go — AutoSetupFromEnv helpers
|
|
// =============================================================================
|
|
|
|
func TestGetEnvOrDefault(t *testing.T) {
|
|
t.Run("returns env var value", func(t *testing.T) {
|
|
t.Setenv("TEST_GETENV_KEY", "hello_value")
|
|
assert.Equal(t, "hello_value", getEnvOrDefault("TEST_GETENV_KEY", "default"))
|
|
})
|
|
|
|
t.Run("returns default when not set", func(t *testing.T) {
|
|
os.Unsetenv("TEST_NONEXISTENT_KEY_XYZ")
|
|
assert.Equal(t, "fallback", getEnvOrDefault("TEST_NONEXISTENT_KEY_XYZ", "fallback"))
|
|
})
|
|
|
|
t.Run("returns default for empty string env", func(t *testing.T) {
|
|
t.Setenv("TEST_EMPTY_ENV_KEY", "")
|
|
assert.Equal(t, "fallback", getEnvOrDefault("TEST_EMPTY_ENV_KEY", "fallback"))
|
|
})
|
|
}
|
|
|
|
func TestGetEnvIntOrDefault(t *testing.T) {
|
|
t.Run("parses valid integer", func(t *testing.T) {
|
|
t.Setenv("TEST_INT_KEY", "5432")
|
|
assert.Equal(t, 5432, getEnvIntOrDefault("TEST_INT_KEY", 0))
|
|
})
|
|
|
|
t.Run("returns default for invalid int", func(t *testing.T) {
|
|
t.Setenv("TEST_BAD_INT", "not_a_number")
|
|
assert.Equal(t, 9999, getEnvIntOrDefault("TEST_BAD_INT", 9999))
|
|
})
|
|
|
|
t.Run("returns default for empty", func(t *testing.T) {
|
|
os.Unsetenv("TEST_EMPTY_INT_KEY")
|
|
assert.Equal(t, 42, getEnvIntOrDefault("TEST_EMPTY_INT_KEY", 42))
|
|
})
|
|
}
|
|
|
|
func TestAutoSetupEnabled(t *testing.T) {
|
|
cases := map[string]bool{
|
|
"true": true, "1": true, "yes": true,
|
|
"false": false, "0": false, "no": false,
|
|
"": false, "TRUE": false, "Yes": false, // case-sensitive
|
|
}
|
|
for val, expected := range cases {
|
|
val, expected := val, expected
|
|
t.Run(fmt.Sprintf("AUTO_SETUP=%q", val), func(t *testing.T) {
|
|
t.Setenv("AUTO_SETUP", val)
|
|
assert.Equal(t, expected, AutoSetupEnabled())
|
|
})
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// Test: setup.go — GetDataDir / NeedsSetup
|
|
// =============================================================================
|
|
|
|
func TestGetDataDir_Priority(t *testing.T) {
|
|
t.Run("DATA_DIR env takes priority", func(t *testing.T) {
|
|
t.Setenv("DATA_DIR", "/custom/data/path")
|
|
assert.Equal(t, "/custom/data/path", GetDataDir())
|
|
})
|
|
|
|
t.Run("falls back to current directory when no DATA_DIR and no /app/data", func(t *testing.T) {
|
|
os.Unsetenv("DATA_DIR")
|
|
// /app/data likely doesn't exist on dev machine
|
|
dir := GetDataDir()
|
|
assert.NotEmpty(t, dir)
|
|
// Should be "." or similar fallback
|
|
})
|
|
}
|
|
|
|
func TestNeedsSetup_WithNoFiles(t *testing.T) {
|
|
dir := t.TempDir()
|
|
t.Setenv("DATA_DIR", dir)
|
|
// No config.yaml or .installed → needs setup
|
|
assert.True(t, NeedsSetup(), "should need setup when no config/lock files exist")
|
|
}
|
|
|
|
func TestNeedsSetup_WithConfigFile(t *testing.T) {
|
|
dir := t.TempDir()
|
|
t.Setenv("DATA_DIR", dir)
|
|
configPath := filepath.Join(dir, "config.yaml")
|
|
assert.NoError(t, os.WriteFile(configPath, []byte("test: data"), 0o644))
|
|
assert.False(t, NeedsSetup(), "should NOT need setup when config.yaml exists")
|
|
}
|
|
|
|
func TestNeedsSetup_WithLockFile(t *testing.T) {
|
|
dir := t.TempDir()
|
|
t.Setenv("DATA_DIR", dir)
|
|
lockPath := filepath.Join(dir, ".installed")
|
|
assert.NoError(t, os.WriteFile(lockPath, []byte("installed_at=2024"), 0o644))
|
|
assert.False(t, NeedsSetup(), "should NOT need setup when .installed lock exists")
|
|
}
|
|
|
|
// =============================================================================
|
|
// Test: setup.go — generateSecret
|
|
// =============================================================================
|
|
|
|
func TestGenerateSecret(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("generates hex-encoded string of correct length", func(t *testing.T) {
|
|
s, err := generateSecret(16)
|
|
assert.NoError(t, err)
|
|
assert.Len(t, s, 32) // 16 bytes = 32 hex chars
|
|
})
|
|
|
|
t.Run("generates different values each call", func(t *testing.T) {
|
|
s1, _ := generateSecret(16)
|
|
s2, _ := generateSecret(16)
|
|
assert.NotEqual(t, s1, s2)
|
|
})
|
|
|
|
t.Run("valid hex characters only", func(t *testing.T) {
|
|
s, err := generateSecret(32)
|
|
assert.NoError(t, err)
|
|
for _, c := range s {
|
|
assert.True(t, (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f'),
|
|
"invalid hex char: %c", c)
|
|
}
|
|
})
|
|
}
|