C-06: DBSettlementStore.GetWithdrawableBalance 使用AccountRepository真实查询余额 C-05: DBEarningStore 使用新建的UsageRepository实现ListRecords/GetBillingSummary C-04: 供应商ID从cfg.Server.DefaultSupplierID配置读取 C-07: PDF链接从cfg.Server.StatementBaseURL配置读取 新增: - internal/repository/usage.go: 用量记录仓储
207 lines
5.3 KiB
Go
207 lines
5.3 KiB
Go
package repository
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
"lijiaoqiao/supply-api/internal/domain"
|
|
)
|
|
|
|
// UsageRepository 用量记录仓储
|
|
type UsageRepository struct {
|
|
pool *pgxpool.Pool
|
|
}
|
|
|
|
// NewUsageRepository 创建用量记录仓储
|
|
func NewUsageRepository(pool *pgxpool.Pool) *UsageRepository {
|
|
return &UsageRepository{pool: pool}
|
|
}
|
|
|
|
// ListRecords 查询收益记录列表
|
|
func (r *UsageRepository) ListRecords(ctx context.Context, supplierID int64, startDate, endDate string, page, pageSize int) ([]*domain.EarningRecord, int, error) {
|
|
// 解析日期
|
|
start, err := time.Parse("2006-01-02", startDate)
|
|
if err != nil {
|
|
return nil, 0, fmt.Errorf("invalid start date: %w", err)
|
|
}
|
|
end, err := time.Parse("2006-01-02", endDate)
|
|
if err != nil {
|
|
return nil, 0, fmt.Errorf("invalid end date: %w", err)
|
|
}
|
|
|
|
// 查询总数
|
|
countQuery := `
|
|
SELECT COUNT(*)
|
|
FROM supply_usage_records
|
|
WHERE supplier_user_id = $1
|
|
AND started_at >= $2
|
|
AND started_at < $3
|
|
`
|
|
var total int
|
|
if err := r.pool.QueryRow(ctx, countQuery, supplierID, start, end.AddDate(0, 0, 1)).Scan(&total); err != nil {
|
|
return nil, 0, fmt.Errorf("failed to count records: %w", err)
|
|
}
|
|
|
|
// 查询记录
|
|
offset := (page - 1) * pageSize
|
|
query := `
|
|
SELECT
|
|
id,
|
|
supplier_user_id,
|
|
total_cost,
|
|
started_at,
|
|
platform,
|
|
model,
|
|
total_tokens,
|
|
success
|
|
FROM supply_usage_records
|
|
WHERE supplier_user_id = $1
|
|
AND started_at >= $2
|
|
AND started_at < $3
|
|
ORDER BY started_at DESC
|
|
LIMIT $4 OFFSET $5
|
|
`
|
|
|
|
rows, err := r.pool.Query(ctx, query, supplierID, start, end.AddDate(0, 0, 1), pageSize, offset)
|
|
if err != nil {
|
|
return nil, 0, fmt.Errorf("failed to query records: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var records []*domain.EarningRecord
|
|
for rows.Next() {
|
|
var id int64
|
|
var supplierUserID int64
|
|
var totalCost float64
|
|
var startedAt time.Time
|
|
var platform, model string
|
|
var totalTokens int64
|
|
var success bool
|
|
|
|
if err := rows.Scan(&id, &supplierUserID, &totalCost, &startedAt, &platform, &model, &totalTokens, &success); err != nil {
|
|
return nil, 0, fmt.Errorf("failed to scan record: %w", err)
|
|
}
|
|
|
|
status := "available"
|
|
if !success {
|
|
status = "pending"
|
|
}
|
|
|
|
records = append(records, &domain.EarningRecord{
|
|
ID: id,
|
|
SupplierID: supplierUserID,
|
|
Amount: totalCost,
|
|
EarningsType: "usage",
|
|
Status: status,
|
|
Description: fmt.Sprintf("%s %s %d tokens", platform, model, totalTokens),
|
|
EarnedAt: startedAt,
|
|
})
|
|
}
|
|
|
|
return records, total, nil
|
|
}
|
|
|
|
// GetBillingSummary 获取账单汇总
|
|
func (r *UsageRepository) GetBillingSummary(ctx context.Context, supplierID int64, startDate, endDate string) (*domain.BillingSummary, error) {
|
|
// 解析日期
|
|
start, err := time.Parse("2006-01-02", startDate)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid start date: %w", err)
|
|
}
|
|
end, err := time.Parse("2006-01-02", endDate)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid end date: %w", err)
|
|
}
|
|
|
|
// 查询汇总数据
|
|
query := `
|
|
SELECT
|
|
COALESCE(SUM(total_cost), 0) as total_revenue,
|
|
COUNT(*) as total_orders,
|
|
COALESCE(SUM(total_tokens), 0) as total_usage,
|
|
COUNT(*) as total_requests,
|
|
COALESCE(AVG(CASE WHEN success THEN 100.0 ELSE 0.0 END), 0) as avg_success_rate
|
|
FROM supply_usage_records
|
|
WHERE supplier_user_id = $1
|
|
AND started_at >= $2
|
|
AND started_at < $3
|
|
`
|
|
|
|
var totalRevenue float64
|
|
var totalOrders, totalUsage, totalRequests int64
|
|
var avgSuccessRate float64
|
|
|
|
err = r.pool.QueryRow(ctx, query, supplierID, start, end.AddDate(0, 0, 1)).Scan(
|
|
&totalRevenue, &totalOrders, &totalUsage, &totalRequests, &avgSuccessRate,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query billing summary: %w", err)
|
|
}
|
|
|
|
// 平台费用假设为 1%
|
|
platformFee := totalRevenue * 0.01
|
|
netEarnings := totalRevenue - platformFee
|
|
|
|
// 查询按平台分组的统计数据
|
|
platformQuery := `
|
|
SELECT
|
|
platform,
|
|
COALESCE(SUM(total_cost), 0) as revenue,
|
|
COUNT(*) as orders,
|
|
COALESCE(SUM(total_tokens), 0) as tokens,
|
|
COALESCE(AVG(CASE WHEN success THEN 100.0 ELSE 0.0 END), 0) as success_rate
|
|
FROM supply_usage_records
|
|
WHERE supplier_user_id = $1
|
|
AND started_at >= $2
|
|
AND started_at < $3
|
|
GROUP BY platform
|
|
ORDER BY revenue DESC
|
|
`
|
|
|
|
platformRows, err := r.pool.Query(ctx, platformQuery, supplierID, start, end.AddDate(0, 0, 1))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query platform stats: %w", err)
|
|
}
|
|
defer platformRows.Close()
|
|
|
|
var byPlatform []domain.PlatformStat
|
|
for platformRows.Next() {
|
|
var platform string
|
|
var revenue float64
|
|
var orders int
|
|
var tokens int64
|
|
var successRate float64
|
|
|
|
if err := platformRows.Scan(&platform, &revenue, &orders, &tokens, &successRate); err != nil {
|
|
return nil, fmt.Errorf("failed to scan platform stat: %w", err)
|
|
}
|
|
|
|
byPlatform = append(byPlatform, domain.PlatformStat{
|
|
Platform: platform,
|
|
Revenue: revenue,
|
|
Orders: orders,
|
|
Tokens: tokens,
|
|
SuccessRate: successRate,
|
|
})
|
|
}
|
|
|
|
return &domain.BillingSummary{
|
|
Period: domain.BillingPeriod{
|
|
Start: startDate,
|
|
End: endDate,
|
|
},
|
|
Summary: domain.BillingTotal{
|
|
TotalRevenue: totalRevenue,
|
|
TotalOrders: int(totalOrders),
|
|
TotalUsage: totalUsage,
|
|
TotalRequests: totalRequests,
|
|
AvgSuccessRate: avgSuccessRate,
|
|
PlatformFee: platformFee,
|
|
NetEarnings: netEarnings,
|
|
},
|
|
ByPlatform: byPlatform,
|
|
}, nil
|
|
}
|