fix: close auth, permission, contract and e2e review blockers
This commit is contained in:
@@ -30,11 +30,74 @@ type AuthHandler struct {
|
||||
authService *service.AuthService
|
||||
}
|
||||
|
||||
const (
|
||||
refreshTokenCookieName = "ums_refresh_token"
|
||||
sessionPresenceCookieName = "ums_session_present"
|
||||
)
|
||||
|
||||
// NewAuthHandler creates a new AuthHandler
|
||||
func NewAuthHandler(authService *service.AuthService) *AuthHandler {
|
||||
return &AuthHandler{authService: authService}
|
||||
}
|
||||
|
||||
func isSecureRequest(c *gin.Context) bool {
|
||||
if c == nil || c.Request == nil {
|
||||
return false
|
||||
}
|
||||
if c.Request.TLS != nil {
|
||||
return true
|
||||
}
|
||||
return strings.EqualFold(c.GetHeader("X-Forwarded-Proto"), "https")
|
||||
}
|
||||
|
||||
func (h *AuthHandler) setSessionCookies(c *gin.Context, resp *service.LoginResponse) {
|
||||
if c == nil || resp == nil || strings.TrimSpace(resp.RefreshToken) == "" || h == nil || h.authService == nil {
|
||||
return
|
||||
}
|
||||
|
||||
maxAge := int(h.authService.RefreshTokenTTLSeconds())
|
||||
secure := isSecureRequest(c)
|
||||
http.SetCookie(c.Writer, &http.Cookie{
|
||||
Name: refreshTokenCookieName,
|
||||
Value: resp.RefreshToken,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: secure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: maxAge,
|
||||
})
|
||||
http.SetCookie(c.Writer, &http.Cookie{
|
||||
Name: sessionPresenceCookieName,
|
||||
Value: "1",
|
||||
Path: "/",
|
||||
HttpOnly: false,
|
||||
Secure: secure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: maxAge,
|
||||
})
|
||||
}
|
||||
|
||||
func clearCookie(c *gin.Context, name string) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
http.SetCookie(c.Writer, &http.Cookie{
|
||||
Name: name,
|
||||
Value: "",
|
||||
Path: "/",
|
||||
HttpOnly: name == refreshTokenCookieName,
|
||||
Secure: isSecureRequest(c),
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: -1,
|
||||
Expires: time.Unix(0, 0),
|
||||
})
|
||||
}
|
||||
|
||||
func clearSessionCookies(c *gin.Context) {
|
||||
clearCookie(c, refreshTokenCookieName)
|
||||
clearCookie(c, sessionPresenceCookieName)
|
||||
}
|
||||
|
||||
// Register 用户注册
|
||||
// @Summary 用户注册
|
||||
// @Description 用户注册新账号,支持用户名+密码或手机号注册
|
||||
@@ -130,6 +193,7 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
h.setSessionCookies(c, resp)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"message": "success",
|
||||
@@ -150,21 +214,23 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
||||
// @Router /api/v1/auth/login/totp-verify [post]
|
||||
func (h *AuthHandler) VerifyTOTPAfterPasswordLogin(c *gin.Context) {
|
||||
var req struct {
|
||||
UserID int64 `json:"user_id" binding:"required"`
|
||||
Code string `json:"code" binding:"required"`
|
||||
DeviceID string `json:"device_id"`
|
||||
UserID int64 `json:"user_id" binding:"required"`
|
||||
Code string `json:"code" binding:"required"`
|
||||
DeviceID string `json:"device_id"`
|
||||
TempToken string `json:"temp_token" binding:"required"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := h.authService.VerifyTOTPAfterPasswordLogin(c.Request.Context(), req.UserID, req.Code, req.DeviceID)
|
||||
resp, err := h.authService.VerifyTOTPAfterPasswordLogin(c.Request.Context(), req.UserID, req.Code, req.DeviceID, req.TempToken)
|
||||
if err != nil {
|
||||
handleError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
h.setSessionCookies(c, resp)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"message": "success",
|
||||
@@ -197,6 +263,12 @@ func (h *AuthHandler) Logout(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
if req.RefreshToken == "" {
|
||||
if cookie, err := c.Request.Cookie(refreshTokenCookieName); err == nil {
|
||||
req.RefreshToken = cookie.Value
|
||||
}
|
||||
}
|
||||
|
||||
username, _ := c.Get("username")
|
||||
usernameStr, _ := username.(string)
|
||||
|
||||
@@ -204,7 +276,11 @@ func (h *AuthHandler) Logout(c *gin.Context) {
|
||||
AccessToken: req.AccessToken,
|
||||
RefreshToken: req.RefreshToken,
|
||||
}
|
||||
_ = h.authService.Logout(c.Request.Context(), usernameStr, logoutReq)
|
||||
if err := h.authService.Logout(c.Request.Context(), usernameStr, logoutReq); err != nil {
|
||||
handleError(c, err)
|
||||
return
|
||||
}
|
||||
clearSessionCookies(c)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "logged out"})
|
||||
}
|
||||
@@ -222,20 +298,28 @@ func (h *AuthHandler) Logout(c *gin.Context) {
|
||||
// @Router /api/v1/auth/refresh-token [post]
|
||||
func (h *AuthHandler) RefreshToken(c *gin.Context) {
|
||||
var req struct {
|
||||
RefreshToken string `json:"refresh_token" binding:"required"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
_ = c.ShouldBindJSON(&req)
|
||||
if strings.TrimSpace(req.RefreshToken) == "" {
|
||||
if cookie, err := c.Request.Cookie(refreshTokenCookieName); err == nil {
|
||||
req.RefreshToken = cookie.Value
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(req.RefreshToken) == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "refresh_token is required"})
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := h.authService.RefreshToken(c.Request.Context(), req.RefreshToken)
|
||||
if err != nil {
|
||||
clearSessionCookies(c)
|
||||
handleError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
h.setSessionCookies(c, resp)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"message": "success",
|
||||
@@ -315,7 +399,7 @@ func (h *AuthHandler) GetAuthCapabilities(c *gin.Context) {
|
||||
// @Router /api/v1/auth/oauth/{provider} [get]
|
||||
func (h *AuthHandler) OAuthLogin(c *gin.Context) {
|
||||
provider := c.Param("provider")
|
||||
c.JSON(http.StatusOK, gin.H{"code": 0, "message": "OAuth not configured", "data": gin.H{"provider": provider}})
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"code": http.StatusServiceUnavailable, "message": "OAuth login is not configured", "data": gin.H{"provider": provider}})
|
||||
}
|
||||
|
||||
// OAuthCallback OAuth回调
|
||||
@@ -327,7 +411,7 @@ func (h *AuthHandler) OAuthLogin(c *gin.Context) {
|
||||
// @Success 200 {object} Response "OAuth未配置"
|
||||
// @Router /api/v1/auth/oauth/{provider}/callback [get]
|
||||
func (h *AuthHandler) OAuthCallback(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 0, "message": "OAuth not configured"})
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"code": http.StatusServiceUnavailable, "message": "OAuth callback is not configured"})
|
||||
}
|
||||
|
||||
// OAuthExchange OAuth令牌交换
|
||||
@@ -340,7 +424,7 @@ func (h *AuthHandler) OAuthCallback(c *gin.Context) {
|
||||
// @Success 200 {object} Response "OAuth未配置"
|
||||
// @Router /api/v1/auth/oauth/{provider}/exchange [post]
|
||||
func (h *AuthHandler) OAuthExchange(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 0, "message": "OAuth not configured"})
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"code": http.StatusServiceUnavailable, "message": "OAuth exchange is not configured"})
|
||||
}
|
||||
|
||||
// GetEnabledOAuthProviders 获取已启用的OAuth提供商
|
||||
@@ -481,6 +565,7 @@ func (h *AuthHandler) LoginByEmailCode(c *gin.Context) {
|
||||
}()
|
||||
}
|
||||
|
||||
h.setSessionCookies(c, resp)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"message": "success",
|
||||
@@ -545,6 +630,7 @@ func (h *AuthHandler) BootstrapAdmin(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
h.setSessionCookies(c, resp)
|
||||
c.JSON(http.StatusCreated, gin.H{
|
||||
"code": 0,
|
||||
"message": "success",
|
||||
@@ -561,7 +647,7 @@ func (h *AuthHandler) BootstrapAdmin(c *gin.Context) {
|
||||
// @Success 200 {object} Response "功能未配置"
|
||||
// @Router /api/v1/auth/email/bind/send [post]
|
||||
func (h *AuthHandler) SendEmailBindCode(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 0, "message": "email bind not configured"})
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"code": http.StatusServiceUnavailable, "message": "email binding is not configured"})
|
||||
}
|
||||
|
||||
// BindEmail 绑定邮箱
|
||||
@@ -573,7 +659,7 @@ func (h *AuthHandler) SendEmailBindCode(c *gin.Context) {
|
||||
// @Success 200 {object} Response "功能未配置"
|
||||
// @Router /api/v1/auth/email/bind [post]
|
||||
func (h *AuthHandler) BindEmail(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 0, "message": "email bind not configured"})
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"code": http.StatusServiceUnavailable, "message": "email binding is not configured"})
|
||||
}
|
||||
|
||||
// UnbindEmail 解绑邮箱
|
||||
@@ -585,7 +671,7 @@ func (h *AuthHandler) BindEmail(c *gin.Context) {
|
||||
// @Success 200 {object} Response "功能未配置"
|
||||
// @Router /api/v1/auth/email/unbind [post]
|
||||
func (h *AuthHandler) UnbindEmail(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 0, "message": "email unbind not configured"})
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"code": http.StatusServiceUnavailable, "message": "email binding is not configured"})
|
||||
}
|
||||
|
||||
// SendPhoneBindCode 发送手机绑定验证码
|
||||
@@ -597,7 +683,7 @@ func (h *AuthHandler) UnbindEmail(c *gin.Context) {
|
||||
// @Success 200 {object} Response "功能未配置"
|
||||
// @Router /api/v1/auth/phone/bind/send [post]
|
||||
func (h *AuthHandler) SendPhoneBindCode(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 0, "message": "phone bind not configured"})
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"code": http.StatusServiceUnavailable, "message": "phone binding is not configured"})
|
||||
}
|
||||
|
||||
// BindPhone 绑定手机号
|
||||
@@ -609,7 +695,7 @@ func (h *AuthHandler) SendPhoneBindCode(c *gin.Context) {
|
||||
// @Success 200 {object} Response "功能未配置"
|
||||
// @Router /api/v1/auth/phone/bind [post]
|
||||
func (h *AuthHandler) BindPhone(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 0, "message": "phone bind not configured"})
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"code": http.StatusServiceUnavailable, "message": "phone binding is not configured"})
|
||||
}
|
||||
|
||||
// UnbindPhone 解绑手机号
|
||||
@@ -621,7 +707,7 @@ func (h *AuthHandler) BindPhone(c *gin.Context) {
|
||||
// @Success 200 {object} Response "功能未配置"
|
||||
// @Router /api/v1/auth/phone/unbind [post]
|
||||
func (h *AuthHandler) UnbindPhone(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 0, "message": "phone unbind not configured"})
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"code": http.StatusServiceUnavailable, "message": "phone binding is not configured"})
|
||||
}
|
||||
|
||||
// GetSocialAccounts 获取社交账号列表
|
||||
@@ -645,7 +731,7 @@ func (h *AuthHandler) GetSocialAccounts(c *gin.Context) {
|
||||
// @Success 200 {object} Response "功能未配置"
|
||||
// @Router /api/v1/auth/social/bind [post]
|
||||
func (h *AuthHandler) BindSocialAccount(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 0, "message": "social binding not configured"})
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"code": http.StatusServiceUnavailable, "message": "social binding is not configured"})
|
||||
}
|
||||
|
||||
// UnbindSocialAccount 解绑社交账号
|
||||
@@ -657,7 +743,7 @@ func (h *AuthHandler) BindSocialAccount(c *gin.Context) {
|
||||
// @Success 200 {object} Response "功能未配置"
|
||||
// @Router /api/v1/auth/social/unbind [post]
|
||||
func (h *AuthHandler) UnbindSocialAccount(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 0, "message": "social unbinding not configured"})
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"code": http.StatusServiceUnavailable, "message": "social binding is not configured"})
|
||||
}
|
||||
|
||||
func (h *AuthHandler) SupportsEmailCodeLogin() bool {
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
apimiddleware "github.com/user-management-system/internal/api/middleware"
|
||||
"github.com/user-management-system/internal/domain"
|
||||
)
|
||||
|
||||
@@ -33,10 +34,12 @@ func NewAvatarHandler(userRepo avatarUserRepository) *AvatarHandler {
|
||||
}
|
||||
|
||||
// generateSecureToken generates a secure random token
|
||||
func generateSecureToken(length int) string {
|
||||
func generateSecureToken(length int) (string, error) {
|
||||
bytes := make([]byte, length)
|
||||
rand.Read(bytes)
|
||||
return hex.EncodeToString(bytes)[:length]
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes)[:length], nil
|
||||
}
|
||||
|
||||
// UploadAvatar 上传用户头像
|
||||
@@ -70,17 +73,7 @@ func (h *AvatarHandler) UploadAvatar(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Check permission: user can only update their own avatar, or admin can update any
|
||||
isAdmin := false
|
||||
if roles, ok := c.Get("user_roles"); ok {
|
||||
for _, role := range roles.([]*domain.Role) {
|
||||
if role.Code == "admin" {
|
||||
isAdmin = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if currentUserID != userID && !isAdmin {
|
||||
if currentUserID != userID && !apimiddleware.IsAdmin(c) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"code": 403, "message": "permission denied"})
|
||||
return
|
||||
}
|
||||
@@ -140,7 +133,12 @@ func (h *AvatarHandler) UploadAvatar(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Generate unique filename
|
||||
avatarFilename := fmt.Sprintf("avatar_%d_%s%s", userID, generateSecureToken(8), ext)
|
||||
token, err := generateSecureToken(8)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"code": 500, "message": "failed to generate avatar token"})
|
||||
return
|
||||
}
|
||||
avatarFilename := fmt.Sprintf("avatar_%d_%s%s", userID, token, ext)
|
||||
uploadDir := "./uploads/avatars"
|
||||
|
||||
// Create upload directory if not exists
|
||||
|
||||
@@ -7,7 +7,9 @@ import (
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
@@ -35,6 +37,11 @@ func setupHandlerTestServer(t *testing.T) (*httptest.Server, func()) {
|
||||
t.Helper()
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
previousBootstrapSecret, hadBootstrapSecret := os.LookupEnv("BOOTSTRAP_SECRET")
|
||||
if err := os.Setenv("BOOTSTRAP_SECRET", "test-bootstrap-secret"); err != nil {
|
||||
t.Fatalf("set bootstrap secret failed: %v", err)
|
||||
}
|
||||
|
||||
id := atomic.AddInt64(&handlerDbCounter, 1)
|
||||
dsn := fmt.Sprintf("file:handlerdb_%d_%s?mode=memory&cache=shared", id, t.Name())
|
||||
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
|
||||
@@ -64,6 +71,20 @@ func setupHandlerTestServer(t *testing.T) (*httptest.Server, func()) {
|
||||
t.Fatalf("db migration failed: %v", err)
|
||||
}
|
||||
|
||||
adminRole := &domain.Role{Code: "admin", Name: "管理员", Status: domain.RoleStatusEnabled}
|
||||
if err := db.Create(adminRole).Error; err != nil {
|
||||
t.Fatalf("seed admin role failed: %v", err)
|
||||
}
|
||||
for _, permission := range domain.DefaultPermissions() {
|
||||
perm := permission
|
||||
if err := db.Create(&perm).Error; err != nil {
|
||||
t.Fatalf("seed permission %s failed: %v", perm.Code, err)
|
||||
}
|
||||
if err := db.Create(&domain.RolePermission{RoleID: adminRole.ID, PermissionID: perm.ID}).Error; err != nil {
|
||||
t.Fatalf("seed role permission %s failed: %v", perm.Code, err)
|
||||
}
|
||||
}
|
||||
|
||||
jwtManager, err := auth.NewJWTWithOptions(auth.JWTOptions{
|
||||
HS256Secret: "test-handler-secret-key",
|
||||
AccessTokenExpire: 15 * time.Minute,
|
||||
@@ -136,6 +157,11 @@ func setupHandlerTestServer(t *testing.T) (*httptest.Server, func()) {
|
||||
server := httptest.NewServer(engine)
|
||||
return server, func() {
|
||||
server.Close()
|
||||
if hadBootstrapSecret {
|
||||
_ = os.Setenv("BOOTSTRAP_SECRET", previousBootstrapSecret)
|
||||
} else {
|
||||
_ = os.Unsetenv("BOOTSTRAP_SECRET")
|
||||
}
|
||||
if sqlDB, _ := db.DB(); sqlDB != nil {
|
||||
sqlDB.Close()
|
||||
}
|
||||
@@ -207,6 +233,35 @@ func registerUser(baseURL, username, email, password string) bool {
|
||||
return resp.StatusCode == http.StatusCreated
|
||||
}
|
||||
|
||||
func bootstrapAdminToken(baseURL, username, email, password string) string {
|
||||
payload, _ := json.Marshal(map[string]interface{}{
|
||||
"username": username,
|
||||
"email": email,
|
||||
"password": password,
|
||||
})
|
||||
req, _ := http.NewRequest("POST", baseURL+"/api/v1/auth/bootstrap-admin", bytes.NewReader(payload))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("X-Bootstrap-Secret", "test-bootstrap-secret")
|
||||
resp, err := (&http.Client{}).Do(req)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
return ""
|
||||
}
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(bodyBytes, &result); err != nil {
|
||||
return ""
|
||||
}
|
||||
data, ok := result["data"].(map[string]interface{})
|
||||
if !ok || data["access_token"] == nil {
|
||||
return ""
|
||||
}
|
||||
return data["access_token"].(string)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Auth Handler Tests
|
||||
// =============================================================================
|
||||
@@ -292,6 +347,89 @@ func TestAuthHandler_Login_Success(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthHandler_Login_SetsSessionCookies(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "cookieuser", "cookie@example.com", "Password123!")
|
||||
resp, body := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
|
||||
"account": "cookieuser",
|
||||
"password": "Password123!",
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||
}
|
||||
|
||||
cookies := resp.Cookies()
|
||||
var hasRefreshCookie bool
|
||||
var hasPresenceCookie bool
|
||||
for _, cookie := range cookies {
|
||||
switch cookie.Name {
|
||||
case "ums_refresh_token":
|
||||
hasRefreshCookie = cookie.HttpOnly && cookie.Value != ""
|
||||
case "ums_session_present":
|
||||
hasPresenceCookie = !cookie.HttpOnly && cookie.Value == "1"
|
||||
}
|
||||
}
|
||||
if !hasRefreshCookie {
|
||||
t.Fatalf("expected login response to set ums_refresh_token cookie, got %#v", cookies)
|
||||
}
|
||||
if !hasPresenceCookie {
|
||||
t.Fatalf("expected login response to set ums_session_present cookie, got %#v", cookies)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthHandler_RefreshToken_UsesCookieFallback(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "refreshcookieuser", "refreshcookie@example.com", "Password123!")
|
||||
jar, err := cookiejar.New(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("cookiejar.New() error: %v", err)
|
||||
}
|
||||
client := &http.Client{Jar: jar}
|
||||
|
||||
loginBody, _ := json.Marshal(map[string]interface{}{
|
||||
"account": "refreshcookieuser",
|
||||
"password": "Password123!",
|
||||
})
|
||||
loginReq, _ := http.NewRequest("POST", server.URL+"/api/v1/auth/login", bytes.NewReader(loginBody))
|
||||
loginReq.Header.Set("Content-Type", "application/json")
|
||||
loginResp, err := client.Do(loginReq)
|
||||
if err != nil {
|
||||
t.Fatalf("login request failed: %v", err)
|
||||
}
|
||||
defer loginResp.Body.Close()
|
||||
if loginResp.StatusCode != http.StatusOK {
|
||||
payload, _ := io.ReadAll(loginResp.Body)
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, loginResp.StatusCode, string(payload))
|
||||
}
|
||||
|
||||
refreshReq, _ := http.NewRequest("POST", server.URL+"/api/v1/auth/refresh", nil)
|
||||
refreshReq.Header.Set("Content-Type", "application/json")
|
||||
refreshResp, err := client.Do(refreshReq)
|
||||
if err != nil {
|
||||
t.Fatalf("refresh request failed: %v", err)
|
||||
}
|
||||
defer refreshResp.Body.Close()
|
||||
refreshPayload, _ := io.ReadAll(refreshResp.Body)
|
||||
if refreshResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, refreshResp.StatusCode, string(refreshPayload))
|
||||
}
|
||||
|
||||
var parsed map[string]interface{}
|
||||
if err := json.Unmarshal(refreshPayload, &parsed); err != nil {
|
||||
t.Fatalf("refresh response json unmarshal failed: %v", err)
|
||||
}
|
||||
data, _ := parsed["data"].(map[string]interface{})
|
||||
if data == nil || data["access_token"] == nil || data["refresh_token"] == nil {
|
||||
t.Fatalf("expected refresh response to include token pair, got %v", parsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthHandler_Login_WrongPassword(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
@@ -336,33 +474,61 @@ func TestAuthHandler_BootstrapAdmin_MissingSecret(t *testing.T) {
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Without BOOTSTRAP_SECRET env var set, should get forbidden
|
||||
if resp.StatusCode != http.StatusForbidden {
|
||||
t.Errorf("expected status %d for missing bootstrap secret, got %d", http.StatusForbidden, resp.StatusCode)
|
||||
// P0 修复后:已配置 BOOTSTRAP_SECRET 但未提供 header,应返回 401
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Errorf("expected status %d for missing bootstrap secret header, got %d", http.StatusUnauthorized, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthHandler_GetAuthCapabilities(t *testing.T) {
|
||||
func TestAuthHandler_VerifyTOTPAfterPasswordLogin_RequiresTempToken(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
resp, body := doGet(server.URL+"/api/v1/auth/capabilities", "")
|
||||
resp, body := doPost(server.URL+"/api/v1/auth/login/totp-verify", "", map[string]interface{}{
|
||||
"user_id": 1,
|
||||
"code": "123456",
|
||||
"device_id": "device-1",
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, resp.StatusCode)
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
json.Unmarshal([]byte(body), &result)
|
||||
if result["code"] != float64(0) {
|
||||
t.Errorf("expected code 0, got %v", result["code"])
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// User Handler Tests
|
||||
// =============================================================================
|
||||
func TestAuthHandler_UnconfiguredOAuthAndBindingsFailClosed(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "failclosed", "failclosed@test.com", "AdminPass123!")
|
||||
token := getToken(server.URL, "failclosed", "AdminPass123!")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
body map[string]interface{}
|
||||
}{
|
||||
{name: "oauth login", url: server.URL + "/api/v1/auth/oauth/github"},
|
||||
{name: "email bind code", url: server.URL + "/api/v1/users/me/bind-email/code", body: map[string]interface{}{"email": "bind@example.com"}},
|
||||
{name: "social bind", url: server.URL + "/api/v1/users/me/bind-social", body: map[string]interface{}{"provider": "github"}},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var resp *http.Response
|
||||
var body string
|
||||
if tc.body == nil {
|
||||
resp, body = doGet(tc.url, token)
|
||||
} else {
|
||||
resp, body = doPost(tc.url, token, tc.body)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusServiceUnavailable {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusServiceUnavailable, resp.StatusCode, body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserHandler_CreateUser_RequiresAdmin(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
@@ -400,39 +566,33 @@ func TestUserHandler_CreateUser_Unauthorized(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserHandler_ListUsers_Success(t *testing.T) {
|
||||
func TestUserHandler_ListUsers_ForbiddenForRegularUser(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "listadmin", "listadmin@test.com", "AdminPass123!")
|
||||
token := getToken(server.URL, "listadmin", "AdminPass123!")
|
||||
registerUser(server.URL, "listuser", "listuser@test.com", "AdminPass123!")
|
||||
token := getToken(server.URL, "listuser", "AdminPass123!")
|
||||
|
||||
resp, body := doGet(server.URL+"/api/v1/users?page=1&page_size=10", token)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
json.Unmarshal([]byte(body), &result)
|
||||
if result["code"] != float64(0) {
|
||||
t.Errorf("expected code 0, got %v", result["code"])
|
||||
if resp.StatusCode != http.StatusForbidden {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusForbidden, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserHandler_GetUser_Success(t *testing.T) {
|
||||
func TestUserHandler_GetUser_ForbiddenForRegularUser(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "getadmin", "getadmin@test.com", "AdminPass123!")
|
||||
token := getToken(server.URL, "getadmin", "AdminPass123!")
|
||||
registerUser(server.URL, "getuser", "getuser@test.com", "AdminPass123!")
|
||||
token := getToken(server.URL, "getuser", "AdminPass123!")
|
||||
|
||||
resp, _ := doGet(server.URL+"/api/v1/users/1", token)
|
||||
resp, body := doGet(server.URL+"/api/v1/users/1", token)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, resp.StatusCode)
|
||||
if resp.StatusCode != http.StatusForbidden {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusForbidden, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -440,8 +600,8 @@ func TestUserHandler_UpdateUser_Success(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "updateadmin", "updateadmin@test.com", "AdminPass123!")
|
||||
token := getToken(server.URL, "updateadmin", "AdminPass123!")
|
||||
registerUser(server.URL, "updateuser", "update@example.com", "UserPass123!")
|
||||
token := getToken(server.URL, "updateuser", "UserPass123!")
|
||||
|
||||
resp, body := doPut(server.URL+"/api/v1/users/1", token, map[string]string{"nickname": "Updated Nickname"})
|
||||
defer resp.Body.Close()
|
||||
@@ -451,6 +611,43 @@ func TestUserHandler_UpdateUser_Success(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserHandler_UpdateUser_AdminCanUpdateOther(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
token := bootstrapAdminToken(server.URL, "updateadmin", "updateadmin@test.com", "AdminPass123!")
|
||||
if token == "" {
|
||||
t.Fatal("bootstrap admin token should succeed")
|
||||
}
|
||||
registerUser(server.URL, "manageduser", "manageduser@test.com", "UserPass123!")
|
||||
|
||||
resp, body := doPut(server.URL+"/api/v1/users/2", token, map[string]string{"nickname": "Admin Updated"})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserHandler_UpdatePassword_NonAdminCannotUpdateOther(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "pwd-user-1", "pwd-user-1@test.com", "UserPass123!")
|
||||
token := getToken(server.URL, "pwd-user-1", "UserPass123!")
|
||||
registerUser(server.URL, "pwd-user-2", "pwd-user-2@test.com", "TargetPass123!")
|
||||
|
||||
resp, body := doPut(server.URL+"/api/v1/users/2/password", token, map[string]string{
|
||||
"old_password": "TargetPass123!",
|
||||
"new_password": "TargetNew456!",
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusForbidden {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusForbidden, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserHandler_DeleteUser_NonAdmin_Forbidden(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
@@ -471,8 +668,10 @@ func TestUserHandler_SearchUsers_Success(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "searchadmin", "searchadmin@test.com", "AdminPass123!")
|
||||
token := getToken(server.URL, "searchadmin", "AdminPass123!")
|
||||
token := bootstrapAdminToken(server.URL, "searchadmin", "searchadmin@test.com", "AdminPass123!")
|
||||
if token == "" {
|
||||
t.Fatal("bootstrap admin token should succeed")
|
||||
}
|
||||
|
||||
resp, body := doGet(server.URL+"/api/v1/users/1", token)
|
||||
defer resp.Body.Close()
|
||||
@@ -515,6 +714,24 @@ func TestUserHandler_GetUserRoles_Success(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserHandler_GetUserRoles_AdminCanViewOther(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
token := bootstrapAdminToken(server.URL, "rolesbootstrap", "rolesbootstrap@test.com", "AdminPass123!")
|
||||
if token == "" {
|
||||
t.Fatal("bootstrap admin token should succeed")
|
||||
}
|
||||
registerUser(server.URL, "role-target", "role-target@test.com", "UserPass123!")
|
||||
|
||||
resp, body := doGet(server.URL+"/api/v1/users/2/roles", token)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserHandler_AssignRoles_RequiresAdmin(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
@@ -974,8 +1191,10 @@ func TestInvalidUserID_ReturnsBadRequest(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "invalidid", "invalidid@test.com", "AdminPass123!")
|
||||
token := getToken(server.URL, "invalidid", "AdminPass123!")
|
||||
token := bootstrapAdminToken(server.URL, "invalidid", "invalidid@test.com", "AdminPass123!")
|
||||
if token == "" {
|
||||
t.Fatal("bootstrap admin token should succeed")
|
||||
}
|
||||
|
||||
resp, _ := doGet(server.URL+"/api/v1/users/invalid", token)
|
||||
defer resp.Body.Close()
|
||||
@@ -989,8 +1208,10 @@ func TestNonExistentUserID_ReturnsNotFound(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "notfound", "notfound@test.com", "AdminPass123!")
|
||||
token := getToken(server.URL, "notfound", "AdminPass123!")
|
||||
token := bootstrapAdminToken(server.URL, "notfound", "notfound@test.com", "AdminPass123!")
|
||||
if token == "" {
|
||||
t.Fatal("bootstrap admin token should succeed")
|
||||
}
|
||||
|
||||
resp, _ := doGet(server.URL+"/api/v1/users/99999", token)
|
||||
defer resp.Body.Close()
|
||||
@@ -1350,6 +1571,29 @@ func TestAvatarHandler_UploadAvatar_NonAdminCannotUpdateOther(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAvatarHandler_UploadAvatar_AdminCanUpdateOther(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
token := bootstrapAdminToken(server.URL, "avataradmin", "avataradmin@test.com", "AdminPass123!")
|
||||
if token == "" {
|
||||
t.Fatal("bootstrap admin token should succeed")
|
||||
}
|
||||
registerUser(server.URL, "avatar-target", "avatar-target@test.com", "UserPass123!")
|
||||
|
||||
fileContent := []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}
|
||||
resp, err := doUploadFile(server.URL+"/api/v1/users/2/avatar", token, "avatar", "test.png", fileContent)
|
||||
if err != nil {
|
||||
t.Fatalf("upload request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("expected status %d for admin updating other's avatar, got %d, body: %s", http.StatusOK, resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAvatarHandler_UploadAvatar_UserNotFoundOrForbidden(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
apimiddleware "github.com/user-management-system/internal/api/middleware"
|
||||
"github.com/user-management-system/internal/auth"
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/service"
|
||||
@@ -187,16 +188,7 @@ func (h *UserHandler) UpdateUser(c *gin.Context) {
|
||||
|
||||
// Authorization: only self or admin can update user profile
|
||||
currentUserID := c.GetInt64("user_id")
|
||||
isAdmin := false
|
||||
if roles, ok := c.Get("user_roles"); ok {
|
||||
for _, role := range roles.([]*domain.Role) {
|
||||
if role.Code == "admin" {
|
||||
isAdmin = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if currentUserID != id && !isAdmin {
|
||||
if currentUserID != id && !apimiddleware.IsAdmin(c) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"code": 403, "message": "permission denied"})
|
||||
return
|
||||
}
|
||||
@@ -289,6 +281,12 @@ func (h *UserHandler) UpdatePassword(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
currentUserID := c.GetInt64("user_id")
|
||||
if currentUserID != id && !apimiddleware.IsAdmin(c) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"code": 403, "message": "permission denied"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.userService.ChangePassword(c.Request.Context(), id, req.OldPassword, req.NewPassword); err != nil {
|
||||
handleError(c, err)
|
||||
return
|
||||
@@ -370,16 +368,7 @@ func (h *UserHandler) GetUserRoles(c *gin.Context) {
|
||||
|
||||
// Authorization: only self or admin can view user roles
|
||||
currentUserID := c.GetInt64("user_id")
|
||||
isAdmin := false
|
||||
if roles, ok := c.Get("user_roles"); ok {
|
||||
for _, role := range roles.([]*domain.Role) {
|
||||
if role.Code == "admin" {
|
||||
isAdmin = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if currentUserID != id && !isAdmin {
|
||||
if currentUserID != id && !apimiddleware.IsAdmin(c) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"code": 403, "message": "permission denied"})
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user