From a097055deb5c35ed8890f76513c94ab4346608e7 Mon Sep 17 00:00:00 2001 From: phamnazage-jpg Date: Tue, 31 Mar 2026 11:44:54 +0800 Subject: [PATCH] fix: resolve P0/P1 code quality issues P0 fixes: - ModelError.Is(): use exact matching instead of substring contains() - shouldClearStickySession: add context param for cancellation/tracing P1 fixes: - TODO stubs: return 501 Not Implemented errors - validateInstanceSignature: deduplicate to shared validateCodeSignature() - Error messages: standardize to English only - http.go: remove pseudo if-else with duplicate branches --- backend/internal/pkg/models/interface.go | 256 + backend/internal/server/http.go | 103 + backend/internal/service/admin_service.go | 2661 ++++++ backend/internal/service/api_key_service.go | 881 ++ backend/internal/service/gateway_service.go | 8510 +++++++++++++++++ .../service/gemini_messages_compat_service.go | 3307 +++++++ .../internal/service/instance_signature.go | 38 + .../service/openai_account_scheduler.go | 922 ++ .../service/openai_gateway_service.go | 4709 +++++++++ .../internal/service/openai_ws_forwarder.go | 4077 ++++++++ backend/internal/service/promo_service.go | 268 + backend/internal/service/proxy_service.go | 193 + backend/internal/service/redeem_service.go | 468 + 13 files changed, 26393 insertions(+) create mode 100644 backend/internal/pkg/models/interface.go create mode 100644 backend/internal/server/http.go create mode 100644 backend/internal/service/admin_service.go create mode 100644 backend/internal/service/api_key_service.go create mode 100644 backend/internal/service/gateway_service.go create mode 100644 backend/internal/service/gemini_messages_compat_service.go create mode 100644 backend/internal/service/instance_signature.go create mode 100644 backend/internal/service/openai_account_scheduler.go create mode 100644 backend/internal/service/openai_gateway_service.go create mode 100644 backend/internal/service/openai_ws_forwarder.go create mode 100644 backend/internal/service/promo_service.go create mode 100644 backend/internal/service/proxy_service.go create mode 100644 backend/internal/service/redeem_service.go diff --git a/backend/internal/pkg/models/interface.go b/backend/internal/pkg/models/interface.go new file mode 100644 index 00000000..487dc0a1 --- /dev/null +++ b/backend/internal/pkg/models/interface.go @@ -0,0 +1,256 @@ +package models + +import ( + "context" + "fmt" + "io" + "time" +) + +// ============================================================================= +// Type Definitions +// ============================================================================= + +// Provider 模型提供商接口 +// 所有模型提供商都必须实现此接口 +type Provider interface { + // Name 返回提供商名称 (如 "deepseek", "qwen", "baidu") + Name() string + + // BaseURL 返回 API 基础地址 + BaseURL() string + + // Models 返回支持的模型列表 + Models() []Model + + // Chat 发起聊天请求 (非流式) + Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, error) + + // ChatStream 发起流式聊天请求 + ChatStream(ctx context.Context, req *ChatRequest) (io.ReadCloser, error) + + // Embeddings 获取嵌入向量 + Embeddings(ctx context.Context, req *EmbeddingsRequest) (*EmbeddingsResponse, error) + + // ValidateKey 验证 API 密钥有效性 + ValidateKey(ctx context.Context, key string) error + + // Close 关闭 provider,释放资源 + Close() error +} + +// Model 模型信息 +type Model struct { + ID string `json:"id"` // 模型 ID (如 "deepseek-chat") + Name string `json:"name"` // 显示名称 (如 "DeepSeek Chat") + Provider string `json:"provider"` // 提供商名称 + Type string `json:"type"` // 类型: "chat", "embedding", "image" + ContextSize int `json:"context_size"` // 上下文长度 (tokens) + MaxTokens int `json:"max_tokens"` // 最大输出 tokens + Capabilities []string `json:"capabilities"` // 能力列表: "streaming", "vision", "function_call" +} + +// ChatRequest 聊天请求 +type ChatRequest struct { + Model string `json:"model"` + Messages []ChatMessage `json:"messages"` + Temperature float64 `json:"temperature,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Stream bool `json:"stream,omitempty"` + TopP float64 `json:"top_p,omitempty"` + Stop []string `json:"stop,omitempty"` + Tools []Tool `json:"tools,omitempty"` + ToolChoice interface{} `json:"tool_choice,omitempty"` + ResponseFormat interface{} `json:"response_format,omitempty"` + // Provider 特定参数 + Extra map[string]interface{} `json:"-"` +} + +// ChatMessage 聊天消息 +type ChatMessage struct { + Role string `json:"role"` // "system", "user", "assistant", "tool" + Content string `json:"content"` + Name string `json:"name,omitempty"` + // Tool call 相关 + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` +} + +// ToolCall 工具调用 +type ToolCall struct { + ID string `json:"id"` + Type string `json:"type"` // "function" + Function FunctionCall `json:"function"` +} + +// FunctionCall 函数调用 +type FunctionCall struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +// Tool 工具定义 +type Tool struct { + Type string `json:"type"` // "function" + Function FunctionDefinition `json:"function"` +} + +// FunctionDefinition 函数定义 +type FunctionDefinition struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters map[string]interface{} `json:"parameters"` +} + +// ChatResponse 聊天响应 (非流式) +type ChatResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []Choice `json:"choices"` + Usage Usage `json:"usage"` + SystemFingerprint string `json:"system_fingerprint,omitempty"` +} + +// Choice 选项 +type Choice struct { + Index int `json:"index"` + Message ChatMessage `json:"message"` + FinishReason string `json:"finish_reason"` +} + +// Usage 用量统计 +type Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + // Provider 特定 + PromptTokensDetails *PromptTokensDetails `json:"prompt_tokens_details,omitempty"` + CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details,omitempty"` +} + +type PromptTokensDetails struct { + CachedTokens int `json:"cached_tokens,omitempty"` +} + +type CompletionTokensDetails struct { + ReasoningTokens int `json:"reasoning_tokens,omitempty"` +} + +// EmbeddingsRequest 嵌入请求 +type EmbeddingsRequest struct { + Model string `json:"model"` + Input []string `json:"input"` + InputType string `json:"input_type,omitempty"` + EncodingFormat string `json:"encoding_format,omitempty"` + Dimensions int `json:"dimensions,omitempty"` +} + +// EmbeddingsResponse 嵌入响应 +type EmbeddingsResponse struct { + Object string `json:"object"` + Data []Embedding `json:"data"` + Model string `json:"model"` + Usage Usage `json:"usage"` +} + +// Embedding 嵌入向量 +type Embedding struct { + Object string `json:"object"` + Index int `json:"index"` + Embedding []float64 `json:"embedding"` +} + +// ProviderConfig 提供商配置 +type ProviderConfig struct { + APIKey string + BaseURL string + Organization string + Timeout time.Duration + // 认证相关 + AuthType string // "bearer", "api_key", "oauth" + // Provider 特定 + Extra map[string]interface{} +} + +// HTTPClientConfig HTTP 客户端配置 +type HTTPConfig struct { + Timeout time.Duration + MaxRetries int + RetryDelay time.Duration + MaxKeepAlive int +} + +// PoolConfig 连接池配置 +type PoolConfig struct { + MaxIdle int + MaxActive int + IdleTimeout time.Duration +} + +// Errors +var ( + ErrUnknownProvider = NewError("unknown provider: %s") + ErrInvalidAPIKey = NewError("invalid API key") + ErrRateLimited = NewError("rate limited") + ErrInsufficientQuota = NewError("insufficient quota") + ErrModelNotFound = NewError("model not found: %s") + ErrInvalidRequest = NewError("invalid request: %s") + ErrContextCancelled = NewError("context cancelled") + ErrTimeout = NewError("request timeout") +) + +// NewError 创建错误 +func NewError(format string, args ...interface{}) error { + return &ModelError{Message: format, Args: args} +} + +// Errorf 创建格式化错误 +func Errorf(format string, args ...interface{}) error { + return &ModelError{Message: format, Args: args} +} + +// ModelError 模型错误 +type ModelError struct { + Message string + Args []interface{} +} + +func (e *ModelError) Error() string { + if len(e.Args) > 0 { + return e.Message + ": " + fmt.Sprint(e.Args...) + } + return e.Message +} + +func (e *ModelError) Unwrap() error { + return nil +} + +// Is performs error identity comparison. +// Only returns true when target is the exact same ModelError instance +// (including format string and arguments), not for substring matches. +// This ensures errors.Is(ErrRateLimited, ErrInvalidRequest) returns false +// even when message strings happen to be substrings of each other. +func (e *ModelError) Is(target error) bool { + targetME, ok := target.(*ModelError) + if !ok { + return false + } + if e == targetME { + return true + } + if e.Message != targetME.Message { + return false + } + if len(e.Args) != len(targetME.Args) { + return false + } + for i, a := range e.Args { + if a != targetME.Args[i] { + return false + } + } + return true +} \ No newline at end of file diff --git a/backend/internal/server/http.go b/backend/internal/server/http.go new file mode 100644 index 00000000..5583cf4b --- /dev/null +++ b/backend/internal/server/http.go @@ -0,0 +1,103 @@ +// Package server provides HTTP server initialization and configuration. +package server + +import ( + "log" + "net/http" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/handler" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" + "github.com/google/wire" + "github.com/redis/go-redis/v9" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" +) + +// ProviderSet 提供服务器层的依赖 +var ProviderSet = wire.NewSet( + ProvideRouter, + ProvideHTTPServer, +) + +// ProvideRouter 提供路由器 +func ProvideRouter( + cfg *config.Config, + handlers *handler.Handlers, + jwtAuth middleware2.JWTAuthMiddleware, + adminAuth middleware2.AdminAuthMiddleware, + apiKeyAuth middleware2.APIKeyAuthMiddleware, + apiKeyService *service.APIKeyService, + subscriptionService *service.SubscriptionService, + opsService *service.OpsService, + settingService *service.SettingService, + redisClient *redis.Client, +) *gin.Engine { + // Set GIN mode to release to avoid debug logging leaking sensitive information + gin.SetMode(gin.ReleaseMode) + + r := gin.New() + r.Use(middleware2.Recovery()) + if len(cfg.Server.TrustedProxies) > 0 { + if err := r.SetTrustedProxies(cfg.Server.TrustedProxies); err != nil { + log.Printf("Failed to set trusted proxies: %v", err) + } + } else { + if err := r.SetTrustedProxies(nil); err != nil { + log.Printf("Failed to disable trusted proxies: %v", err) + } + if cfg.Server.Mode == "release" { + log.Printf("Warning: server.trusted_proxies is empty in release mode; client IP trust chain is disabled") + } + } + + return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient) +} + +// ProvideHTTPServer 提供 HTTP 服务器 +func ProvideHTTPServer(cfg *config.Config, router *gin.Engine) *http.Server { + httpHandler := http.Handler(router) + + globalMaxSize := cfg.Server.MaxRequestBodySize + if globalMaxSize <= 0 { + globalMaxSize = cfg.Gateway.MaxBodySize + } + if globalMaxSize > 0 { + httpHandler = http.MaxBytesHandler(httpHandler, globalMaxSize) + log.Printf("Global max request body size: %d bytes (%.2f MB)", globalMaxSize, float64(globalMaxSize)/(1<<20)) + } + + // 根据配置决定是否启用 H2C + if cfg.Server.H2C.Enabled { + h2cConfig := cfg.Server.H2C + httpHandler = h2c.NewHandler(router, &http2.Server{ + MaxConcurrentStreams: h2cConfig.MaxConcurrentStreams, + IdleTimeout: time.Duration(h2cConfig.IdleTimeout) * time.Second, + MaxReadFrameSize: uint32(h2cConfig.MaxReadFrameSize), + MaxUploadBufferPerConnection: int32(h2cConfig.MaxUploadBufferPerConnection), + MaxUploadBufferPerStream: int32(h2cConfig.MaxUploadBufferPerStream), + }) + log.Printf("HTTP/2 Cleartext (h2c) enabled: max_concurrent_streams=%d, idle_timeout=%ds, max_read_frame_size=%d, max_upload_buffer_per_connection=%d, max_upload_buffer_per_stream=%d", + h2cConfig.MaxConcurrentStreams, + h2cConfig.IdleTimeout, + h2cConfig.MaxReadFrameSize, + h2cConfig.MaxUploadBufferPerConnection, + h2cConfig.MaxUploadBufferPerStream, + ) + } + + return &http.Server{ + Addr: cfg.Server.Address(), + Handler: httpHandler, + // ReadHeaderTimeout: 读取请求头的超时时间,防止慢速请求头攻击 + ReadHeaderTimeout: time.Duration(cfg.Server.ReadHeaderTimeout) * time.Second, + // IdleTimeout: 空闲连接超时时间,释放不活跃的连接资源 + IdleTimeout: time.Duration(cfg.Server.IdleTimeout) * time.Second, + // 注意:不设置 WriteTimeout,因为流式响应可能持续十几分钟 + // 不设置 ReadTimeout,因为大请求体可能需要较长时间读取 + } +} diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go new file mode 100644 index 00000000..b69dfdc4 --- /dev/null +++ b/backend/internal/service/admin_service.go @@ -0,0 +1,2661 @@ +package service + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/util/soraerror" +) + +// AdminService interface defines admin management operations +type AdminService interface { + // User management + ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters) ([]User, int64, error) + GetUser(ctx context.Context, id int64) (*User, error) + CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) + UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) + DeleteUser(ctx context.Context, id int64) error + UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) + GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) + GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) + // GetUserBalanceHistory returns paginated balance/concurrency change records for a user. + // codeType is optional - pass empty string to return all types. + // Also returns totalRecharged (sum of all positive balance top-ups). + GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error) + + // Group management + ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error) + GetAllGroups(ctx context.Context) ([]Group, error) + GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error) + GetGroup(ctx context.Context, id int64) (*Group, error) + CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) + UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) + DeleteGroup(ctx context.Context, id int64) error + GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error) + GetGroupRateMultipliers(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error) + ClearGroupRateMultipliers(ctx context.Context, groupID int64) error + BatchSetGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error + UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error + + // API Key management (admin) + AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*AdminUpdateAPIKeyGroupIDResult, error) + + // ReplaceUserGroup 替换用户的专属分组:授予新分组权限、迁移 Key、移除旧分组权限 + ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error) + + // Account management + ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error) + GetAccount(ctx context.Context, id int64) (*Account, error) + GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error) + CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) + UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error) + DeleteAccount(ctx context.Context, id int64) error + RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error) + ClearAccountError(ctx context.Context, id int64) (*Account, error) + SetAccountError(ctx context.Context, id int64, errorMsg string) error + // EnsureOpenAIPrivacy 检查 OpenAI OAuth 账号 privacy_mode,未设置则尝试关闭训练数据共享并持久化。 + EnsureOpenAIPrivacy(ctx context.Context, account *Account) string + SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) + BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) + CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error + + // Proxy management + ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) + ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]ProxyWithAccountCount, int64, error) + GetAllProxies(ctx context.Context) ([]Proxy, error) + GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) + GetProxy(ctx context.Context, id int64) (*Proxy, error) + GetProxiesByIDs(ctx context.Context, ids []int64) ([]Proxy, error) + CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error) + UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*Proxy, error) + DeleteProxy(ctx context.Context, id int64) error + BatchDeleteProxies(ctx context.Context, ids []int64) (*ProxyBatchDeleteResult, error) + GetProxyAccounts(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) + CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) + TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error) + CheckProxyQuality(ctx context.Context, id int64) (*ProxyQualityCheckResult, error) + + // Redeem code management + ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error) + GetRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) + GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error) + DeleteRedeemCode(ctx context.Context, id int64) error + BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error) + ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) + ResetAccountQuota(ctx context.Context, id int64) error +} + +// CreateUserInput represents input for creating a new user via admin operations. +type CreateUserInput struct { + Email string + Password string + Username string + Notes string + Balance float64 + Concurrency int + AllowedGroups []int64 + SoraStorageQuotaBytes int64 +} + +type UpdateUserInput struct { + Email string + Password string + Username *string + Notes *string + Balance *float64 // 使用指针区分"未提供"和"设置为0" + Concurrency *int // 使用指针区分"未提供"和"设置为0" + Status string + AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组" + // GroupRates 用户专属分组倍率配置 + // map[groupID]*rate,nil 表示删除该分组的专属倍率 + GroupRates map[int64]*float64 + SoraStorageQuotaBytes *int64 +} + +type CreateGroupInput struct { + Name string + Description string + Platform string + RateMultiplier float64 + IsExclusive bool + SubscriptionType string // standard/subscription + DailyLimitUSD *float64 // 日限额 (USD) + WeeklyLimitUSD *float64 // 周限额 (USD) + MonthlyLimitUSD *float64 // 月限额 (USD) + // 图片生成计费配置(仅 antigravity 平台使用) + ImagePrice1K *float64 + ImagePrice2K *float64 + ImagePrice4K *float64 + // Sora 按次计费配置 + SoraImagePrice360 *float64 + SoraImagePrice540 *float64 + SoraVideoPricePerRequest *float64 + SoraVideoPricePerRequestHD *float64 + ClaudeCodeOnly bool // 仅允许 Claude Code 客户端 + FallbackGroupID *int64 // 降级分组 ID + // 无效请求兜底分组 ID(仅 anthropic 平台使用) + FallbackGroupIDOnInvalidRequest *int64 + // 模型路由配置(仅 anthropic 平台使用) + ModelRouting map[string][]int64 + ModelRoutingEnabled bool // 是否启用模型路由 + MCPXMLInject *bool + // 支持的模型系列(仅 antigravity 平台使用) + SupportedModelScopes []string + // Sora 存储配额 + SoraStorageQuotaBytes int64 + // OpenAI Messages 调度配置(仅 openai 平台使用) + AllowMessagesDispatch bool + DefaultMappedModel string + // 从指定分组复制账号(创建分组后在同一事务内绑定) + CopyAccountsFromGroupIDs []int64 +} + +type UpdateGroupInput struct { + Name string + Description string + Platform string + RateMultiplier *float64 // 使用指针以支持设置为0 + IsExclusive *bool + Status string + SubscriptionType string // standard/subscription + DailyLimitUSD *float64 // 日限额 (USD) + WeeklyLimitUSD *float64 // 周限额 (USD) + MonthlyLimitUSD *float64 // 月限额 (USD) + // 图片生成计费配置(仅 antigravity 平台使用) + ImagePrice1K *float64 + ImagePrice2K *float64 + ImagePrice4K *float64 + // Sora 按次计费配置 + SoraImagePrice360 *float64 + SoraImagePrice540 *float64 + SoraVideoPricePerRequest *float64 + SoraVideoPricePerRequestHD *float64 + ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端 + FallbackGroupID *int64 // 降级分组 ID + // 无效请求兜底分组 ID(仅 anthropic 平台使用) + FallbackGroupIDOnInvalidRequest *int64 + // 模型路由配置(仅 anthropic 平台使用) + ModelRouting map[string][]int64 + ModelRoutingEnabled *bool // 是否启用模型路由 + MCPXMLInject *bool + // 支持的模型系列(仅 antigravity 平台使用) + SupportedModelScopes *[]string + // Sora 存储配额 + SoraStorageQuotaBytes *int64 + // OpenAI Messages 调度配置(仅 openai 平台使用) + AllowMessagesDispatch *bool + DefaultMappedModel *string + // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) + CopyAccountsFromGroupIDs []int64 +} + +type CreateAccountInput struct { + Name string + Notes *string + Platform string + Type string + Credentials map[string]any + Extra map[string]any + ProxyID *int64 + Concurrency int + Priority int + RateMultiplier *float64 // 账号计费倍率(>=0,允许 0) + LoadFactor *int + GroupIDs []int64 + ExpiresAt *int64 + AutoPauseOnExpired *bool + // SkipDefaultGroupBind prevents auto-binding to platform default group when GroupIDs is empty. + SkipDefaultGroupBind bool + // SkipMixedChannelCheck skips the mixed channel risk check when binding groups. + // This should only be set when the caller has explicitly confirmed the risk. + SkipMixedChannelCheck bool +} + +type UpdateAccountInput struct { + Name string + Notes *string + Type string // Account type: oauth, setup-token, apikey + Credentials map[string]any + Extra map[string]any + ProxyID *int64 + Concurrency *int // 使用指针区分"未提供"和"设置为0" + Priority *int // 使用指针区分"未提供"和"设置为0" + RateMultiplier *float64 // 账号计费倍率(>=0,允许 0) + LoadFactor *int + Status string + GroupIDs *[]int64 + ExpiresAt *int64 + AutoPauseOnExpired *bool + SkipMixedChannelCheck bool // 跳过混合渠道检查(用户已确认风险) +} + +// BulkUpdateAccountsInput describes the payload for bulk updating accounts. +type BulkUpdateAccountsInput struct { + AccountIDs []int64 + Name string + ProxyID *int64 + Concurrency *int + Priority *int + RateMultiplier *float64 // 账号计费倍率(>=0,允许 0) + LoadFactor *int + Status string + Schedulable *bool + GroupIDs *[]int64 + Credentials map[string]any + Extra map[string]any + // SkipMixedChannelCheck skips the mixed channel risk check when binding groups. + // This should only be set when the caller has explicitly confirmed the risk. + SkipMixedChannelCheck bool +} + +// BulkUpdateAccountResult captures the result for a single account update. +type BulkUpdateAccountResult struct { + AccountID int64 `json:"account_id"` + Success bool `json:"success"` + Error string `json:"error,omitempty"` +} + +// AdminUpdateAPIKeyGroupIDResult is the result of AdminUpdateAPIKeyGroupID. +type AdminUpdateAPIKeyGroupIDResult struct { + APIKey *APIKey + AutoGrantedGroupAccess bool // true if a new exclusive group permission was auto-added + GrantedGroupID *int64 // the group ID that was auto-granted + GrantedGroupName string // the group name that was auto-granted +} + +// ReplaceUserGroupResult 分组替换操作的结果 +type ReplaceUserGroupResult struct { + MigratedKeys int64 // 迁移的 Key 数量 +} + +// BulkUpdateAccountsResult is the aggregated response for bulk updates. +type BulkUpdateAccountsResult struct { + Success int `json:"success"` + Failed int `json:"failed"` + SuccessIDs []int64 `json:"success_ids"` + FailedIDs []int64 `json:"failed_ids"` + Results []BulkUpdateAccountResult `json:"results"` +} + +type CreateProxyInput struct { + Name string + Protocol string + Host string + Port int + Username string + Password string +} + +type UpdateProxyInput struct { + Name string + Protocol string + Host string + Port int + Username string + Password string + Status string +} + +type GenerateRedeemCodesInput struct { + Count int + Type string + Value float64 + GroupID *int64 // 订阅类型专用:关联的分组ID + ValidityDays int // 订阅类型专用:有效天数 +} + +type ProxyBatchDeleteResult struct { + DeletedIDs []int64 `json:"deleted_ids"` + Skipped []ProxyBatchDeleteSkipped `json:"skipped"` +} + +type ProxyBatchDeleteSkipped struct { + ID int64 `json:"id"` + Reason string `json:"reason"` +} + +// ProxyTestResult represents the result of testing a proxy +type ProxyTestResult struct { + Success bool `json:"success"` + Message string `json:"message"` + LatencyMs int64 `json:"latency_ms,omitempty"` + IPAddress string `json:"ip_address,omitempty"` + City string `json:"city,omitempty"` + Region string `json:"region,omitempty"` + Country string `json:"country,omitempty"` + CountryCode string `json:"country_code,omitempty"` +} + +type ProxyQualityCheckResult struct { + ProxyID int64 `json:"proxy_id"` + Score int `json:"score"` + Grade string `json:"grade"` + Summary string `json:"summary"` + ExitIP string `json:"exit_ip,omitempty"` + Country string `json:"country,omitempty"` + CountryCode string `json:"country_code,omitempty"` + BaseLatencyMs int64 `json:"base_latency_ms,omitempty"` + PassedCount int `json:"passed_count"` + WarnCount int `json:"warn_count"` + FailedCount int `json:"failed_count"` + ChallengeCount int `json:"challenge_count"` + CheckedAt int64 `json:"checked_at"` + Items []ProxyQualityCheckItem `json:"items"` +} + +type ProxyQualityCheckItem struct { + Target string `json:"target"` + Status string `json:"status"` // pass/warn/fail/challenge + HTTPStatus int `json:"http_status,omitempty"` + LatencyMs int64 `json:"latency_ms,omitempty"` + Message string `json:"message,omitempty"` + CFRay string `json:"cf_ray,omitempty"` +} + +// ProxyExitInfo represents proxy exit information from ip-api.com +type ProxyExitInfo struct { + IP string + City string + Region string + Country string + CountryCode string +} + +// ProxyExitInfoProber tests proxy connectivity and retrieves exit information +type ProxyExitInfoProber interface { + ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error) +} + +type groupExistenceBatchReader interface { + ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error) +} + +type proxyQualityTarget struct { + Target string + URL string + Method string + AllowedStatuses map[int]struct{} +} + +var proxyQualityTargets = []proxyQualityTarget{ + { + Target: "openai", + URL: "https://api.openai.com/v1/models", + Method: http.MethodGet, + AllowedStatuses: map[int]struct{}{ + http.StatusUnauthorized: {}, + }, + }, + { + Target: "anthropic", + URL: "https://api.anthropic.com/v1/messages", + Method: http.MethodGet, + AllowedStatuses: map[int]struct{}{ + http.StatusUnauthorized: {}, + http.StatusMethodNotAllowed: {}, + http.StatusNotFound: {}, + http.StatusBadRequest: {}, + }, + }, + { + Target: "gemini", + URL: "https://generativelanguage.googleapis.com/$discovery/rest?version=v1beta", + Method: http.MethodGet, + AllowedStatuses: map[int]struct{}{ + http.StatusOK: {}, + }, + }, + { + Target: "sora", + URL: "https://sora.chatgpt.com/backend/me", + Method: http.MethodGet, + AllowedStatuses: map[int]struct{}{ + http.StatusUnauthorized: {}, + }, + }, +} + +const ( + proxyQualityRequestTimeout = 15 * time.Second + proxyQualityResponseHeaderTimeout = 10 * time.Second + proxyQualityMaxBodyBytes = int64(8 * 1024) + proxyQualityClientUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36" +) + +// adminServiceImpl implements AdminService +type adminServiceImpl struct { + userRepo UserRepository + groupRepo GroupRepository + accountRepo AccountRepository + soraAccountRepo SoraAccountRepository // Sora 账号扩展表仓储 + proxyRepo ProxyRepository + apiKeyRepo APIKeyRepository + redeemCodeRepo RedeemCodeRepository + userGroupRateRepo UserGroupRateRepository + billingCacheService *BillingCacheService + proxyProber ProxyExitInfoProber + proxyLatencyCache ProxyLatencyCache + authCacheInvalidator APIKeyAuthCacheInvalidator + entClient *dbent.Client // 用于开启数据库事务 + settingService *SettingService + defaultSubAssigner DefaultSubscriptionAssigner + userSubRepo UserSubscriptionRepository + privacyClientFactory PrivacyClientFactory +} + +type userGroupRateBatchReader interface { + GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error) +} + +// NewAdminService creates a new AdminService +func NewAdminService( + userRepo UserRepository, + groupRepo GroupRepository, + accountRepo AccountRepository, + soraAccountRepo SoraAccountRepository, + proxyRepo ProxyRepository, + apiKeyRepo APIKeyRepository, + redeemCodeRepo RedeemCodeRepository, + userGroupRateRepo UserGroupRateRepository, + billingCacheService *BillingCacheService, + proxyProber ProxyExitInfoProber, + proxyLatencyCache ProxyLatencyCache, + authCacheInvalidator APIKeyAuthCacheInvalidator, + entClient *dbent.Client, + settingService *SettingService, + defaultSubAssigner DefaultSubscriptionAssigner, + userSubRepo UserSubscriptionRepository, + privacyClientFactory PrivacyClientFactory, +) AdminService { + return &adminServiceImpl{ + userRepo: userRepo, + groupRepo: groupRepo, + accountRepo: accountRepo, + soraAccountRepo: soraAccountRepo, + proxyRepo: proxyRepo, + apiKeyRepo: apiKeyRepo, + redeemCodeRepo: redeemCodeRepo, + userGroupRateRepo: userGroupRateRepo, + billingCacheService: billingCacheService, + proxyProber: proxyProber, + proxyLatencyCache: proxyLatencyCache, + authCacheInvalidator: authCacheInvalidator, + entClient: entClient, + settingService: settingService, + defaultSubAssigner: defaultSubAssigner, + userSubRepo: userSubRepo, + privacyClientFactory: privacyClientFactory, + } +} + +// User management implementations +func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters) ([]User, int64, error) { + params := pagination.PaginationParams{Page: page, PageSize: pageSize} + users, result, err := s.userRepo.ListWithFilters(ctx, params, filters) + if err != nil { + return nil, 0, err + } + // 批量加载用户专属分组倍率 + if s.userGroupRateRepo != nil && len(users) > 0 { + if batchRepo, ok := s.userGroupRateRepo.(userGroupRateBatchReader); ok { + userIDs := make([]int64, 0, len(users)) + for i := range users { + userIDs = append(userIDs, users[i].ID) + } + ratesByUser, err := batchRepo.GetByUserIDs(ctx, userIDs) + if err != nil { + logger.LegacyPrintf("service.admin", "failed to load user group rates in batch: err=%v", err) + s.loadUserGroupRatesOneByOne(ctx, users) + } else { + for i := range users { + if rates, ok := ratesByUser[users[i].ID]; ok { + users[i].GroupRates = rates + } + } + } + } else { + s.loadUserGroupRatesOneByOne(ctx, users) + } + } + return users, result.Total, nil +} + +func (s *adminServiceImpl) loadUserGroupRatesOneByOne(ctx context.Context, users []User) { + if s.userGroupRateRepo == nil { + return + } + for i := range users { + rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID) + if err != nil { + logger.LegacyPrintf("service.admin", "failed to load user group rates: user_id=%d err=%v", users[i].ID, err) + continue + } + users[i].GroupRates = rates + } +} + +func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) { + user, err := s.userRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + // 加载用户专属分组倍率 + if s.userGroupRateRepo != nil { + rates, err := s.userGroupRateRepo.GetByUserID(ctx, id) + if err != nil { + logger.LegacyPrintf("service.admin", "failed to load user group rates: user_id=%d err=%v", id, err) + } else { + user.GroupRates = rates + } + } + return user, nil +} + +func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) { + user := &User{ + Email: input.Email, + Username: input.Username, + Notes: input.Notes, + Role: RoleUser, // Always create as regular user, never admin + Balance: input.Balance, + Concurrency: input.Concurrency, + Status: StatusActive, + AllowedGroups: input.AllowedGroups, + SoraStorageQuotaBytes: input.SoraStorageQuotaBytes, + } + if err := user.SetPassword(input.Password); err != nil { + return nil, err + } + if err := s.userRepo.Create(ctx, user); err != nil { + return nil, err + } + s.assignDefaultSubscriptions(ctx, user.ID) + return user, nil +} + +func (s *adminServiceImpl) assignDefaultSubscriptions(ctx context.Context, userID int64) { + if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 { + return + } + items := s.settingService.GetDefaultSubscriptions(ctx) + for _, item := range items { + if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{ + UserID: userID, + GroupID: item.GroupID, + ValidityDays: item.ValidityDays, + Notes: "auto assigned by default user subscriptions setting", + }); err != nil { + logger.LegacyPrintf("service.admin", "failed to assign default subscription: user_id=%d group_id=%d err=%v", userID, item.GroupID, err) + } + } +} + +func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) { + user, err := s.userRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + + // Protect admin users: cannot disable admin accounts + if user.Role == "admin" && input.Status == "disabled" { + return nil, errors.New("cannot disable admin user") + } + + oldConcurrency := user.Concurrency + oldStatus := user.Status + oldRole := user.Role + + if input.Email != "" { + user.Email = input.Email + } + if input.Password != "" { + if err := user.SetPassword(input.Password); err != nil { + return nil, err + } + } + + if input.Username != nil { + user.Username = *input.Username + } + if input.Notes != nil { + user.Notes = *input.Notes + } + + if input.Status != "" { + user.Status = input.Status + } + + if input.Concurrency != nil { + user.Concurrency = *input.Concurrency + } + + if input.AllowedGroups != nil { + user.AllowedGroups = *input.AllowedGroups + } + + if input.SoraStorageQuotaBytes != nil { + user.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes + } + + if err := s.userRepo.Update(ctx, user); err != nil { + return nil, err + } + + // 同步用户专属分组倍率 + if input.GroupRates != nil && s.userGroupRateRepo != nil { + if err := s.userGroupRateRepo.SyncUserGroupRates(ctx, user.ID, input.GroupRates); err != nil { + logger.LegacyPrintf("service.admin", "failed to sync user group rates: user_id=%d err=%v", user.ID, err) + } + } + + if s.authCacheInvalidator != nil { + if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID) + } + } + + concurrencyDiff := user.Concurrency - oldConcurrency + if concurrencyDiff != 0 { + code, err := GenerateRedeemCode() + if err != nil { + logger.LegacyPrintf("service.admin", "failed to generate adjustment redeem code: %v", err) + return user, nil + } + adjustmentRecord := &RedeemCode{ + Code: code, + Type: AdjustmentTypeAdminConcurrency, + Value: float64(concurrencyDiff), + Status: StatusUsed, + UsedBy: &user.ID, + } + now := time.Now() + adjustmentRecord.UsedAt = &now + if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil { + logger.LegacyPrintf("service.admin", "failed to create concurrency adjustment redeem code: %v", err) + } + } + + return user, nil +} + +func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error { + // Protect admin users: cannot delete admin accounts + user, err := s.userRepo.GetByID(ctx, id) + if err != nil { + return err + } + if user.Role == "admin" { + return errors.New("cannot delete admin user") + } + if err := s.userRepo.Delete(ctx, id); err != nil { + logger.LegacyPrintf("service.admin", "delete user failed: user_id=%d err=%v", id, err) + return err + } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, id) + } + return nil +} + +func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) { + user, err := s.userRepo.GetByID(ctx, userID) + if err != nil { + return nil, err + } + + oldBalance := user.Balance + + switch operation { + case "set": + user.Balance = balance + case "add": + user.Balance += balance + case "subtract": + user.Balance -= balance + } + + if user.Balance < 0 { + return nil, fmt.Errorf("balance cannot be negative, current balance: %.2f, requested operation would result in: %.2f", oldBalance, user.Balance) + } + + if err := s.userRepo.Update(ctx, user); err != nil { + return nil, err + } + balanceDiff := user.Balance - oldBalance + if s.authCacheInvalidator != nil && balanceDiff != 0 { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } + + if s.billingCacheService != nil { + go func() { + cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := s.billingCacheService.InvalidateUserBalance(cacheCtx, userID); err != nil { + logger.LegacyPrintf("service.admin", "invalidate user balance cache failed: user_id=%d err=%v", userID, err) + } + }() + } + + if balanceDiff != 0 { + code, err := GenerateRedeemCode() + if err != nil { + logger.LegacyPrintf("service.admin", "failed to generate adjustment redeem code: %v", err) + return user, nil + } + + adjustmentRecord := &RedeemCode{ + Code: code, + Type: AdjustmentTypeAdminBalance, + Value: balanceDiff, + Status: StatusUsed, + UsedBy: &user.ID, + Notes: notes, + } + now := time.Now() + adjustmentRecord.UsedAt = &now + + if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil { + logger.LegacyPrintf("service.admin", "failed to create balance adjustment redeem code: %v", err) + } + } + + return user, nil +} + +func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) { + params := pagination.PaginationParams{Page: page, PageSize: pageSize} + keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params, APIKeyListFilters{}) + if err != nil { + return nil, 0, err + } + return keys, result.Total, nil +} + +func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) { + // Return mock data for now + return map[string]any{ + "period": period, + "total_requests": 0, + "total_cost": 0.0, + "total_tokens": 0, + "avg_duration_ms": 0, + }, nil +} + +// GetUserBalanceHistory returns paginated balance/concurrency change records for a user. +func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error) { + params := pagination.PaginationParams{Page: page, PageSize: pageSize} + codes, result, err := s.redeemCodeRepo.ListByUserPaginated(ctx, userID, params, codeType) + if err != nil { + return nil, 0, 0, err + } + // Aggregate total recharged amount (only once, regardless of type filter) + totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID) + if err != nil { + return nil, 0, 0, err + } + return codes, result.Total, totalRecharged, nil +} + +// Group management implementations +func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error) { + params := pagination.PaginationParams{Page: page, PageSize: pageSize} + groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, search, isExclusive) + if err != nil { + return nil, 0, err + } + return groups, result.Total, nil +} + +func (s *adminServiceImpl) GetAllGroups(ctx context.Context) ([]Group, error) { + return s.groupRepo.ListActive(ctx) +} + +func (s *adminServiceImpl) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error) { + return s.groupRepo.ListActiveByPlatform(ctx, platform) +} + +func (s *adminServiceImpl) GetGroup(ctx context.Context, id int64) (*Group, error) { + return s.groupRepo.GetByID(ctx, id) +} + +func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) { + platform := input.Platform + if platform == "" { + platform = PlatformAnthropic + } + + subscriptionType := input.SubscriptionType + if subscriptionType == "" { + subscriptionType = SubscriptionTypeStandard + } + + // 限额字段:nil/负数 表示"无限制",0 表示"不允许用量",正数表示具体限额 + dailyLimit := normalizeLimit(input.DailyLimitUSD) + weeklyLimit := normalizeLimit(input.WeeklyLimitUSD) + monthlyLimit := normalizeLimit(input.MonthlyLimitUSD) + + // 图片价格:负数表示清除(使用默认价格),0 保留(表示免费) + imagePrice1K := normalizePrice(input.ImagePrice1K) + imagePrice2K := normalizePrice(input.ImagePrice2K) + imagePrice4K := normalizePrice(input.ImagePrice4K) + soraImagePrice360 := normalizePrice(input.SoraImagePrice360) + soraImagePrice540 := normalizePrice(input.SoraImagePrice540) + soraVideoPrice := normalizePrice(input.SoraVideoPricePerRequest) + soraVideoPriceHD := normalizePrice(input.SoraVideoPricePerRequestHD) + + // 校验降级分组 + if input.FallbackGroupID != nil { + if err := s.validateFallbackGroup(ctx, 0, *input.FallbackGroupID); err != nil { + return nil, err + } + } + fallbackOnInvalidRequest := input.FallbackGroupIDOnInvalidRequest + if fallbackOnInvalidRequest != nil && *fallbackOnInvalidRequest <= 0 { + fallbackOnInvalidRequest = nil + } + // 校验无效请求兜底分组 + if fallbackOnInvalidRequest != nil { + if err := s.validateFallbackGroupOnInvalidRequest(ctx, 0, platform, subscriptionType, *fallbackOnInvalidRequest); err != nil { + return nil, err + } + } + + // MCPXMLInject:默认为 true,仅当显式传入 false 时关闭 + mcpXMLInject := true + if input.MCPXMLInject != nil { + mcpXMLInject = *input.MCPXMLInject + } + + // 如果指定了复制账号的源分组,先获取账号 ID 列表 + var accountIDsToCopy []int64 + if len(input.CopyAccountsFromGroupIDs) > 0 { + // 去重源分组 IDs + seen := make(map[int64]struct{}) + uniqueSourceGroupIDs := make([]int64, 0, len(input.CopyAccountsFromGroupIDs)) + for _, srcGroupID := range input.CopyAccountsFromGroupIDs { + if _, exists := seen[srcGroupID]; !exists { + seen[srcGroupID] = struct{}{} + uniqueSourceGroupIDs = append(uniqueSourceGroupIDs, srcGroupID) + } + } + + // 校验源分组的平台是否与新分组一致 + for _, srcGroupID := range uniqueSourceGroupIDs { + srcGroup, err := s.groupRepo.GetByIDLite(ctx, srcGroupID) + if err != nil { + return nil, fmt.Errorf("source group %d not found: %w", srcGroupID, err) + } + if srcGroup.Platform != platform { + return nil, fmt.Errorf("source group %d platform mismatch: expected %s, got %s", srcGroupID, platform, srcGroup.Platform) + } + } + + // 获取所有源分组的账号(去重) + var err error + accountIDsToCopy, err = s.groupRepo.GetAccountIDsByGroupIDs(ctx, uniqueSourceGroupIDs) + if err != nil { + return nil, fmt.Errorf("failed to get accounts from source groups: %w", err) + } + } + + group := &Group{ + Name: input.Name, + Description: input.Description, + Platform: platform, + RateMultiplier: input.RateMultiplier, + IsExclusive: input.IsExclusive, + Status: StatusActive, + SubscriptionType: subscriptionType, + DailyLimitUSD: dailyLimit, + WeeklyLimitUSD: weeklyLimit, + MonthlyLimitUSD: monthlyLimit, + ImagePrice1K: imagePrice1K, + ImagePrice2K: imagePrice2K, + ImagePrice4K: imagePrice4K, + SoraImagePrice360: soraImagePrice360, + SoraImagePrice540: soraImagePrice540, + SoraVideoPricePerRequest: soraVideoPrice, + SoraVideoPricePerRequestHD: soraVideoPriceHD, + ClaudeCodeOnly: input.ClaudeCodeOnly, + FallbackGroupID: input.FallbackGroupID, + FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest, + ModelRouting: input.ModelRouting, + MCPXMLInject: mcpXMLInject, + SupportedModelScopes: input.SupportedModelScopes, + SoraStorageQuotaBytes: input.SoraStorageQuotaBytes, + AllowMessagesDispatch: input.AllowMessagesDispatch, + DefaultMappedModel: input.DefaultMappedModel, + } + if err := s.groupRepo.Create(ctx, group); err != nil { + return nil, err + } + + // 如果有需要复制的账号,绑定到新分组 + if len(accountIDsToCopy) > 0 { + if err := s.groupRepo.BindAccountsToGroup(ctx, group.ID, accountIDsToCopy); err != nil { + return nil, fmt.Errorf("failed to bind accounts to new group: %w", err) + } + group.AccountCount = int64(len(accountIDsToCopy)) + } + + return group, nil +} + +// normalizeLimit 将负数转换为 nil(表示无限制),0 保留(表示限额为零) +func normalizeLimit(limit *float64) *float64 { + if limit == nil || *limit < 0 { + return nil + } + return limit +} + +// normalizePrice 将负数转换为 nil(表示使用默认价格),0 保留(表示免费) +func normalizePrice(price *float64) *float64 { + if price == nil || *price < 0 { + return nil + } + return price +} + +// validateFallbackGroup 校验降级分组的有效性 +// currentGroupID: 当前分组 ID(新建时为 0) +// fallbackGroupID: 降级分组 ID +func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGroupID, fallbackGroupID int64) error { + // 不能将自己设置为降级分组 + if currentGroupID > 0 && currentGroupID == fallbackGroupID { + return fmt.Errorf("cannot set self as fallback group") + } + + visited := map[int64]struct{}{} + nextID := fallbackGroupID + for { + if _, seen := visited[nextID]; seen { + return fmt.Errorf("fallback group cycle detected") + } + visited[nextID] = struct{}{} + if currentGroupID > 0 && nextID == currentGroupID { + return fmt.Errorf("fallback group cycle detected") + } + + // 检查降级分组是否存在 + fallbackGroup, err := s.groupRepo.GetByIDLite(ctx, nextID) + if err != nil { + return fmt.Errorf("fallback group not found: %w", err) + } + + // 降级分组不能启用 claude_code_only,否则会造成死循环 + if nextID == fallbackGroupID && fallbackGroup.ClaudeCodeOnly { + return fmt.Errorf("fallback group cannot have claude_code_only enabled") + } + + if fallbackGroup.FallbackGroupID == nil { + return nil + } + nextID = *fallbackGroup.FallbackGroupID + } +} + +// validateFallbackGroupOnInvalidRequest 校验无效请求兜底分组的有效性 +// currentGroupID: 当前分组 ID(新建时为 0) +// platform/subscriptionType: 当前分组的有效平台/订阅类型 +// fallbackGroupID: 兜底分组 ID +func (s *adminServiceImpl) validateFallbackGroupOnInvalidRequest(ctx context.Context, currentGroupID int64, platform, subscriptionType string, fallbackGroupID int64) error { + if platform != PlatformAnthropic && platform != PlatformAntigravity { + return fmt.Errorf("invalid request fallback only supported for anthropic or antigravity groups") + } + if subscriptionType == SubscriptionTypeSubscription { + return fmt.Errorf("subscription groups cannot set invalid request fallback") + } + if currentGroupID > 0 && currentGroupID == fallbackGroupID { + return fmt.Errorf("cannot set self as invalid request fallback group") + } + + fallbackGroup, err := s.groupRepo.GetByIDLite(ctx, fallbackGroupID) + if err != nil { + return fmt.Errorf("fallback group not found: %w", err) + } + if fallbackGroup.Platform != PlatformAnthropic { + return fmt.Errorf("fallback group must be anthropic platform") + } + if fallbackGroup.SubscriptionType == SubscriptionTypeSubscription { + return fmt.Errorf("fallback group cannot be subscription type") + } + if fallbackGroup.FallbackGroupIDOnInvalidRequest != nil { + return fmt.Errorf("fallback group cannot have invalid request fallback configured") + } + return nil +} + +func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) { + group, err := s.groupRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + + if input.Name != "" { + group.Name = input.Name + } + if input.Description != "" { + group.Description = input.Description + } + if input.Platform != "" { + group.Platform = input.Platform + } + if input.RateMultiplier != nil { + group.RateMultiplier = *input.RateMultiplier + } + if input.IsExclusive != nil { + group.IsExclusive = *input.IsExclusive + } + if input.Status != "" { + group.Status = input.Status + } + + // 订阅相关字段 + if input.SubscriptionType != "" { + group.SubscriptionType = input.SubscriptionType + } + // 限额字段:nil/负数 表示"无限制",0 表示"不允许用量",正数表示具体限额 + // 前端始终发送这三个字段,无需 nil 守卫 + group.DailyLimitUSD = normalizeLimit(input.DailyLimitUSD) + group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD) + group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD) + // 图片生成计费配置:负数表示清除(使用默认价格) + if input.ImagePrice1K != nil { + group.ImagePrice1K = normalizePrice(input.ImagePrice1K) + } + if input.ImagePrice2K != nil { + group.ImagePrice2K = normalizePrice(input.ImagePrice2K) + } + if input.ImagePrice4K != nil { + group.ImagePrice4K = normalizePrice(input.ImagePrice4K) + } + if input.SoraImagePrice360 != nil { + group.SoraImagePrice360 = normalizePrice(input.SoraImagePrice360) + } + if input.SoraImagePrice540 != nil { + group.SoraImagePrice540 = normalizePrice(input.SoraImagePrice540) + } + if input.SoraVideoPricePerRequest != nil { + group.SoraVideoPricePerRequest = normalizePrice(input.SoraVideoPricePerRequest) + } + if input.SoraVideoPricePerRequestHD != nil { + group.SoraVideoPricePerRequestHD = normalizePrice(input.SoraVideoPricePerRequestHD) + } + if input.SoraStorageQuotaBytes != nil { + group.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes + } + + // Claude Code 客户端限制 + if input.ClaudeCodeOnly != nil { + group.ClaudeCodeOnly = *input.ClaudeCodeOnly + } + if input.FallbackGroupID != nil { + // 校验降级分组 + if *input.FallbackGroupID > 0 { + if err := s.validateFallbackGroup(ctx, id, *input.FallbackGroupID); err != nil { + return nil, err + } + group.FallbackGroupID = input.FallbackGroupID + } else { + // 传入 0 或负数表示清除降级分组 + group.FallbackGroupID = nil + } + } + fallbackOnInvalidRequest := group.FallbackGroupIDOnInvalidRequest + if input.FallbackGroupIDOnInvalidRequest != nil { + if *input.FallbackGroupIDOnInvalidRequest > 0 { + fallbackOnInvalidRequest = input.FallbackGroupIDOnInvalidRequest + } else { + fallbackOnInvalidRequest = nil + } + } + if fallbackOnInvalidRequest != nil { + if err := s.validateFallbackGroupOnInvalidRequest(ctx, id, group.Platform, group.SubscriptionType, *fallbackOnInvalidRequest); err != nil { + return nil, err + } + } + group.FallbackGroupIDOnInvalidRequest = fallbackOnInvalidRequest + + // 模型路由配置 + if input.ModelRouting != nil { + group.ModelRouting = input.ModelRouting + } + if input.ModelRoutingEnabled != nil { + group.ModelRoutingEnabled = *input.ModelRoutingEnabled + } + if input.MCPXMLInject != nil { + group.MCPXMLInject = *input.MCPXMLInject + } + + // 支持的模型系列(仅 antigravity 平台使用) + if input.SupportedModelScopes != nil { + group.SupportedModelScopes = *input.SupportedModelScopes + } + + // OpenAI Messages 调度配置 + if input.AllowMessagesDispatch != nil { + group.AllowMessagesDispatch = *input.AllowMessagesDispatch + } + if input.DefaultMappedModel != nil { + group.DefaultMappedModel = *input.DefaultMappedModel + } + + if err := s.groupRepo.Update(ctx, group); err != nil { + return nil, err + } + + // 如果指定了复制账号的源分组,同步绑定(替换当前分组的账号) + if len(input.CopyAccountsFromGroupIDs) > 0 { + // 去重源分组 IDs + seen := make(map[int64]struct{}) + uniqueSourceGroupIDs := make([]int64, 0, len(input.CopyAccountsFromGroupIDs)) + for _, srcGroupID := range input.CopyAccountsFromGroupIDs { + // 校验:源分组不能是自身 + if srcGroupID == id { + return nil, fmt.Errorf("cannot copy accounts from self") + } + // 去重 + if _, exists := seen[srcGroupID]; !exists { + seen[srcGroupID] = struct{}{} + uniqueSourceGroupIDs = append(uniqueSourceGroupIDs, srcGroupID) + } + } + + // 校验源分组的平台是否与当前分组一致 + for _, srcGroupID := range uniqueSourceGroupIDs { + srcGroup, err := s.groupRepo.GetByIDLite(ctx, srcGroupID) + if err != nil { + return nil, fmt.Errorf("source group %d not found: %w", srcGroupID, err) + } + if srcGroup.Platform != group.Platform { + return nil, fmt.Errorf("source group %d platform mismatch: expected %s, got %s", srcGroupID, group.Platform, srcGroup.Platform) + } + } + + // 获取所有源分组的账号(去重) + accountIDsToCopy, err := s.groupRepo.GetAccountIDsByGroupIDs(ctx, uniqueSourceGroupIDs) + if err != nil { + return nil, fmt.Errorf("failed to get accounts from source groups: %w", err) + } + + // 先清空当前分组的所有账号绑定 + if _, err := s.groupRepo.DeleteAccountGroupsByGroupID(ctx, id); err != nil { + return nil, fmt.Errorf("failed to clear existing account bindings: %w", err) + } + + // 再绑定源分组的账号 + if len(accountIDsToCopy) > 0 { + if err := s.groupRepo.BindAccountsToGroup(ctx, id, accountIDsToCopy); err != nil { + return nil, fmt.Errorf("failed to bind accounts to group: %w", err) + } + } + } + + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id) + } + return group, nil +} + +func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error { + var groupKeys []string + if s.authCacheInvalidator != nil { + keys, err := s.apiKeyRepo.ListKeysByGroupID(ctx, id) + if err == nil { + groupKeys = keys + } + } + + affectedUserIDs, err := s.groupRepo.DeleteCascade(ctx, id) + if err != nil { + return err + } + // 注意:user_group_rate_multipliers 表通过外键 ON DELETE CASCADE 自动清理 + + // 事务成功后,异步失效受影响用户的订阅缓存 + if len(affectedUserIDs) > 0 && s.billingCacheService != nil { + groupID := id + go func() { + cacheCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + for _, userID := range affectedUserIDs { + if err := s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID); err != nil { + logger.LegacyPrintf("service.admin", "invalidate subscription cache failed: user_id=%d group_id=%d err=%v", userID, groupID, err) + } + } + }() + } + if s.authCacheInvalidator != nil { + for _, key := range groupKeys { + s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, key) + } + } + + return nil +} + +func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error) { + params := pagination.PaginationParams{Page: page, PageSize: pageSize} + keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params) + if err != nil { + return nil, 0, err + } + return keys, result.Total, nil +} + +func (s *adminServiceImpl) GetGroupRateMultipliers(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error) { + if s.userGroupRateRepo == nil { + return nil, nil + } + return s.userGroupRateRepo.GetByGroupID(ctx, groupID) +} + +func (s *adminServiceImpl) ClearGroupRateMultipliers(ctx context.Context, groupID int64) error { + if s.userGroupRateRepo == nil { + return nil + } + return s.userGroupRateRepo.DeleteByGroupID(ctx, groupID) +} + +func (s *adminServiceImpl) BatchSetGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error { + if s.userGroupRateRepo == nil { + return nil + } + return s.userGroupRateRepo.SyncGroupRateMultipliers(ctx, groupID, entries) +} + +func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error { + return s.groupRepo.UpdateSortOrders(ctx, updates) +} + +// AdminUpdateAPIKeyGroupID 管理员修改 API Key 分组绑定 +// groupID: nil=不修改, 指向0=解绑, 指向正整数=绑定到目标分组 +func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*AdminUpdateAPIKeyGroupIDResult, error) { + apiKey, err := s.apiKeyRepo.GetByID(ctx, keyID) + if err != nil { + return nil, err + } + + if groupID == nil { + // nil 表示不修改,直接返回 + return &AdminUpdateAPIKeyGroupIDResult{APIKey: apiKey}, nil + } + + if *groupID < 0 { + return nil, infraerrors.BadRequest("INVALID_GROUP_ID", "group_id must be non-negative") + } + + result := &AdminUpdateAPIKeyGroupIDResult{} + + if *groupID == 0 { + // 0 表示解绑分组(不修改 user_allowed_groups,避免影响用户其他 Key) + apiKey.GroupID = nil + apiKey.Group = nil + } else { + // 验证目标分组存在且状态为 active + group, err := s.groupRepo.GetByID(ctx, *groupID) + if err != nil { + return nil, err + } + if group.Status != StatusActive { + return nil, infraerrors.BadRequest("GROUP_NOT_ACTIVE", "target group is not active") + } + // 订阅类型分组:用户须持有该分组的有效订阅才可绑定 + if group.IsSubscriptionType() { + if s.userSubRepo == nil { + return nil, infraerrors.InternalServer("SUBSCRIPTION_REPOSITORY_UNAVAILABLE", "subscription repository is not configured") + } + if _, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, apiKey.UserID, *groupID); err != nil { + if errors.Is(err, ErrSubscriptionNotFound) { + return nil, infraerrors.BadRequest("SUBSCRIPTION_REQUIRED", "user does not have an active subscription for this group") + } + return nil, err + } + } + + gid := *groupID + apiKey.GroupID = &gid + apiKey.Group = group + + // 专属标准分组:使用事务保证「添加分组权限」与「更新 API Key」的原子性 + if group.IsExclusive && !group.IsSubscriptionType() { + opCtx := ctx + var tx *dbent.Tx + if s.entClient == nil { + logger.LegacyPrintf("service.admin", "Warning: entClient is nil, skipping transaction protection for exclusive group binding") + } else { + var txErr error + tx, txErr = s.entClient.Tx(ctx) + if txErr != nil { + return nil, fmt.Errorf("begin transaction: %w", txErr) + } + defer func() { _ = tx.Rollback() }() + opCtx = dbent.NewTxContext(ctx, tx) + } + + if addErr := s.userRepo.AddGroupToAllowedGroups(opCtx, apiKey.UserID, gid); addErr != nil { + return nil, fmt.Errorf("add group to user allowed groups: %w", addErr) + } + if err := s.apiKeyRepo.Update(opCtx, apiKey); err != nil { + return nil, fmt.Errorf("update api key: %w", err) + } + if tx != nil { + if err := tx.Commit(); err != nil { + return nil, fmt.Errorf("commit transaction: %w", err) + } + } + + result.AutoGrantedGroupAccess = true + result.GrantedGroupID = &gid + result.GrantedGroupName = group.Name + + // 失效认证缓存(在事务提交后执行) + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, apiKey.Key) + } + + result.APIKey = apiKey + return result, nil + } + } + + // 非专属分组 / 解绑:无需事务,单步更新即可 + if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil { + return nil, fmt.Errorf("update api key: %w", err) + } + + // 失效认证缓存 + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, apiKey.Key) + } + + result.APIKey = apiKey + return result, nil +} + +// ReplaceUserGroup 替换用户的专属分组 +func (s *adminServiceImpl) ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error) { + if oldGroupID == newGroupID { + return nil, infraerrors.BadRequest("SAME_GROUP", "old and new group must be different") + } + + // 验证新分组存在且为活跃的专属标准分组 + newGroup, err := s.groupRepo.GetByID(ctx, newGroupID) + if err != nil { + return nil, err + } + if newGroup.Status != StatusActive { + return nil, infraerrors.BadRequest("GROUP_NOT_ACTIVE", "target group is not active") + } + if !newGroup.IsExclusive { + return nil, infraerrors.BadRequest("GROUP_NOT_EXCLUSIVE", "target group is not exclusive") + } + if newGroup.IsSubscriptionType() { + return nil, infraerrors.BadRequest("GROUP_IS_SUBSCRIPTION", "subscription groups are not supported for replacement") + } + + // 事务保证原子性 + if s.entClient == nil { + return nil, fmt.Errorf("entClient is nil, cannot perform group replacement") + } + tx, err := s.entClient.Tx(ctx) + if err != nil { + return nil, fmt.Errorf("begin transaction: %w", err) + } + defer func() { _ = tx.Rollback() }() + opCtx := dbent.NewTxContext(ctx, tx) + + // 1. 授予新分组权限 + if err := s.userRepo.AddGroupToAllowedGroups(opCtx, userID, newGroupID); err != nil { + return nil, fmt.Errorf("add new group to allowed groups: %w", err) + } + + // 2. 迁移绑定旧分组的 Key 到新分组 + migrated, err := s.apiKeyRepo.UpdateGroupIDByUserAndGroup(opCtx, userID, oldGroupID, newGroupID) + if err != nil { + return nil, fmt.Errorf("migrate api keys: %w", err) + } + + // 3. 移除旧分组权限 + if err := s.userRepo.RemoveGroupFromUserAllowedGroups(opCtx, userID, oldGroupID); err != nil { + return nil, fmt.Errorf("remove old group from allowed groups: %w", err) + } + + if err := tx.Commit(); err != nil { + return nil, fmt.Errorf("commit transaction: %w", err) + } + + // 失效该用户所有 Key 的认证缓存 + if s.authCacheInvalidator != nil { + keys, keyErr := s.apiKeyRepo.ListKeysByUserID(ctx, userID) + if keyErr == nil { + for _, k := range keys { + s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, k) + } + } + } + + return &ReplaceUserGroupResult{MigratedKeys: migrated}, nil +} + +// Account management implementations +func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error) { + params := pagination.PaginationParams{Page: page, PageSize: pageSize} + accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search, groupID) + if err != nil { + return nil, 0, err + } + now := time.Now() + for i := range accounts { + syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, &accounts[i], now) + } + return accounts, result.Total, nil +} + +func (s *adminServiceImpl) GetAccount(ctx context.Context, id int64) (*Account, error) { + return s.accountRepo.GetByID(ctx, id) +} + +func (s *adminServiceImpl) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error) { + if len(ids) == 0 { + return []*Account{}, nil + } + + accounts, err := s.accountRepo.GetByIDs(ctx, ids) + if err != nil { + return nil, fmt.Errorf("failed to get accounts by IDs: %w", err) + } + + return accounts, nil +} + +func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) { + // 绑定分组 + groupIDs := input.GroupIDs + // 如果没有指定分组,自动绑定对应平台的默认分组 + if len(groupIDs) == 0 && !input.SkipDefaultGroupBind { + defaultGroupName := input.Platform + "-default" + groups, err := s.groupRepo.ListActiveByPlatform(ctx, input.Platform) + if err == nil { + for _, g := range groups { + if g.Name == defaultGroupName { + groupIDs = []int64{g.ID} + break + } + } + } + } + + // 检查混合渠道风险(除非用户已确认) + if len(groupIDs) > 0 && !input.SkipMixedChannelCheck { + if err := s.checkMixedChannelRisk(ctx, 0, input.Platform, groupIDs); err != nil { + return nil, err + } + } + + // Sora apikey 账号的 base_url 必填校验 + if input.Platform == PlatformSora && input.Type == AccountTypeAPIKey { + baseURL, _ := input.Credentials["base_url"].(string) + baseURL = strings.TrimSpace(baseURL) + if baseURL == "" { + return nil, errors.New("sora apikey 账号必须设置 base_url") + } + if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { + return nil, errors.New("base_url 必须以 http:// 或 https:// 开头") + } + } + + account := &Account{ + Name: input.Name, + Notes: normalizeAccountNotes(input.Notes), + Platform: input.Platform, + Type: input.Type, + Credentials: input.Credentials, + Extra: input.Extra, + ProxyID: input.ProxyID, + Concurrency: input.Concurrency, + Priority: input.Priority, + Status: StatusActive, + Schedulable: true, + } + // 预计算固定时间重置的下次重置时间 + if account.Extra != nil { + if err := ValidateQuotaResetConfig(account.Extra); err != nil { + return nil, err + } + ComputeQuotaResetAt(account.Extra) + } + if input.ExpiresAt != nil && *input.ExpiresAt > 0 { + expiresAt := time.Unix(*input.ExpiresAt, 0) + account.ExpiresAt = &expiresAt + } + if input.AutoPauseOnExpired != nil { + account.AutoPauseOnExpired = *input.AutoPauseOnExpired + } else { + account.AutoPauseOnExpired = true + } + if input.RateMultiplier != nil { + if *input.RateMultiplier < 0 { + return nil, errors.New("rate_multiplier must be >= 0") + } + account.RateMultiplier = input.RateMultiplier + } + if input.LoadFactor != nil && *input.LoadFactor > 0 { + if *input.LoadFactor > 10000 { + return nil, errors.New("load_factor must be <= 10000") + } + account.LoadFactor = input.LoadFactor + } + if err := s.accountRepo.Create(ctx, account); err != nil { + return nil, err + } + + // 如果是 Sora 平台账号,自动创建 sora_accounts 扩展表记录 + if account.Platform == PlatformSora && s.soraAccountRepo != nil { + soraUpdates := map[string]any{ + "access_token": account.GetCredential("access_token"), + "refresh_token": account.GetCredential("refresh_token"), + } + if err := s.soraAccountRepo.Upsert(ctx, account.ID, soraUpdates); err != nil { + // 只记录警告日志,不阻塞账号创建 + logger.LegacyPrintf("service.admin", "[AdminService] 创建 sora_accounts 记录失败: account_id=%d err=%v", account.ID, err) + } + } + + // 绑定分组 + if len(groupIDs) > 0 { + if err := s.accountRepo.BindGroups(ctx, account.ID, groupIDs); err != nil { + return nil, err + } + } + + return account, nil +} + +func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error) { + account, err := s.accountRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + wasOveragesEnabled := account.IsOveragesEnabled() + + if input.Name != "" { + account.Name = input.Name + } + if input.Type != "" { + account.Type = input.Type + } + if input.Notes != nil { + account.Notes = normalizeAccountNotes(input.Notes) + } + if len(input.Credentials) > 0 { + account.Credentials = input.Credentials + } + // Extra 使用 map:需要区分“未提供(nil)”与“显式清空({})”。 + // 关闭配额限制时前端会删除 quota_* 键并提交 extra:{},此时也必须落库。 + if input.Extra != nil { + // 保留配额用量字段,防止编辑账号时意外重置 + for _, key := range []string{"quota_used", "quota_daily_used", "quota_daily_start", "quota_weekly_used", "quota_weekly_start"} { + if v, ok := account.Extra[key]; ok { + input.Extra[key] = v + } + } + account.Extra = input.Extra + if account.Platform == PlatformAntigravity && wasOveragesEnabled && !account.IsOveragesEnabled() { + delete(account.Extra, "antigravity_credits_overages") // 清理旧版 overages 运行态 + // 清除 AICredits 限流 key + if rawLimits, ok := account.Extra[modelRateLimitsKey].(map[string]any); ok { + delete(rawLimits, creditsExhaustedKey) + } + } + if account.Platform == PlatformAntigravity && !wasOveragesEnabled && account.IsOveragesEnabled() { + delete(account.Extra, modelRateLimitsKey) + delete(account.Extra, "antigravity_credits_overages") // 清理旧版 overages 运行态 + } + // 校验并预计算固定时间重置的下次重置时间 + if err := ValidateQuotaResetConfig(account.Extra); err != nil { + return nil, err + } + ComputeQuotaResetAt(account.Extra) + } + if input.ProxyID != nil { + // 0 表示清除代理(前端发送 0 而不是 null 来表达清除意图) + if *input.ProxyID == 0 { + account.ProxyID = nil + } else { + account.ProxyID = input.ProxyID + } + account.Proxy = nil // 清除关联对象,防止 GORM Save 时根据 Proxy.ID 覆盖 ProxyID + } + // 只在指针非 nil 时更新 Concurrency(支持设置为 0) + if input.Concurrency != nil { + account.Concurrency = *input.Concurrency + } + // 只在指针非 nil 时更新 Priority(支持设置为 0) + if input.Priority != nil { + account.Priority = *input.Priority + } + if input.RateMultiplier != nil { + if *input.RateMultiplier < 0 { + return nil, errors.New("rate_multiplier must be >= 0") + } + account.RateMultiplier = input.RateMultiplier + } + if input.LoadFactor != nil { + if *input.LoadFactor <= 0 { + account.LoadFactor = nil // 0 或负数表示清除 + } else if *input.LoadFactor > 10000 { + return nil, errors.New("load_factor must be <= 10000") + } else { + account.LoadFactor = input.LoadFactor + } + } + if input.Status != "" { + account.Status = input.Status + } + if input.ExpiresAt != nil { + if *input.ExpiresAt <= 0 { + account.ExpiresAt = nil + } else { + expiresAt := time.Unix(*input.ExpiresAt, 0) + account.ExpiresAt = &expiresAt + } + } + if input.AutoPauseOnExpired != nil { + account.AutoPauseOnExpired = *input.AutoPauseOnExpired + } + + // Sora apikey 账号的 base_url 必填校验 + if account.Platform == PlatformSora && account.Type == AccountTypeAPIKey { + baseURL, _ := account.Credentials["base_url"].(string) + baseURL = strings.TrimSpace(baseURL) + if baseURL == "" { + return nil, errors.New("sora apikey 账号必须设置 base_url") + } + if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { + return nil, errors.New("base_url 必须以 http:// 或 https:// 开头") + } + } + + // 先验证分组是否存在(在任何写操作之前) + if input.GroupIDs != nil { + if err := s.validateGroupIDsExist(ctx, *input.GroupIDs); err != nil { + return nil, err + } + + // 检查混合渠道风险(除非用户已确认) + if !input.SkipMixedChannelCheck { + if err := s.checkMixedChannelRisk(ctx, account.ID, account.Platform, *input.GroupIDs); err != nil { + return nil, err + } + } + } + + if err := s.accountRepo.Update(ctx, account); err != nil { + return nil, err + } + + // 绑定分组 + if input.GroupIDs != nil { + if err := s.accountRepo.BindGroups(ctx, account.ID, *input.GroupIDs); err != nil { + return nil, err + } + } + + // 重新查询以确保返回完整数据(包括正确的 Proxy 关联对象) + updated, err := s.accountRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + return updated, nil +} + +// BulkUpdateAccounts updates multiple accounts in one request. +// It merges credentials/extra keys instead of overwriting the whole object. +func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) { + result := &BulkUpdateAccountsResult{ + SuccessIDs: make([]int64, 0, len(input.AccountIDs)), + FailedIDs: make([]int64, 0, len(input.AccountIDs)), + Results: make([]BulkUpdateAccountResult, 0, len(input.AccountIDs)), + } + + if len(input.AccountIDs) == 0 { + return result, nil + } + if input.GroupIDs != nil { + if err := s.validateGroupIDsExist(ctx, *input.GroupIDs); err != nil { + return nil, err + } + } + + needMixedChannelCheck := input.GroupIDs != nil && !input.SkipMixedChannelCheck + + // 预加载账号平台信息(混合渠道检查需要)。 + platformByID := map[int64]string{} + if needMixedChannelCheck { + accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs) + if err != nil { + return nil, err + } + for _, account := range accounts { + if account != nil { + platformByID[account.ID] = account.Platform + } + } + } + + // 预检查混合渠道风险:在任何写操作之前,若发现风险立即返回错误。 + if needMixedChannelCheck { + for _, accountID := range input.AccountIDs { + platform := platformByID[accountID] + if platform == "" { + continue + } + if err := s.checkMixedChannelRisk(ctx, accountID, platform, *input.GroupIDs); err != nil { + return nil, err + } + } + } + + if input.RateMultiplier != nil { + if *input.RateMultiplier < 0 { + return nil, errors.New("rate_multiplier must be >= 0") + } + } + + // Prepare bulk updates for columns and JSONB fields. + repoUpdates := AccountBulkUpdate{ + Credentials: input.Credentials, + Extra: input.Extra, + } + if input.Name != "" { + repoUpdates.Name = &input.Name + } + if input.ProxyID != nil { + repoUpdates.ProxyID = input.ProxyID + } + if input.Concurrency != nil { + repoUpdates.Concurrency = input.Concurrency + } + if input.Priority != nil { + repoUpdates.Priority = input.Priority + } + if input.RateMultiplier != nil { + repoUpdates.RateMultiplier = input.RateMultiplier + } + if input.LoadFactor != nil { + if *input.LoadFactor <= 0 { + repoUpdates.LoadFactor = nil // 0 或负数表示清除 + } else if *input.LoadFactor > 10000 { + return nil, errors.New("load_factor must be <= 10000") + } else { + repoUpdates.LoadFactor = input.LoadFactor + } + } + if input.Status != "" { + repoUpdates.Status = &input.Status + } + if input.Schedulable != nil { + repoUpdates.Schedulable = input.Schedulable + } + + // Run bulk update for column/jsonb fields first. + if _, err := s.accountRepo.BulkUpdate(ctx, input.AccountIDs, repoUpdates); err != nil { + return nil, err + } + + // Handle group bindings per account (requires individual operations). + for _, accountID := range input.AccountIDs { + entry := BulkUpdateAccountResult{AccountID: accountID} + + if input.GroupIDs != nil { + if err := s.accountRepo.BindGroups(ctx, accountID, *input.GroupIDs); err != nil { + entry.Success = false + entry.Error = err.Error() + result.Failed++ + result.FailedIDs = append(result.FailedIDs, accountID) + result.Results = append(result.Results, entry) + continue + } + } + + entry.Success = true + result.Success++ + result.SuccessIDs = append(result.SuccessIDs, accountID) + result.Results = append(result.Results, entry) + } + + return result, nil +} + +func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error { + if err := s.accountRepo.Delete(ctx, id); err != nil { + return err + } + return nil +} + +func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error) { + if _, err := s.accountRepo.GetByID(ctx, id); err != nil { + return nil, err + } + return nil, infraerrors.New(http.StatusNotImplemented, "REFRESH_CREDENTIALS_NOT_IMPLEMENTED", "account credential refresh is not yet implemented") +} + +func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Account, error) { + if err := s.accountRepo.ClearError(ctx, id); err != nil { + return nil, err + } + return s.accountRepo.GetByID(ctx, id) +} + +func (s *adminServiceImpl) SetAccountError(ctx context.Context, id int64, errorMsg string) error { + return s.accountRepo.SetError(ctx, id, errorMsg) +} + +func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) { + if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil { + return nil, err + } + updated, err := s.accountRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + return updated, nil +} + +// Proxy management implementations +func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) { + params := pagination.PaginationParams{Page: page, PageSize: pageSize} + proxies, result, err := s.proxyRepo.ListWithFilters(ctx, params, protocol, status, search) + if err != nil { + return nil, 0, err + } + return proxies, result.Total, nil +} + +func (s *adminServiceImpl) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]ProxyWithAccountCount, int64, error) { + params := pagination.PaginationParams{Page: page, PageSize: pageSize} + proxies, result, err := s.proxyRepo.ListWithFiltersAndAccountCount(ctx, params, protocol, status, search) + if err != nil { + return nil, 0, err + } + s.attachProxyLatency(ctx, proxies) + return proxies, result.Total, nil +} + +func (s *adminServiceImpl) GetAllProxies(ctx context.Context) ([]Proxy, error) { + return s.proxyRepo.ListActive(ctx) +} + +func (s *adminServiceImpl) GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) { + proxies, err := s.proxyRepo.ListActiveWithAccountCount(ctx) + if err != nil { + return nil, err + } + s.attachProxyLatency(ctx, proxies) + return proxies, nil +} + +func (s *adminServiceImpl) GetProxy(ctx context.Context, id int64) (*Proxy, error) { + return s.proxyRepo.GetByID(ctx, id) +} + +func (s *adminServiceImpl) GetProxiesByIDs(ctx context.Context, ids []int64) ([]Proxy, error) { + return s.proxyRepo.ListByIDs(ctx, ids) +} + +func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error) { + proxy := &Proxy{ + Name: input.Name, + Protocol: input.Protocol, + Host: input.Host, + Port: input.Port, + Username: input.Username, + Password: input.Password, + Status: StatusActive, + } + if err := s.proxyRepo.Create(ctx, proxy); err != nil { + return nil, err + } + // Probe latency asynchronously so creation isn't blocked by network timeout. + go s.probeProxyLatency(context.Background(), proxy) + return proxy, nil +} + +func (s *adminServiceImpl) UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*Proxy, error) { + proxy, err := s.proxyRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + + if input.Name != "" { + proxy.Name = input.Name + } + if input.Protocol != "" { + proxy.Protocol = input.Protocol + } + if input.Host != "" { + proxy.Host = input.Host + } + if input.Port != 0 { + proxy.Port = input.Port + } + if input.Username != "" { + proxy.Username = input.Username + } + if input.Password != "" { + proxy.Password = input.Password + } + if input.Status != "" { + proxy.Status = input.Status + } + + if err := s.proxyRepo.Update(ctx, proxy); err != nil { + return nil, err + } + return proxy, nil +} + +func (s *adminServiceImpl) DeleteProxy(ctx context.Context, id int64) error { + count, err := s.proxyRepo.CountAccountsByProxyID(ctx, id) + if err != nil { + return err + } + if count > 0 { + return ErrProxyInUse + } + return s.proxyRepo.Delete(ctx, id) +} + +func (s *adminServiceImpl) BatchDeleteProxies(ctx context.Context, ids []int64) (*ProxyBatchDeleteResult, error) { + result := &ProxyBatchDeleteResult{} + if len(ids) == 0 { + return result, nil + } + + for _, id := range ids { + count, err := s.proxyRepo.CountAccountsByProxyID(ctx, id) + if err != nil { + result.Skipped = append(result.Skipped, ProxyBatchDeleteSkipped{ + ID: id, + Reason: err.Error(), + }) + continue + } + if count > 0 { + result.Skipped = append(result.Skipped, ProxyBatchDeleteSkipped{ + ID: id, + Reason: ErrProxyInUse.Error(), + }) + continue + } + if err := s.proxyRepo.Delete(ctx, id); err != nil { + result.Skipped = append(result.Skipped, ProxyBatchDeleteSkipped{ + ID: id, + Reason: err.Error(), + }) + continue + } + result.DeletedIDs = append(result.DeletedIDs, id) + } + + return result, nil +} + +func (s *adminServiceImpl) GetProxyAccounts(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) { + return s.proxyRepo.ListAccountSummariesByProxyID(ctx, proxyID) +} + +func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) { + return s.proxyRepo.ExistsByHostPortAuth(ctx, host, port, username, password) +} + +// Redeem code management implementations +func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error) { + params := pagination.PaginationParams{Page: page, PageSize: pageSize} + codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search) + if err != nil { + return nil, 0, err + } + return codes, result.Total, nil +} + +func (s *adminServiceImpl) GetRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) { + return s.redeemCodeRepo.GetByID(ctx, id) +} + +func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error) { + // 如果是订阅类型,验证必须有 GroupID + if input.Type == RedeemTypeSubscription { + if input.GroupID == nil { + return nil, errors.New("group_id is required for subscription type") + } + // 验证分组存在且为订阅类型 + group, err := s.groupRepo.GetByID(ctx, *input.GroupID) + if err != nil { + return nil, fmt.Errorf("group not found: %w", err) + } + if !group.IsSubscriptionType() { + return nil, errors.New("group must be subscription type") + } + } + + codes := make([]RedeemCode, 0, input.Count) + for i := 0; i < input.Count; i++ { + codeValue, err := GenerateRedeemCode() + if err != nil { + return nil, err + } + code := RedeemCode{ + Code: codeValue, + Type: input.Type, + Value: input.Value, + Status: StatusUnused, + } + // 订阅类型专用字段 + if input.Type == RedeemTypeSubscription { + code.GroupID = input.GroupID + code.ValidityDays = input.ValidityDays + if code.ValidityDays <= 0 { + code.ValidityDays = 30 // 默认30天 + } + } + if err := s.redeemCodeRepo.Create(ctx, &code); err != nil { + return nil, err + } + codes = append(codes, code) + } + return codes, nil +} + +func (s *adminServiceImpl) DeleteRedeemCode(ctx context.Context, id int64) error { + return s.redeemCodeRepo.Delete(ctx, id) +} + +func (s *adminServiceImpl) BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error) { + var deleted int64 + for _, id := range ids { + if err := s.redeemCodeRepo.Delete(ctx, id); err == nil { + deleted++ + } + } + return deleted, nil +} + +func (s *adminServiceImpl) ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) { + code, err := s.redeemCodeRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + code.Status = StatusExpired + if err := s.redeemCodeRepo.Update(ctx, code); err != nil { + return nil, err + } + return code, nil +} + +func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error) { + proxy, err := s.proxyRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + + proxyURL := proxy.URL() + exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL) + if err != nil { + s.saveProxyLatency(ctx, id, &ProxyLatencyInfo{ + Success: false, + Message: err.Error(), + UpdatedAt: time.Now(), + }) + return &ProxyTestResult{ + Success: false, + Message: err.Error(), + }, nil + } + + latency := latencyMs + s.saveProxyLatency(ctx, id, &ProxyLatencyInfo{ + Success: true, + LatencyMs: &latency, + Message: "Proxy is accessible", + IPAddress: exitInfo.IP, + Country: exitInfo.Country, + CountryCode: exitInfo.CountryCode, + Region: exitInfo.Region, + City: exitInfo.City, + UpdatedAt: time.Now(), + }) + return &ProxyTestResult{ + Success: true, + Message: "Proxy is accessible", + LatencyMs: latencyMs, + IPAddress: exitInfo.IP, + City: exitInfo.City, + Region: exitInfo.Region, + Country: exitInfo.Country, + CountryCode: exitInfo.CountryCode, + }, nil +} + +func (s *adminServiceImpl) CheckProxyQuality(ctx context.Context, id int64) (*ProxyQualityCheckResult, error) { + proxy, err := s.proxyRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + + result := &ProxyQualityCheckResult{ + ProxyID: id, + Score: 100, + Grade: "A", + CheckedAt: time.Now().Unix(), + Items: make([]ProxyQualityCheckItem, 0, len(proxyQualityTargets)+1), + } + + proxyURL := proxy.URL() + if s.proxyProber == nil { + result.Items = append(result.Items, ProxyQualityCheckItem{ + Target: "base_connectivity", + Status: "fail", + Message: "代理探测服务未配置", + }) + result.FailedCount++ + finalizeProxyQualityResult(result) + s.saveProxyQualitySnapshot(ctx, id, result, nil) + return result, nil + } + + exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL) + if err != nil { + result.Items = append(result.Items, ProxyQualityCheckItem{ + Target: "base_connectivity", + Status: "fail", + LatencyMs: latencyMs, + Message: err.Error(), + }) + result.FailedCount++ + finalizeProxyQualityResult(result) + s.saveProxyQualitySnapshot(ctx, id, result, nil) + return result, nil + } + + result.ExitIP = exitInfo.IP + result.Country = exitInfo.Country + result.CountryCode = exitInfo.CountryCode + result.BaseLatencyMs = latencyMs + result.Items = append(result.Items, ProxyQualityCheckItem{ + Target: "base_connectivity", + Status: "pass", + LatencyMs: latencyMs, + Message: "代理出口连通正常", + }) + result.PassedCount++ + + client, err := httpclient.GetClient(httpclient.Options{ + ProxyURL: proxyURL, + Timeout: proxyQualityRequestTimeout, + ResponseHeaderTimeout: proxyQualityResponseHeaderTimeout, + }) + if err != nil { + result.Items = append(result.Items, ProxyQualityCheckItem{ + Target: "http_client", + Status: "fail", + Message: fmt.Sprintf("创建检测客户端失败: %v", err), + }) + result.FailedCount++ + finalizeProxyQualityResult(result) + s.saveProxyQualitySnapshot(ctx, id, result, exitInfo) + return result, nil + } + + for _, target := range proxyQualityTargets { + item := runProxyQualityTarget(ctx, client, target) + result.Items = append(result.Items, item) + switch item.Status { + case "pass": + result.PassedCount++ + case "warn": + result.WarnCount++ + case "challenge": + result.ChallengeCount++ + default: + result.FailedCount++ + } + } + + finalizeProxyQualityResult(result) + s.saveProxyQualitySnapshot(ctx, id, result, exitInfo) + return result, nil +} + +func runProxyQualityTarget(ctx context.Context, client *http.Client, target proxyQualityTarget) ProxyQualityCheckItem { + item := ProxyQualityCheckItem{ + Target: target.Target, + } + + req, err := http.NewRequestWithContext(ctx, target.Method, target.URL, nil) + if err != nil { + item.Status = "fail" + item.Message = fmt.Sprintf("构建请求失败: %v", err) + return item + } + req.Header.Set("Accept", "application/json,text/html,*/*") + req.Header.Set("User-Agent", proxyQualityClientUserAgent) + + start := time.Now() + resp, err := client.Do(req) + if err != nil { + item.Status = "fail" + item.LatencyMs = time.Since(start).Milliseconds() + item.Message = fmt.Sprintf("请求失败: %v", err) + return item + } + defer func() { _ = resp.Body.Close() }() + item.LatencyMs = time.Since(start).Milliseconds() + item.HTTPStatus = resp.StatusCode + + body, readErr := io.ReadAll(io.LimitReader(resp.Body, proxyQualityMaxBodyBytes+1)) + if readErr != nil { + item.Status = "fail" + item.Message = fmt.Sprintf("读取响应失败: %v", readErr) + return item + } + if int64(len(body)) > proxyQualityMaxBodyBytes { + body = body[:proxyQualityMaxBodyBytes] + } + + if target.Target == "sora" && soraerror.IsCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) { + item.Status = "challenge" + item.CFRay = soraerror.ExtractCloudflareRayID(resp.Header, body) + item.Message = "Sora 命中 Cloudflare challenge" + return item + } + + if _, ok := target.AllowedStatuses[resp.StatusCode]; ok { + if resp.StatusCode >= http.StatusOK && resp.StatusCode < http.StatusMultipleChoices { + item.Status = "pass" + item.Message = fmt.Sprintf("HTTP %d", resp.StatusCode) + } else { + item.Status = "warn" + item.Message = fmt.Sprintf("HTTP %d(目标可达,但鉴权或方法受限)", resp.StatusCode) + } + return item + } + + if resp.StatusCode == http.StatusTooManyRequests { + item.Status = "warn" + item.Message = "目标返回 429,可能存在频控" + return item + } + + item.Status = "fail" + item.Message = fmt.Sprintf("非预期状态码: %d", resp.StatusCode) + return item +} + +func finalizeProxyQualityResult(result *ProxyQualityCheckResult) { + if result == nil { + return + } + score := 100 - result.WarnCount*10 - result.FailedCount*22 - result.ChallengeCount*30 + if score < 0 { + score = 0 + } + result.Score = score + result.Grade = proxyQualityGrade(score) + result.Summary = fmt.Sprintf( + "通过 %d 项,告警 %d 项,失败 %d 项,挑战 %d 项", + result.PassedCount, + result.WarnCount, + result.FailedCount, + result.ChallengeCount, + ) +} + +func proxyQualityGrade(score int) string { + switch { + case score >= 90: + return "A" + case score >= 75: + return "B" + case score >= 60: + return "C" + case score >= 40: + return "D" + default: + return "F" + } +} + +func proxyQualityOverallStatus(result *ProxyQualityCheckResult) string { + if result == nil { + return "" + } + if result.ChallengeCount > 0 { + return "challenge" + } + if result.FailedCount > 0 { + return "failed" + } + if result.WarnCount > 0 { + return "warn" + } + if result.PassedCount > 0 { + return "healthy" + } + return "failed" +} + +func proxyQualityFirstCFRay(result *ProxyQualityCheckResult) string { + if result == nil { + return "" + } + for _, item := range result.Items { + if item.CFRay != "" { + return item.CFRay + } + } + return "" +} + +func proxyQualityBaseConnectivityPass(result *ProxyQualityCheckResult) bool { + if result == nil { + return false + } + for _, item := range result.Items { + if item.Target == "base_connectivity" { + return item.Status == "pass" + } + } + return false +} + +func (s *adminServiceImpl) saveProxyQualitySnapshot(ctx context.Context, proxyID int64, result *ProxyQualityCheckResult, exitInfo *ProxyExitInfo) { + if result == nil { + return + } + score := result.Score + checkedAt := result.CheckedAt + info := &ProxyLatencyInfo{ + Success: proxyQualityBaseConnectivityPass(result), + Message: result.Summary, + QualityStatus: proxyQualityOverallStatus(result), + QualityScore: &score, + QualityGrade: result.Grade, + QualitySummary: result.Summary, + QualityCheckedAt: &checkedAt, + QualityCFRay: proxyQualityFirstCFRay(result), + UpdatedAt: time.Now(), + } + if result.BaseLatencyMs > 0 { + latency := result.BaseLatencyMs + info.LatencyMs = &latency + } + if exitInfo != nil { + info.IPAddress = exitInfo.IP + info.Country = exitInfo.Country + info.CountryCode = exitInfo.CountryCode + info.Region = exitInfo.Region + info.City = exitInfo.City + } + s.saveProxyLatency(ctx, proxyID, info) +} + +func (s *adminServiceImpl) probeProxyLatency(ctx context.Context, proxy *Proxy) { + if s.proxyProber == nil || proxy == nil { + return + } + exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxy.URL()) + if err != nil { + s.saveProxyLatency(ctx, proxy.ID, &ProxyLatencyInfo{ + Success: false, + Message: err.Error(), + UpdatedAt: time.Now(), + }) + return + } + + latency := latencyMs + s.saveProxyLatency(ctx, proxy.ID, &ProxyLatencyInfo{ + Success: true, + LatencyMs: &latency, + Message: "Proxy is accessible", + IPAddress: exitInfo.IP, + Country: exitInfo.Country, + CountryCode: exitInfo.CountryCode, + Region: exitInfo.Region, + City: exitInfo.City, + UpdatedAt: time.Now(), + }) +} + +// checkMixedChannelRisk 检查分组中是否存在混合渠道(Antigravity + Anthropic) +// 如果存在混合,返回错误提示用户确认 +func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error { + // 判断当前账号的渠道类型(基于 platform 字段,而不是 type 字段) + currentPlatform := getAccountPlatform(currentAccountPlatform) + if currentPlatform == "" { + // 不是 Antigravity 或 Anthropic,无需检查 + return nil + } + + // 检查每个分组中的其他账号 + for _, groupID := range groupIDs { + accounts, err := s.accountRepo.ListByGroup(ctx, groupID) + if err != nil { + return fmt.Errorf("get accounts in group %d: %w", groupID, err) + } + + // 检查是否存在不同渠道的账号 + for _, account := range accounts { + if currentAccountID > 0 && account.ID == currentAccountID { + continue // 跳过当前账号 + } + + otherPlatform := getAccountPlatform(account.Platform) + if otherPlatform == "" { + continue // 不是 Antigravity 或 Anthropic,跳过 + } + + // 检测混合渠道 + if currentPlatform != otherPlatform { + group, _ := s.groupRepo.GetByID(ctx, groupID) + groupName := fmt.Sprintf("Group %d", groupID) + if group != nil { + groupName = group.Name + } + + return &MixedChannelError{ + GroupID: groupID, + GroupName: groupName, + CurrentPlatform: currentPlatform, + OtherPlatform: otherPlatform, + } + } + } + } + + return nil +} + +func (s *adminServiceImpl) validateGroupIDsExist(ctx context.Context, groupIDs []int64) error { + if len(groupIDs) == 0 { + return nil + } + if s.groupRepo == nil { + return errors.New("group repository not configured") + } + + if batchReader, ok := s.groupRepo.(groupExistenceBatchReader); ok { + existsByID, err := batchReader.ExistsByIDs(ctx, groupIDs) + if err != nil { + return fmt.Errorf("check groups exists: %w", err) + } + for _, groupID := range groupIDs { + if groupID <= 0 || !existsByID[groupID] { + return fmt.Errorf("get group: %w", ErrGroupNotFound) + } + } + return nil + } + + for _, groupID := range groupIDs { + if _, err := s.groupRepo.GetByID(ctx, groupID); err != nil { + return fmt.Errorf("get group: %w", err) + } + } + return nil +} + +// CheckMixedChannelRisk checks whether target groups contain mixed channels for the current account platform. +func (s *adminServiceImpl) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error { + return s.checkMixedChannelRisk(ctx, currentAccountID, currentAccountPlatform, groupIDs) +} + +func (s *adminServiceImpl) attachProxyLatency(ctx context.Context, proxies []ProxyWithAccountCount) { + if s.proxyLatencyCache == nil || len(proxies) == 0 { + return + } + + ids := make([]int64, 0, len(proxies)) + for i := range proxies { + ids = append(ids, proxies[i].ID) + } + + latencies, err := s.proxyLatencyCache.GetProxyLatencies(ctx, ids) + if err != nil { + logger.LegacyPrintf("service.admin", "Warning: load proxy latency cache failed: %v", err) + return + } + + for i := range proxies { + info := latencies[proxies[i].ID] + if info == nil { + continue + } + if info.Success { + proxies[i].LatencyStatus = "success" + proxies[i].LatencyMs = info.LatencyMs + } else { + proxies[i].LatencyStatus = "failed" + } + proxies[i].LatencyMessage = info.Message + proxies[i].IPAddress = info.IPAddress + proxies[i].Country = info.Country + proxies[i].CountryCode = info.CountryCode + proxies[i].Region = info.Region + proxies[i].City = info.City + proxies[i].QualityStatus = info.QualityStatus + proxies[i].QualityScore = info.QualityScore + proxies[i].QualityGrade = info.QualityGrade + proxies[i].QualitySummary = info.QualitySummary + proxies[i].QualityChecked = info.QualityCheckedAt + } +} + +func (s *adminServiceImpl) saveProxyLatency(ctx context.Context, proxyID int64, info *ProxyLatencyInfo) { + if s.proxyLatencyCache == nil || info == nil { + return + } + + merged := *info + if latencies, err := s.proxyLatencyCache.GetProxyLatencies(ctx, []int64{proxyID}); err == nil { + if existing := latencies[proxyID]; existing != nil { + if merged.QualityCheckedAt == nil && + merged.QualityScore == nil && + merged.QualityGrade == "" && + merged.QualityStatus == "" && + merged.QualitySummary == "" && + merged.QualityCFRay == "" { + merged.QualityStatus = existing.QualityStatus + merged.QualityScore = existing.QualityScore + merged.QualityGrade = existing.QualityGrade + merged.QualitySummary = existing.QualitySummary + merged.QualityCheckedAt = existing.QualityCheckedAt + merged.QualityCFRay = existing.QualityCFRay + } + } + } + + if err := s.proxyLatencyCache.SetProxyLatency(ctx, proxyID, &merged); err != nil { + logger.LegacyPrintf("service.admin", "Warning: store proxy latency cache failed: %v", err) + } +} + +// getAccountPlatform 根据账号 platform 判断混合渠道检查用的平台标识 +func getAccountPlatform(accountPlatform string) string { + switch strings.ToLower(strings.TrimSpace(accountPlatform)) { + case PlatformAntigravity: + return "Antigravity" + case PlatformAnthropic, "claude": + return "Anthropic" + default: + return "" + } +} + +// MixedChannelError 混合渠道错误 +type MixedChannelError struct { + GroupID int64 + GroupName string + CurrentPlatform string + OtherPlatform string +} + +func (e *MixedChannelError) Error() string { + return fmt.Sprintf("mixed_channel_warning: Group '%s' contains both %s and %s accounts. Using mixed channels in the same context may cause thinking block signature validation issues, which will fallback to non-thinking mode for historical messages.", + e.GroupName, e.CurrentPlatform, e.OtherPlatform) +} + +func (s *adminServiceImpl) ResetAccountQuota(ctx context.Context, id int64) error { + return s.accountRepo.ResetQuotaUsed(ctx, id) +} + +// EnsureOpenAIPrivacy 检查 OpenAI OAuth 账号是否已设置 privacy_mode, +// 未设置则调用 disableOpenAITraining 并持久化到 Extra,返回设置的 mode 值。 +func (s *adminServiceImpl) EnsureOpenAIPrivacy(ctx context.Context, account *Account) string { + if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth { + return "" + } + if s.privacyClientFactory == nil { + return "" + } + if account.Extra != nil { + if _, ok := account.Extra["privacy_mode"]; ok { + return "" + } + } + + token, _ := account.Credentials["access_token"].(string) + if token == "" { + return "" + } + + var proxyURL string + if account.ProxyID != nil { + if p, err := s.proxyRepo.GetByID(ctx, *account.ProxyID); err == nil && p != nil { + proxyURL = p.URL() + } + } + + mode := disableOpenAITraining(ctx, s.privacyClientFactory, token, proxyURL) + if mode == "" { + return "" + } + + _ = s.accountRepo.UpdateExtra(ctx, account.ID, map[string]any{"privacy_mode": mode}) + return mode +} diff --git a/backend/internal/service/api_key_service.go b/backend/internal/service/api_key_service.go new file mode 100644 index 00000000..41ce2a41 --- /dev/null +++ b/backend/internal/service/api_key_service.go @@ -0,0 +1,881 @@ +package service + +import ( + "context" + "crypto/rand" + "encoding/hex" + "fmt" + "strconv" + "strings" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/ip" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" + "github.com/dgraph-io/ristretto" + "golang.org/x/sync/singleflight" +) + +var ( + ErrAPIKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found") + ErrGroupNotAllowed = infraerrors.Forbidden("GROUP_NOT_ALLOWED", "user is not allowed to bind this group") + ErrAPIKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists") + ErrAPIKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters") + ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens") + ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later") + ErrInvalidIPPattern = infraerrors.BadRequest("INVALID_IP_PATTERN", "invalid IP or CIDR pattern") + ErrAPIKeyExpired = infraerrors.Forbidden("API_KEY_EXPIRED", "api key has expired") + ErrAPIKeyQuotaExhausted = infraerrors.TooManyRequests("API_KEY_QUOTA_EXHAUSTED", "api key quota exhausted") + + // Rate limit errors + ErrAPIKeyRateLimit5hExceeded = infraerrors.TooManyRequests("API_KEY_RATE_5H_EXCEEDED", "api key 5h rate limit exceeded") + ErrAPIKeyRateLimit1dExceeded = infraerrors.TooManyRequests("API_KEY_RATE_1D_EXCEEDED", "api key daily rate limit exceeded") + ErrAPIKeyRateLimit7dExceeded = infraerrors.TooManyRequests("API_KEY_RATE_7D_EXCEEDED", "api key 7d rate limit exceeded") +) + +const ( + apiKeyMaxErrorsPerHour = 20 + apiKeyLastUsedMinTouch = 30 * time.Second + // DB 写失败后的短退避,避免请求路径持续同步重试造成写风暴与高延迟。 + apiKeyLastUsedFailBackoff = 5 * time.Second +) + +type APIKeyRepository interface { + Create(ctx context.Context, key *APIKey) error + GetByID(ctx context.Context, id int64) (*APIKey, error) + // GetKeyAndOwnerID 仅获取 API Key 的 key 与所有者 ID,用于删除等轻量场景 + GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) + GetByKey(ctx context.Context, key string) (*APIKey, error) + // GetByKeyForAuth 认证专用查询,返回最小字段集 + GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error) + Update(ctx context.Context, key *APIKey) error + Delete(ctx context.Context, id int64) error + + ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, filters APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) + VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) + CountByUserID(ctx context.Context, userID int64) (int64, error) + ExistsByKey(ctx context.Context, key string) (bool, error) + ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) + SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) + ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) + // UpdateGroupIDByUserAndGroup 将用户下绑定 oldGroupID 的所有 Key 迁移到 newGroupID + UpdateGroupIDByUserAndGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (int64, error) + CountByGroupID(ctx context.Context, groupID int64) (int64, error) + ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) + ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) + + // Quota methods + IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) + UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error + + // Rate limit methods + IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error + ResetRateLimitWindows(ctx context.Context, id int64) error + GetRateLimitData(ctx context.Context, id int64) (*APIKeyRateLimitData, error) +} + +// APIKeyRateLimitData holds rate limit usage and window state for an API key. +type APIKeyRateLimitData struct { + Usage5h float64 + Usage1d float64 + Usage7d float64 + Window5hStart *time.Time + Window1dStart *time.Time + Window7dStart *time.Time +} + +// EffectiveUsage5h returns the 5h window usage, or 0 if the window has expired. +func (d *APIKeyRateLimitData) EffectiveUsage5h() float64 { + if IsWindowExpired(d.Window5hStart, RateLimitWindow5h) { + return 0 + } + return d.Usage5h +} + +// EffectiveUsage1d returns the 1d window usage, or 0 if the window has expired. +func (d *APIKeyRateLimitData) EffectiveUsage1d() float64 { + if IsWindowExpired(d.Window1dStart, RateLimitWindow1d) { + return 0 + } + return d.Usage1d +} + +// EffectiveUsage7d returns the 7d window usage, or 0 if the window has expired. +func (d *APIKeyRateLimitData) EffectiveUsage7d() float64 { + if IsWindowExpired(d.Window7dStart, RateLimitWindow7d) { + return 0 + } + return d.Usage7d +} + +// APIKeyQuotaUsageState captures the latest quota fields after an atomic quota update. +// It is intentionally small so repositories can return it from a single SQL statement. +type APIKeyQuotaUsageState struct { + QuotaUsed float64 + Quota float64 + Key string + Status string +} + +// APIKeyCache defines cache operations for API key service +type APIKeyCache interface { + GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) + IncrementCreateAttemptCount(ctx context.Context, userID int64) error + DeleteCreateAttemptCount(ctx context.Context, userID int64) error + + IncrementDailyUsage(ctx context.Context, apiKey string) error + SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error + + GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) + SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error + DeleteAuthCache(ctx context.Context, key string) error + + // Pub/Sub for L1 cache invalidation across instances + PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error + SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error +} + +// APIKeyAuthCacheInvalidator 提供认证缓存失效能力 +type APIKeyAuthCacheInvalidator interface { + InvalidateAuthCacheByKey(ctx context.Context, key string) + InvalidateAuthCacheByUserID(ctx context.Context, userID int64) + InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) +} + +// CreateAPIKeyRequest 创建API Key请求 +type CreateAPIKeyRequest struct { + Name string `json:"name"` + GroupID *int64 `json:"group_id"` + CustomKey *string `json:"custom_key"` // 可选的自定义key + IPWhitelist []string `json:"ip_whitelist"` // IP 白名单 + IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单 + + // Quota fields + Quota float64 `json:"quota"` // Quota limit in USD (0 = unlimited) + ExpiresInDays *int `json:"expires_in_days"` // Days until expiry (nil = never expires) + + // Rate limit fields (0 = unlimited) + RateLimit5h float64 `json:"rate_limit_5h"` + RateLimit1d float64 `json:"rate_limit_1d"` + RateLimit7d float64 `json:"rate_limit_7d"` +} + +// UpdateAPIKeyRequest 更新API Key请求 +type UpdateAPIKeyRequest struct { + Name *string `json:"name"` + GroupID *int64 `json:"group_id"` + Status *string `json:"status"` + IPWhitelist []string `json:"ip_whitelist"` // IP 白名单(空数组清空) + IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单(空数组清空) + + // Quota fields + Quota *float64 `json:"quota"` // Quota limit in USD (nil = no change, 0 = unlimited) + ExpiresAt *time.Time `json:"expires_at"` // Expiration time (nil = no change) + ClearExpiration bool `json:"-"` // Clear expiration (internal use) + ResetQuota *bool `json:"reset_quota"` // Reset quota_used to 0 + + // Rate limit fields (nil = no change, 0 = unlimited) + RateLimit5h *float64 `json:"rate_limit_5h"` + RateLimit1d *float64 `json:"rate_limit_1d"` + RateLimit7d *float64 `json:"rate_limit_7d"` + ResetRateLimitUsage *bool `json:"reset_rate_limit_usage"` // Reset all usage counters to 0 +} + +// APIKeyService API Key服务 +// RateLimitCacheInvalidator invalidates rate limit cache entries on manual reset. +type RateLimitCacheInvalidator interface { + InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error +} + +type APIKeyService struct { + apiKeyRepo APIKeyRepository + userRepo UserRepository + groupRepo GroupRepository + userSubRepo UserSubscriptionRepository + userGroupRateRepo UserGroupRateRepository + cache APIKeyCache + rateLimitCacheInvalid RateLimitCacheInvalidator // optional: invalidate Redis rate limit cache + cfg *config.Config + authCacheL1 *ristretto.Cache + authCfg apiKeyAuthCacheConfig + authGroup singleflight.Group + lastUsedTouchL1 sync.Map // keyID -> nextAllowedAt(time.Time) + lastUsedTouchSF singleflight.Group +} + +// NewAPIKeyService 创建API Key服务实例 +func NewAPIKeyService( + apiKeyRepo APIKeyRepository, + userRepo UserRepository, + groupRepo GroupRepository, + userSubRepo UserSubscriptionRepository, + userGroupRateRepo UserGroupRateRepository, + cache APIKeyCache, + cfg *config.Config, +) *APIKeyService { + svc := &APIKeyService{ + apiKeyRepo: apiKeyRepo, + userRepo: userRepo, + groupRepo: groupRepo, + userSubRepo: userSubRepo, + userGroupRateRepo: userGroupRateRepo, + cache: cache, + cfg: cfg, + } + svc.initAuthCache(cfg) + return svc +} + +// SetRateLimitCacheInvalidator sets the optional rate limit cache invalidator. +// Called after construction (e.g. in wire) to avoid circular dependencies. +func (s *APIKeyService) SetRateLimitCacheInvalidator(inv RateLimitCacheInvalidator) { + s.rateLimitCacheInvalid = inv +} + +func (s *APIKeyService) compileAPIKeyIPRules(apiKey *APIKey) { + if apiKey == nil { + return + } + apiKey.CompiledIPWhitelist = ip.CompileIPRules(apiKey.IPWhitelist) + apiKey.CompiledIPBlacklist = ip.CompileIPRules(apiKey.IPBlacklist) +} + +// GenerateKey 生成随机API Key +func (s *APIKeyService) GenerateKey() (string, error) { + // 生成32字节随机数据 + bytes := make([]byte, 32) + if _, err := rand.Read(bytes); err != nil { + return "", fmt.Errorf("generate random bytes: %w", err) + } + + // 转换为十六进制字符串并添加前缀 + prefix := s.cfg.Default.APIKeyPrefix + if prefix == "" { + prefix = "sk-" + } + + key := prefix + hex.EncodeToString(bytes) + return key, nil +} + +// ValidateCustomKey 验证自定义API Key格式 +func (s *APIKeyService) ValidateCustomKey(key string) error { + // 检查长度 + if len(key) < 16 { + return ErrAPIKeyTooShort + } + + // 检查字符:只允许字母、数字、下划线、连字符 + for _, c := range key { + if (c >= 'a' && c <= 'z') || + (c >= 'A' && c <= 'Z') || + (c >= '0' && c <= '9') || + c == '_' || c == '-' { + continue + } + return ErrAPIKeyInvalidChars + } + + return nil +} + +// checkAPIKeyRateLimit 检查用户创建自定义Key的错误次数是否超限 +func (s *APIKeyService) checkAPIKeyRateLimit(ctx context.Context, userID int64) error { + if s.cache == nil { + return nil + } + + count, err := s.cache.GetCreateAttemptCount(ctx, userID) + if err != nil { + // Redis 出错时不阻止用户操作 + return nil + } + + if count >= apiKeyMaxErrorsPerHour { + return ErrAPIKeyRateLimited + } + + return nil +} + +// incrementAPIKeyErrorCount 增加用户创建自定义Key的错误计数 +func (s *APIKeyService) incrementAPIKeyErrorCount(ctx context.Context, userID int64) { + if s.cache == nil { + return + } + + _ = s.cache.IncrementCreateAttemptCount(ctx, userID) +} + +// canUserBindGroup 检查用户是否可以绑定指定分组 +// 对于订阅类型分组:检查用户是否有有效订阅 +// 对于标准类型分组:使用原有的 AllowedGroups 和 IsExclusive 逻辑 +func (s *APIKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool { + // 订阅类型分组:需要有效订阅 + if group.IsSubscriptionType() { + _, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, user.ID, group.ID) + return err == nil // 有有效订阅则允许 + } + // 标准类型分组:使用原有逻辑 + return user.CanBindGroup(group.ID, group.IsExclusive) +} + +// Create 创建API Key +func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIKeyRequest) (*APIKey, error) { + // 验证用户存在 + user, err := s.userRepo.GetByID(ctx, userID) + if err != nil { + return nil, fmt.Errorf("get user: %w", err) + } + + // 验证 IP 白名单格式 + if len(req.IPWhitelist) > 0 { + if invalid := ip.ValidateIPPatterns(req.IPWhitelist); len(invalid) > 0 { + return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid) + } + } + + // 验证 IP 黑名单格式 + if len(req.IPBlacklist) > 0 { + if invalid := ip.ValidateIPPatterns(req.IPBlacklist); len(invalid) > 0 { + return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid) + } + } + + // 验证分组权限(如果指定了分组) + if req.GroupID != nil { + group, err := s.groupRepo.GetByID(ctx, *req.GroupID) + if err != nil { + return nil, fmt.Errorf("get group: %w", err) + } + + // 检查用户是否可以绑定该分组 + if !s.canUserBindGroup(ctx, user, group) { + return nil, ErrGroupNotAllowed + } + } + + var key string + + // 判断是否使用自定义Key + if req.CustomKey != nil && *req.CustomKey != "" { + // 检查限流(仅对自定义key进行限流) + if err := s.checkAPIKeyRateLimit(ctx, userID); err != nil { + return nil, err + } + + // 验证自定义Key格式 + if err := s.ValidateCustomKey(*req.CustomKey); err != nil { + return nil, err + } + + // 检查Key是否已存在 + exists, err := s.apiKeyRepo.ExistsByKey(ctx, *req.CustomKey) + if err != nil { + return nil, fmt.Errorf("check key exists: %w", err) + } + if exists { + // Key已存在,增加错误计数 + s.incrementAPIKeyErrorCount(ctx, userID) + return nil, ErrAPIKeyExists + } + + key = *req.CustomKey + } else { + // 生成随机API Key + var err error + key, err = s.GenerateKey() + if err != nil { + return nil, fmt.Errorf("generate key: %w", err) + } + } + + // 创建API Key记录 + apiKey := &APIKey{ + UserID: userID, + Key: key, + Name: req.Name, + GroupID: req.GroupID, + Status: StatusActive, + IPWhitelist: req.IPWhitelist, + IPBlacklist: req.IPBlacklist, + Quota: req.Quota, + QuotaUsed: 0, + RateLimit5h: req.RateLimit5h, + RateLimit1d: req.RateLimit1d, + RateLimit7d: req.RateLimit7d, + } + + // Set expiration time if specified + if req.ExpiresInDays != nil && *req.ExpiresInDays > 0 { + expiresAt := time.Now().AddDate(0, 0, *req.ExpiresInDays) + apiKey.ExpiresAt = &expiresAt + } + + if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil { + return nil, fmt.Errorf("create api key: %w", err) + } + + s.InvalidateAuthCacheByKey(ctx, apiKey.Key) + s.compileAPIKeyIPRules(apiKey) + + return apiKey, nil +} + +// List 获取用户的API Key列表 +func (s *APIKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams, filters APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) { + keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params, filters) + if err != nil { + return nil, nil, fmt.Errorf("list api keys: %w", err) + } + return keys, pagination, nil +} + +func (s *APIKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) { + if len(apiKeyIDs) == 0 { + return []int64{}, nil + } + + validIDs, err := s.apiKeyRepo.VerifyOwnership(ctx, userID, apiKeyIDs) + if err != nil { + return nil, fmt.Errorf("verify api key ownership: %w", err) + } + return validIDs, nil +} + +// GetByID 根据ID获取API Key +func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error) { + apiKey, err := s.apiKeyRepo.GetByID(ctx, id) + if err != nil { + return nil, fmt.Errorf("get api key: %w", err) + } + s.compileAPIKeyIPRules(apiKey) + return apiKey, nil +} + +// GetByKey 根据Key字符串获取API Key(用于认证) +func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, error) { + cacheKey := s.authCacheKey(key) + + if entry, ok := s.getAuthCacheEntry(ctx, cacheKey); ok { + if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used { + if err != nil { + return nil, fmt.Errorf("get api key: %w", err) + } + s.compileAPIKeyIPRules(apiKey) + return apiKey, nil + } + } + + if s.authCfg.singleflight { + value, err, _ := s.authGroup.Do(cacheKey, func() (any, error) { + return s.loadAuthCacheEntry(ctx, key, cacheKey) + }) + if err != nil { + return nil, err + } + entry, _ := value.(*APIKeyAuthCacheEntry) + if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used { + if err != nil { + return nil, fmt.Errorf("get api key: %w", err) + } + s.compileAPIKeyIPRules(apiKey) + return apiKey, nil + } + } else { + entry, err := s.loadAuthCacheEntry(ctx, key, cacheKey) + if err != nil { + return nil, err + } + if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used { + if err != nil { + return nil, fmt.Errorf("get api key: %w", err) + } + s.compileAPIKeyIPRules(apiKey) + return apiKey, nil + } + } + + apiKey, err := s.apiKeyRepo.GetByKeyForAuth(ctx, key) + if err != nil { + return nil, fmt.Errorf("get api key: %w", err) + } + apiKey.Key = key + s.compileAPIKeyIPRules(apiKey) + return apiKey, nil +} + +// Update 更新API Key +func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateAPIKeyRequest) (*APIKey, error) { + apiKey, err := s.apiKeyRepo.GetByID(ctx, id) + if err != nil { + return nil, fmt.Errorf("get api key: %w", err) + } + + // 验证所有权 + if apiKey.UserID != userID { + return nil, ErrInsufficientPerms + } + + // 验证 IP 白名单格式 + if len(req.IPWhitelist) > 0 { + if invalid := ip.ValidateIPPatterns(req.IPWhitelist); len(invalid) > 0 { + return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid) + } + } + + // 验证 IP 黑名单格式 + if len(req.IPBlacklist) > 0 { + if invalid := ip.ValidateIPPatterns(req.IPBlacklist); len(invalid) > 0 { + return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid) + } + } + + // 更新字段 + if req.Name != nil { + apiKey.Name = *req.Name + } + + if req.GroupID != nil { + // 验证分组权限 + user, err := s.userRepo.GetByID(ctx, userID) + if err != nil { + return nil, fmt.Errorf("get user: %w", err) + } + + group, err := s.groupRepo.GetByID(ctx, *req.GroupID) + if err != nil { + return nil, fmt.Errorf("get group: %w", err) + } + + if !s.canUserBindGroup(ctx, user, group) { + return nil, ErrGroupNotAllowed + } + + apiKey.GroupID = req.GroupID + } + + if req.Status != nil { + apiKey.Status = *req.Status + // 如果状态改变,清除Redis缓存 + if s.cache != nil { + _ = s.cache.DeleteCreateAttemptCount(ctx, apiKey.UserID) + } + } + + // Update quota fields + if req.Quota != nil { + apiKey.Quota = *req.Quota + // If quota is increased and status was quota_exhausted, reactivate + if apiKey.Status == StatusAPIKeyQuotaExhausted && *req.Quota > apiKey.QuotaUsed { + apiKey.Status = StatusActive + } + } + if req.ResetQuota != nil && *req.ResetQuota { + apiKey.QuotaUsed = 0 + // If resetting quota and status was quota_exhausted, reactivate + if apiKey.Status == StatusAPIKeyQuotaExhausted { + apiKey.Status = StatusActive + } + } + if req.ClearExpiration { + apiKey.ExpiresAt = nil + // If clearing expiry and status was expired, reactivate + if apiKey.Status == StatusAPIKeyExpired { + apiKey.Status = StatusActive + } + } else if req.ExpiresAt != nil { + apiKey.ExpiresAt = req.ExpiresAt + // If extending expiry and status was expired, reactivate + if apiKey.Status == StatusAPIKeyExpired && time.Now().Before(*req.ExpiresAt) { + apiKey.Status = StatusActive + } + } + + // 更新 IP 限制(空数组会清空设置) + apiKey.IPWhitelist = req.IPWhitelist + apiKey.IPBlacklist = req.IPBlacklist + + // Update rate limit configuration + if req.RateLimit5h != nil { + apiKey.RateLimit5h = *req.RateLimit5h + } + if req.RateLimit1d != nil { + apiKey.RateLimit1d = *req.RateLimit1d + } + if req.RateLimit7d != nil { + apiKey.RateLimit7d = *req.RateLimit7d + } + resetRateLimit := req.ResetRateLimitUsage != nil && *req.ResetRateLimitUsage + if resetRateLimit { + apiKey.Usage5h = 0 + apiKey.Usage1d = 0 + apiKey.Usage7d = 0 + apiKey.Window5hStart = nil + apiKey.Window1dStart = nil + apiKey.Window7dStart = nil + } + + if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil { + return nil, fmt.Errorf("update api key: %w", err) + } + + s.InvalidateAuthCacheByKey(ctx, apiKey.Key) + s.compileAPIKeyIPRules(apiKey) + + // Invalidate Redis rate limit cache so reset takes effect immediately + if resetRateLimit && s.rateLimitCacheInvalid != nil { + _ = s.rateLimitCacheInvalid.InvalidateAPIKeyRateLimit(ctx, apiKey.ID) + } + + return apiKey, nil +} + +// Delete 删除API Key +func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) error { + key, ownerID, err := s.apiKeyRepo.GetKeyAndOwnerID(ctx, id) + if err != nil { + return fmt.Errorf("get api key: %w", err) + } + + // 验证当前用户是否为该 API Key 的所有者 + if ownerID != userID { + return ErrInsufficientPerms + } + + // 清除Redis缓存(使用 userID 而非 apiKey.UserID) + if s.cache != nil { + _ = s.cache.DeleteCreateAttemptCount(ctx, userID) + } + s.InvalidateAuthCacheByKey(ctx, key) + + if err := s.apiKeyRepo.Delete(ctx, id); err != nil { + return fmt.Errorf("delete api key: %w", err) + } + s.lastUsedTouchL1.Delete(id) + + return nil +} + +// ValidateKey 验证API Key是否有效(用于认证中间件) +func (s *APIKeyService) ValidateKey(ctx context.Context, key string) (*APIKey, *User, error) { + // 获取API Key + apiKey, err := s.GetByKey(ctx, key) + if err != nil { + return nil, nil, err + } + + // 检查API Key状态 + if !apiKey.IsActive() { + return nil, nil, infraerrors.Unauthorized("API_KEY_INACTIVE", "api key is not active") + } + + // 获取用户信息 + user, err := s.userRepo.GetByID(ctx, apiKey.UserID) + if err != nil { + return nil, nil, fmt.Errorf("get user: %w", err) + } + + // 检查用户状态 + if !user.IsActive() { + return nil, nil, ErrUserNotActive + } + + return apiKey, user, nil +} + +// TouchLastUsed 通过防抖更新 api_keys.last_used_at,减少高频写放大。 +// 该操作为尽力而为,不应阻塞主请求链路。 +func (s *APIKeyService) TouchLastUsed(ctx context.Context, keyID int64) error { + if keyID <= 0 { + return nil + } + + now := time.Now() + if v, ok := s.lastUsedTouchL1.Load(keyID); ok { + if nextAllowedAt, ok := v.(time.Time); ok && now.Before(nextAllowedAt) { + return nil + } + } + + _, err, _ := s.lastUsedTouchSF.Do(strconv.FormatInt(keyID, 10), func() (any, error) { + latest := time.Now() + if v, ok := s.lastUsedTouchL1.Load(keyID); ok { + if nextAllowedAt, ok := v.(time.Time); ok && latest.Before(nextAllowedAt) { + return nil, nil + } + } + + if err := s.apiKeyRepo.UpdateLastUsed(ctx, keyID, latest); err != nil { + s.lastUsedTouchL1.Store(keyID, latest.Add(apiKeyLastUsedFailBackoff)) + return nil, fmt.Errorf("touch api key last used: %w", err) + } + s.lastUsedTouchL1.Store(keyID, latest.Add(apiKeyLastUsedMinTouch)) + return nil, nil + }) + return err +} + +// IncrementUsage 增加API Key使用次数(可选:用于统计) +func (s *APIKeyService) IncrementUsage(ctx context.Context, keyID int64) error { + // 使用Redis计数器 + if s.cache != nil { + cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02")) + if err := s.cache.IncrementDailyUsage(ctx, cacheKey); err != nil { + return fmt.Errorf("increment usage: %w", err) + } + // 设置24小时过期 + _ = s.cache.SetDailyUsageExpiry(ctx, cacheKey, 24*time.Hour) + } + return nil +} + +// GetAvailableGroups 获取用户有权限绑定的分组列表 +// 返回用户可以选择的分组: +// - 标准类型分组:公开的(非专属)或用户被明确允许的 +// - 订阅类型分组:用户有有效订阅的 +func (s *APIKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) { + // 获取用户信息 + user, err := s.userRepo.GetByID(ctx, userID) + if err != nil { + return nil, fmt.Errorf("get user: %w", err) + } + + // 获取所有活跃分组 + allGroups, err := s.groupRepo.ListActive(ctx) + if err != nil { + return nil, fmt.Errorf("list active groups: %w", err) + } + + // 获取用户的所有有效订阅 + activeSubscriptions, err := s.userSubRepo.ListActiveByUserID(ctx, userID) + if err != nil { + return nil, fmt.Errorf("list active subscriptions: %w", err) + } + + // 构建订阅分组 ID 集合 + subscribedGroupIDs := make(map[int64]bool) + for _, sub := range activeSubscriptions { + subscribedGroupIDs[sub.GroupID] = true + } + + // 过滤出用户有权限的分组 + availableGroups := make([]Group, 0) + for _, group := range allGroups { + if s.canUserBindGroupInternal(user, &group, subscribedGroupIDs) { + availableGroups = append(availableGroups, group) + } + } + + return availableGroups, nil +} + +// canUserBindGroupInternal 内部方法,检查用户是否可以绑定分组(使用预加载的订阅数据) +func (s *APIKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool { + // 订阅类型分组:需要有效订阅 + if group.IsSubscriptionType() { + return subscribedGroupIDs[group.ID] + } + // 标准类型分组:使用原有逻辑 + return user.CanBindGroup(group.ID, group.IsExclusive) +} + +func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) { + keys, err := s.apiKeyRepo.SearchAPIKeys(ctx, userID, keyword, limit) + if err != nil { + return nil, fmt.Errorf("search api keys: %w", err) + } + return keys, nil +} + +// GetUserGroupRates 获取用户的专属分组倍率配置 +// 返回 map[groupID]rateMultiplier +func (s *APIKeyService) GetUserGroupRates(ctx context.Context, userID int64) (map[int64]float64, error) { + if s.userGroupRateRepo == nil { + return nil, nil + } + rates, err := s.userGroupRateRepo.GetByUserID(ctx, userID) + if err != nil { + return nil, fmt.Errorf("get user group rates: %w", err) + } + return rates, nil +} + +// CheckAPIKeyQuotaAndExpiry checks if the API key is valid for use (not expired, quota not exhausted) +// Returns nil if valid, error if invalid +func (s *APIKeyService) CheckAPIKeyQuotaAndExpiry(apiKey *APIKey) error { + // Check expiration + if apiKey.IsExpired() { + return ErrAPIKeyExpired + } + + // Check quota + if apiKey.IsQuotaExhausted() { + return ErrAPIKeyQuotaExhausted + } + + return nil +} + +// UpdateQuotaUsed updates the quota_used field after a request +// Also checks if quota is exhausted and updates status accordingly +func (s *APIKeyService) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error { + if cost <= 0 { + return nil + } + + type quotaStateReader interface { + IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*APIKeyQuotaUsageState, error) + } + + if repo, ok := s.apiKeyRepo.(quotaStateReader); ok { + state, err := repo.IncrementQuotaUsedAndGetState(ctx, apiKeyID, cost) + if err != nil { + return fmt.Errorf("increment quota used: %w", err) + } + if state != nil && state.Status == StatusAPIKeyQuotaExhausted && strings.TrimSpace(state.Key) != "" { + s.InvalidateAuthCacheByKey(ctx, state.Key) + } + return nil + } + + // Use repository to atomically increment quota_used + newQuotaUsed, err := s.apiKeyRepo.IncrementQuotaUsed(ctx, apiKeyID, cost) + if err != nil { + return fmt.Errorf("increment quota used: %w", err) + } + + // Check if quota is now exhausted and update status if needed + apiKey, err := s.apiKeyRepo.GetByID(ctx, apiKeyID) + if err != nil { + return nil // Don't fail the request, just log + } + + // If quota is set and now exhausted, update status + if apiKey.Quota > 0 && newQuotaUsed >= apiKey.Quota { + apiKey.Status = StatusAPIKeyQuotaExhausted + if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil { + return nil // Don't fail the request + } + // Invalidate cache so next request sees the new status + s.InvalidateAuthCacheByKey(ctx, apiKey.Key) + } + + return nil +} + +// GetRateLimitData returns rate limit usage and window state for an API key. +func (s *APIKeyService) GetRateLimitData(ctx context.Context, id int64) (*APIKeyRateLimitData, error) { + return s.apiKeyRepo.GetRateLimitData(ctx, id) +} + +// UpdateRateLimitUsage atomically increments rate limit usage counters in the DB. +func (s *APIKeyService) UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error { + if cost <= 0 { + return nil + } + return s.apiKeyRepo.IncrementRateLimitUsage(ctx, apiKeyID, cost) +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go new file mode 100644 index 00000000..dfc790cf --- /dev/null +++ b/backend/internal/service/gateway_service.go @@ -0,0 +1,8510 @@ +package service + +import ( + "bufio" + "bytes" + "context" + "crypto/sha256" + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + mathrand "math/rand" + "net/http" + "os" + "regexp" + "sort" + "strconv" + "strings" + "sync/atomic" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" + "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" + "github.com/cespare/xxhash/v2" + "github.com/google/uuid" + gocache "github.com/patrickmn/go-cache" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "golang.org/x/sync/singleflight" + + "github.com/gin-gonic/gin" +) + +const ( + claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true" + claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true" + stickySessionTTL = time.Hour // 粘性会话TTL + defaultMaxLineSize = 500 * 1024 * 1024 + // Canonical Claude Code banner. Keep it EXACT (no trailing whitespace/newlines) + // to match real Claude CLI traffic as closely as possible. When we need a visual + // separator between system blocks, we add "\n\n" at concatenation time. + claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude." + maxCacheControlBlocks = 4 // Anthropic API 允许的最大 cache_control 块数量 + + defaultUserGroupRateCacheTTL = 30 * time.Second + defaultModelsListCacheTTL = 15 * time.Second + postUsageBillingTimeout = 15 * time.Second + debugGatewayBodyEnv = "SUB2API_DEBUG_GATEWAY_BODY" +) + +const ( + claudeMimicDebugInfoKey = "claude_mimic_debug_info" +) + +// ForceCacheBillingContextKey 强制缓存计费上下文键 +// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费 +type forceCacheBillingKeyType struct{} + +// accountWithLoad 账号与负载信息的组合,用于负载感知调度 +type accountWithLoad struct { + account *Account + loadInfo *AccountLoadInfo +} + +var ForceCacheBillingContextKey = forceCacheBillingKeyType{} + +var ( + windowCostPrefetchCacheHitTotal atomic.Int64 + windowCostPrefetchCacheMissTotal atomic.Int64 + windowCostPrefetchBatchSQLTotal atomic.Int64 + windowCostPrefetchFallbackTotal atomic.Int64 + windowCostPrefetchErrorTotal atomic.Int64 + + userGroupRateCacheHitTotal atomic.Int64 + userGroupRateCacheMissTotal atomic.Int64 + userGroupRateCacheLoadTotal atomic.Int64 + userGroupRateCacheSFSharedTotal atomic.Int64 + userGroupRateCacheFallbackTotal atomic.Int64 + + modelsListCacheHitTotal atomic.Int64 + modelsListCacheMissTotal atomic.Int64 + modelsListCacheStoreTotal atomic.Int64 +) + +func GatewayWindowCostPrefetchStats() (cacheHit, cacheMiss, batchSQL, fallback, errCount int64) { + return windowCostPrefetchCacheHitTotal.Load(), + windowCostPrefetchCacheMissTotal.Load(), + windowCostPrefetchBatchSQLTotal.Load(), + windowCostPrefetchFallbackTotal.Load(), + windowCostPrefetchErrorTotal.Load() +} + +func GatewayUserGroupRateCacheStats() (cacheHit, cacheMiss, load, singleflightShared, fallback int64) { + return userGroupRateCacheHitTotal.Load(), + userGroupRateCacheMissTotal.Load(), + userGroupRateCacheLoadTotal.Load(), + userGroupRateCacheSFSharedTotal.Load(), + userGroupRateCacheFallbackTotal.Load() +} + +func GatewayModelsListCacheStats() (cacheHit, cacheMiss, store int64) { + return modelsListCacheHitTotal.Load(), modelsListCacheMissTotal.Load(), modelsListCacheStoreTotal.Load() +} + +func openAIStreamEventIsTerminal(data string) bool { + trimmed := strings.TrimSpace(data) + if trimmed == "" { + return false + } + if trimmed == "[DONE]" { + return true + } + switch gjson.Get(trimmed, "type").String() { + case "response.completed", "response.done", "response.failed": + return true + default: + return false + } +} + +func anthropicStreamEventIsTerminal(eventName, data string) bool { + if strings.EqualFold(strings.TrimSpace(eventName), "message_stop") { + return true + } + trimmed := strings.TrimSpace(data) + if trimmed == "" { + return false + } + if trimmed == "[DONE]" { + return true + } + return gjson.Get(trimmed, "type").String() == "message_stop" +} + +func cloneStringSlice(src []string) []string { + if len(src) == 0 { + return nil + } + dst := make([]string, len(src)) + copy(dst, src) + return dst +} + +// IsForceCacheBilling 检查是否启用强制缓存计费 +func IsForceCacheBilling(ctx context.Context) bool { + v, _ := ctx.Value(ForceCacheBillingContextKey).(bool) + return v +} + +// WithForceCacheBilling 返回带有强制缓存计费标记的上下文 +func WithForceCacheBilling(ctx context.Context) context.Context { + return context.WithValue(ctx, ForceCacheBillingContextKey, true) +} + +func (s *GatewayService) debugModelRoutingEnabled() bool { + if s == nil { + return false + } + return s.debugModelRouting.Load() +} + +func (s *GatewayService) debugClaudeMimicEnabled() bool { + if s == nil { + return false + } + return s.debugClaudeMimic.Load() +} + +func parseDebugEnvBool(raw string) bool { + switch strings.ToLower(strings.TrimSpace(raw)) { + case "1", "true", "yes", "on": + return true + default: + return false + } +} + +func shortSessionHash(sessionHash string) string { + if sessionHash == "" { + return "" + } + if len(sessionHash) <= 8 { + return sessionHash + } + return sessionHash[:8] +} + +func redactAuthHeaderValue(v string) string { + v = strings.TrimSpace(v) + if v == "" { + return "" + } + // Keep scheme for debugging, redact secret. + if strings.HasPrefix(strings.ToLower(v), "bearer ") { + return "Bearer [redacted]" + } + return "[redacted]" +} + +func safeHeaderValueForLog(key string, v string) string { + key = strings.ToLower(strings.TrimSpace(key)) + switch key { + case "authorization", "x-api-key": + return redactAuthHeaderValue(v) + default: + return strings.TrimSpace(v) + } +} + +func extractSystemPreviewFromBody(body []byte) string { + if len(body) == 0 { + return "" + } + sys := gjson.GetBytes(body, "system") + if !sys.Exists() { + return "" + } + + switch { + case sys.IsArray(): + for _, item := range sys.Array() { + if !item.IsObject() { + continue + } + if strings.EqualFold(item.Get("type").String(), "text") { + if t := item.Get("text").String(); strings.TrimSpace(t) != "" { + return t + } + } + } + return "" + case sys.Type == gjson.String: + return sys.String() + default: + return "" + } +} + +func buildClaudeMimicDebugLine(req *http.Request, body []byte, account *Account, tokenType string, mimicClaudeCode bool) string { + if req == nil { + return "" + } + + // Only log a minimal fingerprint to avoid leaking user content. + interesting := []string{ + "user-agent", + "x-app", + "anthropic-dangerous-direct-browser-access", + "anthropic-version", + "anthropic-beta", + "x-stainless-lang", + "x-stainless-package-version", + "x-stainless-os", + "x-stainless-arch", + "x-stainless-runtime", + "x-stainless-runtime-version", + "x-stainless-retry-count", + "x-stainless-timeout", + "authorization", + "x-api-key", + "content-type", + "accept", + "x-stainless-helper-method", + } + + h := make([]string, 0, len(interesting)) + for _, k := range interesting { + if v := req.Header.Get(k); v != "" { + h = append(h, fmt.Sprintf("%s=%q", k, safeHeaderValueForLog(k, v))) + } + } + + metaUserID := strings.TrimSpace(gjson.GetBytes(body, "metadata.user_id").String()) + sysPreview := strings.TrimSpace(extractSystemPreviewFromBody(body)) + + // Truncate preview to keep logs sane. + if len(sysPreview) > 300 { + sysPreview = sysPreview[:300] + "..." + } + sysPreview = strings.ReplaceAll(sysPreview, "\n", "\\n") + sysPreview = strings.ReplaceAll(sysPreview, "\r", "\\r") + + aid := int64(0) + aname := "" + if account != nil { + aid = account.ID + aname = account.Name + } + + return fmt.Sprintf( + "url=%s account=%d(%s) tokenType=%s mimic=%t meta.user_id=%q system.preview=%q headers={%s}", + req.URL.String(), + aid, + aname, + tokenType, + mimicClaudeCode, + metaUserID, + sysPreview, + strings.Join(h, " "), + ) +} + +func logClaudeMimicDebug(req *http.Request, body []byte, account *Account, tokenType string, mimicClaudeCode bool) { + line := buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode) + if line == "" { + return + } + logger.LegacyPrintf("service.gateway", "[ClaudeMimicDebug] %s", line) +} + +func isClaudeCodeCredentialScopeError(msg string) bool { + m := strings.ToLower(strings.TrimSpace(msg)) + if m == "" { + return false + } + return strings.Contains(m, "only authorized for use with claude code") && + strings.Contains(m, "cannot be used for other api requests") +} + +// sseDataRe matches SSE data lines with optional whitespace after colon. +// Some upstream APIs return non-standard "data:" without space (should be "data: "). +var ( + sseDataRe = regexp.MustCompile(`^data:\s*`) + claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`) + + // claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表 + // 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等 + // 注意:前缀之间不应存在包含关系,否则会导致冗余匹配 + claudeCodePromptPrefixes = []string{ + "You are Claude Code, Anthropic's official CLI for Claude", // 标准版 & Agent SDK 版(含 running within...) + "You are a Claude agent, built on Anthropic's Claude Agent SDK", // Agent SDK 变体 + "You are a file search specialist for Claude Code", // Explore Agent 版 + "You are a helpful AI assistant tasked with summarizing conversations", // Compact 版 + } +) + +// ErrNoAvailableAccounts 表示没有可用的账号 +var ErrNoAvailableAccounts = errors.New("no available accounts") + +// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问 +var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients") + +// allowedHeaders 白名单headers(参考CRS项目) +var allowedHeaders = map[string]bool{ + "accept": true, + "x-stainless-retry-count": true, + "x-stainless-timeout": true, + "x-stainless-lang": true, + "x-stainless-package-version": true, + "x-stainless-os": true, + "x-stainless-arch": true, + "x-stainless-runtime": true, + "x-stainless-runtime-version": true, + "x-stainless-helper-method": true, + "anthropic-dangerous-direct-browser-access": true, + "anthropic-version": true, + "x-app": true, + "anthropic-beta": true, + "accept-language": true, + "sec-fetch-mode": true, + "user-agent": true, + "content-type": true, +} + +// GatewayCache 定义网关服务的缓存操作接口。 +// 提供粘性会话(Sticky Session)的存储、查询、刷新和删除功能。 +// +// GatewayCache defines cache operations for gateway service. +// Provides sticky session storage, retrieval, refresh and deletion capabilities. +type GatewayCache interface { + // GetSessionAccountID 获取粘性会话绑定的账号 ID + // Get the account ID bound to a sticky session + GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) + // SetSessionAccountID 设置粘性会话与账号的绑定关系 + // Set the binding between sticky session and account + SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error + // RefreshSessionTTL 刷新粘性会话的过期时间 + // Refresh the expiration time of a sticky session + RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error + // DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理 + // Delete sticky session binding, used to proactively clean up when account becomes unavailable + DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error +} + +// derefGroupID safely dereferences *int64 to int64, returning 0 if nil +func derefGroupID(groupID *int64) int64 { + if groupID == nil { + return 0 + } + return *groupID +} + +func resolveUserGroupRateCacheTTL(cfg *config.Config) time.Duration { + if cfg == nil || cfg.Gateway.UserGroupRateCacheTTLSeconds <= 0 { + return defaultUserGroupRateCacheTTL + } + return time.Duration(cfg.Gateway.UserGroupRateCacheTTLSeconds) * time.Second +} + +func resolveModelsListCacheTTL(cfg *config.Config) time.Duration { + if cfg == nil || cfg.Gateway.ModelsListCacheTTLSeconds <= 0 { + return defaultModelsListCacheTTL + } + return time.Duration(cfg.Gateway.ModelsListCacheTTLSeconds) * time.Second +} + +func modelsListCacheKey(groupID *int64, platform string) string { + return fmt.Sprintf("%d|%s", derefGroupID(groupID), strings.TrimSpace(platform)) +} + +func prefetchedStickyGroupIDFromContext(ctx context.Context) (int64, bool) { + return PrefetchedStickyGroupIDFromContext(ctx) +} + +func prefetchedStickyAccountIDFromContext(ctx context.Context, groupID *int64) int64 { + prefetchedGroupID, ok := prefetchedStickyGroupIDFromContext(ctx) + if !ok || prefetchedGroupID != derefGroupID(groupID) { + return 0 + } + if accountID, ok := PrefetchedStickyAccountIDFromContext(ctx); ok && accountID > 0 { + return accountID + } + return 0 +} + +// shouldClearStickySession checks if an account is in an unschedulable state +// and the sticky session binding should be cleared. +// Returns true when account status is error/disabled, schedulable is false, +// within temporary unschedulable period, or the requested model is rate-limited. +// This ensures subsequent requests won't continue using unavailable accounts. +func shouldClearStickySession(ctx context.Context, account *Account, requestedModel string) bool { + if account == nil { + return false + } + if account.Status == StatusError || account.Status == StatusDisabled || !account.Schedulable { + return true + } + if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) { + return true + } + if remaining := account.GetRateLimitRemainingTimeWithContext(ctx, requestedModel); remaining > 0 { + return true + } + return false +} + +type AccountWaitPlan struct { + AccountID int64 + MaxConcurrency int + Timeout time.Duration + MaxWaiting int +} + +type AccountSelectionResult struct { + Account *Account + Acquired bool + ReleaseFunc func() + WaitPlan *AccountWaitPlan // nil means no wait allowed +} + +// ClaudeUsage 表示Claude API返回的usage信息 +type ClaudeUsage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + CacheCreationInputTokens int `json:"cache_creation_input_tokens"` + CacheReadInputTokens int `json:"cache_read_input_tokens"` + CacheCreation5mTokens int // 5分钟缓存创建token(来自嵌套 cache_creation 对象) + CacheCreation1hTokens int // 1小时缓存创建token(来自嵌套 cache_creation 对象) +} + +// ForwardResult 转发结果 +type ForwardResult struct { + RequestID string + Usage ClaudeUsage + Model string + // UpstreamModel is the actual upstream model after mapping. + // Prefer empty when it is identical to Model; persistence normalizes equal values away as no-op mappings. + UpstreamModel string + Stream bool + Duration time.Duration + FirstTokenMs *int // 首字时间(流式请求) + ClientDisconnect bool // 客户端是否在流式传输过程中断开 + ReasoningEffort *string + + // 图片生成计费字段(图片生成模型使用) + ImageCount int // 生成的图片数量 + ImageSize string // 图片尺寸 "1K", "2K", "4K" + + // Sora 媒体字段 + MediaType string // image / video / prompt + MediaURL string // 生成后的媒体地址(可选) +} + +// UpstreamFailoverError indicates an upstream error that should trigger account failover. +type UpstreamFailoverError struct { + StatusCode int + ResponseBody []byte // 上游响应体,用于错误透传规则匹配 + ResponseHeaders http.Header // 上游响应头,用于透传 cf-ray/cf-mitigated/content-type 等诊断信息 + ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true + RetryableOnSameAccount bool // 临时性错误(如 Google 间歇性 400、空响应),应在同一账号上重试 N 次再切换 +} + +func (e *UpstreamFailoverError) Error() string { + return fmt.Sprintf("upstream error: %d (failover)", e.StatusCode) +} + +// TempUnscheduleRetryableError 对 RetryableOnSameAccount 类型的 failover 错误触发临时封禁。 +// 由 handler 层在同账号重试全部用尽、切换账号时调用。 +func (s *GatewayService) TempUnscheduleRetryableError(ctx context.Context, accountID int64, failoverErr *UpstreamFailoverError) { + if failoverErr == nil || !failoverErr.RetryableOnSameAccount { + return + } + // 根据状态码选择封禁策略 + switch failoverErr.StatusCode { + case http.StatusBadRequest: + tempUnscheduleGoogleConfigError(ctx, s.accountRepo, accountID, "[handler]") + case http.StatusBadGateway: + tempUnscheduleEmptyResponse(ctx, s.accountRepo, accountID, "[handler]") + } +} + +// GatewayService handles API gateway operations +type GatewayService struct { + accountRepo AccountRepository + groupRepo GroupRepository + usageLogRepo UsageLogRepository + usageBillingRepo UsageBillingRepository + userRepo UserRepository + userSubRepo UserSubscriptionRepository + userGroupRateRepo UserGroupRateRepository + cache GatewayCache + digestStore *DigestSessionStore + cfg *config.Config + schedulerSnapshot *SchedulerSnapshotService + billingService *BillingService + rateLimitService *RateLimitService + billingCacheService *BillingCacheService + identityService *IdentityService + httpUpstream HTTPUpstream + deferredService *DeferredService + concurrencyService *ConcurrencyService + claudeTokenProvider *ClaudeTokenProvider + sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken) + rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken) + userGroupRateResolver *userGroupRateResolver + userGroupRateCache *gocache.Cache + userGroupRateSF singleflight.Group + modelsListCache *gocache.Cache + modelsListCacheTTL time.Duration + settingService *SettingService + responseHeaderFilter *responseheaders.CompiledHeaderFilter + debugModelRouting atomic.Bool + debugClaudeMimic atomic.Bool +} + +// NewGatewayService creates a new GatewayService +func NewGatewayService( + accountRepo AccountRepository, + groupRepo GroupRepository, + usageLogRepo UsageLogRepository, + usageBillingRepo UsageBillingRepository, + userRepo UserRepository, + userSubRepo UserSubscriptionRepository, + userGroupRateRepo UserGroupRateRepository, + cache GatewayCache, + cfg *config.Config, + schedulerSnapshot *SchedulerSnapshotService, + concurrencyService *ConcurrencyService, + billingService *BillingService, + rateLimitService *RateLimitService, + billingCacheService *BillingCacheService, + identityService *IdentityService, + httpUpstream HTTPUpstream, + deferredService *DeferredService, + claudeTokenProvider *ClaudeTokenProvider, + sessionLimitCache SessionLimitCache, + rpmCache RPMCache, + digestStore *DigestSessionStore, + settingService *SettingService, +) *GatewayService { + userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg) + modelsListTTL := resolveModelsListCacheTTL(cfg) + + svc := &GatewayService{ + accountRepo: accountRepo, + groupRepo: groupRepo, + usageLogRepo: usageLogRepo, + usageBillingRepo: usageBillingRepo, + userRepo: userRepo, + userSubRepo: userSubRepo, + userGroupRateRepo: userGroupRateRepo, + cache: cache, + digestStore: digestStore, + cfg: cfg, + schedulerSnapshot: schedulerSnapshot, + concurrencyService: concurrencyService, + billingService: billingService, + rateLimitService: rateLimitService, + billingCacheService: billingCacheService, + identityService: identityService, + httpUpstream: httpUpstream, + deferredService: deferredService, + claudeTokenProvider: claudeTokenProvider, + sessionLimitCache: sessionLimitCache, + rpmCache: rpmCache, + userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute), + settingService: settingService, + modelsListCache: gocache.New(modelsListTTL, time.Minute), + modelsListCacheTTL: modelsListTTL, + responseHeaderFilter: compileResponseHeaderFilter(cfg), + } + svc.userGroupRateResolver = newUserGroupRateResolver( + userGroupRateRepo, + svc.userGroupRateCache, + userGroupRateTTL, + &svc.userGroupRateSF, + "service.gateway", + ) + svc.debugModelRouting.Store(parseDebugEnvBool(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING"))) + svc.debugClaudeMimic.Store(parseDebugEnvBool(os.Getenv("SUB2API_DEBUG_CLAUDE_MIMIC"))) + return svc +} + +// GenerateSessionHash 从预解析请求计算粘性会话 hash +func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { + if parsed == nil { + return "" + } + + // 1. 最高优先级:从 metadata.user_id 提取 session_xxx + if parsed.MetadataUserID != "" { + if uid := ParseMetadataUserID(parsed.MetadataUserID); uid != nil && uid.SessionID != "" { + return uid.SessionID + } + } + + // 2. 提取带 cache_control: {type: "ephemeral"} 的内容 + cacheableContent := s.extractCacheableContent(parsed) + if cacheableContent != "" { + return s.hashContent(cacheableContent) + } + + // 3. 最后 fallback: 使用 session上下文 + system + 所有消息的完整摘要串 + var combined strings.Builder + // 混入请求上下文区分因子,避免不同用户相同消息产生相同 hash + if parsed.SessionContext != nil { + _, _ = combined.WriteString(parsed.SessionContext.ClientIP) + _, _ = combined.WriteString(":") + _, _ = combined.WriteString(parsed.SessionContext.UserAgent) + _, _ = combined.WriteString(":") + _, _ = combined.WriteString(strconv.FormatInt(parsed.SessionContext.APIKeyID, 10)) + _, _ = combined.WriteString("|") + } + if parsed.System != nil { + systemText := s.extractTextFromSystem(parsed.System) + if systemText != "" { + _, _ = combined.WriteString(systemText) + } + } + for _, msg := range parsed.Messages { + if m, ok := msg.(map[string]any); ok { + if content, exists := m["content"]; exists { + // Anthropic: messages[].content + if msgText := s.extractTextFromContent(content); msgText != "" { + _, _ = combined.WriteString(msgText) + } + } else if parts, ok := m["parts"].([]any); ok { + // Gemini: contents[].parts[].text + for _, part := range parts { + if partMap, ok := part.(map[string]any); ok { + if text, ok := partMap["text"].(string); ok { + _, _ = combined.WriteString(text) + } + } + } + } + } + } + if combined.Len() > 0 { + return s.hashContent(combined.String()) + } + + return "" +} + +// BindStickySession sets session -> account binding with standard TTL. +func (s *GatewayService) BindStickySession(ctx context.Context, groupID *int64, sessionHash string, accountID int64) error { + if sessionHash == "" || accountID <= 0 || s.cache == nil { + return nil + } + return s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, accountID, stickySessionTTL) +} + +// GetCachedSessionAccountID retrieves the account ID bound to a sticky session. +// Returns 0 if no binding exists or on error. +func (s *GatewayService) GetCachedSessionAccountID(ctx context.Context, groupID *int64, sessionHash string) (int64, error) { + if sessionHash == "" || s.cache == nil { + return 0, nil + } + accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) + if err != nil { + return 0, err + } + return accountID, nil +} + +// FindGeminiSession 查找 Gemini 会话(基于内容摘要链的 Fallback 匹配) +// 返回最长匹配的会话信息(uuid, accountID) +func (s *GatewayService) FindGeminiSession(_ context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) { + if digestChain == "" || s.digestStore == nil { + return "", 0, "", false + } + return s.digestStore.Find(groupID, prefixHash, digestChain) +} + +// SaveGeminiSession 保存 Gemini 会话。oldDigestChain 为 Find 返回的 matchedChain,用于删旧 key。 +func (s *GatewayService) SaveGeminiSession(_ context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) error { + if digestChain == "" || s.digestStore == nil { + return nil + } + s.digestStore.Save(groupID, prefixHash, digestChain, uuid, accountID, oldDigestChain) + return nil +} + +// FindAnthropicSession 查找 Anthropic 会话(基于内容摘要链的 Fallback 匹配) +func (s *GatewayService) FindAnthropicSession(_ context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) { + if digestChain == "" || s.digestStore == nil { + return "", 0, "", false + } + return s.digestStore.Find(groupID, prefixHash, digestChain) +} + +// SaveAnthropicSession 保存 Anthropic 会话 +func (s *GatewayService) SaveAnthropicSession(_ context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) error { + if digestChain == "" || s.digestStore == nil { + return nil + } + s.digestStore.Save(groupID, prefixHash, digestChain, uuid, accountID, oldDigestChain) + return nil +} + +func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string { + if parsed == nil { + return "" + } + + var builder strings.Builder + + // 检查 system 中的 cacheable 内容 + if system, ok := parsed.System.([]any); ok { + for _, part := range system { + if partMap, ok := part.(map[string]any); ok { + if cc, ok := partMap["cache_control"].(map[string]any); ok { + if cc["type"] == "ephemeral" { + if text, ok := partMap["text"].(string); ok { + _, _ = builder.WriteString(text) + } + } + } + } + } + } + systemText := builder.String() + + // 检查 messages 中的 cacheable 内容 + for _, msg := range parsed.Messages { + if msgMap, ok := msg.(map[string]any); ok { + if msgContent, ok := msgMap["content"].([]any); ok { + for _, part := range msgContent { + if partMap, ok := part.(map[string]any); ok { + if cc, ok := partMap["cache_control"].(map[string]any); ok { + if cc["type"] == "ephemeral" { + return s.extractTextFromContent(msgMap["content"]) + } + } + } + } + } + } + } + + return systemText +} + +func (s *GatewayService) extractTextFromSystem(system any) string { + switch v := system.(type) { + case string: + return v + case []any: + var texts []string + for _, part := range v { + if partMap, ok := part.(map[string]any); ok { + if text, ok := partMap["text"].(string); ok { + texts = append(texts, text) + } + } + } + return strings.Join(texts, "") + } + return "" +} + +func (s *GatewayService) extractTextFromContent(content any) string { + switch v := content.(type) { + case string: + return v + case []any: + var texts []string + for _, part := range v { + if partMap, ok := part.(map[string]any); ok { + if partMap["type"] == "text" { + if text, ok := partMap["text"].(string); ok { + texts = append(texts, text) + } + } + } + } + return strings.Join(texts, "") + } + return "" +} + +func (s *GatewayService) hashContent(content string) string { + h := xxhash.Sum64String(content) + return strconv.FormatUint(h, 36) +} + +type anthropicCacheControlPayload struct { + Type string `json:"type"` +} + +type anthropicSystemTextBlockPayload struct { + Type string `json:"type"` + Text string `json:"text"` + CacheControl *anthropicCacheControlPayload `json:"cache_control,omitempty"` +} + +type anthropicMetadataPayload struct { + UserID string `json:"user_id"` +} + +// replaceModelInBody 替换请求体中的model字段 +// 优先使用定点修改,尽量保持客户端原始字段顺序。 +func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte { + if len(body) == 0 { + return body + } + if current := gjson.GetBytes(body, "model"); current.Exists() && current.String() == newModel { + return body + } + newBody, err := sjson.SetBytes(body, "model", newModel) + if err != nil { + return body + } + return newBody +} + +type claudeOAuthNormalizeOptions struct { + injectMetadata bool + metadataUserID string + stripSystemCacheControl bool +} + +// sanitizeSystemText rewrites only the fixed OpenCode identity sentence (if present). +// We intentionally avoid broad keyword replacement in system prompts to prevent +// accidentally changing user-provided instructions. +func sanitizeSystemText(text string) string { + if text == "" { + return text + } + // Some clients include a fixed OpenCode identity sentence. Anthropic may treat + // this as a non-Claude-Code fingerprint, so rewrite it to the canonical + // Claude Code banner before generic "OpenCode"/"opencode" replacements. + text = strings.ReplaceAll( + text, + "You are OpenCode, the best coding agent on the planet.", + strings.TrimSpace(claudeCodeSystemPrompt), + ) + return text +} + +func marshalAnthropicSystemTextBlock(text string, includeCacheControl bool) ([]byte, error) { + block := anthropicSystemTextBlockPayload{ + Type: "text", + Text: text, + } + if includeCacheControl { + block.CacheControl = &anthropicCacheControlPayload{Type: "ephemeral"} + } + return json.Marshal(block) +} + +func marshalAnthropicMetadata(userID string) ([]byte, error) { + return json.Marshal(anthropicMetadataPayload{UserID: userID}) +} + +func buildJSONArrayRaw(items [][]byte) []byte { + if len(items) == 0 { + return []byte("[]") + } + + total := 2 + for _, item := range items { + total += len(item) + } + total += len(items) - 1 + + buf := make([]byte, 0, total) + buf = append(buf, '[') + for i, item := range items { + if i > 0 { + buf = append(buf, ',') + } + buf = append(buf, item...) + } + buf = append(buf, ']') + return buf +} + +func setJSONValueBytes(body []byte, path string, value any) ([]byte, bool) { + next, err := sjson.SetBytes(body, path, value) + if err != nil { + return body, false + } + return next, true +} + +func setJSONRawBytes(body []byte, path string, raw []byte) ([]byte, bool) { + next, err := sjson.SetRawBytes(body, path, raw) + if err != nil { + return body, false + } + return next, true +} + +func deleteJSONPathBytes(body []byte, path string) ([]byte, bool) { + next, err := sjson.DeleteBytes(body, path) + if err != nil { + return body, false + } + return next, true +} + +func normalizeClaudeOAuthSystemBody(body []byte, opts claudeOAuthNormalizeOptions) ([]byte, bool) { + sys := gjson.GetBytes(body, "system") + if !sys.Exists() { + return body, false + } + + out := body + modified := false + + switch { + case sys.Type == gjson.String: + sanitized := sanitizeSystemText(sys.String()) + if sanitized != sys.String() { + if next, ok := setJSONValueBytes(out, "system", sanitized); ok { + out = next + modified = true + } + } + case sys.IsArray(): + index := 0 + sys.ForEach(func(_, item gjson.Result) bool { + if item.Get("type").String() == "text" { + textResult := item.Get("text") + if textResult.Exists() && textResult.Type == gjson.String { + text := textResult.String() + sanitized := sanitizeSystemText(text) + if sanitized != text { + if next, ok := setJSONValueBytes(out, fmt.Sprintf("system.%d.text", index), sanitized); ok { + out = next + modified = true + } + } + } + } + + if opts.stripSystemCacheControl && item.Get("cache_control").Exists() { + if next, ok := deleteJSONPathBytes(out, fmt.Sprintf("system.%d.cache_control", index)); ok { + out = next + modified = true + } + } + + index++ + return true + }) + } + + return out, modified +} + +func ensureClaudeOAuthMetadataUserID(body []byte, userID string) ([]byte, bool) { + if strings.TrimSpace(userID) == "" { + return body, false + } + + metadata := gjson.GetBytes(body, "metadata") + if !metadata.Exists() || metadata.Type == gjson.Null { + raw, err := marshalAnthropicMetadata(userID) + if err != nil { + return body, false + } + return setJSONRawBytes(body, "metadata", raw) + } + + trimmedRaw := strings.TrimSpace(metadata.Raw) + if strings.HasPrefix(trimmedRaw, "{") { + existing := metadata.Get("user_id") + if existing.Exists() && existing.Type == gjson.String && existing.String() != "" { + return body, false + } + return setJSONValueBytes(body, "metadata.user_id", userID) + } + + raw, err := marshalAnthropicMetadata(userID) + if err != nil { + return body, false + } + return setJSONRawBytes(body, "metadata", raw) +} + +func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAuthNormalizeOptions) ([]byte, string) { + if len(body) == 0 { + return body, modelID + } + + out := body + modified := false + + if next, changed := normalizeClaudeOAuthSystemBody(out, opts); changed { + out = next + modified = true + } + + rawModel := gjson.GetBytes(out, "model") + if rawModel.Exists() && rawModel.Type == gjson.String { + normalized := claude.NormalizeModelID(rawModel.String()) + if normalized != rawModel.String() { + if next, ok := setJSONValueBytes(out, "model", normalized); ok { + out = next + modified = true + } + modelID = normalized + } + } + + // 确保 tools 字段存在(即使为空数组) + if !gjson.GetBytes(out, "tools").Exists() { + if next, ok := setJSONRawBytes(out, "tools", []byte("[]")); ok { + out = next + modified = true + } + } + + if opts.injectMetadata && opts.metadataUserID != "" { + if next, changed := ensureClaudeOAuthMetadataUserID(out, opts.metadataUserID); changed { + out = next + modified = true + } + } + + if gjson.GetBytes(out, "temperature").Exists() { + if next, ok := deleteJSONPathBytes(out, "temperature"); ok { + out = next + modified = true + } + } + if gjson.GetBytes(out, "tool_choice").Exists() { + if next, ok := deleteJSONPathBytes(out, "tool_choice"); ok { + out = next + modified = true + } + } + + if !modified { + return body, modelID + } + + return out, modelID +} + +func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account *Account, fp *Fingerprint) string { + if parsed == nil || account == nil { + return "" + } + if parsed.MetadataUserID != "" { + return "" + } + + userID := strings.TrimSpace(account.GetClaudeUserID()) + if userID == "" && fp != nil { + userID = fp.ClientID + } + if userID == "" { + // Fall back to a random, well-formed client id so we can still satisfy + // Claude Code OAuth requirements when account metadata is incomplete. + userID = generateClientID() + } + + sessionHash := s.GenerateSessionHash(parsed) + sessionID := uuid.NewString() + if sessionHash != "" { + seed := fmt.Sprintf("%d::%s", account.ID, sessionHash) + sessionID = generateSessionUUID(seed) + } + + // 根据指纹 UA 版本选择输出格式 + var uaVersion string + if fp != nil { + uaVersion = ExtractCLIVersion(fp.UserAgent) + } + accountUUID := strings.TrimSpace(account.GetExtraString("account_uuid")) + return FormatMetadataUserID(userID, accountUUID, sessionID, uaVersion) +} + +// GenerateSessionUUID creates a deterministic UUID4 from a seed string. +func GenerateSessionUUID(seed string) string { + return generateSessionUUID(seed) +} + +func generateSessionUUID(seed string) string { + if seed == "" { + return uuid.NewString() + } + hash := sha256.Sum256([]byte(seed)) + bytes := hash[:16] + bytes[6] = (bytes[6] & 0x0f) | 0x40 + bytes[8] = (bytes[8] & 0x3f) | 0x80 + return fmt.Sprintf("%x-%x-%x-%x-%x", + bytes[0:4], bytes[4:6], bytes[6:8], bytes[8:10], bytes[10:16]) +} + +// SelectAccount 选择账号(粘性会话+优先级) +func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) { + return s.SelectAccountForModel(ctx, groupID, sessionHash, "") +} + +// SelectAccountForModel 选择支持指定模型的账号(粘性会话+优先级+模型映射) +func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) { + return s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, nil) +} + +// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts. +func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { + // 优先检查 context 中的强制平台(/antigravity 路由) + var platform string + forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) + if hasForcePlatform && forcePlatform != "" { + platform = forcePlatform + } else if groupID != nil { + group, resolvedGroupID, err := s.resolveGatewayGroup(ctx, groupID) + if err != nil { + return nil, err + } + groupID = resolvedGroupID + ctx = s.withGroupContext(ctx, group) + platform = group.Platform + } else { + // 无分组时只使用原生 anthropic 平台 + platform = PlatformAnthropic + } + + // anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户) + // 注意:强制平台模式不走混合调度 + if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform { + return s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) + } + + // antigravity 分组、强制平台模式或无分组使用单平台选择 + // 注意:强制平台模式也必须遵守分组限制,不再回退到全平台查询 + return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) +} + +// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan. +// metadataUserID: 已废弃参数,会话限制现在统一使用 sessionHash +func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string) (*AccountSelectionResult, error) { + // 调试日志:记录调度入口参数 + excludedIDsList := make([]int64, 0, len(excludedIDs)) + for id := range excludedIDs { + excludedIDsList = append(excludedIDsList, id) + } + slog.Debug("account_scheduling_starting", + "group_id", derefGroupID(groupID), + "model", requestedModel, + "session", shortSessionHash(sessionHash), + "excluded_ids", excludedIDsList) + + cfg := s.schedulingConfig() + + // 检查 Claude Code 客户端限制(可能会替换 groupID 为降级分组) + group, groupID, err := s.checkClaudeCodeRestriction(ctx, groupID) + if err != nil { + return nil, err + } + ctx = s.withGroupContext(ctx, group) + + var stickyAccountID int64 + if prefetch := prefetchedStickyAccountIDFromContext(ctx, groupID); prefetch > 0 { + stickyAccountID = prefetch + } else if sessionHash != "" && s.cache != nil { + if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil { + stickyAccountID = accountID + } + } + + if s.debugModelRoutingEnabled() && requestedModel != "" { + groupPlatform := "" + if group != nil { + groupPlatform = group.Platform + } + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] select entry: group_id=%v group_platform=%s model=%s session=%s sticky_account=%d load_batch=%v concurrency=%v", + derefGroupID(groupID), groupPlatform, requestedModel, shortSessionHash(sessionHash), stickyAccountID, cfg.LoadBatchEnabled, s.concurrencyService != nil) + } + + if s.concurrencyService == nil || !cfg.LoadBatchEnabled { + // 复制排除列表,用于会话限制拒绝时的重试 + localExcluded := make(map[int64]struct{}) + for k, v := range excludedIDs { + localExcluded[k] = v + } + + for { + account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, localExcluded) + if err != nil { + return nil, err + } + + result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency) + if err == nil && result.Acquired { + // 获取槽位后检查会话限制(使用 sessionHash 作为会话标识符) + if !s.checkAndRegisterSession(ctx, account, sessionHash) { + result.ReleaseFunc() // 释放槽位 + localExcluded[account.ID] = struct{}{} // 排除此账号 + continue // 重新选择 + } + return &AccountSelectionResult{ + Account: account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } + + // 对于等待计划的情况,也需要先检查会话限制 + if !s.checkAndRegisterSession(ctx, account, sessionHash) { + localExcluded[account.ID] = struct{}{} + continue + } + + if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil { + waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID) + if waitingCount < cfg.StickySessionMaxWaiting { + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: account.ID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } + } + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: account.ID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }, + }, nil + } + } + + platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID, group) + if err != nil { + return nil, err + } + preferOAuth := platform == PlatformGemini + if s.debugModelRoutingEnabled() && platform == PlatformAnthropic && requestedModel != "" { + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), platform) + } + + accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) + if err != nil { + return nil, err + } + if len(accounts) == 0 { + return nil, ErrNoAvailableAccounts + } + ctx = s.withWindowCostPrefetch(ctx, accounts) + ctx = s.withRPMPrefetch(ctx, accounts) + + isExcluded := func(accountID int64) bool { + if excludedIDs == nil { + return false + } + _, excluded := excludedIDs[accountID] + return excluded + } + + // 提前构建 accountByID(供 Layer 1 和 Layer 1.5 使用) + accountByID := make(map[int64]*Account, len(accounts)) + for i := range accounts { + accountByID[accounts[i].ID] = &accounts[i] + } + + // 获取模型路由配置(仅 anthropic 平台) + var routingAccountIDs []int64 + if group != nil && requestedModel != "" && group.Platform == PlatformAnthropic { + routingAccountIDs = group.GetRoutingAccountIDs(requestedModel) + if s.debugModelRoutingEnabled() { + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] context group routing: group_id=%d model=%s enabled=%v rules=%d matched_ids=%v session=%s sticky_account=%d", + group.ID, requestedModel, group.ModelRoutingEnabled, len(group.ModelRouting), routingAccountIDs, shortSessionHash(sessionHash), stickyAccountID) + if len(routingAccountIDs) == 0 && group.ModelRoutingEnabled && len(group.ModelRouting) > 0 { + keys := make([]string, 0, len(group.ModelRouting)) + for k := range group.ModelRouting { + keys = append(keys, k) + } + sort.Strings(keys) + const maxKeys = 20 + if len(keys) > maxKeys { + keys = keys[:maxKeys] + } + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] context group routing miss: group_id=%d model=%s patterns(sample)=%v", group.ID, requestedModel, keys) + } + } + } + + // ============ Layer 1: 模型路由优先选择(优先级高于粘性会话) ============ + if len(routingAccountIDs) > 0 && s.concurrencyService != nil { + // 1. 过滤出路由列表中可调度的账号 + var routingCandidates []*Account + var filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping, filteredWindowCost int + var modelScopeSkippedIDs []int64 // 记录因模型限流被跳过的账号 ID + for _, routingAccountID := range routingAccountIDs { + if isExcluded(routingAccountID) { + filteredExcluded++ + continue + } + account, ok := accountByID[routingAccountID] + if !ok || !s.isAccountSchedulableForSelection(account) { + if !ok { + filteredMissing++ + } else { + filteredUnsched++ + } + continue + } + if !s.isAccountAllowedForPlatform(account, platform, useMixed) { + filteredPlatform++ + continue + } + if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, account, requestedModel) { + filteredModelMapping++ + continue + } + if !s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) { + filteredModelScope++ + modelScopeSkippedIDs = append(modelScopeSkippedIDs, account.ID) + continue + } + // 配额检查 + if !s.isAccountSchedulableForQuota(account) { + continue + } + // 窗口费用检查(非粘性会话路径) + if !s.isAccountSchedulableForWindowCost(ctx, account, false) { + filteredWindowCost++ + continue + } + // RPM 检查(非粘性会话路径) + if !s.isAccountSchedulableForRPM(ctx, account, false) { + continue + } + routingCandidates = append(routingCandidates, account) + } + + if s.debugModelRoutingEnabled() { + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed candidates: group_id=%v model=%s routed=%d candidates=%d filtered(excluded=%d missing=%d unsched=%d platform=%d model_scope=%d model_mapping=%d window_cost=%d)", + derefGroupID(groupID), requestedModel, len(routingAccountIDs), len(routingCandidates), + filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping, filteredWindowCost) + if len(modelScopeSkippedIDs) > 0 { + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] model_rate_limited accounts skipped: group_id=%v model=%s account_ids=%v", + derefGroupID(groupID), requestedModel, modelScopeSkippedIDs) + } + } + + if len(routingCandidates) > 0 { + // 1.5. 在路由账号范围内检查粘性会话 + if sessionHash != "" && stickyAccountID > 0 { + if containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) { + // 粘性账号在路由列表中,优先使用 + if stickyAccount, ok := accountByID[stickyAccountID]; ok { + if s.isAccountSchedulableForSelection(stickyAccount) && + s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) && + (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) && + s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) && + s.isAccountSchedulableForQuota(stickyAccount) && + s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) && + + s.isAccountSchedulableForRPM(ctx, stickyAccount, true) { // 粘性会话窗口费用+RPM 检查 + result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency) + if err == nil && result.Acquired { + // 会话数量限制检查 + if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) { + result.ReleaseFunc() // 释放槽位 + // 继续到负载感知选择 + } else { + if s.debugModelRoutingEnabled() { + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID) + } + return &AccountSelectionResult{ + Account: stickyAccount, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } + } + + waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID) + if waitingCount < cfg.StickySessionMaxWaiting { + // 会话数量限制检查(等待计划也需要占用会话配额) + if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) { + // 会话限制已满,继续到负载感知选择 + } else { + return &AccountSelectionResult{ + Account: stickyAccount, + WaitPlan: &AccountWaitPlan{ + AccountID: stickyAccountID, + MaxConcurrency: stickyAccount.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } + } + // 粘性账号槽位满且等待队列已满,继续使用负载感知选择 + } + } else { + _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) + } + } + } + + // 2. 批量获取负载信息 + routingLoads := make([]AccountWithConcurrency, 0, len(routingCandidates)) + for _, acc := range routingCandidates { + routingLoads = append(routingLoads, AccountWithConcurrency{ + ID: acc.ID, + MaxConcurrency: acc.EffectiveLoadFactor(), + }) + } + routingLoadMap, _ := s.concurrencyService.GetAccountsLoadBatch(ctx, routingLoads) + + // 3. 按负载感知排序 + var routingAvailable []accountWithLoad + for _, acc := range routingCandidates { + loadInfo := routingLoadMap[acc.ID] + if loadInfo == nil { + loadInfo = &AccountLoadInfo{AccountID: acc.ID} + } + if loadInfo.LoadRate < 100 { + routingAvailable = append(routingAvailable, accountWithLoad{account: acc, loadInfo: loadInfo}) + } + } + + if len(routingAvailable) > 0 { + // 排序:优先级 > 负载率 > 最后使用时间 + sort.SliceStable(routingAvailable, func(i, j int) bool { + a, b := routingAvailable[i], routingAvailable[j] + if a.account.Priority != b.account.Priority { + return a.account.Priority < b.account.Priority + } + if a.loadInfo.LoadRate != b.loadInfo.LoadRate { + return a.loadInfo.LoadRate < b.loadInfo.LoadRate + } + switch { + case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil: + return true + case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil: + return false + case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil: + return false + default: + return a.account.LastUsedAt.Before(*b.account.LastUsedAt) + } + }) + shuffleWithinSortGroups(routingAvailable) + + // 4. 尝试获取槽位 + for _, item := range routingAvailable { + result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) + if err == nil && result.Acquired { + // 会话数量限制检查 + if !s.checkAndRegisterSession(ctx, item.account, sessionHash) { + result.ReleaseFunc() // 释放槽位,继续尝试下一个账号 + continue + } + if sessionHash != "" && s.cache != nil { + _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL) + } + if s.debugModelRoutingEnabled() { + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID) + } + return &AccountSelectionResult{ + Account: item.account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } + } + + // 5. 所有路由账号槽位满,尝试返回等待计划(选择负载最低的) + // 遍历找到第一个满足会话限制的账号 + for _, item := range routingAvailable { + if !s.checkAndRegisterSession(ctx, item.account, sessionHash) { + continue // 会话限制已满,尝试下一个 + } + if s.debugModelRoutingEnabled() { + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID) + } + return &AccountSelectionResult{ + Account: item.account, + WaitPlan: &AccountWaitPlan{ + AccountID: item.account.ID, + MaxConcurrency: item.account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } + // 所有路由账号会话限制都已满,继续到 Layer 2 回退 + } + // 路由列表中的账号都不可用(负载率 >= 100),继续到 Layer 2 回退 + logger.LegacyPrintf("service.gateway", "[ModelRouting] All routed accounts unavailable for model=%s, falling back to normal selection", requestedModel) + } + } + + // ============ Layer 1.5: 粘性会话(仅在无模型路由配置时生效) ============ + if len(routingAccountIDs) == 0 && sessionHash != "" && stickyAccountID > 0 && !isExcluded(stickyAccountID) { + accountID := stickyAccountID + if accountID > 0 && !isExcluded(accountID) { + account, ok := accountByID[accountID] + if ok { + // 检查账户是否需要清理粘性会话绑定 + // Check if the account needs sticky session cleanup + clearSticky := shouldClearStickySession(ctx, account, requestedModel) + if clearSticky { + _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) + } + if !clearSticky && s.isAccountInGroup(account, groupID) && + s.isAccountAllowedForPlatform(account, platform, useMixed) && + (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && + s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && + s.isAccountSchedulableForQuota(account) && + s.isAccountSchedulableForWindowCost(ctx, account, true) && + + s.isAccountSchedulableForRPM(ctx, account, true) { // 粘性会话窗口费用+RPM 检查 + result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) + if err == nil && result.Acquired { + // 会话数量限制检查 + // Session count limit check + if !s.checkAndRegisterSession(ctx, account, sessionHash) { + result.ReleaseFunc() // 释放槽位,继续到 Layer 2 + } else { + return &AccountSelectionResult{ + Account: account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } + } + + waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) + if waitingCount < cfg.StickySessionMaxWaiting { + // 会话数量限制检查(等待计划也需要占用会话配额) + // Session count limit check (wait plan also requires session quota) + if !s.checkAndRegisterSession(ctx, account, sessionHash) { + // 会话限制已满,继续到 Layer 2 + // Session limit full, continue to Layer 2 + } else { + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: accountID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } + } + } + } + } + } + + // ============ Layer 2: 负载感知选择 ============ + candidates := make([]*Account, 0, len(accounts)) + for i := range accounts { + acc := &accounts[i] + if isExcluded(acc.ID) { + continue + } + // Scheduler snapshots can be temporarily stale (bucket rebuild is throttled); + // re-check schedulability here so recently rate-limited/overloaded accounts + // are not selected again before the bucket is rebuilt. + if !s.isAccountSchedulableForSelection(acc) { + continue + } + if !s.isAccountAllowedForPlatform(acc, platform, useMixed) { + continue + } + if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { + continue + } + if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { + continue + } + // 配额检查 + if !s.isAccountSchedulableForQuota(acc) { + continue + } + // 窗口费用检查(非粘性会话路径) + if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { + continue + } + // RPM 检查(非粘性会话路径) + if !s.isAccountSchedulableForRPM(ctx, acc, false) { + continue + } + candidates = append(candidates, acc) + } + + if len(candidates) == 0 { + return nil, ErrNoAvailableAccounts + } + + accountLoads := make([]AccountWithConcurrency, 0, len(candidates)) + for _, acc := range candidates { + accountLoads = append(accountLoads, AccountWithConcurrency{ + ID: acc.ID, + MaxConcurrency: acc.EffectiveLoadFactor(), + }) + } + + loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) + if err != nil { + if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok { + return result, nil + } + } else { + var available []accountWithLoad + for _, acc := range candidates { + loadInfo := loadMap[acc.ID] + if loadInfo == nil { + loadInfo = &AccountLoadInfo{AccountID: acc.ID} + } + if loadInfo.LoadRate < 100 { + available = append(available, accountWithLoad{ + account: acc, + loadInfo: loadInfo, + }) + } + } + + // 分层过滤选择:优先级 → 负载率 → LRU + for len(available) > 0 { + // 1. 取优先级最小的集合 + candidates := filterByMinPriority(available) + // 2. 取负载率最低的集合 + candidates = filterByMinLoadRate(candidates) + // 3. LRU 选择最久未用的账号 + selected := selectByLRU(candidates, preferOAuth) + if selected == nil { + break + } + + result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency) + if err == nil && result.Acquired { + // 会话数量限制检查 + if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) { + result.ReleaseFunc() // 释放槽位,继续尝试下一个账号 + } else { + if sessionHash != "" && s.cache != nil { + _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL) + } + return &AccountSelectionResult{ + Account: selected.account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } + } + + // 移除已尝试的账号,重新进行分层过滤 + selectedID := selected.account.ID + newAvailable := make([]accountWithLoad, 0, len(available)-1) + for _, acc := range available { + if acc.account.ID != selectedID { + newAvailable = append(newAvailable, acc) + } + } + available = newAvailable + } + } + + // ============ Layer 3: 兜底排队 ============ + s.sortCandidatesForFallback(candidates, preferOAuth, cfg.FallbackSelectionMode) + for _, acc := range candidates { + // 会话数量限制检查(等待计划也需要占用会话配额) + if !s.checkAndRegisterSession(ctx, acc, sessionHash) { + continue // 会话限制已满,尝试下一个账号 + } + return &AccountSelectionResult{ + Account: acc, + WaitPlan: &AccountWaitPlan{ + AccountID: acc.ID, + MaxConcurrency: acc.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }, + }, nil + } + return nil, ErrNoAvailableAccounts +} + +func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) { + ordered := append([]*Account(nil), candidates...) + sortAccountsByPriorityAndLastUsed(ordered, preferOAuth) + + for _, acc := range ordered { + result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency) + if err == nil && result.Acquired { + // 会话数量限制检查 + if !s.checkAndRegisterSession(ctx, acc, sessionHash) { + result.ReleaseFunc() // 释放槽位,继续尝试下一个账号 + continue + } + if sessionHash != "" && s.cache != nil { + _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, acc.ID, stickySessionTTL) + } + return &AccountSelectionResult{ + Account: acc, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, true + } + } + + return nil, false +} + +func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig { + if s.cfg != nil { + return s.cfg.Gateway.Scheduling + } + return config.GatewaySchedulingConfig{ + StickySessionMaxWaiting: 3, + StickySessionWaitTimeout: 45 * time.Second, + FallbackWaitTimeout: 30 * time.Second, + FallbackMaxWaiting: 100, + LoadBatchEnabled: true, + SlotCleanupInterval: 30 * time.Second, + } +} + +func (s *GatewayService) withGroupContext(ctx context.Context, group *Group) context.Context { + if !IsGroupContextValid(group) { + return ctx + } + if existing, ok := ctx.Value(ctxkey.Group).(*Group); ok && existing != nil && existing.ID == group.ID && IsGroupContextValid(existing) { + return ctx + } + return context.WithValue(ctx, ctxkey.Group, group) +} + +func (s *GatewayService) groupFromContext(ctx context.Context, groupID int64) *Group { + if group, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(group) && group.ID == groupID { + return group + } + return nil +} + +func (s *GatewayService) resolveGroupByID(ctx context.Context, groupID int64) (*Group, error) { + if group := s.groupFromContext(ctx, groupID); group != nil { + return group, nil + } + group, err := s.groupRepo.GetByIDLite(ctx, groupID) + if err != nil { + return nil, fmt.Errorf("get group failed: %w", err) + } + return group, nil +} + +func (s *GatewayService) ResolveGroupByID(ctx context.Context, groupID int64) (*Group, error) { + return s.resolveGroupByID(ctx, groupID) +} + +func (s *GatewayService) routingAccountIDsForRequest(ctx context.Context, groupID *int64, requestedModel string, platform string) []int64 { + if groupID == nil || requestedModel == "" || platform != PlatformAnthropic { + return nil + } + group, err := s.resolveGroupByID(ctx, *groupID) + if err != nil || group == nil { + if s.debugModelRoutingEnabled() { + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] resolve group failed: group_id=%v model=%s platform=%s err=%v", derefGroupID(groupID), requestedModel, platform, err) + } + return nil + } + // Preserve existing behavior: model routing only applies to anthropic groups. + if group.Platform != PlatformAnthropic { + if s.debugModelRoutingEnabled() { + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] skip: non-anthropic group platform: group_id=%d group_platform=%s model=%s", group.ID, group.Platform, requestedModel) + } + return nil + } + ids := group.GetRoutingAccountIDs(requestedModel) + if s.debugModelRoutingEnabled() { + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routing lookup: group_id=%d model=%s enabled=%v rules=%d matched_ids=%v", + group.ID, requestedModel, group.ModelRoutingEnabled, len(group.ModelRouting), ids) + } + return ids +} + +func (s *GatewayService) resolveGatewayGroup(ctx context.Context, groupID *int64) (*Group, *int64, error) { + if groupID == nil { + return nil, nil, nil + } + + currentID := *groupID + visited := map[int64]struct{}{} + for { + if _, seen := visited[currentID]; seen { + return nil, nil, fmt.Errorf("fallback group cycle detected") + } + visited[currentID] = struct{}{} + + group, err := s.resolveGroupByID(ctx, currentID) + if err != nil { + return nil, nil, err + } + + if !group.ClaudeCodeOnly || IsClaudeCodeClient(ctx) { + return group, ¤tID, nil + } + + if group.FallbackGroupID == nil { + return nil, nil, ErrClaudeCodeOnly + } + currentID = *group.FallbackGroupID + } +} + +// checkClaudeCodeRestriction 检查分组的 Claude Code 客户端限制 +// 如果分组启用了 claude_code_only 且请求不是来自 Claude Code 客户端: +// - 有降级分组:返回降级分组的 ID +// - 无降级分组:返回 ErrClaudeCodeOnly 错误 +func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID *int64) (*Group, *int64, error) { + if groupID == nil { + return nil, groupID, nil + } + + // 强制平台模式不检查 Claude Code 限制 + if forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform && forcePlatform != "" { + return nil, groupID, nil + } + + group, resolvedID, err := s.resolveGatewayGroup(ctx, groupID) + if err != nil { + return nil, nil, err + } + + return group, resolvedID, nil +} + +func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, group *Group) (string, bool, error) { + forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) + if hasForcePlatform && forcePlatform != "" { + return forcePlatform, true, nil + } + if group != nil { + return group.Platform, false, nil + } + if groupID != nil { + group, err := s.resolveGroupByID(ctx, *groupID) + if err != nil { + return "", false, err + } + return group.Platform, false, nil + } + return PlatformAnthropic, false, nil +} + +func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) { + if platform == PlatformSora { + return s.listSoraSchedulableAccounts(ctx, groupID) + } + if s.schedulerSnapshot != nil { + accounts, useMixed, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) + if err == nil { + slog.Debug("account_scheduling_list_snapshot", + "group_id", derefGroupID(groupID), + "platform", platform, + "use_mixed", useMixed, + "count", len(accounts)) + for _, acc := range accounts { + slog.Debug("account_scheduling_account_detail", + "account_id", acc.ID, + "name", acc.Name, + "platform", acc.Platform, + "type", acc.Type, + "status", acc.Status, + "tls_fingerprint", acc.IsTLSFingerprintEnabled()) + } + } + return accounts, useMixed, err + } + useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform + if useMixed { + platforms := []string{platform, PlatformAntigravity} + var accounts []Account + var err error + if groupID != nil { + accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms) + } else if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms) + } else { + accounts, err = s.accountRepo.ListSchedulableUngroupedByPlatforms(ctx, platforms) + } + if err != nil { + slog.Debug("account_scheduling_list_failed", + "group_id", derefGroupID(groupID), + "platform", platform, + "error", err) + return nil, useMixed, err + } + filtered := make([]Account, 0, len(accounts)) + for _, acc := range accounts { + if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { + continue + } + filtered = append(filtered, acc) + } + slog.Debug("account_scheduling_list_mixed", + "group_id", derefGroupID(groupID), + "platform", platform, + "raw_count", len(accounts), + "filtered_count", len(filtered)) + for _, acc := range filtered { + slog.Debug("account_scheduling_account_detail", + "account_id", acc.ID, + "name", acc.Name, + "platform", acc.Platform, + "type", acc.Type, + "status", acc.Status, + "tls_fingerprint", acc.IsTLSFingerprintEnabled()) + } + return filtered, useMixed, nil + } + + var accounts []Account + var err error + if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) + } else if groupID != nil { + accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform) + // 分组内无账号则返回空列表,由上层处理错误,不再回退到全平台查询 + } else { + accounts, err = s.accountRepo.ListSchedulableUngroupedByPlatform(ctx, platform) + } + if err != nil { + slog.Debug("account_scheduling_list_failed", + "group_id", derefGroupID(groupID), + "platform", platform, + "error", err) + return nil, useMixed, err + } + slog.Debug("account_scheduling_list_single", + "group_id", derefGroupID(groupID), + "platform", platform, + "count", len(accounts)) + for _, acc := range accounts { + slog.Debug("account_scheduling_account_detail", + "account_id", acc.ID, + "name", acc.Name, + "platform", acc.Platform, + "type", acc.Type, + "status", acc.Status, + "tls_fingerprint", acc.IsTLSFingerprintEnabled()) + } + return accounts, useMixed, nil +} + +func (s *GatewayService) listSoraSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, bool, error) { + const useMixed = false + + var accounts []Account + var err error + if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + accounts, err = s.accountRepo.ListByPlatform(ctx, PlatformSora) + } else if groupID != nil { + accounts, err = s.accountRepo.ListByGroup(ctx, *groupID) + } else { + accounts, err = s.accountRepo.ListByPlatform(ctx, PlatformSora) + } + if err != nil { + slog.Debug("account_scheduling_list_failed", + "group_id", derefGroupID(groupID), + "platform", PlatformSora, + "error", err) + return nil, useMixed, err + } + + filtered := make([]Account, 0, len(accounts)) + for _, acc := range accounts { + if acc.Platform != PlatformSora { + continue + } + if !s.isSoraAccountSchedulable(&acc) { + continue + } + filtered = append(filtered, acc) + } + slog.Debug("account_scheduling_list_sora", + "group_id", derefGroupID(groupID), + "platform", PlatformSora, + "raw_count", len(accounts), + "filtered_count", len(filtered)) + for _, acc := range filtered { + slog.Debug("account_scheduling_account_detail", + "account_id", acc.ID, + "name", acc.Name, + "platform", acc.Platform, + "type", acc.Type, + "status", acc.Status, + "tls_fingerprint", acc.IsTLSFingerprintEnabled()) + } + return filtered, useMixed, nil +} + +// IsSingleAntigravityAccountGroup 检查指定分组是否只有一个 antigravity 平台的可调度账号。 +// 用于 Handler 层在首次请求时提前设置 SingleAccountRetry context, +// 避免单账号分组收到 503 时错误地设置模型限流标记导致后续请求连续快速失败。 +func (s *GatewayService) IsSingleAntigravityAccountGroup(ctx context.Context, groupID *int64) bool { + accounts, _, err := s.listSchedulableAccounts(ctx, groupID, PlatformAntigravity, true) + if err != nil { + return false + } + return len(accounts) == 1 +} + +func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform string, useMixed bool) bool { + if account == nil { + return false + } + if useMixed { + if account.Platform == platform { + return true + } + return account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() + } + return account.Platform == platform +} + +func (s *GatewayService) isSoraAccountSchedulable(account *Account) bool { + return s.soraUnschedulableReason(account) == "" +} + +func (s *GatewayService) soraUnschedulableReason(account *Account) string { + if account == nil { + return "account_nil" + } + if account.Status != StatusActive { + return fmt.Sprintf("status=%s", account.Status) + } + if !account.Schedulable { + return "schedulable=false" + } + if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) { + return fmt.Sprintf("temp_unschedulable_until=%s", account.TempUnschedulableUntil.UTC().Format(time.RFC3339)) + } + return "" +} + +func (s *GatewayService) isAccountSchedulableForSelection(account *Account) bool { + if account == nil { + return false + } + if account.Platform == PlatformSora { + return s.isSoraAccountSchedulable(account) + } + return account.IsSchedulable() +} + +func (s *GatewayService) isAccountSchedulableForModelSelection(ctx context.Context, account *Account, requestedModel string) bool { + if account == nil { + return false + } + if account.Platform == PlatformSora { + if !s.isSoraAccountSchedulable(account) { + return false + } + return account.GetRateLimitRemainingTimeWithContext(ctx, requestedModel) <= 0 + } + return account.IsSchedulableForModelWithContext(ctx, requestedModel) +} + +// isAccountInGroup checks if the account belongs to the specified group. +// When groupID is nil, returns true only for ungrouped accounts (no group assignments). +func (s *GatewayService) isAccountInGroup(account *Account, groupID *int64) bool { + if account == nil { + return false + } + if groupID == nil { + // 无分组的 API Key 只能使用未分组的账号 + return len(account.AccountGroups) == 0 + } + for _, ag := range account.AccountGroups { + if ag.GroupID == *groupID { + return true + } + } + return false +} + +func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) { + if s.concurrencyService == nil { + return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil + } + return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) +} + +type usageLogWindowStatsBatchProvider interface { + GetAccountWindowStatsBatch(ctx context.Context, accountIDs []int64, startTime time.Time) (map[int64]*usagestats.AccountStats, error) +} + +type windowCostPrefetchContextKeyType struct{} + +var windowCostPrefetchContextKey = windowCostPrefetchContextKeyType{} + +func windowCostFromPrefetchContext(ctx context.Context, accountID int64) (float64, bool) { + if ctx == nil || accountID <= 0 { + return 0, false + } + m, ok := ctx.Value(windowCostPrefetchContextKey).(map[int64]float64) + if !ok || len(m) == 0 { + return 0, false + } + v, exists := m[accountID] + return v, exists +} + +func (s *GatewayService) withWindowCostPrefetch(ctx context.Context, accounts []Account) context.Context { + if ctx == nil || len(accounts) == 0 || s.sessionLimitCache == nil || s.usageLogRepo == nil { + return ctx + } + + accountByID := make(map[int64]*Account) + accountIDs := make([]int64, 0, len(accounts)) + for i := range accounts { + account := &accounts[i] + if account == nil || !account.IsAnthropicOAuthOrSetupToken() { + continue + } + if account.GetWindowCostLimit() <= 0 { + continue + } + accountByID[account.ID] = account + accountIDs = append(accountIDs, account.ID) + } + if len(accountIDs) == 0 { + return ctx + } + + costs := make(map[int64]float64, len(accountIDs)) + cacheValues, err := s.sessionLimitCache.GetWindowCostBatch(ctx, accountIDs) + if err == nil { + for accountID, cost := range cacheValues { + costs[accountID] = cost + } + windowCostPrefetchCacheHitTotal.Add(int64(len(cacheValues))) + } else { + windowCostPrefetchErrorTotal.Add(1) + logger.LegacyPrintf("service.gateway", "window_cost batch cache read failed: %v", err) + } + cacheMissCount := len(accountIDs) - len(costs) + if cacheMissCount < 0 { + cacheMissCount = 0 + } + windowCostPrefetchCacheMissTotal.Add(int64(cacheMissCount)) + + missingByStart := make(map[int64][]int64) + startTimes := make(map[int64]time.Time) + for _, accountID := range accountIDs { + if _, ok := costs[accountID]; ok { + continue + } + account := accountByID[accountID] + if account == nil { + continue + } + startTime := account.GetCurrentWindowStartTime() + startKey := startTime.Unix() + missingByStart[startKey] = append(missingByStart[startKey], accountID) + startTimes[startKey] = startTime + } + if len(missingByStart) == 0 { + return context.WithValue(ctx, windowCostPrefetchContextKey, costs) + } + + batchReader, hasBatch := s.usageLogRepo.(usageLogWindowStatsBatchProvider) + for startKey, ids := range missingByStart { + startTime := startTimes[startKey] + + if hasBatch { + windowCostPrefetchBatchSQLTotal.Add(1) + queryStart := time.Now() + statsByAccount, err := batchReader.GetAccountWindowStatsBatch(ctx, ids, startTime) + if err == nil { + slog.Debug("window_cost_batch_query_ok", + "accounts", len(ids), + "window_start", startTime.Format(time.RFC3339), + "duration_ms", time.Since(queryStart).Milliseconds()) + for _, accountID := range ids { + stats := statsByAccount[accountID] + cost := 0.0 + if stats != nil { + cost = stats.StandardCost + } + costs[accountID] = cost + _ = s.sessionLimitCache.SetWindowCost(ctx, accountID, cost) + } + continue + } + windowCostPrefetchErrorTotal.Add(1) + logger.LegacyPrintf("service.gateway", "window_cost batch db query failed: start=%s err=%v", startTime.Format(time.RFC3339), err) + } + + // 回退路径:缺少批量仓储能力或批量查询失败时,按账号单查(失败开放)。 + windowCostPrefetchFallbackTotal.Add(int64(len(ids))) + for _, accountID := range ids { + stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, accountID, startTime) + if err != nil { + windowCostPrefetchErrorTotal.Add(1) + continue + } + cost := stats.StandardCost + costs[accountID] = cost + _ = s.sessionLimitCache.SetWindowCost(ctx, accountID, cost) + } + } + + return context.WithValue(ctx, windowCostPrefetchContextKey, costs) +} + +// isAccountSchedulableForQuota 检查账号是否在配额限制内 +// 适用于配置了 quota_limit 的 apikey 和 bedrock 类型账号 +func (s *GatewayService) isAccountSchedulableForQuota(account *Account) bool { + if !account.IsAPIKeyOrBedrock() { + return true + } + return !account.IsQuotaExceeded() +} + +// isAccountSchedulableForWindowCost 检查账号是否可根据窗口费用进行调度 +// 仅适用于 Anthropic OAuth/SetupToken 账号 +// 返回 true 表示可调度,false 表示不可调度 +func (s *GatewayService) isAccountSchedulableForWindowCost(ctx context.Context, account *Account, isSticky bool) bool { + // 只检查 Anthropic OAuth/SetupToken 账号 + if !account.IsAnthropicOAuthOrSetupToken() { + return true + } + + limit := account.GetWindowCostLimit() + if limit <= 0 { + return true // 未启用窗口费用限制 + } + + // 尝试从缓存获取窗口费用 + var currentCost float64 + if cost, ok := windowCostFromPrefetchContext(ctx, account.ID); ok { + currentCost = cost + goto checkSchedulability + } + if s.sessionLimitCache != nil { + if cost, hit, err := s.sessionLimitCache.GetWindowCost(ctx, account.ID); err == nil && hit { + currentCost = cost + goto checkSchedulability + } + } + + // 缓存未命中,从数据库查询 + { + // 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况) + startTime := account.GetCurrentWindowStartTime() + + stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime) + if err != nil { + // 失败开放:查询失败时允许调度 + return true + } + + // 使用标准费用(不含账号倍率) + currentCost = stats.StandardCost + + // 设置缓存(忽略错误) + if s.sessionLimitCache != nil { + _ = s.sessionLimitCache.SetWindowCost(ctx, account.ID, currentCost) + } + } + +checkSchedulability: + schedulability := account.CheckWindowCostSchedulability(currentCost) + + switch schedulability { + case WindowCostSchedulable: + return true + case WindowCostStickyOnly: + return isSticky + case WindowCostNotSchedulable: + return false + } + return true +} + +// rpmPrefetchContextKey is the context key for prefetched RPM counts. +type rpmPrefetchContextKeyType struct{} + +var rpmPrefetchContextKey = rpmPrefetchContextKeyType{} + +func rpmFromPrefetchContext(ctx context.Context, accountID int64) (int, bool) { + if v, ok := ctx.Value(rpmPrefetchContextKey).(map[int64]int); ok { + count, found := v[accountID] + return count, found + } + return 0, false +} + +// withRPMPrefetch 批量预取所有候选账号的 RPM 计数 +func (s *GatewayService) withRPMPrefetch(ctx context.Context, accounts []Account) context.Context { + if s.rpmCache == nil { + return ctx + } + + var ids []int64 + for i := range accounts { + if accounts[i].IsAnthropicOAuthOrSetupToken() && accounts[i].GetBaseRPM() > 0 { + ids = append(ids, accounts[i].ID) + } + } + if len(ids) == 0 { + return ctx + } + + counts, err := s.rpmCache.GetRPMBatch(ctx, ids) + if err != nil { + return ctx // 失败开放 + } + return context.WithValue(ctx, rpmPrefetchContextKey, counts) +} + +// isAccountSchedulableForRPM 检查账号是否可根据 RPM 进行调度 +// 仅适用于 Anthropic OAuth/SetupToken 账号 +func (s *GatewayService) isAccountSchedulableForRPM(ctx context.Context, account *Account, isSticky bool) bool { + if !account.IsAnthropicOAuthOrSetupToken() { + return true + } + baseRPM := account.GetBaseRPM() + if baseRPM <= 0 { + return true + } + + // 尝试从预取缓存获取 + var currentRPM int + if count, ok := rpmFromPrefetchContext(ctx, account.ID); ok { + currentRPM = count + } else if s.rpmCache != nil { + if count, err := s.rpmCache.GetRPM(ctx, account.ID); err == nil { + currentRPM = count + } + // 失败开放:GetRPM 错误时允许调度 + } + + schedulability := account.CheckRPMSchedulability(currentRPM) + switch schedulability { + case WindowCostSchedulable: + return true + case WindowCostStickyOnly: + return isSticky + case WindowCostNotSchedulable: + return false + } + return true +} + +// IncrementAccountRPM increments the RPM counter for the given account. +// 已知 TOCTOU 竞态:调度时读取 RPM 计数与此处递增之间存在时间窗口, +// 高并发下可能短暂超出 RPM 限制。这是与 WindowCost 一致的 soft-limit +// 设计权衡——可接受的少量超额优于加锁带来的延迟和复杂度。 +func (s *GatewayService) IncrementAccountRPM(ctx context.Context, accountID int64) error { + if s.rpmCache == nil { + return nil + } + _, err := s.rpmCache.IncrementRPM(ctx, accountID) + return err +} + +// checkAndRegisterSession 检查并注册会话,用于会话数量限制 +// 仅适用于 Anthropic OAuth/SetupToken 账号 +// sessionID: 会话标识符(使用粘性会话的 hash) +// 返回 true 表示允许(在限制内或会话已存在),false 表示拒绝(超出限制且是新会话) +func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *Account, sessionID string) bool { + // 只检查 Anthropic OAuth/SetupToken 账号 + if !account.IsAnthropicOAuthOrSetupToken() { + return true + } + + maxSessions := account.GetMaxSessions() + if maxSessions <= 0 || sessionID == "" { + return true // 未启用会话限制或无会话ID + } + + if s.sessionLimitCache == nil { + return true // 缓存不可用时允许通过 + } + + idleTimeout := time.Duration(account.GetSessionIdleTimeoutMinutes()) * time.Minute + + allowed, err := s.sessionLimitCache.RegisterSession(ctx, account.ID, sessionID, maxSessions, idleTimeout) + if err != nil { + // 失败开放:缓存错误时允许通过 + return true + } + return allowed +} + +func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) { + if s.schedulerSnapshot != nil { + return s.schedulerSnapshot.GetAccount(ctx, accountID) + } + return s.accountRepo.GetByID(ctx, accountID) +} + +// filterByMinPriority 过滤出优先级最小的账号集合 +func filterByMinPriority(accounts []accountWithLoad) []accountWithLoad { + if len(accounts) == 0 { + return accounts + } + minPriority := accounts[0].account.Priority + for _, acc := range accounts[1:] { + if acc.account.Priority < minPriority { + minPriority = acc.account.Priority + } + } + result := make([]accountWithLoad, 0, len(accounts)) + for _, acc := range accounts { + if acc.account.Priority == minPriority { + result = append(result, acc) + } + } + return result +} + +// filterByMinLoadRate 过滤出负载率最低的账号集合 +func filterByMinLoadRate(accounts []accountWithLoad) []accountWithLoad { + if len(accounts) == 0 { + return accounts + } + minLoadRate := accounts[0].loadInfo.LoadRate + for _, acc := range accounts[1:] { + if acc.loadInfo.LoadRate < minLoadRate { + minLoadRate = acc.loadInfo.LoadRate + } + } + result := make([]accountWithLoad, 0, len(accounts)) + for _, acc := range accounts { + if acc.loadInfo.LoadRate == minLoadRate { + result = append(result, acc) + } + } + return result +} + +// selectByLRU 从集合中选择最久未用的账号 +// 如果有多个账号具有相同的最小 LastUsedAt,则随机选择一个 +func selectByLRU(accounts []accountWithLoad, preferOAuth bool) *accountWithLoad { + if len(accounts) == 0 { + return nil + } + if len(accounts) == 1 { + return &accounts[0] + } + + // 1. 找到最小的 LastUsedAt(nil 被视为最小) + var minTime *time.Time + hasNil := false + for _, acc := range accounts { + if acc.account.LastUsedAt == nil { + hasNil = true + break + } + if minTime == nil || acc.account.LastUsedAt.Before(*minTime) { + minTime = acc.account.LastUsedAt + } + } + + // 2. 收集所有具有最小 LastUsedAt 的账号索引 + var candidateIdxs []int + for i, acc := range accounts { + if hasNil { + if acc.account.LastUsedAt == nil { + candidateIdxs = append(candidateIdxs, i) + } + } else { + if acc.account.LastUsedAt != nil && acc.account.LastUsedAt.Equal(*minTime) { + candidateIdxs = append(candidateIdxs, i) + } + } + } + + // 3. 如果只有一个候选,直接返回 + if len(candidateIdxs) == 1 { + return &accounts[candidateIdxs[0]] + } + + // 4. 如果有多个候选且 preferOAuth,优先选择 OAuth 类型 + if preferOAuth { + var oauthIdxs []int + for _, idx := range candidateIdxs { + if accounts[idx].account.Type == AccountTypeOAuth { + oauthIdxs = append(oauthIdxs, idx) + } + } + if len(oauthIdxs) > 0 { + candidateIdxs = oauthIdxs + } + } + + // 5. 随机选择一个 + selectedIdx := candidateIdxs[mathrand.Intn(len(candidateIdxs))] + return &accounts[selectedIdx] +} + +func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) { + sort.SliceStable(accounts, func(i, j int) bool { + a, b := accounts[i], accounts[j] + if a.Priority != b.Priority { + return a.Priority < b.Priority + } + switch { + case a.LastUsedAt == nil && b.LastUsedAt != nil: + return true + case a.LastUsedAt != nil && b.LastUsedAt == nil: + return false + case a.LastUsedAt == nil && b.LastUsedAt == nil: + if preferOAuth && a.Type != b.Type { + return a.Type == AccountTypeOAuth + } + return false + default: + return a.LastUsedAt.Before(*b.LastUsedAt) + } + }) + shuffleWithinPriorityAndLastUsed(accounts, preferOAuth) +} + +// shuffleWithinSortGroups 对排序后的 accountWithLoad 切片,按 (Priority, LoadRate, LastUsedAt) 分组后组内随机打乱。 +// 防止并发请求读取同一快照时,确定性排序导致所有请求命中相同账号。 +func shuffleWithinSortGroups(accounts []accountWithLoad) { + if len(accounts) <= 1 { + return + } + i := 0 + for i < len(accounts) { + j := i + 1 + for j < len(accounts) && sameAccountWithLoadGroup(accounts[i], accounts[j]) { + j++ + } + if j-i > 1 { + mathrand.Shuffle(j-i, func(a, b int) { + accounts[i+a], accounts[i+b] = accounts[i+b], accounts[i+a] + }) + } + i = j + } +} + +// sameAccountWithLoadGroup 判断两个 accountWithLoad 是否属于同一排序组 +func sameAccountWithLoadGroup(a, b accountWithLoad) bool { + if a.account.Priority != b.account.Priority { + return false + } + if a.loadInfo.LoadRate != b.loadInfo.LoadRate { + return false + } + return sameLastUsedAt(a.account.LastUsedAt, b.account.LastUsedAt) +} + +// shuffleWithinPriorityAndLastUsed 对排序后的 []*Account 切片,按 (Priority, LastUsedAt) 分组后组内随机打乱。 +// +// 注意:当 preferOAuth=true 时,需要保证 OAuth 账号在同组内仍然优先,否则会把排序时的偏好打散掉。 +// 因此这里采用"组内分区 + 分区内 shuffle"的方式: +// - 先把同组账号按 (OAuth / 非 OAuth) 拆成两段,保持 OAuth 段在前; +// - 再分别在各段内随机打散,避免热点。 +func shuffleWithinPriorityAndLastUsed(accounts []*Account, preferOAuth bool) { + if len(accounts) <= 1 { + return + } + i := 0 + for i < len(accounts) { + j := i + 1 + for j < len(accounts) && sameAccountGroup(accounts[i], accounts[j]) { + j++ + } + if j-i > 1 { + if preferOAuth { + oauth := make([]*Account, 0, j-i) + others := make([]*Account, 0, j-i) + for _, acc := range accounts[i:j] { + if acc.Type == AccountTypeOAuth { + oauth = append(oauth, acc) + } else { + others = append(others, acc) + } + } + if len(oauth) > 1 { + mathrand.Shuffle(len(oauth), func(a, b int) { oauth[a], oauth[b] = oauth[b], oauth[a] }) + } + if len(others) > 1 { + mathrand.Shuffle(len(others), func(a, b int) { others[a], others[b] = others[b], others[a] }) + } + copy(accounts[i:], oauth) + copy(accounts[i+len(oauth):], others) + } else { + mathrand.Shuffle(j-i, func(a, b int) { + accounts[i+a], accounts[i+b] = accounts[i+b], accounts[i+a] + }) + } + } + i = j + } +} + +// sameAccountGroup 判断两个 Account 是否属于同一排序组(Priority + LastUsedAt) +func sameAccountGroup(a, b *Account) bool { + if a.Priority != b.Priority { + return false + } + return sameLastUsedAt(a.LastUsedAt, b.LastUsedAt) +} + +// sameLastUsedAt 判断两个 LastUsedAt 是否相同(精度到秒) +func sameLastUsedAt(a, b *time.Time) bool { + switch { + case a == nil && b == nil: + return true + case a == nil || b == nil: + return false + default: + return a.Unix() == b.Unix() + } +} + +// sortCandidatesForFallback 根据配置选择排序策略 +// mode: "last_used"(按最后使用时间) 或 "random"(随机) +func (s *GatewayService) sortCandidatesForFallback(accounts []*Account, preferOAuth bool, mode string) { + if mode == "random" { + // 先按优先级排序,然后在同优先级内随机打乱 + sortAccountsByPriorityOnly(accounts, preferOAuth) + shuffleWithinPriority(accounts) + } else { + // 默认按最后使用时间排序 + sortAccountsByPriorityAndLastUsed(accounts, preferOAuth) + } +} + +// sortAccountsByPriorityOnly 仅按优先级排序 +func sortAccountsByPriorityOnly(accounts []*Account, preferOAuth bool) { + sort.SliceStable(accounts, func(i, j int) bool { + a, b := accounts[i], accounts[j] + if a.Priority != b.Priority { + return a.Priority < b.Priority + } + if preferOAuth && a.Type != b.Type { + return a.Type == AccountTypeOAuth + } + return false + }) +} + +// shuffleWithinPriority 在同优先级内随机打乱顺序 +func shuffleWithinPriority(accounts []*Account) { + if len(accounts) <= 1 { + return + } + r := mathrand.New(mathrand.NewSource(time.Now().UnixNano())) + start := 0 + for start < len(accounts) { + priority := accounts[start].Priority + end := start + 1 + for end < len(accounts) && accounts[end].Priority == priority { + end++ + } + // 对 [start, end) 范围内的账户随机打乱 + if end-start > 1 { + r.Shuffle(end-start, func(i, j int) { + accounts[start+i], accounts[start+j] = accounts[start+j], accounts[start+i] + }) + } + start = end + } +} + +// selectAccountForModelWithPlatform 选择单平台账户(完全隔离) +func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) { + preferOAuth := platform == PlatformGemini + routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform) + + var accounts []Account + accountsLoaded := false + + // ============ Model Routing (legacy path): apply before sticky session ============ + // When load-awareness is disabled (e.g. concurrency service not configured), we still honor model routing + // so switching model can switch upstream account within the same sticky session. + if len(routingAccountIDs) > 0 { + if s.debugModelRoutingEnabled() { + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed begin: group_id=%v model=%s platform=%s session=%s routed_ids=%v", + derefGroupID(groupID), requestedModel, platform, shortSessionHash(sessionHash), routingAccountIDs) + } + // 1) Sticky session only applies if the bound account is within the routing set. + if sessionHash != "" && s.cache != nil { + accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) + if err == nil && accountID > 0 && containsInt64(routingAccountIDs, accountID) { + if _, excluded := excludedIDs[accountID]; !excluded { + account, err := s.getSchedulableAccount(ctx, accountID) + // 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台) + if err == nil { + clearSticky := shouldClearStickySession(ctx, account, requestedModel) + if clearSticky { + _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) + } + if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { + if s.debugModelRoutingEnabled() { + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) + } + return account, nil + } + } + } + } + } + + // 2) Select an account from the routed candidates. + forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) + if hasForcePlatform && forcePlatform == "" { + hasForcePlatform = false + } + var err error + accounts, _, err = s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) + if err != nil { + return nil, fmt.Errorf("query accounts failed: %w", err) + } + accountsLoaded = true + + // 提前预取窗口费用+RPM 计数,确保 routing 段内的调度检查调用能命中缓存 + ctx = s.withWindowCostPrefetch(ctx, accounts) + ctx = s.withRPMPrefetch(ctx, accounts) + + routingSet := make(map[int64]struct{}, len(routingAccountIDs)) + for _, id := range routingAccountIDs { + if id > 0 { + routingSet[id] = struct{}{} + } + } + + var selected *Account + for i := range accounts { + acc := &accounts[i] + if _, ok := routingSet[acc.ID]; !ok { + continue + } + if _, excluded := excludedIDs[acc.ID]; excluded { + continue + } + // Scheduler snapshots can be temporarily stale; re-check schedulability here to + // avoid selecting accounts that were recently rate-limited/overloaded. + if !s.isAccountSchedulableForSelection(acc) { + continue + } + if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { + continue + } + if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { + continue + } + if !s.isAccountSchedulableForQuota(acc) { + continue + } + if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { + continue + } + if !s.isAccountSchedulableForRPM(ctx, acc, false) { + continue + } + if selected == nil { + selected = acc + continue + } + if acc.Priority < selected.Priority { + selected = acc + } else if acc.Priority == selected.Priority { + switch { + case acc.LastUsedAt == nil && selected.LastUsedAt != nil: + selected = acc + case acc.LastUsedAt != nil && selected.LastUsedAt == nil: + // keep selected (never used is preferred) + case acc.LastUsedAt == nil && selected.LastUsedAt == nil: + if preferOAuth && acc.Type != selected.Type && acc.Type == AccountTypeOAuth { + selected = acc + } + default: + if acc.LastUsedAt.Before(*selected.LastUsedAt) { + selected = acc + } + } + } + } + + if selected != nil { + if sessionHash != "" && s.cache != nil { + if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil { + logger.LegacyPrintf("service.gateway", "set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) + } + } + if s.debugModelRoutingEnabled() { + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), selected.ID) + } + return selected, nil + } + logger.LegacyPrintf("service.gateway", "[ModelRouting] No routed accounts available for model=%s, falling back to normal selection", requestedModel) + } + + // 1. 查询粘性会话 + if sessionHash != "" && s.cache != nil { + accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) + if err == nil && accountID > 0 { + if _, excluded := excludedIDs[accountID]; !excluded { + account, err := s.getSchedulableAccount(ctx, accountID) + // 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台) + if err == nil { + clearSticky := shouldClearStickySession(ctx, account, requestedModel) + if clearSticky { + _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) + } + if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { + return account, nil + } + } + } + } + } + + // 2. 获取可调度账号列表(单平台) + if !accountsLoaded { + forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) + if hasForcePlatform && forcePlatform == "" { + hasForcePlatform = false + } + var err error + accounts, _, err = s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) + if err != nil { + return nil, fmt.Errorf("query accounts failed: %w", err) + } + } + + // 批量预取窗口费用+RPM 计数,避免逐个账号查询(N+1) + ctx = s.withWindowCostPrefetch(ctx, accounts) + ctx = s.withRPMPrefetch(ctx, accounts) + + // 3. 按优先级+最久未用选择(考虑模型支持) + var selected *Account + for i := range accounts { + acc := &accounts[i] + if _, excluded := excludedIDs[acc.ID]; excluded { + continue + } + // Scheduler snapshots can be temporarily stale; re-check schedulability here to + // avoid selecting accounts that were recently rate-limited/overloaded. + if !s.isAccountSchedulableForSelection(acc) { + continue + } + if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { + continue + } + if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { + continue + } + if !s.isAccountSchedulableForQuota(acc) { + continue + } + if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { + continue + } + if !s.isAccountSchedulableForRPM(ctx, acc, false) { + continue + } + if selected == nil { + selected = acc + continue + } + if acc.Priority < selected.Priority { + selected = acc + } else if acc.Priority == selected.Priority { + switch { + case acc.LastUsedAt == nil && selected.LastUsedAt != nil: + selected = acc + case acc.LastUsedAt != nil && selected.LastUsedAt == nil: + // keep selected (never used is preferred) + case acc.LastUsedAt == nil && selected.LastUsedAt == nil: + if preferOAuth && acc.Type != selected.Type && acc.Type == AccountTypeOAuth { + selected = acc + } + default: + if acc.LastUsedAt.Before(*selected.LastUsedAt) { + selected = acc + } + } + } + } + + if selected == nil { + stats := s.logDetailedSelectionFailure(ctx, groupID, sessionHash, requestedModel, platform, accounts, excludedIDs, false) + if requestedModel != "" { + return nil, fmt.Errorf("%w supporting model: %s (%s)", ErrNoAvailableAccounts, requestedModel, summarizeSelectionFailureStats(stats)) + } + return nil, ErrNoAvailableAccounts + } + + // 4. 建立粘性绑定 + if sessionHash != "" && s.cache != nil { + if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil { + logger.LegacyPrintf("service.gateway", "set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) + } + } + + return selected, nil +} + +// selectAccountWithMixedScheduling 选择账户(支持混合调度) +// 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户 +func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, nativePlatform string) (*Account, error) { + preferOAuth := nativePlatform == PlatformGemini + routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, nativePlatform) + + var accounts []Account + accountsLoaded := false + + // ============ Model Routing (legacy path): apply before sticky session ============ + if len(routingAccountIDs) > 0 { + if s.debugModelRoutingEnabled() { + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy mixed routed begin: group_id=%v model=%s platform=%s session=%s routed_ids=%v", + derefGroupID(groupID), requestedModel, nativePlatform, shortSessionHash(sessionHash), routingAccountIDs) + } + // 1) Sticky session only applies if the bound account is within the routing set. + if sessionHash != "" && s.cache != nil { + accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) + if err == nil && accountID > 0 && containsInt64(routingAccountIDs, accountID) { + if _, excluded := excludedIDs[accountID]; !excluded { + account, err := s.getSchedulableAccount(ctx, accountID) + // 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度 + if err == nil { + clearSticky := shouldClearStickySession(ctx, account, requestedModel) + if clearSticky { + _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) + } + if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { + if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { + if s.debugModelRoutingEnabled() { + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) + } + return account, nil + } + } + } + } + } + } + + // 2) Select an account from the routed candidates. + var err error + accounts, _, err = s.listSchedulableAccounts(ctx, groupID, nativePlatform, false) + if err != nil { + return nil, fmt.Errorf("query accounts failed: %w", err) + } + accountsLoaded = true + + // 提前预取窗口费用+RPM 计数,确保 routing 段内的调度检查调用能命中缓存 + ctx = s.withWindowCostPrefetch(ctx, accounts) + ctx = s.withRPMPrefetch(ctx, accounts) + + routingSet := make(map[int64]struct{}, len(routingAccountIDs)) + for _, id := range routingAccountIDs { + if id > 0 { + routingSet[id] = struct{}{} + } + } + + var selected *Account + for i := range accounts { + acc := &accounts[i] + if _, ok := routingSet[acc.ID]; !ok { + continue + } + if _, excluded := excludedIDs[acc.ID]; excluded { + continue + } + // Scheduler snapshots can be temporarily stale; re-check schedulability here to + // avoid selecting accounts that were recently rate-limited/overloaded. + if !s.isAccountSchedulableForSelection(acc) { + continue + } + // 过滤:原生平台直接通过,antigravity 需要启用混合调度 + if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { + continue + } + if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { + continue + } + if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { + continue + } + if !s.isAccountSchedulableForQuota(acc) { + continue + } + if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { + continue + } + if !s.isAccountSchedulableForRPM(ctx, acc, false) { + continue + } + if selected == nil { + selected = acc + continue + } + if acc.Priority < selected.Priority { + selected = acc + } else if acc.Priority == selected.Priority { + switch { + case acc.LastUsedAt == nil && selected.LastUsedAt != nil: + selected = acc + case acc.LastUsedAt != nil && selected.LastUsedAt == nil: + // keep selected (never used is preferred) + case acc.LastUsedAt == nil && selected.LastUsedAt == nil: + if preferOAuth && acc.Platform == PlatformGemini && selected.Platform == PlatformGemini && acc.Type != selected.Type && acc.Type == AccountTypeOAuth { + selected = acc + } + default: + if acc.LastUsedAt.Before(*selected.LastUsedAt) { + selected = acc + } + } + } + } + + if selected != nil { + if sessionHash != "" && s.cache != nil { + if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil { + logger.LegacyPrintf("service.gateway", "set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) + } + } + if s.debugModelRoutingEnabled() { + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy mixed routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), selected.ID) + } + return selected, nil + } + logger.LegacyPrintf("service.gateway", "[ModelRouting] No routed accounts available for model=%s, falling back to normal selection", requestedModel) + } + + // 1. 查询粘性会话 + if sessionHash != "" && s.cache != nil { + accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) + if err == nil && accountID > 0 { + if _, excluded := excludedIDs[accountID]; !excluded { + account, err := s.getSchedulableAccount(ctx, accountID) + // 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度 + if err == nil { + clearSticky := shouldClearStickySession(ctx, account, requestedModel) + if clearSticky { + _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) + } + if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { + if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { + return account, nil + } + } + } + } + } + } + + // 2. 获取可调度账号列表 + if !accountsLoaded { + var err error + accounts, _, err = s.listSchedulableAccounts(ctx, groupID, nativePlatform, false) + if err != nil { + return nil, fmt.Errorf("query accounts failed: %w", err) + } + } + + // 批量预取窗口费用+RPM 计数,避免逐个账号查询(N+1) + ctx = s.withWindowCostPrefetch(ctx, accounts) + ctx = s.withRPMPrefetch(ctx, accounts) + + // 3. 按优先级+最久未用选择(考虑模型支持和混合调度) + var selected *Account + for i := range accounts { + acc := &accounts[i] + if _, excluded := excludedIDs[acc.ID]; excluded { + continue + } + // Scheduler snapshots can be temporarily stale; re-check schedulability here to + // avoid selecting accounts that were recently rate-limited/overloaded. + if !s.isAccountSchedulableForSelection(acc) { + continue + } + // 过滤:原生平台直接通过,antigravity 需要启用混合调度 + if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { + continue + } + if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { + continue + } + if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { + continue + } + if !s.isAccountSchedulableForQuota(acc) { + continue + } + if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { + continue + } + if !s.isAccountSchedulableForRPM(ctx, acc, false) { + continue + } + if selected == nil { + selected = acc + continue + } + if acc.Priority < selected.Priority { + selected = acc + } else if acc.Priority == selected.Priority { + switch { + case acc.LastUsedAt == nil && selected.LastUsedAt != nil: + selected = acc + case acc.LastUsedAt != nil && selected.LastUsedAt == nil: + // keep selected (never used is preferred) + case acc.LastUsedAt == nil && selected.LastUsedAt == nil: + if preferOAuth && acc.Platform == PlatformGemini && selected.Platform == PlatformGemini && acc.Type != selected.Type && acc.Type == AccountTypeOAuth { + selected = acc + } + default: + if acc.LastUsedAt.Before(*selected.LastUsedAt) { + selected = acc + } + } + } + } + + if selected == nil { + stats := s.logDetailedSelectionFailure(ctx, groupID, sessionHash, requestedModel, nativePlatform, accounts, excludedIDs, true) + if requestedModel != "" { + return nil, fmt.Errorf("%w supporting model: %s (%s)", ErrNoAvailableAccounts, requestedModel, summarizeSelectionFailureStats(stats)) + } + return nil, ErrNoAvailableAccounts + } + + // 4. 建立粘性绑定 + if sessionHash != "" && s.cache != nil { + if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil { + logger.LegacyPrintf("service.gateway", "set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) + } + } + + return selected, nil +} + +type selectionFailureStats struct { + Total int + Eligible int + Excluded int + Unschedulable int + PlatformFiltered int + ModelUnsupported int + ModelRateLimited int + SamplePlatformIDs []int64 + SampleMappingIDs []int64 + SampleRateLimitIDs []string +} + +type selectionFailureDiagnosis struct { + Category string + Detail string +} + +func (s *GatewayService) logDetailedSelectionFailure( + ctx context.Context, + groupID *int64, + sessionHash string, + requestedModel string, + platform string, + accounts []Account, + excludedIDs map[int64]struct{}, + allowMixedScheduling bool, +) selectionFailureStats { + stats := s.collectSelectionFailureStats(ctx, accounts, requestedModel, platform, excludedIDs, allowMixedScheduling) + logger.LegacyPrintf( + "service.gateway", + "[SelectAccountDetailed] group_id=%v model=%s platform=%s session=%s total=%d eligible=%d excluded=%d unschedulable=%d platform_filtered=%d model_unsupported=%d model_rate_limited=%d sample_platform_filtered=%v sample_model_unsupported=%v sample_model_rate_limited=%v", + derefGroupID(groupID), + requestedModel, + platform, + shortSessionHash(sessionHash), + stats.Total, + stats.Eligible, + stats.Excluded, + stats.Unschedulable, + stats.PlatformFiltered, + stats.ModelUnsupported, + stats.ModelRateLimited, + stats.SamplePlatformIDs, + stats.SampleMappingIDs, + stats.SampleRateLimitIDs, + ) + if platform == PlatformSora { + s.logSoraSelectionFailureDetails(ctx, groupID, sessionHash, requestedModel, accounts, excludedIDs, allowMixedScheduling) + } + return stats +} + +func (s *GatewayService) collectSelectionFailureStats( + ctx context.Context, + accounts []Account, + requestedModel string, + platform string, + excludedIDs map[int64]struct{}, + allowMixedScheduling bool, +) selectionFailureStats { + stats := selectionFailureStats{ + Total: len(accounts), + } + + for i := range accounts { + acc := &accounts[i] + diagnosis := s.diagnoseSelectionFailure(ctx, acc, requestedModel, platform, excludedIDs, allowMixedScheduling) + switch diagnosis.Category { + case "excluded": + stats.Excluded++ + case "unschedulable": + stats.Unschedulable++ + case "platform_filtered": + stats.PlatformFiltered++ + stats.SamplePlatformIDs = appendSelectionFailureSampleID(stats.SamplePlatformIDs, acc.ID) + case "model_unsupported": + stats.ModelUnsupported++ + stats.SampleMappingIDs = appendSelectionFailureSampleID(stats.SampleMappingIDs, acc.ID) + case "model_rate_limited": + stats.ModelRateLimited++ + remaining := acc.GetRateLimitRemainingTimeWithContext(ctx, requestedModel).Truncate(time.Second) + stats.SampleRateLimitIDs = appendSelectionFailureRateSample(stats.SampleRateLimitIDs, acc.ID, remaining) + default: + stats.Eligible++ + } + } + + return stats +} + +func (s *GatewayService) diagnoseSelectionFailure( + ctx context.Context, + acc *Account, + requestedModel string, + platform string, + excludedIDs map[int64]struct{}, + allowMixedScheduling bool, +) selectionFailureDiagnosis { + if acc == nil { + return selectionFailureDiagnosis{Category: "unschedulable", Detail: "account_nil"} + } + if _, excluded := excludedIDs[acc.ID]; excluded { + return selectionFailureDiagnosis{Category: "excluded"} + } + if !s.isAccountSchedulableForSelection(acc) { + detail := "generic_unschedulable" + if acc.Platform == PlatformSora { + detail = s.soraUnschedulableReason(acc) + } + return selectionFailureDiagnosis{Category: "unschedulable", Detail: detail} + } + if isPlatformFilteredForSelection(acc, platform, allowMixedScheduling) { + return selectionFailureDiagnosis{ + Category: "platform_filtered", + Detail: fmt.Sprintf("account_platform=%s requested_platform=%s", acc.Platform, strings.TrimSpace(platform)), + } + } + if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { + return selectionFailureDiagnosis{ + Category: "model_unsupported", + Detail: fmt.Sprintf("model=%s", requestedModel), + } + } + if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { + remaining := acc.GetRateLimitRemainingTimeWithContext(ctx, requestedModel).Truncate(time.Second) + return selectionFailureDiagnosis{ + Category: "model_rate_limited", + Detail: fmt.Sprintf("remaining=%s", remaining), + } + } + return selectionFailureDiagnosis{Category: "eligible"} +} + +func (s *GatewayService) logSoraSelectionFailureDetails( + ctx context.Context, + groupID *int64, + sessionHash string, + requestedModel string, + accounts []Account, + excludedIDs map[int64]struct{}, + allowMixedScheduling bool, +) { + const maxLines = 30 + logged := 0 + + for i := range accounts { + if logged >= maxLines { + break + } + acc := &accounts[i] + diagnosis := s.diagnoseSelectionFailure(ctx, acc, requestedModel, PlatformSora, excludedIDs, allowMixedScheduling) + if diagnosis.Category == "eligible" { + continue + } + detail := diagnosis.Detail + if detail == "" { + detail = "-" + } + logger.LegacyPrintf( + "service.gateway", + "[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s account_id=%d account_platform=%s category=%s detail=%s", + derefGroupID(groupID), + requestedModel, + shortSessionHash(sessionHash), + acc.ID, + acc.Platform, + diagnosis.Category, + detail, + ) + logged++ + } + if len(accounts) > maxLines { + logger.LegacyPrintf( + "service.gateway", + "[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s truncated=true total=%d logged=%d", + derefGroupID(groupID), + requestedModel, + shortSessionHash(sessionHash), + len(accounts), + logged, + ) + } +} + +func isPlatformFilteredForSelection(acc *Account, platform string, allowMixedScheduling bool) bool { + if acc == nil { + return true + } + if allowMixedScheduling { + if acc.Platform == PlatformAntigravity { + return !acc.IsMixedSchedulingEnabled() + } + return acc.Platform != platform + } + if strings.TrimSpace(platform) == "" { + return false + } + return acc.Platform != platform +} + +func appendSelectionFailureSampleID(samples []int64, id int64) []int64 { + const limit = 5 + if len(samples) >= limit { + return samples + } + return append(samples, id) +} + +func appendSelectionFailureRateSample(samples []string, accountID int64, remaining time.Duration) []string { + const limit = 5 + if len(samples) >= limit { + return samples + } + return append(samples, fmt.Sprintf("%d(%s)", accountID, remaining)) +} + +func summarizeSelectionFailureStats(stats selectionFailureStats) string { + return fmt.Sprintf( + "total=%d eligible=%d excluded=%d unschedulable=%d platform_filtered=%d model_unsupported=%d model_rate_limited=%d", + stats.Total, + stats.Eligible, + stats.Excluded, + stats.Unschedulable, + stats.PlatformFiltered, + stats.ModelUnsupported, + stats.ModelRateLimited, + ) +} + +// isModelSupportedByAccountWithContext 根据账户平台检查模型支持(带 context) +// 对于 Antigravity 平台,会先获取映射后的最终模型名(包括 thinking 后缀)再检查支持 +func (s *GatewayService) isModelSupportedByAccountWithContext(ctx context.Context, account *Account, requestedModel string) bool { + if account.Platform == PlatformAntigravity { + if strings.TrimSpace(requestedModel) == "" { + return true + } + // 使用与转发阶段一致的映射逻辑:自定义映射优先 → 默认映射兜底 + mapped := mapAntigravityModel(account, requestedModel) + if mapped == "" { + return false + } + // 应用 thinking 后缀后检查最终模型是否在账号映射中 + if enabled, ok := ThinkingEnabledFromContext(ctx); ok { + finalModel := applyThinkingModelSuffix(mapped, enabled) + if finalModel == mapped { + return true // thinking 后缀未改变模型名,映射已通过 + } + return account.IsModelSupported(finalModel) + } + return true + } + return s.isModelSupportedByAccount(account, requestedModel) +} + +// isModelSupportedByAccount 根据账户平台检查模型支持(无 context,用于非 Antigravity 平台) +func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedModel string) bool { + if account.Platform == PlatformAntigravity { + if strings.TrimSpace(requestedModel) == "" { + return true + } + return mapAntigravityModel(account, requestedModel) != "" + } + if account.Platform == PlatformSora { + return s.isSoraModelSupportedByAccount(account, requestedModel) + } + if account.IsBedrock() { + _, ok := ResolveBedrockModelID(account, requestedModel) + return ok + } + // OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID) + if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { + requestedModel = claude.NormalizeModelID(requestedModel) + } + // 其他平台使用账户的模型支持检查 + return account.IsModelSupported(requestedModel) +} + +func (s *GatewayService) isSoraModelSupportedByAccount(account *Account, requestedModel string) bool { + if account == nil { + return false + } + if strings.TrimSpace(requestedModel) == "" { + return true + } + + // 先走原始精确/通配符匹配。 + mapping := account.GetModelMapping() + if len(mapping) == 0 || account.IsModelSupported(requestedModel) { + return true + } + + aliases := buildSoraModelAliases(requestedModel) + if len(aliases) == 0 { + return false + } + + hasSoraSelector := false + for pattern := range mapping { + if !isSoraModelSelector(pattern) { + continue + } + hasSoraSelector = true + if matchPatternAnyAlias(pattern, aliases) { + return true + } + } + + // 兼容旧账号:mapping 存在但未配置任何 Sora 选择器(例如只含 gpt-*), + // 此时不应误拦截 Sora 模型请求。 + if !hasSoraSelector { + return true + } + + return false +} + +func matchPatternAnyAlias(pattern string, aliases []string) bool { + normalizedPattern := strings.ToLower(strings.TrimSpace(pattern)) + if normalizedPattern == "" { + return false + } + for _, alias := range aliases { + if matchWildcard(normalizedPattern, alias) { + return true + } + } + return false +} + +func isSoraModelSelector(pattern string) bool { + p := strings.ToLower(strings.TrimSpace(pattern)) + if p == "" { + return false + } + + switch { + case strings.HasPrefix(p, "sora"), + strings.HasPrefix(p, "gpt-image"), + strings.HasPrefix(p, "prompt-enhance"), + strings.HasPrefix(p, "sy_"): + return true + } + + return p == "video" || p == "image" +} + +func buildSoraModelAliases(requestedModel string) []string { + modelID := strings.ToLower(strings.TrimSpace(requestedModel)) + if modelID == "" { + return nil + } + + aliases := make([]string, 0, 8) + addAlias := func(value string) { + v := strings.ToLower(strings.TrimSpace(value)) + if v == "" { + return + } + for _, existing := range aliases { + if existing == v { + return + } + } + aliases = append(aliases, v) + } + + addAlias(modelID) + cfg, ok := GetSoraModelConfig(modelID) + if ok { + addAlias(cfg.Model) + switch cfg.Type { + case "video": + addAlias("video") + addAlias("sora") + addAlias(soraVideoFamilyAlias(modelID)) + case "image": + addAlias("image") + addAlias("gpt-image") + case "prompt_enhance": + addAlias("prompt-enhance") + } + return aliases + } + + switch { + case strings.HasPrefix(modelID, "sora"): + addAlias("video") + addAlias("sora") + addAlias(soraVideoFamilyAlias(modelID)) + case strings.HasPrefix(modelID, "gpt-image"): + addAlias("image") + addAlias("gpt-image") + case strings.HasPrefix(modelID, "prompt-enhance"): + addAlias("prompt-enhance") + default: + return nil + } + + return aliases +} + +func soraVideoFamilyAlias(modelID string) string { + switch { + case strings.HasPrefix(modelID, "sora2pro-hd"): + return "sora2pro-hd" + case strings.HasPrefix(modelID, "sora2pro"): + return "sora2pro" + case strings.HasPrefix(modelID, "sora2"): + return "sora2" + default: + return "" + } +} + +// GetAccessToken 获取账号凭证 +func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) { + switch account.Type { + case AccountTypeOAuth, AccountTypeSetupToken: + // Both oauth and setup-token use OAuth token flow + return s.getOAuthToken(ctx, account) + case AccountTypeAPIKey: + apiKey := account.GetCredential("api_key") + if apiKey == "" { + return "", "", errors.New("api_key not found in credentials") + } + return apiKey, "apikey", nil + case AccountTypeBedrock: + return "", "bedrock", nil // Bedrock 使用 SigV4 签名或 API Key,由 forwardBedrock 处理 + default: + return "", "", fmt.Errorf("unsupported account type: %s", account.Type) + } +} + +func (s *GatewayService) getOAuthToken(ctx context.Context, account *Account) (string, string, error) { + // 对于 Anthropic OAuth 账号,使用 ClaudeTokenProvider 获取缓存的 token + if account.Platform == PlatformAnthropic && account.Type == AccountTypeOAuth && s.claudeTokenProvider != nil { + accessToken, err := s.claudeTokenProvider.GetAccessToken(ctx, account) + if err != nil { + return "", "", err + } + return accessToken, "oauth", nil + } + + // 其他情况(Gemini 有自己的 TokenProvider,setup-token 类型等)直接从账号读取 + accessToken := account.GetCredential("access_token") + if accessToken == "" { + return "", "", errors.New("access_token not found in credentials") + } + // Token刷新由后台 TokenRefreshService 处理,此处只返回当前token + return accessToken, "oauth", nil +} + +// 重试相关常量 +const ( + // 最大尝试次数(包含首次请求)。过多重试会导致请求堆积与资源耗尽。 + maxRetryAttempts = 5 + + // 指数退避:第 N 次失败后的等待 = retryBaseDelay * 2^(N-1),并且上限为 retryMaxDelay。 + retryBaseDelay = 300 * time.Millisecond + retryMaxDelay = 3 * time.Second + + // 最大重试耗时(包含请求本身耗时 + 退避等待时间)。 + // 用于防止极端情况下 goroutine 长时间堆积导致资源耗尽。 + maxRetryElapsed = 10 * time.Second +) + +func (s *GatewayService) shouldRetryUpstreamError(account *Account, statusCode int) bool { + // OAuth/Setup Token 账号:仅 403 重试 + if account.IsOAuth() { + return statusCode == 403 + } + + // API Key 账号:未配置的错误码重试 + return !account.ShouldHandleErrorCode(statusCode) +} + +// shouldFailoverUpstreamError determines whether an upstream error should trigger account failover. +func (s *GatewayService) shouldFailoverUpstreamError(statusCode int) bool { + switch statusCode { + case 401, 403, 429, 529: + return true + default: + return statusCode >= 500 + } +} + +func retryBackoffDelay(attempt int) time.Duration { + // attempt 从 1 开始,表示第 attempt 次请求刚失败,需要等待后进行第 attempt+1 次请求。 + if attempt <= 0 { + return retryBaseDelay + } + delay := retryBaseDelay * time.Duration(1<<(attempt-1)) + if delay > retryMaxDelay { + return retryMaxDelay + } + return delay +} + +func sleepWithContext(ctx context.Context, d time.Duration) error { + if d <= 0 { + return nil + } + timer := time.NewTimer(d) + defer func() { + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + }() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} + +// isClaudeCodeClient 判断请求是否来自 Claude Code 客户端 +// 简化判断:User-Agent 匹配 + metadata.user_id 存在 +func isClaudeCodeClient(userAgent string, metadataUserID string) bool { + if metadataUserID == "" { + return false + } + return claudeCliUserAgentRe.MatchString(userAgent) +} + +func isClaudeCodeRequest(ctx context.Context, c *gin.Context, parsed *ParsedRequest) bool { + if IsClaudeCodeClient(ctx) { + return true + } + if parsed == nil || c == nil { + return false + } + return isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID) +} + +// systemIncludesClaudeCodePrompt 检查 system 中是否已包含 Claude Code 提示词 +// 使用前缀匹配支持多种变体(标准版、Agent SDK 版等) +func systemIncludesClaudeCodePrompt(system any) bool { + switch v := system.(type) { + case string: + return hasClaudeCodePrefix(v) + case []any: + for _, item := range v { + if m, ok := item.(map[string]any); ok { + if text, ok := m["text"].(string); ok && hasClaudeCodePrefix(text) { + return true + } + } + } + } + return false +} + +// hasClaudeCodePrefix 检查文本是否以 Claude Code 提示词的特征前缀开头 +func hasClaudeCodePrefix(text string) bool { + for _, prefix := range claudeCodePromptPrefixes { + if strings.HasPrefix(text, prefix) { + return true + } + } + return false +} + +// injectClaudeCodePrompt 在 system 开头注入 Claude Code 提示词 +// 处理 null、字符串、数组三种格式 +func injectClaudeCodePrompt(body []byte, system any) []byte { + claudeCodeBlock, err := marshalAnthropicSystemTextBlock(claudeCodeSystemPrompt, true) + if err != nil { + logger.LegacyPrintf("service.gateway", "Warning: failed to build Claude Code prompt block: %v", err) + return body + } + // Opencode plugin applies an extra safeguard: it not only prepends the Claude Code + // banner, it also prefixes the next system instruction with the same banner plus + // a blank line. This helps when upstream concatenates system instructions. + claudeCodePrefix := strings.TrimSpace(claudeCodeSystemPrompt) + + var items [][]byte + + switch v := system.(type) { + case nil: + items = [][]byte{claudeCodeBlock} + case string: + // Be tolerant of older/newer clients that may differ only by trailing whitespace/newlines. + if strings.TrimSpace(v) == "" || strings.TrimSpace(v) == strings.TrimSpace(claudeCodeSystemPrompt) { + items = [][]byte{claudeCodeBlock} + } else { + // Mirror opencode behavior: keep the banner as a separate system entry, + // but also prefix the next system text with the banner. + merged := v + if !strings.HasPrefix(v, claudeCodePrefix) { + merged = claudeCodePrefix + "\n\n" + v + } + nextBlock, buildErr := marshalAnthropicSystemTextBlock(merged, false) + if buildErr != nil { + logger.LegacyPrintf("service.gateway", "Warning: failed to build prefixed Claude Code system block: %v", buildErr) + return body + } + items = [][]byte{claudeCodeBlock, nextBlock} + } + case []any: + items = make([][]byte, 0, len(v)+1) + items = append(items, claudeCodeBlock) + prefixedNext := false + systemResult := gjson.GetBytes(body, "system") + if systemResult.IsArray() { + systemResult.ForEach(func(_, item gjson.Result) bool { + textResult := item.Get("text") + if textResult.Exists() && textResult.Type == gjson.String && + strings.TrimSpace(textResult.String()) == strings.TrimSpace(claudeCodeSystemPrompt) { + return true + } + + raw := []byte(item.Raw) + // Prefix the first subsequent text system block once. + if !prefixedNext && item.Get("type").String() == "text" && textResult.Exists() && textResult.Type == gjson.String { + text := textResult.String() + if strings.TrimSpace(text) != "" && !strings.HasPrefix(text, claudeCodePrefix) { + next, setErr := sjson.SetBytes(raw, "text", claudeCodePrefix+"\n\n"+text) + if setErr == nil { + raw = next + prefixedNext = true + } + } + } + items = append(items, raw) + return true + }) + } else { + for _, item := range v { + m, ok := item.(map[string]any) + if !ok { + raw, marshalErr := json.Marshal(item) + if marshalErr == nil { + items = append(items, raw) + } + continue + } + if text, ok := m["text"].(string); ok && strings.TrimSpace(text) == strings.TrimSpace(claudeCodeSystemPrompt) { + continue + } + if !prefixedNext { + if blockType, _ := m["type"].(string); blockType == "text" { + if text, ok := m["text"].(string); ok && strings.TrimSpace(text) != "" && !strings.HasPrefix(text, claudeCodePrefix) { + m["text"] = claudeCodePrefix + "\n\n" + text + prefixedNext = true + } + } + } + raw, marshalErr := json.Marshal(m) + if marshalErr == nil { + items = append(items, raw) + } + } + } + default: + items = [][]byte{claudeCodeBlock} + } + + result, ok := setJSONRawBytes(body, "system", buildJSONArrayRaw(items)) + if !ok { + logger.LegacyPrintf("service.gateway", "Warning: failed to inject Claude Code prompt") + return body + } + return result +} + +type cacheControlPath struct { + path string + log string +} + +func collectCacheControlPaths(body []byte) (invalidThinking []cacheControlPath, messagePaths []string, systemPaths []string) { + system := gjson.GetBytes(body, "system") + if system.IsArray() { + sysIndex := 0 + system.ForEach(func(_, item gjson.Result) bool { + if item.Get("cache_control").Exists() { + path := fmt.Sprintf("system.%d.cache_control", sysIndex) + if item.Get("type").String() == "thinking" { + invalidThinking = append(invalidThinking, cacheControlPath{ + path: path, + log: "[Warning] Removed illegal cache_control from thinking block in system", + }) + } else { + systemPaths = append(systemPaths, path) + } + } + sysIndex++ + return true + }) + } + + messages := gjson.GetBytes(body, "messages") + if messages.IsArray() { + msgIndex := 0 + messages.ForEach(func(_, msg gjson.Result) bool { + content := msg.Get("content") + if content.IsArray() { + contentIndex := 0 + content.ForEach(func(_, item gjson.Result) bool { + if item.Get("cache_control").Exists() { + path := fmt.Sprintf("messages.%d.content.%d.cache_control", msgIndex, contentIndex) + if item.Get("type").String() == "thinking" { + invalidThinking = append(invalidThinking, cacheControlPath{ + path: path, + log: fmt.Sprintf("[Warning] Removed illegal cache_control from thinking block in messages[%d].content[%d]", msgIndex, contentIndex), + }) + } else { + messagePaths = append(messagePaths, path) + } + } + contentIndex++ + return true + }) + } + msgIndex++ + return true + }) + } + + return invalidThinking, messagePaths, systemPaths +} + +// enforceCacheControlLimit 强制执行 cache_control 块数量限制(最多 4 个) +// 超限时优先从 messages 中移除 cache_control,保护 system 中的缓存控制 +func enforceCacheControlLimit(body []byte) []byte { + if len(body) == 0 { + return body + } + + invalidThinking, messagePaths, systemPaths := collectCacheControlPaths(body) + out := body + modified := false + + // 先清理 thinking 块中的非法 cache_control(thinking 块不支持该字段) + for _, item := range invalidThinking { + if !gjson.GetBytes(out, item.path).Exists() { + continue + } + next, ok := deleteJSONPathBytes(out, item.path) + if !ok { + continue + } + out = next + modified = true + logger.LegacyPrintf("service.gateway", "%s", item.log) + } + + count := len(messagePaths) + len(systemPaths) + if count <= maxCacheControlBlocks { + if modified { + return out + } + return body + } + + // 超限:优先从 messages 中移除,再从 system 中移除 + remaining := count - maxCacheControlBlocks + for _, path := range messagePaths { + if remaining <= 0 { + break + } + if !gjson.GetBytes(out, path).Exists() { + continue + } + next, ok := deleteJSONPathBytes(out, path) + if !ok { + continue + } + out = next + modified = true + remaining-- + } + + for i := len(systemPaths) - 1; i >= 0 && remaining > 0; i-- { + path := systemPaths[i] + if !gjson.GetBytes(out, path).Exists() { + continue + } + next, ok := deleteJSONPathBytes(out, path) + if !ok { + continue + } + out = next + modified = true + remaining-- + } + + if modified { + return out + } + return body +} + +// Forward 转发请求到Claude API +func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) { + startTime := time.Now() + if parsed == nil { + return nil, fmt.Errorf("parse request: empty request") + } + + if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() { + passthroughBody := parsed.Body + passthroughModel := parsed.Model + if passthroughModel != "" { + if mappedModel := account.GetMappedModel(passthroughModel); mappedModel != passthroughModel { + passthroughBody = s.replaceModelInBody(passthroughBody, mappedModel) + logger.LegacyPrintf("service.gateway", "Passthrough model mapping: %s -> %s (account: %s)", parsed.Model, mappedModel, account.Name) + passthroughModel = mappedModel + } + } + return s.forwardAnthropicAPIKeyPassthroughWithInput(ctx, c, account, anthropicPassthroughForwardInput{ + Body: passthroughBody, + RequestModel: passthroughModel, + OriginalModel: parsed.Model, + RequestStream: parsed.Stream, + StartTime: startTime, + }) + } + + if account != nil && account.IsBedrock() { + return s.forwardBedrock(ctx, c, account, parsed, startTime) + } + + // Beta policy: evaluate once; block check + cache filter set for buildUpstreamRequest. + // Always overwrite the cache to prevent stale values from a previous retry with a different account. + if account.Platform == PlatformAnthropic && c != nil { + policy := s.evaluateBetaPolicy(ctx, c.GetHeader("anthropic-beta"), account) + if policy.blockErr != nil { + return nil, policy.blockErr + } + filterSet := policy.filterSet + if filterSet == nil { + filterSet = map[string]struct{}{} + } + c.Set(betaPolicyFilterSetKey, filterSet) + } + + body := parsed.Body + reqModel := parsed.Model + reqStream := parsed.Stream + originalModel := reqModel + + // === DEBUG: 打印客户端原始请求 body === + debugLogRequestBody("CLIENT_ORIGINAL", body) + + isClaudeCode := isClaudeCodeRequest(ctx, c, parsed) + shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode + + if shouldMimicClaudeCode { + // 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要) + // 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词 + if !strings.Contains(strings.ToLower(reqModel), "haiku") && + !systemIncludesClaudeCodePrompt(parsed.System) { + body = injectClaudeCodePrompt(body, parsed.System) + } + + normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true} + if s.identityService != nil { + fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header) + if err == nil && fp != nil { + if metadataUserID := s.buildOAuthMetadataUserID(parsed, account, fp); metadataUserID != "" { + normalizeOpts.injectMetadata = true + normalizeOpts.metadataUserID = metadataUserID + } + } + } + + body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts) + } + + // 强制执行 cache_control 块数量限制(最多 4 个) + body = enforceCacheControlLimit(body) + + // 应用模型映射: + // - APIKey 账号:使用账号级别的显式映射(如果配置),否则透传原始模型名 + // - OAuth/SetupToken 账号:使用 Anthropic 标准映射(短ID → 长ID) + mappedModel := reqModel + mappingSource := "" + if account.Type == AccountTypeAPIKey { + mappedModel = account.GetMappedModel(reqModel) + if mappedModel != reqModel { + mappingSource = "account" + } + } + if mappingSource == "" && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { + normalized := claude.NormalizeModelID(reqModel) + if normalized != reqModel { + mappedModel = normalized + mappingSource = "prefix" + } + } + if mappedModel != reqModel { + // 替换请求体中的模型名 + body = s.replaceModelInBody(body, mappedModel) + reqModel = mappedModel + logger.LegacyPrintf("service.gateway", "Model mapping applied: %s -> %s (account: %s, source=%s)", originalModel, mappedModel, account.Name, mappingSource) + } + + // 获取凭证 + token, tokenType, err := s.GetAccessToken(ctx, account) + if err != nil { + return nil, err + } + + // 获取代理URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + // 调试日志:记录即将转发的账号信息 + logger.LegacyPrintf("service.gateway", "[Forward] Using account: ID=%d Name=%s Platform=%s Type=%s TLSFingerprint=%v Proxy=%s", + account.ID, account.Name, account.Platform, account.Type, account.IsTLSFingerprintEnabled(), proxyURL) + // 重试间复用同一请求体,避免每次 string(body) 产生额外分配。 + setOpsUpstreamRequestBody(c, body) + + // 重试循环 + var resp *http.Response + retryStart := time.Now() + for attempt := 1; attempt <= maxRetryAttempts; attempt++ { + // 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取) + upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) + upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + releaseUpstreamCtx() + if err != nil { + return nil, err + } + + // 发送请求 + resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) + if err != nil { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + // Ensure the client receives an error response (handlers assume Forward writes on non-failover errors). + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + c.JSON(http.StatusBadGateway, gin.H{ + "type": "error", + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream request failed", + }, + }) + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + + // 优先检测thinking block签名错误(400)并重试一次 + if resp.StatusCode == 400 { + respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + if readErr == nil { + _ = resp.Body.Close() + + if s.isThinkingBlockSignatureError(respBody) && s.settingService.IsSignatureRectifierEnabled(ctx) { + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "signature_error", + Message: extractUpstreamErrorMessage(respBody), + Detail: func() string { + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) + } + return "" + }(), + }) + + looksLikeToolSignatureError := func(msg string) bool { + m := strings.ToLower(msg) + return strings.Contains(m, "tool_use") || + strings.Contains(m, "tool_result") || + strings.Contains(m, "functioncall") || + strings.Contains(m, "function_call") || + strings.Contains(m, "functionresponse") || + strings.Contains(m, "function_response") + } + + // 避免在重试预算已耗尽时再发起额外请求 + if time.Since(retryStart) >= maxRetryElapsed { + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + break + } + logger.LegacyPrintf("service.gateway", "[warn] Account %d: thinking blocks have invalid signature, retrying with filtered blocks", account.ID) + + // Conservative two-stage fallback: + // 1) Disable thinking + thinking->text (preserve content) + // 2) Only if upstream still errors AND error message points to tool/function signature issues: + // also downgrade tool_use/tool_result blocks to text. + + filteredBody := FilterThinkingBlocksForRetry(body) + retryCtx, releaseRetryCtx := detachStreamUpstreamContext(ctx, reqStream) + retryReq, buildErr := s.buildUpstreamRequest(retryCtx, c, account, filteredBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + releaseRetryCtx() + if buildErr == nil { + retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) + if retryErr == nil { + if retryResp.StatusCode < 400 { + logger.LegacyPrintf("service.gateway", "Account %d: thinking block retry succeeded (blocks downgraded)", account.ID) + resp = retryResp + break + } + + retryRespBody, retryReadErr := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20)) + _ = retryResp.Body.Close() + if retryReadErr == nil && retryResp.StatusCode == 400 && s.isThinkingBlockSignatureError(retryRespBody) { + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: retryResp.StatusCode, + UpstreamRequestID: retryResp.Header.Get("x-request-id"), + Kind: "signature_retry_thinking", + Message: extractUpstreamErrorMessage(retryRespBody), + Detail: func() string { + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + return truncateString(string(retryRespBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) + } + return "" + }(), + }) + msg2 := extractUpstreamErrorMessage(retryRespBody) + if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed { + logger.LegacyPrintf("service.gateway", "Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID) + filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body) + retryCtx2, releaseRetryCtx2 := detachStreamUpstreamContext(ctx, reqStream) + retryReq2, buildErr2 := s.buildUpstreamRequest(retryCtx2, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + releaseRetryCtx2() + if buildErr2 == nil { + retryResp2, retryErr2 := s.httpUpstream.DoWithTLS(retryReq2, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) + if retryErr2 == nil { + resp = retryResp2 + break + } + if retryResp2 != nil && retryResp2.Body != nil { + _ = retryResp2.Body.Close() + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "signature_retry_tools_request_error", + Message: sanitizeUpstreamErrorMessage(retryErr2.Error()), + }) + logger.LegacyPrintf("service.gateway", "Account %d: tool-downgrade signature retry failed: %v", account.ID, retryErr2) + } else { + logger.LegacyPrintf("service.gateway", "Account %d: tool-downgrade signature retry build failed: %v", account.ID, buildErr2) + } + } + } + + // Fall back to the original retry response context. + resp = &http.Response{ + StatusCode: retryResp.StatusCode, + Header: retryResp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(retryRespBody)), + } + break + } + if retryResp != nil && retryResp.Body != nil { + _ = retryResp.Body.Close() + } + logger.LegacyPrintf("service.gateway", "Account %d: signature error retry failed: %v", account.ID, retryErr) + } else { + logger.LegacyPrintf("service.gateway", "Account %d: signature error retry build request failed: %v", account.ID, buildErr) + } + + // Retry failed: restore original response body and continue handling. + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + break + } + // 不是签名错误(或整流器已关闭),继续检查 budget 约束 + errMsg := extractUpstreamErrorMessage(respBody) + if isThinkingBudgetConstraintError(errMsg) && s.settingService.IsBudgetRectifierEnabled(ctx) { + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "budget_constraint_error", + Message: errMsg, + Detail: func() string { + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) + } + return "" + }(), + }) + + rectifiedBody, applied := RectifyThinkingBudget(body) + if applied && time.Since(retryStart) < maxRetryElapsed { + logger.LegacyPrintf("service.gateway", "Account %d: detected budget_tokens constraint error, retrying with rectified budget (budget_tokens=%d, max_tokens=%d)", account.ID, BudgetRectifyBudgetTokens, BudgetRectifyMaxTokens) + budgetRetryCtx, releaseBudgetRetryCtx := detachStreamUpstreamContext(ctx, reqStream) + budgetRetryReq, buildErr := s.buildUpstreamRequest(budgetRetryCtx, c, account, rectifiedBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + releaseBudgetRetryCtx() + if buildErr == nil { + budgetRetryResp, retryErr := s.httpUpstream.DoWithTLS(budgetRetryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) + if retryErr == nil { + resp = budgetRetryResp + break + } + if budgetRetryResp != nil && budgetRetryResp.Body != nil { + _ = budgetRetryResp.Body.Close() + } + logger.LegacyPrintf("service.gateway", "Account %d: budget rectifier retry failed: %v", account.ID, retryErr) + } else { + logger.LegacyPrintf("service.gateway", "Account %d: budget rectifier retry build failed: %v", account.ID, buildErr) + } + } + } + + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + } + } + + // 检查是否需要通用重试(排除400,因为400已经在上面特殊处理过了) + if resp.StatusCode >= 400 && resp.StatusCode != 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) { + if attempt < maxRetryAttempts { + elapsed := time.Since(retryStart) + if elapsed >= maxRetryElapsed { + break + } + + delay := retryBackoffDelay(attempt) + remaining := maxRetryElapsed - elapsed + if delay > remaining { + delay = remaining + } + if delay <= 0 { + break + } + + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "retry", + Message: extractUpstreamErrorMessage(respBody), + Detail: func() string { + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) + } + return "" + }(), + }) + logger.LegacyPrintf("service.gateway", "Account %d: upstream error %d, retry %d/%d after %v (elapsed=%v/%v)", + account.ID, resp.StatusCode, attempt, maxRetryAttempts, delay, elapsed, maxRetryElapsed) + if err := sleepWithContext(ctx, delay); err != nil { + return nil, err + } + continue + } + // 最后一次尝试也失败,跳出循环处理重试耗尽 + break + } + + // 不需要重试(成功或不可重试的错误),跳出循环 + // DEBUG: 输出响应 headers(用于检测 rate limit 信息) + if account.Platform == PlatformGemini && resp.StatusCode < 400 && s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders { + logger.LegacyPrintf("service.gateway", "[DEBUG] Gemini API Response Headers for account %d:", account.ID) + for k, v := range resp.Header { + logger.LegacyPrintf("service.gateway", "[DEBUG] %s: %v", k, v) + } + } + break + } + if resp == nil || resp.Body == nil { + return nil, errors.New("upstream request failed: empty response") + } + defer func() { _ = resp.Body.Close() }() + + // 处理重试耗尽的情况 + if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) { + if s.shouldFailoverUpstreamError(resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + // 调试日志:打印重试耗尽后的错误响应 + logger.LegacyPrintf("service.gateway", "[Forward] Upstream error (retry exhausted, failover): Account=%d(%s) Status=%d RequestID=%s Body=%s", + account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000)) + + s.handleRetryExhaustedSideEffects(ctx, resp, account) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "retry_exhausted_failover", + Message: extractUpstreamErrorMessage(respBody), + Detail: func() string { + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) + } + return "" + }(), + }) + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + } + } + return s.handleRetryExhaustedError(ctx, resp, c, account) + } + + // 处理可切换账号的错误 + if resp.StatusCode >= 400 && s.shouldFailoverUpstreamError(resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + // 调试日志:打印上游错误响应 + logger.LegacyPrintf("service.gateway", "[Forward] Upstream error (failover): Account=%d(%s) Status=%d RequestID=%s Body=%s", + account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000)) + + s.handleFailoverSideEffects(ctx, resp, account) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "failover", + Message: extractUpstreamErrorMessage(respBody), + Detail: func() string { + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) + } + return "" + }(), + }) + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + } + } + if resp.StatusCode >= 400 { + // 可选:对部分 400 触发 failover(默认关闭以保持语义) + if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 { + respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + if readErr != nil { + // ReadAll failed, fall back to normal error handling without consuming the stream + return s.handleErrorResponse(ctx, resp, c, account) + } + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + if s.shouldFailoverOn400(respBody) { + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "failover_on_400", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + + if s.cfg.Gateway.LogUpstreamErrorBody { + logger.LegacyPrintf("service.gateway", + "Account %d: 400 error, attempting failover: %s", + account.ID, + truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes), + ) + } else { + logger.LegacyPrintf("service.gateway", "Account %d: 400 error, attempting failover", account.ID) + } + s.handleFailoverSideEffects(ctx, resp, account) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} + } + } + return s.handleErrorResponse(ctx, resp, c, account) + } + + // 处理正常响应 + + // 触发上游接受回调(提前释放串行锁,不等流完成) + if parsed.OnUpstreamAccepted != nil { + parsed.OnUpstreamAccepted() + } + + var usage *ClaudeUsage + var firstTokenMs *int + var clientDisconnect bool + if reqStream { + streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel, shouldMimicClaudeCode) + if err != nil { + if err.Error() == "have error in stream" { + return nil, &UpstreamFailoverError{ + StatusCode: 403, + } + } + return nil, err + } + usage = streamResult.usage + firstTokenMs = streamResult.firstTokenMs + clientDisconnect = streamResult.clientDisconnect + } else { + usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel) + if err != nil { + return nil, err + } + } + + return &ForwardResult{ + RequestID: resp.Header.Get("x-request-id"), + Usage: *usage, + Model: originalModel, // 使用原始模型用于计费和日志 + UpstreamModel: mappedModel, + Stream: reqStream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ClientDisconnect: clientDisconnect, + }, nil +} + +type anthropicPassthroughForwardInput struct { + Body []byte + RequestModel string + OriginalModel string + RequestStream bool + StartTime time.Time +} + +func (s *GatewayService) forwardAnthropicAPIKeyPassthrough( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + reqModel string, + originalModel string, + reqStream bool, + startTime time.Time, +) (*ForwardResult, error) { + return s.forwardAnthropicAPIKeyPassthroughWithInput(ctx, c, account, anthropicPassthroughForwardInput{ + Body: body, + RequestModel: reqModel, + OriginalModel: originalModel, + RequestStream: reqStream, + StartTime: startTime, + }) +} + +func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput( + ctx context.Context, + c *gin.Context, + account *Account, + input anthropicPassthroughForwardInput, +) (*ForwardResult, error) { + token, tokenType, err := s.GetAccessToken(ctx, account) + if err != nil { + return nil, err + } + if tokenType != "apikey" { + return nil, fmt.Errorf("anthropic api key passthrough requires apikey token, got: %s", tokenType) + } + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + logger.LegacyPrintf("service.gateway", "[Anthropic 自动透传] 命中 API Key 透传分支: account=%d name=%s model=%s stream=%v", + account.ID, account.Name, input.RequestModel, input.RequestStream) + + if c != nil { + c.Set("anthropic_passthrough", true) + } + // 重试间复用同一请求体,避免每次 string(body) 产生额外分配。 + setOpsUpstreamRequestBody(c, input.Body) + + var resp *http.Response + retryStart := time.Now() + for attempt := 1; attempt <= maxRetryAttempts; attempt++ { + upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, input.RequestStream) + upstreamReq, err := s.buildUpstreamRequestAnthropicAPIKeyPassthrough(upstreamCtx, c, account, input.Body, token) + releaseUpstreamCtx() + if err != nil { + return nil, err + } + + resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) + if err != nil { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Passthrough: true, + Kind: "request_error", + Message: safeErr, + }) + c.JSON(http.StatusBadGateway, gin.H{ + "type": "error", + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream request failed", + }, + }) + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + + // 透传分支禁止 400 请求体降级重试(该重试会改写请求体) + if resp.StatusCode >= 400 && resp.StatusCode != 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) { + if attempt < maxRetryAttempts { + elapsed := time.Since(retryStart) + if elapsed >= maxRetryElapsed { + break + } + + delay := retryBackoffDelay(attempt) + remaining := maxRetryElapsed - elapsed + if delay > remaining { + delay = remaining + } + if delay <= 0 { + break + } + + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Passthrough: true, + Kind: "retry", + Message: extractUpstreamErrorMessage(respBody), + Detail: func() string { + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) + } + return "" + }(), + }) + logger.LegacyPrintf("service.gateway", "Anthropic passthrough account %d: upstream error %d, retry %d/%d after %v (elapsed=%v/%v)", + account.ID, resp.StatusCode, attempt, maxRetryAttempts, delay, elapsed, maxRetryElapsed) + if err := sleepWithContext(ctx, delay); err != nil { + return nil, err + } + continue + } + break + } + + break + } + if resp == nil || resp.Body == nil { + return nil, errors.New("upstream request failed: empty response") + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) { + if s.shouldFailoverUpstreamError(resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + logger.LegacyPrintf("service.gateway", "[Anthropic Passthrough] Upstream error (retry exhausted, failover): Account=%d(%s) Status=%d RequestID=%s Body=%s", + account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000)) + + s.handleRetryExhaustedSideEffects(ctx, resp, account) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Passthrough: true, + Kind: "retry_exhausted_failover", + Message: extractUpstreamErrorMessage(respBody), + Detail: func() string { + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) + } + return "" + }(), + }) + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + } + } + return s.handleRetryExhaustedError(ctx, resp, c, account) + } + + if resp.StatusCode >= 400 && s.shouldFailoverUpstreamError(resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + logger.LegacyPrintf("service.gateway", "[Anthropic Passthrough] Upstream error (failover): Account=%d(%s) Status=%d RequestID=%s Body=%s", + account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000)) + + s.handleFailoverSideEffects(ctx, resp, account) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Passthrough: true, + Kind: "failover", + Message: extractUpstreamErrorMessage(respBody), + Detail: func() string { + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) + } + return "" + }(), + }) + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + } + } + + if resp.StatusCode >= 400 { + return s.handleErrorResponse(ctx, resp, c, account) + } + + var usage *ClaudeUsage + var firstTokenMs *int + var clientDisconnect bool + if input.RequestStream { + streamResult, err := s.handleStreamingResponseAnthropicAPIKeyPassthrough(ctx, resp, c, account, input.StartTime, input.RequestModel) + if err != nil { + return nil, err + } + usage = streamResult.usage + firstTokenMs = streamResult.firstTokenMs + clientDisconnect = streamResult.clientDisconnect + } else { + usage, err = s.handleNonStreamingResponseAnthropicAPIKeyPassthrough(ctx, resp, c, account) + if err != nil { + return nil, err + } + } + if usage == nil { + usage = &ClaudeUsage{} + } + + return &ForwardResult{ + RequestID: resp.Header.Get("x-request-id"), + Usage: *usage, + Model: input.OriginalModel, + UpstreamModel: input.RequestModel, + Stream: input.RequestStream, + Duration: time.Since(input.StartTime), + FirstTokenMs: firstTokenMs, + ClientDisconnect: clientDisconnect, + }, nil +} + +func (s *GatewayService) buildUpstreamRequestAnthropicAPIKeyPassthrough( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + token string, +) (*http.Request, error) { + targetURL := claudeAPIURL + baseURL := account.GetBaseURL() + if baseURL != "" { + validatedURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, err + } + targetURL = validatedURL + "/v1/messages?beta=true" + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + + if c != nil && c.Request != nil { + for key, values := range c.Request.Header { + lowerKey := strings.ToLower(strings.TrimSpace(key)) + if !allowedHeaders[lowerKey] { + continue + } + for _, v := range values { + req.Header.Add(key, v) + } + } + } + + // 覆盖入站鉴权残留,并注入上游认证 + req.Header.Del("authorization") + req.Header.Del("x-api-key") + req.Header.Del("x-goog-api-key") + req.Header.Del("cookie") + req.Header.Set("x-api-key", token) + + if req.Header.Get("content-type") == "" { + req.Header.Set("content-type", "application/json") + } + if req.Header.Get("anthropic-version") == "" { + req.Header.Set("anthropic-version", "2023-06-01") + } + + return req, nil +} + +func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough( + ctx context.Context, + resp *http.Response, + c *gin.Context, + account *Account, + startTime time.Time, + model string, +) (*streamingResult, error) { + if s.rateLimitService != nil { + s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) + } + + writeAnthropicPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + + contentType := strings.TrimSpace(resp.Header.Get("Content-Type")) + if contentType == "" { + contentType = "text/event-stream" + } + c.Header("Content-Type", contentType) + if c.Writer.Header().Get("Cache-Control") == "" { + c.Header("Cache-Control", "no-cache") + } + if c.Writer.Header().Get("Connection") == "" { + c.Header("Connection", "keep-alive") + } + c.Header("X-Accel-Buffering", "no") + if v := resp.Header.Get("x-request-id"); v != "" { + c.Header("x-request-id", v) + } + + w := c.Writer + flusher, ok := w.(http.Flusher) + if !ok { + return nil, errors.New("streaming not supported") + } + + usage := &ClaudeUsage{} + var firstTokenMs *int + clientDisconnected := false + sawTerminalEvent := false + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) + + type scanEvent struct { + line string + err error + } + events := make(chan scanEvent, 16) + done := make(chan struct{}) + sendEvent := func(ev scanEvent) bool { + select { + case events <- ev: + return true + case <-done: + return false + } + } + var lastReadAt int64 + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + go func(scanBuf *sseScannerBuf64K) { + defer putSSEScannerBuf64K(scanBuf) + defer close(events) + for scanner.Scan() { + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + if !sendEvent(scanEvent{line: scanner.Text()}) { + return + } + } + if err := scanner.Err(); err != nil { + _ = sendEvent(scanEvent{err: err}) + } + }(scanBuf) + defer close(done) + + streamInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { + streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second + } + var intervalTicker *time.Ticker + if streamInterval > 0 { + intervalTicker = time.NewTicker(streamInterval) + defer intervalTicker.Stop() + } + var intervalCh <-chan time.Time + if intervalTicker != nil { + intervalCh = intervalTicker.C + } + + for { + select { + case ev, ok := <-events: + if !ok { + if !clientDisconnected { + // 兜底补刷,确保最后一个未以空行结尾的事件也能及时送达客户端。 + flusher.Flush() + } + if !sawTerminalEvent { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, fmt.Errorf("stream usage incomplete: missing terminal event") + } + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil + } + if ev.err != nil { + if sawTerminalEvent { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil + } + if clientDisconnected { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after disconnect: %w", ev.err) + } + if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete: %w", ev.err) + } + if errors.Is(ev.err, bufio.ErrTooLong) { + logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err + } + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err) + } + + line := ev.line + if data, ok := extractAnthropicSSEDataLine(line); ok { + trimmed := strings.TrimSpace(data) + if anthropicStreamEventIsTerminal("", trimmed) { + sawTerminalEvent = true + } + if firstTokenMs == nil && trimmed != "" && trimmed != "[DONE]" { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + s.parseSSEUsagePassthrough(data, usage) + } else { + trimmed := strings.TrimSpace(line) + if strings.HasPrefix(trimmed, "event:") && anthropicStreamEventIsTerminal(strings.TrimSpace(strings.TrimPrefix(trimmed, "event:")), "") { + sawTerminalEvent = true + } + } + + if !clientDisconnected { + if _, err := io.WriteString(w, line); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID) + } else if _, err := io.WriteString(w, "\n"); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID) + } else if line == "" { + // 按 SSE 事件边界刷出,减少每行 flush 带来的 syscall 开销。 + flusher.Flush() + } + } + + case <-intervalCh: + lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) + if time.Since(lastRead) < streamInterval { + continue + } + if clientDisconnected { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after timeout") + } + logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Stream data interval timeout: account=%d model=%s interval=%s", account.ID, model, streamInterval) + if s.rateLimitService != nil { + s.rateLimitService.HandleStreamTimeout(ctx, account, model) + } + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") + } + } +} + +func extractAnthropicSSEDataLine(line string) (string, bool) { + if !strings.HasPrefix(line, "data:") { + return "", false + } + start := len("data:") + for start < len(line) { + if line[start] != ' ' && line[start] != '\t' { + break + } + start++ + } + return line[start:], true +} + +func (s *GatewayService) parseSSEUsagePassthrough(data string, usage *ClaudeUsage) { + if usage == nil || data == "" || data == "[DONE]" { + return + } + + parsed := gjson.Parse(data) + switch parsed.Get("type").String() { + case "message_start": + msgUsage := parsed.Get("message.usage") + if msgUsage.Exists() { + usage.InputTokens = int(msgUsage.Get("input_tokens").Int()) + usage.CacheCreationInputTokens = int(msgUsage.Get("cache_creation_input_tokens").Int()) + usage.CacheReadInputTokens = int(msgUsage.Get("cache_read_input_tokens").Int()) + + // 保持与通用解析一致:message_start 允许覆盖 5m/1h 明细(包括 0)。 + cc5m := msgUsage.Get("cache_creation.ephemeral_5m_input_tokens") + cc1h := msgUsage.Get("cache_creation.ephemeral_1h_input_tokens") + if cc5m.Exists() || cc1h.Exists() { + usage.CacheCreation5mTokens = int(cc5m.Int()) + usage.CacheCreation1hTokens = int(cc1h.Int()) + } + } + case "message_delta": + deltaUsage := parsed.Get("usage") + if deltaUsage.Exists() { + if v := deltaUsage.Get("input_tokens").Int(); v > 0 { + usage.InputTokens = int(v) + } + if v := deltaUsage.Get("output_tokens").Int(); v > 0 { + usage.OutputTokens = int(v) + } + if v := deltaUsage.Get("cache_creation_input_tokens").Int(); v > 0 { + usage.CacheCreationInputTokens = int(v) + } + if v := deltaUsage.Get("cache_read_input_tokens").Int(); v > 0 { + usage.CacheReadInputTokens = int(v) + } + + cc5m := deltaUsage.Get("cache_creation.ephemeral_5m_input_tokens") + cc1h := deltaUsage.Get("cache_creation.ephemeral_1h_input_tokens") + if cc5m.Exists() && cc5m.Int() > 0 { + usage.CacheCreation5mTokens = int(cc5m.Int()) + } + if cc1h.Exists() && cc1h.Int() > 0 { + usage.CacheCreation1hTokens = int(cc1h.Int()) + } + } + } + + if usage.CacheReadInputTokens == 0 { + if cached := parsed.Get("message.usage.cached_tokens").Int(); cached > 0 { + usage.CacheReadInputTokens = int(cached) + } + if cached := parsed.Get("usage.cached_tokens").Int(); usage.CacheReadInputTokens == 0 && cached > 0 { + usage.CacheReadInputTokens = int(cached) + } + } + if usage.CacheCreationInputTokens == 0 { + cc5m := parsed.Get("message.usage.cache_creation.ephemeral_5m_input_tokens").Int() + cc1h := parsed.Get("message.usage.cache_creation.ephemeral_1h_input_tokens").Int() + if cc5m == 0 && cc1h == 0 { + cc5m = parsed.Get("usage.cache_creation.ephemeral_5m_input_tokens").Int() + cc1h = parsed.Get("usage.cache_creation.ephemeral_1h_input_tokens").Int() + } + total := cc5m + cc1h + if total > 0 { + usage.CacheCreationInputTokens = int(total) + } + } +} + +func parseClaudeUsageFromResponseBody(body []byte) *ClaudeUsage { + usage := &ClaudeUsage{} + if len(body) == 0 { + return usage + } + + parsed := gjson.ParseBytes(body) + usageNode := parsed.Get("usage") + if !usageNode.Exists() { + return usage + } + + usage.InputTokens = int(usageNode.Get("input_tokens").Int()) + usage.OutputTokens = int(usageNode.Get("output_tokens").Int()) + usage.CacheCreationInputTokens = int(usageNode.Get("cache_creation_input_tokens").Int()) + usage.CacheReadInputTokens = int(usageNode.Get("cache_read_input_tokens").Int()) + + cc5m := usageNode.Get("cache_creation.ephemeral_5m_input_tokens").Int() + cc1h := usageNode.Get("cache_creation.ephemeral_1h_input_tokens").Int() + if cc5m > 0 || cc1h > 0 { + usage.CacheCreation5mTokens = int(cc5m) + usage.CacheCreation1hTokens = int(cc1h) + } + if usage.CacheCreationInputTokens == 0 && (cc5m > 0 || cc1h > 0) { + usage.CacheCreationInputTokens = int(cc5m + cc1h) + } + if usage.CacheReadInputTokens == 0 { + if cached := usageNode.Get("cached_tokens").Int(); cached > 0 { + usage.CacheReadInputTokens = int(cached) + } + } + return usage +} + +func (s *GatewayService) handleNonStreamingResponseAnthropicAPIKeyPassthrough( + ctx context.Context, + resp *http.Response, + c *gin.Context, + account *Account, +) (*ClaudeUsage, error) { + if s.rateLimitService != nil { + s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) + } + + maxBytes := resolveUpstreamResponseReadLimit(s.cfg) + body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) + if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + c.JSON(http.StatusBadGateway, gin.H{ + "type": "error", + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream response too large", + }, + }) + } + return nil, err + } + + usage := parseClaudeUsageFromResponseBody(body) + + writeAnthropicPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + contentType := strings.TrimSpace(resp.Header.Get("Content-Type")) + if contentType == "" { + contentType = "application/json" + } + c.Data(resp.StatusCode, contentType, body) + return usage, nil +} + +func writeAnthropicPassthroughResponseHeaders(dst http.Header, src http.Header, filter *responseheaders.CompiledHeaderFilter) { + if dst == nil || src == nil { + return + } + if filter != nil { + responseheaders.WriteFilteredHeaders(dst, src, filter) + return + } + if v := strings.TrimSpace(src.Get("Content-Type")); v != "" { + dst.Set("Content-Type", v) + } + if v := strings.TrimSpace(src.Get("x-request-id")); v != "" { + dst.Set("x-request-id", v) + } +} + +// forwardBedrock 转发请求到 AWS Bedrock +func (s *GatewayService) forwardBedrock( + ctx context.Context, + c *gin.Context, + account *Account, + parsed *ParsedRequest, + startTime time.Time, +) (*ForwardResult, error) { + reqModel := parsed.Model + reqStream := parsed.Stream + body := parsed.Body + + region := bedrockRuntimeRegion(account) + mappedModel, ok := ResolveBedrockModelID(account, reqModel) + if !ok { + return nil, fmt.Errorf("unsupported bedrock model: %s", reqModel) + } + if mappedModel != reqModel { + logger.LegacyPrintf("service.gateway", "[Bedrock] Model mapping: %s -> %s (account: %s)", reqModel, mappedModel, account.Name) + } + + betaHeader := "" + if c != nil && c.Request != nil { + betaHeader = c.GetHeader("anthropic-beta") + } + + // 准备请求体(注入 anthropic_version/anthropic_beta,移除 Bedrock 不支持的字段,清理 cache_control) + betaTokens, err := s.resolveBedrockBetaTokensForRequest(ctx, account, betaHeader, body, mappedModel) + if err != nil { + return nil, err + } + + bedrockBody, err := PrepareBedrockRequestBodyWithTokens(body, mappedModel, betaTokens) + if err != nil { + return nil, fmt.Errorf("prepare bedrock request body: %w", err) + } + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + logger.LegacyPrintf("service.gateway", "[Bedrock] 命中 Bedrock 分支: account=%d name=%s model=%s->%s stream=%v", + account.ID, account.Name, reqModel, mappedModel, reqStream) + + // 根据账号类型选择认证方式 + var signer *BedrockSigner + var bedrockAPIKey string + if account.IsBedrockAPIKey() { + bedrockAPIKey = account.GetCredential("api_key") + if bedrockAPIKey == "" { + return nil, fmt.Errorf("api_key not found in bedrock credentials") + } + } else { + signer, err = NewBedrockSignerFromAccount(account) + if err != nil { + return nil, fmt.Errorf("create bedrock signer: %w", err) + } + } + + // 执行上游请求(含重试) + resp, err := s.executeBedrockUpstream(ctx, c, account, bedrockBody, mappedModel, region, reqStream, signer, bedrockAPIKey, proxyURL) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + // 将 Bedrock 的 x-amzn-requestid 映射到 x-request-id, + // 使通用错误处理函数(handleErrorResponse、handleRetryExhaustedError)能正确提取 AWS request ID。 + if awsReqID := resp.Header.Get("x-amzn-requestid"); awsReqID != "" && resp.Header.Get("x-request-id") == "" { + resp.Header.Set("x-request-id", awsReqID) + } + + // 错误/failover 处理 + if resp.StatusCode >= 400 { + return s.handleBedrockUpstreamErrors(ctx, resp, c, account) + } + + // 响应处理 + var usage *ClaudeUsage + var firstTokenMs *int + var clientDisconnect bool + if reqStream { + streamResult, err := s.handleBedrockStreamingResponse(ctx, resp, c, account, startTime, reqModel) + if err != nil { + return nil, err + } + usage = streamResult.usage + firstTokenMs = streamResult.firstTokenMs + clientDisconnect = streamResult.clientDisconnect + } else { + usage, err = s.handleBedrockNonStreamingResponse(ctx, resp, c, account) + if err != nil { + return nil, err + } + } + if usage == nil { + usage = &ClaudeUsage{} + } + + return &ForwardResult{ + RequestID: resp.Header.Get("x-amzn-requestid"), + Usage: *usage, + Model: reqModel, + UpstreamModel: mappedModel, + Stream: reqStream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ClientDisconnect: clientDisconnect, + }, nil +} + +// executeBedrockUpstream 执行 Bedrock 上游请求(含重试逻辑) +func (s *GatewayService) executeBedrockUpstream( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + modelID string, + region string, + stream bool, + signer *BedrockSigner, + apiKey string, + proxyURL string, +) (*http.Response, error) { + var resp *http.Response + var err error + retryStart := time.Now() + for attempt := 1; attempt <= maxRetryAttempts; attempt++ { + var upstreamReq *http.Request + if account.IsBedrockAPIKey() { + upstreamReq, err = s.buildUpstreamRequestBedrockAPIKey(ctx, body, modelID, region, stream, apiKey) + } else { + upstreamReq, err = s.buildUpstreamRequestBedrock(ctx, body, modelID, region, stream, signer) + } + if err != nil { + return nil, err + } + + resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, false) + if err != nil { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + c.JSON(http.StatusBadGateway, gin.H{ + "type": "error", + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream request failed", + }, + }) + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + + if resp.StatusCode >= 400 && resp.StatusCode != 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) { + if attempt < maxRetryAttempts { + elapsed := time.Since(retryStart) + if elapsed >= maxRetryElapsed { + break + } + + delay := retryBackoffDelay(attempt) + remaining := maxRetryElapsed - elapsed + if delay > remaining { + delay = remaining + } + if delay <= 0 { + break + } + + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + Kind: "retry", + Message: extractUpstreamErrorMessage(respBody), + Detail: func() string { + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) + } + return "" + }(), + }) + logger.LegacyPrintf("service.gateway", "[Bedrock] account %d: upstream error %d, retry %d/%d after %v", + account.ID, resp.StatusCode, attempt, maxRetryAttempts, delay) + if err := sleepWithContext(ctx, delay); err != nil { + return nil, err + } + continue + } + break + } + + break + } + if resp == nil || resp.Body == nil { + return nil, errors.New("upstream request failed: empty response") + } + return resp, nil +} + +// handleBedrockUpstreamErrors 处理 Bedrock 上游 4xx/5xx 错误(failover + 错误响应) +func (s *GatewayService) handleBedrockUpstreamErrors( + ctx context.Context, + resp *http.Response, + c *gin.Context, + account *Account, +) (*ForwardResult, error) { + // retry exhausted + failover + if s.shouldRetryUpstreamError(account, resp.StatusCode) { + if s.shouldFailoverUpstreamError(resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + logger.LegacyPrintf("service.gateway", "[Bedrock] Upstream error (retry exhausted, failover): Account=%d(%s) Status=%d Body=%s", + account.ID, account.Name, resp.StatusCode, truncateString(string(respBody), 1000)) + + s.handleRetryExhaustedSideEffects(ctx, resp, account) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + Kind: "retry_exhausted_failover", + Message: extractUpstreamErrorMessage(respBody), + }) + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + } + } + return s.handleRetryExhaustedError(ctx, resp, c, account) + } + + // non-retryable failover + if s.shouldFailoverUpstreamError(resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + s.handleFailoverSideEffects(ctx, resp, account) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + Kind: "failover", + Message: extractUpstreamErrorMessage(respBody), + }) + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + } + } + + // other errors + return s.handleErrorResponse(ctx, resp, c, account) +} + +// buildUpstreamRequestBedrock 构建 Bedrock 上游请求 +func (s *GatewayService) buildUpstreamRequestBedrock( + ctx context.Context, + body []byte, + modelID string, + region string, + stream bool, + signer *BedrockSigner, +) (*http.Request, error) { + targetURL := BuildBedrockURL(region, modelID, stream) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + // SigV4 签名 + if err := signer.SignRequest(ctx, req, body); err != nil { + return nil, fmt.Errorf("sign bedrock request: %w", err) + } + + return req, nil +} + +// buildUpstreamRequestBedrockAPIKey 构建 Bedrock API Key (Bearer Token) 上游请求 +func (s *GatewayService) buildUpstreamRequestBedrockAPIKey( + ctx context.Context, + body []byte, + modelID string, + region string, + stream bool, + apiKey string, +) (*http.Request, error) { + targetURL := BuildBedrockURL(region, modelID, stream) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + + return req, nil +} + +// handleBedrockNonStreamingResponse 处理 Bedrock 非流式响应 +// Bedrock InvokeModel 非流式响应的 body 格式与 Claude API 兼容 +func (s *GatewayService) handleBedrockNonStreamingResponse( + ctx context.Context, + resp *http.Response, + c *gin.Context, + account *Account, +) (*ClaudeUsage, error) { + maxBytes := resolveUpstreamResponseReadLimit(s.cfg) + body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) + if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + c.JSON(http.StatusBadGateway, gin.H{ + "type": "error", + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream response too large", + }, + }) + } + return nil, err + } + + // 转换 Bedrock 特有的 amazon-bedrock-invocationMetrics 为标准 Anthropic usage 格式 + // 并移除该字段避免透传给客户端 + body = transformBedrockInvocationMetrics(body) + + usage := parseClaudeUsageFromResponseBody(body) + + c.Header("Content-Type", "application/json") + if v := resp.Header.Get("x-amzn-requestid"); v != "" { + c.Header("x-request-id", v) + } + c.Data(resp.StatusCode, "application/json", body) + return usage, nil +} + +func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool, mimicClaudeCode bool) (*http.Request, error) { + // 确定目标URL + targetURL := claudeAPIURL + if account.Type == AccountTypeAPIKey { + baseURL := account.GetBaseURL() + if baseURL != "" { + validatedURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, err + } + targetURL = validatedURL + "/v1/messages?beta=true" + } + } + + clientHeaders := http.Header{} + if c != nil && c.Request != nil { + clientHeaders = c.Request.Header + } + + // OAuth账号:应用统一指纹 + var fingerprint *Fingerprint + if account.IsOAuth() && s.identityService != nil { + // 1. 获取或创建指纹(包含随机生成的ClientID) + fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, clientHeaders) + if err != nil { + logger.LegacyPrintf("service.gateway", "Warning: failed to get fingerprint for account %d: %v", account.ID, err) + // 失败时降级为透传原始headers + } else { + fingerprint = fp + + // 2. 重写metadata.user_id(需要指纹中的ClientID和账号的account_uuid) + // 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值 + accountUUID := account.GetExtraString("account_uuid") + if accountUUID != "" && fp.ClientID != "" { + if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID, fp.UserAgent); err == nil && len(newBody) > 0 { + body = newBody + } + } + } + } + + // === DEBUG: 打印转发给上游的 body(metadata 已重写) === + debugLogRequestBody("UPSTREAM_FORWARD", body) + + req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + + // 设置认证头 + if tokenType == "oauth" { + req.Header.Set("authorization", "Bearer "+token) + } else { + req.Header.Set("x-api-key", token) + } + + // 白名单透传headers + for key, values := range clientHeaders { + lowerKey := strings.ToLower(key) + if allowedHeaders[lowerKey] { + for _, v := range values { + req.Header.Add(key, v) + } + } + } + + // OAuth账号:应用缓存的指纹到请求头(覆盖白名单透传的头) + if fingerprint != nil { + s.identityService.ApplyFingerprint(req, fingerprint) + } + + // 确保必要的headers存在 + if req.Header.Get("content-type") == "" { + req.Header.Set("content-type", "application/json") + } + if req.Header.Get("anthropic-version") == "" { + req.Header.Set("anthropic-version", "2023-06-01") + } + if tokenType == "oauth" { + applyClaudeOAuthHeaderDefaults(req, reqStream) + } + + // Build effective drop set: merge static defaults with dynamic beta policy filter rules + policyFilterSet := s.getBetaPolicyFilterSet(ctx, c, account) + effectiveDropSet := mergeDropSets(policyFilterSet) + effectiveDropWithClaudeCodeSet := mergeDropSets(policyFilterSet, claude.BetaClaudeCode) + + // 处理 anthropic-beta header(OAuth 账号需要包含 oauth beta) + if tokenType == "oauth" { + if mimicClaudeCode { + // 非 Claude Code 客户端:按 opencode 的策略处理: + // - 强制 Claude Code 指纹相关请求头(尤其是 user-agent/x-stainless/x-app) + // - 保留 incoming beta 的同时,确保 OAuth 所需 beta 存在 + applyClaudeCodeMimicHeaders(req, reqStream) + + incomingBeta := req.Header.Get("anthropic-beta") + // Match real Claude CLI traffic (per mitmproxy reports): + // messages requests typically use only oauth + interleaved-thinking. + // Also drop claude-code beta if a downstream client added it. + requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking} + req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropWithClaudeCodeSet)) + } else { + // Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta + clientBetaHeader := req.Header.Get("anthropic-beta") + req.Header.Set("anthropic-beta", stripBetaTokensWithSet(s.getBetaHeader(modelID, clientBetaHeader), effectiveDropSet)) + } + } else { + // API-key accounts: apply beta policy filter to strip controlled tokens + if existingBeta := req.Header.Get("anthropic-beta"); existingBeta != "" { + req.Header.Set("anthropic-beta", stripBetaTokensWithSet(existingBeta, effectiveDropSet)) + } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey { + // API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭) + if requestNeedsBetaFeatures(body) { + if beta := defaultAPIKeyBetaHeader(body); beta != "" { + req.Header.Set("anthropic-beta", beta) + } + } + } + } + + // Always capture a compact fingerprint line for later error diagnostics. + // We only print it when needed (or when the explicit debug flag is enabled). + if c != nil && tokenType == "oauth" { + c.Set(claudeMimicDebugInfoKey, buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode)) + } + if s.debugClaudeMimicEnabled() { + logClaudeMimicDebug(req, body, account, tokenType, mimicClaudeCode) + } + + return req, nil +} + +// getBetaHeader 处理anthropic-beta header +// 对于OAuth账号,需要确保包含oauth-2025-04-20 +func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string) string { + // 如果客户端传了anthropic-beta + if clientBetaHeader != "" { + // 已包含oauth beta则直接返回 + if strings.Contains(clientBetaHeader, claude.BetaOAuth) { + return clientBetaHeader + } + + // 需要添加oauth beta + parts := strings.Split(clientBetaHeader, ",") + for i, p := range parts { + parts[i] = strings.TrimSpace(p) + } + + // 在claude-code-20250219后面插入oauth beta + claudeCodeIdx := -1 + for i, p := range parts { + if p == claude.BetaClaudeCode { + claudeCodeIdx = i + break + } + } + + if claudeCodeIdx >= 0 { + // 在claude-code后面插入 + newParts := make([]string, 0, len(parts)+1) + newParts = append(newParts, parts[:claudeCodeIdx+1]...) + newParts = append(newParts, claude.BetaOAuth) + newParts = append(newParts, parts[claudeCodeIdx+1:]...) + return strings.Join(newParts, ",") + } + + // 没有claude-code,放在第一位 + return claude.BetaOAuth + "," + clientBetaHeader + } + + // 客户端没传,根据模型生成 + // haiku 模型不需要 claude-code beta + if strings.Contains(strings.ToLower(modelID), "haiku") { + return claude.HaikuBetaHeader + } + + return claude.DefaultBetaHeader +} + +func requestNeedsBetaFeatures(body []byte) bool { + tools := gjson.GetBytes(body, "tools") + if tools.Exists() && tools.IsArray() && len(tools.Array()) > 0 { + return true + } + thinkingType := gjson.GetBytes(body, "thinking.type").String() + if strings.EqualFold(thinkingType, "enabled") || strings.EqualFold(thinkingType, "adaptive") { + return true + } + return false +} + +func defaultAPIKeyBetaHeader(body []byte) string { + modelID := gjson.GetBytes(body, "model").String() + if strings.Contains(strings.ToLower(modelID), "haiku") { + return claude.APIKeyHaikuBetaHeader + } + return claude.APIKeyBetaHeader +} + +func applyClaudeOAuthHeaderDefaults(req *http.Request, isStream bool) { + if req == nil { + return + } + if req.Header.Get("accept") == "" { + req.Header.Set("accept", "application/json") + } + for key, value := range claude.DefaultHeaders { + if value == "" { + continue + } + if req.Header.Get(key) == "" { + req.Header.Set(key, value) + } + } + if isStream && req.Header.Get("x-stainless-helper-method") == "" { + req.Header.Set("x-stainless-helper-method", "stream") + } +} + +func mergeAnthropicBeta(required []string, incoming string) string { + seen := make(map[string]struct{}, len(required)+8) + out := make([]string, 0, len(required)+8) + + add := func(v string) { + v = strings.TrimSpace(v) + if v == "" { + return + } + if _, ok := seen[v]; ok { + return + } + seen[v] = struct{}{} + out = append(out, v) + } + + for _, r := range required { + add(r) + } + for _, p := range strings.Split(incoming, ",") { + add(p) + } + return strings.Join(out, ",") +} + +func mergeAnthropicBetaDropping(required []string, incoming string, drop map[string]struct{}) string { + merged := mergeAnthropicBeta(required, incoming) + if merged == "" || len(drop) == 0 { + return merged + } + out := make([]string, 0, 8) + for _, p := range strings.Split(merged, ",") { + p = strings.TrimSpace(p) + if p == "" { + continue + } + if _, ok := drop[p]; ok { + continue + } + out = append(out, p) + } + return strings.Join(out, ",") +} + +// stripBetaTokens removes the given beta tokens from a comma-separated header value. +func stripBetaTokens(header string, tokens []string) string { + if header == "" || len(tokens) == 0 { + return header + } + return stripBetaTokensWithSet(header, buildBetaTokenSet(tokens)) +} + +func stripBetaTokensWithSet(header string, drop map[string]struct{}) string { + if header == "" || len(drop) == 0 { + return header + } + parts := strings.Split(header, ",") + out := make([]string, 0, len(parts)) + for _, p := range parts { + p = strings.TrimSpace(p) + if p == "" { + continue + } + if _, ok := drop[p]; ok { + continue + } + out = append(out, p) + } + if len(out) == len(parts) { + return header // no change, avoid allocation + } + return strings.Join(out, ",") +} + +// BetaBlockedError indicates a request was blocked by a beta policy rule. +type BetaBlockedError struct { + Message string +} + +func (e *BetaBlockedError) Error() string { return e.Message } + +// betaPolicyResult holds the evaluated result of beta policy rules for a single request. +type betaPolicyResult struct { + blockErr *BetaBlockedError // non-nil if a block rule matched + filterSet map[string]struct{} // tokens to filter (may be nil) +} + +// evaluateBetaPolicy loads settings once and evaluates all rules against the given request. +func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader string, account *Account) betaPolicyResult { + if s.settingService == nil { + return betaPolicyResult{} + } + settings, err := s.settingService.GetBetaPolicySettings(ctx) + if err != nil || settings == nil { + return betaPolicyResult{} + } + isOAuth := account.IsOAuth() + isBedrock := account.IsBedrock() + var result betaPolicyResult + for _, rule := range settings.Rules { + if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) { + continue + } + switch rule.Action { + case BetaPolicyActionBlock: + if result.blockErr == nil && betaHeader != "" && containsBetaToken(betaHeader, rule.BetaToken) { + msg := rule.ErrorMessage + if msg == "" { + msg = "beta feature " + rule.BetaToken + " is not allowed" + } + result.blockErr = &BetaBlockedError{Message: msg} + } + case BetaPolicyActionFilter: + if result.filterSet == nil { + result.filterSet = make(map[string]struct{}) + } + result.filterSet[rule.BetaToken] = struct{}{} + } + } + return result +} + +// mergeDropSets merges the static defaultDroppedBetasSet with dynamic policy filter tokens. +// Returns defaultDroppedBetasSet directly when policySet is empty (zero allocation). +func mergeDropSets(policySet map[string]struct{}, extra ...string) map[string]struct{} { + if len(policySet) == 0 && len(extra) == 0 { + return defaultDroppedBetasSet + } + m := make(map[string]struct{}, len(defaultDroppedBetasSet)+len(policySet)+len(extra)) + for t := range defaultDroppedBetasSet { + m[t] = struct{}{} + } + for t := range policySet { + m[t] = struct{}{} + } + for _, t := range extra { + m[t] = struct{}{} + } + return m +} + +// betaPolicyFilterSetKey is the gin.Context key for caching the policy filter set within a request. +const betaPolicyFilterSetKey = "betaPolicyFilterSet" + +// getBetaPolicyFilterSet returns the beta policy filter set, using the gin context cache if available. +// In the /v1/messages path, Forward() evaluates the policy first and caches the result; +// buildUpstreamRequest reuses it (zero extra DB calls). In the count_tokens path, this +// evaluates on demand (one DB call). +func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Context, account *Account) map[string]struct{} { + if c != nil { + if v, ok := c.Get(betaPolicyFilterSetKey); ok { + if fs, ok := v.(map[string]struct{}); ok { + return fs + } + } + } + return s.evaluateBetaPolicy(ctx, "", account).filterSet +} + +// betaPolicyScopeMatches checks whether a rule's scope matches the current account type. +func betaPolicyScopeMatches(scope string, isOAuth bool, isBedrock bool) bool { + switch scope { + case BetaPolicyScopeAll: + return true + case BetaPolicyScopeOAuth: + return isOAuth + case BetaPolicyScopeAPIKey: + return !isOAuth && !isBedrock + case BetaPolicyScopeBedrock: + return isBedrock + default: + return true // unknown scope → match all (fail-open) + } +} + +// droppedBetaSet returns claude.DroppedBetas as a set, with optional extra tokens. +func droppedBetaSet(extra ...string) map[string]struct{} { + m := make(map[string]struct{}, len(defaultDroppedBetasSet)+len(extra)) + for t := range defaultDroppedBetasSet { + m[t] = struct{}{} + } + for _, t := range extra { + m[t] = struct{}{} + } + return m +} + +// containsBetaToken checks if a comma-separated header value contains the given token. +func containsBetaToken(header, token string) bool { + if header == "" || token == "" { + return false + } + for _, p := range strings.Split(header, ",") { + if strings.TrimSpace(p) == token { + return true + } + } + return false +} + +func filterBetaTokens(tokens []string, filterSet map[string]struct{}) []string { + if len(tokens) == 0 || len(filterSet) == 0 { + return tokens + } + kept := make([]string, 0, len(tokens)) + for _, token := range tokens { + if _, filtered := filterSet[token]; !filtered { + kept = append(kept, token) + } + } + return kept +} + +func (s *GatewayService) resolveBedrockBetaTokensForRequest( + ctx context.Context, + account *Account, + betaHeader string, + body []byte, + modelID string, +) ([]string, error) { + // 1. 对原始 header 中的 beta token 做 block 检查(快速失败) + policy := s.evaluateBetaPolicy(ctx, betaHeader, account) + if policy.blockErr != nil { + return nil, policy.blockErr + } + + // 2. 解析 header + body 自动注入 + Bedrock 转换/过滤 + betaTokens := ResolveBedrockBetaTokens(betaHeader, body, modelID) + + // 3. 对最终 token 列表再做 block 检查,捕获通过 body 自动注入绕过 header block 的情况。 + // 例如:管理员 block 了 interleaved-thinking,客户端不在 header 中带该 token, + // 但请求体中包含 thinking 字段 → autoInjectBedrockBetaTokens 会自动补齐 → + // 如果不做此检查,block 规则会被绕过。 + if blockErr := s.checkBetaPolicyBlockForTokens(ctx, betaTokens, account); blockErr != nil { + return nil, blockErr + } + + return filterBetaTokens(betaTokens, policy.filterSet), nil +} + +// checkBetaPolicyBlockForTokens 检查 token 列表中是否有被管理员 block 规则命中的 token。 +// 用于补充 evaluateBetaPolicy 对 header 的检查,覆盖 body 自动注入的 token。 +func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, tokens []string, account *Account) *BetaBlockedError { + if s.settingService == nil || len(tokens) == 0 { + return nil + } + settings, err := s.settingService.GetBetaPolicySettings(ctx) + if err != nil || settings == nil { + return nil + } + isOAuth := account.IsOAuth() + isBedrock := account.IsBedrock() + tokenSet := buildBetaTokenSet(tokens) + for _, rule := range settings.Rules { + if rule.Action != BetaPolicyActionBlock { + continue + } + if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) { + continue + } + if _, present := tokenSet[rule.BetaToken]; present { + msg := rule.ErrorMessage + if msg == "" { + msg = "beta feature " + rule.BetaToken + " is not allowed" + } + return &BetaBlockedError{Message: msg} + } + } + return nil +} + +func buildBetaTokenSet(tokens []string) map[string]struct{} { + m := make(map[string]struct{}, len(tokens)) + for _, t := range tokens { + if t == "" { + continue + } + m[t] = struct{}{} + } + return m +} + +var defaultDroppedBetasSet = buildBetaTokenSet(claude.DroppedBetas) + +// applyClaudeCodeMimicHeaders forces "Claude Code-like" request headers. +// This mirrors opencode-anthropic-auth behavior: do not trust downstream +// headers when using Claude Code-scoped OAuth credentials. +func applyClaudeCodeMimicHeaders(req *http.Request, isStream bool) { + if req == nil { + return + } + // Start with the standard defaults (fill missing). + applyClaudeOAuthHeaderDefaults(req, isStream) + // Then force key headers to match Claude Code fingerprint regardless of what the client sent. + for key, value := range claude.DefaultHeaders { + if value == "" { + continue + } + req.Header.Set(key, value) + } + // Real Claude CLI uses Accept: application/json (even for streaming). + req.Header.Set("accept", "application/json") + if isStream { + req.Header.Set("x-stainless-helper-method", "stream") + } +} + +func truncateForLog(b []byte, maxBytes int) string { + if maxBytes <= 0 { + maxBytes = 2048 + } + if len(b) > maxBytes { + b = b[:maxBytes] + } + s := string(b) + // 保持一行,避免污染日志格式 + s = strings.ReplaceAll(s, "\n", "\\n") + s = strings.ReplaceAll(s, "\r", "\\r") + return s +} + +// isThinkingBlockSignatureError 检测是否是thinking block相关错误 +// 这类错误可以通过过滤thinking blocks并重试来解决 +func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool { + msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody))) + if msg == "" { + return false + } + + // 检测signature相关的错误(更宽松的匹配) + // 例如: "Invalid `signature` in `thinking` block", "***.signature" 等 + if strings.Contains(msg, "signature") { + return true + } + + // 检测 thinking block 顺序/类型错误 + // 例如: "Expected `thinking` or `redacted_thinking`, but found `text`" + if strings.Contains(msg, "expected") && (strings.Contains(msg, "thinking") || strings.Contains(msg, "redacted_thinking")) { + logger.LegacyPrintf("service.gateway", "[SignatureCheck] Detected thinking block type error") + return true + } + + // 检测 thinking block 被修改的错误 + // 例如: "thinking or redacted_thinking blocks in the latest assistant message cannot be modified" + if strings.Contains(msg, "cannot be modified") && (strings.Contains(msg, "thinking") || strings.Contains(msg, "redacted_thinking")) { + logger.LegacyPrintf("service.gateway", "[SignatureCheck] Detected thinking block modification error") + return true + } + + // 检测空消息内容错误(可能是过滤 thinking blocks 后导致的,或客户端发送了空 text block) + // 例如: "all messages must have non-empty content" + // "messages: text content blocks must be non-empty" + if strings.Contains(msg, "non-empty content") || strings.Contains(msg, "empty content") || + strings.Contains(msg, "content blocks must be non-empty") { + logger.LegacyPrintf("service.gateway", "[SignatureCheck] Detected empty content error") + return true + } + + return false +} + +func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool { + // 只对"可能是兼容性差异导致"的 400 允许切换,避免无意义重试。 + // 默认保守:无法识别则不切换。 + msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody))) + if msg == "" { + return false + } + + // 缺少/错误的 beta header:换账号/链路可能成功(尤其是混合调度时)。 + // 更精确匹配 beta 相关的兼容性问题,避免误触发切换。 + if strings.Contains(msg, "anthropic-beta") || + strings.Contains(msg, "beta feature") || + strings.Contains(msg, "requires beta") { + return true + } + + // thinking/tool streaming 等兼容性约束(常见于中间转换链路) + if strings.Contains(msg, "thinking") || strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature") { + return true + } + if strings.Contains(msg, "tool_use") || strings.Contains(msg, "tool_result") || strings.Contains(msg, "tools") { + return true + } + + return false +} + +// ExtractUpstreamErrorMessage 从上游响应体中提取错误消息 +// 支持 Claude 风格的错误格式:{"type":"error","error":{"type":"...","message":"..."}} +func ExtractUpstreamErrorMessage(body []byte) string { + return extractUpstreamErrorMessage(body) +} + +func extractUpstreamErrorMessage(body []byte) string { + // Claude 风格:{"type":"error","error":{"type":"...","message":"..."}} + if m := gjson.GetBytes(body, "error.message").String(); strings.TrimSpace(m) != "" { + inner := strings.TrimSpace(m) + // 有些上游会把完整 JSON 作为字符串塞进 message + if strings.HasPrefix(inner, "{") { + if innerMsg := gjson.Get(inner, "error.message").String(); strings.TrimSpace(innerMsg) != "" { + return innerMsg + } + } + return m + } + + // ChatGPT 内部 API 风格:{"detail":"..."} + if d := gjson.GetBytes(body, "detail").String(); strings.TrimSpace(d) != "" { + return d + } + + // 兜底:尝试顶层 message + return gjson.GetBytes(body, "message").String() +} + +func extractUpstreamErrorCode(body []byte) string { + if code := strings.TrimSpace(gjson.GetBytes(body, "error.code").String()); code != "" { + return code + } + + inner := strings.TrimSpace(gjson.GetBytes(body, "error.message").String()) + if !strings.HasPrefix(inner, "{") { + return "" + } + + if code := strings.TrimSpace(gjson.Get(inner, "error.code").String()); code != "" { + return code + } + + if lastBrace := strings.LastIndex(inner, "}"); lastBrace >= 0 { + if code := strings.TrimSpace(gjson.Get(inner[:lastBrace+1], "error.code").String()); code != "" { + return code + } + } + + return "" +} + +func isCountTokensUnsupported404(statusCode int, body []byte) bool { + if statusCode != http.StatusNotFound { + return false + } + msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(body))) + if msg == "" { + return false + } + if strings.Contains(msg, "/v1/messages/count_tokens") { + return true + } + return strings.Contains(msg, "count_tokens") && strings.Contains(msg, "not found") +} + +func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + + // 调试日志:打印上游错误响应 + logger.LegacyPrintf("service.gateway", "[Forward] Upstream error (non-retryable): Account=%d(%s) Status=%d RequestID=%s Body=%s", + account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(body), 1000)) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + + // Print a compact upstream request fingerprint when we hit the Claude Code OAuth + // credential scope error. This avoids requiring env-var tweaks in a fixed deploy. + if isClaudeCodeCredentialScopeError(upstreamMsg) && c != nil { + if v, ok := c.Get(claudeMimicDebugInfoKey); ok { + if line, ok := v.(string); ok && strings.TrimSpace(line) != "" { + logger.LegacyPrintf("service.gateway", "[ClaudeMimicDebugOnError] status=%d request_id=%s %s", + resp.StatusCode, + resp.Header.Get("x-request-id"), + line, + ) + } + } + } + + // Enrich Ops error logs with upstream status + message, and optionally a truncated body snippet. + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(body), maxBytes) + } + setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "http_error", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + + // 处理上游错误,标记账号状态 + shouldDisable := false + if s.rateLimitService != nil { + shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) + } + if shouldDisable { + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: body} + } + + // 记录上游错误响应体摘要便于排障(可选:由配置控制;不回显到客户端) + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + logger.LegacyPrintf("service.gateway", + "Upstream error %d (account=%d platform=%s type=%s): %s", + resp.StatusCode, + account.ID, + account.Platform, + account.Type, + truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes), + ) + } + + // 非 failover 错误也支持错误透传规则匹配。 + if status, errType, errMsg, matched := applyErrorPassthroughRule( + c, + account.Platform, + resp.StatusCode, + body, + http.StatusBadGateway, + "upstream_error", + "Upstream request failed", + ); matched { + c.JSON(status, gin.H{ + "type": "error", + "error": gin.H{ + "type": errType, + "message": errMsg, + }, + }) + + summary := upstreamMsg + if summary == "" { + summary = errMsg + } + if summary == "" { + return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode) + } + return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, summary) + } + + // 根据状态码返回适当的自定义错误响应(不透传上游详细信息) + var errType, errMsg string + var statusCode int + + switch resp.StatusCode { + case 400: + c.Data(http.StatusBadRequest, "application/json", body) + summary := upstreamMsg + if summary == "" { + summary = truncateForLog(body, 512) + } + if summary == "" { + return nil, fmt.Errorf("upstream error: %d", resp.StatusCode) + } + return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, summary) + case 401: + statusCode = http.StatusBadGateway + errType = "upstream_error" + errMsg = "Upstream authentication failed, please contact administrator" + case 403: + statusCode = http.StatusBadGateway + errType = "upstream_error" + errMsg = "Upstream access forbidden, please contact administrator" + case 429: + statusCode = http.StatusTooManyRequests + errType = "rate_limit_error" + errMsg = "Upstream rate limit exceeded, please retry later" + case 529: + statusCode = http.StatusServiceUnavailable + errType = "overloaded_error" + errMsg = "Upstream service overloaded, please retry later" + case 500, 502, 503, 504: + statusCode = http.StatusBadGateway + errType = "upstream_error" + errMsg = "Upstream service temporarily unavailable" + default: + statusCode = http.StatusBadGateway + errType = "upstream_error" + errMsg = "Upstream request failed" + } + + // 返回自定义错误响应 + c.JSON(statusCode, gin.H{ + "type": "error", + "error": gin.H{ + "type": errType, + "message": errMsg, + }, + }) + + if upstreamMsg == "" { + return nil, fmt.Errorf("upstream error: %d", resp.StatusCode) + } + return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg) +} + +func (s *GatewayService) handleRetryExhaustedSideEffects(ctx context.Context, resp *http.Response, account *Account) { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + statusCode := resp.StatusCode + + // OAuth/Setup Token 账号的 403:标记账号异常 + if account.IsOAuth() && statusCode == 403 { + s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, resp.Header, body) + logger.LegacyPrintf("service.gateway", "Account %d: marked as error after %d retries for status %d", account.ID, maxRetryAttempts, statusCode) + } else { + // API Key 未配置错误码:不标记账号状态 + logger.LegacyPrintf("service.gateway", "Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetryAttempts) + } +} + +func (s *GatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) +} + +// handleRetryExhaustedError 处理重试耗尽后的错误 +// OAuth 403:标记账号异常 +// API Key 未配置错误码:仅返回错误,不标记账号 +func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) { + // Capture upstream error body before side-effects consume the stream. + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + s.handleRetryExhaustedSideEffects(ctx, resp, account) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + + if isClaudeCodeCredentialScopeError(upstreamMsg) && c != nil { + if v, ok := c.Get(claudeMimicDebugInfoKey); ok { + if line, ok := v.(string); ok && strings.TrimSpace(line) != "" { + logger.LegacyPrintf("service.gateway", "[ClaudeMimicDebugOnError] status=%d request_id=%s %s", + resp.StatusCode, + resp.Header.Get("x-request-id"), + line, + ) + } + } + } + + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "retry_exhausted", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + logger.LegacyPrintf("service.gateway", + "Upstream error %d retries_exhausted (account=%d platform=%s type=%s): %s", + resp.StatusCode, + account.ID, + account.Platform, + account.Type, + truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes), + ) + } + + if status, errType, errMsg, matched := applyErrorPassthroughRule( + c, + account.Platform, + resp.StatusCode, + respBody, + http.StatusBadGateway, + "upstream_error", + "Upstream request failed after retries", + ); matched { + c.JSON(status, gin.H{ + "type": "error", + "error": gin.H{ + "type": errType, + "message": errMsg, + }, + }) + + summary := upstreamMsg + if summary == "" { + summary = errMsg + } + if summary == "" { + return nil, fmt.Errorf("upstream error: %d (retries exhausted, passthrough rule matched)", resp.StatusCode) + } + return nil, fmt.Errorf("upstream error: %d (retries exhausted, passthrough rule matched) message=%s", resp.StatusCode, summary) + } + + // 返回统一的重试耗尽错误响应 + c.JSON(http.StatusBadGateway, gin.H{ + "type": "error", + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream request failed after retries", + }, + }) + + if upstreamMsg == "" { + return nil, fmt.Errorf("upstream error: %d (retries exhausted)", resp.StatusCode) + } + return nil, fmt.Errorf("upstream error: %d (retries exhausted) message=%s", resp.StatusCode, upstreamMsg) +} + +// streamingResult 流式响应结果 +type streamingResult struct { + usage *ClaudeUsage + firstTokenMs *int + clientDisconnect bool // 客户端是否在流式传输过程中断开 +} + +func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string, mimicClaudeCode bool) (*streamingResult, error) { + // 更新5h窗口状态 + s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) + + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + + // 设置SSE响应头 + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + + // 透传其他响应头 + if v := resp.Header.Get("x-request-id"); v != "" { + c.Header("x-request-id", v) + } + + w := c.Writer + flusher, ok := w.(http.Flusher) + if !ok { + return nil, errors.New("streaming not supported") + } + + usage := &ClaudeUsage{} + var firstTokenMs *int + scanner := bufio.NewScanner(resp.Body) + // 设置更大的buffer以处理长行 + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) + + type scanEvent struct { + line string + err error + } + // 独立 goroutine 读取上游,避免读取阻塞导致超时/keepalive无法处理 + events := make(chan scanEvent, 16) + done := make(chan struct{}) + sendEvent := func(ev scanEvent) bool { + select { + case events <- ev: + return true + case <-done: + return false + } + } + var lastReadAt int64 + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + go func(scanBuf *sseScannerBuf64K) { + defer putSSEScannerBuf64K(scanBuf) + defer close(events) + for scanner.Scan() { + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + if !sendEvent(scanEvent{line: scanner.Text()}) { + return + } + } + if err := scanner.Err(); err != nil { + _ = sendEvent(scanEvent{err: err}) + } + }(scanBuf) + defer close(done) + + streamInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { + streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second + } + // 仅监控上游数据间隔超时,避免下游写入阻塞导致误判 + var intervalTicker *time.Ticker + if streamInterval > 0 { + intervalTicker = time.NewTicker(streamInterval) + defer intervalTicker.Stop() + } + var intervalCh <-chan time.Time + if intervalTicker != nil { + intervalCh = intervalTicker.C + } + + // 下游 keepalive:防止代理/Cloudflare Tunnel 因连接空闲而断开 + keepaliveInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 { + keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second + } + var keepaliveTicker *time.Ticker + if keepaliveInterval > 0 { + keepaliveTicker = time.NewTicker(keepaliveInterval) + defer keepaliveTicker.Stop() + } + var keepaliveCh <-chan time.Time + if keepaliveTicker != nil { + keepaliveCh = keepaliveTicker.C + } + lastDataAt := time.Now() + + // 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端) + errorEventSent := false + sendErrorEvent := func(reason string) { + if errorEventSent { + return + } + errorEventSent = true + _, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason) + flusher.Flush() + } + + needModelReplace := originalModel != mappedModel + clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage + sawTerminalEvent := false + + pendingEventLines := make([]string, 0, 4) + + processSSEEvent := func(lines []string) ([]string, string, *sseUsagePatch, error) { + if len(lines) == 0 { + return nil, "", nil, nil + } + + eventName := "" + dataLine := "" + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if strings.HasPrefix(trimmed, "event:") { + eventName = strings.TrimSpace(strings.TrimPrefix(trimmed, "event:")) + continue + } + if dataLine == "" && sseDataRe.MatchString(trimmed) { + dataLine = sseDataRe.ReplaceAllString(trimmed, "") + } + } + + if eventName == "error" { + return nil, dataLine, nil, errors.New("have error in stream") + } + + if dataLine == "" { + return []string{strings.Join(lines, "\n") + "\n\n"}, "", nil, nil + } + + if dataLine == "[DONE]" { + sawTerminalEvent = true + block := "" + if eventName != "" { + block = "event: " + eventName + "\n" + } + block += "data: " + dataLine + "\n\n" + return []string{block}, dataLine, nil, nil + } + + var event map[string]any + if err := json.Unmarshal([]byte(dataLine), &event); err != nil { + // JSON 解析失败,直接透传原始数据 + block := "" + if eventName != "" { + block = "event: " + eventName + "\n" + } + block += "data: " + dataLine + "\n\n" + return []string{block}, dataLine, nil, nil + } + + eventType, _ := event["type"].(string) + if eventName == "" { + eventName = eventType + } + eventChanged := false + + // 兼容 Kimi cached_tokens → cache_read_input_tokens + if eventType == "message_start" { + if msg, ok := event["message"].(map[string]any); ok { + if u, ok := msg["usage"].(map[string]any); ok { + eventChanged = reconcileCachedTokens(u) || eventChanged + } + } + } + if eventType == "message_delta" { + if u, ok := event["usage"].(map[string]any); ok { + eventChanged = reconcileCachedTokens(u) || eventChanged + } + } + + // Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类 + if account.IsCacheTTLOverrideEnabled() { + overrideTarget := account.GetCacheTTLOverrideTarget() + if eventType == "message_start" { + if msg, ok := event["message"].(map[string]any); ok { + if u, ok := msg["usage"].(map[string]any); ok { + eventChanged = rewriteCacheCreationJSON(u, overrideTarget) || eventChanged + } + } + } + if eventType == "message_delta" { + if u, ok := event["usage"].(map[string]any); ok { + eventChanged = rewriteCacheCreationJSON(u, overrideTarget) || eventChanged + } + } + } + + if needModelReplace { + if msg, ok := event["message"].(map[string]any); ok { + if model, ok := msg["model"].(string); ok && model == mappedModel { + msg["model"] = originalModel + eventChanged = true + } + } + } + + usagePatch := s.extractSSEUsagePatch(event) + if anthropicStreamEventIsTerminal(eventName, dataLine) { + sawTerminalEvent = true + } + if !eventChanged { + block := "" + if eventName != "" { + block = "event: " + eventName + "\n" + } + block += "data: " + dataLine + "\n\n" + return []string{block}, dataLine, usagePatch, nil + } + + newData, err := json.Marshal(event) + if err != nil { + // 序列化失败,直接透传原始数据 + block := "" + if eventName != "" { + block = "event: " + eventName + "\n" + } + block += "data: " + dataLine + "\n\n" + return []string{block}, dataLine, usagePatch, nil + } + + block := "" + if eventName != "" { + block = "event: " + eventName + "\n" + } + block += "data: " + string(newData) + "\n\n" + return []string{block}, string(newData), usagePatch, nil + } + + for { + select { + case ev, ok := <-events: + if !ok { + // 上游完成,返回结果 + if !sawTerminalEvent { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, fmt.Errorf("stream usage incomplete: missing terminal event") + } + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil + } + if ev.err != nil { + if sawTerminalEvent { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil + } + // 检测 context 取消(客户端断开会导致 context 取消,进而影响上游读取) + if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete: %w", ev.err) + } + // 客户端已通过写入失败检测到断开,上游也出错了,返回已收集的 usage + if clientDisconnected { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after disconnect: %w", ev.err) + } + // 客户端未断开,正常的错误处理 + if errors.Is(ev.err, bufio.ErrTooLong) { + logger.LegacyPrintf("service.gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) + sendErrorEvent("response_too_large") + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err + } + sendErrorEvent("stream_read_error") + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err) + } + line := ev.line + trimmed := strings.TrimSpace(line) + + if trimmed == "" { + if len(pendingEventLines) == 0 { + continue + } + + outputBlocks, data, usagePatch, err := processSSEEvent(pendingEventLines) + pendingEventLines = pendingEventLines[:0] + if err != nil { + if clientDisconnected { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + } + return nil, err + } + + for _, block := range outputBlocks { + if !clientDisconnected { + if _, werr := fmt.Fprint(w, block); werr != nil { + clientDisconnected = true + logger.LegacyPrintf("service.gateway", "Client disconnected during streaming, continuing to drain upstream for billing") + break + } + flusher.Flush() + lastDataAt = time.Now() + } + if data != "" { + if firstTokenMs == nil && data != "[DONE]" { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + if usagePatch != nil { + mergeSSEUsagePatch(usage, usagePatch) + } + } + } + continue + } + + pendingEventLines = append(pendingEventLines, line) + + case <-intervalCh: + lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) + if time.Since(lastRead) < streamInterval { + continue + } + if clientDisconnected { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after timeout") + } + logger.LegacyPrintf("service.gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) + // 处理流超时,可能标记账户为临时不可调度或错误状态 + if s.rateLimitService != nil { + s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel) + } + sendErrorEvent("stream_timeout") + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") + + case <-keepaliveCh: + if clientDisconnected { + continue + } + if time.Since(lastDataAt) < keepaliveInterval { + continue + } + // SSE ping 事件:Anthropic 原生格式,客户端会正确处理, + // 同时保持连接活跃防止 Cloudflare Tunnel 等代理断开 + if _, werr := fmt.Fprint(w, "event: ping\ndata: {\"type\": \"ping\"}\n\n"); werr != nil { + clientDisconnected = true + logger.LegacyPrintf("service.gateway", "Client disconnected during keepalive ping, continuing to drain upstream for billing") + continue + } + flusher.Flush() + } + } + +} + +func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) { + if usage == nil { + return + } + + var event map[string]any + if err := json.Unmarshal([]byte(data), &event); err != nil { + return + } + + if patch := s.extractSSEUsagePatch(event); patch != nil { + mergeSSEUsagePatch(usage, patch) + } +} + +type sseUsagePatch struct { + inputTokens int + hasInputTokens bool + outputTokens int + hasOutputTokens bool + cacheCreationInputTokens int + hasCacheCreationInput bool + cacheReadInputTokens int + hasCacheReadInput bool + cacheCreation5mTokens int + hasCacheCreation5m bool + cacheCreation1hTokens int + hasCacheCreation1h bool +} + +func (s *GatewayService) extractSSEUsagePatch(event map[string]any) *sseUsagePatch { + if len(event) == 0 { + return nil + } + + eventType, _ := event["type"].(string) + switch eventType { + case "message_start": + msg, _ := event["message"].(map[string]any) + usageObj, _ := msg["usage"].(map[string]any) + if len(usageObj) == 0 { + return nil + } + + patch := &sseUsagePatch{} + patch.hasInputTokens = true + if v, ok := parseSSEUsageInt(usageObj["input_tokens"]); ok { + patch.inputTokens = v + } + patch.hasCacheCreationInput = true + if v, ok := parseSSEUsageInt(usageObj["cache_creation_input_tokens"]); ok { + patch.cacheCreationInputTokens = v + } + patch.hasCacheReadInput = true + if v, ok := parseSSEUsageInt(usageObj["cache_read_input_tokens"]); ok { + patch.cacheReadInputTokens = v + } + if cc, ok := usageObj["cache_creation"].(map[string]any); ok { + if v, exists := parseSSEUsageInt(cc["ephemeral_5m_input_tokens"]); exists { + patch.cacheCreation5mTokens = v + patch.hasCacheCreation5m = true + } + if v, exists := parseSSEUsageInt(cc["ephemeral_1h_input_tokens"]); exists { + patch.cacheCreation1hTokens = v + patch.hasCacheCreation1h = true + } + } + return patch + + case "message_delta": + usageObj, _ := event["usage"].(map[string]any) + if len(usageObj) == 0 { + return nil + } + + patch := &sseUsagePatch{} + if v, ok := parseSSEUsageInt(usageObj["input_tokens"]); ok && v > 0 { + patch.inputTokens = v + patch.hasInputTokens = true + } + if v, ok := parseSSEUsageInt(usageObj["output_tokens"]); ok && v > 0 { + patch.outputTokens = v + patch.hasOutputTokens = true + } + if v, ok := parseSSEUsageInt(usageObj["cache_creation_input_tokens"]); ok && v > 0 { + patch.cacheCreationInputTokens = v + patch.hasCacheCreationInput = true + } + if v, ok := parseSSEUsageInt(usageObj["cache_read_input_tokens"]); ok && v > 0 { + patch.cacheReadInputTokens = v + patch.hasCacheReadInput = true + } + if cc, ok := usageObj["cache_creation"].(map[string]any); ok { + if v, exists := parseSSEUsageInt(cc["ephemeral_5m_input_tokens"]); exists && v > 0 { + patch.cacheCreation5mTokens = v + patch.hasCacheCreation5m = true + } + if v, exists := parseSSEUsageInt(cc["ephemeral_1h_input_tokens"]); exists && v > 0 { + patch.cacheCreation1hTokens = v + patch.hasCacheCreation1h = true + } + } + return patch + } + + return nil +} + +func mergeSSEUsagePatch(usage *ClaudeUsage, patch *sseUsagePatch) { + if usage == nil || patch == nil { + return + } + + if patch.hasInputTokens { + usage.InputTokens = patch.inputTokens + } + if patch.hasCacheCreationInput { + usage.CacheCreationInputTokens = patch.cacheCreationInputTokens + } + if patch.hasCacheReadInput { + usage.CacheReadInputTokens = patch.cacheReadInputTokens + } + if patch.hasOutputTokens { + usage.OutputTokens = patch.outputTokens + } + if patch.hasCacheCreation5m { + usage.CacheCreation5mTokens = patch.cacheCreation5mTokens + } + if patch.hasCacheCreation1h { + usage.CacheCreation1hTokens = patch.cacheCreation1hTokens + } +} + +func parseSSEUsageInt(value any) (int, bool) { + switch v := value.(type) { + case float64: + return int(v), true + case float32: + return int(v), true + case int: + return v, true + case int64: + return int(v), true + case int32: + return int(v), true + case json.Number: + if i, err := v.Int64(); err == nil { + return int(i), true + } + if f, err := v.Float64(); err == nil { + return int(f), true + } + case string: + if parsed, err := strconv.Atoi(strings.TrimSpace(v)); err == nil { + return parsed, true + } + } + return 0, false +} + +// applyCacheTTLOverride 将所有 cache creation tokens 归入指定的 TTL 类型。 +// target 为 "5m" 或 "1h"。返回 true 表示发生了变更。 +func applyCacheTTLOverride(usage *ClaudeUsage, target string) bool { + // Fallback: 如果只有聚合字段但无 5m/1h 明细,将聚合字段归入 5m 默认类别 + if usage.CacheCreation5mTokens == 0 && usage.CacheCreation1hTokens == 0 && usage.CacheCreationInputTokens > 0 { + usage.CacheCreation5mTokens = usage.CacheCreationInputTokens + } + + total := usage.CacheCreation5mTokens + usage.CacheCreation1hTokens + if total == 0 { + return false + } + switch target { + case "1h": + if usage.CacheCreation1hTokens == total { + return false // 已经全是 1h + } + usage.CacheCreation1hTokens = total + usage.CacheCreation5mTokens = 0 + default: // "5m" + if usage.CacheCreation5mTokens == total { + return false // 已经全是 5m + } + usage.CacheCreation5mTokens = total + usage.CacheCreation1hTokens = 0 + } + return true +} + +// rewriteCacheCreationJSON 在 JSON usage 对象中重写 cache_creation 嵌套对象的 TTL 分类。 +// usageObj 是 usage JSON 对象(map[string]any)。 +func rewriteCacheCreationJSON(usageObj map[string]any, target string) bool { + ccObj, ok := usageObj["cache_creation"].(map[string]any) + if !ok { + return false + } + v5m, _ := parseSSEUsageInt(ccObj["ephemeral_5m_input_tokens"]) + v1h, _ := parseSSEUsageInt(ccObj["ephemeral_1h_input_tokens"]) + total := v5m + v1h + if total == 0 { + return false + } + switch target { + case "1h": + if v1h == total { + return false + } + ccObj["ephemeral_1h_input_tokens"] = float64(total) + ccObj["ephemeral_5m_input_tokens"] = float64(0) + default: // "5m" + if v5m == total { + return false + } + ccObj["ephemeral_5m_input_tokens"] = float64(total) + ccObj["ephemeral_1h_input_tokens"] = float64(0) + } + return true +} + +func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) { + // 更新5h窗口状态 + s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) + + maxBytes := resolveUpstreamResponseReadLimit(s.cfg) + body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) + if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + c.JSON(http.StatusBadGateway, gin.H{ + "type": "error", + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream response too large", + }, + }) + } + return nil, err + } + + // 解析usage + var response struct { + Usage ClaudeUsage `json:"usage"` + } + if err := json.Unmarshal(body, &response); err != nil { + return nil, fmt.Errorf("parse response: %w", err) + } + + // 解析嵌套的 cache_creation 对象中的 5m/1h 明细 + cc5m := gjson.GetBytes(body, "usage.cache_creation.ephemeral_5m_input_tokens") + cc1h := gjson.GetBytes(body, "usage.cache_creation.ephemeral_1h_input_tokens") + if cc5m.Exists() || cc1h.Exists() { + response.Usage.CacheCreation5mTokens = int(cc5m.Int()) + response.Usage.CacheCreation1hTokens = int(cc1h.Int()) + } + + // 兼容 Kimi cached_tokens → cache_read_input_tokens + if response.Usage.CacheReadInputTokens == 0 { + cachedTokens := gjson.GetBytes(body, "usage.cached_tokens").Int() + if cachedTokens > 0 { + response.Usage.CacheReadInputTokens = int(cachedTokens) + if newBody, err := sjson.SetBytes(body, "usage.cache_read_input_tokens", cachedTokens); err == nil { + body = newBody + } + } + } + + // Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类 + if account.IsCacheTTLOverrideEnabled() { + overrideTarget := account.GetCacheTTLOverrideTarget() + if applyCacheTTLOverride(&response.Usage, overrideTarget) { + // 同步更新 body JSON 中的嵌套 cache_creation 对象 + if newBody, err := sjson.SetBytes(body, "usage.cache_creation.ephemeral_5m_input_tokens", response.Usage.CacheCreation5mTokens); err == nil { + body = newBody + } + if newBody, err := sjson.SetBytes(body, "usage.cache_creation.ephemeral_1h_input_tokens", response.Usage.CacheCreation1hTokens); err == nil { + body = newBody + } + } + } + + // 如果有模型映射,替换响应中的model字段 + if originalModel != mappedModel { + body = s.replaceModelInResponseBody(body, mappedModel, originalModel) + } + + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + + contentType := "application/json" + if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled { + if upstreamType := resp.Header.Get("Content-Type"); upstreamType != "" { + contentType = upstreamType + } + } + + // 写入响应 + c.Data(resp.StatusCode, contentType, body) + + return &response.Usage, nil +} + +// replaceModelInResponseBody 替换响应体中的model字段 +// 使用 gjson/sjson 精确替换,避免全量 JSON 反序列化 +func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte { + if m := gjson.GetBytes(body, "model"); m.Exists() && m.Str == fromModel { + newBody, err := sjson.SetBytes(body, "model", toModel) + if err != nil { + return body + } + return newBody + } + return body +} + +func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID, groupID int64, groupDefaultMultiplier float64) float64 { + if s == nil { + return groupDefaultMultiplier + } + resolver := s.userGroupRateResolver + if resolver == nil { + resolver = newUserGroupRateResolver( + s.userGroupRateRepo, + s.userGroupRateCache, + resolveUserGroupRateCacheTTL(s.cfg), + &s.userGroupRateSF, + "service.gateway", + ) + } + return resolver.Resolve(ctx, userID, groupID, groupDefaultMultiplier) +} + +// RecordUsageInput 记录使用量的输入参数 +type RecordUsageInput struct { + Result *ForwardResult + APIKey *APIKey + User *User + Account *Account + Subscription *UserSubscription // 可选:订阅信息 + InboundEndpoint string // 入站端点(客户端请求路径) + UpstreamEndpoint string // 上游端点(标准化后的上游路径) + UserAgent string // 请求的 User-Agent + IPAddress string // 请求的客户端 IP 地址 + RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险 + ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) + APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额 +} + +// APIKeyQuotaUpdater defines the interface for updating API Key quota and rate limit usage +type APIKeyQuotaUpdater interface { + UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error + UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error +} + +type apiKeyAuthCacheInvalidator interface { + InvalidateAuthCacheByKey(ctx context.Context, key string) +} + +type usageLogBestEffortWriter interface { + CreateBestEffort(ctx context.Context, log *UsageLog) error +} + +// postUsageBillingParams 统一扣费所需的参数 +type postUsageBillingParams struct { + Cost *CostBreakdown + User *User + APIKey *APIKey + Account *Account + Subscription *UserSubscription + RequestPayloadHash string + IsSubscriptionBill bool + AccountRateMultiplier float64 + APIKeyService APIKeyQuotaUpdater +} + +// postUsageBilling 统一处理使用量记录后的扣费逻辑: +// - 订阅/余额扣费 +// - API Key 配额更新 +// - API Key 限速用量更新 +// - 账号配额用量更新(账号口径:TotalCost × 账号计费倍率) +func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *billingDeps) { + billingCtx, cancel := detachedBillingContext(ctx) + defer cancel() + + cost := p.Cost + + // 1. 订阅 / 余额扣费 + if p.IsSubscriptionBill { + if cost.TotalCost > 0 { + if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.TotalCost); err != nil { + slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err) + } + deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, cost.TotalCost) + } + } else { + if cost.ActualCost > 0 { + if err := deps.userRepo.DeductBalance(billingCtx, p.User.ID, cost.ActualCost); err != nil { + slog.Error("deduct balance failed", "user_id", p.User.ID, "error", err) + } + deps.billingCacheService.QueueDeductBalance(p.User.ID, cost.ActualCost) + } + } + + // 2. API Key 配额 + if cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil { + if err := p.APIKeyService.UpdateQuotaUsed(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil { + slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err) + } + } + + // 3. API Key 限速用量 + if cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil { + if err := p.APIKeyService.UpdateRateLimitUsage(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil { + slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err) + } + } + + // 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率) + if cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() { + accountCost := cost.TotalCost * p.AccountRateMultiplier + if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil { + slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err) + } + } + + finalizePostUsageBilling(p, deps) +} + +func resolveUsageBillingRequestID(ctx context.Context, upstreamRequestID string) string { + if ctx != nil { + if clientRequestID, _ := ctx.Value(ctxkey.ClientRequestID).(string); strings.TrimSpace(clientRequestID) != "" { + return "client:" + strings.TrimSpace(clientRequestID) + } + if requestID, _ := ctx.Value(ctxkey.RequestID).(string); strings.TrimSpace(requestID) != "" { + return "local:" + strings.TrimSpace(requestID) + } + } + if requestID := strings.TrimSpace(upstreamRequestID); requestID != "" { + return requestID + } + return "generated:" + generateRequestID() +} + +func resolveUsageBillingPayloadFingerprint(ctx context.Context, requestPayloadHash string) string { + if payloadHash := strings.TrimSpace(requestPayloadHash); payloadHash != "" { + return payloadHash + } + if ctx != nil { + if clientRequestID, _ := ctx.Value(ctxkey.ClientRequestID).(string); strings.TrimSpace(clientRequestID) != "" { + return "client:" + strings.TrimSpace(clientRequestID) + } + if requestID, _ := ctx.Value(ctxkey.RequestID).(string); strings.TrimSpace(requestID) != "" { + return "local:" + strings.TrimSpace(requestID) + } + } + return "" +} + +func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsageBillingParams) *UsageBillingCommand { + if p == nil || p.Cost == nil || p.APIKey == nil || p.User == nil || p.Account == nil { + return nil + } + + cmd := &UsageBillingCommand{ + RequestID: requestID, + APIKeyID: p.APIKey.ID, + UserID: p.User.ID, + AccountID: p.Account.ID, + AccountType: p.Account.Type, + RequestPayloadHash: strings.TrimSpace(p.RequestPayloadHash), + } + if usageLog != nil { + cmd.Model = usageLog.Model + cmd.BillingType = usageLog.BillingType + cmd.InputTokens = usageLog.InputTokens + cmd.OutputTokens = usageLog.OutputTokens + cmd.CacheCreationTokens = usageLog.CacheCreationTokens + cmd.CacheReadTokens = usageLog.CacheReadTokens + cmd.ImageCount = usageLog.ImageCount + if usageLog.MediaType != nil { + cmd.MediaType = *usageLog.MediaType + } + if usageLog.ServiceTier != nil { + cmd.ServiceTier = *usageLog.ServiceTier + } + if usageLog.ReasoningEffort != nil { + cmd.ReasoningEffort = *usageLog.ReasoningEffort + } + if usageLog.SubscriptionID != nil { + cmd.SubscriptionID = usageLog.SubscriptionID + } + } + + if p.IsSubscriptionBill && p.Subscription != nil && p.Cost.TotalCost > 0 { + cmd.SubscriptionID = &p.Subscription.ID + cmd.SubscriptionCost = p.Cost.TotalCost + } else if p.Cost.ActualCost > 0 { + cmd.BalanceCost = p.Cost.ActualCost + } + + if p.Cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil { + cmd.APIKeyQuotaCost = p.Cost.ActualCost + } + if p.Cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil { + cmd.APIKeyRateLimitCost = p.Cost.ActualCost + } + if p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() { + cmd.AccountQuotaCost = p.Cost.TotalCost * p.AccountRateMultiplier + } + + cmd.Normalize() + return cmd +} + +func applyUsageBilling(ctx context.Context, requestID string, usageLog *UsageLog, p *postUsageBillingParams, deps *billingDeps, repo UsageBillingRepository) (bool, error) { + if p == nil || deps == nil { + return false, nil + } + + cmd := buildUsageBillingCommand(requestID, usageLog, p) + if cmd == nil || cmd.RequestID == "" || repo == nil { + postUsageBilling(ctx, p, deps) + return true, nil + } + + billingCtx, cancel := detachedBillingContext(ctx) + defer cancel() + + result, err := repo.Apply(billingCtx, cmd) + if err != nil { + return false, err + } + + if result == nil || !result.Applied { + deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID) + return false, nil + } + + if result.APIKeyQuotaExhausted { + if invalidator, ok := p.APIKeyService.(apiKeyAuthCacheInvalidator); ok && p.APIKey != nil && p.APIKey.Key != "" { + invalidator.InvalidateAuthCacheByKey(billingCtx, p.APIKey.Key) + } + } + + finalizePostUsageBilling(p, deps) + return true, nil +} + +func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps) { + if p == nil || p.Cost == nil || deps == nil { + return + } + + if p.IsSubscriptionBill { + if p.Cost.TotalCost > 0 && p.User != nil && p.APIKey != nil && p.APIKey.GroupID != nil { + deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, p.Cost.TotalCost) + } + } else if p.Cost.ActualCost > 0 && p.User != nil { + deps.billingCacheService.QueueDeductBalance(p.User.ID, p.Cost.ActualCost) + } + + if p.Cost.ActualCost > 0 && p.APIKey != nil && p.APIKey.HasRateLimits() { + deps.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(p.APIKey.ID, p.Cost.ActualCost) + } + + deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID) +} + +func detachedBillingContext(ctx context.Context) (context.Context, context.CancelFunc) { + base := context.Background() + if ctx != nil { + base = context.WithoutCancel(ctx) + } + return context.WithTimeout(base, postUsageBillingTimeout) +} + +func detachStreamUpstreamContext(ctx context.Context, stream bool) (context.Context, context.CancelFunc) { + if !stream { + return ctx, func() {} + } + if ctx == nil { + return context.Background(), func() {} + } + return context.WithoutCancel(ctx), func() {} +} + +// billingDeps 扣费逻辑依赖的服务(由各 gateway service 提供) +type billingDeps struct { + accountRepo AccountRepository + userRepo UserRepository + userSubRepo UserSubscriptionRepository + billingCacheService *BillingCacheService + deferredService *DeferredService +} + +func (s *GatewayService) billingDeps() *billingDeps { + return &billingDeps{ + accountRepo: s.accountRepo, + userRepo: s.userRepo, + userSubRepo: s.userSubRepo, + billingCacheService: s.billingCacheService, + deferredService: s.deferredService, + } +} + +func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usageLog *UsageLog, logKey string) { + if repo == nil || usageLog == nil { + return + } + usageCtx, cancel := detachedBillingContext(ctx) + defer cancel() + + if writer, ok := repo.(usageLogBestEffortWriter); ok { + if err := writer.CreateBestEffort(usageCtx, usageLog); err != nil { + logger.LegacyPrintf(logKey, "Create usage log failed: %v", err) + if IsUsageLogCreateDropped(err) { + return + } + if _, syncErr := repo.Create(usageCtx, usageLog); syncErr != nil { + logger.LegacyPrintf(logKey, "Create usage log sync fallback failed: %v", syncErr) + } + } + return + } + + if _, err := repo.Create(usageCtx, usageLog); err != nil { + logger.LegacyPrintf(logKey, "Create usage log failed: %v", err) + } +} + +// RecordUsage 记录使用量并扣费(或更新订阅用量) +func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error { + result := input.Result + apiKey := input.APIKey + user := input.User + account := input.Account + subscription := input.Subscription + + // 强制缓存计费:将 input_tokens 转为 cache_read_input_tokens + // 用于粘性会话切换时的特殊计费处理 + if input.ForceCacheBilling && result.Usage.InputTokens > 0 { + logger.LegacyPrintf("service.gateway", "force_cache_billing: %d input_tokens → cache_read_input_tokens (account=%d)", + result.Usage.InputTokens, account.ID) + result.Usage.CacheReadInputTokens += result.Usage.InputTokens + result.Usage.InputTokens = 0 + } + + // Cache TTL Override: 确保计费时 token 分类与账号设置一致 + cacheTTLOverridden := false + if account.IsCacheTTLOverrideEnabled() { + applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget()) + cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0 + } + + // 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认) + multiplier := 1.0 + if s.cfg != nil { + multiplier = s.cfg.Default.RateMultiplier + } + if apiKey.GroupID != nil && apiKey.Group != nil { + groupDefault := apiKey.Group.RateMultiplier + multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault) + } + + var cost *CostBreakdown + billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel) + + // 根据请求类型选择计费方式 + if result.MediaType == "image" || result.MediaType == "video" { + var soraConfig *SoraPriceConfig + if apiKey.Group != nil { + soraConfig = &SoraPriceConfig{ + ImagePrice360: apiKey.Group.SoraImagePrice360, + ImagePrice540: apiKey.Group.SoraImagePrice540, + VideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest, + VideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD, + } + } + if result.MediaType == "image" { + cost = s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier) + } else { + cost = s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier) + } + } else if result.MediaType == "prompt" { + cost = &CostBreakdown{} + } else if result.ImageCount > 0 { + // 图片生成计费 + var groupConfig *ImagePriceConfig + if apiKey.Group != nil { + groupConfig = &ImagePriceConfig{ + Price1K: apiKey.Group.ImagePrice1K, + Price2K: apiKey.Group.ImagePrice2K, + Price4K: apiKey.Group.ImagePrice4K, + } + } + cost = s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier) + } else { + // Token 计费 + tokens := UsageTokens{ + InputTokens: result.Usage.InputTokens, + OutputTokens: result.Usage.OutputTokens, + CacheCreationTokens: result.Usage.CacheCreationInputTokens, + CacheReadTokens: result.Usage.CacheReadInputTokens, + CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, + CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, + } + var err error + cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier) + if err != nil { + logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err) + cost = &CostBreakdown{ActualCost: 0} + } + } + + // 判断计费方式:订阅模式 vs 余额模式 + isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType() + billingType := BillingTypeBalance + if isSubscriptionBilling { + billingType = BillingTypeSubscription + } + + // 创建使用日志 + durationMs := int(result.Duration.Milliseconds()) + var imageSize *string + if result.ImageSize != "" { + imageSize = &result.ImageSize + } + var mediaType *string + if strings.TrimSpace(result.MediaType) != "" { + mediaType = &result.MediaType + } + accountRateMultiplier := account.BillingRateMultiplier() + requestID := resolveUsageBillingRequestID(ctx, result.RequestID) + usageLog := &UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: requestID, + Model: result.Model, + RequestedModel: result.Model, + UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), + ReasoningEffort: result.ReasoningEffort, + InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), + UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint), + InputTokens: result.Usage.InputTokens, + OutputTokens: result.Usage.OutputTokens, + CacheCreationTokens: result.Usage.CacheCreationInputTokens, + CacheReadTokens: result.Usage.CacheReadInputTokens, + CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, + CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, + InputCost: cost.InputCost, + OutputCost: cost.OutputCost, + CacheCreationCost: cost.CacheCreationCost, + CacheReadCost: cost.CacheReadCost, + TotalCost: cost.TotalCost, + ActualCost: cost.ActualCost, + RateMultiplier: multiplier, + AccountRateMultiplier: &accountRateMultiplier, + BillingType: billingType, + Stream: result.Stream, + DurationMs: &durationMs, + FirstTokenMs: result.FirstTokenMs, + ImageCount: result.ImageCount, + ImageSize: imageSize, + MediaType: mediaType, + CacheTTLOverridden: cacheTTLOverridden, + CreatedAt: time.Now(), + } + + // 添加 UserAgent + if input.UserAgent != "" { + usageLog.UserAgent = &input.UserAgent + } + + // 添加 IPAddress + if input.IPAddress != "" { + usageLog.IPAddress = &input.IPAddress + } + + // 添加分组和订阅关联 + if apiKey.GroupID != nil { + usageLog.GroupID = apiKey.GroupID + } + if subscription != nil { + usageLog.SubscriptionID = &subscription.ID + } + + if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") + logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) + s.deferredService.ScheduleLastUsedUpdate(account.ID) + return nil + } + + billingErr := func() error { + _, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{ + Cost: cost, + User: user, + APIKey: apiKey, + Account: account, + Subscription: subscription, + RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash), + IsSubscriptionBill: isSubscriptionBilling, + AccountRateMultiplier: accountRateMultiplier, + APIKeyService: input.APIKeyService, + }, s.billingDeps(), s.usageBillingRepo) + return err + }() + + if billingErr != nil { + return billingErr + } + writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") + + return nil +} + +// RecordUsageLongContextInput 记录使用量的输入参数(支持长上下文双倍计费) +type RecordUsageLongContextInput struct { + Result *ForwardResult + APIKey *APIKey + User *User + Account *Account + Subscription *UserSubscription // 可选:订阅信息 + InboundEndpoint string // 入站端点(客户端请求路径) + UpstreamEndpoint string // 上游端点(标准化后的上游路径) + UserAgent string // 请求的 User-Agent + IPAddress string // 请求的客户端 IP 地址 + RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险 + LongContextThreshold int // 长上下文阈值(如 200000) + LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0) + ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) + APIKeyService APIKeyQuotaUpdater // API Key 配额服务(可选) +} + +// RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini) +func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *RecordUsageLongContextInput) error { + result := input.Result + apiKey := input.APIKey + user := input.User + account := input.Account + subscription := input.Subscription + + // 强制缓存计费:将 input_tokens 转为 cache_read_input_tokens + // 用于粘性会话切换时的特殊计费处理 + if input.ForceCacheBilling && result.Usage.InputTokens > 0 { + logger.LegacyPrintf("service.gateway", "force_cache_billing: %d input_tokens → cache_read_input_tokens (account=%d)", + result.Usage.InputTokens, account.ID) + result.Usage.CacheReadInputTokens += result.Usage.InputTokens + result.Usage.InputTokens = 0 + } + + // Cache TTL Override: 确保计费时 token 分类与账号设置一致 + cacheTTLOverridden := false + if account.IsCacheTTLOverrideEnabled() { + applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget()) + cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0 + } + + // 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认) + multiplier := 1.0 + if s.cfg != nil { + multiplier = s.cfg.Default.RateMultiplier + } + if apiKey.GroupID != nil && apiKey.Group != nil { + groupDefault := apiKey.Group.RateMultiplier + multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault) + } + + var cost *CostBreakdown + billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel) + + // 根据请求类型选择计费方式 + if result.ImageCount > 0 { + // 图片生成计费 + var groupConfig *ImagePriceConfig + if apiKey.Group != nil { + groupConfig = &ImagePriceConfig{ + Price1K: apiKey.Group.ImagePrice1K, + Price2K: apiKey.Group.ImagePrice2K, + Price4K: apiKey.Group.ImagePrice4K, + } + } + cost = s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier) + } else { + // Token 计费(使用长上下文计费方法) + tokens := UsageTokens{ + InputTokens: result.Usage.InputTokens, + OutputTokens: result.Usage.OutputTokens, + CacheCreationTokens: result.Usage.CacheCreationInputTokens, + CacheReadTokens: result.Usage.CacheReadInputTokens, + CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, + CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, + } + var err error + cost, err = s.billingService.CalculateCostWithLongContext(billingModel, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier) + if err != nil { + logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err) + cost = &CostBreakdown{ActualCost: 0} + } + } + + // 判断计费方式:订阅模式 vs 余额模式 + isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType() + billingType := BillingTypeBalance + if isSubscriptionBilling { + billingType = BillingTypeSubscription + } + + // 创建使用日志 + durationMs := int(result.Duration.Milliseconds()) + var imageSize *string + if result.ImageSize != "" { + imageSize = &result.ImageSize + } + accountRateMultiplier := account.BillingRateMultiplier() + requestID := resolveUsageBillingRequestID(ctx, result.RequestID) + usageLog := &UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: requestID, + Model: result.Model, + RequestedModel: result.Model, + UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), + ReasoningEffort: result.ReasoningEffort, + InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), + UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint), + InputTokens: result.Usage.InputTokens, + OutputTokens: result.Usage.OutputTokens, + CacheCreationTokens: result.Usage.CacheCreationInputTokens, + CacheReadTokens: result.Usage.CacheReadInputTokens, + CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, + CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, + InputCost: cost.InputCost, + OutputCost: cost.OutputCost, + CacheCreationCost: cost.CacheCreationCost, + CacheReadCost: cost.CacheReadCost, + TotalCost: cost.TotalCost, + ActualCost: cost.ActualCost, + RateMultiplier: multiplier, + AccountRateMultiplier: &accountRateMultiplier, + BillingType: billingType, + Stream: result.Stream, + DurationMs: &durationMs, + FirstTokenMs: result.FirstTokenMs, + ImageCount: result.ImageCount, + ImageSize: imageSize, + CacheTTLOverridden: cacheTTLOverridden, + CreatedAt: time.Now(), + } + + // 添加 UserAgent + if input.UserAgent != "" { + usageLog.UserAgent = &input.UserAgent + } + + // 添加 IPAddress + if input.IPAddress != "" { + usageLog.IPAddress = &input.IPAddress + } + + // 添加分组和订阅关联 + if apiKey.GroupID != nil { + usageLog.GroupID = apiKey.GroupID + } + if subscription != nil { + usageLog.SubscriptionID = &subscription.ID + } + + if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") + logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) + s.deferredService.ScheduleLastUsedUpdate(account.ID) + return nil + } + + billingErr := func() error { + _, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{ + Cost: cost, + User: user, + APIKey: apiKey, + Account: account, + Subscription: subscription, + RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash), + IsSubscriptionBill: isSubscriptionBilling, + AccountRateMultiplier: accountRateMultiplier, + APIKeyService: input.APIKeyService, + }, s.billingDeps(), s.usageBillingRepo) + return err + }() + + if billingErr != nil { + return billingErr + } + writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") + + return nil +} + +// ForwardCountTokens 转发 count_tokens 请求到上游 API +// 特点:不记录使用量、仅支持非流式响应 +func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error { + if parsed == nil { + s.countTokensError(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") + return fmt.Errorf("parse request: empty request") + } + + if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() { + passthroughBody := parsed.Body + if reqModel := parsed.Model; reqModel != "" { + if mappedModel := account.GetMappedModel(reqModel); mappedModel != reqModel { + passthroughBody = s.replaceModelInBody(passthroughBody, mappedModel) + logger.LegacyPrintf("service.gateway", "CountTokens passthrough model mapping: %s -> %s (account: %s)", reqModel, mappedModel, account.Name) + } + } + return s.forwardCountTokensAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody) + } + + // Bedrock 不支持 count_tokens 端点 + if account != nil && account.IsBedrock() { + s.countTokensError(c, http.StatusNotFound, "not_found_error", "count_tokens endpoint is not supported for Bedrock") + return nil + } + + body := parsed.Body + reqModel := parsed.Model + + isClaudeCode := isClaudeCodeRequest(ctx, c, parsed) + shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode + + if shouldMimicClaudeCode { + normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true} + body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts) + } + + // Antigravity 账户不支持 count_tokens,返回 404 让客户端 fallback 到本地估算。 + // 返回 nil 避免 handler 层记录为错误,也不设置 ops 上游错误上下文。 + if account.Platform == PlatformAntigravity { + s.countTokensError(c, http.StatusNotFound, "not_found_error", "count_tokens endpoint is not supported for this platform") + return nil + } + + // 应用模型映射: + // - APIKey 账号:使用账号级别的显式映射(如果配置),否则透传原始模型名 + // - OAuth/SetupToken 账号:使用 Anthropic 标准映射(短ID → 长ID) + if reqModel != "" { + mappedModel := reqModel + mappingSource := "" + if account.Type == AccountTypeAPIKey { + mappedModel = account.GetMappedModel(reqModel) + if mappedModel != reqModel { + mappingSource = "account" + } + } + if mappingSource == "" && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { + normalized := claude.NormalizeModelID(reqModel) + if normalized != reqModel { + mappedModel = normalized + mappingSource = "prefix" + } + } + if mappedModel != reqModel { + body = s.replaceModelInBody(body, mappedModel) + reqModel = mappedModel + logger.LegacyPrintf("service.gateway", "CountTokens model mapping applied: %s -> %s (account: %s, source=%s)", parsed.Model, mappedModel, account.Name, mappingSource) + } + } + + // 获取凭证 + token, tokenType, err := s.GetAccessToken(ctx, account) + if err != nil { + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to get access token") + return err + } + + // 构建上游请求 + upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType, reqModel, shouldMimicClaudeCode) + if err != nil { + s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request") + return err + } + + // 获取代理URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + // 发送请求 + resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) + if err != nil { + setOpsUpstreamError(c, 0, sanitizeUpstreamErrorMessage(err.Error()), "") + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed") + return fmt.Errorf("upstream request failed: %w", err) + } + + // 读取响应体 + maxReadBytes := resolveUpstreamResponseReadLimit(s.cfg) + respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxReadBytes) + _ = resp.Body.Close() + if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large") + return err + } + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") + return err + } + + // 检测 thinking block 签名错误(400)并重试一次(过滤 thinking blocks) + if resp.StatusCode == 400 && s.isThinkingBlockSignatureError(respBody) && s.settingService.IsSignatureRectifierEnabled(ctx) { + logger.LegacyPrintf("service.gateway", "Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID) + + filteredBody := FilterThinkingBlocksForRetry(body) + retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, shouldMimicClaudeCode) + if buildErr == nil { + retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) + if retryErr == nil { + resp = retryResp + respBody, err = readUpstreamResponseBodyLimited(resp.Body, maxReadBytes) + _ = resp.Body.Close() + if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large") + return err + } + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") + return err + } + } + } + } + + // 处理错误响应 + if resp.StatusCode >= 400 { + // 标记账号状态(429/529等) + s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) + + // 记录上游错误摘要便于排障(不回显请求内容) + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + logger.LegacyPrintf("service.gateway", + "count_tokens upstream error %d (account=%d platform=%s type=%s): %s", + resp.StatusCode, + account.ID, + account.Platform, + account.Type, + truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes), + ) + } + + // 返回简化的错误响应 + errMsg := "Upstream request failed" + switch resp.StatusCode { + case 429: + errMsg = "Rate limit exceeded" + case 529: + errMsg = "Service overloaded" + } + s.countTokensError(c, resp.StatusCode, "upstream_error", errMsg) + if upstreamMsg == "" { + return fmt.Errorf("upstream error: %d", resp.StatusCode) + } + return fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg) + } + + // 透传成功响应 + c.Data(resp.StatusCode, "application/json", respBody) + return nil +} + +func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx context.Context, c *gin.Context, account *Account, body []byte) error { + token, tokenType, err := s.GetAccessToken(ctx, account) + if err != nil { + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to get access token") + return err + } + if tokenType != "apikey" { + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Invalid account token type") + return fmt.Errorf("anthropic api key passthrough requires apikey token, got: %s", tokenType) + } + + upstreamReq, err := s.buildCountTokensRequestAnthropicAPIKeyPassthrough(ctx, c, account, body, token) + if err != nil { + s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request") + return err + } + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) + if err != nil { + setOpsUpstreamError(c, 0, sanitizeUpstreamErrorMessage(err.Error()), "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Passthrough: true, + Kind: "request_error", + Message: sanitizeUpstreamErrorMessage(err.Error()), + }) + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed") + return fmt.Errorf("upstream request failed: %w", err) + } + + maxReadBytes := resolveUpstreamResponseReadLimit(s.cfg) + respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxReadBytes) + _ = resp.Body.Close() + if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large") + return err + } + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") + return err + } + + if resp.StatusCode >= 400 { + if s.rateLimitService != nil { + s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + } + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + + // 中转站不支持 count_tokens 端点时(404),返回 404 让客户端 fallback 到本地估算。 + // 仅在错误消息明确指向 count_tokens endpoint 不存在时生效,避免误吞其他 404(如错误 base_url)。 + // 返回 nil 避免 handler 层记录为错误,也不设置 ops 上游错误上下文。 + if isCountTokensUnsupported404(resp.StatusCode, respBody) { + logger.LegacyPrintf("service.gateway", + "[count_tokens] Upstream does not support count_tokens (404), returning 404: account=%d name=%s msg=%s", + account.ID, account.Name, truncateString(upstreamMsg, 512)) + s.countTokensError(c, http.StatusNotFound, "not_found_error", "count_tokens endpoint is not supported by upstream") + return nil + } + + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Passthrough: true, + Kind: "http_error", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + + errMsg := "Upstream request failed" + switch resp.StatusCode { + case 429: + errMsg = "Rate limit exceeded" + case 529: + errMsg = "Service overloaded" + } + s.countTokensError(c, resp.StatusCode, "upstream_error", errMsg) + if upstreamMsg == "" { + return fmt.Errorf("upstream error: %d", resp.StatusCode) + } + return fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg) + } + + writeAnthropicPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + contentType := strings.TrimSpace(resp.Header.Get("Content-Type")) + if contentType == "" { + contentType = "application/json" + } + c.Data(resp.StatusCode, contentType, respBody) + return nil +} + +func (s *GatewayService) buildCountTokensRequestAnthropicAPIKeyPassthrough( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + token string, +) (*http.Request, error) { + targetURL := claudeAPICountTokensURL + baseURL := account.GetBaseURL() + if baseURL != "" { + validatedURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, err + } + targetURL = validatedURL + "/v1/messages/count_tokens?beta=true" + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + + if c != nil && c.Request != nil { + for key, values := range c.Request.Header { + lowerKey := strings.ToLower(strings.TrimSpace(key)) + if !allowedHeaders[lowerKey] { + continue + } + for _, v := range values { + req.Header.Add(key, v) + } + } + } + + req.Header.Del("authorization") + req.Header.Del("x-api-key") + req.Header.Del("x-goog-api-key") + req.Header.Del("cookie") + req.Header.Set("x-api-key", token) + + if req.Header.Get("content-type") == "" { + req.Header.Set("content-type", "application/json") + } + if req.Header.Get("anthropic-version") == "" { + req.Header.Set("anthropic-version", "2023-06-01") + } + + return req, nil +} + +// buildCountTokensRequest 构建 count_tokens 上游请求 +func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, mimicClaudeCode bool) (*http.Request, error) { + // 确定目标 URL + targetURL := claudeAPICountTokensURL + if account.Type == AccountTypeAPIKey { + baseURL := account.GetBaseURL() + if baseURL != "" { + validatedURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, err + } + targetURL = validatedURL + "/v1/messages/count_tokens?beta=true" + } + } + + clientHeaders := http.Header{} + if c != nil && c.Request != nil { + clientHeaders = c.Request.Header + } + + // OAuth 账号:应用统一指纹和重写 userID + // 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值 + if account.IsOAuth() && s.identityService != nil { + fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, clientHeaders) + if err == nil { + accountUUID := account.GetExtraString("account_uuid") + if accountUUID != "" && fp.ClientID != "" { + if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID, fp.UserAgent); err == nil && len(newBody) > 0 { + body = newBody + } + } + } + } + + req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + + // 设置认证头 + if tokenType == "oauth" { + req.Header.Set("authorization", "Bearer "+token) + } else { + req.Header.Set("x-api-key", token) + } + + // 白名单透传 headers + for key, values := range clientHeaders { + lowerKey := strings.ToLower(key) + if allowedHeaders[lowerKey] { + for _, v := range values { + req.Header.Add(key, v) + } + } + } + + // OAuth 账号:应用指纹到请求头 + if account.IsOAuth() && s.identityService != nil { + fp, _ := s.identityService.GetOrCreateFingerprint(ctx, account.ID, clientHeaders) + if fp != nil { + s.identityService.ApplyFingerprint(req, fp) + } + } + + // 确保必要的 headers 存在 + if req.Header.Get("content-type") == "" { + req.Header.Set("content-type", "application/json") + } + if req.Header.Get("anthropic-version") == "" { + req.Header.Set("anthropic-version", "2023-06-01") + } + if tokenType == "oauth" { + applyClaudeOAuthHeaderDefaults(req, false) + } + + // Build effective drop set for count_tokens: merge static defaults with dynamic beta policy filter rules + ctEffectiveDropSet := mergeDropSets(s.getBetaPolicyFilterSet(ctx, c, account)) + + // OAuth 账号:处理 anthropic-beta header + if tokenType == "oauth" { + if mimicClaudeCode { + applyClaudeCodeMimicHeaders(req, false) + + incomingBeta := req.Header.Get("anthropic-beta") + requiredBetas := []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking, claude.BetaTokenCounting} + req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, ctEffectiveDropSet)) + } else { + clientBetaHeader := req.Header.Get("anthropic-beta") + if clientBetaHeader == "" { + req.Header.Set("anthropic-beta", claude.CountTokensBetaHeader) + } else { + beta := s.getBetaHeader(modelID, clientBetaHeader) + if !strings.Contains(beta, claude.BetaTokenCounting) { + beta = beta + "," + claude.BetaTokenCounting + } + req.Header.Set("anthropic-beta", stripBetaTokensWithSet(beta, ctEffectiveDropSet)) + } + } + } else { + // API-key accounts: apply beta policy filter to strip controlled tokens + if existingBeta := req.Header.Get("anthropic-beta"); existingBeta != "" { + req.Header.Set("anthropic-beta", stripBetaTokensWithSet(existingBeta, ctEffectiveDropSet)) + } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey { + // API-key:与 messages 同步的按需 beta 注入(默认关闭) + if requestNeedsBetaFeatures(body) { + if beta := defaultAPIKeyBetaHeader(body); beta != "" { + req.Header.Set("anthropic-beta", beta) + } + } + } + } + + if c != nil && tokenType == "oauth" { + c.Set(claudeMimicDebugInfoKey, buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode)) + } + if s.debugClaudeMimicEnabled() { + logClaudeMimicDebug(req, body, account, tokenType, mimicClaudeCode) + } + + return req, nil +} + +// countTokensError 返回 count_tokens 错误响应 +func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, message string) { + c.JSON(status, gin.H{ + "type": "error", + "error": gin.H{ + "type": errType, + "message": message, + }, + }) +} + +func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) { + if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled { + normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP) + if err != nil { + return "", fmt.Errorf("invalid base_url: %w", err) + } + return normalized, nil + } + normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{ + AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts, + RequireAllowlist: true, + AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts, + }) + if err != nil { + return "", fmt.Errorf("invalid base_url: %w", err) + } + return normalized, nil +} + +// GetAvailableModels returns the list of models available for a group +// It aggregates model_mapping keys from all schedulable accounts in the group +func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string { + cacheKey := modelsListCacheKey(groupID, platform) + if s.modelsListCache != nil { + if cached, found := s.modelsListCache.Get(cacheKey); found { + if models, ok := cached.([]string); ok { + modelsListCacheHitTotal.Add(1) + return cloneStringSlice(models) + } + } + } + modelsListCacheMissTotal.Add(1) + + var accounts []Account + var err error + + if groupID != nil { + accounts, err = s.accountRepo.ListSchedulableByGroupID(ctx, *groupID) + } else { + accounts, err = s.accountRepo.ListSchedulable(ctx) + } + + if err != nil || len(accounts) == 0 { + return nil + } + + // Filter by platform if specified + if platform != "" { + filtered := make([]Account, 0) + for _, acc := range accounts { + if acc.Platform == platform { + filtered = append(filtered, acc) + } + } + accounts = filtered + } + + // Collect unique models from all accounts + modelSet := make(map[string]struct{}) + hasAnyMapping := false + + for _, acc := range accounts { + mapping := acc.GetModelMapping() + if len(mapping) > 0 { + hasAnyMapping = true + for model := range mapping { + modelSet[model] = struct{}{} + } + } + } + + // If no account has model_mapping, return nil (use default) + if !hasAnyMapping { + if s.modelsListCache != nil { + s.modelsListCache.Set(cacheKey, []string(nil), s.modelsListCacheTTL) + modelsListCacheStoreTotal.Add(1) + } + return nil + } + + // Convert to slice + models := make([]string, 0, len(modelSet)) + for model := range modelSet { + models = append(models, model) + } + sort.Strings(models) + + if s.modelsListCache != nil { + s.modelsListCache.Set(cacheKey, cloneStringSlice(models), s.modelsListCacheTTL) + modelsListCacheStoreTotal.Add(1) + } + return cloneStringSlice(models) +} + +func (s *GatewayService) InvalidateAvailableModelsCache(groupID *int64, platform string) { + if s == nil || s.modelsListCache == nil { + return + } + + normalizedPlatform := strings.TrimSpace(platform) + // 完整匹配时精准失效;否则按维度批量失效。 + if groupID != nil && normalizedPlatform != "" { + s.modelsListCache.Delete(modelsListCacheKey(groupID, normalizedPlatform)) + return + } + + targetGroup := derefGroupID(groupID) + for key := range s.modelsListCache.Items() { + parts := strings.SplitN(key, "|", 2) + if len(parts) != 2 { + continue + } + groupPart, parseErr := strconv.ParseInt(parts[0], 10, 64) + if parseErr != nil { + continue + } + if groupID != nil && groupPart != targetGroup { + continue + } + if normalizedPlatform != "" && parts[1] != normalizedPlatform { + continue + } + s.modelsListCache.Delete(key) + } +} + +// reconcileCachedTokens 兼容 Kimi 等上游: +// 将 OpenAI 风格的 cached_tokens 映射到 Claude 标准的 cache_read_input_tokens +func reconcileCachedTokens(usage map[string]any) bool { + if usage == nil { + return false + } + cacheRead, _ := usage["cache_read_input_tokens"].(float64) + if cacheRead > 0 { + return false // 已有标准字段,无需处理 + } + cached, _ := usage["cached_tokens"].(float64) + if cached <= 0 { + return false + } + usage["cache_read_input_tokens"] = cached + return true +} + +func debugGatewayBodyLoggingEnabled() bool { + raw := strings.TrimSpace(os.Getenv(debugGatewayBodyEnv)) + if raw == "" { + return false + } + + switch strings.ToLower(raw) { + case "1", "true", "yes", "on": + return true + default: + return false + } +} + +// debugLogRequestBody 打印请求 body 用于调试 metadata.user_id 重写。 +// 默认关闭,仅在设置环境变量时启用: +// +// SUB2API_DEBUG_GATEWAY_BODY=1 +func debugLogRequestBody(tag string, body []byte) { + if !debugGatewayBodyLoggingEnabled() { + return + } + + if len(body) == 0 { + logger.LegacyPrintf("service.gateway", "[DEBUG_%s] body is empty", tag) + return + } + + // 提取 metadata 字段完整打印 + metadataResult := gjson.GetBytes(body, "metadata") + if metadataResult.Exists() { + logger.LegacyPrintf("service.gateway", "[DEBUG_%s] metadata = %s", tag, metadataResult.Raw) + } else { + logger.LegacyPrintf("service.gateway", "[DEBUG_%s] metadata field not found", tag) + } + + // 全量打印 body + logger.LegacyPrintf("service.gateway", "[DEBUG_%s] body (%d bytes) = %s", tag, len(body), string(body)) +} diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go new file mode 100644 index 00000000..6ada08bf --- /dev/null +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -0,0 +1,3307 @@ +package service + +import ( + "bufio" + "bytes" + "context" + "crypto/rand" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "math" + mathrand "math/rand" + "net/http" + "regexp" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" + "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" + "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" + + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" +) + +const geminiStickySessionTTL = time.Hour + +const ( + geminiMaxRetries = 5 + geminiRetryBaseDelay = 1 * time.Second + geminiRetryMaxDelay = 16 * time.Second +) + +// Gemini tool calling now requires `thoughtSignature` in parts that include `functionCall`. +// Many clients don't send it; we inject a known dummy signature to satisfy the validator. +// Ref: https://ai.google.dev/gemini-api/docs/thought-signatures +const geminiDummyThoughtSignature = "skip_thought_signature_validator" + +type GeminiMessagesCompatService struct { + accountRepo AccountRepository + groupRepo GroupRepository + cache GatewayCache + schedulerSnapshot *SchedulerSnapshotService + tokenProvider *GeminiTokenProvider + rateLimitService *RateLimitService + httpUpstream HTTPUpstream + antigravityGatewayService *AntigravityGatewayService + cfg *config.Config + responseHeaderFilter *responseheaders.CompiledHeaderFilter +} + +func NewGeminiMessagesCompatService( + accountRepo AccountRepository, + groupRepo GroupRepository, + cache GatewayCache, + schedulerSnapshot *SchedulerSnapshotService, + tokenProvider *GeminiTokenProvider, + rateLimitService *RateLimitService, + httpUpstream HTTPUpstream, + antigravityGatewayService *AntigravityGatewayService, + cfg *config.Config, +) *GeminiMessagesCompatService { + return &GeminiMessagesCompatService{ + accountRepo: accountRepo, + groupRepo: groupRepo, + cache: cache, + schedulerSnapshot: schedulerSnapshot, + tokenProvider: tokenProvider, + rateLimitService: rateLimitService, + httpUpstream: httpUpstream, + antigravityGatewayService: antigravityGatewayService, + cfg: cfg, + responseHeaderFilter: compileResponseHeaderFilter(cfg), + } +} + +// GetTokenProvider returns the token provider for OAuth accounts +func (s *GeminiMessagesCompatService) GetTokenProvider() *GeminiTokenProvider { + return s.tokenProvider +} + +func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) { + return s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, nil) +} + +func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { + // 1. 确定目标平台和调度模式 + // Determine target platform and scheduling mode + platform, useMixedScheduling, hasForcePlatform, err := s.resolvePlatformAndSchedulingMode(ctx, groupID) + if err != nil { + return nil, err + } + + cacheKey := "gemini:" + sessionHash + + // 2. 尝试粘性会话命中 + // Try sticky session hit + if account := s.tryStickySessionHit(ctx, groupID, sessionHash, cacheKey, requestedModel, excludedIDs, platform, useMixedScheduling); account != nil { + return account, nil + } + + // 3. 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部) + // Query schedulable accounts (force platform mode: try group first, fallback to all) + accounts, err := s.listSchedulableAccountsOnce(ctx, groupID, platform, hasForcePlatform) + if err != nil { + return nil, fmt.Errorf("query accounts failed: %w", err) + } + // 强制平台模式下,分组中找不到账户时回退查询全部 + if len(accounts) == 0 && groupID != nil && hasForcePlatform { + accounts, err = s.listSchedulableAccountsOnce(ctx, nil, platform, hasForcePlatform) + if err != nil { + return nil, fmt.Errorf("query accounts failed: %w", err) + } + } + + // 4. 按优先级 + LRU 选择最佳账号 + // Select best account by priority + LRU + selected := s.selectBestGeminiAccount(ctx, accounts, requestedModel, excludedIDs, platform, useMixedScheduling) + + if selected == nil { + if requestedModel != "" { + return nil, fmt.Errorf("no available Gemini accounts supporting model: %s", requestedModel) + } + return nil, errors.New("no available Gemini accounts") + } + + // 5. 设置粘性会话绑定 + // Set sticky session binding + if sessionHash != "" { + _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, geminiStickySessionTTL) + } + + return selected, nil +} + +// resolvePlatformAndSchedulingMode 解析目标平台和调度模式。 +// 返回:平台名称、是否使用混合调度、是否强制平台、错误。 +// +// resolvePlatformAndSchedulingMode resolves target platform and scheduling mode. +// Returns: platform name, whether to use mixed scheduling, whether force platform, error. +func (s *GeminiMessagesCompatService) resolvePlatformAndSchedulingMode(ctx context.Context, groupID *int64) (platform string, useMixedScheduling bool, hasForcePlatform bool, err error) { + // 优先检查 context 中的强制平台(/antigravity 路由) + forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) + if hasForcePlatform && forcePlatform != "" { + return forcePlatform, false, true, nil + } + + if groupID != nil { + // 根据分组 platform 决定查询哪种账号 + var group *Group + if ctxGroup, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(ctxGroup) && ctxGroup.ID == *groupID { + group = ctxGroup + } else { + group, err = s.groupRepo.GetByIDLite(ctx, *groupID) + if err != nil { + return "", false, false, fmt.Errorf("get group failed: %w", err) + } + } + // gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户) + return group.Platform, group.Platform == PlatformGemini, false, nil + } + + // 无分组时只使用原生 gemini 平台 + return PlatformGemini, true, false, nil +} + +// tryStickySessionHit 尝试从粘性会话获取账号。 +// 如果命中且账号可用则返回账号;如果账号不可用则清理会话并返回 nil。 +// +// tryStickySessionHit attempts to get account from sticky session. +// Returns account if hit and usable; clears session and returns nil if account unavailable. +func (s *GeminiMessagesCompatService) tryStickySessionHit( + ctx context.Context, + groupID *int64, + sessionHash, cacheKey, requestedModel string, + excludedIDs map[int64]struct{}, + platform string, + useMixedScheduling bool, +) *Account { + if sessionHash == "" { + return nil + } + + accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey) + if err != nil || accountID <= 0 { + return nil + } + + if _, excluded := excludedIDs[accountID]; excluded { + return nil + } + + account, err := s.getSchedulableAccount(ctx, accountID) + if err != nil { + return nil + } + + // 检查账号是否需要清理粘性会话 + // Check if sticky session should be cleared + if shouldClearStickySession(ctx, account, requestedModel) { + _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), cacheKey) + return nil + } + + // 验证账号是否可用于当前请求 + // Verify account is usable for current request + if !s.isAccountUsableForRequest(ctx, account, requestedModel, platform, useMixedScheduling) { + return nil + } + + // 刷新会话 TTL 并返回账号 + // Refresh session TTL and return account + _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, geminiStickySessionTTL) + return account +} + +// isAccountUsableForRequest 检查账号是否可用于当前请求。 +// 验证:模型调度、模型支持、平台匹配、速率限制预检。 +// +// isAccountUsableForRequest checks if account is usable for current request. +// Validates: model scheduling, model support, platform matching, rate limit precheck. +func (s *GeminiMessagesCompatService) isAccountUsableForRequest( + ctx context.Context, + account *Account, + requestedModel, platform string, + useMixedScheduling bool, +) bool { + return s.isAccountUsableForRequestWithPrecheck(ctx, account, requestedModel, platform, useMixedScheduling, nil) +} + +func (s *GeminiMessagesCompatService) isAccountUsableForRequestWithPrecheck( + ctx context.Context, + account *Account, + requestedModel, platform string, + useMixedScheduling bool, + precheckResult map[int64]bool, +) bool { + // 检查模型调度能力 + // Check model scheduling capability + if !account.IsSchedulableForModelWithContext(ctx, requestedModel) { + return false + } + + // 检查模型支持 + // Check model support + if requestedModel != "" && !s.isModelSupportedByAccount(account, requestedModel) { + return false + } + + // 检查平台匹配 + // Check platform matching + if !s.isAccountValidForPlatform(account, platform, useMixedScheduling) { + return false + } + + // 速率限制预检 + // Rate limit precheck + if !s.passesRateLimitPreCheckWithCache(ctx, account, requestedModel, precheckResult) { + return false + } + + return true +} + +// isAccountValidForPlatform 检查账号是否匹配目标平台。 +// 原生平台直接匹配;混合调度模式下 antigravity 需要启用 mixed_scheduling。 +// +// isAccountValidForPlatform checks if account matches target platform. +// Native platform matches directly; mixed scheduling mode requires antigravity to enable mixed_scheduling. +func (s *GeminiMessagesCompatService) isAccountValidForPlatform(account *Account, platform string, useMixedScheduling bool) bool { + if account.Platform == platform { + return true + } + if useMixedScheduling && account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() { + return true + } + return false +} + +func (s *GeminiMessagesCompatService) passesRateLimitPreCheckWithCache(ctx context.Context, account *Account, requestedModel string, precheckResult map[int64]bool) bool { + if s.rateLimitService == nil || requestedModel == "" { + return true + } + + if precheckResult != nil { + if ok, exists := precheckResult[account.ID]; exists { + return ok + } + } + + ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel) + if err != nil { + logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini PreCheck] Account %d precheck error: %v", account.ID, err) + } + return ok +} + +// selectBestGeminiAccount 从候选账号中选择最佳账号(优先级 + LRU + OAuth 优先)。 +// 返回 nil 表示无可用账号。 +// +// selectBestGeminiAccount selects best account from candidates (priority + LRU + OAuth preferred). +// Returns nil if no available account. +func (s *GeminiMessagesCompatService) selectBestGeminiAccount( + ctx context.Context, + accounts []Account, + requestedModel string, + excludedIDs map[int64]struct{}, + platform string, + useMixedScheduling bool, +) *Account { + var selected *Account + precheckResult := s.buildPreCheckUsageResultMap(ctx, accounts, requestedModel) + + for i := range accounts { + acc := &accounts[i] + + // 跳过被排除的账号 + if _, excluded := excludedIDs[acc.ID]; excluded { + continue + } + + // 检查账号是否可用于当前请求 + if !s.isAccountUsableForRequestWithPrecheck(ctx, acc, requestedModel, platform, useMixedScheduling, precheckResult) { + continue + } + + // 选择最佳账号 + if selected == nil { + selected = acc + continue + } + + if s.isBetterGeminiAccount(acc, selected) { + selected = acc + } + } + + return selected +} + +func (s *GeminiMessagesCompatService) buildPreCheckUsageResultMap(ctx context.Context, accounts []Account, requestedModel string) map[int64]bool { + if s.rateLimitService == nil || requestedModel == "" || len(accounts) == 0 { + return nil + } + + candidates := make([]*Account, 0, len(accounts)) + for i := range accounts { + candidates = append(candidates, &accounts[i]) + } + + result, err := s.rateLimitService.PreCheckUsageBatch(ctx, candidates, requestedModel) + if err != nil { + logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini PreCheckBatch] failed: %v", err) + } + return result +} + +// isBetterGeminiAccount 判断 candidate 是否比 current 更优。 +// 规则:优先级更高(数值更小)优先;同优先级时,未使用过的优先(OAuth > 非 OAuth),其次是最久未使用的。 +// +// isBetterGeminiAccount checks if candidate is better than current. +// Rules: higher priority (lower value) wins; same priority: never used (OAuth > non-OAuth) > least recently used. +func (s *GeminiMessagesCompatService) isBetterGeminiAccount(candidate, current *Account) bool { + // 优先级更高(数值更小) + if candidate.Priority < current.Priority { + return true + } + if candidate.Priority > current.Priority { + return false + } + + // 同优先级,比较最后使用时间 + switch { + case candidate.LastUsedAt == nil && current.LastUsedAt != nil: + // candidate 从未使用,优先 + return true + case candidate.LastUsedAt != nil && current.LastUsedAt == nil: + // current 从未使用,保持 + return false + case candidate.LastUsedAt == nil && current.LastUsedAt == nil: + // 都未使用,优先选择 OAuth 账号(更兼容 Code Assist 流程) + return candidate.Type == AccountTypeOAuth && current.Type != AccountTypeOAuth + default: + // 都使用过,选择最久未使用的 + return candidate.LastUsedAt.Before(*current.LastUsedAt) + } +} + +// isModelSupportedByAccount 根据账户平台检查模型支持 +func (s *GeminiMessagesCompatService) isModelSupportedByAccount(account *Account, requestedModel string) bool { + if account.Platform == PlatformAntigravity { + if strings.TrimSpace(requestedModel) == "" { + return true + } + return mapAntigravityModel(account, requestedModel) != "" + } + return account.IsModelSupported(requestedModel) +} + +// GetAntigravityGatewayService 返回 AntigravityGatewayService +func (s *GeminiMessagesCompatService) GetAntigravityGatewayService() *AntigravityGatewayService { + return s.antigravityGatewayService +} + +func (s *GeminiMessagesCompatService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) { + if s.schedulerSnapshot != nil { + return s.schedulerSnapshot.GetAccount(ctx, accountID) + } + return s.accountRepo.GetByID(ctx, accountID) +} + +func (s *GeminiMessagesCompatService) listSchedulableAccountsOnce(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, error) { + if s.schedulerSnapshot != nil { + accounts, _, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) + return accounts, err + } + + useMixedScheduling := platform == PlatformGemini && !hasForcePlatform + queryPlatforms := []string{platform} + if useMixedScheduling { + queryPlatforms = []string{platform, PlatformAntigravity} + } + + if groupID != nil { + return s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, queryPlatforms) + } + if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + return s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms) + } + return s.accountRepo.ListSchedulableUngroupedByPlatforms(ctx, queryPlatforms) +} + +func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (string, error) { + if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled { + normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP) + if err != nil { + return "", fmt.Errorf("invalid base_url: %w", err) + } + return normalized, nil + } + normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{ + AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts, + RequireAllowlist: true, + AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts, + }) + if err != nil { + return "", fmt.Errorf("invalid base_url: %w", err) + } + return normalized, nil +} + +// HasAntigravityAccounts 检查是否有可用的 antigravity 账户 +func (s *GeminiMessagesCompatService) HasAntigravityAccounts(ctx context.Context, groupID *int64) (bool, error) { + accounts, err := s.listSchedulableAccountsOnce(ctx, groupID, PlatformAntigravity, false) + if err != nil { + return false, err + } + return len(accounts) > 0, nil +} + +// SelectAccountForAIStudioEndpoints selects an account that is likely to succeed against +// generativelanguage.googleapis.com (e.g. GET /v1beta/models). +// +// Preference order: +// 1) API key accounts (AI Studio) +// 2) OAuth accounts without project_id (AI Studio OAuth) +// 3) OAuth accounts explicitly marked as ai_studio +// 4) Any remaining Gemini accounts (fallback) +func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx context.Context, groupID *int64) (*Account, error) { + accounts, err := s.listSchedulableAccountsOnce(ctx, groupID, PlatformGemini, true) + if err != nil { + return nil, fmt.Errorf("query accounts failed: %w", err) + } + if len(accounts) == 0 { + return nil, errors.New("no available Gemini accounts") + } + + rank := func(a *Account) int { + if a == nil { + return 999 + } + switch a.Type { + case AccountTypeAPIKey: + if strings.TrimSpace(a.GetCredential("api_key")) != "" { + return 0 + } + return 9 + case AccountTypeOAuth: + if strings.TrimSpace(a.GetCredential("project_id")) == "" { + return 1 + } + if strings.TrimSpace(a.GetCredential("oauth_type")) == "ai_studio" { + return 2 + } + // Code Assist OAuth tokens often lack AI Studio scopes for models listing. + return 3 + default: + return 10 + } + } + + var selected *Account + for i := range accounts { + acc := &accounts[i] + if selected == nil { + selected = acc + continue + } + + r1, r2 := rank(acc), rank(selected) + if r1 < r2 { + selected = acc + continue + } + if r1 > r2 { + continue + } + + if acc.Priority < selected.Priority { + selected = acc + } else if acc.Priority == selected.Priority { + switch { + case acc.LastUsedAt == nil && selected.LastUsedAt != nil: + selected = acc + case acc.LastUsedAt != nil && selected.LastUsedAt == nil: + // keep selected + case acc.LastUsedAt == nil && selected.LastUsedAt == nil: + if acc.Type == AccountTypeOAuth && selected.Type != AccountTypeOAuth { + selected = acc + } + default: + if acc.LastUsedAt.Before(*selected.LastUsedAt) { + selected = acc + } + } + } + } + + if selected == nil { + return nil, errors.New("no available Gemini accounts") + } + return selected, nil +} + +func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) { + startTime := time.Now() + + var req struct { + Model string `json:"model"` + Stream bool `json:"stream"` + } + if err := json.Unmarshal(body, &req); err != nil { + return nil, fmt.Errorf("parse request: %w", err) + } + if strings.TrimSpace(req.Model) == "" { + return nil, fmt.Errorf("missing model") + } + + originalModel := req.Model + mappedModel := req.Model + if account.Type == AccountTypeAPIKey { + mappedModel = account.GetMappedModel(req.Model) + } + + geminiReq, err := convertClaudeMessagesToGeminiGenerateContent(body) + if err != nil { + return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", err.Error()) + } + geminiReq = ensureGeminiFunctionCallThoughtSignatures(geminiReq) + originalClaudeBody := body + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + var requestIDHeader string + var buildReq func(ctx context.Context) (*http.Request, string, error) + useUpstreamStream := req.Stream + if account.Type == AccountTypeOAuth && !req.Stream && strings.TrimSpace(account.GetCredential("project_id")) != "" { + // Code Assist's non-streaming generateContent may return no content; use streaming upstream and aggregate. + useUpstreamStream = true + } + + switch account.Type { + case AccountTypeAPIKey: + buildReq = func(ctx context.Context) (*http.Request, string, error) { + apiKey := account.GetCredential("api_key") + if strings.TrimSpace(apiKey) == "" { + return nil, "", errors.New("gemini api_key not configured") + } + + baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL) + normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, "", err + } + + action := "generateContent" + if req.Stream { + action = "streamGenerateContent" + } + fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, action) + if req.Stream { + fullURL += "?alt=sse" + } + + upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(geminiReq)) + if err != nil { + return nil, "", err + } + upstreamReq.Header.Set("Content-Type", "application/json") + upstreamReq.Header.Set("x-goog-api-key", apiKey) + return upstreamReq, "x-request-id", nil + } + requestIDHeader = "x-request-id" + + case AccountTypeOAuth: + buildReq = func(ctx context.Context) (*http.Request, string, error) { + if s.tokenProvider == nil { + return nil, "", errors.New("gemini token provider not configured") + } + accessToken, err := s.tokenProvider.GetAccessToken(ctx, account) + if err != nil { + return nil, "", err + } + + projectID := strings.TrimSpace(account.GetCredential("project_id")) + + action := "generateContent" + if useUpstreamStream { + action = "streamGenerateContent" + } + + // Two modes for OAuth: + // 1. With project_id -> Code Assist API (wrapped request) + // 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token) + if projectID != "" { + // Mode 1: Code Assist API + baseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL) + if err != nil { + return nil, "", err + } + fullURL := fmt.Sprintf("%s/v1internal:%s", strings.TrimRight(baseURL, "/"), action) + if useUpstreamStream { + fullURL += "?alt=sse" + } + + wrapped := map[string]any{ + "model": mappedModel, + "project": projectID, + } + var inner any + if err := json.Unmarshal(geminiReq, &inner); err != nil { + return nil, "", fmt.Errorf("failed to parse gemini request: %w", err) + } + wrapped["request"] = inner + wrappedBytes, _ := json.Marshal(wrapped) + + upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(wrappedBytes)) + if err != nil { + return nil, "", err + } + upstreamReq.Header.Set("Content-Type", "application/json") + upstreamReq.Header.Set("Authorization", "Bearer "+accessToken) + upstreamReq.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent) + return upstreamReq, "x-request-id", nil + } else { + // Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token) + baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL) + normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, "", err + } + + fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, action) + if useUpstreamStream { + fullURL += "?alt=sse" + } + + upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(geminiReq)) + if err != nil { + return nil, "", err + } + upstreamReq.Header.Set("Content-Type", "application/json") + upstreamReq.Header.Set("Authorization", "Bearer "+accessToken) + return upstreamReq, "x-request-id", nil + } + } + requestIDHeader = "x-request-id" + + default: + return nil, fmt.Errorf("unsupported account type: %s", account.Type) + } + + var resp *http.Response + signatureRetryStage := 0 + for attempt := 1; attempt <= geminiMaxRetries; attempt++ { + upstreamReq, idHeader, err := buildReq(ctx) + if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return nil, err + } + // Local build error: don't retry. + if strings.Contains(err.Error(), "missing project_id") { + return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", err.Error()) + } + return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", err.Error()) + } + requestIDHeader = idHeader + + // Capture upstream request body for ops retry of this attempt. + if c != nil { + // In this code path `body` is already the JSON sent to upstream. + c.Set(OpsUpstreamRequestBodyKey, string(body)) + } + + resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + if err != nil { + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + if attempt < geminiMaxRetries { + logger.LegacyPrintf("service.gemini_messages_compat", "Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err) + sleepGeminiBackoff(attempt) + continue + } + setOpsUpstreamError(c, 0, safeErr, "") + return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries: "+safeErr) + } + + // Special-case: signature/thought_signature validation errors are not transient, but may be fixed by + // downgrading Claude thinking/tool history to plain text (conservative two-stage retry). + if resp.StatusCode == http.StatusBadRequest && signatureRetryStage < 2 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + + if isGeminiSignatureRelatedError(respBody) { + upstreamReqID := resp.Header.Get(requestIDHeader) + if upstreamReqID == "" { + upstreamReqID = resp.Header.Get("x-goog-request-id") + } + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: upstreamReqID, + Kind: "signature_error", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + + var strippedClaudeBody []byte + stageName := "" + switch signatureRetryStage { + case 0: + // Stage 1: disable thinking + thinking->text + strippedClaudeBody = FilterThinkingBlocksForRetry(originalClaudeBody) + stageName = "thinking-only" + signatureRetryStage = 1 + default: + // Stage 2: additionally downgrade tool_use/tool_result blocks to text + strippedClaudeBody = FilterSignatureSensitiveBlocksForRetry(originalClaudeBody) + stageName = "thinking+tools" + signatureRetryStage = 2 + } + retryGeminiReq, txErr := convertClaudeMessagesToGeminiGenerateContent(strippedClaudeBody) + if txErr == nil { + logger.LegacyPrintf("service.gemini_messages_compat", "Gemini account %d: detected signature-related 400, retrying with downgraded Claude blocks (%s)", account.ID, stageName) + geminiReq = retryGeminiReq + // Consume one retry budget attempt and continue with the updated request payload. + sleepGeminiBackoff(1) + continue + } + } + + // Restore body for downstream error handling. + resp = &http.Response{ + StatusCode: http.StatusBadRequest, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(respBody)), + } + break + } + + // 错误策略优先:匹配则跳过重试直接处理。 + if matched, rebuilt := s.checkErrorPolicyInLoop(ctx, account, resp); matched { + resp = rebuilt + break + } else { + resp = rebuilt + } + + if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + // Don't treat insufficient-scope as transient. + if resp.StatusCode == 403 && isGeminiInsufficientScope(resp.Header, respBody) { + resp = &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(respBody)), + } + break + } + if resp.StatusCode == 429 { + // Mark as rate-limited early so concurrent requests avoid this account. + s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + } + if attempt < geminiMaxRetries { + upstreamReqID := resp.Header.Get(requestIDHeader) + if upstreamReqID == "" { + upstreamReqID = resp.Header.Get("x-goog-request-id") + } + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: upstreamReqID, + Kind: "retry", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + + logger.LegacyPrintf("service.gemini_messages_compat", "Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries) + sleepGeminiBackoff(attempt) + continue + } + // Final attempt: surface the upstream error body (mapped below) instead of a generic retry error. + resp = &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(respBody)), + } + break + } + + break + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + // 统一错误策略:自定义错误码 + 临时不可调度 + if s.rateLimitService != nil { + switch s.rateLimitService.CheckErrorPolicy(ctx, account, resp.StatusCode, respBody) { + case ErrorPolicySkipped: + upstreamReqID := resp.Header.Get(requestIDHeader) + if upstreamReqID == "" { + upstreamReqID = resp.Header.Get("x-goog-request-id") + } + return nil, s.writeGeminiMappedError(c, account, http.StatusInternalServerError, upstreamReqID, respBody) + case ErrorPolicyMatched, ErrorPolicyTempUnscheduled: + s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + upstreamReqID := resp.Header.Get(requestIDHeader) + if upstreamReqID == "" { + upstreamReqID = resp.Header.Get("x-goog-request-id") + } + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: upstreamReqID, + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} + } + } + + // ErrorPolicyNone → 原有逻辑 + s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + // 精确匹配服务端配置类 400 错误,触发 failover + 临时封禁 + if resp.StatusCode == http.StatusBadRequest { + msg400 := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody))) + if isGoogleProjectConfigError(msg400) { + upstreamReqID := resp.Header.Get(requestIDHeader) + if upstreamReqID == "" { + upstreamReqID = resp.Header.Get("x-goog-request-id") + } + upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractUpstreamErrorMessage(respBody))) + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + log.Printf("[Gemini] status=400 google_config_error failover=true upstream_message=%q account=%d", upstreamMsg, account.ID) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: upstreamReqID, + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody, RetryableOnSameAccount: true} + } + } + if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) { + upstreamReqID := resp.Header.Get(requestIDHeader) + if upstreamReqID == "" { + upstreamReqID = resp.Header.Get("x-goog-request-id") + } + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: upstreamReqID, + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} + } + upstreamReqID := resp.Header.Get(requestIDHeader) + if upstreamReqID == "" { + upstreamReqID = resp.Header.Get("x-goog-request-id") + } + return nil, s.writeGeminiMappedError(c, account, resp.StatusCode, upstreamReqID, respBody) + } + + requestID := resp.Header.Get(requestIDHeader) + if requestID == "" { + requestID = resp.Header.Get("x-goog-request-id") + } + if requestID != "" { + c.Header("x-request-id", requestID) + } + + var usage *ClaudeUsage + var firstTokenMs *int + if req.Stream { + streamRes, err := s.handleStreamingResponse(c, resp, startTime, originalModel) + if err != nil { + return nil, err + } + usage = streamRes.usage + firstTokenMs = streamRes.firstTokenMs + } else { + if useUpstreamStream { + collected, usageObj, err := collectGeminiSSE(resp.Body, true) + if err != nil { + return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream stream") + } + collectedBytes, _ := json.Marshal(collected) + claudeResp, usageObj2 := convertGeminiToClaudeMessage(collected, originalModel, collectedBytes) + c.JSON(http.StatusOK, claudeResp) + usage = usageObj2 + if usageObj != nil && (usageObj.InputTokens > 0 || usageObj.OutputTokens > 0) { + usage = usageObj + } + } else { + usage, err = s.handleNonStreamingResponse(c, resp, originalModel) + if err != nil { + return nil, err + } + } + } + + // 图片生成计费 + imageCount := 0 + imageSize := s.extractImageSize(body) + if isImageGenerationModel(originalModel) { + imageCount = 1 + } + + return &ForwardResult{ + RequestID: requestID, + Usage: *usage, + Model: originalModel, + UpstreamModel: mappedModel, + Stream: req.Stream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ImageCount: imageCount, + ImageSize: imageSize, + }, nil +} + +func isGeminiSignatureRelatedError(respBody []byte) bool { + msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody))) + if msg == "" { + msg = strings.ToLower(string(respBody)) + } + return strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature") +} + +func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) { + startTime := time.Now() + + if strings.TrimSpace(originalModel) == "" { + return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing model in URL") + } + if strings.TrimSpace(action) == "" { + return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing action in URL") + } + if len(body) == 0 { + return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty") + } + + // 过滤掉 parts 为空的消息(Gemini API 不接受空 parts) + if filteredBody, err := filterEmptyPartsFromGeminiRequest(body); err == nil { + body = filteredBody + } + + switch action { + case "generateContent", "streamGenerateContent", "countTokens": + // ok + default: + return nil, s.writeGoogleError(c, http.StatusNotFound, "Unsupported action: "+action) + } + + // Some Gemini upstreams validate tool call parts strictly; ensure any `functionCall` part includes a + // `thoughtSignature` to avoid frequent INVALID_ARGUMENT 400s. + body = ensureGeminiFunctionCallThoughtSignatures(body) + + mappedModel := originalModel + if account.Type == AccountTypeAPIKey { + mappedModel = account.GetMappedModel(originalModel) + } + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + useUpstreamStream := stream + upstreamAction := action + if account.Type == AccountTypeOAuth && !stream && action == "generateContent" && strings.TrimSpace(account.GetCredential("project_id")) != "" { + // Code Assist's non-streaming generateContent may return no content; use streaming upstream and aggregate. + useUpstreamStream = true + upstreamAction = "streamGenerateContent" + } + forceAIStudio := action == "countTokens" + + var requestIDHeader string + var buildReq func(ctx context.Context) (*http.Request, string, error) + + switch account.Type { + case AccountTypeAPIKey: + buildReq = func(ctx context.Context) (*http.Request, string, error) { + apiKey := account.GetCredential("api_key") + if strings.TrimSpace(apiKey) == "" { + return nil, "", errors.New("gemini api_key not configured") + } + + baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL) + normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, "", err + } + + fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, upstreamAction) + if useUpstreamStream { + fullURL += "?alt=sse" + } + + upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(body)) + if err != nil { + return nil, "", err + } + upstreamReq.Header.Set("Content-Type", "application/json") + upstreamReq.Header.Set("x-goog-api-key", apiKey) + return upstreamReq, "x-request-id", nil + } + requestIDHeader = "x-request-id" + + case AccountTypeOAuth: + buildReq = func(ctx context.Context) (*http.Request, string, error) { + if s.tokenProvider == nil { + return nil, "", errors.New("gemini token provider not configured") + } + accessToken, err := s.tokenProvider.GetAccessToken(ctx, account) + if err != nil { + return nil, "", err + } + + projectID := strings.TrimSpace(account.GetCredential("project_id")) + + // Two modes for OAuth: + // 1. With project_id -> Code Assist API (wrapped request) + // 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token) + if projectID != "" && !forceAIStudio { + // Mode 1: Code Assist API + baseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL) + if err != nil { + return nil, "", err + } + fullURL := fmt.Sprintf("%s/v1internal:%s", strings.TrimRight(baseURL, "/"), upstreamAction) + if useUpstreamStream { + fullURL += "?alt=sse" + } + + wrapped := map[string]any{ + "model": mappedModel, + "project": projectID, + } + var inner any + if err := json.Unmarshal(body, &inner); err != nil { + return nil, "", fmt.Errorf("failed to parse gemini request: %w", err) + } + wrapped["request"] = inner + wrappedBytes, _ := json.Marshal(wrapped) + + upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(wrappedBytes)) + if err != nil { + return nil, "", err + } + upstreamReq.Header.Set("Content-Type", "application/json") + upstreamReq.Header.Set("Authorization", "Bearer "+accessToken) + upstreamReq.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent) + return upstreamReq, "x-request-id", nil + } else { + // Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token) + baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL) + normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, "", err + } + + fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, upstreamAction) + if useUpstreamStream { + fullURL += "?alt=sse" + } + + upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(body)) + if err != nil { + return nil, "", err + } + upstreamReq.Header.Set("Content-Type", "application/json") + upstreamReq.Header.Set("Authorization", "Bearer "+accessToken) + return upstreamReq, "x-request-id", nil + } + } + requestIDHeader = "x-request-id" + + default: + return nil, s.writeGoogleError(c, http.StatusBadGateway, "Unsupported account type: "+account.Type) + } + + var resp *http.Response + for attempt := 1; attempt <= geminiMaxRetries; attempt++ { + upstreamReq, idHeader, err := buildReq(ctx) + if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return nil, err + } + // Local build error: don't retry. + if strings.Contains(err.Error(), "missing project_id") { + return nil, s.writeGoogleError(c, http.StatusBadRequest, err.Error()) + } + return nil, s.writeGoogleError(c, http.StatusBadGateway, err.Error()) + } + requestIDHeader = idHeader + + // Capture upstream request body for ops retry of this attempt. + if c != nil { + // In this code path `body` is already the JSON sent to upstream. + c.Set(OpsUpstreamRequestBodyKey, string(body)) + } + + resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + if err != nil { + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + if attempt < geminiMaxRetries { + logger.LegacyPrintf("service.gemini_messages_compat", "Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err) + sleepGeminiBackoff(attempt) + continue + } + if action == "countTokens" { + estimated := estimateGeminiCountTokens(body) + c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated}) + return &ForwardResult{ + RequestID: "", + Usage: ClaudeUsage{}, + Model: originalModel, + UpstreamModel: mappedModel, + Stream: false, + Duration: time.Since(startTime), + FirstTokenMs: nil, + }, nil + } + setOpsUpstreamError(c, 0, safeErr, "") + return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries: "+safeErr) + } + + // 错误策略优先:匹配则跳过重试直接处理。 + if matched, rebuilt := s.checkErrorPolicyInLoop(ctx, account, resp); matched { + resp = rebuilt + break + } else { + resp = rebuilt + } + + if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + // Don't treat insufficient-scope as transient. + if resp.StatusCode == 403 && isGeminiInsufficientScope(resp.Header, respBody) { + resp = &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(respBody)), + } + break + } + if resp.StatusCode == 429 { + s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + } + if attempt < geminiMaxRetries { + upstreamReqID := resp.Header.Get(requestIDHeader) + if upstreamReqID == "" { + upstreamReqID = resp.Header.Get("x-goog-request-id") + } + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: upstreamReqID, + Kind: "retry", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + + logger.LegacyPrintf("service.gemini_messages_compat", "Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries) + sleepGeminiBackoff(attempt) + continue + } + if action == "countTokens" { + estimated := estimateGeminiCountTokens(body) + c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated}) + return &ForwardResult{ + RequestID: "", + Usage: ClaudeUsage{}, + Model: originalModel, + UpstreamModel: mappedModel, + Stream: false, + Duration: time.Since(startTime), + FirstTokenMs: nil, + }, nil + } + // Final attempt: surface the upstream error body (passed through below) instead of a generic retry error. + resp = &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(respBody)), + } + break + } + + break + } + defer func() { _ = resp.Body.Close() }() + + requestID := resp.Header.Get(requestIDHeader) + if requestID == "" { + requestID = resp.Header.Get("x-goog-request-id") + } + if requestID != "" { + c.Header("x-request-id", requestID) + } + + isOAuth := account.Type == AccountTypeOAuth + + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + // Best-effort fallback for OAuth tokens missing AI Studio scopes when calling countTokens. + // This avoids Gemini SDKs failing hard during preflight token counting. + // Checked before error policy so it always works regardless of custom error codes. + if action == "countTokens" && isOAuth && isGeminiInsufficientScope(resp.Header, respBody) { + estimated := estimateGeminiCountTokens(body) + c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated}) + return &ForwardResult{ + RequestID: requestID, + Usage: ClaudeUsage{}, + Model: originalModel, + UpstreamModel: mappedModel, + Stream: false, + Duration: time.Since(startTime), + FirstTokenMs: nil, + }, nil + } + + // 统一错误策略:自定义错误码 + 临时不可调度 + if s.rateLimitService != nil { + switch s.rateLimitService.CheckErrorPolicy(ctx, account, resp.StatusCode, respBody) { + case ErrorPolicySkipped: + respBody = unwrapIfNeeded(isOAuth, respBody) + contentType := resp.Header.Get("Content-Type") + if contentType == "" { + contentType = "application/json" + } + c.Data(http.StatusInternalServerError, contentType, respBody) + return nil, fmt.Errorf("gemini upstream error: %d (skipped by error policy)", resp.StatusCode) + case ErrorPolicyMatched, ErrorPolicyTempUnscheduled: + s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + evBody := unwrapIfNeeded(isOAuth, respBody) + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(evBody), maxBytes) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: requestID, + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} + } + } + + // ErrorPolicyNone → 原有逻辑 + s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + // 精确匹配服务端配置类 400 错误,触发 failover + 临时封禁 + if resp.StatusCode == http.StatusBadRequest { + msg400 := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody))) + if isGoogleProjectConfigError(msg400) { + evBody := unwrapIfNeeded(isOAuth, respBody) + upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractUpstreamErrorMessage(evBody))) + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(evBody), maxBytes) + } + log.Printf("[Gemini] status=400 google_config_error failover=true upstream_message=%q account=%d", upstreamMsg, account.ID) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: requestID, + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: evBody, RetryableOnSameAccount: true} + } + } + if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) { + evBody := unwrapIfNeeded(isOAuth, respBody) + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(evBody), maxBytes) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: requestID, + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: evBody} + } + + respBody = unwrapIfNeeded(isOAuth, respBody) + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini] native upstream error %d: %s", resp.StatusCode, truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes)) + } + setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: requestID, + Kind: "http_error", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + + contentType := resp.Header.Get("Content-Type") + if contentType == "" { + contentType = "application/json" + } + c.Data(resp.StatusCode, contentType, respBody) + if upstreamMsg == "" { + return nil, fmt.Errorf("gemini upstream error: %d", resp.StatusCode) + } + return nil, fmt.Errorf("gemini upstream error: %d message=%s", resp.StatusCode, upstreamMsg) + } + + var usage *ClaudeUsage + var firstTokenMs *int + + if stream { + streamRes, err := s.handleNativeStreamingResponse(c, resp, startTime, isOAuth) + if err != nil { + return nil, err + } + usage = streamRes.usage + firstTokenMs = streamRes.firstTokenMs + } else { + if useUpstreamStream { + collected, usageObj, err := collectGeminiSSE(resp.Body, isOAuth) + if err != nil { + return nil, s.writeGoogleError(c, http.StatusBadGateway, "Failed to read upstream stream") + } + b, _ := json.Marshal(collected) + c.Data(http.StatusOK, "application/json", b) + usage = usageObj + } else { + usageResp, err := s.handleNativeNonStreamingResponse(c, resp, isOAuth) + if err != nil { + return nil, err + } + usage = usageResp + } + } + + if usage == nil { + usage = &ClaudeUsage{} + } + + // 图片生成计费 + imageCount := 0 + imageSize := s.extractImageSize(body) + if isImageGenerationModel(originalModel) { + imageCount = 1 + } + + return &ForwardResult{ + RequestID: requestID, + Usage: *usage, + Model: originalModel, + UpstreamModel: mappedModel, + Stream: stream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ImageCount: imageCount, + ImageSize: imageSize, + }, nil +} + +// checkErrorPolicyInLoop 在重试循环内预检查错误策略。 +// 返回 true 表示策略已匹配(调用者应 break),resp 已重建可直接使用。 +// 返回 false 表示 ErrorPolicyNone,resp 已重建,调用者继续走重试逻辑。 +func (s *GeminiMessagesCompatService) checkErrorPolicyInLoop( + ctx context.Context, account *Account, resp *http.Response, +) (matched bool, rebuilt *http.Response) { + if resp.StatusCode < 400 || s.rateLimitService == nil { + return false, resp + } + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + rebuilt = &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(body)), + } + policy := s.rateLimitService.CheckErrorPolicy(ctx, account, resp.StatusCode, body) + return policy != ErrorPolicyNone, rebuilt +} + +func (s *GeminiMessagesCompatService) shouldRetryGeminiUpstreamError(account *Account, statusCode int) bool { + switch statusCode { + case 429, 500, 502, 503, 504, 529: + return true + case 403: + // GeminiCli OAuth occasionally returns 403 transiently (activation/quota propagation); allow retry. + if account == nil || account.Type != AccountTypeOAuth { + return false + } + oauthType := strings.ToLower(strings.TrimSpace(account.GetCredential("oauth_type"))) + if oauthType == "" && strings.TrimSpace(account.GetCredential("project_id")) != "" { + // Legacy/implicit Code Assist OAuth accounts. + oauthType = "code_assist" + } + return oauthType == "code_assist" + default: + return false + } +} + +func (s *GeminiMessagesCompatService) shouldFailoverGeminiUpstreamError(statusCode int) bool { + switch statusCode { + case 401, 403, 429, 529: + return true + default: + return statusCode >= 500 + } +} + +func sleepGeminiBackoff(attempt int) { + delay := geminiRetryBaseDelay * time.Duration(1< geminiRetryMaxDelay { + delay = geminiRetryMaxDelay + } + + // +/- 20% jitter + r := mathrand.New(mathrand.NewSource(time.Now().UnixNano())) + jitter := time.Duration(float64(delay) * 0.2 * (r.Float64()*2 - 1)) + sleepFor := delay + jitter + if sleepFor < 0 { + sleepFor = 0 + } + time.Sleep(sleepFor) +} + +var ( + sensitiveQueryParamRegex = regexp.MustCompile(`(?i)([?&](?:key|client_secret|access_token|refresh_token)=)[^&"\s]+`) + retryInRegex = regexp.MustCompile(`Please retry in ([0-9.]+)s`) +) + +func sanitizeUpstreamErrorMessage(msg string) string { + if msg == "" { + return msg + } + return sensitiveQueryParamRegex.ReplaceAllString(msg, `$1***`) +} + +func (s *GeminiMessagesCompatService) writeGeminiMappedError(c *gin.Context, account *Account, upstreamStatus int, upstreamRequestID string, body []byte) error { + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(body), maxBytes) + } + setOpsUpstreamError(c, upstreamStatus, upstreamMsg, upstreamDetail) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: upstreamStatus, + UpstreamRequestID: upstreamRequestID, + Kind: "http_error", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini] upstream error %d: %s", upstreamStatus, truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes)) + } + + if status, errType, errMsg, matched := applyErrorPassthroughRule( + c, + PlatformGemini, + upstreamStatus, + body, + http.StatusBadGateway, + "upstream_error", + "Upstream request failed", + ); matched { + c.JSON(status, gin.H{ + "type": "error", + "error": gin.H{"type": errType, "message": errMsg}, + }) + if upstreamMsg == "" { + upstreamMsg = errMsg + } + if upstreamMsg == "" { + return fmt.Errorf("upstream error: %d (passthrough rule matched)", upstreamStatus) + } + return fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", upstreamStatus, upstreamMsg) + } + + var statusCode int + var errType, errMsg string + + if mapped := mapGeminiErrorBodyToClaudeError(body); mapped != nil { + errType = mapped.Type + if mapped.Message != "" { + errMsg = mapped.Message + } + if mapped.StatusCode > 0 { + statusCode = mapped.StatusCode + } + } + + switch upstreamStatus { + case 400: + if statusCode == 0 { + statusCode = http.StatusBadRequest + } + if errType == "" { + errType = "invalid_request_error" + } + if errMsg == "" { + errMsg = "Invalid request" + } + case 401: + if statusCode == 0 { + statusCode = http.StatusBadGateway + } + if errType == "" { + errType = "authentication_error" + } + if errMsg == "" { + errMsg = "Upstream authentication failed, please contact administrator" + } + case 403: + if statusCode == 0 { + statusCode = http.StatusBadGateway + } + if errType == "" { + errType = "permission_error" + } + if errMsg == "" { + errMsg = "Upstream access forbidden, please contact administrator" + } + case 404: + if statusCode == 0 { + statusCode = http.StatusNotFound + } + if errType == "" { + errType = "not_found_error" + } + if errMsg == "" { + errMsg = "Resource not found" + } + case 429: + if statusCode == 0 { + statusCode = http.StatusTooManyRequests + } + if errType == "" { + errType = "rate_limit_error" + } + if errMsg == "" { + errMsg = "Upstream rate limit exceeded, please retry later" + } + case 529: + if statusCode == 0 { + statusCode = http.StatusServiceUnavailable + } + if errType == "" { + errType = "overloaded_error" + } + if errMsg == "" { + errMsg = "Upstream service overloaded, please retry later" + } + case 500, 502, 503, 504: + if statusCode == 0 { + statusCode = http.StatusBadGateway + } + if errType == "" { + switch upstreamStatus { + case 504: + errType = "timeout_error" + case 503: + errType = "overloaded_error" + default: + errType = "api_error" + } + } + if errMsg == "" { + errMsg = "Upstream service temporarily unavailable" + } + default: + if statusCode == 0 { + statusCode = http.StatusBadGateway + } + if errType == "" { + errType = "upstream_error" + } + if errMsg == "" { + errMsg = "Upstream request failed" + } + } + + c.JSON(statusCode, gin.H{ + "type": "error", + "error": gin.H{"type": errType, "message": errMsg}, + }) + if upstreamMsg == "" { + return fmt.Errorf("upstream error: %d", upstreamStatus) + } + return fmt.Errorf("upstream error: %d message=%s", upstreamStatus, upstreamMsg) +} + +type claudeErrorMapping struct { + Type string + Message string + StatusCode int +} + +func mapGeminiErrorBodyToClaudeError(body []byte) *claudeErrorMapping { + if len(body) == 0 { + return nil + } + + var parsed struct { + Error struct { + Code int `json:"code"` + Message string `json:"message"` + Status string `json:"status"` + } `json:"error"` + } + if err := json.Unmarshal(body, &parsed); err != nil { + return nil + } + if strings.TrimSpace(parsed.Error.Status) == "" && parsed.Error.Code == 0 && strings.TrimSpace(parsed.Error.Message) == "" { + return nil + } + + mapped := &claudeErrorMapping{ + Type: mapGeminiStatusToClaudeErrorType(parsed.Error.Status), + Message: "", + } + if mapped.Type == "" { + mapped.Type = "upstream_error" + } + + switch strings.ToUpper(strings.TrimSpace(parsed.Error.Status)) { + case "INVALID_ARGUMENT": + mapped.StatusCode = http.StatusBadRequest + case "NOT_FOUND": + mapped.StatusCode = http.StatusNotFound + case "RESOURCE_EXHAUSTED": + mapped.StatusCode = http.StatusTooManyRequests + default: + // Keep StatusCode unset and let HTTP status mapping decide. + } + + // Keep messages generic by default; upstream error message can be long or include sensitive fragments. + return mapped +} + +func mapGeminiStatusToClaudeErrorType(status string) string { + switch strings.ToUpper(strings.TrimSpace(status)) { + case "INVALID_ARGUMENT": + return "invalid_request_error" + case "PERMISSION_DENIED": + return "permission_error" + case "NOT_FOUND": + return "not_found_error" + case "RESOURCE_EXHAUSTED": + return "rate_limit_error" + case "UNAUTHENTICATED": + return "authentication_error" + case "UNAVAILABLE": + return "overloaded_error" + case "INTERNAL": + return "api_error" + case "DEADLINE_EXCEEDED": + return "timeout_error" + default: + return "" + } +} + +type geminiStreamResult struct { + usage *ClaudeUsage + firstTokenMs *int +} + +func (s *GeminiMessagesCompatService) handleNonStreamingResponse(c *gin.Context, resp *http.Response, originalModel string) (*ClaudeUsage, error) { + body, err := io.ReadAll(io.LimitReader(resp.Body, 8<<20)) + if err != nil { + return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream response") + } + + unwrappedBody, err := unwrapGeminiResponse(body) + if err != nil { + return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response") + } + + var geminiResp map[string]any + if err := json.Unmarshal(unwrappedBody, &geminiResp); err != nil { + return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response") + } + + claudeResp, usage := convertGeminiToClaudeMessage(geminiResp, originalModel, unwrappedBody) + c.JSON(http.StatusOK, claudeResp) + + return usage, nil +} + +func (s *GeminiMessagesCompatService) handleStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string) (*geminiStreamResult, error) { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + c.Status(http.StatusOK) + + flusher, ok := c.Writer.(http.Flusher) + if !ok { + return nil, errors.New("streaming not supported") + } + + messageID := "msg_" + randomHex(12) + messageStart := map[string]any{ + "type": "message_start", + "message": map[string]any{ + "id": messageID, + "type": "message", + "role": "assistant", + "model": originalModel, + "content": []any{}, + "stop_reason": nil, + "stop_sequence": nil, + "usage": map[string]any{ + "input_tokens": 0, + "output_tokens": 0, + }, + }, + } + writeSSE(c.Writer, "message_start", messageStart) + flusher.Flush() + + var firstTokenMs *int + var usage ClaudeUsage + finishReason := "" + sawToolUse := false + + nextBlockIndex := 0 + openBlockIndex := -1 + openBlockType := "" + seenText := "" + openToolIndex := -1 + openToolID := "" + openToolName := "" + seenToolJSON := "" + + reader := bufio.NewReader(resp.Body) + for { + line, err := reader.ReadString('\n') + if err != nil && !errors.Is(err, io.EOF) { + return nil, fmt.Errorf("stream read error: %w", err) + } + + if !strings.HasPrefix(line, "data:") { + if errors.Is(err, io.EOF) { + break + } + continue + } + payload := strings.TrimSpace(strings.TrimPrefix(line, "data:")) + if payload == "" || payload == "[DONE]" { + if errors.Is(err, io.EOF) { + break + } + continue + } + + unwrappedBytes, err := unwrapGeminiResponse([]byte(payload)) + if err != nil { + continue + } + + var geminiResp map[string]any + if err := json.Unmarshal(unwrappedBytes, &geminiResp); err != nil { + continue + } + + if fr := extractGeminiFinishReason(geminiResp); fr != "" { + finishReason = fr + } + + parts := extractGeminiParts(geminiResp) + for _, part := range parts { + if text, ok := part["text"].(string); ok && text != "" { + delta, newSeen := computeGeminiTextDelta(seenText, text) + seenText = newSeen + if delta == "" { + continue + } + + if openBlockType != "text" { + if openBlockIndex >= 0 { + writeSSE(c.Writer, "content_block_stop", map[string]any{ + "type": "content_block_stop", + "index": openBlockIndex, + }) + } + openBlockType = "text" + openBlockIndex = nextBlockIndex + nextBlockIndex++ + writeSSE(c.Writer, "content_block_start", map[string]any{ + "type": "content_block_start", + "index": openBlockIndex, + "content_block": map[string]any{ + "type": "text", + "text": "", + }, + }) + } + + if firstTokenMs == nil { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + writeSSE(c.Writer, "content_block_delta", map[string]any{ + "type": "content_block_delta", + "index": openBlockIndex, + "delta": map[string]any{ + "type": "text_delta", + "text": delta, + }, + }) + flusher.Flush() + continue + } + + if fc, ok := part["functionCall"].(map[string]any); ok && fc != nil { + name, _ := fc["name"].(string) + args := fc["args"] + if strings.TrimSpace(name) == "" { + name = "tool" + } + + // Close any open text block before tool_use. + if openBlockIndex >= 0 { + writeSSE(c.Writer, "content_block_stop", map[string]any{ + "type": "content_block_stop", + "index": openBlockIndex, + }) + openBlockIndex = -1 + openBlockType = "" + } + + // If we receive streamed tool args in pieces, keep a single tool block open and emit deltas. + if openToolIndex >= 0 && openToolName != name { + writeSSE(c.Writer, "content_block_stop", map[string]any{ + "type": "content_block_stop", + "index": openToolIndex, + }) + openToolIndex = -1 + openToolName = "" + seenToolJSON = "" + } + + if openToolIndex < 0 { + openToolID = "toolu_" + randomHex(8) + openToolIndex = nextBlockIndex + openToolName = name + nextBlockIndex++ + sawToolUse = true + + writeSSE(c.Writer, "content_block_start", map[string]any{ + "type": "content_block_start", + "index": openToolIndex, + "content_block": map[string]any{ + "type": "tool_use", + "id": openToolID, + "name": name, + "input": map[string]any{}, + }, + }) + } + + argsJSONText := "{}" + switch v := args.(type) { + case nil: + // keep default "{}" + case string: + if strings.TrimSpace(v) != "" { + argsJSONText = v + } + default: + if b, err := json.Marshal(args); err == nil && len(b) > 0 { + argsJSONText = string(b) + } + } + + delta, newSeen := computeGeminiTextDelta(seenToolJSON, argsJSONText) + seenToolJSON = newSeen + if delta != "" { + writeSSE(c.Writer, "content_block_delta", map[string]any{ + "type": "content_block_delta", + "index": openToolIndex, + "delta": map[string]any{ + "type": "input_json_delta", + "partial_json": delta, + }, + }) + } + flusher.Flush() + } + } + + if u := extractGeminiUsage(unwrappedBytes); u != nil { + usage = *u + } + + // Process the final unterminated line at EOF as well. + if errors.Is(err, io.EOF) { + break + } + } + + if openBlockIndex >= 0 { + writeSSE(c.Writer, "content_block_stop", map[string]any{ + "type": "content_block_stop", + "index": openBlockIndex, + }) + } + if openToolIndex >= 0 { + writeSSE(c.Writer, "content_block_stop", map[string]any{ + "type": "content_block_stop", + "index": openToolIndex, + }) + } + + stopReason := mapGeminiFinishReasonToClaudeStopReason(finishReason) + if sawToolUse { + stopReason = "tool_use" + } + + usageObj := map[string]any{ + "output_tokens": usage.OutputTokens, + } + if usage.InputTokens > 0 { + usageObj["input_tokens"] = usage.InputTokens + } + writeSSE(c.Writer, "message_delta", map[string]any{ + "type": "message_delta", + "delta": map[string]any{ + "stop_reason": stopReason, + "stop_sequence": nil, + }, + "usage": usageObj, + }) + writeSSE(c.Writer, "message_stop", map[string]any{ + "type": "message_stop", + }) + flusher.Flush() + + return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil +} + +func writeSSE(w io.Writer, event string, data any) { + if event != "" { + _, _ = fmt.Fprintf(w, "event: %s\n", event) + } + b, _ := json.Marshal(data) + _, _ = fmt.Fprintf(w, "data: %s\n\n", string(b)) +} + +func randomHex(nBytes int) string { + b := make([]byte, nBytes) + _, _ = rand.Read(b) + return hex.EncodeToString(b) +} + +func (s *GeminiMessagesCompatService) writeClaudeError(c *gin.Context, status int, errType, message string) error { + c.JSON(status, gin.H{ + "type": "error", + "error": gin.H{"type": errType, "message": message}, + }) + return fmt.Errorf("%s", message) +} + +func (s *GeminiMessagesCompatService) writeGoogleError(c *gin.Context, status int, message string) error { + c.JSON(status, gin.H{ + "error": gin.H{ + "code": status, + "message": message, + "status": googleapi.HTTPStatusToGoogleStatus(status), + }, + }) + return fmt.Errorf("%s", message) +} + +func unwrapIfNeeded(isOAuth bool, raw []byte) []byte { + if !isOAuth { + return raw + } + inner, err := unwrapGeminiResponse(raw) + if err != nil { + return raw + } + return inner +} + +func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsage, error) { + reader := bufio.NewReader(body) + + var last map[string]any + var lastWithParts map[string]any + var collectedTextParts []string // Collect all text parts for aggregation + usage := &ClaudeUsage{} + + for { + line, err := reader.ReadString('\n') + if len(line) > 0 { + trimmed := strings.TrimRight(line, "\r\n") + if strings.HasPrefix(trimmed, "data:") { + payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:")) + switch payload { + case "", "[DONE]": + if payload == "[DONE]" { + return mergeCollectedTextParts(pickGeminiCollectResult(last, lastWithParts), collectedTextParts), usage, nil + } + default: + var parsed map[string]any + var rawBytes []byte + if isOAuth { + innerBytes, err := unwrapGeminiResponse([]byte(payload)) + if err == nil { + rawBytes = innerBytes + _ = json.Unmarshal(innerBytes, &parsed) + } + } else { + rawBytes = []byte(payload) + _ = json.Unmarshal(rawBytes, &parsed) + } + if parsed != nil { + last = parsed + if u := extractGeminiUsage(rawBytes); u != nil { + usage = u + } + if parts := extractGeminiParts(parsed); len(parts) > 0 { + lastWithParts = parsed + // Collect text from each part for aggregation + for _, part := range parts { + if text, ok := part["text"].(string); ok && text != "" { + collectedTextParts = append(collectedTextParts, text) + } + } + } + } + } + } + } + + if errors.Is(err, io.EOF) { + break + } + if err != nil { + return nil, nil, err + } + } + + return mergeCollectedTextParts(pickGeminiCollectResult(last, lastWithParts), collectedTextParts), usage, nil +} + +func pickGeminiCollectResult(last map[string]any, lastWithParts map[string]any) map[string]any { + if lastWithParts != nil { + return lastWithParts + } + if last != nil { + return last + } + return map[string]any{} +} + +// mergeCollectedTextParts merges all collected text chunks into the final response. +// This fixes the issue where non-streaming responses only returned the last chunk +// instead of the complete aggregated text. +func mergeCollectedTextParts(response map[string]any, textParts []string) map[string]any { + if len(textParts) == 0 { + return response + } + + // Join all text parts + mergedText := strings.Join(textParts, "") + + // Deep copy response + result := make(map[string]any) + for k, v := range response { + result[k] = v + } + + // Get or create candidates + candidates, ok := result["candidates"].([]any) + if !ok || len(candidates) == 0 { + candidates = []any{map[string]any{}} + } + + // Get first candidate + candidate, ok := candidates[0].(map[string]any) + if !ok { + candidate = make(map[string]any) + candidates[0] = candidate + } + + // Get or create content + content, ok := candidate["content"].(map[string]any) + if !ok { + content = map[string]any{"role": "model"} + candidate["content"] = content + } + + // Get existing parts + existingParts, ok := content["parts"].([]any) + if !ok { + existingParts = []any{} + } + + // Find and update first text part, or create new one + newParts := make([]any, 0, len(existingParts)+1) + textUpdated := false + + for _, p := range existingParts { + pm, ok := p.(map[string]any) + if !ok { + newParts = append(newParts, p) + continue + } + if _, hasText := pm["text"]; hasText && !textUpdated { + // Replace with merged text + newPart := make(map[string]any) + for k, v := range pm { + newPart[k] = v + } + newPart["text"] = mergedText + newParts = append(newParts, newPart) + textUpdated = true + } else { + newParts = append(newParts, pm) + } + } + + if !textUpdated { + newParts = append([]any{map[string]any{"text": mergedText}}, newParts...) + } + + content["parts"] = newParts + result["candidates"] = candidates + + return result +} + +type geminiNativeStreamResult struct { + usage *ClaudeUsage + firstTokenMs *int +} + +func isGeminiInsufficientScope(headers http.Header, body []byte) bool { + if strings.Contains(strings.ToLower(headers.Get("Www-Authenticate")), "insufficient_scope") { + return true + } + lower := strings.ToLower(string(body)) + return strings.Contains(lower, "insufficient authentication scopes") || strings.Contains(lower, "access_token_scope_insufficient") +} + +func estimateGeminiCountTokens(reqBody []byte) int { + total := 0 + + // systemInstruction.parts[].text + gjson.GetBytes(reqBody, "systemInstruction.parts").ForEach(func(_, part gjson.Result) bool { + if t := strings.TrimSpace(part.Get("text").String()); t != "" { + total += estimateTokensForText(t) + } + return true + }) + + // contents[].parts[].text + gjson.GetBytes(reqBody, "contents").ForEach(func(_, content gjson.Result) bool { + content.Get("parts").ForEach(func(_, part gjson.Result) bool { + if t := strings.TrimSpace(part.Get("text").String()); t != "" { + total += estimateTokensForText(t) + } + return true + }) + return true + }) + + if total < 0 { + return 0 + } + return total +} + +func estimateTokensForText(s string) int { + s = strings.TrimSpace(s) + if s == "" { + return 0 + } + runes := []rune(s) + if len(runes) == 0 { + return 0 + } + ascii := 0 + for _, r := range runes { + if r <= 0x7f { + ascii++ + } + } + asciiRatio := float64(ascii) / float64(len(runes)) + if asciiRatio >= 0.8 { + // Roughly 4 chars per token for English-like text. + return (len(runes) + 3) / 4 + } + // For CJK-heavy text, approximate 1 rune per token. + return len(runes) +} + +type UpstreamHTTPResult struct { + StatusCode int + Headers http.Header + Body []byte +} + +func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Context, resp *http.Response, isOAuth bool) (*ClaudeUsage, error) { + if s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders { + logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Response Headers ==========") + for key, values := range resp.Header { + if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") { + logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values) + } + } + logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========================================") + } + + maxBytes := resolveUpstreamResponseReadLimit(s.cfg) + respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) + if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream response too large", + }, + }) + } + return nil, err + } + + if isOAuth { + unwrappedBody, uwErr := unwrapGeminiResponse(respBody) + if uwErr == nil { + respBody = unwrappedBody + } + } + + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + + contentType := resp.Header.Get("Content-Type") + if contentType == "" { + contentType = "application/json" + } + c.Data(resp.StatusCode, contentType, respBody) + + if u := extractGeminiUsage(respBody); u != nil { + return u, nil + } + return &ClaudeUsage{}, nil +} + +func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, isOAuth bool) (*geminiNativeStreamResult, error) { + if s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders { + logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Streaming Response Headers ==========") + for key, values := range resp.Header { + if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") { + logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values) + } + } + logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ====================================================") + } + + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + + c.Status(resp.StatusCode) + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + + contentType := resp.Header.Get("Content-Type") + if contentType == "" { + contentType = "text/event-stream; charset=utf-8" + } + c.Header("Content-Type", contentType) + + flusher, ok := c.Writer.(http.Flusher) + if !ok { + return nil, errors.New("streaming not supported") + } + + reader := bufio.NewReader(resp.Body) + usage := &ClaudeUsage{} + var firstTokenMs *int + + for { + line, err := reader.ReadString('\n') + if len(line) > 0 { + trimmed := strings.TrimRight(line, "\r\n") + if strings.HasPrefix(trimmed, "data:") { + payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:")) + // Keepalive / done markers + if payload == "" || payload == "[DONE]" { + _, _ = io.WriteString(c.Writer, line) + flusher.Flush() + } else { + var rawToWrite string + rawToWrite = payload + + var rawBytes []byte + if isOAuth { + innerBytes, err := unwrapGeminiResponse([]byte(payload)) + if err == nil { + rawToWrite = string(innerBytes) + rawBytes = innerBytes + } + } else { + rawBytes = []byte(payload) + } + + if u := extractGeminiUsage(rawBytes); u != nil { + usage = u + } + + if firstTokenMs == nil { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + + if isOAuth { + // SSE format requires double newline (\n\n) to separate events + _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", rawToWrite) + } else { + // Pass-through for AI Studio responses. + _, _ = io.WriteString(c.Writer, line) + } + flusher.Flush() + } + } else { + _, _ = io.WriteString(c.Writer, line) + flusher.Flush() + } + } + + if errors.Is(err, io.EOF) { + break + } + if err != nil { + return nil, err + } + } + + return &geminiNativeStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil +} + +// ForwardAIStudioGET forwards a GET request to AI Studio (generativelanguage.googleapis.com) for +// endpoints like /v1beta/models and /v1beta/models/{model}. +// +// This is used to support Gemini SDKs that call models listing endpoints before generation. +func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, account *Account, path string) (*UpstreamHTTPResult, error) { + if account == nil { + return nil, errors.New("account is nil") + } + path = strings.TrimSpace(path) + if path == "" || !strings.HasPrefix(path, "/") { + return nil, errors.New("invalid path") + } + + baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL) + normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, err + } + fullURL := strings.TrimRight(normalizedBaseURL, "/") + path + + var proxyURL string + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil) + if err != nil { + return nil, err + } + + switch account.Type { + case AccountTypeAPIKey: + apiKey := strings.TrimSpace(account.GetCredential("api_key")) + if apiKey == "" { + return nil, errors.New("gemini api_key not configured") + } + req.Header.Set("x-goog-api-key", apiKey) + case AccountTypeOAuth: + if s.tokenProvider == nil { + return nil, errors.New("gemini token provider not configured") + } + accessToken, err := s.tokenProvider.GetAccessToken(ctx, account) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+accessToken) + default: + return nil, fmt.Errorf("unsupported account type: %s", account.Type) + } + + resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + body, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<20)) + wwwAuthenticate := resp.Header.Get("Www-Authenticate") + filteredHeaders := responseheaders.FilterHeaders(resp.Header, s.responseHeaderFilter) + if wwwAuthenticate != "" { + filteredHeaders.Set("Www-Authenticate", wwwAuthenticate) + } + return &UpstreamHTTPResult{ + StatusCode: resp.StatusCode, + Headers: filteredHeaders, + Body: body, + }, nil +} + +// unwrapGeminiResponse 解包 Gemini OAuth 响应中的 response 字段 +// 使用 gjson 零拷贝提取,避免完整 Unmarshal+Marshal +func unwrapGeminiResponse(raw []byte) ([]byte, error) { + result := gjson.GetBytes(raw, "response") + if result.Exists() && result.Type == gjson.JSON { + return []byte(result.Raw), nil + } + return raw, nil +} + +func convertGeminiToClaudeMessage(geminiResp map[string]any, originalModel string, rawData []byte) (map[string]any, *ClaudeUsage) { + usage := extractGeminiUsage(rawData) + if usage == nil { + usage = &ClaudeUsage{} + } + + contentBlocks := make([]any, 0) + sawToolUse := false + if candidates, ok := geminiResp["candidates"].([]any); ok && len(candidates) > 0 { + if cand, ok := candidates[0].(map[string]any); ok { + if content, ok := cand["content"].(map[string]any); ok { + if parts, ok := content["parts"].([]any); ok { + for _, part := range parts { + pm, ok := part.(map[string]any) + if !ok { + continue + } + if text, ok := pm["text"].(string); ok && text != "" { + contentBlocks = append(contentBlocks, map[string]any{ + "type": "text", + "text": text, + }) + } + if fc, ok := pm["functionCall"].(map[string]any); ok { + name, _ := fc["name"].(string) + if strings.TrimSpace(name) == "" { + name = "tool" + } + args := fc["args"] + sawToolUse = true + contentBlocks = append(contentBlocks, map[string]any{ + "type": "tool_use", + "id": "toolu_" + randomHex(8), + "name": name, + "input": args, + }) + } + } + } + } + } + } + + stopReason := mapGeminiFinishReasonToClaudeStopReason(extractGeminiFinishReason(geminiResp)) + if sawToolUse { + stopReason = "tool_use" + } + + resp := map[string]any{ + "id": "msg_" + randomHex(12), + "type": "message", + "role": "assistant", + "model": originalModel, + "content": contentBlocks, + "stop_reason": stopReason, + "stop_sequence": nil, + "usage": map[string]any{ + "input_tokens": usage.InputTokens, + "output_tokens": usage.OutputTokens, + }, + } + + return resp, usage +} + +func extractGeminiUsage(data []byte) *ClaudeUsage { + usage := gjson.GetBytes(data, "usageMetadata") + if !usage.Exists() { + return nil + } + prompt := int(usage.Get("promptTokenCount").Int()) + cand := int(usage.Get("candidatesTokenCount").Int()) + cached := int(usage.Get("cachedContentTokenCount").Int()) + thoughts := int(usage.Get("thoughtsTokenCount").Int()) + // 注意:Gemini 的 promptTokenCount 包含 cachedContentTokenCount, + // 但 Claude 的 input_tokens 不包含 cache_read_input_tokens,需要减去 + return &ClaudeUsage{ + InputTokens: prompt - cached, + OutputTokens: cand + thoughts, + CacheReadInputTokens: cached, + } +} + +func asInt(v any) (int, bool) { + switch t := v.(type) { + case float64: + return int(t), true + case int: + return t, true + case int64: + return int(t), true + case json.Number: + i, err := t.Int64() + if err != nil { + return 0, false + } + return int(i), true + default: + return 0, false + } +} + +func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, body []byte) { + // 遵守自定义错误码策略:未命中则跳过所有限流处理 + if !account.ShouldHandleErrorCode(statusCode) { + return + } + if s.rateLimitService != nil && (statusCode == 401 || statusCode == 403 || statusCode == 529) { + s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body) + return + } + if statusCode != 429 { + return + } + + oauthType := account.GeminiOAuthType() + tierID := account.GeminiTierID() + projectID := strings.TrimSpace(account.GetCredential("project_id")) + isCodeAssist := account.IsGeminiCodeAssist() + + resetAt := ParseGeminiRateLimitResetTime(body) + if resetAt == nil { + // 根据账号类型使用不同的默认重置时间 + var ra time.Time + if isCodeAssist { + // Code Assist: fallback cooldown by tier + cooldown := geminiCooldownForTier(tierID) + if s.rateLimitService != nil { + cooldown = s.rateLimitService.GeminiCooldown(ctx, account) + } + ra = time.Now().Add(cooldown) + logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini 429] Account %d (Code Assist, tier=%s, project=%s) rate limited, cooldown=%v", account.ID, tierID, projectID, time.Until(ra).Truncate(time.Second)) + } else { + // API Key / AI Studio OAuth: PST 午夜 + if ts := nextGeminiDailyResetUnix(); ts != nil { + ra = time.Unix(*ts, 0) + logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini 429] Account %d (API Key/AI Studio, type=%s) rate limited, reset at PST midnight (%v)", account.ID, account.Type, ra) + } else { + // 兜底:5 分钟 + ra = time.Now().Add(5 * time.Minute) + logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini 429] Account %d rate limited, fallback to 5min", account.ID) + } + } + _ = s.accountRepo.SetRateLimited(ctx, account.ID, ra) + return + } + + // 使用解析到的重置时间 + resetTime := time.Unix(*resetAt, 0) + _ = s.accountRepo.SetRateLimited(ctx, account.ID, resetTime) + logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini 429] Account %d rate limited until %v (oauth_type=%s, tier=%s)", + account.ID, resetTime, oauthType, tierID) +} + +// ParseGeminiRateLimitResetTime 解析 Gemini 格式的 429 响应,返回重置时间的 Unix 时间戳 +func ParseGeminiRateLimitResetTime(body []byte) *int64 { + // 第一阶段:gjson 结构化提取 + errMsg := gjson.GetBytes(body, "error.message").String() + if looksLikeGeminiDailyQuota(errMsg) { + if ts := nextGeminiDailyResetUnix(); ts != nil { + return ts + } + } + + // 遍历 error.details 查找 quotaResetDelay + var found *int64 + gjson.GetBytes(body, "error.details").ForEach(func(_, detail gjson.Result) bool { + v := detail.Get("metadata.quotaResetDelay").String() + if v == "" { + return true + } + if dur, err := time.ParseDuration(v); err == nil { + // Use ceil to avoid undercounting fractional seconds (e.g. 10.1s should not become 10s), + // which can affect scheduling decisions around thresholds (like 10s). + ts := time.Now().Unix() + int64(math.Ceil(dur.Seconds())) + found = &ts + return false + } + return true + }) + if found != nil { + return found + } + + // 第二阶段:regex 回退匹配 "Please retry in Xs" + matches := retryInRegex.FindStringSubmatch(string(body)) + if len(matches) == 2 { + if dur, err := time.ParseDuration(matches[1] + "s"); err == nil { + ts := time.Now().Unix() + int64(math.Ceil(dur.Seconds())) + return &ts + } + } + + return nil +} + +func looksLikeGeminiDailyQuota(message string) bool { + m := strings.ToLower(message) + if strings.Contains(m, "per day") || strings.Contains(m, "requests per day") || strings.Contains(m, "quota") && strings.Contains(m, "per day") { + return true + } + return false +} + +func nextGeminiDailyResetUnix() *int64 { + reset := geminiDailyResetTime(time.Now()) + ts := reset.Unix() + return &ts +} + +func ensureGeminiFunctionCallThoughtSignatures(body []byte) []byte { + // Fast path: only run when functionCall is present. + if !bytes.Contains(body, []byte(`"functionCall"`)) { + return body + } + + var payload map[string]any + if err := json.Unmarshal(body, &payload); err != nil { + return body + } + + contentsAny, ok := payload["contents"].([]any) + if !ok || len(contentsAny) == 0 { + return body + } + + modified := false + for _, c := range contentsAny { + cm, ok := c.(map[string]any) + if !ok { + continue + } + partsAny, ok := cm["parts"].([]any) + if !ok || len(partsAny) == 0 { + continue + } + for _, p := range partsAny { + pm, ok := p.(map[string]any) + if !ok || pm == nil { + continue + } + if fc, ok := pm["functionCall"].(map[string]any); !ok || fc == nil { + continue + } + ts, _ := pm["thoughtSignature"].(string) + if strings.TrimSpace(ts) == "" { + pm["thoughtSignature"] = geminiDummyThoughtSignature + modified = true + } + } + } + + if !modified { + return body + } + b, err := json.Marshal(payload) + if err != nil { + return body + } + return b +} + +func extractGeminiFinishReason(geminiResp map[string]any) string { + if candidates, ok := geminiResp["candidates"].([]any); ok && len(candidates) > 0 { + if cand, ok := candidates[0].(map[string]any); ok { + if fr, ok := cand["finishReason"].(string); ok { + return fr + } + } + } + return "" +} + +func extractGeminiParts(geminiResp map[string]any) []map[string]any { + if candidates, ok := geminiResp["candidates"].([]any); ok && len(candidates) > 0 { + if cand, ok := candidates[0].(map[string]any); ok { + if content, ok := cand["content"].(map[string]any); ok { + if partsAny, ok := content["parts"].([]any); ok && len(partsAny) > 0 { + out := make([]map[string]any, 0, len(partsAny)) + for _, p := range partsAny { + pm, ok := p.(map[string]any) + if !ok { + continue + } + out = append(out, pm) + } + return out + } + } + } + } + return nil +} + +func computeGeminiTextDelta(seen, incoming string) (delta, newSeen string) { + incoming = strings.TrimSuffix(incoming, "\u0000") + if incoming == "" { + return "", seen + } + + // Cumulative mode: incoming contains full text so far. + if strings.HasPrefix(incoming, seen) { + return strings.TrimPrefix(incoming, seen), incoming + } + // Duplicate/rewind: ignore. + if strings.HasPrefix(seen, incoming) { + return "", seen + } + // Delta mode: treat incoming as incremental chunk. + return incoming, seen + incoming +} + +func mapGeminiFinishReasonToClaudeStopReason(finishReason string) string { + switch strings.ToUpper(strings.TrimSpace(finishReason)) { + case "MAX_TOKENS": + return "max_tokens" + case "STOP": + return "end_turn" + default: + return "end_turn" + } +} + +func convertClaudeMessagesToGeminiGenerateContent(body []byte) ([]byte, error) { + var req map[string]any + if err := json.Unmarshal(body, &req); err != nil { + return nil, err + } + + toolUseIDToName := make(map[string]string) + + systemText := extractClaudeSystemText(req["system"]) + contents, err := convertClaudeMessagesToGeminiContents(req["messages"], toolUseIDToName) + if err != nil { + return nil, err + } + + out := make(map[string]any) + if systemText != "" { + out["systemInstruction"] = map[string]any{ + "parts": []any{map[string]any{"text": systemText}}, + } + } + out["contents"] = contents + + if tools := convertClaudeToolsToGeminiTools(req["tools"]); tools != nil { + out["tools"] = tools + } + + generationConfig := convertClaudeGenerationConfig(req) + if generationConfig != nil { + out["generationConfig"] = generationConfig + } + + stripGeminiFunctionIDs(out) + return json.Marshal(out) +} + +func stripGeminiFunctionIDs(req map[string]any) { + // Defensive cleanup: some upstreams reject unexpected `id` fields in functionCall/functionResponse. + contents, ok := req["contents"].([]any) + if !ok { + return + } + for _, c := range contents { + cm, ok := c.(map[string]any) + if !ok { + continue + } + contentParts, ok := cm["parts"].([]any) + if !ok { + continue + } + for _, p := range contentParts { + pm, ok := p.(map[string]any) + if !ok { + continue + } + if fc, ok := pm["functionCall"].(map[string]any); ok && fc != nil { + delete(fc, "id") + } + if fr, ok := pm["functionResponse"].(map[string]any); ok && fr != nil { + delete(fr, "id") + } + } + } +} + +func extractClaudeSystemText(system any) string { + switch v := system.(type) { + case string: + return strings.TrimSpace(v) + case []any: + var parts []string + for _, p := range v { + pm, ok := p.(map[string]any) + if !ok { + continue + } + if t, _ := pm["type"].(string); t != "text" { + continue + } + if text, ok := pm["text"].(string); ok && strings.TrimSpace(text) != "" { + parts = append(parts, text) + } + } + return strings.TrimSpace(strings.Join(parts, "\n")) + default: + return "" + } +} + +func convertClaudeMessagesToGeminiContents(messages any, toolUseIDToName map[string]string) ([]any, error) { + arr, ok := messages.([]any) + if !ok { + return nil, errors.New("messages must be an array") + } + + out := make([]any, 0, len(arr)) + for _, m := range arr { + mm, ok := m.(map[string]any) + if !ok { + continue + } + role, _ := mm["role"].(string) + role = strings.ToLower(strings.TrimSpace(role)) + gRole := "user" + if role == "assistant" { + gRole = "model" + } + + parts := make([]any, 0) + switch content := mm["content"].(type) { + case string: + // 字符串形式的 content,保留所有内容(包括空白) + parts = append(parts, map[string]any{"text": content}) + case []any: + // 如果只有一个 block,不过滤空白(让上游 API 报错) + singleBlock := len(content) == 1 + + for _, block := range content { + bm, ok := block.(map[string]any) + if !ok { + continue + } + bt, _ := bm["type"].(string) + switch bt { + case "text": + if text, ok := bm["text"].(string); ok { + // 单个 block 时保留所有内容(包括空白) + // 多个 blocks 时过滤掉空白 + if singleBlock || strings.TrimSpace(text) != "" { + parts = append(parts, map[string]any{"text": text}) + } + } + case "tool_use": + id, _ := bm["id"].(string) + name, _ := bm["name"].(string) + if strings.TrimSpace(id) != "" && strings.TrimSpace(name) != "" { + toolUseIDToName[id] = name + } + signature, _ := bm["signature"].(string) + signature = strings.TrimSpace(signature) + if signature == "" { + signature = geminiDummyThoughtSignature + } + parts = append(parts, map[string]any{ + "thoughtSignature": signature, + "functionCall": map[string]any{ + "name": name, + "args": bm["input"], + }, + }) + case "tool_result": + toolUseID, _ := bm["tool_use_id"].(string) + name := toolUseIDToName[toolUseID] + if name == "" { + name = "tool" + } + parts = append(parts, map[string]any{ + "functionResponse": map[string]any{ + "name": name, + "response": map[string]any{ + "content": extractClaudeContentText(bm["content"]), + }, + }, + }) + case "image": + if src, ok := bm["source"].(map[string]any); ok { + if srcType, _ := src["type"].(string); srcType == "base64" { + mediaType, _ := src["media_type"].(string) + data, _ := src["data"].(string) + if mediaType != "" && data != "" { + parts = append(parts, map[string]any{ + "inlineData": map[string]any{ + "mimeType": mediaType, + "data": data, + }, + }) + } + } + } + default: + // best-effort: preserve unknown blocks as text + if b, err := json.Marshal(bm); err == nil { + parts = append(parts, map[string]any{"text": string(b)}) + } + } + } + default: + // ignore + } + + out = append(out, map[string]any{ + "role": gRole, + "parts": parts, + }) + } + return out, nil +} + +func extractClaudeContentText(v any) string { + switch t := v.(type) { + case string: + return t + case []any: + var sb strings.Builder + for _, part := range t { + pm, ok := part.(map[string]any) + if !ok { + continue + } + if pm["type"] == "text" { + if text, ok := pm["text"].(string); ok { + _, _ = sb.WriteString(text) + } + } + } + return sb.String() + default: + b, _ := json.Marshal(t) + return string(b) + } +} + +func convertClaudeToolsToGeminiTools(tools any) []any { + arr, ok := tools.([]any) + if !ok || len(arr) == 0 { + return nil + } + + funcDecls := make([]any, 0, len(arr)) + for _, t := range arr { + tm, ok := t.(map[string]any) + if !ok { + continue + } + + var name, desc string + var params any + + // 检查是否为 custom 类型工具 (MCP) + toolType, _ := tm["type"].(string) + if toolType == "custom" { + // Custom 格式: 从 custom 字段获取 description 和 input_schema + custom, ok := tm["custom"].(map[string]any) + if !ok { + continue + } + name, _ = tm["name"].(string) + desc, _ = custom["description"].(string) + params = custom["input_schema"] + } else { + // 标准格式: 从顶层字段获取 + name, _ = tm["name"].(string) + desc, _ = tm["description"].(string) + params = tm["input_schema"] + } + + if name == "" { + continue + } + + // 为 nil params 提供默认值 + if params == nil { + params = map[string]any{ + "type": "object", + "properties": map[string]any{}, + } + } + // 清理 JSON Schema + cleanedParams := cleanToolSchema(params) + + funcDecls = append(funcDecls, map[string]any{ + "name": name, + "description": desc, + "parameters": cleanedParams, + }) + } + + if len(funcDecls) == 0 { + return nil + } + return []any{ + map[string]any{ + "functionDeclarations": funcDecls, + }, + } +} + +// cleanToolSchema 清理工具的 JSON Schema,移除 Gemini 不支持的字段 +func cleanToolSchema(schema any) any { + if schema == nil { + return nil + } + + switch v := schema.(type) { + case map[string]any: + cleaned := make(map[string]any) + for key, value := range v { + // 跳过不支持的字段 + if key == "$schema" || key == "$id" || key == "$ref" || + key == "additionalProperties" || key == "patternProperties" || key == "minLength" || + key == "maxLength" || key == "minItems" || key == "maxItems" { + continue + } + // 递归清理嵌套对象 + cleaned[key] = cleanToolSchema(value) + } + // 规范化 type 字段为大写 + if typeVal, ok := cleaned["type"].(string); ok { + cleaned["type"] = strings.ToUpper(typeVal) + } + return cleaned + case []any: + cleaned := make([]any, len(v)) + for i, item := range v { + cleaned[i] = cleanToolSchema(item) + } + return cleaned + default: + return v + } +} + +func convertClaudeGenerationConfig(req map[string]any) map[string]any { + out := make(map[string]any) + if mt, ok := asInt(req["max_tokens"]); ok && mt > 0 { + out["maxOutputTokens"] = mt + } + if temp, ok := req["temperature"].(float64); ok { + out["temperature"] = temp + } + if topP, ok := req["top_p"].(float64); ok { + out["topP"] = topP + } + if stopSeq, ok := req["stop_sequences"].([]any); ok && len(stopSeq) > 0 { + out["stopSequences"] = stopSeq + } + if len(out) == 0 { + return nil + } + return out +} + +// extractImageSize 从 Gemini 请求中提取 image_size 参数 +func (s *GeminiMessagesCompatService) extractImageSize(body []byte) string { + var req struct { + GenerationConfig *struct { + ImageConfig *struct { + ImageSize string `json:"imageSize"` + } `json:"imageConfig"` + } `json:"generationConfig"` + } + if err := json.Unmarshal(body, &req); err != nil { + return "2K" + } + + if req.GenerationConfig != nil && req.GenerationConfig.ImageConfig != nil { + size := strings.ToUpper(strings.TrimSpace(req.GenerationConfig.ImageConfig.ImageSize)) + if size == "1K" || size == "2K" || size == "4K" { + return size + } + } + + return "2K" +} diff --git a/backend/internal/service/instance_signature.go b/backend/internal/service/instance_signature.go new file mode 100644 index 00000000..db0777ad --- /dev/null +++ b/backend/internal/service/instance_signature.go @@ -0,0 +1,38 @@ +package service + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "strings" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" +) + +// generateCodeHMAC generates HMAC-SHA256 signature, returning the first 8 hex characters. +// Used for RedeemCode and PromoCode. +func generateCodeHMAC(message, secret string) string { + h := hmac.New(sha256.New, []byte(secret)) + h.Write([]byte(message)) + return hex.EncodeToString(h.Sum(nil))[:8] +} + +// ValidateCodeSignature validates an instance signature appended to a code. +// Returns the raw code without the signature, or an error if validation fails. +// The code format is: rawCode.signature +func validateCodeSignature(codeWithSig string, secret string) (string, error) { + parts := strings.Split(codeWithSig, ".") + if len(parts) != 2 { + return "", infraerrors.BadRequest("INVALID_CODE_FORMAT", "invalid code format: missing instance signature") + } + + rawCode := parts[0] + signature := parts[1] + + expectedSig := generateCodeHMAC(rawCode, secret) + if !hmac.Equal([]byte(signature), []byte(expectedSig)) { + return "", infraerrors.BadRequest("INVALID_CODE_SIGNATURE", "invalid code signature") + } + + return rawCode, nil +} \ No newline at end of file diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go new file mode 100644 index 00000000..923c410d --- /dev/null +++ b/backend/internal/service/openai_account_scheduler.go @@ -0,0 +1,922 @@ +package service + +import ( + "container/heap" + "context" + "errors" + "hash/fnv" + "math" + "sort" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" +) + +const ( + openAIAccountScheduleLayerPreviousResponse = "previous_response_id" + openAIAccountScheduleLayerSessionSticky = "session_hash" + openAIAccountScheduleLayerLoadBalance = "load_balance" +) + +type OpenAIAccountScheduleRequest struct { + GroupID *int64 + SessionHash string + StickyAccountID int64 + PreviousResponseID string + RequestedModel string + RequiredTransport OpenAIUpstreamTransport + ExcludedIDs map[int64]struct{} +} + +type OpenAIAccountScheduleDecision struct { + Layer string + StickyPreviousHit bool + StickySessionHit bool + CandidateCount int + TopK int + LatencyMs int64 + LoadSkew float64 + SelectedAccountID int64 + SelectedAccountType string +} + +type OpenAIAccountSchedulerMetricsSnapshot struct { + SelectTotal int64 + StickyPreviousHitTotal int64 + StickySessionHitTotal int64 + LoadBalanceSelectTotal int64 + AccountSwitchTotal int64 + SchedulerLatencyMsTotal int64 + SchedulerLatencyMsAvg float64 + StickyHitRatio float64 + AccountSwitchRate float64 + LoadSkewAvg float64 + RuntimeStatsAccountCount int +} + +type OpenAIAccountScheduler interface { + Select(ctx context.Context, req OpenAIAccountScheduleRequest) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) + ReportResult(accountID int64, success bool, firstTokenMs *int) + ReportSwitch() + SnapshotMetrics() OpenAIAccountSchedulerMetricsSnapshot +} + +type openAIAccountSchedulerMetrics struct { + selectTotal atomic.Int64 + stickyPreviousHitTotal atomic.Int64 + stickySessionHitTotal atomic.Int64 + loadBalanceSelectTotal atomic.Int64 + accountSwitchTotal atomic.Int64 + latencyMsTotal atomic.Int64 + loadSkewMilliTotal atomic.Int64 +} + +func (m *openAIAccountSchedulerMetrics) recordSelect(decision OpenAIAccountScheduleDecision) { + if m == nil { + return + } + m.selectTotal.Add(1) + m.latencyMsTotal.Add(decision.LatencyMs) + m.loadSkewMilliTotal.Add(int64(math.Round(decision.LoadSkew * 1000))) + if decision.StickyPreviousHit { + m.stickyPreviousHitTotal.Add(1) + } + if decision.StickySessionHit { + m.stickySessionHitTotal.Add(1) + } + if decision.Layer == openAIAccountScheduleLayerLoadBalance { + m.loadBalanceSelectTotal.Add(1) + } +} + +func (m *openAIAccountSchedulerMetrics) recordSwitch() { + if m == nil { + return + } + m.accountSwitchTotal.Add(1) +} + +type openAIAccountRuntimeStats struct { + accounts sync.Map + accountCount atomic.Int64 +} + +type openAIAccountRuntimeStat struct { + errorRateEWMABits atomic.Uint64 + ttftEWMABits atomic.Uint64 +} + +func newOpenAIAccountRuntimeStats() *openAIAccountRuntimeStats { + return &openAIAccountRuntimeStats{} +} + +func (s *openAIAccountRuntimeStats) loadOrCreate(accountID int64) *openAIAccountRuntimeStat { + if value, ok := s.accounts.Load(accountID); ok { + stat, _ := value.(*openAIAccountRuntimeStat) + if stat != nil { + return stat + } + } + + stat := &openAIAccountRuntimeStat{} + stat.ttftEWMABits.Store(math.Float64bits(math.NaN())) + actual, loaded := s.accounts.LoadOrStore(accountID, stat) + if !loaded { + s.accountCount.Add(1) + return stat + } + existing, _ := actual.(*openAIAccountRuntimeStat) + if existing != nil { + return existing + } + return stat +} + +func updateEWMAAtomic(target *atomic.Uint64, sample float64, alpha float64) { + for { + oldBits := target.Load() + oldValue := math.Float64frombits(oldBits) + newValue := alpha*sample + (1-alpha)*oldValue + if target.CompareAndSwap(oldBits, math.Float64bits(newValue)) { + return + } + } +} + +func (s *openAIAccountRuntimeStats) report(accountID int64, success bool, firstTokenMs *int) { + if s == nil || accountID <= 0 { + return + } + const alpha = 0.2 + stat := s.loadOrCreate(accountID) + + errorSample := 1.0 + if success { + errorSample = 0.0 + } + updateEWMAAtomic(&stat.errorRateEWMABits, errorSample, alpha) + + if firstTokenMs != nil && *firstTokenMs > 0 { + ttft := float64(*firstTokenMs) + ttftBits := math.Float64bits(ttft) + for { + oldBits := stat.ttftEWMABits.Load() + oldValue := math.Float64frombits(oldBits) + if math.IsNaN(oldValue) { + if stat.ttftEWMABits.CompareAndSwap(oldBits, ttftBits) { + break + } + continue + } + newValue := alpha*ttft + (1-alpha)*oldValue + if stat.ttftEWMABits.CompareAndSwap(oldBits, math.Float64bits(newValue)) { + break + } + } + } +} + +func (s *openAIAccountRuntimeStats) snapshot(accountID int64) (errorRate float64, ttft float64, hasTTFT bool) { + if s == nil || accountID <= 0 { + return 0, 0, false + } + value, ok := s.accounts.Load(accountID) + if !ok { + return 0, 0, false + } + stat, _ := value.(*openAIAccountRuntimeStat) + if stat == nil { + return 0, 0, false + } + errorRate = clamp01(math.Float64frombits(stat.errorRateEWMABits.Load())) + ttftValue := math.Float64frombits(stat.ttftEWMABits.Load()) + if math.IsNaN(ttftValue) { + return errorRate, 0, false + } + return errorRate, ttftValue, true +} + +func (s *openAIAccountRuntimeStats) size() int { + if s == nil { + return 0 + } + return int(s.accountCount.Load()) +} + +type defaultOpenAIAccountScheduler struct { + service *OpenAIGatewayService + metrics openAIAccountSchedulerMetrics + stats *openAIAccountRuntimeStats +} + +func newDefaultOpenAIAccountScheduler(service *OpenAIGatewayService, stats *openAIAccountRuntimeStats) OpenAIAccountScheduler { + if stats == nil { + stats = newOpenAIAccountRuntimeStats() + } + return &defaultOpenAIAccountScheduler{ + service: service, + stats: stats, + } +} + +func (s *defaultOpenAIAccountScheduler) Select( + ctx context.Context, + req OpenAIAccountScheduleRequest, +) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) { + decision := OpenAIAccountScheduleDecision{} + start := time.Now() + defer func() { + decision.LatencyMs = time.Since(start).Milliseconds() + s.metrics.recordSelect(decision) + }() + + previousResponseID := strings.TrimSpace(req.PreviousResponseID) + if previousResponseID != "" { + selection, err := s.service.SelectAccountByPreviousResponseID( + ctx, + req.GroupID, + previousResponseID, + req.RequestedModel, + req.ExcludedIDs, + ) + if err != nil { + return nil, decision, err + } + if selection != nil && selection.Account != nil { + if !s.isAccountTransportCompatible(selection.Account, req.RequiredTransport) { + selection = nil + } + } + if selection != nil && selection.Account != nil { + decision.Layer = openAIAccountScheduleLayerPreviousResponse + decision.StickyPreviousHit = true + decision.SelectedAccountID = selection.Account.ID + decision.SelectedAccountType = selection.Account.Type + if req.SessionHash != "" { + _ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, selection.Account.ID) + } + return selection, decision, nil + } + } + + selection, err := s.selectBySessionHash(ctx, req) + if err != nil { + return nil, decision, err + } + if selection != nil && selection.Account != nil { + decision.Layer = openAIAccountScheduleLayerSessionSticky + decision.StickySessionHit = true + decision.SelectedAccountID = selection.Account.ID + decision.SelectedAccountType = selection.Account.Type + return selection, decision, nil + } + + selection, candidateCount, topK, loadSkew, err := s.selectByLoadBalance(ctx, req) + decision.Layer = openAIAccountScheduleLayerLoadBalance + decision.CandidateCount = candidateCount + decision.TopK = topK + decision.LoadSkew = loadSkew + if err != nil { + return nil, decision, err + } + if selection != nil && selection.Account != nil { + decision.SelectedAccountID = selection.Account.ID + decision.SelectedAccountType = selection.Account.Type + } + return selection, decision, nil +} + +func (s *defaultOpenAIAccountScheduler) selectBySessionHash( + ctx context.Context, + req OpenAIAccountScheduleRequest, +) (*AccountSelectionResult, error) { + sessionHash := strings.TrimSpace(req.SessionHash) + if sessionHash == "" || s == nil || s.service == nil || s.service.cache == nil { + return nil, nil + } + + accountID := req.StickyAccountID + if accountID <= 0 { + var err error + accountID, err = s.service.getStickySessionAccountID(ctx, req.GroupID, sessionHash) + if err != nil || accountID <= 0 { + return nil, nil + } + } + if accountID <= 0 { + return nil, nil + } + if req.ExcludedIDs != nil { + if _, excluded := req.ExcludedIDs[accountID]; excluded { + return nil, nil + } + } + + account, err := s.service.getSchedulableAccount(ctx, accountID) + if err != nil || account == nil { + _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) + return nil, nil + } + if shouldClearStickySession(ctx, account, req.RequestedModel) || !account.IsOpenAI() || !account.IsSchedulable() { + _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) + return nil, nil + } + if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) { + return nil, nil + } + if !s.isAccountTransportCompatible(account, req.RequiredTransport) { + _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) + return nil, nil + } + + result, acquireErr := s.service.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) + if acquireErr == nil && result.Acquired { + _ = s.service.refreshStickySessionTTL(ctx, req.GroupID, sessionHash, s.service.openAIWSSessionStickyTTL()) + return &AccountSelectionResult{ + Account: account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } + + cfg := s.service.schedulingConfig() + // WaitPlan.MaxConcurrency 使用 Concurrency(非 EffectiveLoadFactor),因为 WaitPlan 控制的是 Redis 实际并发槽位等待。 + if s.service.concurrencyService != nil { + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: accountID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } + return nil, nil +} + +type openAIAccountCandidateScore struct { + account *Account + loadInfo *AccountLoadInfo + score float64 + errorRate float64 + ttft float64 + hasTTFT bool +} + +type openAIAccountCandidateHeap []openAIAccountCandidateScore + +func (h openAIAccountCandidateHeap) Len() int { + return len(h) +} + +func (h openAIAccountCandidateHeap) Less(i, j int) bool { + // 最小堆根节点保存“最差”候选,便于 O(log k) 维护 topK。 + return isOpenAIAccountCandidateBetter(h[j], h[i]) +} + +func (h openAIAccountCandidateHeap) Swap(i, j int) { + h[i], h[j] = h[j], h[i] +} + +func (h *openAIAccountCandidateHeap) Push(x any) { + candidate, ok := x.(openAIAccountCandidateScore) + if !ok { + panic("openAIAccountCandidateHeap: invalid element type") + } + *h = append(*h, candidate) +} + +func (h *openAIAccountCandidateHeap) Pop() any { + old := *h + n := len(old) + last := old[n-1] + *h = old[:n-1] + return last +} + +func isOpenAIAccountCandidateBetter(left openAIAccountCandidateScore, right openAIAccountCandidateScore) bool { + if left.score != right.score { + return left.score > right.score + } + if left.account.Priority != right.account.Priority { + return left.account.Priority < right.account.Priority + } + if left.loadInfo.LoadRate != right.loadInfo.LoadRate { + return left.loadInfo.LoadRate < right.loadInfo.LoadRate + } + if left.loadInfo.WaitingCount != right.loadInfo.WaitingCount { + return left.loadInfo.WaitingCount < right.loadInfo.WaitingCount + } + return left.account.ID < right.account.ID +} + +func selectTopKOpenAICandidates(candidates []openAIAccountCandidateScore, topK int) []openAIAccountCandidateScore { + if len(candidates) == 0 { + return nil + } + if topK <= 0 { + topK = 1 + } + if topK >= len(candidates) { + ranked := append([]openAIAccountCandidateScore(nil), candidates...) + sort.Slice(ranked, func(i, j int) bool { + return isOpenAIAccountCandidateBetter(ranked[i], ranked[j]) + }) + return ranked + } + + best := make(openAIAccountCandidateHeap, 0, topK) + for _, candidate := range candidates { + if len(best) < topK { + heap.Push(&best, candidate) + continue + } + if isOpenAIAccountCandidateBetter(candidate, best[0]) { + best[0] = candidate + heap.Fix(&best, 0) + } + } + + ranked := make([]openAIAccountCandidateScore, len(best)) + copy(ranked, best) + sort.Slice(ranked, func(i, j int) bool { + return isOpenAIAccountCandidateBetter(ranked[i], ranked[j]) + }) + return ranked +} + +type openAISelectionRNG struct { + state uint64 +} + +func newOpenAISelectionRNG(seed uint64) openAISelectionRNG { + if seed == 0 { + seed = 0x9e3779b97f4a7c15 + } + return openAISelectionRNG{state: seed} +} + +func (r *openAISelectionRNG) nextUint64() uint64 { + // xorshift64* + x := r.state + x ^= x >> 12 + x ^= x << 25 + x ^= x >> 27 + r.state = x + return x * 2685821657736338717 +} + +func (r *openAISelectionRNG) nextFloat64() float64 { + // [0,1) + return float64(r.nextUint64()>>11) / (1 << 53) +} + +func deriveOpenAISelectionSeed(req OpenAIAccountScheduleRequest) uint64 { + hasher := fnv.New64a() + writeValue := func(value string) { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return + } + _, _ = hasher.Write([]byte(trimmed)) + _, _ = hasher.Write([]byte{0}) + } + + writeValue(req.SessionHash) + writeValue(req.PreviousResponseID) + writeValue(req.RequestedModel) + if req.GroupID != nil { + _, _ = hasher.Write([]byte(strconv.FormatInt(*req.GroupID, 10))) + } + + seed := hasher.Sum64() + // 对“无会话锚点”的纯负载均衡请求引入时间熵,避免固定命中同一账号。 + if strings.TrimSpace(req.SessionHash) == "" && strings.TrimSpace(req.PreviousResponseID) == "" { + seed ^= uint64(time.Now().UnixNano()) + } + if seed == 0 { + seed = uint64(time.Now().UnixNano()) ^ 0x9e3779b97f4a7c15 + } + return seed +} + +func buildOpenAIWeightedSelectionOrder( + candidates []openAIAccountCandidateScore, + req OpenAIAccountScheduleRequest, +) []openAIAccountCandidateScore { + if len(candidates) <= 1 { + return append([]openAIAccountCandidateScore(nil), candidates...) + } + + pool := append([]openAIAccountCandidateScore(nil), candidates...) + weights := make([]float64, len(pool)) + minScore := pool[0].score + for i := 1; i < len(pool); i++ { + if pool[i].score < minScore { + minScore = pool[i].score + } + } + for i := range pool { + // 将 top-K 分值平移到正区间,避免“单一最高分账号”长期垄断。 + weight := (pool[i].score - minScore) + 1.0 + if math.IsNaN(weight) || math.IsInf(weight, 0) || weight <= 0 { + weight = 1.0 + } + weights[i] = weight + } + + order := make([]openAIAccountCandidateScore, 0, len(pool)) + rng := newOpenAISelectionRNG(deriveOpenAISelectionSeed(req)) + for len(pool) > 0 { + total := 0.0 + for _, w := range weights { + total += w + } + + selectedIdx := 0 + if total > 0 { + r := rng.nextFloat64() * total + acc := 0.0 + for i, w := range weights { + acc += w + if r <= acc { + selectedIdx = i + break + } + } + } else { + selectedIdx = int(rng.nextUint64() % uint64(len(pool))) + } + + order = append(order, pool[selectedIdx]) + pool = append(pool[:selectedIdx], pool[selectedIdx+1:]...) + weights = append(weights[:selectedIdx], weights[selectedIdx+1:]...) + } + return order +} + +func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( + ctx context.Context, + req OpenAIAccountScheduleRequest, +) (*AccountSelectionResult, int, int, float64, error) { + accounts, err := s.service.listSchedulableAccounts(ctx, req.GroupID) + if err != nil { + return nil, 0, 0, 0, err + } + if len(accounts) == 0 { + return nil, 0, 0, 0, errors.New("no available OpenAI accounts") + } + + filtered := make([]*Account, 0, len(accounts)) + loadReq := make([]AccountWithConcurrency, 0, len(accounts)) + for i := range accounts { + account := &accounts[i] + if req.ExcludedIDs != nil { + if _, excluded := req.ExcludedIDs[account.ID]; excluded { + continue + } + } + if !account.IsSchedulable() || !account.IsOpenAI() { + continue + } + if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) { + continue + } + if !s.isAccountTransportCompatible(account, req.RequiredTransport) { + continue + } + filtered = append(filtered, account) + loadReq = append(loadReq, AccountWithConcurrency{ + ID: account.ID, + MaxConcurrency: account.EffectiveLoadFactor(), + }) + } + if len(filtered) == 0 { + return nil, 0, 0, 0, errors.New("no available OpenAI accounts") + } + + loadMap := map[int64]*AccountLoadInfo{} + if s.service.concurrencyService != nil { + if batchLoad, loadErr := s.service.concurrencyService.GetAccountsLoadBatch(ctx, loadReq); loadErr == nil { + loadMap = batchLoad + } + } + + minPriority, maxPriority := filtered[0].Priority, filtered[0].Priority + maxWaiting := 1 + loadRateSum := 0.0 + loadRateSumSquares := 0.0 + minTTFT, maxTTFT := 0.0, 0.0 + hasTTFTSample := false + candidates := make([]openAIAccountCandidateScore, 0, len(filtered)) + for _, account := range filtered { + loadInfo := loadMap[account.ID] + if loadInfo == nil { + loadInfo = &AccountLoadInfo{AccountID: account.ID} + } + if account.Priority < minPriority { + minPriority = account.Priority + } + if account.Priority > maxPriority { + maxPriority = account.Priority + } + if loadInfo.WaitingCount > maxWaiting { + maxWaiting = loadInfo.WaitingCount + } + errorRate, ttft, hasTTFT := s.stats.snapshot(account.ID) + if hasTTFT && ttft > 0 { + if !hasTTFTSample { + minTTFT, maxTTFT = ttft, ttft + hasTTFTSample = true + } else { + if ttft < minTTFT { + minTTFT = ttft + } + if ttft > maxTTFT { + maxTTFT = ttft + } + } + } + loadRate := float64(loadInfo.LoadRate) + loadRateSum += loadRate + loadRateSumSquares += loadRate * loadRate + candidates = append(candidates, openAIAccountCandidateScore{ + account: account, + loadInfo: loadInfo, + errorRate: errorRate, + ttft: ttft, + hasTTFT: hasTTFT, + }) + } + loadSkew := calcLoadSkewByMoments(loadRateSum, loadRateSumSquares, len(candidates)) + + weights := s.service.openAIWSSchedulerWeights() + for i := range candidates { + item := &candidates[i] + priorityFactor := 1.0 + if maxPriority > minPriority { + priorityFactor = 1 - float64(item.account.Priority-minPriority)/float64(maxPriority-minPriority) + } + loadFactor := 1 - clamp01(float64(item.loadInfo.LoadRate)/100.0) + queueFactor := 1 - clamp01(float64(item.loadInfo.WaitingCount)/float64(maxWaiting)) + errorFactor := 1 - clamp01(item.errorRate) + ttftFactor := 0.5 + if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT { + ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT)) + } + + item.score = weights.Priority*priorityFactor + + weights.Load*loadFactor + + weights.Queue*queueFactor + + weights.ErrorRate*errorFactor + + weights.TTFT*ttftFactor + } + + topK := s.service.openAIWSLBTopK() + if topK > len(candidates) { + topK = len(candidates) + } + if topK <= 0 { + topK = 1 + } + rankedCandidates := selectTopKOpenAICandidates(candidates, topK) + selectionOrder := buildOpenAIWeightedSelectionOrder(rankedCandidates, req) + + for i := 0; i < len(selectionOrder); i++ { + candidate := selectionOrder[i] + fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel) + if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) { + continue + } + result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) + if acquireErr != nil { + return nil, len(candidates), topK, loadSkew, acquireErr + } + if result != nil && result.Acquired { + if req.SessionHash != "" { + _ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, fresh.ID) + } + return &AccountSelectionResult{ + Account: fresh, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, len(candidates), topK, loadSkew, nil + } + } + + cfg := s.service.schedulingConfig() + // WaitPlan.MaxConcurrency 使用 Concurrency(非 EffectiveLoadFactor),因为 WaitPlan 控制的是 Redis 实际并发槽位等待。 + for _, candidate := range selectionOrder { + fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel) + if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) { + continue + } + return &AccountSelectionResult{ + Account: fresh, + WaitPlan: &AccountWaitPlan{ + AccountID: fresh.ID, + MaxConcurrency: fresh.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }, + }, len(candidates), topK, loadSkew, nil + } + + return nil, len(candidates), topK, loadSkew, ErrNoAvailableAccounts +} + +func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool { + // HTTP 入站可回退到 HTTP 线路,不需要在账号选择阶段做传输协议强过滤。 + if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE { + return true + } + if s == nil || s.service == nil || account == nil { + return false + } + return s.service.getOpenAIWSProtocolResolver().Resolve(account).Transport == requiredTransport +} + +func (s *defaultOpenAIAccountScheduler) ReportResult(accountID int64, success bool, firstTokenMs *int) { + if s == nil || s.stats == nil { + return + } + s.stats.report(accountID, success, firstTokenMs) +} + +func (s *defaultOpenAIAccountScheduler) ReportSwitch() { + if s == nil { + return + } + s.metrics.recordSwitch() +} + +func (s *defaultOpenAIAccountScheduler) SnapshotMetrics() OpenAIAccountSchedulerMetricsSnapshot { + if s == nil { + return OpenAIAccountSchedulerMetricsSnapshot{} + } + + selectTotal := s.metrics.selectTotal.Load() + prevHit := s.metrics.stickyPreviousHitTotal.Load() + sessionHit := s.metrics.stickySessionHitTotal.Load() + switchTotal := s.metrics.accountSwitchTotal.Load() + latencyTotal := s.metrics.latencyMsTotal.Load() + loadSkewTotal := s.metrics.loadSkewMilliTotal.Load() + + snapshot := OpenAIAccountSchedulerMetricsSnapshot{ + SelectTotal: selectTotal, + StickyPreviousHitTotal: prevHit, + StickySessionHitTotal: sessionHit, + LoadBalanceSelectTotal: s.metrics.loadBalanceSelectTotal.Load(), + AccountSwitchTotal: switchTotal, + SchedulerLatencyMsTotal: latencyTotal, + RuntimeStatsAccountCount: s.stats.size(), + } + if selectTotal > 0 { + snapshot.SchedulerLatencyMsAvg = float64(latencyTotal) / float64(selectTotal) + snapshot.StickyHitRatio = float64(prevHit+sessionHit) / float64(selectTotal) + snapshot.AccountSwitchRate = float64(switchTotal) / float64(selectTotal) + snapshot.LoadSkewAvg = float64(loadSkewTotal) / 1000 / float64(selectTotal) + } + return snapshot +} + +func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountScheduler { + if s == nil { + return nil + } + s.openaiSchedulerOnce.Do(func() { + if s.openaiAccountStats == nil { + s.openaiAccountStats = newOpenAIAccountRuntimeStats() + } + if s.openaiScheduler == nil { + s.openaiScheduler = newDefaultOpenAIAccountScheduler(s, s.openaiAccountStats) + } + }) + return s.openaiScheduler +} + +func (s *OpenAIGatewayService) SelectAccountWithScheduler( + ctx context.Context, + groupID *int64, + previousResponseID string, + sessionHash string, + requestedModel string, + excludedIDs map[int64]struct{}, + requiredTransport OpenAIUpstreamTransport, +) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) { + decision := OpenAIAccountScheduleDecision{} + scheduler := s.getOpenAIAccountScheduler() + if scheduler == nil { + selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs) + decision.Layer = openAIAccountScheduleLayerLoadBalance + return selection, decision, err + } + + var stickyAccountID int64 + if sessionHash != "" && s.cache != nil { + if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil && accountID > 0 { + stickyAccountID = accountID + } + } + + return scheduler.Select(ctx, OpenAIAccountScheduleRequest{ + GroupID: groupID, + SessionHash: sessionHash, + StickyAccountID: stickyAccountID, + PreviousResponseID: previousResponseID, + RequestedModel: requestedModel, + RequiredTransport: requiredTransport, + ExcludedIDs: excludedIDs, + }) +} + +func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64, success bool, firstTokenMs *int) { + scheduler := s.getOpenAIAccountScheduler() + if scheduler == nil { + return + } + scheduler.ReportResult(accountID, success, firstTokenMs) +} + +func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() { + scheduler := s.getOpenAIAccountScheduler() + if scheduler == nil { + return + } + scheduler.ReportSwitch() +} + +func (s *OpenAIGatewayService) SnapshotOpenAIAccountSchedulerMetrics() OpenAIAccountSchedulerMetricsSnapshot { + scheduler := s.getOpenAIAccountScheduler() + if scheduler == nil { + return OpenAIAccountSchedulerMetricsSnapshot{} + } + return scheduler.SnapshotMetrics() +} + +func (s *OpenAIGatewayService) openAIWSSessionStickyTTL() time.Duration { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.StickySessionTTLSeconds > 0 { + return time.Duration(s.cfg.Gateway.OpenAIWS.StickySessionTTLSeconds) * time.Second + } + return openaiStickySessionTTL +} + +func (s *OpenAIGatewayService) openAIWSLBTopK() int { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.LBTopK > 0 { + return s.cfg.Gateway.OpenAIWS.LBTopK + } + return 7 +} + +func (s *OpenAIGatewayService) openAIWSSchedulerWeights() GatewayOpenAIWSSchedulerScoreWeightsView { + if s != nil && s.cfg != nil { + return GatewayOpenAIWSSchedulerScoreWeightsView{ + Priority: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority, + Load: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load, + Queue: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue, + ErrorRate: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate, + TTFT: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT, + } + } + return GatewayOpenAIWSSchedulerScoreWeightsView{ + Priority: 1.0, + Load: 1.0, + Queue: 0.7, + ErrorRate: 0.8, + TTFT: 0.5, + } +} + +type GatewayOpenAIWSSchedulerScoreWeightsView struct { + Priority float64 + Load float64 + Queue float64 + ErrorRate float64 + TTFT float64 +} + +func clamp01(value float64) float64 { + switch { + case value < 0: + return 0 + case value > 1: + return 1 + default: + return value + } +} + +func calcLoadSkewByMoments(sum float64, sumSquares float64, count int) float64 { + if count <= 1 { + return 0 + } + mean := sum / float64(count) + variance := sumSquares/float64(count) - mean*mean + if variance < 0 { + variance = 0 + } + return math.Sqrt(variance) +} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go new file mode 100644 index 00000000..85a5d89c --- /dev/null +++ b/backend/internal/service/openai_gateway_service.go @@ -0,0 +1,4709 @@ +package service + +import ( + "bufio" + "bytes" + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "math/rand" + "net/http" + "sort" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" + "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" + "github.com/cespare/xxhash/v2" + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "go.uber.org/zap" +) + +const ( + // ChatGPT internal API for OAuth accounts + chatgptCodexURL = "https://chatgpt.com/backend-api/codex/responses" + // OpenAI Platform API for API Key accounts (fallback) + openaiPlatformAPIURL = "https://api.openai.com/v1/responses" + openaiStickySessionTTL = time.Hour // 粘性会话TTL + codexCLIUserAgent = "codex_cli_rs/0.104.0" + // codex_cli_only 拒绝时单个请求头日志长度上限(字符) + codexCLIOnlyHeaderValueMaxBytes = 256 + + // OpenAIParsedRequestBodyKey 缓存 handler 侧已解析的请求体,避免重复解析。 + OpenAIParsedRequestBodyKey = "openai_parsed_request_body" + // OpenAI WS Mode 失败后的重连次数上限(不含首次尝试)。 + // 与 Codex 客户端保持一致:失败后最多重连 5 次。 + openAIWSReconnectRetryLimit = 5 + // OpenAI WS Mode 重连退避默认值(可由配置覆盖)。 + openAIWSRetryBackoffInitialDefault = 120 * time.Millisecond + openAIWSRetryBackoffMaxDefault = 2 * time.Second + openAIWSRetryJitterRatioDefault = 0.2 + openAICompactSessionSeedKey = "openai_compact_session_seed" + codexCLIVersion = "0.104.0" + // Codex 限额快照仅用于后台展示/诊断,不需要每个成功请求都立即落库。 + openAICodexSnapshotPersistMinInterval = 30 * time.Second +) + +// OpenAI allowed headers whitelist (for non-passthrough). +var openaiAllowedHeaders = map[string]bool{ + "accept-language": true, + "content-type": true, + "conversation_id": true, + "user-agent": true, + "originator": true, + "session_id": true, + "x-codex-turn-state": true, + "x-codex-turn-metadata": true, +} + +// OpenAI passthrough allowed headers whitelist. +// 透传模式下仅放行这些低风险请求头,避免将非标准/环境噪声头传给上游触发风控。 +var openaiPassthroughAllowedHeaders = map[string]bool{ + "accept": true, + "accept-language": true, + "content-type": true, + "conversation_id": true, + "openai-beta": true, + "user-agent": true, + "originator": true, + "session_id": true, + "x-codex-turn-state": true, + "x-codex-turn-metadata": true, +} + +// codex_cli_only 拒绝时记录的请求头白名单(仅用于诊断日志,不参与上游透传) +var codexCLIOnlyDebugHeaderWhitelist = []string{ + "User-Agent", + "Content-Type", + "Accept", + "Accept-Language", + "OpenAI-Beta", + "Originator", + "Session_ID", + "Conversation_ID", + "X-Request-ID", + "X-Client-Request-ID", + "X-Forwarded-For", + "X-Real-IP", +} + +// OpenAICodexUsageSnapshot represents Codex API usage limits from response headers +type OpenAICodexUsageSnapshot struct { + PrimaryUsedPercent *float64 `json:"primary_used_percent,omitempty"` + PrimaryResetAfterSeconds *int `json:"primary_reset_after_seconds,omitempty"` + PrimaryWindowMinutes *int `json:"primary_window_minutes,omitempty"` + SecondaryUsedPercent *float64 `json:"secondary_used_percent,omitempty"` + SecondaryResetAfterSeconds *int `json:"secondary_reset_after_seconds,omitempty"` + SecondaryWindowMinutes *int `json:"secondary_window_minutes,omitempty"` + PrimaryOverSecondaryPercent *float64 `json:"primary_over_secondary_percent,omitempty"` + UpdatedAt string `json:"updated_at,omitempty"` +} + +// NormalizedCodexLimits contains normalized 5h/7d rate limit data +type NormalizedCodexLimits struct { + Used5hPercent *float64 + Reset5hSeconds *int + Window5hMinutes *int + Used7dPercent *float64 + Reset7dSeconds *int + Window7dMinutes *int +} + +// Normalize converts primary/secondary fields to canonical 5h/7d fields. +// Strategy: Compare window_minutes to determine which is 5h vs 7d. +// Returns nil if snapshot is nil or has no useful data. +func (s *OpenAICodexUsageSnapshot) Normalize() *NormalizedCodexLimits { + if s == nil { + return nil + } + + result := &NormalizedCodexLimits{} + + primaryMins := 0 + secondaryMins := 0 + hasPrimaryWindow := false + hasSecondaryWindow := false + + if s.PrimaryWindowMinutes != nil { + primaryMins = *s.PrimaryWindowMinutes + hasPrimaryWindow = true + } + if s.SecondaryWindowMinutes != nil { + secondaryMins = *s.SecondaryWindowMinutes + hasSecondaryWindow = true + } + + // Determine mapping based on window_minutes + use5hFromPrimary := false + use7dFromPrimary := false + + if hasPrimaryWindow && hasSecondaryWindow { + // Both known: smaller window is 5h, larger is 7d + if primaryMins < secondaryMins { + use5hFromPrimary = true + } else { + use7dFromPrimary = true + } + } else if hasPrimaryWindow { + // Only primary known: classify by threshold (<=360 min = 6h -> 5h window) + if primaryMins <= 360 { + use5hFromPrimary = true + } else { + use7dFromPrimary = true + } + } else if hasSecondaryWindow { + // Only secondary known: classify by threshold + if secondaryMins <= 360 { + // 5h from secondary, so primary (if any data) is 7d + use7dFromPrimary = true + } else { + // 7d from secondary, so primary (if any data) is 5h + use5hFromPrimary = true + } + } else { + // No window_minutes: fall back to legacy assumption (primary=7d, secondary=5h) + use7dFromPrimary = true + } + + // Assign values + if use5hFromPrimary { + result.Used5hPercent = s.PrimaryUsedPercent + result.Reset5hSeconds = s.PrimaryResetAfterSeconds + result.Window5hMinutes = s.PrimaryWindowMinutes + result.Used7dPercent = s.SecondaryUsedPercent + result.Reset7dSeconds = s.SecondaryResetAfterSeconds + result.Window7dMinutes = s.SecondaryWindowMinutes + } else if use7dFromPrimary { + result.Used7dPercent = s.PrimaryUsedPercent + result.Reset7dSeconds = s.PrimaryResetAfterSeconds + result.Window7dMinutes = s.PrimaryWindowMinutes + result.Used5hPercent = s.SecondaryUsedPercent + result.Reset5hSeconds = s.SecondaryResetAfterSeconds + result.Window5hMinutes = s.SecondaryWindowMinutes + } + + return result +} + +// OpenAIUsage represents OpenAI API response usage +type OpenAIUsage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"` + CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"` +} + +// OpenAIForwardResult represents the result of forwarding +type OpenAIForwardResult struct { + RequestID string + Usage OpenAIUsage + Model string // 原始模型(用于响应和日志显示) + // BillingModel is the model used for cost calculation. + // When non-empty, CalculateCost uses this instead of Model. + // This is set by the Anthropic Messages conversion path where + // the mapped upstream model differs from the client-facing model. + BillingModel string + // UpstreamModel is the actual model sent to the upstream provider after mapping. + // Empty when no mapping was applied (requested model was used as-is). + UpstreamModel string + // ServiceTier records the OpenAI Responses API service tier, e.g. "priority" / "flex". + // Nil means the request did not specify a recognized tier. + ServiceTier *string + // ReasoningEffort is extracted from request body (reasoning.effort) or derived from model suffix. + // Stored for usage records display; nil means not provided / not applicable. + ReasoningEffort *string + Stream bool + OpenAIWSMode bool + ResponseHeaders http.Header + Duration time.Duration + FirstTokenMs *int +} + +type OpenAIWSRetryMetricsSnapshot struct { + RetryAttemptsTotal int64 `json:"retry_attempts_total"` + RetryBackoffMsTotal int64 `json:"retry_backoff_ms_total"` + RetryExhaustedTotal int64 `json:"retry_exhausted_total"` + NonRetryableFastFallbackTotal int64 `json:"non_retryable_fast_fallback_total"` +} + +type OpenAICompatibilityFallbackMetricsSnapshot struct { + SessionHashLegacyReadFallbackTotal int64 `json:"session_hash_legacy_read_fallback_total"` + SessionHashLegacyReadFallbackHit int64 `json:"session_hash_legacy_read_fallback_hit"` + SessionHashLegacyDualWriteTotal int64 `json:"session_hash_legacy_dual_write_total"` + SessionHashLegacyReadHitRate float64 `json:"session_hash_legacy_read_hit_rate"` + + MetadataLegacyFallbackIsMaxTokensOneHaikuTotal int64 `json:"metadata_legacy_fallback_is_max_tokens_one_haiku_total"` + MetadataLegacyFallbackThinkingEnabledTotal int64 `json:"metadata_legacy_fallback_thinking_enabled_total"` + MetadataLegacyFallbackPrefetchedStickyAccount int64 `json:"metadata_legacy_fallback_prefetched_sticky_account_total"` + MetadataLegacyFallbackPrefetchedStickyGroup int64 `json:"metadata_legacy_fallback_prefetched_sticky_group_total"` + MetadataLegacyFallbackSingleAccountRetryTotal int64 `json:"metadata_legacy_fallback_single_account_retry_total"` + MetadataLegacyFallbackAccountSwitchCountTotal int64 `json:"metadata_legacy_fallback_account_switch_count_total"` + MetadataLegacyFallbackTotal int64 `json:"metadata_legacy_fallback_total"` +} + +type openAIWSRetryMetrics struct { + retryAttempts atomic.Int64 + retryBackoffMs atomic.Int64 + retryExhausted atomic.Int64 + nonRetryableFastFallback atomic.Int64 +} + +type accountWriteThrottle struct { + minInterval time.Duration + mu sync.Mutex + lastByID map[int64]time.Time +} + +func newAccountWriteThrottle(minInterval time.Duration) *accountWriteThrottle { + return &accountWriteThrottle{ + minInterval: minInterval, + lastByID: make(map[int64]time.Time), + } +} + +func (t *accountWriteThrottle) Allow(id int64, now time.Time) bool { + if t == nil || id <= 0 || t.minInterval <= 0 { + return true + } + + t.mu.Lock() + defer t.mu.Unlock() + + if last, ok := t.lastByID[id]; ok && now.Sub(last) < t.minInterval { + return false + } + t.lastByID[id] = now + + if len(t.lastByID) > 4096 { + cutoff := now.Add(-4 * t.minInterval) + for accountID, writtenAt := range t.lastByID { + if writtenAt.Before(cutoff) { + delete(t.lastByID, accountID) + } + } + } + + return true +} + +var defaultOpenAICodexSnapshotPersistThrottle = newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval) + +// OpenAIGatewayService handles OpenAI API gateway operations +type OpenAIGatewayService struct { + accountRepo AccountRepository + usageLogRepo UsageLogRepository + usageBillingRepo UsageBillingRepository + userRepo UserRepository + userSubRepo UserSubscriptionRepository + cache GatewayCache + cfg *config.Config + codexDetector CodexClientRestrictionDetector + schedulerSnapshot *SchedulerSnapshotService + concurrencyService *ConcurrencyService + billingService *BillingService + rateLimitService *RateLimitService + billingCacheService *BillingCacheService + userGroupRateResolver *userGroupRateResolver + httpUpstream HTTPUpstream + deferredService *DeferredService + openAITokenProvider *OpenAITokenProvider + toolCorrector *CodexToolCorrector + openaiWSResolver OpenAIWSProtocolResolver + + openaiWSPoolOnce sync.Once + openaiWSStateStoreOnce sync.Once + openaiSchedulerOnce sync.Once + openaiWSPassthroughDialerOnce sync.Once + openaiWSPool *openAIWSConnPool + openaiWSStateStore OpenAIWSStateStore + openaiScheduler OpenAIAccountScheduler + openaiWSPassthroughDialer openAIWSClientDialer + openaiAccountStats *openAIAccountRuntimeStats + + openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time + openaiWSRetryMetrics openAIWSRetryMetrics + responseHeaderFilter *responseheaders.CompiledHeaderFilter + codexSnapshotThrottle *accountWriteThrottle +} + +// NewOpenAIGatewayService creates a new OpenAIGatewayService +func NewOpenAIGatewayService( + accountRepo AccountRepository, + usageLogRepo UsageLogRepository, + usageBillingRepo UsageBillingRepository, + userRepo UserRepository, + userSubRepo UserSubscriptionRepository, + userGroupRateRepo UserGroupRateRepository, + cache GatewayCache, + cfg *config.Config, + schedulerSnapshot *SchedulerSnapshotService, + concurrencyService *ConcurrencyService, + billingService *BillingService, + rateLimitService *RateLimitService, + billingCacheService *BillingCacheService, + httpUpstream HTTPUpstream, + deferredService *DeferredService, + openAITokenProvider *OpenAITokenProvider, +) *OpenAIGatewayService { + svc := &OpenAIGatewayService{ + accountRepo: accountRepo, + usageLogRepo: usageLogRepo, + usageBillingRepo: usageBillingRepo, + userRepo: userRepo, + userSubRepo: userSubRepo, + cache: cache, + cfg: cfg, + codexDetector: NewOpenAICodexClientRestrictionDetector(cfg), + schedulerSnapshot: schedulerSnapshot, + concurrencyService: concurrencyService, + billingService: billingService, + rateLimitService: rateLimitService, + billingCacheService: billingCacheService, + userGroupRateResolver: newUserGroupRateResolver( + userGroupRateRepo, + nil, + resolveUserGroupRateCacheTTL(cfg), + nil, + "service.openai_gateway", + ), + httpUpstream: httpUpstream, + deferredService: deferredService, + openAITokenProvider: openAITokenProvider, + toolCorrector: NewCodexToolCorrector(), + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + responseHeaderFilter: compileResponseHeaderFilter(cfg), + codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval), + } + svc.logOpenAIWSModeBootstrap() + return svc +} + +func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle { + if s != nil && s.codexSnapshotThrottle != nil { + return s.codexSnapshotThrottle + } + return defaultOpenAICodexSnapshotPersistThrottle +} + +func (s *OpenAIGatewayService) billingDeps() *billingDeps { + return &billingDeps{ + accountRepo: s.accountRepo, + userRepo: s.userRepo, + userSubRepo: s.userSubRepo, + billingCacheService: s.billingCacheService, + deferredService: s.deferredService, + } +} + +// CloseOpenAIWSPool 关闭 OpenAI WebSocket 连接池的后台 worker 和空闲连接。 +// 应在应用优雅关闭时调用。 +func (s *OpenAIGatewayService) CloseOpenAIWSPool() { + if s != nil && s.openaiWSPool != nil { + s.openaiWSPool.Close() + } +} + +func (s *OpenAIGatewayService) logOpenAIWSModeBootstrap() { + if s == nil || s.cfg == nil { + return + } + wsCfg := s.cfg.Gateway.OpenAIWS + logOpenAIWSModeInfo( + "bootstrap enabled=%v oauth_enabled=%v apikey_enabled=%v force_http=%v responses_websockets_v2=%v responses_websockets=%v payload_log_sample_rate=%.3f event_flush_batch_size=%d event_flush_interval_ms=%d prewarm_cooldown_ms=%d retry_backoff_initial_ms=%d retry_backoff_max_ms=%d retry_jitter_ratio=%.3f retry_total_budget_ms=%d ws_read_limit_bytes=%d", + wsCfg.Enabled, + wsCfg.OAuthEnabled, + wsCfg.APIKeyEnabled, + wsCfg.ForceHTTP, + wsCfg.ResponsesWebsocketsV2, + wsCfg.ResponsesWebsockets, + wsCfg.PayloadLogSampleRate, + wsCfg.EventFlushBatchSize, + wsCfg.EventFlushIntervalMS, + wsCfg.PrewarmCooldownMS, + wsCfg.RetryBackoffInitialMS, + wsCfg.RetryBackoffMaxMS, + wsCfg.RetryJitterRatio, + wsCfg.RetryTotalBudgetMS, + openAIWSMessageReadLimitBytes, + ) +} + +func (s *OpenAIGatewayService) getCodexClientRestrictionDetector() CodexClientRestrictionDetector { + if s != nil && s.codexDetector != nil { + return s.codexDetector + } + var cfg *config.Config + if s != nil { + cfg = s.cfg + } + return NewOpenAICodexClientRestrictionDetector(cfg) +} + +func (s *OpenAIGatewayService) getOpenAIWSProtocolResolver() OpenAIWSProtocolResolver { + if s != nil && s.openaiWSResolver != nil { + return s.openaiWSResolver + } + var cfg *config.Config + if s != nil { + cfg = s.cfg + } + return NewOpenAIWSProtocolResolver(cfg) +} + +func classifyOpenAIWSReconnectReason(err error) (string, bool) { + if err == nil { + return "", false + } + var fallbackErr *openAIWSFallbackError + if !errors.As(err, &fallbackErr) || fallbackErr == nil { + return "", false + } + reason := strings.TrimSpace(fallbackErr.Reason) + if reason == "" { + return "", false + } + + baseReason := strings.TrimPrefix(reason, "prewarm_") + + switch baseReason { + case "policy_violation", + "message_too_big", + "upgrade_required", + "ws_unsupported", + "auth_failed", + "invalid_encrypted_content", + "previous_response_not_found": + return reason, false + } + + switch baseReason { + case "read_event", + "write_request", + "write", + "acquire_timeout", + "acquire_conn", + "conn_queue_full", + "dial_failed", + "upstream_5xx", + "event_error", + "error_event", + "upstream_error_event", + "ws_connection_limit_reached", + "missing_final_response": + return reason, true + default: + return reason, false + } +} + +func resolveOpenAIWSFallbackErrorResponse(err error) (statusCode int, errType string, clientMessage string, upstreamMessage string, ok bool) { + if err == nil { + return 0, "", "", "", false + } + var fallbackErr *openAIWSFallbackError + if !errors.As(err, &fallbackErr) || fallbackErr == nil { + return 0, "", "", "", false + } + + reason := strings.TrimSpace(fallbackErr.Reason) + reason = strings.TrimPrefix(reason, "prewarm_") + if reason == "" { + return 0, "", "", "", false + } + + var dialErr *openAIWSDialError + if fallbackErr.Err != nil && errors.As(fallbackErr.Err, &dialErr) && dialErr != nil { + if dialErr.StatusCode > 0 { + statusCode = dialErr.StatusCode + } + if dialErr.Err != nil { + upstreamMessage = sanitizeUpstreamErrorMessage(strings.TrimSpace(dialErr.Err.Error())) + } + } + + switch reason { + case "invalid_encrypted_content": + if statusCode == 0 { + statusCode = http.StatusBadRequest + } + errType = "invalid_request_error" + if upstreamMessage == "" { + upstreamMessage = "encrypted content could not be verified" + } + case "previous_response_not_found": + if statusCode == 0 { + statusCode = http.StatusBadRequest + } + errType = "invalid_request_error" + if upstreamMessage == "" { + upstreamMessage = "previous response not found" + } + case "upgrade_required": + if statusCode == 0 { + statusCode = http.StatusUpgradeRequired + } + case "ws_unsupported": + if statusCode == 0 { + statusCode = http.StatusBadRequest + } + case "auth_failed": + if statusCode == 0 { + statusCode = http.StatusUnauthorized + } + case "upstream_rate_limited": + if statusCode == 0 { + statusCode = http.StatusTooManyRequests + } + default: + if statusCode == 0 { + return 0, "", "", "", false + } + } + + if upstreamMessage == "" && fallbackErr.Err != nil { + upstreamMessage = sanitizeUpstreamErrorMessage(strings.TrimSpace(fallbackErr.Err.Error())) + } + if upstreamMessage == "" { + switch reason { + case "upgrade_required": + upstreamMessage = "upstream websocket upgrade required" + case "ws_unsupported": + upstreamMessage = "upstream websocket not supported" + case "auth_failed": + upstreamMessage = "upstream authentication failed" + case "upstream_rate_limited": + upstreamMessage = "upstream rate limit exceeded, please retry later" + default: + upstreamMessage = "Upstream request failed" + } + } + + if errType == "" { + if statusCode == http.StatusTooManyRequests { + errType = "rate_limit_error" + } else { + errType = "upstream_error" + } + } + clientMessage = upstreamMessage + return statusCode, errType, clientMessage, upstreamMessage, true +} + +func (s *OpenAIGatewayService) writeOpenAIWSFallbackErrorResponse(c *gin.Context, account *Account, wsErr error) bool { + if c == nil || c.Writer == nil || c.Writer.Written() { + return false + } + statusCode, errType, clientMessage, upstreamMessage, ok := resolveOpenAIWSFallbackErrorResponse(wsErr) + if !ok { + return false + } + if strings.TrimSpace(clientMessage) == "" { + clientMessage = "Upstream request failed" + } + if strings.TrimSpace(upstreamMessage) == "" { + upstreamMessage = clientMessage + } + + setOpsUpstreamError(c, statusCode, upstreamMessage, "") + if account != nil { + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: statusCode, + Kind: "ws_error", + Message: upstreamMessage, + }) + } + c.JSON(statusCode, gin.H{ + "error": gin.H{ + "type": errType, + "message": clientMessage, + }, + }) + return true +} + +func (s *OpenAIGatewayService) openAIWSRetryBackoff(attempt int) time.Duration { + if attempt <= 0 { + return 0 + } + + initial := openAIWSRetryBackoffInitialDefault + maxBackoff := openAIWSRetryBackoffMaxDefault + jitterRatio := openAIWSRetryJitterRatioDefault + if s != nil && s.cfg != nil { + wsCfg := s.cfg.Gateway.OpenAIWS + if wsCfg.RetryBackoffInitialMS > 0 { + initial = time.Duration(wsCfg.RetryBackoffInitialMS) * time.Millisecond + } + if wsCfg.RetryBackoffMaxMS > 0 { + maxBackoff = time.Duration(wsCfg.RetryBackoffMaxMS) * time.Millisecond + } + if wsCfg.RetryJitterRatio >= 0 { + jitterRatio = wsCfg.RetryJitterRatio + } + } + if initial <= 0 { + return 0 + } + if maxBackoff <= 0 { + maxBackoff = initial + } + if maxBackoff < initial { + maxBackoff = initial + } + if jitterRatio < 0 { + jitterRatio = 0 + } + if jitterRatio > 1 { + jitterRatio = 1 + } + + shift := attempt - 1 + if shift < 0 { + shift = 0 + } + backoff := initial + if shift > 0 { + backoff = initial * time.Duration(1< maxBackoff { + backoff = maxBackoff + } + if jitterRatio <= 0 { + return backoff + } + jitter := time.Duration(float64(backoff) * jitterRatio) + if jitter <= 0 { + return backoff + } + delta := time.Duration(rand.Int63n(int64(jitter)*2+1)) - jitter + withJitter := backoff + delta + if withJitter < 0 { + return 0 + } + return withJitter +} + +func (s *OpenAIGatewayService) openAIWSRetryTotalBudget() time.Duration { + if s != nil && s.cfg != nil { + ms := s.cfg.Gateway.OpenAIWS.RetryTotalBudgetMS + if ms <= 0 { + return 0 + } + return time.Duration(ms) * time.Millisecond + } + return 0 +} + +func (s *OpenAIGatewayService) recordOpenAIWSRetryAttempt(backoff time.Duration) { + if s == nil { + return + } + s.openaiWSRetryMetrics.retryAttempts.Add(1) + if backoff > 0 { + s.openaiWSRetryMetrics.retryBackoffMs.Add(backoff.Milliseconds()) + } +} + +func (s *OpenAIGatewayService) recordOpenAIWSRetryExhausted() { + if s == nil { + return + } + s.openaiWSRetryMetrics.retryExhausted.Add(1) +} + +func (s *OpenAIGatewayService) recordOpenAIWSNonRetryableFastFallback() { + if s == nil { + return + } + s.openaiWSRetryMetrics.nonRetryableFastFallback.Add(1) +} + +func (s *OpenAIGatewayService) SnapshotOpenAIWSRetryMetrics() OpenAIWSRetryMetricsSnapshot { + if s == nil { + return OpenAIWSRetryMetricsSnapshot{} + } + return OpenAIWSRetryMetricsSnapshot{ + RetryAttemptsTotal: s.openaiWSRetryMetrics.retryAttempts.Load(), + RetryBackoffMsTotal: s.openaiWSRetryMetrics.retryBackoffMs.Load(), + RetryExhaustedTotal: s.openaiWSRetryMetrics.retryExhausted.Load(), + NonRetryableFastFallbackTotal: s.openaiWSRetryMetrics.nonRetryableFastFallback.Load(), + } +} + +func SnapshotOpenAICompatibilityFallbackMetrics() OpenAICompatibilityFallbackMetricsSnapshot { + legacyReadFallbackTotal, legacyReadFallbackHit, legacyDualWriteTotal := openAIStickyCompatStats() + isMaxTokensOneHaiku, thinkingEnabled, prefetchedStickyAccount, prefetchedStickyGroup, singleAccountRetry, accountSwitchCount := RequestMetadataFallbackStats() + + readHitRate := float64(0) + if legacyReadFallbackTotal > 0 { + readHitRate = float64(legacyReadFallbackHit) / float64(legacyReadFallbackTotal) + } + metadataFallbackTotal := isMaxTokensOneHaiku + thinkingEnabled + prefetchedStickyAccount + prefetchedStickyGroup + singleAccountRetry + accountSwitchCount + + return OpenAICompatibilityFallbackMetricsSnapshot{ + SessionHashLegacyReadFallbackTotal: legacyReadFallbackTotal, + SessionHashLegacyReadFallbackHit: legacyReadFallbackHit, + SessionHashLegacyDualWriteTotal: legacyDualWriteTotal, + SessionHashLegacyReadHitRate: readHitRate, + + MetadataLegacyFallbackIsMaxTokensOneHaikuTotal: isMaxTokensOneHaiku, + MetadataLegacyFallbackThinkingEnabledTotal: thinkingEnabled, + MetadataLegacyFallbackPrefetchedStickyAccount: prefetchedStickyAccount, + MetadataLegacyFallbackPrefetchedStickyGroup: prefetchedStickyGroup, + MetadataLegacyFallbackSingleAccountRetryTotal: singleAccountRetry, + MetadataLegacyFallbackAccountSwitchCountTotal: accountSwitchCount, + MetadataLegacyFallbackTotal: metadataFallbackTotal, + } +} + +func (s *OpenAIGatewayService) detectCodexClientRestriction(c *gin.Context, account *Account) CodexClientRestrictionDetectionResult { + return s.getCodexClientRestrictionDetector().Detect(c, account) +} + +func getAPIKeyIDFromContext(c *gin.Context) int64 { + if c == nil { + return 0 + } + v, exists := c.Get("api_key") + if !exists { + return 0 + } + apiKey, ok := v.(*APIKey) + if !ok || apiKey == nil { + return 0 + } + return apiKey.ID +} + +// isolateOpenAISessionID 将 apiKeyID 混入 session 标识符, +// 确保不同 API Key 的用户即使使用相同的原始 session_id/conversation_id, +// 到达上游的标识符也不同,防止跨用户会话碰撞。 +func isolateOpenAISessionID(apiKeyID int64, raw string) string { + raw = strings.TrimSpace(raw) + if raw == "" { + return "" + } + h := xxhash.New() + _, _ = fmt.Fprintf(h, "k%d:", apiKeyID) + _, _ = h.WriteString(raw) + return fmt.Sprintf("%016x", h.Sum64()) +} + +func logCodexCLIOnlyDetection(ctx context.Context, c *gin.Context, account *Account, apiKeyID int64, result CodexClientRestrictionDetectionResult, body []byte) { + if !result.Enabled { + return + } + if ctx == nil { + ctx = context.Background() + } + accountID := int64(0) + if account != nil { + accountID = account.ID + } + fields := []zap.Field{ + zap.String("component", "service.openai_gateway"), + zap.Int64("account_id", accountID), + zap.Bool("codex_cli_only_enabled", result.Enabled), + zap.Bool("codex_official_client_match", result.Matched), + zap.String("reject_reason", result.Reason), + } + if apiKeyID > 0 { + fields = append(fields, zap.Int64("api_key_id", apiKeyID)) + } + if !result.Matched { + fields = appendCodexCLIOnlyRejectedRequestFields(fields, c, body) + } + log := logger.FromContext(ctx).With(fields...) + if result.Matched { + return + } + log.Warn("OpenAI codex_cli_only 拒绝非官方客户端请求") +} + +func appendCodexCLIOnlyRejectedRequestFields(fields []zap.Field, c *gin.Context, body []byte) []zap.Field { + if c == nil || c.Request == nil { + return fields + } + + req := c.Request + requestModel, requestStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body) + fields = append(fields, + zap.String("request_method", strings.TrimSpace(req.Method)), + zap.String("request_path", strings.TrimSpace(req.URL.Path)), + zap.String("request_query", strings.TrimSpace(req.URL.RawQuery)), + zap.String("request_host", strings.TrimSpace(req.Host)), + zap.String("request_client_ip", strings.TrimSpace(c.ClientIP())), + zap.String("request_remote_addr", strings.TrimSpace(req.RemoteAddr)), + zap.String("request_user_agent", strings.TrimSpace(req.Header.Get("User-Agent"))), + zap.String("request_content_type", strings.TrimSpace(req.Header.Get("Content-Type"))), + zap.Int64("request_content_length", req.ContentLength), + zap.Bool("request_stream", requestStream), + ) + if requestModel != "" { + fields = append(fields, zap.String("request_model", requestModel)) + } + if promptCacheKey != "" { + fields = append(fields, zap.String("request_prompt_cache_key_sha256", hashSensitiveValueForLog(promptCacheKey))) + } + + if headers := snapshotCodexCLIOnlyHeaders(req.Header); len(headers) > 0 { + fields = append(fields, zap.Any("request_headers", headers)) + } + fields = append(fields, zap.Int("request_body_size", len(body))) + return fields +} + +func snapshotCodexCLIOnlyHeaders(header http.Header) map[string]string { + if len(header) == 0 { + return nil + } + result := make(map[string]string, len(codexCLIOnlyDebugHeaderWhitelist)) + for _, key := range codexCLIOnlyDebugHeaderWhitelist { + value := strings.TrimSpace(header.Get(key)) + if value == "" { + continue + } + result[strings.ToLower(key)] = truncateString(value, codexCLIOnlyHeaderValueMaxBytes) + } + return result +} + +func hashSensitiveValueForLog(raw string) string { + value := strings.TrimSpace(raw) + if value == "" { + return "" + } + sum := sha256.Sum256([]byte(value)) + return hex.EncodeToString(sum[:8]) +} + +func logOpenAIInstructionsRequiredDebug( + ctx context.Context, + c *gin.Context, + account *Account, + upstreamStatusCode int, + upstreamMsg string, + requestBody []byte, + upstreamBody []byte, +) { + msg := strings.TrimSpace(upstreamMsg) + if !isOpenAIInstructionsRequiredError(upstreamStatusCode, msg, upstreamBody) { + return + } + if ctx == nil { + ctx = context.Background() + } + + accountID := int64(0) + accountName := "" + if account != nil { + accountID = account.ID + accountName = strings.TrimSpace(account.Name) + } + + userAgent := "" + originator := "" + if c != nil { + userAgent = strings.TrimSpace(c.GetHeader("User-Agent")) + originator = strings.TrimSpace(c.GetHeader("originator")) + } + + fields := []zap.Field{ + zap.String("component", "service.openai_gateway"), + zap.Int64("account_id", accountID), + zap.String("account_name", accountName), + zap.Int("upstream_status_code", upstreamStatusCode), + zap.String("upstream_error_message", msg), + zap.String("request_user_agent", userAgent), + zap.Bool("codex_official_client_match", openai.IsCodexOfficialClientByHeaders(userAgent, originator)), + } + fields = appendCodexCLIOnlyRejectedRequestFields(fields, c, requestBody) + + logger.FromContext(ctx).With(fields...).Warn("OpenAI 上游返回 Instructions are required,已记录请求详情用于排查") +} + +func isOpenAIInstructionsRequiredError(upstreamStatusCode int, upstreamMsg string, upstreamBody []byte) bool { + if upstreamStatusCode != http.StatusBadRequest { + return false + } + + hasInstructionRequired := func(text string) bool { + lower := strings.ToLower(strings.TrimSpace(text)) + if lower == "" { + return false + } + if strings.Contains(lower, "instructions are required") { + return true + } + if strings.Contains(lower, "required parameter: 'instructions'") { + return true + } + if strings.Contains(lower, "required parameter: instructions") { + return true + } + if strings.Contains(lower, "missing required parameter") && strings.Contains(lower, "instructions") { + return true + } + return strings.Contains(lower, "instruction") && strings.Contains(lower, "required") + } + + if hasInstructionRequired(upstreamMsg) { + return true + } + if len(upstreamBody) == 0 { + return false + } + + errMsg := gjson.GetBytes(upstreamBody, "error.message").String() + errMsgLower := strings.ToLower(strings.TrimSpace(errMsg)) + errCode := strings.ToLower(strings.TrimSpace(gjson.GetBytes(upstreamBody, "error.code").String())) + errParam := strings.ToLower(strings.TrimSpace(gjson.GetBytes(upstreamBody, "error.param").String())) + errType := strings.ToLower(strings.TrimSpace(gjson.GetBytes(upstreamBody, "error.type").String())) + + if errParam == "instructions" { + return true + } + if hasInstructionRequired(errMsg) { + return true + } + if strings.Contains(errCode, "missing_required_parameter") && strings.Contains(errMsgLower, "instructions") { + return true + } + if strings.Contains(errType, "invalid_request") && strings.Contains(errMsgLower, "instructions") && strings.Contains(errMsgLower, "required") { + return true + } + + return false +} + +func isOpenAITransientProcessingError(upstreamStatusCode int, upstreamMsg string, upstreamBody []byte) bool { + if upstreamStatusCode != http.StatusBadRequest { + return false + } + + match := func(text string) bool { + lower := strings.ToLower(strings.TrimSpace(text)) + if lower == "" { + return false + } + if strings.Contains(lower, "an error occurred while processing your request") { + return true + } + return strings.Contains(lower, "you can retry your request") && + strings.Contains(lower, "help.openai.com") && + strings.Contains(lower, "request id") + } + + if match(upstreamMsg) { + return true + } + if len(upstreamBody) == 0 { + return false + } + if match(gjson.GetBytes(upstreamBody, "error.message").String()) { + return true + } + return match(string(upstreamBody)) +} + +// ExtractSessionID extracts the raw session ID from headers or body without hashing. +// Used by ForwardAsAnthropic to pass as prompt_cache_key for upstream cache. +func (s *OpenAIGatewayService) ExtractSessionID(c *gin.Context, body []byte) string { + if c == nil { + return "" + } + sessionID := strings.TrimSpace(c.GetHeader("session_id")) + if sessionID == "" { + sessionID = strings.TrimSpace(c.GetHeader("conversation_id")) + } + if sessionID == "" && len(body) > 0 { + sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()) + } + return sessionID +} + +// GenerateSessionHash generates a sticky-session hash for OpenAI requests. +// +// Priority: +// 1. Header: session_id +// 2. Header: conversation_id +// 3. Body: prompt_cache_key (opencode) +func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, body []byte) string { + if c == nil { + return "" + } + + sessionID := strings.TrimSpace(c.GetHeader("session_id")) + if sessionID == "" { + sessionID = strings.TrimSpace(c.GetHeader("conversation_id")) + } + if sessionID == "" && len(body) > 0 { + sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()) + } + if sessionID == "" { + return "" + } + + currentHash, legacyHash := deriveOpenAISessionHashes(sessionID) + attachOpenAILegacySessionHashToGin(c, legacyHash) + return currentHash +} + +// GenerateSessionHashWithFallback 先按常规信号生成会话哈希; +// 当未携带 session_id/conversation_id/prompt_cache_key 时,使用 fallbackSeed 生成稳定哈希。 +// 该方法用于 WS ingress,避免会话信号缺失时发生跨账号漂移。 +func (s *OpenAIGatewayService) GenerateSessionHashWithFallback(c *gin.Context, body []byte, fallbackSeed string) string { + sessionHash := s.GenerateSessionHash(c, body) + if sessionHash != "" { + return sessionHash + } + + seed := strings.TrimSpace(fallbackSeed) + if seed == "" { + return "" + } + + currentHash, legacyHash := deriveOpenAISessionHashes(seed) + attachOpenAILegacySessionHashToGin(c, legacyHash) + return currentHash +} + +func resolveOpenAIUpstreamOriginator(c *gin.Context, isOfficialClient bool) string { + if c != nil { + if originator := strings.TrimSpace(c.GetHeader("originator")); originator != "" { + return originator + } + } + if isOfficialClient { + return "codex_cli_rs" + } + return "opencode" +} + +// BindStickySession sets session -> account binding with standard TTL. +func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, groupID *int64, sessionHash string, accountID int64) error { + if sessionHash == "" || accountID <= 0 { + return nil + } + ttl := openaiStickySessionTTL + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.StickySessionTTLSeconds > 0 { + ttl = time.Duration(s.cfg.Gateway.OpenAIWS.StickySessionTTLSeconds) * time.Second + } + return s.setStickySessionAccountID(ctx, groupID, sessionHash, accountID, ttl) +} + +// SelectAccount selects an OpenAI account with sticky session support +func (s *OpenAIGatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) { + return s.SelectAccountForModel(ctx, groupID, sessionHash, "") +} + +// SelectAccountForModel selects an account supporting the requested model +func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) { + return s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, nil) +} + +// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts. +// SelectAccountForModelWithExclusions 选择支持指定模型的账号,同时排除指定的账号。 +func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { + return s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, 0) +} + +func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, stickyAccountID int64) (*Account, error) { + // 1. 尝试粘性会话命中 + // Try sticky session hit + if account := s.tryStickySessionHit(ctx, groupID, sessionHash, requestedModel, excludedIDs, stickyAccountID); account != nil { + return account, nil + } + + // 2. 获取可调度的 OpenAI 账号 + // Get schedulable OpenAI accounts + accounts, err := s.listSchedulableAccounts(ctx, groupID) + if err != nil { + return nil, fmt.Errorf("query accounts failed: %w", err) + } + + // 3. 按优先级 + LRU 选择最佳账号 + // Select by priority + LRU + selected := s.selectBestAccount(ctx, accounts, requestedModel, excludedIDs) + + if selected == nil { + if requestedModel != "" { + return nil, fmt.Errorf("no available OpenAI accounts supporting model: %s", requestedModel) + } + return nil, errors.New("no available OpenAI accounts") + } + + // 4. 设置粘性会话绑定 + // Set sticky session binding + if sessionHash != "" { + _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, selected.ID, openaiStickySessionTTL) + } + + return selected, nil +} + +// tryStickySessionHit 尝试从粘性会话获取账号。 +// 如果命中且账号可用则返回账号;如果账号不可用则清理会话并返回 nil。 +// +// tryStickySessionHit attempts to get account from sticky session. +// Returns account if hit and usable; clears session and returns nil if account is unavailable. +func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID *int64, sessionHash, requestedModel string, excludedIDs map[int64]struct{}, stickyAccountID int64) *Account { + if sessionHash == "" { + return nil + } + + accountID := stickyAccountID + if accountID <= 0 { + var err error + accountID, err = s.getStickySessionAccountID(ctx, groupID, sessionHash) + if err != nil || accountID <= 0 { + return nil + } + } + + if _, excluded := excludedIDs[accountID]; excluded { + return nil + } + + account, err := s.getSchedulableAccount(ctx, accountID) + if err != nil { + return nil + } + + // 检查账号是否需要清理粘性会话 + // Check if sticky session should be cleared + if shouldClearStickySession(ctx, account, requestedModel) { + _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) + return nil + } + + // 验证账号是否可用于当前请求 + // Verify account is usable for current request + if !account.IsSchedulable() || !account.IsOpenAI() { + return nil + } + if requestedModel != "" && !account.IsModelSupported(requestedModel) { + return nil + } + + // 刷新会话 TTL 并返回账号 + // Refresh session TTL and return account + _ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL) + return account +} + +// selectBestAccount 从候选账号中选择最佳账号(优先级 + LRU)。 +// 返回 nil 表示无可用账号。 +// +// selectBestAccount selects the best account from candidates (priority + LRU). +// Returns nil if no available account. +func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account { + var selected *Account + + for i := range accounts { + acc := &accounts[i] + + // 跳过被排除的账号 + // Skip excluded accounts + if _, excluded := excludedIDs[acc.ID]; excluded { + continue + } + + fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel) + if fresh == nil { + continue + } + + // 选择优先级最高且最久未使用的账号 + // Select highest priority and least recently used + if selected == nil { + selected = fresh + continue + } + + if s.isBetterAccount(fresh, selected) { + selected = fresh + } + } + + return selected +} + +// isBetterAccount 判断 candidate 是否比 current 更优。 +// 规则:优先级更高(数值更小)优先;同优先级时,未使用过的优先,其次是最久未使用的。 +// +// isBetterAccount checks if candidate is better than current. +// Rules: higher priority (lower value) wins; same priority: never used > least recently used. +func (s *OpenAIGatewayService) isBetterAccount(candidate, current *Account) bool { + // 优先级更高(数值更小) + // Higher priority (lower value) + if candidate.Priority < current.Priority { + return true + } + if candidate.Priority > current.Priority { + return false + } + + // 同优先级,比较最后使用时间 + // Same priority, compare last used time + switch { + case candidate.LastUsedAt == nil && current.LastUsedAt != nil: + // candidate 从未使用,优先 + return true + case candidate.LastUsedAt != nil && current.LastUsedAt == nil: + // current 从未使用,保持 + return false + case candidate.LastUsedAt == nil && current.LastUsedAt == nil: + // 都未使用,保持 + return false + default: + // 都使用过,选择最久未使用的 + return candidate.LastUsedAt.Before(*current.LastUsedAt) + } +} + +// SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan. +func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) { + cfg := s.schedulingConfig() + var stickyAccountID int64 + if sessionHash != "" && s.cache != nil { + if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil { + stickyAccountID = accountID + } + } + if s.concurrencyService == nil || !cfg.LoadBatchEnabled { + account, err := s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, stickyAccountID) + if err != nil { + return nil, err + } + result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency) + if err == nil && result.Acquired { + return &AccountSelectionResult{ + Account: account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } + if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil { + waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID) + if waitingCount < cfg.StickySessionMaxWaiting { + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: account.ID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } + } + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: account.ID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }, + }, nil + } + + accounts, err := s.listSchedulableAccounts(ctx, groupID) + if err != nil { + return nil, err + } + if len(accounts) == 0 { + return nil, ErrNoAvailableAccounts + } + + isExcluded := func(accountID int64) bool { + if excludedIDs == nil { + return false + } + _, excluded := excludedIDs[accountID] + return excluded + } + + // ============ Layer 1: Sticky session ============ + if sessionHash != "" { + accountID := stickyAccountID + if accountID > 0 && !isExcluded(accountID) { + account, err := s.getSchedulableAccount(ctx, accountID) + if err == nil { + clearSticky := shouldClearStickySession(ctx, account, requestedModel) + if clearSticky { + _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) + } + if !clearSticky && account.IsSchedulable() && account.IsOpenAI() && + (requestedModel == "" || account.IsModelSupported(requestedModel)) { + result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) + if err == nil && result.Acquired { + _ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL) + return &AccountSelectionResult{ + Account: account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } + + waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) + if waitingCount < cfg.StickySessionMaxWaiting { + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: accountID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } + } + } + } + } + + // ============ Layer 2: Load-aware selection ============ + candidates := make([]*Account, 0, len(accounts)) + for i := range accounts { + acc := &accounts[i] + if isExcluded(acc.ID) { + continue + } + // Scheduler snapshots can be temporarily stale (bucket rebuild is throttled); + // re-check schedulability here so recently rate-limited/overloaded accounts + // are not selected again before the bucket is rebuilt. + if !acc.IsSchedulable() { + continue + } + if requestedModel != "" && !acc.IsModelSupported(requestedModel) { + continue + } + candidates = append(candidates, acc) + } + + if len(candidates) == 0 { + return nil, ErrNoAvailableAccounts + } + + accountLoads := make([]AccountWithConcurrency, 0, len(candidates)) + for _, acc := range candidates { + accountLoads = append(accountLoads, AccountWithConcurrency{ + ID: acc.ID, + MaxConcurrency: acc.EffectiveLoadFactor(), + }) + } + + loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) + if err != nil { + ordered := append([]*Account(nil), candidates...) + sortAccountsByPriorityAndLastUsed(ordered, false) + for _, acc := range ordered { + fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel) + if fresh == nil { + continue + } + result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) + if err == nil && result.Acquired { + if sessionHash != "" { + _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL) + } + return &AccountSelectionResult{ + Account: fresh, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } + } + } else { + var available []accountWithLoad + for _, acc := range candidates { + loadInfo := loadMap[acc.ID] + if loadInfo == nil { + loadInfo = &AccountLoadInfo{AccountID: acc.ID} + } + if loadInfo.LoadRate < 100 { + available = append(available, accountWithLoad{ + account: acc, + loadInfo: loadInfo, + }) + } + } + + if len(available) > 0 { + sort.SliceStable(available, func(i, j int) bool { + a, b := available[i], available[j] + if a.account.Priority != b.account.Priority { + return a.account.Priority < b.account.Priority + } + if a.loadInfo.LoadRate != b.loadInfo.LoadRate { + return a.loadInfo.LoadRate < b.loadInfo.LoadRate + } + switch { + case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil: + return true + case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil: + return false + case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil: + return false + default: + return a.account.LastUsedAt.Before(*b.account.LastUsedAt) + } + }) + shuffleWithinSortGroups(available) + + for _, item := range available { + fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, item.account, requestedModel) + if fresh == nil { + continue + } + result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) + if err == nil && result.Acquired { + if sessionHash != "" { + _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL) + } + return &AccountSelectionResult{ + Account: fresh, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } + } + } + } + + // ============ Layer 3: Fallback wait ============ + sortAccountsByPriorityAndLastUsed(candidates, false) + for _, acc := range candidates { + fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel) + if fresh == nil { + continue + } + return &AccountSelectionResult{ + Account: fresh, + WaitPlan: &AccountWaitPlan{ + AccountID: fresh.ID, + MaxConcurrency: fresh.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }, + }, nil + } + + return nil, ErrNoAvailableAccounts +} + +func (s *OpenAIGatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, error) { + if s.schedulerSnapshot != nil { + accounts, _, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, PlatformOpenAI, false) + return accounts, err + } + var accounts []Account + var err error + if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI) + } else if groupID != nil { + accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI) + } else { + accounts, err = s.accountRepo.ListSchedulableUngroupedByPlatform(ctx, PlatformOpenAI) + } + if err != nil { + return nil, fmt.Errorf("query accounts failed: %w", err) + } + return accounts, nil +} + +func (s *OpenAIGatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) { + if s.concurrencyService == nil { + return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil + } + return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) +} + +func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.Context, account *Account, requestedModel string) *Account { + if account == nil { + return nil + } + + fresh := account + if s.schedulerSnapshot != nil { + current, err := s.getSchedulableAccount(ctx, account.ID) + if err != nil || current == nil { + return nil + } + fresh = current + } + + if !fresh.IsSchedulable() || !fresh.IsOpenAI() { + return nil + } + if requestedModel != "" && !fresh.IsModelSupported(requestedModel) { + return nil + } + return fresh +} + +func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) { + var ( + account *Account + err error + ) + if s.schedulerSnapshot != nil { + account, err = s.schedulerSnapshot.GetAccount(ctx, accountID) + } else { + account, err = s.accountRepo.GetByID(ctx, accountID) + } + if err != nil || account == nil { + return account, err + } + syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, account, time.Now()) + return account, nil +} + +func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig { + if s.cfg != nil { + return s.cfg.Gateway.Scheduling + } + return config.GatewaySchedulingConfig{ + StickySessionMaxWaiting: 3, + StickySessionWaitTimeout: 45 * time.Second, + FallbackWaitTimeout: 30 * time.Second, + FallbackMaxWaiting: 100, + LoadBatchEnabled: true, + SlotCleanupInterval: 30 * time.Second, + } +} + +// GetAccessToken gets the access token for an OpenAI account +func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) { + switch account.Type { + case AccountTypeOAuth: + // 使用 TokenProvider 获取缓存的 token + if s.openAITokenProvider != nil { + accessToken, err := s.openAITokenProvider.GetAccessToken(ctx, account) + if err != nil { + return "", "", err + } + return accessToken, "oauth", nil + } + // 降级:TokenProvider 未配置时直接从账号读取 + accessToken := account.GetOpenAIAccessToken() + if accessToken == "" { + return "", "", errors.New("access_token not found in credentials") + } + return accessToken, "oauth", nil + case AccountTypeAPIKey: + apiKey := account.GetOpenAIApiKey() + if apiKey == "" { + return "", "", errors.New("api_key not found in credentials") + } + return apiKey, "apikey", nil + default: + return "", "", fmt.Errorf("unsupported account type: %s", account.Type) + } +} + +func (s *OpenAIGatewayService) shouldFailoverUpstreamError(statusCode int) bool { + switch statusCode { + case 401, 402, 403, 429, 529: + return true + default: + return statusCode >= 500 + } +} + +func (s *OpenAIGatewayService) shouldFailoverOpenAIUpstreamResponse(statusCode int, upstreamMsg string, upstreamBody []byte) bool { + if s.shouldFailoverUpstreamError(statusCode) { + return true + } + return isOpenAITransientProcessingError(statusCode, upstreamMsg, upstreamBody) +} + +func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) +} + +// Forward forwards request to OpenAI API +func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*OpenAIForwardResult, error) { + startTime := time.Now() + + restrictionResult := s.detectCodexClientRestriction(c, account) + apiKeyID := getAPIKeyIDFromContext(c) + logCodexCLIOnlyDetection(ctx, c, account, apiKeyID, restrictionResult, body) + if restrictionResult.Enabled && !restrictionResult.Matched { + c.JSON(http.StatusForbidden, gin.H{ + "error": gin.H{ + "type": "forbidden_error", + "message": "This account only allows Codex official clients", + }, + }) + return nil, errors.New("codex_cli_only restriction: only codex official clients are allowed") + } + + originalBody := body + reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body) + originalModel := reqModel + + isCodexCLI := openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI) + wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account) + clientTransport := GetOpenAIClientTransport(c) + // 仅允许 WS 入站请求走 WS 上游,避免出现 HTTP -> WS 协议混用。 + wsDecision = resolveOpenAIWSDecisionByClientTransport(wsDecision, clientTransport) + if c != nil { + c.Set("openai_ws_transport_decision", string(wsDecision.Transport)) + c.Set("openai_ws_transport_reason", wsDecision.Reason) + } + if wsDecision.Transport == OpenAIUpstreamTransportResponsesWebsocketV2 { + logOpenAIWSModeDebug( + "selected account_id=%d account_type=%s transport=%s reason=%s model=%s stream=%v", + account.ID, + account.Type, + normalizeOpenAIWSLogValue(string(wsDecision.Transport)), + normalizeOpenAIWSLogValue(wsDecision.Reason), + reqModel, + reqStream, + ) + } + // 当前仅支持 WSv2;WSv1 命中时直接返回错误,避免出现“配置可开但行为不确定”。 + if wsDecision.Transport == OpenAIUpstreamTransportResponsesWebsocket { + if c != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "type": "invalid_request_error", + "message": "OpenAI WSv1 is temporarily unsupported. Please enable responses_websockets_v2.", + }, + }) + } + return nil, errors.New("openai ws v1 is temporarily unsupported; use ws v2") + } + passthroughEnabled := account.IsOpenAIPassthroughEnabled() + if passthroughEnabled { + // 透传分支只需要轻量提取字段,避免热路径全量 Unmarshal。 + reasoningEffort := extractOpenAIReasoningEffortFromBody(body, reqModel) + return s.forwardOpenAIPassthrough(ctx, c, account, originalBody, reqModel, reasoningEffort, reqStream, startTime) + } + + reqBody, err := getOpenAIRequestBodyMap(c, body) + if err != nil { + return nil, err + } + + if v, ok := reqBody["model"].(string); ok { + reqModel = v + originalModel = reqModel + } + if v, ok := reqBody["stream"].(bool); ok { + reqStream = v + } + if promptCacheKey == "" { + if v, ok := reqBody["prompt_cache_key"].(string); ok { + promptCacheKey = strings.TrimSpace(v) + } + } + + // Track if body needs re-serialization + bodyModified := false + // 单字段补丁快速路径:只要整个变更集最终可归约为同一路径的 set/delete,就避免全量 Marshal。 + patchDisabled := false + patchHasOp := false + patchDelete := false + patchPath := "" + var patchValue any + markPatchSet := func(path string, value any) { + if strings.TrimSpace(path) == "" { + patchDisabled = true + return + } + if patchDisabled { + return + } + if !patchHasOp { + patchHasOp = true + patchDelete = false + patchPath = path + patchValue = value + return + } + if patchDelete || patchPath != path { + patchDisabled = true + return + } + patchValue = value + } + markPatchDelete := func(path string) { + if strings.TrimSpace(path) == "" { + patchDisabled = true + return + } + if patchDisabled { + return + } + if !patchHasOp { + patchHasOp = true + patchDelete = true + patchPath = path + return + } + if !patchDelete || patchPath != path { + patchDisabled = true + } + } + disablePatch := func() { + patchDisabled = true + } + + // 非透传模式下,instructions 为空时注入默认指令。 + if isInstructionsEmpty(reqBody) { + reqBody["instructions"] = "You are a helpful coding assistant." + bodyModified = true + markPatchSet("instructions", "You are a helpful coding assistant.") + } + + // 对所有请求执行模型映射(包含 Codex CLI)。 + mappedModel := account.GetMappedModel(reqModel) + if mappedModel != reqModel { + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, mappedModel, account.Name, isCodexCLI) + reqBody["model"] = mappedModel + bodyModified = true + markPatchSet("model", mappedModel) + } + + // 针对所有 OpenAI 账号执行 Codex 模型名规范化,确保上游识别一致。 + if model, ok := reqBody["model"].(string); ok { + normalizedModel := normalizeCodexModel(model) + if normalizedModel != "" && normalizedModel != model { + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Codex model normalization: %s -> %s (account: %s, type: %s, isCodexCLI: %v)", + model, normalizedModel, account.Name, account.Type, isCodexCLI) + reqBody["model"] = normalizedModel + mappedModel = normalizedModel + bodyModified = true + markPatchSet("model", normalizedModel) + } + + // 移除 gpt-5.2-codex 以下的版本 verbosity 参数 + // 确保高版本模型向低版本模型映射不报错 + if !SupportsVerbosity(normalizedModel) { + if text, ok := reqBody["text"].(map[string]any); ok { + delete(text, "verbosity") + } + } + } + + // 规范化 reasoning.effort 参数(minimal -> none),与上游允许值对齐。 + if reasoning, ok := reqBody["reasoning"].(map[string]any); ok { + if effort, ok := reasoning["effort"].(string); ok && effort == "minimal" { + reasoning["effort"] = "none" + bodyModified = true + markPatchSet("reasoning.effort", "none") + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Normalized reasoning.effort: minimal -> none (account: %s)", account.Name) + } + } + + if account.Type == AccountTypeOAuth { + codexResult := applyCodexOAuthTransform(reqBody, isCodexCLI, isOpenAIResponsesCompactPath(c)) + if codexResult.Modified { + bodyModified = true + disablePatch() + } + if codexResult.NormalizedModel != "" { + mappedModel = codexResult.NormalizedModel + } + if codexResult.PromptCacheKey != "" { + promptCacheKey = codexResult.PromptCacheKey + } + } + + // Handle max_output_tokens based on platform and account type + if !isCodexCLI { + if maxOutputTokens, hasMaxOutputTokens := reqBody["max_output_tokens"]; hasMaxOutputTokens { + switch account.Platform { + case PlatformOpenAI: + // For OpenAI API Key, remove max_output_tokens (not supported) + // For OpenAI OAuth (Responses API), keep it (supported) + if account.Type == AccountTypeAPIKey { + delete(reqBody, "max_output_tokens") + bodyModified = true + markPatchDelete("max_output_tokens") + } + case PlatformAnthropic: + // For Anthropic (Claude), convert to max_tokens + delete(reqBody, "max_output_tokens") + markPatchDelete("max_output_tokens") + if _, hasMaxTokens := reqBody["max_tokens"]; !hasMaxTokens { + reqBody["max_tokens"] = maxOutputTokens + disablePatch() + } + bodyModified = true + case PlatformGemini: + // For Gemini, remove (will be handled by Gemini-specific transform) + delete(reqBody, "max_output_tokens") + bodyModified = true + markPatchDelete("max_output_tokens") + default: + // For unknown platforms, remove to be safe + delete(reqBody, "max_output_tokens") + bodyModified = true + markPatchDelete("max_output_tokens") + } + } + + // Also handle max_completion_tokens (similar logic) + if _, hasMaxCompletionTokens := reqBody["max_completion_tokens"]; hasMaxCompletionTokens { + if account.Type == AccountTypeAPIKey || account.Platform != PlatformOpenAI { + delete(reqBody, "max_completion_tokens") + bodyModified = true + markPatchDelete("max_completion_tokens") + } + } + + // Remove unsupported fields (not supported by upstream OpenAI API) + unsupportedFields := []string{"prompt_cache_retention", "safety_identifier"} + for _, unsupportedField := range unsupportedFields { + if _, has := reqBody[unsupportedField]; has { + delete(reqBody, unsupportedField) + bodyModified = true + markPatchDelete(unsupportedField) + } + } + } + + // 仅在 WSv2 模式保留 previous_response_id,其他模式(HTTP/WSv1)统一过滤。 + // 注意:该规则同样适用于 Codex CLI 请求,避免 WSv1 向上游透传不支持字段。 + if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 { + if _, has := reqBody["previous_response_id"]; has { + delete(reqBody, "previous_response_id") + bodyModified = true + markPatchDelete("previous_response_id") + } + } + + // Re-serialize body only if modified + if bodyModified { + serializedByPatch := false + if !patchDisabled && patchHasOp { + var patchErr error + if patchDelete { + body, patchErr = sjson.DeleteBytes(body, patchPath) + } else { + body, patchErr = sjson.SetBytes(body, patchPath, patchValue) + } + if patchErr == nil { + serializedByPatch = true + } + } + if !serializedByPatch { + var marshalErr error + body, marshalErr = json.Marshal(reqBody) + if marshalErr != nil { + return nil, fmt.Errorf("serialize request body: %w", marshalErr) + } + } + } + + // Get access token + token, _, err := s.GetAccessToken(ctx, account) + if err != nil { + return nil, err + } + + // Capture upstream request body for ops retry of this attempt. + setOpsUpstreamRequestBody(c, body) + + // 命中 WS 时仅走 WebSocket Mode;不再自动回退 HTTP。 + if wsDecision.Transport == OpenAIUpstreamTransportResponsesWebsocketV2 { + wsReqBody := reqBody + if len(reqBody) > 0 { + wsReqBody = make(map[string]any, len(reqBody)) + for k, v := range reqBody { + wsReqBody[k] = v + } + } + _, hasPreviousResponseID := wsReqBody["previous_response_id"] + logOpenAIWSModeDebug( + "forward_start account_id=%d account_type=%s model=%s stream=%v has_previous_response_id=%v", + account.ID, + account.Type, + mappedModel, + reqStream, + hasPreviousResponseID, + ) + maxAttempts := openAIWSReconnectRetryLimit + 1 + wsAttempts := 0 + var wsResult *OpenAIForwardResult + var wsErr error + wsLastFailureReason := "" + wsPrevResponseRecoveryTried := false + wsInvalidEncryptedContentRecoveryTried := false + recoverPrevResponseNotFound := func(attempt int) bool { + if wsPrevResponseRecoveryTried { + return false + } + previousResponseID := openAIWSPayloadString(wsReqBody, "previous_response_id") + if previousResponseID == "" { + logOpenAIWSModeInfo( + "reconnect_prev_response_recovery_skip account_id=%d attempt=%d reason=missing_previous_response_id previous_response_id_present=false", + account.ID, + attempt, + ) + return false + } + if HasFunctionCallOutput(wsReqBody) { + logOpenAIWSModeInfo( + "reconnect_prev_response_recovery_skip account_id=%d attempt=%d reason=has_function_call_output previous_response_id_present=true", + account.ID, + attempt, + ) + return false + } + delete(wsReqBody, "previous_response_id") + wsPrevResponseRecoveryTried = true + logOpenAIWSModeInfo( + "reconnect_prev_response_recovery account_id=%d attempt=%d action=drop_previous_response_id retry=1 previous_response_id=%s previous_response_id_kind=%s", + account.ID, + attempt, + truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(ClassifyOpenAIPreviousResponseIDKind(previousResponseID)), + ) + return true + } + recoverInvalidEncryptedContent := func(attempt int) bool { + if wsInvalidEncryptedContentRecoveryTried { + return false + } + removedReasoningItems := trimOpenAIEncryptedReasoningItems(wsReqBody) + if !removedReasoningItems { + logOpenAIWSModeInfo( + "reconnect_invalid_encrypted_content_recovery_skip account_id=%d attempt=%d reason=missing_encrypted_reasoning_items", + account.ID, + attempt, + ) + return false + } + previousResponseID := openAIWSPayloadString(wsReqBody, "previous_response_id") + hasFunctionCallOutput := HasFunctionCallOutput(wsReqBody) + if previousResponseID != "" && !hasFunctionCallOutput { + delete(wsReqBody, "previous_response_id") + } + wsInvalidEncryptedContentRecoveryTried = true + logOpenAIWSModeInfo( + "reconnect_invalid_encrypted_content_recovery account_id=%d attempt=%d action=drop_encrypted_reasoning_items retry=1 previous_response_id_present=%v previous_response_id=%s previous_response_id_kind=%s has_function_call_output=%v dropped_previous_response_id=%v", + account.ID, + attempt, + previousResponseID != "", + truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(ClassifyOpenAIPreviousResponseIDKind(previousResponseID)), + hasFunctionCallOutput, + previousResponseID != "" && !hasFunctionCallOutput, + ) + return true + } + retryBudget := s.openAIWSRetryTotalBudget() + retryStartedAt := time.Now() + wsRetryLoop: + for attempt := 1; attempt <= maxAttempts; attempt++ { + wsAttempts = attempt + wsResult, wsErr = s.forwardOpenAIWSV2( + ctx, + c, + account, + wsReqBody, + token, + wsDecision, + isCodexCLI, + reqStream, + originalModel, + mappedModel, + startTime, + attempt, + wsLastFailureReason, + ) + if wsErr == nil { + break + } + if c != nil && c.Writer != nil && c.Writer.Written() { + break + } + + reason, retryable := classifyOpenAIWSReconnectReason(wsErr) + if reason != "" { + wsLastFailureReason = reason + } + // previous_response_not_found 说明续链锚点不可用: + // 对非 function_call_output 场景,允许一次“去掉 previous_response_id 后重放”。 + if reason == "previous_response_not_found" && recoverPrevResponseNotFound(attempt) { + continue + } + if reason == "invalid_encrypted_content" && recoverInvalidEncryptedContent(attempt) { + continue + } + if retryable && attempt < maxAttempts { + backoff := s.openAIWSRetryBackoff(attempt) + if retryBudget > 0 && time.Since(retryStartedAt)+backoff > retryBudget { + s.recordOpenAIWSRetryExhausted() + logOpenAIWSModeInfo( + "reconnect_budget_exhausted account_id=%d attempts=%d max_retries=%d reason=%s elapsed_ms=%d budget_ms=%d", + account.ID, + attempt, + openAIWSReconnectRetryLimit, + normalizeOpenAIWSLogValue(reason), + time.Since(retryStartedAt).Milliseconds(), + retryBudget.Milliseconds(), + ) + break + } + s.recordOpenAIWSRetryAttempt(backoff) + logOpenAIWSModeInfo( + "reconnect_retry account_id=%d retry=%d max_retries=%d reason=%s backoff_ms=%d", + account.ID, + attempt, + openAIWSReconnectRetryLimit, + normalizeOpenAIWSLogValue(reason), + backoff.Milliseconds(), + ) + if backoff > 0 { + timer := time.NewTimer(backoff) + select { + case <-ctx.Done(): + if !timer.Stop() { + <-timer.C + } + wsErr = wrapOpenAIWSFallback("retry_backoff_canceled", ctx.Err()) + break wsRetryLoop + case <-timer.C: + } + } + continue + } + if retryable { + s.recordOpenAIWSRetryExhausted() + logOpenAIWSModeInfo( + "reconnect_exhausted account_id=%d attempts=%d max_retries=%d reason=%s", + account.ID, + attempt, + openAIWSReconnectRetryLimit, + normalizeOpenAIWSLogValue(reason), + ) + } else if reason != "" { + s.recordOpenAIWSNonRetryableFastFallback() + logOpenAIWSModeInfo( + "reconnect_stop account_id=%d attempt=%d reason=%s", + account.ID, + attempt, + normalizeOpenAIWSLogValue(reason), + ) + } + break + } + if wsErr == nil { + firstTokenMs := int64(0) + hasFirstTokenMs := wsResult != nil && wsResult.FirstTokenMs != nil + if hasFirstTokenMs { + firstTokenMs = int64(*wsResult.FirstTokenMs) + } + requestID := "" + if wsResult != nil { + requestID = strings.TrimSpace(wsResult.RequestID) + } + logOpenAIWSModeDebug( + "forward_succeeded account_id=%d request_id=%s stream=%v has_first_token_ms=%v first_token_ms=%d ws_attempts=%d", + account.ID, + requestID, + reqStream, + hasFirstTokenMs, + firstTokenMs, + wsAttempts, + ) + wsResult.UpstreamModel = mappedModel + return wsResult, nil + } + s.writeOpenAIWSFallbackErrorResponse(c, account, wsErr) + return nil, wsErr + } + + httpInvalidEncryptedContentRetryTried := false + for { + // Build upstream request + upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) + upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI) + releaseUpstreamCtx() + if err != nil { + return nil, err + } + + // Get proxy URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + // Send request + upstreamStart := time.Now() + resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds()) + if err != nil { + // Ensure the client receives an error response (handlers assume Forward writes on non-failover errors). + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream request failed", + }, + }) + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + + // Handle error response + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + upstreamCode := extractUpstreamErrorCode(respBody) + if !httpInvalidEncryptedContentRetryTried && resp.StatusCode == http.StatusBadRequest && upstreamCode == "invalid_encrypted_content" { + if trimOpenAIEncryptedReasoningItems(reqBody) { + body, err = json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("serialize invalid_encrypted_content retry body: %w", err) + } + setOpsUpstreamRequestBody(c, body) + httpInvalidEncryptedContentRetryTried = true + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Retrying non-WSv2 request once after invalid_encrypted_content (account: %s)", account.Name) + continue + } + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Skip non-WSv2 invalid_encrypted_content retry because encrypted reasoning items are missing (account: %s)", account.Name) + } + if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) { + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + + s.handleFailoverSideEffects(ctx, resp, account) + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)), + } + } + return s.handleErrorResponse(ctx, resp, c, account, body) + } + defer func() { _ = resp.Body.Close() }() + + // Handle normal response + var usage *OpenAIUsage + var firstTokenMs *int + if reqStream { + streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, mappedModel) + if err != nil { + return nil, err + } + usage = streamResult.usage + firstTokenMs = streamResult.firstTokenMs + } else { + usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, mappedModel) + if err != nil { + return nil, err + } + } + + // Extract and save Codex usage snapshot from response headers (for OAuth accounts) + if account.Type == AccountTypeOAuth { + if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil { + s.updateCodexUsageSnapshot(ctx, account.ID, snapshot) + } + } + + if usage == nil { + usage = &OpenAIUsage{} + } + + reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel) + serviceTier := extractOpenAIServiceTier(reqBody) + + return &OpenAIForwardResult{ + RequestID: resp.Header.Get("x-request-id"), + Usage: *usage, + Model: originalModel, + UpstreamModel: mappedModel, + ServiceTier: serviceTier, + ReasoningEffort: reasoningEffort, + Stream: reqStream, + OpenAIWSMode: false, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + }, nil + } +} + +func (s *OpenAIGatewayService) forwardOpenAIPassthrough( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + reqModel string, + reasoningEffort *string, + reqStream bool, + startTime time.Time, +) (*OpenAIForwardResult, error) { + if account != nil && account.Type == AccountTypeOAuth { + if rejectReason := detectOpenAIPassthroughInstructionsRejectReason(reqModel, body); rejectReason != "" { + rejectMsg := "OpenAI codex passthrough requires a non-empty instructions field" + setOpsUpstreamError(c, http.StatusForbidden, rejectMsg, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: http.StatusForbidden, + Passthrough: true, + Kind: "request_error", + Message: rejectMsg, + Detail: rejectReason, + }) + logOpenAIPassthroughInstructionsRejected(ctx, c, account, reqModel, rejectReason, body) + c.JSON(http.StatusForbidden, gin.H{ + "error": gin.H{ + "type": "forbidden_error", + "message": rejectMsg, + }, + }) + return nil, fmt.Errorf("openai passthrough rejected before upstream: %s", rejectReason) + } + + normalizedBody, normalized, err := normalizeOpenAIPassthroughOAuthBody(body, isOpenAIResponsesCompactPath(c)) + if err != nil { + return nil, err + } + if normalized { + body = normalizedBody + } + reqStream = gjson.GetBytes(body, "stream").Bool() + } + + logger.LegacyPrintf("service.openai_gateway", + "[OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v", + account.ID, + account.Name, + account.Type, + reqModel, + reqStream, + ) + if reqStream && c != nil && c.Request != nil { + if timeoutHeaders := collectOpenAIPassthroughTimeoutHeaders(c.Request.Header); len(timeoutHeaders) > 0 { + streamWarnLogger := logger.FromContext(ctx).With( + zap.String("component", "service.openai_gateway"), + zap.Int64("account_id", account.ID), + zap.Strings("timeout_headers", timeoutHeaders), + ) + if s.isOpenAIPassthroughTimeoutHeadersAllowed() { + streamWarnLogger.Warn("OpenAI passthrough 透传请求包含超时相关请求头,且当前配置为放行,可能导致上游提前断流") + } else { + streamWarnLogger.Warn("OpenAI passthrough 检测到超时相关请求头,将按配置过滤以降低断流风险") + } + } + } + + // Get access token + token, _, err := s.GetAccessToken(ctx, account) + if err != nil { + return nil, err + } + + upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) + upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(upstreamCtx, c, account, body, token) + releaseUpstreamCtx() + if err != nil { + return nil, err + } + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + setOpsUpstreamRequestBody(c, body) + if c != nil { + c.Set("openai_passthrough", true) + } + + upstreamStart := time.Now() + resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds()) + if err != nil { + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Passthrough: true, + Kind: "request_error", + Message: safeErr, + }) + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream request failed", + }, + }) + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 { + // 透传模式不做 failover(避免改变原始上游语义),按上游原样返回错误响应。 + return nil, s.handleErrorResponsePassthrough(ctx, resp, c, account, body) + } + + var usage *OpenAIUsage + var firstTokenMs *int + if reqStream { + result, err := s.handleStreamingResponsePassthrough(ctx, resp, c, account, startTime) + if err != nil { + return nil, err + } + usage = result.usage + firstTokenMs = result.firstTokenMs + } else { + usage, err = s.handleNonStreamingResponsePassthrough(ctx, resp, c) + if err != nil { + return nil, err + } + } + + if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil { + s.updateCodexUsageSnapshot(ctx, account.ID, snapshot) + } + + if usage == nil { + usage = &OpenAIUsage{} + } + + return &OpenAIForwardResult{ + RequestID: resp.Header.Get("x-request-id"), + Usage: *usage, + Model: reqModel, + ServiceTier: extractOpenAIServiceTierFromBody(body), + ReasoningEffort: reasoningEffort, + Stream: reqStream, + OpenAIWSMode: false, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + }, nil +} + +func logOpenAIPassthroughInstructionsRejected( + ctx context.Context, + c *gin.Context, + account *Account, + reqModel string, + rejectReason string, + body []byte, +) { + if ctx == nil { + ctx = context.Background() + } + accountID := int64(0) + accountName := "" + accountType := "" + if account != nil { + accountID = account.ID + accountName = strings.TrimSpace(account.Name) + accountType = strings.TrimSpace(string(account.Type)) + } + fields := []zap.Field{ + zap.String("component", "service.openai_gateway"), + zap.Int64("account_id", accountID), + zap.String("account_name", accountName), + zap.String("account_type", accountType), + zap.String("request_model", strings.TrimSpace(reqModel)), + zap.String("reject_reason", strings.TrimSpace(rejectReason)), + } + fields = appendCodexCLIOnlyRejectedRequestFields(fields, c, body) + logger.FromContext(ctx).With(fields...).Warn("OpenAI passthrough 本地拦截:Codex 请求缺少有效 instructions") +} + +func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + token string, +) (*http.Request, error) { + targetURL := openaiPlatformAPIURL + switch account.Type { + case AccountTypeOAuth: + targetURL = chatgptCodexURL + case AccountTypeAPIKey: + baseURL := account.GetOpenAIBaseURL() + if baseURL != "" { + validatedURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, err + } + targetURL = buildOpenAIResponsesURL(validatedURL) + } + } + targetURL = appendOpenAIResponsesRequestPathSuffix(targetURL, openAIResponsesRequestPathSuffix(c)) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + + // 透传客户端请求头(安全白名单)。 + allowTimeoutHeaders := s.isOpenAIPassthroughTimeoutHeadersAllowed() + if c != nil && c.Request != nil { + for key, values := range c.Request.Header { + lower := strings.ToLower(strings.TrimSpace(key)) + if !isOpenAIPassthroughAllowedRequestHeader(lower, allowTimeoutHeaders) { + continue + } + for _, v := range values { + req.Header.Add(key, v) + } + } + } + + // 覆盖入站鉴权残留,并注入上游认证 + req.Header.Del("authorization") + req.Header.Del("x-api-key") + req.Header.Del("x-goog-api-key") + req.Header.Set("authorization", "Bearer "+token) + + // OAuth 透传到 ChatGPT internal API 时补齐必要头。 + if account.Type == AccountTypeOAuth { + promptCacheKey := strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()) + req.Host = "chatgpt.com" + if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" { + req.Header.Set("chatgpt-account-id", chatgptAccountID) + } + apiKeyID := getAPIKeyIDFromContext(c) + // 先保存客户端原始值,再做 compact 补充,避免后续统一隔离时读到已处理的值。 + clientSessionID := strings.TrimSpace(req.Header.Get("session_id")) + clientConversationID := strings.TrimSpace(req.Header.Get("conversation_id")) + if isOpenAIResponsesCompactPath(c) { + req.Header.Set("accept", "application/json") + if req.Header.Get("version") == "" { + req.Header.Set("version", codexCLIVersion) + } + if clientSessionID == "" { + clientSessionID = resolveOpenAICompactSessionID(c) + } + } else if req.Header.Get("accept") == "" { + req.Header.Set("accept", "text/event-stream") + } + if req.Header.Get("OpenAI-Beta") == "" { + req.Header.Set("OpenAI-Beta", "responses=experimental") + } + if req.Header.Get("originator") == "" { + req.Header.Set("originator", "codex_cli_rs") + } + // 用隔离后的 session 标识符覆盖客户端透传值,防止跨用户会话碰撞。 + if clientSessionID == "" { + clientSessionID = promptCacheKey + } + if clientConversationID == "" { + clientConversationID = promptCacheKey + } + if clientSessionID != "" { + req.Header.Set("session_id", isolateOpenAISessionID(apiKeyID, clientSessionID)) + } + if clientConversationID != "" { + req.Header.Set("conversation_id", isolateOpenAISessionID(apiKeyID, clientConversationID)) + } + } + + // 透传模式也支持账户自定义 User-Agent 与 ForceCodexCLI 兜底。 + customUA := account.GetOpenAIUserAgent() + if customUA != "" { + req.Header.Set("user-agent", customUA) + } + if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI { + req.Header.Set("user-agent", codexCLIUserAgent) + } + // OAuth 安全透传:对非 Codex UA 统一兜底,降低被上游风控拦截概率。 + if account.Type == AccountTypeOAuth && !openai.IsCodexCLIRequest(req.Header.Get("user-agent")) { + req.Header.Set("user-agent", codexCLIUserAgent) + } + + if req.Header.Get("content-type") == "" { + req.Header.Set("content-type", "application/json") + } + + return req, nil +} + +func (s *OpenAIGatewayService) handleErrorResponsePassthrough( + ctx context.Context, + resp *http.Response, + c *gin.Context, + account *Account, + requestBody []byte, +) error { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(body), maxBytes) + } + setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) + logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Passthrough: true, + Kind: "http_error", + Message: upstreamMsg, + Detail: upstreamDetail, + UpstreamResponseBody: upstreamDetail, + }) + + writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + contentType := resp.Header.Get("Content-Type") + if contentType == "" { + contentType = "application/json" + } + c.Data(resp.StatusCode, contentType, body) + + if upstreamMsg == "" { + return fmt.Errorf("upstream error: %d", resp.StatusCode) + } + return fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg) +} + +func isOpenAIPassthroughAllowedRequestHeader(lowerKey string, allowTimeoutHeaders bool) bool { + if lowerKey == "" { + return false + } + if isOpenAIPassthroughTimeoutHeader(lowerKey) { + return allowTimeoutHeaders + } + return openaiPassthroughAllowedHeaders[lowerKey] +} + +func isOpenAIPassthroughTimeoutHeader(lowerKey string) bool { + switch lowerKey { + case "x-stainless-timeout", "x-stainless-read-timeout", "x-stainless-connect-timeout", "x-request-timeout", "request-timeout", "grpc-timeout": + return true + default: + return false + } +} + +func (s *OpenAIGatewayService) isOpenAIPassthroughTimeoutHeadersAllowed() bool { + return s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIPassthroughAllowTimeoutHeaders +} + +func collectOpenAIPassthroughTimeoutHeaders(h http.Header) []string { + if h == nil { + return nil + } + var matched []string + for key, values := range h { + lowerKey := strings.ToLower(strings.TrimSpace(key)) + if isOpenAIPassthroughTimeoutHeader(lowerKey) { + entry := lowerKey + if len(values) > 0 { + entry = fmt.Sprintf("%s=%s", lowerKey, strings.Join(values, "|")) + } + matched = append(matched, entry) + } + } + sort.Strings(matched) + return matched +} + +type openaiStreamingResultPassthrough struct { + usage *OpenAIUsage + firstTokenMs *int +} + +func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( + ctx context.Context, + resp *http.Response, + c *gin.Context, + account *Account, + startTime time.Time, +) (*openaiStreamingResultPassthrough, error) { + writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + + // SSE headers + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + if v := resp.Header.Get("x-request-id"); v != "" { + c.Header("x-request-id", v) + } + + w := c.Writer + flusher, ok := w.(http.Flusher) + if !ok { + return nil, errors.New("streaming not supported") + } + + usage := &OpenAIUsage{} + var firstTokenMs *int + clientDisconnected := false + sawDone := false + sawTerminalEvent := false + upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id")) + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) + defer putSSEScannerBuf64K(scanBuf) + + for scanner.Scan() { + line := scanner.Text() + if data, ok := extractOpenAISSEDataLine(line); ok { + dataBytes := []byte(data) + trimmedData := strings.TrimSpace(data) + if trimmedData == "[DONE]" { + sawDone = true + } + if openAIStreamEventIsTerminal(trimmedData) { + sawTerminalEvent = true + } + if firstTokenMs == nil && trimmedData != "" && trimmedData != "[DONE]" { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + s.parseSSEUsageBytes(dataBytes, usage) + } + + if !clientDisconnected { + if _, err := fmt.Fprintln(w, line); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID) + } else { + flusher.Flush() + } + } + } + if err := scanner.Err(); err != nil { + if sawTerminalEvent { + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil + } + if clientDisconnected { + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete after disconnect: %w", err) + } + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete: %w", err) + } + if errors.Is(err, bufio.ErrTooLong) { + logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err) + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, err + } + logger.LegacyPrintf("service.openai_gateway", + "[OpenAI passthrough] 流读取异常中断: account=%d request_id=%s err=%v", + account.ID, + upstreamRequestID, + err, + ) + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err) + } + if !clientDisconnected && !sawDone && !sawTerminalEvent && ctx.Err() == nil { + logger.FromContext(ctx).With( + zap.String("component", "service.openai_gateway"), + zap.Int64("account_id", account.ID), + zap.String("upstream_request_id", upstreamRequestID), + ).Info("OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流") + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, errors.New("stream usage incomplete: missing terminal event") + } + + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil +} + +func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough( + ctx context.Context, + resp *http.Response, + c *gin.Context, +) (*OpenAIUsage, error) { + maxBytes := resolveUpstreamResponseReadLimit(s.cfg) + body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) + if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream response too large", + }, + }) + } + return nil, err + } + + usage := &OpenAIUsage{} + usageParsed := false + if len(body) > 0 { + if parsedUsage, ok := extractOpenAIUsageFromJSONBytes(body); ok { + *usage = parsedUsage + usageParsed = true + } + } + if !usageParsed { + // 兜底:尝试从 SSE 文本中解析 usage + usage = s.parseSSEUsageFromBody(string(body)) + } + + writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + + contentType := resp.Header.Get("Content-Type") + if contentType == "" { + contentType = "application/json" + } + c.Data(resp.StatusCode, contentType, body) + return usage, nil +} + +func writeOpenAIPassthroughResponseHeaders(dst http.Header, src http.Header, filter *responseheaders.CompiledHeaderFilter) { + if dst == nil || src == nil { + return + } + if filter != nil { + responseheaders.WriteFilteredHeaders(dst, src, filter) + } else { + // 兜底:尽量保留最基础的 content-type + if v := strings.TrimSpace(src.Get("Content-Type")); v != "" { + dst.Set("Content-Type", v) + } + } + // 透传模式强制放行 x-codex-* 响应头(若上游返回)。 + // 注意:真实 http.Response.Header 的 key 一般会被 canonicalize;但为了兼容测试/自建响应, + // 这里用 EqualFold 做一次大小写不敏感的查找。 + getCaseInsensitiveValues := func(h http.Header, want string) []string { + if h == nil { + return nil + } + for k, vals := range h { + if strings.EqualFold(k, want) { + return vals + } + } + return nil + } + + for _, rawKey := range []string{ + "x-codex-primary-used-percent", + "x-codex-primary-reset-after-seconds", + "x-codex-primary-window-minutes", + "x-codex-secondary-used-percent", + "x-codex-secondary-reset-after-seconds", + "x-codex-secondary-window-minutes", + "x-codex-primary-over-secondary-limit-percent", + } { + vals := getCaseInsensitiveValues(src, rawKey) + if len(vals) == 0 { + continue + } + key := http.CanonicalHeaderKey(rawKey) + dst.Del(key) + for _, v := range vals { + dst.Add(key, v) + } + } +} + +func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool, promptCacheKey string, isCodexCLI bool) (*http.Request, error) { + // Determine target URL based on account type + var targetURL string + switch account.Type { + case AccountTypeOAuth: + // OAuth accounts use ChatGPT internal API + targetURL = chatgptCodexURL + case AccountTypeAPIKey: + // API Key accounts use Platform API or custom base URL + baseURL := account.GetOpenAIBaseURL() + if baseURL == "" { + targetURL = openaiPlatformAPIURL + } else { + validatedURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, err + } + targetURL = buildOpenAIResponsesURL(validatedURL) + } + default: + targetURL = openaiPlatformAPIURL + } + targetURL = appendOpenAIResponsesRequestPathSuffix(targetURL, openAIResponsesRequestPathSuffix(c)) + + req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + + // Set authentication header + req.Header.Set("authorization", "Bearer "+token) + + // Set headers specific to OAuth accounts (ChatGPT internal API) + if account.Type == AccountTypeOAuth { + // Required: set Host for ChatGPT API (must use req.Host, not Header.Set) + req.Host = "chatgpt.com" + // Required: set chatgpt-account-id header + chatgptAccountID := account.GetChatGPTAccountID() + if chatgptAccountID != "" { + req.Header.Set("chatgpt-account-id", chatgptAccountID) + } + } + + // Whitelist passthrough headers + for key, values := range c.Request.Header { + lowerKey := strings.ToLower(key) + if openaiAllowedHeaders[lowerKey] { + for _, v := range values { + req.Header.Add(key, v) + } + } + } + if account.Type == AccountTypeOAuth { + // 清除客户端透传的 session 头,后续用隔离后的值重新设置,防止跨用户会话碰撞。 + req.Header.Del("conversation_id") + req.Header.Del("session_id") + + req.Header.Set("OpenAI-Beta", "responses=experimental") + req.Header.Set("originator", resolveOpenAIUpstreamOriginator(c, isCodexCLI)) + apiKeyID := getAPIKeyIDFromContext(c) + if isOpenAIResponsesCompactPath(c) { + req.Header.Set("accept", "application/json") + if req.Header.Get("version") == "" { + req.Header.Set("version", codexCLIVersion) + } + compactSession := resolveOpenAICompactSessionID(c) + req.Header.Set("session_id", isolateOpenAISessionID(apiKeyID, compactSession)) + } else { + req.Header.Set("accept", "text/event-stream") + } + if promptCacheKey != "" { + isolated := isolateOpenAISessionID(apiKeyID, promptCacheKey) + req.Header.Set("conversation_id", isolated) + req.Header.Set("session_id", isolated) + } + } + + // Apply custom User-Agent if configured + customUA := account.GetOpenAIUserAgent() + if customUA != "" { + req.Header.Set("user-agent", customUA) + } + + // 若开启 ForceCodexCLI,则强制将上游 User-Agent 伪装为 Codex CLI。 + // 用于网关未透传/改写 User-Agent 时,仍能命中 Codex 侧识别逻辑。 + if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI { + req.Header.Set("user-agent", codexCLIUserAgent) + } + + // Ensure required headers exist + if req.Header.Get("content-type") == "" { + req.Header.Set("content-type", "application/json") + } + + return req, nil +} + +func (s *OpenAIGatewayService) handleErrorResponse( + ctx context.Context, + resp *http.Response, + c *gin.Context, + account *Account, + requestBody []byte, +) (*OpenAIForwardResult, error) { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(body), maxBytes) + } + setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) + logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body) + + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + logger.LegacyPrintf("service.openai_gateway", + "OpenAI upstream error %d (account=%d platform=%s type=%s): %s", + resp.StatusCode, + account.ID, + account.Platform, + account.Type, + truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes), + ) + } + + if status, errType, errMsg, matched := applyErrorPassthroughRule( + c, + PlatformOpenAI, + resp.StatusCode, + body, + http.StatusBadGateway, + "upstream_error", + "Upstream request failed", + ); matched { + c.JSON(status, gin.H{ + "error": gin.H{ + "type": errType, + "message": errMsg, + }, + }) + if upstreamMsg == "" { + upstreamMsg = errMsg + } + if upstreamMsg == "" { + return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode) + } + return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, upstreamMsg) + } + + // Check custom error codes + if !account.ShouldHandleErrorCode(resp.StatusCode) { + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "http_error", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream gateway error", + }, + }) + if upstreamMsg == "" { + return nil, fmt.Errorf("upstream error: %d (not in custom error codes)", resp.StatusCode) + } + return nil, fmt.Errorf("upstream error: %d (not in custom error codes) message=%s", resp.StatusCode, upstreamMsg) + } + + // Handle upstream error (mark account status) + shouldDisable := false + if s.rateLimitService != nil { + shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) + } + kind := "http_error" + if shouldDisable { + kind = "failover" + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: kind, + Message: upstreamMsg, + Detail: upstreamDetail, + }) + if shouldDisable { + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: body, + RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + } + } + + // Return appropriate error response + var errType, errMsg string + var statusCode int + + switch resp.StatusCode { + case 401: + statusCode = http.StatusBadGateway + errType = "upstream_error" + errMsg = "Upstream authentication failed, please contact administrator" + case 402: + statusCode = http.StatusBadGateway + errType = "upstream_error" + errMsg = "Upstream payment required: insufficient balance or billing issue" + case 403: + statusCode = http.StatusBadGateway + errType = "upstream_error" + errMsg = "Upstream access forbidden, please contact administrator" + case 429: + statusCode = http.StatusTooManyRequests + errType = "rate_limit_error" + errMsg = "Upstream rate limit exceeded, please retry later" + default: + statusCode = http.StatusBadGateway + errType = "upstream_error" + errMsg = "Upstream request failed" + } + + c.JSON(statusCode, gin.H{ + "error": gin.H{ + "type": errType, + "message": errMsg, + }, + }) + + if upstreamMsg == "" { + return nil, fmt.Errorf("upstream error: %d", resp.StatusCode) + } + return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg) +} + +// compatErrorWriter is the signature for format-specific error writers used by +// the compat paths (Chat Completions and Anthropic Messages). +type compatErrorWriter func(c *gin.Context, statusCode int, errType, message string) + +// handleCompatErrorResponse is the shared non-failover error handler for the +// Chat Completions and Anthropic Messages compat paths. It mirrors the logic of +// handleErrorResponse (passthrough rules, ShouldHandleErrorCode, rate-limit +// tracking, secondary failover) but delegates the final error write to the +// format-specific writer function. +func (s *OpenAIGatewayService) handleCompatErrorResponse( + resp *http.Response, + c *gin.Context, + account *Account, + writeError compatErrorWriter, +) (*OpenAIForwardResult, error) { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body)) + if upstreamMsg == "" { + upstreamMsg = fmt.Sprintf("Upstream error: %d", resp.StatusCode) + } + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(body), maxBytes) + } + setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) + + // Apply error passthrough rules + if status, errType, errMsg, matched := applyErrorPassthroughRule( + c, account.Platform, resp.StatusCode, body, + http.StatusBadGateway, "api_error", "Upstream request failed", + ); matched { + writeError(c, status, errType, errMsg) + if upstreamMsg == "" { + upstreamMsg = errMsg + } + if upstreamMsg == "" { + return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode) + } + return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, upstreamMsg) + } + + // Check custom error codes — if the account does not handle this status, + // return a generic error without exposing upstream details. + if !account.ShouldHandleErrorCode(resp.StatusCode) { + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "http_error", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + writeError(c, http.StatusInternalServerError, "api_error", "Upstream gateway error") + if upstreamMsg == "" { + return nil, fmt.Errorf("upstream error: %d (not in custom error codes)", resp.StatusCode) + } + return nil, fmt.Errorf("upstream error: %d (not in custom error codes) message=%s", resp.StatusCode, upstreamMsg) + } + + // Track rate limits and decide whether to trigger secondary failover. + shouldDisable := false + if s.rateLimitService != nil { + shouldDisable = s.rateLimitService.HandleUpstreamError( + c.Request.Context(), account, resp.StatusCode, resp.Header, body, + ) + } + kind := "http_error" + if shouldDisable { + kind = "failover" + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: kind, + Message: upstreamMsg, + Detail: upstreamDetail, + }) + if shouldDisable { + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: body, + RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + } + } + + // Map status code to error type and write response + errType := "api_error" + switch { + case resp.StatusCode == 400: + errType = "invalid_request_error" + case resp.StatusCode == 404: + errType = "not_found_error" + case resp.StatusCode == 429: + errType = "rate_limit_error" + case resp.StatusCode >= 500: + errType = "api_error" + } + + writeError(c, resp.StatusCode, errType, upstreamMsg) + return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg) +} + +// openaiStreamingResult streaming response result +type openaiStreamingResult struct { + usage *OpenAIUsage + firstTokenMs *int +} + +func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) { + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + + // Set SSE response headers + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + + // Pass through other headers + if v := resp.Header.Get("x-request-id"); v != "" { + c.Header("x-request-id", v) + } + + w := c.Writer + flusher, ok := w.(http.Flusher) + if !ok { + return nil, errors.New("streaming not supported") + } + bufferedWriter := bufio.NewWriterSize(w, 4*1024) + flushBuffered := func() error { + if err := bufferedWriter.Flush(); err != nil { + return err + } + flusher.Flush() + return nil + } + + usage := &OpenAIUsage{} + var firstTokenMs *int + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) + + streamInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { + streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second + } + // 仅监控上游数据间隔超时,不被下游写入阻塞影响 + var intervalTicker *time.Ticker + if streamInterval > 0 { + intervalTicker = time.NewTicker(streamInterval) + defer intervalTicker.Stop() + } + var intervalCh <-chan time.Time + if intervalTicker != nil { + intervalCh = intervalTicker.C + } + + keepaliveInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 { + keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second + } + // 下游 keepalive 仅用于防止代理空闲断开 + var keepaliveTicker *time.Ticker + if keepaliveInterval > 0 { + keepaliveTicker = time.NewTicker(keepaliveInterval) + defer keepaliveTicker.Stop() + } + var keepaliveCh <-chan time.Time + if keepaliveTicker != nil { + keepaliveCh = keepaliveTicker.C + } + // 记录上次收到上游数据的时间,用于控制 keepalive 发送频率 + lastDataAt := time.Now() + + // 仅发送一次错误事件,避免多次写入导致协议混乱。 + // 注意:OpenAI `/v1/responses` streaming 事件必须符合 OpenAI Responses schema; + // 否则下游 SDK(例如 OpenCode)会因为类型校验失败而报错。 + errorEventSent := false + clientDisconnected := false // 客户端断开后继续 drain 上游以收集 usage + sawTerminalEvent := false + sendErrorEvent := func(reason string) { + if errorEventSent || clientDisconnected { + return + } + errorEventSent = true + payload := `{"type":"error","sequence_number":0,"error":{"type":"upstream_error","message":` + strconv.Quote(reason) + `,"code":` + strconv.Quote(reason) + `}}` + if err := flushBuffered(); err != nil { + clientDisconnected = true + return + } + if _, err := bufferedWriter.WriteString("data: " + payload + "\n\n"); err != nil { + clientDisconnected = true + return + } + if err := flushBuffered(); err != nil { + clientDisconnected = true + } + } + + needModelReplace := originalModel != mappedModel + resultWithUsage := func() *openaiStreamingResult { + return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs} + } + finalizeStream := func() (*openaiStreamingResult, error) { + if !clientDisconnected { + if err := flushBuffered(); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during final flush, returning collected usage") + } + } + if !sawTerminalEvent { + return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event") + } + return resultWithUsage(), nil + } + handleScanErr := func(scanErr error) (*openaiStreamingResult, error, bool) { + if scanErr == nil { + return nil, nil, false + } + if sawTerminalEvent { + logger.LegacyPrintf("service.openai_gateway", "Upstream scan ended after terminal event: %v", scanErr) + return resultWithUsage(), nil, true + } + // 客户端断开/取消请求时,上游读取往往会返回 context canceled。 + // /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。 + if errors.Is(scanErr, context.Canceled) || errors.Is(scanErr, context.DeadlineExceeded) { + return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", scanErr), true + } + // 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage + if clientDisconnected { + return resultWithUsage(), fmt.Errorf("stream usage incomplete after disconnect: %w", scanErr), true + } + if errors.Is(scanErr, bufio.ErrTooLong) { + logger.LegacyPrintf("service.openai_gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, scanErr) + sendErrorEvent("response_too_large") + return resultWithUsage(), scanErr, true + } + sendErrorEvent("stream_read_error") + return resultWithUsage(), fmt.Errorf("stream read error: %w", scanErr), true + } + processSSELine := func(line string, queueDrained bool) { + lastDataAt = time.Now() + + // Extract data from SSE line (supports both "data: " and "data:" formats) + if data, ok := extractOpenAISSEDataLine(line); ok { + + // Replace model in response if needed. + // Fast path: most events do not contain model field values. + if needModelReplace && mappedModel != "" && strings.Contains(data, mappedModel) { + line = s.replaceModelInSSELine(line, mappedModel, originalModel) + } + + dataBytes := []byte(data) + if openAIStreamEventIsTerminal(data) { + sawTerminalEvent = true + } + + // Correct Codex tool calls if needed (apply_patch -> edit, etc.) + if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected { + dataBytes = correctedData + data = string(correctedData) + line = "data: " + data + } + + // 写入客户端(客户端断开后继续 drain 上游) + if !clientDisconnected { + shouldFlush := queueDrained + if firstTokenMs == nil && data != "" && data != "[DONE]" { + // 保证首个 token 事件尽快出站,避免影响 TTFT。 + shouldFlush = true + } + if _, err := bufferedWriter.WriteString(line); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing") + } else if _, err := bufferedWriter.WriteString("\n"); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing") + } else if shouldFlush { + if err := flushBuffered(); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing") + } + } + } + + // Record first token time + if firstTokenMs == nil && data != "" && data != "[DONE]" { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + s.parseSSEUsageBytes(dataBytes, usage) + return + } + + // Forward non-data lines as-is + if !clientDisconnected { + if _, err := bufferedWriter.WriteString(line); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing") + } else if _, err := bufferedWriter.WriteString("\n"); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing") + } else if queueDrained { + if err := flushBuffered(); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing") + } + } + } + } + + // 无超时/无 keepalive 的常见路径走同步扫描,减少 goroutine 与 channel 开销。 + if streamInterval <= 0 && keepaliveInterval <= 0 { + defer putSSEScannerBuf64K(scanBuf) + for scanner.Scan() { + processSSELine(scanner.Text(), true) + } + if result, err, done := handleScanErr(scanner.Err()); done { + return result, err + } + return finalizeStream() + } + + type scanEvent struct { + line string + err error + } + // 独立 goroutine 读取上游,避免读取阻塞影响 keepalive/超时处理 + events := make(chan scanEvent, 16) + done := make(chan struct{}) + sendEvent := func(ev scanEvent) bool { + select { + case events <- ev: + return true + case <-done: + return false + } + } + var lastReadAt int64 + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + go func(scanBuf *sseScannerBuf64K) { + defer putSSEScannerBuf64K(scanBuf) + defer close(events) + for scanner.Scan() { + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + if !sendEvent(scanEvent{line: scanner.Text()}) { + return + } + } + if err := scanner.Err(); err != nil { + _ = sendEvent(scanEvent{err: err}) + } + }(scanBuf) + defer close(done) + + for { + select { + case ev, ok := <-events: + if !ok { + return finalizeStream() + } + if result, err, done := handleScanErr(ev.err); done { + return result, err + } + processSSELine(ev.line, len(events) == 0) + + case <-intervalCh: + lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) + if time.Since(lastRead) < streamInterval { + continue + } + if clientDisconnected { + return resultWithUsage(), fmt.Errorf("stream usage incomplete after timeout") + } + logger.LegacyPrintf("service.openai_gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) + // 处理流超时,可能标记账户为临时不可调度或错误状态 + if s.rateLimitService != nil { + s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel) + } + sendErrorEvent("stream_timeout") + return resultWithUsage(), fmt.Errorf("stream data interval timeout") + + case <-keepaliveCh: + if clientDisconnected { + continue + } + if time.Since(lastDataAt) < keepaliveInterval { + continue + } + if _, err := bufferedWriter.WriteString(":\n\n"); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing") + continue + } + if err := flushBuffered(); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during keepalive flush, continuing to drain upstream for billing") + } + } + } + +} + +// extractOpenAISSEDataLine 低开销提取 SSE `data:` 行内容。 +// 兼容 `data: xxx` 与 `data:xxx` 两种格式。 +func extractOpenAISSEDataLine(line string) (string, bool) { + if !strings.HasPrefix(line, "data:") { + return "", false + } + start := len("data:") + for start < len(line) { + if line[start] != ' ' && line[start] != ' ' { + break + } + start++ + } + return line[start:], true +} + +func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel string) string { + data, ok := extractOpenAISSEDataLine(line) + if !ok { + return line + } + if data == "" || data == "[DONE]" { + return line + } + + // 使用 gjson 精确检查 model 字段,避免全量 JSON 反序列化 + if m := gjson.Get(data, "model"); m.Exists() && m.Str == fromModel { + newData, err := sjson.Set(data, "model", toModel) + if err != nil { + return line + } + return "data: " + newData + } + + // 检查嵌套的 response.model 字段 + if m := gjson.Get(data, "response.model"); m.Exists() && m.Str == fromModel { + newData, err := sjson.Set(data, "response.model", toModel) + if err != nil { + return line + } + return "data: " + newData + } + + return line +} + +// correctToolCallsInResponseBody 修正响应体中的工具调用 +func (s *OpenAIGatewayService) correctToolCallsInResponseBody(body []byte) []byte { + if len(body) == 0 { + return body + } + + corrected, changed := s.toolCorrector.CorrectToolCallsInSSEBytes(body) + if changed { + return corrected + } + return body +} + +func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) { + s.parseSSEUsageBytes([]byte(data), usage) +} + +func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsage) { + if usage == nil || len(data) == 0 || bytes.Equal(data, []byte("[DONE]")) { + return + } + // 选择性解析:仅在数据中包含终止事件标识时才进入字段提取。 + if len(data) < 72 { + return + } + eventType := gjson.GetBytes(data, "type").String() + if eventType != "response.completed" && eventType != "response.done" { + return + } + + usage.InputTokens = int(gjson.GetBytes(data, "response.usage.input_tokens").Int()) + usage.OutputTokens = int(gjson.GetBytes(data, "response.usage.output_tokens").Int()) + usage.CacheReadInputTokens = int(gjson.GetBytes(data, "response.usage.input_tokens_details.cached_tokens").Int()) +} + +func extractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) { + if len(body) == 0 || !gjson.ValidBytes(body) { + return OpenAIUsage{}, false + } + values := gjson.GetManyBytes( + body, + "usage.input_tokens", + "usage.output_tokens", + "usage.input_tokens_details.cached_tokens", + ) + return OpenAIUsage{ + InputTokens: int(values[0].Int()), + OutputTokens: int(values[1].Int()), + CacheReadInputTokens: int(values[2].Int()), + }, true +} + +func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) { + maxBytes := resolveUpstreamResponseReadLimit(s.cfg) + body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) + if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream response too large", + }, + }) + } + return nil, err + } + + if account.Type == AccountTypeOAuth { + bodyLooksLikeSSE := bytes.Contains(body, []byte("data:")) || bytes.Contains(body, []byte("event:")) + if isEventStreamResponse(resp.Header) || bodyLooksLikeSSE { + return s.handleOAuthSSEToJSON(resp, c, body, originalModel, mappedModel) + } + } + + usageValue, usageOK := extractOpenAIUsageFromJSONBytes(body) + if !usageOK { + return nil, fmt.Errorf("parse response: invalid json response") + } + usage := &usageValue + + // Replace model in response if needed + if originalModel != mappedModel { + body = s.replaceModelInResponseBody(body, mappedModel, originalModel) + } + + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + + contentType := "application/json" + if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled { + if upstreamType := resp.Header.Get("Content-Type"); upstreamType != "" { + contentType = upstreamType + } + } + + c.Data(resp.StatusCode, contentType, body) + + return usage, nil +} + +func isEventStreamResponse(header http.Header) bool { + contentType := strings.ToLower(header.Get("Content-Type")) + return strings.Contains(contentType, "text/event-stream") +} + +func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel, mappedModel string) (*OpenAIUsage, error) { + bodyText := string(body) + finalResponse, ok := extractCodexFinalResponse(bodyText) + + usage := &OpenAIUsage{} + if ok { + if parsedUsage, parsed := extractOpenAIUsageFromJSONBytes(finalResponse); parsed { + *usage = parsedUsage + } + body = finalResponse + if originalModel != mappedModel { + body = s.replaceModelInResponseBody(body, mappedModel, originalModel) + } + // Correct tool calls in final response + body = s.correctToolCallsInResponseBody(body) + } else { + terminalType, terminalPayload, terminalOK := extractOpenAISSETerminalEvent(bodyText) + if terminalOK && terminalType == "response.failed" { + msg := extractOpenAISSEErrorMessage(terminalPayload) + if msg == "" { + msg = "Upstream compact response failed" + } + return nil, s.writeOpenAINonStreamingProtocolError(resp, c, msg) + } + usage = s.parseSSEUsageFromBody(bodyText) + if originalModel != mappedModel { + bodyText = s.replaceModelInSSEBody(bodyText, mappedModel, originalModel) + } + body = []byte(bodyText) + } + + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + + contentType := "application/json; charset=utf-8" + if !ok { + contentType = resp.Header.Get("Content-Type") + if contentType == "" { + contentType = "text/event-stream" + } + } + c.Data(resp.StatusCode, contentType, body) + + return usage, nil +} + +func extractOpenAISSETerminalEvent(body string) (string, []byte, bool) { + lines := strings.Split(body, "\n") + for _, line := range lines { + data, ok := extractOpenAISSEDataLine(line) + if !ok || data == "" || data == "[DONE]" { + continue + } + eventType := strings.TrimSpace(gjson.Get(data, "type").String()) + switch eventType { + case "response.completed", "response.done", "response.failed": + return eventType, []byte(data), true + } + } + return "", nil, false +} + +func extractOpenAISSEErrorMessage(payload []byte) string { + if len(payload) == 0 { + return "" + } + for _, path := range []string{"response.error.message", "error.message", "message"} { + if msg := strings.TrimSpace(gjson.GetBytes(payload, path).String()); msg != "" { + return sanitizeUpstreamErrorMessage(msg) + } + } + return sanitizeUpstreamErrorMessage(strings.TrimSpace(extractUpstreamErrorMessage(payload))) +} + +func (s *OpenAIGatewayService) writeOpenAINonStreamingProtocolError(resp *http.Response, c *gin.Context, message string) error { + message = sanitizeUpstreamErrorMessage(strings.TrimSpace(message)) + if message == "" { + message = "Upstream returned an invalid non-streaming response" + } + setOpsUpstreamError(c, http.StatusBadGateway, message, "") + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8") + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": message, + }, + }) + return fmt.Errorf("non-streaming openai protocol error: %s", message) +} + +func extractCodexFinalResponse(body string) ([]byte, bool) { + lines := strings.Split(body, "\n") + for _, line := range lines { + data, ok := extractOpenAISSEDataLine(line) + if !ok { + continue + } + if data == "" || data == "[DONE]" { + continue + } + eventType := gjson.Get(data, "type").String() + if eventType == "response.done" || eventType == "response.completed" { + if response := gjson.Get(data, "response"); response.Exists() && response.Type == gjson.JSON && response.Raw != "" { + return []byte(response.Raw), true + } + } + } + return nil, false +} + +func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage { + usage := &OpenAIUsage{} + lines := strings.Split(body, "\n") + for _, line := range lines { + data, ok := extractOpenAISSEDataLine(line) + if !ok { + continue + } + if data == "" || data == "[DONE]" { + continue + } + s.parseSSEUsageBytes([]byte(data), usage) + } + return usage +} + +func (s *OpenAIGatewayService) replaceModelInSSEBody(body, fromModel, toModel string) string { + lines := strings.Split(body, "\n") + for i, line := range lines { + if _, ok := extractOpenAISSEDataLine(line); !ok { + continue + } + lines[i] = s.replaceModelInSSELine(line, fromModel, toModel) + } + return strings.Join(lines, "\n") +} + +func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, error) { + if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled { + normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP) + if err != nil { + return "", fmt.Errorf("invalid base_url: %w", err) + } + return normalized, nil + } + normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{ + AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts, + RequireAllowlist: true, + AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts, + }) + if err != nil { + return "", fmt.Errorf("invalid base_url: %w", err) + } + return normalized, nil +} + +// buildOpenAIResponsesURL 组装 OpenAI Responses 端点。 +// - base 以 /v1 结尾:追加 /responses +// - base 已是 /responses:原样返回 +// - 其他情况:追加 /v1/responses +func buildOpenAIResponsesURL(base string) string { + normalized := strings.TrimRight(strings.TrimSpace(base), "/") + if strings.HasSuffix(normalized, "/responses") { + return normalized + } + if strings.HasSuffix(normalized, "/v1") { + return normalized + "/responses" + } + return normalized + "/v1/responses" +} + +func trimOpenAIEncryptedReasoningItems(reqBody map[string]any) bool { + if len(reqBody) == 0 { + return false + } + + inputValue, has := reqBody["input"] + if !has { + return false + } + + switch input := inputValue.(type) { + case []any: + filtered := input[:0] + changed := false + for _, item := range input { + nextItem, itemChanged, keep := sanitizeEncryptedReasoningInputItem(item) + if itemChanged { + changed = true + } + if !keep { + continue + } + filtered = append(filtered, nextItem) + } + if !changed { + return false + } + if len(filtered) == 0 { + delete(reqBody, "input") + return true + } + reqBody["input"] = filtered + return true + case []map[string]any: + filtered := input[:0] + changed := false + for _, item := range input { + nextItem, itemChanged, keep := sanitizeEncryptedReasoningInputItem(item) + if itemChanged { + changed = true + } + if !keep { + continue + } + nextMap, ok := nextItem.(map[string]any) + if !ok { + filtered = append(filtered, item) + continue + } + filtered = append(filtered, nextMap) + } + if !changed { + return false + } + if len(filtered) == 0 { + delete(reqBody, "input") + return true + } + reqBody["input"] = filtered + return true + case map[string]any: + nextItem, changed, keep := sanitizeEncryptedReasoningInputItem(input) + if !changed { + return false + } + if !keep { + delete(reqBody, "input") + return true + } + nextMap, ok := nextItem.(map[string]any) + if !ok { + return false + } + reqBody["input"] = nextMap + return true + default: + return false + } +} + +func sanitizeEncryptedReasoningInputItem(item any) (next any, changed bool, keep bool) { + inputItem, ok := item.(map[string]any) + if !ok { + return item, false, true + } + + itemType, _ := inputItem["type"].(string) + if strings.TrimSpace(itemType) != "reasoning" { + return item, false, true + } + + _, hasEncryptedContent := inputItem["encrypted_content"] + if !hasEncryptedContent { + return item, false, true + } + + delete(inputItem, "encrypted_content") + if len(inputItem) == 1 { + return nil, true, false + } + return inputItem, true, true +} + +func IsOpenAIResponsesCompactPathForTest(c *gin.Context) bool { + return isOpenAIResponsesCompactPath(c) +} + +func OpenAICompactSessionSeedKeyForTest() string { + return openAICompactSessionSeedKey +} + +func NormalizeOpenAICompactRequestBodyForTest(body []byte) ([]byte, bool, error) { + return normalizeOpenAICompactRequestBody(body) +} + +func isOpenAIResponsesCompactPath(c *gin.Context) bool { + suffix := strings.TrimSpace(openAIResponsesRequestPathSuffix(c)) + return suffix == "/compact" || strings.HasPrefix(suffix, "/compact/") +} + +func normalizeOpenAICompactRequestBody(body []byte) ([]byte, bool, error) { + if len(body) == 0 { + return body, false, nil + } + + normalized := []byte(`{}`) + for _, field := range []string{"model", "input", "instructions", "previous_response_id"} { + value := gjson.GetBytes(body, field) + if !value.Exists() { + continue + } + next, err := sjson.SetRawBytes(normalized, field, []byte(value.Raw)) + if err != nil { + return body, false, fmt.Errorf("normalize compact body %s: %w", field, err) + } + normalized = next + } + + if bytes.Equal(bytes.TrimSpace(body), bytes.TrimSpace(normalized)) { + return body, false, nil + } + return normalized, true, nil +} + +func resolveOpenAICompactSessionID(c *gin.Context) string { + if c != nil { + if sessionID := strings.TrimSpace(c.GetHeader("session_id")); sessionID != "" { + return sessionID + } + if conversationID := strings.TrimSpace(c.GetHeader("conversation_id")); conversationID != "" { + return conversationID + } + if seed, ok := c.Get(openAICompactSessionSeedKey); ok { + if seedStr, ok := seed.(string); ok && strings.TrimSpace(seedStr) != "" { + return strings.TrimSpace(seedStr) + } + } + } + return uuid.NewString() +} + +func openAIResponsesRequestPathSuffix(c *gin.Context) string { + if c == nil || c.Request == nil || c.Request.URL == nil { + return "" + } + normalizedPath := strings.TrimRight(strings.TrimSpace(c.Request.URL.Path), "/") + if normalizedPath == "" { + return "" + } + idx := strings.LastIndex(normalizedPath, "/responses") + if idx < 0 { + return "" + } + suffix := normalizedPath[idx+len("/responses"):] + if suffix == "" || suffix == "/" { + return "" + } + if !strings.HasPrefix(suffix, "/") { + return "" + } + return suffix +} + +func appendOpenAIResponsesRequestPathSuffix(baseURL, suffix string) string { + trimmedBase := strings.TrimRight(strings.TrimSpace(baseURL), "/") + trimmedSuffix := strings.TrimSpace(suffix) + if trimmedBase == "" || trimmedSuffix == "" { + return trimmedBase + } + return trimmedBase + trimmedSuffix +} + +func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte { + // 使用 gjson/sjson 精确替换 model 字段,避免全量 JSON 反序列化 + if m := gjson.GetBytes(body, "model"); m.Exists() && m.Str == fromModel { + newBody, err := sjson.SetBytes(body, "model", toModel) + if err != nil { + return body + } + return newBody + } + return body +} + +// OpenAIRecordUsageInput input for recording usage +type OpenAIRecordUsageInput struct { + Result *OpenAIForwardResult + APIKey *APIKey + User *User + Account *Account + Subscription *UserSubscription + InboundEndpoint string + UpstreamEndpoint string + UserAgent string // 请求的 User-Agent + IPAddress string // 请求的客户端 IP 地址 + RequestPayloadHash string + APIKeyService APIKeyQuotaUpdater +} + +// RecordUsage records usage and deducts balance +func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error { + result := input.Result + + // 跳过所有 token 均为零的用量记录——上游未返回 usage 时不应写入数据库 + if result.Usage.InputTokens == 0 && result.Usage.OutputTokens == 0 && + result.Usage.CacheCreationInputTokens == 0 && result.Usage.CacheReadInputTokens == 0 { + return nil + } + + apiKey := input.APIKey + user := input.User + account := input.Account + subscription := input.Subscription + + // 计算实际的新输入token(减去缓存读取的token) + // 因为 input_tokens 包含了 cache_read_tokens,而缓存读取的token不应按输入价格计费 + actualInputTokens := result.Usage.InputTokens - result.Usage.CacheReadInputTokens + if actualInputTokens < 0 { + actualInputTokens = 0 + } + + // Calculate cost + tokens := UsageTokens{ + InputTokens: actualInputTokens, + OutputTokens: result.Usage.OutputTokens, + CacheCreationTokens: result.Usage.CacheCreationInputTokens, + CacheReadTokens: result.Usage.CacheReadInputTokens, + } + + // Get rate multiplier + multiplier := s.cfg.Default.RateMultiplier + if apiKey.GroupID != nil && apiKey.Group != nil { + resolver := s.userGroupRateResolver + if resolver == nil { + resolver = newUserGroupRateResolver(nil, nil, resolveUserGroupRateCacheTTL(s.cfg), nil, "service.openai_gateway") + } + multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier) + } + + billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel) + if result.BillingModel != "" { + billingModel = strings.TrimSpace(result.BillingModel) + } + serviceTier := "" + if result.ServiceTier != nil { + serviceTier = strings.TrimSpace(*result.ServiceTier) + } + cost, err := s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier) + if err != nil { + cost = &CostBreakdown{ActualCost: 0} + } + + // Determine billing type + isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType() + billingType := BillingTypeBalance + if isSubscriptionBilling { + billingType = BillingTypeSubscription + } + + // Create usage log + durationMs := int(result.Duration.Milliseconds()) + accountRateMultiplier := account.BillingRateMultiplier() + requestID := resolveUsageBillingRequestID(ctx, result.RequestID) + usageLog := &UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: requestID, + Model: result.Model, + RequestedModel: result.Model, + UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), + ServiceTier: result.ServiceTier, + ReasoningEffort: result.ReasoningEffort, + InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), + UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint), + InputTokens: actualInputTokens, + OutputTokens: result.Usage.OutputTokens, + CacheCreationTokens: result.Usage.CacheCreationInputTokens, + CacheReadTokens: result.Usage.CacheReadInputTokens, + InputCost: cost.InputCost, + OutputCost: cost.OutputCost, + CacheCreationCost: cost.CacheCreationCost, + CacheReadCost: cost.CacheReadCost, + TotalCost: cost.TotalCost, + ActualCost: cost.ActualCost, + RateMultiplier: multiplier, + AccountRateMultiplier: &accountRateMultiplier, + BillingType: billingType, + Stream: result.Stream, + OpenAIWSMode: result.OpenAIWSMode, + DurationMs: &durationMs, + FirstTokenMs: result.FirstTokenMs, + CreatedAt: time.Now(), + } + // 添加 UserAgent + if input.UserAgent != "" { + usageLog.UserAgent = &input.UserAgent + } + + // 添加 IPAddress + if input.IPAddress != "" { + usageLog.IPAddress = &input.IPAddress + } + + if apiKey.GroupID != nil { + usageLog.GroupID = apiKey.GroupID + } + if subscription != nil { + usageLog.SubscriptionID = &subscription.ID + } + + if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway") + logger.LegacyPrintf("service.openai_gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) + s.deferredService.ScheduleLastUsedUpdate(account.ID) + return nil + } + + billingErr := func() error { + _, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{ + Cost: cost, + User: user, + APIKey: apiKey, + Account: account, + Subscription: subscription, + RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash), + IsSubscriptionBill: isSubscriptionBilling, + AccountRateMultiplier: accountRateMultiplier, + APIKeyService: input.APIKeyService, + }, s.billingDeps(), s.usageBillingRepo) + return err + }() + + if billingErr != nil { + return billingErr + } + writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway") + + return nil +} + +// ParseCodexRateLimitHeaders extracts Codex usage limits from response headers. +// Exported for use in ratelimit_service when handling OpenAI 429 responses. +func ParseCodexRateLimitHeaders(headers http.Header) *OpenAICodexUsageSnapshot { + snapshot := &OpenAICodexUsageSnapshot{} + hasData := false + + // Helper to parse float64 from header + parseFloat := func(key string) *float64 { + if v := headers.Get(key); v != "" { + if f, err := strconv.ParseFloat(v, 64); err == nil { + return &f + } + } + return nil + } + + // Helper to parse int from header + parseInt := func(key string) *int { + if v := headers.Get(key); v != "" { + if i, err := strconv.Atoi(v); err == nil { + return &i + } + } + return nil + } + + // Primary (weekly) limits + if v := parseFloat("x-codex-primary-used-percent"); v != nil { + snapshot.PrimaryUsedPercent = v + hasData = true + } + if v := parseInt("x-codex-primary-reset-after-seconds"); v != nil { + snapshot.PrimaryResetAfterSeconds = v + hasData = true + } + if v := parseInt("x-codex-primary-window-minutes"); v != nil { + snapshot.PrimaryWindowMinutes = v + hasData = true + } + + // Secondary (5h) limits + if v := parseFloat("x-codex-secondary-used-percent"); v != nil { + snapshot.SecondaryUsedPercent = v + hasData = true + } + if v := parseInt("x-codex-secondary-reset-after-seconds"); v != nil { + snapshot.SecondaryResetAfterSeconds = v + hasData = true + } + if v := parseInt("x-codex-secondary-window-minutes"); v != nil { + snapshot.SecondaryWindowMinutes = v + hasData = true + } + + // Overflow ratio + if v := parseFloat("x-codex-primary-over-secondary-limit-percent"); v != nil { + snapshot.PrimaryOverSecondaryPercent = v + hasData = true + } + + if !hasData { + return nil + } + + snapshot.UpdatedAt = time.Now().Format(time.RFC3339) + return snapshot +} + +func codexSnapshotBaseTime(snapshot *OpenAICodexUsageSnapshot, fallback time.Time) time.Time { + if snapshot == nil { + return fallback + } + if snapshot.UpdatedAt == "" { + return fallback + } + base, err := time.Parse(time.RFC3339, snapshot.UpdatedAt) + if err != nil { + return fallback + } + return base +} + +func codexResetAtRFC3339(base time.Time, resetAfterSeconds *int) *string { + if resetAfterSeconds == nil { + return nil + } + sec := *resetAfterSeconds + if sec < 0 { + sec = 0 + } + resetAt := base.Add(time.Duration(sec) * time.Second).Format(time.RFC3339) + return &resetAt +} + +func buildCodexUsageExtraUpdates(snapshot *OpenAICodexUsageSnapshot, fallbackNow time.Time) map[string]any { + if snapshot == nil { + return nil + } + + baseTime := codexSnapshotBaseTime(snapshot, fallbackNow) + updates := make(map[string]any) + + // 保存原始 primary/secondary 字段,便于排查问题 + if snapshot.PrimaryUsedPercent != nil { + updates["codex_primary_used_percent"] = *snapshot.PrimaryUsedPercent + } + if snapshot.PrimaryResetAfterSeconds != nil { + updates["codex_primary_reset_after_seconds"] = *snapshot.PrimaryResetAfterSeconds + } + if snapshot.PrimaryWindowMinutes != nil { + updates["codex_primary_window_minutes"] = *snapshot.PrimaryWindowMinutes + } + if snapshot.SecondaryUsedPercent != nil { + updates["codex_secondary_used_percent"] = *snapshot.SecondaryUsedPercent + } + if snapshot.SecondaryResetAfterSeconds != nil { + updates["codex_secondary_reset_after_seconds"] = *snapshot.SecondaryResetAfterSeconds + } + if snapshot.SecondaryWindowMinutes != nil { + updates["codex_secondary_window_minutes"] = *snapshot.SecondaryWindowMinutes + } + if snapshot.PrimaryOverSecondaryPercent != nil { + updates["codex_primary_over_secondary_percent"] = *snapshot.PrimaryOverSecondaryPercent + } + updates["codex_usage_updated_at"] = baseTime.Format(time.RFC3339) + + // 归一化到 5h/7d 规范字段 + if normalized := snapshot.Normalize(); normalized != nil { + if normalized.Used5hPercent != nil { + updates["codex_5h_used_percent"] = *normalized.Used5hPercent + } + if normalized.Reset5hSeconds != nil { + updates["codex_5h_reset_after_seconds"] = *normalized.Reset5hSeconds + } + if normalized.Window5hMinutes != nil { + updates["codex_5h_window_minutes"] = *normalized.Window5hMinutes + } + if normalized.Used7dPercent != nil { + updates["codex_7d_used_percent"] = *normalized.Used7dPercent + } + if normalized.Reset7dSeconds != nil { + updates["codex_7d_reset_after_seconds"] = *normalized.Reset7dSeconds + } + if normalized.Window7dMinutes != nil { + updates["codex_7d_window_minutes"] = *normalized.Window7dMinutes + } + if reset5hAt := codexResetAtRFC3339(baseTime, normalized.Reset5hSeconds); reset5hAt != nil { + updates["codex_5h_reset_at"] = *reset5hAt + } + if reset7dAt := codexResetAtRFC3339(baseTime, normalized.Reset7dSeconds); reset7dAt != nil { + updates["codex_7d_reset_at"] = *reset7dAt + } + } + + return updates +} + +func codexUsagePercentExhausted(value *float64) bool { + return value != nil && *value >= 100-1e-9 +} + +func codexRateLimitResetAtFromSnapshot(snapshot *OpenAICodexUsageSnapshot, fallbackNow time.Time) *time.Time { + if snapshot == nil { + return nil + } + normalized := snapshot.Normalize() + if normalized == nil { + return nil + } + baseTime := codexSnapshotBaseTime(snapshot, fallbackNow) + if codexUsagePercentExhausted(normalized.Used7dPercent) && normalized.Reset7dSeconds != nil { + resetAt := baseTime.Add(time.Duration(*normalized.Reset7dSeconds) * time.Second) + return &resetAt + } + if codexUsagePercentExhausted(normalized.Used5hPercent) && normalized.Reset5hSeconds != nil { + resetAt := baseTime.Add(time.Duration(*normalized.Reset5hSeconds) * time.Second) + return &resetAt + } + return nil +} + +func codexRateLimitResetAtFromExtra(extra map[string]any, now time.Time) *time.Time { + if len(extra) == 0 { + return nil + } + if progress := buildCodexUsageProgressFromExtra(extra, "7d", now); progress != nil && codexUsagePercentExhausted(&progress.Utilization) && progress.ResetsAt != nil && now.Before(*progress.ResetsAt) { + resetAt := progress.ResetsAt.UTC() + return &resetAt + } + if progress := buildCodexUsageProgressFromExtra(extra, "5h", now); progress != nil && codexUsagePercentExhausted(&progress.Utilization) && progress.ResetsAt != nil && now.Before(*progress.ResetsAt) { + resetAt := progress.ResetsAt.UTC() + return &resetAt + } + return nil +} + +func applyOpenAICodexRateLimitFromExtra(account *Account, now time.Time) (*time.Time, bool) { + if account == nil || !account.IsOpenAI() { + return nil, false + } + resetAt := codexRateLimitResetAtFromExtra(account.Extra, now) + if resetAt == nil { + return nil, false + } + if account.RateLimitResetAt != nil && now.Before(*account.RateLimitResetAt) && !account.RateLimitResetAt.Before(*resetAt) { + return account.RateLimitResetAt, false + } + account.RateLimitResetAt = resetAt + return resetAt, true +} + +func syncOpenAICodexRateLimitFromExtra(ctx context.Context, repo AccountRepository, account *Account, now time.Time) *time.Time { + resetAt, changed := applyOpenAICodexRateLimitFromExtra(account, now) + if !changed || resetAt == nil || repo == nil || account == nil || account.ID <= 0 { + return resetAt + } + _ = repo.SetRateLimited(ctx, account.ID, *resetAt) + return resetAt +} + +// updateCodexUsageSnapshot saves the Codex usage snapshot to account's Extra field +func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, accountID int64, snapshot *OpenAICodexUsageSnapshot) { + if snapshot == nil { + return + } + if s == nil || s.accountRepo == nil { + return + } + + now := time.Now() + updates := buildCodexUsageExtraUpdates(snapshot, now) + resetAt := codexRateLimitResetAtFromSnapshot(snapshot, now) + if len(updates) == 0 && resetAt == nil { + return + } + shouldPersistUpdates := len(updates) > 0 && s.getCodexSnapshotThrottle().Allow(accountID, now) + if !shouldPersistUpdates && resetAt == nil { + return + } + + go func() { + updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if shouldPersistUpdates { + _ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates) + } + if resetAt != nil { + _ = s.accountRepo.SetRateLimited(updateCtx, accountID, *resetAt) + } + }() +} + +func (s *OpenAIGatewayService) UpdateCodexUsageSnapshotFromHeaders(ctx context.Context, accountID int64, headers http.Header) { + if accountID <= 0 || headers == nil { + return + } + if snapshot := ParseCodexRateLimitHeaders(headers); snapshot != nil { + s.updateCodexUsageSnapshot(ctx, accountID, snapshot) + } +} + +func getOpenAIReasoningEffortFromReqBody(reqBody map[string]any) (value string, present bool) { + if reqBody == nil { + return "", false + } + + // Primary: reasoning.effort + if reasoning, ok := reqBody["reasoning"].(map[string]any); ok { + if effort, ok := reasoning["effort"].(string); ok { + return normalizeOpenAIReasoningEffort(effort), true + } + } + + // Fallback: some clients may use a flat field. + if effort, ok := reqBody["reasoning_effort"].(string); ok { + return normalizeOpenAIReasoningEffort(effort), true + } + + return "", false +} + +func deriveOpenAIReasoningEffortFromModel(model string) string { + if strings.TrimSpace(model) == "" { + return "" + } + + modelID := strings.TrimSpace(model) + if strings.Contains(modelID, "/") { + parts := strings.Split(modelID, "/") + modelID = parts[len(parts)-1] + } + + parts := strings.FieldsFunc(strings.ToLower(modelID), func(r rune) bool { + switch r { + case '-', '_', ' ': + return true + default: + return false + } + }) + if len(parts) == 0 { + return "" + } + + return normalizeOpenAIReasoningEffort(parts[len(parts)-1]) +} + +func extractOpenAIRequestMetaFromBody(body []byte) (model string, stream bool, promptCacheKey string) { + if len(body) == 0 { + return "", false, "" + } + + model = strings.TrimSpace(gjson.GetBytes(body, "model").String()) + stream = gjson.GetBytes(body, "stream").Bool() + promptCacheKey = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()) + return model, stream, promptCacheKey +} + +// normalizeOpenAIPassthroughOAuthBody 将透传 OAuth 请求体收敛为旧链路关键行为: +// 1) store=false 2) 非 compact 保持 stream=true;compact 强制 stream=false +func normalizeOpenAIPassthroughOAuthBody(body []byte, compact bool) ([]byte, bool, error) { + if len(body) == 0 { + return body, false, nil + } + + normalized := body + changed := false + + if compact { + if store := gjson.GetBytes(normalized, "store"); store.Exists() { + next, err := sjson.DeleteBytes(normalized, "store") + if err != nil { + return body, false, fmt.Errorf("normalize passthrough body delete store: %w", err) + } + normalized = next + changed = true + } + if stream := gjson.GetBytes(normalized, "stream"); stream.Exists() { + next, err := sjson.DeleteBytes(normalized, "stream") + if err != nil { + return body, false, fmt.Errorf("normalize passthrough body delete stream: %w", err) + } + normalized = next + changed = true + } + } else { + if store := gjson.GetBytes(normalized, "store"); !store.Exists() || store.Type != gjson.False { + next, err := sjson.SetBytes(normalized, "store", false) + if err != nil { + return body, false, fmt.Errorf("normalize passthrough body store=false: %w", err) + } + normalized = next + changed = true + } + if stream := gjson.GetBytes(normalized, "stream"); !stream.Exists() || stream.Type != gjson.True { + next, err := sjson.SetBytes(normalized, "stream", true) + if err != nil { + return body, false, fmt.Errorf("normalize passthrough body stream=true: %w", err) + } + normalized = next + changed = true + } + } + + return normalized, changed, nil +} + +func detectOpenAIPassthroughInstructionsRejectReason(reqModel string, body []byte) string { + model := strings.ToLower(strings.TrimSpace(reqModel)) + if !strings.Contains(model, "codex") { + return "" + } + + instructions := gjson.GetBytes(body, "instructions") + if !instructions.Exists() { + return "instructions_missing" + } + if instructions.Type != gjson.String { + return "instructions_not_string" + } + if strings.TrimSpace(instructions.String()) == "" { + return "instructions_empty" + } + return "" +} + +func extractOpenAIReasoningEffortFromBody(body []byte, requestedModel string) *string { + reasoningEffort := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String()) + if reasoningEffort == "" { + reasoningEffort = strings.TrimSpace(gjson.GetBytes(body, "reasoning_effort").String()) + } + if reasoningEffort != "" { + normalized := normalizeOpenAIReasoningEffort(reasoningEffort) + if normalized == "" { + return nil + } + return &normalized + } + + value := deriveOpenAIReasoningEffortFromModel(requestedModel) + if value == "" { + return nil + } + return &value +} + +func extractOpenAIServiceTier(reqBody map[string]any) *string { + if reqBody == nil { + return nil + } + raw, ok := reqBody["service_tier"].(string) + if !ok { + return nil + } + return normalizeOpenAIServiceTier(raw) +} + +func extractOpenAIServiceTierFromBody(body []byte) *string { + if len(body) == 0 { + return nil + } + return normalizeOpenAIServiceTier(gjson.GetBytes(body, "service_tier").String()) +} + +func normalizeOpenAIServiceTier(raw string) *string { + value := strings.ToLower(strings.TrimSpace(raw)) + if value == "" { + return nil + } + if value == "fast" { + value = "priority" + } + switch value { + case "priority", "flex": + return &value + default: + return nil + } +} + +func getOpenAIRequestBodyMap(c *gin.Context, body []byte) (map[string]any, error) { + if c != nil { + if cached, ok := c.Get(OpenAIParsedRequestBodyKey); ok { + if reqBody, ok := cached.(map[string]any); ok && reqBody != nil { + return reqBody, nil + } + } + } + + var reqBody map[string]any + if err := json.Unmarshal(body, &reqBody); err != nil { + return nil, fmt.Errorf("parse request: %w", err) + } + if c != nil { + c.Set(OpenAIParsedRequestBodyKey, reqBody) + } + return reqBody, nil +} + +func extractOpenAIReasoningEffort(reqBody map[string]any, requestedModel string) *string { + if value, present := getOpenAIReasoningEffortFromReqBody(reqBody); present { + if value == "" { + return nil + } + return &value + } + + value := deriveOpenAIReasoningEffortFromModel(requestedModel) + if value == "" { + return nil + } + return &value +} + +func normalizeOpenAIReasoningEffort(raw string) string { + value := strings.ToLower(strings.TrimSpace(raw)) + if value == "" { + return "" + } + + // Normalize separators for "x-high"/"x_high" variants. + value = strings.NewReplacer("-", "", "_", "", " ", "").Replace(value) + + switch value { + case "none", "minimal": + return "" + case "low", "medium", "high": + return value + case "xhigh", "extrahigh": + return "xhigh" + default: + // Only store known effort levels for now to keep UI consistent. + return "" + } +} diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go new file mode 100644 index 00000000..f45a0ea0 --- /dev/null +++ b/backend/internal/service/openai_ws_forwarder.go @@ -0,0 +1,4077 @@ +package service + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "math/rand" + "net" + "net/http" + "net/url" + "sort" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" + coderws "github.com/coder/websocket" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "go.uber.org/zap" +) + +const ( + openAIWSBetaV1Value = "responses_websockets=2026-02-04" + openAIWSBetaV2Value = "responses_websockets=2026-02-06" + + openAIWSTurnStateHeader = "x-codex-turn-state" + openAIWSTurnMetadataHeader = "x-codex-turn-metadata" + + openAIWSLogValueMaxLen = 160 + openAIWSHeaderValueMaxLen = 120 + openAIWSIDValueMaxLen = 64 + openAIWSEventLogHeadLimit = 20 + openAIWSEventLogEveryN = 50 + openAIWSBufferLogHeadLimit = 8 + openAIWSBufferLogEveryN = 20 + openAIWSPrewarmEventLogHead = 10 + openAIWSPayloadKeySizeTopN = 6 + + openAIWSPayloadSizeEstimateDepth = 3 + openAIWSPayloadSizeEstimateMaxBytes = 64 * 1024 + openAIWSPayloadSizeEstimateMaxItems = 16 + + openAIWSEventFlushBatchSizeDefault = 4 + openAIWSEventFlushIntervalDefault = 25 * time.Millisecond + openAIWSPayloadLogSampleDefault = 0.2 + openAIWSPassthroughIdleTimeoutDefault = time.Hour + + openAIWSStoreDisabledConnModeStrict = "strict" + openAIWSStoreDisabledConnModeAdaptive = "adaptive" + openAIWSStoreDisabledConnModeOff = "off" + + openAIWSIngressStagePreviousResponseNotFound = "previous_response_not_found" + openAIWSMaxPrevResponseIDDeletePasses = 8 +) + +var openAIWSLogValueReplacer = strings.NewReplacer( + "error", "err", + "fallback", "fb", + "warning", "warnx", + "failed", "fail", +) + +var openAIWSIngressPreflightPingIdle = 20 * time.Second + +// openAIWSFallbackError 表示可安全回退到 HTTP 的 WS 错误(尚未写下游)。 +type openAIWSFallbackError struct { + Reason string + Err error +} + +func (e *openAIWSFallbackError) Error() string { + if e == nil { + return "" + } + if e.Err == nil { + return fmt.Sprintf("openai ws fallback: %s", strings.TrimSpace(e.Reason)) + } + return fmt.Sprintf("openai ws fallback: %s: %v", strings.TrimSpace(e.Reason), e.Err) +} + +func (e *openAIWSFallbackError) Unwrap() error { + if e == nil { + return nil + } + return e.Err +} + +func wrapOpenAIWSFallback(reason string, err error) error { + return &openAIWSFallbackError{Reason: strings.TrimSpace(reason), Err: err} +} + +// OpenAIWSClientCloseError 表示应以指定 WebSocket close code 主动关闭客户端连接的错误。 +type OpenAIWSClientCloseError struct { + statusCode coderws.StatusCode + reason string + err error +} + +type openAIWSIngressTurnError struct { + stage string + cause error + wroteDownstream bool +} + +func (e *openAIWSIngressTurnError) Error() string { + if e == nil { + return "" + } + if e.cause == nil { + return strings.TrimSpace(e.stage) + } + return e.cause.Error() +} + +func (e *openAIWSIngressTurnError) Unwrap() error { + if e == nil { + return nil + } + return e.cause +} + +func wrapOpenAIWSIngressTurnError(stage string, cause error, wroteDownstream bool) error { + if cause == nil { + return nil + } + return &openAIWSIngressTurnError{ + stage: strings.TrimSpace(stage), + cause: cause, + wroteDownstream: wroteDownstream, + } +} + +func isOpenAIWSIngressTurnRetryable(err error) bool { + var turnErr *openAIWSIngressTurnError + if !errors.As(err, &turnErr) || turnErr == nil { + return false + } + if errors.Is(turnErr.cause, context.Canceled) || errors.Is(turnErr.cause, context.DeadlineExceeded) { + return false + } + if turnErr.wroteDownstream { + return false + } + switch turnErr.stage { + case "write_upstream", "read_upstream": + return true + default: + return false + } +} + +func openAIWSIngressTurnRetryReason(err error) string { + var turnErr *openAIWSIngressTurnError + if !errors.As(err, &turnErr) || turnErr == nil { + return "unknown" + } + if turnErr.stage == "" { + return "unknown" + } + return turnErr.stage +} + +func isOpenAIWSIngressPreviousResponseNotFound(err error) bool { + var turnErr *openAIWSIngressTurnError + if !errors.As(err, &turnErr) || turnErr == nil { + return false + } + if strings.TrimSpace(turnErr.stage) != openAIWSIngressStagePreviousResponseNotFound { + return false + } + return !turnErr.wroteDownstream +} + +// NewOpenAIWSClientCloseError 创建一个客户端 WS 关闭错误。 +func NewOpenAIWSClientCloseError(statusCode coderws.StatusCode, reason string, err error) error { + return &OpenAIWSClientCloseError{ + statusCode: statusCode, + reason: strings.TrimSpace(reason), + err: err, + } +} + +func (e *OpenAIWSClientCloseError) Error() string { + if e == nil { + return "" + } + if e.err == nil { + return fmt.Sprintf("openai ws client close: %d %s", int(e.statusCode), strings.TrimSpace(e.reason)) + } + return fmt.Sprintf("openai ws client close: %d %s: %v", int(e.statusCode), strings.TrimSpace(e.reason), e.err) +} + +func (e *OpenAIWSClientCloseError) Unwrap() error { + if e == nil { + return nil + } + return e.err +} + +func (e *OpenAIWSClientCloseError) StatusCode() coderws.StatusCode { + if e == nil { + return coderws.StatusInternalError + } + return e.statusCode +} + +func (e *OpenAIWSClientCloseError) Reason() string { + if e == nil { + return "" + } + return strings.TrimSpace(e.reason) +} + +// OpenAIWSIngressHooks 定义入站 WS 每个 turn 的生命周期回调。 +type OpenAIWSIngressHooks struct { + BeforeTurn func(turn int) error + AfterTurn func(turn int, result *OpenAIForwardResult, turnErr error) +} + +func normalizeOpenAIWSLogValue(value string) string { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return "-" + } + return openAIWSLogValueReplacer.Replace(trimmed) +} + +func truncateOpenAIWSLogValue(value string, maxLen int) string { + normalized := normalizeOpenAIWSLogValue(value) + if normalized == "-" || maxLen <= 0 { + return normalized + } + if len(normalized) <= maxLen { + return normalized + } + return normalized[:maxLen] + "..." +} + +func openAIWSHeaderValueForLog(headers http.Header, key string) string { + if headers == nil { + return "-" + } + return truncateOpenAIWSLogValue(headers.Get(key), openAIWSHeaderValueMaxLen) +} + +func hasOpenAIWSHeader(headers http.Header, key string) bool { + if headers == nil { + return false + } + return strings.TrimSpace(headers.Get(key)) != "" +} + +type openAIWSSessionHeaderResolution struct { + SessionID string + ConversationID string + SessionSource string + ConversationSource string +} + +func resolveOpenAIWSSessionHeaders(c *gin.Context, promptCacheKey string) openAIWSSessionHeaderResolution { + resolution := openAIWSSessionHeaderResolution{ + SessionSource: "none", + ConversationSource: "none", + } + if c != nil && c.Request != nil { + if sessionID := strings.TrimSpace(c.Request.Header.Get("session_id")); sessionID != "" { + resolution.SessionID = sessionID + resolution.SessionSource = "header_session_id" + } + if conversationID := strings.TrimSpace(c.Request.Header.Get("conversation_id")); conversationID != "" { + resolution.ConversationID = conversationID + resolution.ConversationSource = "header_conversation_id" + if resolution.SessionID == "" { + resolution.SessionID = conversationID + resolution.SessionSource = "header_conversation_id" + } + } + } + + cacheKey := strings.TrimSpace(promptCacheKey) + if cacheKey != "" { + if resolution.SessionID == "" { + resolution.SessionID = cacheKey + resolution.SessionSource = "prompt_cache_key" + } + } + return resolution +} + +func shouldLogOpenAIWSEvent(idx int, eventType string) bool { + if idx <= openAIWSEventLogHeadLimit { + return true + } + if openAIWSEventLogEveryN > 0 && idx%openAIWSEventLogEveryN == 0 { + return true + } + if eventType == "error" || isOpenAIWSTerminalEvent(eventType) { + return true + } + return false +} + +func shouldLogOpenAIWSBufferedEvent(idx int) bool { + if idx <= openAIWSBufferLogHeadLimit { + return true + } + if openAIWSBufferLogEveryN > 0 && idx%openAIWSBufferLogEveryN == 0 { + return true + } + return false +} + +func openAIWSEventMayContainModel(eventType string) bool { + switch eventType { + case "response.created", + "response.in_progress", + "response.completed", + "response.done", + "response.failed", + "response.incomplete", + "response.cancelled", + "response.canceled": + return true + default: + trimmed := strings.TrimSpace(eventType) + if trimmed == eventType { + return false + } + switch trimmed { + case "response.created", + "response.in_progress", + "response.completed", + "response.done", + "response.failed", + "response.incomplete", + "response.cancelled", + "response.canceled": + return true + default: + return false + } + } +} + +func openAIWSEventMayContainToolCalls(eventType string) bool { + eventType = strings.TrimSpace(eventType) + if eventType == "" { + return false + } + if strings.Contains(eventType, "function_call") || strings.Contains(eventType, "tool_call") { + return true + } + switch eventType { + case "response.output_item.added", "response.output_item.done", "response.completed", "response.done": + return true + default: + return false + } +} + +func openAIWSEventShouldParseUsage(eventType string) bool { + return eventType == "response.completed" || strings.TrimSpace(eventType) == "response.completed" +} + +func parseOpenAIWSEventEnvelope(message []byte) (eventType string, responseID string, response gjson.Result) { + if len(message) == 0 { + return "", "", gjson.Result{} + } + values := gjson.GetManyBytes(message, "type", "response.id", "id", "response") + eventType = strings.TrimSpace(values[0].String()) + if id := strings.TrimSpace(values[1].String()); id != "" { + responseID = id + } else { + responseID = strings.TrimSpace(values[2].String()) + } + return eventType, responseID, values[3] +} + +func openAIWSMessageLikelyContainsToolCalls(message []byte) bool { + if len(message) == 0 { + return false + } + return bytes.Contains(message, []byte(`"tool_calls"`)) || + bytes.Contains(message, []byte(`"tool_call"`)) || + bytes.Contains(message, []byte(`"function_call"`)) +} + +func parseOpenAIWSResponseUsageFromCompletedEvent(message []byte, usage *OpenAIUsage) { + if usage == nil || len(message) == 0 { + return + } + values := gjson.GetManyBytes( + message, + "response.usage.input_tokens", + "response.usage.output_tokens", + "response.usage.input_tokens_details.cached_tokens", + ) + usage.InputTokens = int(values[0].Int()) + usage.OutputTokens = int(values[1].Int()) + usage.CacheReadInputTokens = int(values[2].Int()) +} + +func parseOpenAIWSErrorEventFields(message []byte) (code string, errType string, errMessage string) { + if len(message) == 0 { + return "", "", "" + } + values := gjson.GetManyBytes(message, "error.code", "error.type", "error.message") + return strings.TrimSpace(values[0].String()), strings.TrimSpace(values[1].String()), strings.TrimSpace(values[2].String()) +} + +func summarizeOpenAIWSErrorEventFieldsFromRaw(codeRaw, errTypeRaw, errMessageRaw string) (code string, errType string, errMessage string) { + code = truncateOpenAIWSLogValue(codeRaw, openAIWSLogValueMaxLen) + errType = truncateOpenAIWSLogValue(errTypeRaw, openAIWSLogValueMaxLen) + errMessage = truncateOpenAIWSLogValue(errMessageRaw, openAIWSLogValueMaxLen) + return code, errType, errMessage +} + +func summarizeOpenAIWSErrorEventFields(message []byte) (code string, errType string, errMessage string) { + if len(message) == 0 { + return "-", "-", "-" + } + return summarizeOpenAIWSErrorEventFieldsFromRaw(parseOpenAIWSErrorEventFields(message)) +} + +func summarizeOpenAIWSPayloadKeySizes(payload map[string]any, topN int) string { + if len(payload) == 0 { + return "-" + } + type keySize struct { + Key string + Size int + } + sizes := make([]keySize, 0, len(payload)) + for key, value := range payload { + size := estimateOpenAIWSPayloadValueSize(value, openAIWSPayloadSizeEstimateDepth) + sizes = append(sizes, keySize{Key: key, Size: size}) + } + sort.Slice(sizes, func(i, j int) bool { + if sizes[i].Size == sizes[j].Size { + return sizes[i].Key < sizes[j].Key + } + return sizes[i].Size > sizes[j].Size + }) + + if topN <= 0 || topN > len(sizes) { + topN = len(sizes) + } + parts := make([]string, 0, topN) + for idx := 0; idx < topN; idx++ { + item := sizes[idx] + parts = append(parts, fmt.Sprintf("%s:%d", item.Key, item.Size)) + } + return strings.Join(parts, ",") +} + +func estimateOpenAIWSPayloadValueSize(value any, depth int) int { + if depth <= 0 { + return -1 + } + switch v := value.(type) { + case nil: + return 0 + case string: + return len(v) + case []byte: + return len(v) + case bool: + return 1 + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + return 8 + case float32, float64: + return 8 + case map[string]any: + if len(v) == 0 { + return 2 + } + total := 2 + count := 0 + for key, item := range v { + count++ + if count > openAIWSPayloadSizeEstimateMaxItems { + return -1 + } + itemSize := estimateOpenAIWSPayloadValueSize(item, depth-1) + if itemSize < 0 { + return -1 + } + total += len(key) + itemSize + 3 + if total > openAIWSPayloadSizeEstimateMaxBytes { + return -1 + } + } + return total + case []any: + if len(v) == 0 { + return 2 + } + total := 2 + limit := len(v) + if limit > openAIWSPayloadSizeEstimateMaxItems { + return -1 + } + for i := 0; i < limit; i++ { + itemSize := estimateOpenAIWSPayloadValueSize(v[i], depth-1) + if itemSize < 0 { + return -1 + } + total += itemSize + 1 + if total > openAIWSPayloadSizeEstimateMaxBytes { + return -1 + } + } + return total + default: + raw, err := json.Marshal(v) + if err != nil { + return -1 + } + if len(raw) > openAIWSPayloadSizeEstimateMaxBytes { + return -1 + } + return len(raw) + } +} + +func openAIWSPayloadString(payload map[string]any, key string) string { + if len(payload) == 0 { + return "" + } + raw, ok := payload[key] + if !ok { + return "" + } + switch v := raw.(type) { + case nil: + return "" + case string: + return strings.TrimSpace(v) + case []byte: + return strings.TrimSpace(string(v)) + default: + return "" + } +} + +func openAIWSPayloadStringFromRaw(payload []byte, key string) string { + if len(payload) == 0 || strings.TrimSpace(key) == "" { + return "" + } + return strings.TrimSpace(gjson.GetBytes(payload, key).String()) +} + +func openAIWSPayloadBoolFromRaw(payload []byte, key string, defaultValue bool) bool { + if len(payload) == 0 || strings.TrimSpace(key) == "" { + return defaultValue + } + value := gjson.GetBytes(payload, key) + if !value.Exists() { + return defaultValue + } + if value.Type != gjson.True && value.Type != gjson.False { + return defaultValue + } + return value.Bool() +} + +func openAIWSSessionHashesFromID(sessionID string) (string, string) { + return deriveOpenAISessionHashes(sessionID) +} + +func extractOpenAIWSImageURL(value any) string { + switch v := value.(type) { + case string: + return strings.TrimSpace(v) + case map[string]any: + if raw, ok := v["url"].(string); ok { + return strings.TrimSpace(raw) + } + } + return "" +} + +func summarizeOpenAIWSInput(input any) string { + items, ok := input.([]any) + if !ok || len(items) == 0 { + return "-" + } + + itemCount := len(items) + textChars := 0 + imageDataURLs := 0 + imageDataURLChars := 0 + imageRemoteURLs := 0 + + handleContentItem := func(contentItem map[string]any) { + contentType, _ := contentItem["type"].(string) + switch strings.TrimSpace(contentType) { + case "input_text", "output_text", "text": + if text, ok := contentItem["text"].(string); ok { + textChars += len(text) + } + case "input_image": + imageURL := extractOpenAIWSImageURL(contentItem["image_url"]) + if imageURL == "" { + return + } + if strings.HasPrefix(strings.ToLower(imageURL), "data:image/") { + imageDataURLs++ + imageDataURLChars += len(imageURL) + return + } + imageRemoteURLs++ + } + } + + handleInputItem := func(inputItem map[string]any) { + if content, ok := inputItem["content"].([]any); ok { + for _, rawContent := range content { + contentItem, ok := rawContent.(map[string]any) + if !ok { + continue + } + handleContentItem(contentItem) + } + return + } + + itemType, _ := inputItem["type"].(string) + switch strings.TrimSpace(itemType) { + case "input_text", "output_text", "text": + if text, ok := inputItem["text"].(string); ok { + textChars += len(text) + } + case "input_image": + imageURL := extractOpenAIWSImageURL(inputItem["image_url"]) + if imageURL == "" { + return + } + if strings.HasPrefix(strings.ToLower(imageURL), "data:image/") { + imageDataURLs++ + imageDataURLChars += len(imageURL) + return + } + imageRemoteURLs++ + } + } + + for _, rawItem := range items { + inputItem, ok := rawItem.(map[string]any) + if !ok { + continue + } + handleInputItem(inputItem) + } + + return fmt.Sprintf( + "items=%d,text_chars=%d,image_data_urls=%d,image_data_url_chars=%d,image_remote_urls=%d", + itemCount, + textChars, + imageDataURLs, + imageDataURLChars, + imageRemoteURLs, + ) +} + +func dropOpenAIWSPayloadKey(payload map[string]any, key string, removed *[]string) { + if len(payload) == 0 || strings.TrimSpace(key) == "" { + return + } + if _, exists := payload[key]; !exists { + return + } + delete(payload, key) + *removed = append(*removed, key) +} + +// applyOpenAIWSRetryPayloadStrategy 在 WS 连续失败时仅移除无语义字段, +// 避免重试成功却改变原始请求语义。 +// 注意:prompt_cache_key 不应在重试中移除;它常用于会话稳定标识(session_id 兜底)。 +func applyOpenAIWSRetryPayloadStrategy(payload map[string]any, attempt int) (strategy string, removedKeys []string) { + if len(payload) == 0 { + return "empty", nil + } + if attempt <= 1 { + return "full", nil + } + + removed := make([]string, 0, 2) + if attempt >= 2 { + dropOpenAIWSPayloadKey(payload, "include", &removed) + } + + if len(removed) == 0 { + return "full", nil + } + sort.Strings(removed) + return "trim_optional_fields", removed +} + +func logOpenAIWSModeInfo(format string, args ...any) { + logger.LegacyPrintf("service.openai_gateway", "[OpenAI WS Mode][openai_ws_mode=true] "+format, args...) +} + +func isOpenAIWSModeDebugEnabled() bool { + return logger.L().Core().Enabled(zap.DebugLevel) +} + +func logOpenAIWSModeDebug(format string, args ...any) { + if !isOpenAIWSModeDebugEnabled() { + return + } + logger.LegacyPrintf("service.openai_gateway", "[debug] [OpenAI WS Mode][openai_ws_mode=true] "+format, args...) +} + +func logOpenAIWSBindResponseAccountWarn(groupID, accountID int64, responseID string, err error) { + if err == nil { + return + } + logger.L().Warn( + "openai.ws_bind_response_account_failed", + zap.Int64("group_id", groupID), + zap.Int64("account_id", accountID), + zap.String("response_id", truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen)), + zap.Error(err), + ) +} + +func summarizeOpenAIWSReadCloseError(err error) (status string, reason string) { + if err == nil { + return "-", "-" + } + statusCode := coderws.CloseStatus(err) + if statusCode == -1 { + return "-", "-" + } + closeStatus := fmt.Sprintf("%d(%s)", int(statusCode), statusCode.String()) + closeReason := "-" + var closeErr coderws.CloseError + if errors.As(err, &closeErr) { + reasonText := strings.TrimSpace(closeErr.Reason) + if reasonText != "" { + closeReason = normalizeOpenAIWSLogValue(reasonText) + } + } + return normalizeOpenAIWSLogValue(closeStatus), closeReason +} + +func unwrapOpenAIWSDialBaseError(err error) error { + if err == nil { + return nil + } + var dialErr *openAIWSDialError + if errors.As(err, &dialErr) && dialErr != nil && dialErr.Err != nil { + return dialErr.Err + } + return err +} + +func openAIWSDialRespHeaderForLog(err error, key string) string { + var dialErr *openAIWSDialError + if !errors.As(err, &dialErr) || dialErr == nil || dialErr.ResponseHeaders == nil { + return "-" + } + return truncateOpenAIWSLogValue(dialErr.ResponseHeaders.Get(key), openAIWSHeaderValueMaxLen) +} + +func classifyOpenAIWSDialError(err error) string { + if err == nil { + return "-" + } + baseErr := unwrapOpenAIWSDialBaseError(err) + if baseErr == nil { + return "-" + } + if errors.Is(baseErr, context.DeadlineExceeded) { + return "ctx_deadline_exceeded" + } + if errors.Is(baseErr, context.Canceled) { + return "ctx_canceled" + } + var netErr net.Error + if errors.As(baseErr, &netErr) && netErr.Timeout() { + return "net_timeout" + } + if status := coderws.CloseStatus(baseErr); status != -1 { + return normalizeOpenAIWSLogValue(fmt.Sprintf("ws_close_%d", int(status))) + } + message := strings.ToLower(strings.TrimSpace(baseErr.Error())) + switch { + case strings.Contains(message, "handshake not finished"): + return "handshake_not_finished" + case strings.Contains(message, "bad handshake"): + return "bad_handshake" + case strings.Contains(message, "connection refused"): + return "connection_refused" + case strings.Contains(message, "no such host"): + return "dns_not_found" + case strings.Contains(message, "tls"): + return "tls_error" + case strings.Contains(message, "i/o timeout"): + return "io_timeout" + case strings.Contains(message, "context deadline exceeded"): + return "ctx_deadline_exceeded" + default: + return "dial_error" + } +} + +func summarizeOpenAIWSDialError(err error) ( + statusCode int, + dialClass string, + closeStatus string, + closeReason string, + respServer string, + respVia string, + respCFRay string, + respRequestID string, +) { + dialClass = "-" + closeStatus = "-" + closeReason = "-" + respServer = "-" + respVia = "-" + respCFRay = "-" + respRequestID = "-" + if err == nil { + return + } + var dialErr *openAIWSDialError + if errors.As(err, &dialErr) && dialErr != nil { + statusCode = dialErr.StatusCode + respServer = openAIWSDialRespHeaderForLog(err, "server") + respVia = openAIWSDialRespHeaderForLog(err, "via") + respCFRay = openAIWSDialRespHeaderForLog(err, "cf-ray") + respRequestID = openAIWSDialRespHeaderForLog(err, "x-request-id") + } + dialClass = normalizeOpenAIWSLogValue(classifyOpenAIWSDialError(err)) + closeStatus, closeReason = summarizeOpenAIWSReadCloseError(unwrapOpenAIWSDialBaseError(err)) + return +} + +func isOpenAIWSClientDisconnectError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) { + return true + } + switch coderws.CloseStatus(err) { + case coderws.StatusNormalClosure, coderws.StatusGoingAway, coderws.StatusNoStatusRcvd, coderws.StatusAbnormalClosure: + return true + } + message := strings.ToLower(strings.TrimSpace(err.Error())) + if message == "" { + return false + } + return strings.Contains(message, "failed to read frame header: eof") || + strings.Contains(message, "unexpected eof") || + strings.Contains(message, "use of closed network connection") || + strings.Contains(message, "connection reset by peer") || + strings.Contains(message, "broken pipe") || + strings.Contains(message, "an established connection was aborted") +} + +func classifyOpenAIWSReadFallbackReason(err error) string { + if err == nil { + return "read_event" + } + switch coderws.CloseStatus(err) { + case coderws.StatusPolicyViolation: + return "policy_violation" + case coderws.StatusMessageTooBig: + return "message_too_big" + default: + return "read_event" + } +} + +func sortedKeys(m map[string]any) []string { + if len(m) == 0 { + return nil + } + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + sort.Strings(keys) + return keys +} + +func (s *OpenAIGatewayService) getOpenAIWSConnPool() *openAIWSConnPool { + if s == nil { + return nil + } + s.openaiWSPoolOnce.Do(func() { + if s.openaiWSPool == nil { + s.openaiWSPool = newOpenAIWSConnPool(s.cfg) + } + }) + return s.openaiWSPool +} + +func (s *OpenAIGatewayService) getOpenAIWSPassthroughDialer() openAIWSClientDialer { + if s == nil { + return nil + } + s.openaiWSPassthroughDialerOnce.Do(func() { + if s.openaiWSPassthroughDialer == nil { + s.openaiWSPassthroughDialer = newDefaultOpenAIWSClientDialer() + } + }) + return s.openaiWSPassthroughDialer +} + +func (s *OpenAIGatewayService) SnapshotOpenAIWSPoolMetrics() OpenAIWSPoolMetricsSnapshot { + pool := s.getOpenAIWSConnPool() + if pool == nil { + return OpenAIWSPoolMetricsSnapshot{} + } + return pool.SnapshotMetrics() +} + +type OpenAIWSPerformanceMetricsSnapshot struct { + Pool OpenAIWSPoolMetricsSnapshot `json:"pool"` + Retry OpenAIWSRetryMetricsSnapshot `json:"retry"` + Transport OpenAIWSTransportMetricsSnapshot `json:"transport"` +} + +func (s *OpenAIGatewayService) SnapshotOpenAIWSPerformanceMetrics() OpenAIWSPerformanceMetricsSnapshot { + pool := s.getOpenAIWSConnPool() + snapshot := OpenAIWSPerformanceMetricsSnapshot{ + Retry: s.SnapshotOpenAIWSRetryMetrics(), + } + if pool == nil { + return snapshot + } + snapshot.Pool = pool.SnapshotMetrics() + snapshot.Transport = pool.SnapshotTransportMetrics() + return snapshot +} + +func (s *OpenAIGatewayService) getOpenAIWSStateStore() OpenAIWSStateStore { + if s == nil { + return nil + } + s.openaiWSStateStoreOnce.Do(func() { + if s.openaiWSStateStore == nil { + s.openaiWSStateStore = NewOpenAIWSStateStore(s.cache) + } + }) + return s.openaiWSStateStore +} + +func (s *OpenAIGatewayService) openAIWSResponseStickyTTL() time.Duration { + if s != nil && s.cfg != nil { + seconds := s.cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds + if seconds > 0 { + return time.Duration(seconds) * time.Second + } + } + return time.Hour +} + +func (s *OpenAIGatewayService) openAIWSIngressPreviousResponseRecoveryEnabled() bool { + if s != nil && s.cfg != nil { + return s.cfg.Gateway.OpenAIWS.IngressPreviousResponseRecoveryEnabled + } + return true +} + +func (s *OpenAIGatewayService) openAIWSReadTimeout() time.Duration { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ReadTimeoutSeconds > 0 { + return time.Duration(s.cfg.Gateway.OpenAIWS.ReadTimeoutSeconds) * time.Second + } + return 15 * time.Minute +} + +func (s *OpenAIGatewayService) openAIWSPassthroughIdleTimeout() time.Duration { + if timeout := s.openAIWSReadTimeout(); timeout > 0 { + return timeout + } + return openAIWSPassthroughIdleTimeoutDefault +} + +func (s *OpenAIGatewayService) openAIWSWriteTimeout() time.Duration { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds > 0 { + return time.Duration(s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds) * time.Second + } + return 2 * time.Minute +} + +func (s *OpenAIGatewayService) openAIWSEventFlushBatchSize() int { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.EventFlushBatchSize > 0 { + return s.cfg.Gateway.OpenAIWS.EventFlushBatchSize + } + return openAIWSEventFlushBatchSizeDefault +} + +func (s *OpenAIGatewayService) openAIWSEventFlushInterval() time.Duration { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.EventFlushIntervalMS >= 0 { + if s.cfg.Gateway.OpenAIWS.EventFlushIntervalMS == 0 { + return 0 + } + return time.Duration(s.cfg.Gateway.OpenAIWS.EventFlushIntervalMS) * time.Millisecond + } + return openAIWSEventFlushIntervalDefault +} + +func (s *OpenAIGatewayService) openAIWSPayloadLogSampleRate() float64 { + if s != nil && s.cfg != nil { + rate := s.cfg.Gateway.OpenAIWS.PayloadLogSampleRate + if rate < 0 { + return 0 + } + if rate > 1 { + return 1 + } + return rate + } + return openAIWSPayloadLogSampleDefault +} + +func (s *OpenAIGatewayService) shouldLogOpenAIWSPayloadSchema(attempt int) bool { + // 首次尝试保留一条完整 payload_schema 便于排障。 + if attempt <= 1 { + return true + } + rate := s.openAIWSPayloadLogSampleRate() + if rate <= 0 { + return false + } + if rate >= 1 { + return true + } + return rand.Float64() < rate +} + +func (s *OpenAIGatewayService) shouldEmitOpenAIWSPayloadSchema(attempt int) bool { + if !s.shouldLogOpenAIWSPayloadSchema(attempt) { + return false + } + return logger.L().Core().Enabled(zap.DebugLevel) +} + +func (s *OpenAIGatewayService) openAIWSDialTimeout() time.Duration { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.DialTimeoutSeconds > 0 { + return time.Duration(s.cfg.Gateway.OpenAIWS.DialTimeoutSeconds) * time.Second + } + return 10 * time.Second +} + +func (s *OpenAIGatewayService) openAIWSAcquireTimeout() time.Duration { + // Acquire 覆盖“连接复用命中/排队/新建连接”三个阶段。 + // 这里不再叠加 write_timeout,避免高并发排队时把 TTFT 长尾拉到分钟级。 + dial := s.openAIWSDialTimeout() + if dial <= 0 { + dial = 10 * time.Second + } + return dial + 2*time.Second +} + +func (s *OpenAIGatewayService) buildOpenAIResponsesWSURL(account *Account) (string, error) { + if account == nil { + return "", errors.New("account is nil") + } + var targetURL string + switch account.Type { + case AccountTypeOAuth: + targetURL = chatgptCodexURL + case AccountTypeAPIKey: + baseURL := account.GetOpenAIBaseURL() + if baseURL == "" { + targetURL = openaiPlatformAPIURL + } else { + validatedURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return "", err + } + targetURL = buildOpenAIResponsesURL(validatedURL) + } + default: + targetURL = openaiPlatformAPIURL + } + + parsed, err := url.Parse(strings.TrimSpace(targetURL)) + if err != nil { + return "", fmt.Errorf("invalid target url: %w", err) + } + switch strings.ToLower(parsed.Scheme) { + case "https": + parsed.Scheme = "wss" + case "http": + parsed.Scheme = "ws" + case "wss", "ws": + // 保持不变 + default: + return "", fmt.Errorf("unsupported scheme for ws: %s", parsed.Scheme) + } + return parsed.String(), nil +} + +func (s *OpenAIGatewayService) buildOpenAIWSHeaders( + c *gin.Context, + account *Account, + token string, + decision OpenAIWSProtocolDecision, + isCodexCLI bool, + turnState string, + turnMetadata string, + promptCacheKey string, +) (http.Header, openAIWSSessionHeaderResolution) { + headers := make(http.Header) + headers.Set("authorization", "Bearer "+token) + + sessionResolution := resolveOpenAIWSSessionHeaders(c, promptCacheKey) + if c != nil && c.Request != nil { + if v := strings.TrimSpace(c.Request.Header.Get("accept-language")); v != "" { + headers.Set("accept-language", v) + } + } + // OAuth 账号:将 apiKeyID 混入 session 标识符,防止跨用户会话碰撞。 + if account != nil && account.Type == AccountTypeOAuth { + apiKeyID := getAPIKeyIDFromContext(c) + if sessionResolution.SessionID != "" { + headers.Set("session_id", isolateOpenAISessionID(apiKeyID, sessionResolution.SessionID)) + } + if sessionResolution.ConversationID != "" { + headers.Set("conversation_id", isolateOpenAISessionID(apiKeyID, sessionResolution.ConversationID)) + } + } else { + if sessionResolution.SessionID != "" { + headers.Set("session_id", sessionResolution.SessionID) + } + if sessionResolution.ConversationID != "" { + headers.Set("conversation_id", sessionResolution.ConversationID) + } + } + if state := strings.TrimSpace(turnState); state != "" { + headers.Set(openAIWSTurnStateHeader, state) + } + if metadata := strings.TrimSpace(turnMetadata); metadata != "" { + headers.Set(openAIWSTurnMetadataHeader, metadata) + } + + if account != nil && account.Type == AccountTypeOAuth { + if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" { + headers.Set("chatgpt-account-id", chatgptAccountID) + } + headers.Set("originator", resolveOpenAIUpstreamOriginator(c, isCodexCLI)) + } + + betaValue := openAIWSBetaV2Value + if decision.Transport == OpenAIUpstreamTransportResponsesWebsocket { + betaValue = openAIWSBetaV1Value + } + headers.Set("OpenAI-Beta", betaValue) + + customUA := "" + if account != nil { + customUA = account.GetOpenAIUserAgent() + } + if strings.TrimSpace(customUA) != "" { + headers.Set("user-agent", customUA) + } else if c != nil { + if ua := strings.TrimSpace(c.GetHeader("User-Agent")); ua != "" { + headers.Set("user-agent", ua) + } + } + if s != nil && s.cfg != nil && s.cfg.Gateway.ForceCodexCLI { + headers.Set("user-agent", codexCLIUserAgent) + } + if account != nil && account.Type == AccountTypeOAuth && !openai.IsCodexCLIRequest(headers.Get("user-agent")) { + headers.Set("user-agent", codexCLIUserAgent) + } + + return headers, sessionResolution +} + +func (s *OpenAIGatewayService) buildOpenAIWSCreatePayload(reqBody map[string]any, account *Account) map[string]any { + // OpenAI WS Mode 协议:response.create 字段与 HTTP /responses 基本一致。 + // 保留 stream 字段(与 Codex CLI 一致),仅移除 background。 + payload := make(map[string]any, len(reqBody)+1) + for k, v := range reqBody { + payload[k] = v + } + + delete(payload, "background") + if _, exists := payload["stream"]; !exists { + payload["stream"] = true + } + payload["type"] = "response.create" + + // OAuth 默认保持 store=false,避免误依赖服务端历史。 + if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) { + payload["store"] = false + } + return payload +} + +func setOpenAIWSTurnMetadata(payload map[string]any, turnMetadata string) { + if len(payload) == 0 { + return + } + metadata := strings.TrimSpace(turnMetadata) + if metadata == "" { + return + } + + switch existing := payload["client_metadata"].(type) { + case map[string]any: + existing[openAIWSTurnMetadataHeader] = metadata + payload["client_metadata"] = existing + case map[string]string: + next := make(map[string]any, len(existing)+1) + for k, v := range existing { + next[k] = v + } + next[openAIWSTurnMetadataHeader] = metadata + payload["client_metadata"] = next + default: + payload["client_metadata"] = map[string]any{ + openAIWSTurnMetadataHeader: metadata, + } + } +} + +func (s *OpenAIGatewayService) isOpenAIWSStoreRecoveryAllowed(account *Account) bool { + if account != nil && account.IsOpenAIWSAllowStoreRecoveryEnabled() { + return true + } + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.AllowStoreRecovery { + return true + } + return false +} + +func (s *OpenAIGatewayService) isOpenAIWSStoreDisabledInRequest(reqBody map[string]any, account *Account) bool { + if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) { + return true + } + if len(reqBody) == 0 { + return false + } + rawStore, ok := reqBody["store"] + if !ok { + return false + } + storeEnabled, ok := rawStore.(bool) + if !ok { + return false + } + return !storeEnabled +} + +func (s *OpenAIGatewayService) isOpenAIWSStoreDisabledInRequestRaw(reqBody []byte, account *Account) bool { + if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) { + return true + } + if len(reqBody) == 0 { + return false + } + storeValue := gjson.GetBytes(reqBody, "store") + if !storeValue.Exists() { + return false + } + if storeValue.Type != gjson.True && storeValue.Type != gjson.False { + return false + } + return !storeValue.Bool() +} + +func (s *OpenAIGatewayService) openAIWSStoreDisabledConnMode() string { + if s == nil || s.cfg == nil { + return openAIWSStoreDisabledConnModeStrict + } + mode := strings.ToLower(strings.TrimSpace(s.cfg.Gateway.OpenAIWS.StoreDisabledConnMode)) + switch mode { + case openAIWSStoreDisabledConnModeStrict, openAIWSStoreDisabledConnModeAdaptive, openAIWSStoreDisabledConnModeOff: + return mode + case "": + // 兼容旧配置:仅配置了布尔开关时按旧语义推导。 + if s.cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn { + return openAIWSStoreDisabledConnModeStrict + } + return openAIWSStoreDisabledConnModeOff + default: + return openAIWSStoreDisabledConnModeStrict + } +} + +func shouldForceNewConnOnStoreDisabled(mode, lastFailureReason string) bool { + switch mode { + case openAIWSStoreDisabledConnModeOff: + return false + case openAIWSStoreDisabledConnModeAdaptive: + reason := strings.TrimPrefix(strings.TrimSpace(lastFailureReason), "prewarm_") + switch reason { + case "policy_violation", "message_too_big", "auth_failed", "write_request", "write": + return true + default: + return false + } + default: + return true + } +} + +func dropPreviousResponseIDFromRawPayload(payload []byte) ([]byte, bool, error) { + return dropPreviousResponseIDFromRawPayloadWithDeleteFn(payload, sjson.DeleteBytes) +} + +func dropPreviousResponseIDFromRawPayloadWithDeleteFn( + payload []byte, + deleteFn func([]byte, string) ([]byte, error), +) ([]byte, bool, error) { + if len(payload) == 0 { + return payload, false, nil + } + if !gjson.GetBytes(payload, "previous_response_id").Exists() { + return payload, false, nil + } + if deleteFn == nil { + deleteFn = sjson.DeleteBytes + } + + updated := payload + for i := 0; i < openAIWSMaxPrevResponseIDDeletePasses && + gjson.GetBytes(updated, "previous_response_id").Exists(); i++ { + next, err := deleteFn(updated, "previous_response_id") + if err != nil { + return payload, false, err + } + updated = next + } + return updated, !gjson.GetBytes(updated, "previous_response_id").Exists(), nil +} + +func setPreviousResponseIDToRawPayload(payload []byte, previousResponseID string) ([]byte, error) { + normalizedPrevID := strings.TrimSpace(previousResponseID) + if len(payload) == 0 || normalizedPrevID == "" { + return payload, nil + } + updated, err := sjson.SetBytes(payload, "previous_response_id", normalizedPrevID) + if err == nil { + return updated, nil + } + + var reqBody map[string]any + if unmarshalErr := json.Unmarshal(payload, &reqBody); unmarshalErr != nil { + return nil, err + } + reqBody["previous_response_id"] = normalizedPrevID + rebuilt, marshalErr := json.Marshal(reqBody) + if marshalErr != nil { + return nil, marshalErr + } + return rebuilt, nil +} + +func shouldInferIngressFunctionCallOutputPreviousResponseID( + storeDisabled bool, + turn int, + hasFunctionCallOutput bool, + currentPreviousResponseID string, + expectedPreviousResponseID string, +) bool { + if !storeDisabled || turn <= 1 || !hasFunctionCallOutput { + return false + } + if strings.TrimSpace(currentPreviousResponseID) != "" { + return false + } + return strings.TrimSpace(expectedPreviousResponseID) != "" +} + +func alignStoreDisabledPreviousResponseID( + payload []byte, + expectedPreviousResponseID string, +) ([]byte, bool, error) { + if len(payload) == 0 { + return payload, false, nil + } + expected := strings.TrimSpace(expectedPreviousResponseID) + if expected == "" { + return payload, false, nil + } + current := openAIWSPayloadStringFromRaw(payload, "previous_response_id") + if current == "" || current == expected { + return payload, false, nil + } + + withoutPrev, removed, dropErr := dropPreviousResponseIDFromRawPayload(payload) + if dropErr != nil { + return payload, false, dropErr + } + if !removed { + return payload, false, nil + } + updated, setErr := setPreviousResponseIDToRawPayload(withoutPrev, expected) + if setErr != nil { + return payload, false, setErr + } + return updated, true, nil +} + +func cloneOpenAIWSPayloadBytes(payload []byte) []byte { + if len(payload) == 0 { + return nil + } + cloned := make([]byte, len(payload)) + copy(cloned, payload) + return cloned +} + +func cloneOpenAIWSRawMessages(items []json.RawMessage) []json.RawMessage { + if items == nil { + return nil + } + cloned := make([]json.RawMessage, 0, len(items)) + for idx := range items { + cloned = append(cloned, json.RawMessage(cloneOpenAIWSPayloadBytes(items[idx]))) + } + return cloned +} + +func normalizeOpenAIWSJSONForCompare(raw []byte) ([]byte, error) { + trimmed := bytes.TrimSpace(raw) + if len(trimmed) == 0 { + return nil, errors.New("json is empty") + } + var decoded any + if err := json.Unmarshal(trimmed, &decoded); err != nil { + return nil, err + } + return json.Marshal(decoded) +} + +func normalizeOpenAIWSJSONForCompareOrRaw(raw []byte) []byte { + normalized, err := normalizeOpenAIWSJSONForCompare(raw) + if err != nil { + return bytes.TrimSpace(raw) + } + return normalized +} + +func normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(payload []byte) ([]byte, error) { + if len(payload) == 0 { + return nil, errors.New("payload is empty") + } + var decoded map[string]any + if err := json.Unmarshal(payload, &decoded); err != nil { + return nil, err + } + delete(decoded, "input") + delete(decoded, "previous_response_id") + return json.Marshal(decoded) +} + +func openAIWSExtractNormalizedInputSequence(payload []byte) ([]json.RawMessage, bool, error) { + if len(payload) == 0 { + return nil, false, nil + } + inputValue := gjson.GetBytes(payload, "input") + if !inputValue.Exists() { + return nil, false, nil + } + if inputValue.Type == gjson.JSON { + raw := strings.TrimSpace(inputValue.Raw) + if strings.HasPrefix(raw, "[") { + var items []json.RawMessage + if err := json.Unmarshal([]byte(raw), &items); err != nil { + return nil, true, err + } + return items, true, nil + } + return []json.RawMessage{json.RawMessage(raw)}, true, nil + } + if inputValue.Type == gjson.String { + encoded, _ := json.Marshal(inputValue.String()) + return []json.RawMessage{encoded}, true, nil + } + return []json.RawMessage{json.RawMessage(inputValue.Raw)}, true, nil +} + +func openAIWSInputIsPrefixExtended(previousPayload, currentPayload []byte) (bool, error) { + previousItems, previousExists, prevErr := openAIWSExtractNormalizedInputSequence(previousPayload) + if prevErr != nil { + return false, prevErr + } + currentItems, currentExists, currentErr := openAIWSExtractNormalizedInputSequence(currentPayload) + if currentErr != nil { + return false, currentErr + } + if !previousExists && !currentExists { + return true, nil + } + if !previousExists { + return len(currentItems) == 0, nil + } + if !currentExists { + return len(previousItems) == 0, nil + } + if len(currentItems) < len(previousItems) { + return false, nil + } + + for idx := range previousItems { + previousNormalized := normalizeOpenAIWSJSONForCompareOrRaw(previousItems[idx]) + currentNormalized := normalizeOpenAIWSJSONForCompareOrRaw(currentItems[idx]) + if !bytes.Equal(previousNormalized, currentNormalized) { + return false, nil + } + } + return true, nil +} + +func openAIWSRawItemsHasPrefix(items []json.RawMessage, prefix []json.RawMessage) bool { + if len(prefix) == 0 { + return true + } + if len(items) < len(prefix) { + return false + } + for idx := range prefix { + previousNormalized := normalizeOpenAIWSJSONForCompareOrRaw(prefix[idx]) + currentNormalized := normalizeOpenAIWSJSONForCompareOrRaw(items[idx]) + if !bytes.Equal(previousNormalized, currentNormalized) { + return false + } + } + return true +} + +func buildOpenAIWSReplayInputSequence( + previousFullInput []json.RawMessage, + previousFullInputExists bool, + currentPayload []byte, + hasPreviousResponseID bool, +) ([]json.RawMessage, bool, error) { + currentItems, currentExists, currentErr := openAIWSExtractNormalizedInputSequence(currentPayload) + if currentErr != nil { + return nil, false, currentErr + } + if !hasPreviousResponseID { + return cloneOpenAIWSRawMessages(currentItems), currentExists, nil + } + if !previousFullInputExists { + return cloneOpenAIWSRawMessages(currentItems), currentExists, nil + } + if !currentExists || len(currentItems) == 0 { + return cloneOpenAIWSRawMessages(previousFullInput), true, nil + } + if openAIWSRawItemsHasPrefix(currentItems, previousFullInput) { + return cloneOpenAIWSRawMessages(currentItems), true, nil + } + merged := make([]json.RawMessage, 0, len(previousFullInput)+len(currentItems)) + merged = append(merged, cloneOpenAIWSRawMessages(previousFullInput)...) + merged = append(merged, cloneOpenAIWSRawMessages(currentItems)...) + return merged, true, nil +} + +func setOpenAIWSPayloadInputSequence( + payload []byte, + fullInput []json.RawMessage, + fullInputExists bool, +) ([]byte, error) { + if !fullInputExists { + return payload, nil + } + // Preserve [] vs null semantics when input exists but is empty. + inputForMarshal := fullInput + if inputForMarshal == nil { + inputForMarshal = []json.RawMessage{} + } + inputRaw, marshalErr := json.Marshal(inputForMarshal) + if marshalErr != nil { + return nil, marshalErr + } + return sjson.SetRawBytes(payload, "input", inputRaw) +} + +func shouldKeepIngressPreviousResponseID( + previousPayload []byte, + currentPayload []byte, + lastTurnResponseID string, + hasFunctionCallOutput bool, +) (bool, string, error) { + if hasFunctionCallOutput { + return true, "has_function_call_output", nil + } + currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id")) + if currentPreviousResponseID == "" { + return false, "missing_previous_response_id", nil + } + expectedPreviousResponseID := strings.TrimSpace(lastTurnResponseID) + if expectedPreviousResponseID == "" { + return false, "missing_last_turn_response_id", nil + } + if currentPreviousResponseID != expectedPreviousResponseID { + return false, "previous_response_id_mismatch", nil + } + if len(previousPayload) == 0 { + return false, "missing_previous_turn_payload", nil + } + + previousComparable, previousComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(previousPayload) + if previousComparableErr != nil { + return false, "non_input_compare_error", previousComparableErr + } + currentComparable, currentComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(currentPayload) + if currentComparableErr != nil { + return false, "non_input_compare_error", currentComparableErr + } + if !bytes.Equal(previousComparable, currentComparable) { + return false, "non_input_changed", nil + } + return true, "strict_incremental_ok", nil +} + +type openAIWSIngressPreviousTurnStrictState struct { + nonInputComparable []byte +} + +func buildOpenAIWSIngressPreviousTurnStrictState(payload []byte) (*openAIWSIngressPreviousTurnStrictState, error) { + if len(payload) == 0 { + return nil, nil + } + nonInputComparable, nonInputErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(payload) + if nonInputErr != nil { + return nil, nonInputErr + } + return &openAIWSIngressPreviousTurnStrictState{ + nonInputComparable: nonInputComparable, + }, nil +} + +func shouldKeepIngressPreviousResponseIDWithStrictState( + previousState *openAIWSIngressPreviousTurnStrictState, + currentPayload []byte, + lastTurnResponseID string, + hasFunctionCallOutput bool, +) (bool, string, error) { + if hasFunctionCallOutput { + return true, "has_function_call_output", nil + } + currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id")) + if currentPreviousResponseID == "" { + return false, "missing_previous_response_id", nil + } + expectedPreviousResponseID := strings.TrimSpace(lastTurnResponseID) + if expectedPreviousResponseID == "" { + return false, "missing_last_turn_response_id", nil + } + if currentPreviousResponseID != expectedPreviousResponseID { + return false, "previous_response_id_mismatch", nil + } + if previousState == nil { + return false, "missing_previous_turn_payload", nil + } + + currentComparable, currentComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(currentPayload) + if currentComparableErr != nil { + return false, "non_input_compare_error", currentComparableErr + } + if !bytes.Equal(previousState.nonInputComparable, currentComparable) { + return false, "non_input_changed", nil + } + return true, "strict_incremental_ok", nil +} + +func (s *OpenAIGatewayService) forwardOpenAIWSV2( + ctx context.Context, + c *gin.Context, + account *Account, + reqBody map[string]any, + token string, + decision OpenAIWSProtocolDecision, + isCodexCLI bool, + reqStream bool, + originalModel string, + mappedModel string, + startTime time.Time, + attempt int, + lastFailureReason string, +) (*OpenAIForwardResult, error) { + if s == nil || account == nil { + return nil, wrapOpenAIWSFallback("invalid_state", errors.New("service or account is nil")) + } + + wsURL, err := s.buildOpenAIResponsesWSURL(account) + if err != nil { + return nil, wrapOpenAIWSFallback("build_ws_url", err) + } + wsHost := "-" + wsPath := "-" + if parsed, parseErr := url.Parse(wsURL); parseErr == nil && parsed != nil { + if h := strings.TrimSpace(parsed.Host); h != "" { + wsHost = normalizeOpenAIWSLogValue(h) + } + if p := strings.TrimSpace(parsed.Path); p != "" { + wsPath = normalizeOpenAIWSLogValue(p) + } + } + logOpenAIWSModeDebug( + "dial_target account_id=%d account_type=%s ws_host=%s ws_path=%s", + account.ID, + account.Type, + wsHost, + wsPath, + ) + + payload := s.buildOpenAIWSCreatePayload(reqBody, account) + payloadStrategy, removedKeys := applyOpenAIWSRetryPayloadStrategy(payload, attempt) + previousResponseID := openAIWSPayloadString(payload, "previous_response_id") + previousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(previousResponseID) + promptCacheKey := openAIWSPayloadString(payload, "prompt_cache_key") + _, hasTools := payload["tools"] + debugEnabled := isOpenAIWSModeDebugEnabled() + payloadBytes := -1 + resolvePayloadBytes := func() int { + if payloadBytes >= 0 { + return payloadBytes + } + payloadBytes = len(payloadAsJSONBytes(payload)) + return payloadBytes + } + streamValue := "-" + if raw, ok := payload["stream"]; ok { + streamValue = normalizeOpenAIWSLogValue(strings.TrimSpace(fmt.Sprintf("%v", raw))) + } + turnState := "" + turnMetadata := "" + if c != nil && c.Request != nil { + turnState = strings.TrimSpace(c.GetHeader(openAIWSTurnStateHeader)) + turnMetadata = strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)) + } + setOpenAIWSTurnMetadata(payload, turnMetadata) + payloadEventType := openAIWSPayloadString(payload, "type") + if payloadEventType == "" { + payloadEventType = "response.create" + } + if s.shouldEmitOpenAIWSPayloadSchema(attempt) { + logOpenAIWSModeInfo( + "[debug] payload_schema account_id=%d attempt=%d event=%s payload_keys=%s payload_bytes=%d payload_key_sizes=%s input_summary=%s stream=%s payload_strategy=%s removed_keys=%s has_previous_response_id=%v has_prompt_cache_key=%v has_tools=%v", + account.ID, + attempt, + payloadEventType, + normalizeOpenAIWSLogValue(strings.Join(sortedKeys(payload), ",")), + resolvePayloadBytes(), + normalizeOpenAIWSLogValue(summarizeOpenAIWSPayloadKeySizes(payload, openAIWSPayloadKeySizeTopN)), + normalizeOpenAIWSLogValue(summarizeOpenAIWSInput(payload["input"])), + streamValue, + normalizeOpenAIWSLogValue(payloadStrategy), + normalizeOpenAIWSLogValue(strings.Join(removedKeys, ",")), + previousResponseID != "", + promptCacheKey != "", + hasTools, + ) + } + + stateStore := s.getOpenAIWSStateStore() + groupID := getOpenAIGroupIDFromContext(c) + sessionHash := s.GenerateSessionHash(c, nil) + if sessionHash == "" { + var legacySessionHash string + sessionHash, legacySessionHash = openAIWSSessionHashesFromID(promptCacheKey) + attachOpenAILegacySessionHashToGin(c, legacySessionHash) + } + if turnState == "" && stateStore != nil && sessionHash != "" { + if savedTurnState, ok := stateStore.GetSessionTurnState(groupID, sessionHash); ok { + turnState = savedTurnState + } + } + preferredConnID := "" + if stateStore != nil && previousResponseID != "" { + if connID, ok := stateStore.GetResponseConn(previousResponseID); ok { + preferredConnID = connID + } + } + storeDisabled := s.isOpenAIWSStoreDisabledInRequest(reqBody, account) + if stateStore != nil && storeDisabled && previousResponseID == "" && sessionHash != "" { + if connID, ok := stateStore.GetSessionConn(groupID, sessionHash); ok { + preferredConnID = connID + } + } + storeDisabledConnMode := s.openAIWSStoreDisabledConnMode() + forceNewConnByPolicy := shouldForceNewConnOnStoreDisabled(storeDisabledConnMode, lastFailureReason) + forceNewConn := forceNewConnByPolicy && storeDisabled && previousResponseID == "" && sessionHash != "" && preferredConnID == "" + wsHeaders, sessionResolution := s.buildOpenAIWSHeaders(c, account, token, decision, isCodexCLI, turnState, turnMetadata, promptCacheKey) + logOpenAIWSModeDebug( + "acquire_start account_id=%d account_type=%s transport=%s preferred_conn_id=%s has_previous_response_id=%v session_hash=%s has_turn_state=%v turn_state_len=%d has_turn_metadata=%v turn_metadata_len=%d store_disabled=%v store_disabled_conn_mode=%s retry_last_reason=%s force_new_conn=%v header_user_agent=%s header_openai_beta=%s header_originator=%s header_accept_language=%s header_session_id=%s header_conversation_id=%s session_id_source=%s conversation_id_source=%s has_prompt_cache_key=%v has_chatgpt_account_id=%v has_authorization=%v has_session_id=%v has_conversation_id=%v proxy_enabled=%v", + account.ID, + account.Type, + normalizeOpenAIWSLogValue(string(decision.Transport)), + truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen), + previousResponseID != "", + truncateOpenAIWSLogValue(sessionHash, 12), + turnState != "", + len(turnState), + turnMetadata != "", + len(turnMetadata), + storeDisabled, + normalizeOpenAIWSLogValue(storeDisabledConnMode), + truncateOpenAIWSLogValue(lastFailureReason, openAIWSLogValueMaxLen), + forceNewConn, + openAIWSHeaderValueForLog(wsHeaders, "user-agent"), + openAIWSHeaderValueForLog(wsHeaders, "openai-beta"), + openAIWSHeaderValueForLog(wsHeaders, "originator"), + openAIWSHeaderValueForLog(wsHeaders, "accept-language"), + openAIWSHeaderValueForLog(wsHeaders, "session_id"), + openAIWSHeaderValueForLog(wsHeaders, "conversation_id"), + normalizeOpenAIWSLogValue(sessionResolution.SessionSource), + normalizeOpenAIWSLogValue(sessionResolution.ConversationSource), + promptCacheKey != "", + hasOpenAIWSHeader(wsHeaders, "chatgpt-account-id"), + hasOpenAIWSHeader(wsHeaders, "authorization"), + hasOpenAIWSHeader(wsHeaders, "session_id"), + hasOpenAIWSHeader(wsHeaders, "conversation_id"), + account.ProxyID != nil && account.Proxy != nil, + ) + + acquireCtx, acquireCancel := context.WithTimeout(ctx, s.openAIWSAcquireTimeout()) + defer acquireCancel() + + lease, err := s.getOpenAIWSConnPool().Acquire(acquireCtx, openAIWSAcquireRequest{ + Account: account, + WSURL: wsURL, + Headers: wsHeaders, + PreferredConnID: preferredConnID, + ForceNewConn: forceNewConn, + ProxyURL: func() string { + if account.ProxyID != nil && account.Proxy != nil { + return account.Proxy.URL() + } + return "" + }(), + }) + if err != nil { + dialStatus, dialClass, dialCloseStatus, dialCloseReason, dialRespServer, dialRespVia, dialRespCFRay, dialRespReqID := summarizeOpenAIWSDialError(err) + logOpenAIWSModeInfo( + "acquire_fail account_id=%d account_type=%s transport=%s reason=%s dial_status=%d dial_class=%s dial_close_status=%s dial_close_reason=%s dial_resp_server=%s dial_resp_via=%s dial_resp_cf_ray=%s dial_resp_x_request_id=%s cause=%s preferred_conn_id=%s force_new_conn=%v ws_host=%s ws_path=%s proxy_enabled=%v", + account.ID, + account.Type, + normalizeOpenAIWSLogValue(string(decision.Transport)), + normalizeOpenAIWSLogValue(classifyOpenAIWSAcquireError(err)), + dialStatus, + dialClass, + dialCloseStatus, + truncateOpenAIWSLogValue(dialCloseReason, openAIWSHeaderValueMaxLen), + dialRespServer, + dialRespVia, + dialRespCFRay, + dialRespReqID, + truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen), + forceNewConn, + wsHost, + wsPath, + account.ProxyID != nil && account.Proxy != nil, + ) + var dialErr *openAIWSDialError + if errors.As(err, &dialErr) && dialErr != nil && dialErr.StatusCode == http.StatusTooManyRequests { + s.persistOpenAIWSRateLimitSignal(ctx, account, dialErr.ResponseHeaders, nil, "rate_limit_exceeded", "rate_limit_error", strings.TrimSpace(err.Error())) + } + return nil, wrapOpenAIWSFallback(classifyOpenAIWSAcquireError(err), err) + } + // cleanExit 标记正常终端事件退出,此时上游不会再发送帧,连接可安全归还复用。 + // 所有异常路径(读写错误、error 事件等)已在各自分支中提前调用 MarkBroken, + // 因此 defer 中只需处理正常退出时不 MarkBroken 即可。 + cleanExit := false + defer func() { + if !cleanExit { + lease.MarkBroken() + } + lease.Release() + }() + connID := strings.TrimSpace(lease.ConnID()) + logOpenAIWSModeDebug( + "connected account_id=%d account_type=%s transport=%s conn_id=%s conn_reused=%v conn_pick_ms=%d queue_wait_ms=%d has_previous_response_id=%v", + account.ID, + account.Type, + normalizeOpenAIWSLogValue(string(decision.Transport)), + connID, + lease.Reused(), + lease.ConnPickDuration().Milliseconds(), + lease.QueueWaitDuration().Milliseconds(), + previousResponseID != "", + ) + if previousResponseID != "" { + logOpenAIWSModeInfo( + "continuation_probe account_id=%d account_type=%s conn_id=%s previous_response_id=%s previous_response_id_kind=%s preferred_conn_id=%s conn_reused=%v store_disabled=%v session_hash=%s header_session_id=%s header_conversation_id=%s session_id_source=%s conversation_id_source=%s has_turn_state=%v turn_state_len=%d has_prompt_cache_key=%v", + account.ID, + account.Type, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(previousResponseIDKind), + truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen), + lease.Reused(), + storeDisabled, + truncateOpenAIWSLogValue(sessionHash, 12), + openAIWSHeaderValueForLog(wsHeaders, "session_id"), + openAIWSHeaderValueForLog(wsHeaders, "conversation_id"), + normalizeOpenAIWSLogValue(sessionResolution.SessionSource), + normalizeOpenAIWSLogValue(sessionResolution.ConversationSource), + turnState != "", + len(turnState), + promptCacheKey != "", + ) + } + if c != nil { + SetOpsLatencyMs(c, OpsOpenAIWSConnPickMsKey, lease.ConnPickDuration().Milliseconds()) + SetOpsLatencyMs(c, OpsOpenAIWSQueueWaitMsKey, lease.QueueWaitDuration().Milliseconds()) + c.Set(OpsOpenAIWSConnReusedKey, lease.Reused()) + if connID != "" { + c.Set(OpsOpenAIWSConnIDKey, connID) + } + } + + handshakeTurnState := strings.TrimSpace(lease.HandshakeHeader(openAIWSTurnStateHeader)) + logOpenAIWSModeDebug( + "handshake account_id=%d conn_id=%s has_turn_state=%v turn_state_len=%d", + account.ID, + connID, + handshakeTurnState != "", + len(handshakeTurnState), + ) + if handshakeTurnState != "" { + if stateStore != nil && sessionHash != "" { + stateStore.BindSessionTurnState(groupID, sessionHash, handshakeTurnState, s.openAIWSSessionStickyTTL()) + } + if c != nil { + c.Header(http.CanonicalHeaderKey(openAIWSTurnStateHeader), handshakeTurnState) + } + } + + if err := s.performOpenAIWSGeneratePrewarm( + ctx, + lease, + decision, + payload, + previousResponseID, + reqBody, + account, + stateStore, + groupID, + ); err != nil { + return nil, err + } + + if err := lease.WriteJSONWithContextTimeout(ctx, payload, s.openAIWSWriteTimeout()); err != nil { + lease.MarkBroken() + logOpenAIWSModeInfo( + "write_request_fail account_id=%d conn_id=%s cause=%s payload_bytes=%d", + account.ID, + connID, + truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen), + resolvePayloadBytes(), + ) + return nil, wrapOpenAIWSFallback("write_request", err) + } + if debugEnabled { + logOpenAIWSModeDebug( + "write_request_sent account_id=%d conn_id=%s stream=%v payload_bytes=%d previous_response_id=%s", + account.ID, + connID, + reqStream, + resolvePayloadBytes(), + truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen), + ) + } + + usage := &OpenAIUsage{} + var firstTokenMs *int + responseID := "" + var finalResponse []byte + wroteDownstream := false + needModelReplace := originalModel != mappedModel + var mappedModelBytes []byte + if needModelReplace && mappedModel != "" { + mappedModelBytes = []byte(mappedModel) + } + bufferedStreamEvents := make([][]byte, 0, 4) + eventCount := 0 + tokenEventCount := 0 + terminalEventCount := 0 + bufferedEventCount := 0 + flushedBufferedEventCount := 0 + firstEventType := "" + lastEventType := "" + + var flusher http.Flusher + if reqStream { + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), http.Header{}, s.responseHeaderFilter) + } + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + f, ok := c.Writer.(http.Flusher) + if !ok { + lease.MarkBroken() + return nil, wrapOpenAIWSFallback("streaming_not_supported", errors.New("streaming not supported")) + } + flusher = f + } + + clientDisconnected := false + flushBatchSize := s.openAIWSEventFlushBatchSize() + flushInterval := s.openAIWSEventFlushInterval() + pendingFlushEvents := 0 + lastFlushAt := time.Now() + flushStreamWriter := func(force bool) { + if clientDisconnected || flusher == nil || pendingFlushEvents <= 0 { + return + } + if !force && flushBatchSize > 1 && pendingFlushEvents < flushBatchSize { + if flushInterval <= 0 || time.Since(lastFlushAt) < flushInterval { + return + } + } + flusher.Flush() + pendingFlushEvents = 0 + lastFlushAt = time.Now() + } + emitStreamMessage := func(message []byte, forceFlush bool) { + if clientDisconnected { + return + } + frame := make([]byte, 0, len(message)+8) + frame = append(frame, "data: "...) + frame = append(frame, message...) + frame = append(frame, '\n', '\n') + _, wErr := c.Writer.Write(frame) + if wErr == nil { + wroteDownstream = true + pendingFlushEvents++ + flushStreamWriter(forceFlush) + return + } + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "[OpenAI WS Mode] client disconnected, continue draining upstream: account=%d", account.ID) + } + flushBufferedStreamEvents := func(reason string) { + if len(bufferedStreamEvents) == 0 { + return + } + flushed := len(bufferedStreamEvents) + for _, buffered := range bufferedStreamEvents { + emitStreamMessage(buffered, false) + } + bufferedStreamEvents = bufferedStreamEvents[:0] + flushStreamWriter(true) + flushedBufferedEventCount += flushed + if debugEnabled { + logOpenAIWSModeDebug( + "buffer_flush account_id=%d conn_id=%s reason=%s flushed=%d total_flushed=%d client_disconnected=%v", + account.ID, + connID, + truncateOpenAIWSLogValue(reason, openAIWSLogValueMaxLen), + flushed, + flushedBufferedEventCount, + clientDisconnected, + ) + } + } + + readTimeout := s.openAIWSReadTimeout() + + for { + message, readErr := lease.ReadMessageWithContextTimeout(ctx, readTimeout) + if readErr != nil { + lease.MarkBroken() + closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr) + logOpenAIWSModeInfo( + "read_fail account_id=%d conn_id=%s wrote_downstream=%v close_status=%s close_reason=%s cause=%s events=%d token_events=%d terminal_events=%d buffered_pending=%d buffered_flushed=%d first_event=%s last_event=%s", + account.ID, + connID, + wroteDownstream, + closeStatus, + closeReason, + truncateOpenAIWSLogValue(readErr.Error(), openAIWSLogValueMaxLen), + eventCount, + tokenEventCount, + terminalEventCount, + len(bufferedStreamEvents), + flushedBufferedEventCount, + truncateOpenAIWSLogValue(firstEventType, openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(lastEventType, openAIWSLogValueMaxLen), + ) + if !wroteDownstream { + return nil, wrapOpenAIWSFallback(classifyOpenAIWSReadFallbackReason(readErr), readErr) + } + if clientDisconnected { + break + } + setOpsUpstreamError(c, 0, sanitizeUpstreamErrorMessage(readErr.Error()), "") + return nil, fmt.Errorf("openai ws read event: %w", readErr) + } + + eventType, eventResponseID, responseField := parseOpenAIWSEventEnvelope(message) + if eventType == "" { + continue + } + eventCount++ + if firstEventType == "" { + firstEventType = eventType + } + lastEventType = eventType + + if responseID == "" && eventResponseID != "" { + responseID = eventResponseID + } + + isTokenEvent := isOpenAIWSTokenEvent(eventType) + if isTokenEvent { + tokenEventCount++ + } + isTerminalEvent := isOpenAIWSTerminalEvent(eventType) + if isTerminalEvent { + terminalEventCount++ + } + if firstTokenMs == nil && isTokenEvent { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + if debugEnabled && shouldLogOpenAIWSEvent(eventCount, eventType) { + logOpenAIWSModeDebug( + "event_received account_id=%d conn_id=%s idx=%d type=%s bytes=%d token=%v terminal=%v buffered_pending=%d", + account.ID, + connID, + eventCount, + truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen), + len(message), + isTokenEvent, + isTerminalEvent, + len(bufferedStreamEvents), + ) + } + + if !clientDisconnected { + if needModelReplace && len(mappedModelBytes) > 0 && openAIWSEventMayContainModel(eventType) && bytes.Contains(message, mappedModelBytes) { + message = replaceOpenAIWSMessageModel(message, mappedModel, originalModel) + } + if openAIWSEventMayContainToolCalls(eventType) && openAIWSMessageLikelyContainsToolCalls(message) { + if corrected, changed := s.toolCorrector.CorrectToolCallsInSSEBytes(message); changed { + message = corrected + } + } + } + if openAIWSEventShouldParseUsage(eventType) { + parseOpenAIWSResponseUsageFromCompletedEvent(message, usage) + } + + if eventType == "error" { + errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message) + s.persistOpenAIWSRateLimitSignal(ctx, account, lease.HandshakeHeaders(), message, errCodeRaw, errTypeRaw, errMsgRaw) + errMsg := strings.TrimSpace(errMsgRaw) + if errMsg == "" { + errMsg = "Upstream websocket error" + } + fallbackReason, canFallback := classifyOpenAIWSErrorEventFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) + errCode, errType, errMessage := summarizeOpenAIWSErrorEventFieldsFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) + logOpenAIWSModeInfo( + "error_event account_id=%d conn_id=%s idx=%d fallback_reason=%s can_fallback=%v err_code=%s err_type=%s err_message=%s", + account.ID, + connID, + eventCount, + truncateOpenAIWSLogValue(fallbackReason, openAIWSLogValueMaxLen), + canFallback, + errCode, + errType, + errMessage, + ) + if fallbackReason == "previous_response_not_found" { + logOpenAIWSModeInfo( + "previous_response_not_found_diag account_id=%d account_type=%s conn_id=%s previous_response_id=%s previous_response_id_kind=%s response_id=%s event_idx=%d req_stream=%v store_disabled=%v conn_reused=%v session_hash=%s header_session_id=%s header_conversation_id=%s session_id_source=%s conversation_id_source=%s has_turn_state=%v turn_state_len=%d has_prompt_cache_key=%v err_code=%s err_type=%s err_message=%s", + account.ID, + account.Type, + connID, + truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(previousResponseIDKind), + truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen), + eventCount, + reqStream, + storeDisabled, + lease.Reused(), + truncateOpenAIWSLogValue(sessionHash, 12), + openAIWSHeaderValueForLog(wsHeaders, "session_id"), + openAIWSHeaderValueForLog(wsHeaders, "conversation_id"), + normalizeOpenAIWSLogValue(sessionResolution.SessionSource), + normalizeOpenAIWSLogValue(sessionResolution.ConversationSource), + turnState != "", + len(turnState), + promptCacheKey != "", + errCode, + errType, + errMessage, + ) + } + // error 事件后连接不再可复用,避免回池后污染下一请求。 + lease.MarkBroken() + if !wroteDownstream && canFallback { + return nil, wrapOpenAIWSFallback(fallbackReason, errors.New(errMsg)) + } + statusCode := openAIWSErrorHTTPStatusFromRaw(errCodeRaw, errTypeRaw) + setOpsUpstreamError(c, statusCode, errMsg, "") + if reqStream && !clientDisconnected { + flushBufferedStreamEvents("error_event") + emitStreamMessage(message, true) + } + if !reqStream { + c.JSON(statusCode, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": errMsg, + }, + }) + } + return nil, fmt.Errorf("openai ws error event: %s", errMsg) + } + + if reqStream { + // 在首个 token 前先缓冲事件(如 response.created), + // 以便上游早期断连时仍可安全回退到 HTTP,不给下游发送半截流。 + shouldBuffer := firstTokenMs == nil && !isTokenEvent && !isTerminalEvent + if shouldBuffer { + buffered := make([]byte, len(message)) + copy(buffered, message) + bufferedStreamEvents = append(bufferedStreamEvents, buffered) + bufferedEventCount++ + if debugEnabled && shouldLogOpenAIWSBufferedEvent(bufferedEventCount) { + logOpenAIWSModeDebug( + "buffer_enqueue account_id=%d conn_id=%s idx=%d event_idx=%d event_type=%s buffer_size=%d", + account.ID, + connID, + bufferedEventCount, + eventCount, + truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen), + len(bufferedStreamEvents), + ) + } + } else { + flushBufferedStreamEvents(eventType) + emitStreamMessage(message, isTerminalEvent) + } + } else { + if responseField.Exists() && responseField.Type == gjson.JSON { + finalResponse = []byte(responseField.Raw) + } + } + + if isTerminalEvent { + cleanExit = true + break + } + } + + if !reqStream { + if len(finalResponse) == 0 { + logOpenAIWSModeInfo( + "missing_final_response account_id=%d conn_id=%s events=%d token_events=%d terminal_events=%d wrote_downstream=%v", + account.ID, + connID, + eventCount, + tokenEventCount, + terminalEventCount, + wroteDownstream, + ) + if !wroteDownstream { + return nil, wrapOpenAIWSFallback("missing_final_response", errors.New("no terminal response payload")) + } + return nil, errors.New("ws finished without final response") + } + + if needModelReplace { + finalResponse = s.replaceModelInResponseBody(finalResponse, mappedModel, originalModel) + } + finalResponse = s.correctToolCallsInResponseBody(finalResponse) + populateOpenAIUsageFromResponseJSON(finalResponse, usage) + if responseID == "" { + responseID = strings.TrimSpace(gjson.GetBytes(finalResponse, "id").String()) + } + + c.Data(http.StatusOK, "application/json", finalResponse) + } else { + flushStreamWriter(true) + } + + if responseID != "" && stateStore != nil { + ttl := s.openAIWSResponseStickyTTL() + logOpenAIWSBindResponseAccountWarn(groupID, account.ID, responseID, stateStore.BindResponseAccount(ctx, groupID, responseID, account.ID, ttl)) + stateStore.BindResponseConn(responseID, lease.ConnID(), ttl) + } + if stateStore != nil && storeDisabled && sessionHash != "" { + stateStore.BindSessionConn(groupID, sessionHash, lease.ConnID(), s.openAIWSSessionStickyTTL()) + } + firstTokenMsValue := -1 + if firstTokenMs != nil { + firstTokenMsValue = *firstTokenMs + } + logOpenAIWSModeDebug( + "completed account_id=%d conn_id=%s response_id=%s stream=%v duration_ms=%d events=%d token_events=%d terminal_events=%d buffered_events=%d buffered_flushed=%d first_event=%s last_event=%s first_token_ms=%d wrote_downstream=%v client_disconnected=%v", + account.ID, + connID, + truncateOpenAIWSLogValue(strings.TrimSpace(responseID), openAIWSIDValueMaxLen), + reqStream, + time.Since(startTime).Milliseconds(), + eventCount, + tokenEventCount, + terminalEventCount, + bufferedEventCount, + flushedBufferedEventCount, + truncateOpenAIWSLogValue(firstEventType, openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(lastEventType, openAIWSLogValueMaxLen), + firstTokenMsValue, + wroteDownstream, + clientDisconnected, + ) + + return &OpenAIForwardResult{ + RequestID: responseID, + Usage: *usage, + Model: originalModel, + UpstreamModel: mappedModel, + ServiceTier: extractOpenAIServiceTier(reqBody), + ReasoningEffort: extractOpenAIReasoningEffort(reqBody, originalModel), + Stream: reqStream, + OpenAIWSMode: true, + ResponseHeaders: lease.HandshakeHeaders(), + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + }, nil +} + +// ProxyResponsesWebSocketFromClient 处理客户端入站 WebSocket(OpenAI Responses WS Mode)并转发到上游。 +// 当前实现按“单请求 -> 终止事件 -> 下一请求”的顺序代理,适配 Codex CLI 的 turn 模式。 +func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( + ctx context.Context, + c *gin.Context, + clientConn *coderws.Conn, + account *Account, + token string, + firstClientMessage []byte, + hooks *OpenAIWSIngressHooks, +) error { + if s == nil { + return errors.New("service is nil") + } + if c == nil { + return errors.New("gin context is nil") + } + if clientConn == nil { + return errors.New("client websocket is nil") + } + if account == nil { + return errors.New("account is nil") + } + if strings.TrimSpace(token) == "" { + return errors.New("token is empty") + } + + wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account) + modeRouterV2Enabled := s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled + ingressMode := OpenAIWSIngressModeCtxPool + if modeRouterV2Enabled { + ingressMode = account.ResolveOpenAIResponsesWebSocketV2Mode(s.cfg.Gateway.OpenAIWS.IngressModeDefault) + if ingressMode == OpenAIWSIngressModeOff { + return NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "websocket mode is disabled for this account", + nil, + ) + } + switch ingressMode { + case OpenAIWSIngressModePassthrough: + if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 { + return fmt.Errorf("websocket ingress requires ws_v2 transport, got=%s", wsDecision.Transport) + } + return s.proxyResponsesWebSocketV2Passthrough( + ctx, + c, + clientConn, + account, + token, + firstClientMessage, + hooks, + wsDecision, + ) + case OpenAIWSIngressModeCtxPool, OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated: + // continue + default: + return NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "websocket mode only supports ctx_pool/passthrough", + nil, + ) + } + } + if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 { + return fmt.Errorf("websocket ingress requires ws_v2 transport, got=%s", wsDecision.Transport) + } + dedicatedMode := modeRouterV2Enabled && ingressMode == OpenAIWSIngressModeDedicated + + wsURL, err := s.buildOpenAIResponsesWSURL(account) + if err != nil { + return fmt.Errorf("build ws url: %w", err) + } + wsHost := "-" + wsPath := "-" + if parsedURL, parseErr := url.Parse(wsURL); parseErr == nil && parsedURL != nil { + wsHost = normalizeOpenAIWSLogValue(parsedURL.Host) + wsPath = normalizeOpenAIWSLogValue(parsedURL.Path) + } + debugEnabled := isOpenAIWSModeDebugEnabled() + + type openAIWSClientPayload struct { + payloadRaw []byte + rawForHash []byte + promptCacheKey string + previousResponseID string + originalModel string + payloadBytes int + } + + applyPayloadMutation := func(current []byte, path string, value any) ([]byte, error) { + next, err := sjson.SetBytes(current, path, value) + if err == nil { + return next, nil + } + + // 仅在确实需要修改 payload 且 sjson 失败时,退回 map 路径确保兼容性。 + payload := make(map[string]any) + if unmarshalErr := json.Unmarshal(current, &payload); unmarshalErr != nil { + return nil, err + } + switch path { + case "type", "model": + payload[path] = value + case "client_metadata." + openAIWSTurnMetadataHeader: + setOpenAIWSTurnMetadata(payload, fmt.Sprintf("%v", value)) + default: + return nil, err + } + rebuilt, marshalErr := json.Marshal(payload) + if marshalErr != nil { + return nil, marshalErr + } + return rebuilt, nil + } + + parseClientPayload := func(raw []byte) (openAIWSClientPayload, error) { + trimmed := bytes.TrimSpace(raw) + if len(trimmed) == 0 { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "empty websocket request payload", nil) + } + if !gjson.ValidBytes(trimmed) { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", errors.New("invalid json")) + } + + values := gjson.GetManyBytes(trimmed, "type", "model", "prompt_cache_key", "previous_response_id") + eventType := strings.TrimSpace(values[0].String()) + normalized := trimmed + switch eventType { + case "": + eventType = "response.create" + next, setErr := applyPayloadMutation(normalized, "type", eventType) + if setErr != nil { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", setErr) + } + normalized = next + case "response.create": + case "response.append": + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "response.append is not supported in ws v2; use response.create with previous_response_id", + nil, + ) + default: + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + fmt.Sprintf("unsupported websocket request type: %s", eventType), + nil, + ) + } + + originalModel := strings.TrimSpace(values[1].String()) + if originalModel == "" { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "model is required in response.create payload", + nil, + ) + } + promptCacheKey := strings.TrimSpace(values[2].String()) + previousResponseID := strings.TrimSpace(values[3].String()) + previousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(previousResponseID) + if previousResponseID != "" && previousResponseIDKind == OpenAIPreviousResponseIDKindMessageID { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "previous_response_id must be a response.id (resp_*), not a message id", + nil, + ) + } + if turnMetadata := strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)); turnMetadata != "" { + next, setErr := applyPayloadMutation(normalized, "client_metadata."+openAIWSTurnMetadataHeader, turnMetadata) + if setErr != nil { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", setErr) + } + normalized = next + } + mappedModel := account.GetMappedModel(originalModel) + if normalizedModel := normalizeCodexModel(mappedModel); normalizedModel != "" { + mappedModel = normalizedModel + } + if mappedModel != originalModel { + next, setErr := applyPayloadMutation(normalized, "model", mappedModel) + if setErr != nil { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", setErr) + } + normalized = next + } + + return openAIWSClientPayload{ + payloadRaw: normalized, + rawForHash: trimmed, + promptCacheKey: promptCacheKey, + previousResponseID: previousResponseID, + originalModel: originalModel, + payloadBytes: len(normalized), + }, nil + } + + firstPayload, err := parseClientPayload(firstClientMessage) + if err != nil { + return err + } + + turnState := strings.TrimSpace(c.GetHeader(openAIWSTurnStateHeader)) + stateStore := s.getOpenAIWSStateStore() + groupID := getOpenAIGroupIDFromContext(c) + sessionHash := s.GenerateSessionHash(c, firstPayload.rawForHash) + if turnState == "" && stateStore != nil && sessionHash != "" { + if savedTurnState, ok := stateStore.GetSessionTurnState(groupID, sessionHash); ok { + turnState = savedTurnState + } + } + + preferredConnID := "" + if stateStore != nil && firstPayload.previousResponseID != "" { + if connID, ok := stateStore.GetResponseConn(firstPayload.previousResponseID); ok { + preferredConnID = connID + } + } + + storeDisabled := s.isOpenAIWSStoreDisabledInRequestRaw(firstPayload.payloadRaw, account) + storeDisabledConnMode := s.openAIWSStoreDisabledConnMode() + if stateStore != nil && storeDisabled && firstPayload.previousResponseID == "" && sessionHash != "" { + if connID, ok := stateStore.GetSessionConn(groupID, sessionHash); ok { + preferredConnID = connID + } + } + + isCodexCLI := openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI) + wsHeaders, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, turnState, strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)), firstPayload.promptCacheKey) + baseAcquireReq := openAIWSAcquireRequest{ + Account: account, + WSURL: wsURL, + Headers: wsHeaders, + ProxyURL: func() string { + if account.ProxyID != nil && account.Proxy != nil { + return account.Proxy.URL() + } + return "" + }(), + ForceNewConn: false, + } + pool := s.getOpenAIWSConnPool() + if pool == nil { + return errors.New("openai ws conn pool is nil") + } + + logOpenAIWSModeInfo( + "ingress_ws_protocol_confirm account_id=%d account_type=%s transport=%s ws_host=%s ws_path=%s ws_mode=%s store_disabled=%v has_session_hash=%v has_previous_response_id=%v", + account.ID, + account.Type, + normalizeOpenAIWSLogValue(string(wsDecision.Transport)), + wsHost, + wsPath, + normalizeOpenAIWSLogValue(ingressMode), + storeDisabled, + sessionHash != "", + firstPayload.previousResponseID != "", + ) + + if debugEnabled { + logOpenAIWSModeDebug( + "ingress_ws_start account_id=%d account_type=%s transport=%s ws_host=%s preferred_conn_id=%s has_session_hash=%v has_previous_response_id=%v store_disabled=%v", + account.ID, + account.Type, + normalizeOpenAIWSLogValue(string(wsDecision.Transport)), + wsHost, + truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen), + sessionHash != "", + firstPayload.previousResponseID != "", + storeDisabled, + ) + } + if firstPayload.previousResponseID != "" { + firstPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(firstPayload.previousResponseID) + logOpenAIWSModeInfo( + "ingress_ws_continuation_probe account_id=%d turn=%d previous_response_id=%s previous_response_id_kind=%s preferred_conn_id=%s session_hash=%s header_session_id=%s header_conversation_id=%s has_turn_state=%v turn_state_len=%d has_prompt_cache_key=%v store_disabled=%v", + account.ID, + 1, + truncateOpenAIWSLogValue(firstPayload.previousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(firstPreviousResponseIDKind), + truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(sessionHash, 12), + openAIWSHeaderValueForLog(baseAcquireReq.Headers, "session_id"), + openAIWSHeaderValueForLog(baseAcquireReq.Headers, "conversation_id"), + turnState != "", + len(turnState), + firstPayload.promptCacheKey != "", + storeDisabled, + ) + } + + acquireTimeout := s.openAIWSAcquireTimeout() + if acquireTimeout <= 0 { + acquireTimeout = 30 * time.Second + } + + acquireTurnLease := func(turn int, preferred string, forcePreferredConn bool) (*openAIWSConnLease, error) { + req := cloneOpenAIWSAcquireRequest(baseAcquireReq) + req.PreferredConnID = strings.TrimSpace(preferred) + req.ForcePreferredConn = forcePreferredConn + // dedicated 模式下每次获取均新建连接,避免跨会话复用残留上下文。 + req.ForceNewConn = dedicatedMode + acquireCtx, acquireCancel := context.WithTimeout(ctx, acquireTimeout) + lease, acquireErr := pool.Acquire(acquireCtx, req) + acquireCancel() + if acquireErr != nil { + dialStatus, dialClass, dialCloseStatus, dialCloseReason, dialRespServer, dialRespVia, dialRespCFRay, dialRespReqID := summarizeOpenAIWSDialError(acquireErr) + logOpenAIWSModeInfo( + "ingress_ws_upstream_acquire_fail account_id=%d turn=%d reason=%s dial_status=%d dial_class=%s dial_close_status=%s dial_close_reason=%s dial_resp_server=%s dial_resp_via=%s dial_resp_cf_ray=%s dial_resp_x_request_id=%s cause=%s preferred_conn_id=%s force_preferred_conn=%v ws_host=%s ws_path=%s proxy_enabled=%v", + account.ID, + turn, + normalizeOpenAIWSLogValue(classifyOpenAIWSAcquireError(acquireErr)), + dialStatus, + dialClass, + dialCloseStatus, + truncateOpenAIWSLogValue(dialCloseReason, openAIWSHeaderValueMaxLen), + dialRespServer, + dialRespVia, + dialRespCFRay, + dialRespReqID, + truncateOpenAIWSLogValue(acquireErr.Error(), openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(preferred, openAIWSIDValueMaxLen), + forcePreferredConn, + wsHost, + wsPath, + account.ProxyID != nil && account.Proxy != nil, + ) + var dialErr *openAIWSDialError + if errors.As(acquireErr, &dialErr) && dialErr != nil && dialErr.StatusCode == http.StatusTooManyRequests { + s.persistOpenAIWSRateLimitSignal(ctx, account, dialErr.ResponseHeaders, nil, "rate_limit_exceeded", "rate_limit_error", strings.TrimSpace(acquireErr.Error())) + } + if errors.Is(acquireErr, errOpenAIWSPreferredConnUnavailable) { + return nil, NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "upstream continuation connection is unavailable; please restart the conversation", + acquireErr, + ) + } + if errors.Is(acquireErr, context.DeadlineExceeded) || errors.Is(acquireErr, errOpenAIWSConnQueueFull) { + return nil, NewOpenAIWSClientCloseError( + coderws.StatusTryAgainLater, + "upstream websocket is busy, please retry later", + acquireErr, + ) + } + return nil, acquireErr + } + connID := strings.TrimSpace(lease.ConnID()) + if handshakeTurnState := strings.TrimSpace(lease.HandshakeHeader(openAIWSTurnStateHeader)); handshakeTurnState != "" { + turnState = handshakeTurnState + if stateStore != nil && sessionHash != "" { + stateStore.BindSessionTurnState(groupID, sessionHash, handshakeTurnState, s.openAIWSSessionStickyTTL()) + } + updatedHeaders := cloneHeader(baseAcquireReq.Headers) + if updatedHeaders == nil { + updatedHeaders = make(http.Header) + } + updatedHeaders.Set(openAIWSTurnStateHeader, handshakeTurnState) + baseAcquireReq.Headers = updatedHeaders + } + logOpenAIWSModeInfo( + "ingress_ws_upstream_connected account_id=%d turn=%d conn_id=%s conn_reused=%v conn_pick_ms=%d queue_wait_ms=%d preferred_conn_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + lease.Reused(), + lease.ConnPickDuration().Milliseconds(), + lease.QueueWaitDuration().Milliseconds(), + truncateOpenAIWSLogValue(preferred, openAIWSIDValueMaxLen), + ) + return lease, nil + } + + writeClientMessage := func(message []byte) error { + writeCtx, cancel := context.WithTimeout(ctx, s.openAIWSWriteTimeout()) + defer cancel() + return clientConn.Write(writeCtx, coderws.MessageText, message) + } + + readClientMessage := func() ([]byte, error) { + msgType, payload, readErr := clientConn.Read(ctx) + if readErr != nil { + return nil, readErr + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + return nil, NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + fmt.Sprintf("unsupported websocket client message type: %s", msgType.String()), + nil, + ) + } + return payload, nil + } + + sendAndRelay := func(turn int, lease *openAIWSConnLease, payload []byte, payloadBytes int, originalModel string) (*OpenAIForwardResult, error) { + if lease == nil { + return nil, errors.New("upstream websocket lease is nil") + } + turnStart := time.Now() + wroteDownstream := false + if err := lease.WriteJSONWithContextTimeout(ctx, json.RawMessage(payload), s.openAIWSWriteTimeout()); err != nil { + return nil, wrapOpenAIWSIngressTurnError( + "write_upstream", + fmt.Errorf("write upstream websocket request: %w", err), + false, + ) + } + if debugEnabled { + logOpenAIWSModeDebug( + "ingress_ws_turn_request_sent account_id=%d turn=%d conn_id=%s payload_bytes=%d", + account.ID, + turn, + truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), + payloadBytes, + ) + } + + responseID := "" + usage := OpenAIUsage{} + var firstTokenMs *int + reqStream := openAIWSPayloadBoolFromRaw(payload, "stream", true) + turnPreviousResponseID := openAIWSPayloadStringFromRaw(payload, "previous_response_id") + turnPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(turnPreviousResponseID) + turnPromptCacheKey := openAIWSPayloadStringFromRaw(payload, "prompt_cache_key") + turnStoreDisabled := s.isOpenAIWSStoreDisabledInRequestRaw(payload, account) + turnHasFunctionCallOutput := gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists() + eventCount := 0 + tokenEventCount := 0 + terminalEventCount := 0 + firstEventType := "" + lastEventType := "" + needModelReplace := false + clientDisconnected := false + mappedModel := "" + var mappedModelBytes []byte + if originalModel != "" { + mappedModel = account.GetMappedModel(originalModel) + if normalizedModel := normalizeCodexModel(mappedModel); normalizedModel != "" { + mappedModel = normalizedModel + } + needModelReplace = mappedModel != "" && mappedModel != originalModel + if needModelReplace { + mappedModelBytes = []byte(mappedModel) + } + } + for { + upstreamMessage, readErr := lease.ReadMessageWithContextTimeout(ctx, s.openAIWSReadTimeout()) + if readErr != nil { + lease.MarkBroken() + return nil, wrapOpenAIWSIngressTurnError( + "read_upstream", + fmt.Errorf("read upstream websocket event: %w", readErr), + wroteDownstream, + ) + } + + eventType, eventResponseID, _ := parseOpenAIWSEventEnvelope(upstreamMessage) + if responseID == "" && eventResponseID != "" { + responseID = eventResponseID + } + if eventType != "" { + eventCount++ + if firstEventType == "" { + firstEventType = eventType + } + lastEventType = eventType + } + if eventType == "error" { + errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(upstreamMessage) + s.persistOpenAIWSRateLimitSignal(ctx, account, lease.HandshakeHeaders(), upstreamMessage, errCodeRaw, errTypeRaw, errMsgRaw) + fallbackReason, _ := classifyOpenAIWSErrorEventFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) + errCode, errType, errMessage := summarizeOpenAIWSErrorEventFieldsFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) + recoverablePrevNotFound := fallbackReason == openAIWSIngressStagePreviousResponseNotFound && + turnPreviousResponseID != "" && + !turnHasFunctionCallOutput && + s.openAIWSIngressPreviousResponseRecoveryEnabled() && + !wroteDownstream + if recoverablePrevNotFound { + // 可恢复场景使用非 error 关键字日志,避免被 LegacyPrintf 误判为 ERROR 级别。 + logOpenAIWSModeInfo( + "ingress_ws_prev_response_recoverable account_id=%d turn=%d conn_id=%s idx=%d reason=%s code=%s type=%s message=%s previous_response_id=%s previous_response_id_kind=%s response_id=%s store_disabled=%v has_prompt_cache_key=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), + eventCount, + truncateOpenAIWSLogValue(fallbackReason, openAIWSLogValueMaxLen), + errCode, + errType, + errMessage, + truncateOpenAIWSLogValue(turnPreviousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(turnPreviousResponseIDKind), + truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen), + turnStoreDisabled, + turnPromptCacheKey != "", + ) + } else { + logOpenAIWSModeInfo( + "ingress_ws_error_event account_id=%d turn=%d conn_id=%s idx=%d fallback_reason=%s err_code=%s err_type=%s err_message=%s previous_response_id=%s previous_response_id_kind=%s response_id=%s store_disabled=%v has_prompt_cache_key=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), + eventCount, + truncateOpenAIWSLogValue(fallbackReason, openAIWSLogValueMaxLen), + errCode, + errType, + errMessage, + truncateOpenAIWSLogValue(turnPreviousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(turnPreviousResponseIDKind), + truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen), + turnStoreDisabled, + turnPromptCacheKey != "", + ) + } + // previous_response_not_found 在 ingress 模式支持单次恢复重试: + // 不把该 error 直接下发客户端,而是由上层去掉 previous_response_id 后重放当前 turn。 + if recoverablePrevNotFound { + lease.MarkBroken() + errMsg := strings.TrimSpace(errMsgRaw) + if errMsg == "" { + errMsg = "previous response not found" + } + return nil, wrapOpenAIWSIngressTurnError( + openAIWSIngressStagePreviousResponseNotFound, + errors.New(errMsg), + false, + ) + } + } + isTokenEvent := isOpenAIWSTokenEvent(eventType) + if isTokenEvent { + tokenEventCount++ + } + isTerminalEvent := isOpenAIWSTerminalEvent(eventType) + if isTerminalEvent { + terminalEventCount++ + } + if firstTokenMs == nil && isTokenEvent { + ms := int(time.Since(turnStart).Milliseconds()) + firstTokenMs = &ms + } + if openAIWSEventShouldParseUsage(eventType) { + parseOpenAIWSResponseUsageFromCompletedEvent(upstreamMessage, &usage) + } + + if !clientDisconnected { + if needModelReplace && len(mappedModelBytes) > 0 && openAIWSEventMayContainModel(eventType) && bytes.Contains(upstreamMessage, mappedModelBytes) { + upstreamMessage = replaceOpenAIWSMessageModel(upstreamMessage, mappedModel, originalModel) + } + if openAIWSEventMayContainToolCalls(eventType) && openAIWSMessageLikelyContainsToolCalls(upstreamMessage) { + if corrected, changed := s.toolCorrector.CorrectToolCallsInSSEBytes(upstreamMessage); changed { + upstreamMessage = corrected + } + } + if err := writeClientMessage(upstreamMessage); err != nil { + if isOpenAIWSClientDisconnectError(err) { + clientDisconnected = true + closeStatus, closeReason := summarizeOpenAIWSReadCloseError(err) + logOpenAIWSModeInfo( + "ingress_ws_client_disconnected_drain account_id=%d turn=%d conn_id=%s close_status=%s close_reason=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), + closeStatus, + truncateOpenAIWSLogValue(closeReason, openAIWSHeaderValueMaxLen), + ) + } else { + return nil, wrapOpenAIWSIngressTurnError( + "write_client", + fmt.Errorf("write client websocket event: %w", err), + wroteDownstream, + ) + } + } else { + wroteDownstream = true + } + } + if isTerminalEvent { + // 客户端已断连时,上游连接的 session 状态不可信,标记 broken 避免回池复用。 + if clientDisconnected { + lease.MarkBroken() + } + firstTokenMsValue := -1 + if firstTokenMs != nil { + firstTokenMsValue = *firstTokenMs + } + if debugEnabled { + logOpenAIWSModeDebug( + "ingress_ws_turn_completed account_id=%d turn=%d conn_id=%s response_id=%s duration_ms=%d events=%d token_events=%d terminal_events=%d first_event=%s last_event=%s first_token_ms=%d client_disconnected=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen), + time.Since(turnStart).Milliseconds(), + eventCount, + tokenEventCount, + terminalEventCount, + truncateOpenAIWSLogValue(firstEventType, openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(lastEventType, openAIWSLogValueMaxLen), + firstTokenMsValue, + clientDisconnected, + ) + } + return &OpenAIForwardResult{ + RequestID: responseID, + Usage: usage, + Model: originalModel, + UpstreamModel: mappedModel, + ServiceTier: extractOpenAIServiceTierFromBody(payload), + ReasoningEffort: extractOpenAIReasoningEffortFromBody(payload, originalModel), + Stream: reqStream, + OpenAIWSMode: true, + ResponseHeaders: lease.HandshakeHeaders(), + Duration: time.Since(turnStart), + FirstTokenMs: firstTokenMs, + }, nil + } + } + } + + currentPayload := firstPayload.payloadRaw + currentOriginalModel := firstPayload.originalModel + currentPayloadBytes := firstPayload.payloadBytes + isStrictAffinityTurn := func(payload []byte) bool { + if !storeDisabled { + return false + } + return strings.TrimSpace(openAIWSPayloadStringFromRaw(payload, "previous_response_id")) != "" + } + var sessionLease *openAIWSConnLease + sessionConnID := "" + pinnedSessionConnID := "" + unpinSessionConn := func(connID string) { + connID = strings.TrimSpace(connID) + if connID == "" || pinnedSessionConnID != connID { + return + } + pool.UnpinConn(account.ID, connID) + pinnedSessionConnID = "" + } + pinSessionConn := func(connID string) { + if !storeDisabled { + return + } + connID = strings.TrimSpace(connID) + if connID == "" || pinnedSessionConnID == connID { + return + } + if pinnedSessionConnID != "" { + pool.UnpinConn(account.ID, pinnedSessionConnID) + pinnedSessionConnID = "" + } + if pool.PinConn(account.ID, connID) { + pinnedSessionConnID = connID + } + } + // lastTurnClean 标记最后一轮 sendAndRelay 是否正常完成(收到终端事件且客户端未断连)。 + // 所有异常路径(读写错误、error 事件、客户端断连)已在各自分支或上层(L3403)中 MarkBroken, + // 因此 releaseSessionLease 中只需在非正常结束时 MarkBroken。 + lastTurnClean := false + releaseSessionLease := func() { + if sessionLease == nil { + return + } + if !lastTurnClean { + sessionLease.MarkBroken() + } + unpinSessionConn(sessionConnID) + sessionLease.Release() + if debugEnabled { + logOpenAIWSModeDebug( + "ingress_ws_upstream_released account_id=%d conn_id=%s", + account.ID, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + ) + } + } + defer releaseSessionLease() + + turn := 1 + turnRetry := 0 + turnPrevRecoveryTried := false + lastTurnFinishedAt := time.Time{} + lastTurnResponseID := "" + lastTurnPayload := []byte(nil) + var lastTurnStrictState *openAIWSIngressPreviousTurnStrictState + lastTurnReplayInput := []json.RawMessage(nil) + lastTurnReplayInputExists := false + currentTurnReplayInput := []json.RawMessage(nil) + currentTurnReplayInputExists := false + skipBeforeTurn := false + resetSessionLease := func(markBroken bool) { + if sessionLease == nil { + return + } + if markBroken { + sessionLease.MarkBroken() + } + releaseSessionLease() + sessionLease = nil + sessionConnID = "" + preferredConnID = "" + } + recoverIngressPrevResponseNotFound := func(relayErr error, turn int, connID string) bool { + if !isOpenAIWSIngressPreviousResponseNotFound(relayErr) { + return false + } + if turnPrevRecoveryTried || !s.openAIWSIngressPreviousResponseRecoveryEnabled() { + return false + } + if isStrictAffinityTurn(currentPayload) { + // Layer 2:严格亲和链路命中 previous_response_not_found 时,降级为“去掉 previous_response_id 后重放一次”。 + // 该错误说明续链锚点已失效,继续 strict fail-close 只会直接中断本轮请求。 + logOpenAIWSModeInfo( + "ingress_ws_prev_response_recovery_layer2 account_id=%d turn=%d conn_id=%s store_disabled_conn_mode=%s action=drop_previous_response_id_retry", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(storeDisabledConnMode), + ) + } + turnPrevRecoveryTried = true + updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload) + if dropErr != nil || !removed { + reason := "not_removed" + if dropErr != nil { + reason = "drop_error" + } + logOpenAIWSModeInfo( + "ingress_ws_prev_response_recovery_skip account_id=%d turn=%d conn_id=%s reason=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(reason), + ) + return false + } + updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence( + updatedPayload, + currentTurnReplayInput, + currentTurnReplayInputExists, + ) + if setInputErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_prev_response_recovery_skip account_id=%d turn=%d conn_id=%s reason=set_full_input_error cause=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(setInputErr.Error(), openAIWSLogValueMaxLen), + ) + return false + } + logOpenAIWSModeInfo( + "ingress_ws_prev_response_recovery account_id=%d turn=%d conn_id=%s action=drop_previous_response_id retry=1", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + ) + currentPayload = updatedWithInput + currentPayloadBytes = len(updatedWithInput) + resetSessionLease(true) + skipBeforeTurn = true + return true + } + retryIngressTurn := func(relayErr error, turn int, connID string) bool { + if !isOpenAIWSIngressTurnRetryable(relayErr) || turnRetry >= 1 { + return false + } + if isStrictAffinityTurn(currentPayload) { + logOpenAIWSModeInfo( + "ingress_ws_turn_retry_skip account_id=%d turn=%d conn_id=%s reason=strict_affinity", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + ) + return false + } + turnRetry++ + logOpenAIWSModeInfo( + "ingress_ws_turn_retry account_id=%d turn=%d retry=%d reason=%s conn_id=%s", + account.ID, + turn, + turnRetry, + truncateOpenAIWSLogValue(openAIWSIngressTurnRetryReason(relayErr), openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + ) + resetSessionLease(true) + skipBeforeTurn = true + return true + } + for { + if !skipBeforeTurn && hooks != nil && hooks.BeforeTurn != nil { + if err := hooks.BeforeTurn(turn); err != nil { + return err + } + } + skipBeforeTurn = false + currentPreviousResponseID := openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id") + expectedPrev := strings.TrimSpace(lastTurnResponseID) + hasFunctionCallOutput := gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists() + // store=false + function_call_output 场景必须有续链锚点。 + // 若客户端未传 previous_response_id,优先回填上一轮响应 ID,避免上游报 call_id 无法关联。 + if shouldInferIngressFunctionCallOutputPreviousResponseID( + storeDisabled, + turn, + hasFunctionCallOutput, + currentPreviousResponseID, + expectedPrev, + ) { + updatedPayload, setPrevErr := setPreviousResponseIDToRawPayload(currentPayload, expectedPrev) + if setPrevErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_function_call_output_prev_infer_skip account_id=%d turn=%d conn_id=%s reason=set_previous_response_id_error cause=%s expected_previous_response_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(setPrevErr.Error(), openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + ) + } else { + currentPayload = updatedPayload + currentPayloadBytes = len(updatedPayload) + currentPreviousResponseID = expectedPrev + logOpenAIWSModeInfo( + "ingress_ws_function_call_output_prev_infer account_id=%d turn=%d conn_id=%s action=set_previous_response_id previous_response_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + ) + } + } + nextReplayInput, nextReplayInputExists, replayInputErr := buildOpenAIWSReplayInputSequence( + lastTurnReplayInput, + lastTurnReplayInputExists, + currentPayload, + currentPreviousResponseID != "", + ) + if replayInputErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_replay_input_skip account_id=%d turn=%d conn_id=%s reason=build_error cause=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(replayInputErr.Error(), openAIWSLogValueMaxLen), + ) + currentTurnReplayInput = nil + currentTurnReplayInputExists = false + } else { + currentTurnReplayInput = nextReplayInput + currentTurnReplayInputExists = nextReplayInputExists + } + if storeDisabled && turn > 1 && currentPreviousResponseID != "" { + shouldKeepPreviousResponseID := false + strictReason := "" + var strictErr error + if lastTurnStrictState != nil { + shouldKeepPreviousResponseID, strictReason, strictErr = shouldKeepIngressPreviousResponseIDWithStrictState( + lastTurnStrictState, + currentPayload, + lastTurnResponseID, + hasFunctionCallOutput, + ) + } else { + shouldKeepPreviousResponseID, strictReason, strictErr = shouldKeepIngressPreviousResponseID( + lastTurnPayload, + currentPayload, + lastTurnResponseID, + hasFunctionCallOutput, + ) + } + if strictErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_prev_response_strict_eval account_id=%d turn=%d conn_id=%s action=keep_previous_response_id reason=%s cause=%s previous_response_id=%s expected_previous_response_id=%s has_function_call_output=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(strictReason), + truncateOpenAIWSLogValue(strictErr.Error(), openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + hasFunctionCallOutput, + ) + } else if !shouldKeepPreviousResponseID { + updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload) + if dropErr != nil || !removed { + dropReason := "not_removed" + if dropErr != nil { + dropReason = "drop_error" + } + logOpenAIWSModeInfo( + "ingress_ws_prev_response_strict_eval account_id=%d turn=%d conn_id=%s action=keep_previous_response_id reason=%s drop_reason=%s previous_response_id=%s expected_previous_response_id=%s has_function_call_output=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(strictReason), + normalizeOpenAIWSLogValue(dropReason), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + hasFunctionCallOutput, + ) + } else { + updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence( + updatedPayload, + currentTurnReplayInput, + currentTurnReplayInputExists, + ) + if setInputErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_prev_response_strict_eval account_id=%d turn=%d conn_id=%s action=keep_previous_response_id reason=%s drop_reason=set_full_input_error previous_response_id=%s expected_previous_response_id=%s cause=%s has_function_call_output=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(strictReason), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(setInputErr.Error(), openAIWSLogValueMaxLen), + hasFunctionCallOutput, + ) + } else { + currentPayload = updatedWithInput + currentPayloadBytes = len(updatedWithInput) + logOpenAIWSModeInfo( + "ingress_ws_prev_response_strict_eval account_id=%d turn=%d conn_id=%s action=drop_previous_response_id_full_create reason=%s previous_response_id=%s expected_previous_response_id=%s has_function_call_output=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(strictReason), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + hasFunctionCallOutput, + ) + currentPreviousResponseID = "" + } + } + } + } + forcePreferredConn := isStrictAffinityTurn(currentPayload) + if sessionLease == nil { + acquiredLease, acquireErr := acquireTurnLease(turn, preferredConnID, forcePreferredConn) + if acquireErr != nil { + return fmt.Errorf("acquire upstream websocket: %w", acquireErr) + } + sessionLease = acquiredLease + sessionConnID = strings.TrimSpace(sessionLease.ConnID()) + if storeDisabled { + pinSessionConn(sessionConnID) + } else { + unpinSessionConn(sessionConnID) + } + } + shouldPreflightPing := turn > 1 && sessionLease != nil && turnRetry == 0 + if shouldPreflightPing && openAIWSIngressPreflightPingIdle > 0 && !lastTurnFinishedAt.IsZero() { + if time.Since(lastTurnFinishedAt) < openAIWSIngressPreflightPingIdle { + shouldPreflightPing = false + } + } + if shouldPreflightPing { + if pingErr := sessionLease.PingWithTimeout(openAIWSConnHealthCheckTO); pingErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_upstream_preflight_ping_fail account_id=%d turn=%d conn_id=%s cause=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(pingErr.Error(), openAIWSLogValueMaxLen), + ) + if forcePreferredConn { + if !turnPrevRecoveryTried && currentPreviousResponseID != "" { + updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload) + if dropErr != nil || !removed { + reason := "not_removed" + if dropErr != nil { + reason = "drop_error" + } + logOpenAIWSModeInfo( + "ingress_ws_preflight_ping_recovery_skip account_id=%d turn=%d conn_id=%s reason=%s previous_response_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(reason), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + ) + } else { + updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence( + updatedPayload, + currentTurnReplayInput, + currentTurnReplayInputExists, + ) + if setInputErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_preflight_ping_recovery_skip account_id=%d turn=%d conn_id=%s reason=set_full_input_error previous_response_id=%s cause=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(setInputErr.Error(), openAIWSLogValueMaxLen), + ) + } else { + logOpenAIWSModeInfo( + "ingress_ws_preflight_ping_recovery account_id=%d turn=%d conn_id=%s action=drop_previous_response_id_retry previous_response_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + ) + turnPrevRecoveryTried = true + currentPayload = updatedWithInput + currentPayloadBytes = len(updatedWithInput) + resetSessionLease(true) + skipBeforeTurn = true + continue + } + } + } + resetSessionLease(true) + return NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "upstream continuation connection is unavailable; please restart the conversation", + pingErr, + ) + } + resetSessionLease(true) + + acquiredLease, acquireErr := acquireTurnLease(turn, preferredConnID, forcePreferredConn) + if acquireErr != nil { + return fmt.Errorf("acquire upstream websocket after preflight ping fail: %w", acquireErr) + } + sessionLease = acquiredLease + sessionConnID = strings.TrimSpace(sessionLease.ConnID()) + if storeDisabled { + pinSessionConn(sessionConnID) + } + } + } + connID := sessionConnID + if currentPreviousResponseID != "" { + chainedFromLast := expectedPrev != "" && currentPreviousResponseID == expectedPrev + currentPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(currentPreviousResponseID) + logOpenAIWSModeInfo( + "ingress_ws_turn_chain account_id=%d turn=%d conn_id=%s previous_response_id=%s previous_response_id_kind=%s last_turn_response_id=%s chained_from_last=%v preferred_conn_id=%s header_session_id=%s header_conversation_id=%s has_turn_state=%v turn_state_len=%d has_prompt_cache_key=%v store_disabled=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(currentPreviousResponseIDKind), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + chainedFromLast, + truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen), + openAIWSHeaderValueForLog(baseAcquireReq.Headers, "session_id"), + openAIWSHeaderValueForLog(baseAcquireReq.Headers, "conversation_id"), + turnState != "", + len(turnState), + openAIWSPayloadStringFromRaw(currentPayload, "prompt_cache_key") != "", + storeDisabled, + ) + } + + result, relayErr := sendAndRelay(turn, sessionLease, currentPayload, currentPayloadBytes, currentOriginalModel) + if relayErr != nil { + lastTurnClean = false + if recoverIngressPrevResponseNotFound(relayErr, turn, connID) { + continue + } + if retryIngressTurn(relayErr, turn, connID) { + continue + } + finalErr := relayErr + if unwrapped := errors.Unwrap(relayErr); unwrapped != nil { + finalErr = unwrapped + } + if hooks != nil && hooks.AfterTurn != nil { + hooks.AfterTurn(turn, nil, finalErr) + } + sessionLease.MarkBroken() + return finalErr + } + turnRetry = 0 + turnPrevRecoveryTried = false + lastTurnFinishedAt = time.Now() + lastTurnClean = true + if hooks != nil && hooks.AfterTurn != nil { + hooks.AfterTurn(turn, result, nil) + } + if result == nil { + return errors.New("websocket turn result is nil") + } + responseID := strings.TrimSpace(result.RequestID) + lastTurnResponseID = responseID + lastTurnPayload = cloneOpenAIWSPayloadBytes(currentPayload) + lastTurnReplayInput = cloneOpenAIWSRawMessages(currentTurnReplayInput) + lastTurnReplayInputExists = currentTurnReplayInputExists + nextStrictState, strictStateErr := buildOpenAIWSIngressPreviousTurnStrictState(currentPayload) + if strictStateErr != nil { + lastTurnStrictState = nil + logOpenAIWSModeInfo( + "ingress_ws_prev_response_strict_state_skip account_id=%d turn=%d conn_id=%s reason=build_error cause=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(strictStateErr.Error(), openAIWSLogValueMaxLen), + ) + } else { + lastTurnStrictState = nextStrictState + } + + if responseID != "" && stateStore != nil { + ttl := s.openAIWSResponseStickyTTL() + logOpenAIWSBindResponseAccountWarn(groupID, account.ID, responseID, stateStore.BindResponseAccount(ctx, groupID, responseID, account.ID, ttl)) + stateStore.BindResponseConn(responseID, connID, ttl) + } + if stateStore != nil && storeDisabled && sessionHash != "" { + stateStore.BindSessionConn(groupID, sessionHash, connID, s.openAIWSSessionStickyTTL()) + } + if connID != "" { + preferredConnID = connID + } + + nextClientMessage, readErr := readClientMessage() + if readErr != nil { + if isOpenAIWSClientDisconnectError(readErr) { + closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr) + logOpenAIWSModeInfo( + "ingress_ws_client_closed account_id=%d conn_id=%s close_status=%s close_reason=%s", + account.ID, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + closeStatus, + truncateOpenAIWSLogValue(closeReason, openAIWSHeaderValueMaxLen), + ) + return nil + } + return fmt.Errorf("read client websocket request: %w", readErr) + } + + nextPayload, parseErr := parseClientPayload(nextClientMessage) + if parseErr != nil { + return parseErr + } + if nextPayload.promptCacheKey != "" { + // ingress 会话在整个客户端 WS 生命周期内复用同一上游连接; + // prompt_cache_key 对握手头的更新仅在未来需要重新建连时生效。 + updatedHeaders, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, turnState, strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)), nextPayload.promptCacheKey) + baseAcquireReq.Headers = updatedHeaders + } + if nextPayload.previousResponseID != "" { + expectedPrev := strings.TrimSpace(lastTurnResponseID) + chainedFromLast := expectedPrev != "" && nextPayload.previousResponseID == expectedPrev + nextPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(nextPayload.previousResponseID) + logOpenAIWSModeInfo( + "ingress_ws_next_turn_chain account_id=%d turn=%d next_turn=%d conn_id=%s previous_response_id=%s previous_response_id_kind=%s last_turn_response_id=%s chained_from_last=%v has_prompt_cache_key=%v store_disabled=%v", + account.ID, + turn, + turn+1, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(nextPayload.previousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(nextPreviousResponseIDKind), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + chainedFromLast, + nextPayload.promptCacheKey != "", + storeDisabled, + ) + } + if stateStore != nil && nextPayload.previousResponseID != "" { + if stickyConnID, ok := stateStore.GetResponseConn(nextPayload.previousResponseID); ok { + if sessionConnID != "" && stickyConnID != "" && stickyConnID != sessionConnID { + logOpenAIWSModeInfo( + "ingress_ws_keep_session_conn account_id=%d turn=%d conn_id=%s sticky_conn_id=%s previous_response_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(stickyConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(nextPayload.previousResponseID, openAIWSIDValueMaxLen), + ) + } else { + preferredConnID = stickyConnID + } + } + } + currentPayload = nextPayload.payloadRaw + currentOriginalModel = nextPayload.originalModel + currentPayloadBytes = nextPayload.payloadBytes + storeDisabled = s.isOpenAIWSStoreDisabledInRequestRaw(currentPayload, account) + if !storeDisabled { + unpinSessionConn(sessionConnID) + } + turn++ + } +} + +func (s *OpenAIGatewayService) isOpenAIWSGeneratePrewarmEnabled() bool { + return s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.PrewarmGenerateEnabled +} + +// performOpenAIWSGeneratePrewarm 在 WSv2 下执行可选的 generate=false 预热。 +// 预热默认关闭,仅在配置开启后生效;失败时按可恢复错误回退到 HTTP。 +func (s *OpenAIGatewayService) performOpenAIWSGeneratePrewarm( + ctx context.Context, + lease *openAIWSConnLease, + decision OpenAIWSProtocolDecision, + payload map[string]any, + previousResponseID string, + reqBody map[string]any, + account *Account, + stateStore OpenAIWSStateStore, + groupID int64, +) error { + if s == nil { + return nil + } + if lease == nil || account == nil { + logOpenAIWSModeInfo("prewarm_skip reason=invalid_state has_lease=%v has_account=%v", lease != nil, account != nil) + return nil + } + connID := strings.TrimSpace(lease.ConnID()) + if !s.isOpenAIWSGeneratePrewarmEnabled() { + return nil + } + if decision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 { + logOpenAIWSModeInfo( + "prewarm_skip account_id=%d conn_id=%s reason=transport_not_v2 transport=%s", + account.ID, + connID, + normalizeOpenAIWSLogValue(string(decision.Transport)), + ) + return nil + } + if strings.TrimSpace(previousResponseID) != "" { + logOpenAIWSModeInfo( + "prewarm_skip account_id=%d conn_id=%s reason=has_previous_response_id previous_response_id=%s", + account.ID, + connID, + truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen), + ) + return nil + } + if lease.IsPrewarmed() { + logOpenAIWSModeInfo("prewarm_skip account_id=%d conn_id=%s reason=already_prewarmed", account.ID, connID) + return nil + } + if NeedsToolContinuation(reqBody) { + logOpenAIWSModeInfo("prewarm_skip account_id=%d conn_id=%s reason=tool_continuation", account.ID, connID) + return nil + } + prewarmStart := time.Now() + logOpenAIWSModeInfo("prewarm_start account_id=%d conn_id=%s", account.ID, connID) + + prewarmPayload := make(map[string]any, len(payload)+1) + for k, v := range payload { + prewarmPayload[k] = v + } + prewarmPayload["generate"] = false + prewarmPayloadJSON := payloadAsJSONBytes(prewarmPayload) + + if err := lease.WriteJSONWithContextTimeout(ctx, prewarmPayload, s.openAIWSWriteTimeout()); err != nil { + lease.MarkBroken() + logOpenAIWSModeInfo( + "prewarm_write_fail account_id=%d conn_id=%s cause=%s", + account.ID, + connID, + truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen), + ) + return wrapOpenAIWSFallback("prewarm_write", err) + } + logOpenAIWSModeInfo("prewarm_write_sent account_id=%d conn_id=%s payload_bytes=%d", account.ID, connID, len(prewarmPayloadJSON)) + + prewarmResponseID := "" + prewarmEventCount := 0 + prewarmTerminalCount := 0 + for { + message, readErr := lease.ReadMessageWithContextTimeout(ctx, s.openAIWSReadTimeout()) + if readErr != nil { + lease.MarkBroken() + closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr) + logOpenAIWSModeInfo( + "prewarm_read_fail account_id=%d conn_id=%s close_status=%s close_reason=%s cause=%s events=%d", + account.ID, + connID, + closeStatus, + closeReason, + truncateOpenAIWSLogValue(readErr.Error(), openAIWSLogValueMaxLen), + prewarmEventCount, + ) + return wrapOpenAIWSFallback("prewarm_"+classifyOpenAIWSReadFallbackReason(readErr), readErr) + } + + eventType, eventResponseID, _ := parseOpenAIWSEventEnvelope(message) + if eventType == "" { + continue + } + prewarmEventCount++ + if prewarmResponseID == "" && eventResponseID != "" { + prewarmResponseID = eventResponseID + } + if prewarmEventCount <= openAIWSPrewarmEventLogHead || eventType == "error" || isOpenAIWSTerminalEvent(eventType) { + logOpenAIWSModeInfo( + "prewarm_event account_id=%d conn_id=%s idx=%d type=%s bytes=%d", + account.ID, + connID, + prewarmEventCount, + truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen), + len(message), + ) + } + + if eventType == "error" { + errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message) + s.persistOpenAIWSRateLimitSignal(ctx, account, lease.HandshakeHeaders(), message, errCodeRaw, errTypeRaw, errMsgRaw) + errMsg := strings.TrimSpace(errMsgRaw) + if errMsg == "" { + errMsg = "OpenAI websocket prewarm error" + } + fallbackReason, canFallback := classifyOpenAIWSErrorEventFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) + errCode, errType, errMessage := summarizeOpenAIWSErrorEventFieldsFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) + logOpenAIWSModeInfo( + "prewarm_error_event account_id=%d conn_id=%s idx=%d fallback_reason=%s can_fallback=%v err_code=%s err_type=%s err_message=%s", + account.ID, + connID, + prewarmEventCount, + truncateOpenAIWSLogValue(fallbackReason, openAIWSLogValueMaxLen), + canFallback, + errCode, + errType, + errMessage, + ) + lease.MarkBroken() + if canFallback { + return wrapOpenAIWSFallback("prewarm_"+fallbackReason, errors.New(errMsg)) + } + return wrapOpenAIWSFallback("prewarm_error_event", errors.New(errMsg)) + } + + if isOpenAIWSTerminalEvent(eventType) { + prewarmTerminalCount++ + break + } + } + + lease.MarkPrewarmed() + if prewarmResponseID != "" && stateStore != nil { + ttl := s.openAIWSResponseStickyTTL() + logOpenAIWSBindResponseAccountWarn(groupID, account.ID, prewarmResponseID, stateStore.BindResponseAccount(ctx, groupID, prewarmResponseID, account.ID, ttl)) + stateStore.BindResponseConn(prewarmResponseID, lease.ConnID(), ttl) + } + logOpenAIWSModeInfo( + "prewarm_done account_id=%d conn_id=%s response_id=%s events=%d terminal_events=%d duration_ms=%d", + account.ID, + connID, + truncateOpenAIWSLogValue(prewarmResponseID, openAIWSIDValueMaxLen), + prewarmEventCount, + prewarmTerminalCount, + time.Since(prewarmStart).Milliseconds(), + ) + return nil +} + +func payloadAsJSON(payload map[string]any) string { + return string(payloadAsJSONBytes(payload)) +} + +func payloadAsJSONBytes(payload map[string]any) []byte { + if len(payload) == 0 { + return []byte("{}") + } + body, err := json.Marshal(payload) + if err != nil { + return []byte("{}") + } + return body +} + +func isOpenAIWSTerminalEvent(eventType string) bool { + switch strings.TrimSpace(eventType) { + case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled": + return true + default: + return false + } +} + +func isOpenAIWSTokenEvent(eventType string) bool { + eventType = strings.TrimSpace(eventType) + if eventType == "" { + return false + } + switch eventType { + case "response.created", "response.in_progress", "response.output_item.added", "response.output_item.done": + return false + } + if strings.Contains(eventType, ".delta") { + return true + } + if strings.HasPrefix(eventType, "response.output_text") { + return true + } + if strings.HasPrefix(eventType, "response.output") { + return true + } + return eventType == "response.completed" || eventType == "response.done" +} + +func replaceOpenAIWSMessageModel(message []byte, fromModel, toModel string) []byte { + if len(message) == 0 { + return message + } + if strings.TrimSpace(fromModel) == "" || strings.TrimSpace(toModel) == "" || fromModel == toModel { + return message + } + if !bytes.Contains(message, []byte(`"model"`)) || !bytes.Contains(message, []byte(fromModel)) { + return message + } + modelValues := gjson.GetManyBytes(message, "model", "response.model") + replaceModel := modelValues[0].Exists() && modelValues[0].Str == fromModel + replaceResponseModel := modelValues[1].Exists() && modelValues[1].Str == fromModel + if !replaceModel && !replaceResponseModel { + return message + } + updated := message + if replaceModel { + if next, err := sjson.SetBytes(updated, "model", toModel); err == nil { + updated = next + } + } + if replaceResponseModel { + if next, err := sjson.SetBytes(updated, "response.model", toModel); err == nil { + updated = next + } + } + return updated +} + +func populateOpenAIUsageFromResponseJSON(body []byte, usage *OpenAIUsage) { + if usage == nil || len(body) == 0 { + return + } + values := gjson.GetManyBytes( + body, + "usage.input_tokens", + "usage.output_tokens", + "usage.input_tokens_details.cached_tokens", + ) + usage.InputTokens = int(values[0].Int()) + usage.OutputTokens = int(values[1].Int()) + usage.CacheReadInputTokens = int(values[2].Int()) +} + +func getOpenAIGroupIDFromContext(c *gin.Context) int64 { + if c == nil { + return 0 + } + value, exists := c.Get("api_key") + if !exists { + return 0 + } + apiKey, ok := value.(*APIKey) + if !ok || apiKey == nil || apiKey.GroupID == nil { + return 0 + } + return *apiKey.GroupID +} + +// SelectAccountByPreviousResponseID 按 previous_response_id 命中账号粘连。 +// 未命中或账号不可用时返回 (nil, nil),由调用方继续走常规调度。 +func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID( + ctx context.Context, + groupID *int64, + previousResponseID string, + requestedModel string, + excludedIDs map[int64]struct{}, +) (*AccountSelectionResult, error) { + if s == nil { + return nil, nil + } + responseID := strings.TrimSpace(previousResponseID) + if responseID == "" { + return nil, nil + } + store := s.getOpenAIWSStateStore() + if store == nil { + return nil, nil + } + + accountID, err := store.GetResponseAccount(ctx, derefGroupID(groupID), responseID) + if err != nil || accountID <= 0 { + return nil, nil + } + if excludedIDs != nil { + if _, excluded := excludedIDs[accountID]; excluded { + return nil, nil + } + } + + account, err := s.getSchedulableAccount(ctx, accountID) + if err != nil || account == nil { + _ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID) + return nil, nil + } + // 非 WSv2 场景(如 force_http/全局关闭)不应使用 previous_response_id 粘连, + // 以保持“回滚到 HTTP”后的历史行为一致性。 + if s.getOpenAIWSProtocolResolver().Resolve(account).Transport != OpenAIUpstreamTransportResponsesWebsocketV2 { + return nil, nil + } + if shouldClearStickySession(ctx, account, requestedModel) || !account.IsOpenAI() || !account.IsSchedulable() { + _ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID) + return nil, nil + } + if requestedModel != "" && !account.IsModelSupported(requestedModel) { + return nil, nil + } + + result, acquireErr := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) + if acquireErr == nil && result.Acquired { + logOpenAIWSBindResponseAccountWarn( + derefGroupID(groupID), + accountID, + responseID, + store.BindResponseAccount(ctx, derefGroupID(groupID), responseID, accountID, s.openAIWSResponseStickyTTL()), + ) + return &AccountSelectionResult{ + Account: account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } + + cfg := s.schedulingConfig() + if s.concurrencyService != nil { + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: accountID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } + return nil, nil +} + +func classifyOpenAIWSAcquireError(err error) string { + if err == nil { + return "acquire_conn" + } + var dialErr *openAIWSDialError + if errors.As(err, &dialErr) { + switch dialErr.StatusCode { + case 426: + return "upgrade_required" + case 401, 403: + return "auth_failed" + case 429: + return "upstream_rate_limited" + } + if dialErr.StatusCode >= 500 { + return "upstream_5xx" + } + return "dial_failed" + } + if errors.Is(err, errOpenAIWSConnQueueFull) { + return "conn_queue_full" + } + if errors.Is(err, errOpenAIWSPreferredConnUnavailable) { + return "preferred_conn_unavailable" + } + if errors.Is(err, context.DeadlineExceeded) { + return "acquire_timeout" + } + return "acquire_conn" +} + +func isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw string) bool { + code := strings.ToLower(strings.TrimSpace(codeRaw)) + errType := strings.ToLower(strings.TrimSpace(errTypeRaw)) + msg := strings.ToLower(strings.TrimSpace(msgRaw)) + + if strings.Contains(errType, "rate_limit") || strings.Contains(errType, "usage_limit") { + return true + } + if strings.Contains(code, "rate_limit") || strings.Contains(code, "usage_limit") || strings.Contains(code, "insufficient_quota") { + return true + } + if strings.Contains(msg, "usage limit") && strings.Contains(msg, "reached") { + return true + } + if strings.Contains(msg, "rate limit") && (strings.Contains(msg, "reached") || strings.Contains(msg, "exceeded")) { + return true + } + return false +} + +func (s *OpenAIGatewayService) persistOpenAIWSRateLimitSignal(ctx context.Context, account *Account, headers http.Header, responseBody []byte, codeRaw, errTypeRaw, msgRaw string) { + if s == nil || s.rateLimitService == nil || account == nil || account.Platform != PlatformOpenAI { + return + } + if !isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw) { + return + } + s.rateLimitService.HandleUpstreamError(ctx, account, http.StatusTooManyRequests, headers, responseBody) +} + +func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (string, bool) { + code := strings.ToLower(strings.TrimSpace(codeRaw)) + errType := strings.ToLower(strings.TrimSpace(errTypeRaw)) + msg := strings.ToLower(strings.TrimSpace(msgRaw)) + + switch code { + case "upgrade_required": + return "upgrade_required", true + case "websocket_not_supported", "websocket_unsupported": + return "ws_unsupported", true + case "websocket_connection_limit_reached": + return "ws_connection_limit_reached", true + case "invalid_encrypted_content": + return "invalid_encrypted_content", true + case "previous_response_not_found": + return "previous_response_not_found", true + } + if isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw) { + return "upstream_rate_limited", false + } + if strings.Contains(msg, "upgrade required") || strings.Contains(msg, "status 426") { + return "upgrade_required", true + } + if strings.Contains(errType, "upgrade") { + return "upgrade_required", true + } + if strings.Contains(msg, "websocket") && strings.Contains(msg, "unsupported") { + return "ws_unsupported", true + } + if strings.Contains(msg, "connection limit") && strings.Contains(msg, "websocket") { + return "ws_connection_limit_reached", true + } + if strings.Contains(msg, "invalid_encrypted_content") || + (strings.Contains(msg, "encrypted content") && strings.Contains(msg, "could not be verified")) { + return "invalid_encrypted_content", true + } + if strings.Contains(msg, "previous_response_not_found") || + (strings.Contains(msg, "previous response") && strings.Contains(msg, "not found")) { + return "previous_response_not_found", true + } + if strings.Contains(errType, "server_error") || strings.Contains(code, "server_error") { + return "upstream_error_event", true + } + return "event_error", false +} + +func classifyOpenAIWSErrorEvent(message []byte) (string, bool) { + if len(message) == 0 { + return "event_error", false + } + return classifyOpenAIWSErrorEventFromRaw(parseOpenAIWSErrorEventFields(message)) +} + +func openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw string) int { + code := strings.ToLower(strings.TrimSpace(codeRaw)) + errType := strings.ToLower(strings.TrimSpace(errTypeRaw)) + switch { + case strings.Contains(errType, "invalid_request"), + strings.Contains(code, "invalid_request"), + strings.Contains(code, "bad_request"), + code == "invalid_encrypted_content", + code == "previous_response_not_found": + return http.StatusBadRequest + case strings.Contains(errType, "authentication"), + strings.Contains(code, "invalid_api_key"), + strings.Contains(code, "unauthorized"): + return http.StatusUnauthorized + case strings.Contains(errType, "permission"), + strings.Contains(code, "forbidden"): + return http.StatusForbidden + case isOpenAIWSRateLimitError(codeRaw, errTypeRaw, ""): + return http.StatusTooManyRequests + default: + return http.StatusBadGateway + } +} + +func openAIWSErrorHTTPStatus(message []byte) int { + if len(message) == 0 { + return http.StatusBadGateway + } + codeRaw, errTypeRaw, _ := parseOpenAIWSErrorEventFields(message) + return openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw) +} + +func (s *OpenAIGatewayService) openAIWSFallbackCooldown() time.Duration { + if s == nil || s.cfg == nil { + return 30 * time.Second + } + seconds := s.cfg.Gateway.OpenAIWS.FallbackCooldownSeconds + if seconds <= 0 { + return 0 + } + return time.Duration(seconds) * time.Second +} + +func (s *OpenAIGatewayService) isOpenAIWSFallbackCooling(accountID int64) bool { + if s == nil || accountID <= 0 { + return false + } + cooldown := s.openAIWSFallbackCooldown() + if cooldown <= 0 { + return false + } + rawUntil, ok := s.openaiWSFallbackUntil.Load(accountID) + if !ok || rawUntil == nil { + return false + } + until, ok := rawUntil.(time.Time) + if !ok || until.IsZero() { + s.openaiWSFallbackUntil.Delete(accountID) + return false + } + if time.Now().Before(until) { + return true + } + s.openaiWSFallbackUntil.Delete(accountID) + return false +} + +func (s *OpenAIGatewayService) markOpenAIWSFallbackCooling(accountID int64, _ string) { + if s == nil || accountID <= 0 { + return + } + cooldown := s.openAIWSFallbackCooldown() + if cooldown <= 0 { + return + } + s.openaiWSFallbackUntil.Store(accountID, time.Now().Add(cooldown)) +} + +func (s *OpenAIGatewayService) clearOpenAIWSFallbackCooling(accountID int64) { + if s == nil || accountID <= 0 { + return + } + s.openaiWSFallbackUntil.Delete(accountID) +} diff --git a/backend/internal/service/promo_service.go b/backend/internal/service/promo_service.go new file mode 100644 index 00000000..5ff63bdc --- /dev/null +++ b/backend/internal/service/promo_service.go @@ -0,0 +1,268 @@ +package service + +import ( + "context" + "crypto/rand" + "encoding/hex" + "fmt" + "strings" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" +) + +var ( + ErrPromoCodeNotFound = infraerrors.NotFound("PROMO_CODE_NOT_FOUND", "promo code not found") + ErrPromoCodeExpired = infraerrors.BadRequest("PROMO_CODE_EXPIRED", "promo code has expired") + ErrPromoCodeDisabled = infraerrors.BadRequest("PROMO_CODE_DISABLED", "promo code is disabled") + ErrPromoCodeMaxUsed = infraerrors.BadRequest("PROMO_CODE_MAX_USED", "promo code has reached maximum uses") + ErrPromoCodeAlreadyUsed = infraerrors.Conflict("PROMO_CODE_ALREADY_USED", "you have already used this promo code") + ErrPromoCodeInvalid = infraerrors.BadRequest("PROMO_CODE_INVALID", "invalid promo code") +) + +// PromoService 优惠码服务 +type PromoService struct { + promoRepo PromoCodeRepository + userRepo UserRepository + billingCacheService *BillingCacheService + entClient *dbent.Client + authCacheInvalidator APIKeyAuthCacheInvalidator +} + +// NewPromoService 创建优惠码服务实例 +func NewPromoService( + promoRepo PromoCodeRepository, + userRepo UserRepository, + billingCacheService *BillingCacheService, + entClient *dbent.Client, + authCacheInvalidator APIKeyAuthCacheInvalidator, +) *PromoService { + return &PromoService{ + promoRepo: promoRepo, + userRepo: userRepo, + billingCacheService: billingCacheService, + entClient: entClient, + authCacheInvalidator: authCacheInvalidator, + } +} + +// ValidatePromoCode 验证优惠码(注册前调用) +// 返回 nil, nil 表示空码(不报错) +func (s *PromoService) ValidatePromoCode(ctx context.Context, code string) (*PromoCode, error) { + code = strings.TrimSpace(code) + if code == "" { + return nil, nil // 空码不报错,直接返回 + } + + promoCode, err := s.promoRepo.GetByCode(ctx, code) + if err != nil { + // 保留原始错误类型,不要统一映射为 NotFound + return nil, err + } + + if err := s.validatePromoCodeStatus(promoCode); err != nil { + return nil, err + } + + return promoCode, nil +} + +// validatePromoCodeStatus 验证优惠码状态 +func (s *PromoService) validatePromoCodeStatus(promoCode *PromoCode) error { + if !promoCode.CanUse() { + if promoCode.IsExpired() { + return ErrPromoCodeExpired + } + if promoCode.Status == PromoCodeStatusDisabled { + return ErrPromoCodeDisabled + } + if promoCode.MaxUses > 0 && promoCode.UsedCount >= promoCode.MaxUses { + return ErrPromoCodeMaxUsed + } + return ErrPromoCodeInvalid + } + return nil +} + +// ApplyPromoCode 应用优惠码(注册成功后调用) +// 使用事务和行锁确保并发安全 +func (s *PromoService) ApplyPromoCode(ctx context.Context, userID int64, code string) error { + code = strings.TrimSpace(code) + if code == "" { + return nil + } + + // 开启事务 + tx, err := s.entClient.Tx(ctx) + if err != nil { + return fmt.Errorf("begin transaction: %w", err) + } + defer func() { _ = tx.Rollback() }() + + txCtx := dbent.NewTxContext(ctx, tx) + + // 在事务中获取并锁定优惠码记录(FOR UPDATE) + promoCode, err := s.promoRepo.GetByCodeForUpdate(txCtx, code) + if err != nil { + return err + } + + // 在事务中验证优惠码状态 + if err := s.validatePromoCodeStatus(promoCode); err != nil { + return err + } + + // 在事务中检查用户是否已使用过此优惠码 + existing, err := s.promoRepo.GetUsageByPromoCodeAndUser(txCtx, promoCode.ID, userID) + if err != nil { + return fmt.Errorf("check existing usage: %w", err) + } + if existing != nil { + return ErrPromoCodeAlreadyUsed + } + + // 增加用户余额 + if err := s.userRepo.UpdateBalance(txCtx, userID, promoCode.BonusAmount); err != nil { + return fmt.Errorf("update user balance: %w", err) + } + + // 创建使用记录 + usage := &PromoCodeUsage{ + PromoCodeID: promoCode.ID, + UserID: userID, + BonusAmount: promoCode.BonusAmount, + UsedAt: time.Now(), + } + if err := s.promoRepo.CreateUsage(txCtx, usage); err != nil { + return fmt.Errorf("create usage record: %w", err) + } + + // 增加使用次数 + if err := s.promoRepo.IncrementUsedCount(txCtx, promoCode.ID); err != nil { + return fmt.Errorf("increment used count: %w", err) + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("commit transaction: %w", err) + } + + s.invalidatePromoCaches(ctx, userID, promoCode.BonusAmount) + + // 失效余额缓存 + if s.billingCacheService != nil { + go func() { + cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID) + }() + } + + return nil +} + +func (s *PromoService) invalidatePromoCaches(ctx context.Context, userID int64, bonusAmount float64) { + if bonusAmount == 0 || s.authCacheInvalidator == nil { + return + } + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) +} + +// GenerateRandomCode 生成随机优惠码 +func (s *PromoService) GenerateRandomCode() (string, error) { + bytes := make([]byte, 8) + if _, err := rand.Read(bytes); err != nil { + return "", fmt.Errorf("generate random bytes: %w", err) + } + return strings.ToUpper(hex.EncodeToString(bytes)), nil +} + +// Create 创建优惠码 +func (s *PromoService) Create(ctx context.Context, input *CreatePromoCodeInput) (*PromoCode, error) { + code := strings.TrimSpace(input.Code) + if code == "" { + // 自动生成 + var err error + code, err = s.GenerateRandomCode() + if err != nil { + return nil, err + } + } + + promoCode := &PromoCode{ + Code: strings.ToUpper(code), + BonusAmount: input.BonusAmount, + MaxUses: input.MaxUses, + UsedCount: 0, + Status: PromoCodeStatusActive, + ExpiresAt: input.ExpiresAt, + Notes: input.Notes, + } + + if err := s.promoRepo.Create(ctx, promoCode); err != nil { + return nil, fmt.Errorf("create promo code: %w", err) + } + + return promoCode, nil +} + +// GetByID 根据ID获取优惠码 +func (s *PromoService) GetByID(ctx context.Context, id int64) (*PromoCode, error) { + code, err := s.promoRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + return code, nil +} + +// Update 更新优惠码 +func (s *PromoService) Update(ctx context.Context, id int64, input *UpdatePromoCodeInput) (*PromoCode, error) { + promoCode, err := s.promoRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + + if input.Code != nil { + promoCode.Code = strings.ToUpper(strings.TrimSpace(*input.Code)) + } + if input.BonusAmount != nil { + promoCode.BonusAmount = *input.BonusAmount + } + if input.MaxUses != nil { + promoCode.MaxUses = *input.MaxUses + } + if input.Status != nil { + promoCode.Status = *input.Status + } + if input.ExpiresAt != nil { + promoCode.ExpiresAt = input.ExpiresAt + } + if input.Notes != nil { + promoCode.Notes = *input.Notes + } + + if err := s.promoRepo.Update(ctx, promoCode); err != nil { + return nil, fmt.Errorf("update promo code: %w", err) + } + + return promoCode, nil +} + +// Delete 删除优惠码 +func (s *PromoService) Delete(ctx context.Context, id int64) error { + if err := s.promoRepo.Delete(ctx, id); err != nil { + return fmt.Errorf("delete promo code: %w", err) + } + return nil +} + +// List 获取优惠码列表 +func (s *PromoService) List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]PromoCode, *pagination.PaginationResult, error) { + return s.promoRepo.ListWithFilters(ctx, params, status, search) +} + +// ListUsages 获取使用记录 +func (s *PromoService) ListUsages(ctx context.Context, promoCodeID int64, params pagination.PaginationParams) ([]PromoCodeUsage, *pagination.PaginationResult, error) { + return s.promoRepo.ListUsagesByPromoCode(ctx, promoCodeID, params) +} diff --git a/backend/internal/service/proxy_service.go b/backend/internal/service/proxy_service.go new file mode 100644 index 00000000..f23117b5 --- /dev/null +++ b/backend/internal/service/proxy_service.go @@ -0,0 +1,193 @@ +package service + +import ( + "context" + "fmt" + "net/http" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" +) + +var ( + ErrProxyNotFound = infraerrors.NotFound("PROXY_NOT_FOUND", "proxy not found") + ErrProxyInUse = infraerrors.Conflict("PROXY_IN_USE", "proxy is in use by accounts") +) + +type ProxyRepository interface { + Create(ctx context.Context, proxy *Proxy) error + GetByID(ctx context.Context, id int64) (*Proxy, error) + ListByIDs(ctx context.Context, ids []int64) ([]Proxy, error) + Update(ctx context.Context, proxy *Proxy) error + Delete(ctx context.Context, id int64) error + + List(ctx context.Context, params pagination.PaginationParams) ([]Proxy, *pagination.PaginationResult, error) + ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]Proxy, *pagination.PaginationResult, error) + ListWithFiltersAndAccountCount(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]ProxyWithAccountCount, *pagination.PaginationResult, error) + ListActive(ctx context.Context) ([]Proxy, error) + ListActiveWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) + + ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) + CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) + ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) +} + +// CreateProxyRequest 创建代理请求 +type CreateProxyRequest struct { + Name string `json:"name"` + Protocol string `json:"protocol"` + Host string `json:"host"` + Port int `json:"port"` + Username string `json:"username"` + Password string `json:"password"` +} + +// UpdateProxyRequest 更新代理请求 +type UpdateProxyRequest struct { + Name *string `json:"name"` + Protocol *string `json:"protocol"` + Host *string `json:"host"` + Port *int `json:"port"` + Username *string `json:"username"` + Password *string `json:"password"` + Status *string `json:"status"` +} + +// ProxyService 代理管理服务 +type ProxyService struct { + proxyRepo ProxyRepository +} + +// NewProxyService 创建代理服务实例 +func NewProxyService(proxyRepo ProxyRepository) *ProxyService { + return &ProxyService{ + proxyRepo: proxyRepo, + } +} + +// Create 创建代理 +func (s *ProxyService) Create(ctx context.Context, req CreateProxyRequest) (*Proxy, error) { + // 创建代理 + proxy := &Proxy{ + Name: req.Name, + Protocol: req.Protocol, + Host: req.Host, + Port: req.Port, + Username: req.Username, + Password: req.Password, + Status: StatusActive, + } + + if err := s.proxyRepo.Create(ctx, proxy); err != nil { + return nil, fmt.Errorf("create proxy: %w", err) + } + + return proxy, nil +} + +// GetByID 根据ID获取代理 +func (s *ProxyService) GetByID(ctx context.Context, id int64) (*Proxy, error) { + proxy, err := s.proxyRepo.GetByID(ctx, id) + if err != nil { + return nil, fmt.Errorf("get proxy: %w", err) + } + return proxy, nil +} + +// List 获取代理列表 +func (s *ProxyService) List(ctx context.Context, params pagination.PaginationParams) ([]Proxy, *pagination.PaginationResult, error) { + proxies, pagination, err := s.proxyRepo.List(ctx, params) + if err != nil { + return nil, nil, fmt.Errorf("list proxies: %w", err) + } + return proxies, pagination, nil +} + +// ListActive 获取活跃代理列表 +func (s *ProxyService) ListActive(ctx context.Context) ([]Proxy, error) { + proxies, err := s.proxyRepo.ListActive(ctx) + if err != nil { + return nil, fmt.Errorf("list active proxies: %w", err) + } + return proxies, nil +} + +// Update 更新代理 +func (s *ProxyService) Update(ctx context.Context, id int64, req UpdateProxyRequest) (*Proxy, error) { + proxy, err := s.proxyRepo.GetByID(ctx, id) + if err != nil { + return nil, fmt.Errorf("get proxy: %w", err) + } + + // 更新字段 + if req.Name != nil { + proxy.Name = *req.Name + } + + if req.Protocol != nil { + proxy.Protocol = *req.Protocol + } + + if req.Host != nil { + proxy.Host = *req.Host + } + + if req.Port != nil { + proxy.Port = *req.Port + } + + if req.Username != nil { + proxy.Username = *req.Username + } + + if req.Password != nil { + proxy.Password = *req.Password + } + + if req.Status != nil { + proxy.Status = *req.Status + } + + if err := s.proxyRepo.Update(ctx, proxy); err != nil { + return nil, fmt.Errorf("update proxy: %w", err) + } + + return proxy, nil +} + +// Delete 删除代理 +func (s *ProxyService) Delete(ctx context.Context, id int64) error { + // 检查代理是否存在 + _, err := s.proxyRepo.GetByID(ctx, id) + if err != nil { + return fmt.Errorf("get proxy: %w", err) + } + + if err := s.proxyRepo.Delete(ctx, id); err != nil { + return fmt.Errorf("delete proxy: %w", err) + } + + return nil +} + +// TestConnection tests proxy connectivity by attempting a request through the proxy. +// Returns ErrNotImplemented if the proxy test feature is not yet implemented. +func (s *ProxyService) TestConnection(ctx context.Context, id int64) error { + proxy, err := s.proxyRepo.GetByID(ctx, id) + if err != nil { + return fmt.Errorf("get proxy: %w", err) + } + _ = proxy + + return infraerrors.New(http.StatusNotImplemented, "PROXY_TEST_NOT_IMPLEMENTED", "proxy connection testing is not yet implemented") +} + +// GetURL 获取代理URL +func (s *ProxyService) GetURL(ctx context.Context, id int64) (string, error) { + proxy, err := s.proxyRepo.GetByID(ctx, id) + if err != nil { + return "", fmt.Errorf("get proxy: %w", err) + } + + return proxy.URL(), nil +} diff --git a/backend/internal/service/redeem_service.go b/backend/internal/service/redeem_service.go new file mode 100644 index 00000000..5b3310d6 --- /dev/null +++ b/backend/internal/service/redeem_service.go @@ -0,0 +1,468 @@ +package service + +import ( + "context" + "crypto/rand" + "encoding/hex" + "errors" + "fmt" + "net/http" + "strings" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" +) + +var ( + ErrRedeemCodeNotFound = infraerrors.NotFound("REDEEM_CODE_NOT_FOUND", "redeem code not found") + ErrRedeemCodeUsed = infraerrors.Conflict("REDEEM_CODE_USED", "redeem code already used") + ErrInsufficientBalance = infraerrors.BadRequest("INSUFFICIENT_BALANCE", "insufficient balance") + ErrRedeemRateLimited = infraerrors.TooManyRequests("REDEEM_RATE_LIMITED", "too many failed attempts, please try again later") + ErrRedeemCodeLocked = infraerrors.Conflict("REDEEM_CODE_LOCKED", "redeem code is being processed, please try again") +) + +const ( + redeemMaxErrorsPerHour = 20 + redeemRateLimitDuration = time.Hour + redeemLockDuration = 10 * time.Second // 锁超时时间,防止死锁 +) + +// RedeemCache defines cache operations for redeem service +type RedeemCache interface { + GetRedeemAttemptCount(ctx context.Context, userID int64) (int, error) + IncrementRedeemAttemptCount(ctx context.Context, userID int64) error + + AcquireRedeemLock(ctx context.Context, code string, ttl time.Duration) (bool, error) + ReleaseRedeemLock(ctx context.Context, code string) error +} + +type RedeemCodeRepository interface { + Create(ctx context.Context, code *RedeemCode) error + CreateBatch(ctx context.Context, codes []RedeemCode) error + GetByID(ctx context.Context, id int64) (*RedeemCode, error) + GetByCode(ctx context.Context, code string) (*RedeemCode, error) + Update(ctx context.Context, code *RedeemCode) error + Delete(ctx context.Context, id int64) error + Use(ctx context.Context, id, userID int64) error + + List(ctx context.Context, params pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error) + ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]RedeemCode, *pagination.PaginationResult, error) + ListByUser(ctx context.Context, userID int64, limit int) ([]RedeemCode, error) + // ListByUserPaginated returns paginated balance/concurrency history for a specific user. + // codeType filter is optional - pass empty string to return all types. + ListByUserPaginated(ctx context.Context, userID int64, params pagination.PaginationParams, codeType string) ([]RedeemCode, *pagination.PaginationResult, error) + // SumPositiveBalanceByUser returns the total recharged amount (sum of positive balance values) for a user. + SumPositiveBalanceByUser(ctx context.Context, userID int64) (float64, error) +} + +// GenerateCodesRequest 生成兑换码请求 +type GenerateCodesRequest struct { + Count int `json:"count"` + Value float64 `json:"value"` + Type string `json:"type"` +} + +// RedeemCodeResponse 兑换码响应 +type RedeemCodeResponse struct { + Code string `json:"code"` + Value float64 `json:"value"` + Status string `json:"status"` + CreatedAt time.Time `json:"created_at"` +} + +// RedeemService 兑换码服务 +type RedeemService struct { + redeemRepo RedeemCodeRepository + userRepo UserRepository + subscriptionService *SubscriptionService + cache RedeemCache + billingCacheService *BillingCacheService + entClient *dbent.Client + authCacheInvalidator APIKeyAuthCacheInvalidator +} + +// NewRedeemService 创建兑换码服务实例 +func NewRedeemService( + redeemRepo RedeemCodeRepository, + userRepo UserRepository, + subscriptionService *SubscriptionService, + cache RedeemCache, + billingCacheService *BillingCacheService, + entClient *dbent.Client, + authCacheInvalidator APIKeyAuthCacheInvalidator, +) *RedeemService { + return &RedeemService{ + redeemRepo: redeemRepo, + userRepo: userRepo, + subscriptionService: subscriptionService, + cache: cache, + billingCacheService: billingCacheService, + entClient: entClient, + authCacheInvalidator: authCacheInvalidator, + } +} + +// GenerateRandomCode 生成随机兑换码 +func (s *RedeemService) GenerateRandomCode() (string, error) { + // 生成16字节随机数据 + bytes := make([]byte, 16) + if _, err := rand.Read(bytes); err != nil { + return "", fmt.Errorf("generate random bytes: %w", err) + } + + // 转换为十六进制字符串 + code := hex.EncodeToString(bytes) + + // 格式化为 XXXX-XXXX-XXXX-XXXX 格式 + parts := []string{ + strings.ToUpper(code[0:8]), + strings.ToUpper(code[8:16]), + strings.ToUpper(code[16:24]), + strings.ToUpper(code[24:32]), + } + + return strings.Join(parts, "-"), nil +} + +// GenerateCodes 批量生成兑换码 +func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequest) ([]RedeemCode, error) { + if req.Count <= 0 { + return nil, errors.New("count must be greater than 0") + } + + // 邀请码类型不需要数值,其他类型需要 + if req.Type != RedeemTypeInvitation && req.Value <= 0 { + return nil, errors.New("value must be greater than 0") + } + + if req.Count > 1000 { + return nil, errors.New("cannot generate more than 1000 codes at once") + } + + codeType := req.Type + if codeType == "" { + codeType = RedeemTypeBalance + } + + // 邀请码类型的 value 设为 0 + value := req.Value + if codeType == RedeemTypeInvitation { + value = 0 + } + + codes := make([]RedeemCode, 0, req.Count) + for i := 0; i < req.Count; i++ { + code, err := s.GenerateRandomCode() + if err != nil { + return nil, fmt.Errorf("generate code: %w", err) + } + + codes = append(codes, RedeemCode{ + Code: code, + Type: codeType, + Value: value, + Status: StatusUnused, + }) + } + + // 批量插入 + if err := s.redeemRepo.CreateBatch(ctx, codes); err != nil { + return nil, fmt.Errorf("create batch codes: %w", err) + } + + return codes, nil +} + +// CreateCode creates a redeem code with caller-provided code value. +// It is primarily used by admin integrations that require an external order ID +// to be mapped to a deterministic redeem code. +func (s *RedeemService) CreateCode(ctx context.Context, code *RedeemCode) error { + if code == nil { + return errors.New("redeem code is required") + } + code.Code = strings.TrimSpace(code.Code) + if code.Code == "" { + return errors.New("code is required") + } + if code.Type == "" { + code.Type = RedeemTypeBalance + } + if code.Type != RedeemTypeInvitation && code.Value <= 0 { + return errors.New("value must be greater than 0") + } + if code.Status == "" { + code.Status = StatusUnused + } + + if err := s.redeemRepo.Create(ctx, code); err != nil { + return fmt.Errorf("create redeem code: %w", err) + } + return nil +} + +// checkRedeemRateLimit 检查用户兑换错误次数是否超限 +func (s *RedeemService) checkRedeemRateLimit(ctx context.Context, userID int64) error { + if s.cache == nil { + return nil + } + + count, err := s.cache.GetRedeemAttemptCount(ctx, userID) + if err != nil { + // Redis 出错时不阻止用户操作 + return nil + } + + if count >= redeemMaxErrorsPerHour { + return ErrRedeemRateLimited + } + + return nil +} + +// incrementRedeemErrorCount 增加用户兑换错误计数 +func (s *RedeemService) incrementRedeemErrorCount(ctx context.Context, userID int64) { + if s.cache == nil { + return + } + + _ = s.cache.IncrementRedeemAttemptCount(ctx, userID) +} + +// acquireRedeemLock 尝试获取兑换码的分布式锁 +// 返回 true 表示获取成功,false 表示锁已被占用 +func (s *RedeemService) acquireRedeemLock(ctx context.Context, code string) bool { + if s.cache == nil { + return true // 无 Redis 时降级为不加锁 + } + + ok, err := s.cache.AcquireRedeemLock(ctx, code, redeemLockDuration) + if err != nil { + // Redis 出错时不阻止操作,依赖数据库层面的状态检查 + return true + } + return ok +} + +// releaseRedeemLock 释放兑换码的分布式锁 +func (s *RedeemService) releaseRedeemLock(ctx context.Context, code string) { + if s.cache == nil { + return + } + + _ = s.cache.ReleaseRedeemLock(ctx, code) +} + +// Redeem 使用兑换码 +func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (*RedeemCode, error) { + // 检查限流 + if err := s.checkRedeemRateLimit(ctx, userID); err != nil { + return nil, err + } + + // 获取分布式锁,防止同一兑换码并发使用 + if !s.acquireRedeemLock(ctx, code) { + return nil, ErrRedeemCodeLocked + } + defer s.releaseRedeemLock(ctx, code) + + // 查找兑换码 + redeemCode, err := s.redeemRepo.GetByCode(ctx, code) + if err != nil { + if errors.Is(err, ErrRedeemCodeNotFound) { + s.incrementRedeemErrorCount(ctx, userID) + return nil, ErrRedeemCodeNotFound + } + return nil, fmt.Errorf("get redeem code: %w", err) + } + + // 检查兑换码状态 + if !redeemCode.CanUse() { + s.incrementRedeemErrorCount(ctx, userID) + return nil, ErrRedeemCodeUsed + } + + // 验证兑换码类型的前置条件 + if redeemCode.Type == RedeemTypeSubscription && redeemCode.GroupID == nil { + return nil, infraerrors.BadRequest("REDEEM_CODE_INVALID", "invalid subscription redeem code: missing group_id") + } + + // 获取用户信息 + user, err := s.userRepo.GetByID(ctx, userID) + if err != nil { + return nil, fmt.Errorf("get user: %w", err) + } + _ = user // 使用变量避免未使用错误 + + // 使用数据库事务保证兑换码标记与权益发放的原子性 + tx, err := s.entClient.Tx(ctx) + if err != nil { + return nil, fmt.Errorf("begin transaction: %w", err) + } + defer func() { _ = tx.Rollback() }() + + // 将事务放入 context,使 repository 方法能够使用同一事务 + txCtx := dbent.NewTxContext(ctx, tx) + + // 【关键】先标记兑换码为已使用,确保并发安全 + // 利用数据库乐观锁(WHERE status = 'unused')保证原子性 + if err := s.redeemRepo.Use(txCtx, redeemCode.ID, userID); err != nil { + if errors.Is(err, ErrRedeemCodeNotFound) || errors.Is(err, ErrRedeemCodeUsed) { + return nil, ErrRedeemCodeUsed + } + return nil, fmt.Errorf("mark code as used: %w", err) + } + + // 执行兑换逻辑(兑换码已被锁定,此时可安全操作) + switch redeemCode.Type { + case RedeemTypeBalance: + // 增加用户余额 + if err := s.userRepo.UpdateBalance(txCtx, userID, redeemCode.Value); err != nil { + return nil, fmt.Errorf("update user balance: %w", err) + } + + case RedeemTypeConcurrency: + // 增加用户并发数 + if err := s.userRepo.UpdateConcurrency(txCtx, userID, int(redeemCode.Value)); err != nil { + return nil, fmt.Errorf("update user concurrency: %w", err) + } + + case RedeemTypeSubscription: + validityDays := redeemCode.ValidityDays + if validityDays <= 0 { + validityDays = 30 + } + _, _, err := s.subscriptionService.AssignOrExtendSubscription(txCtx, &AssignSubscriptionInput{ + UserID: userID, + GroupID: *redeemCode.GroupID, + ValidityDays: validityDays, + AssignedBy: 0, // 系统分配 + Notes: fmt.Sprintf("通过兑换码 %s 兑换", redeemCode.Code), + }) + if err != nil { + return nil, fmt.Errorf("assign or extend subscription: %w", err) + } + + default: + return nil, fmt.Errorf("unsupported redeem type: %s", redeemCode.Type) + } + + // 提交事务 + if err := tx.Commit(); err != nil { + return nil, fmt.Errorf("commit transaction: %w", err) + } + + // 事务提交成功后失效缓存 + s.invalidateRedeemCaches(ctx, userID, redeemCode) + + // 重新获取更新后的兑换码 + redeemCode, err = s.redeemRepo.GetByID(ctx, redeemCode.ID) + if err != nil { + return nil, fmt.Errorf("get updated redeem code: %w", err) + } + + return redeemCode, nil +} + +// invalidateRedeemCaches 失效兑换相关的缓存 +func (s *RedeemService) invalidateRedeemCaches(ctx context.Context, userID int64, redeemCode *RedeemCode) { + switch redeemCode.Type { + case RedeemTypeBalance: + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } + if s.billingCacheService == nil { + return + } + go func() { + cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID) + }() + case RedeemTypeConcurrency: + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } + if s.billingCacheService == nil { + return + } + case RedeemTypeSubscription: + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } + if s.billingCacheService == nil { + return + } + if redeemCode.GroupID != nil { + groupID := *redeemCode.GroupID + go func() { + cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID) + }() + } + } +} + +// GetByID 根据ID获取兑换码 +func (s *RedeemService) GetByID(ctx context.Context, id int64) (*RedeemCode, error) { + code, err := s.redeemRepo.GetByID(ctx, id) + if err != nil { + return nil, fmt.Errorf("get redeem code: %w", err) + } + return code, nil +} + +// GetByCode 根据Code获取兑换码 +func (s *RedeemService) GetByCode(ctx context.Context, code string) (*RedeemCode, error) { + redeemCode, err := s.redeemRepo.GetByCode(ctx, code) + if err != nil { + return nil, fmt.Errorf("get redeem code: %w", err) + } + return redeemCode, nil +} + +// List 获取兑换码列表(管理员功能) +func (s *RedeemService) List(ctx context.Context, params pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error) { + codes, pagination, err := s.redeemRepo.List(ctx, params) + if err != nil { + return nil, nil, fmt.Errorf("list redeem codes: %w", err) + } + return codes, pagination, nil +} + +// Delete 删除兑换码(管理员功能) +func (s *RedeemService) Delete(ctx context.Context, id int64) error { + // 检查兑换码是否存在 + code, err := s.redeemRepo.GetByID(ctx, id) + if err != nil { + return fmt.Errorf("get redeem code: %w", err) + } + + // 不允许删除已使用的兑换码 + if code.IsUsed() { + return infraerrors.Conflict("REDEEM_CODE_DELETE_USED", "cannot delete used redeem code") + } + + if err := s.redeemRepo.Delete(ctx, id); err != nil { + return fmt.Errorf("delete redeem code: %w", err) + } + + return nil +} + +// GetStats returns redeem code statistics (unused, used counts, total value). +// Returns ErrNotImplemented as the stats query is not yet implemented. +func (s *RedeemService) GetStats(ctx context.Context) (map[string]any, error) { + return nil, infraerrors.New(http.StatusNotImplemented, "REDEEM_STATS_NOT_IMPLEMENTED", "redeem code statistics is not yet implemented") +} + +// GetUserHistory 获取用户的兑换历史 +func (s *RedeemService) GetUserHistory(ctx context.Context, userID int64, limit int) ([]RedeemCode, error) { + codes, err := s.redeemRepo.ListByUser(ctx, userID, limit) + if err != nil { + return nil, fmt.Errorf("get user redeem history: %w", err) + } + return codes, nil +}