diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go
index 47f8f518..e9ce927e 100644
--- a/backend/cmd/server/wire.go
+++ b/backend/cmd/server/wire.go
@@ -43,11 +43,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
server.ProviderSet,
// Payment providers
- payment.ProvideRegistry,
- payment.ProvideEncryptionKey,
- payment.ProvideDefaultLoadBalancer,
- service.ProvidePaymentConfigService,
- service.ProvidePaymentOrderExpiryService,
+ payment.ProviderSet,
// Privacy client factory for OpenAI training opt-out
providePrivacyClientFactory,
@@ -84,6 +80,7 @@ func provideCleanup(
opsCleanup *service.OpsCleanupService,
opsScheduledReport *service.OpsScheduledReportService,
opsSystemLogSink *service.OpsSystemLogSink,
+ soraMediaCleanup *service.SoraMediaCleanupService, // 从本地版本合并
schedulerSnapshot *service.SchedulerSnapshotService,
tokenRefresh *service.TokenRefreshService,
accountExpiry *service.AccountExpiryService,
@@ -245,6 +242,12 @@ func provideCleanup(
}
return nil
}},
+ {"SoraMediaCleanupService", func() error {
+ if soraMediaCleanup != nil {
+ soraMediaCleanup.Stop()
+ }
+ return nil
+ }},
}
infraSteps := []cleanupStep{
diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go
index c288a289..b634bf31 100644
--- a/backend/cmd/server/wire_gen.go
+++ b/backend/cmd/server/wire_gen.go
@@ -50,7 +50,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
refreshTokenCache := repository.NewRefreshTokenCache(redisClient)
settingRepository := repository.NewSettingRepository(client)
groupRepository := repository.NewGroupRepository(client, db)
- channelRepository := repository.NewChannelRepository(db)
settingService := service.ProvideSettingService(settingRepository, groupRepository, configConfig)
emailCache := repository.NewEmailCache(redisClient)
emailService := service.NewEmailService(settingRepository, emailCache)
@@ -65,7 +64,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
userGroupRateRepository := repository.NewUserGroupRateRepository(db)
apiKeyCache := repository.NewAPIKeyCache(redisClient)
apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig)
- apiKeyService.SetRateLimitCacheInvalidator(billingCache)
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
@@ -73,15 +71,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache)
redeemCache := repository.NewRedeemCache(redisClient)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
- registry := payment.ProvideRegistry()
- encryptionKey, err := payment.ProvideEncryptionKey(configConfig)
- if err != nil {
- return nil, err
- }
- defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey)
- paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey)
- paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository)
- paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
secretEncryptor, err := repository.NewAESEncryptor(configConfig)
if err != nil {
return nil, err
@@ -92,7 +81,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
userHandler := handler.NewUserHandler(userService)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db)
- usageBillingRepository := repository.NewUsageBillingRepository(client, db)
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
redeemHandler := handler.NewRedeemHandler(redeemService)
@@ -120,11 +108,14 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
+ sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
+ rpmCache := repository.NewRPMCache(redisClient)
+ groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache)
+ groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService)
claudeOAuthClient := repository.NewClaudeOAuthClient()
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
openAIOAuthClient := repository.NewOpenAIOAuthClient()
openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient)
- openAIOAuthService.SetPrivacyClientFactory(privacyClientFactory)
geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig)
geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient()
driveClient := repository.NewGeminiDriveClient()
@@ -134,7 +125,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
- oauthRefreshAPI := service.NewOAuthRefreshAPI(accountRepository, geminiTokenCache)
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator)
httpUpstream := repository.NewHTTPUpstream(configConfig)
@@ -142,23 +132,20 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
usageCache := service.NewUsageCache()
identityCache := repository.NewIdentityCache(redisClient)
- geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oauthRefreshAPI)
- gatewayCache := repository.NewGatewayCache(redisClient)
- schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
- schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
- antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oauthRefreshAPI, tempUnschedCache)
- internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
tlsFingerprintProfileRepository := repository.NewTLSFingerprintProfileRepository(client)
tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient)
tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache)
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService)
+ oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache)
+ geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI)
+ gatewayCache := repository.NewGatewayCache(redisClient)
+ schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
+ schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
+ antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache)
+ internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
- sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
- rpmCache := repository.NewRPMCache(redisClient)
- groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache)
- groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
dataManagementService := service.NewDataManagementService()
@@ -175,6 +162,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
adminRedeemHandler := admin.NewRedeemHandler(adminService, redeemService)
promoHandler := admin.NewPromoHandler(promoService)
opsRepository := repository.NewOpsRepository(db)
+ usageBillingRepository := repository.NewUsageBillingRepository(client, db)
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
if err != nil {
@@ -183,16 +171,25 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
billingService := service.NewBillingService(configConfig, pricingService)
identityService := service.NewIdentityService(identityCache)
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
- claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oauthRefreshAPI)
+ claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI)
digestSessionStore := service.NewDigestSessionStore()
+ channelRepository := repository.NewChannelRepository(db)
channelService := service.NewChannelService(channelRepository, apiKeyAuthCacheInvalidator)
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver)
- openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI)
+ openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
+ encryptionKey, err := payment.ProvideEncryptionKey(configConfig)
+ if err != nil {
+ return nil, err
+ }
+ paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey)
+ registry := payment.ProvideRegistry()
+ defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey)
+ paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository)
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService)
opsHandler := admin.NewOpsHandler(opsService)
updateCache := repository.NewUpdateCache(redisClient)
@@ -221,20 +218,30 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
channelHandler := admin.NewChannelHandler(channelService, billingService)
- adminPaymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService)
- adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, adminPaymentHandler)
+ paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService)
+ adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, paymentHandler)
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, userMessageQueueService, configConfig, settingService)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig)
+ soraAccountRepository := repository.NewSoraAccountRepository(db)
+ soraSDKClient := service.ProvideSoraSDKClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository)
+ soraGatewayService := service.NewSoraGatewayService(soraSDKClient, rateLimitService, httpUpstream, configConfig)
+ soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, usageRecordWorkerPool, configConfig)
+ soraGenerationRepository := repository.NewSoraGenerationRepository(db)
+ soraS3Storage := service.NewSoraS3Storage(settingService)
+ soraQuotaService := service.NewSoraQuotaService(userRepository, groupRepository, settingService)
+ soraGenerationService := service.NewSoraGenerationService(soraGenerationRepository, soraS3Storage, soraQuotaService)
+ soraMediaStorage := service.ProvideSoraMediaStorage(configConfig)
+ soraClientHandler := handler.NewSoraClientHandler(soraGenerationService, soraQuotaService, soraS3Storage, soraGatewayService, gatewayService, soraMediaStorage, apiKeyService)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
totpHandler := handler.NewTotpHandler(totpService)
handlerPaymentHandler := handler.NewPaymentHandler(paymentService, paymentConfigService, channelService)
paymentWebhookHandler := handler.NewPaymentWebhookHandler(paymentService, registry)
idempotencyCoordinator := service.ProvideIdempotencyCoordinator(idempotencyRepository, configConfig)
idempotencyCleanupService := service.ProvideIdempotencyCleanupService(idempotencyRepository, configConfig)
- handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler, handlerPaymentHandler, paymentWebhookHandler, idempotencyCoordinator, idempotencyCleanupService)
+ handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, soraGatewayHandler, soraClientHandler, handlerSettingHandler, totpHandler, handlerPaymentHandler, paymentWebhookHandler, idempotencyCoordinator, idempotencyCleanupService)
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
@@ -245,11 +252,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
- tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oauthRefreshAPI)
+ soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig)
+ tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
- v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService)
+ paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
+ v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService)
application := &Application{
Server: httpServer,
Cleanup: v,
@@ -284,6 +293,7 @@ func provideCleanup(
opsCleanup *service.OpsCleanupService,
opsScheduledReport *service.OpsScheduledReportService,
opsSystemLogSink *service.OpsSystemLogSink,
+ soraMediaCleanup *service.SoraMediaCleanupService,
schedulerSnapshot *service.SchedulerSnapshotService,
tokenRefresh *service.TokenRefreshService,
accountExpiry *service.AccountExpiryService,
diff --git a/backend/cmd/server/wire_gen_test.go b/backend/cmd/server/wire_gen_test.go
index a6e0551a..4db6f144 100644
--- a/backend/cmd/server/wire_gen_test.go
+++ b/backend/cmd/server/wire_gen_test.go
@@ -57,6 +57,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
&service.OpsCleanupService{},
&service.OpsScheduledReportService{},
opsSystemLogSinkSvc,
+ nil, // soraMediaCleanup (从本地版本合并)
schedulerSnapshotSvc,
tokenRefreshSvc,
accountExpirySvc,
diff --git a/backend/ent/group.go b/backend/ent/group.go
index f10b50c3..80fd7982 100644
--- a/backend/ent/group.go
+++ b/backend/ent/group.go
@@ -79,6 +79,8 @@ type Group struct {
DefaultMappedModel string `json:"default_mapped_model,omitempty"`
// OpenAI Messages 调度模型配置:按 Claude 系列/精确模型映射到目标 GPT 模型
MessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config,omitempty"`
+ // SoraStorageQuotaBytes holds the value of the "sora_storage_quota_bytes" field.
+ SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the GroupQuery when eager-loading is set.
Edges GroupEdges `json:"edges"`
@@ -191,7 +193,7 @@ func (*Group) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullBool)
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k:
values[i] = new(sql.NullFloat64)
- case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder:
+ case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder, group.FieldSoraStorageQuotaBytes:
values[i] = new(sql.NullInt64)
case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType, group.FieldDefaultMappedModel:
values[i] = new(sql.NullString)
@@ -414,6 +416,12 @@ func (_m *Group) assignValues(columns []string, values []any) error {
return fmt.Errorf("unmarshal field messages_dispatch_model_config: %w", err)
}
}
+ case group.FieldSoraStorageQuotaBytes:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field sora_storage_quota_bytes", values[i])
+ } else if value.Valid {
+ _m.SoraStorageQuotaBytes = value.Int64
+ }
default:
_m.selectValues.Set(columns[i], values[i])
}
@@ -599,6 +607,9 @@ func (_m *Group) String() string {
builder.WriteString(", ")
builder.WriteString("messages_dispatch_model_config=")
builder.WriteString(fmt.Sprintf("%v", _m.MessagesDispatchModelConfig))
+ builder.WriteString(", ")
+ builder.WriteString("sora_storage_quota_bytes=")
+ builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageQuotaBytes))
builder.WriteByte(')')
return builder.String()
}
diff --git a/backend/ent/group/group.go b/backend/ent/group/group.go
index b1371630..12ed7987 100644
--- a/backend/ent/group/group.go
+++ b/backend/ent/group/group.go
@@ -76,6 +76,8 @@ const (
FieldDefaultMappedModel = "default_mapped_model"
// FieldMessagesDispatchModelConfig holds the string denoting the messages_dispatch_model_config field in the database.
FieldMessagesDispatchModelConfig = "messages_dispatch_model_config"
+ // FieldSoraStorageQuotaBytes holds the string denoting the sora_storage_quota_bytes field in the database.
+ FieldSoraStorageQuotaBytes = "sora_storage_quota_bytes"
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
EdgeAPIKeys = "api_keys"
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
@@ -181,6 +183,7 @@ var Columns = []string{
FieldRequirePrivacySet,
FieldDefaultMappedModel,
FieldMessagesDispatchModelConfig,
+ FieldSoraStorageQuotaBytes,
}
var (
@@ -258,6 +261,8 @@ var (
DefaultMappedModelValidator func(string) error
// DefaultMessagesDispatchModelConfig holds the default value on creation for the "messages_dispatch_model_config" field.
DefaultMessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig
+ // DefaultSoraStorageQuotaBytes holds the default value on creation for the "sora_storage_quota_bytes" field.
+ DefaultSoraStorageQuotaBytes int64
)
// OrderOption defines the ordering options for the Group queries.
@@ -403,6 +408,11 @@ func ByDefaultMappedModel(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDefaultMappedModel, opts...).ToFunc()
}
+// BySoraStorageQuotaBytes orders the results by the sora_storage_quota_bytes field.
+func BySoraStorageQuotaBytes(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldSoraStorageQuotaBytes, opts...).ToFunc()
+}
+
// ByAPIKeysCount orders the results by api_keys count.
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
diff --git a/backend/ent/group/where.go b/backend/ent/group/where.go
index cba2ce5f..cd02f7e3 100644
--- a/backend/ent/group/where.go
+++ b/backend/ent/group/where.go
@@ -190,6 +190,11 @@ func DefaultMappedModel(v string) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldDefaultMappedModel, v))
}
+// SoraStorageQuotaBytes applies equality check predicate on the "sora_storage_quota_bytes" field. It's identical to SoraStorageQuotaBytesEQ.
+func SoraStorageQuotaBytes(v int64) predicate.Group {
+ return predicate.Group(sql.FieldEQ(FieldSoraStorageQuotaBytes, v))
+}
+
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldCreatedAt, v))
@@ -1320,6 +1325,46 @@ func DefaultMappedModelContainsFold(v string) predicate.Group {
return predicate.Group(sql.FieldContainsFold(FieldDefaultMappedModel, v))
}
+// SoraStorageQuotaBytesEQ applies the EQ predicate on the "sora_storage_quota_bytes" field.
+func SoraStorageQuotaBytesEQ(v int64) predicate.Group {
+ return predicate.Group(sql.FieldEQ(FieldSoraStorageQuotaBytes, v))
+}
+
+// SoraStorageQuotaBytesNEQ applies the NEQ predicate on the "sora_storage_quota_bytes" field.
+func SoraStorageQuotaBytesNEQ(v int64) predicate.Group {
+ return predicate.Group(sql.FieldNEQ(FieldSoraStorageQuotaBytes, v))
+}
+
+// SoraStorageQuotaBytesIn applies the In predicate on the "sora_storage_quota_bytes" field.
+func SoraStorageQuotaBytesIn(vs ...int64) predicate.Group {
+ return predicate.Group(sql.FieldIn(FieldSoraStorageQuotaBytes, vs...))
+}
+
+// SoraStorageQuotaBytesNotIn applies the NotIn predicate on the "sora_storage_quota_bytes" field.
+func SoraStorageQuotaBytesNotIn(vs ...int64) predicate.Group {
+ return predicate.Group(sql.FieldNotIn(FieldSoraStorageQuotaBytes, vs...))
+}
+
+// SoraStorageQuotaBytesGT applies the GT predicate on the "sora_storage_quota_bytes" field.
+func SoraStorageQuotaBytesGT(v int64) predicate.Group {
+ return predicate.Group(sql.FieldGT(FieldSoraStorageQuotaBytes, v))
+}
+
+// SoraStorageQuotaBytesGTE applies the GTE predicate on the "sora_storage_quota_bytes" field.
+func SoraStorageQuotaBytesGTE(v int64) predicate.Group {
+ return predicate.Group(sql.FieldGTE(FieldSoraStorageQuotaBytes, v))
+}
+
+// SoraStorageQuotaBytesLT applies the LT predicate on the "sora_storage_quota_bytes" field.
+func SoraStorageQuotaBytesLT(v int64) predicate.Group {
+ return predicate.Group(sql.FieldLT(FieldSoraStorageQuotaBytes, v))
+}
+
+// SoraStorageQuotaBytesLTE applies the LTE predicate on the "sora_storage_quota_bytes" field.
+func SoraStorageQuotaBytesLTE(v int64) predicate.Group {
+ return predicate.Group(sql.FieldLTE(FieldSoraStorageQuotaBytes, v))
+}
+
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
func HasAPIKeys() predicate.Group {
return predicate.Group(func(s *sql.Selector) {
diff --git a/backend/ent/group_create.go b/backend/ent/group_create.go
index f412fa40..a7263051 100644
--- a/backend/ent/group_create.go
+++ b/backend/ent/group_create.go
@@ -425,6 +425,20 @@ func (_c *GroupCreate) SetNillableMessagesDispatchModelConfig(v *domain.OpenAIMe
return _c
}
+// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
+func (_c *GroupCreate) SetSoraStorageQuotaBytes(v int64) *GroupCreate {
+ _c.mutation.SetSoraStorageQuotaBytes(v)
+ return _c
+}
+
+// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
+func (_c *GroupCreate) SetNillableSoraStorageQuotaBytes(v *int64) *GroupCreate {
+ if v != nil {
+ _c.SetSoraStorageQuotaBytes(*v)
+ }
+ return _c
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
_c.mutation.AddAPIKeyIDs(ids...)
@@ -630,6 +644,10 @@ func (_c *GroupCreate) defaults() error {
v := group.DefaultMessagesDispatchModelConfig
_c.mutation.SetMessagesDispatchModelConfig(v)
}
+ if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok {
+ v := group.DefaultSoraStorageQuotaBytes
+ _c.mutation.SetSoraStorageQuotaBytes(v)
+ }
return nil
}
@@ -717,6 +735,9 @@ func (_c *GroupCreate) check() error {
if _, ok := _c.mutation.MessagesDispatchModelConfig(); !ok {
return &ValidationError{Name: "messages_dispatch_model_config", err: errors.New(`ent: missing required field "Group.messages_dispatch_model_config"`)}
}
+ if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok {
+ return &ValidationError{Name: "sora_storage_quota_bytes", err: errors.New(`ent: missing required field "Group.sora_storage_quota_bytes"`)}
+ }
return nil
}
@@ -864,6 +885,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
_spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value)
_node.MessagesDispatchModelConfig = value
}
+ if value, ok := _c.mutation.SoraStorageQuotaBytes(); ok {
+ _spec.SetField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
+ _node.SoraStorageQuotaBytes = value
+ }
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -1500,6 +1525,24 @@ func (u *GroupUpsert) UpdateMessagesDispatchModelConfig() *GroupUpsert {
return u
}
+// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
+func (u *GroupUpsert) SetSoraStorageQuotaBytes(v int64) *GroupUpsert {
+ u.Set(group.FieldSoraStorageQuotaBytes, v)
+ return u
+}
+
+// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
+func (u *GroupUpsert) UpdateSoraStorageQuotaBytes() *GroupUpsert {
+ u.SetExcluded(group.FieldSoraStorageQuotaBytes)
+ return u
+}
+
+// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
+func (u *GroupUpsert) AddSoraStorageQuotaBytes(v int64) *GroupUpsert {
+ u.Add(group.FieldSoraStorageQuotaBytes, v)
+ return u
+}
+
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
@@ -2105,6 +2148,27 @@ func (u *GroupUpsertOne) UpdateMessagesDispatchModelConfig() *GroupUpsertOne {
})
}
+// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
+func (u *GroupUpsertOne) SetSoraStorageQuotaBytes(v int64) *GroupUpsertOne {
+ return u.Update(func(s *GroupUpsert) {
+ s.SetSoraStorageQuotaBytes(v)
+ })
+}
+
+// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
+func (u *GroupUpsertOne) AddSoraStorageQuotaBytes(v int64) *GroupUpsertOne {
+ return u.Update(func(s *GroupUpsert) {
+ s.AddSoraStorageQuotaBytes(v)
+ })
+}
+
+// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
+func (u *GroupUpsertOne) UpdateSoraStorageQuotaBytes() *GroupUpsertOne {
+ return u.Update(func(s *GroupUpsert) {
+ s.UpdateSoraStorageQuotaBytes()
+ })
+}
+
// Exec executes the query.
func (u *GroupUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
@@ -2876,6 +2940,27 @@ func (u *GroupUpsertBulk) UpdateMessagesDispatchModelConfig() *GroupUpsertBulk {
})
}
+// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
+func (u *GroupUpsertBulk) SetSoraStorageQuotaBytes(v int64) *GroupUpsertBulk {
+ return u.Update(func(s *GroupUpsert) {
+ s.SetSoraStorageQuotaBytes(v)
+ })
+}
+
+// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
+func (u *GroupUpsertBulk) AddSoraStorageQuotaBytes(v int64) *GroupUpsertBulk {
+ return u.Update(func(s *GroupUpsert) {
+ s.AddSoraStorageQuotaBytes(v)
+ })
+}
+
+// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
+func (u *GroupUpsertBulk) UpdateSoraStorageQuotaBytes() *GroupUpsertBulk {
+ return u.Update(func(s *GroupUpsert) {
+ s.UpdateSoraStorageQuotaBytes()
+ })
+}
+
// Exec executes the query.
func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {
diff --git a/backend/ent/group_update.go b/backend/ent/group_update.go
index 7b6d6193..fa398180 100644
--- a/backend/ent/group_update.go
+++ b/backend/ent/group_update.go
@@ -567,6 +567,27 @@ func (_u *GroupUpdate) SetNillableMessagesDispatchModelConfig(v *domain.OpenAIMe
return _u
}
+// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
+func (_u *GroupUpdate) SetSoraStorageQuotaBytes(v int64) *GroupUpdate {
+ _u.mutation.ResetSoraStorageQuotaBytes()
+ _u.mutation.SetSoraStorageQuotaBytes(v)
+ return _u
+}
+
+// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
+func (_u *GroupUpdate) SetNillableSoraStorageQuotaBytes(v *int64) *GroupUpdate {
+ if v != nil {
+ _u.SetSoraStorageQuotaBytes(*v)
+ }
+ return _u
+}
+
+// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field.
+func (_u *GroupUpdate) AddSoraStorageQuotaBytes(v int64) *GroupUpdate {
+ _u.mutation.AddSoraStorageQuotaBytes(v)
+ return _u
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -1030,6 +1051,12 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.MessagesDispatchModelConfig(); ok {
_spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value)
}
+ if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok {
+ _spec.SetField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok {
+ _spec.AddField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
+ }
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -1875,6 +1902,27 @@ func (_u *GroupUpdateOne) SetNillableMessagesDispatchModelConfig(v *domain.OpenA
return _u
}
+// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
+func (_u *GroupUpdateOne) SetSoraStorageQuotaBytes(v int64) *GroupUpdateOne {
+ _u.mutation.ResetSoraStorageQuotaBytes()
+ _u.mutation.SetSoraStorageQuotaBytes(v)
+ return _u
+}
+
+// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
+func (_u *GroupUpdateOne) SetNillableSoraStorageQuotaBytes(v *int64) *GroupUpdateOne {
+ if v != nil {
+ _u.SetSoraStorageQuotaBytes(*v)
+ }
+ return _u
+}
+
+// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field.
+func (_u *GroupUpdateOne) AddSoraStorageQuotaBytes(v int64) *GroupUpdateOne {
+ _u.mutation.AddSoraStorageQuotaBytes(v)
+ return _u
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -2368,6 +2416,12 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
if value, ok := _u.mutation.MessagesDispatchModelConfig(); ok {
_spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value)
}
+ if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok {
+ _spec.SetField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok {
+ _spec.AddField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
+ }
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go
index e947b2e8..77ed0682 100644
--- a/backend/ent/migrate/schema.go
+++ b/backend/ent/migrate/schema.go
@@ -408,6 +408,7 @@ var (
{Name: "require_privacy_set", Type: field.TypeBool, Default: false},
{Name: "default_mapped_model", Type: field.TypeString, Size: 100, Default: ""},
{Name: "messages_dispatch_model_config", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
+ {Name: "sora_storage_quota_bytes", Type: field.TypeInt64, Default: 0},
}
// GroupsTable holds the schema information for the "groups" table.
GroupsTable = &schema.Table{
@@ -931,6 +932,7 @@ var (
{Name: "model_mapping_chain", Type: field.TypeString, Nullable: true, Size: 500},
{Name: "billing_tier", Type: field.TypeString, Nullable: true, Size: 50},
{Name: "billing_mode", Type: field.TypeString, Nullable: true, Size: 20},
+ {Name: "media_type", Type: field.TypeString, Nullable: true, Size: 16},
{Name: "input_tokens", Type: field.TypeInt, Default: 0},
{Name: "output_tokens", Type: field.TypeInt, Default: 0},
{Name: "cache_creation_tokens", Type: field.TypeInt, Default: 0},
@@ -969,31 +971,31 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "usage_logs_api_keys_usage_logs",
- Columns: []*schema.Column{UsageLogsColumns[33]},
+ Columns: []*schema.Column{UsageLogsColumns[34]},
RefColumns: []*schema.Column{APIKeysColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_accounts_usage_logs",
- Columns: []*schema.Column{UsageLogsColumns[34]},
+ Columns: []*schema.Column{UsageLogsColumns[35]},
RefColumns: []*schema.Column{AccountsColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_groups_usage_logs",
- Columns: []*schema.Column{UsageLogsColumns[35]},
+ Columns: []*schema.Column{UsageLogsColumns[36]},
RefColumns: []*schema.Column{GroupsColumns[0]},
OnDelete: schema.SetNull,
},
{
Symbol: "usage_logs_users_usage_logs",
- Columns: []*schema.Column{UsageLogsColumns[36]},
+ Columns: []*schema.Column{UsageLogsColumns[37]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_user_subscriptions_usage_logs",
- Columns: []*schema.Column{UsageLogsColumns[37]},
+ Columns: []*schema.Column{UsageLogsColumns[38]},
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
OnDelete: schema.SetNull,
},
@@ -1002,32 +1004,32 @@ var (
{
Name: "usagelog_user_id",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[36]},
+ Columns: []*schema.Column{UsageLogsColumns[37]},
},
{
Name: "usagelog_api_key_id",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[33]},
+ Columns: []*schema.Column{UsageLogsColumns[34]},
},
{
Name: "usagelog_account_id",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[34]},
+ Columns: []*schema.Column{UsageLogsColumns[35]},
},
{
Name: "usagelog_group_id",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[35]},
+ Columns: []*schema.Column{UsageLogsColumns[36]},
},
{
Name: "usagelog_subscription_id",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[37]},
+ Columns: []*schema.Column{UsageLogsColumns[38]},
},
{
Name: "usagelog_created_at",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[32]},
+ Columns: []*schema.Column{UsageLogsColumns[33]},
},
{
Name: "usagelog_model",
@@ -1047,17 +1049,17 @@ var (
{
Name: "usagelog_user_id_created_at",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[36], UsageLogsColumns[32]},
+ Columns: []*schema.Column{UsageLogsColumns[37], UsageLogsColumns[33]},
},
{
Name: "usagelog_api_key_id_created_at",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[33], UsageLogsColumns[32]},
+ Columns: []*schema.Column{UsageLogsColumns[34], UsageLogsColumns[33]},
},
{
Name: "usagelog_group_id_created_at",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[35], UsageLogsColumns[32]},
+ Columns: []*schema.Column{UsageLogsColumns[36], UsageLogsColumns[33]},
},
},
}
@@ -1078,6 +1080,8 @@ var (
{Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
{Name: "totp_enabled", Type: field.TypeBool, Default: false},
{Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true},
+ {Name: "sora_storage_quota_bytes", Type: field.TypeInt64, Default: 0},
+ {Name: "sora_storage_used_bytes", Type: field.TypeInt64, Default: 0},
}
// UsersTable holds the schema information for the "users" table.
UsersTable = &schema.Table{
diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go
index 6b2fa838..e973eff2 100644
--- a/backend/ent/mutation.go
+++ b/backend/ent/mutation.go
@@ -8255,6 +8255,8 @@ type GroupMutation struct {
require_privacy_set *bool
default_mapped_model *string
messages_dispatch_model_config *domain.OpenAIMessagesDispatchModelConfig
+ sora_storage_quota_bytes *int64
+ addsora_storage_quota_bytes *int64
clearedFields map[string]struct{}
api_keys map[int64]struct{}
removedapi_keys map[int64]struct{}
@@ -9843,6 +9845,62 @@ func (m *GroupMutation) ResetMessagesDispatchModelConfig() {
m.messages_dispatch_model_config = nil
}
+// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
+func (m *GroupMutation) SetSoraStorageQuotaBytes(i int64) {
+ m.sora_storage_quota_bytes = &i
+ m.addsora_storage_quota_bytes = nil
+}
+
+// SoraStorageQuotaBytes returns the value of the "sora_storage_quota_bytes" field in the mutation.
+func (m *GroupMutation) SoraStorageQuotaBytes() (r int64, exists bool) {
+ v := m.sora_storage_quota_bytes
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldSoraStorageQuotaBytes returns the old "sora_storage_quota_bytes" field's value of the Group entity.
+// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *GroupMutation) OldSoraStorageQuotaBytes(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldSoraStorageQuotaBytes is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldSoraStorageQuotaBytes requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldSoraStorageQuotaBytes: %w", err)
+ }
+ return oldValue.SoraStorageQuotaBytes, nil
+}
+
+// AddSoraStorageQuotaBytes adds i to the "sora_storage_quota_bytes" field.
+func (m *GroupMutation) AddSoraStorageQuotaBytes(i int64) {
+ if m.addsora_storage_quota_bytes != nil {
+ *m.addsora_storage_quota_bytes += i
+ } else {
+ m.addsora_storage_quota_bytes = &i
+ }
+}
+
+// AddedSoraStorageQuotaBytes returns the value that was added to the "sora_storage_quota_bytes" field in this mutation.
+func (m *GroupMutation) AddedSoraStorageQuotaBytes() (r int64, exists bool) {
+ v := m.addsora_storage_quota_bytes
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetSoraStorageQuotaBytes resets all changes to the "sora_storage_quota_bytes" field.
+func (m *GroupMutation) ResetSoraStorageQuotaBytes() {
+ m.sora_storage_quota_bytes = nil
+ m.addsora_storage_quota_bytes = nil
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) {
if m.api_keys == nil {
@@ -10201,7 +10259,7 @@ func (m *GroupMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *GroupMutation) Fields() []string {
- fields := make([]string, 0, 30)
+ fields := make([]string, 0, 31)
if m.created_at != nil {
fields = append(fields, group.FieldCreatedAt)
}
@@ -10292,6 +10350,9 @@ func (m *GroupMutation) Fields() []string {
if m.messages_dispatch_model_config != nil {
fields = append(fields, group.FieldMessagesDispatchModelConfig)
}
+ if m.sora_storage_quota_bytes != nil {
+ fields = append(fields, group.FieldSoraStorageQuotaBytes)
+ }
return fields
}
@@ -10360,6 +10421,8 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
return m.DefaultMappedModel()
case group.FieldMessagesDispatchModelConfig:
return m.MessagesDispatchModelConfig()
+ case group.FieldSoraStorageQuotaBytes:
+ return m.SoraStorageQuotaBytes()
}
return nil, false
}
@@ -10429,6 +10492,8 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
return m.OldDefaultMappedModel(ctx)
case group.FieldMessagesDispatchModelConfig:
return m.OldMessagesDispatchModelConfig(ctx)
+ case group.FieldSoraStorageQuotaBytes:
+ return m.OldSoraStorageQuotaBytes(ctx)
}
return nil, fmt.Errorf("unknown Group field %s", name)
}
@@ -10648,6 +10713,13 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
}
m.SetMessagesDispatchModelConfig(v)
return nil
+ case group.FieldSoraStorageQuotaBytes:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetSoraStorageQuotaBytes(v)
+ return nil
}
return fmt.Errorf("unknown Group field %s", name)
}
@@ -10689,6 +10761,9 @@ func (m *GroupMutation) AddedFields() []string {
if m.addsort_order != nil {
fields = append(fields, group.FieldSortOrder)
}
+ if m.addsora_storage_quota_bytes != nil {
+ fields = append(fields, group.FieldSoraStorageQuotaBytes)
+ }
return fields
}
@@ -10719,6 +10794,8 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) {
return m.AddedFallbackGroupIDOnInvalidRequest()
case group.FieldSortOrder:
return m.AddedSortOrder()
+ case group.FieldSoraStorageQuotaBytes:
+ return m.AddedSoraStorageQuotaBytes()
}
return nil, false
}
@@ -10805,6 +10882,13 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error {
}
m.AddSortOrder(v)
return nil
+ case group.FieldSoraStorageQuotaBytes:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddSoraStorageQuotaBytes(v)
+ return nil
}
return fmt.Errorf("unknown Group numeric field %s", name)
}
@@ -10991,6 +11075,9 @@ func (m *GroupMutation) ResetField(name string) error {
case group.FieldMessagesDispatchModelConfig:
m.ResetMessagesDispatchModelConfig()
return nil
+ case group.FieldSoraStorageQuotaBytes:
+ m.ResetSoraStorageQuotaBytes()
+ return nil
}
return fmt.Errorf("unknown Group field %s", name)
}
@@ -24792,6 +24879,7 @@ type UsageLogMutation struct {
model_mapping_chain *string
billing_tier *string
billing_mode *string
+ media_type *string
input_tokens *int
addinput_tokens *int
output_tokens *int
@@ -25443,6 +25531,55 @@ func (m *UsageLogMutation) ResetBillingMode() {
delete(m.clearedFields, usagelog.FieldBillingMode)
}
+// SetMediaType sets the "media_type" field.
+func (m *UsageLogMutation) SetMediaType(s string) {
+ m.media_type = &s
+}
+
+// MediaType returns the value of the "media_type" field in the mutation.
+func (m *UsageLogMutation) MediaType() (r string, exists bool) {
+ v := m.media_type
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldMediaType returns the old "media_type" field's value of the UsageLog entity.
+// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UsageLogMutation) OldMediaType(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldMediaType is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldMediaType requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldMediaType: %w", err)
+ }
+ return oldValue.MediaType, nil
+}
+
+// ClearMediaType clears the value of the "media_type" field.
+func (m *UsageLogMutation) ClearMediaType() {
+ m.media_type = nil
+ m.clearedFields[usagelog.FieldMediaType] = struct{}{}
+}
+
+// MediaTypeCleared returns if the "media_type" field was cleared in this mutation.
+func (m *UsageLogMutation) MediaTypeCleared() bool {
+ _, ok := m.clearedFields[usagelog.FieldMediaType]
+ return ok
+}
+
+// ResetMediaType resets all changes to the "media_type" field.
+func (m *UsageLogMutation) ResetMediaType() {
+ m.media_type = nil
+ delete(m.clearedFields, usagelog.FieldMediaType)
+}
+
// SetGroupID sets the "group_id" field.
func (m *UsageLogMutation) SetGroupID(i int64) {
m.group = &i
@@ -27015,7 +27152,7 @@ func (m *UsageLogMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *UsageLogMutation) Fields() []string {
- fields := make([]string, 0, 37)
+ fields := make([]string, 0, 38)
if m.user != nil {
fields = append(fields, usagelog.FieldUserID)
}
@@ -27049,6 +27186,9 @@ func (m *UsageLogMutation) Fields() []string {
if m.billing_mode != nil {
fields = append(fields, usagelog.FieldBillingMode)
}
+ if m.media_type != nil {
+ fields = append(fields, usagelog.FieldMediaType)
+ }
if m.group != nil {
fields = append(fields, usagelog.FieldGroupID)
}
@@ -27157,6 +27297,8 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) {
return m.BillingTier()
case usagelog.FieldBillingMode:
return m.BillingMode()
+ case usagelog.FieldMediaType:
+ return m.MediaType()
case usagelog.FieldGroupID:
return m.GroupID()
case usagelog.FieldSubscriptionID:
@@ -27240,6 +27382,8 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value
return m.OldBillingTier(ctx)
case usagelog.FieldBillingMode:
return m.OldBillingMode(ctx)
+ case usagelog.FieldMediaType:
+ return m.OldMediaType(ctx)
case usagelog.FieldGroupID:
return m.OldGroupID(ctx)
case usagelog.FieldSubscriptionID:
@@ -27378,6 +27522,13 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error {
}
m.SetBillingMode(v)
return nil
+ case usagelog.FieldMediaType:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetMediaType(v)
+ return nil
case usagelog.FieldGroupID:
v, ok := value.(int64)
if !ok {
@@ -27839,6 +27990,9 @@ func (m *UsageLogMutation) ClearedFields() []string {
if m.FieldCleared(usagelog.FieldBillingMode) {
fields = append(fields, usagelog.FieldBillingMode)
}
+ if m.FieldCleared(usagelog.FieldMediaType) {
+ fields = append(fields, usagelog.FieldMediaType)
+ }
if m.FieldCleared(usagelog.FieldGroupID) {
fields = append(fields, usagelog.FieldGroupID)
}
@@ -27895,6 +28049,9 @@ func (m *UsageLogMutation) ClearField(name string) error {
case usagelog.FieldBillingMode:
m.ClearBillingMode()
return nil
+ case usagelog.FieldMediaType:
+ m.ClearMediaType()
+ return nil
case usagelog.FieldGroupID:
m.ClearGroupID()
return nil
@@ -27960,6 +28117,9 @@ func (m *UsageLogMutation) ResetField(name string) error {
case usagelog.FieldBillingMode:
m.ResetBillingMode()
return nil
+ case usagelog.FieldMediaType:
+ m.ResetMediaType()
+ return nil
case usagelog.FieldGroupID:
m.ResetGroupID()
return nil
@@ -28210,6 +28370,10 @@ type UserMutation struct {
totp_secret_encrypted *string
totp_enabled *bool
totp_enabled_at *time.Time
+ sora_storage_quota_bytes *int64
+ addsora_storage_quota_bytes *int64
+ sora_storage_used_bytes *int64
+ addsora_storage_used_bytes *int64
clearedFields map[string]struct{}
api_keys map[int64]struct{}
removedapi_keys map[int64]struct{}
@@ -28927,6 +29091,118 @@ func (m *UserMutation) ResetTotpEnabledAt() {
delete(m.clearedFields, user.FieldTotpEnabledAt)
}
+// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
+func (m *UserMutation) SetSoraStorageQuotaBytes(i int64) {
+ m.sora_storage_quota_bytes = &i
+ m.addsora_storage_quota_bytes = nil
+}
+
+// SoraStorageQuotaBytes returns the value of the "sora_storage_quota_bytes" field in the mutation.
+func (m *UserMutation) SoraStorageQuotaBytes() (r int64, exists bool) {
+ v := m.sora_storage_quota_bytes
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldSoraStorageQuotaBytes returns the old "sora_storage_quota_bytes" field's value of the User entity.
+// If the User object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UserMutation) OldSoraStorageQuotaBytes(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldSoraStorageQuotaBytes is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldSoraStorageQuotaBytes requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldSoraStorageQuotaBytes: %w", err)
+ }
+ return oldValue.SoraStorageQuotaBytes, nil
+}
+
+// AddSoraStorageQuotaBytes adds i to the "sora_storage_quota_bytes" field.
+func (m *UserMutation) AddSoraStorageQuotaBytes(i int64) {
+ if m.addsora_storage_quota_bytes != nil {
+ *m.addsora_storage_quota_bytes += i
+ } else {
+ m.addsora_storage_quota_bytes = &i
+ }
+}
+
+// AddedSoraStorageQuotaBytes returns the value that was added to the "sora_storage_quota_bytes" field in this mutation.
+func (m *UserMutation) AddedSoraStorageQuotaBytes() (r int64, exists bool) {
+ v := m.addsora_storage_quota_bytes
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetSoraStorageQuotaBytes resets all changes to the "sora_storage_quota_bytes" field.
+func (m *UserMutation) ResetSoraStorageQuotaBytes() {
+ m.sora_storage_quota_bytes = nil
+ m.addsora_storage_quota_bytes = nil
+}
+
+// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
+func (m *UserMutation) SetSoraStorageUsedBytes(i int64) {
+ m.sora_storage_used_bytes = &i
+ m.addsora_storage_used_bytes = nil
+}
+
+// SoraStorageUsedBytes returns the value of the "sora_storage_used_bytes" field in the mutation.
+func (m *UserMutation) SoraStorageUsedBytes() (r int64, exists bool) {
+ v := m.sora_storage_used_bytes
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldSoraStorageUsedBytes returns the old "sora_storage_used_bytes" field's value of the User entity.
+// If the User object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UserMutation) OldSoraStorageUsedBytes(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldSoraStorageUsedBytes is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldSoraStorageUsedBytes requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldSoraStorageUsedBytes: %w", err)
+ }
+ return oldValue.SoraStorageUsedBytes, nil
+}
+
+// AddSoraStorageUsedBytes adds i to the "sora_storage_used_bytes" field.
+func (m *UserMutation) AddSoraStorageUsedBytes(i int64) {
+ if m.addsora_storage_used_bytes != nil {
+ *m.addsora_storage_used_bytes += i
+ } else {
+ m.addsora_storage_used_bytes = &i
+ }
+}
+
+// AddedSoraStorageUsedBytes returns the value that was added to the "sora_storage_used_bytes" field in this mutation.
+func (m *UserMutation) AddedSoraStorageUsedBytes() (r int64, exists bool) {
+ v := m.addsora_storage_used_bytes
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetSoraStorageUsedBytes resets all changes to the "sora_storage_used_bytes" field.
+func (m *UserMutation) ResetSoraStorageUsedBytes() {
+ m.sora_storage_used_bytes = nil
+ m.addsora_storage_used_bytes = nil
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
func (m *UserMutation) AddAPIKeyIDs(ids ...int64) {
if m.api_keys == nil {
@@ -29501,7 +29777,7 @@ func (m *UserMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *UserMutation) Fields() []string {
- fields := make([]string, 0, 14)
+ fields := make([]string, 0, 16)
if m.created_at != nil {
fields = append(fields, user.FieldCreatedAt)
}
@@ -29544,6 +29820,12 @@ func (m *UserMutation) Fields() []string {
if m.totp_enabled_at != nil {
fields = append(fields, user.FieldTotpEnabledAt)
}
+ if m.sora_storage_quota_bytes != nil {
+ fields = append(fields, user.FieldSoraStorageQuotaBytes)
+ }
+ if m.sora_storage_used_bytes != nil {
+ fields = append(fields, user.FieldSoraStorageUsedBytes)
+ }
return fields
}
@@ -29580,6 +29862,10 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) {
return m.TotpEnabled()
case user.FieldTotpEnabledAt:
return m.TotpEnabledAt()
+ case user.FieldSoraStorageQuotaBytes:
+ return m.SoraStorageQuotaBytes()
+ case user.FieldSoraStorageUsedBytes:
+ return m.SoraStorageUsedBytes()
}
return nil, false
}
@@ -29617,6 +29903,10 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er
return m.OldTotpEnabled(ctx)
case user.FieldTotpEnabledAt:
return m.OldTotpEnabledAt(ctx)
+ case user.FieldSoraStorageQuotaBytes:
+ return m.OldSoraStorageQuotaBytes(ctx)
+ case user.FieldSoraStorageUsedBytes:
+ return m.OldSoraStorageUsedBytes(ctx)
}
return nil, fmt.Errorf("unknown User field %s", name)
}
@@ -29724,6 +30014,20 @@ func (m *UserMutation) SetField(name string, value ent.Value) error {
}
m.SetTotpEnabledAt(v)
return nil
+ case user.FieldSoraStorageQuotaBytes:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetSoraStorageQuotaBytes(v)
+ return nil
+ case user.FieldSoraStorageUsedBytes:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetSoraStorageUsedBytes(v)
+ return nil
}
return fmt.Errorf("unknown User field %s", name)
}
@@ -29738,6 +30042,12 @@ func (m *UserMutation) AddedFields() []string {
if m.addconcurrency != nil {
fields = append(fields, user.FieldConcurrency)
}
+ if m.addsora_storage_quota_bytes != nil {
+ fields = append(fields, user.FieldSoraStorageQuotaBytes)
+ }
+ if m.addsora_storage_used_bytes != nil {
+ fields = append(fields, user.FieldSoraStorageUsedBytes)
+ }
return fields
}
@@ -29750,6 +30060,10 @@ func (m *UserMutation) AddedField(name string) (ent.Value, bool) {
return m.AddedBalance()
case user.FieldConcurrency:
return m.AddedConcurrency()
+ case user.FieldSoraStorageQuotaBytes:
+ return m.AddedSoraStorageQuotaBytes()
+ case user.FieldSoraStorageUsedBytes:
+ return m.AddedSoraStorageUsedBytes()
}
return nil, false
}
@@ -29773,6 +30087,20 @@ func (m *UserMutation) AddField(name string, value ent.Value) error {
}
m.AddConcurrency(v)
return nil
+ case user.FieldSoraStorageQuotaBytes:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddSoraStorageQuotaBytes(v)
+ return nil
+ case user.FieldSoraStorageUsedBytes:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddSoraStorageUsedBytes(v)
+ return nil
}
return fmt.Errorf("unknown User numeric field %s", name)
}
@@ -29863,6 +30191,12 @@ func (m *UserMutation) ResetField(name string) error {
case user.FieldTotpEnabledAt:
m.ResetTotpEnabledAt()
return nil
+ case user.FieldSoraStorageQuotaBytes:
+ m.ResetSoraStorageQuotaBytes()
+ return nil
+ case user.FieldSoraStorageUsedBytes:
+ m.ResetSoraStorageUsedBytes()
+ return nil
}
return fmt.Errorf("unknown User field %s", name)
}
diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go
index 821b7d66..352d14af 100644
--- a/backend/ent/runtime/runtime.go
+++ b/backend/ent/runtime/runtime.go
@@ -477,6 +477,10 @@ func init() {
groupDescMessagesDispatchModelConfig := groupFields[26].Descriptor()
// group.DefaultMessagesDispatchModelConfig holds the default value on creation for the messages_dispatch_model_config field.
group.DefaultMessagesDispatchModelConfig = groupDescMessagesDispatchModelConfig.Default.(domain.OpenAIMessagesDispatchModelConfig)
+ // groupDescSoraStorageQuotaBytes is the schema descriptor for sora_storage_quota_bytes field.
+ groupDescSoraStorageQuotaBytes := groupFields[27].Descriptor()
+ // group.DefaultSoraStorageQuotaBytes holds the default value on creation for the sora_storage_quota_bytes field.
+ group.DefaultSoraStorageQuotaBytes = groupDescSoraStorageQuotaBytes.Default.(int64)
idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin()
idempotencyrecordMixinFields0 := idempotencyrecordMixin[0].Fields()
_ = idempotencyrecordMixinFields0
@@ -1120,88 +1124,92 @@ func init() {
usagelogDescBillingMode := usagelogFields[10].Descriptor()
// usagelog.BillingModeValidator is a validator for the "billing_mode" field. It is called by the builders before save.
usagelog.BillingModeValidator = usagelogDescBillingMode.Validators[0].(func(string) error)
+ // usagelogDescMediaType is the schema descriptor for media_type field.
+ usagelogDescMediaType := usagelogFields[11].Descriptor()
+ // usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
+ usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error)
// usagelogDescInputTokens is the schema descriptor for input_tokens field.
- usagelogDescInputTokens := usagelogFields[13].Descriptor()
+ usagelogDescInputTokens := usagelogFields[14].Descriptor()
// usagelog.DefaultInputTokens holds the default value on creation for the input_tokens field.
usagelog.DefaultInputTokens = usagelogDescInputTokens.Default.(int)
// usagelogDescOutputTokens is the schema descriptor for output_tokens field.
- usagelogDescOutputTokens := usagelogFields[14].Descriptor()
+ usagelogDescOutputTokens := usagelogFields[15].Descriptor()
// usagelog.DefaultOutputTokens holds the default value on creation for the output_tokens field.
usagelog.DefaultOutputTokens = usagelogDescOutputTokens.Default.(int)
// usagelogDescCacheCreationTokens is the schema descriptor for cache_creation_tokens field.
- usagelogDescCacheCreationTokens := usagelogFields[15].Descriptor()
+ usagelogDescCacheCreationTokens := usagelogFields[16].Descriptor()
// usagelog.DefaultCacheCreationTokens holds the default value on creation for the cache_creation_tokens field.
usagelog.DefaultCacheCreationTokens = usagelogDescCacheCreationTokens.Default.(int)
// usagelogDescCacheReadTokens is the schema descriptor for cache_read_tokens field.
- usagelogDescCacheReadTokens := usagelogFields[16].Descriptor()
+ usagelogDescCacheReadTokens := usagelogFields[17].Descriptor()
// usagelog.DefaultCacheReadTokens holds the default value on creation for the cache_read_tokens field.
usagelog.DefaultCacheReadTokens = usagelogDescCacheReadTokens.Default.(int)
// usagelogDescCacheCreation5mTokens is the schema descriptor for cache_creation_5m_tokens field.
- usagelogDescCacheCreation5mTokens := usagelogFields[17].Descriptor()
+ usagelogDescCacheCreation5mTokens := usagelogFields[18].Descriptor()
// usagelog.DefaultCacheCreation5mTokens holds the default value on creation for the cache_creation_5m_tokens field.
usagelog.DefaultCacheCreation5mTokens = usagelogDescCacheCreation5mTokens.Default.(int)
// usagelogDescCacheCreation1hTokens is the schema descriptor for cache_creation_1h_tokens field.
- usagelogDescCacheCreation1hTokens := usagelogFields[18].Descriptor()
+ usagelogDescCacheCreation1hTokens := usagelogFields[19].Descriptor()
// usagelog.DefaultCacheCreation1hTokens holds the default value on creation for the cache_creation_1h_tokens field.
usagelog.DefaultCacheCreation1hTokens = usagelogDescCacheCreation1hTokens.Default.(int)
// usagelogDescInputCost is the schema descriptor for input_cost field.
- usagelogDescInputCost := usagelogFields[19].Descriptor()
+ usagelogDescInputCost := usagelogFields[20].Descriptor()
// usagelog.DefaultInputCost holds the default value on creation for the input_cost field.
usagelog.DefaultInputCost = usagelogDescInputCost.Default.(float64)
// usagelogDescOutputCost is the schema descriptor for output_cost field.
- usagelogDescOutputCost := usagelogFields[20].Descriptor()
+ usagelogDescOutputCost := usagelogFields[21].Descriptor()
// usagelog.DefaultOutputCost holds the default value on creation for the output_cost field.
usagelog.DefaultOutputCost = usagelogDescOutputCost.Default.(float64)
// usagelogDescCacheCreationCost is the schema descriptor for cache_creation_cost field.
- usagelogDescCacheCreationCost := usagelogFields[21].Descriptor()
+ usagelogDescCacheCreationCost := usagelogFields[22].Descriptor()
// usagelog.DefaultCacheCreationCost holds the default value on creation for the cache_creation_cost field.
usagelog.DefaultCacheCreationCost = usagelogDescCacheCreationCost.Default.(float64)
// usagelogDescCacheReadCost is the schema descriptor for cache_read_cost field.
- usagelogDescCacheReadCost := usagelogFields[22].Descriptor()
+ usagelogDescCacheReadCost := usagelogFields[23].Descriptor()
// usagelog.DefaultCacheReadCost holds the default value on creation for the cache_read_cost field.
usagelog.DefaultCacheReadCost = usagelogDescCacheReadCost.Default.(float64)
// usagelogDescTotalCost is the schema descriptor for total_cost field.
- usagelogDescTotalCost := usagelogFields[23].Descriptor()
+ usagelogDescTotalCost := usagelogFields[24].Descriptor()
// usagelog.DefaultTotalCost holds the default value on creation for the total_cost field.
usagelog.DefaultTotalCost = usagelogDescTotalCost.Default.(float64)
// usagelogDescActualCost is the schema descriptor for actual_cost field.
- usagelogDescActualCost := usagelogFields[24].Descriptor()
+ usagelogDescActualCost := usagelogFields[25].Descriptor()
// usagelog.DefaultActualCost holds the default value on creation for the actual_cost field.
usagelog.DefaultActualCost = usagelogDescActualCost.Default.(float64)
// usagelogDescRateMultiplier is the schema descriptor for rate_multiplier field.
- usagelogDescRateMultiplier := usagelogFields[25].Descriptor()
+ usagelogDescRateMultiplier := usagelogFields[26].Descriptor()
// usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64)
// usagelogDescBillingType is the schema descriptor for billing_type field.
- usagelogDescBillingType := usagelogFields[27].Descriptor()
+ usagelogDescBillingType := usagelogFields[28].Descriptor()
// usagelog.DefaultBillingType holds the default value on creation for the billing_type field.
usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8)
// usagelogDescStream is the schema descriptor for stream field.
- usagelogDescStream := usagelogFields[28].Descriptor()
+ usagelogDescStream := usagelogFields[29].Descriptor()
// usagelog.DefaultStream holds the default value on creation for the stream field.
usagelog.DefaultStream = usagelogDescStream.Default.(bool)
// usagelogDescUserAgent is the schema descriptor for user_agent field.
- usagelogDescUserAgent := usagelogFields[31].Descriptor()
+ usagelogDescUserAgent := usagelogFields[32].Descriptor()
// usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error)
// usagelogDescIPAddress is the schema descriptor for ip_address field.
- usagelogDescIPAddress := usagelogFields[32].Descriptor()
+ usagelogDescIPAddress := usagelogFields[33].Descriptor()
// usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error)
// usagelogDescImageCount is the schema descriptor for image_count field.
- usagelogDescImageCount := usagelogFields[33].Descriptor()
+ usagelogDescImageCount := usagelogFields[34].Descriptor()
// usagelog.DefaultImageCount holds the default value on creation for the image_count field.
usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int)
// usagelogDescImageSize is the schema descriptor for image_size field.
- usagelogDescImageSize := usagelogFields[34].Descriptor()
+ usagelogDescImageSize := usagelogFields[35].Descriptor()
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
// usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field.
- usagelogDescCacheTTLOverridden := usagelogFields[35].Descriptor()
+ usagelogDescCacheTTLOverridden := usagelogFields[36].Descriptor()
// usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field.
usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool)
// usagelogDescCreatedAt is the schema descriptor for created_at field.
- usagelogDescCreatedAt := usagelogFields[36].Descriptor()
+ usagelogDescCreatedAt := usagelogFields[37].Descriptor()
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
userMixin := schema.User{}.Mixin()
@@ -1293,6 +1301,14 @@ func init() {
userDescTotpEnabled := userFields[9].Descriptor()
// user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field.
user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool)
+ // userDescSoraStorageQuotaBytes is the schema descriptor for sora_storage_quota_bytes field.
+ userDescSoraStorageQuotaBytes := userFields[11].Descriptor()
+ // user.DefaultSoraStorageQuotaBytes holds the default value on creation for the sora_storage_quota_bytes field.
+ user.DefaultSoraStorageQuotaBytes = userDescSoraStorageQuotaBytes.Default.(int64)
+ // userDescSoraStorageUsedBytes is the schema descriptor for sora_storage_used_bytes field.
+ userDescSoraStorageUsedBytes := userFields[12].Descriptor()
+ // user.DefaultSoraStorageUsedBytes holds the default value on creation for the sora_storage_used_bytes field.
+ user.DefaultSoraStorageUsedBytes = userDescSoraStorageUsedBytes.Default.(int64)
userallowedgroupFields := schema.UserAllowedGroup{}.Fields()
_ = userallowedgroupFields
// userallowedgroupDescCreatedAt is the schema descriptor for created_at field.
diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go
index d78a6898..82487990 100644
--- a/backend/ent/schema/group.go
+++ b/backend/ent/schema/group.go
@@ -145,6 +145,10 @@ func (Group) Fields() []ent.Field {
Default(domain.OpenAIMessagesDispatchModelConfig{}).
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
Comment("OpenAI Messages 调度模型配置:按 Claude 系列/精确模型映射到目标 GPT 模型"),
+
+ // Sora 存储配额 (从本地版本合并)
+ field.Int64("sora_storage_quota_bytes").
+ Default(0),
}
}
diff --git a/backend/ent/schema/usage_log.go b/backend/ent/schema/usage_log.go
index bd3ebfcc..867fb7e3 100644
--- a/backend/ent/schema/usage_log.go
+++ b/backend/ent/schema/usage_log.go
@@ -57,6 +57,7 @@ func (UsageLog) Fields() []ent.Field {
field.String("model_mapping_chain").MaxLen(500).Optional().Nillable().Comment("模型映射链"),
field.String("billing_tier").MaxLen(50).Optional().Nillable().Comment("计费层级标签"),
field.String("billing_mode").MaxLen(20).Optional().Nillable().Comment("计费模式:token/per_request/image"),
+ field.String("media_type").MaxLen(16).Optional().Nillable().Comment("媒体类型:video/image(Sora生成)"),
field.Int64("group_id").
Optional().
Nillable(),
diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go
index af143d38..e5b5a83b 100644
--- a/backend/ent/schema/user.go
+++ b/backend/ent/schema/user.go
@@ -72,6 +72,12 @@ func (User) Fields() []ent.Field {
field.Time("totp_enabled_at").
Optional().
Nillable(),
+
+ // Sora 存储配额 (从本地版本合并)
+ field.Int64("sora_storage_quota_bytes").
+ Default(0),
+ field.Int64("sora_storage_used_bytes").
+ Default(0),
}
}
diff --git a/backend/ent/usagelog.go b/backend/ent/usagelog.go
index a8e0cc6c..ad5680c4 100644
--- a/backend/ent/usagelog.go
+++ b/backend/ent/usagelog.go
@@ -44,6 +44,8 @@ type UsageLog struct {
BillingTier *string `json:"billing_tier,omitempty"`
// 计费模式:token/per_request/image
BillingMode *string `json:"billing_mode,omitempty"`
+ // 媒体类型:video/image(Sora生成)
+ MediaType *string `json:"media_type,omitempty"`
// GroupID holds the value of the "group_id" field.
GroupID *int64 `json:"group_id,omitempty"`
// SubscriptionID holds the value of the "subscription_id" field.
@@ -185,7 +187,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullFloat64)
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldChannelID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
values[i] = new(sql.NullInt64)
- case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldRequestedModel, usagelog.FieldUpstreamModel, usagelog.FieldModelMappingChain, usagelog.FieldBillingTier, usagelog.FieldBillingMode, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize:
+ case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldRequestedModel, usagelog.FieldUpstreamModel, usagelog.FieldModelMappingChain, usagelog.FieldBillingTier, usagelog.FieldBillingMode, usagelog.FieldMediaType, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize:
values[i] = new(sql.NullString)
case usagelog.FieldCreatedAt:
values[i] = new(sql.NullTime)
@@ -282,6 +284,13 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
_m.BillingMode = new(string)
*_m.BillingMode = value.String
}
+ case usagelog.FieldMediaType:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field media_type", values[i])
+ } else if value.Valid {
+ _m.MediaType = new(string)
+ *_m.MediaType = value.String
+ }
case usagelog.FieldGroupID:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field group_id", values[i])
@@ -552,6 +561,11 @@ func (_m *UsageLog) String() string {
builder.WriteString(*v)
}
builder.WriteString(", ")
+ if v := _m.MediaType; v != nil {
+ builder.WriteString("media_type=")
+ builder.WriteString(*v)
+ }
+ builder.WriteString(", ")
if v := _m.GroupID; v != nil {
builder.WriteString("group_id=")
builder.WriteString(fmt.Sprintf("%v", *v))
diff --git a/backend/ent/usagelog/usagelog.go b/backend/ent/usagelog/usagelog.go
index a7438e60..5f3ca0e2 100644
--- a/backend/ent/usagelog/usagelog.go
+++ b/backend/ent/usagelog/usagelog.go
@@ -36,6 +36,8 @@ const (
FieldBillingTier = "billing_tier"
// FieldBillingMode holds the string denoting the billing_mode field in the database.
FieldBillingMode = "billing_mode"
+ // FieldMediaType holds the string denoting the media_type field in the database.
+ FieldMediaType = "media_type"
// FieldGroupID holds the string denoting the group_id field in the database.
FieldGroupID = "group_id"
// FieldSubscriptionID holds the string denoting the subscription_id field in the database.
@@ -151,6 +153,7 @@ var Columns = []string{
FieldModelMappingChain,
FieldBillingTier,
FieldBillingMode,
+ FieldMediaType,
FieldGroupID,
FieldSubscriptionID,
FieldInputTokens,
@@ -204,6 +207,8 @@ var (
BillingTierValidator func(string) error
// BillingModeValidator is a validator for the "billing_mode" field. It is called by the builders before save.
BillingModeValidator func(string) error
+ // MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
+ MediaTypeValidator func(string) error
// DefaultInputTokens holds the default value on creation for the "input_tokens" field.
DefaultInputTokens int
// DefaultOutputTokens holds the default value on creation for the "output_tokens" field.
@@ -311,6 +316,11 @@ func ByBillingMode(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldBillingMode, opts...).ToFunc()
}
+// ByMediaType orders the results by the media_type field.
+func ByMediaType(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldMediaType, opts...).ToFunc()
+}
+
// ByGroupID orders the results by the group_id field.
func ByGroupID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldGroupID, opts...).ToFunc()
diff --git a/backend/ent/usagelog/where.go b/backend/ent/usagelog/where.go
index b8439a03..ecbd4861 100644
--- a/backend/ent/usagelog/where.go
+++ b/backend/ent/usagelog/where.go
@@ -110,6 +110,11 @@ func BillingMode(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldBillingMode, v))
}
+// MediaType applies equality check predicate on the "media_type" field. It's identical to MediaTypeEQ.
+func MediaType(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldEQ(FieldMediaType, v))
+}
+
// GroupID applies equality check predicate on the "group_id" field. It's identical to GroupIDEQ.
func GroupID(v int64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v))
@@ -855,6 +860,81 @@ func BillingModeContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldBillingMode, v))
}
+// MediaTypeEQ applies the EQ predicate on the "media_type" field.
+func MediaTypeEQ(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldEQ(FieldMediaType, v))
+}
+
+// MediaTypeNEQ applies the NEQ predicate on the "media_type" field.
+func MediaTypeNEQ(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldNEQ(FieldMediaType, v))
+}
+
+// MediaTypeIn applies the In predicate on the "media_type" field.
+func MediaTypeIn(vs ...string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldIn(FieldMediaType, vs...))
+}
+
+// MediaTypeNotIn applies the NotIn predicate on the "media_type" field.
+func MediaTypeNotIn(vs ...string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldNotIn(FieldMediaType, vs...))
+}
+
+// MediaTypeGT applies the GT predicate on the "media_type" field.
+func MediaTypeGT(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldGT(FieldMediaType, v))
+}
+
+// MediaTypeGTE applies the GTE predicate on the "media_type" field.
+func MediaTypeGTE(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldGTE(FieldMediaType, v))
+}
+
+// MediaTypeLT applies the LT predicate on the "media_type" field.
+func MediaTypeLT(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldLT(FieldMediaType, v))
+}
+
+// MediaTypeLTE applies the LTE predicate on the "media_type" field.
+func MediaTypeLTE(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldLTE(FieldMediaType, v))
+}
+
+// MediaTypeContains applies the Contains predicate on the "media_type" field.
+func MediaTypeContains(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldContains(FieldMediaType, v))
+}
+
+// MediaTypeHasPrefix applies the HasPrefix predicate on the "media_type" field.
+func MediaTypeHasPrefix(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldHasPrefix(FieldMediaType, v))
+}
+
+// MediaTypeHasSuffix applies the HasSuffix predicate on the "media_type" field.
+func MediaTypeHasSuffix(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldHasSuffix(FieldMediaType, v))
+}
+
+// MediaTypeIsNil applies the IsNil predicate on the "media_type" field.
+func MediaTypeIsNil() predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldIsNull(FieldMediaType))
+}
+
+// MediaTypeNotNil applies the NotNil predicate on the "media_type" field.
+func MediaTypeNotNil() predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldNotNull(FieldMediaType))
+}
+
+// MediaTypeEqualFold applies the EqualFold predicate on the "media_type" field.
+func MediaTypeEqualFold(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldEqualFold(FieldMediaType, v))
+}
+
+// MediaTypeContainsFold applies the ContainsFold predicate on the "media_type" field.
+func MediaTypeContainsFold(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldContainsFold(FieldMediaType, v))
+}
+
// GroupIDEQ applies the EQ predicate on the "group_id" field.
func GroupIDEQ(v int64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v))
diff --git a/backend/ent/usagelog_create.go b/backend/ent/usagelog_create.go
index fded364e..97a1ceeb 100644
--- a/backend/ent/usagelog_create.go
+++ b/backend/ent/usagelog_create.go
@@ -141,6 +141,20 @@ func (_c *UsageLogCreate) SetNillableBillingMode(v *string) *UsageLogCreate {
return _c
}
+// SetMediaType sets the "media_type" field.
+func (_c *UsageLogCreate) SetMediaType(v string) *UsageLogCreate {
+ _c.mutation.SetMediaType(v)
+ return _c
+}
+
+// SetNillableMediaType sets the "media_type" field if the given value is not nil.
+func (_c *UsageLogCreate) SetNillableMediaType(v *string) *UsageLogCreate {
+ if v != nil {
+ _c.SetMediaType(*v)
+ }
+ return _c
+}
+
// SetGroupID sets the "group_id" field.
func (_c *UsageLogCreate) SetGroupID(v int64) *UsageLogCreate {
_c.mutation.SetGroupID(v)
@@ -691,6 +705,11 @@ func (_c *UsageLogCreate) check() error {
return &ValidationError{Name: "billing_mode", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_mode": %w`, err)}
}
}
+ if v, ok := _c.mutation.MediaType(); ok {
+ if err := usagelog.MediaTypeValidator(v); err != nil {
+ return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)}
+ }
+ }
if _, ok := _c.mutation.InputTokens(); !ok {
return &ValidationError{Name: "input_tokens", err: errors.New(`ent: missing required field "UsageLog.input_tokens"`)}
}
@@ -828,6 +847,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
_spec.SetField(usagelog.FieldBillingMode, field.TypeString, value)
_node.BillingMode = &value
}
+ if value, ok := _c.mutation.MediaType(); ok {
+ _spec.SetField(usagelog.FieldMediaType, field.TypeString, value)
+ _node.MediaType = &value
+ }
if value, ok := _c.mutation.InputTokens(); ok {
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
_node.InputTokens = value
@@ -1235,6 +1258,24 @@ func (u *UsageLogUpsert) ClearBillingMode() *UsageLogUpsert {
return u
}
+// SetMediaType sets the "media_type" field.
+func (u *UsageLogUpsert) SetMediaType(v string) *UsageLogUpsert {
+ u.Set(usagelog.FieldMediaType, v)
+ return u
+}
+
+// UpdateMediaType sets the "media_type" field to the value that was provided on create.
+func (u *UsageLogUpsert) UpdateMediaType() *UsageLogUpsert {
+ u.SetExcluded(usagelog.FieldMediaType)
+ return u
+}
+
+// ClearMediaType clears the value of the "media_type" field.
+func (u *UsageLogUpsert) ClearMediaType() *UsageLogUpsert {
+ u.SetNull(usagelog.FieldMediaType)
+ return u
+}
+
// SetGroupID sets the "group_id" field.
func (u *UsageLogUpsert) SetGroupID(v int64) *UsageLogUpsert {
u.Set(usagelog.FieldGroupID, v)
@@ -1939,6 +1980,27 @@ func (u *UsageLogUpsertOne) ClearBillingMode() *UsageLogUpsertOne {
})
}
+// SetMediaType sets the "media_type" field.
+func (u *UsageLogUpsertOne) SetMediaType(v string) *UsageLogUpsertOne {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.SetMediaType(v)
+ })
+}
+
+// UpdateMediaType sets the "media_type" field to the value that was provided on create.
+func (u *UsageLogUpsertOne) UpdateMediaType() *UsageLogUpsertOne {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.UpdateMediaType()
+ })
+}
+
+// ClearMediaType clears the value of the "media_type" field.
+func (u *UsageLogUpsertOne) ClearMediaType() *UsageLogUpsertOne {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.ClearMediaType()
+ })
+}
+
// SetGroupID sets the "group_id" field.
func (u *UsageLogUpsertOne) SetGroupID(v int64) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
@@ -2885,6 +2947,27 @@ func (u *UsageLogUpsertBulk) ClearBillingMode() *UsageLogUpsertBulk {
})
}
+// SetMediaType sets the "media_type" field.
+func (u *UsageLogUpsertBulk) SetMediaType(v string) *UsageLogUpsertBulk {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.SetMediaType(v)
+ })
+}
+
+// UpdateMediaType sets the "media_type" field to the value that was provided on create.
+func (u *UsageLogUpsertBulk) UpdateMediaType() *UsageLogUpsertBulk {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.UpdateMediaType()
+ })
+}
+
+// ClearMediaType clears the value of the "media_type" field.
+func (u *UsageLogUpsertBulk) ClearMediaType() *UsageLogUpsertBulk {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.ClearMediaType()
+ })
+}
+
// SetGroupID sets the "group_id" field.
func (u *UsageLogUpsertBulk) SetGroupID(v int64) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
diff --git a/backend/ent/usagelog_update.go b/backend/ent/usagelog_update.go
index bb5ac86c..1dd92420 100644
--- a/backend/ent/usagelog_update.go
+++ b/backend/ent/usagelog_update.go
@@ -229,6 +229,26 @@ func (_u *UsageLogUpdate) ClearBillingMode() *UsageLogUpdate {
return _u
}
+// SetMediaType sets the "media_type" field.
+func (_u *UsageLogUpdate) SetMediaType(v string) *UsageLogUpdate {
+ _u.mutation.SetMediaType(v)
+ return _u
+}
+
+// SetNillableMediaType sets the "media_type" field if the given value is not nil.
+func (_u *UsageLogUpdate) SetNillableMediaType(v *string) *UsageLogUpdate {
+ if v != nil {
+ _u.SetMediaType(*v)
+ }
+ return _u
+}
+
+// ClearMediaType clears the value of the "media_type" field.
+func (_u *UsageLogUpdate) ClearMediaType() *UsageLogUpdate {
+ _u.mutation.ClearMediaType()
+ return _u
+}
+
// SetGroupID sets the "group_id" field.
func (_u *UsageLogUpdate) SetGroupID(v int64) *UsageLogUpdate {
_u.mutation.SetGroupID(v)
@@ -877,6 +897,11 @@ func (_u *UsageLogUpdate) check() error {
return &ValidationError{Name: "billing_mode", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_mode": %w`, err)}
}
}
+ if v, ok := _u.mutation.MediaType(); ok {
+ if err := usagelog.MediaTypeValidator(v); err != nil {
+ return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)}
+ }
+ }
if v, ok := _u.mutation.UserAgent(); ok {
if err := usagelog.UserAgentValidator(v); err != nil {
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
@@ -961,6 +986,12 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.BillingModeCleared() {
_spec.ClearField(usagelog.FieldBillingMode, field.TypeString)
}
+ if value, ok := _u.mutation.MediaType(); ok {
+ _spec.SetField(usagelog.FieldMediaType, field.TypeString, value)
+ }
+ if _u.mutation.MediaTypeCleared() {
+ _spec.ClearField(usagelog.FieldMediaType, field.TypeString)
+ }
if value, ok := _u.mutation.InputTokens(); ok {
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
}
@@ -1464,6 +1495,26 @@ func (_u *UsageLogUpdateOne) ClearBillingMode() *UsageLogUpdateOne {
return _u
}
+// SetMediaType sets the "media_type" field.
+func (_u *UsageLogUpdateOne) SetMediaType(v string) *UsageLogUpdateOne {
+ _u.mutation.SetMediaType(v)
+ return _u
+}
+
+// SetNillableMediaType sets the "media_type" field if the given value is not nil.
+func (_u *UsageLogUpdateOne) SetNillableMediaType(v *string) *UsageLogUpdateOne {
+ if v != nil {
+ _u.SetMediaType(*v)
+ }
+ return _u
+}
+
+// ClearMediaType clears the value of the "media_type" field.
+func (_u *UsageLogUpdateOne) ClearMediaType() *UsageLogUpdateOne {
+ _u.mutation.ClearMediaType()
+ return _u
+}
+
// SetGroupID sets the "group_id" field.
func (_u *UsageLogUpdateOne) SetGroupID(v int64) *UsageLogUpdateOne {
_u.mutation.SetGroupID(v)
@@ -2125,6 +2176,11 @@ func (_u *UsageLogUpdateOne) check() error {
return &ValidationError{Name: "billing_mode", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_mode": %w`, err)}
}
}
+ if v, ok := _u.mutation.MediaType(); ok {
+ if err := usagelog.MediaTypeValidator(v); err != nil {
+ return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)}
+ }
+ }
if v, ok := _u.mutation.UserAgent(); ok {
if err := usagelog.UserAgentValidator(v); err != nil {
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
@@ -2226,6 +2282,12 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
if _u.mutation.BillingModeCleared() {
_spec.ClearField(usagelog.FieldBillingMode, field.TypeString)
}
+ if value, ok := _u.mutation.MediaType(); ok {
+ _spec.SetField(usagelog.FieldMediaType, field.TypeString, value)
+ }
+ if _u.mutation.MediaTypeCleared() {
+ _spec.ClearField(usagelog.FieldMediaType, field.TypeString)
+ }
if value, ok := _u.mutation.InputTokens(); ok {
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
}
diff --git a/backend/ent/user.go b/backend/ent/user.go
index a0eef2ba..72a1ad7c 100644
--- a/backend/ent/user.go
+++ b/backend/ent/user.go
@@ -45,6 +45,10 @@ type User struct {
TotpEnabled bool `json:"totp_enabled,omitempty"`
// TotpEnabledAt holds the value of the "totp_enabled_at" field.
TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"`
+ // SoraStorageQuotaBytes holds the value of the "sora_storage_quota_bytes" field.
+ SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes,omitempty"`
+ // SoraStorageUsedBytes holds the value of the "sora_storage_used_bytes" field.
+ SoraStorageUsedBytes int64 `json:"sora_storage_used_bytes,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the UserQuery when eager-loading is set.
Edges UserEdges `json:"edges"`
@@ -188,7 +192,7 @@ func (*User) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullBool)
case user.FieldBalance:
values[i] = new(sql.NullFloat64)
- case user.FieldID, user.FieldConcurrency:
+ case user.FieldID, user.FieldConcurrency, user.FieldSoraStorageQuotaBytes, user.FieldSoraStorageUsedBytes:
values[i] = new(sql.NullInt64)
case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted:
values[i] = new(sql.NullString)
@@ -302,6 +306,18 @@ func (_m *User) assignValues(columns []string, values []any) error {
_m.TotpEnabledAt = new(time.Time)
*_m.TotpEnabledAt = value.Time
}
+ case user.FieldSoraStorageQuotaBytes:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field sora_storage_quota_bytes", values[i])
+ } else if value.Valid {
+ _m.SoraStorageQuotaBytes = value.Int64
+ }
+ case user.FieldSoraStorageUsedBytes:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field sora_storage_used_bytes", values[i])
+ } else if value.Valid {
+ _m.SoraStorageUsedBytes = value.Int64
+ }
default:
_m.selectValues.Set(columns[i], values[i])
}
@@ -440,6 +456,12 @@ func (_m *User) String() string {
builder.WriteString("totp_enabled_at=")
builder.WriteString(v.Format(time.ANSIC))
}
+ builder.WriteString(", ")
+ builder.WriteString("sora_storage_quota_bytes=")
+ builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageQuotaBytes))
+ builder.WriteString(", ")
+ builder.WriteString("sora_storage_used_bytes=")
+ builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageUsedBytes))
builder.WriteByte(')')
return builder.String()
}
diff --git a/backend/ent/user/user.go b/backend/ent/user/user.go
index 338518a8..affc4b53 100644
--- a/backend/ent/user/user.go
+++ b/backend/ent/user/user.go
@@ -43,6 +43,10 @@ const (
FieldTotpEnabled = "totp_enabled"
// FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database.
FieldTotpEnabledAt = "totp_enabled_at"
+ // FieldSoraStorageQuotaBytes holds the string denoting the sora_storage_quota_bytes field in the database.
+ FieldSoraStorageQuotaBytes = "sora_storage_quota_bytes"
+ // FieldSoraStorageUsedBytes holds the string denoting the sora_storage_used_bytes field in the database.
+ FieldSoraStorageUsedBytes = "sora_storage_used_bytes"
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
EdgeAPIKeys = "api_keys"
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
@@ -161,6 +165,8 @@ var Columns = []string{
FieldTotpSecretEncrypted,
FieldTotpEnabled,
FieldTotpEnabledAt,
+ FieldSoraStorageQuotaBytes,
+ FieldSoraStorageUsedBytes,
}
var (
@@ -217,6 +223,10 @@ var (
DefaultNotes string
// DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field.
DefaultTotpEnabled bool
+ // DefaultSoraStorageQuotaBytes holds the default value on creation for the "sora_storage_quota_bytes" field.
+ DefaultSoraStorageQuotaBytes int64
+ // DefaultSoraStorageUsedBytes holds the default value on creation for the "sora_storage_used_bytes" field.
+ DefaultSoraStorageUsedBytes int64
)
// OrderOption defines the ordering options for the User queries.
@@ -297,6 +307,16 @@ func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc()
}
+// BySoraStorageQuotaBytes orders the results by the sora_storage_quota_bytes field.
+func BySoraStorageQuotaBytes(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldSoraStorageQuotaBytes, opts...).ToFunc()
+}
+
+// BySoraStorageUsedBytes orders the results by the sora_storage_used_bytes field.
+func BySoraStorageUsedBytes(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldSoraStorageUsedBytes, opts...).ToFunc()
+}
+
// ByAPIKeysCount orders the results by api_keys count.
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
diff --git a/backend/ent/user/where.go b/backend/ent/user/where.go
index b1d1000f..90b5fefb 100644
--- a/backend/ent/user/where.go
+++ b/backend/ent/user/where.go
@@ -125,6 +125,16 @@ func TotpEnabledAt(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v))
}
+// SoraStorageQuotaBytes applies equality check predicate on the "sora_storage_quota_bytes" field. It's identical to SoraStorageQuotaBytesEQ.
+func SoraStorageQuotaBytes(v int64) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldSoraStorageQuotaBytes, v))
+}
+
+// SoraStorageUsedBytes applies equality check predicate on the "sora_storage_used_bytes" field. It's identical to SoraStorageUsedBytesEQ.
+func SoraStorageUsedBytes(v int64) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldSoraStorageUsedBytes, v))
+}
+
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldCreatedAt, v))
@@ -860,6 +870,86 @@ func TotpEnabledAtNotNil() predicate.User {
return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt))
}
+// SoraStorageQuotaBytesEQ applies the EQ predicate on the "sora_storage_quota_bytes" field.
+func SoraStorageQuotaBytesEQ(v int64) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldSoraStorageQuotaBytes, v))
+}
+
+// SoraStorageQuotaBytesNEQ applies the NEQ predicate on the "sora_storage_quota_bytes" field.
+func SoraStorageQuotaBytesNEQ(v int64) predicate.User {
+ return predicate.User(sql.FieldNEQ(FieldSoraStorageQuotaBytes, v))
+}
+
+// SoraStorageQuotaBytesIn applies the In predicate on the "sora_storage_quota_bytes" field.
+func SoraStorageQuotaBytesIn(vs ...int64) predicate.User {
+ return predicate.User(sql.FieldIn(FieldSoraStorageQuotaBytes, vs...))
+}
+
+// SoraStorageQuotaBytesNotIn applies the NotIn predicate on the "sora_storage_quota_bytes" field.
+func SoraStorageQuotaBytesNotIn(vs ...int64) predicate.User {
+ return predicate.User(sql.FieldNotIn(FieldSoraStorageQuotaBytes, vs...))
+}
+
+// SoraStorageQuotaBytesGT applies the GT predicate on the "sora_storage_quota_bytes" field.
+func SoraStorageQuotaBytesGT(v int64) predicate.User {
+ return predicate.User(sql.FieldGT(FieldSoraStorageQuotaBytes, v))
+}
+
+// SoraStorageQuotaBytesGTE applies the GTE predicate on the "sora_storage_quota_bytes" field.
+func SoraStorageQuotaBytesGTE(v int64) predicate.User {
+ return predicate.User(sql.FieldGTE(FieldSoraStorageQuotaBytes, v))
+}
+
+// SoraStorageQuotaBytesLT applies the LT predicate on the "sora_storage_quota_bytes" field.
+func SoraStorageQuotaBytesLT(v int64) predicate.User {
+ return predicate.User(sql.FieldLT(FieldSoraStorageQuotaBytes, v))
+}
+
+// SoraStorageQuotaBytesLTE applies the LTE predicate on the "sora_storage_quota_bytes" field.
+func SoraStorageQuotaBytesLTE(v int64) predicate.User {
+ return predicate.User(sql.FieldLTE(FieldSoraStorageQuotaBytes, v))
+}
+
+// SoraStorageUsedBytesEQ applies the EQ predicate on the "sora_storage_used_bytes" field.
+func SoraStorageUsedBytesEQ(v int64) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldSoraStorageUsedBytes, v))
+}
+
+// SoraStorageUsedBytesNEQ applies the NEQ predicate on the "sora_storage_used_bytes" field.
+func SoraStorageUsedBytesNEQ(v int64) predicate.User {
+ return predicate.User(sql.FieldNEQ(FieldSoraStorageUsedBytes, v))
+}
+
+// SoraStorageUsedBytesIn applies the In predicate on the "sora_storage_used_bytes" field.
+func SoraStorageUsedBytesIn(vs ...int64) predicate.User {
+ return predicate.User(sql.FieldIn(FieldSoraStorageUsedBytes, vs...))
+}
+
+// SoraStorageUsedBytesNotIn applies the NotIn predicate on the "sora_storage_used_bytes" field.
+func SoraStorageUsedBytesNotIn(vs ...int64) predicate.User {
+ return predicate.User(sql.FieldNotIn(FieldSoraStorageUsedBytes, vs...))
+}
+
+// SoraStorageUsedBytesGT applies the GT predicate on the "sora_storage_used_bytes" field.
+func SoraStorageUsedBytesGT(v int64) predicate.User {
+ return predicate.User(sql.FieldGT(FieldSoraStorageUsedBytes, v))
+}
+
+// SoraStorageUsedBytesGTE applies the GTE predicate on the "sora_storage_used_bytes" field.
+func SoraStorageUsedBytesGTE(v int64) predicate.User {
+ return predicate.User(sql.FieldGTE(FieldSoraStorageUsedBytes, v))
+}
+
+// SoraStorageUsedBytesLT applies the LT predicate on the "sora_storage_used_bytes" field.
+func SoraStorageUsedBytesLT(v int64) predicate.User {
+ return predicate.User(sql.FieldLT(FieldSoraStorageUsedBytes, v))
+}
+
+// SoraStorageUsedBytesLTE applies the LTE predicate on the "sora_storage_used_bytes" field.
+func SoraStorageUsedBytesLTE(v int64) predicate.User {
+ return predicate.User(sql.FieldLTE(FieldSoraStorageUsedBytes, v))
+}
+
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
func HasAPIKeys() predicate.User {
return predicate.User(func(s *sql.Selector) {
diff --git a/backend/ent/user_create.go b/backend/ent/user_create.go
index 7f1c5df1..bc19e7ef 100644
--- a/backend/ent/user_create.go
+++ b/backend/ent/user_create.go
@@ -211,6 +211,34 @@ func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate {
return _c
}
+// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
+func (_c *UserCreate) SetSoraStorageQuotaBytes(v int64) *UserCreate {
+ _c.mutation.SetSoraStorageQuotaBytes(v)
+ return _c
+}
+
+// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
+func (_c *UserCreate) SetNillableSoraStorageQuotaBytes(v *int64) *UserCreate {
+ if v != nil {
+ _c.SetSoraStorageQuotaBytes(*v)
+ }
+ return _c
+}
+
+// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
+func (_c *UserCreate) SetSoraStorageUsedBytes(v int64) *UserCreate {
+ _c.mutation.SetSoraStorageUsedBytes(v)
+ return _c
+}
+
+// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil.
+func (_c *UserCreate) SetNillableSoraStorageUsedBytes(v *int64) *UserCreate {
+ if v != nil {
+ _c.SetSoraStorageUsedBytes(*v)
+ }
+ return _c
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate {
_c.mutation.AddAPIKeyIDs(ids...)
@@ -440,6 +468,14 @@ func (_c *UserCreate) defaults() error {
v := user.DefaultTotpEnabled
_c.mutation.SetTotpEnabled(v)
}
+ if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok {
+ v := user.DefaultSoraStorageQuotaBytes
+ _c.mutation.SetSoraStorageQuotaBytes(v)
+ }
+ if _, ok := _c.mutation.SoraStorageUsedBytes(); !ok {
+ v := user.DefaultSoraStorageUsedBytes
+ _c.mutation.SetSoraStorageUsedBytes(v)
+ }
return nil
}
@@ -503,6 +539,12 @@ func (_c *UserCreate) check() error {
if _, ok := _c.mutation.TotpEnabled(); !ok {
return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)}
}
+ if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok {
+ return &ValidationError{Name: "sora_storage_quota_bytes", err: errors.New(`ent: missing required field "User.sora_storage_quota_bytes"`)}
+ }
+ if _, ok := _c.mutation.SoraStorageUsedBytes(); !ok {
+ return &ValidationError{Name: "sora_storage_used_bytes", err: errors.New(`ent: missing required field "User.sora_storage_used_bytes"`)}
+ }
return nil
}
@@ -586,6 +628,14 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
_node.TotpEnabledAt = &value
}
+ if value, ok := _c.mutation.SoraStorageQuotaBytes(); ok {
+ _spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
+ _node.SoraStorageQuotaBytes = value
+ }
+ if value, ok := _c.mutation.SoraStorageUsedBytes(); ok {
+ _spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
+ _node.SoraStorageUsedBytes = value
+ }
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -988,6 +1038,42 @@ func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert {
return u
}
+// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
+func (u *UserUpsert) SetSoraStorageQuotaBytes(v int64) *UserUpsert {
+ u.Set(user.FieldSoraStorageQuotaBytes, v)
+ return u
+}
+
+// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
+func (u *UserUpsert) UpdateSoraStorageQuotaBytes() *UserUpsert {
+ u.SetExcluded(user.FieldSoraStorageQuotaBytes)
+ return u
+}
+
+// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
+func (u *UserUpsert) AddSoraStorageQuotaBytes(v int64) *UserUpsert {
+ u.Add(user.FieldSoraStorageQuotaBytes, v)
+ return u
+}
+
+// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
+func (u *UserUpsert) SetSoraStorageUsedBytes(v int64) *UserUpsert {
+ u.Set(user.FieldSoraStorageUsedBytes, v)
+ return u
+}
+
+// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create.
+func (u *UserUpsert) UpdateSoraStorageUsedBytes() *UserUpsert {
+ u.SetExcluded(user.FieldSoraStorageUsedBytes)
+ return u
+}
+
+// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field.
+func (u *UserUpsert) AddSoraStorageUsedBytes(v int64) *UserUpsert {
+ u.Add(user.FieldSoraStorageUsedBytes, v)
+ return u
+}
+
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
@@ -1250,6 +1336,48 @@ func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne {
})
}
+// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
+func (u *UserUpsertOne) SetSoraStorageQuotaBytes(v int64) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.SetSoraStorageQuotaBytes(v)
+ })
+}
+
+// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
+func (u *UserUpsertOne) AddSoraStorageQuotaBytes(v int64) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.AddSoraStorageQuotaBytes(v)
+ })
+}
+
+// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
+func (u *UserUpsertOne) UpdateSoraStorageQuotaBytes() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateSoraStorageQuotaBytes()
+ })
+}
+
+// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
+func (u *UserUpsertOne) SetSoraStorageUsedBytes(v int64) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.SetSoraStorageUsedBytes(v)
+ })
+}
+
+// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field.
+func (u *UserUpsertOne) AddSoraStorageUsedBytes(v int64) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.AddSoraStorageUsedBytes(v)
+ })
+}
+
+// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create.
+func (u *UserUpsertOne) UpdateSoraStorageUsedBytes() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateSoraStorageUsedBytes()
+ })
+}
+
// Exec executes the query.
func (u *UserUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
@@ -1678,6 +1806,48 @@ func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk {
})
}
+// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
+func (u *UserUpsertBulk) SetSoraStorageQuotaBytes(v int64) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.SetSoraStorageQuotaBytes(v)
+ })
+}
+
+// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
+func (u *UserUpsertBulk) AddSoraStorageQuotaBytes(v int64) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.AddSoraStorageQuotaBytes(v)
+ })
+}
+
+// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
+func (u *UserUpsertBulk) UpdateSoraStorageQuotaBytes() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateSoraStorageQuotaBytes()
+ })
+}
+
+// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
+func (u *UserUpsertBulk) SetSoraStorageUsedBytes(v int64) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.SetSoraStorageUsedBytes(v)
+ })
+}
+
+// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field.
+func (u *UserUpsertBulk) AddSoraStorageUsedBytes(v int64) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.AddSoraStorageUsedBytes(v)
+ })
+}
+
+// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create.
+func (u *UserUpsertBulk) UpdateSoraStorageUsedBytes() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateSoraStorageUsedBytes()
+ })
+}
+
// Exec executes the query.
func (u *UserUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {
diff --git a/backend/ent/user_update.go b/backend/ent/user_update.go
index 8107c980..87758a59 100644
--- a/backend/ent/user_update.go
+++ b/backend/ent/user_update.go
@@ -243,6 +243,48 @@ func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate {
return _u
}
+// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
+func (_u *UserUpdate) SetSoraStorageQuotaBytes(v int64) *UserUpdate {
+ _u.mutation.ResetSoraStorageQuotaBytes()
+ _u.mutation.SetSoraStorageQuotaBytes(v)
+ return _u
+}
+
+// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
+func (_u *UserUpdate) SetNillableSoraStorageQuotaBytes(v *int64) *UserUpdate {
+ if v != nil {
+ _u.SetSoraStorageQuotaBytes(*v)
+ }
+ return _u
+}
+
+// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field.
+func (_u *UserUpdate) AddSoraStorageQuotaBytes(v int64) *UserUpdate {
+ _u.mutation.AddSoraStorageQuotaBytes(v)
+ return _u
+}
+
+// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
+func (_u *UserUpdate) SetSoraStorageUsedBytes(v int64) *UserUpdate {
+ _u.mutation.ResetSoraStorageUsedBytes()
+ _u.mutation.SetSoraStorageUsedBytes(v)
+ return _u
+}
+
+// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil.
+func (_u *UserUpdate) SetNillableSoraStorageUsedBytes(v *int64) *UserUpdate {
+ if v != nil {
+ _u.SetSoraStorageUsedBytes(*v)
+ }
+ return _u
+}
+
+// AddSoraStorageUsedBytes adds value to the "sora_storage_used_bytes" field.
+func (_u *UserUpdate) AddSoraStorageUsedBytes(v int64) *UserUpdate {
+ _u.mutation.AddSoraStorageUsedBytes(v)
+ return _u
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -746,6 +788,18 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.TotpEnabledAtCleared() {
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
}
+ if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok {
+ _spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok {
+ _spec.AddField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.SoraStorageUsedBytes(); ok {
+ _spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedSoraStorageUsedBytes(); ok {
+ _spec.AddField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
+ }
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -1434,6 +1488,48 @@ func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne {
return _u
}
+// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
+func (_u *UserUpdateOne) SetSoraStorageQuotaBytes(v int64) *UserUpdateOne {
+ _u.mutation.ResetSoraStorageQuotaBytes()
+ _u.mutation.SetSoraStorageQuotaBytes(v)
+ return _u
+}
+
+// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
+func (_u *UserUpdateOne) SetNillableSoraStorageQuotaBytes(v *int64) *UserUpdateOne {
+ if v != nil {
+ _u.SetSoraStorageQuotaBytes(*v)
+ }
+ return _u
+}
+
+// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field.
+func (_u *UserUpdateOne) AddSoraStorageQuotaBytes(v int64) *UserUpdateOne {
+ _u.mutation.AddSoraStorageQuotaBytes(v)
+ return _u
+}
+
+// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
+func (_u *UserUpdateOne) SetSoraStorageUsedBytes(v int64) *UserUpdateOne {
+ _u.mutation.ResetSoraStorageUsedBytes()
+ _u.mutation.SetSoraStorageUsedBytes(v)
+ return _u
+}
+
+// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil.
+func (_u *UserUpdateOne) SetNillableSoraStorageUsedBytes(v *int64) *UserUpdateOne {
+ if v != nil {
+ _u.SetSoraStorageUsedBytes(*v)
+ }
+ return _u
+}
+
+// AddSoraStorageUsedBytes adds value to the "sora_storage_used_bytes" field.
+func (_u *UserUpdateOne) AddSoraStorageUsedBytes(v int64) *UserUpdateOne {
+ _u.mutation.AddSoraStorageUsedBytes(v)
+ return _u
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -1967,6 +2063,18 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
if _u.mutation.TotpEnabledAtCleared() {
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
}
+ if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok {
+ _spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok {
+ _spec.AddField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.SoraStorageUsedBytes(); ok {
+ _spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedSoraStorageUsedBytes(); ok {
+ _spec.AddField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
+ }
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
diff --git a/backend/go.mod b/backend/go.mod
index 66b6cc25..968c2d95 100644
--- a/backend/go.mod
+++ b/backend/go.mod
@@ -5,6 +5,7 @@ go 1.26.2
require (
entgo.io/ent v0.14.5
github.com/DATA-DOG/go-sqlmock v1.5.2
+ github.com/DouDOU-start/go-sora2api v1.1.0 // 从本地版本合并,Sora SDK依赖
github.com/alitto/pond/v2 v2.6.2
github.com/andybalholm/brotli v1.2.0
github.com/aws/aws-sdk-go-v2 v1.41.3
@@ -70,7 +71,14 @@ require (
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 // indirect
github.com/aws/smithy-go v1.24.2 // indirect
+ github.com/bdandy/go-errors v1.2.2 // indirect
+ github.com/bdandy/go-socks4 v1.2.3 // indirect
github.com/bmatcuk/doublestar v1.3.4 // indirect
+ github.com/bogdanfinn/fhttp v0.6.8 // indirect
+ github.com/bogdanfinn/quic-go-utls v1.0.9-utls // indirect
+ github.com/bogdanfinn/tls-client v1.14.0 // indirect
+ github.com/bogdanfinn/utls v1.7.7-barnius // indirect
+ github.com/bogdanfinn/websocket v1.5.5-barnius // indirect
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
@@ -151,6 +159,7 @@ require (
github.com/spf13/cast v1.6.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
+ github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5 // indirect
github.com/testcontainers/testcontainers-go v0.40.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
diff --git a/backend/go.sum b/backend/go.sum
index e4496f2c..863e7ebb 100644
--- a/backend/go.sum
+++ b/backend/go.sum
@@ -10,6 +10,8 @@ github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOEl
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
+github.com/DouDOU-start/go-sora2api v1.1.0 h1:PxWiukK77StiHxEngOFwT1rKUn9oTAJJTl07wQUXwiU=
+github.com/DouDOU-start/go-sora2api v1.1.0/go.mod h1:dcwpethoKfAsMWskDD9iGgc/3yox2tkthPLSMVGnhkE=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/agext/levenshtein v1.2.3 h1:YB2fHEn0UJagG8T1rrWknE3ZQzWM06O8AMAatNn7lmo=
@@ -60,10 +62,24 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 h1:NITQpgo9A5NrDZ57uOWj+abvXSb8
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7/go.mod h1:sks5UWBhEuWYDPdwlnRFn1w7xWdH29Jcpe+/PJQefEs=
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
+github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q=
+github.com/bdandy/go-errors v1.2.2/go.mod h1:NkYHl4Fey9oRRdbB1CoC6e84tuqQHiqrOcZpqFEkBxM=
+github.com/bdandy/go-socks4 v1.2.3 h1:Q6Y2heY1GRjCtHbmlKfnwrKVU/k81LS8mRGLRlmDlic=
+github.com/bdandy/go-socks4 v1.2.3/go.mod h1:98kiVFgpdogR8aIGLWLvjDVZ8XcKPsSI/ypGrO+bqHI=
github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8=
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
github.com/bmatcuk/doublestar v1.3.4 h1:gPypJ5xD31uhX6Tf54sDPUOBXTqKH4c9aPY66CyQrS0=
github.com/bmatcuk/doublestar v1.3.4/go.mod h1:wiQtGV+rzVYxB7WIlirSN++5HPtPlXEo9MEoZQC/PmE=
+github.com/bogdanfinn/fhttp v0.6.8 h1:LiQyHOY3i0QoxxNB7nq27/nGNNbtPj0fuBPozhR7Ws4=
+github.com/bogdanfinn/fhttp v0.6.8/go.mod h1:A+EKDzMx2hb4IUbMx4TlkoHnaJEiLl8r/1Ss1Y+5e5M=
+github.com/bogdanfinn/quic-go-utls v1.0.9-utls h1:tV6eDEiRbRCcepALSzxR94JUVD3N3ACIiRLgyc2Ep8s=
+github.com/bogdanfinn/quic-go-utls v1.0.9-utls/go.mod h1:aHph9B9H9yPOt5xnhWKSOum27DJAqpiHzwX+gjvaXcg=
+github.com/bogdanfinn/tls-client v1.14.0 h1:vyk7Cn4BIvLAGVuMfb0tP22OqogfO1lYamquQNEZU1A=
+github.com/bogdanfinn/tls-client v1.14.0/go.mod h1:LsU6mXVn8MOFDwTkyRfI7V1BZM1p0wf2ZfZsICW/1fM=
+github.com/bogdanfinn/utls v1.7.7-barnius h1:OuJ497cc7F3yKNVHRsYPQdGggmk5x6+V5ZlrCR7fOLU=
+github.com/bogdanfinn/utls v1.7.7-barnius/go.mod h1:aAK1VZQlpKZClF1WEQeq6kyclbkPq4hz6xTbB5xSlmg=
+github.com/bogdanfinn/websocket v1.5.5-barnius h1:bY+qnxpai1qe7Jmjx+Sds/cmOSpuuLoR8x61rWltjOI=
+github.com/bogdanfinn/websocket v1.5.5-barnius/go.mod h1:gvvEw6pTKHb7yOiFvIfAFTStQWyrm25BMVCTj5wRSsI=
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI=
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
@@ -80,6 +96,10 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
+github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
+github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
+github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
+github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
@@ -183,6 +203,8 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
+github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
+github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
@@ -218,6 +240,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
+github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
+github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
@@ -251,6 +275,8 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
+github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
+github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
@@ -312,6 +338,8 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
+github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
+github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
@@ -335,6 +363,8 @@ github.com/stripe/stripe-go/v85 v85.0.0 h1:HMlFJXW6I/9WvkeSAtj8V7dI5pzeDu4gS1Taq
github.com/stripe/stripe-go/v85 v85.0.0/go.mod h1:5P+HGFenpWgak27T5Is6JMsmDfUC1yJnjhhmquz7kXw=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
+github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5 h1:YqAladjX7xpA6BM04leXMWAEjS0mTZ5kUU9KRBriQJc=
+github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5/go.mod h1:2JjD2zLQYH5HO74y5+aE3remJQvl6q4Sn6aWA2wD1Ng=
github.com/testcontainers/testcontainers-go v0.40.0 h1:pSdJYLOVgLE8YdUY2FHQ1Fxu+aMnb6JfVz1mxk7OeMU=
github.com/testcontainers/testcontainers-go v0.40.0/go.mod h1:FSXV5KQtX2HAMlm7U3APNyLkkap35zNLxukw9oBi/MY=
github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0 h1:s2bIayFXlbDFexo96y+htn7FzuhpXLYJNnIuglNKqOk=
@@ -409,12 +439,15 @@ golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c=
golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU=
+golang.org/x/net v0.0.0-20211104170005-ce137452f963/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -424,12 +457,15 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
+golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=
+golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
+golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go
index bc4e5e46..c1245c1e 100644
--- a/backend/internal/config/config.go
+++ b/backend/internal/config/config.go
@@ -80,6 +80,7 @@ type Config struct {
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
+ Sora SoraConfig `mapstructure:"sora"` // 从本地版本合并
Gemini GeminiConfig `mapstructure:"gemini"`
Update UpdateConfig `mapstructure:"update"`
Idempotency IdempotencyConfig `mapstructure:"idempotency"`
@@ -139,6 +140,65 @@ type GeminiTierQuotaConfig struct {
CooldownMinutes *int `mapstructure:"cooldown_minutes" json:"cooldown_minutes"`
}
+// SoraConfig 直连 Sora 配置 (从本地版本合并)
+type SoraConfig struct {
+ Client SoraClientConfig `mapstructure:"client"`
+ Storage SoraStorageConfig `mapstructure:"storage"`
+}
+
+// SoraClientConfig 直连 Sora 客户端配置 (从本地版本合并)
+type SoraClientConfig struct {
+ BaseURL string `mapstructure:"base_url"`
+ TimeoutSeconds int `mapstructure:"timeout_seconds"`
+ MaxRetries int `mapstructure:"max_retries"`
+ CloudflareChallengeCooldownSeconds int `mapstructure:"cloudflare_challenge_cooldown_seconds"`
+ PollIntervalSeconds int `mapstructure:"poll_interval_seconds"`
+ MaxPollAttempts int `mapstructure:"max_poll_attempts"`
+ RecentTaskLimit int `mapstructure:"recent_task_limit"`
+ RecentTaskLimitMax int `mapstructure:"recent_task_limit_max"`
+ Debug bool `mapstructure:"debug"`
+ UseOpenAITokenProvider bool `mapstructure:"use_openai_token_provider"`
+ Headers map[string]string `mapstructure:"headers"`
+ UserAgent string `mapstructure:"user_agent"`
+ DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"`
+ CurlCFFISidecar SoraCurlCFFISidecarConfig `mapstructure:"curl_cffi_sidecar"`
+}
+
+// SoraCurlCFFISidecarConfig Sora 专用 curl_cffi sidecar 配置 (从本地版本合并)
+type SoraCurlCFFISidecarConfig struct {
+ Enabled bool `mapstructure:"enabled"`
+ BaseURL string `mapstructure:"base_url"`
+ Impersonate string `mapstructure:"impersonate"`
+ TimeoutSeconds int `mapstructure:"timeout_seconds"`
+ SessionReuseEnabled bool `mapstructure:"session_reuse_enabled"`
+ SessionTTLSeconds int `mapstructure:"session_ttl_seconds"`
+}
+
+// SoraStorageConfig 媒体存储配置 (从本地版本合并)
+type SoraStorageConfig struct {
+ Type string `mapstructure:"type"`
+ LocalPath string `mapstructure:"local_path"`
+ FallbackToUpstream bool `mapstructure:"fallback_to_upstream"`
+ MaxConcurrentDownloads int `mapstructure:"max_concurrent_downloads"`
+ DownloadTimeoutSeconds int `mapstructure:"download_timeout_seconds"`
+ MaxDownloadBytes int64 `mapstructure:"max_download_bytes"`
+ Debug bool `mapstructure:"debug"`
+ Cleanup SoraStorageCleanupConfig `mapstructure:"cleanup"`
+}
+
+// SoraStorageCleanupConfig 媒体清理配置 (从本地版本合并)
+type SoraStorageCleanupConfig struct {
+ Enabled bool `mapstructure:"enabled"`
+ Schedule string `mapstructure:"schedule"`
+ RetentionDays int `mapstructure:"retention_days"`
+}
+
+// SoraModelFiltersConfig Sora 模型过滤配置 (从本地版本合并)
+type SoraModelFiltersConfig struct {
+ // HidePromptEnhance 是否隐藏 prompt-enhance 模型
+ HidePromptEnhance bool `mapstructure:"hide_prompt_enhance"`
+}
+
type UpdateConfig struct {
// ProxyURL 用于访问 GitHub 的代理地址
// 支持 http/https/socks5/socks5h 协议
@@ -403,6 +463,24 @@ type GatewayConfig struct {
// 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义)
FailoverOn400 bool `mapstructure:"failover_on_400"`
+ // Sora 专用配置 (从本地版本合并)
+ // SoraMaxBodySize: Sora 请求体最大字节数(0 表示使用 gateway.max_body_size)
+ SoraMaxBodySize int64 `mapstructure:"sora_max_body_size"`
+ // SoraStreamTimeoutSeconds: Sora 流式请求总超时(秒,0 表示不限制)
+ SoraStreamTimeoutSeconds int `mapstructure:"sora_stream_timeout_seconds"`
+ // SoraRequestTimeoutSeconds: Sora 非流式请求超时(秒,0 表示不限制)
+ SoraRequestTimeoutSeconds int `mapstructure:"sora_request_timeout_seconds"`
+ // SoraStreamMode: stream 强制策略(force/error)
+ SoraStreamMode string `mapstructure:"sora_stream_mode"`
+ // SoraModelFilters: 模型列表过滤配置
+ SoraModelFilters SoraModelFiltersConfig `mapstructure:"sora_model_filters"`
+ // SoraMediaRequireAPIKey: 是否要求访问 /sora/media 携带 API Key
+ SoraMediaRequireAPIKey bool `mapstructure:"sora_media_require_api_key"`
+ // SoraMediaSigningKey: /sora/media 临时签名密钥(空表示禁用签名)
+ SoraMediaSigningKey string `mapstructure:"sora_media_signing_key"`
+ // SoraMediaSignedURLTTLSeconds: 临时签名 URL 有效期(秒,<=0 表示禁用)
+ SoraMediaSignedURLTTLSeconds int `mapstructure:"sora_media_signed_url_ttl_seconds"`
+
// 账户切换最大次数(遇到上游错误时切换到其他账户的次数上限)
MaxAccountSwitches int `mapstructure:"max_account_switches"`
// Gemini 账户切换最大次数(Gemini 平台单独配置,因 API 限制更严格)
@@ -1152,14 +1230,46 @@ func setDefaults() {
// Security
viper.SetDefault("security.url_allowlist.enabled", false)
viper.SetDefault("security.url_allowlist.upstream_hosts", []string{
+ // 国际模型
"api.openai.com",
"api.anthropic.com",
- "api.kimi.com",
- "open.bigmodel.cn",
- "api.minimaxi.com",
"generativelanguage.googleapis.com",
"cloudcode-pa.googleapis.com",
"*.openai.azure.com",
+ // 国内模型 - 月之暗面Kimi
+ "api.kimi.com",
+ "api.moonshot.cn",
+ // 国内模型 - 智谱GLM
+ "open.bigmodel.cn",
+ "bigmodel.cn",
+ // 国内模型 - MiniMax
+ "api.minimaxi.com",
+ "minimaxi.com",
+ // 国内模型 - 阿里云通义千问
+ "dashscope.aliyuncs.com",
+ "dashscope.aliyun.com",
+ // 国内模型 - 豆包/火山引擎
+ "ark.cn-beijing.volces.com",
+ "ark-api.volces.com",
+ "api.volcengine.com",
+ // 国内模型 - DeepSeek
+ "api.deepseek.com",
+ // 国内模型 - 百度文心
+ "aip.baidubce.com",
+ // 国内模型 - 讯飞星火
+ "spark-api-open.xf-yun.com",
+ // 国内模型 - 腾讯混元
+ "hunyuan.tencentcloudapi.com",
+ // 国内模型 - 零一万物
+ "api.lingyiwanwu.com",
+ // 国内模型 - 百川智能
+ "api.baichuan-ai.com",
+ // 国内模型 - 硅基流动SiliconFlow
+ "api.siliconflow.cn",
+ // 国内模型 - 智谱API域名(国际)
+ "api.z.ai",
+ // 国内模型 - Groq (加速推理)
+ "api.groq.com",
})
viper.SetDefault("security.url_allowlist.pricing_hosts", []string{
"raw.githubusercontent.com",
diff --git a/backend/internal/domain/constants.go b/backend/internal/domain/constants.go
index 429486c3..885ef834 100644
--- a/backend/internal/domain/constants.go
+++ b/backend/internal/domain/constants.go
@@ -22,6 +22,7 @@ const (
PlatformOpenAI = "openai"
PlatformGemini = "gemini"
PlatformAntigravity = "antigravity"
+ PlatformSora = "sora" // Sora视频生成平台 (从本地版本合并)
)
// Account type constants
diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go
index 906a74f1..adf709b5 100644
--- a/backend/internal/handler/handler.go
+++ b/backend/internal/handler/handler.go
@@ -46,6 +46,8 @@ type Handlers struct {
Admin *AdminHandlers
Gateway *GatewayHandler
OpenAIGateway *OpenAIGatewayHandler
+ SoraGateway *SoraGatewayHandler // 从本地版本合并
+ SoraClient *SoraClientHandler // 从本地版本合并
Setting *SettingHandler
Totp *TotpHandler
Payment *PaymentHandler
diff --git a/backend/internal/handler/sora_client_handler.go b/backend/internal/handler/sora_client_handler.go
new file mode 100644
index 00000000..80acc833
--- /dev/null
+++ b/backend/internal/handler/sora_client_handler.go
@@ -0,0 +1,979 @@
+package handler
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+)
+
+const (
+ // 上游模型缓存 TTL
+ modelCacheTTL = 1 * time.Hour // 上游获取成功
+ modelCacheFailedTTL = 2 * time.Minute // 上游获取失败(降级到本地)
+)
+
+// SoraClientHandler 处理 Sora 客户端 API 请求。
+type SoraClientHandler struct {
+ genService *service.SoraGenerationService
+ quotaService *service.SoraQuotaService
+ s3Storage *service.SoraS3Storage
+ soraGatewayService *service.SoraGatewayService
+ gatewayService *service.GatewayService
+ mediaStorage *service.SoraMediaStorage
+ apiKeyService *service.APIKeyService
+
+ // 上游模型缓存
+ modelCacheMu sync.RWMutex
+ cachedFamilies []service.SoraModelFamily
+ modelCacheTime time.Time
+ modelCacheUpstream bool // 是否来自上游(决定 TTL)
+}
+
+// NewSoraClientHandler 创建 Sora 客户端 Handler。
+func NewSoraClientHandler(
+ genService *service.SoraGenerationService,
+ quotaService *service.SoraQuotaService,
+ s3Storage *service.SoraS3Storage,
+ soraGatewayService *service.SoraGatewayService,
+ gatewayService *service.GatewayService,
+ mediaStorage *service.SoraMediaStorage,
+ apiKeyService *service.APIKeyService,
+) *SoraClientHandler {
+ return &SoraClientHandler{
+ genService: genService,
+ quotaService: quotaService,
+ s3Storage: s3Storage,
+ soraGatewayService: soraGatewayService,
+ gatewayService: gatewayService,
+ mediaStorage: mediaStorage,
+ apiKeyService: apiKeyService,
+ }
+}
+
+// GenerateRequest 生成请求。
+type GenerateRequest struct {
+ Model string `json:"model" binding:"required"`
+ Prompt string `json:"prompt" binding:"required"`
+ MediaType string `json:"media_type"` // video / image,默认 video
+ VideoCount int `json:"video_count,omitempty"` // 视频数量(1-3)
+ ImageInput string `json:"image_input,omitempty"` // 参考图(base64 或 URL)
+ APIKeyID *int64 `json:"api_key_id,omitempty"` // 前端传递的 API Key ID
+}
+
+// Generate 异步生成 — 创建 pending 记录后立即返回。
+// POST /api/v1/sora/generate
+func (h *SoraClientHandler) Generate(c *gin.Context) {
+ userID := getUserIDFromContext(c)
+ if userID == 0 {
+ response.Error(c, http.StatusUnauthorized, "未登录")
+ return
+ }
+
+ var req GenerateRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.Error(c, http.StatusBadRequest, "参数错误: "+err.Error())
+ return
+ }
+
+ if req.MediaType == "" {
+ req.MediaType = "video"
+ }
+ req.VideoCount = normalizeVideoCount(req.MediaType, req.VideoCount)
+
+ // 并发数检查(最多 3 个)
+ activeCount, err := h.genService.CountActiveByUser(c.Request.Context(), userID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if activeCount >= 3 {
+ response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个")
+ return
+ }
+
+ // 配额检查(粗略检查,实际文件大小在上传后才知道)
+ if h.quotaService != nil {
+ if err := h.quotaService.CheckQuota(c.Request.Context(), userID, 0); err != nil {
+ var quotaErr *service.QuotaExceededError
+ if errors.As(err, "aErr) {
+ response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间")
+ return
+ }
+ response.Error(c, http.StatusForbidden, err.Error())
+ return
+ }
+ }
+
+ // 获取 API Key ID 和 Group ID
+ var apiKeyID *int64
+ var groupID *int64
+
+ if req.APIKeyID != nil && h.apiKeyService != nil {
+ // 前端传递了 api_key_id,需要校验
+ apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), *req.APIKeyID)
+ if err != nil {
+ response.Error(c, http.StatusBadRequest, "API Key 不存在")
+ return
+ }
+ if apiKey.UserID != userID {
+ response.Error(c, http.StatusForbidden, "API Key 不属于当前用户")
+ return
+ }
+ if apiKey.Status != service.StatusAPIKeyActive {
+ response.Error(c, http.StatusForbidden, "API Key 不可用")
+ return
+ }
+ apiKeyID = &apiKey.ID
+ groupID = apiKey.GroupID
+ } else if id, ok := c.Get("api_key_id"); ok {
+ // 兼容 API Key 认证路径(/sora/v1/ 网关路由)
+ if v, ok := id.(int64); ok {
+ apiKeyID = &v
+ }
+ }
+
+ gen, err := h.genService.CreatePending(c.Request.Context(), userID, apiKeyID, req.Model, req.Prompt, req.MediaType)
+ if err != nil {
+ if errors.Is(err, service.ErrSoraGenerationConcurrencyLimit) {
+ response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个")
+ return
+ }
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ // 启动后台异步生成 goroutine
+ go h.processGeneration(gen.ID, userID, groupID, req.Model, req.Prompt, req.MediaType, req.ImageInput, req.VideoCount)
+
+ response.Success(c, gin.H{
+ "generation_id": gen.ID,
+ "status": gen.Status,
+ })
+}
+
+// processGeneration 后台异步执行 Sora 生成任务。
+// 流程:选择账号 → Forward → 提取媒体 URL → 三层降级存储(S3 → 本地 → 上游)→ 更新记录。
+func (h *SoraClientHandler) processGeneration(genID int64, userID int64, groupID *int64, model, prompt, mediaType, imageInput string, videoCount int) {
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
+ defer cancel()
+
+ // 标记为生成中
+ if err := h.genService.MarkGenerating(ctx, genID, ""); err != nil {
+ if errors.Is(err, service.ErrSoraGenerationStateConflict) {
+ logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务状态已变化,跳过生成 id=%d", genID)
+ return
+ }
+ logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记生成中失败 id=%d err=%v", genID, err)
+ return
+ }
+
+ logger.LegacyPrintf(
+ "handler.sora_client",
+ "[SoraClient] 开始生成 id=%d user=%d group=%d model=%s media_type=%s video_count=%d has_image=%v prompt_len=%d",
+ genID,
+ userID,
+ groupIDForLog(groupID),
+ model,
+ mediaType,
+ videoCount,
+ strings.TrimSpace(imageInput) != "",
+ len(strings.TrimSpace(prompt)),
+ )
+
+ // 有 groupID 时由分组决定平台,无 groupID 时用 ForcePlatform 兜底
+ if groupID == nil {
+ ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora)
+ }
+
+ if h.gatewayService == nil {
+ _ = h.genService.MarkFailed(ctx, genID, "内部错误: gatewayService 未初始化")
+ return
+ }
+
+ // 选择 Sora 账号
+ account, err := h.gatewayService.SelectAccountForModel(ctx, groupID, "", model)
+ if err != nil {
+ logger.LegacyPrintf(
+ "handler.sora_client",
+ "[SoraClient] 选择账号失败 id=%d user=%d group=%d model=%s err=%v",
+ genID,
+ userID,
+ groupIDForLog(groupID),
+ model,
+ err,
+ )
+ _ = h.genService.MarkFailed(ctx, genID, "选择账号失败: "+err.Error())
+ return
+ }
+ logger.LegacyPrintf(
+ "handler.sora_client",
+ "[SoraClient] 选中账号 id=%d user=%d group=%d model=%s account_id=%d account_name=%s platform=%s type=%s",
+ genID,
+ userID,
+ groupIDForLog(groupID),
+ model,
+ account.ID,
+ account.Name,
+ account.Platform,
+ account.Type,
+ )
+
+ // 构建 chat completions 请求体(非流式)
+ body := buildAsyncRequestBody(model, prompt, imageInput, normalizeVideoCount(mediaType, videoCount))
+
+ if h.soraGatewayService == nil {
+ _ = h.genService.MarkFailed(ctx, genID, "内部错误: soraGatewayService 未初始化")
+ return
+ }
+
+ // 创建 mock gin 上下文用于 Forward(捕获响应以提取媒体 URL)
+ recorder := httptest.NewRecorder()
+ mockGinCtx, _ := gin.CreateTestContext(recorder)
+ mockGinCtx.Request, _ = http.NewRequest("POST", "/", nil)
+
+ // 调用 Forward(非流式)
+ result, err := h.soraGatewayService.Forward(ctx, mockGinCtx, account, body, false)
+ if err != nil {
+ logger.LegacyPrintf(
+ "handler.sora_client",
+ "[SoraClient] Forward失败 id=%d account_id=%d model=%s status=%d body=%s err=%v",
+ genID,
+ account.ID,
+ model,
+ recorder.Code,
+ trimForLog(recorder.Body.String(), 400),
+ err,
+ )
+ // 检查是否已取消
+ gen, _ := h.genService.GetByID(ctx, genID, userID)
+ if gen != nil && gen.Status == service.SoraGenStatusCancelled {
+ return
+ }
+ _ = h.genService.MarkFailed(ctx, genID, "生成失败: "+err.Error())
+ return
+ }
+
+ // 提取媒体 URL(优先从 ForwardResult,其次从响应体解析)
+ mediaURL, mediaURLs := extractMediaURLsFromResult(result, recorder)
+ if mediaURL == "" {
+ logger.LegacyPrintf(
+ "handler.sora_client",
+ "[SoraClient] 未提取到媒体URL id=%d account_id=%d model=%s status=%d body=%s",
+ genID,
+ account.ID,
+ model,
+ recorder.Code,
+ trimForLog(recorder.Body.String(), 400),
+ )
+ _ = h.genService.MarkFailed(ctx, genID, "未获取到媒体 URL")
+ return
+ }
+
+ // 检查任务是否已被取消
+ gen, _ := h.genService.GetByID(ctx, genID, userID)
+ if gen != nil && gen.Status == service.SoraGenStatusCancelled {
+ logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务已取消,跳过存储 id=%d", genID)
+ return
+ }
+
+ // 三层降级存储:S3 → 本地 → 上游临时 URL
+ storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(ctx, userID, mediaType, mediaURL, mediaURLs)
+
+ usageAdded := false
+ if (storageType == service.SoraStorageTypeS3 || storageType == service.SoraStorageTypeLocal) && fileSize > 0 && h.quotaService != nil {
+ if err := h.quotaService.AddUsage(ctx, userID, fileSize); err != nil {
+ h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
+ var quotaErr *service.QuotaExceededError
+ if errors.As(err, "aErr) {
+ _ = h.genService.MarkFailed(ctx, genID, "存储配额已满,请删除不需要的作品释放空间")
+ return
+ }
+ _ = h.genService.MarkFailed(ctx, genID, "存储配额更新失败: "+err.Error())
+ return
+ }
+ usageAdded = true
+ }
+
+ // 存储完成后再做一次取消检查,防止取消被 completed 覆盖。
+ gen, _ = h.genService.GetByID(ctx, genID, userID)
+ if gen != nil && gen.Status == service.SoraGenStatusCancelled {
+ logger.LegacyPrintf("handler.sora_client", "[SoraClient] 存储后检测到任务已取消,回滚存储 id=%d", genID)
+ h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
+ if usageAdded && h.quotaService != nil {
+ _ = h.quotaService.ReleaseUsage(ctx, userID, fileSize)
+ }
+ return
+ }
+
+ // 标记完成
+ if err := h.genService.MarkCompleted(ctx, genID, storedURL, storedURLs, storageType, s3Keys, fileSize); err != nil {
+ if errors.Is(err, service.ErrSoraGenerationStateConflict) {
+ h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
+ if usageAdded && h.quotaService != nil {
+ _ = h.quotaService.ReleaseUsage(ctx, userID, fileSize)
+ }
+ return
+ }
+ logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记完成失败 id=%d err=%v", genID, err)
+ return
+ }
+
+ logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成完成 id=%d storage=%s size=%d", genID, storageType, fileSize)
+}
+
+// storeMediaWithDegradation 实现三层降级存储链:S3 → 本地 → 上游。
+func (h *SoraClientHandler) storeMediaWithDegradation(
+ ctx context.Context, userID int64, mediaType string,
+ mediaURL string, mediaURLs []string,
+) (storedURL string, storedURLs []string, storageType string, s3Keys []string, fileSize int64) {
+ urls := mediaURLs
+ if len(urls) == 0 {
+ urls = []string{mediaURL}
+ }
+
+ // 第一层:尝试 S3
+ if h.s3Storage != nil && h.s3Storage.Enabled(ctx) {
+ keys := make([]string, 0, len(urls))
+ var totalSize int64
+ allOK := true
+ for _, u := range urls {
+ key, size, err := h.s3Storage.UploadFromURL(ctx, userID, u)
+ if err != nil {
+ logger.LegacyPrintf("handler.sora_client", "[SoraClient] S3 上传失败 err=%v", err)
+ allOK = false
+ // 清理已上传的文件
+ if len(keys) > 0 {
+ _ = h.s3Storage.DeleteObjects(ctx, keys)
+ }
+ break
+ }
+ keys = append(keys, key)
+ totalSize += size
+ }
+ if allOK && len(keys) > 0 {
+ accessURLs := make([]string, 0, len(keys))
+ for _, key := range keys {
+ accessURL, err := h.s3Storage.GetAccessURL(ctx, key)
+ if err != nil {
+ logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成 S3 访问 URL 失败 err=%v", err)
+ _ = h.s3Storage.DeleteObjects(ctx, keys)
+ allOK = false
+ break
+ }
+ accessURLs = append(accessURLs, accessURL)
+ }
+ if allOK && len(accessURLs) > 0 {
+ return accessURLs[0], accessURLs, service.SoraStorageTypeS3, keys, totalSize
+ }
+ }
+ }
+
+ // 第二层:尝试本地存储
+ if h.mediaStorage != nil && h.mediaStorage.Enabled() {
+ storedPaths, err := h.mediaStorage.StoreFromURLs(ctx, mediaType, urls)
+ if err == nil && len(storedPaths) > 0 {
+ firstPath := storedPaths[0]
+ totalSize, sizeErr := h.mediaStorage.TotalSizeByRelativePaths(storedPaths)
+ if sizeErr != nil {
+ logger.LegacyPrintf("handler.sora_client", "[SoraClient] 统计本地文件大小失败 err=%v", sizeErr)
+ }
+ return firstPath, storedPaths, service.SoraStorageTypeLocal, nil, totalSize
+ }
+ logger.LegacyPrintf("handler.sora_client", "[SoraClient] 本地存储失败 err=%v", err)
+ }
+
+ // 第三层:保留上游临时 URL
+ return urls[0], urls, service.SoraStorageTypeUpstream, nil, 0
+}
+
+// buildAsyncRequestBody 构建 Sora 异步生成的 chat completions 请求体。
+func buildAsyncRequestBody(model, prompt, imageInput string, videoCount int) []byte {
+ body := map[string]any{
+ "model": model,
+ "messages": []map[string]string{
+ {"role": "user", "content": prompt},
+ },
+ "stream": false,
+ }
+ if imageInput != "" {
+ body["image_input"] = imageInput
+ }
+ if videoCount > 1 {
+ body["video_count"] = videoCount
+ }
+ b, _ := json.Marshal(body)
+ return b
+}
+
+func normalizeVideoCount(mediaType string, videoCount int) int {
+ if mediaType != "video" {
+ return 1
+ }
+ if videoCount <= 0 {
+ return 1
+ }
+ if videoCount > 3 {
+ return 3
+ }
+ return videoCount
+}
+
+// extractMediaURLsFromResult 从 Forward 结果和响应体中提取媒体 URL。
+// OAuth 路径:ForwardResult.MediaURL 已填充。
+// APIKey 路径:需从响应体解析 media_url / media_urls 字段。
+func extractMediaURLsFromResult(result *service.ForwardResult, recorder *httptest.ResponseRecorder) (string, []string) {
+ // 优先从 ForwardResult 获取(OAuth 路径)
+ if result != nil && result.MediaURL != "" {
+ // 尝试从响应体获取完整 URL 列表
+ if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 {
+ return urls[0], urls
+ }
+ return result.MediaURL, []string{result.MediaURL}
+ }
+
+ // 从响应体解析(APIKey 路径)
+ if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 {
+ return urls[0], urls
+ }
+
+ return "", nil
+}
+
+// parseMediaURLsFromBody 从 JSON 响应体中解析 media_url / media_urls 字段。
+func parseMediaURLsFromBody(body []byte) []string {
+ if len(body) == 0 {
+ return nil
+ }
+ var resp map[string]any
+ if err := json.Unmarshal(body, &resp); err != nil {
+ return nil
+ }
+
+ // 优先 media_urls(多图数组)
+ if rawURLs, ok := resp["media_urls"]; ok {
+ if arr, ok := rawURLs.([]any); ok && len(arr) > 0 {
+ urls := make([]string, 0, len(arr))
+ for _, item := range arr {
+ if s, ok := item.(string); ok && s != "" {
+ urls = append(urls, s)
+ }
+ }
+ if len(urls) > 0 {
+ return urls
+ }
+ }
+ }
+
+ // 回退到 media_url(单个 URL)
+ if url, ok := resp["media_url"].(string); ok && url != "" {
+ return []string{url}
+ }
+
+ return nil
+}
+
+// ListGenerations 查询生成记录列表。
+// GET /api/v1/sora/generations
+func (h *SoraClientHandler) ListGenerations(c *gin.Context) {
+ userID := getUserIDFromContext(c)
+ if userID == 0 {
+ response.Error(c, http.StatusUnauthorized, "未登录")
+ return
+ }
+
+ page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
+ pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
+
+ params := service.SoraGenerationListParams{
+ UserID: userID,
+ Status: c.Query("status"),
+ StorageType: c.Query("storage_type"),
+ MediaType: c.Query("media_type"),
+ Page: page,
+ PageSize: pageSize,
+ }
+
+ gens, total, err := h.genService.List(c.Request.Context(), params)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ // 为 S3 记录动态生成预签名 URL
+ for _, gen := range gens {
+ _ = h.genService.ResolveMediaURLs(c.Request.Context(), gen)
+ }
+
+ response.Success(c, gin.H{
+ "data": gens,
+ "total": total,
+ "page": page,
+ })
+}
+
+// GetGeneration 查询生成记录详情。
+// GET /api/v1/sora/generations/:id
+func (h *SoraClientHandler) GetGeneration(c *gin.Context) {
+ userID := getUserIDFromContext(c)
+ if userID == 0 {
+ response.Error(c, http.StatusUnauthorized, "未登录")
+ return
+ }
+
+ id, err := strconv.ParseInt(c.Param("id"), 10, 64)
+ if err != nil {
+ response.Error(c, http.StatusBadRequest, "无效的 ID")
+ return
+ }
+
+ gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
+ if err != nil {
+ response.Error(c, http.StatusNotFound, err.Error())
+ return
+ }
+
+ _ = h.genService.ResolveMediaURLs(c.Request.Context(), gen)
+ response.Success(c, gen)
+}
+
+// DeleteGeneration 删除生成记录。
+// DELETE /api/v1/sora/generations/:id
+func (h *SoraClientHandler) DeleteGeneration(c *gin.Context) {
+ userID := getUserIDFromContext(c)
+ if userID == 0 {
+ response.Error(c, http.StatusUnauthorized, "未登录")
+ return
+ }
+
+ id, err := strconv.ParseInt(c.Param("id"), 10, 64)
+ if err != nil {
+ response.Error(c, http.StatusBadRequest, "无效的 ID")
+ return
+ }
+
+ gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
+ if err != nil {
+ response.Error(c, http.StatusNotFound, err.Error())
+ return
+ }
+
+ // 先尝试清理本地文件,再删除记录(清理失败不阻塞删除)。
+ if gen.StorageType == service.SoraStorageTypeLocal && h.mediaStorage != nil {
+ paths := gen.MediaURLs
+ if len(paths) == 0 && gen.MediaURL != "" {
+ paths = []string{gen.MediaURL}
+ }
+ if err := h.mediaStorage.DeleteByRelativePaths(paths); err != nil {
+ logger.LegacyPrintf("handler.sora_client", "[SoraClient] 删除本地文件失败 id=%d err=%v", id, err)
+ }
+ }
+
+ if err := h.genService.Delete(c.Request.Context(), id, userID); err != nil {
+ response.Error(c, http.StatusNotFound, err.Error())
+ return
+ }
+
+ response.Success(c, gin.H{"message": "已删除"})
+}
+
+// GetQuota 查询用户存储配额。
+// GET /api/v1/sora/quota
+func (h *SoraClientHandler) GetQuota(c *gin.Context) {
+ userID := getUserIDFromContext(c)
+ if userID == 0 {
+ response.Error(c, http.StatusUnauthorized, "未登录")
+ return
+ }
+
+ if h.quotaService == nil {
+ response.Success(c, service.QuotaInfo{QuotaSource: "unlimited", Source: "unlimited"})
+ return
+ }
+
+ quota, err := h.quotaService.GetQuota(c.Request.Context(), userID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, quota)
+}
+
+// CancelGeneration 取消生成任务。
+// POST /api/v1/sora/generations/:id/cancel
+func (h *SoraClientHandler) CancelGeneration(c *gin.Context) {
+ userID := getUserIDFromContext(c)
+ if userID == 0 {
+ response.Error(c, http.StatusUnauthorized, "未登录")
+ return
+ }
+
+ id, err := strconv.ParseInt(c.Param("id"), 10, 64)
+ if err != nil {
+ response.Error(c, http.StatusBadRequest, "无效的 ID")
+ return
+ }
+
+ // 权限校验
+ gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
+ if err != nil {
+ response.Error(c, http.StatusNotFound, err.Error())
+ return
+ }
+ _ = gen
+
+ if err := h.genService.MarkCancelled(c.Request.Context(), id); err != nil {
+ if errors.Is(err, service.ErrSoraGenerationNotActive) {
+ response.Error(c, http.StatusConflict, "任务已结束,无法取消")
+ return
+ }
+ response.Error(c, http.StatusBadRequest, err.Error())
+ return
+ }
+
+ response.Success(c, gin.H{"message": "已取消"})
+}
+
+// SaveToStorage 手动保存 upstream 记录到 S3。
+// POST /api/v1/sora/generations/:id/save
+func (h *SoraClientHandler) SaveToStorage(c *gin.Context) {
+ userID := getUserIDFromContext(c)
+ if userID == 0 {
+ response.Error(c, http.StatusUnauthorized, "未登录")
+ return
+ }
+
+ id, err := strconv.ParseInt(c.Param("id"), 10, 64)
+ if err != nil {
+ response.Error(c, http.StatusBadRequest, "无效的 ID")
+ return
+ }
+
+ gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
+ if err != nil {
+ response.Error(c, http.StatusNotFound, err.Error())
+ return
+ }
+
+ if gen.StorageType != service.SoraStorageTypeUpstream {
+ response.Error(c, http.StatusBadRequest, "仅 upstream 类型的记录可手动保存")
+ return
+ }
+ if gen.MediaURL == "" {
+ response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期")
+ return
+ }
+
+ if h.s3Storage == nil || !h.s3Storage.Enabled(c.Request.Context()) {
+ response.Error(c, http.StatusServiceUnavailable, "云存储未配置,请联系管理员")
+ return
+ }
+
+ sourceURLs := gen.MediaURLs
+ if len(sourceURLs) == 0 && gen.MediaURL != "" {
+ sourceURLs = []string{gen.MediaURL}
+ }
+ if len(sourceURLs) == 0 {
+ response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期")
+ return
+ }
+
+ uploadedKeys := make([]string, 0, len(sourceURLs))
+ accessURLs := make([]string, 0, len(sourceURLs))
+ var totalSize int64
+
+ for _, sourceURL := range sourceURLs {
+ objectKey, fileSize, uploadErr := h.s3Storage.UploadFromURL(c.Request.Context(), userID, sourceURL)
+ if uploadErr != nil {
+ if len(uploadedKeys) > 0 {
+ _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
+ }
+ var upstreamErr *service.UpstreamDownloadError
+ if errors.As(uploadErr, &upstreamErr) && (upstreamErr.StatusCode == http.StatusForbidden || upstreamErr.StatusCode == http.StatusNotFound) {
+ response.Error(c, http.StatusGone, "媒体链接已过期,无法保存")
+ return
+ }
+ response.Error(c, http.StatusInternalServerError, "上传到 S3 失败: "+uploadErr.Error())
+ return
+ }
+ accessURL, err := h.s3Storage.GetAccessURL(c.Request.Context(), objectKey)
+ if err != nil {
+ uploadedKeys = append(uploadedKeys, objectKey)
+ _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
+ response.Error(c, http.StatusInternalServerError, "生成 S3 访问链接失败: "+err.Error())
+ return
+ }
+ uploadedKeys = append(uploadedKeys, objectKey)
+ accessURLs = append(accessURLs, accessURL)
+ totalSize += fileSize
+ }
+
+ usageAdded := false
+ if totalSize > 0 && h.quotaService != nil {
+ if err := h.quotaService.AddUsage(c.Request.Context(), userID, totalSize); err != nil {
+ _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
+ var quotaErr *service.QuotaExceededError
+ if errors.As(err, "aErr) {
+ response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间")
+ return
+ }
+ response.Error(c, http.StatusInternalServerError, "配额更新失败: "+err.Error())
+ return
+ }
+ usageAdded = true
+ }
+
+ if err := h.genService.UpdateStorageForCompleted(
+ c.Request.Context(),
+ id,
+ accessURLs[0],
+ accessURLs,
+ service.SoraStorageTypeS3,
+ uploadedKeys,
+ totalSize,
+ ); err != nil {
+ _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
+ if usageAdded && h.quotaService != nil {
+ _ = h.quotaService.ReleaseUsage(c.Request.Context(), userID, totalSize)
+ }
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, gin.H{
+ "message": "已保存到 S3",
+ "object_key": uploadedKeys[0],
+ "object_keys": uploadedKeys,
+ })
+}
+
+// GetStorageStatus 返回存储状态。
+// GET /api/v1/sora/storage-status
+func (h *SoraClientHandler) GetStorageStatus(c *gin.Context) {
+ s3Enabled := h.s3Storage != nil && h.s3Storage.Enabled(c.Request.Context())
+ s3Healthy := false
+ if s3Enabled {
+ s3Healthy = h.s3Storage.IsHealthy(c.Request.Context())
+ }
+ localEnabled := h.mediaStorage != nil && h.mediaStorage.Enabled()
+ response.Success(c, gin.H{
+ "s3_enabled": s3Enabled,
+ "s3_healthy": s3Healthy,
+ "local_enabled": localEnabled,
+ })
+}
+
+func (h *SoraClientHandler) cleanupStoredMedia(ctx context.Context, storageType string, s3Keys []string, localPaths []string) {
+ switch storageType {
+ case service.SoraStorageTypeS3:
+ if h.s3Storage != nil && len(s3Keys) > 0 {
+ if err := h.s3Storage.DeleteObjects(ctx, s3Keys); err != nil {
+ logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理 S3 文件失败 keys=%v err=%v", s3Keys, err)
+ }
+ }
+ case service.SoraStorageTypeLocal:
+ if h.mediaStorage != nil && len(localPaths) > 0 {
+ if err := h.mediaStorage.DeleteByRelativePaths(localPaths); err != nil {
+ logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理本地文件失败 paths=%v err=%v", localPaths, err)
+ }
+ }
+ }
+}
+
+// getUserIDFromContext 从 gin 上下文中提取用户 ID。
+func getUserIDFromContext(c *gin.Context) int64 {
+ if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok && subject.UserID > 0 {
+ return subject.UserID
+ }
+
+ if id, ok := c.Get("user_id"); ok {
+ switch v := id.(type) {
+ case int64:
+ return v
+ case float64:
+ return int64(v)
+ case string:
+ n, _ := strconv.ParseInt(v, 10, 64)
+ return n
+ }
+ }
+ // 尝试从 JWT claims 获取
+ if id, ok := c.Get("userID"); ok {
+ if v, ok := id.(int64); ok {
+ return v
+ }
+ }
+ return 0
+}
+
+func groupIDForLog(groupID *int64) int64 {
+ if groupID == nil {
+ return 0
+ }
+ return *groupID
+}
+
+func trimForLog(raw string, maxLen int) string {
+ trimmed := strings.TrimSpace(raw)
+ if maxLen <= 0 || len(trimmed) <= maxLen {
+ return trimmed
+ }
+ return trimmed[:maxLen] + "...(truncated)"
+}
+
+// GetModels 获取可用 Sora 模型家族列表。
+// 优先从上游 Sora API 同步模型列表,失败时降级到本地配置。
+// GET /api/v1/sora/models
+func (h *SoraClientHandler) GetModels(c *gin.Context) {
+ families := h.getModelFamilies(c.Request.Context())
+ response.Success(c, families)
+}
+
+// getModelFamilies 获取模型家族列表(带缓存)。
+func (h *SoraClientHandler) getModelFamilies(ctx context.Context) []service.SoraModelFamily {
+ // 读锁检查缓存
+ h.modelCacheMu.RLock()
+ ttl := modelCacheTTL
+ if !h.modelCacheUpstream {
+ ttl = modelCacheFailedTTL
+ }
+ if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl {
+ families := h.cachedFamilies
+ h.modelCacheMu.RUnlock()
+ return families
+ }
+ h.modelCacheMu.RUnlock()
+
+ // 写锁更新缓存
+ h.modelCacheMu.Lock()
+ defer h.modelCacheMu.Unlock()
+
+ // double-check
+ ttl = modelCacheTTL
+ if !h.modelCacheUpstream {
+ ttl = modelCacheFailedTTL
+ }
+ if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl {
+ return h.cachedFamilies
+ }
+
+ // 尝试从上游获取
+ families, err := h.fetchUpstreamModels(ctx)
+ if err != nil {
+ logger.LegacyPrintf("handler.sora_client", "[SoraClient] 上游模型获取失败,使用本地配置: %v", err)
+ families = service.BuildSoraModelFamilies()
+ h.cachedFamilies = families
+ h.modelCacheTime = time.Now()
+ h.modelCacheUpstream = false
+ return families
+ }
+
+ logger.LegacyPrintf("handler.sora_client", "[SoraClient] 从上游同步到 %d 个模型家族", len(families))
+ h.cachedFamilies = families
+ h.modelCacheTime = time.Now()
+ h.modelCacheUpstream = true
+ return families
+}
+
+// fetchUpstreamModels 从上游 Sora API 获取模型列表。
+func (h *SoraClientHandler) fetchUpstreamModels(ctx context.Context) ([]service.SoraModelFamily, error) {
+ if h.gatewayService == nil {
+ return nil, fmt.Errorf("gatewayService 未初始化")
+ }
+
+ // 设置 ForcePlatform 用于 Sora 账号选择
+ ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora)
+
+ // 选择一个 Sora 账号
+ account, err := h.gatewayService.SelectAccountForModel(ctx, nil, "", "sora2-landscape-10s")
+ if err != nil {
+ return nil, fmt.Errorf("选择 Sora 账号失败: %w", err)
+ }
+
+ // 仅支持 API Key 类型账号
+ if account.Type != service.AccountTypeAPIKey {
+ return nil, fmt.Errorf("当前账号类型 %s 不支持模型同步", account.Type)
+ }
+
+ apiKey := account.GetCredential("api_key")
+ if apiKey == "" {
+ return nil, fmt.Errorf("账号缺少 api_key")
+ }
+
+ baseURL := account.GetBaseURL()
+ if baseURL == "" {
+ return nil, fmt.Errorf("账号缺少 base_url")
+ }
+
+ // 构建上游模型列表请求
+ modelsURL := strings.TrimRight(baseURL, "/") + "/sora/v1/models"
+
+ reqCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
+ defer cancel()
+
+ req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, modelsURL, nil)
+ if err != nil {
+ return nil, fmt.Errorf("创建请求失败: %w", err)
+ }
+ req.Header.Set("Authorization", "Bearer "+apiKey)
+
+ client := &http.Client{Timeout: 10 * time.Second}
+ resp, err := client.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("请求上游失败: %w", err)
+ }
+ defer func() {
+ _ = resp.Body.Close()
+ }()
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("上游返回状态码 %d", resp.StatusCode)
+ }
+
+ body, err := io.ReadAll(io.LimitReader(resp.Body, 1*1024*1024))
+ if err != nil {
+ return nil, fmt.Errorf("读取响应失败: %w", err)
+ }
+
+ // 解析 OpenAI 格式的模型列表
+ var modelsResp struct {
+ Data []struct {
+ ID string `json:"id"`
+ } `json:"data"`
+ }
+ if err := json.Unmarshal(body, &modelsResp); err != nil {
+ return nil, fmt.Errorf("解析响应失败: %w", err)
+ }
+
+ if len(modelsResp.Data) == 0 {
+ return nil, fmt.Errorf("上游返回空模型列表")
+ }
+
+ // 提取模型 ID
+ modelIDs := make([]string, 0, len(modelsResp.Data))
+ for _, m := range modelsResp.Data {
+ modelIDs = append(modelIDs, m.ID)
+ }
+
+ // 转换为模型家族
+ families := service.BuildSoraModelFamiliesFromIDs(modelIDs)
+ if len(families) == 0 {
+ return nil, fmt.Errorf("未能从上游模型列表中识别出有效的模型家族")
+ }
+
+ return families, nil
+}
diff --git a/backend/internal/handler/sora_client_handler_test.go b/backend/internal/handler/sora_client_handler_test.go
new file mode 100644
index 00000000..bc84ed52
--- /dev/null
+++ b/backend/internal/handler/sora_client_handler_test.go
@@ -0,0 +1,3195 @@
+//go:build unit
+
+package handler
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "strings"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+func init() {
+ gin.SetMode(gin.TestMode)
+}
+
+// ==================== Stub: SoraGenerationRepository ====================
+
+var _ service.SoraGenerationRepository = (*stubSoraGenRepo)(nil)
+
+type stubSoraGenRepo struct {
+ gens map[int64]*service.SoraGeneration
+ nextID int64
+ createErr error
+ getErr error
+ updateErr error
+ deleteErr error
+ listErr error
+ countErr error
+ countValue int64
+
+ // 条件性 Update 失败:前 updateFailAfterN 次成功,之后失败
+ updateCallCount *int32
+ updateFailAfterN int32
+
+ // 条件性 GetByID 状态覆盖:前 getByIDOverrideAfterN 次正常返回,之后返回 overrideStatus
+ getByIDCallCount int32
+ getByIDOverrideAfterN int32 // 0 = 不覆盖
+ getByIDOverrideStatus string
+}
+
+func newStubSoraGenRepo() *stubSoraGenRepo {
+ return &stubSoraGenRepo{gens: make(map[int64]*service.SoraGeneration), nextID: 1}
+}
+
+func (r *stubSoraGenRepo) Create(_ context.Context, gen *service.SoraGeneration) error {
+ if r.createErr != nil {
+ return r.createErr
+ }
+ gen.ID = r.nextID
+ r.nextID++
+ r.gens[gen.ID] = gen
+ return nil
+}
+func (r *stubSoraGenRepo) GetByID(_ context.Context, id int64) (*service.SoraGeneration, error) {
+ if r.getErr != nil {
+ return nil, r.getErr
+ }
+ gen, ok := r.gens[id]
+ if !ok {
+ return nil, fmt.Errorf("not found")
+ }
+ // 条件性状态覆盖:模拟外部取消等场景
+ if r.getByIDOverrideAfterN > 0 {
+ n := atomic.AddInt32(&r.getByIDCallCount, 1)
+ if n > r.getByIDOverrideAfterN {
+ cp := *gen
+ cp.Status = r.getByIDOverrideStatus
+ return &cp, nil
+ }
+ }
+ return gen, nil
+}
+func (r *stubSoraGenRepo) Update(_ context.Context, gen *service.SoraGeneration) error {
+ // 条件性失败:前 N 次成功,之后失败
+ if r.updateCallCount != nil {
+ n := atomic.AddInt32(r.updateCallCount, 1)
+ if n > r.updateFailAfterN {
+ return fmt.Errorf("conditional update error (call #%d)", n)
+ }
+ }
+ if r.updateErr != nil {
+ return r.updateErr
+ }
+ r.gens[gen.ID] = gen
+ return nil
+}
+func (r *stubSoraGenRepo) Delete(_ context.Context, id int64) error {
+ if r.deleteErr != nil {
+ return r.deleteErr
+ }
+ delete(r.gens, id)
+ return nil
+}
+func (r *stubSoraGenRepo) List(_ context.Context, params service.SoraGenerationListParams) ([]*service.SoraGeneration, int64, error) {
+ if r.listErr != nil {
+ return nil, 0, r.listErr
+ }
+ var result []*service.SoraGeneration
+ for _, gen := range r.gens {
+ if gen.UserID != params.UserID {
+ continue
+ }
+ result = append(result, gen)
+ }
+ return result, int64(len(result)), nil
+}
+func (r *stubSoraGenRepo) CountByUserAndStatus(_ context.Context, _ int64, _ []string) (int64, error) {
+ if r.countErr != nil {
+ return 0, r.countErr
+ }
+ return r.countValue, nil
+}
+
+// ==================== 辅助函数 ====================
+
+func newTestSoraClientHandler(repo *stubSoraGenRepo) *SoraClientHandler {
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ return &SoraClientHandler{genService: genService}
+}
+
+func makeGinContext(method, path, body string, userID int64) (*gin.Context, *httptest.ResponseRecorder) {
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ if body != "" {
+ c.Request = httptest.NewRequest(method, path, strings.NewReader(body))
+ c.Request.Header.Set("Content-Type", "application/json")
+ } else {
+ c.Request = httptest.NewRequest(method, path, nil)
+ }
+ if userID > 0 {
+ c.Set("user_id", userID)
+ }
+ return c, rec
+}
+
+func parseResponse(t *testing.T, rec *httptest.ResponseRecorder) map[string]any {
+ t.Helper()
+ var resp map[string]any
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ return resp
+}
+
+// ==================== 纯函数测试: buildAsyncRequestBody ====================
+
+func TestBuildAsyncRequestBody(t *testing.T) {
+ body := buildAsyncRequestBody("sora2-landscape-10s", "一只猫在跳舞", "", 1)
+ var parsed map[string]any
+ require.NoError(t, json.Unmarshal(body, &parsed))
+ require.Equal(t, "sora2-landscape-10s", parsed["model"])
+ require.Equal(t, false, parsed["stream"])
+
+ msgs := parsed["messages"].([]any)
+ require.Len(t, msgs, 1)
+ msg := msgs[0].(map[string]any)
+ require.Equal(t, "user", msg["role"])
+ require.Equal(t, "一只猫在跳舞", msg["content"])
+}
+
+func TestBuildAsyncRequestBody_EmptyPrompt(t *testing.T) {
+ body := buildAsyncRequestBody("gpt-image", "", "", 1)
+ var parsed map[string]any
+ require.NoError(t, json.Unmarshal(body, &parsed))
+ require.Equal(t, "gpt-image", parsed["model"])
+ msgs := parsed["messages"].([]any)
+ msg := msgs[0].(map[string]any)
+ require.Equal(t, "", msg["content"])
+}
+
+func TestBuildAsyncRequestBody_WithImageInput(t *testing.T) {
+ body := buildAsyncRequestBody("gpt-image", "一只猫", "https://example.com/ref.png", 1)
+ var parsed map[string]any
+ require.NoError(t, json.Unmarshal(body, &parsed))
+ require.Equal(t, "https://example.com/ref.png", parsed["image_input"])
+}
+
+func TestBuildAsyncRequestBody_WithVideoCount(t *testing.T) {
+ body := buildAsyncRequestBody("sora2-landscape-10s", "一只猫在跳舞", "", 3)
+ var parsed map[string]any
+ require.NoError(t, json.Unmarshal(body, &parsed))
+ require.Equal(t, float64(3), parsed["video_count"])
+}
+
+func TestNormalizeVideoCount(t *testing.T) {
+ require.Equal(t, 1, normalizeVideoCount("video", 0))
+ require.Equal(t, 2, normalizeVideoCount("video", 2))
+ require.Equal(t, 3, normalizeVideoCount("video", 5))
+ require.Equal(t, 1, normalizeVideoCount("image", 3))
+}
+
+// ==================== 纯函数测试: parseMediaURLsFromBody ====================
+
+func TestParseMediaURLsFromBody_MediaURLs(t *testing.T) {
+ urls := parseMediaURLsFromBody([]byte(`{"media_urls":["https://a.com/1.mp4","https://a.com/2.mp4"]}`))
+ require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls)
+}
+
+func TestParseMediaURLsFromBody_SingleMediaURL(t *testing.T) {
+ urls := parseMediaURLsFromBody([]byte(`{"media_url":"https://a.com/video.mp4"}`))
+ require.Equal(t, []string{"https://a.com/video.mp4"}, urls)
+}
+
+func TestParseMediaURLsFromBody_EmptyBody(t *testing.T) {
+ require.Nil(t, parseMediaURLsFromBody(nil))
+ require.Nil(t, parseMediaURLsFromBody([]byte{}))
+}
+
+func TestParseMediaURLsFromBody_InvalidJSON(t *testing.T) {
+ require.Nil(t, parseMediaURLsFromBody([]byte("not json")))
+}
+
+func TestParseMediaURLsFromBody_NoMediaFields(t *testing.T) {
+ require.Nil(t, parseMediaURLsFromBody([]byte(`{"data":"something"}`)))
+}
+
+func TestParseMediaURLsFromBody_EmptyMediaURL(t *testing.T) {
+ require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_url":""}`)))
+}
+
+func TestParseMediaURLsFromBody_EmptyMediaURLs(t *testing.T) {
+ require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":[]}`)))
+}
+
+func TestParseMediaURLsFromBody_MediaURLsPriority(t *testing.T) {
+ body := `{"media_url":"https://single.com/1.mp4","media_urls":["https://multi.com/a.mp4","https://multi.com/b.mp4"]}`
+ urls := parseMediaURLsFromBody([]byte(body))
+ require.Len(t, urls, 2)
+ require.Equal(t, "https://multi.com/a.mp4", urls[0])
+}
+
+func TestParseMediaURLsFromBody_FilterEmpty(t *testing.T) {
+ urls := parseMediaURLsFromBody([]byte(`{"media_urls":["https://a.com/1.mp4","","https://a.com/2.mp4"]}`))
+ require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls)
+}
+
+func TestParseMediaURLsFromBody_AllEmpty(t *testing.T) {
+ require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":["",""]}`)))
+}
+
+func TestParseMediaURLsFromBody_NonStringArray(t *testing.T) {
+ // media_urls 不是 string 数组
+ require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":"not-array"}`)))
+}
+
+func TestParseMediaURLsFromBody_MediaURLNotString(t *testing.T) {
+ require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_url":123}`)))
+}
+
+// ==================== 纯函数测试: extractMediaURLsFromResult ====================
+
+func TestExtractMediaURLsFromResult_OAuthPath(t *testing.T) {
+ result := &service.ForwardResult{MediaURL: "https://oauth.com/video.mp4"}
+ recorder := httptest.NewRecorder()
+ url, urls := extractMediaURLsFromResult(result, recorder)
+ require.Equal(t, "https://oauth.com/video.mp4", url)
+ require.Equal(t, []string{"https://oauth.com/video.mp4"}, urls)
+}
+
+func TestExtractMediaURLsFromResult_OAuthWithBody(t *testing.T) {
+ result := &service.ForwardResult{MediaURL: "https://oauth.com/video.mp4"}
+ recorder := httptest.NewRecorder()
+ _, _ = recorder.Write([]byte(`{"media_urls":["https://body.com/1.mp4","https://body.com/2.mp4"]}`))
+ url, urls := extractMediaURLsFromResult(result, recorder)
+ require.Equal(t, "https://body.com/1.mp4", url)
+ require.Len(t, urls, 2)
+}
+
+func TestExtractMediaURLsFromResult_APIKeyPath(t *testing.T) {
+ recorder := httptest.NewRecorder()
+ _, _ = recorder.Write([]byte(`{"media_url":"https://upstream.com/video.mp4"}`))
+ url, urls := extractMediaURLsFromResult(nil, recorder)
+ require.Equal(t, "https://upstream.com/video.mp4", url)
+ require.Equal(t, []string{"https://upstream.com/video.mp4"}, urls)
+}
+
+func TestExtractMediaURLsFromResult_NilResultEmptyBody(t *testing.T) {
+ recorder := httptest.NewRecorder()
+ url, urls := extractMediaURLsFromResult(nil, recorder)
+ require.Empty(t, url)
+ require.Nil(t, urls)
+}
+
+func TestExtractMediaURLsFromResult_EmptyMediaURL(t *testing.T) {
+ result := &service.ForwardResult{MediaURL: ""}
+ recorder := httptest.NewRecorder()
+ url, urls := extractMediaURLsFromResult(result, recorder)
+ require.Empty(t, url)
+ require.Nil(t, urls)
+}
+
+// ==================== getUserIDFromContext ====================
+
+func TestGetUserIDFromContext_Int64(t *testing.T) {
+ c, _ := gin.CreateTestContext(httptest.NewRecorder())
+ c.Request = httptest.NewRequest("GET", "/", nil)
+ c.Set("user_id", int64(42))
+ require.Equal(t, int64(42), getUserIDFromContext(c))
+}
+
+func TestGetUserIDFromContext_AuthSubject(t *testing.T) {
+ c, _ := gin.CreateTestContext(httptest.NewRecorder())
+ c.Request = httptest.NewRequest("GET", "/", nil)
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 777})
+ require.Equal(t, int64(777), getUserIDFromContext(c))
+}
+
+func TestGetUserIDFromContext_Float64(t *testing.T) {
+ c, _ := gin.CreateTestContext(httptest.NewRecorder())
+ c.Request = httptest.NewRequest("GET", "/", nil)
+ c.Set("user_id", float64(99))
+ require.Equal(t, int64(99), getUserIDFromContext(c))
+}
+
+func TestGetUserIDFromContext_String(t *testing.T) {
+ c, _ := gin.CreateTestContext(httptest.NewRecorder())
+ c.Request = httptest.NewRequest("GET", "/", nil)
+ c.Set("user_id", "123")
+ require.Equal(t, int64(123), getUserIDFromContext(c))
+}
+
+func TestGetUserIDFromContext_UserIDFallback(t *testing.T) {
+ c, _ := gin.CreateTestContext(httptest.NewRecorder())
+ c.Request = httptest.NewRequest("GET", "/", nil)
+ c.Set("userID", int64(55))
+ require.Equal(t, int64(55), getUserIDFromContext(c))
+}
+
+func TestGetUserIDFromContext_NoID(t *testing.T) {
+ c, _ := gin.CreateTestContext(httptest.NewRecorder())
+ c.Request = httptest.NewRequest("GET", "/", nil)
+ require.Equal(t, int64(0), getUserIDFromContext(c))
+}
+
+func TestGetUserIDFromContext_InvalidString(t *testing.T) {
+ c, _ := gin.CreateTestContext(httptest.NewRecorder())
+ c.Request = httptest.NewRequest("GET", "/", nil)
+ c.Set("user_id", "not-a-number")
+ require.Equal(t, int64(0), getUserIDFromContext(c))
+}
+
+// ==================== Handler: Generate ====================
+
+func TestGenerate_Unauthorized(t *testing.T) {
+ h := newTestSoraClientHandler(newStubSoraGenRepo())
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 0)
+ h.Generate(c)
+ require.Equal(t, http.StatusUnauthorized, rec.Code)
+}
+
+func TestGenerate_BadRequest_MissingModel(t *testing.T) {
+ h := newTestSoraClientHandler(newStubSoraGenRepo())
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"prompt":"test"}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+}
+
+func TestGenerate_BadRequest_MissingPrompt(t *testing.T) {
+ h := newTestSoraClientHandler(newStubSoraGenRepo())
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s"}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+}
+
+func TestGenerate_BadRequest_InvalidJSON(t *testing.T) {
+ h := newTestSoraClientHandler(newStubSoraGenRepo())
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{invalid`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+}
+
+func TestGenerate_TooManyRequests(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.countValue = 3
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusTooManyRequests, rec.Code)
+}
+
+func TestGenerate_CountError(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.countErr = fmt.Errorf("db error")
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusInternalServerError, rec.Code)
+}
+
+func TestGenerate_Success(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"测试生成"}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ resp := parseResponse(t, rec)
+ data := resp["data"].(map[string]any)
+ require.NotZero(t, data["generation_id"])
+ require.Equal(t, "pending", data["status"])
+}
+
+func TestGenerate_DefaultMediaType(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, "video", repo.gens[1].MediaType)
+}
+
+func TestGenerate_ImageMediaType(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"gpt-image","prompt":"test","media_type":"image"}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, "image", repo.gens[1].MediaType)
+}
+
+func TestGenerate_CreatePendingError(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.createErr = fmt.Errorf("create failed")
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusInternalServerError, rec.Code)
+}
+
+func TestGenerate_NilQuotaServiceSkipsCheck(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+}
+
+func TestGenerate_APIKeyInContext(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
+ c.Set("api_key_id", int64(42))
+ h.Generate(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.NotNil(t, repo.gens[1].APIKeyID)
+ require.Equal(t, int64(42), *repo.gens[1].APIKeyID)
+}
+
+func TestGenerate_NoAPIKeyInContext(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Nil(t, repo.gens[1].APIKeyID)
+}
+
+func TestGenerate_ConcurrencyBoundary(t *testing.T) {
+ // activeCount == 2 应该允许
+ repo := newStubSoraGenRepo()
+ repo.countValue = 2
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+}
+
+// ==================== Handler: ListGenerations ====================
+
+func TestListGenerations_Unauthorized(t *testing.T) {
+ h := newTestSoraClientHandler(newStubSoraGenRepo())
+ c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 0)
+ h.ListGenerations(c)
+ require.Equal(t, http.StatusUnauthorized, rec.Code)
+}
+
+func TestListGenerations_Success(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Model: "sora2-landscape-10s", Status: "completed", StorageType: "upstream"}
+ repo.gens[2] = &service.SoraGeneration{ID: 2, UserID: 1, Model: "gpt-image", Status: "pending", StorageType: "none"}
+ repo.nextID = 3
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("GET", "/api/v1/sora/generations?page=1&page_size=10", "", 1)
+ h.ListGenerations(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ resp := parseResponse(t, rec)
+ data := resp["data"].(map[string]any)
+ items := data["data"].([]any)
+ require.Len(t, items, 2)
+ require.Equal(t, float64(2), data["total"])
+}
+
+func TestListGenerations_ListError(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.listErr = fmt.Errorf("db error")
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 1)
+ h.ListGenerations(c)
+ require.Equal(t, http.StatusInternalServerError, rec.Code)
+}
+
+func TestListGenerations_DefaultPagination(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ h := newTestSoraClientHandler(repo)
+ // 不传分页参数,应默认 page=1 page_size=20
+ c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 1)
+ h.ListGenerations(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ resp := parseResponse(t, rec)
+ data := resp["data"].(map[string]any)
+ require.Equal(t, float64(1), data["page"])
+}
+
+// ==================== Handler: GetGeneration ====================
+
+func TestGetGeneration_Unauthorized(t *testing.T) {
+ h := newTestSoraClientHandler(newStubSoraGenRepo())
+ c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 0)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.GetGeneration(c)
+ require.Equal(t, http.StatusUnauthorized, rec.Code)
+}
+
+func TestGetGeneration_InvalidID(t *testing.T) {
+ h := newTestSoraClientHandler(newStubSoraGenRepo())
+ c, rec := makeGinContext("GET", "/api/v1/sora/generations/abc", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "abc"}}
+ h.GetGeneration(c)
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+}
+
+func TestGetGeneration_NotFound(t *testing.T) {
+ h := newTestSoraClientHandler(newStubSoraGenRepo())
+ c, rec := makeGinContext("GET", "/api/v1/sora/generations/999", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "999"}}
+ h.GetGeneration(c)
+ require.Equal(t, http.StatusNotFound, rec.Code)
+}
+
+func TestGetGeneration_WrongUser(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed"}
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.GetGeneration(c)
+ require.Equal(t, http.StatusNotFound, rec.Code)
+}
+
+func TestGetGeneration_Success(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Model: "sora2-landscape-10s", Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"}
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.GetGeneration(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ resp := parseResponse(t, rec)
+ data := resp["data"].(map[string]any)
+ require.Equal(t, float64(1), data["id"])
+}
+
+// ==================== Handler: DeleteGeneration ====================
+
+func TestDeleteGeneration_Unauthorized(t *testing.T) {
+ h := newTestSoraClientHandler(newStubSoraGenRepo())
+ c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 0)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.DeleteGeneration(c)
+ require.Equal(t, http.StatusUnauthorized, rec.Code)
+}
+
+func TestDeleteGeneration_InvalidID(t *testing.T) {
+ h := newTestSoraClientHandler(newStubSoraGenRepo())
+ c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/abc", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "abc"}}
+ h.DeleteGeneration(c)
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+}
+
+func TestDeleteGeneration_NotFound(t *testing.T) {
+ h := newTestSoraClientHandler(newStubSoraGenRepo())
+ c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/999", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "999"}}
+ h.DeleteGeneration(c)
+ require.Equal(t, http.StatusNotFound, rec.Code)
+}
+
+func TestDeleteGeneration_WrongUser(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed"}
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.DeleteGeneration(c)
+ require.Equal(t, http.StatusNotFound, rec.Code)
+}
+
+func TestDeleteGeneration_Success(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"}
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.DeleteGeneration(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ _, exists := repo.gens[1]
+ require.False(t, exists)
+}
+
+// ==================== Handler: CancelGeneration ====================
+
+func TestCancelGeneration_Unauthorized(t *testing.T) {
+ h := newTestSoraClientHandler(newStubSoraGenRepo())
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 0)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.CancelGeneration(c)
+ require.Equal(t, http.StatusUnauthorized, rec.Code)
+}
+
+func TestCancelGeneration_InvalidID(t *testing.T) {
+ h := newTestSoraClientHandler(newStubSoraGenRepo())
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/abc/cancel", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "abc"}}
+ h.CancelGeneration(c)
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+}
+
+func TestCancelGeneration_NotFound(t *testing.T) {
+ h := newTestSoraClientHandler(newStubSoraGenRepo())
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/999/cancel", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "999"}}
+ h.CancelGeneration(c)
+ require.Equal(t, http.StatusNotFound, rec.Code)
+}
+
+func TestCancelGeneration_WrongUser(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "pending"}
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.CancelGeneration(c)
+ require.Equal(t, http.StatusNotFound, rec.Code)
+}
+
+func TestCancelGeneration_Pending(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.CancelGeneration(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, "cancelled", repo.gens[1].Status)
+}
+
+func TestCancelGeneration_Generating(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "generating"}
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.CancelGeneration(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, "cancelled", repo.gens[1].Status)
+}
+
+func TestCancelGeneration_Completed(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"}
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.CancelGeneration(c)
+ require.Equal(t, http.StatusConflict, rec.Code)
+}
+
+func TestCancelGeneration_Failed(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "failed"}
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.CancelGeneration(c)
+ require.Equal(t, http.StatusConflict, rec.Code)
+}
+
+func TestCancelGeneration_Cancelled(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "cancelled"}
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.CancelGeneration(c)
+ require.Equal(t, http.StatusConflict, rec.Code)
+}
+
+// ==================== Handler: GetQuota ====================
+
+func TestGetQuota_Unauthorized(t *testing.T) {
+ h := newTestSoraClientHandler(newStubSoraGenRepo())
+ c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 0)
+ h.GetQuota(c)
+ require.Equal(t, http.StatusUnauthorized, rec.Code)
+}
+
+func TestGetQuota_NilQuotaService(t *testing.T) {
+ h := newTestSoraClientHandler(newStubSoraGenRepo())
+ c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 1)
+ h.GetQuota(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ resp := parseResponse(t, rec)
+ data := resp["data"].(map[string]any)
+ require.Equal(t, "unlimited", data["source"])
+}
+
+// ==================== Handler: GetModels ====================
+
+func TestGetModels(t *testing.T) {
+ h := newTestSoraClientHandler(newStubSoraGenRepo())
+ c, rec := makeGinContext("GET", "/api/v1/sora/models", "", 0)
+ h.GetModels(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ resp := parseResponse(t, rec)
+ data := resp["data"].([]any)
+ require.Len(t, data, 4)
+ // 验证类型分布
+ videoCount, imageCount := 0, 0
+ for _, item := range data {
+ m := item.(map[string]any)
+ if m["type"] == "video" {
+ videoCount++
+ } else if m["type"] == "image" {
+ imageCount++
+ }
+ }
+ require.Equal(t, 3, videoCount)
+ require.Equal(t, 1, imageCount)
+}
+
+// ==================== Handler: GetStorageStatus ====================
+
+func TestGetStorageStatus_NilS3(t *testing.T) {
+ h := newTestSoraClientHandler(newStubSoraGenRepo())
+ c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0)
+ h.GetStorageStatus(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ resp := parseResponse(t, rec)
+ data := resp["data"].(map[string]any)
+ require.Equal(t, false, data["s3_enabled"])
+ require.Equal(t, false, data["s3_healthy"])
+ require.Equal(t, false, data["local_enabled"])
+}
+
+func TestGetStorageStatus_LocalEnabled(t *testing.T) {
+ tmpDir, err := os.MkdirTemp("", "sora-storage-status-*")
+ require.NoError(t, err)
+ defer os.RemoveAll(tmpDir)
+
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Storage: config.SoraStorageConfig{
+ Type: "local",
+ LocalPath: tmpDir,
+ },
+ },
+ }
+ mediaStorage := service.NewSoraMediaStorage(cfg)
+ h := &SoraClientHandler{mediaStorage: mediaStorage}
+
+ c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0)
+ h.GetStorageStatus(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ resp := parseResponse(t, rec)
+ data := resp["data"].(map[string]any)
+ require.Equal(t, false, data["s3_enabled"])
+ require.Equal(t, false, data["s3_healthy"])
+ require.Equal(t, true, data["local_enabled"])
+}
+
+// ==================== Handler: SaveToStorage ====================
+
+func TestSaveToStorage_Unauthorized(t *testing.T) {
+ h := newTestSoraClientHandler(newStubSoraGenRepo())
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 0)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.SaveToStorage(c)
+ require.Equal(t, http.StatusUnauthorized, rec.Code)
+}
+
+func TestSaveToStorage_InvalidID(t *testing.T) {
+ h := newTestSoraClientHandler(newStubSoraGenRepo())
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/abc/save", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "abc"}}
+ h.SaveToStorage(c)
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+}
+
+func TestSaveToStorage_NotFound(t *testing.T) {
+ h := newTestSoraClientHandler(newStubSoraGenRepo())
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/999/save", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "999"}}
+ h.SaveToStorage(c)
+ require.Equal(t, http.StatusNotFound, rec.Code)
+}
+
+func TestSaveToStorage_NotUpstream(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "s3", MediaURL: "https://example.com/v.mp4"}
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.SaveToStorage(c)
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+}
+
+func TestSaveToStorage_EmptyMediaURL(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: ""}
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.SaveToStorage(c)
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+}
+
+func TestSaveToStorage_S3Nil(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"}
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.SaveToStorage(c)
+ require.Equal(t, http.StatusServiceUnavailable, rec.Code)
+ resp := parseResponse(t, rec)
+ require.Contains(t, fmt.Sprint(resp["message"]), "云存储")
+}
+
+func TestSaveToStorage_WrongUser(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"}
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.SaveToStorage(c)
+ require.Equal(t, http.StatusNotFound, rec.Code)
+}
+
+// ==================== storeMediaWithDegradation — nil guard 路径 ====================
+
+func TestStoreMediaWithDegradation_NilS3NilMedia(t *testing.T) {
+ h := &SoraClientHandler{}
+ url, urls, storageType, keys, size := h.storeMediaWithDegradation(
+ context.Background(), 1, "video", "https://upstream.com/v.mp4", nil,
+ )
+ require.Equal(t, service.SoraStorageTypeUpstream, storageType)
+ require.Equal(t, "https://upstream.com/v.mp4", url)
+ require.Equal(t, []string{"https://upstream.com/v.mp4"}, urls)
+ require.Nil(t, keys)
+ require.Equal(t, int64(0), size)
+}
+
+func TestStoreMediaWithDegradation_NilGuardsMultiURL(t *testing.T) {
+ h := &SoraClientHandler{}
+ url, urls, storageType, keys, size := h.storeMediaWithDegradation(
+ context.Background(), 1, "video", "https://upstream.com/v.mp4", []string{"https://a.com/1.mp4", "https://a.com/2.mp4"},
+ )
+ require.Equal(t, service.SoraStorageTypeUpstream, storageType)
+ require.Equal(t, "https://a.com/1.mp4", url)
+ require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls)
+ require.Nil(t, keys)
+ require.Equal(t, int64(0), size)
+}
+
+func TestStoreMediaWithDegradation_EmptyMediaURLsFallback(t *testing.T) {
+ h := &SoraClientHandler{}
+ url, _, storageType, _, _ := h.storeMediaWithDegradation(
+ context.Background(), 1, "video", "https://upstream.com/v.mp4", []string{},
+ )
+ require.Equal(t, service.SoraStorageTypeUpstream, storageType)
+ require.Equal(t, "https://upstream.com/v.mp4", url)
+}
+
+// ==================== Stub: UserRepository (用于 SoraQuotaService) ====================
+
+var _ service.UserRepository = (*stubUserRepoForHandler)(nil)
+
+type stubUserRepoForHandler struct {
+ users map[int64]*service.User
+ updateErr error
+}
+
+func newStubUserRepoForHandler() *stubUserRepoForHandler {
+ return &stubUserRepoForHandler{users: make(map[int64]*service.User)}
+}
+
+func (r *stubUserRepoForHandler) GetByID(_ context.Context, id int64) (*service.User, error) {
+ if u, ok := r.users[id]; ok {
+ return u, nil
+ }
+ return nil, fmt.Errorf("user not found")
+}
+func (r *stubUserRepoForHandler) Update(_ context.Context, user *service.User) error {
+ if r.updateErr != nil {
+ return r.updateErr
+ }
+ r.users[user.ID] = user
+ return nil
+}
+func (r *stubUserRepoForHandler) Create(context.Context, *service.User) error { return nil }
+func (r *stubUserRepoForHandler) GetByEmail(context.Context, string) (*service.User, error) {
+ return nil, nil
+}
+func (r *stubUserRepoForHandler) GetFirstAdmin(context.Context) (*service.User, error) {
+ return nil, nil
+}
+func (r *stubUserRepoForHandler) Delete(context.Context, int64) error { return nil }
+func (r *stubUserRepoForHandler) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
+ return nil, nil, nil
+}
+func (r *stubUserRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
+ return nil, nil, nil
+}
+func (r *stubUserRepoForHandler) UpdateBalance(context.Context, int64, float64) error { return nil }
+func (r *stubUserRepoForHandler) DeductBalance(context.Context, int64, float64) error { return nil }
+func (r *stubUserRepoForHandler) UpdateConcurrency(context.Context, int64, int) error { return nil }
+func (r *stubUserRepoForHandler) ExistsByEmail(context.Context, string) (bool, error) {
+ return false, nil
+}
+func (r *stubUserRepoForHandler) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
+ return 0, nil
+}
+func (r *stubUserRepoForHandler) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
+ return nil
+}
+func (r *stubUserRepoForHandler) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
+func (r *stubUserRepoForHandler) EnableTotp(context.Context, int64) error { return nil }
+func (r *stubUserRepoForHandler) DisableTotp(context.Context, int64) error { return nil }
+func (r *stubUserRepoForHandler) AddGroupToAllowedGroups(context.Context, int64, int64) error {
+ return nil
+}
+
+// ==================== NewSoraClientHandler ====================
+
+func TestNewSoraClientHandler(t *testing.T) {
+ h := NewSoraClientHandler(nil, nil, nil, nil, nil, nil, nil)
+ require.NotNil(t, h)
+}
+
+func TestNewSoraClientHandler_WithAPIKeyService(t *testing.T) {
+ h := NewSoraClientHandler(nil, nil, nil, nil, nil, nil, nil)
+ require.NotNil(t, h)
+ require.Nil(t, h.apiKeyService)
+}
+
+// ==================== Stub: APIKeyRepository (用于 API Key 校验测试) ====================
+
+var _ service.APIKeyRepository = (*stubAPIKeyRepoForHandler)(nil)
+
+type stubAPIKeyRepoForHandler struct {
+ keys map[int64]*service.APIKey
+ getErr error
+}
+
+func newStubAPIKeyRepoForHandler() *stubAPIKeyRepoForHandler {
+ return &stubAPIKeyRepoForHandler{keys: make(map[int64]*service.APIKey)}
+}
+
+func (r *stubAPIKeyRepoForHandler) GetByID(_ context.Context, id int64) (*service.APIKey, error) {
+ if r.getErr != nil {
+ return nil, r.getErr
+ }
+ if k, ok := r.keys[id]; ok {
+ return k, nil
+ }
+ return nil, fmt.Errorf("api key not found: %d", id)
+}
+func (r *stubAPIKeyRepoForHandler) Create(context.Context, *service.APIKey) error { return nil }
+func (r *stubAPIKeyRepoForHandler) GetKeyAndOwnerID(_ context.Context, _ int64) (string, int64, error) {
+ return "", 0, nil
+}
+func (r *stubAPIKeyRepoForHandler) GetByKey(context.Context, string) (*service.APIKey, error) {
+ return nil, nil
+}
+func (r *stubAPIKeyRepoForHandler) GetByKeyForAuth(context.Context, string) (*service.APIKey, error) {
+ return nil, nil
+}
+func (r *stubAPIKeyRepoForHandler) Update(context.Context, *service.APIKey) error { return nil }
+func (r *stubAPIKeyRepoForHandler) Delete(context.Context, int64) error { return nil }
+func (r *stubAPIKeyRepoForHandler) ListByUserID(_ context.Context, _ int64, _ pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) {
+ return nil, nil, nil
+}
+func (r *stubAPIKeyRepoForHandler) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) {
+ return nil, nil
+}
+func (r *stubAPIKeyRepoForHandler) CountByUserID(context.Context, int64) (int64, error) {
+ return 0, nil
+}
+func (r *stubAPIKeyRepoForHandler) ExistsByKey(context.Context, string) (bool, error) {
+ return false, nil
+}
+func (r *stubAPIKeyRepoForHandler) ListByGroupID(_ context.Context, _ int64, _ pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
+ return nil, nil, nil
+}
+func (r *stubAPIKeyRepoForHandler) SearchAPIKeys(context.Context, int64, string, int) ([]service.APIKey, error) {
+ return nil, nil
+}
+func (r *stubAPIKeyRepoForHandler) ClearGroupIDByGroupID(context.Context, int64) (int64, error) {
+ return 0, nil
+}
+func (r *stubAPIKeyRepoForHandler) UpdateGroupIDByUserAndGroup(_ context.Context, userID, oldGroupID, newGroupID int64) (int64, error) {
+ var updated int64
+ for id, key := range r.keys {
+ if key.UserID != userID || key.GroupID == nil || *key.GroupID != oldGroupID {
+ continue
+ }
+ clone := *key
+ gid := newGroupID
+ clone.GroupID = &gid
+ r.keys[id] = &clone
+ updated++
+ }
+ return updated, nil
+}
+func (r *stubAPIKeyRepoForHandler) CountByGroupID(context.Context, int64) (int64, error) {
+ return 0, nil
+}
+func (r *stubAPIKeyRepoForHandler) ListKeysByUserID(context.Context, int64) ([]string, error) {
+ return nil, nil
+}
+func (r *stubAPIKeyRepoForHandler) ListKeysByGroupID(context.Context, int64) ([]string, error) {
+ return nil, nil
+}
+func (r *stubAPIKeyRepoForHandler) IncrementQuotaUsed(_ context.Context, _ int64, _ float64) (float64, error) {
+ return 0, nil
+}
+func (r *stubAPIKeyRepoForHandler) UpdateLastUsed(context.Context, int64, time.Time) error {
+ return nil
+}
+func (r *stubAPIKeyRepoForHandler) IncrementRateLimitUsage(context.Context, int64, float64) error {
+ return nil
+}
+func (r *stubAPIKeyRepoForHandler) ResetRateLimitWindows(context.Context, int64) error {
+ return nil
+}
+func (r *stubAPIKeyRepoForHandler) GetRateLimitData(context.Context, int64) (*service.APIKeyRateLimitData, error) {
+ return nil, nil
+}
+
+// newTestAPIKeyService 创建测试用的 APIKeyService
+func newTestAPIKeyService(repo *stubAPIKeyRepoForHandler) *service.APIKeyService {
+ return service.NewAPIKeyService(repo, nil, nil, nil, nil, nil, &config.Config{})
+}
+
+// ==================== Generate: API Key 校验(前端传递 api_key_id)====================
+
+func TestGenerate_WithAPIKeyID_Success(t *testing.T) {
+ // 前端传递 api_key_id,校验通过 → 成功生成,记录关联 api_key_id
+ repo := newStubSoraGenRepo()
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+
+ groupID := int64(5)
+ apiKeyRepo := newStubAPIKeyRepoForHandler()
+ apiKeyRepo.keys[42] = &service.APIKey{
+ ID: 42,
+ UserID: 1,
+ Status: service.StatusAPIKeyActive,
+ GroupID: &groupID,
+ }
+ apiKeyService := newTestAPIKeyService(apiKeyRepo)
+
+ h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate",
+ `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ resp := parseResponse(t, rec)
+ data := resp["data"].(map[string]any)
+ require.NotZero(t, data["generation_id"])
+
+ // 验证 api_key_id 已关联到生成记录
+ gen := repo.gens[1]
+ require.NotNil(t, gen.APIKeyID)
+ require.Equal(t, int64(42), *gen.APIKeyID)
+}
+
+func TestGenerate_WithAPIKeyID_NotFound(t *testing.T) {
+ // 前端传递不存在的 api_key_id → 400
+ repo := newStubSoraGenRepo()
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+
+ apiKeyRepo := newStubAPIKeyRepoForHandler()
+ apiKeyService := newTestAPIKeyService(apiKeyRepo)
+
+ h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate",
+ `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":999}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+ resp := parseResponse(t, rec)
+ require.Contains(t, fmt.Sprint(resp["message"]), "不存在")
+}
+
+func TestGenerate_WithAPIKeyID_WrongUser(t *testing.T) {
+ // 前端传递别人的 api_key_id → 403
+ repo := newStubSoraGenRepo()
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+
+ apiKeyRepo := newStubAPIKeyRepoForHandler()
+ apiKeyRepo.keys[42] = &service.APIKey{
+ ID: 42,
+ UserID: 999, // 属于 user 999
+ Status: service.StatusAPIKeyActive,
+ }
+ apiKeyService := newTestAPIKeyService(apiKeyRepo)
+
+ h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate",
+ `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusForbidden, rec.Code)
+ resp := parseResponse(t, rec)
+ require.Contains(t, fmt.Sprint(resp["message"]), "不属于")
+}
+
+func TestGenerate_WithAPIKeyID_Disabled(t *testing.T) {
+ // 前端传递已禁用的 api_key_id → 403
+ repo := newStubSoraGenRepo()
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+
+ apiKeyRepo := newStubAPIKeyRepoForHandler()
+ apiKeyRepo.keys[42] = &service.APIKey{
+ ID: 42,
+ UserID: 1,
+ Status: service.StatusAPIKeyDisabled,
+ }
+ apiKeyService := newTestAPIKeyService(apiKeyRepo)
+
+ h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate",
+ `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusForbidden, rec.Code)
+ resp := parseResponse(t, rec)
+ require.Contains(t, fmt.Sprint(resp["message"]), "不可用")
+}
+
+func TestGenerate_WithAPIKeyID_QuotaExhausted(t *testing.T) {
+ // 前端传递配额耗尽的 api_key_id → 403
+ repo := newStubSoraGenRepo()
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+
+ apiKeyRepo := newStubAPIKeyRepoForHandler()
+ apiKeyRepo.keys[42] = &service.APIKey{
+ ID: 42,
+ UserID: 1,
+ Status: service.StatusAPIKeyQuotaExhausted,
+ }
+ apiKeyService := newTestAPIKeyService(apiKeyRepo)
+
+ h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate",
+ `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusForbidden, rec.Code)
+}
+
+func TestGenerate_WithAPIKeyID_Expired(t *testing.T) {
+ // 前端传递已过期的 api_key_id → 403
+ repo := newStubSoraGenRepo()
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+
+ apiKeyRepo := newStubAPIKeyRepoForHandler()
+ apiKeyRepo.keys[42] = &service.APIKey{
+ ID: 42,
+ UserID: 1,
+ Status: service.StatusAPIKeyExpired,
+ }
+ apiKeyService := newTestAPIKeyService(apiKeyRepo)
+
+ h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate",
+ `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusForbidden, rec.Code)
+}
+
+func TestGenerate_WithAPIKeyID_NilAPIKeyService(t *testing.T) {
+ // apiKeyService 为 nil 时忽略 api_key_id → 正常生成但不记录 api_key_id
+ repo := newStubSoraGenRepo()
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+
+ h := &SoraClientHandler{genService: genService} // apiKeyService = nil
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate",
+ `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ // apiKeyService 为 nil → 跳过校验 → api_key_id 不记录
+ require.Nil(t, repo.gens[1].APIKeyID)
+}
+
+func TestGenerate_WithAPIKeyID_NilGroupID(t *testing.T) {
+ // api_key 有效但 GroupID 为 nil → 成功,groupID 为 nil
+ repo := newStubSoraGenRepo()
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+
+ apiKeyRepo := newStubAPIKeyRepoForHandler()
+ apiKeyRepo.keys[42] = &service.APIKey{
+ ID: 42,
+ UserID: 1,
+ Status: service.StatusAPIKeyActive,
+ GroupID: nil, // 无分组
+ }
+ apiKeyService := newTestAPIKeyService(apiKeyRepo)
+
+ h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate",
+ `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.NotNil(t, repo.gens[1].APIKeyID)
+ require.Equal(t, int64(42), *repo.gens[1].APIKeyID)
+}
+
+func TestGenerate_NoAPIKeyID_NoContext_NilResult(t *testing.T) {
+ // 既无 api_key_id 字段也无 context 中的 api_key_id → api_key_id 为 nil
+ repo := newStubSoraGenRepo()
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ apiKeyRepo := newStubAPIKeyRepoForHandler()
+ apiKeyService := newTestAPIKeyService(apiKeyRepo)
+
+ h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate",
+ `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Nil(t, repo.gens[1].APIKeyID)
+}
+
+func TestGenerate_WithAPIKeyIDInBody_OverridesContext(t *testing.T) {
+ // 同时有 body api_key_id 和 context api_key_id → 优先使用 body 的
+ repo := newStubSoraGenRepo()
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+
+ groupID := int64(10)
+ apiKeyRepo := newStubAPIKeyRepoForHandler()
+ apiKeyRepo.keys[42] = &service.APIKey{
+ ID: 42,
+ UserID: 1,
+ Status: service.StatusAPIKeyActive,
+ GroupID: &groupID,
+ }
+ apiKeyService := newTestAPIKeyService(apiKeyRepo)
+
+ h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate",
+ `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
+ c.Set("api_key_id", int64(99)) // context 中有另一个 api_key_id
+ h.Generate(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ // 应使用 body 中的 api_key_id=42,而不是 context 中的 99
+ require.NotNil(t, repo.gens[1].APIKeyID)
+ require.Equal(t, int64(42), *repo.gens[1].APIKeyID)
+}
+
+func TestGenerate_WithContextAPIKeyID_FallbackPath(t *testing.T) {
+ // 无 body api_key_id,但 context 有 → 使用 context 中的(兼容网关路由)
+ repo := newStubSoraGenRepo()
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ apiKeyRepo := newStubAPIKeyRepoForHandler()
+ apiKeyService := newTestAPIKeyService(apiKeyRepo)
+
+ h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate",
+ `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
+ c.Set("api_key_id", int64(99))
+ h.Generate(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ // 应使用 context 中的 api_key_id=99
+ require.NotNil(t, repo.gens[1].APIKeyID)
+ require.Equal(t, int64(99), *repo.gens[1].APIKeyID)
+}
+
+func TestGenerate_APIKeyID_Zero_IgnoredInJSON(t *testing.T) {
+ // JSON 中 api_key_id=0 被视为 omitempty → 仍然为指针值 0,需要传 nil 检查
+ repo := newStubSoraGenRepo()
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ apiKeyRepo := newStubAPIKeyRepoForHandler()
+ apiKeyService := newTestAPIKeyService(apiKeyRepo)
+
+ h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
+ // JSON 中传了 api_key_id: 0 → 解析后 *int64(0),会触发校验
+ // api_key_id=0 不存在 → 400
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate",
+ `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":0}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+}
+
+// ==================== processGeneration: groupID 传递与 ForcePlatform ====================
+
+func TestProcessGeneration_WithGroupID_NoForcePlatform(t *testing.T) {
+ // groupID 不为 nil → 不设置 ForcePlatform
+ // gatewayService 为 nil → MarkFailed → 检查错误消息不包含 ForcePlatform 相关
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{genService: genService}
+
+ gid := int64(5)
+ h.processGeneration(1, 1, &gid, "sora2-landscape-10s", "test", "video", "", 1)
+ require.Equal(t, "failed", repo.gens[1].Status)
+ require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService")
+}
+
+func TestProcessGeneration_NilGroupID_SetsForcePlatform(t *testing.T) {
+ // groupID 为 nil → 设置 ForcePlatform → gatewayService 为 nil → MarkFailed
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{genService: genService}
+
+ h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
+ require.Equal(t, "failed", repo.gens[1].Status)
+ require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService")
+}
+
+func TestProcessGeneration_MarkGeneratingStateConflict(t *testing.T) {
+ // 任务状态已变化(如已取消)→ MarkGenerating 返回 ErrSoraGenerationStateConflict → 跳过
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "cancelled"}
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{genService: genService}
+
+ h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
+ // 状态为 cancelled 时 MarkGenerating 不符合状态转换规则 → 应保持 cancelled
+ require.Equal(t, "cancelled", repo.gens[1].Status)
+}
+
+// ==================== GenerateRequest JSON 解析 ====================
+
+func TestGenerateRequest_WithAPIKeyID_JSONParsing(t *testing.T) {
+ // 验证 api_key_id 在 JSON 中正确解析为 *int64
+ var req GenerateRequest
+ err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test","api_key_id":42}`), &req)
+ require.NoError(t, err)
+ require.NotNil(t, req.APIKeyID)
+ require.Equal(t, int64(42), *req.APIKeyID)
+}
+
+func TestGenerateRequest_WithoutAPIKeyID_JSONParsing(t *testing.T) {
+ // 不传 api_key_id → 解析后为 nil
+ var req GenerateRequest
+ err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test"}`), &req)
+ require.NoError(t, err)
+ require.Nil(t, req.APIKeyID)
+}
+
+func TestGenerateRequest_NullAPIKeyID_JSONParsing(t *testing.T) {
+ // api_key_id: null → 解析后为 nil
+ var req GenerateRequest
+ err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test","api_key_id":null}`), &req)
+ require.NoError(t, err)
+ require.Nil(t, req.APIKeyID)
+}
+
+func TestGenerateRequest_FullFields_JSONParsing(t *testing.T) {
+ // 全字段解析
+ var req GenerateRequest
+ err := json.Unmarshal([]byte(`{
+ "model":"sora2-landscape-10s",
+ "prompt":"test prompt",
+ "media_type":"video",
+ "video_count":2,
+ "image_input":"data:image/png;base64,abc",
+ "api_key_id":100
+ }`), &req)
+ require.NoError(t, err)
+ require.Equal(t, "sora2-landscape-10s", req.Model)
+ require.Equal(t, "test prompt", req.Prompt)
+ require.Equal(t, "video", req.MediaType)
+ require.Equal(t, 2, req.VideoCount)
+ require.Equal(t, "data:image/png;base64,abc", req.ImageInput)
+ require.NotNil(t, req.APIKeyID)
+ require.Equal(t, int64(100), *req.APIKeyID)
+}
+
+func TestGenerateRequest_JSONSerialize_OmitsNilAPIKeyID(t *testing.T) {
+ // api_key_id 为 nil 时 JSON 序列化应省略
+ req := GenerateRequest{Model: "sora2", Prompt: "test"}
+ b, err := json.Marshal(req)
+ require.NoError(t, err)
+ var parsed map[string]any
+ require.NoError(t, json.Unmarshal(b, &parsed))
+ _, hasAPIKeyID := parsed["api_key_id"]
+ require.False(t, hasAPIKeyID, "api_key_id 为 nil 时应省略")
+}
+
+func TestGenerateRequest_JSONSerialize_IncludesAPIKeyID(t *testing.T) {
+ // api_key_id 不为 nil 时 JSON 序列化应包含
+ id := int64(42)
+ req := GenerateRequest{Model: "sora2", Prompt: "test", APIKeyID: &id}
+ b, err := json.Marshal(req)
+ require.NoError(t, err)
+ var parsed map[string]any
+ require.NoError(t, json.Unmarshal(b, &parsed))
+ require.Equal(t, float64(42), parsed["api_key_id"])
+}
+
+// ==================== GetQuota: 有配额服务 ====================
+
+func TestGetQuota_WithQuotaService_Success(t *testing.T) {
+ userRepo := newStubUserRepoForHandler()
+ userRepo.users[1] = &service.User{
+ ID: 1,
+ SoraStorageQuotaBytes: 10 * 1024 * 1024,
+ SoraStorageUsedBytes: 3 * 1024 * 1024,
+ }
+ quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
+
+ repo := newStubSoraGenRepo()
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{
+ genService: genService,
+ quotaService: quotaService,
+ }
+
+ c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 1)
+ h.GetQuota(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ resp := parseResponse(t, rec)
+ data := resp["data"].(map[string]any)
+ require.Equal(t, "user", data["source"])
+ require.Equal(t, float64(10*1024*1024), data["quota_bytes"])
+ require.Equal(t, float64(3*1024*1024), data["used_bytes"])
+}
+
+func TestGetQuota_WithQuotaService_Error(t *testing.T) {
+ // 用户不存在时 GetQuota 返回错误
+ userRepo := newStubUserRepoForHandler()
+ quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
+
+ repo := newStubSoraGenRepo()
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{
+ genService: genService,
+ quotaService: quotaService,
+ }
+
+ c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 999)
+ h.GetQuota(c)
+ require.Equal(t, http.StatusInternalServerError, rec.Code)
+}
+
+// ==================== Generate: 配额检查 ====================
+
+func TestGenerate_QuotaCheckFailed(t *testing.T) {
+ // 配额超限时返回 429
+ userRepo := newStubUserRepoForHandler()
+ userRepo.users[1] = &service.User{
+ ID: 1,
+ SoraStorageQuotaBytes: 1024,
+ SoraStorageUsedBytes: 1025, // 已超限
+ }
+ quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
+
+ repo := newStubSoraGenRepo()
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{
+ genService: genService,
+ quotaService: quotaService,
+ }
+
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusTooManyRequests, rec.Code)
+}
+
+func TestGenerate_QuotaCheckPassed(t *testing.T) {
+ // 配额充足时允许生成
+ userRepo := newStubUserRepoForHandler()
+ userRepo.users[1] = &service.User{
+ ID: 1,
+ SoraStorageQuotaBytes: 10 * 1024 * 1024,
+ SoraStorageUsedBytes: 0,
+ }
+ quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
+
+ repo := newStubSoraGenRepo()
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{
+ genService: genService,
+ quotaService: quotaService,
+ }
+
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+}
+
+// ==================== Stub: SettingRepository (用于 S3 存储测试) ====================
+
+var _ service.SettingRepository = (*stubSettingRepoForHandler)(nil)
+
+type stubSettingRepoForHandler struct {
+ values map[string]string
+}
+
+func newStubSettingRepoForHandler(values map[string]string) *stubSettingRepoForHandler {
+ if values == nil {
+ values = make(map[string]string)
+ }
+ return &stubSettingRepoForHandler{values: values}
+}
+
+func (r *stubSettingRepoForHandler) Get(_ context.Context, key string) (*service.Setting, error) {
+ if v, ok := r.values[key]; ok {
+ return &service.Setting{Key: key, Value: v}, nil
+ }
+ return nil, service.ErrSettingNotFound
+}
+func (r *stubSettingRepoForHandler) GetValue(_ context.Context, key string) (string, error) {
+ if v, ok := r.values[key]; ok {
+ return v, nil
+ }
+ return "", service.ErrSettingNotFound
+}
+func (r *stubSettingRepoForHandler) Set(_ context.Context, key, value string) error {
+ r.values[key] = value
+ return nil
+}
+func (r *stubSettingRepoForHandler) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
+ result := make(map[string]string)
+ for _, k := range keys {
+ if v, ok := r.values[k]; ok {
+ result[k] = v
+ }
+ }
+ return result, nil
+}
+func (r *stubSettingRepoForHandler) SetMultiple(_ context.Context, settings map[string]string) error {
+ for k, v := range settings {
+ r.values[k] = v
+ }
+ return nil
+}
+func (r *stubSettingRepoForHandler) GetAll(_ context.Context) (map[string]string, error) {
+ return r.values, nil
+}
+func (r *stubSettingRepoForHandler) Delete(_ context.Context, key string) error {
+ delete(r.values, key)
+ return nil
+}
+
+// ==================== S3 / MediaStorage 辅助函数 ====================
+
+// newS3StorageForHandler 创建指向指定 endpoint 的 S3Storage(用于测试)。
+func newS3StorageForHandler(endpoint string) *service.SoraS3Storage {
+ settingRepo := newStubSettingRepoForHandler(map[string]string{
+ "sora_s3_enabled": "true",
+ "sora_s3_endpoint": endpoint,
+ "sora_s3_region": "us-east-1",
+ "sora_s3_bucket": "test-bucket",
+ "sora_s3_access_key_id": "AKIATEST",
+ "sora_s3_secret_access_key": "test-secret",
+ "sora_s3_prefix": "sora",
+ "sora_s3_force_path_style": "true",
+ })
+ settingService := service.NewSettingService(settingRepo, &config.Config{})
+ return service.NewSoraS3Storage(settingService)
+}
+
+// newFakeSourceServer 创建返回固定内容的 HTTP 服务器(模拟上游媒体文件)。
+func newFakeSourceServer() *httptest.Server {
+ return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "video/mp4")
+ w.WriteHeader(http.StatusOK)
+ _, _ = w.Write([]byte("fake video data for test"))
+ }))
+}
+
+// newFakeS3Server 创建模拟 S3 的 HTTP 服务器。
+// mode: "ok" 接受所有请求,"fail" 返回 403,"fail-second" 第一次成功第二次失败。
+func newFakeS3Server(mode string) *httptest.Server {
+ var counter atomic.Int32
+ return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ _, _ = io.Copy(io.Discard, r.Body)
+ _ = r.Body.Close()
+
+ switch mode {
+ case "ok":
+ w.Header().Set("ETag", `"test-etag"`)
+ w.WriteHeader(http.StatusOK)
+ case "fail":
+ w.WriteHeader(http.StatusForbidden)
+ _, _ = w.Write([]byte(`AccessDenied`))
+ case "fail-second":
+ n := counter.Add(1)
+ if n <= 1 {
+ w.Header().Set("ETag", `"test-etag"`)
+ w.WriteHeader(http.StatusOK)
+ } else {
+ w.WriteHeader(http.StatusForbidden)
+ _, _ = w.Write([]byte(`AccessDenied`))
+ }
+ }
+ }))
+}
+
+// ==================== processGeneration 直接调用测试 ====================
+
+func TestProcessGeneration_MarkGeneratingFails(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
+ repo.updateErr = fmt.Errorf("db error")
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{genService: genService}
+
+ // 直接调用(非 goroutine),MarkGenerating 失败 → 早退
+ h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
+ // MarkGenerating 在调用 repo.Update 前已修改内存对象为 "generating"
+ // repo.Update 返回错误 → processGeneration 早退,不会继续到 MarkFailed
+ // 因此 ErrorMessage 为空(证明未调用 MarkFailed)
+ require.Equal(t, "generating", repo.gens[1].Status)
+ require.Empty(t, repo.gens[1].ErrorMessage)
+}
+
+func TestProcessGeneration_GatewayServiceNil(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{genService: genService}
+ // gatewayService 未设置 → MarkFailed
+
+ h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
+ require.Equal(t, "failed", repo.gens[1].Status)
+ require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService")
+}
+
+// ==================== storeMediaWithDegradation: S3 路径 ====================
+
+func TestStoreMediaWithDegradation_S3SuccessSingleURL(t *testing.T) {
+ sourceServer := newFakeSourceServer()
+ defer sourceServer.Close()
+ fakeS3 := newFakeS3Server("ok")
+ defer fakeS3.Close()
+
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+ h := &SoraClientHandler{s3Storage: s3Storage}
+
+ storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(
+ context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil,
+ )
+ require.Equal(t, service.SoraStorageTypeS3, storageType)
+ require.Len(t, s3Keys, 1)
+ require.NotEmpty(t, s3Keys[0])
+ require.Len(t, storedURLs, 1)
+ require.Equal(t, storedURL, storedURLs[0])
+ require.Contains(t, storedURL, fakeS3.URL)
+ require.Contains(t, storedURL, "/test-bucket/")
+ require.Greater(t, fileSize, int64(0))
+}
+
+func TestStoreMediaWithDegradation_S3SuccessMultiURL(t *testing.T) {
+ sourceServer := newFakeSourceServer()
+ defer sourceServer.Close()
+ fakeS3 := newFakeS3Server("ok")
+ defer fakeS3.Close()
+
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+ h := &SoraClientHandler{s3Storage: s3Storage}
+
+ urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"}
+ storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(
+ context.Background(), 1, "video", sourceServer.URL+"/a.mp4", urls,
+ )
+ require.Equal(t, service.SoraStorageTypeS3, storageType)
+ require.Len(t, s3Keys, 2)
+ require.Len(t, storedURLs, 2)
+ require.Equal(t, storedURL, storedURLs[0])
+ require.Contains(t, storedURLs[0], fakeS3.URL)
+ require.Contains(t, storedURLs[1], fakeS3.URL)
+ require.Greater(t, fileSize, int64(0))
+}
+
+func TestStoreMediaWithDegradation_S3DownloadFails(t *testing.T) {
+ // 上游返回 404 → 下载失败 → S3 上传不会开始
+ fakeS3 := newFakeS3Server("ok")
+ defer fakeS3.Close()
+ badSource := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusNotFound)
+ }))
+ defer badSource.Close()
+
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+ h := &SoraClientHandler{s3Storage: s3Storage}
+
+ _, _, storageType, _, _ := h.storeMediaWithDegradation(
+ context.Background(), 1, "video", badSource.URL+"/missing.mp4", nil,
+ )
+ require.Equal(t, service.SoraStorageTypeUpstream, storageType)
+}
+
+func TestStoreMediaWithDegradation_S3FailsSingleURL(t *testing.T) {
+ sourceServer := newFakeSourceServer()
+ defer sourceServer.Close()
+ fakeS3 := newFakeS3Server("fail")
+ defer fakeS3.Close()
+
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+ h := &SoraClientHandler{s3Storage: s3Storage}
+
+ _, _, storageType, s3Keys, _ := h.storeMediaWithDegradation(
+ context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil,
+ )
+ // S3 失败,降级到 upstream
+ require.Equal(t, service.SoraStorageTypeUpstream, storageType)
+ require.Nil(t, s3Keys)
+}
+
+func TestStoreMediaWithDegradation_S3PartialFailureCleanup(t *testing.T) {
+ sourceServer := newFakeSourceServer()
+ defer sourceServer.Close()
+ fakeS3 := newFakeS3Server("fail-second")
+ defer fakeS3.Close()
+
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+ h := &SoraClientHandler{s3Storage: s3Storage}
+
+ urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"}
+ _, _, storageType, s3Keys, _ := h.storeMediaWithDegradation(
+ context.Background(), 1, "video", sourceServer.URL+"/a.mp4", urls,
+ )
+ // 第二个 URL 上传失败 → 清理已上传 → 降级到 upstream
+ require.Equal(t, service.SoraStorageTypeUpstream, storageType)
+ require.Nil(t, s3Keys)
+}
+
+// ==================== storeMediaWithDegradation: 本地存储路径 ====================
+
+func TestStoreMediaWithDegradation_LocalStorageFails(t *testing.T) {
+ // 使用无效路径,EnsureLocalDirs 失败 → StoreFromURLs 返回 error
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Storage: config.SoraStorageConfig{
+ Type: "local",
+ LocalPath: "/dev/null/invalid_dir",
+ },
+ },
+ }
+ mediaStorage := service.NewSoraMediaStorage(cfg)
+ h := &SoraClientHandler{mediaStorage: mediaStorage}
+
+ _, _, storageType, _, _ := h.storeMediaWithDegradation(
+ context.Background(), 1, "video", "https://upstream.com/v.mp4", nil,
+ )
+ // 本地存储失败,降级到 upstream
+ require.Equal(t, service.SoraStorageTypeUpstream, storageType)
+}
+
+func TestStoreMediaWithDegradation_LocalStorageSuccess(t *testing.T) {
+ tmpDir, err := os.MkdirTemp("", "sora-handler-test-*")
+ require.NoError(t, err)
+ defer os.RemoveAll(tmpDir)
+
+ sourceServer := newFakeSourceServer()
+ defer sourceServer.Close()
+
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Storage: config.SoraStorageConfig{
+ Type: "local",
+ LocalPath: tmpDir,
+ DownloadTimeoutSeconds: 5,
+ MaxDownloadBytes: 10 * 1024 * 1024,
+ },
+ },
+ }
+ mediaStorage := service.NewSoraMediaStorage(cfg)
+ h := &SoraClientHandler{mediaStorage: mediaStorage}
+
+ _, _, storageType, s3Keys, _ := h.storeMediaWithDegradation(
+ context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil,
+ )
+ require.Equal(t, service.SoraStorageTypeLocal, storageType)
+ require.Nil(t, s3Keys) // 本地存储不返回 S3 keys
+}
+
+func TestStoreMediaWithDegradation_S3FailsFallbackToLocal(t *testing.T) {
+ tmpDir, err := os.MkdirTemp("", "sora-handler-test-*")
+ require.NoError(t, err)
+ defer os.RemoveAll(tmpDir)
+
+ sourceServer := newFakeSourceServer()
+ defer sourceServer.Close()
+ fakeS3 := newFakeS3Server("fail")
+ defer fakeS3.Close()
+
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Storage: config.SoraStorageConfig{
+ Type: "local",
+ LocalPath: tmpDir,
+ DownloadTimeoutSeconds: 5,
+ MaxDownloadBytes: 10 * 1024 * 1024,
+ },
+ },
+ }
+ mediaStorage := service.NewSoraMediaStorage(cfg)
+ h := &SoraClientHandler{
+ s3Storage: s3Storage,
+ mediaStorage: mediaStorage,
+ }
+
+ _, _, storageType, _, _ := h.storeMediaWithDegradation(
+ context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil,
+ )
+ // S3 失败 → 本地存储成功
+ require.Equal(t, service.SoraStorageTypeLocal, storageType)
+}
+
+// ==================== SaveToStorage: S3 路径 ====================
+
+func TestSaveToStorage_S3EnabledButUploadFails(t *testing.T) {
+ sourceServer := newFakeSourceServer()
+ defer sourceServer.Close()
+ fakeS3 := newFakeS3Server("fail")
+ defer fakeS3.Close()
+
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{
+ ID: 1, UserID: 1, Status: "completed",
+ StorageType: "upstream",
+ MediaURL: sourceServer.URL + "/v.mp4",
+ }
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
+
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.SaveToStorage(c)
+ require.Equal(t, http.StatusInternalServerError, rec.Code)
+ resp := parseResponse(t, rec)
+ require.Contains(t, resp["message"], "S3")
+}
+
+func TestSaveToStorage_UpstreamURLExpired(t *testing.T) {
+ expiredServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.WriteHeader(http.StatusForbidden)
+ }))
+ defer expiredServer.Close()
+ fakeS3 := newFakeS3Server("ok")
+ defer fakeS3.Close()
+
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{
+ ID: 1, UserID: 1, Status: "completed",
+ StorageType: "upstream",
+ MediaURL: expiredServer.URL + "/v.mp4",
+ }
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
+
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.SaveToStorage(c)
+ require.Equal(t, http.StatusGone, rec.Code)
+ resp := parseResponse(t, rec)
+ require.Contains(t, fmt.Sprint(resp["message"]), "过期")
+}
+
+func TestSaveToStorage_S3EnabledUploadSuccess(t *testing.T) {
+ sourceServer := newFakeSourceServer()
+ defer sourceServer.Close()
+ fakeS3 := newFakeS3Server("ok")
+ defer fakeS3.Close()
+
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{
+ ID: 1, UserID: 1, Status: "completed",
+ StorageType: "upstream",
+ MediaURL: sourceServer.URL + "/v.mp4",
+ }
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
+
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.SaveToStorage(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ resp := parseResponse(t, rec)
+ data := resp["data"].(map[string]any)
+ require.Contains(t, data["message"], "S3")
+ require.NotEmpty(t, data["object_key"])
+ // 验证记录已更新为 S3 存储
+ require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType)
+}
+
+func TestSaveToStorage_S3EnabledUploadSuccess_MultiMediaURLs(t *testing.T) {
+ sourceServer := newFakeSourceServer()
+ defer sourceServer.Close()
+ fakeS3 := newFakeS3Server("ok")
+ defer fakeS3.Close()
+
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{
+ ID: 1, UserID: 1, Status: "completed",
+ StorageType: "upstream",
+ MediaURL: sourceServer.URL + "/v1.mp4",
+ MediaURLs: []string{
+ sourceServer.URL + "/v1.mp4",
+ sourceServer.URL + "/v2.mp4",
+ },
+ }
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
+
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.SaveToStorage(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ resp := parseResponse(t, rec)
+ data := resp["data"].(map[string]any)
+ require.Len(t, data["object_keys"].([]any), 2)
+ require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType)
+ require.Len(t, repo.gens[1].S3ObjectKeys, 2)
+ require.Len(t, repo.gens[1].MediaURLs, 2)
+}
+
+func TestSaveToStorage_S3EnabledUploadSuccessWithQuota(t *testing.T) {
+ sourceServer := newFakeSourceServer()
+ defer sourceServer.Close()
+ fakeS3 := newFakeS3Server("ok")
+ defer fakeS3.Close()
+
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{
+ ID: 1, UserID: 1, Status: "completed",
+ StorageType: "upstream",
+ MediaURL: sourceServer.URL + "/v.mp4",
+ }
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+
+ userRepo := newStubUserRepoForHandler()
+ userRepo.users[1] = &service.User{
+ ID: 1,
+ SoraStorageQuotaBytes: 100 * 1024 * 1024,
+ SoraStorageUsedBytes: 0,
+ }
+ quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
+ h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
+
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.SaveToStorage(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ // 验证配额已累加
+ require.Greater(t, userRepo.users[1].SoraStorageUsedBytes, int64(0))
+}
+
+func TestSaveToStorage_S3UploadSuccessMarkCompletedFails(t *testing.T) {
+ sourceServer := newFakeSourceServer()
+ defer sourceServer.Close()
+ fakeS3 := newFakeS3Server("ok")
+ defer fakeS3.Close()
+
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{
+ ID: 1, UserID: 1, Status: "completed",
+ StorageType: "upstream",
+ MediaURL: sourceServer.URL + "/v.mp4",
+ }
+ // S3 上传成功后,MarkCompleted 会调用 repo.Update → 失败
+ repo.updateErr = fmt.Errorf("db error")
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
+
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.SaveToStorage(c)
+ require.Equal(t, http.StatusInternalServerError, rec.Code)
+}
+
+// ==================== GetStorageStatus: S3 路径 ====================
+
+func TestGetStorageStatus_S3EnabledNotHealthy(t *testing.T) {
+ // S3 启用但 TestConnection 失败(fake 端点不响应 HeadBucket)
+ fakeS3 := newFakeS3Server("fail")
+ defer fakeS3.Close()
+
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+ h := &SoraClientHandler{s3Storage: s3Storage}
+
+ c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0)
+ h.GetStorageStatus(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ resp := parseResponse(t, rec)
+ data := resp["data"].(map[string]any)
+ require.Equal(t, true, data["s3_enabled"])
+ require.Equal(t, false, data["s3_healthy"])
+}
+
+func TestGetStorageStatus_S3EnabledHealthy(t *testing.T) {
+ fakeS3 := newFakeS3Server("ok")
+ defer fakeS3.Close()
+
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+ h := &SoraClientHandler{s3Storage: s3Storage}
+
+ c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0)
+ h.GetStorageStatus(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ resp := parseResponse(t, rec)
+ data := resp["data"].(map[string]any)
+ require.Equal(t, true, data["s3_enabled"])
+ require.Equal(t, true, data["s3_healthy"])
+}
+
+// ==================== Stub: AccountRepository (用于 GatewayService) ====================
+
+var _ service.AccountRepository = (*stubAccountRepoForHandler)(nil)
+
+type stubAccountRepoForHandler struct {
+ accounts []service.Account
+}
+
+func (r *stubAccountRepoForHandler) Create(context.Context, *service.Account) error { return nil }
+func (r *stubAccountRepoForHandler) GetByID(_ context.Context, id int64) (*service.Account, error) {
+ for i := range r.accounts {
+ if r.accounts[i].ID == id {
+ return &r.accounts[i], nil
+ }
+ }
+ return nil, fmt.Errorf("account not found")
+}
+func (r *stubAccountRepoForHandler) GetByIDs(context.Context, []int64) ([]*service.Account, error) {
+ return nil, nil
+}
+func (r *stubAccountRepoForHandler) ExistsByID(context.Context, int64) (bool, error) {
+ return false, nil
+}
+func (r *stubAccountRepoForHandler) GetByCRSAccountID(context.Context, string) (*service.Account, error) {
+ return nil, nil
+}
+func (r *stubAccountRepoForHandler) FindByExtraField(context.Context, string, any) ([]service.Account, error) {
+ return nil, nil
+}
+func (r *stubAccountRepoForHandler) ListCRSAccountIDs(context.Context) (map[string]int64, error) {
+ return nil, nil
+}
+func (r *stubAccountRepoForHandler) Update(context.Context, *service.Account) error { return nil }
+func (r *stubAccountRepoForHandler) Delete(context.Context, int64) error { return nil }
+func (r *stubAccountRepoForHandler) List(context.Context, pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
+ return nil, nil, nil
+}
+func (r *stubAccountRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64) ([]service.Account, *pagination.PaginationResult, error) {
+ return nil, nil, nil
+}
+func (r *stubAccountRepoForHandler) ListByGroup(context.Context, int64) ([]service.Account, error) {
+ return nil, nil
+}
+func (r *stubAccountRepoForHandler) ListActive(context.Context) ([]service.Account, error) {
+ return nil, nil
+}
+func (r *stubAccountRepoForHandler) ListByPlatform(context.Context, string) ([]service.Account, error) {
+ return r.accounts, nil
+}
+func (r *stubAccountRepoForHandler) UpdateLastUsed(context.Context, int64) error { return nil }
+func (r *stubAccountRepoForHandler) BatchUpdateLastUsed(context.Context, map[int64]time.Time) error {
+ return nil
+}
+func (r *stubAccountRepoForHandler) SetError(context.Context, int64, string) error { return nil }
+func (r *stubAccountRepoForHandler) ClearError(context.Context, int64) error { return nil }
+func (r *stubAccountRepoForHandler) SetSchedulable(context.Context, int64, bool) error {
+ return nil
+}
+func (r *stubAccountRepoForHandler) AutoPauseExpiredAccounts(context.Context, time.Time) (int64, error) {
+ return 0, nil
+}
+func (r *stubAccountRepoForHandler) BindGroups(context.Context, int64, []int64) error { return nil }
+func (r *stubAccountRepoForHandler) ListSchedulable(context.Context) ([]service.Account, error) {
+ return r.accounts, nil
+}
+func (r *stubAccountRepoForHandler) ListSchedulableByGroupID(context.Context, int64) ([]service.Account, error) {
+ return r.accounts, nil
+}
+func (r *stubAccountRepoForHandler) ListSchedulableByPlatform(_ context.Context, _ string) ([]service.Account, error) {
+ return r.accounts, nil
+}
+func (r *stubAccountRepoForHandler) ListSchedulableByGroupIDAndPlatform(context.Context, int64, string) ([]service.Account, error) {
+ return r.accounts, nil
+}
+func (r *stubAccountRepoForHandler) ListSchedulableByPlatforms(context.Context, []string) ([]service.Account, error) {
+ return r.accounts, nil
+}
+func (r *stubAccountRepoForHandler) ListSchedulableByGroupIDAndPlatforms(context.Context, int64, []string) ([]service.Account, error) {
+ return r.accounts, nil
+}
+func (r *stubAccountRepoForHandler) ListSchedulableUngroupedByPlatform(_ context.Context, _ string) ([]service.Account, error) {
+ return r.accounts, nil
+}
+func (r *stubAccountRepoForHandler) ListSchedulableUngroupedByPlatforms(_ context.Context, _ []string) ([]service.Account, error) {
+ return r.accounts, nil
+}
+func (r *stubAccountRepoForHandler) SetRateLimited(context.Context, int64, time.Time) error {
+ return nil
+}
+func (r *stubAccountRepoForHandler) SetModelRateLimit(context.Context, int64, string, time.Time) error {
+ return nil
+}
+func (r *stubAccountRepoForHandler) SetOverloaded(context.Context, int64, time.Time) error {
+ return nil
+}
+func (r *stubAccountRepoForHandler) SetTempUnschedulable(context.Context, int64, time.Time, string) error {
+ return nil
+}
+func (r *stubAccountRepoForHandler) ClearTempUnschedulable(context.Context, int64) error { return nil }
+func (r *stubAccountRepoForHandler) ClearRateLimit(context.Context, int64) error { return nil }
+func (r *stubAccountRepoForHandler) ClearAntigravityQuotaScopes(context.Context, int64) error {
+ return nil
+}
+func (r *stubAccountRepoForHandler) ClearModelRateLimits(context.Context, int64) error { return nil }
+func (r *stubAccountRepoForHandler) UpdateSessionWindow(context.Context, int64, *time.Time, *time.Time, string) error {
+ return nil
+}
+func (r *stubAccountRepoForHandler) UpdateExtra(context.Context, int64, map[string]any) error {
+ return nil
+}
+func (r *stubAccountRepoForHandler) BulkUpdate(context.Context, []int64, service.AccountBulkUpdate) (int64, error) {
+ return 0, nil
+}
+
+func (r *stubAccountRepoForHandler) IncrementQuotaUsed(context.Context, int64, float64) error {
+ return nil
+}
+
+func (r *stubAccountRepoForHandler) ResetQuotaUsed(context.Context, int64) error {
+ return nil
+}
+
+// ==================== Stub: SoraClient (用于 SoraGatewayService) ====================
+
+var _ service.SoraClient = (*stubSoraClientForHandler)(nil)
+
+type stubSoraClientForHandler struct {
+ videoStatus *service.SoraVideoTaskStatus
+}
+
+func (s *stubSoraClientForHandler) Enabled() bool { return true }
+func (s *stubSoraClientForHandler) UploadImage(context.Context, *service.Account, []byte, string) (string, error) {
+ return "", nil
+}
+func (s *stubSoraClientForHandler) CreateImageTask(context.Context, *service.Account, service.SoraImageRequest) (string, error) {
+ return "task-image", nil
+}
+func (s *stubSoraClientForHandler) CreateVideoTask(context.Context, *service.Account, service.SoraVideoRequest) (string, error) {
+ return "task-video", nil
+}
+func (s *stubSoraClientForHandler) CreateStoryboardTask(context.Context, *service.Account, service.SoraStoryboardRequest) (string, error) {
+ return "task-video", nil
+}
+func (s *stubSoraClientForHandler) UploadCharacterVideo(context.Context, *service.Account, []byte) (string, error) {
+ return "", nil
+}
+func (s *stubSoraClientForHandler) GetCameoStatus(context.Context, *service.Account, string) (*service.SoraCameoStatus, error) {
+ return nil, nil
+}
+func (s *stubSoraClientForHandler) DownloadCharacterImage(context.Context, *service.Account, string) ([]byte, error) {
+ return nil, nil
+}
+func (s *stubSoraClientForHandler) UploadCharacterImage(context.Context, *service.Account, []byte) (string, error) {
+ return "", nil
+}
+func (s *stubSoraClientForHandler) FinalizeCharacter(context.Context, *service.Account, service.SoraCharacterFinalizeRequest) (string, error) {
+ return "", nil
+}
+func (s *stubSoraClientForHandler) SetCharacterPublic(context.Context, *service.Account, string) error {
+ return nil
+}
+func (s *stubSoraClientForHandler) DeleteCharacter(context.Context, *service.Account, string) error {
+ return nil
+}
+func (s *stubSoraClientForHandler) PostVideoForWatermarkFree(context.Context, *service.Account, string) (string, error) {
+ return "", nil
+}
+func (s *stubSoraClientForHandler) DeletePost(context.Context, *service.Account, string) error {
+ return nil
+}
+func (s *stubSoraClientForHandler) GetWatermarkFreeURLCustom(context.Context, *service.Account, string, string, string) (string, error) {
+ return "", nil
+}
+func (s *stubSoraClientForHandler) EnhancePrompt(context.Context, *service.Account, string, string, int) (string, error) {
+ return "", nil
+}
+func (s *stubSoraClientForHandler) GetImageTask(context.Context, *service.Account, string) (*service.SoraImageTaskStatus, error) {
+ return nil, nil
+}
+func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Account, _ string) (*service.SoraVideoTaskStatus, error) {
+ return s.videoStatus, nil
+}
+
+// ==================== 辅助:创建最小 GatewayService 和 SoraGatewayService ====================
+
+// newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService(用于测试 SelectAccountForModel)。
+func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService {
+ return service.NewGatewayService(
+ accountRepo, nil, nil, nil, nil, nil, nil, nil, nil,
+ nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
+ )
+}
+
+// newMinimalSoraGatewayService 创建最小 SoraGatewayService(用于测试 Forward)。
+func newMinimalSoraGatewayService(soraClient service.SoraClient) *service.SoraGatewayService {
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ PollIntervalSeconds: 1,
+ MaxPollAttempts: 1,
+ },
+ },
+ }
+ return service.NewSoraGatewayService(soraClient, nil, nil, cfg)
+}
+
+// ==================== processGeneration: 更多路径测试 ====================
+
+func TestProcessGeneration_SelectAccountError(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ // accountRepo 返回空列表 → SelectAccountForModel 返回 "no available accounts"
+ accountRepo := &stubAccountRepoForHandler{accounts: nil}
+ gatewayService := newMinimalGatewayService(accountRepo)
+ h := &SoraClientHandler{genService: genService, gatewayService: gatewayService}
+
+ h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
+ require.Equal(t, "failed", repo.gens[1].Status)
+ require.Contains(t, repo.gens[1].ErrorMessage, "选择账号失败")
+}
+
+func TestProcessGeneration_SoraGatewayServiceNil(t *testing.T) {
+ // TODO: Re-enable after Sora process generation is stable
+ // t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ // 提供可用账号使 SelectAccountForModel 成功
+ accountRepo := &stubAccountRepoForHandler{
+ accounts: []service.Account{
+ {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
+ },
+ }
+ gatewayService := newMinimalGatewayService(accountRepo)
+ // soraGatewayService 为 nil
+ h := &SoraClientHandler{genService: genService, gatewayService: gatewayService}
+
+ h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
+ require.Equal(t, "failed", repo.gens[1].Status)
+ require.Contains(t, repo.gens[1].ErrorMessage, "soraGatewayService")
+}
+
+func TestProcessGeneration_ForwardError(t *testing.T) {
+ // TODO: Re-enable after Sora process generation is stable
+ // t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ accountRepo := &stubAccountRepoForHandler{
+ accounts: []service.Account{
+ {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
+ },
+ }
+ gatewayService := newMinimalGatewayService(accountRepo)
+ // SoraClient 返回视频任务失败
+ soraClient := &stubSoraClientForHandler{
+ videoStatus: &service.SoraVideoTaskStatus{
+ Status: "failed",
+ ErrorMsg: "content policy violation",
+ },
+ }
+ soraGatewayService := newMinimalSoraGatewayService(soraClient)
+ h := &SoraClientHandler{
+ genService: genService,
+ gatewayService: gatewayService,
+ soraGatewayService: soraGatewayService,
+ }
+
+ h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1)
+ require.Equal(t, "failed", repo.gens[1].Status)
+ require.Contains(t, repo.gens[1].ErrorMessage, "生成失败")
+}
+
+func TestProcessGeneration_ForwardErrorCancelled(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
+ // MarkGenerating 内部调用 GetByID(第 1 次),Forward 失败后 processGeneration
+ // 调用 GetByID(第 2 次)。模拟外部在 Forward 期间取消了任务。
+ repo.getByIDOverrideAfterN = 1
+ repo.getByIDOverrideStatus = "cancelled"
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ accountRepo := &stubAccountRepoForHandler{
+ accounts: []service.Account{
+ {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
+ },
+ }
+ gatewayService := newMinimalGatewayService(accountRepo)
+ soraClient := &stubSoraClientForHandler{
+ videoStatus: &service.SoraVideoTaskStatus{Status: "failed", ErrorMsg: "reject"},
+ }
+ soraGatewayService := newMinimalSoraGatewayService(soraClient)
+ h := &SoraClientHandler{
+ genService: genService,
+ gatewayService: gatewayService,
+ soraGatewayService: soraGatewayService,
+ }
+
+ h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
+ // Forward 失败后检测到外部取消,不应调用 MarkFailed(状态保持 generating)
+ require.Equal(t, "generating", repo.gens[1].Status)
+}
+
+func TestProcessGeneration_ForwardSuccessNoMediaURL(t *testing.T) {
+ // TODO: Re-enable after Sora process generation is stable
+ // t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ accountRepo := &stubAccountRepoForHandler{
+ accounts: []service.Account{
+ {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
+ },
+ }
+ gatewayService := newMinimalGatewayService(accountRepo)
+ // SoraClient 返回 completed 但无 URL
+ soraClient := &stubSoraClientForHandler{
+ videoStatus: &service.SoraVideoTaskStatus{
+ Status: "completed",
+ URLs: nil, // 无 URL
+ },
+ }
+ soraGatewayService := newMinimalSoraGatewayService(soraClient)
+ h := &SoraClientHandler{
+ genService: genService,
+ gatewayService: gatewayService,
+ soraGatewayService: soraGatewayService,
+ }
+
+ h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
+ require.Equal(t, "failed", repo.gens[1].Status)
+ require.Contains(t, repo.gens[1].ErrorMessage, "未获取到媒体 URL")
+}
+
+func TestProcessGeneration_ForwardSuccessCancelledBeforeStore(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
+ // MarkGenerating 调用 GetByID(第 1 次),之后 processGeneration 行 176 调用 GetByID(第 2 次)
+ // 第 2 次返回 "cancelled" 状态,模拟外部取消
+ repo.getByIDOverrideAfterN = 1
+ repo.getByIDOverrideStatus = "cancelled"
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ accountRepo := &stubAccountRepoForHandler{
+ accounts: []service.Account{
+ {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
+ },
+ }
+ gatewayService := newMinimalGatewayService(accountRepo)
+ soraClient := &stubSoraClientForHandler{
+ videoStatus: &service.SoraVideoTaskStatus{
+ Status: "completed",
+ URLs: []string{"https://example.com/video.mp4"},
+ },
+ }
+ soraGatewayService := newMinimalSoraGatewayService(soraClient)
+ h := &SoraClientHandler{
+ genService: genService,
+ gatewayService: gatewayService,
+ soraGatewayService: soraGatewayService,
+ }
+
+ h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
+ // Forward 成功后检测到外部取消,不应调用存储和 MarkCompleted(状态保持 generating)
+ require.Equal(t, "generating", repo.gens[1].Status)
+}
+
+func TestProcessGeneration_FullSuccessUpstream(t *testing.T) {
+ // TODO: Re-enable after Sora process generation is stable
+ // t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ accountRepo := &stubAccountRepoForHandler{
+ accounts: []service.Account{
+ {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
+ },
+ }
+ gatewayService := newMinimalGatewayService(accountRepo)
+ soraClient := &stubSoraClientForHandler{
+ videoStatus: &service.SoraVideoTaskStatus{
+ Status: "completed",
+ URLs: []string{"https://example.com/video.mp4"},
+ },
+ }
+ soraGatewayService := newMinimalSoraGatewayService(soraClient)
+ // 无 S3 和本地存储,降级到 upstream
+ h := &SoraClientHandler{
+ genService: genService,
+ gatewayService: gatewayService,
+ soraGatewayService: soraGatewayService,
+ }
+
+ h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1)
+ require.Equal(t, "completed", repo.gens[1].Status)
+ require.Equal(t, service.SoraStorageTypeUpstream, repo.gens[1].StorageType)
+ require.NotEmpty(t, repo.gens[1].MediaURL)
+}
+
+func TestProcessGeneration_FullSuccessWithS3(t *testing.T) {
+ // TODO: Re-enable after Sora process generation is stable
+ // t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
+ sourceServer := newFakeSourceServer()
+ defer sourceServer.Close()
+ fakeS3 := newFakeS3Server("ok")
+ defer fakeS3.Close()
+
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ accountRepo := &stubAccountRepoForHandler{
+ accounts: []service.Account{
+ {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
+ },
+ }
+ gatewayService := newMinimalGatewayService(accountRepo)
+ soraClient := &stubSoraClientForHandler{
+ videoStatus: &service.SoraVideoTaskStatus{
+ Status: "completed",
+ URLs: []string{sourceServer.URL + "/video.mp4"},
+ },
+ }
+ soraGatewayService := newMinimalSoraGatewayService(soraClient)
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+
+ userRepo := newStubUserRepoForHandler()
+ userRepo.users[1] = &service.User{
+ ID: 1, SoraStorageQuotaBytes: 100 * 1024 * 1024,
+ }
+ quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
+
+ h := &SoraClientHandler{
+ genService: genService,
+ gatewayService: gatewayService,
+ soraGatewayService: soraGatewayService,
+ s3Storage: s3Storage,
+ quotaService: quotaService,
+ }
+
+ h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1)
+ require.Equal(t, "completed", repo.gens[1].Status)
+ require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType)
+ require.NotEmpty(t, repo.gens[1].S3ObjectKeys)
+ require.Greater(t, repo.gens[1].FileSizeBytes, int64(0))
+ // 验证配额已累加
+ require.Greater(t, userRepo.users[1].SoraStorageUsedBytes, int64(0))
+}
+
+func TestProcessGeneration_MarkCompletedFails(t *testing.T) {
+ // TODO: Re-enable after Sora process generation is stable
+ // t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
+ // 第 1 次 Update(MarkGenerating)成功,第 2 次(MarkCompleted)失败
+ repo.updateCallCount = new(int32)
+ repo.updateFailAfterN = 1
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ accountRepo := &stubAccountRepoForHandler{
+ accounts: []service.Account{
+ {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
+ },
+ }
+ gatewayService := newMinimalGatewayService(accountRepo)
+ soraClient := &stubSoraClientForHandler{
+ videoStatus: &service.SoraVideoTaskStatus{
+ Status: "completed",
+ URLs: []string{"https://example.com/video.mp4"},
+ },
+ }
+ soraGatewayService := newMinimalSoraGatewayService(soraClient)
+ h := &SoraClientHandler{
+ genService: genService,
+ gatewayService: gatewayService,
+ soraGatewayService: soraGatewayService,
+ }
+
+ h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1)
+ // MarkCompleted 内部先修改内存对象状态为 completed,然后 Update 失败。
+ // 由于 stub 存储的是指针,内存中的状态已被修改为 completed。
+ // 此测试验证 processGeneration 在 MarkCompleted 失败后提前返回(不调用 AddUsage)。
+ require.Equal(t, "completed", repo.gens[1].Status)
+}
+
+// ==================== cleanupStoredMedia 直接测试 ====================
+
+func TestCleanupStoredMedia_S3Path(t *testing.T) {
+ // S3 清理路径:s3Storage 为 nil 时不 panic
+ h := &SoraClientHandler{}
+ // 不应 panic
+ h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil)
+}
+
+func TestCleanupStoredMedia_LocalPath(t *testing.T) {
+ // 本地清理路径:mediaStorage 为 nil 时不 panic
+ h := &SoraClientHandler{}
+ h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, []string{"/tmp/test.mp4"})
+}
+
+func TestCleanupStoredMedia_UpstreamPath(t *testing.T) {
+ // upstream 类型不清理
+ h := &SoraClientHandler{}
+ h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeUpstream, nil, nil)
+}
+
+func TestCleanupStoredMedia_EmptyKeys(t *testing.T) {
+ // 空 keys 不触发清理
+ h := &SoraClientHandler{}
+ h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, nil, nil)
+ h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, nil)
+}
+
+// ==================== DeleteGeneration: 本地存储清理路径 ====================
+
+func TestDeleteGeneration_LocalStorageCleanup(t *testing.T) {
+ tmpDir, err := os.MkdirTemp("", "sora-delete-test-*")
+ require.NoError(t, err)
+ defer os.RemoveAll(tmpDir)
+
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Storage: config.SoraStorageConfig{
+ Type: "local",
+ LocalPath: tmpDir,
+ },
+ },
+ }
+ mediaStorage := service.NewSoraMediaStorage(cfg)
+
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{
+ ID: 1,
+ UserID: 1,
+ Status: "completed",
+ StorageType: service.SoraStorageTypeLocal,
+ MediaURL: "video/test.mp4",
+ MediaURLs: []string{"video/test.mp4"},
+ }
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage}
+
+ c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.DeleteGeneration(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ _, exists := repo.gens[1]
+ require.False(t, exists)
+}
+
+func TestDeleteGeneration_LocalStorageCleanup_MediaURLFallback(t *testing.T) {
+ // MediaURLs 为空,使用 MediaURL 作为清理路径
+ tmpDir, err := os.MkdirTemp("", "sora-delete-fallback-*")
+ require.NoError(t, err)
+ defer os.RemoveAll(tmpDir)
+
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Storage: config.SoraStorageConfig{
+ Type: "local",
+ LocalPath: tmpDir,
+ },
+ },
+ }
+ mediaStorage := service.NewSoraMediaStorage(cfg)
+
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{
+ ID: 1,
+ UserID: 1,
+ Status: "completed",
+ StorageType: service.SoraStorageTypeLocal,
+ MediaURL: "video/test.mp4",
+ MediaURLs: nil, // 空
+ }
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage}
+
+ c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.DeleteGeneration(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+}
+
+func TestDeleteGeneration_NonLocalStorage_SkipCleanup(t *testing.T) {
+ // 非本地存储类型 → 跳过清理
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{
+ ID: 1,
+ UserID: 1,
+ Status: "completed",
+ StorageType: service.SoraStorageTypeUpstream,
+ MediaURL: "https://upstream.com/v.mp4",
+ }
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{genService: genService}
+
+ c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.DeleteGeneration(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+}
+
+func TestDeleteGeneration_DeleteError(t *testing.T) {
+ // repo.Delete 出错
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream"}
+ repo.deleteErr = fmt.Errorf("delete failed")
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{genService: genService}
+
+ c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.DeleteGeneration(c)
+ require.Equal(t, http.StatusNotFound, rec.Code)
+}
+
+// ==================== fetchUpstreamModels 测试 ====================
+
+func TestFetchUpstreamModels_NilGateway(t *testing.T) {
+ // TODO: Re-enable after Sora upstream model sync is stable
+ // t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
+ h := &SoraClientHandler{}
+ _, err := h.fetchUpstreamModels(context.Background())
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "gatewayService 未初始化")
+}
+
+func TestFetchUpstreamModels_NoAccounts(t *testing.T) {
+ // TODO: Re-enable after Sora upstream model sync is stable
+ // t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
+ accountRepo := &stubAccountRepoForHandler{accounts: nil}
+ gatewayService := newMinimalGatewayService(accountRepo)
+ h := &SoraClientHandler{gatewayService: gatewayService}
+ _, err := h.fetchUpstreamModels(context.Background())
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "选择 Sora 账号失败")
+}
+
+func TestFetchUpstreamModels_NonAPIKeyAccount(t *testing.T) {
+ // TODO: Re-enable after Sora upstream model sync is stable
+ // t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
+ accountRepo := &stubAccountRepoForHandler{
+ accounts: []service.Account{
+ {ID: 1, Type: "oauth", Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
+ },
+ }
+ gatewayService := newMinimalGatewayService(accountRepo)
+ h := &SoraClientHandler{gatewayService: gatewayService}
+ _, err := h.fetchUpstreamModels(context.Background())
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "不支持模型同步")
+}
+
+func TestFetchUpstreamModels_MissingAPIKey(t *testing.T) {
+ // TODO: Re-enable after Sora upstream model sync is stable
+ // t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
+ accountRepo := &stubAccountRepoForHandler{
+ accounts: []service.Account{
+ {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
+ Credentials: map[string]any{"base_url": "https://sora.test"}},
+ },
+ }
+ gatewayService := newMinimalGatewayService(accountRepo)
+ h := &SoraClientHandler{gatewayService: gatewayService}
+ _, err := h.fetchUpstreamModels(context.Background())
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "api_key")
+}
+
+func TestFetchUpstreamModels_MissingBaseURL_FallsBackToDefault(t *testing.T) {
+ // TODO: Re-enable after Sora upstream model sync is stable
+ // t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
+ // GetBaseURL() 在缺少 base_url 时返回默认值 "https://api.anthropic.com"
+ // 因此不会触发 "账号缺少 base_url" 错误,而是会尝试请求默认 URL 并失败
+ accountRepo := &stubAccountRepoForHandler{
+ accounts: []service.Account{
+ {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
+ Credentials: map[string]any{"api_key": "sk-test"}},
+ },
+ }
+ gatewayService := newMinimalGatewayService(accountRepo)
+ h := &SoraClientHandler{gatewayService: gatewayService}
+ _, err := h.fetchUpstreamModels(context.Background())
+ require.Error(t, err)
+}
+
+func TestFetchUpstreamModels_UpstreamReturns500(t *testing.T) {
+ // TODO: Re-enable after Sora upstream model sync is stable
+ // t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusInternalServerError)
+ }))
+ defer ts.Close()
+
+ accountRepo := &stubAccountRepoForHandler{
+ accounts: []service.Account{
+ {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
+ Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
+ },
+ }
+ gatewayService := newMinimalGatewayService(accountRepo)
+ h := &SoraClientHandler{gatewayService: gatewayService}
+ _, err := h.fetchUpstreamModels(context.Background())
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "状态码 500")
+}
+
+func TestFetchUpstreamModels_UpstreamReturnsInvalidJSON(t *testing.T) {
+ // TODO: Re-enable after Sora upstream model sync is stable
+ // t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ _, _ = w.Write([]byte("not json"))
+ }))
+ defer ts.Close()
+
+ accountRepo := &stubAccountRepoForHandler{
+ accounts: []service.Account{
+ {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
+ Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
+ },
+ }
+ gatewayService := newMinimalGatewayService(accountRepo)
+ h := &SoraClientHandler{gatewayService: gatewayService}
+ _, err := h.fetchUpstreamModels(context.Background())
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "解析响应失败")
+}
+
+func TestFetchUpstreamModels_UpstreamReturnsEmptyList(t *testing.T) {
+ // TODO: Re-enable after Sora upstream model sync is stable
+ // t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ _, _ = w.Write([]byte(`{"data":[]}`))
+ }))
+ defer ts.Close()
+
+ accountRepo := &stubAccountRepoForHandler{
+ accounts: []service.Account{
+ {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
+ Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
+ },
+ }
+ gatewayService := newMinimalGatewayService(accountRepo)
+ h := &SoraClientHandler{gatewayService: gatewayService}
+ _, err := h.fetchUpstreamModels(context.Background())
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "空模型列表")
+}
+
+func TestFetchUpstreamModels_Success(t *testing.T) {
+ // TODO: Re-enable after Sora upstream model sync is stable
+ // t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ // 验证请求头
+ require.Equal(t, "Bearer sk-test", r.Header.Get("Authorization"))
+ require.True(t, strings.HasSuffix(r.URL.Path, "/sora/v1/models"))
+ w.WriteHeader(http.StatusOK)
+ _, _ = w.Write([]byte(`{"data":[{"id":"sora2-landscape-10s"},{"id":"sora2-portrait-10s"},{"id":"sora2-landscape-15s"},{"id":"gpt-image"}]}`))
+ }))
+ defer ts.Close()
+
+ accountRepo := &stubAccountRepoForHandler{
+ accounts: []service.Account{
+ {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
+ Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
+ },
+ }
+ gatewayService := newMinimalGatewayService(accountRepo)
+ h := &SoraClientHandler{gatewayService: gatewayService}
+ families, err := h.fetchUpstreamModels(context.Background())
+ require.NoError(t, err)
+ require.NotEmpty(t, families)
+}
+
+func TestFetchUpstreamModels_UnrecognizedModels(t *testing.T) {
+ // TODO: Re-enable after Sora upstream model sync is stable
+ // t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ _, _ = w.Write([]byte(`{"data":[{"id":"unknown-model-1"},{"id":"unknown-model-2"}]}`))
+ }))
+ defer ts.Close()
+
+ accountRepo := &stubAccountRepoForHandler{
+ accounts: []service.Account{
+ {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
+ Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
+ },
+ }
+ gatewayService := newMinimalGatewayService(accountRepo)
+ h := &SoraClientHandler{gatewayService: gatewayService}
+ _, err := h.fetchUpstreamModels(context.Background())
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "未能从上游模型列表中识别")
+}
+
+// ==================== getModelFamilies 缓存测试 ====================
+
+func TestGetModelFamilies_CachesLocalConfig(t *testing.T) {
+ // gatewayService 为 nil → fetchUpstreamModels 失败 → 降级到本地配置
+ h := &SoraClientHandler{}
+ families := h.getModelFamilies(context.Background())
+ require.NotEmpty(t, families)
+
+ // 第二次调用应命中缓存(modelCacheUpstream=false → 使用短 TTL)
+ families2 := h.getModelFamilies(context.Background())
+ require.Equal(t, families, families2)
+ require.False(t, h.modelCacheUpstream)
+}
+
+func TestGetModelFamilies_CachesUpstreamResult(t *testing.T) {
+ // TODO: Re-enable after Sora upstream model sync is stable
+ // t.Skip("TODO: 临时屏蔽依赖 Sora 上游模型同步的缓存测试,待账号选择逻辑稳定后恢复")
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ _, _ = w.Write([]byte(`{"data":[{"id":"sora2-landscape-10s"},{"id":"gpt-image"}]}`))
+ }))
+ defer ts.Close()
+
+ accountRepo := &stubAccountRepoForHandler{
+ accounts: []service.Account{
+ {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
+ Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
+ },
+ }
+ gatewayService := newMinimalGatewayService(accountRepo)
+ h := &SoraClientHandler{gatewayService: gatewayService}
+
+ families := h.getModelFamilies(context.Background())
+ require.NotEmpty(t, families)
+ require.True(t, h.modelCacheUpstream)
+
+ // 第二次调用命中缓存
+ families2 := h.getModelFamilies(context.Background())
+ require.Equal(t, families, families2)
+}
+
+func TestGetModelFamilies_ExpiredCacheRefreshes(t *testing.T) {
+ // 预设过期的缓存(modelCacheUpstream=false → 短 TTL)
+ h := &SoraClientHandler{
+ cachedFamilies: []service.SoraModelFamily{{ID: "old"}},
+ modelCacheTime: time.Now().Add(-10 * time.Minute), // 已过期
+ modelCacheUpstream: false,
+ }
+ // gatewayService 为 nil → fetchUpstreamModels 失败 → 使用本地配置刷新缓存
+ families := h.getModelFamilies(context.Background())
+ require.NotEmpty(t, families)
+ // 缓存已刷新,不再是 "old"
+ found := false
+ for _, f := range families {
+ if f.ID == "old" {
+ found = true
+ }
+ }
+ require.False(t, found, "过期缓存应被刷新")
+}
+
+// ==================== processGeneration: groupID 与 ForcePlatform ====================
+
+func TestProcessGeneration_NilGroupID_WithGateway_SelectAccountFails(t *testing.T) {
+ // groupID 为 nil → 设置 ForcePlatform=sora → 无可用 sora 账号 → MarkFailed
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+
+ // 空账号列表 → SelectAccountForModel 失败
+ accountRepo := &stubAccountRepoForHandler{accounts: nil}
+ gatewayService := newMinimalGatewayService(accountRepo)
+
+ h := &SoraClientHandler{
+ genService: genService,
+ gatewayService: gatewayService,
+ }
+
+ h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
+ require.Equal(t, "failed", repo.gens[1].Status)
+ require.Contains(t, repo.gens[1].ErrorMessage, "选择账号失败")
+}
+
+// ==================== Generate: 配额检查非 QuotaExceeded 错误 ====================
+
+func TestGenerate_CheckQuotaNonQuotaError(t *testing.T) {
+ // quotaService.CheckQuota 返回非 QuotaExceededError → 返回 403
+ repo := newStubSoraGenRepo()
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+
+ // 用户不存在 → GetByID 失败 → CheckQuota 返回普通 error
+ userRepo := newStubUserRepoForHandler()
+ quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
+
+ h := NewSoraClientHandler(genService, quotaService, nil, nil, nil, nil, nil)
+
+ body := `{"model":"sora2-landscape-10s","prompt":"test"}`
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate", body, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusForbidden, rec.Code)
+}
+
+// ==================== Generate: CreatePending 并发限制错误 ====================
+
+// stubSoraGenRepoWithAtomicCreate 实现 soraGenerationRepoAtomicCreator 接口
+type stubSoraGenRepoWithAtomicCreate struct {
+ stubSoraGenRepo
+ limitErr error
+}
+
+func (r *stubSoraGenRepoWithAtomicCreate) CreatePendingWithLimit(_ context.Context, gen *service.SoraGeneration, _ []string, _ int64) error {
+ if r.limitErr != nil {
+ return r.limitErr
+ }
+ return r.stubSoraGenRepo.Create(context.Background(), gen)
+}
+
+func TestGenerate_CreatePendingConcurrencyLimit(t *testing.T) {
+ repo := &stubSoraGenRepoWithAtomicCreate{
+ stubSoraGenRepo: *newStubSoraGenRepo(),
+ limitErr: service.ErrSoraGenerationConcurrencyLimit,
+ }
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := NewSoraClientHandler(genService, nil, nil, nil, nil, nil, nil)
+
+ body := `{"model":"sora2-landscape-10s","prompt":"test"}`
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate", body, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusTooManyRequests, rec.Code)
+ resp := parseResponse(t, rec)
+ require.Contains(t, resp["message"], "3")
+}
+
+// ==================== SaveToStorage: 配额超限 ====================
+
+func TestSaveToStorage_QuotaExceeded(t *testing.T) {
+ sourceServer := newFakeSourceServer()
+ defer sourceServer.Close()
+ fakeS3 := newFakeS3Server("ok")
+ defer fakeS3.Close()
+
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{
+ ID: 1, UserID: 1, Status: "completed",
+ StorageType: "upstream",
+ MediaURL: sourceServer.URL + "/v.mp4",
+ }
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+
+ // 用户配额已满
+ userRepo := newStubUserRepoForHandler()
+ userRepo.users[1] = &service.User{
+ ID: 1,
+ SoraStorageQuotaBytes: 10,
+ SoraStorageUsedBytes: 10,
+ }
+ quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
+ h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
+
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.SaveToStorage(c)
+ require.Equal(t, http.StatusTooManyRequests, rec.Code)
+}
+
+// ==================== SaveToStorage: 配额非 QuotaExceeded 错误 ====================
+
+func TestSaveToStorage_QuotaNonQuotaError(t *testing.T) {
+ sourceServer := newFakeSourceServer()
+ defer sourceServer.Close()
+ fakeS3 := newFakeS3Server("ok")
+ defer fakeS3.Close()
+
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{
+ ID: 1, UserID: 1, Status: "completed",
+ StorageType: "upstream",
+ MediaURL: sourceServer.URL + "/v.mp4",
+ }
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+
+ // 用户不存在 → GetByID 失败 → AddUsage 返回普通 error
+ userRepo := newStubUserRepoForHandler()
+ quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
+ h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
+
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.SaveToStorage(c)
+ require.Equal(t, http.StatusInternalServerError, rec.Code)
+}
+
+// ==================== SaveToStorage: MediaURLs 全为空 ====================
+
+func TestSaveToStorage_EmptyMediaURLs(t *testing.T) {
+ fakeS3 := newFakeS3Server("ok")
+ defer fakeS3.Close()
+
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{
+ ID: 1, UserID: 1, Status: "completed",
+ StorageType: "upstream",
+ MediaURL: "",
+ MediaURLs: []string{},
+ }
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
+
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.SaveToStorage(c)
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+ resp := parseResponse(t, rec)
+ require.Contains(t, resp["message"], "已过期")
+}
+
+// ==================== SaveToStorage: S3 上传失败时已有已上传文件需清理 ====================
+
+func TestSaveToStorage_MultiURL_SecondUploadFails(t *testing.T) {
+ sourceServer := newFakeSourceServer()
+ defer sourceServer.Close()
+ fakeS3 := newFakeS3Server("fail-second")
+ defer fakeS3.Close()
+
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{
+ ID: 1, UserID: 1, Status: "completed",
+ StorageType: "upstream",
+ MediaURL: sourceServer.URL + "/v1.mp4",
+ MediaURLs: []string{sourceServer.URL + "/v1.mp4", sourceServer.URL + "/v2.mp4"},
+ }
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
+
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.SaveToStorage(c)
+ require.Equal(t, http.StatusInternalServerError, rec.Code)
+}
+
+// ==================== SaveToStorage: UpdateStorageForCompleted 失败(含配额回滚) ====================
+
+func TestSaveToStorage_MarkCompletedFailsWithQuotaRollback(t *testing.T) {
+ sourceServer := newFakeSourceServer()
+ defer sourceServer.Close()
+ fakeS3 := newFakeS3Server("ok")
+ defer fakeS3.Close()
+
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{
+ ID: 1, UserID: 1, Status: "completed",
+ StorageType: "upstream",
+ MediaURL: sourceServer.URL + "/v.mp4",
+ }
+ repo.updateErr = fmt.Errorf("db error")
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+
+ userRepo := newStubUserRepoForHandler()
+ userRepo.users[1] = &service.User{
+ ID: 1,
+ SoraStorageQuotaBytes: 100 * 1024 * 1024,
+ SoraStorageUsedBytes: 0,
+ }
+ quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
+ h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
+
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.SaveToStorage(c)
+ require.Equal(t, http.StatusInternalServerError, rec.Code)
+}
+
+// ==================== cleanupStoredMedia: 实际 S3 删除路径 ====================
+
+func TestCleanupStoredMedia_WithS3Storage_ActualDelete(t *testing.T) {
+ fakeS3 := newFakeS3Server("ok")
+ defer fakeS3.Close()
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+ h := &SoraClientHandler{s3Storage: s3Storage}
+
+ h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1", "key2"}, nil)
+}
+
+func TestCleanupStoredMedia_S3DeleteFails_LogOnly(t *testing.T) {
+ fakeS3 := newFakeS3Server("fail")
+ defer fakeS3.Close()
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+ h := &SoraClientHandler{s3Storage: s3Storage}
+
+ h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil)
+}
+
+func TestCleanupStoredMedia_LocalDeleteFails_LogOnly(t *testing.T) {
+ tmpDir, err := os.MkdirTemp("", "sora-cleanup-fail-*")
+ require.NoError(t, err)
+ defer os.RemoveAll(tmpDir)
+
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Storage: config.SoraStorageConfig{
+ Type: "local",
+ LocalPath: tmpDir,
+ },
+ },
+ }
+ mediaStorage := service.NewSoraMediaStorage(cfg)
+ h := &SoraClientHandler{mediaStorage: mediaStorage}
+
+ h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, []string{"nonexistent/file.mp4"})
+}
+
+// ==================== DeleteGeneration: 本地文件删除失败(仅日志) ====================
+
+func TestDeleteGeneration_LocalStorageDeleteFails_LogOnly(t *testing.T) {
+ tmpDir, err := os.MkdirTemp("", "sora-del-test-*")
+ require.NoError(t, err)
+ defer os.RemoveAll(tmpDir)
+
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Storage: config.SoraStorageConfig{
+ Type: "local",
+ LocalPath: tmpDir,
+ },
+ },
+ }
+ mediaStorage := service.NewSoraMediaStorage(cfg)
+
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{
+ ID: 1, UserID: 1, Status: "completed",
+ StorageType: service.SoraStorageTypeLocal,
+ MediaURL: "nonexistent/video.mp4",
+ MediaURLs: []string{"nonexistent/video.mp4"},
+ }
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage}
+
+ c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.DeleteGeneration(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+}
+
+// ==================== CancelGeneration: 任务已结束冲突 ====================
+
+func TestCancelGeneration_AlreadyCompleted(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"}
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{genService: genService}
+
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.CancelGeneration(c)
+ require.Equal(t, http.StatusConflict, rec.Code)
+}
diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go
new file mode 100644
index 00000000..c9c7de17
--- /dev/null
+++ b/backend/internal/handler/sora_gateway_handler.go
@@ -0,0 +1,694 @@
+package handler
+
+import (
+ "context"
+ "crypto/sha256"
+ "encoding/hex"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "net/http"
+ "os"
+ "path"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/ip"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
+ middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/Wei-Shaw/sub2api/internal/util/soraerror"
+
+ "github.com/gin-gonic/gin"
+ "github.com/tidwall/gjson"
+ "github.com/tidwall/sjson"
+ "go.uber.org/zap"
+)
+
+// SoraGatewayHandler handles Sora chat completions requests
+type SoraGatewayHandler struct {
+ gatewayService *service.GatewayService
+ soraGatewayService *service.SoraGatewayService
+ billingCacheService *service.BillingCacheService
+ usageRecordWorkerPool *service.UsageRecordWorkerPool
+ concurrencyHelper *ConcurrencyHelper
+ maxAccountSwitches int
+ streamMode string
+ soraTLSEnabled bool
+ soraMediaSigningKey string
+ soraMediaRoot string
+}
+
+// NewSoraGatewayHandler creates a new SoraGatewayHandler
+func NewSoraGatewayHandler(
+ gatewayService *service.GatewayService,
+ soraGatewayService *service.SoraGatewayService,
+ concurrencyService *service.ConcurrencyService,
+ billingCacheService *service.BillingCacheService,
+ usageRecordWorkerPool *service.UsageRecordWorkerPool,
+ cfg *config.Config,
+) *SoraGatewayHandler {
+ pingInterval := time.Duration(0)
+ maxAccountSwitches := 3
+ streamMode := "force"
+ soraTLSEnabled := true
+ signKey := ""
+ mediaRoot := "/app/data/sora"
+ if cfg != nil {
+ pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
+ if cfg.Gateway.MaxAccountSwitches > 0 {
+ maxAccountSwitches = cfg.Gateway.MaxAccountSwitches
+ }
+ if mode := strings.TrimSpace(cfg.Gateway.SoraStreamMode); mode != "" {
+ streamMode = mode
+ }
+ soraTLSEnabled = !cfg.Sora.Client.DisableTLSFingerprint
+ signKey = strings.TrimSpace(cfg.Gateway.SoraMediaSigningKey)
+ if root := strings.TrimSpace(cfg.Sora.Storage.LocalPath); root != "" {
+ mediaRoot = root
+ }
+ }
+ return &SoraGatewayHandler{
+ gatewayService: gatewayService,
+ soraGatewayService: soraGatewayService,
+ billingCacheService: billingCacheService,
+ usageRecordWorkerPool: usageRecordWorkerPool,
+ concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
+ maxAccountSwitches: maxAccountSwitches,
+ streamMode: strings.ToLower(streamMode),
+ soraTLSEnabled: soraTLSEnabled,
+ soraMediaSigningKey: signKey,
+ soraMediaRoot: mediaRoot,
+ }
+}
+
+// ChatCompletions handles Sora /v1/chat/completions endpoint
+func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
+ apiKey, ok := middleware2.GetAPIKeyFromContext(c)
+ if !ok {
+ h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
+ return
+ }
+
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
+ return
+ }
+ reqLog := requestLogger(
+ c,
+ "handler.sora_gateway.chat_completions",
+ zap.Int64("user_id", subject.UserID),
+ zap.Int64("api_key_id", apiKey.ID),
+ zap.Any("group_id", apiKey.GroupID),
+ )
+
+ body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
+ if err != nil {
+ if maxErr, ok := extractMaxBytesError(err); ok {
+ h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
+ return
+ }
+ h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
+ return
+ }
+ if len(body) == 0 {
+ h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
+ return
+ }
+
+ setOpsRequestContext(c, "", false, body)
+
+ // 校验请求体 JSON 合法性
+ if !gjson.ValidBytes(body) {
+ h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
+ return
+ }
+
+ // 使用 gjson 只读提取字段做校验,避免完整 Unmarshal
+ modelResult := gjson.GetBytes(body, "model")
+ if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
+ h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
+ return
+ }
+ reqModel := modelResult.String()
+
+ msgsResult := gjson.GetBytes(body, "messages")
+ if !msgsResult.IsArray() || len(msgsResult.Array()) == 0 {
+ h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "messages is required")
+ return
+ }
+
+ clientStream := gjson.GetBytes(body, "stream").Bool()
+ reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", clientStream))
+ if !clientStream {
+ if h.streamMode == "error" {
+ h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Sora requires stream=true")
+ return
+ }
+ var err error
+ body, err = sjson.SetBytes(body, "stream", true)
+ if err != nil {
+ h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
+ return
+ }
+ }
+
+ setOpsRequestContext(c, reqModel, clientStream, body)
+
+ platform := ""
+ if forced, ok := middleware2.GetForcePlatformFromContext(c); ok {
+ platform = forced
+ } else if apiKey.Group != nil {
+ platform = apiKey.Group.Platform
+ }
+ if platform != service.PlatformSora {
+ h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "This endpoint only supports Sora platform")
+ return
+ }
+
+ streamStarted := false
+ subscription, _ := middleware2.GetSubscriptionFromContext(c)
+
+ maxWait := service.CalculateMaxWait(subject.Concurrency)
+ canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
+ waitCounted := false
+ if err != nil {
+ reqLog.Warn("sora.user_wait_counter_increment_failed", zap.Error(err))
+ } else if !canWait {
+ reqLog.Info("sora.user_wait_queue_full", zap.Int("max_wait", maxWait))
+ h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
+ return
+ }
+ if err == nil && canWait {
+ waitCounted = true
+ }
+ defer func() {
+ if waitCounted {
+ h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
+ }
+ }()
+
+ userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, clientStream, &streamStarted)
+ if err != nil {
+ reqLog.Warn("sora.user_slot_acquire_failed", zap.Error(err))
+ h.handleConcurrencyError(c, err, "user", streamStarted)
+ return
+ }
+ if waitCounted {
+ h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
+ waitCounted = false
+ }
+ userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
+ if userReleaseFunc != nil {
+ defer userReleaseFunc()
+ }
+
+ if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
+ reqLog.Info("sora.billing_eligibility_check_failed", zap.Error(err))
+ status, code, message := billingErrorDetails(err)
+ h.handleStreamingAwareError(c, status, code, message, streamStarted)
+ return
+ }
+
+ sessionHash := generateOpenAISessionHash(c, body)
+
+ maxAccountSwitches := h.maxAccountSwitches
+ switchCount := 0
+ failedAccountIDs := make(map[int64]struct{})
+ lastFailoverStatus := 0
+ var lastFailoverBody []byte
+ var lastFailoverHeaders http.Header
+
+ for {
+ selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "", int64(0))
+ if err != nil {
+ reqLog.Warn("sora.account_select_failed",
+ zap.Error(err),
+ zap.Int("excluded_account_count", len(failedAccountIDs)),
+ )
+ if len(failedAccountIDs) == 0 {
+ h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
+ return
+ }
+ rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
+ fields := []zap.Field{
+ zap.Int("last_upstream_status", lastFailoverStatus),
+ }
+ if rayID != "" {
+ fields = append(fields, zap.String("last_upstream_cf_ray", rayID))
+ }
+ if mitigated != "" {
+ fields = append(fields, zap.String("last_upstream_cf_mitigated", mitigated))
+ }
+ if contentType != "" {
+ fields = append(fields, zap.String("last_upstream_content_type", contentType))
+ }
+ reqLog.Warn("sora.failover_exhausted_no_available_accounts", fields...)
+ h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted)
+ return
+ }
+ account := selection.Account
+ setOpsSelectedAccount(c, account.ID, account.Platform)
+ proxyBound := account.ProxyID != nil
+ proxyID := int64(0)
+ if account.ProxyID != nil {
+ proxyID = *account.ProxyID
+ }
+ tlsFingerprintEnabled := h.soraTLSEnabled
+
+ accountReleaseFunc := selection.ReleaseFunc
+ if !selection.Acquired {
+ if selection.WaitPlan == nil {
+ h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
+ return
+ }
+ accountWaitCounted := false
+ canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
+ if err != nil {
+ reqLog.Warn("sora.account_wait_counter_increment_failed",
+ zap.Int64("account_id", account.ID),
+ zap.Int64("proxy_id", proxyID),
+ zap.Bool("proxy_bound", proxyBound),
+ zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
+ zap.Error(err),
+ )
+ } else if !canWait {
+ reqLog.Info("sora.account_wait_queue_full",
+ zap.Int64("account_id", account.ID),
+ zap.Int64("proxy_id", proxyID),
+ zap.Bool("proxy_bound", proxyBound),
+ zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
+ zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
+ )
+ h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
+ return
+ }
+ if err == nil && canWait {
+ accountWaitCounted = true
+ }
+ defer func() {
+ if accountWaitCounted {
+ h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
+ }
+ }()
+
+ accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
+ c,
+ account.ID,
+ selection.WaitPlan.MaxConcurrency,
+ selection.WaitPlan.Timeout,
+ clientStream,
+ &streamStarted,
+ )
+ if err != nil {
+ reqLog.Warn("sora.account_slot_acquire_failed",
+ zap.Int64("account_id", account.ID),
+ zap.Int64("proxy_id", proxyID),
+ zap.Bool("proxy_bound", proxyBound),
+ zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
+ zap.Error(err),
+ )
+ h.handleConcurrencyError(c, err, "account", streamStarted)
+ return
+ }
+ if accountWaitCounted {
+ h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
+ accountWaitCounted = false
+ }
+ }
+ accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
+
+ result, err := h.soraGatewayService.Forward(c.Request.Context(), c, account, body, clientStream)
+ if accountReleaseFunc != nil {
+ accountReleaseFunc()
+ }
+ if err != nil {
+ var failoverErr *service.UpstreamFailoverError
+ if errors.As(err, &failoverErr) {
+ failedAccountIDs[account.ID] = struct{}{}
+ if switchCount >= maxAccountSwitches {
+ lastFailoverStatus = failoverErr.StatusCode
+ lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders)
+ lastFailoverBody = failoverErr.ResponseBody
+ rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
+ fields := []zap.Field{
+ zap.Int64("account_id", account.ID),
+ zap.Int64("proxy_id", proxyID),
+ zap.Bool("proxy_bound", proxyBound),
+ zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
+ zap.Int("upstream_status", failoverErr.StatusCode),
+ zap.Int("switch_count", switchCount),
+ zap.Int("max_switches", maxAccountSwitches),
+ }
+ if rayID != "" {
+ fields = append(fields, zap.String("upstream_cf_ray", rayID))
+ }
+ if mitigated != "" {
+ fields = append(fields, zap.String("upstream_cf_mitigated", mitigated))
+ }
+ if contentType != "" {
+ fields = append(fields, zap.String("upstream_content_type", contentType))
+ }
+ reqLog.Warn("sora.upstream_failover_exhausted", fields...)
+ h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted)
+ return
+ }
+ lastFailoverStatus = failoverErr.StatusCode
+ lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders)
+ lastFailoverBody = failoverErr.ResponseBody
+ switchCount++
+ upstreamErrCode, upstreamErrMsg := extractUpstreamErrorCodeAndMessage(lastFailoverBody)
+ rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
+ fields := []zap.Field{
+ zap.Int64("account_id", account.ID),
+ zap.Int64("proxy_id", proxyID),
+ zap.Bool("proxy_bound", proxyBound),
+ zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
+ zap.Int("upstream_status", failoverErr.StatusCode),
+ zap.String("upstream_error_code", upstreamErrCode),
+ zap.String("upstream_error_message", upstreamErrMsg),
+ zap.Int("switch_count", switchCount),
+ zap.Int("max_switches", maxAccountSwitches),
+ }
+ if rayID != "" {
+ fields = append(fields, zap.String("upstream_cf_ray", rayID))
+ }
+ if mitigated != "" {
+ fields = append(fields, zap.String("upstream_cf_mitigated", mitigated))
+ }
+ if contentType != "" {
+ fields = append(fields, zap.String("upstream_content_type", contentType))
+ }
+ reqLog.Warn("sora.upstream_failover_switching", fields...)
+ continue
+ }
+ reqLog.Error("sora.forward_failed",
+ zap.Int64("account_id", account.ID),
+ zap.Int64("proxy_id", proxyID),
+ zap.Bool("proxy_bound", proxyBound),
+ zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
+ zap.Error(err),
+ )
+ return
+ }
+
+ userAgent := c.GetHeader("User-Agent")
+ clientIP := ip.GetClientIP(c)
+ requestPayloadHash := service.HashUsageRequestPayload(body)
+ inboundEndpoint := GetInboundEndpoint(c)
+ upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
+
+ // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
+ h.submitUsageRecordTask(func(ctx context.Context) {
+ if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
+ Result: result,
+ APIKey: apiKey,
+ User: apiKey.User,
+ Account: account,
+ Subscription: subscription,
+ InboundEndpoint: inboundEndpoint,
+ UpstreamEndpoint: upstreamEndpoint,
+ UserAgent: userAgent,
+ IPAddress: clientIP,
+ RequestPayloadHash: requestPayloadHash,
+ }); err != nil {
+ logger.L().With(
+ zap.String("component", "handler.sora_gateway.chat_completions"),
+ zap.Int64("user_id", subject.UserID),
+ zap.Int64("api_key_id", apiKey.ID),
+ zap.Any("group_id", apiKey.GroupID),
+ zap.String("model", reqModel),
+ zap.Int64("account_id", account.ID),
+ ).Error("sora.record_usage_failed", zap.Error(err))
+ }
+ })
+ reqLog.Debug("sora.request_completed",
+ zap.Int64("account_id", account.ID),
+ zap.Int64("proxy_id", proxyID),
+ zap.Bool("proxy_bound", proxyBound),
+ zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
+ zap.Int("switch_count", switchCount),
+ )
+ return
+ }
+}
+
+func generateOpenAISessionHash(c *gin.Context, body []byte) string {
+ if c == nil {
+ return ""
+ }
+ sessionID := strings.TrimSpace(c.GetHeader("session_id"))
+ if sessionID == "" {
+ sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
+ }
+ if sessionID == "" && len(body) > 0 {
+ sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
+ }
+ if sessionID == "" {
+ return ""
+ }
+ hash := sha256.Sum256([]byte(sessionID))
+ return hex.EncodeToString(hash[:])
+}
+
+func (h *SoraGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
+ if task == nil {
+ return
+ }
+ if h.usageRecordWorkerPool != nil {
+ h.usageRecordWorkerPool.Submit(task)
+ return
+ }
+ // 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+ defer func() {
+ if recovered := recover(); recovered != nil {
+ logger.L().With(
+ zap.String("component", "handler.sora_gateway.chat_completions"),
+ zap.Any("panic", recovered),
+ ).Error("sora.usage_record_task_panic_recovered")
+ }
+ }()
+ task(ctx)
+}
+
+func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
+ h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
+ fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
+}
+
+func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseHeaders http.Header, responseBody []byte, streamStarted bool) {
+ upstreamMsg := service.ExtractUpstreamErrorMessage(responseBody)
+ service.SetOpsUpstreamError(c, statusCode, upstreamMsg, "")
+
+ status, errType, errMsg := h.mapUpstreamError(statusCode, responseHeaders, responseBody)
+ h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
+}
+
+func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseHeaders http.Header, responseBody []byte) (int, string, string) {
+ if isSoraCloudflareChallengeResponse(statusCode, responseHeaders, responseBody) {
+ baseMsg := fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", statusCode)
+ return http.StatusBadGateway, "upstream_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody)
+ }
+
+ upstreamCode, upstreamMessage := extractUpstreamErrorCodeAndMessage(responseBody)
+ if strings.EqualFold(upstreamCode, "cf_shield_429") {
+ baseMsg := "Sora request blocked by Cloudflare shield (429). Please switch to a clean proxy/network and retry."
+ return http.StatusTooManyRequests, "rate_limit_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody)
+ }
+ if shouldPassthroughSoraUpstreamMessage(statusCode, upstreamMessage) {
+ switch statusCode {
+ case 401, 403, 404, 500, 502, 503, 504:
+ return http.StatusBadGateway, "upstream_error", upstreamMessage
+ case 429:
+ return http.StatusTooManyRequests, "rate_limit_error", upstreamMessage
+ }
+ }
+
+ switch statusCode {
+ case 401:
+ return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
+ case 403:
+ return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
+ case 404:
+ if strings.EqualFold(upstreamCode, "unsupported_country_code") {
+ return http.StatusBadGateway, "upstream_error", "Upstream region capability unavailable for this account, please contact administrator"
+ }
+ return http.StatusBadGateway, "upstream_error", "Upstream capability unavailable for this account, please contact administrator"
+ case 429:
+ return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
+ case 529:
+ return http.StatusServiceUnavailable, "upstream_error", "Upstream service overloaded, please retry later"
+ case 500, 502, 503, 504:
+ return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable"
+ default:
+ return http.StatusBadGateway, "upstream_error", "Upstream request failed"
+ }
+}
+
+func cloneHTTPHeaders(headers http.Header) http.Header {
+ if headers == nil {
+ return nil
+ }
+ return headers.Clone()
+}
+
+func extractSoraFailoverHeaderInsights(headers http.Header, body []byte) (rayID, mitigated, contentType string) {
+ if headers != nil {
+ mitigated = strings.TrimSpace(headers.Get("cf-mitigated"))
+ contentType = strings.TrimSpace(headers.Get("content-type"))
+ if contentType == "" {
+ contentType = strings.TrimSpace(headers.Get("Content-Type"))
+ }
+ }
+ rayID = soraerror.ExtractCloudflareRayID(headers, body)
+ return rayID, mitigated, contentType
+}
+
+func isSoraCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
+ return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body)
+}
+
+func shouldPassthroughSoraUpstreamMessage(statusCode int, message string) bool {
+ message = strings.TrimSpace(message)
+ if message == "" {
+ return false
+ }
+ if statusCode == http.StatusForbidden || statusCode == http.StatusTooManyRequests {
+ lower := strings.ToLower(message)
+ if strings.Contains(lower, "
Just a moment...`)
+
+ h := &SoraGatewayHandler{}
+ h.handleFailoverExhausted(c, http.StatusForbidden, headers, body, true)
+
+ lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n")
+ require.Len(t, lines, 2)
+ jsonStr := strings.TrimPrefix(lines[1], "data: ")
+
+ var parsed map[string]any
+ require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
+
+ errorObj, ok := parsed["error"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "upstream_error", errorObj["type"])
+ msg, _ := errorObj["message"].(string)
+ require.Contains(t, msg, "Cloudflare challenge")
+ require.Contains(t, msg, "cf-ray: 9d01b0e9ecc35829-SEA")
+}
+
+func TestSoraHandleFailoverExhausted_CfShield429MappedToRateLimitError(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
+
+ headers := http.Header{}
+ headers.Set("cf-ray", "9d03b68c086027a1-SEA")
+ body := []byte(`{"error":{"code":"cf_shield_429","message":"shield blocked"}}`)
+
+ h := &SoraGatewayHandler{}
+ h.handleFailoverExhausted(c, http.StatusTooManyRequests, headers, body, true)
+
+ lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n")
+ require.Len(t, lines, 2)
+ jsonStr := strings.TrimPrefix(lines[1], "data: ")
+
+ var parsed map[string]any
+ require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
+
+ errorObj, ok := parsed["error"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "rate_limit_error", errorObj["type"])
+ msg, _ := errorObj["message"].(string)
+ require.Contains(t, msg, "Cloudflare shield")
+ require.Contains(t, msg, "cf-ray: 9d03b68c086027a1-SEA")
+}
+
+func TestExtractSoraFailoverHeaderInsights(t *testing.T) {
+ headers := http.Header{}
+ headers.Set("cf-mitigated", "challenge")
+ headers.Set("content-type", "text/html")
+ body := []byte(``)
+
+ rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(headers, body)
+ require.Equal(t, "9cff2d62d83bb98d", rayID)
+ require.Equal(t, "challenge", mitigated)
+ require.Equal(t, "text/html", contentType)
+}
diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go
index 4b54d41a..305a4632 100644
--- a/backend/internal/handler/wire.go
+++ b/backend/internal/handler/wire.go
@@ -88,6 +88,8 @@ func ProvideHandlers(
adminHandlers *AdminHandlers,
gatewayHandler *GatewayHandler,
openaiGatewayHandler *OpenAIGatewayHandler,
+ soraGatewayHandler *SoraGatewayHandler, // 从本地版本合并
+ soraClientHandler *SoraClientHandler, // 从本地版本合并
settingHandler *SettingHandler,
totpHandler *TotpHandler,
paymentHandler *PaymentHandler,
@@ -106,6 +108,8 @@ func ProvideHandlers(
Admin: adminHandlers,
Gateway: gatewayHandler,
OpenAIGateway: openaiGatewayHandler,
+ SoraGateway: soraGatewayHandler, // 从本地版本合并
+ SoraClient: soraClientHandler, // 从本地版本合并
Setting: settingHandler,
Totp: totpHandler,
Payment: paymentHandler,
@@ -125,6 +129,8 @@ var ProviderSet = wire.NewSet(
NewAnnouncementHandler,
NewGatewayHandler,
NewOpenAIGatewayHandler,
+ NewSoraGatewayHandler, // 从本地版本合并
+ NewSoraClientHandler, // 从本地版本合并
NewTotpHandler,
ProvideSettingHandler,
NewPaymentHandler,
diff --git a/backend/internal/pkg/openai/oauth.go b/backend/internal/pkg/openai/oauth.go
index 618b6adb..8c765bef 100644
--- a/backend/internal/pkg/openai/oauth.go
+++ b/backend/internal/pkg/openai/oauth.go
@@ -18,6 +18,9 @@ const (
// OAuth Client ID for OpenAI (Codex CLI official)
ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
+ // OAuth Client ID for Sora (从本地版本合并)
+ SoraClientID = "app_LlGpXReQgckcGGUo2JrYvtJK"
+
// OAuth endpoints
AuthorizeURL = "https://auth.openai.com/oauth/authorize"
TokenURL = "https://auth.openai.com/oauth/token"
diff --git a/backend/internal/repository/sora_account_repo.go b/backend/internal/repository/sora_account_repo.go
new file mode 100644
index 00000000..ad2ae638
--- /dev/null
+++ b/backend/internal/repository/sora_account_repo.go
@@ -0,0 +1,98 @@
+package repository
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+)
+
+// soraAccountRepository 实现 service.SoraAccountRepository 接口。
+// 使用原生 SQL 操作 sora_accounts 表,因为该表不在 Ent ORM 管理范围内。
+//
+// 设计说明:
+// - sora_accounts 表是独立迁移创建的,不通过 Ent Schema 管理
+// - 使用 ON CONFLICT (account_id) DO UPDATE 实现 Upsert 语义
+// - 与 accounts 主表通过外键关联,ON DELETE CASCADE 确保级联删除
+type soraAccountRepository struct {
+ sql *sql.DB
+}
+
+// NewSoraAccountRepository 创建 Sora 账号扩展表仓储实例
+func NewSoraAccountRepository(sqlDB *sql.DB) service.SoraAccountRepository {
+ return &soraAccountRepository{sql: sqlDB}
+}
+
+// Upsert 创建或更新 Sora 账号扩展信息
+// 使用 PostgreSQL ON CONFLICT ... DO UPDATE 实现原子性 upsert
+func (r *soraAccountRepository) Upsert(ctx context.Context, accountID int64, updates map[string]any) error {
+ accessToken, accessOK := updates["access_token"].(string)
+ refreshToken, refreshOK := updates["refresh_token"].(string)
+ sessionToken, sessionOK := updates["session_token"].(string)
+
+ if !accessOK || accessToken == "" || !refreshOK || refreshToken == "" {
+ if !sessionOK {
+ return errors.New("缺少 access_token/refresh_token,且未提供可更新字段")
+ }
+ result, err := r.sql.ExecContext(ctx, `
+ UPDATE sora_accounts
+ SET session_token = CASE WHEN $2 = '' THEN session_token ELSE $2 END,
+ updated_at = NOW()
+ WHERE account_id = $1
+ `, accountID, sessionToken)
+ if err != nil {
+ return err
+ }
+ rows, err := result.RowsAffected()
+ if err != nil {
+ return err
+ }
+ if rows == 0 {
+ return errors.New("sora_accounts 记录不存在,无法仅更新 session_token")
+ }
+ return nil
+ }
+
+ _, err := r.sql.ExecContext(ctx, `
+ INSERT INTO sora_accounts (account_id, access_token, refresh_token, session_token, created_at, updated_at)
+ VALUES ($1, $2, $3, $4, NOW(), NOW())
+ ON CONFLICT (account_id) DO UPDATE SET
+ access_token = EXCLUDED.access_token,
+ refresh_token = EXCLUDED.refresh_token,
+ session_token = CASE WHEN EXCLUDED.session_token = '' THEN sora_accounts.session_token ELSE EXCLUDED.session_token END,
+ updated_at = NOW()
+ `, accountID, accessToken, refreshToken, sessionToken)
+ return err
+}
+
+// GetByAccountID 根据账号 ID 获取 Sora 扩展信息
+func (r *soraAccountRepository) GetByAccountID(ctx context.Context, accountID int64) (*service.SoraAccount, error) {
+ rows, err := r.sql.QueryContext(ctx, `
+ SELECT account_id, access_token, refresh_token, COALESCE(session_token, '')
+ FROM sora_accounts
+ WHERE account_id = $1
+ `, accountID)
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ if !rows.Next() {
+ return nil, nil // 记录不存在
+ }
+
+ var sa service.SoraAccount
+ if err := rows.Scan(&sa.AccountID, &sa.AccessToken, &sa.RefreshToken, &sa.SessionToken); err != nil {
+ return nil, err
+ }
+ return &sa, nil
+}
+
+// Delete 删除 Sora 账号扩展信息
+func (r *soraAccountRepository) Delete(ctx context.Context, accountID int64) error {
+ _, err := r.sql.ExecContext(ctx, `
+ DELETE FROM sora_accounts WHERE account_id = $1
+ `, accountID)
+ return err
+}
diff --git a/backend/internal/repository/sora_generation_repo.go b/backend/internal/repository/sora_generation_repo.go
new file mode 100644
index 00000000..aaf3cb2f
--- /dev/null
+++ b/backend/internal/repository/sora_generation_repo.go
@@ -0,0 +1,419 @@
+package repository
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+ "fmt"
+ "strings"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+)
+
+// soraGenerationRepository 实现 service.SoraGenerationRepository 接口。
+// 使用原生 SQL 操作 sora_generations 表。
+type soraGenerationRepository struct {
+ sql *sql.DB
+}
+
+// NewSoraGenerationRepository 创建 Sora 生成记录仓储实例。
+func NewSoraGenerationRepository(sqlDB *sql.DB) service.SoraGenerationRepository {
+ return &soraGenerationRepository{sql: sqlDB}
+}
+
+func (r *soraGenerationRepository) Create(ctx context.Context, gen *service.SoraGeneration) error {
+ mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
+ s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
+
+ err := r.sql.QueryRowContext(ctx, `
+ INSERT INTO sora_generations (
+ user_id, api_key_id, model, prompt, media_type,
+ status, media_url, media_urls, file_size_bytes,
+ storage_type, s3_object_keys, upstream_task_id, error_message
+ ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
+ RETURNING id, created_at
+ `,
+ gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType,
+ gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
+ gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage,
+ ).Scan(&gen.ID, &gen.CreatedAt)
+ return err
+}
+
+// CreatePendingWithLimit 在单事务内执行“并发上限检查 + 创建”,避免 count+create 竞态。
+func (r *soraGenerationRepository) CreatePendingWithLimit(
+ ctx context.Context,
+ gen *service.SoraGeneration,
+ activeStatuses []string,
+ maxActive int64,
+) error {
+ if gen == nil {
+ return fmt.Errorf("generation is nil")
+ }
+ if maxActive <= 0 {
+ return r.Create(ctx, gen)
+ }
+ if len(activeStatuses) == 0 {
+ activeStatuses = []string{service.SoraGenStatusPending, service.SoraGenStatusGenerating}
+ }
+
+ tx, err := r.sql.BeginTx(ctx, nil)
+ if err != nil {
+ return err
+ }
+ defer func() { _ = tx.Rollback() }()
+
+ // 使用用户级 advisory lock 串行化并发创建,避免超限竞态。
+ if _, err := tx.ExecContext(ctx, `SELECT pg_advisory_xact_lock($1)`, gen.UserID); err != nil {
+ return err
+ }
+
+ placeholders := make([]string, len(activeStatuses))
+ args := make([]any, 0, 1+len(activeStatuses))
+ args = append(args, gen.UserID)
+ for i, s := range activeStatuses {
+ placeholders[i] = fmt.Sprintf("$%d", i+2)
+ args = append(args, s)
+ }
+ countQuery := fmt.Sprintf(
+ `SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)`,
+ strings.Join(placeholders, ","),
+ )
+ var activeCount int64
+ if err := tx.QueryRowContext(ctx, countQuery, args...).Scan(&activeCount); err != nil {
+ return err
+ }
+ if activeCount >= maxActive {
+ return service.ErrSoraGenerationConcurrencyLimit
+ }
+
+ mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
+ s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
+ if err := tx.QueryRowContext(ctx, `
+ INSERT INTO sora_generations (
+ user_id, api_key_id, model, prompt, media_type,
+ status, media_url, media_urls, file_size_bytes,
+ storage_type, s3_object_keys, upstream_task_id, error_message
+ ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
+ RETURNING id, created_at
+ `,
+ gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType,
+ gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
+ gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage,
+ ).Scan(&gen.ID, &gen.CreatedAt); err != nil {
+ return err
+ }
+
+ return tx.Commit()
+}
+
+func (r *soraGenerationRepository) GetByID(ctx context.Context, id int64) (*service.SoraGeneration, error) {
+ gen := &service.SoraGeneration{}
+ var mediaURLsJSON, s3KeysJSON []byte
+ var completedAt sql.NullTime
+ var apiKeyID sql.NullInt64
+
+ err := r.sql.QueryRowContext(ctx, `
+ SELECT id, user_id, api_key_id, model, prompt, media_type,
+ status, media_url, media_urls, file_size_bytes,
+ storage_type, s3_object_keys, upstream_task_id, error_message,
+ created_at, completed_at
+ FROM sora_generations WHERE id = $1
+ `, id).Scan(
+ &gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType,
+ &gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes,
+ &gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage,
+ &gen.CreatedAt, &completedAt,
+ )
+ if err != nil {
+ if err == sql.ErrNoRows {
+ return nil, fmt.Errorf("生成记录不存在")
+ }
+ return nil, err
+ }
+
+ if apiKeyID.Valid {
+ gen.APIKeyID = &apiKeyID.Int64
+ }
+ if completedAt.Valid {
+ gen.CompletedAt = &completedAt.Time
+ }
+ _ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs)
+ _ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys)
+ return gen, nil
+}
+
+func (r *soraGenerationRepository) Update(ctx context.Context, gen *service.SoraGeneration) error {
+ mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
+ s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
+
+ var completedAt *time.Time
+ if gen.CompletedAt != nil {
+ completedAt = gen.CompletedAt
+ }
+
+ _, err := r.sql.ExecContext(ctx, `
+ UPDATE sora_generations SET
+ status = $2, media_url = $3, media_urls = $4, file_size_bytes = $5,
+ storage_type = $6, s3_object_keys = $7, upstream_task_id = $8,
+ error_message = $9, completed_at = $10
+ WHERE id = $1
+ `,
+ gen.ID, gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
+ gen.StorageType, s3KeysJSON, gen.UpstreamTaskID,
+ gen.ErrorMessage, completedAt,
+ )
+ return err
+}
+
+// UpdateGeneratingIfPending 仅当状态为 pending 时更新为 generating。
+func (r *soraGenerationRepository) UpdateGeneratingIfPending(ctx context.Context, id int64, upstreamTaskID string) (bool, error) {
+ result, err := r.sql.ExecContext(ctx, `
+ UPDATE sora_generations
+ SET status = $2, upstream_task_id = $3
+ WHERE id = $1 AND status = $4
+ `,
+ id, service.SoraGenStatusGenerating, upstreamTaskID, service.SoraGenStatusPending,
+ )
+ if err != nil {
+ return false, err
+ }
+ affected, err := result.RowsAffected()
+ if err != nil {
+ return false, err
+ }
+ return affected > 0, nil
+}
+
+// UpdateCompletedIfActive 仅当状态为 pending/generating 时更新为 completed。
+func (r *soraGenerationRepository) UpdateCompletedIfActive(
+ ctx context.Context,
+ id int64,
+ mediaURL string,
+ mediaURLs []string,
+ storageType string,
+ s3Keys []string,
+ fileSizeBytes int64,
+ completedAt time.Time,
+) (bool, error) {
+ mediaURLsJSON, _ := json.Marshal(mediaURLs)
+ s3KeysJSON, _ := json.Marshal(s3Keys)
+ result, err := r.sql.ExecContext(ctx, `
+ UPDATE sora_generations
+ SET status = $2,
+ media_url = $3,
+ media_urls = $4,
+ file_size_bytes = $5,
+ storage_type = $6,
+ s3_object_keys = $7,
+ error_message = '',
+ completed_at = $8
+ WHERE id = $1 AND status IN ($9, $10)
+ `,
+ id, service.SoraGenStatusCompleted, mediaURL, mediaURLsJSON, fileSizeBytes,
+ storageType, s3KeysJSON, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
+ )
+ if err != nil {
+ return false, err
+ }
+ affected, err := result.RowsAffected()
+ if err != nil {
+ return false, err
+ }
+ return affected > 0, nil
+}
+
+// UpdateFailedIfActive 仅当状态为 pending/generating 时更新为 failed。
+func (r *soraGenerationRepository) UpdateFailedIfActive(
+ ctx context.Context,
+ id int64,
+ errMsg string,
+ completedAt time.Time,
+) (bool, error) {
+ result, err := r.sql.ExecContext(ctx, `
+ UPDATE sora_generations
+ SET status = $2,
+ error_message = $3,
+ completed_at = $4
+ WHERE id = $1 AND status IN ($5, $6)
+ `,
+ id, service.SoraGenStatusFailed, errMsg, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
+ )
+ if err != nil {
+ return false, err
+ }
+ affected, err := result.RowsAffected()
+ if err != nil {
+ return false, err
+ }
+ return affected > 0, nil
+}
+
+// UpdateCancelledIfActive 仅当状态为 pending/generating 时更新为 cancelled。
+func (r *soraGenerationRepository) UpdateCancelledIfActive(ctx context.Context, id int64, completedAt time.Time) (bool, error) {
+ result, err := r.sql.ExecContext(ctx, `
+ UPDATE sora_generations
+ SET status = $2, completed_at = $3
+ WHERE id = $1 AND status IN ($4, $5)
+ `,
+ id, service.SoraGenStatusCancelled, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
+ )
+ if err != nil {
+ return false, err
+ }
+ affected, err := result.RowsAffected()
+ if err != nil {
+ return false, err
+ }
+ return affected > 0, nil
+}
+
+// UpdateStorageIfCompleted 更新已完成记录的存储信息(用于手动保存,不重置 completed_at)。
+func (r *soraGenerationRepository) UpdateStorageIfCompleted(
+ ctx context.Context,
+ id int64,
+ mediaURL string,
+ mediaURLs []string,
+ storageType string,
+ s3Keys []string,
+ fileSizeBytes int64,
+) (bool, error) {
+ mediaURLsJSON, _ := json.Marshal(mediaURLs)
+ s3KeysJSON, _ := json.Marshal(s3Keys)
+ result, err := r.sql.ExecContext(ctx, `
+ UPDATE sora_generations
+ SET media_url = $2,
+ media_urls = $3,
+ file_size_bytes = $4,
+ storage_type = $5,
+ s3_object_keys = $6
+ WHERE id = $1 AND status = $7
+ `,
+ id, mediaURL, mediaURLsJSON, fileSizeBytes, storageType, s3KeysJSON, service.SoraGenStatusCompleted,
+ )
+ if err != nil {
+ return false, err
+ }
+ affected, err := result.RowsAffected()
+ if err != nil {
+ return false, err
+ }
+ return affected > 0, nil
+}
+
+func (r *soraGenerationRepository) Delete(ctx context.Context, id int64) error {
+ _, err := r.sql.ExecContext(ctx, `DELETE FROM sora_generations WHERE id = $1`, id)
+ return err
+}
+
+func (r *soraGenerationRepository) List(ctx context.Context, params service.SoraGenerationListParams) ([]*service.SoraGeneration, int64, error) {
+ // 构建 WHERE 条件
+ conditions := []string{"user_id = $1"}
+ args := []any{params.UserID}
+ argIdx := 2
+
+ if params.Status != "" {
+ // 支持逗号分隔的多状态
+ statuses := strings.Split(params.Status, ",")
+ placeholders := make([]string, len(statuses))
+ for i, s := range statuses {
+ placeholders[i] = fmt.Sprintf("$%d", argIdx)
+ args = append(args, strings.TrimSpace(s))
+ argIdx++
+ }
+ conditions = append(conditions, fmt.Sprintf("status IN (%s)", strings.Join(placeholders, ",")))
+ }
+ if params.StorageType != "" {
+ storageTypes := strings.Split(params.StorageType, ",")
+ placeholders := make([]string, len(storageTypes))
+ for i, s := range storageTypes {
+ placeholders[i] = fmt.Sprintf("$%d", argIdx)
+ args = append(args, strings.TrimSpace(s))
+ argIdx++
+ }
+ conditions = append(conditions, fmt.Sprintf("storage_type IN (%s)", strings.Join(placeholders, ",")))
+ }
+ if params.MediaType != "" {
+ conditions = append(conditions, fmt.Sprintf("media_type = $%d", argIdx))
+ args = append(args, params.MediaType)
+ argIdx++
+ }
+
+ whereClause := "WHERE " + strings.Join(conditions, " AND ")
+
+ // 计数
+ var total int64
+ countQuery := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations %s", whereClause)
+ if err := r.sql.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
+ return nil, 0, err
+ }
+
+ // 分页查询
+ offset := (params.Page - 1) * params.PageSize
+ listQuery := fmt.Sprintf(`
+ SELECT id, user_id, api_key_id, model, prompt, media_type,
+ status, media_url, media_urls, file_size_bytes,
+ storage_type, s3_object_keys, upstream_task_id, error_message,
+ created_at, completed_at
+ FROM sora_generations %s
+ ORDER BY created_at DESC
+ LIMIT $%d OFFSET $%d
+ `, whereClause, argIdx, argIdx+1)
+ args = append(args, params.PageSize, offset)
+
+ rows, err := r.sql.QueryContext(ctx, listQuery, args...)
+ if err != nil {
+ return nil, 0, err
+ }
+ defer func() {
+ _ = rows.Close()
+ }()
+
+ var results []*service.SoraGeneration
+ for rows.Next() {
+ gen := &service.SoraGeneration{}
+ var mediaURLsJSON, s3KeysJSON []byte
+ var completedAt sql.NullTime
+ var apiKeyID sql.NullInt64
+
+ if err := rows.Scan(
+ &gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType,
+ &gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes,
+ &gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage,
+ &gen.CreatedAt, &completedAt,
+ ); err != nil {
+ return nil, 0, err
+ }
+
+ if apiKeyID.Valid {
+ gen.APIKeyID = &apiKeyID.Int64
+ }
+ if completedAt.Valid {
+ gen.CompletedAt = &completedAt.Time
+ }
+ _ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs)
+ _ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys)
+ results = append(results, gen)
+ }
+
+ return results, total, rows.Err()
+}
+
+func (r *soraGenerationRepository) CountByUserAndStatus(ctx context.Context, userID int64, statuses []string) (int64, error) {
+ if len(statuses) == 0 {
+ return 0, nil
+ }
+
+ placeholders := make([]string, len(statuses))
+ args := []any{userID}
+ for i, s := range statuses {
+ placeholders[i] = fmt.Sprintf("$%d", i+2)
+ args = append(args, s)
+ }
+
+ var count int64
+ query := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)", strings.Join(placeholders, ","))
+ err := r.sql.QueryRowContext(ctx, query, args...).Scan(&count)
+ return count, err
+}
diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go
index d3adb4a0..e71120f3 100644
--- a/backend/internal/repository/wire.go
+++ b/backend/internal/repository/wire.go
@@ -89,6 +89,8 @@ var ProviderSet = wire.NewSet(
NewErrorPassthroughRepository,
NewTLSFingerprintProfileRepository,
NewChannelRepository,
+ NewSoraAccountRepository, // Sora 账号扩展表仓储 (从本地版本合并)
+ NewSoraGenerationRepository, // Sora 生成记录仓储 (从本地版本合并)
// Cache implementations
NewGatewayCache,
diff --git a/backend/internal/server/middleware/security_headers.go b/backend/internal/server/middleware/security_headers.go
index 73210bfc..d9ec951e 100644
--- a/backend/internal/server/middleware/security_headers.go
+++ b/backend/internal/server/middleware/security_headers.go
@@ -94,6 +94,7 @@ func isAPIRoutePath(c *gin.Context) bool {
return strings.HasPrefix(path, "/v1/") ||
strings.HasPrefix(path, "/v1beta/") ||
strings.HasPrefix(path, "/antigravity/") ||
+ strings.HasPrefix(path, "/sora/") ||
strings.HasPrefix(path, "/responses")
}
diff --git a/backend/internal/server/routes/sora_client.go b/backend/internal/server/routes/sora_client.go
new file mode 100644
index 00000000..13fceb81
--- /dev/null
+++ b/backend/internal/server/routes/sora_client.go
@@ -0,0 +1,36 @@
+package routes
+
+import (
+ "github.com/Wei-Shaw/sub2api/internal/handler"
+ "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+
+ "github.com/gin-gonic/gin"
+)
+
+// RegisterSoraClientRoutes 注册 Sora 客户端 API 路由(需要用户认证)。
+func RegisterSoraClientRoutes(
+ v1 *gin.RouterGroup,
+ h *handler.Handlers,
+ jwtAuth middleware.JWTAuthMiddleware,
+ settingService *service.SettingService,
+) {
+ if h.SoraClient == nil {
+ return
+ }
+
+ authenticated := v1.Group("/sora")
+ authenticated.Use(gin.HandlerFunc(jwtAuth))
+ authenticated.Use(middleware.BackendModeUserGuard(settingService))
+ {
+ authenticated.POST("/generate", h.SoraClient.Generate)
+ authenticated.GET("/generations", h.SoraClient.ListGenerations)
+ authenticated.GET("/generations/:id", h.SoraClient.GetGeneration)
+ authenticated.DELETE("/generations/:id", h.SoraClient.DeleteGeneration)
+ authenticated.POST("/generations/:id/cancel", h.SoraClient.CancelGeneration)
+ authenticated.POST("/generations/:id/save", h.SoraClient.SaveToStorage)
+ authenticated.GET("/quota", h.SoraClient.GetQuota)
+ authenticated.GET("/models", h.SoraClient.GetModels)
+ authenticated.GET("/storage-status", h.SoraClient.GetStorageStatus)
+ }
+}
diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go
index 763abadb..e65116dc 100644
--- a/backend/internal/service/billing_service.go
+++ b/backend/internal/service/billing_service.go
@@ -270,6 +270,169 @@ func (s *BillingService) initFallbackPricing() {
SupportsCacheBreakdown: false,
}
s.fallbackPrices["gpt-5.3-codex"] = s.fallbackPrices["gpt-5.1-codex"]
+
+ // 国内模型定价 - 智谱GLM系列
+ s.fallbackPrices["glm-4-plus"] = &ModelPricing{
+ InputPricePerToken: 6.94e-7, // ¥5/MTok ≈ $0.000694/MTok
+ OutputPricePerToken: 6.94e-7,
+ SupportsCacheBreakdown: false,
+ }
+ s.fallbackPrices["glm-4-flash"] = &ModelPricing{
+ InputPricePerToken: 1.39e-8, // ¥0.1/MTok
+ OutputPricePerToken: 1.39e-8,
+ SupportsCacheBreakdown: false,
+ }
+ s.fallbackPrices["glm-4-flashx"] = &ModelPricing{
+ InputPricePerToken: 1.39e-7, // ¥1/MTok
+ OutputPricePerToken: 1.39e-7,
+ SupportsCacheBreakdown: false,
+ }
+
+ // 国内模型定价 - 通义千问系列
+ s.fallbackPrices["qwen-turbo"] = &ModelPricing{
+ InputPricePerToken: 4.17e-7, // ¥0.3/MTok ≈ $0.000042/MTok
+ OutputPricePerToken: 8.33e-7, // ¥0.6/MTok
+ SupportsCacheBreakdown: false,
+ }
+ s.fallbackPrices["qwen-plus"] = &ModelPricing{
+ InputPricePerToken: 1.11e-6, // ¥0.8/MTok
+ OutputPricePerToken: 2.78e-6, // ¥2/MTok
+ SupportsCacheBreakdown: false,
+ }
+ s.fallbackPrices["qwen-max"] = &ModelPricing{
+ InputPricePerToken: 5.56e-6, // ¥4/MTok
+ OutputPricePerToken: 1.67e-5, // ¥12/MTok
+ SupportsCacheBreakdown: false,
+ }
+ s.fallbackPrices["qwen-long"] = &ModelPricing{
+ InputPricePerToken: 6.94e-7, // ¥0.5/MTok
+ OutputPricePerToken: 2.78e-6, // ¥2/MTok
+ SupportsCacheBreakdown: false,
+ }
+
+ // 国内模型定价 - 月之暗面Kimi系列
+ s.fallbackPrices["moonshot-v1-8k"] = &ModelPricing{
+ InputPricePerToken: 2.78e-7, // ¥2/MTok ≈ $0.000278/MTok
+ OutputPricePerToken: 2.78e-7,
+ SupportsCacheBreakdown: false,
+ }
+ s.fallbackPrices["moonshot-v1-32k"] = &ModelPricing{
+ InputPricePerToken: 5.56e-7, // ¥4/MTok
+ OutputPricePerToken: 5.56e-7,
+ SupportsCacheBreakdown: false,
+ }
+ s.fallbackPrices["moonshot-v1-128k"] = &ModelPricing{
+ InputPricePerToken: 1.39e-6, // ¥10/MTok
+ OutputPricePerToken: 1.39e-6,
+ SupportsCacheBreakdown: false,
+ }
+ s.fallbackPrices["kimi-k2"] = &ModelPricing{
+ InputPricePerToken: 6.0e-7, // $0.6/MTok
+ OutputPricePerToken: 2.0e-6, // $2/MTok
+ SupportsCacheBreakdown: false,
+ }
+
+ // 国内模型定价 - MiniMax系列
+ s.fallbackPrices["abab6.5s-chat"] = &ModelPricing{
+ InputPricePerToken: 1.39e-6, // ¥10/MTok ≈ $0.00139/MTok
+ OutputPricePerToken: 1.39e-6,
+ SupportsCacheBreakdown: false,
+ }
+ s.fallbackPrices["minimax-m2.7"] = &ModelPricing{
+ InputPricePerToken: 3.0e-7, // $0.3/MTok
+ OutputPricePerToken: 1.2e-6, // $1.2/MTok
+ SupportsCacheBreakdown: false,
+ }
+
+ // 国内模型定价 - 豆包系列
+ s.fallbackPrices["doubao-lite-4k"] = &ModelPricing{
+ InputPricePerToken: 4.17e-8, // ¥0.3/MTok
+ OutputPricePerToken: 2.78e-7, // ¥2/MTok
+ SupportsCacheBreakdown: false,
+ }
+ s.fallbackPrices["doubao-pro-128k"] = &ModelPricing{
+ InputPricePerToken: 6.94e-7, // ¥5/MTok
+ OutputPricePerToken: 1.25e-6, // ¥9/MTok
+ SupportsCacheBreakdown: false,
+ }
+
+ // 国内模型定价 - 智谱GLM最新系列
+ s.fallbackPrices["glm-5"] = &ModelPricing{
+ InputPricePerToken: 5.6e-7, // ¥4/MTok
+ OutputPricePerToken: 3.08e-6, // ¥22/MTok
+ SupportsCacheBreakdown: false,
+ }
+ s.fallbackPrices["glm-5-turbo"] = &ModelPricing{
+ InputPricePerToken: 1.2e-6, // $1.2/MTok
+ OutputPricePerToken: 4e-6, // $4/MTok
+ SupportsCacheBreakdown: false,
+ }
+ s.fallbackPrices["glm-5.1"] = &ModelPricing{
+ InputPricePerToken: 8.4e-7, // ¥6/MTok
+ OutputPricePerToken: 3.92e-6, // ¥28/MTok
+ SupportsCacheBreakdown: false,
+ }
+ s.fallbackPrices["glm-4.7"] = &ModelPricing{
+ InputPricePerToken: 2.8e-7, // ¥2/MTok
+ OutputPricePerToken: 2.8e-7, // ¥2/MTok
+ SupportsCacheBreakdown: false,
+ }
+ s.fallbackPrices["glm-4.5-air"] = &ModelPricing{
+ InputPricePerToken: 1.4e-7, // ¥1/MTok
+ OutputPricePerToken: 8.4e-7, // ¥6/MTok
+ SupportsCacheBreakdown: false,
+ }
+
+ // 国内模型定价 - 百川智能系列
+ s.fallbackPrices["baichuan4"] = &ModelPricing{
+ InputPricePerToken: 1.4e-5, // ¥100/MTok
+ OutputPricePerToken: 1.4e-5, // ¥100/MTok
+ SupportsCacheBreakdown: false,
+ }
+ s.fallbackPrices["baichuan4-turbo"] = &ModelPricing{
+ InputPricePerToken: 2.1e-6, // ¥15/MTok
+ OutputPricePerToken: 2.1e-6, // ¥15/MTok
+ SupportsCacheBreakdown: false,
+ }
+ s.fallbackPrices["baichuan4-air"] = &ModelPricing{
+ InputPricePerToken: 1.37e-7, // ¥0.98/MTok
+ OutputPricePerToken: 1.37e-7, // ¥0.98/MTok
+ SupportsCacheBreakdown: false,
+ }
+ s.fallbackPrices["baichuan-m3-plus"] = &ModelPricing{
+ InputPricePerToken: 7e-7, // ¥5/MTok
+ OutputPricePerToken: 1.26e-6, // ¥9/MTok
+ SupportsCacheBreakdown: false,
+ }
+
+ // 国内模型定价 - DeepSeek系列
+ s.fallbackPrices["deepseek-v3"] = &ModelPricing{
+ InputPricePerToken: 2.8e-7, // ¥2/MTok
+ OutputPricePerToken: 1.12e-6, // ¥8/MTok
+ SupportsCacheBreakdown: false,
+ }
+ s.fallbackPrices["deepseek-v3.2"] = &ModelPricing{
+ InputPricePerToken: 2.8e-7, // ¥2/MTok
+ OutputPricePerToken: 4.2e-7, // ¥3/MTok
+ SupportsCacheBreakdown: false,
+ }
+ s.fallbackPrices["deepseek-r1"] = &ModelPricing{
+ InputPricePerToken: 2.8e-7, // ¥2/MTok
+ OutputPricePerToken: 1.12e-6, // ¥8/MTok
+ SupportsCacheBreakdown: false,
+ }
+
+ // 国内模型定价 - 通义千问最新系列
+ s.fallbackPrices["qwen3-8b"] = &ModelPricing{
+ InputPricePerToken: 0, // 免费
+ OutputPricePerToken: 0, // 免费
+ SupportsCacheBreakdown: false,
+ }
+ s.fallbackPrices["qwen2.5-72b-instruct"] = &ModelPricing{
+ InputPricePerToken: 5.6e-7, // ¥4/MTok
+ OutputPricePerToken: 1.68e-6, // ¥12/MTok
+ SupportsCacheBreakdown: false,
+ }
}
// getFallbackPricing 根据模型系列获取回退价格
@@ -329,6 +492,137 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
}
}
+ // 国内模型匹配 - 智谱GLM系列
+ if strings.Contains(modelLower, "glm-5.1") || strings.Contains(modelLower, "glm-5-1") {
+ return s.fallbackPrices["glm-5.1"]
+ }
+ if strings.Contains(modelLower, "glm-5-turbo") {
+ return s.fallbackPrices["glm-5-turbo"]
+ }
+ if strings.Contains(modelLower, "glm-5") {
+ return s.fallbackPrices["glm-5"]
+ }
+ if strings.Contains(modelLower, "glm-4.7") || strings.Contains(modelLower, "glm-4-7") {
+ return s.fallbackPrices["glm-4.7"]
+ }
+ if strings.Contains(modelLower, "glm-4.5-air") || strings.Contains(modelLower, "glm-4-5-air") {
+ return s.fallbackPrices["glm-4.5-air"]
+ }
+ if strings.Contains(modelLower, "glm-4-plus") {
+ return s.fallbackPrices["glm-4-plus"]
+ }
+ if strings.Contains(modelLower, "glm-4-flashx") || strings.Contains(modelLower, "glm-4-flash-x") {
+ return s.fallbackPrices["glm-4-flashx"]
+ }
+ if strings.Contains(modelLower, "glm-4-flash") {
+ return s.fallbackPrices["glm-4-flash"]
+ }
+ if strings.Contains(modelLower, "glm-z1-airx") {
+ return s.fallbackPrices["glm-z1-airx"]
+ }
+ if strings.Contains(modelLower, "glm-z1-air") || strings.Contains(modelLower, "glm-z1") {
+ return s.fallbackPrices["glm-z1-air"]
+ }
+ if strings.Contains(modelLower, "glm-4") || strings.Contains(modelLower, "glm-3") {
+ return s.fallbackPrices["glm-4-plus"]
+ }
+
+ // 国内模型匹配 - 百川智能系列
+ if strings.Contains(modelLower, "baichuan-m3-plus") || strings.Contains(modelLower, "baichuan-m3") {
+ return s.fallbackPrices["baichuan-m3-plus"]
+ }
+ if strings.Contains(modelLower, "baichuan4-turbo") || strings.Contains(modelLower, "baichuan-4-turbo") {
+ return s.fallbackPrices["baichuan4-turbo"]
+ }
+ if strings.Contains(modelLower, "baichuan4-air") || strings.Contains(modelLower, "baichuan-4-air") {
+ return s.fallbackPrices["baichuan4-air"]
+ }
+ if strings.Contains(modelLower, "baichuan4") || strings.Contains(modelLower, "baichuan-4") {
+ return s.fallbackPrices["baichuan4"]
+ }
+ if strings.Contains(modelLower, "baichuan") {
+ return s.fallbackPrices["baichuan4-turbo"]
+ }
+
+ // 国内模型匹配 - DeepSeek系列
+ if strings.Contains(modelLower, "deepseek-v3.2") || strings.Contains(modelLower, "deepseek-v3-2") {
+ return s.fallbackPrices["deepseek-v3.2"]
+ }
+ if strings.Contains(modelLower, "deepseek-r1") || strings.Contains(modelLower, "deepseek-reasoner") {
+ return s.fallbackPrices["deepseek-r1"]
+ }
+ if strings.Contains(modelLower, "deepseek-v3") {
+ return s.fallbackPrices["deepseek-v3"]
+ }
+ if strings.Contains(modelLower, "deepseek") {
+ return s.fallbackPrices["deepseek-v3"]
+ }
+
+ // 国内模型匹配 - 通义千问系列
+ if strings.Contains(modelLower, "qwen3-8b") || strings.Contains(modelLower, "qwen3-8") {
+ return s.fallbackPrices["qwen3-8b"]
+ }
+ if strings.Contains(modelLower, "qwen2.5-72b") || strings.Contains(modelLower, "qwen-2.5-72b") {
+ return s.fallbackPrices["qwen2.5-72b-instruct"]
+ }
+ if strings.Contains(modelLower, "qwen-max") {
+ return s.fallbackPrices["qwen-max"]
+ }
+ if strings.Contains(modelLower, "qwen-long") {
+ return s.fallbackPrices["qwen-long"]
+ }
+ if strings.Contains(modelLower, "qwen-plus") || strings.Contains(modelLower, "qwen2.5-32b") || strings.Contains(modelLower, "qwen-coder-plus") {
+ return s.fallbackPrices["qwen-plus"]
+ }
+ if strings.Contains(modelLower, "qwen-turbo") || strings.Contains(modelLower, "qwen2.5-14b") || strings.Contains(modelLower, "qwen-coder-turbo") {
+ return s.fallbackPrices["qwen-turbo"]
+ }
+ if strings.Contains(modelLower, "qwen") {
+ return s.fallbackPrices["qwen-turbo"]
+ }
+
+ // 国内模型匹配 - 月之暗面Kimi系列
+ if strings.Contains(modelLower, "kimi-k2") || strings.Contains(modelLower, "kimi-k2.5") {
+ return s.fallbackPrices["kimi-k2"]
+ }
+ if strings.Contains(modelLower, "moonshot-v1-128k") || strings.Contains(modelLower, "moonshot-128k") {
+ return s.fallbackPrices["moonshot-v1-128k"]
+ }
+ if strings.Contains(modelLower, "moonshot-v1-32k") || strings.Contains(modelLower, "moonshot-32k") {
+ return s.fallbackPrices["moonshot-v1-32k"]
+ }
+ if strings.Contains(modelLower, "moonshot") || strings.Contains(modelLower, "kimi") {
+ return s.fallbackPrices["moonshot-v1-8k"]
+ }
+
+ // 国内模型匹配 - MiniMax系列
+ if strings.Contains(modelLower, "minimax-m2") || strings.Contains(modelLower, "minimax-m1") {
+ return s.fallbackPrices["minimax-m2.7"]
+ }
+ if strings.Contains(modelLower, "abab6.5s") || strings.Contains(modelLower, "abab6.5s-chat") {
+ return s.fallbackPrices["abab6.5s-chat"]
+ }
+ if strings.Contains(modelLower, "abab6.5t") || strings.Contains(modelLower, "abab6.5") {
+ return s.fallbackPrices["abab6.5s-chat"]
+ }
+ if strings.Contains(modelLower, "abab") || strings.Contains(modelLower, "minimax") {
+ return s.fallbackPrices["abab6.5s-chat"]
+ }
+
+ // 国内模型匹配 - 豆包系列
+ if strings.Contains(modelLower, "doubao-seed") || strings.Contains(modelLower, "doubao-1.6") {
+ return s.fallbackPrices["doubao-pro-128k"]
+ }
+ if strings.Contains(modelLower, "doubao-pro") || strings.Contains(modelLower, "doubao-1.5-thinking") {
+ return s.fallbackPrices["doubao-pro-128k"]
+ }
+ if strings.Contains(modelLower, "doubao-lite") || strings.Contains(modelLower, "doubao-1.5") {
+ return s.fallbackPrices["doubao-lite-4k"]
+ }
+ if strings.Contains(modelLower, "doubao") {
+ return s.fallbackPrices["doubao-lite-4k"]
+ }
+
return nil
}
diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go
index 68d7da3b..0f45d733 100644
--- a/backend/internal/service/domain_constants.go
+++ b/backend/internal/service/domain_constants.go
@@ -24,6 +24,7 @@ const (
PlatformOpenAI = domain.PlatformOpenAI
PlatformGemini = domain.PlatformGemini
PlatformAntigravity = domain.PlatformAntigravity
+ PlatformSora = domain.PlatformSora // 从本地版本合并
)
// Account type constants
@@ -249,6 +250,19 @@ const (
SettingKeyEnableMetadataPassthrough = "enable_metadata_passthrough"
// SettingKeyEnableCCHSigning 是否对 billing header 中的 cch 进行 xxHash64 签名(默认 false)
SettingKeyEnableCCHSigning = "enable_cch_signing"
+
+ // Sora S3 存储配置 (从本地版本合并)
+ SettingKeySoraS3Enabled = "sora_s3_enabled" // 是否启用 Sora S3 存储
+ SettingKeySoraS3Endpoint = "sora_s3_endpoint" // S3 端点地址
+ SettingKeySoraS3Region = "sora_s3_region" // S3 区域
+ SettingKeySoraS3Bucket = "sora_s3_bucket" // S3 存储桶名称
+ SettingKeySoraS3AccessKeyID = "sora_s3_access_key_id" // S3 Access Key ID
+ SettingKeySoraS3SecretAccessKey = "sora_s3_secret_access_key" // S3 Secret Access Key(加密存储)
+ SettingKeySoraS3Prefix = "sora_s3_prefix" // S3 对象键前缀
+ SettingKeySoraS3ForcePathStyle = "sora_s3_force_path_style" // 是否强制 Path Style(兼容 MinIO 等)
+ SettingKeySoraS3CDNURL = "sora_s3_cdn_url" // CDN 加速 URL(可选)
+ SettingKeySoraS3Profiles = "sora_s3_profiles" // Sora S3 多配置(JSON)
+ SettingKeySoraDefaultStorageQuotaBytes = "sora_default_storage_quota_bytes" // Sora 默认存储配额(字节)
)
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index 8b0bdc2a..9cc6a8bb 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -504,6 +504,9 @@ type ForwardResult struct {
ImageCount int // 生成的图片数量
ImageSize string // 图片尺寸 "1K", "2K", "4K"
+ // Sora 媒体字段 (从本地版本合并)
+ MediaType string // image / video / prompt
+ MediaURL string // 生成后的媒体地址(可选)
}
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go
index 12262613..8513e218 100644
--- a/backend/internal/service/group.go
+++ b/backend/internal/service/group.go
@@ -59,6 +59,9 @@ type Group struct {
DefaultMappedModel string
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig
+ // Sora 存储配额 (从本地版本合并)
+ SoraStorageQuotaBytes int64
+
CreatedAt time.Time
UpdatedAt time.Time
diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go
index 48f25da0..619f201e 100644
--- a/backend/internal/service/setting_service.go
+++ b/backend/internal/service/setting_service.go
@@ -24,6 +24,8 @@ import (
var (
ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found")
+ ErrSoraS3ProfileNotFound = infraerrors.NotFound("SORA_S3_PROFILE_NOT_FOUND", "sora s3 profile not found") // 从本地版本合并
+ ErrSoraS3ProfileExists = infraerrors.Conflict("SORA_S3_PROFILE_EXISTS", "sora s3 profile already exists") // 从本地版本合并
ErrDefaultSubGroupInvalid = infraerrors.BadRequest(
"DEFAULT_SUBSCRIPTION_GROUP_INVALID",
"default subscription group must exist and be subscription type",
@@ -2126,3 +2128,315 @@ func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings
return s.settingRepo.Set(ctx, SettingKeyStreamTimeoutSettings, string(data))
}
+
+// Sora S3 存储配置 (从本地版本合并)
+type soraS3ProfilesStore struct {
+ ActiveProfileID string `json:"active_profile_id"`
+ Items []soraS3ProfileStoreItem `json:"items"`
+}
+
+type soraS3ProfileStoreItem struct {
+ ProfileID string `json:"profile_id"`
+ Name string `json:"name"`
+ Enabled bool `json:"enabled"`
+ Endpoint string `json:"endpoint"`
+ Region string `json:"region"`
+ Bucket string `json:"bucket"`
+ AccessKeyID string `json:"access_key_id"`
+ SecretAccessKey string `json:"secret_access_key"`
+ Prefix string `json:"prefix"`
+ ForcePathStyle bool `json:"force_path_style"`
+ CDNURL string `json:"cdn_url"`
+ DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
+ UpdatedAt string `json:"updated_at"`
+}
+
+// GetSoraS3Settings 获取 Sora S3 存储配置(兼容旧单配置语义:返回当前激活配置)
+func (s *SettingService) GetSoraS3Settings(ctx context.Context) (*SoraS3Settings, error) {
+ profiles, err := s.ListSoraS3Profiles(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ activeProfile := pickActiveSoraS3Profile(profiles.Items, profiles.ActiveProfileID)
+ if activeProfile == nil {
+ return &SoraS3Settings{}, nil
+ }
+
+ return &SoraS3Settings{
+ Enabled: activeProfile.Enabled,
+ Endpoint: activeProfile.Endpoint,
+ Region: activeProfile.Region,
+ Bucket: activeProfile.Bucket,
+ AccessKeyID: activeProfile.AccessKeyID,
+ SecretAccessKey: activeProfile.SecretAccessKey,
+ SecretAccessKeyConfigured: activeProfile.SecretAccessKeyConfigured,
+ Prefix: activeProfile.Prefix,
+ ForcePathStyle: activeProfile.ForcePathStyle,
+ CDNURL: activeProfile.CDNURL,
+ DefaultStorageQuotaBytes: activeProfile.DefaultStorageQuotaBytes,
+ }, nil
+}
+
+// ListSoraS3Profiles 获取 Sora S3 多配置列表
+func (s *SettingService) ListSoraS3Profiles(ctx context.Context) (*SoraS3ProfileList, error) {
+ store, err := s.loadSoraS3ProfilesStore(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return convertSoraS3ProfilesStore(store), nil
+}
+
+func (s *SettingService) loadSoraS3ProfilesStore(ctx context.Context) (*soraS3ProfilesStore, error) {
+ raw, err := s.settingRepo.GetValue(ctx, SettingKeySoraS3Profiles)
+ if err == nil {
+ trimmed := strings.TrimSpace(raw)
+ if trimmed == "" {
+ return &soraS3ProfilesStore{}, nil
+ }
+ var store soraS3ProfilesStore
+ if unmarshalErr := json.Unmarshal([]byte(trimmed), &store); unmarshalErr != nil {
+ legacy, legacyErr := s.getLegacySoraS3Settings(ctx)
+ if legacyErr != nil {
+ return nil, fmt.Errorf("unmarshal sora s3 profiles: %w", unmarshalErr)
+ }
+ if isEmptyLegacySoraS3Settings(legacy) {
+ return &soraS3ProfilesStore{}, nil
+ }
+ now := time.Now().UTC().Format(time.RFC3339)
+ return &soraS3ProfilesStore{
+ ActiveProfileID: "default",
+ Items: []soraS3ProfileStoreItem{
+ {
+ ProfileID: "default",
+ Name: "Default",
+ Enabled: legacy.Enabled,
+ Endpoint: strings.TrimSpace(legacy.Endpoint),
+ Region: strings.TrimSpace(legacy.Region),
+ Bucket: strings.TrimSpace(legacy.Bucket),
+ AccessKeyID: strings.TrimSpace(legacy.AccessKeyID),
+ SecretAccessKey: legacy.SecretAccessKey,
+ Prefix: strings.TrimSpace(legacy.Prefix),
+ ForcePathStyle: legacy.ForcePathStyle,
+ CDNURL: strings.TrimSpace(legacy.CDNURL),
+ DefaultStorageQuotaBytes: maxInt64(legacy.DefaultStorageQuotaBytes, 0),
+ UpdatedAt: now,
+ },
+ },
+ }, nil
+ }
+ normalized := normalizeSoraS3ProfilesStore(store)
+ return &normalized, nil
+ }
+
+ if !errors.Is(err, ErrSettingNotFound) {
+ return nil, fmt.Errorf("get sora s3 profiles: %w", err)
+ }
+
+ legacy, legacyErr := s.getLegacySoraS3Settings(ctx)
+ if legacyErr != nil {
+ return nil, legacyErr
+ }
+ if isEmptyLegacySoraS3Settings(legacy) {
+ return &soraS3ProfilesStore{}, nil
+ }
+
+ now := time.Now().UTC().Format(time.RFC3339)
+ return &soraS3ProfilesStore{
+ ActiveProfileID: "default",
+ Items: []soraS3ProfileStoreItem{
+ {
+ ProfileID: "default",
+ Name: "Default",
+ Enabled: legacy.Enabled,
+ Endpoint: strings.TrimSpace(legacy.Endpoint),
+ Region: strings.TrimSpace(legacy.Region),
+ Bucket: strings.TrimSpace(legacy.Bucket),
+ AccessKeyID: strings.TrimSpace(legacy.AccessKeyID),
+ SecretAccessKey: legacy.SecretAccessKey,
+ Prefix: strings.TrimSpace(legacy.Prefix),
+ ForcePathStyle: legacy.ForcePathStyle,
+ CDNURL: strings.TrimSpace(legacy.CDNURL),
+ DefaultStorageQuotaBytes: maxInt64(legacy.DefaultStorageQuotaBytes, 0),
+ UpdatedAt: now,
+ },
+ },
+ }, nil
+}
+
+func (s *SettingService) getLegacySoraS3Settings(ctx context.Context) (*SoraS3Settings, error) {
+ keys := []string{
+ SettingKeySoraS3Enabled,
+ SettingKeySoraS3Endpoint,
+ SettingKeySoraS3Region,
+ SettingKeySoraS3Bucket,
+ SettingKeySoraS3AccessKeyID,
+ SettingKeySoraS3SecretAccessKey,
+ SettingKeySoraS3Prefix,
+ SettingKeySoraS3ForcePathStyle,
+ SettingKeySoraS3CDNURL,
+ SettingKeySoraDefaultStorageQuotaBytes,
+ }
+
+ values, err := s.settingRepo.GetMultiple(ctx, keys)
+ if err != nil {
+ return nil, fmt.Errorf("get legacy sora s3 settings: %w", err)
+ }
+
+ result := &SoraS3Settings{
+ Enabled: values[SettingKeySoraS3Enabled] == "true",
+ Endpoint: values[SettingKeySoraS3Endpoint],
+ Region: values[SettingKeySoraS3Region],
+ Bucket: values[SettingKeySoraS3Bucket],
+ AccessKeyID: values[SettingKeySoraS3AccessKeyID],
+ SecretAccessKey: values[SettingKeySoraS3SecretAccessKey],
+ SecretAccessKeyConfigured: values[SettingKeySoraS3SecretAccessKey] != "",
+ Prefix: values[SettingKeySoraS3Prefix],
+ ForcePathStyle: values[SettingKeySoraS3ForcePathStyle] == "true",
+ CDNURL: values[SettingKeySoraS3CDNURL],
+ }
+ if v, parseErr := strconv.ParseInt(values[SettingKeySoraDefaultStorageQuotaBytes], 10, 64); parseErr == nil {
+ result.DefaultStorageQuotaBytes = v
+ }
+ return result, nil
+}
+
+func normalizeSoraS3ProfilesStore(store soraS3ProfilesStore) soraS3ProfilesStore {
+ seen := make(map[string]struct{}, len(store.Items))
+ normalized := soraS3ProfilesStore{
+ ActiveProfileID: strings.TrimSpace(store.ActiveProfileID),
+ Items: make([]soraS3ProfileStoreItem, 0, len(store.Items)),
+ }
+ now := time.Now().UTC().Format(time.RFC3339)
+
+ for idx := range store.Items {
+ item := store.Items[idx]
+ item.ProfileID = strings.TrimSpace(item.ProfileID)
+ if item.ProfileID == "" {
+ item.ProfileID = fmt.Sprintf("profile-%d", idx+1)
+ }
+ if _, exists := seen[item.ProfileID]; exists {
+ continue
+ }
+ seen[item.ProfileID] = struct{}{}
+
+ item.Name = strings.TrimSpace(item.Name)
+ if item.Name == "" {
+ item.Name = item.ProfileID
+ }
+ item.Endpoint = strings.TrimSpace(item.Endpoint)
+ item.Region = strings.TrimSpace(item.Region)
+ item.Bucket = strings.TrimSpace(item.Bucket)
+ item.AccessKeyID = strings.TrimSpace(item.AccessKeyID)
+ item.Prefix = strings.TrimSpace(item.Prefix)
+ item.CDNURL = strings.TrimSpace(item.CDNURL)
+ item.DefaultStorageQuotaBytes = maxInt64(item.DefaultStorageQuotaBytes, 0)
+ item.UpdatedAt = strings.TrimSpace(item.UpdatedAt)
+ if item.UpdatedAt == "" {
+ item.UpdatedAt = now
+ }
+ normalized.Items = append(normalized.Items, item)
+ }
+
+ if len(normalized.Items) == 0 {
+ normalized.ActiveProfileID = ""
+ return normalized
+ }
+
+ if findSoraS3ProfileIndex(normalized.Items, normalized.ActiveProfileID) >= 0 {
+ return normalized
+ }
+
+ normalized.ActiveProfileID = normalized.Items[0].ProfileID
+ return normalized
+}
+
+func convertSoraS3ProfilesStore(store *soraS3ProfilesStore) *SoraS3ProfileList {
+ if store == nil {
+ return &SoraS3ProfileList{}
+ }
+ items := make([]SoraS3Profile, 0, len(store.Items))
+ for idx := range store.Items {
+ item := store.Items[idx]
+ items = append(items, SoraS3Profile{
+ ProfileID: item.ProfileID,
+ Name: item.Name,
+ IsActive: item.ProfileID == store.ActiveProfileID,
+ Enabled: item.Enabled,
+ Endpoint: item.Endpoint,
+ Region: item.Region,
+ Bucket: item.Bucket,
+ AccessKeyID: item.AccessKeyID,
+ SecretAccessKey: item.SecretAccessKey,
+ SecretAccessKeyConfigured: item.SecretAccessKey != "",
+ Prefix: item.Prefix,
+ ForcePathStyle: item.ForcePathStyle,
+ CDNURL: item.CDNURL,
+ DefaultStorageQuotaBytes: item.DefaultStorageQuotaBytes,
+ UpdatedAt: item.UpdatedAt,
+ })
+ }
+ return &SoraS3ProfileList{
+ ActiveProfileID: store.ActiveProfileID,
+ Items: items,
+ }
+}
+
+func pickActiveSoraS3Profile(items []SoraS3Profile, activeProfileID string) *SoraS3Profile {
+ for idx := range items {
+ if items[idx].ProfileID == activeProfileID {
+ return &items[idx]
+ }
+ }
+ if len(items) == 0 {
+ return nil
+ }
+ return &items[0]
+}
+
+func findSoraS3ProfileIndex(items []soraS3ProfileStoreItem, profileID string) int {
+ for idx := range items {
+ if items[idx].ProfileID == profileID {
+ return idx
+ }
+ }
+ return -1
+}
+
+func isEmptyLegacySoraS3Settings(settings *SoraS3Settings) bool {
+ if settings == nil {
+ return true
+ }
+ if settings.Enabled {
+ return false
+ }
+ if strings.TrimSpace(settings.Endpoint) != "" {
+ return false
+ }
+ if strings.TrimSpace(settings.Region) != "" {
+ return false
+ }
+ if strings.TrimSpace(settings.Bucket) != "" {
+ return false
+ }
+ if strings.TrimSpace(settings.AccessKeyID) != "" {
+ return false
+ }
+ if settings.SecretAccessKey != "" {
+ return false
+ }
+ if strings.TrimSpace(settings.Prefix) != "" {
+ return false
+ }
+ if strings.TrimSpace(settings.CDNURL) != "" {
+ return false
+ }
+ return settings.DefaultStorageQuotaBytes == 0
+}
+
+func maxInt64(value int64, min int64) int64 {
+ if value < min {
+ return min
+ }
+ return value
+}
diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go
index de92b796..d8c1748d 100644
--- a/backend/internal/service/settings_view.go
+++ b/backend/internal/service/settings_view.go
@@ -113,6 +113,46 @@ type DefaultSubscriptionSetting struct {
ValidityDays int `json:"validity_days"`
}
+// SoraS3Settings Sora S3 存储配置 (从本地版本合并)
+type SoraS3Settings struct {
+ Enabled bool `json:"enabled"`
+ Endpoint string `json:"endpoint"`
+ Region string `json:"region"`
+ Bucket string `json:"bucket"`
+ AccessKeyID string `json:"access_key_id"`
+ SecretAccessKey string `json:"secret_access_key"` // 仅内部使用,不直接返回前端
+ SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` // 前端展示用
+ Prefix string `json:"prefix"`
+ ForcePathStyle bool `json:"force_path_style"`
+ CDNURL string `json:"cdn_url"`
+ DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
+}
+
+// SoraS3Profile Sora S3 多配置项(服务内部模型)(从本地版本合并)
+type SoraS3Profile struct {
+ ProfileID string `json:"profile_id"`
+ Name string `json:"name"`
+ IsActive bool `json:"is_active"`
+ Enabled bool `json:"enabled"`
+ Endpoint string `json:"endpoint"`
+ Region string `json:"region"`
+ Bucket string `json:"bucket"`
+ AccessKeyID string `json:"access_key_id"`
+ SecretAccessKey string `json:"-"` // 仅内部使用,不直接返回前端
+ SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` // 前端展示用
+ Prefix string `json:"prefix"`
+ ForcePathStyle bool `json:"force_path_style"`
+ CDNURL string `json:"cdn_url"`
+ DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
+ UpdatedAt string `json:"updated_at"`
+}
+
+// SoraS3ProfileList Sora S3 多配置列表 (从本地版本合并)
+type SoraS3ProfileList struct {
+ ActiveProfileID string `json:"active_profile_id"`
+ Items []SoraS3Profile `json:"items"`
+}
+
type PublicSettings struct {
RegistrationEnabled bool
EmailVerifyEnabled bool
diff --git a/backend/internal/service/sora_account_service.go b/backend/internal/service/sora_account_service.go
new file mode 100644
index 00000000..eccc1acf
--- /dev/null
+++ b/backend/internal/service/sora_account_service.go
@@ -0,0 +1,40 @@
+package service
+
+import "context"
+
+// SoraAccountRepository Sora 账号扩展表仓储接口
+// 用于管理 sora_accounts 表,与 accounts 主表形成双表结构。
+//
+// 设计说明:
+// - sora_accounts 表存储 Sora 账号的 OAuth 凭证副本
+// - Sora gateway 优先读取此表的字段以获得更好的查询性能
+// - 主表 accounts 通过 credentials JSON 字段也存储相同信息
+// - Token 刷新时需要同时更新两个表以保持数据一致性
+type SoraAccountRepository interface {
+ // Upsert 创建或更新 Sora 账号扩展信息
+ // accountID: 关联的 accounts.id
+ // updates: 要更新的字段,支持 access_token、refresh_token、session_token
+ //
+ // 如果记录不存在则创建,存在则更新。
+ // 用于:
+ // 1. 创建 Sora 账号时初始化扩展表
+ // 2. Token 刷新时同步更新扩展表
+ Upsert(ctx context.Context, accountID int64, updates map[string]any) error
+
+ // GetByAccountID 根据账号 ID 获取 Sora 扩展信息
+ // 返回 nil, nil 表示记录不存在(非错误)
+ GetByAccountID(ctx context.Context, accountID int64) (*SoraAccount, error)
+
+ // Delete 删除 Sora 账号扩展信息
+ // 通常由外键 ON DELETE CASCADE 自动处理,此方法用于手动清理
+ Delete(ctx context.Context, accountID int64) error
+}
+
+// SoraAccount Sora 账号扩展信息
+// 对应 sora_accounts 表,存储 Sora 账号的 OAuth 凭证副本
+type SoraAccount struct {
+ AccountID int64 // 关联的 accounts.id
+ AccessToken string // OAuth access_token
+ RefreshToken string // OAuth refresh_token
+ SessionToken string // Session token(可选,用于 ST→AT 兜底)
+}
diff --git a/backend/internal/service/sora_client.go b/backend/internal/service/sora_client.go
new file mode 100644
index 00000000..0a914d2d
--- /dev/null
+++ b/backend/internal/service/sora_client.go
@@ -0,0 +1,117 @@
+package service
+
+import (
+ "context"
+ "fmt"
+ "net/http"
+)
+
+// SoraClient 定义直连 Sora 的任务操作接口。
+type SoraClient interface {
+ Enabled() bool
+ UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error)
+ CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error)
+ CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error)
+ CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error)
+ UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error)
+ GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error)
+ DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error)
+ UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error)
+ FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error)
+ SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error
+ DeleteCharacter(ctx context.Context, account *Account, characterID string) error
+ PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error)
+ DeletePost(ctx context.Context, account *Account, postID string) error
+ GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error)
+ EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error)
+ GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error)
+ GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error)
+}
+
+// SoraImageRequest 图片生成请求参数
+type SoraImageRequest struct {
+ Prompt string
+ Width int
+ Height int
+ MediaID string
+}
+
+// SoraVideoRequest 视频生成请求参数
+type SoraVideoRequest struct {
+ Prompt string
+ Orientation string
+ Frames int
+ Model string
+ Size string
+ VideoCount int
+ MediaID string
+ RemixTargetID string
+ CameoIDs []string
+}
+
+// SoraStoryboardRequest 分镜视频生成请求参数
+type SoraStoryboardRequest struct {
+ Prompt string
+ Orientation string
+ Frames int
+ Model string
+ Size string
+ MediaID string
+}
+
+// SoraImageTaskStatus 图片任务状态
+type SoraImageTaskStatus struct {
+ ID string
+ Status string
+ ProgressPct float64
+ URLs []string
+ ErrorMsg string
+}
+
+// SoraVideoTaskStatus 视频任务状态
+type SoraVideoTaskStatus struct {
+ ID string
+ Status string
+ ProgressPct int
+ URLs []string
+ GenerationID string
+ ErrorMsg string
+}
+
+// SoraCameoStatus 角色处理中间态
+type SoraCameoStatus struct {
+ Status string
+ StatusMessage string
+ DisplayNameHint string
+ UsernameHint string
+ ProfileAssetURL string
+ InstructionSetHint any
+ InstructionSet any
+}
+
+// SoraCharacterFinalizeRequest 角色定稿请求参数
+type SoraCharacterFinalizeRequest struct {
+ CameoID string
+ Username string
+ DisplayName string
+ ProfileAssetPointer string
+ InstructionSet any
+}
+
+// SoraUpstreamError 上游错误
+type SoraUpstreamError struct {
+ StatusCode int
+ Message string
+ Headers http.Header
+ Body []byte
+}
+
+func (e *SoraUpstreamError) Error() string {
+ if e == nil {
+ return "sora upstream error"
+ }
+ if e.Message != "" {
+ return fmt.Sprintf("sora upstream error: %d %s", e.StatusCode, e.Message)
+ }
+ return fmt.Sprintf("sora upstream error: %d", e.StatusCode)
+}
diff --git a/backend/internal/service/sora_gateway_service.go b/backend/internal/service/sora_gateway_service.go
new file mode 100644
index 00000000..e9d325f4
--- /dev/null
+++ b/backend/internal/service/sora_gateway_service.go
@@ -0,0 +1,1559 @@
+package service
+
+import (
+ "bytes"
+ "context"
+ "encoding/base64"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "log"
+ "math"
+ "math/rand"
+ "mime"
+ "net"
+ "net/http"
+ "net/url"
+ "regexp"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
+ "github.com/gin-gonic/gin"
+)
+
+const soraImageInputMaxBytes = 20 << 20
+const soraImageInputMaxRedirects = 3
+const soraImageInputTimeout = 20 * time.Second
+const soraVideoInputMaxBytes = 200 << 20
+const soraVideoInputMaxRedirects = 3
+const soraVideoInputTimeout = 60 * time.Second
+
+var soraImageSizeMap = map[string]string{
+ "gpt-image": "360",
+ "gpt-image-landscape": "540",
+ "gpt-image-portrait": "540",
+}
+
+var soraBlockedHostnames = map[string]struct{}{
+ "localhost": {},
+ "localhost.localdomain": {},
+ "metadata.google.internal": {},
+ "metadata.google.internal.": {},
+}
+
+var soraBlockedCIDRs = mustParseCIDRs([]string{
+ "0.0.0.0/8",
+ "10.0.0.0/8",
+ "100.64.0.0/10",
+ "127.0.0.0/8",
+ "169.254.0.0/16",
+ "172.16.0.0/12",
+ "192.168.0.0/16",
+ "224.0.0.0/4",
+ "240.0.0.0/4",
+ "::/128",
+ "::1/128",
+ "fc00::/7",
+ "fe80::/10",
+})
+
+// SoraGatewayService handles forwarding requests to Sora upstream.
+type SoraGatewayService struct {
+ soraClient SoraClient
+ rateLimitService *RateLimitService
+ httpUpstream HTTPUpstream // 用于 apikey 类型账号的 HTTP 透传
+ cfg *config.Config
+}
+
+type soraWatermarkOptions struct {
+ Enabled bool
+ ParseMethod string
+ ParseURL string
+ ParseToken string
+ FallbackOnFailure bool
+ DeletePost bool
+}
+
+type soraCharacterOptions struct {
+ SetPublic bool
+ DeleteAfterGenerate bool
+}
+
+type soraCharacterFlowResult struct {
+ CameoID string
+ CharacterID string
+ Username string
+ DisplayName string
+}
+
+var soraStoryboardPattern = regexp.MustCompile(`\[\d+(?:\.\d+)?s\]`)
+var soraStoryboardShotPattern = regexp.MustCompile(`\[(\d+(?:\.\d+)?)s\]\s*([^\[]+)`)
+var soraRemixTargetPattern = regexp.MustCompile(`s_[a-f0-9]{32}`)
+var soraRemixTargetInURLPattern = regexp.MustCompile(`https://sora\.chatgpt\.com/p/s_[a-f0-9]{32}`)
+
+type soraPreflightChecker interface {
+ PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error
+}
+
+func NewSoraGatewayService(
+ soraClient SoraClient,
+ rateLimitService *RateLimitService,
+ httpUpstream HTTPUpstream,
+ cfg *config.Config,
+) *SoraGatewayService {
+ return &SoraGatewayService{
+ soraClient: soraClient,
+ rateLimitService: rateLimitService,
+ httpUpstream: httpUpstream,
+ cfg: cfg,
+ }
+}
+
+func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, clientStream bool) (*ForwardResult, error) {
+ startTime := time.Now()
+
+ // apikey 类型账号:HTTP 透传到上游,不走 SoraSDKClient
+ if account.Type == AccountTypeAPIKey && account.GetBaseURL() != "" {
+ if s.httpUpstream == nil {
+ s.writeSoraError(c, http.StatusInternalServerError, "api_error", "HTTP upstream client not configured", clientStream)
+ return nil, errors.New("httpUpstream not configured for sora apikey forwarding")
+ }
+ return s.forwardToUpstream(ctx, c, account, body, clientStream, startTime)
+ }
+
+ if s.soraClient == nil || !s.soraClient.Enabled() {
+ if c != nil {
+ c.JSON(http.StatusServiceUnavailable, gin.H{
+ "error": gin.H{
+ "type": "api_error",
+ "message": "Sora 上游未配置",
+ },
+ })
+ }
+ return nil, errors.New("sora upstream not configured")
+ }
+
+ var reqBody map[string]any
+ if err := json.Unmarshal(body, &reqBody); err != nil {
+ s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body", clientStream)
+ return nil, fmt.Errorf("parse request: %w", err)
+ }
+ reqModel, _ := reqBody["model"].(string)
+ reqStream, _ := reqBody["stream"].(bool)
+ if strings.TrimSpace(reqModel) == "" {
+ s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "model is required", clientStream)
+ return nil, errors.New("model is required")
+ }
+ originalModel := reqModel
+
+ mappedModel := account.GetMappedModel(reqModel)
+ var upstreamModel string
+ if mappedModel != "" && mappedModel != reqModel {
+ reqModel = mappedModel
+ upstreamModel = mappedModel
+ }
+
+ modelCfg, ok := GetSoraModelConfig(reqModel)
+ if !ok {
+ s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Unsupported Sora model", clientStream)
+ return nil, fmt.Errorf("unsupported model: %s", reqModel)
+ }
+ prompt, imageInput, videoInput, remixTargetID := extractSoraInput(reqBody)
+ prompt = strings.TrimSpace(prompt)
+ imageInput = strings.TrimSpace(imageInput)
+ videoInput = strings.TrimSpace(videoInput)
+ remixTargetID = strings.TrimSpace(remixTargetID)
+
+ if videoInput != "" && modelCfg.Type != "video" {
+ s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "video input only supports video models", clientStream)
+ return nil, errors.New("video input only supports video models")
+ }
+ if videoInput != "" && imageInput != "" {
+ s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "image input and video input cannot be used together", clientStream)
+ return nil, errors.New("image input and video input cannot be used together")
+ }
+ characterOnly := videoInput != "" && prompt == ""
+ if modelCfg.Type == "prompt_enhance" && prompt == "" {
+ s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream)
+ return nil, errors.New("prompt is required")
+ }
+ if modelCfg.Type != "prompt_enhance" && prompt == "" && !characterOnly {
+ s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream)
+ return nil, errors.New("prompt is required")
+ }
+
+ reqCtx, cancel := s.withSoraTimeout(ctx, reqStream)
+ if cancel != nil {
+ defer cancel()
+ }
+ if checker, ok := s.soraClient.(soraPreflightChecker); ok && !characterOnly {
+ if err := checker.PreflightCheck(reqCtx, account, reqModel, modelCfg); err != nil {
+ return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream)
+ }
+ }
+
+ if modelCfg.Type == "prompt_enhance" {
+ enhancedPrompt, err := s.soraClient.EnhancePrompt(reqCtx, account, prompt, modelCfg.ExpansionLevel, modelCfg.DurationS)
+ if err != nil {
+ return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream)
+ }
+ content := strings.TrimSpace(enhancedPrompt)
+ if content == "" {
+ content = prompt
+ }
+ var firstTokenMs *int
+ if clientStream {
+ ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime)
+ if streamErr != nil {
+ return nil, streamErr
+ }
+ firstTokenMs = ms
+ } else if c != nil {
+ c.JSON(http.StatusOK, buildSoraNonStreamResponse(content, reqModel))
+ }
+ return &ForwardResult{
+ RequestID: "",
+ Model: originalModel,
+ UpstreamModel: upstreamModel,
+ Stream: clientStream,
+ Duration: time.Since(startTime),
+ FirstTokenMs: firstTokenMs,
+ Usage: ClaudeUsage{},
+ MediaType: "prompt",
+ }, nil
+ }
+
+ characterOpts := parseSoraCharacterOptions(reqBody)
+ watermarkOpts := parseSoraWatermarkOptions(reqBody)
+ var characterResult *soraCharacterFlowResult
+ if videoInput != "" {
+ videoData, videoErr := decodeSoraVideoInput(reqCtx, videoInput)
+ if videoErr != nil {
+ s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", videoErr.Error(), clientStream)
+ return nil, videoErr
+ }
+ characterResult, videoErr = s.createCharacterFromVideo(reqCtx, account, videoData, characterOpts)
+ if videoErr != nil {
+ return nil, s.handleSoraRequestError(ctx, account, videoErr, reqModel, c, clientStream)
+ }
+ if characterResult != nil && characterOpts.DeleteAfterGenerate && strings.TrimSpace(characterResult.CharacterID) != "" && !characterOnly {
+ characterID := strings.TrimSpace(characterResult.CharacterID)
+ defer func() {
+ cleanupCtx, cancelCleanup := context.WithTimeout(context.Background(), 15*time.Second)
+ defer cancelCleanup()
+ if err := s.soraClient.DeleteCharacter(cleanupCtx, account, characterID); err != nil {
+ log.Printf("[Sora] cleanup character failed, character_id=%s err=%v", characterID, err)
+ }
+ }()
+ }
+ if characterOnly {
+ content := "角色创建成功"
+ if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" {
+ content = fmt.Sprintf("角色创建成功,角色名@%s", strings.TrimSpace(characterResult.Username))
+ }
+ var firstTokenMs *int
+ if clientStream {
+ ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime)
+ if streamErr != nil {
+ return nil, streamErr
+ }
+ firstTokenMs = ms
+ } else if c != nil {
+ resp := buildSoraNonStreamResponse(content, reqModel)
+ if characterResult != nil {
+ resp["character_id"] = characterResult.CharacterID
+ resp["cameo_id"] = characterResult.CameoID
+ resp["character_username"] = characterResult.Username
+ resp["character_display_name"] = characterResult.DisplayName
+ }
+ c.JSON(http.StatusOK, resp)
+ }
+ return &ForwardResult{
+ RequestID: "",
+ Model: originalModel,
+ UpstreamModel: upstreamModel,
+ Stream: clientStream,
+ Duration: time.Since(startTime),
+ FirstTokenMs: firstTokenMs,
+ Usage: ClaudeUsage{},
+ MediaType: "prompt",
+ }, nil
+ }
+ if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" {
+ prompt = fmt.Sprintf("@%s %s", characterResult.Username, prompt)
+ }
+ }
+
+ var imageData []byte
+ imageFilename := ""
+ if imageInput != "" {
+ decoded, filename, err := decodeSoraImageInput(reqCtx, imageInput)
+ if err != nil {
+ s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", err.Error(), clientStream)
+ return nil, err
+ }
+ imageData = decoded
+ imageFilename = filename
+ }
+
+ mediaID := ""
+ if len(imageData) > 0 {
+ uploadID, err := s.soraClient.UploadImage(reqCtx, account, imageData, imageFilename)
+ if err != nil {
+ return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream)
+ }
+ mediaID = uploadID
+ }
+
+ taskID := ""
+ var err error
+ videoCount := parseSoraVideoCount(reqBody)
+ switch modelCfg.Type {
+ case "image":
+ taskID, err = s.soraClient.CreateImageTask(reqCtx, account, SoraImageRequest{
+ Prompt: prompt,
+ Width: modelCfg.Width,
+ Height: modelCfg.Height,
+ MediaID: mediaID,
+ })
+ case "video":
+ if remixTargetID == "" && isSoraStoryboardPrompt(prompt) {
+ taskID, err = s.soraClient.CreateStoryboardTask(reqCtx, account, SoraStoryboardRequest{
+ Prompt: formatSoraStoryboardPrompt(prompt),
+ Orientation: modelCfg.Orientation,
+ Frames: modelCfg.Frames,
+ Model: modelCfg.Model,
+ Size: modelCfg.Size,
+ MediaID: mediaID,
+ })
+ } else {
+ taskID, err = s.soraClient.CreateVideoTask(reqCtx, account, SoraVideoRequest{
+ Prompt: prompt,
+ Orientation: modelCfg.Orientation,
+ Frames: modelCfg.Frames,
+ Model: modelCfg.Model,
+ Size: modelCfg.Size,
+ VideoCount: videoCount,
+ MediaID: mediaID,
+ RemixTargetID: remixTargetID,
+ CameoIDs: extractSoraCameoIDs(reqBody),
+ })
+ }
+ default:
+ err = fmt.Errorf("unsupported model type: %s", modelCfg.Type)
+ }
+ if err != nil {
+ return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream)
+ }
+
+ if clientStream && c != nil {
+ s.prepareSoraStream(c, taskID)
+ }
+
+ var mediaURLs []string
+ videoGenerationID := ""
+ mediaType := modelCfg.Type
+ imageCount := 0
+ imageSize := ""
+ switch modelCfg.Type {
+ case "image":
+ urls, pollErr := s.pollImageTask(reqCtx, c, account, taskID, clientStream)
+ if pollErr != nil {
+ return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream)
+ }
+ mediaURLs = urls
+ imageCount = len(urls)
+ imageSize = soraImageSizeFromModel(reqModel)
+ case "video":
+ videoStatus, pollErr := s.pollVideoTaskDetailed(reqCtx, c, account, taskID, clientStream)
+ if pollErr != nil {
+ return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream)
+ }
+ if videoStatus != nil {
+ mediaURLs = videoStatus.URLs
+ videoGenerationID = strings.TrimSpace(videoStatus.GenerationID)
+ }
+ default:
+ mediaType = "prompt"
+ }
+
+ watermarkPostID := ""
+ if modelCfg.Type == "video" && watermarkOpts.Enabled {
+ watermarkURL, postID, watermarkErr := s.resolveWatermarkFreeURL(reqCtx, account, videoGenerationID, watermarkOpts)
+ if watermarkErr != nil {
+ if !watermarkOpts.FallbackOnFailure {
+ return nil, s.handleSoraRequestError(ctx, account, watermarkErr, reqModel, c, clientStream)
+ }
+ log.Printf("[Sora] watermark-free fallback to original URL, task_id=%s err=%v", taskID, watermarkErr)
+ } else if strings.TrimSpace(watermarkURL) != "" {
+ mediaURLs = []string{strings.TrimSpace(watermarkURL)}
+ watermarkPostID = strings.TrimSpace(postID)
+ }
+ }
+
+ // 直调路径(/sora/v1/chat/completions)保持纯透传,不执行本地/S3 媒体落盘。
+ // 媒体存储由客户端 API 路径(/api/v1/sora/generate)的异步流程负责。
+ finalURLs := s.normalizeSoraMediaURLs(mediaURLs)
+ if watermarkPostID != "" && watermarkOpts.DeletePost {
+ if deleteErr := s.soraClient.DeletePost(reqCtx, account, watermarkPostID); deleteErr != nil {
+ log.Printf("[Sora] delete post failed, post_id=%s err=%v", watermarkPostID, deleteErr)
+ }
+ }
+
+ content := buildSoraContent(mediaType, finalURLs)
+ var firstTokenMs *int
+ if clientStream {
+ ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime)
+ if streamErr != nil {
+ return nil, streamErr
+ }
+ firstTokenMs = ms
+ } else if c != nil {
+ response := buildSoraNonStreamResponse(content, reqModel)
+ if len(finalURLs) > 0 {
+ response["media_url"] = finalURLs[0]
+ if len(finalURLs) > 1 {
+ response["media_urls"] = finalURLs
+ }
+ }
+ c.JSON(http.StatusOK, response)
+ }
+
+ return &ForwardResult{
+ RequestID: taskID,
+ Model: originalModel,
+ UpstreamModel: upstreamModel,
+ Stream: clientStream,
+ Duration: time.Since(startTime),
+ FirstTokenMs: firstTokenMs,
+ Usage: ClaudeUsage{},
+ MediaType: mediaType,
+ MediaURL: firstMediaURL(finalURLs),
+ ImageCount: imageCount,
+ ImageSize: imageSize,
+ }, nil
+}
+
+func (s *SoraGatewayService) withSoraTimeout(ctx context.Context, stream bool) (context.Context, context.CancelFunc) {
+ if s == nil || s.cfg == nil {
+ return ctx, nil
+ }
+ timeoutSeconds := s.cfg.Gateway.SoraRequestTimeoutSeconds
+ if stream {
+ timeoutSeconds = s.cfg.Gateway.SoraStreamTimeoutSeconds
+ }
+ if timeoutSeconds <= 0 {
+ return ctx, nil
+ }
+ return context.WithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second)
+}
+
+func parseSoraWatermarkOptions(body map[string]any) soraWatermarkOptions {
+ opts := soraWatermarkOptions{
+ Enabled: parseBoolWithDefault(body, "watermark_free", false),
+ ParseMethod: strings.ToLower(strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_method", "third_party"))),
+ ParseURL: strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_url", "")),
+ ParseToken: strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_token", "")),
+ FallbackOnFailure: parseBoolWithDefault(body, "watermark_fallback_on_failure", true),
+ DeletePost: parseBoolWithDefault(body, "watermark_delete_post", false),
+ }
+ if opts.ParseMethod == "" {
+ opts.ParseMethod = "third_party"
+ }
+ return opts
+}
+
+func parseSoraCharacterOptions(body map[string]any) soraCharacterOptions {
+ return soraCharacterOptions{
+ SetPublic: parseBoolWithDefault(body, "character_set_public", true),
+ DeleteAfterGenerate: parseBoolWithDefault(body, "character_delete_after_generate", true),
+ }
+}
+
+func parseSoraVideoCount(body map[string]any) int {
+ if body == nil {
+ return 1
+ }
+ keys := []string{"video_count", "videos", "n_variants"}
+ for _, key := range keys {
+ count := parseIntWithDefault(body, key, 0)
+ if count > 0 {
+ return clampInt(count, 1, 3)
+ }
+ }
+ return 1
+}
+
+func parseBoolWithDefault(body map[string]any, key string, def bool) bool {
+ if body == nil {
+ return def
+ }
+ val, ok := body[key]
+ if !ok {
+ return def
+ }
+ switch typed := val.(type) {
+ case bool:
+ return typed
+ case int:
+ return typed != 0
+ case int32:
+ return typed != 0
+ case int64:
+ return typed != 0
+ case float64:
+ return typed != 0
+ case string:
+ typed = strings.ToLower(strings.TrimSpace(typed))
+ if typed == "true" || typed == "1" || typed == "yes" {
+ return true
+ }
+ if typed == "false" || typed == "0" || typed == "no" {
+ return false
+ }
+ }
+ return def
+}
+
+func parseStringWithDefault(body map[string]any, key, def string) string {
+ if body == nil {
+ return def
+ }
+ val, ok := body[key]
+ if !ok {
+ return def
+ }
+ if str, ok := val.(string); ok {
+ return str
+ }
+ return def
+}
+
+func parseIntWithDefault(body map[string]any, key string, def int) int {
+ if body == nil {
+ return def
+ }
+ val, ok := body[key]
+ if !ok {
+ return def
+ }
+ switch typed := val.(type) {
+ case int:
+ return typed
+ case int32:
+ return int(typed)
+ case int64:
+ return int(typed)
+ case float64:
+ return int(typed)
+ case string:
+ parsed, err := strconv.Atoi(strings.TrimSpace(typed))
+ if err == nil {
+ return parsed
+ }
+ }
+ return def
+}
+
+func clampInt(v, minVal, maxVal int) int {
+ if v < minVal {
+ return minVal
+ }
+ if v > maxVal {
+ return maxVal
+ }
+ return v
+}
+
+func extractSoraCameoIDs(body map[string]any) []string {
+ if body == nil {
+ return nil
+ }
+ raw, ok := body["cameo_ids"]
+ if !ok {
+ return nil
+ }
+ switch typed := raw.(type) {
+ case []string:
+ out := make([]string, 0, len(typed))
+ for _, item := range typed {
+ item = strings.TrimSpace(item)
+ if item != "" {
+ out = append(out, item)
+ }
+ }
+ return out
+ case []any:
+ out := make([]string, 0, len(typed))
+ for _, item := range typed {
+ str, ok := item.(string)
+ if !ok {
+ continue
+ }
+ str = strings.TrimSpace(str)
+ if str != "" {
+ out = append(out, str)
+ }
+ }
+ return out
+ default:
+ return nil
+ }
+}
+
+func (s *SoraGatewayService) createCharacterFromVideo(ctx context.Context, account *Account, videoData []byte, opts soraCharacterOptions) (*soraCharacterFlowResult, error) {
+ cameoID, err := s.soraClient.UploadCharacterVideo(ctx, account, videoData)
+ if err != nil {
+ return nil, err
+ }
+
+ cameoStatus, err := s.pollCameoStatus(ctx, account, cameoID)
+ if err != nil {
+ return nil, err
+ }
+ username := processSoraCharacterUsername(cameoStatus.UsernameHint)
+ displayName := strings.TrimSpace(cameoStatus.DisplayNameHint)
+ if displayName == "" {
+ displayName = "Character"
+ }
+ profileAssetURL := strings.TrimSpace(cameoStatus.ProfileAssetURL)
+ if profileAssetURL == "" {
+ return nil, errors.New("profile asset url not found in cameo status")
+ }
+
+ avatarData, err := s.soraClient.DownloadCharacterImage(ctx, account, profileAssetURL)
+ if err != nil {
+ return nil, err
+ }
+ assetPointer, err := s.soraClient.UploadCharacterImage(ctx, account, avatarData)
+ if err != nil {
+ return nil, err
+ }
+ instructionSet := cameoStatus.InstructionSetHint
+ if instructionSet == nil {
+ instructionSet = cameoStatus.InstructionSet
+ }
+
+ characterID, err := s.soraClient.FinalizeCharacter(ctx, account, SoraCharacterFinalizeRequest{
+ CameoID: strings.TrimSpace(cameoID),
+ Username: username,
+ DisplayName: displayName,
+ ProfileAssetPointer: assetPointer,
+ InstructionSet: instructionSet,
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ if opts.SetPublic {
+ if err := s.soraClient.SetCharacterPublic(ctx, account, cameoID); err != nil {
+ return nil, err
+ }
+ }
+
+ return &soraCharacterFlowResult{
+ CameoID: strings.TrimSpace(cameoID),
+ CharacterID: strings.TrimSpace(characterID),
+ Username: strings.TrimSpace(username),
+ DisplayName: displayName,
+ }, nil
+}
+
+func (s *SoraGatewayService) pollCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) {
+ timeout := 10 * time.Minute
+ interval := 5 * time.Second
+ maxAttempts := int(math.Ceil(timeout.Seconds() / interval.Seconds()))
+ if maxAttempts < 1 {
+ maxAttempts = 1
+ }
+
+ var lastErr error
+ consecutiveErrors := 0
+ for attempt := 0; attempt < maxAttempts; attempt++ {
+ status, err := s.soraClient.GetCameoStatus(ctx, account, cameoID)
+ if err != nil {
+ lastErr = err
+ consecutiveErrors++
+ if consecutiveErrors >= 3 {
+ break
+ }
+ if attempt < maxAttempts-1 {
+ if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil {
+ return nil, sleepErr
+ }
+ }
+ continue
+ }
+ consecutiveErrors = 0
+ if status == nil {
+ if attempt < maxAttempts-1 {
+ if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil {
+ return nil, sleepErr
+ }
+ }
+ continue
+ }
+ currentStatus := strings.ToLower(strings.TrimSpace(status.Status))
+ statusMessage := strings.TrimSpace(status.StatusMessage)
+ if currentStatus == "failed" {
+ if statusMessage == "" {
+ statusMessage = "character creation failed"
+ }
+ return nil, errors.New(statusMessage)
+ }
+ if strings.EqualFold(statusMessage, "Completed") || currentStatus == "finalized" {
+ return status, nil
+ }
+ if attempt < maxAttempts-1 {
+ if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil {
+ return nil, sleepErr
+ }
+ }
+ }
+ if lastErr != nil {
+ return nil, fmt.Errorf("poll cameo status failed: %w", lastErr)
+ }
+ return nil, errors.New("cameo processing timeout")
+}
+
+func processSoraCharacterUsername(usernameHint string) string {
+ usernameHint = strings.TrimSpace(usernameHint)
+ if usernameHint == "" {
+ usernameHint = "character"
+ }
+ if strings.Contains(usernameHint, ".") {
+ parts := strings.Split(usernameHint, ".")
+ usernameHint = strings.TrimSpace(parts[len(parts)-1])
+ }
+ if usernameHint == "" {
+ usernameHint = "character"
+ }
+ return fmt.Sprintf("%s%d", usernameHint, rand.Intn(900)+100)
+}
+
+func (s *SoraGatewayService) resolveWatermarkFreeURL(ctx context.Context, account *Account, generationID string, opts soraWatermarkOptions) (string, string, error) {
+ generationID = strings.TrimSpace(generationID)
+ if generationID == "" {
+ return "", "", errors.New("generation id is required for watermark-free mode")
+ }
+ postID, err := s.soraClient.PostVideoForWatermarkFree(ctx, account, generationID)
+ if err != nil {
+ return "", "", err
+ }
+ postID = strings.TrimSpace(postID)
+ if postID == "" {
+ return "", "", errors.New("watermark-free publish returned empty post id")
+ }
+
+ switch opts.ParseMethod {
+ case "custom":
+ urlVal, parseErr := s.soraClient.GetWatermarkFreeURLCustom(ctx, account, opts.ParseURL, opts.ParseToken, postID)
+ if parseErr != nil {
+ return "", postID, parseErr
+ }
+ return strings.TrimSpace(urlVal), postID, nil
+ case "", "third_party":
+ return fmt.Sprintf("https://oscdn2.dyysy.com/MP4/%s.mp4", postID), postID, nil
+ default:
+ return "", postID, fmt.Errorf("unsupported watermark parse method: %s", opts.ParseMethod)
+ }
+}
+
+func (s *SoraGatewayService) shouldFailoverUpstreamError(statusCode int) bool {
+ switch statusCode {
+ case 401, 402, 403, 404, 429, 529:
+ return true
+ default:
+ return statusCode >= 500
+ }
+}
+
+func buildSoraNonStreamResponse(content, model string) map[string]any {
+ return map[string]any{
+ "id": fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()),
+ "object": "chat.completion",
+ "created": time.Now().Unix(),
+ "model": model,
+ "choices": []any{
+ map[string]any{
+ "index": 0,
+ "message": map[string]any{
+ "role": "assistant",
+ "content": content,
+ },
+ "finish_reason": "stop",
+ },
+ },
+ }
+}
+
+func soraImageSizeFromModel(model string) string {
+ modelLower := strings.ToLower(model)
+ if size, ok := soraImageSizeMap[modelLower]; ok {
+ return size
+ }
+ if strings.Contains(modelLower, "landscape") || strings.Contains(modelLower, "portrait") {
+ return "540"
+ }
+ return "360"
+}
+
+func soraProErrorMessage(model, upstreamMsg string) string {
+ modelLower := strings.ToLower(model)
+ if strings.Contains(modelLower, "sora2pro-hd") {
+ return "当前账号无法使用 Sora Pro-HD 模型,请更换模型或账号"
+ }
+ if strings.Contains(modelLower, "sora2pro") {
+ return "当前账号无法使用 Sora Pro 模型,请更换模型或账号"
+ }
+ return ""
+}
+
+func firstMediaURL(urls []string) string {
+ if len(urls) == 0 {
+ return ""
+ }
+ return urls[0]
+}
+
+func (s *SoraGatewayService) buildSoraMediaURL(path string, rawQuery string) string {
+ if path == "" {
+ return path
+ }
+ prefix := "/sora/media"
+ values := url.Values{}
+ if rawQuery != "" {
+ if parsed, err := url.ParseQuery(rawQuery); err == nil {
+ values = parsed
+ }
+ }
+
+ signKey := ""
+ ttlSeconds := 0
+ if s != nil && s.cfg != nil {
+ signKey = strings.TrimSpace(s.cfg.Gateway.SoraMediaSigningKey)
+ ttlSeconds = s.cfg.Gateway.SoraMediaSignedURLTTLSeconds
+ }
+ values.Del("sig")
+ values.Del("expires")
+ signingQuery := values.Encode()
+ if signKey != "" && ttlSeconds > 0 {
+ expires := time.Now().Add(time.Duration(ttlSeconds) * time.Second).Unix()
+ signature := SignSoraMediaURL(path, signingQuery, expires, signKey)
+ if signature != "" {
+ values.Set("expires", strconv.FormatInt(expires, 10))
+ values.Set("sig", signature)
+ prefix = "/sora/media-signed"
+ }
+ }
+
+ encoded := values.Encode()
+ if encoded == "" {
+ return prefix + path
+ }
+ return prefix + path + "?" + encoded
+}
+
+func (s *SoraGatewayService) prepareSoraStream(c *gin.Context, requestID string) {
+ if c == nil {
+ return
+ }
+ c.Header("Content-Type", "text/event-stream")
+ c.Header("Cache-Control", "no-cache")
+ c.Header("Connection", "keep-alive")
+ c.Header("X-Accel-Buffering", "no")
+ if strings.TrimSpace(requestID) != "" {
+ c.Header("x-request-id", requestID)
+ }
+}
+
+func (s *SoraGatewayService) writeSoraStream(c *gin.Context, model, content string, startTime time.Time) (*int, error) {
+ if c == nil {
+ return nil, nil
+ }
+ writer := c.Writer
+ flusher, _ := writer.(http.Flusher)
+
+ chunk := map[string]any{
+ "id": fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()),
+ "object": "chat.completion.chunk",
+ "created": time.Now().Unix(),
+ "model": model,
+ "choices": []any{
+ map[string]any{
+ "index": 0,
+ "delta": map[string]any{
+ "content": content,
+ },
+ },
+ },
+ }
+ encoded, _ := jsonMarshalRaw(chunk)
+ if _, err := fmt.Fprintf(writer, "data: %s\n\n", encoded); err != nil {
+ return nil, err
+ }
+ if flusher != nil {
+ flusher.Flush()
+ }
+ ms := int(time.Since(startTime).Milliseconds())
+ finalChunk := map[string]any{
+ "id": chunk["id"],
+ "object": "chat.completion.chunk",
+ "created": time.Now().Unix(),
+ "model": model,
+ "choices": []any{
+ map[string]any{
+ "index": 0,
+ "delta": map[string]any{},
+ "finish_reason": "stop",
+ },
+ },
+ }
+ finalEncoded, _ := jsonMarshalRaw(finalChunk)
+ if _, err := fmt.Fprintf(writer, "data: %s\n\n", finalEncoded); err != nil {
+ return &ms, err
+ }
+ if _, err := fmt.Fprint(writer, "data: [DONE]\n\n"); err != nil {
+ return &ms, err
+ }
+ if flusher != nil {
+ flusher.Flush()
+ }
+ return &ms, nil
+}
+
+func (s *SoraGatewayService) writeSoraError(c *gin.Context, status int, errType, message string, stream bool) {
+ if c == nil {
+ return
+ }
+ if stream {
+ flusher, _ := c.Writer.(http.Flusher)
+ errorData := map[string]any{
+ "error": map[string]string{
+ "type": errType,
+ "message": message,
+ },
+ }
+ jsonBytes, err := json.Marshal(errorData)
+ if err != nil {
+ _ = c.Error(err)
+ return
+ }
+ errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes))
+ _, _ = fmt.Fprint(c.Writer, errorEvent)
+ _, _ = fmt.Fprint(c.Writer, "data: [DONE]\n\n")
+ if flusher != nil {
+ flusher.Flush()
+ }
+ return
+ }
+ c.JSON(status, gin.H{
+ "error": gin.H{
+ "type": errType,
+ "message": message,
+ },
+ })
+}
+
+func (s *SoraGatewayService) handleSoraRequestError(ctx context.Context, account *Account, err error, model string, c *gin.Context, stream bool) error {
+ if err == nil {
+ return nil
+ }
+ var upstreamErr *SoraUpstreamError
+ if errors.As(err, &upstreamErr) {
+ accountID := int64(0)
+ if account != nil {
+ accountID = account.ID
+ }
+ logger.LegacyPrintf(
+ "service.sora",
+ "[SoraRawError] account_id=%d model=%s status=%d request_id=%s cf_ray=%s message=%s raw_body=%s",
+ accountID,
+ model,
+ upstreamErr.StatusCode,
+ strings.TrimSpace(upstreamErr.Headers.Get("x-request-id")),
+ strings.TrimSpace(upstreamErr.Headers.Get("cf-ray")),
+ strings.TrimSpace(upstreamErr.Message),
+ truncateForLog(upstreamErr.Body, 1024),
+ )
+ if s.rateLimitService != nil && account != nil {
+ s.rateLimitService.HandleUpstreamError(ctx, account, upstreamErr.StatusCode, upstreamErr.Headers, upstreamErr.Body)
+ }
+ if s.shouldFailoverUpstreamError(upstreamErr.StatusCode) {
+ var responseHeaders http.Header
+ if upstreamErr.Headers != nil {
+ responseHeaders = upstreamErr.Headers.Clone()
+ }
+ return &UpstreamFailoverError{
+ StatusCode: upstreamErr.StatusCode,
+ ResponseBody: upstreamErr.Body,
+ ResponseHeaders: responseHeaders,
+ }
+ }
+ msg := upstreamErr.Message
+ if override := soraProErrorMessage(model, msg); override != "" {
+ msg = override
+ }
+ s.writeSoraError(c, upstreamErr.StatusCode, "upstream_error", msg, stream)
+ return err
+ }
+ if errors.Is(err, context.DeadlineExceeded) {
+ s.writeSoraError(c, http.StatusGatewayTimeout, "timeout_error", "Sora generation timeout", stream)
+ return err
+ }
+ s.writeSoraError(c, http.StatusBadGateway, "api_error", err.Error(), stream)
+ return err
+}
+
+func (s *SoraGatewayService) pollImageTask(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) ([]string, error) {
+ interval := s.pollInterval()
+ maxAttempts := s.pollMaxAttempts()
+ lastPing := time.Now()
+ for attempt := 0; attempt < maxAttempts; attempt++ {
+ status, err := s.soraClient.GetImageTask(ctx, account, taskID)
+ if err != nil {
+ return nil, err
+ }
+ switch strings.ToLower(status.Status) {
+ case "succeeded", "completed":
+ return status.URLs, nil
+ case "failed":
+ if status.ErrorMsg != "" {
+ return nil, errors.New(status.ErrorMsg)
+ }
+ return nil, errors.New("sora image generation failed")
+ }
+ if stream {
+ s.maybeSendPing(c, &lastPing)
+ }
+ if err := sleepWithContext(ctx, interval); err != nil {
+ return nil, err
+ }
+ }
+ return nil, errors.New("sora image generation timeout")
+}
+
+func (s *SoraGatewayService) pollVideoTaskDetailed(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) (*SoraVideoTaskStatus, error) {
+ interval := s.pollInterval()
+ maxAttempts := s.pollMaxAttempts()
+ lastPing := time.Now()
+ for attempt := 0; attempt < maxAttempts; attempt++ {
+ status, err := s.soraClient.GetVideoTask(ctx, account, taskID)
+ if err != nil {
+ return nil, err
+ }
+ switch strings.ToLower(status.Status) {
+ case "completed", "succeeded":
+ return status, nil
+ case "failed":
+ if status.ErrorMsg != "" {
+ return nil, errors.New(status.ErrorMsg)
+ }
+ return nil, errors.New("sora video generation failed")
+ }
+ if stream {
+ s.maybeSendPing(c, &lastPing)
+ }
+ if err := sleepWithContext(ctx, interval); err != nil {
+ return nil, err
+ }
+ }
+ return nil, errors.New("sora video generation timeout")
+}
+
+func (s *SoraGatewayService) pollInterval() time.Duration {
+ if s == nil || s.cfg == nil {
+ return 2 * time.Second
+ }
+ interval := s.cfg.Sora.Client.PollIntervalSeconds
+ if interval <= 0 {
+ interval = 2
+ }
+ return time.Duration(interval) * time.Second
+}
+
+func (s *SoraGatewayService) pollMaxAttempts() int {
+ if s == nil || s.cfg == nil {
+ return 600
+ }
+ maxAttempts := s.cfg.Sora.Client.MaxPollAttempts
+ if maxAttempts <= 0 {
+ maxAttempts = 600
+ }
+ return maxAttempts
+}
+
+func (s *SoraGatewayService) maybeSendPing(c *gin.Context, lastPing *time.Time) {
+ if c == nil {
+ return
+ }
+ interval := 10 * time.Second
+ if s != nil && s.cfg != nil && s.cfg.Concurrency.PingInterval > 0 {
+ interval = time.Duration(s.cfg.Concurrency.PingInterval) * time.Second
+ }
+ if time.Since(*lastPing) < interval {
+ return
+ }
+ if _, err := fmt.Fprint(c.Writer, ":\n\n"); err == nil {
+ if flusher, ok := c.Writer.(http.Flusher); ok {
+ flusher.Flush()
+ }
+ *lastPing = time.Now()
+ }
+}
+
+func (s *SoraGatewayService) normalizeSoraMediaURLs(urls []string) []string {
+ if len(urls) == 0 {
+ return urls
+ }
+ output := make([]string, 0, len(urls))
+ for _, raw := range urls {
+ raw = strings.TrimSpace(raw)
+ if raw == "" {
+ continue
+ }
+ if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") {
+ output = append(output, raw)
+ continue
+ }
+ pathVal := raw
+ if !strings.HasPrefix(pathVal, "/") {
+ pathVal = "/" + pathVal
+ }
+ output = append(output, s.buildSoraMediaURL(pathVal, ""))
+ }
+ return output
+}
+
+// jsonMarshalRaw 序列化 JSON,不转义 &、<、> 等 HTML 字符,
+// 避免 URL 中的 & 被转义为 \u0026 导致客户端无法直接使用。
+func jsonMarshalRaw(v any) ([]byte, error) {
+ var buf bytes.Buffer
+ enc := json.NewEncoder(&buf)
+ enc.SetEscapeHTML(false)
+ if err := enc.Encode(v); err != nil {
+ return nil, err
+ }
+ // Encode 会追加换行符,去掉它
+ b := buf.Bytes()
+ if len(b) > 0 && b[len(b)-1] == '\n' {
+ b = b[:len(b)-1]
+ }
+ return b, nil
+}
+
+func buildSoraContent(mediaType string, urls []string) string {
+ switch mediaType {
+ case "image":
+ parts := make([]string, 0, len(urls))
+ for _, u := range urls {
+ parts = append(parts, fmt.Sprintf("", u))
+ }
+ return strings.Join(parts, "\n")
+ case "video":
+ if len(urls) == 0 {
+ return ""
+ }
+ return fmt.Sprintf("```html\n\n```", urls[0])
+ default:
+ return ""
+ }
+}
+
+func extractSoraInput(body map[string]any) (prompt, imageInput, videoInput, remixTargetID string) {
+ if body == nil {
+ return "", "", "", ""
+ }
+ if v, ok := body["remix_target_id"].(string); ok {
+ remixTargetID = strings.TrimSpace(v)
+ }
+ if v, ok := body["image"].(string); ok {
+ imageInput = v
+ }
+ if v, ok := body["video"].(string); ok {
+ videoInput = v
+ }
+ if v, ok := body["prompt"].(string); ok && strings.TrimSpace(v) != "" {
+ prompt = v
+ }
+ if messages, ok := body["messages"].([]any); ok {
+ builder := strings.Builder{}
+ for _, raw := range messages {
+ msg, ok := raw.(map[string]any)
+ if !ok {
+ continue
+ }
+ role, _ := msg["role"].(string)
+ if role != "" && role != "user" {
+ continue
+ }
+ content := msg["content"]
+ text, img, vid := parseSoraMessageContent(content)
+ if text != "" {
+ if builder.Len() > 0 {
+ _, _ = builder.WriteString("\n")
+ }
+ _, _ = builder.WriteString(text)
+ }
+ if imageInput == "" && img != "" {
+ imageInput = img
+ }
+ if videoInput == "" && vid != "" {
+ videoInput = vid
+ }
+ }
+ if prompt == "" {
+ prompt = builder.String()
+ }
+ }
+ if remixTargetID == "" {
+ remixTargetID = extractRemixTargetIDFromPrompt(prompt)
+ }
+ prompt = cleanRemixLinkFromPrompt(prompt)
+ return prompt, imageInput, videoInput, remixTargetID
+}
+
+func parseSoraMessageContent(content any) (text, imageInput, videoInput string) {
+ switch val := content.(type) {
+ case string:
+ return val, "", ""
+ case []any:
+ builder := strings.Builder{}
+ for _, item := range val {
+ itemMap, ok := item.(map[string]any)
+ if !ok {
+ continue
+ }
+ t, _ := itemMap["type"].(string)
+ switch t {
+ case "text":
+ if txt, ok := itemMap["text"].(string); ok && strings.TrimSpace(txt) != "" {
+ if builder.Len() > 0 {
+ _, _ = builder.WriteString("\n")
+ }
+ _, _ = builder.WriteString(txt)
+ }
+ case "image_url":
+ if imageInput == "" {
+ if urlVal, ok := itemMap["image_url"].(map[string]any); ok {
+ imageInput = fmt.Sprintf("%v", urlVal["url"])
+ } else if urlStr, ok := itemMap["image_url"].(string); ok {
+ imageInput = urlStr
+ }
+ }
+ case "video_url":
+ if videoInput == "" {
+ if urlVal, ok := itemMap["video_url"].(map[string]any); ok {
+ videoInput = fmt.Sprintf("%v", urlVal["url"])
+ } else if urlStr, ok := itemMap["video_url"].(string); ok {
+ videoInput = urlStr
+ }
+ }
+ }
+ }
+ return builder.String(), imageInput, videoInput
+ default:
+ return "", "", ""
+ }
+}
+
+func isSoraStoryboardPrompt(prompt string) bool {
+ prompt = strings.TrimSpace(prompt)
+ if prompt == "" {
+ return false
+ }
+ return len(soraStoryboardPattern.FindAllString(prompt, -1)) >= 1
+}
+
+func formatSoraStoryboardPrompt(prompt string) string {
+ prompt = strings.TrimSpace(prompt)
+ if prompt == "" {
+ return ""
+ }
+ matches := soraStoryboardShotPattern.FindAllStringSubmatch(prompt, -1)
+ if len(matches) == 0 {
+ return prompt
+ }
+ firstBracketPos := strings.Index(prompt, "[")
+ instructions := ""
+ if firstBracketPos > 0 {
+ instructions = strings.TrimSpace(prompt[:firstBracketPos])
+ }
+ shots := make([]string, 0, len(matches))
+ for i, match := range matches {
+ if len(match) < 3 {
+ continue
+ }
+ duration := strings.TrimSpace(match[1])
+ scene := strings.TrimSpace(match[2])
+ if scene == "" {
+ continue
+ }
+ shots = append(shots, fmt.Sprintf("Shot %d:\nduration: %ssec\nScene: %s", i+1, duration, scene))
+ }
+ if len(shots) == 0 {
+ return prompt
+ }
+ timeline := strings.Join(shots, "\n\n")
+ if instructions == "" {
+ return timeline
+ }
+ return fmt.Sprintf("current timeline:\n%s\n\ninstructions:\n%s", timeline, instructions)
+}
+
+func extractRemixTargetIDFromPrompt(prompt string) string {
+ prompt = strings.TrimSpace(prompt)
+ if prompt == "" {
+ return ""
+ }
+ return strings.TrimSpace(soraRemixTargetPattern.FindString(prompt))
+}
+
+func cleanRemixLinkFromPrompt(prompt string) string {
+ prompt = strings.TrimSpace(prompt)
+ if prompt == "" {
+ return prompt
+ }
+ cleaned := soraRemixTargetInURLPattern.ReplaceAllString(prompt, "")
+ cleaned = soraRemixTargetPattern.ReplaceAllString(cleaned, "")
+ cleaned = strings.Join(strings.Fields(cleaned), " ")
+ return strings.TrimSpace(cleaned)
+}
+
+func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, error) {
+ raw := strings.TrimSpace(input)
+ if raw == "" {
+ return nil, "", errors.New("empty image input")
+ }
+ if strings.HasPrefix(raw, "data:") {
+ parts := strings.SplitN(raw, ",", 2)
+ if len(parts) != 2 {
+ return nil, "", errors.New("invalid data url")
+ }
+ meta := parts[0]
+ payload := parts[1]
+ decoded, err := decodeBase64WithLimit(payload, soraImageInputMaxBytes)
+ if err != nil {
+ return nil, "", err
+ }
+ ext := ""
+ if strings.HasPrefix(meta, "data:") {
+ metaParts := strings.SplitN(meta[5:], ";", 2)
+ if len(metaParts) > 0 {
+ if exts, err := mime.ExtensionsByType(metaParts[0]); err == nil && len(exts) > 0 {
+ ext = exts[0]
+ }
+ }
+ }
+ filename := "image" + ext
+ return decoded, filename, nil
+ }
+ if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") {
+ return downloadSoraImageInput(ctx, raw)
+ }
+ decoded, err := decodeBase64WithLimit(raw, soraImageInputMaxBytes)
+ if err != nil {
+ return nil, "", errors.New("invalid base64 image")
+ }
+ return decoded, "image.png", nil
+}
+
+func decodeSoraVideoInput(ctx context.Context, input string) ([]byte, error) {
+ raw := strings.TrimSpace(input)
+ if raw == "" {
+ return nil, errors.New("empty video input")
+ }
+ if strings.HasPrefix(raw, "data:") {
+ parts := strings.SplitN(raw, ",", 2)
+ if len(parts) != 2 {
+ return nil, errors.New("invalid video data url")
+ }
+ decoded, err := decodeBase64WithLimit(parts[1], soraVideoInputMaxBytes)
+ if err != nil {
+ return nil, errors.New("invalid base64 video")
+ }
+ if len(decoded) == 0 {
+ return nil, errors.New("empty video data")
+ }
+ return decoded, nil
+ }
+ if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") {
+ return downloadSoraVideoInput(ctx, raw)
+ }
+ decoded, err := decodeBase64WithLimit(raw, soraVideoInputMaxBytes)
+ if err != nil {
+ return nil, errors.New("invalid base64 video")
+ }
+ if len(decoded) == 0 {
+ return nil, errors.New("empty video data")
+ }
+ return decoded, nil
+}
+
+func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string, error) {
+ parsed, err := validateSoraRemoteURL(rawURL)
+ if err != nil {
+ return nil, "", err
+ }
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, parsed.String(), nil)
+ if err != nil {
+ return nil, "", err
+ }
+ client := &http.Client{
+ Timeout: soraImageInputTimeout,
+ CheckRedirect: func(req *http.Request, via []*http.Request) error {
+ if len(via) >= soraImageInputMaxRedirects {
+ return errors.New("too many redirects")
+ }
+ return validateSoraRemoteURLValue(req.URL)
+ },
+ }
+ resp, err := client.Do(req)
+ if err != nil {
+ return nil, "", err
+ }
+ defer func() { _ = resp.Body.Close() }()
+ if resp.StatusCode != http.StatusOK {
+ return nil, "", fmt.Errorf("download image failed: %d", resp.StatusCode)
+ }
+ data, err := io.ReadAll(io.LimitReader(resp.Body, soraImageInputMaxBytes))
+ if err != nil {
+ return nil, "", err
+ }
+ ext := fileExtFromURL(parsed.String())
+ if ext == "" {
+ ext = fileExtFromContentType(resp.Header.Get("Content-Type"))
+ }
+ filename := "image" + ext
+ return data, filename, nil
+}
+
+func downloadSoraVideoInput(ctx context.Context, rawURL string) ([]byte, error) {
+ parsed, err := validateSoraRemoteURL(rawURL)
+ if err != nil {
+ return nil, err
+ }
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, parsed.String(), nil)
+ if err != nil {
+ return nil, err
+ }
+ client := &http.Client{
+ Timeout: soraVideoInputTimeout,
+ CheckRedirect: func(req *http.Request, via []*http.Request) error {
+ if len(via) >= soraVideoInputMaxRedirects {
+ return errors.New("too many redirects")
+ }
+ return validateSoraRemoteURLValue(req.URL)
+ },
+ }
+ resp, err := client.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = resp.Body.Close() }()
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("download video failed: %d", resp.StatusCode)
+ }
+ data, err := io.ReadAll(io.LimitReader(resp.Body, soraVideoInputMaxBytes))
+ if err != nil {
+ return nil, err
+ }
+ if len(data) == 0 {
+ return nil, errors.New("empty video content")
+ }
+ return data, nil
+}
+
+func decodeBase64WithLimit(encoded string, maxBytes int64) ([]byte, error) {
+ if maxBytes <= 0 {
+ return nil, errors.New("invalid max bytes limit")
+ }
+ decoder := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded))
+ limited := io.LimitReader(decoder, maxBytes+1)
+ data, err := io.ReadAll(limited)
+ if err != nil {
+ return nil, err
+ }
+ if int64(len(data)) > maxBytes {
+ return nil, fmt.Errorf("input exceeds %d bytes limit", maxBytes)
+ }
+ return data, nil
+}
+
+func validateSoraRemoteURL(raw string) (*url.URL, error) {
+ if strings.TrimSpace(raw) == "" {
+ return nil, errors.New("empty remote url")
+ }
+ parsed, err := url.Parse(raw)
+ if err != nil {
+ return nil, fmt.Errorf("invalid remote url: %w", err)
+ }
+ if err := validateSoraRemoteURLValue(parsed); err != nil {
+ return nil, err
+ }
+ return parsed, nil
+}
+
+func validateSoraRemoteURLValue(parsed *url.URL) error {
+ if parsed == nil {
+ return errors.New("invalid remote url")
+ }
+ scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
+ if scheme != "http" && scheme != "https" {
+ return errors.New("only http/https remote url is allowed")
+ }
+ if parsed.User != nil {
+ return errors.New("remote url cannot contain userinfo")
+ }
+ host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
+ if host == "" {
+ return errors.New("remote url missing host")
+ }
+ if _, blocked := soraBlockedHostnames[host]; blocked {
+ return errors.New("remote url is not allowed")
+ }
+ if ip := net.ParseIP(host); ip != nil {
+ if isSoraBlockedIP(ip) {
+ return errors.New("remote url is not allowed")
+ }
+ return nil
+ }
+ ips, err := net.LookupIP(host)
+ if err != nil {
+ return fmt.Errorf("resolve remote url failed: %w", err)
+ }
+ for _, ip := range ips {
+ if isSoraBlockedIP(ip) {
+ return errors.New("remote url is not allowed")
+ }
+ }
+ return nil
+}
+
+func isSoraBlockedIP(ip net.IP) bool {
+ if ip == nil {
+ return true
+ }
+ for _, cidr := range soraBlockedCIDRs {
+ if cidr.Contains(ip) {
+ return true
+ }
+ }
+ return false
+}
+
+func mustParseCIDRs(values []string) []*net.IPNet {
+ out := make([]*net.IPNet, 0, len(values))
+ for _, val := range values {
+ _, cidr, err := net.ParseCIDR(val)
+ if err != nil {
+ continue
+ }
+ out = append(out, cidr)
+ }
+ return out
+}
diff --git a/backend/internal/service/sora_gateway_service_test.go b/backend/internal/service/sora_gateway_service_test.go
new file mode 100644
index 00000000..2fef600c
--- /dev/null
+++ b/backend/internal/service/sora_gateway_service_test.go
@@ -0,0 +1,564 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+var _ SoraClient = (*stubSoraClientForPoll)(nil)
+
+type stubSoraClientForPoll struct {
+ imageStatus *SoraImageTaskStatus
+ videoStatus *SoraVideoTaskStatus
+ imageCalls int
+ videoCalls int
+ enhanced string
+ enhanceErr error
+ storyboard bool
+ videoReq SoraVideoRequest
+ parseErr error
+ postCalls int
+ deleteCalls int
+}
+
+func (s *stubSoraClientForPoll) Enabled() bool { return true }
+func (s *stubSoraClientForPoll) UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) {
+ return "", nil
+}
+func (s *stubSoraClientForPoll) CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error) {
+ return "task-image", nil
+}
+func (s *stubSoraClientForPoll) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) {
+ s.videoReq = req
+ return "task-video", nil
+}
+func (s *stubSoraClientForPoll) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) {
+ s.storyboard = true
+ return "task-video", nil
+}
+func (s *stubSoraClientForPoll) UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) {
+ return "cameo-1", nil
+}
+func (s *stubSoraClientForPoll) GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) {
+ return &SoraCameoStatus{
+ Status: "finalized",
+ StatusMessage: "Completed",
+ DisplayNameHint: "Character",
+ UsernameHint: "user.character",
+ ProfileAssetURL: "https://example.com/avatar.webp",
+ }, nil
+}
+func (s *stubSoraClientForPoll) DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) {
+ return []byte("avatar"), nil
+}
+func (s *stubSoraClientForPoll) UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) {
+ return "asset-pointer", nil
+}
+func (s *stubSoraClientForPoll) FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) {
+ return "character-1", nil
+}
+func (s *stubSoraClientForPoll) SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error {
+ return nil
+}
+func (s *stubSoraClientForPoll) DeleteCharacter(ctx context.Context, account *Account, characterID string) error {
+ return nil
+}
+func (s *stubSoraClientForPoll) PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) {
+ s.postCalls++
+ return "s_post", nil
+}
+func (s *stubSoraClientForPoll) DeletePost(ctx context.Context, account *Account, postID string) error {
+ s.deleteCalls++
+ return nil
+}
+func (s *stubSoraClientForPoll) GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) {
+ if s.parseErr != nil {
+ return "", s.parseErr
+ }
+ return "https://example.com/no-watermark.mp4", nil
+}
+func (s *stubSoraClientForPoll) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) {
+ if s.enhanced != "" {
+ return s.enhanced, s.enhanceErr
+ }
+ return "enhanced prompt", s.enhanceErr
+}
+func (s *stubSoraClientForPoll) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) {
+ s.imageCalls++
+ return s.imageStatus, nil
+}
+func (s *stubSoraClientForPoll) GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) {
+ s.videoCalls++
+ return s.videoStatus, nil
+}
+
+func TestSoraGatewayService_PollImageTaskCompleted(t *testing.T) {
+ client := &stubSoraClientForPoll{
+ imageStatus: &SoraImageTaskStatus{
+ Status: "completed",
+ URLs: []string{"https://example.com/a.png"},
+ },
+ }
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ PollIntervalSeconds: 1,
+ MaxPollAttempts: 1,
+ },
+ },
+ }
+ service := NewSoraGatewayService(client, nil, nil, cfg)
+
+ urls, err := service.pollImageTask(context.Background(), nil, &Account{ID: 1}, "task", false)
+ require.NoError(t, err)
+ require.Equal(t, []string{"https://example.com/a.png"}, urls)
+ require.Equal(t, 1, client.imageCalls)
+}
+
+func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) {
+ client := &stubSoraClientForPoll{
+ enhanced: "cinematic prompt",
+ }
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ PollIntervalSeconds: 1,
+ MaxPollAttempts: 1,
+ },
+ },
+ }
+ svc := NewSoraGatewayService(client, nil, nil, cfg)
+ account := &Account{
+ ID: 1,
+ Platform: PlatformSora,
+ Status: StatusActive,
+ Credentials: map[string]any{
+ "model_mapping": map[string]any{
+ "prompt-enhance-short-10s": "prompt-enhance-short-15s",
+ },
+ },
+ }
+ body := []byte(`{"model":"prompt-enhance-short-10s","messages":[{"role":"user","content":"cat running"}],"stream":false}`)
+
+ result, err := svc.Forward(context.Background(), nil, account, body, false)
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, "prompt", result.MediaType)
+ require.Equal(t, "prompt-enhance-short-10s", result.Model)
+ require.Equal(t, "prompt-enhance-short-15s", result.UpstreamModel)
+}
+
+func TestSoraGatewayService_ForwardStoryboardPrompt(t *testing.T) {
+ client := &stubSoraClientForPoll{
+ videoStatus: &SoraVideoTaskStatus{
+ Status: "completed",
+ URLs: []string{"https://example.com/v.mp4"},
+ },
+ }
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ PollIntervalSeconds: 1,
+ MaxPollAttempts: 1,
+ },
+ },
+ }
+ svc := NewSoraGatewayService(client, nil, nil, cfg)
+ account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
+ body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"[5.0s]猫猫跳伞 [5.0s]猫猫落地"}],"stream":false}`)
+
+ result, err := svc.Forward(context.Background(), nil, account, body, false)
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.True(t, client.storyboard)
+}
+
+func TestSoraGatewayService_ForwardVideoCount(t *testing.T) {
+ client := &stubSoraClientForPoll{
+ videoStatus: &SoraVideoTaskStatus{
+ Status: "completed",
+ URLs: []string{"https://example.com/v.mp4"},
+ },
+ }
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ PollIntervalSeconds: 1,
+ MaxPollAttempts: 1,
+ },
+ },
+ }
+ svc := NewSoraGatewayService(client, nil, nil, cfg)
+ account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
+ body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"video_count":3,"stream":false}`)
+
+ result, err := svc.Forward(context.Background(), nil, account, body, false)
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, 3, client.videoReq.VideoCount)
+}
+
+func TestSoraGatewayService_ForwardCharacterOnly(t *testing.T) {
+ client := &stubSoraClientForPoll{}
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ PollIntervalSeconds: 1,
+ MaxPollAttempts: 1,
+ },
+ },
+ }
+ svc := NewSoraGatewayService(client, nil, nil, cfg)
+ account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
+ body := []byte(`{"model":"sora2-landscape-10s","video":"aGVsbG8=","stream":false}`)
+
+ result, err := svc.Forward(context.Background(), nil, account, body, false)
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, "prompt", result.MediaType)
+ require.Equal(t, 0, client.videoCalls)
+}
+
+func TestSoraGatewayService_ForwardWatermarkFallback(t *testing.T) {
+ client := &stubSoraClientForPoll{
+ videoStatus: &SoraVideoTaskStatus{
+ Status: "completed",
+ URLs: []string{"https://example.com/original.mp4"},
+ GenerationID: "gen_1",
+ },
+ parseErr: errors.New("parse failed"),
+ }
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ PollIntervalSeconds: 1,
+ MaxPollAttempts: 1,
+ },
+ },
+ }
+ svc := NewSoraGatewayService(client, nil, nil, cfg)
+ account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
+ body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_fallback_on_failure":true}`)
+
+ result, err := svc.Forward(context.Background(), nil, account, body, false)
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, "https://example.com/original.mp4", result.MediaURL)
+ require.Equal(t, 1, client.postCalls)
+ require.Equal(t, 0, client.deleteCalls)
+}
+
+func TestSoraGatewayService_ForwardWatermarkCustomSuccessAndDelete(t *testing.T) {
+ client := &stubSoraClientForPoll{
+ videoStatus: &SoraVideoTaskStatus{
+ Status: "completed",
+ URLs: []string{"https://example.com/original.mp4"},
+ GenerationID: "gen_1",
+ },
+ }
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ PollIntervalSeconds: 1,
+ MaxPollAttempts: 1,
+ },
+ },
+ }
+ svc := NewSoraGatewayService(client, nil, nil, cfg)
+ account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
+ body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_delete_post":true}`)
+
+ result, err := svc.Forward(context.Background(), nil, account, body, false)
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, "https://example.com/no-watermark.mp4", result.MediaURL)
+ require.Equal(t, 1, client.postCalls)
+ require.Equal(t, 1, client.deleteCalls)
+}
+
+func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) {
+ client := &stubSoraClientForPoll{
+ videoStatus: &SoraVideoTaskStatus{
+ Status: "failed",
+ ErrorMsg: "reject",
+ },
+ }
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ PollIntervalSeconds: 1,
+ MaxPollAttempts: 1,
+ },
+ },
+ }
+ service := NewSoraGatewayService(client, nil, nil, cfg)
+
+ status, err := service.pollVideoTaskDetailed(context.Background(), nil, &Account{ID: 1}, "task", false)
+ require.Error(t, err)
+ require.Nil(t, status)
+ require.Contains(t, err.Error(), "reject")
+ require.Equal(t, 1, client.videoCalls)
+}
+
+func TestSoraGatewayService_BuildSoraMediaURLSigned(t *testing.T) {
+ cfg := &config.Config{
+ Gateway: config.GatewayConfig{
+ SoraMediaSigningKey: "test-key",
+ SoraMediaSignedURLTTLSeconds: 600,
+ },
+ }
+ service := NewSoraGatewayService(nil, nil, nil, cfg)
+
+ url := service.buildSoraMediaURL("/image/2025/01/01/a.png", "")
+ require.Contains(t, url, "/sora/media-signed")
+ require.Contains(t, url, "expires=")
+ require.Contains(t, url, "sig=")
+}
+
+func TestNormalizeSoraMediaURLs_Empty(t *testing.T) {
+ svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
+ result := svc.normalizeSoraMediaURLs(nil)
+ require.Empty(t, result)
+
+ result = svc.normalizeSoraMediaURLs([]string{})
+ require.Empty(t, result)
+}
+
+func TestNormalizeSoraMediaURLs_HTTPUrls(t *testing.T) {
+ svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
+ urls := []string{"https://example.com/a.png", "http://example.com/b.mp4"}
+ result := svc.normalizeSoraMediaURLs(urls)
+ require.Equal(t, urls, result)
+}
+
+func TestNormalizeSoraMediaURLs_LocalPaths(t *testing.T) {
+ cfg := &config.Config{}
+ svc := NewSoraGatewayService(nil, nil, nil, cfg)
+ urls := []string{"/image/2025/01/a.png", "video/2025/01/b.mp4"}
+ result := svc.normalizeSoraMediaURLs(urls)
+ require.Len(t, result, 2)
+ require.Contains(t, result[0], "/sora/media")
+ require.Contains(t, result[1], "/sora/media")
+}
+
+func TestNormalizeSoraMediaURLs_SkipsBlank(t *testing.T) {
+ svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
+ urls := []string{"https://example.com/a.png", "", " ", "https://example.com/b.png"}
+ result := svc.normalizeSoraMediaURLs(urls)
+ require.Len(t, result, 2)
+}
+
+func TestBuildSoraContent_Image(t *testing.T) {
+ content := buildSoraContent("image", []string{"https://a.com/1.png", "https://a.com/2.png"})
+ require.Contains(t, content, "")
+ require.Contains(t, content, "")
+}
+
+func TestBuildSoraContent_Video(t *testing.T) {
+ content := buildSoraContent("video", []string{"https://a.com/v.mp4"})
+ require.Contains(t, content, "