Files
tokens-reef/backend/internal/config/config.go
2026-04-24 08:32:16 +08:00

261 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// Package config provides configuration loading, defaults, and validation.
//
// Type definitions are organized by domain into companion files:
//
// - server.go — ServerConfig, H2CConfig, CORSConfig, ConcurrencyConfig
// - security.go — SecurityConfig, URLAllowlist, CSP, ResponseHeaders
// - database.go — DatabaseConfig, RedisConfig (with DSN helpers)
// - auth.go — JWTConfig, TotpConfig, TurnstileConfig, DefaultConfig, RateLimitConfig
// - billing.go — BillingConfig, PricingConfig
// - gateway.go — GatewayConfig, UserMessageQueue, SchedulingConfig
// - gateway_sub.go — OpenAIWS, UsageRecord, TLSFingerprint sub-structs
// - platforms.go — Sora, Gemini, LinuxDo, OIDC, Update, Idempotency configs
// - ops_and_cache.go— LogConfig, OpsConfig, Dashboard, Cache, Cleanup configs
package config
import (
"fmt"
"log/slog"
"os"
"strings"
"github.com/spf13/viper"
)
const (
RunModeStandard = "standard"
RunModeSimple = "simple"
)
// 使用量记录队列溢出策略
const (
UsageRecordOverflowPolicyDrop = "drop"
UsageRecordOverflowPolicySample = "sample"
UsageRecordOverflowPolicySync = "sync"
)
// UMQ用户消息队列模式常量
const (
UMQModeSerialize = "serialize"
UMQModeThrottle = "throttle"
)
// 连接池隔离策略常量
const (
ConnectionPoolIsolationProxy = "proxy"
ConnectionPoolIsolationAccount = "account"
ConnectionPoolIsolationAccountProxy = "account_proxy"
)
// DefaultCSPPolicy is the default Content-Security-Policy with nonce support.
const DefaultCSPPolicy = `default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'`
// Config is the top-level application configuration.
type Config struct {
Server ServerConfig `mapstructure:"server"`
Log LogConfig `mapstructure:"log"`
CORS CORSConfig `mapstructure:"cors"`
Security SecurityConfig `mapstructure:"security"`
Billing BillingConfig `mapstructure:"billing"`
Turnstile TurnstileConfig `mapstructure:"turnstile"`
Database DatabaseConfig `mapstructure:"database"`
Redis RedisConfig `mapstructure:"redis"`
Ops OpsConfig `mapstructure:"ops"`
JWT JWTConfig `mapstructure:"jwt"`
Totp TotpConfig `mapstructure:"totp"`
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
OIDC OIDCConnectConfig `mapstructure:"oidc_connect"`
Default DefaultConfig `mapstructure:"default"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
Pricing PricingConfig `mapstructure:"pricing"`
Gateway GatewayConfig `mapstructure:"gateway"`
APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"`
SubscriptionCache SubscriptionCacheConfig `mapstructure:"subscription_cache"`
SubscriptionMaintenance SubscriptionMaintenanceConfig `mapstructure:"subscription_maintenance"`
Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"`
DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"`
UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"`
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
Timezone string `mapstructure:"timezone"`
Sora SoraConfig `mapstructure:"sora"`
Gemini GeminiConfig `mapstructure:"gemini"`
Update UpdateConfig `mapstructure:"update"`
Idempotency IdempotencyConfig `mapstructure:"idempotency"`
}
func NormalizeRunMode(value string) string {
normalized := strings.ToLower(strings.TrimSpace(value))
switch normalized {
case RunModeStandard, RunModeSimple:
return normalized
default:
return RunModeStandard
}
}
// Load 读取并校验完整配置(要求 jwt.secret 已显式提供)。
func Load() (*Config, error) { return load(false) }
// LoadForBootstrap 读取启动阶段配置。允许 jwt.secret 先留空。
func LoadForBootstrap() (*Config, error) { return load(true) }
func load(allowMissingJWTSecret bool) (*Config, error) {
viper.SetConfigName("config")
viper.SetConfigType("yaml")
// Add config paths in priority order
if dataDir := os.Getenv("DATA_DIR"); dataDir != "" {
viper.AddConfigPath(dataDir)
}
viper.AddConfigPath("/app/data")
viper.AddConfigPath(".")
viper.AddConfigPath("./config")
viper.AddConfigPath("/etc/sub2api")
viper.AutomaticEnv()
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
setDefaults()
if err := viper.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
return nil, fmt.Errorf("read config error: %w", err)
}
}
var cfg Config
if err := viper.Unmarshal(&cfg); err != nil {
return nil, fmt.Errorf("unmarshal config error: %w", err)
}
cfg.RunMode = NormalizeRunMode(cfg.RunMode)
cfg.Server.Mode = strings.ToLower(strings.TrimSpace(cfg.Server.Mode))
if cfg.Server.Mode == "" {
cfg.Server.Mode = "debug"
}
cfg.Server.FrontendURL = strings.TrimSpace(cfg.Server.FrontendURL)
normalizeAllStringFields(&cfg)
if err := loadCodexTemplate(&cfg); err != nil {
return nil, err
}
// 兼容旧键 sticky_previous_response_ttl_seconds
if cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && cfg.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 {
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = cfg.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds
}
// Normalize UMQ mode
if m := cfg.Gateway.UserMessageQueue.Mode; m != "" && m != UMQModeSerialize && m != UMQModeThrottle {
slog.Warn("invalid user_message_queue mode, disabling", "mode", m,
"valid_modes", []string{UMQModeSerialize, UMQModeThrottle})
cfg.Gateway.UserMessageQueue.Mode = ""
}
// Auto-generate TOTP encryption key if not set
cfg.Totp.EncryptionKey = strings.TrimSpace(cfg.Totp.EncryptionKey)
if cfg.Totp.EncryptionKey == "" {
key, err := generateJWTSecret(32)
if err != nil {
return nil, fmt.Errorf("generate totp encryption key error: %w", err)
}
cfg.Totp.EncryptionKey = key
cfg.Totp.EncryptionKeyConfigured = false
slog.Warn("TOTP encryption key auto-generated. Consider setting a fixed key for production.")
} else {
cfg.Totp.EncryptionKeyConfigured = true
}
originalJWTSecret := cfg.JWT.Secret
cfg.JWT.SecretConfigured = strings.TrimSpace(originalJWTSecret) != ""
if allowMissingJWTSecret && originalJWTSecret == "" {
cfg.JWT.Secret = strings.Repeat("0", 32)
}
if err := cfg.Validate(); err != nil {
return nil, fmt.Errorf("validate config error: %w", err)
}
if allowMissingJWTSecret && originalJWTSecret == "" {
cfg.JWT.Secret = ""
}
logSecurityWarnings(&cfg)
return &cfg, nil
}
// normalizeAllStringFields trims all string fields loaded from config.
func normalizeAllStringFields(cfg *Config) {
cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret)
cfg.LinuxDo.ClientID = strings.TrimSpace(cfg.LinuxDo.ClientID)
cfg.LinuxDo.ClientSecret = strings.TrimSpace(cfg.LinuxDo.ClientSecret)
cfg.LinuxDo.AuthorizeURL = strings.TrimSpace(cfg.LinuxDo.AuthorizeURL)
cfg.LinuxDo.TokenURL = strings.TrimSpace(cfg.LinuxDo.TokenURL)
cfg.LinuxDo.UserInfoURL = strings.TrimSpace(cfg.LinuxDo.UserInfoURL)
cfg.LinuxDo.Scopes = strings.TrimSpace(cfg.LinuxDo.Scopes)
cfg.LinuxDo.RedirectURL = strings.TrimSpace(cfg.LinuxDo.RedirectURL)
cfg.LinuxDo.FrontendRedirectURL = strings.TrimSpace(cfg.LinuxDo.FrontendRedirectURL)
cfg.LinuxDo.TokenAuthMethod = strings.ToLower(strings.TrimSpace(cfg.LinuxDo.TokenAuthMethod))
cfg.LinuxDo.UserInfoEmailPath = strings.TrimSpace(cfg.LinuxDo.UserInfoEmailPath)
cfg.LinuxDo.UserInfoIDPath = strings.TrimSpace(cfg.LinuxDo.UserInfoIDPath)
cfg.LinuxDo.UserInfoUsernamePath = strings.TrimSpace(cfg.LinuxDo.UserInfoUsernamePath)
cfg.OIDC.ProviderName = strings.TrimSpace(cfg.OIDC.ProviderName)
cfg.OIDC.ClientID = strings.TrimSpace(cfg.OIDC.ClientID)
cfg.OIDC.ClientSecret = strings.TrimSpace(cfg.OIDC.ClientSecret)
cfg.OIDC.IssuerURL = strings.TrimSpace(cfg.OIDC.IssuerURL)
cfg.OIDC.DiscoveryURL = strings.TrimSpace(cfg.OIDC.DiscoveryURL)
cfg.OIDC.AuthorizeURL = strings.TrimSpace(cfg.OIDC.AuthorizeURL)
cfg.OIDC.TokenURL = strings.TrimSpace(cfg.OIDC.TokenURL)
cfg.OIDC.UserInfoURL = strings.TrimSpace(cfg.OIDC.UserInfoURL)
cfg.OIDC.JWKSURL = strings.TrimSpace(cfg.OIDC.JWKSURL)
cfg.OIDC.Scopes = strings.TrimSpace(cfg.OIDC.Scopes)
cfg.OIDC.RedirectURL = strings.TrimSpace(cfg.OIDC.RedirectURL)
cfg.OIDC.FrontendRedirectURL = strings.TrimSpace(cfg.OIDC.FrontendRedirectURL)
cfg.OIDC.TokenAuthMethod = strings.ToLower(strings.TrimSpace(cfg.OIDC.TokenAuthMethod))
cfg.OIDC.AllowedSigningAlgs = strings.TrimSpace(cfg.OIDC.AllowedSigningAlgs)
cfg.OIDC.UserInfoEmailPath = strings.TrimSpace(cfg.OIDC.UserInfoEmailPath)
cfg.OIDC.UserInfoIDPath = strings.TrimSpace(cfg.OIDC.UserInfoIDPath)
cfg.OIDC.UserInfoUsernamePath = strings.TrimSpace(cfg.OIDC.UserInfoUsernamePath)
cfg.Dashboard.KeyPrefix = strings.TrimSpace(cfg.Dashboard.KeyPrefix)
cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins)
cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed)
cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove)
cfg.Security.CSP.Policy = strings.TrimSpace(cfg.Security.CSP.Policy)
cfg.Log.Level = strings.ToLower(strings.TrimSpace(cfg.Log.Level))
cfg.Log.Format = strings.ToLower(strings.TrimSpace(cfg.Log.Format))
cfg.Log.ServiceName = strings.TrimSpace(cfg.Log.ServiceName)
cfg.Log.Environment = strings.TrimSpace(cfg.Log.Environment)
cfg.Log.StacktraceLevel = strings.ToLower(strings.TrimSpace(cfg.Log.StacktraceLevel))
cfg.Log.Output.FilePath = strings.TrimSpace(cfg.Log.Output.FilePath)
cfg.Gateway.ForcedCodexInstructionsTemplateFile = strings.TrimSpace(cfg.Gateway.ForcedCodexInstructionsTemplateFile)
}
func loadCodexTemplate(cfg *Config) error {
if cfg.Gateway.ForcedCodexInstructionsTemplateFile != "" {
content, err := os.ReadFile(cfg.Gateway.ForcedCodexInstructionsTemplateFile)
if err != nil {
return fmt.Errorf("read forced codex instructions template %q: %w", cfg.Gateway.ForcedCodexInstructionsTemplateFile, err)
}
cfg.Gateway.ForcedCodexInstructionsTemplate = string(content)
}
return nil
}
func logSecurityWarnings(cfg *Config) {
if !cfg.Security.URLAllowlist.Enabled {
slog.Warn("security.url_allowlist.enabled=false; allowlist/SSRF checks disabled (minimal format validation only).")
}
if !cfg.Security.ResponseHeaders.Enabled {
slog.Warn("security.response_headers.enabled=false; configurable header filtering disabled (default allowlist only).")
}
if cfg.JWT.Secret != "" && isWeakJWTSecret(cfg.JWT.Secret) {
slog.Warn("JWT secret appears weak; use a 32+ character random secret in production.")
}
if len(cfg.Security.ResponseHeaders.AdditionalAllowed) > 0 || len(cfg.Security.ResponseHeaders.ForceRemove) > 0 {
slog.Info("response header policy configured",
"additional_allowed", cfg.Security.ResponseHeaders.AdditionalAllowed,
"force_remove", cfg.Security.ResponseHeaders.ForceRemove)
}
}