feat(admin): persist provider drafts in crm
This commit is contained in:
@@ -23,6 +23,7 @@ type Queries struct {
|
||||
Hosts *HostsRepo
|
||||
Packs *PacksRepo
|
||||
Providers *ProvidersRepo
|
||||
ProviderDrafts *ProviderDraftsRepo
|
||||
ImportBatches *ImportBatchesRepo
|
||||
ImportBatchItems *ImportBatchItemsRepo
|
||||
ImportRuns *ImportRunsRepo
|
||||
@@ -91,6 +92,10 @@ func (db *DB) Providers() *ProvidersRepo {
|
||||
return db.queries.Providers
|
||||
}
|
||||
|
||||
func (db *DB) ProviderDrafts() *ProviderDraftsRepo {
|
||||
return db.queries.ProviderDrafts
|
||||
}
|
||||
|
||||
func (db *DB) ImportBatches() *ImportBatchesRepo {
|
||||
return db.queries.ImportBatches
|
||||
}
|
||||
@@ -159,6 +164,7 @@ func newQueries(db execQuerier) *Queries {
|
||||
Hosts: newHostsRepo(db),
|
||||
Packs: newPacksRepo(db),
|
||||
Providers: newProvidersRepo(db),
|
||||
ProviderDrafts: newProviderDraftsRepo(db),
|
||||
ImportBatches: newImportBatchesRepo(db),
|
||||
ImportBatchItems: newImportBatchItemsRepo(db),
|
||||
ImportRuns: newImportRunsRepo(db),
|
||||
|
||||
259
internal/store/sqlite/provider_drafts_repo.go
Normal file
259
internal/store/sqlite/provider_drafts_repo.go
Normal file
@@ -0,0 +1,259 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ProviderDraft struct {
|
||||
ID int64
|
||||
DraftID string
|
||||
PackID string
|
||||
ProviderID string
|
||||
DisplayName string
|
||||
Platform string
|
||||
BaseURL string
|
||||
SmokeTestModel string
|
||||
SupportedModelsJSON string
|
||||
ManifestJSON string
|
||||
SourceHostID string
|
||||
Notes string
|
||||
CreatedAt string
|
||||
UpdatedAt string
|
||||
}
|
||||
|
||||
type ListProviderDraftsFilter struct {
|
||||
PackID string
|
||||
ProviderID string
|
||||
Query string
|
||||
}
|
||||
|
||||
type ProviderDraftsRepo struct {
|
||||
db execQuerier
|
||||
}
|
||||
|
||||
func newProviderDraftsRepo(db execQuerier) *ProviderDraftsRepo {
|
||||
return &ProviderDraftsRepo{db: db}
|
||||
}
|
||||
|
||||
func (r *ProviderDraftsRepo) Create(ctx context.Context, draft ProviderDraft) (int64, error) {
|
||||
draftID := strings.TrimSpace(draft.DraftID)
|
||||
packID := strings.TrimSpace(draft.PackID)
|
||||
providerID := strings.TrimSpace(draft.ProviderID)
|
||||
displayName := strings.TrimSpace(draft.DisplayName)
|
||||
platform := strings.TrimSpace(draft.Platform)
|
||||
if draft.ManifestJSON = strings.TrimSpace(draft.ManifestJSON); draft.ManifestJSON == "" {
|
||||
draft.ManifestJSON = "{}"
|
||||
}
|
||||
if draft.SupportedModelsJSON = strings.TrimSpace(draft.SupportedModelsJSON); draft.SupportedModelsJSON == "" {
|
||||
draft.SupportedModelsJSON = "[]"
|
||||
}
|
||||
|
||||
switch {
|
||||
case draftID == "":
|
||||
return 0, fmt.Errorf("draft_id is required")
|
||||
case packID == "":
|
||||
return 0, fmt.Errorf("pack_id is required")
|
||||
case providerID == "":
|
||||
return 0, fmt.Errorf("provider_id is required")
|
||||
case displayName == "":
|
||||
return 0, fmt.Errorf("display_name is required")
|
||||
case platform == "":
|
||||
return 0, fmt.Errorf("platform is required")
|
||||
}
|
||||
|
||||
result, err := r.db.ExecContext(
|
||||
ctx,
|
||||
`INSERT INTO provider_drafts (draft_id, pack_id, provider_id, display_name, platform, base_url, smoke_test_model, supported_models_json, manifest_json, source_host_id, notes)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
draftID,
|
||||
packID,
|
||||
providerID,
|
||||
displayName,
|
||||
platform,
|
||||
strings.TrimSpace(draft.BaseURL),
|
||||
strings.TrimSpace(draft.SmokeTestModel),
|
||||
draft.SupportedModelsJSON,
|
||||
draft.ManifestJSON,
|
||||
strings.TrimSpace(draft.SourceHostID),
|
||||
strings.TrimSpace(draft.Notes),
|
||||
)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("insert provider draft %q: %w", draftID, err)
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("read inserted provider draft id for %q: %w", draftID, err)
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (r *ProviderDraftsRepo) UpdateByDraftID(ctx context.Context, draft ProviderDraft) error {
|
||||
draftID := strings.TrimSpace(draft.DraftID)
|
||||
packID := strings.TrimSpace(draft.PackID)
|
||||
providerID := strings.TrimSpace(draft.ProviderID)
|
||||
displayName := strings.TrimSpace(draft.DisplayName)
|
||||
platform := strings.TrimSpace(draft.Platform)
|
||||
if draft.ManifestJSON = strings.TrimSpace(draft.ManifestJSON); draft.ManifestJSON == "" {
|
||||
draft.ManifestJSON = "{}"
|
||||
}
|
||||
if draft.SupportedModelsJSON = strings.TrimSpace(draft.SupportedModelsJSON); draft.SupportedModelsJSON == "" {
|
||||
draft.SupportedModelsJSON = "[]"
|
||||
}
|
||||
|
||||
switch {
|
||||
case draftID == "":
|
||||
return fmt.Errorf("draft_id is required")
|
||||
case packID == "":
|
||||
return fmt.Errorf("pack_id is required")
|
||||
case providerID == "":
|
||||
return fmt.Errorf("provider_id is required")
|
||||
case displayName == "":
|
||||
return fmt.Errorf("display_name is required")
|
||||
case platform == "":
|
||||
return fmt.Errorf("platform is required")
|
||||
}
|
||||
|
||||
result, err := r.db.ExecContext(
|
||||
ctx,
|
||||
`UPDATE provider_drafts
|
||||
SET pack_id = ?, provider_id = ?, display_name = ?, platform = ?, base_url = ?, smoke_test_model = ?, supported_models_json = ?, manifest_json = ?, source_host_id = ?, notes = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE draft_id = ?`,
|
||||
packID,
|
||||
providerID,
|
||||
displayName,
|
||||
platform,
|
||||
strings.TrimSpace(draft.BaseURL),
|
||||
strings.TrimSpace(draft.SmokeTestModel),
|
||||
draft.SupportedModelsJSON,
|
||||
draft.ManifestJSON,
|
||||
strings.TrimSpace(draft.SourceHostID),
|
||||
strings.TrimSpace(draft.Notes),
|
||||
draftID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update provider draft %q: %w", draftID, err)
|
||||
}
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("read updated provider draft rows for %q: %w", draftID, err)
|
||||
}
|
||||
if affected == 0 {
|
||||
return fmt.Errorf("provider draft %q not found", draftID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *ProviderDraftsRepo) GetByDraftID(ctx context.Context, draftID string) (ProviderDraft, error) {
|
||||
draftID = strings.TrimSpace(draftID)
|
||||
if draftID == "" {
|
||||
return ProviderDraft{}, fmt.Errorf("draft_id is required")
|
||||
}
|
||||
|
||||
var draft ProviderDraft
|
||||
if err := r.db.QueryRowContext(
|
||||
ctx,
|
||||
`SELECT id, draft_id, pack_id, provider_id, display_name, platform, base_url, smoke_test_model, supported_models_json, manifest_json, source_host_id, notes, created_at, updated_at
|
||||
FROM provider_drafts WHERE draft_id = ?`,
|
||||
draftID,
|
||||
).Scan(
|
||||
&draft.ID,
|
||||
&draft.DraftID,
|
||||
&draft.PackID,
|
||||
&draft.ProviderID,
|
||||
&draft.DisplayName,
|
||||
&draft.Platform,
|
||||
&draft.BaseURL,
|
||||
&draft.SmokeTestModel,
|
||||
&draft.SupportedModelsJSON,
|
||||
&draft.ManifestJSON,
|
||||
&draft.SourceHostID,
|
||||
&draft.Notes,
|
||||
&draft.CreatedAt,
|
||||
&draft.UpdatedAt,
|
||||
); err != nil {
|
||||
return ProviderDraft{}, err
|
||||
}
|
||||
return draft, nil
|
||||
}
|
||||
|
||||
func (r *ProviderDraftsRepo) List(ctx context.Context, filter ListProviderDraftsFilter) ([]ProviderDraft, error) {
|
||||
query := `SELECT id, draft_id, pack_id, provider_id, display_name, platform, base_url, smoke_test_model, supported_models_json, manifest_json, source_host_id, notes, created_at, updated_at
|
||||
FROM provider_drafts`
|
||||
where := make([]string, 0, 3)
|
||||
args := make([]any, 0, 3)
|
||||
|
||||
if packID := strings.TrimSpace(filter.PackID); packID != "" {
|
||||
where = append(where, "pack_id = ?")
|
||||
args = append(args, packID)
|
||||
}
|
||||
if providerID := strings.TrimSpace(filter.ProviderID); providerID != "" {
|
||||
where = append(where, "provider_id = ?")
|
||||
args = append(args, providerID)
|
||||
}
|
||||
if rawQuery := strings.TrimSpace(filter.Query); rawQuery != "" {
|
||||
like := "%" + rawQuery + "%"
|
||||
where = append(where, "(draft_id LIKE ? OR provider_id LIKE ? OR display_name LIKE ?)")
|
||||
args = append(args, like, like, like)
|
||||
}
|
||||
if len(where) > 0 {
|
||||
query += " WHERE " + strings.Join(where, " AND ")
|
||||
}
|
||||
query += " ORDER BY id DESC"
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list provider drafts: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
drafts := make([]ProviderDraft, 0)
|
||||
for rows.Next() {
|
||||
var draft ProviderDraft
|
||||
if err := rows.Scan(
|
||||
&draft.ID,
|
||||
&draft.DraftID,
|
||||
&draft.PackID,
|
||||
&draft.ProviderID,
|
||||
&draft.DisplayName,
|
||||
&draft.Platform,
|
||||
&draft.BaseURL,
|
||||
&draft.SmokeTestModel,
|
||||
&draft.SupportedModelsJSON,
|
||||
&draft.ManifestJSON,
|
||||
&draft.SourceHostID,
|
||||
&draft.Notes,
|
||||
&draft.CreatedAt,
|
||||
&draft.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("scan provider draft: %w", err)
|
||||
}
|
||||
drafts = append(drafts, draft)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate provider drafts: %w", err)
|
||||
}
|
||||
return drafts, nil
|
||||
}
|
||||
|
||||
func (r *ProviderDraftsRepo) DeleteByDraftID(ctx context.Context, draftID string) error {
|
||||
draftID = strings.TrimSpace(draftID)
|
||||
if draftID == "" {
|
||||
return fmt.Errorf("draft_id is required")
|
||||
}
|
||||
|
||||
result, err := r.db.ExecContext(ctx, `DELETE FROM provider_drafts WHERE draft_id = ?`, draftID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete provider draft %q: %w", draftID, err)
|
||||
}
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("read deleted provider draft rows for %q: %w", draftID, err)
|
||||
}
|
||||
if affected == 0 {
|
||||
return fmt.Errorf("provider draft %q not found", draftID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
177
internal/store/sqlite/provider_drafts_repo_test.go
Normal file
177
internal/store/sqlite/provider_drafts_repo_test.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestProviderDraftsRepoCreateGetAndList(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
|
||||
id, err := store.ProviderDrafts().Create(context.Background(), ProviderDraft{
|
||||
DraftID: "draft_001",
|
||||
PackID: "openai-cn-pack",
|
||||
ProviderID: "openai-zhongzhuan",
|
||||
DisplayName: "OpenAI 中转",
|
||||
Platform: "openai",
|
||||
BaseURL: "https://api.example.com/v1",
|
||||
SmokeTestModel: "gpt-5.4",
|
||||
SupportedModelsJSON: `["gpt-5.4","gpt-5.4-mini"]`,
|
||||
ManifestJSON: `{"provider_id":"openai-zhongzhuan"}`,
|
||||
SourceHostID: "remote43-current-host",
|
||||
Notes: "first draft",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
}
|
||||
if id <= 0 {
|
||||
t.Fatalf("Create() id = %d, want positive", id)
|
||||
}
|
||||
|
||||
got, err := store.ProviderDrafts().GetByDraftID(context.Background(), "draft_001")
|
||||
if err != nil {
|
||||
t.Fatalf("GetByDraftID() error = %v", err)
|
||||
}
|
||||
if got.ProviderID != "openai-zhongzhuan" || got.DisplayName != "OpenAI 中转" {
|
||||
t.Fatalf("GetByDraftID() = %+v, want provider draft payload", got)
|
||||
}
|
||||
|
||||
drafts, err := store.ProviderDrafts().List(context.Background(), ListProviderDraftsFilter{PackID: "openai-cn-pack"})
|
||||
if err != nil {
|
||||
t.Fatalf("List() error = %v", err)
|
||||
}
|
||||
if len(drafts) != 1 {
|
||||
t.Fatalf("List() len = %d, want 1", len(drafts))
|
||||
}
|
||||
if drafts[0].DraftID != "draft_001" {
|
||||
t.Fatalf("List() draft_id = %q, want draft_001", drafts[0].DraftID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderDraftsRepoListSupportsQuery(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
|
||||
_, err := store.ProviderDrafts().Create(context.Background(), ProviderDraft{
|
||||
DraftID: "draft_alpha",
|
||||
PackID: "openai-cn-pack",
|
||||
ProviderID: "minimax-53hk",
|
||||
DisplayName: "MiniMax 53HK",
|
||||
Platform: "openai",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
}
|
||||
_, err = store.ProviderDrafts().Create(context.Background(), ProviderDraft{
|
||||
DraftID: "draft_beta",
|
||||
PackID: "openai-cn-pack",
|
||||
ProviderID: "deepseek-chat-official",
|
||||
DisplayName: "DeepSeek Official",
|
||||
Platform: "openai",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Create() second error = %v", err)
|
||||
}
|
||||
|
||||
drafts, err := store.ProviderDrafts().List(context.Background(), ListProviderDraftsFilter{Query: "MiniMax"})
|
||||
if err != nil {
|
||||
t.Fatalf("List() query error = %v", err)
|
||||
}
|
||||
if len(drafts) != 1 || drafts[0].ProviderID != "minimax-53hk" {
|
||||
t.Fatalf("List() query = %+v, want only minimax-53hk", drafts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderDraftsRepoUpdateByDraftID(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
|
||||
_, err := store.ProviderDrafts().Create(context.Background(), ProviderDraft{
|
||||
DraftID: "draft_update",
|
||||
PackID: "openai-cn-pack",
|
||||
ProviderID: "minimax-53hk",
|
||||
DisplayName: "MiniMax 53HK",
|
||||
Platform: "openai",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
}
|
||||
|
||||
if err := store.ProviderDrafts().UpdateByDraftID(context.Background(), ProviderDraft{
|
||||
DraftID: "draft_update",
|
||||
PackID: "openai-cn-pack",
|
||||
ProviderID: "minimax-53hk",
|
||||
DisplayName: "MiniMax 53HK Updated",
|
||||
Platform: "openai",
|
||||
BaseURL: "https://api.53hk.cn/v1",
|
||||
SmokeTestModel: "MiniMax-M2.7-highspeed",
|
||||
SupportedModelsJSON: `["MiniMax-M2.7-highspeed"]`,
|
||||
ManifestJSON: `{"provider_id":"minimax-53hk","display_name":"MiniMax 53HK Updated"}`,
|
||||
SourceHostID: "remote43-current-host",
|
||||
Notes: "updated",
|
||||
}); err != nil {
|
||||
t.Fatalf("UpdateByDraftID() error = %v", err)
|
||||
}
|
||||
|
||||
got, err := store.ProviderDrafts().GetByDraftID(context.Background(), "draft_update")
|
||||
if err != nil {
|
||||
t.Fatalf("GetByDraftID() error = %v", err)
|
||||
}
|
||||
if got.DisplayName != "MiniMax 53HK Updated" || got.BaseURL != "https://api.53hk.cn/v1" {
|
||||
t.Fatalf("updated draft = %+v, want updated fields", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderDraftsRepoDeleteByDraftID(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
|
||||
_, err := store.ProviderDrafts().Create(context.Background(), ProviderDraft{
|
||||
DraftID: "draft_delete",
|
||||
PackID: "openai-cn-pack",
|
||||
ProviderID: "deepseek-chat-official",
|
||||
DisplayName: "DeepSeek Official",
|
||||
Platform: "openai",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
}
|
||||
|
||||
if err := store.ProviderDrafts().DeleteByDraftID(context.Background(), "draft_delete"); err != nil {
|
||||
t.Fatalf("DeleteByDraftID() error = %v", err)
|
||||
}
|
||||
_, err = store.ProviderDrafts().GetByDraftID(context.Background(), "draft_delete")
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
t.Fatalf("GetByDraftID() after delete error = %v, want sql.ErrNoRows", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderDraftsRepoValidationErrors(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
draft ProviderDraft
|
||||
}{
|
||||
{"empty draft_id", ProviderDraft{PackID: "pack", ProviderID: "p", DisplayName: "n", Platform: "openai"}},
|
||||
{"empty pack_id", ProviderDraft{DraftID: "draft", ProviderID: "p", DisplayName: "n", Platform: "openai"}},
|
||||
{"empty provider_id", ProviderDraft{DraftID: "draft", PackID: "pack", DisplayName: "n", Platform: "openai"}},
|
||||
{"empty display_name", ProviderDraft{DraftID: "draft", PackID: "pack", ProviderID: "p", Platform: "openai"}},
|
||||
{"empty platform", ProviderDraft{DraftID: "draft", PackID: "pack", ProviderID: "p", DisplayName: "n"}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := store.ProviderDrafts().Create(context.Background(), tt.draft)
|
||||
if err == nil {
|
||||
t.Fatal("Create() error = nil, want validation error")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderDraftsRepoGetByDraftIDNotFound(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
_, err := store.ProviderDrafts().GetByDraftID(context.Background(), "missing")
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
t.Fatalf("GetByDraftID() error = %v, want sql.ErrNoRows", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user