139 lines
4.2 KiB
Go
139 lines
4.2 KiB
Go
package config
|
|
|
|
import (
|
|
"fmt"
|
|
"os"
|
|
"strings"
|
|
|
|
"github.com/spf13/viper"
|
|
)
|
|
|
|
// Config 是应用配置结构
|
|
type Config struct {
|
|
Server ServerConfig `mapstructure:"server"`
|
|
Database DatabaseConfig `mapstructure:"database"`
|
|
Redis RedisConfig `mapstructure:"redis"`
|
|
Metrics MetricsConfig `mapstructure:"metrics"`
|
|
}
|
|
|
|
type ServerConfig struct {
|
|
Port int `mapstructure:"port"`
|
|
Mode string `mapstructure:"mode"` // development / production
|
|
JWTSecret string `mapstructure:"jwt_secret"`
|
|
MetricsAuth string `mapstructure:"metrics_auth"` // API Key for /metrics
|
|
}
|
|
|
|
type DatabaseConfig struct {
|
|
Host string `mapstructure:"host"`
|
|
Port int `mapstructure:"port"`
|
|
User string `mapstructure:"user"`
|
|
Password string `mapstructure:"password"`
|
|
DBName string `mapstructure:"dbname"`
|
|
SSLMode string `mapstructure:"sslmode"`
|
|
PoolSize int `mapstructure:"pool_size"`
|
|
}
|
|
|
|
type RedisConfig struct {
|
|
Host string `mapstructure:"host"`
|
|
Port int `mapstructure:"port"`
|
|
Password string `mapstructure:"password"`
|
|
DB int `mapstructure:"db"`
|
|
}
|
|
|
|
type MetricsConfig struct {
|
|
PrometheusURL string `mapstructure:"prometheus_url"`
|
|
RetentionDays int `mapstructure:"retention_days"`
|
|
}
|
|
|
|
// Load 从配置文件和环境变量加载配置
|
|
func Load(path string) (*Config, error) {
|
|
v := viper.New()
|
|
v.SetConfigFile(path)
|
|
v.SetEnvPrefix("AI_OPS")
|
|
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
|
v.AutomaticEnv()
|
|
|
|
// 默认值
|
|
v.SetDefault("server.port", 8080)
|
|
v.SetDefault("server.mode", "development")
|
|
v.SetDefault("database.host", "localhost")
|
|
v.SetDefault("database.port", 5432)
|
|
v.SetDefault("database.sslmode", "disable")
|
|
v.SetDefault("database.pool_size", 10)
|
|
v.SetDefault("redis.host", "localhost")
|
|
v.SetDefault("redis.port", 6379)
|
|
v.SetDefault("metrics.retention_days", 7)
|
|
|
|
if err := v.ReadInConfig(); err != nil {
|
|
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
|
|
return nil, fmt.Errorf("read config: %w", err)
|
|
}
|
|
}
|
|
|
|
var cfg Config
|
|
if err := v.Unmarshal(&cfg); err != nil {
|
|
return nil, fmt.Errorf("unmarshal config: %w", err)
|
|
}
|
|
|
|
// 环境变量覆盖
|
|
if host := os.Getenv("SPRING_DATASOURCE_URL"); host != "" {
|
|
// 兼容 Spring Boot 风格的数据库配置
|
|
cfg.Database.Host = host
|
|
}
|
|
applyExplicitEnvOverrides(&cfg)
|
|
if err := cfg.Validate(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &cfg, nil
|
|
}
|
|
|
|
func applyExplicitEnvOverrides(cfg *Config) {
|
|
setString := func(key string, dst *string) {
|
|
if v := os.Getenv(key); v != "" {
|
|
*dst = v
|
|
}
|
|
}
|
|
setString("AI_OPS_SERVER_JWT_SECRET", &cfg.Server.JWTSecret)
|
|
setString("AI_OPS_SERVER_METRICS_AUTH", &cfg.Server.MetricsAuth)
|
|
setString("AI_OPS_DATABASE_HOST", &cfg.Database.Host)
|
|
setString("AI_OPS_DATABASE_USER", &cfg.Database.User)
|
|
setString("AI_OPS_DATABASE_PASSWORD", &cfg.Database.Password)
|
|
setString("AI_OPS_DATABASE_DBNAME", &cfg.Database.DBName)
|
|
setString("AI_OPS_REDIS_HOST", &cfg.Redis.Host)
|
|
setString("AI_OPS_REDIS_PASSWORD", &cfg.Redis.Password)
|
|
}
|
|
|
|
func (c *Config) Validate() error {
|
|
if c.Server.Port <= 0 || c.Server.Port > 65535 {
|
|
return fmt.Errorf("invalid server.port: %d", c.Server.Port)
|
|
}
|
|
if c.Database.Port <= 0 || c.Database.Port > 65535 {
|
|
return fmt.Errorf("invalid database.port: %d", c.Database.Port)
|
|
}
|
|
if c.Database.PoolSize <= 0 {
|
|
return fmt.Errorf("invalid database.pool_size: %d", c.Database.PoolSize)
|
|
}
|
|
if c.Metrics.RetentionDays <= 0 {
|
|
return fmt.Errorf("invalid metrics.retention_days: %d", c.Metrics.RetentionDays)
|
|
}
|
|
if strings.EqualFold(c.Server.Mode, "production") {
|
|
if len(c.Server.JWTSecret) < 32 {
|
|
return fmt.Errorf("server.jwt_secret must be at least 32 characters in production")
|
|
}
|
|
if len(c.Server.MetricsAuth) < 16 {
|
|
return fmt.Errorf("server.metrics_auth must be at least 16 characters in production")
|
|
}
|
|
if c.Database.Host == "" || c.Database.User == "" || c.Database.Password == "" || c.Database.DBName == "" {
|
|
return fmt.Errorf("database host/user/password/dbname are required in production")
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// DSN 返回 PostgreSQL 连接字符串
|
|
func (c DatabaseConfig) DSN() string {
|
|
return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s pool_max_conns=%d",
|
|
c.Host, c.Port, c.User, c.Password, c.DBName, c.SSLMode, c.PoolSize)
|
|
}
|