fix: harden handler context and rate limit isolation
This commit is contained in:
@@ -759,6 +759,15 @@ func getUserIDFromContext(c *gin.Context) (int64, bool) {
|
||||
return id, ok
|
||||
}
|
||||
|
||||
func getUsernameFromContext(c *gin.Context) (string, bool) {
|
||||
username, exists := c.Get("username")
|
||||
if !exists {
|
||||
return "", false
|
||||
}
|
||||
usernameStr, ok := username.(string)
|
||||
return usernameStr, ok
|
||||
}
|
||||
|
||||
// handleError 将 error 转换为对应的 HTTP 响应。
|
||||
// 优先识别 ApplicationError,其次通过关键词推断业务错误类型,兜底返回 500。
|
||||
func handleError(c *gin.Context, err error) {
|
||||
|
||||
95
internal/api/handler/context_guard_test.go
Normal file
95
internal/api/handler/context_guard_test.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
func TestSSOHandlerAuthorize_InvalidContextTypes_ReturnsUnauthorized(t *testing.T) {
|
||||
h := &SSOHandler{}
|
||||
engine := gin.New()
|
||||
engine.GET("/authorize", func(c *gin.Context) {
|
||||
c.Set("user_id", "not-int64")
|
||||
c.Set("username", 123)
|
||||
h.Authorize(c)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/authorize?client_id=test-client&redirect_uri=https://example.com/callback&response_type=code", nil)
|
||||
w := httptest.NewRecorder()
|
||||
engine.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected 401, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSOHandlerUserInfo_InvalidContextTypes_ReturnsUnauthorized(t *testing.T) {
|
||||
h := &SSOHandler{}
|
||||
engine := gin.New()
|
||||
engine.GET("/userinfo", func(c *gin.Context) {
|
||||
c.Set("user_id", "not-int64")
|
||||
c.Set("username", 123)
|
||||
h.UserInfo(c)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/userinfo", nil)
|
||||
w := httptest.NewRecorder()
|
||||
engine.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected 401, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebhookHandlerCreateWebhook_InvalidContextType_ReturnsUnauthorized(t *testing.T) {
|
||||
h := &WebhookHandler{}
|
||||
engine := gin.New()
|
||||
engine.POST("/webhooks", func(c *gin.Context) {
|
||||
c.Set("user_id", "not-int64")
|
||||
h.CreateWebhook(c)
|
||||
})
|
||||
|
||||
body, err := json.Marshal(map[string]any{
|
||||
"name": "test",
|
||||
"url": "https://example.com/webhook",
|
||||
"events": []string{"user.created"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("marshal request: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/webhooks", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
engine.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected 401, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebhookHandlerListWebhooks_InvalidContextType_ReturnsUnauthorized(t *testing.T) {
|
||||
h := &WebhookHandler{}
|
||||
engine := gin.New()
|
||||
engine.GET("/webhooks", func(c *gin.Context) {
|
||||
c.Set("user_id", "not-int64")
|
||||
h.ListWebhooks(c)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/webhooks?page=1&page_size=20", nil)
|
||||
w := httptest.NewRecorder()
|
||||
engine.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected 401, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
@@ -72,13 +72,17 @@ func (h *SSOHandler) Authorize(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 获取当前登录用户(从 auth middleware 设置的 context)
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
userID, ok := getUserIDFromContext(c)
|
||||
if !ok {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
username, _ := c.Get("username")
|
||||
username, ok := getUsernameFromContext(c)
|
||||
if !ok {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
// 生成授权码或 access token
|
||||
if req.ResponseType == "code" {
|
||||
@@ -86,8 +90,8 @@ func (h *SSOHandler) Authorize(c *gin.Context) {
|
||||
req.ClientID,
|
||||
req.RedirectURI,
|
||||
req.Scope,
|
||||
userID.(int64),
|
||||
username.(string),
|
||||
userID,
|
||||
username,
|
||||
)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"code": 500, "message": "failed to generate code"})
|
||||
@@ -106,8 +110,8 @@ func (h *SSOHandler) Authorize(c *gin.Context) {
|
||||
req.ClientID,
|
||||
req.RedirectURI,
|
||||
req.Scope,
|
||||
userID.(int64),
|
||||
username.(string),
|
||||
userID,
|
||||
username,
|
||||
)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"code": 500, "message": "failed to generate code"})
|
||||
@@ -312,20 +316,24 @@ type UserInfoResponse struct {
|
||||
// @Failure 500 {object} Response "服务器错误"
|
||||
// @Router /api/v1/sso/userinfo [get]
|
||||
func (h *SSOHandler) UserInfo(c *gin.Context) {
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
userID, ok := getUserIDFromContext(c)
|
||||
if !ok {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
username, _ := c.Get("username")
|
||||
username, ok := getUsernameFromContext(c)
|
||||
if !ok {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"message": "success",
|
||||
"data": UserInfoResponse{
|
||||
UserID: userID.(int64),
|
||||
Username: username.(string),
|
||||
UserID: userID,
|
||||
Username: username,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -40,8 +40,11 @@ func (h *WebhookHandler) CreateWebhook(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
userID, _ := c.Get("user_id")
|
||||
creatorID, _ := userID.(int64)
|
||||
creatorID, ok := getUserIDFromContext(c)
|
||||
if !ok {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
webhook, err := h.webhookService.CreateWebhook(c.Request.Context(), &req, creatorID)
|
||||
if err != nil {
|
||||
@@ -76,8 +79,11 @@ func (h *WebhookHandler) ListWebhooks(c *gin.Context) {
|
||||
}
|
||||
offset := (page - 1) * pageSize
|
||||
|
||||
userID, _ := c.Get("user_id")
|
||||
creatorID, _ := userID.(int64)
|
||||
creatorID, ok := getUserIDFromContext(c)
|
||||
if !ok {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
webhooks, total, err := h.webhookService.ListWebhooksPaginated(c.Request.Context(), creatorID, offset, pageSize)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -10,11 +11,20 @@ import (
|
||||
)
|
||||
|
||||
// RateLimitMiddleware 限流中间件
|
||||
// 使用 endpoint + subject(IP 或 user_id) 作为限流键,并对空闲条目做 TTL 清理,
|
||||
// 避免单一全局限流器误伤所有用户,也避免历史客户端条目无限增长。
|
||||
type RateLimitMiddleware struct {
|
||||
cfg config.RateLimitConfig
|
||||
limiters map[string]*SlidingWindowLimiter
|
||||
mu sync.RWMutex
|
||||
cleanupInt time.Duration
|
||||
cfg config.RateLimitConfig
|
||||
limiters map[string]*limiterEntry
|
||||
mu sync.RWMutex
|
||||
cleanupInt time.Duration
|
||||
lastCleanup time.Time
|
||||
}
|
||||
|
||||
type limiterEntry struct {
|
||||
limiter *SlidingWindowLimiter
|
||||
window time.Duration
|
||||
lastSeen time.Time
|
||||
}
|
||||
|
||||
// SlidingWindowLimiter 滑动窗口限流器
|
||||
@@ -43,7 +53,7 @@ func (l *SlidingWindowLimiter) Allow() bool {
|
||||
cutoff := now - l.window.Milliseconds()
|
||||
|
||||
// 清理过期请求
|
||||
var validRequests []int64
|
||||
validRequests := l.requests[:0]
|
||||
for _, t := range l.requests {
|
||||
if t > cutoff {
|
||||
validRequests = append(validRequests, t)
|
||||
@@ -63,9 +73,10 @@ func (l *SlidingWindowLimiter) Allow() bool {
|
||||
// NewRateLimitMiddleware 创建限流中间件
|
||||
func NewRateLimitMiddleware(cfg config.RateLimitConfig) *RateLimitMiddleware {
|
||||
return &RateLimitMiddleware{
|
||||
cfg: cfg,
|
||||
limiters: make(map[string]*SlidingWindowLimiter),
|
||||
cleanupInt: 5 * time.Minute,
|
||||
cfg: cfg,
|
||||
limiters: make(map[string]*limiterEntry),
|
||||
cleanupInt: 5 * time.Minute,
|
||||
lastCleanup: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -89,16 +100,18 @@ func (m *RateLimitMiddleware) Refresh() gin.HandlerFunc {
|
||||
return m.limitForKey("refresh", 60, 10)
|
||||
}
|
||||
|
||||
func (m *RateLimitMiddleware) limitForKey(key string, windowSeconds int, capacity int64) gin.HandlerFunc {
|
||||
func (m *RateLimitMiddleware) limitForKey(scope string, windowSeconds int, capacity int64) gin.HandlerFunc {
|
||||
if os.Getenv("DISABLE_RATE_LIMIT") == "1" {
|
||||
return func(c *gin.Context) {
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
limiter := m.getOrCreateLimiter(key, time.Duration(windowSeconds)*time.Second, capacity)
|
||||
window := time.Duration(windowSeconds) * time.Second
|
||||
|
||||
return func(c *gin.Context) {
|
||||
limiterKey := m.buildLimiterKey(scope, c)
|
||||
limiter := m.getOrCreateLimiter(limiterKey, window, capacity)
|
||||
if !limiter.Allow() {
|
||||
c.JSON(429, gin.H{
|
||||
"code": 429,
|
||||
@@ -111,24 +124,60 @@ func (m *RateLimitMiddleware) limitForKey(key string, windowSeconds int, capacit
|
||||
}
|
||||
}
|
||||
|
||||
func (m *RateLimitMiddleware) getOrCreateLimiter(key string, window time.Duration, capacity int64) *SlidingWindowLimiter {
|
||||
m.mu.RLock()
|
||||
limiter, exists := m.limiters[key]
|
||||
m.mu.RUnlock()
|
||||
func (m *RateLimitMiddleware) buildLimiterKey(scope string, c *gin.Context) string {
|
||||
if userID, ok := c.Get("user_id"); ok {
|
||||
return fmt.Sprintf("%s:user:%v", scope, userID)
|
||||
}
|
||||
return fmt.Sprintf("%s:ip:%s", scope, c.ClientIP())
|
||||
}
|
||||
|
||||
func (m *RateLimitMiddleware) getOrCreateLimiter(key string, window time.Duration, capacity int64) *SlidingWindowLimiter {
|
||||
now := time.Now()
|
||||
m.maybeCleanup(now)
|
||||
|
||||
m.mu.RLock()
|
||||
entry, exists := m.limiters[key]
|
||||
m.mu.RUnlock()
|
||||
if exists {
|
||||
return limiter
|
||||
m.mu.Lock()
|
||||
entry.lastSeen = now
|
||||
m.mu.Unlock()
|
||||
return entry.limiter
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// 双重检查
|
||||
if limiter, exists = m.limiters[key]; exists {
|
||||
return limiter
|
||||
if entry, exists = m.limiters[key]; exists {
|
||||
entry.lastSeen = now
|
||||
return entry.limiter
|
||||
}
|
||||
|
||||
limiter = NewSlidingWindowLimiter(window, capacity)
|
||||
m.limiters[key] = limiter
|
||||
return limiter
|
||||
entry = &limiterEntry{
|
||||
limiter: NewSlidingWindowLimiter(window, capacity),
|
||||
window: window,
|
||||
lastSeen: now,
|
||||
}
|
||||
m.limiters[key] = entry
|
||||
return entry.limiter
|
||||
}
|
||||
|
||||
func (m *RateLimitMiddleware) maybeCleanup(now time.Time) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if now.Sub(m.lastCleanup) < m.cleanupInt {
|
||||
return
|
||||
}
|
||||
|
||||
for key, entry := range m.limiters {
|
||||
idleTTL := entry.window
|
||||
if idleTTL < m.cleanupInt {
|
||||
idleTTL = m.cleanupInt
|
||||
}
|
||||
if now.Sub(entry.lastSeen) > idleTTL {
|
||||
delete(m.limiters, key)
|
||||
}
|
||||
}
|
||||
m.lastCleanup = now
|
||||
}
|
||||
|
||||
107
internal/api/middleware/ratelimit_test.go
Normal file
107
internal/api/middleware/ratelimit_test.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/user-management-system/internal/config"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
func newRateLimitTestEngine(mw gin.HandlerFunc) *gin.Engine {
|
||||
engine := gin.New()
|
||||
engine.Use(mw)
|
||||
engine.GET("/ping", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
return engine
|
||||
}
|
||||
|
||||
func performRateLimitRequest(engine *gin.Engine, remoteAddr string, setup func(*http.Request)) int {
|
||||
req := httptest.NewRequest(http.MethodGet, "/ping", nil)
|
||||
req.RemoteAddr = remoteAddr
|
||||
if setup != nil {
|
||||
setup(req)
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
engine.ServeHTTP(w, req)
|
||||
return w.Code
|
||||
}
|
||||
|
||||
func TestRateLimitMiddleware_LoginUsesIndependentIPBuckets(t *testing.T) {
|
||||
mw := NewRateLimitMiddleware(config.RateLimitConfig{})
|
||||
engine := newRateLimitTestEngine(mw.Login())
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
if code := performRateLimitRequest(engine, "1.1.1.1:1234", nil); code != http.StatusOK {
|
||||
t.Fatalf("ip1 request %d expected 200, got %d", i+1, code)
|
||||
}
|
||||
}
|
||||
if code := performRateLimitRequest(engine, "1.1.1.1:1234", nil); code != http.StatusTooManyRequests {
|
||||
t.Fatalf("ip1 sixth request expected 429, got %d", code)
|
||||
}
|
||||
|
||||
if code := performRateLimitRequest(engine, "2.2.2.2:1234", nil); code != http.StatusOK {
|
||||
t.Fatalf("independent ip should not be throttled, got %d", code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimitMiddleware_APIPrefersUserIDOverSharedIP(t *testing.T) {
|
||||
mw := NewRateLimitMiddleware(config.RateLimitConfig{})
|
||||
engine := gin.New()
|
||||
engine.Use(func(c *gin.Context) {
|
||||
if userID := c.GetHeader("X-Test-User-ID"); userID != "" {
|
||||
c.Set("user_id", userID)
|
||||
}
|
||||
c.Next()
|
||||
})
|
||||
engine.Use(mw.limitForKey("api-test", 60, 1))
|
||||
engine.GET("/ping", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
setupUser1 := func(req *http.Request) {
|
||||
req.Header.Set("X-Test-User-ID", "101")
|
||||
}
|
||||
setupUser2 := func(req *http.Request) {
|
||||
req.Header.Set("X-Test-User-ID", "202")
|
||||
}
|
||||
|
||||
if code := performRateLimitRequest(engine, "9.9.9.9:1234", setupUser1); code != http.StatusOK {
|
||||
t.Fatalf("user1 first request expected 200, got %d", code)
|
||||
}
|
||||
if code := performRateLimitRequest(engine, "9.9.9.9:1234", setupUser1); code != http.StatusTooManyRequests {
|
||||
t.Fatalf("user1 second request expected 429, got %d", code)
|
||||
}
|
||||
if code := performRateLimitRequest(engine, "9.9.9.9:1234", setupUser2); code != http.StatusOK {
|
||||
t.Fatalf("user2 should have independent bucket on shared ip, got %d", code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimitMiddleware_CleansUpIdleLimiters(t *testing.T) {
|
||||
mw := NewRateLimitMiddleware(config.RateLimitConfig{})
|
||||
mw.cleanupInt = 10 * time.Millisecond
|
||||
engine := newRateLimitTestEngine(mw.limitForKey("cleanup", 1, 2))
|
||||
|
||||
if code := performRateLimitRequest(engine, "3.3.3.3:1234", nil); code != http.StatusOK {
|
||||
t.Fatalf("seed request expected 200, got %d", code)
|
||||
}
|
||||
if got := len(mw.limiters); got != 1 {
|
||||
t.Fatalf("expected 1 limiter after seed request, got %d", got)
|
||||
}
|
||||
|
||||
time.Sleep(1100 * time.Millisecond)
|
||||
if code := performRateLimitRequest(engine, "4.4.4.4:1234", nil); code != http.StatusOK {
|
||||
t.Fatalf("cleanup trigger request expected 200, got %d", code)
|
||||
}
|
||||
|
||||
if got := len(mw.limiters); got != 1 {
|
||||
t.Fatalf("expected stale limiter to be cleaned up, got %d entries", got)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user