fix: v6 code review P0 auth/IDOR fixes + frontend regression patches
Backend fixes: - auth_handler: P0 认证逻辑修复 - ratelimit: 限速中间件增强 + 新增单元测试 - auth_service: 认证服务逻辑完善 + 新增测试 - server: server 配置增强 + 新增测试 - handler_test: 新增 handler 层集成测试 - auth_bootstrap_test: bootstrap 路径测试 Frontend patches: - LoginPage/RegisterPage: CSRF + 表单交互修复 - BootstrapAdminPage: 引导流程修复 - DevicesPage: 设备管理页修复 - auth/social-accounts/users/webhooks services: 类型修正 - csrf.ts: CSRF token 处理修正 - E2E 脚本: CDP smoke + auth e2e 增强 Docs: - FULL_CODE_REVIEW_REPORT_2026-04-20 - report-v6 执行计划 - REAL_PROJECT_STATUS 更新 - .gitignore: 新增 .gocache-*/config.yaml 排除 验证: go build/vet 0错误, go test 42/42 PASS, 0 FAIL
This commit is contained in:
103
internal/api/middleware/auth_bootstrap_test.go
Normal file
103
internal/api/middleware/auth_bootstrap_test.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/user-management-system/internal/auth"
|
||||
"github.com/user-management-system/internal/cache"
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/repository"
|
||||
"github.com/user-management-system/internal/service"
|
||||
gormsqlite "gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func TestAuthMiddleware_AcceptsBootstrapAdminTokenImmediately(t *testing.T) {
|
||||
t.Helper()
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
|
||||
DriverName: "sqlite",
|
||||
DSN: "file:middleware_bootstrap_test?mode=memory&cache=shared",
|
||||
}), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("open sqlite failed: %v", err)
|
||||
}
|
||||
|
||||
if err := db.AutoMigrate(&domain.User{}, &domain.Role{}, &domain.UserRole{}); err != nil {
|
||||
t.Fatalf("migrate failed: %v", err)
|
||||
}
|
||||
|
||||
if err := db.Create(&domain.Role{
|
||||
Name: "管理员",
|
||||
Code: "admin",
|
||||
IsSystem: true,
|
||||
Status: domain.RoleStatusEnabled,
|
||||
}).Error; err != nil {
|
||||
t.Fatalf("seed admin role failed: %v", err)
|
||||
}
|
||||
|
||||
jwtManager, err := auth.NewJWTWithOptions(auth.JWTOptions{
|
||||
HS256Secret: "test-bootstrap-token-secret-at-least-32-chars",
|
||||
AccessTokenExpire: 15 * time.Minute,
|
||||
RefreshTokenExpire: 7 * 24 * time.Hour,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("create jwt manager failed: %v", err)
|
||||
}
|
||||
|
||||
l1Cache := cache.NewL1Cache()
|
||||
l2Cache := cache.NewRedisCache(false)
|
||||
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
|
||||
|
||||
userRepo := repository.NewUserRepository(db)
|
||||
roleRepo := repository.NewRoleRepository(db)
|
||||
userRoleRepo := repository.NewUserRoleRepository(db)
|
||||
|
||||
authService := service.NewAuthService(userRepo, nil, jwtManager, cacheManager, 8, 5, 15*time.Minute)
|
||||
authService.SetRoleRepositories(userRoleRepo, roleRepo)
|
||||
|
||||
loginResponse, err := authService.BootstrapAdmin(context.Background(), &service.BootstrapAdminRequest{
|
||||
Username: "bootstrap_admin",
|
||||
Email: "bootstrap_admin@example.com",
|
||||
Password: "AdminPass123!",
|
||||
}, "127.0.0.1")
|
||||
if err != nil {
|
||||
t.Fatalf("bootstrap admin failed: %v", err)
|
||||
}
|
||||
if loginResponse == nil || loginResponse.AccessToken == "" {
|
||||
t.Fatalf("expected bootstrap access token, got %+v", loginResponse)
|
||||
}
|
||||
|
||||
if _, err := jwtManager.ValidateAccessToken(loginResponse.AccessToken); err != nil {
|
||||
t.Fatalf("bootstrap access token should validate immediately: %v", err)
|
||||
}
|
||||
|
||||
authMiddleware := NewAuthMiddleware(jwtManager, userRepo, userRoleRepo, l1Cache)
|
||||
authMiddleware.SetCacheManager(cacheManager)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
ctx, engine := gin.CreateTestContext(recorder)
|
||||
ctx.Request = httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||
ctx.Request.Header.Set("Authorization", "Bearer "+loginResponse.AccessToken)
|
||||
|
||||
engine.Use(authMiddleware.Required())
|
||||
engine.GET("/protected", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 0})
|
||||
})
|
||||
|
||||
engine.ServeHTTP(recorder, ctx.Request)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected bootstrap token to pass auth middleware immediately, got %d body: %s", recorder.Code, recorder.Body.String())
|
||||
}
|
||||
}
|
||||
@@ -1,14 +1,21 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/user-management-system/internal/config"
|
||||
)
|
||||
|
||||
// RateLimitMiddleware 限流中间件
|
||||
// RateLimitMiddleware provides simple in-memory sliding-window rate limiting.
|
||||
type RateLimitMiddleware struct {
|
||||
cfg config.RateLimitConfig
|
||||
limiters map[string]*SlidingWindowLimiter
|
||||
@@ -16,7 +23,7 @@ type RateLimitMiddleware struct {
|
||||
cleanupInt time.Duration
|
||||
}
|
||||
|
||||
// SlidingWindowLimiter 滑动窗口限流器
|
||||
// SlidingWindowLimiter enforces a fixed-capacity sliding window.
|
||||
type SlidingWindowLimiter struct {
|
||||
mu sync.Mutex
|
||||
window time.Duration
|
||||
@@ -24,7 +31,6 @@ type SlidingWindowLimiter struct {
|
||||
requests []int64
|
||||
}
|
||||
|
||||
// NewSlidingWindowLimiter 创建滑动窗口限流器
|
||||
func NewSlidingWindowLimiter(window time.Duration, capacity int64) *SlidingWindowLimiter {
|
||||
return &SlidingWindowLimiter{
|
||||
window: window,
|
||||
@@ -33,7 +39,6 @@ func NewSlidingWindowLimiter(window time.Duration, capacity int64) *SlidingWindo
|
||||
}
|
||||
}
|
||||
|
||||
// Allow 检查是否允许请求
|
||||
func (l *SlidingWindowLimiter) Allow() bool {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
@@ -41,16 +46,14 @@ func (l *SlidingWindowLimiter) Allow() bool {
|
||||
now := time.Now().UnixMilli()
|
||||
cutoff := now - l.window.Milliseconds()
|
||||
|
||||
// 清理过期请求
|
||||
var validRequests []int64
|
||||
for _, t := range l.requests {
|
||||
if t > cutoff {
|
||||
validRequests = append(validRequests, t)
|
||||
validRequests := make([]int64, 0, len(l.requests))
|
||||
for _, ts := range l.requests {
|
||||
if ts > cutoff {
|
||||
validRequests = append(validRequests, ts)
|
||||
}
|
||||
}
|
||||
l.requests = validRequests
|
||||
|
||||
// 检查容量
|
||||
if int64(len(l.requests)) >= l.capacity {
|
||||
return false
|
||||
}
|
||||
@@ -59,7 +62,6 @@ func (l *SlidingWindowLimiter) Allow() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// NewRateLimitMiddleware 创建限流中间件
|
||||
func NewRateLimitMiddleware(cfg config.RateLimitConfig) *RateLimitMiddleware {
|
||||
return &RateLimitMiddleware{
|
||||
cfg: cfg,
|
||||
@@ -68,30 +70,28 @@ func NewRateLimitMiddleware(cfg config.RateLimitConfig) *RateLimitMiddleware {
|
||||
}
|
||||
}
|
||||
|
||||
// Register 返回注册接口的限流中间件
|
||||
func (m *RateLimitMiddleware) Register() gin.HandlerFunc {
|
||||
return m.limitForKey("register", 60, 10)
|
||||
}
|
||||
|
||||
// Login 返回登录接口的限流中间件
|
||||
func (m *RateLimitMiddleware) Login() gin.HandlerFunc {
|
||||
return m.limitForKey("login", 60, 5)
|
||||
}
|
||||
|
||||
// API 返回 API 接口的限流中间件
|
||||
func (m *RateLimitMiddleware) API() gin.HandlerFunc {
|
||||
return m.limitForKey("api", 60, 100)
|
||||
}
|
||||
|
||||
// Refresh 返回刷新令牌的限流中间件
|
||||
func (m *RateLimitMiddleware) Refresh() gin.HandlerFunc {
|
||||
return m.limitForKey("refresh", 60, 10)
|
||||
}
|
||||
|
||||
func (m *RateLimitMiddleware) limitForKey(key string, windowSeconds int, capacity int64) gin.HandlerFunc {
|
||||
limiter := m.getOrCreateLimiter(key, time.Duration(windowSeconds)*time.Second, capacity)
|
||||
func (m *RateLimitMiddleware) limitForKey(bucket string, windowSeconds int, capacity int64) gin.HandlerFunc {
|
||||
window := time.Duration(windowSeconds) * time.Second
|
||||
|
||||
return func(c *gin.Context) {
|
||||
limiterKey := m.resolveLimiterKey(c, bucket)
|
||||
limiter := m.getOrCreateLimiter(limiterKey, window, capacity)
|
||||
if !limiter.Allow() {
|
||||
c.JSON(429, gin.H{
|
||||
"code": 429,
|
||||
@@ -104,6 +104,81 @@ func (m *RateLimitMiddleware) limitForKey(key string, windowSeconds int, capacit
|
||||
}
|
||||
}
|
||||
|
||||
func (m *RateLimitMiddleware) resolveLimiterKey(c *gin.Context, bucket string) string {
|
||||
if bucket == "refresh" {
|
||||
if refreshToken := extractRefreshToken(c); refreshToken != "" {
|
||||
return fmt.Sprintf("%s:token:%s", bucket, fingerprintValue(refreshToken))
|
||||
}
|
||||
}
|
||||
|
||||
identity := "anonymous"
|
||||
if c != nil {
|
||||
if userID, ok := c.Get("user_id"); ok {
|
||||
identity = fmt.Sprintf("user:%v", userID)
|
||||
} else if ip := c.ClientIP(); ip != "" {
|
||||
identity = "ip:" + ip
|
||||
}
|
||||
}
|
||||
|
||||
if bucket == "api" {
|
||||
method := ""
|
||||
route := ""
|
||||
if c != nil {
|
||||
if c.Request != nil {
|
||||
method = c.Request.Method
|
||||
if c.Request.URL != nil {
|
||||
route = c.Request.URL.Path
|
||||
}
|
||||
}
|
||||
if fullPath := c.FullPath(); fullPath != "" {
|
||||
route = fullPath
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("%s:%s:%s:%s", bucket, method, route, identity)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s:%s", bucket, identity)
|
||||
}
|
||||
|
||||
func extractRefreshToken(c *gin.Context) string {
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if refreshToken, err := c.Cookie("ums_refresh_token"); err == nil && refreshToken != "" {
|
||||
return refreshToken
|
||||
}
|
||||
|
||||
if c.Request == nil || c.Request.Body == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(body))
|
||||
|
||||
if len(bytes.TrimSpace(body)) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return payload.RefreshToken
|
||||
}
|
||||
|
||||
func fingerprintValue(value string) string {
|
||||
sum := sha256.Sum256([]byte(value))
|
||||
return hex.EncodeToString(sum[:12])
|
||||
}
|
||||
|
||||
func (m *RateLimitMiddleware) getOrCreateLimiter(key string, window time.Duration, capacity int64) *SlidingWindowLimiter {
|
||||
m.mu.RLock()
|
||||
limiter, exists := m.limiters[key]
|
||||
@@ -116,7 +191,6 @@ func (m *RateLimitMiddleware) getOrCreateLimiter(key string, window time.Duratio
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// 双重检查
|
||||
if limiter, exists = m.limiters[key]; exists {
|
||||
return limiter
|
||||
}
|
||||
|
||||
140
internal/api/middleware/ratelimit_test.go
Normal file
140
internal/api/middleware/ratelimit_test.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/user-management-system/internal/config"
|
||||
)
|
||||
|
||||
func performRateLimitedRequest(router *gin.Engine, path string, userID int64) *httptest.ResponseRecorder {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, path, nil)
|
||||
req.RemoteAddr = "127.0.0.1:12345"
|
||||
req.Header.Set("X-Test-User-ID", strconv.FormatInt(userID, 10))
|
||||
router.ServeHTTP(recorder, req)
|
||||
return recorder
|
||||
}
|
||||
|
||||
func performRefreshRateLimitedRequestWithCookie(router *gin.Engine, refreshToken string) *httptest.ResponseRecorder {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/auth/refresh", nil)
|
||||
req.RemoteAddr = "127.0.0.1:12345"
|
||||
if refreshToken != "" {
|
||||
req.AddCookie(&http.Cookie{Name: "ums_refresh_token", Value: refreshToken})
|
||||
}
|
||||
router.ServeHTTP(recorder, req)
|
||||
return recorder
|
||||
}
|
||||
|
||||
func performRefreshRateLimitedRequestWithBody(router *gin.Engine, refreshToken string) *httptest.ResponseRecorder {
|
||||
recorder := httptest.NewRecorder()
|
||||
body := bytes.NewBufferString(`{"refresh_token":"` + refreshToken + `"}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/auth/refresh", body)
|
||||
req.RemoteAddr = "127.0.0.1:12345"
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(recorder, req)
|
||||
return recorder
|
||||
}
|
||||
|
||||
func TestRateLimitMiddleware_API_ScopesBudgetByRouteForAuthenticatedUser(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rateLimitMiddleware := NewRateLimitMiddleware(config.RateLimitConfig{})
|
||||
router := gin.New()
|
||||
router.Use(func(c *gin.Context) {
|
||||
rawUserID := c.GetHeader("X-Test-User-ID")
|
||||
if rawUserID != "" {
|
||||
userID, err := strconv.ParseInt(rawUserID, 10, 64)
|
||||
if err == nil {
|
||||
c.Set("user_id", userID)
|
||||
}
|
||||
}
|
||||
c.Next()
|
||||
})
|
||||
|
||||
protected := router.Group("")
|
||||
protected.Use(rateLimitMiddleware.API())
|
||||
protected.GET("/users", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
protected.GET("/roles", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
recorder := performRateLimitedRequest(router, "/users", 1)
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("request %d to /users returned %d, want %d", i+1, recorder.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
sameRouteOverflow := performRateLimitedRequest(router, "/users", 1)
|
||||
if sameRouteOverflow.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("overflow request to /users returned %d, want %d", sameRouteOverflow.Code, http.StatusTooManyRequests)
|
||||
}
|
||||
|
||||
differentRoute := performRateLimitedRequest(router, "/roles", 1)
|
||||
if differentRoute.Code != http.StatusOK {
|
||||
t.Fatalf("request to /roles after exhausting /users budget returned %d, want %d", differentRoute.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimitMiddleware_Refresh_ScopesBudgetByRefreshCookie(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rateLimitMiddleware := NewRateLimitMiddleware(config.RateLimitConfig{})
|
||||
router := gin.New()
|
||||
router.POST("/auth/refresh", rateLimitMiddleware.Refresh(), func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
recorder := performRefreshRateLimitedRequestWithCookie(router, "refresh-token-a")
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("request %d for refresh-token-a returned %d, want %d", i+1, recorder.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
sameTokenOverflow := performRefreshRateLimitedRequestWithCookie(router, "refresh-token-a")
|
||||
if sameTokenOverflow.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("overflow request for refresh-token-a returned %d, want %d", sameTokenOverflow.Code, http.StatusTooManyRequests)
|
||||
}
|
||||
|
||||
differentToken := performRefreshRateLimitedRequestWithCookie(router, "refresh-token-b")
|
||||
if differentToken.Code != http.StatusOK {
|
||||
t.Fatalf("request for refresh-token-b after exhausting refresh-token-a budget returned %d, want %d", differentToken.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimitMiddleware_Refresh_ScopesBudgetByRefreshTokenBody(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rateLimitMiddleware := NewRateLimitMiddleware(config.RateLimitConfig{})
|
||||
router := gin.New()
|
||||
router.POST("/auth/refresh", rateLimitMiddleware.Refresh(), func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
recorder := performRefreshRateLimitedRequestWithBody(router, "refresh-token-a")
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("request %d for refresh-token-a body returned %d, want %d", i+1, recorder.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
sameTokenOverflow := performRefreshRateLimitedRequestWithBody(router, "refresh-token-a")
|
||||
if sameTokenOverflow.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("overflow request for refresh-token-a body returned %d, want %d", sameTokenOverflow.Code, http.StatusTooManyRequests)
|
||||
}
|
||||
|
||||
differentToken := performRefreshRateLimitedRequestWithBody(router, "refresh-token-b")
|
||||
if differentToken.Code != http.StatusOK {
|
||||
t.Fatalf("request for refresh-token-b body after exhausting refresh-token-a budget returned %d, want %d", differentToken.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user