测试增强: - handler_test.go: 大幅增强 handler 集成测试(+1284/-98 行) - theme_handler_test.go: 增强主题管理测试(+174/-22 行) - auth_bootstrap_test.go: 新增 bootstrap 认证测试(+329 行) - ratelimit_test.go: 新增限流中间件测试(+153 行) - runtime_test.go: 新增运行时中间件测试(+351 行) 错误处理: - auth_handler.go: classifyErrorMessage 增加 TOTP 错误码和 2FA 状态字分类 清理: - 删除覆盖率报告残留文件(coverage_issue, handler, middleware 等) - 归档 docs/superpowers/plans/2026-05-09-middleware-test-backfill-phase1.md
294 lines
10 KiB
Go
294 lines
10 KiB
Go
package middleware
|
|
|
|
import (
|
|
"bytes"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strconv"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
|
"github.com/user-management-system/internal/config"
|
|
)
|
|
|
|
func performRateLimitedRequest(router *gin.Engine, path string, userID int64) *httptest.ResponseRecorder {
|
|
recorder := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, path, nil)
|
|
req.RemoteAddr = "127.0.0.1:12345"
|
|
req.Header.Set("X-Test-User-ID", strconv.FormatInt(userID, 10))
|
|
router.ServeHTTP(recorder, req)
|
|
return recorder
|
|
}
|
|
|
|
func performRefreshRateLimitedRequestWithCookie(router *gin.Engine, refreshToken string) *httptest.ResponseRecorder {
|
|
recorder := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodPost, "/auth/refresh", nil)
|
|
req.RemoteAddr = "127.0.0.1:12345"
|
|
if refreshToken != "" {
|
|
req.AddCookie(&http.Cookie{Name: "ums_refresh_token", Value: refreshToken})
|
|
}
|
|
router.ServeHTTP(recorder, req)
|
|
return recorder
|
|
}
|
|
|
|
func performRefreshRateLimitedRequestWithBody(router *gin.Engine, refreshToken string) *httptest.ResponseRecorder {
|
|
recorder := httptest.NewRecorder()
|
|
body := bytes.NewBufferString(`{"refresh_token":"` + refreshToken + `"}`)
|
|
req := httptest.NewRequest(http.MethodPost, "/auth/refresh", body)
|
|
req.RemoteAddr = "127.0.0.1:12345"
|
|
req.Header.Set("Content-Type", "application/json")
|
|
router.ServeHTTP(recorder, req)
|
|
return recorder
|
|
}
|
|
|
|
func TestRateLimitMiddleware_API_ScopesBudgetByRouteForAuthenticatedUser(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
rateLimitMiddleware := NewRateLimitMiddleware(config.RateLimitConfig{})
|
|
router := gin.New()
|
|
router.Use(func(c *gin.Context) {
|
|
rawUserID := c.GetHeader("X-Test-User-ID")
|
|
if rawUserID != "" {
|
|
userID, err := strconv.ParseInt(rawUserID, 10, 64)
|
|
if err == nil {
|
|
c.Set("user_id", userID)
|
|
}
|
|
}
|
|
c.Next()
|
|
})
|
|
|
|
protected := router.Group("")
|
|
protected.Use(rateLimitMiddleware.API())
|
|
protected.GET("/users", func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
})
|
|
protected.GET("/roles", func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
})
|
|
|
|
for i := 0; i < 100; i++ {
|
|
recorder := performRateLimitedRequest(router, "/users", 1)
|
|
if recorder.Code != http.StatusOK {
|
|
t.Fatalf("request %d to /users returned %d, want %d", i+1, recorder.Code, http.StatusOK)
|
|
}
|
|
}
|
|
|
|
sameRouteOverflow := performRateLimitedRequest(router, "/users", 1)
|
|
if sameRouteOverflow.Code != http.StatusTooManyRequests {
|
|
t.Fatalf("overflow request to /users returned %d, want %d", sameRouteOverflow.Code, http.StatusTooManyRequests)
|
|
}
|
|
|
|
differentRoute := performRateLimitedRequest(router, "/roles", 1)
|
|
if differentRoute.Code != http.StatusOK {
|
|
t.Fatalf("request to /roles after exhausting /users budget returned %d, want %d", differentRoute.Code, http.StatusOK)
|
|
}
|
|
}
|
|
|
|
func TestRateLimitMiddleware_Refresh_ScopesBudgetByRefreshCookie(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
rateLimitMiddleware := NewRateLimitMiddleware(config.RateLimitConfig{})
|
|
router := gin.New()
|
|
router.POST("/auth/refresh", rateLimitMiddleware.Refresh(), func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
})
|
|
|
|
for i := 0; i < 10; i++ {
|
|
recorder := performRefreshRateLimitedRequestWithCookie(router, "refresh-token-a")
|
|
if recorder.Code != http.StatusOK {
|
|
t.Fatalf("request %d for refresh-token-a returned %d, want %d", i+1, recorder.Code, http.StatusOK)
|
|
}
|
|
}
|
|
|
|
sameTokenOverflow := performRefreshRateLimitedRequestWithCookie(router, "refresh-token-a")
|
|
if sameTokenOverflow.Code != http.StatusTooManyRequests {
|
|
t.Fatalf("overflow request for refresh-token-a returned %d, want %d", sameTokenOverflow.Code, http.StatusTooManyRequests)
|
|
}
|
|
|
|
differentToken := performRefreshRateLimitedRequestWithCookie(router, "refresh-token-b")
|
|
if differentToken.Code != http.StatusOK {
|
|
t.Fatalf("request for refresh-token-b after exhausting refresh-token-a budget returned %d, want %d", differentToken.Code, http.StatusOK)
|
|
}
|
|
}
|
|
|
|
func TestRateLimitMiddleware_Refresh_ScopesBudgetByRefreshTokenBody(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
rateLimitMiddleware := NewRateLimitMiddleware(config.RateLimitConfig{})
|
|
router := gin.New()
|
|
router.POST("/auth/refresh", rateLimitMiddleware.Refresh(), func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
})
|
|
|
|
for i := 0; i < 10; i++ {
|
|
recorder := performRefreshRateLimitedRequestWithBody(router, "refresh-token-a")
|
|
if recorder.Code != http.StatusOK {
|
|
t.Fatalf("request %d for refresh-token-a body returned %d, want %d", i+1, recorder.Code, http.StatusOK)
|
|
}
|
|
}
|
|
|
|
sameTokenOverflow := performRefreshRateLimitedRequestWithBody(router, "refresh-token-a")
|
|
if sameTokenOverflow.Code != http.StatusTooManyRequests {
|
|
t.Fatalf("overflow request for refresh-token-a body returned %d, want %d", sameTokenOverflow.Code, http.StatusTooManyRequests)
|
|
}
|
|
|
|
differentToken := performRefreshRateLimitedRequestWithBody(router, "refresh-token-b")
|
|
if differentToken.Code != http.StatusOK {
|
|
t.Fatalf("request for refresh-token-b body after exhausting refresh-token-a budget returned %d, want %d", differentToken.Code, http.StatusOK)
|
|
}
|
|
}
|
|
|
|
func TestExtractRefreshToken_PreservesRequestBody(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
body := bytes.NewBufferString(`{"refresh_token":"refresh-token-a"}`)
|
|
req := httptest.NewRequest(http.MethodPost, "/auth/refresh", body)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
recorder := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(recorder)
|
|
c.Request = req
|
|
|
|
if got := extractRefreshToken(c); got != "refresh-token-a" {
|
|
t.Fatalf("extractRefreshToken() = %q, want refresh-token-a", got)
|
|
}
|
|
|
|
readBack := new(bytes.Buffer)
|
|
if _, err := readBack.ReadFrom(c.Request.Body); err != nil {
|
|
t.Fatalf("re-read body failed: %v", err)
|
|
}
|
|
if got := readBack.String(); got != `{"refresh_token":"refresh-token-a"}` {
|
|
t.Fatalf("request body after extraction = %q, want original JSON", got)
|
|
}
|
|
}
|
|
|
|
func TestRateLimitMiddleware_CleanupRemovesExpiredLimiters(t *testing.T) {
|
|
middleware := NewRateLimitMiddleware(config.RateLimitConfig{})
|
|
limiter := middleware.getOrCreateLimiter("login:ip:127.0.0.1", time.Millisecond, 1)
|
|
limiter.requests = []int64{time.Now().Add(-time.Second).UnixMilli()}
|
|
|
|
middleware.Cleanup()
|
|
|
|
if _, exists := middleware.limiters["login:ip:127.0.0.1"]; exists {
|
|
t.Fatal("expected expired limiter to be removed")
|
|
}
|
|
}
|
|
|
|
func TestRateLimitMiddleware_ResolveLimiterKeyPrefersUserIDForAPI(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
recorder := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(recorder)
|
|
c.Request = httptest.NewRequest(http.MethodGet, "/users/1", nil)
|
|
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
|
c.Set("user_id", int64(99))
|
|
|
|
middleware := NewRateLimitMiddleware(config.RateLimitConfig{})
|
|
key := middleware.resolveLimiterKey(c, "api")
|
|
|
|
if key != "api:GET:/users/1:user:99" {
|
|
t.Fatalf("resolveLimiterKey() = %q, want api:GET:/users/1:user:99", key)
|
|
}
|
|
}
|
|
|
|
func TestSlidingWindowLimiter_EnforcesCapacityWithinWindow(t *testing.T) {
|
|
limiter := NewSlidingWindowLimiter(time.Second, 2)
|
|
|
|
if !limiter.Allow() {
|
|
t.Fatal("expected first request to pass")
|
|
}
|
|
if !limiter.Allow() {
|
|
t.Fatal("expected second request to pass")
|
|
}
|
|
if limiter.Allow() {
|
|
t.Fatal("expected third request to be rejected")
|
|
}
|
|
}
|
|
|
|
func TestRateLimitMiddleware_StartCleanupStopsSafely(t *testing.T) {
|
|
middleware := NewRateLimitMiddleware(config.RateLimitConfig{})
|
|
middleware.cleanupInt = 10 * time.Millisecond
|
|
stop := middleware.StartCleanup()
|
|
time.Sleep(25 * time.Millisecond)
|
|
stop()
|
|
}
|
|
|
|
func TestRateLimitMiddleware_ResolveLimiterKeyRefreshFallsBackToIP(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
recorder := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(recorder)
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/auth/refresh", bytes.NewBufferString(`{}`))
|
|
c.Request.RemoteAddr = "127.0.0.1:12345"
|
|
|
|
middleware := NewRateLimitMiddleware(config.RateLimitConfig{})
|
|
key := middleware.resolveLimiterKey(c, "refresh")
|
|
|
|
if key != "refresh:ip:127.0.0.1" {
|
|
t.Fatalf("resolveLimiterKey() = %q, want refresh:ip:127.0.0.1", key)
|
|
}
|
|
}
|
|
|
|
func TestFingerprintValue_IsDeterministic(t *testing.T) {
|
|
first := fingerprintValue("refresh-token-a")
|
|
second := fingerprintValue("refresh-token-a")
|
|
third := fingerprintValue("refresh-token-b")
|
|
|
|
if first != second {
|
|
t.Fatalf("expected same input fingerprint to match: %q vs %q", first, second)
|
|
}
|
|
if first == third {
|
|
t.Fatalf("expected different inputs to produce different fingerprints: %q vs %q", first, third)
|
|
}
|
|
}
|
|
|
|
func TestRateLimitMiddleware_RegisterAndLoginLimiters(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
middleware := NewRateLimitMiddleware(config.RateLimitConfig{})
|
|
router := gin.New()
|
|
router.POST("/register", middleware.Register(), func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
})
|
|
router.POST("/login", middleware.Login(), func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
})
|
|
|
|
for i := 0; i < 10; i++ {
|
|
recorder := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodPost, "/register", nil)
|
|
req.RemoteAddr = "127.0.0.1:12345"
|
|
router.ServeHTTP(recorder, req)
|
|
if recorder.Code != http.StatusOK {
|
|
t.Fatalf("register request %d returned %d, want %d", i+1, recorder.Code, http.StatusOK)
|
|
}
|
|
}
|
|
|
|
registerOverflow := httptest.NewRecorder()
|
|
registerReq := httptest.NewRequest(http.MethodPost, "/register", nil)
|
|
registerReq.RemoteAddr = "127.0.0.1:12345"
|
|
router.ServeHTTP(registerOverflow, registerReq)
|
|
if registerOverflow.Code != http.StatusTooManyRequests {
|
|
t.Fatalf("register overflow returned %d, want %d", registerOverflow.Code, http.StatusTooManyRequests)
|
|
}
|
|
|
|
for i := 0; i < 5; i++ {
|
|
recorder := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodPost, "/login", nil)
|
|
req.RemoteAddr = "127.0.0.1:54321"
|
|
router.ServeHTTP(recorder, req)
|
|
if recorder.Code != http.StatusOK {
|
|
t.Fatalf("login request %d returned %d, want %d", i+1, recorder.Code, http.StatusOK)
|
|
}
|
|
}
|
|
|
|
loginOverflow := httptest.NewRecorder()
|
|
loginReq := httptest.NewRequest(http.MethodPost, "/login", nil)
|
|
loginReq.RemoteAddr = "127.0.0.1:54321"
|
|
router.ServeHTTP(loginOverflow, loginReq)
|
|
if loginOverflow.Code != http.StatusTooManyRequests {
|
|
t.Fatalf("login overflow returned %d, want %d", loginOverflow.Code, http.StatusTooManyRequests)
|
|
}
|
|
}
|