fix(api): wire batch import create-run entry pipeline
This commit is contained in:
428
internal/app/batch_runtime.go
Normal file
428
internal/app/batch_runtime.go
Normal file
@@ -0,0 +1,428 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/batch"
|
||||
"sub2api-cn-relay-manager/internal/host/sub2api"
|
||||
"sub2api-cn-relay-manager/internal/pack"
|
||||
"sub2api-cn-relay-manager/internal/probe"
|
||||
"sub2api-cn-relay-manager/internal/provision"
|
||||
"sub2api-cn-relay-manager/internal/store/sqlite"
|
||||
)
|
||||
|
||||
const (
|
||||
generatedBatchImportPackID = "batch-auto-import-v2-generated"
|
||||
generatedBatchImportPackVersion = "2026.05.22"
|
||||
batchImportRetryDelay = 200 * time.Millisecond
|
||||
)
|
||||
|
||||
type batchImportRuntimeRunner struct {
|
||||
store *sqlite.DB
|
||||
hostRow sqlite.Host
|
||||
hostClient *sub2api.Client
|
||||
request CreateBatchImportRunRequest
|
||||
}
|
||||
|
||||
func (r batchImportRuntimeRunner) execute(ctx context.Context) (BatchImportRunCreateResponse, error) {
|
||||
runID := fmt.Sprintf("run_%d", time.Now().UnixNano())
|
||||
|
||||
service := batch.BatchImportService{
|
||||
RunStore: r.store.ImportRuns(),
|
||||
ItemStore: r.store.ImportRunItems(),
|
||||
ProbeModels: probe.ProviderModels,
|
||||
ProbeCapabilities: probe.ProbeCapabilities,
|
||||
Provisioner: batchImportProvisioner{
|
||||
store: r.store,
|
||||
hostRow: r.hostRow,
|
||||
hostClient: r.hostClient,
|
||||
request: r.request,
|
||||
},
|
||||
}
|
||||
|
||||
entries := make([]batch.BatchImportEntry, 0, len(r.request.Entries))
|
||||
for _, entry := range r.request.Entries {
|
||||
entries = append(entries, batch.BatchImportEntry{
|
||||
BaseURL: entry.BaseURL,
|
||||
APIKey: entry.APIKey,
|
||||
RequestedModels: append([]string(nil), entry.RequestedModels...),
|
||||
})
|
||||
}
|
||||
|
||||
if _, err := service.StartRun(ctx, batch.BatchImportRunRequest{
|
||||
RunID: runID,
|
||||
Mode: r.request.Mode,
|
||||
AccessMode: r.request.AccessMode,
|
||||
HostID: r.hostRow.HostID,
|
||||
Entries: entries,
|
||||
}); err != nil {
|
||||
return BatchImportRunCreateResponse{}, err
|
||||
}
|
||||
|
||||
if err := r.advanceRun(ctx, runID); err != nil {
|
||||
return BatchImportRunCreateResponse{}, err
|
||||
}
|
||||
|
||||
run, err := r.store.ImportRuns().GetByRunID(ctx, runID)
|
||||
if err != nil {
|
||||
return BatchImportRunCreateResponse{}, err
|
||||
}
|
||||
return BatchImportRunCreateResponse{
|
||||
RunID: run.RunID,
|
||||
State: run.State,
|
||||
ResultPage: "/batch-import/runs/" + run.RunID,
|
||||
TotalItems: run.TotalItems,
|
||||
ActiveItems: run.ActiveItems,
|
||||
DegradedItems: run.DegradedItems,
|
||||
BrokenItems: run.BrokenItems,
|
||||
WarningItems: run.WarningItems,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r batchImportRuntimeRunner) advanceRun(ctx context.Context, runID string) error {
|
||||
timeout := time.Duration(r.request.ConfirmWaitTimeoutSec) * time.Second
|
||||
if timeout <= 0 {
|
||||
timeout = time.Second
|
||||
}
|
||||
deadline := time.Now().Add(timeout)
|
||||
|
||||
worker := batch.ConfirmationWorker{
|
||||
WorkerID: "batch-import-api",
|
||||
ItemStore: batchImportRunItemStore{store: r.store, runID: runID},
|
||||
EventStore: r.store.ImportRunEvents(),
|
||||
LeaseDuration: time.Minute,
|
||||
RetryDelay: batchImportRetryDelay,
|
||||
Confirmer: r.confirmItem,
|
||||
}
|
||||
validator := batch.ValidationService{
|
||||
ItemStore: r.store.ImportRunItems(),
|
||||
RunStore: r.store.ImportRuns(),
|
||||
Validator: r.validateItem,
|
||||
}
|
||||
|
||||
for {
|
||||
now := time.Now()
|
||||
if err := worker.Tick(ctx, now); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
items, err := r.store.ImportRunItems().ListByRunID(ctx, runID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pendingWork := false
|
||||
for _, item := range items {
|
||||
switch item.CurrentStage {
|
||||
case string(batch.ItemStageValidate):
|
||||
if err := validator.ValidateItem(ctx, item); err != nil {
|
||||
return err
|
||||
}
|
||||
case string(batch.ItemStageConfirm):
|
||||
if item.ConfirmationStatus == string(batch.ConfirmationPending) {
|
||||
pendingWork = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
run, err := r.store.ImportRuns().GetByRunID(ctx, runID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if run.TotalItems > 0 && run.CompletedItems >= run.TotalItems {
|
||||
return nil
|
||||
}
|
||||
if !pendingWork || !time.Now().Before(deadline) {
|
||||
return nil
|
||||
}
|
||||
if err := sleepWithContext(ctx, batchImportRetryDelay); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r batchImportRuntimeRunner) confirmItem(ctx context.Context, item sqlite.ImportRunItem) (batch.ConfirmationResult, error) {
|
||||
accountID, err := resolveManagedResourceHostID(ctx, r.store, item, "account")
|
||||
if err != nil {
|
||||
return batch.ConfirmationResult{}, err
|
||||
}
|
||||
|
||||
probeResult, err := r.hostClient.TestAccount(ctx, accountID, item.ResolvedSmokeModel)
|
||||
if err != nil {
|
||||
var httpErr *sub2api.HTTPError
|
||||
if errors.As(err, &httpErr) {
|
||||
return batch.ConfirmationResult{StatusCode: httpErr.StatusCode, Message: httpErr.Body}, nil
|
||||
}
|
||||
return batch.ConfirmationResult{}, err
|
||||
}
|
||||
if probeResult.OK {
|
||||
return batch.ConfirmationResult{StatusCode: http.StatusOK, Message: probeResult.Message}, nil
|
||||
}
|
||||
|
||||
message := strings.TrimSpace(probeResult.Message)
|
||||
lowerMessage := strings.ToLower(message)
|
||||
switch {
|
||||
case strings.Contains(lowerMessage, "no available accounts"):
|
||||
return batch.ConfirmationResult{StatusCode: http.StatusServiceUnavailable, Message: message}, nil
|
||||
case strings.Contains(lowerMessage, "forbidden"):
|
||||
return batch.ConfirmationResult{StatusCode: http.StatusForbidden, Message: message}, nil
|
||||
default:
|
||||
return batch.ConfirmationResult{StatusCode: http.StatusBadRequest, Message: message}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (r batchImportRuntimeRunner) validateItem(ctx context.Context, item sqlite.ImportRunItem) (sub2api.GatewayCompletionResult, error) {
|
||||
apiKey, err := r.resolveValidationAPIKey(ctx, item)
|
||||
if err != nil {
|
||||
return sub2api.GatewayCompletionResult{}, err
|
||||
}
|
||||
return r.hostClient.CheckGatewayCompletion(ctx, sub2api.GatewayCompletionCheckRequest{
|
||||
APIKey: apiKey,
|
||||
Model: item.ResolvedSmokeModel,
|
||||
Prompt: "ping",
|
||||
MaxTokens: 8,
|
||||
})
|
||||
}
|
||||
|
||||
func (r batchImportRuntimeRunner) resolveValidationAPIKey(ctx context.Context, item sqlite.ImportRunItem) (string, error) {
|
||||
switch strings.TrimSpace(r.request.AccessMode) {
|
||||
case provision.AccessModeSelfService:
|
||||
return strings.TrimSpace(r.request.ProbeAPIKey), nil
|
||||
case provision.AccessModeSubscription:
|
||||
if len(r.request.SubscriptionUsers) == 0 {
|
||||
return "", fmt.Errorf("subscription_users is required")
|
||||
}
|
||||
groupID, err := resolveManagedResourceHostID(ctx, r.store, item, "group")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
accessRef, err := r.hostClient.EnsureSubscriptionAccess(ctx, sub2api.EnsureSubscriptionAccessRequest{
|
||||
UserSelector: r.request.SubscriptionUsers[0],
|
||||
GroupID: groupID,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
userID := strings.TrimSpace(accessRef.UserID)
|
||||
if userID == "" {
|
||||
userID = r.request.SubscriptionUsers[0]
|
||||
}
|
||||
if _, err := r.hostClient.AssignSubscription(ctx, sub2api.AssignSubscriptionRequest{
|
||||
UserID: userID,
|
||||
GroupID: groupID,
|
||||
DurationDays: r.request.SubscriptionDays,
|
||||
}); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if strings.TrimSpace(accessRef.APIKey) == "" {
|
||||
return "", fmt.Errorf("subscription access api key is empty")
|
||||
}
|
||||
return strings.TrimSpace(accessRef.APIKey), nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported access mode %q", r.request.AccessMode)
|
||||
}
|
||||
}
|
||||
|
||||
type batchImportProvisioner struct {
|
||||
store *sqlite.DB
|
||||
hostRow sqlite.Host
|
||||
hostClient *sub2api.Client
|
||||
request CreateBatchImportRunRequest
|
||||
}
|
||||
|
||||
func (p batchImportProvisioner) Provision(ctx context.Context, req batch.ProvisionRequest) (batch.ProvisionResult, error) {
|
||||
runtimeService := provision.NewRuntimeImportService(p.store, p.hostClient)
|
||||
providerManifest := generatedBatchImportProviderManifest(req, p.request)
|
||||
result, err := runtimeService.Import(ctx, provision.RuntimeImportRequest{
|
||||
HostID: p.hostRow.HostID,
|
||||
HostBaseURL: p.hostRow.BaseURL,
|
||||
Pack: generatedBatchImportPack(providerManifest),
|
||||
Provider: providerManifest,
|
||||
Mode: firstNonEmptyString(strings.TrimSpace(p.request.Mode), provision.ImportModeStrict),
|
||||
Keys: []string{strings.TrimSpace(req.Entry.APIKey)},
|
||||
Access: batchImportAccessRequest(p.request),
|
||||
})
|
||||
if err != nil {
|
||||
return batch.ProvisionResult{}, err
|
||||
}
|
||||
legacyBatchID := result.BatchID
|
||||
return batch.ProvisionResult{
|
||||
LegacyBatchID: &legacyBatchID,
|
||||
LegacyProviderID: req.ProviderID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p batchImportProvisioner) Patch(_ context.Context, _ batch.PatchProvisionRequest) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type batchImportRunItemStore struct {
|
||||
store *sqlite.DB
|
||||
runID string
|
||||
}
|
||||
|
||||
func (s batchImportRunItemStore) List(ctx context.Context) ([]sqlite.ImportRunItem, error) {
|
||||
return s.store.ImportRunItems().ListByRunID(ctx, s.runID)
|
||||
}
|
||||
|
||||
func (s batchImportRunItemStore) Upsert(ctx context.Context, item sqlite.ImportRunItem) error {
|
||||
return s.store.ImportRunItems().Upsert(ctx, item)
|
||||
}
|
||||
|
||||
func generatedBatchImportPack(providerManifest pack.ProviderManifest) pack.LoadedPack {
|
||||
return pack.LoadedPack{
|
||||
Manifest: pack.Manifest{
|
||||
PackID: generatedBatchImportPackID,
|
||||
Version: generatedBatchImportPackVersion,
|
||||
Vendor: "sub2api-cn-relay-manager",
|
||||
TargetHost: "sub2api",
|
||||
},
|
||||
Providers: []pack.ProviderManifest{providerManifest},
|
||||
Checksum: generatedBatchImportPackID + "@" + generatedBatchImportPackVersion,
|
||||
}
|
||||
}
|
||||
|
||||
func generatedBatchImportProviderManifest(req batch.ProvisionRequest, createReq CreateBatchImportRunRequest) pack.ProviderManifest {
|
||||
defaultModels := uniqueNonEmptyStrings(capabilityProfileModels(req.CapabilityProfile))
|
||||
if len(defaultModels) == 0 {
|
||||
defaultModels = uniqueNonEmptyStrings([]string{req.ResolvedModel})
|
||||
}
|
||||
smokeModel := firstNonEmptyString(strings.TrimSpace(req.ResolvedModel))
|
||||
if smokeModel == "" && len(defaultModels) > 0 {
|
||||
smokeModel = defaultModels[0]
|
||||
}
|
||||
if smokeModel == "" {
|
||||
smokeModel = "ping"
|
||||
}
|
||||
if len(defaultModels) == 0 {
|
||||
defaultModels = []string{smokeModel}
|
||||
}
|
||||
|
||||
modelMapping := make(map[string]string, len(defaultModels))
|
||||
for _, modelID := range defaultModels {
|
||||
modelMapping[modelID] = modelID
|
||||
}
|
||||
|
||||
names := fmt.Sprintf("crm-%s", strings.TrimSpace(req.ProviderID))
|
||||
validityDays := createReq.SubscriptionDays
|
||||
if validityDays <= 0 {
|
||||
validityDays = 30
|
||||
}
|
||||
|
||||
return pack.ProviderManifest{
|
||||
ProviderID: req.ProviderID,
|
||||
DisplayName: req.ProviderID,
|
||||
BaseURL: strings.TrimSpace(req.Entry.BaseURL),
|
||||
Platform: "openai",
|
||||
AccountType: "apikey",
|
||||
DefaultModels: defaultModels,
|
||||
SmokeTestModel: smokeModel,
|
||||
GroupTemplate: pack.GroupTemplate{
|
||||
Name: names + "-group",
|
||||
RateMultiplier: 1,
|
||||
},
|
||||
ChannelTemplate: pack.ChannelTemplate{
|
||||
Name: names + "-channel",
|
||||
ModelMapping: modelMapping,
|
||||
},
|
||||
PlanTemplate: pack.PlanTemplate{
|
||||
Name: names + "-plan",
|
||||
Price: 1,
|
||||
ValidityDays: validityDays,
|
||||
ValidityUnit: "day",
|
||||
},
|
||||
Import: pack.ImportOptions{
|
||||
SupportsMultiKey: true,
|
||||
SupportsStrict: true,
|
||||
SupportsPartial: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func batchImportAccessRequest(req CreateBatchImportRunRequest) provision.AccessRequest {
|
||||
subscriptions := make([]provision.SubscriptionTarget, 0, len(req.SubscriptionUsers))
|
||||
for _, userID := range req.SubscriptionUsers {
|
||||
subscriptions = append(subscriptions, provision.SubscriptionTarget{
|
||||
UserID: userID,
|
||||
DurationDays: req.SubscriptionDays,
|
||||
})
|
||||
}
|
||||
return provision.AccessRequest{
|
||||
Mode: strings.TrimSpace(req.AccessMode),
|
||||
ProbeAPIKey: strings.TrimSpace(req.ProbeAPIKey),
|
||||
Subscriptions: subscriptions,
|
||||
}
|
||||
}
|
||||
|
||||
func resolveManagedResourceHostID(ctx context.Context, store *sqlite.DB, item sqlite.ImportRunItem, resourceType string) (string, error) {
|
||||
if store == nil {
|
||||
return "", fmt.Errorf("store is required")
|
||||
}
|
||||
if item.LegacyBatchID == nil || *item.LegacyBatchID <= 0 {
|
||||
return "", fmt.Errorf("legacy_batch_id is required for %s lookup", strings.TrimSpace(resourceType))
|
||||
}
|
||||
resources, err := store.ManagedResources().GetByBatchID(ctx, *item.LegacyBatchID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
for _, resource := range resources {
|
||||
if strings.TrimSpace(resource.ResourceType) == strings.TrimSpace(resourceType) {
|
||||
return strings.TrimSpace(resource.HostResourceID), nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("%s resource not found for batch %d", resourceType, *item.LegacyBatchID)
|
||||
}
|
||||
|
||||
func capabilityProfileModels(profile *probe.CapabilityProfile) []string {
|
||||
if profile == nil {
|
||||
return nil
|
||||
}
|
||||
models := make([]string, 0, len(profile.ModelProfiles))
|
||||
for _, modelProfile := range profile.ModelProfiles {
|
||||
models = append(models, strings.TrimSpace(modelProfile.RawModelID))
|
||||
}
|
||||
return models
|
||||
}
|
||||
|
||||
func uniqueNonEmptyStrings(values []string) []string {
|
||||
seen := make(map[string]struct{}, len(values))
|
||||
result := make([]string, 0, len(values))
|
||||
for _, value := range values {
|
||||
trimmed := strings.TrimSpace(value)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[trimmed]; ok {
|
||||
continue
|
||||
}
|
||||
seen[trimmed] = struct{}{}
|
||||
result = append(result, trimmed)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func firstNonEmptyString(values ...string) string {
|
||||
for _, value := range values {
|
||||
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func sleepWithContext(ctx context.Context, delay time.Duration) error {
|
||||
timer := time.NewTimer(delay)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -2,13 +2,9 @@ package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/batch"
|
||||
"sub2api-cn-relay-manager/internal/store/sqlite"
|
||||
@@ -118,52 +114,17 @@ func buildCreateBatchImportRunAction(sqliteDSN string) func(context.Context, Cre
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
runID := fmt.Sprintf("run_%d", time.Now().UnixNano())
|
||||
run := sqlite.ImportRun{
|
||||
RunID: runID,
|
||||
Mode: strings.TrimSpace(req.Mode),
|
||||
AccessMode: strings.TrimSpace(req.AccessMode),
|
||||
State: string(batch.RunStateRunning),
|
||||
TotalItems: len(req.Entries),
|
||||
}
|
||||
if err := store.ImportRuns().Create(ctx, run); err != nil {
|
||||
hostRow, client, err := resolveManagedHost(ctx, store, req.HostID, "", CreateHostAuth{})
|
||||
if err != nil {
|
||||
return BatchImportRunCreateResponse{}, err
|
||||
}
|
||||
|
||||
for idx, entry := range req.Entries {
|
||||
item := sqlite.ImportRunItem{
|
||||
ItemID: fmt.Sprintf("%s-item-%d", runID, idx+1),
|
||||
RunID: runID,
|
||||
BaseURL: strings.TrimSpace(entry.BaseURL),
|
||||
ProviderID: batch.NormalizeProviderID(entry.BaseURL),
|
||||
APIKeyFingerprint: fingerprintBatchAPIKey(entry.APIKey),
|
||||
RequestedModelsJSON: mustMarshalAppJSON(entry.RequestedModels, "[]"),
|
||||
CurrentStage: string(batch.ItemStageProbe),
|
||||
ConfirmationStatus: string(batch.ConfirmationPending),
|
||||
AccessStatus: string(batch.AccessStatusUnknown),
|
||||
MatchedAccountState: string(batch.MatchedAccountStateNone),
|
||||
AccountResolution: string(batch.AccountResolutionCreated),
|
||||
ProvisionReused: false,
|
||||
CanonicalFamiliesJSON: "[]",
|
||||
RawModelsJSON: "[]",
|
||||
NormalizedModelsJSON: "[]",
|
||||
RecommendedModelsJSON: "[]",
|
||||
}
|
||||
if err := store.ImportRunItems().Upsert(ctx, item); err != nil {
|
||||
return BatchImportRunCreateResponse{}, err
|
||||
}
|
||||
runner := batchImportRuntimeRunner{
|
||||
store: store,
|
||||
hostRow: hostRow,
|
||||
hostClient: client,
|
||||
request: req,
|
||||
}
|
||||
|
||||
return BatchImportRunCreateResponse{
|
||||
RunID: runID,
|
||||
State: string(batch.RunStateRunning),
|
||||
ResultPage: "/batch-import/runs/" + runID,
|
||||
TotalItems: len(req.Entries),
|
||||
ActiveItems: 0,
|
||||
DegradedItems: 0,
|
||||
BrokenItems: 0,
|
||||
WarningItems: 0,
|
||||
}, nil
|
||||
return runner.execute(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -242,20 +203,3 @@ func defaultPositiveInt(value, fallback int) int {
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func fingerprintBatchAPIKey(apiKey string) string {
|
||||
trimmed := strings.TrimSpace(apiKey)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256([]byte(trimmed))
|
||||
return fmt.Sprintf("sha256:%x", sum[:4])
|
||||
}
|
||||
|
||||
func mustMarshalAppJSON(value any, fallback string) string {
|
||||
payload, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return fallback
|
||||
}
|
||||
return string(payload)
|
||||
}
|
||||
|
||||
@@ -2,8 +2,14 @@ package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/store/sqlite"
|
||||
)
|
||||
|
||||
func TestBatchImportHTTP(t *testing.T) {
|
||||
@@ -126,4 +132,223 @@ func TestBatchImportHTTP(t *testing.T) {
|
||||
assertJSONContains(t, res.Body().Bytes(), "error.code", "invalid_request")
|
||||
assertJSONContains(t, res.Body().Bytes(), "error.message", "host_id is required")
|
||||
})
|
||||
|
||||
t.Run("create run action wires real batch pipeline", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(newBatchImportActionStubServer(t))
|
||||
defer server.Close()
|
||||
|
||||
dsn := fmt.Sprintf("file:%s?_busy_timeout=5000&_pragma=foreign_keys(0)", filepath.ToSlash(filepath.Join(t.TempDir(), "state.db")))
|
||||
store, err := sqlite.Open(context.Background(), dsn)
|
||||
if err != nil {
|
||||
t.Fatalf("sqlite.Open() error = %v", err)
|
||||
}
|
||||
defer closeAppTestStore(t, store)
|
||||
|
||||
if _, err := store.Hosts().Create(context.Background(), sqlite.Host{
|
||||
HostID: "host-1",
|
||||
BaseURL: server.URL,
|
||||
HostVersion: "0.1.126",
|
||||
CapabilityProbeJSON: "{}",
|
||||
AuthType: "apikey",
|
||||
AuthToken: "host-token",
|
||||
}); err != nil {
|
||||
t.Fatalf("Hosts().Create() error = %v", err)
|
||||
}
|
||||
|
||||
action := buildCreateBatchImportRunAction(dsn)
|
||||
result, err := action(context.Background(), CreateBatchImportRunRequest{
|
||||
HostID: "host-1",
|
||||
Mode: "strict",
|
||||
AccessMode: "self_service",
|
||||
ConfirmWaitTimeoutSec: 1,
|
||||
ProbeAPIKey: "gateway-key",
|
||||
Entries: []BatchImportEntryRequest{
|
||||
{
|
||||
BaseURL: server.URL,
|
||||
APIKey: "entry-key",
|
||||
RequestedModels: []string{"kimi-k2.6"},
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("buildCreateBatchImportRunAction() error = %v", err)
|
||||
}
|
||||
if result.State != string("completed") {
|
||||
t.Fatalf("result.State = %q, want completed", result.State)
|
||||
}
|
||||
if result.ActiveItems != 1 {
|
||||
t.Fatalf("result.ActiveItems = %d, want 1", result.ActiveItems)
|
||||
}
|
||||
|
||||
run, err := store.ImportRuns().GetByRunID(context.Background(), result.RunID)
|
||||
if err != nil {
|
||||
t.Fatalf("ImportRuns().GetByRunID() error = %v", err)
|
||||
}
|
||||
if run.State != "completed" {
|
||||
t.Fatalf("run.State = %q, want completed", run.State)
|
||||
}
|
||||
items, err := store.ImportRunItems().ListByRunID(context.Background(), result.RunID)
|
||||
if err != nil {
|
||||
t.Fatalf("ImportRunItems().ListByRunID() error = %v", err)
|
||||
}
|
||||
if len(items) != 1 {
|
||||
t.Fatalf("len(items) = %d, want 1", len(items))
|
||||
}
|
||||
if items[0].CurrentStage != "done" || items[0].AccessStatus != "active" {
|
||||
t.Fatalf("item = %+v, want current_stage=done and access_status=active", items[0])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func newBatchImportActionStubServer(t *testing.T) http.Handler {
|
||||
t.Helper()
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/api/v1/admin/system/version", func(w http.ResponseWriter, r *http.Request) {
|
||||
if !requireBatchImportActionAdminToken(t, w, r) {
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"version": "0.1.126"}})
|
||||
})
|
||||
mux.HandleFunc("/api/v1/admin/groups", func(w http.ResponseWriter, r *http.Request) {
|
||||
if !requireBatchImportActionAdminToken(t, w, r) {
|
||||
return
|
||||
}
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
writeJSON(w, http.StatusOK, map[string]any{"data": []map[string]any{}})
|
||||
case http.MethodPost:
|
||||
writeJSON(w, http.StatusCreated, map[string]any{"data": map[string]any{"id": "group_1", "name": "batch-import-group"}})
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
})
|
||||
mux.HandleFunc("/api/v1/admin/channels", func(w http.ResponseWriter, r *http.Request) {
|
||||
if !requireBatchImportActionAdminToken(t, w, r) {
|
||||
return
|
||||
}
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
writeJSON(w, http.StatusOK, map[string]any{"data": []map[string]any{}})
|
||||
case http.MethodPost:
|
||||
writeJSON(w, http.StatusCreated, map[string]any{"data": map[string]any{"id": "channel_1", "name": "batch-import-channel"}})
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
})
|
||||
mux.HandleFunc("/api/v1/admin/payment/plans", func(w http.ResponseWriter, r *http.Request) {
|
||||
if !requireBatchImportActionAdminToken(t, w, r) {
|
||||
return
|
||||
}
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
writeJSON(w, http.StatusOK, map[string]any{"data": []map[string]any{}})
|
||||
case http.MethodPost:
|
||||
writeJSON(w, http.StatusCreated, map[string]any{"data": map[string]any{"id": "plan_1", "name": "batch-import-plan"}})
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
})
|
||||
mux.HandleFunc("/api/v1/admin/accounts", func(w http.ResponseWriter, r *http.Request) {
|
||||
if !requireBatchImportActionAdminToken(t, w, r) {
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"items": []map[string]any{}, "pages": 1}})
|
||||
})
|
||||
mux.HandleFunc("/api/v1/admin/accounts/batch", func(w http.ResponseWriter, r *http.Request) {
|
||||
if !requireBatchImportActionAdminToken(t, w, r) {
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"data": []map[string]any{{"id": "account_1", "name": "batch-import-01"}}})
|
||||
})
|
||||
mux.HandleFunc("/api/v1/admin/accounts/__probe__/test", func(w http.ResponseWriter, r *http.Request) {
|
||||
if !requireBatchImportActionAdminToken(t, w, r) {
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"error": "probe only"})
|
||||
})
|
||||
mux.HandleFunc("/api/v1/admin/accounts/__probe__/models", func(w http.ResponseWriter, r *http.Request) {
|
||||
if !requireBatchImportActionAdminToken(t, w, r) {
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"error": "probe only"})
|
||||
})
|
||||
mux.HandleFunc("/api/v1/admin/accounts/account_1/test", func(w http.ResponseWriter, r *http.Request) {
|
||||
if !requireBatchImportActionAdminToken(t, w, r) {
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("event: result\n"))
|
||||
_, _ = w.Write([]byte("data: {\"status\":\"passed\",\"message\":\"smoke passed\",\"ok\":true}\n\n"))
|
||||
})
|
||||
mux.HandleFunc("/api/v1/admin/accounts/account_1/models", func(w http.ResponseWriter, r *http.Request) {
|
||||
if !requireBatchImportActionAdminToken(t, w, r) {
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"items": []map[string]any{{"id": "kimi-k2.6", "display_name": "Kimi K2.6", "type": "chat"}}}})
|
||||
})
|
||||
mux.HandleFunc("/api/v1/admin/subscriptions/assign", func(w http.ResponseWriter, r *http.Request) {
|
||||
if !requireBatchImportActionAdminToken(t, w, r) {
|
||||
return
|
||||
}
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"error": "probe only"})
|
||||
case http.MethodPost:
|
||||
writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"id": "subscription_1"}})
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
})
|
||||
mux.HandleFunc("/v1/models", func(w http.ResponseWriter, r *http.Request) {
|
||||
switch strings.TrimSpace(r.Header.Get("Authorization")) {
|
||||
case "Bearer entry-key", "Bearer gateway-key":
|
||||
writeJSON(w, http.StatusOK, map[string]any{"data": []map[string]any{{"id": "kimi-k2.6"}}})
|
||||
default:
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte(`{"error":"unauthorized"}`))
|
||||
}
|
||||
})
|
||||
mux.HandleFunc("/v1/responses", func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.TrimSpace(r.Header.Get("Authorization")) != "Bearer entry-key" {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte(`{"error":"unauthorized"}`))
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
_, _ = w.Write([]byte(`{"error":"responses unsupported"}`))
|
||||
})
|
||||
mux.HandleFunc("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||
switch strings.TrimSpace(r.Header.Get("Authorization")) {
|
||||
case "Bearer entry-key", "Bearer gateway-key":
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"id": "chatcmpl_batch_import",
|
||||
"choices": []map[string]any{{
|
||||
"index": 0,
|
||||
"message": map[string]any{
|
||||
"role": "assistant",
|
||||
"content": "pong",
|
||||
},
|
||||
}},
|
||||
})
|
||||
default:
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte(`{"error":"unauthorized"}`))
|
||||
}
|
||||
})
|
||||
|
||||
return mux
|
||||
}
|
||||
|
||||
func requireBatchImportActionAdminToken(t *testing.T, w http.ResponseWriter, r *http.Request) bool {
|
||||
t.Helper()
|
||||
if strings.TrimSpace(r.Header.Get("x-api-key")) != "host-token" {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte(`{"error":"unauthorized"}`))
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user