package middleware import ( "sync" "time" "github.com/gin-gonic/gin" "github.com/user-management-system/internal/config" ) // RateLimitMiddleware 限流中间件 type RateLimitMiddleware struct { cfg config.RateLimitConfig limiters map[string]*SlidingWindowLimiter mu sync.RWMutex cleanupInt time.Duration } // SlidingWindowLimiter 滑动窗口限流器 type SlidingWindowLimiter struct { mu sync.Mutex window time.Duration capacity int64 requests []int64 } // NewSlidingWindowLimiter 创建滑动窗口限流器 func NewSlidingWindowLimiter(window time.Duration, capacity int64) *SlidingWindowLimiter { return &SlidingWindowLimiter{ window: window, capacity: capacity, requests: make([]int64, 0), } } // Allow 检查是否允许请求 func (l *SlidingWindowLimiter) Allow() bool { l.mu.Lock() defer l.mu.Unlock() now := time.Now().UnixMilli() cutoff := now - l.window.Milliseconds() // 清理过期请求 var validRequests []int64 for _, t := range l.requests { if t > cutoff { validRequests = append(validRequests, t) } } l.requests = validRequests // 检查容量 if int64(len(l.requests)) >= l.capacity { return false } l.requests = append(l.requests, now) return true } // NewRateLimitMiddleware 创建限流中间件 func NewRateLimitMiddleware(cfg config.RateLimitConfig) *RateLimitMiddleware { return &RateLimitMiddleware{ cfg: cfg, limiters: make(map[string]*SlidingWindowLimiter), cleanupInt: 5 * time.Minute, } } // Register 返回注册接口的限流中间件 func (m *RateLimitMiddleware) Register() gin.HandlerFunc { return m.limitForKey("register", 60, 10) } // Login 返回登录接口的限流中间件 func (m *RateLimitMiddleware) Login() gin.HandlerFunc { return m.limitForKey("login", 60, 5) } // API 返回 API 接口的限流中间件 func (m *RateLimitMiddleware) API() gin.HandlerFunc { return m.limitForKey("api", 60, 100) } // Refresh 返回刷新令牌的限流中间件 func (m *RateLimitMiddleware) Refresh() gin.HandlerFunc { return m.limitForKey("refresh", 60, 10) } func (m *RateLimitMiddleware) limitForKey(key string, windowSeconds int, capacity int64) gin.HandlerFunc { limiter := m.getOrCreateLimiter(key, time.Duration(windowSeconds)*time.Second, capacity) return func(c *gin.Context) { if !limiter.Allow() { c.JSON(429, gin.H{ "code": 429, "message": "请求过于频繁,请稍后再试", }) c.Abort() return } c.Next() } } func (m *RateLimitMiddleware) getOrCreateLimiter(key string, window time.Duration, capacity int64) *SlidingWindowLimiter { m.mu.RLock() limiter, exists := m.limiters[key] m.mu.RUnlock() if exists { return limiter } m.mu.Lock() defer m.mu.Unlock() // 双重检查 if limiter, exists = m.limiters[key]; exists { return limiter } limiter = NewSlidingWindowLimiter(window, capacity) m.limiters[key] = limiter return limiter }