package sqlite import ( "context" "database/sql" "errors" "fmt" "io/fs" "sort" "strings" _ "modernc.org/sqlite" "sub2api-cn-relay-manager/internal/store/migrations" ) type execQuerier interface { ExecContext(context.Context, string, ...any) (sql.Result, error) QueryContext(context.Context, string, ...any) (*sql.Rows, error) QueryRowContext(context.Context, string, ...any) *sql.Row } type Queries struct { Hosts *HostsRepo Packs *PacksRepo Providers *ProvidersRepo ImportBatches *ImportBatchesRepo ImportBatchItems *ImportBatchItemsRepo ImportRuns *ImportRunsRepo ImportRunItems *ImportRunItemsRepo ImportRunEvents *ImportRunItemEventsRepo ManagedResources *ManagedResourcesRepo ProbeResults *ProbeResultsRepo AccessClosures *AccessClosureRecordsRepo ReconcileRuns *ReconcileRunsRepo } type DB struct { sqlDB *sql.DB queries *Queries } func Open(ctx context.Context, dsn string) (*DB, error) { sqlDB, err := sql.Open("sqlite", withForeignKeysEnabled(dsn)) if err != nil { return nil, fmt.Errorf("open sqlite database: %w", err) } if err := sqlDB.PingContext(ctx); err != nil { _ = sqlDB.Close() return nil, fmt.Errorf("ping sqlite database: %w", err) } if err := ensureForeignKeys(ctx, sqlDB); err != nil { _ = sqlDB.Close() return nil, err } if err := migrate(ctx, sqlDB); err != nil { _ = sqlDB.Close() return nil, err } return &DB{ sqlDB: sqlDB, queries: newQueries(sqlDB), }, nil } func (db *DB) Close() error { return db.sqlDB.Close() } func (db *DB) SQLDB() *sql.DB { return db.sqlDB } func (db *DB) Hosts() *HostsRepo { return db.queries.Hosts } func (db *DB) Packs() *PacksRepo { return db.queries.Packs } func (db *DB) Providers() *ProvidersRepo { return db.queries.Providers } func (db *DB) ImportBatches() *ImportBatchesRepo { return db.queries.ImportBatches } func (db *DB) ImportBatchItems() *ImportBatchItemsRepo { return db.queries.ImportBatchItems } func (db *DB) ImportRuns() *ImportRunsRepo { return db.queries.ImportRuns } func (db *DB) ImportRunItems() *ImportRunItemsRepo { return db.queries.ImportRunItems } func (db *DB) ImportRunEvents() *ImportRunItemEventsRepo { return db.queries.ImportRunEvents } func (db *DB) ImportRunItemEvents() *ImportRunItemEventsRepo { return db.queries.ImportRunEvents } func (db *DB) ManagedResources() *ManagedResourcesRepo { return db.queries.ManagedResources } func (db *DB) ProbeResults() *ProbeResultsRepo { return db.queries.ProbeResults } func (db *DB) AccessClosures() *AccessClosureRecordsRepo { return db.queries.AccessClosures } func (db *DB) ReconcileRuns() *ReconcileRunsRepo { return db.queries.ReconcileRuns } func (db *DB) WithTx(ctx context.Context, fn func(*Queries) error) error { tx, err := db.sqlDB.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("begin sqlite transaction: %w", err) } queries := newQueries(tx) if err := fn(queries); err != nil { if rollbackErr := tx.Rollback(); rollbackErr != nil { return errors.Join(err, fmt.Errorf("rollback sqlite transaction: %w", rollbackErr)) } return err } if err := tx.Commit(); err != nil { _ = tx.Rollback() return fmt.Errorf("commit sqlite transaction: %w", err) } return nil } func newQueries(db execQuerier) *Queries { return &Queries{ Hosts: newHostsRepo(db), Packs: newPacksRepo(db), Providers: newProvidersRepo(db), ImportBatches: newImportBatchesRepo(db), ImportBatchItems: newImportBatchItemsRepo(db), ImportRuns: newImportRunsRepo(db), ImportRunItems: newImportRunItemsRepo(db), ImportRunEvents: newImportRunItemEventsRepo(db), ManagedResources: newManagedResourcesRepo(db), ProbeResults: newProbeResultsRepo(db), AccessClosures: newAccessClosureRecordsRepo(db), ReconcileRuns: newReconcileRunsRepo(db), } } func migrate(ctx context.Context, db *sql.DB) error { migrationNames, err := migrationFileNames() if err != nil { return err } tx, err := db.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("begin sqlite migration transaction: %w", err) } if err := ensureMigrationLedger(ctx, tx); err != nil { return rollbackMigration(tx, err) } appliedMigrations, err := loadAppliedMigrations(ctx, tx) if err != nil { return rollbackMigration(tx, err) } if err := backfillLegacySchemaIfNeeded(ctx, tx, migrationNames, appliedMigrations); err != nil { return rollbackMigration(tx, err) } for _, name := range migrationNames { if appliedMigrations[name] { continue } migrationSQL, err := readMigration(name) if err != nil { return rollbackMigration(tx, err) } if _, err := tx.ExecContext(ctx, migrationSQL); err != nil { return rollbackMigration(tx, fmt.Errorf("apply sqlite migration %s: %w", name, err)) } if _, err := tx.ExecContext( ctx, "INSERT INTO schema_migrations (version) VALUES (?)", name, ); err != nil { return rollbackMigration(tx, fmt.Errorf("record sqlite migration %s: %w", name, err)) } } if err := tx.Commit(); err != nil { _ = tx.Rollback() return fmt.Errorf("commit sqlite migration transaction: %w", err) } return nil } func withForeignKeysEnabled(dsn string) string { const pragma = "_pragma=foreign_keys(1)" if strings.Contains(dsn, "?") { return dsn + "&" + pragma } return dsn + "?" + pragma } func ensureForeignKeys(ctx context.Context, db *sql.DB) error { if _, err := db.ExecContext(ctx, "PRAGMA foreign_keys = ON"); err != nil { return fmt.Errorf("enable sqlite foreign keys: %w", err) } var enabled int if err := db.QueryRowContext(ctx, "PRAGMA foreign_keys").Scan(&enabled); err != nil { return fmt.Errorf("verify sqlite foreign keys: %w", err) } if enabled != 1 { return errors.New("sqlite foreign keys are disabled") } return nil } func ensureMigrationLedger(ctx context.Context, tx *sql.Tx) error { const createLedgerSQL = ` CREATE TABLE IF NOT EXISTS schema_migrations ( version TEXT PRIMARY KEY, applied_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP )` if _, err := tx.ExecContext(ctx, createLedgerSQL); err != nil { return fmt.Errorf("create schema_migrations table: %w", err) } return nil } func loadAppliedMigrations(ctx context.Context, tx *sql.Tx) (map[string]bool, error) { rows, err := tx.QueryContext(ctx, "SELECT version FROM schema_migrations") if err != nil { return nil, fmt.Errorf("query applied sqlite migrations: %w", err) } defer rows.Close() applied := make(map[string]bool) for rows.Next() { var version string if err := rows.Scan(&version); err != nil { return nil, fmt.Errorf("scan applied sqlite migration: %w", err) } applied[version] = true } if err := rows.Err(); err != nil { return nil, fmt.Errorf("iterate applied sqlite migrations: %w", err) } return applied, nil } func migrationFileNames() ([]string, error) { entries, err := fs.ReadDir(migrations.Files, ".") if err != nil { return nil, fmt.Errorf("read embedded sqlite migrations: %w", err) } var names []string for _, entry := range entries { if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".sql") { continue } names = append(names, entry.Name()) } sort.Strings(names) return names, nil } func readMigration(name string) (string, error) { data, err := fs.ReadFile(migrations.Files, name) if err != nil { return "", fmt.Errorf("read embedded migration %s: %w", name, err) } return string(data), nil } func rollbackMigration(tx *sql.Tx, err error) error { if rollbackErr := tx.Rollback(); rollbackErr != nil { return errors.Join(err, fmt.Errorf("rollback sqlite migration transaction: %w", rollbackErr)) } return err } func backfillLegacySchemaIfNeeded(ctx context.Context, tx *sql.Tx, migrationNames []string, appliedMigrations map[string]bool) error { if len(migrationNames) == 0 { return nil } if len(appliedMigrations) != 0 { return nil } firstMigration := migrationNames[0] if firstMigration != "0001_init.sql" { return nil } complete, partial, err := detectLegacy0001Schema(ctx, tx) if err != nil { return err } if partial { return errors.New("legacy sqlite schema is partially applied without schema_migrations") } if !complete { return nil } if _, err := tx.ExecContext( ctx, "INSERT INTO schema_migrations (version) VALUES (?)", firstMigration, ); err != nil { return fmt.Errorf("backfill sqlite migration %s: %w", firstMigration, err) } appliedMigrations[firstMigration] = true return nil } func detectLegacy0001Schema(ctx context.Context, tx *sql.Tx) (complete bool, partial bool, err error) { legacyTables := []string{"hosts", "packs", "providers"} existing := 0 for _, table := range legacyTables { found, err := tableExists(ctx, tx, table) if err != nil { return false, false, err } if found { existing++ } } switch existing { case 0: return false, false, nil case len(legacyTables): return true, false, nil default: return false, true, nil } } func tableExists(ctx context.Context, db execQuerier, table string) (bool, error) { var name string err := db.QueryRowContext( ctx, "SELECT name FROM sqlite_master WHERE type = 'table' AND name = ?", table, ).Scan(&name) if errors.Is(err, sql.ErrNoRows) { return false, nil } if err != nil { return false, fmt.Errorf("check sqlite table %s: %w", table, err) } return name == table, nil }