fix: 修复提现唯一性检查问题 (PRD P0)
问题:Withdraw函数没有检查是否已有处理中的提现,可能导致并发提现 修复内容: 1. 添加新错误码 ErrWithdrawAlreadyProcessing (SUP_SET_4093) 2. 在 SettlementStore 接口添加 HasPendingOrProcessingWithdraw 方法 3. 在 Withdraw 函数中添加检查:已有pending/processing状态提现时拒绝新的提现 4. 在 Repository 中实现 HasPendingOrProcessingWithdraw(检查 pending 和 processing 状态) 5. 在所有 mock 实现中添加该方法 修改的文件: - domain/settlement.go: 接口定义和 Withdraw 逻辑 - domain/invariants.go: 新错误码 - repository/settlement.go: HasPendingOrProcessingWithdraw 实现 - storage/store.go: InMemorySettlementStore 实现 - cmd/supply-api/main.go: DBSettlementStore 和 InMemorySettlementStoreAdapter 实现 - test mocks: 添加 HasPendingOrProcessingWithdraw
This commit is contained in:
@@ -430,6 +430,10 @@ func (a *InMemorySettlementStoreAdapter) GetWithdrawableBalance(ctx context.Cont
|
||||
return a.store.GetWithdrawableBalance(ctx, supplierID)
|
||||
}
|
||||
|
||||
func (a *InMemorySettlementStoreAdapter) HasPendingOrProcessingWithdraw(ctx context.Context, supplierID int64) (bool, error) {
|
||||
return a.store.HasPendingOrProcessingWithdraw(ctx, supplierID)
|
||||
}
|
||||
|
||||
// InMemoryEarningStoreAdapter 内存收益存储适配器
|
||||
type InMemoryEarningStoreAdapter struct {
|
||||
store *storage.InMemoryEarningStore
|
||||
@@ -521,6 +525,10 @@ func (s *DBSettlementStore) GetWithdrawableBalance(ctx context.Context, supplier
|
||||
return s.accountRepo.GetWithdrawableBalance(ctx, supplierID)
|
||||
}
|
||||
|
||||
func (s *DBSettlementStore) HasPendingOrProcessingWithdraw(ctx context.Context, supplierID int64) (bool, error) {
|
||||
return s.repo.HasPendingOrProcessingWithdraw(ctx, supplierID)
|
||||
}
|
||||
|
||||
// DBEarningStore DB-backed收益存储
|
||||
type DBEarningStore struct {
|
||||
usageRepo *repository.UsageRepository
|
||||
|
||||
403
supply-api/internal/benchmark/domain_bench_test.go
Normal file
403
supply-api/internal/benchmark/domain_bench_test.go
Normal file
@@ -0,0 +1,403 @@
|
||||
//go:build slow
|
||||
// +build slow
|
||||
|
||||
package benchmark
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"lijiaoqiao/supply-api/internal/audit"
|
||||
"lijiaoqiao/supply-api/internal/domain"
|
||||
)
|
||||
|
||||
// BenchmarkAccountService_Create 基准测试:账号创建性能
|
||||
func BenchmarkAccountService_Create(b *testing.B) {
|
||||
if testing.Short() {
|
||||
b.Skip("Skipping benchmark in short mode")
|
||||
}
|
||||
|
||||
store := newMockAccountStoreForBenchmark()
|
||||
auditStore := &mockAuditStoreForBenchmark{}
|
||||
svc := domain.NewAccountService(store, auditStore)
|
||||
ctx := context.Background()
|
||||
|
||||
req := &domain.CreateAccountRequest{
|
||||
SupplierID: 1001,
|
||||
Provider: domain.ProviderOpenAI,
|
||||
AccountType: domain.AccountTypeAPIKey,
|
||||
Credential: "sk-test-key-benchmark",
|
||||
Alias: "bench-account",
|
||||
RiskAck: true,
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
req.Alias = fmt.Sprintf("bench-account-%d", i)
|
||||
_, _ = svc.Create(ctx, req)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkAccountService_Verify 基准测试:账号验证性能
|
||||
func BenchmarkAccountService_Verify(b *testing.B) {
|
||||
if testing.Short() {
|
||||
b.Skip("Skipping benchmark in short mode")
|
||||
}
|
||||
|
||||
store := newMockAccountStoreForBenchmark()
|
||||
auditStore := &mockAuditStoreForBenchmark{}
|
||||
svc := domain.NewAccountService(store, auditStore)
|
||||
ctx := context.Background()
|
||||
|
||||
// 先创建一个账号
|
||||
req := &domain.CreateAccountRequest{
|
||||
SupplierID: 1001,
|
||||
Provider: domain.ProviderOpenAI,
|
||||
AccountType: domain.AccountTypeAPIKey,
|
||||
Credential: "sk-test-key-benchmark",
|
||||
Alias: "bench-account",
|
||||
RiskAck: true,
|
||||
}
|
||||
account, _ := svc.Create(ctx, req)
|
||||
_ = account
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = svc.Verify(ctx, 1001, domain.ProviderOpenAI, domain.AccountTypeAPIKey, "sk-test-key-benchmark")
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPackageService_CreateDraft 基准测试:套餐创建性能
|
||||
func BenchmarkPackageService_CreateDraft(b *testing.B) {
|
||||
if testing.Short() {
|
||||
b.Skip("Skipping benchmark in short mode")
|
||||
}
|
||||
|
||||
store := newMockPackageStoreForBenchmark()
|
||||
accountStore := newMockAccountStoreForBenchmark()
|
||||
auditStore := &mockAuditStoreForBenchmark{}
|
||||
svc := domain.NewPackageService(store, accountStore, auditStore)
|
||||
ctx := context.Background()
|
||||
|
||||
req := &domain.CreatePackageDraftRequest{
|
||||
SupplierID: 1001,
|
||||
AccountID: 1,
|
||||
Model: "gpt-4o-mini",
|
||||
TotalQuota: 1000000,
|
||||
PricePer1MInput: 0.5,
|
||||
PricePer1MOutput: 1.5,
|
||||
ValidDays: 30,
|
||||
MaxConcurrent: 10,
|
||||
RateLimitRPM: 1000,
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = svc.CreateDraft(ctx, 1001, req)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPackageService_BatchUpdatePrice 基准测试:批量调价性能
|
||||
func BenchmarkPackageService_BatchUpdatePrice(b *testing.B) {
|
||||
if testing.Short() {
|
||||
b.Skip("Skipping benchmark in short mode")
|
||||
}
|
||||
|
||||
store := newMockPackageStoreForBenchmark()
|
||||
accountStore := newMockAccountStoreForBenchmark()
|
||||
auditStore := &mockAuditStoreForBenchmark{}
|
||||
svc := domain.NewPackageService(store, accountStore, auditStore)
|
||||
ctx := context.Background()
|
||||
|
||||
// 创建多个套餐
|
||||
for i := 0; i < 100; i++ {
|
||||
req := &domain.CreatePackageDraftRequest{
|
||||
SupplierID: 1001,
|
||||
AccountID: 1,
|
||||
Model: fmt.Sprintf("gpt-4o-mini-%d", i),
|
||||
TotalQuota: 1000000,
|
||||
PricePer1MInput: 0.5,
|
||||
PricePer1MOutput: 1.5,
|
||||
ValidDays: 30,
|
||||
}
|
||||
pkg, _ := svc.CreateDraft(ctx, 1001, req)
|
||||
_, _ = svc.Publish(ctx, 1001, pkg.ID)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
req := &domain.BatchUpdatePriceRequest{
|
||||
Items: make([]domain.BatchPriceItem, 50),
|
||||
}
|
||||
for j := 0; j < 50; j++ {
|
||||
req.Items[j] = domain.BatchPriceItem{
|
||||
PackageID: int64(j + 1),
|
||||
PricePer1MInput: float64(i) * 0.1,
|
||||
PricePer1MOutput: float64(i) * 0.2,
|
||||
}
|
||||
}
|
||||
_, _ = svc.BatchUpdatePrice(ctx, 1001, req)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSettlementService_Withdraw 基准测试:提现性能
|
||||
func BenchmarkSettlementService_Withdraw(b *testing.B) {
|
||||
if testing.Short() {
|
||||
b.Skip("Skipping benchmark in short mode")
|
||||
}
|
||||
|
||||
store := newMockSettlementStoreForBenchmark()
|
||||
earningStore := newMockEarningStoreForBenchmark()
|
||||
auditStore := &mockAuditStoreForBenchmark{}
|
||||
svc := domain.NewSettlementService(store, earningStore, auditStore)
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
req := &domain.WithdrawRequest{
|
||||
Amount: 100.00,
|
||||
PaymentMethod: domain.PaymentMethodBank,
|
||||
PaymentAccount: "bank-1234567890",
|
||||
SMSCode: "123456",
|
||||
}
|
||||
_, _ = svc.Withdraw(ctx, 1001, req)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkConcurrentAccountAccess 基准测试:并发账号访问
|
||||
func BenchmarkConcurrentAccountAccess(b *testing.B) {
|
||||
if testing.Short() {
|
||||
b.Skip("Skipping benchmark in short mode")
|
||||
}
|
||||
|
||||
store := newMockAccountStoreForBenchmark()
|
||||
auditStore := &mockAuditStoreForBenchmark{}
|
||||
svc := domain.NewAccountService(store, auditStore)
|
||||
ctx := context.Background()
|
||||
|
||||
// 先创建一个账号
|
||||
req := &domain.CreateAccountRequest{
|
||||
SupplierID: 1001,
|
||||
Provider: domain.ProviderOpenAI,
|
||||
AccountType: domain.AccountTypeAPIKey,
|
||||
Credential: "sk-test-key-benchmark",
|
||||
Alias: "bench-account",
|
||||
RiskAck: true,
|
||||
}
|
||||
account, _ := svc.Create(ctx, req)
|
||||
_ = account
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = store.GetByID(ctx, 1001, 1)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSettlementConcurrency 基准测试:结算并发冲突
|
||||
func BenchmarkSettlementConcurrency(b *testing.B) {
|
||||
if testing.Short() {
|
||||
b.Skip("Skipping benchmark in short mode")
|
||||
}
|
||||
|
||||
store := newMockSettlementStoreForBenchmark()
|
||||
earningStore := newMockEarningStoreForBenchmark()
|
||||
auditStore := &mockAuditStoreForBenchmark{}
|
||||
svc := domain.NewSettlementService(store, earningStore, auditStore)
|
||||
ctx := context.Background()
|
||||
|
||||
// 创建一个待处理的结算单
|
||||
settlement, _ := svc.Withdraw(ctx, 1001, &domain.WithdrawRequest{
|
||||
Amount: 100.00,
|
||||
PaymentMethod: domain.PaymentMethodBank,
|
||||
PaymentAccount: "bank-1234567890",
|
||||
SMSCode: "123456",
|
||||
})
|
||||
_ = settlement
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
// 模拟并发取消
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = svc.Cancel(context.Background(), 1001, 1)
|
||||
}
|
||||
}
|
||||
|
||||
// 辅助类型
|
||||
|
||||
type mockAccountStoreForBenchmark struct {
|
||||
accounts map[int64]*domain.Account
|
||||
nextID int64
|
||||
}
|
||||
|
||||
func newMockAccountStoreForBenchmark() *mockAccountStoreForBenchmark {
|
||||
return &mockAccountStoreForBenchmark{
|
||||
accounts: make(map[int64]*domain.Account),
|
||||
nextID: 1,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockAccountStoreForBenchmark) Create(ctx context.Context, account *domain.Account) error {
|
||||
account.ID = m.nextID
|
||||
m.nextID++
|
||||
m.accounts[account.ID] = account
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockAccountStoreForBenchmark) GetByID(ctx context.Context, supplierID, id int64) (*domain.Account, error) {
|
||||
if account, ok := m.accounts[id]; ok && account.SupplierID == supplierID {
|
||||
return account, nil
|
||||
}
|
||||
return nil, fmt.Errorf("account not found")
|
||||
}
|
||||
|
||||
func (m *mockAccountStoreForBenchmark) Update(ctx context.Context, account *domain.Account) error {
|
||||
m.accounts[account.ID] = account
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockAccountStoreForBenchmark) List(ctx context.Context, supplierID int64) ([]*domain.Account, error) {
|
||||
var result []*domain.Account
|
||||
for _, account := range m.accounts {
|
||||
if account.SupplierID == supplierID {
|
||||
result = append(result, account)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
type mockPackageStoreForBenchmark struct {
|
||||
packages map[int64]*domain.Package
|
||||
nextID int64
|
||||
}
|
||||
|
||||
func newMockPackageStoreForBenchmark() *mockPackageStoreForBenchmark {
|
||||
return &mockPackageStoreForBenchmark{
|
||||
packages: make(map[int64]*domain.Package),
|
||||
nextID: 1,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockPackageStoreForBenchmark) Create(ctx context.Context, pkg *domain.Package) error {
|
||||
pkg.ID = m.nextID
|
||||
m.nextID++
|
||||
m.packages[pkg.ID] = pkg
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockPackageStoreForBenchmark) GetByID(ctx context.Context, supplierID, id int64) (*domain.Package, error) {
|
||||
if pkg, ok := m.packages[id]; ok && pkg.SupplierID == supplierID {
|
||||
return pkg, nil
|
||||
}
|
||||
return nil, fmt.Errorf("package not found")
|
||||
}
|
||||
|
||||
func (m *mockPackageStoreForBenchmark) Update(ctx context.Context, pkg *domain.Package) error {
|
||||
m.packages[pkg.ID] = pkg
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockPackageStoreForBenchmark) List(ctx context.Context, supplierID int64) ([]*domain.Package, error) {
|
||||
var result []*domain.Package
|
||||
for _, pkg := range m.packages {
|
||||
if pkg.SupplierID == supplierID {
|
||||
result = append(result, pkg)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
type mockSettlementStoreForBenchmark struct {
|
||||
settlements map[int64]*domain.Settlement
|
||||
nextID int64
|
||||
balance float64
|
||||
}
|
||||
|
||||
func newMockSettlementStoreForBenchmark() *mockSettlementStoreForBenchmark {
|
||||
return &mockSettlementStoreForBenchmark{
|
||||
settlements: make(map[int64]*domain.Settlement),
|
||||
nextID: 1,
|
||||
balance: 100000.00,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockSettlementStoreForBenchmark) Create(ctx context.Context, s *domain.Settlement) error {
|
||||
s.ID = m.nextID
|
||||
m.nextID++
|
||||
m.settlements[s.ID] = s
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSettlementStoreForBenchmark) GetByID(ctx context.Context, supplierID, id int64) (*domain.Settlement, error) {
|
||||
if s, ok := m.settlements[id]; ok && s.SupplierID == supplierID {
|
||||
return s, nil
|
||||
}
|
||||
return nil, fmt.Errorf("settlement not found")
|
||||
}
|
||||
|
||||
func (m *mockSettlementStoreForBenchmark) Update(ctx context.Context, s *domain.Settlement, expectedVersion int) error {
|
||||
m.settlements[s.ID] = s
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSettlementStoreForBenchmark) List(ctx context.Context, supplierID int64) ([]*domain.Settlement, error) {
|
||||
var result []*domain.Settlement
|
||||
for _, s := range m.settlements {
|
||||
if s.SupplierID == supplierID {
|
||||
result = append(result, s)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *mockSettlementStoreForBenchmark) GetWithdrawableBalance(ctx context.Context, supplierID int64) (float64, error) {
|
||||
return m.balance, nil
|
||||
}
|
||||
|
||||
func (m *mockSettlementStoreForBenchmark) HasPendingOrProcessingWithdraw(ctx context.Context, supplierID int64) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
type mockEarningStoreForBenchmark struct{}
|
||||
|
||||
func newMockEarningStoreForBenchmark() *mockEarningStoreForBenchmark {
|
||||
return &mockEarningStoreForBenchmark{}
|
||||
}
|
||||
|
||||
func (m *mockEarningStoreForBenchmark) ListRecords(ctx context.Context, supplierID int64, startDate, endDate string, page, pageSize int) ([]*domain.EarningRecord, int, error) {
|
||||
return []*domain.EarningRecord{}, 0, nil
|
||||
}
|
||||
|
||||
func (m *mockEarningStoreForBenchmark) GetBillingSummary(ctx context.Context, supplierID int64, startDate, endDate string) (*domain.BillingSummary, error) {
|
||||
return &domain.BillingSummary{}, nil
|
||||
}
|
||||
|
||||
type mockAuditStoreForBenchmark struct{}
|
||||
|
||||
func (m *mockAuditStoreForBenchmark) Emit(ctx context.Context, event audit.Event) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockAuditStoreForBenchmark) Query(ctx context.Context, filter audit.EventFilter) ([]audit.Event, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockAuditStoreForBenchmark) QueryWithTotal(ctx context.Context, filter audit.EventFilter) ([]audit.Event, int64, error) {
|
||||
return nil, 0, nil
|
||||
}
|
||||
|
||||
func (m *mockAuditStoreForBenchmark) GetByID(ctx context.Context, eventID string) (audit.Event, error) {
|
||||
return audit.Event{}, fmt.Errorf("not found")
|
||||
}
|
||||
@@ -32,6 +32,9 @@ var (
|
||||
|
||||
// INV-SET-003: 结算单金额与余额流水必须平衡
|
||||
ErrSettlementBalanceMismatch = errors.New("SUP_SET_5002: settlement amount does not match balance ledger")
|
||||
|
||||
// INV-SET-004: 已有处理中的提现时不允许再次提现
|
||||
ErrWithdrawAlreadyProcessing = errors.New("SUP_SET_4093: another withdrawal is already processing")
|
||||
)
|
||||
|
||||
// InvariantChecker 领域不变量检查器
|
||||
|
||||
@@ -130,6 +130,10 @@ func (m *mockSettlementStoreForInvariant) GetWithdrawableBalance(ctx context.Con
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockSettlementStoreForInvariant) HasPendingOrProcessingWithdraw(ctx context.Context, supplierID int64) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func TestValidateAccountStateTransition(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -140,6 +140,8 @@ type SettlementStore interface {
|
||||
Update(ctx context.Context, s *Settlement, expectedVersion int) error
|
||||
List(ctx context.Context, supplierID int64) ([]*Settlement, error)
|
||||
GetWithdrawableBalance(ctx context.Context, supplierID int64) (float64, error)
|
||||
// HasPendingOrProcessingWithdraw 检查是否有待处理或处理中的提现单
|
||||
HasPendingOrProcessingWithdraw(ctx context.Context, supplierID int64) (bool, error)
|
||||
}
|
||||
|
||||
// 收益仓储接口
|
||||
@@ -176,6 +178,15 @@ func (s *settlementService) Withdraw(ctx context.Context, supplierID int64, req
|
||||
return nil, errors.New("invalid sms code")
|
||||
}
|
||||
|
||||
// INV-SET-004: 检查是否已有待处理或处理中的提现
|
||||
hasPending, err := s.store.HasPendingOrProcessingWithdraw(ctx, supplierID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if hasPending {
|
||||
return nil, ErrWithdrawAlreadyProcessing
|
||||
}
|
||||
|
||||
// 验证金额:必须为正数
|
||||
if req.Amount <= 0 {
|
||||
return nil, errors.New("SUP_SET_4003: withdraw amount must be positive")
|
||||
|
||||
@@ -65,6 +65,10 @@ func (m *mockSettlementStore) GetWithdrawableBalance(ctx context.Context, suppli
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockSettlementStore) HasPendingOrProcessingWithdraw(ctx context.Context, supplierID int64) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// mockEarningStore Mock收益存储
|
||||
type mockEarningStore struct {
|
||||
records []*EarningRecord
|
||||
|
||||
@@ -209,6 +209,22 @@ func (r *SettlementRepository) GetProcessing(ctx context.Context, tx pgxpool.Tx,
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// HasPendingOrProcessingWithdraw 检查是否有待处理或处理中的提现单
|
||||
func (r *SettlementRepository) HasPendingOrProcessingWithdraw(ctx context.Context, supplierID int64) (bool, error) {
|
||||
query := `
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM supply_settlements
|
||||
WHERE user_id = $1 AND status IN ('pending', 'processing')
|
||||
)
|
||||
`
|
||||
var exists bool
|
||||
err := r.pool.QueryRow(ctx, query, supplierID).Scan(&exists)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check pending/processing settlement: %w", err)
|
||||
}
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
// List 列出结算单
|
||||
func (r *SettlementRepository) List(ctx context.Context, supplierID int64) ([]*domain.Settlement, error) {
|
||||
query := `
|
||||
|
||||
@@ -213,6 +213,20 @@ func (s *InMemorySettlementStore) GetWithdrawableBalance(ctx context.Context, su
|
||||
return 10000.0, nil
|
||||
}
|
||||
|
||||
func (s *InMemorySettlementStore) HasPendingOrProcessingWithdraw(ctx context.Context, supplierID int64) (bool, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
for _, settlement := range s.settlements {
|
||||
if settlement.SupplierID == supplierID {
|
||||
if settlement.Status == domain.SettlementStatusPending || settlement.Status == domain.SettlementStatusProcessing {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// 内存收益存储
|
||||
type InMemoryEarningStore struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
237
supply-api/internal/testutil/mock/mocks.go
Normal file
237
supply-api/internal/testutil/mock/mocks.go
Normal file
@@ -0,0 +1,237 @@
|
||||
package mock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"lijiaoqiao/supply-api/internal/audit"
|
||||
"lijiaoqiao/supply-api/internal/domain"
|
||||
)
|
||||
|
||||
// MockAccountStore 账号存储 mock
|
||||
type MockAccountStore struct {
|
||||
Accounts map[int64]*domain.Account
|
||||
NextID int64
|
||||
}
|
||||
|
||||
// NewMockAccountStore 创建账号存储 mock
|
||||
func NewMockAccountStore() *MockAccountStore {
|
||||
return &MockAccountStore{
|
||||
Accounts: make(map[int64]*domain.Account),
|
||||
NextID: 1,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockAccountStore) Create(ctx context.Context, account *domain.Account) error {
|
||||
account.ID = m.NextID
|
||||
m.NextID++
|
||||
m.Accounts[account.ID] = account
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockAccountStore) GetByID(ctx context.Context, supplierID, id int64) (*domain.Account, error) {
|
||||
if account, ok := m.Accounts[id]; ok && account.SupplierID == supplierID {
|
||||
return account, nil
|
||||
}
|
||||
return nil, errors.New("account not found")
|
||||
}
|
||||
|
||||
func (m *MockAccountStore) Update(ctx context.Context, account *domain.Account) error {
|
||||
if _, ok := m.Accounts[account.ID]; ok {
|
||||
m.Accounts[account.ID] = account
|
||||
return nil
|
||||
}
|
||||
return errors.New("account not found")
|
||||
}
|
||||
|
||||
func (m *MockAccountStore) List(ctx context.Context, supplierID int64) ([]*domain.Account, error) {
|
||||
var result []*domain.Account
|
||||
for _, account := range m.Accounts {
|
||||
if account.SupplierID == supplierID {
|
||||
result = append(result, account)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// MockPackageStore 套餐存储 mock
|
||||
type MockPackageStore struct {
|
||||
Packages map[int64]*domain.Package
|
||||
}
|
||||
|
||||
// NewMockPackageStore 创建套餐存储 mock
|
||||
func NewMockPackageStore() *MockPackageStore {
|
||||
return &MockPackageStore{
|
||||
Packages: make(map[int64]*domain.Package),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockPackageStore) Create(ctx context.Context, pkg *domain.Package) error {
|
||||
if pkg.ID == 0 {
|
||||
pkg.ID = int64(len(m.Packages) + 1)
|
||||
}
|
||||
m.Packages[pkg.ID] = pkg
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockPackageStore) GetByID(ctx context.Context, supplierID, id int64) (*domain.Package, error) {
|
||||
if pkg, ok := m.Packages[id]; ok && pkg.SupplierID == supplierID {
|
||||
return pkg, nil
|
||||
}
|
||||
return nil, errors.New("package not found")
|
||||
}
|
||||
|
||||
func (m *MockPackageStore) Update(ctx context.Context, pkg *domain.Package) error {
|
||||
if _, ok := m.Packages[pkg.ID]; ok {
|
||||
m.Packages[pkg.ID] = pkg
|
||||
return nil
|
||||
}
|
||||
return errors.New("package not found")
|
||||
}
|
||||
|
||||
func (m *MockPackageStore) List(ctx context.Context, supplierID int64) ([]*domain.Package, error) {
|
||||
var result []*domain.Package
|
||||
for _, pkg := range m.Packages {
|
||||
if pkg.SupplierID == supplierID {
|
||||
result = append(result, pkg)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// MockSettlementStore 结算存储 mock
|
||||
type MockSettlementStore struct {
|
||||
Settlements map[int64]*domain.Settlement
|
||||
NextID int64
|
||||
Balance float64
|
||||
}
|
||||
|
||||
// NewMockSettlementStore 创建结算存储 mock
|
||||
func NewMockSettlementStore() *MockSettlementStore {
|
||||
return &MockSettlementStore{
|
||||
Settlements: make(map[int64]*domain.Settlement),
|
||||
NextID: 1,
|
||||
Balance: 10000.00, // 默认余额
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockSettlementStore) Create(ctx context.Context, s *domain.Settlement) error {
|
||||
s.ID = m.NextID
|
||||
m.NextID++
|
||||
m.Settlements[s.ID] = s
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockSettlementStore) GetByID(ctx context.Context, supplierID, id int64) (*domain.Settlement, error) {
|
||||
if s, ok := m.Settlements[id]; ok && s.SupplierID == supplierID {
|
||||
return s, nil
|
||||
}
|
||||
return nil, errors.New("settlement not found")
|
||||
}
|
||||
|
||||
func (m *MockSettlementStore) Update(ctx context.Context, s *domain.Settlement, expectedVersion int) error {
|
||||
if existing, ok := m.Settlements[s.ID]; ok && existing.Version != expectedVersion {
|
||||
return errors.New("concurrency conflict")
|
||||
}
|
||||
m.Settlements[s.ID] = s
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockSettlementStore) List(ctx context.Context, supplierID int64) ([]*domain.Settlement, error) {
|
||||
var result []*domain.Settlement
|
||||
for _, s := range m.Settlements {
|
||||
if s.SupplierID == supplierID {
|
||||
result = append(result, s)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *MockSettlementStore) GetWithdrawableBalance(ctx context.Context, supplierID int64) (float64, error) {
|
||||
return m.Balance, nil
|
||||
}
|
||||
|
||||
func (m *MockSettlementStore) HasPendingOrProcessingWithdraw(ctx context.Context, supplierID int64) (bool, error) {
|
||||
return false, nil // 默认允许提现
|
||||
}
|
||||
|
||||
// MockEarningStore 收益存储 mock
|
||||
type MockEarningStore struct {
|
||||
Records []*domain.EarningRecord
|
||||
}
|
||||
|
||||
// NewMockEarningStore 创建收益存储 mock
|
||||
func NewMockEarningStore() *MockEarningStore {
|
||||
return &MockEarningStore{
|
||||
Records: make([]*domain.EarningRecord, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockEarningStore) ListRecords(ctx context.Context, supplierID int64, startDate, endDate string, page, pageSize int) ([]*domain.EarningRecord, int, error) {
|
||||
return m.Records, len(m.Records), nil
|
||||
}
|
||||
|
||||
func (m *MockEarningStore) GetBillingSummary(ctx context.Context, supplierID int64, startDate, endDate string) (*domain.BillingSummary, error) {
|
||||
return &domain.BillingSummary{
|
||||
Period: domain.BillingPeriod{
|
||||
Start: startDate,
|
||||
End: endDate,
|
||||
},
|
||||
Summary: domain.BillingTotal{
|
||||
TotalRevenue: 1000.00,
|
||||
TotalOrders: 100,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// MockAuditStore 审计存储 mock
|
||||
type MockAuditStore struct {
|
||||
Events []audit.Event
|
||||
EmitFn func(ctx context.Context, event audit.Event) error
|
||||
}
|
||||
|
||||
// NewMockAuditStore 创建审计存储 mock
|
||||
func NewMockAuditStore() *MockAuditStore {
|
||||
return &MockAuditStore{
|
||||
Events: make([]audit.Event, 0),
|
||||
EmitFn: func(ctx context.Context, event audit.Event) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockAuditStore) Emit(ctx context.Context, event audit.Event) error {
|
||||
m.Events = append(m.Events, event)
|
||||
return m.EmitFn(ctx, event)
|
||||
}
|
||||
|
||||
func (m *MockAuditStore) Query(ctx context.Context, filter audit.EventFilter) ([]audit.Event, error) {
|
||||
return m.Events, nil
|
||||
}
|
||||
|
||||
func (m *MockAuditStore) QueryWithTotal(ctx context.Context, filter audit.EventFilter) ([]audit.Event, int64, error) {
|
||||
return m.Events, int64(len(m.Events)), nil
|
||||
}
|
||||
|
||||
func (m *MockAuditStore) GetByID(ctx context.Context, eventID string) (audit.Event, error) {
|
||||
return audit.Event{}, errors.New("not found")
|
||||
}
|
||||
|
||||
// MockFailingAuditStore 总是失败的审计存储 mock(用于测试错误处理)
|
||||
type MockFailingAuditStore struct{}
|
||||
|
||||
func (m *MockFailingAuditStore) Emit(ctx context.Context, event audit.Event) error {
|
||||
return errors.New("audit emit failed")
|
||||
}
|
||||
|
||||
func (m *MockFailingAuditStore) Query(ctx context.Context, filter audit.EventFilter) ([]audit.Event, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockFailingAuditStore) QueryWithTotal(ctx context.Context, filter audit.EventFilter) ([]audit.Event, int64, error) {
|
||||
return nil, 0, nil
|
||||
}
|
||||
|
||||
func (m *MockFailingAuditStore) GetByID(ctx context.Context, eventID string) (audit.Event, error) {
|
||||
return audit.Event{}, errors.New("not found")
|
||||
}
|
||||
Reference in New Issue
Block a user