test: add middleware tests for cache_control, security_headers, trace_id

Add comprehensive tests for three middleware components:
- cache_control: NoStoreSensitiveResponses, shouldDisableCaching
- security_headers: SecurityHeaders, shouldAttachCSP, isHTTPSRequest
- trace_id: TraceID, GetTraceID, generateTraceID

Coverage: middleware 35.7% → 36.4%
This commit is contained in:
Your Name
2026-05-29 20:11:26 +08:00
parent 17a46c2770
commit 707d35fb74
3 changed files with 425 additions and 0 deletions

View File

@@ -0,0 +1,148 @@
package middleware
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func TestTraceID(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(TraceID())
router.GET("/test", func(c *gin.Context) {
// 返回 trace ID 供验证
traceID := GetTraceID(c)
c.String(200, traceID)
})
tests := []struct {
name string
incomingTraceID string
expectNewGenerated bool
}{
{
name: "generate new trace ID",
incomingTraceID: "",
expectNewGenerated: true,
},
{
name: "reuse incoming trace ID",
incomingTraceID: "abc123xyz",
expectNewGenerated: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
if tt.incomingTraceID != "" {
req.Header.Set(TraceIDHeader, tt.incomingTraceID)
}
router.ServeHTTP(w, req)
// 检查响应头中的 trace ID
responseTraceID := w.Header().Get(TraceIDHeader)
assert.NotEmpty(t, responseTraceID)
if tt.expectNewGenerated {
// 新生成的 trace ID 应该包含日期格式
assert.True(t, strings.Contains(responseTraceID, "-"))
// 格式: YYYYMMDD-xxxxxxxx
parts := strings.Split(responseTraceID, "-")
assert.Equal(t, 2, len(parts))
assert.Equal(t, 8, len(parts[0])) // YYYYMMDD
assert.Equal(t, 16, len(parts[1])) // hex
} else {
assert.Equal(t, tt.incomingTraceID, responseTraceID)
}
// 响应体应该包含 trace ID
body := w.Body.String()
assert.Equal(t, responseTraceID, body)
})
}
}
func TestTraceID_SetInContext(t *testing.T) {
gin.SetMode(gin.TestMode)
var capturedTraceID string
router := gin.New()
router.Use(TraceID())
router.GET("/test", func(c *gin.Context) {
capturedTraceID = GetTraceID(c)
c.String(200, "OK")
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
req.Header.Set(TraceIDHeader, "custom-trace-123")
router.ServeHTTP(w, req)
assert.Equal(t, "custom-trace-123", capturedTraceID)
}
func TestGetTraceID(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
setupContext func(*gin.Context)
expected string
}{
{
name: "trace ID exists",
setupContext: func(c *gin.Context) {
c.Set(TraceIDKey, "existing-trace")
},
expected: "existing-trace",
},
{
name: "trace ID not exists",
setupContext: func(c *gin.Context) {},
expected: "",
},
{
name: "trace ID is not string",
setupContext: func(c *gin.Context) {
c.Set(TraceIDKey, 12345)
},
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, _ := gin.CreateTestContext(httptest.NewRecorder())
tt.setupContext(c)
result := GetTraceID(c)
assert.Equal(t, tt.expected, result)
})
}
}
func TestGenerateTraceID(t *testing.T) {
// 生成多个 trace ID验证格式
traceIDs := make(map[string]bool)
for i := 0; i < 100; i++ {
id := generateTraceID()
traceIDs[id] = true
// 验证格式
parts := strings.Split(id, "-")
assert.Equal(t, 2, len(parts), "trace ID should have 2 parts separated by -")
assert.Equal(t, 8, len(parts[0]), "date part should be 8 characters (YYYYMMDD)")
assert.Equal(t, 16, len(parts[1]), "random part should be 16 hex characters")
}
// 验证唯一性100个应该都不同
assert.Equal(t, 100, len(traceIDs), "generated trace IDs should be unique")
}