Files
lijiaoqiao/supply-api/internal/config/config.go

250 lines
7.1 KiB
Go
Raw Normal View History

package config
import (
"fmt"
"os"
"strconv"
"strings"
"time"
"github.com/spf13/viper"
)
// Config 应用配置
type Config struct {
Server ServerConfig
Database DatabaseConfig
Redis RedisConfig
Token TokenConfig
Audit AuditConfig
}
// ServerConfig HTTP服务配置
type ServerConfig struct {
Addr string
ReadTimeout time.Duration
WriteTimeout time.Duration
IdleTimeout time.Duration
ShutdownTimeout time.Duration
}
// DatabaseConfig PostgreSQL配置
type DatabaseConfig struct {
Host string
Port int
User string
Password string
Database string
MaxOpenConns int
MaxIdleConns int
ConnMaxLifetime time.Duration
ConnMaxIdleTime time.Duration
}
// RedisConfig Redis配置
type RedisConfig struct {
Host string
Port int
Password string
DB int
PoolSize int
}
// TokenConfig Token运行时配置
type TokenConfig struct {
SecretKey string
Issuer string
AccessTokenTTL time.Duration
RefreshTokenTTL time.Duration
RevocationCacheTTL time.Duration
}
// AuditConfig 审计配置
type AuditConfig struct {
BufferSize int
FlushInterval time.Duration
ExportTimeout time.Duration
}
// DSN 返回数据库连接字符串(包含明文密码,仅限内部使用)
func (d *DatabaseConfig) DSN() string {
return fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=disable",
d.User, d.Password, d.Host, d.Port, d.Database)
}
// SafeDSN 返回脱敏的数据库连接字符串(密码被替换为***),用于日志记录
// P2-05: 避免在日志中泄露数据库密码
func (d *DatabaseConfig) SafeDSN() string {
return fmt.Sprintf("postgres://%s:***@%s:%d/%s?sslmode=disable",
d.User, d.Host, d.Port, d.Database)
}
// Addr 返回Redis地址
func (r *RedisConfig) Addr() string {
return fmt.Sprintf("%s:%d", r.Host, r.Port)
}
// Load 加载配置
func Load(env string) (*Config, error) {
v := viper.New()
// 设置环境变量前缀
v.SetEnvPrefix("SUPPLY_API")
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
// 默认配置
setDefaults(v)
// 加载配置文件
configFile := fmt.Sprintf("config.%s.yaml", env)
v.SetConfigName(configFile)
v.SetConfigType("yaml")
v.AddConfigPath(".")
v.AddConfigPath("./config")
// 允许环境变量覆盖
v.AutomaticEnv()
// 读取配置文件
if err := v.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
return nil, fmt.Errorf("failed to read config: %w", err)
}
// 配置文件不存在时,使用环境变量
}
// 绑定环境变量
bindEnvVars(v)
var cfg Config
// Server配置
cfg.Server.Addr = v.GetString("server.addr")
cfg.Server.ReadTimeout = v.GetDuration("server.read_timeout")
cfg.Server.WriteTimeout = v.GetDuration("server.write_timeout")
cfg.Server.IdleTimeout = v.GetDuration("server.idle_timeout")
cfg.Server.ShutdownTimeout = v.GetDuration("server.shutdown_timeout")
// Database配置
cfg.Database.Host = v.GetString("database.host")
cfg.Database.Port = v.GetInt("database.port")
cfg.Database.User = v.GetString("database.user")
cfg.Database.Password = v.GetString("database.password")
cfg.Database.Database = v.GetString("database.database")
cfg.Database.MaxOpenConns = v.GetInt("database.max_open_conns")
cfg.Database.MaxIdleConns = v.GetInt("database.max_idle_conns")
cfg.Database.ConnMaxLifetime = v.GetDuration("database.conn_max_lifetime")
cfg.Database.ConnMaxIdleTime = v.GetDuration("database.conn_max_idle_time")
// Redis配置
cfg.Redis.Host = v.GetString("redis.host")
cfg.Redis.Port = v.GetInt("redis.port")
cfg.Redis.Password = v.GetString("redis.password")
cfg.Redis.DB = v.GetInt("redis.db")
cfg.Redis.PoolSize = v.GetInt("redis.pool_size")
// Token配置
cfg.Token.SecretKey = v.GetString("token.secret_key")
cfg.Token.Issuer = v.GetString("token.issuer")
cfg.Token.AccessTokenTTL = v.GetDuration("token.access_token_ttl")
cfg.Token.RefreshTokenTTL = v.GetDuration("token.refresh_token_ttl")
cfg.Token.RevocationCacheTTL = v.GetDuration("token.revocation_cache_ttl")
// Audit配置
cfg.Audit.BufferSize = v.GetInt("audit.buffer_size")
cfg.Audit.FlushInterval = v.GetDuration("audit.flush_interval")
cfg.Audit.ExportTimeout = v.GetDuration("audit.export_timeout")
return &cfg, nil
}
// setDefaults 设置默认值
func setDefaults(v *viper.Viper) {
// Server defaults
v.SetDefault("server.addr", ":18082")
v.SetDefault("server.read_timeout", 10*time.Second)
v.SetDefault("server.write_timeout", 15*time.Second)
v.SetDefault("server.idle_timeout", 30*time.Second)
v.SetDefault("server.shutdown_timeout", 5*time.Second)
// Database defaults
v.SetDefault("database.host", "localhost")
v.SetDefault("database.port", 5432)
v.SetDefault("database.user", "postgres")
v.SetDefault("database.password", "")
v.SetDefault("database.database", "supply_db")
v.SetDefault("database.max_open_conns", 25)
v.SetDefault("database.max_idle_conns", 5)
v.SetDefault("database.conn_max_lifetime", 1*time.Hour)
v.SetDefault("database.conn_max_idle_time", 10*time.Minute)
// Redis defaults
v.SetDefault("redis.host", "localhost")
v.SetDefault("redis.port", 6379)
v.SetDefault("redis.password", "")
v.SetDefault("redis.db", 0)
v.SetDefault("redis.pool_size", 10)
// Token defaults
v.SetDefault("token.issuer", "lijiaoqiao/supply-api")
v.SetDefault("token.access_token_ttl", 1*time.Hour)
v.SetDefault("token.refresh_token_ttl", 7*24*time.Hour)
v.SetDefault("token.revocation_cache_ttl", 30*time.Second)
// Audit defaults
v.SetDefault("audit.buffer_size", 1000)
v.SetDefault("audit.flush_interval", 5*time.Second)
v.SetDefault("audit.export_timeout", 30*time.Second)
}
// bindEnvVars 绑定环境变量
func bindEnvVars(v *viper.Viper) {
_ = v.BindEnv("server.addr", "SUPPLY_API_ADDR")
_ = v.BindEnv("server.read_timeout", "SUPPLY_API_READ_TIMEOUT")
_ = v.BindEnv("server.write_timeout", "SUPPLY_API_WRITE_TIMEOUT")
_ = v.BindEnv("database.host", "SUPPLY_DB_HOST")
_ = v.BindEnv("database.port", "SUPPLY_DB_PORT")
_ = v.BindEnv("database.user", "SUPPLY_DB_USER")
_ = v.BindEnv("database.password", "SUPPLY_DB_PASSWORD")
_ = v.BindEnv("database.database", "SUPPLY_DB_NAME")
_ = v.BindEnv("database.max_open_conns", "SUPPLY_DB_MAX_OPEN_CONNS")
_ = v.BindEnv("database.max_idle_conns", "SUPPLY_DB_MAX_IDLE_CONNS")
_ = v.BindEnv("redis.host", "SUPPLY_REDIS_HOST")
_ = v.BindEnv("redis.port", "SUPPLY_REDIS_PORT")
_ = v.BindEnv("redis.password", "SUPPLY_REDIS_PASSWORD")
_ = v.BindEnv("redis.db", "SUPPLY_REDIS_DB")
_ = v.BindEnv("token.secret_key", "SUPPLY_TOKEN_SECRET_KEY")
}
// MustLoad 加载配置失败时panic
func MustLoad(env string) *Config {
cfg, err := Load(env)
if err != nil {
panic("failed to load config: " + err.Error())
}
return cfg
}
// GetEnvInt 获取环境变量int值
func GetEnvInt(key string, defaultVal int) int {
if v := os.Getenv(key); v != "" {
if i, err := strconv.Atoi(v); err == nil {
return i
}
}
return defaultVal
}
// GetEnvDuration 获取环境变量duration值
func GetEnvDuration(key string, defaultVal time.Duration) time.Duration {
if v := os.Getenv(key); v != "" {
if d, err := time.ParseDuration(v); err == nil {
return d
}
}
return defaultVal
}