273 lines
8.4 KiB
Go
273 lines
8.4 KiB
Go
package sqlite
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"strings"
|
|
)
|
|
|
|
type Host struct {
|
|
ID int64
|
|
HostID string
|
|
BaseURL string
|
|
HostVersion string
|
|
CapabilityProbeJSON string
|
|
AuthType string
|
|
AuthToken string
|
|
}
|
|
|
|
type HostsRepo struct {
|
|
db execQuerier
|
|
}
|
|
|
|
func newHostsRepo(db execQuerier) *HostsRepo {
|
|
return &HostsRepo{db: db}
|
|
}
|
|
|
|
func (r *HostsRepo) GetByID(ctx context.Context, id int64) (Host, error) {
|
|
if id <= 0 {
|
|
return Host{}, fmt.Errorf("id is required")
|
|
}
|
|
|
|
var host Host
|
|
if err := r.db.QueryRowContext(ctx, `SELECT id, host_id, base_url, host_version, capability_probe_json, auth_type, auth_token FROM hosts WHERE id = ?`, id).Scan(&host.ID, &host.HostID, &host.BaseURL, &host.HostVersion, &host.CapabilityProbeJSON, &host.AuthType, &host.AuthToken); err != nil {
|
|
return Host{}, err
|
|
}
|
|
return host, nil
|
|
}
|
|
|
|
func (r *HostsRepo) GetByHostID(ctx context.Context, hostID string) (Host, error) {
|
|
hostID = strings.TrimSpace(hostID)
|
|
if hostID == "" {
|
|
return Host{}, fmt.Errorf("host_id is required")
|
|
}
|
|
|
|
var host Host
|
|
if err := r.db.QueryRowContext(ctx, `SELECT id, host_id, base_url, host_version, capability_probe_json, auth_type, auth_token FROM hosts WHERE host_id = ?`, hostID).Scan(&host.ID, &host.HostID, &host.BaseURL, &host.HostVersion, &host.CapabilityProbeJSON, &host.AuthType, &host.AuthToken); err != nil {
|
|
return Host{}, err
|
|
}
|
|
return host, nil
|
|
}
|
|
|
|
func (r *HostsRepo) GetByBaseURL(ctx context.Context, baseURL string) (Host, error) {
|
|
baseURL = strings.TrimSpace(baseURL)
|
|
if baseURL == "" {
|
|
return Host{}, fmt.Errorf("base_url is required")
|
|
}
|
|
|
|
var host Host
|
|
if err := r.db.QueryRowContext(ctx, `SELECT id, host_id, base_url, host_version, capability_probe_json, auth_type, auth_token FROM hosts WHERE base_url = ?`, baseURL).Scan(&host.ID, &host.HostID, &host.BaseURL, &host.HostVersion, &host.CapabilityProbeJSON, &host.AuthType, &host.AuthToken); err != nil {
|
|
return Host{}, err
|
|
}
|
|
return host, nil
|
|
}
|
|
|
|
func (r *HostsRepo) Create(ctx context.Context, host Host) (int64, error) {
|
|
hostID := strings.TrimSpace(host.HostID)
|
|
baseURL := strings.TrimSpace(host.BaseURL)
|
|
hostVersion := strings.TrimSpace(host.HostVersion)
|
|
capabilityProbeJSON := strings.TrimSpace(host.CapabilityProbeJSON)
|
|
authType := firstNonEmptyTrimmed(host.AuthType, "apikey")
|
|
authToken := strings.TrimSpace(host.AuthToken)
|
|
|
|
switch {
|
|
case hostID == "":
|
|
return 0, fmt.Errorf("host_id is required")
|
|
case baseURL == "":
|
|
return 0, fmt.Errorf("base_url is required")
|
|
case hostVersion == "":
|
|
return 0, fmt.Errorf("host_version is required")
|
|
case capabilityProbeJSON == "":
|
|
capabilityProbeJSON = "{}"
|
|
}
|
|
|
|
result, err := r.db.ExecContext(
|
|
ctx,
|
|
`INSERT INTO hosts (host_id, base_url, host_version, capability_probe_json, auth_type, auth_token)
|
|
VALUES (?, ?, ?, ?, ?, ?)`,
|
|
hostID,
|
|
baseURL,
|
|
hostVersion,
|
|
capabilityProbeJSON,
|
|
authType,
|
|
authToken,
|
|
)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("insert host %q: %w", hostID, err)
|
|
}
|
|
|
|
id, err := result.LastInsertId()
|
|
if err != nil {
|
|
return 0, fmt.Errorf("read inserted host id for %q: %w", hostID, err)
|
|
}
|
|
|
|
return id, nil
|
|
}
|
|
|
|
func (r *HostsRepo) UpdateConnectionByHostID(ctx context.Context, hostID, baseURL, hostVersion, capabilityProbeJSON, authType, authToken string) error {
|
|
hostID = strings.TrimSpace(hostID)
|
|
baseURL = strings.TrimSpace(baseURL)
|
|
hostVersion = strings.TrimSpace(hostVersion)
|
|
capabilityProbeJSON = strings.TrimSpace(capabilityProbeJSON)
|
|
authType = firstNonEmptyTrimmed(authType, "apikey")
|
|
authToken = strings.TrimSpace(authToken)
|
|
if hostID == "" {
|
|
return fmt.Errorf("host_id is required")
|
|
}
|
|
if baseURL == "" {
|
|
return fmt.Errorf("base_url is required")
|
|
}
|
|
if hostVersion == "" {
|
|
return fmt.Errorf("host_version is required")
|
|
}
|
|
if capabilityProbeJSON == "" {
|
|
capabilityProbeJSON = "{}"
|
|
}
|
|
|
|
result, err := r.db.ExecContext(ctx, `UPDATE hosts SET base_url = ?, host_version = ?, capability_probe_json = ?, auth_type = ?, auth_token = ? WHERE host_id = ?`, baseURL, hostVersion, capabilityProbeJSON, authType, authToken, hostID)
|
|
if err != nil {
|
|
return fmt.Errorf("update host %q connection: %w", hostID, err)
|
|
}
|
|
rows, err := result.RowsAffected()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if rows == 0 {
|
|
return fmt.Errorf("host %q not found", hostID)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *HostsRepo) ListAll(ctx context.Context) ([]Host, error) {
|
|
rows, err := r.db.QueryContext(ctx, `SELECT id, host_id, base_url, host_version, capability_probe_json, auth_type, auth_token FROM hosts ORDER BY id`)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list hosts: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var hosts []Host
|
|
for rows.Next() {
|
|
var host Host
|
|
if err := rows.Scan(&host.ID, &host.HostID, &host.BaseURL, &host.HostVersion, &host.CapabilityProbeJSON, &host.AuthType, &host.AuthToken); err != nil {
|
|
return nil, fmt.Errorf("scan host: %w", err)
|
|
}
|
|
hosts = append(hosts, host)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("iterate hosts: %w", err)
|
|
}
|
|
return hosts, nil
|
|
}
|
|
func (r *HostsRepo) RuntimeDependencyCountsByHostID(ctx context.Context, hostID string) (HostDeleteBlocker, error) {
|
|
hostID = strings.TrimSpace(hostID)
|
|
if hostID == "" {
|
|
return HostDeleteBlocker{}, fmt.Errorf("host_id is required")
|
|
}
|
|
|
|
host, err := r.GetByHostID(ctx, hostID)
|
|
if err != nil {
|
|
return HostDeleteBlocker{}, err
|
|
}
|
|
|
|
blocker := HostDeleteBlocker{HostID: host.HostID}
|
|
if err := r.db.QueryRowContext(ctx, `SELECT COUNT(1) FROM import_batches WHERE host_id = ?`, host.ID).Scan(&blocker.ImportBatchCount); err != nil {
|
|
return HostDeleteBlocker{}, fmt.Errorf("count import batches for host %q: %w", hostID, err)
|
|
}
|
|
if err := r.db.QueryRowContext(ctx, `SELECT COUNT(1) FROM managed_resources WHERE host_id = ?`, host.ID).Scan(&blocker.ManagedResourceCount); err != nil {
|
|
return HostDeleteBlocker{}, fmt.Errorf("count managed resources for host %q: %w", hostID, err)
|
|
}
|
|
if err := r.db.QueryRowContext(ctx, `SELECT COUNT(1) FROM reconcile_runs WHERE host_id = ?`, host.ID).Scan(&blocker.ReconcileRunCount); err != nil {
|
|
return HostDeleteBlocker{}, fmt.Errorf("count reconcile runs for host %q: %w", hostID, err)
|
|
}
|
|
return blocker, nil
|
|
}
|
|
|
|
type HostDeleteBlocker struct {
|
|
HostID string
|
|
ImportBatchCount int
|
|
ManagedResourceCount int
|
|
ReconcileRunCount int
|
|
}
|
|
|
|
func (e *HostDeleteBlocker) Error() string {
|
|
if e == nil {
|
|
return "host delete is blocked"
|
|
}
|
|
return fmt.Sprintf(
|
|
"host %q cannot be deleted while runtime state exists (import_batches=%d managed_resources=%d reconcile_runs=%d)",
|
|
e.HostID,
|
|
e.ImportBatchCount,
|
|
e.ManagedResourceCount,
|
|
e.ReconcileRunCount,
|
|
)
|
|
}
|
|
|
|
func (r *HostsRepo) DeleteByHostID(ctx context.Context, hostID string) error {
|
|
hostID = strings.TrimSpace(hostID)
|
|
if hostID == "" {
|
|
return fmt.Errorf("host_id is required")
|
|
}
|
|
|
|
blocker, err := r.RuntimeDependencyCountsByHostID(ctx, hostID)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return fmt.Errorf("host %q not found", hostID)
|
|
}
|
|
return fmt.Errorf("resolve host %q runtime dependencies: %w", hostID, err)
|
|
}
|
|
if blocker.ImportBatchCount > 0 || blocker.ManagedResourceCount > 0 || blocker.ReconcileRunCount > 0 {
|
|
return &blocker
|
|
}
|
|
|
|
result, err := r.db.ExecContext(ctx, `DELETE FROM hosts WHERE host_id = ?`, hostID)
|
|
if err != nil {
|
|
return fmt.Errorf("delete host %q: %w", hostID, err)
|
|
}
|
|
rows, err := result.RowsAffected()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if rows == 0 {
|
|
return fmt.Errorf("host %q not found", hostID)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *HostsRepo) UpdateProbeByHostID(ctx context.Context, hostID, hostVersion, capabilityProbeJSON string) error {
|
|
hostID = strings.TrimSpace(hostID)
|
|
hostVersion = strings.TrimSpace(hostVersion)
|
|
capabilityProbeJSON = strings.TrimSpace(capabilityProbeJSON)
|
|
if hostID == "" {
|
|
return fmt.Errorf("host_id is required")
|
|
}
|
|
if hostVersion == "" {
|
|
return fmt.Errorf("host_version is required")
|
|
}
|
|
if capabilityProbeJSON == "" {
|
|
capabilityProbeJSON = "{}"
|
|
}
|
|
|
|
result, err := r.db.ExecContext(ctx, `UPDATE hosts SET host_version = ?, capability_probe_json = ? WHERE host_id = ?`, hostVersion, capabilityProbeJSON, hostID)
|
|
if err != nil {
|
|
return fmt.Errorf("update host %q probe: %w", hostID, err)
|
|
}
|
|
rows, err := result.RowsAffected()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if rows == 0 {
|
|
return fmt.Errorf("host %q not found", hostID)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func firstNonEmptyTrimmed(values ...string) string {
|
|
for _, value := range values {
|
|
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
|
return trimmed
|
|
}
|
|
}
|
|
return ""
|
|
}
|