package middleware import ( "net/http" "net/http/httptest" "strings" "testing" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" ) func TestTraceID(t *testing.T) { gin.SetMode(gin.TestMode) router := gin.New() router.Use(TraceID()) router.GET("/test", func(c *gin.Context) { // 返回 trace ID 供验证 traceID := GetTraceID(c) c.String(200, traceID) }) tests := []struct { name string incomingTraceID string expectNewGenerated bool }{ { name: "generate new trace ID", incomingTraceID: "", expectNewGenerated: true, }, { name: "reuse incoming trace ID", incomingTraceID: "abc123xyz", expectNewGenerated: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { w := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/test", nil) if tt.incomingTraceID != "" { req.Header.Set(TraceIDHeader, tt.incomingTraceID) } router.ServeHTTP(w, req) // 检查响应头中的 trace ID responseTraceID := w.Header().Get(TraceIDHeader) assert.NotEmpty(t, responseTraceID) if tt.expectNewGenerated { // 新生成的 trace ID 应该包含日期格式 assert.True(t, strings.Contains(responseTraceID, "-")) // 格式: YYYYMMDD-xxxxxxxx parts := strings.Split(responseTraceID, "-") assert.Equal(t, 2, len(parts)) assert.Equal(t, 8, len(parts[0])) // YYYYMMDD assert.Equal(t, 16, len(parts[1])) // hex } else { assert.Equal(t, tt.incomingTraceID, responseTraceID) } // 响应体应该包含 trace ID body := w.Body.String() assert.Equal(t, responseTraceID, body) }) } } func TestTraceID_SetInContext(t *testing.T) { gin.SetMode(gin.TestMode) var capturedTraceID string router := gin.New() router.Use(TraceID()) router.GET("/test", func(c *gin.Context) { capturedTraceID = GetTraceID(c) c.String(200, "OK") }) w := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/test", nil) req.Header.Set(TraceIDHeader, "custom-trace-123") router.ServeHTTP(w, req) assert.Equal(t, "custom-trace-123", capturedTraceID) } func TestGetTraceID(t *testing.T) { gin.SetMode(gin.TestMode) tests := []struct { name string setupContext func(*gin.Context) expected string }{ { name: "trace ID exists", setupContext: func(c *gin.Context) { c.Set(TraceIDKey, "existing-trace") }, expected: "existing-trace", }, { name: "trace ID not exists", setupContext: func(c *gin.Context) {}, expected: "", }, { name: "trace ID is not string", setupContext: func(c *gin.Context) { c.Set(TraceIDKey, 12345) }, expected: "", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, _ := gin.CreateTestContext(httptest.NewRecorder()) tt.setupContext(c) result := GetTraceID(c) assert.Equal(t, tt.expected, result) }) } } func TestGenerateTraceID(t *testing.T) { // 生成多个 trace ID,验证格式 traceIDs := make(map[string]bool) for i := 0; i < 100; i++ { id := generateTraceID() traceIDs[id] = true // 验证格式 parts := strings.Split(id, "-") assert.Equal(t, 2, len(parts), "trace ID should have 2 parts separated by -") assert.Equal(t, 8, len(parts[0]), "date part should be 8 characters (YYYYMMDD)") assert.Equal(t, 16, len(parts[1]), "random part should be 16 hex characters") } // 验证唯一性(100个应该都不同) assert.Equal(t, 100, len(traceIDs), "generated trace IDs should be unique") }