package middleware import ( "fmt" "os" "sync" "time" "github.com/gin-gonic/gin" "github.com/user-management-system/internal/config" ) // RateLimitMiddleware 限流中间件 // 使用 endpoint + subject(IP 或 user_id) 作为限流键,并对空闲条目做 TTL 清理, // 避免单一全局限流器误伤所有用户,也避免历史客户端条目无限增长。 type RateLimitMiddleware struct { cfg config.RateLimitConfig limiters map[string]*limiterEntry mu sync.RWMutex cleanupInt time.Duration lastCleanup time.Time } type limiterEntry struct { limiter *SlidingWindowLimiter window time.Duration lastSeen time.Time } // SlidingWindowLimiter 滑动窗口限流器 type SlidingWindowLimiter struct { mu sync.Mutex window time.Duration capacity int64 requests []int64 } // NewSlidingWindowLimiter 创建滑动窗口限流器 func NewSlidingWindowLimiter(window time.Duration, capacity int64) *SlidingWindowLimiter { return &SlidingWindowLimiter{ window: window, capacity: capacity, requests: make([]int64, 0), } } // Allow 检查是否允许请求 func (l *SlidingWindowLimiter) Allow() bool { l.mu.Lock() defer l.mu.Unlock() now := time.Now().UnixMilli() cutoff := now - l.window.Milliseconds() // 清理过期请求 validRequests := l.requests[:0] for _, t := range l.requests { if t > cutoff { validRequests = append(validRequests, t) } } l.requests = validRequests // 检查容量 if int64(len(l.requests)) >= l.capacity { return false } l.requests = append(l.requests, now) return true } // NewRateLimitMiddleware 创建限流中间件 func NewRateLimitMiddleware(cfg config.RateLimitConfig) *RateLimitMiddleware { return &RateLimitMiddleware{ cfg: cfg, limiters: make(map[string]*limiterEntry), cleanupInt: 5 * time.Minute, lastCleanup: time.Now(), } } // Register 返回注册接口的限流中间件 func (m *RateLimitMiddleware) Register() gin.HandlerFunc { return m.limitForKey("register", 60, 10) } // Login 返回登录接口的限流中间件 func (m *RateLimitMiddleware) Login() gin.HandlerFunc { return m.limitForKey("login", 60, 5) } // API 返回 API 接口的限流中间件 func (m *RateLimitMiddleware) API() gin.HandlerFunc { return m.limitForKey("api", 60, 100) } // Refresh 返回刷新令牌的限流中间件 func (m *RateLimitMiddleware) Refresh() gin.HandlerFunc { return m.limitForKey("refresh", 60, 10) } func (m *RateLimitMiddleware) limitForKey(scope string, windowSeconds int, capacity int64) gin.HandlerFunc { if os.Getenv("DISABLE_RATE_LIMIT") == "1" { return func(c *gin.Context) { c.Next() } } window := time.Duration(windowSeconds) * time.Second return func(c *gin.Context) { limiterKey := m.buildLimiterKey(scope, c) limiter := m.getOrCreateLimiter(limiterKey, window, capacity) if !limiter.Allow() { c.JSON(429, gin.H{ "code": 429, "message": "请求过于频繁,请稍后再试", }) c.Abort() return } c.Next() } } func (m *RateLimitMiddleware) buildLimiterKey(scope string, c *gin.Context) string { if userID, ok := c.Get("user_id"); ok { return fmt.Sprintf("%s:user:%v", scope, userID) } return fmt.Sprintf("%s:ip:%s", scope, c.ClientIP()) } func (m *RateLimitMiddleware) getOrCreateLimiter(key string, window time.Duration, capacity int64) *SlidingWindowLimiter { now := time.Now() m.maybeCleanup(now) m.mu.RLock() entry, exists := m.limiters[key] m.mu.RUnlock() if exists { m.mu.Lock() entry.lastSeen = now m.mu.Unlock() return entry.limiter } m.mu.Lock() defer m.mu.Unlock() if entry, exists = m.limiters[key]; exists { entry.lastSeen = now return entry.limiter } entry = &limiterEntry{ limiter: NewSlidingWindowLimiter(window, capacity), window: window, lastSeen: now, } m.limiters[key] = entry return entry.limiter } func (m *RateLimitMiddleware) maybeCleanup(now time.Time) { m.mu.Lock() defer m.mu.Unlock() if now.Sub(m.lastCleanup) < m.cleanupInt { return } for key, entry := range m.limiters { idleTTL := entry.window if idleTTL < m.cleanupInt { idleTTL = m.cleanupInt } if now.Sub(entry.lastSeen) > idleTTL { delete(m.limiters, key) } } m.lastCleanup = now }