diff --git a/cmd/cli/batch_import.go b/cmd/cli/batch_import.go new file mode 100644 index 00000000..d6f5dcbf --- /dev/null +++ b/cmd/cli/batch_import.go @@ -0,0 +1,218 @@ +package main + +import ( + "context" + "flag" + "fmt" + "io" + "os" + "strings" + "time" + + "sub2api-cn-relay-manager/internal/app" + "sub2api-cn-relay-manager/internal/config" +) + +type batchImportFunc func(context.Context, batchImportCLIRequest) (batchImportCLIResult, error) + +type batchImportCLIEntry struct { + BaseURL string + APIKey string + RequestedModels []string +} + +type batchImportCLIRequest struct { + HostID string + Mode string + AccessMode string + ConfirmWaitTimeoutSec int + SubscriptionUsers []string + SubscriptionDays int + ProbeAPIKey string + Entries []batchImportCLIEntry +} + +type batchImportCLIResult struct { + RunID string + ResultPage string +} + +func parseBatchImportCLIArgs(args []string) (batchImportCLIRequest, error) { + fs := flag.NewFlagSet("batch-import", flag.ContinueOnError) + fs.SetOutput(io.Discard) + + var req batchImportCLIRequest + var entryValues cliStringList + var batchFile string + var subscriptionUsersCSV string + var confirmTimeoutRaw string + + fs.StringVar(&req.HostID, "host-id", "", "") + fs.Var(&entryValues, "entry", "") + fs.StringVar(&batchFile, "batch-file", "", "") + fs.StringVar(&req.Mode, "mode", "partial", "") + fs.StringVar(&req.AccessMode, "access-mode", "self_service", "") + fs.StringVar(&subscriptionUsersCSV, "subscription-users", "", "") + fs.IntVar(&req.SubscriptionDays, "subscription-days", 0, "") + fs.StringVar(&req.ProbeAPIKey, "probe-api-key", "", "") + fs.StringVar(&confirmTimeoutRaw, "confirm-timeout", "", "") + fs.StringVar(&confirmTimeoutRaw, "confirm-wait-timeout", "", "") + if err := fs.Parse(args); err != nil { + return batchImportCLIRequest{}, err + } + + req.SubscriptionUsers = splitCSV(subscriptionUsersCSV) + + entries := make([]batchImportCLIEntry, 0, len(entryValues)) + for _, raw := range entryValues { + entry, err := parseBatchImportEntry(raw) + if err != nil { + return batchImportCLIRequest{}, err + } + entries = append(entries, entry) + } + if strings.TrimSpace(batchFile) != "" { + fileEntries, err := parseBatchImportFile(batchFile) + if err != nil { + return batchImportCLIRequest{}, err + } + entries = append(entries, fileEntries...) + } + req.Entries = entries + + if strings.TrimSpace(confirmTimeoutRaw) != "" { + duration, err := time.ParseDuration(strings.TrimSpace(confirmTimeoutRaw)) + if err != nil { + return batchImportCLIRequest{}, fmt.Errorf("--confirm-timeout is invalid: %w", err) + } + if duration <= 0 { + return batchImportCLIRequest{}, fmt.Errorf("--confirm-timeout must be positive") + } + req.ConfirmWaitTimeoutSec = int(duration / time.Second) + if req.ConfirmWaitTimeoutSec == 0 { + req.ConfirmWaitTimeoutSec = 1 + } + } + + switch { + case strings.TrimSpace(req.HostID) == "": + return batchImportCLIRequest{}, fmt.Errorf("--host-id is required") + case len(req.Entries) == 0: + return batchImportCLIRequest{}, fmt.Errorf("--entry or --batch-file is required") + case strings.TrimSpace(req.Mode) == "": + return batchImportCLIRequest{}, fmt.Errorf("--mode is required") + case strings.TrimSpace(req.AccessMode) == "": + return batchImportCLIRequest{}, fmt.Errorf("--access-mode is required") + } + + switch strings.TrimSpace(req.AccessMode) { + case "subscription": + switch { + case len(req.SubscriptionUsers) == 0: + return batchImportCLIRequest{}, fmt.Errorf("--subscription-users is required when --access-mode=subscription") + case req.SubscriptionDays <= 0: + return batchImportCLIRequest{}, fmt.Errorf("--subscription-days is required when --access-mode=subscription") + } + case "self_service": + if strings.TrimSpace(req.ProbeAPIKey) == "" { + return batchImportCLIRequest{}, fmt.Errorf("--probe-api-key is required when --access-mode=self_service") + } + default: + return batchImportCLIRequest{}, fmt.Errorf("--access-mode must be subscription or self_service") + } + + return req, nil +} + +func runBatchImport(ctx context.Context, req batchImportCLIRequest) (batchImportCLIResult, error) { + startupConfig, err := config.LoadStartupFromEnv() + if err != nil { + return batchImportCLIResult{}, err + } + + entries := make([]app.BatchImportEntryRequest, 0, len(req.Entries)) + for _, entry := range req.Entries { + entries = append(entries, app.BatchImportEntryRequest{ + BaseURL: entry.BaseURL, + APIKey: entry.APIKey, + RequestedModels: append([]string(nil), entry.RequestedModels...), + }) + } + + actionSet := app.NewActionSet(startupConfig.Database.SQLiteDSN) + result, err := actionSet.CreateBatchImportRun(ctx, app.CreateBatchImportRunRequest{ + HostID: req.HostID, + Mode: req.Mode, + AccessMode: req.AccessMode, + ConfirmWaitTimeoutSec: req.ConfirmWaitTimeoutSec, + SubscriptionUsers: append([]string(nil), req.SubscriptionUsers...), + SubscriptionDays: req.SubscriptionDays, + ProbeAPIKey: req.ProbeAPIKey, + Entries: entries, + }) + if err != nil { + return batchImportCLIResult{}, err + } + + return batchImportCLIResult{ + RunID: result.RunID, + ResultPage: result.ResultPage, + }, nil +} + +func parseBatchImportEntry(raw string) (batchImportCLIEntry, error) { + parts := strings.SplitN(strings.TrimSpace(raw), ",", 3) + if len(parts) < 2 { + return batchImportCLIEntry{}, fmt.Errorf("--entry must be base_url,api_key[,requested_model|requested_model]") + } + + entry := batchImportCLIEntry{ + BaseURL: strings.TrimSpace(parts[0]), + APIKey: strings.TrimSpace(parts[1]), + } + if entry.BaseURL == "" || entry.APIKey == "" { + return batchImportCLIEntry{}, fmt.Errorf("--entry must include non-empty base_url and api_key") + } + if len(parts) == 3 { + entry.RequestedModels = splitPipeCSV(parts[2]) + } + return entry, nil +} + +func parseBatchImportFile(path string) ([]batchImportCLIEntry, error) { + content, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("read batch file %q: %w", path, err) + } + + lines := strings.Split(string(content), "\n") + entries := make([]batchImportCLIEntry, 0, len(lines)) + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if trimmed == "" || strings.HasPrefix(trimmed, "#") { + continue + } + entry, err := parseBatchImportEntry(trimmed) + if err != nil { + return nil, fmt.Errorf("parse batch file %q: %w", path, err) + } + entries = append(entries, entry) + } + return entries, nil +} + +type cliStringList []string + +func (l *cliStringList) String() string { + return strings.Join(*l, ",") +} + +func (l *cliStringList) Set(value string) error { + *l = append(*l, value) + return nil +} + +func splitPipeCSV(value string) []string { + value = strings.ReplaceAll(value, "|", ",") + return splitCSV(value) +} diff --git a/cmd/cli/batch_import_test.go b/cmd/cli/batch_import_test.go new file mode 100644 index 00000000..9f791326 --- /dev/null +++ b/cmd/cli/batch_import_test.go @@ -0,0 +1,126 @@ +package main + +import ( + "bytes" + "context" + "strings" + "testing" +) + +func TestBatchImportCLI(t *testing.T) { + t.Parallel() + + t.Run("parses subscription request and confirm timeout", func(t *testing.T) { + t.Parallel() + + req, err := parseBatchImportCLIArgs([]string{ + "--host-id", "host-1", + "--entry", "https://kimi.example.com/v1,sk-kimi,kimi-k2.6|kimi-2.6", + "--mode", "strict", + "--access-mode", "subscription", + "--subscription-users", "u1,u2", + "--subscription-days", "30", + "--confirm-timeout", "15s", + }) + if err != nil { + t.Fatalf("parseBatchImportCLIArgs() error = %v", err) + } + if req.HostID != "host-1" { + t.Fatalf("HostID = %q, want host-1", req.HostID) + } + if req.ConfirmWaitTimeoutSec != 15 { + t.Fatalf("ConfirmWaitTimeoutSec = %d, want 15", req.ConfirmWaitTimeoutSec) + } + if len(req.SubscriptionUsers) != 2 || req.SubscriptionUsers[0] != "u1" || req.SubscriptionUsers[1] != "u2" { + t.Fatalf("SubscriptionUsers = %#v, want [u1 u2]", req.SubscriptionUsers) + } + if len(req.Entries) != 1 { + t.Fatalf("Entries length = %d, want 1", len(req.Entries)) + } + if req.Entries[0].BaseURL != "https://kimi.example.com/v1" { + t.Fatalf("Entries[0].BaseURL = %q, want kimi url", req.Entries[0].BaseURL) + } + if len(req.Entries[0].RequestedModels) != 2 { + t.Fatalf("Entries[0].RequestedModels = %#v, want 2 models", req.Entries[0].RequestedModels) + } + }) + + t.Run("subscription requires subscription fields", func(t *testing.T) { + t.Parallel() + + _, err := parseBatchImportCLIArgs([]string{ + "--host-id", "host-1", + "--entry", "https://kimi.example.com/v1,sk-kimi", + "--mode", "strict", + "--access-mode", "subscription", + }) + if err == nil { + t.Fatal("parseBatchImportCLIArgs() error = nil, want validation error") + } + if !strings.Contains(err.Error(), "--subscription-users is required") { + t.Fatalf("parseBatchImportCLIArgs() error = %v, want subscription-users validation", err) + } + }) + + t.Run("self service requires probe api key", func(t *testing.T) { + t.Parallel() + + _, err := parseBatchImportCLIArgs([]string{ + "--host-id", "host-1", + "--entry", "https://deepseek.example.com/v1,sk-deepseek", + "--mode", "partial", + "--access-mode", "self_service", + }) + if err == nil { + t.Fatal("parseBatchImportCLIArgs() error = nil, want validation error") + } + if !strings.Contains(err.Error(), "--probe-api-key is required") { + t.Fatalf("parseBatchImportCLIArgs() error = %v, want probe-api-key validation", err) + } + }) + + t.Run("execute writes run id and result page", func(t *testing.T) { + t.Parallel() + + var output bytes.Buffer + batchImportCalled := false + + err := execute(context.Background(), &output, []string{ + "batch-import", + "--host-id", "host-1", + "--entry", "https://kimi.example.com/v1,sk-kimi", + "--mode", "strict", + "--access-mode", "self_service", + "--probe-api-key", "gateway-key", + "--confirm-timeout", "15s", + }, nil, nil, nil, nil, nil, nil, func(_ context.Context, req batchImportCLIRequest) (batchImportCLIResult, error) { + batchImportCalled = true + if req.HostID != "host-1" { + t.Fatalf("HostID = %q, want host-1", req.HostID) + } + if req.AccessMode != "self_service" { + t.Fatalf("AccessMode = %q, want self_service", req.AccessMode) + } + if req.ProbeAPIKey != "gateway-key" { + t.Fatalf("ProbeAPIKey = %q, want gateway-key", req.ProbeAPIKey) + } + if req.ConfirmWaitTimeoutSec != 15 { + t.Fatalf("ConfirmWaitTimeoutSec = %d, want 15", req.ConfirmWaitTimeoutSec) + } + return batchImportCLIResult{ + RunID: "run_20260522_0001", + ResultPage: "/batch-import/runs/run_20260522_0001", + }, nil + }) + if err != nil { + t.Fatalf("execute() batch-import error = %v", err) + } + if !batchImportCalled { + t.Fatal("execute() did not invoke batchImport") + } + got := output.String() + if !strings.Contains(got, "run_id=run_20260522_0001") || !strings.Contains(got, "result_page=/batch-import/runs/run_20260522_0001") { + t.Fatalf("execute() batch-import output = %q, want run summary", got) + } + }) +} diff --git a/cmd/cli/main.go b/cmd/cli/main.go index 702d2a0a..73863fb5 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -80,7 +80,7 @@ type rollbackSummary struct { func main() { if err := execute(context.Background(), log.Writer(), os.Args[1:], func(context.Context) (config.StartupConfig, error) { return config.LoadStartupFromEnv() - }, runInstallPack, runImportProvider, runPreviewProvider, runRollbackProvider, runReconcileProvider); err != nil { + }, runInstallPack, runImportProvider, runPreviewProvider, runRollbackProvider, runReconcileProvider, runBatchImport); err != nil { log.Fatalf("run cli: %v", err) } } @@ -95,7 +95,20 @@ func execute( previewProvider previewProviderFunc, rollbackProvider rollbackProviderFunc, reconcileProvider reconcileProviderFunc, + batchImport batchImportFunc, ) error { + if len(args) > 0 && args[0] == "batch-import" { + req, err := parseBatchImportCLIArgs(args[1:]) + if err != nil { + return err + } + result, err := batchImport(ctx, req) + if err != nil { + return err + } + _, err = fmt.Fprintf(output, "run_id=%s\nresult_page=%s\n", result.RunID, result.ResultPage) + return err + } if len(args) > 0 && args[0] == "install-pack" { req, err := parseInstallPackCLIArgs(args[1:]) if err != nil { diff --git a/cmd/cli/main_test.go b/cmd/cli/main_test.go index 1ba2262e..3a9ece85 100644 --- a/cmd/cli/main_test.go +++ b/cmd/cli/main_test.go @@ -32,7 +32,7 @@ func TestExecuteWritesConfigSummaryAfterBootstrap(t *testing.T) { SQLiteDSN: "file:test.db?_foreign_keys=on", }, }, nil - }, nil, nil, nil, nil, nil) + }, nil, nil, nil, nil, nil, nil) if err != nil { t.Fatalf("execute() returned error: %v", err) } @@ -60,7 +60,7 @@ func TestExecuteReturnsBootstrapError(t *testing.T) { err := execute(context.Background(), &bytes.Buffer{}, nil, func(context.Context) (config.StartupConfig, error) { return config.StartupConfig{}, wantErr - }, nil, nil, nil, nil, nil) + }, nil, nil, nil, nil, nil, nil) if !errors.Is(err, wantErr) { t.Fatalf("execute() error = %v, want %v", err, wantErr) } @@ -74,7 +74,7 @@ func TestExecuteReturnsWriteError(t *testing.T) { Server: config.ServerConfig{ListenAddr: ":9292"}, Database: config.DatabaseConfig{SQLiteDSN: "file:test.db"}, }, nil - }, nil, nil, nil, nil, nil) + }, nil, nil, nil, nil, nil, nil) if !errors.Is(err, wantErr) { t.Fatalf("execute() error = %v, want %v", err, wantErr) } @@ -98,7 +98,7 @@ func TestExecuteInstallPackWritesSummary(t *testing.T) { HostVersion: "0.1.126", Providers: []sqlite.Provider{{ProviderID: "deepseek"}}, }, nil - }, nil, nil, nil, nil) + }, nil, nil, nil, nil, nil) if err != nil { t.Fatalf("execute() install-pack error = %v", err) } @@ -133,7 +133,7 @@ func TestExecuteImportProviderWritesSummary(t *testing.T) { AccessStatus: provision.AccessStatusSelfServiceReady, Accounts: []provision.AccountImportResult{{}, {}}, }, nil - }, nil, nil, nil) + }, nil, nil, nil, nil) if err != nil { t.Fatalf("execute() import error = %v", err) } @@ -170,7 +170,7 @@ func TestExecutePreviewProviderWritesSummary(t *testing.T) { "plan": {Action: provision.PreviewActionConflict}, }, }, nil - }, nil, nil) + }, nil, nil, nil) if err != nil { t.Fatalf("execute() preview error = %v", err) } @@ -198,7 +198,7 @@ func TestExecuteRollbackProviderWritesSummary(t *testing.T) { t.Fatalf("unexpected rollback request: %+v", req) } return rollbackSummary{Accounts: 2, Plans: 1, Channels: 1, Groups: 1}, nil - }, nil) + }, nil, nil) if err != nil { t.Fatalf("execute() rollback error = %v", err) } @@ -227,7 +227,7 @@ func TestExecuteReconcileProviderWritesSummary(t *testing.T) { t.Fatalf("unexpected reconcile request: %+v", req) } return provision.ReconcileResult{Status: "drifted", MissingCount: 1, ExtraCount: 2, ProbeFailureCount: 1, AccessStatus: provision.AccessStatusBroken}, nil - }) + }, nil) if err != nil { t.Fatalf("execute() reconcile error = %v", err) }