package middleware import ( "net/http" "net/http/httptest" "testing" "time" "github.com/gin-gonic/gin" "github.com/user-management-system/internal/config" ) func init() { gin.SetMode(gin.TestMode) } func newRateLimitTestEngine(mw gin.HandlerFunc) *gin.Engine { engine := gin.New() engine.Use(mw) engine.GET("/ping", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"ok": true}) }) return engine } func performRateLimitRequest(engine *gin.Engine, remoteAddr string, setup func(*http.Request)) int { req := httptest.NewRequest(http.MethodGet, "/ping", nil) req.RemoteAddr = remoteAddr if setup != nil { setup(req) } w := httptest.NewRecorder() engine.ServeHTTP(w, req) return w.Code } func TestRateLimitMiddleware_LoginUsesIndependentIPBuckets(t *testing.T) { mw := NewRateLimitMiddleware(config.RateLimitConfig{}) engine := newRateLimitTestEngine(mw.Login()) for i := 0; i < 5; i++ { if code := performRateLimitRequest(engine, "1.1.1.1:1234", nil); code != http.StatusOK { t.Fatalf("ip1 request %d expected 200, got %d", i+1, code) } } if code := performRateLimitRequest(engine, "1.1.1.1:1234", nil); code != http.StatusTooManyRequests { t.Fatalf("ip1 sixth request expected 429, got %d", code) } if code := performRateLimitRequest(engine, "2.2.2.2:1234", nil); code != http.StatusOK { t.Fatalf("independent ip should not be throttled, got %d", code) } } func TestRateLimitMiddleware_APIPrefersUserIDOverSharedIP(t *testing.T) { mw := NewRateLimitMiddleware(config.RateLimitConfig{}) engine := gin.New() engine.Use(func(c *gin.Context) { if userID := c.GetHeader("X-Test-User-ID"); userID != "" { c.Set("user_id", userID) } c.Next() }) engine.Use(mw.limitForKey("api-test", 60, 1)) engine.GET("/ping", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"ok": true}) }) setupUser1 := func(req *http.Request) { req.Header.Set("X-Test-User-ID", "101") } setupUser2 := func(req *http.Request) { req.Header.Set("X-Test-User-ID", "202") } if code := performRateLimitRequest(engine, "9.9.9.9:1234", setupUser1); code != http.StatusOK { t.Fatalf("user1 first request expected 200, got %d", code) } if code := performRateLimitRequest(engine, "9.9.9.9:1234", setupUser1); code != http.StatusTooManyRequests { t.Fatalf("user1 second request expected 429, got %d", code) } if code := performRateLimitRequest(engine, "9.9.9.9:1234", setupUser2); code != http.StatusOK { t.Fatalf("user2 should have independent bucket on shared ip, got %d", code) } } func TestRateLimitMiddleware_CleansUpIdleLimiters(t *testing.T) { mw := NewRateLimitMiddleware(config.RateLimitConfig{}) mw.cleanupInt = 10 * time.Millisecond engine := newRateLimitTestEngine(mw.limitForKey("cleanup", 1, 2)) if code := performRateLimitRequest(engine, "3.3.3.3:1234", nil); code != http.StatusOK { t.Fatalf("seed request expected 200, got %d", code) } if got := len(mw.limiters); got != 1 { t.Fatalf("expected 1 limiter after seed request, got %d", got) } time.Sleep(1100 * time.Millisecond) if code := performRateLimitRequest(engine, "4.4.4.4:1234", nil); code != http.StatusOK { t.Fatalf("cleanup trigger request expected 200, got %d", code) } if got := len(mw.limiters); got != 1 { t.Fatalf("expected stale limiter to be cleaned up, got %d entries", got) } }