fix: close p0 auth and release gate gaps
This commit is contained in:
@@ -18,14 +18,15 @@ import (
|
||||
|
||||
// SupplyAPI 处理器
|
||||
type SupplyAPI struct {
|
||||
accountService domain.AccountService
|
||||
packageService domain.PackageService
|
||||
settlementService domain.SettlementService
|
||||
accountService domain.AccountService
|
||||
packageService domain.PackageService
|
||||
settlementService domain.SettlementService
|
||||
earningService domain.EarningService
|
||||
idempotencyMw *middleware.IdempotencyMiddleware // P0-P4修复: 使用DB-backed幂等中间件
|
||||
auditStore audit.AuditStore // P0-R08修复: 使用接口支持DB-backed实现
|
||||
fkValidator *repository.ForeignKeyValidator // P0-09修复: 外键校验器
|
||||
auditStore audit.AuditStore // P0-R08修复: 使用接口支持DB-backed实现
|
||||
fkValidator *repository.ForeignKeyValidator // P0-09修复: 外键校验器
|
||||
supplierID int64
|
||||
withdrawEnabled bool
|
||||
statementBaseURL string
|
||||
now func() time.Time
|
||||
}
|
||||
@@ -51,11 +52,16 @@ func NewSupplyAPI(
|
||||
auditStore: auditStore,
|
||||
fkValidator: fkValidator,
|
||||
supplierID: supplierID,
|
||||
withdrawEnabled: true,
|
||||
statementBaseURL: statementBaseURL,
|
||||
now: now,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *SupplyAPI) SetWithdrawEnabled(enabled bool) {
|
||||
a.withdrawEnabled = enabled
|
||||
}
|
||||
|
||||
func (a *SupplyAPI) Register(mux *http.ServeMux) {
|
||||
// Supply Accounts
|
||||
mux.HandleFunc("/api/v1/supply/accounts/verify", a.handleVerifyAccount)
|
||||
@@ -82,6 +88,25 @@ func (a *SupplyAPI) Register(mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/v1/audit/events/", a.handleAuditEvent)
|
||||
}
|
||||
|
||||
func (a *SupplyAPI) resolveSupplierID(ctx context.Context) (int64, error) {
|
||||
if tenantID := middleware.GetTenantID(ctx); tenantID > 0 {
|
||||
return tenantID, nil
|
||||
}
|
||||
if a.supplierID > 0 {
|
||||
return a.supplierID, nil
|
||||
}
|
||||
return 0, fmt.Errorf("supplier context is missing")
|
||||
}
|
||||
|
||||
func (a *SupplyAPI) requireSupplierID(w http.ResponseWriter, r *http.Request) (int64, bool) {
|
||||
supplierID, err := a.resolveSupplierID(r.Context())
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "AUTH_CONTEXT_MISSING", err.Error())
|
||||
return 0, false
|
||||
}
|
||||
return supplierID, true
|
||||
}
|
||||
|
||||
// ==================== Account Handlers ====================
|
||||
|
||||
type VerifyAccountRequest struct {
|
||||
@@ -110,7 +135,12 @@ func (a *SupplyAPI) handleVerifyAccount(w http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
result, err := a.accountService.Verify(r.Context(), a.supplierID,
|
||||
supplierID, ok := a.requireSupplierID(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
result, err := a.accountService.Verify(r.Context(), supplierID,
|
||||
domain.Provider(req.Provider),
|
||||
domain.AccountType(req.AccountType),
|
||||
req.CredentialInput)
|
||||
@@ -139,12 +169,17 @@ func (a *SupplyAPI) handleCreateAccount(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
// 降级:使用内联幂等逻辑(仅在幂等中间件未启用时)
|
||||
a.createAccountHandler(context.Background(), w, r, nil)
|
||||
a.createAccountHandler(r.Context(), w, r, nil)
|
||||
}
|
||||
|
||||
// createAccountHandler 创建账号的业务逻辑(供幂等中间件包装)
|
||||
func (a *SupplyAPI) createAccountHandler(ctx context.Context, w http.ResponseWriter, r *http.Request, _ *repository.IdempotencyRecord) error {
|
||||
requestID := r.Header.Get("X-Request-Id")
|
||||
supplierID, err := a.resolveSupplierID(ctx)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "AUTH_CONTEXT_MISSING", err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
@@ -169,14 +204,14 @@ func (a *SupplyAPI) createAccountHandler(ctx context.Context, w http.ResponseWri
|
||||
|
||||
// P0-09修复: 创建账户前校验外键引用
|
||||
if a.fkValidator != nil {
|
||||
if err := a.fkValidator.ValidateSupplyAccountOwner(ctx, a.supplierID); err != nil {
|
||||
if err := a.fkValidator.ValidateSupplyAccountOwner(ctx, supplierID); err != nil {
|
||||
writeError(w, http.StatusUnprocessableEntity, "FK_VALIDATION_FAILED", "supplier does not exist")
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
createReq := &domain.CreateAccountRequest{
|
||||
SupplierID: a.supplierID,
|
||||
SupplierID: supplierID,
|
||||
Provider: domain.Provider(rawReq.Provider),
|
||||
AccountType: domain.AccountType(rawReq.AccountType),
|
||||
Credential: rawReq.CredentialInput,
|
||||
@@ -252,7 +287,12 @@ func (a *SupplyAPI) handleAccountActions(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
func (a *SupplyAPI) handleActivateAccount(w http.ResponseWriter, r *http.Request, accountID int64) {
|
||||
account, err := a.accountService.Activate(r.Context(), a.supplierID, accountID)
|
||||
supplierID, ok := a.requireSupplierID(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
account, err := a.accountService.Activate(r.Context(), supplierID, accountID)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "SUP_ACC") {
|
||||
writeError(w, http.StatusConflict, "CONFLICT", err.Error())
|
||||
@@ -273,7 +313,12 @@ func (a *SupplyAPI) handleActivateAccount(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
|
||||
func (a *SupplyAPI) handleSuspendAccount(w http.ResponseWriter, r *http.Request, accountID int64) {
|
||||
account, err := a.accountService.Suspend(r.Context(), a.supplierID, accountID)
|
||||
supplierID, ok := a.requireSupplierID(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
account, err := a.accountService.Suspend(r.Context(), supplierID, accountID)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "SUP_ACC") {
|
||||
writeError(w, http.StatusConflict, "CONFLICT", err.Error())
|
||||
@@ -294,7 +339,12 @@ func (a *SupplyAPI) handleSuspendAccount(w http.ResponseWriter, r *http.Request,
|
||||
}
|
||||
|
||||
func (a *SupplyAPI) handleDeleteAccount(w http.ResponseWriter, r *http.Request, accountID int64) {
|
||||
err := a.accountService.Delete(r.Context(), a.supplierID, accountID)
|
||||
supplierID, ok := a.requireSupplierID(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
err := a.accountService.Delete(r.Context(), supplierID, accountID)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "SUP_ACC") {
|
||||
writeError(w, http.StatusConflict, "CONFLICT", err.Error())
|
||||
@@ -308,11 +358,27 @@ func (a *SupplyAPI) handleDeleteAccount(w http.ResponseWriter, r *http.Request,
|
||||
}
|
||||
|
||||
func (a *SupplyAPI) handleAccountAuditLogs(w http.ResponseWriter, r *http.Request, accountID int64) {
|
||||
supplierID, ok := a.requireSupplierID(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
page := getQueryInt(r, "page", 1)
|
||||
pageSize := getQueryInt(r, "page_size", 20)
|
||||
|
||||
// 分页参数边界验证
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize < 1 {
|
||||
pageSize = 20
|
||||
}
|
||||
if pageSize > 1000 {
|
||||
pageSize = 1000
|
||||
}
|
||||
|
||||
events, total, err := a.auditStore.QueryWithTotal(r.Context(), audit.EventFilter{
|
||||
TenantID: a.supplierID,
|
||||
TenantID: supplierID,
|
||||
ObjectType: "supply_account",
|
||||
ObjectID: accountID,
|
||||
Limit: pageSize,
|
||||
@@ -378,6 +444,11 @@ func (a *SupplyAPI) handleCreatePackageDraft(w http.ResponseWriter, r *http.Requ
|
||||
return
|
||||
}
|
||||
|
||||
supplierID, ok := a.requireSupplierID(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// P0-09修复: 创建套餐前校验外键引用
|
||||
if a.fkValidator != nil {
|
||||
if err := a.fkValidator.ValidatePackageSupplyAccount(r.Context(), req.SupplyAccountID); err != nil {
|
||||
@@ -387,7 +458,7 @@ func (a *SupplyAPI) handleCreatePackageDraft(w http.ResponseWriter, r *http.Requ
|
||||
}
|
||||
|
||||
createReq := &domain.CreatePackageDraftRequest{
|
||||
SupplierID: a.supplierID,
|
||||
SupplierID: supplierID,
|
||||
AccountID: req.SupplyAccountID,
|
||||
Model: req.Model,
|
||||
TotalQuota: req.TotalQuota,
|
||||
@@ -398,7 +469,7 @@ func (a *SupplyAPI) handleCreatePackageDraft(w http.ResponseWriter, r *http.Requ
|
||||
RateLimitRPM: req.RateLimitRPM,
|
||||
}
|
||||
|
||||
pkg, err := a.packageService.CreateDraft(r.Context(), a.supplierID, createReq)
|
||||
pkg, err := a.packageService.CreateDraft(r.Context(), supplierID, createReq)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnprocessableEntity, "CREATE_FAILED", err.Error())
|
||||
return
|
||||
@@ -477,7 +548,12 @@ func (a *SupplyAPI) handlePackageActions(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
func (a *SupplyAPI) handlePublishPackage(w http.ResponseWriter, r *http.Request, packageID int64) {
|
||||
pkg, err := a.packageService.Publish(r.Context(), a.supplierID, packageID)
|
||||
supplierID, ok := a.requireSupplierID(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
pkg, err := a.packageService.Publish(r.Context(), supplierID, packageID)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "SUP_PKG") {
|
||||
writeError(w, http.StatusConflict, "CONFLICT", err.Error())
|
||||
@@ -498,7 +574,12 @@ func (a *SupplyAPI) handlePublishPackage(w http.ResponseWriter, r *http.Request,
|
||||
}
|
||||
|
||||
func (a *SupplyAPI) handlePausePackage(w http.ResponseWriter, r *http.Request, packageID int64) {
|
||||
pkg, err := a.packageService.Pause(r.Context(), a.supplierID, packageID)
|
||||
supplierID, ok := a.requireSupplierID(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
pkg, err := a.packageService.Pause(r.Context(), supplierID, packageID)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "SUP_PKG") {
|
||||
writeError(w, http.StatusConflict, "CONFLICT", err.Error())
|
||||
@@ -519,7 +600,12 @@ func (a *SupplyAPI) handlePausePackage(w http.ResponseWriter, r *http.Request, p
|
||||
}
|
||||
|
||||
func (a *SupplyAPI) handleUnlistPackage(w http.ResponseWriter, r *http.Request, packageID int64) {
|
||||
pkg, err := a.packageService.Unlist(r.Context(), a.supplierID, packageID)
|
||||
supplierID, ok := a.requireSupplierID(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
pkg, err := a.packageService.Unlist(r.Context(), supplierID, packageID)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "SUP_PKG") {
|
||||
writeError(w, http.StatusConflict, "CONFLICT", err.Error())
|
||||
@@ -540,7 +626,12 @@ func (a *SupplyAPI) handleUnlistPackage(w http.ResponseWriter, r *http.Request,
|
||||
}
|
||||
|
||||
func (a *SupplyAPI) handleClonePackage(w http.ResponseWriter, r *http.Request, packageID int64) {
|
||||
pkg, err := a.packageService.Clone(r.Context(), a.supplierID, packageID)
|
||||
supplierID, ok := a.requireSupplierID(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
pkg, err := a.packageService.Clone(r.Context(), supplierID, packageID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "NOT_FOUND", err.Error())
|
||||
return
|
||||
@@ -595,7 +686,12 @@ func (a *SupplyAPI) handleBatchUpdatePrice(w http.ResponseWriter, r *http.Reques
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := a.packageService.BatchUpdatePrice(r.Context(), a.supplierID, req)
|
||||
supplierID, ok := a.requireSupplierID(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := a.packageService.BatchUpdatePrice(r.Context(), supplierID, req)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnprocessableEntity, "BATCH_UPDATE_FAILED", err.Error())
|
||||
return
|
||||
@@ -618,7 +714,12 @@ func (a *SupplyAPI) handleGetBilling(w http.ResponseWriter, r *http.Request) {
|
||||
startDate := r.URL.Query().Get("start_date")
|
||||
endDate := r.URL.Query().Get("end_date")
|
||||
|
||||
summary, err := a.earningService.GetBillingSummary(r.Context(), a.supplierID, startDate, endDate)
|
||||
supplierID, ok := a.requireSupplierID(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
summary, err := a.earningService.GetBillingSummary(r.Context(), supplierID, startDate, endDate)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "QUERY_FAILED", err.Error())
|
||||
return
|
||||
@@ -637,6 +738,10 @@ func (a *SupplyAPI) handleWithdraw(w http.ResponseWriter, r *http.Request) {
|
||||
writeError(w, http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", "method not allowed")
|
||||
return
|
||||
}
|
||||
if !a.withdrawEnabled {
|
||||
writeError(w, http.StatusServiceUnavailable, "FEATURE_DISABLED", "withdraw is disabled until SMS verification is integrated")
|
||||
return
|
||||
}
|
||||
|
||||
// P0-P4修复: 使用DB-backed幂等中间件
|
||||
if a.idempotencyMw != nil {
|
||||
@@ -645,12 +750,17 @@ func (a *SupplyAPI) handleWithdraw(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// 降级:使用内联幂等逻辑(仅在幂等中间件未启用时)
|
||||
a.withdrawHandler(context.Background(), w, r, nil)
|
||||
a.withdrawHandler(r.Context(), w, r, nil)
|
||||
}
|
||||
|
||||
// withdrawHandler 提现的业务逻辑(供幂等中间件包装)
|
||||
func (a *SupplyAPI) withdrawHandler(ctx context.Context, w http.ResponseWriter, r *http.Request, _ *repository.IdempotencyRecord) error {
|
||||
requestID := r.Header.Get("X-Request-Id")
|
||||
supplierID, err := a.resolveSupplierID(ctx)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "AUTH_CONTEXT_MISSING", err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
@@ -678,7 +788,7 @@ func (a *SupplyAPI) withdrawHandler(ctx context.Context, w http.ResponseWriter,
|
||||
SMSCode: req.SMSCode,
|
||||
}
|
||||
|
||||
settlement, err := a.settlementService.Withdraw(ctx, a.supplierID, withdrawReq)
|
||||
settlement, err := a.settlementService.Withdraw(ctx, supplierID, withdrawReq)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "SUP_SET") {
|
||||
writeError(w, http.StatusConflict, "WITHDRAW_FAILED", err.Error())
|
||||
@@ -740,7 +850,12 @@ func (a *SupplyAPI) handleSettlementActions(w http.ResponseWriter, r *http.Reque
|
||||
}
|
||||
|
||||
func (a *SupplyAPI) handleCancelSettlement(w http.ResponseWriter, r *http.Request, settlementID int64) {
|
||||
settlement, err := a.settlementService.Cancel(r.Context(), a.supplierID, settlementID)
|
||||
supplierID, ok := a.requireSupplierID(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
settlement, err := a.settlementService.Cancel(r.Context(), supplierID, settlementID)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "SUP_SET") {
|
||||
writeError(w, http.StatusConflict, "CONFLICT", err.Error())
|
||||
@@ -761,7 +876,12 @@ func (a *SupplyAPI) handleCancelSettlement(w http.ResponseWriter, r *http.Reques
|
||||
}
|
||||
|
||||
func (a *SupplyAPI) handleGetStatement(w http.ResponseWriter, r *http.Request, settlementID int64) {
|
||||
settlement, err := a.settlementService.GetByID(r.Context(), a.supplierID, settlementID)
|
||||
supplierID, ok := a.requireSupplierID(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
settlement, err := a.settlementService.GetByID(r.Context(), supplierID, settlementID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "NOT_FOUND", err.Error())
|
||||
return
|
||||
@@ -791,7 +911,23 @@ func (a *SupplyAPI) handleGetEarningRecords(w http.ResponseWriter, r *http.Reque
|
||||
page := getQueryInt(r, "page", 1)
|
||||
pageSize := getQueryInt(r, "page_size", 20)
|
||||
|
||||
records, total, err := a.earningService.ListRecords(r.Context(), a.supplierID, startDate, endDate, page, pageSize)
|
||||
// 分页参数边界验证
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize < 1 {
|
||||
pageSize = 20
|
||||
}
|
||||
if pageSize > 1000 {
|
||||
pageSize = 1000
|
||||
}
|
||||
|
||||
supplierID, ok := a.requireSupplierID(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
records, total, err := a.earningService.ListRecords(r.Context(), supplierID, startDate, endDate, page, pageSize)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "QUERY_FAILED", err.Error())
|
||||
return
|
||||
|
||||
1399
supply-api/internal/httpapi/supply_api_test.go
Normal file
1399
supply-api/internal/httpapi/supply_api_test.go
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user