Files
tokens-reef/backend/internal/service/sora_upstream_forwarder.go
User d96a9f384a
Some checks failed
CI / test (push) Has been cancelled
CI / golangci-lint (push) Has been cancelled
Security Scan / backend-security (push) Has been cancelled
Security Scan / frontend-security (push) Has been cancelled
feat: merge sub2apipro features and add Chinese model pricing
## Merged Features from sub2apipro
- Sora video generation integration (OpenAI Sora API)
- Group management enhancements
- Usage log improvements
- Security headers middleware

## Chinese Model Pricing Updates
- GLM-5, GLM-5-Turbo, GLM-5.1, GLM-4.7, GLM-4.5-Air
- Baichuan4, Baichuan4-Turbo, Baichuan4-Air, Baichuan-M3-Plus
- DeepSeek-V3, DeepSeek-V3.2, DeepSeek-R1
- Qwen3-8B (free), Qwen2.5-72B-Instruct

## URL Whitelist Additions
- api.baichuan-ai.com (百川智能)
- api.siliconflow.cn (硅基流动)
- api.z.ai (智谱国际)
- api.groq.com (Groq加速推理)

## Documentation
- Added merge guide (docs/MERGE_GUIDE.md)
- Added quick reference (docs/MERGE_QUICKREF.md)
- Added review reports (docs/reviews/)
2026-04-15 12:02:07 +08:00

150 lines
4.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
)
// forwardToUpstream 将请求 HTTP 透传到上游 Sora 服务(用于 apikey 类型账号)。
// 上游地址为 account.GetBaseURL() + "/sora/v1/chat/completions"
// 使用 account.GetCredential("api_key") 作为 Bearer Token。
// 支持流式和非流式响应的直接透传。
func (s *SoraGatewayService) forwardToUpstream(
ctx context.Context,
c *gin.Context,
account *Account,
body []byte,
clientStream bool,
startTime time.Time,
) (*ForwardResult, error) {
apiKey := account.GetCredential("api_key")
if apiKey == "" {
s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Sora apikey account missing api_key credential", clientStream)
return nil, fmt.Errorf("sora apikey account %d missing api_key", account.ID)
}
baseURL := account.GetBaseURL()
if baseURL == "" {
s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Sora apikey account missing base_url", clientStream)
return nil, fmt.Errorf("sora apikey account %d missing base_url", account.ID)
}
// 校验 scheme 合法性(仅允许 http/https
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Sora apikey base_url must start with http:// or https://", clientStream)
return nil, fmt.Errorf("sora apikey account %d invalid base_url scheme: %s", account.ID, baseURL)
}
upstreamURL := strings.TrimRight(baseURL, "/") + "/sora/v1/chat/completions"
// 构建上游请求
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(body))
if err != nil {
s.writeSoraError(c, http.StatusInternalServerError, "api_error", "Failed to create upstream request", clientStream)
return nil, fmt.Errorf("create upstream request: %w", err)
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+apiKey)
// 透传客户端的部分请求头
for _, header := range []string{"Accept", "Accept-Encoding"} {
if v := c.GetHeader(header); v != "" {
upstreamReq.Header.Set(header, v)
}
}
logger.LegacyPrintf("service.sora", "[ForwardUpstream] account=%d url=%s", account.ID, upstreamURL)
// 获取代理 URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 发送请求
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Failed to connect to upstream Sora service", clientStream)
return nil, &UpstreamFailoverError{
StatusCode: http.StatusBadGateway,
}
}
defer func() {
_ = resp.Body.Close()
}()
// 错误响应处理
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 64*1024))
if s.shouldFailoverUpstreamError(resp.StatusCode) {
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
ResponseHeaders: resp.Header.Clone(),
}
}
// 非转移错误,直接透传给客户端
c.Status(resp.StatusCode)
for key, values := range resp.Header {
for _, v := range values {
c.Writer.Header().Add(key, v)
}
}
if _, err := c.Writer.Write(respBody); err != nil {
return nil, fmt.Errorf("write upstream error response: %w", err)
}
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
}
// 成功响应 — 直接透传
c.Status(resp.StatusCode)
for key, values := range resp.Header {
lower := strings.ToLower(key)
// 透传内容相关头部
if lower == "content-type" || lower == "transfer-encoding" ||
lower == "cache-control" || lower == "x-request-id" {
for _, v := range values {
c.Writer.Header().Add(key, v)
}
}
}
// 流式复制响应体
if flusher, ok := c.Writer.(http.Flusher); ok && clientStream {
buf := make([]byte, 4096)
for {
n, readErr := resp.Body.Read(buf)
if n > 0 {
if _, err := c.Writer.Write(buf[:n]); err != nil {
return nil, fmt.Errorf("stream upstream response write: %w", err)
}
flusher.Flush()
}
if readErr != nil {
break
}
}
} else {
if _, err := io.Copy(c.Writer, resp.Body); err != nil {
return nil, fmt.Errorf("copy upstream response: %w", err)
}
}
duration := time.Since(startTime)
return &ForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Model: "", // 由调用方填充
Stream: clientStream,
Duration: duration,
}, nil
}