262 lines
8.6 KiB
Go
262 lines
8.6 KiB
Go
package app
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"sub2api-cn-relay-manager/internal/batch"
|
|
"sub2api-cn-relay-manager/internal/store/sqlite"
|
|
)
|
|
|
|
type BatchImportEntryRequest struct {
|
|
BaseURL string `json:"base_url"`
|
|
APIKey string `json:"api_key"`
|
|
RequestedModels []string `json:"requested_models"`
|
|
}
|
|
|
|
type CreateBatchImportRunRequest struct {
|
|
HostID string `json:"host_id,omitempty"`
|
|
Mode string `json:"mode"`
|
|
AccessMode string `json:"access_mode"`
|
|
ConfirmWaitTimeoutSec int `json:"confirm_wait_timeout_sec,omitempty"`
|
|
SubscriptionUsers []string `json:"subscription_users,omitempty"`
|
|
SubscriptionDays int `json:"subscription_days,omitempty"`
|
|
ProbeAPIKey string `json:"probe_api_key,omitempty"`
|
|
Entries []BatchImportEntryRequest `json:"entries"`
|
|
}
|
|
|
|
type BatchImportRunCreateResponse struct {
|
|
RunID string `json:"run_id"`
|
|
State string `json:"state"`
|
|
ResultPage string `json:"result_page"`
|
|
TotalItems int `json:"total_items"`
|
|
ActiveItems int `json:"active_items"`
|
|
DegradedItems int `json:"degraded_items"`
|
|
BrokenItems int `json:"broken_items"`
|
|
WarningItems int `json:"warning_items"`
|
|
}
|
|
|
|
type ListBatchImportRunsRequest struct {
|
|
State string
|
|
AccessMode string
|
|
Query string
|
|
Limit int
|
|
}
|
|
|
|
type ListBatchImportRunItemsRequest struct {
|
|
RunID string
|
|
CurrentStage string
|
|
ConfirmationStatus string
|
|
AccessStatus string
|
|
HasWarning *bool
|
|
ProviderID string
|
|
MatchedAccountState string
|
|
AccountResolution string
|
|
Query string
|
|
Limit int
|
|
}
|
|
|
|
type GetBatchImportRunItemRequest struct {
|
|
RunID string
|
|
ItemID string
|
|
}
|
|
|
|
func handleCreateBatchImportRun(w http.ResponseWriter, r *http.Request, fn func(context.Context, CreateBatchImportRunRequest) (BatchImportRunCreateResponse, error)) {
|
|
if fn == nil {
|
|
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "create-batch-import-run action is not configured"})
|
|
return
|
|
}
|
|
var req CreateBatchImportRunRequest
|
|
if err := decodeJSON(r, &req); err != nil {
|
|
writeHTTPError(w, err)
|
|
return
|
|
}
|
|
if err := validateCreateBatchImportRunRequest(req); err != nil {
|
|
writeHTTPError(w, err)
|
|
return
|
|
}
|
|
result, err := fn(r.Context(), req)
|
|
if err != nil {
|
|
writeHTTPError(w, classifyError(err))
|
|
return
|
|
}
|
|
writeJSON(w, http.StatusOK, result)
|
|
}
|
|
|
|
func handleListBatchImportRuns(w http.ResponseWriter, r *http.Request, fn func(context.Context, ListBatchImportRunsRequest) ([]batch.RunSummaryProjection, error)) {
|
|
if fn == nil {
|
|
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "list-batch-import-runs action is not configured"})
|
|
return
|
|
}
|
|
result, err := fn(r.Context(), ListBatchImportRunsRequest{
|
|
State: strings.TrimSpace(r.URL.Query().Get("state")),
|
|
AccessMode: strings.TrimSpace(r.URL.Query().Get("access_mode")),
|
|
Query: strings.TrimSpace(r.URL.Query().Get("q")),
|
|
Limit: parsePositiveInt(r.URL.Query().Get("limit")),
|
|
})
|
|
if err != nil {
|
|
writeHTTPError(w, classifyError(err))
|
|
return
|
|
}
|
|
if result == nil {
|
|
result = []batch.RunSummaryProjection{}
|
|
}
|
|
writeJSON(w, http.StatusOK, map[string]any{"runs": result})
|
|
}
|
|
|
|
func buildCreateBatchImportRunAction(sqliteDSN string) func(context.Context, CreateBatchImportRunRequest) (BatchImportRunCreateResponse, error) {
|
|
return func(ctx context.Context, req CreateBatchImportRunRequest) (BatchImportRunCreateResponse, error) {
|
|
store, err := sqlite.Open(ctx, sqliteDSN)
|
|
if err != nil {
|
|
return BatchImportRunCreateResponse{}, err
|
|
}
|
|
defer store.Close()
|
|
|
|
runID := fmt.Sprintf("run_%d", time.Now().UnixNano())
|
|
run := sqlite.ImportRun{
|
|
RunID: runID,
|
|
Mode: strings.TrimSpace(req.Mode),
|
|
AccessMode: strings.TrimSpace(req.AccessMode),
|
|
State: string(batch.RunStateRunning),
|
|
TotalItems: len(req.Entries),
|
|
}
|
|
if err := store.ImportRuns().Create(ctx, run); err != nil {
|
|
return BatchImportRunCreateResponse{}, err
|
|
}
|
|
|
|
for idx, entry := range req.Entries {
|
|
item := sqlite.ImportRunItem{
|
|
ItemID: fmt.Sprintf("%s-item-%d", runID, idx+1),
|
|
RunID: runID,
|
|
BaseURL: strings.TrimSpace(entry.BaseURL),
|
|
ProviderID: batch.NormalizeProviderID(entry.BaseURL),
|
|
APIKeyFingerprint: fingerprintBatchAPIKey(entry.APIKey),
|
|
RequestedModelsJSON: mustMarshalAppJSON(entry.RequestedModels, "[]"),
|
|
CurrentStage: string(batch.ItemStageProbe),
|
|
ConfirmationStatus: string(batch.ConfirmationPending),
|
|
AccessStatus: string(batch.AccessStatusUnknown),
|
|
MatchedAccountState: string(batch.MatchedAccountStateNone),
|
|
AccountResolution: string(batch.AccountResolutionCreated),
|
|
ProvisionReused: false,
|
|
CanonicalFamiliesJSON: "[]",
|
|
RawModelsJSON: "[]",
|
|
NormalizedModelsJSON: "[]",
|
|
RecommendedModelsJSON: "[]",
|
|
}
|
|
if err := store.ImportRunItems().Upsert(ctx, item); err != nil {
|
|
return BatchImportRunCreateResponse{}, err
|
|
}
|
|
}
|
|
|
|
return BatchImportRunCreateResponse{
|
|
RunID: runID,
|
|
State: string(batch.RunStateRunning),
|
|
ResultPage: "/batch-import/runs/" + runID,
|
|
TotalItems: len(req.Entries),
|
|
ActiveItems: 0,
|
|
DegradedItems: 0,
|
|
BrokenItems: 0,
|
|
WarningItems: 0,
|
|
}, nil
|
|
}
|
|
}
|
|
|
|
func buildListBatchImportRunsAction(sqliteDSN string) func(context.Context, ListBatchImportRunsRequest) ([]batch.RunSummaryProjection, error) {
|
|
return func(ctx context.Context, req ListBatchImportRunsRequest) ([]batch.RunSummaryProjection, error) {
|
|
store, err := sqlite.Open(ctx, sqliteDSN)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer store.Close()
|
|
|
|
runs, err := store.ImportRuns().List(ctx, defaultPositiveInt(req.Limit, 50))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
result := make([]batch.RunSummaryProjection, 0, len(runs))
|
|
for _, run := range runs {
|
|
if req.State != "" && run.State != req.State {
|
|
continue
|
|
}
|
|
if req.AccessMode != "" && run.AccessMode != req.AccessMode {
|
|
continue
|
|
}
|
|
if req.Query != "" && !strings.Contains(strings.ToLower(run.RunID), strings.ToLower(req.Query)) {
|
|
continue
|
|
}
|
|
result = append(result, batch.ProjectRunSummary(run))
|
|
}
|
|
return result, nil
|
|
}
|
|
}
|
|
|
|
func validateCreateBatchImportRunRequest(req CreateBatchImportRunRequest) *httpError {
|
|
if strings.TrimSpace(req.HostID) == "" {
|
|
return &httpError{StatusCode: http.StatusBadRequest, Code: "invalid_request", Message: "host_id is required"}
|
|
}
|
|
if strings.TrimSpace(req.Mode) == "" {
|
|
return &httpError{StatusCode: http.StatusBadRequest, Code: "invalid_request", Message: "mode is required"}
|
|
}
|
|
if strings.TrimSpace(req.AccessMode) == "" {
|
|
return &httpError{StatusCode: http.StatusBadRequest, Code: "invalid_request", Message: "access_mode is required"}
|
|
}
|
|
if len(req.Entries) == 0 {
|
|
return &httpError{StatusCode: http.StatusBadRequest, Code: "invalid_request", Message: "entries is required"}
|
|
}
|
|
switch strings.TrimSpace(req.AccessMode) {
|
|
case "subscription":
|
|
if len(req.SubscriptionUsers) == 0 {
|
|
return &httpError{StatusCode: http.StatusBadRequest, Code: "invalid_request", Message: "subscription_users is required when access_mode=subscription"}
|
|
}
|
|
if req.SubscriptionDays <= 0 {
|
|
return &httpError{StatusCode: http.StatusBadRequest, Code: "invalid_request", Message: "subscription_days is required when access_mode=subscription"}
|
|
}
|
|
case "self_service":
|
|
if strings.TrimSpace(req.ProbeAPIKey) == "" {
|
|
return &httpError{StatusCode: http.StatusBadRequest, Code: "invalid_request", Message: "probe_api_key is required when access_mode=self_service"}
|
|
}
|
|
default:
|
|
return &httpError{StatusCode: http.StatusBadRequest, Code: "invalid_request", Message: "access_mode must be subscription or self_service"}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func parsePositiveInt(raw string) int {
|
|
value, err := strconv.Atoi(strings.TrimSpace(raw))
|
|
if err != nil || value <= 0 {
|
|
return 0
|
|
}
|
|
return value
|
|
}
|
|
|
|
func defaultPositiveInt(value, fallback int) int {
|
|
if value > 0 {
|
|
return value
|
|
}
|
|
return fallback
|
|
}
|
|
|
|
func fingerprintBatchAPIKey(apiKey string) string {
|
|
trimmed := strings.TrimSpace(apiKey)
|
|
if trimmed == "" {
|
|
return ""
|
|
}
|
|
sum := sha256.Sum256([]byte(trimmed))
|
|
return fmt.Sprintf("sha256:%x", sum[:4])
|
|
}
|
|
|
|
func mustMarshalAppJSON(value any, fallback string) string {
|
|
payload, err := json.Marshal(value)
|
|
if err != nil {
|
|
return fallback
|
|
}
|
|
return string(payload)
|
|
}
|