122 lines
4.1 KiB
Go
122 lines
4.1 KiB
Go
|
|
package middleware
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"context"
|
|||
|
|
"net/http"
|
|||
|
|
|
|||
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
|||
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
|||
|
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
|||
|
|
"github.com/gin-gonic/gin"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// ContextKey 定义上下文键类型
|
|||
|
|
type ContextKey string
|
|||
|
|
|
|||
|
|
const (
|
|||
|
|
// ContextKeyUser 用户上下文键
|
|||
|
|
ContextKeyUser ContextKey = "user"
|
|||
|
|
// ContextKeyUserRole 当前用户角色(string)
|
|||
|
|
ContextKeyUserRole ContextKey = "user_role"
|
|||
|
|
// ContextKeyAPIKey API密钥上下文键
|
|||
|
|
ContextKeyAPIKey ContextKey = "api_key"
|
|||
|
|
// ContextKeySubscription 订阅上下文键
|
|||
|
|
ContextKeySubscription ContextKey = "subscription"
|
|||
|
|
// ContextKeyForcePlatform 强制平台(用于 /antigravity 路由)
|
|||
|
|
ContextKeyForcePlatform ContextKey = "force_platform"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// ForcePlatform 返回设置强制平台的中间件
|
|||
|
|
// 同时设置 request.Context(供 Service 使用)和 gin.Context(供 Handler 快速检查)
|
|||
|
|
func ForcePlatform(platform string) gin.HandlerFunc {
|
|||
|
|
return func(c *gin.Context) {
|
|||
|
|
// 设置到 request.Context,使用 ctxkey.ForcePlatform 供 Service 层读取
|
|||
|
|
ctx := context.WithValue(c.Request.Context(), ctxkey.ForcePlatform, platform)
|
|||
|
|
c.Request = c.Request.WithContext(ctx)
|
|||
|
|
// 同时设置到 gin.Context,供 Handler 快速检查
|
|||
|
|
c.Set(string(ContextKeyForcePlatform), platform)
|
|||
|
|
c.Next()
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// HasForcePlatform 检查是否有强制平台(用于 Handler 跳过分组检查)
|
|||
|
|
func HasForcePlatform(c *gin.Context) bool {
|
|||
|
|
_, exists := c.Get(string(ContextKeyForcePlatform))
|
|||
|
|
return exists
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetForcePlatformFromContext 从 gin.Context 获取强制平台
|
|||
|
|
func GetForcePlatformFromContext(c *gin.Context) (string, bool) {
|
|||
|
|
value, exists := c.Get(string(ContextKeyForcePlatform))
|
|||
|
|
if !exists {
|
|||
|
|
return "", false
|
|||
|
|
}
|
|||
|
|
platform, ok := value.(string)
|
|||
|
|
return platform, ok
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ErrorResponse 标准错误响应结构
|
|||
|
|
type ErrorResponse struct {
|
|||
|
|
Code string `json:"code"`
|
|||
|
|
Message string `json:"message"`
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// NewErrorResponse 创建错误响应
|
|||
|
|
func NewErrorResponse(code, message string) ErrorResponse {
|
|||
|
|
return ErrorResponse{
|
|||
|
|
Code: code,
|
|||
|
|
Message: message,
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// AbortWithError 中断请求并返回JSON错误
|
|||
|
|
func AbortWithError(c *gin.Context, statusCode int, code, message string) {
|
|||
|
|
c.JSON(statusCode, NewErrorResponse(code, message))
|
|||
|
|
c.Abort()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ──────────────────────────────────────────────────────────
|
|||
|
|
// RequireGroupAssignment — 未分组 Key 拦截中间件
|
|||
|
|
// ──────────────────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
// GatewayErrorWriter 定义网关错误响应格式(不同协议使用不同格式)
|
|||
|
|
type GatewayErrorWriter func(c *gin.Context, status int, message string)
|
|||
|
|
|
|||
|
|
// AnthropicErrorWriter 按 Anthropic API 规范输出错误
|
|||
|
|
func AnthropicErrorWriter(c *gin.Context, status int, message string) {
|
|||
|
|
c.JSON(status, gin.H{
|
|||
|
|
"type": "error",
|
|||
|
|
"error": gin.H{"type": "permission_error", "message": message},
|
|||
|
|
})
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GoogleErrorWriter 按 Google API 规范输出错误
|
|||
|
|
func GoogleErrorWriter(c *gin.Context, status int, message string) {
|
|||
|
|
c.JSON(status, gin.H{
|
|||
|
|
"error": gin.H{
|
|||
|
|
"code": status,
|
|||
|
|
"message": message,
|
|||
|
|
"status": googleapi.HTTPStatusToGoogleStatus(status),
|
|||
|
|
},
|
|||
|
|
})
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// RequireGroupAssignment 检查 API Key 是否已分配到分组,
|
|||
|
|
// 如果未分组且系统设置不允许未分组 Key 调度则返回 403。
|
|||
|
|
func RequireGroupAssignment(settingService *service.SettingService, writeError GatewayErrorWriter) gin.HandlerFunc {
|
|||
|
|
return func(c *gin.Context) {
|
|||
|
|
apiKey, ok := GetAPIKeyFromContext(c)
|
|||
|
|
if !ok || apiKey.GroupID != nil {
|
|||
|
|
c.Next()
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
// 未分组 Key — 检查系统设置
|
|||
|
|
if settingService.IsUngroupedKeySchedulingAllowed(c.Request.Context()) {
|
|||
|
|
c.Next()
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
writeError(c, http.StatusForbidden, "API Key is not assigned to any group and cannot be used. Please contact the administrator to assign it to a group.")
|
|||
|
|
c.Abort()
|
|||
|
|
}
|
|||
|
|
}
|