feat(cli): add v2 batch import command
This commit is contained in:
218
cmd/cli/batch_import.go
Normal file
218
cmd/cli/batch_import.go
Normal file
@@ -0,0 +1,218 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/app"
|
||||
"sub2api-cn-relay-manager/internal/config"
|
||||
)
|
||||
|
||||
type batchImportFunc func(context.Context, batchImportCLIRequest) (batchImportCLIResult, error)
|
||||
|
||||
type batchImportCLIEntry struct {
|
||||
BaseURL string
|
||||
APIKey string
|
||||
RequestedModels []string
|
||||
}
|
||||
|
||||
type batchImportCLIRequest struct {
|
||||
HostID string
|
||||
Mode string
|
||||
AccessMode string
|
||||
ConfirmWaitTimeoutSec int
|
||||
SubscriptionUsers []string
|
||||
SubscriptionDays int
|
||||
ProbeAPIKey string
|
||||
Entries []batchImportCLIEntry
|
||||
}
|
||||
|
||||
type batchImportCLIResult struct {
|
||||
RunID string
|
||||
ResultPage string
|
||||
}
|
||||
|
||||
func parseBatchImportCLIArgs(args []string) (batchImportCLIRequest, error) {
|
||||
fs := flag.NewFlagSet("batch-import", flag.ContinueOnError)
|
||||
fs.SetOutput(io.Discard)
|
||||
|
||||
var req batchImportCLIRequest
|
||||
var entryValues cliStringList
|
||||
var batchFile string
|
||||
var subscriptionUsersCSV string
|
||||
var confirmTimeoutRaw string
|
||||
|
||||
fs.StringVar(&req.HostID, "host-id", "", "")
|
||||
fs.Var(&entryValues, "entry", "")
|
||||
fs.StringVar(&batchFile, "batch-file", "", "")
|
||||
fs.StringVar(&req.Mode, "mode", "partial", "")
|
||||
fs.StringVar(&req.AccessMode, "access-mode", "self_service", "")
|
||||
fs.StringVar(&subscriptionUsersCSV, "subscription-users", "", "")
|
||||
fs.IntVar(&req.SubscriptionDays, "subscription-days", 0, "")
|
||||
fs.StringVar(&req.ProbeAPIKey, "probe-api-key", "", "")
|
||||
fs.StringVar(&confirmTimeoutRaw, "confirm-timeout", "", "")
|
||||
fs.StringVar(&confirmTimeoutRaw, "confirm-wait-timeout", "", "")
|
||||
if err := fs.Parse(args); err != nil {
|
||||
return batchImportCLIRequest{}, err
|
||||
}
|
||||
|
||||
req.SubscriptionUsers = splitCSV(subscriptionUsersCSV)
|
||||
|
||||
entries := make([]batchImportCLIEntry, 0, len(entryValues))
|
||||
for _, raw := range entryValues {
|
||||
entry, err := parseBatchImportEntry(raw)
|
||||
if err != nil {
|
||||
return batchImportCLIRequest{}, err
|
||||
}
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
if strings.TrimSpace(batchFile) != "" {
|
||||
fileEntries, err := parseBatchImportFile(batchFile)
|
||||
if err != nil {
|
||||
return batchImportCLIRequest{}, err
|
||||
}
|
||||
entries = append(entries, fileEntries...)
|
||||
}
|
||||
req.Entries = entries
|
||||
|
||||
if strings.TrimSpace(confirmTimeoutRaw) != "" {
|
||||
duration, err := time.ParseDuration(strings.TrimSpace(confirmTimeoutRaw))
|
||||
if err != nil {
|
||||
return batchImportCLIRequest{}, fmt.Errorf("--confirm-timeout is invalid: %w", err)
|
||||
}
|
||||
if duration <= 0 {
|
||||
return batchImportCLIRequest{}, fmt.Errorf("--confirm-timeout must be positive")
|
||||
}
|
||||
req.ConfirmWaitTimeoutSec = int(duration / time.Second)
|
||||
if req.ConfirmWaitTimeoutSec == 0 {
|
||||
req.ConfirmWaitTimeoutSec = 1
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case strings.TrimSpace(req.HostID) == "":
|
||||
return batchImportCLIRequest{}, fmt.Errorf("--host-id is required")
|
||||
case len(req.Entries) == 0:
|
||||
return batchImportCLIRequest{}, fmt.Errorf("--entry or --batch-file is required")
|
||||
case strings.TrimSpace(req.Mode) == "":
|
||||
return batchImportCLIRequest{}, fmt.Errorf("--mode is required")
|
||||
case strings.TrimSpace(req.AccessMode) == "":
|
||||
return batchImportCLIRequest{}, fmt.Errorf("--access-mode is required")
|
||||
}
|
||||
|
||||
switch strings.TrimSpace(req.AccessMode) {
|
||||
case "subscription":
|
||||
switch {
|
||||
case len(req.SubscriptionUsers) == 0:
|
||||
return batchImportCLIRequest{}, fmt.Errorf("--subscription-users is required when --access-mode=subscription")
|
||||
case req.SubscriptionDays <= 0:
|
||||
return batchImportCLIRequest{}, fmt.Errorf("--subscription-days is required when --access-mode=subscription")
|
||||
}
|
||||
case "self_service":
|
||||
if strings.TrimSpace(req.ProbeAPIKey) == "" {
|
||||
return batchImportCLIRequest{}, fmt.Errorf("--probe-api-key is required when --access-mode=self_service")
|
||||
}
|
||||
default:
|
||||
return batchImportCLIRequest{}, fmt.Errorf("--access-mode must be subscription or self_service")
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func runBatchImport(ctx context.Context, req batchImportCLIRequest) (batchImportCLIResult, error) {
|
||||
startupConfig, err := config.LoadStartupFromEnv()
|
||||
if err != nil {
|
||||
return batchImportCLIResult{}, err
|
||||
}
|
||||
|
||||
entries := make([]app.BatchImportEntryRequest, 0, len(req.Entries))
|
||||
for _, entry := range req.Entries {
|
||||
entries = append(entries, app.BatchImportEntryRequest{
|
||||
BaseURL: entry.BaseURL,
|
||||
APIKey: entry.APIKey,
|
||||
RequestedModels: append([]string(nil), entry.RequestedModels...),
|
||||
})
|
||||
}
|
||||
|
||||
actionSet := app.NewActionSet(startupConfig.Database.SQLiteDSN)
|
||||
result, err := actionSet.CreateBatchImportRun(ctx, app.CreateBatchImportRunRequest{
|
||||
HostID: req.HostID,
|
||||
Mode: req.Mode,
|
||||
AccessMode: req.AccessMode,
|
||||
ConfirmWaitTimeoutSec: req.ConfirmWaitTimeoutSec,
|
||||
SubscriptionUsers: append([]string(nil), req.SubscriptionUsers...),
|
||||
SubscriptionDays: req.SubscriptionDays,
|
||||
ProbeAPIKey: req.ProbeAPIKey,
|
||||
Entries: entries,
|
||||
})
|
||||
if err != nil {
|
||||
return batchImportCLIResult{}, err
|
||||
}
|
||||
|
||||
return batchImportCLIResult{
|
||||
RunID: result.RunID,
|
||||
ResultPage: result.ResultPage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseBatchImportEntry(raw string) (batchImportCLIEntry, error) {
|
||||
parts := strings.SplitN(strings.TrimSpace(raw), ",", 3)
|
||||
if len(parts) < 2 {
|
||||
return batchImportCLIEntry{}, fmt.Errorf("--entry must be base_url,api_key[,requested_model|requested_model]")
|
||||
}
|
||||
|
||||
entry := batchImportCLIEntry{
|
||||
BaseURL: strings.TrimSpace(parts[0]),
|
||||
APIKey: strings.TrimSpace(parts[1]),
|
||||
}
|
||||
if entry.BaseURL == "" || entry.APIKey == "" {
|
||||
return batchImportCLIEntry{}, fmt.Errorf("--entry must include non-empty base_url and api_key")
|
||||
}
|
||||
if len(parts) == 3 {
|
||||
entry.RequestedModels = splitPipeCSV(parts[2])
|
||||
}
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
func parseBatchImportFile(path string) ([]batchImportCLIEntry, error) {
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read batch file %q: %w", path, err)
|
||||
}
|
||||
|
||||
lines := strings.Split(string(content), "\n")
|
||||
entries := make([]batchImportCLIEntry, 0, len(lines))
|
||||
for _, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if trimmed == "" || strings.HasPrefix(trimmed, "#") {
|
||||
continue
|
||||
}
|
||||
entry, err := parseBatchImportEntry(trimmed)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse batch file %q: %w", path, err)
|
||||
}
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
type cliStringList []string
|
||||
|
||||
func (l *cliStringList) String() string {
|
||||
return strings.Join(*l, ",")
|
||||
}
|
||||
|
||||
func (l *cliStringList) Set(value string) error {
|
||||
*l = append(*l, value)
|
||||
return nil
|
||||
}
|
||||
|
||||
func splitPipeCSV(value string) []string {
|
||||
value = strings.ReplaceAll(value, "|", ",")
|
||||
return splitCSV(value)
|
||||
}
|
||||
126
cmd/cli/batch_import_test.go
Normal file
126
cmd/cli/batch_import_test.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBatchImportCLI(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("parses subscription request and confirm timeout", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req, err := parseBatchImportCLIArgs([]string{
|
||||
"--host-id", "host-1",
|
||||
"--entry", "https://kimi.example.com/v1,sk-kimi,kimi-k2.6|kimi-2.6",
|
||||
"--mode", "strict",
|
||||
"--access-mode", "subscription",
|
||||
"--subscription-users", "u1,u2",
|
||||
"--subscription-days", "30",
|
||||
"--confirm-timeout", "15s",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("parseBatchImportCLIArgs() error = %v", err)
|
||||
}
|
||||
if req.HostID != "host-1" {
|
||||
t.Fatalf("HostID = %q, want host-1", req.HostID)
|
||||
}
|
||||
if req.ConfirmWaitTimeoutSec != 15 {
|
||||
t.Fatalf("ConfirmWaitTimeoutSec = %d, want 15", req.ConfirmWaitTimeoutSec)
|
||||
}
|
||||
if len(req.SubscriptionUsers) != 2 || req.SubscriptionUsers[0] != "u1" || req.SubscriptionUsers[1] != "u2" {
|
||||
t.Fatalf("SubscriptionUsers = %#v, want [u1 u2]", req.SubscriptionUsers)
|
||||
}
|
||||
if len(req.Entries) != 1 {
|
||||
t.Fatalf("Entries length = %d, want 1", len(req.Entries))
|
||||
}
|
||||
if req.Entries[0].BaseURL != "https://kimi.example.com/v1" {
|
||||
t.Fatalf("Entries[0].BaseURL = %q, want kimi url", req.Entries[0].BaseURL)
|
||||
}
|
||||
if len(req.Entries[0].RequestedModels) != 2 {
|
||||
t.Fatalf("Entries[0].RequestedModels = %#v, want 2 models", req.Entries[0].RequestedModels)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("subscription requires subscription fields", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := parseBatchImportCLIArgs([]string{
|
||||
"--host-id", "host-1",
|
||||
"--entry", "https://kimi.example.com/v1,sk-kimi",
|
||||
"--mode", "strict",
|
||||
"--access-mode", "subscription",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("parseBatchImportCLIArgs() error = nil, want validation error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "--subscription-users is required") {
|
||||
t.Fatalf("parseBatchImportCLIArgs() error = %v, want subscription-users validation", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("self service requires probe api key", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := parseBatchImportCLIArgs([]string{
|
||||
"--host-id", "host-1",
|
||||
"--entry", "https://deepseek.example.com/v1,sk-deepseek",
|
||||
"--mode", "partial",
|
||||
"--access-mode", "self_service",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("parseBatchImportCLIArgs() error = nil, want validation error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "--probe-api-key is required") {
|
||||
t.Fatalf("parseBatchImportCLIArgs() error = %v, want probe-api-key validation", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("execute writes run id and result page", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var output bytes.Buffer
|
||||
batchImportCalled := false
|
||||
|
||||
err := execute(context.Background(), &output, []string{
|
||||
"batch-import",
|
||||
"--host-id", "host-1",
|
||||
"--entry", "https://kimi.example.com/v1,sk-kimi",
|
||||
"--mode", "strict",
|
||||
"--access-mode", "self_service",
|
||||
"--probe-api-key", "gateway-key",
|
||||
"--confirm-timeout", "15s",
|
||||
}, nil, nil, nil, nil, nil, nil, func(_ context.Context, req batchImportCLIRequest) (batchImportCLIResult, error) {
|
||||
batchImportCalled = true
|
||||
if req.HostID != "host-1" {
|
||||
t.Fatalf("HostID = %q, want host-1", req.HostID)
|
||||
}
|
||||
if req.AccessMode != "self_service" {
|
||||
t.Fatalf("AccessMode = %q, want self_service", req.AccessMode)
|
||||
}
|
||||
if req.ProbeAPIKey != "gateway-key" {
|
||||
t.Fatalf("ProbeAPIKey = %q, want gateway-key", req.ProbeAPIKey)
|
||||
}
|
||||
if req.ConfirmWaitTimeoutSec != 15 {
|
||||
t.Fatalf("ConfirmWaitTimeoutSec = %d, want 15", req.ConfirmWaitTimeoutSec)
|
||||
}
|
||||
return batchImportCLIResult{
|
||||
RunID: "run_20260522_0001",
|
||||
ResultPage: "/batch-import/runs/run_20260522_0001",
|
||||
}, nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("execute() batch-import error = %v", err)
|
||||
}
|
||||
if !batchImportCalled {
|
||||
t.Fatal("execute() did not invoke batchImport")
|
||||
}
|
||||
got := output.String()
|
||||
if !strings.Contains(got, "run_id=run_20260522_0001") || !strings.Contains(got, "result_page=/batch-import/runs/run_20260522_0001") {
|
||||
t.Fatalf("execute() batch-import output = %q, want run summary", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -80,7 +80,7 @@ type rollbackSummary struct {
|
||||
func main() {
|
||||
if err := execute(context.Background(), log.Writer(), os.Args[1:], func(context.Context) (config.StartupConfig, error) {
|
||||
return config.LoadStartupFromEnv()
|
||||
}, runInstallPack, runImportProvider, runPreviewProvider, runRollbackProvider, runReconcileProvider); err != nil {
|
||||
}, runInstallPack, runImportProvider, runPreviewProvider, runRollbackProvider, runReconcileProvider, runBatchImport); err != nil {
|
||||
log.Fatalf("run cli: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -95,7 +95,20 @@ func execute(
|
||||
previewProvider previewProviderFunc,
|
||||
rollbackProvider rollbackProviderFunc,
|
||||
reconcileProvider reconcileProviderFunc,
|
||||
batchImport batchImportFunc,
|
||||
) error {
|
||||
if len(args) > 0 && args[0] == "batch-import" {
|
||||
req, err := parseBatchImportCLIArgs(args[1:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result, err := batchImport(ctx, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = fmt.Fprintf(output, "run_id=%s\nresult_page=%s\n", result.RunID, result.ResultPage)
|
||||
return err
|
||||
}
|
||||
if len(args) > 0 && args[0] == "install-pack" {
|
||||
req, err := parseInstallPackCLIArgs(args[1:])
|
||||
if err != nil {
|
||||
|
||||
@@ -32,7 +32,7 @@ func TestExecuteWritesConfigSummaryAfterBootstrap(t *testing.T) {
|
||||
SQLiteDSN: "file:test.db?_foreign_keys=on",
|
||||
},
|
||||
}, nil
|
||||
}, nil, nil, nil, nil, nil)
|
||||
}, nil, nil, nil, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("execute() returned error: %v", err)
|
||||
}
|
||||
@@ -60,7 +60,7 @@ func TestExecuteReturnsBootstrapError(t *testing.T) {
|
||||
|
||||
err := execute(context.Background(), &bytes.Buffer{}, nil, func(context.Context) (config.StartupConfig, error) {
|
||||
return config.StartupConfig{}, wantErr
|
||||
}, nil, nil, nil, nil, nil)
|
||||
}, nil, nil, nil, nil, nil, nil)
|
||||
if !errors.Is(err, wantErr) {
|
||||
t.Fatalf("execute() error = %v, want %v", err, wantErr)
|
||||
}
|
||||
@@ -74,7 +74,7 @@ func TestExecuteReturnsWriteError(t *testing.T) {
|
||||
Server: config.ServerConfig{ListenAddr: ":9292"},
|
||||
Database: config.DatabaseConfig{SQLiteDSN: "file:test.db"},
|
||||
}, nil
|
||||
}, nil, nil, nil, nil, nil)
|
||||
}, nil, nil, nil, nil, nil, nil)
|
||||
if !errors.Is(err, wantErr) {
|
||||
t.Fatalf("execute() error = %v, want %v", err, wantErr)
|
||||
}
|
||||
@@ -98,7 +98,7 @@ func TestExecuteInstallPackWritesSummary(t *testing.T) {
|
||||
HostVersion: "0.1.126",
|
||||
Providers: []sqlite.Provider{{ProviderID: "deepseek"}},
|
||||
}, nil
|
||||
}, nil, nil, nil, nil)
|
||||
}, nil, nil, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("execute() install-pack error = %v", err)
|
||||
}
|
||||
@@ -133,7 +133,7 @@ func TestExecuteImportProviderWritesSummary(t *testing.T) {
|
||||
AccessStatus: provision.AccessStatusSelfServiceReady,
|
||||
Accounts: []provision.AccountImportResult{{}, {}},
|
||||
}, nil
|
||||
}, nil, nil, nil)
|
||||
}, nil, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("execute() import error = %v", err)
|
||||
}
|
||||
@@ -170,7 +170,7 @@ func TestExecutePreviewProviderWritesSummary(t *testing.T) {
|
||||
"plan": {Action: provision.PreviewActionConflict},
|
||||
},
|
||||
}, nil
|
||||
}, nil, nil)
|
||||
}, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("execute() preview error = %v", err)
|
||||
}
|
||||
@@ -198,7 +198,7 @@ func TestExecuteRollbackProviderWritesSummary(t *testing.T) {
|
||||
t.Fatalf("unexpected rollback request: %+v", req)
|
||||
}
|
||||
return rollbackSummary{Accounts: 2, Plans: 1, Channels: 1, Groups: 1}, nil
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("execute() rollback error = %v", err)
|
||||
}
|
||||
@@ -227,7 +227,7 @@ func TestExecuteReconcileProviderWritesSummary(t *testing.T) {
|
||||
t.Fatalf("unexpected reconcile request: %+v", req)
|
||||
}
|
||||
return provision.ReconcileResult{Status: "drifted", MissingCount: 1, ExtraCount: 2, ProbeFailureCount: 1, AccessStatus: provision.AccessStatusBroken}, nil
|
||||
})
|
||||
}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("execute() reconcile error = %v", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user