feat(admin): persist provider drafts in crm

This commit is contained in:
phamnazage-jpg
2026-05-27 21:49:12 +08:00
parent ebd86a4256
commit 8d7aa925df
18 changed files with 2687 additions and 3 deletions

View File

@@ -161,6 +161,144 @@ func TestAPIPreviewProviderReturnsSummary(t *testing.T) {
assertJSONContains(t, response.Body().Bytes(), "accepted_keys_count", float64(2))
}
func TestAPICreateProviderDraftReturnsCreated(t *testing.T) {
handler := NewAPIHandler("secret-token", ActionSet{
CreateProviderDraft: func(_ context.Context, req CreateProviderDraftRequest) (ProviderDraftInfo, error) {
if req.ProviderID != "openai-zhongzhuan" {
t.Fatalf("ProviderID = %q, want openai-zhongzhuan", req.ProviderID)
}
return ProviderDraftInfo{
DraftID: "draft_001",
PackID: req.PackID,
ProviderID: req.ProviderID,
DisplayName: req.DisplayName,
Platform: req.Platform,
SmokeTestModel: req.SmokeTestModel,
SupportedModels: []string{
"gpt-5.4",
},
Manifest: map[string]any{"provider_id": req.ProviderID},
}, nil
},
})
request := httptestRequest(t, http.MethodPost, "/api/provider-drafts", map[string]any{
"pack_id": "openai-cn-pack",
"provider_id": "openai-zhongzhuan",
"display_name": "OpenAI 中转",
"platform": "openai",
"smoke_test_model": "gpt-5.4",
"supported_models": []string{"gpt-5.4"},
}, "secret-token")
response := httptestRecorder(handler, request)
assertStatusCode(t, response, http.StatusCreated)
assertJSONContains(t, response.Body().Bytes(), "draft.draft_id", "draft_001")
assertJSONContains(t, response.Body().Bytes(), "draft.provider_id", "openai-zhongzhuan")
}
func TestAPIListProviderDraftsReturnsCollection(t *testing.T) {
handler := NewAPIHandler("secret-token", ActionSet{
ListProviderDrafts: func(_ context.Context, req ListProviderDraftsRequest) ([]ProviderDraftInfo, error) {
if req.PackID != "openai-cn-pack" {
t.Fatalf("PackID = %q, want openai-cn-pack", req.PackID)
}
return []ProviderDraftInfo{{
DraftID: "draft_001",
PackID: req.PackID,
ProviderID: "minimax-53hk",
DisplayName: "MiniMax 53HK",
Platform: "openai",
Manifest: map[string]any{"provider_id": "minimax-53hk"},
SourceHostID: "remote43-current-host",
}}, nil
},
})
request := httptestRequest(t, http.MethodGet, "/api/provider-drafts?pack_id=openai-cn-pack", nil, "secret-token")
response := httptestRecorder(handler, request)
assertStatusCode(t, response, http.StatusOK)
var payload map[string]any
if err := json.Unmarshal(response.Body().Bytes(), &payload); err != nil {
t.Fatalf("json.Unmarshal() error = %v", err)
}
drafts, ok := payload["provider_drafts"].([]any)
if !ok || len(drafts) != 1 {
t.Fatalf("provider_drafts = %#v, want one item", payload["provider_drafts"])
}
item, ok := drafts[0].(map[string]any)
if !ok {
t.Fatalf("draft[0] = %#v, want object", drafts[0])
}
if got := item["provider_id"]; got != "minimax-53hk" {
t.Fatalf("provider_id = %#v, want minimax-53hk", got)
}
}
func TestAPIGetProviderDraftReturnsItem(t *testing.T) {
handler := NewAPIHandler("secret-token", ActionSet{
GetProviderDraft: func(_ context.Context, draftID string) (ProviderDraftInfo, error) {
if draftID != "draft_001" {
t.Fatalf("draftID = %q, want draft_001", draftID)
}
return ProviderDraftInfo{
DraftID: draftID,
PackID: "openai-cn-pack",
ProviderID: "deepseek-chat-official",
DisplayName: "DeepSeek Official",
Platform: "openai",
Manifest: map[string]any{"provider_id": "deepseek-chat-official"},
SourceHostID: "remote43-current-host",
}, nil
},
})
request := httptestRequest(t, http.MethodGet, "/api/provider-drafts/draft_001", nil, "secret-token")
response := httptestRecorder(handler, request)
assertStatusCode(t, response, http.StatusOK)
assertJSONContains(t, response.Body().Bytes(), "draft.provider_id", "deepseek-chat-official")
}
func TestAPIUpdateProviderDraftReturnsUpdatedItem(t *testing.T) {
handler := NewAPIHandler("secret-token", ActionSet{
UpdateProviderDraft: func(_ context.Context, req UpdateProviderDraftRequest) (ProviderDraftInfo, error) {
if req.DraftID != "draft_001" {
t.Fatalf("DraftID = %q, want draft_001", req.DraftID)
}
return ProviderDraftInfo{
DraftID: req.DraftID,
PackID: req.PackID,
ProviderID: req.ProviderID,
DisplayName: req.DisplayName,
Platform: req.Platform,
BaseURL: req.BaseURL,
Manifest: map[string]any{"provider_id": req.ProviderID},
SourceHostID: req.SourceHostID,
}, nil
},
})
request := httptestRequest(t, http.MethodPut, "/api/provider-drafts/draft_001", map[string]any{
"pack_id": "openai-cn-pack",
"provider_id": "openai-zhongzhuan",
"display_name": "OpenAI 中转 Updated",
"platform": "openai",
"base_url": "https://api.example.com/v1",
}, "secret-token")
response := httptestRecorder(handler, request)
assertStatusCode(t, response, http.StatusOK)
assertJSONContains(t, response.Body().Bytes(), "draft.display_name", "OpenAI 中转 Updated")
}
func TestAPIDeleteProviderDraftReturnsNoContent(t *testing.T) {
handler := NewAPIHandler("secret-token", ActionSet{
DeleteProviderDraft: func(_ context.Context, draftID string) error {
if draftID != "draft_001" {
t.Fatalf("draftID = %q, want draft_001", draftID)
}
return nil
},
})
request := httptestRequest(t, http.MethodDelete, "/api/provider-drafts/draft_001", nil, "secret-token")
response := httptestRecorder(handler, request)
assertStatusCode(t, response, http.StatusNoContent)
}
func TestAPIImportProviderReturnsConflictWithBatchStatus(t *testing.T) {
handler := NewAPIHandler("secret-token", ActionSet{
ImportProvider: func(context.Context, ImportProviderRequest) (provision.RuntimeImportResult, error) {

View File

@@ -9,6 +9,7 @@ import (
"net/http"
"strconv"
"strings"
"time"
"sub2api-cn-relay-manager/internal/batch"
"sub2api-cn-relay-manager/internal/host/sub2api"
@@ -26,6 +27,11 @@ type ActionSet struct {
GetBatchImportRun func(context.Context, string) (batch.RunSummaryProjection, error)
ListBatchImportRunItems func(context.Context, ListBatchImportRunItemsRequest) (ListBatchImportRunItemsResponse, error)
GetBatchImportRunItem func(context.Context, GetBatchImportRunItemRequest) (batch.ItemDetailProjection, error)
CreateProviderDraft func(context.Context, CreateProviderDraftRequest) (ProviderDraftInfo, error)
ListProviderDrafts func(context.Context, ListProviderDraftsRequest) ([]ProviderDraftInfo, error)
GetProviderDraft func(context.Context, string) (ProviderDraftInfo, error)
UpdateProviderDraft func(context.Context, UpdateProviderDraftRequest) (ProviderDraftInfo, error)
DeleteProviderDraft func(context.Context, string) error
InstallPack func(context.Context, InstallPackRequest) (provision.PackInstallResult, error)
BatchDetail func(context.Context, BatchDetailRequest) (provision.BatchDetailResult, error)
GetProviderStatus func(context.Context, ProviderQueryRequest) (provision.ProviderSnapshot, error)
@@ -98,6 +104,47 @@ type PackProviderInfo struct {
HostOverlays int `json:"host_overlays,omitempty"`
}
type CreateProviderDraftRequest struct {
DraftID string `json:"draft_id,omitempty"`
PackID string `json:"pack_id"`
ProviderID string `json:"provider_id"`
DisplayName string `json:"display_name"`
Platform string `json:"platform"`
BaseURL string `json:"base_url,omitempty"`
SmokeTestModel string `json:"smoke_test_model,omitempty"`
SupportedModels []string `json:"supported_models,omitempty"`
Manifest json.RawMessage `json:"manifest,omitempty"`
SourceHostID string `json:"source_host_id,omitempty"`
Notes string `json:"notes,omitempty"`
}
type ListProviderDraftsRequest struct {
PackID string
ProviderID string
Query string
}
type UpdateProviderDraftRequest struct {
DraftID string `json:"-"`
CreateProviderDraftRequest
}
type ProviderDraftInfo struct {
DraftID string `json:"draft_id"`
PackID string `json:"pack_id"`
ProviderID string `json:"provider_id"`
DisplayName string `json:"display_name"`
Platform string `json:"platform"`
BaseURL string `json:"base_url,omitempty"`
SmokeTestModel string `json:"smoke_test_model,omitempty"`
SupportedModels []string `json:"supported_models,omitempty"`
Manifest any `json:"manifest,omitempty"`
SourceHostID string `json:"source_host_id,omitempty"`
Notes string `json:"notes,omitempty"`
CreatedAt string `json:"created_at,omitempty"`
UpdatedAt string `json:"updated_at,omitempty"`
}
type AssignAccessSubscriptionsRequest struct {
HostID string `json:"host_id,omitempty"`
PackPath string `json:"pack_path"`
@@ -226,6 +273,21 @@ func NewAPIHandler(adminToken string, actions ActionSet) http.Handler {
mux.Handle("GET /api/batch-import/runs/{run_id}/items/{item_id}", requireAdminToken(adminToken, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleGetBatchImportRunItem(w, r, actions.GetBatchImportRunItem)
})))
mux.Handle("POST /api/provider-drafts", requireAdminToken(adminToken, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleCreateProviderDraft(w, r, actions.CreateProviderDraft)
})))
mux.Handle("GET /api/provider-drafts", requireAdminToken(adminToken, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleListProviderDrafts(w, r, actions.ListProviderDrafts)
})))
mux.Handle("GET /api/provider-drafts/{draftID}", requireAdminToken(adminToken, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleGetProviderDraft(w, r, actions.GetProviderDraft)
})))
mux.Handle("PUT /api/provider-drafts/{draftID}", requireAdminToken(adminToken, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleUpdateProviderDraft(w, r, actions.UpdateProviderDraft)
})))
mux.Handle("DELETE /api/provider-drafts/{draftID}", requireAdminToken(adminToken, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleDeleteProviderDraft(w, r, actions.DeleteProviderDraft)
})))
mux.Handle("GET /api/import-batches/{batchID}", requireAdminToken(adminToken, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleBatchDetail(w, r, actions.BatchDetail)
})))
@@ -297,6 +359,102 @@ func healthz(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("ok"))
}
func handleCreateProviderDraft(w http.ResponseWriter, r *http.Request, fn func(context.Context, CreateProviderDraftRequest) (ProviderDraftInfo, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "create-provider-draft action is not configured"})
return
}
var req CreateProviderDraftRequest
if err := decodeJSON(r, &req); err != nil {
writeHTTPError(w, err)
return
}
draft, err := fn(r.Context(), req)
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
writeJSON(w, http.StatusCreated, map[string]any{"draft": draft})
}
func handleListProviderDrafts(w http.ResponseWriter, r *http.Request, fn func(context.Context, ListProviderDraftsRequest) ([]ProviderDraftInfo, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "list-provider-drafts action is not configured"})
return
}
drafts, err := fn(r.Context(), ListProviderDraftsRequest{
PackID: strings.TrimSpace(r.URL.Query().Get("pack_id")),
ProviderID: strings.TrimSpace(r.URL.Query().Get("provider_id")),
Query: strings.TrimSpace(r.URL.Query().Get("q")),
})
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
if drafts == nil {
drafts = []ProviderDraftInfo{}
}
writeJSON(w, http.StatusOK, map[string]any{"provider_drafts": drafts})
}
func handleGetProviderDraft(w http.ResponseWriter, r *http.Request, fn func(context.Context, string) (ProviderDraftInfo, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "get-provider-draft action is not configured"})
return
}
draftID := strings.TrimSpace(r.PathValue("draftID"))
if draftID == "" {
writeHTTPError(w, &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: "draft_id is required"})
return
}
draft, err := fn(r.Context(), draftID)
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
writeJSON(w, http.StatusOK, map[string]any{"draft": draft})
}
func handleUpdateProviderDraft(w http.ResponseWriter, r *http.Request, fn func(context.Context, UpdateProviderDraftRequest) (ProviderDraftInfo, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "update-provider-draft action is not configured"})
return
}
var req UpdateProviderDraftRequest
if err := decodeJSON(r, &req.CreateProviderDraftRequest); err != nil {
writeHTTPError(w, err)
return
}
req.DraftID = strings.TrimSpace(r.PathValue("draftID"))
if req.DraftID == "" {
writeHTTPError(w, &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: "draft_id is required"})
return
}
draft, err := fn(r.Context(), req)
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
writeJSON(w, http.StatusOK, map[string]any{"draft": draft})
}
func handleDeleteProviderDraft(w http.ResponseWriter, r *http.Request, fn func(context.Context, string) error) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "delete-provider-draft action is not configured"})
return
}
draftID := strings.TrimSpace(r.PathValue("draftID"))
if draftID == "" {
writeHTTPError(w, &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: "draft_id is required"})
return
}
if err := fn(r.Context(), draftID); err != nil {
writeHTTPError(w, classifyError(err))
return
}
w.WriteHeader(http.StatusNoContent)
}
func requireAdminToken(token string, next http.Handler) http.Handler {
if strings.TrimSpace(token) == "" {
return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
@@ -912,6 +1070,123 @@ func NewActionSet(sqliteDSN string) ActionSet {
GetBatchImportRun: buildGetBatchImportRunAction(sqliteDSN),
ListBatchImportRunItems: buildListBatchImportRunItemsAction(sqliteDSN),
GetBatchImportRunItem: buildGetBatchImportRunItemAction(sqliteDSN),
CreateProviderDraft: func(ctx context.Context, req CreateProviderDraftRequest) (ProviderDraftInfo, error) {
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {
return ProviderDraftInfo{}, err
}
defer store.Close()
draftID := strings.TrimSpace(req.DraftID)
if draftID == "" {
draftID = fmt.Sprintf("draft_%d", time.Now().UnixNano())
}
manifestJSON, manifestValue, supportedModels, err := normalizeProviderDraftPayload(req)
if err != nil {
return ProviderDraftInfo{}, err
}
draftRow := sqlite.ProviderDraft{
DraftID: draftID,
PackID: strings.TrimSpace(req.PackID),
ProviderID: strings.TrimSpace(req.ProviderID),
DisplayName: strings.TrimSpace(req.DisplayName),
Platform: strings.TrimSpace(req.Platform),
BaseURL: strings.TrimSpace(req.BaseURL),
SmokeTestModel: strings.TrimSpace(req.SmokeTestModel),
SupportedModelsJSON: encodeStringList(supportedModels),
ManifestJSON: manifestJSON,
SourceHostID: strings.TrimSpace(req.SourceHostID),
Notes: strings.TrimSpace(req.Notes),
}
if _, err := store.ProviderDrafts().Create(ctx, draftRow); err != nil {
return ProviderDraftInfo{}, err
}
persisted, err := store.ProviderDrafts().GetByDraftID(ctx, draftID)
if err != nil {
return ProviderDraftInfo{}, err
}
return providerDraftRecordToInfo(persisted, manifestValue, supportedModels)
},
ListProviderDrafts: func(ctx context.Context, req ListProviderDraftsRequest) ([]ProviderDraftInfo, error) {
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {
return nil, err
}
defer store.Close()
rows, err := store.ProviderDrafts().List(ctx, sqlite.ListProviderDraftsFilter{
PackID: req.PackID,
ProviderID: req.ProviderID,
Query: req.Query,
})
if err != nil {
return nil, err
}
result := make([]ProviderDraftInfo, 0, len(rows))
for _, row := range rows {
info, err := providerDraftRecordToInfoFromStored(row)
if err != nil {
return nil, err
}
result = append(result, info)
}
return result, nil
},
GetProviderDraft: func(ctx context.Context, draftID string) (ProviderDraftInfo, error) {
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {
return ProviderDraftInfo{}, err
}
defer store.Close()
row, err := store.ProviderDrafts().GetByDraftID(ctx, draftID)
if err != nil {
return ProviderDraftInfo{}, err
}
return providerDraftRecordToInfoFromStored(row)
},
UpdateProviderDraft: func(ctx context.Context, req UpdateProviderDraftRequest) (ProviderDraftInfo, error) {
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {
return ProviderDraftInfo{}, err
}
defer store.Close()
manifestJSON, _, supportedModels, err := normalizeProviderDraftPayload(req.CreateProviderDraftRequest)
if err != nil {
return ProviderDraftInfo{}, err
}
if err := store.ProviderDrafts().UpdateByDraftID(ctx, sqlite.ProviderDraft{
DraftID: strings.TrimSpace(req.DraftID),
PackID: strings.TrimSpace(req.PackID),
ProviderID: strings.TrimSpace(req.ProviderID),
DisplayName: strings.TrimSpace(req.DisplayName),
Platform: strings.TrimSpace(req.Platform),
BaseURL: strings.TrimSpace(req.BaseURL),
SmokeTestModel: strings.TrimSpace(req.SmokeTestModel),
SupportedModelsJSON: encodeStringList(supportedModels),
ManifestJSON: manifestJSON,
SourceHostID: strings.TrimSpace(req.SourceHostID),
Notes: strings.TrimSpace(req.Notes),
}); err != nil {
return ProviderDraftInfo{}, err
}
row, err := store.ProviderDrafts().GetByDraftID(ctx, req.DraftID)
if err != nil {
return ProviderDraftInfo{}, err
}
return providerDraftRecordToInfoFromStored(row)
},
DeleteProviderDraft: func(ctx context.Context, draftID string) error {
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {
return err
}
defer store.Close()
return store.ProviderDrafts().DeleteByDraftID(ctx, draftID)
},
InstallPack: func(ctx context.Context, req InstallPackRequest) (provision.PackInstallResult, error) {
loadedPack, err := pack.LoadPath(req.PackPath)
if err != nil {
@@ -1650,6 +1925,92 @@ func packRecordToInfo(pack sqlite.Pack) PackInfo {
}
}
func normalizeProviderDraftPayload(req CreateProviderDraftRequest) (string, any, []string, error) {
supportedModels := normalizeStringList(req.SupportedModels)
if len(req.Manifest) > 0 {
var manifestValue any
if err := json.Unmarshal(req.Manifest, &manifestValue); err != nil {
return "", nil, nil, fmt.Errorf("decode manifest: %w", err)
}
manifestJSON := strings.TrimSpace(string(req.Manifest))
if manifestJSON == "" {
manifestJSON = "{}"
}
return manifestJSON, manifestValue, supportedModels, nil
}
manifestValue := map[string]any{
"provider_id": strings.TrimSpace(req.ProviderID),
"display_name": strings.TrimSpace(req.DisplayName),
"platform": strings.TrimSpace(req.Platform),
"base_url": strings.TrimSpace(req.BaseURL),
"smoke_test_model": strings.TrimSpace(req.SmokeTestModel),
"supported_models": supportedModels,
}
manifestJSONBytes, err := json.Marshal(manifestValue)
if err != nil {
return "", nil, nil, fmt.Errorf("marshal manifest: %w", err)
}
return string(manifestJSONBytes), manifestValue, supportedModels, nil
}
func providerDraftRecordToInfo(row sqlite.ProviderDraft, manifestValue any, supportedModels []string) (ProviderDraftInfo, error) {
if manifestValue == nil {
manifestValue = map[string]any{}
}
return ProviderDraftInfo{
DraftID: row.DraftID,
PackID: row.PackID,
ProviderID: row.ProviderID,
DisplayName: row.DisplayName,
Platform: row.Platform,
BaseURL: row.BaseURL,
SmokeTestModel: row.SmokeTestModel,
SupportedModels: append([]string(nil), supportedModels...),
Manifest: manifestValue,
SourceHostID: row.SourceHostID,
Notes: row.Notes,
CreatedAt: row.CreatedAt,
UpdatedAt: row.UpdatedAt,
}, nil
}
func providerDraftRecordToInfoFromStored(row sqlite.ProviderDraft) (ProviderDraftInfo, error) {
var manifestValue any
if strings.TrimSpace(row.ManifestJSON) != "" {
if err := json.Unmarshal([]byte(row.ManifestJSON), &manifestValue); err != nil {
return ProviderDraftInfo{}, fmt.Errorf("decode stored provider draft manifest: %w", err)
}
}
supportedModels := []string{}
if strings.TrimSpace(row.SupportedModelsJSON) != "" {
if err := json.Unmarshal([]byte(row.SupportedModelsJSON), &supportedModels); err != nil {
return ProviderDraftInfo{}, fmt.Errorf("decode stored provider draft supported_models: %w", err)
}
}
return providerDraftRecordToInfo(row, manifestValue, supportedModels)
}
func encodeStringList(values []string) string {
encoded, err := json.Marshal(normalizeStringList(values))
if err != nil {
return "[]"
}
return string(encoded)
}
func normalizeStringList(values []string) []string {
normalized := make([]string, 0, len(values))
for _, value := range values {
value = strings.TrimSpace(value)
if value == "" {
continue
}
normalized = append(normalized, value)
}
return normalized
}
func deriveAccessStatus(gw sub2api.GatewayAccessResult) string {
if provision.GatewayAccessReady(gw) {
return provision.AccessStatusSubscriptionReady

View File

@@ -0,0 +1,19 @@
CREATE TABLE provider_drafts (
id INTEGER PRIMARY KEY AUTOINCREMENT,
draft_id TEXT NOT NULL UNIQUE,
pack_id TEXT NOT NULL,
provider_id TEXT NOT NULL,
display_name TEXT NOT NULL,
platform TEXT NOT NULL DEFAULT '',
base_url TEXT NOT NULL DEFAULT '',
smoke_test_model TEXT NOT NULL DEFAULT '',
supported_models_json TEXT NOT NULL DEFAULT '[]',
manifest_json TEXT NOT NULL DEFAULT '{}',
source_host_id TEXT NOT NULL DEFAULT '',
notes TEXT NOT NULL DEFAULT '',
created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX idx_provider_drafts_pack_id ON provider_drafts(pack_id);
CREATE INDEX idx_provider_drafts_provider_id ON provider_drafts(provider_id);

View File

@@ -23,6 +23,7 @@ type Queries struct {
Hosts *HostsRepo
Packs *PacksRepo
Providers *ProvidersRepo
ProviderDrafts *ProviderDraftsRepo
ImportBatches *ImportBatchesRepo
ImportBatchItems *ImportBatchItemsRepo
ImportRuns *ImportRunsRepo
@@ -91,6 +92,10 @@ func (db *DB) Providers() *ProvidersRepo {
return db.queries.Providers
}
func (db *DB) ProviderDrafts() *ProviderDraftsRepo {
return db.queries.ProviderDrafts
}
func (db *DB) ImportBatches() *ImportBatchesRepo {
return db.queries.ImportBatches
}
@@ -159,6 +164,7 @@ func newQueries(db execQuerier) *Queries {
Hosts: newHostsRepo(db),
Packs: newPacksRepo(db),
Providers: newProvidersRepo(db),
ProviderDrafts: newProviderDraftsRepo(db),
ImportBatches: newImportBatchesRepo(db),
ImportBatchItems: newImportBatchItemsRepo(db),
ImportRuns: newImportRunsRepo(db),

View File

@@ -0,0 +1,259 @@
package sqlite
import (
"context"
"fmt"
"strings"
)
type ProviderDraft struct {
ID int64
DraftID string
PackID string
ProviderID string
DisplayName string
Platform string
BaseURL string
SmokeTestModel string
SupportedModelsJSON string
ManifestJSON string
SourceHostID string
Notes string
CreatedAt string
UpdatedAt string
}
type ListProviderDraftsFilter struct {
PackID string
ProviderID string
Query string
}
type ProviderDraftsRepo struct {
db execQuerier
}
func newProviderDraftsRepo(db execQuerier) *ProviderDraftsRepo {
return &ProviderDraftsRepo{db: db}
}
func (r *ProviderDraftsRepo) Create(ctx context.Context, draft ProviderDraft) (int64, error) {
draftID := strings.TrimSpace(draft.DraftID)
packID := strings.TrimSpace(draft.PackID)
providerID := strings.TrimSpace(draft.ProviderID)
displayName := strings.TrimSpace(draft.DisplayName)
platform := strings.TrimSpace(draft.Platform)
if draft.ManifestJSON = strings.TrimSpace(draft.ManifestJSON); draft.ManifestJSON == "" {
draft.ManifestJSON = "{}"
}
if draft.SupportedModelsJSON = strings.TrimSpace(draft.SupportedModelsJSON); draft.SupportedModelsJSON == "" {
draft.SupportedModelsJSON = "[]"
}
switch {
case draftID == "":
return 0, fmt.Errorf("draft_id is required")
case packID == "":
return 0, fmt.Errorf("pack_id is required")
case providerID == "":
return 0, fmt.Errorf("provider_id is required")
case displayName == "":
return 0, fmt.Errorf("display_name is required")
case platform == "":
return 0, fmt.Errorf("platform is required")
}
result, err := r.db.ExecContext(
ctx,
`INSERT INTO provider_drafts (draft_id, pack_id, provider_id, display_name, platform, base_url, smoke_test_model, supported_models_json, manifest_json, source_host_id, notes)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
draftID,
packID,
providerID,
displayName,
platform,
strings.TrimSpace(draft.BaseURL),
strings.TrimSpace(draft.SmokeTestModel),
draft.SupportedModelsJSON,
draft.ManifestJSON,
strings.TrimSpace(draft.SourceHostID),
strings.TrimSpace(draft.Notes),
)
if err != nil {
return 0, fmt.Errorf("insert provider draft %q: %w", draftID, err)
}
id, err := result.LastInsertId()
if err != nil {
return 0, fmt.Errorf("read inserted provider draft id for %q: %w", draftID, err)
}
return id, nil
}
func (r *ProviderDraftsRepo) UpdateByDraftID(ctx context.Context, draft ProviderDraft) error {
draftID := strings.TrimSpace(draft.DraftID)
packID := strings.TrimSpace(draft.PackID)
providerID := strings.TrimSpace(draft.ProviderID)
displayName := strings.TrimSpace(draft.DisplayName)
platform := strings.TrimSpace(draft.Platform)
if draft.ManifestJSON = strings.TrimSpace(draft.ManifestJSON); draft.ManifestJSON == "" {
draft.ManifestJSON = "{}"
}
if draft.SupportedModelsJSON = strings.TrimSpace(draft.SupportedModelsJSON); draft.SupportedModelsJSON == "" {
draft.SupportedModelsJSON = "[]"
}
switch {
case draftID == "":
return fmt.Errorf("draft_id is required")
case packID == "":
return fmt.Errorf("pack_id is required")
case providerID == "":
return fmt.Errorf("provider_id is required")
case displayName == "":
return fmt.Errorf("display_name is required")
case platform == "":
return fmt.Errorf("platform is required")
}
result, err := r.db.ExecContext(
ctx,
`UPDATE provider_drafts
SET pack_id = ?, provider_id = ?, display_name = ?, platform = ?, base_url = ?, smoke_test_model = ?, supported_models_json = ?, manifest_json = ?, source_host_id = ?, notes = ?, updated_at = CURRENT_TIMESTAMP
WHERE draft_id = ?`,
packID,
providerID,
displayName,
platform,
strings.TrimSpace(draft.BaseURL),
strings.TrimSpace(draft.SmokeTestModel),
draft.SupportedModelsJSON,
draft.ManifestJSON,
strings.TrimSpace(draft.SourceHostID),
strings.TrimSpace(draft.Notes),
draftID,
)
if err != nil {
return fmt.Errorf("update provider draft %q: %w", draftID, err)
}
affected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("read updated provider draft rows for %q: %w", draftID, err)
}
if affected == 0 {
return fmt.Errorf("provider draft %q not found", draftID)
}
return nil
}
func (r *ProviderDraftsRepo) GetByDraftID(ctx context.Context, draftID string) (ProviderDraft, error) {
draftID = strings.TrimSpace(draftID)
if draftID == "" {
return ProviderDraft{}, fmt.Errorf("draft_id is required")
}
var draft ProviderDraft
if err := r.db.QueryRowContext(
ctx,
`SELECT id, draft_id, pack_id, provider_id, display_name, platform, base_url, smoke_test_model, supported_models_json, manifest_json, source_host_id, notes, created_at, updated_at
FROM provider_drafts WHERE draft_id = ?`,
draftID,
).Scan(
&draft.ID,
&draft.DraftID,
&draft.PackID,
&draft.ProviderID,
&draft.DisplayName,
&draft.Platform,
&draft.BaseURL,
&draft.SmokeTestModel,
&draft.SupportedModelsJSON,
&draft.ManifestJSON,
&draft.SourceHostID,
&draft.Notes,
&draft.CreatedAt,
&draft.UpdatedAt,
); err != nil {
return ProviderDraft{}, err
}
return draft, nil
}
func (r *ProviderDraftsRepo) List(ctx context.Context, filter ListProviderDraftsFilter) ([]ProviderDraft, error) {
query := `SELECT id, draft_id, pack_id, provider_id, display_name, platform, base_url, smoke_test_model, supported_models_json, manifest_json, source_host_id, notes, created_at, updated_at
FROM provider_drafts`
where := make([]string, 0, 3)
args := make([]any, 0, 3)
if packID := strings.TrimSpace(filter.PackID); packID != "" {
where = append(where, "pack_id = ?")
args = append(args, packID)
}
if providerID := strings.TrimSpace(filter.ProviderID); providerID != "" {
where = append(where, "provider_id = ?")
args = append(args, providerID)
}
if rawQuery := strings.TrimSpace(filter.Query); rawQuery != "" {
like := "%" + rawQuery + "%"
where = append(where, "(draft_id LIKE ? OR provider_id LIKE ? OR display_name LIKE ?)")
args = append(args, like, like, like)
}
if len(where) > 0 {
query += " WHERE " + strings.Join(where, " AND ")
}
query += " ORDER BY id DESC"
rows, err := r.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("list provider drafts: %w", err)
}
defer rows.Close()
drafts := make([]ProviderDraft, 0)
for rows.Next() {
var draft ProviderDraft
if err := rows.Scan(
&draft.ID,
&draft.DraftID,
&draft.PackID,
&draft.ProviderID,
&draft.DisplayName,
&draft.Platform,
&draft.BaseURL,
&draft.SmokeTestModel,
&draft.SupportedModelsJSON,
&draft.ManifestJSON,
&draft.SourceHostID,
&draft.Notes,
&draft.CreatedAt,
&draft.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("scan provider draft: %w", err)
}
drafts = append(drafts, draft)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate provider drafts: %w", err)
}
return drafts, nil
}
func (r *ProviderDraftsRepo) DeleteByDraftID(ctx context.Context, draftID string) error {
draftID = strings.TrimSpace(draftID)
if draftID == "" {
return fmt.Errorf("draft_id is required")
}
result, err := r.db.ExecContext(ctx, `DELETE FROM provider_drafts WHERE draft_id = ?`, draftID)
if err != nil {
return fmt.Errorf("delete provider draft %q: %w", draftID, err)
}
affected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("read deleted provider draft rows for %q: %w", draftID, err)
}
if affected == 0 {
return fmt.Errorf("provider draft %q not found", draftID)
}
return nil
}

View File

@@ -0,0 +1,177 @@
package sqlite
import (
"context"
"database/sql"
"errors"
"testing"
)
func TestProviderDraftsRepoCreateGetAndList(t *testing.T) {
store := openTestDB(t)
id, err := store.ProviderDrafts().Create(context.Background(), ProviderDraft{
DraftID: "draft_001",
PackID: "openai-cn-pack",
ProviderID: "openai-zhongzhuan",
DisplayName: "OpenAI 中转",
Platform: "openai",
BaseURL: "https://api.example.com/v1",
SmokeTestModel: "gpt-5.4",
SupportedModelsJSON: `["gpt-5.4","gpt-5.4-mini"]`,
ManifestJSON: `{"provider_id":"openai-zhongzhuan"}`,
SourceHostID: "remote43-current-host",
Notes: "first draft",
})
if err != nil {
t.Fatalf("Create() error = %v", err)
}
if id <= 0 {
t.Fatalf("Create() id = %d, want positive", id)
}
got, err := store.ProviderDrafts().GetByDraftID(context.Background(), "draft_001")
if err != nil {
t.Fatalf("GetByDraftID() error = %v", err)
}
if got.ProviderID != "openai-zhongzhuan" || got.DisplayName != "OpenAI 中转" {
t.Fatalf("GetByDraftID() = %+v, want provider draft payload", got)
}
drafts, err := store.ProviderDrafts().List(context.Background(), ListProviderDraftsFilter{PackID: "openai-cn-pack"})
if err != nil {
t.Fatalf("List() error = %v", err)
}
if len(drafts) != 1 {
t.Fatalf("List() len = %d, want 1", len(drafts))
}
if drafts[0].DraftID != "draft_001" {
t.Fatalf("List() draft_id = %q, want draft_001", drafts[0].DraftID)
}
}
func TestProviderDraftsRepoListSupportsQuery(t *testing.T) {
store := openTestDB(t)
_, err := store.ProviderDrafts().Create(context.Background(), ProviderDraft{
DraftID: "draft_alpha",
PackID: "openai-cn-pack",
ProviderID: "minimax-53hk",
DisplayName: "MiniMax 53HK",
Platform: "openai",
})
if err != nil {
t.Fatalf("Create() error = %v", err)
}
_, err = store.ProviderDrafts().Create(context.Background(), ProviderDraft{
DraftID: "draft_beta",
PackID: "openai-cn-pack",
ProviderID: "deepseek-chat-official",
DisplayName: "DeepSeek Official",
Platform: "openai",
})
if err != nil {
t.Fatalf("Create() second error = %v", err)
}
drafts, err := store.ProviderDrafts().List(context.Background(), ListProviderDraftsFilter{Query: "MiniMax"})
if err != nil {
t.Fatalf("List() query error = %v", err)
}
if len(drafts) != 1 || drafts[0].ProviderID != "minimax-53hk" {
t.Fatalf("List() query = %+v, want only minimax-53hk", drafts)
}
}
func TestProviderDraftsRepoUpdateByDraftID(t *testing.T) {
store := openTestDB(t)
_, err := store.ProviderDrafts().Create(context.Background(), ProviderDraft{
DraftID: "draft_update",
PackID: "openai-cn-pack",
ProviderID: "minimax-53hk",
DisplayName: "MiniMax 53HK",
Platform: "openai",
})
if err != nil {
t.Fatalf("Create() error = %v", err)
}
if err := store.ProviderDrafts().UpdateByDraftID(context.Background(), ProviderDraft{
DraftID: "draft_update",
PackID: "openai-cn-pack",
ProviderID: "minimax-53hk",
DisplayName: "MiniMax 53HK Updated",
Platform: "openai",
BaseURL: "https://api.53hk.cn/v1",
SmokeTestModel: "MiniMax-M2.7-highspeed",
SupportedModelsJSON: `["MiniMax-M2.7-highspeed"]`,
ManifestJSON: `{"provider_id":"minimax-53hk","display_name":"MiniMax 53HK Updated"}`,
SourceHostID: "remote43-current-host",
Notes: "updated",
}); err != nil {
t.Fatalf("UpdateByDraftID() error = %v", err)
}
got, err := store.ProviderDrafts().GetByDraftID(context.Background(), "draft_update")
if err != nil {
t.Fatalf("GetByDraftID() error = %v", err)
}
if got.DisplayName != "MiniMax 53HK Updated" || got.BaseURL != "https://api.53hk.cn/v1" {
t.Fatalf("updated draft = %+v, want updated fields", got)
}
}
func TestProviderDraftsRepoDeleteByDraftID(t *testing.T) {
store := openTestDB(t)
_, err := store.ProviderDrafts().Create(context.Background(), ProviderDraft{
DraftID: "draft_delete",
PackID: "openai-cn-pack",
ProviderID: "deepseek-chat-official",
DisplayName: "DeepSeek Official",
Platform: "openai",
})
if err != nil {
t.Fatalf("Create() error = %v", err)
}
if err := store.ProviderDrafts().DeleteByDraftID(context.Background(), "draft_delete"); err != nil {
t.Fatalf("DeleteByDraftID() error = %v", err)
}
_, err = store.ProviderDrafts().GetByDraftID(context.Background(), "draft_delete")
if !errors.Is(err, sql.ErrNoRows) {
t.Fatalf("GetByDraftID() after delete error = %v, want sql.ErrNoRows", err)
}
}
func TestProviderDraftsRepoValidationErrors(t *testing.T) {
store := openTestDB(t)
tests := []struct {
name string
draft ProviderDraft
}{
{"empty draft_id", ProviderDraft{PackID: "pack", ProviderID: "p", DisplayName: "n", Platform: "openai"}},
{"empty pack_id", ProviderDraft{DraftID: "draft", ProviderID: "p", DisplayName: "n", Platform: "openai"}},
{"empty provider_id", ProviderDraft{DraftID: "draft", PackID: "pack", DisplayName: "n", Platform: "openai"}},
{"empty display_name", ProviderDraft{DraftID: "draft", PackID: "pack", ProviderID: "p", Platform: "openai"}},
{"empty platform", ProviderDraft{DraftID: "draft", PackID: "pack", ProviderID: "p", DisplayName: "n"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := store.ProviderDrafts().Create(context.Background(), tt.draft)
if err == nil {
t.Fatal("Create() error = nil, want validation error")
}
})
}
}
func TestProviderDraftsRepoGetByDraftIDNotFound(t *testing.T) {
store := openTestDB(t)
_, err := store.ProviderDrafts().GetByDraftID(context.Background(), "missing")
if !errors.Is(err, sql.ErrNoRows) {
t.Fatalf("GetByDraftID() error = %v, want sql.ErrNoRows", err)
}
}