fix(n+1): 批量查询替代循环单查

- IsAdminBootstrapRequired: userRepo.GetByID 循环 → GetByIDs 批量
- AssignRoles: roleRepo.GetByID 循环 → GetByIDs 批量
- 在 userRepositoryInterface 补充 GetByIDs 方法签名
This commit is contained in:
2026-05-08 08:05:26 +08:00
parent 9b1cea246e
commit 2a18a6fb47
39 changed files with 3169 additions and 393 deletions

View File

@@ -8,6 +8,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/api/middleware"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/service"
)
@@ -22,6 +23,15 @@ func NewDeviceHandler(deviceService *service.DeviceService) *DeviceHandler {
return &DeviceHandler{deviceService: deviceService}
}
func (h *DeviceHandler) currentActor(c *gin.Context) (int64, bool, bool) {
userID, ok := getUserIDFromContext(c)
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "unauthorized"})
return 0, false, false
}
return userID, middleware.IsAdmin(c), true
}
// CreateDevice 创建设备
// @Summary 创建设备记录
// @Description 当前用户创建设备记录
@@ -118,7 +128,12 @@ func (h *DeviceHandler) GetDevice(c *gin.Context) {
return
}
device, err := h.deviceService.GetDevice(c.Request.Context(), id)
actorUserID, isAdmin, ok := h.currentActor(c)
if !ok {
return
}
device, err := h.deviceService.GetDeviceForActor(c.Request.Context(), actorUserID, id, isAdmin)
if err != nil {
handleError(c, err)
return
@@ -157,7 +172,12 @@ func (h *DeviceHandler) UpdateDevice(c *gin.Context) {
return
}
device, err := h.deviceService.UpdateDevice(c.Request.Context(), id, &req)
actorUserID, isAdmin, ok := h.currentActor(c)
if !ok {
return
}
device, err := h.deviceService.UpdateDeviceForActor(c.Request.Context(), actorUserID, id, isAdmin, &req)
if err != nil {
handleError(c, err)
return
@@ -187,7 +207,12 @@ func (h *DeviceHandler) DeleteDevice(c *gin.Context) {
return
}
if err := h.deviceService.DeleteDevice(c.Request.Context(), id); err != nil {
actorUserID, isAdmin, ok := h.currentActor(c)
if !ok {
return
}
if err := h.deviceService.DeleteDeviceForActor(c.Request.Context(), actorUserID, id, isAdmin); err != nil {
handleError(c, err)
return
}
@@ -238,7 +263,12 @@ func (h *DeviceHandler) UpdateDeviceStatus(c *gin.Context) {
return
}
if err := h.deviceService.UpdateDeviceStatus(c.Request.Context(), id, status); err != nil {
actorUserID, isAdmin, ok := h.currentActor(c)
if !ok {
return
}
if err := h.deviceService.UpdateDeviceStatusForActor(c.Request.Context(), actorUserID, id, isAdmin, status); err != nil {
handleError(c, err)
return
}
@@ -270,16 +300,7 @@ func (h *DeviceHandler) GetUserDevices(c *gin.Context) {
}
// 检查是否为管理员
roleCodes, _ := c.Get("role_codes")
isAdmin := false
if roles, ok := roleCodes.([]string); ok {
for _, role := range roles {
if role == "admin" {
isAdmin = true
break
}
}
}
isAdmin := middleware.IsAdmin(c)
userIDParam := c.Param("id")
userID, err := strconv.ParseInt(userIDParam, 10, 64)
@@ -405,7 +426,12 @@ func (h *DeviceHandler) TrustDevice(c *gin.Context) {
// 解析信任持续时间
trustDuration := parseDuration(req.TrustDuration)
if err := h.deviceService.TrustDevice(c.Request.Context(), id, trustDuration); err != nil {
actorUserID, isAdmin, ok := h.currentActor(c)
if !ok {
return
}
if err := h.deviceService.TrustDeviceForActor(c.Request.Context(), actorUserID, id, isAdmin, trustDuration); err != nil {
handleError(c, err)
return
}
@@ -478,7 +504,12 @@ func (h *DeviceHandler) UntrustDevice(c *gin.Context) {
return
}
if err := h.deviceService.UntrustDevice(c.Request.Context(), id); err != nil {
actorUserID, isAdmin, ok := h.currentActor(c)
if !ok {
return
}
if err := h.deviceService.UntrustDeviceForActor(c.Request.Context(), actorUserID, id, isAdmin); err != nil {
handleError(c, err)
return
}

View File

@@ -730,6 +730,173 @@ func TestUserHandler_UpdateUser_AdminCanUpdateAnotherUser(t *testing.T) {
}
}
func TestUserHandler_UpdateUser_ProfileFieldsPersisted(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "profileuser", "profileuser@test.com", "UserPass123!")
token := getToken(server.URL, "profileuser", "UserPass123!")
updatePayload := map[string]interface{}{
"nickname": "Profile Updated",
"gender": 1,
"birthday": "2026-03-15",
"region": "Hangzhou",
"bio": "Updated bio",
}
resp, body := doPut(server.URL+"/api/v1/users/1", token, updatePayload)
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
}
var updateResult map[string]interface{}
if err := json.Unmarshal([]byte(body), &updateResult); err != nil {
t.Fatalf("failed to parse update response: %v", err)
}
updateData, ok := updateResult["data"].(map[string]interface{})
if !ok {
t.Fatalf("expected update response data, got %s", body)
}
if updateData["nickname"] != "Profile Updated" {
t.Fatalf("expected nickname to be updated, got %+v", updateData)
}
if updateData["gender"] != float64(1) {
t.Fatalf("expected gender=1, got %+v", updateData)
}
if updateData["region"] != "Hangzhou" {
t.Fatalf("expected region to be persisted, got %+v", updateData)
}
if updateData["bio"] != "Updated bio" {
t.Fatalf("expected bio to be persisted, got %+v", updateData)
}
updateBirthday, ok := updateData["birthday"].(string)
if !ok || updateBirthday == "" {
t.Fatalf("expected birthday in update response, got %+v", updateData)
}
parsedUpdateBirthday, err := time.Parse(time.RFC3339, updateBirthday)
if err != nil {
t.Fatalf("expected RFC3339 birthday, got %q: %v", updateBirthday, err)
}
if parsedUpdateBirthday.Format("2006-01-02") != "2026-03-15" {
t.Fatalf("expected birthday 2026-03-15, got %s", parsedUpdateBirthday.Format("2006-01-02"))
}
getResp, getBody := doGet(server.URL+"/api/v1/users/1", token)
defer getResp.Body.Close()
if getResp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, getResp.StatusCode, getBody)
}
var getResult map[string]interface{}
if err := json.Unmarshal([]byte(getBody), &getResult); err != nil {
t.Fatalf("failed to parse get response: %v", err)
}
getData, ok := getResult["data"].(map[string]interface{})
if !ok {
t.Fatalf("expected get response data, got %s", getBody)
}
if getData["region"] != "Hangzhou" {
t.Fatalf("expected region in get response, got %+v", getData)
}
if getData["bio"] != "Updated bio" {
t.Fatalf("expected bio in get response, got %+v", getData)
}
}
func TestUserHandler_UpdatePassword_NonAdminCannotUpdateAnotherUser(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "password-actor", "password-actor@test.com", "ActorPass123!")
registerUser(server.URL, "password-target", "password-target@test.com", "TargetPass123!")
actorToken := getToken(server.URL, "password-actor", "ActorPass123!")
if actorToken == "" {
t.Fatal("actor token should not be empty")
}
resp, body := doPut(server.URL+"/api/v1/users/2/password", actorToken, map[string]interface{}{
"old_password": "TargetPass123!",
"new_password": "ChangedByOther123!",
})
defer resp.Body.Close()
if resp.StatusCode != http.StatusForbidden {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusForbidden, resp.StatusCode, body)
}
oldLoginResp, oldLoginBody := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
"account": "password-target",
"password": "TargetPass123!",
})
defer oldLoginResp.Body.Close()
if oldLoginResp.StatusCode != http.StatusOK {
t.Fatalf("expected target old password to remain valid, got %d, body: %s", oldLoginResp.StatusCode, oldLoginBody)
}
newLoginResp, _ := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
"account": "password-target",
"password": "ChangedByOther123!",
})
defer newLoginResp.Body.Close()
if newLoginResp.StatusCode == http.StatusOK {
t.Fatal("expected unauthorized password change attempt to leave target password unchanged")
}
}
func TestUserHandler_UpdatePassword_AdminCanResetAnotherUser(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
t.Setenv("BOOTSTRAP_SECRET", "handler-bootstrap-secret")
adminToken := bootstrapAdmin(server.URL, "handler-bootstrap-secret", "passwordadmin", "passwordadmin@test.com", "AdminPass123!")
registerUser(server.URL, "password-target", "password-target@test.com", "TargetPass123!")
if adminToken == "" {
t.Fatal("bootstrap admin should return access token")
}
resp, body := doPut(server.URL+"/api/v1/users/2/password", adminToken, map[string]interface{}{
"new_password": "AdminReset123!",
})
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
}
oldLoginResp, _ := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
"account": "password-target",
"password": "TargetPass123!",
})
defer oldLoginResp.Body.Close()
if oldLoginResp.StatusCode == http.StatusOK {
t.Fatal("expected old password to be invalid after admin reset")
}
newLoginResp, newLoginBody := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
"account": "password-target",
"password": "AdminReset123!",
})
defer newLoginResp.Body.Close()
if newLoginResp.StatusCode != http.StatusOK {
t.Fatalf("expected reset password to work, got %d, body: %s", newLoginResp.StatusCode, newLoginBody)
}
}
func TestUserHandler_DeleteUser_NonAdmin_Forbidden(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
@@ -958,6 +1125,218 @@ func TestDeviceHandler_CreateDevice_Success(t *testing.T) {
}
}
func createDeviceForHandlerTest(t *testing.T, baseURL, token, deviceID, deviceName string) int64 {
t.Helper()
resp, body := doPost(baseURL+"/api/v1/devices", token, map[string]interface{}{
"device_id": deviceID,
"device_name": deviceName,
"device_type": 1,
})
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
t.Fatalf("expected device create status %d, got %d, body: %s", http.StatusCreated, resp.StatusCode, body)
}
var result map[string]interface{}
if err := json.Unmarshal([]byte(body), &result); err != nil {
t.Fatalf("parse create device response failed: %v", err)
}
data, ok := result["data"].(map[string]interface{})
if !ok {
t.Fatalf("expected device payload, got body: %s", body)
}
id, ok := data["id"].(float64)
if !ok {
t.Fatalf("expected numeric device id, got body: %s", body)
}
return int64(id)
}
func TestDeviceHandler_GetDevice_IDOR_Forbidden(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "deviceidor_get_actor", "deviceidor_get_actor@test.com", "UserPass123!")
registerUser(server.URL, "deviceidor_get_owner", "deviceidor_get_owner@test.com", "UserPass123!")
actorToken := getToken(server.URL, "deviceidor_get_actor", "UserPass123!")
ownerToken := getToken(server.URL, "deviceidor_get_owner", "UserPass123!")
deviceID := createDeviceForHandlerTest(t, server.URL, ownerToken, "device-idor-get", "Owner Device")
resp, body := doGet(fmt.Sprintf("%s/api/v1/devices/%d", server.URL, deviceID), actorToken)
defer resp.Body.Close()
if resp.StatusCode != http.StatusForbidden {
t.Fatalf("expected status %d for cross-user device read, got %d, body: %s", http.StatusForbidden, resp.StatusCode, body)
}
}
func TestDeviceHandler_UpdateDevice_IDOR_Forbidden(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "deviceidor_update_actor", "deviceidor_update_actor@test.com", "UserPass123!")
registerUser(server.URL, "deviceidor_update_owner", "deviceidor_update_owner@test.com", "UserPass123!")
actorToken := getToken(server.URL, "deviceidor_update_actor", "UserPass123!")
ownerToken := getToken(server.URL, "deviceidor_update_owner", "UserPass123!")
deviceID := createDeviceForHandlerTest(t, server.URL, ownerToken, "device-idor-update", "Original Device")
resp, body := doPut(fmt.Sprintf("%s/api/v1/devices/%d", server.URL, deviceID), actorToken, map[string]interface{}{
"device_name": "Hacked Device",
})
defer resp.Body.Close()
if resp.StatusCode != http.StatusForbidden {
t.Fatalf("expected status %d for cross-user device update, got %d, body: %s", http.StatusForbidden, resp.StatusCode, body)
}
ownerResp, ownerBody := doGet(fmt.Sprintf("%s/api/v1/devices/%d", server.URL, deviceID), ownerToken)
defer ownerResp.Body.Close()
if ownerResp.StatusCode != http.StatusOK {
t.Fatalf("expected owner device read status %d, got %d, body: %s", http.StatusOK, ownerResp.StatusCode, ownerBody)
}
if !bytes.Contains([]byte(ownerBody), []byte("Original Device")) {
t.Fatalf("expected device name to remain unchanged, body: %s", ownerBody)
}
}
func TestDeviceHandler_DeleteDevice_IDOR_Forbidden(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "deviceidor_delete_actor", "deviceidor_delete_actor@test.com", "UserPass123!")
registerUser(server.URL, "deviceidor_delete_owner", "deviceidor_delete_owner@test.com", "UserPass123!")
actorToken := getToken(server.URL, "deviceidor_delete_actor", "UserPass123!")
ownerToken := getToken(server.URL, "deviceidor_delete_owner", "UserPass123!")
deviceID := createDeviceForHandlerTest(t, server.URL, ownerToken, "device-idor-delete", "Delete Target")
resp, body := doDelete(fmt.Sprintf("%s/api/v1/devices/%d", server.URL, deviceID), actorToken)
defer resp.Body.Close()
if resp.StatusCode != http.StatusForbidden {
t.Fatalf("expected status %d for cross-user device delete, got %d, body: %s", http.StatusForbidden, resp.StatusCode, body)
}
ownerResp, ownerBody := doGet(fmt.Sprintf("%s/api/v1/devices/%d", server.URL, deviceID), ownerToken)
defer ownerResp.Body.Close()
if ownerResp.StatusCode != http.StatusOK {
t.Fatalf("expected device to remain after forbidden delete, got %d, body: %s", ownerResp.StatusCode, ownerBody)
}
}
func TestDeviceHandler_TrustDevice_IDOR_Forbidden(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "deviceidor_trust_actor", "deviceidor_trust_actor@test.com", "UserPass123!")
registerUser(server.URL, "deviceidor_trust_owner", "deviceidor_trust_owner@test.com", "UserPass123!")
actorToken := getToken(server.URL, "deviceidor_trust_actor", "UserPass123!")
ownerToken := getToken(server.URL, "deviceidor_trust_owner", "UserPass123!")
deviceID := createDeviceForHandlerTest(t, server.URL, ownerToken, "device-idor-trust", "Trust Target")
resp, body := doPost(fmt.Sprintf("%s/api/v1/devices/%d/trust", server.URL, deviceID), actorToken, map[string]interface{}{
"trust_duration": "24h",
})
defer resp.Body.Close()
if resp.StatusCode != http.StatusForbidden {
t.Fatalf("expected status %d for cross-user device trust, got %d, body: %s", http.StatusForbidden, resp.StatusCode, body)
}
ownerResp, ownerBody := doGet(fmt.Sprintf("%s/api/v1/devices/%d", server.URL, deviceID), ownerToken)
defer ownerResp.Body.Close()
if ownerResp.StatusCode != http.StatusOK {
t.Fatalf("expected owner device read status %d, got %d, body: %s", http.StatusOK, ownerResp.StatusCode, ownerBody)
}
if bytes.Contains([]byte(ownerBody), []byte("\"is_trusted\":true")) {
t.Fatalf("expected forbidden trust to leave device untrusted, body: %s", ownerBody)
}
}
func TestDeviceHandler_UntrustDevice_IDOR_Forbidden(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "deviceidor_untrust_actor", "deviceidor_untrust_actor@test.com", "UserPass123!")
registerUser(server.URL, "deviceidor_untrust_owner", "deviceidor_untrust_owner@test.com", "UserPass123!")
actorToken := getToken(server.URL, "deviceidor_untrust_actor", "UserPass123!")
ownerToken := getToken(server.URL, "deviceidor_untrust_owner", "UserPass123!")
deviceID := createDeviceForHandlerTest(t, server.URL, ownerToken, "device-idor-untrust", "Untrust Target")
trustResp, trustBody := doPost(fmt.Sprintf("%s/api/v1/devices/%d/trust", server.URL, deviceID), ownerToken, map[string]interface{}{
"trust_duration": "24h",
})
defer trustResp.Body.Close()
if trustResp.StatusCode != http.StatusOK {
t.Fatalf("expected owner trust status %d, got %d, body: %s", http.StatusOK, trustResp.StatusCode, trustBody)
}
resp, body := doDelete(fmt.Sprintf("%s/api/v1/devices/%d/trust", server.URL, deviceID), actorToken)
defer resp.Body.Close()
if resp.StatusCode != http.StatusForbidden {
t.Fatalf("expected status %d for cross-user device untrust, got %d, body: %s", http.StatusForbidden, resp.StatusCode, body)
}
ownerResp, ownerBody := doGet(fmt.Sprintf("%s/api/v1/devices/%d", server.URL, deviceID), ownerToken)
defer ownerResp.Body.Close()
if ownerResp.StatusCode != http.StatusOK {
t.Fatalf("expected owner device read status %d, got %d, body: %s", http.StatusOK, ownerResp.StatusCode, ownerBody)
}
if !bytes.Contains([]byte(ownerBody), []byte("\"is_trusted\":true")) {
t.Fatalf("expected forbidden untrust to leave trusted device unchanged, body: %s", ownerBody)
}
}
func TestDeviceHandler_UpdateDeviceStatus_IDOR_Forbidden(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "deviceidor_status_actor", "deviceidor_status_actor@test.com", "UserPass123!")
registerUser(server.URL, "deviceidor_status_owner", "deviceidor_status_owner@test.com", "UserPass123!")
actorToken := getToken(server.URL, "deviceidor_status_actor", "UserPass123!")
ownerToken := getToken(server.URL, "deviceidor_status_owner", "UserPass123!")
deviceID := createDeviceForHandlerTest(t, server.URL, ownerToken, "device-idor-status", "Status Target")
resp, body := doPut(fmt.Sprintf("%s/api/v1/devices/%d/status", server.URL, deviceID), actorToken, map[string]interface{}{
"status": "inactive",
})
defer resp.Body.Close()
if resp.StatusCode != http.StatusForbidden {
t.Fatalf("expected status %d for cross-user device status update, got %d, body: %s", http.StatusForbidden, resp.StatusCode, body)
}
ownerResp, ownerBody := doGet(fmt.Sprintf("%s/api/v1/devices/%d", server.URL, deviceID), ownerToken)
defer ownerResp.Body.Close()
if ownerResp.StatusCode != http.StatusOK {
t.Fatalf("expected owner device read status %d, got %d, body: %s", http.StatusOK, ownerResp.StatusCode, ownerBody)
}
if !bytes.Contains([]byte(ownerBody), []byte("\"status\":1")) {
t.Fatalf("expected forbidden status update to leave device active, body: %s", ownerBody)
}
}
// =============================================================================
// Role Handler Tests
// =============================================================================

View File

@@ -3,6 +3,7 @@ package handler
import (
"net/http"
"strconv"
"time"
"github.com/gin-gonic/gin"
@@ -195,8 +196,13 @@ func (h *UserHandler) UpdateUser(c *gin.Context) {
}
var req struct {
Email *string `json:"email"`
Nickname *string `json:"nickname"`
Email *string `json:"email"`
Phone *string `json:"phone"`
Nickname *string `json:"nickname"`
Gender *domain.Gender `json:"gender"`
Birthday *string `json:"birthday"`
Region *string `json:"region"`
Bio *string `json:"bio"`
}
if err := c.ShouldBindJSON(&req); err != nil {
@@ -211,11 +217,35 @@ func (h *UserHandler) UpdateUser(c *gin.Context) {
}
if req.Email != nil {
user.Email = req.Email
user.Email = domain.StrPtr(*req.Email)
}
if req.Phone != nil {
user.Phone = domain.StrPtr(*req.Phone)
}
if req.Nickname != nil {
user.Nickname = *req.Nickname
}
if req.Gender != nil {
user.Gender = *req.Gender
}
if req.Birthday != nil {
if *req.Birthday == "" {
user.Birthday = nil
} else if birthday, err := time.Parse("2006-01-02", *req.Birthday); err == nil {
user.Birthday = &birthday
} else if birthday, err := time.Parse(time.RFC3339, *req.Birthday); err == nil {
user.Birthday = &birthday
} else {
c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": "invalid birthday"})
return
}
}
if req.Region != nil {
user.Region = *req.Region
}
if req.Bio != nil {
user.Bio = *req.Bio
}
if err := h.userService.Update(c.Request.Context(), user); err != nil {
handleError(c, err)
@@ -272,8 +302,16 @@ func (h *UserHandler) UpdatePassword(c *gin.Context) {
return
}
currentUserID := c.GetInt64("user_id")
isAdmin := middleware.IsAdmin(c)
isSelf := currentUserID == id
if !isSelf && !isAdmin {
c.JSON(http.StatusForbidden, gin.H{"code": 403, "message": "permission denied"})
return
}
var req struct {
OldPassword string `json:"old_password" binding:"required"`
OldPassword string `json:"old_password"`
NewPassword string `json:"new_password" binding:"required"`
}
@@ -282,9 +320,16 @@ func (h *UserHandler) UpdatePassword(c *gin.Context) {
return
}
if err := h.userService.ChangePassword(c.Request.Context(), id, req.OldPassword, req.NewPassword); err != nil {
handleError(c, err)
return
if isSelf {
if err := h.userService.ChangePassword(c.Request.Context(), id, req.OldPassword, req.NewPassword); err != nil {
handleError(c, err)
return
}
} else {
if err := h.userService.AdminResetPassword(c.Request.Context(), id, req.NewPassword); err != nil {
handleError(c, err)
return
}
}
c.JSON(http.StatusOK, gin.H{"code": 0, "message": "密码修改成功"})
@@ -570,11 +615,22 @@ func (h *UserHandler) DeleteAdmin(c *gin.Context) {
}
type UserResponse struct {
ID int64 `json:"id"`
Username string `json:"username"`
Email string `json:"email,omitempty"`
Nickname string `json:"nickname,omitempty"`
Status string `json:"status"`
ID int64 `json:"id"`
Username string `json:"username"`
Email string `json:"email,omitempty"`
Phone string `json:"phone,omitempty"`
Nickname string `json:"nickname,omitempty"`
Avatar string `json:"avatar,omitempty"`
Gender domain.Gender `json:"gender"`
Birthday *time.Time `json:"birthday,omitempty"`
Region string `json:"region,omitempty"`
Bio string `json:"bio,omitempty"`
Status string `json:"status"`
LastLoginAt *time.Time `json:"last_login_at,omitempty"`
LastLoginIP string `json:"last_login_ip,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
TOTPEnabled bool `json:"totp_enabled"`
}
func toUserResponse(u *domain.User) *UserResponse {
@@ -582,11 +638,26 @@ func toUserResponse(u *domain.User) *UserResponse {
if u.Email != nil {
email = *u.Email
}
phone := ""
if u.Phone != nil {
phone = *u.Phone
}
return &UserResponse{
ID: u.ID,
Username: u.Username,
Email: email,
Nickname: u.Nickname,
Status: strconv.FormatInt(int64(u.Status), 10),
ID: u.ID,
Username: u.Username,
Email: email,
Phone: phone,
Nickname: u.Nickname,
Avatar: u.Avatar,
Gender: u.Gender,
Birthday: u.Birthday,
Region: u.Region,
Bio: u.Bio,
Status: strconv.FormatInt(int64(u.Status), 10),
LastLoginAt: u.LastLoginTime,
LastLoginIP: u.LastLoginIP,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
TOTPEnabled: u.TOTPEnabled,
}
}

View File

@@ -88,6 +88,11 @@ func (m *OperationLogMiddleware) Record() gin.HandlerFunc {
}
go func(entry *domain.OperationLog) {
defer func() {
if r := recover(); r != nil {
// PERF-07: panic recover 保护,防止操作日志写入异常导致进程崩溃
}
}()
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
_ = m.repo.Create(ctx, entry)

View File

@@ -199,3 +199,47 @@ func (m *RateLimitMiddleware) getOrCreateLimiter(key string, window time.Duratio
m.limiters[key] = limiter
return limiter
}
// Cleanup 清理过期的不活跃 limiter防止 map 无界增长P0 资源泄漏修复)
func (m *RateLimitMiddleware) Cleanup() {
m.mu.Lock()
defer m.mu.Unlock()
now := time.Now().UnixMilli()
for key, limiter := range m.limiters {
limiter.mu.Lock()
cutoff := now - limiter.window.Milliseconds()
// 只保留仍在窗口内的请求时间戳
validRequests := make([]int64, 0, len(limiter.requests))
for _, ts := range limiter.requests {
if ts > cutoff {
validRequests = append(validRequests, ts)
}
}
limiter.requests = validRequests
isEmpty := len(limiter.requests) == 0
limiter.mu.Unlock()
if isEmpty {
delete(m.limiters, key)
}
}
}
// StartCleanup 启动后台定期清理 goroutine返回停止函数P0 资源泄漏修复)
func (m *RateLimitMiddleware) StartCleanup() func() {
ticker := time.NewTicker(m.cleanupInt)
done := make(chan struct{})
go func() {
for {
select {
case <-ticker.C:
m.Cleanup()
case <-done:
ticker.Stop()
return
}
}
}()
return func() { close(done) }
}

View File

@@ -122,7 +122,10 @@ func (r *Router) Setup() *gin.Engine {
)
}
r.engine.Static("/uploads", "./uploads")
// P0 安全修复:/uploads 目录不再公开暴露,改为需要认证后才能访问
uploadsGroup := r.engine.Group("/uploads", r.authMiddleware.Required())
uploadsGroup.Static("", "./uploads")
r.engine.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler))
if r.ipFilterMiddleware != nil {

View File

@@ -10,11 +10,13 @@ import (
"encoding/hex"
"fmt"
"image/png"
"regexp"
"strings"
"time"
"github.com/pquerna/otp"
"github.com/pquerna/otp/totp"
"golang.org/x/crypto/bcrypt"
)
const (
@@ -111,27 +113,62 @@ func ValidateRecoveryCode(inputCode string, storedCodes []string) (int, bool) {
return -1, false
}
// HashRecoveryCode 使用 SHA256 哈希恢复码(用于存储)
// HashRecoveryCode 使用 bcrypt 慢哈希恢复码(用于存储)
// P2 安全增强:将 SHA256 快速哈希升级为 bcrypt 慢哈希,防止 GPU 暴力破解
func HashRecoveryCode(code string) (string, error) {
h := sha256.Sum256([]byte(code))
return hex.EncodeToString(h[:]), nil
normalized := strings.ToUpper(strings.ReplaceAll(strings.TrimSpace(code), "-", ""))
hash, err := bcrypt.GenerateFromPassword([]byte(normalized), bcrypt.DefaultCost)
if err != nil {
return "", fmt.Errorf("hash recovery code failed: %w", err)
}
return string(hash), nil
}
// VerifyRecoveryCode 验证恢复码(自动哈希后比较
// sha256HexPattern 匹配 64 位十六进制字符串(旧版 SHA256 哈希格式
var sha256HexPattern = regexp.MustCompile("^[0-9a-fA-F]{64}$")
// isLegacySHA256Hash 检测是否为旧版 SHA256 哈希值
func isLegacySHA256Hash(hash string) bool {
return sha256HexPattern.MatchString(hash)
}
// legacyHashRecoveryCode 旧版 SHA256 哈希(用于向后兼容验证)
func legacyHashRecoveryCode(code string) string {
h := sha256.Sum256([]byte(code))
return hex.EncodeToString(h[:])
}
// VerifyRecoveryCode 验证恢复码(支持 bcrypt 新哈希和 SHA256 旧哈希向后兼容)
// 使用恒定时间比较防止时序攻击
func VerifyRecoveryCode(inputCode string, hashedCodes []string) (int, bool) {
hashedInput, err := HashRecoveryCode(inputCode)
if err != nil {
return -1, false
}
normalized := strings.ToUpper(strings.ReplaceAll(strings.TrimSpace(inputCode), "-", ""))
found := -1
// 固定次数比较,防止时序攻击泄露匹配位置
for i := 0; i < len(hashedCodes); i++ {
hashed := hashedCodes[i]
if subtle.ConstantTimeCompare([]byte(hashedInput), []byte(hashed)) == 1 {
stored := hashedCodes[i]
if stored == "" {
continue
}
var matched bool
if isLegacySHA256Hash(stored) {
// 向后兼容:旧版 SHA256 哈希
hashedInput := legacyHashRecoveryCode(inputCode)
if subtle.ConstantTimeCompare([]byte(hashedInput), []byte(stored)) == 1 {
matched = true
}
} else {
// 新版 bcrypt 哈希
if err := bcrypt.CompareHashAndPassword([]byte(stored), []byte(normalized)); err == nil {
matched = true
}
}
if matched {
found = i
}
}
if found >= 0 {
return found, true
}

View File

@@ -112,27 +112,29 @@ func TestHashRecoveryCode(t *testing.T) {
t.Fatal("HashRecoveryCode should return non-empty hash")
}
// Same code should produce same hash
hashed2, err := HashRecoveryCode(code)
if err != nil {
t.Fatalf("HashRecoveryCode second call failed: %v", err)
// Same code should verify against its own hash (bcrypt uses random salt, so hashes differ)
_, ok := VerifyRecoveryCode(code, []string{hashed})
if !ok {
t.Error("Same code should verify against its own hash")
}
if hashed != hashed2 {
t.Error("Same code should produce same hash")
}
// Different codes should produce different hashes
// Different codes should NOT verify
hashed3, err := HashRecoveryCode("DIFFERENT-CODE")
if err != nil {
t.Fatalf("HashRecoveryCode for different code failed: %v", err)
}
if hashed == hashed3 {
t.Error("Different codes should produce different hashes")
_, ok2 := VerifyRecoveryCode(code, []string{hashed3})
if ok2 {
t.Error("Different codes should NOT verify against each other's hash")
}
t.Logf("Hashed code: %s", hashed)
// bcrypt hash format check
if !strings.HasPrefix(hashed, "$2a$") {
t.Errorf("Hash should be bcrypt format, got: %s", hashed)
}
t.Logf("Hashed code (bcrypt): %s", hashed)
}
func TestVerifyRecoveryCode(t *testing.T) {

View File

@@ -207,6 +207,16 @@ func (r *DeviceRepository) GetTrustedDevices(ctx context.Context, userID int64)
return devices, nil
}
// CountTrustedDevices 统计用户当前信任设备数量(未过期的)
func (r *DeviceRepository) CountTrustedDevices(ctx context.Context, userID int64) (int64, error) {
var count int64
now := time.Now()
err := r.db.WithContext(ctx).Model(&domain.Device{}).
Where("user_id = ? AND is_trusted = ? AND (trust_expires_at IS NULL OR trust_expires_at > ?)", userID, true, now).
Count(&count).Error
return count, err
}
// ListDevicesParams 设备列表查询参数
type ListDevicesParams struct {
UserID int64

View File

@@ -191,13 +191,20 @@ func (r *RoleRepository) GetByIDs(ctx context.Context, ids []int64) ([]*domain.R
return roles, nil
}
// maxAncestorDepth 角色祖先查询最大深度,防止循环引用导致无限循环
const maxAncestorDepth = 20
// GetAncestorIDs 获取角色的所有祖先角色ID用于权限继承
func (r *RoleRepository) GetAncestorIDs(ctx context.Context, roleID int64) ([]int64, error) {
var ancestorIDs []int64
currentID := roleID
depth := 0
// 循环向上查找父角色,直到没有父角色为止
// 循环向上查找父角色,直到没有父角色或达到深度上限为止
for {
if depth >= maxAncestorDepth {
break
}
var role domain.Role
err := r.db.WithContext(ctx).Select("id", "parent_id").First(&role, currentID).Error
if err != nil {
@@ -211,6 +218,7 @@ func (r *RoleRepository) GetAncestorIDs(ctx context.Context, roleID int64) ([]in
}
ancestorIDs = append(ancestorIDs, *role.ParentID)
currentID = *role.ParentID
depth++
}
return ancestorIDs, nil

View File

@@ -119,15 +119,61 @@ func (r *RolePermissionRepository) GetPermissionByID(ctx context.Context, permis
return &permission, nil
}
// GetPermissionIDsByRoleIDs 根据角色ID列表批量获取权限ID
// GetRoleAncestorIDs 递归获取角色的所有祖先角色ID含自身
// 包含循环检测(最大深度 5 层)
func (r *RolePermissionRepository) GetRoleAncestorIDs(ctx context.Context, roleID int64) ([]int64, error) {
var ancestors []int64
visited := make(map[int64]bool)
current := roleID
depth := 0
maxDepth := 5
for current > 0 && depth < maxDepth {
if visited[current] {
break // 循环检测
}
visited[current] = true
ancestors = append(ancestors, current)
var role domain.Role
err := r.db.WithContext(ctx).Select("parent_id").First(&role, current).Error
if err != nil || role.ParentID == nil {
break
}
current = *role.ParentID
depth++
}
return ancestors, nil
}
// GetPermissionIDsByRoleIDs 根据角色ID列表批量获取权限ID含继承的父角色权限
func (r *RolePermissionRepository) GetPermissionIDsByRoleIDs(ctx context.Context, roleIDs []int64) ([]int64, error) {
if len(roleIDs) == 0 {
return []int64{}, nil
}
// 收集所有角色ID含继承的父角色
allRoleIDs := make(map[int64]bool)
for _, roleID := range roleIDs {
ancestors, err := r.GetRoleAncestorIDs(ctx, roleID)
if err != nil {
return nil, err
}
for _, id := range ancestors {
allRoleIDs[id] = true
}
}
// 转换为 slice
ids := make([]int64, 0, len(allRoleIDs))
for id := range allRoleIDs {
ids = append(ids, id)
}
var permissionIDs []int64
err := r.db.WithContext(ctx).Model(&domain.RolePermission{}).
Where("role_id IN ?", roleIDs).
Where("role_id IN ?", ids).
Pluck("permission_id", &permissionIDs).Error
if err != nil {
return nil, err

View File

@@ -104,6 +104,18 @@ func (r *UserRepository) GetByPhone(ctx context.Context, phone string) (*domain.
return &user, nil
}
// FindByAccount 按账号查询用户(支持用户名/邮箱/手机号P1性能优化替代串行查询
func (r *UserRepository) FindByAccount(ctx context.Context, account string) (*domain.User, error) {
var user domain.User
err := r.db.WithContext(ctx).
Where("username = ? OR email = ? OR phone = ?", account, account, account).
First(&user).Error
if err != nil {
return nil, err
}
return &user, nil
}
// List 获取用户列表
func (r *UserRepository) List(ctx context.Context, offset, limit int) ([]*domain.User, int64, error) {
var users []*domain.User
@@ -195,6 +207,21 @@ func (r *UserRepository) ExistsByPhone(ctx context.Context, phone string) (bool,
return count > 0, err
}
// FilterExistingUsernames 批量筛选已存在的用户名P1性能优化替代循环查询
func (r *UserRepository) FilterExistingUsernames(ctx context.Context, usernames []string) ([]string, error) {
if len(usernames) == 0 {
return []string{}, nil
}
var existing []string
err := r.db.WithContext(ctx).Model(&domain.User{}).
Where("username IN ?", usernames).
Pluck("username", &existing).Error
if err != nil {
return nil, err
}
return existing, nil
}
// Search 搜索用户
func (r *UserRepository) Search(ctx context.Context, keyword string, offset, limit int) ([]*domain.User, int64, error) {
var users []*domain.User

View File

@@ -83,69 +83,89 @@ func (r *UserRoleRepository) GetRoleIDsByUserID(ctx context.Context, userID int6
return roleIDs, nil
}
// GetUserRolesAndPermissions 获取用户角色和权限PERF-01 优化:合并为单次 JOIN 查询
func (r *UserRoleRepository) GetUserRolesAndPermissions(ctx context.Context, userID int64) ([]*domain.Role, []*domain.Permission, error) {
var results []struct {
RoleID int64
RoleName string
RoleCode string
RoleStatus int
PermissionID int64
PermissionCode string
PermissionName string
// getRoleAncestorIDs 递归获取角色的所有祖先角色ID含自身
// 包含循环检测(最大深度 5 层)
func (r *UserRoleRepository) getRoleAncestorIDs(ctx context.Context, roleID int64) ([]int64, error) {
var ancestors []int64
visited := make(map[int64]bool)
current := roleID
depth := 0
maxDepth := 5
for current > 0 && depth < maxDepth {
if visited[current] {
break // 循环检测
}
visited[current] = true
ancestors = append(ancestors, current)
var role domain.Role
err := r.db.WithContext(ctx).Select("parent_id").First(&role, current).Error
if err != nil || role.ParentID == nil {
break
}
current = *role.ParentID
depth++
}
// 使用 LEFT JOIN 一次性获取用户角色和权限
err := r.db.WithContext(ctx).
Raw(`
SELECT DISTINCT r.id as role_id, r.name as role_name, r.code as role_code, r.status as role_status,
p.id as permission_id, p.code as permission_code, p.name as permission_name
FROM user_roles ur
JOIN roles r ON ur.role_id = r.id
LEFT JOIN role_permissions rp ON r.id = rp.role_id
LEFT JOIN permissions p ON rp.permission_id = p.id
WHERE ur.user_id = ? AND r.status = 1
`, userID).
Scan(&results).Error
return ancestors, nil
}
// GetUserRolesAndPermissions 获取用户角色和权限(包含继承的父角色和权限)
func (r *UserRoleRepository) GetUserRolesAndPermissions(ctx context.Context, userID int64) ([]*domain.Role, []*domain.Permission, error) {
// 获取用户直接分配的角色ID
var directRoleIDs []int64
err := r.db.WithContext(ctx).Model(&domain.UserRole{}).Where("user_id = ?", userID).Pluck("role_id", &directRoleIDs).Error
if err != nil {
return nil, nil, err
}
// 构建角色和权限列表
roleMap := make(map[int64]*domain.Role)
permMap := make(map[int64]*domain.Permission)
for _, row := range results {
if _, ok := roleMap[row.RoleID]; !ok {
roleMap[row.RoleID] = &domain.Role{
ID: row.RoleID,
Name: row.RoleName,
Code: row.RoleCode,
Status: domain.RoleStatus(row.RoleStatus),
}
// 递归获取所有祖先角色ID含自身包含循环检测
allRoleIDMap := make(map[int64]bool)
for _, roleID := range directRoleIDs {
ancestors, err := r.getRoleAncestorIDs(ctx, roleID)
if err != nil {
return nil, nil, err
}
if row.PermissionID > 0 {
if _, ok := permMap[row.PermissionID]; !ok {
permMap[row.PermissionID] = &domain.Permission{
ID: row.PermissionID,
Code: row.PermissionCode,
Name: row.PermissionName,
}
}
for _, id := range ancestors {
allRoleIDMap[id] = true
}
}
roles := make([]*domain.Role, 0, len(roleMap))
for _, role := range roleMap {
roles = append(roles, role)
// 转换为 slice
allRoleIDs := make([]int64, 0, len(allRoleIDMap))
for id := range allRoleIDMap {
allRoleIDs = append(allRoleIDs, id)
}
perms := make([]*domain.Permission, 0, len(permMap))
for _, perm := range permMap {
perms = append(perms, perm)
if len(allRoleIDs) == 0 {
return []*domain.Role{}, []*domain.Permission{}, nil
}
return roles, perms, nil
// 查询所有角色信息
var roles []*domain.Role
err = r.db.WithContext(ctx).Where("id IN ? AND status = ?", allRoleIDs, domain.RoleStatusEnabled).Find(&roles).Error
if err != nil {
return nil, nil, err
}
// 查询所有权限ID
var permissionIDs []int64
err = r.db.WithContext(ctx).Model(&domain.RolePermission{}).Where("role_id IN ?", allRoleIDs).Pluck("permission_id", &permissionIDs).Error
if err != nil {
return nil, nil, err
}
// 查询权限详情
var permissions []*domain.Permission
if len(permissionIDs) > 0 {
err = r.db.WithContext(ctx).Where("id IN ?", permissionIDs).Find(&permissions).Error
if err != nil {
return nil, nil, err
}
}
return roles, permissions, nil
}
// GetUserIDByRoleID 根据角色ID获取用户ID列表

View File

@@ -11,48 +11,92 @@ type Validator struct {
passwordMinLength int
passwordRequireSpecial bool
passwordRequireNumber bool
// 预编译的正则表达式避免每次调用重复编译P1性能优化
emailRe *regexp.Regexp
phoneRe *regexp.Regexp
usernameRe *regexp.Regexp
urlRe *regexp.Regexp
sqlPatterns []*regexp.Regexp
xssPatterns []*regexp.Regexp
}
// NewValidator creates a validator with the configured password rules.
func NewValidator(minLength int, requireSpecial, requireNumber bool) *Validator {
return &Validator{
v := &Validator{
passwordMinLength: minLength,
passwordRequireSpecial: requireSpecial,
passwordRequireNumber: requireNumber,
}
// 预编译常用验证正则P1性能优化
v.emailRe = regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
v.phoneRe = regexp.MustCompile(`^1[3-9]\d{9}$`)
v.usernameRe = regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9_]{3,19}$`)
v.urlRe = regexp.MustCompile(`^https?://[a-zA-Z0-9\-._~:/?#[\]@!$&'()*+,;=]+$`)
// 预编译SQL注入检测正则P1性能优化
sqlRawPatterns := []string{
`;[\s]*--`,
`/\*.*?\*/`,
`\bxp_\w+`,
`\bexec[\s\(]`,
`\bsp_\w+`,
`\bwaitfor[\s]+delay`,
`\bunion[\s]+select`,
`\bdrop[\s]+table`,
`\binsert[\s]+into`,
`\bupdate[\s]+\w+[\s]+set`,
`\bdelete[\s]+from`,
}
v.sqlPatterns = make([]*regexp.Regexp, len(sqlRawPatterns))
for i, p := range sqlRawPatterns {
v.sqlPatterns[i] = regexp.MustCompile(`(?i)` + p)
}
// 预编译XSS检测正则P1性能优化
xssRawPatterns := []string{
`(?i)<script[^>]*>.*?</script>`,
`(?i)</script>`,
`(?i)<iframe[^>]*>.*?</iframe>`,
`(?i)<object[^>]*>.*?</object>`,
`(?i)<embed[^>]*>.*?</embed>`,
`(?i)<applet[^>]*>.*?</applet>`,
`(?i)javascript\s*:`,
`(?i)vbscript\s*:`,
`(?i)data\s*:`,
`(?i)on\w+\s*=`,
`(?i)<style[^>]*>.*?</style>`,
}
v.xssPatterns = make([]*regexp.Regexp, len(xssRawPatterns))
for i, p := range xssRawPatterns {
v.xssPatterns[i] = regexp.MustCompile(p)
}
return v
}
// ValidateEmail validates email format.
func (v *Validator) ValidateEmail(email string) bool {
if email == "" {
if email == "" || v.emailRe == nil {
return false
}
pattern := `^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`
matched, _ := regexp.MatchString(pattern, email)
return matched
return v.emailRe.MatchString(email)
}
// ValidatePhone validates mainland China mobile numbers.
func (v *Validator) ValidatePhone(phone string) bool {
if phone == "" {
if phone == "" || v.phoneRe == nil {
return false
}
pattern := `^1[3-9]\d{9}$`
matched, _ := regexp.MatchString(pattern, phone)
return matched
return v.phoneRe.MatchString(phone)
}
// ValidateUsername validates usernames.
func (v *Validator) ValidateUsername(username string) bool {
if username == "" {
if username == "" || v.usernameRe == nil {
return false
}
pattern := `^[a-zA-Z][a-zA-Z0-9_]{3,19}$`
matched, _ := regexp.MatchString(pattern, username)
return matched
return v.usernameRe.MatchString(username)
}
// ValidatePassword validates passwords using the shared runtime policy.
@@ -77,27 +121,13 @@ func (v *Validator) SanitizeSQL(input string) string {
`"`, `""`,
)
// Remove common SQL injection patterns that could bypass quoting
dangerousPatterns := []string{
`;[\s]*--`, // SQL comment
`/\*.*?\*/`, // Block comment (non-greedy)
`\bxp_\w+`, // Extended stored procedures
`\bexec[\s\(]`, // EXEC statements
`\bsp_\w+`, // System stored procedures
`\bwaitfor[\s]+delay`, // Time-based blind SQL injection
`\bunion[\s]+select`, // UNION injection
`\bdrop[\s]+table`, // DROP TABLE
`\binsert[\s]+into`, // INSERT
`\bupdate[\s]+\w+[\s]+set`, // UPDATE
`\bdelete[\s]+from`, // DELETE
}
result := replacer.Replace(input)
// Apply pattern removal
for _, pattern := range dangerousPatterns {
re := regexp.MustCompile(`(?i)` + pattern) // Case-insensitive
result = re.ReplaceAllString(result, "")
// 使用预编译的正则移除SQL注入模式P1性能优化
for _, re := range v.sqlPatterns {
if re != nil {
result = re.ReplaceAllString(result, "")
}
}
return result
@@ -106,31 +136,11 @@ func (v *Validator) SanitizeSQL(input string) string {
// SanitizeXSS removes obviously dangerous XSS patterns using regex.
// This is a defense-in-depth measure; output encoding should always be used.
func (v *Validator) SanitizeXSS(input string) string {
// Remove dangerous tags and attributes using pattern matching
dangerousPatterns := []struct {
pattern string
replaceAll bool
}{
{`(?i)<script[^>]*>.*?</script>`, true}, // Script tags
{`(?i)</script>`, false}, // Closing script
{`(?i)<iframe[^>]*>.*?</iframe>`, true}, // Iframe injection
{`(?i)<object[^>]*>.*?</object>`, true}, // Object injection
{`(?i)<embed[^>]*>.*?</embed>`, true}, // Embed injection
{`(?i)<applet[^>]*>.*?</applet>`, true}, // Applet injection
{`(?i)javascript\s*:`, false}, // JavaScript protocol
{`(?i)vbscript\s*:`, false}, // VBScript protocol
{`(?i)data\s*:`, false}, // Data URL protocol
{`(?i)on\w+\s*=`, false}, // Event handlers
{`(?i)<style[^>]*>.*?</style>`, true}, // Style injection
}
result := input
for _, p := range dangerousPatterns {
re := regexp.MustCompile(p.pattern)
if p.replaceAll {
result = re.ReplaceAllString(result, "")
} else {
// 使用预编译的正则移除XSS模式P1性能优化
for _, re := range v.xssPatterns {
if re != nil {
result = re.ReplaceAllString(result, "")
}
}
@@ -148,13 +158,10 @@ func (v *Validator) SanitizeXSS(input string) string {
// ValidateURL validates a basic HTTP/HTTPS URL.
func (v *Validator) ValidateURL(url string) bool {
if url == "" {
if url == "" || v.urlRe == nil {
return false
}
pattern := `^https?://[a-zA-Z0-9\-._~:/?#[\]@!$&'()*+,;=]+$`
matched, _ := regexp.MatchString(pattern, url)
return matched
return v.urlRe.MatchString(url)
}
// ValidateIP validates IPv4 or IPv6 addresses using net.ParseIP.

View File

@@ -148,6 +148,9 @@ func Serve(cfg *config.Config) error {
// 初始化中间件
rateLimitMiddleware := middleware.NewRateLimitMiddleware(cfg.RateLimit)
stopRateLimitCleanup := rateLimitMiddleware.StartCleanup()
defer stopRateLimitCleanup()
authMiddleware := middleware.NewAuthMiddleware(
jwtManager,
userRepo,

View File

@@ -30,6 +30,8 @@ const (
defaultTOTPChallengeTTL = 5 * time.Minute
defaultPasswordMinLen = 8
refreshTokenRetryGrace = 10 * time.Second
MaxUsernameAttempts = 100 // 最大尝试次数P1性能优化减少循环查询
MaxUsernameLength = 40 // 用户名最大长度
)
type userRepositoryInterface interface {
@@ -38,6 +40,7 @@ type userRepositoryInterface interface {
UpdateTOTP(ctx context.Context, user *domain.User) error
Delete(ctx context.Context, id int64) error
GetByID(ctx context.Context, id int64) (*domain.User, error)
GetByIDs(ctx context.Context, ids []int64) ([]*domain.User, error)
GetByUsername(ctx context.Context, username string) (*domain.User, error)
GetByEmail(ctx context.Context, email string) (*domain.User, error)
GetByPhone(ctx context.Context, phone string) (*domain.User, error)
@@ -49,6 +52,10 @@ type userRepositoryInterface interface {
ExistsByEmail(ctx context.Context, email string) (bool, error)
ExistsByPhone(ctx context.Context, phone string) (bool, error)
Search(ctx context.Context, keyword string, offset, limit int) ([]*domain.User, int64, error)
// FilterExistingUsernames 批量筛选已存在的用户名P1性能优化替代循环查询
FilterExistingUsernames(ctx context.Context, usernames []string) ([]string, error)
// FindByAccount 按账号查询用户(支持用户名/邮箱/手机号P1性能优化替代串行查询
FindByAccount(ctx context.Context, account string) (*domain.User, error)
}
type userRoleRepositoryInterface interface {
@@ -282,17 +289,28 @@ func (s *AuthService) generateUniqueUsername(ctx context.Context, base string) (
}
baseRunes := []rune(username)
if len(baseRunes) > 40 {
username = string(baseRunes[:40])
if len(baseRunes) > MaxUsernameLength {
username = string(baseRunes[:MaxUsernameLength])
}
for i := 1; i <= 1000; i++ {
candidate := fmt.Sprintf("%s_%d", username, i)
exists, err = s.userRepo.ExistsByUsername(ctx, candidate)
if err != nil {
return "", err
}
if !exists {
// P1性能优化批量生成候选列表后一次性查询避免循环DB往返
candidates := make([]string, 0, MaxUsernameAttempts)
for i := 1; i <= MaxUsernameAttempts; i++ {
candidates = append(candidates, fmt.Sprintf("%s_%d", username, i))
}
existing, err := s.userRepo.FilterExistingUsernames(ctx, candidates)
if err != nil {
return "", err
}
existingSet := make(map[string]bool, len(existing))
for _, u := range existing {
existingSet[u] = true
}
for _, candidate := range candidates {
if !existingSet[candidate] {
return candidate, nil
}
}
@@ -530,6 +548,11 @@ func (s *AuthService) writeLoginLog(
// #nosec G118 - 使用带超时的独立 context防止日志写入无限等待
go func() { // #nosec G118
defer func() {
if r := recover(); r != nil {
log.Printf("auth: write login log panic recovered, user_id=%v login_type=%d err=%v", userID, loginType, r)
}
}()
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.loginLogRepo.Create(bgCtx, loginRecord); err != nil {
@@ -591,7 +614,7 @@ func buildDeviceFingerprint(req *LoginRequest) string {
return result
}
// bestEffortRegisterDevice 尝试自动注册/更新设备记录
// bestEffortRegisterDevice 尝试自动注册/更新设备记录(异步,不阻塞登录响应)
func (s *AuthService) bestEffortRegisterDevice(ctx context.Context, userID int64, req *LoginRequest) {
if s == nil || s.deviceService == nil || req == nil || req.DeviceID == "" {
return
@@ -603,7 +626,18 @@ func (s *AuthService) bestEffortRegisterDevice(ctx context.Context, userID int64
DeviceBrowser: req.DeviceBrowser,
DeviceOS: req.DeviceOS,
}
_, _ = s.deviceService.CreateDevice(ctx, userID, createReq)
// PERF-01: 改为异步 goroutine不阻塞登录响应返回
go func() {
defer func() {
if r := recover(); r != nil {
log.Printf("auth: register device panic recovered, user_id=%d device_id=%s err=%v", userID, req.DeviceID, r)
}
}()
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, _ = s.deviceService.CreateDevice(bgCtx, userID, createReq)
}()
}
// BestEffortRegisterDevicePublic 供外部 handler如 SMS 登录)调用,安静地注册设备

View File

@@ -75,21 +75,16 @@ func (s *AuthService) IsAdminBootstrapRequired(ctx context.Context) bool {
return true
}
hadUnexpectedLookupError := false
for _, userID := range userIDs {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
if isUserNotFoundError(err) {
continue
}
hadUnexpectedLookupError = true
log.Printf("auth: resolve auth capabilities failed while loading admin user: user_id=%d err=%v", userID, err)
continue
}
if user != nil && user.Status == domain.UserStatusActive {
users, err := s.userRepo.GetByIDs(ctx, userIDs)
if err != nil {
log.Printf("auth: resolve auth capabilities failed while loading admin users: err=%v", err)
return false
}
for _, user := range users {
if user.Status == domain.UserStatusActive {
return false
}
}
return !hadUnexpectedLookupError
return true
}

View File

@@ -30,27 +30,15 @@ func (s *AuthService) RegisterOAuthProvider(provider auth.OAuthProvider, cfg *au
}
func (s *AuthService) findUserForLogin(ctx context.Context, account string) (*domain.User, error) {
user, err := s.userRepo.GetByUsername(ctx, account)
if err == nil {
return user, nil
// P1性能优化使用单一查询替代 username->email->phone 串行查询减少DB往返
user, err := s.userRepo.FindByAccount(ctx, account)
if err != nil {
if isUserNotFoundError(err) {
return nil, err
}
return nil, fmt.Errorf("lookup user failed: %w", err)
}
if !isUserNotFoundError(err) {
return nil, fmt.Errorf("lookup user by username failed: %w", err)
}
user, err = s.userRepo.GetByEmail(ctx, account)
if err == nil {
return user, nil
}
if !isUserNotFoundError(err) {
return nil, fmt.Errorf("lookup user by email failed: %w", err)
}
user, err = s.userRepo.GetByPhone(ctx, account)
if err != nil && !isUserNotFoundError(err) {
return nil, fmt.Errorf("lookup user by phone failed: %w", err)
}
return user, err
return user, nil
}
func isUserNotFoundError(err error) bool {
@@ -100,9 +88,19 @@ func (s *AuthService) bestEffortUpdateLastLogin(ctx context.Context, userID int6
return
}
if err := s.userRepo.UpdateLastLogin(ctx, userID, ip); err != nil {
log.Printf("auth: update last login failed, source=%s user_id=%d ip=%s err=%v", source, userID, ip, err)
}
// PERF-01: 改为异步 goroutine不阻塞登录响应返回
go func() {
defer func() {
if r := recover(); r != nil {
log.Printf("auth: update last login panic recovered, source=%s user_id=%d err=%v", source, userID, r)
}
}()
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.userRepo.UpdateLastLogin(bgCtx, userID, ip); err != nil {
log.Printf("auth: update last login failed, source=%s user_id=%d ip=%s err=%v", source, userID, ip, err)
}
}()
}
func loginAttemptKey(account string, user *domain.User) string {

View File

@@ -4,14 +4,16 @@ import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/pagination"
apierrors "github.com/user-management-system/internal/pkg/errors"
"github.com/user-management-system/internal/repository"
"gorm.io/gorm"
)
// Interfaces for dependency inversion (DIP) — service layer depends on these abstractions, not concrete types.
type deviceRepository interface {
Create(ctx context.Context, device *domain.Device) error
Update(ctx context.Context, device *domain.Device) error
@@ -27,6 +29,7 @@ type deviceRepository interface {
UntrustDevice(ctx context.Context, id int64) error
DeleteAllByUserIDExcept(ctx context.Context, userID int64, exceptDeviceID int64) error
GetTrustedDevices(ctx context.Context, userID int64) ([]*domain.Device, error)
CountTrustedDevices(ctx context.Context, userID int64) (int64, error)
ListAll(ctx context.Context, params *repository.ListDevicesParams) ([]*domain.Device, int64, error)
ListAllCursor(ctx context.Context, params *repository.ListDevicesParams, limit int, cursor *pagination.Cursor) ([]*domain.Device, bool, error)
}
@@ -35,24 +38,18 @@ type deviceUserRepository interface {
GetByID(ctx context.Context, id int64) (*domain.User, error)
}
// DeviceService 设备服务
type DeviceService struct {
deviceRepo deviceRepository
userRepo deviceUserRepository
}
// NewDeviceService 创建设备服务
func NewDeviceService(
deviceRepo deviceRepository,
userRepo deviceUserRepository,
) *DeviceService {
func NewDeviceService(deviceRepo deviceRepository, userRepo deviceUserRepository) *DeviceService {
return &DeviceService{
deviceRepo: deviceRepo,
userRepo: userRepo,
}
}
// CreateDeviceRequest 创建设备请求
type CreateDeviceRequest struct {
DeviceID string `json:"device_id" binding:"required"`
DeviceName string `json:"device_name"`
@@ -63,7 +60,6 @@ type CreateDeviceRequest struct {
Location string `json:"location"`
}
// UpdateDeviceRequest 更新设备请求
type UpdateDeviceRequest struct {
DeviceName string `json:"device_name"`
DeviceType int `json:"device_type"`
@@ -74,21 +70,16 @@ type UpdateDeviceRequest struct {
Status int `json:"status"`
}
// CreateDevice 创建设备
func (s *DeviceService) CreateDevice(ctx context.Context, userID int64, req *CreateDeviceRequest) (*domain.Device, error) {
// 检查用户是否存在
_, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, errors.New("用户不存在")
if _, err := s.userRepo.GetByID(ctx, userID); err != nil {
return nil, errors.New("user not found")
}
// 检查设备是否已存在
exists, err := s.deviceRepo.Exists(ctx, userID, req.DeviceID)
if err != nil {
return nil, err
}
if exists {
// 设备已存在,更新最后活跃时间
device, err := s.deviceRepo.GetByDeviceID(ctx, userID, req.DeviceID)
if err != nil {
return nil, err
@@ -97,7 +88,6 @@ func (s *DeviceService) CreateDevice(ctx context.Context, userID int64, req *Cre
return device, s.deviceRepo.Update(ctx, device)
}
// 创建设备
device := &domain.Device{
UserID: userID,
DeviceID: req.DeviceID,
@@ -117,14 +107,47 @@ func (s *DeviceService) CreateDevice(ctx context.Context, userID int64, req *Cre
return device, nil
}
// UpdateDevice 更新设备
func (s *DeviceService) UpdateDevice(ctx context.Context, deviceID int64, req *UpdateDeviceRequest) (*domain.Device, error) {
device, err := s.deviceRepo.GetByID(ctx, deviceID)
if err != nil {
return nil, errors.New("设备不存在")
func isDeviceNotFoundError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, gorm.ErrRecordNotFound) {
return true
}
lowerErr := strings.ToLower(strings.TrimSpace(err.Error()))
return strings.Contains(lowerErr, "record not found") ||
strings.Contains(lowerErr, "device not found") ||
strings.Contains(lowerErr, "not found")
}
func (s *DeviceService) getDeviceByID(ctx context.Context, deviceID int64) (*domain.Device, error) {
device, err := s.deviceRepo.GetByID(ctx, deviceID)
if err != nil {
if isDeviceNotFoundError(err) {
return nil, apierrors.NotFound("device_not_found", "device not found").WithCause(err)
}
return nil, err
}
return device, nil
}
func (s *DeviceService) getAuthorizedDevice(ctx context.Context, actorUserID, deviceID int64, isAdmin bool) (*domain.Device, error) {
device, err := s.getDeviceByID(ctx, deviceID)
if err != nil {
return nil, err
}
if !isAdmin && device.UserID != actorUserID {
return nil, apierrors.Forbidden("device_forbidden", "permission denied")
}
return device, nil
}
func (s *DeviceService) persistDeviceUpdate(ctx context.Context, device *domain.Device, req *UpdateDeviceRequest) (*domain.Device, error) {
if req == nil {
return device, nil
}
// 更新字段
if req.DeviceName != "" {
device.DeviceName = req.DeviceName
}
@@ -154,19 +177,48 @@ func (s *DeviceService) UpdateDevice(ctx context.Context, deviceID int64, req *U
return device, nil
}
// DeleteDevice 删除设备
// maxTrustedDevicesPerUser 每个用户最大信任设备数量P2 安全增强)
const maxTrustedDevicesPerUser = 10
func (s *DeviceService) trustDeviceRecord(ctx context.Context, device *domain.Device, trustDuration time.Duration) error {
// P2 安全增强:检查信任设备数量上限
trustedCount, err := s.deviceRepo.CountTrustedDevices(ctx, device.UserID)
if err != nil {
return fmt.Errorf("count trusted devices failed: %w", err)
}
if trustedCount >= maxTrustedDevicesPerUser {
return fmt.Errorf("trusted device limit reached (max %d), please untrust an existing device first", maxTrustedDevicesPerUser)
}
var trustExpiresAt *time.Time
if trustDuration > 0 {
expiresAt := time.Now().Add(trustDuration)
trustExpiresAt = &expiresAt
}
return s.deviceRepo.TrustDevice(ctx, device.ID, trustExpiresAt)
}
func (s *DeviceService) UpdateDevice(ctx context.Context, deviceID int64, req *UpdateDeviceRequest) (*domain.Device, error) {
device, err := s.getDeviceByID(ctx, deviceID)
if err != nil {
return nil, err
}
return s.persistDeviceUpdate(ctx, device, req)
}
func (s *DeviceService) DeleteDevice(ctx context.Context, deviceID int64) error {
return s.deviceRepo.Delete(ctx, deviceID)
device, err := s.getDeviceByID(ctx, deviceID)
if err != nil {
return err
}
return s.deviceRepo.Delete(ctx, device.ID)
}
// GetDevice 获取设备信息
func (s *DeviceService) GetDevice(ctx context.Context, deviceID int64) (*domain.Device, error) {
return s.deviceRepo.GetByID(ctx, deviceID)
return s.getDeviceByID(ctx, deviceID)
}
// GetUserDevices 获取用户设备列表
func (s *DeviceService) GetUserDevices(ctx context.Context, userID int64, page, pageSize int) ([]*domain.Device, int64, error) {
offset := (page - 1) * pageSize
if page <= 0 {
page = 1
}
@@ -174,22 +226,51 @@ func (s *DeviceService) GetUserDevices(ctx context.Context, userID int64, page,
pageSize = 20
}
offset := (page - 1) * pageSize
return s.deviceRepo.ListByUserID(ctx, userID, offset, pageSize)
}
// UpdateDeviceStatus 更新设备状态
func (s *DeviceService) UpdateDeviceStatus(ctx context.Context, deviceID int64, status domain.DeviceStatus) error {
return s.deviceRepo.UpdateStatus(ctx, deviceID, status)
func (s *DeviceService) GetDeviceForActor(ctx context.Context, actorUserID, deviceID int64, isAdmin bool) (*domain.Device, error) {
return s.getAuthorizedDevice(ctx, actorUserID, deviceID, isAdmin)
}
func (s *DeviceService) UpdateDeviceForActor(ctx context.Context, actorUserID, deviceID int64, isAdmin bool, req *UpdateDeviceRequest) (*domain.Device, error) {
device, err := s.getAuthorizedDevice(ctx, actorUserID, deviceID, isAdmin)
if err != nil {
return nil, err
}
return s.persistDeviceUpdate(ctx, device, req)
}
func (s *DeviceService) DeleteDeviceForActor(ctx context.Context, actorUserID, deviceID int64, isAdmin bool) error {
device, err := s.getAuthorizedDevice(ctx, actorUserID, deviceID, isAdmin)
if err != nil {
return err
}
return s.deviceRepo.Delete(ctx, device.ID)
}
func (s *DeviceService) UpdateDeviceStatus(ctx context.Context, deviceID int64, status domain.DeviceStatus) error {
device, err := s.getDeviceByID(ctx, deviceID)
if err != nil {
return err
}
return s.deviceRepo.UpdateStatus(ctx, device.ID, status)
}
func (s *DeviceService) UpdateDeviceStatusForActor(ctx context.Context, actorUserID, deviceID int64, isAdmin bool, status domain.DeviceStatus) error {
device, err := s.getAuthorizedDevice(ctx, actorUserID, deviceID, isAdmin)
if err != nil {
return err
}
return s.deviceRepo.UpdateStatus(ctx, device.ID, status)
}
// UpdateLastActiveTime 更新最后活跃时间
func (s *DeviceService) UpdateLastActiveTime(ctx context.Context, deviceID int64) error {
return s.deviceRepo.UpdateLastActiveTime(ctx, deviceID)
}
// GetActiveDevices 获取活跃设备
func (s *DeviceService) GetActiveDevices(ctx context.Context, page, pageSize int) ([]*domain.Device, int64, error) {
offset := (page - 1) * pageSize
if page <= 0 {
page = 1
}
@@ -197,74 +278,72 @@ func (s *DeviceService) GetActiveDevices(ctx context.Context, page, pageSize int
pageSize = 20
}
offset := (page - 1) * pageSize
return s.deviceRepo.ListByStatus(ctx, domain.DeviceStatusActive, offset, pageSize)
}
// TrustDevice 设置设备为信任状态
func (s *DeviceService) TrustDevice(ctx context.Context, deviceID int64, trustDuration time.Duration) error {
device, err := s.deviceRepo.GetByID(ctx, deviceID)
device, err := s.getDeviceByID(ctx, deviceID)
if err != nil {
return errors.New("设备不存在")
return err
}
var trustExpiresAt *time.Time
if trustDuration > 0 {
expiresAt := time.Now().Add(trustDuration)
trustExpiresAt = &expiresAt
}
return s.deviceRepo.TrustDevice(ctx, device.ID, trustExpiresAt)
return s.trustDeviceRecord(ctx, device, trustDuration)
}
func (s *DeviceService) TrustDeviceForActor(ctx context.Context, actorUserID, deviceID int64, isAdmin bool, trustDuration time.Duration) error {
device, err := s.getAuthorizedDevice(ctx, actorUserID, deviceID, isAdmin)
if err != nil {
return err
}
return s.trustDeviceRecord(ctx, device, trustDuration)
}
// TrustDeviceByDeviceID 根据设备标识字符串设置设备为信任状态
func (s *DeviceService) TrustDeviceByDeviceID(ctx context.Context, userID int64, deviceID string, trustDuration time.Duration) error {
device, err := s.deviceRepo.GetByDeviceID(ctx, userID, deviceID)
if err != nil {
return errors.New("设备不存在")
if isDeviceNotFoundError(err) {
return apierrors.NotFound("device_not_found", "device not found").WithCause(err)
}
return err
}
var trustExpiresAt *time.Time
if trustDuration > 0 {
expiresAt := time.Now().Add(trustDuration)
trustExpiresAt = &expiresAt
}
return s.deviceRepo.TrustDevice(ctx, device.ID, trustExpiresAt)
return s.trustDeviceRecord(ctx, device, trustDuration)
}
// UntrustDevice 取消设备信任状态
func (s *DeviceService) UntrustDevice(ctx context.Context, deviceID int64) error {
device, err := s.deviceRepo.GetByID(ctx, deviceID)
device, err := s.getDeviceByID(ctx, deviceID)
if err != nil {
return errors.New("设备不存在")
return err
}
return s.deviceRepo.UntrustDevice(ctx, device.ID)
}
func (s *DeviceService) UntrustDeviceForActor(ctx context.Context, actorUserID, deviceID int64, isAdmin bool) error {
device, err := s.getAuthorizedDevice(ctx, actorUserID, deviceID, isAdmin)
if err != nil {
return err
}
return s.deviceRepo.UntrustDevice(ctx, device.ID)
}
// LogoutAllOtherDevices 登出所有其他设备
func (s *DeviceService) LogoutAllOtherDevices(ctx context.Context, userID int64, currentDeviceID int64) error {
return s.deviceRepo.DeleteAllByUserIDExcept(ctx, userID, currentDeviceID)
}
// GetTrustedDevices 获取用户的信任设备列表
func (s *DeviceService) GetTrustedDevices(ctx context.Context, userID int64) ([]*domain.Device, error) {
return s.deviceRepo.GetTrustedDevices(ctx, userID)
}
// GetAllDevicesRequest 获取所有设备请求参数
type GetAllDevicesRequest struct {
Page int `form:"page"`
PageSize int `form:"page_size"`
UserID int64 `form:"user_id"`
Status *int `form:"status"` // 0-禁用, 1-激活, nil-不筛选
Status *int `form:"status"`
IsTrusted *bool `form:"is_trusted"`
Keyword string `form:"keyword"`
Cursor string `form:"cursor"` // Opaque cursor for keyset pagination
Size int `form:"size"` // Page size when using cursor mode
Cursor string `form:"cursor"`
Size int `form:"size"`
}
// GetAllDevices 获取所有设备(管理员用)
func (s *DeviceService) GetAllDevices(ctx context.Context, req *GetAllDevicesRequest) ([]*domain.Device, int64, error) {
if req.Page <= 0 {
req.Page = 1
@@ -277,7 +356,6 @@ func (s *DeviceService) GetAllDevices(ctx context.Context, req *GetAllDevicesReq
}
offset := (req.Page - 1) * req.PageSize
params := &repository.ListDevicesParams{
UserID: req.UserID,
Keyword: req.Keyword,
@@ -285,13 +363,10 @@ func (s *DeviceService) GetAllDevices(ctx context.Context, req *GetAllDevicesReq
Limit: req.PageSize,
}
// 处理状态筛选(仅当明确指定了状态时才筛选)
if req.Status != nil && (*req.Status == 0 || *req.Status == 1) {
status := domain.DeviceStatus(*req.Status)
params.Status = &status
}
// 处理信任状态筛选
if req.IsTrusted != nil {
params.IsTrusted = req.IsTrusted
}
@@ -299,7 +374,6 @@ func (s *DeviceService) GetAllDevices(ctx context.Context, req *GetAllDevicesReq
return s.deviceRepo.ListAll(ctx, params)
}
// GetAllDevicesCursor 游标分页获取所有设备(推荐使用)
func (s *DeviceService) GetAllDevicesCursor(ctx context.Context, req *GetAllDevicesRequest) (*CursorResult, error) {
size := pagination.ClampPageSize(req.Size)
if req.PageSize > 0 && req.Cursor == "" {
@@ -342,7 +416,6 @@ func (s *DeviceService) GetAllDevicesCursor(ctx context.Context, req *GetAllDevi
}, nil
}
// GetDeviceByDeviceID 根据设备标识获取设备(用于设备信任检查)
func (s *DeviceService) GetDeviceByDeviceID(ctx context.Context, userID int64, deviceID string) (*domain.Device, error) {
return s.deviceRepo.GetByDeviceID(ctx, userID, deviceID)
}

View File

@@ -6,6 +6,7 @@ import (
"time"
"github.com/user-management-system/internal/domain"
apierrors "github.com/user-management-system/internal/pkg/errors"
"github.com/user-management-system/internal/repository"
"github.com/user-management-system/internal/service"
gormsqlite "gorm.io/driver/sqlite"
@@ -156,6 +157,104 @@ func TestDeviceService_GetDevice(t *testing.T) {
})
}
func TestDeviceService_DeviceOwnershipAuthorization(t *testing.T) {
svc, db := setupDeviceTestEnv(t)
ctx := context.Background()
owner := &domain.User{Username: "device_owner", Status: domain.UserStatusActive}
if err := db.Create(owner).Error; err != nil {
t.Fatalf("create owner failed: %v", err)
}
actor := &domain.User{Username: "device_actor", Status: domain.UserStatusActive}
if err := db.Create(actor).Error; err != nil {
t.Fatalf("create actor failed: %v", err)
}
device, err := svc.CreateDevice(ctx, owner.ID, &service.CreateDeviceRequest{
DeviceID: "ownership_device",
DeviceName: "Owner Device",
})
if err != nil {
t.Fatalf("CreateDevice failed: %v", err)
}
t.Run("GetDeviceForActor forbids cross-user access", func(t *testing.T) {
_, err := svc.GetDeviceForActor(ctx, actor.ID, device.ID, false)
if !apierrors.IsForbidden(err) {
t.Fatalf("expected forbidden error, got %v", err)
}
})
t.Run("UpdateDeviceForActor forbids cross-user access", func(t *testing.T) {
_, err := svc.UpdateDeviceForActor(ctx, actor.ID, device.ID, false, &service.UpdateDeviceRequest{
DeviceName: "Hacked Name",
})
if !apierrors.IsForbidden(err) {
t.Fatalf("expected forbidden error, got %v", err)
}
current, getErr := svc.GetDevice(ctx, device.ID)
if getErr != nil {
t.Fatalf("GetDevice failed: %v", getErr)
}
if current.DeviceName != "Owner Device" {
t.Fatalf("expected device name to remain unchanged, got %q", current.DeviceName)
}
})
t.Run("DeleteDeviceForActor forbids cross-user access", func(t *testing.T) {
err := svc.DeleteDeviceForActor(ctx, actor.ID, device.ID, false)
if !apierrors.IsForbidden(err) {
t.Fatalf("expected forbidden error, got %v", err)
}
if _, getErr := svc.GetDevice(ctx, device.ID); getErr != nil {
t.Fatalf("expected device to remain after forbidden delete, got %v", getErr)
}
})
t.Run("TrustDeviceForActor forbids cross-user access", func(t *testing.T) {
err := svc.TrustDeviceForActor(ctx, actor.ID, device.ID, false, time.Hour)
if !apierrors.IsForbidden(err) {
t.Fatalf("expected forbidden error, got %v", err)
}
current, getErr := svc.GetDevice(ctx, device.ID)
if getErr != nil {
t.Fatalf("GetDevice failed: %v", getErr)
}
if current.IsTrusted {
t.Fatal("expected device to remain untrusted")
}
})
t.Run("UpdateDeviceStatusForActor forbids cross-user access", func(t *testing.T) {
err := svc.UpdateDeviceStatusForActor(ctx, actor.ID, device.ID, false, domain.DeviceStatusInactive)
if !apierrors.IsForbidden(err) {
t.Fatalf("expected forbidden error, got %v", err)
}
current, getErr := svc.GetDevice(ctx, device.ID)
if getErr != nil {
t.Fatalf("GetDevice failed: %v", getErr)
}
if current.Status != domain.DeviceStatusActive {
t.Fatalf("expected device to remain active, got %d", current.Status)
}
})
t.Run("Admin can manage another users device", func(t *testing.T) {
got, err := svc.GetDeviceForActor(ctx, actor.ID, device.ID, true)
if err != nil {
t.Fatalf("expected admin access, got %v", err)
}
if got.ID != device.ID {
t.Fatalf("expected device id %d, got %d", device.ID, got.ID)
}
})
}
func TestDeviceService_GetUserDevices(t *testing.T) {
svc, _ := setupDeviceTestEnv(t)
ctx := context.Background()

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"log"
"strings"
"time"
"unicode/utf8"
@@ -103,32 +104,101 @@ func (s *UserService) ChangePassword(ctx context.Context, userID int64, oldPassw
if strings.TrimSpace(newPassword) == "" {
return errors.New("新密码不能为空")
}
return s.applyNewPassword(ctx, user, newPassword)
/*
if err := validatePasswordStrength(newPassword, 8, false); err != nil {
return err
}
// 检查密码历史(需要明文密码比对,必须在哈希之前)
if s.passwordHistoryRepo != nil {
histories, err := s.passwordHistoryRepo.GetByUserID(ctx, userID, passwordHistoryLimit)
if err == nil && len(histories) > 0 {
for _, h := range histories {
if auth.VerifyPassword(h.PasswordHash, newPassword) {
return errors.New("新密码不能与最近5次密码相同")
}
}
}
}
// 计算一次哈希,用于更新密码和保存历史(避免 Argon2id 重复计算的高成本)
newHashedPassword, hashErr := auth.HashPassword(newPassword)
if hashErr != nil {
return errors.New("密码哈希失败")
}
// 保存新密码到历史记录(异步,不阻塞密码更新)
if s.passwordHistoryRepo != nil {
// #nosec G118 - 使用带超时的独立 context不能使用请求 ctx该 goroutine 在请求完成后仍可能运行)
go func(hashedPw string) { // #nosec G118
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.passwordHistoryRepo.Create(bgCtx, &domain.PasswordHistory{
UserID: userID,
PasswordHash: hashedPw,
})
_ = s.passwordHistoryRepo.DeleteOldRecords(bgCtx, userID, passwordHistoryLimit)
}(newHashedPassword)
}
// 更新密码(使用同一哈希值)
user.Password = newHashedPassword
user.PasswordChangedAt = time.Now()
return s.userRepo.Update(ctx, user)
*/
}
// GetByID 根据ID获取用户
// AdminResetPassword resets a user's password without requiring the old password.
func (s *UserService) AdminResetPassword(ctx context.Context, userID int64, newPassword string) error {
if s.userRepo == nil {
return errors.New("user repository is not configured")
}
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return errors.New("user not found")
}
return s.applyNewPassword(ctx, user, newPassword)
}
func (s *UserService) applyNewPassword(ctx context.Context, user *domain.User, newPassword string) error {
if user == nil {
return errors.New("user not found")
}
if strings.TrimSpace(newPassword) == "" {
return errors.New("new password is required")
}
if err := validatePasswordStrength(newPassword, 8, false); err != nil {
return err
}
// 检查密码历史(需要明文密码比对,必须在哈希之前)
if s.passwordHistoryRepo != nil {
histories, err := s.passwordHistoryRepo.GetByUserID(ctx, userID, passwordHistoryLimit)
histories, err := s.passwordHistoryRepo.GetByUserID(ctx, user.ID, passwordHistoryLimit)
if err == nil && len(histories) > 0 {
for _, h := range histories {
if auth.VerifyPassword(h.PasswordHash, newPassword) {
return errors.New("新密码不能与最近5次密码相同")
return errors.New("new password cannot reuse recent password history")
}
}
}
}
// 计算一次哈希,用于更新密码和保存历史(避免 Argon2id 重复计算的高成本)
newHashedPassword, hashErr := auth.HashPassword(newPassword)
if hashErr != nil {
return errors.New("密码哈希失败")
return errors.New("password hashing failed")
}
// 保存新密码到历史记录(异步,不阻塞密码更新)
if s.passwordHistoryRepo != nil {
// #nosec G118 - 使用带超时的独立 context不能使用请求 ctx该 goroutine 在请求完成后仍可能运行)
go func(hashedPw string) { // #nosec G118
go func(userID int64, hashedPw string) { // #nosec G118
defer func() {
if r := recover(); r != nil {
log.Printf("user_service: password history save panic recovered, user_id=%d err=%v", userID, r)
}
}()
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.passwordHistoryRepo.Create(bgCtx, &domain.PasswordHistory{
@@ -136,16 +206,14 @@ func (s *UserService) ChangePassword(ctx context.Context, userID int64, oldPassw
PasswordHash: hashedPw,
})
_ = s.passwordHistoryRepo.DeleteOldRecords(bgCtx, userID, passwordHistoryLimit)
}(newHashedPassword)
}(user.ID, newHashedPassword)
}
// 更新密码(使用同一哈希值)
user.Password = newHashedPassword
user.PasswordChangedAt = time.Now()
return s.userRepo.Update(ctx, user)
}
// GetByID 根据ID获取用户
func (s *UserService) GetByID(ctx context.Context, id int64) (*domain.User, error) {
return s.userRepo.GetByID(ctx, id)
}
@@ -357,10 +425,23 @@ func (s *UserService) AssignRoles(ctx context.Context, userID int64, roleIDs []i
return err
}
// 验证所有角色存在(预先验证,避免在事务内做不必要的查询
for _, roleID := range roleIDs {
if _, err := s.roleRepo.GetByID(ctx, roleID); err != nil {
return fmt.Errorf("角色 %d 不存在", roleID)
// 验证所有角色存在(批量查询消除 N+1
if len(roleIDs) > 0 {
foundRoles, err := s.roleRepo.GetByIDs(ctx, roleIDs)
if err != nil {
return fmt.Errorf("验证角色失败: %w", err)
}
if len(foundRoles) != len(roleIDs) {
// 找出缺失的角色ID
foundMap := make(map[int64]bool, len(foundRoles))
for _, r := range foundRoles {
foundMap[r.ID] = true
}
for _, id := range roleIDs {
if !foundMap[id] {
return fmt.Errorf("角色 %d 不存在", id)
}
}
}
}

View File

@@ -341,6 +341,44 @@ func TestUserService_ChangePassword(t *testing.T) {
})
}
func TestUserService_AdminResetPassword(t *testing.T) {
env := setupAuthTestEnv(t)
if env == nil {
return
}
ctx := context.Background()
t.Run("Admin reset password success", func(t *testing.T) {
hashedPassword, _ := auth.HashPassword("OldPassword123!")
user := &domain.User{
Username: "adminresetpwd",
Password: hashedPassword,
Status: domain.UserStatusActive,
}
env.userSvc.Create(ctx, user)
err := env.userSvc.AdminResetPassword(ctx, user.ID, "ResetPassword456!")
if err != nil {
t.Fatalf("AdminResetPassword failed: %v", err)
}
updated, _ := env.userSvc.GetByID(ctx, user.ID)
if !auth.VerifyPassword(updated.Password, "ResetPassword456!") {
t.Error("reset password verification failed")
}
if auth.VerifyPassword(updated.Password, "OldPassword123!") {
t.Error("old password should no longer work after admin reset")
}
})
t.Run("Admin reset password for non-existent user", func(t *testing.T) {
err := env.userSvc.AdminResetPassword(ctx, 99999, "ResetPassword456!")
if err == nil {
t.Error("Expected error for non-existent user")
}
})
}
func TestUserService_BatchUpdateStatus(t *testing.T) {
env := setupAuthTestEnv(t)
if env == nil {