diff --git a/gateway/internal/middleware/cors.go b/gateway/internal/middleware/cors.go index 724ee97f..e4c3830a 100644 --- a/gateway/internal/middleware/cors.go +++ b/gateway/internal/middleware/cors.go @@ -2,6 +2,8 @@ package middleware import ( "net/http" + "net/url" + "strconv" "strings" ) @@ -18,12 +20,12 @@ type CORSConfig struct { // DefaultCORSConfig 返回默认CORS配置 func DefaultCORSConfig() CORSConfig { return CORSConfig{ - AllowOrigins: []string{"*"}, // 生产环境应限制具体域名 - AllowMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, - AllowHeaders: []string{"Authorization", "Content-Type", "X-Request-ID", "X-Request-Key"}, - ExposeHeaders: []string{"X-Request-ID"}, + AllowOrigins: []string{"*"}, // 生产环境应限制具体域名 + AllowMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, + AllowHeaders: []string{"Authorization", "Content-Type", "X-Request-ID", "X-Request-Key"}, + ExposeHeaders: []string{"X-Request-ID"}, AllowCredentials: false, - MaxAge: 86400, // 24小时 + MaxAge: 86400, // 24小时 } } @@ -31,13 +33,17 @@ func DefaultCORSConfig() CORSConfig { func CORSMiddleware(config CORSConfig) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // 处理CORS预检请求 if r.Method == http.MethodOptions { handleCORSPreflight(w, r, config) return } - // 处理实际请求的CORS头 + origin := r.Header.Get("Origin") + if origin != "" && !isOriginAllowed(origin, config.AllowOrigins) { + w.WriteHeader(http.StatusForbidden) + return + } + setCORSHeaders(w, r, config) next.ServeHTTP(w, r) }) @@ -48,17 +54,15 @@ func CORSMiddleware(config CORSConfig) func(http.Handler) http.Handler { func handleCORSPreflight(w http.ResponseWriter, r *http.Request, config CORSConfig) { origin := r.Header.Get("Origin") - // 检查origin是否被允许 if !isOriginAllowed(origin, config.AllowOrigins) { w.WriteHeader(http.StatusForbidden) return } - // 设置预检响应头 w.Header().Set("Access-Control-Allow-Origin", origin) w.Header().Set("Access-Control-Allow-Methods", strings.Join(config.AllowMethods, ", ")) w.Header().Set("Access-Control-Allow-Headers", strings.Join(config.AllowHeaders, ", ")) - w.Header().Set("Access-Control-Max-Age", string(rune(config.MaxAge))) + w.Header().Set("Access-Control-Max-Age", strconv.Itoa(config.MaxAge)) if config.AllowCredentials { w.Header().Set("Access-Control-Allow-Credentials", "true") @@ -71,8 +75,7 @@ func handleCORSPreflight(w http.ResponseWriter, r *http.Request, config CORSConf func setCORSHeaders(w http.ResponseWriter, r *http.Request, config CORSConfig) { origin := r.Header.Get("Origin") - // 检查origin是否被允许 - if !isOriginAllowed(origin, config.AllowOrigins) { + if origin == "" || !isOriginAllowed(origin, config.AllowOrigins) { return } @@ -100,13 +103,24 @@ func isOriginAllowed(origin string, allowedOrigins []string) bool { if strings.EqualFold(allowed, origin) { return true } - // 支持通配符子域名 *.example.com if strings.HasPrefix(allowed, "*.") { domain := allowed[2:] - if strings.HasSuffix(origin, domain) { + host := originHost(origin) + if host == "" { + continue + } + if strings.EqualFold(host, domain) || strings.HasSuffix(host, "."+domain) { return true } } } return false -} \ No newline at end of file +} + +func originHost(origin string) string { + parsed, err := url.Parse(origin) + if err != nil { + return "" + } + return strings.ToLower(parsed.Hostname()) +}