Files
lijiaoqiao/supply-api/internal/storage/store.go
Your Name da385ee744 fix: P0-02 修复提现竞态条件
使用 SELECT ... FOR UPDATE SKIP LOCKED 实现原子化提现创建

问题:
- HasPendingOrProcessingWithdraw 和 CreateInTx 分开调用导致竞态
- 两个并发请求可能同时通过检查并创建提现

解决方案:
- 新增 CreateWithdrawTx 方法,先锁定 pending 记录再检查插入
- 使用 FOR UPDATE SKIP LOCKED 防止并发插入

涉及文件:
- internal/repository/settlement.go: 新增 CreateWithdrawTx
- internal/adapter/adapter.go: 实现 CreateWithdrawTx
- internal/domain/settlement.go: 使用 CreateWithdrawTx
- internal/storage/store.go: 实现内存存储版本
- sql/postgresql/settlement_withdraw_constraint_v1.sql: 文档说明

测试: go test -short ./... 通过
2026-04-09 22:16:08 +08:00

406 lines
9.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package storage
import (
"context"
"errors"
"sync"
"time"
"lijiaoqiao/supply-api/internal/domain"
"lijiaoqiao/supply-api/internal/repository"
)
// 错误定义
var ErrNotFound = errors.New("resource not found")
// 内存账号存储
type InMemoryAccountStore struct {
mu sync.RWMutex
accounts map[int64]*domain.Account
nextID int64
}
func NewInMemoryAccountStore() *InMemoryAccountStore {
return &InMemoryAccountStore{
accounts: make(map[int64]*domain.Account),
nextID: 1,
}
}
func (s *InMemoryAccountStore) Create(ctx context.Context, account *domain.Account) error {
s.mu.Lock()
defer s.mu.Unlock()
account.ID = s.nextID
s.nextID++
account.CreatedAt = time.Now()
account.UpdatedAt = time.Now()
s.accounts[account.ID] = account
return nil
}
func (s *InMemoryAccountStore) GetByID(ctx context.Context, supplierID, id int64) (*domain.Account, error) {
s.mu.RLock()
defer s.mu.RUnlock()
account, ok := s.accounts[id]
if !ok || account.SupplierID != supplierID {
return nil, ErrNotFound
}
return account, nil
}
func (s *InMemoryAccountStore) Update(ctx context.Context, account *domain.Account) error {
s.mu.Lock()
defer s.mu.Unlock()
existing, ok := s.accounts[account.ID]
if !ok || existing.SupplierID != account.SupplierID {
return ErrNotFound
}
account.UpdatedAt = time.Now()
s.accounts[account.ID] = account
return nil
}
func (s *InMemoryAccountStore) List(ctx context.Context, supplierID int64) ([]*domain.Account, error) {
s.mu.RLock()
defer s.mu.RUnlock()
result := make([]*domain.Account, 0)
for _, account := range s.accounts {
if account.SupplierID == supplierID {
result = append(result, account)
}
}
return result, nil
}
// 内存套餐存储
type InMemoryPackageStore struct {
mu sync.RWMutex
packages map[int64]*domain.Package
nextID int64
}
func NewInMemoryPackageStore() *InMemoryPackageStore {
return &InMemoryPackageStore{
packages: make(map[int64]*domain.Package),
nextID: 1,
}
}
func (s *InMemoryPackageStore) Create(ctx context.Context, pkg *domain.Package) error {
s.mu.Lock()
defer s.mu.Unlock()
pkg.ID = s.nextID
s.nextID++
pkg.CreatedAt = time.Now()
pkg.UpdatedAt = time.Now()
s.packages[pkg.ID] = pkg
return nil
}
func (s *InMemoryPackageStore) GetByID(ctx context.Context, supplierID, id int64) (*domain.Package, error) {
s.mu.RLock()
defer s.mu.RUnlock()
pkg, ok := s.packages[id]
if !ok || pkg.SupplierID != supplierID {
return nil, ErrNotFound
}
return pkg, nil
}
func (s *InMemoryPackageStore) Update(ctx context.Context, pkg *domain.Package) error {
s.mu.Lock()
defer s.mu.Unlock()
existing, ok := s.packages[pkg.ID]
if !ok || existing.SupplierID != pkg.SupplierID {
return ErrNotFound
}
pkg.UpdatedAt = time.Now()
s.packages[pkg.ID] = pkg
return nil
}
func (s *InMemoryPackageStore) List(ctx context.Context, supplierID int64) ([]*domain.Package, error) {
s.mu.RLock()
defer s.mu.RUnlock()
result := make([]*domain.Package, 0)
for _, pkg := range s.packages {
if pkg.SupplierID == supplierID {
result = append(result, pkg)
}
}
return result, nil
}
// 内存结算存储
type InMemorySettlementStore struct {
mu sync.RWMutex
settlements map[int64]*domain.Settlement
nextID int64
}
func NewInMemorySettlementStore() *InMemorySettlementStore {
return &InMemorySettlementStore{
settlements: make(map[int64]*domain.Settlement),
nextID: 1,
}
}
func (s *InMemorySettlementStore) Create(ctx context.Context, settlement *domain.Settlement) error {
s.mu.Lock()
defer s.mu.Unlock()
settlement.ID = s.nextID
s.nextID++
settlement.CreatedAt = time.Now()
settlement.UpdatedAt = time.Now()
s.settlements[settlement.ID] = settlement
return nil
}
// CreateInTx 在事务中创建内存存储不需要真实事务直接调用Create
func (s *InMemorySettlementStore) CreateInTx(ctx context.Context, settlement *domain.Settlement) error {
return s.Create(ctx, settlement)
}
// CreateWithdrawTx 原子化提现创建(内存存储实现)
// 注意:内存存储天然是串行的,不需要额外锁
func (s *InMemorySettlementStore) CreateWithdrawTx(ctx context.Context, settlement *domain.Settlement) error {
s.mu.Lock()
defer s.mu.Unlock()
// 检查是否有pending的提现
for _, existing := range s.settlements {
if existing.SupplierID == settlement.SupplierID &&
(existing.Status == domain.SettlementStatusPending || existing.Status == domain.SettlementStatusProcessing) {
return errors.New("already has pending or processing withdrawal")
}
}
settlement.ID = s.nextID
s.nextID++
settlement.CreatedAt = time.Now()
settlement.UpdatedAt = time.Now()
s.settlements[settlement.ID] = settlement
return nil
}
func (s *InMemorySettlementStore) GetByID(ctx context.Context, supplierID, id int64) (*domain.Settlement, error) {
s.mu.RLock()
defer s.mu.RUnlock()
settlement, ok := s.settlements[id]
if !ok || settlement.SupplierID != supplierID {
return nil, ErrNotFound
}
return settlement, nil
}
func (s *InMemorySettlementStore) Update(ctx context.Context, settlement *domain.Settlement, expectedVersion int) error {
s.mu.Lock()
defer s.mu.Unlock()
existing, ok := s.settlements[settlement.ID]
if !ok || existing.SupplierID != settlement.SupplierID {
return ErrNotFound
}
// P1-005: 乐观锁检查
if existing.Version != expectedVersion {
return repository.ErrConcurrencyConflict
}
settlement.Version = expectedVersion + 1
settlement.UpdatedAt = time.Now()
s.settlements[settlement.ID] = settlement
return nil
}
func (s *InMemorySettlementStore) List(ctx context.Context, supplierID int64) ([]*domain.Settlement, error) {
s.mu.RLock()
defer s.mu.RUnlock()
result := make([]*domain.Settlement, 0)
for _, settlement := range s.settlements {
if settlement.SupplierID == supplierID {
result = append(result, settlement)
}
}
return result, nil
}
func (s *InMemorySettlementStore) GetWithdrawableBalance(ctx context.Context, supplierID int64) (float64, error) {
return 10000.0, nil
}
func (s *InMemorySettlementStore) HasPendingOrProcessingWithdraw(ctx context.Context, supplierID int64) (bool, error) {
s.mu.RLock()
defer s.mu.RUnlock()
for _, settlement := range s.settlements {
if settlement.SupplierID == supplierID {
if settlement.Status == domain.SettlementStatusPending || settlement.Status == domain.SettlementStatusProcessing {
return true, nil
}
}
}
return false, nil
}
// 内存收益存储
type InMemoryEarningStore struct {
mu sync.RWMutex
records map[int64]*domain.EarningRecord
nextID int64
}
func NewInMemoryEarningStore() *InMemoryEarningStore {
return &InMemoryEarningStore{
records: make(map[int64]*domain.EarningRecord),
nextID: 1,
}
}
func (s *InMemoryEarningStore) ListRecords(ctx context.Context, supplierID int64, startDate, endDate string, page, pageSize int) ([]*domain.EarningRecord, int, error) {
s.mu.RLock()
defer s.mu.RUnlock()
var result []*domain.EarningRecord
for _, record := range s.records {
if record.SupplierID == supplierID {
result = append(result, record)
}
}
total := len(result)
start := (page - 1) * pageSize
end := start + pageSize
if start >= total {
return []*domain.EarningRecord{}, total, nil
}
if end > total {
end = total
}
return result[start:end], total, nil
}
func (s *InMemoryEarningStore) GetBillingSummary(ctx context.Context, supplierID int64, startDate, endDate string) (*domain.BillingSummary, error) {
return &domain.BillingSummary{
Period: domain.BillingPeriod{
Start: startDate,
End: endDate,
},
Summary: domain.BillingTotal{
TotalRevenue: 10000.0,
TotalOrders: 100,
TotalUsage: 1000000,
TotalRequests: 50000,
AvgSuccessRate: 99.5,
PlatformFee: 100.0,
NetEarnings: 9900.0,
},
}, nil
}
// 内存幂等存储
type InMemoryIdempotencyStore struct {
mu sync.RWMutex
records map[string]*IdempotencyRecord
cleanupCounter int64 // 清理触发计数器
}
type IdempotencyRecord struct {
Key string
Status string // processing, succeeded, failed
Response interface{}
CreatedAt time.Time
ExpiresAt time.Time
}
func NewInMemoryIdempotencyStore() *InMemoryIdempotencyStore {
return &InMemoryIdempotencyStore{
records: make(map[string]*IdempotencyRecord),
}
}
func (s *InMemoryIdempotencyStore) Get(key string) (*IdempotencyRecord, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
record, ok := s.records[key]
if ok && record.ExpiresAt.After(time.Now()) {
return record, true
}
return nil, false
}
func (s *InMemoryIdempotencyStore) SetProcessing(key string, ttl time.Duration) {
s.mu.Lock()
defer s.mu.Unlock()
s.records[key] = &IdempotencyRecord{
Key: key,
Status: "processing",
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(ttl),
}
s.triggerCleanupLocked()
}
func (s *InMemoryIdempotencyStore) SetSuccess(key string, response interface{}, ttl time.Duration) {
s.mu.Lock()
defer s.mu.Unlock()
s.records[key] = &IdempotencyRecord{
Key: key,
Status: "succeeded",
Response: response,
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(ttl),
}
s.triggerCleanupLocked()
}
// triggerCleanupLocked 触发清理每100次操作清理一次过期记录
// 调用时必须持有锁
func (s *InMemoryIdempotencyStore) triggerCleanupLocked() {
s.cleanupCounter++
if s.cleanupCounter >= 100 {
s.cleanupCounter = 0
s.cleanupExpiredLocked()
}
}
// cleanupExpiredLocked 清理过期记录(需要持有锁)
func (s *InMemoryIdempotencyStore) cleanupExpiredLocked() {
now := time.Now()
for key, record := range s.records {
if record.ExpiresAt.Before(now) {
delete(s.records, key)
}
}
}
// CleanExpired 主动清理过期记录(可由外部定期调用)
func (s *InMemoryIdempotencyStore) CleanExpired() {
s.mu.Lock()
defer s.mu.Unlock()
s.cleanupExpiredLocked()
}
// Len 返回当前记录数量(用于监控)
func (s *InMemoryIdempotencyStore) Len() int {
s.mu.RLock()
defer s.mu.RUnlock()
return len(s.records)
}