4711 lines
148 KiB
Go
4711 lines
148 KiB
Go
|
|
package service
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"bufio"
|
|||
|
|
"bytes"
|
|||
|
|
"context"
|
|||
|
|
"crypto/sha256"
|
|||
|
|
"encoding/hex"
|
|||
|
|
"encoding/json"
|
|||
|
|
"errors"
|
|||
|
|
"fmt"
|
|||
|
|
"io"
|
|||
|
|
"math/rand"
|
|||
|
|
"net/http"
|
|||
|
|
"sort"
|
|||
|
|
"strconv"
|
|||
|
|
"strings"
|
|||
|
|
"sync"
|
|||
|
|
"sync/atomic"
|
|||
|
|
"time"
|
|||
|
|
|
|||
|
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
|||
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
|||
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
|||
|
|
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
|||
|
|
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
|||
|
|
"github.com/cespare/xxhash/v2"
|
|||
|
|
"github.com/gin-gonic/gin"
|
|||
|
|
"github.com/google/uuid"
|
|||
|
|
"github.com/tidwall/gjson"
|
|||
|
|
"github.com/tidwall/sjson"
|
|||
|
|
"go.uber.org/zap"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
const (
|
|||
|
|
// ChatGPT internal API for OAuth accounts
|
|||
|
|
chatgptCodexURL = "https://chatgpt.com/backend-api/codex/responses"
|
|||
|
|
// OpenAI Platform API for API Key accounts (fallback)
|
|||
|
|
openaiPlatformAPIURL = "https://api.openai.com/v1/responses"
|
|||
|
|
openaiStickySessionTTL = time.Hour // 粘性会话TTL
|
|||
|
|
codexCLIUserAgent = "codex_cli_rs/0.104.0"
|
|||
|
|
// codex_cli_only 拒绝时单个请求头日志长度上限(字符)
|
|||
|
|
codexCLIOnlyHeaderValueMaxBytes = 256
|
|||
|
|
|
|||
|
|
// OpenAIParsedRequestBodyKey 缓存 handler 侧已解析的请求体,避免重复解析。
|
|||
|
|
OpenAIParsedRequestBodyKey = "openai_parsed_request_body"
|
|||
|
|
// OpenAI WS Mode 失败后的重连次数上限(不含首次尝试)。
|
|||
|
|
// 与 Codex 客户端保持一致:失败后最多重连 5 次。
|
|||
|
|
openAIWSReconnectRetryLimit = 5
|
|||
|
|
// OpenAI WS Mode 重连退避默认值(可由配置覆盖)。
|
|||
|
|
openAIWSRetryBackoffInitialDefault = 120 * time.Millisecond
|
|||
|
|
openAIWSRetryBackoffMaxDefault = 2 * time.Second
|
|||
|
|
openAIWSRetryJitterRatioDefault = 0.2
|
|||
|
|
openAICompactSessionSeedKey = "openai_compact_session_seed"
|
|||
|
|
codexCLIVersion = "0.104.0"
|
|||
|
|
// Codex 限额快照仅用于后台展示/诊断,不需要每个成功请求都立即落库。
|
|||
|
|
openAICodexSnapshotPersistMinInterval = 30 * time.Second
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// OpenAI allowed headers whitelist (for non-passthrough).
|
|||
|
|
var openaiAllowedHeaders = map[string]bool{
|
|||
|
|
"accept-language": true,
|
|||
|
|
"content-type": true,
|
|||
|
|
"conversation_id": true,
|
|||
|
|
"user-agent": true,
|
|||
|
|
"originator": true,
|
|||
|
|
"session_id": true,
|
|||
|
|
"x-codex-turn-state": true,
|
|||
|
|
"x-codex-turn-metadata": true,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// OpenAI passthrough allowed headers whitelist.
|
|||
|
|
// 透传模式下仅放行这些低风险请求头,避免将非标准/环境噪声头传给上游触发风控。
|
|||
|
|
var openaiPassthroughAllowedHeaders = map[string]bool{
|
|||
|
|
"accept": true,
|
|||
|
|
"accept-language": true,
|
|||
|
|
"content-type": true,
|
|||
|
|
"conversation_id": true,
|
|||
|
|
"openai-beta": true,
|
|||
|
|
"user-agent": true,
|
|||
|
|
"originator": true,
|
|||
|
|
"session_id": true,
|
|||
|
|
"x-codex-turn-state": true,
|
|||
|
|
"x-codex-turn-metadata": true,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// codex_cli_only 拒绝时记录的请求头白名单(仅用于诊断日志,不参与上游透传)
|
|||
|
|
var codexCLIOnlyDebugHeaderWhitelist = []string{
|
|||
|
|
"User-Agent",
|
|||
|
|
"Content-Type",
|
|||
|
|
"Accept",
|
|||
|
|
"Accept-Language",
|
|||
|
|
"OpenAI-Beta",
|
|||
|
|
"Originator",
|
|||
|
|
"Session_ID",
|
|||
|
|
"Conversation_ID",
|
|||
|
|
"X-Request-ID",
|
|||
|
|
"X-Client-Request-ID",
|
|||
|
|
"X-Forwarded-For",
|
|||
|
|
"X-Real-IP",
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// OpenAICodexUsageSnapshot represents Codex API usage limits from response headers
|
|||
|
|
type OpenAICodexUsageSnapshot struct {
|
|||
|
|
PrimaryUsedPercent *float64 `json:"primary_used_percent,omitempty"`
|
|||
|
|
PrimaryResetAfterSeconds *int `json:"primary_reset_after_seconds,omitempty"`
|
|||
|
|
PrimaryWindowMinutes *int `json:"primary_window_minutes,omitempty"`
|
|||
|
|
SecondaryUsedPercent *float64 `json:"secondary_used_percent,omitempty"`
|
|||
|
|
SecondaryResetAfterSeconds *int `json:"secondary_reset_after_seconds,omitempty"`
|
|||
|
|
SecondaryWindowMinutes *int `json:"secondary_window_minutes,omitempty"`
|
|||
|
|
PrimaryOverSecondaryPercent *float64 `json:"primary_over_secondary_percent,omitempty"`
|
|||
|
|
UpdatedAt string `json:"updated_at,omitempty"`
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// NormalizedCodexLimits contains normalized 5h/7d rate limit data
|
|||
|
|
type NormalizedCodexLimits struct {
|
|||
|
|
Used5hPercent *float64
|
|||
|
|
Reset5hSeconds *int
|
|||
|
|
Window5hMinutes *int
|
|||
|
|
Used7dPercent *float64
|
|||
|
|
Reset7dSeconds *int
|
|||
|
|
Window7dMinutes *int
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Normalize converts primary/secondary fields to canonical 5h/7d fields.
|
|||
|
|
// Strategy: Compare window_minutes to determine which is 5h vs 7d.
|
|||
|
|
// Returns nil if snapshot is nil or has no useful data.
|
|||
|
|
func (s *OpenAICodexUsageSnapshot) Normalize() *NormalizedCodexLimits {
|
|||
|
|
if s == nil {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
result := &NormalizedCodexLimits{}
|
|||
|
|
|
|||
|
|
primaryMins := 0
|
|||
|
|
secondaryMins := 0
|
|||
|
|
hasPrimaryWindow := false
|
|||
|
|
hasSecondaryWindow := false
|
|||
|
|
|
|||
|
|
if s.PrimaryWindowMinutes != nil {
|
|||
|
|
primaryMins = *s.PrimaryWindowMinutes
|
|||
|
|
hasPrimaryWindow = true
|
|||
|
|
}
|
|||
|
|
if s.SecondaryWindowMinutes != nil {
|
|||
|
|
secondaryMins = *s.SecondaryWindowMinutes
|
|||
|
|
hasSecondaryWindow = true
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Determine mapping based on window_minutes
|
|||
|
|
use5hFromPrimary := false
|
|||
|
|
use7dFromPrimary := false
|
|||
|
|
|
|||
|
|
if hasPrimaryWindow && hasSecondaryWindow {
|
|||
|
|
// Both known: smaller window is 5h, larger is 7d
|
|||
|
|
if primaryMins < secondaryMins {
|
|||
|
|
use5hFromPrimary = true
|
|||
|
|
} else {
|
|||
|
|
use7dFromPrimary = true
|
|||
|
|
}
|
|||
|
|
} else if hasPrimaryWindow {
|
|||
|
|
// Only primary known: classify by threshold (<=360 min = 6h -> 5h window)
|
|||
|
|
if primaryMins <= 360 {
|
|||
|
|
use5hFromPrimary = true
|
|||
|
|
} else {
|
|||
|
|
use7dFromPrimary = true
|
|||
|
|
}
|
|||
|
|
} else if hasSecondaryWindow {
|
|||
|
|
// Only secondary known: classify by threshold
|
|||
|
|
if secondaryMins <= 360 {
|
|||
|
|
// 5h from secondary, so primary (if any data) is 7d
|
|||
|
|
use7dFromPrimary = true
|
|||
|
|
} else {
|
|||
|
|
// 7d from secondary, so primary (if any data) is 5h
|
|||
|
|
use5hFromPrimary = true
|
|||
|
|
}
|
|||
|
|
} else {
|
|||
|
|
// No window_minutes: fall back to legacy assumption (primary=7d, secondary=5h)
|
|||
|
|
use7dFromPrimary = true
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Assign values
|
|||
|
|
if use5hFromPrimary {
|
|||
|
|
result.Used5hPercent = s.PrimaryUsedPercent
|
|||
|
|
result.Reset5hSeconds = s.PrimaryResetAfterSeconds
|
|||
|
|
result.Window5hMinutes = s.PrimaryWindowMinutes
|
|||
|
|
result.Used7dPercent = s.SecondaryUsedPercent
|
|||
|
|
result.Reset7dSeconds = s.SecondaryResetAfterSeconds
|
|||
|
|
result.Window7dMinutes = s.SecondaryWindowMinutes
|
|||
|
|
} else if use7dFromPrimary {
|
|||
|
|
result.Used7dPercent = s.PrimaryUsedPercent
|
|||
|
|
result.Reset7dSeconds = s.PrimaryResetAfterSeconds
|
|||
|
|
result.Window7dMinutes = s.PrimaryWindowMinutes
|
|||
|
|
result.Used5hPercent = s.SecondaryUsedPercent
|
|||
|
|
result.Reset5hSeconds = s.SecondaryResetAfterSeconds
|
|||
|
|
result.Window5hMinutes = s.SecondaryWindowMinutes
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return result
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// OpenAIUsage represents OpenAI API response usage
|
|||
|
|
type OpenAIUsage struct {
|
|||
|
|
InputTokens int `json:"input_tokens"`
|
|||
|
|
OutputTokens int `json:"output_tokens"`
|
|||
|
|
CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"`
|
|||
|
|
CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"`
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// OpenAIForwardResult represents the result of forwarding
|
|||
|
|
type OpenAIForwardResult struct {
|
|||
|
|
RequestID string
|
|||
|
|
Usage OpenAIUsage
|
|||
|
|
Model string // 原始模型(用于响应和日志显示)
|
|||
|
|
// BillingModel is the model used for cost calculation.
|
|||
|
|
// When non-empty, CalculateCost uses this instead of Model.
|
|||
|
|
// This is set by the Anthropic Messages conversion path where
|
|||
|
|
// the mapped upstream model differs from the client-facing model.
|
|||
|
|
BillingModel string
|
|||
|
|
// ServiceTier records the OpenAI Responses API service tier, e.g. "priority" / "flex".
|
|||
|
|
// Nil means the request did not specify a recognized tier.
|
|||
|
|
ServiceTier *string
|
|||
|
|
// ReasoningEffort is extracted from request body (reasoning.effort) or derived from model suffix.
|
|||
|
|
// Stored for usage records display; nil means not provided / not applicable.
|
|||
|
|
ReasoningEffort *string
|
|||
|
|
Stream bool
|
|||
|
|
OpenAIWSMode bool
|
|||
|
|
ResponseHeaders http.Header
|
|||
|
|
Duration time.Duration
|
|||
|
|
FirstTokenMs *int
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
type OpenAIWSRetryMetricsSnapshot struct {
|
|||
|
|
RetryAttemptsTotal int64 `json:"retry_attempts_total"`
|
|||
|
|
RetryBackoffMsTotal int64 `json:"retry_backoff_ms_total"`
|
|||
|
|
RetryExhaustedTotal int64 `json:"retry_exhausted_total"`
|
|||
|
|
NonRetryableFastFallbackTotal int64 `json:"non_retryable_fast_fallback_total"`
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
type OpenAICompatibilityFallbackMetricsSnapshot struct {
|
|||
|
|
SessionHashLegacyReadFallbackTotal int64 `json:"session_hash_legacy_read_fallback_total"`
|
|||
|
|
SessionHashLegacyReadFallbackHit int64 `json:"session_hash_legacy_read_fallback_hit"`
|
|||
|
|
SessionHashLegacyDualWriteTotal int64 `json:"session_hash_legacy_dual_write_total"`
|
|||
|
|
SessionHashLegacyReadHitRate float64 `json:"session_hash_legacy_read_hit_rate"`
|
|||
|
|
|
|||
|
|
MetadataLegacyFallbackIsMaxTokensOneHaikuTotal int64 `json:"metadata_legacy_fallback_is_max_tokens_one_haiku_total"`
|
|||
|
|
MetadataLegacyFallbackThinkingEnabledTotal int64 `json:"metadata_legacy_fallback_thinking_enabled_total"`
|
|||
|
|
MetadataLegacyFallbackPrefetchedStickyAccount int64 `json:"metadata_legacy_fallback_prefetched_sticky_account_total"`
|
|||
|
|
MetadataLegacyFallbackPrefetchedStickyGroup int64 `json:"metadata_legacy_fallback_prefetched_sticky_group_total"`
|
|||
|
|
MetadataLegacyFallbackSingleAccountRetryTotal int64 `json:"metadata_legacy_fallback_single_account_retry_total"`
|
|||
|
|
MetadataLegacyFallbackAccountSwitchCountTotal int64 `json:"metadata_legacy_fallback_account_switch_count_total"`
|
|||
|
|
MetadataLegacyFallbackTotal int64 `json:"metadata_legacy_fallback_total"`
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
type openAIWSRetryMetrics struct {
|
|||
|
|
retryAttempts atomic.Int64
|
|||
|
|
retryBackoffMs atomic.Int64
|
|||
|
|
retryExhausted atomic.Int64
|
|||
|
|
nonRetryableFastFallback atomic.Int64
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
type accountWriteThrottle struct {
|
|||
|
|
minInterval time.Duration
|
|||
|
|
mu sync.Mutex
|
|||
|
|
lastByID map[int64]time.Time
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func newAccountWriteThrottle(minInterval time.Duration) *accountWriteThrottle {
|
|||
|
|
return &accountWriteThrottle{
|
|||
|
|
minInterval: minInterval,
|
|||
|
|
lastByID: make(map[int64]time.Time),
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (t *accountWriteThrottle) Allow(id int64, now time.Time) bool {
|
|||
|
|
if t == nil || id <= 0 || t.minInterval <= 0 {
|
|||
|
|
return true
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
t.mu.Lock()
|
|||
|
|
defer t.mu.Unlock()
|
|||
|
|
|
|||
|
|
if last, ok := t.lastByID[id]; ok && now.Sub(last) < t.minInterval {
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
t.lastByID[id] = now
|
|||
|
|
|
|||
|
|
if len(t.lastByID) > 4096 {
|
|||
|
|
cutoff := now.Add(-4 * t.minInterval)
|
|||
|
|
for accountID, writtenAt := range t.lastByID {
|
|||
|
|
if writtenAt.Before(cutoff) {
|
|||
|
|
delete(t.lastByID, accountID)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return true
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
var defaultOpenAICodexSnapshotPersistThrottle = newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval)
|
|||
|
|
|
|||
|
|
// OpenAIGatewayService handles OpenAI API gateway operations
|
|||
|
|
type OpenAIGatewayService struct {
|
|||
|
|
accountRepo AccountRepository
|
|||
|
|
usageLogRepo UsageLogRepository
|
|||
|
|
usageBillingRepo UsageBillingRepository
|
|||
|
|
userRepo UserRepository
|
|||
|
|
userSubRepo UserSubscriptionRepository
|
|||
|
|
cache GatewayCache
|
|||
|
|
cfg *config.Config
|
|||
|
|
codexDetector CodexClientRestrictionDetector
|
|||
|
|
schedulerSnapshot *SchedulerSnapshotService
|
|||
|
|
concurrencyService *ConcurrencyService
|
|||
|
|
billingService *BillingService
|
|||
|
|
rateLimitService *RateLimitService
|
|||
|
|
billingCacheService *BillingCacheService
|
|||
|
|
userGroupRateResolver *userGroupRateResolver
|
|||
|
|
httpUpstream HTTPUpstream
|
|||
|
|
deferredService *DeferredService
|
|||
|
|
openAITokenProvider *OpenAITokenProvider
|
|||
|
|
toolCorrector *CodexToolCorrector
|
|||
|
|
openaiWSResolver OpenAIWSProtocolResolver
|
|||
|
|
|
|||
|
|
openaiWSPoolOnce sync.Once
|
|||
|
|
openaiWSStateStoreOnce sync.Once
|
|||
|
|
openaiSchedulerOnce sync.Once
|
|||
|
|
openaiWSPassthroughDialerOnce sync.Once
|
|||
|
|
openaiWSPool *openAIWSConnPool
|
|||
|
|
openaiWSStateStore OpenAIWSStateStore
|
|||
|
|
openaiScheduler OpenAIAccountScheduler
|
|||
|
|
openaiWSPassthroughDialer openAIWSClientDialer
|
|||
|
|
openaiAccountStats *openAIAccountRuntimeStats
|
|||
|
|
|
|||
|
|
openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time
|
|||
|
|
openaiWSRetryMetrics openAIWSRetryMetrics
|
|||
|
|
responseHeaderFilter *responseheaders.CompiledHeaderFilter
|
|||
|
|
codexSnapshotThrottle *accountWriteThrottle
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// NewOpenAIGatewayService creates a new OpenAIGatewayService
|
|||
|
|
func NewOpenAIGatewayService(
|
|||
|
|
accountRepo AccountRepository,
|
|||
|
|
usageLogRepo UsageLogRepository,
|
|||
|
|
usageBillingRepo UsageBillingRepository,
|
|||
|
|
userRepo UserRepository,
|
|||
|
|
userSubRepo UserSubscriptionRepository,
|
|||
|
|
userGroupRateRepo UserGroupRateRepository,
|
|||
|
|
cache GatewayCache,
|
|||
|
|
cfg *config.Config,
|
|||
|
|
schedulerSnapshot *SchedulerSnapshotService,
|
|||
|
|
concurrencyService *ConcurrencyService,
|
|||
|
|
billingService *BillingService,
|
|||
|
|
rateLimitService *RateLimitService,
|
|||
|
|
billingCacheService *BillingCacheService,
|
|||
|
|
httpUpstream HTTPUpstream,
|
|||
|
|
deferredService *DeferredService,
|
|||
|
|
openAITokenProvider *OpenAITokenProvider,
|
|||
|
|
) *OpenAIGatewayService {
|
|||
|
|
svc := &OpenAIGatewayService{
|
|||
|
|
accountRepo: accountRepo,
|
|||
|
|
usageLogRepo: usageLogRepo,
|
|||
|
|
usageBillingRepo: usageBillingRepo,
|
|||
|
|
userRepo: userRepo,
|
|||
|
|
userSubRepo: userSubRepo,
|
|||
|
|
cache: cache,
|
|||
|
|
cfg: cfg,
|
|||
|
|
codexDetector: NewOpenAICodexClientRestrictionDetector(cfg),
|
|||
|
|
schedulerSnapshot: schedulerSnapshot,
|
|||
|
|
concurrencyService: concurrencyService,
|
|||
|
|
billingService: billingService,
|
|||
|
|
rateLimitService: rateLimitService,
|
|||
|
|
billingCacheService: billingCacheService,
|
|||
|
|
userGroupRateResolver: newUserGroupRateResolver(
|
|||
|
|
userGroupRateRepo,
|
|||
|
|
nil,
|
|||
|
|
resolveUserGroupRateCacheTTL(cfg),
|
|||
|
|
nil,
|
|||
|
|
"service.openai_gateway",
|
|||
|
|
),
|
|||
|
|
httpUpstream: httpUpstream,
|
|||
|
|
deferredService: deferredService,
|
|||
|
|
openAITokenProvider: openAITokenProvider,
|
|||
|
|
toolCorrector: NewCodexToolCorrector(),
|
|||
|
|
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
|||
|
|
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
|||
|
|
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
|
|||
|
|
}
|
|||
|
|
svc.logOpenAIWSModeBootstrap()
|
|||
|
|
return svc
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle {
|
|||
|
|
if s != nil && s.codexSnapshotThrottle != nil {
|
|||
|
|
return s.codexSnapshotThrottle
|
|||
|
|
}
|
|||
|
|
return defaultOpenAICodexSnapshotPersistThrottle
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *OpenAIGatewayService) billingDeps() *billingDeps {
|
|||
|
|
return &billingDeps{
|
|||
|
|
accountRepo: s.accountRepo,
|
|||
|
|
userRepo: s.userRepo,
|
|||
|
|
userSubRepo: s.userSubRepo,
|
|||
|
|
billingCacheService: s.billingCacheService,
|
|||
|
|
deferredService: s.deferredService,
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// CloseOpenAIWSPool 关闭 OpenAI WebSocket 连接池的后台 worker 和空闲连接。
|
|||
|
|
// 应在应用优雅关闭时调用。
|
|||
|
|
func (s *OpenAIGatewayService) CloseOpenAIWSPool() {
|
|||
|
|
if s != nil && s.openaiWSPool != nil {
|
|||
|
|
s.openaiWSPool.Close()
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *OpenAIGatewayService) logOpenAIWSModeBootstrap() {
|
|||
|
|
if s == nil || s.cfg == nil {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
wsCfg := s.cfg.Gateway.OpenAIWS
|
|||
|
|
logOpenAIWSModeInfo(
|
|||
|
|
"bootstrap enabled=%v oauth_enabled=%v apikey_enabled=%v force_http=%v responses_websockets_v2=%v responses_websockets=%v payload_log_sample_rate=%.3f event_flush_batch_size=%d event_flush_interval_ms=%d prewarm_cooldown_ms=%d retry_backoff_initial_ms=%d retry_backoff_max_ms=%d retry_jitter_ratio=%.3f retry_total_budget_ms=%d ws_read_limit_bytes=%d",
|
|||
|
|
wsCfg.Enabled,
|
|||
|
|
wsCfg.OAuthEnabled,
|
|||
|
|
wsCfg.APIKeyEnabled,
|
|||
|
|
wsCfg.ForceHTTP,
|
|||
|
|
wsCfg.ResponsesWebsocketsV2,
|
|||
|
|
wsCfg.ResponsesWebsockets,
|
|||
|
|
wsCfg.PayloadLogSampleRate,
|
|||
|
|
wsCfg.EventFlushBatchSize,
|
|||
|
|
wsCfg.EventFlushIntervalMS,
|
|||
|
|
wsCfg.PrewarmCooldownMS,
|
|||
|
|
wsCfg.RetryBackoffInitialMS,
|
|||
|
|
wsCfg.RetryBackoffMaxMS,
|
|||
|
|
wsCfg.RetryJitterRatio,
|
|||
|
|
wsCfg.RetryTotalBudgetMS,
|
|||
|
|
openAIWSMessageReadLimitBytes,
|
|||
|
|
)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *OpenAIGatewayService) getCodexClientRestrictionDetector() CodexClientRestrictionDetector {
|
|||
|
|
if s != nil && s.codexDetector != nil {
|
|||
|
|
return s.codexDetector
|
|||
|
|
}
|
|||
|
|
var cfg *config.Config
|
|||
|
|
if s != nil {
|
|||
|
|
cfg = s.cfg
|
|||
|
|
}
|
|||
|
|
return NewOpenAICodexClientRestrictionDetector(cfg)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *OpenAIGatewayService) getOpenAIWSProtocolResolver() OpenAIWSProtocolResolver {
|
|||
|
|
if s != nil && s.openaiWSResolver != nil {
|
|||
|
|
return s.openaiWSResolver
|
|||
|
|
}
|
|||
|
|
var cfg *config.Config
|
|||
|
|
if s != nil {
|
|||
|
|
cfg = s.cfg
|
|||
|
|
}
|
|||
|
|
return NewOpenAIWSProtocolResolver(cfg)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func classifyOpenAIWSReconnectReason(err error) (string, bool) {
|
|||
|
|
if err == nil {
|
|||
|
|
return "", false
|
|||
|
|
}
|
|||
|
|
var fallbackErr *openAIWSFallbackError
|
|||
|
|
if !errors.As(err, &fallbackErr) || fallbackErr == nil {
|
|||
|
|
return "", false
|
|||
|
|
}
|
|||
|
|
reason := strings.TrimSpace(fallbackErr.Reason)
|
|||
|
|
if reason == "" {
|
|||
|
|
return "", false
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
baseReason := strings.TrimPrefix(reason, "prewarm_")
|
|||
|
|
|
|||
|
|
switch baseReason {
|
|||
|
|
case "policy_violation",
|
|||
|
|
"message_too_big",
|
|||
|
|
"upgrade_required",
|
|||
|
|
"ws_unsupported",
|
|||
|
|
"auth_failed",
|
|||
|
|
"invalid_encrypted_content",
|
|||
|
|
"previous_response_not_found":
|
|||
|
|
return reason, false
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
switch baseReason {
|
|||
|
|
case "read_event",
|
|||
|
|
"write_request",
|
|||
|
|
"write",
|
|||
|
|
"acquire_timeout",
|
|||
|
|
"acquire_conn",
|
|||
|
|
"conn_queue_full",
|
|||
|
|
"dial_failed",
|
|||
|
|
"upstream_5xx",
|
|||
|
|
"event_error",
|
|||
|
|
"error_event",
|
|||
|
|
"upstream_error_event",
|
|||
|
|
"ws_connection_limit_reached",
|
|||
|
|
"missing_final_response":
|
|||
|
|
return reason, true
|
|||
|
|
default:
|
|||
|
|
return reason, false
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func resolveOpenAIWSFallbackErrorResponse(err error) (statusCode int, errType string, clientMessage string, upstreamMessage string, ok bool) {
|
|||
|
|
if err == nil {
|
|||
|
|
return 0, "", "", "", false
|
|||
|
|
}
|
|||
|
|
var fallbackErr *openAIWSFallbackError
|
|||
|
|
if !errors.As(err, &fallbackErr) || fallbackErr == nil {
|
|||
|
|
return 0, "", "", "", false
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
reason := strings.TrimSpace(fallbackErr.Reason)
|
|||
|
|
reason = strings.TrimPrefix(reason, "prewarm_")
|
|||
|
|
if reason == "" {
|
|||
|
|
return 0, "", "", "", false
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
var dialErr *openAIWSDialError
|
|||
|
|
if fallbackErr.Err != nil && errors.As(fallbackErr.Err, &dialErr) && dialErr != nil {
|
|||
|
|
if dialErr.StatusCode > 0 {
|
|||
|
|
statusCode = dialErr.StatusCode
|
|||
|
|
}
|
|||
|
|
if dialErr.Err != nil {
|
|||
|
|
upstreamMessage = sanitizeUpstreamErrorMessage(strings.TrimSpace(dialErr.Err.Error()))
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
switch reason {
|
|||
|
|
case "invalid_encrypted_content":
|
|||
|
|
if statusCode == 0 {
|
|||
|
|
statusCode = http.StatusBadRequest
|
|||
|
|
}
|
|||
|
|
errType = "invalid_request_error"
|
|||
|
|
if upstreamMessage == "" {
|
|||
|
|
upstreamMessage = "encrypted content could not be verified"
|
|||
|
|
}
|
|||
|
|
case "previous_response_not_found":
|
|||
|
|
if statusCode == 0 {
|
|||
|
|
statusCode = http.StatusBadRequest
|
|||
|
|
}
|
|||
|
|
errType = "invalid_request_error"
|
|||
|
|
if upstreamMessage == "" {
|
|||
|
|
upstreamMessage = "previous response not found"
|
|||
|
|
}
|
|||
|
|
case "upgrade_required":
|
|||
|
|
if statusCode == 0 {
|
|||
|
|
statusCode = http.StatusUpgradeRequired
|
|||
|
|
}
|
|||
|
|
case "ws_unsupported":
|
|||
|
|
if statusCode == 0 {
|
|||
|
|
statusCode = http.StatusBadRequest
|
|||
|
|
}
|
|||
|
|
case "auth_failed":
|
|||
|
|
if statusCode == 0 {
|
|||
|
|
statusCode = http.StatusUnauthorized
|
|||
|
|
}
|
|||
|
|
case "upstream_rate_limited":
|
|||
|
|
if statusCode == 0 {
|
|||
|
|
statusCode = http.StatusTooManyRequests
|
|||
|
|
}
|
|||
|
|
default:
|
|||
|
|
if statusCode == 0 {
|
|||
|
|
return 0, "", "", "", false
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if upstreamMessage == "" && fallbackErr.Err != nil {
|
|||
|
|
upstreamMessage = sanitizeUpstreamErrorMessage(strings.TrimSpace(fallbackErr.Err.Error()))
|
|||
|
|
}
|
|||
|
|
if upstreamMessage == "" {
|
|||
|
|
switch reason {
|
|||
|
|
case "upgrade_required":
|
|||
|
|
upstreamMessage = "upstream websocket upgrade required"
|
|||
|
|
case "ws_unsupported":
|
|||
|
|
upstreamMessage = "upstream websocket not supported"
|
|||
|
|
case "auth_failed":
|
|||
|
|
upstreamMessage = "upstream authentication failed"
|
|||
|
|
case "upstream_rate_limited":
|
|||
|
|
upstreamMessage = "upstream rate limit exceeded, please retry later"
|
|||
|
|
default:
|
|||
|
|
upstreamMessage = "Upstream request failed"
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if errType == "" {
|
|||
|
|
if statusCode == http.StatusTooManyRequests {
|
|||
|
|
errType = "rate_limit_error"
|
|||
|
|
} else {
|
|||
|
|
errType = "upstream_error"
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
clientMessage = upstreamMessage
|
|||
|
|
return statusCode, errType, clientMessage, upstreamMessage, true
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *OpenAIGatewayService) writeOpenAIWSFallbackErrorResponse(c *gin.Context, account *Account, wsErr error) bool {
|
|||
|
|
if c == nil || c.Writer == nil || c.Writer.Written() {
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
statusCode, errType, clientMessage, upstreamMessage, ok := resolveOpenAIWSFallbackErrorResponse(wsErr)
|
|||
|
|
if !ok {
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
if strings.TrimSpace(clientMessage) == "" {
|
|||
|
|
clientMessage = "Upstream request failed"
|
|||
|
|
}
|
|||
|
|
if strings.TrimSpace(upstreamMessage) == "" {
|
|||
|
|
upstreamMessage = clientMessage
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
setOpsUpstreamError(c, statusCode, upstreamMessage, "")
|
|||
|
|
if account != nil {
|
|||
|
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
|||
|
|
Platform: account.Platform,
|
|||
|
|
AccountID: account.ID,
|
|||
|
|
AccountName: account.Name,
|
|||
|
|
UpstreamStatusCode: statusCode,
|
|||
|
|
Kind: "ws_error",
|
|||
|
|
Message: upstreamMessage,
|
|||
|
|
})
|
|||
|
|
}
|
|||
|
|
c.JSON(statusCode, gin.H{
|
|||
|
|
"error": gin.H{
|
|||
|
|
"type": errType,
|
|||
|
|
"message": clientMessage,
|
|||
|
|
},
|
|||
|
|
})
|
|||
|
|
return true
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *OpenAIGatewayService) openAIWSRetryBackoff(attempt int) time.Duration {
|
|||
|
|
if attempt <= 0 {
|
|||
|
|
return 0
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
initial := openAIWSRetryBackoffInitialDefault
|
|||
|
|
maxBackoff := openAIWSRetryBackoffMaxDefault
|
|||
|
|
jitterRatio := openAIWSRetryJitterRatioDefault
|
|||
|
|
if s != nil && s.cfg != nil {
|
|||
|
|
wsCfg := s.cfg.Gateway.OpenAIWS
|
|||
|
|
if wsCfg.RetryBackoffInitialMS > 0 {
|
|||
|
|
initial = time.Duration(wsCfg.RetryBackoffInitialMS) * time.Millisecond
|
|||
|
|
}
|
|||
|
|
if wsCfg.RetryBackoffMaxMS > 0 {
|
|||
|
|
maxBackoff = time.Duration(wsCfg.RetryBackoffMaxMS) * time.Millisecond
|
|||
|
|
}
|
|||
|
|
if wsCfg.RetryJitterRatio >= 0 {
|
|||
|
|
jitterRatio = wsCfg.RetryJitterRatio
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
if initial <= 0 {
|
|||
|
|
return 0
|
|||
|
|
}
|
|||
|
|
if maxBackoff <= 0 {
|
|||
|
|
maxBackoff = initial
|
|||
|
|
}
|
|||
|
|
if maxBackoff < initial {
|
|||
|
|
maxBackoff = initial
|
|||
|
|
}
|
|||
|
|
if jitterRatio < 0 {
|
|||
|
|
jitterRatio = 0
|
|||
|
|
}
|
|||
|
|
if jitterRatio > 1 {
|
|||
|
|
jitterRatio = 1
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
shift := attempt - 1
|
|||
|
|
if shift < 0 {
|
|||
|
|
shift = 0
|
|||
|
|
}
|
|||
|
|
backoff := initial
|
|||
|
|
if shift > 0 {
|
|||
|
|
backoff = initial * time.Duration(1<<shift)
|
|||
|
|
}
|
|||
|
|
if backoff > maxBackoff {
|
|||
|
|
backoff = maxBackoff
|
|||
|
|
}
|
|||
|
|
if jitterRatio <= 0 {
|
|||
|
|
return backoff
|
|||
|
|
}
|
|||
|
|
jitter := time.Duration(float64(backoff) * jitterRatio)
|
|||
|
|
if jitter <= 0 {
|
|||
|
|
return backoff
|
|||
|
|
}
|
|||
|
|
delta := time.Duration(rand.Int63n(int64(jitter)*2+1)) - jitter
|
|||
|
|
withJitter := backoff + delta
|
|||
|
|
if withJitter < 0 {
|
|||
|
|
return 0
|
|||
|
|
}
|
|||
|
|
return withJitter
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *OpenAIGatewayService) openAIWSRetryTotalBudget() time.Duration {
|
|||
|
|
if s != nil && s.cfg != nil {
|
|||
|
|
ms := s.cfg.Gateway.OpenAIWS.RetryTotalBudgetMS
|
|||
|
|
if ms <= 0 {
|
|||
|
|
return 0
|
|||
|
|
}
|
|||
|
|
return time.Duration(ms) * time.Millisecond
|
|||
|
|
}
|
|||
|
|
return 0
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *OpenAIGatewayService) recordOpenAIWSRetryAttempt(backoff time.Duration) {
|
|||
|
|
if s == nil {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
s.openaiWSRetryMetrics.retryAttempts.Add(1)
|
|||
|
|
if backoff > 0 {
|
|||
|
|
s.openaiWSRetryMetrics.retryBackoffMs.Add(backoff.Milliseconds())
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *OpenAIGatewayService) recordOpenAIWSRetryExhausted() {
|
|||
|
|
if s == nil {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
s.openaiWSRetryMetrics.retryExhausted.Add(1)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *OpenAIGatewayService) recordOpenAIWSNonRetryableFastFallback() {
|
|||
|
|
if s == nil {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
s.openaiWSRetryMetrics.nonRetryableFastFallback.Add(1)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *OpenAIGatewayService) SnapshotOpenAIWSRetryMetrics() OpenAIWSRetryMetricsSnapshot {
|
|||
|
|
if s == nil {
|
|||
|
|
return OpenAIWSRetryMetricsSnapshot{}
|
|||
|
|
}
|
|||
|
|
return OpenAIWSRetryMetricsSnapshot{
|
|||
|
|
RetryAttemptsTotal: s.openaiWSRetryMetrics.retryAttempts.Load(),
|
|||
|
|
RetryBackoffMsTotal: s.openaiWSRetryMetrics.retryBackoffMs.Load(),
|
|||
|
|
RetryExhaustedTotal: s.openaiWSRetryMetrics.retryExhausted.Load(),
|
|||
|
|
NonRetryableFastFallbackTotal: s.openaiWSRetryMetrics.nonRetryableFastFallback.Load(),
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func SnapshotOpenAICompatibilityFallbackMetrics() OpenAICompatibilityFallbackMetricsSnapshot {
|
|||
|
|
legacyReadFallbackTotal, legacyReadFallbackHit, legacyDualWriteTotal := openAIStickyCompatStats()
|
|||
|
|
isMaxTokensOneHaiku, thinkingEnabled, prefetchedStickyAccount, prefetchedStickyGroup, singleAccountRetry, accountSwitchCount := RequestMetadataFallbackStats()
|
|||
|
|
|
|||
|
|
readHitRate := float64(0)
|
|||
|
|
if legacyReadFallbackTotal > 0 {
|
|||
|
|
readHitRate = float64(legacyReadFallbackHit) / float64(legacyReadFallbackTotal)
|
|||
|
|
}
|
|||
|
|
metadataFallbackTotal := isMaxTokensOneHaiku + thinkingEnabled + prefetchedStickyAccount + prefetchedStickyGroup + singleAccountRetry + accountSwitchCount
|
|||
|
|
|
|||
|
|
return OpenAICompatibilityFallbackMetricsSnapshot{
|
|||
|
|
SessionHashLegacyReadFallbackTotal: legacyReadFallbackTotal,
|
|||
|
|
SessionHashLegacyReadFallbackHit: legacyReadFallbackHit,
|
|||
|
|
SessionHashLegacyDualWriteTotal: legacyDualWriteTotal,
|
|||
|
|
SessionHashLegacyReadHitRate: readHitRate,
|
|||
|
|
|
|||
|
|
MetadataLegacyFallbackIsMaxTokensOneHaikuTotal: isMaxTokensOneHaiku,
|
|||
|
|
MetadataLegacyFallbackThinkingEnabledTotal: thinkingEnabled,
|
|||
|
|
MetadataLegacyFallbackPrefetchedStickyAccount: prefetchedStickyAccount,
|
|||
|
|
MetadataLegacyFallbackPrefetchedStickyGroup: prefetchedStickyGroup,
|
|||
|
|
MetadataLegacyFallbackSingleAccountRetryTotal: singleAccountRetry,
|
|||
|
|
MetadataLegacyFallbackAccountSwitchCountTotal: accountSwitchCount,
|
|||
|
|
MetadataLegacyFallbackTotal: metadataFallbackTotal,
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *OpenAIGatewayService) detectCodexClientRestriction(c *gin.Context, account *Account) CodexClientRestrictionDetectionResult {
|
|||
|
|
return s.getCodexClientRestrictionDetector().Detect(c, account)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func getAPIKeyIDFromContext(c *gin.Context) int64 {
|
|||
|
|
if c == nil {
|
|||
|
|
return 0
|
|||
|
|
}
|
|||
|
|
v, exists := c.Get("api_key")
|
|||
|
|
if !exists {
|
|||
|
|
return 0
|
|||
|
|
}
|
|||
|
|
apiKey, ok := v.(*APIKey)
|
|||
|
|
if !ok || apiKey == nil {
|
|||
|
|
return 0
|
|||
|
|
}
|
|||
|
|
return apiKey.ID
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// isolateOpenAISessionID 将 apiKeyID 混入 session 标识符,
|
|||
|
|
// 确保不同 API Key 的用户即使使用相同的原始 session_id/conversation_id,
|
|||
|
|
// 到达上游的标识符也不同,防止跨用户会话碰撞。
|
|||
|
|
func isolateOpenAISessionID(apiKeyID int64, raw string) string {
|
|||
|
|
raw = strings.TrimSpace(raw)
|
|||
|
|
if raw == "" {
|
|||
|
|
return ""
|
|||
|
|
}
|
|||
|
|
h := xxhash.New()
|
|||
|
|
_, _ = fmt.Fprintf(h, "k%d:", apiKeyID)
|
|||
|
|
_, _ = h.WriteString(raw)
|
|||
|
|
return fmt.Sprintf("%016x", h.Sum64())
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func logCodexCLIOnlyDetection(ctx context.Context, c *gin.Context, account *Account, apiKeyID int64, result CodexClientRestrictionDetectionResult, body []byte) {
|
|||
|
|
if !result.Enabled {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
if ctx == nil {
|
|||
|
|
ctx = context.Background()
|
|||
|
|
}
|
|||
|
|
accountID := int64(0)
|
|||
|
|
if account != nil {
|
|||
|
|
accountID = account.ID
|
|||
|
|
}
|
|||
|
|
fields := []zap.Field{
|
|||
|
|
zap.String("component", "service.openai_gateway"),
|
|||
|
|
zap.Int64("account_id", accountID),
|
|||
|
|
zap.Bool("codex_cli_only_enabled", result.Enabled),
|
|||
|
|
zap.Bool("codex_official_client_match", result.Matched),
|
|||
|
|
zap.String("reject_reason", result.Reason),
|
|||
|
|
}
|
|||
|
|
if apiKeyID > 0 {
|
|||
|
|
fields = append(fields, zap.Int64("api_key_id", apiKeyID))
|
|||
|
|
}
|
|||
|
|
if !result.Matched {
|
|||
|
|
fields = appendCodexCLIOnlyRejectedRequestFields(fields, c, body)
|
|||
|
|
}
|
|||
|
|
log := logger.FromContext(ctx).With(fields...)
|
|||
|
|
if result.Matched {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
log.Warn("OpenAI codex_cli_only 拒绝非官方客户端请求")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func appendCodexCLIOnlyRejectedRequestFields(fields []zap.Field, c *gin.Context, body []byte) []zap.Field {
|
|||
|
|
if c == nil || c.Request == nil {
|
|||
|
|
return fields
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
req := c.Request
|
|||
|
|
requestModel, requestStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body)
|
|||
|
|
fields = append(fields,
|
|||
|
|
zap.String("request_method", strings.TrimSpace(req.Method)),
|
|||
|
|
zap.String("request_path", strings.TrimSpace(req.URL.Path)),
|
|||
|
|
zap.String("request_query", strings.TrimSpace(req.URL.RawQuery)),
|
|||
|
|
zap.String("request_host", strings.TrimSpace(req.Host)),
|
|||
|
|
zap.String("request_client_ip", strings.TrimSpace(c.ClientIP())),
|
|||
|
|
zap.String("request_remote_addr", strings.TrimSpace(req.RemoteAddr)),
|
|||
|
|
zap.String("request_user_agent", strings.TrimSpace(req.Header.Get("User-Agent"))),
|
|||
|
|
zap.String("request_content_type", strings.TrimSpace(req.Header.Get("Content-Type"))),
|
|||
|
|
zap.Int64("request_content_length", req.ContentLength),
|
|||
|
|
zap.Bool("request_stream", requestStream),
|
|||
|
|
)
|
|||
|
|
if requestModel != "" {
|
|||
|
|
fields = append(fields, zap.String("request_model", requestModel))
|
|||
|
|
}
|
|||
|
|
if promptCacheKey != "" {
|
|||
|
|
fields = append(fields, zap.String("request_prompt_cache_key_sha256", hashSensitiveValueForLog(promptCacheKey)))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if headers := snapshotCodexCLIOnlyHeaders(req.Header); len(headers) > 0 {
|
|||
|
|
fields = append(fields, zap.Any("request_headers", headers))
|
|||
|
|
}
|
|||
|
|
fields = append(fields, zap.Int("request_body_size", len(body)))
|
|||
|
|
return fields
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func snapshotCodexCLIOnlyHeaders(header http.Header) map[string]string {
|
|||
|
|
if len(header) == 0 {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
result := make(map[string]string, len(codexCLIOnlyDebugHeaderWhitelist))
|
|||
|
|
for _, key := range codexCLIOnlyDebugHeaderWhitelist {
|
|||
|
|
value := strings.TrimSpace(header.Get(key))
|
|||
|
|
if value == "" {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
result[strings.ToLower(key)] = truncateString(value, codexCLIOnlyHeaderValueMaxBytes)
|
|||
|
|
}
|
|||
|
|
return result
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func hashSensitiveValueForLog(raw string) string {
|
|||
|
|
value := strings.TrimSpace(raw)
|
|||
|
|
if value == "" {
|
|||
|
|
return ""
|
|||
|
|
}
|
|||
|
|
sum := sha256.Sum256([]byte(value))
|
|||
|
|
return hex.EncodeToString(sum[:8])
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func logOpenAIInstructionsRequiredDebug(
|
|||
|
|
ctx context.Context,
|
|||
|
|
c *gin.Context,
|
|||
|
|
account *Account,
|
|||
|
|
upstreamStatusCode int,
|
|||
|
|
upstreamMsg string,
|
|||
|
|
requestBody []byte,
|
|||
|
|
upstreamBody []byte,
|
|||
|
|
) {
|
|||
|
|
msg := strings.TrimSpace(upstreamMsg)
|
|||
|
|
if !isOpenAIInstructionsRequiredError(upstreamStatusCode, msg, upstreamBody) {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
if ctx == nil {
|
|||
|
|
ctx = context.Background()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
accountID := int64(0)
|
|||
|
|
accountName := ""
|
|||
|
|
if account != nil {
|
|||
|
|
accountID = account.ID
|
|||
|
|
accountName = strings.TrimSpace(account.Name)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
userAgent := ""
|
|||
|
|
originator := ""
|
|||
|
|
if c != nil {
|
|||
|
|
userAgent = strings.TrimSpace(c.GetHeader("User-Agent"))
|
|||
|
|
originator = strings.TrimSpace(c.GetHeader("originator"))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
fields := []zap.Field{
|
|||
|
|
zap.String("component", "service.openai_gateway"),
|
|||
|
|
zap.Int64("account_id", accountID),
|
|||
|
|
zap.String("account_name", accountName),
|
|||
|
|
zap.Int("upstream_status_code", upstreamStatusCode),
|
|||
|
|
zap.String("upstream_error_message", msg),
|
|||
|
|
zap.String("request_user_agent", userAgent),
|
|||
|
|
zap.Bool("codex_official_client_match", openai.IsCodexOfficialClientByHeaders(userAgent, originator)),
|
|||
|
|
}
|
|||
|
|
fields = appendCodexCLIOnlyRejectedRequestFields(fields, c, requestBody)
|
|||
|
|
|
|||
|
|
logger.FromContext(ctx).With(fields...).Warn("OpenAI 上游返回 Instructions are required,已记录请求详情用于排查")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func isOpenAIInstructionsRequiredError(upstreamStatusCode int, upstreamMsg string, upstreamBody []byte) bool {
|
|||
|
|
if upstreamStatusCode != http.StatusBadRequest {
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
hasInstructionRequired := func(text string) bool {
|
|||
|
|
lower := strings.ToLower(strings.TrimSpace(text))
|
|||
|
|
if lower == "" {
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
if strings.Contains(lower, "instructions are required") {
|
|||
|
|
return true
|
|||
|
|
}
|
|||
|
|
if strings.Contains(lower, "required parameter: 'instructions'") {
|
|||
|
|
return true
|
|||
|
|
}
|
|||
|
|
if strings.Contains(lower, "required parameter: instructions") {
|
|||
|
|
return true
|
|||
|
|
}
|
|||
|
|
if strings.Contains(lower, "missing required parameter") && strings.Contains(lower, "instructions") {
|
|||
|
|
return true
|
|||
|
|
}
|
|||
|
|
return strings.Contains(lower, "instruction") && strings.Contains(lower, "required")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if hasInstructionRequired(upstreamMsg) {
|
|||
|
|
return true
|
|||
|
|
}
|
|||
|
|
if len(upstreamBody) == 0 {
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
errMsg := gjson.GetBytes(upstreamBody, "error.message").String()
|
|||
|
|
errMsgLower := strings.ToLower(strings.TrimSpace(errMsg))
|
|||
|
|
errCode := strings.ToLower(strings.TrimSpace(gjson.GetBytes(upstreamBody, "error.code").String()))
|
|||
|
|
errParam := strings.ToLower(strings.TrimSpace(gjson.GetBytes(upstreamBody, "error.param").String()))
|
|||
|
|
errType := strings.ToLower(strings.TrimSpace(gjson.GetBytes(upstreamBody, "error.type").String()))
|
|||
|
|
|
|||
|
|
if errParam == "instructions" {
|
|||
|
|
return true
|
|||
|
|
}
|
|||
|
|
if hasInstructionRequired(errMsg) {
|
|||
|
|
return true
|
|||
|
|
}
|
|||
|
|
if strings.Contains(errCode, "missing_required_parameter") && strings.Contains(errMsgLower, "instructions") {
|
|||
|
|
return true
|
|||
|
|
}
|
|||
|
|
if strings.Contains(errType, "invalid_request") && strings.Contains(errMsgLower, "instructions") && strings.Contains(errMsgLower, "required") {
|
|||
|
|
return true
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func isOpenAITransientProcessingError(upstreamStatusCode int, upstreamMsg string, upstreamBody []byte) bool {
|
|||
|
|
if upstreamStatusCode != http.StatusBadRequest {
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
match := func(text string) bool {
|
|||
|
|
lower := strings.ToLower(strings.TrimSpace(text))
|
|||
|
|
if lower == "" {
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
if strings.Contains(lower, "an error occurred while processing your request") {
|
|||
|
|
return true
|
|||
|
|
}
|
|||
|
|
return strings.Contains(lower, "you can retry your request") &&
|
|||
|
|
strings.Contains(lower, "help.openai.com") &&
|
|||
|
|
strings.Contains(lower, "request id")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if match(upstreamMsg) {
|
|||
|
|
return true
|
|||
|
|
}
|
|||
|
|
if len(upstreamBody) == 0 {
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
if match(gjson.GetBytes(upstreamBody, "error.message").String()) {
|
|||
|
|
return true
|
|||
|
|
}
|
|||
|
|
return match(string(upstreamBody))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ExtractSessionID extracts the raw session ID from headers or body without hashing.
|
|||
|
|
// Used by ForwardAsAnthropic to pass as prompt_cache_key for upstream cache.
|
|||
|
|
func (s *OpenAIGatewayService) ExtractSessionID(c *gin.Context, body []byte) string {
|
|||
|
|
if c == nil {
|
|||
|
|
return ""
|
|||
|
|
}
|
|||
|
|
sessionID := strings.TrimSpace(c.GetHeader("session_id"))
|
|||
|
|
if sessionID == "" {
|
|||
|
|
sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
|
|||
|
|
}
|
|||
|
|
if sessionID == "" && len(body) > 0 {
|
|||
|
|
sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
|
|||
|
|
}
|
|||
|
|
return sessionID
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GenerateSessionHash generates a sticky-session hash for OpenAI requests.
|
|||
|
|
//
|
|||
|
|
// Priority:
|
|||
|
|
// 1. Header: session_id
|
|||
|
|
// 2. Header: conversation_id
|
|||
|
|
// 3. Body: prompt_cache_key (opencode)
|
|||
|
|
func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, body []byte) string {
|
|||
|
|
if c == nil {
|
|||
|
|
return ""
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
sessionID := strings.TrimSpace(c.GetHeader("session_id"))
|
|||
|
|
if sessionID == "" {
|
|||
|
|
sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
|
|||
|
|
}
|
|||
|
|
if sessionID == "" && len(body) > 0 {
|
|||
|
|
sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
|
|||
|
|
}
|
|||
|
|
if sessionID == "" {
|
|||
|
|
return ""
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
currentHash, legacyHash := deriveOpenAISessionHashes(sessionID)
|
|||
|
|
attachOpenAILegacySessionHashToGin(c, legacyHash)
|
|||
|
|
return currentHash
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GenerateSessionHashWithFallback 先按常规信号生成会话哈希;
|
|||
|
|
// 当未携带 session_id/conversation_id/prompt_cache_key 时,使用 fallbackSeed 生成稳定哈希。
|
|||
|
|
// 该方法用于 WS ingress,避免会话信号缺失时发生跨账号漂移。
|
|||
|
|
func (s *OpenAIGatewayService) GenerateSessionHashWithFallback(c *gin.Context, body []byte, fallbackSeed string) string {
|
|||
|
|
sessionHash := s.GenerateSessionHash(c, body)
|
|||
|
|
if sessionHash != "" {
|
|||
|
|
return sessionHash
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
seed := strings.TrimSpace(fallbackSeed)
|
|||
|
|
if seed == "" {
|
|||
|
|
return ""
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
currentHash, legacyHash := deriveOpenAISessionHashes(seed)
|
|||
|
|
attachOpenAILegacySessionHashToGin(c, legacyHash)
|
|||
|
|
return currentHash
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func resolveOpenAIUpstreamOriginator(c *gin.Context, isOfficialClient bool) string {
|
|||
|
|
if c != nil {
|
|||
|
|
if originator := strings.TrimSpace(c.GetHeader("originator")); originator != "" {
|
|||
|
|
return originator
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
if isOfficialClient {
|
|||
|
|
return "codex_cli_rs"
|
|||
|
|
}
|
|||
|
|
return "opencode"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// BindStickySession sets session -> account binding with standard TTL.
|
|||
|
|
func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, groupID *int64, sessionHash string, accountID int64) error {
|
|||
|
|
if sessionHash == "" || accountID <= 0 {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
ttl := openaiStickySessionTTL
|
|||
|
|
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.StickySessionTTLSeconds > 0 {
|
|||
|
|
ttl = time.Duration(s.cfg.Gateway.OpenAIWS.StickySessionTTLSeconds) * time.Second
|
|||
|
|
}
|
|||
|
|
return s.setStickySessionAccountID(ctx, groupID, sessionHash, accountID, ttl)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// SelectAccount selects an OpenAI account with sticky session support
|
|||
|
|
func (s *OpenAIGatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) {
|
|||
|
|
return s.SelectAccountForModel(ctx, groupID, sessionHash, "")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// SelectAccountForModel selects an account supporting the requested model
|
|||
|
|
func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
|
|||
|
|
return s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, nil)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
|
|||
|
|
// SelectAccountForModelWithExclusions 选择支持指定模型的账号,同时排除指定的账号。
|
|||
|
|
func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
|
|||
|
|
return s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, 0)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, stickyAccountID int64) (*Account, error) {
|
|||
|
|
// 1. 尝试粘性会话命中
|
|||
|
|
// Try sticky session hit
|
|||
|
|
if account := s.tryStickySessionHit(ctx, groupID, sessionHash, requestedModel, excludedIDs, stickyAccountID); account != nil {
|
|||
|
|
return account, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 2. 获取可调度的 OpenAI 账号
|
|||
|
|
// Get schedulable OpenAI accounts
|
|||
|
|
accounts, err := s.listSchedulableAccounts(ctx, groupID)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("query accounts failed: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 3. 按优先级 + LRU 选择最佳账号
|
|||
|
|
// Select by priority + LRU
|
|||
|
|
selected := s.selectBestAccount(ctx, accounts, requestedModel, excludedIDs)
|
|||
|
|
|
|||
|
|
if selected == nil {
|
|||
|
|
if requestedModel != "" {
|
|||
|
|
return nil, fmt.Errorf("no available OpenAI accounts supporting model: %s", requestedModel)
|
|||
|
|
}
|
|||
|
|
return nil, errors.New("no available OpenAI accounts")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 4. 设置粘性会话绑定
|
|||
|
|
// Set sticky session binding
|
|||
|
|
if sessionHash != "" {
|
|||
|
|
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, selected.ID, openaiStickySessionTTL)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return selected, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// tryStickySessionHit 尝试从粘性会话获取账号。
|
|||
|
|
// 如果命中且账号可用则返回账号;如果账号不可用则清理会话并返回 nil。
|
|||
|
|
//
|
|||
|
|
// tryStickySessionHit attempts to get account from sticky session.
|
|||
|
|
// Returns account if hit and usable; clears session and returns nil if account is unavailable.
|
|||
|
|
func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID *int64, sessionHash, requestedModel string, excludedIDs map[int64]struct{}, stickyAccountID int64) *Account {
|
|||
|
|
if sessionHash == "" {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
accountID := stickyAccountID
|
|||
|
|
if accountID <= 0 {
|
|||
|
|
var err error
|
|||
|
|
accountID, err = s.getStickySessionAccountID(ctx, groupID, sessionHash)
|
|||
|
|
if err != nil || accountID <= 0 {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if _, excluded := excludedIDs[accountID]; excluded {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
account, err := s.getSchedulableAccount(ctx, accountID)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 检查账号是否需要清理粘性会话
|
|||
|
|
// Check if sticky session should be cleared
|
|||
|
|
if shouldClearStickySession(account, requestedModel) {
|
|||
|
|
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 验证账号是否可用于当前请求
|
|||
|
|
// Verify account is usable for current request
|
|||
|
|
if !account.IsSchedulable() || !account.IsOpenAI() {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
if requestedModel != "" && !account.IsModelSupported(requestedModel) {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 刷新会话 TTL 并返回账号
|
|||
|
|
// Refresh session TTL and return account
|
|||
|
|
_ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL)
|
|||
|
|
return account
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// selectBestAccount 从候选账号中选择最佳账号(优先级 + LRU)。
|
|||
|
|
// 返回 nil 表示无可用账号。
|
|||
|
|
//
|
|||
|
|
// selectBestAccount selects the best account from candidates (priority + LRU).
|
|||
|
|
// Returns nil if no available account.
|
|||
|
|
func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account {
|
|||
|
|
var selected *Account
|
|||
|
|
|
|||
|
|
for i := range accounts {
|
|||
|
|
acc := &accounts[i]
|
|||
|
|
|
|||
|
|
// 跳过被排除的账号
|
|||
|
|
// Skip excluded accounts
|
|||
|
|
if _, excluded := excludedIDs[acc.ID]; excluded {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel)
|
|||
|
|
if fresh == nil {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 选择优先级最高且最久未使用的账号
|
|||
|
|
// Select highest priority and least recently used
|
|||
|
|
if selected == nil {
|
|||
|
|
selected = fresh
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if s.isBetterAccount(fresh, selected) {
|
|||
|
|
selected = fresh
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return selected
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// isBetterAccount 判断 candidate 是否比 current 更优。
|
|||
|
|
// 规则:优先级更高(数值更小)优先;同优先级时,未使用过的优先,其次是最久未使用的。
|
|||
|
|
//
|
|||
|
|
// isBetterAccount checks if candidate is better than current.
|
|||
|
|
// Rules: higher priority (lower value) wins; same priority: never used > least recently used.
|
|||
|
|
func (s *OpenAIGatewayService) isBetterAccount(candidate, current *Account) bool {
|
|||
|
|
// 优先级更高(数值更小)
|
|||
|
|
// Higher priority (lower value)
|
|||
|
|
if candidate.Priority < current.Priority {
|
|||
|
|
return true
|
|||
|
|
}
|
|||
|
|
if candidate.Priority > current.Priority {
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 同优先级,比较最后使用时间
|
|||
|
|
// Same priority, compare last used time
|
|||
|
|
switch {
|
|||
|
|
case candidate.LastUsedAt == nil && current.LastUsedAt != nil:
|
|||
|
|
// candidate 从未使用,优先
|
|||
|
|
return true
|
|||
|
|
case candidate.LastUsedAt != nil && current.LastUsedAt == nil:
|
|||
|
|
// current 从未使用,保持
|
|||
|
|
return false
|
|||
|
|
case candidate.LastUsedAt == nil && current.LastUsedAt == nil:
|
|||
|
|
// 都未使用,保持
|
|||
|
|
return false
|
|||
|
|
default:
|
|||
|
|
// 都使用过,选择最久未使用的
|
|||
|
|
return candidate.LastUsedAt.Before(*current.LastUsedAt)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan.
|
|||
|
|
func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
|
|||
|
|
cfg := s.schedulingConfig()
|
|||
|
|
var stickyAccountID int64
|
|||
|
|
if sessionHash != "" && s.cache != nil {
|
|||
|
|
if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil {
|
|||
|
|
stickyAccountID = accountID
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
|
|||
|
|
account, err := s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, stickyAccountID)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
|
|||
|
|
if err == nil && result.Acquired {
|
|||
|
|
return &AccountSelectionResult{
|
|||
|
|
Account: account,
|
|||
|
|
Acquired: true,
|
|||
|
|
ReleaseFunc: result.ReleaseFunc,
|
|||
|
|
}, nil
|
|||
|
|
}
|
|||
|
|
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
|
|||
|
|
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
|
|||
|
|
if waitingCount < cfg.StickySessionMaxWaiting {
|
|||
|
|
return &AccountSelectionResult{
|
|||
|
|
Account: account,
|
|||
|
|
WaitPlan: &AccountWaitPlan{
|
|||
|
|
AccountID: account.ID,
|
|||
|
|
MaxConcurrency: account.Concurrency,
|
|||
|
|
Timeout: cfg.StickySessionWaitTimeout,
|
|||
|
|
MaxWaiting: cfg.StickySessionMaxWaiting,
|
|||
|
|
},
|
|||
|
|
}, nil
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
return &AccountSelectionResult{
|
|||
|
|
Account: account,
|
|||
|
|
WaitPlan: &AccountWaitPlan{
|
|||
|
|
AccountID: account.ID,
|
|||
|
|
MaxConcurrency: account.Concurrency,
|
|||
|
|
Timeout: cfg.FallbackWaitTimeout,
|
|||
|
|
MaxWaiting: cfg.FallbackMaxWaiting,
|
|||
|
|
},
|
|||
|
|
}, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
accounts, err := s.listSchedulableAccounts(ctx, groupID)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
if len(accounts) == 0 {
|
|||
|
|
return nil, ErrNoAvailableAccounts
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
isExcluded := func(accountID int64) bool {
|
|||
|
|
if excludedIDs == nil {
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
_, excluded := excludedIDs[accountID]
|
|||
|
|
return excluded
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ============ Layer 1: Sticky session ============
|
|||
|
|
if sessionHash != "" {
|
|||
|
|
accountID := stickyAccountID
|
|||
|
|
if accountID > 0 && !isExcluded(accountID) {
|
|||
|
|
account, err := s.getSchedulableAccount(ctx, accountID)
|
|||
|
|
if err == nil {
|
|||
|
|
clearSticky := shouldClearStickySession(account, requestedModel)
|
|||
|
|
if clearSticky {
|
|||
|
|
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
|
|||
|
|
}
|
|||
|
|
if !clearSticky && account.IsSchedulable() && account.IsOpenAI() &&
|
|||
|
|
(requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
|||
|
|
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
|||
|
|
if err == nil && result.Acquired {
|
|||
|
|
_ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL)
|
|||
|
|
return &AccountSelectionResult{
|
|||
|
|
Account: account,
|
|||
|
|
Acquired: true,
|
|||
|
|
ReleaseFunc: result.ReleaseFunc,
|
|||
|
|
}, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
|
|||
|
|
if waitingCount < cfg.StickySessionMaxWaiting {
|
|||
|
|
return &AccountSelectionResult{
|
|||
|
|
Account: account,
|
|||
|
|
WaitPlan: &AccountWaitPlan{
|
|||
|
|
AccountID: accountID,
|
|||
|
|
MaxConcurrency: account.Concurrency,
|
|||
|
|
Timeout: cfg.StickySessionWaitTimeout,
|
|||
|
|
MaxWaiting: cfg.StickySessionMaxWaiting,
|
|||
|
|
},
|
|||
|
|
}, nil
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ============ Layer 2: Load-aware selection ============
|
|||
|
|
candidates := make([]*Account, 0, len(accounts))
|
|||
|
|
for i := range accounts {
|
|||
|
|
acc := &accounts[i]
|
|||
|
|
if isExcluded(acc.ID) {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
// Scheduler snapshots can be temporarily stale (bucket rebuild is throttled);
|
|||
|
|
// re-check schedulability here so recently rate-limited/overloaded accounts
|
|||
|
|
// are not selected again before the bucket is rebuilt.
|
|||
|
|
if !acc.IsSchedulable() {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
candidates = append(candidates, acc)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if len(candidates) == 0 {
|
|||
|
|
return nil, ErrNoAvailableAccounts
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
accountLoads := make([]AccountWithConcurrency, 0, len(candidates))
|
|||
|
|
for _, acc := range candidates {
|
|||
|
|
accountLoads = append(accountLoads, AccountWithConcurrency{
|
|||
|
|
ID: acc.ID,
|
|||
|
|
MaxConcurrency: acc.EffectiveLoadFactor(),
|
|||
|
|
})
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
|
|||
|
|
if err != nil {
|
|||
|
|
ordered := append([]*Account(nil), candidates...)
|
|||
|
|
sortAccountsByPriorityAndLastUsed(ordered, false)
|
|||
|
|
for _, acc := range ordered {
|
|||
|
|
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel)
|
|||
|
|
if fresh == nil {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
|
|||
|
|
if err == nil && result.Acquired {
|
|||
|
|
if sessionHash != "" {
|
|||
|
|
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL)
|
|||
|
|
}
|
|||
|
|
return &AccountSelectionResult{
|
|||
|
|
Account: fresh,
|
|||
|
|
Acquired: true,
|
|||
|
|
ReleaseFunc: result.ReleaseFunc,
|
|||
|
|
}, nil
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
} else {
|
|||
|
|
var available []accountWithLoad
|
|||
|
|
for _, acc := range candidates {
|
|||
|
|
loadInfo := loadMap[acc.ID]
|
|||
|
|
if loadInfo == nil {
|
|||
|
|
loadInfo = &AccountLoadInfo{AccountID: acc.ID}
|
|||
|
|
}
|
|||
|
|
if loadInfo.LoadRate < 100 {
|
|||
|
|
available = append(available, accountWithLoad{
|
|||
|
|
account: acc,
|
|||
|
|
loadInfo: loadInfo,
|
|||
|
|
})
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if len(available) > 0 {
|
|||
|
|
sort.SliceStable(available, func(i, j int) bool {
|
|||
|
|
a, b := available[i], available[j]
|
|||
|
|
if a.account.Priority != b.account.Priority {
|
|||
|
|
return a.account.Priority < b.account.Priority
|
|||
|
|
}
|
|||
|
|
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
|
|||
|
|
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
|
|||
|
|
}
|
|||
|
|
switch {
|
|||
|
|
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
|
|||
|
|
return true
|
|||
|
|
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
|
|||
|
|
return false
|
|||
|
|
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
|
|||
|
|
return false
|
|||
|
|
default:
|
|||
|
|
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
shuffleWithinSortGroups(available)
|
|||
|
|
|
|||
|
|
for _, item := range available {
|
|||
|
|
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, item.account, requestedModel)
|
|||
|
|
if fresh == nil {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
|
|||
|
|
if err == nil && result.Acquired {
|
|||
|
|
if sessionHash != "" {
|
|||
|
|
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL)
|
|||
|
|
}
|
|||
|
|
return &AccountSelectionResult{
|
|||
|
|
Account: fresh,
|
|||
|
|
Acquired: true,
|
|||
|
|
ReleaseFunc: result.ReleaseFunc,
|
|||
|
|
}, nil
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ============ Layer 3: Fallback wait ============
|
|||
|
|
sortAccountsByPriorityAndLastUsed(candidates, false)
|
|||
|
|
for _, acc := range candidates {
|
|||
|
|
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel)
|
|||
|
|
if fresh == nil {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
return &AccountSelectionResult{
|
|||
|
|
Account: fresh,
|
|||
|
|
WaitPlan: &AccountWaitPlan{
|
|||
|
|
AccountID: fresh.ID,
|
|||
|
|
MaxConcurrency: fresh.Concurrency,
|
|||
|
|
Timeout: cfg.FallbackWaitTimeout,
|
|||
|
|
MaxWaiting: cfg.FallbackMaxWaiting,
|
|||
|
|
},
|
|||
|
|
}, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return nil, ErrNoAvailableAccounts
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *OpenAIGatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, error) {
|
|||
|
|
if s.schedulerSnapshot != nil {
|
|||
|
|
accounts, _, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, PlatformOpenAI, false)
|
|||
|
|
return accounts, err
|
|||
|
|
}
|
|||
|
|
var accounts []Account
|
|||
|
|
var err error
|
|||
|
|
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
|||
|
|
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI)
|
|||
|
|
} else if groupID != nil {
|
|||
|
|
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI)
|
|||
|
|
} else {
|
|||
|
|
accounts, err = s.accountRepo.ListSchedulableUngroupedByPlatform(ctx, PlatformOpenAI)
|
|||
|
|
}
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("query accounts failed: %w", err)
|
|||
|
|
}
|
|||
|
|
return accounts, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *OpenAIGatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) {
|
|||
|
|
if s.concurrencyService == nil {
|
|||
|
|
return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil
|
|||
|
|
}
|
|||
|
|
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.Context, account *Account, requestedModel string) *Account {
|
|||
|
|
if account == nil {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
fresh := account
|
|||
|
|
if s.schedulerSnapshot != nil {
|
|||
|
|
current, err := s.getSchedulableAccount(ctx, account.ID)
|
|||
|
|
if err != nil || current == nil {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
fresh = current
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if !fresh.IsSchedulable() || !fresh.IsOpenAI() {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
if requestedModel != "" && !fresh.IsModelSupported(requestedModel) {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
return fresh
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
|
|||
|
|
var (
|
|||
|
|
account *Account
|
|||
|
|
err error
|
|||
|
|
)
|
|||
|
|
if s.schedulerSnapshot != nil {
|
|||
|
|
account, err = s.schedulerSnapshot.GetAccount(ctx, accountID)
|
|||
|
|
} else {
|
|||
|
|
account, err = s.accountRepo.GetByID(ctx, accountID)
|
|||
|
|
}
|
|||
|
|
if err != nil || account == nil {
|
|||
|
|
return account, err
|
|||
|
|
}
|
|||
|
|
syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, account, time.Now())
|
|||
|
|
return account, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig {
|
|||
|
|
if s.cfg != nil {
|
|||
|
|
return s.cfg.Gateway.Scheduling
|
|||
|
|
}
|
|||
|
|
return config.GatewaySchedulingConfig{
|
|||
|
|
StickySessionMaxWaiting: 3,
|
|||
|
|
StickySessionWaitTimeout: 45 * time.Second,
|
|||
|
|
FallbackWaitTimeout: 30 * time.Second,
|
|||
|
|
FallbackMaxWaiting: 100,
|
|||
|
|
LoadBatchEnabled: true,
|
|||
|
|
SlotCleanupInterval: 30 * time.Second,
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetAccessToken gets the access token for an OpenAI account
|
|||
|
|
func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
|
|||
|
|
switch account.Type {
|
|||
|
|
case AccountTypeOAuth:
|
|||
|
|
// 使用 TokenProvider 获取缓存的 token
|
|||
|
|
if s.openAITokenProvider != nil {
|
|||
|
|
accessToken, err := s.openAITokenProvider.GetAccessToken(ctx, account)
|
|||
|
|
if err != nil {
|
|||
|
|
return "", "", err
|
|||
|
|
}
|
|||
|
|
return accessToken, "oauth", nil
|
|||
|
|
}
|
|||
|
|
// 降级:TokenProvider 未配置时直接从账号读取
|
|||
|
|
accessToken := account.GetOpenAIAccessToken()
|
|||
|
|
if accessToken == "" {
|
|||
|
|
return "", "", errors.New("access_token not found in credentials")
|
|||
|
|
}
|
|||
|
|
return accessToken, "oauth", nil
|
|||
|
|
case AccountTypeAPIKey:
|
|||
|
|
apiKey := account.GetOpenAIApiKey()
|
|||
|
|
if apiKey == "" {
|
|||
|
|
return "", "", errors.New("api_key not found in credentials")
|
|||
|
|
}
|
|||
|
|
return apiKey, "apikey", nil
|
|||
|
|
default:
|
|||
|
|
return "", "", fmt.Errorf("unsupported account type: %s", account.Type)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *OpenAIGatewayService) shouldFailoverUpstreamError(statusCode int) bool {
|
|||
|
|
switch statusCode {
|
|||
|
|
case 401, 402, 403, 429, 529:
|
|||
|
|
return true
|
|||
|
|
default:
|
|||
|
|
return statusCode >= 500
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *OpenAIGatewayService) shouldFailoverOpenAIUpstreamResponse(statusCode int, upstreamMsg string, upstreamBody []byte) bool {
|
|||
|
|
if s.shouldFailoverUpstreamError(statusCode) {
|
|||
|
|
return true
|
|||
|
|
}
|
|||
|
|
return isOpenAITransientProcessingError(statusCode, upstreamMsg, upstreamBody)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
|
|||
|
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
|||
|
|
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Forward forwards request to OpenAI API
|
|||
|
|
func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*OpenAIForwardResult, error) {
|
|||
|
|
startTime := time.Now()
|
|||
|
|
|
|||
|
|
restrictionResult := s.detectCodexClientRestriction(c, account)
|
|||
|
|
apiKeyID := getAPIKeyIDFromContext(c)
|
|||
|
|
logCodexCLIOnlyDetection(ctx, c, account, apiKeyID, restrictionResult, body)
|
|||
|
|
if restrictionResult.Enabled && !restrictionResult.Matched {
|
|||
|
|
c.JSON(http.StatusForbidden, gin.H{
|
|||
|
|
"error": gin.H{
|
|||
|
|
"type": "forbidden_error",
|
|||
|
|
"message": "This account only allows Codex official clients",
|
|||
|
|
},
|
|||
|
|
})
|
|||
|
|
return nil, errors.New("codex_cli_only restriction: only codex official clients are allowed")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
originalBody := body
|
|||
|
|
reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body)
|
|||
|
|
originalModel := reqModel
|
|||
|
|
|
|||
|
|
isCodexCLI := openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
|
|||
|
|
wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account)
|
|||
|
|
clientTransport := GetOpenAIClientTransport(c)
|
|||
|
|
// 仅允许 WS 入站请求走 WS 上游,避免出现 HTTP -> WS 协议混用。
|
|||
|
|
wsDecision = resolveOpenAIWSDecisionByClientTransport(wsDecision, clientTransport)
|
|||
|
|
if c != nil {
|
|||
|
|
c.Set("openai_ws_transport_decision", string(wsDecision.Transport))
|
|||
|
|
c.Set("openai_ws_transport_reason", wsDecision.Reason)
|
|||
|
|
}
|
|||
|
|
if wsDecision.Transport == OpenAIUpstreamTransportResponsesWebsocketV2 {
|
|||
|
|
logOpenAIWSModeDebug(
|
|||
|
|
"selected account_id=%d account_type=%s transport=%s reason=%s model=%s stream=%v",
|
|||
|
|
account.ID,
|
|||
|
|
account.Type,
|
|||
|
|
normalizeOpenAIWSLogValue(string(wsDecision.Transport)),
|
|||
|
|
normalizeOpenAIWSLogValue(wsDecision.Reason),
|
|||
|
|
reqModel,
|
|||
|
|
reqStream,
|
|||
|
|
)
|
|||
|
|
}
|
|||
|
|
// 当前仅支持 WSv2;WSv1 命中时直接返回错误,避免出现“配置可开但行为不确定”。
|
|||
|
|
if wsDecision.Transport == OpenAIUpstreamTransportResponsesWebsocket {
|
|||
|
|
if c != nil {
|
|||
|
|
c.JSON(http.StatusBadRequest, gin.H{
|
|||
|
|
"error": gin.H{
|
|||
|
|
"type": "invalid_request_error",
|
|||
|
|
"message": "OpenAI WSv1 is temporarily unsupported. Please enable responses_websockets_v2.",
|
|||
|
|
},
|
|||
|
|
})
|
|||
|
|
}
|
|||
|
|
return nil, errors.New("openai ws v1 is temporarily unsupported; use ws v2")
|
|||
|
|
}
|
|||
|
|
passthroughEnabled := account.IsOpenAIPassthroughEnabled()
|
|||
|
|
if passthroughEnabled {
|
|||
|
|
// 透传分支只需要轻量提取字段,避免热路径全量 Unmarshal。
|
|||
|
|
reasoningEffort := extractOpenAIReasoningEffortFromBody(body, reqModel)
|
|||
|
|
return s.forwardOpenAIPassthrough(ctx, c, account, originalBody, reqModel, reasoningEffort, reqStream, startTime)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
reqBody, err := getOpenAIRequestBodyMap(c, body)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if v, ok := reqBody["model"].(string); ok {
|
|||
|
|
reqModel = v
|
|||
|
|
originalModel = reqModel
|
|||
|
|
}
|
|||
|
|
if v, ok := reqBody["stream"].(bool); ok {
|
|||
|
|
reqStream = v
|
|||
|
|
}
|
|||
|
|
if promptCacheKey == "" {
|
|||
|
|
if v, ok := reqBody["prompt_cache_key"].(string); ok {
|
|||
|
|
promptCacheKey = strings.TrimSpace(v)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Track if body needs re-serialization
|
|||
|
|
bodyModified := false
|
|||
|
|
// 单字段补丁快速路径:只要整个变更集最终可归约为同一路径的 set/delete,就避免全量 Marshal。
|
|||
|
|
patchDisabled := false
|
|||
|
|
patchHasOp := false
|
|||
|
|
patchDelete := false
|
|||
|
|
patchPath := ""
|
|||
|
|
var patchValue any
|
|||
|
|
markPatchSet := func(path string, value any) {
|
|||
|
|
if strings.TrimSpace(path) == "" {
|
|||
|
|
patchDisabled = true
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
if patchDisabled {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
if !patchHasOp {
|
|||
|
|
patchHasOp = true
|
|||
|
|
patchDelete = false
|
|||
|
|
patchPath = path
|
|||
|
|
patchValue = value
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
if patchDelete || patchPath != path {
|
|||
|
|
patchDisabled = true
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
patchValue = value
|
|||
|
|
}
|
|||
|
|
markPatchDelete := func(path string) {
|
|||
|
|
if strings.TrimSpace(path) == "" {
|
|||
|
|
patchDisabled = true
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
if patchDisabled {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
if !patchHasOp {
|
|||
|
|
patchHasOp = true
|
|||
|
|
patchDelete = true
|
|||
|
|
patchPath = path
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
if !patchDelete || patchPath != path {
|
|||
|
|
patchDisabled = true
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
disablePatch := func() {
|
|||
|
|
patchDisabled = true
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 非透传模式下,instructions 为空时注入默认指令。
|
|||
|
|
if isInstructionsEmpty(reqBody) {
|
|||
|
|
reqBody["instructions"] = "You are a helpful coding assistant."
|
|||
|
|
bodyModified = true
|
|||
|
|
markPatchSet("instructions", "You are a helpful coding assistant.")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 对所有请求执行模型映射(包含 Codex CLI)。
|
|||
|
|
mappedModel := account.GetMappedModel(reqModel)
|
|||
|
|
if mappedModel != reqModel {
|
|||
|
|
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, mappedModel, account.Name, isCodexCLI)
|
|||
|
|
reqBody["model"] = mappedModel
|
|||
|
|
bodyModified = true
|
|||
|
|
markPatchSet("model", mappedModel)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 针对所有 OpenAI 账号执行 Codex 模型名规范化,确保上游识别一致。
|
|||
|
|
if model, ok := reqBody["model"].(string); ok {
|
|||
|
|
normalizedModel := normalizeCodexModel(model)
|
|||
|
|
if normalizedModel != "" && normalizedModel != model {
|
|||
|
|
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Codex model normalization: %s -> %s (account: %s, type: %s, isCodexCLI: %v)",
|
|||
|
|
model, normalizedModel, account.Name, account.Type, isCodexCLI)
|
|||
|
|
reqBody["model"] = normalizedModel
|
|||
|
|
mappedModel = normalizedModel
|
|||
|
|
bodyModified = true
|
|||
|
|
markPatchSet("model", normalizedModel)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 移除 gpt-5.2-codex 以下的版本 verbosity 参数
|
|||
|
|
// 确保高版本模型向低版本模型映射不报错
|
|||
|
|
if !SupportsVerbosity(normalizedModel) {
|
|||
|
|
if text, ok := reqBody["text"].(map[string]any); ok {
|
|||
|
|
delete(text, "verbosity")
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 规范化 reasoning.effort 参数(minimal -> none),与上游允许值对齐。
|
|||
|
|
if reasoning, ok := reqBody["reasoning"].(map[string]any); ok {
|
|||
|
|
if effort, ok := reasoning["effort"].(string); ok && effort == "minimal" {
|
|||
|
|
reasoning["effort"] = "none"
|
|||
|
|
bodyModified = true
|
|||
|
|
markPatchSet("reasoning.effort", "none")
|
|||
|
|
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Normalized reasoning.effort: minimal -> none (account: %s)", account.Name)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if account.Type == AccountTypeOAuth {
|
|||
|
|
codexResult := applyCodexOAuthTransform(reqBody, isCodexCLI, isOpenAIResponsesCompactPath(c))
|
|||
|
|
if codexResult.Modified {
|
|||
|
|
bodyModified = true
|
|||
|
|
disablePatch()
|
|||
|
|
}
|
|||
|
|
if codexResult.NormalizedModel != "" {
|
|||
|
|
mappedModel = codexResult.NormalizedModel
|
|||
|
|
}
|
|||
|
|
if codexResult.PromptCacheKey != "" {
|
|||
|
|
promptCacheKey = codexResult.PromptCacheKey
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Handle max_output_tokens based on platform and account type
|
|||
|
|
if !isCodexCLI {
|
|||
|
|
if maxOutputTokens, hasMaxOutputTokens := reqBody["max_output_tokens"]; hasMaxOutputTokens {
|
|||
|
|
switch account.Platform {
|
|||
|
|
case PlatformOpenAI:
|
|||
|
|
// For OpenAI API Key, remove max_output_tokens (not supported)
|
|||
|
|
// For OpenAI OAuth (Responses API), keep it (supported)
|
|||
|
|
if account.Type == AccountTypeAPIKey {
|
|||
|
|
delete(reqBody, "max_output_tokens")
|
|||
|
|
bodyModified = true
|
|||
|
|
markPatchDelete("max_output_tokens")
|
|||
|
|
}
|
|||
|
|
case PlatformAnthropic:
|
|||
|
|
// For Anthropic (Claude), convert to max_tokens
|
|||
|
|
delete(reqBody, "max_output_tokens")
|
|||
|
|
markPatchDelete("max_output_tokens")
|
|||
|
|
if _, hasMaxTokens := reqBody["max_tokens"]; !hasMaxTokens {
|
|||
|
|
reqBody["max_tokens"] = maxOutputTokens
|
|||
|
|
disablePatch()
|
|||
|
|
}
|
|||
|
|
bodyModified = true
|
|||
|
|
case PlatformGemini:
|
|||
|
|
// For Gemini, remove (will be handled by Gemini-specific transform)
|
|||
|
|
delete(reqBody, "max_output_tokens")
|
|||
|
|
bodyModified = true
|
|||
|
|
markPatchDelete("max_output_tokens")
|
|||
|
|
default:
|
|||
|
|
// For unknown platforms, remove to be safe
|
|||
|
|
delete(reqBody, "max_output_tokens")
|
|||
|
|
bodyModified = true
|
|||
|
|
markPatchDelete("max_output_tokens")
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Also handle max_completion_tokens (similar logic)
|
|||
|
|
if _, hasMaxCompletionTokens := reqBody["max_completion_tokens"]; hasMaxCompletionTokens {
|
|||
|
|
if account.Type == AccountTypeAPIKey || account.Platform != PlatformOpenAI {
|
|||
|
|
delete(reqBody, "max_completion_tokens")
|
|||
|
|
bodyModified = true
|
|||
|
|
markPatchDelete("max_completion_tokens")
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Remove unsupported fields (not supported by upstream OpenAI API)
|
|||
|
|
unsupportedFields := []string{"prompt_cache_retention", "safety_identifier"}
|
|||
|
|
for _, unsupportedField := range unsupportedFields {
|
|||
|
|
if _, has := reqBody[unsupportedField]; has {
|
|||
|
|
delete(reqBody, unsupportedField)
|
|||
|
|
bodyModified = true
|
|||
|
|
markPatchDelete(unsupportedField)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 仅在 WSv2 模式保留 previous_response_id,其他模式(HTTP/WSv1)统一过滤。
|
|||
|
|
// 注意:该规则同样适用于 Codex CLI 请求,避免 WSv1 向上游透传不支持字段。
|
|||
|
|
if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 {
|
|||
|
|
if _, has := reqBody["previous_response_id"]; has {
|
|||
|
|
delete(reqBody, "previous_response_id")
|
|||
|
|
bodyModified = true
|
|||
|
|
markPatchDelete("previous_response_id")
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Re-serialize body only if modified
|
|||
|
|
if bodyModified {
|
|||
|
|
serializedByPatch := false
|
|||
|
|
if !patchDisabled && patchHasOp {
|
|||
|
|
var patchErr error
|
|||
|
|
if patchDelete {
|
|||
|
|
body, patchErr = sjson.DeleteBytes(body, patchPath)
|
|||
|
|
} else {
|
|||
|
|
body, patchErr = sjson.SetBytes(body, patchPath, patchValue)
|
|||
|
|
}
|
|||
|
|
if patchErr == nil {
|
|||
|
|
serializedByPatch = true
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
if !serializedByPatch {
|
|||
|
|
var marshalErr error
|
|||
|
|
body, marshalErr = json.Marshal(reqBody)
|
|||
|
|
if marshalErr != nil {
|
|||
|
|
return nil, fmt.Errorf("serialize request body: %w", marshalErr)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Get access token
|
|||
|
|
token, _, err := s.GetAccessToken(ctx, account)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Capture upstream request body for ops retry of this attempt.
|
|||
|
|
setOpsUpstreamRequestBody(c, body)
|
|||
|
|
|
|||
|
|
// 命中 WS 时仅走 WebSocket Mode;不再自动回退 HTTP。
|
|||
|
|
if wsDecision.Transport == OpenAIUpstreamTransportResponsesWebsocketV2 {
|
|||
|
|
wsReqBody := reqBody
|
|||
|
|
if len(reqBody) > 0 {
|
|||
|
|
wsReqBody = make(map[string]any, len(reqBody))
|
|||
|
|
for k, v := range reqBody {
|
|||
|
|
wsReqBody[k] = v
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
_, hasPreviousResponseID := wsReqBody["previous_response_id"]
|
|||
|
|
logOpenAIWSModeDebug(
|
|||
|
|
"forward_start account_id=%d account_type=%s model=%s stream=%v has_previous_response_id=%v",
|
|||
|
|
account.ID,
|
|||
|
|
account.Type,
|
|||
|
|
mappedModel,
|
|||
|
|
reqStream,
|
|||
|
|
hasPreviousResponseID,
|
|||
|
|
)
|
|||
|
|
maxAttempts := openAIWSReconnectRetryLimit + 1
|
|||
|
|
wsAttempts := 0
|
|||
|
|
var wsResult *OpenAIForwardResult
|
|||
|
|
var wsErr error
|
|||
|
|
wsLastFailureReason := ""
|
|||
|
|
wsPrevResponseRecoveryTried := false
|
|||
|
|
wsInvalidEncryptedContentRecoveryTried := false
|
|||
|
|
recoverPrevResponseNotFound := func(attempt int) bool {
|
|||
|
|
if wsPrevResponseRecoveryTried {
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
previousResponseID := openAIWSPayloadString(wsReqBody, "previous_response_id")
|
|||
|
|
if previousResponseID == "" {
|
|||
|
|
logOpenAIWSModeInfo(
|
|||
|
|
"reconnect_prev_response_recovery_skip account_id=%d attempt=%d reason=missing_previous_response_id previous_response_id_present=false",
|
|||
|
|
account.ID,
|
|||
|
|
attempt,
|
|||
|
|
)
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
if HasFunctionCallOutput(wsReqBody) {
|
|||
|
|
logOpenAIWSModeInfo(
|
|||
|
|
"reconnect_prev_response_recovery_skip account_id=%d attempt=%d reason=has_function_call_output previous_response_id_present=true",
|
|||
|
|
account.ID,
|
|||
|
|
attempt,
|
|||
|
|
)
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
delete(wsReqBody, "previous_response_id")
|
|||
|
|
wsPrevResponseRecoveryTried = true
|
|||
|
|
logOpenAIWSModeInfo(
|
|||
|
|
"reconnect_prev_response_recovery account_id=%d attempt=%d action=drop_previous_response_id retry=1 previous_response_id=%s previous_response_id_kind=%s",
|
|||
|
|
account.ID,
|
|||
|
|
attempt,
|
|||
|
|
truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen),
|
|||
|
|
normalizeOpenAIWSLogValue(ClassifyOpenAIPreviousResponseIDKind(previousResponseID)),
|
|||
|
|
)
|
|||
|
|
return true
|
|||
|
|
}
|
|||
|
|
recoverInvalidEncryptedContent := func(attempt int) bool {
|
|||
|
|
if wsInvalidEncryptedContentRecoveryTried {
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
removedReasoningItems := trimOpenAIEncryptedReasoningItems(wsReqBody)
|
|||
|
|
if !removedReasoningItems {
|
|||
|
|
logOpenAIWSModeInfo(
|
|||
|
|
"reconnect_invalid_encrypted_content_recovery_skip account_id=%d attempt=%d reason=missing_encrypted_reasoning_items",
|
|||
|
|
account.ID,
|
|||
|
|
attempt,
|
|||
|
|
)
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
previousResponseID := openAIWSPayloadString(wsReqBody, "previous_response_id")
|
|||
|
|
hasFunctionCallOutput := HasFunctionCallOutput(wsReqBody)
|
|||
|
|
if previousResponseID != "" && !hasFunctionCallOutput {
|
|||
|
|
delete(wsReqBody, "previous_response_id")
|
|||
|
|
}
|
|||
|
|
wsInvalidEncryptedContentRecoveryTried = true
|
|||
|
|
logOpenAIWSModeInfo(
|
|||
|
|
"reconnect_invalid_encrypted_content_recovery account_id=%d attempt=%d action=drop_encrypted_reasoning_items retry=1 previous_response_id_present=%v previous_response_id=%s previous_response_id_kind=%s has_function_call_output=%v dropped_previous_response_id=%v",
|
|||
|
|
account.ID,
|
|||
|
|
attempt,
|
|||
|
|
previousResponseID != "",
|
|||
|
|
truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen),
|
|||
|
|
normalizeOpenAIWSLogValue(ClassifyOpenAIPreviousResponseIDKind(previousResponseID)),
|
|||
|
|
hasFunctionCallOutput,
|
|||
|
|
previousResponseID != "" && !hasFunctionCallOutput,
|
|||
|
|
)
|
|||
|
|
return true
|
|||
|
|
}
|
|||
|
|
retryBudget := s.openAIWSRetryTotalBudget()
|
|||
|
|
retryStartedAt := time.Now()
|
|||
|
|
wsRetryLoop:
|
|||
|
|
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
|||
|
|
wsAttempts = attempt
|
|||
|
|
wsResult, wsErr = s.forwardOpenAIWSV2(
|
|||
|
|
ctx,
|
|||
|
|
c,
|
|||
|
|
account,
|
|||
|
|
wsReqBody,
|
|||
|
|
token,
|
|||
|
|
wsDecision,
|
|||
|
|
isCodexCLI,
|
|||
|
|
reqStream,
|
|||
|
|
originalModel,
|
|||
|
|
mappedModel,
|
|||
|
|
startTime,
|
|||
|
|
attempt,
|
|||
|
|
wsLastFailureReason,
|
|||
|
|
)
|
|||
|
|
if wsErr == nil {
|
|||
|
|
break
|
|||
|
|
}
|
|||
|
|
if c != nil && c.Writer != nil && c.Writer.Written() {
|
|||
|
|
break
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
reason, retryable := classifyOpenAIWSReconnectReason(wsErr)
|
|||
|
|
if reason != "" {
|
|||
|
|
wsLastFailureReason = reason
|
|||
|
|
}
|
|||
|
|
// previous_response_not_found 说明续链锚点不可用:
|
|||
|
|
// 对非 function_call_output 场景,允许一次“去掉 previous_response_id 后重放”。
|
|||
|
|
if reason == "previous_response_not_found" && recoverPrevResponseNotFound(attempt) {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
if reason == "invalid_encrypted_content" && recoverInvalidEncryptedContent(attempt) {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
if retryable && attempt < maxAttempts {
|
|||
|
|
backoff := s.openAIWSRetryBackoff(attempt)
|
|||
|
|
if retryBudget > 0 && time.Since(retryStartedAt)+backoff > retryBudget {
|
|||
|
|
s.recordOpenAIWSRetryExhausted()
|
|||
|
|
logOpenAIWSModeInfo(
|
|||
|
|
"reconnect_budget_exhausted account_id=%d attempts=%d max_retries=%d reason=%s elapsed_ms=%d budget_ms=%d",
|
|||
|
|
account.ID,
|
|||
|
|
attempt,
|
|||
|
|
openAIWSReconnectRetryLimit,
|
|||
|
|
normalizeOpenAIWSLogValue(reason),
|
|||
|
|
time.Since(retryStartedAt).Milliseconds(),
|
|||
|
|
retryBudget.Milliseconds(),
|
|||
|
|
)
|
|||
|
|
break
|
|||
|
|
}
|
|||
|
|
s.recordOpenAIWSRetryAttempt(backoff)
|
|||
|
|
logOpenAIWSModeInfo(
|
|||
|
|
"reconnect_retry account_id=%d retry=%d max_retries=%d reason=%s backoff_ms=%d",
|
|||
|
|
account.ID,
|
|||
|
|
attempt,
|
|||
|
|
openAIWSReconnectRetryLimit,
|
|||
|
|
normalizeOpenAIWSLogValue(reason),
|
|||
|
|
backoff.Milliseconds(),
|
|||
|
|
)
|
|||
|
|
if backoff > 0 {
|
|||
|
|
timer := time.NewTimer(backoff)
|
|||
|
|
select {
|
|||
|
|
case <-ctx.Done():
|
|||
|
|
if !timer.Stop() {
|
|||
|
|
<-timer.C
|
|||
|
|
}
|
|||
|
|
wsErr = wrapOpenAIWSFallback("retry_backoff_canceled", ctx.Err())
|
|||
|
|
break wsRetryLoop
|
|||
|
|
case <-timer.C:
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
if retryable {
|
|||
|
|
s.recordOpenAIWSRetryExhausted()
|
|||
|
|
logOpenAIWSModeInfo(
|
|||
|
|
"reconnect_exhausted account_id=%d attempts=%d max_retries=%d reason=%s",
|
|||
|
|
account.ID,
|
|||
|
|
attempt,
|
|||
|
|
openAIWSReconnectRetryLimit,
|
|||
|
|
normalizeOpenAIWSLogValue(reason),
|
|||
|
|
)
|
|||
|
|
} else if reason != "" {
|
|||
|
|
s.recordOpenAIWSNonRetryableFastFallback()
|
|||
|
|
logOpenAIWSModeInfo(
|
|||
|
|
"reconnect_stop account_id=%d attempt=%d reason=%s",
|
|||
|
|
account.ID,
|
|||
|
|
attempt,
|
|||
|
|
normalizeOpenAIWSLogValue(reason),
|
|||
|
|
)
|
|||
|
|
}
|
|||
|
|
break
|
|||
|
|
}
|
|||
|
|
if wsErr == nil {
|
|||
|
|
firstTokenMs := int64(0)
|
|||
|
|
hasFirstTokenMs := wsResult != nil && wsResult.FirstTokenMs != nil
|
|||
|
|
if hasFirstTokenMs {
|
|||
|
|
firstTokenMs = int64(*wsResult.FirstTokenMs)
|
|||
|
|
}
|
|||
|
|
requestID := ""
|
|||
|
|
if wsResult != nil {
|
|||
|
|
requestID = strings.TrimSpace(wsResult.RequestID)
|
|||
|
|
}
|
|||
|
|
logOpenAIWSModeDebug(
|
|||
|
|
"forward_succeeded account_id=%d request_id=%s stream=%v has_first_token_ms=%v first_token_ms=%d ws_attempts=%d",
|
|||
|
|
account.ID,
|
|||
|
|
requestID,
|
|||
|
|
reqStream,
|
|||
|
|
hasFirstTokenMs,
|
|||
|
|
firstTokenMs,
|
|||
|
|
wsAttempts,
|
|||
|
|
)
|
|||
|
|
return wsResult, nil
|
|||
|
|
}
|
|||
|
|
s.writeOpenAIWSFallbackErrorResponse(c, account, wsErr)
|
|||
|
|
return nil, wsErr
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
httpInvalidEncryptedContentRetryTried := false
|
|||
|
|
for {
|
|||
|
|
// Build upstream request
|
|||
|
|
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
|||
|
|
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
|
|||
|
|
releaseUpstreamCtx()
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Get proxy URL
|
|||
|
|
proxyURL := ""
|
|||
|
|
if account.ProxyID != nil && account.Proxy != nil {
|
|||
|
|
proxyURL = account.Proxy.URL()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Send request
|
|||
|
|
upstreamStart := time.Now()
|
|||
|
|
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
|||
|
|
SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds())
|
|||
|
|
if err != nil {
|
|||
|
|
// Ensure the client receives an error response (handlers assume Forward writes on non-failover errors).
|
|||
|
|
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
|||
|
|
setOpsUpstreamError(c, 0, safeErr, "")
|
|||
|
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
|||
|
|
Platform: account.Platform,
|
|||
|
|
AccountID: account.ID,
|
|||
|
|
AccountName: account.Name,
|
|||
|
|
UpstreamStatusCode: 0,
|
|||
|
|
Kind: "request_error",
|
|||
|
|
Message: safeErr,
|
|||
|
|
})
|
|||
|
|
c.JSON(http.StatusBadGateway, gin.H{
|
|||
|
|
"error": gin.H{
|
|||
|
|
"type": "upstream_error",
|
|||
|
|
"message": "Upstream request failed",
|
|||
|
|
},
|
|||
|
|
})
|
|||
|
|
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Handle error response
|
|||
|
|
if resp.StatusCode >= 400 {
|
|||
|
|
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
|||
|
|
_ = resp.Body.Close()
|
|||
|
|
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
|||
|
|
|
|||
|
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
|||
|
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
|||
|
|
upstreamCode := extractUpstreamErrorCode(respBody)
|
|||
|
|
if !httpInvalidEncryptedContentRetryTried && resp.StatusCode == http.StatusBadRequest && upstreamCode == "invalid_encrypted_content" {
|
|||
|
|
if trimOpenAIEncryptedReasoningItems(reqBody) {
|
|||
|
|
body, err = json.Marshal(reqBody)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("serialize invalid_encrypted_content retry body: %w", err)
|
|||
|
|
}
|
|||
|
|
setOpsUpstreamRequestBody(c, body)
|
|||
|
|
httpInvalidEncryptedContentRetryTried = true
|
|||
|
|
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Retrying non-WSv2 request once after invalid_encrypted_content (account: %s)", account.Name)
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Skip non-WSv2 invalid_encrypted_content retry because encrypted reasoning items are missing (account: %s)", account.Name)
|
|||
|
|
}
|
|||
|
|
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
|
|||
|
|
upstreamDetail := ""
|
|||
|
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
|||
|
|
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
|||
|
|
if maxBytes <= 0 {
|
|||
|
|
maxBytes = 2048
|
|||
|
|
}
|
|||
|
|
upstreamDetail = truncateString(string(respBody), maxBytes)
|
|||
|
|
}
|
|||
|
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
|||
|
|
Platform: account.Platform,
|
|||
|
|
AccountID: account.ID,
|
|||
|
|
AccountName: account.Name,
|
|||
|
|
UpstreamStatusCode: resp.StatusCode,
|
|||
|
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
|||
|
|
Kind: "failover",
|
|||
|
|
Message: upstreamMsg,
|
|||
|
|
Detail: upstreamDetail,
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
s.handleFailoverSideEffects(ctx, resp, account)
|
|||
|
|
return nil, &UpstreamFailoverError{
|
|||
|
|
StatusCode: resp.StatusCode,
|
|||
|
|
ResponseBody: respBody,
|
|||
|
|
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
return s.handleErrorResponse(ctx, resp, c, account, body)
|
|||
|
|
}
|
|||
|
|
defer func() { _ = resp.Body.Close() }()
|
|||
|
|
|
|||
|
|
// Handle normal response
|
|||
|
|
var usage *OpenAIUsage
|
|||
|
|
var firstTokenMs *int
|
|||
|
|
if reqStream {
|
|||
|
|
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, mappedModel)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
usage = streamResult.usage
|
|||
|
|
firstTokenMs = streamResult.firstTokenMs
|
|||
|
|
} else {
|
|||
|
|
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, mappedModel)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
|
|||
|
|
if account.Type == AccountTypeOAuth {
|
|||
|
|
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
|||
|
|
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if usage == nil {
|
|||
|
|
usage = &OpenAIUsage{}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel)
|
|||
|
|
serviceTier := extractOpenAIServiceTier(reqBody)
|
|||
|
|
|
|||
|
|
return &OpenAIForwardResult{
|
|||
|
|
RequestID: resp.Header.Get("x-request-id"),
|
|||
|
|
Usage: *usage,
|
|||
|
|
Model: originalModel,
|
|||
|
|
ServiceTier: serviceTier,
|
|||
|
|
ReasoningEffort: reasoningEffort,
|
|||
|
|
Stream: reqStream,
|
|||
|
|
OpenAIWSMode: false,
|
|||
|
|
Duration: time.Since(startTime),
|
|||
|
|
FirstTokenMs: firstTokenMs,
|
|||
|
|
}, nil
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
|
|||
|
|
ctx context.Context,
|
|||
|
|
c *gin.Context,
|
|||
|
|
account *Account,
|
|||
|
|
body []byte,
|
|||
|
|
reqModel string,
|
|||
|
|
reasoningEffort *string,
|
|||
|
|
reqStream bool,
|
|||
|
|
startTime time.Time,
|
|||
|
|
) (*OpenAIForwardResult, error) {
|
|||
|
|
if account != nil && account.Type == AccountTypeOAuth {
|
|||
|
|
if rejectReason := detectOpenAIPassthroughInstructionsRejectReason(reqModel, body); rejectReason != "" {
|
|||
|
|
rejectMsg := "OpenAI codex passthrough requires a non-empty instructions field"
|
|||
|
|
setOpsUpstreamError(c, http.StatusForbidden, rejectMsg, "")
|
|||
|
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
|||
|
|
Platform: account.Platform,
|
|||
|
|
AccountID: account.ID,
|
|||
|
|
AccountName: account.Name,
|
|||
|
|
UpstreamStatusCode: http.StatusForbidden,
|
|||
|
|
Passthrough: true,
|
|||
|
|
Kind: "request_error",
|
|||
|
|
Message: rejectMsg,
|
|||
|
|
Detail: rejectReason,
|
|||
|
|
})
|
|||
|
|
logOpenAIPassthroughInstructionsRejected(ctx, c, account, reqModel, rejectReason, body)
|
|||
|
|
c.JSON(http.StatusForbidden, gin.H{
|
|||
|
|
"error": gin.H{
|
|||
|
|
"type": "forbidden_error",
|
|||
|
|
"message": rejectMsg,
|
|||
|
|
},
|
|||
|
|
})
|
|||
|
|
return nil, fmt.Errorf("openai passthrough rejected before upstream: %s", rejectReason)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
normalizedBody, normalized, err := normalizeOpenAIPassthroughOAuthBody(body, isOpenAIResponsesCompactPath(c))
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
if normalized {
|
|||
|
|
body = normalizedBody
|
|||
|
|
}
|
|||
|
|
reqStream = gjson.GetBytes(body, "stream").Bool()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
logger.LegacyPrintf("service.openai_gateway",
|
|||
|
|
"[OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v",
|
|||
|
|
account.ID,
|
|||
|
|
account.Name,
|
|||
|
|
account.Type,
|
|||
|
|
reqModel,
|
|||
|
|
reqStream,
|
|||
|
|
)
|
|||
|
|
if reqStream && c != nil && c.Request != nil {
|
|||
|
|
if timeoutHeaders := collectOpenAIPassthroughTimeoutHeaders(c.Request.Header); len(timeoutHeaders) > 0 {
|
|||
|
|
streamWarnLogger := logger.FromContext(ctx).With(
|
|||
|
|
zap.String("component", "service.openai_gateway"),
|
|||
|
|
zap.Int64("account_id", account.ID),
|
|||
|
|
zap.Strings("timeout_headers", timeoutHeaders),
|
|||
|
|
)
|
|||
|
|
if s.isOpenAIPassthroughTimeoutHeadersAllowed() {
|
|||
|
|
streamWarnLogger.Warn("OpenAI passthrough 透传请求包含超时相关请求头,且当前配置为放行,可能导致上游提前断流")
|
|||
|
|
} else {
|
|||
|
|
streamWarnLogger.Warn("OpenAI passthrough 检测到超时相关请求头,将按配置过滤以降低断流风险")
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Get access token
|
|||
|
|
token, _, err := s.GetAccessToken(ctx, account)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
|||
|
|
upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(upstreamCtx, c, account, body, token)
|
|||
|
|
releaseUpstreamCtx()
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
proxyURL := ""
|
|||
|
|
if account.ProxyID != nil && account.Proxy != nil {
|
|||
|
|
proxyURL = account.Proxy.URL()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
setOpsUpstreamRequestBody(c, body)
|
|||
|
|
if c != nil {
|
|||
|
|
c.Set("openai_passthrough", true)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
upstreamStart := time.Now()
|
|||
|
|
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
|||
|
|
SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds())
|
|||
|
|
if err != nil {
|
|||
|
|
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
|||
|
|
setOpsUpstreamError(c, 0, safeErr, "")
|
|||
|
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
|||
|
|
Platform: account.Platform,
|
|||
|
|
AccountID: account.ID,
|
|||
|
|
AccountName: account.Name,
|
|||
|
|
UpstreamStatusCode: 0,
|
|||
|
|
Passthrough: true,
|
|||
|
|
Kind: "request_error",
|
|||
|
|
Message: safeErr,
|
|||
|
|
})
|
|||
|
|
c.JSON(http.StatusBadGateway, gin.H{
|
|||
|
|
"error": gin.H{
|
|||
|
|
"type": "upstream_error",
|
|||
|
|
"message": "Upstream request failed",
|
|||
|
|
},
|
|||
|
|
})
|
|||
|
|
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
|||
|
|
}
|
|||
|
|
defer func() { _ = resp.Body.Close() }()
|
|||
|
|
|
|||
|
|
if resp.StatusCode >= 400 {
|
|||
|
|
// 透传模式不做 failover(避免改变原始上游语义),按上游原样返回错误响应。
|
|||
|
|
return nil, s.handleErrorResponsePassthrough(ctx, resp, c, account, body)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
var usage *OpenAIUsage
|
|||
|
|
var firstTokenMs *int
|
|||
|
|
if reqStream {
|
|||
|
|
result, err := s.handleStreamingResponsePassthrough(ctx, resp, c, account, startTime)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
usage = result.usage
|
|||
|
|
firstTokenMs = result.firstTokenMs
|
|||
|
|
} else {
|
|||
|
|
usage, err = s.handleNonStreamingResponsePassthrough(ctx, resp, c)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
|||
|
|
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if usage == nil {
|
|||
|
|
usage = &OpenAIUsage{}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return &OpenAIForwardResult{
|
|||
|
|
RequestID: resp.Header.Get("x-request-id"),
|
|||
|
|
Usage: *usage,
|
|||
|
|
Model: reqModel,
|
|||
|
|
ServiceTier: extractOpenAIServiceTierFromBody(body),
|
|||
|
|
ReasoningEffort: reasoningEffort,
|
|||
|
|
Stream: reqStream,
|
|||
|
|
OpenAIWSMode: false,
|
|||
|
|
Duration: time.Since(startTime),
|
|||
|
|
FirstTokenMs: firstTokenMs,
|
|||
|
|
}, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func logOpenAIPassthroughInstructionsRejected(
|
|||
|
|
ctx context.Context,
|
|||
|
|
c *gin.Context,
|
|||
|
|
account *Account,
|
|||
|
|
reqModel string,
|
|||
|
|
rejectReason string,
|
|||
|
|
body []byte,
|
|||
|
|
) {
|
|||
|
|
if ctx == nil {
|
|||
|
|
ctx = context.Background()
|
|||
|
|
}
|
|||
|
|
accountID := int64(0)
|
|||
|
|
accountName := ""
|
|||
|
|
accountType := ""
|
|||
|
|
if account != nil {
|
|||
|
|
accountID = account.ID
|
|||
|
|
accountName = strings.TrimSpace(account.Name)
|
|||
|
|
accountType = strings.TrimSpace(string(account.Type))
|
|||
|
|
}
|
|||
|
|
fields := []zap.Field{
|
|||
|
|
zap.String("component", "service.openai_gateway"),
|
|||
|
|
zap.Int64("account_id", accountID),
|
|||
|
|
zap.String("account_name", accountName),
|
|||
|
|
zap.String("account_type", accountType),
|
|||
|
|
zap.String("request_model", strings.TrimSpace(reqModel)),
|
|||
|
|
zap.String("reject_reason", strings.TrimSpace(rejectReason)),
|
|||
|
|
}
|
|||
|
|
fields = appendCodexCLIOnlyRejectedRequestFields(fields, c, body)
|
|||
|
|
logger.FromContext(ctx).With(fields...).Warn("OpenAI passthrough 本地拦截:Codex 请求缺少有效 instructions")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough(
|
|||
|
|
ctx context.Context,
|
|||
|
|
c *gin.Context,
|
|||
|
|
account *Account,
|
|||
|
|
body []byte,
|
|||
|
|
token string,
|
|||
|
|
) (*http.Request, error) {
|
|||
|
|
targetURL := openaiPlatformAPIURL
|
|||
|
|
switch account.Type {
|
|||
|
|
case AccountTypeOAuth:
|
|||
|
|
targetURL = chatgptCodexURL
|
|||
|
|
case AccountTypeAPIKey:
|
|||
|
|
baseURL := account.GetOpenAIBaseURL()
|
|||
|
|
if baseURL != "" {
|
|||
|
|
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
targetURL = buildOpenAIResponsesURL(validatedURL)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
targetURL = appendOpenAIResponsesRequestPathSuffix(targetURL, openAIResponsesRequestPathSuffix(c))
|
|||
|
|
|
|||
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body))
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 透传客户端请求头(安全白名单)。
|
|||
|
|
allowTimeoutHeaders := s.isOpenAIPassthroughTimeoutHeadersAllowed()
|
|||
|
|
if c != nil && c.Request != nil {
|
|||
|
|
for key, values := range c.Request.Header {
|
|||
|
|
lower := strings.ToLower(strings.TrimSpace(key))
|
|||
|
|
if !isOpenAIPassthroughAllowedRequestHeader(lower, allowTimeoutHeaders) {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
for _, v := range values {
|
|||
|
|
req.Header.Add(key, v)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 覆盖入站鉴权残留,并注入上游认证
|
|||
|
|
req.Header.Del("authorization")
|
|||
|
|
req.Header.Del("x-api-key")
|
|||
|
|
req.Header.Del("x-goog-api-key")
|
|||
|
|
req.Header.Set("authorization", "Bearer "+token)
|
|||
|
|
|
|||
|
|
// OAuth 透传到 ChatGPT internal API 时补齐必要头。
|
|||
|
|
if account.Type == AccountTypeOAuth {
|
|||
|
|
promptCacheKey := strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
|
|||
|
|
req.Host = "chatgpt.com"
|
|||
|
|
if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" {
|
|||
|
|
req.Header.Set("chatgpt-account-id", chatgptAccountID)
|
|||
|
|
}
|
|||
|
|
apiKeyID := getAPIKeyIDFromContext(c)
|
|||
|
|
// 先保存客户端原始值,再做 compact 补充,避免后续统一隔离时读到已处理的值。
|
|||
|
|
clientSessionID := strings.TrimSpace(req.Header.Get("session_id"))
|
|||
|
|
clientConversationID := strings.TrimSpace(req.Header.Get("conversation_id"))
|
|||
|
|
if isOpenAIResponsesCompactPath(c) {
|
|||
|
|
req.Header.Set("accept", "application/json")
|
|||
|
|
if req.Header.Get("version") == "" {
|
|||
|
|
req.Header.Set("version", codexCLIVersion)
|
|||
|
|
}
|
|||
|
|
if clientSessionID == "" {
|
|||
|
|
clientSessionID = resolveOpenAICompactSessionID(c)
|
|||
|
|
}
|
|||
|
|
} else if req.Header.Get("accept") == "" {
|
|||
|
|
req.Header.Set("accept", "text/event-stream")
|
|||
|
|
}
|
|||
|
|
if req.Header.Get("OpenAI-Beta") == "" {
|
|||
|
|
req.Header.Set("OpenAI-Beta", "responses=experimental")
|
|||
|
|
}
|
|||
|
|
if req.Header.Get("originator") == "" {
|
|||
|
|
req.Header.Set("originator", "codex_cli_rs")
|
|||
|
|
}
|
|||
|
|
// 用隔离后的 session 标识符覆盖客户端透传值,防止跨用户会话碰撞。
|
|||
|
|
if clientSessionID == "" {
|
|||
|
|
clientSessionID = promptCacheKey
|
|||
|
|
}
|
|||
|
|
if clientConversationID == "" {
|
|||
|
|
clientConversationID = promptCacheKey
|
|||
|
|
}
|
|||
|
|
if clientSessionID != "" {
|
|||
|
|
req.Header.Set("session_id", isolateOpenAISessionID(apiKeyID, clientSessionID))
|
|||
|
|
}
|
|||
|
|
if clientConversationID != "" {
|
|||
|
|
req.Header.Set("conversation_id", isolateOpenAISessionID(apiKeyID, clientConversationID))
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 透传模式也支持账户自定义 User-Agent 与 ForceCodexCLI 兜底。
|
|||
|
|
customUA := account.GetOpenAIUserAgent()
|
|||
|
|
if customUA != "" {
|
|||
|
|
req.Header.Set("user-agent", customUA)
|
|||
|
|
}
|
|||
|
|
if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI {
|
|||
|
|
req.Header.Set("user-agent", codexCLIUserAgent)
|
|||
|
|
}
|
|||
|
|
// OAuth 安全透传:对非 Codex UA 统一兜底,降低被上游风控拦截概率。
|
|||
|
|
if account.Type == AccountTypeOAuth && !openai.IsCodexCLIRequest(req.Header.Get("user-agent")) {
|
|||
|
|
req.Header.Set("user-agent", codexCLIUserAgent)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if req.Header.Get("content-type") == "" {
|
|||
|
|
req.Header.Set("content-type", "application/json")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return req, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *OpenAIGatewayService) handleErrorResponsePassthrough(
|
|||
|
|
ctx context.Context,
|
|||
|
|
resp *http.Response,
|
|||
|
|
c *gin.Context,
|
|||
|
|
account *Account,
|
|||
|
|
requestBody []byte,
|
|||
|
|
) error {
|
|||
|
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
|||
|
|
|
|||
|
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
|
|||
|
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
|||
|
|
upstreamDetail := ""
|
|||
|
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
|||
|
|
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
|||
|
|
if maxBytes <= 0 {
|
|||
|
|
maxBytes = 2048
|
|||
|
|
}
|
|||
|
|
upstreamDetail = truncateString(string(body), maxBytes)
|
|||
|
|
}
|
|||
|
|
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
|||
|
|
logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body)
|
|||
|
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
|||
|
|
Platform: account.Platform,
|
|||
|
|
AccountID: account.ID,
|
|||
|
|
AccountName: account.Name,
|
|||
|
|
UpstreamStatusCode: resp.StatusCode,
|
|||
|
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
|||
|
|
Passthrough: true,
|
|||
|
|
Kind: "http_error",
|
|||
|
|
Message: upstreamMsg,
|
|||
|
|
Detail: upstreamDetail,
|
|||
|
|
UpstreamResponseBody: upstreamDetail,
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
|||
|
|
contentType := resp.Header.Get("Content-Type")
|
|||
|
|
if contentType == "" {
|
|||
|
|
contentType = "application/json"
|
|||
|
|
}
|
|||
|
|
c.Data(resp.StatusCode, contentType, body)
|
|||
|
|
|
|||
|
|
if upstreamMsg == "" {
|
|||
|
|
return fmt.Errorf("upstream error: %d", resp.StatusCode)
|
|||
|
|
}
|
|||
|
|
return fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func isOpenAIPassthroughAllowedRequestHeader(lowerKey string, allowTimeoutHeaders bool) bool {
|
|||
|
|
if lowerKey == "" {
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
if isOpenAIPassthroughTimeoutHeader(lowerKey) {
|
|||
|
|
return allowTimeoutHeaders
|
|||
|
|
}
|
|||
|
|
return openaiPassthroughAllowedHeaders[lowerKey]
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func isOpenAIPassthroughTimeoutHeader(lowerKey string) bool {
|
|||
|
|
switch lowerKey {
|
|||
|
|
case "x-stainless-timeout", "x-stainless-read-timeout", "x-stainless-connect-timeout", "x-request-timeout", "request-timeout", "grpc-timeout":
|
|||
|
|
return true
|
|||
|
|
default:
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *OpenAIGatewayService) isOpenAIPassthroughTimeoutHeadersAllowed() bool {
|
|||
|
|
return s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIPassthroughAllowTimeoutHeaders
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func collectOpenAIPassthroughTimeoutHeaders(h http.Header) []string {
|
|||
|
|
if h == nil {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
var matched []string
|
|||
|
|
for key, values := range h {
|
|||
|
|
lowerKey := strings.ToLower(strings.TrimSpace(key))
|
|||
|
|
if isOpenAIPassthroughTimeoutHeader(lowerKey) {
|
|||
|
|
entry := lowerKey
|
|||
|
|
if len(values) > 0 {
|
|||
|
|
entry = fmt.Sprintf("%s=%s", lowerKey, strings.Join(values, "|"))
|
|||
|
|
}
|
|||
|
|
matched = append(matched, entry)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
sort.Strings(matched)
|
|||
|
|
return matched
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
type openaiStreamingResultPassthrough struct {
|
|||
|
|
usage *OpenAIUsage
|
|||
|
|
firstTokenMs *int
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
|||
|
|
ctx context.Context,
|
|||
|
|
resp *http.Response,
|
|||
|
|
c *gin.Context,
|
|||
|
|
account *Account,
|
|||
|
|
startTime time.Time,
|
|||
|
|
) (*openaiStreamingResultPassthrough, error) {
|
|||
|
|
writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
|||
|
|
|
|||
|
|
// SSE headers
|
|||
|
|
c.Header("Content-Type", "text/event-stream")
|
|||
|
|
c.Header("Cache-Control", "no-cache")
|
|||
|
|
c.Header("Connection", "keep-alive")
|
|||
|
|
c.Header("X-Accel-Buffering", "no")
|
|||
|
|
if v := resp.Header.Get("x-request-id"); v != "" {
|
|||
|
|
c.Header("x-request-id", v)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
w := c.Writer
|
|||
|
|
flusher, ok := w.(http.Flusher)
|
|||
|
|
if !ok {
|
|||
|
|
return nil, errors.New("streaming not supported")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
usage := &OpenAIUsage{}
|
|||
|
|
var firstTokenMs *int
|
|||
|
|
clientDisconnected := false
|
|||
|
|
sawDone := false
|
|||
|
|
sawTerminalEvent := false
|
|||
|
|
upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id"))
|
|||
|
|
|
|||
|
|
scanner := bufio.NewScanner(resp.Body)
|
|||
|
|
maxLineSize := defaultMaxLineSize
|
|||
|
|
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
|||
|
|
maxLineSize = s.cfg.Gateway.MaxLineSize
|
|||
|
|
}
|
|||
|
|
scanBuf := getSSEScannerBuf64K()
|
|||
|
|
scanner.Buffer(scanBuf[:0], maxLineSize)
|
|||
|
|
defer putSSEScannerBuf64K(scanBuf)
|
|||
|
|
|
|||
|
|
for scanner.Scan() {
|
|||
|
|
line := scanner.Text()
|
|||
|
|
if data, ok := extractOpenAISSEDataLine(line); ok {
|
|||
|
|
dataBytes := []byte(data)
|
|||
|
|
trimmedData := strings.TrimSpace(data)
|
|||
|
|
if trimmedData == "[DONE]" {
|
|||
|
|
sawDone = true
|
|||
|
|
}
|
|||
|
|
if openAIStreamEventIsTerminal(trimmedData) {
|
|||
|
|
sawTerminalEvent = true
|
|||
|
|
}
|
|||
|
|
if firstTokenMs == nil && trimmedData != "" && trimmedData != "[DONE]" {
|
|||
|
|
ms := int(time.Since(startTime).Milliseconds())
|
|||
|
|
firstTokenMs = &ms
|
|||
|
|
}
|
|||
|
|
s.parseSSEUsageBytes(dataBytes, usage)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if !clientDisconnected {
|
|||
|
|
if _, err := fmt.Fprintln(w, line); err != nil {
|
|||
|
|
clientDisconnected = true
|
|||
|
|
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID)
|
|||
|
|
} else {
|
|||
|
|
flusher.Flush()
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
if err := scanner.Err(); err != nil {
|
|||
|
|
if sawTerminalEvent {
|
|||
|
|
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
|
|||
|
|
}
|
|||
|
|
if clientDisconnected {
|
|||
|
|
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete after disconnect: %w", err)
|
|||
|
|
}
|
|||
|
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
|||
|
|
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete: %w", err)
|
|||
|
|
}
|
|||
|
|
if errors.Is(err, bufio.ErrTooLong) {
|
|||
|
|
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err)
|
|||
|
|
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, err
|
|||
|
|
}
|
|||
|
|
logger.LegacyPrintf("service.openai_gateway",
|
|||
|
|
"[OpenAI passthrough] 流读取异常中断: account=%d request_id=%s err=%v",
|
|||
|
|
account.ID,
|
|||
|
|
upstreamRequestID,
|
|||
|
|
err,
|
|||
|
|
)
|
|||
|
|
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
|
|||
|
|
}
|
|||
|
|
if !clientDisconnected && !sawDone && !sawTerminalEvent && ctx.Err() == nil {
|
|||
|
|
logger.FromContext(ctx).With(
|
|||
|
|
zap.String("component", "service.openai_gateway"),
|
|||
|
|
zap.Int64("account_id", account.ID),
|
|||
|
|
zap.String("upstream_request_id", upstreamRequestID),
|
|||
|
|
).Info("OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流")
|
|||
|
|
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, errors.New("stream usage incomplete: missing terminal event")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
|
|||
|
|
ctx context.Context,
|
|||
|
|
resp *http.Response,
|
|||
|
|
c *gin.Context,
|
|||
|
|
) (*OpenAIUsage, error) {
|
|||
|
|
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
|||
|
|
body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
|
|||
|
|
if err != nil {
|
|||
|
|
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
|||
|
|
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
|||
|
|
c.JSON(http.StatusBadGateway, gin.H{
|
|||
|
|
"error": gin.H{
|
|||
|
|
"type": "upstream_error",
|
|||
|
|
"message": "Upstream response too large",
|
|||
|
|
},
|
|||
|
|
})
|
|||
|
|
}
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
usage := &OpenAIUsage{}
|
|||
|
|
usageParsed := false
|
|||
|
|
if len(body) > 0 {
|
|||
|
|
if parsedUsage, ok := extractOpenAIUsageFromJSONBytes(body); ok {
|
|||
|
|
*usage = parsedUsage
|
|||
|
|
usageParsed = true
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
if !usageParsed {
|
|||
|
|
// 兜底:尝试从 SSE 文本中解析 usage
|
|||
|
|
usage = s.parseSSEUsageFromBody(string(body))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
|||
|
|
|
|||
|
|
contentType := resp.Header.Get("Content-Type")
|
|||
|
|
if contentType == "" {
|
|||
|
|
contentType = "application/json"
|
|||
|
|
}
|
|||
|
|
c.Data(resp.StatusCode, contentType, body)
|
|||
|
|
return usage, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func writeOpenAIPassthroughResponseHeaders(dst http.Header, src http.Header, filter *responseheaders.CompiledHeaderFilter) {
|
|||
|
|
if dst == nil || src == nil {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
if filter != nil {
|
|||
|
|
responseheaders.WriteFilteredHeaders(dst, src, filter)
|
|||
|
|
} else {
|
|||
|
|
// 兜底:尽量保留最基础的 content-type
|
|||
|
|
if v := strings.TrimSpace(src.Get("Content-Type")); v != "" {
|
|||
|
|
dst.Set("Content-Type", v)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
// 透传模式强制放行 x-codex-* 响应头(若上游返回)。
|
|||
|
|
// 注意:真实 http.Response.Header 的 key 一般会被 canonicalize;但为了兼容测试/自建响应,
|
|||
|
|
// 这里用 EqualFold 做一次大小写不敏感的查找。
|
|||
|
|
getCaseInsensitiveValues := func(h http.Header, want string) []string {
|
|||
|
|
if h == nil {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
for k, vals := range h {
|
|||
|
|
if strings.EqualFold(k, want) {
|
|||
|
|
return vals
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for _, rawKey := range []string{
|
|||
|
|
"x-codex-primary-used-percent",
|
|||
|
|
"x-codex-primary-reset-after-seconds",
|
|||
|
|
"x-codex-primary-window-minutes",
|
|||
|
|
"x-codex-secondary-used-percent",
|
|||
|
|
"x-codex-secondary-reset-after-seconds",
|
|||
|
|
"x-codex-secondary-window-minutes",
|
|||
|
|
"x-codex-primary-over-secondary-limit-percent",
|
|||
|
|
} {
|
|||
|
|
vals := getCaseInsensitiveValues(src, rawKey)
|
|||
|
|
if len(vals) == 0 {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
key := http.CanonicalHeaderKey(rawKey)
|
|||
|
|
dst.Del(key)
|
|||
|
|
for _, v := range vals {
|
|||
|
|
dst.Add(key, v)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool, promptCacheKey string, isCodexCLI bool) (*http.Request, error) {
|
|||
|
|
// Determine target URL based on account type
|
|||
|
|
var targetURL string
|
|||
|
|
switch account.Type {
|
|||
|
|
case AccountTypeOAuth:
|
|||
|
|
// OAuth accounts use ChatGPT internal API
|
|||
|
|
targetURL = chatgptCodexURL
|
|||
|
|
case AccountTypeAPIKey:
|
|||
|
|
// API Key accounts use Platform API or custom base URL
|
|||
|
|
baseURL := account.GetOpenAIBaseURL()
|
|||
|
|
if baseURL == "" {
|
|||
|
|
targetURL = openaiPlatformAPIURL
|
|||
|
|
} else {
|
|||
|
|
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
targetURL = buildOpenAIResponsesURL(validatedURL)
|
|||
|
|
}
|
|||
|
|
default:
|
|||
|
|
targetURL = openaiPlatformAPIURL
|
|||
|
|
}
|
|||
|
|
targetURL = appendOpenAIResponsesRequestPathSuffix(targetURL, openAIResponsesRequestPathSuffix(c))
|
|||
|
|
|
|||
|
|
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Set authentication header
|
|||
|
|
req.Header.Set("authorization", "Bearer "+token)
|
|||
|
|
|
|||
|
|
// Set headers specific to OAuth accounts (ChatGPT internal API)
|
|||
|
|
if account.Type == AccountTypeOAuth {
|
|||
|
|
// Required: set Host for ChatGPT API (must use req.Host, not Header.Set)
|
|||
|
|
req.Host = "chatgpt.com"
|
|||
|
|
// Required: set chatgpt-account-id header
|
|||
|
|
chatgptAccountID := account.GetChatGPTAccountID()
|
|||
|
|
if chatgptAccountID != "" {
|
|||
|
|
req.Header.Set("chatgpt-account-id", chatgptAccountID)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Whitelist passthrough headers
|
|||
|
|
for key, values := range c.Request.Header {
|
|||
|
|
lowerKey := strings.ToLower(key)
|
|||
|
|
if openaiAllowedHeaders[lowerKey] {
|
|||
|
|
for _, v := range values {
|
|||
|
|
req.Header.Add(key, v)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
if account.Type == AccountTypeOAuth {
|
|||
|
|
// 清除客户端透传的 session 头,后续用隔离后的值重新设置,防止跨用户会话碰撞。
|
|||
|
|
req.Header.Del("conversation_id")
|
|||
|
|
req.Header.Del("session_id")
|
|||
|
|
|
|||
|
|
req.Header.Set("OpenAI-Beta", "responses=experimental")
|
|||
|
|
req.Header.Set("originator", resolveOpenAIUpstreamOriginator(c, isCodexCLI))
|
|||
|
|
apiKeyID := getAPIKeyIDFromContext(c)
|
|||
|
|
if isOpenAIResponsesCompactPath(c) {
|
|||
|
|
req.Header.Set("accept", "application/json")
|
|||
|
|
if req.Header.Get("version") == "" {
|
|||
|
|
req.Header.Set("version", codexCLIVersion)
|
|||
|
|
}
|
|||
|
|
compactSession := resolveOpenAICompactSessionID(c)
|
|||
|
|
req.Header.Set("session_id", isolateOpenAISessionID(apiKeyID, compactSession))
|
|||
|
|
} else {
|
|||
|
|
req.Header.Set("accept", "text/event-stream")
|
|||
|
|
}
|
|||
|
|
if promptCacheKey != "" {
|
|||
|
|
isolated := isolateOpenAISessionID(apiKeyID, promptCacheKey)
|
|||
|
|
req.Header.Set("conversation_id", isolated)
|
|||
|
|
req.Header.Set("session_id", isolated)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Apply custom User-Agent if configured
|
|||
|
|
customUA := account.GetOpenAIUserAgent()
|
|||
|
|
if customUA != "" {
|
|||
|
|
req.Header.Set("user-agent", customUA)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 若开启 ForceCodexCLI,则强制将上游 User-Agent 伪装为 Codex CLI。
|
|||
|
|
// 用于网关未透传/改写 User-Agent 时,仍能命中 Codex 侧识别逻辑。
|
|||
|
|
if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI {
|
|||
|
|
req.Header.Set("user-agent", codexCLIUserAgent)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Ensure required headers exist
|
|||
|
|
if req.Header.Get("content-type") == "" {
|
|||
|
|
req.Header.Set("content-type", "application/json")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return req, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *OpenAIGatewayService) handleErrorResponse(
|
|||
|
|
ctx context.Context,
|
|||
|
|
resp *http.Response,
|
|||
|
|
c *gin.Context,
|
|||
|
|
account *Account,
|
|||
|
|
requestBody []byte,
|
|||
|
|
) (*OpenAIForwardResult, error) {
|
|||
|
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
|||
|
|
|
|||
|
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
|
|||
|
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
|||
|
|
upstreamDetail := ""
|
|||
|
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
|||
|
|
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
|||
|
|
if maxBytes <= 0 {
|
|||
|
|
maxBytes = 2048
|
|||
|
|
}
|
|||
|
|
upstreamDetail = truncateString(string(body), maxBytes)
|
|||
|
|
}
|
|||
|
|
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
|||
|
|
logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body)
|
|||
|
|
|
|||
|
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
|||
|
|
logger.LegacyPrintf("service.openai_gateway",
|
|||
|
|
"OpenAI upstream error %d (account=%d platform=%s type=%s): %s",
|
|||
|
|
resp.StatusCode,
|
|||
|
|
account.ID,
|
|||
|
|
account.Platform,
|
|||
|
|
account.Type,
|
|||
|
|
truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
|
|||
|
|
)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if status, errType, errMsg, matched := applyErrorPassthroughRule(
|
|||
|
|
c,
|
|||
|
|
PlatformOpenAI,
|
|||
|
|
resp.StatusCode,
|
|||
|
|
body,
|
|||
|
|
http.StatusBadGateway,
|
|||
|
|
"upstream_error",
|
|||
|
|
"Upstream request failed",
|
|||
|
|
); matched {
|
|||
|
|
c.JSON(status, gin.H{
|
|||
|
|
"error": gin.H{
|
|||
|
|
"type": errType,
|
|||
|
|
"message": errMsg,
|
|||
|
|
},
|
|||
|
|
})
|
|||
|
|
if upstreamMsg == "" {
|
|||
|
|
upstreamMsg = errMsg
|
|||
|
|
}
|
|||
|
|
if upstreamMsg == "" {
|
|||
|
|
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode)
|
|||
|
|
}
|
|||
|
|
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, upstreamMsg)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Check custom error codes
|
|||
|
|
if !account.ShouldHandleErrorCode(resp.StatusCode) {
|
|||
|
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
|||
|
|
Platform: account.Platform,
|
|||
|
|
AccountID: account.ID,
|
|||
|
|
AccountName: account.Name,
|
|||
|
|
UpstreamStatusCode: resp.StatusCode,
|
|||
|
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
|||
|
|
Kind: "http_error",
|
|||
|
|
Message: upstreamMsg,
|
|||
|
|
Detail: upstreamDetail,
|
|||
|
|
})
|
|||
|
|
c.JSON(http.StatusInternalServerError, gin.H{
|
|||
|
|
"error": gin.H{
|
|||
|
|
"type": "upstream_error",
|
|||
|
|
"message": "Upstream gateway error",
|
|||
|
|
},
|
|||
|
|
})
|
|||
|
|
if upstreamMsg == "" {
|
|||
|
|
return nil, fmt.Errorf("upstream error: %d (not in custom error codes)", resp.StatusCode)
|
|||
|
|
}
|
|||
|
|
return nil, fmt.Errorf("upstream error: %d (not in custom error codes) message=%s", resp.StatusCode, upstreamMsg)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Handle upstream error (mark account status)
|
|||
|
|
shouldDisable := false
|
|||
|
|
if s.rateLimitService != nil {
|
|||
|
|
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
|||
|
|
}
|
|||
|
|
kind := "http_error"
|
|||
|
|
if shouldDisable {
|
|||
|
|
kind = "failover"
|
|||
|
|
}
|
|||
|
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
|||
|
|
Platform: account.Platform,
|
|||
|
|
AccountID: account.ID,
|
|||
|
|
AccountName: account.Name,
|
|||
|
|
UpstreamStatusCode: resp.StatusCode,
|
|||
|
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
|||
|
|
Kind: kind,
|
|||
|
|
Message: upstreamMsg,
|
|||
|
|
Detail: upstreamDetail,
|
|||
|
|
})
|
|||
|
|
if shouldDisable {
|
|||
|
|
return nil, &UpstreamFailoverError{
|
|||
|
|
StatusCode: resp.StatusCode,
|
|||
|
|
ResponseBody: body,
|
|||
|
|
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Return appropriate error response
|
|||
|
|
var errType, errMsg string
|
|||
|
|
var statusCode int
|
|||
|
|
|
|||
|
|
switch resp.StatusCode {
|
|||
|
|
case 401:
|
|||
|
|
statusCode = http.StatusBadGateway
|
|||
|
|
errType = "upstream_error"
|
|||
|
|
errMsg = "Upstream authentication failed, please contact administrator"
|
|||
|
|
case 402:
|
|||
|
|
statusCode = http.StatusBadGateway
|
|||
|
|
errType = "upstream_error"
|
|||
|
|
errMsg = "Upstream payment required: insufficient balance or billing issue"
|
|||
|
|
case 403:
|
|||
|
|
statusCode = http.StatusBadGateway
|
|||
|
|
errType = "upstream_error"
|
|||
|
|
errMsg = "Upstream access forbidden, please contact administrator"
|
|||
|
|
case 429:
|
|||
|
|
statusCode = http.StatusTooManyRequests
|
|||
|
|
errType = "rate_limit_error"
|
|||
|
|
errMsg = "Upstream rate limit exceeded, please retry later"
|
|||
|
|
default:
|
|||
|
|
statusCode = http.StatusBadGateway
|
|||
|
|
errType = "upstream_error"
|
|||
|
|
errMsg = "Upstream request failed"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
c.JSON(statusCode, gin.H{
|
|||
|
|
"error": gin.H{
|
|||
|
|
"type": errType,
|
|||
|
|
"message": errMsg,
|
|||
|
|
},
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
if upstreamMsg == "" {
|
|||
|
|
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
|
|||
|
|
}
|
|||
|
|
return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// compatErrorWriter is the signature for format-specific error writers used by
|
|||
|
|
// the compat paths (Chat Completions and Anthropic Messages).
|
|||
|
|
type compatErrorWriter func(c *gin.Context, statusCode int, errType, message string)
|
|||
|
|
|
|||
|
|
// handleCompatErrorResponse is the shared non-failover error handler for the
|
|||
|
|
// Chat Completions and Anthropic Messages compat paths. It mirrors the logic of
|
|||
|
|
// handleErrorResponse (passthrough rules, ShouldHandleErrorCode, rate-limit
|
|||
|
|
// tracking, secondary failover) but delegates the final error write to the
|
|||
|
|
// format-specific writer function.
|
|||
|
|
func (s *OpenAIGatewayService) handleCompatErrorResponse(
|
|||
|
|
resp *http.Response,
|
|||
|
|
c *gin.Context,
|
|||
|
|
account *Account,
|
|||
|
|
writeError compatErrorWriter,
|
|||
|
|
) (*OpenAIForwardResult, error) {
|
|||
|
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
|||
|
|
|
|||
|
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
|
|||
|
|
if upstreamMsg == "" {
|
|||
|
|
upstreamMsg = fmt.Sprintf("Upstream error: %d", resp.StatusCode)
|
|||
|
|
}
|
|||
|
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
|||
|
|
|
|||
|
|
upstreamDetail := ""
|
|||
|
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
|||
|
|
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
|||
|
|
if maxBytes <= 0 {
|
|||
|
|
maxBytes = 2048
|
|||
|
|
}
|
|||
|
|
upstreamDetail = truncateString(string(body), maxBytes)
|
|||
|
|
}
|
|||
|
|
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
|||
|
|
|
|||
|
|
// Apply error passthrough rules
|
|||
|
|
if status, errType, errMsg, matched := applyErrorPassthroughRule(
|
|||
|
|
c, account.Platform, resp.StatusCode, body,
|
|||
|
|
http.StatusBadGateway, "api_error", "Upstream request failed",
|
|||
|
|
); matched {
|
|||
|
|
writeError(c, status, errType, errMsg)
|
|||
|
|
if upstreamMsg == "" {
|
|||
|
|
upstreamMsg = errMsg
|
|||
|
|
}
|
|||
|
|
if upstreamMsg == "" {
|
|||
|
|
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode)
|
|||
|
|
}
|
|||
|
|
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, upstreamMsg)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Check custom error codes — if the account does not handle this status,
|
|||
|
|
// return a generic error without exposing upstream details.
|
|||
|
|
if !account.ShouldHandleErrorCode(resp.StatusCode) {
|
|||
|
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
|||
|
|
Platform: account.Platform,
|
|||
|
|
AccountID: account.ID,
|
|||
|
|
AccountName: account.Name,
|
|||
|
|
UpstreamStatusCode: resp.StatusCode,
|
|||
|
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
|||
|
|
Kind: "http_error",
|
|||
|
|
Message: upstreamMsg,
|
|||
|
|
Detail: upstreamDetail,
|
|||
|
|
})
|
|||
|
|
writeError(c, http.StatusInternalServerError, "api_error", "Upstream gateway error")
|
|||
|
|
if upstreamMsg == "" {
|
|||
|
|
return nil, fmt.Errorf("upstream error: %d (not in custom error codes)", resp.StatusCode)
|
|||
|
|
}
|
|||
|
|
return nil, fmt.Errorf("upstream error: %d (not in custom error codes) message=%s", resp.StatusCode, upstreamMsg)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Track rate limits and decide whether to trigger secondary failover.
|
|||
|
|
shouldDisable := false
|
|||
|
|
if s.rateLimitService != nil {
|
|||
|
|
shouldDisable = s.rateLimitService.HandleUpstreamError(
|
|||
|
|
c.Request.Context(), account, resp.StatusCode, resp.Header, body,
|
|||
|
|
)
|
|||
|
|
}
|
|||
|
|
kind := "http_error"
|
|||
|
|
if shouldDisable {
|
|||
|
|
kind = "failover"
|
|||
|
|
}
|
|||
|
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
|||
|
|
Platform: account.Platform,
|
|||
|
|
AccountID: account.ID,
|
|||
|
|
AccountName: account.Name,
|
|||
|
|
UpstreamStatusCode: resp.StatusCode,
|
|||
|
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
|||
|
|
Kind: kind,
|
|||
|
|
Message: upstreamMsg,
|
|||
|
|
Detail: upstreamDetail,
|
|||
|
|
})
|
|||
|
|
if shouldDisable {
|
|||
|
|
return nil, &UpstreamFailoverError{
|
|||
|
|
StatusCode: resp.StatusCode,
|
|||
|
|
ResponseBody: body,
|
|||
|
|
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Map status code to error type and write response
|
|||
|
|
errType := "api_error"
|
|||
|
|
switch {
|
|||
|
|
case resp.StatusCode == 400:
|
|||
|
|
errType = "invalid_request_error"
|
|||
|
|
case resp.StatusCode == 404:
|
|||
|
|
errType = "not_found_error"
|
|||
|
|
case resp.StatusCode == 429:
|
|||
|
|
errType = "rate_limit_error"
|
|||
|
|
case resp.StatusCode >= 500:
|
|||
|
|
errType = "api_error"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
writeError(c, resp.StatusCode, errType, upstreamMsg)
|
|||
|
|
return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// openaiStreamingResult streaming response result
|
|||
|
|
type openaiStreamingResult struct {
|
|||
|
|
usage *OpenAIUsage
|
|||
|
|
firstTokenMs *int
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) {
|
|||
|
|
if s.responseHeaderFilter != nil {
|
|||
|
|
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Set SSE response headers
|
|||
|
|
c.Header("Content-Type", "text/event-stream")
|
|||
|
|
c.Header("Cache-Control", "no-cache")
|
|||
|
|
c.Header("Connection", "keep-alive")
|
|||
|
|
c.Header("X-Accel-Buffering", "no")
|
|||
|
|
|
|||
|
|
// Pass through other headers
|
|||
|
|
if v := resp.Header.Get("x-request-id"); v != "" {
|
|||
|
|
c.Header("x-request-id", v)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
w := c.Writer
|
|||
|
|
flusher, ok := w.(http.Flusher)
|
|||
|
|
if !ok {
|
|||
|
|
return nil, errors.New("streaming not supported")
|
|||
|
|
}
|
|||
|
|
bufferedWriter := bufio.NewWriterSize(w, 4*1024)
|
|||
|
|
flushBuffered := func() error {
|
|||
|
|
if err := bufferedWriter.Flush(); err != nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
flusher.Flush()
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
usage := &OpenAIUsage{}
|
|||
|
|
var firstTokenMs *int
|
|||
|
|
scanner := bufio.NewScanner(resp.Body)
|
|||
|
|
maxLineSize := defaultMaxLineSize
|
|||
|
|
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
|||
|
|
maxLineSize = s.cfg.Gateway.MaxLineSize
|
|||
|
|
}
|
|||
|
|
scanBuf := getSSEScannerBuf64K()
|
|||
|
|
scanner.Buffer(scanBuf[:0], maxLineSize)
|
|||
|
|
|
|||
|
|
streamInterval := time.Duration(0)
|
|||
|
|
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
|||
|
|
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
|||
|
|
}
|
|||
|
|
// 仅监控上游数据间隔超时,不被下游写入阻塞影响
|
|||
|
|
var intervalTicker *time.Ticker
|
|||
|
|
if streamInterval > 0 {
|
|||
|
|
intervalTicker = time.NewTicker(streamInterval)
|
|||
|
|
defer intervalTicker.Stop()
|
|||
|
|
}
|
|||
|
|
var intervalCh <-chan time.Time
|
|||
|
|
if intervalTicker != nil {
|
|||
|
|
intervalCh = intervalTicker.C
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
keepaliveInterval := time.Duration(0)
|
|||
|
|
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
|
|||
|
|
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
|
|||
|
|
}
|
|||
|
|
// 下游 keepalive 仅用于防止代理空闲断开
|
|||
|
|
var keepaliveTicker *time.Ticker
|
|||
|
|
if keepaliveInterval > 0 {
|
|||
|
|
keepaliveTicker = time.NewTicker(keepaliveInterval)
|
|||
|
|
defer keepaliveTicker.Stop()
|
|||
|
|
}
|
|||
|
|
var keepaliveCh <-chan time.Time
|
|||
|
|
if keepaliveTicker != nil {
|
|||
|
|
keepaliveCh = keepaliveTicker.C
|
|||
|
|
}
|
|||
|
|
// 记录上次收到上游数据的时间,用于控制 keepalive 发送频率
|
|||
|
|
lastDataAt := time.Now()
|
|||
|
|
|
|||
|
|
// 仅发送一次错误事件,避免多次写入导致协议混乱。
|
|||
|
|
// 注意:OpenAI `/v1/responses` streaming 事件必须符合 OpenAI Responses schema;
|
|||
|
|
// 否则下游 SDK(例如 OpenCode)会因为类型校验失败而报错。
|
|||
|
|
errorEventSent := false
|
|||
|
|
clientDisconnected := false // 客户端断开后继续 drain 上游以收集 usage
|
|||
|
|
sawTerminalEvent := false
|
|||
|
|
sendErrorEvent := func(reason string) {
|
|||
|
|
if errorEventSent || clientDisconnected {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
errorEventSent = true
|
|||
|
|
payload := `{"type":"error","sequence_number":0,"error":{"type":"upstream_error","message":` + strconv.Quote(reason) + `,"code":` + strconv.Quote(reason) + `}}`
|
|||
|
|
if err := flushBuffered(); err != nil {
|
|||
|
|
clientDisconnected = true
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
if _, err := bufferedWriter.WriteString("data: " + payload + "\n\n"); err != nil {
|
|||
|
|
clientDisconnected = true
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
if err := flushBuffered(); err != nil {
|
|||
|
|
clientDisconnected = true
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
needModelReplace := originalModel != mappedModel
|
|||
|
|
resultWithUsage := func() *openaiStreamingResult {
|
|||
|
|
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}
|
|||
|
|
}
|
|||
|
|
finalizeStream := func() (*openaiStreamingResult, error) {
|
|||
|
|
if !clientDisconnected {
|
|||
|
|
if err := flushBuffered(); err != nil {
|
|||
|
|
clientDisconnected = true
|
|||
|
|
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during final flush, returning collected usage")
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
if !sawTerminalEvent {
|
|||
|
|
return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
|
|||
|
|
}
|
|||
|
|
return resultWithUsage(), nil
|
|||
|
|
}
|
|||
|
|
handleScanErr := func(scanErr error) (*openaiStreamingResult, error, bool) {
|
|||
|
|
if scanErr == nil {
|
|||
|
|
return nil, nil, false
|
|||
|
|
}
|
|||
|
|
if sawTerminalEvent {
|
|||
|
|
logger.LegacyPrintf("service.openai_gateway", "Upstream scan ended after terminal event: %v", scanErr)
|
|||
|
|
return resultWithUsage(), nil, true
|
|||
|
|
}
|
|||
|
|
// 客户端断开/取消请求时,上游读取往往会返回 context canceled。
|
|||
|
|
// /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。
|
|||
|
|
if errors.Is(scanErr, context.Canceled) || errors.Is(scanErr, context.DeadlineExceeded) {
|
|||
|
|
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", scanErr), true
|
|||
|
|
}
|
|||
|
|
// 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
|
|||
|
|
if clientDisconnected {
|
|||
|
|
return resultWithUsage(), fmt.Errorf("stream usage incomplete after disconnect: %w", scanErr), true
|
|||
|
|
}
|
|||
|
|
if errors.Is(scanErr, bufio.ErrTooLong) {
|
|||
|
|
logger.LegacyPrintf("service.openai_gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, scanErr)
|
|||
|
|
sendErrorEvent("response_too_large")
|
|||
|
|
return resultWithUsage(), scanErr, true
|
|||
|
|
}
|
|||
|
|
sendErrorEvent("stream_read_error")
|
|||
|
|
return resultWithUsage(), fmt.Errorf("stream read error: %w", scanErr), true
|
|||
|
|
}
|
|||
|
|
processSSELine := func(line string, queueDrained bool) {
|
|||
|
|
lastDataAt = time.Now()
|
|||
|
|
|
|||
|
|
// Extract data from SSE line (supports both "data: " and "data:" formats)
|
|||
|
|
if data, ok := extractOpenAISSEDataLine(line); ok {
|
|||
|
|
|
|||
|
|
// Replace model in response if needed.
|
|||
|
|
// Fast path: most events do not contain model field values.
|
|||
|
|
if needModelReplace && mappedModel != "" && strings.Contains(data, mappedModel) {
|
|||
|
|
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
dataBytes := []byte(data)
|
|||
|
|
if openAIStreamEventIsTerminal(data) {
|
|||
|
|
sawTerminalEvent = true
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
|
|||
|
|
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected {
|
|||
|
|
dataBytes = correctedData
|
|||
|
|
data = string(correctedData)
|
|||
|
|
line = "data: " + data
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 写入客户端(客户端断开后继续 drain 上游)
|
|||
|
|
if !clientDisconnected {
|
|||
|
|
shouldFlush := queueDrained
|
|||
|
|
if firstTokenMs == nil && data != "" && data != "[DONE]" {
|
|||
|
|
// 保证首个 token 事件尽快出站,避免影响 TTFT。
|
|||
|
|
shouldFlush = true
|
|||
|
|
}
|
|||
|
|
if _, err := bufferedWriter.WriteString(line); err != nil {
|
|||
|
|
clientDisconnected = true
|
|||
|
|
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing")
|
|||
|
|
} else if _, err := bufferedWriter.WriteString("\n"); err != nil {
|
|||
|
|
clientDisconnected = true
|
|||
|
|
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing")
|
|||
|
|
} else if shouldFlush {
|
|||
|
|
if err := flushBuffered(); err != nil {
|
|||
|
|
clientDisconnected = true
|
|||
|
|
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing")
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Record first token time
|
|||
|
|
if firstTokenMs == nil && data != "" && data != "[DONE]" {
|
|||
|
|
ms := int(time.Since(startTime).Milliseconds())
|
|||
|
|
firstTokenMs = &ms
|
|||
|
|
}
|
|||
|
|
s.parseSSEUsageBytes(dataBytes, usage)
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Forward non-data lines as-is
|
|||
|
|
if !clientDisconnected {
|
|||
|
|
if _, err := bufferedWriter.WriteString(line); err != nil {
|
|||
|
|
clientDisconnected = true
|
|||
|
|
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing")
|
|||
|
|
} else if _, err := bufferedWriter.WriteString("\n"); err != nil {
|
|||
|
|
clientDisconnected = true
|
|||
|
|
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing")
|
|||
|
|
} else if queueDrained {
|
|||
|
|
if err := flushBuffered(); err != nil {
|
|||
|
|
clientDisconnected = true
|
|||
|
|
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing")
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 无超时/无 keepalive 的常见路径走同步扫描,减少 goroutine 与 channel 开销。
|
|||
|
|
if streamInterval <= 0 && keepaliveInterval <= 0 {
|
|||
|
|
defer putSSEScannerBuf64K(scanBuf)
|
|||
|
|
for scanner.Scan() {
|
|||
|
|
processSSELine(scanner.Text(), true)
|
|||
|
|
}
|
|||
|
|
if result, err, done := handleScanErr(scanner.Err()); done {
|
|||
|
|
return result, err
|
|||
|
|
}
|
|||
|
|
return finalizeStream()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
type scanEvent struct {
|
|||
|
|
line string
|
|||
|
|
err error
|
|||
|
|
}
|
|||
|
|
// 独立 goroutine 读取上游,避免读取阻塞影响 keepalive/超时处理
|
|||
|
|
events := make(chan scanEvent, 16)
|
|||
|
|
done := make(chan struct{})
|
|||
|
|
sendEvent := func(ev scanEvent) bool {
|
|||
|
|
select {
|
|||
|
|
case events <- ev:
|
|||
|
|
return true
|
|||
|
|
case <-done:
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
var lastReadAt int64
|
|||
|
|
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
|||
|
|
go func(scanBuf *sseScannerBuf64K) {
|
|||
|
|
defer putSSEScannerBuf64K(scanBuf)
|
|||
|
|
defer close(events)
|
|||
|
|
for scanner.Scan() {
|
|||
|
|
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
|||
|
|
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
if err := scanner.Err(); err != nil {
|
|||
|
|
_ = sendEvent(scanEvent{err: err})
|
|||
|
|
}
|
|||
|
|
}(scanBuf)
|
|||
|
|
defer close(done)
|
|||
|
|
|
|||
|
|
for {
|
|||
|
|
select {
|
|||
|
|
case ev, ok := <-events:
|
|||
|
|
if !ok {
|
|||
|
|
return finalizeStream()
|
|||
|
|
}
|
|||
|
|
if result, err, done := handleScanErr(ev.err); done {
|
|||
|
|
return result, err
|
|||
|
|
}
|
|||
|
|
processSSELine(ev.line, len(events) == 0)
|
|||
|
|
|
|||
|
|
case <-intervalCh:
|
|||
|
|
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
|||
|
|
if time.Since(lastRead) < streamInterval {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
if clientDisconnected {
|
|||
|
|
return resultWithUsage(), fmt.Errorf("stream usage incomplete after timeout")
|
|||
|
|
}
|
|||
|
|
logger.LegacyPrintf("service.openai_gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
|||
|
|
// 处理流超时,可能标记账户为临时不可调度或错误状态
|
|||
|
|
if s.rateLimitService != nil {
|
|||
|
|
s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel)
|
|||
|
|
}
|
|||
|
|
sendErrorEvent("stream_timeout")
|
|||
|
|
return resultWithUsage(), fmt.Errorf("stream data interval timeout")
|
|||
|
|
|
|||
|
|
case <-keepaliveCh:
|
|||
|
|
if clientDisconnected {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
if time.Since(lastDataAt) < keepaliveInterval {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
if _, err := bufferedWriter.WriteString(":\n\n"); err != nil {
|
|||
|
|
clientDisconnected = true
|
|||
|
|
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing")
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
if err := flushBuffered(); err != nil {
|
|||
|
|
clientDisconnected = true
|
|||
|
|
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during keepalive flush, continuing to drain upstream for billing")
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// extractOpenAISSEDataLine 低开销提取 SSE `data:` 行内容。
|
|||
|
|
// 兼容 `data: xxx` 与 `data:xxx` 两种格式。
|
|||
|
|
func extractOpenAISSEDataLine(line string) (string, bool) {
|
|||
|
|
if !strings.HasPrefix(line, "data:") {
|
|||
|
|
return "", false
|
|||
|
|
}
|
|||
|
|
start := len("data:")
|
|||
|
|
for start < len(line) {
|
|||
|
|
if line[start] != ' ' && line[start] != ' ' {
|
|||
|
|
break
|
|||
|
|
}
|
|||
|
|
start++
|
|||
|
|
}
|
|||
|
|
return line[start:], true
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel string) string {
|
|||
|
|
data, ok := extractOpenAISSEDataLine(line)
|
|||
|
|
if !ok {
|
|||
|
|
return line
|
|||
|
|
}
|
|||
|
|
if data == "" || data == "[DONE]" {
|
|||
|
|
return line
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 使用 gjson 精确检查 model 字段,避免全量 JSON 反序列化
|
|||
|
|
if m := gjson.Get(data, "model"); m.Exists() && m.Str == fromModel {
|
|||
|
|
newData, err := sjson.Set(data, "model", toModel)
|
|||
|
|
if err != nil {
|
|||
|
|
return line
|
|||
|
|
}
|
|||
|
|
return "data: " + newData
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 检查嵌套的 response.model 字段
|
|||
|
|
if m := gjson.Get(data, "response.model"); m.Exists() && m.Str == fromModel {
|
|||
|
|
newData, err := sjson.Set(data, "response.model", toModel)
|
|||
|
|
if err != nil {
|
|||
|
|
return line
|
|||
|
|
}
|
|||
|
|
return "data: " + newData
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return line
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// correctToolCallsInResponseBody 修正响应体中的工具调用
|
|||
|
|
func (s *OpenAIGatewayService) correctToolCallsInResponseBody(body []byte) []byte {
|
|||
|
|
if len(body) == 0 {
|
|||
|
|
return body
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
corrected, changed := s.toolCorrector.CorrectToolCallsInSSEBytes(body)
|
|||
|
|
if changed {
|
|||
|
|
return corrected
|
|||
|
|
}
|
|||
|
|
return body
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
|
|||
|
|
s.parseSSEUsageBytes([]byte(data), usage)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsage) {
|
|||
|
|
if usage == nil || len(data) == 0 || bytes.Equal(data, []byte("[DONE]")) {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
// 选择性解析:仅在数据中包含终止事件标识时才进入字段提取。
|
|||
|
|
if len(data) < 72 {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
eventType := gjson.GetBytes(data, "type").String()
|
|||
|
|
if eventType != "response.completed" && eventType != "response.done" {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
usage.InputTokens = int(gjson.GetBytes(data, "response.usage.input_tokens").Int())
|
|||
|
|
usage.OutputTokens = int(gjson.GetBytes(data, "response.usage.output_tokens").Int())
|
|||
|
|
usage.CacheReadInputTokens = int(gjson.GetBytes(data, "response.usage.input_tokens_details.cached_tokens").Int())
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func extractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) {
|
|||
|
|
if len(body) == 0 || !gjson.ValidBytes(body) {
|
|||
|
|
return OpenAIUsage{}, false
|
|||
|
|
}
|
|||
|
|
values := gjson.GetManyBytes(
|
|||
|
|
body,
|
|||
|
|
"usage.input_tokens",
|
|||
|
|
"usage.output_tokens",
|
|||
|
|
"usage.input_tokens_details.cached_tokens",
|
|||
|
|
)
|
|||
|
|
return OpenAIUsage{
|
|||
|
|
InputTokens: int(values[0].Int()),
|
|||
|
|
OutputTokens: int(values[1].Int()),
|
|||
|
|
CacheReadInputTokens: int(values[2].Int()),
|
|||
|
|
}, true
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) {
|
|||
|
|
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
|||
|
|
body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
|
|||
|
|
if err != nil {
|
|||
|
|
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
|||
|
|
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
|||
|
|
c.JSON(http.StatusBadGateway, gin.H{
|
|||
|
|
"error": gin.H{
|
|||
|
|
"type": "upstream_error",
|
|||
|
|
"message": "Upstream response too large",
|
|||
|
|
},
|
|||
|
|
})
|
|||
|
|
}
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if account.Type == AccountTypeOAuth {
|
|||
|
|
bodyLooksLikeSSE := bytes.Contains(body, []byte("data:")) || bytes.Contains(body, []byte("event:"))
|
|||
|
|
if isEventStreamResponse(resp.Header) || bodyLooksLikeSSE {
|
|||
|
|
return s.handleOAuthSSEToJSON(resp, c, body, originalModel, mappedModel)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
usageValue, usageOK := extractOpenAIUsageFromJSONBytes(body)
|
|||
|
|
if !usageOK {
|
|||
|
|
return nil, fmt.Errorf("parse response: invalid json response")
|
|||
|
|
}
|
|||
|
|
usage := &usageValue
|
|||
|
|
|
|||
|
|
// Replace model in response if needed
|
|||
|
|
if originalModel != mappedModel {
|
|||
|
|
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
|||
|
|
|
|||
|
|
contentType := "application/json"
|
|||
|
|
if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled {
|
|||
|
|
if upstreamType := resp.Header.Get("Content-Type"); upstreamType != "" {
|
|||
|
|
contentType = upstreamType
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
c.Data(resp.StatusCode, contentType, body)
|
|||
|
|
|
|||
|
|
return usage, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func isEventStreamResponse(header http.Header) bool {
|
|||
|
|
contentType := strings.ToLower(header.Get("Content-Type"))
|
|||
|
|
return strings.Contains(contentType, "text/event-stream")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel, mappedModel string) (*OpenAIUsage, error) {
|
|||
|
|
bodyText := string(body)
|
|||
|
|
finalResponse, ok := extractCodexFinalResponse(bodyText)
|
|||
|
|
|
|||
|
|
usage := &OpenAIUsage{}
|
|||
|
|
if ok {
|
|||
|
|
if parsedUsage, parsed := extractOpenAIUsageFromJSONBytes(finalResponse); parsed {
|
|||
|
|
*usage = parsedUsage
|
|||
|
|
}
|
|||
|
|
body = finalResponse
|
|||
|
|
if originalModel != mappedModel {
|
|||
|
|
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
|
|||
|
|
}
|
|||
|
|
// Correct tool calls in final response
|
|||
|
|
body = s.correctToolCallsInResponseBody(body)
|
|||
|
|
} else {
|
|||
|
|
terminalType, terminalPayload, terminalOK := extractOpenAISSETerminalEvent(bodyText)
|
|||
|
|
if terminalOK && terminalType == "response.failed" {
|
|||
|
|
msg := extractOpenAISSEErrorMessage(terminalPayload)
|
|||
|
|
if msg == "" {
|
|||
|
|
msg = "Upstream compact response failed"
|
|||
|
|
}
|
|||
|
|
return nil, s.writeOpenAINonStreamingProtocolError(resp, c, msg)
|
|||
|
|
}
|
|||
|
|
usage = s.parseSSEUsageFromBody(bodyText)
|
|||
|
|
if originalModel != mappedModel {
|
|||
|
|
bodyText = s.replaceModelInSSEBody(bodyText, mappedModel, originalModel)
|
|||
|
|
}
|
|||
|
|
body = []byte(bodyText)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
|||
|
|
|
|||
|
|
contentType := "application/json; charset=utf-8"
|
|||
|
|
if !ok {
|
|||
|
|
contentType = resp.Header.Get("Content-Type")
|
|||
|
|
if contentType == "" {
|
|||
|
|
contentType = "text/event-stream"
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
c.Data(resp.StatusCode, contentType, body)
|
|||
|
|
|
|||
|
|
return usage, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func extractOpenAISSETerminalEvent(body string) (string, []byte, bool) {
|
|||
|
|
lines := strings.Split(body, "\n")
|
|||
|
|
for _, line := range lines {
|
|||
|
|
data, ok := extractOpenAISSEDataLine(line)
|
|||
|
|
if !ok || data == "" || data == "[DONE]" {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
eventType := strings.TrimSpace(gjson.Get(data, "type").String())
|
|||
|
|
switch eventType {
|
|||
|
|
case "response.completed", "response.done", "response.failed":
|
|||
|
|
return eventType, []byte(data), true
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
return "", nil, false
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func extractOpenAISSEErrorMessage(payload []byte) string {
|
|||
|
|
if len(payload) == 0 {
|
|||
|
|
return ""
|
|||
|
|
}
|
|||
|
|
for _, path := range []string{"response.error.message", "error.message", "message"} {
|
|||
|
|
if msg := strings.TrimSpace(gjson.GetBytes(payload, path).String()); msg != "" {
|
|||
|
|
return sanitizeUpstreamErrorMessage(msg)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
return sanitizeUpstreamErrorMessage(strings.TrimSpace(extractUpstreamErrorMessage(payload)))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *OpenAIGatewayService) writeOpenAINonStreamingProtocolError(resp *http.Response, c *gin.Context, message string) error {
|
|||
|
|
message = sanitizeUpstreamErrorMessage(strings.TrimSpace(message))
|
|||
|
|
if message == "" {
|
|||
|
|
message = "Upstream returned an invalid non-streaming response"
|
|||
|
|
}
|
|||
|
|
setOpsUpstreamError(c, http.StatusBadGateway, message, "")
|
|||
|
|
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
|||
|
|
c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
|||
|
|
c.JSON(http.StatusBadGateway, gin.H{
|
|||
|
|
"error": gin.H{
|
|||
|
|
"type": "upstream_error",
|
|||
|
|
"message": message,
|
|||
|
|
},
|
|||
|
|
})
|
|||
|
|
return fmt.Errorf("non-streaming openai protocol error: %s", message)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func extractCodexFinalResponse(body string) ([]byte, bool) {
|
|||
|
|
lines := strings.Split(body, "\n")
|
|||
|
|
for _, line := range lines {
|
|||
|
|
data, ok := extractOpenAISSEDataLine(line)
|
|||
|
|
if !ok {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
if data == "" || data == "[DONE]" {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
eventType := gjson.Get(data, "type").String()
|
|||
|
|
if eventType == "response.done" || eventType == "response.completed" {
|
|||
|
|
if response := gjson.Get(data, "response"); response.Exists() && response.Type == gjson.JSON && response.Raw != "" {
|
|||
|
|
return []byte(response.Raw), true
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
return nil, false
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage {
|
|||
|
|
usage := &OpenAIUsage{}
|
|||
|
|
lines := strings.Split(body, "\n")
|
|||
|
|
for _, line := range lines {
|
|||
|
|
data, ok := extractOpenAISSEDataLine(line)
|
|||
|
|
if !ok {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
if data == "" || data == "[DONE]" {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
s.parseSSEUsageBytes([]byte(data), usage)
|
|||
|
|
}
|
|||
|
|
return usage
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *OpenAIGatewayService) replaceModelInSSEBody(body, fromModel, toModel string) string {
|
|||
|
|
lines := strings.Split(body, "\n")
|
|||
|
|
for i, line := range lines {
|
|||
|
|
if _, ok := extractOpenAISSEDataLine(line); !ok {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
lines[i] = s.replaceModelInSSELine(line, fromModel, toModel)
|
|||
|
|
}
|
|||
|
|
return strings.Join(lines, "\n")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, error) {
|
|||
|
|
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
|
|||
|
|
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
|
|||
|
|
if err != nil {
|
|||
|
|
return "", fmt.Errorf("invalid base_url: %w", err)
|
|||
|
|
}
|
|||
|
|
return normalized, nil
|
|||
|
|
}
|
|||
|
|
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
|
|||
|
|
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
|
|||
|
|
RequireAllowlist: true,
|
|||
|
|
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
|
|||
|
|
})
|
|||
|
|
if err != nil {
|
|||
|
|
return "", fmt.Errorf("invalid base_url: %w", err)
|
|||
|
|
}
|
|||
|
|
return normalized, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// buildOpenAIResponsesURL 组装 OpenAI Responses 端点。
|
|||
|
|
// - base 以 /v1 结尾:追加 /responses
|
|||
|
|
// - base 已是 /responses:原样返回
|
|||
|
|
// - 其他情况:追加 /v1/responses
|
|||
|
|
func buildOpenAIResponsesURL(base string) string {
|
|||
|
|
normalized := strings.TrimRight(strings.TrimSpace(base), "/")
|
|||
|
|
if strings.HasSuffix(normalized, "/responses") {
|
|||
|
|
return normalized
|
|||
|
|
}
|
|||
|
|
if strings.HasSuffix(normalized, "/v1") {
|
|||
|
|
return normalized + "/responses"
|
|||
|
|
}
|
|||
|
|
return normalized + "/v1/responses"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func trimOpenAIEncryptedReasoningItems(reqBody map[string]any) bool {
|
|||
|
|
if len(reqBody) == 0 {
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
inputValue, has := reqBody["input"]
|
|||
|
|
if !has {
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
switch input := inputValue.(type) {
|
|||
|
|
case []any:
|
|||
|
|
filtered := input[:0]
|
|||
|
|
changed := false
|
|||
|
|
for _, item := range input {
|
|||
|
|
nextItem, itemChanged, keep := sanitizeEncryptedReasoningInputItem(item)
|
|||
|
|
if itemChanged {
|
|||
|
|
changed = true
|
|||
|
|
}
|
|||
|
|
if !keep {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
filtered = append(filtered, nextItem)
|
|||
|
|
}
|
|||
|
|
if !changed {
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
if len(filtered) == 0 {
|
|||
|
|
delete(reqBody, "input")
|
|||
|
|
return true
|
|||
|
|
}
|
|||
|
|
reqBody["input"] = filtered
|
|||
|
|
return true
|
|||
|
|
case []map[string]any:
|
|||
|
|
filtered := input[:0]
|
|||
|
|
changed := false
|
|||
|
|
for _, item := range input {
|
|||
|
|
nextItem, itemChanged, keep := sanitizeEncryptedReasoningInputItem(item)
|
|||
|
|
if itemChanged {
|
|||
|
|
changed = true
|
|||
|
|
}
|
|||
|
|
if !keep {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
nextMap, ok := nextItem.(map[string]any)
|
|||
|
|
if !ok {
|
|||
|
|
filtered = append(filtered, item)
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
filtered = append(filtered, nextMap)
|
|||
|
|
}
|
|||
|
|
if !changed {
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
if len(filtered) == 0 {
|
|||
|
|
delete(reqBody, "input")
|
|||
|
|
return true
|
|||
|
|
}
|
|||
|
|
reqBody["input"] = filtered
|
|||
|
|
return true
|
|||
|
|
case map[string]any:
|
|||
|
|
nextItem, changed, keep := sanitizeEncryptedReasoningInputItem(input)
|
|||
|
|
if !changed {
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
if !keep {
|
|||
|
|
delete(reqBody, "input")
|
|||
|
|
return true
|
|||
|
|
}
|
|||
|
|
nextMap, ok := nextItem.(map[string]any)
|
|||
|
|
if !ok {
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
reqBody["input"] = nextMap
|
|||
|
|
return true
|
|||
|
|
default:
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func sanitizeEncryptedReasoningInputItem(item any) (next any, changed bool, keep bool) {
|
|||
|
|
inputItem, ok := item.(map[string]any)
|
|||
|
|
if !ok {
|
|||
|
|
return item, false, true
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
itemType, _ := inputItem["type"].(string)
|
|||
|
|
if strings.TrimSpace(itemType) != "reasoning" {
|
|||
|
|
return item, false, true
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
_, hasEncryptedContent := inputItem["encrypted_content"]
|
|||
|
|
if !hasEncryptedContent {
|
|||
|
|
return item, false, true
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
delete(inputItem, "encrypted_content")
|
|||
|
|
if len(inputItem) == 1 {
|
|||
|
|
return nil, true, false
|
|||
|
|
}
|
|||
|
|
return inputItem, true, true
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func IsOpenAIResponsesCompactPathForTest(c *gin.Context) bool {
|
|||
|
|
return isOpenAIResponsesCompactPath(c)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func OpenAICompactSessionSeedKeyForTest() string {
|
|||
|
|
return openAICompactSessionSeedKey
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func NormalizeOpenAICompactRequestBodyForTest(body []byte) ([]byte, bool, error) {
|
|||
|
|
return normalizeOpenAICompactRequestBody(body)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func isOpenAIResponsesCompactPath(c *gin.Context) bool {
|
|||
|
|
suffix := strings.TrimSpace(openAIResponsesRequestPathSuffix(c))
|
|||
|
|
return suffix == "/compact" || strings.HasPrefix(suffix, "/compact/")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func normalizeOpenAICompactRequestBody(body []byte) ([]byte, bool, error) {
|
|||
|
|
if len(body) == 0 {
|
|||
|
|
return body, false, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
normalized := []byte(`{}`)
|
|||
|
|
for _, field := range []string{"model", "input", "instructions", "previous_response_id"} {
|
|||
|
|
value := gjson.GetBytes(body, field)
|
|||
|
|
if !value.Exists() {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
next, err := sjson.SetRawBytes(normalized, field, []byte(value.Raw))
|
|||
|
|
if err != nil {
|
|||
|
|
return body, false, fmt.Errorf("normalize compact body %s: %w", field, err)
|
|||
|
|
}
|
|||
|
|
normalized = next
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if bytes.Equal(bytes.TrimSpace(body), bytes.TrimSpace(normalized)) {
|
|||
|
|
return body, false, nil
|
|||
|
|
}
|
|||
|
|
return normalized, true, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func resolveOpenAICompactSessionID(c *gin.Context) string {
|
|||
|
|
if c != nil {
|
|||
|
|
if sessionID := strings.TrimSpace(c.GetHeader("session_id")); sessionID != "" {
|
|||
|
|
return sessionID
|
|||
|
|
}
|
|||
|
|
if conversationID := strings.TrimSpace(c.GetHeader("conversation_id")); conversationID != "" {
|
|||
|
|
return conversationID
|
|||
|
|
}
|
|||
|
|
if seed, ok := c.Get(openAICompactSessionSeedKey); ok {
|
|||
|
|
if seedStr, ok := seed.(string); ok && strings.TrimSpace(seedStr) != "" {
|
|||
|
|
return strings.TrimSpace(seedStr)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
return uuid.NewString()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func openAIResponsesRequestPathSuffix(c *gin.Context) string {
|
|||
|
|
if c == nil || c.Request == nil || c.Request.URL == nil {
|
|||
|
|
return ""
|
|||
|
|
}
|
|||
|
|
normalizedPath := strings.TrimRight(strings.TrimSpace(c.Request.URL.Path), "/")
|
|||
|
|
if normalizedPath == "" {
|
|||
|
|
return ""
|
|||
|
|
}
|
|||
|
|
idx := strings.LastIndex(normalizedPath, "/responses")
|
|||
|
|
if idx < 0 {
|
|||
|
|
return ""
|
|||
|
|
}
|
|||
|
|
suffix := normalizedPath[idx+len("/responses"):]
|
|||
|
|
if suffix == "" || suffix == "/" {
|
|||
|
|
return ""
|
|||
|
|
}
|
|||
|
|
if !strings.HasPrefix(suffix, "/") {
|
|||
|
|
return ""
|
|||
|
|
}
|
|||
|
|
return suffix
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func appendOpenAIResponsesRequestPathSuffix(baseURL, suffix string) string {
|
|||
|
|
trimmedBase := strings.TrimRight(strings.TrimSpace(baseURL), "/")
|
|||
|
|
trimmedSuffix := strings.TrimSpace(suffix)
|
|||
|
|
if trimmedBase == "" || trimmedSuffix == "" {
|
|||
|
|
return trimmedBase
|
|||
|
|
}
|
|||
|
|
return trimmedBase + trimmedSuffix
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
|
|||
|
|
// 使用 gjson/sjson 精确替换 model 字段,避免全量 JSON 反序列化
|
|||
|
|
if m := gjson.GetBytes(body, "model"); m.Exists() && m.Str == fromModel {
|
|||
|
|
newBody, err := sjson.SetBytes(body, "model", toModel)
|
|||
|
|
if err != nil {
|
|||
|
|
return body
|
|||
|
|
}
|
|||
|
|
return newBody
|
|||
|
|
}
|
|||
|
|
return body
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// OpenAIRecordUsageInput input for recording usage
|
|||
|
|
type OpenAIRecordUsageInput struct {
|
|||
|
|
Result *OpenAIForwardResult
|
|||
|
|
APIKey *APIKey
|
|||
|
|
User *User
|
|||
|
|
Account *Account
|
|||
|
|
Subscription *UserSubscription
|
|||
|
|
InboundEndpoint string
|
|||
|
|
UpstreamEndpoint string
|
|||
|
|
UserAgent string // 请求的 User-Agent
|
|||
|
|
IPAddress string // 请求的客户端 IP 地址
|
|||
|
|
RequestPayloadHash string
|
|||
|
|
APIKeyService APIKeyQuotaUpdater
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// RecordUsage records usage and deducts balance
|
|||
|
|
func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error {
|
|||
|
|
result := input.Result
|
|||
|
|
|
|||
|
|
// 跳过所有 token 均为零的用量记录——上游未返回 usage 时不应写入数据库
|
|||
|
|
if result.Usage.InputTokens == 0 && result.Usage.OutputTokens == 0 &&
|
|||
|
|
result.Usage.CacheCreationInputTokens == 0 && result.Usage.CacheReadInputTokens == 0 {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
apiKey := input.APIKey
|
|||
|
|
user := input.User
|
|||
|
|
account := input.Account
|
|||
|
|
subscription := input.Subscription
|
|||
|
|
|
|||
|
|
// 计算实际的新输入token(减去缓存读取的token)
|
|||
|
|
// 因为 input_tokens 包含了 cache_read_tokens,而缓存读取的token不应按输入价格计费
|
|||
|
|
actualInputTokens := result.Usage.InputTokens - result.Usage.CacheReadInputTokens
|
|||
|
|
if actualInputTokens < 0 {
|
|||
|
|
actualInputTokens = 0
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Calculate cost
|
|||
|
|
tokens := UsageTokens{
|
|||
|
|
InputTokens: actualInputTokens,
|
|||
|
|
OutputTokens: result.Usage.OutputTokens,
|
|||
|
|
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
|||
|
|
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Get rate multiplier
|
|||
|
|
multiplier := s.cfg.Default.RateMultiplier
|
|||
|
|
if apiKey.GroupID != nil && apiKey.Group != nil {
|
|||
|
|
resolver := s.userGroupRateResolver
|
|||
|
|
if resolver == nil {
|
|||
|
|
resolver = newUserGroupRateResolver(nil, nil, resolveUserGroupRateCacheTTL(s.cfg), nil, "service.openai_gateway")
|
|||
|
|
}
|
|||
|
|
multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
billingModel := result.Model
|
|||
|
|
if result.BillingModel != "" {
|
|||
|
|
billingModel = result.BillingModel
|
|||
|
|
}
|
|||
|
|
serviceTier := ""
|
|||
|
|
if result.ServiceTier != nil {
|
|||
|
|
serviceTier = strings.TrimSpace(*result.ServiceTier)
|
|||
|
|
}
|
|||
|
|
cost, err := s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier)
|
|||
|
|
if err != nil {
|
|||
|
|
cost = &CostBreakdown{ActualCost: 0}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Determine billing type
|
|||
|
|
isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
|
|||
|
|
billingType := BillingTypeBalance
|
|||
|
|
if isSubscriptionBilling {
|
|||
|
|
billingType = BillingTypeSubscription
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Create usage log
|
|||
|
|
durationMs := int(result.Duration.Milliseconds())
|
|||
|
|
accountRateMultiplier := account.BillingRateMultiplier()
|
|||
|
|
requestID := resolveUsageBillingRequestID(ctx, result.RequestID)
|
|||
|
|
usageLog := &UsageLog{
|
|||
|
|
UserID: user.ID,
|
|||
|
|
APIKeyID: apiKey.ID,
|
|||
|
|
AccountID: account.ID,
|
|||
|
|
RequestID: requestID,
|
|||
|
|
Model: billingModel,
|
|||
|
|
ServiceTier: result.ServiceTier,
|
|||
|
|
ReasoningEffort: result.ReasoningEffort,
|
|||
|
|
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
|
|||
|
|
UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint),
|
|||
|
|
InputTokens: actualInputTokens,
|
|||
|
|
OutputTokens: result.Usage.OutputTokens,
|
|||
|
|
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
|||
|
|
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
|||
|
|
InputCost: cost.InputCost,
|
|||
|
|
OutputCost: cost.OutputCost,
|
|||
|
|
CacheCreationCost: cost.CacheCreationCost,
|
|||
|
|
CacheReadCost: cost.CacheReadCost,
|
|||
|
|
TotalCost: cost.TotalCost,
|
|||
|
|
ActualCost: cost.ActualCost,
|
|||
|
|
RateMultiplier: multiplier,
|
|||
|
|
AccountRateMultiplier: &accountRateMultiplier,
|
|||
|
|
BillingType: billingType,
|
|||
|
|
Stream: result.Stream,
|
|||
|
|
OpenAIWSMode: result.OpenAIWSMode,
|
|||
|
|
DurationMs: &durationMs,
|
|||
|
|
FirstTokenMs: result.FirstTokenMs,
|
|||
|
|
CreatedAt: time.Now(),
|
|||
|
|
}
|
|||
|
|
// 添加 UserAgent
|
|||
|
|
if input.UserAgent != "" {
|
|||
|
|
usageLog.UserAgent = &input.UserAgent
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 添加 IPAddress
|
|||
|
|
if input.IPAddress != "" {
|
|||
|
|
usageLog.IPAddress = &input.IPAddress
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if apiKey.GroupID != nil {
|
|||
|
|
usageLog.GroupID = apiKey.GroupID
|
|||
|
|
}
|
|||
|
|
if subscription != nil {
|
|||
|
|
usageLog.SubscriptionID = &subscription.ID
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
|||
|
|
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway")
|
|||
|
|
logger.LegacyPrintf("service.openai_gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
|||
|
|
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
billingErr := func() error {
|
|||
|
|
_, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{
|
|||
|
|
Cost: cost,
|
|||
|
|
User: user,
|
|||
|
|
APIKey: apiKey,
|
|||
|
|
Account: account,
|
|||
|
|
Subscription: subscription,
|
|||
|
|
RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash),
|
|||
|
|
IsSubscriptionBill: isSubscriptionBilling,
|
|||
|
|
AccountRateMultiplier: accountRateMultiplier,
|
|||
|
|
APIKeyService: input.APIKeyService,
|
|||
|
|
}, s.billingDeps(), s.usageBillingRepo)
|
|||
|
|
return err
|
|||
|
|
}()
|
|||
|
|
|
|||
|
|
if billingErr != nil {
|
|||
|
|
return billingErr
|
|||
|
|
}
|
|||
|
|
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway")
|
|||
|
|
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ParseCodexRateLimitHeaders extracts Codex usage limits from response headers.
|
|||
|
|
// Exported for use in ratelimit_service when handling OpenAI 429 responses.
|
|||
|
|
func ParseCodexRateLimitHeaders(headers http.Header) *OpenAICodexUsageSnapshot {
|
|||
|
|
snapshot := &OpenAICodexUsageSnapshot{}
|
|||
|
|
hasData := false
|
|||
|
|
|
|||
|
|
// Helper to parse float64 from header
|
|||
|
|
parseFloat := func(key string) *float64 {
|
|||
|
|
if v := headers.Get(key); v != "" {
|
|||
|
|
if f, err := strconv.ParseFloat(v, 64); err == nil {
|
|||
|
|
return &f
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Helper to parse int from header
|
|||
|
|
parseInt := func(key string) *int {
|
|||
|
|
if v := headers.Get(key); v != "" {
|
|||
|
|
if i, err := strconv.Atoi(v); err == nil {
|
|||
|
|
return &i
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Primary (weekly) limits
|
|||
|
|
if v := parseFloat("x-codex-primary-used-percent"); v != nil {
|
|||
|
|
snapshot.PrimaryUsedPercent = v
|
|||
|
|
hasData = true
|
|||
|
|
}
|
|||
|
|
if v := parseInt("x-codex-primary-reset-after-seconds"); v != nil {
|
|||
|
|
snapshot.PrimaryResetAfterSeconds = v
|
|||
|
|
hasData = true
|
|||
|
|
}
|
|||
|
|
if v := parseInt("x-codex-primary-window-minutes"); v != nil {
|
|||
|
|
snapshot.PrimaryWindowMinutes = v
|
|||
|
|
hasData = true
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Secondary (5h) limits
|
|||
|
|
if v := parseFloat("x-codex-secondary-used-percent"); v != nil {
|
|||
|
|
snapshot.SecondaryUsedPercent = v
|
|||
|
|
hasData = true
|
|||
|
|
}
|
|||
|
|
if v := parseInt("x-codex-secondary-reset-after-seconds"); v != nil {
|
|||
|
|
snapshot.SecondaryResetAfterSeconds = v
|
|||
|
|
hasData = true
|
|||
|
|
}
|
|||
|
|
if v := parseInt("x-codex-secondary-window-minutes"); v != nil {
|
|||
|
|
snapshot.SecondaryWindowMinutes = v
|
|||
|
|
hasData = true
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Overflow ratio
|
|||
|
|
if v := parseFloat("x-codex-primary-over-secondary-limit-percent"); v != nil {
|
|||
|
|
snapshot.PrimaryOverSecondaryPercent = v
|
|||
|
|
hasData = true
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if !hasData {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
snapshot.UpdatedAt = time.Now().Format(time.RFC3339)
|
|||
|
|
return snapshot
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func codexSnapshotBaseTime(snapshot *OpenAICodexUsageSnapshot, fallback time.Time) time.Time {
|
|||
|
|
if snapshot == nil {
|
|||
|
|
return fallback
|
|||
|
|
}
|
|||
|
|
if snapshot.UpdatedAt == "" {
|
|||
|
|
return fallback
|
|||
|
|
}
|
|||
|
|
base, err := time.Parse(time.RFC3339, snapshot.UpdatedAt)
|
|||
|
|
if err != nil {
|
|||
|
|
return fallback
|
|||
|
|
}
|
|||
|
|
return base
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func codexResetAtRFC3339(base time.Time, resetAfterSeconds *int) *string {
|
|||
|
|
if resetAfterSeconds == nil {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
sec := *resetAfterSeconds
|
|||
|
|
if sec < 0 {
|
|||
|
|
sec = 0
|
|||
|
|
}
|
|||
|
|
resetAt := base.Add(time.Duration(sec) * time.Second).Format(time.RFC3339)
|
|||
|
|
return &resetAt
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func buildCodexUsageExtraUpdates(snapshot *OpenAICodexUsageSnapshot, fallbackNow time.Time) map[string]any {
|
|||
|
|
if snapshot == nil {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
baseTime := codexSnapshotBaseTime(snapshot, fallbackNow)
|
|||
|
|
updates := make(map[string]any)
|
|||
|
|
|
|||
|
|
// 保存原始 primary/secondary 字段,便于排查问题
|
|||
|
|
if snapshot.PrimaryUsedPercent != nil {
|
|||
|
|
updates["codex_primary_used_percent"] = *snapshot.PrimaryUsedPercent
|
|||
|
|
}
|
|||
|
|
if snapshot.PrimaryResetAfterSeconds != nil {
|
|||
|
|
updates["codex_primary_reset_after_seconds"] = *snapshot.PrimaryResetAfterSeconds
|
|||
|
|
}
|
|||
|
|
if snapshot.PrimaryWindowMinutes != nil {
|
|||
|
|
updates["codex_primary_window_minutes"] = *snapshot.PrimaryWindowMinutes
|
|||
|
|
}
|
|||
|
|
if snapshot.SecondaryUsedPercent != nil {
|
|||
|
|
updates["codex_secondary_used_percent"] = *snapshot.SecondaryUsedPercent
|
|||
|
|
}
|
|||
|
|
if snapshot.SecondaryResetAfterSeconds != nil {
|
|||
|
|
updates["codex_secondary_reset_after_seconds"] = *snapshot.SecondaryResetAfterSeconds
|
|||
|
|
}
|
|||
|
|
if snapshot.SecondaryWindowMinutes != nil {
|
|||
|
|
updates["codex_secondary_window_minutes"] = *snapshot.SecondaryWindowMinutes
|
|||
|
|
}
|
|||
|
|
if snapshot.PrimaryOverSecondaryPercent != nil {
|
|||
|
|
updates["codex_primary_over_secondary_percent"] = *snapshot.PrimaryOverSecondaryPercent
|
|||
|
|
}
|
|||
|
|
updates["codex_usage_updated_at"] = baseTime.Format(time.RFC3339)
|
|||
|
|
|
|||
|
|
// 归一化到 5h/7d 规范字段
|
|||
|
|
if normalized := snapshot.Normalize(); normalized != nil {
|
|||
|
|
if normalized.Used5hPercent != nil {
|
|||
|
|
updates["codex_5h_used_percent"] = *normalized.Used5hPercent
|
|||
|
|
}
|
|||
|
|
if normalized.Reset5hSeconds != nil {
|
|||
|
|
updates["codex_5h_reset_after_seconds"] = *normalized.Reset5hSeconds
|
|||
|
|
}
|
|||
|
|
if normalized.Window5hMinutes != nil {
|
|||
|
|
updates["codex_5h_window_minutes"] = *normalized.Window5hMinutes
|
|||
|
|
}
|
|||
|
|
if normalized.Used7dPercent != nil {
|
|||
|
|
updates["codex_7d_used_percent"] = *normalized.Used7dPercent
|
|||
|
|
}
|
|||
|
|
if normalized.Reset7dSeconds != nil {
|
|||
|
|
updates["codex_7d_reset_after_seconds"] = *normalized.Reset7dSeconds
|
|||
|
|
}
|
|||
|
|
if normalized.Window7dMinutes != nil {
|
|||
|
|
updates["codex_7d_window_minutes"] = *normalized.Window7dMinutes
|
|||
|
|
}
|
|||
|
|
if reset5hAt := codexResetAtRFC3339(baseTime, normalized.Reset5hSeconds); reset5hAt != nil {
|
|||
|
|
updates["codex_5h_reset_at"] = *reset5hAt
|
|||
|
|
}
|
|||
|
|
if reset7dAt := codexResetAtRFC3339(baseTime, normalized.Reset7dSeconds); reset7dAt != nil {
|
|||
|
|
updates["codex_7d_reset_at"] = *reset7dAt
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return updates
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func codexUsagePercentExhausted(value *float64) bool {
|
|||
|
|
return value != nil && *value >= 100-1e-9
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func codexRateLimitResetAtFromSnapshot(snapshot *OpenAICodexUsageSnapshot, fallbackNow time.Time) *time.Time {
|
|||
|
|
if snapshot == nil {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
normalized := snapshot.Normalize()
|
|||
|
|
if normalized == nil {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
baseTime := codexSnapshotBaseTime(snapshot, fallbackNow)
|
|||
|
|
if codexUsagePercentExhausted(normalized.Used7dPercent) && normalized.Reset7dSeconds != nil {
|
|||
|
|
resetAt := baseTime.Add(time.Duration(*normalized.Reset7dSeconds) * time.Second)
|
|||
|
|
return &resetAt
|
|||
|
|
}
|
|||
|
|
if codexUsagePercentExhausted(normalized.Used5hPercent) && normalized.Reset5hSeconds != nil {
|
|||
|
|
resetAt := baseTime.Add(time.Duration(*normalized.Reset5hSeconds) * time.Second)
|
|||
|
|
return &resetAt
|
|||
|
|
}
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func codexRateLimitResetAtFromExtra(extra map[string]any, now time.Time) *time.Time {
|
|||
|
|
if len(extra) == 0 {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
if progress := buildCodexUsageProgressFromExtra(extra, "7d", now); progress != nil && codexUsagePercentExhausted(&progress.Utilization) && progress.ResetsAt != nil && now.Before(*progress.ResetsAt) {
|
|||
|
|
resetAt := progress.ResetsAt.UTC()
|
|||
|
|
return &resetAt
|
|||
|
|
}
|
|||
|
|
if progress := buildCodexUsageProgressFromExtra(extra, "5h", now); progress != nil && codexUsagePercentExhausted(&progress.Utilization) && progress.ResetsAt != nil && now.Before(*progress.ResetsAt) {
|
|||
|
|
resetAt := progress.ResetsAt.UTC()
|
|||
|
|
return &resetAt
|
|||
|
|
}
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func applyOpenAICodexRateLimitFromExtra(account *Account, now time.Time) (*time.Time, bool) {
|
|||
|
|
if account == nil || !account.IsOpenAI() {
|
|||
|
|
return nil, false
|
|||
|
|
}
|
|||
|
|
resetAt := codexRateLimitResetAtFromExtra(account.Extra, now)
|
|||
|
|
if resetAt == nil {
|
|||
|
|
return nil, false
|
|||
|
|
}
|
|||
|
|
if account.RateLimitResetAt != nil && now.Before(*account.RateLimitResetAt) && !account.RateLimitResetAt.Before(*resetAt) {
|
|||
|
|
return account.RateLimitResetAt, false
|
|||
|
|
}
|
|||
|
|
account.RateLimitResetAt = resetAt
|
|||
|
|
return resetAt, true
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func syncOpenAICodexRateLimitFromExtra(ctx context.Context, repo AccountRepository, account *Account, now time.Time) *time.Time {
|
|||
|
|
resetAt, changed := applyOpenAICodexRateLimitFromExtra(account, now)
|
|||
|
|
if !changed || resetAt == nil || repo == nil || account == nil || account.ID <= 0 {
|
|||
|
|
return resetAt
|
|||
|
|
}
|
|||
|
|
_ = repo.SetRateLimited(ctx, account.ID, *resetAt)
|
|||
|
|
return resetAt
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// updateCodexUsageSnapshot saves the Codex usage snapshot to account's Extra field
|
|||
|
|
func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, accountID int64, snapshot *OpenAICodexUsageSnapshot) {
|
|||
|
|
if snapshot == nil {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
if s == nil || s.accountRepo == nil {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
now := time.Now()
|
|||
|
|
updates := buildCodexUsageExtraUpdates(snapshot, now)
|
|||
|
|
resetAt := codexRateLimitResetAtFromSnapshot(snapshot, now)
|
|||
|
|
if len(updates) == 0 && resetAt == nil {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
shouldPersistUpdates := len(updates) > 0 && s.getCodexSnapshotThrottle().Allow(accountID, now)
|
|||
|
|
if !shouldPersistUpdates && resetAt == nil {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
go func() {
|
|||
|
|
updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|||
|
|
defer cancel()
|
|||
|
|
if shouldPersistUpdates {
|
|||
|
|
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
|
|||
|
|
}
|
|||
|
|
if resetAt != nil {
|
|||
|
|
_ = s.accountRepo.SetRateLimited(updateCtx, accountID, *resetAt)
|
|||
|
|
}
|
|||
|
|
}()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *OpenAIGatewayService) UpdateCodexUsageSnapshotFromHeaders(ctx context.Context, accountID int64, headers http.Header) {
|
|||
|
|
if accountID <= 0 || headers == nil {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
if snapshot := ParseCodexRateLimitHeaders(headers); snapshot != nil {
|
|||
|
|
s.updateCodexUsageSnapshot(ctx, accountID, snapshot)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func getOpenAIReasoningEffortFromReqBody(reqBody map[string]any) (value string, present bool) {
|
|||
|
|
if reqBody == nil {
|
|||
|
|
return "", false
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Primary: reasoning.effort
|
|||
|
|
if reasoning, ok := reqBody["reasoning"].(map[string]any); ok {
|
|||
|
|
if effort, ok := reasoning["effort"].(string); ok {
|
|||
|
|
return normalizeOpenAIReasoningEffort(effort), true
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Fallback: some clients may use a flat field.
|
|||
|
|
if effort, ok := reqBody["reasoning_effort"].(string); ok {
|
|||
|
|
return normalizeOpenAIReasoningEffort(effort), true
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return "", false
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func deriveOpenAIReasoningEffortFromModel(model string) string {
|
|||
|
|
if strings.TrimSpace(model) == "" {
|
|||
|
|
return ""
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
modelID := strings.TrimSpace(model)
|
|||
|
|
if strings.Contains(modelID, "/") {
|
|||
|
|
parts := strings.Split(modelID, "/")
|
|||
|
|
modelID = parts[len(parts)-1]
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
parts := strings.FieldsFunc(strings.ToLower(modelID), func(r rune) bool {
|
|||
|
|
switch r {
|
|||
|
|
case '-', '_', ' ':
|
|||
|
|
return true
|
|||
|
|
default:
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
if len(parts) == 0 {
|
|||
|
|
return ""
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return normalizeOpenAIReasoningEffort(parts[len(parts)-1])
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func extractOpenAIRequestMetaFromBody(body []byte) (model string, stream bool, promptCacheKey string) {
|
|||
|
|
if len(body) == 0 {
|
|||
|
|
return "", false, ""
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
model = strings.TrimSpace(gjson.GetBytes(body, "model").String())
|
|||
|
|
stream = gjson.GetBytes(body, "stream").Bool()
|
|||
|
|
promptCacheKey = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
|
|||
|
|
return model, stream, promptCacheKey
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// normalizeOpenAIPassthroughOAuthBody 将透传 OAuth 请求体收敛为旧链路关键行为:
|
|||
|
|
// 1) store=false 2) 非 compact 保持 stream=true;compact 强制 stream=false
|
|||
|
|
func normalizeOpenAIPassthroughOAuthBody(body []byte, compact bool) ([]byte, bool, error) {
|
|||
|
|
if len(body) == 0 {
|
|||
|
|
return body, false, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
normalized := body
|
|||
|
|
changed := false
|
|||
|
|
|
|||
|
|
if compact {
|
|||
|
|
if store := gjson.GetBytes(normalized, "store"); store.Exists() {
|
|||
|
|
next, err := sjson.DeleteBytes(normalized, "store")
|
|||
|
|
if err != nil {
|
|||
|
|
return body, false, fmt.Errorf("normalize passthrough body delete store: %w", err)
|
|||
|
|
}
|
|||
|
|
normalized = next
|
|||
|
|
changed = true
|
|||
|
|
}
|
|||
|
|
if stream := gjson.GetBytes(normalized, "stream"); stream.Exists() {
|
|||
|
|
next, err := sjson.DeleteBytes(normalized, "stream")
|
|||
|
|
if err != nil {
|
|||
|
|
return body, false, fmt.Errorf("normalize passthrough body delete stream: %w", err)
|
|||
|
|
}
|
|||
|
|
normalized = next
|
|||
|
|
changed = true
|
|||
|
|
}
|
|||
|
|
} else {
|
|||
|
|
if store := gjson.GetBytes(normalized, "store"); !store.Exists() || store.Type != gjson.False {
|
|||
|
|
next, err := sjson.SetBytes(normalized, "store", false)
|
|||
|
|
if err != nil {
|
|||
|
|
return body, false, fmt.Errorf("normalize passthrough body store=false: %w", err)
|
|||
|
|
}
|
|||
|
|
normalized = next
|
|||
|
|
changed = true
|
|||
|
|
}
|
|||
|
|
if stream := gjson.GetBytes(normalized, "stream"); !stream.Exists() || stream.Type != gjson.True {
|
|||
|
|
next, err := sjson.SetBytes(normalized, "stream", true)
|
|||
|
|
if err != nil {
|
|||
|
|
return body, false, fmt.Errorf("normalize passthrough body stream=true: %w", err)
|
|||
|
|
}
|
|||
|
|
normalized = next
|
|||
|
|
changed = true
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return normalized, changed, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func detectOpenAIPassthroughInstructionsRejectReason(reqModel string, body []byte) string {
|
|||
|
|
model := strings.ToLower(strings.TrimSpace(reqModel))
|
|||
|
|
if !strings.Contains(model, "codex") {
|
|||
|
|
return ""
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
instructions := gjson.GetBytes(body, "instructions")
|
|||
|
|
if !instructions.Exists() {
|
|||
|
|
return "instructions_missing"
|
|||
|
|
}
|
|||
|
|
if instructions.Type != gjson.String {
|
|||
|
|
return "instructions_not_string"
|
|||
|
|
}
|
|||
|
|
if strings.TrimSpace(instructions.String()) == "" {
|
|||
|
|
return "instructions_empty"
|
|||
|
|
}
|
|||
|
|
return ""
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func extractOpenAIReasoningEffortFromBody(body []byte, requestedModel string) *string {
|
|||
|
|
reasoningEffort := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String())
|
|||
|
|
if reasoningEffort == "" {
|
|||
|
|
reasoningEffort = strings.TrimSpace(gjson.GetBytes(body, "reasoning_effort").String())
|
|||
|
|
}
|
|||
|
|
if reasoningEffort != "" {
|
|||
|
|
normalized := normalizeOpenAIReasoningEffort(reasoningEffort)
|
|||
|
|
if normalized == "" {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
return &normalized
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
value := deriveOpenAIReasoningEffortFromModel(requestedModel)
|
|||
|
|
if value == "" {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
return &value
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func extractOpenAIServiceTier(reqBody map[string]any) *string {
|
|||
|
|
if reqBody == nil {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
raw, ok := reqBody["service_tier"].(string)
|
|||
|
|
if !ok {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
return normalizeOpenAIServiceTier(raw)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func extractOpenAIServiceTierFromBody(body []byte) *string {
|
|||
|
|
if len(body) == 0 {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
return normalizeOpenAIServiceTier(gjson.GetBytes(body, "service_tier").String())
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func normalizeOpenAIServiceTier(raw string) *string {
|
|||
|
|
value := strings.ToLower(strings.TrimSpace(raw))
|
|||
|
|
if value == "" {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
if value == "fast" {
|
|||
|
|
value = "priority"
|
|||
|
|
}
|
|||
|
|
switch value {
|
|||
|
|
case "priority", "flex":
|
|||
|
|
return &value
|
|||
|
|
default:
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func getOpenAIRequestBodyMap(c *gin.Context, body []byte) (map[string]any, error) {
|
|||
|
|
if c != nil {
|
|||
|
|
if cached, ok := c.Get(OpenAIParsedRequestBodyKey); ok {
|
|||
|
|
if reqBody, ok := cached.(map[string]any); ok && reqBody != nil {
|
|||
|
|
return reqBody, nil
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
var reqBody map[string]any
|
|||
|
|
if err := json.Unmarshal(body, &reqBody); err != nil {
|
|||
|
|
return nil, fmt.Errorf("parse request: %w", err)
|
|||
|
|
}
|
|||
|
|
if c != nil {
|
|||
|
|
c.Set(OpenAIParsedRequestBodyKey, reqBody)
|
|||
|
|
}
|
|||
|
|
return reqBody, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func extractOpenAIReasoningEffort(reqBody map[string]any, requestedModel string) *string {
|
|||
|
|
if value, present := getOpenAIReasoningEffortFromReqBody(reqBody); present {
|
|||
|
|
if value == "" {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
return &value
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
value := deriveOpenAIReasoningEffortFromModel(requestedModel)
|
|||
|
|
if value == "" {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
return &value
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func normalizeOpenAIReasoningEffort(raw string) string {
|
|||
|
|
value := strings.ToLower(strings.TrimSpace(raw))
|
|||
|
|
if value == "" {
|
|||
|
|
return ""
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Normalize separators for "x-high"/"x_high" variants.
|
|||
|
|
value = strings.NewReplacer("-", "", "_", "", " ", "").Replace(value)
|
|||
|
|
|
|||
|
|
switch value {
|
|||
|
|
case "none", "minimal":
|
|||
|
|
return ""
|
|||
|
|
case "low", "medium", "high":
|
|||
|
|
return value
|
|||
|
|
case "xhigh", "extrahigh":
|
|||
|
|
return "xhigh"
|
|||
|
|
default:
|
|||
|
|
// Only store known effort levels for now to keep UI consistent.
|
|||
|
|
return ""
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func optionalTrimmedStringPtr(raw string) *string {
|
|||
|
|
trimmed := strings.TrimSpace(raw)
|
|||
|
|
if trimmed == "" {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
return &trimmed
|
|||
|
|
}
|