Files
lijiaoqiao/supply-api/internal/httpapi/supply_api_test.go

1489 lines
41 KiB
Go
Raw Normal View History

package httpapi
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"lijiaoqiao/supply-api/internal/audit"
"lijiaoqiao/supply-api/internal/domain"
"lijiaoqiao/supply-api/internal/middleware"
)
// ==================== Mock Implementations ====================
// mockAccountService Mock账户服务
type mockAccountService struct {
verifyResult *domain.VerifyResult
verifyErr error
account *domain.Account
createErr error
activateErr error
suspendErr error
deleteErr error
lastVerifySupplierID int64
}
func (m *mockAccountService) Verify(ctx context.Context, supplierID int64, provider domain.Provider, accountType domain.AccountType, credential string) (*domain.VerifyResult, error) {
m.lastVerifySupplierID = supplierID
if m.verifyErr != nil {
return nil, m.verifyErr
}
return m.verifyResult, nil
}
func (m *mockAccountService) Create(ctx context.Context, req *domain.CreateAccountRequest) (*domain.Account, error) {
if m.createErr != nil {
return nil, m.createErr
}
return m.account, nil
}
func (m *mockAccountService) Activate(ctx context.Context, supplierID, accountID int64) (*domain.Account, error) {
if m.activateErr != nil {
return nil, m.activateErr
}
return m.account, nil
}
func (m *mockAccountService) Suspend(ctx context.Context, supplierID, accountID int64) (*domain.Account, error) {
if m.suspendErr != nil {
return nil, m.suspendErr
}
return m.account, nil
}
func (m *mockAccountService) Delete(ctx context.Context, supplierID, accountID int64) error {
return m.deleteErr
}
func (m *mockAccountService) GetByID(ctx context.Context, supplierID, accountID int64) (*domain.Account, error) {
return m.account, nil
}
// mockPackageService Mock套餐服务
type mockPackageService struct {
pkg *domain.Package
createDraftErr error
publishErr error
pauseErr error
unlistErr error
cloneErr error
batchResp *domain.BatchUpdatePriceResponse
batchErr error
}
func (m *mockPackageService) CreateDraft(ctx context.Context, supplierID int64, req *domain.CreatePackageDraftRequest) (*domain.Package, error) {
if m.createDraftErr != nil {
return nil, m.createDraftErr
}
return m.pkg, nil
}
func (m *mockPackageService) Publish(ctx context.Context, supplierID, packageID int64) (*domain.Package, error) {
if m.publishErr != nil {
return nil, m.publishErr
}
return m.pkg, nil
}
func (m *mockPackageService) Pause(ctx context.Context, supplierID, packageID int64) (*domain.Package, error) {
if m.pauseErr != nil {
return nil, m.pauseErr
}
return m.pkg, nil
}
func (m *mockPackageService) Unlist(ctx context.Context, supplierID, packageID int64) (*domain.Package, error) {
if m.unlistErr != nil {
return nil, m.unlistErr
}
return m.pkg, nil
}
func (m *mockPackageService) Clone(ctx context.Context, supplierID, packageID int64) (*domain.Package, error) {
if m.cloneErr != nil {
return nil, m.cloneErr
}
return m.pkg, nil
}
func (m *mockPackageService) BatchUpdatePrice(ctx context.Context, supplierID int64, req *domain.BatchUpdatePriceRequest) (*domain.BatchUpdatePriceResponse, error) {
if m.batchErr != nil {
return nil, m.batchErr
}
return m.batchResp, nil
}
func (m *mockPackageService) GetByID(ctx context.Context, supplierID, packageID int64) (*domain.Package, error) {
return m.pkg, nil
}
// mockSettlementService Mock结算服务
type mockSettlementService struct {
settlement *domain.Settlement
withdrawErr error
cancelErr error
getErr error
}
func (m *mockSettlementService) Withdraw(ctx context.Context, supplierID int64, req *domain.WithdrawRequest) (*domain.Settlement, error) {
if m.withdrawErr != nil {
return nil, m.withdrawErr
}
return m.settlement, nil
}
func (m *mockSettlementService) Cancel(ctx context.Context, supplierID, settlementID int64) (*domain.Settlement, error) {
if m.cancelErr != nil {
return nil, m.cancelErr
}
return m.settlement, nil
}
func (m *mockSettlementService) GetByID(ctx context.Context, supplierID, settlementID int64) (*domain.Settlement, error) {
if m.getErr != nil {
return nil, m.getErr
}
return m.settlement, nil
}
func (m *mockSettlementService) List(ctx context.Context, supplierID int64) ([]*domain.Settlement, error) {
if m.settlement != nil {
return []*domain.Settlement{m.settlement}, nil
}
return nil, nil
}
func (m *mockSettlementService) GetBillingSummary(ctx context.Context, supplierID int64, startDate, endDate string) (*domain.BillingSummary, error) {
return nil, nil
}
// mockEarningService Mock收益服务
type mockEarningService struct {
records []*domain.EarningRecord
total int
billingSummary *domain.BillingSummary
listErr error
billingErr error
}
func (m *mockEarningService) ListRecords(ctx context.Context, supplierID int64, startDate, endDate string, page, pageSize int) ([]*domain.EarningRecord, int, error) {
if m.listErr != nil {
return nil, 0, m.listErr
}
return m.records, m.total, nil
}
func (m *mockEarningService) GetBillingSummary(ctx context.Context, supplierID int64, startDate, endDate string) (*domain.BillingSummary, error) {
if m.billingErr != nil {
return nil, m.billingErr
}
return m.billingSummary, nil
}
// mockAuditStore Mock审计存储
type mockAuditStore struct {
events []audit.Event
event audit.Event
err error
}
func (m *mockAuditStore) Emit(ctx context.Context, event audit.Event) error {
return m.err
}
func (m *mockAuditStore) Query(ctx context.Context, filter audit.EventFilter) ([]audit.Event, error) {
if m.err != nil {
return nil, m.err
}
return m.events, nil
}
func (m *mockAuditStore) QueryWithTotal(ctx context.Context, filter audit.EventFilter) ([]audit.Event, int64, error) {
if m.err != nil {
return nil, 0, m.err
}
return m.events, int64(len(m.events)), nil
}
func (m *mockAuditStore) GetByID(ctx context.Context, eventID string) (audit.Event, error) {
if m.err != nil {
return audit.Event{}, m.err
}
return m.event, nil
}
// ==================== Test Helpers ====================
func newTestAPI() (*SupplyAPI, *mockAccountService, *mockPackageService, *mockSettlementService, *mockEarningService, *mockAuditStore) {
return newTestAPIWithIdempotency(middleware.NewIdempotencyMiddleware(nil, middleware.IdempotencyConfig{
Enabled: false,
}))
}
func newTestAPIWithoutIdempotencyForTest() (*SupplyAPI, *mockAccountService, *mockPackageService, *mockSettlementService, *mockEarningService, *mockAuditStore) {
return newTestAPIWithIdempotency(nil)
}
func newTestAPIWithIdempotency(idempotencyMw *middleware.IdempotencyMiddleware) (*SupplyAPI, *mockAccountService, *mockPackageService, *mockSettlementService, *mockEarningService, *mockAuditStore) {
accountSvc := &mockAccountService{
account: &domain.Account{
ID: 1,
SupplierID: 100,
Provider: domain.ProviderOpenAI,
AccountType: domain.AccountTypeAPIKey,
Status: domain.AccountStatusActive,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
},
verifyResult: &domain.VerifyResult{
VerifyStatus: "pass",
AvailableQuota: 1000,
RiskScore: 0,
},
}
packageSvc := &mockPackageService{
pkg: &domain.Package{
ID: 1,
SupplierID: 100,
Model: "gpt-4",
Status: domain.PackageStatusActive,
TotalQuota: 10000,
AvailableQuota: 8000,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
},
}
settlementSvc := &mockSettlementService{
settlement: &domain.Settlement{
ID: 1,
SupplierID: 100,
Status: domain.SettlementStatusPending,
TotalAmount: 1000,
NetAmount: 950,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
},
}
earningSvc := &mockEarningService{
records: []*domain.EarningRecord{
{
ID: 1,
Amount: 100,
Status: "available",
},
},
total: 1,
billingSummary: &domain.BillingSummary{},
}
auditSvc := &mockAuditStore{
events: []audit.Event{
{
EventID: "evt_123",
TenantID: 100,
ObjectType: "supply_account",
ObjectID: 1,
Action: "create",
CreatedAt: time.Now(),
},
},
event: audit.Event{
EventID: "evt_123",
TenantID: 100,
ObjectType: "supply_account",
ObjectID: 1,
Action: "create",
CreatedAt: time.Now(),
},
}
api, err := NewSupplyAPI(
accountSvc,
packageSvc,
settlementSvc,
earningSvc,
idempotencyMw,
auditSvc,
nil, // fkValidator
100, // supplierID
"https://statements.example.com",
time.Now,
)
if err != nil {
panic("expected api constructor to succeed: " + err.Error())
}
return api, accountSvc, packageSvc, settlementSvc, earningSvc, auditSvc
}
func TestNewSupplyAPI_ReturnsErrorWhenAccountServiceMissing(t *testing.T) {
api, err := NewSupplyAPI(
nil,
&mockPackageService{},
&mockSettlementService{},
&mockEarningService{},
nil,
&mockAuditStore{},
nil,
100,
"https://statements.example.com",
time.Now,
)
if err == nil {
t.Fatal("expected error")
}
if api != nil {
t.Fatal("expected nil api")
}
}
func TestNewSupplyAPI_DefaultsClockWhenNil(t *testing.T) {
api, err := NewSupplyAPI(
&mockAccountService{},
&mockPackageService{},
&mockSettlementService{},
&mockEarningService{},
nil,
&mockAuditStore{},
nil,
100,
"https://statements.example.com",
nil,
)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if api.now == nil {
t.Fatal("expected default clock")
}
}
// ==================== Account Handler Tests ====================
func TestSupplyAPI_VerifyAccount_Success(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
body := `{"provider":"openai","account_type":"resource","credential_input":"sk-test123"}`
req := httptest.NewRequest("POST", "/api/v1/supply/accounts/verify", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Request-Id", "test-req-001")
w := httptest.NewRecorder()
api.handleVerifyAccount(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
var resp map[string]any
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp["request_id"] != "test-req-001" {
t.Errorf("expected request_id test-req-001, got %v", resp["request_id"])
}
}
func TestSupplyAPI_VerifyAccount_MethodNotAllowed(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
req := httptest.NewRequest("GET", "/api/v1/supply/accounts/verify", nil)
w := httptest.NewRecorder()
api.handleVerifyAccount(w, req)
if w.Code != http.StatusMethodNotAllowed {
t.Errorf("expected status 405, got %d", w.Code)
}
}
func TestSupplyAPI_VerifyAccount_InvalidJSON(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
body := `{invalid json}`
req := httptest.NewRequest("POST", "/api/v1/supply/accounts/verify", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
api.handleVerifyAccount(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
}
func TestSupplyAPI_VerifyAccount_VerifyFailed(t *testing.T) {
api, accountSvc, _, _, _, _ := newTestAPI()
accountSvc.verifyErr = errors.New("SUP_ACC_4001: verification failed")
body := `{"provider":"openai","account_type":"resource","credential_input":"invalid"}`
req := httptest.NewRequest("POST", "/api/v1/supply/accounts/verify", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
api.handleVerifyAccount(w, req)
if w.Code != http.StatusUnprocessableEntity {
t.Errorf("expected status 422, got %d", w.Code)
}
}
func TestSupplyAPI_VerifyAccount_UsesTenantIDFromContext(t *testing.T) {
api, accountSvc, _, _, _, _ := newTestAPI()
body := `{"provider":"openai","account_type":"resource","credential_input":"sk-test123"}`
req := httptest.NewRequest("POST", "/api/v1/supply/accounts/verify", strings.NewReader(body))
req = req.WithContext(middleware.WithTenantID(req.Context(), 200))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
api.handleVerifyAccount(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
if accountSvc.lastVerifySupplierID != 200 {
t.Fatalf("expected tenant supplier ID 200, got %d", accountSvc.lastVerifySupplierID)
}
}
func TestSupplyAPI_VerifyAccount_RejectsMissingTenantContextWithoutDefaultSupplier(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
api.supplierID = 0
body := `{"provider":"openai","account_type":"resource","credential_input":"sk-test123"}`
req := httptest.NewRequest("POST", "/api/v1/supply/accounts/verify", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
api.handleVerifyAccount(w, req)
if w.Code != http.StatusUnauthorized {
t.Fatalf("expected status 401, got %d body=%s", w.Code, w.Body.String())
}
}
func TestSupplyAPI_CreateAccount_Success(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
body := `{"provider":"openai","account_type":"resource","credential_input":"sk-test","account_alias":"test","risk_ack":true}`
req := httptest.NewRequest("POST", "/api/v1/supply/accounts", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
api.handleCreateAccount(w, req)
if w.Code != http.StatusCreated {
t.Errorf("expected status 201, got %d", w.Code)
}
}
func TestHandleCreateAccount_RequiresIdempotencyMiddleware(t *testing.T) {
api, _, _, _, _, _ := newTestAPIWithoutIdempotencyForTest()
body := `{"provider":"openai","account_type":"resource","credential_input":"sk-test","account_alias":"test","risk_ack":true}`
req := httptest.NewRequest("POST", "/api/v1/supply/accounts", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
api.handleCreateAccount(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Fatalf("expected 503 when idempotency middleware is missing, got=%d body=%s", w.Code, w.Body.String())
}
}
func TestSupplyAPI_CreateAccount_MethodNotAllowed(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
req := httptest.NewRequest("GET", "/api/v1/supply/accounts", nil)
w := httptest.NewRecorder()
api.handleCreateAccount(w, req)
if w.Code != http.StatusMethodNotAllowed {
t.Errorf("expected status 405, got %d", w.Code)
}
}
func TestSupplyAPI_ActivateAccount_Success(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
req := httptest.NewRequest("POST", "/api/v1/supply/accounts/1/activate", nil)
w := httptest.NewRecorder()
api.handleAccountActions(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
}
func TestSupplyAPI_ActivateAccount_NotFound(t *testing.T) {
api, accountSvc, _, _, _, _ := newTestAPI()
accountSvc.activateErr = errors.New("account not found")
req := httptest.NewRequest("POST", "/api/v1/supply/accounts/1/activate", nil)
w := httptest.NewRecorder()
api.handleAccountActions(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("expected status 404, got %d", w.Code)
}
}
func TestSupplyAPI_SuspendAccount_Success(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
req := httptest.NewRequest("POST", "/api/v1/supply/accounts/1/suspend", nil)
w := httptest.NewRecorder()
api.handleAccountActions(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
}
func TestSupplyAPI_SuspendAccount_Conflict(t *testing.T) {
api, accountSvc, _, _, _, _ := newTestAPI()
accountSvc.suspendErr = errors.New("SUP_ACC_4091: account state conflict")
req := httptest.NewRequest("POST", "/api/v1/supply/accounts/1/suspend", nil)
w := httptest.NewRecorder()
api.handleAccountActions(w, req)
if w.Code != http.StatusConflict {
t.Errorf("expected status 409, got %d", w.Code)
}
}
func TestSupplyAPI_SuspendAccount_WrongMethod(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
req := httptest.NewRequest("GET", "/api/v1/supply/accounts/1/suspend", nil)
w := httptest.NewRecorder()
api.handleAccountActions(w, req)
if w.Code != http.StatusMethodNotAllowed {
t.Errorf("expected status 405, got %d", w.Code)
}
}
func TestSupplyAPI_DeleteAccount_Success(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
req := httptest.NewRequest("DELETE", "/api/v1/supply/accounts/1/delete", nil)
w := httptest.NewRecorder()
api.handleAccountActions(w, req)
if w.Code != http.StatusNoContent {
t.Errorf("expected status 204, got %d", w.Code)
}
}
func TestSupplyAPI_DeleteAccount_Conflict(t *testing.T) {
api, accountSvc, _, _, _, _ := newTestAPI()
accountSvc.deleteErr = errors.New("SUP_ACC_4092: cannot delete account with active packages")
req := httptest.NewRequest("DELETE", "/api/v1/supply/accounts/1/delete", nil)
w := httptest.NewRecorder()
api.handleAccountActions(w, req)
if w.Code != http.StatusConflict {
t.Errorf("expected status 409, got %d", w.Code)
}
}
func TestSupplyAPI_DeleteAccount_WrongMethod(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
req := httptest.NewRequest("POST", "/api/v1/supply/accounts/1/delete", nil)
w := httptest.NewRecorder()
api.handleAccountActions(w, req)
if w.Code != http.StatusMethodNotAllowed {
t.Errorf("expected status 405, got %d", w.Code)
}
}
func TestSupplyAPI_AccountAuditLogs_Success(t *testing.T) {
api, _, _, _, _, auditSvc := newTestAPI()
req := httptest.NewRequest("GET", "/api/v1/supply/accounts/1/audit-logs", nil)
w := httptest.NewRecorder()
api.handleAccountActions(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
var resp map[string]any
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
data, ok := resp["data"].([]any)
if !ok {
t.Fatal("expected data array in response")
}
if len(data) != 1 {
t.Errorf("expected 1 event, got %d", len(data))
}
auditSvc.err = errors.New("query failed")
req = httptest.NewRequest("GET", "/api/v1/supply/accounts/1/audit-logs", nil)
w = httptest.NewRecorder()
api.handleAccountActions(w, req)
if w.Code != http.StatusInternalServerError {
t.Errorf("expected status 500, got %d", w.Code)
}
}
func TestSupplyAPI_AccountActions_InvalidID(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
req := httptest.NewRequest("POST", "/api/v1/supply/accounts/invalid/activate", nil)
w := httptest.NewRecorder()
api.handleAccountActions(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
}
func TestSupplyAPI_AccountActions_UnknownRoute(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
req := httptest.NewRequest("POST", "/api/v1/supply/accounts/1/unknown", nil)
w := httptest.NewRecorder()
api.handleAccountActions(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("expected status 404, got %d", w.Code)
}
}
// ==================== Package Handler Tests ====================
func TestSupplyAPI_CreatePackageDraft_Success(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
body := `{"supply_account_id":1,"model":"gpt-4","total_quota":10000,"price_per_1m_input":0.1,"price_per_1m_output":0.2,"valid_days":30,"max_concurrent":10,"rate_limit_rpm":1000}`
req := httptest.NewRequest("POST", "/api/v1/supply/packages/draft", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
api.handleCreatePackageDraft(w, req)
if w.Code != http.StatusCreated {
t.Errorf("expected status 201, got %d", w.Code)
}
}
func TestSupplyAPI_CreatePackageDraft_MethodNotAllowed(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
req := httptest.NewRequest("GET", "/api/v1/supply/packages/draft", nil)
w := httptest.NewRecorder()
api.handleCreatePackageDraft(w, req)
if w.Code != http.StatusMethodNotAllowed {
t.Errorf("expected status 405, got %d", w.Code)
}
}
func TestSupplyAPI_PublishPackage_Success(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
req := httptest.NewRequest("POST", "/api/v1/supply/packages/1/publish", nil)
w := httptest.NewRecorder()
api.handlePackageActions(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
}
func TestSupplyAPI_PublishPackage_NotFound(t *testing.T) {
api, _, packageSvc, _, _, _ := newTestAPI()
packageSvc.publishErr = errors.New("package not found")
req := httptest.NewRequest("POST", "/api/v1/supply/packages/1/publish", nil)
w := httptest.NewRecorder()
api.handlePackageActions(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("expected status 404, got %d", w.Code)
}
}
func TestSupplyAPI_PausePackage_Success(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
req := httptest.NewRequest("POST", "/api/v1/supply/packages/1/pause", nil)
w := httptest.NewRecorder()
api.handlePackageActions(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
}
func TestSupplyAPI_PausePackage_Conflict(t *testing.T) {
api, _, packageSvc, _, _, _ := newTestAPI()
packageSvc.pauseErr = errors.New("SUP_PKG_4092: cannot pause active package")
req := httptest.NewRequest("POST", "/api/v1/supply/packages/1/pause", nil)
w := httptest.NewRecorder()
api.handlePackageActions(w, req)
if w.Code != http.StatusConflict {
t.Errorf("expected status 409, got %d", w.Code)
}
}
func TestSupplyAPI_PausePackage_WrongMethod(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
req := httptest.NewRequest("GET", "/api/v1/supply/packages/1/pause", nil)
w := httptest.NewRecorder()
api.handlePackageActions(w, req)
if w.Code != http.StatusMethodNotAllowed {
t.Errorf("expected status 405, got %d", w.Code)
}
}
func TestSupplyAPI_UnlistPackage_Success(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
req := httptest.NewRequest("POST", "/api/v1/supply/packages/1/unlist", nil)
w := httptest.NewRecorder()
api.handlePackageActions(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
}
func TestSupplyAPI_UnlistPackage_Conflict(t *testing.T) {
api, _, packageSvc, _, _, _ := newTestAPI()
packageSvc.unlistErr = errors.New("SUP_PKG_4093: cannot unlist package")
req := httptest.NewRequest("POST", "/api/v1/supply/packages/1/unlist", nil)
w := httptest.NewRecorder()
api.handlePackageActions(w, req)
if w.Code != http.StatusConflict {
t.Errorf("expected status 409, got %d", w.Code)
}
}
func TestSupplyAPI_UnlistPackage_WrongMethod(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
req := httptest.NewRequest("GET", "/api/v1/supply/packages/1/unlist", nil)
w := httptest.NewRecorder()
api.handlePackageActions(w, req)
if w.Code != http.StatusMethodNotAllowed {
t.Errorf("expected status 405, got %d", w.Code)
}
}
func TestSupplyAPI_ClonePackage_WrongMethod(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
req := httptest.NewRequest("GET", "/api/v1/supply/packages/1/clone", nil)
w := httptest.NewRecorder()
api.handlePackageActions(w, req)
if w.Code != http.StatusMethodNotAllowed {
t.Errorf("expected status 405, got %d", w.Code)
}
}
func TestSupplyAPI_ClonePackage_NotFound(t *testing.T) {
api, _, packageSvc, _, _, _ := newTestAPI()
packageSvc.cloneErr = errors.New("package not found")
req := httptest.NewRequest("POST", "/api/v1/supply/packages/1/clone", nil)
w := httptest.NewRecorder()
api.handlePackageActions(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("expected status 404, got %d", w.Code)
}
}
func TestSupplyAPI_PublishPackage_WrongMethod(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
req := httptest.NewRequest("GET", "/api/v1/supply/packages/1/publish", nil)
w := httptest.NewRecorder()
api.handlePackageActions(w, req)
if w.Code != http.StatusMethodNotAllowed {
t.Errorf("expected status 405, got %d", w.Code)
}
}
func TestSupplyAPI_ClonePackage_Success(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
req := httptest.NewRequest("POST", "/api/v1/supply/packages/1/clone", nil)
w := httptest.NewRecorder()
api.handlePackageActions(w, req)
if w.Code != http.StatusCreated {
t.Errorf("expected status 201, got %d", w.Code)
}
}
func TestSupplyAPI_BatchUpdatePrice_Success(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
api.packageService.(*mockPackageService).batchResp = &domain.BatchUpdatePriceResponse{
Total: 2,
SuccessCount: 2,
FailedCount: 0,
}
body := `{"items":[{"package_id":1,"price_per_1m_input":0.15,"price_per_1m_output":0.25},{"package_id":2,"price_per_1m_input":0.12,"price_per_1m_output":0.22}]}`
req := httptest.NewRequest("POST", "/api/v1/supply/packages/batch-price", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
api.handleBatchUpdatePrice(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
}
func TestSupplyAPI_BatchUpdatePrice_MethodNotAllowed(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
req := httptest.NewRequest("GET", "/api/v1/supply/packages/batch-price", nil)
w := httptest.NewRecorder()
api.handleBatchUpdatePrice(w, req)
if w.Code != http.StatusMethodNotAllowed {
t.Errorf("expected status 405, got %d", w.Code)
}
}
func TestSupplyAPI_BatchUpdatePrice_InvalidJSON(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
body := `{invalid}`
req := httptest.NewRequest("POST", "/api/v1/supply/packages/batch-price", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
api.handleBatchUpdatePrice(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
}
func TestSupplyAPI_BatchUpdatePrice_BatchFailed(t *testing.T) {
api, _, packageSvc, _, _, _ := newTestAPI()
packageSvc.batchErr = errors.New("batch update failed")
body := `{"items":[{"package_id":1,"price_per_1m_input":0.15,"price_per_1m_output":0.25}]}`
req := httptest.NewRequest("POST", "/api/v1/supply/packages/batch-price", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
api.handleBatchUpdatePrice(w, req)
if w.Code != http.StatusUnprocessableEntity {
t.Errorf("expected status 422, got %d", w.Code)
}
}
// ==================== Billing Handler Tests ====================
func TestSupplyAPI_GetBilling_Success(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
req := httptest.NewRequest("GET", "/api/v1/supply/billing?start_date=2024-01-01&end_date=2024-01-31", nil)
w := httptest.NewRecorder()
api.handleGetBilling(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
}
func TestSupplyAPI_GetBilling_MethodNotAllowed(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
req := httptest.NewRequest("POST", "/api/v1/supply/billing", nil)
w := httptest.NewRecorder()
api.handleGetBilling(w, req)
if w.Code != http.StatusMethodNotAllowed {
t.Errorf("expected status 405, got %d", w.Code)
}
}
func TestSupplyAPI_GetBilling_QueryFailed(t *testing.T) {
api, _, _, _, earningSvc, _ := newTestAPI()
earningSvc.billingErr = errors.New("query failed")
req := httptest.NewRequest("GET", "/api/v1/supply/billing", nil)
w := httptest.NewRecorder()
api.handleGetBilling(w, req)
if w.Code != http.StatusInternalServerError {
t.Errorf("expected status 500, got %d", w.Code)
}
}
// ==================== Settlement Handler Tests ====================
func TestSupplyAPI_Withdraw_Success(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
body := `{"withdraw_amount":1000,"payment_method":"bank","payment_account":"1234567890","sms_code":"123456"}`
req := httptest.NewRequest("POST", "/api/v1/supply/settlements/withdraw", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
api.handleWithdraw(w, req)
if w.Code != http.StatusCreated {
t.Errorf("expected status 201, got %d", w.Code)
}
}
func TestHandleWithdraw_RequiresIdempotencyMiddleware(t *testing.T) {
api, _, _, _, _, _ := newTestAPIWithoutIdempotencyForTest()
body := `{"withdraw_amount":1000,"payment_method":"bank","payment_account":"1234567890","sms_code":"123456"}`
req := httptest.NewRequest("POST", "/api/v1/supply/settlements/withdraw", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
api.handleWithdraw(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Fatalf("expected 503 when idempotency middleware is missing, got=%d body=%s", w.Code, w.Body.String())
}
}
func TestSupplyAPI_Withdraw_MethodNotAllowed(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
req := httptest.NewRequest("GET", "/api/v1/supply/settlements/withdraw", nil)
w := httptest.NewRecorder()
api.handleWithdraw(w, req)
if w.Code != http.StatusMethodNotAllowed {
t.Errorf("expected status 405, got %d", w.Code)
}
}
func TestSupplyAPI_Withdraw_InvalidJSON(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
body := `{invalid}`
req := httptest.NewRequest("POST", "/api/v1/supply/settlements/withdraw", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
api.handleWithdraw(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
}
func TestSupplyAPI_Withdraw_WithdrawFailed(t *testing.T) {
api, _, _, settlementSvc, _, _ := newTestAPI()
settlementSvc.withdrawErr = errors.New("SUP_SET_4001: insufficient balance")
body := `{"withdraw_amount":1000000,"payment_method":"bank","payment_account":"1234567890","sms_code":"123456"}`
req := httptest.NewRequest("POST", "/api/v1/supply/settlements/withdraw", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
api.handleWithdraw(w, req)
if w.Code != http.StatusConflict {
t.Errorf("expected status 409, got %d", w.Code)
}
}
func TestSupplyAPI_CancelSettlement_Success(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
req := httptest.NewRequest("POST", "/api/v1/supply/settlements/1/cancel", nil)
w := httptest.NewRecorder()
api.handleSettlementActions(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
}
func TestSupplyAPI_CancelSettlement_NotFound(t *testing.T) {
api, _, _, settlementSvc, _, _ := newTestAPI()
settlementSvc.cancelErr = errors.New("settlement not found")
req := httptest.NewRequest("POST", "/api/v1/supply/settlements/1/cancel", nil)
w := httptest.NewRecorder()
api.handleSettlementActions(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("expected status 404, got %d", w.Code)
}
}
func TestSupplyAPI_GetStatement_Success(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
req := httptest.NewRequest("GET", "/api/v1/supply/settlements/1/statement", nil)
w := httptest.NewRecorder()
api.handleSettlementActions(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
var resp map[string]any
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
data, ok := resp["data"].(map[string]any)
if !ok {
t.Fatal("expected data in response")
}
if data["file_name"] == nil {
t.Error("expected file_name in data")
}
if data["download_url"] == nil {
t.Error("expected download_url in data")
}
if data["expires_at"] == nil {
t.Error("expected expires_at in data")
}
}
func TestSupplyAPI_GetStatement_NotFound(t *testing.T) {
api, _, _, settlementSvc, _, _ := newTestAPI()
settlementSvc.getErr = errors.New("settlement not found")
req := httptest.NewRequest("GET", "/api/v1/supply/settlements/1/statement", nil)
w := httptest.NewRecorder()
api.handleSettlementActions(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("expected status 404, got %d", w.Code)
}
}
func TestSupplyAPI_SettlementActions_InvalidID(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
req := httptest.NewRequest("POST", "/api/v1/supply/settlements/invalid/cancel", nil)
w := httptest.NewRecorder()
api.handleSettlementActions(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
}
func TestSupplyAPI_SettlementActions_UnknownAction(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
req := httptest.NewRequest("POST", "/api/v1/supply/settlements/1/unknown", nil)
w := httptest.NewRecorder()
api.handleSettlementActions(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("expected status 404, got %d", w.Code)
}
}
// ==================== Earning Handler Tests ====================
func TestSupplyAPI_GetEarningRecords_Success(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
req := httptest.NewRequest("GET", "/api/v1/supply/earnings/records?start_date=2024-01-01&end_date=2024-01-31&page=1&page_size=20", nil)
w := httptest.NewRecorder()
api.handleGetEarningRecords(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
var resp map[string]any
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
data, ok := resp["data"].([]any)
if !ok {
t.Fatal("expected data array in response")
}
if len(data) != 1 {
t.Errorf("expected 1 record, got %d", len(data))
}
pagination, ok := resp["pagination"].(map[string]any)
if !ok {
t.Fatal("expected pagination in response")
}
if pagination["total"] != float64(1) {
t.Errorf("expected total 1, got %v", pagination["total"])
}
}
func TestSupplyAPI_GetEarningRecords_MethodNotAllowed(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
req := httptest.NewRequest("POST", "/api/v1/supply/earnings/records", nil)
w := httptest.NewRecorder()
api.handleGetEarningRecords(w, req)
if w.Code != http.StatusMethodNotAllowed {
t.Errorf("expected status 405, got %d", w.Code)
}
}
func TestSupplyAPI_GetEarningRecords_QueryFailed(t *testing.T) {
api, _, _, _, earningSvc, _ := newTestAPI()
earningSvc.listErr = errors.New("query failed")
req := httptest.NewRequest("GET", "/api/v1/supply/earnings/records", nil)
w := httptest.NewRecorder()
api.handleGetEarningRecords(w, req)
if w.Code != http.StatusInternalServerError {
t.Errorf("expected status 500, got %d", w.Code)
}
}
// ==================== Audit Event Handler Tests ====================
func TestSupplyAPI_GetAuditEvent_Success(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
req := httptest.NewRequest("GET", "/api/v1/audit/events/evt_123", nil)
w := httptest.NewRecorder()
api.handleAuditEvent(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
var resp map[string]any
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
data, ok := resp["data"].(map[string]any)
if !ok {
t.Fatal("expected data in response")
}
if data["event_id"] != "evt_123" {
t.Errorf("expected event_id evt_123, got %v", data["event_id"])
}
}
func TestSupplyAPI_GetAuditEvent_NotFound(t *testing.T) {
api, _, _, _, _, auditSvc := newTestAPI()
auditSvc.err = errors.New("not found")
req := httptest.NewRequest("GET", "/api/v1/audit/events/evt_999", nil)
w := httptest.NewRecorder()
api.handleAuditEvent(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("expected status 404, got %d", w.Code)
}
}
func TestSupplyAPI_GetAuditEvent_MissingID(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
req := httptest.NewRequest("GET", "/api/v1/audit/events/", nil)
w := httptest.NewRecorder()
api.handleAuditEvent(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
}
func TestSupplyAPI_GetAuditEvent_MethodNotAllowed(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
req := httptest.NewRequest("POST", "/api/v1/audit/events/evt_123", nil)
w := httptest.NewRecorder()
api.handleAuditEvent(w, req)
if w.Code != http.StatusMethodNotAllowed {
t.Errorf("expected status 405, got %d", w.Code)
}
}
// ==================== Helper Function Tests ====================
func TestGetRequestID(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("X-Request-Id", "req-123")
id := getRequestID(req)
if id != "req-123" {
t.Errorf("expected req-123, got %s", id)
}
req = httptest.NewRequest("GET", "/", nil)
req.Header.Set("X-Request-ID", "req-456")
id = getRequestID(req)
if id != "req-456" {
t.Errorf("expected req-456, got %s", id)
}
req = httptest.NewRequest("GET", "/", nil)
id = getRequestID(req)
if id != "" {
t.Errorf("expected empty string, got %s", id)
}
}
func TestGetQueryInt(t *testing.T) {
req := httptest.NewRequest("GET", "/?page=5&page_size=100", nil)
if getQueryInt(req, "page", 1) != 5 {
t.Error("expected page 5")
}
if getQueryInt(req, "page_size", 20) != 100 {
t.Error("expected page_size 100")
}
if getQueryInt(req, "missing", 10) != 10 {
t.Error("expected default 10 for missing param")
}
if getQueryInt(req, "invalid", 1) != 1 {
t.Error("expected default 1 for invalid value")
}
}
func TestWriteJSON(t *testing.T) {
w := httptest.NewRecorder()
writeJSON(w, http.StatusOK, map[string]any{"key": "value"})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if w.Header().Get("Content-Type") != "application/json" {
t.Error("expected Content-Type application/json")
}
}
func TestWriteError(t *testing.T) {
w := httptest.NewRecorder()
writeError(w, http.StatusBadRequest, "TEST_ERROR", "test message")
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
var resp map[string]any
json.Unmarshal(w.Body.Bytes(), &resp)
errObj, ok := resp["error"].(map[string]any)
if !ok {
t.Fatal("expected error object in response")
}
if errObj["code"] != "TEST_ERROR" {
t.Errorf("expected code TEST_ERROR, got %v", errObj["code"])
}
if errObj["message"] != "test message" {
t.Errorf("expected message 'test message', got %v", errObj["message"])
}
}
// ==================== Integration Tests ====================
func TestSupplyAPI_Register(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
mux := http.NewServeMux()
api.Register(mux)
// 验证路由已注册不会panic
_ = mux
}
func TestSupplyAPI_EndToEnd_Withdraw(t *testing.T) {
api, _, _, settlementSvc, _, _ := newTestAPI()
settlementSvc.settlement = &domain.Settlement{
ID: 1,
SupplierID: 100,
SettlementNo: "SET_20240101_001",
Status: domain.SettlementStatusPending,
TotalAmount: 1000,
NetAmount: 950,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
body := `{"withdraw_amount":500,"payment_method":"bank","payment_account":"1234567890","sms_code":"123456"}`
req := httptest.NewRequest("POST", "/api/v1/supply/settlements/withdraw", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Request-Id", "test-req-001")
w := httptest.NewRecorder()
api.handleWithdraw(w, req)
if w.Code != http.StatusCreated {
t.Errorf("expected status 201, got %d. Body: %s", w.Code, w.Body.String())
}
var resp map[string]any
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp["request_id"] != "test-req-001" {
t.Errorf("expected request_id test-req-001, got %v", resp["request_id"])
}
data, ok := resp["data"].(map[string]any)
if !ok {
t.Fatal("expected data in response")
}
if data["settlement_id"] != float64(1) {
t.Errorf("expected settlement_id 1, got %v", data["settlement_id"])
}
if data["status"] != "pending" {
t.Errorf("expected status pending, got %v", data["status"])
}
}
func TestSupplyAPI_WithdrawDisabled_ReturnsServiceUnavailable(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
api.withdrawEnabled = false
body := `{"withdraw_amount":500,"payment_method":"bank","payment_account":"1234567890","sms_code":"123456"}`
req := httptest.NewRequest("POST", "/api/v1/supply/settlements/withdraw", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
api.handleWithdraw(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Fatalf("expected status 503, got %d body=%s", w.Code, w.Body.String())
}
}
func TestSupplyAPI_EndToEnd_BillingSummary(t *testing.T) {
api, _, _, _, earningSvc, _ := newTestAPI()
earningSvc.billingSummary = &domain.BillingSummary{
Period: domain.BillingPeriod{
Start: "2024-01-01",
End: "2024-01-31",
},
Summary: domain.BillingTotal{
TotalRevenue: 10000,
TotalOrders: 100,
TotalUsage: 1000000,
TotalRequests: 5000000,
AvgSuccessRate: 99.5,
},
}
req := httptest.NewRequest("GET", "/api/v1/supply/billing?start_date=2024-01-01&end_date=2024-01-31", nil)
w := httptest.NewRecorder()
api.handleGetBilling(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
var resp map[string]any
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
data, ok := resp["data"].(map[string]any)
if !ok {
t.Fatal("expected data in response")
}
summary, ok := data["summary"].(map[string]any)
if !ok {
t.Fatal("expected summary in data")
}
if summary["total_revenue"] != float64(10000) {
t.Errorf("expected total_revenue 10000, got %v", summary["total_revenue"])
}
}