Files
sub2api-cn-relay-manager/internal/store/sqlite/hosts_repo_test.go
2026-05-25 07:30:07 +08:00

422 lines
13 KiB
Go

package sqlite
import (
"context"
"database/sql"
"errors"
"fmt"
"path/filepath"
"testing"
)
// openTestDB creates a test database with foreign keys disabled.
func openTestDB(t *testing.T) *DB {
t.Helper()
dbPath := filepath.Join(t.TempDir(), "test.db")
dsn := "file:" + filepath.ToSlash(dbPath) + "?_busy_timeout=5000&_pragma=foreign_keys(0)"
store, err := Open(context.Background(), dsn)
if err != nil {
t.Fatalf("Open() error = %v", err)
}
t.Cleanup(func() { _ = store.Close() })
return store
}
// openTestDBWithFK creates a test database with foreign keys enforced.
func openTestDBWithFK(t *testing.T) *DB {
t.Helper()
dbPath := filepath.Join(t.TempDir(), "test-fk.db")
dsn := "file:" + filepath.ToSlash(dbPath) + "?_busy_timeout=5000"
store, err := Open(context.Background(), dsn)
if err != nil {
t.Fatalf("Open() error = %v", err)
}
t.Cleanup(func() { _ = store.Close() })
return store
}
func createTestPack(t *testing.T, store *DB) int64 {
t.Helper()
id, err := store.Packs().Create(context.Background(), Pack{
PackID: "pack-" + sanitizeTestName(t.Name()), Version: "1.0.0", Checksum: "chk",
})
if err != nil {
t.Fatalf("createTestPack error = %v", err)
}
return id
}
func createTestHost(t *testing.T, store *DB) int64 {
t.Helper()
id, err := store.Hosts().Create(context.Background(), Host{
HostID: "host-" + sanitizeTestName(t.Name()), BaseURL: "https://h.com", HostVersion: "0.1.0",
})
if err != nil {
t.Fatalf("createTestHost error = %v", err)
}
return id
}
func createTestHostWithBaseURL(t *testing.T, store *DB, hostID, baseURL string) int64 {
t.Helper()
id, err := store.Hosts().Create(context.Background(), Host{
HostID: hostID, BaseURL: baseURL, HostVersion: "0.1.0",
})
if err != nil {
t.Fatalf("createTestHostWithBaseURL error = %v", err)
}
return id
}
func createTestBatch(t *testing.T, store *DB) int64 {
t.Helper()
hostID := createTestHost(t, store)
packID := createTestPack(t, store)
providerID, err := store.Providers().Create(context.Background(), Provider{
PackID: packID, ProviderID: "test-provider", DisplayName: "Test",
BaseURL: "https://t.com", Platform: "openai",
})
if err != nil {
t.Fatalf("createTestBatch create provider error = %v", err)
}
id, err := store.ImportBatches().Create(context.Background(), ImportBatch{
HostID: hostID, PackID: packID, ProviderID: providerID,
Mode: "partial", BatchStatus: "running", AccessStatus: "pending",
})
if err != nil {
t.Fatalf("createTestBatch error = %v", err)
}
return id
}
func createTestBatchItem(t *testing.T, store *DB, batchID int64) int64 {
t.Helper()
id, err := store.ImportBatchItems().Create(context.Background(), ImportBatchItem{
BatchID: batchID, KeyFingerprint: "sha256:test", AccountStatus: "pending",
})
if err != nil {
t.Fatalf("createTestBatchItem error = %v", err)
}
return id
}
func sanitizeTestName(name string) string {
result := ""
for _, c := range name {
if (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-' {
result += string(c)
}
}
if result == "" {
result = "default"
}
return result
}
// --- Hosts Repo Tests ---
func TestHostsRepoCreateAndGet(t *testing.T) {
store := openTestDB(t)
id, err := store.Hosts().Create(context.Background(), Host{
HostID: "host-1",
BaseURL: "https://sub2api.example.com",
HostVersion: "0.1.126",
CapabilityProbeJSON: `{"groups":true}`,
})
if err != nil {
t.Fatalf("Create() error = %v", err)
}
if id <= 0 {
t.Fatalf("Create() id = %d, want positive", id)
}
got, err := store.Hosts().GetByID(context.Background(), id)
if err != nil {
t.Fatalf("GetByID() error = %v", err)
}
if got.HostID != "host-1" || got.BaseURL != "https://sub2api.example.com" {
t.Fatalf("GetByID() = %+v, want host-1", got)
}
got2, err := store.Hosts().GetByHostID(context.Background(), "host-1")
if err != nil {
t.Fatalf("GetByHostID() error = %v", err)
}
if got2.ID != id {
t.Fatalf("GetByHostID() id = %d, want %d", got2.ID, id)
}
}
func TestHostsRepoCreateDefaultsCapabilityProbe(t *testing.T) {
store := openTestDB(t)
id, _ := store.Hosts().Create(context.Background(), Host{
HostID: "host-empty", BaseURL: "https://example.com", HostVersion: "0.1.0",
})
got, _ := store.Hosts().GetByID(context.Background(), id)
if got.CapabilityProbeJSON != "{}" {
t.Fatalf("CapabilityProbeJSON = %q, want {}", got.CapabilityProbeJSON)
}
}
func TestHostsRepoValidationErrors(t *testing.T) {
store := openTestDB(t)
for _, tt := range []struct {
name string
host Host
}{
{"empty host_id", Host{BaseURL: "b", HostVersion: "v"}},
{"empty base_url", Host{HostID: "h", HostVersion: "v"}},
{"empty host_version", Host{HostID: "h", BaseURL: "b"}},
} {
t.Run(tt.name, func(t *testing.T) {
_, err := store.Hosts().Create(context.Background(), tt.host)
if err == nil {
t.Fatal("Create() error = nil, want validation error")
}
})
}
}
func TestHostsRepoGetByIDZeroError(t *testing.T) {
store := openTestDB(t)
_, err := store.Hosts().GetByID(context.Background(), 0)
if err == nil {
t.Fatal("GetByID(0) error = nil, want error")
}
}
func TestHostsRepoGetByIDNotFound(t *testing.T) {
store := openTestDB(t)
_, err := store.Hosts().GetByID(context.Background(), 999)
if !errors.Is(err, sql.ErrNoRows) {
t.Fatalf("GetByID(999) error = %v, want sql.ErrNoRows", err)
}
}
func TestHostsRepoGetByHostIDEmptyError(t *testing.T) {
store := openTestDB(t)
_, err := store.Hosts().GetByHostID(context.Background(), "")
if err == nil {
t.Fatal("GetByHostID('') error = nil, want error")
}
}
func TestHostsRepoGetByHostIDNotFound(t *testing.T) {
store := openTestDB(t)
_, err := store.Hosts().GetByHostID(context.Background(), "nonexistent")
if !errors.Is(err, sql.ErrNoRows) {
t.Fatalf("GetByHostID('nonexistent') error = %v, want sql.ErrNoRows", err)
}
}
func TestHostsRepoGetByBaseURL(t *testing.T) {
store := openTestDB(t)
createTestHostWithBaseURL(t, store, "host-base-url", "https://base-url.example.com")
got, err := store.Hosts().GetByBaseURL(context.Background(), "https://base-url.example.com")
if err != nil {
t.Fatalf("GetByBaseURL() error = %v", err)
}
if got.HostID != "host-base-url" {
t.Fatalf("GetByBaseURL() host_id = %q, want host-base-url", got.HostID)
}
}
func TestHostsRepoGetByBaseURLErrors(t *testing.T) {
store := openTestDB(t)
if _, err := store.Hosts().GetByBaseURL(context.Background(), ""); err == nil {
t.Fatal("GetByBaseURL(\"\") error = nil, want error")
}
if _, err := store.Hosts().GetByBaseURL(context.Background(), "https://missing.example.com"); !errors.Is(err, sql.ErrNoRows) {
t.Fatalf("GetByBaseURL(missing) error = %v, want sql.ErrNoRows", err)
}
}
func TestHostsRepoListAll(t *testing.T) {
store := openTestDB(t)
hosts, err := store.Hosts().ListAll(context.Background())
if err != nil {
t.Fatalf("ListAll() on empty DB error = %v", err)
}
if len(hosts) != 0 {
t.Fatalf("ListAll() len = %d, want 0", len(hosts))
}
for i := 0; i < 2; i++ {
_, err := store.Hosts().Create(context.Background(), Host{
HostID: fmt.Sprintf("host-listall-%d", i), BaseURL: "https://h.com", HostVersion: "0.1.0",
})
if err != nil {
t.Fatalf("Create() error = %v", err)
}
}
hosts, err = store.Hosts().ListAll(context.Background())
if err != nil {
t.Fatalf("ListAll() error = %v", err)
}
if len(hosts) != 2 {
t.Fatalf("ListAll() len = %d, want 2", len(hosts))
}
}
func TestHostsRepoDeleteByHostID(t *testing.T) {
store := openTestDB(t)
createTestHost(t, store)
if err := store.Hosts().DeleteByHostID(context.Background(), "host-"+sanitizeTestName(t.Name())); err != nil {
t.Fatalf("DeleteByHostID() error = %v", err)
}
hosts, err := store.Hosts().ListAll(context.Background())
if err != nil {
t.Fatalf("ListAll() error = %v", err)
}
if len(hosts) != 0 {
t.Fatalf("ListAll() after delete len = %d, want 0", len(hosts))
}
}
func TestHostsRepoDeleteByHostIDBlocksWhenRuntimeStateExists(t *testing.T) {
store := openTestDBWithFK(t)
batchID := createTestBatch(t, store)
hostRowID := mustHostRowIDForBatch(t, store, batchID)
host, err := store.Hosts().GetByID(context.Background(), hostRowID)
if err != nil {
t.Fatalf("Hosts().GetByID() error = %v", err)
}
if _, err := store.ManagedResources().Create(context.Background(), ManagedResource{
BatchID: batchID,
HostID: host.ID,
ResourceType: "group",
HostResourceID: "group_1",
ResourceName: "group",
}); err != nil {
t.Fatalf("ManagedResources().Create() error = %v", err)
}
providerID := mustProviderIDForBatch(t, store, batchID)
if _, err := store.ReconcileRuns().Create(context.Background(), ReconcileRun{
BatchID: batchID,
HostID: host.ID,
ProviderID: providerID,
Status: "active",
SummaryJSON: `{}`,
}); err != nil {
t.Fatalf("ReconcileRuns().Create() error = %v", err)
}
err = store.Hosts().DeleteByHostID(context.Background(), host.HostID)
if err == nil {
t.Fatal("DeleteByHostID() error = nil, want blocked error")
}
var blocker *HostDeleteBlocker
if !errors.As(err, &blocker) {
t.Fatalf("DeleteByHostID() error = %T %v, want HostDeleteBlocker", err, err)
}
if blocker.ImportBatchCount != 1 || blocker.ManagedResourceCount != 1 || blocker.ReconcileRunCount != 1 {
t.Fatalf("blocker = %+v, want all dependency counts = 1", blocker)
}
}
func TestHostsRepoUpdateProbeByHostID(t *testing.T) {
store := openTestDB(t)
createTestHost(t, store)
if err := store.Hosts().UpdateProbeByHostID(context.Background(), "host-"+sanitizeTestName(t.Name()), "0.2.0", `{"groups":true}`); err != nil {
t.Fatalf("UpdateProbeByHostID() error = %v", err)
}
host, err := store.Hosts().GetByHostID(context.Background(), "host-"+sanitizeTestName(t.Name()))
if err != nil {
t.Fatalf("GetByHostID() error = %v", err)
}
if host.HostVersion != "0.2.0" || host.CapabilityProbeJSON != `{"groups":true}` {
t.Fatalf("updated host = %+v, want version/capability update", host)
}
}
func TestHostsRepoUpdateConnectionByHostID(t *testing.T) {
store := openTestDB(t)
createTestHost(t, store)
if err := store.Hosts().UpdateConnectionByHostID(context.Background(), "host-"+sanitizeTestName(t.Name()), "https://updated.example.com", "0.3.0", "", "", "token-1"); err != nil {
t.Fatalf("UpdateConnectionByHostID() error = %v", err)
}
host, err := store.Hosts().GetByHostID(context.Background(), "host-"+sanitizeTestName(t.Name()))
if err != nil {
t.Fatalf("GetByHostID() error = %v", err)
}
if host.BaseURL != "https://updated.example.com" {
t.Fatalf("BaseURL = %q, want updated URL", host.BaseURL)
}
if host.HostVersion != "0.3.0" {
t.Fatalf("HostVersion = %q, want 0.3.0", host.HostVersion)
}
if host.CapabilityProbeJSON != "{}" {
t.Fatalf("CapabilityProbeJSON = %q, want {}", host.CapabilityProbeJSON)
}
if host.AuthType != "apikey" {
t.Fatalf("AuthType = %q, want apikey", host.AuthType)
}
if host.AuthToken != "token-1" {
t.Fatalf("AuthToken = %q, want token-1", host.AuthToken)
}
}
func TestHostsRepoUpdateConnectionByHostIDErrors(t *testing.T) {
store := openTestDB(t)
if err := store.Hosts().UpdateConnectionByHostID(context.Background(), "", "https://example.com", "0.2.0", "{}", "apikey", "token"); err == nil {
t.Fatal("UpdateConnectionByHostID() empty host_id error = nil")
}
if err := store.Hosts().UpdateConnectionByHostID(context.Background(), "missing", "", "0.2.0", "{}", "apikey", "token"); err == nil {
t.Fatal("UpdateConnectionByHostID() empty base_url error = nil")
}
if err := store.Hosts().UpdateConnectionByHostID(context.Background(), "missing", "https://example.com", "", "{}", "apikey", "token"); err == nil {
t.Fatal("UpdateConnectionByHostID() empty host_version error = nil")
}
if err := store.Hosts().UpdateConnectionByHostID(context.Background(), "missing", "https://example.com", "0.2.0", "{}", "apikey", "token"); err == nil {
t.Fatal("UpdateConnectionByHostID() missing host error = nil")
}
}
func TestHostsRepoDeleteByHostIDNotFound(t *testing.T) {
store := openTestDB(t)
err := store.Hosts().DeleteByHostID(context.Background(), "nonexistent")
if err == nil {
t.Fatal("DeleteByHostID('nonexistent') error = nil, want error")
}
if err.Error() != `host "nonexistent" not found` {
t.Fatalf("DeleteByHostID() error = %q, want not found error", err)
}
}
func TestHostsRepoDeleteByHostIDEmptyError(t *testing.T) {
store := openTestDB(t)
err := store.Hosts().DeleteByHostID(context.Background(), "")
if err == nil {
t.Fatal("DeleteByHostID('') error = nil, want error")
}
}
func mustHostRowIDForBatch(t *testing.T, store *DB, batchID int64) int64 {
t.Helper()
var hostID int64
if err := store.SQLDB().QueryRow(`SELECT host_id FROM import_batches WHERE id = ?`, batchID).Scan(&hostID); err != nil {
t.Fatalf("query host_id for batch %d error = %v", batchID, err)
}
return hostID
}
func mustProviderIDForBatch(t *testing.T, store *DB, batchID int64) int64 {
t.Helper()
var providerID int64
if err := store.SQLDB().QueryRow(`SELECT provider_id FROM import_batches WHERE id = ?`, batchID).Scan(&providerID); err != nil {
t.Fatalf("query provider_id for batch %d error = %v", batchID, err)
}
return providerID
}