package sqlite import ( "context" "database/sql" "errors" "fmt" "path/filepath" "testing" ) // openTestDB creates a test database with foreign keys disabled. func openTestDB(t *testing.T) *DB { t.Helper() dbPath := filepath.Join(t.TempDir(), "test.db") dsn := "file:" + filepath.ToSlash(dbPath) + "?_busy_timeout=5000&_pragma=foreign_keys(0)" store, err := Open(context.Background(), dsn) if err != nil { t.Fatalf("Open() error = %v", err) } t.Cleanup(func() { _ = store.Close() }) return store } // openTestDBWithFK creates a test database with foreign keys enforced. func openTestDBWithFK(t *testing.T) *DB { t.Helper() dbPath := filepath.Join(t.TempDir(), "test-fk.db") dsn := "file:" + filepath.ToSlash(dbPath) + "?_busy_timeout=5000" store, err := Open(context.Background(), dsn) if err != nil { t.Fatalf("Open() error = %v", err) } t.Cleanup(func() { _ = store.Close() }) return store } func createTestPack(t *testing.T, store *DB) int64 { t.Helper() id, err := store.Packs().Create(context.Background(), Pack{ PackID: "pack-" + sanitizeTestName(t.Name()), Version: "1.0.0", Checksum: "chk", }) if err != nil { t.Fatalf("createTestPack error = %v", err) } return id } func createTestHost(t *testing.T, store *DB) int64 { t.Helper() id, err := store.Hosts().Create(context.Background(), Host{ HostID: "host-" + sanitizeTestName(t.Name()), BaseURL: "https://h.com", HostVersion: "0.1.0", }) if err != nil { t.Fatalf("createTestHost error = %v", err) } return id } func createTestHostWithBaseURL(t *testing.T, store *DB, hostID, baseURL string) int64 { t.Helper() id, err := store.Hosts().Create(context.Background(), Host{ HostID: hostID, BaseURL: baseURL, HostVersion: "0.1.0", }) if err != nil { t.Fatalf("createTestHostWithBaseURL error = %v", err) } return id } func createTestBatch(t *testing.T, store *DB) int64 { t.Helper() hostID := createTestHost(t, store) packID := createTestPack(t, store) providerID, err := store.Providers().Create(context.Background(), Provider{ PackID: packID, ProviderID: "test-provider", DisplayName: "Test", BaseURL: "https://t.com", Platform: "openai", }) if err != nil { t.Fatalf("createTestBatch create provider error = %v", err) } id, err := store.ImportBatches().Create(context.Background(), ImportBatch{ HostID: hostID, PackID: packID, ProviderID: providerID, Mode: "partial", BatchStatus: "running", AccessStatus: "pending", }) if err != nil { t.Fatalf("createTestBatch error = %v", err) } return id } func createTestBatchItem(t *testing.T, store *DB, batchID int64) int64 { t.Helper() id, err := store.ImportBatchItems().Create(context.Background(), ImportBatchItem{ BatchID: batchID, KeyFingerprint: "sha256:test", AccountStatus: "pending", }) if err != nil { t.Fatalf("createTestBatchItem error = %v", err) } return id } func sanitizeTestName(name string) string { result := "" for _, c := range name { if (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-' { result += string(c) } } if result == "" { result = "default" } return result } // --- Hosts Repo Tests --- func TestHostsRepoCreateAndGet(t *testing.T) { store := openTestDB(t) id, err := store.Hosts().Create(context.Background(), Host{ HostID: "host-1", BaseURL: "https://sub2api.example.com", HostVersion: "0.1.126", CapabilityProbeJSON: `{"groups":true}`, }) if err != nil { t.Fatalf("Create() error = %v", err) } if id <= 0 { t.Fatalf("Create() id = %d, want positive", id) } got, err := store.Hosts().GetByID(context.Background(), id) if err != nil { t.Fatalf("GetByID() error = %v", err) } if got.HostID != "host-1" || got.BaseURL != "https://sub2api.example.com" { t.Fatalf("GetByID() = %+v, want host-1", got) } got2, err := store.Hosts().GetByHostID(context.Background(), "host-1") if err != nil { t.Fatalf("GetByHostID() error = %v", err) } if got2.ID != id { t.Fatalf("GetByHostID() id = %d, want %d", got2.ID, id) } } func TestHostsRepoCreateDefaultsCapabilityProbe(t *testing.T) { store := openTestDB(t) id, _ := store.Hosts().Create(context.Background(), Host{ HostID: "host-empty", BaseURL: "https://example.com", HostVersion: "0.1.0", }) got, _ := store.Hosts().GetByID(context.Background(), id) if got.CapabilityProbeJSON != "{}" { t.Fatalf("CapabilityProbeJSON = %q, want {}", got.CapabilityProbeJSON) } } func TestHostsRepoValidationErrors(t *testing.T) { store := openTestDB(t) for _, tt := range []struct { name string host Host }{ {"empty host_id", Host{BaseURL: "b", HostVersion: "v"}}, {"empty base_url", Host{HostID: "h", HostVersion: "v"}}, {"empty host_version", Host{HostID: "h", BaseURL: "b"}}, } { t.Run(tt.name, func(t *testing.T) { _, err := store.Hosts().Create(context.Background(), tt.host) if err == nil { t.Fatal("Create() error = nil, want validation error") } }) } } func TestHostsRepoGetByIDZeroError(t *testing.T) { store := openTestDB(t) _, err := store.Hosts().GetByID(context.Background(), 0) if err == nil { t.Fatal("GetByID(0) error = nil, want error") } } func TestHostsRepoGetByIDNotFound(t *testing.T) { store := openTestDB(t) _, err := store.Hosts().GetByID(context.Background(), 999) if !errors.Is(err, sql.ErrNoRows) { t.Fatalf("GetByID(999) error = %v, want sql.ErrNoRows", err) } } func TestHostsRepoGetByHostIDEmptyError(t *testing.T) { store := openTestDB(t) _, err := store.Hosts().GetByHostID(context.Background(), "") if err == nil { t.Fatal("GetByHostID('') error = nil, want error") } } func TestHostsRepoGetByHostIDNotFound(t *testing.T) { store := openTestDB(t) _, err := store.Hosts().GetByHostID(context.Background(), "nonexistent") if !errors.Is(err, sql.ErrNoRows) { t.Fatalf("GetByHostID('nonexistent') error = %v, want sql.ErrNoRows", err) } } func TestHostsRepoGetByBaseURL(t *testing.T) { store := openTestDB(t) createTestHostWithBaseURL(t, store, "host-base-url", "https://base-url.example.com") got, err := store.Hosts().GetByBaseURL(context.Background(), "https://base-url.example.com") if err != nil { t.Fatalf("GetByBaseURL() error = %v", err) } if got.HostID != "host-base-url" { t.Fatalf("GetByBaseURL() host_id = %q, want host-base-url", got.HostID) } } func TestHostsRepoGetByBaseURLErrors(t *testing.T) { store := openTestDB(t) if _, err := store.Hosts().GetByBaseURL(context.Background(), ""); err == nil { t.Fatal("GetByBaseURL(\"\") error = nil, want error") } if _, err := store.Hosts().GetByBaseURL(context.Background(), "https://missing.example.com"); !errors.Is(err, sql.ErrNoRows) { t.Fatalf("GetByBaseURL(missing) error = %v, want sql.ErrNoRows", err) } } func TestHostsRepoListAll(t *testing.T) { store := openTestDB(t) hosts, err := store.Hosts().ListAll(context.Background()) if err != nil { t.Fatalf("ListAll() on empty DB error = %v", err) } if len(hosts) != 0 { t.Fatalf("ListAll() len = %d, want 0", len(hosts)) } for i := 0; i < 2; i++ { _, err := store.Hosts().Create(context.Background(), Host{ HostID: fmt.Sprintf("host-listall-%d", i), BaseURL: "https://h.com", HostVersion: "0.1.0", }) if err != nil { t.Fatalf("Create() error = %v", err) } } hosts, err = store.Hosts().ListAll(context.Background()) if err != nil { t.Fatalf("ListAll() error = %v", err) } if len(hosts) != 2 { t.Fatalf("ListAll() len = %d, want 2", len(hosts)) } } func TestHostsRepoDeleteByHostID(t *testing.T) { store := openTestDB(t) createTestHost(t, store) if err := store.Hosts().DeleteByHostID(context.Background(), "host-"+sanitizeTestName(t.Name())); err != nil { t.Fatalf("DeleteByHostID() error = %v", err) } hosts, err := store.Hosts().ListAll(context.Background()) if err != nil { t.Fatalf("ListAll() error = %v", err) } if len(hosts) != 0 { t.Fatalf("ListAll() after delete len = %d, want 0", len(hosts)) } } func TestHostsRepoDeleteByHostIDBlocksWhenRuntimeStateExists(t *testing.T) { store := openTestDBWithFK(t) batchID := createTestBatch(t, store) hostRowID := mustHostRowIDForBatch(t, store, batchID) host, err := store.Hosts().GetByID(context.Background(), hostRowID) if err != nil { t.Fatalf("Hosts().GetByID() error = %v", err) } if _, err := store.ManagedResources().Create(context.Background(), ManagedResource{ BatchID: batchID, HostID: host.ID, ResourceType: "group", HostResourceID: "group_1", ResourceName: "group", }); err != nil { t.Fatalf("ManagedResources().Create() error = %v", err) } providerID := mustProviderIDForBatch(t, store, batchID) if _, err := store.ReconcileRuns().Create(context.Background(), ReconcileRun{ BatchID: batchID, HostID: host.ID, ProviderID: providerID, Status: "active", SummaryJSON: `{}`, }); err != nil { t.Fatalf("ReconcileRuns().Create() error = %v", err) } err = store.Hosts().DeleteByHostID(context.Background(), host.HostID) if err == nil { t.Fatal("DeleteByHostID() error = nil, want blocked error") } var blocker *HostDeleteBlocker if !errors.As(err, &blocker) { t.Fatalf("DeleteByHostID() error = %T %v, want HostDeleteBlocker", err, err) } if blocker.ImportBatchCount != 1 || blocker.ManagedResourceCount != 1 || blocker.ReconcileRunCount != 1 { t.Fatalf("blocker = %+v, want all dependency counts = 1", blocker) } } func TestHostsRepoUpdateProbeByHostID(t *testing.T) { store := openTestDB(t) createTestHost(t, store) if err := store.Hosts().UpdateProbeByHostID(context.Background(), "host-"+sanitizeTestName(t.Name()), "0.2.0", `{"groups":true}`); err != nil { t.Fatalf("UpdateProbeByHostID() error = %v", err) } host, err := store.Hosts().GetByHostID(context.Background(), "host-"+sanitizeTestName(t.Name())) if err != nil { t.Fatalf("GetByHostID() error = %v", err) } if host.HostVersion != "0.2.0" || host.CapabilityProbeJSON != `{"groups":true}` { t.Fatalf("updated host = %+v, want version/capability update", host) } } func TestHostsRepoUpdateConnectionByHostID(t *testing.T) { store := openTestDB(t) createTestHost(t, store) if err := store.Hosts().UpdateConnectionByHostID(context.Background(), "host-"+sanitizeTestName(t.Name()), "https://updated.example.com", "0.3.0", "", "", "token-1"); err != nil { t.Fatalf("UpdateConnectionByHostID() error = %v", err) } host, err := store.Hosts().GetByHostID(context.Background(), "host-"+sanitizeTestName(t.Name())) if err != nil { t.Fatalf("GetByHostID() error = %v", err) } if host.BaseURL != "https://updated.example.com" { t.Fatalf("BaseURL = %q, want updated URL", host.BaseURL) } if host.HostVersion != "0.3.0" { t.Fatalf("HostVersion = %q, want 0.3.0", host.HostVersion) } if host.CapabilityProbeJSON != "{}" { t.Fatalf("CapabilityProbeJSON = %q, want {}", host.CapabilityProbeJSON) } if host.AuthType != "apikey" { t.Fatalf("AuthType = %q, want apikey", host.AuthType) } if host.AuthToken != "token-1" { t.Fatalf("AuthToken = %q, want token-1", host.AuthToken) } } func TestHostsRepoUpdateConnectionByHostIDErrors(t *testing.T) { store := openTestDB(t) if err := store.Hosts().UpdateConnectionByHostID(context.Background(), "", "https://example.com", "0.2.0", "{}", "apikey", "token"); err == nil { t.Fatal("UpdateConnectionByHostID() empty host_id error = nil") } if err := store.Hosts().UpdateConnectionByHostID(context.Background(), "missing", "", "0.2.0", "{}", "apikey", "token"); err == nil { t.Fatal("UpdateConnectionByHostID() empty base_url error = nil") } if err := store.Hosts().UpdateConnectionByHostID(context.Background(), "missing", "https://example.com", "", "{}", "apikey", "token"); err == nil { t.Fatal("UpdateConnectionByHostID() empty host_version error = nil") } if err := store.Hosts().UpdateConnectionByHostID(context.Background(), "missing", "https://example.com", "0.2.0", "{}", "apikey", "token"); err == nil { t.Fatal("UpdateConnectionByHostID() missing host error = nil") } } func TestHostsRepoDeleteByHostIDNotFound(t *testing.T) { store := openTestDB(t) err := store.Hosts().DeleteByHostID(context.Background(), "nonexistent") if err == nil { t.Fatal("DeleteByHostID('nonexistent') error = nil, want error") } if err.Error() != `host "nonexistent" not found` { t.Fatalf("DeleteByHostID() error = %q, want not found error", err) } } func TestHostsRepoDeleteByHostIDEmptyError(t *testing.T) { store := openTestDB(t) err := store.Hosts().DeleteByHostID(context.Background(), "") if err == nil { t.Fatal("DeleteByHostID('') error = nil, want error") } } func mustHostRowIDForBatch(t *testing.T, store *DB, batchID int64) int64 { t.Helper() var hostID int64 if err := store.SQLDB().QueryRow(`SELECT host_id FROM import_batches WHERE id = ?`, batchID).Scan(&hostID); err != nil { t.Fatalf("query host_id for batch %d error = %v", batchID, err) } return hostID } func mustProviderIDForBatch(t *testing.T, store *DB, batchID int64) int64 { t.Helper() var providerID int64 if err := store.SQLDB().QueryRow(`SELECT provider_id FROM import_batches WHERE id = ?`, batchID).Scan(&providerID); err != nil { t.Fatalf("query provider_id for batch %d error = %v", batchID, err) } return providerID }