From eac860e72f1b03fc79b6378ae3c0fff62d3fdcef Mon Sep 17 00:00:00 2001 From: phamnazage-jpg Date: Fri, 22 May 2026 16:12:52 +0800 Subject: [PATCH] fix(api): wire batch import create-run entry pipeline --- internal/app/batch_runtime.go | 428 +++++++++++++++++++++++++ internal/app/http_batch_import.go | 72 +---- internal/app/http_batch_import_test.go | 225 +++++++++++++ 3 files changed, 661 insertions(+), 64 deletions(-) create mode 100644 internal/app/batch_runtime.go diff --git a/internal/app/batch_runtime.go b/internal/app/batch_runtime.go new file mode 100644 index 00000000..72633d34 --- /dev/null +++ b/internal/app/batch_runtime.go @@ -0,0 +1,428 @@ +package app + +import ( + "context" + "errors" + "fmt" + "net/http" + "strings" + "time" + + "sub2api-cn-relay-manager/internal/batch" + "sub2api-cn-relay-manager/internal/host/sub2api" + "sub2api-cn-relay-manager/internal/pack" + "sub2api-cn-relay-manager/internal/probe" + "sub2api-cn-relay-manager/internal/provision" + "sub2api-cn-relay-manager/internal/store/sqlite" +) + +const ( + generatedBatchImportPackID = "batch-auto-import-v2-generated" + generatedBatchImportPackVersion = "2026.05.22" + batchImportRetryDelay = 200 * time.Millisecond +) + +type batchImportRuntimeRunner struct { + store *sqlite.DB + hostRow sqlite.Host + hostClient *sub2api.Client + request CreateBatchImportRunRequest +} + +func (r batchImportRuntimeRunner) execute(ctx context.Context) (BatchImportRunCreateResponse, error) { + runID := fmt.Sprintf("run_%d", time.Now().UnixNano()) + + service := batch.BatchImportService{ + RunStore: r.store.ImportRuns(), + ItemStore: r.store.ImportRunItems(), + ProbeModels: probe.ProviderModels, + ProbeCapabilities: probe.ProbeCapabilities, + Provisioner: batchImportProvisioner{ + store: r.store, + hostRow: r.hostRow, + hostClient: r.hostClient, + request: r.request, + }, + } + + entries := make([]batch.BatchImportEntry, 0, len(r.request.Entries)) + for _, entry := range r.request.Entries { + entries = append(entries, batch.BatchImportEntry{ + BaseURL: entry.BaseURL, + APIKey: entry.APIKey, + RequestedModels: append([]string(nil), entry.RequestedModels...), + }) + } + + if _, err := service.StartRun(ctx, batch.BatchImportRunRequest{ + RunID: runID, + Mode: r.request.Mode, + AccessMode: r.request.AccessMode, + HostID: r.hostRow.HostID, + Entries: entries, + }); err != nil { + return BatchImportRunCreateResponse{}, err + } + + if err := r.advanceRun(ctx, runID); err != nil { + return BatchImportRunCreateResponse{}, err + } + + run, err := r.store.ImportRuns().GetByRunID(ctx, runID) + if err != nil { + return BatchImportRunCreateResponse{}, err + } + return BatchImportRunCreateResponse{ + RunID: run.RunID, + State: run.State, + ResultPage: "/batch-import/runs/" + run.RunID, + TotalItems: run.TotalItems, + ActiveItems: run.ActiveItems, + DegradedItems: run.DegradedItems, + BrokenItems: run.BrokenItems, + WarningItems: run.WarningItems, + }, nil +} + +func (r batchImportRuntimeRunner) advanceRun(ctx context.Context, runID string) error { + timeout := time.Duration(r.request.ConfirmWaitTimeoutSec) * time.Second + if timeout <= 0 { + timeout = time.Second + } + deadline := time.Now().Add(timeout) + + worker := batch.ConfirmationWorker{ + WorkerID: "batch-import-api", + ItemStore: batchImportRunItemStore{store: r.store, runID: runID}, + EventStore: r.store.ImportRunEvents(), + LeaseDuration: time.Minute, + RetryDelay: batchImportRetryDelay, + Confirmer: r.confirmItem, + } + validator := batch.ValidationService{ + ItemStore: r.store.ImportRunItems(), + RunStore: r.store.ImportRuns(), + Validator: r.validateItem, + } + + for { + now := time.Now() + if err := worker.Tick(ctx, now); err != nil { + return err + } + + items, err := r.store.ImportRunItems().ListByRunID(ctx, runID) + if err != nil { + return err + } + + pendingWork := false + for _, item := range items { + switch item.CurrentStage { + case string(batch.ItemStageValidate): + if err := validator.ValidateItem(ctx, item); err != nil { + return err + } + case string(batch.ItemStageConfirm): + if item.ConfirmationStatus == string(batch.ConfirmationPending) { + pendingWork = true + } + } + } + + run, err := r.store.ImportRuns().GetByRunID(ctx, runID) + if err != nil { + return err + } + if run.TotalItems > 0 && run.CompletedItems >= run.TotalItems { + return nil + } + if !pendingWork || !time.Now().Before(deadline) { + return nil + } + if err := sleepWithContext(ctx, batchImportRetryDelay); err != nil { + return err + } + } +} + +func (r batchImportRuntimeRunner) confirmItem(ctx context.Context, item sqlite.ImportRunItem) (batch.ConfirmationResult, error) { + accountID, err := resolveManagedResourceHostID(ctx, r.store, item, "account") + if err != nil { + return batch.ConfirmationResult{}, err + } + + probeResult, err := r.hostClient.TestAccount(ctx, accountID, item.ResolvedSmokeModel) + if err != nil { + var httpErr *sub2api.HTTPError + if errors.As(err, &httpErr) { + return batch.ConfirmationResult{StatusCode: httpErr.StatusCode, Message: httpErr.Body}, nil + } + return batch.ConfirmationResult{}, err + } + if probeResult.OK { + return batch.ConfirmationResult{StatusCode: http.StatusOK, Message: probeResult.Message}, nil + } + + message := strings.TrimSpace(probeResult.Message) + lowerMessage := strings.ToLower(message) + switch { + case strings.Contains(lowerMessage, "no available accounts"): + return batch.ConfirmationResult{StatusCode: http.StatusServiceUnavailable, Message: message}, nil + case strings.Contains(lowerMessage, "forbidden"): + return batch.ConfirmationResult{StatusCode: http.StatusForbidden, Message: message}, nil + default: + return batch.ConfirmationResult{StatusCode: http.StatusBadRequest, Message: message}, nil + } +} + +func (r batchImportRuntimeRunner) validateItem(ctx context.Context, item sqlite.ImportRunItem) (sub2api.GatewayCompletionResult, error) { + apiKey, err := r.resolveValidationAPIKey(ctx, item) + if err != nil { + return sub2api.GatewayCompletionResult{}, err + } + return r.hostClient.CheckGatewayCompletion(ctx, sub2api.GatewayCompletionCheckRequest{ + APIKey: apiKey, + Model: item.ResolvedSmokeModel, + Prompt: "ping", + MaxTokens: 8, + }) +} + +func (r batchImportRuntimeRunner) resolveValidationAPIKey(ctx context.Context, item sqlite.ImportRunItem) (string, error) { + switch strings.TrimSpace(r.request.AccessMode) { + case provision.AccessModeSelfService: + return strings.TrimSpace(r.request.ProbeAPIKey), nil + case provision.AccessModeSubscription: + if len(r.request.SubscriptionUsers) == 0 { + return "", fmt.Errorf("subscription_users is required") + } + groupID, err := resolveManagedResourceHostID(ctx, r.store, item, "group") + if err != nil { + return "", err + } + accessRef, err := r.hostClient.EnsureSubscriptionAccess(ctx, sub2api.EnsureSubscriptionAccessRequest{ + UserSelector: r.request.SubscriptionUsers[0], + GroupID: groupID, + }) + if err != nil { + return "", err + } + userID := strings.TrimSpace(accessRef.UserID) + if userID == "" { + userID = r.request.SubscriptionUsers[0] + } + if _, err := r.hostClient.AssignSubscription(ctx, sub2api.AssignSubscriptionRequest{ + UserID: userID, + GroupID: groupID, + DurationDays: r.request.SubscriptionDays, + }); err != nil { + return "", err + } + if strings.TrimSpace(accessRef.APIKey) == "" { + return "", fmt.Errorf("subscription access api key is empty") + } + return strings.TrimSpace(accessRef.APIKey), nil + default: + return "", fmt.Errorf("unsupported access mode %q", r.request.AccessMode) + } +} + +type batchImportProvisioner struct { + store *sqlite.DB + hostRow sqlite.Host + hostClient *sub2api.Client + request CreateBatchImportRunRequest +} + +func (p batchImportProvisioner) Provision(ctx context.Context, req batch.ProvisionRequest) (batch.ProvisionResult, error) { + runtimeService := provision.NewRuntimeImportService(p.store, p.hostClient) + providerManifest := generatedBatchImportProviderManifest(req, p.request) + result, err := runtimeService.Import(ctx, provision.RuntimeImportRequest{ + HostID: p.hostRow.HostID, + HostBaseURL: p.hostRow.BaseURL, + Pack: generatedBatchImportPack(providerManifest), + Provider: providerManifest, + Mode: firstNonEmptyString(strings.TrimSpace(p.request.Mode), provision.ImportModeStrict), + Keys: []string{strings.TrimSpace(req.Entry.APIKey)}, + Access: batchImportAccessRequest(p.request), + }) + if err != nil { + return batch.ProvisionResult{}, err + } + legacyBatchID := result.BatchID + return batch.ProvisionResult{ + LegacyBatchID: &legacyBatchID, + LegacyProviderID: req.ProviderID, + }, nil +} + +func (p batchImportProvisioner) Patch(_ context.Context, _ batch.PatchProvisionRequest) error { + return nil +} + +type batchImportRunItemStore struct { + store *sqlite.DB + runID string +} + +func (s batchImportRunItemStore) List(ctx context.Context) ([]sqlite.ImportRunItem, error) { + return s.store.ImportRunItems().ListByRunID(ctx, s.runID) +} + +func (s batchImportRunItemStore) Upsert(ctx context.Context, item sqlite.ImportRunItem) error { + return s.store.ImportRunItems().Upsert(ctx, item) +} + +func generatedBatchImportPack(providerManifest pack.ProviderManifest) pack.LoadedPack { + return pack.LoadedPack{ + Manifest: pack.Manifest{ + PackID: generatedBatchImportPackID, + Version: generatedBatchImportPackVersion, + Vendor: "sub2api-cn-relay-manager", + TargetHost: "sub2api", + }, + Providers: []pack.ProviderManifest{providerManifest}, + Checksum: generatedBatchImportPackID + "@" + generatedBatchImportPackVersion, + } +} + +func generatedBatchImportProviderManifest(req batch.ProvisionRequest, createReq CreateBatchImportRunRequest) pack.ProviderManifest { + defaultModels := uniqueNonEmptyStrings(capabilityProfileModels(req.CapabilityProfile)) + if len(defaultModels) == 0 { + defaultModels = uniqueNonEmptyStrings([]string{req.ResolvedModel}) + } + smokeModel := firstNonEmptyString(strings.TrimSpace(req.ResolvedModel)) + if smokeModel == "" && len(defaultModels) > 0 { + smokeModel = defaultModels[0] + } + if smokeModel == "" { + smokeModel = "ping" + } + if len(defaultModels) == 0 { + defaultModels = []string{smokeModel} + } + + modelMapping := make(map[string]string, len(defaultModels)) + for _, modelID := range defaultModels { + modelMapping[modelID] = modelID + } + + names := fmt.Sprintf("crm-%s", strings.TrimSpace(req.ProviderID)) + validityDays := createReq.SubscriptionDays + if validityDays <= 0 { + validityDays = 30 + } + + return pack.ProviderManifest{ + ProviderID: req.ProviderID, + DisplayName: req.ProviderID, + BaseURL: strings.TrimSpace(req.Entry.BaseURL), + Platform: "openai", + AccountType: "apikey", + DefaultModels: defaultModels, + SmokeTestModel: smokeModel, + GroupTemplate: pack.GroupTemplate{ + Name: names + "-group", + RateMultiplier: 1, + }, + ChannelTemplate: pack.ChannelTemplate{ + Name: names + "-channel", + ModelMapping: modelMapping, + }, + PlanTemplate: pack.PlanTemplate{ + Name: names + "-plan", + Price: 1, + ValidityDays: validityDays, + ValidityUnit: "day", + }, + Import: pack.ImportOptions{ + SupportsMultiKey: true, + SupportsStrict: true, + SupportsPartial: true, + }, + } +} + +func batchImportAccessRequest(req CreateBatchImportRunRequest) provision.AccessRequest { + subscriptions := make([]provision.SubscriptionTarget, 0, len(req.SubscriptionUsers)) + for _, userID := range req.SubscriptionUsers { + subscriptions = append(subscriptions, provision.SubscriptionTarget{ + UserID: userID, + DurationDays: req.SubscriptionDays, + }) + } + return provision.AccessRequest{ + Mode: strings.TrimSpace(req.AccessMode), + ProbeAPIKey: strings.TrimSpace(req.ProbeAPIKey), + Subscriptions: subscriptions, + } +} + +func resolveManagedResourceHostID(ctx context.Context, store *sqlite.DB, item sqlite.ImportRunItem, resourceType string) (string, error) { + if store == nil { + return "", fmt.Errorf("store is required") + } + if item.LegacyBatchID == nil || *item.LegacyBatchID <= 0 { + return "", fmt.Errorf("legacy_batch_id is required for %s lookup", strings.TrimSpace(resourceType)) + } + resources, err := store.ManagedResources().GetByBatchID(ctx, *item.LegacyBatchID) + if err != nil { + return "", err + } + for _, resource := range resources { + if strings.TrimSpace(resource.ResourceType) == strings.TrimSpace(resourceType) { + return strings.TrimSpace(resource.HostResourceID), nil + } + } + return "", fmt.Errorf("%s resource not found for batch %d", resourceType, *item.LegacyBatchID) +} + +func capabilityProfileModels(profile *probe.CapabilityProfile) []string { + if profile == nil { + return nil + } + models := make([]string, 0, len(profile.ModelProfiles)) + for _, modelProfile := range profile.ModelProfiles { + models = append(models, strings.TrimSpace(modelProfile.RawModelID)) + } + return models +} + +func uniqueNonEmptyStrings(values []string) []string { + seen := make(map[string]struct{}, len(values)) + result := make([]string, 0, len(values)) + for _, value := range values { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + continue + } + if _, ok := seen[trimmed]; ok { + continue + } + seen[trimmed] = struct{}{} + result = append(result, trimmed) + } + return result +} + +func firstNonEmptyString(values ...string) string { + for _, value := range values { + if trimmed := strings.TrimSpace(value); trimmed != "" { + return trimmed + } + } + return "" +} + +func sleepWithContext(ctx context.Context, delay time.Duration) error { + timer := time.NewTimer(delay) + defer timer.Stop() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} diff --git a/internal/app/http_batch_import.go b/internal/app/http_batch_import.go index 1ca75239..9e319b25 100644 --- a/internal/app/http_batch_import.go +++ b/internal/app/http_batch_import.go @@ -2,13 +2,9 @@ package app import ( "context" - "crypto/sha256" - "encoding/json" - "fmt" "net/http" "strconv" "strings" - "time" "sub2api-cn-relay-manager/internal/batch" "sub2api-cn-relay-manager/internal/store/sqlite" @@ -118,52 +114,17 @@ func buildCreateBatchImportRunAction(sqliteDSN string) func(context.Context, Cre } defer store.Close() - runID := fmt.Sprintf("run_%d", time.Now().UnixNano()) - run := sqlite.ImportRun{ - RunID: runID, - Mode: strings.TrimSpace(req.Mode), - AccessMode: strings.TrimSpace(req.AccessMode), - State: string(batch.RunStateRunning), - TotalItems: len(req.Entries), - } - if err := store.ImportRuns().Create(ctx, run); err != nil { + hostRow, client, err := resolveManagedHost(ctx, store, req.HostID, "", CreateHostAuth{}) + if err != nil { return BatchImportRunCreateResponse{}, err } - - for idx, entry := range req.Entries { - item := sqlite.ImportRunItem{ - ItemID: fmt.Sprintf("%s-item-%d", runID, idx+1), - RunID: runID, - BaseURL: strings.TrimSpace(entry.BaseURL), - ProviderID: batch.NormalizeProviderID(entry.BaseURL), - APIKeyFingerprint: fingerprintBatchAPIKey(entry.APIKey), - RequestedModelsJSON: mustMarshalAppJSON(entry.RequestedModels, "[]"), - CurrentStage: string(batch.ItemStageProbe), - ConfirmationStatus: string(batch.ConfirmationPending), - AccessStatus: string(batch.AccessStatusUnknown), - MatchedAccountState: string(batch.MatchedAccountStateNone), - AccountResolution: string(batch.AccountResolutionCreated), - ProvisionReused: false, - CanonicalFamiliesJSON: "[]", - RawModelsJSON: "[]", - NormalizedModelsJSON: "[]", - RecommendedModelsJSON: "[]", - } - if err := store.ImportRunItems().Upsert(ctx, item); err != nil { - return BatchImportRunCreateResponse{}, err - } + runner := batchImportRuntimeRunner{ + store: store, + hostRow: hostRow, + hostClient: client, + request: req, } - - return BatchImportRunCreateResponse{ - RunID: runID, - State: string(batch.RunStateRunning), - ResultPage: "/batch-import/runs/" + runID, - TotalItems: len(req.Entries), - ActiveItems: 0, - DegradedItems: 0, - BrokenItems: 0, - WarningItems: 0, - }, nil + return runner.execute(ctx) } } @@ -242,20 +203,3 @@ func defaultPositiveInt(value, fallback int) int { } return fallback } - -func fingerprintBatchAPIKey(apiKey string) string { - trimmed := strings.TrimSpace(apiKey) - if trimmed == "" { - return "" - } - sum := sha256.Sum256([]byte(trimmed)) - return fmt.Sprintf("sha256:%x", sum[:4]) -} - -func mustMarshalAppJSON(value any, fallback string) string { - payload, err := json.Marshal(value) - if err != nil { - return fallback - } - return string(payload) -} diff --git a/internal/app/http_batch_import_test.go b/internal/app/http_batch_import_test.go index e5c4f3bf..7fa4b0c2 100644 --- a/internal/app/http_batch_import_test.go +++ b/internal/app/http_batch_import_test.go @@ -2,8 +2,14 @@ package app import ( "context" + "fmt" "net/http" + "net/http/httptest" + "path/filepath" + "strings" "testing" + + "sub2api-cn-relay-manager/internal/store/sqlite" ) func TestBatchImportHTTP(t *testing.T) { @@ -126,4 +132,223 @@ func TestBatchImportHTTP(t *testing.T) { assertJSONContains(t, res.Body().Bytes(), "error.code", "invalid_request") assertJSONContains(t, res.Body().Bytes(), "error.message", "host_id is required") }) + + t.Run("create run action wires real batch pipeline", func(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(newBatchImportActionStubServer(t)) + defer server.Close() + + dsn := fmt.Sprintf("file:%s?_busy_timeout=5000&_pragma=foreign_keys(0)", filepath.ToSlash(filepath.Join(t.TempDir(), "state.db"))) + store, err := sqlite.Open(context.Background(), dsn) + if err != nil { + t.Fatalf("sqlite.Open() error = %v", err) + } + defer closeAppTestStore(t, store) + + if _, err := store.Hosts().Create(context.Background(), sqlite.Host{ + HostID: "host-1", + BaseURL: server.URL, + HostVersion: "0.1.126", + CapabilityProbeJSON: "{}", + AuthType: "apikey", + AuthToken: "host-token", + }); err != nil { + t.Fatalf("Hosts().Create() error = %v", err) + } + + action := buildCreateBatchImportRunAction(dsn) + result, err := action(context.Background(), CreateBatchImportRunRequest{ + HostID: "host-1", + Mode: "strict", + AccessMode: "self_service", + ConfirmWaitTimeoutSec: 1, + ProbeAPIKey: "gateway-key", + Entries: []BatchImportEntryRequest{ + { + BaseURL: server.URL, + APIKey: "entry-key", + RequestedModels: []string{"kimi-k2.6"}, + }, + }, + }) + if err != nil { + t.Fatalf("buildCreateBatchImportRunAction() error = %v", err) + } + if result.State != string("completed") { + t.Fatalf("result.State = %q, want completed", result.State) + } + if result.ActiveItems != 1 { + t.Fatalf("result.ActiveItems = %d, want 1", result.ActiveItems) + } + + run, err := store.ImportRuns().GetByRunID(context.Background(), result.RunID) + if err != nil { + t.Fatalf("ImportRuns().GetByRunID() error = %v", err) + } + if run.State != "completed" { + t.Fatalf("run.State = %q, want completed", run.State) + } + items, err := store.ImportRunItems().ListByRunID(context.Background(), result.RunID) + if err != nil { + t.Fatalf("ImportRunItems().ListByRunID() error = %v", err) + } + if len(items) != 1 { + t.Fatalf("len(items) = %d, want 1", len(items)) + } + if items[0].CurrentStage != "done" || items[0].AccessStatus != "active" { + t.Fatalf("item = %+v, want current_stage=done and access_status=active", items[0]) + } + }) +} + +func newBatchImportActionStubServer(t *testing.T) http.Handler { + t.Helper() + + mux := http.NewServeMux() + mux.HandleFunc("/api/v1/admin/system/version", func(w http.ResponseWriter, r *http.Request) { + if !requireBatchImportActionAdminToken(t, w, r) { + return + } + writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"version": "0.1.126"}}) + }) + mux.HandleFunc("/api/v1/admin/groups", func(w http.ResponseWriter, r *http.Request) { + if !requireBatchImportActionAdminToken(t, w, r) { + return + } + switch r.Method { + case http.MethodGet: + writeJSON(w, http.StatusOK, map[string]any{"data": []map[string]any{}}) + case http.MethodPost: + writeJSON(w, http.StatusCreated, map[string]any{"data": map[string]any{"id": "group_1", "name": "batch-import-group"}}) + default: + http.NotFound(w, r) + } + }) + mux.HandleFunc("/api/v1/admin/channels", func(w http.ResponseWriter, r *http.Request) { + if !requireBatchImportActionAdminToken(t, w, r) { + return + } + switch r.Method { + case http.MethodGet: + writeJSON(w, http.StatusOK, map[string]any{"data": []map[string]any{}}) + case http.MethodPost: + writeJSON(w, http.StatusCreated, map[string]any{"data": map[string]any{"id": "channel_1", "name": "batch-import-channel"}}) + default: + http.NotFound(w, r) + } + }) + mux.HandleFunc("/api/v1/admin/payment/plans", func(w http.ResponseWriter, r *http.Request) { + if !requireBatchImportActionAdminToken(t, w, r) { + return + } + switch r.Method { + case http.MethodGet: + writeJSON(w, http.StatusOK, map[string]any{"data": []map[string]any{}}) + case http.MethodPost: + writeJSON(w, http.StatusCreated, map[string]any{"data": map[string]any{"id": "plan_1", "name": "batch-import-plan"}}) + default: + http.NotFound(w, r) + } + }) + mux.HandleFunc("/api/v1/admin/accounts", func(w http.ResponseWriter, r *http.Request) { + if !requireBatchImportActionAdminToken(t, w, r) { + return + } + writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"items": []map[string]any{}, "pages": 1}}) + }) + mux.HandleFunc("/api/v1/admin/accounts/batch", func(w http.ResponseWriter, r *http.Request) { + if !requireBatchImportActionAdminToken(t, w, r) { + return + } + writeJSON(w, http.StatusOK, map[string]any{"data": []map[string]any{{"id": "account_1", "name": "batch-import-01"}}}) + }) + mux.HandleFunc("/api/v1/admin/accounts/__probe__/test", func(w http.ResponseWriter, r *http.Request) { + if !requireBatchImportActionAdminToken(t, w, r) { + return + } + writeJSON(w, http.StatusBadRequest, map[string]any{"error": "probe only"}) + }) + mux.HandleFunc("/api/v1/admin/accounts/__probe__/models", func(w http.ResponseWriter, r *http.Request) { + if !requireBatchImportActionAdminToken(t, w, r) { + return + } + writeJSON(w, http.StatusBadRequest, map[string]any{"error": "probe only"}) + }) + mux.HandleFunc("/api/v1/admin/accounts/account_1/test", func(w http.ResponseWriter, r *http.Request) { + if !requireBatchImportActionAdminToken(t, w, r) { + return + } + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("event: result\n")) + _, _ = w.Write([]byte("data: {\"status\":\"passed\",\"message\":\"smoke passed\",\"ok\":true}\n\n")) + }) + mux.HandleFunc("/api/v1/admin/accounts/account_1/models", func(w http.ResponseWriter, r *http.Request) { + if !requireBatchImportActionAdminToken(t, w, r) { + return + } + writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"items": []map[string]any{{"id": "kimi-k2.6", "display_name": "Kimi K2.6", "type": "chat"}}}}) + }) + mux.HandleFunc("/api/v1/admin/subscriptions/assign", func(w http.ResponseWriter, r *http.Request) { + if !requireBatchImportActionAdminToken(t, w, r) { + return + } + switch r.Method { + case http.MethodGet: + writeJSON(w, http.StatusBadRequest, map[string]any{"error": "probe only"}) + case http.MethodPost: + writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"id": "subscription_1"}}) + default: + http.NotFound(w, r) + } + }) + mux.HandleFunc("/v1/models", func(w http.ResponseWriter, r *http.Request) { + switch strings.TrimSpace(r.Header.Get("Authorization")) { + case "Bearer entry-key", "Bearer gateway-key": + writeJSON(w, http.StatusOK, map[string]any{"data": []map[string]any{{"id": "kimi-k2.6"}}}) + default: + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"unauthorized"}`)) + } + }) + mux.HandleFunc("/v1/responses", func(w http.ResponseWriter, r *http.Request) { + if strings.TrimSpace(r.Header.Get("Authorization")) != "Bearer entry-key" { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"unauthorized"}`)) + return + } + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`{"error":"responses unsupported"}`)) + }) + mux.HandleFunc("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + switch strings.TrimSpace(r.Header.Get("Authorization")) { + case "Bearer entry-key", "Bearer gateway-key": + writeJSON(w, http.StatusOK, map[string]any{ + "id": "chatcmpl_batch_import", + "choices": []map[string]any{{ + "index": 0, + "message": map[string]any{ + "role": "assistant", + "content": "pong", + }, + }}, + }) + default: + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"unauthorized"}`)) + } + }) + + return mux +} + +func requireBatchImportActionAdminToken(t *testing.T, w http.ResponseWriter, r *http.Request) bool { + t.Helper() + if strings.TrimSpace(r.Header.Get("x-api-key")) != "host-token" { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"unauthorized"}`)) + return false + } + return true }