package middleware import ( "bytes" "crypto/sha256" "encoding/hex" "encoding/json" "fmt" "io" "sync" "time" "github.com/gin-gonic/gin" "github.com/user-management-system/internal/config" ) // RateLimitMiddleware provides simple in-memory sliding-window rate limiting. type RateLimitMiddleware struct { cfg config.RateLimitConfig limiters map[string]*SlidingWindowLimiter mu sync.RWMutex cleanupInt time.Duration } // SlidingWindowLimiter enforces a fixed-capacity sliding window. type SlidingWindowLimiter struct { mu sync.Mutex window time.Duration capacity int64 requests []int64 } func NewSlidingWindowLimiter(window time.Duration, capacity int64) *SlidingWindowLimiter { return &SlidingWindowLimiter{ window: window, capacity: capacity, requests: make([]int64, 0), } } func (l *SlidingWindowLimiter) Allow() bool { l.mu.Lock() defer l.mu.Unlock() now := time.Now().UnixMilli() cutoff := now - l.window.Milliseconds() validRequests := make([]int64, 0, len(l.requests)) for _, ts := range l.requests { if ts > cutoff { validRequests = append(validRequests, ts) } } l.requests = validRequests if int64(len(l.requests)) >= l.capacity { return false } l.requests = append(l.requests, now) return true } func NewRateLimitMiddleware(cfg config.RateLimitConfig) *RateLimitMiddleware { return &RateLimitMiddleware{ cfg: cfg, limiters: make(map[string]*SlidingWindowLimiter), cleanupInt: 5 * time.Minute, } } func (m *RateLimitMiddleware) Register() gin.HandlerFunc { return m.limitForKey("register", 60, 10) } func (m *RateLimitMiddleware) Login() gin.HandlerFunc { return m.limitForKey("login", 60, 5) } func (m *RateLimitMiddleware) API() gin.HandlerFunc { return m.limitForKey("api", 60, 100) } func (m *RateLimitMiddleware) Refresh() gin.HandlerFunc { return m.limitForKey("refresh", 60, 10) } func (m *RateLimitMiddleware) limitForKey(bucket string, windowSeconds int, capacity int64) gin.HandlerFunc { window := time.Duration(windowSeconds) * time.Second return func(c *gin.Context) { limiterKey := m.resolveLimiterKey(c, bucket) limiter := m.getOrCreateLimiter(limiterKey, window, capacity) if !limiter.Allow() { c.JSON(429, gin.H{ "code": 429, "message": "请求过于频繁,请稍后再试", }) c.Abort() return } c.Next() } } func (m *RateLimitMiddleware) resolveLimiterKey(c *gin.Context, bucket string) string { if bucket == "refresh" { if refreshToken := extractRefreshToken(c); refreshToken != "" { return fmt.Sprintf("%s:token:%s", bucket, fingerprintValue(refreshToken)) } } identity := "anonymous" if c != nil { if userID, ok := c.Get("user_id"); ok { identity = fmt.Sprintf("user:%v", userID) } else if ip := c.ClientIP(); ip != "" { identity = "ip:" + ip } } if bucket == "api" { method := "" route := "" if c != nil { if c.Request != nil { method = c.Request.Method if c.Request.URL != nil { route = c.Request.URL.Path } } if fullPath := c.FullPath(); fullPath != "" { route = fullPath } } return fmt.Sprintf("%s:%s:%s:%s", bucket, method, route, identity) } return fmt.Sprintf("%s:%s", bucket, identity) } func extractRefreshToken(c *gin.Context) string { if c == nil { return "" } if refreshToken, err := c.Cookie("ums_refresh_token"); err == nil && refreshToken != "" { return refreshToken } if c.Request == nil || c.Request.Body == nil { return "" } body, err := io.ReadAll(c.Request.Body) if err != nil { return "" } c.Request.Body = io.NopCloser(bytes.NewReader(body)) if len(bytes.TrimSpace(body)) == 0 { return "" } var payload struct { RefreshToken string `json:"refresh_token"` } if err := json.Unmarshal(body, &payload); err != nil { return "" } return payload.RefreshToken } func fingerprintValue(value string) string { sum := sha256.Sum256([]byte(value)) return hex.EncodeToString(sum[:12]) } func (m *RateLimitMiddleware) getOrCreateLimiter(key string, window time.Duration, capacity int64) *SlidingWindowLimiter { m.mu.RLock() limiter, exists := m.limiters[key] m.mu.RUnlock() if exists { return limiter } m.mu.Lock() defer m.mu.Unlock() if limiter, exists = m.limiters[key]; exists { return limiter } limiter = NewSlidingWindowLimiter(window, capacity) m.limiters[key] = limiter return limiter } // Cleanup 清理过期的不活跃 limiter,防止 map 无界增长(P0 资源泄漏修复) func (m *RateLimitMiddleware) Cleanup() { m.mu.Lock() defer m.mu.Unlock() now := time.Now().UnixMilli() for key, limiter := range m.limiters { limiter.mu.Lock() cutoff := now - limiter.window.Milliseconds() // 只保留仍在窗口内的请求时间戳 validRequests := make([]int64, 0, len(limiter.requests)) for _, ts := range limiter.requests { if ts > cutoff { validRequests = append(validRequests, ts) } } limiter.requests = validRequests isEmpty := len(limiter.requests) == 0 limiter.mu.Unlock() if isEmpty { delete(m.limiters, key) } } } // StartCleanup 启动后台定期清理 goroutine,返回停止函数(P0 资源泄漏修复) func (m *RateLimitMiddleware) StartCleanup() func() { ticker := time.NewTicker(m.cleanupInt) done := make(chan struct{}) go func() { for { select { case <-ticker.C: m.Cleanup() case <-done: ticker.Stop() return } } }() return func() { close(done) } }