package sqlite import ( "context" "database/sql" "errors" "fmt" "path/filepath" "testing" ) // openTestDB creates a test database with foreign keys disabled. func openTestDB(t *testing.T) *DB { t.Helper() dbPath := filepath.Join(t.TempDir(), "test.db") dsn := "file:" + filepath.ToSlash(dbPath) + "?_pragma=foreign_keys(0)" store, err := Open(context.Background(), dsn) if err != nil { t.Fatalf("Open() error = %v", err) } t.Cleanup(func() { store.Close() }) return store } // openTestDBWithFK creates a test database with foreign keys enforced. func openTestDBWithFK(t *testing.T) *DB { t.Helper() dbPath := filepath.Join(t.TempDir(), "test-fk.db") dsn := "file:" + filepath.ToSlash(dbPath) store, err := Open(context.Background(), dsn) if err != nil { t.Fatalf("Open() error = %v", err) } t.Cleanup(func() { store.Close() }) return store } func createTestPack(t *testing.T, store *DB) int64 { t.Helper() id, err := store.Packs().Create(context.Background(), Pack{ PackID: "pack-" + sanitizeTestName(t.Name()), Version: "1.0.0", Checksum: "chk", }) if err != nil { t.Fatalf("createTestPack error = %v", err) } return id } func createTestHost(t *testing.T, store *DB) int64 { t.Helper() id, err := store.Hosts().Create(context.Background(), Host{ HostID: "host-" + sanitizeTestName(t.Name()), BaseURL: "https://h.com", HostVersion: "0.1.0", }) if err != nil { t.Fatalf("createTestHost error = %v", err) } return id } func createTestHostWithBaseURL(t *testing.T, store *DB, hostID, baseURL string) int64 { t.Helper() id, err := store.Hosts().Create(context.Background(), Host{ HostID: hostID, BaseURL: baseURL, HostVersion: "0.1.0", }) if err != nil { t.Fatalf("createTestHostWithBaseURL error = %v", err) } return id } func createTestBatch(t *testing.T, store *DB) int64 { t.Helper() hostID := createTestHost(t, store) packID := createTestPack(t, store) providerID, err := store.Providers().Create(context.Background(), Provider{ PackID: packID, ProviderID: "test-provider", DisplayName: "Test", BaseURL: "https://t.com", Platform: "openai", }) if err != nil { t.Fatalf("createTestBatch create provider error = %v", err) } id, err := store.ImportBatches().Create(context.Background(), ImportBatch{ HostID: hostID, PackID: packID, ProviderID: providerID, Mode: "partial", BatchStatus: "running", AccessStatus: "pending", }) if err != nil { t.Fatalf("createTestBatch error = %v", err) } return id } func createTestBatchItem(t *testing.T, store *DB, batchID int64) int64 { t.Helper() id, err := store.ImportBatchItems().Create(context.Background(), ImportBatchItem{ BatchID: batchID, KeyFingerprint: "sha256:test", AccountStatus: "pending", }) if err != nil { t.Fatalf("createTestBatchItem error = %v", err) } return id } func sanitizeTestName(name string) string { result := "" for _, c := range name { if (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-' { result += string(c) } } if result == "" { result = "default" } return result } // --- Hosts Repo Tests --- func TestHostsRepoCreateAndGet(t *testing.T) { store := openTestDB(t) id, err := store.Hosts().Create(context.Background(), Host{ HostID: "host-1", BaseURL: "https://sub2api.example.com", HostVersion: "0.1.126", CapabilityProbeJSON: `{"groups":true}`, }) if err != nil { t.Fatalf("Create() error = %v", err) } if id <= 0 { t.Fatalf("Create() id = %d, want positive", id) } got, err := store.Hosts().GetByID(context.Background(), id) if err != nil { t.Fatalf("GetByID() error = %v", err) } if got.HostID != "host-1" || got.BaseURL != "https://sub2api.example.com" { t.Fatalf("GetByID() = %+v, want host-1", got) } got2, err := store.Hosts().GetByHostID(context.Background(), "host-1") if err != nil { t.Fatalf("GetByHostID() error = %v", err) } if got2.ID != id { t.Fatalf("GetByHostID() id = %d, want %d", got2.ID, id) } } func TestHostsRepoCreateDefaultsCapabilityProbe(t *testing.T) { store := openTestDB(t) id, _ := store.Hosts().Create(context.Background(), Host{ HostID: "host-empty", BaseURL: "https://example.com", HostVersion: "0.1.0", }) got, _ := store.Hosts().GetByID(context.Background(), id) if got.CapabilityProbeJSON != "{}" { t.Fatalf("CapabilityProbeJSON = %q, want {}", got.CapabilityProbeJSON) } } func TestHostsRepoValidationErrors(t *testing.T) { store := openTestDB(t) for _, tt := range []struct { name string host Host }{ {"empty host_id", Host{BaseURL: "b", HostVersion: "v"}}, {"empty base_url", Host{HostID: "h", HostVersion: "v"}}, {"empty host_version", Host{HostID: "h", BaseURL: "b"}}, } { t.Run(tt.name, func(t *testing.T) { _, err := store.Hosts().Create(context.Background(), tt.host) if err == nil { t.Fatal("Create() error = nil, want validation error") } }) } } func TestHostsRepoGetByIDZeroError(t *testing.T) { store := openTestDB(t) _, err := store.Hosts().GetByID(context.Background(), 0) if err == nil { t.Fatal("GetByID(0) error = nil, want error") } } func TestHostsRepoGetByIDNotFound(t *testing.T) { store := openTestDB(t) _, err := store.Hosts().GetByID(context.Background(), 999) if !errors.Is(err, sql.ErrNoRows) { t.Fatalf("GetByID(999) error = %v, want sql.ErrNoRows", err) } } func TestHostsRepoGetByHostIDEmptyError(t *testing.T) { store := openTestDB(t) _, err := store.Hosts().GetByHostID(context.Background(), "") if err == nil { t.Fatal("GetByHostID('') error = nil, want error") } } func TestHostsRepoGetByHostIDNotFound(t *testing.T) { store := openTestDB(t) _, err := store.Hosts().GetByHostID(context.Background(), "nonexistent") if !errors.Is(err, sql.ErrNoRows) { t.Fatalf("GetByHostID('nonexistent') error = %v, want sql.ErrNoRows", err) } } func TestHostsRepoListAll(t *testing.T) { store := openTestDB(t) hosts, err := store.Hosts().ListAll(context.Background()) if err != nil { t.Fatalf("ListAll() on empty DB error = %v", err) } if len(hosts) != 0 { t.Fatalf("ListAll() len = %d, want 0", len(hosts)) } for i := 0; i < 2; i++ { _, err := store.Hosts().Create(context.Background(), Host{ HostID: fmt.Sprintf("host-listall-%d", i), BaseURL: "https://h.com", HostVersion: "0.1.0", }) if err != nil { t.Fatalf("Create() error = %v", err) } } hosts, err = store.Hosts().ListAll(context.Background()) if err != nil { t.Fatalf("ListAll() error = %v", err) } if len(hosts) != 2 { t.Fatalf("ListAll() len = %d, want 2", len(hosts)) } } func TestHostsRepoDeleteByHostID(t *testing.T) { store := openTestDB(t) createTestHost(t, store) if err := store.Hosts().DeleteByHostID(context.Background(), "host-"+sanitizeTestName(t.Name())); err != nil { t.Fatalf("DeleteByHostID() error = %v", err) } hosts, err := store.Hosts().ListAll(context.Background()) if err != nil { t.Fatalf("ListAll() error = %v", err) } if len(hosts) != 0 { t.Fatalf("ListAll() after delete len = %d, want 0", len(hosts)) } } func TestHostsRepoUpdateProbeByHostID(t *testing.T) { store := openTestDB(t) createTestHost(t, store) if err := store.Hosts().UpdateProbeByHostID(context.Background(), "host-"+sanitizeTestName(t.Name()), "0.2.0", `{"groups":true}`); err != nil { t.Fatalf("UpdateProbeByHostID() error = %v", err) } host, err := store.Hosts().GetByHostID(context.Background(), "host-"+sanitizeTestName(t.Name())) if err != nil { t.Fatalf("GetByHostID() error = %v", err) } if host.HostVersion != "0.2.0" || host.CapabilityProbeJSON != `{"groups":true}` { t.Fatalf("updated host = %+v, want version/capability update", host) } } func TestHostsRepoDeleteByHostIDNotFound(t *testing.T) { store := openTestDB(t) err := store.Hosts().DeleteByHostID(context.Background(), "nonexistent") if err == nil { t.Fatal("DeleteByHostID('nonexistent') error = nil, want error") } if err.Error() != `host "nonexistent" not found` { t.Fatalf("DeleteByHostID() error = %q, want not found error", err) } } func TestHostsRepoDeleteByHostIDEmptyError(t *testing.T) { store := openTestDB(t) err := store.Hosts().DeleteByHostID(context.Background(), "") if err == nil { t.Fatal("DeleteByHostID('') error = nil, want error") } }