422 lines
13 KiB
Go
422 lines
13 KiB
Go
package sqlite
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"path/filepath"
|
|
"testing"
|
|
)
|
|
|
|
// openTestDB creates a test database with foreign keys disabled.
|
|
func openTestDB(t *testing.T) *DB {
|
|
t.Helper()
|
|
dbPath := filepath.Join(t.TempDir(), "test.db")
|
|
dsn := "file:" + filepath.ToSlash(dbPath) + "?_busy_timeout=5000&_pragma=foreign_keys(0)"
|
|
store, err := Open(context.Background(), dsn)
|
|
if err != nil {
|
|
t.Fatalf("Open() error = %v", err)
|
|
}
|
|
t.Cleanup(func() { _ = store.Close() })
|
|
return store
|
|
}
|
|
|
|
// openTestDBWithFK creates a test database with foreign keys enforced.
|
|
func openTestDBWithFK(t *testing.T) *DB {
|
|
t.Helper()
|
|
dbPath := filepath.Join(t.TempDir(), "test-fk.db")
|
|
dsn := "file:" + filepath.ToSlash(dbPath) + "?_busy_timeout=5000"
|
|
store, err := Open(context.Background(), dsn)
|
|
if err != nil {
|
|
t.Fatalf("Open() error = %v", err)
|
|
}
|
|
t.Cleanup(func() { _ = store.Close() })
|
|
return store
|
|
}
|
|
|
|
func createTestPack(t *testing.T, store *DB) int64 {
|
|
t.Helper()
|
|
id, err := store.Packs().Create(context.Background(), Pack{
|
|
PackID: "pack-" + sanitizeTestName(t.Name()), Version: "1.0.0", Checksum: "chk",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("createTestPack error = %v", err)
|
|
}
|
|
return id
|
|
}
|
|
|
|
func createTestHost(t *testing.T, store *DB) int64 {
|
|
t.Helper()
|
|
id, err := store.Hosts().Create(context.Background(), Host{
|
|
HostID: "host-" + sanitizeTestName(t.Name()), BaseURL: "https://h.com", HostVersion: "0.1.0",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("createTestHost error = %v", err)
|
|
}
|
|
return id
|
|
}
|
|
|
|
func createTestHostWithBaseURL(t *testing.T, store *DB, hostID, baseURL string) int64 {
|
|
t.Helper()
|
|
id, err := store.Hosts().Create(context.Background(), Host{
|
|
HostID: hostID, BaseURL: baseURL, HostVersion: "0.1.0",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("createTestHostWithBaseURL error = %v", err)
|
|
}
|
|
return id
|
|
}
|
|
|
|
func createTestBatch(t *testing.T, store *DB) int64 {
|
|
t.Helper()
|
|
hostID := createTestHost(t, store)
|
|
packID := createTestPack(t, store)
|
|
providerID, err := store.Providers().Create(context.Background(), Provider{
|
|
PackID: packID, ProviderID: "test-provider", DisplayName: "Test",
|
|
BaseURL: "https://t.com", Platform: "openai",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("createTestBatch create provider error = %v", err)
|
|
}
|
|
id, err := store.ImportBatches().Create(context.Background(), ImportBatch{
|
|
HostID: hostID, PackID: packID, ProviderID: providerID,
|
|
Mode: "partial", BatchStatus: "running", AccessStatus: "pending",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("createTestBatch error = %v", err)
|
|
}
|
|
return id
|
|
}
|
|
|
|
func createTestBatchItem(t *testing.T, store *DB, batchID int64) int64 {
|
|
t.Helper()
|
|
id, err := store.ImportBatchItems().Create(context.Background(), ImportBatchItem{
|
|
BatchID: batchID, KeyFingerprint: "sha256:test", AccountStatus: "pending",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("createTestBatchItem error = %v", err)
|
|
}
|
|
return id
|
|
}
|
|
|
|
func sanitizeTestName(name string) string {
|
|
result := ""
|
|
for _, c := range name {
|
|
if (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-' {
|
|
result += string(c)
|
|
}
|
|
}
|
|
if result == "" {
|
|
result = "default"
|
|
}
|
|
return result
|
|
}
|
|
|
|
// --- Hosts Repo Tests ---
|
|
|
|
func TestHostsRepoCreateAndGet(t *testing.T) {
|
|
store := openTestDB(t)
|
|
|
|
id, err := store.Hosts().Create(context.Background(), Host{
|
|
HostID: "host-1",
|
|
BaseURL: "https://sub2api.example.com",
|
|
HostVersion: "0.1.126",
|
|
CapabilityProbeJSON: `{"groups":true}`,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Create() error = %v", err)
|
|
}
|
|
if id <= 0 {
|
|
t.Fatalf("Create() id = %d, want positive", id)
|
|
}
|
|
|
|
got, err := store.Hosts().GetByID(context.Background(), id)
|
|
if err != nil {
|
|
t.Fatalf("GetByID() error = %v", err)
|
|
}
|
|
if got.HostID != "host-1" || got.BaseURL != "https://sub2api.example.com" {
|
|
t.Fatalf("GetByID() = %+v, want host-1", got)
|
|
}
|
|
|
|
got2, err := store.Hosts().GetByHostID(context.Background(), "host-1")
|
|
if err != nil {
|
|
t.Fatalf("GetByHostID() error = %v", err)
|
|
}
|
|
if got2.ID != id {
|
|
t.Fatalf("GetByHostID() id = %d, want %d", got2.ID, id)
|
|
}
|
|
}
|
|
|
|
func TestHostsRepoCreateDefaultsCapabilityProbe(t *testing.T) {
|
|
store := openTestDB(t)
|
|
id, _ := store.Hosts().Create(context.Background(), Host{
|
|
HostID: "host-empty", BaseURL: "https://example.com", HostVersion: "0.1.0",
|
|
})
|
|
got, _ := store.Hosts().GetByID(context.Background(), id)
|
|
if got.CapabilityProbeJSON != "{}" {
|
|
t.Fatalf("CapabilityProbeJSON = %q, want {}", got.CapabilityProbeJSON)
|
|
}
|
|
}
|
|
|
|
func TestHostsRepoValidationErrors(t *testing.T) {
|
|
store := openTestDB(t)
|
|
for _, tt := range []struct {
|
|
name string
|
|
host Host
|
|
}{
|
|
{"empty host_id", Host{BaseURL: "b", HostVersion: "v"}},
|
|
{"empty base_url", Host{HostID: "h", HostVersion: "v"}},
|
|
{"empty host_version", Host{HostID: "h", BaseURL: "b"}},
|
|
} {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
_, err := store.Hosts().Create(context.Background(), tt.host)
|
|
if err == nil {
|
|
t.Fatal("Create() error = nil, want validation error")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestHostsRepoGetByIDZeroError(t *testing.T) {
|
|
store := openTestDB(t)
|
|
_, err := store.Hosts().GetByID(context.Background(), 0)
|
|
if err == nil {
|
|
t.Fatal("GetByID(0) error = nil, want error")
|
|
}
|
|
}
|
|
|
|
func TestHostsRepoGetByIDNotFound(t *testing.T) {
|
|
store := openTestDB(t)
|
|
_, err := store.Hosts().GetByID(context.Background(), 999)
|
|
if !errors.Is(err, sql.ErrNoRows) {
|
|
t.Fatalf("GetByID(999) error = %v, want sql.ErrNoRows", err)
|
|
}
|
|
}
|
|
|
|
func TestHostsRepoGetByHostIDEmptyError(t *testing.T) {
|
|
store := openTestDB(t)
|
|
_, err := store.Hosts().GetByHostID(context.Background(), "")
|
|
if err == nil {
|
|
t.Fatal("GetByHostID('') error = nil, want error")
|
|
}
|
|
}
|
|
|
|
func TestHostsRepoGetByHostIDNotFound(t *testing.T) {
|
|
store := openTestDB(t)
|
|
_, err := store.Hosts().GetByHostID(context.Background(), "nonexistent")
|
|
if !errors.Is(err, sql.ErrNoRows) {
|
|
t.Fatalf("GetByHostID('nonexistent') error = %v, want sql.ErrNoRows", err)
|
|
}
|
|
}
|
|
|
|
func TestHostsRepoGetByBaseURL(t *testing.T) {
|
|
store := openTestDB(t)
|
|
createTestHostWithBaseURL(t, store, "host-base-url", "https://base-url.example.com")
|
|
|
|
got, err := store.Hosts().GetByBaseURL(context.Background(), "https://base-url.example.com")
|
|
if err != nil {
|
|
t.Fatalf("GetByBaseURL() error = %v", err)
|
|
}
|
|
if got.HostID != "host-base-url" {
|
|
t.Fatalf("GetByBaseURL() host_id = %q, want host-base-url", got.HostID)
|
|
}
|
|
}
|
|
|
|
func TestHostsRepoGetByBaseURLErrors(t *testing.T) {
|
|
store := openTestDB(t)
|
|
|
|
if _, err := store.Hosts().GetByBaseURL(context.Background(), ""); err == nil {
|
|
t.Fatal("GetByBaseURL(\"\") error = nil, want error")
|
|
}
|
|
if _, err := store.Hosts().GetByBaseURL(context.Background(), "https://missing.example.com"); !errors.Is(err, sql.ErrNoRows) {
|
|
t.Fatalf("GetByBaseURL(missing) error = %v, want sql.ErrNoRows", err)
|
|
}
|
|
}
|
|
|
|
func TestHostsRepoListAll(t *testing.T) {
|
|
store := openTestDB(t)
|
|
|
|
hosts, err := store.Hosts().ListAll(context.Background())
|
|
if err != nil {
|
|
t.Fatalf("ListAll() on empty DB error = %v", err)
|
|
}
|
|
if len(hosts) != 0 {
|
|
t.Fatalf("ListAll() len = %d, want 0", len(hosts))
|
|
}
|
|
|
|
for i := 0; i < 2; i++ {
|
|
_, err := store.Hosts().Create(context.Background(), Host{
|
|
HostID: fmt.Sprintf("host-listall-%d", i), BaseURL: "https://h.com", HostVersion: "0.1.0",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Create() error = %v", err)
|
|
}
|
|
}
|
|
|
|
hosts, err = store.Hosts().ListAll(context.Background())
|
|
if err != nil {
|
|
t.Fatalf("ListAll() error = %v", err)
|
|
}
|
|
if len(hosts) != 2 {
|
|
t.Fatalf("ListAll() len = %d, want 2", len(hosts))
|
|
}
|
|
}
|
|
|
|
func TestHostsRepoDeleteByHostID(t *testing.T) {
|
|
store := openTestDB(t)
|
|
createTestHost(t, store)
|
|
|
|
if err := store.Hosts().DeleteByHostID(context.Background(), "host-"+sanitizeTestName(t.Name())); err != nil {
|
|
t.Fatalf("DeleteByHostID() error = %v", err)
|
|
}
|
|
|
|
hosts, err := store.Hosts().ListAll(context.Background())
|
|
if err != nil {
|
|
t.Fatalf("ListAll() error = %v", err)
|
|
}
|
|
if len(hosts) != 0 {
|
|
t.Fatalf("ListAll() after delete len = %d, want 0", len(hosts))
|
|
}
|
|
}
|
|
func TestHostsRepoDeleteByHostIDBlocksWhenRuntimeStateExists(t *testing.T) {
|
|
store := openTestDBWithFK(t)
|
|
batchID := createTestBatch(t, store)
|
|
|
|
hostRowID := mustHostRowIDForBatch(t, store, batchID)
|
|
host, err := store.Hosts().GetByID(context.Background(), hostRowID)
|
|
if err != nil {
|
|
t.Fatalf("Hosts().GetByID() error = %v", err)
|
|
}
|
|
if _, err := store.ManagedResources().Create(context.Background(), ManagedResource{
|
|
BatchID: batchID,
|
|
HostID: host.ID,
|
|
ResourceType: "group",
|
|
HostResourceID: "group_1",
|
|
ResourceName: "group",
|
|
}); err != nil {
|
|
t.Fatalf("ManagedResources().Create() error = %v", err)
|
|
}
|
|
providerID := mustProviderIDForBatch(t, store, batchID)
|
|
if _, err := store.ReconcileRuns().Create(context.Background(), ReconcileRun{
|
|
BatchID: batchID,
|
|
HostID: host.ID,
|
|
ProviderID: providerID,
|
|
Status: "active",
|
|
SummaryJSON: `{}`,
|
|
}); err != nil {
|
|
t.Fatalf("ReconcileRuns().Create() error = %v", err)
|
|
}
|
|
|
|
err = store.Hosts().DeleteByHostID(context.Background(), host.HostID)
|
|
if err == nil {
|
|
t.Fatal("DeleteByHostID() error = nil, want blocked error")
|
|
}
|
|
var blocker *HostDeleteBlocker
|
|
if !errors.As(err, &blocker) {
|
|
t.Fatalf("DeleteByHostID() error = %T %v, want HostDeleteBlocker", err, err)
|
|
}
|
|
if blocker.ImportBatchCount != 1 || blocker.ManagedResourceCount != 1 || blocker.ReconcileRunCount != 1 {
|
|
t.Fatalf("blocker = %+v, want all dependency counts = 1", blocker)
|
|
}
|
|
}
|
|
|
|
func TestHostsRepoUpdateProbeByHostID(t *testing.T) {
|
|
store := openTestDB(t)
|
|
createTestHost(t, store)
|
|
|
|
if err := store.Hosts().UpdateProbeByHostID(context.Background(), "host-"+sanitizeTestName(t.Name()), "0.2.0", `{"groups":true}`); err != nil {
|
|
t.Fatalf("UpdateProbeByHostID() error = %v", err)
|
|
}
|
|
|
|
host, err := store.Hosts().GetByHostID(context.Background(), "host-"+sanitizeTestName(t.Name()))
|
|
if err != nil {
|
|
t.Fatalf("GetByHostID() error = %v", err)
|
|
}
|
|
if host.HostVersion != "0.2.0" || host.CapabilityProbeJSON != `{"groups":true}` {
|
|
t.Fatalf("updated host = %+v, want version/capability update", host)
|
|
}
|
|
}
|
|
|
|
func TestHostsRepoUpdateConnectionByHostID(t *testing.T) {
|
|
store := openTestDB(t)
|
|
createTestHost(t, store)
|
|
|
|
if err := store.Hosts().UpdateConnectionByHostID(context.Background(), "host-"+sanitizeTestName(t.Name()), "https://updated.example.com", "0.3.0", "", "", "token-1"); err != nil {
|
|
t.Fatalf("UpdateConnectionByHostID() error = %v", err)
|
|
}
|
|
|
|
host, err := store.Hosts().GetByHostID(context.Background(), "host-"+sanitizeTestName(t.Name()))
|
|
if err != nil {
|
|
t.Fatalf("GetByHostID() error = %v", err)
|
|
}
|
|
if host.BaseURL != "https://updated.example.com" {
|
|
t.Fatalf("BaseURL = %q, want updated URL", host.BaseURL)
|
|
}
|
|
if host.HostVersion != "0.3.0" {
|
|
t.Fatalf("HostVersion = %q, want 0.3.0", host.HostVersion)
|
|
}
|
|
if host.CapabilityProbeJSON != "{}" {
|
|
t.Fatalf("CapabilityProbeJSON = %q, want {}", host.CapabilityProbeJSON)
|
|
}
|
|
if host.AuthType != "apikey" {
|
|
t.Fatalf("AuthType = %q, want apikey", host.AuthType)
|
|
}
|
|
if host.AuthToken != "token-1" {
|
|
t.Fatalf("AuthToken = %q, want token-1", host.AuthToken)
|
|
}
|
|
}
|
|
|
|
func TestHostsRepoUpdateConnectionByHostIDErrors(t *testing.T) {
|
|
store := openTestDB(t)
|
|
|
|
if err := store.Hosts().UpdateConnectionByHostID(context.Background(), "", "https://example.com", "0.2.0", "{}", "apikey", "token"); err == nil {
|
|
t.Fatal("UpdateConnectionByHostID() empty host_id error = nil")
|
|
}
|
|
if err := store.Hosts().UpdateConnectionByHostID(context.Background(), "missing", "", "0.2.0", "{}", "apikey", "token"); err == nil {
|
|
t.Fatal("UpdateConnectionByHostID() empty base_url error = nil")
|
|
}
|
|
if err := store.Hosts().UpdateConnectionByHostID(context.Background(), "missing", "https://example.com", "", "{}", "apikey", "token"); err == nil {
|
|
t.Fatal("UpdateConnectionByHostID() empty host_version error = nil")
|
|
}
|
|
if err := store.Hosts().UpdateConnectionByHostID(context.Background(), "missing", "https://example.com", "0.2.0", "{}", "apikey", "token"); err == nil {
|
|
t.Fatal("UpdateConnectionByHostID() missing host error = nil")
|
|
}
|
|
}
|
|
|
|
func TestHostsRepoDeleteByHostIDNotFound(t *testing.T) {
|
|
store := openTestDB(t)
|
|
err := store.Hosts().DeleteByHostID(context.Background(), "nonexistent")
|
|
if err == nil {
|
|
t.Fatal("DeleteByHostID('nonexistent') error = nil, want error")
|
|
}
|
|
if err.Error() != `host "nonexistent" not found` {
|
|
t.Fatalf("DeleteByHostID() error = %q, want not found error", err)
|
|
}
|
|
}
|
|
|
|
func TestHostsRepoDeleteByHostIDEmptyError(t *testing.T) {
|
|
store := openTestDB(t)
|
|
err := store.Hosts().DeleteByHostID(context.Background(), "")
|
|
if err == nil {
|
|
t.Fatal("DeleteByHostID('') error = nil, want error")
|
|
}
|
|
}
|
|
func mustHostRowIDForBatch(t *testing.T, store *DB, batchID int64) int64 {
|
|
t.Helper()
|
|
var hostID int64
|
|
if err := store.SQLDB().QueryRow(`SELECT host_id FROM import_batches WHERE id = ?`, batchID).Scan(&hostID); err != nil {
|
|
t.Fatalf("query host_id for batch %d error = %v", batchID, err)
|
|
}
|
|
return hostID
|
|
}
|
|
|
|
func mustProviderIDForBatch(t *testing.T, store *DB, batchID int64) int64 {
|
|
t.Helper()
|
|
var providerID int64
|
|
if err := store.SQLDB().QueryRow(`SELECT provider_id FROM import_batches WHERE id = ?`, batchID).Scan(&providerID); err != nil {
|
|
t.Fatalf("query provider_id for batch %d error = %v", batchID, err)
|
|
}
|
|
return providerID
|
|
}
|