438 lines
14 KiB
Go
438 lines
14 KiB
Go
package provision
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"sub2api-cn-relay-manager/internal/access"
|
|
"sub2api-cn-relay-manager/internal/host/sub2api"
|
|
"sub2api-cn-relay-manager/internal/pack"
|
|
)
|
|
|
|
const (
|
|
ImportModeStrict = "strict"
|
|
ImportModePartial = "partial"
|
|
|
|
AccessModeSubscription = "subscription"
|
|
AccessModeSelfService = "self_service"
|
|
|
|
BatchStatusSucceeded = "succeeded"
|
|
BatchStatusPartial = "partially_succeeded"
|
|
BatchStatusFailed = "failed"
|
|
BatchStatusRolledBack = "rolled_back"
|
|
|
|
ProviderStatusActive = "active"
|
|
ProviderStatusDegraded = "degraded"
|
|
ProviderStatusFailed = "failed"
|
|
|
|
AccessStatusSubscriptionReady = "subscription_ready"
|
|
AccessStatusSelfServiceReady = "self_service_ready"
|
|
AccessStatusFullyReady = "fully_ready"
|
|
AccessStatusBroken = "broken"
|
|
)
|
|
|
|
type AccessRequest struct {
|
|
Mode string
|
|
ProbeAPIKey string
|
|
Subscriptions []SubscriptionTarget
|
|
}
|
|
|
|
type SubscriptionTarget struct {
|
|
UserID string
|
|
DurationDays int
|
|
}
|
|
|
|
type ImportRequest struct {
|
|
Provider pack.ProviderManifest
|
|
Mode string
|
|
Access AccessRequest
|
|
Keys []string
|
|
}
|
|
|
|
type ImportReport struct {
|
|
BatchStatus string
|
|
ProviderStatus string
|
|
AccessStatus string
|
|
AcceptedKeys []string
|
|
Group sub2api.GroupRef
|
|
Channel sub2api.ChannelRef
|
|
Plan *sub2api.PlanRef
|
|
Accounts []AccountImportResult
|
|
Gateway sub2api.GatewayAccessResult
|
|
}
|
|
|
|
type AccountImportResult struct {
|
|
Ref sub2api.AccountRef
|
|
Probe sub2api.ProbeResult
|
|
Models []sub2api.AccountModel
|
|
SmokeModelSeen bool
|
|
}
|
|
|
|
type hostAdapter interface {
|
|
sub2api.HostAdapter
|
|
CheckGatewayAccess(ctx context.Context, req sub2api.GatewayAccessCheckRequest) (sub2api.GatewayAccessResult, error)
|
|
}
|
|
|
|
type resolvedManagedResources struct {
|
|
Group sub2api.GroupRef
|
|
Channel sub2api.ChannelRef
|
|
Plan *sub2api.PlanRef
|
|
|
|
CreatedGroup bool
|
|
CreatedChannel bool
|
|
CreatedPlan bool
|
|
}
|
|
|
|
type ImportService struct {
|
|
host hostAdapter
|
|
}
|
|
|
|
func NewImportService(host hostAdapter) *ImportService {
|
|
return &ImportService{host: host}
|
|
}
|
|
|
|
func (s *ImportService) Import(ctx context.Context, req ImportRequest) (report ImportReport, err error) {
|
|
normalizedKeys, err := normalizeKeys(req.Keys)
|
|
if err != nil {
|
|
return ImportReport{}, err
|
|
}
|
|
if err := validateMode(req.Mode); err != nil {
|
|
return ImportReport{}, err
|
|
}
|
|
if err := access.Validate(access.ClosureRequest{
|
|
Mode: req.Access.Mode,
|
|
ProbeAPIKey: req.Access.ProbeAPIKey,
|
|
Subscriptions: toAccessSubscriptionTargets(req.Access.Subscriptions),
|
|
}); err != nil {
|
|
return ImportReport{}, err
|
|
}
|
|
|
|
report = ImportReport{AcceptedKeys: normalizedKeys}
|
|
rollback := newManagedResourceRollback(s.host)
|
|
defer func() {
|
|
if err == nil || req.Mode != ImportModeStrict {
|
|
return
|
|
}
|
|
if rollbackErr := rollback.Run(ctx); rollbackErr != nil {
|
|
err = errors.Join(err, fmt.Errorf("rollback managed resources: %w", rollbackErr))
|
|
}
|
|
}()
|
|
resources, err := s.ensureManagedResources(ctx, req.Provider, req.Access.Mode)
|
|
if err != nil {
|
|
return report, err
|
|
}
|
|
report.Group = resources.Group
|
|
report.Channel = resources.Channel
|
|
report.Plan = resources.Plan
|
|
if resources.CreatedGroup {
|
|
rollback.AddGroup(resources.Group.ID)
|
|
}
|
|
if resources.CreatedChannel {
|
|
rollback.AddChannel(resources.Channel.ID)
|
|
}
|
|
if resources.CreatedPlan && resources.Plan != nil {
|
|
rollback.AddPlan(resources.Plan.ID)
|
|
}
|
|
|
|
accounts, err := s.host.BatchCreateAccounts(ctx, buildBatchAccountsRequest(req.Provider, resources.Group.ID, normalizedKeys))
|
|
if err != nil {
|
|
return report, fmt.Errorf("batch create accounts: %w", err)
|
|
}
|
|
rollback.AddAccounts(accounts)
|
|
for _, account := range accounts {
|
|
probe, err := s.host.TestAccount(ctx, account.ID)
|
|
if err != nil {
|
|
return failOrDegrade(report, req.Mode, fmt.Errorf("test account %s: %w", account.ID, err))
|
|
}
|
|
models, err := s.host.GetAccountModels(ctx, account.ID)
|
|
if err != nil {
|
|
return failOrDegrade(report, req.Mode, fmt.Errorf("get account models %s: %w", account.ID, err))
|
|
}
|
|
result := AccountImportResult{Ref: account, Probe: probe, Models: models, SmokeModelSeen: hasModel(models, req.Provider.SmokeTestModel)}
|
|
report.Accounts = append(report.Accounts, result)
|
|
}
|
|
|
|
failedAccounts := 0
|
|
for _, account := range report.Accounts {
|
|
if !account.Probe.OK || !account.SmokeModelSeen {
|
|
failedAccounts++
|
|
}
|
|
}
|
|
if failedAccounts > 0 && req.Mode == ImportModeStrict {
|
|
report.BatchStatus = BatchStatusFailed
|
|
report.ProviderStatus = ProviderStatusFailed
|
|
report.AccessStatus = AccessStatusBroken
|
|
return report, fmt.Errorf("strict import failed: %d account(s) did not pass smoke validation", failedAccounts)
|
|
}
|
|
|
|
closureService := access.NewService(s.host)
|
|
gateway, err := closureService.Close(ctx, access.ClosureRequest{
|
|
Mode: req.Access.Mode,
|
|
ProbeAPIKey: req.Access.ProbeAPIKey,
|
|
Subscriptions: toAccessSubscriptionTargets(req.Access.Subscriptions),
|
|
GroupID: resources.Group.ID,
|
|
ExpectedModel: req.Provider.SmokeTestModel,
|
|
})
|
|
if err != nil {
|
|
return failOrDegrade(report, req.Mode, err)
|
|
}
|
|
report.Gateway = gateway
|
|
|
|
report.BatchStatus = BatchStatusSucceeded
|
|
report.ProviderStatus = ProviderStatusActive
|
|
if failedAccounts > 0 || !gateway.OK || !gateway.HasExpectedModel {
|
|
report.BatchStatus = BatchStatusPartial
|
|
report.ProviderStatus = ProviderStatusDegraded
|
|
}
|
|
switch req.Access.Mode {
|
|
case AccessModeSubscription:
|
|
report.AccessStatus = AccessStatusSubscriptionReady
|
|
case AccessModeSelfService:
|
|
report.AccessStatus = AccessStatusSelfServiceReady
|
|
}
|
|
if !gateway.OK || !gateway.HasExpectedModel {
|
|
report.AccessStatus = AccessStatusBroken
|
|
}
|
|
return report, nil
|
|
}
|
|
|
|
func (s *ImportService) ensureManagedResources(ctx context.Context, provider pack.ProviderManifest, accessMode string) (resolvedManagedResources, error) {
|
|
names := SuggestResourceNamesForMode(provider, accessMode)
|
|
snapshot, err := s.host.ListManagedResources(ctx, sub2api.ListManagedResourcesRequest{
|
|
GroupName: names.Group,
|
|
ChannelName: names.Channel,
|
|
PlanName: names.Plan,
|
|
})
|
|
if err != nil {
|
|
return resolvedManagedResources{}, fmt.Errorf("list managed resources: %w", err)
|
|
}
|
|
|
|
result := resolvedManagedResources{}
|
|
group, created, err := ensureGroup(ctx, s.host, snapshot.Groups, provider, accessMode, names.Group)
|
|
if err != nil {
|
|
return resolvedManagedResources{}, fmt.Errorf("ensure group: %w", err)
|
|
}
|
|
result.Group = group
|
|
result.CreatedGroup = created
|
|
|
|
channel, created, err := ensureChannel(ctx, s.host, snapshot.Channels, provider, group.ID, names.Channel)
|
|
if err != nil {
|
|
return resolvedManagedResources{}, fmt.Errorf("ensure channel: %w", err)
|
|
}
|
|
result.Channel = channel
|
|
result.CreatedChannel = created
|
|
|
|
if accessMode == AccessModeSubscription {
|
|
plan, created, err := ensurePlan(ctx, s.host, snapshot.Plans, provider, group.ID, names.Plan)
|
|
if err != nil {
|
|
return resolvedManagedResources{}, fmt.Errorf("ensure plan: %w", err)
|
|
}
|
|
result.Plan = &plan
|
|
result.CreatedPlan = created
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func ensureGroup(ctx context.Context, host hostAdapter, existing []sub2api.NamedResource, provider pack.ProviderManifest, accessMode, groupName string) (sub2api.GroupRef, bool, error) {
|
|
switch len(existing) {
|
|
case 0:
|
|
groupReq := sub2api.CreateGroupRequest{Name: groupName, Platform: provider.Platform, RateMultiplier: provider.GroupTemplate.RateMultiplier}
|
|
if accessMode == AccessModeSubscription {
|
|
groupReq.SubscriptionType = "subscription"
|
|
}
|
|
group, err := host.CreateGroup(ctx, groupReq)
|
|
return group, true, err
|
|
case 1:
|
|
return sub2api.GroupRef{ID: existing[0].ID, Name: existing[0].Name}, false, nil
|
|
default:
|
|
return sub2api.GroupRef{}, false, fmt.Errorf("multiple groups already exist for %q", groupName)
|
|
}
|
|
}
|
|
|
|
func ensureChannel(ctx context.Context, host hostAdapter, existing []sub2api.NamedResource, provider pack.ProviderManifest, groupID, channelName string) (sub2api.ChannelRef, bool, error) {
|
|
channelReq := buildChannelRequest(provider, groupID, channelName)
|
|
switch len(existing) {
|
|
case 0:
|
|
channel, err := host.CreateChannel(ctx, channelReq)
|
|
return channel, true, err
|
|
case 1:
|
|
if err := host.UpdateChannel(ctx, existing[0].ID, channelReq); err != nil {
|
|
return sub2api.ChannelRef{}, false, err
|
|
}
|
|
return sub2api.ChannelRef{ID: existing[0].ID, Name: existing[0].Name}, false, nil
|
|
default:
|
|
return sub2api.ChannelRef{}, false, fmt.Errorf("multiple channels already exist for %q", channelName)
|
|
}
|
|
}
|
|
|
|
func buildChannelRequest(provider pack.ProviderManifest, groupID, channelName string) sub2api.CreateChannelRequest {
|
|
return sub2api.CreateChannelRequest{
|
|
Name: channelName,
|
|
GroupIDs: []string{groupID},
|
|
ModelMapping: provider.ChannelTemplate.ModelMapping,
|
|
ModelPricing: []sub2api.ChannelModelPricing{{
|
|
Platform: provider.Platform,
|
|
Models: append([]string(nil), provider.DefaultModels...),
|
|
BillingMode: "token",
|
|
Intervals: []sub2api.ChannelPricingTier{},
|
|
}},
|
|
Platform: provider.Platform,
|
|
RestrictModels: true,
|
|
BillingModelSource: "channel_mapped",
|
|
}
|
|
}
|
|
|
|
func ensurePlan(ctx context.Context, host hostAdapter, existing []sub2api.NamedResource, provider pack.ProviderManifest, groupID, planName string) (sub2api.PlanRef, bool, error) {
|
|
switch len(existing) {
|
|
case 0:
|
|
plan, err := host.CreatePlan(ctx, sub2api.CreatePlanRequest{GroupID: groupID, Name: planName, Price: provider.PlanTemplate.Price, ValidityDays: provider.PlanTemplate.ValidityDays, ValidityUnit: provider.PlanTemplate.ValidityUnit})
|
|
return plan, true, err
|
|
case 1:
|
|
return sub2api.PlanRef{ID: existing[0].ID, Name: existing[0].Name}, false, nil
|
|
default:
|
|
return sub2api.PlanRef{}, false, fmt.Errorf("multiple plans already exist for %q", planName)
|
|
}
|
|
}
|
|
|
|
func validateMode(mode string) error {
|
|
switch strings.TrimSpace(mode) {
|
|
case ImportModeStrict, ImportModePartial:
|
|
return nil
|
|
default:
|
|
return fmt.Errorf("unsupported import mode %q", mode)
|
|
}
|
|
}
|
|
|
|
func toAccessSubscriptionTargets(targets []SubscriptionTarget) []access.SubscriptionTarget {
|
|
result := make([]access.SubscriptionTarget, 0, len(targets))
|
|
for _, target := range targets {
|
|
result = append(result, access.SubscriptionTarget{UserID: target.UserID, DurationDays: target.DurationDays})
|
|
}
|
|
return result
|
|
}
|
|
|
|
func normalizeKeys(keys []string) ([]string, error) {
|
|
seen := map[string]struct{}{}
|
|
result := make([]string, 0, len(keys))
|
|
for _, key := range keys {
|
|
normalized := strings.TrimSpace(strings.TrimPrefix(key, "\ufeff"))
|
|
if normalized == "" {
|
|
continue
|
|
}
|
|
if _, ok := seen[normalized]; ok {
|
|
continue
|
|
}
|
|
seen[normalized] = struct{}{}
|
|
result = append(result, normalized)
|
|
}
|
|
if len(result) == 0 {
|
|
return nil, fmt.Errorf("at least one api key is required")
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func buildBatchAccountsRequest(provider pack.ProviderManifest, groupID string, keys []string) sub2api.BatchCreateAccountsRequest {
|
|
accounts := make([]sub2api.CreateAccountRequest, 0, len(keys))
|
|
for index, key := range keys {
|
|
accounts = append(accounts, sub2api.CreateAccountRequest{
|
|
Name: fmt.Sprintf("%s-%02d", provider.ProviderID, index+1),
|
|
Platform: provider.Platform,
|
|
Type: provider.AccountType,
|
|
GroupIDs: []string{groupID},
|
|
Credentials: map[string]any{
|
|
"base_url": provider.BaseURL,
|
|
"api_key": key,
|
|
"model_mapping": provider.ChannelTemplate.ModelMapping,
|
|
},
|
|
})
|
|
}
|
|
return sub2api.BatchCreateAccountsRequest{Accounts: accounts}
|
|
}
|
|
|
|
func hasModel(models []sub2api.AccountModel, target string) bool {
|
|
for _, model := range models {
|
|
if strings.TrimSpace(model.ID) == strings.TrimSpace(target) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
type managedResourceRollback struct {
|
|
host hostAdapter
|
|
groupID string
|
|
channelID string
|
|
planID string
|
|
accountIDs []string
|
|
}
|
|
|
|
func newManagedResourceRollback(host hostAdapter) *managedResourceRollback {
|
|
return &managedResourceRollback{host: host}
|
|
}
|
|
|
|
func (r *managedResourceRollback) AddGroup(groupID string) {
|
|
r.groupID = strings.TrimSpace(groupID)
|
|
}
|
|
|
|
func (r *managedResourceRollback) AddChannel(channelID string) {
|
|
r.channelID = strings.TrimSpace(channelID)
|
|
}
|
|
|
|
func (r *managedResourceRollback) AddPlan(planID string) {
|
|
r.planID = strings.TrimSpace(planID)
|
|
}
|
|
|
|
func (r *managedResourceRollback) AddAccounts(accounts []sub2api.AccountRef) {
|
|
for _, account := range accounts {
|
|
accountID := strings.TrimSpace(account.ID)
|
|
if accountID == "" {
|
|
continue
|
|
}
|
|
r.accountIDs = append(r.accountIDs, accountID)
|
|
}
|
|
}
|
|
|
|
func (r *managedResourceRollback) Run(ctx context.Context) error {
|
|
if r == nil || r.host == nil {
|
|
return nil
|
|
}
|
|
var errs []error
|
|
for index := len(r.accountIDs) - 1; index >= 0; index-- {
|
|
if err := r.host.DeleteAccount(ctx, r.accountIDs[index]); err != nil {
|
|
errs = append(errs, fmt.Errorf("delete account %s: %w", r.accountIDs[index], err))
|
|
}
|
|
}
|
|
if r.planID != "" {
|
|
if err := r.host.DeletePlan(ctx, r.planID); err != nil {
|
|
errs = append(errs, fmt.Errorf("delete plan %s: %w", r.planID, err))
|
|
}
|
|
}
|
|
if r.channelID != "" {
|
|
if err := r.host.DeleteChannel(ctx, r.channelID); err != nil {
|
|
errs = append(errs, fmt.Errorf("delete channel %s: %w", r.channelID, err))
|
|
}
|
|
}
|
|
if r.groupID != "" {
|
|
if err := r.host.DeleteGroup(ctx, r.groupID); err != nil {
|
|
errs = append(errs, fmt.Errorf("delete group %s: %w", r.groupID, err))
|
|
}
|
|
}
|
|
return errors.Join(errs...)
|
|
}
|
|
|
|
func failOrDegrade(report ImportReport, mode string, err error) (ImportReport, error) {
|
|
if mode == ImportModeStrict {
|
|
report.BatchStatus = BatchStatusFailed
|
|
report.ProviderStatus = ProviderStatusFailed
|
|
report.AccessStatus = AccessStatusBroken
|
|
return report, err
|
|
}
|
|
report.BatchStatus = BatchStatusPartial
|
|
report.ProviderStatus = ProviderStatusDegraded
|
|
report.AccessStatus = AccessStatusBroken
|
|
return report, err
|
|
}
|