diff --git a/gateway/cmd/gateway/main.go b/gateway/cmd/gateway/main.go index d6bb80c1..adb383b8 100644 --- a/gateway/cmd/gateway/main.go +++ b/gateway/cmd/gateway/main.go @@ -77,18 +77,36 @@ func main() { auditor = middleware.NewMemoryAuditEmitter() } - // 初始化 token 运行时(内存实现) - tokenRuntime := middleware.NewInMemoryTokenRuntime(time.Now) + // 初始化 token 运行时 + var tokenRuntime interface { + middleware.TokenVerifier + middleware.TokenStatusResolver + } + switch cfg.Auth.TokenRuntimeMode { + case "inmemory": + tokenRuntime = middleware.NewInMemoryTokenRuntime(time.Now) + case "remote_introspection": + tokenRuntime = middleware.NewRemoteTokenRuntime(cfg.Auth.TokenRuntimeURL, http.DefaultClient, time.Now) + default: + log.Fatalf("unsupported token runtime mode: %s", cfg.Auth.TokenRuntimeMode) + } // 构建认证中间件配置 authMiddlewareConfig := middleware.AuthMiddlewareConfig{ - Verifier: tokenRuntime, - StatusResolver: tokenRuntime, - Authorizer: middleware.NewScopeRoleAuthorizer(), - Auditor: auditor, - ProtectedPrefixes: []string{"/api/v1/supply", "/api/v1/platform"}, - ExcludedPrefixes: []string{"/health", "/healthz", "/metrics", "/readyz"}, - Now: time.Now, + Verifier: tokenRuntime, + StatusResolver: tokenRuntime, + Authorizer: middleware.NewScopeRoleAuthorizer(), + Auditor: auditor, + ProtectedPrefixes: []string{ + "/v1/chat/completions", + "/v1/completions", + "/api/v1/chat/completions", + "/api/v1/completions", + "/api/v1/supply", + "/api/v1/platform", + }, + ExcludedPrefixes: []string{"/health", "/healthz", "/metrics", "/readyz"}, + Now: time.Now, } // 初始化Handler @@ -132,33 +150,38 @@ func main() { func createMux(h *handler.Handler, limiter *ratelimit.Middleware, authConfig middleware.AuthMiddlewareConfig) http.Handler { mux := http.NewServeMux() - // 创建认证处理链 - authHandler := middleware.BuildTokenAuthChain(authConfig, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - h.ChatCompletionsHandle(w, r) - })) + chatHandler := middleware.BuildTokenAuthChain(authConfig, http.HandlerFunc(h.ChatCompletionsHandle)) + completionsHandler := middleware.BuildTokenAuthChain(authConfig, http.HandlerFunc(h.CompletionsHandle)) - // Chat Completions - 应用限流和认证 mux.HandleFunc("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { - limiter.Limit(authHandler.ServeHTTP)(w, r) + limitHandler(limiter, chatHandler).ServeHTTP(w, r) }) - // Completions - 应用限流和认证 mux.HandleFunc("/v1/completions", func(w http.ResponseWriter, r *http.Request) { - limiter.Limit(authHandler.ServeHTTP)(w, r) + limitHandler(limiter, completionsHandler).ServeHTTP(w, r) }) - // Models - 公开接口 mux.HandleFunc("/v1/models", h.ModelsHandle) - // 旧版路径兼容 mux.HandleFunc("/api/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { - h.ChatCompletionsHandle(w, r) + limitHandler(limiter, chatHandler).ServeHTTP(w, r) + }) + mux.HandleFunc("/api/v1/completions", func(w http.ResponseWriter, r *http.Request) { + limitHandler(limiter, completionsHandler).ServeHTTP(w, r) }) - // Health - 排除认证 mux.HandleFunc("/health", h.HealthHandle) mux.HandleFunc("/healthz", h.HealthHandle) mux.HandleFunc("/readyz", h.HealthHandle) - return mux + return middleware.CORSMiddleware(middleware.DefaultCORSConfig())(mux) +} + +func limitHandler(limiter *ratelimit.Middleware, next http.Handler) http.Handler { + if limiter == nil { + return next + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + limiter.Limit(next.ServeHTTP)(w, r) + }) } diff --git a/gateway/cmd/gateway/main_test.go b/gateway/cmd/gateway/main_test.go new file mode 100644 index 00000000..3f1a16ba --- /dev/null +++ b/gateway/cmd/gateway/main_test.go @@ -0,0 +1,173 @@ +package main + +import ( + "bytes" + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "lijiaoqiao/gateway/internal/adapter" + "lijiaoqiao/gateway/internal/handler" + "lijiaoqiao/gateway/internal/middleware" + "lijiaoqiao/gateway/internal/ratelimit" + "lijiaoqiao/gateway/internal/router" +) + +type testProvider struct { + name string + models []string +} + +func (p *testProvider) ChatCompletion(_ context.Context, model string, messages []adapter.Message, _ adapter.CompletionOptions) (*adapter.CompletionResponse, error) { + content := "" + if len(messages) > 0 { + content = messages[0].Content + } + return &adapter.CompletionResponse{ + ID: "resp-1", + Object: "chat.completion", + Created: time.Now().Unix(), + Model: model, + Choices: []adapter.Choice{{Index: 0, Message: &adapter.Message{Role: "assistant", Content: content}, FinishReason: "stop"}}, + Usage: adapter.Usage{PromptTokens: 1, CompletionTokens: 1, TotalTokens: 2}, + }, nil +} + +func (p *testProvider) ChatCompletionStream(_ context.Context, _ string, _ []adapter.Message, _ adapter.CompletionOptions) (<-chan *adapter.StreamChunk, error) { + ch := make(chan *adapter.StreamChunk) + close(ch) + return ch, nil +} + +func (p *testProvider) GetUsage(response *adapter.CompletionResponse) adapter.Usage { + return response.Usage +} + +func (p *testProvider) MapError(err error) adapter.ProviderError { + return adapter.ProviderError{Code: "provider_error", Message: err.Error(), HTTPStatus: http.StatusBadGateway} +} + +func (p *testProvider) HealthCheck(context.Context) bool { + return true +} + +func (p *testProvider) ProviderName() string { + return p.name +} + +func (p *testProvider) SupportedModels() []string { + return p.models +} + +func TestCreateMux_ProtectsCompletionRoutes(t *testing.T) { + now := time.Now() + tokenRuntime := middleware.NewInMemoryTokenRuntime(func() time.Time { return now }) + r := router.NewRouter(router.StrategyLatency) + h := handler.NewHandler(r) + limiter := ratelimit.NewMiddleware(ratelimit.NewTokenBucketLimiter(60, 60000, 1.5)) + + authConfig := middleware.AuthMiddlewareConfig{ + Verifier: tokenRuntime, + StatusResolver: tokenRuntime, + Authorizer: middleware.NewScopeRoleAuthorizer(), + Auditor: middleware.NewMemoryAuditEmitter(), + ProtectedPrefixes: []string{ + "/v1/chat/completions", + "/v1/completions", + "/api/v1/chat/completions", + "/api/v1/completions", + }, + ExcludedPrefixes: []string{"/health", "/healthz", "/readyz"}, + Now: func() time.Time { return now }, + } + + mux := createMux(h, limiter, authConfig) + + for _, path := range []string{ + "/v1/chat/completions", + "/v1/completions", + "/api/v1/chat/completions", + "/api/v1/completions", + } { + req := httptest.NewRequest(http.MethodPost, path, bytes.NewBufferString(`{}`)) + rr := httptest.NewRecorder() + mux.ServeHTTP(rr, req) + if rr.Code != http.StatusUnauthorized { + t.Fatalf("expected %s to return 401, got %d", path, rr.Code) + } + } +} + +func TestCreateMux_HealthRoutesRemainOpen(t *testing.T) { + now := time.Now() + tokenRuntime := middleware.NewInMemoryTokenRuntime(func() time.Time { return now }) + r := router.NewRouter(router.StrategyLatency) + h := handler.NewHandler(r) + limiter := ratelimit.NewMiddleware(ratelimit.NewTokenBucketLimiter(60, 60000, 1.5)) + + authConfig := middleware.AuthMiddlewareConfig{ + Verifier: tokenRuntime, + StatusResolver: tokenRuntime, + Authorizer: middleware.NewScopeRoleAuthorizer(), + ProtectedPrefixes: []string{"/v1/chat/completions"}, + ExcludedPrefixes: []string{"/health", "/healthz", "/readyz"}, + Now: func() time.Time { return now }, + } + + mux := createMux(h, limiter, authConfig) + req := httptest.NewRequest(http.MethodGet, "/health", nil) + rr := httptest.NewRecorder() + mux.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected /health to return 200, got %d", rr.Code) + } +} + +func TestCreateMux_CompletionsRouteUsesCompletionsHandler(t *testing.T) { + now := time.Now() + tokenRuntime := middleware.NewInMemoryTokenRuntime(func() time.Time { return now }) + token, err := tokenRuntime.Issue(context.Background(), "user1", "user", []string{"gateway:invoke"}, time.Hour) + if err != nil { + t.Fatalf("failed to issue token: %v", err) + } + + r := router.NewRouter(router.StrategyLatency) + r.RegisterProvider("test", &testProvider{name: "test", models: []string{"gpt-4"}}) + h := handler.NewHandler(r) + limiter := ratelimit.NewMiddleware(ratelimit.NewTokenBucketLimiter(60, 60000, 1.5)) + + authConfig := middleware.AuthMiddlewareConfig{ + Verifier: tokenRuntime, + StatusResolver: tokenRuntime, + Authorizer: middleware.NewScopeRoleAuthorizer(), + Auditor: middleware.NewMemoryAuditEmitter(), + ProtectedPrefixes: []string{ + "/v1/completions", + }, + ExcludedPrefixes: []string{"/health", "/healthz", "/readyz"}, + Now: func() time.Time { return now }, + } + + mux := createMux(h, limiter, authConfig) + reqBody := `{"model":"gpt-4","prompt":"hello","max_tokens":16}` + req := httptest.NewRequest(http.MethodPost, "/v1/completions", bytes.NewBufferString(reqBody)) + req.Header.Set("Authorization", "Bearer "+token) + rr := httptest.NewRecorder() + + mux.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + body, _ := io.ReadAll(rr.Result().Body) + t.Fatalf("expected 200, got %d: %s", rr.Code, strings.TrimSpace(string(body))) + } + + body, _ := io.ReadAll(rr.Result().Body) + if !strings.Contains(string(body), `"object":"text_completion"`) { + t.Fatalf("expected completions response, got %s", string(body)) + } +} diff --git a/gateway/internal/config/config.go b/gateway/internal/config/config.go index 205ce3f0..e5a2306e 100644 --- a/gateway/internal/config/config.go +++ b/gateway/internal/config/config.go @@ -6,7 +6,9 @@ import ( "crypto/rand" "encoding/base64" "errors" + "fmt" "os" + "strings" "time" ) @@ -17,33 +19,41 @@ var encryptionKey = []byte(getEnv("PASSWORD_ENCRYPTION_KEY", "default-key-32-byt // Config 网关配置 type Config struct { - Server ServerConfig - Database DatabaseConfig - Redis RedisConfig - Router RouterConfig + Server ServerConfig + Database DatabaseConfig + Redis RedisConfig + Auth AuthConfig + Router RouterConfig RateLimit RateLimitConfig - Alert AlertConfig + Alert AlertConfig Providers []ProviderConfig } // ServerConfig 服务配置 type ServerConfig struct { - Host string - Port int - ReadTimeout time.Duration + Host string + Port int + ReadTimeout time.Duration WriteTimeout time.Duration - IdleTimeout time.Duration + IdleTimeout time.Duration +} + +// AuthConfig 鉴权运行时配置 +type AuthConfig struct { + Env string + TokenRuntimeMode string + TokenRuntimeURL string } // DatabaseConfig 数据库配置 type DatabaseConfig struct { - Host string - Port int - User string - Password string // 兼容旧版本,仍可直接使用明文密码(不推荐) - EncryptedPassword string // 加密后的密码,优先级高于Password字段 - Database string - MaxConns int + Host string + Port int + User string + Password string // 兼容旧版本,仍可直接使用明文密码(不推荐) + EncryptedPassword string // 加密后的密码,优先级高于Password字段 + Database string + MaxConns int } // GetPassword 返回解密后的数据库密码 @@ -62,12 +72,12 @@ func (c *DatabaseConfig) GetPassword() string { // RedisConfig Redis配置 type RedisConfig struct { - Host string - Port int - Password string // 兼容旧版本 - EncryptedPassword string // 加密后的密码 - DB int - PoolSize int + Host string + Port int + Password string // 兼容旧版本 + EncryptedPassword string // 加密后的密码 + DB int + PoolSize int } // GetPassword 返回解密后的Redis密码 @@ -84,28 +94,28 @@ func (c *RedisConfig) GetPassword() string { // RouterConfig 路由配置 type RouterConfig struct { - Strategy string // "latency", "cost", "availability", "weighted" - Timeout time.Duration - MaxRetries int - RetryDelay time.Duration + Strategy string // "latency", "cost", "availability", "weighted" + Timeout time.Duration + MaxRetries int + RetryDelay time.Duration HealthCheckInterval time.Duration } // RateLimitConfig 限流配置 type RateLimitConfig struct { - Enabled bool - Algorithm string // "token_bucket", "sliding_window", "fixed_window" - DefaultRPM int // 请求数/分钟 - DefaultTPM int // Token数/分钟 + Enabled bool + Algorithm string // "token_bucket", "sliding_window", "fixed_window" + DefaultRPM int // 请求数/分钟 + DefaultTPM int // Token数/分钟 BurstMultiplier float64 } // AlertConfig 告警配置 type AlertConfig struct { - Enabled bool - Email EmailConfig - DingTalk DingTalkConfig - Feishu FeishuConfig + Enabled bool + Email EmailConfig + DingTalk DingTalkConfig + Feishu FeishuConfig } // EmailConfig 邮件配置 @@ -135,11 +145,11 @@ type FeishuConfig struct { // ProviderConfig Provider配置 type ProviderConfig struct { - Name string - Type string // "openai", "anthropic", "google", "custom" - BaseURL string - APIKey string - Models []string + Name string + Type string // "openai", "anthropic", "google", "custom" + BaseURL string + APIKey string + Models []string Priority int Weight float64 } @@ -155,26 +165,31 @@ func LoadConfig(path string) (*Config, error) { WriteTimeout: 30 * time.Second, IdleTimeout: 120 * time.Second, }, + Auth: AuthConfig{ + Env: strings.ToLower(getEnv("GATEWAY_ENV", "dev")), + TokenRuntimeMode: strings.ToLower(getEnv("GATEWAY_TOKEN_RUNTIME_MODE", "inmemory")), + TokenRuntimeURL: strings.TrimSpace(getEnv("GATEWAY_TOKEN_RUNTIME_URL", "")), + }, Router: RouterConfig{ - Strategy: "latency", - Timeout: 30 * time.Second, - MaxRetries: 3, - RetryDelay: 1 * time.Second, + Strategy: "latency", + Timeout: 30 * time.Second, + MaxRetries: 3, + RetryDelay: 1 * time.Second, HealthCheckInterval: 10 * time.Second, }, RateLimit: RateLimitConfig{ - Enabled: true, - Algorithm: "token_bucket", - DefaultRPM: 60, - DefaultTPM: 60000, + Enabled: true, + Algorithm: "token_bucket", + DefaultRPM: 60, + DefaultTPM: 60000, BurstMultiplier: 1.5, }, Alert: AlertConfig{ Enabled: true, Email: EmailConfig{ Enabled: false, - Host: getEnv("SMTP_HOST", "smtp.example.com"), - Port: 587, + Host: getEnv("SMTP_HOST", "smtp.example.com"), + Port: 587, }, DingTalk: DingTalkConfig{ Enabled: getEnv("DINGTALK_ENABLED", "false") == "true", @@ -189,9 +204,33 @@ func LoadConfig(path string) (*Config, error) { }, } + if err := validateAuthConfig(cfg.Auth); err != nil { + return nil, err + } + return cfg, nil } +func validateAuthConfig(cfg AuthConfig) error { + mode := strings.ToLower(strings.TrimSpace(cfg.TokenRuntimeMode)) + env := strings.ToLower(strings.TrimSpace(cfg.Env)) + + switch mode { + case "inmemory", "remote_introspection": + default: + return fmt.Errorf("unsupported token runtime mode %q", cfg.TokenRuntimeMode) + } + + if (env == "prod" || env == "staging") && mode == "inmemory" { + return fmt.Errorf("inmemory token runtime is not allowed in %s, use remote_introspection", env) + } + if mode == "remote_introspection" && strings.TrimSpace(cfg.TokenRuntimeURL) == "" { + return errors.New("GATEWAY_TOKEN_RUNTIME_URL is required when token runtime mode is remote_introspection") + } + + return nil +} + func getEnv(key, defaultValue string) string { if value := os.Getenv(key); value != "" { return value diff --git a/gateway/internal/config/config_test.go b/gateway/internal/config/config_test.go index 06eec45b..83d1f5de 100644 --- a/gateway/internal/config/config_test.go +++ b/gateway/internal/config/config_test.go @@ -2,6 +2,7 @@ package config import ( "os" + "strings" "testing" "time" ) @@ -110,7 +111,7 @@ func TestAlertConfig_Struct(t *testing.T) { cfg := AlertConfig{ Enabled: true, Email: EmailConfig{ - Enabled: false, + Enabled: false, Host: "smtp.example.com", Port: 587, From: "alert@example.com", @@ -344,10 +345,10 @@ func TestConfig_AllFields(t *testing.T) { PoolSize: 10, }, Router: RouterConfig{ - Strategy: "latency", - Timeout: 30 * time.Second, - MaxRetries: 3, - RetryDelay: 1 * time.Second, + Strategy: "latency", + Timeout: 30 * time.Second, + MaxRetries: 3, + RetryDelay: 1 * time.Second, HealthCheckInterval: 10 * time.Second, }, RateLimit: RateLimitConfig{ @@ -360,7 +361,7 @@ func TestConfig_AllFields(t *testing.T) { Alert: AlertConfig{ Enabled: true, Email: EmailConfig{ - Enabled: false, + Enabled: false, Host: "smtp.example.com", Port: 587, }, @@ -405,3 +406,30 @@ func TestLoadConfig_EnvOverrides(t *testing.T) { t.Errorf("expected custom.smtp.com, got %s", cfg.Alert.Email.Host) } } + +func TestLoadConfig_ProdRejectsInMemoryTokenRuntime(t *testing.T) { + t.Setenv("GATEWAY_ENV", "prod") + t.Setenv("GATEWAY_TOKEN_RUNTIME_MODE", "inmemory") + + _, err := LoadConfig("") + if err == nil { + t.Fatal("expected prod config with in-memory token runtime to return error") + } + if !strings.Contains(err.Error(), "inmemory") { + t.Fatalf("expected error to mention inmemory runtime, got %v", err) + } +} + +func TestLoadConfig_ProdRequiresTokenRuntimeURL(t *testing.T) { + t.Setenv("GATEWAY_ENV", "prod") + t.Setenv("GATEWAY_TOKEN_RUNTIME_MODE", "remote_introspection") + t.Setenv("GATEWAY_TOKEN_RUNTIME_URL", "") + + _, err := LoadConfig("") + if err == nil { + t.Fatal("expected prod config without token runtime URL to return error") + } + if !strings.Contains(err.Error(), "TOKEN_RUNTIME_URL") { + t.Fatalf("expected error to mention TOKEN_RUNTIME_URL, got %v", err) + } +} diff --git a/gateway/internal/middleware/remote_runtime.go b/gateway/internal/middleware/remote_runtime.go new file mode 100644 index 00000000..6da5272e --- /dev/null +++ b/gateway/internal/middleware/remote_runtime.go @@ -0,0 +1,117 @@ +package middleware + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + "sync" + "time" +) + +type RemoteTokenRuntime struct { + baseURL string + httpClient *http.Client + now func() time.Time + + mu sync.RWMutex + records map[string]remoteResolvedToken +} + +type remoteResolvedToken struct { + status TokenStatus + expiresAt time.Time +} + +type remoteIntrospectResponse struct { + Data struct { + TokenID string `json:"token_id"` + SubjectID string `json:"subject_id"` + Role string `json:"role"` + Status string `json:"status"` + Scope []string `json:"scope"` + IssuedAt time.Time `json:"issued_at"` + ExpiresAt time.Time `json:"expires_at"` + } `json:"data"` +} + +func NewRemoteTokenRuntime(baseURL string, httpClient *http.Client, now func() time.Time) *RemoteTokenRuntime { + if httpClient == nil { + httpClient = http.DefaultClient + } + if now == nil { + now = time.Now + } + return &RemoteTokenRuntime{ + baseURL: strings.TrimRight(baseURL, "/"), + httpClient: httpClient, + now: now, + records: make(map[string]remoteResolvedToken), + } +} + +func (r *RemoteTokenRuntime) Verify(ctx context.Context, rawToken string) (VerifiedToken, error) { + payload := map[string]string{"token": rawToken} + body, err := json.Marshal(payload) + if err != nil { + return VerifiedToken{}, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, r.baseURL+"/api/v1/platform/tokens/introspect", bytes.NewReader(body)) + if err != nil { + return VerifiedToken{}, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Request-Id", fmt.Sprintf("gateway-introspect-%d", r.now().UnixNano())) + + resp, err := r.httpClient.Do(req) + if err != nil { + return VerifiedToken{}, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return VerifiedToken{}, fmt.Errorf("token introspection failed with status %d", resp.StatusCode) + } + + var result remoteIntrospectResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return VerifiedToken{}, err + } + if strings.TrimSpace(result.Data.TokenID) == "" { + return VerifiedToken{}, errors.New("token introspection response missing token_id") + } + + status := TokenStatus(strings.ToLower(strings.TrimSpace(result.Data.Status))) + r.mu.Lock() + r.records[result.Data.TokenID] = remoteResolvedToken{ + status: status, + expiresAt: result.Data.ExpiresAt, + } + r.mu.Unlock() + + return VerifiedToken{ + TokenID: result.Data.TokenID, + SubjectID: result.Data.SubjectID, + Role: result.Data.Role, + Scope: append([]string(nil), result.Data.Scope...), + IssuedAt: result.Data.IssuedAt, + ExpiresAt: result.Data.ExpiresAt, + }, nil +} + +func (r *RemoteTokenRuntime) Resolve(ctx context.Context, tokenID string) (TokenStatus, error) { + r.mu.RLock() + record, ok := r.records[tokenID] + r.mu.RUnlock() + if !ok { + return "", errors.New("token status not cached; verify must run before resolve") + } + if !record.expiresAt.IsZero() && !r.now().Before(record.expiresAt) && record.status == TokenStatusActive { + return TokenStatusExpired, nil + } + return record.status, nil +} diff --git a/gateway/internal/middleware/remote_runtime_test.go b/gateway/internal/middleware/remote_runtime_test.go new file mode 100644 index 00000000..2e94c1d6 --- /dev/null +++ b/gateway/internal/middleware/remote_runtime_test.go @@ -0,0 +1,67 @@ +package middleware + +import ( + "context" + "io" + "net/http" + "strings" + "testing" + "time" +) + +func TestRemoteTokenRuntime_VerifyAndResolve(t *testing.T) { + httpClient := &http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + if req.URL.Path != "/api/v1/platform/tokens/introspect" { + return &http.Response{ + StatusCode: http.StatusNotFound, + Body: io.NopCloser(strings.NewReader("not found")), + Header: make(http.Header), + }, nil + } + + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{ + "request_id":"req-1", + "data":{ + "token_id":"tok-1", + "subject_id":"user-1", + "role":"org_admin", + "status":"active", + "scope":["gateway:invoke"], + "issued_at":"2026-04-10T10:00:00Z", + "expires_at":"2026-04-10T11:00:00Z" + } + }`)), + Header: make(http.Header), + }, nil + }), + } + + runtime := NewRemoteTokenRuntime("http://token-runtime.internal", httpClient, func() time.Time { + return time.Date(2026, 4, 10, 10, 30, 0, 0, time.UTC) + }) + + claims, err := runtime.Verify(context.Background(), "raw-token") + if err != nil { + t.Fatalf("Verify returned error: %v", err) + } + if claims.TokenID != "tok-1" { + t.Fatalf("expected token id tok-1, got %s", claims.TokenID) + } + + status, err := runtime.Resolve(context.Background(), claims.TokenID) + if err != nil { + t.Fatalf("Resolve returned error: %v", err) + } + if status != TokenStatusActive { + t.Fatalf("expected active status, got %s", status) + } +} + +type roundTripFunc func(req *http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} diff --git a/platform-token-runtime/cmd/platform-token-runtime/main.go b/platform-token-runtime/cmd/platform-token-runtime/main.go index 9ab7d72f..93c591c1 100644 --- a/platform-token-runtime/cmd/platform-token-runtime/main.go +++ b/platform-token-runtime/cmd/platform-token-runtime/main.go @@ -6,6 +6,7 @@ import ( "net/http" "os" "os/signal" + "strings" "syscall" "time" @@ -15,6 +16,10 @@ import ( func main() { addr := envOrDefault("TOKEN_RUNTIME_ADDR", ":18081") + env := strings.ToLower(envOrDefault("TOKEN_RUNTIME_ENV", "dev")) + if env == "prod" || env == "staging" { + log.Fatalf("in-memory token runtime is not allowed in %s", env) + } runtime := service.NewInMemoryTokenRuntime(nil) auditor := service.NewMemoryAuditEmitter() diff --git a/platform-token-runtime/cmd/platform-token-runtime/main_test.go b/platform-token-runtime/cmd/platform-token-runtime/main_test.go new file mode 100644 index 00000000..ee208e3f --- /dev/null +++ b/platform-token-runtime/cmd/platform-token-runtime/main_test.go @@ -0,0 +1,41 @@ +package main + +import ( + "context" + "os" + "os/exec" + "strings" + "testing" + "time" +) + +func TestMain_ProdRejectsInMemoryRuntime(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, os.Args[0], "-test.run=TestMainHelperProcess") + cmd.Env = append(os.Environ(), + "GO_WANT_HELPER_PROCESS=1", + "TOKEN_RUNTIME_ENV=prod", + "TOKEN_RUNTIME_ADDR=127.0.0.1:0", + ) + output, err := cmd.CombinedOutput() + + if ctx.Err() == context.DeadlineExceeded { + t.Fatalf("expected prod startup to fail fast, but process timed out. output=%s", string(output)) + } + if err == nil { + t.Fatalf("expected prod startup to fail, but process exited successfully. output=%s", string(output)) + } + if !strings.Contains(string(output), "in-memory token runtime is not allowed") { + t.Fatalf("expected startup failure output to mention in-memory token runtime is not allowed, got: %s", string(output)) + } +} + +func TestMainHelperProcess(t *testing.T) { + if os.Getenv("GO_WANT_HELPER_PROCESS") != "1" { + return + } + main() + os.Exit(0) +} diff --git a/reports/gates/backend_verify_2026-04-11_091249.md b/reports/gates/backend_verify_2026-04-11_091249.md new file mode 100644 index 00000000..42930b0a --- /dev/null +++ b/reports/gates/backend_verify_2026-04-11_091249.md @@ -0,0 +1,12 @@ +# Backend Verify Report + +- 时间戳:2026-04-11_091249 +- 结果:**PASS** +- 说明:all backend release gates passed + +| 步骤 | 结果 | 说明 | 证据 | +|---|---|---|---| +| STEP-01 | PASS | supply-api critical regression suite | /home/long/project/立交桥/reports/gates/step-01_2026-04-11_091249.out.log | +| STEP-02 | PASS | gateway critical regression suite | /home/long/project/立交桥/reports/gates/step-02_2026-04-11_091249.out.log | +| STEP-03 | PASS | platform-token-runtime critical regression suite | /home/long/project/立交桥/reports/gates/step-03_2026-04-11_091249.out.log | +| STEP-04 | PASS | supply-api E2E gate must not contain placeholder skip | /home/long/project/立交桥/reports/gates/step-04_2026-04-11_091249.out.log | diff --git a/reports/gates/superpowers_stage_validation_2026-04-11_091525.md b/reports/gates/superpowers_stage_validation_2026-04-11_091525.md new file mode 100644 index 00000000..18fa400a --- /dev/null +++ b/reports/gates/superpowers_stage_validation_2026-04-11_091525.md @@ -0,0 +1,29 @@ +# Superpowers 阶段验证报告 + +- 时间戳:2026-04-11_091525 +- 执行脚本:`scripts/ci/superpowers_stage_validate.sh` +- 决策:**CONDITIONAL_GO** +- 决策依据:all executable phases passed but real staging phase is deferred + +## 阶段结果 + +| 阶段 | 结果 | 说明 | 证据 | +|---|---|---|---| +| PHASE-00 | PASS | Backend critical verification gate | /home/long/project/立交桥/tests/supply/artifacts/superpowers_stage_validation_2026-04-11_091525/phase00_backend_verify.log | +| PHASE-01 | PASS | TOK runtime code tests | /home/long/project/立交桥/tests/supply/artifacts/superpowers_stage_validation_2026-04-11_091525/phase01_go_test.log | +| PHASE-02 | PASS | SUP local-mock run_all execution | /home/long/project/立交桥/tests/supply/artifacts/superpowers_stage_validation_2026-04-11_091525/phase02_sup_run_all_mock.log | +| PHASE-03 | PASS | TOK-005 boundary dry-run on local-mock env | /home/long/project/立交桥/tests/supply/artifacts/superpowers_stage_validation_2026-04-11_091525/phase03_tok005_dryrun_mock.log | +| PHASE-04 | PASS | TOK-006 gate bundle aggregation | /home/long/project/立交桥/tests/supply/artifacts/superpowers_stage_validation_2026-04-11_091525/phase04_tok006_bundle.log | +| PHASE-05 | PASS | Dependency audit gate validation | /home/long/project/立交桥/tests/supply/artifacts/superpowers_stage_validation_2026-04-11_091525/phase05_dependency_audit.log | +| PHASE-06 | PASS | Stage gate rollback drill | /home/long/project/立交桥/tests/supply/artifacts/superpowers_stage_validation_2026-04-11_091525/phase06_stage_gate_drill.log | +| PHASE-07 | DEFERRED | Real staging precheck (expected deferred before real secrets) | /home/long/project/立交桥/tests/supply/artifacts/superpowers_stage_validation_2026-04-11_091525/phase07_staging_precheck.log | +| PHASE-08 | PASS | Daily metrics snapshot for M-017/M-018/M-019 | /home/long/project/立交桥/tests/supply/artifacts/superpowers_stage_validation_2026-04-11_091525/phase08_metrics_snapshot.log | +| PHASE-09 | PASS | 7-day metrics trend report generation | /home/long/project/立交桥/tests/supply/artifacts/superpowers_stage_validation_2026-04-11_091525/phase09_metrics_trend.log | +| PHASE-10 | PASS | Token runtime readiness check (M-021) | /home/long/project/立交桥/tests/supply/artifacts/superpowers_stage_validation_2026-04-11_091525/phase10_token_runtime_readiness.log | + +## 说明 + +1. PHASE-07 为真实 staging 验证阶段,在占位凭证场景下允许 DEFERRED,不得伪造 PASS。 +2. PHASE-08/09 负责 M-017/M-018/M-019 的每日快照与趋势证据生成。 +3. PHASE-10 负责 M-021 token 运行态就绪度计算。 +4. 其余阶段均为可执行验证,必须以命令返回码与证据文件为准。 diff --git a/scripts/ci/backend-verify.sh b/scripts/ci/backend-verify.sh new file mode 100755 index 00000000..586ee775 --- /dev/null +++ b/scripts/ci/backend-verify.sh @@ -0,0 +1,145 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "$0")/../.." && pwd)" +OUT_DIR="${ROOT_DIR}/reports/gates" +TS="$(date +%F_%H%M%S)" +LOG_FILE="${OUT_DIR}/backend_verify_${TS}.log" +REPORT_FILE="${OUT_DIR}/backend_verify_${TS}.md" +GO_BIN="${ROOT_DIR}/.tools/go-current/bin/go" +DEFAULT_GOPATH="" +DEFAULT_GOMODCACHE="" + +mkdir -p "${OUT_DIR}" +: > "${LOG_FILE}" + +if [[ ! -x "${GO_BIN}" ]]; then + GO_BIN="$(command -v go || true)" +fi +if [[ -z "${GO_BIN}" ]]; then + echo "[FAIL] go binary not found" | tee -a "${LOG_FILE}" + exit 1 +fi + +export PATH="$(dirname "${GO_BIN}"):${PATH}" +export GOCACHE="${ROOT_DIR}/.tools/go-cache" +DEFAULT_GOPATH="$("${GO_BIN}" env GOPATH 2>/dev/null || true)" +DEFAULT_GOMODCACHE="$("${GO_BIN}" env GOMODCACHE 2>/dev/null || true)" +if [[ -n "${DEFAULT_GOPATH}" ]]; then + export GOPATH="${DEFAULT_GOPATH}" +fi +if [[ -n "${DEFAULT_GOMODCACHE}" ]]; then + export GOMODCACHE="${DEFAULT_GOMODCACHE}" +fi + +STEP_RESULTS=() + +log() { + echo "$1" | tee -a "${LOG_FILE}" +} + +run_step() { + local step_id="$1" + local title="$2" + local cmd="$3" + local out_file="${OUT_DIR}/${step_id,,}_${TS}.out.log" + + log "[INFO] ${step_id} ${title} start" + set +e + bash -lc "${cmd}" > "${out_file}" 2>&1 + local rc=$? + set -e + + if [[ "${rc}" -eq 0 ]]; then + log "[PASS] ${step_id} rc=${rc}" + STEP_RESULTS+=("${step_id}|PASS|${title}|${out_file}") + else + log "[FAIL] ${step_id} rc=${rc}" + STEP_RESULTS+=("${step_id}|FAIL|${title}|${out_file}") + fi +} + +run_e2e_skip_gate() { + local step_id="$1" + local title="$2" + local out_file="${OUT_DIR}/${step_id,,}_${TS}.out.log" + + log "[INFO] ${step_id} ${title} start" + set +e + bash -lc "cd \"${ROOT_DIR}/supply-api\" && \"${GO_BIN}\" test -tags=e2e -v ./e2e/..." > "${out_file}" 2>&1 + local rc=$? + set -e + + if grep -Eiq 'SKIP|需要完整环境运行 E2E 测试|Skipping E2E test' "${out_file}"; then + log "[FAIL] ${step_id} placeholder E2E detected" + STEP_RESULTS+=("${step_id}|FAIL|${title}|${out_file}") + return + fi + + if [[ "${rc}" -eq 0 ]]; then + log "[PASS] ${step_id} rc=${rc}" + STEP_RESULTS+=("${step_id}|PASS|${title}|${out_file}") + else + log "[FAIL] ${step_id} rc=${rc}" + STEP_RESULTS+=("${step_id}|FAIL|${title}|${out_file}") + fi +} + +run_step \ + "STEP-01" \ + "supply-api critical regression suite" \ + "cd \"${ROOT_DIR}/supply-api\" && \"${GO_BIN}\" test ./cmd/supply-api ./internal/config ./internal/httpapi ./internal/middleware ./internal/outbox ./internal/repository" + +run_step \ + "STEP-02" \ + "gateway critical regression suite" \ + "cd \"${ROOT_DIR}/gateway\" && \"${GO_BIN}\" test ./cmd/gateway ./internal/config ./internal/middleware" + +run_step \ + "STEP-03" \ + "platform-token-runtime critical regression suite" \ + "cd \"${ROOT_DIR}/platform-token-runtime\" && \"${GO_BIN}\" test ./cmd/platform-token-runtime ./internal/httpapi ./internal/token ./internal/auth/..." + +run_e2e_skip_gate \ + "STEP-04" \ + "supply-api E2E gate must not contain placeholder skip" + +HAS_FAIL=0 +for row in "${STEP_RESULTS[@]}"; do + status="$(echo "${row}" | awk -F'|' '{print $2}')" + if [[ "${status}" == "FAIL" ]]; then + HAS_FAIL=1 + fi +done + +RESULT="PASS" +NOTE="all backend release gates passed" +if [[ "${HAS_FAIL}" -eq 1 ]]; then + RESULT="FAIL" + NOTE="at least one backend release gate failed" +fi + +{ + echo "# Backend Verify Report" + echo + echo "- 时间戳:${TS}" + echo "- 结果:**${RESULT}**" + echo "- 说明:${NOTE}" + echo + echo "| 步骤 | 结果 | 说明 | 证据 |" + echo "|---|---|---|---|" + for row in "${STEP_RESULTS[@]}"; do + step_id="$(echo "${row}" | awk -F'|' '{print $1}')" + status="$(echo "${row}" | awk -F'|' '{print $2}')" + title="$(echo "${row}" | awk -F'|' '{print $3}')" + evidence="$(echo "${row}" | awk -F'|' '{print $4}')" + echo "| ${step_id} | ${status} | ${title} | ${evidence} |" + done +} > "${REPORT_FILE}" + +log "[INFO] report generated: ${REPORT_FILE}" +log "[RESULT] ${RESULT}" + +if [[ "${RESULT}" != "PASS" ]]; then + exit 1 +fi diff --git a/scripts/ci/superpowers_stage_validate.sh b/scripts/ci/superpowers_stage_validate.sh index 7ad8001e..6c78f6a3 100755 --- a/scripts/ci/superpowers_stage_validate.sh +++ b/scripts/ci/superpowers_stage_validate.sh @@ -129,6 +129,12 @@ if [[ -z "${GO_BIN}" ]]; then exit 1 fi +run_step \ + "PHASE-00" \ + "Backend critical verification gate" \ + "cd \"${ROOT_DIR}\" && bash \"scripts/ci/backend-verify.sh\"" \ + "${ART_DIR}/phase00_backend_verify.log" + run_step \ "PHASE-01" \ "TOK runtime code tests" \ @@ -170,7 +176,7 @@ run_step_allow_deferred \ "Real staging precheck (expected deferred before real secrets)" \ "cd \"${ROOT_DIR}\" && bash \"scripts/supply-gate/staging_precheck_and_run.sh\" \"${STAGING_ENV_FILE}\"" \ "${ART_DIR}/phase07_staging_precheck.log" \ - "placeholder token detected|placeholder API_BASE_URL|missing env var" + "placeholder token detected|placeholder API_BASE_URL|missing env var|API_BASE_URL unreachable" run_step \ "PHASE-08" \ diff --git a/supply-api/cmd/supply-api/main.go b/supply-api/cmd/supply-api/main.go index 2dbba9df..004675b6 100644 --- a/supply-api/cmd/supply-api/main.go +++ b/supply-api/cmd/supply-api/main.go @@ -37,22 +37,29 @@ func main() { } // 加载配置 - cfg, err := config.Load(*env) + cfg, err := config.LoadFromPath(*env, *configPath) if err != nil { log.Fatalf("failed to load config: %v", err) } log.Printf("starting supply-api in %s mode", *env) + isProd := *env == "prod" // P1-010修复: 初始化结构化日志 jsonLogger := logging.NewLogger("supply-api", logging.LogLevelInfo) - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() + rootCtx, stop := context.WithCancel(context.Background()) + defer stop() + + initCtx, initCancel := context.WithTimeout(rootCtx, 30*time.Second) + defer initCancel() // 初始化数据库连接 - db, err := repository.NewDB(ctx, cfg.Database) + db, err := repository.NewDB(initCtx, cfg.Database) if err != nil { + if isProd { + log.Fatalf("production startup requirement failed: database unavailable: %v", err) + } log.Printf("warning: failed to connect to database: %v (using in-memory store)", err) db = nil } else { @@ -63,7 +70,11 @@ func main() { // 初始化Redis缓存 redisCache, err := cache.NewRedisCache(cfg.Redis) if err != nil { - log.Printf("warning: failed to connect to redis: %v (caching disabled)", err) + if isProd { + log.Printf("warning: redis unavailable at startup: %v", err) + } else { + log.Printf("warning: failed to connect to redis: %v (caching disabled)", err) + } redisCache = nil } else { log.Printf("connected to redis at %s:%d", cfg.Redis.Host, cfg.Redis.Port) @@ -154,7 +165,7 @@ func main() { // 启动主动吊销订阅机制(仅在Redis可用时) if redisCache != nil { if dbTokenBackend, ok := tokenBackend.(*middleware.DBTokenStatusBackend); ok { - if err := dbTokenBackend.StartRevocationSubscriber(ctx); err != nil { + if err := dbTokenBackend.StartRevocationSubscriber(rootCtx); err != nil { log.Printf("警告: 启动主动吊销订阅失败: %v", err) } else { log.Println("主动吊销机制: 已启动 (Redis Pub/Sub)") @@ -172,6 +183,8 @@ func main() { // 初始化鉴权中间件 authConfig := middleware.AuthConfig{ SecretKey: cfg.Token.SecretKey, + PublicKey: cfg.Token.PublicKey, + Algorithm: cfg.Token.Algorithm, Issuer: cfg.Token.Issuer, CacheTTL: cfg.Token.RevocationCacheTTL, Enabled: *env != "dev", // 开发模式禁用鉴权 @@ -210,6 +223,7 @@ func main() { cfg.Server.StatementBaseURL, time.Now, ) + api.SetWithdrawEnabled(cfg.Settlement.WithdrawEnabled) // 创建路由器 mux := http.NewServeMux() @@ -254,14 +268,12 @@ func main() { // 生产环境启用安全中间件 if *env != "dev" { - // 5. QueryKeyReject - 拒绝外部query key - handler = authMiddleware.QueryKeyRejectMiddleware(handler) - // 6. BearerExtract - handler = authMiddleware.BearerExtractMiddleware(handler) - // 7. TokenVerify - handler = authMiddleware.TokenVerifyMiddleware(handler) - // 8. RateLimit - 限流 (使用中间件包装器) + // 包装顺序与请求执行顺序相反,这里从内到外构建,保证实际执行顺序为: + // QueryKeyReject -> BearerExtract -> TokenVerify -> RateLimit handler = middleware.NewRateLimitHandler(rateLimitConfig, handler) + handler = authMiddleware.TokenVerifyMiddleware(handler) + handler = authMiddleware.BearerExtractMiddleware(handler) + handler = authMiddleware.QueryKeyRejectMiddleware(handler) } // 创建HTTP服务器 @@ -274,6 +286,14 @@ func main() { IdleTimeout: cfg.Server.IdleTimeout, } + serverErrCh := make(chan error, 1) + go func() { + log.Printf("starting HTTP server on %s", cfg.Server.Addr) + if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + serverErrCh <- err + } + }() + // P0-06修复: 启动OutboxProcessor(仅在DB可用时) var outboxProcessor *outbox.OutboxProcessorRunner if db != nil { @@ -284,14 +304,21 @@ func main() { redisClient := redisCache.GetClient() msgBroker = messaging.NewOutboxMessageBroker(redisClient, "supply:outbox:stream", "outbox-processor") } - stats := &messaging.NoOpOutboxStats{} - outboxProcessor = outbox.NewOutboxProcessorRunner(outboxRepo, msgBroker, stats) - go outboxProcessor.Start(ctx) - log.Println("OutboxProcessor已启动") + if msgBroker == nil { + if isProd { + log.Fatalf("production startup requirement failed: outbox message broker unavailable") + } + log.Println("警告: OutboxProcessor未启动 (message broker不可用)") + } else { + stats := &messaging.NoOpOutboxStats{} + outboxProcessor = outbox.NewOutboxProcessorRunner(outboxRepo, msgBroker, stats) + go outboxProcessor.Start(rootCtx) + log.Println("OutboxProcessor已启动") + } // 分区维护:确保未来分区已创建 partitionManager := repository.NewPartitionManager(db.Pool) - if err := partitionManager.EnsureFuturePartitions(ctx); err != nil { + if err := partitionManager.EnsureFuturePartitions(initCtx); err != nil { log.Printf("警告: 预创建未来分区失败: %v", err) } else { log.Println("分区管理: 未来分区已确保存在") @@ -303,7 +330,7 @@ func main() { defer ticker.Stop() for { select { - case <-ctx.Done(): + case <-rootCtx.Done(): return case <-ticker.C: if err := partitionManager.EnsureFuturePartitions(context.Background()); err != nil { @@ -327,14 +354,21 @@ func main() { log.Println("批量补偿处理器: 已初始化") // 启动后台补偿处理goroutine - compensationProcessor.StartBackgroundWorker(ctx, 5*time.Minute) + compensationProcessor.StartBackgroundWorker(rootCtx, 5*time.Minute) log.Println("批量补偿处理器: 后台worker已启动 (每5分钟检查一次)") } // 优雅关闭 sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) - <-sigCh + select { + case sig := <-sigCh: + log.Printf("received signal %s", sig) + case err := <-serverErrCh: + log.Fatalf("server failed: %v", err) + } + + stop() log.Println("shutting down...") diff --git a/supply-api/cmd/supply-api/main_test.go b/supply-api/cmd/supply-api/main_test.go new file mode 100644 index 00000000..2922f358 --- /dev/null +++ b/supply-api/cmd/supply-api/main_test.go @@ -0,0 +1,71 @@ +package main + +import ( + "context" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + "time" +) + +func TestMain_ProdStartupFailsWhenDatabaseUnavailable(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.prod.yaml") + content := []byte(` +server: + addr: "127.0.0.1:0" + shutdown_timeout: 1s + default_supplier_id: 0 +database: + host: "127.0.0.1" + port: 1 + user: "postgres" + password: "secret" + database: "supply_db" +redis: + host: "127.0.0.1" + port: 1 +token: + issuer: "prod-issuer" + secret_key: "prod-secret" + algorithm: "HS256" +`) + if err := os.WriteFile(configPath, content, 0o600); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, os.Args[0], "-test.run=TestMainHelperProcess", "--", "-env", "prod", "-config", configPath) + cmd.Env = append(os.Environ(), "GO_WANT_HELPER_PROCESS=1") + output, err := cmd.CombinedOutput() + + if ctx.Err() == context.DeadlineExceeded { + t.Fatalf("expected prod startup to fail fast, but process timed out. output=%s", string(output)) + } + if err == nil { + t.Fatalf("expected prod startup to fail, but process exited successfully. output=%s", string(output)) + } + if !strings.Contains(string(output), "production startup requirement failed") { + t.Fatalf("expected startup failure output to mention production startup requirement failed, got: %s", string(output)) + } +} + +func TestMainHelperProcess(t *testing.T) { + if os.Getenv("GO_WANT_HELPER_PROCESS") != "1" { + return + } + + for i, arg := range os.Args { + if arg != "--" { + continue + } + os.Args = append([]string{os.Args[0]}, os.Args[i+1:]...) + break + } + + main() + os.Exit(0) +} diff --git a/supply-api/e2e/e2e_test.go b/supply-api/e2e/e2e_test.go index 28514979..e14a9221 100644 --- a/supply-api/e2e/e2e_test.go +++ b/supply-api/e2e/e2e_test.go @@ -4,133 +4,418 @@ package e2e import ( - "os" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" "testing" "time" + + "github.com/golang-jwt/jwt/v5" + + "lijiaoqiao/supply-api/internal/adapter" + "lijiaoqiao/supply-api/internal/audit" + "lijiaoqiao/supply-api/internal/domain" + "lijiaoqiao/supply-api/internal/httpapi" + "lijiaoqiao/supply-api/internal/middleware" + "lijiaoqiao/supply-api/internal/pkg/logging" ) -// E2E 测试配置 -type E2EConfig struct { - BaseURL string - APIKey string - SupplierID int64 - Timeout time.Duration - RetryAttempts int +type e2eOptions struct { + withdrawEnabled bool } -// getE2EConfig 从环境变量获取 E2E 测试配置 -func getE2EConfig() *E2EConfig { - return &E2EConfig{ - BaseURL: getEnv("E2E_BASE_URL", "http://localhost:8080"), - APIKey: getEnv("E2E_API_KEY", "test-api-key"), - SupplierID: 1001, - Timeout: 30 * time.Second, - RetryAttempts: 3, +type e2eSystem struct { + handler http.Handler + accountSvc *e2eAccountService + auditStore *audit.MemoryAuditStore + secretKey string + tokenIssuer string +} + +type e2eAccountService struct { + verifyResult *domain.VerifyResult + lastVerifySupplierID int64 +} + +func (s *e2eAccountService) Verify(ctx context.Context, supplierID int64, provider domain.Provider, accountType domain.AccountType, credential string) (*domain.VerifyResult, error) { + s.lastVerifySupplierID = supplierID + return s.verifyResult, nil +} + +func (s *e2eAccountService) Create(ctx context.Context, req *domain.CreateAccountRequest) (*domain.Account, error) { + return &domain.Account{ID: 1, SupplierID: req.SupplierID, Provider: req.Provider, AccountType: req.AccountType, Status: domain.AccountStatusActive}, nil +} + +func (s *e2eAccountService) Activate(ctx context.Context, supplierID, accountID int64) (*domain.Account, error) { + return &domain.Account{ID: accountID, SupplierID: supplierID, Status: domain.AccountStatusActive}, nil +} + +func (s *e2eAccountService) Suspend(ctx context.Context, supplierID, accountID int64) (*domain.Account, error) { + return &domain.Account{ID: accountID, SupplierID: supplierID, Status: domain.AccountStatusSuspended}, nil +} + +func (s *e2eAccountService) Delete(ctx context.Context, supplierID, accountID int64) error { + return nil +} + +func (s *e2eAccountService) GetByID(ctx context.Context, supplierID, accountID int64) (*domain.Account, error) { + return &domain.Account{ID: accountID, SupplierID: supplierID, Status: domain.AccountStatusActive}, nil +} + +type e2ePackageService struct{} + +func (s *e2ePackageService) CreateDraft(ctx context.Context, supplierID int64, req *domain.CreatePackageDraftRequest) (*domain.Package, error) { + return &domain.Package{ID: 1, SupplierID: supplierID, Model: req.Model, Status: domain.PackageStatusDraft}, nil +} + +func (s *e2ePackageService) Publish(ctx context.Context, supplierID, packageID int64) (*domain.Package, error) { + return &domain.Package{ID: packageID, SupplierID: supplierID, Status: domain.PackageStatusActive}, nil +} + +func (s *e2ePackageService) Pause(ctx context.Context, supplierID, packageID int64) (*domain.Package, error) { + return &domain.Package{ID: packageID, SupplierID: supplierID, Status: domain.PackageStatusPaused}, nil +} + +func (s *e2ePackageService) Unlist(ctx context.Context, supplierID, packageID int64) (*domain.Package, error) { + return &domain.Package{ID: packageID, SupplierID: supplierID, Status: domain.PackageStatusExpired}, nil +} + +func (s *e2ePackageService) Clone(ctx context.Context, supplierID, packageID int64) (*domain.Package, error) { + return &domain.Package{ID: packageID + 1, SupplierID: supplierID, Status: domain.PackageStatusDraft}, nil +} + +func (s *e2ePackageService) BatchUpdatePrice(ctx context.Context, supplierID int64, req *domain.BatchUpdatePriceRequest) (*domain.BatchUpdatePriceResponse, error) { + return &domain.BatchUpdatePriceResponse{Total: len(req.Items), SuccessCount: len(req.Items)}, nil +} + +func (s *e2ePackageService) GetByID(ctx context.Context, supplierID, packageID int64) (*domain.Package, error) { + return &domain.Package{ID: packageID, SupplierID: supplierID, Status: domain.PackageStatusActive}, nil +} + +type e2eSettlementService struct{} + +func (s *e2eSettlementService) Withdraw(ctx context.Context, supplierID int64, req *domain.WithdrawRequest) (*domain.Settlement, error) { + now := time.Now().UTC() + return &domain.Settlement{ + ID: 1, + SupplierID: supplierID, + SettlementNo: "SET-001", + Status: domain.SettlementStatusPending, + TotalAmount: req.Amount, + NetAmount: req.Amount, + CreatedAt: now, + UpdatedAt: now, + }, nil +} + +func (s *e2eSettlementService) Cancel(ctx context.Context, supplierID, settlementID int64) (*domain.Settlement, error) { + now := time.Now().UTC() + return &domain.Settlement{ID: settlementID, SupplierID: supplierID, Status: domain.SettlementStatusFailed, CreatedAt: now, UpdatedAt: now}, nil +} + +func (s *e2eSettlementService) GetByID(ctx context.Context, supplierID, settlementID int64) (*domain.Settlement, error) { + now := time.Now().UTC() + return &domain.Settlement{ID: settlementID, SupplierID: supplierID, Status: domain.SettlementStatusPending, CreatedAt: now, UpdatedAt: now}, nil +} + +func (s *e2eSettlementService) List(ctx context.Context, supplierID int64) ([]*domain.Settlement, error) { + now := time.Now().UTC() + return []*domain.Settlement{{ID: 1, SupplierID: supplierID, Status: domain.SettlementStatusPending, CreatedAt: now, UpdatedAt: now}}, nil +} + +type e2eEarningService struct{} + +func (s *e2eEarningService) ListRecords(ctx context.Context, supplierID int64, startDate, endDate string, page, pageSize int) ([]*domain.EarningRecord, int, error) { + return []*domain.EarningRecord{ + {ID: 1, SupplierID: supplierID, EarningsType: "usage", Amount: 100, Status: "available", EarnedAt: time.Now().UTC()}, + }, 1, nil +} + +func (s *e2eEarningService) GetBillingSummary(ctx context.Context, supplierID int64, startDate, endDate string) (*domain.BillingSummary, error) { + return &domain.BillingSummary{ + Period: domain.BillingPeriod{Start: startDate, End: endDate}, + Summary: domain.BillingTotal{TotalRevenue: 100, TotalOrders: 1, TotalUsage: 1000, TotalRequests: 10, AvgSuccessRate: 1, NetEarnings: 95}, + }, nil +} + +type staticTokenBackend struct { + statusByTokenID map[string]string +} + +func (b *staticTokenBackend) CheckTokenStatus(ctx context.Context, tokenID string) (string, error) { + if status, ok := b.statusByTokenID[tokenID]; ok { + return status, nil + } + return "active", nil +} + +func newE2ESystem(t *testing.T, opts e2eOptions) *e2eSystem { + t.Helper() + + accountSvc := &e2eAccountService{ + verifyResult: &domain.VerifyResult{ + VerifyStatus: "pass", + AvailableQuota: 2048, + RiskScore: 0, + CheckItems: []domain.CheckItem{ + {Item: "credential_format", Result: "pass", Message: "ok"}, + }, + }, + } + auditStore := audit.NewMemoryAuditStore() + + api := httpapi.NewSupplyAPI( + accountSvc, + &e2ePackageService{}, + &e2eSettlementService{}, + &e2eEarningService{}, + nil, + auditStore, + nil, + 0, + "https://statements.example.com", + func() time.Time { return time.Unix(1712800000, 0).UTC() }, + ) + api.SetWithdrawEnabled(opts.withdrawEnabled) + + mux := http.NewServeMux() + healthHandler := httpapi.NewHealthHandlerWithDefaults(nil, nil) + mux.HandleFunc("/actuator/health", healthHandler.ServeHealth) + mux.HandleFunc("/actuator/health/live", healthHandler.ServeLiveness) + mux.HandleFunc("/actuator/health/ready", healthHandler.ServeReadiness) + api.Register(mux) + + authMiddleware := middleware.NewAuthMiddleware( + middleware.AuthConfig{ + SecretKey: "e2e-secret-key-should-be-long", + Algorithm: "HS256", + Issuer: "supply-api-e2e", + Enabled: true, + }, + middleware.NewTokenCache(), + &staticTokenBackend{statusByTokenID: map[string]string{}}, + adapter.NewAuditEmitterAdapter(auditStore), + ) + + logger := logging.NewLogger("supply-api-e2e", logging.LogLevelError) + + var handler http.Handler = mux + handler = middleware.RequestID(handler) + handler = middleware.Recovery(handler) + handler = middleware.Logging(handler, logger) + handler = middleware.TracingMiddleware(handler) + handler = authMiddleware.TokenVerifyMiddleware(handler) + handler = authMiddleware.BearerExtractMiddleware(handler) + handler = authMiddleware.QueryKeyRejectMiddleware(handler) + + return &e2eSystem{ + handler: handler, + accountSvc: accountSvc, + auditStore: auditStore, + secretKey: "e2e-secret-key-should-be-long", + tokenIssuer: "supply-api-e2e", } } -func getEnv(key, defaultValue string) string { - if value := os.Getenv(key); value != "" { - return value - } - return defaultValue -} +func (s *e2eSystem) tokenForTenant(t *testing.T, tokenID string, tenantID int64) string { + t.Helper() -// TestE2E_HealthCheck E2E 测试:健康检查 -func TestE2E_HealthCheck(t *testing.T) { - if testing.Short() { - t.Skip("Skipping E2E test in short mode") + claims := &middleware.TokenClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + ID: tokenID, + Issuer: s.tokenIssuer, + Subject: "subject-42", + ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + SubjectID: "42", + Role: "org_admin", + Scope: []string{"supply:write", "supply:read"}, + TenantID: tenantID, } - cfg := getE2EConfig() - _, _ = cfg.Timeout, cfg.RetryAttempts // 使用配置参数 - - // 验证服务健康状态 - // 在真实的 E2E 测试中,这里会使用 HTTP 客户端调用真实服务 - t.Logf("E2E_BASE_URL: %s", cfg.BaseURL) - t.Skip("需要完整环境运行 E2E 测试") + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenString, err := token.SignedString([]byte(s.secretKey)) + if err != nil { + t.Fatalf("failed to sign token: %v", err) + } + return tokenString } -// TestE2E_AccountLifecycle E2E 测试:账号完整生命周期 -func TestE2E_AccountLifecycle(t *testing.T) { - if testing.Short() { - t.Skip("Skipping E2E test in short mode") +func decodeJSONBody(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any { + t.Helper() + + var payload map[string]any + if err := json.Unmarshal(recorder.Body.Bytes(), &payload); err != nil { + t.Fatalf("failed to decode response: %v, body=%s", err, recorder.Body.String()) + } + return payload +} + +func TestE2E_HealthProbe_IsPublicAndHealthy(t *testing.T) { + system := newE2ESystem(t, e2eOptions{withdrawEnabled: false}) + + req := httptest.NewRequest(http.MethodGet, "/actuator/health", nil) + req.Header.Set("X-Request-Id", "health-req-001") + recorder := httptest.NewRecorder() + + system.handler.ServeHTTP(recorder, req) + + if recorder.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d, body=%s", recorder.Code, recorder.Body.String()) + } + if recorder.Header().Get("X-Request-Id") != "health-req-001" { + t.Fatalf("expected X-Request-Id response header to round-trip") } - cfg := getE2EConfig() - _, _ = cfg.Timeout, cfg.RetryAttempts // 使用配置参数 - - t.Log("E2E 测试:账号完整生命周期") - t.Log("步骤:创建 -> 验证 -> 暂停 -> 恢复 -> 禁用") - t.Skip("需要完整环境运行 E2E 测试") + payload := decodeJSONBody(t, recorder) + if payload["status"] != "healthy" { + t.Fatalf("expected healthy status, got %v", payload["status"]) + } } -// TestE2E_PackageLifecycle E2E 测试:套餐完整生命周期 -func TestE2E_PackageLifecycle(t *testing.T) { - if testing.Short() { - t.Skip("Skipping E2E test in short mode") +func TestE2E_ProtectedRoute_RejectsMissingBearer(t *testing.T) { + system := newE2ESystem(t, e2eOptions{withdrawEnabled: false}) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/supply/billing?start_date=2026-04-01&end_date=2026-04-11", nil) + recorder := httptest.NewRecorder() + + system.handler.ServeHTTP(recorder, req) + + if recorder.Code != http.StatusUnauthorized { + t.Fatalf("expected status 401, got %d, body=%s", recorder.Code, recorder.Body.String()) } - cfg := getE2EConfig() - _, _ = cfg.Timeout, cfg.RetryAttempts // 使用配置参数 - - t.Log("E2E 测试:套餐完整生命周期") - t.Log("步骤:创建草稿 -> 发布 -> 暂停 -> 售罄 -> 过期") - t.Skip("需要完整环境运行 E2E 测试") + payload := decodeJSONBody(t, recorder) + errBody, ok := payload["error"].(map[string]any) + if !ok { + t.Fatalf("expected error body, got %v", payload["error"]) + } + if errBody["code"] != "AUTH_MISSING_BEARER" { + t.Fatalf("expected AUTH_MISSING_BEARER, got %v", errBody["code"]) + } } -// TestE2E_SettlementWorkflow E2E 测试:结算完整流程 -func TestE2E_SettlementWorkflow(t *testing.T) { - if testing.Short() { - t.Skip("Skipping E2E test in short mode") +func TestE2E_ProtectedRoute_RejectsQueryCredentialLeak(t *testing.T) { + system := newE2ESystem(t, e2eOptions{withdrawEnabled: false}) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/supply/billing?api_key=abcdefghijklmnopqrstuvwxyz123456", nil) + recorder := httptest.NewRecorder() + + system.handler.ServeHTTP(recorder, req) + + if recorder.Code != http.StatusUnauthorized { + t.Fatalf("expected status 401, got %d, body=%s", recorder.Code, recorder.Body.String()) } - cfg := getE2EConfig() - _, _ = cfg.Timeout, cfg.RetryAttempts // 使用配置参数 - - t.Log("E2E 测试:结算完整流程") - t.Log("步骤:创建结算单 -> 处理中 -> 完成/失败 -> 提现") - t.Skip("需要完整环境运行 E2E 测试") + payload := decodeJSONBody(t, recorder) + errBody, ok := payload["error"].(map[string]any) + if !ok { + t.Fatalf("expected error body, got %v", payload["error"]) + } + if errBody["code"] != "QUERY_KEY_NOT_ALLOWED" { + t.Fatalf("expected QUERY_KEY_NOT_ALLOWED, got %v", errBody["code"]) + } } -// TestE2E_WithdrawFlow E2E 测试:提现流程 -func TestE2E_WithdrawFlow(t *testing.T) { - if testing.Short() { - t.Skip("Skipping E2E test in short mode") +func TestE2E_VerifyAccount_UsesTenantIDFromVerifiedToken(t *testing.T) { + system := newE2ESystem(t, e2eOptions{withdrawEnabled: false}) + token := system.tokenForTenant(t, "tok-e2e-verify", 2001) + + body := `{"provider":"openai","account_type":"resource","credential_input":"sk-test"}` + req := httptest.NewRequest(http.MethodPost, "/api/v1/supply/accounts/verify", strings.NewReader(body)) + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Request-Id", "verify-req-001") + recorder := httptest.NewRecorder() + + system.handler.ServeHTTP(recorder, req) + + if recorder.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d, body=%s", recorder.Code, recorder.Body.String()) + } + if system.accountSvc.lastVerifySupplierID != 2001 { + t.Fatalf("expected supplierID 2001 from token tenant, got %d", system.accountSvc.lastVerifySupplierID) } - cfg := getE2EConfig() - _, _ = cfg.Timeout, cfg.RetryAttempts // 使用配置参数 - - t.Log("E2E 测试:提现流程") - t.Log("步骤:验证余额 -> 发起提现 -> 短信验证 -> 确认 -> 到账") - t.Skip("需要完整环境运行 E2E 测试") + payload := decodeJSONBody(t, recorder) + if payload["request_id"] != "verify-req-001" { + t.Fatalf("expected request_id verify-req-001, got %v", payload["request_id"]) + } } -// TestE2E_AuditLogTrace E2E 测试:审计日志追踪 -func TestE2E_AuditLogTrace(t *testing.T) { - if testing.Short() { - t.Skip("Skipping E2E test in short mode") +func TestE2E_Withdraw_DisabledBeforeSMSIntegration(t *testing.T) { + system := newE2ESystem(t, e2eOptions{withdrawEnabled: false}) + token := system.tokenForTenant(t, "tok-e2e-withdraw", 2002) + + body := `{"withdraw_amount":100,"payment_method":"bank","payment_account":"13800000000","sms_code":"123456"}` + req := httptest.NewRequest(http.MethodPost, "/api/v1/supply/settlements/withdraw", strings.NewReader(body)) + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Content-Type", "application/json") + recorder := httptest.NewRecorder() + + system.handler.ServeHTTP(recorder, req) + + if recorder.Code != http.StatusServiceUnavailable { + t.Fatalf("expected status 503, got %d, body=%s", recorder.Code, recorder.Body.String()) } - cfg := getE2EConfig() - _, _ = cfg.Timeout, cfg.RetryAttempts // 使用配置参数 - - t.Log("E2E 测试:审计日志追踪") - t.Log("验证操作被正确记录到审计日志") - t.Skip("需要完整环境运行 E2E 测试") + payload := decodeJSONBody(t, recorder) + errBody, ok := payload["error"].(map[string]any) + if !ok { + t.Fatalf("expected error body, got %v", payload["error"]) + } + if errBody["code"] != "FEATURE_DISABLED" { + t.Fatalf("expected FEATURE_DISABLED, got %v", errBody["code"]) + } } -// TestE2E_ConcurrentWithdraw E2E 测试:并发提现 -func TestE2E_ConcurrentWithdraw(t *testing.T) { - if testing.Short() { - t.Skip("Skipping E2E test in short mode") +func TestE2E_AuditEvent_CanBeReadBackThroughAPI(t *testing.T) { + system := newE2ESystem(t, e2eOptions{withdrawEnabled: false}) + token := system.tokenForTenant(t, "tok-e2e-audit", 3001) + + if err := system.auditStore.Emit(context.Background(), audit.Event{ + TenantID: 3001, + ObjectType: "supply_account", + ObjectID: 77, + Action: "verify", + RequestID: "audit-req-001", + ResultCode: "OK", + SourceIP: "127.0.0.1", + }); err != nil { + t.Fatalf("failed to seed audit event: %v", err) } - cfg := getE2EConfig() - _, _ = cfg.Timeout, cfg.RetryAttempts // 使用配置参数 + events, err := system.auditStore.Query(context.Background(), audit.EventFilter{TenantID: 3001, Limit: 1}) + if err != nil { + t.Fatalf("failed to query seeded audit event: %v", err) + } + if len(events) != 1 { + t.Fatalf("expected 1 audit event, got %d", len(events)) + } - t.Log("E2E 测试:并发提现") - t.Log("验证乐观锁正确处理并发请求") - t.Skip("需要完整环境运行 E2E 测试") + req := httptest.NewRequest(http.MethodGet, "/api/v1/audit/events/"+events[0].EventID, nil) + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("X-Request-Id", "audit-read-001") + recorder := httptest.NewRecorder() + + system.handler.ServeHTTP(recorder, req) + + if recorder.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d, body=%s", recorder.Code, recorder.Body.String()) + } + + payload := decodeJSONBody(t, recorder) + data, ok := payload["data"].(map[string]any) + if !ok { + t.Fatalf("expected data object, got %v", payload["data"]) + } + if data["event_id"] != events[0].EventID { + t.Fatalf("expected event_id %s, got %v", events[0].EventID, data["event_id"]) + } + if data["request_id"] != "audit-req-001" { + t.Fatalf("expected seeded request_id audit-req-001, got %v", data["request_id"]) + } } diff --git a/supply-api/internal/config/config.go b/supply-api/internal/config/config.go index 1e06277e..65893b61 100644 --- a/supply-api/internal/config/config.go +++ b/supply-api/internal/config/config.go @@ -12,11 +12,12 @@ import ( // Config 应用配置 type Config struct { - Server ServerConfig - Database DatabaseConfig - Redis RedisConfig - Token TokenConfig - Audit AuditConfig + Server ServerConfig + Database DatabaseConfig + Redis RedisConfig + Token TokenConfig + Settlement SettlementConfig + Audit AuditConfig } // ServerConfig HTTP服务配置 @@ -63,6 +64,11 @@ type TokenConfig struct { RevocationCacheTTL time.Duration } +// SettlementConfig 结算与提现能力配置 +type SettlementConfig struct { + WithdrawEnabled bool +} + // AuditConfig 审计配置 type AuditConfig struct { BufferSize int @@ -90,6 +96,15 @@ func (r *RedisConfig) Addr() string { // Load 加载配置 func Load(env string) (*Config, error) { + return load(env, "") +} + +// LoadFromPath 从指定路径加载配置 +func LoadFromPath(env, configPath string) (*Config, error) { + return load(env, configPath) +} + +func load(env, configPath string) (*Config, error) { v := viper.New() // 设置环境变量前缀 @@ -100,17 +115,24 @@ func Load(env string) (*Config, error) { setDefaults(v) // 加载配置文件 - configFile := fmt.Sprintf("config.%s.yaml", env) - v.SetConfigName(configFile) - v.SetConfigType("yaml") - v.AddConfigPath(".") - v.AddConfigPath("./config") + if strings.TrimSpace(configPath) != "" { + v.SetConfigFile(configPath) + } else { + configFile := fmt.Sprintf("config.%s.yaml", env) + v.SetConfigName(configFile) + v.SetConfigType("yaml") + v.AddConfigPath(".") + v.AddConfigPath("./config") + } // 允许环境变量覆盖 v.AutomaticEnv() // 读取配置文件 if err := v.ReadInConfig(); err != nil { + if strings.TrimSpace(configPath) != "" { + return nil, fmt.Errorf("failed to read config: %w", err) + } if _, ok := err.(viper.ConfigFileNotFoundError); !ok { return nil, fmt.Errorf("failed to read config: %w", err) } @@ -163,6 +185,13 @@ func Load(env string) (*Config, error) { cfg.Audit.FlushInterval = v.GetDuration("audit.flush_interval") cfg.Audit.ExportTimeout = v.GetDuration("audit.export_timeout") + // Settlement配置 + cfg.Settlement.WithdrawEnabled = v.GetBool("settlement.withdraw_enabled") + + if err := validateForEnv(env, &cfg); err != nil { + return nil, err + } + return &cfg, nil } @@ -202,6 +231,9 @@ func setDefaults(v *viper.Viper) { v.SetDefault("token.revocation_cache_ttl", 30*time.Second) v.SetDefault("token.algorithm", "HS256") // 默认HS256,可配置RS256 + // Settlement defaults + v.SetDefault("settlement.withdraw_enabled", false) + // Audit defaults v.SetDefault("audit.buffer_size", 1000) v.SetDefault("audit.flush_interval", 5*time.Second) @@ -228,6 +260,10 @@ func bindEnvVars(v *viper.Viper) { _ = v.BindEnv("redis.db", "SUPPLY_REDIS_DB") _ = v.BindEnv("token.secret_key", "SUPPLY_TOKEN_SECRET_KEY") + _ = v.BindEnv("token.public_key", "SUPPLY_TOKEN_PUBLIC_KEY") + _ = v.BindEnv("token.algorithm", "SUPPLY_TOKEN_ALGORITHM") + _ = v.BindEnv("token.issuer", "SUPPLY_TOKEN_ISSUER") + _ = v.BindEnv("settlement.withdraw_enabled", "SUPPLY_SETTLEMENT_WITHDRAW_ENABLED") } // MustLoad 加载配置,失败时panic @@ -249,6 +285,46 @@ func GetEnvInt(key string, defaultVal int) int { return defaultVal } +func validateForEnv(env string, cfg *Config) error { + if cfg == nil { + return fmt.Errorf("config is nil") + } + + cfg.Token.Algorithm = strings.ToUpper(strings.TrimSpace(cfg.Token.Algorithm)) + if cfg.Token.Algorithm == "" { + cfg.Token.Algorithm = "HS256" + } + + if env != "prod" { + return nil + } + + if cfg.Server.DefaultSupplierID != 0 { + return fmt.Errorf("invalid prod config: server.default_supplier_id must be 0 to disable static supplier fallback") + } + if strings.TrimSpace(cfg.Token.Issuer) == "" { + return fmt.Errorf("invalid prod config: token.issuer is required") + } + if cfg.Settlement.WithdrawEnabled { + return fmt.Errorf("invalid prod config: settlement.withdraw_enabled cannot be true until SMS integration is production-ready") + } + + switch cfg.Token.Algorithm { + case "HS256", "HS384", "HS512": + if strings.TrimSpace(cfg.Token.SecretKey) == "" { + return fmt.Errorf("invalid prod config: token.secret_key is required for %s", cfg.Token.Algorithm) + } + case "RS256", "RS384", "RS512": + if strings.TrimSpace(cfg.Token.PublicKey) == "" { + return fmt.Errorf("invalid prod config: token.public_key is required for %s", cfg.Token.Algorithm) + } + default: + return fmt.Errorf("invalid prod config: unsupported token.algorithm %q", cfg.Token.Algorithm) + } + + return nil +} + // GetEnvDuration 获取环境变量duration值 func GetEnvDuration(key string, defaultVal time.Duration) time.Duration { if v := os.Getenv(key); v != "" { diff --git a/supply-api/internal/config/config_test.go b/supply-api/internal/config/config_test.go new file mode 100644 index 00000000..696f6556 --- /dev/null +++ b/supply-api/internal/config/config_test.go @@ -0,0 +1,135 @@ +package config + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func TestLoadFromPath(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, "custom.yaml") + content := []byte(` +server: + addr: ":19090" +database: + host: "db.internal" +token: + issuer: "custom-issuer" +`) + if err := os.WriteFile(configPath, content, 0o600); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + + cfg, err := LoadFromPath("dev", configPath) + if err != nil { + t.Fatalf("LoadFromPath returned error: %v", err) + } + + if cfg.Server.Addr != ":19090" { + t.Fatalf("expected addr :19090, got %s", cfg.Server.Addr) + } + if cfg.Database.Host != "db.internal" { + t.Fatalf("expected database host db.internal, got %s", cfg.Database.Host) + } + if cfg.Token.Issuer != "custom-issuer" { + t.Fatalf("expected token issuer custom-issuer, got %s", cfg.Token.Issuer) + } +} + +func TestLoadFromPath_MissingFile(t *testing.T) { + _, err := LoadFromPath("dev", filepath.Join(t.TempDir(), "missing.yaml")) + if err == nil { + t.Fatal("expected missing config file to return error") + } +} + +func TestLoadFromPath_ProdRejectsDefaultSupplierIDFallback(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, "prod.yaml") + content := []byte(` +server: + addr: ":19090" +database: + host: "db.internal" + user: "postgres" + password: "secret" + database: "supply_db" +token: + issuer: "prod-issuer" + secret_key: "prod-secret" +`) + if err := os.WriteFile(configPath, content, 0o600); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + + _, err := LoadFromPath("prod", configPath) + if err == nil { + t.Fatal("expected prod config with default supplier fallback to return error") + } + if !strings.Contains(err.Error(), "default_supplier_id") { + t.Fatalf("expected error to mention default_supplier_id, got %v", err) + } +} + +func TestLoadFromPath_ProdRejectsMissingHS256SecretKey(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, "prod.yaml") + content := []byte(` +server: + addr: ":19090" + default_supplier_id: 0 +database: + host: "db.internal" + user: "postgres" + password: "secret" + database: "supply_db" +token: + issuer: "prod-issuer" + algorithm: "HS256" +`) + if err := os.WriteFile(configPath, content, 0o600); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + + _, err := LoadFromPath("prod", configPath) + if err == nil { + t.Fatal("expected prod config without HS256 secret key to return error") + } + if !strings.Contains(err.Error(), "token.secret_key") { + t.Fatalf("expected error to mention token.secret_key, got %v", err) + } +} + +func TestLoadFromPath_ProdRejectsWithdrawEnabledUntilSMSIntegrated(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, "prod.yaml") + content := []byte(` +server: + addr: ":19090" + default_supplier_id: 0 +database: + host: "db.internal" + user: "postgres" + password: "secret" + database: "supply_db" +token: + issuer: "prod-issuer" + algorithm: "HS256" + secret_key: "prod-secret" +settlement: + withdraw_enabled: true +`) + if err := os.WriteFile(configPath, content, 0o600); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + + _, err := LoadFromPath("prod", configPath) + if err == nil { + t.Fatal("expected prod config with withdraw enabled but no integrated SMS path to return error") + } + if !strings.Contains(err.Error(), "withdraw_enabled") { + t.Fatalf("expected error to mention withdraw_enabled, got %v", err) + } +} diff --git a/supply-api/internal/httpapi/supply_api.go b/supply-api/internal/httpapi/supply_api.go index 89b52f3f..511c63bf 100644 --- a/supply-api/internal/httpapi/supply_api.go +++ b/supply-api/internal/httpapi/supply_api.go @@ -18,14 +18,15 @@ import ( // SupplyAPI 处理器 type SupplyAPI struct { - accountService domain.AccountService - packageService domain.PackageService - settlementService domain.SettlementService + accountService domain.AccountService + packageService domain.PackageService + settlementService domain.SettlementService earningService domain.EarningService idempotencyMw *middleware.IdempotencyMiddleware // P0-P4修复: 使用DB-backed幂等中间件 - auditStore audit.AuditStore // P0-R08修复: 使用接口支持DB-backed实现 - fkValidator *repository.ForeignKeyValidator // P0-09修复: 外键校验器 + auditStore audit.AuditStore // P0-R08修复: 使用接口支持DB-backed实现 + fkValidator *repository.ForeignKeyValidator // P0-09修复: 外键校验器 supplierID int64 + withdrawEnabled bool statementBaseURL string now func() time.Time } @@ -51,11 +52,16 @@ func NewSupplyAPI( auditStore: auditStore, fkValidator: fkValidator, supplierID: supplierID, + withdrawEnabled: true, statementBaseURL: statementBaseURL, now: now, } } +func (a *SupplyAPI) SetWithdrawEnabled(enabled bool) { + a.withdrawEnabled = enabled +} + func (a *SupplyAPI) Register(mux *http.ServeMux) { // Supply Accounts mux.HandleFunc("/api/v1/supply/accounts/verify", a.handleVerifyAccount) @@ -82,6 +88,25 @@ func (a *SupplyAPI) Register(mux *http.ServeMux) { mux.HandleFunc("/api/v1/audit/events/", a.handleAuditEvent) } +func (a *SupplyAPI) resolveSupplierID(ctx context.Context) (int64, error) { + if tenantID := middleware.GetTenantID(ctx); tenantID > 0 { + return tenantID, nil + } + if a.supplierID > 0 { + return a.supplierID, nil + } + return 0, fmt.Errorf("supplier context is missing") +} + +func (a *SupplyAPI) requireSupplierID(w http.ResponseWriter, r *http.Request) (int64, bool) { + supplierID, err := a.resolveSupplierID(r.Context()) + if err != nil { + writeError(w, http.StatusUnauthorized, "AUTH_CONTEXT_MISSING", err.Error()) + return 0, false + } + return supplierID, true +} + // ==================== Account Handlers ==================== type VerifyAccountRequest struct { @@ -110,7 +135,12 @@ func (a *SupplyAPI) handleVerifyAccount(w http.ResponseWriter, r *http.Request) return } - result, err := a.accountService.Verify(r.Context(), a.supplierID, + supplierID, ok := a.requireSupplierID(w, r) + if !ok { + return + } + + result, err := a.accountService.Verify(r.Context(), supplierID, domain.Provider(req.Provider), domain.AccountType(req.AccountType), req.CredentialInput) @@ -139,12 +169,17 @@ func (a *SupplyAPI) handleCreateAccount(w http.ResponseWriter, r *http.Request) } // 降级:使用内联幂等逻辑(仅在幂等中间件未启用时) - a.createAccountHandler(context.Background(), w, r, nil) + a.createAccountHandler(r.Context(), w, r, nil) } // createAccountHandler 创建账号的业务逻辑(供幂等中间件包装) func (a *SupplyAPI) createAccountHandler(ctx context.Context, w http.ResponseWriter, r *http.Request, _ *repository.IdempotencyRecord) error { requestID := r.Header.Get("X-Request-Id") + supplierID, err := a.resolveSupplierID(ctx) + if err != nil { + writeError(w, http.StatusUnauthorized, "AUTH_CONTEXT_MISSING", err.Error()) + return err + } body, err := io.ReadAll(r.Body) if err != nil { @@ -169,14 +204,14 @@ func (a *SupplyAPI) createAccountHandler(ctx context.Context, w http.ResponseWri // P0-09修复: 创建账户前校验外键引用 if a.fkValidator != nil { - if err := a.fkValidator.ValidateSupplyAccountOwner(ctx, a.supplierID); err != nil { + if err := a.fkValidator.ValidateSupplyAccountOwner(ctx, supplierID); err != nil { writeError(w, http.StatusUnprocessableEntity, "FK_VALIDATION_FAILED", "supplier does not exist") return err } } createReq := &domain.CreateAccountRequest{ - SupplierID: a.supplierID, + SupplierID: supplierID, Provider: domain.Provider(rawReq.Provider), AccountType: domain.AccountType(rawReq.AccountType), Credential: rawReq.CredentialInput, @@ -252,7 +287,12 @@ func (a *SupplyAPI) handleAccountActions(w http.ResponseWriter, r *http.Request) } func (a *SupplyAPI) handleActivateAccount(w http.ResponseWriter, r *http.Request, accountID int64) { - account, err := a.accountService.Activate(r.Context(), a.supplierID, accountID) + supplierID, ok := a.requireSupplierID(w, r) + if !ok { + return + } + + account, err := a.accountService.Activate(r.Context(), supplierID, accountID) if err != nil { if strings.Contains(err.Error(), "SUP_ACC") { writeError(w, http.StatusConflict, "CONFLICT", err.Error()) @@ -273,7 +313,12 @@ func (a *SupplyAPI) handleActivateAccount(w http.ResponseWriter, r *http.Request } func (a *SupplyAPI) handleSuspendAccount(w http.ResponseWriter, r *http.Request, accountID int64) { - account, err := a.accountService.Suspend(r.Context(), a.supplierID, accountID) + supplierID, ok := a.requireSupplierID(w, r) + if !ok { + return + } + + account, err := a.accountService.Suspend(r.Context(), supplierID, accountID) if err != nil { if strings.Contains(err.Error(), "SUP_ACC") { writeError(w, http.StatusConflict, "CONFLICT", err.Error()) @@ -294,7 +339,12 @@ func (a *SupplyAPI) handleSuspendAccount(w http.ResponseWriter, r *http.Request, } func (a *SupplyAPI) handleDeleteAccount(w http.ResponseWriter, r *http.Request, accountID int64) { - err := a.accountService.Delete(r.Context(), a.supplierID, accountID) + supplierID, ok := a.requireSupplierID(w, r) + if !ok { + return + } + + err := a.accountService.Delete(r.Context(), supplierID, accountID) if err != nil { if strings.Contains(err.Error(), "SUP_ACC") { writeError(w, http.StatusConflict, "CONFLICT", err.Error()) @@ -308,11 +358,27 @@ func (a *SupplyAPI) handleDeleteAccount(w http.ResponseWriter, r *http.Request, } func (a *SupplyAPI) handleAccountAuditLogs(w http.ResponseWriter, r *http.Request, accountID int64) { + supplierID, ok := a.requireSupplierID(w, r) + if !ok { + return + } + page := getQueryInt(r, "page", 1) pageSize := getQueryInt(r, "page_size", 20) + // 分页参数边界验证 + if page < 1 { + page = 1 + } + if pageSize < 1 { + pageSize = 20 + } + if pageSize > 1000 { + pageSize = 1000 + } + events, total, err := a.auditStore.QueryWithTotal(r.Context(), audit.EventFilter{ - TenantID: a.supplierID, + TenantID: supplierID, ObjectType: "supply_account", ObjectID: accountID, Limit: pageSize, @@ -378,6 +444,11 @@ func (a *SupplyAPI) handleCreatePackageDraft(w http.ResponseWriter, r *http.Requ return } + supplierID, ok := a.requireSupplierID(w, r) + if !ok { + return + } + // P0-09修复: 创建套餐前校验外键引用 if a.fkValidator != nil { if err := a.fkValidator.ValidatePackageSupplyAccount(r.Context(), req.SupplyAccountID); err != nil { @@ -387,7 +458,7 @@ func (a *SupplyAPI) handleCreatePackageDraft(w http.ResponseWriter, r *http.Requ } createReq := &domain.CreatePackageDraftRequest{ - SupplierID: a.supplierID, + SupplierID: supplierID, AccountID: req.SupplyAccountID, Model: req.Model, TotalQuota: req.TotalQuota, @@ -398,7 +469,7 @@ func (a *SupplyAPI) handleCreatePackageDraft(w http.ResponseWriter, r *http.Requ RateLimitRPM: req.RateLimitRPM, } - pkg, err := a.packageService.CreateDraft(r.Context(), a.supplierID, createReq) + pkg, err := a.packageService.CreateDraft(r.Context(), supplierID, createReq) if err != nil { writeError(w, http.StatusUnprocessableEntity, "CREATE_FAILED", err.Error()) return @@ -477,7 +548,12 @@ func (a *SupplyAPI) handlePackageActions(w http.ResponseWriter, r *http.Request) } func (a *SupplyAPI) handlePublishPackage(w http.ResponseWriter, r *http.Request, packageID int64) { - pkg, err := a.packageService.Publish(r.Context(), a.supplierID, packageID) + supplierID, ok := a.requireSupplierID(w, r) + if !ok { + return + } + + pkg, err := a.packageService.Publish(r.Context(), supplierID, packageID) if err != nil { if strings.Contains(err.Error(), "SUP_PKG") { writeError(w, http.StatusConflict, "CONFLICT", err.Error()) @@ -498,7 +574,12 @@ func (a *SupplyAPI) handlePublishPackage(w http.ResponseWriter, r *http.Request, } func (a *SupplyAPI) handlePausePackage(w http.ResponseWriter, r *http.Request, packageID int64) { - pkg, err := a.packageService.Pause(r.Context(), a.supplierID, packageID) + supplierID, ok := a.requireSupplierID(w, r) + if !ok { + return + } + + pkg, err := a.packageService.Pause(r.Context(), supplierID, packageID) if err != nil { if strings.Contains(err.Error(), "SUP_PKG") { writeError(w, http.StatusConflict, "CONFLICT", err.Error()) @@ -519,7 +600,12 @@ func (a *SupplyAPI) handlePausePackage(w http.ResponseWriter, r *http.Request, p } func (a *SupplyAPI) handleUnlistPackage(w http.ResponseWriter, r *http.Request, packageID int64) { - pkg, err := a.packageService.Unlist(r.Context(), a.supplierID, packageID) + supplierID, ok := a.requireSupplierID(w, r) + if !ok { + return + } + + pkg, err := a.packageService.Unlist(r.Context(), supplierID, packageID) if err != nil { if strings.Contains(err.Error(), "SUP_PKG") { writeError(w, http.StatusConflict, "CONFLICT", err.Error()) @@ -540,7 +626,12 @@ func (a *SupplyAPI) handleUnlistPackage(w http.ResponseWriter, r *http.Request, } func (a *SupplyAPI) handleClonePackage(w http.ResponseWriter, r *http.Request, packageID int64) { - pkg, err := a.packageService.Clone(r.Context(), a.supplierID, packageID) + supplierID, ok := a.requireSupplierID(w, r) + if !ok { + return + } + + pkg, err := a.packageService.Clone(r.Context(), supplierID, packageID) if err != nil { writeError(w, http.StatusNotFound, "NOT_FOUND", err.Error()) return @@ -595,7 +686,12 @@ func (a *SupplyAPI) handleBatchUpdatePrice(w http.ResponseWriter, r *http.Reques } } - resp, err := a.packageService.BatchUpdatePrice(r.Context(), a.supplierID, req) + supplierID, ok := a.requireSupplierID(w, r) + if !ok { + return + } + + resp, err := a.packageService.BatchUpdatePrice(r.Context(), supplierID, req) if err != nil { writeError(w, http.StatusUnprocessableEntity, "BATCH_UPDATE_FAILED", err.Error()) return @@ -618,7 +714,12 @@ func (a *SupplyAPI) handleGetBilling(w http.ResponseWriter, r *http.Request) { startDate := r.URL.Query().Get("start_date") endDate := r.URL.Query().Get("end_date") - summary, err := a.earningService.GetBillingSummary(r.Context(), a.supplierID, startDate, endDate) + supplierID, ok := a.requireSupplierID(w, r) + if !ok { + return + } + + summary, err := a.earningService.GetBillingSummary(r.Context(), supplierID, startDate, endDate) if err != nil { writeError(w, http.StatusInternalServerError, "QUERY_FAILED", err.Error()) return @@ -637,6 +738,10 @@ func (a *SupplyAPI) handleWithdraw(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", "method not allowed") return } + if !a.withdrawEnabled { + writeError(w, http.StatusServiceUnavailable, "FEATURE_DISABLED", "withdraw is disabled until SMS verification is integrated") + return + } // P0-P4修复: 使用DB-backed幂等中间件 if a.idempotencyMw != nil { @@ -645,12 +750,17 @@ func (a *SupplyAPI) handleWithdraw(w http.ResponseWriter, r *http.Request) { } // 降级:使用内联幂等逻辑(仅在幂等中间件未启用时) - a.withdrawHandler(context.Background(), w, r, nil) + a.withdrawHandler(r.Context(), w, r, nil) } // withdrawHandler 提现的业务逻辑(供幂等中间件包装) func (a *SupplyAPI) withdrawHandler(ctx context.Context, w http.ResponseWriter, r *http.Request, _ *repository.IdempotencyRecord) error { requestID := r.Header.Get("X-Request-Id") + supplierID, err := a.resolveSupplierID(ctx) + if err != nil { + writeError(w, http.StatusUnauthorized, "AUTH_CONTEXT_MISSING", err.Error()) + return err + } body, err := io.ReadAll(r.Body) if err != nil { @@ -678,7 +788,7 @@ func (a *SupplyAPI) withdrawHandler(ctx context.Context, w http.ResponseWriter, SMSCode: req.SMSCode, } - settlement, err := a.settlementService.Withdraw(ctx, a.supplierID, withdrawReq) + settlement, err := a.settlementService.Withdraw(ctx, supplierID, withdrawReq) if err != nil { if strings.Contains(err.Error(), "SUP_SET") { writeError(w, http.StatusConflict, "WITHDRAW_FAILED", err.Error()) @@ -740,7 +850,12 @@ func (a *SupplyAPI) handleSettlementActions(w http.ResponseWriter, r *http.Reque } func (a *SupplyAPI) handleCancelSettlement(w http.ResponseWriter, r *http.Request, settlementID int64) { - settlement, err := a.settlementService.Cancel(r.Context(), a.supplierID, settlementID) + supplierID, ok := a.requireSupplierID(w, r) + if !ok { + return + } + + settlement, err := a.settlementService.Cancel(r.Context(), supplierID, settlementID) if err != nil { if strings.Contains(err.Error(), "SUP_SET") { writeError(w, http.StatusConflict, "CONFLICT", err.Error()) @@ -761,7 +876,12 @@ func (a *SupplyAPI) handleCancelSettlement(w http.ResponseWriter, r *http.Reques } func (a *SupplyAPI) handleGetStatement(w http.ResponseWriter, r *http.Request, settlementID int64) { - settlement, err := a.settlementService.GetByID(r.Context(), a.supplierID, settlementID) + supplierID, ok := a.requireSupplierID(w, r) + if !ok { + return + } + + settlement, err := a.settlementService.GetByID(r.Context(), supplierID, settlementID) if err != nil { writeError(w, http.StatusNotFound, "NOT_FOUND", err.Error()) return @@ -791,7 +911,23 @@ func (a *SupplyAPI) handleGetEarningRecords(w http.ResponseWriter, r *http.Reque page := getQueryInt(r, "page", 1) pageSize := getQueryInt(r, "page_size", 20) - records, total, err := a.earningService.ListRecords(r.Context(), a.supplierID, startDate, endDate, page, pageSize) + // 分页参数边界验证 + if page < 1 { + page = 1 + } + if pageSize < 1 { + pageSize = 20 + } + if pageSize > 1000 { + pageSize = 1000 + } + + supplierID, ok := a.requireSupplierID(w, r) + if !ok { + return + } + + records, total, err := a.earningService.ListRecords(r.Context(), supplierID, startDate, endDate, page, pageSize) if err != nil { writeError(w, http.StatusInternalServerError, "QUERY_FAILED", err.Error()) return diff --git a/supply-api/internal/httpapi/supply_api_test.go b/supply-api/internal/httpapi/supply_api_test.go new file mode 100644 index 00000000..140766e1 --- /dev/null +++ b/supply-api/internal/httpapi/supply_api_test.go @@ -0,0 +1,1399 @@ +package httpapi + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "lijiaoqiao/supply-api/internal/audit" + "lijiaoqiao/supply-api/internal/domain" + "lijiaoqiao/supply-api/internal/middleware" +) + +// ==================== Mock Implementations ==================== + +// mockAccountService Mock账户服务 +type mockAccountService struct { + verifyResult *domain.VerifyResult + verifyErr error + account *domain.Account + createErr error + activateErr error + suspendErr error + deleteErr error + lastVerifySupplierID int64 +} + +func (m *mockAccountService) Verify(ctx context.Context, supplierID int64, provider domain.Provider, accountType domain.AccountType, credential string) (*domain.VerifyResult, error) { + m.lastVerifySupplierID = supplierID + if m.verifyErr != nil { + return nil, m.verifyErr + } + return m.verifyResult, nil +} + +func (m *mockAccountService) Create(ctx context.Context, req *domain.CreateAccountRequest) (*domain.Account, error) { + if m.createErr != nil { + return nil, m.createErr + } + return m.account, nil +} + +func (m *mockAccountService) Activate(ctx context.Context, supplierID, accountID int64) (*domain.Account, error) { + if m.activateErr != nil { + return nil, m.activateErr + } + return m.account, nil +} + +func (m *mockAccountService) Suspend(ctx context.Context, supplierID, accountID int64) (*domain.Account, error) { + if m.suspendErr != nil { + return nil, m.suspendErr + } + return m.account, nil +} + +func (m *mockAccountService) Delete(ctx context.Context, supplierID, accountID int64) error { + return m.deleteErr +} + +func (m *mockAccountService) GetByID(ctx context.Context, supplierID, accountID int64) (*domain.Account, error) { + return m.account, nil +} + +// mockPackageService Mock套餐服务 +type mockPackageService struct { + pkg *domain.Package + createDraftErr error + publishErr error + pauseErr error + unlistErr error + cloneErr error + batchResp *domain.BatchUpdatePriceResponse + batchErr error +} + +func (m *mockPackageService) CreateDraft(ctx context.Context, supplierID int64, req *domain.CreatePackageDraftRequest) (*domain.Package, error) { + if m.createDraftErr != nil { + return nil, m.createDraftErr + } + return m.pkg, nil +} + +func (m *mockPackageService) Publish(ctx context.Context, supplierID, packageID int64) (*domain.Package, error) { + if m.publishErr != nil { + return nil, m.publishErr + } + return m.pkg, nil +} + +func (m *mockPackageService) Pause(ctx context.Context, supplierID, packageID int64) (*domain.Package, error) { + if m.pauseErr != nil { + return nil, m.pauseErr + } + return m.pkg, nil +} + +func (m *mockPackageService) Unlist(ctx context.Context, supplierID, packageID int64) (*domain.Package, error) { + if m.unlistErr != nil { + return nil, m.unlistErr + } + return m.pkg, nil +} + +func (m *mockPackageService) Clone(ctx context.Context, supplierID, packageID int64) (*domain.Package, error) { + if m.cloneErr != nil { + return nil, m.cloneErr + } + return m.pkg, nil +} + +func (m *mockPackageService) BatchUpdatePrice(ctx context.Context, supplierID int64, req *domain.BatchUpdatePriceRequest) (*domain.BatchUpdatePriceResponse, error) { + if m.batchErr != nil { + return nil, m.batchErr + } + return m.batchResp, nil +} + +func (m *mockPackageService) GetByID(ctx context.Context, supplierID, packageID int64) (*domain.Package, error) { + return m.pkg, nil +} + +// mockSettlementService Mock结算服务 +type mockSettlementService struct { + settlement *domain.Settlement + withdrawErr error + cancelErr error + getErr error +} + +func (m *mockSettlementService) Withdraw(ctx context.Context, supplierID int64, req *domain.WithdrawRequest) (*domain.Settlement, error) { + if m.withdrawErr != nil { + return nil, m.withdrawErr + } + return m.settlement, nil +} + +func (m *mockSettlementService) Cancel(ctx context.Context, supplierID, settlementID int64) (*domain.Settlement, error) { + if m.cancelErr != nil { + return nil, m.cancelErr + } + return m.settlement, nil +} + +func (m *mockSettlementService) GetByID(ctx context.Context, supplierID, settlementID int64) (*domain.Settlement, error) { + if m.getErr != nil { + return nil, m.getErr + } + return m.settlement, nil +} + +func (m *mockSettlementService) List(ctx context.Context, supplierID int64) ([]*domain.Settlement, error) { + if m.settlement != nil { + return []*domain.Settlement{m.settlement}, nil + } + return nil, nil +} + +// mockEarningService Mock收益服务 +type mockEarningService struct { + records []*domain.EarningRecord + total int + billingSummary *domain.BillingSummary + listErr error + billingErr error +} + +func (m *mockEarningService) ListRecords(ctx context.Context, supplierID int64, startDate, endDate string, page, pageSize int) ([]*domain.EarningRecord, int, error) { + if m.listErr != nil { + return nil, 0, m.listErr + } + return m.records, m.total, nil +} + +func (m *mockEarningService) GetBillingSummary(ctx context.Context, supplierID int64, startDate, endDate string) (*domain.BillingSummary, error) { + if m.billingErr != nil { + return nil, m.billingErr + } + return m.billingSummary, nil +} + +// mockAuditStore Mock审计存储 +type mockAuditStore struct { + events []audit.Event + event audit.Event + err error +} + +func (m *mockAuditStore) Emit(ctx context.Context, event audit.Event) error { + return m.err +} + +func (m *mockAuditStore) Query(ctx context.Context, filter audit.EventFilter) ([]audit.Event, error) { + if m.err != nil { + return nil, m.err + } + return m.events, nil +} + +func (m *mockAuditStore) QueryWithTotal(ctx context.Context, filter audit.EventFilter) ([]audit.Event, int64, error) { + if m.err != nil { + return nil, 0, m.err + } + return m.events, int64(len(m.events)), nil +} + +func (m *mockAuditStore) GetByID(ctx context.Context, eventID string) (audit.Event, error) { + if m.err != nil { + return audit.Event{}, m.err + } + return m.event, nil +} + +// ==================== Test Helpers ==================== + +func newTestAPI() (*SupplyAPI, *mockAccountService, *mockPackageService, *mockSettlementService, *mockEarningService, *mockAuditStore) { + accountSvc := &mockAccountService{ + account: &domain.Account{ + ID: 1, + SupplierID: 100, + Provider: domain.ProviderOpenAI, + AccountType: domain.AccountTypeAPIKey, + Status: domain.AccountStatusActive, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }, + verifyResult: &domain.VerifyResult{ + VerifyStatus: "pass", + AvailableQuota: 1000, + RiskScore: 0, + }, + } + + packageSvc := &mockPackageService{ + pkg: &domain.Package{ + ID: 1, + SupplierID: 100, + Model: "gpt-4", + Status: domain.PackageStatusActive, + TotalQuota: 10000, + AvailableQuota: 8000, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }, + } + + settlementSvc := &mockSettlementService{ + settlement: &domain.Settlement{ + ID: 1, + SupplierID: 100, + Status: domain.SettlementStatusPending, + TotalAmount: 1000, + NetAmount: 950, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }, + } + + earningSvc := &mockEarningService{ + records: []*domain.EarningRecord{ + { + ID: 1, + Amount: 100, + Status: "available", + }, + }, + total: 1, + billingSummary: &domain.BillingSummary{}, + } + + auditSvc := &mockAuditStore{ + events: []audit.Event{ + { + EventID: "evt_123", + TenantID: 100, + ObjectType: "supply_account", + ObjectID: 1, + Action: "create", + CreatedAt: time.Now(), + }, + }, + event: audit.Event{ + EventID: "evt_123", + TenantID: 100, + ObjectType: "supply_account", + ObjectID: 1, + Action: "create", + CreatedAt: time.Now(), + }, + } + + api := NewSupplyAPI( + accountSvc, + packageSvc, + settlementSvc, + earningSvc, + nil, // idempotencyMw + auditSvc, + nil, // fkValidator + 100, // supplierID + "https://statements.example.com", + time.Now, + ) + + return api, accountSvc, packageSvc, settlementSvc, earningSvc, auditSvc +} + +// ==================== Account Handler Tests ==================== + +func TestSupplyAPI_VerifyAccount_Success(t *testing.T) { + api, _, _, _, _, _ := newTestAPI() + + body := `{"provider":"openai","account_type":"resource","credential_input":"sk-test123"}` + req := httptest.NewRequest("POST", "/api/v1/supply/accounts/verify", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Request-Id", "test-req-001") + w := httptest.NewRecorder() + + api.handleVerifyAccount(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + var resp map[string]any + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + if resp["request_id"] != "test-req-001" { + t.Errorf("expected request_id test-req-001, got %v", resp["request_id"]) + } +} + +func TestSupplyAPI_VerifyAccount_MethodNotAllowed(t *testing.T) { + api, _, _, _, _, _ := newTestAPI() + + req := httptest.NewRequest("GET", "/api/v1/supply/accounts/verify", nil) + w := httptest.NewRecorder() + + api.handleVerifyAccount(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Errorf("expected status 405, got %d", w.Code) + } +} + +func TestSupplyAPI_VerifyAccount_InvalidJSON(t *testing.T) { + api, _, _, _, _, _ := newTestAPI() + + body := `{invalid json}` + req := httptest.NewRequest("POST", "/api/v1/supply/accounts/verify", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + api.handleVerifyAccount(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected status 400, got %d", w.Code) + } +} + +func TestSupplyAPI_VerifyAccount_VerifyFailed(t *testing.T) { + api, accountSvc, _, _, _, _ := newTestAPI() + accountSvc.verifyErr = errors.New("SUP_ACC_4001: verification failed") + + body := `{"provider":"openai","account_type":"resource","credential_input":"invalid"}` + req := httptest.NewRequest("POST", "/api/v1/supply/accounts/verify", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + api.handleVerifyAccount(w, req) + + if w.Code != http.StatusUnprocessableEntity { + t.Errorf("expected status 422, got %d", w.Code) + } +} + +func TestSupplyAPI_VerifyAccount_UsesTenantIDFromContext(t *testing.T) { + api, accountSvc, _, _, _, _ := newTestAPI() + + body := `{"provider":"openai","account_type":"resource","credential_input":"sk-test123"}` + req := httptest.NewRequest("POST", "/api/v1/supply/accounts/verify", strings.NewReader(body)) + req = req.WithContext(middleware.WithTenantID(req.Context(), 200)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + api.handleVerifyAccount(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + if accountSvc.lastVerifySupplierID != 200 { + t.Fatalf("expected tenant supplier ID 200, got %d", accountSvc.lastVerifySupplierID) + } +} + +func TestSupplyAPI_VerifyAccount_RejectsMissingTenantContextWithoutDefaultSupplier(t *testing.T) { + api, _, _, _, _, _ := newTestAPI() + api.supplierID = 0 + + body := `{"provider":"openai","account_type":"resource","credential_input":"sk-test123"}` + req := httptest.NewRequest("POST", "/api/v1/supply/accounts/verify", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + api.handleVerifyAccount(w, req) + + if w.Code != http.StatusUnauthorized { + t.Fatalf("expected status 401, got %d body=%s", w.Code, w.Body.String()) + } +} + +func TestSupplyAPI_CreateAccount_Success(t *testing.T) { + api, _, _, _, _, _ := newTestAPI() + + body := `{"provider":"openai","account_type":"resource","credential_input":"sk-test","account_alias":"test","risk_ack":true}` + req := httptest.NewRequest("POST", "/api/v1/supply/accounts", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + api.handleCreateAccount(w, req) + + if w.Code != http.StatusCreated { + t.Errorf("expected status 201, got %d", w.Code) + } +} + +func TestSupplyAPI_CreateAccount_MethodNotAllowed(t *testing.T) { + api, _, _, _, _, _ := newTestAPI() + + req := httptest.NewRequest("GET", "/api/v1/supply/accounts", nil) + w := httptest.NewRecorder() + + api.handleCreateAccount(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Errorf("expected status 405, got %d", w.Code) + } +} + +func TestSupplyAPI_ActivateAccount_Success(t *testing.T) { + api, _, _, _, _, _ := newTestAPI() + + req := httptest.NewRequest("POST", "/api/v1/supply/accounts/1/activate", nil) + w := httptest.NewRecorder() + + api.handleAccountActions(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } +} + +func TestSupplyAPI_ActivateAccount_NotFound(t *testing.T) { + api, accountSvc, _, _, _, _ := newTestAPI() + accountSvc.activateErr = errors.New("account not found") + + req := httptest.NewRequest("POST", "/api/v1/supply/accounts/1/activate", nil) + w := httptest.NewRecorder() + + api.handleAccountActions(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("expected status 404, got %d", w.Code) + } +} + +func TestSupplyAPI_SuspendAccount_Success(t *testing.T) { + api, _, _, _, _, _ := newTestAPI() + + req := httptest.NewRequest("POST", "/api/v1/supply/accounts/1/suspend", nil) + w := httptest.NewRecorder() + + api.handleAccountActions(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } +} + +func TestSupplyAPI_SuspendAccount_Conflict(t *testing.T) { + api, accountSvc, _, _, _, _ := newTestAPI() + accountSvc.suspendErr = errors.New("SUP_ACC_4091: account state conflict") + + req := httptest.NewRequest("POST", "/api/v1/supply/accounts/1/suspend", nil) + w := httptest.NewRecorder() + + api.handleAccountActions(w, req) + + if w.Code != http.StatusConflict { + t.Errorf("expected status 409, got %d", w.Code) + } +} + +func TestSupplyAPI_SuspendAccount_WrongMethod(t *testing.T) { + api, _, _, _, _, _ := newTestAPI() + + req := httptest.NewRequest("GET", "/api/v1/supply/accounts/1/suspend", nil) + w := httptest.NewRecorder() + + api.handleAccountActions(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Errorf("expected status 405, got %d", w.Code) + } +} + +func TestSupplyAPI_DeleteAccount_Success(t *testing.T) { + api, _, _, _, _, _ := newTestAPI() + + req := httptest.NewRequest("DELETE", "/api/v1/supply/accounts/1/delete", nil) + w := httptest.NewRecorder() + + api.handleAccountActions(w, req) + + if w.Code != http.StatusNoContent { + t.Errorf("expected status 204, got %d", w.Code) + } +} + +func TestSupplyAPI_DeleteAccount_Conflict(t *testing.T) { + api, accountSvc, _, _, _, _ := newTestAPI() + accountSvc.deleteErr = errors.New("SUP_ACC_4092: cannot delete account with active packages") + + req := httptest.NewRequest("DELETE", "/api/v1/supply/accounts/1/delete", nil) + w := httptest.NewRecorder() + + api.handleAccountActions(w, req) + + if w.Code != http.StatusConflict { + t.Errorf("expected status 409, got %d", w.Code) + } +} + +func TestSupplyAPI_DeleteAccount_WrongMethod(t *testing.T) { + api, _, _, _, _, _ := newTestAPI() + + req := httptest.NewRequest("POST", "/api/v1/supply/accounts/1/delete", nil) + w := httptest.NewRecorder() + + api.handleAccountActions(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Errorf("expected status 405, got %d", w.Code) + } +} + +func TestSupplyAPI_AccountAuditLogs_Success(t *testing.T) { + api, _, _, _, _, auditSvc := newTestAPI() + + req := httptest.NewRequest("GET", "/api/v1/supply/accounts/1/audit-logs", nil) + w := httptest.NewRecorder() + + api.handleAccountActions(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + var resp map[string]any + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + data, ok := resp["data"].([]any) + if !ok { + t.Fatal("expected data array in response") + } + if len(data) != 1 { + t.Errorf("expected 1 event, got %d", len(data)) + } + + auditSvc.err = errors.New("query failed") + req = httptest.NewRequest("GET", "/api/v1/supply/accounts/1/audit-logs", nil) + w = httptest.NewRecorder() + + api.handleAccountActions(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("expected status 500, got %d", w.Code) + } +} + +func TestSupplyAPI_AccountActions_InvalidID(t *testing.T) { + api, _, _, _, _, _ := newTestAPI() + + req := httptest.NewRequest("POST", "/api/v1/supply/accounts/invalid/activate", nil) + w := httptest.NewRecorder() + + api.handleAccountActions(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected status 400, got %d", w.Code) + } +} + +func TestSupplyAPI_AccountActions_UnknownRoute(t *testing.T) { + api, _, _, _, _, _ := newTestAPI() + + req := httptest.NewRequest("POST", "/api/v1/supply/accounts/1/unknown", nil) + w := httptest.NewRecorder() + + api.handleAccountActions(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("expected status 404, got %d", w.Code) + } +} + +// ==================== Package Handler Tests ==================== + +func TestSupplyAPI_CreatePackageDraft_Success(t *testing.T) { + api, _, _, _, _, _ := newTestAPI() + + body := `{"supply_account_id":1,"model":"gpt-4","total_quota":10000,"price_per_1m_input":0.1,"price_per_1m_output":0.2,"valid_days":30,"max_concurrent":10,"rate_limit_rpm":1000}` + req := httptest.NewRequest("POST", "/api/v1/supply/packages/draft", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + api.handleCreatePackageDraft(w, req) + + if w.Code != http.StatusCreated { + t.Errorf("expected status 201, got %d", w.Code) + } +} + +func TestSupplyAPI_CreatePackageDraft_MethodNotAllowed(t *testing.T) { + api, _, _, _, _, _ := newTestAPI() + + req := httptest.NewRequest("GET", "/api/v1/supply/packages/draft", nil) + w := httptest.NewRecorder() + + api.handleCreatePackageDraft(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Errorf("expected status 405, got %d", w.Code) + } +} + +func TestSupplyAPI_PublishPackage_Success(t *testing.T) { + api, _, _, _, _, _ := newTestAPI() + + req := httptest.NewRequest("POST", "/api/v1/supply/packages/1/publish", nil) + w := httptest.NewRecorder() + + api.handlePackageActions(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } +} + +func TestSupplyAPI_PublishPackage_NotFound(t *testing.T) { + api, _, packageSvc, _, _, _ := newTestAPI() + packageSvc.publishErr = errors.New("package not found") + + req := httptest.NewRequest("POST", "/api/v1/supply/packages/1/publish", nil) + w := httptest.NewRecorder() + + api.handlePackageActions(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("expected status 404, got %d", w.Code) + } +} + +func TestSupplyAPI_PausePackage_Success(t *testing.T) { + api, _, _, _, _, _ := newTestAPI() + + req := httptest.NewRequest("POST", "/api/v1/supply/packages/1/pause", nil) + w := httptest.NewRecorder() + + api.handlePackageActions(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } +} + +func TestSupplyAPI_PausePackage_Conflict(t *testing.T) { + api, _, packageSvc, _, _, _ := newTestAPI() + packageSvc.pauseErr = errors.New("SUP_PKG_4092: cannot pause active package") + + req := httptest.NewRequest("POST", "/api/v1/supply/packages/1/pause", nil) + w := httptest.NewRecorder() + + api.handlePackageActions(w, req) + + if w.Code != http.StatusConflict { + t.Errorf("expected status 409, got %d", w.Code) + } +} + +func TestSupplyAPI_PausePackage_WrongMethod(t *testing.T) { + api, _, _, _, _, _ := newTestAPI() + + req := httptest.NewRequest("GET", "/api/v1/supply/packages/1/pause", nil) + w := httptest.NewRecorder() + + api.handlePackageActions(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Errorf("expected status 405, got %d", w.Code) + } +} + +func TestSupplyAPI_UnlistPackage_Success(t *testing.T) { + api, _, _, _, _, _ := newTestAPI() + + req := httptest.NewRequest("POST", "/api/v1/supply/packages/1/unlist", nil) + w := httptest.NewRecorder() + + api.handlePackageActions(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } +} + +func TestSupplyAPI_UnlistPackage_Conflict(t *testing.T) { + api, _, packageSvc, _, _, _ := newTestAPI() + packageSvc.unlistErr = errors.New("SUP_PKG_4093: cannot unlist package") + + req := httptest.NewRequest("POST", "/api/v1/supply/packages/1/unlist", nil) + w := httptest.NewRecorder() + + api.handlePackageActions(w, req) + + if w.Code != http.StatusConflict { + t.Errorf("expected status 409, got %d", w.Code) + } +} + +func TestSupplyAPI_UnlistPackage_WrongMethod(t *testing.T) { + api, _, _, _, _, _ := newTestAPI() + + req := httptest.NewRequest("GET", "/api/v1/supply/packages/1/unlist", nil) + w := httptest.NewRecorder() + + api.handlePackageActions(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Errorf("expected status 405, got %d", w.Code) + } +} + +func TestSupplyAPI_ClonePackage_WrongMethod(t *testing.T) { + api, _, _, _, _, _ := newTestAPI() + + req := httptest.NewRequest("GET", "/api/v1/supply/packages/1/clone", nil) + w := httptest.NewRecorder() + + api.handlePackageActions(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Errorf("expected status 405, got %d", w.Code) + } +} + +func TestSupplyAPI_ClonePackage_NotFound(t *testing.T) { + api, _, packageSvc, _, _, _ := newTestAPI() + packageSvc.cloneErr = errors.New("package not found") + + req := httptest.NewRequest("POST", "/api/v1/supply/packages/1/clone", nil) + w := httptest.NewRecorder() + + api.handlePackageActions(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("expected status 404, got %d", w.Code) + } +} + +func TestSupplyAPI_PublishPackage_WrongMethod(t *testing.T) { + api, _, _, _, _, _ := newTestAPI() + + req := httptest.NewRequest("GET", "/api/v1/supply/packages/1/publish", nil) + w := httptest.NewRecorder() + + api.handlePackageActions(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Errorf("expected status 405, got %d", w.Code) + } +} + +func TestSupplyAPI_ClonePackage_Success(t *testing.T) { + api, _, _, _, _, _ := newTestAPI() + + req := httptest.NewRequest("POST", "/api/v1/supply/packages/1/clone", nil) + w := httptest.NewRecorder() + + api.handlePackageActions(w, req) + + if w.Code != http.StatusCreated { + t.Errorf("expected status 201, got %d", w.Code) + } +} + +func TestSupplyAPI_BatchUpdatePrice_Success(t *testing.T) { + api, _, _, _, _, _ := newTestAPI() + api.packageService.(*mockPackageService).batchResp = &domain.BatchUpdatePriceResponse{ + Total: 2, + SuccessCount: 2, + FailedCount: 0, + } + + body := `{"items":[{"package_id":1,"price_per_1m_input":0.15,"price_per_1m_output":0.25},{"package_id":2,"price_per_1m_input":0.12,"price_per_1m_output":0.22}]}` + req := httptest.NewRequest("POST", "/api/v1/supply/packages/batch-price", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + api.handleBatchUpdatePrice(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } +} + +func TestSupplyAPI_BatchUpdatePrice_MethodNotAllowed(t *testing.T) { + api, _, _, _, _, _ := newTestAPI() + + req := httptest.NewRequest("GET", "/api/v1/supply/packages/batch-price", nil) + w := httptest.NewRecorder() + + api.handleBatchUpdatePrice(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Errorf("expected status 405, got %d", w.Code) + } +} + +func TestSupplyAPI_BatchUpdatePrice_InvalidJSON(t *testing.T) { + api, _, _, _, _, _ := newTestAPI() + + body := `{invalid}` + req := httptest.NewRequest("POST", "/api/v1/supply/packages/batch-price", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + api.handleBatchUpdatePrice(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected status 400, got %d", w.Code) + } +} + +func TestSupplyAPI_BatchUpdatePrice_BatchFailed(t *testing.T) { + api, _, packageSvc, _, _, _ := newTestAPI() + packageSvc.batchErr = errors.New("batch update failed") + + body := `{"items":[{"package_id":1,"price_per_1m_input":0.15,"price_per_1m_output":0.25}]}` + req := httptest.NewRequest("POST", "/api/v1/supply/packages/batch-price", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + api.handleBatchUpdatePrice(w, req) + + if w.Code != http.StatusUnprocessableEntity { + t.Errorf("expected status 422, got %d", w.Code) + } +} + +// ==================== Billing Handler Tests ==================== + +func TestSupplyAPI_GetBilling_Success(t *testing.T) { + api, _, _, _, _, _ := newTestAPI() + + req := httptest.NewRequest("GET", "/api/v1/supply/billing?start_date=2024-01-01&end_date=2024-01-31", nil) + w := httptest.NewRecorder() + + api.handleGetBilling(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } +} + +func TestSupplyAPI_GetBilling_MethodNotAllowed(t *testing.T) { + api, _, _, _, _, _ := newTestAPI() + + req := httptest.NewRequest("POST", "/api/v1/supply/billing", nil) + w := httptest.NewRecorder() + + api.handleGetBilling(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Errorf("expected status 405, got %d", w.Code) + } +} + +func TestSupplyAPI_GetBilling_QueryFailed(t *testing.T) { + api, _, _, _, earningSvc, _ := newTestAPI() + earningSvc.billingErr = errors.New("query failed") + + req := httptest.NewRequest("GET", "/api/v1/supply/billing", nil) + w := httptest.NewRecorder() + + api.handleGetBilling(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("expected status 500, got %d", w.Code) + } +} + +// ==================== Settlement Handler Tests ==================== + +func TestSupplyAPI_Withdraw_Success(t *testing.T) { + api, _, _, _, _, _ := newTestAPI() + + body := `{"withdraw_amount":1000,"payment_method":"bank","payment_account":"1234567890","sms_code":"123456"}` + req := httptest.NewRequest("POST", "/api/v1/supply/settlements/withdraw", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + api.handleWithdraw(w, req) + + if w.Code != http.StatusCreated { + t.Errorf("expected status 201, got %d", w.Code) + } +} + +func TestSupplyAPI_Withdraw_MethodNotAllowed(t *testing.T) { + api, _, _, _, _, _ := newTestAPI() + + req := httptest.NewRequest("GET", "/api/v1/supply/settlements/withdraw", nil) + w := httptest.NewRecorder() + + api.handleWithdraw(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Errorf("expected status 405, got %d", w.Code) + } +} + +func TestSupplyAPI_Withdraw_InvalidJSON(t *testing.T) { + api, _, _, _, _, _ := newTestAPI() + + body := `{invalid}` + req := httptest.NewRequest("POST", "/api/v1/supply/settlements/withdraw", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + api.handleWithdraw(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected status 400, got %d", w.Code) + } +} + +func TestSupplyAPI_Withdraw_WithdrawFailed(t *testing.T) { + api, _, _, settlementSvc, _, _ := newTestAPI() + settlementSvc.withdrawErr = errors.New("SUP_SET_4001: insufficient balance") + + body := `{"withdraw_amount":1000000,"payment_method":"bank","payment_account":"1234567890","sms_code":"123456"}` + req := httptest.NewRequest("POST", "/api/v1/supply/settlements/withdraw", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + api.handleWithdraw(w, req) + + if w.Code != http.StatusConflict { + t.Errorf("expected status 409, got %d", w.Code) + } +} + +func TestSupplyAPI_CancelSettlement_Success(t *testing.T) { + api, _, _, _, _, _ := newTestAPI() + + req := httptest.NewRequest("POST", "/api/v1/supply/settlements/1/cancel", nil) + w := httptest.NewRecorder() + + api.handleSettlementActions(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } +} + +func TestSupplyAPI_CancelSettlement_NotFound(t *testing.T) { + api, _, _, settlementSvc, _, _ := newTestAPI() + settlementSvc.cancelErr = errors.New("settlement not found") + + req := httptest.NewRequest("POST", "/api/v1/supply/settlements/1/cancel", nil) + w := httptest.NewRecorder() + + api.handleSettlementActions(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("expected status 404, got %d", w.Code) + } +} + +func TestSupplyAPI_GetStatement_Success(t *testing.T) { + api, _, _, _, _, _ := newTestAPI() + + req := httptest.NewRequest("GET", "/api/v1/supply/settlements/1/statement", nil) + w := httptest.NewRecorder() + + api.handleSettlementActions(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + var resp map[string]any + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + data, ok := resp["data"].(map[string]any) + if !ok { + t.Fatal("expected data in response") + } + + if data["file_name"] == nil { + t.Error("expected file_name in data") + } + if data["download_url"] == nil { + t.Error("expected download_url in data") + } + if data["expires_at"] == nil { + t.Error("expected expires_at in data") + } +} + +func TestSupplyAPI_GetStatement_NotFound(t *testing.T) { + api, _, _, settlementSvc, _, _ := newTestAPI() + settlementSvc.getErr = errors.New("settlement not found") + + req := httptest.NewRequest("GET", "/api/v1/supply/settlements/1/statement", nil) + w := httptest.NewRecorder() + + api.handleSettlementActions(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("expected status 404, got %d", w.Code) + } +} + +func TestSupplyAPI_SettlementActions_InvalidID(t *testing.T) { + api, _, _, _, _, _ := newTestAPI() + + req := httptest.NewRequest("POST", "/api/v1/supply/settlements/invalid/cancel", nil) + w := httptest.NewRecorder() + + api.handleSettlementActions(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected status 400, got %d", w.Code) + } +} + +func TestSupplyAPI_SettlementActions_UnknownAction(t *testing.T) { + api, _, _, _, _, _ := newTestAPI() + + req := httptest.NewRequest("POST", "/api/v1/supply/settlements/1/unknown", nil) + w := httptest.NewRecorder() + + api.handleSettlementActions(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("expected status 404, got %d", w.Code) + } +} + +// ==================== Earning Handler Tests ==================== + +func TestSupplyAPI_GetEarningRecords_Success(t *testing.T) { + api, _, _, _, _, _ := newTestAPI() + + req := httptest.NewRequest("GET", "/api/v1/supply/earnings/records?start_date=2024-01-01&end_date=2024-01-31&page=1&page_size=20", nil) + w := httptest.NewRecorder() + + api.handleGetEarningRecords(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + var resp map[string]any + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + data, ok := resp["data"].([]any) + if !ok { + t.Fatal("expected data array in response") + } + if len(data) != 1 { + t.Errorf("expected 1 record, got %d", len(data)) + } + + pagination, ok := resp["pagination"].(map[string]any) + if !ok { + t.Fatal("expected pagination in response") + } + if pagination["total"] != float64(1) { + t.Errorf("expected total 1, got %v", pagination["total"]) + } +} + +func TestSupplyAPI_GetEarningRecords_MethodNotAllowed(t *testing.T) { + api, _, _, _, _, _ := newTestAPI() + + req := httptest.NewRequest("POST", "/api/v1/supply/earnings/records", nil) + w := httptest.NewRecorder() + + api.handleGetEarningRecords(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Errorf("expected status 405, got %d", w.Code) + } +} + +func TestSupplyAPI_GetEarningRecords_QueryFailed(t *testing.T) { + api, _, _, _, earningSvc, _ := newTestAPI() + earningSvc.listErr = errors.New("query failed") + + req := httptest.NewRequest("GET", "/api/v1/supply/earnings/records", nil) + w := httptest.NewRecorder() + + api.handleGetEarningRecords(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("expected status 500, got %d", w.Code) + } +} + +// ==================== Audit Event Handler Tests ==================== + +func TestSupplyAPI_GetAuditEvent_Success(t *testing.T) { + api, _, _, _, _, _ := newTestAPI() + + req := httptest.NewRequest("GET", "/api/v1/audit/events/evt_123", nil) + w := httptest.NewRecorder() + + api.handleAuditEvent(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + var resp map[string]any + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + data, ok := resp["data"].(map[string]any) + if !ok { + t.Fatal("expected data in response") + } + if data["event_id"] != "evt_123" { + t.Errorf("expected event_id evt_123, got %v", data["event_id"]) + } +} + +func TestSupplyAPI_GetAuditEvent_NotFound(t *testing.T) { + api, _, _, _, _, auditSvc := newTestAPI() + auditSvc.err = errors.New("not found") + + req := httptest.NewRequest("GET", "/api/v1/audit/events/evt_999", nil) + w := httptest.NewRecorder() + + api.handleAuditEvent(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("expected status 404, got %d", w.Code) + } +} + +func TestSupplyAPI_GetAuditEvent_MissingID(t *testing.T) { + api, _, _, _, _, _ := newTestAPI() + + req := httptest.NewRequest("GET", "/api/v1/audit/events/", nil) + w := httptest.NewRecorder() + + api.handleAuditEvent(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected status 400, got %d", w.Code) + } +} + +func TestSupplyAPI_GetAuditEvent_MethodNotAllowed(t *testing.T) { + api, _, _, _, _, _ := newTestAPI() + + req := httptest.NewRequest("POST", "/api/v1/audit/events/evt_123", nil) + w := httptest.NewRecorder() + + api.handleAuditEvent(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Errorf("expected status 405, got %d", w.Code) + } +} + +// ==================== Helper Function Tests ==================== + +func TestGetRequestID(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("X-Request-Id", "req-123") + + id := getRequestID(req) + if id != "req-123" { + t.Errorf("expected req-123, got %s", id) + } + + req = httptest.NewRequest("GET", "/", nil) + req.Header.Set("X-Request-ID", "req-456") + + id = getRequestID(req) + if id != "req-456" { + t.Errorf("expected req-456, got %s", id) + } + + req = httptest.NewRequest("GET", "/", nil) + id = getRequestID(req) + if id != "" { + t.Errorf("expected empty string, got %s", id) + } +} + +func TestGetQueryInt(t *testing.T) { + req := httptest.NewRequest("GET", "/?page=5&page_size=100", nil) + + if getQueryInt(req, "page", 1) != 5 { + t.Error("expected page 5") + } + if getQueryInt(req, "page_size", 20) != 100 { + t.Error("expected page_size 100") + } + if getQueryInt(req, "missing", 10) != 10 { + t.Error("expected default 10 for missing param") + } + if getQueryInt(req, "invalid", 1) != 1 { + t.Error("expected default 1 for invalid value") + } +} + +func TestWriteJSON(t *testing.T) { + w := httptest.NewRecorder() + + writeJSON(w, http.StatusOK, map[string]any{"key": "value"}) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + if w.Header().Get("Content-Type") != "application/json" { + t.Error("expected Content-Type application/json") + } +} + +func TestWriteError(t *testing.T) { + w := httptest.NewRecorder() + + writeError(w, http.StatusBadRequest, "TEST_ERROR", "test message") + + if w.Code != http.StatusBadRequest { + t.Errorf("expected status 400, got %d", w.Code) + } + + var resp map[string]any + json.Unmarshal(w.Body.Bytes(), &resp) + + errObj, ok := resp["error"].(map[string]any) + if !ok { + t.Fatal("expected error object in response") + } + if errObj["code"] != "TEST_ERROR" { + t.Errorf("expected code TEST_ERROR, got %v", errObj["code"]) + } + if errObj["message"] != "test message" { + t.Errorf("expected message 'test message', got %v", errObj["message"]) + } +} + +// ==================== Integration Tests ==================== + +func TestSupplyAPI_Register(t *testing.T) { + api, _, _, _, _, _ := newTestAPI() + mux := http.NewServeMux() + + api.Register(mux) + + // 验证路由已注册(不会panic) + _ = mux +} + +func TestSupplyAPI_EndToEnd_Withdraw(t *testing.T) { + api, _, _, settlementSvc, _, _ := newTestAPI() + settlementSvc.settlement = &domain.Settlement{ + ID: 1, + SupplierID: 100, + SettlementNo: "SET_20240101_001", + Status: domain.SettlementStatusPending, + TotalAmount: 1000, + NetAmount: 950, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + body := `{"withdraw_amount":500,"payment_method":"bank","payment_account":"1234567890","sms_code":"123456"}` + req := httptest.NewRequest("POST", "/api/v1/supply/settlements/withdraw", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Request-Id", "test-req-001") + w := httptest.NewRecorder() + + api.handleWithdraw(w, req) + + if w.Code != http.StatusCreated { + t.Errorf("expected status 201, got %d. Body: %s", w.Code, w.Body.String()) + } + + var resp map[string]any + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + if resp["request_id"] != "test-req-001" { + t.Errorf("expected request_id test-req-001, got %v", resp["request_id"]) + } + + data, ok := resp["data"].(map[string]any) + if !ok { + t.Fatal("expected data in response") + } + if data["settlement_id"] != float64(1) { + t.Errorf("expected settlement_id 1, got %v", data["settlement_id"]) + } + if data["status"] != "pending" { + t.Errorf("expected status pending, got %v", data["status"]) + } +} + +func TestSupplyAPI_WithdrawDisabled_ReturnsServiceUnavailable(t *testing.T) { + api, _, _, _, _, _ := newTestAPI() + api.withdrawEnabled = false + + body := `{"withdraw_amount":500,"payment_method":"bank","payment_account":"1234567890","sms_code":"123456"}` + req := httptest.NewRequest("POST", "/api/v1/supply/settlements/withdraw", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + api.handleWithdraw(w, req) + + if w.Code != http.StatusServiceUnavailable { + t.Fatalf("expected status 503, got %d body=%s", w.Code, w.Body.String()) + } +} + +func TestSupplyAPI_EndToEnd_BillingSummary(t *testing.T) { + api, _, _, _, earningSvc, _ := newTestAPI() + earningSvc.billingSummary = &domain.BillingSummary{ + Period: domain.BillingPeriod{ + Start: "2024-01-01", + End: "2024-01-31", + }, + Summary: domain.BillingTotal{ + TotalRevenue: 10000, + TotalOrders: 100, + TotalUsage: 1000000, + TotalRequests: 5000000, + AvgSuccessRate: 99.5, + }, + } + + req := httptest.NewRequest("GET", "/api/v1/supply/billing?start_date=2024-01-01&end_date=2024-01-31", nil) + w := httptest.NewRecorder() + + api.handleGetBilling(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + var resp map[string]any + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + data, ok := resp["data"].(map[string]any) + if !ok { + t.Fatal("expected data in response") + } + + summary, ok := data["summary"].(map[string]any) + if !ok { + t.Fatal("expected summary in data") + } + if summary["total_revenue"] != float64(10000) { + t.Errorf("expected total_revenue 10000, got %v", summary["total_revenue"]) + } +} diff --git a/supply-api/internal/middleware/auth.go b/supply-api/internal/middleware/auth.go index e157c2e3..b02a0fba 100644 --- a/supply-api/internal/middleware/auth.go +++ b/supply-api/internal/middleware/auth.go @@ -2,9 +2,12 @@ package middleware import ( "context" + "crypto/rsa" "crypto/sha256" + "crypto/x509" "encoding/hex" "encoding/json" + "encoding/pem" "errors" "fmt" "log" @@ -30,21 +33,23 @@ type TokenClaims struct { // AuthConfig 鉴权中间件配置 type AuthConfig struct { - SecretKey string - Issuer string - CacheTTL time.Duration // token状态缓存TTL - Enabled bool // 是否启用鉴权 - TrustedProxies []string // 可信代理IP列表CIDR,如 "10.0.0.0/8" + SecretKey string + PublicKey string + Algorithm string + Issuer string + CacheTTL time.Duration // token状态缓存TTL + Enabled bool // 是否启用鉴权 + TrustedProxies []string // 可信代理IP列表CIDR,如 "10.0.0.0/8" } // AuthMiddleware 鉴权中间件 type AuthMiddleware struct { - config AuthConfig - tokenCache *TokenCache - tokenBackend TokenStatusBackend - auditEmitter AuditEmitter - bruteForce *BruteForceProtection // 暴力破解保护 - trustedProxies []string // 可信代理列表 + config AuthConfig + tokenCache *TokenCache + tokenBackend TokenStatusBackend + auditEmitter AuditEmitter + bruteForce *BruteForceProtection // 暴力破解保护 + trustedProxies []string // 可信代理列表 } // TokenStatusBackend Token状态后端查询接口 @@ -200,6 +205,11 @@ func (b *BruteForceProtection) Len() int { // 对应M-016指标 func (m *AuthMiddleware) QueryKeyRejectMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if shouldBypassAuth(r.URL.Path) { + next.ServeHTTP(w, r) + return + } + // 检查query string中的可疑参数 queryParams := r.URL.Query() @@ -257,6 +267,11 @@ func (m *AuthMiddleware) QueryKeyRejectMiddleware(next http.Handler) http.Handle // BearerExtractMiddleware 提取Bearer Token func (m *AuthMiddleware) BearerExtractMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if shouldBypassAuth(r.URL.Path) { + next.ServeHTTP(w, r) + return + } + authHeader := r.Header.Get("Authorization") if authHeader == "" { @@ -299,6 +314,11 @@ func (m *AuthMiddleware) BearerExtractMiddleware(next http.Handler) http.Handler // MED-12: 添加暴力破解保护 func (m *AuthMiddleware) TokenVerifyMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if shouldBypassAuth(r.URL.Path) { + next.ServeHTTP(w, r) + return + } + // 如果鉴权被禁用(仅用于开发环境),直接跳过验证 if !m.config.Enabled { // 在开发模式下,虽然跳过JWT验证,但仍记录警告日志 @@ -356,7 +376,25 @@ func (m *AuthMiddleware) TokenVerifyMiddleware(next http.Handler) http.Handler { // 检查token状态(是否被吊销) status, err := m.checkTokenStatus(r.Context(), claims.ID) - if err == nil && status != "active" { + if err != nil { + if m.auditEmitter != nil { + m.auditEmitter.Emit(r.Context(), AuditEvent{ + EventName: "token.authn.fail", + RequestID: getRequestID(r), + TokenID: claims.ID, + SubjectID: claims.SubjectID, + Route: sanitizeRoute(r.URL.Path), + ResultCode: "AUTH_TOKEN_STATUS_UNAVAILABLE", + ClientIP: getClientIP(r), + CreatedAt: time.Now(), + }) + } + + writeAuthError(w, http.StatusUnauthorized, "AUTH_TOKEN_STATUS_UNAVAILABLE", + "token status backend is unavailable") + return + } + if status != "active" { if m.auditEmitter != nil { m.auditEmitter.Emit(r.Context(), AuditEvent{ EventName: "token.authn.fail", @@ -435,10 +473,10 @@ func (m *AuthMiddleware) ScopeRoleAuthzMiddleware(requiredScope string) func(htt // viewer: level 10, operator: level 30, org_admin: level 50 routeRoles := map[string]string{ "/api/v1/supply/accounts": "org_admin", - "/api/v1/supply/packages": "org_admin", - "/api/v1/supply/settlements": "org_admin", - "/api/v1/supply/billing": "viewer", - "/api/v1/supplier/billing": "viewer", + "/api/v1/supply/packages": "org_admin", + "/api/v1/supply/settlements": "org_admin", + "/api/v1/supply/billing": "viewer", + "/api/v1/supplier/billing": "viewer", } for path, requiredRole := range routeRoles { @@ -458,12 +496,16 @@ func (m *AuthMiddleware) ScopeRoleAuthzMiddleware(requiredScope string) func(htt // verifyToken 校验JWT token func (m *AuthMiddleware) verifyToken(tokenString string) (*TokenClaims, error) { + expectedAlgorithm := strings.ToUpper(strings.TrimSpace(m.config.Algorithm)) + if expectedAlgorithm == "" { + expectedAlgorithm = jwt.SigningMethodHS256.Alg() + } + token, err := jwt.ParseWithClaims(tokenString, &TokenClaims{}, func(token *jwt.Token) (interface{}, error) { - // 严格验证算法:只接受HS256 - if token.Method.Alg() != jwt.SigningMethodHS256.Alg() { + if token.Method.Alg() != expectedAlgorithm { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } - return []byte(m.config.SecretKey), nil + return m.signingKey(expectedAlgorithm) }) if err != nil { @@ -492,6 +534,53 @@ func (m *AuthMiddleware) verifyToken(tokenString string) (*TokenClaims, error) { return nil, errors.New("invalid token") } +func (m *AuthMiddleware) signingKey(algorithm string) (interface{}, error) { + switch algorithm { + case jwt.SigningMethodHS256.Alg(), jwt.SigningMethodHS384.Alg(), jwt.SigningMethodHS512.Alg(): + if strings.TrimSpace(m.config.SecretKey) == "" { + return nil, errors.New("missing token secret key") + } + return []byte(m.config.SecretKey), nil + case jwt.SigningMethodRS256.Alg(), jwt.SigningMethodRS384.Alg(), jwt.SigningMethodRS512.Alg(): + return parseRSAPublicKey(m.config.PublicKey) + default: + return nil, fmt.Errorf("unsupported signing method: %s", algorithm) + } +} + +func parseRSAPublicKey(publicKeyPEM string) (*rsa.PublicKey, error) { + block, _ := pem.Decode([]byte(strings.TrimSpace(publicKeyPEM))) + if block == nil { + return nil, errors.New("invalid RSA public key: PEM decode failed") + } + + pub, err := x509.ParsePKIXPublicKey(block.Bytes) + if err == nil { + rsaPub, ok := pub.(*rsa.PublicKey) + if !ok { + return nil, errors.New("invalid RSA public key type") + } + return rsaPub, nil + } + + cert, certErr := x509.ParseCertificate(block.Bytes) + if certErr == nil { + rsaPub, ok := cert.PublicKey.(*rsa.PublicKey) + if !ok { + return nil, errors.New("invalid RSA certificate public key type") + } + return rsaPub, nil + } + + return nil, fmt.Errorf("invalid RSA public key: %w", err) +} + +func shouldBypassAuth(path string) bool { + return path == "/actuator/health" || + path == "/actuator/health/live" || + path == "/actuator/health/ready" +} + // checkTokenStatus 检查token状态(从缓存或数据库) func (m *AuthMiddleware) checkTokenStatus(ctx context.Context, tokenID string) (string, error) { if m.tokenCache != nil { diff --git a/supply-api/internal/middleware/auth_test.go b/supply-api/internal/middleware/auth_test.go index 7bb95feb..b7d4aefb 100644 --- a/supply-api/internal/middleware/auth_test.go +++ b/supply-api/internal/middleware/auth_test.go @@ -2,6 +2,7 @@ package middleware import ( "context" + "errors" "net/http" "net/http/httptest" "strings" @@ -13,6 +14,15 @@ import ( "lijiaoqiao/supply-api/internal/iam/model" ) +type stubTokenStatusBackend struct { + status string + err error +} + +func (b *stubTokenStatusBackend) CheckTokenStatus(ctx context.Context, tokenID string) (string, error) { + return b.status, b.err +} + func TestTokenVerify(t *testing.T) { secretKey := "test-secret-key-12345678901234567890" issuer := "test-issuer" @@ -431,6 +441,38 @@ func TestMED02_TokenCacheMiss_ShouldNotAssumeActive(t *testing.T) { } } +func TestTokenVerifyMiddleware_BackendErrorShouldReject(t *testing.T) { + secretKey := "test-secret-key-12345678901234567890" + issuer := "test-issuer" + + authMiddleware := NewAuthMiddleware(AuthConfig{ + SecretKey: secretKey, + Issuer: issuer, + Enabled: true, + }, NewTokenCache(), &stubTokenStatusBackend{err: errors.New("database unavailable")}, nil) + + nextCalled := false + handler := authMiddleware.TokenVerifyMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nextCalled = true + })) + + req := httptest.NewRequest("GET", "/api/v1/supply/accounts", nil) + req = req.WithContext(context.WithValue(req.Context(), bearerTokenKey, createTestToken(secretKey, issuer, "subject:1", "org_admin", time.Now().Add(time.Hour)))) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + if nextCalled { + t.Fatal("expected request to be rejected when token backend is unavailable") + } + if w.Code != http.StatusUnauthorized { + t.Fatalf("expected status 401, got %d", w.Code) + } + if !strings.Contains(w.Body.String(), "AUTH_TOKEN_STATUS_UNAVAILABLE") { + t.Fatalf("expected response to contain AUTH_TOKEN_STATUS_UNAVAILABLE, got %s", w.Body.String()) + } +} + // Helper functions func createTestToken(secretKey, issuer, subject, role string, expiresAt time.Time) string { diff --git a/supply-api/internal/middleware/db_token_backend.go b/supply-api/internal/middleware/db_token_backend.go index 91f91608..692304c5 100644 --- a/supply-api/internal/middleware/db_token_backend.go +++ b/supply-api/internal/middleware/db_token_backend.go @@ -34,9 +34,9 @@ type TokenCacheBackend interface { // DBTokenStatusBackend DB-backed Token状态后端(P0-03修复) // 同时实现 TokenStatusBackend 和 TokenRevocationBackend 接口 type DBTokenStatusBackend struct { - repo TokenRepository - redisCache TokenCacheBackend - cacheTTL time.Duration + repo TokenRepository + redisCache TokenCacheBackend + cacheTTL time.Duration } // NewDBTokenStatusBackend 创建DB-backed Token状态后端 @@ -119,6 +119,17 @@ func (b *DBTokenStatusBackend) GetTokenStatus(ctx context.Context, tokenID strin // RevokeBySubjectID 根据SubjectID吊销所有Token func (b *DBTokenStatusBackend) RevokeBySubjectID(ctx context.Context, subjectID int64, reason string) error { + var tokenIDs []string + if b.redisCache != nil { + records, err := b.repo.ListActiveBySubjectID(ctx, subjectID) + if err == nil { + tokenIDs = make([]string, 0, len(records)) + for _, record := range records { + tokenIDs = append(tokenIDs, record.TokenID) + } + } + } + // 1. 批量更新数据库 count, err := b.repo.RevokeBySubjectID(ctx, subjectID, reason) if err != nil { @@ -132,13 +143,8 @@ func (b *DBTokenStatusBackend) RevokeBySubjectID(ctx context.Context, subjectID // 2. 失效所有相关缓存(这里需要查询后逐个失效) // 注意:生产环境建议使用Redis的pattern删除或发布事件通知 if b.redisCache != nil { - // 查询所有活跃token并失效 - records, err := b.repo.ListActiveBySubjectID(ctx, subjectID) - if err != nil { - return nil // 不影响主流程 - } - for _, record := range records { - _ = b.redisCache.InvalidateToken(ctx, record.TokenID) + for _, tokenID := range tokenIDs { + _ = b.redisCache.InvalidateToken(ctx, tokenID) } } diff --git a/supply-api/internal/middleware/db_token_backend_test.go b/supply-api/internal/middleware/db_token_backend_test.go index 6c76b7e2..066bcce5 100644 --- a/supply-api/internal/middleware/db_token_backend_test.go +++ b/supply-api/internal/middleware/db_token_backend_test.go @@ -14,11 +14,11 @@ import ( // MockTokenStatusRepository mock Token状态仓储 type MockTokenStatusRepository struct { - mu sync.RWMutex - tokenStatuses map[string]string - tokenReasons map[string]string - verificationCounts map[string]int - subjectTokens map[int64][]string + mu sync.RWMutex + tokenStatuses map[string]string + tokenReasons map[string]string + verificationCounts map[string]int + subjectTokens map[int64][]string } func NewMockTokenStatusRepository() *MockTokenStatusRepository { @@ -84,9 +84,9 @@ func (m *MockTokenStatusRepository) ListActiveBySubjectID(ctx context.Context, s // MockRedisCache mock Redis缓存 type MockRedisCache struct { - mu sync.RWMutex - tokenCache map[string]*cache.TokenStatus - subscribers []func(event *cache.TokenRevokedCacheEvent) + mu sync.RWMutex + tokenCache map[string]*cache.TokenStatus + subscribers []func(event *cache.TokenRevokedCacheEvent) } func NewMockRedisCache() *MockRedisCache { @@ -136,7 +136,7 @@ func (m *MockRedisCache) PublishRevocation(tokenID string, reason string) { for _, handler := range handlers { handler(&cache.TokenRevokedCacheEvent{ TokenID: tokenID, - Reason: reason, + Reason: reason, }) } } @@ -165,9 +165,9 @@ type TokenStatusRepositoryInterface interface { // DBTokenStatusBackendForTest 用于测试的DBTokenStatusBackend type DBTokenStatusBackendForTest struct { - repo TokenStatusRepositoryInterface - redisCache *MockRedisCache - cacheTTL time.Duration + repo TokenStatusRepositoryInterface + redisCache *MockRedisCache + cacheTTL time.Duration } func NewDBTokenStatusBackendForTest(repo TokenStatusRepositoryInterface, redisCache *MockRedisCache, cacheTTL time.Duration) *DBTokenStatusBackendForTest { @@ -373,6 +373,31 @@ func TestDBTokenStatusBackend_RevokeBySubjectID(t *testing.T) { repo.mu.RUnlock() } +func TestDBTokenStatusBackend_RevokeBySubjectID_InvalidatesCachedTokens(t *testing.T) { + repo := NewMockTokenStatusRepository() + redisCache := NewMockRedisCache() + + repo.subjectTokens[123] = []string{"token1", "token2"} + redisCache.tokenCache["token1"] = &cache.TokenStatus{TokenID: "token1", Status: "active"} + redisCache.tokenCache["token2"] = &cache.TokenStatus{TokenID: "token2", Status: "active"} + + backend := NewDBTokenStatusBackend(repo, redisCache, 10*time.Second) + + err := backend.RevokeBySubjectID(context.Background(), 123, "bulk revocation") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + redisCache.mu.RLock() + defer redisCache.mu.RUnlock() + if _, ok := redisCache.tokenCache["token1"]; ok { + t.Fatal("expected token1 cache to be invalidated") + } + if _, ok := redisCache.tokenCache["token2"]; ok { + t.Fatal("expected token2 cache to be invalidated") + } +} + func TestDBTokenStatusBackend_RevokeBySubjectID_NoTokens(t *testing.T) { repo := NewMockTokenStatusRepository() redisCache := NewMockRedisCache() @@ -412,10 +437,10 @@ func TestDBTokenStatusBackend_InterfaceCompliance(t *testing.T) { // 测试各种状态转换 tests := []struct { - name string - tokenID string - initialStatus string - action func() error + name string + tokenID string + initialStatus string + action func() error expectedStatus string }{ { @@ -780,7 +805,7 @@ func TestDBTokenStatusBackend_StartRevocationSubscriber_NoRedisCache(t *testing. // MockTokenRevocationBackend mock TokenRevocationBackend type MockTokenRevocationBackend struct { - mu sync.RWMutex + mu sync.RWMutex revokedTokens map[string]string } diff --git a/supply-api/internal/middleware/idempotency.go b/supply-api/internal/middleware/idempotency.go index 686838c8..89d183a8 100644 --- a/supply-api/internal/middleware/idempotency.go +++ b/supply-api/internal/middleware/idempotency.go @@ -278,6 +278,11 @@ func getTenantID(ctx context.Context) int64 { return 0 } +// GetTenantID 公开函数,从context获取租户ID +func GetTenantID(ctx context.Context) int64 { + return getTenantID(ctx) +} + func getOperatorID(ctx context.Context) int64 { if v := ctx.Value(operatorIDKey); v != nil { if id, ok := v.(int64); ok { diff --git a/supply-api/internal/outbox/outbox.go b/supply-api/internal/outbox/outbox.go index 36db8f60..931bcd89 100644 --- a/supply-api/internal/outbox/outbox.go +++ b/supply-api/internal/outbox/outbox.go @@ -2,6 +2,7 @@ package outbox import ( "context" + "fmt" "log" "math" "time" @@ -13,7 +14,7 @@ import ( // OutboxProcessorRunner Outbox处理器运行器 type OutboxProcessorRunner struct { - repo *repository.OutboxRepository + repo outboxRepository msgBroker messaging.MessageBroker stats messaging.OutboxStats stopCh chan struct{} @@ -21,9 +22,16 @@ type OutboxProcessorRunner struct { interval time.Duration } +type outboxRepository interface { + FetchAndLock(ctx context.Context, limit int) ([]*repository.OutboxEvent, error) + MarkCompleted(ctx context.Context, eventID string) error + MarkFailed(ctx context.Context, eventID string, errorMsg string, nextRetryAt *time.Time) error + MoveToDeadLetter(ctx context.Context, event *repository.OutboxEvent, errorMsg string) error +} + // NewOutboxProcessorRunner 创建Outbox处理器运行器 func NewOutboxProcessorRunner( - repo *repository.OutboxRepository, + repo outboxRepository, msgBroker messaging.MessageBroker, stats messaging.OutboxStats, ) *OutboxProcessorRunner { @@ -66,6 +74,10 @@ func (r *OutboxProcessorRunner) Stop() { // process 处理一批Outbox事件 func (r *OutboxProcessorRunner) process(ctx context.Context) error { + if r.msgBroker == nil { + return fmt.Errorf("outbox message broker is unavailable") + } + // 获取待处理事件 events, err := r.repo.FetchAndLock(ctx, r.batchSize) if err != nil { @@ -85,7 +97,7 @@ func (r *OutboxProcessorRunner) process(ctx context.Context) error { EventType: event.EventType, EventID: event.EventID, Payload: event.Payload, - Status: string(event.Status), + Status: string(event.Status), RetryCount: event.RetryCount, MaxRetries: event.MaxRetries, ErrorMessage: event.ErrorMessage, diff --git a/supply-api/internal/outbox/outbox_test.go b/supply-api/internal/outbox/outbox_test.go new file mode 100644 index 00000000..c7cc106a --- /dev/null +++ b/supply-api/internal/outbox/outbox_test.go @@ -0,0 +1,58 @@ +package outbox + +import ( + "context" + "encoding/json" + "strings" + "testing" + "time" + + "lijiaoqiao/supply-api/internal/messaging" + "lijiaoqiao/supply-api/internal/repository" +) + +type stubRunnerRepo struct { + events []*repository.OutboxEvent +} + +func (r *stubRunnerRepo) FetchAndLock(ctx context.Context, limit int) ([]*repository.OutboxEvent, error) { + return r.events, nil +} + +func (r *stubRunnerRepo) MarkCompleted(ctx context.Context, eventID string) error { + return nil +} + +func (r *stubRunnerRepo) MarkFailed(ctx context.Context, eventID string, errorMsg string, nextRetryAt *time.Time) error { + return nil +} + +func (r *stubRunnerRepo) MoveToDeadLetter(ctx context.Context, event *repository.OutboxEvent, errorMsg string) error { + return nil +} + +func TestOutboxProcessorRunner_ProcessRejectsNilMessageBroker(t *testing.T) { + payload := json.RawMessage(`{"event":"created"}`) + runner := NewOutboxProcessorRunner(&stubRunnerRepo{ + events: []*repository.OutboxEvent{ + { + ID: 1, + AggregateType: "account", + AggregateID: "acc-1", + EventType: "created", + EventID: "evt-1", + Payload: payload, + Status: repository.OutboxStatusProcessing, + MaxRetries: 5, + }, + }, + }, nil, &messaging.NoOpOutboxStats{}) + + err := runner.process(context.Background()) + if err == nil { + t.Fatal("expected nil message broker to return error") + } + if !strings.Contains(err.Error(), "message broker") { + t.Fatalf("expected error to mention message broker, got %v", err) + } +} diff --git a/supply-api/internal/repository/outbox.go b/supply-api/internal/repository/outbox.go new file mode 100644 index 00000000..fb2629f7 --- /dev/null +++ b/supply-api/internal/repository/outbox.go @@ -0,0 +1,343 @@ +package repository + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxpool" +) + +// OutboxStatus Outbox事件状态 +type OutboxStatus string + +const ( + OutboxStatusPending OutboxStatus = "pending" + OutboxStatusProcessing OutboxStatus = "processing" + OutboxStatusCompleted OutboxStatus = "completed" + OutboxStatusFailed OutboxStatus = "failed" + OutboxStatusDeadLetter OutboxStatus = "dead_letter" +) + +// OutboxEvent Outbox事件 +type OutboxEvent struct { + ID int64 `json:"id"` + AggregateType string `json:"aggregate_type"` + AggregateID string `json:"aggregate_id"` + EventType string `json:"event_type"` + EventID string `json:"event_id"` + Payload json.RawMessage `json:"payload"` + Status OutboxStatus `json:"status"` + RetryCount int `json:"retry_count"` + MaxRetries int `json:"max_retries"` + ErrorMessage string `json:"error_message,omitempty"` + CreatedAt time.Time `json:"created_at"` + ProcessedAt *time.Time `json:"processed_at,omitempty"` + NextRetryAt *time.Time `json:"next_retry_at,omitempty"` + DeadLetterReason string `json:"dead_letter_reason,omitempty"` + Version int64 `json:"version"` +} + +// OutboxDeadLetter 死信记录 +type OutboxDeadLetter struct { + ID int64 `json:"id"` + OriginalEventID string `json:"original_event_id"` + OriginalAggregateType string `json:"original_aggregate_type"` + OriginalAggregateID string `json:"original_aggregate_id"` + EventType string `json:"event_type"` + Payload json.RawMessage `json:"payload"` + ErrorMessage string `json:"error_message,omitempty"` + RetryCount int `json:"retry_count"` + FirstFailedAt time.Time `json:"first_failed_at"` + DeadLetterAt time.Time `json:"dead_letter_at"` + Handled bool `json:"handled"` + HandledAt *time.Time `json:"handled_at,omitempty"` + HandlerNotes string `json:"handler_notes,omitempty"` + CreatedAt time.Time `json:"created_at"` +} + +// OutboxRepository Outbox仓储 +type OutboxRepository struct { + db outboxDB +} + +type outboxDB interface { + Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) + QueryRow(ctx context.Context, sql string, args ...any) pgx.Row + Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) + Begin(ctx context.Context) (pgx.Tx, error) +} + +// NewOutboxRepository 创建Outbox仓储 +func NewOutboxRepository(pool *pgxpool.Pool) *OutboxRepository { + return &OutboxRepository{db: pool} +} + +// Create 创建Outbox事件 +func (r *OutboxRepository) Create(ctx context.Context, event *OutboxEvent) error { + query := ` + INSERT INTO supply_outbox ( + aggregate_type, aggregate_id, event_type, event_id, + payload, status, retry_count, max_retries + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + RETURNING id, created_at + ` + + err := r.db.QueryRow(ctx, query, + event.AggregateType, event.AggregateID, event.EventType, event.EventID, + event.Payload, event.Status, event.RetryCount, event.MaxRetries, + ).Scan(&event.ID, &event.CreatedAt) + + if err != nil { + return fmt.Errorf("failed to create outbox event: %w", err) + } + return nil +} + +// FetchAndLock 获取并锁定待处理事件(使用FOR UPDATE SKIP LOCKED实现分布式锁) +func (r *OutboxRepository) FetchAndLock(ctx context.Context, limit int) ([]*OutboxEvent, error) { + query := ` + WITH claimed AS ( + SELECT id + FROM supply_outbox + WHERE status IN ('pending', 'failed') + AND (next_retry_at IS NULL OR next_retry_at <= CURRENT_TIMESTAMP) + ORDER BY created_at ASC + LIMIT $1 + FOR UPDATE SKIP LOCKED + ) + UPDATE supply_outbox AS o + SET status = 'processing', + version = o.version + 1 + FROM claimed + WHERE o.id = claimed.id + RETURNING o.id, o.aggregate_type, o.aggregate_id, o.event_type, o.event_id, + o.payload, o.status, o.retry_count, o.max_retries, o.error_message, + o.created_at, o.processed_at, o.next_retry_at, o.version + ` + + rows, err := r.db.Query(ctx, query, limit) + if err != nil { + return nil, fmt.Errorf("failed to fetch outbox events: %w", err) + } + defer rows.Close() + + var events []*OutboxEvent + for rows.Next() { + event := &OutboxEvent{} + err := rows.Scan( + &event.ID, &event.AggregateType, &event.AggregateID, &event.EventType, + &event.EventID, &event.Payload, &event.Status, &event.RetryCount, + &event.MaxRetries, &event.ErrorMessage, &event.CreatedAt, + &event.ProcessedAt, &event.NextRetryAt, &event.Version, + ) + if err != nil { + return nil, fmt.Errorf("failed to scan outbox event: %w", err) + } + + events = append(events, event) + } + + return events, nil +} + +// MarkCompleted 标记事件为已完成 +func (r *OutboxRepository) MarkCompleted(ctx context.Context, eventID string) error { + query := ` + UPDATE supply_outbox SET + status = 'completed', + processed_at = CURRENT_TIMESTAMP, + version = version + 1 + WHERE event_id = $1 AND status = 'processing' + ` + + result, err := r.db.Exec(ctx, query, eventID) + if err != nil { + return fmt.Errorf("failed to mark outbox event completed: %w", err) + } + + if result.RowsAffected() == 0 { + return fmt.Errorf("event not found or not in processing status") + } + + return nil +} + +// MarkFailed 标记事件为失败(并更新重试信息) +func (r *OutboxRepository) MarkFailed(ctx context.Context, eventID string, errorMsg string, nextRetryAt *time.Time) error { + query := ` + UPDATE supply_outbox SET + status = 'failed', + error_message = $2, + next_retry_at = $3, + version = version + 1 + WHERE event_id = $1 AND status = 'processing' + ` + + result, err := r.db.Exec(ctx, query, eventID, errorMsg, nextRetryAt) + if err != nil { + return fmt.Errorf("failed to mark outbox event failed: %w", err) + } + + if result.RowsAffected() == 0 { + return fmt.Errorf("event not found or not in processing status") + } + + return nil +} + +// MoveToDeadLetter 将事件移入死信队列 +func (r *OutboxRepository) MoveToDeadLetter(ctx context.Context, event *OutboxEvent, errorMsg string) error { + tx, err := r.db.Begin(ctx) + if err != nil { + return fmt.Errorf("failed to begin transaction: %w", err) + } + defer tx.Rollback(ctx) + + // 获取事件详情 + selectQuery := ` + SELECT id, aggregate_type, aggregate_id, event_type, event_id, + payload, retry_count, created_at + FROM supply_outbox + WHERE event_id = $1 + FOR UPDATE + ` + + var originalEvent OutboxEvent + err = tx.QueryRow(ctx, selectQuery, event.EventID).Scan( + &originalEvent.ID, &originalEvent.AggregateType, &originalEvent.AggregateID, + &originalEvent.EventType, &originalEvent.EventID, &originalEvent.Payload, + &originalEvent.RetryCount, &originalEvent.CreatedAt, + ) + if err != nil { + return fmt.Errorf("failed to get original event: %w", err) + } + + // 插入死信记录 + insertQuery := ` + INSERT INTO supply_outbox_dead_letter ( + original_event_id, original_aggregate_type, original_aggregate_id, + event_type, payload, error_message, retry_count, first_failed_at + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + ` + + _, err = tx.Exec(ctx, insertQuery, + originalEvent.EventID, originalEvent.AggregateType, originalEvent.AggregateID, + originalEvent.EventType, originalEvent.Payload, errorMsg, + originalEvent.RetryCount, originalEvent.CreatedAt, + ) + if err != nil { + return fmt.Errorf("failed to insert dead letter: %w", err) + } + + // 删除原始事件 + deleteQuery := `DELETE FROM supply_outbox WHERE event_id = $1` + _, err = tx.Exec(ctx, deleteQuery, event.EventID) + if err != nil { + return fmt.Errorf("failed to delete original event: %w", err) + } + + if err := tx.Commit(ctx); err != nil { + return fmt.Errorf("failed to commit transaction: %w", err) + } + + return nil +} + +// GetDeadLetterByEventID 根据原始事件ID获取死信记录 +func (r *OutboxRepository) GetDeadLetterByEventID(ctx context.Context, originalEventID string) (*OutboxDeadLetter, error) { + query := ` + SELECT id, original_event_id, original_aggregate_type, original_aggregate_id, + event_type, payload, error_message, retry_count, first_failed_at, + dead_letter_at, handled, handled_at, handler_notes, created_at + FROM supply_outbox_dead_letter + WHERE original_event_id = $1 + ` + + dl := &OutboxDeadLetter{} + err := r.db.QueryRow(ctx, query, originalEventID).Scan( + &dl.ID, &dl.OriginalEventID, &dl.OriginalAggregateType, &dl.OriginalAggregateID, + &dl.EventType, &dl.Payload, &dl.ErrorMessage, &dl.RetryCount, + &dl.FirstFailedAt, &dl.DeadLetterAt, &dl.Handled, &dl.HandledAt, + &dl.HandlerNotes, &dl.CreatedAt, + ) + + if errors.Is(err, pgx.ErrNoRows) { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("failed to get dead letter: %w", err) + } + + return dl, nil +} + +// ListUnhandledDeadLetters 列出未处理的死信记录 +func (r *OutboxRepository) ListUnhandledDeadLetters(ctx context.Context, limit int) ([]*OutboxDeadLetter, error) { + query := ` + SELECT id, original_event_id, original_aggregate_type, original_aggregate_id, + event_type, payload, error_message, retry_count, first_failed_at, + dead_letter_at, handled, handled_at, handler_notes, created_at + FROM supply_outbox_dead_letter + WHERE handled = FALSE + ORDER BY dead_letter_at ASC + LIMIT $1 + ` + + rows, err := r.db.Query(ctx, query, limit) + if err != nil { + return nil, fmt.Errorf("failed to list dead letters: %w", err) + } + defer rows.Close() + + var dls []*OutboxDeadLetter + for rows.Next() { + dl := &OutboxDeadLetter{} + err := rows.Scan( + &dl.ID, &dl.OriginalEventID, &dl.OriginalAggregateType, &dl.OriginalAggregateID, + &dl.EventType, &dl.Payload, &dl.ErrorMessage, &dl.RetryCount, + &dl.FirstFailedAt, &dl.DeadLetterAt, &dl.Handled, &dl.HandledAt, + &dl.HandlerNotes, &dl.CreatedAt, + ) + if err != nil { + return nil, fmt.Errorf("failed to scan dead letter: %w", err) + } + dls = append(dls, dl) + } + + return dls, nil +} + +// MarkDeadLetterHandled 标记死信已处理 +func (r *OutboxRepository) MarkDeadLetterHandled(ctx context.Context, id int64, notes string) error { + query := ` + UPDATE supply_outbox_dead_letter SET + handled = TRUE, + handled_at = CURRENT_TIMESTAMP, + handler_notes = $2 + WHERE id = $1 + ` + + _, err := r.db.Exec(ctx, query, id, notes) + if err != nil { + return fmt.Errorf("failed to mark dead letter handled: %w", err) + } + + return nil +} + +// DeleteCompleted 删除已完成的旧事件(定时清理) +func (r *OutboxRepository) DeleteCompleted(ctx context.Context, before time.Time) (int64, error) { + query := `DELETE FROM supply_outbox WHERE status = 'completed' AND processed_at < $1` + + result, err := r.db.Exec(ctx, query, before) + if err != nil { + return 0, fmt.Errorf("failed to delete completed events: %w", err) + } + + return result.RowsAffected(), nil +} diff --git a/supply-api/internal/repository/outbox_test.go b/supply-api/internal/repository/outbox_test.go new file mode 100644 index 00000000..7ecc7033 --- /dev/null +++ b/supply-api/internal/repository/outbox_test.go @@ -0,0 +1,114 @@ +package repository + +import ( + "context" + "encoding/json" + "strings" + "testing" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +type stubOutboxDB struct { + querySQL string + rows pgx.Rows +} + +func (s *stubOutboxDB) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) { + s.querySQL = sql + return s.rows, nil +} + +func (s *stubOutboxDB) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row { + panic("unexpected QueryRow call") +} + +func (s *stubOutboxDB) Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) { + panic("unexpected Exec call") +} + +func (s *stubOutboxDB) Begin(ctx context.Context) (pgx.Tx, error) { + panic("unexpected Begin call") +} + +type stubOutboxRows struct { + events []*OutboxEvent + index int +} + +func (r *stubOutboxRows) Close() {} +func (r *stubOutboxRows) Err() error { return nil } +func (r *stubOutboxRows) CommandTag() pgconn.CommandTag { return pgconn.CommandTag{} } +func (r *stubOutboxRows) FieldDescriptions() []pgconn.FieldDescription { return nil } +func (r *stubOutboxRows) RawValues() [][]byte { return nil } +func (r *stubOutboxRows) Values() ([]any, error) { return nil, nil } +func (r *stubOutboxRows) Conn() *pgx.Conn { return nil } + +func (r *stubOutboxRows) Next() bool { + if r.index >= len(r.events) { + return false + } + r.index++ + return true +} + +func (r *stubOutboxRows) Scan(dest ...any) error { + event := r.events[r.index-1] + *(dest[0].(*int64)) = event.ID + *(dest[1].(*string)) = event.AggregateType + *(dest[2].(*string)) = event.AggregateID + *(dest[3].(*string)) = event.EventType + *(dest[4].(*string)) = event.EventID + *(dest[5].(*json.RawMessage)) = event.Payload + *(dest[6].(*OutboxStatus)) = event.Status + *(dest[7].(*int)) = event.RetryCount + *(dest[8].(*int)) = event.MaxRetries + *(dest[9].(*string)) = event.ErrorMessage + *(dest[10].(*time.Time)) = event.CreatedAt + *(dest[11].(**time.Time)) = event.ProcessedAt + *(dest[12].(**time.Time)) = event.NextRetryAt + *(dest[13].(*int64)) = event.Version + return nil +} + +func TestFetchAndLock_ClaimsEventsAsProcessingInDatabase(t *testing.T) { + now := time.Now() + payload := json.RawMessage(`{"event":"created"}`) + db := &stubOutboxDB{ + rows: &stubOutboxRows{ + events: []*OutboxEvent{ + { + ID: 1, + AggregateType: "account", + AggregateID: "acc-1", + EventType: "created", + EventID: "evt-1", + Payload: payload, + Status: OutboxStatusProcessing, + RetryCount: 0, + MaxRetries: 5, + CreatedAt: now, + Version: 2, + }, + }, + }, + } + repo := &OutboxRepository{db: db} + + events, err := repo.FetchAndLock(context.Background(), 1) + if err != nil { + t.Fatalf("FetchAndLock returned error: %v", err) + } + + if len(events) != 1 { + t.Fatalf("expected 1 event, got %d", len(events)) + } + if events[0].Status != OutboxStatusProcessing { + t.Fatalf("expected event status processing, got %s", events[0].Status) + } + if !strings.Contains(db.querySQL, "SET status = 'processing'") { + t.Fatalf("expected claim query to persist processing status, got SQL: %s", db.querySQL) + } +}