fix: enforce resource ownership checks

This commit is contained in:
Your Name
2026-05-28 17:28:08 +08:00
parent 7eb5f9c7d4
commit 11232177d9
4 changed files with 209 additions and 22 deletions

View File

@@ -8,6 +8,7 @@ import (
"github.com/gin-gonic/gin"
apimiddleware "github.com/user-management-system/internal/api/middleware"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/service"
)
@@ -118,9 +119,8 @@ func (h *DeviceHandler) GetDevice(c *gin.Context) {
return
}
device, err := h.deviceService.GetDevice(c.Request.Context(), id)
if err != nil {
handleError(c, err)
device, ok := h.authorizeDeviceAccess(c, id)
if !ok {
return
}
@@ -151,6 +151,10 @@ func (h *DeviceHandler) UpdateDevice(c *gin.Context) {
return
}
if _, ok := h.authorizeDeviceAccess(c, id); !ok {
return
}
var req service.UpdateDeviceRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": err.Error()})
@@ -187,6 +191,10 @@ func (h *DeviceHandler) DeleteDevice(c *gin.Context) {
return
}
if _, ok := h.authorizeDeviceAccess(c, id); !ok {
return
}
if err := h.deviceService.DeleteDevice(c.Request.Context(), id); err != nil {
handleError(c, err)
return
@@ -218,6 +226,10 @@ func (h *DeviceHandler) UpdateDeviceStatus(c *gin.Context) {
return
}
if _, ok := h.authorizeDeviceAccess(c, id); !ok {
return
}
var req struct {
Status string `json:"status" binding:"required"`
}
@@ -269,27 +281,14 @@ func (h *DeviceHandler) GetUserDevices(c *gin.Context) {
return
}
// 检查是否为管理员
roleCodes, _ := c.Get("role_codes")
isAdmin := false
if roles, ok := roleCodes.([]string); ok {
for _, role := range roles {
if role == "admin" {
isAdmin = true
break
}
}
}
userIDParam := c.Param("id")
userID, err := strconv.ParseInt(userIDParam, 10, 64)
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": "invalid user id"})
return
}
// 非管理员只能查看自己的设备
if !isAdmin && userID != currentUserID {
if !apimiddleware.IsAdmin(c) && userID != currentUserID {
c.JSON(http.StatusForbidden, gin.H{"code": 403, "message": "无权访问该用户的设备列表"})
return
}
@@ -396,6 +395,10 @@ func (h *DeviceHandler) TrustDevice(c *gin.Context) {
return
}
if _, ok := h.authorizeDeviceAccess(c, id); !ok {
return
}
var req TrustDeviceRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": err.Error()})
@@ -478,6 +481,10 @@ func (h *DeviceHandler) UntrustDevice(c *gin.Context) {
return
}
if _, ok := h.authorizeDeviceAccess(c, id); !ok {
return
}
if err := h.deviceService.UntrustDevice(c.Request.Context(), id); err != nil {
handleError(c, err)
return
@@ -555,6 +562,27 @@ func (h *DeviceHandler) LogoutAllOtherDevices(c *gin.Context) {
})
}
func (h *DeviceHandler) authorizeDeviceAccess(c *gin.Context, deviceID int64) (*domain.Device, bool) {
currentUserID, ok := getUserIDFromContext(c)
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "unauthorized"})
return nil, false
}
device, err := h.deviceService.GetDevice(c.Request.Context(), deviceID)
if err != nil {
handleError(c, err)
return nil, false
}
if device.UserID != currentUserID && !apimiddleware.IsAdmin(c) {
c.JSON(http.StatusForbidden, gin.H{"code": 403, "message": "permission denied"})
return nil, false
}
return device, true
}
// parseDuration 解析duration字符串如 "30d" -> 30天的time.Duration
func parseDuration(s string) time.Duration {
if s == "" {

View File

@@ -118,6 +118,7 @@ func setupHandlerTestServer(t *testing.T) (*httptest.Server, func()) {
deviceSvc := service.NewDeviceService(deviceRepo, userRepo)
loginLogSvc := service.NewLoginLogService(loginLogRepo)
opLogSvc := service.NewOperationLogService(opLogRepo)
webhookSvc := service.NewWebhookService(db)
captchaSvc := service.NewCaptchaService(cacheManager)
totpSvc := service.NewTOTPService(userRepo)
pwdResetCfg := service.DefaultPasswordResetConfig()
@@ -141,6 +142,7 @@ func setupHandlerTestServer(t *testing.T) (*httptest.Server, func()) {
permHandler := handler.NewPermissionHandler(permSvc)
deviceHandler := handler.NewDeviceHandler(deviceSvc)
logHandler := handler.NewLogHandler(loginLogSvc, opLogSvc)
webhookHandler := handler.NewWebhookHandler(webhookSvc)
captchaHandler := handler.NewCaptchaHandler(captchaSvc)
totpHandler := handler.NewTOTPHandler(authSvc, totpSvc)
pwdResetHandler := handler.NewPasswordResetHandler(pwdResetSvc)
@@ -149,7 +151,7 @@ func setupHandlerTestServer(t *testing.T) (*httptest.Server, func()) {
r := router.NewRouter(
authHandler, userHandler, roleHandler, permHandler, deviceHandler,
logHandler, authMiddleware, rateLimitMiddleware, opLogMiddleware,
pwdResetHandler, captchaHandler, totpHandler, nil,
pwdResetHandler, captchaHandler, totpHandler, webhookHandler,
nil, nil, nil, nil, nil, themeHandler, nil, nil, nil, avatarH,
)
engine := r.Setup()
@@ -233,6 +235,62 @@ func registerUser(baseURL, username, email, password string) bool {
return resp.StatusCode == http.StatusCreated
}
func createDeviceAndGetID(t *testing.T, baseURL, token, deviceID string) int64 {
t.Helper()
resp, body := doPost(baseURL+"/api/v1/devices", token, map[string]interface{}{
"device_id": deviceID,
"device_name": "Owned Device",
"device_type": 3,
"device_os": "Linux",
"device_browser": "Chrome",
})
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
t.Fatalf("create device failed: status=%d body=%s", resp.StatusCode, body)
}
var result struct {
Data struct {
ID int64 `json:"id"`
} `json:"data"`
}
if err := json.Unmarshal([]byte(body), &result); err != nil {
t.Fatalf("decode create device response failed: %v body=%s", err, body)
}
if result.Data.ID == 0 {
t.Fatalf("expected non-zero device id, body=%s", body)
}
return result.Data.ID
}
func createWebhookAndGetID(t *testing.T, baseURL, token, name string) int64 {
t.Helper()
resp, body := doPost(baseURL+"/api/v1/webhooks", token, map[string]interface{}{
"name": name,
"url": "https://example.com/webhook",
"events": []string{"user.created"},
})
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
t.Fatalf("create webhook failed: status=%d body=%s", resp.StatusCode, body)
}
var result struct {
Data struct {
ID int64 `json:"id"`
} `json:"data"`
}
if err := json.Unmarshal([]byte(body), &result); err != nil {
t.Fatalf("decode create webhook response failed: %v body=%s", err, body)
}
if result.Data.ID == 0 {
t.Fatalf("expected non-zero webhook id, body=%s", body)
}
return result.Data.ID
}
func bootstrapAdminToken(baseURL, username, email, password string) string {
payload, _ := json.Marshal(map[string]interface{}{
"username": username,
@@ -876,6 +934,73 @@ func TestDeviceHandler_CreateDevice_Success(t *testing.T) {
}
}
func TestDeviceHandler_DeviceByIDRoutes_ForbiddenForOtherUser(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "device-owner", "device-owner@test.com", "UserPass123!")
registerUser(server.URL, "device-attacker", "device-attacker@test.com", "UserPass123!")
ownerToken := getToken(server.URL, "device-owner", "UserPass123!")
attackerToken := getToken(server.URL, "device-attacker", "UserPass123!")
deviceID := createDeviceAndGetID(t, server.URL, ownerToken, "device-owner-001")
tests := []struct {
name string
method string
url string
body map[string]interface{}
}{
{name: "get", method: http.MethodGet, url: fmt.Sprintf("%s/api/v1/devices/%d", server.URL, deviceID)},
{name: "update", method: http.MethodPut, url: fmt.Sprintf("%s/api/v1/devices/%d", server.URL, deviceID), body: map[string]interface{}{"device_name": "hijacked"}},
{name: "delete", method: http.MethodDelete, url: fmt.Sprintf("%s/api/v1/devices/%d", server.URL, deviceID)},
{name: "status", method: http.MethodPut, url: fmt.Sprintf("%s/api/v1/devices/%d/status", server.URL, deviceID), body: map[string]interface{}{"status": "inactive"}},
{name: "trust", method: http.MethodPost, url: fmt.Sprintf("%s/api/v1/devices/%d/trust", server.URL, deviceID), body: map[string]interface{}{"trust_duration": "30d"}},
{name: "untrust", method: http.MethodDelete, url: fmt.Sprintf("%s/api/v1/devices/%d/trust", server.URL, deviceID)},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
resp, body := doRequest(tc.method, tc.url, attackerToken, tc.body)
defer resp.Body.Close()
if resp.StatusCode != http.StatusForbidden {
t.Fatalf("expected 403 for %s, got %d body=%s", tc.name, resp.StatusCode, body)
}
})
}
}
func TestWebhookHandler_OtherUserCannotManageForeignWebhook(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "webhook-owner", "webhook-owner@test.com", "UserPass123!")
registerUser(server.URL, "webhook-attacker", "webhook-attacker@test.com", "UserPass123!")
ownerToken := getToken(server.URL, "webhook-owner", "UserPass123!")
attackerToken := getToken(server.URL, "webhook-attacker", "UserPass123!")
webhookID := createWebhookAndGetID(t, server.URL, ownerToken, "owner-webhook")
tests := []struct {
name string
method string
url string
body map[string]interface{}
}{
{name: "update", method: http.MethodPut, url: fmt.Sprintf("%s/api/v1/webhooks/%d", server.URL, webhookID), body: map[string]interface{}{"name": "hijacked"}},
{name: "delete", method: http.MethodDelete, url: fmt.Sprintf("%s/api/v1/webhooks/%d", server.URL, webhookID)},
{name: "deliveries", method: http.MethodGet, url: fmt.Sprintf("%s/api/v1/webhooks/%d/deliveries", server.URL, webhookID)},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
resp, body := doRequest(tc.method, tc.url, attackerToken, tc.body)
defer resp.Body.Close()
if resp.StatusCode != http.StatusForbidden {
t.Fatalf("expected 403 for webhook %s, got %d body=%s", tc.name, resp.StatusCode, body)
}
})
}
}
// =============================================================================
// Role Handler Tests
// =============================================================================

View File

@@ -6,6 +6,7 @@ import (
"github.com/gin-gonic/gin"
apimiddleware "github.com/user-management-system/internal/api/middleware"
"github.com/user-management-system/internal/service"
)
@@ -117,6 +118,10 @@ func (h *WebhookHandler) UpdateWebhook(c *gin.Context) {
return
}
if _, ok := h.authorizeWebhookAccess(c, id); !ok {
return
}
var req service.UpdateWebhookRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": err.Error()})
@@ -150,6 +155,10 @@ func (h *WebhookHandler) DeleteWebhook(c *gin.Context) {
return
}
if _, ok := h.authorizeWebhookAccess(c, id); !ok {
return
}
if err := h.webhookService.DeleteWebhook(c.Request.Context(), id); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"code": 500, "message": "删除 Webhook 失败"})
return
@@ -178,6 +187,10 @@ func (h *WebhookHandler) GetWebhookDeliveries(c *gin.Context) {
return
}
if _, ok := h.authorizeWebhookAccess(c, id); !ok {
return
}
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20"))
if limit < 1 || limit > 100 {
limit = 20
@@ -191,3 +204,24 @@ func (h *WebhookHandler) GetWebhookDeliveries(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"code": 0, "message": "success", "data": gin.H{"deliveries": deliveries}})
}
func (h *WebhookHandler) authorizeWebhookAccess(c *gin.Context, webhookID int64) (int64, bool) {
userID, ok := getUserIDFromContext(c)
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "unauthorized"})
return 0, false
}
webhook, err := h.webhookService.GetWebhook(c.Request.Context(), webhookID)
if err != nil {
handleError(c, err)
return 0, false
}
if webhook.CreatedBy != userID && !apimiddleware.IsAdmin(c) {
c.JSON(http.StatusForbidden, gin.H{"code": 403, "message": "permission denied"})
return 0, false
}
return userID, true
}

View File

@@ -359,9 +359,9 @@ func TestWebhookHandler_DeleteWebhook_NotFound(t *testing.T) {
resp := doRequestWithCheck(t, "DELETE", server.URL+"/api/v1/webhooks/99999", token, nil)
defer resp.Body.Close()
// Delete is idempotent - returns 200 even if not found
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status 200, got %d", resp.StatusCode)
// 先做归属/存在性校验,不存在的 webhook 返回 404
if resp.StatusCode != http.StatusNotFound {
t.Fatalf("expected status 404, got %d", resp.StatusCode)
}
}