Files
Developer da36506b89 fix: resolve P0/P1 code quality issues
P0 fixes:
- ModelError.Is(): use exact matching instead of substring contains()
- shouldClearStickySession: add context param for cancellation/tracing

P1 fixes:
- TODO stubs: return 501 Not Implemented errors
- validateInstanceSignature: deduplicate to shared validateCodeSignature()
- Error messages: standardize to English only
- http.go: remove pseudo if-else with duplicate branches
2026-03-31 11:39:18 +08:00

309 lines
9.0 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func init() {
// cors_test 与 security_headers_test 在同一个包,但 init 是幂等的
gin.SetMode(gin.TestMode)
}
// --- Task 8.2: 验证 CORS 条件化头部 ---
func TestCORS_DisallowedOrigin_NoAllowHeaders(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{"https://allowed.example.com"},
AllowCredentials: false,
}
middleware := CORS(cfg)
tests := []struct {
name string
method string
origin string
}{
{
name: "preflight_disallowed_origin",
method: http.MethodOptions,
origin: "https://evil.example.com",
},
{
name: "get_disallowed_origin",
method: http.MethodGet,
origin: "https://evil.example.com",
},
{
name: "post_disallowed_origin",
method: http.MethodPost,
origin: "https://attacker.example.com",
},
{
name: "preflight_no_origin",
method: http.MethodOptions,
origin: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(tt.method, "/", nil)
if tt.origin != "" {
c.Request.Header.Set("Origin", tt.origin)
}
middleware(c)
// 不应设置 Allow-Headers、Allow-Methods 和 Max-Age
assert.Empty(t, w.Header().Get("Access-Control-Allow-Headers"),
"不允许的 origin 不应收到 Allow-Headers")
assert.Empty(t, w.Header().Get("Access-Control-Allow-Methods"),
"不允许的 origin 不应收到 Allow-Methods")
assert.Empty(t, w.Header().Get("Access-Control-Max-Age"),
"不允许的 origin 不应收到 Max-Age")
assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"),
"不允许的 origin 不应收到 Allow-Origin")
})
}
}
func TestCORS_AllowedOrigin_HasAllowHeaders(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{"https://allowed.example.com"},
AllowCredentials: false,
}
middleware := CORS(cfg)
tests := []struct {
name string
method string
}{
{name: "preflight_OPTIONS", method: http.MethodOptions},
{name: "normal_GET", method: http.MethodGet},
{name: "normal_POST", method: http.MethodPost},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(tt.method, "/", nil)
c.Request.Header.Set("Origin", "https://allowed.example.com")
middleware(c)
// 应设置 Allow-Headers、Allow-Methods 和 Max-Age
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"),
"允许的 origin 应收到 Allow-Headers")
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Methods"),
"允许的 origin 应收到 Allow-Methods")
assert.Equal(t, "86400", w.Header().Get("Access-Control-Max-Age"),
"允许的 origin 应收到 Max-Age=86400")
assert.Equal(t, "https://allowed.example.com", w.Header().Get("Access-Control-Allow-Origin"),
"允许的 origin 应收到 Allow-Origin")
})
}
}
func TestCORS_PreflightDisallowedOrigin_ReturnsForbidden(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{"https://allowed.example.com"},
AllowCredentials: false,
}
middleware := CORS(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodOptions, "/", nil)
c.Request.Header.Set("Origin", "https://evil.example.com")
middleware(c)
assert.Equal(t, http.StatusForbidden, w.Code,
"不允许的 origin 的 preflight 请求应返回 403")
}
func TestCORS_PreflightAllowedOrigin_ReturnsNoContent(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{"https://allowed.example.com"},
AllowCredentials: false,
}
middleware := CORS(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodOptions, "/", nil)
c.Request.Header.Set("Origin", "https://allowed.example.com")
middleware(c)
assert.Equal(t, http.StatusNoContent, w.Code,
"允许的 origin 的 preflight 请求应返回 204")
}
func TestCORS_WildcardOrigin_AllowsAny(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{"*"},
AllowCredentials: false,
}
middleware := CORS(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("Origin", "https://any-origin.example.com")
middleware(c)
assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"),
"通配符配置应返回 *")
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"),
"通配符 origin 应设置 Allow-Headers")
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Methods"),
"通配符 origin 应设置 Allow-Methods")
}
func TestCORS_AllowCredentials_SetCorrectly(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{"https://allowed.example.com"},
AllowCredentials: true,
}
middleware := CORS(cfg)
t.Run("allowed_origin_gets_credentials", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("Origin", "https://allowed.example.com")
middleware(c)
assert.Equal(t, "true", w.Header().Get("Access-Control-Allow-Credentials"),
"允许的 origin 且开启 credentials 应设置 Allow-Credentials")
})
t.Run("disallowed_origin_no_credentials", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("Origin", "https://evil.example.com")
middleware(c)
assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"),
"不允许的 origin 不应收到 Allow-Credentials")
})
}
func TestCORS_WildcardWithCredentials_DisablesCredentials(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{"*"},
AllowCredentials: true,
}
middleware := CORS(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("Origin", "https://any.example.com")
middleware(c)
// 通配符 + credentials 不兼容credentials 应被禁用
assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"),
"通配符 origin 应禁用 Allow-Credentials")
}
func TestCORS_MultipleAllowedOrigins(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{
"https://app1.example.com",
"https://app2.example.com",
},
AllowCredentials: false,
}
middleware := CORS(cfg)
t.Run("first_origin_allowed", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("Origin", "https://app1.example.com")
middleware(c)
assert.Equal(t, "https://app1.example.com", w.Header().Get("Access-Control-Allow-Origin"))
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"))
})
t.Run("second_origin_allowed", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("Origin", "https://app2.example.com")
middleware(c)
assert.Equal(t, "https://app2.example.com", w.Header().Get("Access-Control-Allow-Origin"))
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"))
})
t.Run("unlisted_origin_rejected", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("Origin", "https://app3.example.com")
middleware(c)
assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"))
assert.Empty(t, w.Header().Get("Access-Control-Allow-Headers"))
})
}
func TestCORS_VaryHeader_SetForSpecificOrigin(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{"https://allowed.example.com"},
AllowCredentials: false,
}
middleware := CORS(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("Origin", "https://allowed.example.com")
middleware(c)
assert.Contains(t, w.Header().Values("Vary"), "Origin",
"非通配符允许的 origin 应设置 Vary: Origin")
}
func TestNormalizeOrigins(t *testing.T) {
tests := []struct {
name string
input []string
expect []string
}{
{name: "nil_input", input: nil, expect: nil},
{name: "empty_input", input: []string{}, expect: nil},
{name: "trims_whitespace", input: []string{" https://a.com ", " https://b.com"}, expect: []string{"https://a.com", "https://b.com"}},
{name: "removes_empty_strings", input: []string{"", " ", "https://a.com"}, expect: []string{"https://a.com"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := normalizeOrigins(tt.input)
assert.Equal(t, tt.expect, result)
})
}
}