Files
ai-ops/internal/config/config.go
2026-05-12 17:48:22 +08:00

139 lines
4.2 KiB
Go

package config
import (
"fmt"
"os"
"strings"
"github.com/spf13/viper"
)
// Config 是应用配置结构
type Config struct {
Server ServerConfig `mapstructure:"server"`
Database DatabaseConfig `mapstructure:"database"`
Redis RedisConfig `mapstructure:"redis"`
Metrics MetricsConfig `mapstructure:"metrics"`
}
type ServerConfig struct {
Port int `mapstructure:"port"`
Mode string `mapstructure:"mode"` // development / production
JWTSecret string `mapstructure:"jwt_secret"`
MetricsAuth string `mapstructure:"metrics_auth"` // API Key for /metrics
}
type DatabaseConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
User string `mapstructure:"user"`
Password string `mapstructure:"password"`
DBName string `mapstructure:"dbname"`
SSLMode string `mapstructure:"sslmode"`
PoolSize int `mapstructure:"pool_size"`
}
type RedisConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
Password string `mapstructure:"password"`
DB int `mapstructure:"db"`
}
type MetricsConfig struct {
PrometheusURL string `mapstructure:"prometheus_url"`
RetentionDays int `mapstructure:"retention_days"`
}
// Load 从配置文件和环境变量加载配置
func Load(path string) (*Config, error) {
v := viper.New()
v.SetConfigFile(path)
v.SetEnvPrefix("AI_OPS")
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
v.AutomaticEnv()
// 默认值
v.SetDefault("server.port", 8080)
v.SetDefault("server.mode", "development")
v.SetDefault("database.host", "localhost")
v.SetDefault("database.port", 5432)
v.SetDefault("database.sslmode", "disable")
v.SetDefault("database.pool_size", 10)
v.SetDefault("redis.host", "localhost")
v.SetDefault("redis.port", 6379)
v.SetDefault("metrics.retention_days", 7)
if err := v.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
return nil, fmt.Errorf("read config: %w", err)
}
}
var cfg Config
if err := v.Unmarshal(&cfg); err != nil {
return nil, fmt.Errorf("unmarshal config: %w", err)
}
// 环境变量覆盖
if host := os.Getenv("SPRING_DATASOURCE_URL"); host != "" {
// 兼容 Spring Boot 风格的数据库配置
cfg.Database.Host = host
}
applyExplicitEnvOverrides(&cfg)
if err := cfg.Validate(); err != nil {
return nil, err
}
return &cfg, nil
}
func applyExplicitEnvOverrides(cfg *Config) {
setString := func(key string, dst *string) {
if v := os.Getenv(key); v != "" {
*dst = v
}
}
setString("AI_OPS_SERVER_JWT_SECRET", &cfg.Server.JWTSecret)
setString("AI_OPS_SERVER_METRICS_AUTH", &cfg.Server.MetricsAuth)
setString("AI_OPS_DATABASE_HOST", &cfg.Database.Host)
setString("AI_OPS_DATABASE_USER", &cfg.Database.User)
setString("AI_OPS_DATABASE_PASSWORD", &cfg.Database.Password)
setString("AI_OPS_DATABASE_DBNAME", &cfg.Database.DBName)
setString("AI_OPS_REDIS_HOST", &cfg.Redis.Host)
setString("AI_OPS_REDIS_PASSWORD", &cfg.Redis.Password)
}
func (c *Config) Validate() error {
if c.Server.Port <= 0 || c.Server.Port > 65535 {
return fmt.Errorf("invalid server.port: %d", c.Server.Port)
}
if c.Database.Port <= 0 || c.Database.Port > 65535 {
return fmt.Errorf("invalid database.port: %d", c.Database.Port)
}
if c.Database.PoolSize <= 0 {
return fmt.Errorf("invalid database.pool_size: %d", c.Database.PoolSize)
}
if c.Metrics.RetentionDays <= 0 {
return fmt.Errorf("invalid metrics.retention_days: %d", c.Metrics.RetentionDays)
}
if strings.EqualFold(c.Server.Mode, "production") {
if len(c.Server.JWTSecret) < 32 {
return fmt.Errorf("server.jwt_secret must be at least 32 characters in production")
}
if len(c.Server.MetricsAuth) < 16 {
return fmt.Errorf("server.metrics_auth must be at least 16 characters in production")
}
if c.Database.Host == "" || c.Database.User == "" || c.Database.Password == "" || c.Database.DBName == "" {
return fmt.Errorf("database host/user/password/dbname are required in production")
}
}
return nil
}
// DSN 返回 PostgreSQL 连接字符串
func (c DatabaseConfig) DSN() string {
return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s pool_max_conns=%d",
c.Host, c.Port, c.User, c.Password, c.DBName, c.SSLMode, c.PoolSize)
}