fix: harden handler context and rate limit isolation
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -10,11 +11,20 @@ import (
|
||||
)
|
||||
|
||||
// RateLimitMiddleware 限流中间件
|
||||
// 使用 endpoint + subject(IP 或 user_id) 作为限流键,并对空闲条目做 TTL 清理,
|
||||
// 避免单一全局限流器误伤所有用户,也避免历史客户端条目无限增长。
|
||||
type RateLimitMiddleware struct {
|
||||
cfg config.RateLimitConfig
|
||||
limiters map[string]*SlidingWindowLimiter
|
||||
mu sync.RWMutex
|
||||
cleanupInt time.Duration
|
||||
cfg config.RateLimitConfig
|
||||
limiters map[string]*limiterEntry
|
||||
mu sync.RWMutex
|
||||
cleanupInt time.Duration
|
||||
lastCleanup time.Time
|
||||
}
|
||||
|
||||
type limiterEntry struct {
|
||||
limiter *SlidingWindowLimiter
|
||||
window time.Duration
|
||||
lastSeen time.Time
|
||||
}
|
||||
|
||||
// SlidingWindowLimiter 滑动窗口限流器
|
||||
@@ -43,7 +53,7 @@ func (l *SlidingWindowLimiter) Allow() bool {
|
||||
cutoff := now - l.window.Milliseconds()
|
||||
|
||||
// 清理过期请求
|
||||
var validRequests []int64
|
||||
validRequests := l.requests[:0]
|
||||
for _, t := range l.requests {
|
||||
if t > cutoff {
|
||||
validRequests = append(validRequests, t)
|
||||
@@ -63,9 +73,10 @@ func (l *SlidingWindowLimiter) Allow() bool {
|
||||
// NewRateLimitMiddleware 创建限流中间件
|
||||
func NewRateLimitMiddleware(cfg config.RateLimitConfig) *RateLimitMiddleware {
|
||||
return &RateLimitMiddleware{
|
||||
cfg: cfg,
|
||||
limiters: make(map[string]*SlidingWindowLimiter),
|
||||
cleanupInt: 5 * time.Minute,
|
||||
cfg: cfg,
|
||||
limiters: make(map[string]*limiterEntry),
|
||||
cleanupInt: 5 * time.Minute,
|
||||
lastCleanup: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -89,16 +100,18 @@ func (m *RateLimitMiddleware) Refresh() gin.HandlerFunc {
|
||||
return m.limitForKey("refresh", 60, 10)
|
||||
}
|
||||
|
||||
func (m *RateLimitMiddleware) limitForKey(key string, windowSeconds int, capacity int64) gin.HandlerFunc {
|
||||
func (m *RateLimitMiddleware) limitForKey(scope string, windowSeconds int, capacity int64) gin.HandlerFunc {
|
||||
if os.Getenv("DISABLE_RATE_LIMIT") == "1" {
|
||||
return func(c *gin.Context) {
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
limiter := m.getOrCreateLimiter(key, time.Duration(windowSeconds)*time.Second, capacity)
|
||||
window := time.Duration(windowSeconds) * time.Second
|
||||
|
||||
return func(c *gin.Context) {
|
||||
limiterKey := m.buildLimiterKey(scope, c)
|
||||
limiter := m.getOrCreateLimiter(limiterKey, window, capacity)
|
||||
if !limiter.Allow() {
|
||||
c.JSON(429, gin.H{
|
||||
"code": 429,
|
||||
@@ -111,24 +124,60 @@ func (m *RateLimitMiddleware) limitForKey(key string, windowSeconds int, capacit
|
||||
}
|
||||
}
|
||||
|
||||
func (m *RateLimitMiddleware) getOrCreateLimiter(key string, window time.Duration, capacity int64) *SlidingWindowLimiter {
|
||||
m.mu.RLock()
|
||||
limiter, exists := m.limiters[key]
|
||||
m.mu.RUnlock()
|
||||
func (m *RateLimitMiddleware) buildLimiterKey(scope string, c *gin.Context) string {
|
||||
if userID, ok := c.Get("user_id"); ok {
|
||||
return fmt.Sprintf("%s:user:%v", scope, userID)
|
||||
}
|
||||
return fmt.Sprintf("%s:ip:%s", scope, c.ClientIP())
|
||||
}
|
||||
|
||||
func (m *RateLimitMiddleware) getOrCreateLimiter(key string, window time.Duration, capacity int64) *SlidingWindowLimiter {
|
||||
now := time.Now()
|
||||
m.maybeCleanup(now)
|
||||
|
||||
m.mu.RLock()
|
||||
entry, exists := m.limiters[key]
|
||||
m.mu.RUnlock()
|
||||
if exists {
|
||||
return limiter
|
||||
m.mu.Lock()
|
||||
entry.lastSeen = now
|
||||
m.mu.Unlock()
|
||||
return entry.limiter
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// 双重检查
|
||||
if limiter, exists = m.limiters[key]; exists {
|
||||
return limiter
|
||||
if entry, exists = m.limiters[key]; exists {
|
||||
entry.lastSeen = now
|
||||
return entry.limiter
|
||||
}
|
||||
|
||||
limiter = NewSlidingWindowLimiter(window, capacity)
|
||||
m.limiters[key] = limiter
|
||||
return limiter
|
||||
entry = &limiterEntry{
|
||||
limiter: NewSlidingWindowLimiter(window, capacity),
|
||||
window: window,
|
||||
lastSeen: now,
|
||||
}
|
||||
m.limiters[key] = entry
|
||||
return entry.limiter
|
||||
}
|
||||
|
||||
func (m *RateLimitMiddleware) maybeCleanup(now time.Time) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if now.Sub(m.lastCleanup) < m.cleanupInt {
|
||||
return
|
||||
}
|
||||
|
||||
for key, entry := range m.limiters {
|
||||
idleTTL := entry.window
|
||||
if idleTTL < m.cleanupInt {
|
||||
idleTTL = m.cleanupInt
|
||||
}
|
||||
if now.Sub(entry.lastSeen) > idleTTL {
|
||||
delete(m.limiters, key)
|
||||
}
|
||||
}
|
||||
m.lastCleanup = now
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user