refactor: 彻底移除 Sora 视频生成模块(全栈清理)
Some checks failed
CI / test (push) Has been cancelled
CI / golangci-lint (push) Has been cancelled
Security Scan / backend-security (push) Has been cancelled
Security Scan / frontend-security (push) Has been cancelled

## 后端变更
- 删除 21 个 sora_*.go 服务文件(service/handler/repository/routes)
- 删除 Sora 相关 migration 文件(046/047/063/090)
- 清理 config 中的 sora_* 配置项和平台常量
- 清理 wire 依赖注入中的 Sora 组件
- 修复 wire_gen.go 语法错误(缺少逗号和闭合括号)
- 移除 go.mod 中的 go-sora2api 依赖
- 更新 ent schema usage_log.go 注释

## 前端变更
- 删除 SoraView、SoraAdminView 及 8 个 Sora 子组件
- 删除 sora API 层和路由配置
- 清理 UserEditModal 中的 Sora 存储配额 UI
- 清理 types/index.ts 中 Sora 相关类型定义
- 清理 stores/app.ts 默认配置
- 清理 i18n 翻译文件 en.ts/zh.ts (~110 行)
- 更新相关测试文件

## 文档更新
- README.md / README_CN.md / README_JA.md: 移除 Sora 状态说明和配置段落
- PROJECT_DIFF.md: 移除 Sora 相关差异描述

## 验证结果
-  Go 编译通过 (go build ./...)
-  TypeScript 类型检查通过 (vue-tsc --noEmit)
-  后端测试全通过 (0 failures)
-  前端测试全通过 (59 files, 329 tests, 0 failures)
-  前端生产构建成功 (23.81s)
This commit is contained in:
2026-05-10 14:15:45 +08:00
parent 1da074cfd6
commit 0e057904e6
96 changed files with 726 additions and 20525 deletions

View File

@@ -1,142 +0,0 @@
package admin
import (
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// SoraHandler handles admin Sora statistics and management.
type SoraHandler struct {
soraGenService *service.SoraGenerationService
soraQuotaService *service.SoraQuotaService
userRepo service.UserRepository
}
// NewSoraHandler creates a new admin Sora handler.
func NewSoraHandler(
soraGenService *service.SoraGenerationService,
soraQuotaService *service.SoraQuotaService,
userRepo service.UserRepository,
) *SoraHandler {
return &SoraHandler{
soraGenService: soraGenService,
soraQuotaService: soraQuotaService,
userRepo: userRepo,
}
}
type SoraSystemStatsResponse struct {
TotalUsers int64 `json:"total_users"`
TotalGenerations int64 `json:"total_generations"`
TotalStorageBytes int64 `json:"total_storage_bytes"`
ActiveGenerations int64 `json:"active_generations"`
ByStatus map[string]int64 `json:"by_status"`
ByModel map[string]int64 `json:"by_model"`
}
// GetSystemStats returns aggregate admin Sora statistics.
func (h *SoraHandler) GetSystemStats(c *gin.Context) {
ctx := c.Request.Context()
users, _, err := h.userRepo.List(ctx, pagination.PaginationParams{Page: 1, PageSize: 10000})
if err != nil {
response.Error(c, 500, "Failed to get users")
return
}
resp := SoraSystemStatsResponse{
TotalUsers: int64(len(users)),
TotalGenerations: 0,
TotalStorageBytes: 0,
ActiveGenerations: 0,
ByStatus: map[string]int64{},
ByModel: map[string]int64{},
}
response.Success(c, resp)
}
type SoraUserStatsResponse struct {
UserID int64 `json:"user_id"`
Username string `json:"username"`
Email string `json:"email"`
QuotaBytes int64 `json:"quota_bytes"`
UsedBytes int64 `json:"used_bytes"`
AvailableBytes int64 `json:"available_bytes"`
QuotaSource string `json:"quota_source"`
GenerationsCount int64 `json:"generations_count"`
ActiveCount int64 `json:"active_count"`
TotalFileSizeBytes int64 `json:"total_file_size_bytes"`
}
// ListUserStats returns per-user admin Sora usage rows.
func (h *SoraHandler) ListUserStats(c *gin.Context) {
ctx := c.Request.Context()
page, pageSize := response.ParsePagination(c)
search := c.Query("search")
users, result, err := h.userRepo.ListWithFilters(ctx, pagination.PaginationParams{
Page: page,
PageSize: pageSize,
}, service.UserListFilters{Search: search})
if err != nil {
response.Error(c, 500, "Failed to get users")
return
}
results := make([]SoraUserStatsResponse, len(users))
for i, u := range users {
quota, _ := h.soraQuotaService.GetQuota(ctx, u.ID)
activeCount, _ := h.soraGenService.CountActiveByUser(ctx, u.ID)
quotaBytes := int64(0)
availableBytes := int64(0)
quotaSource := "unlimited"
if quota != nil {
quotaBytes = quota.QuotaBytes
availableBytes = quota.AvailableBytes
quotaSource = quota.QuotaSource
}
results[i] = SoraUserStatsResponse{
UserID: u.ID,
Username: u.Username,
Email: u.Email,
QuotaBytes: quotaBytes,
UsedBytes: 0,
AvailableBytes: availableBytes,
QuotaSource: quotaSource,
GenerationsCount: 0,
ActiveCount: activeCount,
TotalFileSizeBytes: 0,
}
}
response.Paginated(c, results, result.Total, page, pageSize)
}
type SoraGenerationAdminResponse struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
Username string `json:"username"`
Email string `json:"email"`
Model string `json:"model"`
Prompt string `json:"prompt"`
MediaType string `json:"media_type"`
Status string `json:"status"`
StorageType string `json:"storage_type"`
MediaURL string `json:"media_url"`
FileSizeBytes int64 `json:"file_size_bytes"`
ErrorMessage string `json:"error_message"`
CreatedAt string `json:"created_at"`
CompletedAt *string `json:"completed_at"`
}
// ListGenerations returns admin-visible generation rows.
func (h *SoraHandler) ListGenerations(c *gin.Context) {
response.Paginated(c, []SoraGenerationAdminResponse{}, int64(0), 1, 20)
}

View File

@@ -1,262 +0,0 @@
package admin
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func TestSoraHandler_ListGenerations(t *testing.T) {
gin.SetMode(gin.TestMode)
handler := &SoraHandler{}
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/admin/sora/generations", nil)
handler.ListGenerations(c)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "items")
}
func TestSoraSystemStatsResponse_Fields(t *testing.T) {
resp := SoraSystemStatsResponse{
TotalUsers: 10,
TotalGenerations: 100,
TotalStorageBytes: 1024 * 1024 * 1024,
ActiveGenerations: 5,
ByStatus: map[string]int64{"completed": 80, "failed": 20},
ByModel: map[string]int64{"sora2": 50, "sora1": 50},
}
assert.Equal(t, int64(10), resp.TotalUsers)
assert.Equal(t, int64(100), resp.TotalGenerations)
assert.Equal(t, int64(1024*1024*1024), resp.TotalStorageBytes)
assert.Equal(t, int64(5), resp.ActiveGenerations)
assert.Equal(t, int64(80), resp.ByStatus["completed"])
assert.Equal(t, int64(50), resp.ByModel["sora2"])
}
func TestSoraUserStatsResponse_Fields(t *testing.T) {
resp := SoraUserStatsResponse{
UserID: 1,
Username: "testuser",
Email: "test@example.com",
QuotaBytes: 10 * 1024 * 1024 * 1024,
UsedBytes: 1 * 1024 * 1024 * 1024,
AvailableBytes: 9 * 1024 * 1024 * 1024,
QuotaSource: "user",
GenerationsCount: 10,
ActiveCount: 2,
TotalFileSizeBytes: 1 * 1024 * 1024 * 1024,
}
assert.Equal(t, int64(1), resp.UserID)
assert.Equal(t, "testuser", resp.Username)
assert.Equal(t, "test@example.com", resp.Email)
assert.Equal(t, int64(10*1024*1024*1024), resp.QuotaBytes)
assert.Equal(t, int64(1*1024*1024*1024), resp.UsedBytes)
assert.Equal(t, "user", resp.QuotaSource)
assert.Equal(t, int64(10), resp.GenerationsCount)
assert.Equal(t, int64(2), resp.ActiveCount)
}
func TestSoraGenerationAdminResponse_Fields(t *testing.T) {
completedAt := "2024-01-01T12:00:00Z"
resp := SoraGenerationAdminResponse{
ID: 1,
UserID: 100,
Username: "testuser",
Email: "test@example.com",
Model: "sora2",
Prompt: "A beautiful sunset",
MediaType: "video",
Status: "completed",
StorageType: "s3",
MediaURL: "https://example.com/video.mp4",
FileSizeBytes: 1024 * 1024 * 10,
ErrorMessage: "",
CreatedAt: "2024-01-01T10:00:00Z",
CompletedAt: &completedAt,
}
assert.Equal(t, int64(1), resp.ID)
assert.Equal(t, int64(100), resp.UserID)
assert.Equal(t, "testuser", resp.Username)
assert.Equal(t, "sora2", resp.Model)
assert.Equal(t, "video", resp.MediaType)
assert.Equal(t, "completed", resp.Status)
assert.Equal(t, "s3", resp.StorageType)
assert.Equal(t, int64(1024*1024*10), resp.FileSizeBytes)
assert.NotNil(t, resp.CompletedAt)
}
func TestSoraGenerationAdminResponse_NilCompletedAt(t *testing.T) {
resp := SoraGenerationAdminResponse{
ID: 1,
UserID: 100,
Username: "testuser",
Email: "test@example.com",
Model: "sora2",
Prompt: "A beautiful sunset",
MediaType: "video",
Status: "pending",
StorageType: "upstream",
CreatedAt: "2024-01-01T10:00:00Z",
CompletedAt: nil,
}
assert.Equal(t, "pending", resp.Status)
assert.Nil(t, resp.CompletedAt)
}
func TestNewSoraHandler(t *testing.T) {
handler := NewSoraHandler(nil, nil, nil)
assert.NotNil(t, handler)
assert.Nil(t, handler.soraGenService)
assert.Nil(t, handler.soraQuotaService)
assert.Nil(t, handler.userRepo)
}
func TestUser_SoraFields(t *testing.T) {
user := &service.User{
ID: 1,
Email: "test@example.com",
}
assert.Equal(t, int64(1), user.ID)
assert.Equal(t, "test@example.com", user.Email)
}
func TestQuotaInfo_Fields(t *testing.T) {
quota := &service.QuotaInfo{
QuotaBytes: 10 * 1024 * 1024 * 1024,
UsedBytes: 1 * 1024 * 1024 * 1024,
AvailableBytes: 9 * 1024 * 1024 * 1024,
QuotaSource: "user",
}
assert.Equal(t, int64(10*1024*1024*1024), quota.QuotaBytes)
assert.Equal(t, int64(1*1024*1024*1024), quota.UsedBytes)
assert.Equal(t, "user", quota.QuotaSource)
}
func TestSoraSystemStatsResponse_JSON(t *testing.T) {
resp := SoraSystemStatsResponse{
TotalUsers: 10,
TotalGenerations: 100,
TotalStorageBytes: 1024,
ActiveGenerations: 5,
ByStatus: map[string]int64{"completed": 80},
ByModel: map[string]int64{"sora2": 50},
}
// Verify JSON tags by checking field values
assert.Equal(t, int64(10), resp.TotalUsers)
assert.Equal(t, int64(100), resp.TotalGenerations)
assert.Equal(t, int64(1024), resp.TotalStorageBytes)
assert.Equal(t, int64(5), resp.ActiveGenerations)
}
func TestSoraUserStatsResponse_JSON(t *testing.T) {
resp := SoraUserStatsResponse{
UserID: 1,
Username: "testuser",
Email: "test@example.com",
QuotaBytes: 1024,
UsedBytes: 512,
AvailableBytes: 512,
QuotaSource: "user",
GenerationsCount: 10,
ActiveCount: 2,
TotalFileSizeBytes: 1024,
}
// Verify all fields
assert.Equal(t, int64(1), resp.UserID)
assert.Equal(t, "testuser", resp.Username)
assert.Equal(t, "test@example.com", resp.Email)
assert.Equal(t, int64(1024), resp.QuotaBytes)
assert.Equal(t, int64(512), resp.UsedBytes)
assert.Equal(t, int64(512), resp.AvailableBytes)
assert.Equal(t, "user", resp.QuotaSource)
assert.Equal(t, int64(10), resp.GenerationsCount)
assert.Equal(t, int64(2), resp.ActiveCount)
assert.Equal(t, int64(1024), resp.TotalFileSizeBytes)
}
func TestSoraSystemStatsResponse_EmptyMaps(t *testing.T) {
resp := SoraSystemStatsResponse{
TotalUsers: 0,
TotalGenerations: 0,
TotalStorageBytes: 0,
ActiveGenerations: 0,
ByStatus: map[string]int64{},
ByModel: map[string]int64{},
}
assert.Equal(t, int64(0), resp.TotalUsers)
assert.Equal(t, int64(0), resp.TotalGenerations)
assert.Equal(t, int64(0), resp.TotalStorageBytes)
assert.Equal(t, int64(0), resp.ActiveGenerations)
assert.NotNil(t, resp.ByStatus)
assert.NotNil(t, resp.ByModel)
}
func TestSoraUserStatsResponse_QuotaSources(t *testing.T) {
sources := []string{"user", "group", "system", "unlimited"}
for _, source := range sources {
resp := SoraUserStatsResponse{
UserID: 1,
QuotaSource: source,
}
assert.Equal(t, source, resp.QuotaSource)
}
}
func TestSoraGenerationAdminResponse_Statuses(t *testing.T) {
statuses := []string{"pending", "generating", "completed", "failed", "cancelled"}
for _, status := range statuses {
resp := SoraGenerationAdminResponse{
ID: 1,
Status: status,
}
assert.Equal(t, status, resp.Status)
}
}
func TestSoraGenerationAdminResponse_MediaTypes(t *testing.T) {
mediaTypes := []string{"video", "image"}
for _, mt := range mediaTypes {
resp := SoraGenerationAdminResponse{
ID: 1,
MediaType: mt,
}
assert.Equal(t, mt, resp.MediaType)
}
}
func TestSoraGenerationAdminResponse_StorageTypes(t *testing.T) {
storageTypes := []string{"s3", "upstream"}
for _, st := range storageTypes {
resp := SoraGenerationAdminResponse{
ID: 1,
StorageType: st,
}
assert.Equal(t, st, resp.StorageType)
}
}

View File

@@ -179,7 +179,6 @@ type PublicSettings struct {
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
SoraClientEnabled bool `json:"sora_client_enabled"`
BackendModeEnabled bool `json:"backend_mode_enabled"`
PaymentEnabled bool `json:"payment_enabled"`
Version string `json:"version"`

View File

@@ -31,7 +31,6 @@ type AdminHandlers struct {
ScheduledTest *admin.ScheduledTestHandler
Channel *admin.ChannelHandler
Payment *admin.PaymentHandler
Sora *admin.SoraHandler
}
// Handlers contains all HTTP handlers
@@ -46,8 +45,6 @@ type Handlers struct {
Admin *AdminHandlers
Gateway *GatewayHandler
OpenAIGateway *OpenAIGatewayHandler
SoraGateway *SoraGatewayHandler // 从本地版本合并
SoraClient *SoraClientHandler // 从本地版本合并
Setting *SettingHandler
Totp *TotpHandler
Payment *PaymentHandler

View File

@@ -1,979 +0,0 @@
package handler
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
const (
// 上游模型缓存 TTL
modelCacheTTL = 1 * time.Hour // 上游获取成功
modelCacheFailedTTL = 2 * time.Minute // 上游获取失败(降级到本地)
)
// SoraClientHandler 处理 Sora 客户端 API 请求。
type SoraClientHandler struct {
genService *service.SoraGenerationService
quotaService *service.SoraQuotaService
s3Storage *service.SoraS3Storage
soraGatewayService *service.SoraGatewayService
gatewayService *service.GatewayService
mediaStorage *service.SoraMediaStorage
apiKeyService *service.APIKeyService
// 上游模型缓存
modelCacheMu sync.RWMutex
cachedFamilies []service.SoraModelFamily
modelCacheTime time.Time
modelCacheUpstream bool // 是否来自上游(决定 TTL
}
// NewSoraClientHandler 创建 Sora 客户端 Handler。
func NewSoraClientHandler(
genService *service.SoraGenerationService,
quotaService *service.SoraQuotaService,
s3Storage *service.SoraS3Storage,
soraGatewayService *service.SoraGatewayService,
gatewayService *service.GatewayService,
mediaStorage *service.SoraMediaStorage,
apiKeyService *service.APIKeyService,
) *SoraClientHandler {
return &SoraClientHandler{
genService: genService,
quotaService: quotaService,
s3Storage: s3Storage,
soraGatewayService: soraGatewayService,
gatewayService: gatewayService,
mediaStorage: mediaStorage,
apiKeyService: apiKeyService,
}
}
// GenerateRequest 生成请求。
type GenerateRequest struct {
Model string `json:"model" binding:"required"`
Prompt string `json:"prompt" binding:"required"`
MediaType string `json:"media_type"` // video / image默认 video
VideoCount int `json:"video_count,omitempty"` // 视频数量1-3
ImageInput string `json:"image_input,omitempty"` // 参考图base64 或 URL
APIKeyID *int64 `json:"api_key_id,omitempty"` // 前端传递的 API Key ID
}
// Generate 异步生成 — 创建 pending 记录后立即返回。
// POST /api/v1/sora/generate
func (h *SoraClientHandler) Generate(c *gin.Context) {
userID := getUserIDFromContext(c)
if userID == 0 {
response.Error(c, http.StatusUnauthorized, "未登录")
return
}
var req GenerateRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.Error(c, http.StatusBadRequest, "参数错误: "+err.Error())
return
}
if req.MediaType == "" {
req.MediaType = "video"
}
req.VideoCount = normalizeVideoCount(req.MediaType, req.VideoCount)
// 并发数检查(最多 3 个)
activeCount, err := h.genService.CountActiveByUser(c.Request.Context(), userID)
if err != nil {
response.ErrorFrom(c, err)
return
}
if activeCount >= 3 {
response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个")
return
}
// 配额检查(粗略检查,实际文件大小在上传后才知道)
if h.quotaService != nil {
if err := h.quotaService.CheckQuota(c.Request.Context(), userID, 0); err != nil {
var quotaErr *service.QuotaExceededError
if errors.As(err, &quotaErr) {
response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间")
return
}
response.Error(c, http.StatusForbidden, err.Error())
return
}
}
// 获取 API Key ID 和 Group ID
var apiKeyID *int64
var groupID *int64
if req.APIKeyID != nil && h.apiKeyService != nil {
// 前端传递了 api_key_id需要校验
apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), *req.APIKeyID)
if err != nil {
response.Error(c, http.StatusBadRequest, "API Key 不存在")
return
}
if apiKey.UserID != userID {
response.Error(c, http.StatusForbidden, "API Key 不属于当前用户")
return
}
if apiKey.Status != service.StatusAPIKeyActive {
response.Error(c, http.StatusForbidden, "API Key 不可用")
return
}
apiKeyID = &apiKey.ID
groupID = apiKey.GroupID
} else if id, ok := c.Get("api_key_id"); ok {
// 兼容 API Key 认证路径(/sora/v1/ 网关路由)
if v, ok := id.(int64); ok {
apiKeyID = &v
}
}
gen, err := h.genService.CreatePending(c.Request.Context(), userID, apiKeyID, req.Model, req.Prompt, req.MediaType)
if err != nil {
if errors.Is(err, service.ErrSoraGenerationConcurrencyLimit) {
response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个")
return
}
response.ErrorFrom(c, err)
return
}
// 启动后台异步生成 goroutine
go h.processGeneration(gen.ID, userID, groupID, req.Model, req.Prompt, req.MediaType, req.ImageInput, req.VideoCount)
response.Success(c, gin.H{
"generation_id": gen.ID,
"status": gen.Status,
})
}
// processGeneration 后台异步执行 Sora 生成任务。
// 流程:选择账号 → Forward → 提取媒体 URL → 三层降级存储S3 → 本地 → 上游)→ 更新记录。
func (h *SoraClientHandler) processGeneration(genID int64, userID int64, groupID *int64, model, prompt, mediaType, imageInput string, videoCount int) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
// 标记为生成中
if err := h.genService.MarkGenerating(ctx, genID, ""); err != nil {
if errors.Is(err, service.ErrSoraGenerationStateConflict) {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务状态已变化,跳过生成 id=%d", genID)
return
}
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记生成中失败 id=%d err=%v", genID, err)
return
}
logger.LegacyPrintf(
"handler.sora_client",
"[SoraClient] 开始生成 id=%d user=%d group=%d model=%s media_type=%s video_count=%d has_image=%v prompt_len=%d",
genID,
userID,
groupIDForLog(groupID),
model,
mediaType,
videoCount,
strings.TrimSpace(imageInput) != "",
len(strings.TrimSpace(prompt)),
)
// 有 groupID 时由分组决定平台,无 groupID 时用 ForcePlatform 兜底
if groupID == nil {
ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora)
}
if h.gatewayService == nil {
_ = h.genService.MarkFailed(ctx, genID, "内部错误: gatewayService 未初始化")
return
}
// 选择 Sora 账号
account, err := h.gatewayService.SelectAccountForModel(ctx, groupID, "", model)
if err != nil {
logger.LegacyPrintf(
"handler.sora_client",
"[SoraClient] 选择账号失败 id=%d user=%d group=%d model=%s err=%v",
genID,
userID,
groupIDForLog(groupID),
model,
err,
)
_ = h.genService.MarkFailed(ctx, genID, "选择账号失败: "+err.Error())
return
}
logger.LegacyPrintf(
"handler.sora_client",
"[SoraClient] 选中账号 id=%d user=%d group=%d model=%s account_id=%d account_name=%s platform=%s type=%s",
genID,
userID,
groupIDForLog(groupID),
model,
account.ID,
account.Name,
account.Platform,
account.Type,
)
// 构建 chat completions 请求体(非流式)
body := buildAsyncRequestBody(model, prompt, imageInput, normalizeVideoCount(mediaType, videoCount))
if h.soraGatewayService == nil {
_ = h.genService.MarkFailed(ctx, genID, "内部错误: soraGatewayService 未初始化")
return
}
// 创建 mock gin 上下文用于 Forward捕获响应以提取媒体 URL
recorder := httptest.NewRecorder()
mockGinCtx, _ := gin.CreateTestContext(recorder)
mockGinCtx.Request, _ = http.NewRequest("POST", "/", nil)
// 调用 Forward非流式
result, err := h.soraGatewayService.Forward(ctx, mockGinCtx, account, body, false)
if err != nil {
logger.LegacyPrintf(
"handler.sora_client",
"[SoraClient] Forward失败 id=%d account_id=%d model=%s status=%d body=%s err=%v",
genID,
account.ID,
model,
recorder.Code,
trimForLog(recorder.Body.String(), 400),
err,
)
// 检查是否已取消
gen, _ := h.genService.GetByID(ctx, genID, userID)
if gen != nil && gen.Status == service.SoraGenStatusCancelled {
return
}
_ = h.genService.MarkFailed(ctx, genID, "生成失败: "+err.Error())
return
}
// 提取媒体 URL优先从 ForwardResult其次从响应体解析
mediaURL, mediaURLs := extractMediaURLsFromResult(result, recorder)
if mediaURL == "" {
logger.LegacyPrintf(
"handler.sora_client",
"[SoraClient] 未提取到媒体URL id=%d account_id=%d model=%s status=%d body=%s",
genID,
account.ID,
model,
recorder.Code,
trimForLog(recorder.Body.String(), 400),
)
_ = h.genService.MarkFailed(ctx, genID, "未获取到媒体 URL")
return
}
// 检查任务是否已被取消
gen, _ := h.genService.GetByID(ctx, genID, userID)
if gen != nil && gen.Status == service.SoraGenStatusCancelled {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务已取消,跳过存储 id=%d", genID)
return
}
// 三层降级存储S3 → 本地 → 上游临时 URL
storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(ctx, userID, mediaType, mediaURL, mediaURLs)
usageAdded := false
if (storageType == service.SoraStorageTypeS3 || storageType == service.SoraStorageTypeLocal) && fileSize > 0 && h.quotaService != nil {
if err := h.quotaService.AddUsage(ctx, userID, fileSize); err != nil {
h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
var quotaErr *service.QuotaExceededError
if errors.As(err, &quotaErr) {
_ = h.genService.MarkFailed(ctx, genID, "存储配额已满,请删除不需要的作品释放空间")
return
}
_ = h.genService.MarkFailed(ctx, genID, "存储配额更新失败: "+err.Error())
return
}
usageAdded = true
}
// 存储完成后再做一次取消检查,防止取消被 completed 覆盖。
gen, _ = h.genService.GetByID(ctx, genID, userID)
if gen != nil && gen.Status == service.SoraGenStatusCancelled {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 存储后检测到任务已取消,回滚存储 id=%d", genID)
h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
if usageAdded && h.quotaService != nil {
_ = h.quotaService.ReleaseUsage(ctx, userID, fileSize)
}
return
}
// 标记完成
if err := h.genService.MarkCompleted(ctx, genID, storedURL, storedURLs, storageType, s3Keys, fileSize); err != nil {
if errors.Is(err, service.ErrSoraGenerationStateConflict) {
h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
if usageAdded && h.quotaService != nil {
_ = h.quotaService.ReleaseUsage(ctx, userID, fileSize)
}
return
}
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记完成失败 id=%d err=%v", genID, err)
return
}
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成完成 id=%d storage=%s size=%d", genID, storageType, fileSize)
}
// storeMediaWithDegradation 实现三层降级存储链S3 → 本地 → 上游。
func (h *SoraClientHandler) storeMediaWithDegradation(
ctx context.Context, userID int64, mediaType string,
mediaURL string, mediaURLs []string,
) (storedURL string, storedURLs []string, storageType string, s3Keys []string, fileSize int64) {
urls := mediaURLs
if len(urls) == 0 {
urls = []string{mediaURL}
}
// 第一层:尝试 S3
if h.s3Storage != nil && h.s3Storage.Enabled(ctx) {
keys := make([]string, 0, len(urls))
var totalSize int64
allOK := true
for _, u := range urls {
key, size, err := h.s3Storage.UploadFromURL(ctx, userID, u)
if err != nil {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] S3 上传失败 err=%v", err)
allOK = false
// 清理已上传的文件
if len(keys) > 0 {
_ = h.s3Storage.DeleteObjects(ctx, keys)
}
break
}
keys = append(keys, key)
totalSize += size
}
if allOK && len(keys) > 0 {
accessURLs := make([]string, 0, len(keys))
for _, key := range keys {
accessURL, err := h.s3Storage.GetAccessURL(ctx, key)
if err != nil {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成 S3 访问 URL 失败 err=%v", err)
_ = h.s3Storage.DeleteObjects(ctx, keys)
allOK = false
break
}
accessURLs = append(accessURLs, accessURL)
}
if allOK && len(accessURLs) > 0 {
return accessURLs[0], accessURLs, service.SoraStorageTypeS3, keys, totalSize
}
}
}
// 第二层:尝试本地存储
if h.mediaStorage != nil && h.mediaStorage.Enabled() {
storedPaths, err := h.mediaStorage.StoreFromURLs(ctx, mediaType, urls)
if err == nil && len(storedPaths) > 0 {
firstPath := storedPaths[0]
totalSize, sizeErr := h.mediaStorage.TotalSizeByRelativePaths(storedPaths)
if sizeErr != nil {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 统计本地文件大小失败 err=%v", sizeErr)
}
return firstPath, storedPaths, service.SoraStorageTypeLocal, nil, totalSize
}
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 本地存储失败 err=%v", err)
}
// 第三层:保留上游临时 URL
return urls[0], urls, service.SoraStorageTypeUpstream, nil, 0
}
// buildAsyncRequestBody 构建 Sora 异步生成的 chat completions 请求体。
func buildAsyncRequestBody(model, prompt, imageInput string, videoCount int) []byte {
body := map[string]any{
"model": model,
"messages": []map[string]string{
{"role": "user", "content": prompt},
},
"stream": false,
}
if imageInput != "" {
body["image_input"] = imageInput
}
if videoCount > 1 {
body["video_count"] = videoCount
}
b, _ := json.Marshal(body)
return b
}
func normalizeVideoCount(mediaType string, videoCount int) int {
if mediaType != "video" {
return 1
}
if videoCount <= 0 {
return 1
}
if videoCount > 3 {
return 3
}
return videoCount
}
// extractMediaURLsFromResult 从 Forward 结果和响应体中提取媒体 URL。
// OAuth 路径ForwardResult.MediaURL 已填充。
// APIKey 路径:需从响应体解析 media_url / media_urls 字段。
func extractMediaURLsFromResult(result *service.ForwardResult, recorder *httptest.ResponseRecorder) (string, []string) {
// 优先从 ForwardResult 获取OAuth 路径)
if result != nil && result.MediaURL != "" {
// 尝试从响应体获取完整 URL 列表
if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 {
return urls[0], urls
}
return result.MediaURL, []string{result.MediaURL}
}
// 从响应体解析APIKey 路径)
if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 {
return urls[0], urls
}
return "", nil
}
// parseMediaURLsFromBody 从 JSON 响应体中解析 media_url / media_urls 字段。
func parseMediaURLsFromBody(body []byte) []string {
if len(body) == 0 {
return nil
}
var resp map[string]any
if err := json.Unmarshal(body, &resp); err != nil {
return nil
}
// 优先 media_urls多图数组
if rawURLs, ok := resp["media_urls"]; ok {
if arr, ok := rawURLs.([]any); ok && len(arr) > 0 {
urls := make([]string, 0, len(arr))
for _, item := range arr {
if s, ok := item.(string); ok && s != "" {
urls = append(urls, s)
}
}
if len(urls) > 0 {
return urls
}
}
}
// 回退到 media_url单个 URL
if url, ok := resp["media_url"].(string); ok && url != "" {
return []string{url}
}
return nil
}
// ListGenerations 查询生成记录列表。
// GET /api/v1/sora/generations
func (h *SoraClientHandler) ListGenerations(c *gin.Context) {
userID := getUserIDFromContext(c)
if userID == 0 {
response.Error(c, http.StatusUnauthorized, "未登录")
return
}
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
params := service.SoraGenerationListParams{
UserID: userID,
Status: c.Query("status"),
StorageType: c.Query("storage_type"),
MediaType: c.Query("media_type"),
Page: page,
PageSize: pageSize,
}
gens, total, err := h.genService.List(c.Request.Context(), params)
if err != nil {
response.ErrorFrom(c, err)
return
}
// 为 S3 记录动态生成预签名 URL
for _, gen := range gens {
_ = h.genService.ResolveMediaURLs(c.Request.Context(), gen)
}
response.Success(c, gin.H{
"data": gens,
"total": total,
"page": page,
})
}
// GetGeneration 查询生成记录详情。
// GET /api/v1/sora/generations/:id
func (h *SoraClientHandler) GetGeneration(c *gin.Context) {
userID := getUserIDFromContext(c)
if userID == 0 {
response.Error(c, http.StatusUnauthorized, "未登录")
return
}
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.Error(c, http.StatusBadRequest, "无效的 ID")
return
}
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
if err != nil {
response.Error(c, http.StatusNotFound, err.Error())
return
}
_ = h.genService.ResolveMediaURLs(c.Request.Context(), gen)
response.Success(c, gen)
}
// DeleteGeneration 删除生成记录。
// DELETE /api/v1/sora/generations/:id
func (h *SoraClientHandler) DeleteGeneration(c *gin.Context) {
userID := getUserIDFromContext(c)
if userID == 0 {
response.Error(c, http.StatusUnauthorized, "未登录")
return
}
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.Error(c, http.StatusBadRequest, "无效的 ID")
return
}
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
if err != nil {
response.Error(c, http.StatusNotFound, err.Error())
return
}
// 先尝试清理本地文件,再删除记录(清理失败不阻塞删除)。
if gen.StorageType == service.SoraStorageTypeLocal && h.mediaStorage != nil {
paths := gen.MediaURLs
if len(paths) == 0 && gen.MediaURL != "" {
paths = []string{gen.MediaURL}
}
if err := h.mediaStorage.DeleteByRelativePaths(paths); err != nil {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 删除本地文件失败 id=%d err=%v", id, err)
}
}
if err := h.genService.Delete(c.Request.Context(), id, userID); err != nil {
response.Error(c, http.StatusNotFound, err.Error())
return
}
response.Success(c, gin.H{"message": "已删除"})
}
// GetQuota 查询用户存储配额。
// GET /api/v1/sora/quota
func (h *SoraClientHandler) GetQuota(c *gin.Context) {
userID := getUserIDFromContext(c)
if userID == 0 {
response.Error(c, http.StatusUnauthorized, "未登录")
return
}
if h.quotaService == nil {
response.Success(c, service.QuotaInfo{QuotaSource: "unlimited", Source: "unlimited"})
return
}
quota, err := h.quotaService.GetQuota(c.Request.Context(), userID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, quota)
}
// CancelGeneration 取消生成任务。
// POST /api/v1/sora/generations/:id/cancel
func (h *SoraClientHandler) CancelGeneration(c *gin.Context) {
userID := getUserIDFromContext(c)
if userID == 0 {
response.Error(c, http.StatusUnauthorized, "未登录")
return
}
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.Error(c, http.StatusBadRequest, "无效的 ID")
return
}
// 权限校验
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
if err != nil {
response.Error(c, http.StatusNotFound, err.Error())
return
}
_ = gen
if err := h.genService.MarkCancelled(c.Request.Context(), id); err != nil {
if errors.Is(err, service.ErrSoraGenerationNotActive) {
response.Error(c, http.StatusConflict, "任务已结束,无法取消")
return
}
response.Error(c, http.StatusBadRequest, err.Error())
return
}
response.Success(c, gin.H{"message": "已取消"})
}
// SaveToStorage 手动保存 upstream 记录到 S3。
// POST /api/v1/sora/generations/:id/save
func (h *SoraClientHandler) SaveToStorage(c *gin.Context) {
userID := getUserIDFromContext(c)
if userID == 0 {
response.Error(c, http.StatusUnauthorized, "未登录")
return
}
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.Error(c, http.StatusBadRequest, "无效的 ID")
return
}
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
if err != nil {
response.Error(c, http.StatusNotFound, err.Error())
return
}
if gen.StorageType != service.SoraStorageTypeUpstream {
response.Error(c, http.StatusBadRequest, "仅 upstream 类型的记录可手动保存")
return
}
if gen.MediaURL == "" {
response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期")
return
}
if h.s3Storage == nil || !h.s3Storage.Enabled(c.Request.Context()) {
response.Error(c, http.StatusServiceUnavailable, "云存储未配置,请联系管理员")
return
}
sourceURLs := gen.MediaURLs
if len(sourceURLs) == 0 && gen.MediaURL != "" {
sourceURLs = []string{gen.MediaURL}
}
if len(sourceURLs) == 0 {
response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期")
return
}
uploadedKeys := make([]string, 0, len(sourceURLs))
accessURLs := make([]string, 0, len(sourceURLs))
var totalSize int64
for _, sourceURL := range sourceURLs {
objectKey, fileSize, uploadErr := h.s3Storage.UploadFromURL(c.Request.Context(), userID, sourceURL)
if uploadErr != nil {
if len(uploadedKeys) > 0 {
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
}
var upstreamErr *service.UpstreamDownloadError
if errors.As(uploadErr, &upstreamErr) && (upstreamErr.StatusCode == http.StatusForbidden || upstreamErr.StatusCode == http.StatusNotFound) {
response.Error(c, http.StatusGone, "媒体链接已过期,无法保存")
return
}
response.Error(c, http.StatusInternalServerError, "上传到 S3 失败: "+uploadErr.Error())
return
}
accessURL, err := h.s3Storage.GetAccessURL(c.Request.Context(), objectKey)
if err != nil {
uploadedKeys = append(uploadedKeys, objectKey)
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
response.Error(c, http.StatusInternalServerError, "生成 S3 访问链接失败: "+err.Error())
return
}
uploadedKeys = append(uploadedKeys, objectKey)
accessURLs = append(accessURLs, accessURL)
totalSize += fileSize
}
usageAdded := false
if totalSize > 0 && h.quotaService != nil {
if err := h.quotaService.AddUsage(c.Request.Context(), userID, totalSize); err != nil {
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
var quotaErr *service.QuotaExceededError
if errors.As(err, &quotaErr) {
response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间")
return
}
response.Error(c, http.StatusInternalServerError, "配额更新失败: "+err.Error())
return
}
usageAdded = true
}
if err := h.genService.UpdateStorageForCompleted(
c.Request.Context(),
id,
accessURLs[0],
accessURLs,
service.SoraStorageTypeS3,
uploadedKeys,
totalSize,
); err != nil {
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
if usageAdded && h.quotaService != nil {
_ = h.quotaService.ReleaseUsage(c.Request.Context(), userID, totalSize)
}
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{
"message": "已保存到 S3",
"object_key": uploadedKeys[0],
"object_keys": uploadedKeys,
})
}
// GetStorageStatus 返回存储状态。
// GET /api/v1/sora/storage-status
func (h *SoraClientHandler) GetStorageStatus(c *gin.Context) {
s3Enabled := h.s3Storage != nil && h.s3Storage.Enabled(c.Request.Context())
s3Healthy := false
if s3Enabled {
s3Healthy = h.s3Storage.IsHealthy(c.Request.Context())
}
localEnabled := h.mediaStorage != nil && h.mediaStorage.Enabled()
response.Success(c, gin.H{
"s3_enabled": s3Enabled,
"s3_healthy": s3Healthy,
"local_enabled": localEnabled,
})
}
func (h *SoraClientHandler) cleanupStoredMedia(ctx context.Context, storageType string, s3Keys []string, localPaths []string) {
switch storageType {
case service.SoraStorageTypeS3:
if h.s3Storage != nil && len(s3Keys) > 0 {
if err := h.s3Storage.DeleteObjects(ctx, s3Keys); err != nil {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理 S3 文件失败 keys=%v err=%v", s3Keys, err)
}
}
case service.SoraStorageTypeLocal:
if h.mediaStorage != nil && len(localPaths) > 0 {
if err := h.mediaStorage.DeleteByRelativePaths(localPaths); err != nil {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理本地文件失败 paths=%v err=%v", localPaths, err)
}
}
}
}
// getUserIDFromContext 从 gin 上下文中提取用户 ID。
func getUserIDFromContext(c *gin.Context) int64 {
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok && subject.UserID > 0 {
return subject.UserID
}
if id, ok := c.Get("user_id"); ok {
switch v := id.(type) {
case int64:
return v
case float64:
return int64(v)
case string:
n, _ := strconv.ParseInt(v, 10, 64)
return n
}
}
// 尝试从 JWT claims 获取
if id, ok := c.Get("userID"); ok {
if v, ok := id.(int64); ok {
return v
}
}
return 0
}
func groupIDForLog(groupID *int64) int64 {
if groupID == nil {
return 0
}
return *groupID
}
func trimForLog(raw string, maxLen int) string {
trimmed := strings.TrimSpace(raw)
if maxLen <= 0 || len(trimmed) <= maxLen {
return trimmed
}
return trimmed[:maxLen] + "...(truncated)"
}
// GetModels 获取可用 Sora 模型家族列表。
// 优先从上游 Sora API 同步模型列表,失败时降级到本地配置。
// GET /api/v1/sora/models
func (h *SoraClientHandler) GetModels(c *gin.Context) {
families := h.getModelFamilies(c.Request.Context())
response.Success(c, families)
}
// getModelFamilies 获取模型家族列表(带缓存)。
func (h *SoraClientHandler) getModelFamilies(ctx context.Context) []service.SoraModelFamily {
// 读锁检查缓存
h.modelCacheMu.RLock()
ttl := modelCacheTTL
if !h.modelCacheUpstream {
ttl = modelCacheFailedTTL
}
if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl {
families := h.cachedFamilies
h.modelCacheMu.RUnlock()
return families
}
h.modelCacheMu.RUnlock()
// 写锁更新缓存
h.modelCacheMu.Lock()
defer h.modelCacheMu.Unlock()
// double-check
ttl = modelCacheTTL
if !h.modelCacheUpstream {
ttl = modelCacheFailedTTL
}
if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl {
return h.cachedFamilies
}
// 尝试从上游获取
families, err := h.fetchUpstreamModels(ctx)
if err != nil {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 上游模型获取失败,使用本地配置: %v", err)
families = service.BuildSoraModelFamilies()
h.cachedFamilies = families
h.modelCacheTime = time.Now()
h.modelCacheUpstream = false
return families
}
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 从上游同步到 %d 个模型家族", len(families))
h.cachedFamilies = families
h.modelCacheTime = time.Now()
h.modelCacheUpstream = true
return families
}
// fetchUpstreamModels 从上游 Sora API 获取模型列表。
func (h *SoraClientHandler) fetchUpstreamModels(ctx context.Context) ([]service.SoraModelFamily, error) {
if h.gatewayService == nil {
return nil, fmt.Errorf("gatewayService 未初始化")
}
// 设置 ForcePlatform 用于 Sora 账号选择
ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora)
// 选择一个 Sora 账号
account, err := h.gatewayService.SelectAccountForModel(ctx, nil, "", "sora2-landscape-10s")
if err != nil {
return nil, fmt.Errorf("选择 Sora 账号失败: %w", err)
}
// 仅支持 API Key 类型账号
if account.Type != service.AccountTypeAPIKey {
return nil, fmt.Errorf("当前账号类型 %s 不支持模型同步", account.Type)
}
apiKey := account.GetCredential("api_key")
if apiKey == "" {
return nil, fmt.Errorf("账号缺少 api_key")
}
baseURL := account.GetBaseURL()
if baseURL == "" {
return nil, fmt.Errorf("账号缺少 base_url")
}
// 构建上游模型列表请求
modelsURL := strings.TrimRight(baseURL, "/") + "/sora/v1/models"
reqCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, modelsURL, nil)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Authorization", "Bearer "+apiKey)
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("请求上游失败: %w", err)
}
defer func() {
_ = resp.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("上游返回状态码 %d", resp.StatusCode)
}
body, err := io.ReadAll(io.LimitReader(resp.Body, 1*1024*1024))
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
}
// 解析 OpenAI 格式的模型列表
var modelsResp struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
if err := json.Unmarshal(body, &modelsResp); err != nil {
return nil, fmt.Errorf("解析响应失败: %w", err)
}
if len(modelsResp.Data) == 0 {
return nil, fmt.Errorf("上游返回空模型列表")
}
// 提取模型 ID
modelIDs := make([]string, 0, len(modelsResp.Data))
for _, m := range modelsResp.Data {
modelIDs = append(modelIDs, m.ID)
}
// 转换为模型家族
families := service.BuildSoraModelFamiliesFromIDs(modelIDs)
if len(families) == 0 {
return nil, fmt.Errorf("未能从上游模型列表中识别出有效的模型家族")
}
return families, nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,694 +0,0 @@
package handler
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"net/http"
"os"
"path"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"go.uber.org/zap"
)
// SoraGatewayHandler handles Sora chat completions requests
type SoraGatewayHandler struct {
gatewayService *service.GatewayService
soraGatewayService *service.SoraGatewayService
billingCacheService *service.BillingCacheService
usageRecordWorkerPool *service.UsageRecordWorkerPool
concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
streamMode string
soraTLSEnabled bool
soraMediaSigningKey string
soraMediaRoot string
}
// NewSoraGatewayHandler creates a new SoraGatewayHandler
func NewSoraGatewayHandler(
gatewayService *service.GatewayService,
soraGatewayService *service.SoraGatewayService,
concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService,
usageRecordWorkerPool *service.UsageRecordWorkerPool,
cfg *config.Config,
) *SoraGatewayHandler {
pingInterval := time.Duration(0)
maxAccountSwitches := 3
streamMode := "force"
soraTLSEnabled := true
signKey := ""
mediaRoot := "/app/data/sora"
if cfg != nil {
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
if cfg.Gateway.MaxAccountSwitches > 0 {
maxAccountSwitches = cfg.Gateway.MaxAccountSwitches
}
if mode := strings.TrimSpace(cfg.Gateway.SoraStreamMode); mode != "" {
streamMode = mode
}
soraTLSEnabled = !cfg.Sora.Client.DisableTLSFingerprint
signKey = strings.TrimSpace(cfg.Gateway.SoraMediaSigningKey)
if root := strings.TrimSpace(cfg.Sora.Storage.LocalPath); root != "" {
mediaRoot = root
}
}
return &SoraGatewayHandler{
gatewayService: gatewayService,
soraGatewayService: soraGatewayService,
billingCacheService: billingCacheService,
usageRecordWorkerPool: usageRecordWorkerPool,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
maxAccountSwitches: maxAccountSwitches,
streamMode: strings.ToLower(streamMode),
soraTLSEnabled: soraTLSEnabled,
soraMediaSigningKey: signKey,
soraMediaRoot: mediaRoot,
}
}
// ChatCompletions handles Sora /v1/chat/completions endpoint
func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
reqLog := requestLogger(
c,
"handler.sora_gateway.chat_completions",
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
if len(body) == 0 {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
setOpsRequestContext(c, "", false, body)
// 校验请求体 JSON 合法性
if !gjson.ValidBytes(body) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
// 使用 gjson 只读提取字段做校验,避免完整 Unmarshal
modelResult := gjson.GetBytes(body, "model")
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return
}
reqModel := modelResult.String()
msgsResult := gjson.GetBytes(body, "messages")
if !msgsResult.IsArray() || len(msgsResult.Array()) == 0 {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "messages is required")
return
}
clientStream := gjson.GetBytes(body, "stream").Bool()
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", clientStream))
if !clientStream {
if h.streamMode == "error" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Sora requires stream=true")
return
}
var err error
body, err = sjson.SetBytes(body, "stream", true)
if err != nil {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
return
}
}
setOpsRequestContext(c, reqModel, clientStream, body)
platform := ""
if forced, ok := middleware2.GetForcePlatformFromContext(c); ok {
platform = forced
} else if apiKey.Group != nil {
platform = apiKey.Group.Platform
}
if platform != service.PlatformSora {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "This endpoint only supports Sora platform")
return
}
streamStarted := false
subscription, _ := middleware2.GetSubscriptionFromContext(c)
maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
waitCounted := false
if err != nil {
reqLog.Warn("sora.user_wait_counter_increment_failed", zap.Error(err))
} else if !canWait {
reqLog.Info("sora.user_wait_queue_full", zap.Int("max_wait", maxWait))
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return
}
if err == nil && canWait {
waitCounted = true
}
defer func() {
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
}
}()
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, clientStream, &streamStarted)
if err != nil {
reqLog.Warn("sora.user_slot_acquire_failed", zap.Error(err))
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
waitCounted = false
}
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil {
defer userReleaseFunc()
}
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("sora.billing_eligibility_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err)
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}
sessionHash := generateOpenAISessionHash(c, body)
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0
var lastFailoverBody []byte
var lastFailoverHeaders http.Header
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "", int64(0))
if err != nil {
reqLog.Warn("sora.account_select_failed",
zap.Error(err),
zap.Int("excluded_account_count", len(failedAccountIDs)),
)
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
fields := []zap.Field{
zap.Int("last_upstream_status", lastFailoverStatus),
}
if rayID != "" {
fields = append(fields, zap.String("last_upstream_cf_ray", rayID))
}
if mitigated != "" {
fields = append(fields, zap.String("last_upstream_cf_mitigated", mitigated))
}
if contentType != "" {
fields = append(fields, zap.String("last_upstream_content_type", contentType))
}
reqLog.Warn("sora.failover_exhausted_no_available_accounts", fields...)
h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted)
return
}
account := selection.Account
setOpsSelectedAccount(c, account.ID, account.Platform)
proxyBound := account.ProxyID != nil
proxyID := int64(0)
if account.ProxyID != nil {
proxyID = *account.ProxyID
}
tlsFingerprintEnabled := h.soraTLSEnabled
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
accountWaitCounted := false
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
reqLog.Warn("sora.account_wait_counter_increment_failed",
zap.Int64("account_id", account.ID),
zap.Int64("proxy_id", proxyID),
zap.Bool("proxy_bound", proxyBound),
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
zap.Error(err),
)
} else if !canWait {
reqLog.Info("sora.account_wait_queue_full",
zap.Int64("account_id", account.ID),
zap.Int64("proxy_id", proxyID),
zap.Bool("proxy_bound", proxyBound),
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
}
if err == nil && canWait {
accountWaitCounted = true
}
defer func() {
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}()
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
clientStream,
&streamStarted,
)
if err != nil {
reqLog.Warn("sora.account_slot_acquire_failed",
zap.Int64("account_id", account.ID),
zap.Int64("proxy_id", proxyID),
zap.Bool("proxy_bound", proxyBound),
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
zap.Error(err),
)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
}
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
result, err := h.soraGatewayService.Forward(c.Request.Context(), c, account, body, clientStream)
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
if switchCount >= maxAccountSwitches {
lastFailoverStatus = failoverErr.StatusCode
lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders)
lastFailoverBody = failoverErr.ResponseBody
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
fields := []zap.Field{
zap.Int64("account_id", account.ID),
zap.Int64("proxy_id", proxyID),
zap.Bool("proxy_bound", proxyBound),
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
}
if rayID != "" {
fields = append(fields, zap.String("upstream_cf_ray", rayID))
}
if mitigated != "" {
fields = append(fields, zap.String("upstream_cf_mitigated", mitigated))
}
if contentType != "" {
fields = append(fields, zap.String("upstream_content_type", contentType))
}
reqLog.Warn("sora.upstream_failover_exhausted", fields...)
h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted)
return
}
lastFailoverStatus = failoverErr.StatusCode
lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders)
lastFailoverBody = failoverErr.ResponseBody
switchCount++
upstreamErrCode, upstreamErrMsg := extractUpstreamErrorCodeAndMessage(lastFailoverBody)
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
fields := []zap.Field{
zap.Int64("account_id", account.ID),
zap.Int64("proxy_id", proxyID),
zap.Bool("proxy_bound", proxyBound),
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.String("upstream_error_code", upstreamErrCode),
zap.String("upstream_error_message", upstreamErrMsg),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
}
if rayID != "" {
fields = append(fields, zap.String("upstream_cf_ray", rayID))
}
if mitigated != "" {
fields = append(fields, zap.String("upstream_cf_mitigated", mitigated))
}
if contentType != "" {
fields = append(fields, zap.String("upstream_content_type", contentType))
}
reqLog.Warn("sora.upstream_failover_switching", fields...)
continue
}
reqLog.Error("sora.forward_failed",
zap.Int64("account_id", account.ID),
zap.Int64("proxy_id", proxyID),
zap.Bool("proxy_bound", proxyBound),
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
zap.Error(err),
)
return
}
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
}); err != nil {
logger.L().With(
zap.String("component", "handler.sora_gateway.chat_completions"),
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
zap.String("model", reqModel),
zap.Int64("account_id", account.ID),
).Error("sora.record_usage_failed", zap.Error(err))
}
})
reqLog.Debug("sora.request_completed",
zap.Int64("account_id", account.ID),
zap.Int64("proxy_id", proxyID),
zap.Bool("proxy_bound", proxyBound),
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
zap.Int("switch_count", switchCount),
)
return
}
}
func generateOpenAISessionHash(c *gin.Context, body []byte) string {
if c == nil {
return ""
}
sessionID := strings.TrimSpace(c.GetHeader("session_id"))
if sessionID == "" {
sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
}
if sessionID == "" && len(body) > 0 {
sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
}
if sessionID == "" {
return ""
}
hash := sha256.Sum256([]byte(sessionID))
return hex.EncodeToString(hash[:])
}
func (h *SoraGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
if task == nil {
return
}
if h.usageRecordWorkerPool != nil {
h.usageRecordWorkerPool.Submit(task)
return
}
// 回退路径worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
defer func() {
if recovered := recover(); recovered != nil {
logger.L().With(
zap.String("component", "handler.sora_gateway.chat_completions"),
zap.Any("panic", recovered),
).Error("sora.usage_record_task_panic_recovered")
}
}()
task(ctx)
}
func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
}
func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseHeaders http.Header, responseBody []byte, streamStarted bool) {
upstreamMsg := service.ExtractUpstreamErrorMessage(responseBody)
service.SetOpsUpstreamError(c, statusCode, upstreamMsg, "")
status, errType, errMsg := h.mapUpstreamError(statusCode, responseHeaders, responseBody)
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
}
func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseHeaders http.Header, responseBody []byte) (int, string, string) {
if isSoraCloudflareChallengeResponse(statusCode, responseHeaders, responseBody) {
baseMsg := fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", statusCode)
return http.StatusBadGateway, "upstream_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody)
}
upstreamCode, upstreamMessage := extractUpstreamErrorCodeAndMessage(responseBody)
if strings.EqualFold(upstreamCode, "cf_shield_429") {
baseMsg := "Sora request blocked by Cloudflare shield (429). Please switch to a clean proxy/network and retry."
return http.StatusTooManyRequests, "rate_limit_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody)
}
if shouldPassthroughSoraUpstreamMessage(statusCode, upstreamMessage) {
switch statusCode {
case 401, 403, 404, 500, 502, 503, 504:
return http.StatusBadGateway, "upstream_error", upstreamMessage
case 429:
return http.StatusTooManyRequests, "rate_limit_error", upstreamMessage
}
}
switch statusCode {
case 401:
return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
case 403:
return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
case 404:
if strings.EqualFold(upstreamCode, "unsupported_country_code") {
return http.StatusBadGateway, "upstream_error", "Upstream region capability unavailable for this account, please contact administrator"
}
return http.StatusBadGateway, "upstream_error", "Upstream capability unavailable for this account, please contact administrator"
case 429:
return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
case 529:
return http.StatusServiceUnavailable, "upstream_error", "Upstream service overloaded, please retry later"
case 500, 502, 503, 504:
return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable"
default:
return http.StatusBadGateway, "upstream_error", "Upstream request failed"
}
}
func cloneHTTPHeaders(headers http.Header) http.Header {
if headers == nil {
return nil
}
return headers.Clone()
}
func extractSoraFailoverHeaderInsights(headers http.Header, body []byte) (rayID, mitigated, contentType string) {
if headers != nil {
mitigated = strings.TrimSpace(headers.Get("cf-mitigated"))
contentType = strings.TrimSpace(headers.Get("content-type"))
if contentType == "" {
contentType = strings.TrimSpace(headers.Get("Content-Type"))
}
}
rayID = soraerror.ExtractCloudflareRayID(headers, body)
return rayID, mitigated, contentType
}
func isSoraCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body)
}
func shouldPassthroughSoraUpstreamMessage(statusCode int, message string) bool {
message = strings.TrimSpace(message)
if message == "" {
return false
}
if statusCode == http.StatusForbidden || statusCode == http.StatusTooManyRequests {
lower := strings.ToLower(message)
if strings.Contains(lower, "<html") || strings.Contains(lower, "<!doctype html") || strings.Contains(lower, "window._cf_chl_opt") {
return false
}
}
return true
}
func formatSoraCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
return soraerror.FormatCloudflareChallengeMessage(base, headers, body)
}
func extractUpstreamErrorCodeAndMessage(body []byte) (string, string) {
return soraerror.ExtractUpstreamErrorCodeAndMessage(body)
}
func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
if streamStarted {
flusher, ok := c.Writer.(http.Flusher)
if ok {
errorData := map[string]any{
"error": map[string]string{
"type": errType,
"message": message,
},
}
jsonBytes, err := json.Marshal(errorData)
if err != nil {
_ = c.Error(err)
return
}
errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes))
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
_ = c.Error(err)
}
flusher.Flush()
}
return
}
h.errorResponse(c, status, errType, message)
}
func (h *SoraGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
c.JSON(status, gin.H{
"error": gin.H{
"type": errType,
"message": message,
},
})
}
// MediaProxy serves local Sora media files.
func (h *SoraGatewayHandler) MediaProxy(c *gin.Context) {
h.proxySoraMedia(c, false)
}
// MediaProxySigned serves local Sora media files with signature verification.
func (h *SoraGatewayHandler) MediaProxySigned(c *gin.Context) {
h.proxySoraMedia(c, true)
}
func (h *SoraGatewayHandler) proxySoraMedia(c *gin.Context, requireSignature bool) {
rawPath := c.Param("filepath")
if rawPath == "" {
c.Status(http.StatusNotFound)
return
}
cleaned := path.Clean(rawPath)
if !strings.HasPrefix(cleaned, "/image/") && !strings.HasPrefix(cleaned, "/video/") {
c.Status(http.StatusNotFound)
return
}
query := c.Request.URL.Query()
if requireSignature {
if h.soraMediaSigningKey == "" {
c.JSON(http.StatusServiceUnavailable, gin.H{
"error": gin.H{
"type": "api_error",
"message": "Sora 媒体签名未配置",
},
})
return
}
expiresStr := strings.TrimSpace(query.Get("expires"))
signature := strings.TrimSpace(query.Get("sig"))
expires, err := strconv.ParseInt(expiresStr, 10, 64)
if err != nil || expires <= time.Now().Unix() {
c.JSON(http.StatusUnauthorized, gin.H{
"error": gin.H{
"type": "authentication_error",
"message": "Sora 媒体签名已过期",
},
})
return
}
query.Del("sig")
query.Del("expires")
signingQuery := query.Encode()
if !service.VerifySoraMediaURL(cleaned, signingQuery, expires, signature, h.soraMediaSigningKey) {
c.JSON(http.StatusUnauthorized, gin.H{
"error": gin.H{
"type": "authentication_error",
"message": "Sora 媒体签名无效",
},
})
return
}
}
if strings.TrimSpace(h.soraMediaRoot) == "" {
c.JSON(http.StatusServiceUnavailable, gin.H{
"error": gin.H{
"type": "api_error",
"message": "Sora 媒体目录未配置",
},
})
return
}
relative := strings.TrimPrefix(cleaned, "/")
localPath := filepath.Join(h.soraMediaRoot, filepath.FromSlash(relative))
if _, err := os.Stat(localPath); err != nil {
if os.IsNotExist(err) {
c.Status(http.StatusNotFound)
return
}
c.Status(http.StatusInternalServerError)
return
}
c.File(localPath)
}

View File

@@ -1,728 +0,0 @@
//go:build unit
package handler
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/testutil"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// 编译期接口断言
var _ service.SoraClient = (*stubSoraClient)(nil)
var _ service.AccountRepository = (*stubAccountRepo)(nil)
var _ service.GroupRepository = (*stubGroupRepo)(nil)
var _ service.UsageLogRepository = (*stubUsageLogRepo)(nil)
type stubSoraClient struct {
imageURLs []string
}
func (s *stubSoraClient) Enabled() bool { return true }
func (s *stubSoraClient) UploadImage(ctx context.Context, account *service.Account, data []byte, filename string) (string, error) {
return "upload", nil
}
func (s *stubSoraClient) CreateImageTask(ctx context.Context, account *service.Account, req service.SoraImageRequest) (string, error) {
return "task-image", nil
}
func (s *stubSoraClient) CreateVideoTask(ctx context.Context, account *service.Account, req service.SoraVideoRequest) (string, error) {
return "task-video", nil
}
func (s *stubSoraClient) CreateStoryboardTask(ctx context.Context, account *service.Account, req service.SoraStoryboardRequest) (string, error) {
return "task-video", nil
}
func (s *stubSoraClient) UploadCharacterVideo(ctx context.Context, account *service.Account, data []byte) (string, error) {
return "cameo-1", nil
}
func (s *stubSoraClient) GetCameoStatus(ctx context.Context, account *service.Account, cameoID string) (*service.SoraCameoStatus, error) {
return &service.SoraCameoStatus{
Status: "finalized",
StatusMessage: "Completed",
DisplayNameHint: "Character",
UsernameHint: "user.character",
ProfileAssetURL: "https://example.com/avatar.webp",
}, nil
}
func (s *stubSoraClient) DownloadCharacterImage(ctx context.Context, account *service.Account, imageURL string) ([]byte, error) {
return []byte("avatar"), nil
}
func (s *stubSoraClient) UploadCharacterImage(ctx context.Context, account *service.Account, data []byte) (string, error) {
return "asset-pointer", nil
}
func (s *stubSoraClient) FinalizeCharacter(ctx context.Context, account *service.Account, req service.SoraCharacterFinalizeRequest) (string, error) {
return "character-1", nil
}
func (s *stubSoraClient) SetCharacterPublic(ctx context.Context, account *service.Account, cameoID string) error {
return nil
}
func (s *stubSoraClient) DeleteCharacter(ctx context.Context, account *service.Account, characterID string) error {
return nil
}
func (s *stubSoraClient) PostVideoForWatermarkFree(ctx context.Context, account *service.Account, generationID string) (string, error) {
return "s_post", nil
}
func (s *stubSoraClient) DeletePost(ctx context.Context, account *service.Account, postID string) error {
return nil
}
func (s *stubSoraClient) GetWatermarkFreeURLCustom(ctx context.Context, account *service.Account, parseURL, parseToken, postID string) (string, error) {
return "https://example.com/no-watermark.mp4", nil
}
func (s *stubSoraClient) EnhancePrompt(ctx context.Context, account *service.Account, prompt, expansionLevel string, durationS int) (string, error) {
return "enhanced prompt", nil
}
func (s *stubSoraClient) GetImageTask(ctx context.Context, account *service.Account, taskID string) (*service.SoraImageTaskStatus, error) {
return &service.SoraImageTaskStatus{ID: taskID, Status: "completed", URLs: s.imageURLs}, nil
}
func (s *stubSoraClient) GetVideoTask(ctx context.Context, account *service.Account, taskID string) (*service.SoraVideoTaskStatus, error) {
return &service.SoraVideoTaskStatus{ID: taskID, Status: "completed", URLs: s.imageURLs}, nil
}
type stubAccountRepo struct {
accounts map[int64]*service.Account
}
func (r *stubAccountRepo) Create(ctx context.Context, account *service.Account) error { return nil }
func (r *stubAccountRepo) GetByID(ctx context.Context, id int64) (*service.Account, error) {
if acc, ok := r.accounts[id]; ok {
return acc, nil
}
return nil, service.ErrAccountNotFound
}
func (r *stubAccountRepo) GetByIDs(ctx context.Context, ids []int64) ([]*service.Account, error) {
var result []*service.Account
for _, id := range ids {
if acc, ok := r.accounts[id]; ok {
result = append(result, acc)
}
}
return result, nil
}
func (r *stubAccountRepo) ExistsByID(ctx context.Context, id int64) (bool, error) {
_, ok := r.accounts[id]
return ok, nil
}
func (r *stubAccountRepo) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*service.Account, error) {
return nil, nil
}
func (r *stubAccountRepo) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) {
return nil, nil
}
func (r *stubAccountRepo) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) {
return map[string]int64{}, nil
}
func (r *stubAccountRepo) Update(ctx context.Context, account *service.Account) error { return nil }
func (r *stubAccountRepo) Delete(ctx context.Context, id int64) error { return nil }
func (r *stubAccountRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubAccountRepo) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) {
return nil, nil
}
func (r *stubAccountRepo) ListActive(ctx context.Context) ([]service.Account, error) { return nil, nil }
func (r *stubAccountRepo) ListByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
return r.listSchedulableByPlatform(platform), nil
}
func (r *stubAccountRepo) UpdateLastUsed(ctx context.Context, id int64) error { return nil }
func (r *stubAccountRepo) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
return nil
}
func (r *stubAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error { return nil }
func (r *stubAccountRepo) ClearError(ctx context.Context, id int64) error { return nil }
func (r *stubAccountRepo) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return nil
}
func (r *stubAccountRepo) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
return 0, nil
}
func (r *stubAccountRepo) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
return nil
}
func (r *stubAccountRepo) ListSchedulable(ctx context.Context) ([]service.Account, error) {
return r.listSchedulable(), nil
}
func (r *stubAccountRepo) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) {
return r.listSchedulable(), nil
}
func (r *stubAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
return r.listSchedulableByPlatform(platform), nil
}
func (r *stubAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) {
return r.listSchedulableByPlatform(platform), nil
}
func (r *stubAccountRepo) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
var result []service.Account
for _, acc := range r.accounts {
for _, platform := range platforms {
if acc.Platform == platform && acc.IsSchedulable() {
result = append(result, *acc)
break
}
}
}
return result, nil
}
func (r *stubAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) {
return r.ListSchedulableByPlatforms(ctx, platforms)
}
func (r *stubAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
return r.ListSchedulableByPlatform(ctx, platform)
}
func (r *stubAccountRepo) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
return r.ListSchedulableByPlatforms(ctx, platforms)
}
func (r *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
return nil
}
func (r *stubAccountRepo) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
return nil
}
func (r *stubAccountRepo) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return nil
}
func (r *stubAccountRepo) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
return nil
}
func (r *stubAccountRepo) ClearTempUnschedulable(ctx context.Context, id int64) error { return nil }
func (r *stubAccountRepo) ClearRateLimit(ctx context.Context, id int64) error { return nil }
func (r *stubAccountRepo) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
return nil
}
func (r *stubAccountRepo) ClearModelRateLimits(ctx context.Context, id int64) error { return nil }
func (r *stubAccountRepo) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
return nil
}
func (r *stubAccountRepo) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
return nil
}
func (r *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
return 0, nil
}
func (r *stubAccountRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
return nil
}
func (r *stubAccountRepo) ResetQuotaUsed(ctx context.Context, id int64) error {
return nil
}
func (r *stubAccountRepo) listSchedulable() []service.Account {
var result []service.Account
for _, acc := range r.accounts {
if acc.IsSchedulable() {
result = append(result, *acc)
}
}
return result
}
func (r *stubAccountRepo) listSchedulableByPlatform(platform string) []service.Account {
var result []service.Account
for _, acc := range r.accounts {
if acc.Platform == platform && acc.IsSchedulable() {
result = append(result, *acc)
}
}
return result
}
type stubGroupRepo struct {
group *service.Group
}
func (r *stubGroupRepo) Create(ctx context.Context, group *service.Group) error { return nil }
func (r *stubGroupRepo) GetByID(ctx context.Context, id int64) (*service.Group, error) {
return r.group, nil
}
func (r *stubGroupRepo) GetByIDLite(ctx context.Context, id int64) (*service.Group, error) {
return r.group, nil
}
func (r *stubGroupRepo) Update(ctx context.Context, group *service.Group) error { return nil }
func (r *stubGroupRepo) Delete(ctx context.Context, id int64) error { return nil }
func (r *stubGroupRepo) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
return nil, nil
}
func (r *stubGroupRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubGroupRepo) ListActive(ctx context.Context) ([]service.Group, error) { return nil, nil }
func (r *stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
return nil, nil
}
func (r *stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) {
return false, nil
}
func (r *stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
return 0, 0, nil
}
func (r *stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, nil
}
func (r *stubGroupRepo) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
return nil, nil
}
func (r *stubGroupRepo) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
return nil
}
func (r *stubGroupRepo) UpdateSortOrders(ctx context.Context, updates []service.GroupSortOrderUpdate) error {
return nil
}
type stubUsageLogRepo struct{}
func (s *stubUsageLogRepo) Create(ctx context.Context, log *service.UsageLog) (bool, error) {
return true, nil
}
func (s *stubUsageLogRepo) GetByID(ctx context.Context, id int64) (*service.UsageLog, error) {
return nil, nil
}
func (s *stubUsageLogRepo) Delete(ctx context.Context, id int64) error { return nil }
func (s *stubUsageLogRepo) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubUsageLogRepo) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubUsageLogRepo) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubUsageLogRepo) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubUsageLogRepo) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubUsageLogRepo) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubUsageLogRepo) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubUsageLogRepo) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
return []usagestats.EndpointStat{}, nil
}
func (s *stubUsageLogRepo) GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
return []usagestats.EndpointStat{}, nil
}
func (s *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetAPIKeyDashboardStats(ctx context.Context, apiKeyID int64) (*usagestats.UserDashboardStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error) {
return nil, nil
}
func (s *stubUsageLogRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubUsageLogRepo) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error) {
return nil, nil
}
func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
RunMode: config.RunModeSimple,
Gateway: config.GatewayConfig{
SoraStreamMode: "force",
MaxAccountSwitches: 1,
Scheduling: config.GatewaySchedulingConfig{
LoadBatchEnabled: false,
},
},
Concurrency: config.ConcurrencyConfig{PingInterval: 0},
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.test",
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
account := &service.Account{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, Concurrency: 1, Priority: 1}
accountRepo := &stubAccountRepo{accounts: map[int64]*service.Account{account.ID: account}}
group := &service.Group{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Hydrated: true}
groupRepo := &stubGroupRepo{group: group}
usageLogRepo := &stubUsageLogRepo{}
deferredService := service.NewDeferredService(accountRepo, nil, 0)
billingService := service.NewBillingService(cfg, nil)
concurrencyService := service.NewConcurrencyService(testutil.StubConcurrencyCache{})
billingCacheService := service.NewBillingCacheService(nil, nil, nil, nil, cfg)
t.Cleanup(func() {
billingCacheService.Stop()
})
gatewayService := service.NewGatewayService(
accountRepo,
groupRepo,
usageLogRepo,
nil, // usageBillingRepo
nil, // userRepo
nil, // userSubRepo
nil, // userGroupRateRepo
testutil.StubGatewayCache{},
cfg,
nil, // schedulerSnapshot
concurrencyService,
billingService,
nil, // rateLimitService
billingCacheService,
nil, // identityService
nil, // httpUpstream
deferredService,
nil, // claudeTokenProvider
testutil.StubSessionLimitCache{},
nil, // rpmCache
nil, // digestStore
nil, // settingService
nil, // tlsFPProfileService
nil, // channelService
nil, // resolver
)
soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}}
soraGatewayService := service.NewSoraGatewayService(soraClient, nil, nil, cfg)
handler := NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, nil, cfg)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := `{"model":"gpt-image","messages":[{"role":"user","content":"hello"}]}`
c.Request = httptest.NewRequest(http.MethodPost, "/sora/v1/chat/completions", strings.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
apiKey := &service.APIKey{
ID: 1,
UserID: 1,
Status: service.StatusActive,
GroupID: &group.ID,
User: &service.User{ID: 1, Concurrency: 1, Status: service.StatusActive},
Group: group,
}
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.UserID, Concurrency: apiKey.User.Concurrency})
handler.ChatCompletions(c)
require.Equal(t, http.StatusOK, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.NotEmpty(t, resp["media_url"])
}
// TestSoraHandler_StreamForcing 验证 sora handler 的 stream 强制逻辑
func TestSoraHandler_StreamForcing(t *testing.T) {
// 测试 1stream=false 时 sjson 强制修改为 true
body := []byte(`{"model":"sora","messages":[{"role":"user","content":"test"}],"stream":false}`)
clientStream := gjson.GetBytes(body, "stream").Bool()
require.False(t, clientStream)
newBody, err := sjson.SetBytes(body, "stream", true)
require.NoError(t, err)
require.True(t, gjson.GetBytes(newBody, "stream").Bool())
// 测试 2stream=true 时不修改
body2 := []byte(`{"model":"sora","messages":[{"role":"user","content":"test"}],"stream":true}`)
require.True(t, gjson.GetBytes(body2, "stream").Bool())
// 测试 3无 stream 字段时 gjson 返回 false零值
body3 := []byte(`{"model":"sora","messages":[{"role":"user","content":"test"}]}`)
require.False(t, gjson.GetBytes(body3, "stream").Bool())
}
// TestSoraHandler_ValidationExtraction 验证 sora handler 中 gjson 字段校验逻辑
func TestSoraHandler_ValidationExtraction(t *testing.T) {
// model 缺失
body := []byte(`{"messages":[{"role":"user","content":"test"}]}`)
modelResult := gjson.GetBytes(body, "model")
require.True(t, !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "")
// model 为数字 → 类型不是 gjson.String应被拒绝
body1b := []byte(`{"model":123,"messages":[{"role":"user","content":"test"}]}`)
modelResult1b := gjson.GetBytes(body1b, "model")
require.True(t, modelResult1b.Exists())
require.NotEqual(t, gjson.String, modelResult1b.Type)
// messages 缺失
body2 := []byte(`{"model":"sora"}`)
require.False(t, gjson.GetBytes(body2, "messages").IsArray())
// messages 不是 JSON 数组(字符串)
body3 := []byte(`{"model":"sora","messages":"not array"}`)
require.False(t, gjson.GetBytes(body3, "messages").IsArray())
// messages 是对象而非数组 → IsArray 返回 false
body4 := []byte(`{"model":"sora","messages":{}}`)
require.False(t, gjson.GetBytes(body4, "messages").IsArray())
// messages 是空数组 → IsArray 为 true 但 len==0应被拒绝
body5 := []byte(`{"model":"sora","messages":[]}`)
msgsResult := gjson.GetBytes(body5, "messages")
require.True(t, msgsResult.IsArray())
require.Equal(t, 0, len(msgsResult.Array()))
// 非法 JSON 被 gjson.ValidBytes 拦截
require.False(t, gjson.ValidBytes([]byte(`{invalid`)))
}
// TestGenerateOpenAISessionHash_WithBody 验证 generateOpenAISessionHash 的 body/header 解析逻辑
func TestGenerateOpenAISessionHash_WithBody(t *testing.T) {
gin.SetMode(gin.TestMode)
// 从 body 提取 prompt_cache_key
body := []byte(`{"model":"sora","prompt_cache_key":"session-abc"}`)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/", nil)
hash := generateOpenAISessionHash(c, body)
require.NotEmpty(t, hash)
// 无 prompt_cache_key 且无 header → 空 hash
body2 := []byte(`{"model":"sora"}`)
hash2 := generateOpenAISessionHash(c, body2)
require.Empty(t, hash2)
// header 优先于 body
c.Request.Header.Set("session_id", "from-header")
hash3 := generateOpenAISessionHash(c, body)
require.NotEmpty(t, hash3)
require.NotEqual(t, hash, hash3) // 不同来源应产生不同 hash
}
func TestSoraHandleStreamingAwareError_JSONEscaping(t *testing.T) {
tests := []struct {
name string
errType string
message string
}{
{
name: "包含双引号",
errType: "upstream_error",
message: `upstream returned "invalid" payload`,
},
{
name: "包含换行和制表符",
errType: "rate_limit_error",
message: "line1\nline2\ttab",
},
{
name: "包含反斜杠",
errType: "upstream_error",
message: `path C:\Users\test\file.txt not found`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
h := &SoraGatewayHandler{}
h.handleStreamingAwareError(c, http.StatusBadGateway, tt.errType, tt.message, true)
body := w.Body.String()
require.True(t, strings.HasPrefix(body, "event: error\n"), "应以 SSE error 事件开头")
require.True(t, strings.HasSuffix(body, "\n\n"), "应以 SSE 结束分隔符结尾")
lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n")
require.Len(t, lines, 2, "SSE 错误事件应包含 event 行和 data 行")
require.Equal(t, "event: error", lines[0])
require.True(t, strings.HasPrefix(lines[1], "data: "), "第二行应为 data 前缀")
jsonStr := strings.TrimPrefix(lines[1], "data: ")
var parsed map[string]any
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed), "data 行必须是合法 JSON")
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok, "JSON 中应包含 error 对象")
require.Equal(t, tt.errType, errorObj["type"])
require.Equal(t, tt.message, errorObj["message"])
})
}
}
func TestSoraHandleFailoverExhausted_StreamPassesUpstreamMessage(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
h := &SoraGatewayHandler{}
resp := []byte(`{"error":{"message":"invalid \"prompt\"\nline2","code":"bad_request"}}`)
h.handleFailoverExhausted(c, http.StatusBadGateway, nil, resp, true)
body := w.Body.String()
require.True(t, strings.HasPrefix(body, "event: error\n"))
require.True(t, strings.HasSuffix(body, "\n\n"))
lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n")
require.Len(t, lines, 2)
jsonStr := strings.TrimPrefix(lines[1], "data: ")
var parsed map[string]any
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok)
require.Equal(t, "upstream_error", errorObj["type"])
require.Equal(t, "invalid \"prompt\"\nline2", errorObj["message"])
}
func TestSoraHandleFailoverExhausted_CloudflareChallengeIncludesRay(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
headers := http.Header{}
headers.Set("cf-ray", "9d01b0e9ecc35829-SEA")
body := []byte(`<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={};</script></body></html>`)
h := &SoraGatewayHandler{}
h.handleFailoverExhausted(c, http.StatusForbidden, headers, body, true)
lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n")
require.Len(t, lines, 2)
jsonStr := strings.TrimPrefix(lines[1], "data: ")
var parsed map[string]any
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok)
require.Equal(t, "upstream_error", errorObj["type"])
msg, _ := errorObj["message"].(string)
require.Contains(t, msg, "Cloudflare challenge")
require.Contains(t, msg, "cf-ray: 9d01b0e9ecc35829-SEA")
}
func TestSoraHandleFailoverExhausted_CfShield429MappedToRateLimitError(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
headers := http.Header{}
headers.Set("cf-ray", "9d03b68c086027a1-SEA")
body := []byte(`{"error":{"code":"cf_shield_429","message":"shield blocked"}}`)
h := &SoraGatewayHandler{}
h.handleFailoverExhausted(c, http.StatusTooManyRequests, headers, body, true)
lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n")
require.Len(t, lines, 2)
jsonStr := strings.TrimPrefix(lines[1], "data: ")
var parsed map[string]any
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok)
require.Equal(t, "rate_limit_error", errorObj["type"])
msg, _ := errorObj["message"].(string)
require.Contains(t, msg, "Cloudflare shield")
require.Contains(t, msg, "cf-ray: 9d03b68c086027a1-SEA")
}
func TestExtractSoraFailoverHeaderInsights(t *testing.T) {
headers := http.Header{}
headers.Set("cf-mitigated", "challenge")
headers.Set("content-type", "text/html")
body := []byte(`<script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script>`)
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(headers, body)
require.Equal(t, "9cff2d62d83bb98d", rayID)
require.Equal(t, "challenge", mitigated)
require.Equal(t, "text/html", contentType)
}

View File

@@ -34,7 +34,6 @@ func ProvideAdminHandlers(
scheduledTestHandler *admin.ScheduledTestHandler,
channelHandler *admin.ChannelHandler,
paymentHandler *admin.PaymentHandler,
soraHandler *admin.SoraHandler,
) *AdminHandlers {
return &AdminHandlers{
Dashboard: dashboardHandler,
@@ -62,7 +61,6 @@ func ProvideAdminHandlers(
ScheduledTest: scheduledTestHandler,
Channel: channelHandler,
Payment: paymentHandler,
Sora: soraHandler,
}
}
@@ -88,8 +86,6 @@ func ProvideHandlers(
adminHandlers *AdminHandlers,
gatewayHandler *GatewayHandler,
openaiGatewayHandler *OpenAIGatewayHandler,
soraGatewayHandler *SoraGatewayHandler, // 从本地版本合并
soraClientHandler *SoraClientHandler, // 从本地版本合并
settingHandler *SettingHandler,
totpHandler *TotpHandler,
paymentHandler *PaymentHandler,
@@ -108,8 +104,6 @@ func ProvideHandlers(
Admin: adminHandlers,
Gateway: gatewayHandler,
OpenAIGateway: openaiGatewayHandler,
SoraGateway: soraGatewayHandler, // 从本地版本合并
SoraClient: soraClientHandler, // 从本地版本合并
Setting: settingHandler,
Totp: totpHandler,
Payment: paymentHandler,
@@ -129,8 +123,6 @@ var ProviderSet = wire.NewSet(
NewAnnouncementHandler,
NewGatewayHandler,
NewOpenAIGatewayHandler,
NewSoraGatewayHandler, // 从本地版本合并
NewSoraClientHandler, // 从本地版本合并
NewTotpHandler,
ProvideSettingHandler,
NewPaymentHandler,
@@ -162,7 +154,6 @@ var ProviderSet = wire.NewSet(
admin.NewScheduledTestHandler,
admin.NewChannelHandler,
admin.NewPaymentHandler,
admin.NewSoraHandler,
// AdminHandlers and Handlers constructors
ProvideAdminHandlers,