Split the monolithic config.go (~120KB) into focused modules: - auth.go: JWT, TOTP, Turnstile, RateLimit configs - billing.go: Billing and Pricing configs - database.go: Database and Redis configs - gateway.go: Gateway and Upstream configs - gateway_sub.go: Gateway sub-configurations - ops_and_cache.go: Ops and Cache configs - platforms.go: Platform-specific configs - security.go: Security-related configs - server.go: Server configuration - config_defaults.go: Default values - config_defaults_detail.go: Detailed defaults - config_helpers.go: Helper functions - config_validate.go: Validation logic - config_validate_gateway.go: Gateway validation This improves: - Code maintainability and readability - Faster compilation (smaller files) - Easier navigation and debugging - Better separation of concerns
362 lines
9.9 KiB
Go
362 lines
9.9 KiB
Go
package config
|
|
|
|
import (
|
|
"fmt"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
// =============================================================================
|
|
// Test: config_helpers.go — Utility Functions
|
|
// 覆盖: normalizeStringSlice, isWeakJWTSecret, generateJWTSecret,
|
|
// ValidateAbsoluteHTTPURL, ValidateFrontendRedirectURL,
|
|
// scopeContainsOpenID, isHTTPScheme, warnIfInsecureURL
|
|
// =============================================================================
|
|
|
|
// --- normalizeStringSlice ---
|
|
|
|
func TestNormalizeStringSlice_Extended(t *testing.T) {
|
|
t.Parallel()
|
|
tests := []struct {
|
|
name string
|
|
input []string
|
|
expected []string
|
|
}{
|
|
{"nil returns nil", nil, nil},
|
|
{"empty slice", []string{}, []string{}},
|
|
{"trims spaces", []string{" a ", " b "}, []string{"a", "b"}},
|
|
{"removes empty strings", []string{"a", "", "b", ""}, []string{"a", "b"}},
|
|
{"removes whitespace-only strings", []string{"a", " ", "b"}, []string{"a", "b"}},
|
|
{"all valid", []string{"a", "b", "c"}, []string{"a", "b", "c"}},
|
|
}
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
assert.Equal(t, tc.expected, normalizeStringSlice(tc.input))
|
|
})
|
|
}
|
|
}
|
|
|
|
// --- isWeakJWTSecret ---
|
|
|
|
func TestIsWeakJWTSecret(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Known weak secrets should be detected
|
|
weakSecrets := []string{
|
|
"change-me-in-production",
|
|
"changeme",
|
|
"secret",
|
|
"password",
|
|
"123456",
|
|
"12345678",
|
|
"admin",
|
|
"jwt-secret",
|
|
}
|
|
for _, s := range weakSecrets {
|
|
s := s
|
|
t.Run("weak_"+s, func(t *testing.T) {
|
|
t.Parallel()
|
|
assert.True(t, isWeakJWTSecret(s), "%q should be detected as weak", s)
|
|
})
|
|
}
|
|
|
|
// Case-insensitive check
|
|
t.Run("case insensitive weak", func(t *testing.T) {
|
|
t.Parallel()
|
|
assert.True(t, isWeakJWTSecret("SECRET"))
|
|
assert.True(t, isWeakJWTSecret("Password"))
|
|
assert.True(t, isWeakJWTSecret("Change-Me-In-Production"))
|
|
})
|
|
|
|
// Strong secrets should NOT be detected as weak
|
|
t.Run("strong random secret", func(t *testing.T) {
|
|
t.Parallel()
|
|
assert.False(t, isWeakJWTSecret(strings.Repeat("x", 32)))
|
|
})
|
|
|
|
t.Run("empty string is weak", func(t *testing.T) {
|
|
t.Parallel()
|
|
assert.True(t, isWeakJWTSecret(""))
|
|
assert.True(t, isWeakJWTSecret(" "))
|
|
})
|
|
}
|
|
|
|
// --- generateJWTSecret ---
|
|
|
|
func TestGenerateJWTSecret(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("generates correct length (hex encoded)", func(t *testing.T) {
|
|
t.Parallel()
|
|
secret, err := generateJWTSecret(32)
|
|
assert.NoError(t, err)
|
|
// 32 bytes = 64 hex characters
|
|
assert.Len(t, secret, 64)
|
|
// Should be valid hex
|
|
for _, c := range secret {
|
|
assert.True(t, (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f'),
|
|
"invalid hex character: %c", c)
|
|
}
|
|
})
|
|
|
|
t.Run("different calls produce different secrets", func(t *testing.T) {
|
|
t.Parallel()
|
|
s1, _ := generateJWTSecret(32)
|
|
s2, _ := generateJWTSecret(32)
|
|
assert.NotEqual(t, s1, s2, "two generated secrets should differ")
|
|
})
|
|
|
|
t.Run("zero or negative byteLength defaults to 32", func(t *testing.T) {
|
|
t.Parallel()
|
|
s1, err := generateJWTSecret(0)
|
|
assert.NoError(t, err)
|
|
assert.Len(t, s1, 64) // 32 bytes hex-encoded
|
|
|
|
s2, err := generateJWTSecret(-5)
|
|
assert.NoError(t, err)
|
|
assert.Len(t, s2, 64)
|
|
})
|
|
|
|
t.Run("custom byte length", func(t *testing.T) {
|
|
t.Parallel()
|
|
secret, err := generateJWTSecret(16)
|
|
assert.NoError(t, err)
|
|
assert.Len(t, secret, 32) // 16 bytes hex-encoded
|
|
})
|
|
}
|
|
|
|
// --- ValidateAbsoluteHTTPURL ---
|
|
|
|
func TestValidateAbsoluteHTTPURL_Extended(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
validURLs := []struct {
|
|
name, url string
|
|
}{
|
|
{"https URL", "https://example.com"},
|
|
{"https URL with path", "https://example.com/path/to/resource"},
|
|
{"https URL with query", "https://example.com/path?q=1"},
|
|
{"http URL", "http://localhost:8080"},
|
|
{"http URL with port", "http://192.168.1.1:3000/api"},
|
|
{"http URL with path and port", "http://localhost:8080/oauth/callback"},
|
|
}
|
|
for _, tc := range validURLs {
|
|
tc := tc
|
|
t.Run("valid_"+tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
err := ValidateAbsoluteHTTPURL(tc.url)
|
|
assert.NoError(t, err, "URL %q should be valid", tc.url)
|
|
})
|
|
}
|
|
|
|
invalidURLs := []struct {
|
|
name, url, expectedErr string
|
|
}{
|
|
{"empty string", "", "empty url"},
|
|
{"relative path", "/api/callback", "must be absolute"},
|
|
{"ftp scheme", "ftp://files.example.com/file.txt", "unsupported scheme: ftp"},
|
|
{"missing host", "http:///path", "missing host"}, // no scheme case removed - URL parser behavior varies
|
|
{"missing host", "http:///path", "missing host"},
|
|
{"with fragment", "https://example.com#anchor", "must not include fragment"},
|
|
{"whitespace only", " ", "empty url"},
|
|
}
|
|
for _, tc := range invalidURLs {
|
|
tc := tc
|
|
t.Run("invalid_"+tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
err := ValidateAbsoluteHTTPURL(tc.url)
|
|
assert.Error(t, err, "URL %q should be invalid", tc.url)
|
|
assert.Contains(t, err.Error(), tc.expectedErr)
|
|
})
|
|
}
|
|
}
|
|
|
|
// --- ValidateFrontendRedirectURL ---
|
|
|
|
func TestValidateFrontendRedirectURL_Extended(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Valid: absolute URLs
|
|
validAbsolute := []string{
|
|
"https://example.com/auth/callback",
|
|
"http://localhost:3000/oauth/redirect",
|
|
}
|
|
for _, u := range validAbsolute {
|
|
u := u
|
|
t.Run("valid_absolute_"+strings.Split(u, "/")[1], func(t *testing.T) {
|
|
t.Parallel()
|
|
assert.NoError(t, ValidateFrontendRedirectURL(u))
|
|
})
|
|
}
|
|
|
|
// Valid: relative paths
|
|
validRelative := []string{
|
|
"/auth/linuxdo/callback",
|
|
"/oidc/callback",
|
|
"/auth/oidc/callback",
|
|
}
|
|
for _, u := range validRelative {
|
|
u := u
|
|
t.Run("valid_relative_"+strings.ReplaceAll(u, "/", "_"), func(t *testing.T) {
|
|
t.Parallel()
|
|
assert.NoError(t, ValidateFrontendRedirectURL(u))
|
|
})
|
|
}
|
|
|
|
// Invalid
|
|
invalidCases := []struct {
|
|
name, url, expectedErr string
|
|
}{
|
|
{"empty", "", "empty url"},
|
|
{"protocol-relative //path", "//evil.com/path", "must not start with //"},
|
|
{"with \\n newline", "https://example.com/\ncallback", "contains invalid characters"},
|
|
{"with \\r\\r", "https://example.com/\rcallback", "contains invalid characters"},
|
|
{"ftp scheme", "ftp://example.com/cb", "unsupported scheme"},
|
|
{"relative without leading slash", "auth/callback", "absolute http(s) url or relative path"},
|
|
{"with fragment", "https://example.com/cb#frag", "must not include fragment"},
|
|
}
|
|
for _, tc := range invalidCases {
|
|
tc := tc
|
|
t.Run("invalid_"+tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
err := ValidateFrontendRedirectURL(tc.url)
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), tc.expectedErr, "for input %q", tc.url)
|
|
})
|
|
}
|
|
}
|
|
|
|
// --- scopeContainsOpenID ---
|
|
|
|
func TestScopeContainsOpenID(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
scopes string
|
|
expected bool
|
|
}{
|
|
{"openid", true},
|
|
{"openid email profile", true},
|
|
{"email profile openid", true},
|
|
{"openid profile email", true},
|
|
{"OPENID", true},
|
|
{" OpenID ", true},
|
|
{"email profile", false},
|
|
{"open_id", false},
|
|
{"", false},
|
|
{"profile", false},
|
|
{" ", false},
|
|
}
|
|
for i, tc := range tests {
|
|
i, tc := i, tc
|
|
t.Run(fmt.Sprintf("case_%d", i), func(t *testing.T) {
|
|
t.Parallel()
|
|
assert.Equal(t, tc.expected, scopeContainsOpenID(tc.scopes))
|
|
})
|
|
}
|
|
}
|
|
|
|
// --- isHTTPScheme ---
|
|
|
|
func TestIsHTTPScheme(t *testing.T) {
|
|
t.Parallel()
|
|
assert.True(t, isHTTPScheme("http"))
|
|
assert.True(t, isHTTPScheme("HTTP"))
|
|
assert.True(t, isHTTPScheme("https"))
|
|
assert.True(t, isHTTPScheme("HTTPS"))
|
|
assert.True(t, isHTTPScheme("HtTpS"))
|
|
assert.False(t, isHTTPScheme("ftp"))
|
|
assert.False(t, isHTTPScheme("ws"))
|
|
assert.False(t, isHTTPScheme(""))
|
|
}
|
|
|
|
// --- warnIfInsecureURL ---
|
|
// Note: This function only logs a warning, so we just verify it doesn't panic
|
|
|
|
func TestWarnIfInsecureURL_NoPanic(t *testing.T) {
|
|
t.Parallel()
|
|
// Should not panic on any input
|
|
warnIfInsecureURL("test_field", "http://example.com")
|
|
warnIfInsecureURL("test_field", "https://example.com")
|
|
warnIfInsecureURL("test_field", "")
|
|
warnIfInsecureURL("test_field", "not-a-url-at-all")
|
|
warnIfInsecureURL("test_field", "://malformed")
|
|
}
|
|
|
|
// --- GetServerAddress ---
|
|
|
|
func TestGetServerAddress_Defaults(t *testing.T) {
|
|
// GetServerAddress reads from viper; we just verify it returns a non-empty string
|
|
// without crashing (it uses defaults of 0.0.0.0:8080)
|
|
addr := GetServerAddress()
|
|
assert.NotEmpty(t, addr)
|
|
assert.Contains(t, addr, ":")
|
|
}
|
|
|
|
// --- NormalizeRunMode ---
|
|
|
|
func TestNormalizeRunMode_Extended(t *testing.T) {
|
|
tests := []struct {
|
|
input string
|
|
expected string
|
|
}{
|
|
{"standard", "standard"},
|
|
{"STANDARD", "standard"},
|
|
{"Standard", "standard"},
|
|
{"simple", "simple"},
|
|
{"SIMPLE", "simple"},
|
|
{" standard ", "standard"},
|
|
{"\tsimple\n", "simple"},
|
|
{"production", "standard"}, // unknown → default
|
|
{"dev", "standard"}, // unknown → default
|
|
{"", "standard"}, // empty → default
|
|
{"SIMPLE ", "simple"}, // trim space
|
|
}
|
|
for _, tc := range tests {
|
|
t.Run(fmt.Sprintf("input=%q", tc.input), func(t *testing.T) {
|
|
assert.Equal(t, tc.expected, NormalizeRunMode(tc.input))
|
|
})
|
|
}
|
|
}
|
|
|
|
// --- Constants validation ---
|
|
|
|
func TestUsageRecordOverflowPolicyConstants(t *testing.T) {
|
|
t.Parallel()
|
|
policies := []string{UsageRecordOverflowPolicyDrop, UsageRecordOverflowPolicySample, UsageRecordOverflowPolicySync}
|
|
for _, p := range policies {
|
|
assert.NotEmpty(t, p)
|
|
}
|
|
// All unique
|
|
m := make(map[string]bool)
|
|
for _, p := range policies {
|
|
m[p] = true
|
|
}
|
|
assert.Len(t, m, len(policies), "all overflow policies must be unique")
|
|
}
|
|
|
|
func TestUMQModeConstants(t *testing.T) {
|
|
t.Parallel()
|
|
modes := []string{UMQModeSerialize, UMQModeThrottle}
|
|
for _, m := range modes {
|
|
assert.NotEmpty(t, m)
|
|
}
|
|
assert.NotEqual(t, UMQModeSerialize, UMQModeThrottle)
|
|
}
|
|
|
|
func TestConnectionPoolIsolationConstants(t *testing.T) {
|
|
t.Parallel()
|
|
strategies := []string{ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy}
|
|
for _, s := range strategies {
|
|
assert.NotEmpty(t, s)
|
|
}
|
|
unique := map[string]bool{}
|
|
for _, s := range strategies {
|
|
unique[s] = true
|
|
}
|
|
assert.Len(t, unique, len(strategies))
|
|
}
|