From 707d35fb7447d9eecf3d273bb0390b4cc2ea860d Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 29 May 2026 20:11:26 +0800 Subject: [PATCH] test: add middleware tests for cache_control, security_headers, trace_id MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add comprehensive tests for three middleware components: - cache_control: NoStoreSensitiveResponses, shouldDisableCaching - security_headers: SecurityHeaders, shouldAttachCSP, isHTTPSRequest - trace_id: TraceID, GetTraceID, generateTraceID Coverage: middleware 35.7% → 36.4% --- internal/api/middleware/cache_control_test.go | 117 +++++++++++++ .../api/middleware/security_headers_test.go | 160 ++++++++++++++++++ internal/api/middleware/trace_id_test.go | 148 ++++++++++++++++ 3 files changed, 425 insertions(+) create mode 100644 internal/api/middleware/cache_control_test.go create mode 100644 internal/api/middleware/security_headers_test.go create mode 100644 internal/api/middleware/trace_id_test.go diff --git a/internal/api/middleware/cache_control_test.go b/internal/api/middleware/cache_control_test.go new file mode 100644 index 0000000..05f0728 --- /dev/null +++ b/internal/api/middleware/cache_control_test.go @@ -0,0 +1,117 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" +) + +func TestNoStoreSensitiveResponses(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + path string + fullPath string + wantNoCache bool + }{ + { + name: "auth login path", + path: "/api/v1/auth/login", + fullPath: "/api/v1/auth/login", + wantNoCache: true, + }, + { + name: "auth register path", + path: "/api/v1/auth/register", + fullPath: "/api/v1/auth/register", + wantNoCache: true, + }, + { + name: "non-auth path", + path: "/api/v1/users", + fullPath: "/api/v1/users", + wantNoCache: false, + }, + { + name: "empty fullPath uses request path", + path: "/api/v1/auth/refresh", + fullPath: "", + wantNoCache: true, + }, + { + name: "subpath of auth", + path: "/api/v1/auth/oauth/callback", + fullPath: "/api/v1/auth/oauth/callback", + wantNoCache: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + router := gin.New() + router.Use(NoStoreSensitiveResponses()) + router.GET(tt.path, func(c *gin.Context) { + c.String(200, "OK") + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", tt.path, nil) + router.ServeHTTP(w, req) + + if tt.wantNoCache { + assert.Equal(t, "no-store, no-cache, must-revalidate, max-age=0", w.Header().Get("Cache-Control")) + assert.Equal(t, "no-cache", w.Header().Get("Pragma")) + assert.Equal(t, "0", w.Header().Get("Expires")) + assert.Equal(t, "no-store", w.Header().Get("Surrogate-Control")) + } else { + assert.Empty(t, w.Header().Get("Cache-Control")) + assert.Empty(t, w.Header().Get("Pragma")) + } + }) + } +} + +func TestShouldDisableCaching(t *testing.T) { + tests := []struct { + name string + routePath string + requestPath string + expected bool + }{ + { + name: "auth prefix match", + routePath: "/api/v1/auth/login", + requestPath: "/api/v1/auth/login", + expected: true, + }, + { + name: "no auth prefix", + routePath: "/api/v1/users", + requestPath: "/api/v1/users", + expected: false, + }, + { + name: "empty routePath uses requestPath", + routePath: "", + requestPath: "/api/v1/auth/logout", + expected: true, + }, + { + name: "trimmed spaces", + routePath: " /api/v1/auth/login ", + requestPath: "/api/v1/auth/login", + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := shouldDisableCaching(tt.routePath, tt.requestPath) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/internal/api/middleware/security_headers_test.go b/internal/api/middleware/security_headers_test.go new file mode 100644 index 0000000..9509995 --- /dev/null +++ b/internal/api/middleware/security_headers_test.go @@ -0,0 +1,160 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" +) + +func TestSecurityHeaders(t *testing.T) { + gin.SetMode(gin.TestMode) + + router := gin.New() + router.Use(SecurityHeaders()) + router.GET("/test", func(c *gin.Context) { + c.String(200, "OK") + }) + router.GET("/swagger/index.html", func(c *gin.Context) { + c.String(200, "Swagger UI") + }) + + tests := []struct { + name string + path string + wantCSP bool + wantSTS bool // Strict-Transport-Security (only for HTTPS) + }{ + { + name: "regular API endpoint", + path: "/test", + wantCSP: true, + wantSTS: false, // HTTP request + }, + { + name: "swagger endpoint no CSP", + path: "/swagger/index.html", + wantCSP: false, + wantSTS: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", tt.path, nil) + router.ServeHTTP(w, req) + + // 基础安全头 + assert.Equal(t, "nosniff", w.Header().Get("X-Content-Type-Options")) + assert.Equal(t, "DENY", w.Header().Get("X-Frame-Options")) + assert.Equal(t, "strict-origin-when-cross-origin", w.Header().Get("Referrer-Policy")) + assert.Equal(t, "camera=(), microphone=(), geolocation=()", w.Header().Get("Permissions-Policy")) + assert.Equal(t, "same-origin", w.Header().Get("Cross-Origin-Opener-Policy")) + assert.Equal(t, "none", w.Header().Get("X-Permitted-Cross-Domain-Policies")) + + // CSP 头 + if tt.wantCSP { + assert.NotEmpty(t, w.Header().Get("Content-Security-Policy")) + } + }) + } +} + +func TestShouldAttachCSP(t *testing.T) { + tests := []struct { + name string + routePath string + requestPath string + expected bool + }{ + { + name: "non-swagger path", + routePath: "/api/v1/users", + requestPath: "/api/v1/users", + expected: true, + }, + { + name: "swagger path", + routePath: "/swagger/index.html", + requestPath: "/swagger/index.html", + expected: false, + }, + { + name: "swagger subpath", + routePath: "/swagger/api-docs", + requestPath: "/swagger/api-docs", + expected: false, + }, + { + name: "empty routePath uses requestPath", + routePath: "", + requestPath: "/swagger/", + expected: false, + }, + { + name: "trimmed spaces", + routePath: " /api/v1/users ", + requestPath: "/api/v1/users", + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := shouldAttachCSP(tt.routePath, tt.requestPath) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestIsHTTPSRequest(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + setup func(*http.Request) + expected bool + }{ + { + name: "plain HTTP request", + setup: func(req *http.Request) {}, + expected: false, + }, + { + name: "X-Forwarded-Proto is https", + setup: func(req *http.Request) { + req.Header.Set("X-Forwarded-Proto", "https") + }, + expected: true, + }, + { + name: "X-Forwarded-Proto is HTTPS (uppercase)", + setup: func(req *http.Request) { + req.Header.Set("X-Forwarded-Proto", "HTTPS") + }, + expected: true, + }, + { + name: "X-Forwarded-Proto is http", + setup: func(req *http.Request) { + req.Header.Set("X-Forwarded-Proto", "http") + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + req, _ := http.NewRequest("GET", "/test", nil) + tt.setup(req) + c.Request = req + + result := isHTTPSRequest(c) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/internal/api/middleware/trace_id_test.go b/internal/api/middleware/trace_id_test.go new file mode 100644 index 0000000..d5360e3 --- /dev/null +++ b/internal/api/middleware/trace_id_test.go @@ -0,0 +1,148 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" +) + +func TestTraceID(t *testing.T) { + gin.SetMode(gin.TestMode) + + router := gin.New() + router.Use(TraceID()) + router.GET("/test", func(c *gin.Context) { + // 返回 trace ID 供验证 + traceID := GetTraceID(c) + c.String(200, traceID) + }) + + tests := []struct { + name string + incomingTraceID string + expectNewGenerated bool + }{ + { + name: "generate new trace ID", + incomingTraceID: "", + expectNewGenerated: true, + }, + { + name: "reuse incoming trace ID", + incomingTraceID: "abc123xyz", + expectNewGenerated: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + if tt.incomingTraceID != "" { + req.Header.Set(TraceIDHeader, tt.incomingTraceID) + } + router.ServeHTTP(w, req) + + // 检查响应头中的 trace ID + responseTraceID := w.Header().Get(TraceIDHeader) + assert.NotEmpty(t, responseTraceID) + + if tt.expectNewGenerated { + // 新生成的 trace ID 应该包含日期格式 + assert.True(t, strings.Contains(responseTraceID, "-")) + // 格式: YYYYMMDD-xxxxxxxx + parts := strings.Split(responseTraceID, "-") + assert.Equal(t, 2, len(parts)) + assert.Equal(t, 8, len(parts[0])) // YYYYMMDD + assert.Equal(t, 16, len(parts[1])) // hex + } else { + assert.Equal(t, tt.incomingTraceID, responseTraceID) + } + + // 响应体应该包含 trace ID + body := w.Body.String() + assert.Equal(t, responseTraceID, body) + }) + } +} + +func TestTraceID_SetInContext(t *testing.T) { + gin.SetMode(gin.TestMode) + + var capturedTraceID string + router := gin.New() + router.Use(TraceID()) + router.GET("/test", func(c *gin.Context) { + capturedTraceID = GetTraceID(c) + c.String(200, "OK") + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + req.Header.Set(TraceIDHeader, "custom-trace-123") + router.ServeHTTP(w, req) + + assert.Equal(t, "custom-trace-123", capturedTraceID) +} + +func TestGetTraceID(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + setupContext func(*gin.Context) + expected string + }{ + { + name: "trace ID exists", + setupContext: func(c *gin.Context) { + c.Set(TraceIDKey, "existing-trace") + }, + expected: "existing-trace", + }, + { + name: "trace ID not exists", + setupContext: func(c *gin.Context) {}, + expected: "", + }, + { + name: "trace ID is not string", + setupContext: func(c *gin.Context) { + c.Set(TraceIDKey, 12345) + }, + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + tt.setupContext(c) + + result := GetTraceID(c) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestGenerateTraceID(t *testing.T) { + // 生成多个 trace ID,验证格式 + traceIDs := make(map[string]bool) + for i := 0; i < 100; i++ { + id := generateTraceID() + traceIDs[id] = true + + // 验证格式 + parts := strings.Split(id, "-") + assert.Equal(t, 2, len(parts), "trace ID should have 2 parts separated by -") + assert.Equal(t, 8, len(parts[0]), "date part should be 8 characters (YYYYMMDD)") + assert.Equal(t, 16, len(parts[1]), "random part should be 16 hex characters") + } + + // 验证唯一性(100个应该都不同) + assert.Equal(t, 100, len(traceIDs), "generated trace IDs should be unique") +}