Files
user-system/internal/api/middleware/ratelimit.go
long-agent 2a18a6fb47 fix(n+1): 批量查询替代循环单查
- IsAdminBootstrapRequired: userRepo.GetByID 循环 → GetByIDs 批量
- AssignRoles: roleRepo.GetByID 循环 → GetByIDs 批量
- 在 userRepositoryInterface 补充 GetByIDs 方法签名
2026-05-08 08:05:26 +08:00

246 lines
5.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package middleware
import (
"bytes"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/config"
)
// RateLimitMiddleware provides simple in-memory sliding-window rate limiting.
type RateLimitMiddleware struct {
cfg config.RateLimitConfig
limiters map[string]*SlidingWindowLimiter
mu sync.RWMutex
cleanupInt time.Duration
}
// SlidingWindowLimiter enforces a fixed-capacity sliding window.
type SlidingWindowLimiter struct {
mu sync.Mutex
window time.Duration
capacity int64
requests []int64
}
func NewSlidingWindowLimiter(window time.Duration, capacity int64) *SlidingWindowLimiter {
return &SlidingWindowLimiter{
window: window,
capacity: capacity,
requests: make([]int64, 0),
}
}
func (l *SlidingWindowLimiter) Allow() bool {
l.mu.Lock()
defer l.mu.Unlock()
now := time.Now().UnixMilli()
cutoff := now - l.window.Milliseconds()
validRequests := make([]int64, 0, len(l.requests))
for _, ts := range l.requests {
if ts > cutoff {
validRequests = append(validRequests, ts)
}
}
l.requests = validRequests
if int64(len(l.requests)) >= l.capacity {
return false
}
l.requests = append(l.requests, now)
return true
}
func NewRateLimitMiddleware(cfg config.RateLimitConfig) *RateLimitMiddleware {
return &RateLimitMiddleware{
cfg: cfg,
limiters: make(map[string]*SlidingWindowLimiter),
cleanupInt: 5 * time.Minute,
}
}
func (m *RateLimitMiddleware) Register() gin.HandlerFunc {
return m.limitForKey("register", 60, 10)
}
func (m *RateLimitMiddleware) Login() gin.HandlerFunc {
return m.limitForKey("login", 60, 5)
}
func (m *RateLimitMiddleware) API() gin.HandlerFunc {
return m.limitForKey("api", 60, 100)
}
func (m *RateLimitMiddleware) Refresh() gin.HandlerFunc {
return m.limitForKey("refresh", 60, 10)
}
func (m *RateLimitMiddleware) limitForKey(bucket string, windowSeconds int, capacity int64) gin.HandlerFunc {
window := time.Duration(windowSeconds) * time.Second
return func(c *gin.Context) {
limiterKey := m.resolveLimiterKey(c, bucket)
limiter := m.getOrCreateLimiter(limiterKey, window, capacity)
if !limiter.Allow() {
c.JSON(429, gin.H{
"code": 429,
"message": "请求过于频繁,请稍后再试",
})
c.Abort()
return
}
c.Next()
}
}
func (m *RateLimitMiddleware) resolveLimiterKey(c *gin.Context, bucket string) string {
if bucket == "refresh" {
if refreshToken := extractRefreshToken(c); refreshToken != "" {
return fmt.Sprintf("%s:token:%s", bucket, fingerprintValue(refreshToken))
}
}
identity := "anonymous"
if c != nil {
if userID, ok := c.Get("user_id"); ok {
identity = fmt.Sprintf("user:%v", userID)
} else if ip := c.ClientIP(); ip != "" {
identity = "ip:" + ip
}
}
if bucket == "api" {
method := ""
route := ""
if c != nil {
if c.Request != nil {
method = c.Request.Method
if c.Request.URL != nil {
route = c.Request.URL.Path
}
}
if fullPath := c.FullPath(); fullPath != "" {
route = fullPath
}
}
return fmt.Sprintf("%s:%s:%s:%s", bucket, method, route, identity)
}
return fmt.Sprintf("%s:%s", bucket, identity)
}
func extractRefreshToken(c *gin.Context) string {
if c == nil {
return ""
}
if refreshToken, err := c.Cookie("ums_refresh_token"); err == nil && refreshToken != "" {
return refreshToken
}
if c.Request == nil || c.Request.Body == nil {
return ""
}
body, err := io.ReadAll(c.Request.Body)
if err != nil {
return ""
}
c.Request.Body = io.NopCloser(bytes.NewReader(body))
if len(bytes.TrimSpace(body)) == 0 {
return ""
}
var payload struct {
RefreshToken string `json:"refresh_token"`
}
if err := json.Unmarshal(body, &payload); err != nil {
return ""
}
return payload.RefreshToken
}
func fingerprintValue(value string) string {
sum := sha256.Sum256([]byte(value))
return hex.EncodeToString(sum[:12])
}
func (m *RateLimitMiddleware) getOrCreateLimiter(key string, window time.Duration, capacity int64) *SlidingWindowLimiter {
m.mu.RLock()
limiter, exists := m.limiters[key]
m.mu.RUnlock()
if exists {
return limiter
}
m.mu.Lock()
defer m.mu.Unlock()
if limiter, exists = m.limiters[key]; exists {
return limiter
}
limiter = NewSlidingWindowLimiter(window, capacity)
m.limiters[key] = limiter
return limiter
}
// Cleanup 清理过期的不活跃 limiter防止 map 无界增长P0 资源泄漏修复)
func (m *RateLimitMiddleware) Cleanup() {
m.mu.Lock()
defer m.mu.Unlock()
now := time.Now().UnixMilli()
for key, limiter := range m.limiters {
limiter.mu.Lock()
cutoff := now - limiter.window.Milliseconds()
// 只保留仍在窗口内的请求时间戳
validRequests := make([]int64, 0, len(limiter.requests))
for _, ts := range limiter.requests {
if ts > cutoff {
validRequests = append(validRequests, ts)
}
}
limiter.requests = validRequests
isEmpty := len(limiter.requests) == 0
limiter.mu.Unlock()
if isEmpty {
delete(m.limiters, key)
}
}
}
// StartCleanup 启动后台定期清理 goroutine返回停止函数P0 资源泄漏修复)
func (m *RateLimitMiddleware) StartCleanup() func() {
ticker := time.NewTicker(m.cleanupInt)
done := make(chan struct{})
go func() {
for {
select {
case <-ticker.C:
m.Cleanup()
case <-done:
ticker.Stop()
return
}
}
}()
return func() { close(done) }
}