fix: close p0 auth and release gate gaps

This commit is contained in:
Your Name
2026-04-11 09:25:31 +08:00
parent b7b46dc827
commit 4adeee2e06
28 changed files with 3791 additions and 276 deletions

View File

@@ -18,14 +18,15 @@ import (
// SupplyAPI 处理器
type SupplyAPI struct {
accountService domain.AccountService
packageService domain.PackageService
settlementService domain.SettlementService
accountService domain.AccountService
packageService domain.PackageService
settlementService domain.SettlementService
earningService domain.EarningService
idempotencyMw *middleware.IdempotencyMiddleware // P0-P4修复: 使用DB-backed幂等中间件
auditStore audit.AuditStore // P0-R08修复: 使用接口支持DB-backed实现
fkValidator *repository.ForeignKeyValidator // P0-09修复: 外键校验器
auditStore audit.AuditStore // P0-R08修复: 使用接口支持DB-backed实现
fkValidator *repository.ForeignKeyValidator // P0-09修复: 外键校验器
supplierID int64
withdrawEnabled bool
statementBaseURL string
now func() time.Time
}
@@ -51,11 +52,16 @@ func NewSupplyAPI(
auditStore: auditStore,
fkValidator: fkValidator,
supplierID: supplierID,
withdrawEnabled: true,
statementBaseURL: statementBaseURL,
now: now,
}
}
func (a *SupplyAPI) SetWithdrawEnabled(enabled bool) {
a.withdrawEnabled = enabled
}
func (a *SupplyAPI) Register(mux *http.ServeMux) {
// Supply Accounts
mux.HandleFunc("/api/v1/supply/accounts/verify", a.handleVerifyAccount)
@@ -82,6 +88,25 @@ func (a *SupplyAPI) Register(mux *http.ServeMux) {
mux.HandleFunc("/api/v1/audit/events/", a.handleAuditEvent)
}
func (a *SupplyAPI) resolveSupplierID(ctx context.Context) (int64, error) {
if tenantID := middleware.GetTenantID(ctx); tenantID > 0 {
return tenantID, nil
}
if a.supplierID > 0 {
return a.supplierID, nil
}
return 0, fmt.Errorf("supplier context is missing")
}
func (a *SupplyAPI) requireSupplierID(w http.ResponseWriter, r *http.Request) (int64, bool) {
supplierID, err := a.resolveSupplierID(r.Context())
if err != nil {
writeError(w, http.StatusUnauthorized, "AUTH_CONTEXT_MISSING", err.Error())
return 0, false
}
return supplierID, true
}
// ==================== Account Handlers ====================
type VerifyAccountRequest struct {
@@ -110,7 +135,12 @@ func (a *SupplyAPI) handleVerifyAccount(w http.ResponseWriter, r *http.Request)
return
}
result, err := a.accountService.Verify(r.Context(), a.supplierID,
supplierID, ok := a.requireSupplierID(w, r)
if !ok {
return
}
result, err := a.accountService.Verify(r.Context(), supplierID,
domain.Provider(req.Provider),
domain.AccountType(req.AccountType),
req.CredentialInput)
@@ -139,12 +169,17 @@ func (a *SupplyAPI) handleCreateAccount(w http.ResponseWriter, r *http.Request)
}
// 降级:使用内联幂等逻辑(仅在幂等中间件未启用时)
a.createAccountHandler(context.Background(), w, r, nil)
a.createAccountHandler(r.Context(), w, r, nil)
}
// createAccountHandler 创建账号的业务逻辑(供幂等中间件包装)
func (a *SupplyAPI) createAccountHandler(ctx context.Context, w http.ResponseWriter, r *http.Request, _ *repository.IdempotencyRecord) error {
requestID := r.Header.Get("X-Request-Id")
supplierID, err := a.resolveSupplierID(ctx)
if err != nil {
writeError(w, http.StatusUnauthorized, "AUTH_CONTEXT_MISSING", err.Error())
return err
}
body, err := io.ReadAll(r.Body)
if err != nil {
@@ -169,14 +204,14 @@ func (a *SupplyAPI) createAccountHandler(ctx context.Context, w http.ResponseWri
// P0-09修复: 创建账户前校验外键引用
if a.fkValidator != nil {
if err := a.fkValidator.ValidateSupplyAccountOwner(ctx, a.supplierID); err != nil {
if err := a.fkValidator.ValidateSupplyAccountOwner(ctx, supplierID); err != nil {
writeError(w, http.StatusUnprocessableEntity, "FK_VALIDATION_FAILED", "supplier does not exist")
return err
}
}
createReq := &domain.CreateAccountRequest{
SupplierID: a.supplierID,
SupplierID: supplierID,
Provider: domain.Provider(rawReq.Provider),
AccountType: domain.AccountType(rawReq.AccountType),
Credential: rawReq.CredentialInput,
@@ -252,7 +287,12 @@ func (a *SupplyAPI) handleAccountActions(w http.ResponseWriter, r *http.Request)
}
func (a *SupplyAPI) handleActivateAccount(w http.ResponseWriter, r *http.Request, accountID int64) {
account, err := a.accountService.Activate(r.Context(), a.supplierID, accountID)
supplierID, ok := a.requireSupplierID(w, r)
if !ok {
return
}
account, err := a.accountService.Activate(r.Context(), supplierID, accountID)
if err != nil {
if strings.Contains(err.Error(), "SUP_ACC") {
writeError(w, http.StatusConflict, "CONFLICT", err.Error())
@@ -273,7 +313,12 @@ func (a *SupplyAPI) handleActivateAccount(w http.ResponseWriter, r *http.Request
}
func (a *SupplyAPI) handleSuspendAccount(w http.ResponseWriter, r *http.Request, accountID int64) {
account, err := a.accountService.Suspend(r.Context(), a.supplierID, accountID)
supplierID, ok := a.requireSupplierID(w, r)
if !ok {
return
}
account, err := a.accountService.Suspend(r.Context(), supplierID, accountID)
if err != nil {
if strings.Contains(err.Error(), "SUP_ACC") {
writeError(w, http.StatusConflict, "CONFLICT", err.Error())
@@ -294,7 +339,12 @@ func (a *SupplyAPI) handleSuspendAccount(w http.ResponseWriter, r *http.Request,
}
func (a *SupplyAPI) handleDeleteAccount(w http.ResponseWriter, r *http.Request, accountID int64) {
err := a.accountService.Delete(r.Context(), a.supplierID, accountID)
supplierID, ok := a.requireSupplierID(w, r)
if !ok {
return
}
err := a.accountService.Delete(r.Context(), supplierID, accountID)
if err != nil {
if strings.Contains(err.Error(), "SUP_ACC") {
writeError(w, http.StatusConflict, "CONFLICT", err.Error())
@@ -308,11 +358,27 @@ func (a *SupplyAPI) handleDeleteAccount(w http.ResponseWriter, r *http.Request,
}
func (a *SupplyAPI) handleAccountAuditLogs(w http.ResponseWriter, r *http.Request, accountID int64) {
supplierID, ok := a.requireSupplierID(w, r)
if !ok {
return
}
page := getQueryInt(r, "page", 1)
pageSize := getQueryInt(r, "page_size", 20)
// 分页参数边界验证
if page < 1 {
page = 1
}
if pageSize < 1 {
pageSize = 20
}
if pageSize > 1000 {
pageSize = 1000
}
events, total, err := a.auditStore.QueryWithTotal(r.Context(), audit.EventFilter{
TenantID: a.supplierID,
TenantID: supplierID,
ObjectType: "supply_account",
ObjectID: accountID,
Limit: pageSize,
@@ -378,6 +444,11 @@ func (a *SupplyAPI) handleCreatePackageDraft(w http.ResponseWriter, r *http.Requ
return
}
supplierID, ok := a.requireSupplierID(w, r)
if !ok {
return
}
// P0-09修复: 创建套餐前校验外键引用
if a.fkValidator != nil {
if err := a.fkValidator.ValidatePackageSupplyAccount(r.Context(), req.SupplyAccountID); err != nil {
@@ -387,7 +458,7 @@ func (a *SupplyAPI) handleCreatePackageDraft(w http.ResponseWriter, r *http.Requ
}
createReq := &domain.CreatePackageDraftRequest{
SupplierID: a.supplierID,
SupplierID: supplierID,
AccountID: req.SupplyAccountID,
Model: req.Model,
TotalQuota: req.TotalQuota,
@@ -398,7 +469,7 @@ func (a *SupplyAPI) handleCreatePackageDraft(w http.ResponseWriter, r *http.Requ
RateLimitRPM: req.RateLimitRPM,
}
pkg, err := a.packageService.CreateDraft(r.Context(), a.supplierID, createReq)
pkg, err := a.packageService.CreateDraft(r.Context(), supplierID, createReq)
if err != nil {
writeError(w, http.StatusUnprocessableEntity, "CREATE_FAILED", err.Error())
return
@@ -477,7 +548,12 @@ func (a *SupplyAPI) handlePackageActions(w http.ResponseWriter, r *http.Request)
}
func (a *SupplyAPI) handlePublishPackage(w http.ResponseWriter, r *http.Request, packageID int64) {
pkg, err := a.packageService.Publish(r.Context(), a.supplierID, packageID)
supplierID, ok := a.requireSupplierID(w, r)
if !ok {
return
}
pkg, err := a.packageService.Publish(r.Context(), supplierID, packageID)
if err != nil {
if strings.Contains(err.Error(), "SUP_PKG") {
writeError(w, http.StatusConflict, "CONFLICT", err.Error())
@@ -498,7 +574,12 @@ func (a *SupplyAPI) handlePublishPackage(w http.ResponseWriter, r *http.Request,
}
func (a *SupplyAPI) handlePausePackage(w http.ResponseWriter, r *http.Request, packageID int64) {
pkg, err := a.packageService.Pause(r.Context(), a.supplierID, packageID)
supplierID, ok := a.requireSupplierID(w, r)
if !ok {
return
}
pkg, err := a.packageService.Pause(r.Context(), supplierID, packageID)
if err != nil {
if strings.Contains(err.Error(), "SUP_PKG") {
writeError(w, http.StatusConflict, "CONFLICT", err.Error())
@@ -519,7 +600,12 @@ func (a *SupplyAPI) handlePausePackage(w http.ResponseWriter, r *http.Request, p
}
func (a *SupplyAPI) handleUnlistPackage(w http.ResponseWriter, r *http.Request, packageID int64) {
pkg, err := a.packageService.Unlist(r.Context(), a.supplierID, packageID)
supplierID, ok := a.requireSupplierID(w, r)
if !ok {
return
}
pkg, err := a.packageService.Unlist(r.Context(), supplierID, packageID)
if err != nil {
if strings.Contains(err.Error(), "SUP_PKG") {
writeError(w, http.StatusConflict, "CONFLICT", err.Error())
@@ -540,7 +626,12 @@ func (a *SupplyAPI) handleUnlistPackage(w http.ResponseWriter, r *http.Request,
}
func (a *SupplyAPI) handleClonePackage(w http.ResponseWriter, r *http.Request, packageID int64) {
pkg, err := a.packageService.Clone(r.Context(), a.supplierID, packageID)
supplierID, ok := a.requireSupplierID(w, r)
if !ok {
return
}
pkg, err := a.packageService.Clone(r.Context(), supplierID, packageID)
if err != nil {
writeError(w, http.StatusNotFound, "NOT_FOUND", err.Error())
return
@@ -595,7 +686,12 @@ func (a *SupplyAPI) handleBatchUpdatePrice(w http.ResponseWriter, r *http.Reques
}
}
resp, err := a.packageService.BatchUpdatePrice(r.Context(), a.supplierID, req)
supplierID, ok := a.requireSupplierID(w, r)
if !ok {
return
}
resp, err := a.packageService.BatchUpdatePrice(r.Context(), supplierID, req)
if err != nil {
writeError(w, http.StatusUnprocessableEntity, "BATCH_UPDATE_FAILED", err.Error())
return
@@ -618,7 +714,12 @@ func (a *SupplyAPI) handleGetBilling(w http.ResponseWriter, r *http.Request) {
startDate := r.URL.Query().Get("start_date")
endDate := r.URL.Query().Get("end_date")
summary, err := a.earningService.GetBillingSummary(r.Context(), a.supplierID, startDate, endDate)
supplierID, ok := a.requireSupplierID(w, r)
if !ok {
return
}
summary, err := a.earningService.GetBillingSummary(r.Context(), supplierID, startDate, endDate)
if err != nil {
writeError(w, http.StatusInternalServerError, "QUERY_FAILED", err.Error())
return
@@ -637,6 +738,10 @@ func (a *SupplyAPI) handleWithdraw(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", "method not allowed")
return
}
if !a.withdrawEnabled {
writeError(w, http.StatusServiceUnavailable, "FEATURE_DISABLED", "withdraw is disabled until SMS verification is integrated")
return
}
// P0-P4修复: 使用DB-backed幂等中间件
if a.idempotencyMw != nil {
@@ -645,12 +750,17 @@ func (a *SupplyAPI) handleWithdraw(w http.ResponseWriter, r *http.Request) {
}
// 降级:使用内联幂等逻辑(仅在幂等中间件未启用时)
a.withdrawHandler(context.Background(), w, r, nil)
a.withdrawHandler(r.Context(), w, r, nil)
}
// withdrawHandler 提现的业务逻辑(供幂等中间件包装)
func (a *SupplyAPI) withdrawHandler(ctx context.Context, w http.ResponseWriter, r *http.Request, _ *repository.IdempotencyRecord) error {
requestID := r.Header.Get("X-Request-Id")
supplierID, err := a.resolveSupplierID(ctx)
if err != nil {
writeError(w, http.StatusUnauthorized, "AUTH_CONTEXT_MISSING", err.Error())
return err
}
body, err := io.ReadAll(r.Body)
if err != nil {
@@ -678,7 +788,7 @@ func (a *SupplyAPI) withdrawHandler(ctx context.Context, w http.ResponseWriter,
SMSCode: req.SMSCode,
}
settlement, err := a.settlementService.Withdraw(ctx, a.supplierID, withdrawReq)
settlement, err := a.settlementService.Withdraw(ctx, supplierID, withdrawReq)
if err != nil {
if strings.Contains(err.Error(), "SUP_SET") {
writeError(w, http.StatusConflict, "WITHDRAW_FAILED", err.Error())
@@ -740,7 +850,12 @@ func (a *SupplyAPI) handleSettlementActions(w http.ResponseWriter, r *http.Reque
}
func (a *SupplyAPI) handleCancelSettlement(w http.ResponseWriter, r *http.Request, settlementID int64) {
settlement, err := a.settlementService.Cancel(r.Context(), a.supplierID, settlementID)
supplierID, ok := a.requireSupplierID(w, r)
if !ok {
return
}
settlement, err := a.settlementService.Cancel(r.Context(), supplierID, settlementID)
if err != nil {
if strings.Contains(err.Error(), "SUP_SET") {
writeError(w, http.StatusConflict, "CONFLICT", err.Error())
@@ -761,7 +876,12 @@ func (a *SupplyAPI) handleCancelSettlement(w http.ResponseWriter, r *http.Reques
}
func (a *SupplyAPI) handleGetStatement(w http.ResponseWriter, r *http.Request, settlementID int64) {
settlement, err := a.settlementService.GetByID(r.Context(), a.supplierID, settlementID)
supplierID, ok := a.requireSupplierID(w, r)
if !ok {
return
}
settlement, err := a.settlementService.GetByID(r.Context(), supplierID, settlementID)
if err != nil {
writeError(w, http.StatusNotFound, "NOT_FOUND", err.Error())
return
@@ -791,7 +911,23 @@ func (a *SupplyAPI) handleGetEarningRecords(w http.ResponseWriter, r *http.Reque
page := getQueryInt(r, "page", 1)
pageSize := getQueryInt(r, "page_size", 20)
records, total, err := a.earningService.ListRecords(r.Context(), a.supplierID, startDate, endDate, page, pageSize)
// 分页参数边界验证
if page < 1 {
page = 1
}
if pageSize < 1 {
pageSize = 20
}
if pageSize > 1000 {
pageSize = 1000
}
supplierID, ok := a.requireSupplierID(w, r)
if !ok {
return
}
records, total, err := a.earningService.ListRecords(r.Context(), supplierID, startDate, endDate, page, pageSize)
if err != nil {
writeError(w, http.StatusInternalServerError, "QUERY_FAILED", err.Error())
return

File diff suppressed because it is too large Load Diff