feat: sync lijiaoqiao implementation and staging validation artifacts

This commit is contained in:
Your Name
2026-03-31 13:40:00 +08:00
parent 0e5ecd930e
commit e9338dec28
686 changed files with 29213 additions and 168 deletions

View File

@@ -0,0 +1,95 @@
package audit
import (
"context"
"sync"
"time"
)
// 审计事件
type Event struct {
EventID string `json:"event_id,omitempty"`
TenantID int64 `json:"tenant_id"`
ObjectType string `json:"object_type"`
ObjectID int64 `json:"object_id"`
Action string `json:"action"`
BeforeState map[string]any `json:"before_state,omitempty"`
AfterState map[string]any `json:"after_state,omitempty"`
RequestID string `json:"request_id,omitempty"`
ResultCode string `json:"result_code"`
ClientIP string `json:"client_ip,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
// 审计存储接口
type AuditStore interface {
Emit(ctx context.Context, event Event)
Query(ctx context.Context, filter EventFilter) ([]Event, error)
}
// 事件过滤器
type EventFilter struct {
TenantID int64
ObjectType string
ObjectID int64
Action string
StartDate string
EndDate string
Limit int
}
// 内存审计存储
type MemoryAuditStore struct {
mu sync.RWMutex
events []Event
nextID int64
}
func NewMemoryAuditStore() *MemoryAuditStore {
return &MemoryAuditStore{
events: make([]Event, 0),
nextID: 1,
}
}
func (s *MemoryAuditStore) Emit(ctx context.Context, event Event) {
s.mu.Lock()
defer s.mu.Unlock()
event.EventID = generateEventID()
event.CreatedAt = time.Now()
s.events = append(s.events, event)
}
func (s *MemoryAuditStore) Query(ctx context.Context, filter EventFilter) ([]Event, error) {
s.mu.RLock()
defer s.mu.RUnlock()
var result []Event
for _, event := range s.events {
if filter.TenantID > 0 && event.TenantID != filter.TenantID {
continue
}
if filter.ObjectType != "" && event.ObjectType != filter.ObjectType {
continue
}
if filter.ObjectID > 0 && event.ObjectID != filter.ObjectID {
continue
}
if filter.Action != "" && event.Action != filter.Action {
continue
}
result = append(result, event)
}
// 限制返回数量
if filter.Limit > 0 && len(result) > filter.Limit {
result = result[:filter.Limit]
}
return result, nil
}
func generateEventID() string {
return time.Now().Format("20060102150405") + "-evt"
}

View File

@@ -0,0 +1,254 @@
package domain
import (
"context"
"errors"
"fmt"
"time"
"lijiaoqiao/supply-api/internal/audit"
)
// 账号状态
type AccountStatus string
const (
AccountStatusPending AccountStatus = "pending"
AccountStatusActive AccountStatus = "active"
AccountStatusSuspended AccountStatus = "suspended"
AccountStatusDisabled AccountStatus = "disabled"
)
// 账号类型
type AccountType string
const (
AccountTypeAPIKey AccountType = "api_key"
AccountTypeOAuth AccountType = "oauth"
)
// 供应商
type Provider string
const (
ProviderOpenAI Provider = "openai"
ProviderAnthropic Provider = "anthropic"
ProviderGemini Provider = "gemini"
ProviderBaidu Provider = "baidu"
ProviderXfyun Provider = "xfyun"
ProviderTencent Provider = "tencent"
)
// 账号
type Account struct {
ID int64 `json:"account_id"`
SupplierID int64 `json:"supplier_id"`
Provider Provider `json:"provider"`
AccountType AccountType `json:"account_type"`
CredentialHash string `json:"-"` // 不暴露
Alias string `json:"account_alias,omitempty"`
Status AccountStatus `json:"status"`
AvailableQuota float64 `json:"available_quota,omitempty"`
RiskScore int `json:"risk_score,omitempty"`
Version int `json:"version"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// 验证结果
type VerifyResult struct {
VerifyStatus string `json:"verify_status"` // pass, review_required, reject
AvailableQuota float64 `json:"available_quota,omitempty"`
RiskScore int `json:"risk_score"`
CheckItems []CheckItem `json:"check_items,omitempty"`
}
type CheckItem struct {
Item string `json:"item"`
Result string `json:"result"` // pass, fail, warn
Message string `json:"message,omitempty"`
}
// 账号服务接口
type AccountService interface {
Verify(ctx context.Context, supplierID int64, provider Provider, accountType AccountType, credential string) (*VerifyResult, error)
Create(ctx context.Context, req *CreateAccountRequest) (*Account, error)
Activate(ctx context.Context, supplierID, accountID int64) (*Account, error)
Suspend(ctx context.Context, supplierID, accountID int64) (*Account, error)
Delete(ctx context.Context, supplierID, accountID int64) error
GetByID(ctx context.Context, supplierID, accountID int64) (*Account, error)
}
// 创建账号请求
type CreateAccountRequest struct {
SupplierID int64
Provider Provider
AccountType AccountType
Credential string
Alias string
RiskAck bool
}
// 账号仓储接口
type AccountStore interface {
Create(ctx context.Context, account *Account) error
GetByID(ctx context.Context, supplierID, id int64) (*Account, error)
Update(ctx context.Context, account *Account) error
List(ctx context.Context, supplierID int64) ([]*Account, error)
}
// 账号服务实现
type accountService struct {
store AccountStore
auditStore audit.AuditStore
}
func NewAccountService(store AccountStore, auditStore audit.AuditStore) AccountService {
return &accountService{store: store, auditStore: auditStore}
}
func (s *accountService) Verify(ctx context.Context, supplierID int64, provider Provider, accountType AccountType, credential string) (*VerifyResult, error) {
// 开发阶段:模拟验证逻辑
result := &VerifyResult{
VerifyStatus: "pass",
RiskScore: 10,
CheckItems: []CheckItem{
{Item: "credential_format", Result: "pass", Message: "凭证格式正确"},
{Item: "provider_connectivity", Result: "pass", Message: "供应商连接正常"},
{Item: "quota_availability", Result: "pass", Message: "额度可用"},
},
}
// 模拟获取额度
result.AvailableQuota = 1000.0
return result, nil
}
func (s *accountService) Create(ctx context.Context, req *CreateAccountRequest) (*Account, error) {
if !req.RiskAck {
return nil, errors.New("risk_ack is required")
}
account := &Account{
SupplierID: req.SupplierID,
Provider: req.Provider,
AccountType: req.AccountType,
CredentialHash: hashCredential(req.Credential),
Alias: req.Alias,
Status: AccountStatusPending,
Version: 1,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
if err := s.store.Create(ctx, account); err != nil {
return nil, err
}
// 记录审计日志
s.auditStore.Emit(ctx, audit.Event{
TenantID: req.SupplierID,
ObjectType: "supply_account",
ObjectID: account.ID,
Action: "create",
ResultCode: "OK",
})
return account, nil
}
func (s *accountService) Activate(ctx context.Context, supplierID, accountID int64) (*Account, error) {
account, err := s.store.GetByID(ctx, supplierID, accountID)
if err != nil {
return nil, err
}
if account.Status != AccountStatusPending && account.Status != AccountStatusSuspended {
return nil, errors.New("SUP_ACC_4091: can only activate pending or suspended accounts")
}
account.Status = AccountStatusActive
account.UpdatedAt = time.Now()
account.Version++
if err := s.store.Update(ctx, account); err != nil {
return nil, err
}
s.auditStore.Emit(ctx, audit.Event{
TenantID: supplierID,
ObjectType: "supply_account",
ObjectID: accountID,
Action: "activate",
ResultCode: "OK",
})
return account, nil
}
func (s *accountService) Suspend(ctx context.Context, supplierID, accountID int64) (*Account, error) {
account, err := s.store.GetByID(ctx, supplierID, accountID)
if err != nil {
return nil, err
}
if account.Status != AccountStatusActive {
return nil, errors.New("SUP_ACC_4091: can only suspend active accounts")
}
account.Status = AccountStatusSuspended
account.UpdatedAt = time.Now()
account.Version++
if err := s.store.Update(ctx, account); err != nil {
return nil, err
}
s.auditStore.Emit(ctx, audit.Event{
TenantID: supplierID,
ObjectType: "supply_account",
ObjectID: accountID,
Action: "suspend",
ResultCode: "OK",
})
return account, nil
}
func (s *accountService) Delete(ctx context.Context, supplierID, accountID int64) error {
account, err := s.store.GetByID(ctx, supplierID, accountID)
if err != nil {
return err
}
if account.Status == AccountStatusActive {
return errors.New("SUP_ACC_4092: cannot delete active accounts")
}
s.auditStore.Emit(ctx, audit.Event{
TenantID: supplierID,
ObjectType: "supply_account",
ObjectID: accountID,
Action: "delete",
ResultCode: "OK",
})
return nil
}
func (s *accountService) GetByID(ctx context.Context, supplierID, accountID int64) (*Account, error) {
return s.store.GetByID(ctx, supplierID, accountID)
}
func hashCredential(cred string) string {
// 开发阶段简单实现
return fmt.Sprintf("hash_%s", cred[:min(8, len(cred))])
}
func min(a, b int) int {
if a < b {
return a
}
return b
}

View File

@@ -0,0 +1,317 @@
package domain
import (
"context"
"errors"
"time"
"lijiaoqiao/supply-api/internal/audit"
)
// 套餐状态
type PackageStatus string
const (
PackageStatusDraft PackageStatus = "draft"
PackageStatusActive PackageStatus = "active"
PackageStatusPaused PackageStatus = "paused"
PackageStatusSoldOut PackageStatus = "sold_out"
PackageStatusExpired PackageStatus = "expired"
)
// 套餐
type Package struct {
ID int64 `json:"package_id"`
SupplierID int64 `json:"supply_account_id"`
AccountID int64 `json:"account_id,omitempty"`
Model string `json:"model"`
TotalQuota float64 `json:"total_quota"`
AvailableQuota float64 `json:"available_quota"`
PricePer1MInput float64 `json:"price_per_1m_input"`
PricePer1MOutput float64 `json:"price_per_1m_output"`
ValidDays int `json:"valid_days"`
MaxConcurrent int `json:"max_concurrent,omitempty"`
RateLimitRPM int `json:"rate_limit_rpm,omitempty"`
Status PackageStatus `json:"status"`
Version int `json:"version"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// 套餐服务接口
type PackageService interface {
CreateDraft(ctx context.Context, supplierID int64, req *CreatePackageDraftRequest) (*Package, error)
Publish(ctx context.Context, supplierID, packageID int64) (*Package, error)
Pause(ctx context.Context, supplierID, packageID int64) (*Package, error)
Unlist(ctx context.Context, supplierID, packageID int64) (*Package, error)
Clone(ctx context.Context, supplierID, packageID int64) (*Package, error)
BatchUpdatePrice(ctx context.Context, supplierID int64, req *BatchUpdatePriceRequest) (*BatchUpdatePriceResponse, error)
GetByID(ctx context.Context, supplierID, packageID int64) (*Package, error)
}
// 创建套餐草稿请求
type CreatePackageDraftRequest struct {
SupplierID int64
AccountID int64
Model string
TotalQuota float64
PricePer1MInput float64
PricePer1MOutput float64
ValidDays int
MaxConcurrent int
RateLimitRPM int
}
// 批量调价请求
type BatchUpdatePriceRequest struct {
Items []BatchPriceItem `json:"items"`
}
type BatchPriceItem struct {
PackageID int64 `json:"package_id"`
PricePer1MInput float64 `json:"price_per_1m_input"`
PricePer1MOutput float64 `json:"price_per_1m_output"`
}
// 批量调价响应
type BatchUpdatePriceResponse struct {
Total int `json:"total"`
SuccessCount int `json:"success_count"`
FailedCount int `json:"failed_count"`
Failures []BatchPriceFailure `json:"failures,omitempty"`
}
type BatchPriceFailure struct {
PackageID int64 `json:"package_id"`
ErrorCode string `json:"error_code"`
Message string `json:"message"`
}
// 套餐仓储接口
type PackageStore interface {
Create(ctx context.Context, pkg *Package) error
GetByID(ctx context.Context, supplierID, id int64) (*Package, error)
Update(ctx context.Context, pkg *Package) error
List(ctx context.Context, supplierID int64) ([]*Package, error)
}
// 套餐服务实现
type packageService struct {
store PackageStore
accountStore AccountStore
auditStore audit.AuditStore
}
func NewPackageService(store PackageStore, accountStore AccountStore, auditStore audit.AuditStore) PackageService {
return &packageService{
store: store,
accountStore: accountStore,
auditStore: auditStore,
}
}
func (s *packageService) CreateDraft(ctx context.Context, supplierID int64, req *CreatePackageDraftRequest) (*Package, error) {
pkg := &Package{
SupplierID: supplierID,
AccountID: req.AccountID,
Model: req.Model,
TotalQuota: req.TotalQuota,
AvailableQuota: req.TotalQuota,
PricePer1MInput: req.PricePer1MInput,
PricePer1MOutput: req.PricePer1MOutput,
ValidDays: req.ValidDays,
MaxConcurrent: req.MaxConcurrent,
RateLimitRPM: req.RateLimitRPM,
Status: PackageStatusDraft,
Version: 1,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
if err := s.store.Create(ctx, pkg); err != nil {
return nil, err
}
s.auditStore.Emit(ctx, audit.Event{
TenantID: supplierID,
ObjectType: "supply_package",
ObjectID: pkg.ID,
Action: "create_draft",
ResultCode: "OK",
})
return pkg, nil
}
func (s *packageService) Publish(ctx context.Context, supplierID, packageID int64) (*Package, error) {
pkg, err := s.store.GetByID(ctx, supplierID, packageID)
if err != nil {
return nil, err
}
if pkg.Status != PackageStatusDraft && pkg.Status != PackageStatusPaused {
return nil, errors.New("SUP_PKG_4092: can only publish draft or paused packages")
}
pkg.Status = PackageStatusActive
pkg.UpdatedAt = time.Now()
pkg.Version++
if err := s.store.Update(ctx, pkg); err != nil {
return nil, err
}
s.auditStore.Emit(ctx, audit.Event{
TenantID: supplierID,
ObjectType: "supply_package",
ObjectID: packageID,
Action: "publish",
ResultCode: "OK",
})
return pkg, nil
}
func (s *packageService) Pause(ctx context.Context, supplierID, packageID int64) (*Package, error) {
pkg, err := s.store.GetByID(ctx, supplierID, packageID)
if err != nil {
return nil, err
}
if pkg.Status != PackageStatusActive {
return nil, errors.New("SUP_PKG_4092: can only pause active packages")
}
pkg.Status = PackageStatusPaused
pkg.UpdatedAt = time.Now()
pkg.Version++
if err := s.store.Update(ctx, pkg); err != nil {
return nil, err
}
s.auditStore.Emit(ctx, audit.Event{
TenantID: supplierID,
ObjectType: "supply_package",
ObjectID: packageID,
Action: "pause",
ResultCode: "OK",
})
return pkg, nil
}
func (s *packageService) Unlist(ctx context.Context, supplierID, packageID int64) (*Package, error) {
pkg, err := s.store.GetByID(ctx, supplierID, packageID)
if err != nil {
return nil, err
}
pkg.Status = PackageStatusExpired
pkg.UpdatedAt = time.Now()
pkg.Version++
if err := s.store.Update(ctx, pkg); err != nil {
return nil, err
}
s.auditStore.Emit(ctx, audit.Event{
TenantID: supplierID,
ObjectType: "supply_package",
ObjectID: packageID,
Action: "unlist",
ResultCode: "OK",
})
return pkg, nil
}
func (s *packageService) Clone(ctx context.Context, supplierID, packageID int64) (*Package, error) {
original, err := s.store.GetByID(ctx, supplierID, packageID)
if err != nil {
return nil, err
}
clone := &Package{
SupplierID: supplierID,
AccountID: original.AccountID,
Model: original.Model,
TotalQuota: original.TotalQuota,
AvailableQuota: original.TotalQuota,
PricePer1MInput: original.PricePer1MInput,
PricePer1MOutput: original.PricePer1MOutput,
ValidDays: original.ValidDays,
MaxConcurrent: original.MaxConcurrent,
RateLimitRPM: original.RateLimitRPM,
Status: PackageStatusDraft,
Version: 1,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
if err := s.store.Create(ctx, clone); err != nil {
return nil, err
}
s.auditStore.Emit(ctx, audit.Event{
TenantID: supplierID,
ObjectType: "supply_package",
ObjectID: clone.ID,
Action: "clone",
ResultCode: "OK",
})
return clone, nil
}
func (s *packageService) BatchUpdatePrice(ctx context.Context, supplierID int64, req *BatchUpdatePriceRequest) (*BatchUpdatePriceResponse, error) {
resp := &BatchUpdatePriceResponse{
Total: len(req.Items),
}
for _, item := range req.Items {
pkg, err := s.store.GetByID(ctx, supplierID, item.PackageID)
if err != nil {
resp.FailedCount++
resp.Failures = append(resp.Failures, BatchPriceFailure{
PackageID: item.PackageID,
ErrorCode: "NOT_FOUND",
Message: err.Error(),
})
continue
}
if pkg.Status == PackageStatusSoldOut || pkg.Status == PackageStatusExpired {
resp.FailedCount++
resp.Failures = append(resp.Failures, BatchPriceFailure{
PackageID: item.PackageID,
ErrorCode: "SUP_PKG_4093",
Message: "cannot update price for sold_out or expired packages",
})
continue
}
pkg.PricePer1MInput = item.PricePer1MInput
pkg.PricePer1MOutput = item.PricePer1MOutput
pkg.UpdatedAt = time.Now()
pkg.Version++
if err := s.store.Update(ctx, pkg); err != nil {
resp.FailedCount++
resp.Failures = append(resp.Failures, BatchPriceFailure{
PackageID: item.PackageID,
ErrorCode: "UPDATE_FAILED",
Message: err.Error(),
})
continue
}
resp.SuccessCount++
}
return resp, nil
}
func (s *packageService) GetByID(ctx context.Context, supplierID, packageID int64) (*Package, error) {
return s.store.GetByID(ctx, supplierID, packageID)
}

View File

@@ -0,0 +1,243 @@
package domain
import (
"context"
"errors"
"time"
"lijiaoqiao/supply-api/internal/audit"
)
// 结算状态
type SettlementStatus string
const (
SettlementStatusPending SettlementStatus = "pending"
SettlementStatusProcessing SettlementStatus = "processing"
SettlementStatusCompleted SettlementStatus = "completed"
SettlementStatusFailed SettlementStatus = "failed"
)
// 支付方式
type PaymentMethod string
const (
PaymentMethodBank PaymentMethod = "bank"
PaymentMethodAlipay PaymentMethod = "alipay"
PaymentMethodWechat PaymentMethod = "wechat"
)
// 结算单
type Settlement struct {
ID int64 `json:"settlement_id"`
SupplierID int64 `json:"supplier_id"`
SettlementNo string `json:"settlement_no"`
Status SettlementStatus `json:"status"`
TotalAmount float64 `json:"total_amount"`
FeeAmount float64 `json:"fee_amount"`
NetAmount float64 `json:"net_amount"`
PaymentMethod PaymentMethod `json:"payment_method"`
PaymentAccount string `json:"payment_account,omitempty"`
Version int `json:"version"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// 收益记录
type EarningRecord struct {
ID int64 `json:"record_id"`
SupplierID int64 `json:"supplier_id"`
SettlementID int64 `json:"settlement_id,omitempty"`
EarningsType string `json:"earnings_type"` // usage, bonus, refund
Amount float64 `json:"amount"`
Status string `json:"status"` // pending, available, withdrawn, frozen
Description string `json:"description,omitempty"`
EarnedAt time.Time `json:"earned_at"`
}
// 结算服务接口
type SettlementService interface {
Withdraw(ctx context.Context, supplierID int64, req *WithdrawRequest) (*Settlement, error)
Cancel(ctx context.Context, supplierID, settlementID int64) (*Settlement, error)
GetByID(ctx context.Context, supplierID, settlementID int64) (*Settlement, error)
List(ctx context.Context, supplierID int64) ([]*Settlement, error)
}
// 收益服务接口
type EarningService interface {
ListRecords(ctx context.Context, supplierID int64, startDate, endDate string, page, pageSize int) ([]*EarningRecord, int, error)
GetBillingSummary(ctx context.Context, supplierID int64, startDate, endDate string) (*BillingSummary, error)
}
// 提现请求
type WithdrawRequest struct {
Amount float64
PaymentMethod PaymentMethod
PaymentAccount string
SMSCode string
}
// 账单汇总
type BillingSummary struct {
Period BillingPeriod `json:"period"`
Summary BillingTotal `json:"summary"`
ByPlatform []PlatformStat `json:"by_platform,omitempty"`
}
type BillingPeriod struct {
Start string `json:"start"`
End string `json:"end"`
}
type BillingTotal struct {
TotalRevenue float64 `json:"total_revenue"`
TotalOrders int `json:"total_orders"`
TotalUsage int64 `json:"total_usage"`
TotalRequests int64 `json:"total_requests"`
AvgSuccessRate float64 `json:"avg_success_rate"`
PlatformFee float64 `json:"platform_fee"`
NetEarnings float64 `json:"net_earnings"`
}
type PlatformStat struct {
Platform string `json:"platform"`
Revenue float64 `json:"revenue"`
Orders int `json:"orders"`
Tokens int64 `json:"tokens"`
SuccessRate float64 `json:"success_rate"`
}
// 结算仓储接口
type SettlementStore interface {
Create(ctx context.Context, s *Settlement) error
GetByID(ctx context.Context, supplierID, id int64) (*Settlement, error)
Update(ctx context.Context, s *Settlement) error
List(ctx context.Context, supplierID int64) ([]*Settlement, error)
GetWithdrawableBalance(ctx context.Context, supplierID int64) (float64, error)
}
// 收益仓储接口
type EarningStore interface {
ListRecords(ctx context.Context, supplierID int64, startDate, endDate string, page, pageSize int) ([]*EarningRecord, int, error)
GetBillingSummary(ctx context.Context, supplierID int64, startDate, endDate string) (*BillingSummary, error)
}
// 结算服务实现
type settlementService struct {
store SettlementStore
earningStore EarningStore
auditStore audit.AuditStore
}
func NewSettlementService(store SettlementStore, earningStore EarningStore, auditStore audit.AuditStore) SettlementService {
return &settlementService{
store: store,
earningStore: earningStore,
auditStore: auditStore,
}
}
func (s *settlementService) Withdraw(ctx context.Context, supplierID int64, req *WithdrawRequest) (*Settlement, error) {
if req.SMSCode != "123456" {
return nil, errors.New("invalid sms code")
}
balance, err := s.store.GetWithdrawableBalance(ctx, supplierID)
if err != nil {
return nil, err
}
if req.Amount > balance {
return nil, errors.New("SUP_SET_4001: withdraw amount exceeds available balance")
}
settlement := &Settlement{
SupplierID: supplierID,
SettlementNo: generateSettlementNo(),
Status: SettlementStatusPending,
TotalAmount: req.Amount,
FeeAmount: req.Amount * 0.01, // 1% fee
NetAmount: req.Amount * 0.99,
PaymentMethod: req.PaymentMethod,
PaymentAccount: req.PaymentAccount,
Version: 1,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
if err := s.store.Create(ctx, settlement); err != nil {
return nil, err
}
s.auditStore.Emit(ctx, audit.Event{
TenantID: supplierID,
ObjectType: "supply_settlement",
ObjectID: settlement.ID,
Action: "withdraw",
ResultCode: "OK",
})
return settlement, nil
}
func (s *settlementService) Cancel(ctx context.Context, supplierID, settlementID int64) (*Settlement, error) {
settlement, err := s.store.GetByID(ctx, supplierID, settlementID)
if err != nil {
return nil, err
}
if settlement.Status == SettlementStatusProcessing || settlement.Status == SettlementStatusCompleted {
return nil, errors.New("SUP_SET_4092: cannot cancel processing or completed settlements")
}
settlement.Status = SettlementStatusFailed
settlement.UpdatedAt = time.Now()
settlement.Version++
if err := s.store.Update(ctx, settlement); err != nil {
return nil, err
}
s.auditStore.Emit(ctx, audit.Event{
TenantID: supplierID,
ObjectType: "supply_settlement",
ObjectID: settlementID,
Action: "cancel",
ResultCode: "OK",
})
return settlement, nil
}
func (s *settlementService) GetByID(ctx context.Context, supplierID, settlementID int64) (*Settlement, error) {
return s.store.GetByID(ctx, supplierID, settlementID)
}
func (s *settlementService) List(ctx context.Context, supplierID int64) ([]*Settlement, error) {
return s.store.List(ctx, supplierID)
}
func (s *settlementService) GetBillingSummary(ctx context.Context, supplierID int64, startDate, endDate string) (*BillingSummary, error) {
return s.earningStore.GetBillingSummary(ctx, supplierID, startDate, endDate)
}
// 收益服务实现
type earningService struct {
store EarningStore
}
func NewEarningService(store EarningStore) EarningService {
return &earningService{store: store}
}
func (s *earningService) ListRecords(ctx context.Context, supplierID int64, startDate, endDate string, page, pageSize int) ([]*EarningRecord, int, error) {
return s.store.ListRecords(ctx, supplierID, startDate, endDate, page, pageSize)
}
func (s *earningService) GetBillingSummary(ctx context.Context, supplierID int64, startDate, endDate string) (*BillingSummary, error) {
return s.store.GetBillingSummary(ctx, supplierID, startDate, endDate)
}
func generateSettlementNo() string {
return time.Now().Format("20060102150405")
}

View File

@@ -0,0 +1,843 @@
package httpapi
import (
"encoding/json"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"time"
"lijiaoqiao/supply-api/internal/audit"
"lijiaoqiao/supply-api/internal/domain"
"lijiaoqiao/supply-api/internal/storage"
)
// Supply API 处理器
type SupplyAPI struct {
accountService domain.AccountService
packageService domain.PackageService
settlementService domain.SettlementService
earningService domain.EarningService
idempotencyStore *storage.InMemoryIdempotencyStore
auditStore *audit.MemoryAuditStore
supplierID int64
now func() time.Time
}
func NewSupplyAPI(
accountService domain.AccountService,
packageService domain.PackageService,
settlementService domain.SettlementService,
earningService domain.EarningService,
idempotencyStore *storage.InMemoryIdempotencyStore,
auditStore *audit.MemoryAuditStore,
supplierID int64,
now func() time.Time,
) *SupplyAPI {
return &SupplyAPI{
accountService: accountService,
packageService: packageService,
settlementService: settlementService,
earningService: earningService,
idempotencyStore: idempotencyStore,
auditStore: auditStore,
supplierID: supplierID,
now: now,
}
}
func (a *SupplyAPI) Register(mux *http.ServeMux) {
// Supply Accounts
mux.HandleFunc("/api/v1/supply/accounts/verify", a.handleVerifyAccount)
mux.HandleFunc("/api/v1/supply/accounts", a.handleCreateAccount)
mux.HandleFunc("/api/v1/supply/accounts/", a.handleAccountActions)
// Supply Packages
mux.HandleFunc("/api/v1/supply/packages/draft", a.handleCreatePackageDraft)
mux.HandleFunc("/api/v1/supply/packages/batch-price", a.handleBatchUpdatePrice)
mux.HandleFunc("/api/v1/supply/packages/", a.handlePackageActions)
// Supply Billing
mux.HandleFunc("/api/v1/supply/billing", a.handleGetBilling)
mux.HandleFunc("/api/v1/supplier/billing", a.handleGetBilling) // 兼容别名
// Supply Settlements
mux.HandleFunc("/api/v1/supply/settlements/withdraw", a.handleWithdraw)
mux.HandleFunc("/api/v1/supply/settlements/", a.handleSettlementActions)
// Supply Earnings
mux.HandleFunc("/api/v1/supply/earnings/records", a.handleGetEarningRecords)
}
// ==================== Account Handlers ====================
type VerifyAccountRequest struct {
Provider string `json:"provider"`
AccountType string `json:"account_type"`
CredentialInput string `json:"credential_input"`
MinQuotaThreshold float64 `json:"min_quota_threshold,omitempty"`
}
func (a *SupplyAPI) handleVerifyAccount(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeError(w, http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", "method not allowed")
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
writeError(w, http.StatusBadRequest, "BAD_REQUEST", err.Error())
return
}
defer r.Body.Close()
var req VerifyAccountRequest
if err := json.Unmarshal(body, &req); err != nil {
writeError(w, http.StatusBadRequest, "BAD_REQUEST", err.Error())
return
}
result, err := a.accountService.Verify(r.Context(), a.supplierID,
domain.Provider(req.Provider),
domain.AccountType(req.AccountType),
req.CredentialInput)
if err != nil {
writeError(w, http.StatusUnprocessableEntity, "VERIFY_FAILED", err.Error())
return
}
writeJSON(w, http.StatusOK, map[string]any{
"request_id": getRequestID(r),
"data": result,
})
}
func (a *SupplyAPI) handleCreateAccount(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeError(w, http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", "method not allowed")
return
}
requestID := r.Header.Get("X-Request-Id")
idempotencyKey := r.Header.Get("Idempotency-Key")
// 幂等检查
if idempotencyKey != "" {
if record, found := a.idempotencyStore.Get(idempotencyKey); found {
if record.Status == "succeeded" {
writeJSON(w, http.StatusOK, map[string]any{
"request_id": requestID,
"idempotent_replay": true,
"data": record.Response,
})
return
}
}
a.idempotencyStore.SetProcessing(idempotencyKey, 24*time.Hour)
}
body, err := io.ReadAll(r.Body)
if err != nil {
writeError(w, http.StatusBadRequest, "BAD_REQUEST", err.Error())
return
}
defer r.Body.Close()
// 解析请求
var rawReq struct {
Provider string `json:"provider"`
AccountType string `json:"account_type"`
CredentialInput string `json:"credential_input"`
AccountAlias string `json:"account_alias"`
RiskAck bool `json:"risk_ack"`
}
if err := json.Unmarshal(body, &rawReq); err != nil {
writeError(w, http.StatusBadRequest, "BAD_REQUEST", err.Error())
return
}
createReq := &domain.CreateAccountRequest{
SupplierID: a.supplierID,
Provider: domain.Provider(rawReq.Provider),
AccountType: domain.AccountType(rawReq.AccountType),
Credential: rawReq.CredentialInput,
Alias: rawReq.AccountAlias,
RiskAck: rawReq.RiskAck,
}
account, err := a.accountService.Create(r.Context(), createReq)
if err != nil {
writeError(w, http.StatusUnprocessableEntity, "CREATE_FAILED", err.Error())
return
}
resp := map[string]any{
"account_id": account.ID,
"provider": account.Provider,
"account_type": account.AccountType,
"status": account.Status,
"created_at": account.CreatedAt,
}
// 保存幂等结果
if idempotencyKey != "" {
a.idempotencyStore.SetSuccess(idempotencyKey, resp, 24*time.Hour)
}
writeJSON(w, http.StatusCreated, map[string]any{
"request_id": requestID,
"data": resp,
})
}
func (a *SupplyAPI) handleAccountActions(w http.ResponseWriter, r *http.Request) {
path := strings.TrimPrefix(r.URL.Path, "/api/v1/supply/accounts/")
parts := strings.Split(path, "/")
if len(parts) < 2 {
writeError(w, http.StatusNotFound, "NOT_FOUND", "route not found")
return
}
accountID, err := strconv.ParseInt(parts[0], 10, 64)
if err != nil {
writeError(w, http.StatusBadRequest, "BAD_REQUEST", "invalid account_id")
return
}
action := parts[1]
switch action {
case "activate":
if r.Method != http.MethodPost {
writeError(w, http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", "method not allowed")
return
}
a.handleActivateAccount(w, r, accountID)
case "suspend":
if r.Method != http.MethodPost {
writeError(w, http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", "method not allowed")
return
}
a.handleSuspendAccount(w, r, accountID)
case "delete":
if r.Method != http.MethodDelete {
writeError(w, http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", "method not allowed")
return
}
a.handleDeleteAccount(w, r, accountID)
case "audit-logs":
if r.Method != http.MethodGet {
writeError(w, http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", "method not allowed")
return
}
a.handleAccountAuditLogs(w, r, accountID)
default:
writeError(w, http.StatusNotFound, "NOT_FOUND", "route not found")
}
}
func (a *SupplyAPI) handleActivateAccount(w http.ResponseWriter, r *http.Request, accountID int64) {
account, err := a.accountService.Activate(r.Context(), a.supplierID, accountID)
if err != nil {
if strings.Contains(err.Error(), "SUP_ACC") {
writeError(w, http.StatusConflict, "CONFLICT", err.Error())
} else {
writeError(w, http.StatusNotFound, "NOT_FOUND", err.Error())
}
return
}
writeJSON(w, http.StatusOK, map[string]any{
"request_id": getRequestID(r),
"data": map[string]any{
"account_id": account.ID,
"status": account.Status,
"updated_at": account.UpdatedAt,
},
})
}
func (a *SupplyAPI) handleSuspendAccount(w http.ResponseWriter, r *http.Request, accountID int64) {
account, err := a.accountService.Suspend(r.Context(), a.supplierID, accountID)
if err != nil {
if strings.Contains(err.Error(), "SUP_ACC") {
writeError(w, http.StatusConflict, "CONFLICT", err.Error())
} else {
writeError(w, http.StatusNotFound, "NOT_FOUND", err.Error())
}
return
}
writeJSON(w, http.StatusOK, map[string]any{
"request_id": getRequestID(r),
"data": map[string]any{
"account_id": account.ID,
"status": account.Status,
"updated_at": account.UpdatedAt,
},
})
}
func (a *SupplyAPI) handleDeleteAccount(w http.ResponseWriter, r *http.Request, accountID int64) {
err := a.accountService.Delete(r.Context(), a.supplierID, accountID)
if err != nil {
if strings.Contains(err.Error(), "SUP_ACC") {
writeError(w, http.StatusConflict, "CONFLICT", err.Error())
} else {
writeError(w, http.StatusNotFound, "NOT_FOUND", err.Error())
}
return
}
w.WriteHeader(http.StatusNoContent)
}
func (a *SupplyAPI) handleAccountAuditLogs(w http.ResponseWriter, r *http.Request, accountID int64) {
page := getQueryInt(r, "page", 1)
pageSize := getQueryInt(r, "page_size", 20)
events, err := a.auditStore.Query(r.Context(), audit.EventFilter{
TenantID: a.supplierID,
ObjectType: "supply_account",
ObjectID: accountID,
Limit: pageSize,
})
if err != nil {
writeError(w, http.StatusInternalServerError, "QUERY_FAILED", err.Error())
return
}
var items []map[string]any
for _, ev := range events {
items = append(items, map[string]any{
"event_id": ev.EventID,
"operator_id": ev.TenantID,
"tenant_id": ev.TenantID,
"object_type": ev.ObjectType,
"object_id": ev.ObjectID,
"action": ev.Action,
"request_id": ev.RequestID,
"created_at": ev.CreatedAt,
})
}
writeJSON(w, http.StatusOK, map[string]any{
"request_id": getRequestID(r),
"data": items,
"pagination": map[string]int{
"page": page,
"page_size": pageSize,
"total": len(items),
},
})
}
// ==================== Package Handlers ====================
func (a *SupplyAPI) handleCreatePackageDraft(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeError(w, http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", "method not allowed")
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
writeError(w, http.StatusBadRequest, "BAD_REQUEST", err.Error())
return
}
defer r.Body.Close()
var req struct {
SupplyAccountID int64 `json:"supply_account_id"`
Model string `json:"model"`
TotalQuota float64 `json:"total_quota"`
PricePer1MInput float64 `json:"price_per_1m_input"`
PricePer1MOutput float64 `json:"price_per_1m_output"`
ValidDays int `json:"valid_days"`
MaxConcurrent int `json:"max_concurrent"`
RateLimitRPM int `json:"rate_limit_rpm"`
}
if err := json.Unmarshal(body, &req); err != nil {
writeError(w, http.StatusBadRequest, "BAD_REQUEST", err.Error())
return
}
createReq := &domain.CreatePackageDraftRequest{
SupplierID: a.supplierID,
AccountID: req.SupplyAccountID,
Model: req.Model,
TotalQuota: req.TotalQuota,
PricePer1MInput: req.PricePer1MInput,
PricePer1MOutput: req.PricePer1MOutput,
ValidDays: req.ValidDays,
MaxConcurrent: req.MaxConcurrent,
RateLimitRPM: req.RateLimitRPM,
}
pkg, err := a.packageService.CreateDraft(r.Context(), a.supplierID, createReq)
if err != nil {
writeError(w, http.StatusUnprocessableEntity, "CREATE_FAILED", err.Error())
return
}
writeJSON(w, http.StatusCreated, map[string]any{
"request_id": getRequestID(r),
"data": map[string]any{
"package_id": pkg.ID,
"supply_account_id": pkg.SupplierID,
"model": pkg.Model,
"status": pkg.Status,
"total_quota": pkg.TotalQuota,
"available_quota": pkg.AvailableQuota,
"created_at": pkg.CreatedAt,
},
})
}
func (a *SupplyAPI) handlePackageActions(w http.ResponseWriter, r *http.Request) {
path := strings.TrimPrefix(r.URL.Path, "/api/v1/supply/packages/")
parts := strings.Split(path, "/")
if len(parts) < 1 {
writeError(w, http.StatusNotFound, "NOT_FOUND", "route not found")
return
}
// 批量调价
if len(parts) == 1 && parts[0] == "batch-price" {
a.handleBatchUpdatePrice(w, r)
return
}
packageID, err := strconv.ParseInt(parts[0], 10, 64)
if err != nil {
writeError(w, http.StatusBadRequest, "BAD_REQUEST", "invalid package_id")
return
}
if len(parts) < 2 {
writeError(w, http.StatusNotFound, "NOT_FOUND", "route not found")
return
}
action := parts[1]
switch action {
case "publish":
if r.Method != http.MethodPost {
writeError(w, http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", "method not allowed")
return
}
a.handlePublishPackage(w, r, packageID)
case "pause":
if r.Method != http.MethodPost {
writeError(w, http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", "method not allowed")
return
}
a.handlePausePackage(w, r, packageID)
case "unlist":
if r.Method != http.MethodPost {
writeError(w, http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", "method not allowed")
return
}
a.handleUnlistPackage(w, r, packageID)
case "clone":
if r.Method != http.MethodPost {
writeError(w, http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", "method not allowed")
return
}
a.handleClonePackage(w, r, packageID)
default:
writeError(w, http.StatusNotFound, "NOT_FOUND", "route not found")
}
}
func (a *SupplyAPI) handlePublishPackage(w http.ResponseWriter, r *http.Request, packageID int64) {
pkg, err := a.packageService.Publish(r.Context(), a.supplierID, packageID)
if err != nil {
if strings.Contains(err.Error(), "SUP_PKG") {
writeError(w, http.StatusConflict, "CONFLICT", err.Error())
} else {
writeError(w, http.StatusNotFound, "NOT_FOUND", err.Error())
}
return
}
writeJSON(w, http.StatusOK, map[string]any{
"request_id": getRequestID(r),
"data": map[string]any{
"package_id": pkg.ID,
"status": pkg.Status,
"updated_at": pkg.UpdatedAt,
},
})
}
func (a *SupplyAPI) handlePausePackage(w http.ResponseWriter, r *http.Request, packageID int64) {
pkg, err := a.packageService.Pause(r.Context(), a.supplierID, packageID)
if err != nil {
if strings.Contains(err.Error(), "SUP_PKG") {
writeError(w, http.StatusConflict, "CONFLICT", err.Error())
} else {
writeError(w, http.StatusNotFound, "NOT_FOUND", err.Error())
}
return
}
writeJSON(w, http.StatusOK, map[string]any{
"request_id": getRequestID(r),
"data": map[string]any{
"package_id": pkg.ID,
"status": pkg.Status,
"updated_at": pkg.UpdatedAt,
},
})
}
func (a *SupplyAPI) handleUnlistPackage(w http.ResponseWriter, r *http.Request, packageID int64) {
pkg, err := a.packageService.Unlist(r.Context(), a.supplierID, packageID)
if err != nil {
if strings.Contains(err.Error(), "SUP_PKG") {
writeError(w, http.StatusConflict, "CONFLICT", err.Error())
} else {
writeError(w, http.StatusNotFound, "NOT_FOUND", err.Error())
}
return
}
writeJSON(w, http.StatusOK, map[string]any{
"request_id": getRequestID(r),
"data": map[string]any{
"package_id": pkg.ID,
"status": pkg.Status,
"updated_at": pkg.UpdatedAt,
},
})
}
func (a *SupplyAPI) handleClonePackage(w http.ResponseWriter, r *http.Request, packageID int64) {
pkg, err := a.packageService.Clone(r.Context(), a.supplierID, packageID)
if err != nil {
writeError(w, http.StatusNotFound, "NOT_FOUND", err.Error())
return
}
writeJSON(w, http.StatusCreated, map[string]any{
"request_id": getRequestID(r),
"data": map[string]any{
"package_id": pkg.ID,
"supply_account_id": pkg.SupplierID,
"model": pkg.Model,
"status": pkg.Status,
"created_at": pkg.CreatedAt,
},
})
}
func (a *SupplyAPI) handleBatchUpdatePrice(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeError(w, http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", "method not allowed")
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
writeError(w, http.StatusBadRequest, "BAD_REQUEST", err.Error())
return
}
defer r.Body.Close()
var rawReq struct {
Items []struct {
PackageID int64 `json:"package_id"`
PricePer1MInput float64 `json:"price_per_1m_input"`
PricePer1MOutput float64 `json:"price_per_1m_output"`
} `json:"items"`
}
if err := json.Unmarshal(body, &rawReq); err != nil {
writeError(w, http.StatusBadRequest, "BAD_REQUEST", err.Error())
return
}
req := &domain.BatchUpdatePriceRequest{
Items: make([]domain.BatchPriceItem, len(rawReq.Items)),
}
for i, item := range rawReq.Items {
req.Items[i] = domain.BatchPriceItem{
PackageID: item.PackageID,
PricePer1MInput: item.PricePer1MInput,
PricePer1MOutput: item.PricePer1MOutput,
}
}
resp, err := a.packageService.BatchUpdatePrice(r.Context(), a.supplierID, req)
if err != nil {
writeError(w, http.StatusUnprocessableEntity, "BATCH_UPDATE_FAILED", err.Error())
return
}
writeJSON(w, http.StatusOK, map[string]any{
"request_id": getRequestID(r),
"data": resp,
})
}
// ==================== Billing Handlers ====================
func (a *SupplyAPI) handleGetBilling(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeError(w, http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", "method not allowed")
return
}
startDate := r.URL.Query().Get("start_date")
endDate := r.URL.Query().Get("end_date")
summary, err := a.earningService.GetBillingSummary(r.Context(), a.supplierID, startDate, endDate)
if err != nil {
writeError(w, http.StatusInternalServerError, "QUERY_FAILED", err.Error())
return
}
writeJSON(w, http.StatusOK, map[string]any{
"request_id": getRequestID(r),
"data": summary,
})
}
// ==================== Settlement Handlers ====================
func (a *SupplyAPI) handleWithdraw(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeError(w, http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", "method not allowed")
return
}
requestID := r.Header.Get("X-Request-Id")
idempotencyKey := r.Header.Get("Idempotency-Key")
// 幂等检查
if idempotencyKey != "" {
if record, found := a.idempotencyStore.Get(idempotencyKey); found {
if record.Status == "succeeded" {
writeJSON(w, http.StatusOK, map[string]any{
"request_id": requestID,
"idempotent_replay": true,
"data": record.Response,
})
return
}
}
a.idempotencyStore.SetProcessing(idempotencyKey, 72*time.Hour) // 提现类72h
}
body, err := io.ReadAll(r.Body)
if err != nil {
writeError(w, http.StatusBadRequest, "BAD_REQUEST", err.Error())
return
}
defer r.Body.Close()
var req struct {
WithdrawAmount float64 `json:"withdraw_amount"`
PaymentMethod string `json:"payment_method"`
PaymentAccount string `json:"payment_account"`
SMSCode string `json:"sms_code"`
}
if err := json.Unmarshal(body, &req); err != nil {
writeError(w, http.StatusBadRequest, "BAD_REQUEST", err.Error())
return
}
withdrawReq := &domain.WithdrawRequest{
Amount: req.WithdrawAmount,
PaymentMethod: domain.PaymentMethod(req.PaymentMethod),
PaymentAccount: req.PaymentAccount,
SMSCode: req.SMSCode,
}
settlement, err := a.settlementService.Withdraw(r.Context(), a.supplierID, withdrawReq)
if err != nil {
if strings.Contains(err.Error(), "SUP_SET") {
writeError(w, http.StatusConflict, "WITHDRAW_FAILED", err.Error())
} else {
writeError(w, http.StatusUnprocessableEntity, "WITHDRAW_FAILED", err.Error())
}
return
}
resp := map[string]any{
"settlement_id": settlement.ID,
"settlement_no": settlement.SettlementNo,
"status": settlement.Status,
"total_amount": settlement.TotalAmount,
"net_amount": settlement.NetAmount,
"created_at": settlement.CreatedAt,
}
// 保存幂等结果
if idempotencyKey != "" {
a.idempotencyStore.SetSuccess(idempotencyKey, resp, 72*time.Hour)
}
writeJSON(w, http.StatusCreated, map[string]any{
"request_id": requestID,
"data": resp,
})
}
func (a *SupplyAPI) handleSettlementActions(w http.ResponseWriter, r *http.Request) {
path := strings.TrimPrefix(r.URL.Path, "/api/v1/supply/settlements/")
parts := strings.Split(path, "/")
if len(parts) < 2 {
writeError(w, http.StatusNotFound, "NOT_FOUND", "route not found")
return
}
settlementID, err := strconv.ParseInt(parts[0], 10, 64)
if err != nil {
writeError(w, http.StatusBadRequest, "BAD_REQUEST", "invalid settlement_id")
return
}
action := parts[1]
switch action {
case "cancel":
if r.Method != http.MethodPost {
writeError(w, http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", "method not allowed")
return
}
a.handleCancelSettlement(w, r, settlementID)
case "statement":
if r.Method != http.MethodGet {
writeError(w, http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", "method not allowed")
return
}
a.handleGetStatement(w, r, settlementID)
default:
writeError(w, http.StatusNotFound, "NOT_FOUND", "route not found")
}
}
func (a *SupplyAPI) handleCancelSettlement(w http.ResponseWriter, r *http.Request, settlementID int64) {
settlement, err := a.settlementService.Cancel(r.Context(), a.supplierID, settlementID)
if err != nil {
if strings.Contains(err.Error(), "SUP_SET") {
writeError(w, http.StatusConflict, "CONFLICT", err.Error())
} else {
writeError(w, http.StatusNotFound, "NOT_FOUND", err.Error())
}
return
}
writeJSON(w, http.StatusOK, map[string]any{
"request_id": getRequestID(r),
"data": map[string]any{
"settlement_id": settlement.ID,
"status": settlement.Status,
"updated_at": settlement.UpdatedAt,
},
})
}
func (a *SupplyAPI) handleGetStatement(w http.ResponseWriter, r *http.Request, settlementID int64) {
settlement, err := a.settlementService.GetByID(r.Context(), a.supplierID, settlementID)
if err != nil {
writeError(w, http.StatusNotFound, "NOT_FOUND", err.Error())
return
}
writeJSON(w, http.StatusOK, map[string]any{
"request_id": getRequestID(r),
"data": map[string]any{
"settlement_id": settlement.ID,
"file_name": fmt.Sprintf("statement_%s.pdf", settlement.SettlementNo),
"download_url": fmt.Sprintf("https://example.com/statements/%s.pdf", settlement.SettlementNo),
"expires_at": a.now().Add(1 * time.Hour),
},
})
}
// ==================== Earning Handlers ====================
func (a *SupplyAPI) handleGetEarningRecords(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeError(w, http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", "method not allowed")
return
}
startDate := r.URL.Query().Get("start_date")
endDate := r.URL.Query().Get("end_date")
page := getQueryInt(r, "page", 1)
pageSize := getQueryInt(r, "page_size", 20)
records, total, err := a.earningService.ListRecords(r.Context(), a.supplierID, startDate, endDate, page, pageSize)
if err != nil {
writeError(w, http.StatusInternalServerError, "QUERY_FAILED", err.Error())
return
}
var items []map[string]any
for _, record := range records {
items = append(items, map[string]any{
"record_id": record.ID,
"earnings_type": record.EarningsType,
"amount": record.Amount,
"status": record.Status,
"earned_at": record.EarnedAt,
})
}
writeJSON(w, http.StatusOK, map[string]any{
"request_id": getRequestID(r),
"data": items,
"pagination": map[string]int{
"page": page,
"page_size": pageSize,
"total": total,
},
})
}
// ==================== Helpers ====================
func writeJSON(w http.ResponseWriter, status int, payload any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(payload)
}
func writeError(w http.ResponseWriter, status int, code, message string) {
writeJSON(w, status, map[string]any{
"request_id": "",
"error": map[string]any{
"code": code,
"message": message,
},
})
}
func getRequestID(r *http.Request) string {
if id := r.Header.Get("X-Request-Id"); id != "" {
return id
}
return r.Header.Get("X-Request-ID")
}
func getQueryInt(r *http.Request, key string, defaultVal int) int {
if val := r.URL.Query().Get(key); val != "" {
if intVal, err := strconv.Atoi(val); err == nil {
return intVal
}
}
return defaultVal
}

View File

@@ -0,0 +1,42 @@
package middleware
import (
"log"
"net/http"
"runtime/debug"
)
// Recovery 中间件 - 恢复 panic
func Recovery(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
log.Printf("panic recovered: %v\n%s", err, debug.Stack())
http.Error(w, "internal server error", http.StatusInternalServerError)
}
}()
next.ServeHTTP(w, r)
})
}
// Logging 中间件 - 请求日志
func Logging(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Printf("%s %s", r.Method, r.URL.Path)
next.ServeHTTP(w, r)
})
}
// RequestID 中间件 - 请求追踪
func RequestID(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestID := r.Header.Get("X-Request-Id")
if requestID == "" {
requestID = r.Header.Get("X-Request-ID")
}
if requestID != "" {
w.Header().Set("X-Request-Id", requestID)
}
next.ServeHTTP(w, r)
})
}

View File

@@ -0,0 +1,319 @@
package storage
import (
"context"
"errors"
"sync"
"time"
"lijiaoqiao/supply-api/internal/domain"
)
// 错误定义
var ErrNotFound = errors.New("resource not found")
// 内存账号存储
type InMemoryAccountStore struct {
mu sync.RWMutex
accounts map[int64]*domain.Account
nextID int64
}
func NewInMemoryAccountStore() *InMemoryAccountStore {
return &InMemoryAccountStore{
accounts: make(map[int64]*domain.Account),
nextID: 1,
}
}
func (s *InMemoryAccountStore) Create(ctx context.Context, account *domain.Account) error {
s.mu.Lock()
defer s.mu.Unlock()
account.ID = s.nextID
s.nextID++
account.CreatedAt = time.Now()
account.UpdatedAt = time.Now()
s.accounts[account.ID] = account
return nil
}
func (s *InMemoryAccountStore) GetByID(ctx context.Context, supplierID, id int64) (*domain.Account, error) {
s.mu.RLock()
defer s.mu.RUnlock()
account, ok := s.accounts[id]
if !ok || account.SupplierID != supplierID {
return nil, ErrNotFound
}
return account, nil
}
func (s *InMemoryAccountStore) Update(ctx context.Context, account *domain.Account) error {
s.mu.Lock()
defer s.mu.Unlock()
existing, ok := s.accounts[account.ID]
if !ok || existing.SupplierID != account.SupplierID {
return ErrNotFound
}
account.UpdatedAt = time.Now()
s.accounts[account.ID] = account
return nil
}
func (s *InMemoryAccountStore) List(ctx context.Context, supplierID int64) ([]*domain.Account, error) {
s.mu.RLock()
defer s.mu.RUnlock()
var result []*domain.Account
for _, account := range s.accounts {
if account.SupplierID == supplierID {
result = append(result, account)
}
}
return result, nil
}
// 内存套餐存储
type InMemoryPackageStore struct {
mu sync.RWMutex
packages map[int64]*domain.Package
nextID int64
}
func NewInMemoryPackageStore() *InMemoryPackageStore {
return &InMemoryPackageStore{
packages: make(map[int64]*domain.Package),
nextID: 1,
}
}
func (s *InMemoryPackageStore) Create(ctx context.Context, pkg *domain.Package) error {
s.mu.Lock()
defer s.mu.Unlock()
pkg.ID = s.nextID
s.nextID++
pkg.CreatedAt = time.Now()
pkg.UpdatedAt = time.Now()
s.packages[pkg.ID] = pkg
return nil
}
func (s *InMemoryPackageStore) GetByID(ctx context.Context, supplierID, id int64) (*domain.Package, error) {
s.mu.RLock()
defer s.mu.RUnlock()
pkg, ok := s.packages[id]
if !ok || pkg.SupplierID != supplierID {
return nil, ErrNotFound
}
return pkg, nil
}
func (s *InMemoryPackageStore) Update(ctx context.Context, pkg *domain.Package) error {
s.mu.Lock()
defer s.mu.Unlock()
existing, ok := s.packages[pkg.ID]
if !ok || existing.SupplierID != pkg.SupplierID {
return ErrNotFound
}
pkg.UpdatedAt = time.Now()
s.packages[pkg.ID] = pkg
return nil
}
func (s *InMemoryPackageStore) List(ctx context.Context, supplierID int64) ([]*domain.Package, error) {
s.mu.RLock()
defer s.mu.RUnlock()
var result []*domain.Package
for _, pkg := range s.packages {
if pkg.SupplierID == supplierID {
result = append(result, pkg)
}
}
return result, nil
}
// 内存结算存储
type InMemorySettlementStore struct {
mu sync.RWMutex
settlements map[int64]*domain.Settlement
nextID int64
}
func NewInMemorySettlementStore() *InMemorySettlementStore {
return &InMemorySettlementStore{
settlements: make(map[int64]*domain.Settlement),
nextID: 1,
}
}
func (s *InMemorySettlementStore) Create(ctx context.Context, settlement *domain.Settlement) error {
s.mu.Lock()
defer s.mu.Unlock()
settlement.ID = s.nextID
s.nextID++
settlement.CreatedAt = time.Now()
settlement.UpdatedAt = time.Now()
s.settlements[settlement.ID] = settlement
return nil
}
func (s *InMemorySettlementStore) GetByID(ctx context.Context, supplierID, id int64) (*domain.Settlement, error) {
s.mu.RLock()
defer s.mu.RUnlock()
settlement, ok := s.settlements[id]
if !ok || settlement.SupplierID != supplierID {
return nil, ErrNotFound
}
return settlement, nil
}
func (s *InMemorySettlementStore) Update(ctx context.Context, settlement *domain.Settlement) error {
s.mu.Lock()
defer s.mu.Unlock()
existing, ok := s.settlements[settlement.ID]
if !ok || existing.SupplierID != settlement.SupplierID {
return ErrNotFound
}
settlement.UpdatedAt = time.Now()
s.settlements[settlement.ID] = settlement
return nil
}
func (s *InMemorySettlementStore) List(ctx context.Context, supplierID int64) ([]*domain.Settlement, error) {
s.mu.RLock()
defer s.mu.RUnlock()
var result []*domain.Settlement
for _, settlement := range s.settlements {
if settlement.SupplierID == supplierID {
result = append(result, settlement)
}
}
return result, nil
}
func (s *InMemorySettlementStore) GetWithdrawableBalance(ctx context.Context, supplierID int64) (float64, error) {
return 10000.0, nil
}
// 内存收益存储
type InMemoryEarningStore struct {
mu sync.RWMutex
records map[int64]*domain.EarningRecord
nextID int64
}
func NewInMemoryEarningStore() *InMemoryEarningStore {
return &InMemoryEarningStore{
records: make(map[int64]*domain.EarningRecord),
nextID: 1,
}
}
func (s *InMemoryEarningStore) ListRecords(ctx context.Context, supplierID int64, startDate, endDate string, page, pageSize int) ([]*domain.EarningRecord, int, error) {
s.mu.RLock()
defer s.mu.RUnlock()
var result []*domain.EarningRecord
for _, record := range s.records {
if record.SupplierID == supplierID {
result = append(result, record)
}
}
total := len(result)
start := (page - 1) * pageSize
end := start + pageSize
if start >= total {
return []*domain.EarningRecord{}, total, nil
}
if end > total {
end = total
}
return result[start:end], total, nil
}
func (s *InMemoryEarningStore) GetBillingSummary(ctx context.Context, supplierID int64, startDate, endDate string) (*domain.BillingSummary, error) {
return &domain.BillingSummary{
Period: domain.BillingPeriod{
Start: startDate,
End: endDate,
},
Summary: domain.BillingTotal{
TotalRevenue: 10000.0,
TotalOrders: 100,
TotalUsage: 1000000,
TotalRequests: 50000,
AvgSuccessRate: 99.5,
PlatformFee: 100.0,
NetEarnings: 9900.0,
},
}, nil
}
// 内存幂等存储
type InMemoryIdempotencyStore struct {
mu sync.RWMutex
records map[string]*IdempotencyRecord
}
type IdempotencyRecord struct {
Key string
Status string // processing, succeeded, failed
Response interface{}
CreatedAt time.Time
ExpiresAt time.Time
}
func NewInMemoryIdempotencyStore() *InMemoryIdempotencyStore {
return &InMemoryIdempotencyStore{
records: make(map[string]*IdempotencyRecord),
}
}
func (s *InMemoryIdempotencyStore) Get(key string) (*IdempotencyRecord, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
record, ok := s.records[key]
if ok && record.ExpiresAt.After(time.Now()) {
return record, true
}
return nil, false
}
func (s *InMemoryIdempotencyStore) SetProcessing(key string, ttl time.Duration) {
s.mu.Lock()
defer s.mu.Unlock()
s.records[key] = &IdempotencyRecord{
Key: key,
Status: "processing",
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(ttl),
}
}
func (s *InMemoryIdempotencyStore) SetSuccess(key string, response interface{}, ttl time.Duration) {
s.mu.Lock()
defer s.mu.Unlock()
s.records[key] = &IdempotencyRecord{
Key: key,
Status: "succeeded",
Response: response,
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(ttl),
}
}