feat(accounts): add provider account inventory api

This commit is contained in:
phamnazage-jpg
2026-05-29 14:43:34 +08:00
parent 83148ce3c1
commit b5343452cb
12 changed files with 1332 additions and 0 deletions

View File

@@ -61,6 +61,10 @@ type ActionSet struct {
GetRouteFailure func(context.Context, GetRouteFailureRequest) (RouteFailureInfo, error)
SetRouteCooldown func(context.Context, SetRouteCooldownRequest) (RouteCooldownInfo, error)
GetRouteCooldown func(context.Context, GetRouteCooldownRequest) (RouteCooldownInfo, error)
ListProviderAccounts func(context.Context, ListProviderAccountsRequest) ([]ProviderAccountInfo, error)
EnableProviderAccount func(context.Context, UpdateProviderAccountStatusRequest) (ProviderAccountInfo, error)
DisableProviderAccount func(context.Context, UpdateProviderAccountStatusRequest) (ProviderAccountInfo, error)
RetireProviderAccount func(context.Context, UpdateProviderAccountStatusRequest) (ProviderAccountInfo, error)
CreateProviderDraft func(context.Context, CreateProviderDraftRequest) (ProviderDraftInfo, error)
ListProviderDrafts func(context.Context, ListProviderDraftsRequest) ([]ProviderDraftInfo, error)
GetProviderDraft func(context.Context, string) (ProviderDraftInfo, error)
@@ -432,6 +436,18 @@ func NewAPIHandlerWithAuth(adminAuth AdminAuthConfig, actions ActionSet) http.Ha
mux.Handle("GET /api/routing/sticky/cooldowns", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleGetRouteCooldown(w, r, actions.GetRouteCooldown)
})))
mux.Handle("GET /api/provider-accounts", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleListProviderAccounts(w, r, actions.ListProviderAccounts)
})))
mux.Handle("POST /api/provider-accounts/{accountID}/enable", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleEnableProviderAccount(w, r, actions.EnableProviderAccount)
})))
mux.Handle("POST /api/provider-accounts/{accountID}/disable", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleDisableProviderAccount(w, r, actions.DisableProviderAccount)
})))
mux.Handle("POST /api/provider-accounts/{accountID}/retire", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleRetireProviderAccount(w, r, actions.RetireProviderAccount)
})))
mux.Handle("POST /api/provider-drafts", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleCreateProviderDraft(w, r, actions.CreateProviderDraft)
})))
@@ -1280,6 +1296,10 @@ func NewActionSetWithStickyRuntime(sqliteDSN string, stickyRuntime stickyStoreRu
GetRouteFailure: buildGetRouteFailureAction(stickyRuntime),
SetRouteCooldown: buildSetRouteCooldownAction(stickyRuntime),
GetRouteCooldown: buildGetRouteCooldownAction(stickyRuntime),
ListProviderAccounts: buildListProviderAccountsAction(sqliteDSN),
EnableProviderAccount: buildUpdateProviderAccountStatusAction(sqliteDSN, sqlite.ProviderAccountStatusActive),
DisableProviderAccount: buildUpdateProviderAccountStatusAction(sqliteDSN, sqlite.ProviderAccountStatusDisabled),
RetireProviderAccount: buildUpdateProviderAccountStatusAction(sqliteDSN, sqlite.ProviderAccountStatusDeprecated),
CreateProviderDraft: func(ctx context.Context, req CreateProviderDraftRequest) (ProviderDraftInfo, error) {
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {

View File

@@ -0,0 +1,188 @@
package app
import (
"context"
"database/sql"
"fmt"
"net/http"
"strconv"
"strings"
"sub2api-cn-relay-manager/internal/store/sqlite"
)
type ListProviderAccountsRequest struct {
HostID string
ProviderID string
RouteID string
ShadowGroupID string
AccountStatus string
Query string
Limit int
}
type UpdateProviderAccountStatusRequest struct {
AccountID int64 `json:"-"`
AccountStatus string `json:"-"`
DisabledReason string `json:"reason,omitempty"`
}
type ProviderAccountInfo struct {
ID int64 `json:"id"`
HostID string `json:"host_id"`
ProviderID string `json:"provider_id"`
ProviderName string `json:"provider_name"`
RouteID string `json:"route_id,omitempty"`
LogicalGroupID string `json:"logical_group_id,omitempty"`
ShadowGroupID string `json:"shadow_group_id,omitempty"`
HostAccountID string `json:"host_account_id"`
KeyFingerprint string `json:"key_fingerprint"`
AccountName string `json:"account_name"`
AccountStatus string `json:"account_status"`
LastProbeStatus string `json:"last_probe_status,omitempty"`
LastProbeAt string `json:"last_probe_at,omitempty"`
DisabledReason string `json:"disabled_reason,omitempty"`
CreatedAt string `json:"created_at,omitempty"`
UpdatedAt string `json:"updated_at,omitempty"`
}
func handleListProviderAccounts(w http.ResponseWriter, r *http.Request, fn func(context.Context, ListProviderAccountsRequest) ([]ProviderAccountInfo, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "list-provider-accounts action is not configured"})
return
}
accounts, err := fn(r.Context(), ListProviderAccountsRequest{
HostID: strings.TrimSpace(r.URL.Query().Get("host_id")),
ProviderID: strings.TrimSpace(r.URL.Query().Get("provider_id")),
RouteID: strings.TrimSpace(r.URL.Query().Get("route_id")),
ShadowGroupID: strings.TrimSpace(r.URL.Query().Get("shadow_group_id")),
AccountStatus: strings.TrimSpace(r.URL.Query().Get("account_status")),
Query: strings.TrimSpace(r.URL.Query().Get("q")),
Limit: parsePositiveInt(r.URL.Query().Get("limit")),
})
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
if accounts == nil {
accounts = []ProviderAccountInfo{}
}
writeJSON(w, http.StatusOK, map[string]any{"provider_accounts": accounts})
}
func handleEnableProviderAccount(w http.ResponseWriter, r *http.Request, fn func(context.Context, UpdateProviderAccountStatusRequest) (ProviderAccountInfo, error)) {
handleUpdateProviderAccountStatus(w, r, fn, sqlite.ProviderAccountStatusActive)
}
func handleDisableProviderAccount(w http.ResponseWriter, r *http.Request, fn func(context.Context, UpdateProviderAccountStatusRequest) (ProviderAccountInfo, error)) {
handleUpdateProviderAccountStatus(w, r, fn, sqlite.ProviderAccountStatusDisabled)
}
func handleRetireProviderAccount(w http.ResponseWriter, r *http.Request, fn func(context.Context, UpdateProviderAccountStatusRequest) (ProviderAccountInfo, error)) {
handleUpdateProviderAccountStatus(w, r, fn, sqlite.ProviderAccountStatusDeprecated)
}
func handleUpdateProviderAccountStatus(w http.ResponseWriter, r *http.Request, fn func(context.Context, UpdateProviderAccountStatusRequest) (ProviderAccountInfo, error), accountStatus string) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "update-provider-account-status action is not configured"})
return
}
rawID := strings.TrimSpace(r.PathValue("accountID"))
accountID, err := strconv.ParseInt(rawID, 10, 64)
if err != nil || accountID <= 0 {
writeHTTPError(w, &httpError{StatusCode: http.StatusBadRequest, Code: "invalid_request", Message: "account_id must be a positive integer"})
return
}
req := UpdateProviderAccountStatusRequest{
AccountID: accountID,
AccountStatus: accountStatus,
}
if r.ContentLength != 0 {
if err := decodeJSON(r, &req); err != nil {
writeHTTPError(w, err)
return
}
req.AccountID = accountID
req.AccountStatus = accountStatus
}
account, err := fn(r.Context(), req)
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
writeJSON(w, http.StatusOK, map[string]any{"provider_account": account})
}
func buildListProviderAccountsAction(sqliteDSN string) func(context.Context, ListProviderAccountsRequest) ([]ProviderAccountInfo, error) {
return func(ctx context.Context, req ListProviderAccountsRequest) ([]ProviderAccountInfo, error) {
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {
return nil, err
}
defer store.Close()
if err := sqlite.SyncProviderAccountsFromLatestImportBatches(ctx, store); err != nil {
return nil, err
}
rows, err := store.ProviderAccounts().List(ctx, sqlite.ProviderAccountListFilter{
HostID: req.HostID,
ProviderID: req.ProviderID,
RouteID: req.RouteID,
ShadowGroupID: req.ShadowGroupID,
AccountStatus: req.AccountStatus,
Query: req.Query,
Limit: req.Limit,
})
if err != nil {
return nil, err
}
result := make([]ProviderAccountInfo, 0, len(rows))
for _, row := range rows {
result = append(result, providerAccountViewToInfo(row))
}
return result, nil
}
}
func buildUpdateProviderAccountStatusAction(sqliteDSN, accountStatus string) func(context.Context, UpdateProviderAccountStatusRequest) (ProviderAccountInfo, error) {
return func(ctx context.Context, req UpdateProviderAccountStatusRequest) (ProviderAccountInfo, error) {
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {
return ProviderAccountInfo{}, err
}
defer store.Close()
if err := store.ProviderAccounts().UpdateStatusByID(ctx, req.AccountID, accountStatus, strings.TrimSpace(req.DisabledReason)); err != nil {
if err == sql.ErrNoRows {
return ProviderAccountInfo{}, fmt.Errorf("provider account %d not found", req.AccountID)
}
return ProviderAccountInfo{}, err
}
updated, err := store.ProviderAccounts().GetViewByID(ctx, req.AccountID)
if err != nil {
return ProviderAccountInfo{}, err
}
return providerAccountViewToInfo(updated), nil
}
}
func providerAccountViewToInfo(row sqlite.ProviderAccountView) ProviderAccountInfo {
return ProviderAccountInfo{
ID: row.ID,
HostID: row.HostID,
ProviderID: row.ProviderID,
ProviderName: row.ProviderName,
RouteID: row.RouteID,
LogicalGroupID: row.LogicalGroupID,
ShadowGroupID: row.ShadowGroupID,
HostAccountID: row.HostAccountID,
KeyFingerprint: row.KeyFingerprint,
AccountName: row.AccountName,
AccountStatus: row.AccountStatus,
LastProbeStatus: row.LastProbeStatus,
LastProbeAt: row.LastProbeAt,
DisabledReason: row.DisabledReason,
CreatedAt: row.CreatedAt,
UpdatedAt: row.UpdatedAt,
}
}

View File

@@ -0,0 +1,151 @@
package app
import (
"context"
"encoding/json"
"path/filepath"
"testing"
"sub2api-cn-relay-manager/internal/store/sqlite"
)
func TestAPIListProviderAccountsReturnsRows(t *testing.T) {
handler := NewAPIHandler("secret-token", ActionSet{
ListProviderAccounts: func(_ context.Context, req ListProviderAccountsRequest) ([]ProviderAccountInfo, error) {
if req.ProviderID != "deepseek-official" {
t.Fatalf("ProviderID = %q, want deepseek-official", req.ProviderID)
}
if req.AccountStatus != "disabled" {
t.Fatalf("AccountStatus = %q, want disabled", req.AccountStatus)
}
return []ProviderAccountInfo{{
ID: 7,
HostID: "remote43",
ProviderID: "deepseek-official",
ProviderName: "DeepSeek Official",
HostAccountID: "9",
AccountName: "deepseek-01",
AccountStatus: "disabled",
DisabledReason: "manual_disable",
}}, nil
},
})
request := httptestRequest(t, "GET", "/api/provider-accounts?provider_id=deepseek-official&account_status=disabled", nil, "secret-token")
response := httptestRecorder(handler, request)
assertStatusCode(t, response, 200)
var payload map[string][]ProviderAccountInfo
if err := json.Unmarshal(response.Body().Bytes(), &payload); err != nil {
t.Fatalf("json.Unmarshal() error = %v", err)
}
accounts := payload["provider_accounts"]
if len(accounts) != 1 || accounts[0].ID != 7 || accounts[0].AccountStatus != "disabled" {
t.Fatalf("provider_accounts = %+v, want one disabled row id=7", accounts)
}
}
func TestAPIDisableProviderAccountUsesPathID(t *testing.T) {
handler := NewAPIHandler("secret-token", ActionSet{
DisableProviderAccount: func(_ context.Context, req UpdateProviderAccountStatusRequest) (ProviderAccountInfo, error) {
if req.AccountID != 42 {
t.Fatalf("AccountID = %d, want 42", req.AccountID)
}
if req.AccountStatus != "disabled" {
t.Fatalf("AccountStatus = %q, want disabled", req.AccountStatus)
}
if req.DisabledReason != "manual_disable" {
t.Fatalf("DisabledReason = %q, want manual_disable", req.DisabledReason)
}
return ProviderAccountInfo{ID: req.AccountID, AccountStatus: req.AccountStatus, DisabledReason: req.DisabledReason}, nil
},
})
request := httptestRequest(t, "POST", "/api/provider-accounts/42/disable", map[string]any{"reason": "manual_disable"}, "secret-token")
response := httptestRecorder(handler, request)
assertStatusCode(t, response, 200)
assertJSONContains(t, response.Body().Bytes(), "provider_account.id", float64(42))
assertJSONContains(t, response.Body().Bytes(), "provider_account.account_status", "disabled")
}
func TestNewActionSetProviderAccountListAndStatusFlow(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "provider-accounts.db")
dsn := "file:" + filepath.ToSlash(dbPath) + "?_busy_timeout=5000"
actions := NewActionSet(dsn)
ctx := context.Background()
store, err := sqlite.Open(ctx, dsn)
if err != nil {
t.Fatalf("sqlite.Open() error = %v", err)
}
defer store.Close()
hostID, err := store.Hosts().Create(ctx, sqlite.Host{
HostID: "remote43",
BaseURL: "https://host.example.com",
HostVersion: "0.1.129",
CapabilityProbeJSON: `{"accounts":true}`,
AuthType: "apikey",
AuthToken: "host-key",
})
if err != nil {
t.Fatalf("Hosts().Create() error = %v", err)
}
hostRow, err := store.Hosts().GetByID(ctx, hostID)
if err != nil {
t.Fatalf("Hosts().GetByID() error = %v", err)
}
packID, err := store.Packs().Create(ctx, sqlite.Pack{PackID: "openai-cn-pack", Version: "1.0.0", Checksum: "chk"})
if err != nil {
t.Fatalf("Packs().Create() error = %v", err)
}
providerRowID, err := store.Providers().Create(ctx, sqlite.Provider{
PackID: packID,
ProviderID: "deepseek-official",
DisplayName: "DeepSeek Official",
BaseURL: "https://api.deepseek.com",
Platform: "openai",
})
if err != nil {
t.Fatalf("Providers().Create() error = %v", err)
}
providerAccountID, err := store.ProviderAccounts().Create(ctx, sqlite.ProviderAccount{
HostID: hostRow.ID,
ProviderID: providerRowID,
HostAccountID: "9",
KeyFingerprint: "sha256:abc",
AccountName: "deepseek-01",
AccountStatus: sqlite.ProviderAccountStatusActive,
LastProbeStatus: "passed",
LastProbeAt: "2026-05-29T00:00:00Z",
})
if err != nil {
t.Fatalf("ProviderAccounts().Create() error = %v", err)
}
listed, err := actions.ListProviderAccounts(ctx, ListProviderAccountsRequest{HostID: "remote43", ProviderID: "deepseek-official"})
if err != nil {
t.Fatalf("ListProviderAccounts() error = %v", err)
}
if len(listed) != 1 || listed[0].ID != providerAccountID {
t.Fatalf("ListProviderAccounts() = %+v, want one row for id %d", listed, providerAccountID)
}
disabled, err := actions.DisableProviderAccount(ctx, UpdateProviderAccountStatusRequest{
AccountID: providerAccountID,
DisabledReason: "manual_disable",
})
if err != nil {
t.Fatalf("DisableProviderAccount() error = %v", err)
}
if disabled.AccountStatus != sqlite.ProviderAccountStatusDisabled || disabled.DisabledReason != "manual_disable" {
t.Fatalf("DisableProviderAccount() = %+v", disabled)
}
enabled, err := actions.EnableProviderAccount(ctx, UpdateProviderAccountStatusRequest{AccountID: providerAccountID})
if err != nil {
t.Fatalf("EnableProviderAccount() error = %v", err)
}
if enabled.AccountStatus != sqlite.ProviderAccountStatusActive {
t.Fatalf("EnableProviderAccount() = %+v, want active", enabled)
}
}

View File

@@ -121,6 +121,9 @@ func (s *RuntimeImportService) Import(ctx context.Context, req RuntimeImportRequ
if err := s.store.ImportBatches().UpdateStatus(ctx, batchID, report.BatchStatus, report.AccessStatus); err != nil {
return RuntimeImportResult{}, err
}
if err := sqlite.SyncProviderAccountsFromImportBatch(ctx, s.store, batchID); err != nil {
return RuntimeImportResult{}, err
}
if importErr != nil {
return RuntimeImportResult{BatchID: batchID, Report: report}, importErr
}

View File

@@ -87,6 +87,9 @@ func TestRuntimeImportServicePersistsOperationalState(t *testing.T) {
if got := queryCount(t, store.SQLDB(), "access_closure_records"); got != 1 {
t.Fatalf("access_closure_records row count = %d, want 1", got)
}
if got := queryCount(t, store.SQLDB(), "provider_accounts"); got != 2 {
t.Fatalf("provider_accounts row count = %d, want 2", got)
}
var batchStatus string
var accessStatus string
@@ -111,6 +114,18 @@ func TestRuntimeImportServicePersistsOperationalState(t *testing.T) {
if accountStatus != "passed" {
t.Fatalf("account_status = %q, want passed", accountStatus)
}
var inventoryStatus string
var inventoryShadowGroup string
if err := store.SQLDB().QueryRowContext(context.Background(), "SELECT account_status, shadow_group_id FROM provider_accounts WHERE host_account_id = ? ORDER BY id LIMIT 1", "account_1").Scan(&inventoryStatus, &inventoryShadowGroup); err != nil {
t.Fatalf("query provider account inventory: %v", err)
}
if inventoryStatus != sqlite.ProviderAccountStatusActive {
t.Fatalf("provider_accounts.account_status = %q, want %q", inventoryStatus, sqlite.ProviderAccountStatusActive)
}
if inventoryShadowGroup == "" {
t.Fatal("provider_accounts.shadow_group_id = empty, want group id")
}
}
func TestRuntimeImportServiceIncludesMatchingHostOverlaysInReport(t *testing.T) {

View File

@@ -0,0 +1,25 @@
CREATE TABLE provider_accounts (
id INTEGER PRIMARY KEY AUTOINCREMENT,
host_id INTEGER NOT NULL,
provider_id INTEGER NOT NULL,
route_id TEXT NOT NULL DEFAULT '',
shadow_group_id TEXT NOT NULL DEFAULT '',
host_account_id TEXT NOT NULL,
key_fingerprint TEXT NOT NULL,
account_name TEXT NOT NULL DEFAULT '',
account_status TEXT NOT NULL,
last_probe_status TEXT NOT NULL DEFAULT '',
last_probe_at TEXT NOT NULL DEFAULT '',
disabled_reason TEXT NOT NULL DEFAULT '',
created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
CONSTRAINT fk_provider_accounts_host FOREIGN KEY (host_id) REFERENCES hosts(id) ON DELETE CASCADE,
CONSTRAINT fk_provider_accounts_provider FOREIGN KEY (provider_id) REFERENCES providers(id) ON DELETE CASCADE,
CONSTRAINT uq_provider_accounts_host_account UNIQUE (host_id, host_account_id),
CHECK (account_status IN ('active', 'disabled', 'deprecated', 'broken'))
);
CREATE INDEX idx_provider_accounts_provider_host ON provider_accounts(provider_id, host_id);
CREATE INDEX idx_provider_accounts_status ON provider_accounts(account_status);
CREATE INDEX idx_provider_accounts_route_id ON provider_accounts(route_id);
CREATE INDEX idx_provider_accounts_shadow_group_id ON provider_accounts(shadow_group_id);

View File

@@ -30,6 +30,7 @@ type Queries struct {
RouteDecisionLogs *RouteDecisionLogsRepo
RouteFailoverEvents *RouteFailoverEventsRepo
RouteStickyAudit *RouteStickyAuditRepo
ProviderAccounts *ProviderAccountsRepo
ProviderDrafts *ProviderDraftsRepo
ImportBatches *ImportBatchesRepo
ImportBatchItems *ImportBatchItemsRepo
@@ -127,6 +128,10 @@ func (db *DB) RouteStickyAudit() *RouteStickyAuditRepo {
return db.queries.RouteStickyAudit
}
func (db *DB) ProviderAccounts() *ProviderAccountsRepo {
return db.queries.ProviderAccounts
}
func (db *DB) ProviderDrafts() *ProviderDraftsRepo {
return db.queries.ProviderDrafts
}
@@ -206,6 +211,7 @@ func newQueries(db execQuerier) *Queries {
RouteDecisionLogs: newRouteDecisionLogsRepo(db),
RouteFailoverEvents: newRouteFailoverEventsRepo(db),
RouteStickyAudit: newRouteStickyAuditRepo(db),
ProviderAccounts: newProviderAccountsRepo(db),
ProviderDrafts: newProviderDraftsRepo(db),
ImportBatches: newImportBatchesRepo(db),
ImportBatchItems: newImportBatchItemsRepo(db),

View File

@@ -114,6 +114,7 @@ func TestOpenAppliesLogicalRoutingTables(t *testing.T) {
"route_decision_logs",
"route_failover_events",
"route_sticky_audit",
"provider_accounts",
} {
found, err := tableExists(context.Background(), db, table)
if err != nil {

View File

@@ -0,0 +1,438 @@
package sqlite
import (
"context"
"database/sql"
"fmt"
"strings"
)
const (
ProviderAccountStatusActive = "active"
ProviderAccountStatusDisabled = "disabled"
ProviderAccountStatusDeprecated = "deprecated"
ProviderAccountStatusBroken = "broken"
)
type ProviderAccount struct {
ID int64
HostID int64
ProviderID int64
RouteID string
ShadowGroupID string
HostAccountID string
KeyFingerprint string
AccountName string
AccountStatus string
LastProbeStatus string
LastProbeAt string
DisabledReason string
CreatedAt string
UpdatedAt string
}
type ProviderAccountListFilter struct {
HostID string
ProviderID string
RouteID string
ShadowGroupID string
AccountStatus string
Query string
Limit int
}
type ProviderAccountView struct {
ID int64 `json:"id"`
HostID string `json:"host_id"`
ProviderID string `json:"provider_id"`
ProviderName string `json:"provider_name"`
RouteID string `json:"route_id,omitempty"`
LogicalGroupID string `json:"logical_group_id,omitempty"`
ShadowGroupID string `json:"shadow_group_id,omitempty"`
HostAccountID string `json:"host_account_id"`
KeyFingerprint string `json:"key_fingerprint"`
AccountName string `json:"account_name"`
AccountStatus string `json:"account_status"`
LastProbeStatus string `json:"last_probe_status,omitempty"`
LastProbeAt string `json:"last_probe_at,omitempty"`
DisabledReason string `json:"disabled_reason,omitempty"`
CreatedAt string `json:"created_at,omitempty"`
UpdatedAt string `json:"updated_at,omitempty"`
}
type ProviderAccountsRepo struct {
db execQuerier
}
func newProviderAccountsRepo(db execQuerier) *ProviderAccountsRepo {
return &ProviderAccountsRepo{db: db}
}
func (r *ProviderAccountsRepo) Create(ctx context.Context, account ProviderAccount) (int64, error) {
account, err := normalizeProviderAccount(account)
if err != nil {
return 0, err
}
result, err := r.db.ExecContext(ctx, `INSERT INTO provider_accounts (
host_id, provider_id, route_id, shadow_group_id, host_account_id, key_fingerprint,
account_name, account_status, last_probe_status, last_probe_at, disabled_reason
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
account.HostID,
account.ProviderID,
account.RouteID,
account.ShadowGroupID,
account.HostAccountID,
account.KeyFingerprint,
account.AccountName,
account.AccountStatus,
account.LastProbeStatus,
account.LastProbeAt,
account.DisabledReason,
)
if err != nil {
return 0, fmt.Errorf("insert provider account %q: %w", account.HostAccountID, err)
}
id, err := result.LastInsertId()
if err != nil {
return 0, fmt.Errorf("read inserted provider account id for %q: %w", account.HostAccountID, err)
}
return id, nil
}
func (r *ProviderAccountsRepo) Upsert(ctx context.Context, account ProviderAccount) (int64, error) {
account, err := normalizeProviderAccount(account)
if err != nil {
return 0, err
}
result, err := r.db.ExecContext(ctx, `INSERT INTO provider_accounts (
host_id, provider_id, route_id, shadow_group_id, host_account_id, key_fingerprint,
account_name, account_status, last_probe_status, last_probe_at, disabled_reason
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(host_id, host_account_id) DO UPDATE SET
provider_id = excluded.provider_id,
route_id = excluded.route_id,
shadow_group_id = excluded.shadow_group_id,
key_fingerprint = excluded.key_fingerprint,
account_name = excluded.account_name,
account_status = excluded.account_status,
last_probe_status = excluded.last_probe_status,
last_probe_at = excluded.last_probe_at,
disabled_reason = excluded.disabled_reason,
updated_at = CURRENT_TIMESTAMP`,
account.HostID,
account.ProviderID,
account.RouteID,
account.ShadowGroupID,
account.HostAccountID,
account.KeyFingerprint,
account.AccountName,
account.AccountStatus,
account.LastProbeStatus,
account.LastProbeAt,
account.DisabledReason,
)
if err != nil {
return 0, fmt.Errorf("upsert provider account %q: %w", account.HostAccountID, err)
}
id, err := result.LastInsertId()
if err != nil {
return 0, fmt.Errorf("read provider account last insert id for %q: %w", account.HostAccountID, err)
}
if existing, err := r.GetByHostIDAndAccountID(ctx, account.HostID, account.HostAccountID); err == nil {
return existing.ID, nil
}
return id, nil
}
func (r *ProviderAccountsRepo) GetByID(ctx context.Context, id int64) (ProviderAccount, error) {
if id <= 0 {
return ProviderAccount{}, fmt.Errorf("id is required")
}
return r.scanOne(ctx, `SELECT id, host_id, provider_id, route_id, shadow_group_id, host_account_id, key_fingerprint, account_name, account_status, last_probe_status, last_probe_at, disabled_reason, created_at, updated_at FROM provider_accounts WHERE id = ?`, id)
}
func (r *ProviderAccountsRepo) GetByHostIDAndAccountID(ctx context.Context, hostID int64, hostAccountID string) (ProviderAccount, error) {
if hostID <= 0 {
return ProviderAccount{}, fmt.Errorf("host_id is required")
}
hostAccountID = strings.TrimSpace(hostAccountID)
if hostAccountID == "" {
return ProviderAccount{}, fmt.Errorf("host_account_id is required")
}
return r.scanOne(ctx, `SELECT id, host_id, provider_id, route_id, shadow_group_id, host_account_id, key_fingerprint, account_name, account_status, last_probe_status, last_probe_at, disabled_reason, created_at, updated_at FROM provider_accounts WHERE host_id = ? AND host_account_id = ?`, hostID, hostAccountID)
}
func (r *ProviderAccountsRepo) GetViewByID(ctx context.Context, id int64) (ProviderAccountView, error) {
if id <= 0 {
return ProviderAccountView{}, fmt.Errorf("id is required")
}
return r.scanViewOne(ctx, `SELECT
pa.id,
h.host_id,
p.provider_id,
p.display_name,
COALESCE(pa.route_id, ''),
COALESCE(lgr.logical_group_id, ''),
COALESCE(pa.shadow_group_id, ''),
pa.host_account_id,
pa.key_fingerprint,
pa.account_name,
pa.account_status,
COALESCE(pa.last_probe_status, ''),
COALESCE(pa.last_probe_at, ''),
COALESCE(pa.disabled_reason, ''),
pa.created_at,
pa.updated_at
FROM provider_accounts pa
JOIN hosts h ON h.id = pa.host_id
JOIN providers p ON p.id = pa.provider_id
LEFT JOIN logical_group_routes lgr ON lgr.route_id = pa.route_id
WHERE pa.id = ?`, id)
}
func (r *ProviderAccountsRepo) UpdateStatusByID(ctx context.Context, id int64, accountStatus, disabledReason string) error {
if id <= 0 {
return fmt.Errorf("id is required")
}
accountStatus = normalizeProviderAccountStatus(accountStatus)
if accountStatus == "" {
return fmt.Errorf("account_status is required")
}
disabledReason = strings.TrimSpace(disabledReason)
result, err := r.db.ExecContext(ctx, `UPDATE provider_accounts SET account_status = ?, disabled_reason = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?`, accountStatus, disabledReason, id)
if err != nil {
return fmt.Errorf("update provider account %d status: %w", id, err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("provider account %d rows affected: %w", id, err)
}
if rows == 0 {
return sql.ErrNoRows
}
return nil
}
func (r *ProviderAccountsRepo) DeprecateMissingForScope(ctx context.Context, providerID, hostID int64, keepHostAccountIDs []string, reason string) error {
if providerID <= 0 {
return fmt.Errorf("provider_id is required")
}
if hostID <= 0 {
return fmt.Errorf("host_id is required")
}
reason = strings.TrimSpace(reason)
args := []any{ProviderAccountStatusDeprecated, reason, providerID, hostID}
query := `UPDATE provider_accounts
SET account_status = ?, disabled_reason = ?, updated_at = CURRENT_TIMESTAMP
WHERE provider_id = ? AND host_id = ?`
if len(keepHostAccountIDs) > 0 {
placeholders := make([]string, 0, len(keepHostAccountIDs))
for _, id := range keepHostAccountIDs {
trimmed := strings.TrimSpace(id)
if trimmed == "" {
continue
}
placeholders = append(placeholders, "?")
args = append(args, trimmed)
}
if len(placeholders) > 0 {
query += ` AND host_account_id NOT IN (` + strings.Join(placeholders, ",") + `)`
}
}
query += ` AND account_status IN ('active', 'broken')`
if _, err := r.db.ExecContext(ctx, query, args...); err != nil {
return fmt.Errorf("deprecate missing provider accounts for provider_id=%d host_id=%d: %w", providerID, hostID, err)
}
return nil
}
func (r *ProviderAccountsRepo) List(ctx context.Context, filter ProviderAccountListFilter) ([]ProviderAccountView, error) {
query := `SELECT
pa.id,
h.host_id,
p.provider_id,
p.display_name,
COALESCE(pa.route_id, ''),
COALESCE(lgr.logical_group_id, ''),
COALESCE(pa.shadow_group_id, ''),
pa.host_account_id,
pa.key_fingerprint,
pa.account_name,
pa.account_status,
COALESCE(pa.last_probe_status, ''),
COALESCE(pa.last_probe_at, ''),
COALESCE(pa.disabled_reason, ''),
pa.created_at,
pa.updated_at
FROM provider_accounts pa
JOIN hosts h ON h.id = pa.host_id
JOIN providers p ON p.id = pa.provider_id
LEFT JOIN logical_group_routes lgr ON lgr.route_id = pa.route_id
WHERE 1 = 1`
args := make([]any, 0)
if value := strings.TrimSpace(filter.HostID); value != "" {
query += ` AND h.host_id = ?`
args = append(args, value)
}
if value := strings.TrimSpace(filter.ProviderID); value != "" {
query += ` AND p.provider_id = ?`
args = append(args, value)
}
if value := strings.TrimSpace(filter.RouteID); value != "" {
query += ` AND pa.route_id = ?`
args = append(args, value)
}
if value := strings.TrimSpace(filter.ShadowGroupID); value != "" {
query += ` AND pa.shadow_group_id = ?`
args = append(args, value)
}
if value := normalizeProviderAccountStatus(filter.AccountStatus); value != "" {
query += ` AND pa.account_status = ?`
args = append(args, value)
}
if value := strings.TrimSpace(filter.Query); value != "" {
like := "%" + strings.ToLower(value) + "%"
query += ` AND (
LOWER(pa.host_account_id) LIKE ? OR
LOWER(pa.account_name) LIKE ? OR
LOWER(pa.key_fingerprint) LIKE ? OR
LOWER(p.provider_id) LIKE ? OR
LOWER(h.host_id) LIKE ?
)`
args = append(args, like, like, like, like, like)
}
query += ` ORDER BY pa.updated_at DESC, pa.id DESC`
limit := filter.Limit
if limit <= 0 {
limit = 200
}
query += ` LIMIT ?`
args = append(args, limit)
rows, err := r.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("list provider accounts: %w", err)
}
defer rows.Close()
views := make([]ProviderAccountView, 0)
for rows.Next() {
var view ProviderAccountView
if err := rows.Scan(
&view.ID,
&view.HostID,
&view.ProviderID,
&view.ProviderName,
&view.RouteID,
&view.LogicalGroupID,
&view.ShadowGroupID,
&view.HostAccountID,
&view.KeyFingerprint,
&view.AccountName,
&view.AccountStatus,
&view.LastProbeStatus,
&view.LastProbeAt,
&view.DisabledReason,
&view.CreatedAt,
&view.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("scan provider account view: %w", err)
}
views = append(views, view)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate provider accounts: %w", err)
}
return views, nil
}
func (r *ProviderAccountsRepo) scanOne(ctx context.Context, query string, args ...any) (ProviderAccount, error) {
var row ProviderAccount
if err := r.db.QueryRowContext(ctx, query, args...).Scan(
&row.ID,
&row.HostID,
&row.ProviderID,
&row.RouteID,
&row.ShadowGroupID,
&row.HostAccountID,
&row.KeyFingerprint,
&row.AccountName,
&row.AccountStatus,
&row.LastProbeStatus,
&row.LastProbeAt,
&row.DisabledReason,
&row.CreatedAt,
&row.UpdatedAt,
); err != nil {
return ProviderAccount{}, err
}
return row, nil
}
func (r *ProviderAccountsRepo) scanViewOne(ctx context.Context, query string, args ...any) (ProviderAccountView, error) {
var view ProviderAccountView
if err := r.db.QueryRowContext(ctx, query, args...).Scan(
&view.ID,
&view.HostID,
&view.ProviderID,
&view.ProviderName,
&view.RouteID,
&view.LogicalGroupID,
&view.ShadowGroupID,
&view.HostAccountID,
&view.KeyFingerprint,
&view.AccountName,
&view.AccountStatus,
&view.LastProbeStatus,
&view.LastProbeAt,
&view.DisabledReason,
&view.CreatedAt,
&view.UpdatedAt,
); err != nil {
return ProviderAccountView{}, err
}
return view, nil
}
func normalizeProviderAccount(account ProviderAccount) (ProviderAccount, error) {
account.RouteID = strings.TrimSpace(account.RouteID)
account.ShadowGroupID = strings.TrimSpace(account.ShadowGroupID)
account.HostAccountID = strings.TrimSpace(account.HostAccountID)
account.KeyFingerprint = strings.TrimSpace(account.KeyFingerprint)
account.AccountName = strings.TrimSpace(account.AccountName)
account.AccountStatus = normalizeProviderAccountStatus(account.AccountStatus)
account.LastProbeStatus = strings.TrimSpace(account.LastProbeStatus)
account.LastProbeAt = strings.TrimSpace(account.LastProbeAt)
account.DisabledReason = strings.TrimSpace(account.DisabledReason)
switch {
case account.HostID <= 0:
return ProviderAccount{}, fmt.Errorf("host_id is required")
case account.ProviderID <= 0:
return ProviderAccount{}, fmt.Errorf("provider_id is required")
case account.HostAccountID == "":
return ProviderAccount{}, fmt.Errorf("host_account_id is required")
case account.KeyFingerprint == "":
return ProviderAccount{}, fmt.Errorf("key_fingerprint is required")
case account.AccountStatus == "":
return ProviderAccount{}, fmt.Errorf("account_status is required")
}
return account, nil
}
func normalizeProviderAccountStatus(status string) string {
switch strings.TrimSpace(status) {
case ProviderAccountStatusActive:
return ProviderAccountStatusActive
case ProviderAccountStatusDisabled:
return ProviderAccountStatusDisabled
case ProviderAccountStatusDeprecated:
return ProviderAccountStatusDeprecated
case ProviderAccountStatusBroken:
return ProviderAccountStatusBroken
default:
return ""
}
}

View File

@@ -0,0 +1,286 @@
package sqlite
import (
"context"
"testing"
)
func TestProviderAccountsRepoCRUDAndFilters(t *testing.T) {
t.Parallel()
store := openTestDBWithFK(t)
ctx := context.Background()
hostID := createTestHost(t, store)
packID := createTestPack(t, store)
providerID, err := store.Providers().Create(ctx, Provider{
PackID: packID,
ProviderID: "deepseek-official",
DisplayName: "DeepSeek Official",
BaseURL: "https://api.deepseek.com",
Platform: "openai",
})
if err != nil {
t.Fatalf("Providers().Create() error = %v", err)
}
if _, err := store.LogicalGroups().Create(ctx, LogicalGroup{LogicalGroupID: "lg-1", DisplayName: "LG 1", Status: "active"}); err != nil {
t.Fatalf("LogicalGroups().Create() error = %v", err)
}
if _, err := store.LogicalGroupRoutes().Create(ctx, LogicalGroupRoute{
RouteID: "route-1",
LogicalGroupID: "lg-1",
Name: "Route 1",
Status: "active",
Priority: 10,
Weight: 100,
ShadowGroupID: "shadow-group-1",
ShadowHostID: "shadow-host-1",
}); err != nil {
t.Fatalf("LogicalGroupRoutes().Create() error = %v", err)
}
accountRepo := store.ProviderAccounts()
accountID, err := accountRepo.Create(ctx, ProviderAccount{
HostID: hostID,
ProviderID: providerID,
RouteID: "route-1",
ShadowGroupID: "shadow-group-1",
HostAccountID: "account-1",
KeyFingerprint: "sha256:abc",
AccountName: "deepseek-01",
AccountStatus: ProviderAccountStatusActive,
LastProbeStatus: "passed",
LastProbeAt: "2026-05-29T00:00:00Z",
})
if err != nil {
t.Fatalf("ProviderAccounts().Create() error = %v", err)
}
got, err := accountRepo.GetByID(ctx, accountID)
if err != nil {
t.Fatalf("ProviderAccounts().GetByID() error = %v", err)
}
if got.HostAccountID != "account-1" || got.AccountStatus != ProviderAccountStatusActive {
t.Fatalf("ProviderAccounts().GetByID() = %+v", got)
}
if _, err := accountRepo.Upsert(ctx, ProviderAccount{
HostID: hostID,
ProviderID: providerID,
RouteID: "route-1",
ShadowGroupID: "shadow-group-1",
HostAccountID: "account-1",
KeyFingerprint: "sha256:abc",
AccountName: "deepseek-01",
AccountStatus: ProviderAccountStatusBroken,
LastProbeStatus: "failed",
LastProbeAt: "2026-05-29T01:00:00Z",
}); err != nil {
t.Fatalf("ProviderAccounts().Upsert() error = %v", err)
}
view, err := accountRepo.GetViewByID(ctx, accountID)
if err != nil {
t.Fatalf("ProviderAccounts().GetViewByID() error = %v", err)
}
if view.ProviderID != "deepseek-official" || view.LogicalGroupID != "lg-1" || view.AccountStatus != ProviderAccountStatusBroken {
t.Fatalf("ProviderAccounts().GetViewByID() = %+v", view)
}
rows, err := accountRepo.List(ctx, ProviderAccountListFilter{
HostID: "host-" + sanitizeTestName(t.Name()),
ProviderID: "deepseek-official",
RouteID: "route-1",
ShadowGroupID: "shadow-group-1",
AccountStatus: ProviderAccountStatusBroken,
Query: "deepseek",
})
if err != nil {
t.Fatalf("ProviderAccounts().List() error = %v", err)
}
if len(rows) != 1 || rows[0].ID != accountID {
t.Fatalf("ProviderAccounts().List() = %+v, want one row for account_id %d", rows, accountID)
}
if err := accountRepo.UpdateStatusByID(ctx, accountID, ProviderAccountStatusDisabled, "manual_disable"); err != nil {
t.Fatalf("ProviderAccounts().UpdateStatusByID() error = %v", err)
}
got, err = accountRepo.GetByID(ctx, accountID)
if err != nil {
t.Fatalf("ProviderAccounts().GetByID() after status update error = %v", err)
}
if got.AccountStatus != ProviderAccountStatusDisabled || got.DisabledReason != "manual_disable" {
t.Fatalf("ProviderAccounts().GetByID() after status update = %+v", got)
}
}
func TestSyncProviderAccountsFromImportBatchCreatesAndDeprecatesInventory(t *testing.T) {
t.Parallel()
store := openTestDBWithFK(t)
ctx := context.Background()
hostID := createTestHost(t, store)
packID := createTestPack(t, store)
providerID, err := store.Providers().Create(ctx, Provider{
PackID: packID,
ProviderID: "asxs-provider",
DisplayName: "ASXS Provider",
BaseURL: "https://api.asxs.top/v1",
Platform: "openai",
})
if err != nil {
t.Fatalf("Providers().Create() error = %v", err)
}
batch1, err := store.ImportBatches().Create(ctx, ImportBatch{
HostID: hostID,
PackID: packID,
ProviderID: providerID,
Mode: "strict",
BatchStatus: "succeeded",
AccessStatus: "subscription_ready",
})
if err != nil {
t.Fatalf("ImportBatches().Create(batch1) error = %v", err)
}
if _, err := store.ImportBatchItems().Create(ctx, ImportBatchItem{
BatchID: batch1,
KeyFingerprint: "sha256:key1",
AccountStatus: "passed",
ProbeSummaryJSON: `{"account_id":"account-1","probe_status":"passed"}`,
}); err != nil {
t.Fatalf("ImportBatchItems().Create(batch1) error = %v", err)
}
for _, resource := range []ManagedResource{
{BatchID: batch1, HostID: hostID, ResourceType: "group", HostResourceID: "group-1", ResourceName: "ASXS Group"},
{BatchID: batch1, HostID: hostID, ResourceType: "account", HostResourceID: "account-1", ResourceName: "asxs-01"},
} {
if _, err := store.ManagedResources().Create(ctx, resource); err != nil {
t.Fatalf("ManagedResources().Create(batch1/%s) error = %v", resource.ResourceType, err)
}
}
if err := SyncProviderAccountsFromImportBatch(ctx, store, batch1); err != nil {
t.Fatalf("SyncProviderAccountsFromImportBatch(batch1) error = %v", err)
}
account1, err := store.ProviderAccounts().GetByHostIDAndAccountID(ctx, hostID, "account-1")
if err != nil {
t.Fatalf("ProviderAccounts().GetByHostIDAndAccountID(account-1) error = %v", err)
}
if account1.AccountStatus != ProviderAccountStatusActive || account1.ShadowGroupID != "group-1" {
t.Fatalf("account-1 = %+v, want active shadow group-1", account1)
}
batch2, err := store.ImportBatches().Create(ctx, ImportBatch{
HostID: hostID,
PackID: packID,
ProviderID: providerID,
Mode: "strict",
BatchStatus: "succeeded",
AccessStatus: "subscription_ready",
})
if err != nil {
t.Fatalf("ImportBatches().Create(batch2) error = %v", err)
}
if _, err := store.ImportBatchItems().Create(ctx, ImportBatchItem{
BatchID: batch2,
KeyFingerprint: "sha256:key2",
AccountStatus: "failed",
ProbeSummaryJSON: `{"account_id":"account-2","probe_status":"failed"}`,
}); err != nil {
t.Fatalf("ImportBatchItems().Create(batch2) error = %v", err)
}
for _, resource := range []ManagedResource{
{BatchID: batch2, HostID: hostID, ResourceType: "group", HostResourceID: "group-2", ResourceName: "ASXS Group 2"},
{BatchID: batch2, HostID: hostID, ResourceType: "account", HostResourceID: "account-2", ResourceName: "asxs-02"},
} {
if _, err := store.ManagedResources().Create(ctx, resource); err != nil {
t.Fatalf("ManagedResources().Create(batch2/%s) error = %v", resource.ResourceType, err)
}
}
if err := SyncProviderAccountsFromImportBatch(ctx, store, batch2); err != nil {
t.Fatalf("SyncProviderAccountsFromImportBatch(batch2) error = %v", err)
}
account1, err = store.ProviderAccounts().GetByHostIDAndAccountID(ctx, hostID, "account-1")
if err != nil {
t.Fatalf("ProviderAccounts().GetByHostIDAndAccountID(account-1 after batch2) error = %v", err)
}
if account1.AccountStatus != ProviderAccountStatusDeprecated || account1.DisabledReason != providerAccountDeprecatedMissingReason {
t.Fatalf("account-1 after batch2 = %+v, want deprecated missing_from_latest_batch", account1)
}
account2, err := store.ProviderAccounts().GetByHostIDAndAccountID(ctx, hostID, "account-2")
if err != nil {
t.Fatalf("ProviderAccounts().GetByHostIDAndAccountID(account-2) error = %v", err)
}
if account2.AccountStatus != ProviderAccountStatusBroken || account2.LastProbeStatus != "failed" {
t.Fatalf("account-2 = %+v, want broken failed", account2)
}
}
func TestSyncProviderAccountsFromImportBatchPreservesManualDisabledStatus(t *testing.T) {
t.Parallel()
store := openTestDBWithFK(t)
ctx := context.Background()
hostID := createTestHost(t, store)
packID := createTestPack(t, store)
providerID, err := store.Providers().Create(ctx, Provider{
PackID: packID,
ProviderID: "asxs-provider",
DisplayName: "ASXS Provider",
BaseURL: "https://api.asxs.top/v1",
Platform: "openai",
})
if err != nil {
t.Fatalf("Providers().Create() error = %v", err)
}
batchID, err := store.ImportBatches().Create(ctx, ImportBatch{
HostID: hostID,
PackID: packID,
ProviderID: providerID,
Mode: "strict",
BatchStatus: "succeeded",
AccessStatus: "subscription_ready",
})
if err != nil {
t.Fatalf("ImportBatches().Create() error = %v", err)
}
if _, err := store.ImportBatchItems().Create(ctx, ImportBatchItem{
BatchID: batchID,
KeyFingerprint: "sha256:key1",
AccountStatus: "passed",
ProbeSummaryJSON: `{"account_id":"account-1","probe_status":"passed"}`,
}); err != nil {
t.Fatalf("ImportBatchItems().Create() error = %v", err)
}
for _, resource := range []ManagedResource{
{BatchID: batchID, HostID: hostID, ResourceType: "group", HostResourceID: "group-1", ResourceName: "ASXS Group"},
{BatchID: batchID, HostID: hostID, ResourceType: "account", HostResourceID: "account-1", ResourceName: "asxs-01"},
} {
if _, err := store.ManagedResources().Create(ctx, resource); err != nil {
t.Fatalf("ManagedResources().Create(%s) error = %v", resource.ResourceType, err)
}
}
if err := SyncProviderAccountsFromImportBatch(ctx, store, batchID); err != nil {
t.Fatalf("SyncProviderAccountsFromImportBatch() error = %v", err)
}
account, err := store.ProviderAccounts().GetByHostIDAndAccountID(ctx, hostID, "account-1")
if err != nil {
t.Fatalf("ProviderAccounts().GetByHostIDAndAccountID() error = %v", err)
}
if err := store.ProviderAccounts().UpdateStatusByID(ctx, account.ID, ProviderAccountStatusDisabled, "manual_disable"); err != nil {
t.Fatalf("ProviderAccounts().UpdateStatusByID() error = %v", err)
}
if err := SyncProviderAccountsFromImportBatch(ctx, store, batchID); err != nil {
t.Fatalf("SyncProviderAccountsFromImportBatch(second) error = %v", err)
}
account, err = store.ProviderAccounts().GetByHostIDAndAccountID(ctx, hostID, "account-1")
if err != nil {
t.Fatalf("ProviderAccounts().GetByHostIDAndAccountID(second) error = %v", err)
}
if account.AccountStatus != ProviderAccountStatusDisabled || account.DisabledReason != "manual_disable" {
t.Fatalf("account after resync = %+v, want disabled manual_disable preserved", account)
}
}

View File

@@ -0,0 +1,181 @@
package sqlite
import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
)
const providerAccountDeprecatedMissingReason = "missing_from_latest_batch"
func SyncProviderAccountsFromLatestImportBatches(ctx context.Context, store *DB) error {
if store == nil {
return fmt.Errorf("store is required")
}
batches, err := store.ImportBatches().ListLatestReconcilable(ctx)
if err != nil {
return err
}
for _, batch := range batches {
if err := SyncProviderAccountsFromImportBatch(ctx, store, batch.ID); err != nil {
return err
}
}
return nil
}
func SyncProviderAccountsFromImportBatch(ctx context.Context, store *DB, batchID int64) error {
if store == nil {
return fmt.Errorf("store is required")
}
if batchID <= 0 {
return fmt.Errorf("batch_id is required")
}
batch, err := store.ImportBatches().GetByID(ctx, batchID)
if err != nil {
return fmt.Errorf("get import batch %d: %w", batchID, err)
}
switch strings.TrimSpace(batch.BatchStatus) {
case "succeeded", "partially_succeeded":
default:
return nil
}
resources, err := store.ManagedResources().GetByBatchID(ctx, batchID)
if err != nil {
return fmt.Errorf("get managed resources for batch %d: %w", batchID, err)
}
items, err := store.ImportBatchItems().GetByBatchID(ctx, batchID)
if err != nil {
return fmt.Errorf("get import batch items for batch %d: %w", batchID, err)
}
nowText := time.Now().UTC().Format(time.RFC3339)
shadowGroupID := ""
for _, resource := range resources {
if strings.TrimSpace(resource.ResourceType) == "group" {
shadowGroupID = strings.TrimSpace(resource.HostResourceID)
break
}
}
accountResources := make([]ManagedResource, 0)
for _, resource := range resources {
if strings.TrimSpace(resource.ResourceType) == "account" {
accountResources = append(accountResources, resource)
}
}
itemByAccountID, unmatchedItems := indexBatchItemsByAccountID(items)
keepAccountIDs := make([]string, 0, len(accountResources))
for index, resource := range accountResources {
hostAccountID := strings.TrimSpace(resource.HostResourceID)
if hostAccountID == "" {
continue
}
keepAccountIDs = append(keepAccountIDs, hostAccountID)
match, ok := itemByAccountID[hostAccountID]
if !ok && index < len(unmatchedItems) {
match = unmatchedItems[index]
}
row := ProviderAccount{
HostID: batch.HostID,
ProviderID: batch.ProviderID,
ShadowGroupID: shadowGroupID,
HostAccountID: hostAccountID,
KeyFingerprint: fallbackString(match.KeyFingerprint, "legacy:"+hostAccountID),
AccountName: fallbackString(resource.ResourceName, hostAccountID),
AccountStatus: providerAccountStatusFromLegacy(match.AccountStatus),
LastProbeStatus: strings.TrimSpace(match.ProbeStatus),
LastProbeAt: nowText,
}
if existing, err := store.ProviderAccounts().GetByHostIDAndAccountID(ctx, batch.HostID, hostAccountID); err == nil {
if strings.TrimSpace(existing.RouteID) != "" {
row.RouteID = existing.RouteID
}
if strings.TrimSpace(existing.ShadowGroupID) != "" {
row.ShadowGroupID = existing.ShadowGroupID
}
preserveManagedProviderAccountStatus(&row, existing)
}
if _, err := store.ProviderAccounts().Upsert(ctx, row); err != nil {
return fmt.Errorf("upsert provider account %q from batch %d: %w", hostAccountID, batchID, err)
}
}
if err := store.ProviderAccounts().DeprecateMissingForScope(ctx, batch.ProviderID, batch.HostID, keepAccountIDs, providerAccountDeprecatedMissingReason); err != nil {
return err
}
return nil
}
type legacyBatchAccountProjection struct {
KeyFingerprint string
AccountStatus string
ProbeStatus string
AccountID string
}
func indexBatchItemsByAccountID(items []ImportBatchItem) (map[string]legacyBatchAccountProjection, []legacyBatchAccountProjection) {
indexed := make(map[string]legacyBatchAccountProjection, len(items))
unmatched := make([]legacyBatchAccountProjection, 0, len(items))
for _, item := range items {
projection := legacyBatchAccountProjection{
KeyFingerprint: strings.TrimSpace(item.KeyFingerprint),
AccountStatus: strings.TrimSpace(item.AccountStatus),
}
var payload map[string]any
if err := json.Unmarshal([]byte(defaultJSON(strings.TrimSpace(item.ProbeSummaryJSON), "{}")), &payload); err == nil {
if value, ok := payload["probe_status"].(string); ok {
projection.ProbeStatus = strings.TrimSpace(value)
}
if value, ok := payload["account_id"].(string); ok {
projection.AccountID = strings.TrimSpace(value)
}
}
if projection.AccountID != "" {
indexed[projection.AccountID] = projection
continue
}
unmatched = append(unmatched, projection)
}
return indexed, unmatched
}
func providerAccountStatusFromLegacy(accountStatus string) string {
switch strings.TrimSpace(accountStatus) {
case "passed", "warning":
return ProviderAccountStatusActive
case ProviderAccountStatusDisabled:
return ProviderAccountStatusDisabled
case ProviderAccountStatusDeprecated:
return ProviderAccountStatusDeprecated
default:
return ProviderAccountStatusBroken
}
}
func fallbackString(values ...string) string {
for _, value := range values {
if trimmed := strings.TrimSpace(value); trimmed != "" {
return trimmed
}
}
return ""
}
func preserveManagedProviderAccountStatus(row *ProviderAccount, existing ProviderAccount) {
if row == nil {
return
}
switch strings.TrimSpace(existing.AccountStatus) {
case ProviderAccountStatusDisabled:
row.AccountStatus = ProviderAccountStatusDisabled
row.DisabledReason = strings.TrimSpace(existing.DisabledReason)
case ProviderAccountStatusDeprecated:
if strings.TrimSpace(existing.DisabledReason) != providerAccountDeprecatedMissingReason {
row.AccountStatus = ProviderAccountStatusDeprecated
row.DisabledReason = strings.TrimSpace(existing.DisabledReason)
}
}
}

View File

@@ -144,6 +144,7 @@ func TestStoreAppliesLatestMigration(t *testing.T) {
"route_decision_logs",
"route_failover_events",
"route_sticky_audit",
"provider_accounts",
} {
if !tableExists(t, store.SQLDB(), table) {
t.Fatalf("table %q does not exist after latest migration", table)
@@ -272,6 +273,23 @@ func TestStoreAppliesLatestMigration(t *testing.T) {
t.Fatalf("column %q missing from route_sticky_audit", column)
}
}
for _, column := range []string{
"host_id",
"provider_id",
"route_id",
"shadow_group_id",
"host_account_id",
"key_fingerprint",
"account_status",
"last_probe_status",
"last_probe_at",
"disabled_reason",
} {
if !tableColumnExists(t, store.SQLDB(), "provider_accounts", column) {
t.Fatalf("column %q missing from provider_accounts", column)
}
}
}
func TestStoreInitEnforcesLogicalRoutingConstraints(t *testing.T) {