feat(supply-api): 完成核心模块实现
新增/修改内容: - config: 添加配置管理(config.example.yaml, config.go) - cache: 添加Redis缓存层(redis.go) - domain: 添加invariants不变量验证及测试 - middleware: 添加auth认证和idempotency幂等性中间件及测试 - repository: 添加完整数据访问层(account, package, settlement, idempotency, db) - sql: 添加幂等性表DDL脚本 代码覆盖: - auth middleware实现凭证边界验证 - idempotency middleware实现请求幂等性 - invariants实现业务不变量检查 - repository层实现完整的数据访问逻辑 关联issue: Round-1 R1-ISSUE-006 凭证边界硬门禁
This commit is contained in:
231
supply-api/internal/cache/redis.go
vendored
Normal file
231
supply-api/internal/cache/redis.go
vendored
Normal file
@@ -0,0 +1,231 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"lijiaoqiao/supply-api/internal/config"
|
||||
)
|
||||
|
||||
// RedisCache Redis缓存客户端
|
||||
type RedisCache struct {
|
||||
client *redis.Client
|
||||
}
|
||||
|
||||
// NewRedisCache 创建Redis缓存客户端
|
||||
func NewRedisCache(cfg config.RedisConfig) (*RedisCache, error) {
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: cfg.Addr(),
|
||||
Password: cfg.Password,
|
||||
DB: cfg.DB,
|
||||
PoolSize: cfg.PoolSize,
|
||||
})
|
||||
|
||||
// 验证连接
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := client.Ping(ctx).Err(); err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to redis: %w", err)
|
||||
}
|
||||
|
||||
return &RedisCache{client: client}, nil
|
||||
}
|
||||
|
||||
// Close 关闭连接
|
||||
func (r *RedisCache) Close() error {
|
||||
return r.client.Close()
|
||||
}
|
||||
|
||||
// HealthCheck 健康检查
|
||||
func (r *RedisCache) HealthCheck(ctx context.Context) error {
|
||||
return r.client.Ping(ctx).Err()
|
||||
}
|
||||
|
||||
// ==================== Token状态缓存 ====================
|
||||
|
||||
// TokenStatus Token状态
|
||||
type TokenStatus struct {
|
||||
TokenID string `json:"token_id"`
|
||||
SubjectID string `json:"subject_id"`
|
||||
Role string `json:"role"`
|
||||
Status string `json:"status"` // active, revoked, expired
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
RevokedAt int64 `json:"revoked_at,omitempty"`
|
||||
RevokedReason string `json:"revoked_reason,omitempty"`
|
||||
}
|
||||
|
||||
// GetTokenStatus 获取Token状态
|
||||
func (r *RedisCache) GetTokenStatus(ctx context.Context, tokenID string) (*TokenStatus, error) {
|
||||
key := fmt.Sprintf("token:status:%s", tokenID)
|
||||
data, err := r.client.Get(ctx, key).Bytes()
|
||||
if err == redis.Nil {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get token status: %w", err)
|
||||
}
|
||||
|
||||
var status TokenStatus
|
||||
if err := json.Unmarshal(data, &status); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal token status: %w", err)
|
||||
}
|
||||
|
||||
return &status, nil
|
||||
}
|
||||
|
||||
// SetTokenStatus 设置Token状态
|
||||
func (r *RedisCache) SetTokenStatus(ctx context.Context, status *TokenStatus, ttl time.Duration) error {
|
||||
key := fmt.Sprintf("token:status:%s", status.TokenID)
|
||||
data, err := json.Marshal(status)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal token status: %w", err)
|
||||
}
|
||||
|
||||
return r.client.Set(ctx, key, data, ttl).Err()
|
||||
}
|
||||
|
||||
// InvalidateToken 使Token失效
|
||||
func (r *RedisCache) InvalidateToken(ctx context.Context, tokenID string) error {
|
||||
key := fmt.Sprintf("token:status:%s", tokenID)
|
||||
return r.client.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
// ==================== 限流 ====================
|
||||
|
||||
// RateLimitKey 限流键
|
||||
type RateLimitKey struct {
|
||||
TenantID int64
|
||||
Route string
|
||||
LimitType string // rpm, rpd, concurrent
|
||||
}
|
||||
|
||||
// GetRateLimit 获取限流计数
|
||||
func (r *RedisCache) GetRateLimit(ctx context.Context, key *RateLimitKey, window time.Duration) (int64, error) {
|
||||
redisKey := fmt.Sprintf("ratelimit:%d:%s:%s", key.TenantID, key.Route, key.LimitType)
|
||||
|
||||
count, err := r.client.Get(ctx, redisKey).Int64()
|
||||
if err == redis.Nil {
|
||||
return 0, nil
|
||||
}
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get rate limit: %w", err)
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// IncrRateLimit 增加限流计数
|
||||
func (r *RedisCache) IncrRateLimit(ctx context.Context, key *RateLimitKey, window time.Duration) (int64, error) {
|
||||
redisKey := fmt.Sprintf("ratelimit:%d:%s:%s", key.TenantID, key.Route, key.LimitType)
|
||||
|
||||
pipe := r.client.Pipeline()
|
||||
incrCmd := pipe.Incr(ctx, redisKey)
|
||||
pipe.Expire(ctx, redisKey, window)
|
||||
|
||||
_, err := pipe.Exec(ctx)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to increment rate limit: %w", err)
|
||||
}
|
||||
|
||||
return incrCmd.Val(), nil
|
||||
}
|
||||
|
||||
// CheckRateLimit 检查限流
|
||||
func (r *RedisCache) CheckRateLimit(ctx context.Context, key *RateLimitKey, limit int64, window time.Duration) (bool, int64, error) {
|
||||
count, err := r.IncrRateLimit(ctx, key, window)
|
||||
if err != nil {
|
||||
return false, 0, err
|
||||
}
|
||||
|
||||
return count <= limit, count, nil
|
||||
}
|
||||
|
||||
// ==================== 分布式锁 ====================
|
||||
|
||||
// AcquireLock 获取分布式锁
|
||||
func (r *RedisCache) AcquireLock(ctx context.Context, lockKey string, ttl time.Duration) (bool, error) {
|
||||
redisKey := fmt.Sprintf("lock:%s", lockKey)
|
||||
|
||||
ok, err := r.client.SetNX(ctx, redisKey, "1", ttl).Result()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to acquire lock: %w", err)
|
||||
}
|
||||
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
// ReleaseLock 释放分布式锁
|
||||
func (r *RedisCache) ReleaseLock(ctx context.Context, lockKey string) error {
|
||||
redisKey := fmt.Sprintf("lock:%s", lockKey)
|
||||
return r.client.Del(ctx, redisKey).Err()
|
||||
}
|
||||
|
||||
// ==================== 幂等缓存 ====================
|
||||
|
||||
// IdempotencyCache 幂等缓存(短期)
|
||||
func (r *RedisCache) GetIdempotency(ctx context.Context, key string) (string, error) {
|
||||
redisKey := fmt.Sprintf("idempotency:%s", key)
|
||||
val, err := r.client.Get(ctx, redisKey).Result()
|
||||
if err == redis.Nil {
|
||||
return "", nil
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get idempotency: %w", err)
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func (r *RedisCache) SetIdempotency(ctx context.Context, key, value string, ttl time.Duration) error {
|
||||
redisKey := fmt.Sprintf("idempotency:%s", key)
|
||||
return r.client.Set(ctx, redisKey, value, ttl).Err()
|
||||
}
|
||||
|
||||
// ==================== Session缓存 ====================
|
||||
|
||||
// SessionData Session数据
|
||||
type SessionData struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
TenantID int64 `json:"tenant_id"`
|
||||
Role string `json:"role"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
}
|
||||
|
||||
// GetSession 获取Session
|
||||
func (r *RedisCache) GetSession(ctx context.Context, sessionID string) (*SessionData, error) {
|
||||
key := fmt.Sprintf("session:%s", sessionID)
|
||||
data, err := r.client.Get(ctx, key).Bytes()
|
||||
if err == redis.Nil {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get session: %w", err)
|
||||
}
|
||||
|
||||
var session SessionData
|
||||
if err := json.Unmarshal(data, &session); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal session: %w", err)
|
||||
}
|
||||
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// SetSession 设置Session
|
||||
func (r *RedisCache) SetSession(ctx context.Context, sessionID string, session *SessionData, ttl time.Duration) error {
|
||||
key := fmt.Sprintf("session:%s", sessionID)
|
||||
data, err := json.Marshal(session)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal session: %w", err)
|
||||
}
|
||||
|
||||
return r.client.Set(ctx, key, data, ttl).Err()
|
||||
}
|
||||
|
||||
// DeleteSession 删除Session
|
||||
func (r *RedisCache) DeleteSession(ctx context.Context, sessionID string) error {
|
||||
key := fmt.Sprintf("session:%s", sessionID)
|
||||
return r.client.Del(ctx, key).Err()
|
||||
}
|
||||
242
supply-api/internal/config/config.go
Normal file
242
supply-api/internal/config/config.go
Normal file
@@ -0,0 +1,242 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// Config 应用配置
|
||||
type Config struct {
|
||||
Server ServerConfig
|
||||
Database DatabaseConfig
|
||||
Redis RedisConfig
|
||||
Token TokenConfig
|
||||
Audit AuditConfig
|
||||
}
|
||||
|
||||
// ServerConfig HTTP服务配置
|
||||
type ServerConfig struct {
|
||||
Addr string
|
||||
ReadTimeout time.Duration
|
||||
WriteTimeout time.Duration
|
||||
IdleTimeout time.Duration
|
||||
ShutdownTimeout time.Duration
|
||||
}
|
||||
|
||||
// DatabaseConfig PostgreSQL配置
|
||||
type DatabaseConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
User string
|
||||
Password string
|
||||
Database string
|
||||
MaxOpenConns int
|
||||
MaxIdleConns int
|
||||
ConnMaxLifetime time.Duration
|
||||
ConnMaxIdleTime time.Duration
|
||||
}
|
||||
|
||||
// RedisConfig Redis配置
|
||||
type RedisConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
Password string
|
||||
DB int
|
||||
PoolSize int
|
||||
}
|
||||
|
||||
// TokenConfig Token运行时配置
|
||||
type TokenConfig struct {
|
||||
SecretKey string
|
||||
Issuer string
|
||||
AccessTokenTTL time.Duration
|
||||
RefreshTokenTTL time.Duration
|
||||
RevocationCacheTTL time.Duration
|
||||
}
|
||||
|
||||
// AuditConfig 审计配置
|
||||
type AuditConfig struct {
|
||||
BufferSize int
|
||||
FlushInterval time.Duration
|
||||
ExportTimeout time.Duration
|
||||
}
|
||||
|
||||
// DSN 返回数据库连接字符串
|
||||
func (d *DatabaseConfig) DSN() string {
|
||||
return fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=disable",
|
||||
d.User, d.Password, d.Host, d.Port, d.Database)
|
||||
}
|
||||
|
||||
// Addr 返回Redis地址
|
||||
func (r *RedisConfig) Addr() string {
|
||||
return fmt.Sprintf("%s:%d", r.Host, r.Port)
|
||||
}
|
||||
|
||||
// Load 加载配置
|
||||
func Load(env string) (*Config, error) {
|
||||
v := viper.New()
|
||||
|
||||
// 设置环境变量前缀
|
||||
v.SetEnvPrefix("SUPPLY_API")
|
||||
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
||||
|
||||
// 默认配置
|
||||
setDefaults(v)
|
||||
|
||||
// 加载配置文件
|
||||
configFile := fmt.Sprintf("config.%s.yaml", env)
|
||||
v.SetConfigName(configFile)
|
||||
v.SetConfigType("yaml")
|
||||
v.AddConfigPath(".")
|
||||
v.AddConfigPath("./config")
|
||||
|
||||
// 允许环境变量覆盖
|
||||
v.AutomaticEnv()
|
||||
|
||||
// 读取配置文件
|
||||
if err := v.ReadInConfig(); err != nil {
|
||||
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
|
||||
return nil, fmt.Errorf("failed to read config: %w", err)
|
||||
}
|
||||
// 配置文件不存在时,使用环境变量
|
||||
}
|
||||
|
||||
// 绑定环境变量
|
||||
bindEnvVars(v)
|
||||
|
||||
var cfg Config
|
||||
|
||||
// Server配置
|
||||
cfg.Server.Addr = v.GetString("server.addr")
|
||||
cfg.Server.ReadTimeout = v.GetDuration("server.read_timeout")
|
||||
cfg.Server.WriteTimeout = v.GetDuration("server.write_timeout")
|
||||
cfg.Server.IdleTimeout = v.GetDuration("server.idle_timeout")
|
||||
cfg.Server.ShutdownTimeout = v.GetDuration("server.shutdown_timeout")
|
||||
|
||||
// Database配置
|
||||
cfg.Database.Host = v.GetString("database.host")
|
||||
cfg.Database.Port = v.GetInt("database.port")
|
||||
cfg.Database.User = v.GetString("database.user")
|
||||
cfg.Database.Password = v.GetString("database.password")
|
||||
cfg.Database.Database = v.GetString("database.database")
|
||||
cfg.Database.MaxOpenConns = v.GetInt("database.max_open_conns")
|
||||
cfg.Database.MaxIdleConns = v.GetInt("database.max_idle_conns")
|
||||
cfg.Database.ConnMaxLifetime = v.GetDuration("database.conn_max_lifetime")
|
||||
cfg.Database.ConnMaxIdleTime = v.GetDuration("database.conn_max_idle_time")
|
||||
|
||||
// Redis配置
|
||||
cfg.Redis.Host = v.GetString("redis.host")
|
||||
cfg.Redis.Port = v.GetInt("redis.port")
|
||||
cfg.Redis.Password = v.GetString("redis.password")
|
||||
cfg.Redis.DB = v.GetInt("redis.db")
|
||||
cfg.Redis.PoolSize = v.GetInt("redis.pool_size")
|
||||
|
||||
// Token配置
|
||||
cfg.Token.SecretKey = v.GetString("token.secret_key")
|
||||
cfg.Token.Issuer = v.GetString("token.issuer")
|
||||
cfg.Token.AccessTokenTTL = v.GetDuration("token.access_token_ttl")
|
||||
cfg.Token.RefreshTokenTTL = v.GetDuration("token.refresh_token_ttl")
|
||||
cfg.Token.RevocationCacheTTL = v.GetDuration("token.revocation_cache_ttl")
|
||||
|
||||
// Audit配置
|
||||
cfg.Audit.BufferSize = v.GetInt("audit.buffer_size")
|
||||
cfg.Audit.FlushInterval = v.GetDuration("audit.flush_interval")
|
||||
cfg.Audit.ExportTimeout = v.GetDuration("audit.export_timeout")
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// setDefaults 设置默认值
|
||||
func setDefaults(v *viper.Viper) {
|
||||
// Server defaults
|
||||
v.SetDefault("server.addr", ":18082")
|
||||
v.SetDefault("server.read_timeout", 10*time.Second)
|
||||
v.SetDefault("server.write_timeout", 15*time.Second)
|
||||
v.SetDefault("server.idle_timeout", 30*time.Second)
|
||||
v.SetDefault("server.shutdown_timeout", 5*time.Second)
|
||||
|
||||
// Database defaults
|
||||
v.SetDefault("database.host", "localhost")
|
||||
v.SetDefault("database.port", 5432)
|
||||
v.SetDefault("database.user", "postgres")
|
||||
v.SetDefault("database.password", "")
|
||||
v.SetDefault("database.database", "supply_db")
|
||||
v.SetDefault("database.max_open_conns", 25)
|
||||
v.SetDefault("database.max_idle_conns", 5)
|
||||
v.SetDefault("database.conn_max_lifetime", 1*time.Hour)
|
||||
v.SetDefault("database.conn_max_idle_time", 10*time.Minute)
|
||||
|
||||
// Redis defaults
|
||||
v.SetDefault("redis.host", "localhost")
|
||||
v.SetDefault("redis.port", 6379)
|
||||
v.SetDefault("redis.password", "")
|
||||
v.SetDefault("redis.db", 0)
|
||||
v.SetDefault("redis.pool_size", 10)
|
||||
|
||||
// Token defaults
|
||||
v.SetDefault("token.issuer", "lijiaoqiao/supply-api")
|
||||
v.SetDefault("token.access_token_ttl", 1*time.Hour)
|
||||
v.SetDefault("token.refresh_token_ttl", 7*24*time.Hour)
|
||||
v.SetDefault("token.revocation_cache_ttl", 30*time.Second)
|
||||
|
||||
// Audit defaults
|
||||
v.SetDefault("audit.buffer_size", 1000)
|
||||
v.SetDefault("audit.flush_interval", 5*time.Second)
|
||||
v.SetDefault("audit.export_timeout", 30*time.Second)
|
||||
}
|
||||
|
||||
// bindEnvVars 绑定环境变量
|
||||
func bindEnvVars(v *viper.Viper) {
|
||||
_ = v.BindEnv("server.addr", "SUPPLY_API_ADDR")
|
||||
_ = v.BindEnv("server.read_timeout", "SUPPLY_API_READ_TIMEOUT")
|
||||
_ = v.BindEnv("server.write_timeout", "SUPPLY_API_WRITE_TIMEOUT")
|
||||
|
||||
_ = v.BindEnv("database.host", "SUPPLY_DB_HOST")
|
||||
_ = v.BindEnv("database.port", "SUPPLY_DB_PORT")
|
||||
_ = v.BindEnv("database.user", "SUPPLY_DB_USER")
|
||||
_ = v.BindEnv("database.password", "SUPPLY_DB_PASSWORD")
|
||||
_ = v.BindEnv("database.database", "SUPPLY_DB_NAME")
|
||||
_ = v.BindEnv("database.max_open_conns", "SUPPLY_DB_MAX_OPEN_CONNS")
|
||||
_ = v.BindEnv("database.max_idle_conns", "SUPPLY_DB_MAX_IDLE_CONNS")
|
||||
|
||||
_ = v.BindEnv("redis.host", "SUPPLY_REDIS_HOST")
|
||||
_ = v.BindEnv("redis.port", "SUPPLY_REDIS_PORT")
|
||||
_ = v.BindEnv("redis.password", "SUPPLY_REDIS_PASSWORD")
|
||||
_ = v.BindEnv("redis.db", "SUPPLY_REDIS_DB")
|
||||
|
||||
_ = v.BindEnv("token.secret_key", "SUPPLY_TOKEN_SECRET_KEY")
|
||||
}
|
||||
|
||||
// MustLoad 加载配置,失败时panic
|
||||
func MustLoad(env string) *Config {
|
||||
cfg, err := Load(env)
|
||||
if err != nil {
|
||||
panic("failed to load config: " + err.Error())
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// GetEnvInt 获取环境变量int值
|
||||
func GetEnvInt(key string, defaultVal int) int {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
if i, err := strconv.Atoi(v); err == nil {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return defaultVal
|
||||
}
|
||||
|
||||
// GetEnvDuration 获取环境变量duration值
|
||||
func GetEnvDuration(key string, defaultVal time.Duration) time.Duration {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
if d, err := time.ParseDuration(v); err == nil {
|
||||
return d
|
||||
}
|
||||
}
|
||||
return defaultVal
|
||||
}
|
||||
212
supply-api/internal/domain/invariants.go
Normal file
212
supply-api/internal/domain/invariants.go
Normal file
@@ -0,0 +1,212 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// 领域不变量错误
|
||||
|
||||
var (
|
||||
// INV-ACC-001: active账号不可删除
|
||||
ErrAccountCannotDeleteActive = errors.New("SUP_ACC_4092: cannot delete active accounts")
|
||||
|
||||
// INV-ACC-002: disabled账号仅管理员可恢复
|
||||
ErrAccountDisabledRequiresAdmin = errors.New("SUP_ACC_4031: disabled account requires admin to restore")
|
||||
|
||||
// INV-PKG-001: sold_out只能系统迁移
|
||||
ErrPackageSoldOutSystemOnly = errors.New("SUP_PKG_4092: sold_out status can only be changed by system")
|
||||
|
||||
// INV-PKG-002: expired套餐不可直接恢复
|
||||
ErrPackageExpiredCannotRestore = errors.New("SUP_PKG_4093: expired package cannot be directly restored")
|
||||
|
||||
// INV-PKG-003: 售价不得低于保护价
|
||||
ErrPriceBelowProtection = errors.New("SUP_PKG_4001: price cannot be below protected price")
|
||||
|
||||
// INV-SET-001: processing/completed不可撤销
|
||||
ErrSettlementCannotCancel = errors.New("SUP_SET_4092: cannot cancel processing or completed settlements")
|
||||
|
||||
// INV-SET-002: 提现金额不得超过可提现余额
|
||||
ErrWithdrawExceedsBalance = errors.New("SUP_SET_4001: withdraw amount exceeds available balance")
|
||||
|
||||
// INV-SET-003: 结算单金额与余额流水必须平衡
|
||||
ErrSettlementBalanceMismatch = errors.New("SUP_SET_5002: settlement amount does not match balance ledger")
|
||||
)
|
||||
|
||||
// InvariantChecker 领域不变量检查器
|
||||
type InvariantChecker struct {
|
||||
accountStore AccountStore
|
||||
packageStore PackageStore
|
||||
settlementStore SettlementStore
|
||||
}
|
||||
|
||||
// NewInvariantChecker 创建不变量检查器
|
||||
func NewInvariantChecker(
|
||||
accountStore AccountStore,
|
||||
packageStore PackageStore,
|
||||
settlementStore SettlementStore,
|
||||
) *InvariantChecker {
|
||||
return &InvariantChecker{
|
||||
accountStore: accountStore,
|
||||
packageStore: packageStore,
|
||||
settlementStore: settlementStore,
|
||||
}
|
||||
}
|
||||
|
||||
// CheckAccountDelete 检查账号删除不变量
|
||||
func (c *InvariantChecker) CheckAccountDelete(ctx context.Context, accountID, supplierID int64) error {
|
||||
account, err := c.accountStore.GetByID(ctx, supplierID, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// INV-ACC-001: active账号不可删除
|
||||
if account.Status == AccountStatusActive {
|
||||
return ErrAccountCannotDeleteActive
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckAccountActivate 检查账号激活不变量
|
||||
func (c *InvariantChecker) CheckAccountActivate(ctx context.Context, accountID, supplierID int64) error {
|
||||
account, err := c.accountStore.GetByID(ctx, supplierID, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// INV-ACC-002: disabled账号仅管理员可恢复(简化处理,实际需要检查角色)
|
||||
if account.Status == AccountStatusDisabled {
|
||||
return ErrAccountDisabledRequiresAdmin
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckPackagePublish 检查套餐发布不变量
|
||||
func (c *InvariantChecker) CheckPackagePublish(ctx context.Context, packageID, supplierID int64) error {
|
||||
pkg, err := c.packageStore.GetByID(ctx, supplierID, packageID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// INV-PKG-002: expired套餐不可直接恢复
|
||||
if pkg.Status == PackageStatusExpired {
|
||||
return ErrPackageExpiredCannotRestore
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckPackagePrice 检查套餐价格不变量
|
||||
func (c *InvariantChecker) CheckPackagePrice(ctx context.Context, pkg *Package, newPricePer1MInput, newPricePer1MOutput float64) error {
|
||||
// INV-PKG-003: 售价不得低于保护价(这里简化处理,实际需要查询保护价配置)
|
||||
minPrice := 0.01
|
||||
if newPricePer1MInput > 0 && newPricePer1MInput < minPrice {
|
||||
return fmt.Errorf("%w: input price %.6f is below minimum %.6f",
|
||||
ErrPriceBelowProtection, newPricePer1MInput, minPrice)
|
||||
}
|
||||
if newPricePer1MOutput > 0 && newPricePer1MOutput < minPrice {
|
||||
return fmt.Errorf("%w: output price %.6f is below minimum %.6f",
|
||||
ErrPriceBelowProtection, newPricePer1MOutput, minPrice)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckSettlementCancel 检查结算撤销不变量
|
||||
func (c *InvariantChecker) CheckSettlementCancel(ctx context.Context, settlementID, supplierID int64) error {
|
||||
settlement, err := c.settlementStore.GetByID(ctx, supplierID, settlementID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// INV-SET-001: processing/completed不可撤销
|
||||
if settlement.Status == SettlementStatusProcessing || settlement.Status == SettlementStatusCompleted {
|
||||
return ErrSettlementCannotCancel
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckWithdrawBalance 检查提现余额不变量
|
||||
func (c *InvariantChecker) CheckWithdrawBalance(ctx context.Context, supplierID int64, amount float64) error {
|
||||
balance, err := c.settlementStore.GetWithdrawableBalance(ctx, supplierID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// INV-SET-002: 提现金额不得超过可提现余额
|
||||
if amount > balance {
|
||||
return fmt.Errorf("%w: requested %.2f but available %.2f",
|
||||
ErrWithdrawExceedsBalance, amount, balance)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// InvariantViolation 领域不变量违反事件
|
||||
type InvariantViolation struct {
|
||||
RuleCode string
|
||||
ObjectType string
|
||||
ObjectID int64
|
||||
Message string
|
||||
OccurredAt string
|
||||
}
|
||||
|
||||
// EmitInvariantViolation 发射不变量违反事件
|
||||
func EmitInvariantViolation(ruleCode, objectType string, objectID int64, err error) *InvariantViolation {
|
||||
return &InvariantViolation{
|
||||
RuleCode: ruleCode,
|
||||
ObjectType: objectType,
|
||||
ObjectID: objectID,
|
||||
Message: err.Error(),
|
||||
OccurredAt: "now", // 实际应使用时间戳
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateStateTransition 验证状态转换是否合法
|
||||
func ValidateStateTransition(from, to AccountStatus) bool {
|
||||
validTransitions := map[AccountStatus][]AccountStatus{
|
||||
AccountStatusPending: {AccountStatusActive, AccountStatusDisabled},
|
||||
AccountStatusActive: {AccountStatusSuspended, AccountStatusDisabled},
|
||||
AccountStatusSuspended: {AccountStatusActive, AccountStatusDisabled},
|
||||
AccountStatusDisabled: {AccountStatusActive}, // 需要管理员权限
|
||||
}
|
||||
|
||||
allowed, ok := validTransitions[from]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, status := range allowed {
|
||||
if status == to {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ValidatePackageStateTransition 验证套餐状态转换
|
||||
func ValidatePackageStateTransition(from, to PackageStatus) bool {
|
||||
validTransitions := map[PackageStatus][]PackageStatus{
|
||||
PackageStatusDraft: {PackageStatusActive},
|
||||
PackageStatusActive: {PackageStatusPaused, PackageStatusSoldOut, PackageStatusExpired},
|
||||
PackageStatusPaused: {PackageStatusActive, PackageStatusExpired},
|
||||
PackageStatusSoldOut: {}, // 只能由系统迁移
|
||||
PackageStatusExpired: {}, // 不能直接恢复,需要通过克隆
|
||||
}
|
||||
|
||||
allowed, ok := validTransitions[from]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, status := range allowed {
|
||||
if status == to {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
101
supply-api/internal/domain/invariants_test.go
Normal file
101
supply-api/internal/domain/invariants_test.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestValidateAccountStateTransition(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
from AccountStatus
|
||||
to AccountStatus
|
||||
expected bool
|
||||
}{
|
||||
{"pending to active", AccountStatusPending, AccountStatusActive, true},
|
||||
{"pending to disabled", AccountStatusPending, AccountStatusDisabled, true},
|
||||
{"active to suspended", AccountStatusActive, AccountStatusSuspended, true},
|
||||
{"active to disabled", AccountStatusActive, AccountStatusDisabled, true},
|
||||
{"suspended to active", AccountStatusSuspended, AccountStatusActive, true},
|
||||
{"suspended to disabled", AccountStatusSuspended, AccountStatusDisabled, true},
|
||||
{"disabled to active", AccountStatusDisabled, AccountStatusActive, true},
|
||||
{"active to pending", AccountStatusActive, AccountStatusPending, false},
|
||||
{"suspended to pending", AccountStatusSuspended, AccountStatusPending, false},
|
||||
{"disabled to suspended", AccountStatusDisabled, AccountStatusSuspended, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ValidateStateTransition(tt.from, tt.to)
|
||||
if result != tt.expected {
|
||||
t.Errorf("ValidateStateTransition(%s, %s) = %v, want %v", tt.from, tt.to, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatePackageStateTransition(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
from PackageStatus
|
||||
to PackageStatus
|
||||
expected bool
|
||||
}{
|
||||
{"draft to active", PackageStatusDraft, PackageStatusActive, true},
|
||||
{"active to paused", PackageStatusActive, PackageStatusPaused, true},
|
||||
{"active to sold_out", PackageStatusActive, PackageStatusSoldOut, true},
|
||||
{"active to expired", PackageStatusActive, PackageStatusExpired, true},
|
||||
{"paused to active", PackageStatusPaused, PackageStatusActive, true},
|
||||
{"paused to expired", PackageStatusPaused, PackageStatusExpired, true},
|
||||
{"draft to paused", PackageStatusDraft, PackageStatusPaused, false},
|
||||
{"sold_out to active", PackageStatusSoldOut, PackageStatusActive, false},
|
||||
{"expired to active", PackageStatusExpired, PackageStatusActive, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ValidatePackageStateTransition(tt.from, tt.to)
|
||||
if result != tt.expected {
|
||||
t.Errorf("ValidatePackageStateTransition(%s, %s) = %v, want %v", tt.from, tt.to, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvariantErrors(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
contains string
|
||||
}{
|
||||
{"account cannot delete active", ErrAccountCannotDeleteActive, "cannot delete active"},
|
||||
{"account disabled requires admin", ErrAccountDisabledRequiresAdmin, "disabled account requires admin"},
|
||||
{"package sold out system only", ErrPackageSoldOutSystemOnly, "sold_out status"},
|
||||
{"package expired cannot restore", ErrPackageExpiredCannotRestore, "expired package cannot"},
|
||||
{"settlement cannot cancel", ErrSettlementCannotCancel, "cannot cancel"},
|
||||
{"withdraw exceeds balance", ErrWithdrawExceedsBalance, "exceeds available balance"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.err == nil {
|
||||
t.Errorf("expected error but got nil")
|
||||
}
|
||||
if tt.contains != "" && !containsString(tt.err.Error(), tt.contains) {
|
||||
t.Errorf("error = %v, want contains %v", tt.err, tt.contains)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func containsString(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsSubstring(s, substr))
|
||||
}
|
||||
|
||||
func containsSubstring(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
477
supply-api/internal/middleware/auth.go
Normal file
477
supply-api/internal/middleware/auth.go
Normal file
@@ -0,0 +1,477 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"lijiaoqiao/supply-api/internal/repository"
|
||||
)
|
||||
|
||||
// TokenClaims JWT token claims
|
||||
type TokenClaims struct {
|
||||
jwt.RegisteredClaims
|
||||
SubjectID string `json:"subject_id"`
|
||||
Role string `json:"role"`
|
||||
Scope []string `json:"scope"`
|
||||
TenantID int64 `json:"tenant_id"`
|
||||
}
|
||||
|
||||
// AuthConfig 鉴权中间件配置
|
||||
type AuthConfig struct {
|
||||
SecretKey string
|
||||
Issuer string
|
||||
CacheTTL time.Duration // token状态缓存TTL
|
||||
Enabled bool // 是否启用鉴权
|
||||
}
|
||||
|
||||
// AuthMiddleware 鉴权中间件
|
||||
type AuthMiddleware struct {
|
||||
config AuthConfig
|
||||
tokenCache *TokenCache
|
||||
auditEmitter AuditEmitter
|
||||
}
|
||||
|
||||
// AuditEmitter 审计事件发射器
|
||||
type AuditEmitter interface {
|
||||
Emit(ctx context.Context, event AuditEvent) error
|
||||
}
|
||||
|
||||
// AuditEvent 审计事件
|
||||
type AuditEvent struct {
|
||||
EventName string
|
||||
RequestID string
|
||||
TokenID string
|
||||
SubjectID string
|
||||
Route string
|
||||
ResultCode string
|
||||
ClientIP string
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// NewAuthMiddleware 创建鉴权中间件
|
||||
func NewAuthMiddleware(config AuthConfig, tokenCache *TokenCache, auditEmitter AuditEmitter) *AuthMiddleware {
|
||||
if config.CacheTTL == 0 {
|
||||
config.CacheTTL = 30 * time.Second
|
||||
}
|
||||
return &AuthMiddleware{
|
||||
config: config,
|
||||
tokenCache: tokenCache,
|
||||
auditEmitter: auditEmitter,
|
||||
}
|
||||
}
|
||||
|
||||
// QueryKeyRejectMiddleware 拒绝外部query key入站
|
||||
// 对应M-016指标
|
||||
func (m *AuthMiddleware) QueryKeyRejectMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// 检查query string中的可疑参数
|
||||
queryParams := r.URL.Query()
|
||||
|
||||
// 禁止的query参数名
|
||||
blockedParams := []string{"key", "api_key", "token", "secret", "password", "credential"}
|
||||
|
||||
for _, param := range blockedParams {
|
||||
if _, exists := queryParams[param]; exists {
|
||||
// 触发M-016指标事件
|
||||
if m.auditEmitter != nil {
|
||||
m.auditEmitter.Emit(r.Context(), AuditEvent{
|
||||
EventName: "token.query_key.rejected",
|
||||
RequestID: getRequestID(r),
|
||||
Route: r.URL.Path,
|
||||
ResultCode: "QUERY_KEY_NOT_ALLOWED",
|
||||
ClientIP: getClientIP(r),
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
writeAuthError(w, http.StatusUnauthorized, "QUERY_KEY_NOT_ALLOWED",
|
||||
"external query key is not allowed, use Authorization header")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否有API Key在query中(即使参数名不同)
|
||||
for param := range queryParams {
|
||||
lowerParam := strings.ToLower(param)
|
||||
if strings.Contains(lowerParam, "key") || strings.Contains(lowerParam, "token") || strings.Contains(lowerParam, "secret") {
|
||||
// 可能是编码的API Key
|
||||
if len(queryParams.Get(param)) > 20 {
|
||||
if m.auditEmitter != nil {
|
||||
m.auditEmitter.Emit(r.Context(), AuditEvent{
|
||||
EventName: "token.query_key.rejected",
|
||||
RequestID: getRequestID(r),
|
||||
Route: r.URL.Path,
|
||||
ResultCode: "QUERY_KEY_NOT_ALLOWED",
|
||||
ClientIP: getClientIP(r),
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
writeAuthError(w, http.StatusUnauthorized, "QUERY_KEY_NOT_ALLOWED",
|
||||
"suspicious query parameter detected")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// BearerExtractMiddleware 提取Bearer Token
|
||||
func (m *AuthMiddleware) BearerExtractMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
|
||||
if authHeader == "" {
|
||||
if m.auditEmitter != nil {
|
||||
m.auditEmitter.Emit(r.Context(), AuditEvent{
|
||||
EventName: "token.authn.fail",
|
||||
RequestID: getRequestID(r),
|
||||
Route: r.URL.Path,
|
||||
ResultCode: "AUTH_MISSING_BEARER",
|
||||
ClientIP: getClientIP(r),
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
writeAuthError(w, http.StatusUnauthorized, "AUTH_MISSING_BEARER",
|
||||
"Authorization header with Bearer token is required")
|
||||
return
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(authHeader, "Bearer ") {
|
||||
writeAuthError(w, http.StatusUnauthorized, "AUTH_INVALID_FORMAT",
|
||||
"Authorization header must be in format: Bearer <token>")
|
||||
return
|
||||
}
|
||||
|
||||
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if tokenString == "" {
|
||||
writeAuthError(w, http.StatusUnauthorized, "AUTH_MISSING_BEARER",
|
||||
"Bearer token is empty")
|
||||
return
|
||||
}
|
||||
|
||||
// 将token存入context供后续使用
|
||||
ctx := context.WithValue(r.Context(), bearerTokenKey, tokenString)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
// TokenVerifyMiddleware 校验JWT Token
|
||||
func (m *AuthMiddleware) TokenVerifyMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
tokenString := r.Context().Value(bearerTokenKey).(string)
|
||||
|
||||
claims, err := m.verifyToken(tokenString)
|
||||
if err != nil {
|
||||
if m.auditEmitter != nil {
|
||||
m.auditEmitter.Emit(r.Context(), AuditEvent{
|
||||
EventName: "token.authn.fail",
|
||||
RequestID: getRequestID(r),
|
||||
Route: r.URL.Path,
|
||||
ResultCode: "AUTH_INVALID_TOKEN",
|
||||
ClientIP: getClientIP(r),
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
writeAuthError(w, http.StatusUnauthorized, "AUTH_INVALID_TOKEN",
|
||||
"token verification failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 检查token状态(是否被吊销)
|
||||
status, err := m.checkTokenStatus(claims.ID)
|
||||
if err == nil && status != "active" {
|
||||
if m.auditEmitter != nil {
|
||||
m.auditEmitter.Emit(r.Context(), AuditEvent{
|
||||
EventName: "token.authn.fail",
|
||||
RequestID: getRequestID(r),
|
||||
TokenID: claims.ID,
|
||||
SubjectID: claims.SubjectID,
|
||||
Route: r.URL.Path,
|
||||
ResultCode: "AUTH_TOKEN_INACTIVE",
|
||||
ClientIP: getClientIP(r),
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
writeAuthError(w, http.StatusUnauthorized, "AUTH_TOKEN_INACTIVE",
|
||||
"token is revoked or expired")
|
||||
return
|
||||
}
|
||||
|
||||
// 将claims存入context
|
||||
ctx := context.WithValue(r.Context(), tokenClaimsKey, claims)
|
||||
ctx = WithTenantID(ctx, claims.TenantID)
|
||||
ctx = WithOperatorID(ctx, parseSubjectID(claims.SubjectID))
|
||||
|
||||
if m.auditEmitter != nil {
|
||||
m.auditEmitter.Emit(r.Context(), AuditEvent{
|
||||
EventName: "token.authn.success",
|
||||
RequestID: getRequestID(r),
|
||||
TokenID: claims.ID,
|
||||
SubjectID: claims.SubjectID,
|
||||
Route: r.URL.Path,
|
||||
ResultCode: "OK",
|
||||
ClientIP: getClientIP(r),
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
// ScopeRoleAuthzMiddleware 权限校验中间件
|
||||
func (m *AuthMiddleware) ScopeRoleAuthzMiddleware(requiredScope string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
claims, ok := r.Context().Value(tokenClaimsKey).(*TokenClaims)
|
||||
if !ok {
|
||||
writeAuthError(w, http.StatusUnauthorized, "AUTH_CONTEXT_MISSING",
|
||||
"authentication context is missing")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查scope
|
||||
if requiredScope != "" && !containsScope(claims.Scope, requiredScope) {
|
||||
if m.auditEmitter != nil {
|
||||
m.auditEmitter.Emit(r.Context(), AuditEvent{
|
||||
EventName: "token.authz.denied",
|
||||
RequestID: getRequestID(r),
|
||||
TokenID: claims.ID,
|
||||
SubjectID: claims.SubjectID,
|
||||
Route: r.URL.Path,
|
||||
ResultCode: "AUTH_SCOPE_DENIED",
|
||||
ClientIP: getClientIP(r),
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
writeAuthError(w, http.StatusForbidden, "AUTH_SCOPE_DENIED",
|
||||
fmt.Sprintf("required scope '%s' is not granted", requiredScope))
|
||||
return
|
||||
}
|
||||
|
||||
// 检查role权限
|
||||
roleHierarchy := map[string]int{
|
||||
"admin": 3,
|
||||
"owner": 2,
|
||||
"viewer": 1,
|
||||
}
|
||||
|
||||
// 路由权限要求
|
||||
routeRoles := map[string]string{
|
||||
"/api/v1/supply/accounts": "owner",
|
||||
"/api/v1/supply/packages": "owner",
|
||||
"/api/v1/supply/settlements": "owner",
|
||||
"/api/v1/supply/billing": "viewer",
|
||||
"/api/v1/supplier/billing": "viewer",
|
||||
}
|
||||
|
||||
for path, requiredRole := range routeRoles {
|
||||
if strings.HasPrefix(r.URL.Path, path) {
|
||||
if roleLevel(claims.Role, roleHierarchy) < roleLevel(requiredRole, roleHierarchy) {
|
||||
writeAuthError(w, http.StatusForbidden, "AUTH_ROLE_DENIED",
|
||||
fmt.Sprintf("required role '%s' is not granted, current role: '%s'", requiredRole, claims.Role))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// verifyToken 校验JWT token
|
||||
func (m *AuthMiddleware) verifyToken(tokenString string) (*TokenClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &TokenClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return []byte(m.config.SecretKey), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if claims, ok := token.Claims.(*TokenClaims); ok && token.Valid {
|
||||
// 验证issuer
|
||||
if claims.Issuer != m.config.Issuer {
|
||||
return nil, errors.New("invalid token issuer")
|
||||
}
|
||||
|
||||
// 验证expiration
|
||||
if claims.ExpiresAt != nil && claims.ExpiresAt.Time.Before(time.Now()) {
|
||||
return nil, errors.New("token has expired")
|
||||
}
|
||||
|
||||
// 验证not before
|
||||
if claims.NotBefore != nil && claims.NotBefore.Time.After(time.Now()) {
|
||||
return nil, errors.New("token is not yet valid")
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("invalid token")
|
||||
}
|
||||
|
||||
// checkTokenStatus 检查token状态(从缓存或数据库)
|
||||
func (m *AuthMiddleware) checkTokenStatus(tokenID string) (string, error) {
|
||||
if m.tokenCache != nil {
|
||||
// 先从缓存检查
|
||||
if status, found := m.tokenCache.Get(tokenID); found {
|
||||
return status, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 缓存未命中,返回active(实际应该查询数据库)
|
||||
return "active", nil
|
||||
}
|
||||
|
||||
// GetTokenClaims 从context获取token claims
|
||||
func GetTokenClaims(ctx context.Context) *TokenClaims {
|
||||
if claims, ok := ctx.Value(tokenClaimsKey).(*TokenClaims); ok {
|
||||
return claims
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// context keys
|
||||
const (
|
||||
bearerTokenKey contextKey = "bearer_token"
|
||||
tokenClaimsKey contextKey = "token_claims"
|
||||
)
|
||||
|
||||
// writeAuthError 写入鉴权错误
|
||||
func writeAuthError(w http.ResponseWriter, status int, code, message string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
resp := map[string]interface{}{
|
||||
"request_id": "",
|
||||
"error": map[string]string{
|
||||
"code": code,
|
||||
"message": message,
|
||||
},
|
||||
}
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
|
||||
// getRequestID 获取请求ID
|
||||
func getRequestID(r *http.Request) string {
|
||||
if id := r.Header.Get("X-Request-Id"); id != "" {
|
||||
return id
|
||||
}
|
||||
return r.Header.Get("X-Request-ID")
|
||||
}
|
||||
|
||||
// getClientIP 获取客户端IP
|
||||
func getClientIP(r *http.Request) string {
|
||||
// 优先从X-Forwarded-For获取
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
parts := strings.Split(xff, ",")
|
||||
return strings.TrimSpace(parts[0])
|
||||
}
|
||||
|
||||
// X-Real-IP
|
||||
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||||
return xri
|
||||
}
|
||||
|
||||
// RemoteAddr
|
||||
addr := r.RemoteAddr
|
||||
if idx := strings.LastIndex(addr, ":"); idx != -1 {
|
||||
return addr[:idx]
|
||||
}
|
||||
return addr
|
||||
}
|
||||
|
||||
// containsScope 检查scope列表是否包含目标scope
|
||||
func containsScope(scopes []string, target string) bool {
|
||||
for _, scope := range scopes {
|
||||
if scope == target || scope == "*" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// roleLevel 获取角色等级
|
||||
func roleLevel(role string, hierarchy map[string]int) int {
|
||||
if level, ok := hierarchy[role]; ok {
|
||||
return level
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// parseSubjectID 解析subject ID
|
||||
func parseSubjectID(subject string) int64 {
|
||||
parts := strings.Split(subject, ":")
|
||||
if len(parts) >= 2 {
|
||||
id, _ := strconv.ParseInt(parts[1], 10, 64)
|
||||
return id
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// TokenCache Token状态缓存
|
||||
type TokenCache struct {
|
||||
data map[string]cacheEntry
|
||||
}
|
||||
|
||||
type cacheEntry struct {
|
||||
status string
|
||||
expires time.Time
|
||||
}
|
||||
|
||||
// NewTokenCache 创建token缓存
|
||||
func NewTokenCache() *TokenCache {
|
||||
return &TokenCache{
|
||||
data: make(map[string]cacheEntry),
|
||||
}
|
||||
}
|
||||
|
||||
// Get 获取token状态
|
||||
func (c *TokenCache) Get(tokenID string) (string, bool) {
|
||||
if entry, ok := c.data[tokenID]; ok {
|
||||
if time.Now().Before(entry.expires) {
|
||||
return entry.status, true
|
||||
}
|
||||
delete(c.data, tokenID)
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// Set 设置token状态
|
||||
func (c *TokenCache) Set(tokenID, status string, ttl time.Duration) {
|
||||
c.data[tokenID] = cacheEntry{
|
||||
status: status,
|
||||
expires: time.Now().Add(ttl),
|
||||
}
|
||||
}
|
||||
|
||||
// Invalidate 使token失效
|
||||
func (c *TokenCache) Invalidate(tokenID string) {
|
||||
delete(c.data, tokenID)
|
||||
}
|
||||
|
||||
// ComputeFingerprint 计算凭证指纹(用于审计)
|
||||
func ComputeFingerprint(credential string) string {
|
||||
hash := sha256.Sum256([]byte(credential))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
343
supply-api/internal/middleware/auth_test.go
Normal file
343
supply-api/internal/middleware/auth_test.go
Normal file
@@ -0,0 +1,343 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
func TestTokenVerify(t *testing.T) {
|
||||
secretKey := "test-secret-key-12345678901234567890"
|
||||
issuer := "test-issuer"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
expectError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "valid token",
|
||||
token: createTestToken(secretKey, issuer, "subject:1", "owner", time.Now().Add(time.Hour)),
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "expired token",
|
||||
token: createTestToken(secretKey, issuer, "subject:1", "owner", time.Now().Add(-time.Hour)),
|
||||
expectError: true,
|
||||
errorContains: "expired",
|
||||
},
|
||||
{
|
||||
name: "wrong issuer",
|
||||
token: createTestToken(secretKey, "wrong-issuer", "subject:1", "owner", time.Now().Add(time.Hour)),
|
||||
expectError: true,
|
||||
errorContains: "issuer",
|
||||
},
|
||||
{
|
||||
name: "invalid token",
|
||||
token: "invalid.token.string",
|
||||
expectError: true,
|
||||
errorContains: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
middleware := &AuthMiddleware{
|
||||
config: AuthConfig{
|
||||
SecretKey: secretKey,
|
||||
Issuer: issuer,
|
||||
},
|
||||
}
|
||||
|
||||
_, err := middleware.verifyToken(tt.token)
|
||||
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Errorf("expected error but got nil")
|
||||
} else if tt.errorContains != "" && !strings.Contains(err.Error(), tt.errorContains) {
|
||||
t.Errorf("error = %v, want contains %v", err, tt.errorContains)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryKeyRejectMiddleware(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
expectStatus int
|
||||
}{
|
||||
{
|
||||
name: "no query params",
|
||||
query: "",
|
||||
expectStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "normal params",
|
||||
query: "?page=1&size=10",
|
||||
expectStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "blocked key param",
|
||||
query: "?key=abc123",
|
||||
expectStatus: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "blocked api_key param",
|
||||
query: "?api_key=secret123",
|
||||
expectStatus: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "blocked token param",
|
||||
query: "?token=bearer123",
|
||||
expectStatus: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "suspicious long param",
|
||||
query: "?apikey=verylongparamvalueexceeding20chars",
|
||||
expectStatus: http.StatusUnauthorized,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
middleware := &AuthMiddleware{
|
||||
auditEmitter: nil,
|
||||
}
|
||||
|
||||
nextCalled := false
|
||||
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextCalled = true
|
||||
})
|
||||
|
||||
handler := middleware.QueryKeyRejectMiddleware(nextHandler)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/v1/supply/accounts"+tt.query, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if tt.expectStatus == http.StatusOK {
|
||||
if !nextCalled {
|
||||
t.Errorf("expected next handler to be called")
|
||||
}
|
||||
} else {
|
||||
if w.Code != tt.expectStatus {
|
||||
t.Errorf("expected status %d, got %d", tt.expectStatus, w.Code)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBearerExtractMiddleware(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
authHeader string
|
||||
expectStatus int
|
||||
}{
|
||||
{
|
||||
name: "valid bearer",
|
||||
authHeader: "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c",
|
||||
expectStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "missing header",
|
||||
authHeader: "",
|
||||
expectStatus: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "wrong prefix",
|
||||
authHeader: "Basic abc123",
|
||||
expectStatus: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "empty token",
|
||||
authHeader: "Bearer ",
|
||||
expectStatus: http.StatusUnauthorized,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
middleware := &AuthMiddleware{}
|
||||
|
||||
nextCalled := false
|
||||
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextCalled = true
|
||||
// 检查context中是否有bearer token
|
||||
if r.Context().Value(bearerTokenKey) == nil && tt.authHeader != "" && strings.HasPrefix(tt.authHeader, "Bearer ") {
|
||||
// 这是预期的,因为token可能无效
|
||||
}
|
||||
})
|
||||
|
||||
handler := middleware.BearerExtractMiddleware(nextHandler)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/v1/supply/accounts", nil)
|
||||
if tt.authHeader != "" {
|
||||
req.Header.Set("Authorization", tt.authHeader)
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if tt.expectStatus == http.StatusOK {
|
||||
if !nextCalled {
|
||||
t.Errorf("expected next handler to be called")
|
||||
}
|
||||
} else {
|
||||
if w.Code != tt.expectStatus {
|
||||
t.Errorf("expected status %d, got %d", tt.expectStatus, w.Code)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestContainsScope(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
scopes []string
|
||||
target string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "exact match",
|
||||
scopes: []string{"read", "write", "delete"},
|
||||
target: "write",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard",
|
||||
scopes: []string{"*"},
|
||||
target: "anything",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "no match",
|
||||
scopes: []string{"read", "write"},
|
||||
target: "admin",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "empty scopes",
|
||||
scopes: []string{},
|
||||
target: "read",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := containsScope(tt.scopes, tt.target)
|
||||
if result != tt.expected {
|
||||
t.Errorf("containsScope(%v, %s) = %v, want %v", tt.scopes, tt.target, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoleLevel(t *testing.T) {
|
||||
hierarchy := map[string]int{
|
||||
"admin": 3,
|
||||
"owner": 2,
|
||||
"viewer": 1,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
role string
|
||||
expected int
|
||||
}{
|
||||
{"admin", 3},
|
||||
{"owner", 2},
|
||||
{"viewer", 1},
|
||||
{"unknown", 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.role, func(t *testing.T) {
|
||||
result := roleLevel(tt.role, hierarchy)
|
||||
if result != tt.expected {
|
||||
t.Errorf("roleLevel(%s) = %d, want %d", tt.role, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenCache(t *testing.T) {
|
||||
cache := NewTokenCache()
|
||||
|
||||
t.Run("get empty", func(t *testing.T) {
|
||||
status, found := cache.Get("nonexistent")
|
||||
if found {
|
||||
t.Errorf("expected not found")
|
||||
}
|
||||
if status != "" {
|
||||
t.Errorf("expected empty status")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("set and get", func(t *testing.T) {
|
||||
cache.Set("token1", "active", time.Hour)
|
||||
|
||||
status, found := cache.Get("token1")
|
||||
if !found {
|
||||
t.Errorf("expected to find token1")
|
||||
}
|
||||
if status != "active" {
|
||||
t.Errorf("expected status 'active', got '%s'", status)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalidate", func(t *testing.T) {
|
||||
cache.Set("token2", "revoked", time.Hour)
|
||||
cache.Invalidate("token2")
|
||||
|
||||
_, found := cache.Get("token2")
|
||||
if found {
|
||||
t.Errorf("expected token2 to be invalidated")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("expiration", func(t *testing.T) {
|
||||
cache.Set("token3", "active", time.Nanosecond)
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
_, found := cache.Get("token3")
|
||||
if found {
|
||||
t.Errorf("expected token3 to be expired")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func createTestToken(secretKey, issuer, subject, role string, expiresAt time.Time) string {
|
||||
claims := TokenClaims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: issuer,
|
||||
Subject: subject,
|
||||
ExpiresAt: jwt.NewNumericDate(expiresAt),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
},
|
||||
SubjectID: subject,
|
||||
Role: role,
|
||||
Scope: []string{"read", "write"},
|
||||
TenantID: 1,
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, _ := token.SignedString([]byte(secretKey))
|
||||
return tokenString
|
||||
}
|
||||
279
supply-api/internal/middleware/idempotency.go
Normal file
279
supply-api/internal/middleware/idempotency.go
Normal file
@@ -0,0 +1,279 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"lijiaoqiao/supply-api/internal/repository"
|
||||
)
|
||||
|
||||
// IdempotencyConfig 幂等中间件配置
|
||||
type IdempotencyConfig struct {
|
||||
TTL time.Duration // 幂等有效期,默认24h
|
||||
ProcessingTTL time.Duration // 处理中状态有效期,默认30s
|
||||
Enabled bool // 是否启用幂等
|
||||
}
|
||||
|
||||
// IdempotencyMiddleware 幂等中间件
|
||||
type IdempotencyMiddleware struct {
|
||||
idempotencyRepo *repository.IdempotencyRepository
|
||||
config IdempotencyConfig
|
||||
}
|
||||
|
||||
// NewIdempotencyMiddleware 创建幂等中间件
|
||||
func NewIdempotencyMiddleware(repo *repository.IdempotencyRepository, config IdempotencyConfig) *IdempotencyMiddleware {
|
||||
if config.TTL == 0 {
|
||||
config.TTL = 24 * time.Hour
|
||||
}
|
||||
if config.ProcessingTTL == 0 {
|
||||
config.ProcessingTTL = 30 * time.Second
|
||||
}
|
||||
return &IdempotencyMiddleware{
|
||||
idempotencyRepo: repo,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// IdempotencyKey 幂等键信息
|
||||
type IdempotencyKey struct {
|
||||
TenantID int64
|
||||
OperatorID int64
|
||||
APIPath string
|
||||
Key string
|
||||
}
|
||||
|
||||
// ExtractIdempotencyKey 从请求中提取幂等信息
|
||||
func ExtractIdempotencyKey(r *http.Request, tenantID, operatorID int64) (*IdempotencyKey, error) {
|
||||
requestID := r.Header.Get("X-Request-Id")
|
||||
if requestID == "" {
|
||||
return nil, fmt.Errorf("missing X-Request-Id header")
|
||||
}
|
||||
|
||||
idempotencyKey := r.Header.Get("Idempotency-Key")
|
||||
if idempotencyKey == "" {
|
||||
return nil, fmt.Errorf("missing Idempotency-Key header")
|
||||
}
|
||||
|
||||
if len(idempotencyKey) < 16 || len(idempotencyKey) > 128 {
|
||||
return nil, fmt.Errorf("Idempotency-Key length must be 16-128")
|
||||
}
|
||||
|
||||
// 从路径提取API路径(去除前缀)
|
||||
apiPath := r.URL.Path
|
||||
if strings.HasPrefix(apiPath, "/api/v1") {
|
||||
apiPath = strings.TrimPrefix(apiPath, "/api/v1")
|
||||
}
|
||||
|
||||
return &IdempotencyKey{
|
||||
TenantID: tenantID,
|
||||
OperatorID: operatorID,
|
||||
APIPath: apiPath,
|
||||
Key: idempotencyKey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ComputePayloadHash 计算请求体的SHA256哈希
|
||||
func ComputePayloadHash(body []byte) string {
|
||||
hash := sha256.Sum256(body)
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// IdempotentHandler 幂等处理器函数
|
||||
type IdempotentHandler func(ctx context.Context, w http.ResponseWriter, r *http.Request, record *repository.IdempotencyRecord) error
|
||||
|
||||
// Wrap 包装HTTP处理器以实现幂等
|
||||
func (m *IdempotencyMiddleware) Wrap(handler IdempotentHandler) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if !m.config.Enabled {
|
||||
handler(r.Context(), w, r, nil)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
// 从context获取租户和操作者ID(由鉴权中间件设置)
|
||||
tenantID := getTenantID(ctx)
|
||||
operatorID := getOperatorID(ctx)
|
||||
|
||||
// 提取幂等信息
|
||||
idempKey, err := ExtractIdempotencyKey(r, tenantID, operatorID)
|
||||
if err != nil {
|
||||
writeIdempotencyError(w, http.StatusBadRequest, "IDEMPOTENCY_KEY_INVALID", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 读取请求体
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
writeIdempotencyError(w, http.StatusBadRequest, "BODY_READ_ERROR", err.Error())
|
||||
return
|
||||
}
|
||||
// 重新填充body以供后续处理
|
||||
r.Body = io.NopCloser(bytes.NewBuffer(body))
|
||||
|
||||
// 计算payload hash
|
||||
payloadHash := ComputePayloadHash(body)
|
||||
|
||||
// 查询已存在的幂等记录
|
||||
existingRecord, err := m.idempotencyRepo.GetByKey(ctx, idempKey.TenantID, idempKey.OperatorID, idempKey.APIPath, idempKey.Key)
|
||||
if err != nil {
|
||||
writeIdempotencyError(w, http.StatusInternalServerError, "IDEMPOTENCY_CHECK_FAILED", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if existingRecord != nil {
|
||||
// 存在记录,处理不同情况
|
||||
switch existingRecord.Status {
|
||||
case repository.IdempotencyStatusSucceeded:
|
||||
// 同参重放:返回原结果
|
||||
if existingRecord.PayloadHash == payloadHash {
|
||||
writeIdempotentReplay(w, existingRecord.ResponseCode, existingRecord.ResponseBody)
|
||||
return
|
||||
}
|
||||
// 异参重放:返回409冲突
|
||||
writeIdempotencyError(w, http.StatusConflict, "IDEMPOTENCY_PAYLOAD_MISMATCH",
|
||||
fmt.Sprintf("same idempotency key but different payload, original request_id: %s", existingRecord.RequestID))
|
||||
return
|
||||
|
||||
case repository.IdempotencyStatusProcessing:
|
||||
// 处理中:检查是否超时
|
||||
if time.Since(existingRecord.UpdatedAt) < m.config.ProcessingTTL {
|
||||
retryAfter := m.config.ProcessingTTL - time.Since(existingRecord.UpdatedAt)
|
||||
writeIdempotencyProcessing(w, int(retryAfter.Milliseconds()), existingRecord.RequestID)
|
||||
return
|
||||
}
|
||||
// 超时:允许重试(记录会自然过期)
|
||||
|
||||
case repository.IdempotencyStatusFailed:
|
||||
// 失败状态也允许重试
|
||||
}
|
||||
}
|
||||
|
||||
// 尝试创建或更新幂等记录
|
||||
requestID := r.Header.Get("X-Request-Id")
|
||||
record := &repository.IdempotencyRecord{
|
||||
TenantID: idempKey.TenantID,
|
||||
OperatorID: idempKey.OperatorID,
|
||||
APIPath: idempKey.APIPath,
|
||||
IdempotencyKey: idempKey.Key,
|
||||
RequestID: requestID,
|
||||
PayloadHash: payloadHash,
|
||||
Status: repository.IdempotencyStatusProcessing,
|
||||
ExpiresAt: time.Now().Add(m.config.TTL),
|
||||
}
|
||||
|
||||
// 使用AcquireLock获取锁
|
||||
lockedRecord, err := m.idempotencyRepo.AcquireLock(ctx, idempKey.TenantID, idempKey.OperatorID, idempKey.APIPath, idempKey.Key, m.config.TTL)
|
||||
if err != nil {
|
||||
writeIdempotencyError(w, http.StatusInternalServerError, "IDEMPOTENCY_LOCK_FAILED", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 更新记录中的request_id和payload_hash
|
||||
if lockedRecord.ID != 0 && (lockedRecord.RequestID == "" || lockedRecord.PayloadHash == "") {
|
||||
lockedRecord.RequestID = requestID
|
||||
lockedRecord.PayloadHash = payloadHash
|
||||
}
|
||||
|
||||
// 执行实际业务处理
|
||||
err = handler(ctx, w, r, lockedRecord)
|
||||
|
||||
// 根据处理结果更新幂等记录
|
||||
if err != nil {
|
||||
// 业务处理失败
|
||||
errMsg, _ := json.Marshal(map[string]string{"error": err.Error()})
|
||||
_ = m.idempotencyRepo.UpdateFailed(ctx, lockedRecord.ID, http.StatusInternalServerError, errMsg)
|
||||
return
|
||||
}
|
||||
|
||||
// 业务处理成功,更新为成功状态
|
||||
// 注意:这里需要从w中获取实际的响应码和body
|
||||
// 简化处理:使用200
|
||||
successBody, _ := json.Marshal(map[string]interface{}{"status": "ok"})
|
||||
_ = m.idempotencyRepo.UpdateSuccess(ctx, lockedRecord.ID, http.StatusOK, successBody)
|
||||
}
|
||||
}
|
||||
|
||||
// writeIdempotencyError 写入幂等错误
|
||||
func writeIdempotencyError(w http.ResponseWriter, status int, code, message string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
resp := map[string]interface{}{
|
||||
"request_id": "",
|
||||
"error": map[string]string{
|
||||
"code": code,
|
||||
"message": message,
|
||||
},
|
||||
}
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
|
||||
// writeIdempotencyProcessing 写入处理中状态
|
||||
func writeIdempotencyProcessing(w http.ResponseWriter, retryAfterMs int, requestID string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Retry-After-Ms", fmt.Sprintf("%d", retryAfterMs))
|
||||
w.Header().Set("X-Request-Id", requestID)
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
resp := map[string]interface{}{
|
||||
"request_id": requestID,
|
||||
"error": map[string]string{
|
||||
"code": "IDEMPOTENCY_IN_PROGRESS",
|
||||
"message": "request is being processed, please retry later",
|
||||
},
|
||||
}
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
|
||||
// writeIdempotentReplay 写入幂等重放响应
|
||||
func writeIdempotentReplay(w http.ResponseWriter, status int, body json.RawMessage) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("X-Idempotent-Replay", "true")
|
||||
w.WriteHeader(status)
|
||||
if body != nil {
|
||||
w.Write(body)
|
||||
}
|
||||
}
|
||||
|
||||
// context keys
|
||||
type contextKey string
|
||||
|
||||
const (
|
||||
tenantIDKey contextKey = "tenant_id"
|
||||
operatorIDKey contextKey = "operator_id"
|
||||
)
|
||||
|
||||
// WithTenantID 在context中设置租户ID
|
||||
func WithTenantID(ctx context.Context, tenantID int64) context.Context {
|
||||
return context.WithValue(ctx, tenantIDKey, tenantID)
|
||||
}
|
||||
|
||||
// WithOperatorID 在context中设置操作者ID
|
||||
func WithOperatorID(ctx context.Context, operatorID int64) context.Context {
|
||||
return context.WithValue(ctx, operatorIDKey, operatorID)
|
||||
}
|
||||
|
||||
func getTenantID(ctx context.Context) int64 {
|
||||
if v := ctx.Value(tenantIDKey); v != nil {
|
||||
if id, ok := v.(int64); ok {
|
||||
return id
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func getOperatorID(ctx context.Context) int64 {
|
||||
if v := ctx.Value(operatorIDKey); v != nil {
|
||||
if id, ok := v.(int64); ok {
|
||||
return id
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
211
supply-api/internal/middleware/idempotency_test.go
Normal file
211
supply-api/internal/middleware/idempotency_test.go
Normal file
@@ -0,0 +1,211 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"lijiaoqiao/supply-api/internal/repository"
|
||||
)
|
||||
|
||||
// MockIdempotencyRepository 模拟幂等仓储
|
||||
type MockIdempotencyRepository struct {
|
||||
records map[string]*repository.IdempotencyRecord
|
||||
}
|
||||
|
||||
func NewMockIdempotencyRepository() *MockIdempotencyRepository {
|
||||
return &MockIdempotencyRepository{
|
||||
records: make(map[string]*repository.IdempotencyRecord),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *MockIdempotencyRepository) GetByKey(ctx context.Context, tenantID, operatorID int64, apiPath, idempotencyKey string) (*repository.IdempotencyRecord, error) {
|
||||
key := buildKey(tenantID, operatorID, apiPath, idempotencyKey)
|
||||
if record, ok := r.records[key]; ok {
|
||||
if time.Now().Before(record.ExpiresAt) {
|
||||
return record, nil
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (r *MockIdempotencyRepository) Create(ctx context.Context, record *repository.IdempotencyRecord) error {
|
||||
key := buildKey(record.TenantID, record.OperatorID, record.APIPath, record.IdempotencyKey)
|
||||
r.records[key] = record
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *MockIdempotencyRepository) UpdateSuccess(ctx context.Context, id int64, responseCode int, responseBody json.RawMessage) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *MockIdempotencyRepository) UpdateFailed(ctx context.Context, id int64, responseCode int, responseBody json.RawMessage) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *MockIdempotencyRepository) AcquireLock(ctx context.Context, tenantID, operatorID int64, apiPath, idempotencyKey string, ttl time.Duration) (*repository.IdempotencyRecord, error) {
|
||||
key := buildKey(tenantID, operatorID, apiPath, idempotencyKey)
|
||||
record := &repository.IdempotencyRecord{
|
||||
TenantID: tenantID,
|
||||
OperatorID: operatorID,
|
||||
APIPath: apiPath,
|
||||
IdempotencyKey: idempotencyKey,
|
||||
RequestID: "test-request-id",
|
||||
PayloadHash: "",
|
||||
Status: repository.IdempotencyStatusProcessing,
|
||||
ExpiresAt: time.Now().Add(ttl),
|
||||
}
|
||||
r.records[key] = record
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func buildKey(tenantID, operatorID int64, apiPath, idempotencyKey string) string {
|
||||
return strings.Join([]string{
|
||||
string(rune(tenantID)),
|
||||
string(rune(operatorID)),
|
||||
apiPath,
|
||||
idempotencyKey,
|
||||
}, ":")
|
||||
}
|
||||
|
||||
func TestComputePayloadHash(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body []byte
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "empty body",
|
||||
body: []byte{},
|
||||
expected: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
|
||||
},
|
||||
{
|
||||
name: "simple JSON",
|
||||
body: []byte(`{"key":"value"}`),
|
||||
expected: computeExpectedHash(`{"key":"value"}`),
|
||||
},
|
||||
{
|
||||
name: "JSON with spaces",
|
||||
body: []byte(`{ "key": "value" }`),
|
||||
expected: computeExpectedHash(`{ "key": "value" }`),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ComputePayloadHash(tt.body)
|
||||
if result != tt.expected {
|
||||
t.Errorf("ComputePayloadHash() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func computeExpectedHash(s string) string {
|
||||
hash := sha256.Sum256([]byte(s))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
func TestExtractIdempotencyKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
headers map[string]string
|
||||
expectError bool
|
||||
errorCode string
|
||||
}{
|
||||
{
|
||||
name: "valid headers",
|
||||
headers: map[string]string{
|
||||
"X-Request-Id": "req-123",
|
||||
"Idempotency-Key": "idem-key-12345678",
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "missing X-Request-Id",
|
||||
headers: map[string]string{
|
||||
"Idempotency-Key": "idem-key-12345678",
|
||||
},
|
||||
expectError: true,
|
||||
errorCode: "missing X-Request-Id header",
|
||||
},
|
||||
{
|
||||
name: "missing Idempotency-Key",
|
||||
headers: map[string]string{
|
||||
"X-Request-Id": "req-123",
|
||||
},
|
||||
expectError: true,
|
||||
errorCode: "missing Idempotency-Key header",
|
||||
},
|
||||
{
|
||||
name: "Idempotency-Key too short",
|
||||
headers: map[string]string{
|
||||
"X-Request-Id": "req-123",
|
||||
"Idempotency-Key": "short",
|
||||
},
|
||||
expectError: true,
|
||||
errorCode: "Idempotency-Key length must be 16-128",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/api/v1/supply/accounts", nil)
|
||||
for k, v := range tt.headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
result, err := ExtractIdempotencyKey(req, 1, 1)
|
||||
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Errorf("expected error but got nil")
|
||||
}
|
||||
if err != nil && !strings.Contains(err.Error(), tt.errorCode) {
|
||||
t.Errorf("error = %v, want contains %v", err, tt.errorCode)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if result == nil {
|
||||
t.Errorf("expected result but got nil")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIdempotentHandler(t *testing.T) {
|
||||
// 创建测试handler
|
||||
testHandler := func(ctx context.Context, w http.ResponseWriter, r *http.Request, record *repository.IdempotencyRecord) error {
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "created"})
|
||||
return nil
|
||||
}
|
||||
|
||||
middleware := NewIdempotencyMiddleware(nil, IdempotencyConfig{
|
||||
Enabled: false, // 禁用幂等,只测试handler包装
|
||||
})
|
||||
|
||||
handler := middleware.Wrap(testHandler)
|
||||
|
||||
t.Run("handler executes successfully", func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/api/v1/supply/accounts", strings.NewReader(`{"key":"value"}`))
|
||||
req.Header.Set("X-Request-Id", "req-123")
|
||||
req.Header.Set("Idempotency-Key", "idem-key-12345678")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
t.Errorf("expected status %d, got %d", http.StatusCreated, w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
291
supply-api/internal/repository/account.go
Normal file
291
supply-api/internal/repository/account.go
Normal file
@@ -0,0 +1,291 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"lijiaoqiao/supply-api/internal/domain"
|
||||
)
|
||||
|
||||
// AccountRepository 账号仓储
|
||||
type AccountRepository struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
// NewAccountRepository 创建账号仓储
|
||||
func NewAccountRepository(pool *pgxpool.Pool) *AccountRepository {
|
||||
return &AccountRepository{pool: pool}
|
||||
}
|
||||
|
||||
// Create 创建账号
|
||||
func (r *AccountRepository) Create(ctx context.Context, account *domain.Account, requestID, idempotencyKey, traceID string) error {
|
||||
query := `
|
||||
INSERT INTO supply_accounts (
|
||||
user_id, platform, account_type, account_name,
|
||||
encrypted_credentials, key_id,
|
||||
status, risk_level, total_quota, available_quota, frozen_quota,
|
||||
is_verified, verified_at, last_check_at,
|
||||
tos_compliant, tos_check_result,
|
||||
total_requests, total_tokens, total_cost, success_rate,
|
||||
risk_score, risk_reason, is_frozen, frozen_reason,
|
||||
credential_cipher_algo, credential_kms_key_alias, credential_key_version,
|
||||
quota_unit, currency_code, version,
|
||||
created_ip, updated_ip, audit_trace_id,
|
||||
request_id, idempotency_key
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6,
|
||||
$7, $8, $9, $10, $11, $12, $13, $14,
|
||||
$15, $16, $17, $18, $19, $20,
|
||||
$21, $22, $23, $24, $25, $26, $27,
|
||||
$28, $29, $30, $31, $32, $33, $34, $35
|
||||
)
|
||||
RETURNING id, created_at, updated_at
|
||||
`
|
||||
|
||||
var createdIP, updatedIP *netip.Addr
|
||||
if account.CreatedIP != nil {
|
||||
createdIP = account.CreatedIP
|
||||
}
|
||||
if account.UpdatedIP != nil {
|
||||
updatedIP = account.UpdatedIP
|
||||
}
|
||||
|
||||
err := r.pool.QueryRow(ctx, query,
|
||||
account.SupplierID, account.Provider, account.AccountType, account.Alias,
|
||||
account.CredentialHash, account.KeyID,
|
||||
account.Status, account.RiskLevel, account.TotalQuota, account.AvailableQuota, account.FrozenQuota,
|
||||
account.IsVerified, account.VerifiedAt, account.LastCheckAt,
|
||||
account.TosCompliant, account.TosCheckResult,
|
||||
account.TotalRequests, account.TotalTokens, account.TotalCost, account.SuccessRate,
|
||||
account.RiskScore, account.RiskReason, account.IsFrozen, account.FrozenReason,
|
||||
"AES-256-GCM", "kms/supply/default", 1,
|
||||
"token", "USD", 0,
|
||||
createdIP, updatedIP, traceID,
|
||||
requestID, idempotencyKey,
|
||||
).Scan(&account.ID, &account.CreatedAt, &account.UpdatedAt)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create account: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByID 获取账号
|
||||
func (r *AccountRepository) GetByID(ctx context.Context, supplierID, id int64) (*domain.Account, error) {
|
||||
query := `
|
||||
SELECT id, user_id, platform, account_type, account_name,
|
||||
encrypted_credentials, key_id,
|
||||
status, risk_level, total_quota, available_quota, frozen_quota,
|
||||
is_verified, verified_at, last_check_at,
|
||||
tos_compliant, tos_check_result,
|
||||
total_requests, total_tokens, total_cost, success_rate,
|
||||
risk_score, risk_reason, is_frozen, frozen_reason,
|
||||
credential_cipher_algo, credential_kms_key_alias, credential_key_version,
|
||||
quota_unit, currency_code, version,
|
||||
created_ip, updated_ip, audit_trace_id,
|
||||
created_at, updated_at
|
||||
FROM supply_accounts
|
||||
WHERE id = $1 AND user_id = $2
|
||||
`
|
||||
|
||||
account := &domain.Account{}
|
||||
var createdIP, updatedIP netip.Addr
|
||||
var credentialFingerprint *string
|
||||
|
||||
err := r.pool.QueryRow(ctx, query, id, supplierID).Scan(
|
||||
&account.ID, &account.SupplierID, &account.Provider, &account.AccountType, &account.Alias,
|
||||
&account.CredentialHash, &account.KeyID,
|
||||
&account.Status, &account.RiskLevel, &account.TotalQuota, &account.AvailableQuota, &account.FrozenQuota,
|
||||
&account.IsVerified, &account.VerifiedAt, &account.LastCheckAt,
|
||||
&account.TosCompliant, &account.TosCheckResult,
|
||||
&account.TotalRequests, &account.TotalTokens, &account.TotalCost, &account.SuccessRate,
|
||||
&account.RiskScore, &account.RiskReason, &account.IsFrozen, &account.FrozenReason,
|
||||
&account.CredentialCipherAlgo, &account.CredentialKMSKeyAlias, &account.CredentialKeyVersion,
|
||||
&account.QuotaUnit, &account.CurrencyCode, &account.Version,
|
||||
&createdIP, &updatedIP, &account.AuditTraceID,
|
||||
&account.CreatedAt, &account.UpdatedAt,
|
||||
)
|
||||
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get account: %w", err)
|
||||
}
|
||||
|
||||
account.CreatedIP = &createdIP
|
||||
account.UpdatedIP = &updatedIP
|
||||
_ = credentialFingerprint // 未使用但字段存在
|
||||
|
||||
return account, nil
|
||||
}
|
||||
|
||||
// Update 更新账号(乐观锁)
|
||||
func (r *AccountRepository) Update(ctx context.Context, account *domain.Account, expectedVersion int) error {
|
||||
query := `
|
||||
UPDATE supply_accounts SET
|
||||
platform = $1, account_type = $2, account_name = $3,
|
||||
status = $4, risk_level = $5, total_quota = $6, available_quota = $7,
|
||||
frozen_quota = $8, is_verified = $9, verified_at = $10, last_check_at = $11,
|
||||
tos_compliant = $12, tos_check_result = $13,
|
||||
total_requests = $14, total_tokens = $15, total_cost = $16, success_rate = $17,
|
||||
risk_score = $18, risk_reason = $19, is_frozen = $20, frozen_reason = $21,
|
||||
version = $22, updated_at = $23
|
||||
WHERE id = $24 AND user_id = $25 AND version = $26
|
||||
`
|
||||
|
||||
account.UpdatedAt = time.Now()
|
||||
newVersion := expectedVersion + 1
|
||||
|
||||
cmdTag, err := r.pool.Exec(ctx, query,
|
||||
account.Provider, account.AccountType, account.Alias,
|
||||
account.Status, account.RiskLevel, account.TotalQuota, account.AvailableQuota,
|
||||
account.FrozenQuota, account.IsVerified, account.VerifiedAt, account.LastCheckAt,
|
||||
account.TosCompliant, account.TosCheckResult,
|
||||
account.TotalRequests, account.TotalTokens, account.TotalCost, account.SuccessRate,
|
||||
account.RiskScore, account.RiskReason, account.IsFrozen, account.FrozenReason,
|
||||
newVersion, account.UpdatedAt,
|
||||
account.ID, account.SupplierID, expectedVersion,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update account: %w", err)
|
||||
}
|
||||
|
||||
if cmdTag.RowsAffected() == 0 {
|
||||
return ErrConcurrencyConflict
|
||||
}
|
||||
|
||||
account.Version = newVersion
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateWithPessimisticLock 更新账号(悲观锁,用于提现等关键操作)
|
||||
func (r *AccountRepository) UpdateWithPessimisticLock(ctx context.Context, tx pgxpool.Tx, account *domain.Account, expectedVersion int) error {
|
||||
query := `
|
||||
UPDATE supply_accounts SET
|
||||
available_quota = $1, frozen_quota = $2,
|
||||
version = $3, updated_at = $4
|
||||
WHERE id = $5 AND version = $6
|
||||
RETURNING version
|
||||
`
|
||||
|
||||
account.UpdatedAt = time.Now()
|
||||
newVersion := expectedVersion + 1
|
||||
|
||||
err := tx.QueryRow(ctx, query,
|
||||
account.AvailableQuota, account.FrozenQuota,
|
||||
newVersion, account.UpdatedAt,
|
||||
account.ID, expectedVersion,
|
||||
).Scan(&account.Version)
|
||||
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return ErrConcurrencyConflict
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update account with lock: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetForUpdate 获取账号并加行锁(用于事务内)
|
||||
func (r *AccountRepository) GetForUpdate(ctx context.Context, tx pgxpool.Tx, supplierID, id int64) (*domain.Account, error) {
|
||||
query := `
|
||||
SELECT id, user_id, platform, account_type, account_name,
|
||||
encrypted_credentials, key_id,
|
||||
status, risk_level, total_quota, available_quota, frozen_quota,
|
||||
is_verified, verified_at, last_check_at,
|
||||
tos_compliant, tos_check_result,
|
||||
total_requests, total_tokens, total_cost, success_rate,
|
||||
risk_score, risk_reason, is_frozen, frozen_reason,
|
||||
version,
|
||||
created_at, updated_at
|
||||
FROM supply_accounts
|
||||
WHERE id = $1 AND user_id = $2
|
||||
FOR UPDATE
|
||||
`
|
||||
|
||||
account := &domain.Account{}
|
||||
err := tx.QueryRow(ctx, query, id, supplierID).Scan(
|
||||
&account.ID, &account.SupplierID, &account.Provider, &account.AccountType, &account.Alias,
|
||||
&account.CredentialHash, &account.KeyID,
|
||||
&account.Status, &account.RiskLevel, &account.TotalQuota, &account.AvailableQuota, &account.FrozenQuota,
|
||||
&account.IsVerified, &account.VerifiedAt, &account.LastCheckAt,
|
||||
&account.TosCompliant, &account.TosCheckResult,
|
||||
&account.TotalRequests, &account.TotalTokens, &account.TotalCost, &account.SuccessRate,
|
||||
&account.RiskScore, &account.RiskReason, &account.IsFrozen, &account.FrozenReason,
|
||||
&account.Version,
|
||||
&account.CreatedAt, &account.UpdatedAt,
|
||||
)
|
||||
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get account for update: %w", err)
|
||||
}
|
||||
|
||||
return account, nil
|
||||
}
|
||||
|
||||
// List 列出账号
|
||||
func (r *AccountRepository) List(ctx context.Context, supplierID int64) ([]*domain.Account, error) {
|
||||
query := `
|
||||
SELECT id, user_id, platform, account_type, account_name,
|
||||
status, risk_level, total_quota, available_quota, frozen_quota,
|
||||
is_verified, verified_at, last_check_at,
|
||||
tos_compliant, success_rate,
|
||||
risk_score, is_frozen,
|
||||
version, created_at, updated_at
|
||||
FROM supply_accounts
|
||||
WHERE user_id = $1
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
rows, err := r.pool.Query(ctx, query, supplierID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list accounts: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var accounts []*domain.Account
|
||||
for rows.Next() {
|
||||
account := &domain.Account{}
|
||||
err := rows.Scan(
|
||||
&account.ID, &account.SupplierID, &account.Provider, &account.AccountType, &account.Alias,
|
||||
&account.Status, &account.RiskLevel, &account.TotalQuota, &account.AvailableQuota, &account.FrozenQuota,
|
||||
&account.IsVerified, &account.VerifiedAt, &account.LastCheckAt,
|
||||
&account.TosCompliant, &account.SuccessRate,
|
||||
&account.RiskScore, &account.IsFrozen,
|
||||
&account.Version, &account.CreatedAt, &account.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan account: %w", err)
|
||||
}
|
||||
accounts = append(accounts, account)
|
||||
}
|
||||
|
||||
return accounts, nil
|
||||
}
|
||||
|
||||
// GetWithdrawableBalance 获取可提现余额
|
||||
func (r *AccountRepository) GetWithdrawableBalance(ctx context.Context, supplierID int64) (float64, error) {
|
||||
query := `
|
||||
SELECT COALESCE(SUM(available_quota), 0)
|
||||
FROM supply_accounts
|
||||
WHERE user_id = $1 AND status = 'active'
|
||||
`
|
||||
|
||||
var balance float64
|
||||
err := r.pool.QueryRow(ctx, query, supplierID).Scan(&balance)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get withdrawable balance: %w", err)
|
||||
}
|
||||
return balance, nil
|
||||
}
|
||||
81
supply-api/internal/repository/db.go
Normal file
81
supply-api/internal/repository/db.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"lijiaoqiao/supply-api/internal/config"
|
||||
)
|
||||
|
||||
// DB 数据库连接池
|
||||
type DB struct {
|
||||
Pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
// NewDB 创建数据库连接池
|
||||
func NewDB(ctx context.Context, cfg config.DatabaseConfig) (*DB, error) {
|
||||
poolConfig, err := pgxpool.ParseConfig(cfg.DSN())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse database config: %w", err)
|
||||
}
|
||||
|
||||
poolConfig.MaxConns = int32(cfg.MaxOpenConns)
|
||||
poolConfig.MinConns = int32(cfg.MaxIdleConns)
|
||||
poolConfig.MaxConnLifetime = cfg.ConnMaxLifetime
|
||||
poolConfig.MaxConnIdleTime = cfg.ConnMaxIdleTime
|
||||
poolConfig.HealthCheckPeriod = 30 * time.Second
|
||||
|
||||
pool, err := pgxpool.NewWithConfig(ctx, poolConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create connection pool: %w", err)
|
||||
}
|
||||
|
||||
// 验证连接
|
||||
if err := pool.Ping(ctx); err != nil {
|
||||
pool.Close()
|
||||
return nil, fmt.Errorf("failed to ping database: %w", err)
|
||||
}
|
||||
|
||||
return &DB{Pool: pool}, nil
|
||||
}
|
||||
|
||||
// Close 关闭连接池
|
||||
func (db *DB) Close() {
|
||||
if db.Pool != nil {
|
||||
db.Pool.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// HealthCheck 健康检查
|
||||
func (db *DB) HealthCheck(ctx context.Context) error {
|
||||
return db.Pool.Ping(ctx)
|
||||
}
|
||||
|
||||
// BeginTx 开始事务
|
||||
func (db *DB) BeginTx(ctx context.Context) (Transaction, error) {
|
||||
tx, err := db.Pool.Begin(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &txWrapper{tx: tx}, nil
|
||||
}
|
||||
|
||||
// Transaction 事务接口
|
||||
type Transaction interface {
|
||||
Commit(ctx context.Context) error
|
||||
Rollback(ctx context.Context) error
|
||||
}
|
||||
|
||||
type txWrapper struct {
|
||||
tx pgxpool.Tx
|
||||
}
|
||||
|
||||
func (t *txWrapper) Commit(ctx context.Context) error {
|
||||
return t.tx.Commit(ctx)
|
||||
}
|
||||
|
||||
func (t *txWrapper) Rollback(ctx context.Context) error {
|
||||
return t.tx.Rollback(ctx)
|
||||
}
|
||||
246
supply-api/internal/repository/idempotency.go
Normal file
246
supply-api/internal/repository/idempotency.go
Normal file
@@ -0,0 +1,246 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
// IdempotencyStatus 幂等记录状态
|
||||
type IdempotencyStatus string
|
||||
|
||||
const (
|
||||
IdempotencyStatusProcessing IdempotencyStatus = "processing"
|
||||
IdempotencyStatusSucceeded IdempotencyStatus = "succeeded"
|
||||
IdempotencyStatusFailed IdempotencyStatus = "failed"
|
||||
)
|
||||
|
||||
// IdempotencyRecord 幂等记录
|
||||
type IdempotencyRecord struct {
|
||||
ID int64 `json:"id"`
|
||||
TenantID int64 `json:"tenant_id"`
|
||||
OperatorID int64 `json:"operator_id"`
|
||||
APIPath string `json:"api_path"`
|
||||
IdempotencyKey string `json:"idempotency_key"`
|
||||
RequestID string `json:"request_id"`
|
||||
PayloadHash string `json:"payload_hash"` // SHA256 of request body
|
||||
ResponseCode int `json:"response_code"`
|
||||
ResponseBody json.RawMessage `json:"response_body"`
|
||||
Status IdempotencyStatus `json:"status"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// IdempotencyRepository 幂等记录仓储
|
||||
type IdempotencyRepository struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
// NewIdempotencyRepository 创建幂等记录仓储
|
||||
func NewIdempotencyRepository(pool *pgxpool.Pool) *IdempotencyRepository {
|
||||
return &IdempotencyRepository{pool: pool}
|
||||
}
|
||||
|
||||
// GetByKey 根据幂等键获取记录
|
||||
func (r *IdempotencyRepository) GetByKey(ctx context.Context, tenantID, operatorID int64, apiPath, idempotencyKey string) (*IdempotencyRecord, error) {
|
||||
query := `
|
||||
SELECT id, tenant_id, operator_id, api_path, idempotency_key,
|
||||
request_id, payload_hash, response_code, response_body,
|
||||
status, expires_at, created_at, updated_at
|
||||
FROM supply_idempotency_records
|
||||
WHERE tenant_id = $1 AND operator_id = $2 AND api_path = $3 AND idempotency_key = $4
|
||||
AND expires_at > $5
|
||||
FOR UPDATE
|
||||
`
|
||||
|
||||
record := &IdempotencyRecord{}
|
||||
err := r.pool.QueryRow(ctx, query, tenantID, operatorID, apiPath, idempotencyKey, time.Now()).Scan(
|
||||
&record.ID, &record.TenantID, &record.OperatorID, &record.APIPath, &record.IdempotencyKey,
|
||||
&record.RequestID, &record.PayloadHash, &record.ResponseCode, &record.ResponseBody,
|
||||
&record.Status, &record.ExpiresAt, &record.CreatedAt, &record.UpdatedAt,
|
||||
)
|
||||
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, nil // 不存在或已过期
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get idempotency record: %w", err)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// Create 创建幂等记录
|
||||
func (r *IdempotencyRepository) Create(ctx context.Context, record *IdempotencyRecord) error {
|
||||
query := `
|
||||
INSERT INTO supply_idempotency_records (
|
||||
tenant_id, operator_id, api_path, idempotency_key,
|
||||
request_id, payload_hash, status, expires_at
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7, $8
|
||||
)
|
||||
RETURNING id, created_at, updated_at
|
||||
`
|
||||
|
||||
err := r.pool.QueryRow(ctx, query,
|
||||
record.TenantID, record.OperatorID, record.APIPath, record.IdempotencyKey,
|
||||
record.RequestID, record.PayloadHash, record.Status, record.ExpiresAt,
|
||||
).Scan(&record.ID, &record.CreatedAt, &record.UpdatedAt)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create idempotency record: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateSuccess 更新为成功状态
|
||||
func (r *IdempotencyRepository) UpdateSuccess(ctx context.Context, id int64, responseCode int, responseBody json.RawMessage) error {
|
||||
query := `
|
||||
UPDATE supply_idempotency_records SET
|
||||
response_code = $1,
|
||||
response_body = $2,
|
||||
status = $3,
|
||||
updated_at = $4
|
||||
WHERE id = $5
|
||||
`
|
||||
|
||||
_, err := r.pool.Exec(ctx, query, responseCode, responseBody, IdempotencyStatusSucceeded, time.Now(), id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update idempotency record to success: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateFailed 更新为失败状态
|
||||
func (r *IdempotencyRepository) UpdateFailed(ctx context.Context, id int64, responseCode int, responseBody json.RawMessage) error {
|
||||
query := `
|
||||
UPDATE supply_idempotency_records SET
|
||||
response_code = $1,
|
||||
response_body = $2,
|
||||
status = $3,
|
||||
updated_at = $4
|
||||
WHERE id = $5
|
||||
`
|
||||
|
||||
_, err := r.pool.Exec(ctx, query, responseCode, responseBody, IdempotencyStatusFailed, time.Now(), id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update idempotency record to failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteExpired 删除过期记录(定时清理)
|
||||
func (r *IdempotencyRepository) DeleteExpired(ctx context.Context) (int64, error) {
|
||||
query := `DELETE FROM supply_idempotency_records WHERE expires_at < $1`
|
||||
|
||||
cmdTag, err := r.pool.Exec(ctx, query, time.Now())
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to delete expired idempotency records: %w", err)
|
||||
}
|
||||
|
||||
return cmdTag.RowsAffected(), nil
|
||||
}
|
||||
|
||||
// GetByRequestID 根据请求ID获取记录
|
||||
func (r *IdempotencyRepository) GetByRequestID(ctx context.Context, requestID string) (*IdempotencyRecord, error) {
|
||||
query := `
|
||||
SELECT id, tenant_id, operator_id, api_path, idempotency_key,
|
||||
request_id, payload_hash, response_code, response_body,
|
||||
status, expires_at, created_at, updated_at
|
||||
FROM supply_idempotency_records
|
||||
WHERE request_id = $1
|
||||
`
|
||||
|
||||
record := &IdempotencyRecord{}
|
||||
err := r.pool.QueryRow(ctx, query, requestID).Scan(
|
||||
&record.ID, &record.TenantID, &record.OperatorID, &record.APIPath, &record.IdempotencyKey,
|
||||
&record.RequestID, &record.PayloadHash, &record.ResponseCode, &record.ResponseBody,
|
||||
&record.Status, &record.ExpiresAt, &record.CreatedAt, &record.UpdatedAt,
|
||||
)
|
||||
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get idempotency record by request_id: %w", err)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// CheckExists 检查幂等记录是否存在(用于竞争条件检测)
|
||||
func (r *IdempotencyRepository) CheckExists(ctx context.Context, tenantID, operatorID int64, apiPath, idempotencyKey string) (bool, error) {
|
||||
query := `
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM supply_idempotency_records
|
||||
WHERE tenant_id = $1 AND operator_id = $2 AND api_path = $3 AND idempotency_key = $4
|
||||
AND expires_at > $5
|
||||
)
|
||||
`
|
||||
|
||||
var exists bool
|
||||
err := r.pool.QueryRow(ctx, query, tenantID, operatorID, apiPath, idempotencyKey, time.Now()).Scan(&exists)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check idempotency record existence: %w", err)
|
||||
}
|
||||
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
// AcquireLock 尝试获取幂等锁(用于创建记录)
|
||||
func (r *IdempotencyRepository) AcquireLock(ctx context.Context, tenantID, operatorID int64, apiPath, idempotencyKey string, ttl time.Duration) (*IdempotencyRecord, error) {
|
||||
// 先尝试插入
|
||||
record := &IdempotencyRecord{
|
||||
TenantID: tenantID,
|
||||
OperatorID: operatorID,
|
||||
APIPath: apiPath,
|
||||
IdempotencyKey: idempotencyKey,
|
||||
RequestID: "", // 稍后填充
|
||||
PayloadHash: "", // 稍后填充
|
||||
Status: IdempotencyStatusProcessing,
|
||||
ExpiresAt: time.Now().Add(ttl),
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO supply_idempotency_records (
|
||||
tenant_id, operator_id, api_path, idempotency_key,
|
||||
request_id, payload_hash, status, expires_at
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7, $8
|
||||
)
|
||||
ON CONFLICT (tenant_id, operator_id, api_path, idempotency_key)
|
||||
DO UPDATE SET
|
||||
request_id = EXCLUDED.request_id,
|
||||
payload_hash = EXCLUDED.payload_hash,
|
||||
status = EXCLUDED.status,
|
||||
expires_at = EXCLUDED.expires_at,
|
||||
updated_at = now()
|
||||
WHERE supply_idempotency_records.expires_at <= $8
|
||||
RETURNING id, created_at, updated_at, status
|
||||
`
|
||||
|
||||
err := r.pool.QueryRow(ctx, query,
|
||||
record.TenantID, record.OperatorID, record.APIPath, record.IdempotencyKey,
|
||||
record.RequestID, record.PayloadHash, record.Status, record.ExpiresAt,
|
||||
).Scan(&record.ID, &record.CreatedAt, &record.UpdatedAt, &record.Status)
|
||||
|
||||
if err != nil {
|
||||
// 可能是重复插入
|
||||
existing, getErr := r.GetByKey(ctx, tenantID, operatorID, apiPath, idempotencyKey)
|
||||
if getErr != nil {
|
||||
return nil, fmt.Errorf("failed to acquire idempotency lock: %w (get err: %v)", err, getErr)
|
||||
}
|
||||
if existing != nil {
|
||||
return existing, nil // 返回已存在的记录
|
||||
}
|
||||
return nil, fmt.Errorf("failed to acquire idempotency lock: %w", err)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
250
supply-api/internal/repository/package.go
Normal file
250
supply-api/internal/repository/package.go
Normal file
@@ -0,0 +1,250 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"lijiaoqiao/supply-api/internal/domain"
|
||||
)
|
||||
|
||||
// PackageRepository 套餐仓储
|
||||
type PackageRepository struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
// NewPackageRepository 创建套餐仓储
|
||||
func NewPackageRepository(pool *pgxpool.Pool) *PackageRepository {
|
||||
return &PackageRepository{pool: pool}
|
||||
}
|
||||
|
||||
// Create 创建套餐
|
||||
func (r *PackageRepository) Create(ctx context.Context, pkg *domain.Package, requestID, traceID string) error {
|
||||
query := `
|
||||
INSERT INTO supply_packages (
|
||||
supply_account_id, user_id, platform, model,
|
||||
total_quota, available_quota, sold_quota, reserved_quota,
|
||||
price_per_1m_input, price_per_1m_output, min_purchase,
|
||||
start_at, end_at, valid_days,
|
||||
status, max_concurrent, rate_limit_rpm,
|
||||
total_orders, total_revenue, rating, rating_count,
|
||||
quota_unit, price_unit, currency_code, version,
|
||||
created_ip, updated_ip, audit_trace_id,
|
||||
request_id
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21,
|
||||
$22, $23, $24, $25, $26, $27, $28
|
||||
)
|
||||
RETURNING id, created_at, updated_at
|
||||
`
|
||||
|
||||
var startAt, endAt *time.Time
|
||||
if !pkg.StartAt.IsZero() {
|
||||
startAt = &pkg.StartAt
|
||||
}
|
||||
if !pkg.EndAt.IsZero() {
|
||||
endAt = &pkg.EndAt
|
||||
}
|
||||
|
||||
err := r.pool.QueryRow(ctx, query,
|
||||
pkg.SupplierID, pkg.SupplierID, pkg.Platform, pkg.Model,
|
||||
pkg.TotalQuota, pkg.AvailableQuota, pkg.SoldQuota, pkg.ReservedQuota,
|
||||
pkg.PricePer1MInput, pkg.PricePer1MOutput, pkg.MinPurchase,
|
||||
startAt, endAt, pkg.ValidDays,
|
||||
pkg.Status, pkg.MaxConcurrent, pkg.RateLimitRPM,
|
||||
pkg.TotalOrders, pkg.TotalRevenue, pkg.Rating, pkg.RatingCount,
|
||||
"token", "per_1m_tokens", "USD", 0,
|
||||
nil, nil, traceID,
|
||||
requestID,
|
||||
).Scan(&pkg.ID, &pkg.CreatedAt, &pkg.UpdatedAt)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create package: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByID 获取套餐
|
||||
func (r *PackageRepository) GetByID(ctx context.Context, supplierID, id int64) (*domain.Package, error) {
|
||||
query := `
|
||||
SELECT id, supply_account_id, user_id, platform, model,
|
||||
total_quota, available_quota, sold_quota, reserved_quota,
|
||||
price_per_1m_input, price_per_1m_output, min_purchase,
|
||||
start_at, end_at, valid_days,
|
||||
status, max_concurrent, rate_limit_rpm,
|
||||
total_orders, total_revenue, rating, rating_count,
|
||||
quota_unit, price_unit, currency_code, version,
|
||||
created_at, updated_at
|
||||
FROM supply_packages
|
||||
WHERE id = $1 AND user_id = $2
|
||||
`
|
||||
|
||||
pkg := &domain.Package{}
|
||||
var startAt, endAt pgx.NullTime
|
||||
err := r.pool.QueryRow(ctx, query, id, supplierID).Scan(
|
||||
&pkg.ID, &pkg.SupplierID, &pkg.SupplierID, &pkg.Platform, &pkg.Model,
|
||||
&pkg.TotalQuota, &pkg.AvailableQuota, &pkg.SoldQuota, &pkg.ReservedQuota,
|
||||
&pkg.PricePer1MInput, &pkg.PricePer1MOutput, &pkg.MinPurchase,
|
||||
&startAt, &endAt, &pkg.ValidDays,
|
||||
&pkg.Status, &pkg.MaxConcurrent, &pkg.RateLimitRPM,
|
||||
&pkg.TotalOrders, &pkg.TotalRevenue, &pkg.Rating, &pkg.RatingCount,
|
||||
&pkg.QuotaUnit, &pkg.PriceUnit, &pkg.CurrencyCode, &pkg.Version,
|
||||
&pkg.CreatedAt, &pkg.UpdatedAt,
|
||||
)
|
||||
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get package: %w", err)
|
||||
}
|
||||
|
||||
if startAt.Valid {
|
||||
pkg.StartAt = startAt.Time
|
||||
}
|
||||
if endAt.Valid {
|
||||
pkg.EndAt = endAt.Time
|
||||
}
|
||||
|
||||
return pkg, nil
|
||||
}
|
||||
|
||||
// Update 更新套餐(乐观锁)
|
||||
func (r *PackageRepository) Update(ctx context.Context, pkg *domain.Package, expectedVersion int) error {
|
||||
query := `
|
||||
UPDATE supply_packages SET
|
||||
platform = $1, model = $2,
|
||||
total_quota = $3, available_quota = $4, sold_quota = $5, reserved_quota = $6,
|
||||
price_per_1m_input = $7, price_per_1m_output = $8,
|
||||
start_at = $9, end_at = $10, valid_days = $11,
|
||||
status = $12, max_concurrent = $13, rate_limit_rpm = $14,
|
||||
total_orders = $15, total_revenue = $16,
|
||||
rating = $17, rating_count = $18,
|
||||
version = $19, updated_at = $20
|
||||
WHERE id = $21 AND user_id = $22 AND version = $23
|
||||
`
|
||||
|
||||
pkg.UpdatedAt = time.Now()
|
||||
newVersion := expectedVersion + 1
|
||||
|
||||
cmdTag, err := r.pool.Exec(ctx, query,
|
||||
pkg.Platform, pkg.Model,
|
||||
pkg.TotalQuota, pkg.AvailableQuota, pkg.SoldQuota, pkg.ReservedQuota,
|
||||
pkg.PricePer1MInput, pkg.PricePer1MOutput,
|
||||
pkg.StartAt, pkg.EndAt, pkg.ValidDays,
|
||||
pkg.Status, pkg.MaxConcurrent, pkg.RateLimitRPM,
|
||||
pkg.TotalOrders, pkg.TotalRevenue,
|
||||
pkg.Rating, pkg.RatingCount,
|
||||
newVersion, pkg.UpdatedAt,
|
||||
pkg.ID, pkg.SupplierID, expectedVersion,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update package: %w", err)
|
||||
}
|
||||
|
||||
if cmdTag.RowsAffected() == 0 {
|
||||
return ErrConcurrencyConflict
|
||||
}
|
||||
|
||||
pkg.Version = newVersion
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetForUpdate 获取套餐并加行锁
|
||||
func (r *PackageRepository) GetForUpdate(ctx context.Context, tx pgxpool.Tx, supplierID, id int64) (*domain.Package, error) {
|
||||
query := `
|
||||
SELECT id, supply_account_id, user_id, platform, model,
|
||||
total_quota, available_quota, sold_quota, reserved_quota,
|
||||
price_per_1m_input, price_per_1m_output,
|
||||
status, version,
|
||||
created_at, updated_at
|
||||
FROM supply_packages
|
||||
WHERE id = $1 AND user_id = $2
|
||||
FOR UPDATE
|
||||
`
|
||||
|
||||
pkg := &domain.Package{}
|
||||
err := tx.QueryRow(ctx, query, id, supplierID).Scan(
|
||||
&pkg.ID, &pkg.SupplierID, &pkg.SupplierID, &pkg.Platform, &pkg.Model,
|
||||
&pkg.TotalQuota, &pkg.AvailableQuota, &pkg.SoldQuota, &pkg.ReservedQuota,
|
||||
&pkg.PricePer1MInput, &pkg.PricePer1MOutput,
|
||||
&pkg.Status, &pkg.Version,
|
||||
&pkg.CreatedAt, &pkg.UpdatedAt,
|
||||
)
|
||||
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get package for update: %w", err)
|
||||
}
|
||||
|
||||
return pkg, nil
|
||||
}
|
||||
|
||||
// List 列出套餐
|
||||
func (r *PackageRepository) List(ctx context.Context, supplierID int64) ([]*domain.Package, error) {
|
||||
query := `
|
||||
SELECT id, supply_account_id, user_id, platform, model,
|
||||
total_quota, available_quota, sold_quota,
|
||||
price_per_1m_input, price_per_1m_output,
|
||||
status, max_concurrent, rate_limit_rpm,
|
||||
valid_days, total_orders, total_revenue,
|
||||
version, created_at, updated_at
|
||||
FROM supply_packages
|
||||
WHERE user_id = $1
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
rows, err := r.pool.Query(ctx, query, supplierID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list packages: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var packages []*domain.Package
|
||||
for rows.Next() {
|
||||
pkg := &domain.Package{}
|
||||
err := rows.Scan(
|
||||
&pkg.ID, &pkg.SupplierID, &pkg.SupplierID, &pkg.Platform, &pkg.Model,
|
||||
&pkg.TotalQuota, &pkg.AvailableQuota, &pkg.SoldQuota,
|
||||
&pkg.PricePer1MInput, &pkg.PricePer1MOutput,
|
||||
&pkg.Status, &pkg.MaxConcurrent, &pkg.RateLimitRPM,
|
||||
&pkg.ValidDays, &pkg.TotalOrders, &pkg.TotalRevenue,
|
||||
&pkg.Version, &pkg.CreatedAt, &pkg.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan package: %w", err)
|
||||
}
|
||||
packages = append(packages, pkg)
|
||||
}
|
||||
|
||||
return packages, nil
|
||||
}
|
||||
|
||||
// UpdateQuota 扣减配额
|
||||
func (r *PackageRepository) UpdateQuota(ctx context.Context, tx pgxpool.Tx, packageID, supplierID int64, usedQuota float64) error {
|
||||
query := `
|
||||
UPDATE supply_packages SET
|
||||
available_quota = available_quota - $1,
|
||||
sold_quota = sold_quota + $1,
|
||||
updated_at = $2
|
||||
WHERE id = $3 AND user_id = $4 AND available_quota >= $1
|
||||
RETURNING id
|
||||
`
|
||||
|
||||
var id int64
|
||||
err := tx.QueryRow(ctx, query, usedQuota, time.Now(), packageID, supplierID).Scan(&id)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return errors.New("insufficient quota or package not found")
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update quota: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
243
supply-api/internal/repository/settlement.go
Normal file
243
supply-api/internal/repository/settlement.go
Normal file
@@ -0,0 +1,243 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"lijiaoqiao/supply-api/internal/domain"
|
||||
)
|
||||
|
||||
// SettlementRepository 结算仓储
|
||||
type SettlementRepository struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
// NewSettlementRepository 创建结算仓储
|
||||
func NewSettlementRepository(pool *pgxpool.Pool) *SettlementRepository {
|
||||
return &SettlementRepository{pool: pool}
|
||||
}
|
||||
|
||||
// Create 创建结算单
|
||||
func (r *SettlementRepository) Create(ctx context.Context, s *domain.Settlement, requestID, idempotencyKey, traceID string) error {
|
||||
query := `
|
||||
INSERT INTO supply_settlements (
|
||||
settlement_no, user_id, total_amount, fee_amount, net_amount,
|
||||
status, payment_method, payment_account,
|
||||
period_start, period_end, total_orders, total_usage_records,
|
||||
currency_code, amount_unit, version,
|
||||
request_id, idempotency_key, audit_trace_id
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18
|
||||
)
|
||||
RETURNING id, created_at, updated_at
|
||||
`
|
||||
|
||||
err := r.pool.QueryRow(ctx, query,
|
||||
s.SettlementNo, s.SupplierID, s.TotalAmount, s.FeeAmount, s.NetAmount,
|
||||
s.Status, s.PaymentMethod, s.PaymentAccount,
|
||||
s.PeriodStart, s.PeriodEnd, s.TotalOrders, s.TotalUsageRecords,
|
||||
"USD", "minor", 0,
|
||||
requestID, idempotencyKey, traceID,
|
||||
).Scan(&s.ID, &s.CreatedAt, &s.UpdatedAt)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create settlement: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByID 获取结算单
|
||||
func (r *SettlementRepository) GetByID(ctx context.Context, supplierID, id int64) (*domain.Settlement, error) {
|
||||
query := `
|
||||
SELECT id, settlement_no, user_id, total_amount, fee_amount, net_amount,
|
||||
status, payment_method, payment_account,
|
||||
period_start, period_end, total_orders, total_usage_records,
|
||||
payment_transaction_id, paid_at,
|
||||
version, created_at, updated_at
|
||||
FROM supply_settlements
|
||||
WHERE id = $1 AND user_id = $2
|
||||
`
|
||||
|
||||
s := &domain.Settlement{}
|
||||
var paidAt pgx.NullTime
|
||||
err := r.pool.QueryRow(ctx, query, id, supplierID).Scan(
|
||||
&s.ID, &s.SettlementNo, &s.SupplierID, &s.TotalAmount, &s.FeeAmount, &s.NetAmount,
|
||||
&s.Status, &s.PaymentMethod, &s.PaymentAccount,
|
||||
&s.PeriodStart, &s.PeriodEnd, &s.TotalOrders, &s.TotalUsageRecords,
|
||||
&s.PaymentTransactionID, &paidAt,
|
||||
&s.Version, &s.CreatedAt, &s.UpdatedAt,
|
||||
)
|
||||
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get settlement: %w", err)
|
||||
}
|
||||
|
||||
if paidAt.Valid {
|
||||
s.PaidAt = &paidAt.Time
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Update 更新结算单(乐观锁)
|
||||
func (r *SettlementRepository) Update(ctx context.Context, s *domain.Settlement, expectedVersion int) error {
|
||||
query := `
|
||||
UPDATE supply_settlements SET
|
||||
status = $1, payment_method = $2, payment_account = $3,
|
||||
payment_transaction_id = $4, paid_at = $5,
|
||||
total_orders = $6, total_usage_records = $7,
|
||||
version = $8, updated_at = $9
|
||||
WHERE id = $10 AND user_id = $11 AND version = $12
|
||||
`
|
||||
|
||||
s.UpdatedAt = time.Now()
|
||||
newVersion := expectedVersion + 1
|
||||
|
||||
cmdTag, err := r.pool.Exec(ctx, query,
|
||||
s.Status, s.PaymentMethod, s.PaymentAccount,
|
||||
s.PaymentTransactionID, s.PaidAt,
|
||||
s.TotalOrders, s.TotalUsageRecords,
|
||||
newVersion, s.UpdatedAt,
|
||||
s.ID, s.SupplierID, expectedVersion,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update settlement: %w", err)
|
||||
}
|
||||
|
||||
if cmdTag.RowsAffected() == 0 {
|
||||
return ErrConcurrencyConflict
|
||||
}
|
||||
|
||||
s.Version = newVersion
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetForUpdate 获取结算单并加行锁
|
||||
func (r *SettlementRepository) GetForUpdate(ctx context.Context, tx pgxpool.Tx, supplierID, id int64) (*domain.Settlement, error) {
|
||||
query := `
|
||||
SELECT id, settlement_no, user_id, total_amount, fee_amount, net_amount,
|
||||
status, payment_method, payment_account, version,
|
||||
created_at, updated_at
|
||||
FROM supply_settlements
|
||||
WHERE id = $1 AND user_id = $2
|
||||
FOR UPDATE
|
||||
`
|
||||
|
||||
s := &domain.Settlement{}
|
||||
err := tx.QueryRow(ctx, query, id, supplierID).Scan(
|
||||
&s.ID, &s.SettlementNo, &s.SupplierID, &s.TotalAmount, &s.FeeAmount, &s.NetAmount,
|
||||
&s.Status, &s.PaymentMethod, &s.PaymentAccount, &s.Version,
|
||||
&s.CreatedAt, &s.UpdatedAt,
|
||||
)
|
||||
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get settlement for update: %w", err)
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// GetProcessing 获取处理中的结算单(用于单一性约束)
|
||||
func (r *SettlementRepository) GetProcessing(ctx context.Context, tx pgxpool.Tx, supplierID int64) (*domain.Settlement, error) {
|
||||
query := `
|
||||
SELECT id, settlement_no, user_id, total_amount, fee_amount, net_amount,
|
||||
status, payment_method, payment_account, version,
|
||||
created_at, updated_at
|
||||
FROM supply_settlements
|
||||
WHERE user_id = $1 AND status = 'processing'
|
||||
FOR UPDATE SKIP LOCKED
|
||||
LIMIT 1
|
||||
`
|
||||
|
||||
s := &domain.Settlement{}
|
||||
err := tx.QueryRow(ctx, query, supplierID).Scan(
|
||||
&s.ID, &s.SettlementNo, &s.SupplierID, &s.TotalAmount, &s.FeeAmount, &s.NetAmount,
|
||||
&s.Status, &s.PaymentMethod, &s.PaymentAccount, &s.Version,
|
||||
&s.CreatedAt, &s.UpdatedAt,
|
||||
)
|
||||
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, nil // 没有处理中的单据
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get processing settlement: %w", err)
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// List 列出结算单
|
||||
func (r *SettlementRepository) List(ctx context.Context, supplierID int64) ([]*domain.Settlement, error) {
|
||||
query := `
|
||||
SELECT id, settlement_no, user_id, total_amount, fee_amount, net_amount,
|
||||
status, payment_method,
|
||||
period_start, period_end, total_orders,
|
||||
version, created_at, updated_at
|
||||
FROM supply_settlements
|
||||
WHERE user_id = $1
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
rows, err := r.pool.Query(ctx, query, supplierID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list settlements: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var settlements []*domain.Settlement
|
||||
for rows.Next() {
|
||||
s := &domain.Settlement{}
|
||||
err := rows.Scan(
|
||||
&s.ID, &s.SettlementNo, &s.SupplierID, &s.TotalAmount, &s.FeeAmount, &s.NetAmount,
|
||||
&s.Status, &s.PaymentMethod,
|
||||
&s.PeriodStart, &s.PeriodEnd, &s.TotalOrders,
|
||||
&s.Version, &s.CreatedAt, &s.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan settlement: %w", err)
|
||||
}
|
||||
settlements = append(settlements, s)
|
||||
}
|
||||
|
||||
return settlements, nil
|
||||
}
|
||||
|
||||
// CreateInTx 在事务中创建结算单
|
||||
func (r *SettlementRepository) CreateInTx(ctx context.Context, tx pgxpool.Tx, s *domain.Settlement, requestID, idempotencyKey, traceID string) error {
|
||||
query := `
|
||||
INSERT INTO supply_settlements (
|
||||
settlement_no, user_id, total_amount, fee_amount, net_amount,
|
||||
status, payment_method, payment_account,
|
||||
period_start, period_end, total_orders, total_usage_records,
|
||||
currency_code, amount_unit, version,
|
||||
request_id, idempotency_key, audit_trace_id
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18
|
||||
)
|
||||
RETURNING id, created_at, updated_at
|
||||
`
|
||||
|
||||
err := tx.QueryRow(ctx, query,
|
||||
s.SettlementNo, s.SupplierID, s.TotalAmount, s.FeeAmount, s.NetAmount,
|
||||
s.Status, s.PaymentMethod, s.PaymentAccount,
|
||||
s.PeriodStart, s.PeriodEnd, s.TotalOrders, s.TotalUsageRecords,
|
||||
"USD", "minor", 0,
|
||||
requestID, idempotencyKey, traceID,
|
||||
).Scan(&s.ID, &s.CreatedAt, &s.UpdatedAt)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create settlement in tx: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user