feat(supply-api): 完成核心模块实现
新增/修改内容: - config: 添加配置管理(config.example.yaml, config.go) - cache: 添加Redis缓存层(redis.go) - domain: 添加invariants不变量验证及测试 - middleware: 添加auth认证和idempotency幂等性中间件及测试 - repository: 添加完整数据访问层(account, package, settlement, idempotency, db) - sql: 添加幂等性表DDL脚本 代码覆盖: - auth middleware实现凭证边界验证 - idempotency middleware实现请求幂等性 - invariants实现业务不变量检查 - repository层实现完整的数据访问逻辑 关联issue: Round-1 R1-ISSUE-006 凭证边界硬门禁
This commit is contained in:
291
supply-api/internal/repository/account.go
Normal file
291
supply-api/internal/repository/account.go
Normal file
@@ -0,0 +1,291 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"lijiaoqiao/supply-api/internal/domain"
|
||||
)
|
||||
|
||||
// AccountRepository 账号仓储
|
||||
type AccountRepository struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
// NewAccountRepository 创建账号仓储
|
||||
func NewAccountRepository(pool *pgxpool.Pool) *AccountRepository {
|
||||
return &AccountRepository{pool: pool}
|
||||
}
|
||||
|
||||
// Create 创建账号
|
||||
func (r *AccountRepository) Create(ctx context.Context, account *domain.Account, requestID, idempotencyKey, traceID string) error {
|
||||
query := `
|
||||
INSERT INTO supply_accounts (
|
||||
user_id, platform, account_type, account_name,
|
||||
encrypted_credentials, key_id,
|
||||
status, risk_level, total_quota, available_quota, frozen_quota,
|
||||
is_verified, verified_at, last_check_at,
|
||||
tos_compliant, tos_check_result,
|
||||
total_requests, total_tokens, total_cost, success_rate,
|
||||
risk_score, risk_reason, is_frozen, frozen_reason,
|
||||
credential_cipher_algo, credential_kms_key_alias, credential_key_version,
|
||||
quota_unit, currency_code, version,
|
||||
created_ip, updated_ip, audit_trace_id,
|
||||
request_id, idempotency_key
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6,
|
||||
$7, $8, $9, $10, $11, $12, $13, $14,
|
||||
$15, $16, $17, $18, $19, $20,
|
||||
$21, $22, $23, $24, $25, $26, $27,
|
||||
$28, $29, $30, $31, $32, $33, $34, $35
|
||||
)
|
||||
RETURNING id, created_at, updated_at
|
||||
`
|
||||
|
||||
var createdIP, updatedIP *netip.Addr
|
||||
if account.CreatedIP != nil {
|
||||
createdIP = account.CreatedIP
|
||||
}
|
||||
if account.UpdatedIP != nil {
|
||||
updatedIP = account.UpdatedIP
|
||||
}
|
||||
|
||||
err := r.pool.QueryRow(ctx, query,
|
||||
account.SupplierID, account.Provider, account.AccountType, account.Alias,
|
||||
account.CredentialHash, account.KeyID,
|
||||
account.Status, account.RiskLevel, account.TotalQuota, account.AvailableQuota, account.FrozenQuota,
|
||||
account.IsVerified, account.VerifiedAt, account.LastCheckAt,
|
||||
account.TosCompliant, account.TosCheckResult,
|
||||
account.TotalRequests, account.TotalTokens, account.TotalCost, account.SuccessRate,
|
||||
account.RiskScore, account.RiskReason, account.IsFrozen, account.FrozenReason,
|
||||
"AES-256-GCM", "kms/supply/default", 1,
|
||||
"token", "USD", 0,
|
||||
createdIP, updatedIP, traceID,
|
||||
requestID, idempotencyKey,
|
||||
).Scan(&account.ID, &account.CreatedAt, &account.UpdatedAt)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create account: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByID 获取账号
|
||||
func (r *AccountRepository) GetByID(ctx context.Context, supplierID, id int64) (*domain.Account, error) {
|
||||
query := `
|
||||
SELECT id, user_id, platform, account_type, account_name,
|
||||
encrypted_credentials, key_id,
|
||||
status, risk_level, total_quota, available_quota, frozen_quota,
|
||||
is_verified, verified_at, last_check_at,
|
||||
tos_compliant, tos_check_result,
|
||||
total_requests, total_tokens, total_cost, success_rate,
|
||||
risk_score, risk_reason, is_frozen, frozen_reason,
|
||||
credential_cipher_algo, credential_kms_key_alias, credential_key_version,
|
||||
quota_unit, currency_code, version,
|
||||
created_ip, updated_ip, audit_trace_id,
|
||||
created_at, updated_at
|
||||
FROM supply_accounts
|
||||
WHERE id = $1 AND user_id = $2
|
||||
`
|
||||
|
||||
account := &domain.Account{}
|
||||
var createdIP, updatedIP netip.Addr
|
||||
var credentialFingerprint *string
|
||||
|
||||
err := r.pool.QueryRow(ctx, query, id, supplierID).Scan(
|
||||
&account.ID, &account.SupplierID, &account.Provider, &account.AccountType, &account.Alias,
|
||||
&account.CredentialHash, &account.KeyID,
|
||||
&account.Status, &account.RiskLevel, &account.TotalQuota, &account.AvailableQuota, &account.FrozenQuota,
|
||||
&account.IsVerified, &account.VerifiedAt, &account.LastCheckAt,
|
||||
&account.TosCompliant, &account.TosCheckResult,
|
||||
&account.TotalRequests, &account.TotalTokens, &account.TotalCost, &account.SuccessRate,
|
||||
&account.RiskScore, &account.RiskReason, &account.IsFrozen, &account.FrozenReason,
|
||||
&account.CredentialCipherAlgo, &account.CredentialKMSKeyAlias, &account.CredentialKeyVersion,
|
||||
&account.QuotaUnit, &account.CurrencyCode, &account.Version,
|
||||
&createdIP, &updatedIP, &account.AuditTraceID,
|
||||
&account.CreatedAt, &account.UpdatedAt,
|
||||
)
|
||||
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get account: %w", err)
|
||||
}
|
||||
|
||||
account.CreatedIP = &createdIP
|
||||
account.UpdatedIP = &updatedIP
|
||||
_ = credentialFingerprint // 未使用但字段存在
|
||||
|
||||
return account, nil
|
||||
}
|
||||
|
||||
// Update 更新账号(乐观锁)
|
||||
func (r *AccountRepository) Update(ctx context.Context, account *domain.Account, expectedVersion int) error {
|
||||
query := `
|
||||
UPDATE supply_accounts SET
|
||||
platform = $1, account_type = $2, account_name = $3,
|
||||
status = $4, risk_level = $5, total_quota = $6, available_quota = $7,
|
||||
frozen_quota = $8, is_verified = $9, verified_at = $10, last_check_at = $11,
|
||||
tos_compliant = $12, tos_check_result = $13,
|
||||
total_requests = $14, total_tokens = $15, total_cost = $16, success_rate = $17,
|
||||
risk_score = $18, risk_reason = $19, is_frozen = $20, frozen_reason = $21,
|
||||
version = $22, updated_at = $23
|
||||
WHERE id = $24 AND user_id = $25 AND version = $26
|
||||
`
|
||||
|
||||
account.UpdatedAt = time.Now()
|
||||
newVersion := expectedVersion + 1
|
||||
|
||||
cmdTag, err := r.pool.Exec(ctx, query,
|
||||
account.Provider, account.AccountType, account.Alias,
|
||||
account.Status, account.RiskLevel, account.TotalQuota, account.AvailableQuota,
|
||||
account.FrozenQuota, account.IsVerified, account.VerifiedAt, account.LastCheckAt,
|
||||
account.TosCompliant, account.TosCheckResult,
|
||||
account.TotalRequests, account.TotalTokens, account.TotalCost, account.SuccessRate,
|
||||
account.RiskScore, account.RiskReason, account.IsFrozen, account.FrozenReason,
|
||||
newVersion, account.UpdatedAt,
|
||||
account.ID, account.SupplierID, expectedVersion,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update account: %w", err)
|
||||
}
|
||||
|
||||
if cmdTag.RowsAffected() == 0 {
|
||||
return ErrConcurrencyConflict
|
||||
}
|
||||
|
||||
account.Version = newVersion
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateWithPessimisticLock 更新账号(悲观锁,用于提现等关键操作)
|
||||
func (r *AccountRepository) UpdateWithPessimisticLock(ctx context.Context, tx pgxpool.Tx, account *domain.Account, expectedVersion int) error {
|
||||
query := `
|
||||
UPDATE supply_accounts SET
|
||||
available_quota = $1, frozen_quota = $2,
|
||||
version = $3, updated_at = $4
|
||||
WHERE id = $5 AND version = $6
|
||||
RETURNING version
|
||||
`
|
||||
|
||||
account.UpdatedAt = time.Now()
|
||||
newVersion := expectedVersion + 1
|
||||
|
||||
err := tx.QueryRow(ctx, query,
|
||||
account.AvailableQuota, account.FrozenQuota,
|
||||
newVersion, account.UpdatedAt,
|
||||
account.ID, expectedVersion,
|
||||
).Scan(&account.Version)
|
||||
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return ErrConcurrencyConflict
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update account with lock: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetForUpdate 获取账号并加行锁(用于事务内)
|
||||
func (r *AccountRepository) GetForUpdate(ctx context.Context, tx pgxpool.Tx, supplierID, id int64) (*domain.Account, error) {
|
||||
query := `
|
||||
SELECT id, user_id, platform, account_type, account_name,
|
||||
encrypted_credentials, key_id,
|
||||
status, risk_level, total_quota, available_quota, frozen_quota,
|
||||
is_verified, verified_at, last_check_at,
|
||||
tos_compliant, tos_check_result,
|
||||
total_requests, total_tokens, total_cost, success_rate,
|
||||
risk_score, risk_reason, is_frozen, frozen_reason,
|
||||
version,
|
||||
created_at, updated_at
|
||||
FROM supply_accounts
|
||||
WHERE id = $1 AND user_id = $2
|
||||
FOR UPDATE
|
||||
`
|
||||
|
||||
account := &domain.Account{}
|
||||
err := tx.QueryRow(ctx, query, id, supplierID).Scan(
|
||||
&account.ID, &account.SupplierID, &account.Provider, &account.AccountType, &account.Alias,
|
||||
&account.CredentialHash, &account.KeyID,
|
||||
&account.Status, &account.RiskLevel, &account.TotalQuota, &account.AvailableQuota, &account.FrozenQuota,
|
||||
&account.IsVerified, &account.VerifiedAt, &account.LastCheckAt,
|
||||
&account.TosCompliant, &account.TosCheckResult,
|
||||
&account.TotalRequests, &account.TotalTokens, &account.TotalCost, &account.SuccessRate,
|
||||
&account.RiskScore, &account.RiskReason, &account.IsFrozen, &account.FrozenReason,
|
||||
&account.Version,
|
||||
&account.CreatedAt, &account.UpdatedAt,
|
||||
)
|
||||
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get account for update: %w", err)
|
||||
}
|
||||
|
||||
return account, nil
|
||||
}
|
||||
|
||||
// List 列出账号
|
||||
func (r *AccountRepository) List(ctx context.Context, supplierID int64) ([]*domain.Account, error) {
|
||||
query := `
|
||||
SELECT id, user_id, platform, account_type, account_name,
|
||||
status, risk_level, total_quota, available_quota, frozen_quota,
|
||||
is_verified, verified_at, last_check_at,
|
||||
tos_compliant, success_rate,
|
||||
risk_score, is_frozen,
|
||||
version, created_at, updated_at
|
||||
FROM supply_accounts
|
||||
WHERE user_id = $1
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
rows, err := r.pool.Query(ctx, query, supplierID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list accounts: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var accounts []*domain.Account
|
||||
for rows.Next() {
|
||||
account := &domain.Account{}
|
||||
err := rows.Scan(
|
||||
&account.ID, &account.SupplierID, &account.Provider, &account.AccountType, &account.Alias,
|
||||
&account.Status, &account.RiskLevel, &account.TotalQuota, &account.AvailableQuota, &account.FrozenQuota,
|
||||
&account.IsVerified, &account.VerifiedAt, &account.LastCheckAt,
|
||||
&account.TosCompliant, &account.SuccessRate,
|
||||
&account.RiskScore, &account.IsFrozen,
|
||||
&account.Version, &account.CreatedAt, &account.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan account: %w", err)
|
||||
}
|
||||
accounts = append(accounts, account)
|
||||
}
|
||||
|
||||
return accounts, nil
|
||||
}
|
||||
|
||||
// GetWithdrawableBalance 获取可提现余额
|
||||
func (r *AccountRepository) GetWithdrawableBalance(ctx context.Context, supplierID int64) (float64, error) {
|
||||
query := `
|
||||
SELECT COALESCE(SUM(available_quota), 0)
|
||||
FROM supply_accounts
|
||||
WHERE user_id = $1 AND status = 'active'
|
||||
`
|
||||
|
||||
var balance float64
|
||||
err := r.pool.QueryRow(ctx, query, supplierID).Scan(&balance)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get withdrawable balance: %w", err)
|
||||
}
|
||||
return balance, nil
|
||||
}
|
||||
81
supply-api/internal/repository/db.go
Normal file
81
supply-api/internal/repository/db.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"lijiaoqiao/supply-api/internal/config"
|
||||
)
|
||||
|
||||
// DB 数据库连接池
|
||||
type DB struct {
|
||||
Pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
// NewDB 创建数据库连接池
|
||||
func NewDB(ctx context.Context, cfg config.DatabaseConfig) (*DB, error) {
|
||||
poolConfig, err := pgxpool.ParseConfig(cfg.DSN())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse database config: %w", err)
|
||||
}
|
||||
|
||||
poolConfig.MaxConns = int32(cfg.MaxOpenConns)
|
||||
poolConfig.MinConns = int32(cfg.MaxIdleConns)
|
||||
poolConfig.MaxConnLifetime = cfg.ConnMaxLifetime
|
||||
poolConfig.MaxConnIdleTime = cfg.ConnMaxIdleTime
|
||||
poolConfig.HealthCheckPeriod = 30 * time.Second
|
||||
|
||||
pool, err := pgxpool.NewWithConfig(ctx, poolConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create connection pool: %w", err)
|
||||
}
|
||||
|
||||
// 验证连接
|
||||
if err := pool.Ping(ctx); err != nil {
|
||||
pool.Close()
|
||||
return nil, fmt.Errorf("failed to ping database: %w", err)
|
||||
}
|
||||
|
||||
return &DB{Pool: pool}, nil
|
||||
}
|
||||
|
||||
// Close 关闭连接池
|
||||
func (db *DB) Close() {
|
||||
if db.Pool != nil {
|
||||
db.Pool.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// HealthCheck 健康检查
|
||||
func (db *DB) HealthCheck(ctx context.Context) error {
|
||||
return db.Pool.Ping(ctx)
|
||||
}
|
||||
|
||||
// BeginTx 开始事务
|
||||
func (db *DB) BeginTx(ctx context.Context) (Transaction, error) {
|
||||
tx, err := db.Pool.Begin(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &txWrapper{tx: tx}, nil
|
||||
}
|
||||
|
||||
// Transaction 事务接口
|
||||
type Transaction interface {
|
||||
Commit(ctx context.Context) error
|
||||
Rollback(ctx context.Context) error
|
||||
}
|
||||
|
||||
type txWrapper struct {
|
||||
tx pgxpool.Tx
|
||||
}
|
||||
|
||||
func (t *txWrapper) Commit(ctx context.Context) error {
|
||||
return t.tx.Commit(ctx)
|
||||
}
|
||||
|
||||
func (t *txWrapper) Rollback(ctx context.Context) error {
|
||||
return t.tx.Rollback(ctx)
|
||||
}
|
||||
246
supply-api/internal/repository/idempotency.go
Normal file
246
supply-api/internal/repository/idempotency.go
Normal file
@@ -0,0 +1,246 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
// IdempotencyStatus 幂等记录状态
|
||||
type IdempotencyStatus string
|
||||
|
||||
const (
|
||||
IdempotencyStatusProcessing IdempotencyStatus = "processing"
|
||||
IdempotencyStatusSucceeded IdempotencyStatus = "succeeded"
|
||||
IdempotencyStatusFailed IdempotencyStatus = "failed"
|
||||
)
|
||||
|
||||
// IdempotencyRecord 幂等记录
|
||||
type IdempotencyRecord struct {
|
||||
ID int64 `json:"id"`
|
||||
TenantID int64 `json:"tenant_id"`
|
||||
OperatorID int64 `json:"operator_id"`
|
||||
APIPath string `json:"api_path"`
|
||||
IdempotencyKey string `json:"idempotency_key"`
|
||||
RequestID string `json:"request_id"`
|
||||
PayloadHash string `json:"payload_hash"` // SHA256 of request body
|
||||
ResponseCode int `json:"response_code"`
|
||||
ResponseBody json.RawMessage `json:"response_body"`
|
||||
Status IdempotencyStatus `json:"status"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// IdempotencyRepository 幂等记录仓储
|
||||
type IdempotencyRepository struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
// NewIdempotencyRepository 创建幂等记录仓储
|
||||
func NewIdempotencyRepository(pool *pgxpool.Pool) *IdempotencyRepository {
|
||||
return &IdempotencyRepository{pool: pool}
|
||||
}
|
||||
|
||||
// GetByKey 根据幂等键获取记录
|
||||
func (r *IdempotencyRepository) GetByKey(ctx context.Context, tenantID, operatorID int64, apiPath, idempotencyKey string) (*IdempotencyRecord, error) {
|
||||
query := `
|
||||
SELECT id, tenant_id, operator_id, api_path, idempotency_key,
|
||||
request_id, payload_hash, response_code, response_body,
|
||||
status, expires_at, created_at, updated_at
|
||||
FROM supply_idempotency_records
|
||||
WHERE tenant_id = $1 AND operator_id = $2 AND api_path = $3 AND idempotency_key = $4
|
||||
AND expires_at > $5
|
||||
FOR UPDATE
|
||||
`
|
||||
|
||||
record := &IdempotencyRecord{}
|
||||
err := r.pool.QueryRow(ctx, query, tenantID, operatorID, apiPath, idempotencyKey, time.Now()).Scan(
|
||||
&record.ID, &record.TenantID, &record.OperatorID, &record.APIPath, &record.IdempotencyKey,
|
||||
&record.RequestID, &record.PayloadHash, &record.ResponseCode, &record.ResponseBody,
|
||||
&record.Status, &record.ExpiresAt, &record.CreatedAt, &record.UpdatedAt,
|
||||
)
|
||||
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, nil // 不存在或已过期
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get idempotency record: %w", err)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// Create 创建幂等记录
|
||||
func (r *IdempotencyRepository) Create(ctx context.Context, record *IdempotencyRecord) error {
|
||||
query := `
|
||||
INSERT INTO supply_idempotency_records (
|
||||
tenant_id, operator_id, api_path, idempotency_key,
|
||||
request_id, payload_hash, status, expires_at
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7, $8
|
||||
)
|
||||
RETURNING id, created_at, updated_at
|
||||
`
|
||||
|
||||
err := r.pool.QueryRow(ctx, query,
|
||||
record.TenantID, record.OperatorID, record.APIPath, record.IdempotencyKey,
|
||||
record.RequestID, record.PayloadHash, record.Status, record.ExpiresAt,
|
||||
).Scan(&record.ID, &record.CreatedAt, &record.UpdatedAt)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create idempotency record: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateSuccess 更新为成功状态
|
||||
func (r *IdempotencyRepository) UpdateSuccess(ctx context.Context, id int64, responseCode int, responseBody json.RawMessage) error {
|
||||
query := `
|
||||
UPDATE supply_idempotency_records SET
|
||||
response_code = $1,
|
||||
response_body = $2,
|
||||
status = $3,
|
||||
updated_at = $4
|
||||
WHERE id = $5
|
||||
`
|
||||
|
||||
_, err := r.pool.Exec(ctx, query, responseCode, responseBody, IdempotencyStatusSucceeded, time.Now(), id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update idempotency record to success: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateFailed 更新为失败状态
|
||||
func (r *IdempotencyRepository) UpdateFailed(ctx context.Context, id int64, responseCode int, responseBody json.RawMessage) error {
|
||||
query := `
|
||||
UPDATE supply_idempotency_records SET
|
||||
response_code = $1,
|
||||
response_body = $2,
|
||||
status = $3,
|
||||
updated_at = $4
|
||||
WHERE id = $5
|
||||
`
|
||||
|
||||
_, err := r.pool.Exec(ctx, query, responseCode, responseBody, IdempotencyStatusFailed, time.Now(), id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update idempotency record to failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteExpired 删除过期记录(定时清理)
|
||||
func (r *IdempotencyRepository) DeleteExpired(ctx context.Context) (int64, error) {
|
||||
query := `DELETE FROM supply_idempotency_records WHERE expires_at < $1`
|
||||
|
||||
cmdTag, err := r.pool.Exec(ctx, query, time.Now())
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to delete expired idempotency records: %w", err)
|
||||
}
|
||||
|
||||
return cmdTag.RowsAffected(), nil
|
||||
}
|
||||
|
||||
// GetByRequestID 根据请求ID获取记录
|
||||
func (r *IdempotencyRepository) GetByRequestID(ctx context.Context, requestID string) (*IdempotencyRecord, error) {
|
||||
query := `
|
||||
SELECT id, tenant_id, operator_id, api_path, idempotency_key,
|
||||
request_id, payload_hash, response_code, response_body,
|
||||
status, expires_at, created_at, updated_at
|
||||
FROM supply_idempotency_records
|
||||
WHERE request_id = $1
|
||||
`
|
||||
|
||||
record := &IdempotencyRecord{}
|
||||
err := r.pool.QueryRow(ctx, query, requestID).Scan(
|
||||
&record.ID, &record.TenantID, &record.OperatorID, &record.APIPath, &record.IdempotencyKey,
|
||||
&record.RequestID, &record.PayloadHash, &record.ResponseCode, &record.ResponseBody,
|
||||
&record.Status, &record.ExpiresAt, &record.CreatedAt, &record.UpdatedAt,
|
||||
)
|
||||
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get idempotency record by request_id: %w", err)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// CheckExists 检查幂等记录是否存在(用于竞争条件检测)
|
||||
func (r *IdempotencyRepository) CheckExists(ctx context.Context, tenantID, operatorID int64, apiPath, idempotencyKey string) (bool, error) {
|
||||
query := `
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM supply_idempotency_records
|
||||
WHERE tenant_id = $1 AND operator_id = $2 AND api_path = $3 AND idempotency_key = $4
|
||||
AND expires_at > $5
|
||||
)
|
||||
`
|
||||
|
||||
var exists bool
|
||||
err := r.pool.QueryRow(ctx, query, tenantID, operatorID, apiPath, idempotencyKey, time.Now()).Scan(&exists)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check idempotency record existence: %w", err)
|
||||
}
|
||||
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
// AcquireLock 尝试获取幂等锁(用于创建记录)
|
||||
func (r *IdempotencyRepository) AcquireLock(ctx context.Context, tenantID, operatorID int64, apiPath, idempotencyKey string, ttl time.Duration) (*IdempotencyRecord, error) {
|
||||
// 先尝试插入
|
||||
record := &IdempotencyRecord{
|
||||
TenantID: tenantID,
|
||||
OperatorID: operatorID,
|
||||
APIPath: apiPath,
|
||||
IdempotencyKey: idempotencyKey,
|
||||
RequestID: "", // 稍后填充
|
||||
PayloadHash: "", // 稍后填充
|
||||
Status: IdempotencyStatusProcessing,
|
||||
ExpiresAt: time.Now().Add(ttl),
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO supply_idempotency_records (
|
||||
tenant_id, operator_id, api_path, idempotency_key,
|
||||
request_id, payload_hash, status, expires_at
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7, $8
|
||||
)
|
||||
ON CONFLICT (tenant_id, operator_id, api_path, idempotency_key)
|
||||
DO UPDATE SET
|
||||
request_id = EXCLUDED.request_id,
|
||||
payload_hash = EXCLUDED.payload_hash,
|
||||
status = EXCLUDED.status,
|
||||
expires_at = EXCLUDED.expires_at,
|
||||
updated_at = now()
|
||||
WHERE supply_idempotency_records.expires_at <= $8
|
||||
RETURNING id, created_at, updated_at, status
|
||||
`
|
||||
|
||||
err := r.pool.QueryRow(ctx, query,
|
||||
record.TenantID, record.OperatorID, record.APIPath, record.IdempotencyKey,
|
||||
record.RequestID, record.PayloadHash, record.Status, record.ExpiresAt,
|
||||
).Scan(&record.ID, &record.CreatedAt, &record.UpdatedAt, &record.Status)
|
||||
|
||||
if err != nil {
|
||||
// 可能是重复插入
|
||||
existing, getErr := r.GetByKey(ctx, tenantID, operatorID, apiPath, idempotencyKey)
|
||||
if getErr != nil {
|
||||
return nil, fmt.Errorf("failed to acquire idempotency lock: %w (get err: %v)", err, getErr)
|
||||
}
|
||||
if existing != nil {
|
||||
return existing, nil // 返回已存在的记录
|
||||
}
|
||||
return nil, fmt.Errorf("failed to acquire idempotency lock: %w", err)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
250
supply-api/internal/repository/package.go
Normal file
250
supply-api/internal/repository/package.go
Normal file
@@ -0,0 +1,250 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"lijiaoqiao/supply-api/internal/domain"
|
||||
)
|
||||
|
||||
// PackageRepository 套餐仓储
|
||||
type PackageRepository struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
// NewPackageRepository 创建套餐仓储
|
||||
func NewPackageRepository(pool *pgxpool.Pool) *PackageRepository {
|
||||
return &PackageRepository{pool: pool}
|
||||
}
|
||||
|
||||
// Create 创建套餐
|
||||
func (r *PackageRepository) Create(ctx context.Context, pkg *domain.Package, requestID, traceID string) error {
|
||||
query := `
|
||||
INSERT INTO supply_packages (
|
||||
supply_account_id, user_id, platform, model,
|
||||
total_quota, available_quota, sold_quota, reserved_quota,
|
||||
price_per_1m_input, price_per_1m_output, min_purchase,
|
||||
start_at, end_at, valid_days,
|
||||
status, max_concurrent, rate_limit_rpm,
|
||||
total_orders, total_revenue, rating, rating_count,
|
||||
quota_unit, price_unit, currency_code, version,
|
||||
created_ip, updated_ip, audit_trace_id,
|
||||
request_id
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21,
|
||||
$22, $23, $24, $25, $26, $27, $28
|
||||
)
|
||||
RETURNING id, created_at, updated_at
|
||||
`
|
||||
|
||||
var startAt, endAt *time.Time
|
||||
if !pkg.StartAt.IsZero() {
|
||||
startAt = &pkg.StartAt
|
||||
}
|
||||
if !pkg.EndAt.IsZero() {
|
||||
endAt = &pkg.EndAt
|
||||
}
|
||||
|
||||
err := r.pool.QueryRow(ctx, query,
|
||||
pkg.SupplierID, pkg.SupplierID, pkg.Platform, pkg.Model,
|
||||
pkg.TotalQuota, pkg.AvailableQuota, pkg.SoldQuota, pkg.ReservedQuota,
|
||||
pkg.PricePer1MInput, pkg.PricePer1MOutput, pkg.MinPurchase,
|
||||
startAt, endAt, pkg.ValidDays,
|
||||
pkg.Status, pkg.MaxConcurrent, pkg.RateLimitRPM,
|
||||
pkg.TotalOrders, pkg.TotalRevenue, pkg.Rating, pkg.RatingCount,
|
||||
"token", "per_1m_tokens", "USD", 0,
|
||||
nil, nil, traceID,
|
||||
requestID,
|
||||
).Scan(&pkg.ID, &pkg.CreatedAt, &pkg.UpdatedAt)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create package: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByID 获取套餐
|
||||
func (r *PackageRepository) GetByID(ctx context.Context, supplierID, id int64) (*domain.Package, error) {
|
||||
query := `
|
||||
SELECT id, supply_account_id, user_id, platform, model,
|
||||
total_quota, available_quota, sold_quota, reserved_quota,
|
||||
price_per_1m_input, price_per_1m_output, min_purchase,
|
||||
start_at, end_at, valid_days,
|
||||
status, max_concurrent, rate_limit_rpm,
|
||||
total_orders, total_revenue, rating, rating_count,
|
||||
quota_unit, price_unit, currency_code, version,
|
||||
created_at, updated_at
|
||||
FROM supply_packages
|
||||
WHERE id = $1 AND user_id = $2
|
||||
`
|
||||
|
||||
pkg := &domain.Package{}
|
||||
var startAt, endAt pgx.NullTime
|
||||
err := r.pool.QueryRow(ctx, query, id, supplierID).Scan(
|
||||
&pkg.ID, &pkg.SupplierID, &pkg.SupplierID, &pkg.Platform, &pkg.Model,
|
||||
&pkg.TotalQuota, &pkg.AvailableQuota, &pkg.SoldQuota, &pkg.ReservedQuota,
|
||||
&pkg.PricePer1MInput, &pkg.PricePer1MOutput, &pkg.MinPurchase,
|
||||
&startAt, &endAt, &pkg.ValidDays,
|
||||
&pkg.Status, &pkg.MaxConcurrent, &pkg.RateLimitRPM,
|
||||
&pkg.TotalOrders, &pkg.TotalRevenue, &pkg.Rating, &pkg.RatingCount,
|
||||
&pkg.QuotaUnit, &pkg.PriceUnit, &pkg.CurrencyCode, &pkg.Version,
|
||||
&pkg.CreatedAt, &pkg.UpdatedAt,
|
||||
)
|
||||
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get package: %w", err)
|
||||
}
|
||||
|
||||
if startAt.Valid {
|
||||
pkg.StartAt = startAt.Time
|
||||
}
|
||||
if endAt.Valid {
|
||||
pkg.EndAt = endAt.Time
|
||||
}
|
||||
|
||||
return pkg, nil
|
||||
}
|
||||
|
||||
// Update 更新套餐(乐观锁)
|
||||
func (r *PackageRepository) Update(ctx context.Context, pkg *domain.Package, expectedVersion int) error {
|
||||
query := `
|
||||
UPDATE supply_packages SET
|
||||
platform = $1, model = $2,
|
||||
total_quota = $3, available_quota = $4, sold_quota = $5, reserved_quota = $6,
|
||||
price_per_1m_input = $7, price_per_1m_output = $8,
|
||||
start_at = $9, end_at = $10, valid_days = $11,
|
||||
status = $12, max_concurrent = $13, rate_limit_rpm = $14,
|
||||
total_orders = $15, total_revenue = $16,
|
||||
rating = $17, rating_count = $18,
|
||||
version = $19, updated_at = $20
|
||||
WHERE id = $21 AND user_id = $22 AND version = $23
|
||||
`
|
||||
|
||||
pkg.UpdatedAt = time.Now()
|
||||
newVersion := expectedVersion + 1
|
||||
|
||||
cmdTag, err := r.pool.Exec(ctx, query,
|
||||
pkg.Platform, pkg.Model,
|
||||
pkg.TotalQuota, pkg.AvailableQuota, pkg.SoldQuota, pkg.ReservedQuota,
|
||||
pkg.PricePer1MInput, pkg.PricePer1MOutput,
|
||||
pkg.StartAt, pkg.EndAt, pkg.ValidDays,
|
||||
pkg.Status, pkg.MaxConcurrent, pkg.RateLimitRPM,
|
||||
pkg.TotalOrders, pkg.TotalRevenue,
|
||||
pkg.Rating, pkg.RatingCount,
|
||||
newVersion, pkg.UpdatedAt,
|
||||
pkg.ID, pkg.SupplierID, expectedVersion,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update package: %w", err)
|
||||
}
|
||||
|
||||
if cmdTag.RowsAffected() == 0 {
|
||||
return ErrConcurrencyConflict
|
||||
}
|
||||
|
||||
pkg.Version = newVersion
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetForUpdate 获取套餐并加行锁
|
||||
func (r *PackageRepository) GetForUpdate(ctx context.Context, tx pgxpool.Tx, supplierID, id int64) (*domain.Package, error) {
|
||||
query := `
|
||||
SELECT id, supply_account_id, user_id, platform, model,
|
||||
total_quota, available_quota, sold_quota, reserved_quota,
|
||||
price_per_1m_input, price_per_1m_output,
|
||||
status, version,
|
||||
created_at, updated_at
|
||||
FROM supply_packages
|
||||
WHERE id = $1 AND user_id = $2
|
||||
FOR UPDATE
|
||||
`
|
||||
|
||||
pkg := &domain.Package{}
|
||||
err := tx.QueryRow(ctx, query, id, supplierID).Scan(
|
||||
&pkg.ID, &pkg.SupplierID, &pkg.SupplierID, &pkg.Platform, &pkg.Model,
|
||||
&pkg.TotalQuota, &pkg.AvailableQuota, &pkg.SoldQuota, &pkg.ReservedQuota,
|
||||
&pkg.PricePer1MInput, &pkg.PricePer1MOutput,
|
||||
&pkg.Status, &pkg.Version,
|
||||
&pkg.CreatedAt, &pkg.UpdatedAt,
|
||||
)
|
||||
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get package for update: %w", err)
|
||||
}
|
||||
|
||||
return pkg, nil
|
||||
}
|
||||
|
||||
// List 列出套餐
|
||||
func (r *PackageRepository) List(ctx context.Context, supplierID int64) ([]*domain.Package, error) {
|
||||
query := `
|
||||
SELECT id, supply_account_id, user_id, platform, model,
|
||||
total_quota, available_quota, sold_quota,
|
||||
price_per_1m_input, price_per_1m_output,
|
||||
status, max_concurrent, rate_limit_rpm,
|
||||
valid_days, total_orders, total_revenue,
|
||||
version, created_at, updated_at
|
||||
FROM supply_packages
|
||||
WHERE user_id = $1
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
rows, err := r.pool.Query(ctx, query, supplierID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list packages: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var packages []*domain.Package
|
||||
for rows.Next() {
|
||||
pkg := &domain.Package{}
|
||||
err := rows.Scan(
|
||||
&pkg.ID, &pkg.SupplierID, &pkg.SupplierID, &pkg.Platform, &pkg.Model,
|
||||
&pkg.TotalQuota, &pkg.AvailableQuota, &pkg.SoldQuota,
|
||||
&pkg.PricePer1MInput, &pkg.PricePer1MOutput,
|
||||
&pkg.Status, &pkg.MaxConcurrent, &pkg.RateLimitRPM,
|
||||
&pkg.ValidDays, &pkg.TotalOrders, &pkg.TotalRevenue,
|
||||
&pkg.Version, &pkg.CreatedAt, &pkg.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan package: %w", err)
|
||||
}
|
||||
packages = append(packages, pkg)
|
||||
}
|
||||
|
||||
return packages, nil
|
||||
}
|
||||
|
||||
// UpdateQuota 扣减配额
|
||||
func (r *PackageRepository) UpdateQuota(ctx context.Context, tx pgxpool.Tx, packageID, supplierID int64, usedQuota float64) error {
|
||||
query := `
|
||||
UPDATE supply_packages SET
|
||||
available_quota = available_quota - $1,
|
||||
sold_quota = sold_quota + $1,
|
||||
updated_at = $2
|
||||
WHERE id = $3 AND user_id = $4 AND available_quota >= $1
|
||||
RETURNING id
|
||||
`
|
||||
|
||||
var id int64
|
||||
err := tx.QueryRow(ctx, query, usedQuota, time.Now(), packageID, supplierID).Scan(&id)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return errors.New("insufficient quota or package not found")
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update quota: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
243
supply-api/internal/repository/settlement.go
Normal file
243
supply-api/internal/repository/settlement.go
Normal file
@@ -0,0 +1,243 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"lijiaoqiao/supply-api/internal/domain"
|
||||
)
|
||||
|
||||
// SettlementRepository 结算仓储
|
||||
type SettlementRepository struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
// NewSettlementRepository 创建结算仓储
|
||||
func NewSettlementRepository(pool *pgxpool.Pool) *SettlementRepository {
|
||||
return &SettlementRepository{pool: pool}
|
||||
}
|
||||
|
||||
// Create 创建结算单
|
||||
func (r *SettlementRepository) Create(ctx context.Context, s *domain.Settlement, requestID, idempotencyKey, traceID string) error {
|
||||
query := `
|
||||
INSERT INTO supply_settlements (
|
||||
settlement_no, user_id, total_amount, fee_amount, net_amount,
|
||||
status, payment_method, payment_account,
|
||||
period_start, period_end, total_orders, total_usage_records,
|
||||
currency_code, amount_unit, version,
|
||||
request_id, idempotency_key, audit_trace_id
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18
|
||||
)
|
||||
RETURNING id, created_at, updated_at
|
||||
`
|
||||
|
||||
err := r.pool.QueryRow(ctx, query,
|
||||
s.SettlementNo, s.SupplierID, s.TotalAmount, s.FeeAmount, s.NetAmount,
|
||||
s.Status, s.PaymentMethod, s.PaymentAccount,
|
||||
s.PeriodStart, s.PeriodEnd, s.TotalOrders, s.TotalUsageRecords,
|
||||
"USD", "minor", 0,
|
||||
requestID, idempotencyKey, traceID,
|
||||
).Scan(&s.ID, &s.CreatedAt, &s.UpdatedAt)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create settlement: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByID 获取结算单
|
||||
func (r *SettlementRepository) GetByID(ctx context.Context, supplierID, id int64) (*domain.Settlement, error) {
|
||||
query := `
|
||||
SELECT id, settlement_no, user_id, total_amount, fee_amount, net_amount,
|
||||
status, payment_method, payment_account,
|
||||
period_start, period_end, total_orders, total_usage_records,
|
||||
payment_transaction_id, paid_at,
|
||||
version, created_at, updated_at
|
||||
FROM supply_settlements
|
||||
WHERE id = $1 AND user_id = $2
|
||||
`
|
||||
|
||||
s := &domain.Settlement{}
|
||||
var paidAt pgx.NullTime
|
||||
err := r.pool.QueryRow(ctx, query, id, supplierID).Scan(
|
||||
&s.ID, &s.SettlementNo, &s.SupplierID, &s.TotalAmount, &s.FeeAmount, &s.NetAmount,
|
||||
&s.Status, &s.PaymentMethod, &s.PaymentAccount,
|
||||
&s.PeriodStart, &s.PeriodEnd, &s.TotalOrders, &s.TotalUsageRecords,
|
||||
&s.PaymentTransactionID, &paidAt,
|
||||
&s.Version, &s.CreatedAt, &s.UpdatedAt,
|
||||
)
|
||||
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get settlement: %w", err)
|
||||
}
|
||||
|
||||
if paidAt.Valid {
|
||||
s.PaidAt = &paidAt.Time
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Update 更新结算单(乐观锁)
|
||||
func (r *SettlementRepository) Update(ctx context.Context, s *domain.Settlement, expectedVersion int) error {
|
||||
query := `
|
||||
UPDATE supply_settlements SET
|
||||
status = $1, payment_method = $2, payment_account = $3,
|
||||
payment_transaction_id = $4, paid_at = $5,
|
||||
total_orders = $6, total_usage_records = $7,
|
||||
version = $8, updated_at = $9
|
||||
WHERE id = $10 AND user_id = $11 AND version = $12
|
||||
`
|
||||
|
||||
s.UpdatedAt = time.Now()
|
||||
newVersion := expectedVersion + 1
|
||||
|
||||
cmdTag, err := r.pool.Exec(ctx, query,
|
||||
s.Status, s.PaymentMethod, s.PaymentAccount,
|
||||
s.PaymentTransactionID, s.PaidAt,
|
||||
s.TotalOrders, s.TotalUsageRecords,
|
||||
newVersion, s.UpdatedAt,
|
||||
s.ID, s.SupplierID, expectedVersion,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update settlement: %w", err)
|
||||
}
|
||||
|
||||
if cmdTag.RowsAffected() == 0 {
|
||||
return ErrConcurrencyConflict
|
||||
}
|
||||
|
||||
s.Version = newVersion
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetForUpdate 获取结算单并加行锁
|
||||
func (r *SettlementRepository) GetForUpdate(ctx context.Context, tx pgxpool.Tx, supplierID, id int64) (*domain.Settlement, error) {
|
||||
query := `
|
||||
SELECT id, settlement_no, user_id, total_amount, fee_amount, net_amount,
|
||||
status, payment_method, payment_account, version,
|
||||
created_at, updated_at
|
||||
FROM supply_settlements
|
||||
WHERE id = $1 AND user_id = $2
|
||||
FOR UPDATE
|
||||
`
|
||||
|
||||
s := &domain.Settlement{}
|
||||
err := tx.QueryRow(ctx, query, id, supplierID).Scan(
|
||||
&s.ID, &s.SettlementNo, &s.SupplierID, &s.TotalAmount, &s.FeeAmount, &s.NetAmount,
|
||||
&s.Status, &s.PaymentMethod, &s.PaymentAccount, &s.Version,
|
||||
&s.CreatedAt, &s.UpdatedAt,
|
||||
)
|
||||
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get settlement for update: %w", err)
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// GetProcessing 获取处理中的结算单(用于单一性约束)
|
||||
func (r *SettlementRepository) GetProcessing(ctx context.Context, tx pgxpool.Tx, supplierID int64) (*domain.Settlement, error) {
|
||||
query := `
|
||||
SELECT id, settlement_no, user_id, total_amount, fee_amount, net_amount,
|
||||
status, payment_method, payment_account, version,
|
||||
created_at, updated_at
|
||||
FROM supply_settlements
|
||||
WHERE user_id = $1 AND status = 'processing'
|
||||
FOR UPDATE SKIP LOCKED
|
||||
LIMIT 1
|
||||
`
|
||||
|
||||
s := &domain.Settlement{}
|
||||
err := tx.QueryRow(ctx, query, supplierID).Scan(
|
||||
&s.ID, &s.SettlementNo, &s.SupplierID, &s.TotalAmount, &s.FeeAmount, &s.NetAmount,
|
||||
&s.Status, &s.PaymentMethod, &s.PaymentAccount, &s.Version,
|
||||
&s.CreatedAt, &s.UpdatedAt,
|
||||
)
|
||||
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, nil // 没有处理中的单据
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get processing settlement: %w", err)
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// List 列出结算单
|
||||
func (r *SettlementRepository) List(ctx context.Context, supplierID int64) ([]*domain.Settlement, error) {
|
||||
query := `
|
||||
SELECT id, settlement_no, user_id, total_amount, fee_amount, net_amount,
|
||||
status, payment_method,
|
||||
period_start, period_end, total_orders,
|
||||
version, created_at, updated_at
|
||||
FROM supply_settlements
|
||||
WHERE user_id = $1
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
rows, err := r.pool.Query(ctx, query, supplierID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list settlements: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var settlements []*domain.Settlement
|
||||
for rows.Next() {
|
||||
s := &domain.Settlement{}
|
||||
err := rows.Scan(
|
||||
&s.ID, &s.SettlementNo, &s.SupplierID, &s.TotalAmount, &s.FeeAmount, &s.NetAmount,
|
||||
&s.Status, &s.PaymentMethod,
|
||||
&s.PeriodStart, &s.PeriodEnd, &s.TotalOrders,
|
||||
&s.Version, &s.CreatedAt, &s.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan settlement: %w", err)
|
||||
}
|
||||
settlements = append(settlements, s)
|
||||
}
|
||||
|
||||
return settlements, nil
|
||||
}
|
||||
|
||||
// CreateInTx 在事务中创建结算单
|
||||
func (r *SettlementRepository) CreateInTx(ctx context.Context, tx pgxpool.Tx, s *domain.Settlement, requestID, idempotencyKey, traceID string) error {
|
||||
query := `
|
||||
INSERT INTO supply_settlements (
|
||||
settlement_no, user_id, total_amount, fee_amount, net_amount,
|
||||
status, payment_method, payment_account,
|
||||
period_start, period_end, total_orders, total_usage_records,
|
||||
currency_code, amount_unit, version,
|
||||
request_id, idempotency_key, audit_trace_id
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18
|
||||
)
|
||||
RETURNING id, created_at, updated_at
|
||||
`
|
||||
|
||||
err := tx.QueryRow(ctx, query,
|
||||
s.SettlementNo, s.SupplierID, s.TotalAmount, s.FeeAmount, s.NetAmount,
|
||||
s.Status, s.PaymentMethod, s.PaymentAccount,
|
||||
s.PeriodStart, s.PeriodEnd, s.TotalOrders, s.TotalUsageRecords,
|
||||
"USD", "minor", 0,
|
||||
requestID, idempotencyKey, traceID,
|
||||
).Scan(&s.ID, &s.CreatedAt, &s.UpdatedAt)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create settlement in tx: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user