219 lines
6.5 KiB
Go
219 lines
6.5 KiB
Go
|
|
package main
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"flag"
|
||
|
|
"fmt"
|
||
|
|
"io"
|
||
|
|
"os"
|
||
|
|
"strings"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"sub2api-cn-relay-manager/internal/app"
|
||
|
|
"sub2api-cn-relay-manager/internal/config"
|
||
|
|
)
|
||
|
|
|
||
|
|
type batchImportFunc func(context.Context, batchImportCLIRequest) (batchImportCLIResult, error)
|
||
|
|
|
||
|
|
type batchImportCLIEntry struct {
|
||
|
|
BaseURL string
|
||
|
|
APIKey string
|
||
|
|
RequestedModels []string
|
||
|
|
}
|
||
|
|
|
||
|
|
type batchImportCLIRequest struct {
|
||
|
|
HostID string
|
||
|
|
Mode string
|
||
|
|
AccessMode string
|
||
|
|
ConfirmWaitTimeoutSec int
|
||
|
|
SubscriptionUsers []string
|
||
|
|
SubscriptionDays int
|
||
|
|
ProbeAPIKey string
|
||
|
|
Entries []batchImportCLIEntry
|
||
|
|
}
|
||
|
|
|
||
|
|
type batchImportCLIResult struct {
|
||
|
|
RunID string
|
||
|
|
ResultPage string
|
||
|
|
}
|
||
|
|
|
||
|
|
func parseBatchImportCLIArgs(args []string) (batchImportCLIRequest, error) {
|
||
|
|
fs := flag.NewFlagSet("batch-import", flag.ContinueOnError)
|
||
|
|
fs.SetOutput(io.Discard)
|
||
|
|
|
||
|
|
var req batchImportCLIRequest
|
||
|
|
var entryValues cliStringList
|
||
|
|
var batchFile string
|
||
|
|
var subscriptionUsersCSV string
|
||
|
|
var confirmTimeoutRaw string
|
||
|
|
|
||
|
|
fs.StringVar(&req.HostID, "host-id", "", "")
|
||
|
|
fs.Var(&entryValues, "entry", "")
|
||
|
|
fs.StringVar(&batchFile, "batch-file", "", "")
|
||
|
|
fs.StringVar(&req.Mode, "mode", "partial", "")
|
||
|
|
fs.StringVar(&req.AccessMode, "access-mode", "self_service", "")
|
||
|
|
fs.StringVar(&subscriptionUsersCSV, "subscription-users", "", "")
|
||
|
|
fs.IntVar(&req.SubscriptionDays, "subscription-days", 0, "")
|
||
|
|
fs.StringVar(&req.ProbeAPIKey, "probe-api-key", "", "")
|
||
|
|
fs.StringVar(&confirmTimeoutRaw, "confirm-timeout", "", "")
|
||
|
|
fs.StringVar(&confirmTimeoutRaw, "confirm-wait-timeout", "", "")
|
||
|
|
if err := fs.Parse(args); err != nil {
|
||
|
|
return batchImportCLIRequest{}, err
|
||
|
|
}
|
||
|
|
|
||
|
|
req.SubscriptionUsers = splitCSV(subscriptionUsersCSV)
|
||
|
|
|
||
|
|
entries := make([]batchImportCLIEntry, 0, len(entryValues))
|
||
|
|
for _, raw := range entryValues {
|
||
|
|
entry, err := parseBatchImportEntry(raw)
|
||
|
|
if err != nil {
|
||
|
|
return batchImportCLIRequest{}, err
|
||
|
|
}
|
||
|
|
entries = append(entries, entry)
|
||
|
|
}
|
||
|
|
if strings.TrimSpace(batchFile) != "" {
|
||
|
|
fileEntries, err := parseBatchImportFile(batchFile)
|
||
|
|
if err != nil {
|
||
|
|
return batchImportCLIRequest{}, err
|
||
|
|
}
|
||
|
|
entries = append(entries, fileEntries...)
|
||
|
|
}
|
||
|
|
req.Entries = entries
|
||
|
|
|
||
|
|
if strings.TrimSpace(confirmTimeoutRaw) != "" {
|
||
|
|
duration, err := time.ParseDuration(strings.TrimSpace(confirmTimeoutRaw))
|
||
|
|
if err != nil {
|
||
|
|
return batchImportCLIRequest{}, fmt.Errorf("--confirm-timeout is invalid: %w", err)
|
||
|
|
}
|
||
|
|
if duration <= 0 {
|
||
|
|
return batchImportCLIRequest{}, fmt.Errorf("--confirm-timeout must be positive")
|
||
|
|
}
|
||
|
|
req.ConfirmWaitTimeoutSec = int(duration / time.Second)
|
||
|
|
if req.ConfirmWaitTimeoutSec == 0 {
|
||
|
|
req.ConfirmWaitTimeoutSec = 1
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
switch {
|
||
|
|
case strings.TrimSpace(req.HostID) == "":
|
||
|
|
return batchImportCLIRequest{}, fmt.Errorf("--host-id is required")
|
||
|
|
case len(req.Entries) == 0:
|
||
|
|
return batchImportCLIRequest{}, fmt.Errorf("--entry or --batch-file is required")
|
||
|
|
case strings.TrimSpace(req.Mode) == "":
|
||
|
|
return batchImportCLIRequest{}, fmt.Errorf("--mode is required")
|
||
|
|
case strings.TrimSpace(req.AccessMode) == "":
|
||
|
|
return batchImportCLIRequest{}, fmt.Errorf("--access-mode is required")
|
||
|
|
}
|
||
|
|
|
||
|
|
switch strings.TrimSpace(req.AccessMode) {
|
||
|
|
case "subscription":
|
||
|
|
switch {
|
||
|
|
case len(req.SubscriptionUsers) == 0:
|
||
|
|
return batchImportCLIRequest{}, fmt.Errorf("--subscription-users is required when --access-mode=subscription")
|
||
|
|
case req.SubscriptionDays <= 0:
|
||
|
|
return batchImportCLIRequest{}, fmt.Errorf("--subscription-days is required when --access-mode=subscription")
|
||
|
|
}
|
||
|
|
case "self_service":
|
||
|
|
if strings.TrimSpace(req.ProbeAPIKey) == "" {
|
||
|
|
return batchImportCLIRequest{}, fmt.Errorf("--probe-api-key is required when --access-mode=self_service")
|
||
|
|
}
|
||
|
|
default:
|
||
|
|
return batchImportCLIRequest{}, fmt.Errorf("--access-mode must be subscription or self_service")
|
||
|
|
}
|
||
|
|
|
||
|
|
return req, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func runBatchImport(ctx context.Context, req batchImportCLIRequest) (batchImportCLIResult, error) {
|
||
|
|
startupConfig, err := config.LoadStartupFromEnv()
|
||
|
|
if err != nil {
|
||
|
|
return batchImportCLIResult{}, err
|
||
|
|
}
|
||
|
|
|
||
|
|
entries := make([]app.BatchImportEntryRequest, 0, len(req.Entries))
|
||
|
|
for _, entry := range req.Entries {
|
||
|
|
entries = append(entries, app.BatchImportEntryRequest{
|
||
|
|
BaseURL: entry.BaseURL,
|
||
|
|
APIKey: entry.APIKey,
|
||
|
|
RequestedModels: append([]string(nil), entry.RequestedModels...),
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
actionSet := app.NewActionSet(startupConfig.Database.SQLiteDSN)
|
||
|
|
result, err := actionSet.CreateBatchImportRun(ctx, app.CreateBatchImportRunRequest{
|
||
|
|
HostID: req.HostID,
|
||
|
|
Mode: req.Mode,
|
||
|
|
AccessMode: req.AccessMode,
|
||
|
|
ConfirmWaitTimeoutSec: req.ConfirmWaitTimeoutSec,
|
||
|
|
SubscriptionUsers: append([]string(nil), req.SubscriptionUsers...),
|
||
|
|
SubscriptionDays: req.SubscriptionDays,
|
||
|
|
ProbeAPIKey: req.ProbeAPIKey,
|
||
|
|
Entries: entries,
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
return batchImportCLIResult{}, err
|
||
|
|
}
|
||
|
|
|
||
|
|
return batchImportCLIResult{
|
||
|
|
RunID: result.RunID,
|
||
|
|
ResultPage: result.ResultPage,
|
||
|
|
}, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func parseBatchImportEntry(raw string) (batchImportCLIEntry, error) {
|
||
|
|
parts := strings.SplitN(strings.TrimSpace(raw), ",", 3)
|
||
|
|
if len(parts) < 2 {
|
||
|
|
return batchImportCLIEntry{}, fmt.Errorf("--entry must be base_url,api_key[,requested_model|requested_model]")
|
||
|
|
}
|
||
|
|
|
||
|
|
entry := batchImportCLIEntry{
|
||
|
|
BaseURL: strings.TrimSpace(parts[0]),
|
||
|
|
APIKey: strings.TrimSpace(parts[1]),
|
||
|
|
}
|
||
|
|
if entry.BaseURL == "" || entry.APIKey == "" {
|
||
|
|
return batchImportCLIEntry{}, fmt.Errorf("--entry must include non-empty base_url and api_key")
|
||
|
|
}
|
||
|
|
if len(parts) == 3 {
|
||
|
|
entry.RequestedModels = splitPipeCSV(parts[2])
|
||
|
|
}
|
||
|
|
return entry, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func parseBatchImportFile(path string) ([]batchImportCLIEntry, error) {
|
||
|
|
content, err := os.ReadFile(path)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("read batch file %q: %w", path, err)
|
||
|
|
}
|
||
|
|
|
||
|
|
lines := strings.Split(string(content), "\n")
|
||
|
|
entries := make([]batchImportCLIEntry, 0, len(lines))
|
||
|
|
for _, line := range lines {
|
||
|
|
trimmed := strings.TrimSpace(line)
|
||
|
|
if trimmed == "" || strings.HasPrefix(trimmed, "#") {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
entry, err := parseBatchImportEntry(trimmed)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("parse batch file %q: %w", path, err)
|
||
|
|
}
|
||
|
|
entries = append(entries, entry)
|
||
|
|
}
|
||
|
|
return entries, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
type cliStringList []string
|
||
|
|
|
||
|
|
func (l *cliStringList) String() string {
|
||
|
|
return strings.Join(*l, ",")
|
||
|
|
}
|
||
|
|
|
||
|
|
func (l *cliStringList) Set(value string) error {
|
||
|
|
*l = append(*l, value)
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func splitPipeCSV(value string) []string {
|
||
|
|
value = strings.ReplaceAll(value, "|", ",")
|
||
|
|
return splitCSV(value)
|
||
|
|
}
|