package middleware import ( "bytes" "net/http" "net/http/httptest" "strconv" "testing" "time" "github.com/gin-gonic/gin" "github.com/user-management-system/internal/config" ) func performRateLimitedRequest(router *gin.Engine, path string, userID int64) *httptest.ResponseRecorder { recorder := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, path, nil) req.RemoteAddr = "127.0.0.1:12345" req.Header.Set("X-Test-User-ID", strconv.FormatInt(userID, 10)) router.ServeHTTP(recorder, req) return recorder } func performRefreshRateLimitedRequestWithCookie(router *gin.Engine, refreshToken string) *httptest.ResponseRecorder { recorder := httptest.NewRecorder() req := httptest.NewRequest(http.MethodPost, "/auth/refresh", nil) req.RemoteAddr = "127.0.0.1:12345" if refreshToken != "" { req.AddCookie(&http.Cookie{Name: "ums_refresh_token", Value: refreshToken}) } router.ServeHTTP(recorder, req) return recorder } func performRefreshRateLimitedRequestWithBody(router *gin.Engine, refreshToken string) *httptest.ResponseRecorder { recorder := httptest.NewRecorder() body := bytes.NewBufferString(`{"refresh_token":"` + refreshToken + `"}`) req := httptest.NewRequest(http.MethodPost, "/auth/refresh", body) req.RemoteAddr = "127.0.0.1:12345" req.Header.Set("Content-Type", "application/json") router.ServeHTTP(recorder, req) return recorder } func TestRateLimitMiddleware_API_ScopesBudgetByRouteForAuthenticatedUser(t *testing.T) { gin.SetMode(gin.TestMode) rateLimitMiddleware := NewRateLimitMiddleware(config.RateLimitConfig{}) router := gin.New() router.Use(func(c *gin.Context) { rawUserID := c.GetHeader("X-Test-User-ID") if rawUserID != "" { userID, err := strconv.ParseInt(rawUserID, 10, 64) if err == nil { c.Set("user_id", userID) } } c.Next() }) protected := router.Group("") protected.Use(rateLimitMiddleware.API()) protected.GET("/users", func(c *gin.Context) { c.Status(http.StatusOK) }) protected.GET("/roles", func(c *gin.Context) { c.Status(http.StatusOK) }) for i := 0; i < 100; i++ { recorder := performRateLimitedRequest(router, "/users", 1) if recorder.Code != http.StatusOK { t.Fatalf("request %d to /users returned %d, want %d", i+1, recorder.Code, http.StatusOK) } } sameRouteOverflow := performRateLimitedRequest(router, "/users", 1) if sameRouteOverflow.Code != http.StatusTooManyRequests { t.Fatalf("overflow request to /users returned %d, want %d", sameRouteOverflow.Code, http.StatusTooManyRequests) } differentRoute := performRateLimitedRequest(router, "/roles", 1) if differentRoute.Code != http.StatusOK { t.Fatalf("request to /roles after exhausting /users budget returned %d, want %d", differentRoute.Code, http.StatusOK) } } func TestRateLimitMiddleware_Refresh_ScopesBudgetByRefreshCookie(t *testing.T) { gin.SetMode(gin.TestMode) rateLimitMiddleware := NewRateLimitMiddleware(config.RateLimitConfig{}) router := gin.New() router.POST("/auth/refresh", rateLimitMiddleware.Refresh(), func(c *gin.Context) { c.Status(http.StatusOK) }) for i := 0; i < 10; i++ { recorder := performRefreshRateLimitedRequestWithCookie(router, "refresh-token-a") if recorder.Code != http.StatusOK { t.Fatalf("request %d for refresh-token-a returned %d, want %d", i+1, recorder.Code, http.StatusOK) } } sameTokenOverflow := performRefreshRateLimitedRequestWithCookie(router, "refresh-token-a") if sameTokenOverflow.Code != http.StatusTooManyRequests { t.Fatalf("overflow request for refresh-token-a returned %d, want %d", sameTokenOverflow.Code, http.StatusTooManyRequests) } differentToken := performRefreshRateLimitedRequestWithCookie(router, "refresh-token-b") if differentToken.Code != http.StatusOK { t.Fatalf("request for refresh-token-b after exhausting refresh-token-a budget returned %d, want %d", differentToken.Code, http.StatusOK) } } func TestRateLimitMiddleware_Refresh_ScopesBudgetByRefreshTokenBody(t *testing.T) { gin.SetMode(gin.TestMode) rateLimitMiddleware := NewRateLimitMiddleware(config.RateLimitConfig{}) router := gin.New() router.POST("/auth/refresh", rateLimitMiddleware.Refresh(), func(c *gin.Context) { c.Status(http.StatusOK) }) for i := 0; i < 10; i++ { recorder := performRefreshRateLimitedRequestWithBody(router, "refresh-token-a") if recorder.Code != http.StatusOK { t.Fatalf("request %d for refresh-token-a body returned %d, want %d", i+1, recorder.Code, http.StatusOK) } } sameTokenOverflow := performRefreshRateLimitedRequestWithBody(router, "refresh-token-a") if sameTokenOverflow.Code != http.StatusTooManyRequests { t.Fatalf("overflow request for refresh-token-a body returned %d, want %d", sameTokenOverflow.Code, http.StatusTooManyRequests) } differentToken := performRefreshRateLimitedRequestWithBody(router, "refresh-token-b") if differentToken.Code != http.StatusOK { t.Fatalf("request for refresh-token-b body after exhausting refresh-token-a budget returned %d, want %d", differentToken.Code, http.StatusOK) } } func TestExtractRefreshToken_PreservesRequestBody(t *testing.T) { gin.SetMode(gin.TestMode) body := bytes.NewBufferString(`{"refresh_token":"refresh-token-a"}`) req := httptest.NewRequest(http.MethodPost, "/auth/refresh", body) req.Header.Set("Content-Type", "application/json") recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) c.Request = req if got := extractRefreshToken(c); got != "refresh-token-a" { t.Fatalf("extractRefreshToken() = %q, want refresh-token-a", got) } readBack := new(bytes.Buffer) if _, err := readBack.ReadFrom(c.Request.Body); err != nil { t.Fatalf("re-read body failed: %v", err) } if got := readBack.String(); got != `{"refresh_token":"refresh-token-a"}` { t.Fatalf("request body after extraction = %q, want original JSON", got) } } func TestRateLimitMiddleware_CleanupRemovesExpiredLimiters(t *testing.T) { middleware := NewRateLimitMiddleware(config.RateLimitConfig{}) limiter := middleware.getOrCreateLimiter("login:ip:127.0.0.1", time.Millisecond, 1) limiter.requests = []int64{time.Now().Add(-time.Second).UnixMilli()} middleware.Cleanup() if _, exists := middleware.limiters["login:ip:127.0.0.1"]; exists { t.Fatal("expected expired limiter to be removed") } } func TestRateLimitMiddleware_ResolveLimiterKeyPrefersUserIDForAPI(t *testing.T) { gin.SetMode(gin.TestMode) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) c.Request = httptest.NewRequest(http.MethodGet, "/users/1", nil) c.Params = gin.Params{{Key: "id", Value: "1"}} c.Set("user_id", int64(99)) middleware := NewRateLimitMiddleware(config.RateLimitConfig{}) key := middleware.resolveLimiterKey(c, "api") if key != "api:GET:/users/1:user:99" { t.Fatalf("resolveLimiterKey() = %q, want api:GET:/users/1:user:99", key) } } func TestSlidingWindowLimiter_EnforcesCapacityWithinWindow(t *testing.T) { limiter := NewSlidingWindowLimiter(time.Second, 2) if !limiter.Allow() { t.Fatal("expected first request to pass") } if !limiter.Allow() { t.Fatal("expected second request to pass") } if limiter.Allow() { t.Fatal("expected third request to be rejected") } } func TestRateLimitMiddleware_StartCleanupStopsSafely(t *testing.T) { middleware := NewRateLimitMiddleware(config.RateLimitConfig{}) middleware.cleanupInt = 10 * time.Millisecond stop := middleware.StartCleanup() time.Sleep(25 * time.Millisecond) stop() } func TestRateLimitMiddleware_ResolveLimiterKeyRefreshFallsBackToIP(t *testing.T) { gin.SetMode(gin.TestMode) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) c.Request = httptest.NewRequest(http.MethodPost, "/auth/refresh", bytes.NewBufferString(`{}`)) c.Request.RemoteAddr = "127.0.0.1:12345" middleware := NewRateLimitMiddleware(config.RateLimitConfig{}) key := middleware.resolveLimiterKey(c, "refresh") if key != "refresh:ip:127.0.0.1" { t.Fatalf("resolveLimiterKey() = %q, want refresh:ip:127.0.0.1", key) } } func TestFingerprintValue_IsDeterministic(t *testing.T) { first := fingerprintValue("refresh-token-a") second := fingerprintValue("refresh-token-a") third := fingerprintValue("refresh-token-b") if first != second { t.Fatalf("expected same input fingerprint to match: %q vs %q", first, second) } if first == third { t.Fatalf("expected different inputs to produce different fingerprints: %q vs %q", first, third) } } func TestRateLimitMiddleware_RegisterAndLoginLimiters(t *testing.T) { gin.SetMode(gin.TestMode) middleware := NewRateLimitMiddleware(config.RateLimitConfig{}) router := gin.New() router.POST("/register", middleware.Register(), func(c *gin.Context) { c.Status(http.StatusOK) }) router.POST("/login", middleware.Login(), func(c *gin.Context) { c.Status(http.StatusOK) }) for i := 0; i < 10; i++ { recorder := httptest.NewRecorder() req := httptest.NewRequest(http.MethodPost, "/register", nil) req.RemoteAddr = "127.0.0.1:12345" router.ServeHTTP(recorder, req) if recorder.Code != http.StatusOK { t.Fatalf("register request %d returned %d, want %d", i+1, recorder.Code, http.StatusOK) } } registerOverflow := httptest.NewRecorder() registerReq := httptest.NewRequest(http.MethodPost, "/register", nil) registerReq.RemoteAddr = "127.0.0.1:12345" router.ServeHTTP(registerOverflow, registerReq) if registerOverflow.Code != http.StatusTooManyRequests { t.Fatalf("register overflow returned %d, want %d", registerOverflow.Code, http.StatusTooManyRequests) } for i := 0; i < 5; i++ { recorder := httptest.NewRecorder() req := httptest.NewRequest(http.MethodPost, "/login", nil) req.RemoteAddr = "127.0.0.1:54321" router.ServeHTTP(recorder, req) if recorder.Code != http.StatusOK { t.Fatalf("login request %d returned %d, want %d", i+1, recorder.Code, http.StatusOK) } } loginOverflow := httptest.NewRecorder() loginReq := httptest.NewRequest(http.MethodPost, "/login", nil) loginReq.RemoteAddr = "127.0.0.1:54321" router.ServeHTTP(loginOverflow, loginReq) if loginOverflow.Code != http.StatusTooManyRequests { t.Fatalf("login overflow returned %d, want %d", loginOverflow.Code, http.StatusTooManyRequests) } }