fix: harden handler context and rate limit isolation

This commit is contained in:
Your Name
2026-05-28 20:30:24 +08:00
parent e46567678f
commit caad1aba0c
6 changed files with 311 additions and 37 deletions

View File

@@ -759,6 +759,15 @@ func getUserIDFromContext(c *gin.Context) (int64, bool) {
return id, ok
}
func getUsernameFromContext(c *gin.Context) (string, bool) {
username, exists := c.Get("username")
if !exists {
return "", false
}
usernameStr, ok := username.(string)
return usernameStr, ok
}
// handleError 将 error 转换为对应的 HTTP 响应。
// 优先识别 ApplicationError其次通过关键词推断业务错误类型兜底返回 500。
func handleError(c *gin.Context, err error) {

View File

@@ -0,0 +1,95 @@
package handler
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
func init() {
gin.SetMode(gin.TestMode)
}
func TestSSOHandlerAuthorize_InvalidContextTypes_ReturnsUnauthorized(t *testing.T) {
h := &SSOHandler{}
engine := gin.New()
engine.GET("/authorize", func(c *gin.Context) {
c.Set("user_id", "not-int64")
c.Set("username", 123)
h.Authorize(c)
})
req := httptest.NewRequest(http.MethodGet, "/authorize?client_id=test-client&redirect_uri=https://example.com/callback&response_type=code", nil)
w := httptest.NewRecorder()
engine.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d", w.Code)
}
}
func TestSSOHandlerUserInfo_InvalidContextTypes_ReturnsUnauthorized(t *testing.T) {
h := &SSOHandler{}
engine := gin.New()
engine.GET("/userinfo", func(c *gin.Context) {
c.Set("user_id", "not-int64")
c.Set("username", 123)
h.UserInfo(c)
})
req := httptest.NewRequest(http.MethodGet, "/userinfo", nil)
w := httptest.NewRecorder()
engine.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d", w.Code)
}
}
func TestWebhookHandlerCreateWebhook_InvalidContextType_ReturnsUnauthorized(t *testing.T) {
h := &WebhookHandler{}
engine := gin.New()
engine.POST("/webhooks", func(c *gin.Context) {
c.Set("user_id", "not-int64")
h.CreateWebhook(c)
})
body, err := json.Marshal(map[string]any{
"name": "test",
"url": "https://example.com/webhook",
"events": []string{"user.created"},
})
if err != nil {
t.Fatalf("marshal request: %v", err)
}
req := httptest.NewRequest(http.MethodPost, "/webhooks", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
engine.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d", w.Code)
}
}
func TestWebhookHandlerListWebhooks_InvalidContextType_ReturnsUnauthorized(t *testing.T) {
h := &WebhookHandler{}
engine := gin.New()
engine.GET("/webhooks", func(c *gin.Context) {
c.Set("user_id", "not-int64")
h.ListWebhooks(c)
})
req := httptest.NewRequest(http.MethodGet, "/webhooks?page=1&page_size=20", nil)
w := httptest.NewRecorder()
engine.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d", w.Code)
}
}

View File

@@ -72,13 +72,17 @@ func (h *SSOHandler) Authorize(c *gin.Context) {
}
// 获取当前登录用户(从 auth middleware 设置的 context
userID, exists := c.Get("user_id")
if !exists {
userID, ok := getUserIDFromContext(c)
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "unauthorized"})
return
}
username, _ := c.Get("username")
username, ok := getUsernameFromContext(c)
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "unauthorized"})
return
}
// 生成授权码或 access token
if req.ResponseType == "code" {
@@ -86,8 +90,8 @@ func (h *SSOHandler) Authorize(c *gin.Context) {
req.ClientID,
req.RedirectURI,
req.Scope,
userID.(int64),
username.(string),
userID,
username,
)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"code": 500, "message": "failed to generate code"})
@@ -106,8 +110,8 @@ func (h *SSOHandler) Authorize(c *gin.Context) {
req.ClientID,
req.RedirectURI,
req.Scope,
userID.(int64),
username.(string),
userID,
username,
)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"code": 500, "message": "failed to generate code"})
@@ -312,20 +316,24 @@ type UserInfoResponse struct {
// @Failure 500 {object} Response "服务器错误"
// @Router /api/v1/sso/userinfo [get]
func (h *SSOHandler) UserInfo(c *gin.Context) {
userID, exists := c.Get("user_id")
if !exists {
userID, ok := getUserIDFromContext(c)
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "unauthorized"})
return
}
username, _ := c.Get("username")
username, ok := getUsernameFromContext(c)
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "unauthorized"})
return
}
c.JSON(http.StatusOK, gin.H{
"code": 0,
"message": "success",
"data": UserInfoResponse{
UserID: userID.(int64),
Username: username.(string),
UserID: userID,
Username: username,
},
})
}

View File

@@ -40,8 +40,11 @@ func (h *WebhookHandler) CreateWebhook(c *gin.Context) {
return
}
userID, _ := c.Get("user_id")
creatorID, _ := userID.(int64)
creatorID, ok := getUserIDFromContext(c)
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "unauthorized"})
return
}
webhook, err := h.webhookService.CreateWebhook(c.Request.Context(), &req, creatorID)
if err != nil {
@@ -76,8 +79,11 @@ func (h *WebhookHandler) ListWebhooks(c *gin.Context) {
}
offset := (page - 1) * pageSize
userID, _ := c.Get("user_id")
creatorID, _ := userID.(int64)
creatorID, ok := getUserIDFromContext(c)
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "unauthorized"})
return
}
webhooks, total, err := h.webhookService.ListWebhooksPaginated(c.Request.Context(), creatorID, offset, pageSize)
if err != nil {