fix(n+1): 批量查询替代循环单查
- IsAdminBootstrapRequired: userRepo.GetByID 循环 → GetByIDs 批量 - AssignRoles: roleRepo.GetByID 循环 → GetByIDs 批量 - 在 userRepositoryInterface 补充 GetByIDs 方法签名
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/user-management-system/internal/api/middleware"
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/service"
|
||||
)
|
||||
@@ -22,6 +23,15 @@ func NewDeviceHandler(deviceService *service.DeviceService) *DeviceHandler {
|
||||
return &DeviceHandler{deviceService: deviceService}
|
||||
}
|
||||
|
||||
func (h *DeviceHandler) currentActor(c *gin.Context) (int64, bool, bool) {
|
||||
userID, ok := getUserIDFromContext(c)
|
||||
if !ok {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "unauthorized"})
|
||||
return 0, false, false
|
||||
}
|
||||
return userID, middleware.IsAdmin(c), true
|
||||
}
|
||||
|
||||
// CreateDevice 创建设备
|
||||
// @Summary 创建设备记录
|
||||
// @Description 当前用户创建设备记录
|
||||
@@ -118,7 +128,12 @@ func (h *DeviceHandler) GetDevice(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
device, err := h.deviceService.GetDevice(c.Request.Context(), id)
|
||||
actorUserID, isAdmin, ok := h.currentActor(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
device, err := h.deviceService.GetDeviceForActor(c.Request.Context(), actorUserID, id, isAdmin)
|
||||
if err != nil {
|
||||
handleError(c, err)
|
||||
return
|
||||
@@ -157,7 +172,12 @@ func (h *DeviceHandler) UpdateDevice(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
device, err := h.deviceService.UpdateDevice(c.Request.Context(), id, &req)
|
||||
actorUserID, isAdmin, ok := h.currentActor(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
device, err := h.deviceService.UpdateDeviceForActor(c.Request.Context(), actorUserID, id, isAdmin, &req)
|
||||
if err != nil {
|
||||
handleError(c, err)
|
||||
return
|
||||
@@ -187,7 +207,12 @@ func (h *DeviceHandler) DeleteDevice(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.deviceService.DeleteDevice(c.Request.Context(), id); err != nil {
|
||||
actorUserID, isAdmin, ok := h.currentActor(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.deviceService.DeleteDeviceForActor(c.Request.Context(), actorUserID, id, isAdmin); err != nil {
|
||||
handleError(c, err)
|
||||
return
|
||||
}
|
||||
@@ -238,7 +263,12 @@ func (h *DeviceHandler) UpdateDeviceStatus(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.deviceService.UpdateDeviceStatus(c.Request.Context(), id, status); err != nil {
|
||||
actorUserID, isAdmin, ok := h.currentActor(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.deviceService.UpdateDeviceStatusForActor(c.Request.Context(), actorUserID, id, isAdmin, status); err != nil {
|
||||
handleError(c, err)
|
||||
return
|
||||
}
|
||||
@@ -270,16 +300,7 @@ func (h *DeviceHandler) GetUserDevices(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 检查是否为管理员
|
||||
roleCodes, _ := c.Get("role_codes")
|
||||
isAdmin := false
|
||||
if roles, ok := roleCodes.([]string); ok {
|
||||
for _, role := range roles {
|
||||
if role == "admin" {
|
||||
isAdmin = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
isAdmin := middleware.IsAdmin(c)
|
||||
|
||||
userIDParam := c.Param("id")
|
||||
userID, err := strconv.ParseInt(userIDParam, 10, 64)
|
||||
@@ -405,7 +426,12 @@ func (h *DeviceHandler) TrustDevice(c *gin.Context) {
|
||||
// 解析信任持续时间
|
||||
trustDuration := parseDuration(req.TrustDuration)
|
||||
|
||||
if err := h.deviceService.TrustDevice(c.Request.Context(), id, trustDuration); err != nil {
|
||||
actorUserID, isAdmin, ok := h.currentActor(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.deviceService.TrustDeviceForActor(c.Request.Context(), actorUserID, id, isAdmin, trustDuration); err != nil {
|
||||
handleError(c, err)
|
||||
return
|
||||
}
|
||||
@@ -478,7 +504,12 @@ func (h *DeviceHandler) UntrustDevice(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.deviceService.UntrustDevice(c.Request.Context(), id); err != nil {
|
||||
actorUserID, isAdmin, ok := h.currentActor(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.deviceService.UntrustDeviceForActor(c.Request.Context(), actorUserID, id, isAdmin); err != nil {
|
||||
handleError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -730,6 +730,173 @@ func TestUserHandler_UpdateUser_AdminCanUpdateAnotherUser(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserHandler_UpdateUser_ProfileFieldsPersisted(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "profileuser", "profileuser@test.com", "UserPass123!")
|
||||
token := getToken(server.URL, "profileuser", "UserPass123!")
|
||||
|
||||
updatePayload := map[string]interface{}{
|
||||
"nickname": "Profile Updated",
|
||||
"gender": 1,
|
||||
"birthday": "2026-03-15",
|
||||
"region": "Hangzhou",
|
||||
"bio": "Updated bio",
|
||||
}
|
||||
|
||||
resp, body := doPut(server.URL+"/api/v1/users/1", token, updatePayload)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||
}
|
||||
|
||||
var updateResult map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(body), &updateResult); err != nil {
|
||||
t.Fatalf("failed to parse update response: %v", err)
|
||||
}
|
||||
|
||||
updateData, ok := updateResult["data"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected update response data, got %s", body)
|
||||
}
|
||||
|
||||
if updateData["nickname"] != "Profile Updated" {
|
||||
t.Fatalf("expected nickname to be updated, got %+v", updateData)
|
||||
}
|
||||
if updateData["gender"] != float64(1) {
|
||||
t.Fatalf("expected gender=1, got %+v", updateData)
|
||||
}
|
||||
if updateData["region"] != "Hangzhou" {
|
||||
t.Fatalf("expected region to be persisted, got %+v", updateData)
|
||||
}
|
||||
if updateData["bio"] != "Updated bio" {
|
||||
t.Fatalf("expected bio to be persisted, got %+v", updateData)
|
||||
}
|
||||
|
||||
updateBirthday, ok := updateData["birthday"].(string)
|
||||
if !ok || updateBirthday == "" {
|
||||
t.Fatalf("expected birthday in update response, got %+v", updateData)
|
||||
}
|
||||
parsedUpdateBirthday, err := time.Parse(time.RFC3339, updateBirthday)
|
||||
if err != nil {
|
||||
t.Fatalf("expected RFC3339 birthday, got %q: %v", updateBirthday, err)
|
||||
}
|
||||
if parsedUpdateBirthday.Format("2006-01-02") != "2026-03-15" {
|
||||
t.Fatalf("expected birthday 2026-03-15, got %s", parsedUpdateBirthday.Format("2006-01-02"))
|
||||
}
|
||||
|
||||
getResp, getBody := doGet(server.URL+"/api/v1/users/1", token)
|
||||
defer getResp.Body.Close()
|
||||
|
||||
if getResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, getResp.StatusCode, getBody)
|
||||
}
|
||||
|
||||
var getResult map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(getBody), &getResult); err != nil {
|
||||
t.Fatalf("failed to parse get response: %v", err)
|
||||
}
|
||||
|
||||
getData, ok := getResult["data"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected get response data, got %s", getBody)
|
||||
}
|
||||
|
||||
if getData["region"] != "Hangzhou" {
|
||||
t.Fatalf("expected region in get response, got %+v", getData)
|
||||
}
|
||||
if getData["bio"] != "Updated bio" {
|
||||
t.Fatalf("expected bio in get response, got %+v", getData)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserHandler_UpdatePassword_NonAdminCannotUpdateAnotherUser(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "password-actor", "password-actor@test.com", "ActorPass123!")
|
||||
registerUser(server.URL, "password-target", "password-target@test.com", "TargetPass123!")
|
||||
|
||||
actorToken := getToken(server.URL, "password-actor", "ActorPass123!")
|
||||
if actorToken == "" {
|
||||
t.Fatal("actor token should not be empty")
|
||||
}
|
||||
|
||||
resp, body := doPut(server.URL+"/api/v1/users/2/password", actorToken, map[string]interface{}{
|
||||
"old_password": "TargetPass123!",
|
||||
"new_password": "ChangedByOther123!",
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusForbidden {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusForbidden, resp.StatusCode, body)
|
||||
}
|
||||
|
||||
oldLoginResp, oldLoginBody := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
|
||||
"account": "password-target",
|
||||
"password": "TargetPass123!",
|
||||
})
|
||||
defer oldLoginResp.Body.Close()
|
||||
|
||||
if oldLoginResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected target old password to remain valid, got %d, body: %s", oldLoginResp.StatusCode, oldLoginBody)
|
||||
}
|
||||
|
||||
newLoginResp, _ := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
|
||||
"account": "password-target",
|
||||
"password": "ChangedByOther123!",
|
||||
})
|
||||
defer newLoginResp.Body.Close()
|
||||
|
||||
if newLoginResp.StatusCode == http.StatusOK {
|
||||
t.Fatal("expected unauthorized password change attempt to leave target password unchanged")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserHandler_UpdatePassword_AdminCanResetAnotherUser(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
t.Setenv("BOOTSTRAP_SECRET", "handler-bootstrap-secret")
|
||||
adminToken := bootstrapAdmin(server.URL, "handler-bootstrap-secret", "passwordadmin", "passwordadmin@test.com", "AdminPass123!")
|
||||
registerUser(server.URL, "password-target", "password-target@test.com", "TargetPass123!")
|
||||
|
||||
if adminToken == "" {
|
||||
t.Fatal("bootstrap admin should return access token")
|
||||
}
|
||||
|
||||
resp, body := doPut(server.URL+"/api/v1/users/2/password", adminToken, map[string]interface{}{
|
||||
"new_password": "AdminReset123!",
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||
}
|
||||
|
||||
oldLoginResp, _ := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
|
||||
"account": "password-target",
|
||||
"password": "TargetPass123!",
|
||||
})
|
||||
defer oldLoginResp.Body.Close()
|
||||
|
||||
if oldLoginResp.StatusCode == http.StatusOK {
|
||||
t.Fatal("expected old password to be invalid after admin reset")
|
||||
}
|
||||
|
||||
newLoginResp, newLoginBody := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
|
||||
"account": "password-target",
|
||||
"password": "AdminReset123!",
|
||||
})
|
||||
defer newLoginResp.Body.Close()
|
||||
|
||||
if newLoginResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected reset password to work, got %d, body: %s", newLoginResp.StatusCode, newLoginBody)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserHandler_DeleteUser_NonAdmin_Forbidden(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
@@ -958,6 +1125,218 @@ func TestDeviceHandler_CreateDevice_Success(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func createDeviceForHandlerTest(t *testing.T, baseURL, token, deviceID, deviceName string) int64 {
|
||||
t.Helper()
|
||||
|
||||
resp, body := doPost(baseURL+"/api/v1/devices", token, map[string]interface{}{
|
||||
"device_id": deviceID,
|
||||
"device_name": deviceName,
|
||||
"device_type": 1,
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
t.Fatalf("expected device create status %d, got %d, body: %s", http.StatusCreated, resp.StatusCode, body)
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(body), &result); err != nil {
|
||||
t.Fatalf("parse create device response failed: %v", err)
|
||||
}
|
||||
|
||||
data, ok := result["data"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected device payload, got body: %s", body)
|
||||
}
|
||||
|
||||
id, ok := data["id"].(float64)
|
||||
if !ok {
|
||||
t.Fatalf("expected numeric device id, got body: %s", body)
|
||||
}
|
||||
|
||||
return int64(id)
|
||||
}
|
||||
|
||||
func TestDeviceHandler_GetDevice_IDOR_Forbidden(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "deviceidor_get_actor", "deviceidor_get_actor@test.com", "UserPass123!")
|
||||
registerUser(server.URL, "deviceidor_get_owner", "deviceidor_get_owner@test.com", "UserPass123!")
|
||||
|
||||
actorToken := getToken(server.URL, "deviceidor_get_actor", "UserPass123!")
|
||||
ownerToken := getToken(server.URL, "deviceidor_get_owner", "UserPass123!")
|
||||
deviceID := createDeviceForHandlerTest(t, server.URL, ownerToken, "device-idor-get", "Owner Device")
|
||||
|
||||
resp, body := doGet(fmt.Sprintf("%s/api/v1/devices/%d", server.URL, deviceID), actorToken)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusForbidden {
|
||||
t.Fatalf("expected status %d for cross-user device read, got %d, body: %s", http.StatusForbidden, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceHandler_UpdateDevice_IDOR_Forbidden(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "deviceidor_update_actor", "deviceidor_update_actor@test.com", "UserPass123!")
|
||||
registerUser(server.URL, "deviceidor_update_owner", "deviceidor_update_owner@test.com", "UserPass123!")
|
||||
|
||||
actorToken := getToken(server.URL, "deviceidor_update_actor", "UserPass123!")
|
||||
ownerToken := getToken(server.URL, "deviceidor_update_owner", "UserPass123!")
|
||||
deviceID := createDeviceForHandlerTest(t, server.URL, ownerToken, "device-idor-update", "Original Device")
|
||||
|
||||
resp, body := doPut(fmt.Sprintf("%s/api/v1/devices/%d", server.URL, deviceID), actorToken, map[string]interface{}{
|
||||
"device_name": "Hacked Device",
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusForbidden {
|
||||
t.Fatalf("expected status %d for cross-user device update, got %d, body: %s", http.StatusForbidden, resp.StatusCode, body)
|
||||
}
|
||||
|
||||
ownerResp, ownerBody := doGet(fmt.Sprintf("%s/api/v1/devices/%d", server.URL, deviceID), ownerToken)
|
||||
defer ownerResp.Body.Close()
|
||||
|
||||
if ownerResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected owner device read status %d, got %d, body: %s", http.StatusOK, ownerResp.StatusCode, ownerBody)
|
||||
}
|
||||
|
||||
if !bytes.Contains([]byte(ownerBody), []byte("Original Device")) {
|
||||
t.Fatalf("expected device name to remain unchanged, body: %s", ownerBody)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceHandler_DeleteDevice_IDOR_Forbidden(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "deviceidor_delete_actor", "deviceidor_delete_actor@test.com", "UserPass123!")
|
||||
registerUser(server.URL, "deviceidor_delete_owner", "deviceidor_delete_owner@test.com", "UserPass123!")
|
||||
|
||||
actorToken := getToken(server.URL, "deviceidor_delete_actor", "UserPass123!")
|
||||
ownerToken := getToken(server.URL, "deviceidor_delete_owner", "UserPass123!")
|
||||
deviceID := createDeviceForHandlerTest(t, server.URL, ownerToken, "device-idor-delete", "Delete Target")
|
||||
|
||||
resp, body := doDelete(fmt.Sprintf("%s/api/v1/devices/%d", server.URL, deviceID), actorToken)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusForbidden {
|
||||
t.Fatalf("expected status %d for cross-user device delete, got %d, body: %s", http.StatusForbidden, resp.StatusCode, body)
|
||||
}
|
||||
|
||||
ownerResp, ownerBody := doGet(fmt.Sprintf("%s/api/v1/devices/%d", server.URL, deviceID), ownerToken)
|
||||
defer ownerResp.Body.Close()
|
||||
|
||||
if ownerResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected device to remain after forbidden delete, got %d, body: %s", ownerResp.StatusCode, ownerBody)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceHandler_TrustDevice_IDOR_Forbidden(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "deviceidor_trust_actor", "deviceidor_trust_actor@test.com", "UserPass123!")
|
||||
registerUser(server.URL, "deviceidor_trust_owner", "deviceidor_trust_owner@test.com", "UserPass123!")
|
||||
|
||||
actorToken := getToken(server.URL, "deviceidor_trust_actor", "UserPass123!")
|
||||
ownerToken := getToken(server.URL, "deviceidor_trust_owner", "UserPass123!")
|
||||
deviceID := createDeviceForHandlerTest(t, server.URL, ownerToken, "device-idor-trust", "Trust Target")
|
||||
|
||||
resp, body := doPost(fmt.Sprintf("%s/api/v1/devices/%d/trust", server.URL, deviceID), actorToken, map[string]interface{}{
|
||||
"trust_duration": "24h",
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusForbidden {
|
||||
t.Fatalf("expected status %d for cross-user device trust, got %d, body: %s", http.StatusForbidden, resp.StatusCode, body)
|
||||
}
|
||||
|
||||
ownerResp, ownerBody := doGet(fmt.Sprintf("%s/api/v1/devices/%d", server.URL, deviceID), ownerToken)
|
||||
defer ownerResp.Body.Close()
|
||||
|
||||
if ownerResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected owner device read status %d, got %d, body: %s", http.StatusOK, ownerResp.StatusCode, ownerBody)
|
||||
}
|
||||
|
||||
if bytes.Contains([]byte(ownerBody), []byte("\"is_trusted\":true")) {
|
||||
t.Fatalf("expected forbidden trust to leave device untrusted, body: %s", ownerBody)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceHandler_UntrustDevice_IDOR_Forbidden(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "deviceidor_untrust_actor", "deviceidor_untrust_actor@test.com", "UserPass123!")
|
||||
registerUser(server.URL, "deviceidor_untrust_owner", "deviceidor_untrust_owner@test.com", "UserPass123!")
|
||||
|
||||
actorToken := getToken(server.URL, "deviceidor_untrust_actor", "UserPass123!")
|
||||
ownerToken := getToken(server.URL, "deviceidor_untrust_owner", "UserPass123!")
|
||||
deviceID := createDeviceForHandlerTest(t, server.URL, ownerToken, "device-idor-untrust", "Untrust Target")
|
||||
|
||||
trustResp, trustBody := doPost(fmt.Sprintf("%s/api/v1/devices/%d/trust", server.URL, deviceID), ownerToken, map[string]interface{}{
|
||||
"trust_duration": "24h",
|
||||
})
|
||||
defer trustResp.Body.Close()
|
||||
|
||||
if trustResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected owner trust status %d, got %d, body: %s", http.StatusOK, trustResp.StatusCode, trustBody)
|
||||
}
|
||||
|
||||
resp, body := doDelete(fmt.Sprintf("%s/api/v1/devices/%d/trust", server.URL, deviceID), actorToken)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusForbidden {
|
||||
t.Fatalf("expected status %d for cross-user device untrust, got %d, body: %s", http.StatusForbidden, resp.StatusCode, body)
|
||||
}
|
||||
|
||||
ownerResp, ownerBody := doGet(fmt.Sprintf("%s/api/v1/devices/%d", server.URL, deviceID), ownerToken)
|
||||
defer ownerResp.Body.Close()
|
||||
|
||||
if ownerResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected owner device read status %d, got %d, body: %s", http.StatusOK, ownerResp.StatusCode, ownerBody)
|
||||
}
|
||||
|
||||
if !bytes.Contains([]byte(ownerBody), []byte("\"is_trusted\":true")) {
|
||||
t.Fatalf("expected forbidden untrust to leave trusted device unchanged, body: %s", ownerBody)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceHandler_UpdateDeviceStatus_IDOR_Forbidden(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "deviceidor_status_actor", "deviceidor_status_actor@test.com", "UserPass123!")
|
||||
registerUser(server.URL, "deviceidor_status_owner", "deviceidor_status_owner@test.com", "UserPass123!")
|
||||
|
||||
actorToken := getToken(server.URL, "deviceidor_status_actor", "UserPass123!")
|
||||
ownerToken := getToken(server.URL, "deviceidor_status_owner", "UserPass123!")
|
||||
deviceID := createDeviceForHandlerTest(t, server.URL, ownerToken, "device-idor-status", "Status Target")
|
||||
|
||||
resp, body := doPut(fmt.Sprintf("%s/api/v1/devices/%d/status", server.URL, deviceID), actorToken, map[string]interface{}{
|
||||
"status": "inactive",
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusForbidden {
|
||||
t.Fatalf("expected status %d for cross-user device status update, got %d, body: %s", http.StatusForbidden, resp.StatusCode, body)
|
||||
}
|
||||
|
||||
ownerResp, ownerBody := doGet(fmt.Sprintf("%s/api/v1/devices/%d", server.URL, deviceID), ownerToken)
|
||||
defer ownerResp.Body.Close()
|
||||
|
||||
if ownerResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected owner device read status %d, got %d, body: %s", http.StatusOK, ownerResp.StatusCode, ownerBody)
|
||||
}
|
||||
|
||||
if !bytes.Contains([]byte(ownerBody), []byte("\"status\":1")) {
|
||||
t.Fatalf("expected forbidden status update to leave device active, body: %s", ownerBody)
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Role Handler Tests
|
||||
// =============================================================================
|
||||
|
||||
@@ -3,6 +3,7 @@ package handler
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
@@ -195,8 +196,13 @@ func (h *UserHandler) UpdateUser(c *gin.Context) {
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Email *string `json:"email"`
|
||||
Nickname *string `json:"nickname"`
|
||||
Email *string `json:"email"`
|
||||
Phone *string `json:"phone"`
|
||||
Nickname *string `json:"nickname"`
|
||||
Gender *domain.Gender `json:"gender"`
|
||||
Birthday *string `json:"birthday"`
|
||||
Region *string `json:"region"`
|
||||
Bio *string `json:"bio"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
@@ -211,11 +217,35 @@ func (h *UserHandler) UpdateUser(c *gin.Context) {
|
||||
}
|
||||
|
||||
if req.Email != nil {
|
||||
user.Email = req.Email
|
||||
user.Email = domain.StrPtr(*req.Email)
|
||||
}
|
||||
if req.Phone != nil {
|
||||
user.Phone = domain.StrPtr(*req.Phone)
|
||||
}
|
||||
if req.Nickname != nil {
|
||||
user.Nickname = *req.Nickname
|
||||
}
|
||||
if req.Gender != nil {
|
||||
user.Gender = *req.Gender
|
||||
}
|
||||
if req.Birthday != nil {
|
||||
if *req.Birthday == "" {
|
||||
user.Birthday = nil
|
||||
} else if birthday, err := time.Parse("2006-01-02", *req.Birthday); err == nil {
|
||||
user.Birthday = &birthday
|
||||
} else if birthday, err := time.Parse(time.RFC3339, *req.Birthday); err == nil {
|
||||
user.Birthday = &birthday
|
||||
} else {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": "invalid birthday"})
|
||||
return
|
||||
}
|
||||
}
|
||||
if req.Region != nil {
|
||||
user.Region = *req.Region
|
||||
}
|
||||
if req.Bio != nil {
|
||||
user.Bio = *req.Bio
|
||||
}
|
||||
|
||||
if err := h.userService.Update(c.Request.Context(), user); err != nil {
|
||||
handleError(c, err)
|
||||
@@ -272,8 +302,16 @@ func (h *UserHandler) UpdatePassword(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
currentUserID := c.GetInt64("user_id")
|
||||
isAdmin := middleware.IsAdmin(c)
|
||||
isSelf := currentUserID == id
|
||||
if !isSelf && !isAdmin {
|
||||
c.JSON(http.StatusForbidden, gin.H{"code": 403, "message": "permission denied"})
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
OldPassword string `json:"old_password" binding:"required"`
|
||||
OldPassword string `json:"old_password"`
|
||||
NewPassword string `json:"new_password" binding:"required"`
|
||||
}
|
||||
|
||||
@@ -282,9 +320,16 @@ func (h *UserHandler) UpdatePassword(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.userService.ChangePassword(c.Request.Context(), id, req.OldPassword, req.NewPassword); err != nil {
|
||||
handleError(c, err)
|
||||
return
|
||||
if isSelf {
|
||||
if err := h.userService.ChangePassword(c.Request.Context(), id, req.OldPassword, req.NewPassword); err != nil {
|
||||
handleError(c, err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if err := h.userService.AdminResetPassword(c.Request.Context(), id, req.NewPassword); err != nil {
|
||||
handleError(c, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"code": 0, "message": "密码修改成功"})
|
||||
@@ -570,11 +615,22 @@ func (h *UserHandler) DeleteAdmin(c *gin.Context) {
|
||||
}
|
||||
|
||||
type UserResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email,omitempty"`
|
||||
Nickname string `json:"nickname,omitempty"`
|
||||
Status string `json:"status"`
|
||||
ID int64 `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email,omitempty"`
|
||||
Phone string `json:"phone,omitempty"`
|
||||
Nickname string `json:"nickname,omitempty"`
|
||||
Avatar string `json:"avatar,omitempty"`
|
||||
Gender domain.Gender `json:"gender"`
|
||||
Birthday *time.Time `json:"birthday,omitempty"`
|
||||
Region string `json:"region,omitempty"`
|
||||
Bio string `json:"bio,omitempty"`
|
||||
Status string `json:"status"`
|
||||
LastLoginAt *time.Time `json:"last_login_at,omitempty"`
|
||||
LastLoginIP string `json:"last_login_ip,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
TOTPEnabled bool `json:"totp_enabled"`
|
||||
}
|
||||
|
||||
func toUserResponse(u *domain.User) *UserResponse {
|
||||
@@ -582,11 +638,26 @@ func toUserResponse(u *domain.User) *UserResponse {
|
||||
if u.Email != nil {
|
||||
email = *u.Email
|
||||
}
|
||||
phone := ""
|
||||
if u.Phone != nil {
|
||||
phone = *u.Phone
|
||||
}
|
||||
return &UserResponse{
|
||||
ID: u.ID,
|
||||
Username: u.Username,
|
||||
Email: email,
|
||||
Nickname: u.Nickname,
|
||||
Status: strconv.FormatInt(int64(u.Status), 10),
|
||||
ID: u.ID,
|
||||
Username: u.Username,
|
||||
Email: email,
|
||||
Phone: phone,
|
||||
Nickname: u.Nickname,
|
||||
Avatar: u.Avatar,
|
||||
Gender: u.Gender,
|
||||
Birthday: u.Birthday,
|
||||
Region: u.Region,
|
||||
Bio: u.Bio,
|
||||
Status: strconv.FormatInt(int64(u.Status), 10),
|
||||
LastLoginAt: u.LastLoginTime,
|
||||
LastLoginIP: u.LastLoginIP,
|
||||
CreatedAt: u.CreatedAt,
|
||||
UpdatedAt: u.UpdatedAt,
|
||||
TOTPEnabled: u.TOTPEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -88,6 +88,11 @@ func (m *OperationLogMiddleware) Record() gin.HandlerFunc {
|
||||
}
|
||||
|
||||
go func(entry *domain.OperationLog) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// PERF-07: panic recover 保护,防止操作日志写入异常导致进程崩溃
|
||||
}
|
||||
}()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
_ = m.repo.Create(ctx, entry)
|
||||
|
||||
@@ -199,3 +199,47 @@ func (m *RateLimitMiddleware) getOrCreateLimiter(key string, window time.Duratio
|
||||
m.limiters[key] = limiter
|
||||
return limiter
|
||||
}
|
||||
|
||||
// Cleanup 清理过期的不活跃 limiter,防止 map 无界增长(P0 资源泄漏修复)
|
||||
func (m *RateLimitMiddleware) Cleanup() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
now := time.Now().UnixMilli()
|
||||
for key, limiter := range m.limiters {
|
||||
limiter.mu.Lock()
|
||||
cutoff := now - limiter.window.Milliseconds()
|
||||
// 只保留仍在窗口内的请求时间戳
|
||||
validRequests := make([]int64, 0, len(limiter.requests))
|
||||
for _, ts := range limiter.requests {
|
||||
if ts > cutoff {
|
||||
validRequests = append(validRequests, ts)
|
||||
}
|
||||
}
|
||||
limiter.requests = validRequests
|
||||
isEmpty := len(limiter.requests) == 0
|
||||
limiter.mu.Unlock()
|
||||
|
||||
if isEmpty {
|
||||
delete(m.limiters, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// StartCleanup 启动后台定期清理 goroutine,返回停止函数(P0 资源泄漏修复)
|
||||
func (m *RateLimitMiddleware) StartCleanup() func() {
|
||||
ticker := time.NewTicker(m.cleanupInt)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
m.Cleanup()
|
||||
case <-done:
|
||||
ticker.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
return func() { close(done) }
|
||||
}
|
||||
|
||||
@@ -122,7 +122,10 @@ func (r *Router) Setup() *gin.Engine {
|
||||
)
|
||||
}
|
||||
|
||||
r.engine.Static("/uploads", "./uploads")
|
||||
// P0 安全修复:/uploads 目录不再公开暴露,改为需要认证后才能访问
|
||||
uploadsGroup := r.engine.Group("/uploads", r.authMiddleware.Required())
|
||||
uploadsGroup.Static("", "./uploads")
|
||||
|
||||
r.engine.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler))
|
||||
|
||||
if r.ipFilterMiddleware != nil {
|
||||
|
||||
@@ -10,11 +10,13 @@ import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"image/png"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pquerna/otp"
|
||||
"github.com/pquerna/otp/totp"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -111,27 +113,62 @@ func ValidateRecoveryCode(inputCode string, storedCodes []string) (int, bool) {
|
||||
return -1, false
|
||||
}
|
||||
|
||||
// HashRecoveryCode 使用 SHA256 哈希恢复码(用于存储)
|
||||
// HashRecoveryCode 使用 bcrypt 慢哈希恢复码(用于存储)
|
||||
// P2 安全增强:将 SHA256 快速哈希升级为 bcrypt 慢哈希,防止 GPU 暴力破解
|
||||
func HashRecoveryCode(code string) (string, error) {
|
||||
h := sha256.Sum256([]byte(code))
|
||||
return hex.EncodeToString(h[:]), nil
|
||||
normalized := strings.ToUpper(strings.ReplaceAll(strings.TrimSpace(code), "-", ""))
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(normalized), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("hash recovery code failed: %w", err)
|
||||
}
|
||||
return string(hash), nil
|
||||
}
|
||||
|
||||
// VerifyRecoveryCode 验证恢复码(自动哈希后比较)
|
||||
// sha256HexPattern 匹配 64 位十六进制字符串(旧版 SHA256 哈希格式)
|
||||
var sha256HexPattern = regexp.MustCompile("^[0-9a-fA-F]{64}$")
|
||||
|
||||
// isLegacySHA256Hash 检测是否为旧版 SHA256 哈希值
|
||||
func isLegacySHA256Hash(hash string) bool {
|
||||
return sha256HexPattern.MatchString(hash)
|
||||
}
|
||||
|
||||
// legacyHashRecoveryCode 旧版 SHA256 哈希(用于向后兼容验证)
|
||||
func legacyHashRecoveryCode(code string) string {
|
||||
h := sha256.Sum256([]byte(code))
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
|
||||
// VerifyRecoveryCode 验证恢复码(支持 bcrypt 新哈希和 SHA256 旧哈希向后兼容)
|
||||
// 使用恒定时间比较防止时序攻击
|
||||
func VerifyRecoveryCode(inputCode string, hashedCodes []string) (int, bool) {
|
||||
hashedInput, err := HashRecoveryCode(inputCode)
|
||||
if err != nil {
|
||||
return -1, false
|
||||
}
|
||||
normalized := strings.ToUpper(strings.ReplaceAll(strings.TrimSpace(inputCode), "-", ""))
|
||||
found := -1
|
||||
// 固定次数比较,防止时序攻击泄露匹配位置
|
||||
|
||||
for i := 0; i < len(hashedCodes); i++ {
|
||||
hashed := hashedCodes[i]
|
||||
if subtle.ConstantTimeCompare([]byte(hashedInput), []byte(hashed)) == 1 {
|
||||
stored := hashedCodes[i]
|
||||
if stored == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var matched bool
|
||||
if isLegacySHA256Hash(stored) {
|
||||
// 向后兼容:旧版 SHA256 哈希
|
||||
hashedInput := legacyHashRecoveryCode(inputCode)
|
||||
if subtle.ConstantTimeCompare([]byte(hashedInput), []byte(stored)) == 1 {
|
||||
matched = true
|
||||
}
|
||||
} else {
|
||||
// 新版 bcrypt 哈希
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(stored), []byte(normalized)); err == nil {
|
||||
matched = true
|
||||
}
|
||||
}
|
||||
|
||||
if matched {
|
||||
found = i
|
||||
}
|
||||
}
|
||||
|
||||
if found >= 0 {
|
||||
return found, true
|
||||
}
|
||||
|
||||
@@ -112,27 +112,29 @@ func TestHashRecoveryCode(t *testing.T) {
|
||||
t.Fatal("HashRecoveryCode should return non-empty hash")
|
||||
}
|
||||
|
||||
// Same code should produce same hash
|
||||
hashed2, err := HashRecoveryCode(code)
|
||||
if err != nil {
|
||||
t.Fatalf("HashRecoveryCode second call failed: %v", err)
|
||||
// Same code should verify against its own hash (bcrypt uses random salt, so hashes differ)
|
||||
_, ok := VerifyRecoveryCode(code, []string{hashed})
|
||||
if !ok {
|
||||
t.Error("Same code should verify against its own hash")
|
||||
}
|
||||
|
||||
if hashed != hashed2 {
|
||||
t.Error("Same code should produce same hash")
|
||||
}
|
||||
|
||||
// Different codes should produce different hashes
|
||||
// Different codes should NOT verify
|
||||
hashed3, err := HashRecoveryCode("DIFFERENT-CODE")
|
||||
if err != nil {
|
||||
t.Fatalf("HashRecoveryCode for different code failed: %v", err)
|
||||
}
|
||||
|
||||
if hashed == hashed3 {
|
||||
t.Error("Different codes should produce different hashes")
|
||||
_, ok2 := VerifyRecoveryCode(code, []string{hashed3})
|
||||
if ok2 {
|
||||
t.Error("Different codes should NOT verify against each other's hash")
|
||||
}
|
||||
|
||||
t.Logf("Hashed code: %s", hashed)
|
||||
// bcrypt hash format check
|
||||
if !strings.HasPrefix(hashed, "$2a$") {
|
||||
t.Errorf("Hash should be bcrypt format, got: %s", hashed)
|
||||
}
|
||||
|
||||
t.Logf("Hashed code (bcrypt): %s", hashed)
|
||||
}
|
||||
|
||||
func TestVerifyRecoveryCode(t *testing.T) {
|
||||
|
||||
@@ -207,6 +207,16 @@ func (r *DeviceRepository) GetTrustedDevices(ctx context.Context, userID int64)
|
||||
return devices, nil
|
||||
}
|
||||
|
||||
// CountTrustedDevices 统计用户当前信任设备数量(未过期的)
|
||||
func (r *DeviceRepository) CountTrustedDevices(ctx context.Context, userID int64) (int64, error) {
|
||||
var count int64
|
||||
now := time.Now()
|
||||
err := r.db.WithContext(ctx).Model(&domain.Device{}).
|
||||
Where("user_id = ? AND is_trusted = ? AND (trust_expires_at IS NULL OR trust_expires_at > ?)", userID, true, now).
|
||||
Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// ListDevicesParams 设备列表查询参数
|
||||
type ListDevicesParams struct {
|
||||
UserID int64
|
||||
|
||||
@@ -191,13 +191,20 @@ func (r *RoleRepository) GetByIDs(ctx context.Context, ids []int64) ([]*domain.R
|
||||
return roles, nil
|
||||
}
|
||||
|
||||
// maxAncestorDepth 角色祖先查询最大深度,防止循环引用导致无限循环
|
||||
const maxAncestorDepth = 20
|
||||
|
||||
// GetAncestorIDs 获取角色的所有祖先角色ID(用于权限继承)
|
||||
func (r *RoleRepository) GetAncestorIDs(ctx context.Context, roleID int64) ([]int64, error) {
|
||||
var ancestorIDs []int64
|
||||
currentID := roleID
|
||||
depth := 0
|
||||
|
||||
// 循环向上查找父角色,直到没有父角色为止
|
||||
// 循环向上查找父角色,直到没有父角色或达到深度上限为止
|
||||
for {
|
||||
if depth >= maxAncestorDepth {
|
||||
break
|
||||
}
|
||||
var role domain.Role
|
||||
err := r.db.WithContext(ctx).Select("id", "parent_id").First(&role, currentID).Error
|
||||
if err != nil {
|
||||
@@ -211,6 +218,7 @@ func (r *RoleRepository) GetAncestorIDs(ctx context.Context, roleID int64) ([]in
|
||||
}
|
||||
ancestorIDs = append(ancestorIDs, *role.ParentID)
|
||||
currentID = *role.ParentID
|
||||
depth++
|
||||
}
|
||||
|
||||
return ancestorIDs, nil
|
||||
|
||||
@@ -119,15 +119,61 @@ func (r *RolePermissionRepository) GetPermissionByID(ctx context.Context, permis
|
||||
return &permission, nil
|
||||
}
|
||||
|
||||
// GetPermissionIDsByRoleIDs 根据角色ID列表批量获取权限ID
|
||||
// GetRoleAncestorIDs 递归获取角色的所有祖先角色ID(含自身)
|
||||
// 包含循环检测(最大深度 5 层)
|
||||
func (r *RolePermissionRepository) GetRoleAncestorIDs(ctx context.Context, roleID int64) ([]int64, error) {
|
||||
var ancestors []int64
|
||||
visited := make(map[int64]bool)
|
||||
current := roleID
|
||||
depth := 0
|
||||
maxDepth := 5
|
||||
|
||||
for current > 0 && depth < maxDepth {
|
||||
if visited[current] {
|
||||
break // 循环检测
|
||||
}
|
||||
visited[current] = true
|
||||
ancestors = append(ancestors, current)
|
||||
|
||||
var role domain.Role
|
||||
err := r.db.WithContext(ctx).Select("parent_id").First(&role, current).Error
|
||||
if err != nil || role.ParentID == nil {
|
||||
break
|
||||
}
|
||||
current = *role.ParentID
|
||||
depth++
|
||||
}
|
||||
|
||||
return ancestors, nil
|
||||
}
|
||||
|
||||
// GetPermissionIDsByRoleIDs 根据角色ID列表批量获取权限ID(含继承的父角色权限)
|
||||
func (r *RolePermissionRepository) GetPermissionIDsByRoleIDs(ctx context.Context, roleIDs []int64) ([]int64, error) {
|
||||
if len(roleIDs) == 0 {
|
||||
return []int64{}, nil
|
||||
}
|
||||
|
||||
// 收集所有角色ID(含继承的父角色)
|
||||
allRoleIDs := make(map[int64]bool)
|
||||
for _, roleID := range roleIDs {
|
||||
ancestors, err := r.GetRoleAncestorIDs(ctx, roleID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, id := range ancestors {
|
||||
allRoleIDs[id] = true
|
||||
}
|
||||
}
|
||||
|
||||
// 转换为 slice
|
||||
ids := make([]int64, 0, len(allRoleIDs))
|
||||
for id := range allRoleIDs {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
var permissionIDs []int64
|
||||
err := r.db.WithContext(ctx).Model(&domain.RolePermission{}).
|
||||
Where("role_id IN ?", roleIDs).
|
||||
Where("role_id IN ?", ids).
|
||||
Pluck("permission_id", &permissionIDs).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -104,6 +104,18 @@ func (r *UserRepository) GetByPhone(ctx context.Context, phone string) (*domain.
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// FindByAccount 按账号查询用户(支持用户名/邮箱/手机号,P1性能优化:替代串行查询)
|
||||
func (r *UserRepository) FindByAccount(ctx context.Context, account string) (*domain.User, error) {
|
||||
var user domain.User
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("username = ? OR email = ? OR phone = ?", account, account, account).
|
||||
First(&user).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// List 获取用户列表
|
||||
func (r *UserRepository) List(ctx context.Context, offset, limit int) ([]*domain.User, int64, error) {
|
||||
var users []*domain.User
|
||||
@@ -195,6 +207,21 @@ func (r *UserRepository) ExistsByPhone(ctx context.Context, phone string) (bool,
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// FilterExistingUsernames 批量筛选已存在的用户名(P1性能优化:替代循环查询)
|
||||
func (r *UserRepository) FilterExistingUsernames(ctx context.Context, usernames []string) ([]string, error) {
|
||||
if len(usernames) == 0 {
|
||||
return []string{}, nil
|
||||
}
|
||||
var existing []string
|
||||
err := r.db.WithContext(ctx).Model(&domain.User{}).
|
||||
Where("username IN ?", usernames).
|
||||
Pluck("username", &existing).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return existing, nil
|
||||
}
|
||||
|
||||
// Search 搜索用户
|
||||
func (r *UserRepository) Search(ctx context.Context, keyword string, offset, limit int) ([]*domain.User, int64, error) {
|
||||
var users []*domain.User
|
||||
|
||||
@@ -83,69 +83,89 @@ func (r *UserRoleRepository) GetRoleIDsByUserID(ctx context.Context, userID int6
|
||||
return roleIDs, nil
|
||||
}
|
||||
|
||||
// GetUserRolesAndPermissions 获取用户角色和权限(PERF-01 优化:合并为单次 JOIN 查询)
|
||||
func (r *UserRoleRepository) GetUserRolesAndPermissions(ctx context.Context, userID int64) ([]*domain.Role, []*domain.Permission, error) {
|
||||
var results []struct {
|
||||
RoleID int64
|
||||
RoleName string
|
||||
RoleCode string
|
||||
RoleStatus int
|
||||
PermissionID int64
|
||||
PermissionCode string
|
||||
PermissionName string
|
||||
// getRoleAncestorIDs 递归获取角色的所有祖先角色ID(含自身)
|
||||
// 包含循环检测(最大深度 5 层)
|
||||
func (r *UserRoleRepository) getRoleAncestorIDs(ctx context.Context, roleID int64) ([]int64, error) {
|
||||
var ancestors []int64
|
||||
visited := make(map[int64]bool)
|
||||
current := roleID
|
||||
depth := 0
|
||||
maxDepth := 5
|
||||
|
||||
for current > 0 && depth < maxDepth {
|
||||
if visited[current] {
|
||||
break // 循环检测
|
||||
}
|
||||
visited[current] = true
|
||||
ancestors = append(ancestors, current)
|
||||
|
||||
var role domain.Role
|
||||
err := r.db.WithContext(ctx).Select("parent_id").First(&role, current).Error
|
||||
if err != nil || role.ParentID == nil {
|
||||
break
|
||||
}
|
||||
current = *role.ParentID
|
||||
depth++
|
||||
}
|
||||
|
||||
// 使用 LEFT JOIN 一次性获取用户角色和权限
|
||||
err := r.db.WithContext(ctx).
|
||||
Raw(`
|
||||
SELECT DISTINCT r.id as role_id, r.name as role_name, r.code as role_code, r.status as role_status,
|
||||
p.id as permission_id, p.code as permission_code, p.name as permission_name
|
||||
FROM user_roles ur
|
||||
JOIN roles r ON ur.role_id = r.id
|
||||
LEFT JOIN role_permissions rp ON r.id = rp.role_id
|
||||
LEFT JOIN permissions p ON rp.permission_id = p.id
|
||||
WHERE ur.user_id = ? AND r.status = 1
|
||||
`, userID).
|
||||
Scan(&results).Error
|
||||
return ancestors, nil
|
||||
}
|
||||
|
||||
// GetUserRolesAndPermissions 获取用户角色和权限(包含继承的父角色和权限)
|
||||
func (r *UserRoleRepository) GetUserRolesAndPermissions(ctx context.Context, userID int64) ([]*domain.Role, []*domain.Permission, error) {
|
||||
// 获取用户直接分配的角色ID
|
||||
var directRoleIDs []int64
|
||||
err := r.db.WithContext(ctx).Model(&domain.UserRole{}).Where("user_id = ?", userID).Pluck("role_id", &directRoleIDs).Error
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 构建角色和权限列表
|
||||
roleMap := make(map[int64]*domain.Role)
|
||||
permMap := make(map[int64]*domain.Permission)
|
||||
|
||||
for _, row := range results {
|
||||
if _, ok := roleMap[row.RoleID]; !ok {
|
||||
roleMap[row.RoleID] = &domain.Role{
|
||||
ID: row.RoleID,
|
||||
Name: row.RoleName,
|
||||
Code: row.RoleCode,
|
||||
Status: domain.RoleStatus(row.RoleStatus),
|
||||
}
|
||||
// 递归获取所有祖先角色ID(含自身),包含循环检测
|
||||
allRoleIDMap := make(map[int64]bool)
|
||||
for _, roleID := range directRoleIDs {
|
||||
ancestors, err := r.getRoleAncestorIDs(ctx, roleID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if row.PermissionID > 0 {
|
||||
if _, ok := permMap[row.PermissionID]; !ok {
|
||||
permMap[row.PermissionID] = &domain.Permission{
|
||||
ID: row.PermissionID,
|
||||
Code: row.PermissionCode,
|
||||
Name: row.PermissionName,
|
||||
}
|
||||
}
|
||||
for _, id := range ancestors {
|
||||
allRoleIDMap[id] = true
|
||||
}
|
||||
}
|
||||
|
||||
roles := make([]*domain.Role, 0, len(roleMap))
|
||||
for _, role := range roleMap {
|
||||
roles = append(roles, role)
|
||||
// 转换为 slice
|
||||
allRoleIDs := make([]int64, 0, len(allRoleIDMap))
|
||||
for id := range allRoleIDMap {
|
||||
allRoleIDs = append(allRoleIDs, id)
|
||||
}
|
||||
|
||||
perms := make([]*domain.Permission, 0, len(permMap))
|
||||
for _, perm := range permMap {
|
||||
perms = append(perms, perm)
|
||||
if len(allRoleIDs) == 0 {
|
||||
return []*domain.Role{}, []*domain.Permission{}, nil
|
||||
}
|
||||
|
||||
return roles, perms, nil
|
||||
// 查询所有角色信息
|
||||
var roles []*domain.Role
|
||||
err = r.db.WithContext(ctx).Where("id IN ? AND status = ?", allRoleIDs, domain.RoleStatusEnabled).Find(&roles).Error
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 查询所有权限ID
|
||||
var permissionIDs []int64
|
||||
err = r.db.WithContext(ctx).Model(&domain.RolePermission{}).Where("role_id IN ?", allRoleIDs).Pluck("permission_id", &permissionIDs).Error
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 查询权限详情
|
||||
var permissions []*domain.Permission
|
||||
if len(permissionIDs) > 0 {
|
||||
err = r.db.WithContext(ctx).Where("id IN ?", permissionIDs).Find(&permissions).Error
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return roles, permissions, nil
|
||||
}
|
||||
|
||||
// GetUserIDByRoleID 根据角色ID获取用户ID列表
|
||||
|
||||
@@ -11,48 +11,92 @@ type Validator struct {
|
||||
passwordMinLength int
|
||||
passwordRequireSpecial bool
|
||||
passwordRequireNumber bool
|
||||
// 预编译的正则表达式,避免每次调用重复编译(P1性能优化)
|
||||
emailRe *regexp.Regexp
|
||||
phoneRe *regexp.Regexp
|
||||
usernameRe *regexp.Regexp
|
||||
urlRe *regexp.Regexp
|
||||
sqlPatterns []*regexp.Regexp
|
||||
xssPatterns []*regexp.Regexp
|
||||
}
|
||||
|
||||
// NewValidator creates a validator with the configured password rules.
|
||||
func NewValidator(minLength int, requireSpecial, requireNumber bool) *Validator {
|
||||
return &Validator{
|
||||
v := &Validator{
|
||||
passwordMinLength: minLength,
|
||||
passwordRequireSpecial: requireSpecial,
|
||||
passwordRequireNumber: requireNumber,
|
||||
}
|
||||
|
||||
// 预编译常用验证正则(P1性能优化)
|
||||
v.emailRe = regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
|
||||
v.phoneRe = regexp.MustCompile(`^1[3-9]\d{9}$`)
|
||||
v.usernameRe = regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9_]{3,19}$`)
|
||||
v.urlRe = regexp.MustCompile(`^https?://[a-zA-Z0-9\-._~:/?#[\]@!$&'()*+,;=]+$`)
|
||||
|
||||
// 预编译SQL注入检测正则(P1性能优化)
|
||||
sqlRawPatterns := []string{
|
||||
`;[\s]*--`,
|
||||
`/\*.*?\*/`,
|
||||
`\bxp_\w+`,
|
||||
`\bexec[\s\(]`,
|
||||
`\bsp_\w+`,
|
||||
`\bwaitfor[\s]+delay`,
|
||||
`\bunion[\s]+select`,
|
||||
`\bdrop[\s]+table`,
|
||||
`\binsert[\s]+into`,
|
||||
`\bupdate[\s]+\w+[\s]+set`,
|
||||
`\bdelete[\s]+from`,
|
||||
}
|
||||
v.sqlPatterns = make([]*regexp.Regexp, len(sqlRawPatterns))
|
||||
for i, p := range sqlRawPatterns {
|
||||
v.sqlPatterns[i] = regexp.MustCompile(`(?i)` + p)
|
||||
}
|
||||
|
||||
// 预编译XSS检测正则(P1性能优化)
|
||||
xssRawPatterns := []string{
|
||||
`(?i)<script[^>]*>.*?</script>`,
|
||||
`(?i)</script>`,
|
||||
`(?i)<iframe[^>]*>.*?</iframe>`,
|
||||
`(?i)<object[^>]*>.*?</object>`,
|
||||
`(?i)<embed[^>]*>.*?</embed>`,
|
||||
`(?i)<applet[^>]*>.*?</applet>`,
|
||||
`(?i)javascript\s*:`,
|
||||
`(?i)vbscript\s*:`,
|
||||
`(?i)data\s*:`,
|
||||
`(?i)on\w+\s*=`,
|
||||
`(?i)<style[^>]*>.*?</style>`,
|
||||
}
|
||||
v.xssPatterns = make([]*regexp.Regexp, len(xssRawPatterns))
|
||||
for i, p := range xssRawPatterns {
|
||||
v.xssPatterns[i] = regexp.MustCompile(p)
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
// ValidateEmail validates email format.
|
||||
func (v *Validator) ValidateEmail(email string) bool {
|
||||
if email == "" {
|
||||
if email == "" || v.emailRe == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
pattern := `^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`
|
||||
matched, _ := regexp.MatchString(pattern, email)
|
||||
return matched
|
||||
return v.emailRe.MatchString(email)
|
||||
}
|
||||
|
||||
// ValidatePhone validates mainland China mobile numbers.
|
||||
func (v *Validator) ValidatePhone(phone string) bool {
|
||||
if phone == "" {
|
||||
if phone == "" || v.phoneRe == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
pattern := `^1[3-9]\d{9}$`
|
||||
matched, _ := regexp.MatchString(pattern, phone)
|
||||
return matched
|
||||
return v.phoneRe.MatchString(phone)
|
||||
}
|
||||
|
||||
// ValidateUsername validates usernames.
|
||||
func (v *Validator) ValidateUsername(username string) bool {
|
||||
if username == "" {
|
||||
if username == "" || v.usernameRe == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
pattern := `^[a-zA-Z][a-zA-Z0-9_]{3,19}$`
|
||||
matched, _ := regexp.MatchString(pattern, username)
|
||||
return matched
|
||||
return v.usernameRe.MatchString(username)
|
||||
}
|
||||
|
||||
// ValidatePassword validates passwords using the shared runtime policy.
|
||||
@@ -77,27 +121,13 @@ func (v *Validator) SanitizeSQL(input string) string {
|
||||
`"`, `""`,
|
||||
)
|
||||
|
||||
// Remove common SQL injection patterns that could bypass quoting
|
||||
dangerousPatterns := []string{
|
||||
`;[\s]*--`, // SQL comment
|
||||
`/\*.*?\*/`, // Block comment (non-greedy)
|
||||
`\bxp_\w+`, // Extended stored procedures
|
||||
`\bexec[\s\(]`, // EXEC statements
|
||||
`\bsp_\w+`, // System stored procedures
|
||||
`\bwaitfor[\s]+delay`, // Time-based blind SQL injection
|
||||
`\bunion[\s]+select`, // UNION injection
|
||||
`\bdrop[\s]+table`, // DROP TABLE
|
||||
`\binsert[\s]+into`, // INSERT
|
||||
`\bupdate[\s]+\w+[\s]+set`, // UPDATE
|
||||
`\bdelete[\s]+from`, // DELETE
|
||||
}
|
||||
|
||||
result := replacer.Replace(input)
|
||||
|
||||
// Apply pattern removal
|
||||
for _, pattern := range dangerousPatterns {
|
||||
re := regexp.MustCompile(`(?i)` + pattern) // Case-insensitive
|
||||
result = re.ReplaceAllString(result, "")
|
||||
// 使用预编译的正则移除SQL注入模式(P1性能优化)
|
||||
for _, re := range v.sqlPatterns {
|
||||
if re != nil {
|
||||
result = re.ReplaceAllString(result, "")
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
@@ -106,31 +136,11 @@ func (v *Validator) SanitizeSQL(input string) string {
|
||||
// SanitizeXSS removes obviously dangerous XSS patterns using regex.
|
||||
// This is a defense-in-depth measure; output encoding should always be used.
|
||||
func (v *Validator) SanitizeXSS(input string) string {
|
||||
// Remove dangerous tags and attributes using pattern matching
|
||||
dangerousPatterns := []struct {
|
||||
pattern string
|
||||
replaceAll bool
|
||||
}{
|
||||
{`(?i)<script[^>]*>.*?</script>`, true}, // Script tags
|
||||
{`(?i)</script>`, false}, // Closing script
|
||||
{`(?i)<iframe[^>]*>.*?</iframe>`, true}, // Iframe injection
|
||||
{`(?i)<object[^>]*>.*?</object>`, true}, // Object injection
|
||||
{`(?i)<embed[^>]*>.*?</embed>`, true}, // Embed injection
|
||||
{`(?i)<applet[^>]*>.*?</applet>`, true}, // Applet injection
|
||||
{`(?i)javascript\s*:`, false}, // JavaScript protocol
|
||||
{`(?i)vbscript\s*:`, false}, // VBScript protocol
|
||||
{`(?i)data\s*:`, false}, // Data URL protocol
|
||||
{`(?i)on\w+\s*=`, false}, // Event handlers
|
||||
{`(?i)<style[^>]*>.*?</style>`, true}, // Style injection
|
||||
}
|
||||
|
||||
result := input
|
||||
|
||||
for _, p := range dangerousPatterns {
|
||||
re := regexp.MustCompile(p.pattern)
|
||||
if p.replaceAll {
|
||||
result = re.ReplaceAllString(result, "")
|
||||
} else {
|
||||
// 使用预编译的正则移除XSS模式(P1性能优化)
|
||||
for _, re := range v.xssPatterns {
|
||||
if re != nil {
|
||||
result = re.ReplaceAllString(result, "")
|
||||
}
|
||||
}
|
||||
@@ -148,13 +158,10 @@ func (v *Validator) SanitizeXSS(input string) string {
|
||||
|
||||
// ValidateURL validates a basic HTTP/HTTPS URL.
|
||||
func (v *Validator) ValidateURL(url string) bool {
|
||||
if url == "" {
|
||||
if url == "" || v.urlRe == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
pattern := `^https?://[a-zA-Z0-9\-._~:/?#[\]@!$&'()*+,;=]+$`
|
||||
matched, _ := regexp.MatchString(pattern, url)
|
||||
return matched
|
||||
return v.urlRe.MatchString(url)
|
||||
}
|
||||
|
||||
// ValidateIP validates IPv4 or IPv6 addresses using net.ParseIP.
|
||||
|
||||
@@ -148,6 +148,9 @@ func Serve(cfg *config.Config) error {
|
||||
|
||||
// 初始化中间件
|
||||
rateLimitMiddleware := middleware.NewRateLimitMiddleware(cfg.RateLimit)
|
||||
stopRateLimitCleanup := rateLimitMiddleware.StartCleanup()
|
||||
defer stopRateLimitCleanup()
|
||||
|
||||
authMiddleware := middleware.NewAuthMiddleware(
|
||||
jwtManager,
|
||||
userRepo,
|
||||
|
||||
@@ -30,6 +30,8 @@ const (
|
||||
defaultTOTPChallengeTTL = 5 * time.Minute
|
||||
defaultPasswordMinLen = 8
|
||||
refreshTokenRetryGrace = 10 * time.Second
|
||||
MaxUsernameAttempts = 100 // 最大尝试次数(P1性能优化:减少循环查询)
|
||||
MaxUsernameLength = 40 // 用户名最大长度
|
||||
)
|
||||
|
||||
type userRepositoryInterface interface {
|
||||
@@ -38,6 +40,7 @@ type userRepositoryInterface interface {
|
||||
UpdateTOTP(ctx context.Context, user *domain.User) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
GetByID(ctx context.Context, id int64) (*domain.User, error)
|
||||
GetByIDs(ctx context.Context, ids []int64) ([]*domain.User, error)
|
||||
GetByUsername(ctx context.Context, username string) (*domain.User, error)
|
||||
GetByEmail(ctx context.Context, email string) (*domain.User, error)
|
||||
GetByPhone(ctx context.Context, phone string) (*domain.User, error)
|
||||
@@ -49,6 +52,10 @@ type userRepositoryInterface interface {
|
||||
ExistsByEmail(ctx context.Context, email string) (bool, error)
|
||||
ExistsByPhone(ctx context.Context, phone string) (bool, error)
|
||||
Search(ctx context.Context, keyword string, offset, limit int) ([]*domain.User, int64, error)
|
||||
// FilterExistingUsernames 批量筛选已存在的用户名(P1性能优化:替代循环查询)
|
||||
FilterExistingUsernames(ctx context.Context, usernames []string) ([]string, error)
|
||||
// FindByAccount 按账号查询用户(支持用户名/邮箱/手机号,P1性能优化:替代串行查询)
|
||||
FindByAccount(ctx context.Context, account string) (*domain.User, error)
|
||||
}
|
||||
|
||||
type userRoleRepositoryInterface interface {
|
||||
@@ -282,17 +289,28 @@ func (s *AuthService) generateUniqueUsername(ctx context.Context, base string) (
|
||||
}
|
||||
|
||||
baseRunes := []rune(username)
|
||||
if len(baseRunes) > 40 {
|
||||
username = string(baseRunes[:40])
|
||||
if len(baseRunes) > MaxUsernameLength {
|
||||
username = string(baseRunes[:MaxUsernameLength])
|
||||
}
|
||||
|
||||
for i := 1; i <= 1000; i++ {
|
||||
candidate := fmt.Sprintf("%s_%d", username, i)
|
||||
exists, err = s.userRepo.ExistsByUsername(ctx, candidate)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !exists {
|
||||
// P1性能优化:批量生成候选列表后一次性查询,避免循环DB往返
|
||||
candidates := make([]string, 0, MaxUsernameAttempts)
|
||||
for i := 1; i <= MaxUsernameAttempts; i++ {
|
||||
candidates = append(candidates, fmt.Sprintf("%s_%d", username, i))
|
||||
}
|
||||
|
||||
existing, err := s.userRepo.FilterExistingUsernames(ctx, candidates)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
existingSet := make(map[string]bool, len(existing))
|
||||
for _, u := range existing {
|
||||
existingSet[u] = true
|
||||
}
|
||||
|
||||
for _, candidate := range candidates {
|
||||
if !existingSet[candidate] {
|
||||
return candidate, nil
|
||||
}
|
||||
}
|
||||
@@ -530,6 +548,11 @@ func (s *AuthService) writeLoginLog(
|
||||
|
||||
// #nosec G118 - 使用带超时的独立 context,防止日志写入无限等待
|
||||
go func() { // #nosec G118
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("auth: write login log panic recovered, user_id=%v login_type=%d err=%v", userID, loginType, r)
|
||||
}
|
||||
}()
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := s.loginLogRepo.Create(bgCtx, loginRecord); err != nil {
|
||||
@@ -591,7 +614,7 @@ func buildDeviceFingerprint(req *LoginRequest) string {
|
||||
return result
|
||||
}
|
||||
|
||||
// bestEffortRegisterDevice 尝试自动注册/更新设备记录
|
||||
// bestEffortRegisterDevice 尝试自动注册/更新设备记录(异步,不阻塞登录响应)
|
||||
func (s *AuthService) bestEffortRegisterDevice(ctx context.Context, userID int64, req *LoginRequest) {
|
||||
if s == nil || s.deviceService == nil || req == nil || req.DeviceID == "" {
|
||||
return
|
||||
@@ -603,7 +626,18 @@ func (s *AuthService) bestEffortRegisterDevice(ctx context.Context, userID int64
|
||||
DeviceBrowser: req.DeviceBrowser,
|
||||
DeviceOS: req.DeviceOS,
|
||||
}
|
||||
_, _ = s.deviceService.CreateDevice(ctx, userID, createReq)
|
||||
|
||||
// PERF-01: 改为异步 goroutine,不阻塞登录响应返回
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("auth: register device panic recovered, user_id=%d device_id=%s err=%v", userID, req.DeviceID, r)
|
||||
}
|
||||
}()
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_, _ = s.deviceService.CreateDevice(bgCtx, userID, createReq)
|
||||
}()
|
||||
}
|
||||
|
||||
// BestEffortRegisterDevicePublic 供外部 handler(如 SMS 登录)调用,安静地注册设备
|
||||
|
||||
@@ -75,21 +75,16 @@ func (s *AuthService) IsAdminBootstrapRequired(ctx context.Context) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
hadUnexpectedLookupError := false
|
||||
for _, userID := range userIDs {
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
if isUserNotFoundError(err) {
|
||||
continue
|
||||
}
|
||||
hadUnexpectedLookupError = true
|
||||
log.Printf("auth: resolve auth capabilities failed while loading admin user: user_id=%d err=%v", userID, err)
|
||||
continue
|
||||
}
|
||||
if user != nil && user.Status == domain.UserStatusActive {
|
||||
users, err := s.userRepo.GetByIDs(ctx, userIDs)
|
||||
if err != nil {
|
||||
log.Printf("auth: resolve auth capabilities failed while loading admin users: err=%v", err)
|
||||
return false
|
||||
}
|
||||
for _, user := range users {
|
||||
if user.Status == domain.UserStatusActive {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return !hadUnexpectedLookupError
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -30,27 +30,15 @@ func (s *AuthService) RegisterOAuthProvider(provider auth.OAuthProvider, cfg *au
|
||||
}
|
||||
|
||||
func (s *AuthService) findUserForLogin(ctx context.Context, account string) (*domain.User, error) {
|
||||
user, err := s.userRepo.GetByUsername(ctx, account)
|
||||
if err == nil {
|
||||
return user, nil
|
||||
// P1性能优化:使用单一查询替代 username->email->phone 串行查询,减少DB往返
|
||||
user, err := s.userRepo.FindByAccount(ctx, account)
|
||||
if err != nil {
|
||||
if isUserNotFoundError(err) {
|
||||
return nil, err
|
||||
}
|
||||
return nil, fmt.Errorf("lookup user failed: %w", err)
|
||||
}
|
||||
if !isUserNotFoundError(err) {
|
||||
return nil, fmt.Errorf("lookup user by username failed: %w", err)
|
||||
}
|
||||
|
||||
user, err = s.userRepo.GetByEmail(ctx, account)
|
||||
if err == nil {
|
||||
return user, nil
|
||||
}
|
||||
if !isUserNotFoundError(err) {
|
||||
return nil, fmt.Errorf("lookup user by email failed: %w", err)
|
||||
}
|
||||
|
||||
user, err = s.userRepo.GetByPhone(ctx, account)
|
||||
if err != nil && !isUserNotFoundError(err) {
|
||||
return nil, fmt.Errorf("lookup user by phone failed: %w", err)
|
||||
}
|
||||
return user, err
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func isUserNotFoundError(err error) bool {
|
||||
@@ -100,9 +88,19 @@ func (s *AuthService) bestEffortUpdateLastLogin(ctx context.Context, userID int6
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.userRepo.UpdateLastLogin(ctx, userID, ip); err != nil {
|
||||
log.Printf("auth: update last login failed, source=%s user_id=%d ip=%s err=%v", source, userID, ip, err)
|
||||
}
|
||||
// PERF-01: 改为异步 goroutine,不阻塞登录响应返回
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("auth: update last login panic recovered, source=%s user_id=%d err=%v", source, userID, r)
|
||||
}
|
||||
}()
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := s.userRepo.UpdateLastLogin(bgCtx, userID, ip); err != nil {
|
||||
log.Printf("auth: update last login failed, source=%s user_id=%d ip=%s err=%v", source, userID, ip, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func loginAttemptKey(account string, user *domain.User) string {
|
||||
|
||||
@@ -4,14 +4,16 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/pagination"
|
||||
apierrors "github.com/user-management-system/internal/pkg/errors"
|
||||
"github.com/user-management-system/internal/repository"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Interfaces for dependency inversion (DIP) — service layer depends on these abstractions, not concrete types.
|
||||
type deviceRepository interface {
|
||||
Create(ctx context.Context, device *domain.Device) error
|
||||
Update(ctx context.Context, device *domain.Device) error
|
||||
@@ -27,6 +29,7 @@ type deviceRepository interface {
|
||||
UntrustDevice(ctx context.Context, id int64) error
|
||||
DeleteAllByUserIDExcept(ctx context.Context, userID int64, exceptDeviceID int64) error
|
||||
GetTrustedDevices(ctx context.Context, userID int64) ([]*domain.Device, error)
|
||||
CountTrustedDevices(ctx context.Context, userID int64) (int64, error)
|
||||
ListAll(ctx context.Context, params *repository.ListDevicesParams) ([]*domain.Device, int64, error)
|
||||
ListAllCursor(ctx context.Context, params *repository.ListDevicesParams, limit int, cursor *pagination.Cursor) ([]*domain.Device, bool, error)
|
||||
}
|
||||
@@ -35,24 +38,18 @@ type deviceUserRepository interface {
|
||||
GetByID(ctx context.Context, id int64) (*domain.User, error)
|
||||
}
|
||||
|
||||
// DeviceService 设备服务
|
||||
type DeviceService struct {
|
||||
deviceRepo deviceRepository
|
||||
userRepo deviceUserRepository
|
||||
}
|
||||
|
||||
// NewDeviceService 创建设备服务
|
||||
func NewDeviceService(
|
||||
deviceRepo deviceRepository,
|
||||
userRepo deviceUserRepository,
|
||||
) *DeviceService {
|
||||
func NewDeviceService(deviceRepo deviceRepository, userRepo deviceUserRepository) *DeviceService {
|
||||
return &DeviceService{
|
||||
deviceRepo: deviceRepo,
|
||||
userRepo: userRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateDeviceRequest 创建设备请求
|
||||
type CreateDeviceRequest struct {
|
||||
DeviceID string `json:"device_id" binding:"required"`
|
||||
DeviceName string `json:"device_name"`
|
||||
@@ -63,7 +60,6 @@ type CreateDeviceRequest struct {
|
||||
Location string `json:"location"`
|
||||
}
|
||||
|
||||
// UpdateDeviceRequest 更新设备请求
|
||||
type UpdateDeviceRequest struct {
|
||||
DeviceName string `json:"device_name"`
|
||||
DeviceType int `json:"device_type"`
|
||||
@@ -74,21 +70,16 @@ type UpdateDeviceRequest struct {
|
||||
Status int `json:"status"`
|
||||
}
|
||||
|
||||
// CreateDevice 创建设备
|
||||
func (s *DeviceService) CreateDevice(ctx context.Context, userID int64, req *CreateDeviceRequest) (*domain.Device, error) {
|
||||
// 检查用户是否存在
|
||||
_, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, errors.New("用户不存在")
|
||||
if _, err := s.userRepo.GetByID(ctx, userID); err != nil {
|
||||
return nil, errors.New("user not found")
|
||||
}
|
||||
|
||||
// 检查设备是否已存在
|
||||
exists, err := s.deviceRepo.Exists(ctx, userID, req.DeviceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if exists {
|
||||
// 设备已存在,更新最后活跃时间
|
||||
device, err := s.deviceRepo.GetByDeviceID(ctx, userID, req.DeviceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -97,7 +88,6 @@ func (s *DeviceService) CreateDevice(ctx context.Context, userID int64, req *Cre
|
||||
return device, s.deviceRepo.Update(ctx, device)
|
||||
}
|
||||
|
||||
// 创建设备
|
||||
device := &domain.Device{
|
||||
UserID: userID,
|
||||
DeviceID: req.DeviceID,
|
||||
@@ -117,14 +107,47 @@ func (s *DeviceService) CreateDevice(ctx context.Context, userID int64, req *Cre
|
||||
return device, nil
|
||||
}
|
||||
|
||||
// UpdateDevice 更新设备
|
||||
func (s *DeviceService) UpdateDevice(ctx context.Context, deviceID int64, req *UpdateDeviceRequest) (*domain.Device, error) {
|
||||
device, err := s.deviceRepo.GetByID(ctx, deviceID)
|
||||
if err != nil {
|
||||
return nil, errors.New("设备不存在")
|
||||
func isDeviceNotFoundError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return true
|
||||
}
|
||||
|
||||
lowerErr := strings.ToLower(strings.TrimSpace(err.Error()))
|
||||
return strings.Contains(lowerErr, "record not found") ||
|
||||
strings.Contains(lowerErr, "device not found") ||
|
||||
strings.Contains(lowerErr, "not found")
|
||||
}
|
||||
|
||||
func (s *DeviceService) getDeviceByID(ctx context.Context, deviceID int64) (*domain.Device, error) {
|
||||
device, err := s.deviceRepo.GetByID(ctx, deviceID)
|
||||
if err != nil {
|
||||
if isDeviceNotFoundError(err) {
|
||||
return nil, apierrors.NotFound("device_not_found", "device not found").WithCause(err)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return device, nil
|
||||
}
|
||||
|
||||
func (s *DeviceService) getAuthorizedDevice(ctx context.Context, actorUserID, deviceID int64, isAdmin bool) (*domain.Device, error) {
|
||||
device, err := s.getDeviceByID(ctx, deviceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !isAdmin && device.UserID != actorUserID {
|
||||
return nil, apierrors.Forbidden("device_forbidden", "permission denied")
|
||||
}
|
||||
return device, nil
|
||||
}
|
||||
|
||||
func (s *DeviceService) persistDeviceUpdate(ctx context.Context, device *domain.Device, req *UpdateDeviceRequest) (*domain.Device, error) {
|
||||
if req == nil {
|
||||
return device, nil
|
||||
}
|
||||
|
||||
// 更新字段
|
||||
if req.DeviceName != "" {
|
||||
device.DeviceName = req.DeviceName
|
||||
}
|
||||
@@ -154,19 +177,48 @@ func (s *DeviceService) UpdateDevice(ctx context.Context, deviceID int64, req *U
|
||||
return device, nil
|
||||
}
|
||||
|
||||
// DeleteDevice 删除设备
|
||||
// maxTrustedDevicesPerUser 每个用户最大信任设备数量(P2 安全增强)
|
||||
const maxTrustedDevicesPerUser = 10
|
||||
|
||||
func (s *DeviceService) trustDeviceRecord(ctx context.Context, device *domain.Device, trustDuration time.Duration) error {
|
||||
// P2 安全增强:检查信任设备数量上限
|
||||
trustedCount, err := s.deviceRepo.CountTrustedDevices(ctx, device.UserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("count trusted devices failed: %w", err)
|
||||
}
|
||||
if trustedCount >= maxTrustedDevicesPerUser {
|
||||
return fmt.Errorf("trusted device limit reached (max %d), please untrust an existing device first", maxTrustedDevicesPerUser)
|
||||
}
|
||||
|
||||
var trustExpiresAt *time.Time
|
||||
if trustDuration > 0 {
|
||||
expiresAt := time.Now().Add(trustDuration)
|
||||
trustExpiresAt = &expiresAt
|
||||
}
|
||||
return s.deviceRepo.TrustDevice(ctx, device.ID, trustExpiresAt)
|
||||
}
|
||||
|
||||
func (s *DeviceService) UpdateDevice(ctx context.Context, deviceID int64, req *UpdateDeviceRequest) (*domain.Device, error) {
|
||||
device, err := s.getDeviceByID(ctx, deviceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.persistDeviceUpdate(ctx, device, req)
|
||||
}
|
||||
|
||||
func (s *DeviceService) DeleteDevice(ctx context.Context, deviceID int64) error {
|
||||
return s.deviceRepo.Delete(ctx, deviceID)
|
||||
device, err := s.getDeviceByID(ctx, deviceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.deviceRepo.Delete(ctx, device.ID)
|
||||
}
|
||||
|
||||
// GetDevice 获取设备信息
|
||||
func (s *DeviceService) GetDevice(ctx context.Context, deviceID int64) (*domain.Device, error) {
|
||||
return s.deviceRepo.GetByID(ctx, deviceID)
|
||||
return s.getDeviceByID(ctx, deviceID)
|
||||
}
|
||||
|
||||
// GetUserDevices 获取用户设备列表
|
||||
func (s *DeviceService) GetUserDevices(ctx context.Context, userID int64, page, pageSize int) ([]*domain.Device, int64, error) {
|
||||
offset := (page - 1) * pageSize
|
||||
if page <= 0 {
|
||||
page = 1
|
||||
}
|
||||
@@ -174,22 +226,51 @@ func (s *DeviceService) GetUserDevices(ctx context.Context, userID int64, page,
|
||||
pageSize = 20
|
||||
}
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
return s.deviceRepo.ListByUserID(ctx, userID, offset, pageSize)
|
||||
}
|
||||
|
||||
// UpdateDeviceStatus 更新设备状态
|
||||
func (s *DeviceService) UpdateDeviceStatus(ctx context.Context, deviceID int64, status domain.DeviceStatus) error {
|
||||
return s.deviceRepo.UpdateStatus(ctx, deviceID, status)
|
||||
func (s *DeviceService) GetDeviceForActor(ctx context.Context, actorUserID, deviceID int64, isAdmin bool) (*domain.Device, error) {
|
||||
return s.getAuthorizedDevice(ctx, actorUserID, deviceID, isAdmin)
|
||||
}
|
||||
|
||||
func (s *DeviceService) UpdateDeviceForActor(ctx context.Context, actorUserID, deviceID int64, isAdmin bool, req *UpdateDeviceRequest) (*domain.Device, error) {
|
||||
device, err := s.getAuthorizedDevice(ctx, actorUserID, deviceID, isAdmin)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.persistDeviceUpdate(ctx, device, req)
|
||||
}
|
||||
|
||||
func (s *DeviceService) DeleteDeviceForActor(ctx context.Context, actorUserID, deviceID int64, isAdmin bool) error {
|
||||
device, err := s.getAuthorizedDevice(ctx, actorUserID, deviceID, isAdmin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.deviceRepo.Delete(ctx, device.ID)
|
||||
}
|
||||
|
||||
func (s *DeviceService) UpdateDeviceStatus(ctx context.Context, deviceID int64, status domain.DeviceStatus) error {
|
||||
device, err := s.getDeviceByID(ctx, deviceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.deviceRepo.UpdateStatus(ctx, device.ID, status)
|
||||
}
|
||||
|
||||
func (s *DeviceService) UpdateDeviceStatusForActor(ctx context.Context, actorUserID, deviceID int64, isAdmin bool, status domain.DeviceStatus) error {
|
||||
device, err := s.getAuthorizedDevice(ctx, actorUserID, deviceID, isAdmin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.deviceRepo.UpdateStatus(ctx, device.ID, status)
|
||||
}
|
||||
|
||||
// UpdateLastActiveTime 更新最后活跃时间
|
||||
func (s *DeviceService) UpdateLastActiveTime(ctx context.Context, deviceID int64) error {
|
||||
return s.deviceRepo.UpdateLastActiveTime(ctx, deviceID)
|
||||
}
|
||||
|
||||
// GetActiveDevices 获取活跃设备
|
||||
func (s *DeviceService) GetActiveDevices(ctx context.Context, page, pageSize int) ([]*domain.Device, int64, error) {
|
||||
offset := (page - 1) * pageSize
|
||||
if page <= 0 {
|
||||
page = 1
|
||||
}
|
||||
@@ -197,74 +278,72 @@ func (s *DeviceService) GetActiveDevices(ctx context.Context, page, pageSize int
|
||||
pageSize = 20
|
||||
}
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
return s.deviceRepo.ListByStatus(ctx, domain.DeviceStatusActive, offset, pageSize)
|
||||
}
|
||||
|
||||
// TrustDevice 设置设备为信任状态
|
||||
func (s *DeviceService) TrustDevice(ctx context.Context, deviceID int64, trustDuration time.Duration) error {
|
||||
device, err := s.deviceRepo.GetByID(ctx, deviceID)
|
||||
device, err := s.getDeviceByID(ctx, deviceID)
|
||||
if err != nil {
|
||||
return errors.New("设备不存在")
|
||||
return err
|
||||
}
|
||||
|
||||
var trustExpiresAt *time.Time
|
||||
if trustDuration > 0 {
|
||||
expiresAt := time.Now().Add(trustDuration)
|
||||
trustExpiresAt = &expiresAt
|
||||
}
|
||||
|
||||
return s.deviceRepo.TrustDevice(ctx, device.ID, trustExpiresAt)
|
||||
return s.trustDeviceRecord(ctx, device, trustDuration)
|
||||
}
|
||||
|
||||
func (s *DeviceService) TrustDeviceForActor(ctx context.Context, actorUserID, deviceID int64, isAdmin bool, trustDuration time.Duration) error {
|
||||
device, err := s.getAuthorizedDevice(ctx, actorUserID, deviceID, isAdmin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.trustDeviceRecord(ctx, device, trustDuration)
|
||||
}
|
||||
|
||||
// TrustDeviceByDeviceID 根据设备标识字符串设置设备为信任状态
|
||||
func (s *DeviceService) TrustDeviceByDeviceID(ctx context.Context, userID int64, deviceID string, trustDuration time.Duration) error {
|
||||
device, err := s.deviceRepo.GetByDeviceID(ctx, userID, deviceID)
|
||||
if err != nil {
|
||||
return errors.New("设备不存在")
|
||||
if isDeviceNotFoundError(err) {
|
||||
return apierrors.NotFound("device_not_found", "device not found").WithCause(err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
var trustExpiresAt *time.Time
|
||||
if trustDuration > 0 {
|
||||
expiresAt := time.Now().Add(trustDuration)
|
||||
trustExpiresAt = &expiresAt
|
||||
}
|
||||
|
||||
return s.deviceRepo.TrustDevice(ctx, device.ID, trustExpiresAt)
|
||||
return s.trustDeviceRecord(ctx, device, trustDuration)
|
||||
}
|
||||
|
||||
// UntrustDevice 取消设备信任状态
|
||||
func (s *DeviceService) UntrustDevice(ctx context.Context, deviceID int64) error {
|
||||
device, err := s.deviceRepo.GetByID(ctx, deviceID)
|
||||
device, err := s.getDeviceByID(ctx, deviceID)
|
||||
if err != nil {
|
||||
return errors.New("设备不存在")
|
||||
return err
|
||||
}
|
||||
return s.deviceRepo.UntrustDevice(ctx, device.ID)
|
||||
}
|
||||
|
||||
func (s *DeviceService) UntrustDeviceForActor(ctx context.Context, actorUserID, deviceID int64, isAdmin bool) error {
|
||||
device, err := s.getAuthorizedDevice(ctx, actorUserID, deviceID, isAdmin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.deviceRepo.UntrustDevice(ctx, device.ID)
|
||||
}
|
||||
|
||||
// LogoutAllOtherDevices 登出所有其他设备
|
||||
func (s *DeviceService) LogoutAllOtherDevices(ctx context.Context, userID int64, currentDeviceID int64) error {
|
||||
return s.deviceRepo.DeleteAllByUserIDExcept(ctx, userID, currentDeviceID)
|
||||
}
|
||||
|
||||
// GetTrustedDevices 获取用户的信任设备列表
|
||||
func (s *DeviceService) GetTrustedDevices(ctx context.Context, userID int64) ([]*domain.Device, error) {
|
||||
return s.deviceRepo.GetTrustedDevices(ctx, userID)
|
||||
}
|
||||
|
||||
// GetAllDevicesRequest 获取所有设备请求参数
|
||||
type GetAllDevicesRequest struct {
|
||||
Page int `form:"page"`
|
||||
PageSize int `form:"page_size"`
|
||||
UserID int64 `form:"user_id"`
|
||||
Status *int `form:"status"` // 0-禁用, 1-激活, nil-不筛选
|
||||
Status *int `form:"status"`
|
||||
IsTrusted *bool `form:"is_trusted"`
|
||||
Keyword string `form:"keyword"`
|
||||
Cursor string `form:"cursor"` // Opaque cursor for keyset pagination
|
||||
Size int `form:"size"` // Page size when using cursor mode
|
||||
Cursor string `form:"cursor"`
|
||||
Size int `form:"size"`
|
||||
}
|
||||
|
||||
// GetAllDevices 获取所有设备(管理员用)
|
||||
func (s *DeviceService) GetAllDevices(ctx context.Context, req *GetAllDevicesRequest) ([]*domain.Device, int64, error) {
|
||||
if req.Page <= 0 {
|
||||
req.Page = 1
|
||||
@@ -277,7 +356,6 @@ func (s *DeviceService) GetAllDevices(ctx context.Context, req *GetAllDevicesReq
|
||||
}
|
||||
|
||||
offset := (req.Page - 1) * req.PageSize
|
||||
|
||||
params := &repository.ListDevicesParams{
|
||||
UserID: req.UserID,
|
||||
Keyword: req.Keyword,
|
||||
@@ -285,13 +363,10 @@ func (s *DeviceService) GetAllDevices(ctx context.Context, req *GetAllDevicesReq
|
||||
Limit: req.PageSize,
|
||||
}
|
||||
|
||||
// 处理状态筛选(仅当明确指定了状态时才筛选)
|
||||
if req.Status != nil && (*req.Status == 0 || *req.Status == 1) {
|
||||
status := domain.DeviceStatus(*req.Status)
|
||||
params.Status = &status
|
||||
}
|
||||
|
||||
// 处理信任状态筛选
|
||||
if req.IsTrusted != nil {
|
||||
params.IsTrusted = req.IsTrusted
|
||||
}
|
||||
@@ -299,7 +374,6 @@ func (s *DeviceService) GetAllDevices(ctx context.Context, req *GetAllDevicesReq
|
||||
return s.deviceRepo.ListAll(ctx, params)
|
||||
}
|
||||
|
||||
// GetAllDevicesCursor 游标分页获取所有设备(推荐使用)
|
||||
func (s *DeviceService) GetAllDevicesCursor(ctx context.Context, req *GetAllDevicesRequest) (*CursorResult, error) {
|
||||
size := pagination.ClampPageSize(req.Size)
|
||||
if req.PageSize > 0 && req.Cursor == "" {
|
||||
@@ -342,7 +416,6 @@ func (s *DeviceService) GetAllDevicesCursor(ctx context.Context, req *GetAllDevi
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetDeviceByDeviceID 根据设备标识获取设备(用于设备信任检查)
|
||||
func (s *DeviceService) GetDeviceByDeviceID(ctx context.Context, userID int64, deviceID string) (*domain.Device, error) {
|
||||
return s.deviceRepo.GetByDeviceID(ctx, userID, deviceID)
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
apierrors "github.com/user-management-system/internal/pkg/errors"
|
||||
"github.com/user-management-system/internal/repository"
|
||||
"github.com/user-management-system/internal/service"
|
||||
gormsqlite "gorm.io/driver/sqlite"
|
||||
@@ -156,6 +157,104 @@ func TestDeviceService_GetDevice(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeviceService_DeviceOwnershipAuthorization(t *testing.T) {
|
||||
svc, db := setupDeviceTestEnv(t)
|
||||
ctx := context.Background()
|
||||
|
||||
owner := &domain.User{Username: "device_owner", Status: domain.UserStatusActive}
|
||||
if err := db.Create(owner).Error; err != nil {
|
||||
t.Fatalf("create owner failed: %v", err)
|
||||
}
|
||||
|
||||
actor := &domain.User{Username: "device_actor", Status: domain.UserStatusActive}
|
||||
if err := db.Create(actor).Error; err != nil {
|
||||
t.Fatalf("create actor failed: %v", err)
|
||||
}
|
||||
|
||||
device, err := svc.CreateDevice(ctx, owner.ID, &service.CreateDeviceRequest{
|
||||
DeviceID: "ownership_device",
|
||||
DeviceName: "Owner Device",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateDevice failed: %v", err)
|
||||
}
|
||||
|
||||
t.Run("GetDeviceForActor forbids cross-user access", func(t *testing.T) {
|
||||
_, err := svc.GetDeviceForActor(ctx, actor.ID, device.ID, false)
|
||||
if !apierrors.IsForbidden(err) {
|
||||
t.Fatalf("expected forbidden error, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("UpdateDeviceForActor forbids cross-user access", func(t *testing.T) {
|
||||
_, err := svc.UpdateDeviceForActor(ctx, actor.ID, device.ID, false, &service.UpdateDeviceRequest{
|
||||
DeviceName: "Hacked Name",
|
||||
})
|
||||
if !apierrors.IsForbidden(err) {
|
||||
t.Fatalf("expected forbidden error, got %v", err)
|
||||
}
|
||||
|
||||
current, getErr := svc.GetDevice(ctx, device.ID)
|
||||
if getErr != nil {
|
||||
t.Fatalf("GetDevice failed: %v", getErr)
|
||||
}
|
||||
if current.DeviceName != "Owner Device" {
|
||||
t.Fatalf("expected device name to remain unchanged, got %q", current.DeviceName)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DeleteDeviceForActor forbids cross-user access", func(t *testing.T) {
|
||||
err := svc.DeleteDeviceForActor(ctx, actor.ID, device.ID, false)
|
||||
if !apierrors.IsForbidden(err) {
|
||||
t.Fatalf("expected forbidden error, got %v", err)
|
||||
}
|
||||
|
||||
if _, getErr := svc.GetDevice(ctx, device.ID); getErr != nil {
|
||||
t.Fatalf("expected device to remain after forbidden delete, got %v", getErr)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TrustDeviceForActor forbids cross-user access", func(t *testing.T) {
|
||||
err := svc.TrustDeviceForActor(ctx, actor.ID, device.ID, false, time.Hour)
|
||||
if !apierrors.IsForbidden(err) {
|
||||
t.Fatalf("expected forbidden error, got %v", err)
|
||||
}
|
||||
|
||||
current, getErr := svc.GetDevice(ctx, device.ID)
|
||||
if getErr != nil {
|
||||
t.Fatalf("GetDevice failed: %v", getErr)
|
||||
}
|
||||
if current.IsTrusted {
|
||||
t.Fatal("expected device to remain untrusted")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("UpdateDeviceStatusForActor forbids cross-user access", func(t *testing.T) {
|
||||
err := svc.UpdateDeviceStatusForActor(ctx, actor.ID, device.ID, false, domain.DeviceStatusInactive)
|
||||
if !apierrors.IsForbidden(err) {
|
||||
t.Fatalf("expected forbidden error, got %v", err)
|
||||
}
|
||||
|
||||
current, getErr := svc.GetDevice(ctx, device.ID)
|
||||
if getErr != nil {
|
||||
t.Fatalf("GetDevice failed: %v", getErr)
|
||||
}
|
||||
if current.Status != domain.DeviceStatusActive {
|
||||
t.Fatalf("expected device to remain active, got %d", current.Status)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Admin can manage another users device", func(t *testing.T) {
|
||||
got, err := svc.GetDeviceForActor(ctx, actor.ID, device.ID, true)
|
||||
if err != nil {
|
||||
t.Fatalf("expected admin access, got %v", err)
|
||||
}
|
||||
if got.ID != device.ID {
|
||||
t.Fatalf("expected device id %d, got %d", device.ID, got.ID)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeviceService_GetUserDevices(t *testing.T) {
|
||||
svc, _ := setupDeviceTestEnv(t)
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
@@ -103,32 +104,101 @@ func (s *UserService) ChangePassword(ctx context.Context, userID int64, oldPassw
|
||||
if strings.TrimSpace(newPassword) == "" {
|
||||
return errors.New("新密码不能为空")
|
||||
}
|
||||
return s.applyNewPassword(ctx, user, newPassword)
|
||||
/*
|
||||
if err := validatePasswordStrength(newPassword, 8, false); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 检查密码历史(需要明文密码比对,必须在哈希之前)
|
||||
if s.passwordHistoryRepo != nil {
|
||||
histories, err := s.passwordHistoryRepo.GetByUserID(ctx, userID, passwordHistoryLimit)
|
||||
if err == nil && len(histories) > 0 {
|
||||
for _, h := range histories {
|
||||
if auth.VerifyPassword(h.PasswordHash, newPassword) {
|
||||
return errors.New("新密码不能与最近5次密码相同")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 计算一次哈希,用于更新密码和保存历史(避免 Argon2id 重复计算的高成本)
|
||||
newHashedPassword, hashErr := auth.HashPassword(newPassword)
|
||||
if hashErr != nil {
|
||||
return errors.New("密码哈希失败")
|
||||
}
|
||||
|
||||
// 保存新密码到历史记录(异步,不阻塞密码更新)
|
||||
if s.passwordHistoryRepo != nil {
|
||||
// #nosec G118 - 使用带超时的独立 context(不能使用请求 ctx,该 goroutine 在请求完成后仍可能运行)
|
||||
go func(hashedPw string) { // #nosec G118
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = s.passwordHistoryRepo.Create(bgCtx, &domain.PasswordHistory{
|
||||
UserID: userID,
|
||||
PasswordHash: hashedPw,
|
||||
})
|
||||
_ = s.passwordHistoryRepo.DeleteOldRecords(bgCtx, userID, passwordHistoryLimit)
|
||||
}(newHashedPassword)
|
||||
}
|
||||
|
||||
// 更新密码(使用同一哈希值)
|
||||
user.Password = newHashedPassword
|
||||
user.PasswordChangedAt = time.Now()
|
||||
return s.userRepo.Update(ctx, user)
|
||||
*/
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取用户
|
||||
// AdminResetPassword resets a user's password without requiring the old password.
|
||||
func (s *UserService) AdminResetPassword(ctx context.Context, userID int64, newPassword string) error {
|
||||
if s.userRepo == nil {
|
||||
return errors.New("user repository is not configured")
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return errors.New("user not found")
|
||||
}
|
||||
|
||||
return s.applyNewPassword(ctx, user, newPassword)
|
||||
}
|
||||
|
||||
func (s *UserService) applyNewPassword(ctx context.Context, user *domain.User, newPassword string) error {
|
||||
if user == nil {
|
||||
return errors.New("user not found")
|
||||
}
|
||||
|
||||
if strings.TrimSpace(newPassword) == "" {
|
||||
return errors.New("new password is required")
|
||||
}
|
||||
if err := validatePasswordStrength(newPassword, 8, false); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 检查密码历史(需要明文密码比对,必须在哈希之前)
|
||||
if s.passwordHistoryRepo != nil {
|
||||
histories, err := s.passwordHistoryRepo.GetByUserID(ctx, userID, passwordHistoryLimit)
|
||||
histories, err := s.passwordHistoryRepo.GetByUserID(ctx, user.ID, passwordHistoryLimit)
|
||||
if err == nil && len(histories) > 0 {
|
||||
for _, h := range histories {
|
||||
if auth.VerifyPassword(h.PasswordHash, newPassword) {
|
||||
return errors.New("新密码不能与最近5次密码相同")
|
||||
return errors.New("new password cannot reuse recent password history")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 计算一次哈希,用于更新密码和保存历史(避免 Argon2id 重复计算的高成本)
|
||||
newHashedPassword, hashErr := auth.HashPassword(newPassword)
|
||||
if hashErr != nil {
|
||||
return errors.New("密码哈希失败")
|
||||
return errors.New("password hashing failed")
|
||||
}
|
||||
|
||||
// 保存新密码到历史记录(异步,不阻塞密码更新)
|
||||
if s.passwordHistoryRepo != nil {
|
||||
// #nosec G118 - 使用带超时的独立 context(不能使用请求 ctx,该 goroutine 在请求完成后仍可能运行)
|
||||
go func(hashedPw string) { // #nosec G118
|
||||
go func(userID int64, hashedPw string) { // #nosec G118
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("user_service: password history save panic recovered, user_id=%d err=%v", userID, r)
|
||||
}
|
||||
}()
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = s.passwordHistoryRepo.Create(bgCtx, &domain.PasswordHistory{
|
||||
@@ -136,16 +206,14 @@ func (s *UserService) ChangePassword(ctx context.Context, userID int64, oldPassw
|
||||
PasswordHash: hashedPw,
|
||||
})
|
||||
_ = s.passwordHistoryRepo.DeleteOldRecords(bgCtx, userID, passwordHistoryLimit)
|
||||
}(newHashedPassword)
|
||||
}(user.ID, newHashedPassword)
|
||||
}
|
||||
|
||||
// 更新密码(使用同一哈希值)
|
||||
user.Password = newHashedPassword
|
||||
user.PasswordChangedAt = time.Now()
|
||||
return s.userRepo.Update(ctx, user)
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取用户
|
||||
func (s *UserService) GetByID(ctx context.Context, id int64) (*domain.User, error) {
|
||||
return s.userRepo.GetByID(ctx, id)
|
||||
}
|
||||
@@ -357,10 +425,23 @@ func (s *UserService) AssignRoles(ctx context.Context, userID int64, roleIDs []i
|
||||
return err
|
||||
}
|
||||
|
||||
// 验证所有角色存在(预先验证,避免在事务内做不必要的查询)
|
||||
for _, roleID := range roleIDs {
|
||||
if _, err := s.roleRepo.GetByID(ctx, roleID); err != nil {
|
||||
return fmt.Errorf("角色 %d 不存在", roleID)
|
||||
// 验证所有角色存在(批量查询消除 N+1)
|
||||
if len(roleIDs) > 0 {
|
||||
foundRoles, err := s.roleRepo.GetByIDs(ctx, roleIDs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("验证角色失败: %w", err)
|
||||
}
|
||||
if len(foundRoles) != len(roleIDs) {
|
||||
// 找出缺失的角色ID
|
||||
foundMap := make(map[int64]bool, len(foundRoles))
|
||||
for _, r := range foundRoles {
|
||||
foundMap[r.ID] = true
|
||||
}
|
||||
for _, id := range roleIDs {
|
||||
if !foundMap[id] {
|
||||
return fmt.Errorf("角色 %d 不存在", id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -341,6 +341,44 @@ func TestUserService_ChangePassword(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestUserService_AdminResetPassword(t *testing.T) {
|
||||
env := setupAuthTestEnv(t)
|
||||
if env == nil {
|
||||
return
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Admin reset password success", func(t *testing.T) {
|
||||
hashedPassword, _ := auth.HashPassword("OldPassword123!")
|
||||
user := &domain.User{
|
||||
Username: "adminresetpwd",
|
||||
Password: hashedPassword,
|
||||
Status: domain.UserStatusActive,
|
||||
}
|
||||
env.userSvc.Create(ctx, user)
|
||||
|
||||
err := env.userSvc.AdminResetPassword(ctx, user.ID, "ResetPassword456!")
|
||||
if err != nil {
|
||||
t.Fatalf("AdminResetPassword failed: %v", err)
|
||||
}
|
||||
|
||||
updated, _ := env.userSvc.GetByID(ctx, user.ID)
|
||||
if !auth.VerifyPassword(updated.Password, "ResetPassword456!") {
|
||||
t.Error("reset password verification failed")
|
||||
}
|
||||
if auth.VerifyPassword(updated.Password, "OldPassword123!") {
|
||||
t.Error("old password should no longer work after admin reset")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Admin reset password for non-existent user", func(t *testing.T) {
|
||||
err := env.userSvc.AdminResetPassword(ctx, 99999, "ResetPassword456!")
|
||||
if err == nil {
|
||||
t.Error("Expected error for non-existent user")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestUserService_BatchUpdateStatus(t *testing.T) {
|
||||
env := setupAuthTestEnv(t)
|
||||
if env == nil {
|
||||
|
||||
Reference in New Issue
Block a user