400 lines
9.2 KiB
Go
400 lines
9.2 KiB
Go
package sqlite
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"io/fs"
|
|
"sort"
|
|
"strings"
|
|
|
|
_ "modernc.org/sqlite"
|
|
"sub2api-cn-relay-manager/internal/store/migrations"
|
|
)
|
|
|
|
type execQuerier interface {
|
|
ExecContext(context.Context, string, ...any) (sql.Result, error)
|
|
QueryContext(context.Context, string, ...any) (*sql.Rows, error)
|
|
QueryRowContext(context.Context, string, ...any) *sql.Row
|
|
}
|
|
|
|
type Queries struct {
|
|
Hosts *HostsRepo
|
|
Packs *PacksRepo
|
|
Providers *ProvidersRepo
|
|
ImportBatches *ImportBatchesRepo
|
|
ImportBatchItems *ImportBatchItemsRepo
|
|
ImportRuns *ImportRunsRepo
|
|
ImportRunItems *ImportRunItemsRepo
|
|
ImportRunEvents *ImportRunItemEventsRepo
|
|
ManagedResources *ManagedResourcesRepo
|
|
ProbeResults *ProbeResultsRepo
|
|
AccessClosures *AccessClosureRecordsRepo
|
|
ReconcileRuns *ReconcileRunsRepo
|
|
}
|
|
|
|
type DB struct {
|
|
sqlDB *sql.DB
|
|
queries *Queries
|
|
}
|
|
|
|
func Open(ctx context.Context, dsn string) (*DB, error) {
|
|
sqlDB, err := sql.Open("sqlite", withForeignKeysEnabled(dsn))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("open sqlite database: %w", err)
|
|
}
|
|
|
|
if err := sqlDB.PingContext(ctx); err != nil {
|
|
_ = sqlDB.Close()
|
|
return nil, fmt.Errorf("ping sqlite database: %w", err)
|
|
}
|
|
|
|
if err := ensureForeignKeys(ctx, sqlDB); err != nil {
|
|
_ = sqlDB.Close()
|
|
return nil, err
|
|
}
|
|
|
|
if err := migrate(ctx, sqlDB); err != nil {
|
|
_ = sqlDB.Close()
|
|
return nil, err
|
|
}
|
|
|
|
return &DB{
|
|
sqlDB: sqlDB,
|
|
queries: newQueries(sqlDB),
|
|
}, nil
|
|
}
|
|
|
|
func (db *DB) Close() error {
|
|
return db.sqlDB.Close()
|
|
}
|
|
|
|
func (db *DB) SQLDB() *sql.DB {
|
|
return db.sqlDB
|
|
}
|
|
|
|
func (db *DB) Hosts() *HostsRepo {
|
|
return db.queries.Hosts
|
|
}
|
|
|
|
func (db *DB) Packs() *PacksRepo {
|
|
return db.queries.Packs
|
|
}
|
|
|
|
func (db *DB) Providers() *ProvidersRepo {
|
|
return db.queries.Providers
|
|
}
|
|
|
|
func (db *DB) ImportBatches() *ImportBatchesRepo {
|
|
return db.queries.ImportBatches
|
|
}
|
|
|
|
func (db *DB) ImportBatchItems() *ImportBatchItemsRepo {
|
|
return db.queries.ImportBatchItems
|
|
}
|
|
|
|
func (db *DB) ImportRuns() *ImportRunsRepo {
|
|
return db.queries.ImportRuns
|
|
}
|
|
|
|
func (db *DB) ImportRunItems() *ImportRunItemsRepo {
|
|
return db.queries.ImportRunItems
|
|
}
|
|
|
|
func (db *DB) ImportRunEvents() *ImportRunItemEventsRepo {
|
|
return db.queries.ImportRunEvents
|
|
}
|
|
|
|
func (db *DB) ImportRunItemEvents() *ImportRunItemEventsRepo {
|
|
return db.queries.ImportRunEvents
|
|
}
|
|
|
|
func (db *DB) ManagedResources() *ManagedResourcesRepo {
|
|
return db.queries.ManagedResources
|
|
}
|
|
|
|
func (db *DB) ProbeResults() *ProbeResultsRepo {
|
|
return db.queries.ProbeResults
|
|
}
|
|
|
|
func (db *DB) AccessClosures() *AccessClosureRecordsRepo {
|
|
return db.queries.AccessClosures
|
|
}
|
|
|
|
func (db *DB) ReconcileRuns() *ReconcileRunsRepo {
|
|
return db.queries.ReconcileRuns
|
|
}
|
|
|
|
func (db *DB) WithTx(ctx context.Context, fn func(*Queries) error) error {
|
|
tx, err := db.sqlDB.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("begin sqlite transaction: %w", err)
|
|
}
|
|
|
|
queries := newQueries(tx)
|
|
if err := fn(queries); err != nil {
|
|
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
|
return errors.Join(err, fmt.Errorf("rollback sqlite transaction: %w", rollbackErr))
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
_ = tx.Rollback()
|
|
return fmt.Errorf("commit sqlite transaction: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func newQueries(db execQuerier) *Queries {
|
|
return &Queries{
|
|
Hosts: newHostsRepo(db),
|
|
Packs: newPacksRepo(db),
|
|
Providers: newProvidersRepo(db),
|
|
ImportBatches: newImportBatchesRepo(db),
|
|
ImportBatchItems: newImportBatchItemsRepo(db),
|
|
ImportRuns: newImportRunsRepo(db),
|
|
ImportRunItems: newImportRunItemsRepo(db),
|
|
ImportRunEvents: newImportRunItemEventsRepo(db),
|
|
ManagedResources: newManagedResourcesRepo(db),
|
|
ProbeResults: newProbeResultsRepo(db),
|
|
AccessClosures: newAccessClosureRecordsRepo(db),
|
|
ReconcileRuns: newReconcileRunsRepo(db),
|
|
}
|
|
}
|
|
|
|
func migrate(ctx context.Context, db *sql.DB) error {
|
|
migrationNames, err := migrationFileNames()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
tx, err := db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("begin sqlite migration transaction: %w", err)
|
|
}
|
|
|
|
if err := ensureMigrationLedger(ctx, tx); err != nil {
|
|
return rollbackMigration(tx, err)
|
|
}
|
|
|
|
appliedMigrations, err := loadAppliedMigrations(ctx, tx)
|
|
if err != nil {
|
|
return rollbackMigration(tx, err)
|
|
}
|
|
|
|
if err := backfillLegacySchemaIfNeeded(ctx, tx, migrationNames, appliedMigrations); err != nil {
|
|
return rollbackMigration(tx, err)
|
|
}
|
|
|
|
for _, name := range migrationNames {
|
|
if appliedMigrations[name] {
|
|
continue
|
|
}
|
|
|
|
migrationSQL, err := readMigration(name)
|
|
if err != nil {
|
|
return rollbackMigration(tx, err)
|
|
}
|
|
|
|
if _, err := tx.ExecContext(ctx, migrationSQL); err != nil {
|
|
return rollbackMigration(tx, fmt.Errorf("apply sqlite migration %s: %w", name, err))
|
|
}
|
|
|
|
if _, err := tx.ExecContext(
|
|
ctx,
|
|
"INSERT INTO schema_migrations (version) VALUES (?)",
|
|
name,
|
|
); err != nil {
|
|
return rollbackMigration(tx, fmt.Errorf("record sqlite migration %s: %w", name, err))
|
|
}
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
_ = tx.Rollback()
|
|
return fmt.Errorf("commit sqlite migration transaction: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func withForeignKeysEnabled(dsn string) string {
|
|
const pragma = "_pragma=foreign_keys(1)"
|
|
|
|
if strings.Contains(dsn, "?") {
|
|
return dsn + "&" + pragma
|
|
}
|
|
|
|
return dsn + "?" + pragma
|
|
}
|
|
|
|
func ensureForeignKeys(ctx context.Context, db *sql.DB) error {
|
|
if _, err := db.ExecContext(ctx, "PRAGMA foreign_keys = ON"); err != nil {
|
|
return fmt.Errorf("enable sqlite foreign keys: %w", err)
|
|
}
|
|
|
|
var enabled int
|
|
if err := db.QueryRowContext(ctx, "PRAGMA foreign_keys").Scan(&enabled); err != nil {
|
|
return fmt.Errorf("verify sqlite foreign keys: %w", err)
|
|
}
|
|
if enabled != 1 {
|
|
return errors.New("sqlite foreign keys are disabled")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func ensureMigrationLedger(ctx context.Context, tx *sql.Tx) error {
|
|
const createLedgerSQL = `
|
|
CREATE TABLE IF NOT EXISTS schema_migrations (
|
|
version TEXT PRIMARY KEY,
|
|
applied_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
|
|
)`
|
|
|
|
if _, err := tx.ExecContext(ctx, createLedgerSQL); err != nil {
|
|
return fmt.Errorf("create schema_migrations table: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func loadAppliedMigrations(ctx context.Context, tx *sql.Tx) (map[string]bool, error) {
|
|
rows, err := tx.QueryContext(ctx, "SELECT version FROM schema_migrations")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("query applied sqlite migrations: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
applied := make(map[string]bool)
|
|
for rows.Next() {
|
|
var version string
|
|
if err := rows.Scan(&version); err != nil {
|
|
return nil, fmt.Errorf("scan applied sqlite migration: %w", err)
|
|
}
|
|
|
|
applied[version] = true
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("iterate applied sqlite migrations: %w", err)
|
|
}
|
|
|
|
return applied, nil
|
|
}
|
|
|
|
func migrationFileNames() ([]string, error) {
|
|
entries, err := fs.ReadDir(migrations.Files, ".")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read embedded sqlite migrations: %w", err)
|
|
}
|
|
|
|
var names []string
|
|
for _, entry := range entries {
|
|
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".sql") {
|
|
continue
|
|
}
|
|
|
|
names = append(names, entry.Name())
|
|
}
|
|
|
|
sort.Strings(names)
|
|
return names, nil
|
|
}
|
|
|
|
func readMigration(name string) (string, error) {
|
|
data, err := fs.ReadFile(migrations.Files, name)
|
|
if err != nil {
|
|
return "", fmt.Errorf("read embedded migration %s: %w", name, err)
|
|
}
|
|
|
|
return string(data), nil
|
|
}
|
|
|
|
func rollbackMigration(tx *sql.Tx, err error) error {
|
|
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
|
return errors.Join(err, fmt.Errorf("rollback sqlite migration transaction: %w", rollbackErr))
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
func backfillLegacySchemaIfNeeded(ctx context.Context, tx *sql.Tx, migrationNames []string, appliedMigrations map[string]bool) error {
|
|
if len(migrationNames) == 0 {
|
|
return nil
|
|
}
|
|
if len(appliedMigrations) != 0 {
|
|
return nil
|
|
}
|
|
|
|
firstMigration := migrationNames[0]
|
|
if firstMigration != "0001_init.sql" {
|
|
return nil
|
|
}
|
|
|
|
complete, partial, err := detectLegacy0001Schema(ctx, tx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if partial {
|
|
return errors.New("legacy sqlite schema is partially applied without schema_migrations")
|
|
}
|
|
if !complete {
|
|
return nil
|
|
}
|
|
|
|
if _, err := tx.ExecContext(
|
|
ctx,
|
|
"INSERT INTO schema_migrations (version) VALUES (?)",
|
|
firstMigration,
|
|
); err != nil {
|
|
return fmt.Errorf("backfill sqlite migration %s: %w", firstMigration, err)
|
|
}
|
|
|
|
appliedMigrations[firstMigration] = true
|
|
return nil
|
|
}
|
|
|
|
func detectLegacy0001Schema(ctx context.Context, tx *sql.Tx) (complete bool, partial bool, err error) {
|
|
legacyTables := []string{"hosts", "packs", "providers"}
|
|
|
|
existing := 0
|
|
for _, table := range legacyTables {
|
|
found, err := tableExists(ctx, tx, table)
|
|
if err != nil {
|
|
return false, false, err
|
|
}
|
|
if found {
|
|
existing++
|
|
}
|
|
}
|
|
|
|
switch existing {
|
|
case 0:
|
|
return false, false, nil
|
|
case len(legacyTables):
|
|
return true, false, nil
|
|
default:
|
|
return false, true, nil
|
|
}
|
|
}
|
|
|
|
func tableExists(ctx context.Context, db execQuerier, table string) (bool, error) {
|
|
var name string
|
|
err := db.QueryRowContext(
|
|
ctx,
|
|
"SELECT name FROM sqlite_master WHERE type = 'table' AND name = ?",
|
|
table,
|
|
).Scan(&name)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return false, nil
|
|
}
|
|
if err != nil {
|
|
return false, fmt.Errorf("check sqlite table %s: %w", table, err)
|
|
}
|
|
|
|
return name == table, nil
|
|
}
|