261 lines
12 KiB
Go
261 lines
12 KiB
Go
// Package config provides configuration loading, defaults, and validation.
|
||
//
|
||
// Type definitions are organized by domain into companion files:
|
||
//
|
||
// - server.go — ServerConfig, H2CConfig, CORSConfig, ConcurrencyConfig
|
||
// - security.go — SecurityConfig, URLAllowlist, CSP, ResponseHeaders
|
||
// - database.go — DatabaseConfig, RedisConfig (with DSN helpers)
|
||
// - auth.go — JWTConfig, TotpConfig, TurnstileConfig, DefaultConfig, RateLimitConfig
|
||
// - billing.go — BillingConfig, PricingConfig
|
||
// - gateway.go — GatewayConfig, UserMessageQueue, SchedulingConfig
|
||
// - gateway_sub.go — OpenAIWS, UsageRecord, TLSFingerprint sub-structs
|
||
// - platforms.go — Sora, Gemini, LinuxDo, OIDC, Update, Idempotency configs
|
||
// - ops_and_cache.go— LogConfig, OpsConfig, Dashboard, Cache, Cleanup configs
|
||
package config
|
||
|
||
import (
|
||
"fmt"
|
||
"log/slog"
|
||
"os"
|
||
"strings"
|
||
|
||
"github.com/spf13/viper"
|
||
)
|
||
|
||
const (
|
||
RunModeStandard = "standard"
|
||
RunModeSimple = "simple"
|
||
)
|
||
|
||
// 使用量记录队列溢出策略
|
||
const (
|
||
UsageRecordOverflowPolicyDrop = "drop"
|
||
UsageRecordOverflowPolicySample = "sample"
|
||
UsageRecordOverflowPolicySync = "sync"
|
||
)
|
||
|
||
// UMQ(用户消息队列)模式常量
|
||
const (
|
||
UMQModeSerialize = "serialize"
|
||
UMQModeThrottle = "throttle"
|
||
)
|
||
|
||
// 连接池隔离策略常量
|
||
const (
|
||
ConnectionPoolIsolationProxy = "proxy"
|
||
ConnectionPoolIsolationAccount = "account"
|
||
ConnectionPoolIsolationAccountProxy = "account_proxy"
|
||
)
|
||
|
||
// DefaultCSPPolicy is the default Content-Security-Policy with nonce support.
|
||
const DefaultCSPPolicy = `default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'`
|
||
|
||
// Config is the top-level application configuration.
|
||
type Config struct {
|
||
Server ServerConfig `mapstructure:"server"`
|
||
Log LogConfig `mapstructure:"log"`
|
||
CORS CORSConfig `mapstructure:"cors"`
|
||
Security SecurityConfig `mapstructure:"security"`
|
||
Billing BillingConfig `mapstructure:"billing"`
|
||
Turnstile TurnstileConfig `mapstructure:"turnstile"`
|
||
Database DatabaseConfig `mapstructure:"database"`
|
||
Redis RedisConfig `mapstructure:"redis"`
|
||
Ops OpsConfig `mapstructure:"ops"`
|
||
JWT JWTConfig `mapstructure:"jwt"`
|
||
Totp TotpConfig `mapstructure:"totp"`
|
||
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
|
||
OIDC OIDCConnectConfig `mapstructure:"oidc_connect"`
|
||
Default DefaultConfig `mapstructure:"default"`
|
||
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
||
Pricing PricingConfig `mapstructure:"pricing"`
|
||
Gateway GatewayConfig `mapstructure:"gateway"`
|
||
APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"`
|
||
SubscriptionCache SubscriptionCacheConfig `mapstructure:"subscription_cache"`
|
||
SubscriptionMaintenance SubscriptionMaintenanceConfig `mapstructure:"subscription_maintenance"`
|
||
Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"`
|
||
DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"`
|
||
UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"`
|
||
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
|
||
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
|
||
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
|
||
Timezone string `mapstructure:"timezone"`
|
||
Sora SoraConfig `mapstructure:"sora"`
|
||
Gemini GeminiConfig `mapstructure:"gemini"`
|
||
Update UpdateConfig `mapstructure:"update"`
|
||
Idempotency IdempotencyConfig `mapstructure:"idempotency"`
|
||
}
|
||
|
||
func NormalizeRunMode(value string) string {
|
||
normalized := strings.ToLower(strings.TrimSpace(value))
|
||
switch normalized {
|
||
case RunModeStandard, RunModeSimple:
|
||
return normalized
|
||
default:
|
||
return RunModeStandard
|
||
}
|
||
}
|
||
|
||
// Load 读取并校验完整配置(要求 jwt.secret 已显式提供)。
|
||
func Load() (*Config, error) { return load(false) }
|
||
|
||
// LoadForBootstrap 读取启动阶段配置。允许 jwt.secret 先留空。
|
||
func LoadForBootstrap() (*Config, error) { return load(true) }
|
||
|
||
func load(allowMissingJWTSecret bool) (*Config, error) {
|
||
viper.SetConfigName("config")
|
||
viper.SetConfigType("yaml")
|
||
|
||
// Add config paths in priority order
|
||
if dataDir := os.Getenv("DATA_DIR"); dataDir != "" {
|
||
viper.AddConfigPath(dataDir)
|
||
}
|
||
viper.AddConfigPath("/app/data")
|
||
viper.AddConfigPath(".")
|
||
viper.AddConfigPath("./config")
|
||
viper.AddConfigPath("/etc/sub2api")
|
||
|
||
viper.AutomaticEnv()
|
||
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
||
setDefaults()
|
||
|
||
if err := viper.ReadInConfig(); err != nil {
|
||
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
|
||
return nil, fmt.Errorf("read config error: %w", err)
|
||
}
|
||
}
|
||
|
||
var cfg Config
|
||
if err := viper.Unmarshal(&cfg); err != nil {
|
||
return nil, fmt.Errorf("unmarshal config error: %w", err)
|
||
}
|
||
|
||
cfg.RunMode = NormalizeRunMode(cfg.RunMode)
|
||
cfg.Server.Mode = strings.ToLower(strings.TrimSpace(cfg.Server.Mode))
|
||
if cfg.Server.Mode == "" {
|
||
cfg.Server.Mode = "debug"
|
||
}
|
||
cfg.Server.FrontendURL = strings.TrimSpace(cfg.Server.FrontendURL)
|
||
normalizeAllStringFields(&cfg)
|
||
|
||
if err := loadCodexTemplate(&cfg); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 兼容旧键 sticky_previous_response_ttl_seconds
|
||
if cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && cfg.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 {
|
||
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = cfg.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds
|
||
}
|
||
|
||
// Normalize UMQ mode
|
||
if m := cfg.Gateway.UserMessageQueue.Mode; m != "" && m != UMQModeSerialize && m != UMQModeThrottle {
|
||
slog.Warn("invalid user_message_queue mode, disabling", "mode", m,
|
||
"valid_modes", []string{UMQModeSerialize, UMQModeThrottle})
|
||
cfg.Gateway.UserMessageQueue.Mode = ""
|
||
}
|
||
|
||
// Auto-generate TOTP encryption key if not set
|
||
cfg.Totp.EncryptionKey = strings.TrimSpace(cfg.Totp.EncryptionKey)
|
||
if cfg.Totp.EncryptionKey == "" {
|
||
key, err := generateJWTSecret(32)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("generate totp encryption key error: %w", err)
|
||
}
|
||
cfg.Totp.EncryptionKey = key
|
||
cfg.Totp.EncryptionKeyConfigured = false
|
||
slog.Warn("TOTP encryption key auto-generated. Consider setting a fixed key for production.")
|
||
} else {
|
||
cfg.Totp.EncryptionKeyConfigured = true
|
||
}
|
||
|
||
originalJWTSecret := cfg.JWT.Secret
|
||
cfg.JWT.SecretConfigured = strings.TrimSpace(originalJWTSecret) != ""
|
||
if allowMissingJWTSecret && originalJWTSecret == "" {
|
||
cfg.JWT.Secret = strings.Repeat("0", 32)
|
||
}
|
||
|
||
if err := cfg.Validate(); err != nil {
|
||
return nil, fmt.Errorf("validate config error: %w", err)
|
||
}
|
||
|
||
if allowMissingJWTSecret && originalJWTSecret == "" {
|
||
cfg.JWT.Secret = ""
|
||
}
|
||
|
||
logSecurityWarnings(&cfg)
|
||
return &cfg, nil
|
||
}
|
||
|
||
// normalizeAllStringFields trims all string fields loaded from config.
|
||
func normalizeAllStringFields(cfg *Config) {
|
||
cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret)
|
||
cfg.LinuxDo.ClientID = strings.TrimSpace(cfg.LinuxDo.ClientID)
|
||
cfg.LinuxDo.ClientSecret = strings.TrimSpace(cfg.LinuxDo.ClientSecret)
|
||
cfg.LinuxDo.AuthorizeURL = strings.TrimSpace(cfg.LinuxDo.AuthorizeURL)
|
||
cfg.LinuxDo.TokenURL = strings.TrimSpace(cfg.LinuxDo.TokenURL)
|
||
cfg.LinuxDo.UserInfoURL = strings.TrimSpace(cfg.LinuxDo.UserInfoURL)
|
||
cfg.LinuxDo.Scopes = strings.TrimSpace(cfg.LinuxDo.Scopes)
|
||
cfg.LinuxDo.RedirectURL = strings.TrimSpace(cfg.LinuxDo.RedirectURL)
|
||
cfg.LinuxDo.FrontendRedirectURL = strings.TrimSpace(cfg.LinuxDo.FrontendRedirectURL)
|
||
cfg.LinuxDo.TokenAuthMethod = strings.ToLower(strings.TrimSpace(cfg.LinuxDo.TokenAuthMethod))
|
||
cfg.LinuxDo.UserInfoEmailPath = strings.TrimSpace(cfg.LinuxDo.UserInfoEmailPath)
|
||
cfg.LinuxDo.UserInfoIDPath = strings.TrimSpace(cfg.LinuxDo.UserInfoIDPath)
|
||
cfg.LinuxDo.UserInfoUsernamePath = strings.TrimSpace(cfg.LinuxDo.UserInfoUsernamePath)
|
||
cfg.OIDC.ProviderName = strings.TrimSpace(cfg.OIDC.ProviderName)
|
||
cfg.OIDC.ClientID = strings.TrimSpace(cfg.OIDC.ClientID)
|
||
cfg.OIDC.ClientSecret = strings.TrimSpace(cfg.OIDC.ClientSecret)
|
||
cfg.OIDC.IssuerURL = strings.TrimSpace(cfg.OIDC.IssuerURL)
|
||
cfg.OIDC.DiscoveryURL = strings.TrimSpace(cfg.OIDC.DiscoveryURL)
|
||
cfg.OIDC.AuthorizeURL = strings.TrimSpace(cfg.OIDC.AuthorizeURL)
|
||
cfg.OIDC.TokenURL = strings.TrimSpace(cfg.OIDC.TokenURL)
|
||
cfg.OIDC.UserInfoURL = strings.TrimSpace(cfg.OIDC.UserInfoURL)
|
||
cfg.OIDC.JWKSURL = strings.TrimSpace(cfg.OIDC.JWKSURL)
|
||
cfg.OIDC.Scopes = strings.TrimSpace(cfg.OIDC.Scopes)
|
||
cfg.OIDC.RedirectURL = strings.TrimSpace(cfg.OIDC.RedirectURL)
|
||
cfg.OIDC.FrontendRedirectURL = strings.TrimSpace(cfg.OIDC.FrontendRedirectURL)
|
||
cfg.OIDC.TokenAuthMethod = strings.ToLower(strings.TrimSpace(cfg.OIDC.TokenAuthMethod))
|
||
cfg.OIDC.AllowedSigningAlgs = strings.TrimSpace(cfg.OIDC.AllowedSigningAlgs)
|
||
cfg.OIDC.UserInfoEmailPath = strings.TrimSpace(cfg.OIDC.UserInfoEmailPath)
|
||
cfg.OIDC.UserInfoIDPath = strings.TrimSpace(cfg.OIDC.UserInfoIDPath)
|
||
cfg.OIDC.UserInfoUsernamePath = strings.TrimSpace(cfg.OIDC.UserInfoUsernamePath)
|
||
cfg.Dashboard.KeyPrefix = strings.TrimSpace(cfg.Dashboard.KeyPrefix)
|
||
cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins)
|
||
cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed)
|
||
cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove)
|
||
cfg.Security.CSP.Policy = strings.TrimSpace(cfg.Security.CSP.Policy)
|
||
cfg.Log.Level = strings.ToLower(strings.TrimSpace(cfg.Log.Level))
|
||
cfg.Log.Format = strings.ToLower(strings.TrimSpace(cfg.Log.Format))
|
||
cfg.Log.ServiceName = strings.TrimSpace(cfg.Log.ServiceName)
|
||
cfg.Log.Environment = strings.TrimSpace(cfg.Log.Environment)
|
||
cfg.Log.StacktraceLevel = strings.ToLower(strings.TrimSpace(cfg.Log.StacktraceLevel))
|
||
cfg.Log.Output.FilePath = strings.TrimSpace(cfg.Log.Output.FilePath)
|
||
cfg.Gateway.ForcedCodexInstructionsTemplateFile = strings.TrimSpace(cfg.Gateway.ForcedCodexInstructionsTemplateFile)
|
||
}
|
||
|
||
func loadCodexTemplate(cfg *Config) error {
|
||
if cfg.Gateway.ForcedCodexInstructionsTemplateFile != "" {
|
||
content, err := os.ReadFile(cfg.Gateway.ForcedCodexInstructionsTemplateFile)
|
||
if err != nil {
|
||
return fmt.Errorf("read forced codex instructions template %q: %w", cfg.Gateway.ForcedCodexInstructionsTemplateFile, err)
|
||
}
|
||
cfg.Gateway.ForcedCodexInstructionsTemplate = string(content)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func logSecurityWarnings(cfg *Config) {
|
||
if !cfg.Security.URLAllowlist.Enabled {
|
||
slog.Warn("security.url_allowlist.enabled=false; allowlist/SSRF checks disabled (minimal format validation only).")
|
||
}
|
||
if !cfg.Security.ResponseHeaders.Enabled {
|
||
slog.Warn("security.response_headers.enabled=false; configurable header filtering disabled (default allowlist only).")
|
||
}
|
||
if cfg.JWT.Secret != "" && isWeakJWTSecret(cfg.JWT.Secret) {
|
||
slog.Warn("JWT secret appears weak; use a 32+ character random secret in production.")
|
||
}
|
||
if len(cfg.Security.ResponseHeaders.AdditionalAllowed) > 0 || len(cfg.Security.ResponseHeaders.ForceRemove) > 0 {
|
||
slog.Info("response header policy configured",
|
||
"additional_allowed", cfg.Security.ResponseHeaders.AdditionalAllowed,
|
||
"force_remove", cfg.Security.ResponseHeaders.ForceRemove)
|
||
}
|
||
}
|