使用 SELECT ... FOR UPDATE SKIP LOCKED 实现原子化提现创建 问题: - HasPendingOrProcessingWithdraw 和 CreateInTx 分开调用导致竞态 - 两个并发请求可能同时通过检查并创建提现 解决方案: - 新增 CreateWithdrawTx 方法,先锁定 pending 记录再检查插入 - 使用 FOR UPDATE SKIP LOCKED 防止并发插入 涉及文件: - internal/repository/settlement.go: 新增 CreateWithdrawTx - internal/adapter/adapter.go: 实现 CreateWithdrawTx - internal/domain/settlement.go: 使用 CreateWithdrawTx - internal/storage/store.go: 实现内存存储版本 - sql/postgresql/settlement_withdraw_constraint_v1.sql: 文档说明 测试: go test -short ./... 通过
406 lines
9.9 KiB
Go
406 lines
9.9 KiB
Go
package storage
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"sync"
|
||
"time"
|
||
|
||
"lijiaoqiao/supply-api/internal/domain"
|
||
"lijiaoqiao/supply-api/internal/repository"
|
||
)
|
||
|
||
// 错误定义
|
||
var ErrNotFound = errors.New("resource not found")
|
||
|
||
// 内存账号存储
|
||
type InMemoryAccountStore struct {
|
||
mu sync.RWMutex
|
||
accounts map[int64]*domain.Account
|
||
nextID int64
|
||
}
|
||
|
||
func NewInMemoryAccountStore() *InMemoryAccountStore {
|
||
return &InMemoryAccountStore{
|
||
accounts: make(map[int64]*domain.Account),
|
||
nextID: 1,
|
||
}
|
||
}
|
||
|
||
func (s *InMemoryAccountStore) Create(ctx context.Context, account *domain.Account) error {
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
|
||
account.ID = s.nextID
|
||
s.nextID++
|
||
account.CreatedAt = time.Now()
|
||
account.UpdatedAt = time.Now()
|
||
s.accounts[account.ID] = account
|
||
return nil
|
||
}
|
||
|
||
func (s *InMemoryAccountStore) GetByID(ctx context.Context, supplierID, id int64) (*domain.Account, error) {
|
||
s.mu.RLock()
|
||
defer s.mu.RUnlock()
|
||
|
||
account, ok := s.accounts[id]
|
||
if !ok || account.SupplierID != supplierID {
|
||
return nil, ErrNotFound
|
||
}
|
||
return account, nil
|
||
}
|
||
|
||
func (s *InMemoryAccountStore) Update(ctx context.Context, account *domain.Account) error {
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
|
||
existing, ok := s.accounts[account.ID]
|
||
if !ok || existing.SupplierID != account.SupplierID {
|
||
return ErrNotFound
|
||
}
|
||
account.UpdatedAt = time.Now()
|
||
s.accounts[account.ID] = account
|
||
return nil
|
||
}
|
||
|
||
func (s *InMemoryAccountStore) List(ctx context.Context, supplierID int64) ([]*domain.Account, error) {
|
||
s.mu.RLock()
|
||
defer s.mu.RUnlock()
|
||
|
||
result := make([]*domain.Account, 0)
|
||
for _, account := range s.accounts {
|
||
if account.SupplierID == supplierID {
|
||
result = append(result, account)
|
||
}
|
||
}
|
||
return result, nil
|
||
}
|
||
|
||
// 内存套餐存储
|
||
type InMemoryPackageStore struct {
|
||
mu sync.RWMutex
|
||
packages map[int64]*domain.Package
|
||
nextID int64
|
||
}
|
||
|
||
func NewInMemoryPackageStore() *InMemoryPackageStore {
|
||
return &InMemoryPackageStore{
|
||
packages: make(map[int64]*domain.Package),
|
||
nextID: 1,
|
||
}
|
||
}
|
||
|
||
func (s *InMemoryPackageStore) Create(ctx context.Context, pkg *domain.Package) error {
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
|
||
pkg.ID = s.nextID
|
||
s.nextID++
|
||
pkg.CreatedAt = time.Now()
|
||
pkg.UpdatedAt = time.Now()
|
||
s.packages[pkg.ID] = pkg
|
||
return nil
|
||
}
|
||
|
||
func (s *InMemoryPackageStore) GetByID(ctx context.Context, supplierID, id int64) (*domain.Package, error) {
|
||
s.mu.RLock()
|
||
defer s.mu.RUnlock()
|
||
|
||
pkg, ok := s.packages[id]
|
||
if !ok || pkg.SupplierID != supplierID {
|
||
return nil, ErrNotFound
|
||
}
|
||
return pkg, nil
|
||
}
|
||
|
||
func (s *InMemoryPackageStore) Update(ctx context.Context, pkg *domain.Package) error {
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
|
||
existing, ok := s.packages[pkg.ID]
|
||
if !ok || existing.SupplierID != pkg.SupplierID {
|
||
return ErrNotFound
|
||
}
|
||
pkg.UpdatedAt = time.Now()
|
||
s.packages[pkg.ID] = pkg
|
||
return nil
|
||
}
|
||
|
||
func (s *InMemoryPackageStore) List(ctx context.Context, supplierID int64) ([]*domain.Package, error) {
|
||
s.mu.RLock()
|
||
defer s.mu.RUnlock()
|
||
|
||
result := make([]*domain.Package, 0)
|
||
for _, pkg := range s.packages {
|
||
if pkg.SupplierID == supplierID {
|
||
result = append(result, pkg)
|
||
}
|
||
}
|
||
return result, nil
|
||
}
|
||
|
||
// 内存结算存储
|
||
type InMemorySettlementStore struct {
|
||
mu sync.RWMutex
|
||
settlements map[int64]*domain.Settlement
|
||
nextID int64
|
||
}
|
||
|
||
func NewInMemorySettlementStore() *InMemorySettlementStore {
|
||
return &InMemorySettlementStore{
|
||
settlements: make(map[int64]*domain.Settlement),
|
||
nextID: 1,
|
||
}
|
||
}
|
||
|
||
func (s *InMemorySettlementStore) Create(ctx context.Context, settlement *domain.Settlement) error {
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
|
||
settlement.ID = s.nextID
|
||
s.nextID++
|
||
settlement.CreatedAt = time.Now()
|
||
settlement.UpdatedAt = time.Now()
|
||
s.settlements[settlement.ID] = settlement
|
||
return nil
|
||
}
|
||
|
||
// CreateInTx 在事务中创建(内存存储不需要真实事务,直接调用Create)
|
||
func (s *InMemorySettlementStore) CreateInTx(ctx context.Context, settlement *domain.Settlement) error {
|
||
return s.Create(ctx, settlement)
|
||
}
|
||
|
||
// CreateWithdrawTx 原子化提现创建(内存存储实现)
|
||
// 注意:内存存储天然是串行的,不需要额外锁
|
||
func (s *InMemorySettlementStore) CreateWithdrawTx(ctx context.Context, settlement *domain.Settlement) error {
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
|
||
// 检查是否有pending的提现
|
||
for _, existing := range s.settlements {
|
||
if existing.SupplierID == settlement.SupplierID &&
|
||
(existing.Status == domain.SettlementStatusPending || existing.Status == domain.SettlementStatusProcessing) {
|
||
return errors.New("already has pending or processing withdrawal")
|
||
}
|
||
}
|
||
|
||
settlement.ID = s.nextID
|
||
s.nextID++
|
||
settlement.CreatedAt = time.Now()
|
||
settlement.UpdatedAt = time.Now()
|
||
s.settlements[settlement.ID] = settlement
|
||
return nil
|
||
}
|
||
|
||
func (s *InMemorySettlementStore) GetByID(ctx context.Context, supplierID, id int64) (*domain.Settlement, error) {
|
||
s.mu.RLock()
|
||
defer s.mu.RUnlock()
|
||
|
||
settlement, ok := s.settlements[id]
|
||
if !ok || settlement.SupplierID != supplierID {
|
||
return nil, ErrNotFound
|
||
}
|
||
return settlement, nil
|
||
}
|
||
|
||
func (s *InMemorySettlementStore) Update(ctx context.Context, settlement *domain.Settlement, expectedVersion int) error {
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
|
||
existing, ok := s.settlements[settlement.ID]
|
||
if !ok || existing.SupplierID != settlement.SupplierID {
|
||
return ErrNotFound
|
||
}
|
||
|
||
// P1-005: 乐观锁检查
|
||
if existing.Version != expectedVersion {
|
||
return repository.ErrConcurrencyConflict
|
||
}
|
||
|
||
settlement.Version = expectedVersion + 1
|
||
settlement.UpdatedAt = time.Now()
|
||
s.settlements[settlement.ID] = settlement
|
||
return nil
|
||
}
|
||
|
||
func (s *InMemorySettlementStore) List(ctx context.Context, supplierID int64) ([]*domain.Settlement, error) {
|
||
s.mu.RLock()
|
||
defer s.mu.RUnlock()
|
||
|
||
result := make([]*domain.Settlement, 0)
|
||
for _, settlement := range s.settlements {
|
||
if settlement.SupplierID == supplierID {
|
||
result = append(result, settlement)
|
||
}
|
||
}
|
||
return result, nil
|
||
}
|
||
|
||
func (s *InMemorySettlementStore) GetWithdrawableBalance(ctx context.Context, supplierID int64) (float64, error) {
|
||
return 10000.0, nil
|
||
}
|
||
|
||
func (s *InMemorySettlementStore) HasPendingOrProcessingWithdraw(ctx context.Context, supplierID int64) (bool, error) {
|
||
s.mu.RLock()
|
||
defer s.mu.RUnlock()
|
||
|
||
for _, settlement := range s.settlements {
|
||
if settlement.SupplierID == supplierID {
|
||
if settlement.Status == domain.SettlementStatusPending || settlement.Status == domain.SettlementStatusProcessing {
|
||
return true, nil
|
||
}
|
||
}
|
||
}
|
||
return false, nil
|
||
}
|
||
|
||
// 内存收益存储
|
||
type InMemoryEarningStore struct {
|
||
mu sync.RWMutex
|
||
records map[int64]*domain.EarningRecord
|
||
nextID int64
|
||
}
|
||
|
||
func NewInMemoryEarningStore() *InMemoryEarningStore {
|
||
return &InMemoryEarningStore{
|
||
records: make(map[int64]*domain.EarningRecord),
|
||
nextID: 1,
|
||
}
|
||
}
|
||
|
||
func (s *InMemoryEarningStore) ListRecords(ctx context.Context, supplierID int64, startDate, endDate string, page, pageSize int) ([]*domain.EarningRecord, int, error) {
|
||
s.mu.RLock()
|
||
defer s.mu.RUnlock()
|
||
|
||
var result []*domain.EarningRecord
|
||
for _, record := range s.records {
|
||
if record.SupplierID == supplierID {
|
||
result = append(result, record)
|
||
}
|
||
}
|
||
|
||
total := len(result)
|
||
start := (page - 1) * pageSize
|
||
end := start + pageSize
|
||
|
||
if start >= total {
|
||
return []*domain.EarningRecord{}, total, nil
|
||
}
|
||
if end > total {
|
||
end = total
|
||
}
|
||
|
||
return result[start:end], total, nil
|
||
}
|
||
|
||
func (s *InMemoryEarningStore) GetBillingSummary(ctx context.Context, supplierID int64, startDate, endDate string) (*domain.BillingSummary, error) {
|
||
return &domain.BillingSummary{
|
||
Period: domain.BillingPeriod{
|
||
Start: startDate,
|
||
End: endDate,
|
||
},
|
||
Summary: domain.BillingTotal{
|
||
TotalRevenue: 10000.0,
|
||
TotalOrders: 100,
|
||
TotalUsage: 1000000,
|
||
TotalRequests: 50000,
|
||
AvgSuccessRate: 99.5,
|
||
PlatformFee: 100.0,
|
||
NetEarnings: 9900.0,
|
||
},
|
||
}, nil
|
||
}
|
||
|
||
// 内存幂等存储
|
||
type InMemoryIdempotencyStore struct {
|
||
mu sync.RWMutex
|
||
records map[string]*IdempotencyRecord
|
||
cleanupCounter int64 // 清理触发计数器
|
||
}
|
||
|
||
type IdempotencyRecord struct {
|
||
Key string
|
||
Status string // processing, succeeded, failed
|
||
Response interface{}
|
||
CreatedAt time.Time
|
||
ExpiresAt time.Time
|
||
}
|
||
|
||
func NewInMemoryIdempotencyStore() *InMemoryIdempotencyStore {
|
||
return &InMemoryIdempotencyStore{
|
||
records: make(map[string]*IdempotencyRecord),
|
||
}
|
||
}
|
||
|
||
func (s *InMemoryIdempotencyStore) Get(key string) (*IdempotencyRecord, bool) {
|
||
s.mu.RLock()
|
||
defer s.mu.RUnlock()
|
||
|
||
record, ok := s.records[key]
|
||
if ok && record.ExpiresAt.After(time.Now()) {
|
||
return record, true
|
||
}
|
||
return nil, false
|
||
}
|
||
|
||
func (s *InMemoryIdempotencyStore) SetProcessing(key string, ttl time.Duration) {
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
|
||
s.records[key] = &IdempotencyRecord{
|
||
Key: key,
|
||
Status: "processing",
|
||
CreatedAt: time.Now(),
|
||
ExpiresAt: time.Now().Add(ttl),
|
||
}
|
||
s.triggerCleanupLocked()
|
||
}
|
||
|
||
func (s *InMemoryIdempotencyStore) SetSuccess(key string, response interface{}, ttl time.Duration) {
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
|
||
s.records[key] = &IdempotencyRecord{
|
||
Key: key,
|
||
Status: "succeeded",
|
||
Response: response,
|
||
CreatedAt: time.Now(),
|
||
ExpiresAt: time.Now().Add(ttl),
|
||
}
|
||
s.triggerCleanupLocked()
|
||
}
|
||
|
||
// triggerCleanupLocked 触发清理(每100次操作清理一次过期记录)
|
||
// 调用时必须持有锁
|
||
func (s *InMemoryIdempotencyStore) triggerCleanupLocked() {
|
||
s.cleanupCounter++
|
||
if s.cleanupCounter >= 100 {
|
||
s.cleanupCounter = 0
|
||
s.cleanupExpiredLocked()
|
||
}
|
||
}
|
||
|
||
// cleanupExpiredLocked 清理过期记录(需要持有锁)
|
||
func (s *InMemoryIdempotencyStore) cleanupExpiredLocked() {
|
||
now := time.Now()
|
||
for key, record := range s.records {
|
||
if record.ExpiresAt.Before(now) {
|
||
delete(s.records, key)
|
||
}
|
||
}
|
||
}
|
||
|
||
// CleanExpired 主动清理过期记录(可由外部定期调用)
|
||
func (s *InMemoryIdempotencyStore) CleanExpired() {
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
s.cleanupExpiredLocked()
|
||
}
|
||
|
||
// Len 返回当前记录数量(用于监控)
|
||
func (s *InMemoryIdempotencyStore) Len() int {
|
||
s.mu.RLock()
|
||
defer s.mu.RUnlock()
|
||
return len(s.records)
|
||
}
|