Files
sub2api-cn-relay-manager/internal/store/sqlite/provider_accounts_repo.go
2026-05-29 19:07:01 +08:00

558 lines
18 KiB
Go

package sqlite
import (
"context"
"database/sql"
"fmt"
"strings"
)
const (
ProviderAccountStatusActive = "active"
ProviderAccountStatusDisabled = "disabled"
ProviderAccountStatusDeprecated = "deprecated"
ProviderAccountStatusBroken = "broken"
ProviderAccountBindingStateAssigned = "assigned"
ProviderAccountBindingStateUnassigned = "unassigned"
ProviderAccountBindingStateConflict = "conflict"
)
type ProviderAccount struct {
ID int64
HostID int64
ProviderID int64
RouteID string
ShadowGroupID string
HostAccountID string
KeyFingerprint string
AccountName string
AccountStatus string
LastProbeStatus string
LastProbeAt string
DisabledReason string
CreatedAt string
UpdatedAt string
}
type ProviderAccountListFilter struct {
HostID string
ProviderID string
LogicalGroupID string
RouteID string
ShadowGroupID string
AccountStatus string
BindingState string
Query string
Limit int
}
type ProviderAccountView struct {
ID int64 `json:"id"`
HostID string `json:"host_id"`
HostBaseURL string `json:"host_base_url"`
ProviderID string `json:"provider_id"`
ProviderName string `json:"provider_name"`
RouteName string `json:"route_name,omitempty"`
RouteID string `json:"route_id,omitempty"`
LogicalGroupID string `json:"logical_group_id,omitempty"`
ShadowGroupID string `json:"shadow_group_id,omitempty"`
ShadowHostID string `json:"shadow_host_id,omitempty"`
UpstreamBaseURLHint string `json:"upstream_base_url_hint,omitempty"`
HostAccountID string `json:"host_account_id"`
KeyFingerprint string `json:"key_fingerprint"`
AccountName string `json:"account_name"`
AccountStatus string `json:"account_status"`
BindingState string `json:"binding_state,omitempty"`
BindingCandidateCount int `json:"binding_candidate_count,omitempty"`
LastProbeStatus string `json:"last_probe_status,omitempty"`
LastProbeAt string `json:"last_probe_at,omitempty"`
DisabledReason string `json:"disabled_reason,omitempty"`
CreatedAt string `json:"created_at,omitempty"`
UpdatedAt string `json:"updated_at,omitempty"`
}
type ProviderAccountsRepo struct {
db execQuerier
}
func newProviderAccountsRepo(db execQuerier) *ProviderAccountsRepo {
return &ProviderAccountsRepo{db: db}
}
func (r *ProviderAccountsRepo) Create(ctx context.Context, account ProviderAccount) (int64, error) {
account, err := normalizeProviderAccount(account)
if err != nil {
return 0, err
}
result, err := r.db.ExecContext(ctx, `INSERT INTO provider_accounts (
host_id, provider_id, route_id, shadow_group_id, host_account_id, key_fingerprint,
account_name, account_status, last_probe_status, last_probe_at, disabled_reason
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
account.HostID,
account.ProviderID,
account.RouteID,
account.ShadowGroupID,
account.HostAccountID,
account.KeyFingerprint,
account.AccountName,
account.AccountStatus,
account.LastProbeStatus,
account.LastProbeAt,
account.DisabledReason,
)
if err != nil {
return 0, fmt.Errorf("insert provider account %q: %w", account.HostAccountID, err)
}
id, err := result.LastInsertId()
if err != nil {
return 0, fmt.Errorf("read inserted provider account id for %q: %w", account.HostAccountID, err)
}
return id, nil
}
func (r *ProviderAccountsRepo) Upsert(ctx context.Context, account ProviderAccount) (int64, error) {
account, err := normalizeProviderAccount(account)
if err != nil {
return 0, err
}
result, err := r.db.ExecContext(ctx, `INSERT INTO provider_accounts (
host_id, provider_id, route_id, shadow_group_id, host_account_id, key_fingerprint,
account_name, account_status, last_probe_status, last_probe_at, disabled_reason
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(host_id, host_account_id) DO UPDATE SET
provider_id = excluded.provider_id,
route_id = excluded.route_id,
shadow_group_id = excluded.shadow_group_id,
key_fingerprint = excluded.key_fingerprint,
account_name = excluded.account_name,
account_status = excluded.account_status,
last_probe_status = excluded.last_probe_status,
last_probe_at = excluded.last_probe_at,
disabled_reason = excluded.disabled_reason,
updated_at = CURRENT_TIMESTAMP`,
account.HostID,
account.ProviderID,
account.RouteID,
account.ShadowGroupID,
account.HostAccountID,
account.KeyFingerprint,
account.AccountName,
account.AccountStatus,
account.LastProbeStatus,
account.LastProbeAt,
account.DisabledReason,
)
if err != nil {
return 0, fmt.Errorf("upsert provider account %q: %w", account.HostAccountID, err)
}
id, err := result.LastInsertId()
if err != nil {
return 0, fmt.Errorf("read provider account last insert id for %q: %w", account.HostAccountID, err)
}
if existing, err := r.GetByHostIDAndAccountID(ctx, account.HostID, account.HostAccountID); err == nil {
return existing.ID, nil
}
return id, nil
}
func (r *ProviderAccountsRepo) GetByID(ctx context.Context, id int64) (ProviderAccount, error) {
if id <= 0 {
return ProviderAccount{}, fmt.Errorf("id is required")
}
return r.scanOne(ctx, `SELECT id, host_id, provider_id, route_id, shadow_group_id, host_account_id, key_fingerprint, account_name, account_status, last_probe_status, last_probe_at, disabled_reason, created_at, updated_at FROM provider_accounts WHERE id = ?`, id)
}
func (r *ProviderAccountsRepo) GetByHostIDAndAccountID(ctx context.Context, hostID int64, hostAccountID string) (ProviderAccount, error) {
if hostID <= 0 {
return ProviderAccount{}, fmt.Errorf("host_id is required")
}
hostAccountID = strings.TrimSpace(hostAccountID)
if hostAccountID == "" {
return ProviderAccount{}, fmt.Errorf("host_account_id is required")
}
return r.scanOne(ctx, `SELECT id, host_id, provider_id, route_id, shadow_group_id, host_account_id, key_fingerprint, account_name, account_status, last_probe_status, last_probe_at, disabled_reason, created_at, updated_at FROM provider_accounts WHERE host_id = ? AND host_account_id = ?`, hostID, hostAccountID)
}
func (r *ProviderAccountsRepo) GetViewByID(ctx context.Context, id int64) (ProviderAccountView, error) {
if id <= 0 {
return ProviderAccountView{}, fmt.Errorf("id is required")
}
return r.scanViewOne(ctx, `SELECT
pa.id,
h.host_id,
h.base_url,
p.provider_id,
p.display_name,
COALESCE(lgr.name, ''),
COALESCE(pa.route_id, ''),
COALESCE(lgr.logical_group_id, ''),
COALESCE(pa.shadow_group_id, ''),
COALESCE(lgr.shadow_host_id, ''),
COALESCE(lgr.upstream_base_url_hint, ''),
pa.host_account_id,
pa.key_fingerprint,
pa.account_name,
pa.account_status,
CASE
WHEN COALESCE(pa.route_id, '') <> '' THEN 'assigned'
WHEN COALESCE((
SELECT COUNT(1)
FROM logical_group_routes lgrb
WHERE lgrb.shadow_host_id = h.host_id AND lgrb.shadow_group_id = pa.shadow_group_id
), 0) > 1 THEN 'conflict'
ELSE 'unassigned'
END,
COALESCE((
SELECT COUNT(1)
FROM logical_group_routes lgrb
WHERE lgrb.shadow_host_id = h.host_id AND lgrb.shadow_group_id = pa.shadow_group_id
), 0),
COALESCE(pa.last_probe_status, ''),
COALESCE(pa.last_probe_at, ''),
COALESCE(pa.disabled_reason, ''),
pa.created_at,
pa.updated_at
FROM provider_accounts pa
JOIN hosts h ON h.id = pa.host_id
JOIN providers p ON p.id = pa.provider_id
LEFT JOIN logical_group_routes lgr ON lgr.route_id = pa.route_id
WHERE pa.id = ?`, id)
}
func (r *ProviderAccountsRepo) UpdateStatusByID(ctx context.Context, id int64, accountStatus, disabledReason string) error {
if id <= 0 {
return fmt.Errorf("id is required")
}
accountStatus = normalizeProviderAccountStatus(accountStatus)
if accountStatus == "" {
return fmt.Errorf("account_status is required")
}
disabledReason = strings.TrimSpace(disabledReason)
result, err := r.db.ExecContext(ctx, `UPDATE provider_accounts SET account_status = ?, disabled_reason = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?`, accountStatus, disabledReason, id)
if err != nil {
return fmt.Errorf("update provider account %d status: %w", id, err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("provider account %d rows affected: %w", id, err)
}
if rows == 0 {
return sql.ErrNoRows
}
return nil
}
func (r *ProviderAccountsRepo) UpdateBindingByID(ctx context.Context, id int64, routeID, shadowGroupID string) error {
if id <= 0 {
return fmt.Errorf("id is required")
}
routeID = strings.TrimSpace(routeID)
shadowGroupID = strings.TrimSpace(shadowGroupID)
result, err := r.db.ExecContext(ctx, `UPDATE provider_accounts
SET route_id = ?, shadow_group_id = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = ?`, routeID, shadowGroupID, id)
if err != nil {
return fmt.Errorf("update provider account %d binding: %w", id, err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("provider account %d binding rows affected: %w", id, err)
}
if rows == 0 {
return sql.ErrNoRows
}
return nil
}
func (r *ProviderAccountsRepo) DeprecateMissingForScope(ctx context.Context, providerID, hostID int64, keepHostAccountIDs []string, reason string) error {
if providerID <= 0 {
return fmt.Errorf("provider_id is required")
}
if hostID <= 0 {
return fmt.Errorf("host_id is required")
}
reason = strings.TrimSpace(reason)
args := []any{ProviderAccountStatusDeprecated, reason, providerID, hostID}
query := `UPDATE provider_accounts
SET account_status = ?, disabled_reason = ?, updated_at = CURRENT_TIMESTAMP
WHERE provider_id = ? AND host_id = ?`
if len(keepHostAccountIDs) > 0 {
placeholders := make([]string, 0, len(keepHostAccountIDs))
for _, id := range keepHostAccountIDs {
trimmed := strings.TrimSpace(id)
if trimmed == "" {
continue
}
placeholders = append(placeholders, "?")
args = append(args, trimmed)
}
if len(placeholders) > 0 {
query += ` AND host_account_id NOT IN (` + strings.Join(placeholders, ",") + `)`
}
}
query += ` AND account_status IN ('active', 'broken')`
if _, err := r.db.ExecContext(ctx, query, args...); err != nil {
return fmt.Errorf("deprecate missing provider accounts for provider_id=%d host_id=%d: %w", providerID, hostID, err)
}
return nil
}
func (r *ProviderAccountsRepo) List(ctx context.Context, filter ProviderAccountListFilter) ([]ProviderAccountView, error) {
query := `SELECT
pa.id,
h.host_id,
h.base_url,
p.provider_id,
p.display_name,
COALESCE(lgr.name, ''),
COALESCE(pa.route_id, ''),
COALESCE(lgr.logical_group_id, ''),
COALESCE(pa.shadow_group_id, ''),
COALESCE(lgr.shadow_host_id, ''),
COALESCE(lgr.upstream_base_url_hint, ''),
pa.host_account_id,
pa.key_fingerprint,
pa.account_name,
pa.account_status,
CASE
WHEN COALESCE(pa.route_id, '') <> '' THEN 'assigned'
WHEN COALESCE((
SELECT COUNT(1)
FROM logical_group_routes lgrb
WHERE lgrb.shadow_host_id = h.host_id AND lgrb.shadow_group_id = pa.shadow_group_id
), 0) > 1 THEN 'conflict'
ELSE 'unassigned'
END,
COALESCE((
SELECT COUNT(1)
FROM logical_group_routes lgrb
WHERE lgrb.shadow_host_id = h.host_id AND lgrb.shadow_group_id = pa.shadow_group_id
), 0),
COALESCE(pa.last_probe_status, ''),
COALESCE(pa.last_probe_at, ''),
COALESCE(pa.disabled_reason, ''),
pa.created_at,
pa.updated_at
FROM provider_accounts pa
JOIN hosts h ON h.id = pa.host_id
JOIN providers p ON p.id = pa.provider_id
LEFT JOIN logical_group_routes lgr ON lgr.route_id = pa.route_id
WHERE 1 = 1`
args := make([]any, 0)
if value := strings.TrimSpace(filter.HostID); value != "" {
query += ` AND h.host_id = ?`
args = append(args, value)
}
if value := strings.TrimSpace(filter.ProviderID); value != "" {
query += ` AND p.provider_id = ?`
args = append(args, value)
}
if value := strings.TrimSpace(filter.LogicalGroupID); value != "" {
query += ` AND lgr.logical_group_id = ?`
args = append(args, value)
}
if value := strings.TrimSpace(filter.RouteID); value != "" {
query += ` AND pa.route_id = ?`
args = append(args, value)
}
if value := strings.TrimSpace(filter.ShadowGroupID); value != "" {
query += ` AND pa.shadow_group_id = ?`
args = append(args, value)
}
if value := normalizeProviderAccountStatus(filter.AccountStatus); value != "" {
query += ` AND pa.account_status = ?`
args = append(args, value)
}
if value := normalizeProviderAccountBindingState(filter.BindingState); value != "" {
switch value {
case ProviderAccountBindingStateAssigned:
query += ` AND COALESCE(pa.route_id, '') <> ''`
case ProviderAccountBindingStateConflict:
query += ` AND COALESCE(pa.route_id, '') = '' AND COALESCE((
SELECT COUNT(1)
FROM logical_group_routes lgrb
WHERE lgrb.shadow_host_id = h.host_id AND lgrb.shadow_group_id = pa.shadow_group_id
), 0) > 1`
case ProviderAccountBindingStateUnassigned:
query += ` AND COALESCE(pa.route_id, '') = '' AND COALESCE((
SELECT COUNT(1)
FROM logical_group_routes lgrb
WHERE lgrb.shadow_host_id = h.host_id AND lgrb.shadow_group_id = pa.shadow_group_id
), 0) <= 1`
}
}
if value := strings.TrimSpace(filter.Query); value != "" {
like := "%" + strings.ToLower(value) + "%"
query += ` AND (
LOWER(pa.host_account_id) LIKE ? OR
LOWER(pa.account_name) LIKE ? OR
LOWER(pa.key_fingerprint) LIKE ? OR
LOWER(p.provider_id) LIKE ? OR
LOWER(h.host_id) LIKE ? OR
LOWER(COALESCE(lgr.logical_group_id, '')) LIKE ? OR
LOWER(COALESCE(lgr.name, '')) LIKE ?
)`
args = append(args, like, like, like, like, like, like, like)
}
query += ` ORDER BY pa.updated_at DESC, pa.id DESC`
limit := filter.Limit
if limit <= 0 {
limit = 200
}
query += ` LIMIT ?`
args = append(args, limit)
rows, err := r.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("list provider accounts: %w", err)
}
defer rows.Close()
views := make([]ProviderAccountView, 0)
for rows.Next() {
var view ProviderAccountView
if err := rows.Scan(
&view.ID,
&view.HostID,
&view.HostBaseURL,
&view.ProviderID,
&view.ProviderName,
&view.RouteName,
&view.RouteID,
&view.LogicalGroupID,
&view.ShadowGroupID,
&view.ShadowHostID,
&view.UpstreamBaseURLHint,
&view.HostAccountID,
&view.KeyFingerprint,
&view.AccountName,
&view.AccountStatus,
&view.BindingState,
&view.BindingCandidateCount,
&view.LastProbeStatus,
&view.LastProbeAt,
&view.DisabledReason,
&view.CreatedAt,
&view.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("scan provider account view: %w", err)
}
views = append(views, view)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate provider accounts: %w", err)
}
return views, nil
}
func (r *ProviderAccountsRepo) scanOne(ctx context.Context, query string, args ...any) (ProviderAccount, error) {
var row ProviderAccount
if err := r.db.QueryRowContext(ctx, query, args...).Scan(
&row.ID,
&row.HostID,
&row.ProviderID,
&row.RouteID,
&row.ShadowGroupID,
&row.HostAccountID,
&row.KeyFingerprint,
&row.AccountName,
&row.AccountStatus,
&row.LastProbeStatus,
&row.LastProbeAt,
&row.DisabledReason,
&row.CreatedAt,
&row.UpdatedAt,
); err != nil {
return ProviderAccount{}, err
}
return row, nil
}
func (r *ProviderAccountsRepo) scanViewOne(ctx context.Context, query string, args ...any) (ProviderAccountView, error) {
var view ProviderAccountView
if err := r.db.QueryRowContext(ctx, query, args...).Scan(
&view.ID,
&view.HostID,
&view.HostBaseURL,
&view.ProviderID,
&view.ProviderName,
&view.RouteName,
&view.RouteID,
&view.LogicalGroupID,
&view.ShadowGroupID,
&view.ShadowHostID,
&view.UpstreamBaseURLHint,
&view.HostAccountID,
&view.KeyFingerprint,
&view.AccountName,
&view.AccountStatus,
&view.BindingState,
&view.BindingCandidateCount,
&view.LastProbeStatus,
&view.LastProbeAt,
&view.DisabledReason,
&view.CreatedAt,
&view.UpdatedAt,
); err != nil {
return ProviderAccountView{}, err
}
return view, nil
}
func normalizeProviderAccount(account ProviderAccount) (ProviderAccount, error) {
account.RouteID = strings.TrimSpace(account.RouteID)
account.ShadowGroupID = strings.TrimSpace(account.ShadowGroupID)
account.HostAccountID = strings.TrimSpace(account.HostAccountID)
account.KeyFingerprint = strings.TrimSpace(account.KeyFingerprint)
account.AccountName = strings.TrimSpace(account.AccountName)
account.AccountStatus = normalizeProviderAccountStatus(account.AccountStatus)
account.LastProbeStatus = strings.TrimSpace(account.LastProbeStatus)
account.LastProbeAt = strings.TrimSpace(account.LastProbeAt)
account.DisabledReason = strings.TrimSpace(account.DisabledReason)
switch {
case account.HostID <= 0:
return ProviderAccount{}, fmt.Errorf("host_id is required")
case account.ProviderID <= 0:
return ProviderAccount{}, fmt.Errorf("provider_id is required")
case account.HostAccountID == "":
return ProviderAccount{}, fmt.Errorf("host_account_id is required")
case account.KeyFingerprint == "":
return ProviderAccount{}, fmt.Errorf("key_fingerprint is required")
case account.AccountStatus == "":
return ProviderAccount{}, fmt.Errorf("account_status is required")
}
return account, nil
}
func normalizeProviderAccountStatus(status string) string {
switch strings.TrimSpace(status) {
case ProviderAccountStatusActive:
return ProviderAccountStatusActive
case ProviderAccountStatusDisabled:
return ProviderAccountStatusDisabled
case ProviderAccountStatusDeprecated:
return ProviderAccountStatusDeprecated
case ProviderAccountStatusBroken:
return ProviderAccountStatusBroken
default:
return ""
}
}
func normalizeProviderAccountBindingState(state string) string {
switch strings.TrimSpace(state) {
case ProviderAccountBindingStateAssigned:
return ProviderAccountBindingStateAssigned
case ProviderAccountBindingStateUnassigned:
return ProviderAccountBindingStateUnassigned
case ProviderAccountBindingStateConflict:
return ProviderAccountBindingStateConflict
default:
return ""
}
}