Files
lijiaoqiao/supply-api/internal/middleware/ratelimit.go
Your Name 2012e23278 feat: 更新TDD任务清单并验证所有安全问题
- TASK-25: domain覆盖率已达72.0% (目标70%+)
- TASK-27: DSN密码设计安全验证完成
- 确认请求超时中间件已正确实现
- 所有go vet问题已修复

剩余未解决项:
- SEC-005: 开发模式鉴权禁用(设计决定)
- SEC-010: TokenCache多实例(需Redis)
2026-04-09 20:44:11 +08:00

259 lines
6.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package middleware
import (
"fmt"
"net/http"
"strconv"
"sync"
"time"
)
// ==================== P0-05 限流策略实现 ====================
// RateLimitConfig 限流配置
type RateLimitConfig struct {
Enabled bool // 是否启用
Requests int // 窗口内最大请求数
Window time.Duration // 时间窗口
HeaderType string // 响应头类型: "X-RateLimit" | "RateLimit-Policy"
// 多维度限流
LimitByTenant bool // 按租户限流
LimitByUser bool // 按用户限流
LimitByIP bool // 按IP限流
LimitByEndpoint bool // 按端点限流
// 降级策略
DegradationEnabled bool // 是否启用降级
DegradationHandler http.Handler // 降级处理器
FallbackCode int // 降级时返回的HTTP状态码
}
// DefaultRateLimitConfig 默认限流配置
func DefaultRateLimitConfig() *RateLimitConfig {
return &RateLimitConfig{
Enabled: true,
Requests: 1000,
Window: time.Minute,
HeaderType: "X-RateLimit",
LimitByTenant: true,
LimitByUser: true,
LimitByIP: true,
LimitByEndpoint: true,
DegradationEnabled: true,
FallbackCode: http.StatusTooManyRequests,
}
}
// TokenBucket 令牌桶
type TokenBucket struct {
mu sync.Mutex
capacity int // 桶容量
rate int // 每秒补充的令牌数
tokens int // 当前令牌数
lastRefill time.Time // 上次补充时间
}
// NewTokenBucket 创建令牌桶
func NewTokenBucket(capacity int, rate int) *TokenBucket {
return &TokenBucket{
capacity: capacity,
rate: rate,
tokens: capacity,
lastRefill: time.Now(),
}
}
// Allow 检查是否允许请求
func (tb *TokenBucket) Allow() bool {
tb.mu.Lock()
defer tb.mu.Unlock()
// 补充令牌
tb.refill()
if tb.tokens > 0 {
tb.tokens--
return true
}
return false
}
// refill 补充令牌
func (tb *TokenBucket) refill() {
now := time.Now()
elapsed := now.Sub(tb.lastRefill)
// 计算应该补充的令牌数使用float64精确计算
tokensToAdd := int(elapsed.Seconds()*float64(tb.rate)) + tb.tokens
if tokensToAdd > tb.capacity {
tb.tokens = tb.capacity
} else {
tb.tokens = tokensToAdd
}
tb.lastRefill = now
}
// Remaining 返回剩余令牌数
func (tb *TokenBucket) Remaining() int {
tb.mu.Lock()
defer tb.mu.Unlock()
tb.refill()
return tb.tokens
}
// Reset 重置令牌桶
func (tb *TokenBucket) Reset() {
tb.mu.Lock()
defer tb.mu.Unlock()
tb.tokens = tb.capacity
tb.lastRefill = time.Now()
}
// RateLimitMiddleware 限流中间件
type RateLimitMiddleware struct {
config *RateLimitConfig
buckets map[string]*TokenBucket // 限流桶
mu sync.RWMutex
next http.Handler
}
// ServeHTTP 实现http.Handler
func (rl *RateLimitMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if !rl.config.Enabled {
rl.next.ServeHTTP(w, r)
return
}
// 生成限流key
key := rl.getRateLimitKey(r)
// 获取或创建限流桶
bucket := rl.getBucket(key)
// 检查是否允许
if !bucket.Allow() {
rl.handleRateLimitExceeded(w, r, bucket)
return
}
// 设置响应头
rl.setRateLimitHeaders(w, bucket)
rl.next.ServeHTTP(w, r)
}
// getRateLimitKey 获取限流key
func (rl *RateLimitMiddleware) getRateLimitKey(r *http.Request) string {
var keyParts []string
if rl.config.LimitByTenant {
keyParts = append(keyParts, fmt.Sprintf("tenant:%d", getTenantIDFromRequest(r)))
}
if rl.config.LimitByUser {
keyParts = append(keyParts, fmt.Sprintf("user:%s", getUserIDFromRequest(r)))
}
if rl.config.LimitByIP {
keyParts = append(keyParts, fmt.Sprintf("ip:%s", getClientIP(r)))
}
if rl.config.LimitByEndpoint {
keyParts = append(keyParts, r.URL.Path)
}
if len(keyParts) == 0 {
return "default"
}
result := keyParts[0]
for i := 1; i < len(keyParts); i++ {
result = fmt.Sprintf("%s:%s", result, keyParts[i])
}
return result
}
// getBucket 获取限流桶
func (rl *RateLimitMiddleware) getBucket(key string) *TokenBucket {
rl.mu.RLock()
bucket, exists := rl.buckets[key]
rl.mu.RUnlock()
if exists {
return bucket
}
rl.mu.Lock()
defer rl.mu.Unlock()
// 双重检查
if bucket, exists = rl.buckets[key]; exists {
return bucket
}
bucket = NewTokenBucket(rl.config.Requests, 0) // 固定容量,不自动补充
rl.buckets[key] = bucket
return bucket
}
// handleRateLimitExceeded 处理限流超出
func (rl *RateLimitMiddleware) handleRateLimitExceeded(w http.ResponseWriter, r *http.Request, bucket *TokenBucket) {
// 设置重试响应头
resetTime := time.Now().Add(rl.config.Window)
w.Header().Set("Retry-After", strconv.Itoa(int(rl.config.Window.Seconds())))
w.Header().Set("X-RateLimit-Reset", strconv.FormatInt(resetTime.Unix(), 10))
if rl.config.DegradationEnabled && rl.config.DegradationHandler != nil {
rl.config.DegradationHandler.ServeHTTP(w, r)
return
}
http.Error(w, "rate limit exceeded", rl.config.FallbackCode)
}
// setRateLimitHeaders 设置限流响应头
func (rl *RateLimitMiddleware) setRateLimitHeaders(w http.ResponseWriter, bucket *TokenBucket) {
prefix := rl.config.HeaderType
w.Header().Set(prefix+"-Limit", strconv.Itoa(rl.config.Requests))
w.Header().Set(prefix+"-Remaining", strconv.Itoa(bucket.Remaining()))
resetTime := time.Now().Add(rl.config.Window).Unix()
w.Header().Set(prefix+"-Reset", strconv.FormatInt(resetTime, 10))
}
// Helper functions
// getTenantIDFromRequest 从请求获取租户ID
func getTenantIDFromRequest(r *http.Request) int64 {
// 从JWT claims获取租户ID
if claims := GetTokenClaims(r.Context()); claims != nil {
return claims.TenantID
}
return 0
}
// getUserIDFromRequest 从请求获取用户ID
func getUserIDFromRequest(r *http.Request) string {
// 从JWT claims获取用户ID
if claims := GetTokenClaims(r.Context()); claims != nil {
return claims.SubjectID
}
return "unknown"
}
// NewRateLimitHandler 创建限流中间件包装器
// 用于简化在中间件链路中的使用
func NewRateLimitHandler(config *RateLimitConfig, next http.Handler) *RateLimitMiddleware {
return &RateLimitMiddleware{
config: config,
buckets: make(map[string]*TokenBucket),
next: next,
}
}