package sqlite import ( "context" "fmt" "strings" ) type ProviderDraft struct { ID int64 DraftID string PackID string ProviderID string DisplayName string Platform string BaseURL string SmokeTestModel string SupportedModelsJSON string ManifestJSON string SourceHostID string Notes string CreatedAt string UpdatedAt string } type ListProviderDraftsFilter struct { PackID string ProviderID string Query string } type ProviderDraftsRepo struct { db execQuerier } func newProviderDraftsRepo(db execQuerier) *ProviderDraftsRepo { return &ProviderDraftsRepo{db: db} } func (r *ProviderDraftsRepo) Create(ctx context.Context, draft ProviderDraft) (int64, error) { draftID := strings.TrimSpace(draft.DraftID) packID := strings.TrimSpace(draft.PackID) providerID := strings.TrimSpace(draft.ProviderID) displayName := strings.TrimSpace(draft.DisplayName) platform := strings.TrimSpace(draft.Platform) if draft.ManifestJSON = strings.TrimSpace(draft.ManifestJSON); draft.ManifestJSON == "" { draft.ManifestJSON = "{}" } if draft.SupportedModelsJSON = strings.TrimSpace(draft.SupportedModelsJSON); draft.SupportedModelsJSON == "" { draft.SupportedModelsJSON = "[]" } switch { case draftID == "": return 0, fmt.Errorf("draft_id is required") case packID == "": return 0, fmt.Errorf("pack_id is required") case providerID == "": return 0, fmt.Errorf("provider_id is required") case displayName == "": return 0, fmt.Errorf("display_name is required") case platform == "": return 0, fmt.Errorf("platform is required") } result, err := r.db.ExecContext( ctx, `INSERT INTO provider_drafts (draft_id, pack_id, provider_id, display_name, platform, base_url, smoke_test_model, supported_models_json, manifest_json, source_host_id, notes) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, draftID, packID, providerID, displayName, platform, strings.TrimSpace(draft.BaseURL), strings.TrimSpace(draft.SmokeTestModel), draft.SupportedModelsJSON, draft.ManifestJSON, strings.TrimSpace(draft.SourceHostID), strings.TrimSpace(draft.Notes), ) if err != nil { return 0, fmt.Errorf("insert provider draft %q: %w", draftID, err) } id, err := result.LastInsertId() if err != nil { return 0, fmt.Errorf("read inserted provider draft id for %q: %w", draftID, err) } return id, nil } func (r *ProviderDraftsRepo) UpdateByDraftID(ctx context.Context, draft ProviderDraft) error { draftID := strings.TrimSpace(draft.DraftID) packID := strings.TrimSpace(draft.PackID) providerID := strings.TrimSpace(draft.ProviderID) displayName := strings.TrimSpace(draft.DisplayName) platform := strings.TrimSpace(draft.Platform) if draft.ManifestJSON = strings.TrimSpace(draft.ManifestJSON); draft.ManifestJSON == "" { draft.ManifestJSON = "{}" } if draft.SupportedModelsJSON = strings.TrimSpace(draft.SupportedModelsJSON); draft.SupportedModelsJSON == "" { draft.SupportedModelsJSON = "[]" } switch { case draftID == "": return fmt.Errorf("draft_id is required") case packID == "": return fmt.Errorf("pack_id is required") case providerID == "": return fmt.Errorf("provider_id is required") case displayName == "": return fmt.Errorf("display_name is required") case platform == "": return fmt.Errorf("platform is required") } result, err := r.db.ExecContext( ctx, `UPDATE provider_drafts SET pack_id = ?, provider_id = ?, display_name = ?, platform = ?, base_url = ?, smoke_test_model = ?, supported_models_json = ?, manifest_json = ?, source_host_id = ?, notes = ?, updated_at = CURRENT_TIMESTAMP WHERE draft_id = ?`, packID, providerID, displayName, platform, strings.TrimSpace(draft.BaseURL), strings.TrimSpace(draft.SmokeTestModel), draft.SupportedModelsJSON, draft.ManifestJSON, strings.TrimSpace(draft.SourceHostID), strings.TrimSpace(draft.Notes), draftID, ) if err != nil { return fmt.Errorf("update provider draft %q: %w", draftID, err) } affected, err := result.RowsAffected() if err != nil { return fmt.Errorf("read updated provider draft rows for %q: %w", draftID, err) } if affected == 0 { return fmt.Errorf("provider draft %q not found", draftID) } return nil } func (r *ProviderDraftsRepo) GetByDraftID(ctx context.Context, draftID string) (ProviderDraft, error) { draftID = strings.TrimSpace(draftID) if draftID == "" { return ProviderDraft{}, fmt.Errorf("draft_id is required") } var draft ProviderDraft if err := r.db.QueryRowContext( ctx, `SELECT id, draft_id, pack_id, provider_id, display_name, platform, base_url, smoke_test_model, supported_models_json, manifest_json, source_host_id, notes, created_at, updated_at FROM provider_drafts WHERE draft_id = ?`, draftID, ).Scan( &draft.ID, &draft.DraftID, &draft.PackID, &draft.ProviderID, &draft.DisplayName, &draft.Platform, &draft.BaseURL, &draft.SmokeTestModel, &draft.SupportedModelsJSON, &draft.ManifestJSON, &draft.SourceHostID, &draft.Notes, &draft.CreatedAt, &draft.UpdatedAt, ); err != nil { return ProviderDraft{}, err } return draft, nil } func (r *ProviderDraftsRepo) List(ctx context.Context, filter ListProviderDraftsFilter) ([]ProviderDraft, error) { query := `SELECT id, draft_id, pack_id, provider_id, display_name, platform, base_url, smoke_test_model, supported_models_json, manifest_json, source_host_id, notes, created_at, updated_at FROM provider_drafts` where := make([]string, 0, 3) args := make([]any, 0, 3) if packID := strings.TrimSpace(filter.PackID); packID != "" { where = append(where, "pack_id = ?") args = append(args, packID) } if providerID := strings.TrimSpace(filter.ProviderID); providerID != "" { where = append(where, "provider_id = ?") args = append(args, providerID) } if rawQuery := strings.TrimSpace(filter.Query); rawQuery != "" { like := "%" + rawQuery + "%" where = append(where, "(draft_id LIKE ? OR provider_id LIKE ? OR display_name LIKE ?)") args = append(args, like, like, like) } if len(where) > 0 { query += " WHERE " + strings.Join(where, " AND ") } query += " ORDER BY id DESC" rows, err := r.db.QueryContext(ctx, query, args...) if err != nil { return nil, fmt.Errorf("list provider drafts: %w", err) } defer rows.Close() drafts := make([]ProviderDraft, 0) for rows.Next() { var draft ProviderDraft if err := rows.Scan( &draft.ID, &draft.DraftID, &draft.PackID, &draft.ProviderID, &draft.DisplayName, &draft.Platform, &draft.BaseURL, &draft.SmokeTestModel, &draft.SupportedModelsJSON, &draft.ManifestJSON, &draft.SourceHostID, &draft.Notes, &draft.CreatedAt, &draft.UpdatedAt, ); err != nil { return nil, fmt.Errorf("scan provider draft: %w", err) } drafts = append(drafts, draft) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("iterate provider drafts: %w", err) } return drafts, nil } func (r *ProviderDraftsRepo) DeleteByDraftID(ctx context.Context, draftID string) error { draftID = strings.TrimSpace(draftID) if draftID == "" { return fmt.Errorf("draft_id is required") } result, err := r.db.ExecContext(ctx, `DELETE FROM provider_drafts WHERE draft_id = ?`, draftID) if err != nil { return fmt.Errorf("delete provider draft %q: %w", draftID, err) } affected, err := result.RowsAffected() if err != nil { return fmt.Errorf("read deleted provider draft rows for %q: %w", draftID, err) } if affected == 0 { return fmt.Errorf("provider draft %q not found", draftID) } return nil }