Files
lijiaoqiao/platform-token-runtime/internal/auth/middleware/token_auth_middleware.go
Your Name ad776e4079 fix: P0/P1 security fixes across gateway, token-runtime, and supply-api
P0 fixes:
- platform-token-runtime: Add store.Save() after Refresh token update (P0-3)
- platform-token-runtime: Add sync.RWMutex to InMemoryRuntimeStore (P0-4)
- platform-token-runtime: Add bearer token auth to /audit-events endpoint (P0-5)
- gateway: Fail startup in production if PASSWORD_ENCRYPTION_KEY uses default (P0-1)
- gateway: Require explicit CORS_ALLOW_ORIGINS in production (P0-2)

P1 fixes:
- gateway: Add TrustedProxies config field + env var GATEWAY_TRUSTED_PROXIES (P1-5)
- gateway: Sanitize X-Request-ID header to prevent log injection (P1-6)
- gateway: Strip internal error details from error responses to clients (P1-7)
- supply-api: Upgrade deriveDEK from trivial byte-rotation to HKDF-SHA256 (P1-1)
- supply-api: Reject HS256/HS384/HS512 in production, require RSA (P1-2)

Code quality fixes:
- supply-api: Add BruteForceMaxAttempts + BruteForceLockoutDuration to AuthConfig (MED-12)
- supply-api: Add TrustedProxies to token_auth_middleware (IP spoofing protection)
- supply-api: Use shared pathutil.SplitPath instead of duplicate splitPath
- supply-api: Fix query_key_reject_middleware call sites with trustedProxies param
- gateway: Wire TrustedProxies into AuthMiddlewareConfig and extractClientIP
- gateway: Add CORSAllowOrigins to AuthConfig, wire into CORSMiddleware
- gateway: Fix CompletionsHandle to have context and RecordResult like ChatCompletions
- gateway: Add sanitizeRequestID helper for X-Request-ID log injection prevention
- gateway: Add os import for PASSWORD_ENCRYPTION_KEY check
- gateway: Add strings import to handler.go for sanitizeRequestID

Environment issues documented in TEST_ENVIRONMENT_ISSUES.md
2026-04-17 14:36:02 +08:00

290 lines
8.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package middleware
import (
"context"
"encoding/json"
"fmt"
"net"
"net/http"
"strings"
"time"
"lijiaoqiao/platform-token-runtime/internal/auth/model"
"lijiaoqiao/platform-token-runtime/internal/auth/service"
)
const requestIDHeader = "X-Request-Id"
var defaultNowFunc = time.Now
type contextKey string
const (
requestIDKey contextKey = "request_id"
principalKey contextKey = "principal"
)
type AuthMiddlewareConfig struct {
Verifier service.TokenVerifier
StatusResolver service.TokenStatusResolver
Authorizer service.RouteAuthorizer
Auditor service.AuditEmitter
ProtectedPrefixes []string
ExcludedPrefixes []string
Now func() time.Time
TrustedProxies []string // 可信代理IP列表只有来自这些IP的请求才信任 X-Forwarded-For
}
func BuildTokenAuthChain(cfg AuthMiddlewareConfig, next http.Handler) http.Handler {
handler := TokenAuthMiddleware(cfg)(next)
handler = QueryKeyRejectMiddleware(handler, cfg.Auditor, cfg.Now, cfg.TrustedProxies)
handler = RequestIDMiddleware(handler, cfg.Now)
return handler
}
func RequestIDMiddleware(next http.Handler, now func() time.Time) http.Handler {
if next == nil {
return http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})
}
if now == nil {
now = defaultNowFunc
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestID := ensureRequestID(r, now)
w.Header().Set(requestIDHeader, requestID)
next.ServeHTTP(w, r)
})
}
func TokenAuthMiddleware(cfg AuthMiddlewareConfig) func(http.Handler) http.Handler {
cfg = cfg.withDefaults()
return func(next http.Handler) http.Handler {
if next == nil {
next = http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !cfg.shouldProtect(r.URL.Path) {
next.ServeHTTP(w, r)
return
}
requestID := ensureRequestID(r, cfg.Now)
if cfg.Verifier == nil || cfg.StatusResolver == nil || cfg.Authorizer == nil {
writeError(w, http.StatusServiceUnavailable, requestID, service.CodeAuthNotReady, "auth middleware dependencies are not ready")
return
}
rawToken, ok := extractBearerToken(r.Header.Get("Authorization"))
if !ok {
emitAuditEvent(r.Context(), cfg.Auditor, service.AuditEvent{
EventName: service.EventTokenAuthnFail,
RequestID: requestID,
Route: r.URL.Path,
ResultCode: service.CodeAuthMissingBearer,
ClientIP: extractClientIP(r, cfg.TrustedProxies),
CreatedAt: cfg.Now(),
})
writeError(w, http.StatusUnauthorized, requestID, service.CodeAuthMissingBearer, "missing bearer token")
return
}
claims, err := cfg.Verifier.Verify(r.Context(), rawToken)
if err != nil {
emitAuditEvent(r.Context(), cfg.Auditor, service.AuditEvent{
EventName: service.EventTokenAuthnFail,
RequestID: requestID,
Route: r.URL.Path,
ResultCode: service.CodeAuthInvalidToken,
ClientIP: extractClientIP(r, cfg.TrustedProxies),
CreatedAt: cfg.Now(),
})
writeError(w, http.StatusUnauthorized, requestID, service.CodeAuthInvalidToken, "invalid bearer token")
return
}
tokenStatus, err := cfg.StatusResolver.Resolve(r.Context(), claims.TokenID)
if err != nil || tokenStatus != service.TokenStatusActive {
emitAuditEvent(r.Context(), cfg.Auditor, service.AuditEvent{
EventName: service.EventTokenAuthnFail,
RequestID: requestID,
TokenID: claims.TokenID,
SubjectID: claims.SubjectID,
Route: r.URL.Path,
ResultCode: service.CodeAuthTokenInactive,
ClientIP: extractClientIP(r, cfg.TrustedProxies),
CreatedAt: cfg.Now(),
})
writeError(w, http.StatusUnauthorized, requestID, service.CodeAuthTokenInactive, "token is inactive")
return
}
if !cfg.Authorizer.Authorize(r.URL.Path, r.Method, claims.Scope, claims.Role) {
emitAuditEvent(r.Context(), cfg.Auditor, service.AuditEvent{
EventName: service.EventTokenAuthzDenied,
RequestID: requestID,
TokenID: claims.TokenID,
SubjectID: claims.SubjectID,
Route: r.URL.Path,
ResultCode: service.CodeAuthScopeDenied,
ClientIP: extractClientIP(r, cfg.TrustedProxies),
CreatedAt: cfg.Now(),
})
writeError(w, http.StatusForbidden, requestID, service.CodeAuthScopeDenied, "scope denied")
return
}
principal := model.Principal{
RequestID: requestID,
TokenID: claims.TokenID,
SubjectID: claims.SubjectID,
Role: claims.Role,
Scope: append([]string(nil), claims.Scope...),
}
ctx := context.WithValue(r.Context(), principalKey, principal)
ctx = context.WithValue(ctx, requestIDKey, requestID)
emitAuditEvent(ctx, cfg.Auditor, service.AuditEvent{
EventName: service.EventTokenAuthnSuccess,
RequestID: requestID,
TokenID: claims.TokenID,
SubjectID: claims.SubjectID,
Route: r.URL.Path,
ResultCode: "OK",
ClientIP: extractClientIP(r, cfg.TrustedProxies),
CreatedAt: cfg.Now(),
})
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
func RequestIDFromContext(ctx context.Context) (string, bool) {
if ctx == nil {
return "", false
}
value, ok := ctx.Value(requestIDKey).(string)
return value, ok
}
func PrincipalFromContext(ctx context.Context) (model.Principal, bool) {
if ctx == nil {
return model.Principal{}, false
}
value, ok := ctx.Value(principalKey).(model.Principal)
return value, ok
}
func (cfg AuthMiddlewareConfig) withDefaults() AuthMiddlewareConfig {
if cfg.Now == nil {
cfg.Now = defaultNowFunc
}
if len(cfg.ProtectedPrefixes) == 0 {
cfg.ProtectedPrefixes = []string{"/api/v1/supply", "/api/v1/platform"}
}
if len(cfg.ExcludedPrefixes) == 0 {
cfg.ExcludedPrefixes = []string{"/healthz", "/metrics", "/readyz"}
}
return cfg
}
func (cfg AuthMiddlewareConfig) shouldProtect(path string) bool {
for _, prefix := range cfg.ExcludedPrefixes {
if strings.HasPrefix(path, prefix) {
return false
}
}
for _, prefix := range cfg.ProtectedPrefixes {
if strings.HasPrefix(path, prefix) {
return true
}
}
return false
}
func ensureRequestID(r *http.Request, now func() time.Time) string {
if now == nil {
now = defaultNowFunc
}
if requestID, ok := RequestIDFromContext(r.Context()); ok && requestID != "" {
return requestID
}
requestID := strings.TrimSpace(r.Header.Get(requestIDHeader))
if requestID == "" {
requestID = fmt.Sprintf("req-%d", now().UnixNano())
}
ctx := context.WithValue(r.Context(), requestIDKey, requestID)
*r = *r.WithContext(ctx)
return requestID
}
func extractBearerToken(authHeader string) (string, bool) {
const bearerPrefix = "Bearer "
if !strings.HasPrefix(authHeader, bearerPrefix) {
return "", false
}
token := strings.TrimSpace(strings.TrimPrefix(authHeader, bearerPrefix))
return token, token != ""
}
func emitAuditEvent(ctx context.Context, auditor service.AuditEmitter, event service.AuditEvent) {
if auditor == nil {
return
}
_ = auditor.Emit(ctx, event)
}
type errorResponse struct {
RequestID string `json:"request_id"`
Error errorPayload `json:"error"`
}
type errorPayload struct {
Code string `json:"code"`
Message string `json:"message"`
Details map[string]any `json:"details,omitempty"`
}
func writeError(w http.ResponseWriter, status int, requestID, code, message string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
payload := errorResponse{
RequestID: requestID,
Error: errorPayload{
Code: code,
Message: message,
},
}
_ = json.NewEncoder(w).Encode(payload)
}
// extractClientIP 安全提取客户端IP参考 gateway 实现:
// - 仅在请求来自可信代理时才信任 X-Forwarded-For
// - 防止 IP 欺骗攻击
func extractClientIP(r *http.Request, trustedProxies []string) string {
isFromTrustedProxy := false
remoteHost, _, err := net.SplitHostPort(r.RemoteAddr)
if err == nil {
for _, proxy := range trustedProxies {
if remoteHost == proxy {
isFromTrustedProxy = true
break
}
}
}
// 只有来自可信代理的请求才使用 X-Forwarded-For
if isFromTrustedProxy {
xForwardedFor := strings.TrimSpace(r.Header.Get("X-Forwarded-For"))
if xForwardedFor != "" {
parts := strings.Split(xForwardedFor, ",")
return strings.TrimSpace(parts[0])
}
}
// 否则使用 RemoteAddr
if err == nil {
return remoteHost
}
return r.RemoteAddr
}