From 0196ee5d472c3a7e6247d8443aa4d5b3079f067c Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 1 Apr 2026 08:53:28 +0800 Subject: [PATCH] =?UTF-8?q?feat(supply-api):=20=E5=AE=8C=E6=88=90=E6=A0=B8?= =?UTF-8?q?=E5=BF=83=E6=A8=A1=E5=9D=97=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 新增/修改内容: - config: 添加配置管理(config.example.yaml, config.go) - cache: 添加Redis缓存层(redis.go) - domain: 添加invariants不变量验证及测试 - middleware: 添加auth认证和idempotency幂等性中间件及测试 - repository: 添加完整数据访问层(account, package, settlement, idempotency, db) - sql: 添加幂等性表DDL脚本 代码覆盖: - auth middleware实现凭证边界验证 - idempotency middleware实现请求幂等性 - invariants实现业务不变量检查 - repository层实现完整的数据访问逻辑 关联issue: Round-1 R1-ISSUE-006 凭证边界硬门禁 --- supply-api/.env.example | 29 ++ supply-api/config/config.example.yaml | 37 ++ supply-api/internal/cache/redis.go | 231 +++++++++ supply-api/internal/config/config.go | 242 +++++++++ supply-api/internal/domain/invariants.go | 212 ++++++++ supply-api/internal/domain/invariants_test.go | 101 ++++ supply-api/internal/middleware/auth.go | 477 ++++++++++++++++++ supply-api/internal/middleware/auth_test.go | 343 +++++++++++++ supply-api/internal/middleware/idempotency.go | 279 ++++++++++ .../internal/middleware/idempotency_test.go | 211 ++++++++ supply-api/internal/repository/account.go | 291 +++++++++++ supply-api/internal/repository/db.go | 81 +++ supply-api/internal/repository/idempotency.go | 246 +++++++++ supply-api/internal/repository/package.go | 250 +++++++++ supply-api/internal/repository/settlement.go | 243 +++++++++ .../supply_idempotency_record_v1.sql | 47 ++ 16 files changed, 3320 insertions(+) create mode 100644 supply-api/.env.example create mode 100644 supply-api/config/config.example.yaml create mode 100644 supply-api/internal/cache/redis.go create mode 100644 supply-api/internal/config/config.go create mode 100644 supply-api/internal/domain/invariants.go create mode 100644 supply-api/internal/domain/invariants_test.go create mode 100644 supply-api/internal/middleware/auth.go create mode 100644 supply-api/internal/middleware/auth_test.go create mode 100644 supply-api/internal/middleware/idempotency.go create mode 100644 supply-api/internal/middleware/idempotency_test.go create mode 100644 supply-api/internal/repository/account.go create mode 100644 supply-api/internal/repository/db.go create mode 100644 supply-api/internal/repository/idempotency.go create mode 100644 supply-api/internal/repository/package.go create mode 100644 supply-api/internal/repository/settlement.go create mode 100644 supply-api/sql/postgresql/supply_idempotency_record_v1.sql diff --git a/supply-api/.env.example b/supply-api/.env.example new file mode 100644 index 0000000..359aa21 --- /dev/null +++ b/supply-api/.env.example @@ -0,0 +1,29 @@ +# Supply API Environment Variables +# Copy this file to .env and fill in the values + +# Server +SUPPLY_API_ADDR=:18082 +SUPPLY_API_READ_TIMEOUT=10s +SUPPLY_API_WRITE_TIMEOUT=15s + +# Database (PostgreSQL) +SUPPLY_DB_HOST=localhost +SUPPLY_DB_PORT=5432 +SUPPLY_DB_USER=postgres +SUPPLY_DB_PASSWORD= +SUPPLY_DB_NAME=supply_db +SUPPLY_DB_MAX_OPEN_CONNS=25 +SUPPLY_DB_MAX_IDLE_CONNS=5 + +# Redis +SUPPLY_REDIS_HOST=localhost +SUPPLY_REDIS_PORT=6379 +SUPPLY_REDIS_PASSWORD= +SUPPLY_REDIS_DB=0 + +# Token (JWT) +# 生成密钥: openssl rand -base64 32 +SUPPLY_TOKEN_SECRET_KEY= + +# Environment (dev/staging/prod) +SUPPLY_ENV=dev diff --git a/supply-api/config/config.example.yaml b/supply-api/config/config.example.yaml new file mode 100644 index 0000000..dd8ed29 --- /dev/null +++ b/supply-api/config/config.example.yaml @@ -0,0 +1,37 @@ +# Supply API Development Configuration +server: + addr: ":18082" + read_timeout: 10s + write_timeout: 15s + idle_timeout: 30s + shutdown_timeout: 5s + +database: + host: "localhost" + port: 5432 + user: "postgres" + password: "" + database: "supply_db" + max_open_conns: 25 + max_idle_conns: 5 + conn_max_lifetime: 1h + conn_max_idle_time: 10m + +redis: + host: "localhost" + port: 6379 + password: "" + db: 0 + pool_size: 10 + +token: + secret_key: "${SUPPLY_TOKEN_SECRET_KEY}" + issuer: "lijiaoqiao/supply-api" + access_token_ttl: 1h + refresh_token_ttl: 168h + revocation_cache_ttl: 30s + +audit: + buffer_size: 1000 + flush_interval: 5s + export_timeout: 30s diff --git a/supply-api/internal/cache/redis.go b/supply-api/internal/cache/redis.go new file mode 100644 index 0000000..397d326 --- /dev/null +++ b/supply-api/internal/cache/redis.go @@ -0,0 +1,231 @@ +package cache + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/redis/go-redis/v9" + "lijiaoqiao/supply-api/internal/config" +) + +// RedisCache Redis缓存客户端 +type RedisCache struct { + client *redis.Client +} + +// NewRedisCache 创建Redis缓存客户端 +func NewRedisCache(cfg config.RedisConfig) (*RedisCache, error) { + client := redis.NewClient(&redis.Options{ + Addr: cfg.Addr(), + Password: cfg.Password, + DB: cfg.DB, + PoolSize: cfg.PoolSize, + }) + + // 验证连接 + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := client.Ping(ctx).Err(); err != nil { + return nil, fmt.Errorf("failed to connect to redis: %w", err) + } + + return &RedisCache{client: client}, nil +} + +// Close 关闭连接 +func (r *RedisCache) Close() error { + return r.client.Close() +} + +// HealthCheck 健康检查 +func (r *RedisCache) HealthCheck(ctx context.Context) error { + return r.client.Ping(ctx).Err() +} + +// ==================== Token状态缓存 ==================== + +// TokenStatus Token状态 +type TokenStatus struct { + TokenID string `json:"token_id"` + SubjectID string `json:"subject_id"` + Role string `json:"role"` + Status string `json:"status"` // active, revoked, expired + ExpiresAt int64 `json:"expires_at"` + RevokedAt int64 `json:"revoked_at,omitempty"` + RevokedReason string `json:"revoked_reason,omitempty"` +} + +// GetTokenStatus 获取Token状态 +func (r *RedisCache) GetTokenStatus(ctx context.Context, tokenID string) (*TokenStatus, error) { + key := fmt.Sprintf("token:status:%s", tokenID) + data, err := r.client.Get(ctx, key).Bytes() + if err == redis.Nil { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("failed to get token status: %w", err) + } + + var status TokenStatus + if err := json.Unmarshal(data, &status); err != nil { + return nil, fmt.Errorf("failed to unmarshal token status: %w", err) + } + + return &status, nil +} + +// SetTokenStatus 设置Token状态 +func (r *RedisCache) SetTokenStatus(ctx context.Context, status *TokenStatus, ttl time.Duration) error { + key := fmt.Sprintf("token:status:%s", status.TokenID) + data, err := json.Marshal(status) + if err != nil { + return fmt.Errorf("failed to marshal token status: %w", err) + } + + return r.client.Set(ctx, key, data, ttl).Err() +} + +// InvalidateToken 使Token失效 +func (r *RedisCache) InvalidateToken(ctx context.Context, tokenID string) error { + key := fmt.Sprintf("token:status:%s", tokenID) + return r.client.Del(ctx, key).Err() +} + +// ==================== 限流 ==================== + +// RateLimitKey 限流键 +type RateLimitKey struct { + TenantID int64 + Route string + LimitType string // rpm, rpd, concurrent +} + +// GetRateLimit 获取限流计数 +func (r *RedisCache) GetRateLimit(ctx context.Context, key *RateLimitKey, window time.Duration) (int64, error) { + redisKey := fmt.Sprintf("ratelimit:%d:%s:%s", key.TenantID, key.Route, key.LimitType) + + count, err := r.client.Get(ctx, redisKey).Int64() + if err == redis.Nil { + return 0, nil + } + if err != nil { + return 0, fmt.Errorf("failed to get rate limit: %w", err) + } + + return count, nil +} + +// IncrRateLimit 增加限流计数 +func (r *RedisCache) IncrRateLimit(ctx context.Context, key *RateLimitKey, window time.Duration) (int64, error) { + redisKey := fmt.Sprintf("ratelimit:%d:%s:%s", key.TenantID, key.Route, key.LimitType) + + pipe := r.client.Pipeline() + incrCmd := pipe.Incr(ctx, redisKey) + pipe.Expire(ctx, redisKey, window) + + _, err := pipe.Exec(ctx) + if err != nil { + return 0, fmt.Errorf("failed to increment rate limit: %w", err) + } + + return incrCmd.Val(), nil +} + +// CheckRateLimit 检查限流 +func (r *RedisCache) CheckRateLimit(ctx context.Context, key *RateLimitKey, limit int64, window time.Duration) (bool, int64, error) { + count, err := r.IncrRateLimit(ctx, key, window) + if err != nil { + return false, 0, err + } + + return count <= limit, count, nil +} + +// ==================== 分布式锁 ==================== + +// AcquireLock 获取分布式锁 +func (r *RedisCache) AcquireLock(ctx context.Context, lockKey string, ttl time.Duration) (bool, error) { + redisKey := fmt.Sprintf("lock:%s", lockKey) + + ok, err := r.client.SetNX(ctx, redisKey, "1", ttl).Result() + if err != nil { + return false, fmt.Errorf("failed to acquire lock: %w", err) + } + + return ok, nil +} + +// ReleaseLock 释放分布式锁 +func (r *RedisCache) ReleaseLock(ctx context.Context, lockKey string) error { + redisKey := fmt.Sprintf("lock:%s", lockKey) + return r.client.Del(ctx, redisKey).Err() +} + +// ==================== 幂等缓存 ==================== + +// IdempotencyCache 幂等缓存(短期) +func (r *RedisCache) GetIdempotency(ctx context.Context, key string) (string, error) { + redisKey := fmt.Sprintf("idempotency:%s", key) + val, err := r.client.Get(ctx, redisKey).Result() + if err == redis.Nil { + return "", nil + } + if err != nil { + return "", fmt.Errorf("failed to get idempotency: %w", err) + } + return val, nil +} + +func (r *RedisCache) SetIdempotency(ctx context.Context, key, value string, ttl time.Duration) error { + redisKey := fmt.Sprintf("idempotency:%s", key) + return r.client.Set(ctx, redisKey, value, ttl).Err() +} + +// ==================== Session缓存 ==================== + +// SessionData Session数据 +type SessionData struct { + UserID int64 `json:"user_id"` + TenantID int64 `json:"tenant_id"` + Role string `json:"role"` + CreatedAt int64 `json:"created_at"` +} + +// GetSession 获取Session +func (r *RedisCache) GetSession(ctx context.Context, sessionID string) (*SessionData, error) { + key := fmt.Sprintf("session:%s", sessionID) + data, err := r.client.Get(ctx, key).Bytes() + if err == redis.Nil { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("failed to get session: %w", err) + } + + var session SessionData + if err := json.Unmarshal(data, &session); err != nil { + return nil, fmt.Errorf("failed to unmarshal session: %w", err) + } + + return &session, nil +} + +// SetSession 设置Session +func (r *RedisCache) SetSession(ctx context.Context, sessionID string, session *SessionData, ttl time.Duration) error { + key := fmt.Sprintf("session:%s", sessionID) + data, err := json.Marshal(session) + if err != nil { + return fmt.Errorf("failed to marshal session: %w", err) + } + + return r.client.Set(ctx, key, data, ttl).Err() +} + +// DeleteSession 删除Session +func (r *RedisCache) DeleteSession(ctx context.Context, sessionID string) error { + key := fmt.Sprintf("session:%s", sessionID) + return r.client.Del(ctx, key).Err() +} diff --git a/supply-api/internal/config/config.go b/supply-api/internal/config/config.go new file mode 100644 index 0000000..b364e8c --- /dev/null +++ b/supply-api/internal/config/config.go @@ -0,0 +1,242 @@ +package config + +import ( + "fmt" + "os" + "strconv" + "strings" + "time" + + "github.com/spf13/viper" +) + +// Config 应用配置 +type Config struct { + Server ServerConfig + Database DatabaseConfig + Redis RedisConfig + Token TokenConfig + Audit AuditConfig +} + +// ServerConfig HTTP服务配置 +type ServerConfig struct { + Addr string + ReadTimeout time.Duration + WriteTimeout time.Duration + IdleTimeout time.Duration + ShutdownTimeout time.Duration +} + +// DatabaseConfig PostgreSQL配置 +type DatabaseConfig struct { + Host string + Port int + User string + Password string + Database string + MaxOpenConns int + MaxIdleConns int + ConnMaxLifetime time.Duration + ConnMaxIdleTime time.Duration +} + +// RedisConfig Redis配置 +type RedisConfig struct { + Host string + Port int + Password string + DB int + PoolSize int +} + +// TokenConfig Token运行时配置 +type TokenConfig struct { + SecretKey string + Issuer string + AccessTokenTTL time.Duration + RefreshTokenTTL time.Duration + RevocationCacheTTL time.Duration +} + +// AuditConfig 审计配置 +type AuditConfig struct { + BufferSize int + FlushInterval time.Duration + ExportTimeout time.Duration +} + +// DSN 返回数据库连接字符串 +func (d *DatabaseConfig) DSN() string { + return fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=disable", + d.User, d.Password, d.Host, d.Port, d.Database) +} + +// Addr 返回Redis地址 +func (r *RedisConfig) Addr() string { + return fmt.Sprintf("%s:%d", r.Host, r.Port) +} + +// Load 加载配置 +func Load(env string) (*Config, error) { + v := viper.New() + + // 设置环境变量前缀 + v.SetEnvPrefix("SUPPLY_API") + v.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) + + // 默认配置 + setDefaults(v) + + // 加载配置文件 + configFile := fmt.Sprintf("config.%s.yaml", env) + v.SetConfigName(configFile) + v.SetConfigType("yaml") + v.AddConfigPath(".") + v.AddConfigPath("./config") + + // 允许环境变量覆盖 + v.AutomaticEnv() + + // 读取配置文件 + if err := v.ReadInConfig(); err != nil { + if _, ok := err.(viper.ConfigFileNotFoundError); !ok { + return nil, fmt.Errorf("failed to read config: %w", err) + } + // 配置文件不存在时,使用环境变量 + } + + // 绑定环境变量 + bindEnvVars(v) + + var cfg Config + + // Server配置 + cfg.Server.Addr = v.GetString("server.addr") + cfg.Server.ReadTimeout = v.GetDuration("server.read_timeout") + cfg.Server.WriteTimeout = v.GetDuration("server.write_timeout") + cfg.Server.IdleTimeout = v.GetDuration("server.idle_timeout") + cfg.Server.ShutdownTimeout = v.GetDuration("server.shutdown_timeout") + + // Database配置 + cfg.Database.Host = v.GetString("database.host") + cfg.Database.Port = v.GetInt("database.port") + cfg.Database.User = v.GetString("database.user") + cfg.Database.Password = v.GetString("database.password") + cfg.Database.Database = v.GetString("database.database") + cfg.Database.MaxOpenConns = v.GetInt("database.max_open_conns") + cfg.Database.MaxIdleConns = v.GetInt("database.max_idle_conns") + cfg.Database.ConnMaxLifetime = v.GetDuration("database.conn_max_lifetime") + cfg.Database.ConnMaxIdleTime = v.GetDuration("database.conn_max_idle_time") + + // Redis配置 + cfg.Redis.Host = v.GetString("redis.host") + cfg.Redis.Port = v.GetInt("redis.port") + cfg.Redis.Password = v.GetString("redis.password") + cfg.Redis.DB = v.GetInt("redis.db") + cfg.Redis.PoolSize = v.GetInt("redis.pool_size") + + // Token配置 + cfg.Token.SecretKey = v.GetString("token.secret_key") + cfg.Token.Issuer = v.GetString("token.issuer") + cfg.Token.AccessTokenTTL = v.GetDuration("token.access_token_ttl") + cfg.Token.RefreshTokenTTL = v.GetDuration("token.refresh_token_ttl") + cfg.Token.RevocationCacheTTL = v.GetDuration("token.revocation_cache_ttl") + + // Audit配置 + cfg.Audit.BufferSize = v.GetInt("audit.buffer_size") + cfg.Audit.FlushInterval = v.GetDuration("audit.flush_interval") + cfg.Audit.ExportTimeout = v.GetDuration("audit.export_timeout") + + return &cfg, nil +} + +// setDefaults 设置默认值 +func setDefaults(v *viper.Viper) { + // Server defaults + v.SetDefault("server.addr", ":18082") + v.SetDefault("server.read_timeout", 10*time.Second) + v.SetDefault("server.write_timeout", 15*time.Second) + v.SetDefault("server.idle_timeout", 30*time.Second) + v.SetDefault("server.shutdown_timeout", 5*time.Second) + + // Database defaults + v.SetDefault("database.host", "localhost") + v.SetDefault("database.port", 5432) + v.SetDefault("database.user", "postgres") + v.SetDefault("database.password", "") + v.SetDefault("database.database", "supply_db") + v.SetDefault("database.max_open_conns", 25) + v.SetDefault("database.max_idle_conns", 5) + v.SetDefault("database.conn_max_lifetime", 1*time.Hour) + v.SetDefault("database.conn_max_idle_time", 10*time.Minute) + + // Redis defaults + v.SetDefault("redis.host", "localhost") + v.SetDefault("redis.port", 6379) + v.SetDefault("redis.password", "") + v.SetDefault("redis.db", 0) + v.SetDefault("redis.pool_size", 10) + + // Token defaults + v.SetDefault("token.issuer", "lijiaoqiao/supply-api") + v.SetDefault("token.access_token_ttl", 1*time.Hour) + v.SetDefault("token.refresh_token_ttl", 7*24*time.Hour) + v.SetDefault("token.revocation_cache_ttl", 30*time.Second) + + // Audit defaults + v.SetDefault("audit.buffer_size", 1000) + v.SetDefault("audit.flush_interval", 5*time.Second) + v.SetDefault("audit.export_timeout", 30*time.Second) +} + +// bindEnvVars 绑定环境变量 +func bindEnvVars(v *viper.Viper) { + _ = v.BindEnv("server.addr", "SUPPLY_API_ADDR") + _ = v.BindEnv("server.read_timeout", "SUPPLY_API_READ_TIMEOUT") + _ = v.BindEnv("server.write_timeout", "SUPPLY_API_WRITE_TIMEOUT") + + _ = v.BindEnv("database.host", "SUPPLY_DB_HOST") + _ = v.BindEnv("database.port", "SUPPLY_DB_PORT") + _ = v.BindEnv("database.user", "SUPPLY_DB_USER") + _ = v.BindEnv("database.password", "SUPPLY_DB_PASSWORD") + _ = v.BindEnv("database.database", "SUPPLY_DB_NAME") + _ = v.BindEnv("database.max_open_conns", "SUPPLY_DB_MAX_OPEN_CONNS") + _ = v.BindEnv("database.max_idle_conns", "SUPPLY_DB_MAX_IDLE_CONNS") + + _ = v.BindEnv("redis.host", "SUPPLY_REDIS_HOST") + _ = v.BindEnv("redis.port", "SUPPLY_REDIS_PORT") + _ = v.BindEnv("redis.password", "SUPPLY_REDIS_PASSWORD") + _ = v.BindEnv("redis.db", "SUPPLY_REDIS_DB") + + _ = v.BindEnv("token.secret_key", "SUPPLY_TOKEN_SECRET_KEY") +} + +// MustLoad 加载配置,失败时panic +func MustLoad(env string) *Config { + cfg, err := Load(env) + if err != nil { + panic("failed to load config: " + err.Error()) + } + return cfg +} + +// GetEnvInt 获取环境变量int值 +func GetEnvInt(key string, defaultVal int) int { + if v := os.Getenv(key); v != "" { + if i, err := strconv.Atoi(v); err == nil { + return i + } + } + return defaultVal +} + +// GetEnvDuration 获取环境变量duration值 +func GetEnvDuration(key string, defaultVal time.Duration) time.Duration { + if v := os.Getenv(key); v != "" { + if d, err := time.ParseDuration(v); err == nil { + return d + } + } + return defaultVal +} diff --git a/supply-api/internal/domain/invariants.go b/supply-api/internal/domain/invariants.go new file mode 100644 index 0000000..c362399 --- /dev/null +++ b/supply-api/internal/domain/invariants.go @@ -0,0 +1,212 @@ +package domain + +import ( + "context" + "errors" + "fmt" +) + +// 领域不变量错误 + +var ( + // INV-ACC-001: active账号不可删除 + ErrAccountCannotDeleteActive = errors.New("SUP_ACC_4092: cannot delete active accounts") + + // INV-ACC-002: disabled账号仅管理员可恢复 + ErrAccountDisabledRequiresAdmin = errors.New("SUP_ACC_4031: disabled account requires admin to restore") + + // INV-PKG-001: sold_out只能系统迁移 + ErrPackageSoldOutSystemOnly = errors.New("SUP_PKG_4092: sold_out status can only be changed by system") + + // INV-PKG-002: expired套餐不可直接恢复 + ErrPackageExpiredCannotRestore = errors.New("SUP_PKG_4093: expired package cannot be directly restored") + + // INV-PKG-003: 售价不得低于保护价 + ErrPriceBelowProtection = errors.New("SUP_PKG_4001: price cannot be below protected price") + + // INV-SET-001: processing/completed不可撤销 + ErrSettlementCannotCancel = errors.New("SUP_SET_4092: cannot cancel processing or completed settlements") + + // INV-SET-002: 提现金额不得超过可提现余额 + ErrWithdrawExceedsBalance = errors.New("SUP_SET_4001: withdraw amount exceeds available balance") + + // INV-SET-003: 结算单金额与余额流水必须平衡 + ErrSettlementBalanceMismatch = errors.New("SUP_SET_5002: settlement amount does not match balance ledger") +) + +// InvariantChecker 领域不变量检查器 +type InvariantChecker struct { + accountStore AccountStore + packageStore PackageStore + settlementStore SettlementStore +} + +// NewInvariantChecker 创建不变量检查器 +func NewInvariantChecker( + accountStore AccountStore, + packageStore PackageStore, + settlementStore SettlementStore, +) *InvariantChecker { + return &InvariantChecker{ + accountStore: accountStore, + packageStore: packageStore, + settlementStore: settlementStore, + } +} + +// CheckAccountDelete 检查账号删除不变量 +func (c *InvariantChecker) CheckAccountDelete(ctx context.Context, accountID, supplierID int64) error { + account, err := c.accountStore.GetByID(ctx, supplierID, accountID) + if err != nil { + return err + } + + // INV-ACC-001: active账号不可删除 + if account.Status == AccountStatusActive { + return ErrAccountCannotDeleteActive + } + + return nil +} + +// CheckAccountActivate 检查账号激活不变量 +func (c *InvariantChecker) CheckAccountActivate(ctx context.Context, accountID, supplierID int64) error { + account, err := c.accountStore.GetByID(ctx, supplierID, accountID) + if err != nil { + return err + } + + // INV-ACC-002: disabled账号仅管理员可恢复(简化处理,实际需要检查角色) + if account.Status == AccountStatusDisabled { + return ErrAccountDisabledRequiresAdmin + } + + return nil +} + +// CheckPackagePublish 检查套餐发布不变量 +func (c *InvariantChecker) CheckPackagePublish(ctx context.Context, packageID, supplierID int64) error { + pkg, err := c.packageStore.GetByID(ctx, supplierID, packageID) + if err != nil { + return err + } + + // INV-PKG-002: expired套餐不可直接恢复 + if pkg.Status == PackageStatusExpired { + return ErrPackageExpiredCannotRestore + } + + return nil +} + +// CheckPackagePrice 检查套餐价格不变量 +func (c *InvariantChecker) CheckPackagePrice(ctx context.Context, pkg *Package, newPricePer1MInput, newPricePer1MOutput float64) error { + // INV-PKG-003: 售价不得低于保护价(这里简化处理,实际需要查询保护价配置) + minPrice := 0.01 + if newPricePer1MInput > 0 && newPricePer1MInput < minPrice { + return fmt.Errorf("%w: input price %.6f is below minimum %.6f", + ErrPriceBelowProtection, newPricePer1MInput, minPrice) + } + if newPricePer1MOutput > 0 && newPricePer1MOutput < minPrice { + return fmt.Errorf("%w: output price %.6f is below minimum %.6f", + ErrPriceBelowProtection, newPricePer1MOutput, minPrice) + } + + return nil +} + +// CheckSettlementCancel 检查结算撤销不变量 +func (c *InvariantChecker) CheckSettlementCancel(ctx context.Context, settlementID, supplierID int64) error { + settlement, err := c.settlementStore.GetByID(ctx, supplierID, settlementID) + if err != nil { + return err + } + + // INV-SET-001: processing/completed不可撤销 + if settlement.Status == SettlementStatusProcessing || settlement.Status == SettlementStatusCompleted { + return ErrSettlementCannotCancel + } + + return nil +} + +// CheckWithdrawBalance 检查提现余额不变量 +func (c *InvariantChecker) CheckWithdrawBalance(ctx context.Context, supplierID int64, amount float64) error { + balance, err := c.settlementStore.GetWithdrawableBalance(ctx, supplierID) + if err != nil { + return err + } + + // INV-SET-002: 提现金额不得超过可提现余额 + if amount > balance { + return fmt.Errorf("%w: requested %.2f but available %.2f", + ErrWithdrawExceedsBalance, amount, balance) + } + + return nil +} + +// InvariantViolation 领域不变量违反事件 +type InvariantViolation struct { + RuleCode string + ObjectType string + ObjectID int64 + Message string + OccurredAt string +} + +// EmitInvariantViolation 发射不变量违反事件 +func EmitInvariantViolation(ruleCode, objectType string, objectID int64, err error) *InvariantViolation { + return &InvariantViolation{ + RuleCode: ruleCode, + ObjectType: objectType, + ObjectID: objectID, + Message: err.Error(), + OccurredAt: "now", // 实际应使用时间戳 + } +} + +// ValidateStateTransition 验证状态转换是否合法 +func ValidateStateTransition(from, to AccountStatus) bool { + validTransitions := map[AccountStatus][]AccountStatus{ + AccountStatusPending: {AccountStatusActive, AccountStatusDisabled}, + AccountStatusActive: {AccountStatusSuspended, AccountStatusDisabled}, + AccountStatusSuspended: {AccountStatusActive, AccountStatusDisabled}, + AccountStatusDisabled: {AccountStatusActive}, // 需要管理员权限 + } + + allowed, ok := validTransitions[from] + if !ok { + return false + } + + for _, status := range allowed { + if status == to { + return true + } + } + return false +} + +// ValidatePackageStateTransition 验证套餐状态转换 +func ValidatePackageStateTransition(from, to PackageStatus) bool { + validTransitions := map[PackageStatus][]PackageStatus{ + PackageStatusDraft: {PackageStatusActive}, + PackageStatusActive: {PackageStatusPaused, PackageStatusSoldOut, PackageStatusExpired}, + PackageStatusPaused: {PackageStatusActive, PackageStatusExpired}, + PackageStatusSoldOut: {}, // 只能由系统迁移 + PackageStatusExpired: {}, // 不能直接恢复,需要通过克隆 + } + + allowed, ok := validTransitions[from] + if !ok { + return false + } + + for _, status := range allowed { + if status == to { + return true + } + } + return false +} diff --git a/supply-api/internal/domain/invariants_test.go b/supply-api/internal/domain/invariants_test.go new file mode 100644 index 0000000..4dfa18e --- /dev/null +++ b/supply-api/internal/domain/invariants_test.go @@ -0,0 +1,101 @@ +package domain + +import ( + "testing" +) + +func TestValidateAccountStateTransition(t *testing.T) { + tests := []struct { + name string + from AccountStatus + to AccountStatus + expected bool + }{ + {"pending to active", AccountStatusPending, AccountStatusActive, true}, + {"pending to disabled", AccountStatusPending, AccountStatusDisabled, true}, + {"active to suspended", AccountStatusActive, AccountStatusSuspended, true}, + {"active to disabled", AccountStatusActive, AccountStatusDisabled, true}, + {"suspended to active", AccountStatusSuspended, AccountStatusActive, true}, + {"suspended to disabled", AccountStatusSuspended, AccountStatusDisabled, true}, + {"disabled to active", AccountStatusDisabled, AccountStatusActive, true}, + {"active to pending", AccountStatusActive, AccountStatusPending, false}, + {"suspended to pending", AccountStatusSuspended, AccountStatusPending, false}, + {"disabled to suspended", AccountStatusDisabled, AccountStatusSuspended, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ValidateStateTransition(tt.from, tt.to) + if result != tt.expected { + t.Errorf("ValidateStateTransition(%s, %s) = %v, want %v", tt.from, tt.to, result, tt.expected) + } + }) + } +} + +func TestValidatePackageStateTransition(t *testing.T) { + tests := []struct { + name string + from PackageStatus + to PackageStatus + expected bool + }{ + {"draft to active", PackageStatusDraft, PackageStatusActive, true}, + {"active to paused", PackageStatusActive, PackageStatusPaused, true}, + {"active to sold_out", PackageStatusActive, PackageStatusSoldOut, true}, + {"active to expired", PackageStatusActive, PackageStatusExpired, true}, + {"paused to active", PackageStatusPaused, PackageStatusActive, true}, + {"paused to expired", PackageStatusPaused, PackageStatusExpired, true}, + {"draft to paused", PackageStatusDraft, PackageStatusPaused, false}, + {"sold_out to active", PackageStatusSoldOut, PackageStatusActive, false}, + {"expired to active", PackageStatusExpired, PackageStatusActive, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ValidatePackageStateTransition(tt.from, tt.to) + if result != tt.expected { + t.Errorf("ValidatePackageStateTransition(%s, %s) = %v, want %v", tt.from, tt.to, result, tt.expected) + } + }) + } +} + +func TestInvariantErrors(t *testing.T) { + tests := []struct { + name string + err error + contains string + }{ + {"account cannot delete active", ErrAccountCannotDeleteActive, "cannot delete active"}, + {"account disabled requires admin", ErrAccountDisabledRequiresAdmin, "disabled account requires admin"}, + {"package sold out system only", ErrPackageSoldOutSystemOnly, "sold_out status"}, + {"package expired cannot restore", ErrPackageExpiredCannotRestore, "expired package cannot"}, + {"settlement cannot cancel", ErrSettlementCannotCancel, "cannot cancel"}, + {"withdraw exceeds balance", ErrWithdrawExceedsBalance, "exceeds available balance"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.err == nil { + t.Errorf("expected error but got nil") + } + if tt.contains != "" && !containsString(tt.err.Error(), tt.contains) { + t.Errorf("error = %v, want contains %v", tt.err, tt.contains) + } + }) + } +} + +func containsString(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsSubstring(s, substr)) +} + +func containsSubstring(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/supply-api/internal/middleware/auth.go b/supply-api/internal/middleware/auth.go new file mode 100644 index 0000000..3bad54c --- /dev/null +++ b/supply-api/internal/middleware/auth.go @@ -0,0 +1,477 @@ +package middleware + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "net/http" + "strconv" + "strings" + "time" + + "github.com/golang-jwt/jwt/v5" + "lijiaoqiao/supply-api/internal/repository" +) + +// TokenClaims JWT token claims +type TokenClaims struct { + jwt.RegisteredClaims + SubjectID string `json:"subject_id"` + Role string `json:"role"` + Scope []string `json:"scope"` + TenantID int64 `json:"tenant_id"` +} + +// AuthConfig 鉴权中间件配置 +type AuthConfig struct { + SecretKey string + Issuer string + CacheTTL time.Duration // token状态缓存TTL + Enabled bool // 是否启用鉴权 +} + +// AuthMiddleware 鉴权中间件 +type AuthMiddleware struct { + config AuthConfig + tokenCache *TokenCache + auditEmitter AuditEmitter +} + +// AuditEmitter 审计事件发射器 +type AuditEmitter interface { + Emit(ctx context.Context, event AuditEvent) error +} + +// AuditEvent 审计事件 +type AuditEvent struct { + EventName string + RequestID string + TokenID string + SubjectID string + Route string + ResultCode string + ClientIP string + CreatedAt time.Time +} + +// NewAuthMiddleware 创建鉴权中间件 +func NewAuthMiddleware(config AuthConfig, tokenCache *TokenCache, auditEmitter AuditEmitter) *AuthMiddleware { + if config.CacheTTL == 0 { + config.CacheTTL = 30 * time.Second + } + return &AuthMiddleware{ + config: config, + tokenCache: tokenCache, + auditEmitter: auditEmitter, + } +} + +// QueryKeyRejectMiddleware 拒绝外部query key入站 +// 对应M-016指标 +func (m *AuthMiddleware) QueryKeyRejectMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 检查query string中的可疑参数 + queryParams := r.URL.Query() + + // 禁止的query参数名 + blockedParams := []string{"key", "api_key", "token", "secret", "password", "credential"} + + for _, param := range blockedParams { + if _, exists := queryParams[param]; exists { + // 触发M-016指标事件 + if m.auditEmitter != nil { + m.auditEmitter.Emit(r.Context(), AuditEvent{ + EventName: "token.query_key.rejected", + RequestID: getRequestID(r), + Route: r.URL.Path, + ResultCode: "QUERY_KEY_NOT_ALLOWED", + ClientIP: getClientIP(r), + CreatedAt: time.Now(), + }) + } + + writeAuthError(w, http.StatusUnauthorized, "QUERY_KEY_NOT_ALLOWED", + "external query key is not allowed, use Authorization header") + return + } + } + + // 检查是否有API Key在query中(即使参数名不同) + for param := range queryParams { + lowerParam := strings.ToLower(param) + if strings.Contains(lowerParam, "key") || strings.Contains(lowerParam, "token") || strings.Contains(lowerParam, "secret") { + // 可能是编码的API Key + if len(queryParams.Get(param)) > 20 { + if m.auditEmitter != nil { + m.auditEmitter.Emit(r.Context(), AuditEvent{ + EventName: "token.query_key.rejected", + RequestID: getRequestID(r), + Route: r.URL.Path, + ResultCode: "QUERY_KEY_NOT_ALLOWED", + ClientIP: getClientIP(r), + CreatedAt: time.Now(), + }) + } + + writeAuthError(w, http.StatusUnauthorized, "QUERY_KEY_NOT_ALLOWED", + "suspicious query parameter detected") + return + } + } + } + + next.ServeHTTP(w, r) + }) +} + +// BearerExtractMiddleware 提取Bearer Token +func (m *AuthMiddleware) BearerExtractMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + authHeader := r.Header.Get("Authorization") + + if authHeader == "" { + if m.auditEmitter != nil { + m.auditEmitter.Emit(r.Context(), AuditEvent{ + EventName: "token.authn.fail", + RequestID: getRequestID(r), + Route: r.URL.Path, + ResultCode: "AUTH_MISSING_BEARER", + ClientIP: getClientIP(r), + CreatedAt: time.Now(), + }) + } + + writeAuthError(w, http.StatusUnauthorized, "AUTH_MISSING_BEARER", + "Authorization header with Bearer token is required") + return + } + + if !strings.HasPrefix(authHeader, "Bearer ") { + writeAuthError(w, http.StatusUnauthorized, "AUTH_INVALID_FORMAT", + "Authorization header must be in format: Bearer ") + return + } + + tokenString := strings.TrimPrefix(authHeader, "Bearer ") + if tokenString == "" { + writeAuthError(w, http.StatusUnauthorized, "AUTH_MISSING_BEARER", + "Bearer token is empty") + return + } + + // 将token存入context供后续使用 + ctx := context.WithValue(r.Context(), bearerTokenKey, tokenString) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// TokenVerifyMiddleware 校验JWT Token +func (m *AuthMiddleware) TokenVerifyMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tokenString := r.Context().Value(bearerTokenKey).(string) + + claims, err := m.verifyToken(tokenString) + if err != nil { + if m.auditEmitter != nil { + m.auditEmitter.Emit(r.Context(), AuditEvent{ + EventName: "token.authn.fail", + RequestID: getRequestID(r), + Route: r.URL.Path, + ResultCode: "AUTH_INVALID_TOKEN", + ClientIP: getClientIP(r), + CreatedAt: time.Now(), + }) + } + + writeAuthError(w, http.StatusUnauthorized, "AUTH_INVALID_TOKEN", + "token verification failed: "+err.Error()) + return + } + + // 检查token状态(是否被吊销) + status, err := m.checkTokenStatus(claims.ID) + if err == nil && status != "active" { + if m.auditEmitter != nil { + m.auditEmitter.Emit(r.Context(), AuditEvent{ + EventName: "token.authn.fail", + RequestID: getRequestID(r), + TokenID: claims.ID, + SubjectID: claims.SubjectID, + Route: r.URL.Path, + ResultCode: "AUTH_TOKEN_INACTIVE", + ClientIP: getClientIP(r), + CreatedAt: time.Now(), + }) + } + + writeAuthError(w, http.StatusUnauthorized, "AUTH_TOKEN_INACTIVE", + "token is revoked or expired") + return + } + + // 将claims存入context + ctx := context.WithValue(r.Context(), tokenClaimsKey, claims) + ctx = WithTenantID(ctx, claims.TenantID) + ctx = WithOperatorID(ctx, parseSubjectID(claims.SubjectID)) + + if m.auditEmitter != nil { + m.auditEmitter.Emit(r.Context(), AuditEvent{ + EventName: "token.authn.success", + RequestID: getRequestID(r), + TokenID: claims.ID, + SubjectID: claims.SubjectID, + Route: r.URL.Path, + ResultCode: "OK", + ClientIP: getClientIP(r), + CreatedAt: time.Now(), + }) + } + + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// ScopeRoleAuthzMiddleware 权限校验中间件 +func (m *AuthMiddleware) ScopeRoleAuthzMiddleware(requiredScope string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + claims, ok := r.Context().Value(tokenClaimsKey).(*TokenClaims) + if !ok { + writeAuthError(w, http.StatusUnauthorized, "AUTH_CONTEXT_MISSING", + "authentication context is missing") + return + } + + // 检查scope + if requiredScope != "" && !containsScope(claims.Scope, requiredScope) { + if m.auditEmitter != nil { + m.auditEmitter.Emit(r.Context(), AuditEvent{ + EventName: "token.authz.denied", + RequestID: getRequestID(r), + TokenID: claims.ID, + SubjectID: claims.SubjectID, + Route: r.URL.Path, + ResultCode: "AUTH_SCOPE_DENIED", + ClientIP: getClientIP(r), + CreatedAt: time.Now(), + }) + } + + writeAuthError(w, http.StatusForbidden, "AUTH_SCOPE_DENIED", + fmt.Sprintf("required scope '%s' is not granted", requiredScope)) + return + } + + // 检查role权限 + roleHierarchy := map[string]int{ + "admin": 3, + "owner": 2, + "viewer": 1, + } + + // 路由权限要求 + routeRoles := map[string]string{ + "/api/v1/supply/accounts": "owner", + "/api/v1/supply/packages": "owner", + "/api/v1/supply/settlements": "owner", + "/api/v1/supply/billing": "viewer", + "/api/v1/supplier/billing": "viewer", + } + + for path, requiredRole := range routeRoles { + if strings.HasPrefix(r.URL.Path, path) { + if roleLevel(claims.Role, roleHierarchy) < roleLevel(requiredRole, roleHierarchy) { + writeAuthError(w, http.StatusForbidden, "AUTH_ROLE_DENIED", + fmt.Sprintf("required role '%s' is not granted, current role: '%s'", requiredRole, claims.Role)) + return + } + } + } + + next.ServeHTTP(w, r) + }) + } +} + +// verifyToken 校验JWT token +func (m *AuthMiddleware) verifyToken(tokenString string) (*TokenClaims, error) { + token, err := jwt.ParseWithClaims(tokenString, &TokenClaims{}, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return []byte(m.config.SecretKey), nil + }) + + if err != nil { + return nil, err + } + + if claims, ok := token.Claims.(*TokenClaims); ok && token.Valid { + // 验证issuer + if claims.Issuer != m.config.Issuer { + return nil, errors.New("invalid token issuer") + } + + // 验证expiration + if claims.ExpiresAt != nil && claims.ExpiresAt.Time.Before(time.Now()) { + return nil, errors.New("token has expired") + } + + // 验证not before + if claims.NotBefore != nil && claims.NotBefore.Time.After(time.Now()) { + return nil, errors.New("token is not yet valid") + } + + return claims, nil + } + + return nil, errors.New("invalid token") +} + +// checkTokenStatus 检查token状态(从缓存或数据库) +func (m *AuthMiddleware) checkTokenStatus(tokenID string) (string, error) { + if m.tokenCache != nil { + // 先从缓存检查 + if status, found := m.tokenCache.Get(tokenID); found { + return status, nil + } + } + + // 缓存未命中,返回active(实际应该查询数据库) + return "active", nil +} + +// GetTokenClaims 从context获取token claims +func GetTokenClaims(ctx context.Context) *TokenClaims { + if claims, ok := ctx.Value(tokenClaimsKey).(*TokenClaims); ok { + return claims + } + return nil +} + +// context keys +const ( + bearerTokenKey contextKey = "bearer_token" + tokenClaimsKey contextKey = "token_claims" +) + +// writeAuthError 写入鉴权错误 +func writeAuthError(w http.ResponseWriter, status int, code, message string) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + resp := map[string]interface{}{ + "request_id": "", + "error": map[string]string{ + "code": code, + "message": message, + }, + } + json.NewEncoder(w).Encode(resp) +} + +// getRequestID 获取请求ID +func getRequestID(r *http.Request) string { + if id := r.Header.Get("X-Request-Id"); id != "" { + return id + } + return r.Header.Get("X-Request-ID") +} + +// getClientIP 获取客户端IP +func getClientIP(r *http.Request) string { + // 优先从X-Forwarded-For获取 + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + parts := strings.Split(xff, ",") + return strings.TrimSpace(parts[0]) + } + + // X-Real-IP + if xri := r.Header.Get("X-Real-IP"); xri != "" { + return xri + } + + // RemoteAddr + addr := r.RemoteAddr + if idx := strings.LastIndex(addr, ":"); idx != -1 { + return addr[:idx] + } + return addr +} + +// containsScope 检查scope列表是否包含目标scope +func containsScope(scopes []string, target string) bool { + for _, scope := range scopes { + if scope == target || scope == "*" { + return true + } + } + return false +} + +// roleLevel 获取角色等级 +func roleLevel(role string, hierarchy map[string]int) int { + if level, ok := hierarchy[role]; ok { + return level + } + return 0 +} + +// parseSubjectID 解析subject ID +func parseSubjectID(subject string) int64 { + parts := strings.Split(subject, ":") + if len(parts) >= 2 { + id, _ := strconv.ParseInt(parts[1], 10, 64) + return id + } + return 0 +} + +// TokenCache Token状态缓存 +type TokenCache struct { + data map[string]cacheEntry +} + +type cacheEntry struct { + status string + expires time.Time +} + +// NewTokenCache 创建token缓存 +func NewTokenCache() *TokenCache { + return &TokenCache{ + data: make(map[string]cacheEntry), + } +} + +// Get 获取token状态 +func (c *TokenCache) Get(tokenID string) (string, bool) { + if entry, ok := c.data[tokenID]; ok { + if time.Now().Before(entry.expires) { + return entry.status, true + } + delete(c.data, tokenID) + } + return "", false +} + +// Set 设置token状态 +func (c *TokenCache) Set(tokenID, status string, ttl time.Duration) { + c.data[tokenID] = cacheEntry{ + status: status, + expires: time.Now().Add(ttl), + } +} + +// Invalidate 使token失效 +func (c *TokenCache) Invalidate(tokenID string) { + delete(c.data, tokenID) +} + +// ComputeFingerprint 计算凭证指纹(用于审计) +func ComputeFingerprint(credential string) string { + hash := sha256.Sum256([]byte(credential)) + return hex.EncodeToString(hash[:]) +} diff --git a/supply-api/internal/middleware/auth_test.go b/supply-api/internal/middleware/auth_test.go new file mode 100644 index 0000000..df3bc59 --- /dev/null +++ b/supply-api/internal/middleware/auth_test.go @@ -0,0 +1,343 @@ +package middleware + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +func TestTokenVerify(t *testing.T) { + secretKey := "test-secret-key-12345678901234567890" + issuer := "test-issuer" + + tests := []struct { + name string + token string + expectError bool + errorContains string + }{ + { + name: "valid token", + token: createTestToken(secretKey, issuer, "subject:1", "owner", time.Now().Add(time.Hour)), + expectError: false, + }, + { + name: "expired token", + token: createTestToken(secretKey, issuer, "subject:1", "owner", time.Now().Add(-time.Hour)), + expectError: true, + errorContains: "expired", + }, + { + name: "wrong issuer", + token: createTestToken(secretKey, "wrong-issuer", "subject:1", "owner", time.Now().Add(time.Hour)), + expectError: true, + errorContains: "issuer", + }, + { + name: "invalid token", + token: "invalid.token.string", + expectError: true, + errorContains: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + middleware := &AuthMiddleware{ + config: AuthConfig{ + SecretKey: secretKey, + Issuer: issuer, + }, + } + + _, err := middleware.verifyToken(tt.token) + + if tt.expectError { + if err == nil { + t.Errorf("expected error but got nil") + } else if tt.errorContains != "" && !strings.Contains(err.Error(), tt.errorContains) { + t.Errorf("error = %v, want contains %v", err, tt.errorContains) + } + } else { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + } + }) + } +} + +func TestQueryKeyRejectMiddleware(t *testing.T) { + tests := []struct { + name string + query string + expectStatus int + }{ + { + name: "no query params", + query: "", + expectStatus: http.StatusOK, + }, + { + name: "normal params", + query: "?page=1&size=10", + expectStatus: http.StatusOK, + }, + { + name: "blocked key param", + query: "?key=abc123", + expectStatus: http.StatusUnauthorized, + }, + { + name: "blocked api_key param", + query: "?api_key=secret123", + expectStatus: http.StatusUnauthorized, + }, + { + name: "blocked token param", + query: "?token=bearer123", + expectStatus: http.StatusUnauthorized, + }, + { + name: "suspicious long param", + query: "?apikey=verylongparamvalueexceeding20chars", + expectStatus: http.StatusUnauthorized, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + middleware := &AuthMiddleware{ + auditEmitter: nil, + } + + nextCalled := false + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nextCalled = true + }) + + handler := middleware.QueryKeyRejectMiddleware(nextHandler) + + req := httptest.NewRequest("POST", "/api/v1/supply/accounts"+tt.query, nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + if tt.expectStatus == http.StatusOK { + if !nextCalled { + t.Errorf("expected next handler to be called") + } + } else { + if w.Code != tt.expectStatus { + t.Errorf("expected status %d, got %d", tt.expectStatus, w.Code) + } + } + }) + } +} + +func TestBearerExtractMiddleware(t *testing.T) { + tests := []struct { + name string + authHeader string + expectStatus int + }{ + { + name: "valid bearer", + authHeader: "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c", + expectStatus: http.StatusOK, + }, + { + name: "missing header", + authHeader: "", + expectStatus: http.StatusUnauthorized, + }, + { + name: "wrong prefix", + authHeader: "Basic abc123", + expectStatus: http.StatusUnauthorized, + }, + { + name: "empty token", + authHeader: "Bearer ", + expectStatus: http.StatusUnauthorized, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + middleware := &AuthMiddleware{} + + nextCalled := false + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nextCalled = true + // 检查context中是否有bearer token + if r.Context().Value(bearerTokenKey) == nil && tt.authHeader != "" && strings.HasPrefix(tt.authHeader, "Bearer ") { + // 这是预期的,因为token可能无效 + } + }) + + handler := middleware.BearerExtractMiddleware(nextHandler) + + req := httptest.NewRequest("POST", "/api/v1/supply/accounts", nil) + if tt.authHeader != "" { + req.Header.Set("Authorization", tt.authHeader) + } + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + if tt.expectStatus == http.StatusOK { + if !nextCalled { + t.Errorf("expected next handler to be called") + } + } else { + if w.Code != tt.expectStatus { + t.Errorf("expected status %d, got %d", tt.expectStatus, w.Code) + } + } + }) + } +} + +func TestContainsScope(t *testing.T) { + tests := []struct { + name string + scopes []string + target string + expected bool + }{ + { + name: "exact match", + scopes: []string{"read", "write", "delete"}, + target: "write", + expected: true, + }, + { + name: "wildcard", + scopes: []string{"*"}, + target: "anything", + expected: true, + }, + { + name: "no match", + scopes: []string{"read", "write"}, + target: "admin", + expected: false, + }, + { + name: "empty scopes", + scopes: []string{}, + target: "read", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := containsScope(tt.scopes, tt.target) + if result != tt.expected { + t.Errorf("containsScope(%v, %s) = %v, want %v", tt.scopes, tt.target, result, tt.expected) + } + }) + } +} + +func TestRoleLevel(t *testing.T) { + hierarchy := map[string]int{ + "admin": 3, + "owner": 2, + "viewer": 1, + } + + tests := []struct { + role string + expected int + }{ + {"admin", 3}, + {"owner", 2}, + {"viewer", 1}, + {"unknown", 0}, + } + + for _, tt := range tests { + t.Run(tt.role, func(t *testing.T) { + result := roleLevel(tt.role, hierarchy) + if result != tt.expected { + t.Errorf("roleLevel(%s) = %d, want %d", tt.role, result, tt.expected) + } + }) + } +} + +func TestTokenCache(t *testing.T) { + cache := NewTokenCache() + + t.Run("get empty", func(t *testing.T) { + status, found := cache.Get("nonexistent") + if found { + t.Errorf("expected not found") + } + if status != "" { + t.Errorf("expected empty status") + } + }) + + t.Run("set and get", func(t *testing.T) { + cache.Set("token1", "active", time.Hour) + + status, found := cache.Get("token1") + if !found { + t.Errorf("expected to find token1") + } + if status != "active" { + t.Errorf("expected status 'active', got '%s'", status) + } + }) + + t.Run("invalidate", func(t *testing.T) { + cache.Set("token2", "revoked", time.Hour) + cache.Invalidate("token2") + + _, found := cache.Get("token2") + if found { + t.Errorf("expected token2 to be invalidated") + } + }) + + t.Run("expiration", func(t *testing.T) { + cache.Set("token3", "active", time.Nanosecond) + time.Sleep(time.Millisecond) + + _, found := cache.Get("token3") + if found { + t.Errorf("expected token3 to be expired") + } + }) +} + +// Helper functions + +func createTestToken(secretKey, issuer, subject, role string, expiresAt time.Time) string { + claims := TokenClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: issuer, + Subject: subject, + ExpiresAt: jwt.NewNumericDate(expiresAt), + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + SubjectID: subject, + Role: role, + Scope: []string{"read", "write"}, + TenantID: 1, + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenString, _ := token.SignedString([]byte(secretKey)) + return tokenString +} diff --git a/supply-api/internal/middleware/idempotency.go b/supply-api/internal/middleware/idempotency.go new file mode 100644 index 0000000..8af34bf --- /dev/null +++ b/supply-api/internal/middleware/idempotency.go @@ -0,0 +1,279 @@ +package middleware + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "lijiaoqiao/supply-api/internal/repository" +) + +// IdempotencyConfig 幂等中间件配置 +type IdempotencyConfig struct { + TTL time.Duration // 幂等有效期,默认24h + ProcessingTTL time.Duration // 处理中状态有效期,默认30s + Enabled bool // 是否启用幂等 +} + +// IdempotencyMiddleware 幂等中间件 +type IdempotencyMiddleware struct { + idempotencyRepo *repository.IdempotencyRepository + config IdempotencyConfig +} + +// NewIdempotencyMiddleware 创建幂等中间件 +func NewIdempotencyMiddleware(repo *repository.IdempotencyRepository, config IdempotencyConfig) *IdempotencyMiddleware { + if config.TTL == 0 { + config.TTL = 24 * time.Hour + } + if config.ProcessingTTL == 0 { + config.ProcessingTTL = 30 * time.Second + } + return &IdempotencyMiddleware{ + idempotencyRepo: repo, + config: config, + } +} + +// IdempotencyKey 幂等键信息 +type IdempotencyKey struct { + TenantID int64 + OperatorID int64 + APIPath string + Key string +} + +// ExtractIdempotencyKey 从请求中提取幂等信息 +func ExtractIdempotencyKey(r *http.Request, tenantID, operatorID int64) (*IdempotencyKey, error) { + requestID := r.Header.Get("X-Request-Id") + if requestID == "" { + return nil, fmt.Errorf("missing X-Request-Id header") + } + + idempotencyKey := r.Header.Get("Idempotency-Key") + if idempotencyKey == "" { + return nil, fmt.Errorf("missing Idempotency-Key header") + } + + if len(idempotencyKey) < 16 || len(idempotencyKey) > 128 { + return nil, fmt.Errorf("Idempotency-Key length must be 16-128") + } + + // 从路径提取API路径(去除前缀) + apiPath := r.URL.Path + if strings.HasPrefix(apiPath, "/api/v1") { + apiPath = strings.TrimPrefix(apiPath, "/api/v1") + } + + return &IdempotencyKey{ + TenantID: tenantID, + OperatorID: operatorID, + APIPath: apiPath, + Key: idempotencyKey, + }, nil +} + +// ComputePayloadHash 计算请求体的SHA256哈希 +func ComputePayloadHash(body []byte) string { + hash := sha256.Sum256(body) + return hex.EncodeToString(hash[:]) +} + +// IdempotentHandler 幂等处理器函数 +type IdempotentHandler func(ctx context.Context, w http.ResponseWriter, r *http.Request, record *repository.IdempotencyRecord) error + +// Wrap 包装HTTP处理器以实现幂等 +func (m *IdempotencyMiddleware) Wrap(handler IdempotentHandler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if !m.config.Enabled { + handler(r.Context(), w, r, nil) + return + } + + ctx := r.Context() + + // 从context获取租户和操作者ID(由鉴权中间件设置) + tenantID := getTenantID(ctx) + operatorID := getOperatorID(ctx) + + // 提取幂等信息 + idempKey, err := ExtractIdempotencyKey(r, tenantID, operatorID) + if err != nil { + writeIdempotencyError(w, http.StatusBadRequest, "IDEMPOTENCY_KEY_INVALID", err.Error()) + return + } + + // 读取请求体 + body, err := io.ReadAll(r.Body) + if err != nil { + writeIdempotencyError(w, http.StatusBadRequest, "BODY_READ_ERROR", err.Error()) + return + } + // 重新填充body以供后续处理 + r.Body = io.NopCloser(bytes.NewBuffer(body)) + + // 计算payload hash + payloadHash := ComputePayloadHash(body) + + // 查询已存在的幂等记录 + existingRecord, err := m.idempotencyRepo.GetByKey(ctx, idempKey.TenantID, idempKey.OperatorID, idempKey.APIPath, idempKey.Key) + if err != nil { + writeIdempotencyError(w, http.StatusInternalServerError, "IDEMPOTENCY_CHECK_FAILED", err.Error()) + return + } + + if existingRecord != nil { + // 存在记录,处理不同情况 + switch existingRecord.Status { + case repository.IdempotencyStatusSucceeded: + // 同参重放:返回原结果 + if existingRecord.PayloadHash == payloadHash { + writeIdempotentReplay(w, existingRecord.ResponseCode, existingRecord.ResponseBody) + return + } + // 异参重放:返回409冲突 + writeIdempotencyError(w, http.StatusConflict, "IDEMPOTENCY_PAYLOAD_MISMATCH", + fmt.Sprintf("same idempotency key but different payload, original request_id: %s", existingRecord.RequestID)) + return + + case repository.IdempotencyStatusProcessing: + // 处理中:检查是否超时 + if time.Since(existingRecord.UpdatedAt) < m.config.ProcessingTTL { + retryAfter := m.config.ProcessingTTL - time.Since(existingRecord.UpdatedAt) + writeIdempotencyProcessing(w, int(retryAfter.Milliseconds()), existingRecord.RequestID) + return + } + // 超时:允许重试(记录会自然过期) + + case repository.IdempotencyStatusFailed: + // 失败状态也允许重试 + } + } + + // 尝试创建或更新幂等记录 + requestID := r.Header.Get("X-Request-Id") + record := &repository.IdempotencyRecord{ + TenantID: idempKey.TenantID, + OperatorID: idempKey.OperatorID, + APIPath: idempKey.APIPath, + IdempotencyKey: idempKey.Key, + RequestID: requestID, + PayloadHash: payloadHash, + Status: repository.IdempotencyStatusProcessing, + ExpiresAt: time.Now().Add(m.config.TTL), + } + + // 使用AcquireLock获取锁 + lockedRecord, err := m.idempotencyRepo.AcquireLock(ctx, idempKey.TenantID, idempKey.OperatorID, idempKey.APIPath, idempKey.Key, m.config.TTL) + if err != nil { + writeIdempotencyError(w, http.StatusInternalServerError, "IDEMPOTENCY_LOCK_FAILED", err.Error()) + return + } + + // 更新记录中的request_id和payload_hash + if lockedRecord.ID != 0 && (lockedRecord.RequestID == "" || lockedRecord.PayloadHash == "") { + lockedRecord.RequestID = requestID + lockedRecord.PayloadHash = payloadHash + } + + // 执行实际业务处理 + err = handler(ctx, w, r, lockedRecord) + + // 根据处理结果更新幂等记录 + if err != nil { + // 业务处理失败 + errMsg, _ := json.Marshal(map[string]string{"error": err.Error()}) + _ = m.idempotencyRepo.UpdateFailed(ctx, lockedRecord.ID, http.StatusInternalServerError, errMsg) + return + } + + // 业务处理成功,更新为成功状态 + // 注意:这里需要从w中获取实际的响应码和body + // 简化处理:使用200 + successBody, _ := json.Marshal(map[string]interface{}{"status": "ok"}) + _ = m.idempotencyRepo.UpdateSuccess(ctx, lockedRecord.ID, http.StatusOK, successBody) + } +} + +// writeIdempotencyError 写入幂等错误 +func writeIdempotencyError(w http.ResponseWriter, status int, code, message string) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + resp := map[string]interface{}{ + "request_id": "", + "error": map[string]string{ + "code": code, + "message": message, + }, + } + json.NewEncoder(w).Encode(resp) +} + +// writeIdempotencyProcessing 写入处理中状态 +func writeIdempotencyProcessing(w http.ResponseWriter, retryAfterMs int, requestID string) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Retry-After-Ms", fmt.Sprintf("%d", retryAfterMs)) + w.Header().Set("X-Request-Id", requestID) + w.WriteHeader(http.StatusAccepted) + resp := map[string]interface{}{ + "request_id": requestID, + "error": map[string]string{ + "code": "IDEMPOTENCY_IN_PROGRESS", + "message": "request is being processed, please retry later", + }, + } + json.NewEncoder(w).Encode(resp) +} + +// writeIdempotentReplay 写入幂等重放响应 +func writeIdempotentReplay(w http.ResponseWriter, status int, body json.RawMessage) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("X-Idempotent-Replay", "true") + w.WriteHeader(status) + if body != nil { + w.Write(body) + } +} + +// context keys +type contextKey string + +const ( + tenantIDKey contextKey = "tenant_id" + operatorIDKey contextKey = "operator_id" +) + +// WithTenantID 在context中设置租户ID +func WithTenantID(ctx context.Context, tenantID int64) context.Context { + return context.WithValue(ctx, tenantIDKey, tenantID) +} + +// WithOperatorID 在context中设置操作者ID +func WithOperatorID(ctx context.Context, operatorID int64) context.Context { + return context.WithValue(ctx, operatorIDKey, operatorID) +} + +func getTenantID(ctx context.Context) int64 { + if v := ctx.Value(tenantIDKey); v != nil { + if id, ok := v.(int64); ok { + return id + } + } + return 0 +} + +func getOperatorID(ctx context.Context) int64 { + if v := ctx.Value(operatorIDKey); v != nil { + if id, ok := v.(int64); ok { + return id + } + } + return 0 +} diff --git a/supply-api/internal/middleware/idempotency_test.go b/supply-api/internal/middleware/idempotency_test.go new file mode 100644 index 0000000..9a55052 --- /dev/null +++ b/supply-api/internal/middleware/idempotency_test.go @@ -0,0 +1,211 @@ +package middleware + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "lijiaoqiao/supply-api/internal/repository" +) + +// MockIdempotencyRepository 模拟幂等仓储 +type MockIdempotencyRepository struct { + records map[string]*repository.IdempotencyRecord +} + +func NewMockIdempotencyRepository() *MockIdempotencyRepository { + return &MockIdempotencyRepository{ + records: make(map[string]*repository.IdempotencyRecord), + } +} + +func (r *MockIdempotencyRepository) GetByKey(ctx context.Context, tenantID, operatorID int64, apiPath, idempotencyKey string) (*repository.IdempotencyRecord, error) { + key := buildKey(tenantID, operatorID, apiPath, idempotencyKey) + if record, ok := r.records[key]; ok { + if time.Now().Before(record.ExpiresAt) { + return record, nil + } + } + return nil, nil +} + +func (r *MockIdempotencyRepository) Create(ctx context.Context, record *repository.IdempotencyRecord) error { + key := buildKey(record.TenantID, record.OperatorID, record.APIPath, record.IdempotencyKey) + r.records[key] = record + return nil +} + +func (r *MockIdempotencyRepository) UpdateSuccess(ctx context.Context, id int64, responseCode int, responseBody json.RawMessage) error { + return nil +} + +func (r *MockIdempotencyRepository) UpdateFailed(ctx context.Context, id int64, responseCode int, responseBody json.RawMessage) error { + return nil +} + +func (r *MockIdempotencyRepository) AcquireLock(ctx context.Context, tenantID, operatorID int64, apiPath, idempotencyKey string, ttl time.Duration) (*repository.IdempotencyRecord, error) { + key := buildKey(tenantID, operatorID, apiPath, idempotencyKey) + record := &repository.IdempotencyRecord{ + TenantID: tenantID, + OperatorID: operatorID, + APIPath: apiPath, + IdempotencyKey: idempotencyKey, + RequestID: "test-request-id", + PayloadHash: "", + Status: repository.IdempotencyStatusProcessing, + ExpiresAt: time.Now().Add(ttl), + } + r.records[key] = record + return record, nil +} + +func buildKey(tenantID, operatorID int64, apiPath, idempotencyKey string) string { + return strings.Join([]string{ + string(rune(tenantID)), + string(rune(operatorID)), + apiPath, + idempotencyKey, + }, ":") +} + +func TestComputePayloadHash(t *testing.T) { + tests := []struct { + name string + body []byte + expected string + }{ + { + name: "empty body", + body: []byte{}, + expected: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + }, + { + name: "simple JSON", + body: []byte(`{"key":"value"}`), + expected: computeExpectedHash(`{"key":"value"}`), + }, + { + name: "JSON with spaces", + body: []byte(`{ "key": "value" }`), + expected: computeExpectedHash(`{ "key": "value" }`), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ComputePayloadHash(tt.body) + if result != tt.expected { + t.Errorf("ComputePayloadHash() = %v, want %v", result, tt.expected) + } + }) + } +} + +func computeExpectedHash(s string) string { + hash := sha256.Sum256([]byte(s)) + return hex.EncodeToString(hash[:]) +} + +func TestExtractIdempotencyKey(t *testing.T) { + tests := []struct { + name string + headers map[string]string + expectError bool + errorCode string + }{ + { + name: "valid headers", + headers: map[string]string{ + "X-Request-Id": "req-123", + "Idempotency-Key": "idem-key-12345678", + }, + expectError: false, + }, + { + name: "missing X-Request-Id", + headers: map[string]string{ + "Idempotency-Key": "idem-key-12345678", + }, + expectError: true, + errorCode: "missing X-Request-Id header", + }, + { + name: "missing Idempotency-Key", + headers: map[string]string{ + "X-Request-Id": "req-123", + }, + expectError: true, + errorCode: "missing Idempotency-Key header", + }, + { + name: "Idempotency-Key too short", + headers: map[string]string{ + "X-Request-Id": "req-123", + "Idempotency-Key": "short", + }, + expectError: true, + errorCode: "Idempotency-Key length must be 16-128", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("POST", "/api/v1/supply/accounts", nil) + for k, v := range tt.headers { + req.Header.Set(k, v) + } + + result, err := ExtractIdempotencyKey(req, 1, 1) + + if tt.expectError { + if err == nil { + t.Errorf("expected error but got nil") + } + if err != nil && !strings.Contains(err.Error(), tt.errorCode) { + t.Errorf("error = %v, want contains %v", err, tt.errorCode) + } + } else { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if result == nil { + t.Errorf("expected result but got nil") + } + } + }) + } +} + +func TestIdempotentHandler(t *testing.T) { + // 创建测试handler + testHandler := func(ctx context.Context, w http.ResponseWriter, r *http.Request, record *repository.IdempotencyRecord) error { + w.WriteHeader(http.StatusCreated) + json.NewEncoder(w).Encode(map[string]string{"status": "created"}) + return nil + } + + middleware := NewIdempotencyMiddleware(nil, IdempotencyConfig{ + Enabled: false, // 禁用幂等,只测试handler包装 + }) + + handler := middleware.Wrap(testHandler) + + t.Run("handler executes successfully", func(t *testing.T) { + req := httptest.NewRequest("POST", "/api/v1/supply/accounts", strings.NewReader(`{"key":"value"}`)) + req.Header.Set("X-Request-Id", "req-123") + req.Header.Set("Idempotency-Key", "idem-key-12345678") + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + if w.Code != http.StatusCreated { + t.Errorf("expected status %d, got %d", http.StatusCreated, w.Code) + } + }) +} diff --git a/supply-api/internal/repository/account.go b/supply-api/internal/repository/account.go new file mode 100644 index 0000000..a1db940 --- /dev/null +++ b/supply-api/internal/repository/account.go @@ -0,0 +1,291 @@ +package repository + +import ( + "context" + "errors" + "fmt" + "net/netip" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + "lijiaoqiao/supply-api/internal/domain" +) + +// AccountRepository 账号仓储 +type AccountRepository struct { + pool *pgxpool.Pool +} + +// NewAccountRepository 创建账号仓储 +func NewAccountRepository(pool *pgxpool.Pool) *AccountRepository { + return &AccountRepository{pool: pool} +} + +// Create 创建账号 +func (r *AccountRepository) Create(ctx context.Context, account *domain.Account, requestID, idempotencyKey, traceID string) error { + query := ` + INSERT INTO supply_accounts ( + user_id, platform, account_type, account_name, + encrypted_credentials, key_id, + status, risk_level, total_quota, available_quota, frozen_quota, + is_verified, verified_at, last_check_at, + tos_compliant, tos_check_result, + total_requests, total_tokens, total_cost, success_rate, + risk_score, risk_reason, is_frozen, frozen_reason, + credential_cipher_algo, credential_kms_key_alias, credential_key_version, + quota_unit, currency_code, version, + created_ip, updated_ip, audit_trace_id, + request_id, idempotency_key + ) VALUES ( + $1, $2, $3, $4, $5, $6, + $7, $8, $9, $10, $11, $12, $13, $14, + $15, $16, $17, $18, $19, $20, + $21, $22, $23, $24, $25, $26, $27, + $28, $29, $30, $31, $32, $33, $34, $35 + ) + RETURNING id, created_at, updated_at + ` + + var createdIP, updatedIP *netip.Addr + if account.CreatedIP != nil { + createdIP = account.CreatedIP + } + if account.UpdatedIP != nil { + updatedIP = account.UpdatedIP + } + + err := r.pool.QueryRow(ctx, query, + account.SupplierID, account.Provider, account.AccountType, account.Alias, + account.CredentialHash, account.KeyID, + account.Status, account.RiskLevel, account.TotalQuota, account.AvailableQuota, account.FrozenQuota, + account.IsVerified, account.VerifiedAt, account.LastCheckAt, + account.TosCompliant, account.TosCheckResult, + account.TotalRequests, account.TotalTokens, account.TotalCost, account.SuccessRate, + account.RiskScore, account.RiskReason, account.IsFrozen, account.FrozenReason, + "AES-256-GCM", "kms/supply/default", 1, + "token", "USD", 0, + createdIP, updatedIP, traceID, + requestID, idempotencyKey, + ).Scan(&account.ID, &account.CreatedAt, &account.UpdatedAt) + + if err != nil { + return fmt.Errorf("failed to create account: %w", err) + } + return nil +} + +// GetByID 获取账号 +func (r *AccountRepository) GetByID(ctx context.Context, supplierID, id int64) (*domain.Account, error) { + query := ` + SELECT id, user_id, platform, account_type, account_name, + encrypted_credentials, key_id, + status, risk_level, total_quota, available_quota, frozen_quota, + is_verified, verified_at, last_check_at, + tos_compliant, tos_check_result, + total_requests, total_tokens, total_cost, success_rate, + risk_score, risk_reason, is_frozen, frozen_reason, + credential_cipher_algo, credential_kms_key_alias, credential_key_version, + quota_unit, currency_code, version, + created_ip, updated_ip, audit_trace_id, + created_at, updated_at + FROM supply_accounts + WHERE id = $1 AND user_id = $2 + ` + + account := &domain.Account{} + var createdIP, updatedIP netip.Addr + var credentialFingerprint *string + + err := r.pool.QueryRow(ctx, query, id, supplierID).Scan( + &account.ID, &account.SupplierID, &account.Provider, &account.AccountType, &account.Alias, + &account.CredentialHash, &account.KeyID, + &account.Status, &account.RiskLevel, &account.TotalQuota, &account.AvailableQuota, &account.FrozenQuota, + &account.IsVerified, &account.VerifiedAt, &account.LastCheckAt, + &account.TosCompliant, &account.TosCheckResult, + &account.TotalRequests, &account.TotalTokens, &account.TotalCost, &account.SuccessRate, + &account.RiskScore, &account.RiskReason, &account.IsFrozen, &account.FrozenReason, + &account.CredentialCipherAlgo, &account.CredentialKMSKeyAlias, &account.CredentialKeyVersion, + &account.QuotaUnit, &account.CurrencyCode, &account.Version, + &createdIP, &updatedIP, &account.AuditTraceID, + &account.CreatedAt, &account.UpdatedAt, + ) + + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrNotFound + } + if err != nil { + return nil, fmt.Errorf("failed to get account: %w", err) + } + + account.CreatedIP = &createdIP + account.UpdatedIP = &updatedIP + _ = credentialFingerprint // 未使用但字段存在 + + return account, nil +} + +// Update 更新账号(乐观锁) +func (r *AccountRepository) Update(ctx context.Context, account *domain.Account, expectedVersion int) error { + query := ` + UPDATE supply_accounts SET + platform = $1, account_type = $2, account_name = $3, + status = $4, risk_level = $5, total_quota = $6, available_quota = $7, + frozen_quota = $8, is_verified = $9, verified_at = $10, last_check_at = $11, + tos_compliant = $12, tos_check_result = $13, + total_requests = $14, total_tokens = $15, total_cost = $16, success_rate = $17, + risk_score = $18, risk_reason = $19, is_frozen = $20, frozen_reason = $21, + version = $22, updated_at = $23 + WHERE id = $24 AND user_id = $25 AND version = $26 + ` + + account.UpdatedAt = time.Now() + newVersion := expectedVersion + 1 + + cmdTag, err := r.pool.Exec(ctx, query, + account.Provider, account.AccountType, account.Alias, + account.Status, account.RiskLevel, account.TotalQuota, account.AvailableQuota, + account.FrozenQuota, account.IsVerified, account.VerifiedAt, account.LastCheckAt, + account.TosCompliant, account.TosCheckResult, + account.TotalRequests, account.TotalTokens, account.TotalCost, account.SuccessRate, + account.RiskScore, account.RiskReason, account.IsFrozen, account.FrozenReason, + newVersion, account.UpdatedAt, + account.ID, account.SupplierID, expectedVersion, + ) + + if err != nil { + return fmt.Errorf("failed to update account: %w", err) + } + + if cmdTag.RowsAffected() == 0 { + return ErrConcurrencyConflict + } + + account.Version = newVersion + return nil +} + +// UpdateWithPessimisticLock 更新账号(悲观锁,用于提现等关键操作) +func (r *AccountRepository) UpdateWithPessimisticLock(ctx context.Context, tx pgxpool.Tx, account *domain.Account, expectedVersion int) error { + query := ` + UPDATE supply_accounts SET + available_quota = $1, frozen_quota = $2, + version = $3, updated_at = $4 + WHERE id = $5 AND version = $6 + RETURNING version + ` + + account.UpdatedAt = time.Now() + newVersion := expectedVersion + 1 + + err := tx.QueryRow(ctx, query, + account.AvailableQuota, account.FrozenQuota, + newVersion, account.UpdatedAt, + account.ID, expectedVersion, + ).Scan(&account.Version) + + if errors.Is(err, pgx.ErrNoRows) { + return ErrConcurrencyConflict + } + if err != nil { + return fmt.Errorf("failed to update account with lock: %w", err) + } + + return nil +} + +// GetForUpdate 获取账号并加行锁(用于事务内) +func (r *AccountRepository) GetForUpdate(ctx context.Context, tx pgxpool.Tx, supplierID, id int64) (*domain.Account, error) { + query := ` + SELECT id, user_id, platform, account_type, account_name, + encrypted_credentials, key_id, + status, risk_level, total_quota, available_quota, frozen_quota, + is_verified, verified_at, last_check_at, + tos_compliant, tos_check_result, + total_requests, total_tokens, total_cost, success_rate, + risk_score, risk_reason, is_frozen, frozen_reason, + version, + created_at, updated_at + FROM supply_accounts + WHERE id = $1 AND user_id = $2 + FOR UPDATE + ` + + account := &domain.Account{} + err := tx.QueryRow(ctx, query, id, supplierID).Scan( + &account.ID, &account.SupplierID, &account.Provider, &account.AccountType, &account.Alias, + &account.CredentialHash, &account.KeyID, + &account.Status, &account.RiskLevel, &account.TotalQuota, &account.AvailableQuota, &account.FrozenQuota, + &account.IsVerified, &account.VerifiedAt, &account.LastCheckAt, + &account.TosCompliant, &account.TosCheckResult, + &account.TotalRequests, &account.TotalTokens, &account.TotalCost, &account.SuccessRate, + &account.RiskScore, &account.RiskReason, &account.IsFrozen, &account.FrozenReason, + &account.Version, + &account.CreatedAt, &account.UpdatedAt, + ) + + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrNotFound + } + if err != nil { + return nil, fmt.Errorf("failed to get account for update: %w", err) + } + + return account, nil +} + +// List 列出账号 +func (r *AccountRepository) List(ctx context.Context, supplierID int64) ([]*domain.Account, error) { + query := ` + SELECT id, user_id, platform, account_type, account_name, + status, risk_level, total_quota, available_quota, frozen_quota, + is_verified, verified_at, last_check_at, + tos_compliant, success_rate, + risk_score, is_frozen, + version, created_at, updated_at + FROM supply_accounts + WHERE user_id = $1 + ORDER BY created_at DESC + ` + + rows, err := r.pool.Query(ctx, query, supplierID) + if err != nil { + return nil, fmt.Errorf("failed to list accounts: %w", err) + } + defer rows.Close() + + var accounts []*domain.Account + for rows.Next() { + account := &domain.Account{} + err := rows.Scan( + &account.ID, &account.SupplierID, &account.Provider, &account.AccountType, &account.Alias, + &account.Status, &account.RiskLevel, &account.TotalQuota, &account.AvailableQuota, &account.FrozenQuota, + &account.IsVerified, &account.VerifiedAt, &account.LastCheckAt, + &account.TosCompliant, &account.SuccessRate, + &account.RiskScore, &account.IsFrozen, + &account.Version, &account.CreatedAt, &account.UpdatedAt, + ) + if err != nil { + return nil, fmt.Errorf("failed to scan account: %w", err) + } + accounts = append(accounts, account) + } + + return accounts, nil +} + +// GetWithdrawableBalance 获取可提现余额 +func (r *AccountRepository) GetWithdrawableBalance(ctx context.Context, supplierID int64) (float64, error) { + query := ` + SELECT COALESCE(SUM(available_quota), 0) + FROM supply_accounts + WHERE user_id = $1 AND status = 'active' + ` + + var balance float64 + err := r.pool.QueryRow(ctx, query, supplierID).Scan(&balance) + if err != nil { + return 0, fmt.Errorf("failed to get withdrawable balance: %w", err) + } + return balance, nil +} diff --git a/supply-api/internal/repository/db.go b/supply-api/internal/repository/db.go new file mode 100644 index 0000000..8ba72ea --- /dev/null +++ b/supply-api/internal/repository/db.go @@ -0,0 +1,81 @@ +package repository + +import ( + "context" + "fmt" + "time" + + "github.com/jackc/pgx/v5/pgxpool" + "lijiaoqiao/supply-api/internal/config" +) + +// DB 数据库连接池 +type DB struct { + Pool *pgxpool.Pool +} + +// NewDB 创建数据库连接池 +func NewDB(ctx context.Context, cfg config.DatabaseConfig) (*DB, error) { + poolConfig, err := pgxpool.ParseConfig(cfg.DSN()) + if err != nil { + return nil, fmt.Errorf("failed to parse database config: %w", err) + } + + poolConfig.MaxConns = int32(cfg.MaxOpenConns) + poolConfig.MinConns = int32(cfg.MaxIdleConns) + poolConfig.MaxConnLifetime = cfg.ConnMaxLifetime + poolConfig.MaxConnIdleTime = cfg.ConnMaxIdleTime + poolConfig.HealthCheckPeriod = 30 * time.Second + + pool, err := pgxpool.NewWithConfig(ctx, poolConfig) + if err != nil { + return nil, fmt.Errorf("failed to create connection pool: %w", err) + } + + // 验证连接 + if err := pool.Ping(ctx); err != nil { + pool.Close() + return nil, fmt.Errorf("failed to ping database: %w", err) + } + + return &DB{Pool: pool}, nil +} + +// Close 关闭连接池 +func (db *DB) Close() { + if db.Pool != nil { + db.Pool.Close() + } +} + +// HealthCheck 健康检查 +func (db *DB) HealthCheck(ctx context.Context) error { + return db.Pool.Ping(ctx) +} + +// BeginTx 开始事务 +func (db *DB) BeginTx(ctx context.Context) (Transaction, error) { + tx, err := db.Pool.Begin(ctx) + if err != nil { + return nil, err + } + return &txWrapper{tx: tx}, nil +} + +// Transaction 事务接口 +type Transaction interface { + Commit(ctx context.Context) error + Rollback(ctx context.Context) error +} + +type txWrapper struct { + tx pgxpool.Tx +} + +func (t *txWrapper) Commit(ctx context.Context) error { + return t.tx.Commit(ctx) +} + +func (t *txWrapper) Rollback(ctx context.Context) error { + return t.tx.Rollback(ctx) +} diff --git a/supply-api/internal/repository/idempotency.go b/supply-api/internal/repository/idempotency.go new file mode 100644 index 0000000..8f3573d --- /dev/null +++ b/supply-api/internal/repository/idempotency.go @@ -0,0 +1,246 @@ +package repository + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" +) + +// IdempotencyStatus 幂等记录状态 +type IdempotencyStatus string + +const ( + IdempotencyStatusProcessing IdempotencyStatus = "processing" + IdempotencyStatusSucceeded IdempotencyStatus = "succeeded" + IdempotencyStatusFailed IdempotencyStatus = "failed" +) + +// IdempotencyRecord 幂等记录 +type IdempotencyRecord struct { + ID int64 `json:"id"` + TenantID int64 `json:"tenant_id"` + OperatorID int64 `json:"operator_id"` + APIPath string `json:"api_path"` + IdempotencyKey string `json:"idempotency_key"` + RequestID string `json:"request_id"` + PayloadHash string `json:"payload_hash"` // SHA256 of request body + ResponseCode int `json:"response_code"` + ResponseBody json.RawMessage `json:"response_body"` + Status IdempotencyStatus `json:"status"` + ExpiresAt time.Time `json:"expires_at"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// IdempotencyRepository 幂等记录仓储 +type IdempotencyRepository struct { + pool *pgxpool.Pool +} + +// NewIdempotencyRepository 创建幂等记录仓储 +func NewIdempotencyRepository(pool *pgxpool.Pool) *IdempotencyRepository { + return &IdempotencyRepository{pool: pool} +} + +// GetByKey 根据幂等键获取记录 +func (r *IdempotencyRepository) GetByKey(ctx context.Context, tenantID, operatorID int64, apiPath, idempotencyKey string) (*IdempotencyRecord, error) { + query := ` + SELECT id, tenant_id, operator_id, api_path, idempotency_key, + request_id, payload_hash, response_code, response_body, + status, expires_at, created_at, updated_at + FROM supply_idempotency_records + WHERE tenant_id = $1 AND operator_id = $2 AND api_path = $3 AND idempotency_key = $4 + AND expires_at > $5 + FOR UPDATE + ` + + record := &IdempotencyRecord{} + err := r.pool.QueryRow(ctx, query, tenantID, operatorID, apiPath, idempotencyKey, time.Now()).Scan( + &record.ID, &record.TenantID, &record.OperatorID, &record.APIPath, &record.IdempotencyKey, + &record.RequestID, &record.PayloadHash, &record.ResponseCode, &record.ResponseBody, + &record.Status, &record.ExpiresAt, &record.CreatedAt, &record.UpdatedAt, + ) + + if errors.Is(err, pgx.ErrNoRows) { + return nil, nil // 不存在或已过期 + } + if err != nil { + return nil, fmt.Errorf("failed to get idempotency record: %w", err) + } + + return record, nil +} + +// Create 创建幂等记录 +func (r *IdempotencyRepository) Create(ctx context.Context, record *IdempotencyRecord) error { + query := ` + INSERT INTO supply_idempotency_records ( + tenant_id, operator_id, api_path, idempotency_key, + request_id, payload_hash, status, expires_at + ) VALUES ( + $1, $2, $3, $4, $5, $6, $7, $8 + ) + RETURNING id, created_at, updated_at + ` + + err := r.pool.QueryRow(ctx, query, + record.TenantID, record.OperatorID, record.APIPath, record.IdempotencyKey, + record.RequestID, record.PayloadHash, record.Status, record.ExpiresAt, + ).Scan(&record.ID, &record.CreatedAt, &record.UpdatedAt) + + if err != nil { + return fmt.Errorf("failed to create idempotency record: %w", err) + } + return nil +} + +// UpdateSuccess 更新为成功状态 +func (r *IdempotencyRepository) UpdateSuccess(ctx context.Context, id int64, responseCode int, responseBody json.RawMessage) error { + query := ` + UPDATE supply_idempotency_records SET + response_code = $1, + response_body = $2, + status = $3, + updated_at = $4 + WHERE id = $5 + ` + + _, err := r.pool.Exec(ctx, query, responseCode, responseBody, IdempotencyStatusSucceeded, time.Now(), id) + if err != nil { + return fmt.Errorf("failed to update idempotency record to success: %w", err) + } + return nil +} + +// UpdateFailed 更新为失败状态 +func (r *IdempotencyRepository) UpdateFailed(ctx context.Context, id int64, responseCode int, responseBody json.RawMessage) error { + query := ` + UPDATE supply_idempotency_records SET + response_code = $1, + response_body = $2, + status = $3, + updated_at = $4 + WHERE id = $5 + ` + + _, err := r.pool.Exec(ctx, query, responseCode, responseBody, IdempotencyStatusFailed, time.Now(), id) + if err != nil { + return fmt.Errorf("failed to update idempotency record to failed: %w", err) + } + return nil +} + +// DeleteExpired 删除过期记录(定时清理) +func (r *IdempotencyRepository) DeleteExpired(ctx context.Context) (int64, error) { + query := `DELETE FROM supply_idempotency_records WHERE expires_at < $1` + + cmdTag, err := r.pool.Exec(ctx, query, time.Now()) + if err != nil { + return 0, fmt.Errorf("failed to delete expired idempotency records: %w", err) + } + + return cmdTag.RowsAffected(), nil +} + +// GetByRequestID 根据请求ID获取记录 +func (r *IdempotencyRepository) GetByRequestID(ctx context.Context, requestID string) (*IdempotencyRecord, error) { + query := ` + SELECT id, tenant_id, operator_id, api_path, idempotency_key, + request_id, payload_hash, response_code, response_body, + status, expires_at, created_at, updated_at + FROM supply_idempotency_records + WHERE request_id = $1 + ` + + record := &IdempotencyRecord{} + err := r.pool.QueryRow(ctx, query, requestID).Scan( + &record.ID, &record.TenantID, &record.OperatorID, &record.APIPath, &record.IdempotencyKey, + &record.RequestID, &record.PayloadHash, &record.ResponseCode, &record.ResponseBody, + &record.Status, &record.ExpiresAt, &record.CreatedAt, &record.UpdatedAt, + ) + + if errors.Is(err, pgx.ErrNoRows) { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("failed to get idempotency record by request_id: %w", err) + } + + return record, nil +} + +// CheckExists 检查幂等记录是否存在(用于竞争条件检测) +func (r *IdempotencyRepository) CheckExists(ctx context.Context, tenantID, operatorID int64, apiPath, idempotencyKey string) (bool, error) { + query := ` + SELECT EXISTS( + SELECT 1 FROM supply_idempotency_records + WHERE tenant_id = $1 AND operator_id = $2 AND api_path = $3 AND idempotency_key = $4 + AND expires_at > $5 + ) + ` + + var exists bool + err := r.pool.QueryRow(ctx, query, tenantID, operatorID, apiPath, idempotencyKey, time.Now()).Scan(&exists) + if err != nil { + return false, fmt.Errorf("failed to check idempotency record existence: %w", err) + } + + return exists, nil +} + +// AcquireLock 尝试获取幂等锁(用于创建记录) +func (r *IdempotencyRepository) AcquireLock(ctx context.Context, tenantID, operatorID int64, apiPath, idempotencyKey string, ttl time.Duration) (*IdempotencyRecord, error) { + // 先尝试插入 + record := &IdempotencyRecord{ + TenantID: tenantID, + OperatorID: operatorID, + APIPath: apiPath, + IdempotencyKey: idempotencyKey, + RequestID: "", // 稍后填充 + PayloadHash: "", // 稍后填充 + Status: IdempotencyStatusProcessing, + ExpiresAt: time.Now().Add(ttl), + } + + query := ` + INSERT INTO supply_idempotency_records ( + tenant_id, operator_id, api_path, idempotency_key, + request_id, payload_hash, status, expires_at + ) VALUES ( + $1, $2, $3, $4, $5, $6, $7, $8 + ) + ON CONFLICT (tenant_id, operator_id, api_path, idempotency_key) + DO UPDATE SET + request_id = EXCLUDED.request_id, + payload_hash = EXCLUDED.payload_hash, + status = EXCLUDED.status, + expires_at = EXCLUDED.expires_at, + updated_at = now() + WHERE supply_idempotency_records.expires_at <= $8 + RETURNING id, created_at, updated_at, status + ` + + err := r.pool.QueryRow(ctx, query, + record.TenantID, record.OperatorID, record.APIPath, record.IdempotencyKey, + record.RequestID, record.PayloadHash, record.Status, record.ExpiresAt, + ).Scan(&record.ID, &record.CreatedAt, &record.UpdatedAt, &record.Status) + + if err != nil { + // 可能是重复插入 + existing, getErr := r.GetByKey(ctx, tenantID, operatorID, apiPath, idempotencyKey) + if getErr != nil { + return nil, fmt.Errorf("failed to acquire idempotency lock: %w (get err: %v)", err, getErr) + } + if existing != nil { + return existing, nil // 返回已存在的记录 + } + return nil, fmt.Errorf("failed to acquire idempotency lock: %w", err) + } + + return record, nil +} diff --git a/supply-api/internal/repository/package.go b/supply-api/internal/repository/package.go new file mode 100644 index 0000000..07fa079 --- /dev/null +++ b/supply-api/internal/repository/package.go @@ -0,0 +1,250 @@ +package repository + +import ( + "context" + "errors" + "fmt" + "net/netip" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + "lijiaoqiao/supply-api/internal/domain" +) + +// PackageRepository 套餐仓储 +type PackageRepository struct { + pool *pgxpool.Pool +} + +// NewPackageRepository 创建套餐仓储 +func NewPackageRepository(pool *pgxpool.Pool) *PackageRepository { + return &PackageRepository{pool: pool} +} + +// Create 创建套餐 +func (r *PackageRepository) Create(ctx context.Context, pkg *domain.Package, requestID, traceID string) error { + query := ` + INSERT INTO supply_packages ( + supply_account_id, user_id, platform, model, + total_quota, available_quota, sold_quota, reserved_quota, + price_per_1m_input, price_per_1m_output, min_purchase, + start_at, end_at, valid_days, + status, max_concurrent, rate_limit_rpm, + total_orders, total_revenue, rating, rating_count, + quota_unit, price_unit, currency_code, version, + created_ip, updated_ip, audit_trace_id, + request_id + ) VALUES ( + $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, + $22, $23, $24, $25, $26, $27, $28 + ) + RETURNING id, created_at, updated_at + ` + + var startAt, endAt *time.Time + if !pkg.StartAt.IsZero() { + startAt = &pkg.StartAt + } + if !pkg.EndAt.IsZero() { + endAt = &pkg.EndAt + } + + err := r.pool.QueryRow(ctx, query, + pkg.SupplierID, pkg.SupplierID, pkg.Platform, pkg.Model, + pkg.TotalQuota, pkg.AvailableQuota, pkg.SoldQuota, pkg.ReservedQuota, + pkg.PricePer1MInput, pkg.PricePer1MOutput, pkg.MinPurchase, + startAt, endAt, pkg.ValidDays, + pkg.Status, pkg.MaxConcurrent, pkg.RateLimitRPM, + pkg.TotalOrders, pkg.TotalRevenue, pkg.Rating, pkg.RatingCount, + "token", "per_1m_tokens", "USD", 0, + nil, nil, traceID, + requestID, + ).Scan(&pkg.ID, &pkg.CreatedAt, &pkg.UpdatedAt) + + if err != nil { + return fmt.Errorf("failed to create package: %w", err) + } + return nil +} + +// GetByID 获取套餐 +func (r *PackageRepository) GetByID(ctx context.Context, supplierID, id int64) (*domain.Package, error) { + query := ` + SELECT id, supply_account_id, user_id, platform, model, + total_quota, available_quota, sold_quota, reserved_quota, + price_per_1m_input, price_per_1m_output, min_purchase, + start_at, end_at, valid_days, + status, max_concurrent, rate_limit_rpm, + total_orders, total_revenue, rating, rating_count, + quota_unit, price_unit, currency_code, version, + created_at, updated_at + FROM supply_packages + WHERE id = $1 AND user_id = $2 + ` + + pkg := &domain.Package{} + var startAt, endAt pgx.NullTime + err := r.pool.QueryRow(ctx, query, id, supplierID).Scan( + &pkg.ID, &pkg.SupplierID, &pkg.SupplierID, &pkg.Platform, &pkg.Model, + &pkg.TotalQuota, &pkg.AvailableQuota, &pkg.SoldQuota, &pkg.ReservedQuota, + &pkg.PricePer1MInput, &pkg.PricePer1MOutput, &pkg.MinPurchase, + &startAt, &endAt, &pkg.ValidDays, + &pkg.Status, &pkg.MaxConcurrent, &pkg.RateLimitRPM, + &pkg.TotalOrders, &pkg.TotalRevenue, &pkg.Rating, &pkg.RatingCount, + &pkg.QuotaUnit, &pkg.PriceUnit, &pkg.CurrencyCode, &pkg.Version, + &pkg.CreatedAt, &pkg.UpdatedAt, + ) + + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrNotFound + } + if err != nil { + return nil, fmt.Errorf("failed to get package: %w", err) + } + + if startAt.Valid { + pkg.StartAt = startAt.Time + } + if endAt.Valid { + pkg.EndAt = endAt.Time + } + + return pkg, nil +} + +// Update 更新套餐(乐观锁) +func (r *PackageRepository) Update(ctx context.Context, pkg *domain.Package, expectedVersion int) error { + query := ` + UPDATE supply_packages SET + platform = $1, model = $2, + total_quota = $3, available_quota = $4, sold_quota = $5, reserved_quota = $6, + price_per_1m_input = $7, price_per_1m_output = $8, + start_at = $9, end_at = $10, valid_days = $11, + status = $12, max_concurrent = $13, rate_limit_rpm = $14, + total_orders = $15, total_revenue = $16, + rating = $17, rating_count = $18, + version = $19, updated_at = $20 + WHERE id = $21 AND user_id = $22 AND version = $23 + ` + + pkg.UpdatedAt = time.Now() + newVersion := expectedVersion + 1 + + cmdTag, err := r.pool.Exec(ctx, query, + pkg.Platform, pkg.Model, + pkg.TotalQuota, pkg.AvailableQuota, pkg.SoldQuota, pkg.ReservedQuota, + pkg.PricePer1MInput, pkg.PricePer1MOutput, + pkg.StartAt, pkg.EndAt, pkg.ValidDays, + pkg.Status, pkg.MaxConcurrent, pkg.RateLimitRPM, + pkg.TotalOrders, pkg.TotalRevenue, + pkg.Rating, pkg.RatingCount, + newVersion, pkg.UpdatedAt, + pkg.ID, pkg.SupplierID, expectedVersion, + ) + + if err != nil { + return fmt.Errorf("failed to update package: %w", err) + } + + if cmdTag.RowsAffected() == 0 { + return ErrConcurrencyConflict + } + + pkg.Version = newVersion + return nil +} + +// GetForUpdate 获取套餐并加行锁 +func (r *PackageRepository) GetForUpdate(ctx context.Context, tx pgxpool.Tx, supplierID, id int64) (*domain.Package, error) { + query := ` + SELECT id, supply_account_id, user_id, platform, model, + total_quota, available_quota, sold_quota, reserved_quota, + price_per_1m_input, price_per_1m_output, + status, version, + created_at, updated_at + FROM supply_packages + WHERE id = $1 AND user_id = $2 + FOR UPDATE + ` + + pkg := &domain.Package{} + err := tx.QueryRow(ctx, query, id, supplierID).Scan( + &pkg.ID, &pkg.SupplierID, &pkg.SupplierID, &pkg.Platform, &pkg.Model, + &pkg.TotalQuota, &pkg.AvailableQuota, &pkg.SoldQuota, &pkg.ReservedQuota, + &pkg.PricePer1MInput, &pkg.PricePer1MOutput, + &pkg.Status, &pkg.Version, + &pkg.CreatedAt, &pkg.UpdatedAt, + ) + + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrNotFound + } + if err != nil { + return nil, fmt.Errorf("failed to get package for update: %w", err) + } + + return pkg, nil +} + +// List 列出套餐 +func (r *PackageRepository) List(ctx context.Context, supplierID int64) ([]*domain.Package, error) { + query := ` + SELECT id, supply_account_id, user_id, platform, model, + total_quota, available_quota, sold_quota, + price_per_1m_input, price_per_1m_output, + status, max_concurrent, rate_limit_rpm, + valid_days, total_orders, total_revenue, + version, created_at, updated_at + FROM supply_packages + WHERE user_id = $1 + ORDER BY created_at DESC + ` + + rows, err := r.pool.Query(ctx, query, supplierID) + if err != nil { + return nil, fmt.Errorf("failed to list packages: %w", err) + } + defer rows.Close() + + var packages []*domain.Package + for rows.Next() { + pkg := &domain.Package{} + err := rows.Scan( + &pkg.ID, &pkg.SupplierID, &pkg.SupplierID, &pkg.Platform, &pkg.Model, + &pkg.TotalQuota, &pkg.AvailableQuota, &pkg.SoldQuota, + &pkg.PricePer1MInput, &pkg.PricePer1MOutput, + &pkg.Status, &pkg.MaxConcurrent, &pkg.RateLimitRPM, + &pkg.ValidDays, &pkg.TotalOrders, &pkg.TotalRevenue, + &pkg.Version, &pkg.CreatedAt, &pkg.UpdatedAt, + ) + if err != nil { + return nil, fmt.Errorf("failed to scan package: %w", err) + } + packages = append(packages, pkg) + } + + return packages, nil +} + +// UpdateQuota 扣减配额 +func (r *PackageRepository) UpdateQuota(ctx context.Context, tx pgxpool.Tx, packageID, supplierID int64, usedQuota float64) error { + query := ` + UPDATE supply_packages SET + available_quota = available_quota - $1, + sold_quota = sold_quota + $1, + updated_at = $2 + WHERE id = $3 AND user_id = $4 AND available_quota >= $1 + RETURNING id + ` + + var id int64 + err := tx.QueryRow(ctx, query, usedQuota, time.Now(), packageID, supplierID).Scan(&id) + if errors.Is(err, pgx.ErrNoRows) { + return errors.New("insufficient quota or package not found") + } + if err != nil { + return fmt.Errorf("failed to update quota: %w", err) + } + return nil +} diff --git a/supply-api/internal/repository/settlement.go b/supply-api/internal/repository/settlement.go new file mode 100644 index 0000000..02b960a --- /dev/null +++ b/supply-api/internal/repository/settlement.go @@ -0,0 +1,243 @@ +package repository + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + "lijiaoqiao/supply-api/internal/domain" +) + +// SettlementRepository 结算仓储 +type SettlementRepository struct { + pool *pgxpool.Pool +} + +// NewSettlementRepository 创建结算仓储 +func NewSettlementRepository(pool *pgxpool.Pool) *SettlementRepository { + return &SettlementRepository{pool: pool} +} + +// Create 创建结算单 +func (r *SettlementRepository) Create(ctx context.Context, s *domain.Settlement, requestID, idempotencyKey, traceID string) error { + query := ` + INSERT INTO supply_settlements ( + settlement_no, user_id, total_amount, fee_amount, net_amount, + status, payment_method, payment_account, + period_start, period_end, total_orders, total_usage_records, + currency_code, amount_unit, version, + request_id, idempotency_key, audit_trace_id + ) VALUES ( + $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18 + ) + RETURNING id, created_at, updated_at + ` + + err := r.pool.QueryRow(ctx, query, + s.SettlementNo, s.SupplierID, s.TotalAmount, s.FeeAmount, s.NetAmount, + s.Status, s.PaymentMethod, s.PaymentAccount, + s.PeriodStart, s.PeriodEnd, s.TotalOrders, s.TotalUsageRecords, + "USD", "minor", 0, + requestID, idempotencyKey, traceID, + ).Scan(&s.ID, &s.CreatedAt, &s.UpdatedAt) + + if err != nil { + return fmt.Errorf("failed to create settlement: %w", err) + } + return nil +} + +// GetByID 获取结算单 +func (r *SettlementRepository) GetByID(ctx context.Context, supplierID, id int64) (*domain.Settlement, error) { + query := ` + SELECT id, settlement_no, user_id, total_amount, fee_amount, net_amount, + status, payment_method, payment_account, + period_start, period_end, total_orders, total_usage_records, + payment_transaction_id, paid_at, + version, created_at, updated_at + FROM supply_settlements + WHERE id = $1 AND user_id = $2 + ` + + s := &domain.Settlement{} + var paidAt pgx.NullTime + err := r.pool.QueryRow(ctx, query, id, supplierID).Scan( + &s.ID, &s.SettlementNo, &s.SupplierID, &s.TotalAmount, &s.FeeAmount, &s.NetAmount, + &s.Status, &s.PaymentMethod, &s.PaymentAccount, + &s.PeriodStart, &s.PeriodEnd, &s.TotalOrders, &s.TotalUsageRecords, + &s.PaymentTransactionID, &paidAt, + &s.Version, &s.CreatedAt, &s.UpdatedAt, + ) + + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrNotFound + } + if err != nil { + return nil, fmt.Errorf("failed to get settlement: %w", err) + } + + if paidAt.Valid { + s.PaidAt = &paidAt.Time + } + + return s, nil +} + +// Update 更新结算单(乐观锁) +func (r *SettlementRepository) Update(ctx context.Context, s *domain.Settlement, expectedVersion int) error { + query := ` + UPDATE supply_settlements SET + status = $1, payment_method = $2, payment_account = $3, + payment_transaction_id = $4, paid_at = $5, + total_orders = $6, total_usage_records = $7, + version = $8, updated_at = $9 + WHERE id = $10 AND user_id = $11 AND version = $12 + ` + + s.UpdatedAt = time.Now() + newVersion := expectedVersion + 1 + + cmdTag, err := r.pool.Exec(ctx, query, + s.Status, s.PaymentMethod, s.PaymentAccount, + s.PaymentTransactionID, s.PaidAt, + s.TotalOrders, s.TotalUsageRecords, + newVersion, s.UpdatedAt, + s.ID, s.SupplierID, expectedVersion, + ) + + if err != nil { + return fmt.Errorf("failed to update settlement: %w", err) + } + + if cmdTag.RowsAffected() == 0 { + return ErrConcurrencyConflict + } + + s.Version = newVersion + return nil +} + +// GetForUpdate 获取结算单并加行锁 +func (r *SettlementRepository) GetForUpdate(ctx context.Context, tx pgxpool.Tx, supplierID, id int64) (*domain.Settlement, error) { + query := ` + SELECT id, settlement_no, user_id, total_amount, fee_amount, net_amount, + status, payment_method, payment_account, version, + created_at, updated_at + FROM supply_settlements + WHERE id = $1 AND user_id = $2 + FOR UPDATE + ` + + s := &domain.Settlement{} + err := tx.QueryRow(ctx, query, id, supplierID).Scan( + &s.ID, &s.SettlementNo, &s.SupplierID, &s.TotalAmount, &s.FeeAmount, &s.NetAmount, + &s.Status, &s.PaymentMethod, &s.PaymentAccount, &s.Version, + &s.CreatedAt, &s.UpdatedAt, + ) + + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrNotFound + } + if err != nil { + return nil, fmt.Errorf("failed to get settlement for update: %w", err) + } + + return s, nil +} + +// GetProcessing 获取处理中的结算单(用于单一性约束) +func (r *SettlementRepository) GetProcessing(ctx context.Context, tx pgxpool.Tx, supplierID int64) (*domain.Settlement, error) { + query := ` + SELECT id, settlement_no, user_id, total_amount, fee_amount, net_amount, + status, payment_method, payment_account, version, + created_at, updated_at + FROM supply_settlements + WHERE user_id = $1 AND status = 'processing' + FOR UPDATE SKIP LOCKED + LIMIT 1 + ` + + s := &domain.Settlement{} + err := tx.QueryRow(ctx, query, supplierID).Scan( + &s.ID, &s.SettlementNo, &s.SupplierID, &s.TotalAmount, &s.FeeAmount, &s.NetAmount, + &s.Status, &s.PaymentMethod, &s.PaymentAccount, &s.Version, + &s.CreatedAt, &s.UpdatedAt, + ) + + if errors.Is(err, pgx.ErrNoRows) { + return nil, nil // 没有处理中的单据 + } + if err != nil { + return nil, fmt.Errorf("failed to get processing settlement: %w", err) + } + + return s, nil +} + +// List 列出结算单 +func (r *SettlementRepository) List(ctx context.Context, supplierID int64) ([]*domain.Settlement, error) { + query := ` + SELECT id, settlement_no, user_id, total_amount, fee_amount, net_amount, + status, payment_method, + period_start, period_end, total_orders, + version, created_at, updated_at + FROM supply_settlements + WHERE user_id = $1 + ORDER BY created_at DESC + ` + + rows, err := r.pool.Query(ctx, query, supplierID) + if err != nil { + return nil, fmt.Errorf("failed to list settlements: %w", err) + } + defer rows.Close() + + var settlements []*domain.Settlement + for rows.Next() { + s := &domain.Settlement{} + err := rows.Scan( + &s.ID, &s.SettlementNo, &s.SupplierID, &s.TotalAmount, &s.FeeAmount, &s.NetAmount, + &s.Status, &s.PaymentMethod, + &s.PeriodStart, &s.PeriodEnd, &s.TotalOrders, + &s.Version, &s.CreatedAt, &s.UpdatedAt, + ) + if err != nil { + return nil, fmt.Errorf("failed to scan settlement: %w", err) + } + settlements = append(settlements, s) + } + + return settlements, nil +} + +// CreateInTx 在事务中创建结算单 +func (r *SettlementRepository) CreateInTx(ctx context.Context, tx pgxpool.Tx, s *domain.Settlement, requestID, idempotencyKey, traceID string) error { + query := ` + INSERT INTO supply_settlements ( + settlement_no, user_id, total_amount, fee_amount, net_amount, + status, payment_method, payment_account, + period_start, period_end, total_orders, total_usage_records, + currency_code, amount_unit, version, + request_id, idempotency_key, audit_trace_id + ) VALUES ( + $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18 + ) + RETURNING id, created_at, updated_at + ` + + err := tx.QueryRow(ctx, query, + s.SettlementNo, s.SupplierID, s.TotalAmount, s.FeeAmount, s.NetAmount, + s.Status, s.PaymentMethod, s.PaymentAccount, + s.PeriodStart, s.PeriodEnd, s.TotalOrders, s.TotalUsageRecords, + "USD", "minor", 0, + requestID, idempotencyKey, traceID, + ).Scan(&s.ID, &s.CreatedAt, &s.UpdatedAt) + + if err != nil { + return fmt.Errorf("failed to create settlement in tx: %w", err) + } + return nil +} diff --git a/supply-api/sql/postgresql/supply_idempotency_record_v1.sql b/supply-api/sql/postgresql/supply_idempotency_record_v1.sql new file mode 100644 index 0000000..f9011ff --- /dev/null +++ b/supply-api/sql/postgresql/supply_idempotency_record_v1.sql @@ -0,0 +1,47 @@ +-- Supply Idempotency Record Schema +-- Based on: XR-001 (supply_technical_design_enhanced_v1_2026-03-25.md) +-- Updated: 2026-03-27 + +BEGIN; + +CREATE TABLE IF NOT EXISTS supply_idempotency_records ( + id BIGSERIAL PRIMARY KEY, + tenant_id BIGINT NOT NULL, + operator_id BIGINT NOT NULL, + api_path VARCHAR(200) NOT NULL, + idempotency_key VARCHAR(128) NOT NULL, + request_id VARCHAR(64) NOT NULL, + payload_hash CHAR(64) NOT NULL, -- SHA256 of request body + response_code INT, + response_body JSONB, + status VARCHAR(20) NOT NULL DEFAULT 'processing' + CHECK (status IN ('processing', 'succeeded', 'failed')), + expires_at TIMESTAMPTZ NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + UNIQUE (tenant_id, operator_id, api_path, idempotency_key) +); + +-- 高频查询索引 +CREATE INDEX IF NOT EXISTS idx_idempotency_tenant_operator_path_key + ON supply_idempotency_records (tenant_id, operator_id, api_path, idempotency_key) + WHERE expires_at > CURRENT_TIMESTAMP; + +-- RequestID 反查索引 +CREATE INDEX IF NOT EXISTS idx_idempotency_request_id + ON supply_idempotency_records (request_id); + +-- 过期清理索引 +CREATE INDEX IF NOT EXISTS idx_idempotency_expires_at + ON supply_idempotency_records (expires_at) + WHERE status != 'processing'; + +-- 状态查询索引 +CREATE INDEX IF NOT EXISTS idx_idempotency_status_expires + ON supply_idempotency_records (status, expires_at); + +COMMENT ON TABLE supply_idempotency_records IS '幂等记录表 - XR-001'; +COMMENT ON COLUMN supply_idempotency_records.payload_hash IS '请求体SHA256摘要,用于检测异参重放'; +COMMENT ON COLUMN supply_idempotency_records.expires_at IS '过期时间,默认24小时,提现类72小时'; + +COMMIT;