113 lines
3.1 KiB
Go
113 lines
3.1 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"github.com/company/ai-ops/internal/config"
|
|
"github.com/company/ai-ops/pkg/errors"
|
|
"github.com/company/ai-ops/pkg/response"
|
|
"github.com/golang-jwt/jwt/v5"
|
|
)
|
|
|
|
// Auth 中间件检查认证
|
|
func Auth(cfg config.ServerConfig) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// 白名单路径免认证
|
|
if isPublicPath(r.URL.Path) {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
// API Key 检查(用于 /metrics 等机器对机器接口)
|
|
if strings.HasPrefix(r.URL.Path, "/metrics") {
|
|
apiKey := r.Header.Get("X-API-Key")
|
|
if apiKey == "" {
|
|
apiKey = r.URL.Query().Get("api_key")
|
|
}
|
|
if apiKey == cfg.MetricsAuth {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
}
|
|
|
|
// JWT 检查
|
|
tokenStr := r.Header.Get("Authorization")
|
|
if tokenStr == "" {
|
|
response.Error(w, errors.ErrUnauthorized)
|
|
return
|
|
}
|
|
tokenStr = strings.TrimPrefix(tokenStr, "Bearer ")
|
|
|
|
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
|
|
return []byte(cfg.JWTSecret), nil
|
|
}, jwt.WithValidMethods([]string{"HS256"}))
|
|
if err != nil || !token.Valid {
|
|
response.Error(w, errors.ErrUnauthorized)
|
|
return
|
|
}
|
|
|
|
// 将用户ID和角色写入上下文
|
|
if claims, ok := token.Claims.(jwt.MapClaims); ok {
|
|
if userID, ok := claims["user_id"].(string); ok {
|
|
r = r.WithContext(context.WithValue(r.Context(), "user_id", userID))
|
|
}
|
|
if role, ok := claims["role"].(string); ok {
|
|
r = r.WithContext(context.WithValue(r.Context(), "role", role))
|
|
}
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
// RequireRole 角色权限中间件
|
|
func RequireRole(roles ...string) func(http.Handler) http.Handler {
|
|
roleSet := make(map[string]bool)
|
|
for _, r := range roles {
|
|
roleSet[r] = true
|
|
}
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
role, _ := r.Context().Value("role").(string)
|
|
if !roleSet[role] {
|
|
response.Error(w, errors.ErrForbidden.WithDetail(map[string]any{
|
|
"error": "insufficient permissions",
|
|
"code": "OPS_AUTH_1001",
|
|
"required": roles,
|
|
"current": role,
|
|
}))
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
// RequireWrite 允许 GET 或需要写权限
|
|
func RequireWrite(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method == "GET" || r.Method == "HEAD" {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
role, _ := r.Context().Value("role").(string)
|
|
if role != "operator" && role != "admin" {
|
|
response.Error(w, errors.ErrForbidden.WithDetail(map[string]any{
|
|
"error": "write permission required",
|
|
"code": "OPS_AUTH_1001",
|
|
"current": role,
|
|
}))
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func isPublicPath(path string) bool {
|
|
return path == "/health" || strings.HasPrefix(path, "/actuator/health") || path == "/api/v1/ai-ops/login" || path == "/openapi.json" || strings.HasPrefix(path, "/ops/dashboard")
|
|
}
|