From 0e057904e6b68c675cf2052166db6fccdfc8bb7e Mon Sep 17 00:00:00 2001 From: pham Date: Sun, 10 May 2026 14:15:45 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E5=BD=BB=E5=BA=95=E7=A7=BB?= =?UTF-8?q?=E9=99=A4=20Sora=20=E8=A7=86=E9=A2=91=E7=94=9F=E6=88=90?= =?UTF-8?q?=E6=A8=A1=E5=9D=97=EF=BC=88=E5=85=A8=E6=A0=88=E6=B8=85=E7=90=86?= =?UTF-8?q?=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## 后端变更 - 删除 21 个 sora_*.go 服务文件(service/handler/repository/routes) - 删除 Sora 相关 migration 文件(046/047/063/090) - 清理 config 中的 sora_* 配置项和平台常量 - 清理 wire 依赖注入中的 Sora 组件 - 修复 wire_gen.go 语法错误(缺少逗号和闭合括号) - 移除 go.mod 中的 go-sora2api 依赖 - 更新 ent schema usage_log.go 注释 ## 前端变更 - 删除 SoraView、SoraAdminView 及 8 个 Sora 子组件 - 删除 sora API 层和路由配置 - 清理 UserEditModal 中的 Sora 存储配额 UI - 清理 types/index.ts 中 Sora 相关类型定义 - 清理 stores/app.ts 默认配置 - 清理 i18n 翻译文件 en.ts/zh.ts (~110 行) - 更新相关测试文件 ## 文档更新 - README.md / README_CN.md / README_JA.md: 移除 Sora 状态说明和配置段落 - PROJECT_DIFF.md: 移除 Sora 相关差异描述 ## 验证结果 - ✅ Go 编译通过 (go build ./...) - ✅ TypeScript 类型检查通过 (vue-tsc --noEmit) - ✅ 后端测试全通过 (0 failures) - ✅ 前端测试全通过 (59 files, 329 tests, 0 failures) - ✅ 前端生产构建成功 (23.81s) --- DEV_GUIDE.md | 2 +- PROJECT_DIFF.md | 52 +- QA_VALIDATION_REPORT.md | 260 ++ README.md | 10 +- README_CN.md | 27 - README_JA.md | 6 - backend/cmd/server/wire.go | 10 +- backend/cmd/server/wire_gen.go | 25 +- backend/cmd/server/wire_gen_test.go | 1 - backend/ent/schema/usage_log.go | 2 +- backend/ent/usagelog.go | 2 +- backend/go.mod | 1 - backend/go.sum | 4 + backend/internal/config/config.go | 3 +- backend/internal/config/config_domain_test.go | 9 - .../config/config_integration_test.go | 1 - .../config/config_validate_gateway.go | 2 +- backend/internal/config/gateway.go | 10 - backend/internal/config/platforms.go | 58 - backend/internal/domain/constants.go | 1 - .../internal/handler/admin/sora_handler.go | 142 - .../handler/admin/sora_handler_test.go | 262 -- backend/internal/handler/dto/settings.go | 1 - backend/internal/handler/handler.go | 3 - .../internal/handler/sora_client_handler.go | 979 ----- .../handler/sora_client_handler_test.go | 3186 ----------------- .../internal/handler/sora_gateway_handler.go | 694 ---- .../handler/sora_gateway_handler_test.go | 728 ---- backend/internal/handler/wire.go | 9 - backend/internal/pkg/openai/oauth.go | 3 - backend/internal/prommetrics/metrics_test.go | 1 - .../internal/repository/sora_account_repo.go | 98 - .../repository/sora_generation_repo.go | 419 --- backend/internal/repository/wire.go | 2 - .../server/middleware/security_headers.go | 1 - backend/internal/server/routes/admin.go | 11 - .../server/routes/admin_routes_test.go | 1 - backend/internal/server/routes/sora_client.go | 36 - backend/internal/service/domain_constants.go | 14 - backend/internal/service/gateway_service.go | 4 - backend/internal/service/setting_service.go | 314 -- backend/internal/service/settings_view.go | 40 - .../internal/service/sora_account_service.go | 40 - backend/internal/service/sora_client.go | 117 - .../internal/service/sora_gateway_service.go | 1559 -------- .../service/sora_gateway_service_test.go | 564 --- .../service/sora_gateway_streaming_legacy.go | 532 --- backend/internal/service/sora_generation.go | 63 - .../service/sora_generation_service.go | 341 -- .../service/sora_generation_service_test.go | 876 ----- .../service/sora_media_cleanup_service.go | 120 - .../sora_media_cleanup_service_test.go | 207 -- backend/internal/service/sora_media_sign.go | 48 - .../internal/service/sora_media_sign_test.go | 34 - .../internal/service/sora_media_storage.go | 381 -- .../service/sora_media_storage_test.go | 119 - backend/internal/service/sora_models.go | 488 --- .../internal/service/sora_quota_service.go | 110 - .../service/sora_quota_service_test.go | 274 -- backend/internal/service/sora_s3_storage.go | 398 -- .../internal/service/sora_s3_storage_test.go | 263 -- backend/internal/service/sora_sdk_client.go | 1027 ------ .../service/sora_upstream_forwarder.go | 149 - backend/internal/service/wire.go | 36 +- backend/internal/util/soraerror/soraerror.go | 170 - .../internal/util/soraerror/soraerror_test.go | 47 - .../045_add_accounts_extra_index.sql | 5 +- backend/migrations/046_add_sora_accounts.sql | 24 - .../047_add_sora_pricing_and_media_type.sql | 11 - .../migrations/063_add_sora_client_tables.sql | 56 - backend/migrations/090_drop_sora.sql | 34 - frontend/src/api/__tests__/sora.spec.ts | 80 - frontend/src/api/admin/index.ts | 7 +- frontend/src/api/admin/sora.ts | 78 - frontend/src/api/sora.ts | 307 -- .../components/admin/user/UserEditModal.vue | 11 +- .../user/__tests__/UserEditModal.spec.ts | 84 +- .../components/sora/SoraDownloadDialog.vue | 217 -- .../src/components/sora/SoraGeneratePage.vue | 430 --- .../src/components/sora/SoraLibraryPage.vue | 606 ---- .../src/components/sora/SoraMediaPreview.vue | 282 -- .../components/sora/SoraNoStorageWarning.vue | 39 - .../src/components/sora/SoraProgressCard.vue | 609 ---- .../src/components/sora/SoraPromptBar.vue | 738 ---- frontend/src/components/sora/SoraQuotaBar.vue | 87 - .../sora/__tests__/SoraGeneratePage.spec.ts | 382 -- frontend/src/i18n/locales/en.ts | 123 - frontend/src/i18n/locales/zh.ts | 125 +- frontend/src/router/index.ts | 24 - frontend/src/stores/app.ts | 1 - frontend/src/types/index.ts | 9 +- frontend/src/views/admin/SoraAdminView.vue | 417 --- .../admin/__tests__/SoraAdminView.spec.ts | 262 -- frontend/src/views/user/SoraView.vue | 369 -- review_tmp/architecture_review_report.md | 188 + ...2api-launch-readiness-review-2026-05-08.md | 249 ++ 96 files changed, 726 insertions(+), 20525 deletions(-) create mode 100644 QA_VALIDATION_REPORT.md delete mode 100644 backend/internal/handler/admin/sora_handler.go delete mode 100644 backend/internal/handler/admin/sora_handler_test.go delete mode 100644 backend/internal/handler/sora_client_handler.go delete mode 100644 backend/internal/handler/sora_client_handler_test.go delete mode 100644 backend/internal/handler/sora_gateway_handler.go delete mode 100644 backend/internal/handler/sora_gateway_handler_test.go delete mode 100644 backend/internal/repository/sora_account_repo.go delete mode 100644 backend/internal/repository/sora_generation_repo.go delete mode 100644 backend/internal/server/routes/sora_client.go delete mode 100644 backend/internal/service/sora_account_service.go delete mode 100644 backend/internal/service/sora_client.go delete mode 100644 backend/internal/service/sora_gateway_service.go delete mode 100644 backend/internal/service/sora_gateway_service_test.go delete mode 100644 backend/internal/service/sora_gateway_streaming_legacy.go delete mode 100644 backend/internal/service/sora_generation.go delete mode 100644 backend/internal/service/sora_generation_service.go delete mode 100644 backend/internal/service/sora_generation_service_test.go delete mode 100644 backend/internal/service/sora_media_cleanup_service.go delete mode 100644 backend/internal/service/sora_media_cleanup_service_test.go delete mode 100644 backend/internal/service/sora_media_sign.go delete mode 100644 backend/internal/service/sora_media_sign_test.go delete mode 100644 backend/internal/service/sora_media_storage.go delete mode 100644 backend/internal/service/sora_media_storage_test.go delete mode 100644 backend/internal/service/sora_models.go delete mode 100644 backend/internal/service/sora_quota_service.go delete mode 100644 backend/internal/service/sora_quota_service_test.go delete mode 100644 backend/internal/service/sora_s3_storage.go delete mode 100644 backend/internal/service/sora_s3_storage_test.go delete mode 100644 backend/internal/service/sora_sdk_client.go delete mode 100644 backend/internal/service/sora_upstream_forwarder.go delete mode 100644 backend/internal/util/soraerror/soraerror.go delete mode 100644 backend/internal/util/soraerror/soraerror_test.go delete mode 100644 backend/migrations/046_add_sora_accounts.sql delete mode 100644 backend/migrations/047_add_sora_pricing_and_media_type.sql delete mode 100644 backend/migrations/063_add_sora_client_tables.sql delete mode 100644 backend/migrations/090_drop_sora.sql delete mode 100644 frontend/src/api/__tests__/sora.spec.ts delete mode 100644 frontend/src/api/admin/sora.ts delete mode 100644 frontend/src/api/sora.ts delete mode 100644 frontend/src/components/sora/SoraDownloadDialog.vue delete mode 100644 frontend/src/components/sora/SoraGeneratePage.vue delete mode 100644 frontend/src/components/sora/SoraLibraryPage.vue delete mode 100644 frontend/src/components/sora/SoraMediaPreview.vue delete mode 100644 frontend/src/components/sora/SoraNoStorageWarning.vue delete mode 100644 frontend/src/components/sora/SoraProgressCard.vue delete mode 100644 frontend/src/components/sora/SoraPromptBar.vue delete mode 100644 frontend/src/components/sora/SoraQuotaBar.vue delete mode 100644 frontend/src/components/sora/__tests__/SoraGeneratePage.spec.ts delete mode 100644 frontend/src/views/admin/SoraAdminView.vue delete mode 100644 frontend/src/views/admin/__tests__/SoraAdminView.spec.ts delete mode 100644 frontend/src/views/user/SoraView.vue create mode 100644 review_tmp/architecture_review_report.md create mode 100644 review_tmp/sub2api-launch-readiness-review-2026-05-08.md diff --git a/DEV_GUIDE.md b/DEV_GUIDE.md index d0d362e0..3645d6d2 100644 --- a/DEV_GUIDE.md +++ b/DEV_GUIDE.md @@ -53,7 +53,7 @@ npm install -g pnpm ### CI 要求 -- Go 版本必须是 **1.25.7** +- Go 版本必须是 **1.26.2** - 前端使用 `pnpm install --frozen-lockfile`,必须提交 `pnpm-lock.yaml` ### 本地测试命令 diff --git a/PROJECT_DIFF.md b/PROJECT_DIFF.md index 76679f8b..70d07628 100644 --- a/PROJECT_DIFF.md +++ b/PROJECT_DIFF.md @@ -85,34 +85,12 @@ prommetrics.SetQPS(100.0) --- -### 1.4 Sora 视频生成服务 (新增) - -**路径**: `backend/internal/service/sora_*.go` - -完整的 Sora 视频生成服务模块。 - -| 文件 | 说明 | -|------|------| -| `sora_gateway_service.go` | Sora API 网关服务 | -| `sora_generation_service.go` | 视频生成服务 | -| `sora_quota_service.go` | 用户配额管理 | -| `sora_account_service.go` | 账户服务 | -| `sora_s3_storage.go` | S3 存储集成 | -| `sora_media_storage.go` | 媒体存储抽象 | -| `sora_media_cleanup_service.go` | 媒体文件清理 | -| `sora_models.go` | 数据模型定义 | -| `sora_client.go` | Sora API 客户端 | -| `sora_sdk_client.go` | SDK 客户端 | - ---- - ### 1.5 管理员 API Handler (新增) **路径**: `backend/internal/handler/admin/` | 文件 | 说明 | |------|------| -| `sora_handler.go` | Sora 管理接口:系统统计、用户统计、生成记录 | | `ops_handler.go` | 运维监控入口 | | `ops_dashboard_handler.go` | 仪表盘数据 | | `ops_alerts_handler.go` | 告警管理 | @@ -193,21 +171,6 @@ GET /metrics -> Prometheus 指标端点 --- -### 2.2 Sora 管理页面 (新增) - -**路径**: `frontend/src/views/admin/SoraAdminView.vue` - -Sora 视频生成服务的管理后台。 - -**功能**: -- 概览标签页:系统统计、按状态/模型分布 -- 用户统计标签页:用户配额、使用量、生成数 -- 生成记录标签页:历史记录、状态筛选 - -**测试文件**: `frontend/src/views/admin/__tests__/SoraAdminView.spec.ts` - ---- - ### 2.3 数据管理配置页面 (新增) **路径**: `frontend/src/views/admin/data-management/` @@ -231,7 +194,6 @@ Sora 视频生成服务的管理后台。 | 文件 | 说明 | |------|------| | `ops.ts` | 运维监控 API | -| `sora.ts` | Sora 管理 API | | `dataManagement.ts` | 数据管理 API | --- @@ -244,13 +206,6 @@ Sora 视频生成服务的管理后台。 新增路由: ```go -// Sora 管理 -soraGroup := admin.Group("/sora") -soraGroup.GET("/stats", soraHandler.GetSystemStats) -soraGroup.GET("/users", soraHandler.ListUserStats) -soraGroup.GET("/generations", soraHandler.ListGenerations) -soraGroup.DELETE("/users/:id/storage", soraHandler.ClearUserStorage) - // 运维监控 opsGroup := admin.Group("/ops") // ... 多个运维监控路由 @@ -263,7 +218,6 @@ opsGroup := admin.Group("/ops") 新增路由: ```typescript { path: '/admin/ops', component: OpsDashboard } -{ path: '/admin/sora', component: SoraAdminView } { path: '/admin/data-management', component: DataManagementView } ``` @@ -276,7 +230,6 @@ opsGroup := admin.Group("/ops") **路径**: `frontend/src/i18n/locales/zh.ts` 新增翻译键: -- `admin.sora.*` - Sora 管理页面 - `admin.ops.*` - 运维监控页面 - `admin.dataManagement.*` - 数据管理页面 @@ -297,13 +250,11 @@ opsGroup := admin.Group("/ops") | `prommetrics/metrics_test.go` | Prometheus 指标测试 | | `routes/common_test.go` | 健康检查端点测试 | | `service/webhook_service_test.go` | Webhook 服务测试 | -| `handler/admin/sora_handler_test.go` | Sora Handler 测试 | ### 5.2 前端测试 | 文件 | 说明 | |------|------| -| `SoraAdminView.spec.ts` | Sora 管理页面测试 | | `OpsSettingsDialog.spec.ts` | 运维设置对话框测试 | | `OpsOpenAITokenStatsCard.spec.ts` | Token 统计卡片测试 | @@ -349,7 +300,6 @@ opsGroup := admin.Group("/ops") 4. **验证功能**: - 访问 `/admin/ops` 验证运维监控 - - 访问 `/admin/sora` 验证 Sora 管理 - 访问 `/admin/data-management` 验证数据管理 - 访问 `/metrics` 验证 Prometheus 指标 @@ -414,7 +364,7 @@ opsGroup := admin.Group("/ops") |------|------|----------| | 💭 #8 | 密码复杂度要求不一致 | ✅ 已确认 | | 💭 #9 | 测试覆盖不均衡 | ✅ 已确认 | -| 💭 #10 | 前端 confirm() 调用 | ✅ 已确认 (SoraAdminView.vue:100) | +| 💭 #10 | 前端 confirm() 调用 | ✅ 已确认 | | 💭 #11 | Dockerfile 非固定镜像标签 | ✅ 已确认 | ### 8.4 待修复项清单 diff --git a/QA_VALIDATION_REPORT.md b/QA_VALIDATION_REPORT.md new file mode 100644 index 00000000..a83b42c9 --- /dev/null +++ b/QA_VALIDATION_REPORT.md @@ -0,0 +1,260 @@ +# Sub2API 合并版本测试验证报告 + +> 验证人: QA (Yan) +> 日期: 2026-05-08 +> 工作目录: d:/project/sub2api-merge + +--- + +## 1. 测试执行结果汇总表 + +| 测试类型 | 命令/方式 | 结果 | 说明 | +|---------|----------|------|------| +| 后端 Unit Tests | `go test -tags=unit ./...` | **通过** | 全部通过,无失败 | +| 后端 Integration Tests | `go test -tags=integration ./...` | **通过** | 全部通过,无失败 | +| 后端 Coverage 收集 | `go test -tags=unit -cover ./...` | **失败** | `internal/config` 因 BOM 问题构建失败 | +| 后端 go vet | `go vet ./...` | **通过** | 无警告 | +| 后端构建 | `go build -tags embed -o sub2api ./cmd/server` | **通过** | 二进制生成成功 | +| 前端 Unit Tests | `npx vitest run` | **通过** | 62 文件 / 364 测试全部通过 | +| 前端构建 | `pnpm run build` | **通过** | 生成成功,有动态导入警告 | +| golangci-lint | `golangci-lint run ./...` | **未执行** | 本地未安装 | + +### 关键模块覆盖率(unit 测试) + +| 模块 | 覆盖率 | 状态 | +|------|--------|------| +| internal/prommetrics | 100.0% | 优秀 | +| internal/service | 46.5% | 一般 | +| internal/handler/admin | 22.8% | 偏低 | +| internal/repository | 15.3% | 偏低 | +| internal/server/routes | 77.0% | 良好 | +| internal/middleware | 65.4% | 良好 | +| internal/pkg/response | 95.3% | 优秀 | +| internal/pkg/proxyurl | 100.0% | 优秀 | +| internal/pkg/usagestats | 100.0% | 优秀 | + +--- + +## 2. 新增模块测试覆盖缺口 + +根据 PROJECT_DIFF.md 对新增模块的测试文件检查结果: + +### 已覆盖(测试文件存在且通过) + +| 模块路径 | 测试文件 | 状态 | +|---------|---------|------| +| `backend/internal/prommetrics/` | `metrics_test.go` | 通过,覆盖率 100% | +| `backend/internal/service/webhook_service.go` | `webhook_service_test.go` | 通过 | +| `backend/internal/handler/admin/sora_handler.go` | `sora_handler_test.go` | 通过 | +| `backend/internal/service/ops_service.go` | `ops_service_batch_test.go`, `ops_service_prepare_queue_test.go`, `ops_service_redaction_test.go` | 通过 | +| `backend/internal/service/ops_alert_evaluator_service.go` | `ops_alert_evaluator_service_test.go` | 通过 | +| `backend/internal/service/ops_metrics_collector.go` | `ops_health_score_test.go`, `ops_openai_token_stats_test.go` | 通过 | +| `backend/internal/service/ops_settings.go` | `ops_settings_advanced_test.go` | 通过 | +| `backend/internal/service/ops_cleanup_service.go` | `ops_partition_test.go` | 通过 | +| `backend/internal/service/ops_aggregation_service.go` | `ops_query_mode_test.go` | 通过 | +| `backend/internal/service/ops_realtime.go` | `ops_log_runtime_test.go`, `ops_upstream_context_test.go`, `ops_retry_context_test.go` | 通过 | +| `backend/internal/service/ops_scheduled_report_service.go` | `ops_system_log_service_test.go`, `ops_system_log_sink_test.go` | 通过 | +| `backend/internal/service/sora_gateway_service.go` | `sora_gateway_service_test.go` | 通过 | +| `backend/internal/service/sora_generation_service.go` | `sora_generation_service_test.go` | 通过 | +| `backend/internal/service/sora_quota_service.go` | `sora_quota_service_test.go` | 通过 | +| `backend/internal/service/sora_s3_storage.go` | `sora_s3_storage_test.go` | 通过 | +| `backend/internal/service/sora_media_storage.go` | `sora_media_storage_test.go` | 通过 | +| `backend/internal/service/sora_media_cleanup_service.go` | `sora_media_cleanup_service_test.go` | 通过 | +| `backend/internal/handler/admin/ops_handler.go` | `ops_runtime_logging_handler_test.go`, `ops_system_log_handler_test.go` | 通过 | +| `backend/internal/repository/ops_repo*.go` | `ops_repo_dashboard_timeout_test.go`, `ops_repo_error_where_test.go`, `ops_repo_latency_histogram_buckets_test.go`, `ops_repo_openai_token_stats_test.go`, `ops_repo_system_logs_test.go` | 通过 | +| `frontend/src/views/admin/__tests__/SoraAdminView.spec.ts` | - | 通过(12 测试) | +| `frontend/src/views/admin/ops/components/__tests__/OpsSettingsDialog.spec.ts` | - | 通过(3 测试) | +| `frontend/src/views/admin/ops/components/__tests__/OpsOpenAITokenStatsCard.spec.ts` | - | 通过(5 测试) | + +### 测试覆盖缺口 + +| 模块路径 | 预期测试 | 实际状态 | 风险 | +|---------|---------|---------|------| +| `backend/internal/handler/admin/ops_dashboard_handler.go` | `ops_dashboard_handler_test.go` | **缺失** | 中 - 仪表盘数据查询接口无测试 | +| `backend/internal/handler/admin/ops_alerts_handler.go` | `ops_alerts_handler_test.go` | **缺失** | 中 - 告警管理接口无测试 | +| `backend/internal/handler/admin/ops_realtime_handler.go` | `ops_realtime_handler_test.go` | **缺失** | 低 - 实时数据接口无测试 | +| `backend/internal/handler/admin/ops_ws_handler.go` | `ops_ws_handler_test.go` | **缺失** | 低 - WebSocket 连接无测试 | +| `backend/internal/handler/admin/data_management_handler.go` | `data_management_handler_test.go` | **缺失** | 中 - 数据管理接口无测试 | +| `backend/internal/service/sora_account_service.go` | `sora_account_service_test.go` | **缺失** | 中 | +| `backend/internal/service/sora_client.go` | `sora_client_test.go` | **缺失** | 中 | +| `backend/internal/service/sora_sdk_client.go` | `sora_sdk_client_test.go` | **缺失** | 中 | +| `backend/internal/repository/ops_repo_preagg.go` | `ops_repo_preagg_test.go` | **缺失** | 低 | +| `backend/internal/repository/ops_repo_trends.go` | `ops_repo_trends_test.go` | **缺失** | 低 | +| `backend/internal/repository/ops_repo_metrics.go` | `ops_repo_metrics_test.go` | **缺失** | 低 | +| `backend/internal/repository/ops_repo_realtime_traffic.go` | `ops_repo_realtime_traffic_test.go` | **缺失** | 低 | +| `backend/internal/repository/ops_repo_request_details.go` | `ops_repo_request_details_test.go` | **缺失** | 低 | + +**总结**: 新增核心模块(prommetrics、webhook、sora 主要服务、ops 核心服务)测试覆盖较好,但 Handler 层和 Repository 层部分模块测试缺失。整体新增模块测试覆盖率达到约 65%。 + +--- + +## 3. 构建验证结果 + +| 组件 | 命令 | 结果 | 问题 | +|------|------|------|------| +| 后端 | `go build -tags embed -o sub2api ./cmd/server` | **通过** | 无 | +| 前端 | `pnpm run build` | **通过** | 动态导入警告(非阻塞) | + +### 前端构建警告(非阻塞) + +- `src/stores/app.ts` 被动态导入同时也被静态导入,导致无法拆分到独立 chunk +- `src/router/title.ts` 和 `src/router/index.ts` 存在同样问题 +- 部分 chunk 超过 500KB(AccountsView 544KB, vendor-ui 430KB) + +**评估**: 上述警告不影响功能,属于构建优化建议。 + +--- + +## 4. CI/CD 配置评估 + +### 存在的配置文件 + +| 文件 | 状态 | 评估 | +|------|------|------| +| `.github/workflows/backend-ci.yml` | 存在 | 有版本匹配问题 | +| `.github/workflows/security-scan.yml` | 存在 | 配置合理 | +| `.github/workflows/release.yml` | 存在 | 未详细检查 | + +### backend-ci.yml 问题 + +1. **Go 版本不匹配**: + - CI 中校验 `go1.26.2` + - DEV_GUIDE.md 要求 Go 1.25.7 + - 当前环境实际安装 `go1.26.2` + - **建议**: 统一文档和 CI 中的版本要求 + +2. **缺少前端 CI**: + - 没有前端测试/构建的 CI 工作流 + - 建议增加 `frontend-ci.yml` + +3. **Makefile 依赖**: + - CI 使用 `make test-unit` 和 `make test-integration` + - Makefile 存在且配置正确 + +### security-scan.yml 评估 + +- 包含后端 `govulncheck` 和前端 `pnpm audit` +- 有定时扫描(每周一 03:00) +- 配置合理 + +--- + +## 5. 数据库迁移风险评估 + +### 迁移文件完整性 + +| 功能 | 迁移文件 | 状态 | +|------|---------|------| +| Ops 监控核心表 | `026_ops_metrics_aggregation_tables.sql` | 存在 | +| Ops 监控 vNext | `033_ops_monitoring_vnext.sql` | 存在 | +| Ops 上游错误事件 | `034_ops_upstream_error_events.sql` | 存在 | +| Ops 错误日志扩展 | `036_ops_error_logs_add_is_count_tokens.sql` | 存在 | +| Ops 告警静默 | `037_ops_alert_silences.sql` | 存在 | +| Ops 错误分类标准化 | `038_ops_errors_resolution_retry_results_and_standardize_classification.sql` | 存在 | +| Ops 任务心跳 | `039_ops_job_heartbeats_add_last_result.sql` | 存在 | +| Ops 系统指标 | `042b_add_ops_system_metrics_switch_count.sql` | 存在 | +| Ops 系统日志 | `054_ops_system_logs.sql` | 存在 | +| Ops 错误日志端点字段 | `079_ops_error_logs_add_endpoint_fields.sql` | 存在 | +| Sora 账号表 | `046_add_sora_accounts.sql` | 存在 | +| Sora 定价和媒体类型 | `047_add_sora_pricing_and_media_type.sql` | 存在 | +| Sora 客户端表 | `063_add_sora_client_tables.sql` | 存在 | + +### 严重风险:迁移 090 与代码矛盾 + +**发现**: `090_drop_sora.sql` 存在严重问题: + +```sql +-- Migration: 090_drop_sora +-- Remove all Sora-related database objects. +DROP TABLE IF EXISTS sora_tasks; +DROP TABLE IF EXISTS sora_generations; +DROP TABLE IF EXISTS sora_accounts; +ALTER TABLE groups DROP COLUMN IF EXISTS sora_image_price_360, ...; +ALTER TABLE users DROP COLUMN IF EXISTS sora_storage_quota_bytes, ...; +ALTER TABLE usage_logs DROP COLUMN IF EXISTS media_type; +``` + +**影响**: +- 迁移 046、047、063 创建 Sora 相关表和字段 +- 迁移 090 在同一迁移序列中**全部删除**这些表和字段 +- 但代码库中仍包含完整的 Sora 服务模块(`sora_*.go`) +- 如果按顺序执行所有迁移,Sora 功能将**无法运行** + +**风险等级**: **高 (HIGH)** + +**建议**: +1. 确认 090_drop_sora.sql 的意图:是计划下线 Sora 功能,还是误提交? +2. 如果保留 Sora 功能,应**删除或跳过** 090_drop_sora.sql +3. 如果确实要下线 Sora,应同步删除代码模块,避免运行时错误 + +### 迁移编号冲突 + +- 存在多个 `006_` 前缀的迁移文件(`006_...`, `006_fix_...`, `006b_...`) +- 存在多个 `028_` 前缀的迁移文件 +- 存在多个 `029_`, `030_`, `042_`, `043_`, `044_`, `045_`, `046_`, `052_`, `053_`, `054_` 前缀文件 +- 项目使用文件名排序执行迁移,相同前缀的文件执行顺序可能不稳定 +- **建议**: 对相同前缀的迁移文件确认执行顺序是否符合依赖关系 + +--- + +## 6. 源码质量问题 + +### BOM 问题(阻塞覆盖率) + +**文件**: `backend/internal/config/config_validate_gateway.go` + +**问题**: 文件开头包含 UTF-8 BOM(`EF BB BF`),导致 `go test -cover` 失败: + +``` +internal\config\config_validate_gateway.go:1:1: invalid BOM in the middle of the file +``` + +**影响**: +- 常规 `go test` 和 `go build` 可以通过(Go 编译器对 BOM 容忍度不同) +- 但 `go test -cover` 和 `golangci-lint` 可能失败 +- 影响 CI 中的覆盖率收集 + +**建议**: 移除文件开头的 BOM 字节。 + +--- + +## 7. 上线建议(GO / NO-GO / CONDITIONAL) + +### 总体结论: **CONDITIONAL GO**(条件通过,需修复后上线) + +### 必须修复(阻塞上线) + +| # | 问题 | 优先级 | 负责人建议 | +|---|------|--------|-----------| +| 1 | **迁移 090_drop_sora.sql 与代码矛盾** | P0 | 与架构师确认意图,删除或调整 | +| 2 | **config_validate_gateway.go BOM 问题** | P0 | 移除 BOM,修复覆盖率收集 | + +### 强烈建议修复(上线前) + +| # | 问题 | 优先级 | 说明 | +|---|------|--------|------| +| 3 | 补充 Handler 层测试 | P1 | ops_dashboard_handler, ops_alerts_handler, data_management_handler 等缺少测试 | +| 4 | 统一 Go 版本文档 | P1 | DEV_GUIDE 与 CI、实际环境版本不一致 | +| 5 | 增加前端 CI 工作流 | P1 | 当前仅后端有 CI,前端无自动化测试 | + +### 建议优化(上线后) + +| # | 问题 | 优先级 | +|---|------|--------| +| 6 | 补充 sora_account_service, sora_client 等测试 | P2 | +| 7 | 补充 repository 层缺失测试 | P2 | +| 8 | 前端构建 chunk 优化 | P2 | +| 9 | 安装 golangci-lint 到本地环境 | P2 | +| 10 | 统一迁移文件编号避免前缀冲突 | P2 | + +--- + +## 附录:执行环境信息 + +- **OS**: Windows 10 Enterprise LTSC 2021 +- **Go**: go1.26.2 windows/amd64 +- **Node**: (pnpm 可用,npx vitest 可用) +- **PostgreSQL**: 端口 5432 (配置存在,测试使用 testcontainers) +- **Redis**: 端口 6379 +- **后端测试总耗时**: Unit ~160s, Integration ~140s +- **前端测试总耗时**: ~38s diff --git a/README.md b/README.md index c2715eae..6b93007d 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@
-[![Go](https://img.shields.io/badge/Go-1.25.7-00ADD8.svg)](https://golang.org/) +[![Go](https://img.shields.io/badge/Go-1.26.2-00ADD8.svg)](https://golang.org/) [![Vue](https://img.shields.io/badge/Vue-3.4+-4FC08D.svg)](https://vuejs.org/) [![PostgreSQL](https://img.shields.io/badge/PostgreSQL-15+-336791.svg)](https://www.postgresql.org/) [![Redis](https://img.shields.io/badge/Redis-7+-DC382D.svg)](https://redis.io/) @@ -106,7 +106,7 @@ Community projects that extend or integrate with Sub2API: | Component | Technology | |-----------|------------| -| Backend | Go 1.25.7, Gin, Ent | +| Backend | Go 1.26.2, Gin, Ent | | Frontend | Vue 3.4+, Vite 5+, TailwindCSS | | Database | PostgreSQL 15+ | | Cache/Queue | Redis 7+ | @@ -431,12 +431,6 @@ default: rate_multiplier: 1.0 ``` -### Sora Status (Temporarily Unavailable) - -> ⚠️ Sora-related features are temporarily unavailable due to technical issues in upstream integration and media delivery. -> Please do not rely on Sora in production at this time. -> Existing `gateway.sora_*` configuration keys are reserved and may not take effect until these issues are resolved. - Additional security-related options are available in `config.yaml`: - `cors.allowed_origins` for CORS allowlist diff --git a/README_CN.md b/README_CN.md index 0ace1f77..fcc86fdd 100644 --- a/README_CN.md +++ b/README_CN.md @@ -442,33 +442,6 @@ default: rate_multiplier: 1.0 ``` -### Sora 功能状态(暂不可用) - -> ⚠️ 当前 Sora 相关功能因上游接入与媒体链路存在技术问题,暂时不可用。 -> 现阶段请勿在生产环境依赖 Sora 能力。 -> 文档中的 `gateway.sora_*` 配置仅作预留,待技术问题修复后再恢复可用。 - -### Sora 媒体签名 URL(功能恢复后可选) - -当配置 `gateway.sora_media_signing_key` 且 `gateway.sora_media_signed_url_ttl_seconds > 0` 时,网关会将 Sora 输出的媒体地址改写为临时签名 URL(`/sora/media-signed/...`)。这样无需 API Key 即可在浏览器中直接访问,且具备过期控制与防篡改能力(签名包含 path + query)。 - -```yaml -gateway: - # /sora/media 是否强制要求 API Key(默认 false) - sora_media_require_api_key: false - # 媒体临时签名密钥(为空则禁用签名) - sora_media_signing_key: "your-signing-key" - # 临时签名 URL 有效期(秒) - sora_media_signed_url_ttl_seconds: 900 -``` - -> 若未配置签名密钥,`/sora/media-signed` 将返回 503。 -> 如需更严格的访问控制,可将 `sora_media_require_api_key` 设为 true,仅允许携带 API Key 的 `/sora/media` 访问。 - -访问策略说明: -- `/sora/media`:内部调用或客户端携带 API Key 才能下载 -- `/sora/media-signed`:外部可访问,但有签名 + 过期控制 - `config.yaml` 还支持以下安全相关配置: - `cors.allowed_origins` 配置 CORS 白名单 diff --git a/README_JA.md b/README_JA.md index d74ca9ce..afea8671 100644 --- a/README_JA.md +++ b/README_JA.md @@ -430,12 +430,6 @@ default: rate_multiplier: 1.0 ``` -### Sora ステータス(一時的に利用不可) - -> ⚠️ Sora 関連の機能は、上流統合およびメディア配信の技術的問題により一時的に利用できません。 -> 現時点では本番環境で Sora に依存しないでください。 -> 既存の `gateway.sora_*` 設定キーは予約されていますが、これらの問題が解決されるまで有効にならない場合があります。 - `config.yaml` では追加のセキュリティ関連オプションも利用できます: - `cors.allowed_origins` - CORS 許可リスト diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go index 62cbe7a8..388b0849 100644 --- a/backend/cmd/server/wire.go +++ b/backend/cmd/server/wire.go @@ -105,7 +105,6 @@ func provideCleanup( opsCleanup *service.OpsCleanupService, opsScheduledReport *service.OpsScheduledReportService, opsSystemLogSink *service.OpsSystemLogSink, - soraMediaCleanup *service.SoraMediaCleanupService, // 从本地版本合并 schedulerSnapshot *service.SchedulerSnapshotService, tokenRefresh *service.TokenRefreshService, accountExpiry *service.AccountExpiryService, @@ -266,14 +265,7 @@ func provideCleanup( paymentOrderExpiry.Stop() } return nil - }}, - {"SoraMediaCleanupService", func() error { - if soraMediaCleanup != nil { - soraMediaCleanup.Stop() - } - return nil - }}, - } + } } infraSteps := []cleanupStep{ {"Redis", func() error { diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 77eb8b57..fdd19383 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -218,30 +218,19 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService) channelHandler := admin.NewChannelHandler(channelService, billingService) paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService) - soraGenerationRepository := repository.NewSoraGenerationRepository(db) - soraS3Storage := service.NewSoraS3Storage(settingService) - soraQuotaService := service.NewSoraQuotaService(settingService) - soraGenerationService := service.NewSoraGenerationService(soraGenerationRepository, soraS3Storage, soraQuotaService) - soraHandler := admin.NewSoraHandler(soraGenerationService, soraQuotaService, userRepository) - adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, paymentHandler, soraHandler) + adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, paymentHandler) usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig) userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient) userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig) gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, userMessageQueueService, configConfig, settingService) openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig) - soraAccountRepository := repository.NewSoraAccountRepository(db) - soraSDKClient := service.ProvideSoraSDKClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository) - soraGatewayService := service.NewSoraGatewayService(soraSDKClient, rateLimitService, httpUpstream, configConfig) - soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, usageRecordWorkerPool, configConfig) - soraMediaStorage := service.ProvideSoraMediaStorage(configConfig) - soraClientHandler := handler.NewSoraClientHandler(soraGenerationService, soraQuotaService, soraS3Storage, soraGatewayService, gatewayService, soraMediaStorage, apiKeyService) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) totpHandler := handler.NewTotpHandler(totpService) handlerPaymentHandler := handler.NewPaymentHandler(paymentService, paymentConfigService, channelService) paymentWebhookHandler := handler.NewPaymentWebhookHandler(paymentService, registry) idempotencyCoordinator := service.ProvideIdempotencyCoordinator(idempotencyRepository, configConfig) idempotencyCleanupService := service.ProvideIdempotencyCleanupService(idempotencyRepository, configConfig) - handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, soraGatewayHandler, soraClientHandler, handlerSettingHandler, totpHandler, handlerPaymentHandler, paymentWebhookHandler, idempotencyCoordinator, idempotencyCleanupService) + handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler, handlerPaymentHandler, paymentWebhookHandler, idempotencyCoordinator, idempotencyCleanupService) jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService) adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService) apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig) @@ -254,13 +243,12 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, webhookService, redisClient, configConfig) opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig) opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig) - soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig) tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI) accountExpiryService := service.ProvideAccountExpiryService(accountRepository) subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository) scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig) paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService) - v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService) + v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService) bootstrap := provideBootstrap(settingService, userRepository, configConfig) application := &Application{ Server: httpServer, @@ -315,7 +303,6 @@ func provideCleanup( opsCleanup *service.OpsCleanupService, opsScheduledReport *service.OpsScheduledReportService, opsSystemLogSink *service.OpsSystemLogSink, - soraMediaCleanup *service.SoraMediaCleanupService, schedulerSnapshot *service.SchedulerSnapshotService, tokenRefresh *service.TokenRefreshService, accountExpiry *service.AccountExpiryService, @@ -476,12 +463,6 @@ func provideCleanup( } return nil }}, - {"SoraMediaCleanupService", func() error { - if soraMediaCleanup != nil { - soraMediaCleanup.Stop() - } - return nil - }}, } infraSteps := []cleanupStep{ diff --git a/backend/cmd/server/wire_gen_test.go b/backend/cmd/server/wire_gen_test.go index c88c09d9..1d01050d 100644 --- a/backend/cmd/server/wire_gen_test.go +++ b/backend/cmd/server/wire_gen_test.go @@ -59,7 +59,6 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) { &service.OpsCleanupService{}, &service.OpsScheduledReportService{}, opsSystemLogSinkSvc, - nil, // soraMediaCleanup (从本地版本合并) schedulerSnapshotSvc, tokenRefreshSvc, accountExpirySvc, diff --git a/backend/ent/schema/usage_log.go b/backend/ent/schema/usage_log.go index 867fb7e3..6f398e3c 100644 --- a/backend/ent/schema/usage_log.go +++ b/backend/ent/schema/usage_log.go @@ -57,7 +57,7 @@ func (UsageLog) Fields() []ent.Field { field.String("model_mapping_chain").MaxLen(500).Optional().Nillable().Comment("模型映射链"), field.String("billing_tier").MaxLen(50).Optional().Nillable().Comment("计费层级标签"), field.String("billing_mode").MaxLen(20).Optional().Nillable().Comment("计费模式:token/per_request/image"), - field.String("media_type").MaxLen(16).Optional().Nillable().Comment("媒体类型:video/image(Sora生成)"), + field.String("media_type").MaxLen(16).Optional().Nillable().Comment("媒体类型:video/image"), field.Int64("group_id"). Optional(). Nillable(), diff --git a/backend/ent/usagelog.go b/backend/ent/usagelog.go index ad5680c4..b6d08805 100644 --- a/backend/ent/usagelog.go +++ b/backend/ent/usagelog.go @@ -44,7 +44,7 @@ type UsageLog struct { BillingTier *string `json:"billing_tier,omitempty"` // 计费模式:token/per_request/image BillingMode *string `json:"billing_mode,omitempty"` - // 媒体类型:video/image(Sora生成) + // 媒体类型:video/image MediaType *string `json:"media_type,omitempty"` // GroupID holds the value of the "group_id" field. GroupID *int64 `json:"group_id,omitempty"` diff --git a/backend/go.mod b/backend/go.mod index 0d1420d1..5552b4bb 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -5,7 +5,6 @@ go 1.26.2 require ( entgo.io/ent v0.14.5 github.com/DATA-DOG/go-sqlmock v1.5.2 - github.com/DouDOU-start/go-sora2api v1.1.0 // 从本地版本合并,Sora SDK依赖 github.com/alitto/pond/v2 v2.6.2 github.com/andybalholm/brotli v1.2.0 github.com/aws/aws-sdk-go-v2 v1.41.3 diff --git a/backend/go.sum b/backend/go.sum index ece799d2..88f95a6e 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -242,6 +242,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= +github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw= github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= @@ -318,6 +320,8 @@ github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEv github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= +github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index a76f1fb0..76152b02 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -9,7 +9,7 @@ // - billing.go — BillingConfig, PricingConfig // - gateway.go — GatewayConfig, UserMessageQueue, SchedulingConfig // - gateway_sub.go — OpenAIWS, UsageRecord, TLSFingerprint sub-structs -// - platforms.go — Sora, Gemini, LinuxDo, OIDC, Update, Idempotency configs +// - platforms.go — Gemini, LinuxDo, OIDC, Update, Idempotency configs // - ops_and_cache.go— LogConfig, OpsConfig, Dashboard, Cache, Cleanup configs package config @@ -79,7 +79,6 @@ type Config struct { TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` RunMode string `mapstructure:"run_mode" yaml:"run_mode"` Timezone string `mapstructure:"timezone"` - Sora SoraConfig `mapstructure:"sora"` Gemini GeminiConfig `mapstructure:"gemini"` Update UpdateConfig `mapstructure:"update"` Idempotency IdempotencyConfig `mapstructure:"idempotency"` diff --git a/backend/internal/config/config_domain_test.go b/backend/internal/config/config_domain_test.go index 5201ed7b..f5cee14a 100644 --- a/backend/internal/config/config_domain_test.go +++ b/backend/internal/config/config_domain_test.go @@ -49,7 +49,6 @@ func TestConfigStructIntegrity(t *testing.T) { assert.IsType(t, GatewaySchedulingConfig{}, cfg.Gateway.Scheduling) assert.IsType(t, GatewayOpenAIWSSchedulerScoreWeights{}, cfg.Gateway.OpenAIWS.SchedulerScoreWeights) - assert.IsType(t, SoraConfig{}, cfg.Sora) assert.IsType(t, GeminiConfig{}, cfg.Gemini) assert.IsType(t, UpdateConfig{}, cfg.Update) assert.IsType(t, IdempotencyConfig{}, cfg.Idempotency) @@ -168,14 +167,6 @@ func TestUserMessageQueueConfig_Methods(t *testing.T) { if q.GetEffectiveMode() != "" { t.Error("disabled+empty → empty") } } -func TestSoraConfigFields(t *testing.T) { - s := SoraConfig{ - Client: SoraClientConfig{BaseURL: "https://sora.example.com"}, - Storage: SoraStorageConfig{Type: "local"}, - } - if s.Client.BaseURL != "https://sora.example.com" { t.Error("BaseURL mismatch") } -} - func TestGeminiConfigFields(t *testing.T) { g := GeminiConfig{Quota: GeminiQuotaConfig{Policy: "conservative"}} if g.Quota.Policy != "conservative" { t.Error("Policy mismatch") } diff --git a/backend/internal/config/config_integration_test.go b/backend/internal/config/config_integration_test.go index 6dc979a3..21939a5b 100644 --- a/backend/internal/config/config_integration_test.go +++ b/backend/internal/config/config_integration_test.go @@ -322,7 +322,6 @@ func TestIntegration_AllDomainFiles_ContributeToFullConfig(t *testing.T) { {"UsageCleanup/Enabled", func(c *Config) bool { return true }}, {"Concurrency/PingInterval", func(c *Config) bool { return c.Concurrency.PingInterval > 0 }}, {"TokenRefresh/Enabled", func(c *Config) bool { return true }}, - {"Sora", func(c *Config) bool { return true }}, {"Gemini", func(c *Config) bool { return true }}, {"Update", func(c *Config) bool { return true }}, {"Idempotency/TTL", func(c *Config) bool { return c.Idempotency.DefaultTTLSeconds > 0 }}, diff --git a/backend/internal/config/config_validate_gateway.go b/backend/internal/config/config_validate_gateway.go index 339f1a31..c3bc7cf9 100644 --- a/backend/internal/config/config_validate_gateway.go +++ b/backend/internal/config/config_validate_gateway.go @@ -1,4 +1,4 @@ -package config +package config import ( "fmt" diff --git a/backend/internal/config/gateway.go b/backend/internal/config/gateway.go index aeb55e02..bf577ef2 100644 --- a/backend/internal/config/gateway.go +++ b/backend/internal/config/gateway.go @@ -36,16 +36,6 @@ type GatewayConfig struct { InjectBetaForAPIKey bool `mapstructure:"inject_beta_for_apikey"` FailoverOn400 bool `mapstructure:"failover_on_400"` - // Sora 专用配置 - SoraMaxBodySize int64 `mapstructure:"sora_max_body_size"` - SoraStreamTimeoutSeconds int `mapstructure:"sora_stream_timeout_seconds"` - SoraRequestTimeoutSeconds int `mapstructure:"sora_request_timeout_seconds"` - SoraStreamMode string `mapstructure:"sora_stream_mode"` - SoraModelFilters SoraModelFiltersConfig `mapstructure:"sora_model_filters"` - SoraMediaRequireAPIKey bool `mapstructure:"sora_media_require_api_key"` - SoraMediaSigningKey string `mapstructure:"sora_media_signing_key"` - SoraMediaSignedURLTTLSeconds int `mapstructure:"sora_media_signed_url_ttl_seconds"` - MaxAccountSwitches int `mapstructure:"max_account_switches"` MaxAccountSwitchesGemini int `mapstructure:"max_account_switches_gemini"` AntigravityFallbackCooldownMinutes int `mapstructure:"antigravity_fallback_cooldown_minutes"` diff --git a/backend/internal/config/platforms.go b/backend/internal/config/platforms.go index 65123105..341cb793 100644 --- a/backend/internal/config/platforms.go +++ b/backend/internal/config/platforms.go @@ -1,63 +1,5 @@ package config -// SoraConfig 直连 Sora 配置 -type SoraConfig struct { - Client SoraClientConfig `mapstructure:"client"` - Storage SoraStorageConfig `mapstructure:"storage"` -} - -// SoraClientConfig Sora 客户端配置 -type SoraClientConfig struct { - BaseURL string `mapstructure:"base_url"` - TimeoutSeconds int `mapstructure:"timeout_seconds"` - MaxRetries int `mapstructure:"max_retries"` - CloudflareChallengeCooldownSeconds int `mapstructure:"cloudflare_challenge_cooldown_seconds"` - PollIntervalSeconds int `mapstructure:"poll_interval_seconds"` - MaxPollAttempts int `mapstructure:"max_poll_attempts"` - RecentTaskLimit int `mapstructure:"recent_task_limit"` - RecentTaskLimitMax int `mapstructure:"recent_task_limit_max"` - Debug bool `mapstructure:"debug"` - UseOpenAITokenProvider bool `mapstructure:"use_openai_token_provider"` - Headers map[string]string `mapstructure:"headers"` - UserAgent string `mapstructure:"user_agent"` - DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"` - CurlCFFISidecar SoraCurlCFFISidecarConfig `mapstructure:"curl_cffi_sidecar"` -} - -// SoraCurlCFFISidecarConfig Sora curl_cffi sidecar 配置 -type SoraCurlCFFISidecarConfig struct { - Enabled bool `mapstructure:"enabled"` - BaseURL string `mapstructure:"base_url"` - Impersonate string `mapstructure:"impersonate"` - TimeoutSeconds int `mapstructure:"timeout_seconds"` - SessionReuseEnabled bool `mapstructure:"session_reuse_enabled"` - SessionTTLSeconds int `mapstructure:"session_ttl_seconds"` -} - -// SoraStorageConfig 媒体存储配置 -type SoraStorageConfig struct { - Type string `mapstructure:"type"` - LocalPath string `mapstructure:"local_path"` - FallbackToUpstream bool `mapstructure:"fallback_to_upstream"` - MaxConcurrentDownloads int `mapstructure:"max_concurrent_downloads"` - DownloadTimeoutSeconds int `mapstructure:"download_timeout_seconds"` - MaxDownloadBytes int64 `mapstructure:"max_download_bytes"` - Debug bool `mapstructure:"debug"` - Cleanup SoraStorageCleanupConfig `mapstructure:"cleanup"` -} - -// SoraStorageCleanupConfig 媒体清理配置 -type SoraStorageCleanupConfig struct { - Enabled bool `mapstructure:"enabled"` - Schedule string `mapstructure:"schedule"` - RetentionDays int `mapstructure:"retention_days"` -} - -// SoraModelFiltersConfig Sora 模型过滤配置 -type SoraModelFiltersConfig struct { - HidePromptEnhance bool `mapstructure:"hide_prompt_envelope"` -} - // GeminiConfig Gemini 配置 type GeminiConfig struct { OAuth GeminiOAuthConfig `mapstructure:"oauth"` diff --git a/backend/internal/domain/constants.go b/backend/internal/domain/constants.go index 885ef834..429486c3 100644 --- a/backend/internal/domain/constants.go +++ b/backend/internal/domain/constants.go @@ -22,7 +22,6 @@ const ( PlatformOpenAI = "openai" PlatformGemini = "gemini" PlatformAntigravity = "antigravity" - PlatformSora = "sora" // Sora视频生成平台 (从本地版本合并) ) // Account type constants diff --git a/backend/internal/handler/admin/sora_handler.go b/backend/internal/handler/admin/sora_handler.go deleted file mode 100644 index dfc2053a..00000000 --- a/backend/internal/handler/admin/sora_handler.go +++ /dev/null @@ -1,142 +0,0 @@ -package admin - -import ( - "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" - "github.com/Wei-Shaw/sub2api/internal/pkg/response" - "github.com/Wei-Shaw/sub2api/internal/service" - - "github.com/gin-gonic/gin" -) - -// SoraHandler handles admin Sora statistics and management. -type SoraHandler struct { - soraGenService *service.SoraGenerationService - soraQuotaService *service.SoraQuotaService - userRepo service.UserRepository -} - -// NewSoraHandler creates a new admin Sora handler. -func NewSoraHandler( - soraGenService *service.SoraGenerationService, - soraQuotaService *service.SoraQuotaService, - userRepo service.UserRepository, -) *SoraHandler { - return &SoraHandler{ - soraGenService: soraGenService, - soraQuotaService: soraQuotaService, - userRepo: userRepo, - } -} - -type SoraSystemStatsResponse struct { - TotalUsers int64 `json:"total_users"` - TotalGenerations int64 `json:"total_generations"` - TotalStorageBytes int64 `json:"total_storage_bytes"` - ActiveGenerations int64 `json:"active_generations"` - ByStatus map[string]int64 `json:"by_status"` - ByModel map[string]int64 `json:"by_model"` -} - -// GetSystemStats returns aggregate admin Sora statistics. -func (h *SoraHandler) GetSystemStats(c *gin.Context) { - ctx := c.Request.Context() - - users, _, err := h.userRepo.List(ctx, pagination.PaginationParams{Page: 1, PageSize: 10000}) - if err != nil { - response.Error(c, 500, "Failed to get users") - return - } - - resp := SoraSystemStatsResponse{ - TotalUsers: int64(len(users)), - TotalGenerations: 0, - TotalStorageBytes: 0, - ActiveGenerations: 0, - ByStatus: map[string]int64{}, - ByModel: map[string]int64{}, - } - - response.Success(c, resp) -} - -type SoraUserStatsResponse struct { - UserID int64 `json:"user_id"` - Username string `json:"username"` - Email string `json:"email"` - QuotaBytes int64 `json:"quota_bytes"` - UsedBytes int64 `json:"used_bytes"` - AvailableBytes int64 `json:"available_bytes"` - QuotaSource string `json:"quota_source"` - GenerationsCount int64 `json:"generations_count"` - ActiveCount int64 `json:"active_count"` - TotalFileSizeBytes int64 `json:"total_file_size_bytes"` -} - -// ListUserStats returns per-user admin Sora usage rows. -func (h *SoraHandler) ListUserStats(c *gin.Context) { - ctx := c.Request.Context() - page, pageSize := response.ParsePagination(c) - search := c.Query("search") - - users, result, err := h.userRepo.ListWithFilters(ctx, pagination.PaginationParams{ - Page: page, - PageSize: pageSize, - }, service.UserListFilters{Search: search}) - if err != nil { - response.Error(c, 500, "Failed to get users") - return - } - - results := make([]SoraUserStatsResponse, len(users)) - for i, u := range users { - quota, _ := h.soraQuotaService.GetQuota(ctx, u.ID) - activeCount, _ := h.soraGenService.CountActiveByUser(ctx, u.ID) - - quotaBytes := int64(0) - availableBytes := int64(0) - quotaSource := "unlimited" - - if quota != nil { - quotaBytes = quota.QuotaBytes - availableBytes = quota.AvailableBytes - quotaSource = quota.QuotaSource - } - - results[i] = SoraUserStatsResponse{ - UserID: u.ID, - Username: u.Username, - Email: u.Email, - QuotaBytes: quotaBytes, - UsedBytes: 0, - AvailableBytes: availableBytes, - QuotaSource: quotaSource, - GenerationsCount: 0, - ActiveCount: activeCount, - TotalFileSizeBytes: 0, - } - } - - response.Paginated(c, results, result.Total, page, pageSize) -} - -type SoraGenerationAdminResponse struct { - ID int64 `json:"id"` - UserID int64 `json:"user_id"` - Username string `json:"username"` - Email string `json:"email"` - Model string `json:"model"` - Prompt string `json:"prompt"` - MediaType string `json:"media_type"` - Status string `json:"status"` - StorageType string `json:"storage_type"` - MediaURL string `json:"media_url"` - FileSizeBytes int64 `json:"file_size_bytes"` - ErrorMessage string `json:"error_message"` - CreatedAt string `json:"created_at"` - CompletedAt *string `json:"completed_at"` -} - -// ListGenerations returns admin-visible generation rows. -func (h *SoraHandler) ListGenerations(c *gin.Context) { - response.Paginated(c, []SoraGenerationAdminResponse{}, int64(0), 1, 20) -} diff --git a/backend/internal/handler/admin/sora_handler_test.go b/backend/internal/handler/admin/sora_handler_test.go deleted file mode 100644 index 538ed592..00000000 --- a/backend/internal/handler/admin/sora_handler_test.go +++ /dev/null @@ -1,262 +0,0 @@ -package admin - -import ( - "net/http" - "net/http/httptest" - "testing" - - "github.com/Wei-Shaw/sub2api/internal/service" - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/assert" -) - -func TestSoraHandler_ListGenerations(t *testing.T) { - gin.SetMode(gin.TestMode) - - handler := &SoraHandler{} - - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - c.Request = httptest.NewRequest(http.MethodGet, "/admin/sora/generations", nil) - - handler.ListGenerations(c) - - assert.Equal(t, http.StatusOK, w.Code) - assert.Contains(t, w.Body.String(), "items") -} - -func TestSoraSystemStatsResponse_Fields(t *testing.T) { - resp := SoraSystemStatsResponse{ - TotalUsers: 10, - TotalGenerations: 100, - TotalStorageBytes: 1024 * 1024 * 1024, - ActiveGenerations: 5, - ByStatus: map[string]int64{"completed": 80, "failed": 20}, - ByModel: map[string]int64{"sora2": 50, "sora1": 50}, - } - - assert.Equal(t, int64(10), resp.TotalUsers) - assert.Equal(t, int64(100), resp.TotalGenerations) - assert.Equal(t, int64(1024*1024*1024), resp.TotalStorageBytes) - assert.Equal(t, int64(5), resp.ActiveGenerations) - assert.Equal(t, int64(80), resp.ByStatus["completed"]) - assert.Equal(t, int64(50), resp.ByModel["sora2"]) -} - -func TestSoraUserStatsResponse_Fields(t *testing.T) { - resp := SoraUserStatsResponse{ - UserID: 1, - Username: "testuser", - Email: "test@example.com", - QuotaBytes: 10 * 1024 * 1024 * 1024, - UsedBytes: 1 * 1024 * 1024 * 1024, - AvailableBytes: 9 * 1024 * 1024 * 1024, - QuotaSource: "user", - GenerationsCount: 10, - ActiveCount: 2, - TotalFileSizeBytes: 1 * 1024 * 1024 * 1024, - } - - assert.Equal(t, int64(1), resp.UserID) - assert.Equal(t, "testuser", resp.Username) - assert.Equal(t, "test@example.com", resp.Email) - assert.Equal(t, int64(10*1024*1024*1024), resp.QuotaBytes) - assert.Equal(t, int64(1*1024*1024*1024), resp.UsedBytes) - assert.Equal(t, "user", resp.QuotaSource) - assert.Equal(t, int64(10), resp.GenerationsCount) - assert.Equal(t, int64(2), resp.ActiveCount) -} - -func TestSoraGenerationAdminResponse_Fields(t *testing.T) { - completedAt := "2024-01-01T12:00:00Z" - resp := SoraGenerationAdminResponse{ - ID: 1, - UserID: 100, - Username: "testuser", - Email: "test@example.com", - Model: "sora2", - Prompt: "A beautiful sunset", - MediaType: "video", - Status: "completed", - StorageType: "s3", - MediaURL: "https://example.com/video.mp4", - FileSizeBytes: 1024 * 1024 * 10, - ErrorMessage: "", - CreatedAt: "2024-01-01T10:00:00Z", - CompletedAt: &completedAt, - } - - assert.Equal(t, int64(1), resp.ID) - assert.Equal(t, int64(100), resp.UserID) - assert.Equal(t, "testuser", resp.Username) - assert.Equal(t, "sora2", resp.Model) - assert.Equal(t, "video", resp.MediaType) - assert.Equal(t, "completed", resp.Status) - assert.Equal(t, "s3", resp.StorageType) - assert.Equal(t, int64(1024*1024*10), resp.FileSizeBytes) - assert.NotNil(t, resp.CompletedAt) -} - -func TestSoraGenerationAdminResponse_NilCompletedAt(t *testing.T) { - resp := SoraGenerationAdminResponse{ - ID: 1, - UserID: 100, - Username: "testuser", - Email: "test@example.com", - Model: "sora2", - Prompt: "A beautiful sunset", - MediaType: "video", - Status: "pending", - StorageType: "upstream", - CreatedAt: "2024-01-01T10:00:00Z", - CompletedAt: nil, - } - - assert.Equal(t, "pending", resp.Status) - assert.Nil(t, resp.CompletedAt) -} - -func TestNewSoraHandler(t *testing.T) { - handler := NewSoraHandler(nil, nil, nil) - assert.NotNil(t, handler) - assert.Nil(t, handler.soraGenService) - assert.Nil(t, handler.soraQuotaService) - assert.Nil(t, handler.userRepo) -} - -func TestUser_SoraFields(t *testing.T) { - user := &service.User{ - ID: 1, - Email: "test@example.com", - } - - assert.Equal(t, int64(1), user.ID) - assert.Equal(t, "test@example.com", user.Email) -} - -func TestQuotaInfo_Fields(t *testing.T) { - quota := &service.QuotaInfo{ - QuotaBytes: 10 * 1024 * 1024 * 1024, - UsedBytes: 1 * 1024 * 1024 * 1024, - AvailableBytes: 9 * 1024 * 1024 * 1024, - QuotaSource: "user", - } - - assert.Equal(t, int64(10*1024*1024*1024), quota.QuotaBytes) - assert.Equal(t, int64(1*1024*1024*1024), quota.UsedBytes) - assert.Equal(t, "user", quota.QuotaSource) -} - -func TestSoraSystemStatsResponse_JSON(t *testing.T) { - resp := SoraSystemStatsResponse{ - TotalUsers: 10, - TotalGenerations: 100, - TotalStorageBytes: 1024, - ActiveGenerations: 5, - ByStatus: map[string]int64{"completed": 80}, - ByModel: map[string]int64{"sora2": 50}, - } - - // Verify JSON tags by checking field values - assert.Equal(t, int64(10), resp.TotalUsers) - assert.Equal(t, int64(100), resp.TotalGenerations) - assert.Equal(t, int64(1024), resp.TotalStorageBytes) - assert.Equal(t, int64(5), resp.ActiveGenerations) -} - -func TestSoraUserStatsResponse_JSON(t *testing.T) { - resp := SoraUserStatsResponse{ - UserID: 1, - Username: "testuser", - Email: "test@example.com", - QuotaBytes: 1024, - UsedBytes: 512, - AvailableBytes: 512, - QuotaSource: "user", - GenerationsCount: 10, - ActiveCount: 2, - TotalFileSizeBytes: 1024, - } - - // Verify all fields - assert.Equal(t, int64(1), resp.UserID) - assert.Equal(t, "testuser", resp.Username) - assert.Equal(t, "test@example.com", resp.Email) - assert.Equal(t, int64(1024), resp.QuotaBytes) - assert.Equal(t, int64(512), resp.UsedBytes) - assert.Equal(t, int64(512), resp.AvailableBytes) - assert.Equal(t, "user", resp.QuotaSource) - assert.Equal(t, int64(10), resp.GenerationsCount) - assert.Equal(t, int64(2), resp.ActiveCount) - assert.Equal(t, int64(1024), resp.TotalFileSizeBytes) -} - -func TestSoraSystemStatsResponse_EmptyMaps(t *testing.T) { - resp := SoraSystemStatsResponse{ - TotalUsers: 0, - TotalGenerations: 0, - TotalStorageBytes: 0, - ActiveGenerations: 0, - ByStatus: map[string]int64{}, - ByModel: map[string]int64{}, - } - - assert.Equal(t, int64(0), resp.TotalUsers) - assert.Equal(t, int64(0), resp.TotalGenerations) - assert.Equal(t, int64(0), resp.TotalStorageBytes) - assert.Equal(t, int64(0), resp.ActiveGenerations) - assert.NotNil(t, resp.ByStatus) - assert.NotNil(t, resp.ByModel) -} - -func TestSoraUserStatsResponse_QuotaSources(t *testing.T) { - sources := []string{"user", "group", "system", "unlimited"} - - for _, source := range sources { - resp := SoraUserStatsResponse{ - UserID: 1, - QuotaSource: source, - } - - assert.Equal(t, source, resp.QuotaSource) - } -} - -func TestSoraGenerationAdminResponse_Statuses(t *testing.T) { - statuses := []string{"pending", "generating", "completed", "failed", "cancelled"} - - for _, status := range statuses { - resp := SoraGenerationAdminResponse{ - ID: 1, - Status: status, - } - - assert.Equal(t, status, resp.Status) - } -} - -func TestSoraGenerationAdminResponse_MediaTypes(t *testing.T) { - mediaTypes := []string{"video", "image"} - - for _, mt := range mediaTypes { - resp := SoraGenerationAdminResponse{ - ID: 1, - MediaType: mt, - } - - assert.Equal(t, mt, resp.MediaType) - } -} - -func TestSoraGenerationAdminResponse_StorageTypes(t *testing.T) { - storageTypes := []string{"s3", "upstream"} - - for _, st := range storageTypes { - resp := SoraGenerationAdminResponse{ - ID: 1, - StorageType: st, - } - - assert.Equal(t, st, resp.StorageType) - } -} diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index cbbe9216..f5954032 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -179,7 +179,6 @@ type PublicSettings struct { LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"` OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"` - SoraClientEnabled bool `json:"sora_client_enabled"` BackendModeEnabled bool `json:"backend_mode_enabled"` PaymentEnabled bool `json:"payment_enabled"` Version string `json:"version"` diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go index 9560f1ff..bc6a1f06 100644 --- a/backend/internal/handler/handler.go +++ b/backend/internal/handler/handler.go @@ -31,7 +31,6 @@ type AdminHandlers struct { ScheduledTest *admin.ScheduledTestHandler Channel *admin.ChannelHandler Payment *admin.PaymentHandler - Sora *admin.SoraHandler } // Handlers contains all HTTP handlers @@ -46,8 +45,6 @@ type Handlers struct { Admin *AdminHandlers Gateway *GatewayHandler OpenAIGateway *OpenAIGatewayHandler - SoraGateway *SoraGatewayHandler // 从本地版本合并 - SoraClient *SoraClientHandler // 从本地版本合并 Setting *SettingHandler Totp *TotpHandler Payment *PaymentHandler diff --git a/backend/internal/handler/sora_client_handler.go b/backend/internal/handler/sora_client_handler.go deleted file mode 100644 index 80acc833..00000000 --- a/backend/internal/handler/sora_client_handler.go +++ /dev/null @@ -1,979 +0,0 @@ -package handler - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "net/http/httptest" - "strconv" - "strings" - "sync" - "time" - - "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" - "github.com/Wei-Shaw/sub2api/internal/pkg/logger" - "github.com/Wei-Shaw/sub2api/internal/pkg/response" - middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" - "github.com/Wei-Shaw/sub2api/internal/service" - "github.com/gin-gonic/gin" -) - -const ( - // 上游模型缓存 TTL - modelCacheTTL = 1 * time.Hour // 上游获取成功 - modelCacheFailedTTL = 2 * time.Minute // 上游获取失败(降级到本地) -) - -// SoraClientHandler 处理 Sora 客户端 API 请求。 -type SoraClientHandler struct { - genService *service.SoraGenerationService - quotaService *service.SoraQuotaService - s3Storage *service.SoraS3Storage - soraGatewayService *service.SoraGatewayService - gatewayService *service.GatewayService - mediaStorage *service.SoraMediaStorage - apiKeyService *service.APIKeyService - - // 上游模型缓存 - modelCacheMu sync.RWMutex - cachedFamilies []service.SoraModelFamily - modelCacheTime time.Time - modelCacheUpstream bool // 是否来自上游(决定 TTL) -} - -// NewSoraClientHandler 创建 Sora 客户端 Handler。 -func NewSoraClientHandler( - genService *service.SoraGenerationService, - quotaService *service.SoraQuotaService, - s3Storage *service.SoraS3Storage, - soraGatewayService *service.SoraGatewayService, - gatewayService *service.GatewayService, - mediaStorage *service.SoraMediaStorage, - apiKeyService *service.APIKeyService, -) *SoraClientHandler { - return &SoraClientHandler{ - genService: genService, - quotaService: quotaService, - s3Storage: s3Storage, - soraGatewayService: soraGatewayService, - gatewayService: gatewayService, - mediaStorage: mediaStorage, - apiKeyService: apiKeyService, - } -} - -// GenerateRequest 生成请求。 -type GenerateRequest struct { - Model string `json:"model" binding:"required"` - Prompt string `json:"prompt" binding:"required"` - MediaType string `json:"media_type"` // video / image,默认 video - VideoCount int `json:"video_count,omitempty"` // 视频数量(1-3) - ImageInput string `json:"image_input,omitempty"` // 参考图(base64 或 URL) - APIKeyID *int64 `json:"api_key_id,omitempty"` // 前端传递的 API Key ID -} - -// Generate 异步生成 — 创建 pending 记录后立即返回。 -// POST /api/v1/sora/generate -func (h *SoraClientHandler) Generate(c *gin.Context) { - userID := getUserIDFromContext(c) - if userID == 0 { - response.Error(c, http.StatusUnauthorized, "未登录") - return - } - - var req GenerateRequest - if err := c.ShouldBindJSON(&req); err != nil { - response.Error(c, http.StatusBadRequest, "参数错误: "+err.Error()) - return - } - - if req.MediaType == "" { - req.MediaType = "video" - } - req.VideoCount = normalizeVideoCount(req.MediaType, req.VideoCount) - - // 并发数检查(最多 3 个) - activeCount, err := h.genService.CountActiveByUser(c.Request.Context(), userID) - if err != nil { - response.ErrorFrom(c, err) - return - } - if activeCount >= 3 { - response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个") - return - } - - // 配额检查(粗略检查,实际文件大小在上传后才知道) - if h.quotaService != nil { - if err := h.quotaService.CheckQuota(c.Request.Context(), userID, 0); err != nil { - var quotaErr *service.QuotaExceededError - if errors.As(err, "aErr) { - response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间") - return - } - response.Error(c, http.StatusForbidden, err.Error()) - return - } - } - - // 获取 API Key ID 和 Group ID - var apiKeyID *int64 - var groupID *int64 - - if req.APIKeyID != nil && h.apiKeyService != nil { - // 前端传递了 api_key_id,需要校验 - apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), *req.APIKeyID) - if err != nil { - response.Error(c, http.StatusBadRequest, "API Key 不存在") - return - } - if apiKey.UserID != userID { - response.Error(c, http.StatusForbidden, "API Key 不属于当前用户") - return - } - if apiKey.Status != service.StatusAPIKeyActive { - response.Error(c, http.StatusForbidden, "API Key 不可用") - return - } - apiKeyID = &apiKey.ID - groupID = apiKey.GroupID - } else if id, ok := c.Get("api_key_id"); ok { - // 兼容 API Key 认证路径(/sora/v1/ 网关路由) - if v, ok := id.(int64); ok { - apiKeyID = &v - } - } - - gen, err := h.genService.CreatePending(c.Request.Context(), userID, apiKeyID, req.Model, req.Prompt, req.MediaType) - if err != nil { - if errors.Is(err, service.ErrSoraGenerationConcurrencyLimit) { - response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个") - return - } - response.ErrorFrom(c, err) - return - } - - // 启动后台异步生成 goroutine - go h.processGeneration(gen.ID, userID, groupID, req.Model, req.Prompt, req.MediaType, req.ImageInput, req.VideoCount) - - response.Success(c, gin.H{ - "generation_id": gen.ID, - "status": gen.Status, - }) -} - -// processGeneration 后台异步执行 Sora 生成任务。 -// 流程:选择账号 → Forward → 提取媒体 URL → 三层降级存储(S3 → 本地 → 上游)→ 更新记录。 -func (h *SoraClientHandler) processGeneration(genID int64, userID int64, groupID *int64, model, prompt, mediaType, imageInput string, videoCount int) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) - defer cancel() - - // 标记为生成中 - if err := h.genService.MarkGenerating(ctx, genID, ""); err != nil { - if errors.Is(err, service.ErrSoraGenerationStateConflict) { - logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务状态已变化,跳过生成 id=%d", genID) - return - } - logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记生成中失败 id=%d err=%v", genID, err) - return - } - - logger.LegacyPrintf( - "handler.sora_client", - "[SoraClient] 开始生成 id=%d user=%d group=%d model=%s media_type=%s video_count=%d has_image=%v prompt_len=%d", - genID, - userID, - groupIDForLog(groupID), - model, - mediaType, - videoCount, - strings.TrimSpace(imageInput) != "", - len(strings.TrimSpace(prompt)), - ) - - // 有 groupID 时由分组决定平台,无 groupID 时用 ForcePlatform 兜底 - if groupID == nil { - ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora) - } - - if h.gatewayService == nil { - _ = h.genService.MarkFailed(ctx, genID, "内部错误: gatewayService 未初始化") - return - } - - // 选择 Sora 账号 - account, err := h.gatewayService.SelectAccountForModel(ctx, groupID, "", model) - if err != nil { - logger.LegacyPrintf( - "handler.sora_client", - "[SoraClient] 选择账号失败 id=%d user=%d group=%d model=%s err=%v", - genID, - userID, - groupIDForLog(groupID), - model, - err, - ) - _ = h.genService.MarkFailed(ctx, genID, "选择账号失败: "+err.Error()) - return - } - logger.LegacyPrintf( - "handler.sora_client", - "[SoraClient] 选中账号 id=%d user=%d group=%d model=%s account_id=%d account_name=%s platform=%s type=%s", - genID, - userID, - groupIDForLog(groupID), - model, - account.ID, - account.Name, - account.Platform, - account.Type, - ) - - // 构建 chat completions 请求体(非流式) - body := buildAsyncRequestBody(model, prompt, imageInput, normalizeVideoCount(mediaType, videoCount)) - - if h.soraGatewayService == nil { - _ = h.genService.MarkFailed(ctx, genID, "内部错误: soraGatewayService 未初始化") - return - } - - // 创建 mock gin 上下文用于 Forward(捕获响应以提取媒体 URL) - recorder := httptest.NewRecorder() - mockGinCtx, _ := gin.CreateTestContext(recorder) - mockGinCtx.Request, _ = http.NewRequest("POST", "/", nil) - - // 调用 Forward(非流式) - result, err := h.soraGatewayService.Forward(ctx, mockGinCtx, account, body, false) - if err != nil { - logger.LegacyPrintf( - "handler.sora_client", - "[SoraClient] Forward失败 id=%d account_id=%d model=%s status=%d body=%s err=%v", - genID, - account.ID, - model, - recorder.Code, - trimForLog(recorder.Body.String(), 400), - err, - ) - // 检查是否已取消 - gen, _ := h.genService.GetByID(ctx, genID, userID) - if gen != nil && gen.Status == service.SoraGenStatusCancelled { - return - } - _ = h.genService.MarkFailed(ctx, genID, "生成失败: "+err.Error()) - return - } - - // 提取媒体 URL(优先从 ForwardResult,其次从响应体解析) - mediaURL, mediaURLs := extractMediaURLsFromResult(result, recorder) - if mediaURL == "" { - logger.LegacyPrintf( - "handler.sora_client", - "[SoraClient] 未提取到媒体URL id=%d account_id=%d model=%s status=%d body=%s", - genID, - account.ID, - model, - recorder.Code, - trimForLog(recorder.Body.String(), 400), - ) - _ = h.genService.MarkFailed(ctx, genID, "未获取到媒体 URL") - return - } - - // 检查任务是否已被取消 - gen, _ := h.genService.GetByID(ctx, genID, userID) - if gen != nil && gen.Status == service.SoraGenStatusCancelled { - logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务已取消,跳过存储 id=%d", genID) - return - } - - // 三层降级存储:S3 → 本地 → 上游临时 URL - storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(ctx, userID, mediaType, mediaURL, mediaURLs) - - usageAdded := false - if (storageType == service.SoraStorageTypeS3 || storageType == service.SoraStorageTypeLocal) && fileSize > 0 && h.quotaService != nil { - if err := h.quotaService.AddUsage(ctx, userID, fileSize); err != nil { - h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs) - var quotaErr *service.QuotaExceededError - if errors.As(err, "aErr) { - _ = h.genService.MarkFailed(ctx, genID, "存储配额已满,请删除不需要的作品释放空间") - return - } - _ = h.genService.MarkFailed(ctx, genID, "存储配额更新失败: "+err.Error()) - return - } - usageAdded = true - } - - // 存储完成后再做一次取消检查,防止取消被 completed 覆盖。 - gen, _ = h.genService.GetByID(ctx, genID, userID) - if gen != nil && gen.Status == service.SoraGenStatusCancelled { - logger.LegacyPrintf("handler.sora_client", "[SoraClient] 存储后检测到任务已取消,回滚存储 id=%d", genID) - h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs) - if usageAdded && h.quotaService != nil { - _ = h.quotaService.ReleaseUsage(ctx, userID, fileSize) - } - return - } - - // 标记完成 - if err := h.genService.MarkCompleted(ctx, genID, storedURL, storedURLs, storageType, s3Keys, fileSize); err != nil { - if errors.Is(err, service.ErrSoraGenerationStateConflict) { - h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs) - if usageAdded && h.quotaService != nil { - _ = h.quotaService.ReleaseUsage(ctx, userID, fileSize) - } - return - } - logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记完成失败 id=%d err=%v", genID, err) - return - } - - logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成完成 id=%d storage=%s size=%d", genID, storageType, fileSize) -} - -// storeMediaWithDegradation 实现三层降级存储链:S3 → 本地 → 上游。 -func (h *SoraClientHandler) storeMediaWithDegradation( - ctx context.Context, userID int64, mediaType string, - mediaURL string, mediaURLs []string, -) (storedURL string, storedURLs []string, storageType string, s3Keys []string, fileSize int64) { - urls := mediaURLs - if len(urls) == 0 { - urls = []string{mediaURL} - } - - // 第一层:尝试 S3 - if h.s3Storage != nil && h.s3Storage.Enabled(ctx) { - keys := make([]string, 0, len(urls)) - var totalSize int64 - allOK := true - for _, u := range urls { - key, size, err := h.s3Storage.UploadFromURL(ctx, userID, u) - if err != nil { - logger.LegacyPrintf("handler.sora_client", "[SoraClient] S3 上传失败 err=%v", err) - allOK = false - // 清理已上传的文件 - if len(keys) > 0 { - _ = h.s3Storage.DeleteObjects(ctx, keys) - } - break - } - keys = append(keys, key) - totalSize += size - } - if allOK && len(keys) > 0 { - accessURLs := make([]string, 0, len(keys)) - for _, key := range keys { - accessURL, err := h.s3Storage.GetAccessURL(ctx, key) - if err != nil { - logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成 S3 访问 URL 失败 err=%v", err) - _ = h.s3Storage.DeleteObjects(ctx, keys) - allOK = false - break - } - accessURLs = append(accessURLs, accessURL) - } - if allOK && len(accessURLs) > 0 { - return accessURLs[0], accessURLs, service.SoraStorageTypeS3, keys, totalSize - } - } - } - - // 第二层:尝试本地存储 - if h.mediaStorage != nil && h.mediaStorage.Enabled() { - storedPaths, err := h.mediaStorage.StoreFromURLs(ctx, mediaType, urls) - if err == nil && len(storedPaths) > 0 { - firstPath := storedPaths[0] - totalSize, sizeErr := h.mediaStorage.TotalSizeByRelativePaths(storedPaths) - if sizeErr != nil { - logger.LegacyPrintf("handler.sora_client", "[SoraClient] 统计本地文件大小失败 err=%v", sizeErr) - } - return firstPath, storedPaths, service.SoraStorageTypeLocal, nil, totalSize - } - logger.LegacyPrintf("handler.sora_client", "[SoraClient] 本地存储失败 err=%v", err) - } - - // 第三层:保留上游临时 URL - return urls[0], urls, service.SoraStorageTypeUpstream, nil, 0 -} - -// buildAsyncRequestBody 构建 Sora 异步生成的 chat completions 请求体。 -func buildAsyncRequestBody(model, prompt, imageInput string, videoCount int) []byte { - body := map[string]any{ - "model": model, - "messages": []map[string]string{ - {"role": "user", "content": prompt}, - }, - "stream": false, - } - if imageInput != "" { - body["image_input"] = imageInput - } - if videoCount > 1 { - body["video_count"] = videoCount - } - b, _ := json.Marshal(body) - return b -} - -func normalizeVideoCount(mediaType string, videoCount int) int { - if mediaType != "video" { - return 1 - } - if videoCount <= 0 { - return 1 - } - if videoCount > 3 { - return 3 - } - return videoCount -} - -// extractMediaURLsFromResult 从 Forward 结果和响应体中提取媒体 URL。 -// OAuth 路径:ForwardResult.MediaURL 已填充。 -// APIKey 路径:需从响应体解析 media_url / media_urls 字段。 -func extractMediaURLsFromResult(result *service.ForwardResult, recorder *httptest.ResponseRecorder) (string, []string) { - // 优先从 ForwardResult 获取(OAuth 路径) - if result != nil && result.MediaURL != "" { - // 尝试从响应体获取完整 URL 列表 - if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 { - return urls[0], urls - } - return result.MediaURL, []string{result.MediaURL} - } - - // 从响应体解析(APIKey 路径) - if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 { - return urls[0], urls - } - - return "", nil -} - -// parseMediaURLsFromBody 从 JSON 响应体中解析 media_url / media_urls 字段。 -func parseMediaURLsFromBody(body []byte) []string { - if len(body) == 0 { - return nil - } - var resp map[string]any - if err := json.Unmarshal(body, &resp); err != nil { - return nil - } - - // 优先 media_urls(多图数组) - if rawURLs, ok := resp["media_urls"]; ok { - if arr, ok := rawURLs.([]any); ok && len(arr) > 0 { - urls := make([]string, 0, len(arr)) - for _, item := range arr { - if s, ok := item.(string); ok && s != "" { - urls = append(urls, s) - } - } - if len(urls) > 0 { - return urls - } - } - } - - // 回退到 media_url(单个 URL) - if url, ok := resp["media_url"].(string); ok && url != "" { - return []string{url} - } - - return nil -} - -// ListGenerations 查询生成记录列表。 -// GET /api/v1/sora/generations -func (h *SoraClientHandler) ListGenerations(c *gin.Context) { - userID := getUserIDFromContext(c) - if userID == 0 { - response.Error(c, http.StatusUnauthorized, "未登录") - return - } - - page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) - pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) - - params := service.SoraGenerationListParams{ - UserID: userID, - Status: c.Query("status"), - StorageType: c.Query("storage_type"), - MediaType: c.Query("media_type"), - Page: page, - PageSize: pageSize, - } - - gens, total, err := h.genService.List(c.Request.Context(), params) - if err != nil { - response.ErrorFrom(c, err) - return - } - - // 为 S3 记录动态生成预签名 URL - for _, gen := range gens { - _ = h.genService.ResolveMediaURLs(c.Request.Context(), gen) - } - - response.Success(c, gin.H{ - "data": gens, - "total": total, - "page": page, - }) -} - -// GetGeneration 查询生成记录详情。 -// GET /api/v1/sora/generations/:id -func (h *SoraClientHandler) GetGeneration(c *gin.Context) { - userID := getUserIDFromContext(c) - if userID == 0 { - response.Error(c, http.StatusUnauthorized, "未登录") - return - } - - id, err := strconv.ParseInt(c.Param("id"), 10, 64) - if err != nil { - response.Error(c, http.StatusBadRequest, "无效的 ID") - return - } - - gen, err := h.genService.GetByID(c.Request.Context(), id, userID) - if err != nil { - response.Error(c, http.StatusNotFound, err.Error()) - return - } - - _ = h.genService.ResolveMediaURLs(c.Request.Context(), gen) - response.Success(c, gen) -} - -// DeleteGeneration 删除生成记录。 -// DELETE /api/v1/sora/generations/:id -func (h *SoraClientHandler) DeleteGeneration(c *gin.Context) { - userID := getUserIDFromContext(c) - if userID == 0 { - response.Error(c, http.StatusUnauthorized, "未登录") - return - } - - id, err := strconv.ParseInt(c.Param("id"), 10, 64) - if err != nil { - response.Error(c, http.StatusBadRequest, "无效的 ID") - return - } - - gen, err := h.genService.GetByID(c.Request.Context(), id, userID) - if err != nil { - response.Error(c, http.StatusNotFound, err.Error()) - return - } - - // 先尝试清理本地文件,再删除记录(清理失败不阻塞删除)。 - if gen.StorageType == service.SoraStorageTypeLocal && h.mediaStorage != nil { - paths := gen.MediaURLs - if len(paths) == 0 && gen.MediaURL != "" { - paths = []string{gen.MediaURL} - } - if err := h.mediaStorage.DeleteByRelativePaths(paths); err != nil { - logger.LegacyPrintf("handler.sora_client", "[SoraClient] 删除本地文件失败 id=%d err=%v", id, err) - } - } - - if err := h.genService.Delete(c.Request.Context(), id, userID); err != nil { - response.Error(c, http.StatusNotFound, err.Error()) - return - } - - response.Success(c, gin.H{"message": "已删除"}) -} - -// GetQuota 查询用户存储配额。 -// GET /api/v1/sora/quota -func (h *SoraClientHandler) GetQuota(c *gin.Context) { - userID := getUserIDFromContext(c) - if userID == 0 { - response.Error(c, http.StatusUnauthorized, "未登录") - return - } - - if h.quotaService == nil { - response.Success(c, service.QuotaInfo{QuotaSource: "unlimited", Source: "unlimited"}) - return - } - - quota, err := h.quotaService.GetQuota(c.Request.Context(), userID) - if err != nil { - response.ErrorFrom(c, err) - return - } - response.Success(c, quota) -} - -// CancelGeneration 取消生成任务。 -// POST /api/v1/sora/generations/:id/cancel -func (h *SoraClientHandler) CancelGeneration(c *gin.Context) { - userID := getUserIDFromContext(c) - if userID == 0 { - response.Error(c, http.StatusUnauthorized, "未登录") - return - } - - id, err := strconv.ParseInt(c.Param("id"), 10, 64) - if err != nil { - response.Error(c, http.StatusBadRequest, "无效的 ID") - return - } - - // 权限校验 - gen, err := h.genService.GetByID(c.Request.Context(), id, userID) - if err != nil { - response.Error(c, http.StatusNotFound, err.Error()) - return - } - _ = gen - - if err := h.genService.MarkCancelled(c.Request.Context(), id); err != nil { - if errors.Is(err, service.ErrSoraGenerationNotActive) { - response.Error(c, http.StatusConflict, "任务已结束,无法取消") - return - } - response.Error(c, http.StatusBadRequest, err.Error()) - return - } - - response.Success(c, gin.H{"message": "已取消"}) -} - -// SaveToStorage 手动保存 upstream 记录到 S3。 -// POST /api/v1/sora/generations/:id/save -func (h *SoraClientHandler) SaveToStorage(c *gin.Context) { - userID := getUserIDFromContext(c) - if userID == 0 { - response.Error(c, http.StatusUnauthorized, "未登录") - return - } - - id, err := strconv.ParseInt(c.Param("id"), 10, 64) - if err != nil { - response.Error(c, http.StatusBadRequest, "无效的 ID") - return - } - - gen, err := h.genService.GetByID(c.Request.Context(), id, userID) - if err != nil { - response.Error(c, http.StatusNotFound, err.Error()) - return - } - - if gen.StorageType != service.SoraStorageTypeUpstream { - response.Error(c, http.StatusBadRequest, "仅 upstream 类型的记录可手动保存") - return - } - if gen.MediaURL == "" { - response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期") - return - } - - if h.s3Storage == nil || !h.s3Storage.Enabled(c.Request.Context()) { - response.Error(c, http.StatusServiceUnavailable, "云存储未配置,请联系管理员") - return - } - - sourceURLs := gen.MediaURLs - if len(sourceURLs) == 0 && gen.MediaURL != "" { - sourceURLs = []string{gen.MediaURL} - } - if len(sourceURLs) == 0 { - response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期") - return - } - - uploadedKeys := make([]string, 0, len(sourceURLs)) - accessURLs := make([]string, 0, len(sourceURLs)) - var totalSize int64 - - for _, sourceURL := range sourceURLs { - objectKey, fileSize, uploadErr := h.s3Storage.UploadFromURL(c.Request.Context(), userID, sourceURL) - if uploadErr != nil { - if len(uploadedKeys) > 0 { - _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys) - } - var upstreamErr *service.UpstreamDownloadError - if errors.As(uploadErr, &upstreamErr) && (upstreamErr.StatusCode == http.StatusForbidden || upstreamErr.StatusCode == http.StatusNotFound) { - response.Error(c, http.StatusGone, "媒体链接已过期,无法保存") - return - } - response.Error(c, http.StatusInternalServerError, "上传到 S3 失败: "+uploadErr.Error()) - return - } - accessURL, err := h.s3Storage.GetAccessURL(c.Request.Context(), objectKey) - if err != nil { - uploadedKeys = append(uploadedKeys, objectKey) - _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys) - response.Error(c, http.StatusInternalServerError, "生成 S3 访问链接失败: "+err.Error()) - return - } - uploadedKeys = append(uploadedKeys, objectKey) - accessURLs = append(accessURLs, accessURL) - totalSize += fileSize - } - - usageAdded := false - if totalSize > 0 && h.quotaService != nil { - if err := h.quotaService.AddUsage(c.Request.Context(), userID, totalSize); err != nil { - _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys) - var quotaErr *service.QuotaExceededError - if errors.As(err, "aErr) { - response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间") - return - } - response.Error(c, http.StatusInternalServerError, "配额更新失败: "+err.Error()) - return - } - usageAdded = true - } - - if err := h.genService.UpdateStorageForCompleted( - c.Request.Context(), - id, - accessURLs[0], - accessURLs, - service.SoraStorageTypeS3, - uploadedKeys, - totalSize, - ); err != nil { - _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys) - if usageAdded && h.quotaService != nil { - _ = h.quotaService.ReleaseUsage(c.Request.Context(), userID, totalSize) - } - response.ErrorFrom(c, err) - return - } - - response.Success(c, gin.H{ - "message": "已保存到 S3", - "object_key": uploadedKeys[0], - "object_keys": uploadedKeys, - }) -} - -// GetStorageStatus 返回存储状态。 -// GET /api/v1/sora/storage-status -func (h *SoraClientHandler) GetStorageStatus(c *gin.Context) { - s3Enabled := h.s3Storage != nil && h.s3Storage.Enabled(c.Request.Context()) - s3Healthy := false - if s3Enabled { - s3Healthy = h.s3Storage.IsHealthy(c.Request.Context()) - } - localEnabled := h.mediaStorage != nil && h.mediaStorage.Enabled() - response.Success(c, gin.H{ - "s3_enabled": s3Enabled, - "s3_healthy": s3Healthy, - "local_enabled": localEnabled, - }) -} - -func (h *SoraClientHandler) cleanupStoredMedia(ctx context.Context, storageType string, s3Keys []string, localPaths []string) { - switch storageType { - case service.SoraStorageTypeS3: - if h.s3Storage != nil && len(s3Keys) > 0 { - if err := h.s3Storage.DeleteObjects(ctx, s3Keys); err != nil { - logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理 S3 文件失败 keys=%v err=%v", s3Keys, err) - } - } - case service.SoraStorageTypeLocal: - if h.mediaStorage != nil && len(localPaths) > 0 { - if err := h.mediaStorage.DeleteByRelativePaths(localPaths); err != nil { - logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理本地文件失败 paths=%v err=%v", localPaths, err) - } - } - } -} - -// getUserIDFromContext 从 gin 上下文中提取用户 ID。 -func getUserIDFromContext(c *gin.Context) int64 { - if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok && subject.UserID > 0 { - return subject.UserID - } - - if id, ok := c.Get("user_id"); ok { - switch v := id.(type) { - case int64: - return v - case float64: - return int64(v) - case string: - n, _ := strconv.ParseInt(v, 10, 64) - return n - } - } - // 尝试从 JWT claims 获取 - if id, ok := c.Get("userID"); ok { - if v, ok := id.(int64); ok { - return v - } - } - return 0 -} - -func groupIDForLog(groupID *int64) int64 { - if groupID == nil { - return 0 - } - return *groupID -} - -func trimForLog(raw string, maxLen int) string { - trimmed := strings.TrimSpace(raw) - if maxLen <= 0 || len(trimmed) <= maxLen { - return trimmed - } - return trimmed[:maxLen] + "...(truncated)" -} - -// GetModels 获取可用 Sora 模型家族列表。 -// 优先从上游 Sora API 同步模型列表,失败时降级到本地配置。 -// GET /api/v1/sora/models -func (h *SoraClientHandler) GetModels(c *gin.Context) { - families := h.getModelFamilies(c.Request.Context()) - response.Success(c, families) -} - -// getModelFamilies 获取模型家族列表(带缓存)。 -func (h *SoraClientHandler) getModelFamilies(ctx context.Context) []service.SoraModelFamily { - // 读锁检查缓存 - h.modelCacheMu.RLock() - ttl := modelCacheTTL - if !h.modelCacheUpstream { - ttl = modelCacheFailedTTL - } - if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl { - families := h.cachedFamilies - h.modelCacheMu.RUnlock() - return families - } - h.modelCacheMu.RUnlock() - - // 写锁更新缓存 - h.modelCacheMu.Lock() - defer h.modelCacheMu.Unlock() - - // double-check - ttl = modelCacheTTL - if !h.modelCacheUpstream { - ttl = modelCacheFailedTTL - } - if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl { - return h.cachedFamilies - } - - // 尝试从上游获取 - families, err := h.fetchUpstreamModels(ctx) - if err != nil { - logger.LegacyPrintf("handler.sora_client", "[SoraClient] 上游模型获取失败,使用本地配置: %v", err) - families = service.BuildSoraModelFamilies() - h.cachedFamilies = families - h.modelCacheTime = time.Now() - h.modelCacheUpstream = false - return families - } - - logger.LegacyPrintf("handler.sora_client", "[SoraClient] 从上游同步到 %d 个模型家族", len(families)) - h.cachedFamilies = families - h.modelCacheTime = time.Now() - h.modelCacheUpstream = true - return families -} - -// fetchUpstreamModels 从上游 Sora API 获取模型列表。 -func (h *SoraClientHandler) fetchUpstreamModels(ctx context.Context) ([]service.SoraModelFamily, error) { - if h.gatewayService == nil { - return nil, fmt.Errorf("gatewayService 未初始化") - } - - // 设置 ForcePlatform 用于 Sora 账号选择 - ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora) - - // 选择一个 Sora 账号 - account, err := h.gatewayService.SelectAccountForModel(ctx, nil, "", "sora2-landscape-10s") - if err != nil { - return nil, fmt.Errorf("选择 Sora 账号失败: %w", err) - } - - // 仅支持 API Key 类型账号 - if account.Type != service.AccountTypeAPIKey { - return nil, fmt.Errorf("当前账号类型 %s 不支持模型同步", account.Type) - } - - apiKey := account.GetCredential("api_key") - if apiKey == "" { - return nil, fmt.Errorf("账号缺少 api_key") - } - - baseURL := account.GetBaseURL() - if baseURL == "" { - return nil, fmt.Errorf("账号缺少 base_url") - } - - // 构建上游模型列表请求 - modelsURL := strings.TrimRight(baseURL, "/") + "/sora/v1/models" - - reqCtx, cancel := context.WithTimeout(ctx, 10*time.Second) - defer cancel() - - req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, modelsURL, nil) - if err != nil { - return nil, fmt.Errorf("创建请求失败: %w", err) - } - req.Header.Set("Authorization", "Bearer "+apiKey) - - client := &http.Client{Timeout: 10 * time.Second} - resp, err := client.Do(req) - if err != nil { - return nil, fmt.Errorf("请求上游失败: %w", err) - } - defer func() { - _ = resp.Body.Close() - }() - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("上游返回状态码 %d", resp.StatusCode) - } - - body, err := io.ReadAll(io.LimitReader(resp.Body, 1*1024*1024)) - if err != nil { - return nil, fmt.Errorf("读取响应失败: %w", err) - } - - // 解析 OpenAI 格式的模型列表 - var modelsResp struct { - Data []struct { - ID string `json:"id"` - } `json:"data"` - } - if err := json.Unmarshal(body, &modelsResp); err != nil { - return nil, fmt.Errorf("解析响应失败: %w", err) - } - - if len(modelsResp.Data) == 0 { - return nil, fmt.Errorf("上游返回空模型列表") - } - - // 提取模型 ID - modelIDs := make([]string, 0, len(modelsResp.Data)) - for _, m := range modelsResp.Data { - modelIDs = append(modelIDs, m.ID) - } - - // 转换为模型家族 - families := service.BuildSoraModelFamiliesFromIDs(modelIDs) - if len(families) == 0 { - return nil, fmt.Errorf("未能从上游模型列表中识别出有效的模型家族") - } - - return families, nil -} diff --git a/backend/internal/handler/sora_client_handler_test.go b/backend/internal/handler/sora_client_handler_test.go deleted file mode 100644 index 13523fe8..00000000 --- a/backend/internal/handler/sora_client_handler_test.go +++ /dev/null @@ -1,3186 +0,0 @@ -//go:build unit - -package handler - -import ( - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "net/http/httptest" - "os" - "strings" - "sync/atomic" - "testing" - "time" - - "github.com/Wei-Shaw/sub2api/internal/config" - "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" - middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" - "github.com/Wei-Shaw/sub2api/internal/service" - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/require" -) - -func init() { - gin.SetMode(gin.TestMode) -} - -// ==================== Stub: SoraGenerationRepository ==================== - -var _ service.SoraGenerationRepository = (*stubSoraGenRepo)(nil) - -type stubSoraGenRepo struct { - gens map[int64]*service.SoraGeneration - nextID int64 - createErr error - getErr error - updateErr error - deleteErr error - listErr error - countErr error - countValue int64 - - // 条件性 Update 失败:前 updateFailAfterN 次成功,之后失败 - updateCallCount *int32 - updateFailAfterN int32 - - // 条件性 GetByID 状态覆盖:前 getByIDOverrideAfterN 次正常返回,之后返回 overrideStatus - getByIDCallCount int32 - getByIDOverrideAfterN int32 // 0 = 不覆盖 - getByIDOverrideStatus string -} - -func newStubSoraGenRepo() *stubSoraGenRepo { - return &stubSoraGenRepo{gens: make(map[int64]*service.SoraGeneration), nextID: 1} -} - -func (r *stubSoraGenRepo) Create(_ context.Context, gen *service.SoraGeneration) error { - if r.createErr != nil { - return r.createErr - } - gen.ID = r.nextID - r.nextID++ - r.gens[gen.ID] = gen - return nil -} -func (r *stubSoraGenRepo) GetByID(_ context.Context, id int64) (*service.SoraGeneration, error) { - if r.getErr != nil { - return nil, r.getErr - } - gen, ok := r.gens[id] - if !ok { - return nil, fmt.Errorf("not found") - } - // 条件性状态覆盖:模拟外部取消等场景 - if r.getByIDOverrideAfterN > 0 { - n := atomic.AddInt32(&r.getByIDCallCount, 1) - if n > r.getByIDOverrideAfterN { - cp := *gen - cp.Status = r.getByIDOverrideStatus - return &cp, nil - } - } - return gen, nil -} -func (r *stubSoraGenRepo) Update(_ context.Context, gen *service.SoraGeneration) error { - // 条件性失败:前 N 次成功,之后失败 - if r.updateCallCount != nil { - n := atomic.AddInt32(r.updateCallCount, 1) - if n > r.updateFailAfterN { - return fmt.Errorf("conditional update error (call #%d)", n) - } - } - if r.updateErr != nil { - return r.updateErr - } - r.gens[gen.ID] = gen - return nil -} -func (r *stubSoraGenRepo) Delete(_ context.Context, id int64) error { - if r.deleteErr != nil { - return r.deleteErr - } - delete(r.gens, id) - return nil -} -func (r *stubSoraGenRepo) List(_ context.Context, params service.SoraGenerationListParams) ([]*service.SoraGeneration, int64, error) { - if r.listErr != nil { - return nil, 0, r.listErr - } - var result []*service.SoraGeneration - for _, gen := range r.gens { - if gen.UserID != params.UserID { - continue - } - result = append(result, gen) - } - return result, int64(len(result)), nil -} -func (r *stubSoraGenRepo) CountByUserAndStatus(_ context.Context, _ int64, _ []string) (int64, error) { - if r.countErr != nil { - return 0, r.countErr - } - return r.countValue, nil -} - -// ==================== 辅助函数 ==================== - -func newTestSoraClientHandler(repo *stubSoraGenRepo) *SoraClientHandler { - genService := service.NewSoraGenerationService(repo, nil, nil) - return &SoraClientHandler{genService: genService} -} - -func makeGinContext(method, path, body string, userID int64) (*gin.Context, *httptest.ResponseRecorder) { - rec := httptest.NewRecorder() - c, _ := gin.CreateTestContext(rec) - if body != "" { - c.Request = httptest.NewRequest(method, path, strings.NewReader(body)) - c.Request.Header.Set("Content-Type", "application/json") - } else { - c.Request = httptest.NewRequest(method, path, nil) - } - if userID > 0 { - c.Set("user_id", userID) - } - return c, rec -} - -func parseResponse(t *testing.T, rec *httptest.ResponseRecorder) map[string]any { - t.Helper() - var resp map[string]any - require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) - return resp -} - -// ==================== 纯函数测试: buildAsyncRequestBody ==================== - -func TestBuildAsyncRequestBody(t *testing.T) { - body := buildAsyncRequestBody("sora2-landscape-10s", "一只猫在跳舞", "", 1) - var parsed map[string]any - require.NoError(t, json.Unmarshal(body, &parsed)) - require.Equal(t, "sora2-landscape-10s", parsed["model"]) - require.Equal(t, false, parsed["stream"]) - - msgs := parsed["messages"].([]any) - require.Len(t, msgs, 1) - msg := msgs[0].(map[string]any) - require.Equal(t, "user", msg["role"]) - require.Equal(t, "一只猫在跳舞", msg["content"]) -} - -func TestBuildAsyncRequestBody_EmptyPrompt(t *testing.T) { - body := buildAsyncRequestBody("gpt-image", "", "", 1) - var parsed map[string]any - require.NoError(t, json.Unmarshal(body, &parsed)) - require.Equal(t, "gpt-image", parsed["model"]) - msgs := parsed["messages"].([]any) - msg := msgs[0].(map[string]any) - require.Equal(t, "", msg["content"]) -} - -func TestBuildAsyncRequestBody_WithImageInput(t *testing.T) { - body := buildAsyncRequestBody("gpt-image", "一只猫", "https://example.com/ref.png", 1) - var parsed map[string]any - require.NoError(t, json.Unmarshal(body, &parsed)) - require.Equal(t, "https://example.com/ref.png", parsed["image_input"]) -} - -func TestBuildAsyncRequestBody_WithVideoCount(t *testing.T) { - body := buildAsyncRequestBody("sora2-landscape-10s", "一只猫在跳舞", "", 3) - var parsed map[string]any - require.NoError(t, json.Unmarshal(body, &parsed)) - require.Equal(t, float64(3), parsed["video_count"]) -} - -func TestNormalizeVideoCount(t *testing.T) { - require.Equal(t, 1, normalizeVideoCount("video", 0)) - require.Equal(t, 2, normalizeVideoCount("video", 2)) - require.Equal(t, 3, normalizeVideoCount("video", 5)) - require.Equal(t, 1, normalizeVideoCount("image", 3)) -} - -// ==================== 纯函数测试: parseMediaURLsFromBody ==================== - -func TestParseMediaURLsFromBody_MediaURLs(t *testing.T) { - urls := parseMediaURLsFromBody([]byte(`{"media_urls":["https://a.com/1.mp4","https://a.com/2.mp4"]}`)) - require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls) -} - -func TestParseMediaURLsFromBody_SingleMediaURL(t *testing.T) { - urls := parseMediaURLsFromBody([]byte(`{"media_url":"https://a.com/video.mp4"}`)) - require.Equal(t, []string{"https://a.com/video.mp4"}, urls) -} - -func TestParseMediaURLsFromBody_EmptyBody(t *testing.T) { - require.Nil(t, parseMediaURLsFromBody(nil)) - require.Nil(t, parseMediaURLsFromBody([]byte{})) -} - -func TestParseMediaURLsFromBody_InvalidJSON(t *testing.T) { - require.Nil(t, parseMediaURLsFromBody([]byte("not json"))) -} - -func TestParseMediaURLsFromBody_NoMediaFields(t *testing.T) { - require.Nil(t, parseMediaURLsFromBody([]byte(`{"data":"something"}`))) -} - -func TestParseMediaURLsFromBody_EmptyMediaURL(t *testing.T) { - require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_url":""}`))) -} - -func TestParseMediaURLsFromBody_EmptyMediaURLs(t *testing.T) { - require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":[]}`))) -} - -func TestParseMediaURLsFromBody_MediaURLsPriority(t *testing.T) { - body := `{"media_url":"https://single.com/1.mp4","media_urls":["https://multi.com/a.mp4","https://multi.com/b.mp4"]}` - urls := parseMediaURLsFromBody([]byte(body)) - require.Len(t, urls, 2) - require.Equal(t, "https://multi.com/a.mp4", urls[0]) -} - -func TestParseMediaURLsFromBody_FilterEmpty(t *testing.T) { - urls := parseMediaURLsFromBody([]byte(`{"media_urls":["https://a.com/1.mp4","","https://a.com/2.mp4"]}`)) - require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls) -} - -func TestParseMediaURLsFromBody_AllEmpty(t *testing.T) { - require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":["",""]}`))) -} - -func TestParseMediaURLsFromBody_NonStringArray(t *testing.T) { - // media_urls 不是 string 数组 - require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":"not-array"}`))) -} - -func TestParseMediaURLsFromBody_MediaURLNotString(t *testing.T) { - require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_url":123}`))) -} - -// ==================== 纯函数测试: extractMediaURLsFromResult ==================== - -func TestExtractMediaURLsFromResult_OAuthPath(t *testing.T) { - result := &service.ForwardResult{MediaURL: "https://oauth.com/video.mp4"} - recorder := httptest.NewRecorder() - url, urls := extractMediaURLsFromResult(result, recorder) - require.Equal(t, "https://oauth.com/video.mp4", url) - require.Equal(t, []string{"https://oauth.com/video.mp4"}, urls) -} - -func TestExtractMediaURLsFromResult_OAuthWithBody(t *testing.T) { - result := &service.ForwardResult{MediaURL: "https://oauth.com/video.mp4"} - recorder := httptest.NewRecorder() - _, _ = recorder.Write([]byte(`{"media_urls":["https://body.com/1.mp4","https://body.com/2.mp4"]}`)) - url, urls := extractMediaURLsFromResult(result, recorder) - require.Equal(t, "https://body.com/1.mp4", url) - require.Len(t, urls, 2) -} - -func TestExtractMediaURLsFromResult_APIKeyPath(t *testing.T) { - recorder := httptest.NewRecorder() - _, _ = recorder.Write([]byte(`{"media_url":"https://upstream.com/video.mp4"}`)) - url, urls := extractMediaURLsFromResult(nil, recorder) - require.Equal(t, "https://upstream.com/video.mp4", url) - require.Equal(t, []string{"https://upstream.com/video.mp4"}, urls) -} - -func TestExtractMediaURLsFromResult_NilResultEmptyBody(t *testing.T) { - recorder := httptest.NewRecorder() - url, urls := extractMediaURLsFromResult(nil, recorder) - require.Empty(t, url) - require.Nil(t, urls) -} - -func TestExtractMediaURLsFromResult_EmptyMediaURL(t *testing.T) { - result := &service.ForwardResult{MediaURL: ""} - recorder := httptest.NewRecorder() - url, urls := extractMediaURLsFromResult(result, recorder) - require.Empty(t, url) - require.Nil(t, urls) -} - -// ==================== getUserIDFromContext ==================== - -func TestGetUserIDFromContext_Int64(t *testing.T) { - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - c.Request = httptest.NewRequest("GET", "/", nil) - c.Set("user_id", int64(42)) - require.Equal(t, int64(42), getUserIDFromContext(c)) -} - -func TestGetUserIDFromContext_AuthSubject(t *testing.T) { - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - c.Request = httptest.NewRequest("GET", "/", nil) - c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 777}) - require.Equal(t, int64(777), getUserIDFromContext(c)) -} - -func TestGetUserIDFromContext_Float64(t *testing.T) { - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - c.Request = httptest.NewRequest("GET", "/", nil) - c.Set("user_id", float64(99)) - require.Equal(t, int64(99), getUserIDFromContext(c)) -} - -func TestGetUserIDFromContext_String(t *testing.T) { - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - c.Request = httptest.NewRequest("GET", "/", nil) - c.Set("user_id", "123") - require.Equal(t, int64(123), getUserIDFromContext(c)) -} - -func TestGetUserIDFromContext_UserIDFallback(t *testing.T) { - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - c.Request = httptest.NewRequest("GET", "/", nil) - c.Set("userID", int64(55)) - require.Equal(t, int64(55), getUserIDFromContext(c)) -} - -func TestGetUserIDFromContext_NoID(t *testing.T) { - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - c.Request = httptest.NewRequest("GET", "/", nil) - require.Equal(t, int64(0), getUserIDFromContext(c)) -} - -func TestGetUserIDFromContext_InvalidString(t *testing.T) { - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - c.Request = httptest.NewRequest("GET", "/", nil) - c.Set("user_id", "not-a-number") - require.Equal(t, int64(0), getUserIDFromContext(c)) -} - -// ==================== Handler: Generate ==================== - -func TestGenerate_Unauthorized(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 0) - h.Generate(c) - require.Equal(t, http.StatusUnauthorized, rec.Code) -} - -func TestGenerate_BadRequest_MissingModel(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"prompt":"test"}`, 1) - h.Generate(c) - require.Equal(t, http.StatusBadRequest, rec.Code) -} - -func TestGenerate_BadRequest_MissingPrompt(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s"}`, 1) - h.Generate(c) - require.Equal(t, http.StatusBadRequest, rec.Code) -} - -func TestGenerate_BadRequest_InvalidJSON(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{invalid`, 1) - h.Generate(c) - require.Equal(t, http.StatusBadRequest, rec.Code) -} - -func TestGenerate_TooManyRequests(t *testing.T) { - repo := newStubSoraGenRepo() - repo.countValue = 3 - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) - h.Generate(c) - require.Equal(t, http.StatusTooManyRequests, rec.Code) -} - -func TestGenerate_CountError(t *testing.T) { - repo := newStubSoraGenRepo() - repo.countErr = fmt.Errorf("db error") - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) - h.Generate(c) - require.Equal(t, http.StatusInternalServerError, rec.Code) -} - -func TestGenerate_Success(t *testing.T) { - repo := newStubSoraGenRepo() - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"测试生成"}`, 1) - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) - resp := parseResponse(t, rec) - data := resp["data"].(map[string]any) - require.NotZero(t, data["generation_id"]) - require.Equal(t, "pending", data["status"]) -} - -func TestGenerate_DefaultMediaType(t *testing.T) { - repo := newStubSoraGenRepo() - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) - require.Equal(t, "video", repo.gens[1].MediaType) -} - -func TestGenerate_ImageMediaType(t *testing.T) { - repo := newStubSoraGenRepo() - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"gpt-image","prompt":"test","media_type":"image"}`, 1) - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) - require.Equal(t, "image", repo.gens[1].MediaType) -} - -func TestGenerate_CreatePendingError(t *testing.T) { - repo := newStubSoraGenRepo() - repo.createErr = fmt.Errorf("create failed") - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) - h.Generate(c) - require.Equal(t, http.StatusInternalServerError, rec.Code) -} - -func TestGenerate_NilQuotaServiceSkipsCheck(t *testing.T) { - repo := newStubSoraGenRepo() - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) -} - -func TestGenerate_APIKeyInContext(t *testing.T) { - repo := newStubSoraGenRepo() - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) - c.Set("api_key_id", int64(42)) - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) - require.NotNil(t, repo.gens[1].APIKeyID) - require.Equal(t, int64(42), *repo.gens[1].APIKeyID) -} - -func TestGenerate_NoAPIKeyInContext(t *testing.T) { - repo := newStubSoraGenRepo() - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) - require.Nil(t, repo.gens[1].APIKeyID) -} - -func TestGenerate_ConcurrencyBoundary(t *testing.T) { - // activeCount == 2 应该允许 - repo := newStubSoraGenRepo() - repo.countValue = 2 - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) -} - -// ==================== Handler: ListGenerations ==================== - -func TestListGenerations_Unauthorized(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 0) - h.ListGenerations(c) - require.Equal(t, http.StatusUnauthorized, rec.Code) -} - -func TestListGenerations_Success(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Model: "sora2-landscape-10s", Status: "completed", StorageType: "upstream"} - repo.gens[2] = &service.SoraGeneration{ID: 2, UserID: 1, Model: "gpt-image", Status: "pending", StorageType: "none"} - repo.nextID = 3 - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("GET", "/api/v1/sora/generations?page=1&page_size=10", "", 1) - h.ListGenerations(c) - require.Equal(t, http.StatusOK, rec.Code) - resp := parseResponse(t, rec) - data := resp["data"].(map[string]any) - items := data["data"].([]any) - require.Len(t, items, 2) - require.Equal(t, float64(2), data["total"]) -} - -func TestListGenerations_ListError(t *testing.T) { - repo := newStubSoraGenRepo() - repo.listErr = fmt.Errorf("db error") - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 1) - h.ListGenerations(c) - require.Equal(t, http.StatusInternalServerError, rec.Code) -} - -func TestListGenerations_DefaultPagination(t *testing.T) { - repo := newStubSoraGenRepo() - h := newTestSoraClientHandler(repo) - // 不传分页参数,应默认 page=1 page_size=20 - c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 1) - h.ListGenerations(c) - require.Equal(t, http.StatusOK, rec.Code) - resp := parseResponse(t, rec) - data := resp["data"].(map[string]any) - require.Equal(t, float64(1), data["page"]) -} - -// ==================== Handler: GetGeneration ==================== - -func TestGetGeneration_Unauthorized(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 0) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.GetGeneration(c) - require.Equal(t, http.StatusUnauthorized, rec.Code) -} - -func TestGetGeneration_InvalidID(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("GET", "/api/v1/sora/generations/abc", "", 1) - c.Params = gin.Params{{Key: "id", Value: "abc"}} - h.GetGeneration(c) - require.Equal(t, http.StatusBadRequest, rec.Code) -} - -func TestGetGeneration_NotFound(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("GET", "/api/v1/sora/generations/999", "", 1) - c.Params = gin.Params{{Key: "id", Value: "999"}} - h.GetGeneration(c) - require.Equal(t, http.StatusNotFound, rec.Code) -} - -func TestGetGeneration_WrongUser(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed"} - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.GetGeneration(c) - require.Equal(t, http.StatusNotFound, rec.Code) -} - -func TestGetGeneration_Success(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Model: "sora2-landscape-10s", Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"} - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.GetGeneration(c) - require.Equal(t, http.StatusOK, rec.Code) - resp := parseResponse(t, rec) - data := resp["data"].(map[string]any) - require.Equal(t, float64(1), data["id"]) -} - -// ==================== Handler: DeleteGeneration ==================== - -func TestDeleteGeneration_Unauthorized(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 0) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.DeleteGeneration(c) - require.Equal(t, http.StatusUnauthorized, rec.Code) -} - -func TestDeleteGeneration_InvalidID(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/abc", "", 1) - c.Params = gin.Params{{Key: "id", Value: "abc"}} - h.DeleteGeneration(c) - require.Equal(t, http.StatusBadRequest, rec.Code) -} - -func TestDeleteGeneration_NotFound(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/999", "", 1) - c.Params = gin.Params{{Key: "id", Value: "999"}} - h.DeleteGeneration(c) - require.Equal(t, http.StatusNotFound, rec.Code) -} - -func TestDeleteGeneration_WrongUser(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed"} - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.DeleteGeneration(c) - require.Equal(t, http.StatusNotFound, rec.Code) -} - -func TestDeleteGeneration_Success(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"} - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.DeleteGeneration(c) - require.Equal(t, http.StatusOK, rec.Code) - _, exists := repo.gens[1] - require.False(t, exists) -} - -// ==================== Handler: CancelGeneration ==================== - -func TestCancelGeneration_Unauthorized(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 0) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.CancelGeneration(c) - require.Equal(t, http.StatusUnauthorized, rec.Code) -} - -func TestCancelGeneration_InvalidID(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/abc/cancel", "", 1) - c.Params = gin.Params{{Key: "id", Value: "abc"}} - h.CancelGeneration(c) - require.Equal(t, http.StatusBadRequest, rec.Code) -} - -func TestCancelGeneration_NotFound(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/999/cancel", "", 1) - c.Params = gin.Params{{Key: "id", Value: "999"}} - h.CancelGeneration(c) - require.Equal(t, http.StatusNotFound, rec.Code) -} - -func TestCancelGeneration_WrongUser(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "pending"} - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.CancelGeneration(c) - require.Equal(t, http.StatusNotFound, rec.Code) -} - -func TestCancelGeneration_Pending(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.CancelGeneration(c) - require.Equal(t, http.StatusOK, rec.Code) - require.Equal(t, "cancelled", repo.gens[1].Status) -} - -func TestCancelGeneration_Generating(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "generating"} - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.CancelGeneration(c) - require.Equal(t, http.StatusOK, rec.Code) - require.Equal(t, "cancelled", repo.gens[1].Status) -} - -func TestCancelGeneration_Completed(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"} - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.CancelGeneration(c) - require.Equal(t, http.StatusConflict, rec.Code) -} - -func TestCancelGeneration_Failed(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "failed"} - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.CancelGeneration(c) - require.Equal(t, http.StatusConflict, rec.Code) -} - -func TestCancelGeneration_Cancelled(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "cancelled"} - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.CancelGeneration(c) - require.Equal(t, http.StatusConflict, rec.Code) -} - -// ==================== Handler: GetQuota ==================== - -func TestGetQuota_Unauthorized(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 0) - h.GetQuota(c) - require.Equal(t, http.StatusUnauthorized, rec.Code) -} - -func TestGetQuota_NilQuotaService(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 1) - h.GetQuota(c) - require.Equal(t, http.StatusOK, rec.Code) - resp := parseResponse(t, rec) - data := resp["data"].(map[string]any) - require.Equal(t, "unlimited", data["source"]) -} - -// ==================== Handler: GetModels ==================== - -func TestGetModels(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("GET", "/api/v1/sora/models", "", 0) - h.GetModels(c) - require.Equal(t, http.StatusOK, rec.Code) - resp := parseResponse(t, rec) - data := resp["data"].([]any) - require.Len(t, data, 4) - // 验证类型分布 - videoCount, imageCount := 0, 0 - for _, item := range data { - m := item.(map[string]any) - if m["type"] == "video" { - videoCount++ - } else if m["type"] == "image" { - imageCount++ - } - } - require.Equal(t, 3, videoCount) - require.Equal(t, 1, imageCount) -} - -// ==================== Handler: GetStorageStatus ==================== - -func TestGetStorageStatus_NilS3(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0) - h.GetStorageStatus(c) - require.Equal(t, http.StatusOK, rec.Code) - resp := parseResponse(t, rec) - data := resp["data"].(map[string]any) - require.Equal(t, false, data["s3_enabled"]) - require.Equal(t, false, data["s3_healthy"]) - require.Equal(t, false, data["local_enabled"]) -} - -func TestGetStorageStatus_LocalEnabled(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "sora-storage-status-*") - require.NoError(t, err) - defer os.RemoveAll(tmpDir) - - cfg := &config.Config{ - Sora: config.SoraConfig{ - Storage: config.SoraStorageConfig{ - Type: "local", - LocalPath: tmpDir, - }, - }, - } - mediaStorage := service.NewSoraMediaStorage(cfg) - h := &SoraClientHandler{mediaStorage: mediaStorage} - - c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0) - h.GetStorageStatus(c) - require.Equal(t, http.StatusOK, rec.Code) - resp := parseResponse(t, rec) - data := resp["data"].(map[string]any) - require.Equal(t, false, data["s3_enabled"]) - require.Equal(t, false, data["s3_healthy"]) - require.Equal(t, true, data["local_enabled"]) -} - -// ==================== Handler: SaveToStorage ==================== - -func TestSaveToStorage_Unauthorized(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 0) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusUnauthorized, rec.Code) -} - -func TestSaveToStorage_InvalidID(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/abc/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "abc"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusBadRequest, rec.Code) -} - -func TestSaveToStorage_NotFound(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/999/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "999"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusNotFound, rec.Code) -} - -func TestSaveToStorage_NotUpstream(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "s3", MediaURL: "https://example.com/v.mp4"} - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusBadRequest, rec.Code) -} - -func TestSaveToStorage_EmptyMediaURL(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: ""} - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusBadRequest, rec.Code) -} - -func TestSaveToStorage_S3Nil(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"} - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusServiceUnavailable, rec.Code) - resp := parseResponse(t, rec) - require.Contains(t, fmt.Sprint(resp["message"]), "云存储") -} - -func TestSaveToStorage_WrongUser(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"} - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusNotFound, rec.Code) -} - -// ==================== storeMediaWithDegradation — nil guard 路径 ==================== - -func TestStoreMediaWithDegradation_NilS3NilMedia(t *testing.T) { - h := &SoraClientHandler{} - url, urls, storageType, keys, size := h.storeMediaWithDegradation( - context.Background(), 1, "video", "https://upstream.com/v.mp4", nil, - ) - require.Equal(t, service.SoraStorageTypeUpstream, storageType) - require.Equal(t, "https://upstream.com/v.mp4", url) - require.Equal(t, []string{"https://upstream.com/v.mp4"}, urls) - require.Nil(t, keys) - require.Equal(t, int64(0), size) -} - -func TestStoreMediaWithDegradation_NilGuardsMultiURL(t *testing.T) { - h := &SoraClientHandler{} - url, urls, storageType, keys, size := h.storeMediaWithDegradation( - context.Background(), 1, "video", "https://upstream.com/v.mp4", []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, - ) - require.Equal(t, service.SoraStorageTypeUpstream, storageType) - require.Equal(t, "https://a.com/1.mp4", url) - require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls) - require.Nil(t, keys) - require.Equal(t, int64(0), size) -} - -func TestStoreMediaWithDegradation_EmptyMediaURLsFallback(t *testing.T) { - h := &SoraClientHandler{} - url, _, storageType, _, _ := h.storeMediaWithDegradation( - context.Background(), 1, "video", "https://upstream.com/v.mp4", []string{}, - ) - require.Equal(t, service.SoraStorageTypeUpstream, storageType) - require.Equal(t, "https://upstream.com/v.mp4", url) -} - -// ==================== Stub: UserRepository (用于 SoraQuotaService) ==================== - -var _ service.UserRepository = (*stubUserRepoForHandler)(nil) - -type stubUserRepoForHandler struct { - users map[int64]*service.User - updateErr error -} - -func newStubUserRepoForHandler() *stubUserRepoForHandler { - return &stubUserRepoForHandler{users: make(map[int64]*service.User)} -} - -func (r *stubUserRepoForHandler) GetByID(_ context.Context, id int64) (*service.User, error) { - if u, ok := r.users[id]; ok { - return u, nil - } - return nil, fmt.Errorf("user not found") -} -func (r *stubUserRepoForHandler) Update(_ context.Context, user *service.User) error { - if r.updateErr != nil { - return r.updateErr - } - r.users[user.ID] = user - return nil -} -func (r *stubUserRepoForHandler) Create(context.Context, *service.User) error { return nil } -func (r *stubUserRepoForHandler) GetByEmail(context.Context, string) (*service.User, error) { - return nil, nil -} -func (r *stubUserRepoForHandler) GetFirstAdmin(context.Context) (*service.User, error) { - return nil, nil -} -func (r *stubUserRepoForHandler) Delete(context.Context, int64) error { return nil } -func (r *stubUserRepoForHandler) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) { - return nil, nil, nil -} -func (r *stubUserRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) { - return nil, nil, nil -} -func (r *stubUserRepoForHandler) UpdateBalance(context.Context, int64, float64) error { return nil } -func (r *stubUserRepoForHandler) DeductBalance(context.Context, int64, float64) error { return nil } -func (r *stubUserRepoForHandler) UpdateConcurrency(context.Context, int64, int) error { return nil } -func (r *stubUserRepoForHandler) ExistsByEmail(context.Context, string) (bool, error) { - return false, nil -} -func (r *stubUserRepoForHandler) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { - return 0, nil -} -func (r *stubUserRepoForHandler) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { - return nil -} -func (r *stubUserRepoForHandler) UpdateTotpSecret(context.Context, int64, *string) error { return nil } -func (r *stubUserRepoForHandler) EnableTotp(context.Context, int64) error { return nil } -func (r *stubUserRepoForHandler) DisableTotp(context.Context, int64) error { return nil } -func (r *stubUserRepoForHandler) AddGroupToAllowedGroups(context.Context, int64, int64) error { - return nil -} - -// ==================== NewSoraClientHandler ==================== - -func TestNewSoraClientHandler(t *testing.T) { - h := NewSoraClientHandler(nil, nil, nil, nil, nil, nil, nil) - require.NotNil(t, h) -} - -func TestNewSoraClientHandler_WithAPIKeyService(t *testing.T) { - h := NewSoraClientHandler(nil, nil, nil, nil, nil, nil, nil) - require.NotNil(t, h) - require.Nil(t, h.apiKeyService) -} - -// ==================== Stub: APIKeyRepository (用于 API Key 校验测试) ==================== - -var _ service.APIKeyRepository = (*stubAPIKeyRepoForHandler)(nil) - -type stubAPIKeyRepoForHandler struct { - keys map[int64]*service.APIKey - getErr error -} - -func newStubAPIKeyRepoForHandler() *stubAPIKeyRepoForHandler { - return &stubAPIKeyRepoForHandler{keys: make(map[int64]*service.APIKey)} -} - -func (r *stubAPIKeyRepoForHandler) GetByID(_ context.Context, id int64) (*service.APIKey, error) { - if r.getErr != nil { - return nil, r.getErr - } - if k, ok := r.keys[id]; ok { - return k, nil - } - return nil, fmt.Errorf("api key not found: %d", id) -} -func (r *stubAPIKeyRepoForHandler) Create(context.Context, *service.APIKey) error { return nil } -func (r *stubAPIKeyRepoForHandler) GetKeyAndOwnerID(_ context.Context, _ int64) (string, int64, error) { - return "", 0, nil -} -func (r *stubAPIKeyRepoForHandler) GetByKey(context.Context, string) (*service.APIKey, error) { - return nil, nil -} -func (r *stubAPIKeyRepoForHandler) GetByKeyForAuth(context.Context, string) (*service.APIKey, error) { - return nil, nil -} -func (r *stubAPIKeyRepoForHandler) Update(context.Context, *service.APIKey) error { return nil } -func (r *stubAPIKeyRepoForHandler) Delete(context.Context, int64) error { return nil } -func (r *stubAPIKeyRepoForHandler) ListByUserID(_ context.Context, _ int64, _ pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) { - return nil, nil, nil -} -func (r *stubAPIKeyRepoForHandler) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) { - return nil, nil -} -func (r *stubAPIKeyRepoForHandler) CountByUserID(context.Context, int64) (int64, error) { - return 0, nil -} -func (r *stubAPIKeyRepoForHandler) ExistsByKey(context.Context, string) (bool, error) { - return false, nil -} -func (r *stubAPIKeyRepoForHandler) ListByGroupID(_ context.Context, _ int64, _ pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) { - return nil, nil, nil -} -func (r *stubAPIKeyRepoForHandler) SearchAPIKeys(context.Context, int64, string, int) ([]service.APIKey, error) { - return nil, nil -} -func (r *stubAPIKeyRepoForHandler) ClearGroupIDByGroupID(context.Context, int64) (int64, error) { - return 0, nil -} -func (r *stubAPIKeyRepoForHandler) UpdateGroupIDByUserAndGroup(_ context.Context, userID, oldGroupID, newGroupID int64) (int64, error) { - var updated int64 - for id, key := range r.keys { - if key.UserID != userID || key.GroupID == nil || *key.GroupID != oldGroupID { - continue - } - clone := *key - gid := newGroupID - clone.GroupID = &gid - r.keys[id] = &clone - updated++ - } - return updated, nil -} -func (r *stubAPIKeyRepoForHandler) CountByGroupID(context.Context, int64) (int64, error) { - return 0, nil -} -func (r *stubAPIKeyRepoForHandler) ListKeysByUserID(context.Context, int64) ([]string, error) { - return nil, nil -} -func (r *stubAPIKeyRepoForHandler) ListKeysByGroupID(context.Context, int64) ([]string, error) { - return nil, nil -} -func (r *stubAPIKeyRepoForHandler) IncrementQuotaUsed(_ context.Context, _ int64, _ float64) (float64, error) { - return 0, nil -} -func (r *stubAPIKeyRepoForHandler) UpdateLastUsed(context.Context, int64, time.Time) error { - return nil -} -func (r *stubAPIKeyRepoForHandler) IncrementRateLimitUsage(context.Context, int64, float64) error { - return nil -} -func (r *stubAPIKeyRepoForHandler) ResetRateLimitWindows(context.Context, int64) error { - return nil -} -func (r *stubAPIKeyRepoForHandler) GetRateLimitData(context.Context, int64) (*service.APIKeyRateLimitData, error) { - return nil, nil -} - -// newTestAPIKeyService 创建测试用的 APIKeyService -func newTestAPIKeyService(repo *stubAPIKeyRepoForHandler) *service.APIKeyService { - return service.NewAPIKeyService(repo, nil, nil, nil, nil, nil, &config.Config{}) -} - -// ==================== Generate: API Key 校验(前端传递 api_key_id)==================== - -func TestGenerate_WithAPIKeyID_Success(t *testing.T) { - // 前端传递 api_key_id,校验通过 → 成功生成,记录关联 api_key_id - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - - groupID := int64(5) - apiKeyRepo := newStubAPIKeyRepoForHandler() - apiKeyRepo.keys[42] = &service.APIKey{ - ID: 42, - UserID: 1, - Status: service.StatusAPIKeyActive, - GroupID: &groupID, - } - apiKeyService := newTestAPIKeyService(apiKeyRepo) - - h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} - c, rec := makeGinContext("POST", "/api/v1/sora/generate", - `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) - resp := parseResponse(t, rec) - data := resp["data"].(map[string]any) - require.NotZero(t, data["generation_id"]) - - // 验证 api_key_id 已关联到生成记录 - gen := repo.gens[1] - require.NotNil(t, gen.APIKeyID) - require.Equal(t, int64(42), *gen.APIKeyID) -} - -func TestGenerate_WithAPIKeyID_NotFound(t *testing.T) { - // 前端传递不存在的 api_key_id → 400 - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - - apiKeyRepo := newStubAPIKeyRepoForHandler() - apiKeyService := newTestAPIKeyService(apiKeyRepo) - - h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} - c, rec := makeGinContext("POST", "/api/v1/sora/generate", - `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":999}`, 1) - h.Generate(c) - require.Equal(t, http.StatusBadRequest, rec.Code) - resp := parseResponse(t, rec) - require.Contains(t, fmt.Sprint(resp["message"]), "不存在") -} - -func TestGenerate_WithAPIKeyID_WrongUser(t *testing.T) { - // 前端传递别人的 api_key_id → 403 - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - - apiKeyRepo := newStubAPIKeyRepoForHandler() - apiKeyRepo.keys[42] = &service.APIKey{ - ID: 42, - UserID: 999, // 属于 user 999 - Status: service.StatusAPIKeyActive, - } - apiKeyService := newTestAPIKeyService(apiKeyRepo) - - h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} - c, rec := makeGinContext("POST", "/api/v1/sora/generate", - `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) - h.Generate(c) - require.Equal(t, http.StatusForbidden, rec.Code) - resp := parseResponse(t, rec) - require.Contains(t, fmt.Sprint(resp["message"]), "不属于") -} - -func TestGenerate_WithAPIKeyID_Disabled(t *testing.T) { - // 前端传递已禁用的 api_key_id → 403 - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - - apiKeyRepo := newStubAPIKeyRepoForHandler() - apiKeyRepo.keys[42] = &service.APIKey{ - ID: 42, - UserID: 1, - Status: service.StatusAPIKeyDisabled, - } - apiKeyService := newTestAPIKeyService(apiKeyRepo) - - h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} - c, rec := makeGinContext("POST", "/api/v1/sora/generate", - `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) - h.Generate(c) - require.Equal(t, http.StatusForbidden, rec.Code) - resp := parseResponse(t, rec) - require.Contains(t, fmt.Sprint(resp["message"]), "不可用") -} - -func TestGenerate_WithAPIKeyID_QuotaExhausted(t *testing.T) { - // 前端传递配额耗尽的 api_key_id → 403 - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - - apiKeyRepo := newStubAPIKeyRepoForHandler() - apiKeyRepo.keys[42] = &service.APIKey{ - ID: 42, - UserID: 1, - Status: service.StatusAPIKeyQuotaExhausted, - } - apiKeyService := newTestAPIKeyService(apiKeyRepo) - - h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} - c, rec := makeGinContext("POST", "/api/v1/sora/generate", - `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) - h.Generate(c) - require.Equal(t, http.StatusForbidden, rec.Code) -} - -func TestGenerate_WithAPIKeyID_Expired(t *testing.T) { - // 前端传递已过期的 api_key_id → 403 - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - - apiKeyRepo := newStubAPIKeyRepoForHandler() - apiKeyRepo.keys[42] = &service.APIKey{ - ID: 42, - UserID: 1, - Status: service.StatusAPIKeyExpired, - } - apiKeyService := newTestAPIKeyService(apiKeyRepo) - - h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} - c, rec := makeGinContext("POST", "/api/v1/sora/generate", - `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) - h.Generate(c) - require.Equal(t, http.StatusForbidden, rec.Code) -} - -func TestGenerate_WithAPIKeyID_NilAPIKeyService(t *testing.T) { - // apiKeyService 为 nil 时忽略 api_key_id → 正常生成但不记录 api_key_id - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - - h := &SoraClientHandler{genService: genService} // apiKeyService = nil - c, rec := makeGinContext("POST", "/api/v1/sora/generate", - `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) - // apiKeyService 为 nil → 跳过校验 → api_key_id 不记录 - require.Nil(t, repo.gens[1].APIKeyID) -} - -func TestGenerate_WithAPIKeyID_NilGroupID(t *testing.T) { - // api_key 有效但 GroupID 为 nil → 成功,groupID 为 nil - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - - apiKeyRepo := newStubAPIKeyRepoForHandler() - apiKeyRepo.keys[42] = &service.APIKey{ - ID: 42, - UserID: 1, - Status: service.StatusAPIKeyActive, - GroupID: nil, // 无分组 - } - apiKeyService := newTestAPIKeyService(apiKeyRepo) - - h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} - c, rec := makeGinContext("POST", "/api/v1/sora/generate", - `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) - require.NotNil(t, repo.gens[1].APIKeyID) - require.Equal(t, int64(42), *repo.gens[1].APIKeyID) -} - -func TestGenerate_NoAPIKeyID_NoContext_NilResult(t *testing.T) { - // 既无 api_key_id 字段也无 context 中的 api_key_id → api_key_id 为 nil - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - apiKeyRepo := newStubAPIKeyRepoForHandler() - apiKeyService := newTestAPIKeyService(apiKeyRepo) - - h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} - c, rec := makeGinContext("POST", "/api/v1/sora/generate", - `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) - require.Nil(t, repo.gens[1].APIKeyID) -} - -func TestGenerate_WithAPIKeyIDInBody_OverridesContext(t *testing.T) { - // 同时有 body api_key_id 和 context api_key_id → 优先使用 body 的 - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - - groupID := int64(10) - apiKeyRepo := newStubAPIKeyRepoForHandler() - apiKeyRepo.keys[42] = &service.APIKey{ - ID: 42, - UserID: 1, - Status: service.StatusAPIKeyActive, - GroupID: &groupID, - } - apiKeyService := newTestAPIKeyService(apiKeyRepo) - - h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} - c, rec := makeGinContext("POST", "/api/v1/sora/generate", - `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) - c.Set("api_key_id", int64(99)) // context 中有另一个 api_key_id - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) - // 应使用 body 中的 api_key_id=42,而不是 context 中的 99 - require.NotNil(t, repo.gens[1].APIKeyID) - require.Equal(t, int64(42), *repo.gens[1].APIKeyID) -} - -func TestGenerate_WithContextAPIKeyID_FallbackPath(t *testing.T) { - // 无 body api_key_id,但 context 有 → 使用 context 中的(兼容网关路由) - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - apiKeyRepo := newStubAPIKeyRepoForHandler() - apiKeyService := newTestAPIKeyService(apiKeyRepo) - - h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} - c, rec := makeGinContext("POST", "/api/v1/sora/generate", - `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) - c.Set("api_key_id", int64(99)) - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) - // 应使用 context 中的 api_key_id=99 - require.NotNil(t, repo.gens[1].APIKeyID) - require.Equal(t, int64(99), *repo.gens[1].APIKeyID) -} - -func TestGenerate_APIKeyID_Zero_IgnoredInJSON(t *testing.T) { - // JSON 中 api_key_id=0 被视为 omitempty → 仍然为指针值 0,需要传 nil 检查 - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - apiKeyRepo := newStubAPIKeyRepoForHandler() - apiKeyService := newTestAPIKeyService(apiKeyRepo) - - h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} - // JSON 中传了 api_key_id: 0 → 解析后 *int64(0),会触发校验 - // api_key_id=0 不存在 → 400 - c, rec := makeGinContext("POST", "/api/v1/sora/generate", - `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":0}`, 1) - h.Generate(c) - require.Equal(t, http.StatusBadRequest, rec.Code) -} - -// ==================== processGeneration: groupID 传递与 ForcePlatform ==================== - -func TestProcessGeneration_WithGroupID_NoForcePlatform(t *testing.T) { - // groupID 不为 nil → 不设置 ForcePlatform - // gatewayService 为 nil → MarkFailed → 检查错误消息不包含 ForcePlatform 相关 - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService} - - gid := int64(5) - h.processGeneration(1, 1, &gid, "sora2-landscape-10s", "test", "video", "", 1) - require.Equal(t, "failed", repo.gens[1].Status) - require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService") -} - -func TestProcessGeneration_NilGroupID_SetsForcePlatform(t *testing.T) { - // groupID 为 nil → 设置 ForcePlatform → gatewayService 为 nil → MarkFailed - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService} - - h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) - require.Equal(t, "failed", repo.gens[1].Status) - require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService") -} - -func TestProcessGeneration_MarkGeneratingStateConflict(t *testing.T) { - // 任务状态已变化(如已取消)→ MarkGenerating 返回 ErrSoraGenerationStateConflict → 跳过 - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "cancelled"} - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService} - - h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) - // 状态为 cancelled 时 MarkGenerating 不符合状态转换规则 → 应保持 cancelled - require.Equal(t, "cancelled", repo.gens[1].Status) -} - -// ==================== GenerateRequest JSON 解析 ==================== - -func TestGenerateRequest_WithAPIKeyID_JSONParsing(t *testing.T) { - // 验证 api_key_id 在 JSON 中正确解析为 *int64 - var req GenerateRequest - err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test","api_key_id":42}`), &req) - require.NoError(t, err) - require.NotNil(t, req.APIKeyID) - require.Equal(t, int64(42), *req.APIKeyID) -} - -func TestGenerateRequest_WithoutAPIKeyID_JSONParsing(t *testing.T) { - // 不传 api_key_id → 解析后为 nil - var req GenerateRequest - err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test"}`), &req) - require.NoError(t, err) - require.Nil(t, req.APIKeyID) -} - -func TestGenerateRequest_NullAPIKeyID_JSONParsing(t *testing.T) { - // api_key_id: null → 解析后为 nil - var req GenerateRequest - err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test","api_key_id":null}`), &req) - require.NoError(t, err) - require.Nil(t, req.APIKeyID) -} - -func TestGenerateRequest_FullFields_JSONParsing(t *testing.T) { - // 全字段解析 - var req GenerateRequest - err := json.Unmarshal([]byte(`{ - "model":"sora2-landscape-10s", - "prompt":"test prompt", - "media_type":"video", - "video_count":2, - "image_input":"data:image/png;base64,abc", - "api_key_id":100 - }`), &req) - require.NoError(t, err) - require.Equal(t, "sora2-landscape-10s", req.Model) - require.Equal(t, "test prompt", req.Prompt) - require.Equal(t, "video", req.MediaType) - require.Equal(t, 2, req.VideoCount) - require.Equal(t, "data:image/png;base64,abc", req.ImageInput) - require.NotNil(t, req.APIKeyID) - require.Equal(t, int64(100), *req.APIKeyID) -} - -func TestGenerateRequest_JSONSerialize_OmitsNilAPIKeyID(t *testing.T) { - // api_key_id 为 nil 时 JSON 序列化应省略 - req := GenerateRequest{Model: "sora2", Prompt: "test"} - b, err := json.Marshal(req) - require.NoError(t, err) - var parsed map[string]any - require.NoError(t, json.Unmarshal(b, &parsed)) - _, hasAPIKeyID := parsed["api_key_id"] - require.False(t, hasAPIKeyID, "api_key_id 为 nil 时应省略") -} - -func TestGenerateRequest_JSONSerialize_IncludesAPIKeyID(t *testing.T) { - // api_key_id 不为 nil 时 JSON 序列化应包含 - id := int64(42) - req := GenerateRequest{Model: "sora2", Prompt: "test", APIKeyID: &id} - b, err := json.Marshal(req) - require.NoError(t, err) - var parsed map[string]any - require.NoError(t, json.Unmarshal(b, &parsed)) - require.Equal(t, float64(42), parsed["api_key_id"]) -} - -// ==================== GetQuota: 有配额服务 ==================== - -func TestGetQuota_WithQuotaService_Success(t *testing.T) { - userRepo := newStubUserRepoForHandler() - userRepo.users[1] = &service.User{ - ID: 1, - Email: "test@example.com", - } - quotaService := service.NewSoraQuotaService(nil) - - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{ - genService: genService, - quotaService: quotaService, - } - - c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 1) - h.GetQuota(c) - require.Equal(t, http.StatusOK, rec.Code) - resp := parseResponse(t, rec) - data := resp["data"].(map[string]any) - // After refactoring: SoraQuotaService uses system-default only (no per-user DB field). - // With nil config → system quota = 0 → reported as "unlimited" mode. - require.Contains(t, []string{"system", "unlimited"}, data["source"]) -} - -func TestGetQuota_WithQuotaService_Error(t *testing.T) { - // After refactoring: system-default only mode always succeeds (returns 200). - // Even user ID=999 returns system-level quota info. - quotaService := service.NewSoraQuotaService(nil) - - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{ - genService: genService, - quotaService: quotaService, - } - - c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 999) - h.GetQuota(c) - require.Equal(t, http.StatusOK, rec.Code) -} - -// ==================== Generate: 配额检查 ==================== - -func TestGenerate_QuotaCheckFailed(t *testing.T) { - // After refactoring: system-default only mode. - // With nil config → system quota = 0 → unlimited mode → no 429 block. - // To test 429, we'd need a non-nil config with a small positive system quota, - // but for now just verify the request proceeds (200) because unlimited mode allows all. - userRepo := newStubUserRepoForHandler() - userRepo.users[1] = &service.User{ - ID: 1, - Email: "test@example.com", - } - quotaService := service.NewSoraQuotaService(nil) - - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{ - genService: genService, - quotaService: quotaService, - } - - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) - h.Generate(c) - // In unlimited mode (nil config / zero system quota): no quota block - require.Equal(t, http.StatusOK, rec.Code) -} - -func TestGenerate_QuotaCheckPassed(t *testing.T) { - // 配额充足时允许生成 — after refactoring, quota is system-default only - userRepo := newStubUserRepoForHandler() - userRepo.users[1] = &service.User{ - ID: 1, - Email: "test@example.com", - } - quotaService := service.NewSoraQuotaService(nil) - - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{ - genService: genService, - quotaService: quotaService, - } - - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) -} - -// ==================== Stub: SettingRepository (用于 S3 存储测试) ==================== - -var _ service.SettingRepository = (*stubSettingRepoForHandler)(nil) - -type stubSettingRepoForHandler struct { - values map[string]string -} - -func newStubSettingRepoForHandler(values map[string]string) *stubSettingRepoForHandler { - if values == nil { - values = make(map[string]string) - } - return &stubSettingRepoForHandler{values: values} -} - -func (r *stubSettingRepoForHandler) Get(_ context.Context, key string) (*service.Setting, error) { - if v, ok := r.values[key]; ok { - return &service.Setting{Key: key, Value: v}, nil - } - return nil, service.ErrSettingNotFound -} -func (r *stubSettingRepoForHandler) GetValue(_ context.Context, key string) (string, error) { - if v, ok := r.values[key]; ok { - return v, nil - } - return "", service.ErrSettingNotFound -} -func (r *stubSettingRepoForHandler) Set(_ context.Context, key, value string) error { - r.values[key] = value - return nil -} -func (r *stubSettingRepoForHandler) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { - result := make(map[string]string) - for _, k := range keys { - if v, ok := r.values[k]; ok { - result[k] = v - } - } - return result, nil -} -func (r *stubSettingRepoForHandler) SetMultiple(_ context.Context, settings map[string]string) error { - for k, v := range settings { - r.values[k] = v - } - return nil -} -func (r *stubSettingRepoForHandler) GetAll(_ context.Context) (map[string]string, error) { - return r.values, nil -} -func (r *stubSettingRepoForHandler) Delete(_ context.Context, key string) error { - delete(r.values, key) - return nil -} - -// ==================== S3 / MediaStorage 辅助函数 ==================== - -// newS3StorageForHandler 创建指向指定 endpoint 的 S3Storage(用于测试)。 -func newS3StorageForHandler(endpoint string) *service.SoraS3Storage { - settingRepo := newStubSettingRepoForHandler(map[string]string{ - "sora_s3_enabled": "true", - "sora_s3_endpoint": endpoint, - "sora_s3_region": "us-east-1", - "sora_s3_bucket": "test-bucket", - "sora_s3_access_key_id": "AKIATEST", - "sora_s3_secret_access_key": "test-secret", - "sora_s3_prefix": "sora", - "sora_s3_force_path_style": "true", - }) - settingService := service.NewSettingService(settingRepo, &config.Config{}) - return service.NewSoraS3Storage(settingService) -} - -// newFakeSourceServer 创建返回固定内容的 HTTP 服务器(模拟上游媒体文件)。 -func newFakeSourceServer() *httptest.Server { - return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "video/mp4") - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("fake video data for test")) - })) -} - -// newFakeS3Server 创建模拟 S3 的 HTTP 服务器。 -// mode: "ok" 接受所有请求,"fail" 返回 403,"fail-second" 第一次成功第二次失败。 -func newFakeS3Server(mode string) *httptest.Server { - var counter atomic.Int32 - return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, _ = io.Copy(io.Discard, r.Body) - _ = r.Body.Close() - - switch mode { - case "ok": - w.Header().Set("ETag", `"test-etag"`) - w.WriteHeader(http.StatusOK) - case "fail": - w.WriteHeader(http.StatusForbidden) - _, _ = w.Write([]byte(`AccessDenied`)) - case "fail-second": - n := counter.Add(1) - if n <= 1 { - w.Header().Set("ETag", `"test-etag"`) - w.WriteHeader(http.StatusOK) - } else { - w.WriteHeader(http.StatusForbidden) - _, _ = w.Write([]byte(`AccessDenied`)) - } - } - })) -} - -// ==================== processGeneration 直接调用测试 ==================== - -func TestProcessGeneration_MarkGeneratingFails(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - repo.updateErr = fmt.Errorf("db error") - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService} - - // 直接调用(非 goroutine),MarkGenerating 失败 → 早退 - h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) - // MarkGenerating 在调用 repo.Update 前已修改内存对象为 "generating" - // repo.Update 返回错误 → processGeneration 早退,不会继续到 MarkFailed - // 因此 ErrorMessage 为空(证明未调用 MarkFailed) - require.Equal(t, "generating", repo.gens[1].Status) - require.Empty(t, repo.gens[1].ErrorMessage) -} - -func TestProcessGeneration_GatewayServiceNil(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService} - // gatewayService 未设置 → MarkFailed - - h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) - require.Equal(t, "failed", repo.gens[1].Status) - require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService") -} - -// ==================== storeMediaWithDegradation: S3 路径 ==================== - -func TestStoreMediaWithDegradation_S3SuccessSingleURL(t *testing.T) { - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - - s3Storage := newS3StorageForHandler(fakeS3.URL) - h := &SoraClientHandler{s3Storage: s3Storage} - - storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation( - context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil, - ) - require.Equal(t, service.SoraStorageTypeS3, storageType) - require.Len(t, s3Keys, 1) - require.NotEmpty(t, s3Keys[0]) - require.Len(t, storedURLs, 1) - require.Equal(t, storedURL, storedURLs[0]) - require.Contains(t, storedURL, fakeS3.URL) - require.Contains(t, storedURL, "/test-bucket/") - require.Greater(t, fileSize, int64(0)) -} - -func TestStoreMediaWithDegradation_S3SuccessMultiURL(t *testing.T) { - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - - s3Storage := newS3StorageForHandler(fakeS3.URL) - h := &SoraClientHandler{s3Storage: s3Storage} - - urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"} - storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation( - context.Background(), 1, "video", sourceServer.URL+"/a.mp4", urls, - ) - require.Equal(t, service.SoraStorageTypeS3, storageType) - require.Len(t, s3Keys, 2) - require.Len(t, storedURLs, 2) - require.Equal(t, storedURL, storedURLs[0]) - require.Contains(t, storedURLs[0], fakeS3.URL) - require.Contains(t, storedURLs[1], fakeS3.URL) - require.Greater(t, fileSize, int64(0)) -} - -func TestStoreMediaWithDegradation_S3DownloadFails(t *testing.T) { - // 上游返回 404 → 下载失败 → S3 上传不会开始 - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - badSource := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusNotFound) - })) - defer badSource.Close() - - s3Storage := newS3StorageForHandler(fakeS3.URL) - h := &SoraClientHandler{s3Storage: s3Storage} - - _, _, storageType, _, _ := h.storeMediaWithDegradation( - context.Background(), 1, "video", badSource.URL+"/missing.mp4", nil, - ) - require.Equal(t, service.SoraStorageTypeUpstream, storageType) -} - -func TestStoreMediaWithDegradation_S3FailsSingleURL(t *testing.T) { - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("fail") - defer fakeS3.Close() - - s3Storage := newS3StorageForHandler(fakeS3.URL) - h := &SoraClientHandler{s3Storage: s3Storage} - - _, _, storageType, s3Keys, _ := h.storeMediaWithDegradation( - context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil, - ) - // S3 失败,降级到 upstream - require.Equal(t, service.SoraStorageTypeUpstream, storageType) - require.Nil(t, s3Keys) -} - -func TestStoreMediaWithDegradation_S3PartialFailureCleanup(t *testing.T) { - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("fail-second") - defer fakeS3.Close() - - s3Storage := newS3StorageForHandler(fakeS3.URL) - h := &SoraClientHandler{s3Storage: s3Storage} - - urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"} - _, _, storageType, s3Keys, _ := h.storeMediaWithDegradation( - context.Background(), 1, "video", sourceServer.URL+"/a.mp4", urls, - ) - // 第二个 URL 上传失败 → 清理已上传 → 降级到 upstream - require.Equal(t, service.SoraStorageTypeUpstream, storageType) - require.Nil(t, s3Keys) -} - -// ==================== storeMediaWithDegradation: 本地存储路径 ==================== - -func TestStoreMediaWithDegradation_LocalStorageFails(t *testing.T) { - // 使用无效路径,EnsureLocalDirs 失败 → StoreFromURLs 返回 error - cfg := &config.Config{ - Sora: config.SoraConfig{ - Storage: config.SoraStorageConfig{ - Type: "local", - LocalPath: "/dev/null/invalid_dir", - }, - }, - } - mediaStorage := service.NewSoraMediaStorage(cfg) - h := &SoraClientHandler{mediaStorage: mediaStorage} - - _, _, storageType, _, _ := h.storeMediaWithDegradation( - context.Background(), 1, "video", "https://upstream.com/v.mp4", nil, - ) - // 本地存储失败,降级到 upstream - require.Equal(t, service.SoraStorageTypeUpstream, storageType) -} - -func TestStoreMediaWithDegradation_LocalStorageSuccess(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "sora-handler-test-*") - require.NoError(t, err) - defer os.RemoveAll(tmpDir) - - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - - cfg := &config.Config{ - Sora: config.SoraConfig{ - Storage: config.SoraStorageConfig{ - Type: "local", - LocalPath: tmpDir, - DownloadTimeoutSeconds: 5, - MaxDownloadBytes: 10 * 1024 * 1024, - }, - }, - } - mediaStorage := service.NewSoraMediaStorage(cfg) - h := &SoraClientHandler{mediaStorage: mediaStorage} - - _, _, storageType, s3Keys, _ := h.storeMediaWithDegradation( - context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil, - ) - require.Equal(t, service.SoraStorageTypeLocal, storageType) - require.Nil(t, s3Keys) // 本地存储不返回 S3 keys -} - -func TestStoreMediaWithDegradation_S3FailsFallbackToLocal(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "sora-handler-test-*") - require.NoError(t, err) - defer os.RemoveAll(tmpDir) - - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("fail") - defer fakeS3.Close() - - s3Storage := newS3StorageForHandler(fakeS3.URL) - cfg := &config.Config{ - Sora: config.SoraConfig{ - Storage: config.SoraStorageConfig{ - Type: "local", - LocalPath: tmpDir, - DownloadTimeoutSeconds: 5, - MaxDownloadBytes: 10 * 1024 * 1024, - }, - }, - } - mediaStorage := service.NewSoraMediaStorage(cfg) - h := &SoraClientHandler{ - s3Storage: s3Storage, - mediaStorage: mediaStorage, - } - - _, _, storageType, _, _ := h.storeMediaWithDegradation( - context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil, - ) - // S3 失败 → 本地存储成功 - require.Equal(t, service.SoraStorageTypeLocal, storageType) -} - -// ==================== SaveToStorage: S3 路径 ==================== - -func TestSaveToStorage_S3EnabledButUploadFails(t *testing.T) { - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("fail") - defer fakeS3.Close() - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, UserID: 1, Status: "completed", - StorageType: "upstream", - MediaURL: sourceServer.URL + "/v.mp4", - } - s3Storage := newS3StorageForHandler(fakeS3.URL) - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} - - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusInternalServerError, rec.Code) - resp := parseResponse(t, rec) - require.Contains(t, resp["message"], "S3") -} - -func TestSaveToStorage_UpstreamURLExpired(t *testing.T) { - expiredServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusForbidden) - })) - defer expiredServer.Close() - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, UserID: 1, Status: "completed", - StorageType: "upstream", - MediaURL: expiredServer.URL + "/v.mp4", - } - s3Storage := newS3StorageForHandler(fakeS3.URL) - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} - - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusGone, rec.Code) - resp := parseResponse(t, rec) - require.Contains(t, fmt.Sprint(resp["message"]), "过期") -} - -func TestSaveToStorage_S3EnabledUploadSuccess(t *testing.T) { - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, UserID: 1, Status: "completed", - StorageType: "upstream", - MediaURL: sourceServer.URL + "/v.mp4", - } - s3Storage := newS3StorageForHandler(fakeS3.URL) - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} - - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusOK, rec.Code) - resp := parseResponse(t, rec) - data := resp["data"].(map[string]any) - require.Contains(t, data["message"], "S3") - require.NotEmpty(t, data["object_key"]) - // 验证记录已更新为 S3 存储 - require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType) -} - -func TestSaveToStorage_S3EnabledUploadSuccess_MultiMediaURLs(t *testing.T) { - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, UserID: 1, Status: "completed", - StorageType: "upstream", - MediaURL: sourceServer.URL + "/v1.mp4", - MediaURLs: []string{ - sourceServer.URL + "/v1.mp4", - sourceServer.URL + "/v2.mp4", - }, - } - s3Storage := newS3StorageForHandler(fakeS3.URL) - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} - - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusOK, rec.Code) - resp := parseResponse(t, rec) - data := resp["data"].(map[string]any) - require.Len(t, data["object_keys"].([]any), 2) - require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType) - require.Len(t, repo.gens[1].S3ObjectKeys, 2) - require.Len(t, repo.gens[1].MediaURLs, 2) -} - -func TestSaveToStorage_S3EnabledUploadSuccessWithQuota(t *testing.T) { - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, UserID: 1, Status: "completed", - StorageType: "upstream", - MediaURL: sourceServer.URL + "/v.mp4", - } - s3Storage := newS3StorageForHandler(fakeS3.URL) - genService := service.NewSoraGenerationService(repo, nil, nil) - - userRepo := newStubUserRepoForHandler() - userRepo.users[1] = &service.User{ - ID: 1, - Email: "test@example.com", - } - quotaService := service.NewSoraQuotaService(nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService} - - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusOK, rec.Code) -} - -func TestSaveToStorage_S3UploadSuccessMarkCompletedFails(t *testing.T) { - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, UserID: 1, Status: "completed", - StorageType: "upstream", - MediaURL: sourceServer.URL + "/v.mp4", - } - // S3 上传成功后,MarkCompleted 会调用 repo.Update → 失败 - repo.updateErr = fmt.Errorf("db error") - s3Storage := newS3StorageForHandler(fakeS3.URL) - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} - - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusInternalServerError, rec.Code) -} - -// ==================== GetStorageStatus: S3 路径 ==================== - -func TestGetStorageStatus_S3EnabledNotHealthy(t *testing.T) { - // S3 启用但 TestConnection 失败(fake 端点不响应 HeadBucket) - fakeS3 := newFakeS3Server("fail") - defer fakeS3.Close() - - s3Storage := newS3StorageForHandler(fakeS3.URL) - h := &SoraClientHandler{s3Storage: s3Storage} - - c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0) - h.GetStorageStatus(c) - require.Equal(t, http.StatusOK, rec.Code) - resp := parseResponse(t, rec) - data := resp["data"].(map[string]any) - require.Equal(t, true, data["s3_enabled"]) - require.Equal(t, false, data["s3_healthy"]) -} - -func TestGetStorageStatus_S3EnabledHealthy(t *testing.T) { - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - - s3Storage := newS3StorageForHandler(fakeS3.URL) - h := &SoraClientHandler{s3Storage: s3Storage} - - c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0) - h.GetStorageStatus(c) - require.Equal(t, http.StatusOK, rec.Code) - resp := parseResponse(t, rec) - data := resp["data"].(map[string]any) - require.Equal(t, true, data["s3_enabled"]) - require.Equal(t, true, data["s3_healthy"]) -} - -// ==================== Stub: AccountRepository (用于 GatewayService) ==================== - -var _ service.AccountRepository = (*stubAccountRepoForHandler)(nil) - -type stubAccountRepoForHandler struct { - accounts []service.Account -} - -func (r *stubAccountRepoForHandler) Create(context.Context, *service.Account) error { return nil } -func (r *stubAccountRepoForHandler) GetByID(_ context.Context, id int64) (*service.Account, error) { - for i := range r.accounts { - if r.accounts[i].ID == id { - return &r.accounts[i], nil - } - } - return nil, fmt.Errorf("account not found") -} -func (r *stubAccountRepoForHandler) GetByIDs(context.Context, []int64) ([]*service.Account, error) { - return nil, nil -} -func (r *stubAccountRepoForHandler) ExistsByID(context.Context, int64) (bool, error) { - return false, nil -} -func (r *stubAccountRepoForHandler) GetByCRSAccountID(context.Context, string) (*service.Account, error) { - return nil, nil -} -func (r *stubAccountRepoForHandler) FindByExtraField(context.Context, string, any) ([]service.Account, error) { - return nil, nil -} -func (r *stubAccountRepoForHandler) ListCRSAccountIDs(context.Context) (map[string]int64, error) { - return nil, nil -} -func (r *stubAccountRepoForHandler) Update(context.Context, *service.Account) error { return nil } -func (r *stubAccountRepoForHandler) Delete(context.Context, int64) error { return nil } -func (r *stubAccountRepoForHandler) List(context.Context, pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { - return nil, nil, nil -} -func (r *stubAccountRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64, string) ([]service.Account, *pagination.PaginationResult, error) { - return nil, nil, nil -} -func (r *stubAccountRepoForHandler) ListByGroup(context.Context, int64) ([]service.Account, error) { - return nil, nil -} -func (r *stubAccountRepoForHandler) ListActive(context.Context) ([]service.Account, error) { - return nil, nil -} -func (r *stubAccountRepoForHandler) ListByPlatform(context.Context, string) ([]service.Account, error) { - return r.accounts, nil -} -func (r *stubAccountRepoForHandler) UpdateLastUsed(context.Context, int64) error { return nil } -func (r *stubAccountRepoForHandler) BatchUpdateLastUsed(context.Context, map[int64]time.Time) error { - return nil -} -func (r *stubAccountRepoForHandler) SetError(context.Context, int64, string) error { return nil } -func (r *stubAccountRepoForHandler) ClearError(context.Context, int64) error { return nil } -func (r *stubAccountRepoForHandler) SetSchedulable(context.Context, int64, bool) error { - return nil -} -func (r *stubAccountRepoForHandler) AutoPauseExpiredAccounts(context.Context, time.Time) (int64, error) { - return 0, nil -} -func (r *stubAccountRepoForHandler) BindGroups(context.Context, int64, []int64) error { return nil } -func (r *stubAccountRepoForHandler) ListSchedulable(context.Context) ([]service.Account, error) { - return r.accounts, nil -} -func (r *stubAccountRepoForHandler) ListSchedulableByGroupID(context.Context, int64) ([]service.Account, error) { - return r.accounts, nil -} -func (r *stubAccountRepoForHandler) ListSchedulableByPlatform(_ context.Context, _ string) ([]service.Account, error) { - return r.accounts, nil -} -func (r *stubAccountRepoForHandler) ListSchedulableByGroupIDAndPlatform(context.Context, int64, string) ([]service.Account, error) { - return r.accounts, nil -} -func (r *stubAccountRepoForHandler) ListSchedulableByPlatforms(context.Context, []string) ([]service.Account, error) { - return r.accounts, nil -} -func (r *stubAccountRepoForHandler) ListSchedulableByGroupIDAndPlatforms(context.Context, int64, []string) ([]service.Account, error) { - return r.accounts, nil -} -func (r *stubAccountRepoForHandler) ListSchedulableUngroupedByPlatform(_ context.Context, _ string) ([]service.Account, error) { - return r.accounts, nil -} -func (r *stubAccountRepoForHandler) ListSchedulableUngroupedByPlatforms(_ context.Context, _ []string) ([]service.Account, error) { - return r.accounts, nil -} -func (r *stubAccountRepoForHandler) SetRateLimited(context.Context, int64, time.Time) error { - return nil -} -func (r *stubAccountRepoForHandler) SetModelRateLimit(context.Context, int64, string, time.Time) error { - return nil -} -func (r *stubAccountRepoForHandler) SetOverloaded(context.Context, int64, time.Time) error { - return nil -} -func (r *stubAccountRepoForHandler) SetTempUnschedulable(context.Context, int64, time.Time, string) error { - return nil -} -func (r *stubAccountRepoForHandler) ClearTempUnschedulable(context.Context, int64) error { return nil } -func (r *stubAccountRepoForHandler) ClearRateLimit(context.Context, int64) error { return nil } -func (r *stubAccountRepoForHandler) ClearAntigravityQuotaScopes(context.Context, int64) error { - return nil -} -func (r *stubAccountRepoForHandler) ClearModelRateLimits(context.Context, int64) error { return nil } -func (r *stubAccountRepoForHandler) UpdateSessionWindow(context.Context, int64, *time.Time, *time.Time, string) error { - return nil -} -func (r *stubAccountRepoForHandler) UpdateExtra(context.Context, int64, map[string]any) error { - return nil -} -func (r *stubAccountRepoForHandler) BulkUpdate(context.Context, []int64, service.AccountBulkUpdate) (int64, error) { - return 0, nil -} - -func (r *stubAccountRepoForHandler) IncrementQuotaUsed(context.Context, int64, float64) error { - return nil -} - -func (r *stubAccountRepoForHandler) ResetQuotaUsed(context.Context, int64) error { - return nil -} - -// ==================== Stub: SoraClient (用于 SoraGatewayService) ==================== - -var _ service.SoraClient = (*stubSoraClientForHandler)(nil) - -type stubSoraClientForHandler struct { - videoStatus *service.SoraVideoTaskStatus -} - -func (s *stubSoraClientForHandler) Enabled() bool { return true } -func (s *stubSoraClientForHandler) UploadImage(context.Context, *service.Account, []byte, string) (string, error) { - return "", nil -} -func (s *stubSoraClientForHandler) CreateImageTask(context.Context, *service.Account, service.SoraImageRequest) (string, error) { - return "task-image", nil -} -func (s *stubSoraClientForHandler) CreateVideoTask(context.Context, *service.Account, service.SoraVideoRequest) (string, error) { - return "task-video", nil -} -func (s *stubSoraClientForHandler) CreateStoryboardTask(context.Context, *service.Account, service.SoraStoryboardRequest) (string, error) { - return "task-video", nil -} -func (s *stubSoraClientForHandler) UploadCharacterVideo(context.Context, *service.Account, []byte) (string, error) { - return "", nil -} -func (s *stubSoraClientForHandler) GetCameoStatus(context.Context, *service.Account, string) (*service.SoraCameoStatus, error) { - return nil, nil -} -func (s *stubSoraClientForHandler) DownloadCharacterImage(context.Context, *service.Account, string) ([]byte, error) { - return nil, nil -} -func (s *stubSoraClientForHandler) UploadCharacterImage(context.Context, *service.Account, []byte) (string, error) { - return "", nil -} -func (s *stubSoraClientForHandler) FinalizeCharacter(context.Context, *service.Account, service.SoraCharacterFinalizeRequest) (string, error) { - return "", nil -} -func (s *stubSoraClientForHandler) SetCharacterPublic(context.Context, *service.Account, string) error { - return nil -} -func (s *stubSoraClientForHandler) DeleteCharacter(context.Context, *service.Account, string) error { - return nil -} -func (s *stubSoraClientForHandler) PostVideoForWatermarkFree(context.Context, *service.Account, string) (string, error) { - return "", nil -} -func (s *stubSoraClientForHandler) DeletePost(context.Context, *service.Account, string) error { - return nil -} -func (s *stubSoraClientForHandler) GetWatermarkFreeURLCustom(context.Context, *service.Account, string, string, string) (string, error) { - return "", nil -} -func (s *stubSoraClientForHandler) EnhancePrompt(context.Context, *service.Account, string, string, int) (string, error) { - return "", nil -} -func (s *stubSoraClientForHandler) GetImageTask(context.Context, *service.Account, string) (*service.SoraImageTaskStatus, error) { - return nil, nil -} -func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Account, _ string) (*service.SoraVideoTaskStatus, error) { - return s.videoStatus, nil -} - -// ==================== 辅助:创建最小 GatewayService 和 SoraGatewayService ==================== - -// newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService(用于测试 SelectAccountForModel)。 -func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService { - return service.NewGatewayService( - accountRepo, nil, nil, nil, nil, nil, nil, nil, nil, - nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, - nil, // rateLimitService - nil, nil, - ) -} - -// newMinimalSoraGatewayService 创建最小 SoraGatewayService(用于测试 Forward)。 -func newMinimalSoraGatewayService(soraClient service.SoraClient) *service.SoraGatewayService { - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - PollIntervalSeconds: 1, - MaxPollAttempts: 1, - }, - }, - } - return service.NewSoraGatewayService(soraClient, nil, nil, cfg) -} - -// ==================== processGeneration: 更多路径测试 ==================== - -func TestProcessGeneration_SelectAccountError(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - genService := service.NewSoraGenerationService(repo, nil, nil) - // accountRepo 返回空列表 → SelectAccountForModel 返回 "no available accounts" - accountRepo := &stubAccountRepoForHandler{accounts: nil} - gatewayService := newMinimalGatewayService(accountRepo) - h := &SoraClientHandler{genService: genService, gatewayService: gatewayService} - - h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) - require.Equal(t, "failed", repo.gens[1].Status) - require.Contains(t, repo.gens[1].ErrorMessage, "选择账号失败") -} - -func TestProcessGeneration_SoraGatewayServiceNil(t *testing.T) { - // TODO: Re-enable after Sora process generation is stable - // t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - genService := service.NewSoraGenerationService(repo, nil, nil) - // 提供可用账号使 SelectAccountForModel 成功 - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - // soraGatewayService 为 nil - h := &SoraClientHandler{genService: genService, gatewayService: gatewayService} - - h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) - require.Equal(t, "failed", repo.gens[1].Status) - require.Contains(t, repo.gens[1].ErrorMessage, "soraGatewayService") -} - -func TestProcessGeneration_ForwardError(t *testing.T) { - // TODO: Re-enable after Sora process generation is stable - // t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - genService := service.NewSoraGenerationService(repo, nil, nil) - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - // SoraClient 返回视频任务失败 - soraClient := &stubSoraClientForHandler{ - videoStatus: &service.SoraVideoTaskStatus{ - Status: "failed", - ErrorMsg: "content policy violation", - }, - } - soraGatewayService := newMinimalSoraGatewayService(soraClient) - h := &SoraClientHandler{ - genService: genService, - gatewayService: gatewayService, - soraGatewayService: soraGatewayService, - } - - h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1) - require.Equal(t, "failed", repo.gens[1].Status) - require.Contains(t, repo.gens[1].ErrorMessage, "生成失败") -} - -func TestProcessGeneration_ForwardErrorCancelled(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - // MarkGenerating 内部调用 GetByID(第 1 次),Forward 失败后 processGeneration - // 调用 GetByID(第 2 次)。模拟外部在 Forward 期间取消了任务。 - repo.getByIDOverrideAfterN = 1 - repo.getByIDOverrideStatus = "cancelled" - genService := service.NewSoraGenerationService(repo, nil, nil) - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - soraClient := &stubSoraClientForHandler{ - videoStatus: &service.SoraVideoTaskStatus{Status: "failed", ErrorMsg: "reject"}, - } - soraGatewayService := newMinimalSoraGatewayService(soraClient) - h := &SoraClientHandler{ - genService: genService, - gatewayService: gatewayService, - soraGatewayService: soraGatewayService, - } - - h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) - // Forward 失败后检测到外部取消,不应调用 MarkFailed(状态保持 generating) - require.Equal(t, "generating", repo.gens[1].Status) -} - -func TestProcessGeneration_ForwardSuccessNoMediaURL(t *testing.T) { - // TODO: Re-enable after Sora process generation is stable - // t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - genService := service.NewSoraGenerationService(repo, nil, nil) - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - // SoraClient 返回 completed 但无 URL - soraClient := &stubSoraClientForHandler{ - videoStatus: &service.SoraVideoTaskStatus{ - Status: "completed", - URLs: nil, // 无 URL - }, - } - soraGatewayService := newMinimalSoraGatewayService(soraClient) - h := &SoraClientHandler{ - genService: genService, - gatewayService: gatewayService, - soraGatewayService: soraGatewayService, - } - - h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) - require.Equal(t, "failed", repo.gens[1].Status) - require.Contains(t, repo.gens[1].ErrorMessage, "未获取到媒体 URL") -} - -func TestProcessGeneration_ForwardSuccessCancelledBeforeStore(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - // MarkGenerating 调用 GetByID(第 1 次),之后 processGeneration 行 176 调用 GetByID(第 2 次) - // 第 2 次返回 "cancelled" 状态,模拟外部取消 - repo.getByIDOverrideAfterN = 1 - repo.getByIDOverrideStatus = "cancelled" - genService := service.NewSoraGenerationService(repo, nil, nil) - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - soraClient := &stubSoraClientForHandler{ - videoStatus: &service.SoraVideoTaskStatus{ - Status: "completed", - URLs: []string{"https://example.com/video.mp4"}, - }, - } - soraGatewayService := newMinimalSoraGatewayService(soraClient) - h := &SoraClientHandler{ - genService: genService, - gatewayService: gatewayService, - soraGatewayService: soraGatewayService, - } - - h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) - // Forward 成功后检测到外部取消,不应调用存储和 MarkCompleted(状态保持 generating) - require.Equal(t, "generating", repo.gens[1].Status) -} - -func TestProcessGeneration_FullSuccessUpstream(t *testing.T) { - // TODO: Re-enable after Sora process generation is stable - // t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - genService := service.NewSoraGenerationService(repo, nil, nil) - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - soraClient := &stubSoraClientForHandler{ - videoStatus: &service.SoraVideoTaskStatus{ - Status: "completed", - URLs: []string{"https://example.com/video.mp4"}, - }, - } - soraGatewayService := newMinimalSoraGatewayService(soraClient) - // 无 S3 和本地存储,降级到 upstream - h := &SoraClientHandler{ - genService: genService, - gatewayService: gatewayService, - soraGatewayService: soraGatewayService, - } - - h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1) - require.Equal(t, "completed", repo.gens[1].Status) - require.Equal(t, service.SoraStorageTypeUpstream, repo.gens[1].StorageType) - require.NotEmpty(t, repo.gens[1].MediaURL) -} - -func TestProcessGeneration_FullSuccessWithS3(t *testing.T) { - // TODO: Re-enable after Sora process generation is stable - // t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - genService := service.NewSoraGenerationService(repo, nil, nil) - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - soraClient := &stubSoraClientForHandler{ - videoStatus: &service.SoraVideoTaskStatus{ - Status: "completed", - URLs: []string{sourceServer.URL + "/video.mp4"}, - }, - } - soraGatewayService := newMinimalSoraGatewayService(soraClient) - s3Storage := newS3StorageForHandler(fakeS3.URL) - - userRepo := newStubUserRepoForHandler() - // 配额已满(系统级配额为0,所有用户均被限制) - userRepo.users[1] = &service.User{ID: 1} - quotaService := service.NewSoraQuotaService(nil) - - h := &SoraClientHandler{ - genService: genService, - gatewayService: gatewayService, - soraGatewayService: soraGatewayService, - s3Storage: s3Storage, - quotaService: quotaService, - } - - h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1) - require.Equal(t, "completed", repo.gens[1].Status) - require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType) - require.NotEmpty(t, repo.gens[1].S3ObjectKeys) - require.Greater(t, repo.gens[1].FileSizeBytes, int64(0)) - // 验证配额已累加(通过 quotaService 内部计数验证) - require.NotEmpty(t, repo.gens[1].S3ObjectKeys) -} - -func TestProcessGeneration_MarkCompletedFails(t *testing.T) { - // TODO: Re-enable after Sora process generation is stable - // t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - // 第 1 次 Update(MarkGenerating)成功,第 2 次(MarkCompleted)失败 - repo.updateCallCount = new(int32) - repo.updateFailAfterN = 1 - genService := service.NewSoraGenerationService(repo, nil, nil) - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - soraClient := &stubSoraClientForHandler{ - videoStatus: &service.SoraVideoTaskStatus{ - Status: "completed", - URLs: []string{"https://example.com/video.mp4"}, - }, - } - soraGatewayService := newMinimalSoraGatewayService(soraClient) - h := &SoraClientHandler{ - genService: genService, - gatewayService: gatewayService, - soraGatewayService: soraGatewayService, - } - - h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1) - // MarkCompleted 内部先修改内存对象状态为 completed,然后 Update 失败。 - // 由于 stub 存储的是指针,内存中的状态已被修改为 completed。 - // 此测试验证 processGeneration 在 MarkCompleted 失败后提前返回(不调用 AddUsage)。 - require.Equal(t, "completed", repo.gens[1].Status) -} - -// ==================== cleanupStoredMedia 直接测试 ==================== - -func TestCleanupStoredMedia_S3Path(t *testing.T) { - // S3 清理路径:s3Storage 为 nil 时不 panic - h := &SoraClientHandler{} - // 不应 panic - h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil) -} - -func TestCleanupStoredMedia_LocalPath(t *testing.T) { - // 本地清理路径:mediaStorage 为 nil 时不 panic - h := &SoraClientHandler{} - h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, []string{"/tmp/test.mp4"}) -} - -func TestCleanupStoredMedia_UpstreamPath(t *testing.T) { - // upstream 类型不清理 - h := &SoraClientHandler{} - h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeUpstream, nil, nil) -} - -func TestCleanupStoredMedia_EmptyKeys(t *testing.T) { - // 空 keys 不触发清理 - h := &SoraClientHandler{} - h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, nil, nil) - h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, nil) -} - -// ==================== DeleteGeneration: 本地存储清理路径 ==================== - -func TestDeleteGeneration_LocalStorageCleanup(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "sora-delete-test-*") - require.NoError(t, err) - defer os.RemoveAll(tmpDir) - - cfg := &config.Config{ - Sora: config.SoraConfig{ - Storage: config.SoraStorageConfig{ - Type: "local", - LocalPath: tmpDir, - }, - }, - } - mediaStorage := service.NewSoraMediaStorage(cfg) - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, - UserID: 1, - Status: "completed", - StorageType: service.SoraStorageTypeLocal, - MediaURL: "video/test.mp4", - MediaURLs: []string{"video/test.mp4"}, - } - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage} - - c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.DeleteGeneration(c) - require.Equal(t, http.StatusOK, rec.Code) - _, exists := repo.gens[1] - require.False(t, exists) -} - -func TestDeleteGeneration_LocalStorageCleanup_MediaURLFallback(t *testing.T) { - // MediaURLs 为空,使用 MediaURL 作为清理路径 - tmpDir, err := os.MkdirTemp("", "sora-delete-fallback-*") - require.NoError(t, err) - defer os.RemoveAll(tmpDir) - - cfg := &config.Config{ - Sora: config.SoraConfig{ - Storage: config.SoraStorageConfig{ - Type: "local", - LocalPath: tmpDir, - }, - }, - } - mediaStorage := service.NewSoraMediaStorage(cfg) - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, - UserID: 1, - Status: "completed", - StorageType: service.SoraStorageTypeLocal, - MediaURL: "video/test.mp4", - MediaURLs: nil, // 空 - } - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage} - - c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.DeleteGeneration(c) - require.Equal(t, http.StatusOK, rec.Code) -} - -func TestDeleteGeneration_NonLocalStorage_SkipCleanup(t *testing.T) { - // 非本地存储类型 → 跳过清理 - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, - UserID: 1, - Status: "completed", - StorageType: service.SoraStorageTypeUpstream, - MediaURL: "https://upstream.com/v.mp4", - } - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService} - - c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.DeleteGeneration(c) - require.Equal(t, http.StatusOK, rec.Code) -} - -func TestDeleteGeneration_DeleteError(t *testing.T) { - // repo.Delete 出错 - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream"} - repo.deleteErr = fmt.Errorf("delete failed") - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService} - - c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.DeleteGeneration(c) - require.Equal(t, http.StatusNotFound, rec.Code) -} - -// ==================== fetchUpstreamModels 测试 ==================== - -func TestFetchUpstreamModels_NilGateway(t *testing.T) { - // TODO: Re-enable after Sora upstream model sync is stable - // t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") - h := &SoraClientHandler{} - _, err := h.fetchUpstreamModels(context.Background()) - require.Error(t, err) - require.Contains(t, err.Error(), "gatewayService 未初始化") -} - -func TestFetchUpstreamModels_NoAccounts(t *testing.T) { - // TODO: Re-enable after Sora upstream model sync is stable - // t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") - accountRepo := &stubAccountRepoForHandler{accounts: nil} - gatewayService := newMinimalGatewayService(accountRepo) - h := &SoraClientHandler{gatewayService: gatewayService} - _, err := h.fetchUpstreamModels(context.Background()) - require.Error(t, err) - require.Contains(t, err.Error(), "选择 Sora 账号失败") -} - -func TestFetchUpstreamModels_NonAPIKeyAccount(t *testing.T) { - // TODO: Re-enable after Sora upstream model sync is stable - // t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Type: "oauth", Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - h := &SoraClientHandler{gatewayService: gatewayService} - _, err := h.fetchUpstreamModels(context.Background()) - require.Error(t, err) - require.Contains(t, err.Error(), "不支持模型同步") -} - -func TestFetchUpstreamModels_MissingAPIKey(t *testing.T) { - // TODO: Re-enable after Sora upstream model sync is stable - // t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, - Credentials: map[string]any{"base_url": "https://sora.test"}}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - h := &SoraClientHandler{gatewayService: gatewayService} - _, err := h.fetchUpstreamModels(context.Background()) - require.Error(t, err) - require.Contains(t, err.Error(), "api_key") -} - -func TestFetchUpstreamModels_MissingBaseURL_FallsBackToDefault(t *testing.T) { - // TODO: Re-enable after Sora upstream model sync is stable - // t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") - // GetBaseURL() 在缺少 base_url 时返回默认值 "https://api.anthropic.com" - // 因此不会触发 "账号缺少 base_url" 错误,而是会尝试请求默认 URL 并失败 - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, - Credentials: map[string]any{"api_key": "sk-test"}}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - h := &SoraClientHandler{gatewayService: gatewayService} - _, err := h.fetchUpstreamModels(context.Background()) - require.Error(t, err) -} - -func TestFetchUpstreamModels_UpstreamReturns500(t *testing.T) { - // TODO: Re-enable after Sora upstream model sync is stable - // t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusInternalServerError) - })) - defer ts.Close() - - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, - Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - h := &SoraClientHandler{gatewayService: gatewayService} - _, err := h.fetchUpstreamModels(context.Background()) - require.Error(t, err) - require.Contains(t, err.Error(), "状态码 500") -} - -func TestFetchUpstreamModels_UpstreamReturnsInvalidJSON(t *testing.T) { - // TODO: Re-enable after Sora upstream model sync is stable - // t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("not json")) - })) - defer ts.Close() - - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, - Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - h := &SoraClientHandler{gatewayService: gatewayService} - _, err := h.fetchUpstreamModels(context.Background()) - require.Error(t, err) - require.Contains(t, err.Error(), "解析响应失败") -} - -func TestFetchUpstreamModels_UpstreamReturnsEmptyList(t *testing.T) { - // TODO: Re-enable after Sora upstream model sync is stable - // t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(`{"data":[]}`)) - })) - defer ts.Close() - - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, - Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - h := &SoraClientHandler{gatewayService: gatewayService} - _, err := h.fetchUpstreamModels(context.Background()) - require.Error(t, err) - require.Contains(t, err.Error(), "空模型列表") -} - -func TestFetchUpstreamModels_Success(t *testing.T) { - // TODO: Re-enable after Sora upstream model sync is stable - // t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // 验证请求头 - require.Equal(t, "Bearer sk-test", r.Header.Get("Authorization")) - require.True(t, strings.HasSuffix(r.URL.Path, "/sora/v1/models")) - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(`{"data":[{"id":"sora2-landscape-10s"},{"id":"sora2-portrait-10s"},{"id":"sora2-landscape-15s"},{"id":"gpt-image"}]}`)) - })) - defer ts.Close() - - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, - Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - h := &SoraClientHandler{gatewayService: gatewayService} - families, err := h.fetchUpstreamModels(context.Background()) - require.NoError(t, err) - require.NotEmpty(t, families) -} - -func TestFetchUpstreamModels_UnrecognizedModels(t *testing.T) { - // TODO: Re-enable after Sora upstream model sync is stable - // t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(`{"data":[{"id":"unknown-model-1"},{"id":"unknown-model-2"}]}`)) - })) - defer ts.Close() - - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, - Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - h := &SoraClientHandler{gatewayService: gatewayService} - _, err := h.fetchUpstreamModels(context.Background()) - require.Error(t, err) - require.Contains(t, err.Error(), "未能从上游模型列表中识别") -} - -// ==================== getModelFamilies 缓存测试 ==================== - -func TestGetModelFamilies_CachesLocalConfig(t *testing.T) { - // gatewayService 为 nil → fetchUpstreamModels 失败 → 降级到本地配置 - h := &SoraClientHandler{} - families := h.getModelFamilies(context.Background()) - require.NotEmpty(t, families) - - // 第二次调用应命中缓存(modelCacheUpstream=false → 使用短 TTL) - families2 := h.getModelFamilies(context.Background()) - require.Equal(t, families, families2) - require.False(t, h.modelCacheUpstream) -} - -func TestGetModelFamilies_CachesUpstreamResult(t *testing.T) { - // TODO: Re-enable after Sora upstream model sync is stable - // t.Skip("TODO: 临时屏蔽依赖 Sora 上游模型同步的缓存测试,待账号选择逻辑稳定后恢复") - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(`{"data":[{"id":"sora2-landscape-10s"},{"id":"gpt-image"}]}`)) - })) - defer ts.Close() - - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, - Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - h := &SoraClientHandler{gatewayService: gatewayService} - - families := h.getModelFamilies(context.Background()) - require.NotEmpty(t, families) - require.True(t, h.modelCacheUpstream) - - // 第二次调用命中缓存 - families2 := h.getModelFamilies(context.Background()) - require.Equal(t, families, families2) -} - -func TestGetModelFamilies_ExpiredCacheRefreshes(t *testing.T) { - // 预设过期的缓存(modelCacheUpstream=false → 短 TTL) - h := &SoraClientHandler{ - cachedFamilies: []service.SoraModelFamily{{ID: "old"}}, - modelCacheTime: time.Now().Add(-10 * time.Minute), // 已过期 - modelCacheUpstream: false, - } - // gatewayService 为 nil → fetchUpstreamModels 失败 → 使用本地配置刷新缓存 - families := h.getModelFamilies(context.Background()) - require.NotEmpty(t, families) - // 缓存已刷新,不再是 "old" - found := false - for _, f := range families { - if f.ID == "old" { - found = true - } - } - require.False(t, found, "过期缓存应被刷新") -} - -// ==================== processGeneration: groupID 与 ForcePlatform ==================== - -func TestProcessGeneration_NilGroupID_WithGateway_SelectAccountFails(t *testing.T) { - // groupID 为 nil → 设置 ForcePlatform=sora → 无可用 sora 账号 → MarkFailed - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - genService := service.NewSoraGenerationService(repo, nil, nil) - - // 空账号列表 → SelectAccountForModel 失败 - accountRepo := &stubAccountRepoForHandler{accounts: nil} - gatewayService := newMinimalGatewayService(accountRepo) - - h := &SoraClientHandler{ - genService: genService, - gatewayService: gatewayService, - } - - h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) - require.Equal(t, "failed", repo.gens[1].Status) - require.Contains(t, repo.gens[1].ErrorMessage, "选择账号失败") -} - -// ==================== Generate: 配额检查非 QuotaExceeded 错误 ==================== - -func TestGenerate_CheckQuotaNonQuotaError(t *testing.T) { - // After refactoring: system-default only mode with nil config → unlimited. - // No user lookup needed, no quota check failure path triggered. - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - - _ = newStubUserRepoForHandler() // userRepo not used in unlimited mode - quotaService := service.NewSoraQuotaService(nil) - - h := NewSoraClientHandler(genService, quotaService, nil, nil, nil, nil, nil) - - body := `{"model":"sora2-landscape-10s","prompt":"test"}` - c, rec := makeGinContext("POST", "/api/v1/sora/generate", body, 1) - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) // unlimited mode allows all -} - -// ==================== Generate: CreatePending 并发限制错误 ==================== - -// stubSoraGenRepoWithAtomicCreate 实现 soraGenerationRepoAtomicCreator 接口 -type stubSoraGenRepoWithAtomicCreate struct { - stubSoraGenRepo - limitErr error -} - -func (r *stubSoraGenRepoWithAtomicCreate) CreatePendingWithLimit(_ context.Context, gen *service.SoraGeneration, _ []string, _ int64) error { - if r.limitErr != nil { - return r.limitErr - } - return r.stubSoraGenRepo.Create(context.Background(), gen) -} - -func TestGenerate_CreatePendingConcurrencyLimit(t *testing.T) { - repo := &stubSoraGenRepoWithAtomicCreate{ - stubSoraGenRepo: *newStubSoraGenRepo(), - limitErr: service.ErrSoraGenerationConcurrencyLimit, - } - genService := service.NewSoraGenerationService(repo, nil, nil) - h := NewSoraClientHandler(genService, nil, nil, nil, nil, nil, nil) - - body := `{"model":"sora2-landscape-10s","prompt":"test"}` - c, rec := makeGinContext("POST", "/api/v1/sora/generate", body, 1) - h.Generate(c) - require.Equal(t, http.StatusTooManyRequests, rec.Code) - resp := parseResponse(t, rec) - require.Contains(t, resp["message"], "3") -} - -// ==================== SaveToStorage: 配额超限 ==================== - -func TestSaveToStorage_QuotaExceeded(t *testing.T) { - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, UserID: 1, Status: "completed", - StorageType: "upstream", - MediaURL: sourceServer.URL + "/v.mp4", - } - s3Storage := newS3StorageForHandler(fakeS3.URL) - genService := service.NewSoraGenerationService(repo, nil, nil) - - // 配额已满 - userRepo := newStubUserRepoForHandler() - userRepo.users[1] = &service.User{ID: 1} - quotaService := service.NewSoraQuotaService(nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService} - - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusOK, rec.Code) // unlimited mode allows save -} - -// ==================== SaveToStorage: 配额非 QuotaExceeded 错误 ==================== - -func TestSaveToStorage_QuotaNonQuotaError(t *testing.T) { - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, UserID: 1, Status: "completed", - StorageType: "upstream", - MediaURL: sourceServer.URL + "/v.mp4", - } - s3Storage := newS3StorageForHandler(fakeS3.URL) - genService := service.NewSoraGenerationService(repo, nil, nil) - - // 用户不存在 → After refactoring: unlimited mode doesn't check per-user - _ = newStubUserRepoForHandler() // userRepo not used in unlimited mode - quotaService := service.NewSoraQuotaService(nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService} - - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusOK, rec.Code) // unlimited mode allows save -} - -// ==================== SaveToStorage: MediaURLs 全为空 ==================== - -func TestSaveToStorage_EmptyMediaURLs(t *testing.T) { - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, UserID: 1, Status: "completed", - StorageType: "upstream", - MediaURL: "", - MediaURLs: []string{}, - } - s3Storage := newS3StorageForHandler(fakeS3.URL) - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} - - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusBadRequest, rec.Code) - resp := parseResponse(t, rec) - require.Contains(t, resp["message"], "已过期") -} - -// ==================== SaveToStorage: S3 上传失败时已有已上传文件需清理 ==================== - -func TestSaveToStorage_MultiURL_SecondUploadFails(t *testing.T) { - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("fail-second") - defer fakeS3.Close() - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, UserID: 1, Status: "completed", - StorageType: "upstream", - MediaURL: sourceServer.URL + "/v1.mp4", - MediaURLs: []string{sourceServer.URL + "/v1.mp4", sourceServer.URL + "/v2.mp4"}, - } - s3Storage := newS3StorageForHandler(fakeS3.URL) - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} - - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusInternalServerError, rec.Code) -} - -// ==================== SaveToStorage: UpdateStorageForCompleted 失败(含配额回滚) ==================== - -func TestSaveToStorage_MarkCompletedFailsWithQuotaRollback(t *testing.T) { - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, UserID: 1, Status: "completed", - StorageType: "upstream", - MediaURL: sourceServer.URL + "/v.mp4", - } - repo.updateErr = fmt.Errorf("db error") - s3Storage := newS3StorageForHandler(fakeS3.URL) - genService := service.NewSoraGenerationService(repo, nil, nil) - - userRepo := newStubUserRepoForHandler() - userRepo.users[1] = &service.User{ID: 1} - quotaService := service.NewSoraQuotaService(nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService} - - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusInternalServerError, rec.Code) -} - -// ==================== cleanupStoredMedia: 实际 S3 删除路径 ==================== - -func TestCleanupStoredMedia_WithS3Storage_ActualDelete(t *testing.T) { - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - s3Storage := newS3StorageForHandler(fakeS3.URL) - h := &SoraClientHandler{s3Storage: s3Storage} - - h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1", "key2"}, nil) -} - -func TestCleanupStoredMedia_S3DeleteFails_LogOnly(t *testing.T) { - fakeS3 := newFakeS3Server("fail") - defer fakeS3.Close() - s3Storage := newS3StorageForHandler(fakeS3.URL) - h := &SoraClientHandler{s3Storage: s3Storage} - - h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil) -} - -func TestCleanupStoredMedia_LocalDeleteFails_LogOnly(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "sora-cleanup-fail-*") - require.NoError(t, err) - defer os.RemoveAll(tmpDir) - - cfg := &config.Config{ - Sora: config.SoraConfig{ - Storage: config.SoraStorageConfig{ - Type: "local", - LocalPath: tmpDir, - }, - }, - } - mediaStorage := service.NewSoraMediaStorage(cfg) - h := &SoraClientHandler{mediaStorage: mediaStorage} - - h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, []string{"nonexistent/file.mp4"}) -} - -// ==================== DeleteGeneration: 本地文件删除失败(仅日志) ==================== - -func TestDeleteGeneration_LocalStorageDeleteFails_LogOnly(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "sora-del-test-*") - require.NoError(t, err) - defer os.RemoveAll(tmpDir) - - cfg := &config.Config{ - Sora: config.SoraConfig{ - Storage: config.SoraStorageConfig{ - Type: "local", - LocalPath: tmpDir, - }, - }, - } - mediaStorage := service.NewSoraMediaStorage(cfg) - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, UserID: 1, Status: "completed", - StorageType: service.SoraStorageTypeLocal, - MediaURL: "nonexistent/video.mp4", - MediaURLs: []string{"nonexistent/video.mp4"}, - } - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage} - - c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.DeleteGeneration(c) - require.Equal(t, http.StatusOK, rec.Code) -} - -// ==================== CancelGeneration: 任务已结束冲突 ==================== - -func TestCancelGeneration_AlreadyCompleted(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"} - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService} - - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.CancelGeneration(c) - require.Equal(t, http.StatusConflict, rec.Code) -} diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go deleted file mode 100644 index c9c7de17..00000000 --- a/backend/internal/handler/sora_gateway_handler.go +++ /dev/null @@ -1,694 +0,0 @@ -package handler - -import ( - "context" - "crypto/sha256" - "encoding/hex" - "encoding/json" - "errors" - "fmt" - "net/http" - "os" - "path" - "path/filepath" - "strconv" - "strings" - "time" - - "github.com/Wei-Shaw/sub2api/internal/config" - pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" - "github.com/Wei-Shaw/sub2api/internal/pkg/ip" - "github.com/Wei-Shaw/sub2api/internal/pkg/logger" - middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" - "github.com/Wei-Shaw/sub2api/internal/service" - "github.com/Wei-Shaw/sub2api/internal/util/soraerror" - - "github.com/gin-gonic/gin" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" - "go.uber.org/zap" -) - -// SoraGatewayHandler handles Sora chat completions requests -type SoraGatewayHandler struct { - gatewayService *service.GatewayService - soraGatewayService *service.SoraGatewayService - billingCacheService *service.BillingCacheService - usageRecordWorkerPool *service.UsageRecordWorkerPool - concurrencyHelper *ConcurrencyHelper - maxAccountSwitches int - streamMode string - soraTLSEnabled bool - soraMediaSigningKey string - soraMediaRoot string -} - -// NewSoraGatewayHandler creates a new SoraGatewayHandler -func NewSoraGatewayHandler( - gatewayService *service.GatewayService, - soraGatewayService *service.SoraGatewayService, - concurrencyService *service.ConcurrencyService, - billingCacheService *service.BillingCacheService, - usageRecordWorkerPool *service.UsageRecordWorkerPool, - cfg *config.Config, -) *SoraGatewayHandler { - pingInterval := time.Duration(0) - maxAccountSwitches := 3 - streamMode := "force" - soraTLSEnabled := true - signKey := "" - mediaRoot := "/app/data/sora" - if cfg != nil { - pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second - if cfg.Gateway.MaxAccountSwitches > 0 { - maxAccountSwitches = cfg.Gateway.MaxAccountSwitches - } - if mode := strings.TrimSpace(cfg.Gateway.SoraStreamMode); mode != "" { - streamMode = mode - } - soraTLSEnabled = !cfg.Sora.Client.DisableTLSFingerprint - signKey = strings.TrimSpace(cfg.Gateway.SoraMediaSigningKey) - if root := strings.TrimSpace(cfg.Sora.Storage.LocalPath); root != "" { - mediaRoot = root - } - } - return &SoraGatewayHandler{ - gatewayService: gatewayService, - soraGatewayService: soraGatewayService, - billingCacheService: billingCacheService, - usageRecordWorkerPool: usageRecordWorkerPool, - concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), - maxAccountSwitches: maxAccountSwitches, - streamMode: strings.ToLower(streamMode), - soraTLSEnabled: soraTLSEnabled, - soraMediaSigningKey: signKey, - soraMediaRoot: mediaRoot, - } -} - -// ChatCompletions handles Sora /v1/chat/completions endpoint -func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { - apiKey, ok := middleware2.GetAPIKeyFromContext(c) - if !ok { - h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") - return - } - - subject, ok := middleware2.GetAuthSubjectFromContext(c) - if !ok { - h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") - return - } - reqLog := requestLogger( - c, - "handler.sora_gateway.chat_completions", - zap.Int64("user_id", subject.UserID), - zap.Int64("api_key_id", apiKey.ID), - zap.Any("group_id", apiKey.GroupID), - ) - - body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) - if err != nil { - if maxErr, ok := extractMaxBytesError(err); ok { - h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) - return - } - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") - return - } - if len(body) == 0 { - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") - return - } - - setOpsRequestContext(c, "", false, body) - - // 校验请求体 JSON 合法性 - if !gjson.ValidBytes(body) { - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") - return - } - - // 使用 gjson 只读提取字段做校验,避免完整 Unmarshal - modelResult := gjson.GetBytes(body, "model") - if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" { - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") - return - } - reqModel := modelResult.String() - - msgsResult := gjson.GetBytes(body, "messages") - if !msgsResult.IsArray() || len(msgsResult.Array()) == 0 { - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "messages is required") - return - } - - clientStream := gjson.GetBytes(body, "stream").Bool() - reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", clientStream)) - if !clientStream { - if h.streamMode == "error" { - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Sora requires stream=true") - return - } - var err error - body, err = sjson.SetBytes(body, "stream", true) - if err != nil { - h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request") - return - } - } - - setOpsRequestContext(c, reqModel, clientStream, body) - - platform := "" - if forced, ok := middleware2.GetForcePlatformFromContext(c); ok { - platform = forced - } else if apiKey.Group != nil { - platform = apiKey.Group.Platform - } - if platform != service.PlatformSora { - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "This endpoint only supports Sora platform") - return - } - - streamStarted := false - subscription, _ := middleware2.GetSubscriptionFromContext(c) - - maxWait := service.CalculateMaxWait(subject.Concurrency) - canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait) - waitCounted := false - if err != nil { - reqLog.Warn("sora.user_wait_counter_increment_failed", zap.Error(err)) - } else if !canWait { - reqLog.Info("sora.user_wait_queue_full", zap.Int("max_wait", maxWait)) - h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later") - return - } - if err == nil && canWait { - waitCounted = true - } - defer func() { - if waitCounted { - h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) - } - }() - - userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, clientStream, &streamStarted) - if err != nil { - reqLog.Warn("sora.user_slot_acquire_failed", zap.Error(err)) - h.handleConcurrencyError(c, err, "user", streamStarted) - return - } - if waitCounted { - h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) - waitCounted = false - } - userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc) - if userReleaseFunc != nil { - defer userReleaseFunc() - } - - if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { - reqLog.Info("sora.billing_eligibility_check_failed", zap.Error(err)) - status, code, message := billingErrorDetails(err) - h.handleStreamingAwareError(c, status, code, message, streamStarted) - return - } - - sessionHash := generateOpenAISessionHash(c, body) - - maxAccountSwitches := h.maxAccountSwitches - switchCount := 0 - failedAccountIDs := make(map[int64]struct{}) - lastFailoverStatus := 0 - var lastFailoverBody []byte - var lastFailoverHeaders http.Header - - for { - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "", int64(0)) - if err != nil { - reqLog.Warn("sora.account_select_failed", - zap.Error(err), - zap.Int("excluded_account_count", len(failedAccountIDs)), - ) - if len(failedAccountIDs) == 0 { - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) - return - } - rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody) - fields := []zap.Field{ - zap.Int("last_upstream_status", lastFailoverStatus), - } - if rayID != "" { - fields = append(fields, zap.String("last_upstream_cf_ray", rayID)) - } - if mitigated != "" { - fields = append(fields, zap.String("last_upstream_cf_mitigated", mitigated)) - } - if contentType != "" { - fields = append(fields, zap.String("last_upstream_content_type", contentType)) - } - reqLog.Warn("sora.failover_exhausted_no_available_accounts", fields...) - h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted) - return - } - account := selection.Account - setOpsSelectedAccount(c, account.ID, account.Platform) - proxyBound := account.ProxyID != nil - proxyID := int64(0) - if account.ProxyID != nil { - proxyID = *account.ProxyID - } - tlsFingerprintEnabled := h.soraTLSEnabled - - accountReleaseFunc := selection.ReleaseFunc - if !selection.Acquired { - if selection.WaitPlan == nil { - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) - return - } - accountWaitCounted := false - canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) - if err != nil { - reqLog.Warn("sora.account_wait_counter_increment_failed", - zap.Int64("account_id", account.ID), - zap.Int64("proxy_id", proxyID), - zap.Bool("proxy_bound", proxyBound), - zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), - zap.Error(err), - ) - } else if !canWait { - reqLog.Info("sora.account_wait_queue_full", - zap.Int64("account_id", account.ID), - zap.Int64("proxy_id", proxyID), - zap.Bool("proxy_bound", proxyBound), - zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), - zap.Int("max_waiting", selection.WaitPlan.MaxWaiting), - ) - h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) - return - } - if err == nil && canWait { - accountWaitCounted = true - } - defer func() { - if accountWaitCounted { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) - } - }() - - accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( - c, - account.ID, - selection.WaitPlan.MaxConcurrency, - selection.WaitPlan.Timeout, - clientStream, - &streamStarted, - ) - if err != nil { - reqLog.Warn("sora.account_slot_acquire_failed", - zap.Int64("account_id", account.ID), - zap.Int64("proxy_id", proxyID), - zap.Bool("proxy_bound", proxyBound), - zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), - zap.Error(err), - ) - h.handleConcurrencyError(c, err, "account", streamStarted) - return - } - if accountWaitCounted { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) - accountWaitCounted = false - } - } - accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) - - result, err := h.soraGatewayService.Forward(c.Request.Context(), c, account, body, clientStream) - if accountReleaseFunc != nil { - accountReleaseFunc() - } - if err != nil { - var failoverErr *service.UpstreamFailoverError - if errors.As(err, &failoverErr) { - failedAccountIDs[account.ID] = struct{}{} - if switchCount >= maxAccountSwitches { - lastFailoverStatus = failoverErr.StatusCode - lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders) - lastFailoverBody = failoverErr.ResponseBody - rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody) - fields := []zap.Field{ - zap.Int64("account_id", account.ID), - zap.Int64("proxy_id", proxyID), - zap.Bool("proxy_bound", proxyBound), - zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), - zap.Int("upstream_status", failoverErr.StatusCode), - zap.Int("switch_count", switchCount), - zap.Int("max_switches", maxAccountSwitches), - } - if rayID != "" { - fields = append(fields, zap.String("upstream_cf_ray", rayID)) - } - if mitigated != "" { - fields = append(fields, zap.String("upstream_cf_mitigated", mitigated)) - } - if contentType != "" { - fields = append(fields, zap.String("upstream_content_type", contentType)) - } - reqLog.Warn("sora.upstream_failover_exhausted", fields...) - h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted) - return - } - lastFailoverStatus = failoverErr.StatusCode - lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders) - lastFailoverBody = failoverErr.ResponseBody - switchCount++ - upstreamErrCode, upstreamErrMsg := extractUpstreamErrorCodeAndMessage(lastFailoverBody) - rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody) - fields := []zap.Field{ - zap.Int64("account_id", account.ID), - zap.Int64("proxy_id", proxyID), - zap.Bool("proxy_bound", proxyBound), - zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), - zap.Int("upstream_status", failoverErr.StatusCode), - zap.String("upstream_error_code", upstreamErrCode), - zap.String("upstream_error_message", upstreamErrMsg), - zap.Int("switch_count", switchCount), - zap.Int("max_switches", maxAccountSwitches), - } - if rayID != "" { - fields = append(fields, zap.String("upstream_cf_ray", rayID)) - } - if mitigated != "" { - fields = append(fields, zap.String("upstream_cf_mitigated", mitigated)) - } - if contentType != "" { - fields = append(fields, zap.String("upstream_content_type", contentType)) - } - reqLog.Warn("sora.upstream_failover_switching", fields...) - continue - } - reqLog.Error("sora.forward_failed", - zap.Int64("account_id", account.ID), - zap.Int64("proxy_id", proxyID), - zap.Bool("proxy_bound", proxyBound), - zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), - zap.Error(err), - ) - return - } - - userAgent := c.GetHeader("User-Agent") - clientIP := ip.GetClientIP(c) - requestPayloadHash := service.HashUsageRequestPayload(body) - inboundEndpoint := GetInboundEndpoint(c) - upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) - - // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 - h.submitUsageRecordTask(func(ctx context.Context) { - if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: account, - Subscription: subscription, - InboundEndpoint: inboundEndpoint, - UpstreamEndpoint: upstreamEndpoint, - UserAgent: userAgent, - IPAddress: clientIP, - RequestPayloadHash: requestPayloadHash, - }); err != nil { - logger.L().With( - zap.String("component", "handler.sora_gateway.chat_completions"), - zap.Int64("user_id", subject.UserID), - zap.Int64("api_key_id", apiKey.ID), - zap.Any("group_id", apiKey.GroupID), - zap.String("model", reqModel), - zap.Int64("account_id", account.ID), - ).Error("sora.record_usage_failed", zap.Error(err)) - } - }) - reqLog.Debug("sora.request_completed", - zap.Int64("account_id", account.ID), - zap.Int64("proxy_id", proxyID), - zap.Bool("proxy_bound", proxyBound), - zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), - zap.Int("switch_count", switchCount), - ) - return - } -} - -func generateOpenAISessionHash(c *gin.Context, body []byte) string { - if c == nil { - return "" - } - sessionID := strings.TrimSpace(c.GetHeader("session_id")) - if sessionID == "" { - sessionID = strings.TrimSpace(c.GetHeader("conversation_id")) - } - if sessionID == "" && len(body) > 0 { - sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()) - } - if sessionID == "" { - return "" - } - hash := sha256.Sum256([]byte(sessionID)) - return hex.EncodeToString(hash[:]) -} - -func (h *SoraGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) { - if task == nil { - return - } - if h.usageRecordWorkerPool != nil { - h.usageRecordWorkerPool.Submit(task) - return - } - // 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。 - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - defer func() { - if recovered := recover(); recovered != nil { - logger.L().With( - zap.String("component", "handler.sora_gateway.chat_completions"), - zap.Any("panic", recovered), - ).Error("sora.usage_record_task_panic_recovered") - } - }() - task(ctx) -} - -func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) { - h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", - fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted) -} - -func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseHeaders http.Header, responseBody []byte, streamStarted bool) { - upstreamMsg := service.ExtractUpstreamErrorMessage(responseBody) - service.SetOpsUpstreamError(c, statusCode, upstreamMsg, "") - - status, errType, errMsg := h.mapUpstreamError(statusCode, responseHeaders, responseBody) - h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted) -} - -func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseHeaders http.Header, responseBody []byte) (int, string, string) { - if isSoraCloudflareChallengeResponse(statusCode, responseHeaders, responseBody) { - baseMsg := fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", statusCode) - return http.StatusBadGateway, "upstream_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody) - } - - upstreamCode, upstreamMessage := extractUpstreamErrorCodeAndMessage(responseBody) - if strings.EqualFold(upstreamCode, "cf_shield_429") { - baseMsg := "Sora request blocked by Cloudflare shield (429). Please switch to a clean proxy/network and retry." - return http.StatusTooManyRequests, "rate_limit_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody) - } - if shouldPassthroughSoraUpstreamMessage(statusCode, upstreamMessage) { - switch statusCode { - case 401, 403, 404, 500, 502, 503, 504: - return http.StatusBadGateway, "upstream_error", upstreamMessage - case 429: - return http.StatusTooManyRequests, "rate_limit_error", upstreamMessage - } - } - - switch statusCode { - case 401: - return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator" - case 403: - return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator" - case 404: - if strings.EqualFold(upstreamCode, "unsupported_country_code") { - return http.StatusBadGateway, "upstream_error", "Upstream region capability unavailable for this account, please contact administrator" - } - return http.StatusBadGateway, "upstream_error", "Upstream capability unavailable for this account, please contact administrator" - case 429: - return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later" - case 529: - return http.StatusServiceUnavailable, "upstream_error", "Upstream service overloaded, please retry later" - case 500, 502, 503, 504: - return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable" - default: - return http.StatusBadGateway, "upstream_error", "Upstream request failed" - } -} - -func cloneHTTPHeaders(headers http.Header) http.Header { - if headers == nil { - return nil - } - return headers.Clone() -} - -func extractSoraFailoverHeaderInsights(headers http.Header, body []byte) (rayID, mitigated, contentType string) { - if headers != nil { - mitigated = strings.TrimSpace(headers.Get("cf-mitigated")) - contentType = strings.TrimSpace(headers.Get("content-type")) - if contentType == "" { - contentType = strings.TrimSpace(headers.Get("Content-Type")) - } - } - rayID = soraerror.ExtractCloudflareRayID(headers, body) - return rayID, mitigated, contentType -} - -func isSoraCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool { - return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body) -} - -func shouldPassthroughSoraUpstreamMessage(statusCode int, message string) bool { - message = strings.TrimSpace(message) - if message == "" { - return false - } - if statusCode == http.StatusForbidden || statusCode == http.StatusTooManyRequests { - lower := strings.ToLower(message) - if strings.Contains(lower, "Just a moment...`) - - h := &SoraGatewayHandler{} - h.handleFailoverExhausted(c, http.StatusForbidden, headers, body, true) - - lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n") - require.Len(t, lines, 2) - jsonStr := strings.TrimPrefix(lines[1], "data: ") - - var parsed map[string]any - require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed)) - - errorObj, ok := parsed["error"].(map[string]any) - require.True(t, ok) - require.Equal(t, "upstream_error", errorObj["type"]) - msg, _ := errorObj["message"].(string) - require.Contains(t, msg, "Cloudflare challenge") - require.Contains(t, msg, "cf-ray: 9d01b0e9ecc35829-SEA") -} - -func TestSoraHandleFailoverExhausted_CfShield429MappedToRateLimitError(t *testing.T) { - gin.SetMode(gin.TestMode) - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - c.Request = httptest.NewRequest(http.MethodGet, "/", nil) - - headers := http.Header{} - headers.Set("cf-ray", "9d03b68c086027a1-SEA") - body := []byte(`{"error":{"code":"cf_shield_429","message":"shield blocked"}}`) - - h := &SoraGatewayHandler{} - h.handleFailoverExhausted(c, http.StatusTooManyRequests, headers, body, true) - - lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n") - require.Len(t, lines, 2) - jsonStr := strings.TrimPrefix(lines[1], "data: ") - - var parsed map[string]any - require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed)) - - errorObj, ok := parsed["error"].(map[string]any) - require.True(t, ok) - require.Equal(t, "rate_limit_error", errorObj["type"]) - msg, _ := errorObj["message"].(string) - require.Contains(t, msg, "Cloudflare shield") - require.Contains(t, msg, "cf-ray: 9d03b68c086027a1-SEA") -} - -func TestExtractSoraFailoverHeaderInsights(t *testing.T) { - headers := http.Header{} - headers.Set("cf-mitigated", "challenge") - headers.Set("content-type", "text/html") - body := []byte(``) - - rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(headers, body) - require.Equal(t, "9cff2d62d83bb98d", rayID) - require.Equal(t, "challenge", mitigated) - require.Equal(t, "text/html", contentType) -} diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go index 9a2290ea..e9f6f281 100644 --- a/backend/internal/handler/wire.go +++ b/backend/internal/handler/wire.go @@ -34,7 +34,6 @@ func ProvideAdminHandlers( scheduledTestHandler *admin.ScheduledTestHandler, channelHandler *admin.ChannelHandler, paymentHandler *admin.PaymentHandler, - soraHandler *admin.SoraHandler, ) *AdminHandlers { return &AdminHandlers{ Dashboard: dashboardHandler, @@ -62,7 +61,6 @@ func ProvideAdminHandlers( ScheduledTest: scheduledTestHandler, Channel: channelHandler, Payment: paymentHandler, - Sora: soraHandler, } } @@ -88,8 +86,6 @@ func ProvideHandlers( adminHandlers *AdminHandlers, gatewayHandler *GatewayHandler, openaiGatewayHandler *OpenAIGatewayHandler, - soraGatewayHandler *SoraGatewayHandler, // 从本地版本合并 - soraClientHandler *SoraClientHandler, // 从本地版本合并 settingHandler *SettingHandler, totpHandler *TotpHandler, paymentHandler *PaymentHandler, @@ -108,8 +104,6 @@ func ProvideHandlers( Admin: adminHandlers, Gateway: gatewayHandler, OpenAIGateway: openaiGatewayHandler, - SoraGateway: soraGatewayHandler, // 从本地版本合并 - SoraClient: soraClientHandler, // 从本地版本合并 Setting: settingHandler, Totp: totpHandler, Payment: paymentHandler, @@ -129,8 +123,6 @@ var ProviderSet = wire.NewSet( NewAnnouncementHandler, NewGatewayHandler, NewOpenAIGatewayHandler, - NewSoraGatewayHandler, // 从本地版本合并 - NewSoraClientHandler, // 从本地版本合并 NewTotpHandler, ProvideSettingHandler, NewPaymentHandler, @@ -162,7 +154,6 @@ var ProviderSet = wire.NewSet( admin.NewScheduledTestHandler, admin.NewChannelHandler, admin.NewPaymentHandler, - admin.NewSoraHandler, // AdminHandlers and Handlers constructors ProvideAdminHandlers, diff --git a/backend/internal/pkg/openai/oauth.go b/backend/internal/pkg/openai/oauth.go index adad29f7..07b33854 100644 --- a/backend/internal/pkg/openai/oauth.go +++ b/backend/internal/pkg/openai/oauth.go @@ -19,9 +19,6 @@ const ( // OAuth Client ID for OpenAI (Codex CLI official) ClientID = "app_EMoamEEZ73f0CkXaXp7hrann" - // OAuth Client ID for Sora (从本地版本合并) - SoraClientID = "app_LlGpXReQgckcGGUo2JrYvtJK" - // OAuth endpoints AuthorizeURL = "https://auth.openai.com/oauth/authorize" TokenURL = "https://auth.openai.com/oauth/token" diff --git a/backend/internal/prommetrics/metrics_test.go b/backend/internal/prommetrics/metrics_test.go index d5206efc..bc0e2de5 100644 --- a/backend/internal/prommetrics/metrics_test.go +++ b/backend/internal/prommetrics/metrics_test.go @@ -63,7 +63,6 @@ func TestSetTPS(t *testing.T) { func TestRecordHTTPRequest(t *testing.T) { RecordHTTPRequest("GET", "/api/v1/chat", 200, 100*time.Millisecond) - RecordHTTPRequest("POST", "/api/v1/sora/generate", 201, 200*time.Millisecond) RecordHTTPRequest("GET", "/api/v1/models", 500, 50*time.Millisecond) } diff --git a/backend/internal/repository/sora_account_repo.go b/backend/internal/repository/sora_account_repo.go deleted file mode 100644 index ad2ae638..00000000 --- a/backend/internal/repository/sora_account_repo.go +++ /dev/null @@ -1,98 +0,0 @@ -package repository - -import ( - "context" - "database/sql" - "errors" - - "github.com/Wei-Shaw/sub2api/internal/service" -) - -// soraAccountRepository 实现 service.SoraAccountRepository 接口。 -// 使用原生 SQL 操作 sora_accounts 表,因为该表不在 Ent ORM 管理范围内。 -// -// 设计说明: -// - sora_accounts 表是独立迁移创建的,不通过 Ent Schema 管理 -// - 使用 ON CONFLICT (account_id) DO UPDATE 实现 Upsert 语义 -// - 与 accounts 主表通过外键关联,ON DELETE CASCADE 确保级联删除 -type soraAccountRepository struct { - sql *sql.DB -} - -// NewSoraAccountRepository 创建 Sora 账号扩展表仓储实例 -func NewSoraAccountRepository(sqlDB *sql.DB) service.SoraAccountRepository { - return &soraAccountRepository{sql: sqlDB} -} - -// Upsert 创建或更新 Sora 账号扩展信息 -// 使用 PostgreSQL ON CONFLICT ... DO UPDATE 实现原子性 upsert -func (r *soraAccountRepository) Upsert(ctx context.Context, accountID int64, updates map[string]any) error { - accessToken, accessOK := updates["access_token"].(string) - refreshToken, refreshOK := updates["refresh_token"].(string) - sessionToken, sessionOK := updates["session_token"].(string) - - if !accessOK || accessToken == "" || !refreshOK || refreshToken == "" { - if !sessionOK { - return errors.New("缺少 access_token/refresh_token,且未提供可更新字段") - } - result, err := r.sql.ExecContext(ctx, ` - UPDATE sora_accounts - SET session_token = CASE WHEN $2 = '' THEN session_token ELSE $2 END, - updated_at = NOW() - WHERE account_id = $1 - `, accountID, sessionToken) - if err != nil { - return err - } - rows, err := result.RowsAffected() - if err != nil { - return err - } - if rows == 0 { - return errors.New("sora_accounts 记录不存在,无法仅更新 session_token") - } - return nil - } - - _, err := r.sql.ExecContext(ctx, ` - INSERT INTO sora_accounts (account_id, access_token, refresh_token, session_token, created_at, updated_at) - VALUES ($1, $2, $3, $4, NOW(), NOW()) - ON CONFLICT (account_id) DO UPDATE SET - access_token = EXCLUDED.access_token, - refresh_token = EXCLUDED.refresh_token, - session_token = CASE WHEN EXCLUDED.session_token = '' THEN sora_accounts.session_token ELSE EXCLUDED.session_token END, - updated_at = NOW() - `, accountID, accessToken, refreshToken, sessionToken) - return err -} - -// GetByAccountID 根据账号 ID 获取 Sora 扩展信息 -func (r *soraAccountRepository) GetByAccountID(ctx context.Context, accountID int64) (*service.SoraAccount, error) { - rows, err := r.sql.QueryContext(ctx, ` - SELECT account_id, access_token, refresh_token, COALESCE(session_token, '') - FROM sora_accounts - WHERE account_id = $1 - `, accountID) - if err != nil { - return nil, err - } - defer func() { _ = rows.Close() }() - - if !rows.Next() { - return nil, nil // 记录不存在 - } - - var sa service.SoraAccount - if err := rows.Scan(&sa.AccountID, &sa.AccessToken, &sa.RefreshToken, &sa.SessionToken); err != nil { - return nil, err - } - return &sa, nil -} - -// Delete 删除 Sora 账号扩展信息 -func (r *soraAccountRepository) Delete(ctx context.Context, accountID int64) error { - _, err := r.sql.ExecContext(ctx, ` - DELETE FROM sora_accounts WHERE account_id = $1 - `, accountID) - return err -} diff --git a/backend/internal/repository/sora_generation_repo.go b/backend/internal/repository/sora_generation_repo.go deleted file mode 100644 index aaf3cb2f..00000000 --- a/backend/internal/repository/sora_generation_repo.go +++ /dev/null @@ -1,419 +0,0 @@ -package repository - -import ( - "context" - "database/sql" - "encoding/json" - "fmt" - "strings" - "time" - - "github.com/Wei-Shaw/sub2api/internal/service" -) - -// soraGenerationRepository 实现 service.SoraGenerationRepository 接口。 -// 使用原生 SQL 操作 sora_generations 表。 -type soraGenerationRepository struct { - sql *sql.DB -} - -// NewSoraGenerationRepository 创建 Sora 生成记录仓储实例。 -func NewSoraGenerationRepository(sqlDB *sql.DB) service.SoraGenerationRepository { - return &soraGenerationRepository{sql: sqlDB} -} - -func (r *soraGenerationRepository) Create(ctx context.Context, gen *service.SoraGeneration) error { - mediaURLsJSON, _ := json.Marshal(gen.MediaURLs) - s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys) - - err := r.sql.QueryRowContext(ctx, ` - INSERT INTO sora_generations ( - user_id, api_key_id, model, prompt, media_type, - status, media_url, media_urls, file_size_bytes, - storage_type, s3_object_keys, upstream_task_id, error_message - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) - RETURNING id, created_at - `, - gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType, - gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes, - gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage, - ).Scan(&gen.ID, &gen.CreatedAt) - return err -} - -// CreatePendingWithLimit 在单事务内执行“并发上限检查 + 创建”,避免 count+create 竞态。 -func (r *soraGenerationRepository) CreatePendingWithLimit( - ctx context.Context, - gen *service.SoraGeneration, - activeStatuses []string, - maxActive int64, -) error { - if gen == nil { - return fmt.Errorf("generation is nil") - } - if maxActive <= 0 { - return r.Create(ctx, gen) - } - if len(activeStatuses) == 0 { - activeStatuses = []string{service.SoraGenStatusPending, service.SoraGenStatusGenerating} - } - - tx, err := r.sql.BeginTx(ctx, nil) - if err != nil { - return err - } - defer func() { _ = tx.Rollback() }() - - // 使用用户级 advisory lock 串行化并发创建,避免超限竞态。 - if _, err := tx.ExecContext(ctx, `SELECT pg_advisory_xact_lock($1)`, gen.UserID); err != nil { - return err - } - - placeholders := make([]string, len(activeStatuses)) - args := make([]any, 0, 1+len(activeStatuses)) - args = append(args, gen.UserID) - for i, s := range activeStatuses { - placeholders[i] = fmt.Sprintf("$%d", i+2) - args = append(args, s) - } - countQuery := fmt.Sprintf( - `SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)`, - strings.Join(placeholders, ","), - ) - var activeCount int64 - if err := tx.QueryRowContext(ctx, countQuery, args...).Scan(&activeCount); err != nil { - return err - } - if activeCount >= maxActive { - return service.ErrSoraGenerationConcurrencyLimit - } - - mediaURLsJSON, _ := json.Marshal(gen.MediaURLs) - s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys) - if err := tx.QueryRowContext(ctx, ` - INSERT INTO sora_generations ( - user_id, api_key_id, model, prompt, media_type, - status, media_url, media_urls, file_size_bytes, - storage_type, s3_object_keys, upstream_task_id, error_message - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) - RETURNING id, created_at - `, - gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType, - gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes, - gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage, - ).Scan(&gen.ID, &gen.CreatedAt); err != nil { - return err - } - - return tx.Commit() -} - -func (r *soraGenerationRepository) GetByID(ctx context.Context, id int64) (*service.SoraGeneration, error) { - gen := &service.SoraGeneration{} - var mediaURLsJSON, s3KeysJSON []byte - var completedAt sql.NullTime - var apiKeyID sql.NullInt64 - - err := r.sql.QueryRowContext(ctx, ` - SELECT id, user_id, api_key_id, model, prompt, media_type, - status, media_url, media_urls, file_size_bytes, - storage_type, s3_object_keys, upstream_task_id, error_message, - created_at, completed_at - FROM sora_generations WHERE id = $1 - `, id).Scan( - &gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType, - &gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes, - &gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage, - &gen.CreatedAt, &completedAt, - ) - if err != nil { - if err == sql.ErrNoRows { - return nil, fmt.Errorf("生成记录不存在") - } - return nil, err - } - - if apiKeyID.Valid { - gen.APIKeyID = &apiKeyID.Int64 - } - if completedAt.Valid { - gen.CompletedAt = &completedAt.Time - } - _ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs) - _ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys) - return gen, nil -} - -func (r *soraGenerationRepository) Update(ctx context.Context, gen *service.SoraGeneration) error { - mediaURLsJSON, _ := json.Marshal(gen.MediaURLs) - s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys) - - var completedAt *time.Time - if gen.CompletedAt != nil { - completedAt = gen.CompletedAt - } - - _, err := r.sql.ExecContext(ctx, ` - UPDATE sora_generations SET - status = $2, media_url = $3, media_urls = $4, file_size_bytes = $5, - storage_type = $6, s3_object_keys = $7, upstream_task_id = $8, - error_message = $9, completed_at = $10 - WHERE id = $1 - `, - gen.ID, gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes, - gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, - gen.ErrorMessage, completedAt, - ) - return err -} - -// UpdateGeneratingIfPending 仅当状态为 pending 时更新为 generating。 -func (r *soraGenerationRepository) UpdateGeneratingIfPending(ctx context.Context, id int64, upstreamTaskID string) (bool, error) { - result, err := r.sql.ExecContext(ctx, ` - UPDATE sora_generations - SET status = $2, upstream_task_id = $3 - WHERE id = $1 AND status = $4 - `, - id, service.SoraGenStatusGenerating, upstreamTaskID, service.SoraGenStatusPending, - ) - if err != nil { - return false, err - } - affected, err := result.RowsAffected() - if err != nil { - return false, err - } - return affected > 0, nil -} - -// UpdateCompletedIfActive 仅当状态为 pending/generating 时更新为 completed。 -func (r *soraGenerationRepository) UpdateCompletedIfActive( - ctx context.Context, - id int64, - mediaURL string, - mediaURLs []string, - storageType string, - s3Keys []string, - fileSizeBytes int64, - completedAt time.Time, -) (bool, error) { - mediaURLsJSON, _ := json.Marshal(mediaURLs) - s3KeysJSON, _ := json.Marshal(s3Keys) - result, err := r.sql.ExecContext(ctx, ` - UPDATE sora_generations - SET status = $2, - media_url = $3, - media_urls = $4, - file_size_bytes = $5, - storage_type = $6, - s3_object_keys = $7, - error_message = '', - completed_at = $8 - WHERE id = $1 AND status IN ($9, $10) - `, - id, service.SoraGenStatusCompleted, mediaURL, mediaURLsJSON, fileSizeBytes, - storageType, s3KeysJSON, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating, - ) - if err != nil { - return false, err - } - affected, err := result.RowsAffected() - if err != nil { - return false, err - } - return affected > 0, nil -} - -// UpdateFailedIfActive 仅当状态为 pending/generating 时更新为 failed。 -func (r *soraGenerationRepository) UpdateFailedIfActive( - ctx context.Context, - id int64, - errMsg string, - completedAt time.Time, -) (bool, error) { - result, err := r.sql.ExecContext(ctx, ` - UPDATE sora_generations - SET status = $2, - error_message = $3, - completed_at = $4 - WHERE id = $1 AND status IN ($5, $6) - `, - id, service.SoraGenStatusFailed, errMsg, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating, - ) - if err != nil { - return false, err - } - affected, err := result.RowsAffected() - if err != nil { - return false, err - } - return affected > 0, nil -} - -// UpdateCancelledIfActive 仅当状态为 pending/generating 时更新为 cancelled。 -func (r *soraGenerationRepository) UpdateCancelledIfActive(ctx context.Context, id int64, completedAt time.Time) (bool, error) { - result, err := r.sql.ExecContext(ctx, ` - UPDATE sora_generations - SET status = $2, completed_at = $3 - WHERE id = $1 AND status IN ($4, $5) - `, - id, service.SoraGenStatusCancelled, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating, - ) - if err != nil { - return false, err - } - affected, err := result.RowsAffected() - if err != nil { - return false, err - } - return affected > 0, nil -} - -// UpdateStorageIfCompleted 更新已完成记录的存储信息(用于手动保存,不重置 completed_at)。 -func (r *soraGenerationRepository) UpdateStorageIfCompleted( - ctx context.Context, - id int64, - mediaURL string, - mediaURLs []string, - storageType string, - s3Keys []string, - fileSizeBytes int64, -) (bool, error) { - mediaURLsJSON, _ := json.Marshal(mediaURLs) - s3KeysJSON, _ := json.Marshal(s3Keys) - result, err := r.sql.ExecContext(ctx, ` - UPDATE sora_generations - SET media_url = $2, - media_urls = $3, - file_size_bytes = $4, - storage_type = $5, - s3_object_keys = $6 - WHERE id = $1 AND status = $7 - `, - id, mediaURL, mediaURLsJSON, fileSizeBytes, storageType, s3KeysJSON, service.SoraGenStatusCompleted, - ) - if err != nil { - return false, err - } - affected, err := result.RowsAffected() - if err != nil { - return false, err - } - return affected > 0, nil -} - -func (r *soraGenerationRepository) Delete(ctx context.Context, id int64) error { - _, err := r.sql.ExecContext(ctx, `DELETE FROM sora_generations WHERE id = $1`, id) - return err -} - -func (r *soraGenerationRepository) List(ctx context.Context, params service.SoraGenerationListParams) ([]*service.SoraGeneration, int64, error) { - // 构建 WHERE 条件 - conditions := []string{"user_id = $1"} - args := []any{params.UserID} - argIdx := 2 - - if params.Status != "" { - // 支持逗号分隔的多状态 - statuses := strings.Split(params.Status, ",") - placeholders := make([]string, len(statuses)) - for i, s := range statuses { - placeholders[i] = fmt.Sprintf("$%d", argIdx) - args = append(args, strings.TrimSpace(s)) - argIdx++ - } - conditions = append(conditions, fmt.Sprintf("status IN (%s)", strings.Join(placeholders, ","))) - } - if params.StorageType != "" { - storageTypes := strings.Split(params.StorageType, ",") - placeholders := make([]string, len(storageTypes)) - for i, s := range storageTypes { - placeholders[i] = fmt.Sprintf("$%d", argIdx) - args = append(args, strings.TrimSpace(s)) - argIdx++ - } - conditions = append(conditions, fmt.Sprintf("storage_type IN (%s)", strings.Join(placeholders, ","))) - } - if params.MediaType != "" { - conditions = append(conditions, fmt.Sprintf("media_type = $%d", argIdx)) - args = append(args, params.MediaType) - argIdx++ - } - - whereClause := "WHERE " + strings.Join(conditions, " AND ") - - // 计数 - var total int64 - countQuery := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations %s", whereClause) - if err := r.sql.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil { - return nil, 0, err - } - - // 分页查询 - offset := (params.Page - 1) * params.PageSize - listQuery := fmt.Sprintf(` - SELECT id, user_id, api_key_id, model, prompt, media_type, - status, media_url, media_urls, file_size_bytes, - storage_type, s3_object_keys, upstream_task_id, error_message, - created_at, completed_at - FROM sora_generations %s - ORDER BY created_at DESC - LIMIT $%d OFFSET $%d - `, whereClause, argIdx, argIdx+1) - args = append(args, params.PageSize, offset) - - rows, err := r.sql.QueryContext(ctx, listQuery, args...) - if err != nil { - return nil, 0, err - } - defer func() { - _ = rows.Close() - }() - - var results []*service.SoraGeneration - for rows.Next() { - gen := &service.SoraGeneration{} - var mediaURLsJSON, s3KeysJSON []byte - var completedAt sql.NullTime - var apiKeyID sql.NullInt64 - - if err := rows.Scan( - &gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType, - &gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes, - &gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage, - &gen.CreatedAt, &completedAt, - ); err != nil { - return nil, 0, err - } - - if apiKeyID.Valid { - gen.APIKeyID = &apiKeyID.Int64 - } - if completedAt.Valid { - gen.CompletedAt = &completedAt.Time - } - _ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs) - _ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys) - results = append(results, gen) - } - - return results, total, rows.Err() -} - -func (r *soraGenerationRepository) CountByUserAndStatus(ctx context.Context, userID int64, statuses []string) (int64, error) { - if len(statuses) == 0 { - return 0, nil - } - - placeholders := make([]string, len(statuses)) - args := []any{userID} - for i, s := range statuses { - placeholders[i] = fmt.Sprintf("$%d", i+2) - args = append(args, s) - } - - var count int64 - query := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)", strings.Join(placeholders, ",")) - err := r.sql.QueryRowContext(ctx, query, args...).Scan(&count) - return count, err -} diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index e71120f3..d3adb4a0 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -89,8 +89,6 @@ var ProviderSet = wire.NewSet( NewErrorPassthroughRepository, NewTLSFingerprintProfileRepository, NewChannelRepository, - NewSoraAccountRepository, // Sora 账号扩展表仓储 (从本地版本合并) - NewSoraGenerationRepository, // Sora 生成记录仓储 (从本地版本合并) // Cache implementations NewGatewayCache, diff --git a/backend/internal/server/middleware/security_headers.go b/backend/internal/server/middleware/security_headers.go index d9ec951e..73210bfc 100644 --- a/backend/internal/server/middleware/security_headers.go +++ b/backend/internal/server/middleware/security_headers.go @@ -94,7 +94,6 @@ func isAPIRoutePath(c *gin.Context) bool { return strings.HasPrefix(path, "/v1/") || strings.HasPrefix(path, "/v1beta/") || strings.HasPrefix(path, "/antigravity/") || - strings.HasPrefix(path, "/sora/") || strings.HasPrefix(path, "/responses") } diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 0d9fedc1..52c4e325 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -88,17 +88,6 @@ func RegisterAdminRoutes( // 渠道管理 registerChannelRoutes(admin, h) - // Sora 管理 - registerSoraRoutes(admin, h) - } -} - -func registerSoraRoutes(admin *gin.RouterGroup, h *handler.Handlers) { - sora := admin.Group("/sora") - { - sora.GET("/stats", h.Admin.Sora.GetSystemStats) - sora.GET("/users", h.Admin.Sora.ListUserStats) - sora.GET("/generations", h.Admin.Sora.ListGenerations) } } diff --git a/backend/internal/server/routes/admin_routes_test.go b/backend/internal/server/routes/admin_routes_test.go index 1c192757..9cbae518 100644 --- a/backend/internal/server/routes/admin_routes_test.go +++ b/backend/internal/server/routes/admin_routes_test.go @@ -42,7 +42,6 @@ func TestRegisterAdminRoutes_OmitsDeprecatedMockEndpoints(t *testing.T) { "GET /api/v1/admin/data-management/agent/health", "GET /api/v1/admin/data-management/config", "POST /api/v1/admin/data-management/backups", - "DELETE /api/v1/admin/sora/users/:id/storage", } for _, route := range deprecatedRoutes { diff --git a/backend/internal/server/routes/sora_client.go b/backend/internal/server/routes/sora_client.go deleted file mode 100644 index 13fceb81..00000000 --- a/backend/internal/server/routes/sora_client.go +++ /dev/null @@ -1,36 +0,0 @@ -package routes - -import ( - "github.com/Wei-Shaw/sub2api/internal/handler" - "github.com/Wei-Shaw/sub2api/internal/server/middleware" - "github.com/Wei-Shaw/sub2api/internal/service" - - "github.com/gin-gonic/gin" -) - -// RegisterSoraClientRoutes 注册 Sora 客户端 API 路由(需要用户认证)。 -func RegisterSoraClientRoutes( - v1 *gin.RouterGroup, - h *handler.Handlers, - jwtAuth middleware.JWTAuthMiddleware, - settingService *service.SettingService, -) { - if h.SoraClient == nil { - return - } - - authenticated := v1.Group("/sora") - authenticated.Use(gin.HandlerFunc(jwtAuth)) - authenticated.Use(middleware.BackendModeUserGuard(settingService)) - { - authenticated.POST("/generate", h.SoraClient.Generate) - authenticated.GET("/generations", h.SoraClient.ListGenerations) - authenticated.GET("/generations/:id", h.SoraClient.GetGeneration) - authenticated.DELETE("/generations/:id", h.SoraClient.DeleteGeneration) - authenticated.POST("/generations/:id/cancel", h.SoraClient.CancelGeneration) - authenticated.POST("/generations/:id/save", h.SoraClient.SaveToStorage) - authenticated.GET("/quota", h.SoraClient.GetQuota) - authenticated.GET("/models", h.SoraClient.GetModels) - authenticated.GET("/storage-status", h.SoraClient.GetStorageStatus) - } -} diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index ef1ccf3f..b9c2a6d3 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -24,7 +24,6 @@ const ( PlatformOpenAI = domain.PlatformOpenAI PlatformGemini = domain.PlatformGemini PlatformAntigravity = domain.PlatformAntigravity - PlatformSora = domain.PlatformSora // 从本地版本合并 ) // Account type constants @@ -253,19 +252,6 @@ const ( SettingKeyEnableMetadataPassthrough = "enable_metadata_passthrough" // SettingKeyEnableCCHSigning 是否对 billing header 中的 cch 进行 xxHash64 签名(默认 false) SettingKeyEnableCCHSigning = "enable_cch_signing" - - // Sora S3 存储配置 (从本地版本合并) - SettingKeySoraS3Enabled = "sora_s3_enabled" // 是否启用 Sora S3 存储 - SettingKeySoraS3Endpoint = "sora_s3_endpoint" // S3 端点地址 - SettingKeySoraS3Region = "sora_s3_region" // S3 区域 - SettingKeySoraS3Bucket = "sora_s3_bucket" // S3 存储桶名称 - SettingKeySoraS3AccessKeyID = "sora_s3_access_key_id" // S3 Access Key ID - SettingKeySoraS3SecretAccessKey = "sora_s3_secret_access_key" // S3 Secret Access Key(加密存储) - SettingKeySoraS3Prefix = "sora_s3_prefix" // S3 对象键前缀 - SettingKeySoraS3ForcePathStyle = "sora_s3_force_path_style" // 是否强制 Path Style(兼容 MinIO 等) - SettingKeySoraS3CDNURL = "sora_s3_cdn_url" // CDN 加速 URL(可选) - SettingKeySoraS3Profiles = "sora_s3_profiles" // Sora S3 多配置(JSON) - SettingKeySoraDefaultStorageQuotaBytes = "sora_default_storage_quota_bytes" // Sora 默认存储配额(字节) ) // AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys). diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index a742c926..e9ea00de 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -503,10 +503,6 @@ type ForwardResult struct { // 图片生成计费字段(图片生成模型使用) ImageCount int // 生成的图片数量 ImageSize string // 图片尺寸 "1K", "2K", "4K" - - // Sora 媒体字段 (从本地版本合并) - MediaType string // image / video / prompt - MediaURL string // 生成后的媒体地址(可选) } // UpstreamFailoverError indicates an upstream error that should trigger account failover. diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 619f201e..48f25da0 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -24,8 +24,6 @@ import ( var ( ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled") ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found") - ErrSoraS3ProfileNotFound = infraerrors.NotFound("SORA_S3_PROFILE_NOT_FOUND", "sora s3 profile not found") // 从本地版本合并 - ErrSoraS3ProfileExists = infraerrors.Conflict("SORA_S3_PROFILE_EXISTS", "sora s3 profile already exists") // 从本地版本合并 ErrDefaultSubGroupInvalid = infraerrors.BadRequest( "DEFAULT_SUBSCRIPTION_GROUP_INVALID", "default subscription group must exist and be subscription type", @@ -2128,315 +2126,3 @@ func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings return s.settingRepo.Set(ctx, SettingKeyStreamTimeoutSettings, string(data)) } - -// Sora S3 存储配置 (从本地版本合并) -type soraS3ProfilesStore struct { - ActiveProfileID string `json:"active_profile_id"` - Items []soraS3ProfileStoreItem `json:"items"` -} - -type soraS3ProfileStoreItem struct { - ProfileID string `json:"profile_id"` - Name string `json:"name"` - Enabled bool `json:"enabled"` - Endpoint string `json:"endpoint"` - Region string `json:"region"` - Bucket string `json:"bucket"` - AccessKeyID string `json:"access_key_id"` - SecretAccessKey string `json:"secret_access_key"` - Prefix string `json:"prefix"` - ForcePathStyle bool `json:"force_path_style"` - CDNURL string `json:"cdn_url"` - DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` - UpdatedAt string `json:"updated_at"` -} - -// GetSoraS3Settings 获取 Sora S3 存储配置(兼容旧单配置语义:返回当前激活配置) -func (s *SettingService) GetSoraS3Settings(ctx context.Context) (*SoraS3Settings, error) { - profiles, err := s.ListSoraS3Profiles(ctx) - if err != nil { - return nil, err - } - - activeProfile := pickActiveSoraS3Profile(profiles.Items, profiles.ActiveProfileID) - if activeProfile == nil { - return &SoraS3Settings{}, nil - } - - return &SoraS3Settings{ - Enabled: activeProfile.Enabled, - Endpoint: activeProfile.Endpoint, - Region: activeProfile.Region, - Bucket: activeProfile.Bucket, - AccessKeyID: activeProfile.AccessKeyID, - SecretAccessKey: activeProfile.SecretAccessKey, - SecretAccessKeyConfigured: activeProfile.SecretAccessKeyConfigured, - Prefix: activeProfile.Prefix, - ForcePathStyle: activeProfile.ForcePathStyle, - CDNURL: activeProfile.CDNURL, - DefaultStorageQuotaBytes: activeProfile.DefaultStorageQuotaBytes, - }, nil -} - -// ListSoraS3Profiles 获取 Sora S3 多配置列表 -func (s *SettingService) ListSoraS3Profiles(ctx context.Context) (*SoraS3ProfileList, error) { - store, err := s.loadSoraS3ProfilesStore(ctx) - if err != nil { - return nil, err - } - return convertSoraS3ProfilesStore(store), nil -} - -func (s *SettingService) loadSoraS3ProfilesStore(ctx context.Context) (*soraS3ProfilesStore, error) { - raw, err := s.settingRepo.GetValue(ctx, SettingKeySoraS3Profiles) - if err == nil { - trimmed := strings.TrimSpace(raw) - if trimmed == "" { - return &soraS3ProfilesStore{}, nil - } - var store soraS3ProfilesStore - if unmarshalErr := json.Unmarshal([]byte(trimmed), &store); unmarshalErr != nil { - legacy, legacyErr := s.getLegacySoraS3Settings(ctx) - if legacyErr != nil { - return nil, fmt.Errorf("unmarshal sora s3 profiles: %w", unmarshalErr) - } - if isEmptyLegacySoraS3Settings(legacy) { - return &soraS3ProfilesStore{}, nil - } - now := time.Now().UTC().Format(time.RFC3339) - return &soraS3ProfilesStore{ - ActiveProfileID: "default", - Items: []soraS3ProfileStoreItem{ - { - ProfileID: "default", - Name: "Default", - Enabled: legacy.Enabled, - Endpoint: strings.TrimSpace(legacy.Endpoint), - Region: strings.TrimSpace(legacy.Region), - Bucket: strings.TrimSpace(legacy.Bucket), - AccessKeyID: strings.TrimSpace(legacy.AccessKeyID), - SecretAccessKey: legacy.SecretAccessKey, - Prefix: strings.TrimSpace(legacy.Prefix), - ForcePathStyle: legacy.ForcePathStyle, - CDNURL: strings.TrimSpace(legacy.CDNURL), - DefaultStorageQuotaBytes: maxInt64(legacy.DefaultStorageQuotaBytes, 0), - UpdatedAt: now, - }, - }, - }, nil - } - normalized := normalizeSoraS3ProfilesStore(store) - return &normalized, nil - } - - if !errors.Is(err, ErrSettingNotFound) { - return nil, fmt.Errorf("get sora s3 profiles: %w", err) - } - - legacy, legacyErr := s.getLegacySoraS3Settings(ctx) - if legacyErr != nil { - return nil, legacyErr - } - if isEmptyLegacySoraS3Settings(legacy) { - return &soraS3ProfilesStore{}, nil - } - - now := time.Now().UTC().Format(time.RFC3339) - return &soraS3ProfilesStore{ - ActiveProfileID: "default", - Items: []soraS3ProfileStoreItem{ - { - ProfileID: "default", - Name: "Default", - Enabled: legacy.Enabled, - Endpoint: strings.TrimSpace(legacy.Endpoint), - Region: strings.TrimSpace(legacy.Region), - Bucket: strings.TrimSpace(legacy.Bucket), - AccessKeyID: strings.TrimSpace(legacy.AccessKeyID), - SecretAccessKey: legacy.SecretAccessKey, - Prefix: strings.TrimSpace(legacy.Prefix), - ForcePathStyle: legacy.ForcePathStyle, - CDNURL: strings.TrimSpace(legacy.CDNURL), - DefaultStorageQuotaBytes: maxInt64(legacy.DefaultStorageQuotaBytes, 0), - UpdatedAt: now, - }, - }, - }, nil -} - -func (s *SettingService) getLegacySoraS3Settings(ctx context.Context) (*SoraS3Settings, error) { - keys := []string{ - SettingKeySoraS3Enabled, - SettingKeySoraS3Endpoint, - SettingKeySoraS3Region, - SettingKeySoraS3Bucket, - SettingKeySoraS3AccessKeyID, - SettingKeySoraS3SecretAccessKey, - SettingKeySoraS3Prefix, - SettingKeySoraS3ForcePathStyle, - SettingKeySoraS3CDNURL, - SettingKeySoraDefaultStorageQuotaBytes, - } - - values, err := s.settingRepo.GetMultiple(ctx, keys) - if err != nil { - return nil, fmt.Errorf("get legacy sora s3 settings: %w", err) - } - - result := &SoraS3Settings{ - Enabled: values[SettingKeySoraS3Enabled] == "true", - Endpoint: values[SettingKeySoraS3Endpoint], - Region: values[SettingKeySoraS3Region], - Bucket: values[SettingKeySoraS3Bucket], - AccessKeyID: values[SettingKeySoraS3AccessKeyID], - SecretAccessKey: values[SettingKeySoraS3SecretAccessKey], - SecretAccessKeyConfigured: values[SettingKeySoraS3SecretAccessKey] != "", - Prefix: values[SettingKeySoraS3Prefix], - ForcePathStyle: values[SettingKeySoraS3ForcePathStyle] == "true", - CDNURL: values[SettingKeySoraS3CDNURL], - } - if v, parseErr := strconv.ParseInt(values[SettingKeySoraDefaultStorageQuotaBytes], 10, 64); parseErr == nil { - result.DefaultStorageQuotaBytes = v - } - return result, nil -} - -func normalizeSoraS3ProfilesStore(store soraS3ProfilesStore) soraS3ProfilesStore { - seen := make(map[string]struct{}, len(store.Items)) - normalized := soraS3ProfilesStore{ - ActiveProfileID: strings.TrimSpace(store.ActiveProfileID), - Items: make([]soraS3ProfileStoreItem, 0, len(store.Items)), - } - now := time.Now().UTC().Format(time.RFC3339) - - for idx := range store.Items { - item := store.Items[idx] - item.ProfileID = strings.TrimSpace(item.ProfileID) - if item.ProfileID == "" { - item.ProfileID = fmt.Sprintf("profile-%d", idx+1) - } - if _, exists := seen[item.ProfileID]; exists { - continue - } - seen[item.ProfileID] = struct{}{} - - item.Name = strings.TrimSpace(item.Name) - if item.Name == "" { - item.Name = item.ProfileID - } - item.Endpoint = strings.TrimSpace(item.Endpoint) - item.Region = strings.TrimSpace(item.Region) - item.Bucket = strings.TrimSpace(item.Bucket) - item.AccessKeyID = strings.TrimSpace(item.AccessKeyID) - item.Prefix = strings.TrimSpace(item.Prefix) - item.CDNURL = strings.TrimSpace(item.CDNURL) - item.DefaultStorageQuotaBytes = maxInt64(item.DefaultStorageQuotaBytes, 0) - item.UpdatedAt = strings.TrimSpace(item.UpdatedAt) - if item.UpdatedAt == "" { - item.UpdatedAt = now - } - normalized.Items = append(normalized.Items, item) - } - - if len(normalized.Items) == 0 { - normalized.ActiveProfileID = "" - return normalized - } - - if findSoraS3ProfileIndex(normalized.Items, normalized.ActiveProfileID) >= 0 { - return normalized - } - - normalized.ActiveProfileID = normalized.Items[0].ProfileID - return normalized -} - -func convertSoraS3ProfilesStore(store *soraS3ProfilesStore) *SoraS3ProfileList { - if store == nil { - return &SoraS3ProfileList{} - } - items := make([]SoraS3Profile, 0, len(store.Items)) - for idx := range store.Items { - item := store.Items[idx] - items = append(items, SoraS3Profile{ - ProfileID: item.ProfileID, - Name: item.Name, - IsActive: item.ProfileID == store.ActiveProfileID, - Enabled: item.Enabled, - Endpoint: item.Endpoint, - Region: item.Region, - Bucket: item.Bucket, - AccessKeyID: item.AccessKeyID, - SecretAccessKey: item.SecretAccessKey, - SecretAccessKeyConfigured: item.SecretAccessKey != "", - Prefix: item.Prefix, - ForcePathStyle: item.ForcePathStyle, - CDNURL: item.CDNURL, - DefaultStorageQuotaBytes: item.DefaultStorageQuotaBytes, - UpdatedAt: item.UpdatedAt, - }) - } - return &SoraS3ProfileList{ - ActiveProfileID: store.ActiveProfileID, - Items: items, - } -} - -func pickActiveSoraS3Profile(items []SoraS3Profile, activeProfileID string) *SoraS3Profile { - for idx := range items { - if items[idx].ProfileID == activeProfileID { - return &items[idx] - } - } - if len(items) == 0 { - return nil - } - return &items[0] -} - -func findSoraS3ProfileIndex(items []soraS3ProfileStoreItem, profileID string) int { - for idx := range items { - if items[idx].ProfileID == profileID { - return idx - } - } - return -1 -} - -func isEmptyLegacySoraS3Settings(settings *SoraS3Settings) bool { - if settings == nil { - return true - } - if settings.Enabled { - return false - } - if strings.TrimSpace(settings.Endpoint) != "" { - return false - } - if strings.TrimSpace(settings.Region) != "" { - return false - } - if strings.TrimSpace(settings.Bucket) != "" { - return false - } - if strings.TrimSpace(settings.AccessKeyID) != "" { - return false - } - if settings.SecretAccessKey != "" { - return false - } - if strings.TrimSpace(settings.Prefix) != "" { - return false - } - if strings.TrimSpace(settings.CDNURL) != "" { - return false - } - return settings.DefaultStorageQuotaBytes == 0 -} - -func maxInt64(value int64, min int64) int64 { - if value < min { - return min - } - return value -} diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index d8c1748d..de92b796 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -113,46 +113,6 @@ type DefaultSubscriptionSetting struct { ValidityDays int `json:"validity_days"` } -// SoraS3Settings Sora S3 存储配置 (从本地版本合并) -type SoraS3Settings struct { - Enabled bool `json:"enabled"` - Endpoint string `json:"endpoint"` - Region string `json:"region"` - Bucket string `json:"bucket"` - AccessKeyID string `json:"access_key_id"` - SecretAccessKey string `json:"secret_access_key"` // 仅内部使用,不直接返回前端 - SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` // 前端展示用 - Prefix string `json:"prefix"` - ForcePathStyle bool `json:"force_path_style"` - CDNURL string `json:"cdn_url"` - DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` -} - -// SoraS3Profile Sora S3 多配置项(服务内部模型)(从本地版本合并) -type SoraS3Profile struct { - ProfileID string `json:"profile_id"` - Name string `json:"name"` - IsActive bool `json:"is_active"` - Enabled bool `json:"enabled"` - Endpoint string `json:"endpoint"` - Region string `json:"region"` - Bucket string `json:"bucket"` - AccessKeyID string `json:"access_key_id"` - SecretAccessKey string `json:"-"` // 仅内部使用,不直接返回前端 - SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` // 前端展示用 - Prefix string `json:"prefix"` - ForcePathStyle bool `json:"force_path_style"` - CDNURL string `json:"cdn_url"` - DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` - UpdatedAt string `json:"updated_at"` -} - -// SoraS3ProfileList Sora S3 多配置列表 (从本地版本合并) -type SoraS3ProfileList struct { - ActiveProfileID string `json:"active_profile_id"` - Items []SoraS3Profile `json:"items"` -} - type PublicSettings struct { RegistrationEnabled bool EmailVerifyEnabled bool diff --git a/backend/internal/service/sora_account_service.go b/backend/internal/service/sora_account_service.go deleted file mode 100644 index eccc1acf..00000000 --- a/backend/internal/service/sora_account_service.go +++ /dev/null @@ -1,40 +0,0 @@ -package service - -import "context" - -// SoraAccountRepository Sora 账号扩展表仓储接口 -// 用于管理 sora_accounts 表,与 accounts 主表形成双表结构。 -// -// 设计说明: -// - sora_accounts 表存储 Sora 账号的 OAuth 凭证副本 -// - Sora gateway 优先读取此表的字段以获得更好的查询性能 -// - 主表 accounts 通过 credentials JSON 字段也存储相同信息 -// - Token 刷新时需要同时更新两个表以保持数据一致性 -type SoraAccountRepository interface { - // Upsert 创建或更新 Sora 账号扩展信息 - // accountID: 关联的 accounts.id - // updates: 要更新的字段,支持 access_token、refresh_token、session_token - // - // 如果记录不存在则创建,存在则更新。 - // 用于: - // 1. 创建 Sora 账号时初始化扩展表 - // 2. Token 刷新时同步更新扩展表 - Upsert(ctx context.Context, accountID int64, updates map[string]any) error - - // GetByAccountID 根据账号 ID 获取 Sora 扩展信息 - // 返回 nil, nil 表示记录不存在(非错误) - GetByAccountID(ctx context.Context, accountID int64) (*SoraAccount, error) - - // Delete 删除 Sora 账号扩展信息 - // 通常由外键 ON DELETE CASCADE 自动处理,此方法用于手动清理 - Delete(ctx context.Context, accountID int64) error -} - -// SoraAccount Sora 账号扩展信息 -// 对应 sora_accounts 表,存储 Sora 账号的 OAuth 凭证副本 -type SoraAccount struct { - AccountID int64 // 关联的 accounts.id - AccessToken string // OAuth access_token - RefreshToken string // OAuth refresh_token - SessionToken string // Session token(可选,用于 ST→AT 兜底) -} diff --git a/backend/internal/service/sora_client.go b/backend/internal/service/sora_client.go deleted file mode 100644 index 0a914d2d..00000000 --- a/backend/internal/service/sora_client.go +++ /dev/null @@ -1,117 +0,0 @@ -package service - -import ( - "context" - "fmt" - "net/http" -) - -// SoraClient 定义直连 Sora 的任务操作接口。 -type SoraClient interface { - Enabled() bool - UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) - CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error) - CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) - CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) - UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) - GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) - DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) - UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) - FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) - SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error - DeleteCharacter(ctx context.Context, account *Account, characterID string) error - PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) - DeletePost(ctx context.Context, account *Account, postID string) error - GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) - EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) - GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) - GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) -} - -// SoraImageRequest 图片生成请求参数 -type SoraImageRequest struct { - Prompt string - Width int - Height int - MediaID string -} - -// SoraVideoRequest 视频生成请求参数 -type SoraVideoRequest struct { - Prompt string - Orientation string - Frames int - Model string - Size string - VideoCount int - MediaID string - RemixTargetID string - CameoIDs []string -} - -// SoraStoryboardRequest 分镜视频生成请求参数 -type SoraStoryboardRequest struct { - Prompt string - Orientation string - Frames int - Model string - Size string - MediaID string -} - -// SoraImageTaskStatus 图片任务状态 -type SoraImageTaskStatus struct { - ID string - Status string - ProgressPct float64 - URLs []string - ErrorMsg string -} - -// SoraVideoTaskStatus 视频任务状态 -type SoraVideoTaskStatus struct { - ID string - Status string - ProgressPct int - URLs []string - GenerationID string - ErrorMsg string -} - -// SoraCameoStatus 角色处理中间态 -type SoraCameoStatus struct { - Status string - StatusMessage string - DisplayNameHint string - UsernameHint string - ProfileAssetURL string - InstructionSetHint any - InstructionSet any -} - -// SoraCharacterFinalizeRequest 角色定稿请求参数 -type SoraCharacterFinalizeRequest struct { - CameoID string - Username string - DisplayName string - ProfileAssetPointer string - InstructionSet any -} - -// SoraUpstreamError 上游错误 -type SoraUpstreamError struct { - StatusCode int - Message string - Headers http.Header - Body []byte -} - -func (e *SoraUpstreamError) Error() string { - if e == nil { - return "sora upstream error" - } - if e.Message != "" { - return fmt.Sprintf("sora upstream error: %d %s", e.StatusCode, e.Message) - } - return fmt.Sprintf("sora upstream error: %d", e.StatusCode) -} diff --git a/backend/internal/service/sora_gateway_service.go b/backend/internal/service/sora_gateway_service.go deleted file mode 100644 index e9d325f4..00000000 --- a/backend/internal/service/sora_gateway_service.go +++ /dev/null @@ -1,1559 +0,0 @@ -package service - -import ( - "bytes" - "context" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "io" - "log" - "math" - "math/rand" - "mime" - "net" - "net/http" - "net/url" - "regexp" - "strconv" - "strings" - "time" - - "github.com/Wei-Shaw/sub2api/internal/config" - "github.com/Wei-Shaw/sub2api/internal/pkg/logger" - "github.com/gin-gonic/gin" -) - -const soraImageInputMaxBytes = 20 << 20 -const soraImageInputMaxRedirects = 3 -const soraImageInputTimeout = 20 * time.Second -const soraVideoInputMaxBytes = 200 << 20 -const soraVideoInputMaxRedirects = 3 -const soraVideoInputTimeout = 60 * time.Second - -var soraImageSizeMap = map[string]string{ - "gpt-image": "360", - "gpt-image-landscape": "540", - "gpt-image-portrait": "540", -} - -var soraBlockedHostnames = map[string]struct{}{ - "localhost": {}, - "localhost.localdomain": {}, - "metadata.google.internal": {}, - "metadata.google.internal.": {}, -} - -var soraBlockedCIDRs = mustParseCIDRs([]string{ - "0.0.0.0/8", - "10.0.0.0/8", - "100.64.0.0/10", - "127.0.0.0/8", - "169.254.0.0/16", - "172.16.0.0/12", - "192.168.0.0/16", - "224.0.0.0/4", - "240.0.0.0/4", - "::/128", - "::1/128", - "fc00::/7", - "fe80::/10", -}) - -// SoraGatewayService handles forwarding requests to Sora upstream. -type SoraGatewayService struct { - soraClient SoraClient - rateLimitService *RateLimitService - httpUpstream HTTPUpstream // 用于 apikey 类型账号的 HTTP 透传 - cfg *config.Config -} - -type soraWatermarkOptions struct { - Enabled bool - ParseMethod string - ParseURL string - ParseToken string - FallbackOnFailure bool - DeletePost bool -} - -type soraCharacterOptions struct { - SetPublic bool - DeleteAfterGenerate bool -} - -type soraCharacterFlowResult struct { - CameoID string - CharacterID string - Username string - DisplayName string -} - -var soraStoryboardPattern = regexp.MustCompile(`\[\d+(?:\.\d+)?s\]`) -var soraStoryboardShotPattern = regexp.MustCompile(`\[(\d+(?:\.\d+)?)s\]\s*([^\[]+)`) -var soraRemixTargetPattern = regexp.MustCompile(`s_[a-f0-9]{32}`) -var soraRemixTargetInURLPattern = regexp.MustCompile(`https://sora\.chatgpt\.com/p/s_[a-f0-9]{32}`) - -type soraPreflightChecker interface { - PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error -} - -func NewSoraGatewayService( - soraClient SoraClient, - rateLimitService *RateLimitService, - httpUpstream HTTPUpstream, - cfg *config.Config, -) *SoraGatewayService { - return &SoraGatewayService{ - soraClient: soraClient, - rateLimitService: rateLimitService, - httpUpstream: httpUpstream, - cfg: cfg, - } -} - -func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, clientStream bool) (*ForwardResult, error) { - startTime := time.Now() - - // apikey 类型账号:HTTP 透传到上游,不走 SoraSDKClient - if account.Type == AccountTypeAPIKey && account.GetBaseURL() != "" { - if s.httpUpstream == nil { - s.writeSoraError(c, http.StatusInternalServerError, "api_error", "HTTP upstream client not configured", clientStream) - return nil, errors.New("httpUpstream not configured for sora apikey forwarding") - } - return s.forwardToUpstream(ctx, c, account, body, clientStream, startTime) - } - - if s.soraClient == nil || !s.soraClient.Enabled() { - if c != nil { - c.JSON(http.StatusServiceUnavailable, gin.H{ - "error": gin.H{ - "type": "api_error", - "message": "Sora 上游未配置", - }, - }) - } - return nil, errors.New("sora upstream not configured") - } - - var reqBody map[string]any - if err := json.Unmarshal(body, &reqBody); err != nil { - s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body", clientStream) - return nil, fmt.Errorf("parse request: %w", err) - } - reqModel, _ := reqBody["model"].(string) - reqStream, _ := reqBody["stream"].(bool) - if strings.TrimSpace(reqModel) == "" { - s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "model is required", clientStream) - return nil, errors.New("model is required") - } - originalModel := reqModel - - mappedModel := account.GetMappedModel(reqModel) - var upstreamModel string - if mappedModel != "" && mappedModel != reqModel { - reqModel = mappedModel - upstreamModel = mappedModel - } - - modelCfg, ok := GetSoraModelConfig(reqModel) - if !ok { - s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Unsupported Sora model", clientStream) - return nil, fmt.Errorf("unsupported model: %s", reqModel) - } - prompt, imageInput, videoInput, remixTargetID := extractSoraInput(reqBody) - prompt = strings.TrimSpace(prompt) - imageInput = strings.TrimSpace(imageInput) - videoInput = strings.TrimSpace(videoInput) - remixTargetID = strings.TrimSpace(remixTargetID) - - if videoInput != "" && modelCfg.Type != "video" { - s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "video input only supports video models", clientStream) - return nil, errors.New("video input only supports video models") - } - if videoInput != "" && imageInput != "" { - s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "image input and video input cannot be used together", clientStream) - return nil, errors.New("image input and video input cannot be used together") - } - characterOnly := videoInput != "" && prompt == "" - if modelCfg.Type == "prompt_enhance" && prompt == "" { - s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream) - return nil, errors.New("prompt is required") - } - if modelCfg.Type != "prompt_enhance" && prompt == "" && !characterOnly { - s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream) - return nil, errors.New("prompt is required") - } - - reqCtx, cancel := s.withSoraTimeout(ctx, reqStream) - if cancel != nil { - defer cancel() - } - if checker, ok := s.soraClient.(soraPreflightChecker); ok && !characterOnly { - if err := checker.PreflightCheck(reqCtx, account, reqModel, modelCfg); err != nil { - return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream) - } - } - - if modelCfg.Type == "prompt_enhance" { - enhancedPrompt, err := s.soraClient.EnhancePrompt(reqCtx, account, prompt, modelCfg.ExpansionLevel, modelCfg.DurationS) - if err != nil { - return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream) - } - content := strings.TrimSpace(enhancedPrompt) - if content == "" { - content = prompt - } - var firstTokenMs *int - if clientStream { - ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime) - if streamErr != nil { - return nil, streamErr - } - firstTokenMs = ms - } else if c != nil { - c.JSON(http.StatusOK, buildSoraNonStreamResponse(content, reqModel)) - } - return &ForwardResult{ - RequestID: "", - Model: originalModel, - UpstreamModel: upstreamModel, - Stream: clientStream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - Usage: ClaudeUsage{}, - MediaType: "prompt", - }, nil - } - - characterOpts := parseSoraCharacterOptions(reqBody) - watermarkOpts := parseSoraWatermarkOptions(reqBody) - var characterResult *soraCharacterFlowResult - if videoInput != "" { - videoData, videoErr := decodeSoraVideoInput(reqCtx, videoInput) - if videoErr != nil { - s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", videoErr.Error(), clientStream) - return nil, videoErr - } - characterResult, videoErr = s.createCharacterFromVideo(reqCtx, account, videoData, characterOpts) - if videoErr != nil { - return nil, s.handleSoraRequestError(ctx, account, videoErr, reqModel, c, clientStream) - } - if characterResult != nil && characterOpts.DeleteAfterGenerate && strings.TrimSpace(characterResult.CharacterID) != "" && !characterOnly { - characterID := strings.TrimSpace(characterResult.CharacterID) - defer func() { - cleanupCtx, cancelCleanup := context.WithTimeout(context.Background(), 15*time.Second) - defer cancelCleanup() - if err := s.soraClient.DeleteCharacter(cleanupCtx, account, characterID); err != nil { - log.Printf("[Sora] cleanup character failed, character_id=%s err=%v", characterID, err) - } - }() - } - if characterOnly { - content := "角色创建成功" - if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" { - content = fmt.Sprintf("角色创建成功,角色名@%s", strings.TrimSpace(characterResult.Username)) - } - var firstTokenMs *int - if clientStream { - ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime) - if streamErr != nil { - return nil, streamErr - } - firstTokenMs = ms - } else if c != nil { - resp := buildSoraNonStreamResponse(content, reqModel) - if characterResult != nil { - resp["character_id"] = characterResult.CharacterID - resp["cameo_id"] = characterResult.CameoID - resp["character_username"] = characterResult.Username - resp["character_display_name"] = characterResult.DisplayName - } - c.JSON(http.StatusOK, resp) - } - return &ForwardResult{ - RequestID: "", - Model: originalModel, - UpstreamModel: upstreamModel, - Stream: clientStream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - Usage: ClaudeUsage{}, - MediaType: "prompt", - }, nil - } - if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" { - prompt = fmt.Sprintf("@%s %s", characterResult.Username, prompt) - } - } - - var imageData []byte - imageFilename := "" - if imageInput != "" { - decoded, filename, err := decodeSoraImageInput(reqCtx, imageInput) - if err != nil { - s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", err.Error(), clientStream) - return nil, err - } - imageData = decoded - imageFilename = filename - } - - mediaID := "" - if len(imageData) > 0 { - uploadID, err := s.soraClient.UploadImage(reqCtx, account, imageData, imageFilename) - if err != nil { - return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream) - } - mediaID = uploadID - } - - taskID := "" - var err error - videoCount := parseSoraVideoCount(reqBody) - switch modelCfg.Type { - case "image": - taskID, err = s.soraClient.CreateImageTask(reqCtx, account, SoraImageRequest{ - Prompt: prompt, - Width: modelCfg.Width, - Height: modelCfg.Height, - MediaID: mediaID, - }) - case "video": - if remixTargetID == "" && isSoraStoryboardPrompt(prompt) { - taskID, err = s.soraClient.CreateStoryboardTask(reqCtx, account, SoraStoryboardRequest{ - Prompt: formatSoraStoryboardPrompt(prompt), - Orientation: modelCfg.Orientation, - Frames: modelCfg.Frames, - Model: modelCfg.Model, - Size: modelCfg.Size, - MediaID: mediaID, - }) - } else { - taskID, err = s.soraClient.CreateVideoTask(reqCtx, account, SoraVideoRequest{ - Prompt: prompt, - Orientation: modelCfg.Orientation, - Frames: modelCfg.Frames, - Model: modelCfg.Model, - Size: modelCfg.Size, - VideoCount: videoCount, - MediaID: mediaID, - RemixTargetID: remixTargetID, - CameoIDs: extractSoraCameoIDs(reqBody), - }) - } - default: - err = fmt.Errorf("unsupported model type: %s", modelCfg.Type) - } - if err != nil { - return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream) - } - - if clientStream && c != nil { - s.prepareSoraStream(c, taskID) - } - - var mediaURLs []string - videoGenerationID := "" - mediaType := modelCfg.Type - imageCount := 0 - imageSize := "" - switch modelCfg.Type { - case "image": - urls, pollErr := s.pollImageTask(reqCtx, c, account, taskID, clientStream) - if pollErr != nil { - return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream) - } - mediaURLs = urls - imageCount = len(urls) - imageSize = soraImageSizeFromModel(reqModel) - case "video": - videoStatus, pollErr := s.pollVideoTaskDetailed(reqCtx, c, account, taskID, clientStream) - if pollErr != nil { - return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream) - } - if videoStatus != nil { - mediaURLs = videoStatus.URLs - videoGenerationID = strings.TrimSpace(videoStatus.GenerationID) - } - default: - mediaType = "prompt" - } - - watermarkPostID := "" - if modelCfg.Type == "video" && watermarkOpts.Enabled { - watermarkURL, postID, watermarkErr := s.resolveWatermarkFreeURL(reqCtx, account, videoGenerationID, watermarkOpts) - if watermarkErr != nil { - if !watermarkOpts.FallbackOnFailure { - return nil, s.handleSoraRequestError(ctx, account, watermarkErr, reqModel, c, clientStream) - } - log.Printf("[Sora] watermark-free fallback to original URL, task_id=%s err=%v", taskID, watermarkErr) - } else if strings.TrimSpace(watermarkURL) != "" { - mediaURLs = []string{strings.TrimSpace(watermarkURL)} - watermarkPostID = strings.TrimSpace(postID) - } - } - - // 直调路径(/sora/v1/chat/completions)保持纯透传,不执行本地/S3 媒体落盘。 - // 媒体存储由客户端 API 路径(/api/v1/sora/generate)的异步流程负责。 - finalURLs := s.normalizeSoraMediaURLs(mediaURLs) - if watermarkPostID != "" && watermarkOpts.DeletePost { - if deleteErr := s.soraClient.DeletePost(reqCtx, account, watermarkPostID); deleteErr != nil { - log.Printf("[Sora] delete post failed, post_id=%s err=%v", watermarkPostID, deleteErr) - } - } - - content := buildSoraContent(mediaType, finalURLs) - var firstTokenMs *int - if clientStream { - ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime) - if streamErr != nil { - return nil, streamErr - } - firstTokenMs = ms - } else if c != nil { - response := buildSoraNonStreamResponse(content, reqModel) - if len(finalURLs) > 0 { - response["media_url"] = finalURLs[0] - if len(finalURLs) > 1 { - response["media_urls"] = finalURLs - } - } - c.JSON(http.StatusOK, response) - } - - return &ForwardResult{ - RequestID: taskID, - Model: originalModel, - UpstreamModel: upstreamModel, - Stream: clientStream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - Usage: ClaudeUsage{}, - MediaType: mediaType, - MediaURL: firstMediaURL(finalURLs), - ImageCount: imageCount, - ImageSize: imageSize, - }, nil -} - -func (s *SoraGatewayService) withSoraTimeout(ctx context.Context, stream bool) (context.Context, context.CancelFunc) { - if s == nil || s.cfg == nil { - return ctx, nil - } - timeoutSeconds := s.cfg.Gateway.SoraRequestTimeoutSeconds - if stream { - timeoutSeconds = s.cfg.Gateway.SoraStreamTimeoutSeconds - } - if timeoutSeconds <= 0 { - return ctx, nil - } - return context.WithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second) -} - -func parseSoraWatermarkOptions(body map[string]any) soraWatermarkOptions { - opts := soraWatermarkOptions{ - Enabled: parseBoolWithDefault(body, "watermark_free", false), - ParseMethod: strings.ToLower(strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_method", "third_party"))), - ParseURL: strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_url", "")), - ParseToken: strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_token", "")), - FallbackOnFailure: parseBoolWithDefault(body, "watermark_fallback_on_failure", true), - DeletePost: parseBoolWithDefault(body, "watermark_delete_post", false), - } - if opts.ParseMethod == "" { - opts.ParseMethod = "third_party" - } - return opts -} - -func parseSoraCharacterOptions(body map[string]any) soraCharacterOptions { - return soraCharacterOptions{ - SetPublic: parseBoolWithDefault(body, "character_set_public", true), - DeleteAfterGenerate: parseBoolWithDefault(body, "character_delete_after_generate", true), - } -} - -func parseSoraVideoCount(body map[string]any) int { - if body == nil { - return 1 - } - keys := []string{"video_count", "videos", "n_variants"} - for _, key := range keys { - count := parseIntWithDefault(body, key, 0) - if count > 0 { - return clampInt(count, 1, 3) - } - } - return 1 -} - -func parseBoolWithDefault(body map[string]any, key string, def bool) bool { - if body == nil { - return def - } - val, ok := body[key] - if !ok { - return def - } - switch typed := val.(type) { - case bool: - return typed - case int: - return typed != 0 - case int32: - return typed != 0 - case int64: - return typed != 0 - case float64: - return typed != 0 - case string: - typed = strings.ToLower(strings.TrimSpace(typed)) - if typed == "true" || typed == "1" || typed == "yes" { - return true - } - if typed == "false" || typed == "0" || typed == "no" { - return false - } - } - return def -} - -func parseStringWithDefault(body map[string]any, key, def string) string { - if body == nil { - return def - } - val, ok := body[key] - if !ok { - return def - } - if str, ok := val.(string); ok { - return str - } - return def -} - -func parseIntWithDefault(body map[string]any, key string, def int) int { - if body == nil { - return def - } - val, ok := body[key] - if !ok { - return def - } - switch typed := val.(type) { - case int: - return typed - case int32: - return int(typed) - case int64: - return int(typed) - case float64: - return int(typed) - case string: - parsed, err := strconv.Atoi(strings.TrimSpace(typed)) - if err == nil { - return parsed - } - } - return def -} - -func clampInt(v, minVal, maxVal int) int { - if v < minVal { - return minVal - } - if v > maxVal { - return maxVal - } - return v -} - -func extractSoraCameoIDs(body map[string]any) []string { - if body == nil { - return nil - } - raw, ok := body["cameo_ids"] - if !ok { - return nil - } - switch typed := raw.(type) { - case []string: - out := make([]string, 0, len(typed)) - for _, item := range typed { - item = strings.TrimSpace(item) - if item != "" { - out = append(out, item) - } - } - return out - case []any: - out := make([]string, 0, len(typed)) - for _, item := range typed { - str, ok := item.(string) - if !ok { - continue - } - str = strings.TrimSpace(str) - if str != "" { - out = append(out, str) - } - } - return out - default: - return nil - } -} - -func (s *SoraGatewayService) createCharacterFromVideo(ctx context.Context, account *Account, videoData []byte, opts soraCharacterOptions) (*soraCharacterFlowResult, error) { - cameoID, err := s.soraClient.UploadCharacterVideo(ctx, account, videoData) - if err != nil { - return nil, err - } - - cameoStatus, err := s.pollCameoStatus(ctx, account, cameoID) - if err != nil { - return nil, err - } - username := processSoraCharacterUsername(cameoStatus.UsernameHint) - displayName := strings.TrimSpace(cameoStatus.DisplayNameHint) - if displayName == "" { - displayName = "Character" - } - profileAssetURL := strings.TrimSpace(cameoStatus.ProfileAssetURL) - if profileAssetURL == "" { - return nil, errors.New("profile asset url not found in cameo status") - } - - avatarData, err := s.soraClient.DownloadCharacterImage(ctx, account, profileAssetURL) - if err != nil { - return nil, err - } - assetPointer, err := s.soraClient.UploadCharacterImage(ctx, account, avatarData) - if err != nil { - return nil, err - } - instructionSet := cameoStatus.InstructionSetHint - if instructionSet == nil { - instructionSet = cameoStatus.InstructionSet - } - - characterID, err := s.soraClient.FinalizeCharacter(ctx, account, SoraCharacterFinalizeRequest{ - CameoID: strings.TrimSpace(cameoID), - Username: username, - DisplayName: displayName, - ProfileAssetPointer: assetPointer, - InstructionSet: instructionSet, - }) - if err != nil { - return nil, err - } - - if opts.SetPublic { - if err := s.soraClient.SetCharacterPublic(ctx, account, cameoID); err != nil { - return nil, err - } - } - - return &soraCharacterFlowResult{ - CameoID: strings.TrimSpace(cameoID), - CharacterID: strings.TrimSpace(characterID), - Username: strings.TrimSpace(username), - DisplayName: displayName, - }, nil -} - -func (s *SoraGatewayService) pollCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) { - timeout := 10 * time.Minute - interval := 5 * time.Second - maxAttempts := int(math.Ceil(timeout.Seconds() / interval.Seconds())) - if maxAttempts < 1 { - maxAttempts = 1 - } - - var lastErr error - consecutiveErrors := 0 - for attempt := 0; attempt < maxAttempts; attempt++ { - status, err := s.soraClient.GetCameoStatus(ctx, account, cameoID) - if err != nil { - lastErr = err - consecutiveErrors++ - if consecutiveErrors >= 3 { - break - } - if attempt < maxAttempts-1 { - if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil { - return nil, sleepErr - } - } - continue - } - consecutiveErrors = 0 - if status == nil { - if attempt < maxAttempts-1 { - if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil { - return nil, sleepErr - } - } - continue - } - currentStatus := strings.ToLower(strings.TrimSpace(status.Status)) - statusMessage := strings.TrimSpace(status.StatusMessage) - if currentStatus == "failed" { - if statusMessage == "" { - statusMessage = "character creation failed" - } - return nil, errors.New(statusMessage) - } - if strings.EqualFold(statusMessage, "Completed") || currentStatus == "finalized" { - return status, nil - } - if attempt < maxAttempts-1 { - if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil { - return nil, sleepErr - } - } - } - if lastErr != nil { - return nil, fmt.Errorf("poll cameo status failed: %w", lastErr) - } - return nil, errors.New("cameo processing timeout") -} - -func processSoraCharacterUsername(usernameHint string) string { - usernameHint = strings.TrimSpace(usernameHint) - if usernameHint == "" { - usernameHint = "character" - } - if strings.Contains(usernameHint, ".") { - parts := strings.Split(usernameHint, ".") - usernameHint = strings.TrimSpace(parts[len(parts)-1]) - } - if usernameHint == "" { - usernameHint = "character" - } - return fmt.Sprintf("%s%d", usernameHint, rand.Intn(900)+100) -} - -func (s *SoraGatewayService) resolveWatermarkFreeURL(ctx context.Context, account *Account, generationID string, opts soraWatermarkOptions) (string, string, error) { - generationID = strings.TrimSpace(generationID) - if generationID == "" { - return "", "", errors.New("generation id is required for watermark-free mode") - } - postID, err := s.soraClient.PostVideoForWatermarkFree(ctx, account, generationID) - if err != nil { - return "", "", err - } - postID = strings.TrimSpace(postID) - if postID == "" { - return "", "", errors.New("watermark-free publish returned empty post id") - } - - switch opts.ParseMethod { - case "custom": - urlVal, parseErr := s.soraClient.GetWatermarkFreeURLCustom(ctx, account, opts.ParseURL, opts.ParseToken, postID) - if parseErr != nil { - return "", postID, parseErr - } - return strings.TrimSpace(urlVal), postID, nil - case "", "third_party": - return fmt.Sprintf("https://oscdn2.dyysy.com/MP4/%s.mp4", postID), postID, nil - default: - return "", postID, fmt.Errorf("unsupported watermark parse method: %s", opts.ParseMethod) - } -} - -func (s *SoraGatewayService) shouldFailoverUpstreamError(statusCode int) bool { - switch statusCode { - case 401, 402, 403, 404, 429, 529: - return true - default: - return statusCode >= 500 - } -} - -func buildSoraNonStreamResponse(content, model string) map[string]any { - return map[string]any{ - "id": fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()), - "object": "chat.completion", - "created": time.Now().Unix(), - "model": model, - "choices": []any{ - map[string]any{ - "index": 0, - "message": map[string]any{ - "role": "assistant", - "content": content, - }, - "finish_reason": "stop", - }, - }, - } -} - -func soraImageSizeFromModel(model string) string { - modelLower := strings.ToLower(model) - if size, ok := soraImageSizeMap[modelLower]; ok { - return size - } - if strings.Contains(modelLower, "landscape") || strings.Contains(modelLower, "portrait") { - return "540" - } - return "360" -} - -func soraProErrorMessage(model, upstreamMsg string) string { - modelLower := strings.ToLower(model) - if strings.Contains(modelLower, "sora2pro-hd") { - return "当前账号无法使用 Sora Pro-HD 模型,请更换模型或账号" - } - if strings.Contains(modelLower, "sora2pro") { - return "当前账号无法使用 Sora Pro 模型,请更换模型或账号" - } - return "" -} - -func firstMediaURL(urls []string) string { - if len(urls) == 0 { - return "" - } - return urls[0] -} - -func (s *SoraGatewayService) buildSoraMediaURL(path string, rawQuery string) string { - if path == "" { - return path - } - prefix := "/sora/media" - values := url.Values{} - if rawQuery != "" { - if parsed, err := url.ParseQuery(rawQuery); err == nil { - values = parsed - } - } - - signKey := "" - ttlSeconds := 0 - if s != nil && s.cfg != nil { - signKey = strings.TrimSpace(s.cfg.Gateway.SoraMediaSigningKey) - ttlSeconds = s.cfg.Gateway.SoraMediaSignedURLTTLSeconds - } - values.Del("sig") - values.Del("expires") - signingQuery := values.Encode() - if signKey != "" && ttlSeconds > 0 { - expires := time.Now().Add(time.Duration(ttlSeconds) * time.Second).Unix() - signature := SignSoraMediaURL(path, signingQuery, expires, signKey) - if signature != "" { - values.Set("expires", strconv.FormatInt(expires, 10)) - values.Set("sig", signature) - prefix = "/sora/media-signed" - } - } - - encoded := values.Encode() - if encoded == "" { - return prefix + path - } - return prefix + path + "?" + encoded -} - -func (s *SoraGatewayService) prepareSoraStream(c *gin.Context, requestID string) { - if c == nil { - return - } - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("X-Accel-Buffering", "no") - if strings.TrimSpace(requestID) != "" { - c.Header("x-request-id", requestID) - } -} - -func (s *SoraGatewayService) writeSoraStream(c *gin.Context, model, content string, startTime time.Time) (*int, error) { - if c == nil { - return nil, nil - } - writer := c.Writer - flusher, _ := writer.(http.Flusher) - - chunk := map[string]any{ - "id": fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()), - "object": "chat.completion.chunk", - "created": time.Now().Unix(), - "model": model, - "choices": []any{ - map[string]any{ - "index": 0, - "delta": map[string]any{ - "content": content, - }, - }, - }, - } - encoded, _ := jsonMarshalRaw(chunk) - if _, err := fmt.Fprintf(writer, "data: %s\n\n", encoded); err != nil { - return nil, err - } - if flusher != nil { - flusher.Flush() - } - ms := int(time.Since(startTime).Milliseconds()) - finalChunk := map[string]any{ - "id": chunk["id"], - "object": "chat.completion.chunk", - "created": time.Now().Unix(), - "model": model, - "choices": []any{ - map[string]any{ - "index": 0, - "delta": map[string]any{}, - "finish_reason": "stop", - }, - }, - } - finalEncoded, _ := jsonMarshalRaw(finalChunk) - if _, err := fmt.Fprintf(writer, "data: %s\n\n", finalEncoded); err != nil { - return &ms, err - } - if _, err := fmt.Fprint(writer, "data: [DONE]\n\n"); err != nil { - return &ms, err - } - if flusher != nil { - flusher.Flush() - } - return &ms, nil -} - -func (s *SoraGatewayService) writeSoraError(c *gin.Context, status int, errType, message string, stream bool) { - if c == nil { - return - } - if stream { - flusher, _ := c.Writer.(http.Flusher) - errorData := map[string]any{ - "error": map[string]string{ - "type": errType, - "message": message, - }, - } - jsonBytes, err := json.Marshal(errorData) - if err != nil { - _ = c.Error(err) - return - } - errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes)) - _, _ = fmt.Fprint(c.Writer, errorEvent) - _, _ = fmt.Fprint(c.Writer, "data: [DONE]\n\n") - if flusher != nil { - flusher.Flush() - } - return - } - c.JSON(status, gin.H{ - "error": gin.H{ - "type": errType, - "message": message, - }, - }) -} - -func (s *SoraGatewayService) handleSoraRequestError(ctx context.Context, account *Account, err error, model string, c *gin.Context, stream bool) error { - if err == nil { - return nil - } - var upstreamErr *SoraUpstreamError - if errors.As(err, &upstreamErr) { - accountID := int64(0) - if account != nil { - accountID = account.ID - } - logger.LegacyPrintf( - "service.sora", - "[SoraRawError] account_id=%d model=%s status=%d request_id=%s cf_ray=%s message=%s raw_body=%s", - accountID, - model, - upstreamErr.StatusCode, - strings.TrimSpace(upstreamErr.Headers.Get("x-request-id")), - strings.TrimSpace(upstreamErr.Headers.Get("cf-ray")), - strings.TrimSpace(upstreamErr.Message), - truncateForLog(upstreamErr.Body, 1024), - ) - if s.rateLimitService != nil && account != nil { - s.rateLimitService.HandleUpstreamError(ctx, account, upstreamErr.StatusCode, upstreamErr.Headers, upstreamErr.Body) - } - if s.shouldFailoverUpstreamError(upstreamErr.StatusCode) { - var responseHeaders http.Header - if upstreamErr.Headers != nil { - responseHeaders = upstreamErr.Headers.Clone() - } - return &UpstreamFailoverError{ - StatusCode: upstreamErr.StatusCode, - ResponseBody: upstreamErr.Body, - ResponseHeaders: responseHeaders, - } - } - msg := upstreamErr.Message - if override := soraProErrorMessage(model, msg); override != "" { - msg = override - } - s.writeSoraError(c, upstreamErr.StatusCode, "upstream_error", msg, stream) - return err - } - if errors.Is(err, context.DeadlineExceeded) { - s.writeSoraError(c, http.StatusGatewayTimeout, "timeout_error", "Sora generation timeout", stream) - return err - } - s.writeSoraError(c, http.StatusBadGateway, "api_error", err.Error(), stream) - return err -} - -func (s *SoraGatewayService) pollImageTask(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) ([]string, error) { - interval := s.pollInterval() - maxAttempts := s.pollMaxAttempts() - lastPing := time.Now() - for attempt := 0; attempt < maxAttempts; attempt++ { - status, err := s.soraClient.GetImageTask(ctx, account, taskID) - if err != nil { - return nil, err - } - switch strings.ToLower(status.Status) { - case "succeeded", "completed": - return status.URLs, nil - case "failed": - if status.ErrorMsg != "" { - return nil, errors.New(status.ErrorMsg) - } - return nil, errors.New("sora image generation failed") - } - if stream { - s.maybeSendPing(c, &lastPing) - } - if err := sleepWithContext(ctx, interval); err != nil { - return nil, err - } - } - return nil, errors.New("sora image generation timeout") -} - -func (s *SoraGatewayService) pollVideoTaskDetailed(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) (*SoraVideoTaskStatus, error) { - interval := s.pollInterval() - maxAttempts := s.pollMaxAttempts() - lastPing := time.Now() - for attempt := 0; attempt < maxAttempts; attempt++ { - status, err := s.soraClient.GetVideoTask(ctx, account, taskID) - if err != nil { - return nil, err - } - switch strings.ToLower(status.Status) { - case "completed", "succeeded": - return status, nil - case "failed": - if status.ErrorMsg != "" { - return nil, errors.New(status.ErrorMsg) - } - return nil, errors.New("sora video generation failed") - } - if stream { - s.maybeSendPing(c, &lastPing) - } - if err := sleepWithContext(ctx, interval); err != nil { - return nil, err - } - } - return nil, errors.New("sora video generation timeout") -} - -func (s *SoraGatewayService) pollInterval() time.Duration { - if s == nil || s.cfg == nil { - return 2 * time.Second - } - interval := s.cfg.Sora.Client.PollIntervalSeconds - if interval <= 0 { - interval = 2 - } - return time.Duration(interval) * time.Second -} - -func (s *SoraGatewayService) pollMaxAttempts() int { - if s == nil || s.cfg == nil { - return 600 - } - maxAttempts := s.cfg.Sora.Client.MaxPollAttempts - if maxAttempts <= 0 { - maxAttempts = 600 - } - return maxAttempts -} - -func (s *SoraGatewayService) maybeSendPing(c *gin.Context, lastPing *time.Time) { - if c == nil { - return - } - interval := 10 * time.Second - if s != nil && s.cfg != nil && s.cfg.Concurrency.PingInterval > 0 { - interval = time.Duration(s.cfg.Concurrency.PingInterval) * time.Second - } - if time.Since(*lastPing) < interval { - return - } - if _, err := fmt.Fprint(c.Writer, ":\n\n"); err == nil { - if flusher, ok := c.Writer.(http.Flusher); ok { - flusher.Flush() - } - *lastPing = time.Now() - } -} - -func (s *SoraGatewayService) normalizeSoraMediaURLs(urls []string) []string { - if len(urls) == 0 { - return urls - } - output := make([]string, 0, len(urls)) - for _, raw := range urls { - raw = strings.TrimSpace(raw) - if raw == "" { - continue - } - if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") { - output = append(output, raw) - continue - } - pathVal := raw - if !strings.HasPrefix(pathVal, "/") { - pathVal = "/" + pathVal - } - output = append(output, s.buildSoraMediaURL(pathVal, "")) - } - return output -} - -// jsonMarshalRaw 序列化 JSON,不转义 &、<、> 等 HTML 字符, -// 避免 URL 中的 & 被转义为 \u0026 导致客户端无法直接使用。 -func jsonMarshalRaw(v any) ([]byte, error) { - var buf bytes.Buffer - enc := json.NewEncoder(&buf) - enc.SetEscapeHTML(false) - if err := enc.Encode(v); err != nil { - return nil, err - } - // Encode 会追加换行符,去掉它 - b := buf.Bytes() - if len(b) > 0 && b[len(b)-1] == '\n' { - b = b[:len(b)-1] - } - return b, nil -} - -func buildSoraContent(mediaType string, urls []string) string { - switch mediaType { - case "image": - parts := make([]string, 0, len(urls)) - for _, u := range urls { - parts = append(parts, fmt.Sprintf("![image](%s)", u)) - } - return strings.Join(parts, "\n") - case "video": - if len(urls) == 0 { - return "" - } - return fmt.Sprintf("```html\n\n```", urls[0]) - default: - return "" - } -} - -func extractSoraInput(body map[string]any) (prompt, imageInput, videoInput, remixTargetID string) { - if body == nil { - return "", "", "", "" - } - if v, ok := body["remix_target_id"].(string); ok { - remixTargetID = strings.TrimSpace(v) - } - if v, ok := body["image"].(string); ok { - imageInput = v - } - if v, ok := body["video"].(string); ok { - videoInput = v - } - if v, ok := body["prompt"].(string); ok && strings.TrimSpace(v) != "" { - prompt = v - } - if messages, ok := body["messages"].([]any); ok { - builder := strings.Builder{} - for _, raw := range messages { - msg, ok := raw.(map[string]any) - if !ok { - continue - } - role, _ := msg["role"].(string) - if role != "" && role != "user" { - continue - } - content := msg["content"] - text, img, vid := parseSoraMessageContent(content) - if text != "" { - if builder.Len() > 0 { - _, _ = builder.WriteString("\n") - } - _, _ = builder.WriteString(text) - } - if imageInput == "" && img != "" { - imageInput = img - } - if videoInput == "" && vid != "" { - videoInput = vid - } - } - if prompt == "" { - prompt = builder.String() - } - } - if remixTargetID == "" { - remixTargetID = extractRemixTargetIDFromPrompt(prompt) - } - prompt = cleanRemixLinkFromPrompt(prompt) - return prompt, imageInput, videoInput, remixTargetID -} - -func parseSoraMessageContent(content any) (text, imageInput, videoInput string) { - switch val := content.(type) { - case string: - return val, "", "" - case []any: - builder := strings.Builder{} - for _, item := range val { - itemMap, ok := item.(map[string]any) - if !ok { - continue - } - t, _ := itemMap["type"].(string) - switch t { - case "text": - if txt, ok := itemMap["text"].(string); ok && strings.TrimSpace(txt) != "" { - if builder.Len() > 0 { - _, _ = builder.WriteString("\n") - } - _, _ = builder.WriteString(txt) - } - case "image_url": - if imageInput == "" { - if urlVal, ok := itemMap["image_url"].(map[string]any); ok { - imageInput = fmt.Sprintf("%v", urlVal["url"]) - } else if urlStr, ok := itemMap["image_url"].(string); ok { - imageInput = urlStr - } - } - case "video_url": - if videoInput == "" { - if urlVal, ok := itemMap["video_url"].(map[string]any); ok { - videoInput = fmt.Sprintf("%v", urlVal["url"]) - } else if urlStr, ok := itemMap["video_url"].(string); ok { - videoInput = urlStr - } - } - } - } - return builder.String(), imageInput, videoInput - default: - return "", "", "" - } -} - -func isSoraStoryboardPrompt(prompt string) bool { - prompt = strings.TrimSpace(prompt) - if prompt == "" { - return false - } - return len(soraStoryboardPattern.FindAllString(prompt, -1)) >= 1 -} - -func formatSoraStoryboardPrompt(prompt string) string { - prompt = strings.TrimSpace(prompt) - if prompt == "" { - return "" - } - matches := soraStoryboardShotPattern.FindAllStringSubmatch(prompt, -1) - if len(matches) == 0 { - return prompt - } - firstBracketPos := strings.Index(prompt, "[") - instructions := "" - if firstBracketPos > 0 { - instructions = strings.TrimSpace(prompt[:firstBracketPos]) - } - shots := make([]string, 0, len(matches)) - for i, match := range matches { - if len(match) < 3 { - continue - } - duration := strings.TrimSpace(match[1]) - scene := strings.TrimSpace(match[2]) - if scene == "" { - continue - } - shots = append(shots, fmt.Sprintf("Shot %d:\nduration: %ssec\nScene: %s", i+1, duration, scene)) - } - if len(shots) == 0 { - return prompt - } - timeline := strings.Join(shots, "\n\n") - if instructions == "" { - return timeline - } - return fmt.Sprintf("current timeline:\n%s\n\ninstructions:\n%s", timeline, instructions) -} - -func extractRemixTargetIDFromPrompt(prompt string) string { - prompt = strings.TrimSpace(prompt) - if prompt == "" { - return "" - } - return strings.TrimSpace(soraRemixTargetPattern.FindString(prompt)) -} - -func cleanRemixLinkFromPrompt(prompt string) string { - prompt = strings.TrimSpace(prompt) - if prompt == "" { - return prompt - } - cleaned := soraRemixTargetInURLPattern.ReplaceAllString(prompt, "") - cleaned = soraRemixTargetPattern.ReplaceAllString(cleaned, "") - cleaned = strings.Join(strings.Fields(cleaned), " ") - return strings.TrimSpace(cleaned) -} - -func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, error) { - raw := strings.TrimSpace(input) - if raw == "" { - return nil, "", errors.New("empty image input") - } - if strings.HasPrefix(raw, "data:") { - parts := strings.SplitN(raw, ",", 2) - if len(parts) != 2 { - return nil, "", errors.New("invalid data url") - } - meta := parts[0] - payload := parts[1] - decoded, err := decodeBase64WithLimit(payload, soraImageInputMaxBytes) - if err != nil { - return nil, "", err - } - ext := "" - if strings.HasPrefix(meta, "data:") { - metaParts := strings.SplitN(meta[5:], ";", 2) - if len(metaParts) > 0 { - if exts, err := mime.ExtensionsByType(metaParts[0]); err == nil && len(exts) > 0 { - ext = exts[0] - } - } - } - filename := "image" + ext - return decoded, filename, nil - } - if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") { - return downloadSoraImageInput(ctx, raw) - } - decoded, err := decodeBase64WithLimit(raw, soraImageInputMaxBytes) - if err != nil { - return nil, "", errors.New("invalid base64 image") - } - return decoded, "image.png", nil -} - -func decodeSoraVideoInput(ctx context.Context, input string) ([]byte, error) { - raw := strings.TrimSpace(input) - if raw == "" { - return nil, errors.New("empty video input") - } - if strings.HasPrefix(raw, "data:") { - parts := strings.SplitN(raw, ",", 2) - if len(parts) != 2 { - return nil, errors.New("invalid video data url") - } - decoded, err := decodeBase64WithLimit(parts[1], soraVideoInputMaxBytes) - if err != nil { - return nil, errors.New("invalid base64 video") - } - if len(decoded) == 0 { - return nil, errors.New("empty video data") - } - return decoded, nil - } - if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") { - return downloadSoraVideoInput(ctx, raw) - } - decoded, err := decodeBase64WithLimit(raw, soraVideoInputMaxBytes) - if err != nil { - return nil, errors.New("invalid base64 video") - } - if len(decoded) == 0 { - return nil, errors.New("empty video data") - } - return decoded, nil -} - -func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string, error) { - parsed, err := validateSoraRemoteURL(rawURL) - if err != nil { - return nil, "", err - } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, parsed.String(), nil) - if err != nil { - return nil, "", err - } - client := &http.Client{ - Timeout: soraImageInputTimeout, - CheckRedirect: func(req *http.Request, via []*http.Request) error { - if len(via) >= soraImageInputMaxRedirects { - return errors.New("too many redirects") - } - return validateSoraRemoteURLValue(req.URL) - }, - } - resp, err := client.Do(req) - if err != nil { - return nil, "", err - } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { - return nil, "", fmt.Errorf("download image failed: %d", resp.StatusCode) - } - data, err := io.ReadAll(io.LimitReader(resp.Body, soraImageInputMaxBytes)) - if err != nil { - return nil, "", err - } - ext := fileExtFromURL(parsed.String()) - if ext == "" { - ext = fileExtFromContentType(resp.Header.Get("Content-Type")) - } - filename := "image" + ext - return data, filename, nil -} - -func downloadSoraVideoInput(ctx context.Context, rawURL string) ([]byte, error) { - parsed, err := validateSoraRemoteURL(rawURL) - if err != nil { - return nil, err - } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, parsed.String(), nil) - if err != nil { - return nil, err - } - client := &http.Client{ - Timeout: soraVideoInputTimeout, - CheckRedirect: func(req *http.Request, via []*http.Request) error { - if len(via) >= soraVideoInputMaxRedirects { - return errors.New("too many redirects") - } - return validateSoraRemoteURLValue(req.URL) - }, - } - resp, err := client.Do(req) - if err != nil { - return nil, err - } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("download video failed: %d", resp.StatusCode) - } - data, err := io.ReadAll(io.LimitReader(resp.Body, soraVideoInputMaxBytes)) - if err != nil { - return nil, err - } - if len(data) == 0 { - return nil, errors.New("empty video content") - } - return data, nil -} - -func decodeBase64WithLimit(encoded string, maxBytes int64) ([]byte, error) { - if maxBytes <= 0 { - return nil, errors.New("invalid max bytes limit") - } - decoder := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded)) - limited := io.LimitReader(decoder, maxBytes+1) - data, err := io.ReadAll(limited) - if err != nil { - return nil, err - } - if int64(len(data)) > maxBytes { - return nil, fmt.Errorf("input exceeds %d bytes limit", maxBytes) - } - return data, nil -} - -func validateSoraRemoteURL(raw string) (*url.URL, error) { - if strings.TrimSpace(raw) == "" { - return nil, errors.New("empty remote url") - } - parsed, err := url.Parse(raw) - if err != nil { - return nil, fmt.Errorf("invalid remote url: %w", err) - } - if err := validateSoraRemoteURLValue(parsed); err != nil { - return nil, err - } - return parsed, nil -} - -func validateSoraRemoteURLValue(parsed *url.URL) error { - if parsed == nil { - return errors.New("invalid remote url") - } - scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme)) - if scheme != "http" && scheme != "https" { - return errors.New("only http/https remote url is allowed") - } - if parsed.User != nil { - return errors.New("remote url cannot contain userinfo") - } - host := strings.ToLower(strings.TrimSpace(parsed.Hostname())) - if host == "" { - return errors.New("remote url missing host") - } - if _, blocked := soraBlockedHostnames[host]; blocked { - return errors.New("remote url is not allowed") - } - if ip := net.ParseIP(host); ip != nil { - if isSoraBlockedIP(ip) { - return errors.New("remote url is not allowed") - } - return nil - } - ips, err := net.LookupIP(host) - if err != nil { - return fmt.Errorf("resolve remote url failed: %w", err) - } - for _, ip := range ips { - if isSoraBlockedIP(ip) { - return errors.New("remote url is not allowed") - } - } - return nil -} - -func isSoraBlockedIP(ip net.IP) bool { - if ip == nil { - return true - } - for _, cidr := range soraBlockedCIDRs { - if cidr.Contains(ip) { - return true - } - } - return false -} - -func mustParseCIDRs(values []string) []*net.IPNet { - out := make([]*net.IPNet, 0, len(values)) - for _, val := range values { - _, cidr, err := net.ParseCIDR(val) - if err != nil { - continue - } - out = append(out, cidr) - } - return out -} diff --git a/backend/internal/service/sora_gateway_service_test.go b/backend/internal/service/sora_gateway_service_test.go deleted file mode 100644 index 2fef600c..00000000 --- a/backend/internal/service/sora_gateway_service_test.go +++ /dev/null @@ -1,564 +0,0 @@ -//go:build unit - -package service - -import ( - "context" - "encoding/json" - "errors" - "net/http" - "net/http/httptest" - "strings" - "testing" - "time" - - "github.com/Wei-Shaw/sub2api/internal/config" - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/require" -) - -var _ SoraClient = (*stubSoraClientForPoll)(nil) - -type stubSoraClientForPoll struct { - imageStatus *SoraImageTaskStatus - videoStatus *SoraVideoTaskStatus - imageCalls int - videoCalls int - enhanced string - enhanceErr error - storyboard bool - videoReq SoraVideoRequest - parseErr error - postCalls int - deleteCalls int -} - -func (s *stubSoraClientForPoll) Enabled() bool { return true } -func (s *stubSoraClientForPoll) UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) { - return "", nil -} -func (s *stubSoraClientForPoll) CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error) { - return "task-image", nil -} -func (s *stubSoraClientForPoll) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) { - s.videoReq = req - return "task-video", nil -} -func (s *stubSoraClientForPoll) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) { - s.storyboard = true - return "task-video", nil -} -func (s *stubSoraClientForPoll) UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) { - return "cameo-1", nil -} -func (s *stubSoraClientForPoll) GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) { - return &SoraCameoStatus{ - Status: "finalized", - StatusMessage: "Completed", - DisplayNameHint: "Character", - UsernameHint: "user.character", - ProfileAssetURL: "https://example.com/avatar.webp", - }, nil -} -func (s *stubSoraClientForPoll) DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) { - return []byte("avatar"), nil -} -func (s *stubSoraClientForPoll) UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) { - return "asset-pointer", nil -} -func (s *stubSoraClientForPoll) FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) { - return "character-1", nil -} -func (s *stubSoraClientForPoll) SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error { - return nil -} -func (s *stubSoraClientForPoll) DeleteCharacter(ctx context.Context, account *Account, characterID string) error { - return nil -} -func (s *stubSoraClientForPoll) PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) { - s.postCalls++ - return "s_post", nil -} -func (s *stubSoraClientForPoll) DeletePost(ctx context.Context, account *Account, postID string) error { - s.deleteCalls++ - return nil -} -func (s *stubSoraClientForPoll) GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) { - if s.parseErr != nil { - return "", s.parseErr - } - return "https://example.com/no-watermark.mp4", nil -} -func (s *stubSoraClientForPoll) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) { - if s.enhanced != "" { - return s.enhanced, s.enhanceErr - } - return "enhanced prompt", s.enhanceErr -} -func (s *stubSoraClientForPoll) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) { - s.imageCalls++ - return s.imageStatus, nil -} -func (s *stubSoraClientForPoll) GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) { - s.videoCalls++ - return s.videoStatus, nil -} - -func TestSoraGatewayService_PollImageTaskCompleted(t *testing.T) { - client := &stubSoraClientForPoll{ - imageStatus: &SoraImageTaskStatus{ - Status: "completed", - URLs: []string{"https://example.com/a.png"}, - }, - } - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - PollIntervalSeconds: 1, - MaxPollAttempts: 1, - }, - }, - } - service := NewSoraGatewayService(client, nil, nil, cfg) - - urls, err := service.pollImageTask(context.Background(), nil, &Account{ID: 1}, "task", false) - require.NoError(t, err) - require.Equal(t, []string{"https://example.com/a.png"}, urls) - require.Equal(t, 1, client.imageCalls) -} - -func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) { - client := &stubSoraClientForPoll{ - enhanced: "cinematic prompt", - } - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - PollIntervalSeconds: 1, - MaxPollAttempts: 1, - }, - }, - } - svc := NewSoraGatewayService(client, nil, nil, cfg) - account := &Account{ - ID: 1, - Platform: PlatformSora, - Status: StatusActive, - Credentials: map[string]any{ - "model_mapping": map[string]any{ - "prompt-enhance-short-10s": "prompt-enhance-short-15s", - }, - }, - } - body := []byte(`{"model":"prompt-enhance-short-10s","messages":[{"role":"user","content":"cat running"}],"stream":false}`) - - result, err := svc.Forward(context.Background(), nil, account, body, false) - require.NoError(t, err) - require.NotNil(t, result) - require.Equal(t, "prompt", result.MediaType) - require.Equal(t, "prompt-enhance-short-10s", result.Model) - require.Equal(t, "prompt-enhance-short-15s", result.UpstreamModel) -} - -func TestSoraGatewayService_ForwardStoryboardPrompt(t *testing.T) { - client := &stubSoraClientForPoll{ - videoStatus: &SoraVideoTaskStatus{ - Status: "completed", - URLs: []string{"https://example.com/v.mp4"}, - }, - } - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - PollIntervalSeconds: 1, - MaxPollAttempts: 1, - }, - }, - } - svc := NewSoraGatewayService(client, nil, nil, cfg) - account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive} - body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"[5.0s]猫猫跳伞 [5.0s]猫猫落地"}],"stream":false}`) - - result, err := svc.Forward(context.Background(), nil, account, body, false) - require.NoError(t, err) - require.NotNil(t, result) - require.True(t, client.storyboard) -} - -func TestSoraGatewayService_ForwardVideoCount(t *testing.T) { - client := &stubSoraClientForPoll{ - videoStatus: &SoraVideoTaskStatus{ - Status: "completed", - URLs: []string{"https://example.com/v.mp4"}, - }, - } - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - PollIntervalSeconds: 1, - MaxPollAttempts: 1, - }, - }, - } - svc := NewSoraGatewayService(client, nil, nil, cfg) - account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive} - body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"video_count":3,"stream":false}`) - - result, err := svc.Forward(context.Background(), nil, account, body, false) - require.NoError(t, err) - require.NotNil(t, result) - require.Equal(t, 3, client.videoReq.VideoCount) -} - -func TestSoraGatewayService_ForwardCharacterOnly(t *testing.T) { - client := &stubSoraClientForPoll{} - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - PollIntervalSeconds: 1, - MaxPollAttempts: 1, - }, - }, - } - svc := NewSoraGatewayService(client, nil, nil, cfg) - account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive} - body := []byte(`{"model":"sora2-landscape-10s","video":"aGVsbG8=","stream":false}`) - - result, err := svc.Forward(context.Background(), nil, account, body, false) - require.NoError(t, err) - require.NotNil(t, result) - require.Equal(t, "prompt", result.MediaType) - require.Equal(t, 0, client.videoCalls) -} - -func TestSoraGatewayService_ForwardWatermarkFallback(t *testing.T) { - client := &stubSoraClientForPoll{ - videoStatus: &SoraVideoTaskStatus{ - Status: "completed", - URLs: []string{"https://example.com/original.mp4"}, - GenerationID: "gen_1", - }, - parseErr: errors.New("parse failed"), - } - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - PollIntervalSeconds: 1, - MaxPollAttempts: 1, - }, - }, - } - svc := NewSoraGatewayService(client, nil, nil, cfg) - account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive} - body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_fallback_on_failure":true}`) - - result, err := svc.Forward(context.Background(), nil, account, body, false) - require.NoError(t, err) - require.NotNil(t, result) - require.Equal(t, "https://example.com/original.mp4", result.MediaURL) - require.Equal(t, 1, client.postCalls) - require.Equal(t, 0, client.deleteCalls) -} - -func TestSoraGatewayService_ForwardWatermarkCustomSuccessAndDelete(t *testing.T) { - client := &stubSoraClientForPoll{ - videoStatus: &SoraVideoTaskStatus{ - Status: "completed", - URLs: []string{"https://example.com/original.mp4"}, - GenerationID: "gen_1", - }, - } - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - PollIntervalSeconds: 1, - MaxPollAttempts: 1, - }, - }, - } - svc := NewSoraGatewayService(client, nil, nil, cfg) - account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive} - body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_delete_post":true}`) - - result, err := svc.Forward(context.Background(), nil, account, body, false) - require.NoError(t, err) - require.NotNil(t, result) - require.Equal(t, "https://example.com/no-watermark.mp4", result.MediaURL) - require.Equal(t, 1, client.postCalls) - require.Equal(t, 1, client.deleteCalls) -} - -func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) { - client := &stubSoraClientForPoll{ - videoStatus: &SoraVideoTaskStatus{ - Status: "failed", - ErrorMsg: "reject", - }, - } - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - PollIntervalSeconds: 1, - MaxPollAttempts: 1, - }, - }, - } - service := NewSoraGatewayService(client, nil, nil, cfg) - - status, err := service.pollVideoTaskDetailed(context.Background(), nil, &Account{ID: 1}, "task", false) - require.Error(t, err) - require.Nil(t, status) - require.Contains(t, err.Error(), "reject") - require.Equal(t, 1, client.videoCalls) -} - -func TestSoraGatewayService_BuildSoraMediaURLSigned(t *testing.T) { - cfg := &config.Config{ - Gateway: config.GatewayConfig{ - SoraMediaSigningKey: "test-key", - SoraMediaSignedURLTTLSeconds: 600, - }, - } - service := NewSoraGatewayService(nil, nil, nil, cfg) - - url := service.buildSoraMediaURL("/image/2025/01/01/a.png", "") - require.Contains(t, url, "/sora/media-signed") - require.Contains(t, url, "expires=") - require.Contains(t, url, "sig=") -} - -func TestNormalizeSoraMediaURLs_Empty(t *testing.T) { - svc := NewSoraGatewayService(nil, nil, nil, &config.Config{}) - result := svc.normalizeSoraMediaURLs(nil) - require.Empty(t, result) - - result = svc.normalizeSoraMediaURLs([]string{}) - require.Empty(t, result) -} - -func TestNormalizeSoraMediaURLs_HTTPUrls(t *testing.T) { - svc := NewSoraGatewayService(nil, nil, nil, &config.Config{}) - urls := []string{"https://example.com/a.png", "http://example.com/b.mp4"} - result := svc.normalizeSoraMediaURLs(urls) - require.Equal(t, urls, result) -} - -func TestNormalizeSoraMediaURLs_LocalPaths(t *testing.T) { - cfg := &config.Config{} - svc := NewSoraGatewayService(nil, nil, nil, cfg) - urls := []string{"/image/2025/01/a.png", "video/2025/01/b.mp4"} - result := svc.normalizeSoraMediaURLs(urls) - require.Len(t, result, 2) - require.Contains(t, result[0], "/sora/media") - require.Contains(t, result[1], "/sora/media") -} - -func TestNormalizeSoraMediaURLs_SkipsBlank(t *testing.T) { - svc := NewSoraGatewayService(nil, nil, nil, &config.Config{}) - urls := []string{"https://example.com/a.png", "", " ", "https://example.com/b.png"} - result := svc.normalizeSoraMediaURLs(urls) - require.Len(t, result, 2) -} - -func TestBuildSoraContent_Image(t *testing.T) { - content := buildSoraContent("image", []string{"https://a.com/1.png", "https://a.com/2.png"}) - require.Contains(t, content, "![image](https://a.com/1.png)") - require.Contains(t, content, "![image](https://a.com/2.png)") -} - -func TestBuildSoraContent_Video(t *testing.T) { - content := buildSoraContent("video", []string{"https://a.com/v.mp4"}) - require.Contains(t, content, "
-
- - -

{{ t('admin.users.soraStorageQuotaHint') }}

-