Add comprehensive tests for three middleware components: - cache_control: NoStoreSensitiveResponses, shouldDisableCaching - security_headers: SecurityHeaders, shouldAttachCSP, isHTTPSRequest - trace_id: TraceID, GetTraceID, generateTraceID Coverage: middleware 35.7% → 36.4%
161 lines
3.6 KiB
Go
161 lines
3.6 KiB
Go
package middleware
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
func TestSecurityHeaders(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
router := gin.New()
|
|
router.Use(SecurityHeaders())
|
|
router.GET("/test", func(c *gin.Context) {
|
|
c.String(200, "OK")
|
|
})
|
|
router.GET("/swagger/index.html", func(c *gin.Context) {
|
|
c.String(200, "Swagger UI")
|
|
})
|
|
|
|
tests := []struct {
|
|
name string
|
|
path string
|
|
wantCSP bool
|
|
wantSTS bool // Strict-Transport-Security (only for HTTPS)
|
|
}{
|
|
{
|
|
name: "regular API endpoint",
|
|
path: "/test",
|
|
wantCSP: true,
|
|
wantSTS: false, // HTTP request
|
|
},
|
|
{
|
|
name: "swagger endpoint no CSP",
|
|
path: "/swagger/index.html",
|
|
wantCSP: false,
|
|
wantSTS: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
w := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("GET", tt.path, nil)
|
|
router.ServeHTTP(w, req)
|
|
|
|
// 基础安全头
|
|
assert.Equal(t, "nosniff", w.Header().Get("X-Content-Type-Options"))
|
|
assert.Equal(t, "DENY", w.Header().Get("X-Frame-Options"))
|
|
assert.Equal(t, "strict-origin-when-cross-origin", w.Header().Get("Referrer-Policy"))
|
|
assert.Equal(t, "camera=(), microphone=(), geolocation=()", w.Header().Get("Permissions-Policy"))
|
|
assert.Equal(t, "same-origin", w.Header().Get("Cross-Origin-Opener-Policy"))
|
|
assert.Equal(t, "none", w.Header().Get("X-Permitted-Cross-Domain-Policies"))
|
|
|
|
// CSP 头
|
|
if tt.wantCSP {
|
|
assert.NotEmpty(t, w.Header().Get("Content-Security-Policy"))
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestShouldAttachCSP(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
routePath string
|
|
requestPath string
|
|
expected bool
|
|
}{
|
|
{
|
|
name: "non-swagger path",
|
|
routePath: "/api/v1/users",
|
|
requestPath: "/api/v1/users",
|
|
expected: true,
|
|
},
|
|
{
|
|
name: "swagger path",
|
|
routePath: "/swagger/index.html",
|
|
requestPath: "/swagger/index.html",
|
|
expected: false,
|
|
},
|
|
{
|
|
name: "swagger subpath",
|
|
routePath: "/swagger/api-docs",
|
|
requestPath: "/swagger/api-docs",
|
|
expected: false,
|
|
},
|
|
{
|
|
name: "empty routePath uses requestPath",
|
|
routePath: "",
|
|
requestPath: "/swagger/",
|
|
expected: false,
|
|
},
|
|
{
|
|
name: "trimmed spaces",
|
|
routePath: " /api/v1/users ",
|
|
requestPath: "/api/v1/users",
|
|
expected: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := shouldAttachCSP(tt.routePath, tt.requestPath)
|
|
assert.Equal(t, tt.expected, result)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestIsHTTPSRequest(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
tests := []struct {
|
|
name string
|
|
setup func(*http.Request)
|
|
expected bool
|
|
}{
|
|
{
|
|
name: "plain HTTP request",
|
|
setup: func(req *http.Request) {},
|
|
expected: false,
|
|
},
|
|
{
|
|
name: "X-Forwarded-Proto is https",
|
|
setup: func(req *http.Request) {
|
|
req.Header.Set("X-Forwarded-Proto", "https")
|
|
},
|
|
expected: true,
|
|
},
|
|
{
|
|
name: "X-Forwarded-Proto is HTTPS (uppercase)",
|
|
setup: func(req *http.Request) {
|
|
req.Header.Set("X-Forwarded-Proto", "HTTPS")
|
|
},
|
|
expected: true,
|
|
},
|
|
{
|
|
name: "X-Forwarded-Proto is http",
|
|
setup: func(req *http.Request) {
|
|
req.Header.Set("X-Forwarded-Proto", "http")
|
|
},
|
|
expected: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
|
req, _ := http.NewRequest("GET", "/test", nil)
|
|
tt.setup(req)
|
|
c.Request = req
|
|
|
|
result := isHTTPSRequest(c)
|
|
assert.Equal(t, tt.expected, result)
|
|
})
|
|
}
|
|
}
|