Files
user-system/internal/api/middleware/ratelimit.go
long-agent 3f3bb82f1d fix: v6 code review P0 auth/IDOR fixes + frontend regression patches
Backend fixes:
- auth_handler: P0 认证逻辑修复
- ratelimit: 限速中间件增强 + 新增单元测试
- auth_service: 认证服务逻辑完善 + 新增测试
- server: server 配置增强 + 新增测试
- handler_test: 新增 handler 层集成测试
- auth_bootstrap_test: bootstrap 路径测试

Frontend patches:
- LoginPage/RegisterPage: CSRF + 表单交互修复
- BootstrapAdminPage: 引导流程修复
- DevicesPage: 设备管理页修复
- auth/social-accounts/users/webhooks services: 类型修正
- csrf.ts: CSRF token 处理修正
- E2E 脚本: CDP smoke + auth e2e 增强

Docs:
- FULL_CODE_REVIEW_REPORT_2026-04-20
- report-v6 执行计划
- REAL_PROJECT_STATUS 更新
- .gitignore: 新增 .gocache-*/config.yaml 排除

验证: go build/vet 0错误, go test 42/42 PASS, 0 FAIL
2026-04-23 07:14:12 +08:00

202 lines
4.3 KiB
Go

package middleware
import (
"bytes"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/config"
)
// RateLimitMiddleware provides simple in-memory sliding-window rate limiting.
type RateLimitMiddleware struct {
cfg config.RateLimitConfig
limiters map[string]*SlidingWindowLimiter
mu sync.RWMutex
cleanupInt time.Duration
}
// SlidingWindowLimiter enforces a fixed-capacity sliding window.
type SlidingWindowLimiter struct {
mu sync.Mutex
window time.Duration
capacity int64
requests []int64
}
func NewSlidingWindowLimiter(window time.Duration, capacity int64) *SlidingWindowLimiter {
return &SlidingWindowLimiter{
window: window,
capacity: capacity,
requests: make([]int64, 0),
}
}
func (l *SlidingWindowLimiter) Allow() bool {
l.mu.Lock()
defer l.mu.Unlock()
now := time.Now().UnixMilli()
cutoff := now - l.window.Milliseconds()
validRequests := make([]int64, 0, len(l.requests))
for _, ts := range l.requests {
if ts > cutoff {
validRequests = append(validRequests, ts)
}
}
l.requests = validRequests
if int64(len(l.requests)) >= l.capacity {
return false
}
l.requests = append(l.requests, now)
return true
}
func NewRateLimitMiddleware(cfg config.RateLimitConfig) *RateLimitMiddleware {
return &RateLimitMiddleware{
cfg: cfg,
limiters: make(map[string]*SlidingWindowLimiter),
cleanupInt: 5 * time.Minute,
}
}
func (m *RateLimitMiddleware) Register() gin.HandlerFunc {
return m.limitForKey("register", 60, 10)
}
func (m *RateLimitMiddleware) Login() gin.HandlerFunc {
return m.limitForKey("login", 60, 5)
}
func (m *RateLimitMiddleware) API() gin.HandlerFunc {
return m.limitForKey("api", 60, 100)
}
func (m *RateLimitMiddleware) Refresh() gin.HandlerFunc {
return m.limitForKey("refresh", 60, 10)
}
func (m *RateLimitMiddleware) limitForKey(bucket string, windowSeconds int, capacity int64) gin.HandlerFunc {
window := time.Duration(windowSeconds) * time.Second
return func(c *gin.Context) {
limiterKey := m.resolveLimiterKey(c, bucket)
limiter := m.getOrCreateLimiter(limiterKey, window, capacity)
if !limiter.Allow() {
c.JSON(429, gin.H{
"code": 429,
"message": "请求过于频繁,请稍后再试",
})
c.Abort()
return
}
c.Next()
}
}
func (m *RateLimitMiddleware) resolveLimiterKey(c *gin.Context, bucket string) string {
if bucket == "refresh" {
if refreshToken := extractRefreshToken(c); refreshToken != "" {
return fmt.Sprintf("%s:token:%s", bucket, fingerprintValue(refreshToken))
}
}
identity := "anonymous"
if c != nil {
if userID, ok := c.Get("user_id"); ok {
identity = fmt.Sprintf("user:%v", userID)
} else if ip := c.ClientIP(); ip != "" {
identity = "ip:" + ip
}
}
if bucket == "api" {
method := ""
route := ""
if c != nil {
if c.Request != nil {
method = c.Request.Method
if c.Request.URL != nil {
route = c.Request.URL.Path
}
}
if fullPath := c.FullPath(); fullPath != "" {
route = fullPath
}
}
return fmt.Sprintf("%s:%s:%s:%s", bucket, method, route, identity)
}
return fmt.Sprintf("%s:%s", bucket, identity)
}
func extractRefreshToken(c *gin.Context) string {
if c == nil {
return ""
}
if refreshToken, err := c.Cookie("ums_refresh_token"); err == nil && refreshToken != "" {
return refreshToken
}
if c.Request == nil || c.Request.Body == nil {
return ""
}
body, err := io.ReadAll(c.Request.Body)
if err != nil {
return ""
}
c.Request.Body = io.NopCloser(bytes.NewReader(body))
if len(bytes.TrimSpace(body)) == 0 {
return ""
}
var payload struct {
RefreshToken string `json:"refresh_token"`
}
if err := json.Unmarshal(body, &payload); err != nil {
return ""
}
return payload.RefreshToken
}
func fingerprintValue(value string) string {
sum := sha256.Sum256([]byte(value))
return hex.EncodeToString(sum[:12])
}
func (m *RateLimitMiddleware) getOrCreateLimiter(key string, window time.Duration, capacity int64) *SlidingWindowLimiter {
m.mu.RLock()
limiter, exists := m.limiters[key]
m.mu.RUnlock()
if exists {
return limiter
}
m.mu.Lock()
defer m.mu.Unlock()
if limiter, exists = m.limiters[key]; exists {
return limiter
}
limiter = NewSlidingWindowLimiter(window, capacity)
m.limiters[key] = limiter
return limiter
}