Files
sub2api-cn-relay-manager/internal/store/sqlite/import_run_items_repo.go

295 lines
13 KiB
Go

package sqlite
import (
"context"
"database/sql"
"fmt"
"strings"
"time"
)
type ImportRunItem struct {
ItemID string
RunID string
BaseURL string
ProviderID string
APIKeyFingerprint string
RequestedModelsJSON string
RawModelsJSON string
NormalizedModelsJSON string
CanonicalFamiliesJSON string
RecommendedModelsJSON string
ResolvedSmokeModel string
CapabilityProfileJSON string
CurrentStage string
ConfirmationStatus string
AccessStatus string
MatchedAccountState string
AccountResolution string
ProvisionReused bool
ReusedFromProviderID string
ReusedFromAccountID *int64
ChannelID *int64
AccountID *int64
RetryCount int
ConfirmationAttempts int
LastRetryAt string
NextRetryAt string
LeaseOwner string
LeaseUntil string
AdvisoryMessagesJSON string
LastErrorStage string
LastError string
LegacyBatchID *int64
LegacyProviderID string
CreatedAt string
UpdatedAt string
}
type ImportRunItemsRepo struct {
db execQuerier
}
func newImportRunItemsRepo(db execQuerier) *ImportRunItemsRepo {
return &ImportRunItemsRepo{db: db}
}
func (r *ImportRunItemsRepo) Create(ctx context.Context, item ImportRunItem) error {
return r.Upsert(ctx, item)
}
func (r *ImportRunItemsRepo) Update(ctx context.Context, item ImportRunItem) error {
return r.Upsert(ctx, item)
}
func (r *ImportRunItemsRepo) Upsert(ctx context.Context, item ImportRunItem) error {
itemID := strings.TrimSpace(item.ItemID)
runID := strings.TrimSpace(item.RunID)
baseURL := strings.TrimSpace(item.BaseURL)
providerID := strings.TrimSpace(item.ProviderID)
apiKeyFingerprint := strings.TrimSpace(item.APIKeyFingerprint)
currentStage := strings.TrimSpace(item.CurrentStage)
confirmationStatus := strings.TrimSpace(item.ConfirmationStatus)
accessStatus := strings.TrimSpace(item.AccessStatus)
matchedAccountState := strings.TrimSpace(item.MatchedAccountState)
accountResolution := strings.TrimSpace(item.AccountResolution)
switch {
case itemID == "":
return fmt.Errorf("item_id is required")
case runID == "":
return fmt.Errorf("run_id is required")
case baseURL == "":
return fmt.Errorf("base_url is required")
case providerID == "":
return fmt.Errorf("provider_id is required")
case apiKeyFingerprint == "":
return fmt.Errorf("api_key_fingerprint is required")
case currentStage == "":
return fmt.Errorf("current_stage is required")
case confirmationStatus == "":
return fmt.Errorf("confirmation_status is required")
case accessStatus == "":
return fmt.Errorf("access_status is required")
case matchedAccountState == "":
return fmt.Errorf("matched_account_state is required")
case accountResolution == "":
return fmt.Errorf("account_resolution is required")
}
if _, err := r.db.ExecContext(ctx, `INSERT INTO import_run_items (
item_id, run_id, base_url, provider_id, api_key_fingerprint, requested_models_json, raw_models_json, normalized_models_json,
canonical_model_families_json, recommended_models_json, resolved_smoke_model, capability_profile_json, current_stage,
confirmation_status, access_status, matched_account_state, account_resolution, provision_reused, reused_from_provider_id,
reused_from_account_id, channel_id, account_id, retry_count, confirmation_attempts, last_retry_at, next_retry_at,
lease_owner, lease_until, advisory_messages_json, last_error_stage, last_error, legacy_batch_id, legacy_provider_id
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(item_id) DO UPDATE SET
run_id = excluded.run_id,
base_url = excluded.base_url,
provider_id = excluded.provider_id,
api_key_fingerprint = excluded.api_key_fingerprint,
requested_models_json = excluded.requested_models_json,
raw_models_json = excluded.raw_models_json,
normalized_models_json = excluded.normalized_models_json,
canonical_model_families_json = excluded.canonical_model_families_json,
recommended_models_json = excluded.recommended_models_json,
resolved_smoke_model = excluded.resolved_smoke_model,
capability_profile_json = excluded.capability_profile_json,
current_stage = excluded.current_stage,
confirmation_status = excluded.confirmation_status,
access_status = excluded.access_status,
matched_account_state = excluded.matched_account_state,
account_resolution = excluded.account_resolution,
provision_reused = excluded.provision_reused,
reused_from_provider_id = excluded.reused_from_provider_id,
reused_from_account_id = excluded.reused_from_account_id,
channel_id = excluded.channel_id,
account_id = excluded.account_id,
retry_count = excluded.retry_count,
confirmation_attempts = excluded.confirmation_attempts,
last_retry_at = excluded.last_retry_at,
next_retry_at = excluded.next_retry_at,
lease_owner = excluded.lease_owner,
lease_until = excluded.lease_until,
advisory_messages_json = excluded.advisory_messages_json,
last_error_stage = excluded.last_error_stage,
last_error = excluded.last_error,
legacy_batch_id = excluded.legacy_batch_id,
legacy_provider_id = excluded.legacy_provider_id,
updated_at = CURRENT_TIMESTAMP`,
itemID,
runID,
baseURL,
providerID,
apiKeyFingerprint,
defaultJSON(item.RequestedModelsJSON, "[]"),
defaultJSON(item.RawModelsJSON, "[]"),
defaultJSON(item.NormalizedModelsJSON, "[]"),
defaultJSON(item.CanonicalFamiliesJSON, "[]"),
defaultJSON(item.RecommendedModelsJSON, "[]"),
nullableString(strings.TrimSpace(item.ResolvedSmokeModel)),
defaultJSON(item.CapabilityProfileJSON, "{}"),
currentStage,
confirmationStatus,
accessStatus,
matchedAccountState,
accountResolution,
boolToInt(item.ProvisionReused),
nullableString(strings.TrimSpace(item.ReusedFromProviderID)),
item.ReusedFromAccountID,
item.ChannelID,
item.AccountID,
item.RetryCount,
item.ConfirmationAttempts,
nullableString(strings.TrimSpace(item.LastRetryAt)),
nullableString(strings.TrimSpace(item.NextRetryAt)),
nullableString(strings.TrimSpace(item.LeaseOwner)),
nullableString(strings.TrimSpace(item.LeaseUntil)),
defaultJSON(item.AdvisoryMessagesJSON, "[]"),
nullableString(strings.TrimSpace(item.LastErrorStage)),
nullableString(strings.TrimSpace(item.LastError)),
item.LegacyBatchID,
nullableString(strings.TrimSpace(item.LegacyProviderID)),
); err != nil {
return fmt.Errorf("upsert import run item %q: %w", itemID, err)
}
return nil
}
func (r *ImportRunItemsRepo) GetByItemID(ctx context.Context, itemID string) (ImportRunItem, error) {
itemID = strings.TrimSpace(itemID)
if itemID == "" {
return ImportRunItem{}, fmt.Errorf("item_id is required")
}
var item ImportRunItem
var reusedFromAccountID sql.NullInt64
var channelID sql.NullInt64
var accountID sql.NullInt64
var legacyBatchID sql.NullInt64
var provisionReused int
if err := r.db.QueryRowContext(ctx, `SELECT item_id, run_id, base_url, provider_id, api_key_fingerprint, requested_models_json, raw_models_json, normalized_models_json, canonical_model_families_json, recommended_models_json, COALESCE(resolved_smoke_model, ''), capability_profile_json, current_stage, confirmation_status, access_status, matched_account_state, account_resolution, provision_reused, COALESCE(reused_from_provider_id, ''), reused_from_account_id, channel_id, account_id, retry_count, confirmation_attempts, COALESCE(last_retry_at, ''), COALESCE(next_retry_at, ''), COALESCE(lease_owner, ''), COALESCE(lease_until, ''), advisory_messages_json, COALESCE(last_error_stage, ''), COALESCE(last_error, ''), legacy_batch_id, COALESCE(legacy_provider_id, ''), created_at, updated_at FROM import_run_items WHERE item_id = ?`, itemID).
Scan(&item.ItemID, &item.RunID, &item.BaseURL, &item.ProviderID, &item.APIKeyFingerprint, &item.RequestedModelsJSON, &item.RawModelsJSON, &item.NormalizedModelsJSON, &item.CanonicalFamiliesJSON, &item.RecommendedModelsJSON, &item.ResolvedSmokeModel, &item.CapabilityProfileJSON, &item.CurrentStage, &item.ConfirmationStatus, &item.AccessStatus, &item.MatchedAccountState, &item.AccountResolution, &provisionReused, &item.ReusedFromProviderID, &reusedFromAccountID, &channelID, &accountID, &item.RetryCount, &item.ConfirmationAttempts, &item.LastRetryAt, &item.NextRetryAt, &item.LeaseOwner, &item.LeaseUntil, &item.AdvisoryMessagesJSON, &item.LastErrorStage, &item.LastError, &legacyBatchID, &item.LegacyProviderID, &item.CreatedAt, &item.UpdatedAt); err != nil {
return ImportRunItem{}, err
}
item.ProvisionReused = provisionReused == 1
item.ReusedFromAccountID = ptrFromNullInt64(reusedFromAccountID)
item.ChannelID = ptrFromNullInt64(channelID)
item.AccountID = ptrFromNullInt64(accountID)
item.LegacyBatchID = ptrFromNullInt64(legacyBatchID)
return item, nil
}
func (r *ImportRunItemsRepo) ListByRunID(ctx context.Context, runID string) ([]ImportRunItem, error) {
runID = strings.TrimSpace(runID)
if runID == "" {
return nil, fmt.Errorf("run_id is required")
}
rows, err := r.db.QueryContext(ctx, `SELECT item_id, run_id, base_url, provider_id, api_key_fingerprint, requested_models_json, raw_models_json, normalized_models_json, canonical_model_families_json, recommended_models_json, COALESCE(resolved_smoke_model, ''), capability_profile_json, current_stage, confirmation_status, access_status, matched_account_state, account_resolution, provision_reused, COALESCE(reused_from_provider_id, ''), reused_from_account_id, channel_id, account_id, retry_count, confirmation_attempts, COALESCE(last_retry_at, ''), COALESCE(next_retry_at, ''), COALESCE(lease_owner, ''), COALESCE(lease_until, ''), advisory_messages_json, COALESCE(last_error_stage, ''), COALESCE(last_error, ''), legacy_batch_id, COALESCE(legacy_provider_id, ''), created_at, updated_at FROM import_run_items WHERE run_id = ? ORDER BY created_at, item_id`, runID)
if err != nil {
return nil, fmt.Errorf("list import run items by run_id %q: %w", runID, err)
}
defer rows.Close()
items := make([]ImportRunItem, 0)
for rows.Next() {
var item ImportRunItem
var reusedFromAccountID sql.NullInt64
var channelID sql.NullInt64
var accountID sql.NullInt64
var legacyBatchID sql.NullInt64
var provisionReused int
if err := rows.Scan(&item.ItemID, &item.RunID, &item.BaseURL, &item.ProviderID, &item.APIKeyFingerprint, &item.RequestedModelsJSON, &item.RawModelsJSON, &item.NormalizedModelsJSON, &item.CanonicalFamiliesJSON, &item.RecommendedModelsJSON, &item.ResolvedSmokeModel, &item.CapabilityProfileJSON, &item.CurrentStage, &item.ConfirmationStatus, &item.AccessStatus, &item.MatchedAccountState, &item.AccountResolution, &provisionReused, &item.ReusedFromProviderID, &reusedFromAccountID, &channelID, &accountID, &item.RetryCount, &item.ConfirmationAttempts, &item.LastRetryAt, &item.NextRetryAt, &item.LeaseOwner, &item.LeaseUntil, &item.AdvisoryMessagesJSON, &item.LastErrorStage, &item.LastError, &legacyBatchID, &item.LegacyProviderID, &item.CreatedAt, &item.UpdatedAt); err != nil {
return nil, fmt.Errorf("scan import run item: %w", err)
}
item.ProvisionReused = provisionReused == 1
item.ReusedFromAccountID = ptrFromNullInt64(reusedFromAccountID)
item.ChannelID = ptrFromNullInt64(channelID)
item.AccountID = ptrFromNullInt64(accountID)
item.LegacyBatchID = ptrFromNullInt64(legacyBatchID)
items = append(items, item)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate import run items by run_id %q: %w", runID, err)
}
return items, nil
}
func (r *ImportRunItemsRepo) TryAcquireConfirmationLease(ctx context.Context, itemID, workerID string, now time.Time, leaseDuration time.Duration) (ImportRunItem, bool, error) {
itemID = strings.TrimSpace(itemID)
workerID = strings.TrimSpace(workerID)
if itemID == "" {
return ImportRunItem{}, false, fmt.Errorf("item_id is required")
}
if workerID == "" {
return ImportRunItem{}, false, fmt.Errorf("worker_id is required")
}
if leaseDuration <= 0 {
leaseDuration = time.Minute
}
nowText := now.UTC().Format(time.RFC3339)
leaseUntil := now.UTC().Add(leaseDuration).Format(time.RFC3339)
result, err := r.db.ExecContext(ctx, `UPDATE import_run_items
SET lease_owner = ?, lease_until = ?, confirmation_attempts = confirmation_attempts + 1, updated_at = CURRENT_TIMESTAMP
WHERE item_id = ?
AND current_stage = 'confirm'
AND confirmation_status = 'pending'
AND (next_retry_at IS NULL OR next_retry_at = '' OR next_retry_at <= ?)
AND (lease_until IS NULL OR lease_until = '' OR lease_until < ?)`,
workerID, leaseUntil, itemID, nowText, nowText)
if err != nil {
return ImportRunItem{}, false, fmt.Errorf("acquire confirmation lease for %q: %w", itemID, err)
}
rows, err := result.RowsAffected()
if err != nil {
return ImportRunItem{}, false, err
}
if rows == 0 {
return ImportRunItem{}, false, nil
}
item, err := r.GetByItemID(ctx, itemID)
if err != nil {
return ImportRunItem{}, false, err
}
return item, true, nil
}
func boolToInt(value bool) int {
if value {
return 1
}
return 0
}
func ptrFromNullInt64(value sql.NullInt64) *int64 {
if !value.Valid {
return nil
}
result := value.Int64
return &result
}