package middleware import ( "bytes" "encoding/json" "errors" "log" "net/http" "net/http/httptest" "strings" "testing" "time" "github.com/gin-gonic/gin" "github.com/user-management-system/internal/config" apierrors "github.com/user-management-system/internal/pkg/errors" "github.com/user-management-system/internal/security" ) func TestCORS_UsesConfiguredOrigins(t *testing.T) { gin.SetMode(gin.TestMode) SetCORSConfig(config.CORSConfig{ AllowedOrigins: []string{"https://app.example.com"}, AllowCredentials: true, }) t.Cleanup(func() { SetCORSConfig(config.CORSConfig{ AllowedOrigins: []string{"*"}, AllowCredentials: true, }) }) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) c.Request = httptest.NewRequest(http.MethodOptions, "/api/v1/users", nil) c.Request.Header.Set("Origin", "https://app.example.com") c.Request.Header.Set("Access-Control-Request-Headers", "Authorization") CORS()(c) if recorder.Code != http.StatusNoContent { t.Fatalf("expected 204, got %d", recorder.Code) } if got := recorder.Header().Get("Access-Control-Allow-Origin"); got != "https://app.example.com" { t.Fatalf("unexpected allow origin: %s", got) } if got := recorder.Header().Get("Access-Control-Allow-Credentials"); got != "true" { t.Fatalf("expected credentials header to be 'true', got %q", got) } } func TestCORS_RejectsDisallowedOrigin(t *testing.T) { gin.SetMode(gin.TestMode) SetCORSConfig(config.CORSConfig{ AllowedOrigins: []string{"https://app.example.com"}, AllowCredentials: false, }) t.Cleanup(func() { SetCORSConfig(config.CORSConfig{ AllowedOrigins: []string{"*"}, AllowCredentials: true, }) }) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/users", nil) c.Request.Header.Set("Origin", "https://evil.example.com") CORS()(c) if recorder.Code != http.StatusForbidden { t.Fatalf("expected 403, got %d", recorder.Code) } } func TestSanitizeQuery_MasksSensitiveValues(t *testing.T) { raw := "token=abc123&foo=bar&access_token=xyz&secret=s1" sanitized := sanitizeQuery(raw) if sanitized == "" { t.Fatal("expected sanitized query") } if sanitized == raw { t.Fatal("expected query to be sanitized") } for _, value := range []string{"abc123", "xyz", "s1"} { if strings.Contains(sanitized, value) { t.Fatalf("expected sensitive value %q to be masked in %q", value, sanitized) } } if sanitizeQuery("") != "" { t.Fatal("expected empty query to stay empty") } } func TestSecurityHeaders_AttachesExpectedHeaders(t *testing.T) { gin.SetMode(gin.TestMode) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/users", nil) SecurityHeaders()(c) if got := recorder.Header().Get("X-Content-Type-Options"); got != "nosniff" { t.Fatalf("unexpected nosniff header: %q", got) } if got := recorder.Header().Get("X-Frame-Options"); got != "DENY" { t.Fatalf("unexpected frame options: %q", got) } if got := recorder.Header().Get("Content-Security-Policy"); got == "" { t.Fatal("expected content security policy header") } if got := recorder.Header().Get("Strict-Transport-Security"); got != "" { t.Fatalf("did not expect hsts header for http request, got %q", got) } } func TestSecurityHeaders_AttachesHSTSForForwardedHTTPS(t *testing.T) { gin.SetMode(gin.TestMode) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/users", nil) c.Request.Header.Set("X-Forwarded-Proto", "https") SecurityHeaders()(c) if got := recorder.Header().Get("Strict-Transport-Security"); !strings.Contains(got, "max-age=31536000") { t.Fatalf("expected hsts header, got %q", got) } } func TestNoStoreSensitiveResponses_AttachesExpectedHeadersToAuthRoutes(t *testing.T) { gin.SetMode(gin.TestMode) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/capabilities", nil) NoStoreSensitiveResponses()(c) if got := recorder.Header().Get("Cache-Control"); got != sensitiveNoStoreCacheControl { t.Fatalf("unexpected cache-control header: %q", got) } if got := recorder.Header().Get("Pragma"); got != "no-cache" { t.Fatalf("unexpected pragma header: %q", got) } if got := recorder.Header().Get("Expires"); got != "0" { t.Fatalf("unexpected expires header: %q", got) } if got := recorder.Header().Get("Surrogate-Control"); got != "no-store" { t.Fatalf("unexpected surrogate-control header: %q", got) } } func TestNoStoreSensitiveResponses_DoesNotAttachHeadersToNonAuthRoutes(t *testing.T) { gin.SetMode(gin.TestMode) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/users", nil) NoStoreSensitiveResponses()(c) if got := recorder.Header().Get("Cache-Control"); got != "" { t.Fatalf("did not expect cache-control header, got %q", got) } } // ---------- TraceID middleware ---------- func TestTraceID_GeneratesAndAttachesTraceID(t *testing.T) { gin.SetMode(gin.TestMode) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/users", nil) TraceID()(c) traceID := c.GetString("trace_id") if traceID == "" { t.Fatal("expected trace_id to be set") } if len(traceID) < 8 { t.Fatalf("trace_id should be reasonably long, got %q", traceID) } if got := recorder.Header().Get("X-Trace-ID"); got != traceID { t.Fatalf("expected X-Trace-ID header to match trace_id, got %q", got) } } func TestTraceID_ExtractsExistingTraceID(t *testing.T) { gin.SetMode(gin.TestMode) existingTraceID := "existing-trace-id-12345" recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/users", nil) c.Request.Header.Set("X-Trace-ID", existingTraceID) TraceID()(c) traceID := c.GetString("trace_id") if traceID != existingTraceID { t.Fatalf("expected trace_id to be extracted from header, got %q", traceID) } } func TestTraceID_GetTraceIDHandlesMissingAndPresentValue(t *testing.T) { gin.SetMode(gin.TestMode) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/users", nil) if got := GetTraceID(c); got != "" { t.Fatalf("GetTraceID() = %q, want empty string", got) } c.Set(TraceIDKey, "trace-123") if got := GetTraceID(c); got != "trace-123" { t.Fatalf("GetTraceID() = %q, want trace-123", got) } } // ---------- Error handling middleware ---------- func TestErrorHandler_HandlesErrors(t *testing.T) { gin.SetMode(gin.TestMode) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/users", nil) c.Error(errors.New("test error")) ErrorHandler()(c) if recorder.Code != http.StatusInternalServerError { t.Fatalf("expected status 500, got %d", recorder.Code) } } func TestErrorHandler_ApplicationErrorPreservesStatusAndReason(t *testing.T) { gin.SetMode(gin.TestMode) recorder := httptest.NewRecorder() router := gin.New() router.Use(ErrorHandler()) router.GET("/users", func(c *gin.Context) { _ = c.Error(apierrors.Forbidden("FORBIDDEN", "denied")) }) req := httptest.NewRequest(http.MethodGet, "/users", nil) router.ServeHTTP(recorder, req) if recorder.Code != http.StatusForbidden { t.Fatalf("expected status 403, got %d", recorder.Code) } var body map[string]any if err := json.Unmarshal(recorder.Body.Bytes(), &body); err != nil { t.Fatalf("unmarshal body failed: %v", err) } if got := body["reason"]; got != "FORBIDDEN" { t.Fatalf("reason = %#v, want FORBIDDEN", got) } if got := body["message"]; got != "denied" { t.Fatalf("message = %#v, want denied", got) } } func TestRecover_HandlesPanic(t *testing.T) { gin.SetMode(gin.TestMode) recorder := httptest.NewRecorder() c, router := gin.CreateTestContext(recorder) c.Request = httptest.NewRequest(http.MethodGet, "/panic", nil) router.Use(Recover()) router.GET("/panic", func(c *gin.Context) { panic("test panic") }) router.ServeHTTP(recorder, c.Request) if recorder.Code != http.StatusInternalServerError { t.Fatalf("expected status 500 after panic, got %d", recorder.Code) } } func TestRecover_ReturnsInternalServerErrorPayload(t *testing.T) { gin.SetMode(gin.TestMode) recorder := httptest.NewRecorder() router := gin.New() router.Use(Recover()) router.GET("/panic", func(c *gin.Context) { panic("boom") }) req := httptest.NewRequest(http.MethodGet, "/panic", nil) router.ServeHTTP(recorder, req) if recorder.Code != http.StatusInternalServerError { t.Fatalf("expected status 500 after panic, got %d", recorder.Code) } var body map[string]any if err := json.Unmarshal(recorder.Body.Bytes(), &body); err != nil { t.Fatalf("unmarshal body failed: %v", err) } if got := body["code"]; got != float64(http.StatusInternalServerError) { t.Fatalf("code = %#v, want %d", got, http.StatusInternalServerError) } } func TestLogger_WritesSanitizedQueryAndErrorContext(t *testing.T) { gin.SetMode(gin.TestMode) var buf bytes.Buffer originalWriter := log.Writer() log.SetOutput(&buf) t.Cleanup(func() { log.SetOutput(originalWriter) }) recorder := httptest.NewRecorder() router := gin.New() router.Use(TraceID()) router.Use(Logger()) router.GET("/users", func(c *gin.Context) { c.Set("user_id", int64(7)) _ = c.Error(errors.New("boom")) c.Status(http.StatusAccepted) }) req := httptest.NewRequest(http.MethodGet, "/users?token=secret&name=alice", nil) req.RemoteAddr = "203.0.113.5:1234" req.Header.Set("User-Agent", "logger-test") router.ServeHTTP(recorder, req) deadline := time.Now().Add(time.Second) for time.Now().Before(deadline) && !strings.Contains(buf.String(), "[Query] /users?name=alice&token=%2A%2A%2A") { time.Sleep(10 * time.Millisecond) } logOutput := buf.String() if !strings.Contains(logOutput, "[API]") { t.Fatalf("expected API log entry, got %q", logOutput) } if !strings.Contains(logOutput, "user_id: 7") { t.Fatalf("expected user id in logs, got %q", logOutput) } if !strings.Contains(logOutput, "[Error]") || !strings.Contains(logOutput, "boom") { t.Fatalf("expected error log entry, got %q", logOutput) } if strings.Contains(logOutput, "token=secret") { t.Fatalf("expected sanitized query string, got %q", logOutput) } } func TestLogger_DropsMalformedQueryString(t *testing.T) { gin.SetMode(gin.TestMode) var buf bytes.Buffer originalWriter := log.Writer() log.SetOutput(&buf) t.Cleanup(func() { log.SetOutput(originalWriter) }) recorder := httptest.NewRecorder() router := gin.New() router.Use(Logger()) router.GET("/users", func(c *gin.Context) { c.Status(http.StatusOK) }) req := httptest.NewRequest(http.MethodGet, "/users?bad=%zz", nil) router.ServeHTTP(recorder, req) time.Sleep(25 * time.Millisecond) if strings.Contains(buf.String(), "[Query]") { t.Fatalf("expected malformed query to be skipped, got %q", buf.String()) } } func TestResponseWrapper_SkipsSSEAndBinaryResponses(t *testing.T) { gin.SetMode(gin.TestMode) testCases := []struct { name string path string contentType string }{ {name: "sse", path: "/stream", contentType: "text/event-stream"}, {name: "binary", path: "/download", contentType: "application/octet-stream"}, {name: "swagger", path: "/swagger/index.html", contentType: ""}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() router := gin.New() router.Use(ResponseWrapper()) router.GET(tc.path, func(c *gin.Context) { c.Header("Content-Type", "application/json") c.JSON(http.StatusOK, gin.H{"ok": true}) }) req := httptest.NewRequest(http.MethodGet, tc.path, nil) if tc.contentType != "" { req.Header.Set("Content-Type", tc.contentType) } router.ServeHTTP(recorder, req) if recorder.Code != http.StatusOK { t.Fatalf("expected 200, got %d", recorder.Code) } if got := recorder.Body.String(); got != `{"ok":true}` { t.Fatalf("body = %s, want raw payload", got) } }) } } func TestResponseWrapper_BufferMethodsTrackStatusAndBody(t *testing.T) { recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) wrapper := &responseWrapper{ ResponseWriter: c.Writer, body: bytes.NewBuffer(nil), statusCode: http.StatusOK, } if _, err := wrapper.Write([]byte("abc")); err != nil { t.Fatalf("Write() error = %v", err) } if _, err := wrapper.WriteString("def"); err != nil { t.Fatalf("WriteString() error = %v", err) } wrapper.WriteHeader(http.StatusAccepted) if got := wrapper.body.String(); got != "abcdef" { t.Fatalf("buffered body = %q, want abcdef", got) } if wrapper.statusCode != http.StatusAccepted { t.Fatalf("statusCode = %d, want %d", wrapper.statusCode, http.StatusAccepted) } } func TestIPFilter_RealIPAndInternalOnly(t *testing.T) { gin.SetMode(gin.TestMode) filter := security.NewIPFilter() middleware := NewIPFilterMiddleware(filter, IPFilterConfig{ TrustProxy: true, TrustedProxies: []string{"10.0.0.2"}, }) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) c.Request = httptest.NewRequest(http.MethodGet, "/protected", nil) c.Request.RemoteAddr = "10.0.0.2:8080" c.Request.Header.Set("X-Forwarded-For", "198.51.100.10, 10.0.0.2") if got := middleware.realIP(c); got != "198.51.100.10" { t.Fatalf("realIP() = %q, want 198.51.100.10", got) } if !middleware.isTrustedProxy("10.0.0.2") { t.Fatal("expected trusted proxy match") } if middleware.isTrustedProxy("10.0.0.3") { t.Fatal("unexpected trusted proxy match") } if !isPrivateIP("127.0.0.1") { t.Fatal("expected loopback to be private") } if isPrivateIP("198.51.100.10") { t.Fatal("expected public address to be non-private") } allowed := httptest.NewRecorder() allowedRouter := gin.New() allowedRouter.Use(InternalOnly()) allowedRouter.GET("/metrics", func(c *gin.Context) { c.Status(http.StatusOK) }) allowedReq := httptest.NewRequest(http.MethodGet, "/metrics", nil) allowedReq.RemoteAddr = "127.0.0.1:12345" allowedRouter.ServeHTTP(allowed, allowedReq) if allowed.Code != http.StatusOK { t.Fatalf("expected private IP to pass, got %d", allowed.Code) } blocked := httptest.NewRecorder() blockedRouter := gin.New() blockedRouter.Use(InternalOnly()) blockedRouter.GET("/metrics", func(c *gin.Context) { c.Status(http.StatusOK) }) blockedReq := httptest.NewRequest(http.MethodGet, "/metrics", nil) blockedReq.RemoteAddr = "198.51.100.10:12345" blockedRouter.ServeHTTP(blocked, blockedReq) if blocked.Code != http.StatusForbidden { t.Fatalf("expected public IP to be rejected, got %d", blocked.Code) } } func TestIPFilter_FilterAndFallbacks(t *testing.T) { gin.SetMode(gin.TestMode) filter := security.NewIPFilter() if err := filter.AddToBlacklist("198.51.100.10", "manual", time.Minute); err != nil { t.Fatalf("AddToBlacklist() error = %v", err) } middleware := NewIPFilterMiddleware(filter, IPFilterConfig{}) if middleware.GetFilter() != filter { t.Fatal("expected GetFilter() to expose the original filter") } blockedRecorder := httptest.NewRecorder() blockedRouter := gin.New() blockedRouter.Use(middleware.Filter()) blockedRouter.GET("/protected", func(c *gin.Context) { c.Status(http.StatusOK) }) blockedReq := httptest.NewRequest(http.MethodGet, "/protected", nil) blockedReq.RemoteAddr = "198.51.100.10:12345" blockedRouter.ServeHTTP(blockedRecorder, blockedReq) if blockedRecorder.Code != http.StatusForbidden { t.Fatalf("expected blocked IP to be rejected, got %d", blockedRecorder.Code) } allowedRecorder := httptest.NewRecorder() allowedRouter := gin.New() allowedRouter.Use(middleware.Filter()) allowedRouter.GET("/protected", func(c *gin.Context) { if got := c.GetString("client_ip"); got != "127.0.0.1" { t.Fatalf("client_ip = %q, want 127.0.0.1", got) } c.Status(http.StatusOK) }) allowedReq := httptest.NewRequest(http.MethodGet, "/protected", nil) allowedReq.RemoteAddr = "127.0.0.1:54321" allowedRouter.ServeHTTP(allowedRecorder, allowedReq) if allowedRecorder.Code != http.StatusOK { t.Fatalf("expected allowed IP to pass, got %d", allowedRecorder.Code) } trustedProxyMiddleware := NewIPFilterMiddleware(filter, IPFilterConfig{ TrustProxy: true, }) proxyRecorder := httptest.NewRecorder() proxyCtx, _ := gin.CreateTestContext(proxyRecorder) proxyCtx.Request = httptest.NewRequest(http.MethodGet, "/protected", nil) proxyCtx.Request.RemoteAddr = "10.0.0.2:8080" proxyCtx.Request.Header.Set("X-Real-IP", "203.0.113.9") if got := trustedProxyMiddleware.realIP(proxyCtx); got != "203.0.113.9" { t.Fatalf("realIP() X-Real-IP fallback = %q, want 203.0.113.9", got) } }