Files
sub2api-cn-relay-manager/internal/provision/import_service.go

438 lines
14 KiB
Go

package provision
import (
"context"
"errors"
"fmt"
"strings"
"sub2api-cn-relay-manager/internal/access"
"sub2api-cn-relay-manager/internal/host/sub2api"
"sub2api-cn-relay-manager/internal/pack"
)
const (
ImportModeStrict = "strict"
ImportModePartial = "partial"
AccessModeSubscription = "subscription"
AccessModeSelfService = "self_service"
BatchStatusSucceeded = "succeeded"
BatchStatusPartial = "partially_succeeded"
BatchStatusFailed = "failed"
BatchStatusRolledBack = "rolled_back"
ProviderStatusActive = "active"
ProviderStatusDegraded = "degraded"
ProviderStatusFailed = "failed"
AccessStatusSubscriptionReady = "subscription_ready"
AccessStatusSelfServiceReady = "self_service_ready"
AccessStatusFullyReady = "fully_ready"
AccessStatusBroken = "broken"
)
type AccessRequest struct {
Mode string
ProbeAPIKey string
Subscriptions []SubscriptionTarget
}
type SubscriptionTarget struct {
UserID string
DurationDays int
}
type ImportRequest struct {
Provider pack.ProviderManifest
Mode string
Access AccessRequest
Keys []string
}
type ImportReport struct {
BatchStatus string
ProviderStatus string
AccessStatus string
AcceptedKeys []string
Group sub2api.GroupRef
Channel sub2api.ChannelRef
Plan *sub2api.PlanRef
Accounts []AccountImportResult
Gateway sub2api.GatewayAccessResult
}
type AccountImportResult struct {
Ref sub2api.AccountRef
Probe sub2api.ProbeResult
Models []sub2api.AccountModel
SmokeModelSeen bool
}
type hostAdapter interface {
sub2api.HostAdapter
CheckGatewayAccess(ctx context.Context, req sub2api.GatewayAccessCheckRequest) (sub2api.GatewayAccessResult, error)
}
type resolvedManagedResources struct {
Group sub2api.GroupRef
Channel sub2api.ChannelRef
Plan *sub2api.PlanRef
CreatedGroup bool
CreatedChannel bool
CreatedPlan bool
}
type ImportService struct {
host hostAdapter
}
func NewImportService(host hostAdapter) *ImportService {
return &ImportService{host: host}
}
func (s *ImportService) Import(ctx context.Context, req ImportRequest) (report ImportReport, err error) {
normalizedKeys, err := normalizeKeys(req.Keys)
if err != nil {
return ImportReport{}, err
}
if err := validateMode(req.Mode); err != nil {
return ImportReport{}, err
}
if err := access.Validate(access.ClosureRequest{
Mode: req.Access.Mode,
ProbeAPIKey: req.Access.ProbeAPIKey,
Subscriptions: toAccessSubscriptionTargets(req.Access.Subscriptions),
}); err != nil {
return ImportReport{}, err
}
report = ImportReport{AcceptedKeys: normalizedKeys}
rollback := newManagedResourceRollback(s.host)
defer func() {
if err == nil || req.Mode != ImportModeStrict {
return
}
if rollbackErr := rollback.Run(ctx); rollbackErr != nil {
err = errors.Join(err, fmt.Errorf("rollback managed resources: %w", rollbackErr))
}
}()
resources, err := s.ensureManagedResources(ctx, req.Provider, req.Access.Mode)
if err != nil {
return report, err
}
report.Group = resources.Group
report.Channel = resources.Channel
report.Plan = resources.Plan
if resources.CreatedGroup {
rollback.AddGroup(resources.Group.ID)
}
if resources.CreatedChannel {
rollback.AddChannel(resources.Channel.ID)
}
if resources.CreatedPlan && resources.Plan != nil {
rollback.AddPlan(resources.Plan.ID)
}
accounts, err := s.host.BatchCreateAccounts(ctx, buildBatchAccountsRequest(req.Provider, resources.Group.ID, normalizedKeys))
if err != nil {
return report, fmt.Errorf("batch create accounts: %w", err)
}
rollback.AddAccounts(accounts)
for _, account := range accounts {
probe, err := s.host.TestAccount(ctx, account.ID)
if err != nil {
return failOrDegrade(report, req.Mode, fmt.Errorf("test account %s: %w", account.ID, err))
}
models, err := s.host.GetAccountModels(ctx, account.ID)
if err != nil {
return failOrDegrade(report, req.Mode, fmt.Errorf("get account models %s: %w", account.ID, err))
}
result := AccountImportResult{Ref: account, Probe: probe, Models: models, SmokeModelSeen: hasModel(models, req.Provider.SmokeTestModel)}
report.Accounts = append(report.Accounts, result)
}
failedAccounts := 0
for _, account := range report.Accounts {
if !account.Probe.OK || !account.SmokeModelSeen {
failedAccounts++
}
}
if failedAccounts > 0 && req.Mode == ImportModeStrict {
report.BatchStatus = BatchStatusFailed
report.ProviderStatus = ProviderStatusFailed
report.AccessStatus = AccessStatusBroken
return report, fmt.Errorf("strict import failed: %d account(s) did not pass smoke validation", failedAccounts)
}
closureService := access.NewService(s.host)
gateway, err := closureService.Close(ctx, access.ClosureRequest{
Mode: req.Access.Mode,
ProbeAPIKey: req.Access.ProbeAPIKey,
Subscriptions: toAccessSubscriptionTargets(req.Access.Subscriptions),
GroupID: resources.Group.ID,
ExpectedModel: req.Provider.SmokeTestModel,
})
if err != nil {
return failOrDegrade(report, req.Mode, err)
}
report.Gateway = gateway
report.BatchStatus = BatchStatusSucceeded
report.ProviderStatus = ProviderStatusActive
if failedAccounts > 0 || !gateway.OK || !gateway.HasExpectedModel {
report.BatchStatus = BatchStatusPartial
report.ProviderStatus = ProviderStatusDegraded
}
switch req.Access.Mode {
case AccessModeSubscription:
report.AccessStatus = AccessStatusSubscriptionReady
case AccessModeSelfService:
report.AccessStatus = AccessStatusSelfServiceReady
}
if !gateway.OK || !gateway.HasExpectedModel {
report.AccessStatus = AccessStatusBroken
}
return report, nil
}
func (s *ImportService) ensureManagedResources(ctx context.Context, provider pack.ProviderManifest, accessMode string) (resolvedManagedResources, error) {
names := SuggestResourceNamesForMode(provider, accessMode)
snapshot, err := s.host.ListManagedResources(ctx, sub2api.ListManagedResourcesRequest{
GroupName: names.Group,
ChannelName: names.Channel,
PlanName: names.Plan,
})
if err != nil {
return resolvedManagedResources{}, fmt.Errorf("list managed resources: %w", err)
}
result := resolvedManagedResources{}
group, created, err := ensureGroup(ctx, s.host, snapshot.Groups, provider, accessMode, names.Group)
if err != nil {
return resolvedManagedResources{}, fmt.Errorf("ensure group: %w", err)
}
result.Group = group
result.CreatedGroup = created
channel, created, err := ensureChannel(ctx, s.host, snapshot.Channels, provider, group.ID, names.Channel)
if err != nil {
return resolvedManagedResources{}, fmt.Errorf("ensure channel: %w", err)
}
result.Channel = channel
result.CreatedChannel = created
if accessMode == AccessModeSubscription {
plan, created, err := ensurePlan(ctx, s.host, snapshot.Plans, provider, group.ID, names.Plan)
if err != nil {
return resolvedManagedResources{}, fmt.Errorf("ensure plan: %w", err)
}
result.Plan = &plan
result.CreatedPlan = created
}
return result, nil
}
func ensureGroup(ctx context.Context, host hostAdapter, existing []sub2api.NamedResource, provider pack.ProviderManifest, accessMode, groupName string) (sub2api.GroupRef, bool, error) {
switch len(existing) {
case 0:
groupReq := sub2api.CreateGroupRequest{Name: groupName, Platform: provider.Platform, RateMultiplier: provider.GroupTemplate.RateMultiplier}
if accessMode == AccessModeSubscription {
groupReq.SubscriptionType = "subscription"
}
group, err := host.CreateGroup(ctx, groupReq)
return group, true, err
case 1:
return sub2api.GroupRef{ID: existing[0].ID, Name: existing[0].Name}, false, nil
default:
return sub2api.GroupRef{}, false, fmt.Errorf("multiple groups already exist for %q", groupName)
}
}
func ensureChannel(ctx context.Context, host hostAdapter, existing []sub2api.NamedResource, provider pack.ProviderManifest, groupID, channelName string) (sub2api.ChannelRef, bool, error) {
channelReq := buildChannelRequest(provider, groupID, channelName)
switch len(existing) {
case 0:
channel, err := host.CreateChannel(ctx, channelReq)
return channel, true, err
case 1:
if err := host.UpdateChannel(ctx, existing[0].ID, channelReq); err != nil {
return sub2api.ChannelRef{}, false, err
}
return sub2api.ChannelRef{ID: existing[0].ID, Name: existing[0].Name}, false, nil
default:
return sub2api.ChannelRef{}, false, fmt.Errorf("multiple channels already exist for %q", channelName)
}
}
func buildChannelRequest(provider pack.ProviderManifest, groupID, channelName string) sub2api.CreateChannelRequest {
return sub2api.CreateChannelRequest{
Name: channelName,
GroupIDs: []string{groupID},
ModelMapping: provider.ChannelTemplate.ModelMapping,
ModelPricing: []sub2api.ChannelModelPricing{{
Platform: provider.Platform,
Models: append([]string(nil), provider.DefaultModels...),
BillingMode: "token",
Intervals: []sub2api.ChannelPricingTier{},
}},
Platform: provider.Platform,
RestrictModels: true,
BillingModelSource: "channel_mapped",
}
}
func ensurePlan(ctx context.Context, host hostAdapter, existing []sub2api.NamedResource, provider pack.ProviderManifest, groupID, planName string) (sub2api.PlanRef, bool, error) {
switch len(existing) {
case 0:
plan, err := host.CreatePlan(ctx, sub2api.CreatePlanRequest{GroupID: groupID, Name: planName, Price: provider.PlanTemplate.Price, ValidityDays: provider.PlanTemplate.ValidityDays, ValidityUnit: provider.PlanTemplate.ValidityUnit})
return plan, true, err
case 1:
return sub2api.PlanRef{ID: existing[0].ID, Name: existing[0].Name}, false, nil
default:
return sub2api.PlanRef{}, false, fmt.Errorf("multiple plans already exist for %q", planName)
}
}
func validateMode(mode string) error {
switch strings.TrimSpace(mode) {
case ImportModeStrict, ImportModePartial:
return nil
default:
return fmt.Errorf("unsupported import mode %q", mode)
}
}
func toAccessSubscriptionTargets(targets []SubscriptionTarget) []access.SubscriptionTarget {
result := make([]access.SubscriptionTarget, 0, len(targets))
for _, target := range targets {
result = append(result, access.SubscriptionTarget{UserID: target.UserID, DurationDays: target.DurationDays})
}
return result
}
func normalizeKeys(keys []string) ([]string, error) {
seen := map[string]struct{}{}
result := make([]string, 0, len(keys))
for _, key := range keys {
normalized := strings.TrimSpace(strings.TrimPrefix(key, "\ufeff"))
if normalized == "" {
continue
}
if _, ok := seen[normalized]; ok {
continue
}
seen[normalized] = struct{}{}
result = append(result, normalized)
}
if len(result) == 0 {
return nil, fmt.Errorf("at least one api key is required")
}
return result, nil
}
func buildBatchAccountsRequest(provider pack.ProviderManifest, groupID string, keys []string) sub2api.BatchCreateAccountsRequest {
accounts := make([]sub2api.CreateAccountRequest, 0, len(keys))
for index, key := range keys {
accounts = append(accounts, sub2api.CreateAccountRequest{
Name: fmt.Sprintf("%s-%02d", provider.ProviderID, index+1),
Platform: provider.Platform,
Type: provider.AccountType,
GroupIDs: []string{groupID},
Credentials: map[string]any{
"base_url": provider.BaseURL,
"api_key": key,
"model_mapping": provider.ChannelTemplate.ModelMapping,
},
})
}
return sub2api.BatchCreateAccountsRequest{Accounts: accounts}
}
func hasModel(models []sub2api.AccountModel, target string) bool {
for _, model := range models {
if strings.TrimSpace(model.ID) == strings.TrimSpace(target) {
return true
}
}
return false
}
type managedResourceRollback struct {
host hostAdapter
groupID string
channelID string
planID string
accountIDs []string
}
func newManagedResourceRollback(host hostAdapter) *managedResourceRollback {
return &managedResourceRollback{host: host}
}
func (r *managedResourceRollback) AddGroup(groupID string) {
r.groupID = strings.TrimSpace(groupID)
}
func (r *managedResourceRollback) AddChannel(channelID string) {
r.channelID = strings.TrimSpace(channelID)
}
func (r *managedResourceRollback) AddPlan(planID string) {
r.planID = strings.TrimSpace(planID)
}
func (r *managedResourceRollback) AddAccounts(accounts []sub2api.AccountRef) {
for _, account := range accounts {
accountID := strings.TrimSpace(account.ID)
if accountID == "" {
continue
}
r.accountIDs = append(r.accountIDs, accountID)
}
}
func (r *managedResourceRollback) Run(ctx context.Context) error {
if r == nil || r.host == nil {
return nil
}
var errs []error
for index := len(r.accountIDs) - 1; index >= 0; index-- {
if err := r.host.DeleteAccount(ctx, r.accountIDs[index]); err != nil {
errs = append(errs, fmt.Errorf("delete account %s: %w", r.accountIDs[index], err))
}
}
if r.planID != "" {
if err := r.host.DeletePlan(ctx, r.planID); err != nil {
errs = append(errs, fmt.Errorf("delete plan %s: %w", r.planID, err))
}
}
if r.channelID != "" {
if err := r.host.DeleteChannel(ctx, r.channelID); err != nil {
errs = append(errs, fmt.Errorf("delete channel %s: %w", r.channelID, err))
}
}
if r.groupID != "" {
if err := r.host.DeleteGroup(ctx, r.groupID); err != nil {
errs = append(errs, fmt.Errorf("delete group %s: %w", r.groupID, err))
}
}
return errors.Join(errs...)
}
func failOrDegrade(report ImportReport, mode string, err error) (ImportReport, error) {
if mode == ImportModeStrict {
report.BatchStatus = BatchStatusFailed
report.ProviderStatus = ProviderStatusFailed
report.AccessStatus = AccessStatusBroken
return report, err
}
report.BatchStatus = BatchStatusPartial
report.ProviderStatus = ProviderStatusDegraded
report.AccessStatus = AccessStatusBroken
return report, err
}