package provision import ( "context" "errors" "fmt" "strings" "sub2api-cn-relay-manager/internal/access" "sub2api-cn-relay-manager/internal/host/sub2api" "sub2api-cn-relay-manager/internal/pack" ) const ( ImportModeStrict = "strict" ImportModePartial = "partial" AccessModeSubscription = "subscription" AccessModeSelfService = "self_service" BatchStatusSucceeded = "succeeded" BatchStatusPartial = "partially_succeeded" BatchStatusFailed = "failed" BatchStatusRolledBack = "rolled_back" ProviderStatusActive = "active" ProviderStatusDegraded = "degraded" ProviderStatusFailed = "failed" AccessStatusSubscriptionReady = "subscription_ready" AccessStatusSelfServiceReady = "self_service_ready" AccessStatusFullyReady = "fully_ready" AccessStatusBroken = "broken" ) type AccessRequest struct { Mode string ProbeAPIKey string Subscriptions []SubscriptionTarget } type SubscriptionTarget struct { UserID string DurationDays int } type ImportRequest struct { Provider pack.ProviderManifest Mode string Access AccessRequest Keys []string } type ImportReport struct { BatchStatus string ProviderStatus string AccessStatus string AcceptedKeys []string Group sub2api.GroupRef Channel sub2api.ChannelRef Plan *sub2api.PlanRef Accounts []AccountImportResult Gateway sub2api.GatewayAccessResult } type AccountImportResult struct { Ref sub2api.AccountRef Probe sub2api.ProbeResult Models []sub2api.AccountModel SmokeModelSeen bool } type hostAdapter interface { sub2api.HostAdapter CheckGatewayAccess(ctx context.Context, req sub2api.GatewayAccessCheckRequest) (sub2api.GatewayAccessResult, error) } type resolvedManagedResources struct { Group sub2api.GroupRef Channel sub2api.ChannelRef Plan *sub2api.PlanRef CreatedGroup bool CreatedChannel bool CreatedPlan bool } type ImportService struct { host hostAdapter } func NewImportService(host hostAdapter) *ImportService { return &ImportService{host: host} } func (s *ImportService) Import(ctx context.Context, req ImportRequest) (report ImportReport, err error) { normalizedKeys, err := normalizeKeys(req.Keys) if err != nil { return ImportReport{}, err } if err := validateMode(req.Mode); err != nil { return ImportReport{}, err } if err := access.Validate(access.ClosureRequest{ Mode: req.Access.Mode, ProbeAPIKey: req.Access.ProbeAPIKey, Subscriptions: toAccessSubscriptionTargets(req.Access.Subscriptions), }); err != nil { return ImportReport{}, err } report = ImportReport{AcceptedKeys: normalizedKeys} rollback := newManagedResourceRollback(s.host) defer func() { if err == nil || req.Mode != ImportModeStrict { return } if rollbackErr := rollback.Run(ctx); rollbackErr != nil { err = errors.Join(err, fmt.Errorf("rollback managed resources: %w", rollbackErr)) } }() resources, err := s.ensureManagedResources(ctx, req.Provider, req.Access.Mode) if err != nil { return report, err } report.Group = resources.Group report.Channel = resources.Channel report.Plan = resources.Plan if resources.CreatedGroup { rollback.AddGroup(resources.Group.ID) } if resources.CreatedChannel { rollback.AddChannel(resources.Channel.ID) } if resources.CreatedPlan && resources.Plan != nil { rollback.AddPlan(resources.Plan.ID) } accounts, err := s.host.BatchCreateAccounts(ctx, buildBatchAccountsRequest(req.Provider, resources.Group.ID, normalizedKeys)) if err != nil { return report, fmt.Errorf("batch create accounts: %w", err) } rollback.AddAccounts(accounts) for _, account := range accounts { probe, err := s.host.TestAccount(ctx, account.ID) if err != nil { return failOrDegrade(report, req.Mode, fmt.Errorf("test account %s: %w", account.ID, err)) } models, err := s.host.GetAccountModels(ctx, account.ID) if err != nil { return failOrDegrade(report, req.Mode, fmt.Errorf("get account models %s: %w", account.ID, err)) } result := AccountImportResult{Ref: account, Probe: probe, Models: models, SmokeModelSeen: hasModel(models, req.Provider.SmokeTestModel)} report.Accounts = append(report.Accounts, result) } failedAccounts := 0 for _, account := range report.Accounts { if !account.Probe.OK || !account.SmokeModelSeen { failedAccounts++ } } if failedAccounts > 0 && req.Mode == ImportModeStrict { report.BatchStatus = BatchStatusFailed report.ProviderStatus = ProviderStatusFailed report.AccessStatus = AccessStatusBroken return report, fmt.Errorf("strict import failed: %d account(s) did not pass smoke validation", failedAccounts) } closureService := access.NewService(s.host) gateway, err := closureService.Close(ctx, access.ClosureRequest{ Mode: req.Access.Mode, ProbeAPIKey: req.Access.ProbeAPIKey, Subscriptions: toAccessSubscriptionTargets(req.Access.Subscriptions), GroupID: resources.Group.ID, ExpectedModel: req.Provider.SmokeTestModel, }) if err != nil { return failOrDegrade(report, req.Mode, err) } report.Gateway = gateway report.BatchStatus = BatchStatusSucceeded report.ProviderStatus = ProviderStatusActive if failedAccounts > 0 || !gateway.OK || !gateway.HasExpectedModel { report.BatchStatus = BatchStatusPartial report.ProviderStatus = ProviderStatusDegraded } switch req.Access.Mode { case AccessModeSubscription: report.AccessStatus = AccessStatusSubscriptionReady case AccessModeSelfService: report.AccessStatus = AccessStatusSelfServiceReady } if !gateway.OK || !gateway.HasExpectedModel { report.AccessStatus = AccessStatusBroken } return report, nil } func (s *ImportService) ensureManagedResources(ctx context.Context, provider pack.ProviderManifest, accessMode string) (resolvedManagedResources, error) { names := SuggestResourceNamesForMode(provider, accessMode) snapshot, err := s.host.ListManagedResources(ctx, sub2api.ListManagedResourcesRequest{ GroupName: names.Group, ChannelName: names.Channel, PlanName: names.Plan, }) if err != nil { return resolvedManagedResources{}, fmt.Errorf("list managed resources: %w", err) } result := resolvedManagedResources{} group, created, err := ensureGroup(ctx, s.host, snapshot.Groups, provider, accessMode, names.Group) if err != nil { return resolvedManagedResources{}, fmt.Errorf("ensure group: %w", err) } result.Group = group result.CreatedGroup = created channel, created, err := ensureChannel(ctx, s.host, snapshot.Channels, provider, group.ID, names.Channel) if err != nil { return resolvedManagedResources{}, fmt.Errorf("ensure channel: %w", err) } result.Channel = channel result.CreatedChannel = created if accessMode == AccessModeSubscription { plan, created, err := ensurePlan(ctx, s.host, snapshot.Plans, provider, group.ID, names.Plan) if err != nil { return resolvedManagedResources{}, fmt.Errorf("ensure plan: %w", err) } result.Plan = &plan result.CreatedPlan = created } return result, nil } func ensureGroup(ctx context.Context, host hostAdapter, existing []sub2api.NamedResource, provider pack.ProviderManifest, accessMode, groupName string) (sub2api.GroupRef, bool, error) { switch len(existing) { case 0: groupReq := sub2api.CreateGroupRequest{Name: groupName, Platform: provider.Platform, RateMultiplier: provider.GroupTemplate.RateMultiplier} if accessMode == AccessModeSubscription { groupReq.SubscriptionType = "subscription" } group, err := host.CreateGroup(ctx, groupReq) return group, true, err case 1: return sub2api.GroupRef{ID: existing[0].ID, Name: existing[0].Name}, false, nil default: return sub2api.GroupRef{}, false, fmt.Errorf("multiple groups already exist for %q", groupName) } } func ensureChannel(ctx context.Context, host hostAdapter, existing []sub2api.NamedResource, provider pack.ProviderManifest, groupID, channelName string) (sub2api.ChannelRef, bool, error) { channelReq := buildChannelRequest(provider, groupID, channelName) switch len(existing) { case 0: channel, err := host.CreateChannel(ctx, channelReq) return channel, true, err case 1: if err := host.UpdateChannel(ctx, existing[0].ID, channelReq); err != nil { return sub2api.ChannelRef{}, false, err } return sub2api.ChannelRef{ID: existing[0].ID, Name: existing[0].Name}, false, nil default: return sub2api.ChannelRef{}, false, fmt.Errorf("multiple channels already exist for %q", channelName) } } func buildChannelRequest(provider pack.ProviderManifest, groupID, channelName string) sub2api.CreateChannelRequest { return sub2api.CreateChannelRequest{ Name: channelName, GroupIDs: []string{groupID}, ModelMapping: provider.ChannelTemplate.ModelMapping, ModelPricing: []sub2api.ChannelModelPricing{{ Platform: provider.Platform, Models: append([]string(nil), provider.DefaultModels...), BillingMode: "token", Intervals: []sub2api.ChannelPricingTier{}, }}, Platform: provider.Platform, RestrictModels: true, BillingModelSource: "channel_mapped", } } func ensurePlan(ctx context.Context, host hostAdapter, existing []sub2api.NamedResource, provider pack.ProviderManifest, groupID, planName string) (sub2api.PlanRef, bool, error) { switch len(existing) { case 0: plan, err := host.CreatePlan(ctx, sub2api.CreatePlanRequest{GroupID: groupID, Name: planName, Price: provider.PlanTemplate.Price, ValidityDays: provider.PlanTemplate.ValidityDays, ValidityUnit: provider.PlanTemplate.ValidityUnit}) return plan, true, err case 1: return sub2api.PlanRef{ID: existing[0].ID, Name: existing[0].Name}, false, nil default: return sub2api.PlanRef{}, false, fmt.Errorf("multiple plans already exist for %q", planName) } } func validateMode(mode string) error { switch strings.TrimSpace(mode) { case ImportModeStrict, ImportModePartial: return nil default: return fmt.Errorf("unsupported import mode %q", mode) } } func toAccessSubscriptionTargets(targets []SubscriptionTarget) []access.SubscriptionTarget { result := make([]access.SubscriptionTarget, 0, len(targets)) for _, target := range targets { result = append(result, access.SubscriptionTarget{UserID: target.UserID, DurationDays: target.DurationDays}) } return result } func normalizeKeys(keys []string) ([]string, error) { seen := map[string]struct{}{} result := make([]string, 0, len(keys)) for _, key := range keys { normalized := strings.TrimSpace(strings.TrimPrefix(key, "\ufeff")) if normalized == "" { continue } if _, ok := seen[normalized]; ok { continue } seen[normalized] = struct{}{} result = append(result, normalized) } if len(result) == 0 { return nil, fmt.Errorf("at least one api key is required") } return result, nil } func buildBatchAccountsRequest(provider pack.ProviderManifest, groupID string, keys []string) sub2api.BatchCreateAccountsRequest { accounts := make([]sub2api.CreateAccountRequest, 0, len(keys)) for index, key := range keys { accounts = append(accounts, sub2api.CreateAccountRequest{ Name: fmt.Sprintf("%s-%02d", provider.ProviderID, index+1), Platform: provider.Platform, Type: provider.AccountType, GroupIDs: []string{groupID}, Credentials: map[string]any{ "base_url": provider.BaseURL, "api_key": key, "model_mapping": provider.ChannelTemplate.ModelMapping, }, }) } return sub2api.BatchCreateAccountsRequest{Accounts: accounts} } func hasModel(models []sub2api.AccountModel, target string) bool { for _, model := range models { if strings.TrimSpace(model.ID) == strings.TrimSpace(target) { return true } } return false } type managedResourceRollback struct { host hostAdapter groupID string channelID string planID string accountIDs []string } func newManagedResourceRollback(host hostAdapter) *managedResourceRollback { return &managedResourceRollback{host: host} } func (r *managedResourceRollback) AddGroup(groupID string) { r.groupID = strings.TrimSpace(groupID) } func (r *managedResourceRollback) AddChannel(channelID string) { r.channelID = strings.TrimSpace(channelID) } func (r *managedResourceRollback) AddPlan(planID string) { r.planID = strings.TrimSpace(planID) } func (r *managedResourceRollback) AddAccounts(accounts []sub2api.AccountRef) { for _, account := range accounts { accountID := strings.TrimSpace(account.ID) if accountID == "" { continue } r.accountIDs = append(r.accountIDs, accountID) } } func (r *managedResourceRollback) Run(ctx context.Context) error { if r == nil || r.host == nil { return nil } var errs []error for index := len(r.accountIDs) - 1; index >= 0; index-- { if err := r.host.DeleteAccount(ctx, r.accountIDs[index]); err != nil { errs = append(errs, fmt.Errorf("delete account %s: %w", r.accountIDs[index], err)) } } if r.planID != "" { if err := r.host.DeletePlan(ctx, r.planID); err != nil { errs = append(errs, fmt.Errorf("delete plan %s: %w", r.planID, err)) } } if r.channelID != "" { if err := r.host.DeleteChannel(ctx, r.channelID); err != nil { errs = append(errs, fmt.Errorf("delete channel %s: %w", r.channelID, err)) } } if r.groupID != "" { if err := r.host.DeleteGroup(ctx, r.groupID); err != nil { errs = append(errs, fmt.Errorf("delete group %s: %w", r.groupID, err)) } } return errors.Join(errs...) } func failOrDegrade(report ImportReport, mode string, err error) (ImportReport, error) { if mode == ImportModeStrict { report.BatchStatus = BatchStatusFailed report.ProviderStatus = ProviderStatusFailed report.AccessStatus = AccessStatusBroken return report, err } report.BatchStatus = BatchStatusPartial report.ProviderStatus = ProviderStatusDegraded report.AccessStatus = AccessStatusBroken return report, err }