Add comprehensive tests for three middleware components: - cache_control: NoStoreSensitiveResponses, shouldDisableCaching - security_headers: SecurityHeaders, shouldAttachCSP, isHTTPSRequest - trace_id: TraceID, GetTraceID, generateTraceID Coverage: middleware 35.7% → 36.4%
149 lines
3.5 KiB
Go
149 lines
3.5 KiB
Go
package middleware
|
||
|
||
import (
|
||
"net/http"
|
||
"net/http/httptest"
|
||
"strings"
|
||
"testing"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/stretchr/testify/assert"
|
||
)
|
||
|
||
func TestTraceID(t *testing.T) {
|
||
gin.SetMode(gin.TestMode)
|
||
|
||
router := gin.New()
|
||
router.Use(TraceID())
|
||
router.GET("/test", func(c *gin.Context) {
|
||
// 返回 trace ID 供验证
|
||
traceID := GetTraceID(c)
|
||
c.String(200, traceID)
|
||
})
|
||
|
||
tests := []struct {
|
||
name string
|
||
incomingTraceID string
|
||
expectNewGenerated bool
|
||
}{
|
||
{
|
||
name: "generate new trace ID",
|
||
incomingTraceID: "",
|
||
expectNewGenerated: true,
|
||
},
|
||
{
|
||
name: "reuse incoming trace ID",
|
||
incomingTraceID: "abc123xyz",
|
||
expectNewGenerated: false,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
w := httptest.NewRecorder()
|
||
req, _ := http.NewRequest("GET", "/test", nil)
|
||
if tt.incomingTraceID != "" {
|
||
req.Header.Set(TraceIDHeader, tt.incomingTraceID)
|
||
}
|
||
router.ServeHTTP(w, req)
|
||
|
||
// 检查响应头中的 trace ID
|
||
responseTraceID := w.Header().Get(TraceIDHeader)
|
||
assert.NotEmpty(t, responseTraceID)
|
||
|
||
if tt.expectNewGenerated {
|
||
// 新生成的 trace ID 应该包含日期格式
|
||
assert.True(t, strings.Contains(responseTraceID, "-"))
|
||
// 格式: YYYYMMDD-xxxxxxxx
|
||
parts := strings.Split(responseTraceID, "-")
|
||
assert.Equal(t, 2, len(parts))
|
||
assert.Equal(t, 8, len(parts[0])) // YYYYMMDD
|
||
assert.Equal(t, 16, len(parts[1])) // hex
|
||
} else {
|
||
assert.Equal(t, tt.incomingTraceID, responseTraceID)
|
||
}
|
||
|
||
// 响应体应该包含 trace ID
|
||
body := w.Body.String()
|
||
assert.Equal(t, responseTraceID, body)
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestTraceID_SetInContext(t *testing.T) {
|
||
gin.SetMode(gin.TestMode)
|
||
|
||
var capturedTraceID string
|
||
router := gin.New()
|
||
router.Use(TraceID())
|
||
router.GET("/test", func(c *gin.Context) {
|
||
capturedTraceID = GetTraceID(c)
|
||
c.String(200, "OK")
|
||
})
|
||
|
||
w := httptest.NewRecorder()
|
||
req, _ := http.NewRequest("GET", "/test", nil)
|
||
req.Header.Set(TraceIDHeader, "custom-trace-123")
|
||
router.ServeHTTP(w, req)
|
||
|
||
assert.Equal(t, "custom-trace-123", capturedTraceID)
|
||
}
|
||
|
||
func TestGetTraceID(t *testing.T) {
|
||
gin.SetMode(gin.TestMode)
|
||
|
||
tests := []struct {
|
||
name string
|
||
setupContext func(*gin.Context)
|
||
expected string
|
||
}{
|
||
{
|
||
name: "trace ID exists",
|
||
setupContext: func(c *gin.Context) {
|
||
c.Set(TraceIDKey, "existing-trace")
|
||
},
|
||
expected: "existing-trace",
|
||
},
|
||
{
|
||
name: "trace ID not exists",
|
||
setupContext: func(c *gin.Context) {},
|
||
expected: "",
|
||
},
|
||
{
|
||
name: "trace ID is not string",
|
||
setupContext: func(c *gin.Context) {
|
||
c.Set(TraceIDKey, 12345)
|
||
},
|
||
expected: "",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||
tt.setupContext(c)
|
||
|
||
result := GetTraceID(c)
|
||
assert.Equal(t, tt.expected, result)
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestGenerateTraceID(t *testing.T) {
|
||
// 生成多个 trace ID,验证格式
|
||
traceIDs := make(map[string]bool)
|
||
for i := 0; i < 100; i++ {
|
||
id := generateTraceID()
|
||
traceIDs[id] = true
|
||
|
||
// 验证格式
|
||
parts := strings.Split(id, "-")
|
||
assert.Equal(t, 2, len(parts), "trace ID should have 2 parts separated by -")
|
||
assert.Equal(t, 8, len(parts[0]), "date part should be 8 characters (YYYYMMDD)")
|
||
assert.Equal(t, 16, len(parts[1]), "random part should be 16 hex characters")
|
||
}
|
||
|
||
// 验证唯一性(100个应该都不同)
|
||
assert.Equal(t, 100, len(traceIDs), "generated trace IDs should be unique")
|
||
}
|