test: 增强 handler/middleware 测试覆盖并优化错误分类

测试增强:
- handler_test.go: 大幅增强 handler 集成测试(+1284/-98 行)
- theme_handler_test.go: 增强主题管理测试(+174/-22 行)
- auth_bootstrap_test.go: 新增 bootstrap 认证测试(+329 行)
- ratelimit_test.go: 新增限流中间件测试(+153 行)
- runtime_test.go: 新增运行时中间件测试(+351 行)

错误处理:
- auth_handler.go: classifyErrorMessage 增加 TOTP 错误码和 2FA 状态字分类

清理:
- 删除覆盖率报告残留文件(coverage_issue, handler, middleware 等)
- 归档 docs/superpowers/plans/2026-05-09-middleware-test-backfill-phase1.md
This commit is contained in:
2026-05-10 13:46:29 +08:00
parent f050c60a09
commit b77412b47f
8 changed files with 2205 additions and 34 deletions

View File

@@ -784,13 +784,17 @@ func classifyErrorMessage(msg string) int {
return http.StatusNotFound
case contains(lower, "already exists", "已存在", "已注册", "duplicate"):
return http.StatusConflict
case contains(lower, "验证码错误", "验证码或恢复码错误", "verification code", "recovery code"):
return http.StatusUnauthorized
case contains(lower, "unauthorized", "invalid token", "token", "令牌", "未认证"):
return http.StatusUnauthorized
case contains(lower, "forbidden", "permission", "权限", "禁止"):
return http.StatusForbidden
case contains(lower, "2fa 已", "2fa 未", "请先初始化 2fa", "已启用", "未启用"):
return http.StatusBadRequest
case contains(lower, "invalid", "required", "must", "cannot be empty", "不能为空",
"格式", "参数", "密码不正确", "incorrect", "wrong", "too short", "too long",
"已失效", "expired", "验证码不正确", "不能与"):
"已失效", "expired", "验证码不正确", "不能与", "不能删除自己", "不能删除最后一个管理员"):
return http.StatusBadRequest
case contains(lower, "locked", "too many", "账号已被锁定", "rate limit"):
return http.StatusTooManyRequests

File diff suppressed because it is too large Load Diff

View File

@@ -17,10 +17,6 @@ import (
"gorm.io/gorm/logger"
)
// =============================================================================
// Theme Handler Tests - TDD approach
// =============================================================================
func setupThemeTestEnv(t *testing.T) (*handler.ThemeHandler, *gorm.DB) {
t.Helper()
gin.SetMode(gin.TestMode)
@@ -45,10 +41,22 @@ func setupThemeTestEnv(t *testing.T) (*handler.ThemeHandler, *gorm.DB) {
return handler.NewThemeHandler(themeSvc), db
}
func createThemeForTest(t *testing.T, h *handler.ThemeHandler, body string) {
t.Helper()
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/api/v1/themes", bytes.NewReader([]byte(body)))
c.Request.Header.Set("Content-Type", "application/json")
h.CreateTheme(c)
if w.Code != http.StatusCreated {
t.Fatalf("create theme failed: %d %s", w.Code, w.Body.String())
}
}
func TestThemeHandler_CreateTheme(t *testing.T) {
h, _ := setupThemeTestEnv(t)
t.Run("创建主题成功", func(t *testing.T) {
t.Run("create success", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
body := `{"name":"test-theme","primary_color":"#1976d2"}`
@@ -58,20 +66,19 @@ func TestThemeHandler_CreateTheme(t *testing.T) {
h.CreateTheme(c)
if w.Code != http.StatusCreated {
t.Errorf("期望状态码 %d, 得到 %d", http.StatusCreated, w.Code)
t.Fatalf("expected status %d, got %d", http.StatusCreated, w.Code)
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("解析响应失败: %v", err)
t.Fatalf("decode response failed: %v", err)
}
if resp["code"].(float64) != 0 {
t.Errorf("期望 code=0, 得到 %v", resp["code"])
t.Fatalf("expected code=0, got %v", resp["code"])
}
})
t.Run("创建主题失败-缺少名称", func(t *testing.T) {
t.Run("create missing name", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
body := `{"primary_color":"#1976d2"}`
@@ -81,31 +88,30 @@ func TestThemeHandler_CreateTheme(t *testing.T) {
h.CreateTheme(c)
if w.Code != http.StatusBadRequest {
t.Errorf("期望状态码 %d, 得到 %d", http.StatusBadRequest, w.Code)
t.Fatalf("expected status %d, got %d", http.StatusBadRequest, w.Code)
}
})
}
func TestThemeHandler_ListThemes(t *testing.T) {
h, _ := setupThemeTestEnv(t)
createThemeForTest(t, h, `{"name":"list-theme","primary_color":"#1976d2"}`)
t.Run("获取主题列表", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("GET", "/api/v1/themes", nil)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("GET", "/api/v1/themes", nil)
h.ListThemes(c)
h.ListThemes(c)
if w.Code != http.StatusOK {
t.Errorf("期望状态码 %d, 得到 %d", http.StatusOK, w.Code)
}
})
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
}
func TestThemeHandler_GetTheme(t *testing.T) {
h, _ := setupThemeTestEnv(t)
t.Run("获取主题失败-无效ID", func(t *testing.T) {
t.Run("get invalid id", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "invalid"}}
@@ -114,7 +120,70 @@ func TestThemeHandler_GetTheme(t *testing.T) {
h.GetTheme(c)
if w.Code != http.StatusBadRequest {
t.Errorf("期望状态码 %d, 得到 %d", http.StatusBadRequest, w.Code)
t.Fatalf("expected status %d, got %d", http.StatusBadRequest, w.Code)
}
})
t.Run("get success", func(t *testing.T) {
createThemeForTest(t, h, `{"name":"get-theme","primary_color":"#1976d2"}`)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "1"}}
c.Request = httptest.NewRequest("GET", "/api/v1/themes/1", nil)
h.GetTheme(c)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d, body=%s", http.StatusOK, w.Code, w.Body.String())
}
})
}
func TestThemeHandler_UpdateTheme(t *testing.T) {
h, _ := setupThemeTestEnv(t)
createThemeForTest(t, h, `{"name":"theme-update","primary_color":"#111111"}`)
t.Run("update success", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "1"}}
body := `{"primary_color":"#222222","enabled":true}`
c.Request = httptest.NewRequest("PUT", "/api/v1/themes/1", bytes.NewReader([]byte(body)))
c.Request.Header.Set("Content-Type", "application/json")
h.UpdateTheme(c)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d, body=%s", http.StatusOK, w.Code, w.Body.String())
}
})
t.Run("update invalid id", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "invalid"}}
c.Request = httptest.NewRequest("PUT", "/api/v1/themes/invalid", bytes.NewReader([]byte(`{}`)))
c.Request.Header.Set("Content-Type", "application/json")
h.UpdateTheme(c)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status %d, got %d", http.StatusBadRequest, w.Code)
}
})
t.Run("update invalid json", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "1"}}
c.Request = httptest.NewRequest("PUT", "/api/v1/themes/1", bytes.NewReader([]byte(`{"primary_color":`)))
c.Request.Header.Set("Content-Type", "application/json")
h.UpdateTheme(c)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status %d, got %d", http.StatusBadRequest, w.Code)
}
})
}
@@ -122,7 +191,7 @@ func TestThemeHandler_GetTheme(t *testing.T) {
func TestThemeHandler_DeleteTheme(t *testing.T) {
h, _ := setupThemeTestEnv(t)
t.Run("删除主题失败-无效ID", func(t *testing.T) {
t.Run("delete invalid id", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "invalid"}}
@@ -131,7 +200,90 @@ func TestThemeHandler_DeleteTheme(t *testing.T) {
h.DeleteTheme(c)
if w.Code != http.StatusBadRequest {
t.Errorf("期望状态码 %d, 得到 %d", http.StatusBadRequest, w.Code)
t.Fatalf("expected status %d, got %d", http.StatusBadRequest, w.Code)
}
})
t.Run("delete success", func(t *testing.T) {
createThemeForTest(t, h, `{"name":"theme-delete","primary_color":"#1976d2"}`)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "1"}}
c.Request = httptest.NewRequest("DELETE", "/api/v1/themes/1", nil)
h.DeleteTheme(c)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d, body=%s", http.StatusOK, w.Code, w.Body.String())
}
})
}
func TestThemeHandler_DefaultAndActiveFlows(t *testing.T) {
h, _ := setupThemeTestEnv(t)
createThemeForTest(t, h, `{"name":"default-theme","primary_color":"#111111","is_default":true}`)
createThemeForTest(t, h, `{"name":"other-theme","primary_color":"#222222"}`)
t.Run("list all themes", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("GET", "/api/v1/themes/all", nil)
h.ListAllThemes(c)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
})
t.Run("get default theme", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("GET", "/api/v1/themes/default", nil)
h.GetDefaultTheme(c)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
})
t.Run("set default invalid id", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "bad"}}
c.Request = httptest.NewRequest("PUT", "/api/v1/themes/bad/default", nil)
h.SetDefaultTheme(c)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status %d, got %d", http.StatusBadRequest, w.Code)
}
})
t.Run("set default success", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "2"}}
c.Request = httptest.NewRequest("PUT", "/api/v1/themes/2/default", nil)
h.SetDefaultTheme(c)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d, body=%s", http.StatusOK, w.Code, w.Body.String())
}
})
t.Run("get active theme", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("GET", "/api/v1/themes/active", nil)
h.GetActiveTheme(c)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
})
}

View File

@@ -2,6 +2,7 @@ package middleware
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"
@@ -19,6 +20,68 @@ import (
_ "modernc.org/sqlite"
)
type authStubUserRepo struct {
user *domain.User
err error
}
func (s authStubUserRepo) GetByID(_ context.Context, _ int64) (*domain.User, error) {
return s.user, s.err
}
type authStubUserRoleRepo struct {
roles []*domain.Role
perms []*domain.Permission
err error
}
func (s authStubUserRoleRepo) GetUserRolesAndPermissions(_ context.Context, _ int64) ([]*domain.Role, []*domain.Permission, error) {
return s.roles, s.perms, s.err
}
func newTestJWT(t *testing.T) *auth.JWT {
t.Helper()
jwtManager, err := auth.NewJWTWithOptions(auth.JWTOptions{
HS256Secret: "test-middleware-secret-at-least-32-chars",
AccessTokenExpire: 15 * time.Minute,
RefreshTokenExpire: 7 * 24 * time.Hour,
})
if err != nil {
t.Fatalf("create jwt manager failed: %v", err)
}
return jwtManager
}
func newAuthMiddlewareForTest(t *testing.T, user *domain.User, roles []*domain.Role, perms []*domain.Permission) (*AuthMiddleware, *auth.JWT, *cache.L1Cache) {
t.Helper()
jwtManager := newTestJWT(t)
l1Cache := cache.NewL1Cache()
middleware := NewAuthMiddleware(jwtManager, authStubUserRepo{user: user}, authStubUserRoleRepo{roles: roles, perms: perms}, l1Cache)
return middleware, jwtManager, l1Cache
}
func performMiddlewareRequest(t *testing.T, middleware gin.HandlerFunc, authHeader string) *httptest.ResponseRecorder {
t.Helper()
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
router := gin.New()
router.Use(middleware)
router.GET("/protected", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"code": 0})
})
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
if authHeader != "" {
req.Header.Set("Authorization", authHeader)
}
router.ServeHTTP(recorder, req)
return recorder
}
func TestAuthMiddleware_AcceptsBootstrapAdminTokenImmediately(t *testing.T) {
t.Helper()
gin.SetMode(gin.TestMode)
@@ -101,3 +164,269 @@ func TestAuthMiddleware_AcceptsBootstrapAdminTokenImmediately(t *testing.T) {
t.Fatalf("expected bootstrap token to pass auth middleware immediately, got %d body: %s", recorder.Code, recorder.Body.String())
}
}
func TestAuthMiddleware_RequiredRejectsMissingToken(t *testing.T) {
middleware, _, _ := newAuthMiddlewareForTest(t, nil, nil, nil)
recorder := performMiddlewareRequest(t, middleware.Required(), "")
if recorder.Code != http.StatusUnauthorized {
t.Fatalf("expected 401 for missing token, got %d", recorder.Code)
}
}
func TestAuthMiddleware_RequiredRejectsInvalidToken(t *testing.T) {
middleware, _, _ := newAuthMiddlewareForTest(t, nil, nil, nil)
recorder := performMiddlewareRequest(t, middleware.Required(), "Bearer not-a-jwt")
if recorder.Code != http.StatusUnauthorized {
t.Fatalf("expected 401 for invalid token, got %d", recorder.Code)
}
}
func TestAuthMiddleware_RequiredRejectsBlacklistedToken(t *testing.T) {
user := &domain.User{ID: 7, Username: "alice", Status: domain.UserStatusActive}
middleware, jwtManager, l1Cache := newAuthMiddlewareForTest(t, user, nil, nil)
token, err := jwtManager.GenerateAccessToken(user.ID, user.Username, 0)
if err != nil {
t.Fatalf("generate access token failed: %v", err)
}
claims, err := jwtManager.ValidateAccessToken(token)
if err != nil {
t.Fatalf("validate access token failed: %v", err)
}
l1Cache.Set("jwt_blacklist:"+claims.JTI, true, time.Minute)
recorder := performMiddlewareRequest(t, middleware.Required(), "Bearer "+token)
if recorder.Code != http.StatusUnauthorized {
t.Fatalf("expected 401 for blacklisted token, got %d", recorder.Code)
}
}
func TestAuthMiddleware_RequiredRejectsInactiveUser(t *testing.T) {
user := &domain.User{ID: 8, Username: "disabled", Status: domain.UserStatusDisabled}
middleware, jwtManager, _ := newAuthMiddlewareForTest(t, user, nil, nil)
token, err := jwtManager.GenerateAccessToken(user.ID, user.Username, 0)
if err != nil {
t.Fatalf("generate access token failed: %v", err)
}
recorder := performMiddlewareRequest(t, middleware.Required(), "Bearer "+token)
if recorder.Code != http.StatusUnauthorized {
t.Fatalf("expected 401 for inactive user, got %d", recorder.Code)
}
}
func TestAuthMiddleware_RequiredInjectsIdentityAndAuthorizations(t *testing.T) {
gin.SetMode(gin.TestMode)
user := &domain.User{ID: 9, Username: "admin", Status: domain.UserStatusActive}
roles := []*domain.Role{{Code: "admin"}, {Code: "auditor"}}
perms := []*domain.Permission{{Code: "users:read"}, {Code: "users:write"}}
middleware, jwtManager, _ := newAuthMiddlewareForTest(t, user, roles, perms)
token, err := jwtManager.GenerateAccessToken(user.ID, user.Username, 0)
if err != nil {
t.Fatalf("generate access token failed: %v", err)
}
recorder := httptest.NewRecorder()
router := gin.New()
router.Use(middleware.Required())
router.GET("/protected", func(c *gin.Context) {
if got := c.GetInt64("user_id"); got != user.ID {
t.Fatalf("user_id = %d, want %d", got, user.ID)
}
if got := c.GetString("username"); got != user.Username {
t.Fatalf("username = %q, want %q", got, user.Username)
}
roleCodes := GetRoleCodes(c)
if len(roleCodes) != 2 || roleCodes[0] != "admin" || roleCodes[1] != "auditor" {
t.Fatalf("unexpected role codes: %#v", roleCodes)
}
permCodes := GetPermissionCodes(c)
if len(permCodes) != 2 || permCodes[0] != "users:read" || permCodes[1] != "users:write" {
t.Fatalf("unexpected permission codes: %#v", permCodes)
}
c.JSON(http.StatusOK, gin.H{"code": 0})
})
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
req.Header.Set("Authorization", "Bearer "+token)
router.ServeHTTP(recorder, req)
if recorder.Code != http.StatusOK {
t.Fatalf("expected 200 for valid token, got %d body: %s", recorder.Code, recorder.Body.String())
}
}
func TestAuthMiddleware_OptionalAllowsAnonymousRequest(t *testing.T) {
middleware, _, _ := newAuthMiddlewareForTest(t, nil, nil, nil)
recorder := performMiddlewareRequest(t, middleware.Optional(), "")
if recorder.Code != http.StatusOK {
t.Fatalf("expected optional middleware to allow anonymous request, got %d", recorder.Code)
}
}
func TestAuthMiddleware_OptionalInjectsIdentityForValidToken(t *testing.T) {
gin.SetMode(gin.TestMode)
user := &domain.User{ID: 21, Username: "optional-user", Status: domain.UserStatusActive}
roles := []*domain.Role{{Code: "viewer"}}
perms := []*domain.Permission{{Code: "users:read"}}
middleware, jwtManager, _ := newAuthMiddlewareForTest(t, user, roles, perms)
token, err := jwtManager.GenerateAccessToken(user.ID, user.Username, 0)
if err != nil {
t.Fatalf("generate access token failed: %v", err)
}
recorder := httptest.NewRecorder()
router := gin.New()
router.Use(middleware.Optional())
router.GET("/optional", func(c *gin.Context) {
if got := c.GetInt64("user_id"); got != user.ID {
t.Fatalf("user_id = %d, want %d", got, user.ID)
}
if got := c.GetString("username"); got != user.Username {
t.Fatalf("username = %q, want %q", got, user.Username)
}
if got := GetRoleCodes(c); len(got) != 1 || got[0] != "viewer" {
t.Fatalf("role_codes = %#v, want [viewer]", got)
}
if got := GetPermissionCodes(c); len(got) != 1 || got[0] != "users:read" {
t.Fatalf("permission_codes = %#v, want [users:read]", got)
}
c.Status(http.StatusOK)
})
req := httptest.NewRequest(http.MethodGet, "/optional", nil)
req.Header.Set("Authorization", "Bearer "+token)
router.ServeHTTP(recorder, req)
if recorder.Code != http.StatusOK {
t.Fatalf("expected valid optional auth request to pass, got %d", recorder.Code)
}
}
func TestAuthMiddleware_ExtractTokenCases(t *testing.T) {
gin.SetMode(gin.TestMode)
middleware, _, _ := newAuthMiddlewareForTest(t, nil, nil, nil)
testCases := []struct {
name string
header string
want string
}{
{name: "missing header", header: "", want: ""},
{name: "valid bearer", header: "Bearer abc.def", want: "abc.def"},
{name: "lowercase bearer rejected", header: "bearer abc", want: ""},
{name: "missing token value", header: "Bearer", want: ""},
{name: "wrong scheme", header: "Basic abc", want: ""},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/protected", nil)
if tc.header != "" {
c.Request.Header.Set("Authorization", tc.header)
}
if got := middleware.extractToken(c); got != tc.want {
t.Fatalf("extractToken() = %q, want %q", got, tc.want)
}
})
}
}
func TestAuthMiddleware_ValidateUserStateAndCacheInvalidation(t *testing.T) {
user := &domain.User{
ID: 11,
Username: "cached-user",
Status: domain.UserStatusActive,
PasswordChangedAt: time.Unix(200, 0),
}
middleware, _, l1Cache := newAuthMiddlewareForTest(t, user, nil, nil)
if got := middleware.validateUserState(context.Background(), user.ID, 150); got == "" {
t.Fatal("expected password-changed denial for stale token")
}
if _, ok := l1Cache.Get("user_state:11"); !ok {
t.Fatal("expected user state to be cached")
}
middleware.InvalidateUserStateCache(user.ID)
if _, ok := l1Cache.Get("user_state:11"); ok {
t.Fatal("expected user state cache to be cleared")
}
}
func TestAuthMiddleware_LoadUserRolesAndPermsCachesAndInvalidates(t *testing.T) {
user := &domain.User{ID: 12, Username: "role-user", Status: domain.UserStatusActive}
roles := []*domain.Role{{Code: "admin"}}
perms := []*domain.Permission{{Code: "users:read"}}
middleware, _, l1Cache := newAuthMiddlewareForTest(t, user, roles, perms)
roleCodes, permCodes := middleware.loadUserRolesAndPerms(context.Background(), user.ID)
if len(roleCodes) != 1 || roleCodes[0] != "admin" {
t.Fatalf("unexpected role codes: %#v", roleCodes)
}
if len(permCodes) != 1 || permCodes[0] != "users:read" {
t.Fatalf("unexpected permission codes: %#v", permCodes)
}
if _, ok := l1Cache.Get("user_perms:12"); !ok {
t.Fatal("expected user permissions to be cached")
}
middleware.InvalidateUserPermCache(user.ID)
if _, ok := l1Cache.Get("user_perms:12"); ok {
t.Fatal("expected user permission cache to be cleared")
}
}
func TestAuthMiddleware_AddToBlacklistAndUserHelpers(t *testing.T) {
activeUser := &domain.User{ID: 13, Username: "active", Status: domain.UserStatusActive}
middleware, _, l1Cache := newAuthMiddlewareForTest(t, activeUser, nil, nil)
middleware.AddToBlacklist("jti-1", time.Minute)
if _, ok := l1Cache.Get("jwt_blacklist:jti-1"); !ok {
t.Fatal("expected blacklist entry in cache")
}
if !middleware.isUserActive(context.Background(), activeUser.ID) {
t.Fatal("expected active user to be active")
}
if middleware.isPasswordChangedSinceTokenIssued(context.Background(), activeUser.ID, 0) {
t.Fatal("expected zero token pce to skip password change check")
}
changedUser := &domain.User{
ID: 14,
Username: "changed",
Status: domain.UserStatusActive,
PasswordChangedAt: time.Unix(300, 0),
}
changedMiddleware, _, _ := newAuthMiddlewareForTest(t, changedUser, nil, nil)
if !changedMiddleware.isPasswordChangedSinceTokenIssued(context.Background(), changedUser.ID, 200) {
t.Fatal("expected password-changed helper to return true")
}
}
func TestAuthMiddleware_UserHelpersHandleRepoFailures(t *testing.T) {
middleware, _, _ := newAuthMiddlewareForTest(t, nil, nil, nil)
middleware.userRepo = authStubUserRepo{err: errors.New("db down")}
if middleware.isUserActive(context.Background(), 99) {
t.Fatal("expected repo failure to mark user inactive")
}
if got := middleware.validateUserState(context.Background(), 99, 0); got == "" {
t.Fatal("expected validateUserState to deny on repo failure")
}
}

View File

@@ -6,6 +6,7 @@ import (
"net/http/httptest"
"strconv"
"testing"
"time"
"github.com/gin-gonic/gin"
@@ -138,3 +139,155 @@ func TestRateLimitMiddleware_Refresh_ScopesBudgetByRefreshTokenBody(t *testing.T
t.Fatalf("request for refresh-token-b body after exhausting refresh-token-a budget returned %d, want %d", differentToken.Code, http.StatusOK)
}
}
func TestExtractRefreshToken_PreservesRequestBody(t *testing.T) {
gin.SetMode(gin.TestMode)
body := bytes.NewBufferString(`{"refresh_token":"refresh-token-a"}`)
req := httptest.NewRequest(http.MethodPost, "/auth/refresh", body)
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = req
if got := extractRefreshToken(c); got != "refresh-token-a" {
t.Fatalf("extractRefreshToken() = %q, want refresh-token-a", got)
}
readBack := new(bytes.Buffer)
if _, err := readBack.ReadFrom(c.Request.Body); err != nil {
t.Fatalf("re-read body failed: %v", err)
}
if got := readBack.String(); got != `{"refresh_token":"refresh-token-a"}` {
t.Fatalf("request body after extraction = %q, want original JSON", got)
}
}
func TestRateLimitMiddleware_CleanupRemovesExpiredLimiters(t *testing.T) {
middleware := NewRateLimitMiddleware(config.RateLimitConfig{})
limiter := middleware.getOrCreateLimiter("login:ip:127.0.0.1", time.Millisecond, 1)
limiter.requests = []int64{time.Now().Add(-time.Second).UnixMilli()}
middleware.Cleanup()
if _, exists := middleware.limiters["login:ip:127.0.0.1"]; exists {
t.Fatal("expected expired limiter to be removed")
}
}
func TestRateLimitMiddleware_ResolveLimiterKeyPrefersUserIDForAPI(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/users/1", nil)
c.Params = gin.Params{{Key: "id", Value: "1"}}
c.Set("user_id", int64(99))
middleware := NewRateLimitMiddleware(config.RateLimitConfig{})
key := middleware.resolveLimiterKey(c, "api")
if key != "api:GET:/users/1:user:99" {
t.Fatalf("resolveLimiterKey() = %q, want api:GET:/users/1:user:99", key)
}
}
func TestSlidingWindowLimiter_EnforcesCapacityWithinWindow(t *testing.T) {
limiter := NewSlidingWindowLimiter(time.Second, 2)
if !limiter.Allow() {
t.Fatal("expected first request to pass")
}
if !limiter.Allow() {
t.Fatal("expected second request to pass")
}
if limiter.Allow() {
t.Fatal("expected third request to be rejected")
}
}
func TestRateLimitMiddleware_StartCleanupStopsSafely(t *testing.T) {
middleware := NewRateLimitMiddleware(config.RateLimitConfig{})
middleware.cleanupInt = 10 * time.Millisecond
stop := middleware.StartCleanup()
time.Sleep(25 * time.Millisecond)
stop()
}
func TestRateLimitMiddleware_ResolveLimiterKeyRefreshFallsBackToIP(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPost, "/auth/refresh", bytes.NewBufferString(`{}`))
c.Request.RemoteAddr = "127.0.0.1:12345"
middleware := NewRateLimitMiddleware(config.RateLimitConfig{})
key := middleware.resolveLimiterKey(c, "refresh")
if key != "refresh:ip:127.0.0.1" {
t.Fatalf("resolveLimiterKey() = %q, want refresh:ip:127.0.0.1", key)
}
}
func TestFingerprintValue_IsDeterministic(t *testing.T) {
first := fingerprintValue("refresh-token-a")
second := fingerprintValue("refresh-token-a")
third := fingerprintValue("refresh-token-b")
if first != second {
t.Fatalf("expected same input fingerprint to match: %q vs %q", first, second)
}
if first == third {
t.Fatalf("expected different inputs to produce different fingerprints: %q vs %q", first, third)
}
}
func TestRateLimitMiddleware_RegisterAndLoginLimiters(t *testing.T) {
gin.SetMode(gin.TestMode)
middleware := NewRateLimitMiddleware(config.RateLimitConfig{})
router := gin.New()
router.POST("/register", middleware.Register(), func(c *gin.Context) {
c.Status(http.StatusOK)
})
router.POST("/login", middleware.Login(), func(c *gin.Context) {
c.Status(http.StatusOK)
})
for i := 0; i < 10; i++ {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/register", nil)
req.RemoteAddr = "127.0.0.1:12345"
router.ServeHTTP(recorder, req)
if recorder.Code != http.StatusOK {
t.Fatalf("register request %d returned %d, want %d", i+1, recorder.Code, http.StatusOK)
}
}
registerOverflow := httptest.NewRecorder()
registerReq := httptest.NewRequest(http.MethodPost, "/register", nil)
registerReq.RemoteAddr = "127.0.0.1:12345"
router.ServeHTTP(registerOverflow, registerReq)
if registerOverflow.Code != http.StatusTooManyRequests {
t.Fatalf("register overflow returned %d, want %d", registerOverflow.Code, http.StatusTooManyRequests)
}
for i := 0; i < 5; i++ {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/login", nil)
req.RemoteAddr = "127.0.0.1:54321"
router.ServeHTTP(recorder, req)
if recorder.Code != http.StatusOK {
t.Fatalf("login request %d returned %d, want %d", i+1, recorder.Code, http.StatusOK)
}
}
loginOverflow := httptest.NewRecorder()
loginReq := httptest.NewRequest(http.MethodPost, "/login", nil)
loginReq.RemoteAddr = "127.0.0.1:54321"
router.ServeHTTP(loginOverflow, loginReq)
if loginOverflow.Code != http.StatusTooManyRequests {
t.Fatalf("login overflow returned %d, want %d", loginOverflow.Code, http.StatusTooManyRequests)
}
}

View File

@@ -1,15 +1,21 @@
package middleware
import (
"bytes"
"encoding/json"
"errors"
"log"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/config"
apierrors "github.com/user-management-system/internal/pkg/errors"
"github.com/user-management-system/internal/security"
)
func TestCORS_UsesConfiguredOrigins(t *testing.T) {
@@ -44,6 +50,31 @@ func TestCORS_UsesConfiguredOrigins(t *testing.T) {
}
}
func TestCORS_RejectsDisallowedOrigin(t *testing.T) {
gin.SetMode(gin.TestMode)
SetCORSConfig(config.CORSConfig{
AllowedOrigins: []string{"https://app.example.com"},
AllowCredentials: false,
})
t.Cleanup(func() {
SetCORSConfig(config.CORSConfig{
AllowedOrigins: []string{"*"},
AllowCredentials: true,
})
})
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/users", nil)
c.Request.Header.Set("Origin", "https://evil.example.com")
CORS()(c)
if recorder.Code != http.StatusForbidden {
t.Fatalf("expected 403, got %d", recorder.Code)
}
}
func TestSanitizeQuery_MasksSensitiveValues(t *testing.T) {
raw := "token=abc123&foo=bar&access_token=xyz&secret=s1"
sanitized := sanitizeQuery(raw)
@@ -180,6 +211,23 @@ func TestTraceID_ExtractsExistingTraceID(t *testing.T) {
}
}
func TestTraceID_GetTraceIDHandlesMissingAndPresentValue(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/users", nil)
if got := GetTraceID(c); got != "" {
t.Fatalf("GetTraceID() = %q, want empty string", got)
}
c.Set(TraceIDKey, "trace-123")
if got := GetTraceID(c); got != "trace-123" {
t.Fatalf("GetTraceID() = %q, want trace-123", got)
}
}
// ---------- Error handling middleware ----------
func TestErrorHandler_HandlesErrors(t *testing.T) {
@@ -198,6 +246,35 @@ func TestErrorHandler_HandlesErrors(t *testing.T) {
}
}
func TestErrorHandler_ApplicationErrorPreservesStatusAndReason(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
router := gin.New()
router.Use(ErrorHandler())
router.GET("/users", func(c *gin.Context) {
_ = c.Error(apierrors.Forbidden("FORBIDDEN", "denied"))
})
req := httptest.NewRequest(http.MethodGet, "/users", nil)
router.ServeHTTP(recorder, req)
if recorder.Code != http.StatusForbidden {
t.Fatalf("expected status 403, got %d", recorder.Code)
}
var body map[string]any
if err := json.Unmarshal(recorder.Body.Bytes(), &body); err != nil {
t.Fatalf("unmarshal body failed: %v", err)
}
if got := body["reason"]; got != "FORBIDDEN" {
t.Fatalf("reason = %#v, want FORBIDDEN", got)
}
if got := body["message"]; got != "denied" {
t.Fatalf("message = %#v, want denied", got)
}
}
func TestRecover_HandlesPanic(t *testing.T) {
gin.SetMode(gin.TestMode)
@@ -216,3 +293,277 @@ func TestRecover_HandlesPanic(t *testing.T) {
t.Fatalf("expected status 500 after panic, got %d", recorder.Code)
}
}
func TestRecover_ReturnsInternalServerErrorPayload(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
router := gin.New()
router.Use(Recover())
router.GET("/panic", func(c *gin.Context) {
panic("boom")
})
req := httptest.NewRequest(http.MethodGet, "/panic", nil)
router.ServeHTTP(recorder, req)
if recorder.Code != http.StatusInternalServerError {
t.Fatalf("expected status 500 after panic, got %d", recorder.Code)
}
var body map[string]any
if err := json.Unmarshal(recorder.Body.Bytes(), &body); err != nil {
t.Fatalf("unmarshal body failed: %v", err)
}
if got := body["code"]; got != float64(http.StatusInternalServerError) {
t.Fatalf("code = %#v, want %d", got, http.StatusInternalServerError)
}
}
func TestLogger_WritesSanitizedQueryAndErrorContext(t *testing.T) {
gin.SetMode(gin.TestMode)
var buf bytes.Buffer
originalWriter := log.Writer()
log.SetOutput(&buf)
t.Cleanup(func() {
log.SetOutput(originalWriter)
})
recorder := httptest.NewRecorder()
router := gin.New()
router.Use(TraceID())
router.Use(Logger())
router.GET("/users", func(c *gin.Context) {
c.Set("user_id", int64(7))
_ = c.Error(errors.New("boom"))
c.Status(http.StatusAccepted)
})
req := httptest.NewRequest(http.MethodGet, "/users?token=secret&name=alice", nil)
req.RemoteAddr = "203.0.113.5:1234"
req.Header.Set("User-Agent", "logger-test")
router.ServeHTTP(recorder, req)
deadline := time.Now().Add(time.Second)
for time.Now().Before(deadline) && !strings.Contains(buf.String(), "[Query] /users?name=alice&token=%2A%2A%2A") {
time.Sleep(10 * time.Millisecond)
}
logOutput := buf.String()
if !strings.Contains(logOutput, "[API]") {
t.Fatalf("expected API log entry, got %q", logOutput)
}
if !strings.Contains(logOutput, "user_id: 7") {
t.Fatalf("expected user id in logs, got %q", logOutput)
}
if !strings.Contains(logOutput, "[Error]") || !strings.Contains(logOutput, "boom") {
t.Fatalf("expected error log entry, got %q", logOutput)
}
if strings.Contains(logOutput, "token=secret") {
t.Fatalf("expected sanitized query string, got %q", logOutput)
}
}
func TestLogger_DropsMalformedQueryString(t *testing.T) {
gin.SetMode(gin.TestMode)
var buf bytes.Buffer
originalWriter := log.Writer()
log.SetOutput(&buf)
t.Cleanup(func() {
log.SetOutput(originalWriter)
})
recorder := httptest.NewRecorder()
router := gin.New()
router.Use(Logger())
router.GET("/users", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req := httptest.NewRequest(http.MethodGet, "/users?bad=%zz", nil)
router.ServeHTTP(recorder, req)
time.Sleep(25 * time.Millisecond)
if strings.Contains(buf.String(), "[Query]") {
t.Fatalf("expected malformed query to be skipped, got %q", buf.String())
}
}
func TestResponseWrapper_SkipsSSEAndBinaryResponses(t *testing.T) {
gin.SetMode(gin.TestMode)
testCases := []struct {
name string
path string
contentType string
}{
{name: "sse", path: "/stream", contentType: "text/event-stream"},
{name: "binary", path: "/download", contentType: "application/octet-stream"},
{name: "swagger", path: "/swagger/index.html", contentType: ""},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
router := gin.New()
router.Use(ResponseWrapper())
router.GET(tc.path, func(c *gin.Context) {
c.Header("Content-Type", "application/json")
c.JSON(http.StatusOK, gin.H{"ok": true})
})
req := httptest.NewRequest(http.MethodGet, tc.path, nil)
if tc.contentType != "" {
req.Header.Set("Content-Type", tc.contentType)
}
router.ServeHTTP(recorder, req)
if recorder.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", recorder.Code)
}
if got := recorder.Body.String(); got != `{"ok":true}` {
t.Fatalf("body = %s, want raw payload", got)
}
})
}
}
func TestResponseWrapper_BufferMethodsTrackStatusAndBody(t *testing.T) {
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
wrapper := &responseWrapper{
ResponseWriter: c.Writer,
body: bytes.NewBuffer(nil),
statusCode: http.StatusOK,
}
if _, err := wrapper.Write([]byte("abc")); err != nil {
t.Fatalf("Write() error = %v", err)
}
if _, err := wrapper.WriteString("def"); err != nil {
t.Fatalf("WriteString() error = %v", err)
}
wrapper.WriteHeader(http.StatusAccepted)
if got := wrapper.body.String(); got != "abcdef" {
t.Fatalf("buffered body = %q, want abcdef", got)
}
if wrapper.statusCode != http.StatusAccepted {
t.Fatalf("statusCode = %d, want %d", wrapper.statusCode, http.StatusAccepted)
}
}
func TestIPFilter_RealIPAndInternalOnly(t *testing.T) {
gin.SetMode(gin.TestMode)
filter := security.NewIPFilter()
middleware := NewIPFilterMiddleware(filter, IPFilterConfig{
TrustProxy: true,
TrustedProxies: []string{"10.0.0.2"},
})
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/protected", nil)
c.Request.RemoteAddr = "10.0.0.2:8080"
c.Request.Header.Set("X-Forwarded-For", "198.51.100.10, 10.0.0.2")
if got := middleware.realIP(c); got != "198.51.100.10" {
t.Fatalf("realIP() = %q, want 198.51.100.10", got)
}
if !middleware.isTrustedProxy("10.0.0.2") {
t.Fatal("expected trusted proxy match")
}
if middleware.isTrustedProxy("10.0.0.3") {
t.Fatal("unexpected trusted proxy match")
}
if !isPrivateIP("127.0.0.1") {
t.Fatal("expected loopback to be private")
}
if isPrivateIP("198.51.100.10") {
t.Fatal("expected public address to be non-private")
}
allowed := httptest.NewRecorder()
allowedRouter := gin.New()
allowedRouter.Use(InternalOnly())
allowedRouter.GET("/metrics", func(c *gin.Context) {
c.Status(http.StatusOK)
})
allowedReq := httptest.NewRequest(http.MethodGet, "/metrics", nil)
allowedReq.RemoteAddr = "127.0.0.1:12345"
allowedRouter.ServeHTTP(allowed, allowedReq)
if allowed.Code != http.StatusOK {
t.Fatalf("expected private IP to pass, got %d", allowed.Code)
}
blocked := httptest.NewRecorder()
blockedRouter := gin.New()
blockedRouter.Use(InternalOnly())
blockedRouter.GET("/metrics", func(c *gin.Context) {
c.Status(http.StatusOK)
})
blockedReq := httptest.NewRequest(http.MethodGet, "/metrics", nil)
blockedReq.RemoteAddr = "198.51.100.10:12345"
blockedRouter.ServeHTTP(blocked, blockedReq)
if blocked.Code != http.StatusForbidden {
t.Fatalf("expected public IP to be rejected, got %d", blocked.Code)
}
}
func TestIPFilter_FilterAndFallbacks(t *testing.T) {
gin.SetMode(gin.TestMode)
filter := security.NewIPFilter()
if err := filter.AddToBlacklist("198.51.100.10", "manual", time.Minute); err != nil {
t.Fatalf("AddToBlacklist() error = %v", err)
}
middleware := NewIPFilterMiddleware(filter, IPFilterConfig{})
if middleware.GetFilter() != filter {
t.Fatal("expected GetFilter() to expose the original filter")
}
blockedRecorder := httptest.NewRecorder()
blockedRouter := gin.New()
blockedRouter.Use(middleware.Filter())
blockedRouter.GET("/protected", func(c *gin.Context) {
c.Status(http.StatusOK)
})
blockedReq := httptest.NewRequest(http.MethodGet, "/protected", nil)
blockedReq.RemoteAddr = "198.51.100.10:12345"
blockedRouter.ServeHTTP(blockedRecorder, blockedReq)
if blockedRecorder.Code != http.StatusForbidden {
t.Fatalf("expected blocked IP to be rejected, got %d", blockedRecorder.Code)
}
allowedRecorder := httptest.NewRecorder()
allowedRouter := gin.New()
allowedRouter.Use(middleware.Filter())
allowedRouter.GET("/protected", func(c *gin.Context) {
if got := c.GetString("client_ip"); got != "127.0.0.1" {
t.Fatalf("client_ip = %q, want 127.0.0.1", got)
}
c.Status(http.StatusOK)
})
allowedReq := httptest.NewRequest(http.MethodGet, "/protected", nil)
allowedReq.RemoteAddr = "127.0.0.1:54321"
allowedRouter.ServeHTTP(allowedRecorder, allowedReq)
if allowedRecorder.Code != http.StatusOK {
t.Fatalf("expected allowed IP to pass, got %d", allowedRecorder.Code)
}
trustedProxyMiddleware := NewIPFilterMiddleware(filter, IPFilterConfig{
TrustProxy: true,
})
proxyRecorder := httptest.NewRecorder()
proxyCtx, _ := gin.CreateTestContext(proxyRecorder)
proxyCtx.Request = httptest.NewRequest(http.MethodGet, "/protected", nil)
proxyCtx.Request.RemoteAddr = "10.0.0.2:8080"
proxyCtx.Request.Header.Set("X-Real-IP", "203.0.113.9")
if got := trustedProxyMiddleware.realIP(proxyCtx); got != "203.0.113.9" {
t.Fatalf("realIP() X-Real-IP fallback = %q, want 203.0.113.9", got)
}
}

View File

@@ -1,5 +1,3 @@
//go:build unit
package errors
import (

View File

@@ -1,5 +1,3 @@
//go:build unit
package ip
import (