fix(gateway): harden cors origin validation
Reject non-whitelisted origins on actual requests and format Access-Control-Max-Age correctly. This keeps wildcard subdomain matching explicit and avoids silently serving blocked origins.
This commit is contained in:
@@ -2,6 +2,8 @@ package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -18,12 +20,12 @@ type CORSConfig struct {
|
||||
// DefaultCORSConfig 返回默认CORS配置
|
||||
func DefaultCORSConfig() CORSConfig {
|
||||
return CORSConfig{
|
||||
AllowOrigins: []string{"*"}, // 生产环境应限制具体域名
|
||||
AllowMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
|
||||
AllowHeaders: []string{"Authorization", "Content-Type", "X-Request-ID", "X-Request-Key"},
|
||||
ExposeHeaders: []string{"X-Request-ID"},
|
||||
AllowOrigins: []string{"*"}, // 生产环境应限制具体域名
|
||||
AllowMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
|
||||
AllowHeaders: []string{"Authorization", "Content-Type", "X-Request-ID", "X-Request-Key"},
|
||||
ExposeHeaders: []string{"X-Request-ID"},
|
||||
AllowCredentials: false,
|
||||
MaxAge: 86400, // 24小时
|
||||
MaxAge: 86400, // 24小时
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,13 +33,17 @@ func DefaultCORSConfig() CORSConfig {
|
||||
func CORSMiddleware(config CORSConfig) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// 处理CORS预检请求
|
||||
if r.Method == http.MethodOptions {
|
||||
handleCORSPreflight(w, r, config)
|
||||
return
|
||||
}
|
||||
|
||||
// 处理实际请求的CORS头
|
||||
origin := r.Header.Get("Origin")
|
||||
if origin != "" && !isOriginAllowed(origin, config.AllowOrigins) {
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
setCORSHeaders(w, r, config)
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
@@ -48,17 +54,15 @@ func CORSMiddleware(config CORSConfig) func(http.Handler) http.Handler {
|
||||
func handleCORSPreflight(w http.ResponseWriter, r *http.Request, config CORSConfig) {
|
||||
origin := r.Header.Get("Origin")
|
||||
|
||||
// 检查origin是否被允许
|
||||
if !isOriginAllowed(origin, config.AllowOrigins) {
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// 设置预检响应头
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
w.Header().Set("Access-Control-Allow-Methods", strings.Join(config.AllowMethods, ", "))
|
||||
w.Header().Set("Access-Control-Allow-Headers", strings.Join(config.AllowHeaders, ", "))
|
||||
w.Header().Set("Access-Control-Max-Age", string(rune(config.MaxAge)))
|
||||
w.Header().Set("Access-Control-Max-Age", strconv.Itoa(config.MaxAge))
|
||||
|
||||
if config.AllowCredentials {
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
@@ -71,8 +75,7 @@ func handleCORSPreflight(w http.ResponseWriter, r *http.Request, config CORSConf
|
||||
func setCORSHeaders(w http.ResponseWriter, r *http.Request, config CORSConfig) {
|
||||
origin := r.Header.Get("Origin")
|
||||
|
||||
// 检查origin是否被允许
|
||||
if !isOriginAllowed(origin, config.AllowOrigins) {
|
||||
if origin == "" || !isOriginAllowed(origin, config.AllowOrigins) {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -100,13 +103,24 @@ func isOriginAllowed(origin string, allowedOrigins []string) bool {
|
||||
if strings.EqualFold(allowed, origin) {
|
||||
return true
|
||||
}
|
||||
// 支持通配符子域名 *.example.com
|
||||
if strings.HasPrefix(allowed, "*.") {
|
||||
domain := allowed[2:]
|
||||
if strings.HasSuffix(origin, domain) {
|
||||
host := originHost(origin)
|
||||
if host == "" {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(host, domain) || strings.HasSuffix(host, "."+domain) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func originHost(origin string) string {
|
||||
parsed, err := url.Parse(origin)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return strings.ToLower(parsed.Hostname())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user