test(project): achieve ≥70% package coverage across all internal packages
- store/sqlite: 75.4% (repos + db coverage) - host/sub2api: 80.8% (httptest mock server, pure function tests) - app: 74.2% (handler error paths, NewActionSet closures) - pack: 72.4% - provision: 75.2% - access: 77.3% - config: 94.7% (lookup mock tests) All tests pass: build, vet, race, coverage gates.
This commit is contained in:
321
internal/provision/batch_detail_and_reconcile_service.go
Normal file
321
internal/provision/batch_detail_and_reconcile_service.go
Normal file
@@ -0,0 +1,321 @@
|
||||
package provision
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/host/sub2api"
|
||||
"sub2api-cn-relay-manager/internal/pack"
|
||||
"sub2api-cn-relay-manager/internal/store/sqlite"
|
||||
)
|
||||
|
||||
type BatchDetailResult struct {
|
||||
Batch sqlite.ImportBatch
|
||||
Items []sqlite.ImportBatchItem
|
||||
ManagedResources []sqlite.ManagedResource
|
||||
AccessClosures []sqlite.AccessClosureRecord
|
||||
ReconcileRuns []sqlite.ReconcileRun
|
||||
}
|
||||
|
||||
type BatchDetailService struct {
|
||||
store *sqlite.DB
|
||||
}
|
||||
|
||||
func NewBatchDetailService(store *sqlite.DB) *BatchDetailService {
|
||||
return &BatchDetailService{store: store}
|
||||
}
|
||||
|
||||
func (s *BatchDetailService) Get(ctx context.Context, batchID int64) (BatchDetailResult, error) {
|
||||
if s == nil || s.store == nil {
|
||||
return BatchDetailResult{}, fmt.Errorf("store is required")
|
||||
}
|
||||
batch, err := s.store.ImportBatches().GetByID(ctx, batchID)
|
||||
if err != nil {
|
||||
return BatchDetailResult{}, err
|
||||
}
|
||||
items, err := s.store.ImportBatchItems().GetByBatchID(ctx, batchID)
|
||||
if err != nil {
|
||||
return BatchDetailResult{}, err
|
||||
}
|
||||
managedResources, err := s.store.ManagedResources().GetByBatchID(ctx, batchID)
|
||||
if err != nil {
|
||||
return BatchDetailResult{}, err
|
||||
}
|
||||
accessClosures, err := s.store.AccessClosures().GetByBatchID(ctx, batchID)
|
||||
if err != nil {
|
||||
return BatchDetailResult{}, err
|
||||
}
|
||||
reconcileRuns, err := s.store.ReconcileRuns().GetByProviderID(ctx, batch.ProviderID)
|
||||
if err != nil {
|
||||
return BatchDetailResult{}, err
|
||||
}
|
||||
return BatchDetailResult{
|
||||
Batch: batch,
|
||||
Items: items,
|
||||
ManagedResources: managedResources,
|
||||
AccessClosures: accessClosures,
|
||||
ReconcileRuns: reconcileRuns,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type ReconcileRequest struct {
|
||||
HostBaseURL string
|
||||
AccessProbeAPIKey string
|
||||
Pack pack.LoadedPack
|
||||
Provider pack.ProviderManifest
|
||||
}
|
||||
|
||||
type ReconcileResult struct {
|
||||
BatchID int64
|
||||
Status string
|
||||
MissingCount int
|
||||
ExtraCount int
|
||||
ProbeFailureCount int
|
||||
AccessStatus string
|
||||
Summary map[string]any
|
||||
}
|
||||
|
||||
type ReconcileService struct {
|
||||
store *sqlite.DB
|
||||
host sub2api.HostAdapter
|
||||
}
|
||||
|
||||
func NewReconcileService(store *sqlite.DB, host sub2api.HostAdapter) *ReconcileService {
|
||||
return &ReconcileService{store: store, host: host}
|
||||
}
|
||||
|
||||
func (s *ReconcileService) Reconcile(ctx context.Context, req ReconcileRequest) (ReconcileResult, error) {
|
||||
if s == nil || s.store == nil {
|
||||
return ReconcileResult{}, fmt.Errorf("store is required")
|
||||
}
|
||||
if s.host == nil {
|
||||
return ReconcileResult{}, fmt.Errorf("host adapter is required")
|
||||
}
|
||||
if strings.TrimSpace(req.HostBaseURL) == "" {
|
||||
return ReconcileResult{}, fmt.Errorf("host_base_url is required")
|
||||
}
|
||||
hostVersion, err := s.host.GetHostVersion(ctx)
|
||||
if err != nil {
|
||||
return ReconcileResult{}, fmt.Errorf("get host version: %w", err)
|
||||
}
|
||||
if err := pack.CheckHostCompatibility(req.Pack.Manifest, hostVersion); err != nil {
|
||||
return ReconcileResult{}, err
|
||||
}
|
||||
packRow, err := s.store.Packs().GetByPackID(ctx, req.Pack.Manifest.PackID)
|
||||
if err != nil {
|
||||
return ReconcileResult{}, err
|
||||
}
|
||||
providerRow, err := s.store.Providers().GetByPackIDAndProviderID(ctx, packRow.ID, req.Provider.ProviderID)
|
||||
if err != nil {
|
||||
return ReconcileResult{}, err
|
||||
}
|
||||
batchRow, err := s.store.ImportBatches().GetLatestByProviderID(ctx, providerRow.ID)
|
||||
if err != nil {
|
||||
return ReconcileResult{}, err
|
||||
}
|
||||
storedResources, err := s.store.ManagedResources().GetByBatchID(ctx, batchRow.ID)
|
||||
if err != nil {
|
||||
return ReconcileResult{}, err
|
||||
}
|
||||
batchItems, err := s.store.ImportBatchItems().GetByBatchID(ctx, batchRow.ID)
|
||||
if err != nil {
|
||||
return ReconcileResult{}, err
|
||||
}
|
||||
accessClosures, err := s.store.AccessClosures().GetByBatchID(ctx, batchRow.ID)
|
||||
if err != nil {
|
||||
return ReconcileResult{}, err
|
||||
}
|
||||
snapshot, err := s.host.ListManagedResources(ctx, sub2api.ListManagedResourcesRequest{
|
||||
GroupName: SuggestResourceNames(req.Provider).Group,
|
||||
ChannelName: SuggestResourceNames(req.Provider).Channel,
|
||||
PlanName: SuggestResourceNames(req.Provider).Plan,
|
||||
})
|
||||
if err != nil {
|
||||
return ReconcileResult{}, fmt.Errorf("list managed resources: %w", err)
|
||||
}
|
||||
missing, extra := diffManagedResources(storedResources, snapshot)
|
||||
probeFailures, err := s.rerunAccountProbes(ctx, batchItems, req.Provider.SmokeTestModel)
|
||||
if err != nil {
|
||||
return ReconcileResult{}, err
|
||||
}
|
||||
accessStatus, accessChecked, err := s.rerunAccessClosure(ctx, batchRow.ID, accessClosures, req.AccessProbeAPIKey, req.Provider.SmokeTestModel)
|
||||
if err != nil {
|
||||
return ReconcileResult{}, err
|
||||
}
|
||||
status := "active"
|
||||
if missing > 0 || extra > 0 {
|
||||
status = "drifted"
|
||||
} else if probeFailures > 0 || (accessChecked && accessStatus == AccessStatusBroken) {
|
||||
status = "degraded"
|
||||
}
|
||||
summary := map[string]any{
|
||||
"missing_count": missing,
|
||||
"extra_count": extra,
|
||||
"host_version": hostVersion,
|
||||
"probe_failures": probeFailures,
|
||||
"access_status": accessStatus,
|
||||
"access_rechecked": accessChecked,
|
||||
}
|
||||
summaryJSON, err := json.Marshal(summary)
|
||||
if err != nil {
|
||||
return ReconcileResult{}, fmt.Errorf("marshal reconcile summary: %w", err)
|
||||
}
|
||||
if _, err := s.store.ReconcileRuns().Create(ctx, sqlite.ReconcileRun{ProviderID: providerRow.ID, Status: status, SummaryJSON: string(summaryJSON)}); err != nil {
|
||||
return ReconcileResult{}, err
|
||||
}
|
||||
return ReconcileResult{BatchID: batchRow.ID, Status: status, MissingCount: missing, ExtraCount: extra, ProbeFailureCount: probeFailures, AccessStatus: accessStatus, Summary: summary}, nil
|
||||
}
|
||||
|
||||
func (s *ReconcileService) rerunAccountProbes(ctx context.Context, items []sqlite.ImportBatchItem, expectedModel string) (int, error) {
|
||||
if len(items) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
failures := 0
|
||||
for _, item := range items {
|
||||
accountID, err := accountIDFromProbeSummary(item.ProbeSummaryJSON)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("decode import batch item %d probe summary: %w", item.ID, err)
|
||||
}
|
||||
if strings.TrimSpace(accountID) == "" {
|
||||
return 0, fmt.Errorf("import batch item %d missing account_id in probe summary", item.ID)
|
||||
}
|
||||
probe, err := s.host.TestAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("re-test account %s: %w", accountID, err)
|
||||
}
|
||||
models, err := s.host.GetAccountModels(ctx, accountID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("reload account models %s: %w", accountID, err)
|
||||
}
|
||||
smokeModelSeen := hasModel(models, expectedModel)
|
||||
status := firstNonEmpty(probe.Status, "unknown")
|
||||
payload, err := json.Marshal(map[string]any{
|
||||
"account_id": accountID,
|
||||
"probe_ok": probe.OK,
|
||||
"probe_status": probe.Status,
|
||||
"probe_message": probe.Message,
|
||||
"models": models,
|
||||
"smoke_model_seen": smokeModelSeen,
|
||||
"reconcile_rerun": true,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("marshal probe rerun summary for %s: %w", accountID, err)
|
||||
}
|
||||
if err := s.store.ImportBatchItems().UpdateResult(ctx, item.ID, status, string(payload)); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if _, err := s.store.ProbeResults().Create(ctx, sqlite.ProbeResult{BatchItemID: item.ID, ProbeType: "account_smoke_rerun", Status: status, SummaryJSON: string(payload)}); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if !probe.OK || !smokeModelSeen {
|
||||
failures++
|
||||
}
|
||||
}
|
||||
return failures, nil
|
||||
}
|
||||
|
||||
func (s *ReconcileService) rerunAccessClosure(ctx context.Context, batchID int64, accessClosures []sqlite.AccessClosureRecord, probeAPIKey, expectedModel string) (string, bool, error) {
|
||||
if len(accessClosures) == 0 {
|
||||
return "not_configured", false, nil
|
||||
}
|
||||
latest := accessClosures[len(accessClosures)-1]
|
||||
status := firstNonEmpty(latest.Status, deriveHealthyAccessStatus(latest.ClosureType))
|
||||
if strings.TrimSpace(probeAPIKey) == "" {
|
||||
return status, false, nil
|
||||
}
|
||||
result, err := s.host.CheckGatewayAccess(ctx, sub2api.GatewayAccessCheckRequest{APIKey: probeAPIKey, ExpectedModel: expectedModel})
|
||||
if err != nil {
|
||||
return "", false, fmt.Errorf("re-check gateway access: %w", err)
|
||||
}
|
||||
if result.OK && result.HasExpectedModel {
|
||||
status = deriveHealthyAccessStatus(latest.ClosureType)
|
||||
} else {
|
||||
status = AccessStatusBroken
|
||||
}
|
||||
payload, err := json.Marshal(map[string]any{
|
||||
"status_code": result.StatusCode,
|
||||
"ok": result.OK,
|
||||
"has_expected_model": result.HasExpectedModel,
|
||||
"models": result.Models,
|
||||
"reconcile_rerun": true,
|
||||
})
|
||||
if err != nil {
|
||||
return "", false, fmt.Errorf("marshal access rerun summary: %w", err)
|
||||
}
|
||||
if _, err := s.store.AccessClosures().Create(ctx, sqlite.AccessClosureRecord{BatchID: batchID, ClosureType: latest.ClosureType, Status: status, DetailsJSON: string(payload)}); err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
return status, true, nil
|
||||
}
|
||||
|
||||
func deriveHealthyAccessStatus(closureType string) string {
|
||||
switch strings.TrimSpace(closureType) {
|
||||
case AccessModeSubscription:
|
||||
return AccessStatusSubscriptionReady
|
||||
case AccessModeSelfService:
|
||||
return AccessStatusSelfServiceReady
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func accountIDFromProbeSummary(summaryJSON string) (string, error) {
|
||||
if strings.TrimSpace(summaryJSON) == "" {
|
||||
return "", nil
|
||||
}
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(summaryJSON), &payload); err != nil {
|
||||
return "", err
|
||||
}
|
||||
accountID, _ := payload["account_id"].(string)
|
||||
return strings.TrimSpace(accountID), nil
|
||||
}
|
||||
|
||||
func diffManagedResources(stored []sqlite.ManagedResource, snapshot sub2api.ManagedResourceSnapshot) (int, int) {
|
||||
live := map[string]map[string]struct{}{
|
||||
"group": make(map[string]struct{}),
|
||||
"channel": make(map[string]struct{}),
|
||||
"plan": make(map[string]struct{}),
|
||||
"account": make(map[string]struct{}),
|
||||
}
|
||||
for _, resource := range snapshot.Groups {
|
||||
live["group"][strings.TrimSpace(resource.ID)] = struct{}{}
|
||||
}
|
||||
for _, resource := range snapshot.Channels {
|
||||
live["channel"][strings.TrimSpace(resource.ID)] = struct{}{}
|
||||
}
|
||||
for _, resource := range snapshot.Plans {
|
||||
live["plan"][strings.TrimSpace(resource.ID)] = struct{}{}
|
||||
}
|
||||
for _, resource := range snapshot.Accounts {
|
||||
live["account"][strings.TrimSpace(resource.ID)] = struct{}{}
|
||||
}
|
||||
|
||||
storedByType := map[string]map[string]struct{}{
|
||||
"group": make(map[string]struct{}),
|
||||
"channel": make(map[string]struct{}),
|
||||
"plan": make(map[string]struct{}),
|
||||
"account": make(map[string]struct{}),
|
||||
}
|
||||
for _, resource := range stored {
|
||||
storedByType[strings.TrimSpace(resource.ResourceType)][strings.TrimSpace(resource.HostResourceID)] = struct{}{}
|
||||
}
|
||||
|
||||
missing := 0
|
||||
extra := 0
|
||||
for resourceType, storedIDs := range storedByType {
|
||||
for id := range storedIDs {
|
||||
if _, ok := live[resourceType][id]; !ok {
|
||||
missing++
|
||||
}
|
||||
}
|
||||
for id := range live[resourceType] {
|
||||
if _, ok := storedIDs[id]; !ok {
|
||||
extra++
|
||||
}
|
||||
}
|
||||
}
|
||||
return missing, extra
|
||||
}
|
||||
235
internal/provision/batch_detail_service_test.go
Normal file
235
internal/provision/batch_detail_service_test.go
Normal file
@@ -0,0 +1,235 @@
|
||||
package provision
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/host/sub2api"
|
||||
"sub2api-cn-relay-manager/internal/store/sqlite"
|
||||
)
|
||||
|
||||
func TestBatchDetailServiceGetReturnsPersistedArtifacts(t *testing.T) {
|
||||
store := openProvisionTestStore(t)
|
||||
defer closeProvisionTestStore(t, store)
|
||||
|
||||
host := &fakeHostAdapter{
|
||||
batchAccounts: []sub2api.AccountRef{{ID: "account_1", Name: "deepseek-01"}, {ID: "account_2", Name: "deepseek-02"}},
|
||||
testResults: map[string]sub2api.ProbeResult{
|
||||
"account_1": {OK: true, Status: "passed"},
|
||||
"account_2": {OK: true, Status: "passed"},
|
||||
},
|
||||
models: map[string][]sub2api.AccountModel{
|
||||
"account_1": {{ID: "deepseek-chat"}},
|
||||
"account_2": {{ID: "deepseek-chat"}},
|
||||
},
|
||||
gatewayResult: sub2api.GatewayAccessResult{OK: true, StatusCode: 200, HasExpectedModel: true, Models: []string{"deepseek-chat"}},
|
||||
}
|
||||
|
||||
batchID := seedRuntimeImportForReconcile(t, store, host)
|
||||
providerRow, err := store.Providers().ListByProviderID(context.Background(), sampleProviderManifest().ProviderID)
|
||||
if err != nil {
|
||||
t.Fatalf("Providers().ListByProviderID() error = %v", err)
|
||||
}
|
||||
if len(providerRow) != 1 {
|
||||
t.Fatalf("providers = %d, want 1", len(providerRow))
|
||||
}
|
||||
if _, err := store.ReconcileRuns().Create(context.Background(), sqlite.ReconcileRun{ProviderID: providerRow[0].ID, Status: "active", SummaryJSON: `{"missing_count":0}`}); err != nil {
|
||||
t.Fatalf("ReconcileRuns().Create() error = %v", err)
|
||||
}
|
||||
|
||||
result, err := NewBatchDetailService(store).Get(context.Background(), batchID)
|
||||
if err != nil {
|
||||
t.Fatalf("Get() error = %v", err)
|
||||
}
|
||||
if result.Batch.ID != batchID {
|
||||
t.Fatalf("Batch.ID = %d, want %d", result.Batch.ID, batchID)
|
||||
}
|
||||
if len(result.Items) != 2 {
|
||||
t.Fatalf("len(Items) = %d, want 2", len(result.Items))
|
||||
}
|
||||
if len(result.ManagedResources) != 4 {
|
||||
t.Fatalf("len(ManagedResources) = %d, want 4", len(result.ManagedResources))
|
||||
}
|
||||
if len(result.AccessClosures) != 1 {
|
||||
t.Fatalf("len(AccessClosures) = %d, want 1", len(result.AccessClosures))
|
||||
}
|
||||
if len(result.ReconcileRuns) != 1 {
|
||||
t.Fatalf("len(ReconcileRuns) = %d, want 1", len(result.ReconcileRuns))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchDetailServiceGetValidatesStore(t *testing.T) {
|
||||
_, err := (*BatchDetailService)(nil).Get(context.Background(), 1)
|
||||
if err == nil || err.Error() != "store is required" {
|
||||
t.Fatalf("nil service Get() error = %v, want store is required", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountIDFromProbeSummary(t *testing.T) {
|
||||
accountID, err := accountIDFromProbeSummary(`{"account_id":" account_1 "}`)
|
||||
if err != nil {
|
||||
t.Fatalf("accountIDFromProbeSummary() error = %v", err)
|
||||
}
|
||||
if accountID != "account_1" {
|
||||
t.Fatalf("accountID = %q, want account_1", accountID)
|
||||
}
|
||||
if _, err := accountIDFromProbeSummary(`{`); err == nil {
|
||||
t.Fatal("accountIDFromProbeSummary() error = nil, want JSON decode error")
|
||||
}
|
||||
blank, err := accountIDFromProbeSummary("")
|
||||
if err != nil {
|
||||
t.Fatalf("accountIDFromProbeSummary(blank) error = %v", err)
|
||||
}
|
||||
if blank != "" {
|
||||
t.Fatalf("blank accountID = %q, want empty", blank)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReconcileServiceRerunAccessClosureWithoutProbeKeyUsesLatestStatus(t *testing.T) {
|
||||
store := openProvisionTestStore(t)
|
||||
defer closeProvisionTestStore(t, store)
|
||||
|
||||
status, checked, err := NewReconcileService(store, &fakeHostAdapter{}).rerunAccessClosure(context.Background(), 1, []sqlite.AccessClosureRecord{{ClosureType: AccessModeSubscription, Status: AccessStatusSubscriptionReady}}, "", "deepseek-chat")
|
||||
if err != nil {
|
||||
t.Fatalf("rerunAccessClosure() error = %v", err)
|
||||
}
|
||||
if checked {
|
||||
t.Fatal("checked = true, want false without probe key")
|
||||
}
|
||||
if status != AccessStatusSubscriptionReady {
|
||||
t.Fatalf("status = %q, want %q", status, AccessStatusSubscriptionReady)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReconcileServiceRerunAccessClosureMarksBrokenWhenGatewayCheckFails(t *testing.T) {
|
||||
store := openProvisionTestStore(t)
|
||||
defer closeProvisionTestStore(t, store)
|
||||
|
||||
hostSeed := &fakeHostAdapter{
|
||||
batchAccounts: []sub2api.AccountRef{{ID: "account_1", Name: "deepseek-01"}, {ID: "account_2", Name: "deepseek-02"}},
|
||||
testResults: map[string]sub2api.ProbeResult{
|
||||
"account_1": {OK: true, Status: "passed"},
|
||||
"account_2": {OK: true, Status: "passed"},
|
||||
},
|
||||
models: map[string][]sub2api.AccountModel{
|
||||
"account_1": {{ID: "deepseek-chat"}},
|
||||
"account_2": {{ID: "deepseek-chat"}},
|
||||
},
|
||||
gatewayResult: sub2api.GatewayAccessResult{OK: true, StatusCode: 200, HasExpectedModel: true, Models: []string{"deepseek-chat"}},
|
||||
}
|
||||
batchID := seedRuntimeImportForReconcile(t, store, hostSeed)
|
||||
|
||||
host := &fakeHostAdapter{gatewayResult: sub2api.GatewayAccessResult{OK: false, StatusCode: 403, HasExpectedModel: false}}
|
||||
status, checked, err := NewReconcileService(store, host).rerunAccessClosure(context.Background(), batchID, []sqlite.AccessClosureRecord{{ClosureType: AccessModeSelfService, Status: AccessStatusSelfServiceReady}}, "user-key", "deepseek-chat")
|
||||
if err != nil {
|
||||
t.Fatalf("rerunAccessClosure() error = %v", err)
|
||||
}
|
||||
if !checked {
|
||||
t.Fatal("checked = false, want true")
|
||||
}
|
||||
if status != AccessStatusBroken {
|
||||
t.Fatalf("status = %q, want %q", status, AccessStatusBroken)
|
||||
}
|
||||
if got := queryCount(t, store.SQLDB(), "access_closure_records"); got != 2 {
|
||||
t.Fatalf("access_closure_records row count = %d, want 2 after rerun", got)
|
||||
}
|
||||
if host.gatewayProbe.ExpectedModel != "deepseek-chat" {
|
||||
t.Fatalf("ExpectedModel = %q, want deepseek-chat", host.gatewayProbe.ExpectedModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiffManagedResourcesCountsMissingAndExtra(t *testing.T) {
|
||||
missing, extra := diffManagedResources(
|
||||
[]sqlite.ManagedResource{
|
||||
{ResourceType: "group", HostResourceID: "group_1"},
|
||||
{ResourceType: "account", HostResourceID: "account_1"},
|
||||
},
|
||||
sub2api.ManagedResourceSnapshot{
|
||||
Groups: []sub2api.NamedResource{{ID: "group_1"}},
|
||||
Accounts: []sub2api.NamedResource{{ID: "account_2"}},
|
||||
},
|
||||
)
|
||||
if missing != 1 || extra != 1 {
|
||||
t.Fatalf("diffManagedResources() = (%d, %d), want (1, 1)", missing, extra)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeriveProviderStatus(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
batchStatus string
|
||||
reconcileStatus string
|
||||
want string
|
||||
}{
|
||||
{name: "reconcile wins", batchStatus: BatchStatusSucceeded, reconcileStatus: "degraded", want: "degraded"},
|
||||
{name: "succeeded batch", batchStatus: BatchStatusSucceeded, reconcileStatus: "not_run", want: ProviderStatusActive},
|
||||
{name: "failed batch", batchStatus: BatchStatusFailed, want: ProviderStatusFailed},
|
||||
{name: "running batch", batchStatus: "running", want: "running"},
|
||||
{name: "unknown fallback", batchStatus: " pending ", want: "pending"},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := deriveProviderStatus(tc.batchStatus, tc.reconcileStatus); got != tc.want {
|
||||
t.Fatalf("deriveProviderStatus(%q, %q) = %q, want %q", tc.batchStatus, tc.reconcileStatus, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPackAndProviderRecord(t *testing.T) {
|
||||
packRow, err := buildPackRecord(sampleLoadedPack())
|
||||
if err != nil {
|
||||
t.Fatalf("buildPackRecord() error = %v", err)
|
||||
}
|
||||
if packRow.PackID != "openai-cn-pack" || packRow.TargetHost != "sub2api" {
|
||||
t.Fatalf("packRow = %#v, want populated pack metadata", packRow)
|
||||
}
|
||||
|
||||
providerRow, err := buildProviderRecord(7, sampleProviderManifest())
|
||||
if err != nil {
|
||||
t.Fatalf("buildProviderRecord() error = %v", err)
|
||||
}
|
||||
if providerRow.PackID != 7 || providerRow.ProviderID != sampleProviderManifest().ProviderID {
|
||||
t.Fatalf("providerRow = %#v, want persisted provider metadata", providerRow)
|
||||
}
|
||||
if providerRow.DefaultModelsJSON == "" || providerRow.ManifestJSON == "" {
|
||||
t.Fatalf("providerRow JSON fields = %#v, want serialized JSON", providerRow)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFirstNonEmptyAndFingerprintKey(t *testing.T) {
|
||||
if got := firstNonEmpty(" ", "value", "other"); got != "value" {
|
||||
t.Fatalf("firstNonEmpty() = %q, want value", got)
|
||||
}
|
||||
if got := fingerprintKey([]string{" key-1 "}, 0); got == "key-1" || got == "sha256:" || len(got) < 20 {
|
||||
t.Fatalf("fingerprintKey() = %q, want sha256 fingerprint", got)
|
||||
}
|
||||
if got := fingerprintKey(nil, 3); got != "key-4" {
|
||||
t.Fatalf("fingerprintKey(nil, 3) = %q, want key-4", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderStatusServiceGetResourcesRequiresProviderID(t *testing.T) {
|
||||
store := openProvisionTestStore(t)
|
||||
defer closeProvisionTestStore(t, store)
|
||||
|
||||
_, err := NewProviderStatusService(store).GetResources(context.Background(), ProviderQuery{})
|
||||
if err == nil || err.Error() != "provider_id is required" {
|
||||
t.Fatalf("GetResources() error = %v, want provider_id is required", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResourceSlugFallsBackToProvider(t *testing.T) {
|
||||
if got := resourceSlug(" !!! "); got != "provider" {
|
||||
t.Fatalf("resourceSlug() = %q, want provider", got)
|
||||
}
|
||||
provider := sampleProviderManifest()
|
||||
provider.ProviderID = " DeepSeek CN / Prod "
|
||||
if got := SuggestAccountNamePrefix(provider); got != "deepseek-cn-prod-" {
|
||||
t.Fatalf("SuggestAccountNamePrefix() = %q, want deepseek-cn-prod-", got)
|
||||
}
|
||||
resourceNames := SuggestResourceNames(provider)
|
||||
if resourceNames.Group != "crm-deepseek-cn-prod-group" {
|
||||
t.Fatalf("SuggestResourceNames() = %#v, want slugged resource names", resourceNames)
|
||||
}
|
||||
}
|
||||
343
internal/provision/import_service.go
Normal file
343
internal/provision/import_service.go
Normal file
@@ -0,0 +1,343 @@
|
||||
package provision
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/access"
|
||||
"sub2api-cn-relay-manager/internal/host/sub2api"
|
||||
"sub2api-cn-relay-manager/internal/pack"
|
||||
)
|
||||
|
||||
const (
|
||||
ImportModeStrict = "strict"
|
||||
ImportModePartial = "partial"
|
||||
|
||||
AccessModeSubscription = "subscription"
|
||||
AccessModeSelfService = "self_service"
|
||||
|
||||
BatchStatusSucceeded = "succeeded"
|
||||
BatchStatusPartial = "partially_succeeded"
|
||||
BatchStatusFailed = "failed"
|
||||
|
||||
ProviderStatusActive = "active"
|
||||
ProviderStatusDegraded = "degraded"
|
||||
ProviderStatusFailed = "failed"
|
||||
|
||||
AccessStatusSubscriptionReady = "subscription_ready"
|
||||
AccessStatusSelfServiceReady = "self_service_ready"
|
||||
AccessStatusBroken = "broken"
|
||||
)
|
||||
|
||||
type AccessRequest struct {
|
||||
Mode string
|
||||
ProbeAPIKey string
|
||||
Subscriptions []SubscriptionTarget
|
||||
}
|
||||
|
||||
type SubscriptionTarget struct {
|
||||
UserID string
|
||||
DurationDays int
|
||||
}
|
||||
|
||||
type ImportRequest struct {
|
||||
Provider pack.ProviderManifest
|
||||
Mode string
|
||||
Access AccessRequest
|
||||
Keys []string
|
||||
}
|
||||
|
||||
type ImportReport struct {
|
||||
BatchStatus string
|
||||
ProviderStatus string
|
||||
AccessStatus string
|
||||
AcceptedKeys []string
|
||||
Group sub2api.GroupRef
|
||||
Channel sub2api.ChannelRef
|
||||
Plan *sub2api.PlanRef
|
||||
Accounts []AccountImportResult
|
||||
Gateway sub2api.GatewayAccessResult
|
||||
}
|
||||
|
||||
type AccountImportResult struct {
|
||||
Ref sub2api.AccountRef
|
||||
Probe sub2api.ProbeResult
|
||||
Models []sub2api.AccountModel
|
||||
SmokeModelSeen bool
|
||||
}
|
||||
|
||||
type hostAdapter interface {
|
||||
sub2api.HostAdapter
|
||||
CheckGatewayAccess(ctx context.Context, req sub2api.GatewayAccessCheckRequest) (sub2api.GatewayAccessResult, error)
|
||||
}
|
||||
|
||||
type ImportService struct {
|
||||
host hostAdapter
|
||||
}
|
||||
|
||||
func NewImportService(host hostAdapter) *ImportService {
|
||||
return &ImportService{host: host}
|
||||
}
|
||||
|
||||
func (s *ImportService) Import(ctx context.Context, req ImportRequest) (report ImportReport, err error) {
|
||||
normalizedKeys, err := normalizeKeys(req.Keys)
|
||||
if err != nil {
|
||||
return ImportReport{}, err
|
||||
}
|
||||
if err := validateMode(req.Mode); err != nil {
|
||||
return ImportReport{}, err
|
||||
}
|
||||
if err := access.Validate(access.ClosureRequest{
|
||||
Mode: req.Access.Mode,
|
||||
ProbeAPIKey: req.Access.ProbeAPIKey,
|
||||
Subscriptions: toAccessSubscriptionTargets(req.Access.Subscriptions),
|
||||
}); err != nil {
|
||||
return ImportReport{}, err
|
||||
}
|
||||
|
||||
report = ImportReport{AcceptedKeys: normalizedKeys}
|
||||
rollback := newManagedResourceRollback(s.host)
|
||||
defer func() {
|
||||
if err == nil || req.Mode != ImportModeStrict {
|
||||
return
|
||||
}
|
||||
if rollbackErr := rollback.Run(ctx); rollbackErr != nil {
|
||||
err = errors.Join(err, fmt.Errorf("rollback managed resources: %w", rollbackErr))
|
||||
}
|
||||
}()
|
||||
group, err := s.host.CreateGroup(ctx, sub2api.CreateGroupRequest{
|
||||
Name: req.Provider.GroupTemplate.Name,
|
||||
RateMultiplier: req.Provider.GroupTemplate.RateMultiplier,
|
||||
})
|
||||
if err != nil {
|
||||
return report, fmt.Errorf("create group: %w", err)
|
||||
}
|
||||
report.Group = group
|
||||
rollback.AddGroup(group.ID)
|
||||
|
||||
channel, err := s.host.CreateChannel(ctx, sub2api.CreateChannelRequest{
|
||||
Name: req.Provider.ChannelTemplate.Name,
|
||||
GroupIDs: []string{group.ID},
|
||||
})
|
||||
if err != nil {
|
||||
return report, fmt.Errorf("create channel: %w", err)
|
||||
}
|
||||
report.Channel = channel
|
||||
rollback.AddChannel(channel.ID)
|
||||
|
||||
if req.Access.Mode == AccessModeSubscription {
|
||||
plan, err := s.host.CreatePlan(ctx, sub2api.CreatePlanRequest{
|
||||
GroupID: group.ID,
|
||||
Name: req.Provider.PlanTemplate.Name,
|
||||
Price: req.Provider.PlanTemplate.Price,
|
||||
ValidityDays: req.Provider.PlanTemplate.ValidityDays,
|
||||
ValidityUnit: req.Provider.PlanTemplate.ValidityUnit,
|
||||
})
|
||||
if err != nil {
|
||||
return report, fmt.Errorf("create plan: %w", err)
|
||||
}
|
||||
report.Plan = &plan
|
||||
rollback.AddPlan(plan.ID)
|
||||
}
|
||||
|
||||
accounts, err := s.host.BatchCreateAccounts(ctx, buildBatchAccountsRequest(req.Provider, group.ID, normalizedKeys))
|
||||
if err != nil {
|
||||
return report, fmt.Errorf("batch create accounts: %w", err)
|
||||
}
|
||||
rollback.AddAccounts(accounts)
|
||||
for _, account := range accounts {
|
||||
probe, err := s.host.TestAccount(ctx, account.ID)
|
||||
if err != nil {
|
||||
return failOrDegrade(report, req.Mode, fmt.Errorf("test account %s: %w", account.ID, err))
|
||||
}
|
||||
models, err := s.host.GetAccountModels(ctx, account.ID)
|
||||
if err != nil {
|
||||
return failOrDegrade(report, req.Mode, fmt.Errorf("get account models %s: %w", account.ID, err))
|
||||
}
|
||||
result := AccountImportResult{Ref: account, Probe: probe, Models: models, SmokeModelSeen: hasModel(models, req.Provider.SmokeTestModel)}
|
||||
report.Accounts = append(report.Accounts, result)
|
||||
}
|
||||
|
||||
failedAccounts := 0
|
||||
for _, account := range report.Accounts {
|
||||
if !account.Probe.OK || !account.SmokeModelSeen {
|
||||
failedAccounts++
|
||||
}
|
||||
}
|
||||
if failedAccounts > 0 && req.Mode == ImportModeStrict {
|
||||
report.BatchStatus = BatchStatusFailed
|
||||
report.ProviderStatus = ProviderStatusFailed
|
||||
report.AccessStatus = AccessStatusBroken
|
||||
return report, fmt.Errorf("strict import failed: %d account(s) did not pass smoke validation", failedAccounts)
|
||||
}
|
||||
|
||||
closureService := access.NewService(s.host)
|
||||
gateway, err := closureService.Close(ctx, access.ClosureRequest{
|
||||
Mode: req.Access.Mode,
|
||||
ProbeAPIKey: req.Access.ProbeAPIKey,
|
||||
Subscriptions: toAccessSubscriptionTargets(req.Access.Subscriptions),
|
||||
GroupID: group.ID,
|
||||
ExpectedModel: req.Provider.SmokeTestModel,
|
||||
})
|
||||
if err != nil {
|
||||
return failOrDegrade(report, req.Mode, err)
|
||||
}
|
||||
report.Gateway = gateway
|
||||
|
||||
report.BatchStatus = BatchStatusSucceeded
|
||||
report.ProviderStatus = ProviderStatusActive
|
||||
if failedAccounts > 0 || !gateway.OK || !gateway.HasExpectedModel {
|
||||
report.BatchStatus = BatchStatusPartial
|
||||
report.ProviderStatus = ProviderStatusDegraded
|
||||
}
|
||||
switch req.Access.Mode {
|
||||
case AccessModeSubscription:
|
||||
report.AccessStatus = AccessStatusSubscriptionReady
|
||||
case AccessModeSelfService:
|
||||
report.AccessStatus = AccessStatusSelfServiceReady
|
||||
}
|
||||
if !gateway.OK || !gateway.HasExpectedModel {
|
||||
report.AccessStatus = AccessStatusBroken
|
||||
}
|
||||
return report, nil
|
||||
}
|
||||
|
||||
func validateMode(mode string) error {
|
||||
switch strings.TrimSpace(mode) {
|
||||
case ImportModeStrict, ImportModePartial:
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("unsupported import mode %q", mode)
|
||||
}
|
||||
}
|
||||
|
||||
func toAccessSubscriptionTargets(targets []SubscriptionTarget) []access.SubscriptionTarget {
|
||||
result := make([]access.SubscriptionTarget, 0, len(targets))
|
||||
for _, target := range targets {
|
||||
result = append(result, access.SubscriptionTarget{UserID: target.UserID, DurationDays: target.DurationDays})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func normalizeKeys(keys []string) ([]string, error) {
|
||||
seen := map[string]struct{}{}
|
||||
result := make([]string, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
normalized := strings.TrimSpace(strings.TrimPrefix(key, "\ufeff"))
|
||||
if normalized == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[normalized]; ok {
|
||||
continue
|
||||
}
|
||||
seen[normalized] = struct{}{}
|
||||
result = append(result, normalized)
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil, fmt.Errorf("at least one api key is required")
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func buildBatchAccountsRequest(provider pack.ProviderManifest, groupID string, keys []string) sub2api.BatchCreateAccountsRequest {
|
||||
accounts := make([]sub2api.CreateAccountRequest, 0, len(keys))
|
||||
for index, key := range keys {
|
||||
accounts = append(accounts, sub2api.CreateAccountRequest{
|
||||
Name: fmt.Sprintf("%s-%02d", provider.ProviderID, index+1),
|
||||
Platform: provider.Platform,
|
||||
Type: provider.AccountType,
|
||||
GroupIDs: []string{groupID},
|
||||
Credentials: map[string]any{
|
||||
"base_url": provider.BaseURL,
|
||||
"api_key": key,
|
||||
},
|
||||
})
|
||||
}
|
||||
return sub2api.BatchCreateAccountsRequest{Accounts: accounts}
|
||||
}
|
||||
|
||||
func hasModel(models []sub2api.AccountModel, target string) bool {
|
||||
for _, model := range models {
|
||||
if strings.TrimSpace(model.ID) == strings.TrimSpace(target) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type managedResourceRollback struct {
|
||||
host hostAdapter
|
||||
groupID string
|
||||
channelID string
|
||||
planID string
|
||||
accountIDs []string
|
||||
}
|
||||
|
||||
func newManagedResourceRollback(host hostAdapter) *managedResourceRollback {
|
||||
return &managedResourceRollback{host: host}
|
||||
}
|
||||
|
||||
func (r *managedResourceRollback) AddGroup(groupID string) {
|
||||
r.groupID = strings.TrimSpace(groupID)
|
||||
}
|
||||
|
||||
func (r *managedResourceRollback) AddChannel(channelID string) {
|
||||
r.channelID = strings.TrimSpace(channelID)
|
||||
}
|
||||
|
||||
func (r *managedResourceRollback) AddPlan(planID string) {
|
||||
r.planID = strings.TrimSpace(planID)
|
||||
}
|
||||
|
||||
func (r *managedResourceRollback) AddAccounts(accounts []sub2api.AccountRef) {
|
||||
for _, account := range accounts {
|
||||
accountID := strings.TrimSpace(account.ID)
|
||||
if accountID == "" {
|
||||
continue
|
||||
}
|
||||
r.accountIDs = append(r.accountIDs, accountID)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *managedResourceRollback) Run(ctx context.Context) error {
|
||||
if r == nil || r.host == nil {
|
||||
return nil
|
||||
}
|
||||
var errs []error
|
||||
for index := len(r.accountIDs) - 1; index >= 0; index-- {
|
||||
if err := r.host.DeleteAccount(ctx, r.accountIDs[index]); err != nil {
|
||||
errs = append(errs, fmt.Errorf("delete account %s: %w", r.accountIDs[index], err))
|
||||
}
|
||||
}
|
||||
if r.planID != "" {
|
||||
if err := r.host.DeletePlan(ctx, r.planID); err != nil {
|
||||
errs = append(errs, fmt.Errorf("delete plan %s: %w", r.planID, err))
|
||||
}
|
||||
}
|
||||
if r.channelID != "" {
|
||||
if err := r.host.DeleteChannel(ctx, r.channelID); err != nil {
|
||||
errs = append(errs, fmt.Errorf("delete channel %s: %w", r.channelID, err))
|
||||
}
|
||||
}
|
||||
if r.groupID != "" {
|
||||
if err := r.host.DeleteGroup(ctx, r.groupID); err != nil {
|
||||
errs = append(errs, fmt.Errorf("delete group %s: %w", r.groupID, err))
|
||||
}
|
||||
}
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
func failOrDegrade(report ImportReport, mode string, err error) (ImportReport, error) {
|
||||
if mode == ImportModeStrict {
|
||||
report.BatchStatus = BatchStatusFailed
|
||||
report.ProviderStatus = ProviderStatusFailed
|
||||
report.AccessStatus = AccessStatusBroken
|
||||
return report, err
|
||||
}
|
||||
report.BatchStatus = BatchStatusPartial
|
||||
report.ProviderStatus = ProviderStatusDegraded
|
||||
report.AccessStatus = AccessStatusBroken
|
||||
return report, err
|
||||
}
|
||||
241
internal/provision/import_service_test.go
Normal file
241
internal/provision/import_service_test.go
Normal file
@@ -0,0 +1,241 @@
|
||||
package provision
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/host/sub2api"
|
||||
"sub2api-cn-relay-manager/internal/pack"
|
||||
)
|
||||
|
||||
func TestImportServiceImportSubscriptionFlow(t *testing.T) {
|
||||
host := &fakeHostAdapter{
|
||||
batchAccounts: []sub2api.AccountRef{{ID: "account_1"}, {ID: "account_2"}},
|
||||
testResults: map[string]sub2api.ProbeResult{
|
||||
"account_1": {OK: true, Status: "passed"},
|
||||
"account_2": {OK: true, Status: "passed"},
|
||||
},
|
||||
models: map[string][]sub2api.AccountModel{
|
||||
"account_1": {{ID: "deepseek-chat"}},
|
||||
"account_2": {{ID: "deepseek-chat"}},
|
||||
},
|
||||
gatewayResult: sub2api.GatewayAccessResult{OK: true, StatusCode: 200, HasExpectedModel: true, Models: []string{"deepseek-chat"}},
|
||||
}
|
||||
|
||||
svc := NewImportService(host)
|
||||
report, err := svc.Import(context.Background(), ImportRequest{
|
||||
Provider: sampleProviderManifest(),
|
||||
Mode: ImportModePartial,
|
||||
Access: AccessRequest{
|
||||
Mode: AccessModeSubscription,
|
||||
ProbeAPIKey: "user-key",
|
||||
Subscriptions: []SubscriptionTarget{{UserID: "user_1", DurationDays: 30}},
|
||||
},
|
||||
Keys: []string{" key-1 ", "key-2", "key-1"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Import() error = %v", err)
|
||||
}
|
||||
|
||||
if report.BatchStatus != BatchStatusSucceeded {
|
||||
t.Fatalf("BatchStatus = %q, want %q", report.BatchStatus, BatchStatusSucceeded)
|
||||
}
|
||||
if report.ProviderStatus != ProviderStatusActive {
|
||||
t.Fatalf("ProviderStatus = %q, want %q", report.ProviderStatus, ProviderStatusActive)
|
||||
}
|
||||
if report.AccessStatus != AccessStatusSubscriptionReady {
|
||||
t.Fatalf("AccessStatus = %q, want %q", report.AccessStatus, AccessStatusSubscriptionReady)
|
||||
}
|
||||
if !reflect.DeepEqual(report.AcceptedKeys, []string{"key-1", "key-2"}) {
|
||||
t.Fatalf("AcceptedKeys = %#v, want deduped normalized keys", report.AcceptedKeys)
|
||||
}
|
||||
if len(host.assignedSubscriptions) != 1 {
|
||||
t.Fatalf("assigned subscriptions = %d, want 1", len(host.assignedSubscriptions))
|
||||
}
|
||||
if host.gatewayProbe.ExpectedModel != "deepseek-chat" {
|
||||
t.Fatalf("gateway probe model = %q, want %q", host.gatewayProbe.ExpectedModel, "deepseek-chat")
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportServiceStrictModeFailsWhenAnyAccountProbeFails(t *testing.T) {
|
||||
host := &fakeHostAdapter{
|
||||
batchAccounts: []sub2api.AccountRef{{ID: "account_1"}, {ID: "account_2"}},
|
||||
testResults: map[string]sub2api.ProbeResult{
|
||||
"account_1": {OK: true, Status: "passed"},
|
||||
"account_2": {OK: false, Status: "failed", Message: "bad key"},
|
||||
},
|
||||
models: map[string][]sub2api.AccountModel{
|
||||
"account_1": {{ID: "deepseek-chat"}},
|
||||
"account_2": {{ID: "deepseek-chat"}},
|
||||
},
|
||||
}
|
||||
|
||||
svc := NewImportService(host)
|
||||
report, err := svc.Import(context.Background(), ImportRequest{
|
||||
Provider: sampleProviderManifest(),
|
||||
Mode: ImportModeStrict,
|
||||
Access: AccessRequest{Mode: AccessModeSelfService, ProbeAPIKey: "user-key"},
|
||||
Keys: []string{"key-1", "key-2"},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("Import() error = nil, want strict mode failure")
|
||||
}
|
||||
if report.BatchStatus != BatchStatusFailed {
|
||||
t.Fatalf("BatchStatus = %q, want %q", report.BatchStatus, BatchStatusFailed)
|
||||
}
|
||||
if report.ProviderStatus != ProviderStatusFailed {
|
||||
t.Fatalf("ProviderStatus = %q, want %q", report.ProviderStatus, ProviderStatusFailed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportServiceRejectsUnknownMode(t *testing.T) {
|
||||
svc := NewImportService(&fakeHostAdapter{})
|
||||
_, err := svc.Import(context.Background(), ImportRequest{
|
||||
Provider: sampleProviderManifest(),
|
||||
Mode: "unknown",
|
||||
Access: AccessRequest{Mode: AccessModeSelfService},
|
||||
Keys: []string{"key-1"},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("Import() error = nil, want mode validation error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportServiceStrictModeRollsBackCreatedResources(t *testing.T) {
|
||||
host := &fakeHostAdapter{
|
||||
batchAccounts: []sub2api.AccountRef{{ID: "account_1"}, {ID: "account_2"}},
|
||||
testResults: map[string]sub2api.ProbeResult{
|
||||
"account_1": {OK: true, Status: "passed"},
|
||||
"account_2": {OK: false, Status: "failed", Message: "bad key"},
|
||||
},
|
||||
models: map[string][]sub2api.AccountModel{
|
||||
"account_1": {{ID: "deepseek-chat"}},
|
||||
"account_2": {{ID: "deepseek-chat"}},
|
||||
},
|
||||
}
|
||||
|
||||
svc := NewImportService(host)
|
||||
_, err := svc.Import(context.Background(), ImportRequest{
|
||||
Provider: sampleProviderManifest(),
|
||||
Mode: ImportModeStrict,
|
||||
Access: AccessRequest{Mode: AccessModeSelfService, ProbeAPIKey: "user-key"},
|
||||
Keys: []string{"key-1", "key-2"},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("Import() error = nil, want strict mode failure")
|
||||
}
|
||||
|
||||
want := []string{"account:account_2", "account:account_1", "channel:channel_1", "group:group_1"}
|
||||
if !reflect.DeepEqual(host.deletedResources, want) {
|
||||
t.Fatalf("deleted resources = %#v, want %#v", host.deletedResources, want)
|
||||
}
|
||||
}
|
||||
|
||||
func sampleProviderManifest() pack.ProviderManifest {
|
||||
return pack.ProviderManifest{
|
||||
ProviderID: "deepseek",
|
||||
DisplayName: "DeepSeek OpenAI Compatible",
|
||||
BaseURL: "https://api.deepseek.com",
|
||||
Platform: "openai",
|
||||
AccountType: "api",
|
||||
DefaultModels: []string{"deepseek-chat", "deepseek-reasoner"},
|
||||
SmokeTestModel: "deepseek-chat",
|
||||
GroupTemplate: pack.GroupTemplate{Name: "DeepSeek 默认分组", RateMultiplier: 1},
|
||||
ChannelTemplate: pack.ChannelTemplate{Name: "DeepSeek 默认渠道", ModelMapping: map[string]string{"deepseek-chat": "deepseek-chat"}},
|
||||
PlanTemplate: pack.PlanTemplate{Name: "DeepSeek 默认套餐", Price: 19.9, ValidityDays: 30, ValidityUnit: "day"},
|
||||
}
|
||||
}
|
||||
|
||||
type fakeHostAdapter struct {
|
||||
batchAccounts []sub2api.AccountRef
|
||||
testResults map[string]sub2api.ProbeResult
|
||||
models map[string][]sub2api.AccountModel
|
||||
gatewayResult sub2api.GatewayAccessResult
|
||||
batchCreateErr error
|
||||
assignErr error
|
||||
gatewayErr error
|
||||
hostVersion string
|
||||
assignedSubscriptions []sub2api.AssignSubscriptionRequest
|
||||
gatewayProbe sub2api.GatewayAccessCheckRequest
|
||||
deletedResources []string
|
||||
managedSnapshot sub2api.ManagedResourceSnapshot
|
||||
}
|
||||
|
||||
func (f *fakeHostAdapter) GetHostVersion(context.Context) (string, error) {
|
||||
if strings.TrimSpace(f.hostVersion) == "" {
|
||||
return "0.1.126", nil
|
||||
}
|
||||
return f.hostVersion, nil
|
||||
}
|
||||
func (f *fakeHostAdapter) ProbeCapabilities(context.Context) (sub2api.HostCapabilities, error) {
|
||||
return sub2api.HostCapabilities{}, nil
|
||||
}
|
||||
func (f *fakeHostAdapter) CreateGroup(context.Context, sub2api.CreateGroupRequest) (sub2api.GroupRef, error) {
|
||||
return sub2api.GroupRef{ID: "group_1", Name: "g"}, nil
|
||||
}
|
||||
func (f *fakeHostAdapter) DeleteGroup(_ context.Context, groupID string) error {
|
||||
f.deletedResources = append(f.deletedResources, "group:"+groupID)
|
||||
return nil
|
||||
}
|
||||
func (f *fakeHostAdapter) CreateChannel(context.Context, sub2api.CreateChannelRequest) (sub2api.ChannelRef, error) {
|
||||
return sub2api.ChannelRef{ID: "channel_1", Name: "c"}, nil
|
||||
}
|
||||
func (f *fakeHostAdapter) DeleteChannel(_ context.Context, channelID string) error {
|
||||
f.deletedResources = append(f.deletedResources, "channel:"+channelID)
|
||||
return nil
|
||||
}
|
||||
func (f *fakeHostAdapter) CreatePlan(context.Context, sub2api.CreatePlanRequest) (sub2api.PlanRef, error) {
|
||||
return sub2api.PlanRef{ID: "plan_1", Name: "p"}, nil
|
||||
}
|
||||
func (f *fakeHostAdapter) DeletePlan(_ context.Context, planID string) error {
|
||||
f.deletedResources = append(f.deletedResources, "plan:"+planID)
|
||||
return nil
|
||||
}
|
||||
func (f *fakeHostAdapter) CreateAccount(context.Context, sub2api.CreateAccountRequest) (sub2api.AccountRef, error) {
|
||||
return sub2api.AccountRef{}, errors.New("unused")
|
||||
}
|
||||
func (f *fakeHostAdapter) BatchCreateAccounts(_ context.Context, _ sub2api.BatchCreateAccountsRequest) ([]sub2api.AccountRef, error) {
|
||||
if f.batchCreateErr != nil {
|
||||
return nil, f.batchCreateErr
|
||||
}
|
||||
return f.batchAccounts, nil
|
||||
}
|
||||
func (f *fakeHostAdapter) DeleteAccount(_ context.Context, accountID string) error {
|
||||
f.deletedResources = append(f.deletedResources, "account:"+accountID)
|
||||
return nil
|
||||
}
|
||||
func (f *fakeHostAdapter) TestAccount(_ context.Context, accountID string) (sub2api.ProbeResult, error) {
|
||||
result, ok := f.testResults[accountID]
|
||||
if !ok {
|
||||
return sub2api.ProbeResult{}, fmt.Errorf("missing test result for %s", accountID)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
func (f *fakeHostAdapter) GetAccountModels(_ context.Context, accountID string) ([]sub2api.AccountModel, error) {
|
||||
models, ok := f.models[accountID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing models for %s", accountID)
|
||||
}
|
||||
return models, nil
|
||||
}
|
||||
func (f *fakeHostAdapter) AssignSubscription(_ context.Context, req sub2api.AssignSubscriptionRequest) (sub2api.SubscriptionRef, error) {
|
||||
if f.assignErr != nil {
|
||||
return sub2api.SubscriptionRef{}, f.assignErr
|
||||
}
|
||||
f.assignedSubscriptions = append(f.assignedSubscriptions, req)
|
||||
return sub2api.SubscriptionRef{ID: "subscription_1"}, nil
|
||||
}
|
||||
func (f *fakeHostAdapter) CheckGatewayAccess(_ context.Context, req sub2api.GatewayAccessCheckRequest) (sub2api.GatewayAccessResult, error) {
|
||||
f.gatewayProbe = req
|
||||
if f.gatewayErr != nil {
|
||||
return sub2api.GatewayAccessResult{}, f.gatewayErr
|
||||
}
|
||||
return f.gatewayResult, nil
|
||||
}
|
||||
func (f *fakeHostAdapter) ListManagedResources(context.Context, sub2api.ListManagedResourcesRequest) (sub2api.ManagedResourceSnapshot, error) {
|
||||
return f.managedSnapshot, nil
|
||||
}
|
||||
40
internal/provision/naming.go
Normal file
40
internal/provision/naming.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package provision
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/pack"
|
||||
)
|
||||
|
||||
var nonSlugPattern = regexp.MustCompile(`[^a-z0-9]+`)
|
||||
|
||||
type ResourceNames struct {
|
||||
Group string
|
||||
Channel string
|
||||
Plan string
|
||||
}
|
||||
|
||||
func SuggestAccountNamePrefix(provider pack.ProviderManifest) string {
|
||||
return fmt.Sprintf("%s-", resourceSlug(provider.ProviderID))
|
||||
}
|
||||
|
||||
func SuggestResourceNames(provider pack.ProviderManifest) ResourceNames {
|
||||
slug := resourceSlug(provider.ProviderID)
|
||||
return ResourceNames{
|
||||
Group: fmt.Sprintf("crm-%s-group", slug),
|
||||
Channel: fmt.Sprintf("crm-%s-channel", slug),
|
||||
Plan: fmt.Sprintf("crm-%s-plan", slug),
|
||||
}
|
||||
}
|
||||
|
||||
func resourceSlug(raw string) string {
|
||||
slug := strings.ToLower(strings.TrimSpace(raw))
|
||||
slug = nonSlugPattern.ReplaceAllString(slug, "-")
|
||||
slug = strings.Trim(slug, "-")
|
||||
if slug == "" {
|
||||
return "provider"
|
||||
}
|
||||
return slug
|
||||
}
|
||||
178
internal/provision/pack_install_service.go
Normal file
178
internal/provision/pack_install_service.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package provision
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/host/sub2api"
|
||||
packdef "sub2api-cn-relay-manager/internal/pack"
|
||||
"sub2api-cn-relay-manager/internal/store/sqlite"
|
||||
)
|
||||
|
||||
type PackInstallRequest struct {
|
||||
Pack packdef.LoadedPack
|
||||
}
|
||||
|
||||
type PackInstallResult struct {
|
||||
Pack sqlite.Pack
|
||||
Providers []sqlite.Provider
|
||||
HostVersion string
|
||||
AlreadyInstalled bool
|
||||
}
|
||||
|
||||
type PackInstallService struct {
|
||||
store *sqlite.DB
|
||||
host sub2api.HostAdapter
|
||||
}
|
||||
|
||||
func NewPackInstallService(store *sqlite.DB, host sub2api.HostAdapter) *PackInstallService {
|
||||
return &PackInstallService{store: store, host: host}
|
||||
}
|
||||
|
||||
func (s *PackInstallService) Install(ctx context.Context, req PackInstallRequest) (PackInstallResult, error) {
|
||||
if s == nil || s.store == nil {
|
||||
return PackInstallResult{}, fmt.Errorf("store is required")
|
||||
}
|
||||
if s.host == nil {
|
||||
return PackInstallResult{}, fmt.Errorf("host adapter is required")
|
||||
}
|
||||
if strings.TrimSpace(req.Pack.Manifest.PackID) == "" {
|
||||
return PackInstallResult{}, fmt.Errorf("pack manifest is required")
|
||||
}
|
||||
|
||||
hostVersion, err := s.host.GetHostVersion(ctx)
|
||||
if err != nil {
|
||||
return PackInstallResult{}, fmt.Errorf("get host version: %w", err)
|
||||
}
|
||||
if err := packdef.CheckHostCompatibility(req.Pack.Manifest, hostVersion); err != nil {
|
||||
return PackInstallResult{}, err
|
||||
}
|
||||
|
||||
result := PackInstallResult{HostVersion: hostVersion}
|
||||
if err := s.store.WithTx(ctx, func(queries *sqlite.Queries) error {
|
||||
existing, err := queries.Packs.GetByPackID(ctx, req.Pack.Manifest.PackID)
|
||||
if err == nil {
|
||||
if err := validateExistingPack(existing, req.Pack); err != nil {
|
||||
return err
|
||||
}
|
||||
result.AlreadyInstalled = true
|
||||
} else if !errors.Is(err, sql.ErrNoRows) {
|
||||
return err
|
||||
}
|
||||
|
||||
packRow, err := buildPackRecord(req.Pack)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := queries.Packs.Upsert(ctx, packRow); err != nil {
|
||||
return err
|
||||
}
|
||||
persistedPack, err := queries.Packs.GetByPackID(ctx, req.Pack.Manifest.PackID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result.Pack = persistedPack
|
||||
|
||||
providers := make([]sqlite.Provider, 0, len(req.Pack.Providers))
|
||||
for _, providerManifest := range req.Pack.Providers {
|
||||
providerRow, err := buildProviderRecord(persistedPack.ID, providerManifest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := queries.Providers.Upsert(ctx, providerRow); err != nil {
|
||||
return err
|
||||
}
|
||||
persistedProvider, err := queries.Providers.GetByPackIDAndProviderID(ctx, persistedPack.ID, providerManifest.ProviderID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
providers = append(providers, persistedProvider)
|
||||
}
|
||||
result.Providers = providers
|
||||
return nil
|
||||
}); err != nil {
|
||||
return PackInstallResult{}, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func validateExistingPack(existing sqlite.Pack, loaded packdef.LoadedPack) error {
|
||||
if strings.TrimSpace(existing.Version) != strings.TrimSpace(loaded.Manifest.Version) {
|
||||
return fmt.Errorf("pack %q already installed with version %q; upgrade lifecycle not implemented", existing.PackID, existing.Version)
|
||||
}
|
||||
if strings.TrimSpace(existing.Checksum) != strings.TrimSpace(loaded.Checksum) {
|
||||
return fmt.Errorf("pack %q version %q checksum drift detected", existing.PackID, existing.Version)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildPackRecord(loaded packdef.LoadedPack) (sqlite.Pack, error) {
|
||||
manifestJSON, err := json.Marshal(loaded.Manifest)
|
||||
if err != nil {
|
||||
return sqlite.Pack{}, fmt.Errorf("marshal pack manifest: %w", err)
|
||||
}
|
||||
return sqlite.Pack{
|
||||
PackID: loaded.Manifest.PackID,
|
||||
Version: loaded.Manifest.Version,
|
||||
Checksum: loaded.Checksum,
|
||||
Vendor: loaded.Manifest.Vendor,
|
||||
TargetHost: loaded.Manifest.TargetHost,
|
||||
MinHostVersion: loaded.Manifest.MinHostVersion,
|
||||
MaxHostVersion: loaded.Manifest.MaxHostVersion,
|
||||
ManifestJSON: string(manifestJSON),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func buildProviderRecord(packID int64, provider packdef.ProviderManifest) (sqlite.Provider, error) {
|
||||
defaultModelsJSON, err := marshalJSONString(provider.DefaultModels)
|
||||
if err != nil {
|
||||
return sqlite.Provider{}, fmt.Errorf("marshal provider default models: %w", err)
|
||||
}
|
||||
groupTemplateJSON, err := marshalJSONString(provider.GroupTemplate)
|
||||
if err != nil {
|
||||
return sqlite.Provider{}, fmt.Errorf("marshal group template: %w", err)
|
||||
}
|
||||
channelTemplateJSON, err := marshalJSONString(provider.ChannelTemplate)
|
||||
if err != nil {
|
||||
return sqlite.Provider{}, fmt.Errorf("marshal channel template: %w", err)
|
||||
}
|
||||
planTemplateJSON, err := marshalJSONString(provider.PlanTemplate)
|
||||
if err != nil {
|
||||
return sqlite.Provider{}, fmt.Errorf("marshal plan template: %w", err)
|
||||
}
|
||||
importOptionsJSON, err := marshalJSONString(provider.Import)
|
||||
if err != nil {
|
||||
return sqlite.Provider{}, fmt.Errorf("marshal import options: %w", err)
|
||||
}
|
||||
manifestJSON, err := marshalJSONString(provider)
|
||||
if err != nil {
|
||||
return sqlite.Provider{}, fmt.Errorf("marshal provider manifest: %w", err)
|
||||
}
|
||||
return sqlite.Provider{
|
||||
PackID: packID,
|
||||
ProviderID: provider.ProviderID,
|
||||
DisplayName: provider.DisplayName,
|
||||
BaseURL: provider.BaseURL,
|
||||
Platform: provider.Platform,
|
||||
AccountType: provider.AccountType,
|
||||
DefaultModelsJSON: defaultModelsJSON,
|
||||
SmokeTestModel: provider.SmokeTestModel,
|
||||
GroupTemplateJSON: groupTemplateJSON,
|
||||
ChannelTemplateJSON: channelTemplateJSON,
|
||||
PlanTemplateJSON: planTemplateJSON,
|
||||
ImportOptionsJSON: importOptionsJSON,
|
||||
ManifestJSON: manifestJSON,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func marshalJSONString(value any) (string, error) {
|
||||
body, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(body), nil
|
||||
}
|
||||
120
internal/provision/pack_install_service_test.go
Normal file
120
internal/provision/pack_install_service_test.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package provision
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/pack"
|
||||
"sub2api-cn-relay-manager/internal/store/sqlite"
|
||||
)
|
||||
|
||||
func TestPackInstallServiceInstallPersistsPackAndProviders(t *testing.T) {
|
||||
store := openProvisionTestStore(t)
|
||||
defer closeProvisionTestStore(t, store)
|
||||
|
||||
host := &fakeHostAdapter{}
|
||||
loaded := sampleLoadedPack()
|
||||
|
||||
svc := NewPackInstallService(store, host)
|
||||
result, err := svc.Install(context.Background(), PackInstallRequest{Pack: loaded})
|
||||
if err != nil {
|
||||
t.Fatalf("Install() error = %v", err)
|
||||
}
|
||||
if result.HostVersion != "0.1.126" {
|
||||
t.Fatalf("HostVersion = %q, want 0.1.126", result.HostVersion)
|
||||
}
|
||||
if result.AlreadyInstalled {
|
||||
t.Fatal("AlreadyInstalled = true, want false on first install")
|
||||
}
|
||||
if result.Pack.PackID != loaded.Manifest.PackID {
|
||||
t.Fatalf("Pack.PackID = %q, want %q", result.Pack.PackID, loaded.Manifest.PackID)
|
||||
}
|
||||
if len(result.Providers) != 1 || result.Providers[0].ProviderID != loaded.Providers[0].ProviderID {
|
||||
t.Fatalf("Providers = %#v, want one persisted provider", result.Providers)
|
||||
}
|
||||
if got := queryCount(t, store.SQLDB(), "packs"); got != 1 {
|
||||
t.Fatalf("packs row count = %d, want 1", got)
|
||||
}
|
||||
if got := queryCount(t, store.SQLDB(), "providers"); got != 1 {
|
||||
t.Fatalf("providers row count = %d, want 1", got)
|
||||
}
|
||||
|
||||
repeat, err := svc.Install(context.Background(), PackInstallRequest{Pack: loaded})
|
||||
if err != nil {
|
||||
t.Fatalf("second Install() error = %v", err)
|
||||
}
|
||||
if !repeat.AlreadyInstalled {
|
||||
t.Fatal("AlreadyInstalled = false, want true on re-install")
|
||||
}
|
||||
if got := queryCount(t, store.SQLDB(), "packs"); got != 1 {
|
||||
t.Fatalf("packs row count after re-install = %d, want 1", got)
|
||||
}
|
||||
if got := queryCount(t, store.SQLDB(), "providers"); got != 1 {
|
||||
t.Fatalf("providers row count after re-install = %d, want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPackInstallServiceInstallRejectsVersionAndChecksumDrift(t *testing.T) {
|
||||
store := openProvisionTestStore(t)
|
||||
defer closeProvisionTestStore(t, store)
|
||||
|
||||
svc := NewPackInstallService(store, &fakeHostAdapter{})
|
||||
loaded := sampleLoadedPack()
|
||||
if _, err := svc.Install(context.Background(), PackInstallRequest{Pack: loaded}); err != nil {
|
||||
t.Fatalf("initial Install() error = %v", err)
|
||||
}
|
||||
|
||||
versionDrift := sampleLoadedPack()
|
||||
versionDrift.Manifest.Version = "2.0.0"
|
||||
if _, err := svc.Install(context.Background(), PackInstallRequest{Pack: versionDrift}); err == nil || !strings.Contains(err.Error(), "upgrade lifecycle not implemented") {
|
||||
t.Fatalf("Install() version drift error = %v, want upgrade lifecycle error", err)
|
||||
}
|
||||
|
||||
checksumDrift := sampleLoadedPack()
|
||||
checksumDrift.Checksum = "checksum-2"
|
||||
if _, err := svc.Install(context.Background(), PackInstallRequest{Pack: checksumDrift}); err == nil || !strings.Contains(err.Error(), "checksum drift detected") {
|
||||
t.Fatalf("Install() checksum drift error = %v, want checksum drift error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPackInstallServiceInstallValidatesDependencies(t *testing.T) {
|
||||
loaded := sampleLoadedPack()
|
||||
storeWithoutHost := openProvisionTestStore(t)
|
||||
defer closeProvisionTestStore(t, storeWithoutHost)
|
||||
storeWithoutPack := openProvisionTestStore(t)
|
||||
defer closeProvisionTestStore(t, storeWithoutPack)
|
||||
|
||||
if _, err := (*PackInstallService)(nil).Install(context.Background(), PackInstallRequest{Pack: loaded}); err == nil || err.Error() != "store is required" {
|
||||
t.Fatalf("nil service Install() error = %v, want store is required", err)
|
||||
}
|
||||
if _, err := (&PackInstallService{store: storeWithoutHost}).Install(context.Background(), PackInstallRequest{Pack: loaded}); err == nil || err.Error() != "host adapter is required" {
|
||||
t.Fatalf("missing host Install() error = %v, want host adapter is required", err)
|
||||
}
|
||||
if _, err := NewPackInstallService(storeWithoutPack, &fakeHostAdapter{}).Install(context.Background(), PackInstallRequest{}); err == nil || err.Error() != "pack manifest is required" {
|
||||
t.Fatalf("missing pack Install() error = %v, want pack manifest is required", err)
|
||||
}
|
||||
}
|
||||
|
||||
func sampleLoadedPack() pack.LoadedPack {
|
||||
provider := sampleProviderManifest()
|
||||
return pack.LoadedPack{
|
||||
Manifest: pack.Manifest{
|
||||
PackID: "openai-cn-pack",
|
||||
Version: "1.0.0",
|
||||
Vendor: "nous",
|
||||
TargetHost: "sub2api",
|
||||
MinHostVersion: "0.1.126",
|
||||
MaxHostVersion: "0.2.x",
|
||||
},
|
||||
Providers: []pack.ProviderManifest{provider},
|
||||
Checksum: "checksum-1",
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateExistingPack(t *testing.T) {
|
||||
existing := sqlite.Pack{PackID: "openai-cn-pack", Version: "1.0.0", Checksum: "checksum-1"}
|
||||
if err := validateExistingPack(existing, sampleLoadedPack()); err != nil {
|
||||
t.Fatalf("validateExistingPack() error = %v, want nil", err)
|
||||
}
|
||||
}
|
||||
90
internal/provision/preview_service.go
Normal file
90
internal/provision/preview_service.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package provision
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/host/sub2api"
|
||||
"sub2api-cn-relay-manager/internal/pack"
|
||||
)
|
||||
|
||||
const (
|
||||
PreviewActionCreate = "create"
|
||||
PreviewActionReuse = "reuse"
|
||||
PreviewActionConflict = "conflict"
|
||||
)
|
||||
|
||||
type previewHost interface {
|
||||
ListManagedResources(ctx context.Context, req sub2api.ListManagedResourcesRequest) (sub2api.ManagedResourceSnapshot, error)
|
||||
}
|
||||
|
||||
type PreviewRequest struct {
|
||||
Provider pack.ProviderManifest
|
||||
Mode string
|
||||
Keys []string
|
||||
}
|
||||
|
||||
type PreviewDecision struct {
|
||||
Action string
|
||||
Suggested string
|
||||
ExistingID string
|
||||
Reason string
|
||||
}
|
||||
|
||||
type PreviewReport struct {
|
||||
AcceptedKeys []string
|
||||
Names ResourceNames
|
||||
Decisions map[string]PreviewDecision
|
||||
}
|
||||
|
||||
type PreviewService struct {
|
||||
host previewHost
|
||||
}
|
||||
|
||||
func NewPreviewService(host previewHost) *PreviewService {
|
||||
return &PreviewService{host: host}
|
||||
}
|
||||
|
||||
func (s *PreviewService) PreviewImport(ctx context.Context, req PreviewRequest) (PreviewReport, error) {
|
||||
acceptedKeys, err := normalizeKeys(req.Keys)
|
||||
if err != nil {
|
||||
return PreviewReport{}, err
|
||||
}
|
||||
if err := validateMode(req.Mode); err != nil {
|
||||
return PreviewReport{}, err
|
||||
}
|
||||
if s.host == nil {
|
||||
return PreviewReport{}, fmt.Errorf("preview host is required")
|
||||
}
|
||||
|
||||
names := SuggestResourceNames(req.Provider)
|
||||
snapshot, err := s.host.ListManagedResources(ctx, sub2api.ListManagedResourcesRequest{
|
||||
GroupName: names.Group,
|
||||
ChannelName: names.Channel,
|
||||
PlanName: names.Plan,
|
||||
})
|
||||
if err != nil {
|
||||
return PreviewReport{}, fmt.Errorf("list managed resources: %w", err)
|
||||
}
|
||||
|
||||
return PreviewReport{
|
||||
AcceptedKeys: acceptedKeys,
|
||||
Names: names,
|
||||
Decisions: map[string]PreviewDecision{
|
||||
"group": decideResource(names.Group, snapshot.Groups),
|
||||
"channel": decideResource(names.Channel, snapshot.Channels),
|
||||
"plan": decideResource(names.Plan, snapshot.Plans),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func decideResource(suggested string, existing []sub2api.NamedResource) PreviewDecision {
|
||||
switch len(existing) {
|
||||
case 0:
|
||||
return PreviewDecision{Action: PreviewActionCreate, Suggested: suggested}
|
||||
case 1:
|
||||
return PreviewDecision{Action: PreviewActionReuse, Suggested: suggested, ExistingID: existing[0].ID, Reason: "matching managed resource already exists"}
|
||||
default:
|
||||
return PreviewDecision{Action: PreviewActionConflict, Suggested: suggested, Reason: "multiple managed resources share the suggested name"}
|
||||
}
|
||||
}
|
||||
87
internal/provision/preview_service_test.go
Normal file
87
internal/provision/preview_service_test.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package provision
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/host/sub2api"
|
||||
)
|
||||
|
||||
func TestSuggestResourceNames(t *testing.T) {
|
||||
provider := sampleProviderManifest()
|
||||
|
||||
names := SuggestResourceNames(provider)
|
||||
|
||||
want := ResourceNames{
|
||||
Group: "crm-deepseek-group",
|
||||
Channel: "crm-deepseek-channel",
|
||||
Plan: "crm-deepseek-plan",
|
||||
}
|
||||
if !reflect.DeepEqual(names, want) {
|
||||
t.Fatalf("SuggestResourceNames() = %#v, want %#v", names, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreviewServiceReportsCreateActionsWhenHostHasNoResources(t *testing.T) {
|
||||
host := &fakePreviewHost{}
|
||||
svc := NewPreviewService(host)
|
||||
|
||||
report, err := svc.PreviewImport(context.Background(), PreviewRequest{
|
||||
Provider: sampleProviderManifest(),
|
||||
Mode: ImportModeStrict,
|
||||
Keys: []string{" key-1 ", "key-2", "key-1"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("PreviewImport() error = %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(report.AcceptedKeys, []string{"key-1", "key-2"}) {
|
||||
t.Fatalf("AcceptedKeys = %#v, want normalized deduped keys", report.AcceptedKeys)
|
||||
}
|
||||
if got := report.Decisions["group"]; got.Action != PreviewActionCreate {
|
||||
t.Fatalf("group action = %q, want %q", got.Action, PreviewActionCreate)
|
||||
}
|
||||
if got := report.Decisions["channel"]; got.Action != PreviewActionCreate {
|
||||
t.Fatalf("channel action = %q, want %q", got.Action, PreviewActionCreate)
|
||||
}
|
||||
if got := report.Decisions["plan"]; got.Action != PreviewActionCreate {
|
||||
t.Fatalf("plan action = %q, want %q", got.Action, PreviewActionCreate)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreviewServiceReportsReuseAndConflict(t *testing.T) {
|
||||
host := &fakePreviewHost{snapshot: sub2api.ManagedResourceSnapshot{
|
||||
Groups: []sub2api.NamedResource{{ID: "group_1", Name: "crm-deepseek-group"}},
|
||||
Channels: []sub2api.NamedResource{{ID: "channel_1", Name: "crm-deepseek-channel"}, {ID: "channel_2", Name: "crm-deepseek-channel"}},
|
||||
Plans: []sub2api.NamedResource{{ID: "plan_1", Name: "crm-deepseek-plan"}},
|
||||
}}
|
||||
svc := NewPreviewService(host)
|
||||
|
||||
report, err := svc.PreviewImport(context.Background(), PreviewRequest{
|
||||
Provider: sampleProviderManifest(),
|
||||
Mode: ImportModePartial,
|
||||
Keys: []string{"key-1"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("PreviewImport() error = %v", err)
|
||||
}
|
||||
|
||||
if got := report.Decisions["group"]; got.Action != PreviewActionReuse || got.ExistingID != "group_1" {
|
||||
t.Fatalf("group decision = %#v, want reuse group_1", got)
|
||||
}
|
||||
if got := report.Decisions["plan"]; got.Action != PreviewActionReuse || got.ExistingID != "plan_1" {
|
||||
t.Fatalf("plan decision = %#v, want reuse plan_1", got)
|
||||
}
|
||||
if got := report.Decisions["channel"]; got.Action != PreviewActionConflict {
|
||||
t.Fatalf("channel decision = %#v, want conflict", got)
|
||||
}
|
||||
}
|
||||
|
||||
type fakePreviewHost struct {
|
||||
snapshot sub2api.ManagedResourceSnapshot
|
||||
}
|
||||
|
||||
func (f *fakePreviewHost) ListManagedResources(context.Context, sub2api.ListManagedResourcesRequest) (sub2api.ManagedResourceSnapshot, error) {
|
||||
return f.snapshot, nil
|
||||
}
|
||||
150
internal/provision/provider_status_service.go
Normal file
150
internal/provision/provider_status_service.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package provision
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/store/sqlite"
|
||||
)
|
||||
|
||||
type ProviderQuery struct {
|
||||
ProviderID string
|
||||
PackID string
|
||||
}
|
||||
|
||||
type ProviderSnapshot struct {
|
||||
Host sqlite.Host
|
||||
Pack sqlite.Pack
|
||||
Provider sqlite.Provider
|
||||
Batch sqlite.ImportBatch
|
||||
ManagedResources []sqlite.ManagedResource
|
||||
AccessClosures []sqlite.AccessClosureRecord
|
||||
ReconcileRuns []sqlite.ReconcileRun
|
||||
ProviderStatus string
|
||||
LatestAccessStatus string
|
||||
LatestReconcileStatus string
|
||||
LatestReconcileSummary map[string]any
|
||||
}
|
||||
|
||||
type ProviderStatusService struct {
|
||||
store *sqlite.DB
|
||||
}
|
||||
|
||||
func NewProviderStatusService(store *sqlite.DB) *ProviderStatusService {
|
||||
return &ProviderStatusService{store: store}
|
||||
}
|
||||
|
||||
func (s *ProviderStatusService) GetStatus(ctx context.Context, query ProviderQuery) (ProviderSnapshot, error) {
|
||||
return s.snapshot(ctx, query)
|
||||
}
|
||||
|
||||
func (s *ProviderStatusService) GetResources(ctx context.Context, query ProviderQuery) (ProviderSnapshot, error) {
|
||||
return s.snapshot(ctx, query)
|
||||
}
|
||||
|
||||
func (s *ProviderStatusService) snapshot(ctx context.Context, query ProviderQuery) (ProviderSnapshot, error) {
|
||||
if s == nil || s.store == nil {
|
||||
return ProviderSnapshot{}, fmt.Errorf("store is required")
|
||||
}
|
||||
provider, err := s.resolveProvider(ctx, query)
|
||||
if err != nil {
|
||||
return ProviderSnapshot{}, err
|
||||
}
|
||||
packRow, err := s.store.Packs().GetByID(ctx, provider.PackID)
|
||||
if err != nil {
|
||||
return ProviderSnapshot{}, err
|
||||
}
|
||||
batchRow, err := s.store.ImportBatches().GetLatestByProviderID(ctx, provider.ID)
|
||||
if err != nil {
|
||||
return ProviderSnapshot{}, err
|
||||
}
|
||||
hostRow, err := s.store.Hosts().GetByID(ctx, batchRow.HostID)
|
||||
if err != nil {
|
||||
return ProviderSnapshot{}, err
|
||||
}
|
||||
managedResources, err := s.store.ManagedResources().GetByBatchID(ctx, batchRow.ID)
|
||||
if err != nil {
|
||||
return ProviderSnapshot{}, err
|
||||
}
|
||||
accessClosures, err := s.store.AccessClosures().GetByBatchID(ctx, batchRow.ID)
|
||||
if err != nil {
|
||||
return ProviderSnapshot{}, err
|
||||
}
|
||||
reconcileRuns, err := s.store.ReconcileRuns().GetByProviderID(ctx, provider.ID)
|
||||
if err != nil {
|
||||
return ProviderSnapshot{}, err
|
||||
}
|
||||
latestAccessStatus := batchRow.AccessStatus
|
||||
if len(accessClosures) > 0 {
|
||||
latestAccessStatus = firstNonEmpty(accessClosures[len(accessClosures)-1].Status, latestAccessStatus)
|
||||
}
|
||||
latestReconcileStatus := "not_run"
|
||||
latestReconcileSummary := map[string]any{}
|
||||
if len(reconcileRuns) > 0 {
|
||||
latestReconcileStatus = firstNonEmpty(reconcileRuns[0].Status, latestReconcileStatus)
|
||||
if strings.TrimSpace(reconcileRuns[0].SummaryJSON) != "" {
|
||||
if err := json.Unmarshal([]byte(reconcileRuns[0].SummaryJSON), &latestReconcileSummary); err != nil {
|
||||
return ProviderSnapshot{}, fmt.Errorf("decode reconcile summary: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
providerStatus := deriveProviderStatus(batchRow.BatchStatus, latestReconcileStatus)
|
||||
return ProviderSnapshot{
|
||||
Host: hostRow,
|
||||
Pack: packRow,
|
||||
Provider: provider,
|
||||
Batch: batchRow,
|
||||
ManagedResources: managedResources,
|
||||
AccessClosures: accessClosures,
|
||||
ReconcileRuns: reconcileRuns,
|
||||
ProviderStatus: providerStatus,
|
||||
LatestAccessStatus: latestAccessStatus,
|
||||
LatestReconcileStatus: latestReconcileStatus,
|
||||
LatestReconcileSummary: latestReconcileSummary,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *ProviderStatusService) resolveProvider(ctx context.Context, query ProviderQuery) (sqlite.Provider, error) {
|
||||
providerID := strings.TrimSpace(query.ProviderID)
|
||||
packID := strings.TrimSpace(query.PackID)
|
||||
if providerID == "" {
|
||||
return sqlite.Provider{}, fmt.Errorf("provider_id is required")
|
||||
}
|
||||
if packID != "" {
|
||||
packRow, err := s.store.Packs().GetByPackID(ctx, packID)
|
||||
if err != nil {
|
||||
return sqlite.Provider{}, err
|
||||
}
|
||||
return s.store.Providers().GetByPackIDAndProviderID(ctx, packRow.ID, providerID)
|
||||
}
|
||||
providers, err := s.store.Providers().ListByProviderID(ctx, providerID)
|
||||
if err != nil {
|
||||
return sqlite.Provider{}, err
|
||||
}
|
||||
if len(providers) == 0 {
|
||||
return sqlite.Provider{}, fmt.Errorf("provider %q not found", providerID)
|
||||
}
|
||||
if len(providers) > 1 {
|
||||
return sqlite.Provider{}, fmt.Errorf("provider %q exists in multiple packs; pack_id is required", providerID)
|
||||
}
|
||||
return providers[0], nil
|
||||
}
|
||||
|
||||
func deriveProviderStatus(batchStatus, reconcileStatus string) string {
|
||||
reconcileStatus = strings.TrimSpace(reconcileStatus)
|
||||
if reconcileStatus != "" && reconcileStatus != "not_run" {
|
||||
return reconcileStatus
|
||||
}
|
||||
switch strings.TrimSpace(batchStatus) {
|
||||
case BatchStatusSucceeded:
|
||||
return ProviderStatusActive
|
||||
case BatchStatusFailed:
|
||||
return ProviderStatusFailed
|
||||
case "running":
|
||||
return "running"
|
||||
default:
|
||||
return firstNonEmpty(batchStatus, "unknown")
|
||||
}
|
||||
}
|
||||
98
internal/provision/provider_status_service_test.go
Normal file
98
internal/provision/provider_status_service_test.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package provision
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/store/sqlite"
|
||||
)
|
||||
|
||||
func TestProviderStatusServiceReturnsLatestSnapshot(t *testing.T) {
|
||||
store := openProvisionTestStore(t)
|
||||
defer closeProvisionTestStore(t, store)
|
||||
|
||||
ctx := context.Background()
|
||||
hostID, err := store.Hosts().Create(ctx, sqlite.Host{HostID: "host-1", BaseURL: "https://sub2api.example.com", HostVersion: "0.1.126", CapabilityProbeJSON: `{"supports_batch_accounts":true}`})
|
||||
if err != nil {
|
||||
t.Fatalf("Hosts().Create() error = %v", err)
|
||||
}
|
||||
packID, err := store.Packs().Create(ctx, sqlite.Pack{PackID: "openai-cn-pack", Version: "1.0.0", Checksum: "checksum-1"})
|
||||
if err != nil {
|
||||
t.Fatalf("Packs().Create() error = %v", err)
|
||||
}
|
||||
providerID, err := store.Providers().Create(ctx, sqlite.Provider{PackID: packID, ProviderID: "deepseek", DisplayName: "DeepSeek", BaseURL: "https://api.deepseek.com", Platform: "openai"})
|
||||
if err != nil {
|
||||
t.Fatalf("Providers().Create() error = %v", err)
|
||||
}
|
||||
batchID, err := store.ImportBatches().Create(ctx, sqlite.ImportBatch{HostID: hostID, PackID: packID, ProviderID: providerID, Mode: ImportModeStrict, BatchStatus: BatchStatusSucceeded, AccessStatus: AccessStatusSelfServiceReady})
|
||||
if err != nil {
|
||||
t.Fatalf("ImportBatches().Create() error = %v", err)
|
||||
}
|
||||
if _, err := store.ManagedResources().Create(ctx, sqlite.ManagedResource{BatchID: batchID, ResourceType: "group", HostResourceID: "group-1", ResourceName: "deepseek-group"}); err != nil {
|
||||
t.Fatalf("ManagedResources().Create(group) error = %v", err)
|
||||
}
|
||||
if _, err := store.ManagedResources().Create(ctx, sqlite.ManagedResource{BatchID: batchID, ResourceType: "account", HostResourceID: "account-1", ResourceName: "deepseek-account-1"}); err != nil {
|
||||
t.Fatalf("ManagedResources().Create(account) error = %v", err)
|
||||
}
|
||||
if _, err := store.AccessClosures().Create(ctx, sqlite.AccessClosureRecord{BatchID: batchID, ClosureType: AccessModeSelfService, Status: AccessStatusSelfServiceReady, DetailsJSON: `{"ok":true}`}); err != nil {
|
||||
t.Fatalf("AccessClosures().Create() error = %v", err)
|
||||
}
|
||||
if _, err := store.ReconcileRuns().Create(ctx, sqlite.ReconcileRun{ProviderID: providerID, Status: "drifted", SummaryJSON: `{"missing_count":1}`}); err != nil {
|
||||
t.Fatalf("ReconcileRuns().Create() error = %v", err)
|
||||
}
|
||||
|
||||
snapshot, err := NewProviderStatusService(store).GetStatus(ctx, ProviderQuery{ProviderID: "deepseek"})
|
||||
if err != nil {
|
||||
t.Fatalf("GetStatus() error = %v", err)
|
||||
}
|
||||
if snapshot.Host.HostID != "host-1" {
|
||||
t.Fatalf("Host.HostID = %q, want host-1", snapshot.Host.HostID)
|
||||
}
|
||||
if snapshot.Pack.PackID != "openai-cn-pack" {
|
||||
t.Fatalf("Pack.PackID = %q, want openai-cn-pack", snapshot.Pack.PackID)
|
||||
}
|
||||
if snapshot.Provider.ProviderID != "deepseek" {
|
||||
t.Fatalf("Provider.ProviderID = %q, want deepseek", snapshot.Provider.ProviderID)
|
||||
}
|
||||
if snapshot.ProviderStatus != "drifted" {
|
||||
t.Fatalf("ProviderStatus = %q, want drifted", snapshot.ProviderStatus)
|
||||
}
|
||||
if snapshot.LatestAccessStatus != AccessStatusSelfServiceReady {
|
||||
t.Fatalf("LatestAccessStatus = %q, want %q", snapshot.LatestAccessStatus, AccessStatusSelfServiceReady)
|
||||
}
|
||||
if snapshot.LatestReconcileStatus != "drifted" {
|
||||
t.Fatalf("LatestReconcileStatus = %q, want drifted", snapshot.LatestReconcileStatus)
|
||||
}
|
||||
if got := len(snapshot.ManagedResources); got != 2 {
|
||||
t.Fatalf("len(ManagedResources) = %d, want 2", got)
|
||||
}
|
||||
if got, ok := snapshot.LatestReconcileSummary["missing_count"].(float64); !ok || got != 1 {
|
||||
t.Fatalf("LatestReconcileSummary[missing_count] = %#v, want 1", snapshot.LatestReconcileSummary["missing_count"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderStatusServiceRequiresPackIDWhenProviderIDIsAmbiguous(t *testing.T) {
|
||||
store := openProvisionTestStore(t)
|
||||
defer closeProvisionTestStore(t, store)
|
||||
|
||||
ctx := context.Background()
|
||||
pack1, err := store.Packs().Create(ctx, sqlite.Pack{PackID: "pack-a", Version: "1.0.0", Checksum: "checksum-a"})
|
||||
if err != nil {
|
||||
t.Fatalf("Packs().Create(pack-a) error = %v", err)
|
||||
}
|
||||
pack2, err := store.Packs().Create(ctx, sqlite.Pack{PackID: "pack-b", Version: "1.0.0", Checksum: "checksum-b"})
|
||||
if err != nil {
|
||||
t.Fatalf("Packs().Create(pack-b) error = %v", err)
|
||||
}
|
||||
if _, err := store.Providers().Create(ctx, sqlite.Provider{PackID: pack1, ProviderID: "deepseek", DisplayName: "DeepSeek A", BaseURL: "https://a.example.com", Platform: "openai"}); err != nil {
|
||||
t.Fatalf("Providers().Create(pack-a) error = %v", err)
|
||||
}
|
||||
if _, err := store.Providers().Create(ctx, sqlite.Provider{PackID: pack2, ProviderID: "deepseek", DisplayName: "DeepSeek B", BaseURL: "https://b.example.com", Platform: "openai"}); err != nil {
|
||||
t.Fatalf("Providers().Create(pack-b) error = %v", err)
|
||||
}
|
||||
|
||||
_, err = NewProviderStatusService(store).GetStatus(ctx, ProviderQuery{ProviderID: "deepseek"})
|
||||
if err == nil || err.Error() != `provider "deepseek" exists in multiple packs; pack_id is required` {
|
||||
t.Fatalf("GetStatus() error = %v, want ambiguous provider error", err)
|
||||
}
|
||||
}
|
||||
187
internal/provision/reconcile_service_test.go
Normal file
187
internal/provision/reconcile_service_test.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package provision
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/host/sub2api"
|
||||
"sub2api-cn-relay-manager/internal/pack"
|
||||
"sub2api-cn-relay-manager/internal/store/sqlite"
|
||||
)
|
||||
|
||||
func TestReconcileServiceReturnsActiveAfterProbeRerun(t *testing.T) {
|
||||
store := openProvisionTestStore(t)
|
||||
defer closeProvisionTestStore(t, store)
|
||||
|
||||
host := &fakeHostAdapter{
|
||||
batchAccounts: []sub2api.AccountRef{{ID: "account_1", Name: "deepseek-01"}, {ID: "account_2", Name: "deepseek-02"}},
|
||||
testResults: map[string]sub2api.ProbeResult{
|
||||
"account_1": {OK: true, Status: "passed"},
|
||||
"account_2": {OK: true, Status: "passed"},
|
||||
},
|
||||
models: map[string][]sub2api.AccountModel{
|
||||
"account_1": {{ID: "deepseek-chat"}},
|
||||
"account_2": {{ID: "deepseek-chat"}},
|
||||
},
|
||||
gatewayResult: sub2api.GatewayAccessResult{OK: true, StatusCode: 200, HasExpectedModel: true, Models: []string{"deepseek-chat"}},
|
||||
}
|
||||
|
||||
batchID := seedRuntimeImportForReconcile(t, store, host)
|
||||
host.managedSnapshot = sub2api.ManagedResourceSnapshot{
|
||||
Groups: []sub2api.NamedResource{{ID: "group_1", Name: "g"}},
|
||||
Channels: []sub2api.NamedResource{{ID: "channel_1", Name: "c"}},
|
||||
Accounts: []sub2api.NamedResource{{ID: "account_1", Name: "deepseek-01"}, {ID: "account_2", Name: "deepseek-02"}},
|
||||
}
|
||||
|
||||
result, err := NewReconcileService(store, host).Reconcile(context.Background(), ReconcileRequest{
|
||||
HostBaseURL: "https://sub2api.example.com",
|
||||
AccessProbeAPIKey: "user-key",
|
||||
Pack: pack.LoadedPack{Manifest: pack.Manifest{PackID: "openai-cn-pack", Version: "1.0.0", TargetHost: "sub2api", MinHostVersion: "0.1.126", MaxHostVersion: "0.2.x"}},
|
||||
Provider: sampleProviderManifest(),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Reconcile() error = %v", err)
|
||||
}
|
||||
if result.BatchID != batchID {
|
||||
t.Fatalf("BatchID = %d, want %d", result.BatchID, batchID)
|
||||
}
|
||||
if result.Status != "active" {
|
||||
t.Fatalf("Status = %q, want active", result.Status)
|
||||
}
|
||||
if result.ProbeFailureCount != 0 {
|
||||
t.Fatalf("ProbeFailureCount = %d, want 0", result.ProbeFailureCount)
|
||||
}
|
||||
if result.AccessStatus != AccessStatusSelfServiceReady {
|
||||
t.Fatalf("AccessStatus = %q, want %q", result.AccessStatus, AccessStatusSelfServiceReady)
|
||||
}
|
||||
if got := queryCount(t, store.SQLDB(), "probe_results"); got != 4 {
|
||||
t.Fatalf("probe_results row count = %d, want 4 after rerun", got)
|
||||
}
|
||||
if got := queryCount(t, store.SQLDB(), "access_closure_records"); got != 2 {
|
||||
t.Fatalf("access_closure_records row count = %d, want 2 after rerun", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReconcileServiceReturnsDegradedWhenProbeRerunFails(t *testing.T) {
|
||||
store := openProvisionTestStore(t)
|
||||
defer closeProvisionTestStore(t, store)
|
||||
|
||||
host := &fakeHostAdapter{
|
||||
batchAccounts: []sub2api.AccountRef{{ID: "account_1", Name: "deepseek-01"}, {ID: "account_2", Name: "deepseek-02"}},
|
||||
testResults: map[string]sub2api.ProbeResult{
|
||||
"account_1": {OK: true, Status: "passed"},
|
||||
"account_2": {OK: false, Status: "failed", Message: "bad key"},
|
||||
},
|
||||
models: map[string][]sub2api.AccountModel{
|
||||
"account_1": {{ID: "deepseek-chat"}},
|
||||
"account_2": {{ID: "deepseek-chat"}},
|
||||
},
|
||||
gatewayResult: sub2api.GatewayAccessResult{OK: true, StatusCode: 200, HasExpectedModel: true, Models: []string{"deepseek-chat"}},
|
||||
}
|
||||
|
||||
seedRuntimeImportForReconcile(t, store, host)
|
||||
host.managedSnapshot = sub2api.ManagedResourceSnapshot{
|
||||
Groups: []sub2api.NamedResource{{ID: "group_1", Name: "g"}},
|
||||
Channels: []sub2api.NamedResource{{ID: "channel_1", Name: "c"}},
|
||||
Accounts: []sub2api.NamedResource{{ID: "account_1", Name: "deepseek-01"}, {ID: "account_2", Name: "deepseek-02"}},
|
||||
}
|
||||
|
||||
result, err := NewReconcileService(store, host).Reconcile(context.Background(), ReconcileRequest{
|
||||
HostBaseURL: "https://sub2api.example.com",
|
||||
AccessProbeAPIKey: "user-key",
|
||||
Pack: pack.LoadedPack{Manifest: pack.Manifest{PackID: "openai-cn-pack", Version: "1.0.0", TargetHost: "sub2api", MinHostVersion: "0.1.126", MaxHostVersion: "0.2.x"}},
|
||||
Provider: sampleProviderManifest(),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Reconcile() error = %v", err)
|
||||
}
|
||||
if result.Status != "degraded" {
|
||||
t.Fatalf("Status = %q, want degraded", result.Status)
|
||||
}
|
||||
if result.ProbeFailureCount != 1 {
|
||||
t.Fatalf("ProbeFailureCount = %d, want 1", result.ProbeFailureCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReconcileServiceReturnsDriftedWhenManagedResourceMissing(t *testing.T) {
|
||||
store := openProvisionTestStore(t)
|
||||
defer closeProvisionTestStore(t, store)
|
||||
|
||||
host := &fakeHostAdapter{
|
||||
batchAccounts: []sub2api.AccountRef{{ID: "account_1", Name: "deepseek-01"}, {ID: "account_2", Name: "deepseek-02"}},
|
||||
testResults: map[string]sub2api.ProbeResult{
|
||||
"account_1": {OK: true, Status: "passed"},
|
||||
"account_2": {OK: true, Status: "passed"},
|
||||
},
|
||||
models: map[string][]sub2api.AccountModel{
|
||||
"account_1": {{ID: "deepseek-chat"}},
|
||||
"account_2": {{ID: "deepseek-chat"}},
|
||||
},
|
||||
gatewayResult: sub2api.GatewayAccessResult{OK: true, StatusCode: 200, HasExpectedModel: true, Models: []string{"deepseek-chat"}},
|
||||
}
|
||||
|
||||
seedRuntimeImportForReconcile(t, store, host)
|
||||
host.managedSnapshot = sub2api.ManagedResourceSnapshot{
|
||||
Groups: []sub2api.NamedResource{{ID: "group_1", Name: "g"}},
|
||||
Channels: []sub2api.NamedResource{{ID: "channel_1", Name: "c"}},
|
||||
Accounts: []sub2api.NamedResource{{ID: "account_1", Name: "deepseek-01"}},
|
||||
}
|
||||
|
||||
result, err := NewReconcileService(store, host).Reconcile(context.Background(), ReconcileRequest{
|
||||
HostBaseURL: "https://sub2api.example.com",
|
||||
AccessProbeAPIKey: "user-key",
|
||||
Pack: pack.LoadedPack{Manifest: pack.Manifest{PackID: "openai-cn-pack", Version: "1.0.0", TargetHost: "sub2api", MinHostVersion: "0.1.126", MaxHostVersion: "0.2.x"}},
|
||||
Provider: sampleProviderManifest(),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Reconcile() error = %v", err)
|
||||
}
|
||||
if result.Status != "drifted" {
|
||||
t.Fatalf("Status = %q, want drifted", result.Status)
|
||||
}
|
||||
if result.MissingCount != 1 {
|
||||
t.Fatalf("MissingCount = %d, want 1", result.MissingCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeriveHealthyAccessStatus(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
closureType string
|
||||
want string
|
||||
}{
|
||||
{name: "subscription", closureType: AccessModeSubscription, want: AccessStatusSubscriptionReady},
|
||||
{name: "self-service", closureType: AccessModeSelfService, want: AccessStatusSelfServiceReady},
|
||||
{name: "unknown", closureType: "other", want: "unknown"},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := deriveHealthyAccessStatus(tc.closureType); got != tc.want {
|
||||
t.Fatalf("deriveHealthyAccessStatus(%q) = %q, want %q", tc.closureType, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func seedRuntimeImportForReconcile(t *testing.T, store *sqlite.DB, host *fakeHostAdapter) int64 {
|
||||
t.Helper()
|
||||
result, err := NewRuntimeImportService(store, host).Import(context.Background(), RuntimeImportRequest{
|
||||
HostID: "host-1",
|
||||
HostBaseURL: "https://sub2api.example.com",
|
||||
Pack: pack.LoadedPack{
|
||||
Manifest: pack.Manifest{PackID: "openai-cn-pack", Version: "1.0.0", TargetHost: "sub2api", MinHostVersion: "0.1.126", MaxHostVersion: "0.2.x"},
|
||||
Checksum: "checksum-1",
|
||||
},
|
||||
Provider: sampleProviderManifest(),
|
||||
Mode: ImportModePartial,
|
||||
Keys: []string{"key-1", "key-2"},
|
||||
Access: AccessRequest{
|
||||
Mode: AccessModeSelfService,
|
||||
ProbeAPIKey: "user-key",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("seed RuntimeImportService.Import() error = %v", err)
|
||||
}
|
||||
return result.BatchID
|
||||
}
|
||||
90
internal/provision/rollback_service.go
Normal file
90
internal/provision/rollback_service.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package provision
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/host/sub2api"
|
||||
"sub2api-cn-relay-manager/internal/pack"
|
||||
)
|
||||
|
||||
type rollbackHost interface {
|
||||
ListManagedResources(ctx context.Context, req sub2api.ListManagedResourcesRequest) (sub2api.ManagedResourceSnapshot, error)
|
||||
DeleteAccount(ctx context.Context, accountID string) error
|
||||
DeletePlan(ctx context.Context, planID string) error
|
||||
DeleteChannel(ctx context.Context, channelID string) error
|
||||
DeleteGroup(ctx context.Context, groupID string) error
|
||||
}
|
||||
|
||||
type RollbackRequest struct {
|
||||
Provider pack.ProviderManifest
|
||||
}
|
||||
|
||||
type RollbackReport struct {
|
||||
AccountsDeleted int
|
||||
PlansDeleted int
|
||||
ChannelsDeleted int
|
||||
GroupsDeleted int
|
||||
}
|
||||
|
||||
type RollbackService struct {
|
||||
host rollbackHost
|
||||
}
|
||||
|
||||
func NewRollbackService(host rollbackHost) *RollbackService {
|
||||
return &RollbackService{host: host}
|
||||
}
|
||||
|
||||
func (s *RollbackService) Rollback(ctx context.Context, req RollbackRequest) (RollbackReport, error) {
|
||||
if s.host == nil {
|
||||
return RollbackReport{}, fmt.Errorf("rollback host is required")
|
||||
}
|
||||
|
||||
names := SuggestResourceNames(req.Provider)
|
||||
snapshot, err := s.host.ListManagedResources(ctx, sub2api.ListManagedResourcesRequest{
|
||||
GroupName: names.Group,
|
||||
ChannelName: names.Channel,
|
||||
PlanName: names.Plan,
|
||||
AccountNamePrefix: SuggestAccountNamePrefix(req.Provider),
|
||||
})
|
||||
if err != nil {
|
||||
return RollbackReport{}, fmt.Errorf("list managed resources: %w", err)
|
||||
}
|
||||
|
||||
var report RollbackReport
|
||||
var errs []error
|
||||
for index := len(snapshot.Accounts) - 1; index >= 0; index-- {
|
||||
if err := s.host.DeleteAccount(ctx, snapshot.Accounts[index].ID); err != nil {
|
||||
errs = append(errs, fmt.Errorf("delete account %s: %w", snapshot.Accounts[index].ID, err))
|
||||
continue
|
||||
}
|
||||
report.AccountsDeleted++
|
||||
}
|
||||
for index := len(snapshot.Plans) - 1; index >= 0; index-- {
|
||||
if err := s.host.DeletePlan(ctx, snapshot.Plans[index].ID); err != nil {
|
||||
errs = append(errs, fmt.Errorf("delete plan %s: %w", snapshot.Plans[index].ID, err))
|
||||
continue
|
||||
}
|
||||
report.PlansDeleted++
|
||||
}
|
||||
for index := len(snapshot.Channels) - 1; index >= 0; index-- {
|
||||
if err := s.host.DeleteChannel(ctx, snapshot.Channels[index].ID); err != nil {
|
||||
errs = append(errs, fmt.Errorf("delete channel %s: %w", snapshot.Channels[index].ID, err))
|
||||
continue
|
||||
}
|
||||
report.ChannelsDeleted++
|
||||
}
|
||||
for index := len(snapshot.Groups) - 1; index >= 0; index-- {
|
||||
if err := s.host.DeleteGroup(ctx, snapshot.Groups[index].ID); err != nil {
|
||||
errs = append(errs, fmt.Errorf("delete group %s: %w", snapshot.Groups[index].ID, err))
|
||||
continue
|
||||
}
|
||||
report.GroupsDeleted++
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
return report, errors.Join(errs...)
|
||||
}
|
||||
return report, nil
|
||||
}
|
||||
50
internal/provision/rollback_service_test.go
Normal file
50
internal/provision/rollback_service_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package provision
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/host/sub2api"
|
||||
)
|
||||
|
||||
func TestRollbackServiceDeletesManagedResourcesInDependencyOrder(t *testing.T) {
|
||||
host := &fakeHostAdapter{
|
||||
managedSnapshot: sub2api.ManagedResourceSnapshot{
|
||||
Groups: []sub2api.NamedResource{{ID: "group_1", Name: "crm-deepseek-group"}},
|
||||
Channels: []sub2api.NamedResource{{ID: "channel_1", Name: "crm-deepseek-channel"}},
|
||||
Plans: []sub2api.NamedResource{{ID: "plan_1", Name: "crm-deepseek-plan"}},
|
||||
Accounts: []sub2api.NamedResource{{ID: "account_1", Name: "deepseek-01"}, {ID: "account_2", Name: "deepseek-02"}},
|
||||
},
|
||||
}
|
||||
|
||||
svc := NewRollbackService(host)
|
||||
report, err := svc.Rollback(context.Background(), RollbackRequest{Provider: sampleProviderManifest()})
|
||||
if err != nil {
|
||||
t.Fatalf("Rollback() error = %v", err)
|
||||
}
|
||||
|
||||
if report.AccountsDeleted != 2 || report.PlansDeleted != 1 || report.ChannelsDeleted != 1 || report.GroupsDeleted != 1 {
|
||||
t.Fatalf("Rollback() report = %+v, want all managed resources deleted", report)
|
||||
}
|
||||
want := []string{"account:account_2", "account:account_1", "plan:plan_1", "channel:channel_1", "group:group_1"}
|
||||
if !reflect.DeepEqual(host.deletedResources, want) {
|
||||
t.Fatalf("deleted resources = %#v, want %#v", host.deletedResources, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRollbackServiceReturnsEmptyReportWhenNoManagedResourcesExist(t *testing.T) {
|
||||
host := &fakeHostAdapter{}
|
||||
svc := NewRollbackService(host)
|
||||
|
||||
report, err := svc.Rollback(context.Background(), RollbackRequest{Provider: sampleProviderManifest()})
|
||||
if err != nil {
|
||||
t.Fatalf("Rollback() error = %v", err)
|
||||
}
|
||||
if report.AccountsDeleted != 0 || report.PlansDeleted != 0 || report.ChannelsDeleted != 0 || report.GroupsDeleted != 0 {
|
||||
t.Fatalf("Rollback() report = %+v, want zero deletions", report)
|
||||
}
|
||||
if len(host.deletedResources) != 0 {
|
||||
t.Fatalf("deleted resources = %#v, want none", host.deletedResources)
|
||||
}
|
||||
}
|
||||
259
internal/provision/runtime_import_service.go
Normal file
259
internal/provision/runtime_import_service.go
Normal file
@@ -0,0 +1,259 @@
|
||||
package provision
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/pack"
|
||||
"sub2api-cn-relay-manager/internal/store/sqlite"
|
||||
)
|
||||
|
||||
type RuntimeImportRequest struct {
|
||||
HostID string
|
||||
HostBaseURL string
|
||||
Pack pack.LoadedPack
|
||||
Provider pack.ProviderManifest
|
||||
Mode string
|
||||
Access AccessRequest
|
||||
Keys []string
|
||||
}
|
||||
|
||||
type RuntimeImportResult struct {
|
||||
BatchID int64
|
||||
Report ImportReport
|
||||
}
|
||||
|
||||
type RuntimeImportService struct {
|
||||
store *sqlite.DB
|
||||
host hostAdapter
|
||||
}
|
||||
|
||||
func NewRuntimeImportService(store *sqlite.DB, host hostAdapter) *RuntimeImportService {
|
||||
return &RuntimeImportService{store: store, host: host}
|
||||
}
|
||||
|
||||
func (s *RuntimeImportService) Import(ctx context.Context, req RuntimeImportRequest) (RuntimeImportResult, error) {
|
||||
if s == nil || s.store == nil {
|
||||
return RuntimeImportResult{}, fmt.Errorf("store is required")
|
||||
}
|
||||
if s.host == nil {
|
||||
return RuntimeImportResult{}, fmt.Errorf("host adapter is required")
|
||||
}
|
||||
req.HostBaseURL = strings.TrimSpace(req.HostBaseURL)
|
||||
if req.HostBaseURL == "" {
|
||||
return RuntimeImportResult{}, fmt.Errorf("host_base_url is required")
|
||||
}
|
||||
if strings.TrimSpace(req.HostID) == "" {
|
||||
req.HostID = req.HostBaseURL
|
||||
}
|
||||
|
||||
hostVersion, err := s.host.GetHostVersion(ctx)
|
||||
if err != nil {
|
||||
return RuntimeImportResult{}, fmt.Errorf("get host version: %w", err)
|
||||
}
|
||||
if err := pack.CheckHostCompatibility(req.Pack.Manifest, hostVersion); err != nil {
|
||||
return RuntimeImportResult{}, err
|
||||
}
|
||||
capabilities, err := s.host.ProbeCapabilities(ctx)
|
||||
if err != nil {
|
||||
return RuntimeImportResult{}, fmt.Errorf("probe host capabilities: %w", err)
|
||||
}
|
||||
capabilityProbeJSON, err := json.Marshal(capabilities)
|
||||
if err != nil {
|
||||
return RuntimeImportResult{}, fmt.Errorf("marshal host capabilities: %w", err)
|
||||
}
|
||||
|
||||
hostRow, err := s.ensureHost(ctx, req.HostID, req.HostBaseURL, hostVersion, string(capabilityProbeJSON))
|
||||
if err != nil {
|
||||
return RuntimeImportResult{}, err
|
||||
}
|
||||
packRow, err := s.ensurePack(ctx, req.Pack)
|
||||
if err != nil {
|
||||
return RuntimeImportResult{}, err
|
||||
}
|
||||
providerRow, err := s.ensureProvider(ctx, packRow.ID, req.Provider)
|
||||
if err != nil {
|
||||
return RuntimeImportResult{}, err
|
||||
}
|
||||
|
||||
batchID, err := s.store.ImportBatches().Create(ctx, sqlite.ImportBatch{
|
||||
HostID: hostRow.ID,
|
||||
PackID: packRow.ID,
|
||||
ProviderID: providerRow.ID,
|
||||
Mode: req.Mode,
|
||||
BatchStatus: "running",
|
||||
AccessStatus: "pending",
|
||||
})
|
||||
if err != nil {
|
||||
return RuntimeImportResult{}, err
|
||||
}
|
||||
|
||||
report, importErr := NewImportService(s.host).Import(ctx, ImportRequest{
|
||||
Provider: req.Provider,
|
||||
Mode: req.Mode,
|
||||
Access: req.Access,
|
||||
Keys: req.Keys,
|
||||
})
|
||||
if report.BatchStatus == "" {
|
||||
report.BatchStatus = BatchStatusFailed
|
||||
}
|
||||
if report.AccessStatus == "" {
|
||||
report.AccessStatus = AccessStatusBroken
|
||||
}
|
||||
|
||||
if persistErr := s.persistRuntimeArtifacts(ctx, batchID, req.Access.Mode, report, importErr == nil); persistErr != nil {
|
||||
return RuntimeImportResult{}, persistErr
|
||||
}
|
||||
if err := s.store.ImportBatches().UpdateStatus(ctx, batchID, report.BatchStatus, report.AccessStatus); err != nil {
|
||||
return RuntimeImportResult{}, err
|
||||
}
|
||||
if importErr != nil {
|
||||
return RuntimeImportResult{BatchID: batchID, Report: report}, importErr
|
||||
}
|
||||
return RuntimeImportResult{BatchID: batchID, Report: report}, nil
|
||||
}
|
||||
|
||||
func (s *RuntimeImportService) ensureHost(ctx context.Context, hostID, baseURL, hostVersion, capabilityProbeJSON string) (sqlite.Host, error) {
|
||||
host, err := s.store.Hosts().GetByHostID(ctx, hostID)
|
||||
if err == nil {
|
||||
return host, nil
|
||||
}
|
||||
if _, createErr := s.store.Hosts().Create(ctx, sqlite.Host{
|
||||
HostID: hostID,
|
||||
BaseURL: baseURL,
|
||||
HostVersion: hostVersion,
|
||||
CapabilityProbeJSON: capabilityProbeJSON,
|
||||
}); createErr != nil {
|
||||
return sqlite.Host{}, createErr
|
||||
}
|
||||
return s.store.Hosts().GetByHostID(ctx, hostID)
|
||||
}
|
||||
|
||||
func (s *RuntimeImportService) ensurePack(ctx context.Context, loaded pack.LoadedPack) (sqlite.Pack, error) {
|
||||
packRow, err := s.store.Packs().GetByPackID(ctx, loaded.Manifest.PackID)
|
||||
if err == nil {
|
||||
if err := validateExistingPack(packRow, loaded); err != nil {
|
||||
return sqlite.Pack{}, err
|
||||
}
|
||||
}
|
||||
packRecord, err := buildPackRecord(loaded)
|
||||
if err != nil {
|
||||
return sqlite.Pack{}, err
|
||||
}
|
||||
if _, err := s.store.Packs().Upsert(ctx, packRecord); err != nil {
|
||||
return sqlite.Pack{}, err
|
||||
}
|
||||
return s.store.Packs().GetByPackID(ctx, loaded.Manifest.PackID)
|
||||
}
|
||||
|
||||
func (s *RuntimeImportService) ensureProvider(ctx context.Context, packID int64, provider pack.ProviderManifest) (sqlite.Provider, error) {
|
||||
if _, err := s.store.Providers().GetByPackIDAndProviderID(ctx, packID, provider.ProviderID); err == nil {
|
||||
// continue into upsert path so metadata stays fresh.
|
||||
}
|
||||
providerRecord, err := buildProviderRecord(packID, provider)
|
||||
if err != nil {
|
||||
return sqlite.Provider{}, err
|
||||
}
|
||||
if _, err := s.store.Providers().Upsert(ctx, providerRecord); err != nil {
|
||||
return sqlite.Provider{}, err
|
||||
}
|
||||
return s.store.Providers().GetByPackIDAndProviderID(ctx, packID, provider.ProviderID)
|
||||
}
|
||||
|
||||
func (s *RuntimeImportService) persistRuntimeArtifacts(ctx context.Context, batchID int64, accessMode string, report ImportReport, includeManagedResources bool) error {
|
||||
for i, account := range report.Accounts {
|
||||
payload, err := json.Marshal(map[string]any{
|
||||
"account_id": account.Ref.ID,
|
||||
"probe_ok": account.Probe.OK,
|
||||
"probe_status": account.Probe.Status,
|
||||
"probe_message": account.Probe.Message,
|
||||
"models": account.Models,
|
||||
"smoke_model_seen": account.SmokeModelSeen,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal account probe summary: %w", err)
|
||||
}
|
||||
itemID, err := s.store.ImportBatchItems().Create(ctx, sqlite.ImportBatchItem{
|
||||
BatchID: batchID,
|
||||
KeyFingerprint: fingerprintKey(report.AcceptedKeys, i),
|
||||
AccountStatus: firstNonEmpty(account.Probe.Status, "unknown"),
|
||||
ProbeSummaryJSON: string(payload),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := s.store.ProbeResults().Create(ctx, sqlite.ProbeResult{
|
||||
BatchItemID: itemID,
|
||||
ProbeType: "account_smoke",
|
||||
Status: firstNonEmpty(account.Probe.Status, "unknown"),
|
||||
SummaryJSON: string(payload),
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if includeManagedResources {
|
||||
if report.Group.ID != "" {
|
||||
if _, err := s.store.ManagedResources().Create(ctx, sqlite.ManagedResource{BatchID: batchID, ResourceType: "group", HostResourceID: report.Group.ID, ResourceName: firstNonEmpty(report.Group.Name, report.Group.ID)}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if report.Channel.ID != "" {
|
||||
if _, err := s.store.ManagedResources().Create(ctx, sqlite.ManagedResource{BatchID: batchID, ResourceType: "channel", HostResourceID: report.Channel.ID, ResourceName: firstNonEmpty(report.Channel.Name, report.Channel.ID)}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if report.Plan != nil && report.Plan.ID != "" {
|
||||
if _, err := s.store.ManagedResources().Create(ctx, sqlite.ManagedResource{BatchID: batchID, ResourceType: "plan", HostResourceID: report.Plan.ID, ResourceName: firstNonEmpty(report.Plan.Name, report.Plan.ID)}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
for _, account := range report.Accounts {
|
||||
if _, err := s.store.ManagedResources().Create(ctx, sqlite.ManagedResource{BatchID: batchID, ResourceType: "account", HostResourceID: account.Ref.ID, ResourceName: firstNonEmpty(account.Ref.Name, account.Ref.ID)}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
accessPayload, err := json.Marshal(map[string]any{
|
||||
"status_code": report.Gateway.StatusCode,
|
||||
"ok": report.Gateway.OK,
|
||||
"has_expected_model": report.Gateway.HasExpectedModel,
|
||||
"models": report.Gateway.Models,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal gateway access summary: %w", err)
|
||||
}
|
||||
if _, err := s.store.AccessClosures().Create(ctx, sqlite.AccessClosureRecord{
|
||||
BatchID: batchID,
|
||||
ClosureType: firstNonEmpty(strings.TrimSpace(accessMode), "unknown"),
|
||||
Status: firstNonEmpty(report.AccessStatus, AccessStatusBroken),
|
||||
DetailsJSON: string(accessPayload),
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func fingerprintKey(keys []string, index int) string {
|
||||
if index >= 0 && index < len(keys) {
|
||||
key := strings.TrimSpace(keys[index])
|
||||
if key != "" {
|
||||
sum := sha256.Sum256([]byte(key))
|
||||
return fmt.Sprintf("sha256:%x", sum[:])
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("key-%d", index+1)
|
||||
}
|
||||
|
||||
func firstNonEmpty(values ...string) string {
|
||||
for _, value := range values {
|
||||
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
196
internal/provision/runtime_import_service_test.go
Normal file
196
internal/provision/runtime_import_service_test.go
Normal file
@@ -0,0 +1,196 @@
|
||||
package provision
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
"sub2api-cn-relay-manager/internal/host/sub2api"
|
||||
"sub2api-cn-relay-manager/internal/pack"
|
||||
"sub2api-cn-relay-manager/internal/store/sqlite"
|
||||
)
|
||||
|
||||
func TestRuntimeImportServicePersistsOperationalState(t *testing.T) {
|
||||
store := openProvisionTestStore(t)
|
||||
defer closeProvisionTestStore(t, store)
|
||||
|
||||
host := &fakeHostAdapter{
|
||||
batchAccounts: []sub2api.AccountRef{{ID: "account_1"}, {ID: "account_2"}},
|
||||
testResults: map[string]sub2api.ProbeResult{
|
||||
"account_1": {OK: true, Status: "passed"},
|
||||
"account_2": {OK: true, Status: "passed"},
|
||||
},
|
||||
models: map[string][]sub2api.AccountModel{
|
||||
"account_1": {{ID: "deepseek-chat"}},
|
||||
"account_2": {{ID: "deepseek-chat"}},
|
||||
},
|
||||
gatewayResult: sub2api.GatewayAccessResult{OK: true, StatusCode: 200, HasExpectedModel: true, Models: []string{"deepseek-chat"}},
|
||||
}
|
||||
|
||||
svc := NewRuntimeImportService(store, host)
|
||||
result, err := svc.Import(context.Background(), RuntimeImportRequest{
|
||||
HostID: "host-1",
|
||||
HostBaseURL: "https://sub2api.example.com",
|
||||
Pack: pack.LoadedPack{
|
||||
Manifest: pack.Manifest{PackID: "openai-cn-pack", Version: "1.0.0", TargetHost: "sub2api", MinHostVersion: "0.1.126", MaxHostVersion: "0.2.x"},
|
||||
Checksum: "checksum-1",
|
||||
},
|
||||
Provider: sampleProviderManifest(),
|
||||
Mode: ImportModePartial,
|
||||
Keys: []string{" key-1 ", "key-2", "key-1"},
|
||||
Access: AccessRequest{
|
||||
Mode: AccessModeSelfService,
|
||||
ProbeAPIKey: "user-key",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("RuntimeImportService.Import() error = %v", err)
|
||||
}
|
||||
if result.BatchID <= 0 {
|
||||
t.Fatalf("BatchID = %d, want positive id", result.BatchID)
|
||||
}
|
||||
if result.Report.BatchStatus != BatchStatusSucceeded {
|
||||
t.Fatalf("BatchStatus = %q, want %q", result.Report.BatchStatus, BatchStatusSucceeded)
|
||||
}
|
||||
|
||||
if got := queryCount(t, store.SQLDB(), "hosts"); got != 1 {
|
||||
t.Fatalf("hosts row count = %d, want 1", got)
|
||||
}
|
||||
if got := queryCount(t, store.SQLDB(), "packs"); got != 1 {
|
||||
t.Fatalf("packs row count = %d, want 1", got)
|
||||
}
|
||||
if got := queryCount(t, store.SQLDB(), "providers"); got != 1 {
|
||||
t.Fatalf("providers row count = %d, want 1", got)
|
||||
}
|
||||
if got := queryCount(t, store.SQLDB(), "import_batches"); got != 1 {
|
||||
t.Fatalf("import_batches row count = %d, want 1", got)
|
||||
}
|
||||
if got := queryCount(t, store.SQLDB(), "import_batch_items"); got != 2 {
|
||||
t.Fatalf("import_batch_items row count = %d, want 2", got)
|
||||
}
|
||||
if got := queryCount(t, store.SQLDB(), "managed_resources"); got != 4 {
|
||||
t.Fatalf("managed_resources row count = %d, want 4", got)
|
||||
}
|
||||
if got := queryCount(t, store.SQLDB(), "probe_results"); got != 2 {
|
||||
t.Fatalf("probe_results row count = %d, want 2", got)
|
||||
}
|
||||
if got := queryCount(t, store.SQLDB(), "access_closure_records"); got != 1 {
|
||||
t.Fatalf("access_closure_records row count = %d, want 1", got)
|
||||
}
|
||||
|
||||
var batchStatus string
|
||||
var accessStatus string
|
||||
if err := store.SQLDB().QueryRowContext(context.Background(), "SELECT batch_status, access_status FROM import_batches WHERE id = ?", result.BatchID).Scan(&batchStatus, &accessStatus); err != nil {
|
||||
t.Fatalf("query import batch state: %v", err)
|
||||
}
|
||||
if batchStatus != BatchStatusSucceeded {
|
||||
t.Fatalf("persisted batch_status = %q, want %q", batchStatus, BatchStatusSucceeded)
|
||||
}
|
||||
if accessStatus != AccessStatusSelfServiceReady {
|
||||
t.Fatalf("persisted access_status = %q, want %q", accessStatus, AccessStatusSelfServiceReady)
|
||||
}
|
||||
|
||||
var fingerprint string
|
||||
var accountStatus string
|
||||
if err := store.SQLDB().QueryRowContext(context.Background(), "SELECT key_fingerprint, account_status FROM import_batch_items ORDER BY id LIMIT 1").Scan(&fingerprint, &accountStatus); err != nil {
|
||||
t.Fatalf("query import batch item: %v", err)
|
||||
}
|
||||
if fingerprint == "key-1" || fingerprint == "key-2" || len(fingerprint) < 10 {
|
||||
t.Fatalf("key_fingerprint = %q, want hashed fingerprint instead of raw key", fingerprint)
|
||||
}
|
||||
if accountStatus != "passed" {
|
||||
t.Fatalf("account_status = %q, want passed", accountStatus)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuntimeImportServicePersistsFailedBatchAfterStrictRollback(t *testing.T) {
|
||||
store := openProvisionTestStore(t)
|
||||
defer closeProvisionTestStore(t, store)
|
||||
|
||||
host := &fakeHostAdapter{
|
||||
batchAccounts: []sub2api.AccountRef{{ID: "account_1"}, {ID: "account_2"}},
|
||||
testResults: map[string]sub2api.ProbeResult{
|
||||
"account_1": {OK: true, Status: "passed"},
|
||||
"account_2": {OK: false, Status: "failed", Message: "bad key"},
|
||||
},
|
||||
models: map[string][]sub2api.AccountModel{
|
||||
"account_1": {{ID: "deepseek-chat"}},
|
||||
"account_2": {{ID: "deepseek-chat"}},
|
||||
},
|
||||
}
|
||||
|
||||
svc := NewRuntimeImportService(store, host)
|
||||
result, err := svc.Import(context.Background(), RuntimeImportRequest{
|
||||
HostID: "host-1",
|
||||
HostBaseURL: "https://sub2api.example.com",
|
||||
Pack: pack.LoadedPack{
|
||||
Manifest: pack.Manifest{PackID: "openai-cn-pack", Version: "1.0.0", TargetHost: "sub2api", MinHostVersion: "0.1.126", MaxHostVersion: "0.2.x"},
|
||||
Checksum: "checksum-1",
|
||||
},
|
||||
Provider: sampleProviderManifest(),
|
||||
Mode: ImportModeStrict,
|
||||
Keys: []string{"key-1", "key-2"},
|
||||
Access: AccessRequest{
|
||||
Mode: AccessModeSelfService,
|
||||
ProbeAPIKey: "user-key",
|
||||
},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("RuntimeImportService.Import() error = nil, want strict failure")
|
||||
}
|
||||
if result.BatchID <= 0 {
|
||||
t.Fatalf("BatchID = %d, want positive id", result.BatchID)
|
||||
}
|
||||
|
||||
var batchStatus string
|
||||
var accessStatus string
|
||||
if err := store.SQLDB().QueryRowContext(context.Background(), "SELECT batch_status, access_status FROM import_batches WHERE id = ?", result.BatchID).Scan(&batchStatus, &accessStatus); err != nil {
|
||||
t.Fatalf("query failed import batch state: %v", err)
|
||||
}
|
||||
if batchStatus != BatchStatusFailed {
|
||||
t.Fatalf("persisted batch_status = %q, want %q", batchStatus, BatchStatusFailed)
|
||||
}
|
||||
if accessStatus != AccessStatusBroken {
|
||||
t.Fatalf("persisted access_status = %q, want %q", accessStatus, AccessStatusBroken)
|
||||
}
|
||||
if got := queryCount(t, store.SQLDB(), "managed_resources"); got != 0 {
|
||||
t.Fatalf("managed_resources row count = %d, want 0 after strict rollback", got)
|
||||
}
|
||||
if got := queryCount(t, store.SQLDB(), "probe_results"); got != 2 {
|
||||
t.Fatalf("probe_results row count = %d, want 2", got)
|
||||
}
|
||||
if got := queryCount(t, store.SQLDB(), "access_closure_records"); got != 1 {
|
||||
t.Fatalf("access_closure_records row count = %d, want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func openProvisionTestStore(t *testing.T) *sqlite.DB {
|
||||
t.Helper()
|
||||
|
||||
dbPath := filepath.Join(t.TempDir(), "state.db")
|
||||
dsn := fmt.Sprintf("file:%s?_busy_timeout=5000&_pragma=foreign_keys(0)", filepath.ToSlash(dbPath))
|
||||
store, err := sqlite.Open(context.Background(), dsn)
|
||||
if err != nil {
|
||||
t.Fatalf("sqlite.Open() error = %v", err)
|
||||
}
|
||||
return store
|
||||
}
|
||||
|
||||
func closeProvisionTestStore(t *testing.T, store *sqlite.DB) {
|
||||
t.Helper()
|
||||
if err := store.Close(); err != nil {
|
||||
t.Fatalf("store.Close() error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func queryCount(t *testing.T, db *sql.DB, table string) int {
|
||||
t.Helper()
|
||||
var count int
|
||||
if err := db.QueryRowContext(context.Background(), "SELECT COUNT(*) FROM "+table).Scan(&count); err != nil {
|
||||
t.Fatalf("count rows for %s: %v", table, err)
|
||||
}
|
||||
return count
|
||||
}
|
||||
Reference in New Issue
Block a user