Files
user-system/internal/api/middleware/ratelimit.go

184 lines
4.2 KiB
Go

package middleware
import (
"fmt"
"os"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/config"
)
// RateLimitMiddleware 限流中间件
// 使用 endpoint + subject(IP 或 user_id) 作为限流键,并对空闲条目做 TTL 清理,
// 避免单一全局限流器误伤所有用户,也避免历史客户端条目无限增长。
type RateLimitMiddleware struct {
cfg config.RateLimitConfig
limiters map[string]*limiterEntry
mu sync.RWMutex
cleanupInt time.Duration
lastCleanup time.Time
}
type limiterEntry struct {
limiter *SlidingWindowLimiter
window time.Duration
lastSeen time.Time
}
// SlidingWindowLimiter 滑动窗口限流器
type SlidingWindowLimiter struct {
mu sync.Mutex
window time.Duration
capacity int64
requests []int64
}
// NewSlidingWindowLimiter 创建滑动窗口限流器
func NewSlidingWindowLimiter(window time.Duration, capacity int64) *SlidingWindowLimiter {
return &SlidingWindowLimiter{
window: window,
capacity: capacity,
requests: make([]int64, 0),
}
}
// Allow 检查是否允许请求
func (l *SlidingWindowLimiter) Allow() bool {
l.mu.Lock()
defer l.mu.Unlock()
now := time.Now().UnixMilli()
cutoff := now - l.window.Milliseconds()
// 清理过期请求
validRequests := l.requests[:0]
for _, t := range l.requests {
if t > cutoff {
validRequests = append(validRequests, t)
}
}
l.requests = validRequests
// 检查容量
if int64(len(l.requests)) >= l.capacity {
return false
}
l.requests = append(l.requests, now)
return true
}
// NewRateLimitMiddleware 创建限流中间件
func NewRateLimitMiddleware(cfg config.RateLimitConfig) *RateLimitMiddleware {
return &RateLimitMiddleware{
cfg: cfg,
limiters: make(map[string]*limiterEntry),
cleanupInt: 5 * time.Minute,
lastCleanup: time.Now(),
}
}
// Register 返回注册接口的限流中间件
func (m *RateLimitMiddleware) Register() gin.HandlerFunc {
return m.limitForKey("register", 60, 10)
}
// Login 返回登录接口的限流中间件
func (m *RateLimitMiddleware) Login() gin.HandlerFunc {
return m.limitForKey("login", 60, 5)
}
// API 返回 API 接口的限流中间件
func (m *RateLimitMiddleware) API() gin.HandlerFunc {
return m.limitForKey("api", 60, 100)
}
// Refresh 返回刷新令牌的限流中间件
func (m *RateLimitMiddleware) Refresh() gin.HandlerFunc {
return m.limitForKey("refresh", 60, 10)
}
func (m *RateLimitMiddleware) limitForKey(scope string, windowSeconds int, capacity int64) gin.HandlerFunc {
if os.Getenv("DISABLE_RATE_LIMIT") == "1" {
return func(c *gin.Context) {
c.Next()
}
}
window := time.Duration(windowSeconds) * time.Second
return func(c *gin.Context) {
limiterKey := m.buildLimiterKey(scope, c)
limiter := m.getOrCreateLimiter(limiterKey, window, capacity)
if !limiter.Allow() {
c.JSON(429, gin.H{
"code": 429,
"message": "请求过于频繁,请稍后再试",
})
c.Abort()
return
}
c.Next()
}
}
func (m *RateLimitMiddleware) buildLimiterKey(scope string, c *gin.Context) string {
if userID, ok := c.Get("user_id"); ok {
return fmt.Sprintf("%s:user:%v", scope, userID)
}
return fmt.Sprintf("%s:ip:%s", scope, c.ClientIP())
}
func (m *RateLimitMiddleware) getOrCreateLimiter(key string, window time.Duration, capacity int64) *SlidingWindowLimiter {
now := time.Now()
m.maybeCleanup(now)
m.mu.RLock()
entry, exists := m.limiters[key]
m.mu.RUnlock()
if exists {
m.mu.Lock()
entry.lastSeen = now
m.mu.Unlock()
return entry.limiter
}
m.mu.Lock()
defer m.mu.Unlock()
if entry, exists = m.limiters[key]; exists {
entry.lastSeen = now
return entry.limiter
}
entry = &limiterEntry{
limiter: NewSlidingWindowLimiter(window, capacity),
window: window,
lastSeen: now,
}
m.limiters[key] = entry
return entry.limiter
}
func (m *RateLimitMiddleware) maybeCleanup(now time.Time) {
m.mu.Lock()
defer m.mu.Unlock()
if now.Sub(m.lastCleanup) < m.cleanupInt {
return
}
for key, entry := range m.limiters {
idleTTL := entry.window
if idleTTL < m.cleanupInt {
idleTTL = m.cleanupInt
}
if now.Sub(entry.lastSeen) > idleTTL {
delete(m.limiters, key)
}
}
m.lastCleanup = now
}