fix: close p0 auth and release gate gaps
This commit is contained in:
@@ -77,18 +77,36 @@ func main() {
|
||||
auditor = middleware.NewMemoryAuditEmitter()
|
||||
}
|
||||
|
||||
// 初始化 token 运行时(内存实现)
|
||||
tokenRuntime := middleware.NewInMemoryTokenRuntime(time.Now)
|
||||
// 初始化 token 运行时
|
||||
var tokenRuntime interface {
|
||||
middleware.TokenVerifier
|
||||
middleware.TokenStatusResolver
|
||||
}
|
||||
switch cfg.Auth.TokenRuntimeMode {
|
||||
case "inmemory":
|
||||
tokenRuntime = middleware.NewInMemoryTokenRuntime(time.Now)
|
||||
case "remote_introspection":
|
||||
tokenRuntime = middleware.NewRemoteTokenRuntime(cfg.Auth.TokenRuntimeURL, http.DefaultClient, time.Now)
|
||||
default:
|
||||
log.Fatalf("unsupported token runtime mode: %s", cfg.Auth.TokenRuntimeMode)
|
||||
}
|
||||
|
||||
// 构建认证中间件配置
|
||||
authMiddlewareConfig := middleware.AuthMiddlewareConfig{
|
||||
Verifier: tokenRuntime,
|
||||
StatusResolver: tokenRuntime,
|
||||
Authorizer: middleware.NewScopeRoleAuthorizer(),
|
||||
Auditor: auditor,
|
||||
ProtectedPrefixes: []string{"/api/v1/supply", "/api/v1/platform"},
|
||||
ExcludedPrefixes: []string{"/health", "/healthz", "/metrics", "/readyz"},
|
||||
Now: time.Now,
|
||||
Verifier: tokenRuntime,
|
||||
StatusResolver: tokenRuntime,
|
||||
Authorizer: middleware.NewScopeRoleAuthorizer(),
|
||||
Auditor: auditor,
|
||||
ProtectedPrefixes: []string{
|
||||
"/v1/chat/completions",
|
||||
"/v1/completions",
|
||||
"/api/v1/chat/completions",
|
||||
"/api/v1/completions",
|
||||
"/api/v1/supply",
|
||||
"/api/v1/platform",
|
||||
},
|
||||
ExcludedPrefixes: []string{"/health", "/healthz", "/metrics", "/readyz"},
|
||||
Now: time.Now,
|
||||
}
|
||||
|
||||
// 初始化Handler
|
||||
@@ -132,33 +150,38 @@ func main() {
|
||||
func createMux(h *handler.Handler, limiter *ratelimit.Middleware, authConfig middleware.AuthMiddlewareConfig) http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
// 创建认证处理链
|
||||
authHandler := middleware.BuildTokenAuthChain(authConfig, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
h.ChatCompletionsHandle(w, r)
|
||||
}))
|
||||
chatHandler := middleware.BuildTokenAuthChain(authConfig, http.HandlerFunc(h.ChatCompletionsHandle))
|
||||
completionsHandler := middleware.BuildTokenAuthChain(authConfig, http.HandlerFunc(h.CompletionsHandle))
|
||||
|
||||
// Chat Completions - 应用限流和认证
|
||||
mux.HandleFunc("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||
limiter.Limit(authHandler.ServeHTTP)(w, r)
|
||||
limitHandler(limiter, chatHandler).ServeHTTP(w, r)
|
||||
})
|
||||
|
||||
// Completions - 应用限流和认证
|
||||
mux.HandleFunc("/v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||
limiter.Limit(authHandler.ServeHTTP)(w, r)
|
||||
limitHandler(limiter, completionsHandler).ServeHTTP(w, r)
|
||||
})
|
||||
|
||||
// Models - 公开接口
|
||||
mux.HandleFunc("/v1/models", h.ModelsHandle)
|
||||
|
||||
// 旧版路径兼容
|
||||
mux.HandleFunc("/api/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||
h.ChatCompletionsHandle(w, r)
|
||||
limitHandler(limiter, chatHandler).ServeHTTP(w, r)
|
||||
})
|
||||
mux.HandleFunc("/api/v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||
limitHandler(limiter, completionsHandler).ServeHTTP(w, r)
|
||||
})
|
||||
|
||||
// Health - 排除认证
|
||||
mux.HandleFunc("/health", h.HealthHandle)
|
||||
mux.HandleFunc("/healthz", h.HealthHandle)
|
||||
mux.HandleFunc("/readyz", h.HealthHandle)
|
||||
|
||||
return mux
|
||||
return middleware.CORSMiddleware(middleware.DefaultCORSConfig())(mux)
|
||||
}
|
||||
|
||||
func limitHandler(limiter *ratelimit.Middleware, next http.Handler) http.Handler {
|
||||
if limiter == nil {
|
||||
return next
|
||||
}
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
limiter.Limit(next.ServeHTTP)(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
173
gateway/cmd/gateway/main_test.go
Normal file
173
gateway/cmd/gateway/main_test.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"lijiaoqiao/gateway/internal/adapter"
|
||||
"lijiaoqiao/gateway/internal/handler"
|
||||
"lijiaoqiao/gateway/internal/middleware"
|
||||
"lijiaoqiao/gateway/internal/ratelimit"
|
||||
"lijiaoqiao/gateway/internal/router"
|
||||
)
|
||||
|
||||
type testProvider struct {
|
||||
name string
|
||||
models []string
|
||||
}
|
||||
|
||||
func (p *testProvider) ChatCompletion(_ context.Context, model string, messages []adapter.Message, _ adapter.CompletionOptions) (*adapter.CompletionResponse, error) {
|
||||
content := ""
|
||||
if len(messages) > 0 {
|
||||
content = messages[0].Content
|
||||
}
|
||||
return &adapter.CompletionResponse{
|
||||
ID: "resp-1",
|
||||
Object: "chat.completion",
|
||||
Created: time.Now().Unix(),
|
||||
Model: model,
|
||||
Choices: []adapter.Choice{{Index: 0, Message: &adapter.Message{Role: "assistant", Content: content}, FinishReason: "stop"}},
|
||||
Usage: adapter.Usage{PromptTokens: 1, CompletionTokens: 1, TotalTokens: 2},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *testProvider) ChatCompletionStream(_ context.Context, _ string, _ []adapter.Message, _ adapter.CompletionOptions) (<-chan *adapter.StreamChunk, error) {
|
||||
ch := make(chan *adapter.StreamChunk)
|
||||
close(ch)
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (p *testProvider) GetUsage(response *adapter.CompletionResponse) adapter.Usage {
|
||||
return response.Usage
|
||||
}
|
||||
|
||||
func (p *testProvider) MapError(err error) adapter.ProviderError {
|
||||
return adapter.ProviderError{Code: "provider_error", Message: err.Error(), HTTPStatus: http.StatusBadGateway}
|
||||
}
|
||||
|
||||
func (p *testProvider) HealthCheck(context.Context) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *testProvider) ProviderName() string {
|
||||
return p.name
|
||||
}
|
||||
|
||||
func (p *testProvider) SupportedModels() []string {
|
||||
return p.models
|
||||
}
|
||||
|
||||
func TestCreateMux_ProtectsCompletionRoutes(t *testing.T) {
|
||||
now := time.Now()
|
||||
tokenRuntime := middleware.NewInMemoryTokenRuntime(func() time.Time { return now })
|
||||
r := router.NewRouter(router.StrategyLatency)
|
||||
h := handler.NewHandler(r)
|
||||
limiter := ratelimit.NewMiddleware(ratelimit.NewTokenBucketLimiter(60, 60000, 1.5))
|
||||
|
||||
authConfig := middleware.AuthMiddlewareConfig{
|
||||
Verifier: tokenRuntime,
|
||||
StatusResolver: tokenRuntime,
|
||||
Authorizer: middleware.NewScopeRoleAuthorizer(),
|
||||
Auditor: middleware.NewMemoryAuditEmitter(),
|
||||
ProtectedPrefixes: []string{
|
||||
"/v1/chat/completions",
|
||||
"/v1/completions",
|
||||
"/api/v1/chat/completions",
|
||||
"/api/v1/completions",
|
||||
},
|
||||
ExcludedPrefixes: []string{"/health", "/healthz", "/readyz"},
|
||||
Now: func() time.Time { return now },
|
||||
}
|
||||
|
||||
mux := createMux(h, limiter, authConfig)
|
||||
|
||||
for _, path := range []string{
|
||||
"/v1/chat/completions",
|
||||
"/v1/completions",
|
||||
"/api/v1/chat/completions",
|
||||
"/api/v1/completions",
|
||||
} {
|
||||
req := httptest.NewRequest(http.MethodPost, path, bytes.NewBufferString(`{}`))
|
||||
rr := httptest.NewRecorder()
|
||||
mux.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected %s to return 401, got %d", path, rr.Code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateMux_HealthRoutesRemainOpen(t *testing.T) {
|
||||
now := time.Now()
|
||||
tokenRuntime := middleware.NewInMemoryTokenRuntime(func() time.Time { return now })
|
||||
r := router.NewRouter(router.StrategyLatency)
|
||||
h := handler.NewHandler(r)
|
||||
limiter := ratelimit.NewMiddleware(ratelimit.NewTokenBucketLimiter(60, 60000, 1.5))
|
||||
|
||||
authConfig := middleware.AuthMiddlewareConfig{
|
||||
Verifier: tokenRuntime,
|
||||
StatusResolver: tokenRuntime,
|
||||
Authorizer: middleware.NewScopeRoleAuthorizer(),
|
||||
ProtectedPrefixes: []string{"/v1/chat/completions"},
|
||||
ExcludedPrefixes: []string{"/health", "/healthz", "/readyz"},
|
||||
Now: func() time.Time { return now },
|
||||
}
|
||||
|
||||
mux := createMux(h, limiter, authConfig)
|
||||
req := httptest.NewRequest(http.MethodGet, "/health", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
mux.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("expected /health to return 200, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateMux_CompletionsRouteUsesCompletionsHandler(t *testing.T) {
|
||||
now := time.Now()
|
||||
tokenRuntime := middleware.NewInMemoryTokenRuntime(func() time.Time { return now })
|
||||
token, err := tokenRuntime.Issue(context.Background(), "user1", "user", []string{"gateway:invoke"}, time.Hour)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to issue token: %v", err)
|
||||
}
|
||||
|
||||
r := router.NewRouter(router.StrategyLatency)
|
||||
r.RegisterProvider("test", &testProvider{name: "test", models: []string{"gpt-4"}})
|
||||
h := handler.NewHandler(r)
|
||||
limiter := ratelimit.NewMiddleware(ratelimit.NewTokenBucketLimiter(60, 60000, 1.5))
|
||||
|
||||
authConfig := middleware.AuthMiddlewareConfig{
|
||||
Verifier: tokenRuntime,
|
||||
StatusResolver: tokenRuntime,
|
||||
Authorizer: middleware.NewScopeRoleAuthorizer(),
|
||||
Auditor: middleware.NewMemoryAuditEmitter(),
|
||||
ProtectedPrefixes: []string{
|
||||
"/v1/completions",
|
||||
},
|
||||
ExcludedPrefixes: []string{"/health", "/healthz", "/readyz"},
|
||||
Now: func() time.Time { return now },
|
||||
}
|
||||
|
||||
mux := createMux(h, limiter, authConfig)
|
||||
reqBody := `{"model":"gpt-4","prompt":"hello","max_tokens":16}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/completions", bytes.NewBufferString(reqBody))
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
mux.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
body, _ := io.ReadAll(rr.Result().Body)
|
||||
t.Fatalf("expected 200, got %d: %s", rr.Code, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(rr.Result().Body)
|
||||
if !strings.Contains(string(body), `"object":"text_completion"`) {
|
||||
t.Fatalf("expected completions response, got %s", string(body))
|
||||
}
|
||||
}
|
||||
@@ -6,7 +6,9 @@ import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -17,33 +19,41 @@ var encryptionKey = []byte(getEnv("PASSWORD_ENCRYPTION_KEY", "default-key-32-byt
|
||||
|
||||
// Config 网关配置
|
||||
type Config struct {
|
||||
Server ServerConfig
|
||||
Database DatabaseConfig
|
||||
Redis RedisConfig
|
||||
Router RouterConfig
|
||||
Server ServerConfig
|
||||
Database DatabaseConfig
|
||||
Redis RedisConfig
|
||||
Auth AuthConfig
|
||||
Router RouterConfig
|
||||
RateLimit RateLimitConfig
|
||||
Alert AlertConfig
|
||||
Alert AlertConfig
|
||||
Providers []ProviderConfig
|
||||
}
|
||||
|
||||
// ServerConfig 服务配置
|
||||
type ServerConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
ReadTimeout time.Duration
|
||||
Host string
|
||||
Port int
|
||||
ReadTimeout time.Duration
|
||||
WriteTimeout time.Duration
|
||||
IdleTimeout time.Duration
|
||||
IdleTimeout time.Duration
|
||||
}
|
||||
|
||||
// AuthConfig 鉴权运行时配置
|
||||
type AuthConfig struct {
|
||||
Env string
|
||||
TokenRuntimeMode string
|
||||
TokenRuntimeURL string
|
||||
}
|
||||
|
||||
// DatabaseConfig 数据库配置
|
||||
type DatabaseConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
User string
|
||||
Password string // 兼容旧版本,仍可直接使用明文密码(不推荐)
|
||||
EncryptedPassword string // 加密后的密码,优先级高于Password字段
|
||||
Database string
|
||||
MaxConns int
|
||||
Host string
|
||||
Port int
|
||||
User string
|
||||
Password string // 兼容旧版本,仍可直接使用明文密码(不推荐)
|
||||
EncryptedPassword string // 加密后的密码,优先级高于Password字段
|
||||
Database string
|
||||
MaxConns int
|
||||
}
|
||||
|
||||
// GetPassword 返回解密后的数据库密码
|
||||
@@ -62,12 +72,12 @@ func (c *DatabaseConfig) GetPassword() string {
|
||||
|
||||
// RedisConfig Redis配置
|
||||
type RedisConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
Password string // 兼容旧版本
|
||||
EncryptedPassword string // 加密后的密码
|
||||
DB int
|
||||
PoolSize int
|
||||
Host string
|
||||
Port int
|
||||
Password string // 兼容旧版本
|
||||
EncryptedPassword string // 加密后的密码
|
||||
DB int
|
||||
PoolSize int
|
||||
}
|
||||
|
||||
// GetPassword 返回解密后的Redis密码
|
||||
@@ -84,28 +94,28 @@ func (c *RedisConfig) GetPassword() string {
|
||||
|
||||
// RouterConfig 路由配置
|
||||
type RouterConfig struct {
|
||||
Strategy string // "latency", "cost", "availability", "weighted"
|
||||
Timeout time.Duration
|
||||
MaxRetries int
|
||||
RetryDelay time.Duration
|
||||
Strategy string // "latency", "cost", "availability", "weighted"
|
||||
Timeout time.Duration
|
||||
MaxRetries int
|
||||
RetryDelay time.Duration
|
||||
HealthCheckInterval time.Duration
|
||||
}
|
||||
|
||||
// RateLimitConfig 限流配置
|
||||
type RateLimitConfig struct {
|
||||
Enabled bool
|
||||
Algorithm string // "token_bucket", "sliding_window", "fixed_window"
|
||||
DefaultRPM int // 请求数/分钟
|
||||
DefaultTPM int // Token数/分钟
|
||||
Enabled bool
|
||||
Algorithm string // "token_bucket", "sliding_window", "fixed_window"
|
||||
DefaultRPM int // 请求数/分钟
|
||||
DefaultTPM int // Token数/分钟
|
||||
BurstMultiplier float64
|
||||
}
|
||||
|
||||
// AlertConfig 告警配置
|
||||
type AlertConfig struct {
|
||||
Enabled bool
|
||||
Email EmailConfig
|
||||
DingTalk DingTalkConfig
|
||||
Feishu FeishuConfig
|
||||
Enabled bool
|
||||
Email EmailConfig
|
||||
DingTalk DingTalkConfig
|
||||
Feishu FeishuConfig
|
||||
}
|
||||
|
||||
// EmailConfig 邮件配置
|
||||
@@ -135,11 +145,11 @@ type FeishuConfig struct {
|
||||
|
||||
// ProviderConfig Provider配置
|
||||
type ProviderConfig struct {
|
||||
Name string
|
||||
Type string // "openai", "anthropic", "google", "custom"
|
||||
BaseURL string
|
||||
APIKey string
|
||||
Models []string
|
||||
Name string
|
||||
Type string // "openai", "anthropic", "google", "custom"
|
||||
BaseURL string
|
||||
APIKey string
|
||||
Models []string
|
||||
Priority int
|
||||
Weight float64
|
||||
}
|
||||
@@ -155,26 +165,31 @@ func LoadConfig(path string) (*Config, error) {
|
||||
WriteTimeout: 30 * time.Second,
|
||||
IdleTimeout: 120 * time.Second,
|
||||
},
|
||||
Auth: AuthConfig{
|
||||
Env: strings.ToLower(getEnv("GATEWAY_ENV", "dev")),
|
||||
TokenRuntimeMode: strings.ToLower(getEnv("GATEWAY_TOKEN_RUNTIME_MODE", "inmemory")),
|
||||
TokenRuntimeURL: strings.TrimSpace(getEnv("GATEWAY_TOKEN_RUNTIME_URL", "")),
|
||||
},
|
||||
Router: RouterConfig{
|
||||
Strategy: "latency",
|
||||
Timeout: 30 * time.Second,
|
||||
MaxRetries: 3,
|
||||
RetryDelay: 1 * time.Second,
|
||||
Strategy: "latency",
|
||||
Timeout: 30 * time.Second,
|
||||
MaxRetries: 3,
|
||||
RetryDelay: 1 * time.Second,
|
||||
HealthCheckInterval: 10 * time.Second,
|
||||
},
|
||||
RateLimit: RateLimitConfig{
|
||||
Enabled: true,
|
||||
Algorithm: "token_bucket",
|
||||
DefaultRPM: 60,
|
||||
DefaultTPM: 60000,
|
||||
Enabled: true,
|
||||
Algorithm: "token_bucket",
|
||||
DefaultRPM: 60,
|
||||
DefaultTPM: 60000,
|
||||
BurstMultiplier: 1.5,
|
||||
},
|
||||
Alert: AlertConfig{
|
||||
Enabled: true,
|
||||
Email: EmailConfig{
|
||||
Enabled: false,
|
||||
Host: getEnv("SMTP_HOST", "smtp.example.com"),
|
||||
Port: 587,
|
||||
Host: getEnv("SMTP_HOST", "smtp.example.com"),
|
||||
Port: 587,
|
||||
},
|
||||
DingTalk: DingTalkConfig{
|
||||
Enabled: getEnv("DINGTALK_ENABLED", "false") == "true",
|
||||
@@ -189,9 +204,33 @@ func LoadConfig(path string) (*Config, error) {
|
||||
},
|
||||
}
|
||||
|
||||
if err := validateAuthConfig(cfg.Auth); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func validateAuthConfig(cfg AuthConfig) error {
|
||||
mode := strings.ToLower(strings.TrimSpace(cfg.TokenRuntimeMode))
|
||||
env := strings.ToLower(strings.TrimSpace(cfg.Env))
|
||||
|
||||
switch mode {
|
||||
case "inmemory", "remote_introspection":
|
||||
default:
|
||||
return fmt.Errorf("unsupported token runtime mode %q", cfg.TokenRuntimeMode)
|
||||
}
|
||||
|
||||
if (env == "prod" || env == "staging") && mode == "inmemory" {
|
||||
return fmt.Errorf("inmemory token runtime is not allowed in %s, use remote_introspection", env)
|
||||
}
|
||||
if mode == "remote_introspection" && strings.TrimSpace(cfg.TokenRuntimeURL) == "" {
|
||||
return errors.New("GATEWAY_TOKEN_RUNTIME_URL is required when token runtime mode is remote_introspection")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getEnv(key, defaultValue string) string {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return value
|
||||
|
||||
@@ -2,6 +2,7 @@ package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
@@ -110,7 +111,7 @@ func TestAlertConfig_Struct(t *testing.T) {
|
||||
cfg := AlertConfig{
|
||||
Enabled: true,
|
||||
Email: EmailConfig{
|
||||
Enabled: false,
|
||||
Enabled: false,
|
||||
Host: "smtp.example.com",
|
||||
Port: 587,
|
||||
From: "alert@example.com",
|
||||
@@ -344,10 +345,10 @@ func TestConfig_AllFields(t *testing.T) {
|
||||
PoolSize: 10,
|
||||
},
|
||||
Router: RouterConfig{
|
||||
Strategy: "latency",
|
||||
Timeout: 30 * time.Second,
|
||||
MaxRetries: 3,
|
||||
RetryDelay: 1 * time.Second,
|
||||
Strategy: "latency",
|
||||
Timeout: 30 * time.Second,
|
||||
MaxRetries: 3,
|
||||
RetryDelay: 1 * time.Second,
|
||||
HealthCheckInterval: 10 * time.Second,
|
||||
},
|
||||
RateLimit: RateLimitConfig{
|
||||
@@ -360,7 +361,7 @@ func TestConfig_AllFields(t *testing.T) {
|
||||
Alert: AlertConfig{
|
||||
Enabled: true,
|
||||
Email: EmailConfig{
|
||||
Enabled: false,
|
||||
Enabled: false,
|
||||
Host: "smtp.example.com",
|
||||
Port: 587,
|
||||
},
|
||||
@@ -405,3 +406,30 @@ func TestLoadConfig_EnvOverrides(t *testing.T) {
|
||||
t.Errorf("expected custom.smtp.com, got %s", cfg.Alert.Email.Host)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_ProdRejectsInMemoryTokenRuntime(t *testing.T) {
|
||||
t.Setenv("GATEWAY_ENV", "prod")
|
||||
t.Setenv("GATEWAY_TOKEN_RUNTIME_MODE", "inmemory")
|
||||
|
||||
_, err := LoadConfig("")
|
||||
if err == nil {
|
||||
t.Fatal("expected prod config with in-memory token runtime to return error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "inmemory") {
|
||||
t.Fatalf("expected error to mention inmemory runtime, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_ProdRequiresTokenRuntimeURL(t *testing.T) {
|
||||
t.Setenv("GATEWAY_ENV", "prod")
|
||||
t.Setenv("GATEWAY_TOKEN_RUNTIME_MODE", "remote_introspection")
|
||||
t.Setenv("GATEWAY_TOKEN_RUNTIME_URL", "")
|
||||
|
||||
_, err := LoadConfig("")
|
||||
if err == nil {
|
||||
t.Fatal("expected prod config without token runtime URL to return error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "TOKEN_RUNTIME_URL") {
|
||||
t.Fatalf("expected error to mention TOKEN_RUNTIME_URL, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
117
gateway/internal/middleware/remote_runtime.go
Normal file
117
gateway/internal/middleware/remote_runtime.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type RemoteTokenRuntime struct {
|
||||
baseURL string
|
||||
httpClient *http.Client
|
||||
now func() time.Time
|
||||
|
||||
mu sync.RWMutex
|
||||
records map[string]remoteResolvedToken
|
||||
}
|
||||
|
||||
type remoteResolvedToken struct {
|
||||
status TokenStatus
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
type remoteIntrospectResponse struct {
|
||||
Data struct {
|
||||
TokenID string `json:"token_id"`
|
||||
SubjectID string `json:"subject_id"`
|
||||
Role string `json:"role"`
|
||||
Status string `json:"status"`
|
||||
Scope []string `json:"scope"`
|
||||
IssuedAt time.Time `json:"issued_at"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
func NewRemoteTokenRuntime(baseURL string, httpClient *http.Client, now func() time.Time) *RemoteTokenRuntime {
|
||||
if httpClient == nil {
|
||||
httpClient = http.DefaultClient
|
||||
}
|
||||
if now == nil {
|
||||
now = time.Now
|
||||
}
|
||||
return &RemoteTokenRuntime{
|
||||
baseURL: strings.TrimRight(baseURL, "/"),
|
||||
httpClient: httpClient,
|
||||
now: now,
|
||||
records: make(map[string]remoteResolvedToken),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RemoteTokenRuntime) Verify(ctx context.Context, rawToken string) (VerifiedToken, error) {
|
||||
payload := map[string]string{"token": rawToken}
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return VerifiedToken{}, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, r.baseURL+"/api/v1/platform/tokens/introspect", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return VerifiedToken{}, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("X-Request-Id", fmt.Sprintf("gateway-introspect-%d", r.now().UnixNano()))
|
||||
|
||||
resp, err := r.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return VerifiedToken{}, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return VerifiedToken{}, fmt.Errorf("token introspection failed with status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result remoteIntrospectResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return VerifiedToken{}, err
|
||||
}
|
||||
if strings.TrimSpace(result.Data.TokenID) == "" {
|
||||
return VerifiedToken{}, errors.New("token introspection response missing token_id")
|
||||
}
|
||||
|
||||
status := TokenStatus(strings.ToLower(strings.TrimSpace(result.Data.Status)))
|
||||
r.mu.Lock()
|
||||
r.records[result.Data.TokenID] = remoteResolvedToken{
|
||||
status: status,
|
||||
expiresAt: result.Data.ExpiresAt,
|
||||
}
|
||||
r.mu.Unlock()
|
||||
|
||||
return VerifiedToken{
|
||||
TokenID: result.Data.TokenID,
|
||||
SubjectID: result.Data.SubjectID,
|
||||
Role: result.Data.Role,
|
||||
Scope: append([]string(nil), result.Data.Scope...),
|
||||
IssuedAt: result.Data.IssuedAt,
|
||||
ExpiresAt: result.Data.ExpiresAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *RemoteTokenRuntime) Resolve(ctx context.Context, tokenID string) (TokenStatus, error) {
|
||||
r.mu.RLock()
|
||||
record, ok := r.records[tokenID]
|
||||
r.mu.RUnlock()
|
||||
if !ok {
|
||||
return "", errors.New("token status not cached; verify must run before resolve")
|
||||
}
|
||||
if !record.expiresAt.IsZero() && !r.now().Before(record.expiresAt) && record.status == TokenStatusActive {
|
||||
return TokenStatusExpired, nil
|
||||
}
|
||||
return record.status, nil
|
||||
}
|
||||
67
gateway/internal/middleware/remote_runtime_test.go
Normal file
67
gateway/internal/middleware/remote_runtime_test.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestRemoteTokenRuntime_VerifyAndResolve(t *testing.T) {
|
||||
httpClient := &http.Client{
|
||||
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
if req.URL.Path != "/api/v1/platform/tokens/introspect" {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusNotFound,
|
||||
Body: io.NopCloser(strings.NewReader("not found")),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(`{
|
||||
"request_id":"req-1",
|
||||
"data":{
|
||||
"token_id":"tok-1",
|
||||
"subject_id":"user-1",
|
||||
"role":"org_admin",
|
||||
"status":"active",
|
||||
"scope":["gateway:invoke"],
|
||||
"issued_at":"2026-04-10T10:00:00Z",
|
||||
"expires_at":"2026-04-10T11:00:00Z"
|
||||
}
|
||||
}`)),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
}),
|
||||
}
|
||||
|
||||
runtime := NewRemoteTokenRuntime("http://token-runtime.internal", httpClient, func() time.Time {
|
||||
return time.Date(2026, 4, 10, 10, 30, 0, 0, time.UTC)
|
||||
})
|
||||
|
||||
claims, err := runtime.Verify(context.Background(), "raw-token")
|
||||
if err != nil {
|
||||
t.Fatalf("Verify returned error: %v", err)
|
||||
}
|
||||
if claims.TokenID != "tok-1" {
|
||||
t.Fatalf("expected token id tok-1, got %s", claims.TokenID)
|
||||
}
|
||||
|
||||
status, err := runtime.Resolve(context.Background(), claims.TokenID)
|
||||
if err != nil {
|
||||
t.Fatalf("Resolve returned error: %v", err)
|
||||
}
|
||||
if status != TokenStatusActive {
|
||||
t.Fatalf("expected active status, got %s", status)
|
||||
}
|
||||
}
|
||||
|
||||
type roundTripFunc func(req *http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
@@ -15,6 +16,10 @@ import (
|
||||
|
||||
func main() {
|
||||
addr := envOrDefault("TOKEN_RUNTIME_ADDR", ":18081")
|
||||
env := strings.ToLower(envOrDefault("TOKEN_RUNTIME_ENV", "dev"))
|
||||
if env == "prod" || env == "staging" {
|
||||
log.Fatalf("in-memory token runtime is not allowed in %s", env)
|
||||
}
|
||||
|
||||
runtime := service.NewInMemoryTokenRuntime(nil)
|
||||
auditor := service.NewMemoryAuditEmitter()
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestMain_ProdRejectsInMemoryRuntime(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, os.Args[0], "-test.run=TestMainHelperProcess")
|
||||
cmd.Env = append(os.Environ(),
|
||||
"GO_WANT_HELPER_PROCESS=1",
|
||||
"TOKEN_RUNTIME_ENV=prod",
|
||||
"TOKEN_RUNTIME_ADDR=127.0.0.1:0",
|
||||
)
|
||||
output, err := cmd.CombinedOutput()
|
||||
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
t.Fatalf("expected prod startup to fail fast, but process timed out. output=%s", string(output))
|
||||
}
|
||||
if err == nil {
|
||||
t.Fatalf("expected prod startup to fail, but process exited successfully. output=%s", string(output))
|
||||
}
|
||||
if !strings.Contains(string(output), "in-memory token runtime is not allowed") {
|
||||
t.Fatalf("expected startup failure output to mention in-memory token runtime is not allowed, got: %s", string(output))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMainHelperProcess(t *testing.T) {
|
||||
if os.Getenv("GO_WANT_HELPER_PROCESS") != "1" {
|
||||
return
|
||||
}
|
||||
main()
|
||||
os.Exit(0)
|
||||
}
|
||||
12
reports/gates/backend_verify_2026-04-11_091249.md
Normal file
12
reports/gates/backend_verify_2026-04-11_091249.md
Normal file
@@ -0,0 +1,12 @@
|
||||
# Backend Verify Report
|
||||
|
||||
- 时间戳:2026-04-11_091249
|
||||
- 结果:**PASS**
|
||||
- 说明:all backend release gates passed
|
||||
|
||||
| 步骤 | 结果 | 说明 | 证据 |
|
||||
|---|---|---|---|
|
||||
| STEP-01 | PASS | supply-api critical regression suite | /home/long/project/立交桥/reports/gates/step-01_2026-04-11_091249.out.log |
|
||||
| STEP-02 | PASS | gateway critical regression suite | /home/long/project/立交桥/reports/gates/step-02_2026-04-11_091249.out.log |
|
||||
| STEP-03 | PASS | platform-token-runtime critical regression suite | /home/long/project/立交桥/reports/gates/step-03_2026-04-11_091249.out.log |
|
||||
| STEP-04 | PASS | supply-api E2E gate must not contain placeholder skip | /home/long/project/立交桥/reports/gates/step-04_2026-04-11_091249.out.log |
|
||||
@@ -0,0 +1,29 @@
|
||||
# Superpowers 阶段验证报告
|
||||
|
||||
- 时间戳:2026-04-11_091525
|
||||
- 执行脚本:`scripts/ci/superpowers_stage_validate.sh`
|
||||
- 决策:**CONDITIONAL_GO**
|
||||
- 决策依据:all executable phases passed but real staging phase is deferred
|
||||
|
||||
## 阶段结果
|
||||
|
||||
| 阶段 | 结果 | 说明 | 证据 |
|
||||
|---|---|---|---|
|
||||
| PHASE-00 | PASS | Backend critical verification gate | /home/long/project/立交桥/tests/supply/artifacts/superpowers_stage_validation_2026-04-11_091525/phase00_backend_verify.log |
|
||||
| PHASE-01 | PASS | TOK runtime code tests | /home/long/project/立交桥/tests/supply/artifacts/superpowers_stage_validation_2026-04-11_091525/phase01_go_test.log |
|
||||
| PHASE-02 | PASS | SUP local-mock run_all execution | /home/long/project/立交桥/tests/supply/artifacts/superpowers_stage_validation_2026-04-11_091525/phase02_sup_run_all_mock.log |
|
||||
| PHASE-03 | PASS | TOK-005 boundary dry-run on local-mock env | /home/long/project/立交桥/tests/supply/artifacts/superpowers_stage_validation_2026-04-11_091525/phase03_tok005_dryrun_mock.log |
|
||||
| PHASE-04 | PASS | TOK-006 gate bundle aggregation | /home/long/project/立交桥/tests/supply/artifacts/superpowers_stage_validation_2026-04-11_091525/phase04_tok006_bundle.log |
|
||||
| PHASE-05 | PASS | Dependency audit gate validation | /home/long/project/立交桥/tests/supply/artifacts/superpowers_stage_validation_2026-04-11_091525/phase05_dependency_audit.log |
|
||||
| PHASE-06 | PASS | Stage gate rollback drill | /home/long/project/立交桥/tests/supply/artifacts/superpowers_stage_validation_2026-04-11_091525/phase06_stage_gate_drill.log |
|
||||
| PHASE-07 | DEFERRED | Real staging precheck (expected deferred before real secrets) | /home/long/project/立交桥/tests/supply/artifacts/superpowers_stage_validation_2026-04-11_091525/phase07_staging_precheck.log |
|
||||
| PHASE-08 | PASS | Daily metrics snapshot for M-017/M-018/M-019 | /home/long/project/立交桥/tests/supply/artifacts/superpowers_stage_validation_2026-04-11_091525/phase08_metrics_snapshot.log |
|
||||
| PHASE-09 | PASS | 7-day metrics trend report generation | /home/long/project/立交桥/tests/supply/artifacts/superpowers_stage_validation_2026-04-11_091525/phase09_metrics_trend.log |
|
||||
| PHASE-10 | PASS | Token runtime readiness check (M-021) | /home/long/project/立交桥/tests/supply/artifacts/superpowers_stage_validation_2026-04-11_091525/phase10_token_runtime_readiness.log |
|
||||
|
||||
## 说明
|
||||
|
||||
1. PHASE-07 为真实 staging 验证阶段,在占位凭证场景下允许 DEFERRED,不得伪造 PASS。
|
||||
2. PHASE-08/09 负责 M-017/M-018/M-019 的每日快照与趋势证据生成。
|
||||
3. PHASE-10 负责 M-021 token 运行态就绪度计算。
|
||||
4. 其余阶段均为可执行验证,必须以命令返回码与证据文件为准。
|
||||
145
scripts/ci/backend-verify.sh
Executable file
145
scripts/ci/backend-verify.sh
Executable file
@@ -0,0 +1,145 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
ROOT_DIR="$(cd "$(dirname "$0")/../.." && pwd)"
|
||||
OUT_DIR="${ROOT_DIR}/reports/gates"
|
||||
TS="$(date +%F_%H%M%S)"
|
||||
LOG_FILE="${OUT_DIR}/backend_verify_${TS}.log"
|
||||
REPORT_FILE="${OUT_DIR}/backend_verify_${TS}.md"
|
||||
GO_BIN="${ROOT_DIR}/.tools/go-current/bin/go"
|
||||
DEFAULT_GOPATH=""
|
||||
DEFAULT_GOMODCACHE=""
|
||||
|
||||
mkdir -p "${OUT_DIR}"
|
||||
: > "${LOG_FILE}"
|
||||
|
||||
if [[ ! -x "${GO_BIN}" ]]; then
|
||||
GO_BIN="$(command -v go || true)"
|
||||
fi
|
||||
if [[ -z "${GO_BIN}" ]]; then
|
||||
echo "[FAIL] go binary not found" | tee -a "${LOG_FILE}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export PATH="$(dirname "${GO_BIN}"):${PATH}"
|
||||
export GOCACHE="${ROOT_DIR}/.tools/go-cache"
|
||||
DEFAULT_GOPATH="$("${GO_BIN}" env GOPATH 2>/dev/null || true)"
|
||||
DEFAULT_GOMODCACHE="$("${GO_BIN}" env GOMODCACHE 2>/dev/null || true)"
|
||||
if [[ -n "${DEFAULT_GOPATH}" ]]; then
|
||||
export GOPATH="${DEFAULT_GOPATH}"
|
||||
fi
|
||||
if [[ -n "${DEFAULT_GOMODCACHE}" ]]; then
|
||||
export GOMODCACHE="${DEFAULT_GOMODCACHE}"
|
||||
fi
|
||||
|
||||
STEP_RESULTS=()
|
||||
|
||||
log() {
|
||||
echo "$1" | tee -a "${LOG_FILE}"
|
||||
}
|
||||
|
||||
run_step() {
|
||||
local step_id="$1"
|
||||
local title="$2"
|
||||
local cmd="$3"
|
||||
local out_file="${OUT_DIR}/${step_id,,}_${TS}.out.log"
|
||||
|
||||
log "[INFO] ${step_id} ${title} start"
|
||||
set +e
|
||||
bash -lc "${cmd}" > "${out_file}" 2>&1
|
||||
local rc=$?
|
||||
set -e
|
||||
|
||||
if [[ "${rc}" -eq 0 ]]; then
|
||||
log "[PASS] ${step_id} rc=${rc}"
|
||||
STEP_RESULTS+=("${step_id}|PASS|${title}|${out_file}")
|
||||
else
|
||||
log "[FAIL] ${step_id} rc=${rc}"
|
||||
STEP_RESULTS+=("${step_id}|FAIL|${title}|${out_file}")
|
||||
fi
|
||||
}
|
||||
|
||||
run_e2e_skip_gate() {
|
||||
local step_id="$1"
|
||||
local title="$2"
|
||||
local out_file="${OUT_DIR}/${step_id,,}_${TS}.out.log"
|
||||
|
||||
log "[INFO] ${step_id} ${title} start"
|
||||
set +e
|
||||
bash -lc "cd \"${ROOT_DIR}/supply-api\" && \"${GO_BIN}\" test -tags=e2e -v ./e2e/..." > "${out_file}" 2>&1
|
||||
local rc=$?
|
||||
set -e
|
||||
|
||||
if grep -Eiq 'SKIP|需要完整环境运行 E2E 测试|Skipping E2E test' "${out_file}"; then
|
||||
log "[FAIL] ${step_id} placeholder E2E detected"
|
||||
STEP_RESULTS+=("${step_id}|FAIL|${title}|${out_file}")
|
||||
return
|
||||
fi
|
||||
|
||||
if [[ "${rc}" -eq 0 ]]; then
|
||||
log "[PASS] ${step_id} rc=${rc}"
|
||||
STEP_RESULTS+=("${step_id}|PASS|${title}|${out_file}")
|
||||
else
|
||||
log "[FAIL] ${step_id} rc=${rc}"
|
||||
STEP_RESULTS+=("${step_id}|FAIL|${title}|${out_file}")
|
||||
fi
|
||||
}
|
||||
|
||||
run_step \
|
||||
"STEP-01" \
|
||||
"supply-api critical regression suite" \
|
||||
"cd \"${ROOT_DIR}/supply-api\" && \"${GO_BIN}\" test ./cmd/supply-api ./internal/config ./internal/httpapi ./internal/middleware ./internal/outbox ./internal/repository"
|
||||
|
||||
run_step \
|
||||
"STEP-02" \
|
||||
"gateway critical regression suite" \
|
||||
"cd \"${ROOT_DIR}/gateway\" && \"${GO_BIN}\" test ./cmd/gateway ./internal/config ./internal/middleware"
|
||||
|
||||
run_step \
|
||||
"STEP-03" \
|
||||
"platform-token-runtime critical regression suite" \
|
||||
"cd \"${ROOT_DIR}/platform-token-runtime\" && \"${GO_BIN}\" test ./cmd/platform-token-runtime ./internal/httpapi ./internal/token ./internal/auth/..."
|
||||
|
||||
run_e2e_skip_gate \
|
||||
"STEP-04" \
|
||||
"supply-api E2E gate must not contain placeholder skip"
|
||||
|
||||
HAS_FAIL=0
|
||||
for row in "${STEP_RESULTS[@]}"; do
|
||||
status="$(echo "${row}" | awk -F'|' '{print $2}')"
|
||||
if [[ "${status}" == "FAIL" ]]; then
|
||||
HAS_FAIL=1
|
||||
fi
|
||||
done
|
||||
|
||||
RESULT="PASS"
|
||||
NOTE="all backend release gates passed"
|
||||
if [[ "${HAS_FAIL}" -eq 1 ]]; then
|
||||
RESULT="FAIL"
|
||||
NOTE="at least one backend release gate failed"
|
||||
fi
|
||||
|
||||
{
|
||||
echo "# Backend Verify Report"
|
||||
echo
|
||||
echo "- 时间戳:${TS}"
|
||||
echo "- 结果:**${RESULT}**"
|
||||
echo "- 说明:${NOTE}"
|
||||
echo
|
||||
echo "| 步骤 | 结果 | 说明 | 证据 |"
|
||||
echo "|---|---|---|---|"
|
||||
for row in "${STEP_RESULTS[@]}"; do
|
||||
step_id="$(echo "${row}" | awk -F'|' '{print $1}')"
|
||||
status="$(echo "${row}" | awk -F'|' '{print $2}')"
|
||||
title="$(echo "${row}" | awk -F'|' '{print $3}')"
|
||||
evidence="$(echo "${row}" | awk -F'|' '{print $4}')"
|
||||
echo "| ${step_id} | ${status} | ${title} | ${evidence} |"
|
||||
done
|
||||
} > "${REPORT_FILE}"
|
||||
|
||||
log "[INFO] report generated: ${REPORT_FILE}"
|
||||
log "[RESULT] ${RESULT}"
|
||||
|
||||
if [[ "${RESULT}" != "PASS" ]]; then
|
||||
exit 1
|
||||
fi
|
||||
@@ -129,6 +129,12 @@ if [[ -z "${GO_BIN}" ]]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
run_step \
|
||||
"PHASE-00" \
|
||||
"Backend critical verification gate" \
|
||||
"cd \"${ROOT_DIR}\" && bash \"scripts/ci/backend-verify.sh\"" \
|
||||
"${ART_DIR}/phase00_backend_verify.log"
|
||||
|
||||
run_step \
|
||||
"PHASE-01" \
|
||||
"TOK runtime code tests" \
|
||||
@@ -170,7 +176,7 @@ run_step_allow_deferred \
|
||||
"Real staging precheck (expected deferred before real secrets)" \
|
||||
"cd \"${ROOT_DIR}\" && bash \"scripts/supply-gate/staging_precheck_and_run.sh\" \"${STAGING_ENV_FILE}\"" \
|
||||
"${ART_DIR}/phase07_staging_precheck.log" \
|
||||
"placeholder token detected|placeholder API_BASE_URL|missing env var"
|
||||
"placeholder token detected|placeholder API_BASE_URL|missing env var|API_BASE_URL unreachable"
|
||||
|
||||
run_step \
|
||||
"PHASE-08" \
|
||||
|
||||
@@ -37,22 +37,29 @@ func main() {
|
||||
}
|
||||
|
||||
// 加载配置
|
||||
cfg, err := config.Load(*env)
|
||||
cfg, err := config.LoadFromPath(*env, *configPath)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to load config: %v", err)
|
||||
}
|
||||
|
||||
log.Printf("starting supply-api in %s mode", *env)
|
||||
isProd := *env == "prod"
|
||||
|
||||
// P1-010修复: 初始化结构化日志
|
||||
jsonLogger := logging.NewLogger("supply-api", logging.LogLevelInfo)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
rootCtx, stop := context.WithCancel(context.Background())
|
||||
defer stop()
|
||||
|
||||
initCtx, initCancel := context.WithTimeout(rootCtx, 30*time.Second)
|
||||
defer initCancel()
|
||||
|
||||
// 初始化数据库连接
|
||||
db, err := repository.NewDB(ctx, cfg.Database)
|
||||
db, err := repository.NewDB(initCtx, cfg.Database)
|
||||
if err != nil {
|
||||
if isProd {
|
||||
log.Fatalf("production startup requirement failed: database unavailable: %v", err)
|
||||
}
|
||||
log.Printf("warning: failed to connect to database: %v (using in-memory store)", err)
|
||||
db = nil
|
||||
} else {
|
||||
@@ -63,7 +70,11 @@ func main() {
|
||||
// 初始化Redis缓存
|
||||
redisCache, err := cache.NewRedisCache(cfg.Redis)
|
||||
if err != nil {
|
||||
log.Printf("warning: failed to connect to redis: %v (caching disabled)", err)
|
||||
if isProd {
|
||||
log.Printf("warning: redis unavailable at startup: %v", err)
|
||||
} else {
|
||||
log.Printf("warning: failed to connect to redis: %v (caching disabled)", err)
|
||||
}
|
||||
redisCache = nil
|
||||
} else {
|
||||
log.Printf("connected to redis at %s:%d", cfg.Redis.Host, cfg.Redis.Port)
|
||||
@@ -154,7 +165,7 @@ func main() {
|
||||
// 启动主动吊销订阅机制(仅在Redis可用时)
|
||||
if redisCache != nil {
|
||||
if dbTokenBackend, ok := tokenBackend.(*middleware.DBTokenStatusBackend); ok {
|
||||
if err := dbTokenBackend.StartRevocationSubscriber(ctx); err != nil {
|
||||
if err := dbTokenBackend.StartRevocationSubscriber(rootCtx); err != nil {
|
||||
log.Printf("警告: 启动主动吊销订阅失败: %v", err)
|
||||
} else {
|
||||
log.Println("主动吊销机制: 已启动 (Redis Pub/Sub)")
|
||||
@@ -172,6 +183,8 @@ func main() {
|
||||
// 初始化鉴权中间件
|
||||
authConfig := middleware.AuthConfig{
|
||||
SecretKey: cfg.Token.SecretKey,
|
||||
PublicKey: cfg.Token.PublicKey,
|
||||
Algorithm: cfg.Token.Algorithm,
|
||||
Issuer: cfg.Token.Issuer,
|
||||
CacheTTL: cfg.Token.RevocationCacheTTL,
|
||||
Enabled: *env != "dev", // 开发模式禁用鉴权
|
||||
@@ -210,6 +223,7 @@ func main() {
|
||||
cfg.Server.StatementBaseURL,
|
||||
time.Now,
|
||||
)
|
||||
api.SetWithdrawEnabled(cfg.Settlement.WithdrawEnabled)
|
||||
|
||||
// 创建路由器
|
||||
mux := http.NewServeMux()
|
||||
@@ -254,14 +268,12 @@ func main() {
|
||||
|
||||
// 生产环境启用安全中间件
|
||||
if *env != "dev" {
|
||||
// 5. QueryKeyReject - 拒绝外部query key
|
||||
handler = authMiddleware.QueryKeyRejectMiddleware(handler)
|
||||
// 6. BearerExtract
|
||||
handler = authMiddleware.BearerExtractMiddleware(handler)
|
||||
// 7. TokenVerify
|
||||
handler = authMiddleware.TokenVerifyMiddleware(handler)
|
||||
// 8. RateLimit - 限流 (使用中间件包装器)
|
||||
// 包装顺序与请求执行顺序相反,这里从内到外构建,保证实际执行顺序为:
|
||||
// QueryKeyReject -> BearerExtract -> TokenVerify -> RateLimit
|
||||
handler = middleware.NewRateLimitHandler(rateLimitConfig, handler)
|
||||
handler = authMiddleware.TokenVerifyMiddleware(handler)
|
||||
handler = authMiddleware.BearerExtractMiddleware(handler)
|
||||
handler = authMiddleware.QueryKeyRejectMiddleware(handler)
|
||||
}
|
||||
|
||||
// 创建HTTP服务器
|
||||
@@ -274,6 +286,14 @@ func main() {
|
||||
IdleTimeout: cfg.Server.IdleTimeout,
|
||||
}
|
||||
|
||||
serverErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
log.Printf("starting HTTP server on %s", cfg.Server.Addr)
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
serverErrCh <- err
|
||||
}
|
||||
}()
|
||||
|
||||
// P0-06修复: 启动OutboxProcessor(仅在DB可用时)
|
||||
var outboxProcessor *outbox.OutboxProcessorRunner
|
||||
if db != nil {
|
||||
@@ -284,14 +304,21 @@ func main() {
|
||||
redisClient := redisCache.GetClient()
|
||||
msgBroker = messaging.NewOutboxMessageBroker(redisClient, "supply:outbox:stream", "outbox-processor")
|
||||
}
|
||||
stats := &messaging.NoOpOutboxStats{}
|
||||
outboxProcessor = outbox.NewOutboxProcessorRunner(outboxRepo, msgBroker, stats)
|
||||
go outboxProcessor.Start(ctx)
|
||||
log.Println("OutboxProcessor已启动")
|
||||
if msgBroker == nil {
|
||||
if isProd {
|
||||
log.Fatalf("production startup requirement failed: outbox message broker unavailable")
|
||||
}
|
||||
log.Println("警告: OutboxProcessor未启动 (message broker不可用)")
|
||||
} else {
|
||||
stats := &messaging.NoOpOutboxStats{}
|
||||
outboxProcessor = outbox.NewOutboxProcessorRunner(outboxRepo, msgBroker, stats)
|
||||
go outboxProcessor.Start(rootCtx)
|
||||
log.Println("OutboxProcessor已启动")
|
||||
}
|
||||
|
||||
// 分区维护:确保未来分区已创建
|
||||
partitionManager := repository.NewPartitionManager(db.Pool)
|
||||
if err := partitionManager.EnsureFuturePartitions(ctx); err != nil {
|
||||
if err := partitionManager.EnsureFuturePartitions(initCtx); err != nil {
|
||||
log.Printf("警告: 预创建未来分区失败: %v", err)
|
||||
} else {
|
||||
log.Println("分区管理: 未来分区已确保存在")
|
||||
@@ -303,7 +330,7 @@ func main() {
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-rootCtx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := partitionManager.EnsureFuturePartitions(context.Background()); err != nil {
|
||||
@@ -327,14 +354,21 @@ func main() {
|
||||
log.Println("批量补偿处理器: 已初始化")
|
||||
|
||||
// 启动后台补偿处理goroutine
|
||||
compensationProcessor.StartBackgroundWorker(ctx, 5*time.Minute)
|
||||
compensationProcessor.StartBackgroundWorker(rootCtx, 5*time.Minute)
|
||||
log.Println("批量补偿处理器: 后台worker已启动 (每5分钟检查一次)")
|
||||
}
|
||||
|
||||
// 优雅关闭
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-sigCh
|
||||
select {
|
||||
case sig := <-sigCh:
|
||||
log.Printf("received signal %s", sig)
|
||||
case err := <-serverErrCh:
|
||||
log.Fatalf("server failed: %v", err)
|
||||
}
|
||||
|
||||
stop()
|
||||
|
||||
log.Println("shutting down...")
|
||||
|
||||
|
||||
71
supply-api/cmd/supply-api/main_test.go
Normal file
71
supply-api/cmd/supply-api/main_test.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestMain_ProdStartupFailsWhenDatabaseUnavailable(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.prod.yaml")
|
||||
content := []byte(`
|
||||
server:
|
||||
addr: "127.0.0.1:0"
|
||||
shutdown_timeout: 1s
|
||||
default_supplier_id: 0
|
||||
database:
|
||||
host: "127.0.0.1"
|
||||
port: 1
|
||||
user: "postgres"
|
||||
password: "secret"
|
||||
database: "supply_db"
|
||||
redis:
|
||||
host: "127.0.0.1"
|
||||
port: 1
|
||||
token:
|
||||
issuer: "prod-issuer"
|
||||
secret_key: "prod-secret"
|
||||
algorithm: "HS256"
|
||||
`)
|
||||
if err := os.WriteFile(configPath, content, 0o600); err != nil {
|
||||
t.Fatalf("failed to write config file: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, os.Args[0], "-test.run=TestMainHelperProcess", "--", "-env", "prod", "-config", configPath)
|
||||
cmd.Env = append(os.Environ(), "GO_WANT_HELPER_PROCESS=1")
|
||||
output, err := cmd.CombinedOutput()
|
||||
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
t.Fatalf("expected prod startup to fail fast, but process timed out. output=%s", string(output))
|
||||
}
|
||||
if err == nil {
|
||||
t.Fatalf("expected prod startup to fail, but process exited successfully. output=%s", string(output))
|
||||
}
|
||||
if !strings.Contains(string(output), "production startup requirement failed") {
|
||||
t.Fatalf("expected startup failure output to mention production startup requirement failed, got: %s", string(output))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMainHelperProcess(t *testing.T) {
|
||||
if os.Getenv("GO_WANT_HELPER_PROCESS") != "1" {
|
||||
return
|
||||
}
|
||||
|
||||
for i, arg := range os.Args {
|
||||
if arg != "--" {
|
||||
continue
|
||||
}
|
||||
os.Args = append([]string{os.Args[0]}, os.Args[i+1:]...)
|
||||
break
|
||||
}
|
||||
|
||||
main()
|
||||
os.Exit(0)
|
||||
}
|
||||
@@ -4,133 +4,418 @@
|
||||
package e2e
|
||||
|
||||
import (
|
||||
"os"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
|
||||
"lijiaoqiao/supply-api/internal/adapter"
|
||||
"lijiaoqiao/supply-api/internal/audit"
|
||||
"lijiaoqiao/supply-api/internal/domain"
|
||||
"lijiaoqiao/supply-api/internal/httpapi"
|
||||
"lijiaoqiao/supply-api/internal/middleware"
|
||||
"lijiaoqiao/supply-api/internal/pkg/logging"
|
||||
)
|
||||
|
||||
// E2E 测试配置
|
||||
type E2EConfig struct {
|
||||
BaseURL string
|
||||
APIKey string
|
||||
SupplierID int64
|
||||
Timeout time.Duration
|
||||
RetryAttempts int
|
||||
type e2eOptions struct {
|
||||
withdrawEnabled bool
|
||||
}
|
||||
|
||||
// getE2EConfig 从环境变量获取 E2E 测试配置
|
||||
func getE2EConfig() *E2EConfig {
|
||||
return &E2EConfig{
|
||||
BaseURL: getEnv("E2E_BASE_URL", "http://localhost:8080"),
|
||||
APIKey: getEnv("E2E_API_KEY", "test-api-key"),
|
||||
SupplierID: 1001,
|
||||
Timeout: 30 * time.Second,
|
||||
RetryAttempts: 3,
|
||||
type e2eSystem struct {
|
||||
handler http.Handler
|
||||
accountSvc *e2eAccountService
|
||||
auditStore *audit.MemoryAuditStore
|
||||
secretKey string
|
||||
tokenIssuer string
|
||||
}
|
||||
|
||||
type e2eAccountService struct {
|
||||
verifyResult *domain.VerifyResult
|
||||
lastVerifySupplierID int64
|
||||
}
|
||||
|
||||
func (s *e2eAccountService) Verify(ctx context.Context, supplierID int64, provider domain.Provider, accountType domain.AccountType, credential string) (*domain.VerifyResult, error) {
|
||||
s.lastVerifySupplierID = supplierID
|
||||
return s.verifyResult, nil
|
||||
}
|
||||
|
||||
func (s *e2eAccountService) Create(ctx context.Context, req *domain.CreateAccountRequest) (*domain.Account, error) {
|
||||
return &domain.Account{ID: 1, SupplierID: req.SupplierID, Provider: req.Provider, AccountType: req.AccountType, Status: domain.AccountStatusActive}, nil
|
||||
}
|
||||
|
||||
func (s *e2eAccountService) Activate(ctx context.Context, supplierID, accountID int64) (*domain.Account, error) {
|
||||
return &domain.Account{ID: accountID, SupplierID: supplierID, Status: domain.AccountStatusActive}, nil
|
||||
}
|
||||
|
||||
func (s *e2eAccountService) Suspend(ctx context.Context, supplierID, accountID int64) (*domain.Account, error) {
|
||||
return &domain.Account{ID: accountID, SupplierID: supplierID, Status: domain.AccountStatusSuspended}, nil
|
||||
}
|
||||
|
||||
func (s *e2eAccountService) Delete(ctx context.Context, supplierID, accountID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *e2eAccountService) GetByID(ctx context.Context, supplierID, accountID int64) (*domain.Account, error) {
|
||||
return &domain.Account{ID: accountID, SupplierID: supplierID, Status: domain.AccountStatusActive}, nil
|
||||
}
|
||||
|
||||
type e2ePackageService struct{}
|
||||
|
||||
func (s *e2ePackageService) CreateDraft(ctx context.Context, supplierID int64, req *domain.CreatePackageDraftRequest) (*domain.Package, error) {
|
||||
return &domain.Package{ID: 1, SupplierID: supplierID, Model: req.Model, Status: domain.PackageStatusDraft}, nil
|
||||
}
|
||||
|
||||
func (s *e2ePackageService) Publish(ctx context.Context, supplierID, packageID int64) (*domain.Package, error) {
|
||||
return &domain.Package{ID: packageID, SupplierID: supplierID, Status: domain.PackageStatusActive}, nil
|
||||
}
|
||||
|
||||
func (s *e2ePackageService) Pause(ctx context.Context, supplierID, packageID int64) (*domain.Package, error) {
|
||||
return &domain.Package{ID: packageID, SupplierID: supplierID, Status: domain.PackageStatusPaused}, nil
|
||||
}
|
||||
|
||||
func (s *e2ePackageService) Unlist(ctx context.Context, supplierID, packageID int64) (*domain.Package, error) {
|
||||
return &domain.Package{ID: packageID, SupplierID: supplierID, Status: domain.PackageStatusExpired}, nil
|
||||
}
|
||||
|
||||
func (s *e2ePackageService) Clone(ctx context.Context, supplierID, packageID int64) (*domain.Package, error) {
|
||||
return &domain.Package{ID: packageID + 1, SupplierID: supplierID, Status: domain.PackageStatusDraft}, nil
|
||||
}
|
||||
|
||||
func (s *e2ePackageService) BatchUpdatePrice(ctx context.Context, supplierID int64, req *domain.BatchUpdatePriceRequest) (*domain.BatchUpdatePriceResponse, error) {
|
||||
return &domain.BatchUpdatePriceResponse{Total: len(req.Items), SuccessCount: len(req.Items)}, nil
|
||||
}
|
||||
|
||||
func (s *e2ePackageService) GetByID(ctx context.Context, supplierID, packageID int64) (*domain.Package, error) {
|
||||
return &domain.Package{ID: packageID, SupplierID: supplierID, Status: domain.PackageStatusActive}, nil
|
||||
}
|
||||
|
||||
type e2eSettlementService struct{}
|
||||
|
||||
func (s *e2eSettlementService) Withdraw(ctx context.Context, supplierID int64, req *domain.WithdrawRequest) (*domain.Settlement, error) {
|
||||
now := time.Now().UTC()
|
||||
return &domain.Settlement{
|
||||
ID: 1,
|
||||
SupplierID: supplierID,
|
||||
SettlementNo: "SET-001",
|
||||
Status: domain.SettlementStatusPending,
|
||||
TotalAmount: req.Amount,
|
||||
NetAmount: req.Amount,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *e2eSettlementService) Cancel(ctx context.Context, supplierID, settlementID int64) (*domain.Settlement, error) {
|
||||
now := time.Now().UTC()
|
||||
return &domain.Settlement{ID: settlementID, SupplierID: supplierID, Status: domain.SettlementStatusFailed, CreatedAt: now, UpdatedAt: now}, nil
|
||||
}
|
||||
|
||||
func (s *e2eSettlementService) GetByID(ctx context.Context, supplierID, settlementID int64) (*domain.Settlement, error) {
|
||||
now := time.Now().UTC()
|
||||
return &domain.Settlement{ID: settlementID, SupplierID: supplierID, Status: domain.SettlementStatusPending, CreatedAt: now, UpdatedAt: now}, nil
|
||||
}
|
||||
|
||||
func (s *e2eSettlementService) List(ctx context.Context, supplierID int64) ([]*domain.Settlement, error) {
|
||||
now := time.Now().UTC()
|
||||
return []*domain.Settlement{{ID: 1, SupplierID: supplierID, Status: domain.SettlementStatusPending, CreatedAt: now, UpdatedAt: now}}, nil
|
||||
}
|
||||
|
||||
type e2eEarningService struct{}
|
||||
|
||||
func (s *e2eEarningService) ListRecords(ctx context.Context, supplierID int64, startDate, endDate string, page, pageSize int) ([]*domain.EarningRecord, int, error) {
|
||||
return []*domain.EarningRecord{
|
||||
{ID: 1, SupplierID: supplierID, EarningsType: "usage", Amount: 100, Status: "available", EarnedAt: time.Now().UTC()},
|
||||
}, 1, nil
|
||||
}
|
||||
|
||||
func (s *e2eEarningService) GetBillingSummary(ctx context.Context, supplierID int64, startDate, endDate string) (*domain.BillingSummary, error) {
|
||||
return &domain.BillingSummary{
|
||||
Period: domain.BillingPeriod{Start: startDate, End: endDate},
|
||||
Summary: domain.BillingTotal{TotalRevenue: 100, TotalOrders: 1, TotalUsage: 1000, TotalRequests: 10, AvgSuccessRate: 1, NetEarnings: 95},
|
||||
}, nil
|
||||
}
|
||||
|
||||
type staticTokenBackend struct {
|
||||
statusByTokenID map[string]string
|
||||
}
|
||||
|
||||
func (b *staticTokenBackend) CheckTokenStatus(ctx context.Context, tokenID string) (string, error) {
|
||||
if status, ok := b.statusByTokenID[tokenID]; ok {
|
||||
return status, nil
|
||||
}
|
||||
return "active", nil
|
||||
}
|
||||
|
||||
func newE2ESystem(t *testing.T, opts e2eOptions) *e2eSystem {
|
||||
t.Helper()
|
||||
|
||||
accountSvc := &e2eAccountService{
|
||||
verifyResult: &domain.VerifyResult{
|
||||
VerifyStatus: "pass",
|
||||
AvailableQuota: 2048,
|
||||
RiskScore: 0,
|
||||
CheckItems: []domain.CheckItem{
|
||||
{Item: "credential_format", Result: "pass", Message: "ok"},
|
||||
},
|
||||
},
|
||||
}
|
||||
auditStore := audit.NewMemoryAuditStore()
|
||||
|
||||
api := httpapi.NewSupplyAPI(
|
||||
accountSvc,
|
||||
&e2ePackageService{},
|
||||
&e2eSettlementService{},
|
||||
&e2eEarningService{},
|
||||
nil,
|
||||
auditStore,
|
||||
nil,
|
||||
0,
|
||||
"https://statements.example.com",
|
||||
func() time.Time { return time.Unix(1712800000, 0).UTC() },
|
||||
)
|
||||
api.SetWithdrawEnabled(opts.withdrawEnabled)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
healthHandler := httpapi.NewHealthHandlerWithDefaults(nil, nil)
|
||||
mux.HandleFunc("/actuator/health", healthHandler.ServeHealth)
|
||||
mux.HandleFunc("/actuator/health/live", healthHandler.ServeLiveness)
|
||||
mux.HandleFunc("/actuator/health/ready", healthHandler.ServeReadiness)
|
||||
api.Register(mux)
|
||||
|
||||
authMiddleware := middleware.NewAuthMiddleware(
|
||||
middleware.AuthConfig{
|
||||
SecretKey: "e2e-secret-key-should-be-long",
|
||||
Algorithm: "HS256",
|
||||
Issuer: "supply-api-e2e",
|
||||
Enabled: true,
|
||||
},
|
||||
middleware.NewTokenCache(),
|
||||
&staticTokenBackend{statusByTokenID: map[string]string{}},
|
||||
adapter.NewAuditEmitterAdapter(auditStore),
|
||||
)
|
||||
|
||||
logger := logging.NewLogger("supply-api-e2e", logging.LogLevelError)
|
||||
|
||||
var handler http.Handler = mux
|
||||
handler = middleware.RequestID(handler)
|
||||
handler = middleware.Recovery(handler)
|
||||
handler = middleware.Logging(handler, logger)
|
||||
handler = middleware.TracingMiddleware(handler)
|
||||
handler = authMiddleware.TokenVerifyMiddleware(handler)
|
||||
handler = authMiddleware.BearerExtractMiddleware(handler)
|
||||
handler = authMiddleware.QueryKeyRejectMiddleware(handler)
|
||||
|
||||
return &e2eSystem{
|
||||
handler: handler,
|
||||
accountSvc: accountSvc,
|
||||
auditStore: auditStore,
|
||||
secretKey: "e2e-secret-key-should-be-long",
|
||||
tokenIssuer: "supply-api-e2e",
|
||||
}
|
||||
}
|
||||
|
||||
func getEnv(key, defaultValue string) string {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return value
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
func (s *e2eSystem) tokenForTenant(t *testing.T, tokenID string, tenantID int64) string {
|
||||
t.Helper()
|
||||
|
||||
// TestE2E_HealthCheck E2E 测试:健康检查
|
||||
func TestE2E_HealthCheck(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping E2E test in short mode")
|
||||
claims := &middleware.TokenClaims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ID: tokenID,
|
||||
Issuer: s.tokenIssuer,
|
||||
Subject: "subject-42",
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
},
|
||||
SubjectID: "42",
|
||||
Role: "org_admin",
|
||||
Scope: []string{"supply:write", "supply:read"},
|
||||
TenantID: tenantID,
|
||||
}
|
||||
|
||||
cfg := getE2EConfig()
|
||||
_, _ = cfg.Timeout, cfg.RetryAttempts // 使用配置参数
|
||||
|
||||
// 验证服务健康状态
|
||||
// 在真实的 E2E 测试中,这里会使用 HTTP 客户端调用真实服务
|
||||
t.Logf("E2E_BASE_URL: %s", cfg.BaseURL)
|
||||
t.Skip("需要完整环境运行 E2E 测试")
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, err := token.SignedString([]byte(s.secretKey))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to sign token: %v", err)
|
||||
}
|
||||
return tokenString
|
||||
}
|
||||
|
||||
// TestE2E_AccountLifecycle E2E 测试:账号完整生命周期
|
||||
func TestE2E_AccountLifecycle(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping E2E test in short mode")
|
||||
func decodeJSONBody(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any {
|
||||
t.Helper()
|
||||
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(recorder.Body.Bytes(), &payload); err != nil {
|
||||
t.Fatalf("failed to decode response: %v, body=%s", err, recorder.Body.String())
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
func TestE2E_HealthProbe_IsPublicAndHealthy(t *testing.T) {
|
||||
system := newE2ESystem(t, e2eOptions{withdrawEnabled: false})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/actuator/health", nil)
|
||||
req.Header.Set("X-Request-Id", "health-req-001")
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
system.handler.ServeHTTP(recorder, req)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d, body=%s", recorder.Code, recorder.Body.String())
|
||||
}
|
||||
if recorder.Header().Get("X-Request-Id") != "health-req-001" {
|
||||
t.Fatalf("expected X-Request-Id response header to round-trip")
|
||||
}
|
||||
|
||||
cfg := getE2EConfig()
|
||||
_, _ = cfg.Timeout, cfg.RetryAttempts // 使用配置参数
|
||||
|
||||
t.Log("E2E 测试:账号完整生命周期")
|
||||
t.Log("步骤:创建 -> 验证 -> 暂停 -> 恢复 -> 禁用")
|
||||
t.Skip("需要完整环境运行 E2E 测试")
|
||||
payload := decodeJSONBody(t, recorder)
|
||||
if payload["status"] != "healthy" {
|
||||
t.Fatalf("expected healthy status, got %v", payload["status"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestE2E_PackageLifecycle E2E 测试:套餐完整生命周期
|
||||
func TestE2E_PackageLifecycle(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping E2E test in short mode")
|
||||
func TestE2E_ProtectedRoute_RejectsMissingBearer(t *testing.T) {
|
||||
system := newE2ESystem(t, e2eOptions{withdrawEnabled: false})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/supply/billing?start_date=2026-04-01&end_date=2026-04-11", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
system.handler.ServeHTTP(recorder, req)
|
||||
|
||||
if recorder.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected status 401, got %d, body=%s", recorder.Code, recorder.Body.String())
|
||||
}
|
||||
|
||||
cfg := getE2EConfig()
|
||||
_, _ = cfg.Timeout, cfg.RetryAttempts // 使用配置参数
|
||||
|
||||
t.Log("E2E 测试:套餐完整生命周期")
|
||||
t.Log("步骤:创建草稿 -> 发布 -> 暂停 -> 售罄 -> 过期")
|
||||
t.Skip("需要完整环境运行 E2E 测试")
|
||||
payload := decodeJSONBody(t, recorder)
|
||||
errBody, ok := payload["error"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected error body, got %v", payload["error"])
|
||||
}
|
||||
if errBody["code"] != "AUTH_MISSING_BEARER" {
|
||||
t.Fatalf("expected AUTH_MISSING_BEARER, got %v", errBody["code"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestE2E_SettlementWorkflow E2E 测试:结算完整流程
|
||||
func TestE2E_SettlementWorkflow(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping E2E test in short mode")
|
||||
func TestE2E_ProtectedRoute_RejectsQueryCredentialLeak(t *testing.T) {
|
||||
system := newE2ESystem(t, e2eOptions{withdrawEnabled: false})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/supply/billing?api_key=abcdefghijklmnopqrstuvwxyz123456", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
system.handler.ServeHTTP(recorder, req)
|
||||
|
||||
if recorder.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected status 401, got %d, body=%s", recorder.Code, recorder.Body.String())
|
||||
}
|
||||
|
||||
cfg := getE2EConfig()
|
||||
_, _ = cfg.Timeout, cfg.RetryAttempts // 使用配置参数
|
||||
|
||||
t.Log("E2E 测试:结算完整流程")
|
||||
t.Log("步骤:创建结算单 -> 处理中 -> 完成/失败 -> 提现")
|
||||
t.Skip("需要完整环境运行 E2E 测试")
|
||||
payload := decodeJSONBody(t, recorder)
|
||||
errBody, ok := payload["error"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected error body, got %v", payload["error"])
|
||||
}
|
||||
if errBody["code"] != "QUERY_KEY_NOT_ALLOWED" {
|
||||
t.Fatalf("expected QUERY_KEY_NOT_ALLOWED, got %v", errBody["code"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestE2E_WithdrawFlow E2E 测试:提现流程
|
||||
func TestE2E_WithdrawFlow(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping E2E test in short mode")
|
||||
func TestE2E_VerifyAccount_UsesTenantIDFromVerifiedToken(t *testing.T) {
|
||||
system := newE2ESystem(t, e2eOptions{withdrawEnabled: false})
|
||||
token := system.tokenForTenant(t, "tok-e2e-verify", 2001)
|
||||
|
||||
body := `{"provider":"openai","account_type":"resource","credential_input":"sk-test"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/supply/accounts/verify", strings.NewReader(body))
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("X-Request-Id", "verify-req-001")
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
system.handler.ServeHTTP(recorder, req)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d, body=%s", recorder.Code, recorder.Body.String())
|
||||
}
|
||||
if system.accountSvc.lastVerifySupplierID != 2001 {
|
||||
t.Fatalf("expected supplierID 2001 from token tenant, got %d", system.accountSvc.lastVerifySupplierID)
|
||||
}
|
||||
|
||||
cfg := getE2EConfig()
|
||||
_, _ = cfg.Timeout, cfg.RetryAttempts // 使用配置参数
|
||||
|
||||
t.Log("E2E 测试:提现流程")
|
||||
t.Log("步骤:验证余额 -> 发起提现 -> 短信验证 -> 确认 -> 到账")
|
||||
t.Skip("需要完整环境运行 E2E 测试")
|
||||
payload := decodeJSONBody(t, recorder)
|
||||
if payload["request_id"] != "verify-req-001" {
|
||||
t.Fatalf("expected request_id verify-req-001, got %v", payload["request_id"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestE2E_AuditLogTrace E2E 测试:审计日志追踪
|
||||
func TestE2E_AuditLogTrace(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping E2E test in short mode")
|
||||
func TestE2E_Withdraw_DisabledBeforeSMSIntegration(t *testing.T) {
|
||||
system := newE2ESystem(t, e2eOptions{withdrawEnabled: false})
|
||||
token := system.tokenForTenant(t, "tok-e2e-withdraw", 2002)
|
||||
|
||||
body := `{"withdraw_amount":100,"payment_method":"bank","payment_account":"13800000000","sms_code":"123456"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/supply/settlements/withdraw", strings.NewReader(body))
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
system.handler.ServeHTTP(recorder, req)
|
||||
|
||||
if recorder.Code != http.StatusServiceUnavailable {
|
||||
t.Fatalf("expected status 503, got %d, body=%s", recorder.Code, recorder.Body.String())
|
||||
}
|
||||
|
||||
cfg := getE2EConfig()
|
||||
_, _ = cfg.Timeout, cfg.RetryAttempts // 使用配置参数
|
||||
|
||||
t.Log("E2E 测试:审计日志追踪")
|
||||
t.Log("验证操作被正确记录到审计日志")
|
||||
t.Skip("需要完整环境运行 E2E 测试")
|
||||
payload := decodeJSONBody(t, recorder)
|
||||
errBody, ok := payload["error"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected error body, got %v", payload["error"])
|
||||
}
|
||||
if errBody["code"] != "FEATURE_DISABLED" {
|
||||
t.Fatalf("expected FEATURE_DISABLED, got %v", errBody["code"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestE2E_ConcurrentWithdraw E2E 测试:并发提现
|
||||
func TestE2E_ConcurrentWithdraw(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping E2E test in short mode")
|
||||
func TestE2E_AuditEvent_CanBeReadBackThroughAPI(t *testing.T) {
|
||||
system := newE2ESystem(t, e2eOptions{withdrawEnabled: false})
|
||||
token := system.tokenForTenant(t, "tok-e2e-audit", 3001)
|
||||
|
||||
if err := system.auditStore.Emit(context.Background(), audit.Event{
|
||||
TenantID: 3001,
|
||||
ObjectType: "supply_account",
|
||||
ObjectID: 77,
|
||||
Action: "verify",
|
||||
RequestID: "audit-req-001",
|
||||
ResultCode: "OK",
|
||||
SourceIP: "127.0.0.1",
|
||||
}); err != nil {
|
||||
t.Fatalf("failed to seed audit event: %v", err)
|
||||
}
|
||||
|
||||
cfg := getE2EConfig()
|
||||
_, _ = cfg.Timeout, cfg.RetryAttempts // 使用配置参数
|
||||
events, err := system.auditStore.Query(context.Background(), audit.EventFilter{TenantID: 3001, Limit: 1})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to query seeded audit event: %v", err)
|
||||
}
|
||||
if len(events) != 1 {
|
||||
t.Fatalf("expected 1 audit event, got %d", len(events))
|
||||
}
|
||||
|
||||
t.Log("E2E 测试:并发提现")
|
||||
t.Log("验证乐观锁正确处理并发请求")
|
||||
t.Skip("需要完整环境运行 E2E 测试")
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/audit/events/"+events[0].EventID, nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("X-Request-Id", "audit-read-001")
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
system.handler.ServeHTTP(recorder, req)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d, body=%s", recorder.Code, recorder.Body.String())
|
||||
}
|
||||
|
||||
payload := decodeJSONBody(t, recorder)
|
||||
data, ok := payload["data"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected data object, got %v", payload["data"])
|
||||
}
|
||||
if data["event_id"] != events[0].EventID {
|
||||
t.Fatalf("expected event_id %s, got %v", events[0].EventID, data["event_id"])
|
||||
}
|
||||
if data["request_id"] != "audit-req-001" {
|
||||
t.Fatalf("expected seeded request_id audit-req-001, got %v", data["request_id"])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,11 +12,12 @@ import (
|
||||
|
||||
// Config 应用配置
|
||||
type Config struct {
|
||||
Server ServerConfig
|
||||
Database DatabaseConfig
|
||||
Redis RedisConfig
|
||||
Token TokenConfig
|
||||
Audit AuditConfig
|
||||
Server ServerConfig
|
||||
Database DatabaseConfig
|
||||
Redis RedisConfig
|
||||
Token TokenConfig
|
||||
Settlement SettlementConfig
|
||||
Audit AuditConfig
|
||||
}
|
||||
|
||||
// ServerConfig HTTP服务配置
|
||||
@@ -63,6 +64,11 @@ type TokenConfig struct {
|
||||
RevocationCacheTTL time.Duration
|
||||
}
|
||||
|
||||
// SettlementConfig 结算与提现能力配置
|
||||
type SettlementConfig struct {
|
||||
WithdrawEnabled bool
|
||||
}
|
||||
|
||||
// AuditConfig 审计配置
|
||||
type AuditConfig struct {
|
||||
BufferSize int
|
||||
@@ -90,6 +96,15 @@ func (r *RedisConfig) Addr() string {
|
||||
|
||||
// Load 加载配置
|
||||
func Load(env string) (*Config, error) {
|
||||
return load(env, "")
|
||||
}
|
||||
|
||||
// LoadFromPath 从指定路径加载配置
|
||||
func LoadFromPath(env, configPath string) (*Config, error) {
|
||||
return load(env, configPath)
|
||||
}
|
||||
|
||||
func load(env, configPath string) (*Config, error) {
|
||||
v := viper.New()
|
||||
|
||||
// 设置环境变量前缀
|
||||
@@ -100,17 +115,24 @@ func Load(env string) (*Config, error) {
|
||||
setDefaults(v)
|
||||
|
||||
// 加载配置文件
|
||||
configFile := fmt.Sprintf("config.%s.yaml", env)
|
||||
v.SetConfigName(configFile)
|
||||
v.SetConfigType("yaml")
|
||||
v.AddConfigPath(".")
|
||||
v.AddConfigPath("./config")
|
||||
if strings.TrimSpace(configPath) != "" {
|
||||
v.SetConfigFile(configPath)
|
||||
} else {
|
||||
configFile := fmt.Sprintf("config.%s.yaml", env)
|
||||
v.SetConfigName(configFile)
|
||||
v.SetConfigType("yaml")
|
||||
v.AddConfigPath(".")
|
||||
v.AddConfigPath("./config")
|
||||
}
|
||||
|
||||
// 允许环境变量覆盖
|
||||
v.AutomaticEnv()
|
||||
|
||||
// 读取配置文件
|
||||
if err := v.ReadInConfig(); err != nil {
|
||||
if strings.TrimSpace(configPath) != "" {
|
||||
return nil, fmt.Errorf("failed to read config: %w", err)
|
||||
}
|
||||
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
|
||||
return nil, fmt.Errorf("failed to read config: %w", err)
|
||||
}
|
||||
@@ -163,6 +185,13 @@ func Load(env string) (*Config, error) {
|
||||
cfg.Audit.FlushInterval = v.GetDuration("audit.flush_interval")
|
||||
cfg.Audit.ExportTimeout = v.GetDuration("audit.export_timeout")
|
||||
|
||||
// Settlement配置
|
||||
cfg.Settlement.WithdrawEnabled = v.GetBool("settlement.withdraw_enabled")
|
||||
|
||||
if err := validateForEnv(env, &cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
@@ -202,6 +231,9 @@ func setDefaults(v *viper.Viper) {
|
||||
v.SetDefault("token.revocation_cache_ttl", 30*time.Second)
|
||||
v.SetDefault("token.algorithm", "HS256") // 默认HS256,可配置RS256
|
||||
|
||||
// Settlement defaults
|
||||
v.SetDefault("settlement.withdraw_enabled", false)
|
||||
|
||||
// Audit defaults
|
||||
v.SetDefault("audit.buffer_size", 1000)
|
||||
v.SetDefault("audit.flush_interval", 5*time.Second)
|
||||
@@ -228,6 +260,10 @@ func bindEnvVars(v *viper.Viper) {
|
||||
_ = v.BindEnv("redis.db", "SUPPLY_REDIS_DB")
|
||||
|
||||
_ = v.BindEnv("token.secret_key", "SUPPLY_TOKEN_SECRET_KEY")
|
||||
_ = v.BindEnv("token.public_key", "SUPPLY_TOKEN_PUBLIC_KEY")
|
||||
_ = v.BindEnv("token.algorithm", "SUPPLY_TOKEN_ALGORITHM")
|
||||
_ = v.BindEnv("token.issuer", "SUPPLY_TOKEN_ISSUER")
|
||||
_ = v.BindEnv("settlement.withdraw_enabled", "SUPPLY_SETTLEMENT_WITHDRAW_ENABLED")
|
||||
}
|
||||
|
||||
// MustLoad 加载配置,失败时panic
|
||||
@@ -249,6 +285,46 @@ func GetEnvInt(key string, defaultVal int) int {
|
||||
return defaultVal
|
||||
}
|
||||
|
||||
func validateForEnv(env string, cfg *Config) error {
|
||||
if cfg == nil {
|
||||
return fmt.Errorf("config is nil")
|
||||
}
|
||||
|
||||
cfg.Token.Algorithm = strings.ToUpper(strings.TrimSpace(cfg.Token.Algorithm))
|
||||
if cfg.Token.Algorithm == "" {
|
||||
cfg.Token.Algorithm = "HS256"
|
||||
}
|
||||
|
||||
if env != "prod" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if cfg.Server.DefaultSupplierID != 0 {
|
||||
return fmt.Errorf("invalid prod config: server.default_supplier_id must be 0 to disable static supplier fallback")
|
||||
}
|
||||
if strings.TrimSpace(cfg.Token.Issuer) == "" {
|
||||
return fmt.Errorf("invalid prod config: token.issuer is required")
|
||||
}
|
||||
if cfg.Settlement.WithdrawEnabled {
|
||||
return fmt.Errorf("invalid prod config: settlement.withdraw_enabled cannot be true until SMS integration is production-ready")
|
||||
}
|
||||
|
||||
switch cfg.Token.Algorithm {
|
||||
case "HS256", "HS384", "HS512":
|
||||
if strings.TrimSpace(cfg.Token.SecretKey) == "" {
|
||||
return fmt.Errorf("invalid prod config: token.secret_key is required for %s", cfg.Token.Algorithm)
|
||||
}
|
||||
case "RS256", "RS384", "RS512":
|
||||
if strings.TrimSpace(cfg.Token.PublicKey) == "" {
|
||||
return fmt.Errorf("invalid prod config: token.public_key is required for %s", cfg.Token.Algorithm)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("invalid prod config: unsupported token.algorithm %q", cfg.Token.Algorithm)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetEnvDuration 获取环境变量duration值
|
||||
func GetEnvDuration(key string, defaultVal time.Duration) time.Duration {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
|
||||
135
supply-api/internal/config/config_test.go
Normal file
135
supply-api/internal/config/config_test.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadFromPath(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "custom.yaml")
|
||||
content := []byte(`
|
||||
server:
|
||||
addr: ":19090"
|
||||
database:
|
||||
host: "db.internal"
|
||||
token:
|
||||
issuer: "custom-issuer"
|
||||
`)
|
||||
if err := os.WriteFile(configPath, content, 0o600); err != nil {
|
||||
t.Fatalf("failed to write config file: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadFromPath("dev", configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadFromPath returned error: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Server.Addr != ":19090" {
|
||||
t.Fatalf("expected addr :19090, got %s", cfg.Server.Addr)
|
||||
}
|
||||
if cfg.Database.Host != "db.internal" {
|
||||
t.Fatalf("expected database host db.internal, got %s", cfg.Database.Host)
|
||||
}
|
||||
if cfg.Token.Issuer != "custom-issuer" {
|
||||
t.Fatalf("expected token issuer custom-issuer, got %s", cfg.Token.Issuer)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadFromPath_MissingFile(t *testing.T) {
|
||||
_, err := LoadFromPath("dev", filepath.Join(t.TempDir(), "missing.yaml"))
|
||||
if err == nil {
|
||||
t.Fatal("expected missing config file to return error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadFromPath_ProdRejectsDefaultSupplierIDFallback(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "prod.yaml")
|
||||
content := []byte(`
|
||||
server:
|
||||
addr: ":19090"
|
||||
database:
|
||||
host: "db.internal"
|
||||
user: "postgres"
|
||||
password: "secret"
|
||||
database: "supply_db"
|
||||
token:
|
||||
issuer: "prod-issuer"
|
||||
secret_key: "prod-secret"
|
||||
`)
|
||||
if err := os.WriteFile(configPath, content, 0o600); err != nil {
|
||||
t.Fatalf("failed to write config file: %v", err)
|
||||
}
|
||||
|
||||
_, err := LoadFromPath("prod", configPath)
|
||||
if err == nil {
|
||||
t.Fatal("expected prod config with default supplier fallback to return error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "default_supplier_id") {
|
||||
t.Fatalf("expected error to mention default_supplier_id, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadFromPath_ProdRejectsMissingHS256SecretKey(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "prod.yaml")
|
||||
content := []byte(`
|
||||
server:
|
||||
addr: ":19090"
|
||||
default_supplier_id: 0
|
||||
database:
|
||||
host: "db.internal"
|
||||
user: "postgres"
|
||||
password: "secret"
|
||||
database: "supply_db"
|
||||
token:
|
||||
issuer: "prod-issuer"
|
||||
algorithm: "HS256"
|
||||
`)
|
||||
if err := os.WriteFile(configPath, content, 0o600); err != nil {
|
||||
t.Fatalf("failed to write config file: %v", err)
|
||||
}
|
||||
|
||||
_, err := LoadFromPath("prod", configPath)
|
||||
if err == nil {
|
||||
t.Fatal("expected prod config without HS256 secret key to return error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "token.secret_key") {
|
||||
t.Fatalf("expected error to mention token.secret_key, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadFromPath_ProdRejectsWithdrawEnabledUntilSMSIntegrated(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "prod.yaml")
|
||||
content := []byte(`
|
||||
server:
|
||||
addr: ":19090"
|
||||
default_supplier_id: 0
|
||||
database:
|
||||
host: "db.internal"
|
||||
user: "postgres"
|
||||
password: "secret"
|
||||
database: "supply_db"
|
||||
token:
|
||||
issuer: "prod-issuer"
|
||||
algorithm: "HS256"
|
||||
secret_key: "prod-secret"
|
||||
settlement:
|
||||
withdraw_enabled: true
|
||||
`)
|
||||
if err := os.WriteFile(configPath, content, 0o600); err != nil {
|
||||
t.Fatalf("failed to write config file: %v", err)
|
||||
}
|
||||
|
||||
_, err := LoadFromPath("prod", configPath)
|
||||
if err == nil {
|
||||
t.Fatal("expected prod config with withdraw enabled but no integrated SMS path to return error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "withdraw_enabled") {
|
||||
t.Fatalf("expected error to mention withdraw_enabled, got %v", err)
|
||||
}
|
||||
}
|
||||
@@ -18,14 +18,15 @@ import (
|
||||
|
||||
// SupplyAPI 处理器
|
||||
type SupplyAPI struct {
|
||||
accountService domain.AccountService
|
||||
packageService domain.PackageService
|
||||
settlementService domain.SettlementService
|
||||
accountService domain.AccountService
|
||||
packageService domain.PackageService
|
||||
settlementService domain.SettlementService
|
||||
earningService domain.EarningService
|
||||
idempotencyMw *middleware.IdempotencyMiddleware // P0-P4修复: 使用DB-backed幂等中间件
|
||||
auditStore audit.AuditStore // P0-R08修复: 使用接口支持DB-backed实现
|
||||
fkValidator *repository.ForeignKeyValidator // P0-09修复: 外键校验器
|
||||
auditStore audit.AuditStore // P0-R08修复: 使用接口支持DB-backed实现
|
||||
fkValidator *repository.ForeignKeyValidator // P0-09修复: 外键校验器
|
||||
supplierID int64
|
||||
withdrawEnabled bool
|
||||
statementBaseURL string
|
||||
now func() time.Time
|
||||
}
|
||||
@@ -51,11 +52,16 @@ func NewSupplyAPI(
|
||||
auditStore: auditStore,
|
||||
fkValidator: fkValidator,
|
||||
supplierID: supplierID,
|
||||
withdrawEnabled: true,
|
||||
statementBaseURL: statementBaseURL,
|
||||
now: now,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *SupplyAPI) SetWithdrawEnabled(enabled bool) {
|
||||
a.withdrawEnabled = enabled
|
||||
}
|
||||
|
||||
func (a *SupplyAPI) Register(mux *http.ServeMux) {
|
||||
// Supply Accounts
|
||||
mux.HandleFunc("/api/v1/supply/accounts/verify", a.handleVerifyAccount)
|
||||
@@ -82,6 +88,25 @@ func (a *SupplyAPI) Register(mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/v1/audit/events/", a.handleAuditEvent)
|
||||
}
|
||||
|
||||
func (a *SupplyAPI) resolveSupplierID(ctx context.Context) (int64, error) {
|
||||
if tenantID := middleware.GetTenantID(ctx); tenantID > 0 {
|
||||
return tenantID, nil
|
||||
}
|
||||
if a.supplierID > 0 {
|
||||
return a.supplierID, nil
|
||||
}
|
||||
return 0, fmt.Errorf("supplier context is missing")
|
||||
}
|
||||
|
||||
func (a *SupplyAPI) requireSupplierID(w http.ResponseWriter, r *http.Request) (int64, bool) {
|
||||
supplierID, err := a.resolveSupplierID(r.Context())
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "AUTH_CONTEXT_MISSING", err.Error())
|
||||
return 0, false
|
||||
}
|
||||
return supplierID, true
|
||||
}
|
||||
|
||||
// ==================== Account Handlers ====================
|
||||
|
||||
type VerifyAccountRequest struct {
|
||||
@@ -110,7 +135,12 @@ func (a *SupplyAPI) handleVerifyAccount(w http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
result, err := a.accountService.Verify(r.Context(), a.supplierID,
|
||||
supplierID, ok := a.requireSupplierID(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
result, err := a.accountService.Verify(r.Context(), supplierID,
|
||||
domain.Provider(req.Provider),
|
||||
domain.AccountType(req.AccountType),
|
||||
req.CredentialInput)
|
||||
@@ -139,12 +169,17 @@ func (a *SupplyAPI) handleCreateAccount(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
// 降级:使用内联幂等逻辑(仅在幂等中间件未启用时)
|
||||
a.createAccountHandler(context.Background(), w, r, nil)
|
||||
a.createAccountHandler(r.Context(), w, r, nil)
|
||||
}
|
||||
|
||||
// createAccountHandler 创建账号的业务逻辑(供幂等中间件包装)
|
||||
func (a *SupplyAPI) createAccountHandler(ctx context.Context, w http.ResponseWriter, r *http.Request, _ *repository.IdempotencyRecord) error {
|
||||
requestID := r.Header.Get("X-Request-Id")
|
||||
supplierID, err := a.resolveSupplierID(ctx)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "AUTH_CONTEXT_MISSING", err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
@@ -169,14 +204,14 @@ func (a *SupplyAPI) createAccountHandler(ctx context.Context, w http.ResponseWri
|
||||
|
||||
// P0-09修复: 创建账户前校验外键引用
|
||||
if a.fkValidator != nil {
|
||||
if err := a.fkValidator.ValidateSupplyAccountOwner(ctx, a.supplierID); err != nil {
|
||||
if err := a.fkValidator.ValidateSupplyAccountOwner(ctx, supplierID); err != nil {
|
||||
writeError(w, http.StatusUnprocessableEntity, "FK_VALIDATION_FAILED", "supplier does not exist")
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
createReq := &domain.CreateAccountRequest{
|
||||
SupplierID: a.supplierID,
|
||||
SupplierID: supplierID,
|
||||
Provider: domain.Provider(rawReq.Provider),
|
||||
AccountType: domain.AccountType(rawReq.AccountType),
|
||||
Credential: rawReq.CredentialInput,
|
||||
@@ -252,7 +287,12 @@ func (a *SupplyAPI) handleAccountActions(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
func (a *SupplyAPI) handleActivateAccount(w http.ResponseWriter, r *http.Request, accountID int64) {
|
||||
account, err := a.accountService.Activate(r.Context(), a.supplierID, accountID)
|
||||
supplierID, ok := a.requireSupplierID(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
account, err := a.accountService.Activate(r.Context(), supplierID, accountID)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "SUP_ACC") {
|
||||
writeError(w, http.StatusConflict, "CONFLICT", err.Error())
|
||||
@@ -273,7 +313,12 @@ func (a *SupplyAPI) handleActivateAccount(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
|
||||
func (a *SupplyAPI) handleSuspendAccount(w http.ResponseWriter, r *http.Request, accountID int64) {
|
||||
account, err := a.accountService.Suspend(r.Context(), a.supplierID, accountID)
|
||||
supplierID, ok := a.requireSupplierID(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
account, err := a.accountService.Suspend(r.Context(), supplierID, accountID)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "SUP_ACC") {
|
||||
writeError(w, http.StatusConflict, "CONFLICT", err.Error())
|
||||
@@ -294,7 +339,12 @@ func (a *SupplyAPI) handleSuspendAccount(w http.ResponseWriter, r *http.Request,
|
||||
}
|
||||
|
||||
func (a *SupplyAPI) handleDeleteAccount(w http.ResponseWriter, r *http.Request, accountID int64) {
|
||||
err := a.accountService.Delete(r.Context(), a.supplierID, accountID)
|
||||
supplierID, ok := a.requireSupplierID(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
err := a.accountService.Delete(r.Context(), supplierID, accountID)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "SUP_ACC") {
|
||||
writeError(w, http.StatusConflict, "CONFLICT", err.Error())
|
||||
@@ -308,11 +358,27 @@ func (a *SupplyAPI) handleDeleteAccount(w http.ResponseWriter, r *http.Request,
|
||||
}
|
||||
|
||||
func (a *SupplyAPI) handleAccountAuditLogs(w http.ResponseWriter, r *http.Request, accountID int64) {
|
||||
supplierID, ok := a.requireSupplierID(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
page := getQueryInt(r, "page", 1)
|
||||
pageSize := getQueryInt(r, "page_size", 20)
|
||||
|
||||
// 分页参数边界验证
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize < 1 {
|
||||
pageSize = 20
|
||||
}
|
||||
if pageSize > 1000 {
|
||||
pageSize = 1000
|
||||
}
|
||||
|
||||
events, total, err := a.auditStore.QueryWithTotal(r.Context(), audit.EventFilter{
|
||||
TenantID: a.supplierID,
|
||||
TenantID: supplierID,
|
||||
ObjectType: "supply_account",
|
||||
ObjectID: accountID,
|
||||
Limit: pageSize,
|
||||
@@ -378,6 +444,11 @@ func (a *SupplyAPI) handleCreatePackageDraft(w http.ResponseWriter, r *http.Requ
|
||||
return
|
||||
}
|
||||
|
||||
supplierID, ok := a.requireSupplierID(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// P0-09修复: 创建套餐前校验外键引用
|
||||
if a.fkValidator != nil {
|
||||
if err := a.fkValidator.ValidatePackageSupplyAccount(r.Context(), req.SupplyAccountID); err != nil {
|
||||
@@ -387,7 +458,7 @@ func (a *SupplyAPI) handleCreatePackageDraft(w http.ResponseWriter, r *http.Requ
|
||||
}
|
||||
|
||||
createReq := &domain.CreatePackageDraftRequest{
|
||||
SupplierID: a.supplierID,
|
||||
SupplierID: supplierID,
|
||||
AccountID: req.SupplyAccountID,
|
||||
Model: req.Model,
|
||||
TotalQuota: req.TotalQuota,
|
||||
@@ -398,7 +469,7 @@ func (a *SupplyAPI) handleCreatePackageDraft(w http.ResponseWriter, r *http.Requ
|
||||
RateLimitRPM: req.RateLimitRPM,
|
||||
}
|
||||
|
||||
pkg, err := a.packageService.CreateDraft(r.Context(), a.supplierID, createReq)
|
||||
pkg, err := a.packageService.CreateDraft(r.Context(), supplierID, createReq)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnprocessableEntity, "CREATE_FAILED", err.Error())
|
||||
return
|
||||
@@ -477,7 +548,12 @@ func (a *SupplyAPI) handlePackageActions(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
func (a *SupplyAPI) handlePublishPackage(w http.ResponseWriter, r *http.Request, packageID int64) {
|
||||
pkg, err := a.packageService.Publish(r.Context(), a.supplierID, packageID)
|
||||
supplierID, ok := a.requireSupplierID(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
pkg, err := a.packageService.Publish(r.Context(), supplierID, packageID)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "SUP_PKG") {
|
||||
writeError(w, http.StatusConflict, "CONFLICT", err.Error())
|
||||
@@ -498,7 +574,12 @@ func (a *SupplyAPI) handlePublishPackage(w http.ResponseWriter, r *http.Request,
|
||||
}
|
||||
|
||||
func (a *SupplyAPI) handlePausePackage(w http.ResponseWriter, r *http.Request, packageID int64) {
|
||||
pkg, err := a.packageService.Pause(r.Context(), a.supplierID, packageID)
|
||||
supplierID, ok := a.requireSupplierID(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
pkg, err := a.packageService.Pause(r.Context(), supplierID, packageID)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "SUP_PKG") {
|
||||
writeError(w, http.StatusConflict, "CONFLICT", err.Error())
|
||||
@@ -519,7 +600,12 @@ func (a *SupplyAPI) handlePausePackage(w http.ResponseWriter, r *http.Request, p
|
||||
}
|
||||
|
||||
func (a *SupplyAPI) handleUnlistPackage(w http.ResponseWriter, r *http.Request, packageID int64) {
|
||||
pkg, err := a.packageService.Unlist(r.Context(), a.supplierID, packageID)
|
||||
supplierID, ok := a.requireSupplierID(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
pkg, err := a.packageService.Unlist(r.Context(), supplierID, packageID)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "SUP_PKG") {
|
||||
writeError(w, http.StatusConflict, "CONFLICT", err.Error())
|
||||
@@ -540,7 +626,12 @@ func (a *SupplyAPI) handleUnlistPackage(w http.ResponseWriter, r *http.Request,
|
||||
}
|
||||
|
||||
func (a *SupplyAPI) handleClonePackage(w http.ResponseWriter, r *http.Request, packageID int64) {
|
||||
pkg, err := a.packageService.Clone(r.Context(), a.supplierID, packageID)
|
||||
supplierID, ok := a.requireSupplierID(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
pkg, err := a.packageService.Clone(r.Context(), supplierID, packageID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "NOT_FOUND", err.Error())
|
||||
return
|
||||
@@ -595,7 +686,12 @@ func (a *SupplyAPI) handleBatchUpdatePrice(w http.ResponseWriter, r *http.Reques
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := a.packageService.BatchUpdatePrice(r.Context(), a.supplierID, req)
|
||||
supplierID, ok := a.requireSupplierID(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := a.packageService.BatchUpdatePrice(r.Context(), supplierID, req)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnprocessableEntity, "BATCH_UPDATE_FAILED", err.Error())
|
||||
return
|
||||
@@ -618,7 +714,12 @@ func (a *SupplyAPI) handleGetBilling(w http.ResponseWriter, r *http.Request) {
|
||||
startDate := r.URL.Query().Get("start_date")
|
||||
endDate := r.URL.Query().Get("end_date")
|
||||
|
||||
summary, err := a.earningService.GetBillingSummary(r.Context(), a.supplierID, startDate, endDate)
|
||||
supplierID, ok := a.requireSupplierID(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
summary, err := a.earningService.GetBillingSummary(r.Context(), supplierID, startDate, endDate)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "QUERY_FAILED", err.Error())
|
||||
return
|
||||
@@ -637,6 +738,10 @@ func (a *SupplyAPI) handleWithdraw(w http.ResponseWriter, r *http.Request) {
|
||||
writeError(w, http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", "method not allowed")
|
||||
return
|
||||
}
|
||||
if !a.withdrawEnabled {
|
||||
writeError(w, http.StatusServiceUnavailable, "FEATURE_DISABLED", "withdraw is disabled until SMS verification is integrated")
|
||||
return
|
||||
}
|
||||
|
||||
// P0-P4修复: 使用DB-backed幂等中间件
|
||||
if a.idempotencyMw != nil {
|
||||
@@ -645,12 +750,17 @@ func (a *SupplyAPI) handleWithdraw(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// 降级:使用内联幂等逻辑(仅在幂等中间件未启用时)
|
||||
a.withdrawHandler(context.Background(), w, r, nil)
|
||||
a.withdrawHandler(r.Context(), w, r, nil)
|
||||
}
|
||||
|
||||
// withdrawHandler 提现的业务逻辑(供幂等中间件包装)
|
||||
func (a *SupplyAPI) withdrawHandler(ctx context.Context, w http.ResponseWriter, r *http.Request, _ *repository.IdempotencyRecord) error {
|
||||
requestID := r.Header.Get("X-Request-Id")
|
||||
supplierID, err := a.resolveSupplierID(ctx)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "AUTH_CONTEXT_MISSING", err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
@@ -678,7 +788,7 @@ func (a *SupplyAPI) withdrawHandler(ctx context.Context, w http.ResponseWriter,
|
||||
SMSCode: req.SMSCode,
|
||||
}
|
||||
|
||||
settlement, err := a.settlementService.Withdraw(ctx, a.supplierID, withdrawReq)
|
||||
settlement, err := a.settlementService.Withdraw(ctx, supplierID, withdrawReq)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "SUP_SET") {
|
||||
writeError(w, http.StatusConflict, "WITHDRAW_FAILED", err.Error())
|
||||
@@ -740,7 +850,12 @@ func (a *SupplyAPI) handleSettlementActions(w http.ResponseWriter, r *http.Reque
|
||||
}
|
||||
|
||||
func (a *SupplyAPI) handleCancelSettlement(w http.ResponseWriter, r *http.Request, settlementID int64) {
|
||||
settlement, err := a.settlementService.Cancel(r.Context(), a.supplierID, settlementID)
|
||||
supplierID, ok := a.requireSupplierID(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
settlement, err := a.settlementService.Cancel(r.Context(), supplierID, settlementID)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "SUP_SET") {
|
||||
writeError(w, http.StatusConflict, "CONFLICT", err.Error())
|
||||
@@ -761,7 +876,12 @@ func (a *SupplyAPI) handleCancelSettlement(w http.ResponseWriter, r *http.Reques
|
||||
}
|
||||
|
||||
func (a *SupplyAPI) handleGetStatement(w http.ResponseWriter, r *http.Request, settlementID int64) {
|
||||
settlement, err := a.settlementService.GetByID(r.Context(), a.supplierID, settlementID)
|
||||
supplierID, ok := a.requireSupplierID(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
settlement, err := a.settlementService.GetByID(r.Context(), supplierID, settlementID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "NOT_FOUND", err.Error())
|
||||
return
|
||||
@@ -791,7 +911,23 @@ func (a *SupplyAPI) handleGetEarningRecords(w http.ResponseWriter, r *http.Reque
|
||||
page := getQueryInt(r, "page", 1)
|
||||
pageSize := getQueryInt(r, "page_size", 20)
|
||||
|
||||
records, total, err := a.earningService.ListRecords(r.Context(), a.supplierID, startDate, endDate, page, pageSize)
|
||||
// 分页参数边界验证
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize < 1 {
|
||||
pageSize = 20
|
||||
}
|
||||
if pageSize > 1000 {
|
||||
pageSize = 1000
|
||||
}
|
||||
|
||||
supplierID, ok := a.requireSupplierID(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
records, total, err := a.earningService.ListRecords(r.Context(), supplierID, startDate, endDate, page, pageSize)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "QUERY_FAILED", err.Error())
|
||||
return
|
||||
|
||||
1399
supply-api/internal/httpapi/supply_api_test.go
Normal file
1399
supply-api/internal/httpapi/supply_api_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -2,9 +2,12 @@ package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
@@ -30,21 +33,23 @@ type TokenClaims struct {
|
||||
|
||||
// AuthConfig 鉴权中间件配置
|
||||
type AuthConfig struct {
|
||||
SecretKey string
|
||||
Issuer string
|
||||
CacheTTL time.Duration // token状态缓存TTL
|
||||
Enabled bool // 是否启用鉴权
|
||||
TrustedProxies []string // 可信代理IP列表CIDR,如 "10.0.0.0/8"
|
||||
SecretKey string
|
||||
PublicKey string
|
||||
Algorithm string
|
||||
Issuer string
|
||||
CacheTTL time.Duration // token状态缓存TTL
|
||||
Enabled bool // 是否启用鉴权
|
||||
TrustedProxies []string // 可信代理IP列表CIDR,如 "10.0.0.0/8"
|
||||
}
|
||||
|
||||
// AuthMiddleware 鉴权中间件
|
||||
type AuthMiddleware struct {
|
||||
config AuthConfig
|
||||
tokenCache *TokenCache
|
||||
tokenBackend TokenStatusBackend
|
||||
auditEmitter AuditEmitter
|
||||
bruteForce *BruteForceProtection // 暴力破解保护
|
||||
trustedProxies []string // 可信代理列表
|
||||
config AuthConfig
|
||||
tokenCache *TokenCache
|
||||
tokenBackend TokenStatusBackend
|
||||
auditEmitter AuditEmitter
|
||||
bruteForce *BruteForceProtection // 暴力破解保护
|
||||
trustedProxies []string // 可信代理列表
|
||||
}
|
||||
|
||||
// TokenStatusBackend Token状态后端查询接口
|
||||
@@ -200,6 +205,11 @@ func (b *BruteForceProtection) Len() int {
|
||||
// 对应M-016指标
|
||||
func (m *AuthMiddleware) QueryKeyRejectMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if shouldBypassAuth(r.URL.Path) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查query string中的可疑参数
|
||||
queryParams := r.URL.Query()
|
||||
|
||||
@@ -257,6 +267,11 @@ func (m *AuthMiddleware) QueryKeyRejectMiddleware(next http.Handler) http.Handle
|
||||
// BearerExtractMiddleware 提取Bearer Token
|
||||
func (m *AuthMiddleware) BearerExtractMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if shouldBypassAuth(r.URL.Path) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
|
||||
if authHeader == "" {
|
||||
@@ -299,6 +314,11 @@ func (m *AuthMiddleware) BearerExtractMiddleware(next http.Handler) http.Handler
|
||||
// MED-12: 添加暴力破解保护
|
||||
func (m *AuthMiddleware) TokenVerifyMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if shouldBypassAuth(r.URL.Path) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// 如果鉴权被禁用(仅用于开发环境),直接跳过验证
|
||||
if !m.config.Enabled {
|
||||
// 在开发模式下,虽然跳过JWT验证,但仍记录警告日志
|
||||
@@ -356,7 +376,25 @@ func (m *AuthMiddleware) TokenVerifyMiddleware(next http.Handler) http.Handler {
|
||||
|
||||
// 检查token状态(是否被吊销)
|
||||
status, err := m.checkTokenStatus(r.Context(), claims.ID)
|
||||
if err == nil && status != "active" {
|
||||
if err != nil {
|
||||
if m.auditEmitter != nil {
|
||||
m.auditEmitter.Emit(r.Context(), AuditEvent{
|
||||
EventName: "token.authn.fail",
|
||||
RequestID: getRequestID(r),
|
||||
TokenID: claims.ID,
|
||||
SubjectID: claims.SubjectID,
|
||||
Route: sanitizeRoute(r.URL.Path),
|
||||
ResultCode: "AUTH_TOKEN_STATUS_UNAVAILABLE",
|
||||
ClientIP: getClientIP(r),
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
writeAuthError(w, http.StatusUnauthorized, "AUTH_TOKEN_STATUS_UNAVAILABLE",
|
||||
"token status backend is unavailable")
|
||||
return
|
||||
}
|
||||
if status != "active" {
|
||||
if m.auditEmitter != nil {
|
||||
m.auditEmitter.Emit(r.Context(), AuditEvent{
|
||||
EventName: "token.authn.fail",
|
||||
@@ -435,10 +473,10 @@ func (m *AuthMiddleware) ScopeRoleAuthzMiddleware(requiredScope string) func(htt
|
||||
// viewer: level 10, operator: level 30, org_admin: level 50
|
||||
routeRoles := map[string]string{
|
||||
"/api/v1/supply/accounts": "org_admin",
|
||||
"/api/v1/supply/packages": "org_admin",
|
||||
"/api/v1/supply/settlements": "org_admin",
|
||||
"/api/v1/supply/billing": "viewer",
|
||||
"/api/v1/supplier/billing": "viewer",
|
||||
"/api/v1/supply/packages": "org_admin",
|
||||
"/api/v1/supply/settlements": "org_admin",
|
||||
"/api/v1/supply/billing": "viewer",
|
||||
"/api/v1/supplier/billing": "viewer",
|
||||
}
|
||||
|
||||
for path, requiredRole := range routeRoles {
|
||||
@@ -458,12 +496,16 @@ func (m *AuthMiddleware) ScopeRoleAuthzMiddleware(requiredScope string) func(htt
|
||||
|
||||
// verifyToken 校验JWT token
|
||||
func (m *AuthMiddleware) verifyToken(tokenString string) (*TokenClaims, error) {
|
||||
expectedAlgorithm := strings.ToUpper(strings.TrimSpace(m.config.Algorithm))
|
||||
if expectedAlgorithm == "" {
|
||||
expectedAlgorithm = jwt.SigningMethodHS256.Alg()
|
||||
}
|
||||
|
||||
token, err := jwt.ParseWithClaims(tokenString, &TokenClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
// 严格验证算法:只接受HS256
|
||||
if token.Method.Alg() != jwt.SigningMethodHS256.Alg() {
|
||||
if token.Method.Alg() != expectedAlgorithm {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return []byte(m.config.SecretKey), nil
|
||||
return m.signingKey(expectedAlgorithm)
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
@@ -492,6 +534,53 @@ func (m *AuthMiddleware) verifyToken(tokenString string) (*TokenClaims, error) {
|
||||
return nil, errors.New("invalid token")
|
||||
}
|
||||
|
||||
func (m *AuthMiddleware) signingKey(algorithm string) (interface{}, error) {
|
||||
switch algorithm {
|
||||
case jwt.SigningMethodHS256.Alg(), jwt.SigningMethodHS384.Alg(), jwt.SigningMethodHS512.Alg():
|
||||
if strings.TrimSpace(m.config.SecretKey) == "" {
|
||||
return nil, errors.New("missing token secret key")
|
||||
}
|
||||
return []byte(m.config.SecretKey), nil
|
||||
case jwt.SigningMethodRS256.Alg(), jwt.SigningMethodRS384.Alg(), jwt.SigningMethodRS512.Alg():
|
||||
return parseRSAPublicKey(m.config.PublicKey)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported signing method: %s", algorithm)
|
||||
}
|
||||
}
|
||||
|
||||
func parseRSAPublicKey(publicKeyPEM string) (*rsa.PublicKey, error) {
|
||||
block, _ := pem.Decode([]byte(strings.TrimSpace(publicKeyPEM)))
|
||||
if block == nil {
|
||||
return nil, errors.New("invalid RSA public key: PEM decode failed")
|
||||
}
|
||||
|
||||
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
if err == nil {
|
||||
rsaPub, ok := pub.(*rsa.PublicKey)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid RSA public key type")
|
||||
}
|
||||
return rsaPub, nil
|
||||
}
|
||||
|
||||
cert, certErr := x509.ParseCertificate(block.Bytes)
|
||||
if certErr == nil {
|
||||
rsaPub, ok := cert.PublicKey.(*rsa.PublicKey)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid RSA certificate public key type")
|
||||
}
|
||||
return rsaPub, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("invalid RSA public key: %w", err)
|
||||
}
|
||||
|
||||
func shouldBypassAuth(path string) bool {
|
||||
return path == "/actuator/health" ||
|
||||
path == "/actuator/health/live" ||
|
||||
path == "/actuator/health/ready"
|
||||
}
|
||||
|
||||
// checkTokenStatus 检查token状态(从缓存或数据库)
|
||||
func (m *AuthMiddleware) checkTokenStatus(ctx context.Context, tokenID string) (string, error) {
|
||||
if m.tokenCache != nil {
|
||||
|
||||
@@ -2,6 +2,7 @@ package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
@@ -13,6 +14,15 @@ import (
|
||||
"lijiaoqiao/supply-api/internal/iam/model"
|
||||
)
|
||||
|
||||
type stubTokenStatusBackend struct {
|
||||
status string
|
||||
err error
|
||||
}
|
||||
|
||||
func (b *stubTokenStatusBackend) CheckTokenStatus(ctx context.Context, tokenID string) (string, error) {
|
||||
return b.status, b.err
|
||||
}
|
||||
|
||||
func TestTokenVerify(t *testing.T) {
|
||||
secretKey := "test-secret-key-12345678901234567890"
|
||||
issuer := "test-issuer"
|
||||
@@ -431,6 +441,38 @@ func TestMED02_TokenCacheMiss_ShouldNotAssumeActive(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenVerifyMiddleware_BackendErrorShouldReject(t *testing.T) {
|
||||
secretKey := "test-secret-key-12345678901234567890"
|
||||
issuer := "test-issuer"
|
||||
|
||||
authMiddleware := NewAuthMiddleware(AuthConfig{
|
||||
SecretKey: secretKey,
|
||||
Issuer: issuer,
|
||||
Enabled: true,
|
||||
}, NewTokenCache(), &stubTokenStatusBackend{err: errors.New("database unavailable")}, nil)
|
||||
|
||||
nextCalled := false
|
||||
handler := authMiddleware.TokenVerifyMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextCalled = true
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/v1/supply/accounts", nil)
|
||||
req = req.WithContext(context.WithValue(req.Context(), bearerTokenKey, createTestToken(secretKey, issuer, "subject:1", "org_admin", time.Now().Add(time.Hour))))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if nextCalled {
|
||||
t.Fatal("expected request to be rejected when token backend is unavailable")
|
||||
}
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected status 401, got %d", w.Code)
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "AUTH_TOKEN_STATUS_UNAVAILABLE") {
|
||||
t.Fatalf("expected response to contain AUTH_TOKEN_STATUS_UNAVAILABLE, got %s", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func createTestToken(secretKey, issuer, subject, role string, expiresAt time.Time) string {
|
||||
|
||||
@@ -34,9 +34,9 @@ type TokenCacheBackend interface {
|
||||
// DBTokenStatusBackend DB-backed Token状态后端(P0-03修复)
|
||||
// 同时实现 TokenStatusBackend 和 TokenRevocationBackend 接口
|
||||
type DBTokenStatusBackend struct {
|
||||
repo TokenRepository
|
||||
redisCache TokenCacheBackend
|
||||
cacheTTL time.Duration
|
||||
repo TokenRepository
|
||||
redisCache TokenCacheBackend
|
||||
cacheTTL time.Duration
|
||||
}
|
||||
|
||||
// NewDBTokenStatusBackend 创建DB-backed Token状态后端
|
||||
@@ -119,6 +119,17 @@ func (b *DBTokenStatusBackend) GetTokenStatus(ctx context.Context, tokenID strin
|
||||
|
||||
// RevokeBySubjectID 根据SubjectID吊销所有Token
|
||||
func (b *DBTokenStatusBackend) RevokeBySubjectID(ctx context.Context, subjectID int64, reason string) error {
|
||||
var tokenIDs []string
|
||||
if b.redisCache != nil {
|
||||
records, err := b.repo.ListActiveBySubjectID(ctx, subjectID)
|
||||
if err == nil {
|
||||
tokenIDs = make([]string, 0, len(records))
|
||||
for _, record := range records {
|
||||
tokenIDs = append(tokenIDs, record.TokenID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 1. 批量更新数据库
|
||||
count, err := b.repo.RevokeBySubjectID(ctx, subjectID, reason)
|
||||
if err != nil {
|
||||
@@ -132,13 +143,8 @@ func (b *DBTokenStatusBackend) RevokeBySubjectID(ctx context.Context, subjectID
|
||||
// 2. 失效所有相关缓存(这里需要查询后逐个失效)
|
||||
// 注意:生产环境建议使用Redis的pattern删除或发布事件通知
|
||||
if b.redisCache != nil {
|
||||
// 查询所有活跃token并失效
|
||||
records, err := b.repo.ListActiveBySubjectID(ctx, subjectID)
|
||||
if err != nil {
|
||||
return nil // 不影响主流程
|
||||
}
|
||||
for _, record := range records {
|
||||
_ = b.redisCache.InvalidateToken(ctx, record.TokenID)
|
||||
for _, tokenID := range tokenIDs {
|
||||
_ = b.redisCache.InvalidateToken(ctx, tokenID)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -14,11 +14,11 @@ import (
|
||||
|
||||
// MockTokenStatusRepository mock Token状态仓储
|
||||
type MockTokenStatusRepository struct {
|
||||
mu sync.RWMutex
|
||||
tokenStatuses map[string]string
|
||||
tokenReasons map[string]string
|
||||
verificationCounts map[string]int
|
||||
subjectTokens map[int64][]string
|
||||
mu sync.RWMutex
|
||||
tokenStatuses map[string]string
|
||||
tokenReasons map[string]string
|
||||
verificationCounts map[string]int
|
||||
subjectTokens map[int64][]string
|
||||
}
|
||||
|
||||
func NewMockTokenStatusRepository() *MockTokenStatusRepository {
|
||||
@@ -84,9 +84,9 @@ func (m *MockTokenStatusRepository) ListActiveBySubjectID(ctx context.Context, s
|
||||
|
||||
// MockRedisCache mock Redis缓存
|
||||
type MockRedisCache struct {
|
||||
mu sync.RWMutex
|
||||
tokenCache map[string]*cache.TokenStatus
|
||||
subscribers []func(event *cache.TokenRevokedCacheEvent)
|
||||
mu sync.RWMutex
|
||||
tokenCache map[string]*cache.TokenStatus
|
||||
subscribers []func(event *cache.TokenRevokedCacheEvent)
|
||||
}
|
||||
|
||||
func NewMockRedisCache() *MockRedisCache {
|
||||
@@ -136,7 +136,7 @@ func (m *MockRedisCache) PublishRevocation(tokenID string, reason string) {
|
||||
for _, handler := range handlers {
|
||||
handler(&cache.TokenRevokedCacheEvent{
|
||||
TokenID: tokenID,
|
||||
Reason: reason,
|
||||
Reason: reason,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -165,9 +165,9 @@ type TokenStatusRepositoryInterface interface {
|
||||
|
||||
// DBTokenStatusBackendForTest 用于测试的DBTokenStatusBackend
|
||||
type DBTokenStatusBackendForTest struct {
|
||||
repo TokenStatusRepositoryInterface
|
||||
redisCache *MockRedisCache
|
||||
cacheTTL time.Duration
|
||||
repo TokenStatusRepositoryInterface
|
||||
redisCache *MockRedisCache
|
||||
cacheTTL time.Duration
|
||||
}
|
||||
|
||||
func NewDBTokenStatusBackendForTest(repo TokenStatusRepositoryInterface, redisCache *MockRedisCache, cacheTTL time.Duration) *DBTokenStatusBackendForTest {
|
||||
@@ -373,6 +373,31 @@ func TestDBTokenStatusBackend_RevokeBySubjectID(t *testing.T) {
|
||||
repo.mu.RUnlock()
|
||||
}
|
||||
|
||||
func TestDBTokenStatusBackend_RevokeBySubjectID_InvalidatesCachedTokens(t *testing.T) {
|
||||
repo := NewMockTokenStatusRepository()
|
||||
redisCache := NewMockRedisCache()
|
||||
|
||||
repo.subjectTokens[123] = []string{"token1", "token2"}
|
||||
redisCache.tokenCache["token1"] = &cache.TokenStatus{TokenID: "token1", Status: "active"}
|
||||
redisCache.tokenCache["token2"] = &cache.TokenStatus{TokenID: "token2", Status: "active"}
|
||||
|
||||
backend := NewDBTokenStatusBackend(repo, redisCache, 10*time.Second)
|
||||
|
||||
err := backend.RevokeBySubjectID(context.Background(), 123, "bulk revocation")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
redisCache.mu.RLock()
|
||||
defer redisCache.mu.RUnlock()
|
||||
if _, ok := redisCache.tokenCache["token1"]; ok {
|
||||
t.Fatal("expected token1 cache to be invalidated")
|
||||
}
|
||||
if _, ok := redisCache.tokenCache["token2"]; ok {
|
||||
t.Fatal("expected token2 cache to be invalidated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBTokenStatusBackend_RevokeBySubjectID_NoTokens(t *testing.T) {
|
||||
repo := NewMockTokenStatusRepository()
|
||||
redisCache := NewMockRedisCache()
|
||||
@@ -412,10 +437,10 @@ func TestDBTokenStatusBackend_InterfaceCompliance(t *testing.T) {
|
||||
|
||||
// 测试各种状态转换
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenID string
|
||||
initialStatus string
|
||||
action func() error
|
||||
name string
|
||||
tokenID string
|
||||
initialStatus string
|
||||
action func() error
|
||||
expectedStatus string
|
||||
}{
|
||||
{
|
||||
@@ -780,7 +805,7 @@ func TestDBTokenStatusBackend_StartRevocationSubscriber_NoRedisCache(t *testing.
|
||||
|
||||
// MockTokenRevocationBackend mock TokenRevocationBackend
|
||||
type MockTokenRevocationBackend struct {
|
||||
mu sync.RWMutex
|
||||
mu sync.RWMutex
|
||||
revokedTokens map[string]string
|
||||
}
|
||||
|
||||
|
||||
@@ -278,6 +278,11 @@ func getTenantID(ctx context.Context) int64 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetTenantID 公开函数,从context获取租户ID
|
||||
func GetTenantID(ctx context.Context) int64 {
|
||||
return getTenantID(ctx)
|
||||
}
|
||||
|
||||
func getOperatorID(ctx context.Context) int64 {
|
||||
if v := ctx.Value(operatorIDKey); v != nil {
|
||||
if id, ok := v.(int64); ok {
|
||||
|
||||
@@ -2,6 +2,7 @@ package outbox
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"math"
|
||||
"time"
|
||||
@@ -13,7 +14,7 @@ import (
|
||||
|
||||
// OutboxProcessorRunner Outbox处理器运行器
|
||||
type OutboxProcessorRunner struct {
|
||||
repo *repository.OutboxRepository
|
||||
repo outboxRepository
|
||||
msgBroker messaging.MessageBroker
|
||||
stats messaging.OutboxStats
|
||||
stopCh chan struct{}
|
||||
@@ -21,9 +22,16 @@ type OutboxProcessorRunner struct {
|
||||
interval time.Duration
|
||||
}
|
||||
|
||||
type outboxRepository interface {
|
||||
FetchAndLock(ctx context.Context, limit int) ([]*repository.OutboxEvent, error)
|
||||
MarkCompleted(ctx context.Context, eventID string) error
|
||||
MarkFailed(ctx context.Context, eventID string, errorMsg string, nextRetryAt *time.Time) error
|
||||
MoveToDeadLetter(ctx context.Context, event *repository.OutboxEvent, errorMsg string) error
|
||||
}
|
||||
|
||||
// NewOutboxProcessorRunner 创建Outbox处理器运行器
|
||||
func NewOutboxProcessorRunner(
|
||||
repo *repository.OutboxRepository,
|
||||
repo outboxRepository,
|
||||
msgBroker messaging.MessageBroker,
|
||||
stats messaging.OutboxStats,
|
||||
) *OutboxProcessorRunner {
|
||||
@@ -66,6 +74,10 @@ func (r *OutboxProcessorRunner) Stop() {
|
||||
|
||||
// process 处理一批Outbox事件
|
||||
func (r *OutboxProcessorRunner) process(ctx context.Context) error {
|
||||
if r.msgBroker == nil {
|
||||
return fmt.Errorf("outbox message broker is unavailable")
|
||||
}
|
||||
|
||||
// 获取待处理事件
|
||||
events, err := r.repo.FetchAndLock(ctx, r.batchSize)
|
||||
if err != nil {
|
||||
@@ -85,7 +97,7 @@ func (r *OutboxProcessorRunner) process(ctx context.Context) error {
|
||||
EventType: event.EventType,
|
||||
EventID: event.EventID,
|
||||
Payload: event.Payload,
|
||||
Status: string(event.Status),
|
||||
Status: string(event.Status),
|
||||
RetryCount: event.RetryCount,
|
||||
MaxRetries: event.MaxRetries,
|
||||
ErrorMessage: event.ErrorMessage,
|
||||
|
||||
58
supply-api/internal/outbox/outbox_test.go
Normal file
58
supply-api/internal/outbox/outbox_test.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package outbox
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"lijiaoqiao/supply-api/internal/messaging"
|
||||
"lijiaoqiao/supply-api/internal/repository"
|
||||
)
|
||||
|
||||
type stubRunnerRepo struct {
|
||||
events []*repository.OutboxEvent
|
||||
}
|
||||
|
||||
func (r *stubRunnerRepo) FetchAndLock(ctx context.Context, limit int) ([]*repository.OutboxEvent, error) {
|
||||
return r.events, nil
|
||||
}
|
||||
|
||||
func (r *stubRunnerRepo) MarkCompleted(ctx context.Context, eventID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubRunnerRepo) MarkFailed(ctx context.Context, eventID string, errorMsg string, nextRetryAt *time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubRunnerRepo) MoveToDeadLetter(ctx context.Context, event *repository.OutboxEvent, errorMsg string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestOutboxProcessorRunner_ProcessRejectsNilMessageBroker(t *testing.T) {
|
||||
payload := json.RawMessage(`{"event":"created"}`)
|
||||
runner := NewOutboxProcessorRunner(&stubRunnerRepo{
|
||||
events: []*repository.OutboxEvent{
|
||||
{
|
||||
ID: 1,
|
||||
AggregateType: "account",
|
||||
AggregateID: "acc-1",
|
||||
EventType: "created",
|
||||
EventID: "evt-1",
|
||||
Payload: payload,
|
||||
Status: repository.OutboxStatusProcessing,
|
||||
MaxRetries: 5,
|
||||
},
|
||||
},
|
||||
}, nil, &messaging.NoOpOutboxStats{})
|
||||
|
||||
err := runner.process(context.Background())
|
||||
if err == nil {
|
||||
t.Fatal("expected nil message broker to return error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "message broker") {
|
||||
t.Fatalf("expected error to mention message broker, got %v", err)
|
||||
}
|
||||
}
|
||||
343
supply-api/internal/repository/outbox.go
Normal file
343
supply-api/internal/repository/outbox.go
Normal file
@@ -0,0 +1,343 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
// OutboxStatus Outbox事件状态
|
||||
type OutboxStatus string
|
||||
|
||||
const (
|
||||
OutboxStatusPending OutboxStatus = "pending"
|
||||
OutboxStatusProcessing OutboxStatus = "processing"
|
||||
OutboxStatusCompleted OutboxStatus = "completed"
|
||||
OutboxStatusFailed OutboxStatus = "failed"
|
||||
OutboxStatusDeadLetter OutboxStatus = "dead_letter"
|
||||
)
|
||||
|
||||
// OutboxEvent Outbox事件
|
||||
type OutboxEvent struct {
|
||||
ID int64 `json:"id"`
|
||||
AggregateType string `json:"aggregate_type"`
|
||||
AggregateID string `json:"aggregate_id"`
|
||||
EventType string `json:"event_type"`
|
||||
EventID string `json:"event_id"`
|
||||
Payload json.RawMessage `json:"payload"`
|
||||
Status OutboxStatus `json:"status"`
|
||||
RetryCount int `json:"retry_count"`
|
||||
MaxRetries int `json:"max_retries"`
|
||||
ErrorMessage string `json:"error_message,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ProcessedAt *time.Time `json:"processed_at,omitempty"`
|
||||
NextRetryAt *time.Time `json:"next_retry_at,omitempty"`
|
||||
DeadLetterReason string `json:"dead_letter_reason,omitempty"`
|
||||
Version int64 `json:"version"`
|
||||
}
|
||||
|
||||
// OutboxDeadLetter 死信记录
|
||||
type OutboxDeadLetter struct {
|
||||
ID int64 `json:"id"`
|
||||
OriginalEventID string `json:"original_event_id"`
|
||||
OriginalAggregateType string `json:"original_aggregate_type"`
|
||||
OriginalAggregateID string `json:"original_aggregate_id"`
|
||||
EventType string `json:"event_type"`
|
||||
Payload json.RawMessage `json:"payload"`
|
||||
ErrorMessage string `json:"error_message,omitempty"`
|
||||
RetryCount int `json:"retry_count"`
|
||||
FirstFailedAt time.Time `json:"first_failed_at"`
|
||||
DeadLetterAt time.Time `json:"dead_letter_at"`
|
||||
Handled bool `json:"handled"`
|
||||
HandledAt *time.Time `json:"handled_at,omitempty"`
|
||||
HandlerNotes string `json:"handler_notes,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// OutboxRepository Outbox仓储
|
||||
type OutboxRepository struct {
|
||||
db outboxDB
|
||||
}
|
||||
|
||||
type outboxDB interface {
|
||||
Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error)
|
||||
QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
|
||||
Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error)
|
||||
Begin(ctx context.Context) (pgx.Tx, error)
|
||||
}
|
||||
|
||||
// NewOutboxRepository 创建Outbox仓储
|
||||
func NewOutboxRepository(pool *pgxpool.Pool) *OutboxRepository {
|
||||
return &OutboxRepository{db: pool}
|
||||
}
|
||||
|
||||
// Create 创建Outbox事件
|
||||
func (r *OutboxRepository) Create(ctx context.Context, event *OutboxEvent) error {
|
||||
query := `
|
||||
INSERT INTO supply_outbox (
|
||||
aggregate_type, aggregate_id, event_type, event_id,
|
||||
payload, status, retry_count, max_retries
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
RETURNING id, created_at
|
||||
`
|
||||
|
||||
err := r.db.QueryRow(ctx, query,
|
||||
event.AggregateType, event.AggregateID, event.EventType, event.EventID,
|
||||
event.Payload, event.Status, event.RetryCount, event.MaxRetries,
|
||||
).Scan(&event.ID, &event.CreatedAt)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create outbox event: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// FetchAndLock 获取并锁定待处理事件(使用FOR UPDATE SKIP LOCKED实现分布式锁)
|
||||
func (r *OutboxRepository) FetchAndLock(ctx context.Context, limit int) ([]*OutboxEvent, error) {
|
||||
query := `
|
||||
WITH claimed AS (
|
||||
SELECT id
|
||||
FROM supply_outbox
|
||||
WHERE status IN ('pending', 'failed')
|
||||
AND (next_retry_at IS NULL OR next_retry_at <= CURRENT_TIMESTAMP)
|
||||
ORDER BY created_at ASC
|
||||
LIMIT $1
|
||||
FOR UPDATE SKIP LOCKED
|
||||
)
|
||||
UPDATE supply_outbox AS o
|
||||
SET status = 'processing',
|
||||
version = o.version + 1
|
||||
FROM claimed
|
||||
WHERE o.id = claimed.id
|
||||
RETURNING o.id, o.aggregate_type, o.aggregate_id, o.event_type, o.event_id,
|
||||
o.payload, o.status, o.retry_count, o.max_retries, o.error_message,
|
||||
o.created_at, o.processed_at, o.next_retry_at, o.version
|
||||
`
|
||||
|
||||
rows, err := r.db.Query(ctx, query, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch outbox events: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var events []*OutboxEvent
|
||||
for rows.Next() {
|
||||
event := &OutboxEvent{}
|
||||
err := rows.Scan(
|
||||
&event.ID, &event.AggregateType, &event.AggregateID, &event.EventType,
|
||||
&event.EventID, &event.Payload, &event.Status, &event.RetryCount,
|
||||
&event.MaxRetries, &event.ErrorMessage, &event.CreatedAt,
|
||||
&event.ProcessedAt, &event.NextRetryAt, &event.Version,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan outbox event: %w", err)
|
||||
}
|
||||
|
||||
events = append(events, event)
|
||||
}
|
||||
|
||||
return events, nil
|
||||
}
|
||||
|
||||
// MarkCompleted 标记事件为已完成
|
||||
func (r *OutboxRepository) MarkCompleted(ctx context.Context, eventID string) error {
|
||||
query := `
|
||||
UPDATE supply_outbox SET
|
||||
status = 'completed',
|
||||
processed_at = CURRENT_TIMESTAMP,
|
||||
version = version + 1
|
||||
WHERE event_id = $1 AND status = 'processing'
|
||||
`
|
||||
|
||||
result, err := r.db.Exec(ctx, query, eventID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to mark outbox event completed: %w", err)
|
||||
}
|
||||
|
||||
if result.RowsAffected() == 0 {
|
||||
return fmt.Errorf("event not found or not in processing status")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkFailed 标记事件为失败(并更新重试信息)
|
||||
func (r *OutboxRepository) MarkFailed(ctx context.Context, eventID string, errorMsg string, nextRetryAt *time.Time) error {
|
||||
query := `
|
||||
UPDATE supply_outbox SET
|
||||
status = 'failed',
|
||||
error_message = $2,
|
||||
next_retry_at = $3,
|
||||
version = version + 1
|
||||
WHERE event_id = $1 AND status = 'processing'
|
||||
`
|
||||
|
||||
result, err := r.db.Exec(ctx, query, eventID, errorMsg, nextRetryAt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to mark outbox event failed: %w", err)
|
||||
}
|
||||
|
||||
if result.RowsAffected() == 0 {
|
||||
return fmt.Errorf("event not found or not in processing status")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MoveToDeadLetter 将事件移入死信队列
|
||||
func (r *OutboxRepository) MoveToDeadLetter(ctx context.Context, event *OutboxEvent, errorMsg string) error {
|
||||
tx, err := r.db.Begin(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to begin transaction: %w", err)
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
// 获取事件详情
|
||||
selectQuery := `
|
||||
SELECT id, aggregate_type, aggregate_id, event_type, event_id,
|
||||
payload, retry_count, created_at
|
||||
FROM supply_outbox
|
||||
WHERE event_id = $1
|
||||
FOR UPDATE
|
||||
`
|
||||
|
||||
var originalEvent OutboxEvent
|
||||
err = tx.QueryRow(ctx, selectQuery, event.EventID).Scan(
|
||||
&originalEvent.ID, &originalEvent.AggregateType, &originalEvent.AggregateID,
|
||||
&originalEvent.EventType, &originalEvent.EventID, &originalEvent.Payload,
|
||||
&originalEvent.RetryCount, &originalEvent.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get original event: %w", err)
|
||||
}
|
||||
|
||||
// 插入死信记录
|
||||
insertQuery := `
|
||||
INSERT INTO supply_outbox_dead_letter (
|
||||
original_event_id, original_aggregate_type, original_aggregate_id,
|
||||
event_type, payload, error_message, retry_count, first_failed_at
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
`
|
||||
|
||||
_, err = tx.Exec(ctx, insertQuery,
|
||||
originalEvent.EventID, originalEvent.AggregateType, originalEvent.AggregateID,
|
||||
originalEvent.EventType, originalEvent.Payload, errorMsg,
|
||||
originalEvent.RetryCount, originalEvent.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to insert dead letter: %w", err)
|
||||
}
|
||||
|
||||
// 删除原始事件
|
||||
deleteQuery := `DELETE FROM supply_outbox WHERE event_id = $1`
|
||||
_, err = tx.Exec(ctx, deleteQuery, event.EventID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete original event: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return fmt.Errorf("failed to commit transaction: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetDeadLetterByEventID 根据原始事件ID获取死信记录
|
||||
func (r *OutboxRepository) GetDeadLetterByEventID(ctx context.Context, originalEventID string) (*OutboxDeadLetter, error) {
|
||||
query := `
|
||||
SELECT id, original_event_id, original_aggregate_type, original_aggregate_id,
|
||||
event_type, payload, error_message, retry_count, first_failed_at,
|
||||
dead_letter_at, handled, handled_at, handler_notes, created_at
|
||||
FROM supply_outbox_dead_letter
|
||||
WHERE original_event_id = $1
|
||||
`
|
||||
|
||||
dl := &OutboxDeadLetter{}
|
||||
err := r.db.QueryRow(ctx, query, originalEventID).Scan(
|
||||
&dl.ID, &dl.OriginalEventID, &dl.OriginalAggregateType, &dl.OriginalAggregateID,
|
||||
&dl.EventType, &dl.Payload, &dl.ErrorMessage, &dl.RetryCount,
|
||||
&dl.FirstFailedAt, &dl.DeadLetterAt, &dl.Handled, &dl.HandledAt,
|
||||
&dl.HandlerNotes, &dl.CreatedAt,
|
||||
)
|
||||
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get dead letter: %w", err)
|
||||
}
|
||||
|
||||
return dl, nil
|
||||
}
|
||||
|
||||
// ListUnhandledDeadLetters 列出未处理的死信记录
|
||||
func (r *OutboxRepository) ListUnhandledDeadLetters(ctx context.Context, limit int) ([]*OutboxDeadLetter, error) {
|
||||
query := `
|
||||
SELECT id, original_event_id, original_aggregate_type, original_aggregate_id,
|
||||
event_type, payload, error_message, retry_count, first_failed_at,
|
||||
dead_letter_at, handled, handled_at, handler_notes, created_at
|
||||
FROM supply_outbox_dead_letter
|
||||
WHERE handled = FALSE
|
||||
ORDER BY dead_letter_at ASC
|
||||
LIMIT $1
|
||||
`
|
||||
|
||||
rows, err := r.db.Query(ctx, query, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list dead letters: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var dls []*OutboxDeadLetter
|
||||
for rows.Next() {
|
||||
dl := &OutboxDeadLetter{}
|
||||
err := rows.Scan(
|
||||
&dl.ID, &dl.OriginalEventID, &dl.OriginalAggregateType, &dl.OriginalAggregateID,
|
||||
&dl.EventType, &dl.Payload, &dl.ErrorMessage, &dl.RetryCount,
|
||||
&dl.FirstFailedAt, &dl.DeadLetterAt, &dl.Handled, &dl.HandledAt,
|
||||
&dl.HandlerNotes, &dl.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan dead letter: %w", err)
|
||||
}
|
||||
dls = append(dls, dl)
|
||||
}
|
||||
|
||||
return dls, nil
|
||||
}
|
||||
|
||||
// MarkDeadLetterHandled 标记死信已处理
|
||||
func (r *OutboxRepository) MarkDeadLetterHandled(ctx context.Context, id int64, notes string) error {
|
||||
query := `
|
||||
UPDATE supply_outbox_dead_letter SET
|
||||
handled = TRUE,
|
||||
handled_at = CURRENT_TIMESTAMP,
|
||||
handler_notes = $2
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
_, err := r.db.Exec(ctx, query, id, notes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to mark dead letter handled: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteCompleted 删除已完成的旧事件(定时清理)
|
||||
func (r *OutboxRepository) DeleteCompleted(ctx context.Context, before time.Time) (int64, error) {
|
||||
query := `DELETE FROM supply_outbox WHERE status = 'completed' AND processed_at < $1`
|
||||
|
||||
result, err := r.db.Exec(ctx, query, before)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to delete completed events: %w", err)
|
||||
}
|
||||
|
||||
return result.RowsAffected(), nil
|
||||
}
|
||||
114
supply-api/internal/repository/outbox_test.go
Normal file
114
supply-api/internal/repository/outbox_test.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
)
|
||||
|
||||
type stubOutboxDB struct {
|
||||
querySQL string
|
||||
rows pgx.Rows
|
||||
}
|
||||
|
||||
func (s *stubOutboxDB) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) {
|
||||
s.querySQL = sql
|
||||
return s.rows, nil
|
||||
}
|
||||
|
||||
func (s *stubOutboxDB) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row {
|
||||
panic("unexpected QueryRow call")
|
||||
}
|
||||
|
||||
func (s *stubOutboxDB) Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) {
|
||||
panic("unexpected Exec call")
|
||||
}
|
||||
|
||||
func (s *stubOutboxDB) Begin(ctx context.Context) (pgx.Tx, error) {
|
||||
panic("unexpected Begin call")
|
||||
}
|
||||
|
||||
type stubOutboxRows struct {
|
||||
events []*OutboxEvent
|
||||
index int
|
||||
}
|
||||
|
||||
func (r *stubOutboxRows) Close() {}
|
||||
func (r *stubOutboxRows) Err() error { return nil }
|
||||
func (r *stubOutboxRows) CommandTag() pgconn.CommandTag { return pgconn.CommandTag{} }
|
||||
func (r *stubOutboxRows) FieldDescriptions() []pgconn.FieldDescription { return nil }
|
||||
func (r *stubOutboxRows) RawValues() [][]byte { return nil }
|
||||
func (r *stubOutboxRows) Values() ([]any, error) { return nil, nil }
|
||||
func (r *stubOutboxRows) Conn() *pgx.Conn { return nil }
|
||||
|
||||
func (r *stubOutboxRows) Next() bool {
|
||||
if r.index >= len(r.events) {
|
||||
return false
|
||||
}
|
||||
r.index++
|
||||
return true
|
||||
}
|
||||
|
||||
func (r *stubOutboxRows) Scan(dest ...any) error {
|
||||
event := r.events[r.index-1]
|
||||
*(dest[0].(*int64)) = event.ID
|
||||
*(dest[1].(*string)) = event.AggregateType
|
||||
*(dest[2].(*string)) = event.AggregateID
|
||||
*(dest[3].(*string)) = event.EventType
|
||||
*(dest[4].(*string)) = event.EventID
|
||||
*(dest[5].(*json.RawMessage)) = event.Payload
|
||||
*(dest[6].(*OutboxStatus)) = event.Status
|
||||
*(dest[7].(*int)) = event.RetryCount
|
||||
*(dest[8].(*int)) = event.MaxRetries
|
||||
*(dest[9].(*string)) = event.ErrorMessage
|
||||
*(dest[10].(*time.Time)) = event.CreatedAt
|
||||
*(dest[11].(**time.Time)) = event.ProcessedAt
|
||||
*(dest[12].(**time.Time)) = event.NextRetryAt
|
||||
*(dest[13].(*int64)) = event.Version
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestFetchAndLock_ClaimsEventsAsProcessingInDatabase(t *testing.T) {
|
||||
now := time.Now()
|
||||
payload := json.RawMessage(`{"event":"created"}`)
|
||||
db := &stubOutboxDB{
|
||||
rows: &stubOutboxRows{
|
||||
events: []*OutboxEvent{
|
||||
{
|
||||
ID: 1,
|
||||
AggregateType: "account",
|
||||
AggregateID: "acc-1",
|
||||
EventType: "created",
|
||||
EventID: "evt-1",
|
||||
Payload: payload,
|
||||
Status: OutboxStatusProcessing,
|
||||
RetryCount: 0,
|
||||
MaxRetries: 5,
|
||||
CreatedAt: now,
|
||||
Version: 2,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
repo := &OutboxRepository{db: db}
|
||||
|
||||
events, err := repo.FetchAndLock(context.Background(), 1)
|
||||
if err != nil {
|
||||
t.Fatalf("FetchAndLock returned error: %v", err)
|
||||
}
|
||||
|
||||
if len(events) != 1 {
|
||||
t.Fatalf("expected 1 event, got %d", len(events))
|
||||
}
|
||||
if events[0].Status != OutboxStatusProcessing {
|
||||
t.Fatalf("expected event status processing, got %s", events[0].Status)
|
||||
}
|
||||
if !strings.Contains(db.querySQL, "SET status = 'processing'") {
|
||||
t.Fatalf("expected claim query to persist processing status, got SQL: %s", db.querySQL)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user