fix: v6 code review P0 auth/IDOR fixes + frontend regression patches

Backend fixes:
- auth_handler: P0 认证逻辑修复
- ratelimit: 限速中间件增强 + 新增单元测试
- auth_service: 认证服务逻辑完善 + 新增测试
- server: server 配置增强 + 新增测试
- handler_test: 新增 handler 层集成测试
- auth_bootstrap_test: bootstrap 路径测试

Frontend patches:
- LoginPage/RegisterPage: CSRF + 表单交互修复
- BootstrapAdminPage: 引导流程修复
- DevicesPage: 设备管理页修复
- auth/social-accounts/users/webhooks services: 类型修正
- csrf.ts: CSRF token 处理修正
- E2E 脚本: CDP smoke + auth e2e 增强

Docs:
- FULL_CODE_REVIEW_REPORT_2026-04-20
- report-v6 执行计划
- REAL_PROJECT_STATUS 更新
- .gitignore: 新增 .gocache-*/config.yaml 排除

验证: go build/vet 0错误, go test 42/42 PASS, 0 FAIL
This commit is contained in:
2026-04-23 07:14:12 +08:00
parent 82109ec216
commit 3f3bb82f1d
41 changed files with 2681 additions and 283 deletions

View File

@@ -4,6 +4,7 @@ import (
"context"
"crypto/subtle"
"errors"
"io"
"net/http"
"os"
"strings"
@@ -15,6 +16,11 @@ import (
"github.com/user-management-system/internal/service"
)
const (
refreshTokenCookieName = "ums_refresh_token"
sessionPresenceCookieName = "ums_session_present"
)
// newBackgroundCtx 创建用于后台 goroutine 的带超时独立 context与请求 context 无关)
func newBackgroundCtx(timeoutSec int) (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second)
@@ -129,6 +135,7 @@ func (h *AuthHandler) Login(c *gin.Context) {
handleError(c, err)
return
}
setSessionCookies(c, h.authService, resp.RefreshToken)
c.JSON(http.StatusOK, gin.H{
"code": 0,
@@ -150,20 +157,28 @@ func (h *AuthHandler) Login(c *gin.Context) {
// @Router /api/v1/auth/login/totp-verify [post]
func (h *AuthHandler) VerifyTOTPAfterPasswordLogin(c *gin.Context) {
var req struct {
UserID int64 `json:"user_id" binding:"required"`
Code string `json:"code" binding:"required"`
DeviceID string `json:"device_id"`
UserID int64 `json:"user_id" binding:"required"`
Code string `json:"code" binding:"required"`
DeviceID string `json:"device_id"`
TempToken string `json:"temp_token"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": err.Error()})
return
}
resp, err := h.authService.VerifyTOTPAfterPasswordLogin(c.Request.Context(), req.UserID, req.Code, req.DeviceID)
resp, err := h.authService.VerifyTOTPAfterPasswordLogin(
c.Request.Context(),
req.UserID,
req.Code,
req.DeviceID,
req.TempToken,
)
if err != nil {
handleError(c, err)
return
}
setSessionCookies(c, h.authService, resp.RefreshToken)
c.JSON(http.StatusOK, gin.H{
"code": 0,
@@ -197,6 +212,10 @@ func (h *AuthHandler) Logout(c *gin.Context) {
}
}
if req.RefreshToken == "" {
req.RefreshToken, _ = c.Cookie(refreshTokenCookieName)
}
username, _ := c.Get("username")
usernameStr, _ := username.(string)
@@ -206,6 +225,8 @@ func (h *AuthHandler) Logout(c *gin.Context) {
}
_ = h.authService.Logout(c.Request.Context(), usernameStr, logoutReq)
clearSessionCookies(c)
c.JSON(http.StatusOK, gin.H{"message": "logged out"})
}
@@ -222,19 +243,27 @@ func (h *AuthHandler) Logout(c *gin.Context) {
// @Router /api/v1/auth/refresh-token [post]
func (h *AuthHandler) RefreshToken(c *gin.Context) {
var req struct {
RefreshToken string `json:"refresh_token" binding:"required"`
RefreshToken string `json:"refresh_token"`
}
if err := c.ShouldBindJSON(&req); err != nil {
if err := c.ShouldBindJSON(&req); err != nil && !errors.Is(err, io.EOF) {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.RefreshToken == "" {
req.RefreshToken, _ = c.Cookie(refreshTokenCookieName)
}
if req.RefreshToken == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "refresh_token is required"})
return
}
resp, err := h.authService.RefreshToken(c.Request.Context(), req.RefreshToken)
if err != nil {
handleError(c, err)
return
}
setSessionCookies(c, h.authService, resp.RefreshToken)
c.JSON(http.StatusOK, gin.H{
"code": 0,
@@ -480,6 +509,7 @@ func (h *AuthHandler) LoginByEmailCode(c *gin.Context) {
h.authService.BestEffortRegisterDevicePublic(devCtx, userID, loginReq)
}()
}
setSessionCookies(c, h.authService, resp.RefreshToken)
c.JSON(http.StatusOK, gin.H{
"code": 0,
@@ -544,6 +574,7 @@ func (h *AuthHandler) BootstrapAdmin(c *gin.Context) {
handleError(c, err)
return
}
setSessionCookies(c, h.authService, resp.RefreshToken)
c.JSON(http.StatusCreated, gin.H{
"code": 0,
@@ -673,6 +704,46 @@ func getUserIDFromContext(c *gin.Context) (int64, bool) {
return id, ok
}
func setSessionCookies(c *gin.Context, authService *service.AuthService, refreshToken string) {
if c == nil || strings.TrimSpace(refreshToken) == "" {
return
}
maxAge := 0
if authService != nil {
if ttl := authService.RefreshTokenTTLSeconds(); ttl > 0 {
maxAge = int(ttl)
}
}
secure := requestUsesHTTPS(c)
c.SetSameSite(http.SameSiteLaxMode)
c.SetCookie(refreshTokenCookieName, refreshToken, maxAge, "/", "", secure, true)
c.SetCookie(sessionPresenceCookieName, "1", maxAge, "/", "", secure, false)
}
func clearSessionCookies(c *gin.Context) {
if c == nil {
return
}
secure := requestUsesHTTPS(c)
c.SetSameSite(http.SameSiteLaxMode)
c.SetCookie(refreshTokenCookieName, "", -1, "/", "", secure, true)
c.SetCookie(sessionPresenceCookieName, "", -1, "/", "", secure, false)
}
func requestUsesHTTPS(c *gin.Context) bool {
if c == nil || c.Request == nil {
return false
}
if c.Request.TLS != nil {
return true
}
return strings.EqualFold(strings.TrimSpace(c.GetHeader("X-Forwarded-Proto")), "https")
}
// handleError 将 error 转换为对应的 HTTP 响应。
// 优先识别 ApplicationError其次通过关键词推断业务错误类型兜底返回 500。
func handleError(c *gin.Context, err error) {

View File

@@ -31,6 +31,46 @@ import (
var handlerDbCounter int64
func seedHandlerAuthzData(t *testing.T, db *gorm.DB) {
t.Helper()
roleIDs := make(map[string]int64)
for _, predefined := range domain.PredefinedRoles {
role := predefined
if err := db.Create(&role).Error; err != nil {
t.Fatalf("seed role %s failed: %v", role.Code, err)
}
roleIDs[role.Code] = role.ID
}
permissionIDs := make(map[string]int64)
for _, predefined := range domain.DefaultPermissions() {
permission := predefined
if err := db.Create(&permission).Error; err != nil {
t.Fatalf("seed permission %s failed: %v", permission.Code, err)
}
permissionIDs[permission.Code] = permission.ID
}
adminRoleID := roleIDs["admin"]
for _, permissionID := range permissionIDs {
if err := db.Create(&domain.RolePermission{RoleID: adminRoleID, PermissionID: permissionID}).Error; err != nil {
t.Fatalf("assign admin permission %d failed: %v", permissionID, err)
}
}
userRoleID := roleIDs["user"]
for _, code := range []string{"profile:view", "profile:edit", "log:view_own"} {
permissionID, ok := permissionIDs[code]
if !ok {
t.Fatalf("seeded permissions missing %s", code)
}
if err := db.Create(&domain.RolePermission{RoleID: userRoleID, PermissionID: permissionID}).Error; err != nil {
t.Fatalf("assign user permission %s failed: %v", code, err)
}
}
}
func setupHandlerTestServer(t *testing.T) (*httptest.Server, func()) {
t.Helper()
gin.SetMode(gin.TestMode)
@@ -64,6 +104,8 @@ func setupHandlerTestServer(t *testing.T) (*httptest.Server, func()) {
t.Fatalf("db migration failed: %v", err)
}
seedHandlerAuthzData(t, db)
jwtManager, err := auth.NewJWTWithOptions(auth.JWTOptions{
HS256Secret: "test-handler-secret-key",
AccessTokenExpire: 15 * time.Minute,
@@ -176,6 +218,18 @@ func doDelete(url, token string) (*http.Response, string) {
return doRequest("DELETE", url, token, nil)
}
func getCookie(resp *http.Response, name string) *http.Cookie {
if resp == nil {
return nil
}
for _, cookie := range resp.Cookies() {
if cookie.Name == name {
return cookie
}
}
return nil
}
func getToken(baseURL, username, password string) string {
resp, body := doPost(baseURL+"/api/v1/auth/login", "", map[string]interface{}{
"account": username,
@@ -207,6 +261,111 @@ func registerUser(baseURL, username, email, password string) bool {
return resp.StatusCode == http.StatusCreated
}
func bootstrapAdmin(baseURL, secret, username, email, password string) string {
payload, _ := json.Marshal(map[string]interface{}{
"username": username,
"email": email,
"password": password,
})
req, _ := http.NewRequest(http.MethodPost, baseURL+"/api/v1/auth/bootstrap-admin", bytes.NewReader(payload))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Bootstrap-Secret", secret)
resp, err := (&http.Client{}).Do(req)
if err != nil {
return ""
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK {
return ""
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return ""
}
var result map[string]interface{}
if err := json.Unmarshal(body, &result); err != nil {
return ""
}
data, ok := result["data"].(map[string]interface{})
if !ok || data["access_token"] == nil {
return ""
}
token, _ := data["access_token"].(string)
return token
}
func setupEnabledTOTPUser(t *testing.T, baseURL, username, email, password string) (int64, string) {
t.Helper()
if ok := registerUser(baseURL, username, email, password); !ok {
t.Fatalf("registration failed for %s", username)
}
token := getToken(baseURL, username, password)
if token == "" {
t.Fatalf("failed to get token for %s", username)
}
userInfoResp, userInfoBody := doGet(baseURL+"/api/v1/auth/userinfo", token)
defer userInfoResp.Body.Close()
if userInfoResp.StatusCode != http.StatusOK {
t.Fatalf("userinfo failed: status=%d body=%s", userInfoResp.StatusCode, userInfoBody)
}
var userInfoResult map[string]interface{}
if err := json.Unmarshal([]byte(userInfoBody), &userInfoResult); err != nil {
t.Fatalf("failed to parse userinfo response: %v", err)
}
userData, ok := userInfoResult["data"].(map[string]interface{})
if !ok {
t.Fatalf("userinfo response missing data: %s", userInfoBody)
}
userID, ok := userData["id"].(float64)
if !ok {
t.Fatalf("userinfo response missing id: %s", userInfoBody)
}
setupResp, setupBody := doGet(baseURL+"/api/v1/auth/2fa/setup", token)
defer setupResp.Body.Close()
if setupResp.StatusCode != http.StatusOK {
t.Fatalf("2fa setup failed: status=%d body=%s", setupResp.StatusCode, setupBody)
}
var setupResult map[string]interface{}
if err := json.Unmarshal([]byte(setupBody), &setupResult); err != nil {
t.Fatalf("failed to parse 2fa setup response: %v", err)
}
setupData, ok := setupResult["data"].(map[string]interface{})
if !ok {
t.Fatalf("2fa setup response missing data: %s", setupBody)
}
secret, ok := setupData["secret"].(string)
if !ok || secret == "" {
t.Fatalf("2fa setup response missing secret: %s", setupBody)
}
code, err := auth.NewTOTPManager().GenerateCurrentCode(secret)
if err != nil {
t.Fatalf("failed to generate TOTP code: %v", err)
}
enableResp, enableBody := doPost(baseURL+"/api/v1/auth/2fa/enable", token, map[string]interface{}{
"code": code,
})
defer enableResp.Body.Close()
if enableResp.StatusCode != http.StatusOK {
t.Fatalf("2fa enable failed: status=%d body=%s", enableResp.StatusCode, enableBody)
}
return int64(userID), secret
}
// =============================================================================
// Auth Handler Tests
// =============================================================================
@@ -292,6 +451,38 @@ func TestAuthHandler_Login_Success(t *testing.T) {
}
}
func TestAuthHandler_Login_SetsSessionCookies(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "logincookieuser", "logincookie@example.com", "Password123!")
resp, body := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
"account": "logincookieuser",
"password": "Password123!",
})
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
}
refreshCookie := getCookie(resp, "ums_refresh_token")
if refreshCookie == nil || refreshCookie.Value == "" {
t.Fatalf("login response missing refresh cookie, cookies=%v", resp.Cookies())
}
if !refreshCookie.HttpOnly {
t.Fatalf("refresh cookie should be HttpOnly, got %+v", refreshCookie)
}
presenceCookie := getCookie(resp, "ums_session_present")
if presenceCookie == nil || presenceCookie.Value != "1" {
t.Fatalf("login response missing presence cookie, cookies=%v", resp.Cookies())
}
if presenceCookie.HttpOnly {
t.Fatalf("presence cookie should be readable by the frontend, got %+v", presenceCookie)
}
}
func TestAuthHandler_Login_WrongPassword(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
@@ -360,6 +551,66 @@ func TestAuthHandler_GetAuthCapabilities(t *testing.T) {
}
}
func TestAuthHandler_Login_WithTOTPEnabled_ReturnsChallengeToken(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
_, _ = setupEnabledTOTPUser(t, server.URL, "totplogin", "totplogin@example.com", "Password123!")
resp, body := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
"account": "totplogin",
"password": "Password123!",
"device_id": "device-login-1",
})
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
}
var result map[string]interface{}
if err := json.Unmarshal([]byte(body), &result); err != nil {
t.Fatalf("failed to parse login response: %v", err)
}
data, ok := result["data"].(map[string]interface{})
if !ok {
t.Fatalf("expected login response data, got %s", body)
}
if data["requires_totp"] != true {
t.Fatalf("expected requires_totp=true, got %+v", data)
}
tempToken, ok := data["temp_token"].(string)
if !ok || tempToken == "" {
t.Fatalf("expected temp_token in TOTP challenge response, got %+v", data)
}
}
func TestAuthHandler_VerifyTOTPAfterPasswordLogin_RequiresTempToken(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
userID, secret := setupEnabledTOTPUser(t, server.URL, "totpreverify", "totpreverify@example.com", "Password123!")
code, err := auth.NewTOTPManager().GenerateCurrentCode(secret)
if err != nil {
t.Fatalf("failed to generate TOTP code: %v", err)
}
resp, body := doPost(server.URL+"/api/v1/auth/login/totp-verify", "", map[string]interface{}{
"user_id": userID,
"code": code,
"device_id": "device-login-1",
})
defer resp.Body.Close()
if resp.StatusCode != http.StatusUnauthorized {
t.Fatalf("expected status %d when temp_token is missing, got %d, body: %s", http.StatusUnauthorized, resp.StatusCode, body)
}
}
// =============================================================================
// User Handler Tests
// =============================================================================
@@ -451,6 +702,26 @@ func TestUserHandler_UpdateUser_Success(t *testing.T) {
}
}
func TestUserHandler_UpdateUser_AdminCanUpdateAnotherUser(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
t.Setenv("BOOTSTRAP_SECRET", "handler-bootstrap-secret")
token := bootstrapAdmin(server.URL, "handler-bootstrap-secret", "updateadmin", "updateadmin@test.com", "AdminPass123!")
registerUser(server.URL, "targetuser", "targetuser@test.com", "UserPass123!")
if token == "" {
t.Fatal("bootstrap admin should return access token")
}
resp, body := doPut(server.URL+"/api/v1/users/2", token, map[string]string{"nickname": "Updated By Admin"})
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
}
}
func TestUserHandler_DeleteUser_NonAdmin_Forbidden(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
@@ -515,6 +786,26 @@ func TestUserHandler_GetUserRoles_Success(t *testing.T) {
}
}
func TestUserHandler_GetUserRoles_AdminCanViewAnotherUser(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
t.Setenv("BOOTSTRAP_SECRET", "handler-bootstrap-secret")
token := bootstrapAdmin(server.URL, "handler-bootstrap-secret", "rolesadmin2", "rolesadmin2@test.com", "AdminPass123!")
registerUser(server.URL, "roles-target", "roles-target@test.com", "UserPass123!")
if token == "" {
t.Fatal("bootstrap admin should return access token")
}
resp, body := doGet(server.URL+"/api/v1/users/2/roles", token)
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
}
}
func TestUserHandler_AssignRoles_RequiresAdmin(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
@@ -1253,6 +1544,187 @@ func TestAuthHandler_RefreshToken_Success(t *testing.T) {
}
}
func TestAuthHandler_RefreshToken_AcceptsRefreshCookie(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "refreshcookieuser", "refreshcookie@example.com", "Password123!")
loginResp, loginBody := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
"account": "refreshcookieuser",
"password": "Password123!",
})
defer loginResp.Body.Close()
if loginResp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, loginResp.StatusCode, loginBody)
}
refreshCookie := getCookie(loginResp, "ums_refresh_token")
if refreshCookie == nil || refreshCookie.Value == "" {
t.Fatalf("login response missing refresh cookie, cookies=%v", loginResp.Cookies())
}
req, err := http.NewRequest("POST", server.URL+"/api/v1/auth/refresh", nil)
if err != nil {
t.Fatalf("create refresh request failed: %v", err)
}
req.AddCookie(refreshCookie)
req.AddCookie(&http.Cookie{Name: "ums_session_present", Value: "1"})
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("refresh request failed: %v", err)
}
defer resp.Body.Close()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("read refresh response failed: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, string(bodyBytes))
}
rotatedCookie := getCookie(resp, "ums_refresh_token")
if rotatedCookie == nil || rotatedCookie.Value == "" {
t.Fatalf("refresh response missing rotated refresh cookie, cookies=%v", resp.Cookies())
}
if rotatedCookie.Value == refreshCookie.Value {
t.Fatalf("refresh should rotate cookie value, old=%q new=%q", refreshCookie.Value, rotatedCookie.Value)
}
presenceCookie := getCookie(resp, "ums_session_present")
if presenceCookie == nil || presenceCookie.Value != "1" {
t.Fatalf("refresh response missing presence cookie, cookies=%v", resp.Cookies())
}
}
func TestAuthHandler_RefreshToken_AllowsImmediateRetryWithPreviousCookie(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "refreshretryuser", "refreshretry@example.com", "Password123!")
loginResp, loginBody := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
"account": "refreshretryuser",
"password": "Password123!",
})
defer loginResp.Body.Close()
if loginResp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, loginResp.StatusCode, loginBody)
}
refreshCookie := getCookie(loginResp, "ums_refresh_token")
if refreshCookie == nil || refreshCookie.Value == "" {
t.Fatalf("login response missing refresh cookie, cookies=%v", loginResp.Cookies())
}
newRefreshRequest := func(cookie *http.Cookie) *http.Response {
req, err := http.NewRequest(http.MethodPost, server.URL+"/api/v1/auth/refresh", nil)
if err != nil {
t.Fatalf("create refresh request failed: %v", err)
}
req.AddCookie(cookie)
req.AddCookie(&http.Cookie{Name: "ums_session_present", Value: "1"})
resp, err := (&http.Client{}).Do(req)
if err != nil {
t.Fatalf("refresh request failed: %v", err)
}
return resp
}
firstResp := newRefreshRequest(refreshCookie)
defer firstResp.Body.Close()
firstBody, err := io.ReadAll(firstResp.Body)
if err != nil {
t.Fatalf("read first refresh response failed: %v", err)
}
if firstResp.StatusCode != http.StatusOK {
t.Fatalf("expected first refresh status %d, got %d, body: %s", http.StatusOK, firstResp.StatusCode, string(firstBody))
}
retryResp := newRefreshRequest(refreshCookie)
defer retryResp.Body.Close()
retryBody, err := io.ReadAll(retryResp.Body)
if err != nil {
t.Fatalf("read retry refresh response failed: %v", err)
}
if retryResp.StatusCode != http.StatusOK {
t.Fatalf("expected retry refresh status %d, got %d, body: %s", http.StatusOK, retryResp.StatusCode, string(retryBody))
}
}
func TestAuthHandler_Logout_ClearsSessionCookies(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "logoutcookieuser", "logoutcookie@example.com", "Password123!")
loginResp, loginBody := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
"account": "logoutcookieuser",
"password": "Password123!",
})
defer loginResp.Body.Close()
if loginResp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, loginResp.StatusCode, loginBody)
}
var loginResult map[string]interface{}
if err := json.Unmarshal([]byte(loginBody), &loginResult); err != nil {
t.Fatalf("parse login response failed: %v", err)
}
loginData, ok := loginResult["data"].(map[string]interface{})
if !ok {
t.Fatalf("login response missing data: %s", loginBody)
}
accessToken, ok := loginData["access_token"].(string)
if !ok || accessToken == "" {
t.Fatalf("login response missing access token: %s", loginBody)
}
refreshCookie := getCookie(loginResp, "ums_refresh_token")
if refreshCookie == nil || refreshCookie.Value == "" {
t.Fatalf("login response missing refresh cookie, cookies=%v", loginResp.Cookies())
}
req, err := http.NewRequest("POST", server.URL+"/api/v1/auth/logout", nil)
if err != nil {
t.Fatalf("create logout request failed: %v", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.AddCookie(refreshCookie)
req.AddCookie(&http.Cookie{Name: "ums_session_present", Value: "1"})
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("logout request failed: %v", err)
}
defer resp.Body.Close()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("read logout response failed: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, string(bodyBytes))
}
clearedRefreshCookie := getCookie(resp, "ums_refresh_token")
if clearedRefreshCookie == nil || clearedRefreshCookie.Value != "" {
t.Fatalf("logout response should clear refresh cookie, cookies=%v", resp.Cookies())
}
clearedPresenceCookie := getCookie(resp, "ums_session_present")
if clearedPresenceCookie == nil || clearedPresenceCookie.Value != "" {
t.Fatalf("logout response should clear presence cookie, cookies=%v", resp.Cookies())
}
}
func TestAuthHandler_RefreshToken_InvalidToken(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()

View File

@@ -116,6 +116,7 @@ func (h *SMSHandler) LoginByCode(c *gin.Context) {
h.authService.BestEffortRegisterDevicePublic(devCtx, userID, loginReq)
}()
}
setSessionCookies(c, h.authService, resp.RefreshToken)
c.JSON(http.StatusOK, gin.H{
"code": 0,

View File

@@ -6,6 +6,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/api/middleware"
"github.com/user-management-system/internal/auth"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/service"
@@ -187,15 +188,7 @@ func (h *UserHandler) UpdateUser(c *gin.Context) {
// Authorization: only self or admin can update user profile
currentUserID := c.GetInt64("user_id")
isAdmin := false
if roles, ok := c.Get("user_roles"); ok {
for _, role := range roles.([]*domain.Role) {
if role.Code == "admin" {
isAdmin = true
break
}
}
}
isAdmin := middleware.IsAdmin(c)
if currentUserID != id && !isAdmin {
c.JSON(http.StatusForbidden, gin.H{"code": 403, "message": "permission denied"})
return
@@ -370,15 +363,7 @@ func (h *UserHandler) GetUserRoles(c *gin.Context) {
// Authorization: only self or admin can view user roles
currentUserID := c.GetInt64("user_id")
isAdmin := false
if roles, ok := c.Get("user_roles"); ok {
for _, role := range roles.([]*domain.Role) {
if role.Code == "admin" {
isAdmin = true
break
}
}
}
isAdmin := middleware.IsAdmin(c)
if currentUserID != id && !isAdmin {
c.JSON(http.StatusForbidden, gin.H{"code": 403, "message": "permission denied"})
return

View File

@@ -0,0 +1,103 @@
package middleware
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/auth"
"github.com/user-management-system/internal/cache"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/repository"
"github.com/user-management-system/internal/service"
gormsqlite "gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
_ "modernc.org/sqlite"
)
func TestAuthMiddleware_AcceptsBootstrapAdminTokenImmediately(t *testing.T) {
t.Helper()
gin.SetMode(gin.TestMode)
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
DriverName: "sqlite",
DSN: "file:middleware_bootstrap_test?mode=memory&cache=shared",
}), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("open sqlite failed: %v", err)
}
if err := db.AutoMigrate(&domain.User{}, &domain.Role{}, &domain.UserRole{}); err != nil {
t.Fatalf("migrate failed: %v", err)
}
if err := db.Create(&domain.Role{
Name: "管理员",
Code: "admin",
IsSystem: true,
Status: domain.RoleStatusEnabled,
}).Error; err != nil {
t.Fatalf("seed admin role failed: %v", err)
}
jwtManager, err := auth.NewJWTWithOptions(auth.JWTOptions{
HS256Secret: "test-bootstrap-token-secret-at-least-32-chars",
AccessTokenExpire: 15 * time.Minute,
RefreshTokenExpire: 7 * 24 * time.Hour,
})
if err != nil {
t.Fatalf("create jwt manager failed: %v", err)
}
l1Cache := cache.NewL1Cache()
l2Cache := cache.NewRedisCache(false)
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
userRepo := repository.NewUserRepository(db)
roleRepo := repository.NewRoleRepository(db)
userRoleRepo := repository.NewUserRoleRepository(db)
authService := service.NewAuthService(userRepo, nil, jwtManager, cacheManager, 8, 5, 15*time.Minute)
authService.SetRoleRepositories(userRoleRepo, roleRepo)
loginResponse, err := authService.BootstrapAdmin(context.Background(), &service.BootstrapAdminRequest{
Username: "bootstrap_admin",
Email: "bootstrap_admin@example.com",
Password: "AdminPass123!",
}, "127.0.0.1")
if err != nil {
t.Fatalf("bootstrap admin failed: %v", err)
}
if loginResponse == nil || loginResponse.AccessToken == "" {
t.Fatalf("expected bootstrap access token, got %+v", loginResponse)
}
if _, err := jwtManager.ValidateAccessToken(loginResponse.AccessToken); err != nil {
t.Fatalf("bootstrap access token should validate immediately: %v", err)
}
authMiddleware := NewAuthMiddleware(jwtManager, userRepo, userRoleRepo, l1Cache)
authMiddleware.SetCacheManager(cacheManager)
recorder := httptest.NewRecorder()
ctx, engine := gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest(http.MethodGet, "/protected", nil)
ctx.Request.Header.Set("Authorization", "Bearer "+loginResponse.AccessToken)
engine.Use(authMiddleware.Required())
engine.GET("/protected", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"code": 0})
})
engine.ServeHTTP(recorder, ctx.Request)
if recorder.Code != http.StatusOK {
t.Fatalf("expected bootstrap token to pass auth middleware immediately, got %d body: %s", recorder.Code, recorder.Body.String())
}
}

View File

@@ -1,14 +1,21 @@
package middleware
import (
"bytes"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/config"
)
// RateLimitMiddleware 限流中间件
// RateLimitMiddleware provides simple in-memory sliding-window rate limiting.
type RateLimitMiddleware struct {
cfg config.RateLimitConfig
limiters map[string]*SlidingWindowLimiter
@@ -16,7 +23,7 @@ type RateLimitMiddleware struct {
cleanupInt time.Duration
}
// SlidingWindowLimiter 滑动窗口限流器
// SlidingWindowLimiter enforces a fixed-capacity sliding window.
type SlidingWindowLimiter struct {
mu sync.Mutex
window time.Duration
@@ -24,7 +31,6 @@ type SlidingWindowLimiter struct {
requests []int64
}
// NewSlidingWindowLimiter 创建滑动窗口限流器
func NewSlidingWindowLimiter(window time.Duration, capacity int64) *SlidingWindowLimiter {
return &SlidingWindowLimiter{
window: window,
@@ -33,7 +39,6 @@ func NewSlidingWindowLimiter(window time.Duration, capacity int64) *SlidingWindo
}
}
// Allow 检查是否允许请求
func (l *SlidingWindowLimiter) Allow() bool {
l.mu.Lock()
defer l.mu.Unlock()
@@ -41,16 +46,14 @@ func (l *SlidingWindowLimiter) Allow() bool {
now := time.Now().UnixMilli()
cutoff := now - l.window.Milliseconds()
// 清理过期请求
var validRequests []int64
for _, t := range l.requests {
if t > cutoff {
validRequests = append(validRequests, t)
validRequests := make([]int64, 0, len(l.requests))
for _, ts := range l.requests {
if ts > cutoff {
validRequests = append(validRequests, ts)
}
}
l.requests = validRequests
// 检查容量
if int64(len(l.requests)) >= l.capacity {
return false
}
@@ -59,7 +62,6 @@ func (l *SlidingWindowLimiter) Allow() bool {
return true
}
// NewRateLimitMiddleware 创建限流中间件
func NewRateLimitMiddleware(cfg config.RateLimitConfig) *RateLimitMiddleware {
return &RateLimitMiddleware{
cfg: cfg,
@@ -68,30 +70,28 @@ func NewRateLimitMiddleware(cfg config.RateLimitConfig) *RateLimitMiddleware {
}
}
// Register 返回注册接口的限流中间件
func (m *RateLimitMiddleware) Register() gin.HandlerFunc {
return m.limitForKey("register", 60, 10)
}
// Login 返回登录接口的限流中间件
func (m *RateLimitMiddleware) Login() gin.HandlerFunc {
return m.limitForKey("login", 60, 5)
}
// API 返回 API 接口的限流中间件
func (m *RateLimitMiddleware) API() gin.HandlerFunc {
return m.limitForKey("api", 60, 100)
}
// Refresh 返回刷新令牌的限流中间件
func (m *RateLimitMiddleware) Refresh() gin.HandlerFunc {
return m.limitForKey("refresh", 60, 10)
}
func (m *RateLimitMiddleware) limitForKey(key string, windowSeconds int, capacity int64) gin.HandlerFunc {
limiter := m.getOrCreateLimiter(key, time.Duration(windowSeconds)*time.Second, capacity)
func (m *RateLimitMiddleware) limitForKey(bucket string, windowSeconds int, capacity int64) gin.HandlerFunc {
window := time.Duration(windowSeconds) * time.Second
return func(c *gin.Context) {
limiterKey := m.resolveLimiterKey(c, bucket)
limiter := m.getOrCreateLimiter(limiterKey, window, capacity)
if !limiter.Allow() {
c.JSON(429, gin.H{
"code": 429,
@@ -104,6 +104,81 @@ func (m *RateLimitMiddleware) limitForKey(key string, windowSeconds int, capacit
}
}
func (m *RateLimitMiddleware) resolveLimiterKey(c *gin.Context, bucket string) string {
if bucket == "refresh" {
if refreshToken := extractRefreshToken(c); refreshToken != "" {
return fmt.Sprintf("%s:token:%s", bucket, fingerprintValue(refreshToken))
}
}
identity := "anonymous"
if c != nil {
if userID, ok := c.Get("user_id"); ok {
identity = fmt.Sprintf("user:%v", userID)
} else if ip := c.ClientIP(); ip != "" {
identity = "ip:" + ip
}
}
if bucket == "api" {
method := ""
route := ""
if c != nil {
if c.Request != nil {
method = c.Request.Method
if c.Request.URL != nil {
route = c.Request.URL.Path
}
}
if fullPath := c.FullPath(); fullPath != "" {
route = fullPath
}
}
return fmt.Sprintf("%s:%s:%s:%s", bucket, method, route, identity)
}
return fmt.Sprintf("%s:%s", bucket, identity)
}
func extractRefreshToken(c *gin.Context) string {
if c == nil {
return ""
}
if refreshToken, err := c.Cookie("ums_refresh_token"); err == nil && refreshToken != "" {
return refreshToken
}
if c.Request == nil || c.Request.Body == nil {
return ""
}
body, err := io.ReadAll(c.Request.Body)
if err != nil {
return ""
}
c.Request.Body = io.NopCloser(bytes.NewReader(body))
if len(bytes.TrimSpace(body)) == 0 {
return ""
}
var payload struct {
RefreshToken string `json:"refresh_token"`
}
if err := json.Unmarshal(body, &payload); err != nil {
return ""
}
return payload.RefreshToken
}
func fingerprintValue(value string) string {
sum := sha256.Sum256([]byte(value))
return hex.EncodeToString(sum[:12])
}
func (m *RateLimitMiddleware) getOrCreateLimiter(key string, window time.Duration, capacity int64) *SlidingWindowLimiter {
m.mu.RLock()
limiter, exists := m.limiters[key]
@@ -116,7 +191,6 @@ func (m *RateLimitMiddleware) getOrCreateLimiter(key string, window time.Duratio
m.mu.Lock()
defer m.mu.Unlock()
// 双重检查
if limiter, exists = m.limiters[key]; exists {
return limiter
}

View File

@@ -0,0 +1,140 @@
package middleware
import (
"bytes"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/config"
)
func performRateLimitedRequest(router *gin.Engine, path string, userID int64) *httptest.ResponseRecorder {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, path, nil)
req.RemoteAddr = "127.0.0.1:12345"
req.Header.Set("X-Test-User-ID", strconv.FormatInt(userID, 10))
router.ServeHTTP(recorder, req)
return recorder
}
func performRefreshRateLimitedRequestWithCookie(router *gin.Engine, refreshToken string) *httptest.ResponseRecorder {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/auth/refresh", nil)
req.RemoteAddr = "127.0.0.1:12345"
if refreshToken != "" {
req.AddCookie(&http.Cookie{Name: "ums_refresh_token", Value: refreshToken})
}
router.ServeHTTP(recorder, req)
return recorder
}
func performRefreshRateLimitedRequestWithBody(router *gin.Engine, refreshToken string) *httptest.ResponseRecorder {
recorder := httptest.NewRecorder()
body := bytes.NewBufferString(`{"refresh_token":"` + refreshToken + `"}`)
req := httptest.NewRequest(http.MethodPost, "/auth/refresh", body)
req.RemoteAddr = "127.0.0.1:12345"
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
return recorder
}
func TestRateLimitMiddleware_API_ScopesBudgetByRouteForAuthenticatedUser(t *testing.T) {
gin.SetMode(gin.TestMode)
rateLimitMiddleware := NewRateLimitMiddleware(config.RateLimitConfig{})
router := gin.New()
router.Use(func(c *gin.Context) {
rawUserID := c.GetHeader("X-Test-User-ID")
if rawUserID != "" {
userID, err := strconv.ParseInt(rawUserID, 10, 64)
if err == nil {
c.Set("user_id", userID)
}
}
c.Next()
})
protected := router.Group("")
protected.Use(rateLimitMiddleware.API())
protected.GET("/users", func(c *gin.Context) {
c.Status(http.StatusOK)
})
protected.GET("/roles", func(c *gin.Context) {
c.Status(http.StatusOK)
})
for i := 0; i < 100; i++ {
recorder := performRateLimitedRequest(router, "/users", 1)
if recorder.Code != http.StatusOK {
t.Fatalf("request %d to /users returned %d, want %d", i+1, recorder.Code, http.StatusOK)
}
}
sameRouteOverflow := performRateLimitedRequest(router, "/users", 1)
if sameRouteOverflow.Code != http.StatusTooManyRequests {
t.Fatalf("overflow request to /users returned %d, want %d", sameRouteOverflow.Code, http.StatusTooManyRequests)
}
differentRoute := performRateLimitedRequest(router, "/roles", 1)
if differentRoute.Code != http.StatusOK {
t.Fatalf("request to /roles after exhausting /users budget returned %d, want %d", differentRoute.Code, http.StatusOK)
}
}
func TestRateLimitMiddleware_Refresh_ScopesBudgetByRefreshCookie(t *testing.T) {
gin.SetMode(gin.TestMode)
rateLimitMiddleware := NewRateLimitMiddleware(config.RateLimitConfig{})
router := gin.New()
router.POST("/auth/refresh", rateLimitMiddleware.Refresh(), func(c *gin.Context) {
c.Status(http.StatusOK)
})
for i := 0; i < 10; i++ {
recorder := performRefreshRateLimitedRequestWithCookie(router, "refresh-token-a")
if recorder.Code != http.StatusOK {
t.Fatalf("request %d for refresh-token-a returned %d, want %d", i+1, recorder.Code, http.StatusOK)
}
}
sameTokenOverflow := performRefreshRateLimitedRequestWithCookie(router, "refresh-token-a")
if sameTokenOverflow.Code != http.StatusTooManyRequests {
t.Fatalf("overflow request for refresh-token-a returned %d, want %d", sameTokenOverflow.Code, http.StatusTooManyRequests)
}
differentToken := performRefreshRateLimitedRequestWithCookie(router, "refresh-token-b")
if differentToken.Code != http.StatusOK {
t.Fatalf("request for refresh-token-b after exhausting refresh-token-a budget returned %d, want %d", differentToken.Code, http.StatusOK)
}
}
func TestRateLimitMiddleware_Refresh_ScopesBudgetByRefreshTokenBody(t *testing.T) {
gin.SetMode(gin.TestMode)
rateLimitMiddleware := NewRateLimitMiddleware(config.RateLimitConfig{})
router := gin.New()
router.POST("/auth/refresh", rateLimitMiddleware.Refresh(), func(c *gin.Context) {
c.Status(http.StatusOK)
})
for i := 0; i < 10; i++ {
recorder := performRefreshRateLimitedRequestWithBody(router, "refresh-token-a")
if recorder.Code != http.StatusOK {
t.Fatalf("request %d for refresh-token-a body returned %d, want %d", i+1, recorder.Code, http.StatusOK)
}
}
sameTokenOverflow := performRefreshRateLimitedRequestWithBody(router, "refresh-token-a")
if sameTokenOverflow.Code != http.StatusTooManyRequests {
t.Fatalf("overflow request for refresh-token-a body returned %d, want %d", sameTokenOverflow.Code, http.StatusTooManyRequests)
}
differentToken := performRefreshRateLimitedRequestWithBody(router, "refresh-token-b")
if differentToken.Code != http.StatusOK {
t.Fatalf("request for refresh-token-b body after exhausting refresh-token-a budget returned %d, want %d", differentToken.Code, http.StatusOK)
}
}