fix: harden handler context and rate limit isolation

This commit is contained in:
Your Name
2026-05-28 20:30:24 +08:00
parent e46567678f
commit caad1aba0c
6 changed files with 311 additions and 37 deletions

View File

@@ -1,6 +1,7 @@
package middleware
import (
"fmt"
"os"
"sync"
"time"
@@ -10,11 +11,20 @@ import (
)
// RateLimitMiddleware 限流中间件
// 使用 endpoint + subject(IP 或 user_id) 作为限流键,并对空闲条目做 TTL 清理,
// 避免单一全局限流器误伤所有用户,也避免历史客户端条目无限增长。
type RateLimitMiddleware struct {
cfg config.RateLimitConfig
limiters map[string]*SlidingWindowLimiter
mu sync.RWMutex
cleanupInt time.Duration
cfg config.RateLimitConfig
limiters map[string]*limiterEntry
mu sync.RWMutex
cleanupInt time.Duration
lastCleanup time.Time
}
type limiterEntry struct {
limiter *SlidingWindowLimiter
window time.Duration
lastSeen time.Time
}
// SlidingWindowLimiter 滑动窗口限流器
@@ -43,7 +53,7 @@ func (l *SlidingWindowLimiter) Allow() bool {
cutoff := now - l.window.Milliseconds()
// 清理过期请求
var validRequests []int64
validRequests := l.requests[:0]
for _, t := range l.requests {
if t > cutoff {
validRequests = append(validRequests, t)
@@ -63,9 +73,10 @@ func (l *SlidingWindowLimiter) Allow() bool {
// NewRateLimitMiddleware 创建限流中间件
func NewRateLimitMiddleware(cfg config.RateLimitConfig) *RateLimitMiddleware {
return &RateLimitMiddleware{
cfg: cfg,
limiters: make(map[string]*SlidingWindowLimiter),
cleanupInt: 5 * time.Minute,
cfg: cfg,
limiters: make(map[string]*limiterEntry),
cleanupInt: 5 * time.Minute,
lastCleanup: time.Now(),
}
}
@@ -89,16 +100,18 @@ func (m *RateLimitMiddleware) Refresh() gin.HandlerFunc {
return m.limitForKey("refresh", 60, 10)
}
func (m *RateLimitMiddleware) limitForKey(key string, windowSeconds int, capacity int64) gin.HandlerFunc {
func (m *RateLimitMiddleware) limitForKey(scope string, windowSeconds int, capacity int64) gin.HandlerFunc {
if os.Getenv("DISABLE_RATE_LIMIT") == "1" {
return func(c *gin.Context) {
c.Next()
}
}
limiter := m.getOrCreateLimiter(key, time.Duration(windowSeconds)*time.Second, capacity)
window := time.Duration(windowSeconds) * time.Second
return func(c *gin.Context) {
limiterKey := m.buildLimiterKey(scope, c)
limiter := m.getOrCreateLimiter(limiterKey, window, capacity)
if !limiter.Allow() {
c.JSON(429, gin.H{
"code": 429,
@@ -111,24 +124,60 @@ func (m *RateLimitMiddleware) limitForKey(key string, windowSeconds int, capacit
}
}
func (m *RateLimitMiddleware) getOrCreateLimiter(key string, window time.Duration, capacity int64) *SlidingWindowLimiter {
m.mu.RLock()
limiter, exists := m.limiters[key]
m.mu.RUnlock()
func (m *RateLimitMiddleware) buildLimiterKey(scope string, c *gin.Context) string {
if userID, ok := c.Get("user_id"); ok {
return fmt.Sprintf("%s:user:%v", scope, userID)
}
return fmt.Sprintf("%s:ip:%s", scope, c.ClientIP())
}
func (m *RateLimitMiddleware) getOrCreateLimiter(key string, window time.Duration, capacity int64) *SlidingWindowLimiter {
now := time.Now()
m.maybeCleanup(now)
m.mu.RLock()
entry, exists := m.limiters[key]
m.mu.RUnlock()
if exists {
return limiter
m.mu.Lock()
entry.lastSeen = now
m.mu.Unlock()
return entry.limiter
}
m.mu.Lock()
defer m.mu.Unlock()
// 双重检查
if limiter, exists = m.limiters[key]; exists {
return limiter
if entry, exists = m.limiters[key]; exists {
entry.lastSeen = now
return entry.limiter
}
limiter = NewSlidingWindowLimiter(window, capacity)
m.limiters[key] = limiter
return limiter
entry = &limiterEntry{
limiter: NewSlidingWindowLimiter(window, capacity),
window: window,
lastSeen: now,
}
m.limiters[key] = entry
return entry.limiter
}
func (m *RateLimitMiddleware) maybeCleanup(now time.Time) {
m.mu.Lock()
defer m.mu.Unlock()
if now.Sub(m.lastCleanup) < m.cleanupInt {
return
}
for key, entry := range m.limiters {
idleTTL := entry.window
if idleTTL < m.cleanupInt {
idleTTL = m.cleanupInt
}
if now.Sub(entry.lastSeen) > idleTTL {
delete(m.limiters, key)
}
}
m.lastCleanup = now
}