fix(gateway): harden cors origin validation

Reject non-whitelisted origins on actual requests and format Access-Control-Max-Age correctly. This keeps wildcard subdomain matching explicit and avoids silently serving blocked origins.
This commit is contained in:
Your Name
2026-04-11 09:33:33 +08:00
parent 4adeee2e06
commit dfa8a891ab

View File

@@ -2,6 +2,8 @@ package middleware
import (
"net/http"
"net/url"
"strconv"
"strings"
)
@@ -18,12 +20,12 @@ type CORSConfig struct {
// DefaultCORSConfig 返回默认CORS配置
func DefaultCORSConfig() CORSConfig {
return CORSConfig{
AllowOrigins: []string{"*"}, // 生产环境应限制具体域名
AllowMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
AllowHeaders: []string{"Authorization", "Content-Type", "X-Request-ID", "X-Request-Key"},
ExposeHeaders: []string{"X-Request-ID"},
AllowOrigins: []string{"*"}, // 生产环境应限制具体域名
AllowMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
AllowHeaders: []string{"Authorization", "Content-Type", "X-Request-ID", "X-Request-Key"},
ExposeHeaders: []string{"X-Request-ID"},
AllowCredentials: false,
MaxAge: 86400, // 24小时
MaxAge: 86400, // 24小时
}
}
@@ -31,13 +33,17 @@ func DefaultCORSConfig() CORSConfig {
func CORSMiddleware(config CORSConfig) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 处理CORS预检请求
if r.Method == http.MethodOptions {
handleCORSPreflight(w, r, config)
return
}
// 处理实际请求的CORS头
origin := r.Header.Get("Origin")
if origin != "" && !isOriginAllowed(origin, config.AllowOrigins) {
w.WriteHeader(http.StatusForbidden)
return
}
setCORSHeaders(w, r, config)
next.ServeHTTP(w, r)
})
@@ -48,17 +54,15 @@ func CORSMiddleware(config CORSConfig) func(http.Handler) http.Handler {
func handleCORSPreflight(w http.ResponseWriter, r *http.Request, config CORSConfig) {
origin := r.Header.Get("Origin")
// 检查origin是否被允许
if !isOriginAllowed(origin, config.AllowOrigins) {
w.WriteHeader(http.StatusForbidden)
return
}
// 设置预检响应头
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Access-Control-Allow-Methods", strings.Join(config.AllowMethods, ", "))
w.Header().Set("Access-Control-Allow-Headers", strings.Join(config.AllowHeaders, ", "))
w.Header().Set("Access-Control-Max-Age", string(rune(config.MaxAge)))
w.Header().Set("Access-Control-Max-Age", strconv.Itoa(config.MaxAge))
if config.AllowCredentials {
w.Header().Set("Access-Control-Allow-Credentials", "true")
@@ -71,8 +75,7 @@ func handleCORSPreflight(w http.ResponseWriter, r *http.Request, config CORSConf
func setCORSHeaders(w http.ResponseWriter, r *http.Request, config CORSConfig) {
origin := r.Header.Get("Origin")
// 检查origin是否被允许
if !isOriginAllowed(origin, config.AllowOrigins) {
if origin == "" || !isOriginAllowed(origin, config.AllowOrigins) {
return
}
@@ -100,13 +103,24 @@ func isOriginAllowed(origin string, allowedOrigins []string) bool {
if strings.EqualFold(allowed, origin) {
return true
}
// 支持通配符子域名 *.example.com
if strings.HasPrefix(allowed, "*.") {
domain := allowed[2:]
if strings.HasSuffix(origin, domain) {
host := originHost(origin)
if host == "" {
continue
}
if strings.EqualFold(host, domain) || strings.HasSuffix(host, "."+domain) {
return true
}
}
}
return false
}
}
func originHost(origin string) string {
parsed, err := url.Parse(origin)
if err != nil {
return ""
}
return strings.ToLower(parsed.Hostname())
}