Files
sub2api-cn-relay-manager/internal/app/http_api.go

2562 lines
104 KiB
Go
Raw Normal View History

package app
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"sub2api-cn-relay-manager/internal/batch"
"sub2api-cn-relay-manager/internal/config"
"sub2api-cn-relay-manager/internal/host/sub2api"
"sub2api-cn-relay-manager/internal/pack"
"sub2api-cn-relay-manager/internal/provision"
"sub2api-cn-relay-manager/internal/reconcile"
"sub2api-cn-relay-manager/internal/store/sqlite"
"sub2api-cn-relay-manager/internal/access"
)
type ActionSet struct {
CreateBatchImportRun func(context.Context, CreateBatchImportRunRequest) (BatchImportRunCreateResponse, error)
ListBatchImportRuns func(context.Context, ListBatchImportRunsRequest) (ListBatchImportRunsResponse, error)
GetBatchImportRun func(context.Context, string) (batch.RunSummaryProjection, error)
ListBatchImportRunItems func(context.Context, ListBatchImportRunItemsRequest) (ListBatchImportRunItemsResponse, error)
GetBatchImportRunItem func(context.Context, GetBatchImportRunItemRequest) (batch.ItemDetailProjection, error)
CreateLogicalGroup func(context.Context, CreateLogicalGroupRequest) (LogicalGroupInfo, error)
ListLogicalGroups func(context.Context) ([]LogicalGroupInfo, error)
GetLogicalGroup func(context.Context, string) (LogicalGroupInfo, error)
UpdateLogicalGroup func(context.Context, UpdateLogicalGroupRequest) (LogicalGroupInfo, error)
DeleteLogicalGroup func(context.Context, string) error
CreateLogicalGroupModel func(context.Context, CreateLogicalGroupModelRequest) (LogicalGroupModelInfo, error)
ListLogicalGroupModels func(context.Context, string) ([]LogicalGroupModelInfo, error)
DeleteLogicalGroupModel func(context.Context, DeleteLogicalGroupModelRequest) error
CreateLogicalGroupRoute func(context.Context, CreateLogicalGroupRouteRequest) (LogicalGroupRouteInfo, error)
ListLogicalGroupRoutes func(context.Context, string) ([]LogicalGroupRouteInfo, error)
UpdateLogicalGroupRoute func(context.Context, UpdateLogicalGroupRouteRequest) (LogicalGroupRouteInfo, error)
DeleteLogicalGroupRoute func(context.Context, DeleteLogicalGroupRouteRequest) error
CreateLogicalGroupRouteModel func(context.Context, CreateLogicalGroupRouteModelRequest) (LogicalGroupRouteModelInfo, error)
ListLogicalGroupRouteModels func(context.Context, ListLogicalGroupRouteModelsRequest) ([]LogicalGroupRouteModelInfo, error)
AppendRouteDecisionLog func(context.Context, AppendRouteDecisionLogRequest) (RouteDecisionLogInfo, error)
ListRouteDecisionLogs func(context.Context, ListRouteDecisionLogsRequest) ([]RouteDecisionLogInfo, error)
AppendRouteFailoverEvent func(context.Context, AppendRouteFailoverEventRequest) (RouteFailoverEventInfo, error)
ListRouteFailoverEvents func(context.Context, ListRouteFailoverEventsRequest) ([]RouteFailoverEventInfo, error)
AppendRouteStickyAudit func(context.Context, AppendRouteStickyAuditRequest) (RouteStickyAuditInfo, error)
ListRouteStickyAudit func(context.Context, ListRouteStickyAuditRequest) ([]RouteStickyAuditInfo, error)
ResolveRoute func(context.Context, ResolveRouteRequest) (ResolveRouteInfo, error)
SetStickyBinding func(context.Context, SetStickyBindingRequest) (StickyBindingInfo, error)
GetStickyBinding func(context.Context, GetStickyBindingRequest) (StickyBindingInfo, error)
SetRouteFailure func(context.Context, SetRouteFailureRequest) (RouteFailureInfo, error)
GetRouteFailure func(context.Context, GetRouteFailureRequest) (RouteFailureInfo, error)
SetRouteCooldown func(context.Context, SetRouteCooldownRequest) (RouteCooldownInfo, error)
GetRouteCooldown func(context.Context, GetRouteCooldownRequest) (RouteCooldownInfo, error)
CreateProviderDraft func(context.Context, CreateProviderDraftRequest) (ProviderDraftInfo, error)
ListProviderDrafts func(context.Context, ListProviderDraftsRequest) ([]ProviderDraftInfo, error)
GetProviderDraft func(context.Context, string) (ProviderDraftInfo, error)
UpdateProviderDraft func(context.Context, UpdateProviderDraftRequest) (ProviderDraftInfo, error)
DeleteProviderDraft func(context.Context, string) error
PublishProviderDraft func(context.Context, PublishProviderDraftRequest) (PublishProviderDraftResult, error)
InstallPack func(context.Context, InstallPackRequest) (provision.PackInstallResult, error)
BatchDetail func(context.Context, BatchDetailRequest) (provision.BatchDetailResult, error)
GetProviderStatus func(context.Context, ProviderQueryRequest) (provision.ProviderSnapshot, error)
GetProviderResources func(context.Context, ProviderQueryRequest) (provision.ProviderSnapshot, error)
GetProviderAccessStatus func(context.Context, ProviderQueryRequest) (provision.ProviderSnapshot, error)
ListProviderImportBatches func(context.Context, ProviderQueryRequest) ([]ImportBatchInfo, error)
PreviewProvider func(context.Context, PreviewProviderRequest) (provision.PreviewReport, error)
ImportProvider func(context.Context, ImportProviderRequest) (provision.RuntimeImportResult, error)
RollbackProvider func(context.Context, RollbackProviderRequest) (provision.RollbackReport, error)
RollbackBatch func(context.Context, RollbackBatchRequest) (provision.RollbackReport, error)
ReconcileProvider func(context.Context, ReconcileProviderRequest) (reconcile.Result, error)
CreateHost func(context.Context, CreateHostRequest) (HostInfo, error)
ProbeHost func(context.Context, ProbeHostRequest) (HostInfo, error)
ListHosts func(context.Context) ([]HostInfo, error)
GetHost func(context.Context, string) (HostInfo, error)
DeleteHost func(context.Context, string) error
ListPacks func(context.Context) ([]PackInfo, error)
GetPack func(context.Context, string) (PackInfo, error)
ListPackProviders func(context.Context, string) ([]PackProviderInfo, error)
AssignAccessSubscriptions func(context.Context, AssignAccessSubscriptionsRequest) (AssignAccessSubscriptionsResult, error)
AccessPreview func(context.Context, AccessPreviewRequest) (AccessPreviewResult, error)
}
const maxJSONBodyBytes int64 = 1 << 20
type HostInfo struct {
HostID string `json:"host_id"`
BaseURL string `json:"base_url"`
HostVersion string `json:"host_version"`
AuthType string `json:"auth_type,omitempty"`
Status string `json:"status,omitempty"`
Capabilities *sub2api.HostCapabilities `json:"capabilities,omitempty"`
}
type CreateHostRequest struct {
Name string `json:"name"`
BaseURL string `json:"base_url"`
Auth CreateHostAuth `json:"auth"`
}
type ProbeHostRequest struct {
HostID string `json:"-"`
Auth CreateHostAuth `json:"auth"`
}
type CreateHostAuth struct {
Type string `json:"type"`
Token string `json:"token"`
}
type ImportBatchInfo struct {
BatchID int64 `json:"batch_id"`
BatchStatus string `json:"batch_status"`
AccessStatus string `json:"access_status"`
}
type PackInfo struct {
PackID string `json:"pack_id"`
Version string `json:"version"`
Vendor string `json:"vendor,omitempty"`
TargetHost string `json:"target_host,omitempty"`
MinHostVersion string `json:"min_host_version,omitempty"`
MaxHostVersion string `json:"max_host_version,omitempty"`
}
type PackProviderInfo struct {
ProviderID string `json:"provider_id"`
DisplayName string `json:"display_name"`
Platform string `json:"platform,omitempty"`
HostOverlays int `json:"host_overlays,omitempty"`
BaseURL string `json:"base_url,omitempty"`
SmokeTestModel string `json:"smoke_test_model,omitempty"`
SupportedModels []string `json:"supported_models,omitempty"`
}
type CreateProviderDraftRequest struct {
DraftID string `json:"draft_id,omitempty"`
PackID string `json:"pack_id"`
ProviderID string `json:"provider_id"`
DisplayName string `json:"display_name"`
Platform string `json:"platform"`
BaseURL string `json:"base_url,omitempty"`
SmokeTestModel string `json:"smoke_test_model,omitempty"`
SupportedModels []string `json:"supported_models,omitempty"`
Manifest json.RawMessage `json:"manifest,omitempty"`
SourceHostID string `json:"source_host_id,omitempty"`
Notes string `json:"notes,omitempty"`
}
type ListProviderDraftsRequest struct {
PackID string
ProviderID string
Query string
}
type UpdateProviderDraftRequest struct {
DraftID string `json:"-"`
CreateProviderDraftRequest
}
type PublishProviderDraftRequest struct {
DraftID string `json:"-"`
CommitMessage string `json:"commit_message,omitempty"`
}
type ProviderDraftInfo struct {
DraftID string `json:"draft_id"`
PackID string `json:"pack_id"`
ProviderID string `json:"provider_id"`
DisplayName string `json:"display_name"`
Platform string `json:"platform"`
BaseURL string `json:"base_url,omitempty"`
SmokeTestModel string `json:"smoke_test_model,omitempty"`
SupportedModels []string `json:"supported_models,omitempty"`
Manifest any `json:"manifest,omitempty"`
SourceHostID string `json:"source_host_id,omitempty"`
Notes string `json:"notes,omitempty"`
CreatedAt string `json:"created_at,omitempty"`
UpdatedAt string `json:"updated_at,omitempty"`
}
type PublishProviderDraftResult struct {
DraftID string `json:"draft_id"`
PackID string `json:"pack_id"`
ProviderID string `json:"provider_id"`
ProviderPath string `json:"provider_path"`
PackVersionBefore string `json:"pack_version_before"`
PackVersionAfter string `json:"pack_version_after"`
PublishMode string `json:"publish_mode"`
CommitMessage string `json:"commit_message"`
CommitSHA string `json:"commit_sha"`
RepoRoot string `json:"repo_root"`
}
type AssignAccessSubscriptionsRequest struct {
HostID string `json:"host_id,omitempty"`
PackPath string `json:"pack_path"`
ProviderID string `json:"provider_id"`
HostBaseURL string `json:"host_base_url,omitempty"`
HostAPIKey string `json:"host_api_key,omitempty"`
HostBearerToken string `json:"host_bearer_token,omitempty"`
AccessAPIKey string `json:"access_api_key"`
SubscriptionUsers []string `json:"subscription_users"`
SubscriptionDays int `json:"subscription_days"`
}
type AssignAccessSubscriptionsResult struct {
ProviderID string `json:"provider_id"`
Assigned int `json:"assigned"`
AccessStatus string `json:"access_status"`
}
type AccessPreviewRequest struct {
ProviderID string `json:"provider_id"`
PackID string `json:"pack_id,omitempty"`
HostID string `json:"host_id,omitempty"`
Mode string `json:"mode"`
}
type AccessPreviewResult struct {
ProviderID string `json:"provider_id"`
Mode string `json:"mode"`
Available bool `json:"available"`
Message string `json:"message,omitempty"`
}
type InstallPackRequest struct {
HostBaseURL string `json:"host_base_url"`
HostAPIKey string `json:"host_api_key"`
HostBearerToken string `json:"host_bearer_token"`
PackPath string `json:"pack_path"`
}
type BatchDetailRequest struct {
BatchID int64
}
type ProviderQueryRequest struct {
ProviderID string
PackID string
HostID string
}
type RollbackProviderRequest struct {
HostID string `json:"host_id,omitempty"`
HostBaseURL string `json:"host_base_url,omitempty"`
HostAPIKey string `json:"host_api_key,omitempty"`
HostBearerToken string `json:"host_bearer_token,omitempty"`
PackPath string `json:"pack_path"`
ProviderID string `json:"provider_id"`
}
type RollbackBatchRequest struct {
BatchID int64 `json:"-"`
Auth CreateHostAuth `json:"auth"`
}
type ReconcileProviderRequest struct {
HostID string `json:"host_id,omitempty"`
HostBaseURL string `json:"host_base_url,omitempty"`
HostAPIKey string `json:"host_api_key,omitempty"`
HostBearerToken string `json:"host_bearer_token,omitempty"`
PackPath string `json:"pack_path"`
ProviderID string `json:"provider_id"`
AccessAPIKey string `json:"access_api_key"`
}
type PreviewProviderRequest struct {
HostID string `json:"host_id,omitempty"`
HostBaseURL string `json:"host_base_url,omitempty"`
HostAPIKey string `json:"host_api_key,omitempty"`
HostBearerToken string `json:"host_bearer_token,omitempty"`
PackPath string `json:"pack_path"`
ProviderID string `json:"provider_id"`
Keys []string `json:"keys"`
Mode string `json:"mode"`
}
type ImportProviderRequest struct {
HostID string `json:"host_id,omitempty"`
HostBaseURL string `json:"host_base_url,omitempty"`
HostAPIKey string `json:"host_api_key,omitempty"`
HostBearerToken string `json:"host_bearer_token,omitempty"`
PackPath string `json:"pack_path"`
ProviderID string `json:"provider_id"`
Keys []string `json:"keys"`
Mode string `json:"mode"`
AccessMode string `json:"access_mode"`
AccessAPIKey string `json:"access_api_key"`
SubscriptionUsers []string `json:"subscription_users"`
SubscriptionDays int `json:"subscription_days"`
}
type httpError struct {
StatusCode int `json:"-"`
Code string `json:"code"`
Message string `json:"message"`
UpstreamStatus int `json:"upstream_status,omitempty"`
}
func (e *httpError) Error() string {
return e.Message
}
func NewAPIHandler(adminToken string, actions ActionSet) http.Handler {
return NewAPIHandlerWithAuth(AdminAuthConfig{Token: adminToken}, actions)
}
func NewAPIHandlerWithAuth(adminAuth AdminAuthConfig, actions ActionSet) http.Handler {
mux := http.NewServeMux()
mux.HandleFunc("GET /healthz", healthz)
mux.HandleFunc("GET /api/admin/session", func(w http.ResponseWriter, r *http.Request) {
handleAdminSessionState(w, r, adminAuth)
})
mux.HandleFunc("POST /api/admin/session/login", func(w http.ResponseWriter, r *http.Request) {
handleAdminSessionLogin(w, r, adminAuth)
})
mux.HandleFunc("POST /api/admin/session/logout", func(w http.ResponseWriter, r *http.Request) {
handleAdminSessionLogout(w, r)
})
mux.Handle("POST /api/batch-import/runs", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleCreateBatchImportRun(w, r, actions.CreateBatchImportRun)
})))
mux.Handle("GET /api/batch-import/runs", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleListBatchImportRuns(w, r, actions.ListBatchImportRuns)
})))
mux.Handle("GET /api/batch-import/runs/{run_id}", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleGetBatchImportRun(w, r, actions.GetBatchImportRun)
})))
mux.Handle("GET /api/batch-import/runs/{run_id}/items", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleListBatchImportRunItems(w, r, actions.ListBatchImportRunItems)
})))
mux.Handle("GET /api/batch-import/runs/{run_id}/items/{item_id}", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleGetBatchImportRunItem(w, r, actions.GetBatchImportRunItem)
})))
mux.Handle("POST /api/logical-groups", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleCreateLogicalGroup(w, r, actions.CreateLogicalGroup)
})))
mux.Handle("GET /api/logical-groups", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleListLogicalGroups(w, r, actions.ListLogicalGroups)
})))
mux.Handle("GET /api/logical-groups/{groupID}", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleGetLogicalGroup(w, r, actions.GetLogicalGroup)
})))
mux.Handle("PUT /api/logical-groups/{groupID}", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleUpdateLogicalGroup(w, r, actions.UpdateLogicalGroup)
})))
mux.Handle("DELETE /api/logical-groups/{groupID}", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleDeleteLogicalGroup(w, r, actions.DeleteLogicalGroup)
})))
mux.Handle("POST /api/logical-groups/{groupID}/models", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleCreateLogicalGroupModel(w, r, actions.CreateLogicalGroupModel)
})))
mux.Handle("GET /api/logical-groups/{groupID}/models", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleListLogicalGroupModels(w, r, actions.ListLogicalGroupModels)
})))
mux.Handle("DELETE /api/logical-groups/{groupID}/models/{model}", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleDeleteLogicalGroupModel(w, r, actions.DeleteLogicalGroupModel)
})))
mux.Handle("POST /api/logical-groups/{groupID}/routes", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleCreateLogicalGroupRoute(w, r, actions.CreateLogicalGroupRoute)
})))
mux.Handle("GET /api/logical-groups/{groupID}/routes", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleListLogicalGroupRoutes(w, r, actions.ListLogicalGroupRoutes)
})))
mux.Handle("PUT /api/logical-groups/{groupID}/routes/{routeID}", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleUpdateLogicalGroupRoute(w, r, actions.UpdateLogicalGroupRoute)
})))
mux.Handle("DELETE /api/logical-groups/{groupID}/routes/{routeID}", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleDeleteLogicalGroupRoute(w, r, actions.DeleteLogicalGroupRoute)
})))
mux.Handle("POST /api/logical-groups/{groupID}/routes/{routeID}/models", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleCreateLogicalGroupRouteModel(w, r, actions.CreateLogicalGroupRouteModel)
})))
mux.Handle("GET /api/logical-groups/{groupID}/routes/{routeID}/models", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleListLogicalGroupRouteModels(w, r, actions.ListLogicalGroupRouteModels)
})))
mux.Handle("POST /api/routing/logs/decisions", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleAppendRouteDecisionLog(w, r, actions.AppendRouteDecisionLog)
})))
mux.Handle("GET /api/routing/logs/decisions", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleListRouteDecisionLogs(w, r, actions.ListRouteDecisionLogs)
})))
mux.Handle("POST /api/routing/logs/failovers", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleAppendRouteFailoverEvent(w, r, actions.AppendRouteFailoverEvent)
})))
mux.Handle("GET /api/routing/logs/failovers", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleListRouteFailoverEvents(w, r, actions.ListRouteFailoverEvents)
})))
mux.Handle("POST /api/routing/logs/sticky-audit", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleAppendRouteStickyAudit(w, r, actions.AppendRouteStickyAudit)
})))
mux.Handle("GET /api/routing/logs/sticky-audit", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleListRouteStickyAudit(w, r, actions.ListRouteStickyAudit)
})))
mux.Handle("POST /api/routing/resolve", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleResolveRoute(w, r, actions.ResolveRoute)
})))
mux.Handle("POST /api/routing/sticky/bindings", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleSetStickyBinding(w, r, actions.SetStickyBinding)
})))
mux.Handle("GET /api/routing/sticky/bindings", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleGetStickyBinding(w, r, actions.GetStickyBinding)
})))
mux.Handle("POST /api/routing/sticky/route-failures", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleSetRouteFailure(w, r, actions.SetRouteFailure)
})))
mux.Handle("GET /api/routing/sticky/route-failures", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleGetRouteFailure(w, r, actions.GetRouteFailure)
})))
mux.Handle("POST /api/routing/sticky/cooldowns", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleSetRouteCooldown(w, r, actions.SetRouteCooldown)
})))
mux.Handle("GET /api/routing/sticky/cooldowns", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleGetRouteCooldown(w, r, actions.GetRouteCooldown)
})))
mux.Handle("POST /api/provider-drafts", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleCreateProviderDraft(w, r, actions.CreateProviderDraft)
})))
mux.Handle("GET /api/provider-drafts", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleListProviderDrafts(w, r, actions.ListProviderDrafts)
})))
mux.Handle("GET /api/provider-drafts/{draftID}", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleGetProviderDraft(w, r, actions.GetProviderDraft)
})))
mux.Handle("PUT /api/provider-drafts/{draftID}", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleUpdateProviderDraft(w, r, actions.UpdateProviderDraft)
})))
mux.Handle("DELETE /api/provider-drafts/{draftID}", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleDeleteProviderDraft(w, r, actions.DeleteProviderDraft)
})))
mux.Handle("POST /api/provider-drafts/{draftID}/publish", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handlePublishProviderDraft(w, r, actions.PublishProviderDraft)
})))
mux.Handle("GET /api/import-batches/{batchID}", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleBatchDetail(w, r, actions.BatchDetail)
})))
mux.Handle("POST /api/import-batches/{batchID}/rollback", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleRollbackBatch(w, r, actions.RollbackBatch)
})))
mux.Handle("GET /api/providers/{providerID}/status", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleProviderStatus(w, r, actions.GetProviderStatus)
})))
mux.Handle("GET /api/providers/{providerID}/resources", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleProviderResources(w, r, actions.GetProviderResources)
})))
mux.Handle("GET /api/providers/{providerID}/access/status", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleProviderAccessStatus(w, r, actions.GetProviderAccessStatus)
})))
mux.Handle("GET /api/providers/{providerID}/import-batches", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleListProviderImportBatches(w, r, actions.ListProviderImportBatches)
})))
mux.Handle("POST /api/providers/{providerID}/access/assign-subscriptions", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleAssignAccessSubscriptions(w, r, actions.AssignAccessSubscriptions)
})))
mux.Handle("POST /api/providers/{providerID}/access/preview", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleAccessPreview(w, r, actions.AccessPreview)
})))
mux.Handle("POST /api/packs/install", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleInstallPack(w, r, actions.InstallPack)
})))
mux.Handle("GET /api/packs", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleListPacks(w, r, actions.ListPacks)
})))
mux.Handle("GET /api/packs/{packID}", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleGetPack(w, r, actions.GetPack)
})))
mux.Handle("GET /api/packs/{packID}/providers", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleListPackProviders(w, r, actions.ListPackProviders)
})))
mux.Handle("POST /api/providers/{providerID}/preview-import", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handlePreviewProvider(w, r, actions.PreviewProvider)
})))
mux.Handle("POST /api/providers/{providerID}/import", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleImportProvider(w, r, actions.ImportProvider)
})))
mux.Handle("POST /api/providers/{providerID}/rollback", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleRollbackProvider(w, r, actions.RollbackProvider)
})))
mux.Handle("POST /api/providers/{providerID}/reconcile", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleReconcileProvider(w, r, actions.ReconcileProvider)
})))
mux.Handle("GET /api/hosts", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleListHosts(w, r, actions.ListHosts)
})))
mux.Handle("GET /api/hosts/{hostID}", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleGetHost(w, r, actions.GetHost)
})))
mux.Handle("POST /api/hosts", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleCreateHost(w, r, actions.CreateHost)
})))
mux.Handle("POST /api/hosts/{hostID}/probe", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleProbeHost(w, r, actions.ProbeHost)
})))
mux.Handle("DELETE /api/hosts/{hostID}", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleDeleteHost(w, r, actions.DeleteHost)
})))
return mux
}
func healthz(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}
func handleCreateProviderDraft(w http.ResponseWriter, r *http.Request, fn func(context.Context, CreateProviderDraftRequest) (ProviderDraftInfo, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "create-provider-draft action is not configured"})
return
}
var req CreateProviderDraftRequest
if err := decodeJSON(r, &req); err != nil {
writeHTTPError(w, err)
return
}
draft, err := fn(r.Context(), req)
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
writeJSON(w, http.StatusCreated, map[string]any{"draft": draft})
}
func handleListProviderDrafts(w http.ResponseWriter, r *http.Request, fn func(context.Context, ListProviderDraftsRequest) ([]ProviderDraftInfo, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "list-provider-drafts action is not configured"})
return
}
drafts, err := fn(r.Context(), ListProviderDraftsRequest{
PackID: strings.TrimSpace(r.URL.Query().Get("pack_id")),
ProviderID: strings.TrimSpace(r.URL.Query().Get("provider_id")),
Query: strings.TrimSpace(r.URL.Query().Get("q")),
})
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
if drafts == nil {
drafts = []ProviderDraftInfo{}
}
writeJSON(w, http.StatusOK, map[string]any{"provider_drafts": drafts})
}
func handleGetProviderDraft(w http.ResponseWriter, r *http.Request, fn func(context.Context, string) (ProviderDraftInfo, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "get-provider-draft action is not configured"})
return
}
draftID := strings.TrimSpace(r.PathValue("draftID"))
if draftID == "" {
writeHTTPError(w, &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: "draft_id is required"})
return
}
draft, err := fn(r.Context(), draftID)
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
writeJSON(w, http.StatusOK, map[string]any{"draft": draft})
}
func handleUpdateProviderDraft(w http.ResponseWriter, r *http.Request, fn func(context.Context, UpdateProviderDraftRequest) (ProviderDraftInfo, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "update-provider-draft action is not configured"})
return
}
var req UpdateProviderDraftRequest
if err := decodeJSON(r, &req.CreateProviderDraftRequest); err != nil {
writeHTTPError(w, err)
return
}
req.DraftID = strings.TrimSpace(r.PathValue("draftID"))
if req.DraftID == "" {
writeHTTPError(w, &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: "draft_id is required"})
return
}
draft, err := fn(r.Context(), req)
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
writeJSON(w, http.StatusOK, map[string]any{"draft": draft})
}
func handleDeleteProviderDraft(w http.ResponseWriter, r *http.Request, fn func(context.Context, string) error) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "delete-provider-draft action is not configured"})
return
}
draftID := strings.TrimSpace(r.PathValue("draftID"))
if draftID == "" {
writeHTTPError(w, &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: "draft_id is required"})
return
}
if err := fn(r.Context(), draftID); err != nil {
writeHTTPError(w, classifyError(err))
return
}
w.WriteHeader(http.StatusNoContent)
}
func handlePublishProviderDraft(w http.ResponseWriter, r *http.Request, fn func(context.Context, PublishProviderDraftRequest) (PublishProviderDraftResult, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "publish-provider-draft action is not configured"})
return
}
var req PublishProviderDraftRequest
if err := decodeJSON(r, &req); err != nil {
writeHTTPError(w, err)
return
}
req.DraftID = strings.TrimSpace(r.PathValue("draftID"))
if req.DraftID == "" {
writeHTTPError(w, &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: "draft_id is required"})
return
}
result, err := fn(r.Context(), req)
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
writeJSON(w, http.StatusOK, map[string]any{"publish": result})
}
func bearerToken(r *http.Request) string {
header := strings.TrimSpace(r.Header.Get("Authorization"))
if !strings.HasPrefix(strings.ToLower(header), "bearer ") {
return ""
}
return strings.TrimSpace(header[len("Bearer "):])
}
func handleInstallPack(w http.ResponseWriter, r *http.Request, fn func(context.Context, InstallPackRequest) (provision.PackInstallResult, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "install-pack action is not configured"})
return
}
var req InstallPackRequest
if err := decodeJSON(r, &req); err != nil {
writeHTTPError(w, err)
return
}
result, err := fn(r.Context(), req)
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
providers := make([]map[string]string, 0, len(result.Providers))
for _, provider := range result.Providers {
providers = append(providers, map[string]string{
"provider_id": provider.ProviderID,
"display_name": provider.DisplayName,
})
}
writeJSON(w, http.StatusOK, map[string]any{
"pack_id": result.Pack.PackID,
"version": result.Pack.Version,
"host_version": result.HostVersion,
"already_installed": result.AlreadyInstalled,
"providers": providers,
})
}
func handleListPacks(w http.ResponseWriter, r *http.Request, fn func(context.Context) ([]PackInfo, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "list-packs action is not configured"})
return
}
packs, err := fn(r.Context())
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
if packs == nil {
packs = []PackInfo{}
}
writeJSON(w, http.StatusOK, map[string]any{"packs": packs})
}
func handleListProviderImportBatches(w http.ResponseWriter, r *http.Request, fn func(context.Context, ProviderQueryRequest) ([]ImportBatchInfo, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "list-provider-import-batches action is not configured"})
return
}
batches, err := fn(r.Context(), ProviderQueryRequest{ProviderID: r.PathValue("providerID"), PackID: r.URL.Query().Get("pack_id"), HostID: strings.TrimSpace(r.URL.Query().Get("host_id"))})
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
if batches == nil {
batches = []ImportBatchInfo{}
}
writeJSON(w, http.StatusOK, map[string]any{"batches": batches})
}
func handleAssignAccessSubscriptions(w http.ResponseWriter, r *http.Request, fn func(context.Context, AssignAccessSubscriptionsRequest) (AssignAccessSubscriptionsResult, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "assign-access-subscriptions action is not configured"})
return
}
var req AssignAccessSubscriptionsRequest
if err := decodeJSON(r, &req); err != nil {
writeHTTPError(w, err)
return
}
req.ProviderID = r.PathValue("providerID")
result, err := fn(r.Context(), req)
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
writeJSON(w, http.StatusOK, result)
}
func handleAccessPreview(w http.ResponseWriter, r *http.Request, fn func(context.Context, AccessPreviewRequest) (AccessPreviewResult, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "access-preview action is not configured"})
return
}
var req AccessPreviewRequest
if err := decodeJSON(r, &req); err != nil {
writeHTTPError(w, err)
return
}
req.ProviderID = r.PathValue("providerID")
if req.PackID == "" {
req.PackID = strings.TrimSpace(r.URL.Query().Get("pack_id"))
}
if req.HostID == "" {
req.HostID = strings.TrimSpace(r.URL.Query().Get("host_id"))
}
result, err := fn(r.Context(), req)
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
writeJSON(w, http.StatusOK, result)
}
func handleGetPack(w http.ResponseWriter, r *http.Request, fn func(context.Context, string) (PackInfo, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "get-pack action is not configured"})
return
}
packID := strings.TrimSpace(r.PathValue("packID"))
if packID == "" {
writeHTTPError(w, &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: "pack_id is required"})
return
}
pack, err := fn(r.Context(), packID)
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
writeJSON(w, http.StatusOK, pack)
}
func handleListPackProviders(w http.ResponseWriter, r *http.Request, fn func(context.Context, string) ([]PackProviderInfo, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "list-pack-providers action is not configured"})
return
}
packID := strings.TrimSpace(r.PathValue("packID"))
if packID == "" {
writeHTTPError(w, &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: "pack_id is required"})
return
}
providers, err := fn(r.Context(), packID)
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
if providers == nil {
providers = []PackProviderInfo{}
}
writeJSON(w, http.StatusOK, map[string]any{"pack_id": packID, "providers": providers})
}
func handleBatchDetail(w http.ResponseWriter, r *http.Request, fn func(context.Context, BatchDetailRequest) (provision.BatchDetailResult, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "batch-detail action is not configured"})
return
}
batchID, err := strconv.ParseInt(r.PathValue("batchID"), 10, 64)
if err != nil || batchID <= 0 {
writeHTTPError(w, &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: "batch_id must be a positive integer"})
return
}
result, err := fn(r.Context(), BatchDetailRequest{BatchID: batchID})
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
items := make([]map[string]any, 0, len(result.Items))
for _, item := range result.Items {
items = append(items, map[string]any{
"id": item.ID,
"batch_id": item.BatchID,
"key_fingerprint": item.KeyFingerprint,
"account_status": item.AccountStatus,
"probe_summary_json": item.ProbeSummaryJSON,
})
}
writeJSON(w, http.StatusOK, map[string]any{
"batch": map[string]any{
"id": result.Batch.ID,
"host_id": result.Batch.HostID,
"pack_id": result.Batch.PackID,
"provider_id": result.Batch.ProviderID,
"mode": result.Batch.Mode,
"batch_status": result.Batch.BatchStatus,
"access_status": result.Batch.AccessStatus,
},
"items": items,
"managed_resources": result.ManagedResources,
"access_closures": result.AccessClosures,
"reconcile_runs": result.ReconcileRuns,
"items_count": len(result.Items),
"managed_count": len(result.ManagedResources),
"access_count": len(result.AccessClosures),
"reconcile_count": len(result.ReconcileRuns),
})
}
func handleProviderStatus(w http.ResponseWriter, r *http.Request, fn func(context.Context, ProviderQueryRequest) (provision.ProviderSnapshot, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "provider-status action is not configured"})
return
}
result, err := fn(r.Context(), ProviderQueryRequest{ProviderID: r.PathValue("providerID"), PackID: strings.TrimSpace(r.URL.Query().Get("pack_id")), HostID: strings.TrimSpace(r.URL.Query().Get("host_id"))})
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
writeJSON(w, http.StatusOK, map[string]any{
"host": map[string]any{"host_id": result.Host.HostID, "base_url": result.Host.BaseURL, "host_version": result.Host.HostVersion},
"pack": map[string]any{"pack_id": result.Pack.PackID, "version": result.Pack.Version},
"provider": map[string]any{"provider_id": result.Provider.ProviderID, "display_name": result.Provider.DisplayName, "platform": result.Provider.Platform},
"batch": map[string]any{"id": result.Batch.ID, "batch_status": result.Batch.BatchStatus, "access_status": result.Batch.AccessStatus, "mode": result.Batch.Mode},
"provider_status": result.ProviderStatus,
"latest_access_status": result.LatestAccessStatus,
"latest_reconcile_status": result.LatestReconcileStatus,
"latest_reconcile_summary": result.LatestReconcileSummary,
"managed_resources_count": len(result.ManagedResources),
"access_closures_count": len(result.AccessClosures),
"reconcile_runs_count": len(result.ReconcileRuns),
})
}
func handleProviderAccessStatus(w http.ResponseWriter, r *http.Request, fn func(context.Context, ProviderQueryRequest) (provision.ProviderSnapshot, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "provider-access-status action is not configured"})
return
}
result, err := fn(r.Context(), ProviderQueryRequest{ProviderID: r.PathValue("providerID"), PackID: strings.TrimSpace(r.URL.Query().Get("pack_id")), HostID: strings.TrimSpace(r.URL.Query().Get("host_id"))})
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
latestClosure := map[string]any{}
if n := len(result.AccessClosures); n > 0 {
closure := result.AccessClosures[n-1]
latestClosure = map[string]any{"id": closure.ID, "closure_type": closure.ClosureType, "status": closure.Status, "details_json": closure.DetailsJSON}
}
writeJSON(w, http.StatusOK, map[string]any{
"provider_id": result.Provider.ProviderID,
"pack_id": result.Pack.PackID,
"batch_id": result.Batch.ID,
"batch_access_status": result.Batch.AccessStatus,
"latest_access_status": result.LatestAccessStatus,
"closures_count": len(result.AccessClosures),
"latest_closure": latestClosure,
})
}
func handleProviderResources(w http.ResponseWriter, r *http.Request, fn func(context.Context, ProviderQueryRequest) (provision.ProviderSnapshot, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "provider-resources action is not configured"})
return
}
result, err := fn(r.Context(), ProviderQueryRequest{ProviderID: r.PathValue("providerID"), PackID: strings.TrimSpace(r.URL.Query().Get("pack_id")), HostID: strings.TrimSpace(r.URL.Query().Get("host_id"))})
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
resources := make([]map[string]any, 0, len(result.ManagedResources))
for _, resource := range result.ManagedResources {
resources = append(resources, map[string]any{"id": resource.ID, "resource_type": resource.ResourceType, "host_resource_id": resource.HostResourceID, "resource_name": resource.ResourceName})
}
accessClosures := make([]map[string]any, 0, len(result.AccessClosures))
for _, closure := range result.AccessClosures {
accessClosures = append(accessClosures, map[string]any{"id": closure.ID, "closure_type": closure.ClosureType, "status": closure.Status, "details_json": closure.DetailsJSON})
}
reconcileRuns := make([]map[string]any, 0, len(result.ReconcileRuns))
for _, run := range result.ReconcileRuns {
reconcileRuns = append(reconcileRuns, map[string]any{"id": run.ID, "status": run.Status, "summary_json": run.SummaryJSON})
}
writeJSON(w, http.StatusOK, map[string]any{
"provider_id": result.Provider.ProviderID,
"pack_id": result.Pack.PackID,
"batch_id": result.Batch.ID,
"resources": resources,
"access_closures": accessClosures,
"reconcile_runs": reconcileRuns,
})
}
func handlePreviewProvider(w http.ResponseWriter, r *http.Request, fn func(context.Context, PreviewProviderRequest) (provision.PreviewReport, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "preview-provider action is not configured"})
return
}
var req PreviewProviderRequest
if err := decodeJSON(r, &req); err != nil {
writeHTTPError(w, err)
return
}
req.ProviderID = r.PathValue("providerID")
result, err := fn(r.Context(), req)
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
writeJSON(w, http.StatusOK, map[string]any{
"accepted_keys_count": len(result.AcceptedKeys),
"host_overlays": result.HostOverlays,
"names": result.Names,
"decisions": result.Decisions,
})
}
func handleImportProvider(w http.ResponseWriter, r *http.Request, fn func(context.Context, ImportProviderRequest) (provision.RuntimeImportResult, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "import-provider action is not configured"})
return
}
var req ImportProviderRequest
if err := decodeJSON(r, &req); err != nil {
writeHTTPError(w, err)
return
}
req.ProviderID = r.PathValue("providerID")
result, err := fn(r.Context(), req)
if err != nil {
payload := map[string]any{
"batch_id": result.BatchID,
"batch_status": result.Report.BatchStatus,
"provider_status": result.Report.ProviderStatus,
"access_status": result.Report.AccessStatus,
"accepted_keys_count": len(result.Report.AcceptedKeys),
"accounts_count": len(result.Report.Accounts),
"host_overlays": result.Report.HostOverlays,
"gateway": result.Report.Gateway,
"error": classifyError(err),
}
statusCode := http.StatusConflict
if result.BatchID == 0 {
statusCode = classifyError(err).StatusCode
}
writeJSON(w, statusCode, payload)
return
}
writeJSON(w, http.StatusOK, map[string]any{
"batch_id": result.BatchID,
"batch_status": result.Report.BatchStatus,
"provider_status": result.Report.ProviderStatus,
"access_status": result.Report.AccessStatus,
"accepted_keys_count": len(result.Report.AcceptedKeys),
"accounts_count": len(result.Report.Accounts),
"host_overlays": result.Report.HostOverlays,
"group": result.Report.Group,
"channel": result.Report.Channel,
"plan": result.Report.Plan,
"gateway": result.Report.Gateway,
})
}
func handleRollbackProvider(w http.ResponseWriter, r *http.Request, fn func(context.Context, RollbackProviderRequest) (provision.RollbackReport, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "rollback-provider action is not configured"})
return
}
var req RollbackProviderRequest
if err := decodeJSON(r, &req); err != nil {
writeHTTPError(w, err)
return
}
req.ProviderID = r.PathValue("providerID")
result, err := fn(r.Context(), req)
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
writeJSON(w, http.StatusOK, map[string]any{
"provider_id": req.ProviderID,
"deleted_accounts": result.AccountsDeleted,
"deleted_plans": result.PlansDeleted,
"deleted_channels": result.ChannelsDeleted,
"deleted_groups": result.GroupsDeleted,
})
}
func handleRollbackBatch(w http.ResponseWriter, r *http.Request, fn func(context.Context, RollbackBatchRequest) (provision.RollbackReport, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "rollback-batch action is not configured"})
return
}
var req RollbackBatchRequest
if err := decodeJSON(r, &req); err != nil {
writeHTTPError(w, err)
return
}
batchID, err := strconv.ParseInt(strings.TrimSpace(r.PathValue("batchID")), 10, 64)
if err != nil || batchID <= 0 {
writeHTTPError(w, &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: "batch_id must be a positive integer"})
return
}
req.BatchID = batchID
result, err := fn(r.Context(), req)
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
writeJSON(w, http.StatusOK, map[string]any{
"batch_id": req.BatchID,
"deleted_accounts": result.AccountsDeleted,
"deleted_plans": result.PlansDeleted,
"deleted_channels": result.ChannelsDeleted,
"deleted_groups": result.GroupsDeleted,
})
}
func handleReconcileProvider(w http.ResponseWriter, r *http.Request, fn func(context.Context, ReconcileProviderRequest) (reconcile.Result, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "reconcile-provider action is not configured"})
return
}
var req ReconcileProviderRequest
if err := decodeJSON(r, &req); err != nil {
writeHTTPError(w, err)
return
}
req.ProviderID = r.PathValue("providerID")
result, err := fn(r.Context(), req)
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
writeJSON(w, http.StatusOK, map[string]any{
"provider_id": req.ProviderID,
"batch_id": result.BatchID,
"status": result.Status,
"missing_count": result.MissingCount,
"extra_count": result.ExtraCount,
"stale_noise_count": result.StaleNoiseCount,
"summary": result.Summary,
})
}
func handleListHosts(w http.ResponseWriter, r *http.Request, fn func(context.Context) ([]HostInfo, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "list-hosts action is not configured"})
return
}
hosts, err := fn(r.Context())
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
if hosts == nil {
hosts = []HostInfo{}
}
writeJSON(w, http.StatusOK, map[string]any{"hosts": hosts})
}
func handleGetHost(w http.ResponseWriter, r *http.Request, fn func(context.Context, string) (HostInfo, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "get-host action is not configured"})
return
}
hostID := strings.TrimSpace(r.PathValue("hostID"))
if hostID == "" {
writeHTTPError(w, &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: "host_id is required"})
return
}
host, err := fn(r.Context(), hostID)
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
writeJSON(w, http.StatusOK, host)
}
func handleProbeHost(w http.ResponseWriter, r *http.Request, fn func(context.Context, ProbeHostRequest) (HostInfo, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "probe-host action is not configured"})
return
}
var req ProbeHostRequest
if err := decodeJSON(r, &req); err != nil {
writeHTTPError(w, err)
return
}
req.HostID = strings.TrimSpace(r.PathValue("hostID"))
if req.HostID == "" {
writeHTTPError(w, &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: "host_id is required"})
return
}
host, err := fn(r.Context(), req)
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
writeJSON(w, http.StatusOK, host)
}
func handleCreateHost(w http.ResponseWriter, r *http.Request, fn func(context.Context, CreateHostRequest) (HostInfo, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "create-host action is not configured"})
return
}
var req CreateHostRequest
if err := decodeJSON(r, &req); err != nil {
writeHTTPError(w, err)
return
}
result, err := fn(r.Context(), req)
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
writeJSON(w, http.StatusOK, result)
}
func handleDeleteHost(w http.ResponseWriter, r *http.Request, fn func(context.Context, string) error) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "delete-host action is not configured"})
return
}
hostID := strings.TrimSpace(r.PathValue("hostID"))
if hostID == "" {
writeHTTPError(w, &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: "host_id is required"})
return
}
if err := fn(r.Context(), hostID); err != nil {
writeHTTPError(w, classifyError(err))
return
}
w.WriteHeader(http.StatusNoContent)
}
func decodeJSON(r *http.Request, dest any) *httpError {
if r == nil {
return &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: "request is required"}
}
r.Body = http.MaxBytesReader(nil, r.Body, maxJSONBodyBytes)
decoder := json.NewDecoder(r.Body)
decoder.DisallowUnknownFields()
if err := decoder.Decode(dest); err != nil {
var maxBytesErr *http.MaxBytesError
if errors.As(err, &maxBytesErr) {
return &httpError{StatusCode: http.StatusRequestEntityTooLarge, Code: "request_too_large", Message: fmt.Sprintf("request body exceeds %d bytes", maxJSONBodyBytes)}
}
return &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: fmt.Sprintf("decode request body: %v", err)}
}
if err := decoder.Decode(&struct{}{}); err != nil && !errors.Is(err, io.EOF) {
return &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: "request body must contain a single JSON object"}
}
return nil
}
func writeHTTPError(w http.ResponseWriter, err *httpError) {
if err == nil {
err = &httpError{StatusCode: http.StatusInternalServerError, Code: "internal_error", Message: "internal server error"}
}
writeJSON(w, err.StatusCode, map[string]any{"error": err})
}
func writeJSON(w http.ResponseWriter, statusCode int, body any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
_ = json.NewEncoder(w).Encode(body)
}
func classifyError(err error) *httpError {
if err == nil {
return nil
}
var requestErr *httpError
if errors.As(err, &requestErr) {
return requestErr
}
var upstreamErr *sub2api.HTTPError
if errors.As(err, &upstreamErr) {
return &httpError{StatusCode: http.StatusBadGateway, Code: "host_request_failed", Message: err.Error(), UpstreamStatus: upstreamErr.StatusCode}
}
var hostDeleteBlocker *sqlite.HostDeleteBlocker
if errors.As(err, &hostDeleteBlocker) {
return &httpError{StatusCode: http.StatusConflict, Code: "host_in_use", Message: err.Error()}
}
message := err.Error()
switch {
case strings.Contains(message, "already installed") || strings.Contains(message, "checksum drift"):
return &httpError{StatusCode: http.StatusConflict, Code: "pack_conflict", Message: message}
case strings.Contains(message, "repo root"):
return &httpError{StatusCode: http.StatusServiceUnavailable, Code: "publish_unavailable", Message: message}
case strings.Contains(message, "run import again before reconcile"):
return &httpError{StatusCode: http.StatusConflict, Code: "batch_not_reconcilable", Message: message}
case strings.Contains(message, "not found in pack"):
return &httpError{StatusCode: http.StatusBadRequest, Code: "provider_not_found", Message: message}
case strings.Contains(message, "not found"):
return &httpError{StatusCode: http.StatusNotFound, Code: "not_found", Message: message}
case strings.Contains(message, "pack path") || strings.Contains(message, "pack dir") || strings.Contains(message, "required") || strings.Contains(message, "decode"):
return &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: message}
default:
return &httpError{StatusCode: http.StatusInternalServerError, Code: "internal_error", Message: message}
}
}
func NewActionSet(sqliteDSN string) ActionSet {
return NewActionSetWithStickyRuntime(sqliteDSN, defaultStickyStoreRuntime())
}
func NewActionSetWithStickyRuntime(sqliteDSN string, stickyRuntime stickyStoreRuntime) ActionSet {
routeLogWriter := newLazyRouteLogWriter(sqliteDSN)
resolveRoute := buildResolveRouteAction(sqliteDSN, stickyRuntime, routeLogWriter)
return ActionSet{
CreateBatchImportRun: buildCreateBatchImportRunAction(sqliteDSN),
ListBatchImportRuns: buildListBatchImportRunsAction(sqliteDSN),
GetBatchImportRun: buildGetBatchImportRunAction(sqliteDSN),
ListBatchImportRunItems: buildListBatchImportRunItemsAction(sqliteDSN),
GetBatchImportRunItem: buildGetBatchImportRunItemAction(sqliteDSN),
CreateLogicalGroup: buildCreateLogicalGroupAction(sqliteDSN),
ListLogicalGroups: buildListLogicalGroupsAction(sqliteDSN),
GetLogicalGroup: buildGetLogicalGroupAction(sqliteDSN),
UpdateLogicalGroup: buildUpdateLogicalGroupAction(sqliteDSN),
DeleteLogicalGroup: buildDeleteLogicalGroupAction(sqliteDSN),
CreateLogicalGroupModel: buildCreateLogicalGroupModelAction(sqliteDSN),
ListLogicalGroupModels: buildListLogicalGroupModelsAction(sqliteDSN),
DeleteLogicalGroupModel: buildDeleteLogicalGroupModelAction(sqliteDSN),
CreateLogicalGroupRoute: buildCreateLogicalGroupRouteAction(sqliteDSN),
ListLogicalGroupRoutes: buildListLogicalGroupRoutesAction(sqliteDSN),
UpdateLogicalGroupRoute: buildUpdateLogicalGroupRouteAction(sqliteDSN),
DeleteLogicalGroupRoute: buildDeleteLogicalGroupRouteAction(sqliteDSN),
CreateLogicalGroupRouteModel: buildCreateLogicalGroupRouteModelAction(sqliteDSN),
ListLogicalGroupRouteModels: buildListLogicalGroupRouteModelsAction(sqliteDSN),
AppendRouteDecisionLog: buildAppendRouteDecisionLogAction(routeLogWriter, sqliteDSN),
ListRouteDecisionLogs: buildListRouteDecisionLogsAction(sqliteDSN),
AppendRouteFailoverEvent: buildAppendRouteFailoverEventAction(routeLogWriter, sqliteDSN),
ListRouteFailoverEvents: buildListRouteFailoverEventsAction(sqliteDSN),
AppendRouteStickyAudit: buildAppendRouteStickyAuditAction(routeLogWriter, sqliteDSN),
ListRouteStickyAudit: buildListRouteStickyAuditAction(sqliteDSN),
ResolveRoute: resolveRoute,
SetStickyBinding: buildSetStickyBindingAction(stickyRuntime),
GetStickyBinding: buildGetStickyBindingAction(stickyRuntime),
SetRouteFailure: buildSetRouteFailureAction(stickyRuntime),
GetRouteFailure: buildGetRouteFailureAction(stickyRuntime),
SetRouteCooldown: buildSetRouteCooldownAction(stickyRuntime),
GetRouteCooldown: buildGetRouteCooldownAction(stickyRuntime),
CreateProviderDraft: func(ctx context.Context, req CreateProviderDraftRequest) (ProviderDraftInfo, error) {
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {
return ProviderDraftInfo{}, err
}
defer store.Close()
draftID := strings.TrimSpace(req.DraftID)
if draftID == "" {
draftID = fmt.Sprintf("draft_%d", time.Now().UnixNano())
}
manifestJSON, manifestValue, supportedModels, err := normalizeProviderDraftPayload(req)
if err != nil {
return ProviderDraftInfo{}, err
}
if err := validateProviderDraftModelConflicts(ctx, store, strings.TrimSpace(req.PackID), strings.TrimSpace(req.ProviderID), "", supportedModels, strings.TrimSpace(req.SmokeTestModel)); err != nil {
return ProviderDraftInfo{}, err
}
draftRow := sqlite.ProviderDraft{
DraftID: draftID,
PackID: strings.TrimSpace(req.PackID),
ProviderID: strings.TrimSpace(req.ProviderID),
DisplayName: strings.TrimSpace(req.DisplayName),
Platform: strings.TrimSpace(req.Platform),
BaseURL: strings.TrimSpace(req.BaseURL),
SmokeTestModel: strings.TrimSpace(req.SmokeTestModel),
SupportedModelsJSON: encodeStringList(supportedModels),
ManifestJSON: manifestJSON,
SourceHostID: strings.TrimSpace(req.SourceHostID),
Notes: strings.TrimSpace(req.Notes),
}
if _, err := store.ProviderDrafts().Create(ctx, draftRow); err != nil {
return ProviderDraftInfo{}, err
}
persisted, err := store.ProviderDrafts().GetByDraftID(ctx, draftID)
if err != nil {
return ProviderDraftInfo{}, err
}
return providerDraftRecordToInfo(persisted, manifestValue, supportedModels)
},
ListProviderDrafts: func(ctx context.Context, req ListProviderDraftsRequest) ([]ProviderDraftInfo, error) {
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {
return nil, err
}
defer store.Close()
rows, err := store.ProviderDrafts().List(ctx, sqlite.ListProviderDraftsFilter{
PackID: req.PackID,
ProviderID: req.ProviderID,
Query: req.Query,
})
if err != nil {
return nil, err
}
result := make([]ProviderDraftInfo, 0, len(rows))
for _, row := range rows {
info, err := providerDraftRecordToInfoFromStored(row)
if err != nil {
return nil, err
}
result = append(result, info)
}
return result, nil
},
GetProviderDraft: func(ctx context.Context, draftID string) (ProviderDraftInfo, error) {
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {
return ProviderDraftInfo{}, err
}
defer store.Close()
row, err := store.ProviderDrafts().GetByDraftID(ctx, draftID)
if err != nil {
return ProviderDraftInfo{}, err
}
return providerDraftRecordToInfoFromStored(row)
},
UpdateProviderDraft: func(ctx context.Context, req UpdateProviderDraftRequest) (ProviderDraftInfo, error) {
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {
return ProviderDraftInfo{}, err
}
defer store.Close()
manifestJSON, _, supportedModels, err := normalizeProviderDraftPayload(req.CreateProviderDraftRequest)
if err != nil {
return ProviderDraftInfo{}, err
}
if err := validateProviderDraftModelConflicts(ctx, store, strings.TrimSpace(req.PackID), strings.TrimSpace(req.ProviderID), strings.TrimSpace(req.DraftID), supportedModels, strings.TrimSpace(req.SmokeTestModel)); err != nil {
return ProviderDraftInfo{}, err
}
if err := store.ProviderDrafts().UpdateByDraftID(ctx, sqlite.ProviderDraft{
DraftID: strings.TrimSpace(req.DraftID),
PackID: strings.TrimSpace(req.PackID),
ProviderID: strings.TrimSpace(req.ProviderID),
DisplayName: strings.TrimSpace(req.DisplayName),
Platform: strings.TrimSpace(req.Platform),
BaseURL: strings.TrimSpace(req.BaseURL),
SmokeTestModel: strings.TrimSpace(req.SmokeTestModel),
SupportedModelsJSON: encodeStringList(supportedModels),
ManifestJSON: manifestJSON,
SourceHostID: strings.TrimSpace(req.SourceHostID),
Notes: strings.TrimSpace(req.Notes),
}); err != nil {
return ProviderDraftInfo{}, err
}
row, err := store.ProviderDrafts().GetByDraftID(ctx, req.DraftID)
if err != nil {
return ProviderDraftInfo{}, err
}
return providerDraftRecordToInfoFromStored(row)
},
DeleteProviderDraft: func(ctx context.Context, draftID string) error {
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {
return err
}
defer store.Close()
return store.ProviderDrafts().DeleteByDraftID(ctx, draftID)
},
PublishProviderDraft: func(ctx context.Context, req PublishProviderDraftRequest) (PublishProviderDraftResult, error) {
startupCfg, err := config.LoadStartupFromEnv()
if err != nil {
return PublishProviderDraftResult{}, err
}
if strings.TrimSpace(startupCfg.Repository.RepoRoot) == "" {
return PublishProviderDraftResult{}, fmt.Errorf("pack repo root is not configured")
}
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {
return PublishProviderDraftResult{}, err
}
defer store.Close()
row, err := store.ProviderDrafts().GetByDraftID(ctx, req.DraftID)
if err != nil {
return PublishProviderDraftResult{}, err
}
manifest, err := buildPublishedProviderManifest(row)
if err != nil {
return PublishProviderDraftResult{}, err
}
if err := validateProviderDraftModelConflicts(ctx, store, strings.TrimSpace(row.PackID), strings.TrimSpace(manifest.ProviderID), strings.TrimSpace(row.DraftID), manifest.DefaultModels, strings.TrimSpace(manifest.SmokeTestModel)); err != nil {
return PublishProviderDraftResult{}, err
}
publishResult, err := pack.PublishProviderManifest(ctx, pack.PublishProviderManifestRequest{
RepoRoot: startupCfg.Repository.RepoRoot,
PackID: row.PackID,
Manifest: manifest,
CommitMessage: strings.TrimSpace(req.CommitMessage),
})
if err != nil {
return PublishProviderDraftResult{}, err
}
return PublishProviderDraftResult{
DraftID: row.DraftID,
PackID: publishResult.PackID,
ProviderID: publishResult.ProviderID,
ProviderPath: publishResult.ProviderPath,
PackVersionBefore: publishResult.PackVersionBefore,
PackVersionAfter: publishResult.PackVersionAfter,
PublishMode: publishResult.PublishMode,
CommitMessage: publishResult.CommitMessage,
CommitSHA: publishResult.CommitSHA,
RepoRoot: publishResult.RepoRoot,
}, nil
},
InstallPack: func(ctx context.Context, req InstallPackRequest) (provision.PackInstallResult, error) {
loadedPack, err := pack.LoadPath(req.PackPath)
if err != nil {
return provision.PackInstallResult{}, err
}
client, err := sub2api.NewClient(req.HostBaseURL, sub2api.WithAPIKey(req.HostAPIKey), sub2api.WithBearerToken(req.HostBearerToken))
if err != nil {
return provision.PackInstallResult{}, err
}
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {
return provision.PackInstallResult{}, err
}
defer store.Close()
service := provision.NewPackInstallService(store, client)
return service.Install(ctx, provision.PackInstallRequest{Pack: loadedPack})
},
BatchDetail: func(ctx context.Context, req BatchDetailRequest) (provision.BatchDetailResult, error) {
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {
return provision.BatchDetailResult{}, err
}
defer store.Close()
return provision.NewBatchDetailService(store).Get(ctx, req.BatchID)
},
GetProviderStatus: func(ctx context.Context, req ProviderQueryRequest) (provision.ProviderSnapshot, error) {
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {
return provision.ProviderSnapshot{}, err
}
defer store.Close()
return provision.NewProviderStatusService(store).GetStatus(ctx, provision.ProviderQuery{ProviderID: req.ProviderID, PackID: req.PackID, HostID: req.HostID})
},
GetProviderResources: func(ctx context.Context, req ProviderQueryRequest) (provision.ProviderSnapshot, error) {
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {
return provision.ProviderSnapshot{}, err
}
defer store.Close()
return provision.NewProviderStatusService(store).GetResources(ctx, provision.ProviderQuery{ProviderID: req.ProviderID, PackID: req.PackID, HostID: req.HostID})
},
GetProviderAccessStatus: func(ctx context.Context, req ProviderQueryRequest) (provision.ProviderSnapshot, error) {
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {
return provision.ProviderSnapshot{}, err
}
defer store.Close()
return provision.NewProviderStatusService(store).GetStatus(ctx, provision.ProviderQuery{ProviderID: req.ProviderID, PackID: req.PackID, HostID: req.HostID})
},
ListProviderImportBatches: func(ctx context.Context, req ProviderQueryRequest) ([]ImportBatchInfo, error) {
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {
return nil, err
}
defer store.Close()
providers, err := resolveProvidersForQuery(ctx, store, req)
if err != nil {
return nil, err
}
batches := make([]ImportBatchInfo, 0)
for _, providerRow := range providers {
var rows []sqlite.ImportBatch
if strings.TrimSpace(req.HostID) != "" {
hostRow, err := store.Hosts().GetByHostID(ctx, req.HostID)
if err != nil {
return nil, err
}
rows, err = store.ImportBatches().ListByProviderIDAndHostID(ctx, providerRow.ID, hostRow.ID)
if err != nil {
return nil, err
}
} else {
rows, err = store.ImportBatches().ListByProviderID(ctx, providerRow.ID)
if err != nil {
return nil, err
}
if len(rows) > 1 {
firstHostID := rows[0].HostID
for _, row := range rows[1:] {
if row.HostID != firstHostID {
return nil, fmt.Errorf("provider exists on multiple hosts; host_id is required")
}
}
}
}
for _, batch := range rows {
batches = append(batches, ImportBatchInfo{BatchID: batch.ID, BatchStatus: batch.BatchStatus, AccessStatus: batch.AccessStatus})
}
}
return batches, nil
},
PreviewProvider: func(ctx context.Context, req PreviewProviderRequest) (provision.PreviewReport, error) {
loadedPack, err := pack.LoadPath(req.PackPath)
if err != nil {
return provision.PreviewReport{}, err
}
providerManifest, err := findProvider(loadedPack, req.ProviderID)
if err != nil {
return provision.PreviewReport{}, err
}
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {
return provision.PreviewReport{}, err
}
defer store.Close()
_, client, err := resolveManagedHost(ctx, store, req.HostID, req.HostBaseURL, createHostAuthFromLegacyFields(req.HostAPIKey, req.HostBearerToken))
if err != nil {
return provision.PreviewReport{}, err
}
hostVersion, err := client.GetHostVersion(ctx)
if err != nil {
return provision.PreviewReport{}, err
}
service := provision.NewPreviewService(client)
return service.PreviewImport(ctx, provision.PreviewRequest{
TargetHost: loadedPack.Manifest.TargetHost,
HostVersion: hostVersion,
Provider: providerManifest,
Mode: req.Mode,
Keys: req.Keys,
})
},
ImportProvider: func(ctx context.Context, req ImportProviderRequest) (provision.RuntimeImportResult, error) {
loadedPack, err := pack.LoadPath(req.PackPath)
if err != nil {
return provision.RuntimeImportResult{}, err
}
providerManifest, err := findProvider(loadedPack, req.ProviderID)
if err != nil {
return provision.RuntimeImportResult{}, err
}
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {
return provision.RuntimeImportResult{}, err
}
defer store.Close()
hostRow, client, err := resolveManagedHost(ctx, store, req.HostID, req.HostBaseURL, createHostAuthFromLegacyFields(req.HostAPIKey, req.HostBearerToken))
if err != nil {
return provision.RuntimeImportResult{}, err
}
subscriptions := make([]provision.SubscriptionTarget, 0, len(req.SubscriptionUsers))
for _, userID := range req.SubscriptionUsers {
subscriptions = append(subscriptions, provision.SubscriptionTarget{UserID: userID, DurationDays: req.SubscriptionDays})
}
service := provision.NewRuntimeImportService(store, client)
return service.Import(ctx, provision.RuntimeImportRequest{
HostID: hostRow.HostID,
HostBaseURL: hostRow.BaseURL,
Pack: loadedPack,
Provider: providerManifest,
Mode: req.Mode,
Keys: req.Keys,
Access: provision.AccessRequest{
Mode: req.AccessMode,
ProbeAPIKey: req.AccessAPIKey,
Subscriptions: subscriptions,
},
})
},
RollbackProvider: func(ctx context.Context, req RollbackProviderRequest) (provision.RollbackReport, error) {
loadedPack, err := pack.LoadPath(req.PackPath)
if err != nil {
return provision.RollbackReport{}, err
}
providerManifest, err := findProvider(loadedPack, req.ProviderID)
if err != nil {
return provision.RollbackReport{}, err
}
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {
return provision.RollbackReport{}, err
}
defer store.Close()
hostRow, client, err := resolveManagedHost(ctx, store, req.HostID, req.HostBaseURL, createHostAuthFromLegacyFields(req.HostAPIKey, req.HostBearerToken))
if err != nil {
return provision.RollbackReport{}, err
}
packRow, err := store.Packs().GetByPackID(ctx, loadedPack.Manifest.PackID)
if err != nil {
return provision.RollbackReport{}, err
}
providerRow, err := store.Providers().GetByPackIDAndProviderID(ctx, packRow.ID, providerManifest.ProviderID)
if err != nil {
return provision.RollbackReport{}, err
}
batch, err := store.ImportBatches().GetLatestByProviderIDAndHostID(ctx, providerRow.ID, hostRow.ID)
if err != nil {
return provision.RollbackReport{}, fmt.Errorf("find latest batch for provider %q on host %q: %w", providerManifest.ProviderID, hostRow.HostID, err)
}
managedResources, err := store.ManagedResources().GetByBatchID(ctx, batch.ID)
if err != nil {
return provision.RollbackReport{}, err
}
if len(managedResources) == 0 {
return provision.RollbackReport{}, fmt.Errorf("rollback requires stored managed resources for provider %q on host %q", providerManifest.ProviderID, hostRow.HostID)
}
service := provision.NewRollbackService(client)
report, rollbackErr := service.RollbackStoredResources(ctx, managedResources)
if rollbackErr != nil {
_ = store.ImportBatches().UpdateStatus(ctx, batch.ID, provision.BatchStatusFailed, batch.AccessStatus)
return report, rollbackErr
}
if err := store.ImportBatches().UpdateStatus(ctx, batch.ID, provision.BatchStatusRolledBack, batch.AccessStatus); err != nil {
return report, fmt.Errorf("rollback resources succeeded but update batch status: %w", err)
}
return report, nil
},
RollbackBatch: func(ctx context.Context, req RollbackBatchRequest) (provision.RollbackReport, error) {
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {
return provision.RollbackReport{}, err
}
defer store.Close()
batch, err := store.ImportBatches().GetByID(ctx, req.BatchID)
if err != nil {
return provision.RollbackReport{}, err
}
hostRow, err := store.Hosts().GetByID(ctx, batch.HostID)
if err != nil {
return provision.RollbackReport{}, err
}
client, err := newSub2APIClient(hostRow.BaseURL, authFromStoredHost(hostRow))
if err != nil {
return provision.RollbackReport{}, err
}
managedResources, err := store.ManagedResources().GetByBatchID(ctx, batch.ID)
if err != nil {
return provision.RollbackReport{}, err
}
service := provision.NewRollbackService(client)
report, rollbackErr := service.RollbackStoredResources(ctx, managedResources)
if rollbackErr != nil {
_ = store.ImportBatches().UpdateStatus(ctx, batch.ID, provision.BatchStatusFailed, batch.AccessStatus)
return report, rollbackErr
}
if err := store.ImportBatches().UpdateStatus(ctx, batch.ID, provision.BatchStatusRolledBack, batch.AccessStatus); err != nil {
return report, fmt.Errorf("rollback resources succeeded but update batch status: %w", err)
}
return report, nil
},
ReconcileProvider: func(ctx context.Context, req ReconcileProviderRequest) (reconcile.Result, error) {
loadedPack, err := pack.LoadPath(req.PackPath)
if err != nil {
return reconcile.Result{}, err
}
providerManifest, err := findProvider(loadedPack, req.ProviderID)
if err != nil {
return reconcile.Result{}, err
}
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {
return reconcile.Result{}, err
}
defer store.Close()
hostRow, client, err := resolveManagedHost(ctx, store, req.HostID, req.HostBaseURL, createHostAuthFromLegacyFields(req.HostAPIKey, req.HostBearerToken))
if err != nil {
return reconcile.Result{}, err
}
service := reconcile.NewService(store, client)
return service.Reconcile(ctx, reconcile.Request{HostID: hostRow.HostID, HostBaseURL: hostRow.BaseURL, AccessProbeAPIKey: req.AccessAPIKey, Pack: loadedPack, Provider: providerManifest})
},
CreateHost: func(ctx context.Context, req CreateHostRequest) (HostInfo, error) {
if strings.TrimSpace(req.BaseURL) == "" {
return HostInfo{}, &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: "base_url is required"}
}
client, err := newSub2APIClient(req.BaseURL, req.Auth)
if err != nil {
return HostInfo{}, err
}
hostVersion, capabilities, err := probeHostSnapshot(ctx, client)
if err != nil {
return HostInfo{}, err
}
capabilityJSON, err := json.Marshal(capabilities)
if err != nil {
return HostInfo{}, fmt.Errorf("marshal capabilities: %w", err)
}
name := strings.TrimSpace(req.Name)
if name == "" {
name = req.BaseURL
}
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {
return HostInfo{}, err
}
defer store.Close()
if existing, err := store.Hosts().GetByBaseURL(ctx, req.BaseURL); err == nil && existing.HostID != name {
return HostInfo{}, &httpError{StatusCode: http.StatusConflict, Code: "host_conflict", Message: fmt.Sprintf("base_url %s already registered as host_id %s", req.BaseURL, existing.HostID)}
}
hostRecord := sqlite.Host{HostID: name, BaseURL: req.BaseURL, HostVersion: hostVersion, CapabilityProbeJSON: string(capabilityJSON), AuthType: req.Auth.Type, AuthToken: req.Auth.Token}
if _, err := store.Hosts().GetByHostID(ctx, name); err == nil {
if err := store.Hosts().UpdateConnectionByHostID(ctx, name, req.BaseURL, hostVersion, string(capabilityJSON), req.Auth.Type, req.Auth.Token); err != nil {
return HostInfo{}, fmt.Errorf("update host: %w", err)
}
} else {
if _, err := store.Hosts().Create(ctx, hostRecord); err != nil {
return HostInfo{}, fmt.Errorf("save host: %w", err)
}
}
stored, err := store.Hosts().GetByHostID(ctx, name)
if err != nil {
return HostInfo{}, err
}
return hostRecordToInfo(stored), nil
},
ProbeHost: func(ctx context.Context, req ProbeHostRequest) (HostInfo, error) {
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {
return HostInfo{}, err
}
defer store.Close()
hostRow, err := store.Hosts().GetByHostID(ctx, req.HostID)
if err != nil {
return HostInfo{}, err
}
client, err := newSub2APIClient(hostRow.BaseURL, authFromStoredHost(hostRow))
if err != nil {
return HostInfo{}, err
}
hostVersion, capabilities, err := probeHostSnapshot(ctx, client)
if err != nil {
return HostInfo{}, err
}
capabilityJSON, err := json.Marshal(capabilities)
if err != nil {
return HostInfo{}, fmt.Errorf("marshal capabilities: %w", err)
}
if err := store.Hosts().UpdateProbeByHostID(ctx, req.HostID, hostVersion, string(capabilityJSON)); err != nil {
return HostInfo{}, err
}
return HostInfo{HostID: hostRow.HostID, BaseURL: hostRow.BaseURL, HostVersion: hostVersion, Status: hostSupportStatus(capabilities), Capabilities: &capabilities}, nil
},
ListHosts: func(ctx context.Context) ([]HostInfo, error) {
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {
return nil, err
}
defer store.Close()
hosts, err := store.Hosts().ListAll(ctx)
if err != nil {
return nil, err
}
result := make([]HostInfo, 0, len(hosts))
for _, host := range hosts {
result = append(result, hostRecordToInfo(host))
}
return result, nil
},
GetHost: func(ctx context.Context, hostID string) (HostInfo, error) {
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {
return HostInfo{}, err
}
defer store.Close()
host, err := store.Hosts().GetByHostID(ctx, hostID)
if err != nil {
return HostInfo{}, err
}
return hostRecordToInfo(host), nil
},
DeleteHost: func(ctx context.Context, hostID string) error {
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {
return err
}
defer store.Close()
if err := store.Hosts().DeleteByHostID(ctx, hostID); err != nil {
return classifyError(err)
}
return nil
},
ListPacks: func(ctx context.Context) ([]PackInfo, error) {
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {
return nil, err
}
defer store.Close()
packs, err := store.Packs().ListAll(ctx)
if err != nil {
return nil, err
}
result := make([]PackInfo, 0, len(packs))
for _, p := range packs {
result = append(result, packRecordToInfo(p))
}
return result, nil
},
GetPack: func(ctx context.Context, packID string) (PackInfo, error) {
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {
return PackInfo{}, err
}
defer store.Close()
pack, err := store.Packs().GetByPackID(ctx, packID)
if err != nil {
return PackInfo{}, err
}
return packRecordToInfo(pack), nil
},
ListPackProviders: func(ctx context.Context, packID string) ([]PackProviderInfo, error) {
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {
return nil, err
}
defer store.Close()
packRow, err := store.Packs().GetByPackID(ctx, packID)
if err != nil {
return nil, err
}
providers, err := store.Providers().ListByPackID(ctx, packRow.ID)
if err != nil {
return nil, err
}
result := make([]PackProviderInfo, 0, len(providers))
for _, p := range providers {
hostOverlays := 0
baseURL := ""
smokeTestModel := ""
supportedModels := []string{}
if strings.TrimSpace(p.ManifestJSON) != "" {
var providerManifest pack.ProviderManifest
if err := json.Unmarshal([]byte(p.ManifestJSON), &providerManifest); err != nil {
return nil, fmt.Errorf("decode stored provider manifest: %w", err)
}
hostOverlays = len(providerManifest.HostOverlays)
baseURL = strings.TrimSpace(providerManifest.BaseURL)
smokeTestModel = strings.TrimSpace(providerManifest.SmokeTestModel)
supportedModels = append([]string(nil), providerManifest.DefaultModels...)
}
result = append(result, PackProviderInfo{
ProviderID: p.ProviderID,
DisplayName: p.DisplayName,
Platform: p.Platform,
HostOverlays: hostOverlays,
BaseURL: baseURL,
SmokeTestModel: smokeTestModel,
SupportedModels: supportedModels,
})
}
return result, nil
},
AssignAccessSubscriptions: func(ctx context.Context, req AssignAccessSubscriptionsRequest) (AssignAccessSubscriptionsResult, error) {
loadedPack, err := pack.LoadPath(req.PackPath)
if err != nil {
return AssignAccessSubscriptionsResult{}, err
}
providerManifest, err := findProvider(loadedPack, req.ProviderID)
if err != nil {
return AssignAccessSubscriptionsResult{}, err
}
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {
return AssignAccessSubscriptionsResult{}, err
}
defer store.Close()
hostRow, client, err := resolveManagedHost(ctx, store, req.HostID, req.HostBaseURL, createHostAuthFromLegacyFields(req.HostAPIKey, req.HostBearerToken))
if err != nil {
return AssignAccessSubscriptionsResult{}, err
}
packRow, err := store.Packs().GetByPackID(ctx, loadedPack.Manifest.PackID)
if err != nil {
return AssignAccessSubscriptionsResult{}, err
}
providerRow, err := store.Providers().GetByPackIDAndProviderID(ctx, packRow.ID, providerManifest.ProviderID)
if err != nil {
return AssignAccessSubscriptionsResult{}, fmt.Errorf("provider %q not found in pack %q", req.ProviderID, loadedPack.Manifest.PackID)
}
batch, err := store.ImportBatches().GetLatestByProviderIDAndHostID(ctx, providerRow.ID, hostRow.ID)
if err != nil {
return AssignAccessSubscriptionsResult{}, fmt.Errorf("find batch for provider on host: %w", err)
}
resources, err := store.ManagedResources().GetByBatchID(ctx, batch.ID)
if err != nil {
return AssignAccessSubscriptionsResult{}, err
}
groupID := ""
for _, r := range resources {
if r.ResourceType == "group" {
groupID = r.HostResourceID
break
}
}
if groupID == "" {
return AssignAccessSubscriptionsResult{}, fmt.Errorf("no group found for provider batch")
}
subscriptions := make([]access.SubscriptionTarget, 0, len(req.SubscriptionUsers))
for _, userID := range req.SubscriptionUsers {
subscriptions = append(subscriptions, access.SubscriptionTarget{UserID: userID, DurationDays: req.SubscriptionDays})
}
accessSvc := access.NewService(client)
gwResult, err := accessSvc.Close(ctx, access.ClosureRequest{Mode: access.ModeSubscription, ProbeAPIKey: req.AccessAPIKey, Subscriptions: subscriptions, GroupID: groupID, ExpectedModel: providerManifest.SmokeTestModel, Prompt: "ping", MaxTokens: 8})
if err != nil {
return AssignAccessSubscriptionsResult{}, err
}
accessStatus := deriveAccessStatus(gwResult)
accessPayload, _ := json.Marshal(provision.BuildAccessClosureDetails(provision.AccessRequest{
Mode: provision.AccessModeSubscription,
ProbeAPIKey: req.AccessAPIKey,
Subscriptions: subscriptionTargets(req.SubscriptionUsers, req.SubscriptionDays),
}, gwResult))
if _, err := store.AccessClosures().Create(ctx, sqlite.AccessClosureRecord{BatchID: batch.ID, ClosureType: access.ModeSubscription, Status: accessStatus, DetailsJSON: string(accessPayload)}); err != nil {
return AssignAccessSubscriptionsResult{}, fmt.Errorf("record access closure: %w", err)
}
if err := store.ImportBatches().UpdateStatus(ctx, batch.ID, batch.BatchStatus, accessStatus); err != nil {
return AssignAccessSubscriptionsResult{}, fmt.Errorf("update batch access status: %w", err)
}
return AssignAccessSubscriptionsResult{ProviderID: req.ProviderID, Assigned: len(req.SubscriptionUsers), AccessStatus: accessStatus}, nil
},
AccessPreview: func(ctx context.Context, req AccessPreviewRequest) (AccessPreviewResult, error) {
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {
return AccessPreviewResult{}, err
}
defer store.Close()
providers, err := resolveProvidersForQuery(ctx, store, ProviderQueryRequest{ProviderID: req.ProviderID, PackID: req.PackID, HostID: req.HostID})
if err != nil {
return AccessPreviewResult{}, err
}
if len(providers) == 0 {
return AccessPreviewResult{}, fmt.Errorf("provider %q not found", req.ProviderID)
}
if len(providers) > 1 {
return AccessPreviewResult{}, fmt.Errorf("provider %q exists in multiple packs; pack_id is required", req.ProviderID)
}
providerRow := providers[0]
latestStatus, err := resolveLatestAccessStatus(ctx, store, providerRow, req.HostID)
if err != nil {
return AccessPreviewResult{}, fmt.Errorf("find batch for provider: %w", err)
}
available := accessStatusSupportsMode(latestStatus, req.Mode)
message := fmt.Sprintf("latest access status: %s", latestStatus)
if !available {
message = fmt.Sprintf("access status %s does not satisfy mode %s", latestStatus, req.Mode)
}
return AccessPreviewResult{ProviderID: req.ProviderID, Mode: req.Mode, Available: available, Message: message}, nil
},
}
}
func findProvider(loaded pack.LoadedPack, providerID string) (pack.ProviderManifest, error) {
for _, provider := range loaded.Providers {
if provider.ProviderID == strings.TrimSpace(providerID) {
return provider, nil
}
}
return pack.ProviderManifest{}, fmt.Errorf("provider %q not found in pack %q", providerID, loaded.Manifest.PackID)
}
func resolveProvidersForQuery(ctx context.Context, store *sqlite.DB, req ProviderQueryRequest) ([]sqlite.Provider, error) {
if store == nil {
return nil, fmt.Errorf("store is required")
}
providerID := strings.TrimSpace(req.ProviderID)
if providerID == "" {
return nil, fmt.Errorf("provider_id is required")
}
if packID := strings.TrimSpace(req.PackID); packID != "" {
packRow, err := store.Packs().GetByPackID(ctx, packID)
if err != nil {
return nil, err
}
providerRow, err := store.Providers().GetByPackIDAndProviderID(ctx, packRow.ID, providerID)
if err != nil {
return nil, err
}
return []sqlite.Provider{providerRow}, nil
}
return store.Providers().ListByProviderID(ctx, providerID)
}
func resolveLatestAccessStatus(ctx context.Context, store *sqlite.DB, providerRow sqlite.Provider, hostID string) (string, error) {
if store == nil {
return "", fmt.Errorf("store is required")
}
if strings.TrimSpace(hostID) != "" {
hostRow, err := store.Hosts().GetByHostID(ctx, hostID)
if err != nil {
return "", err
}
batches, err := store.ImportBatches().ListByProviderIDAndHostID(ctx, providerRow.ID, hostRow.ID)
if err != nil {
return "", err
}
modeStatuses, err := provision.LatestModeAccessStatuses(ctx, store, batches)
if err != nil {
return "", err
}
return provision.AggregateAccessStatus(modeStatuses), nil
}
batches, err := store.ImportBatches().ListByProviderID(ctx, providerRow.ID)
if err != nil {
return "", err
}
if len(batches) == 0 {
return "", fmt.Errorf("latest import batch not found for provider")
}
hostIDValue := batches[0].HostID
for _, batch := range batches[1:] {
if batch.HostID != hostIDValue {
return "", fmt.Errorf("provider exists on multiple hosts; host_id is required")
}
}
modeStatuses, err := provision.LatestModeAccessStatuses(ctx, store, batches)
if err != nil {
return "", err
}
return provision.AggregateAccessStatus(modeStatuses), nil
}
func resolveManagedHost(ctx context.Context, store *sqlite.DB, hostID, baseURL string, auth CreateHostAuth) (sqlite.Host, *sub2api.Client, error) {
if store == nil {
return sqlite.Host{}, nil, fmt.Errorf("store is required")
}
hostID = strings.TrimSpace(hostID)
baseURL = strings.TrimSpace(baseURL)
if hostID != "" {
hostRow, err := store.Hosts().GetByHostID(ctx, hostID)
if err != nil {
return sqlite.Host{}, nil, err
}
if baseURL != "" && baseURL != strings.TrimSpace(hostRow.BaseURL) {
return sqlite.Host{}, nil, fmt.Errorf("host %q base_url mismatch: registered=%s runtime=%s", hostID, hostRow.BaseURL, baseURL)
}
client, err := newSub2APIClient(hostRow.BaseURL, authFromStoredHost(hostRow))
if err != nil {
return sqlite.Host{}, nil, err
}
return hostRow, client, nil
}
if baseURL == "" {
return sqlite.Host{}, nil, fmt.Errorf("host_id is required")
}
hostRow, err := store.Hosts().GetByBaseURL(ctx, baseURL)
if err != nil {
return sqlite.Host{}, nil, fmt.Errorf("host_id is required for unregistered host_base_url %q: %w", baseURL, err)
}
client, err := newSub2APIClient(hostRow.BaseURL, authFromStoredHost(hostRow))
if err != nil {
return sqlite.Host{}, nil, err
}
return hostRow, client, nil
}
func authFromStoredHost(host sqlite.Host) CreateHostAuth {
authType := strings.TrimSpace(host.AuthType)
if authType == "" {
authType = "apikey"
}
return CreateHostAuth{Type: authType, Token: strings.TrimSpace(host.AuthToken)}
}
func createHostAuthFromLegacyFields(apiKey, bearerToken string) CreateHostAuth {
if token := strings.TrimSpace(bearerToken); token != "" {
return CreateHostAuth{Type: "bearer", Token: token}
}
return CreateHostAuth{Type: "apikey", Token: strings.TrimSpace(apiKey)}
}
func newSub2APIClient(baseURL string, auth CreateHostAuth) (*sub2api.Client, error) {
authToken := strings.TrimSpace(auth.Token)
if authToken == "" {
return nil, &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: "auth.token is required"}
}
switch strings.ToLower(strings.TrimSpace(auth.Type)) {
case "bearer":
return sub2api.NewClient(baseURL, sub2api.WithBearerToken(authToken))
case "apikey", "api_key", "":
return sub2api.NewClient(baseURL, sub2api.WithAPIKey(authToken))
default:
return nil, &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: fmt.Sprintf("unsupported auth type %q (supported: bearer, apikey)", auth.Type)}
}
}
func probeHostSnapshot(ctx context.Context, client *sub2api.Client) (string, sub2api.HostCapabilities, error) {
hostVersion, err := client.GetHostVersion(ctx)
if err != nil {
return "", sub2api.HostCapabilities{}, fmt.Errorf("get host version: %w", err)
}
capabilities, err := client.ProbeCapabilities(ctx)
if err != nil {
return "", sub2api.HostCapabilities{}, fmt.Errorf("probe host capabilities: %w", err)
}
return hostVersion, capabilities, nil
}
func hostSupportStatus(capabilities sub2api.HostCapabilities) string {
if !capabilities.Groups || !capabilities.Channels || !capabilities.Plans || !capabilities.Accounts || !capabilities.AccountTest || !capabilities.AccountModels || !capabilities.Subscriptions {
return "unsupported"
}
return "supported"
}
func accessStatusSupportsMode(status, mode string) bool {
status = strings.TrimSpace(status)
mode = strings.TrimSpace(mode)
switch mode {
case "", "any":
return status != provision.AccessStatusBroken && status != ""
case provision.AccessModeSubscription:
return status == provision.AccessStatusSubscriptionReady || status == provision.AccessStatusFullyReady
case provision.AccessModeSelfService:
return status == provision.AccessStatusSelfServiceReady || status == provision.AccessStatusFullyReady
default:
return status == provision.AccessStatusFullyReady
}
}
func hostRecordToInfo(host sqlite.Host) HostInfo {
info := HostInfo{
HostID: host.HostID,
BaseURL: host.BaseURL,
HostVersion: host.HostVersion,
AuthType: strings.TrimSpace(host.AuthType),
}
if strings.TrimSpace(host.CapabilityProbeJSON) != "" && host.CapabilityProbeJSON != "{}" {
var caps sub2api.HostCapabilities
if err := json.Unmarshal([]byte(host.CapabilityProbeJSON), &caps); err == nil {
info.Capabilities = &caps
info.Status = hostSupportStatus(caps)
}
}
return info
}
func packRecordToInfo(pack sqlite.Pack) PackInfo {
return PackInfo{
PackID: pack.PackID,
Version: pack.Version,
Vendor: pack.Vendor,
TargetHost: pack.TargetHost,
MinHostVersion: pack.MinHostVersion,
MaxHostVersion: pack.MaxHostVersion,
}
}
type providerModelOwner struct {
ProviderID string
DisplayName string
Models []string
Source string
DraftID string
}
func validateProviderDraftModelConflicts(ctx context.Context, store *sqlite.DB, packID, providerID, draftID string, supportedModels []string, smokeTestModel string) error {
packID = strings.TrimSpace(packID)
providerID = strings.TrimSpace(providerID)
draftID = strings.TrimSpace(draftID)
targetModels := normalizeProviderModels(supportedModels, smokeTestModel)
if packID == "" || providerID == "" || len(targetModels) == 0 {
return nil
}
owners, err := collectProviderModelOwners(ctx, store, packID)
if err != nil {
return err
}
conflicts := make([]string, 0)
seen := make(map[string]struct{})
for _, owner := range owners {
if owner.ProviderID == "" || owner.ProviderID == providerID {
continue
}
if owner.DraftID != "" && owner.DraftID == draftID {
continue
}
matched := intersectNormalizedModels(targetModels, owner.Models)
if len(matched) == 0 {
continue
}
label := fmt.Sprintf("%s -> %s[%s]", strings.Join(matched, "/"), owner.ProviderID, owner.Source)
if _, ok := seen[label]; ok {
continue
}
seen[label] = struct{}{}
conflicts = append(conflicts, label)
}
if len(conflicts) == 0 {
return nil
}
return &httpError{
StatusCode: http.StatusConflict,
Code: "provider_model_conflict",
Message: fmt.Sprintf("provider model conflict in pack %q: %s", packID, strings.Join(conflicts, "; ")),
}
}
func collectProviderModelOwners(ctx context.Context, store *sqlite.DB, packID string) ([]providerModelOwner, error) {
owners := make([]providerModelOwner, 0)
repoOwners, err := loadRepoPackModelOwners(packID)
if err != nil {
return nil, err
}
if len(repoOwners) > 0 {
owners = append(owners, repoOwners...)
} else {
storedOwners, err := loadStoredPackModelOwners(ctx, store, packID)
if err != nil {
return nil, err
}
owners = append(owners, storedOwners...)
}
draftRows, err := store.ProviderDrafts().List(ctx, sqlite.ListProviderDraftsFilter{PackID: packID})
if err != nil {
return nil, err
}
for _, row := range draftRows {
owners = append(owners, providerModelOwner{
ProviderID: strings.TrimSpace(row.ProviderID),
DisplayName: strings.TrimSpace(row.DisplayName),
Models: normalizeProviderModels(decodeStringList(row.SupportedModelsJSON), row.SmokeTestModel),
Source: "draft",
DraftID: strings.TrimSpace(row.DraftID),
})
}
return owners, nil
}
func loadRepoPackModelOwners(packID string) ([]providerModelOwner, error) {
startupCfg, err := config.LoadStartupFromEnv()
if err != nil {
return nil, err
}
repoRoot := strings.TrimSpace(startupCfg.Repository.RepoRoot)
if repoRoot == "" {
return nil, nil
}
packDir := filepath.Join(repoRoot, "packs", strings.TrimSpace(packID))
if _, err := os.Stat(packDir); err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil, nil
}
return nil, fmt.Errorf("stat repo pack dir %q: %w", packDir, err)
}
loadedPack, err := pack.LoadPath(packDir)
if err != nil {
return nil, fmt.Errorf("load repo pack %q for conflict check: %w", packID, err)
}
owners := make([]providerModelOwner, 0, len(loadedPack.Providers))
for _, provider := range loadedPack.Providers {
owners = append(owners, providerModelOwner{
ProviderID: strings.TrimSpace(provider.ProviderID),
DisplayName: strings.TrimSpace(provider.DisplayName),
Models: normalizeProviderModels(provider.DefaultModels, provider.SmokeTestModel),
Source: "repo",
})
}
return owners, nil
}
func loadStoredPackModelOwners(ctx context.Context, store *sqlite.DB, packID string) ([]providerModelOwner, error) {
packRow, err := store.Packs().GetByPackID(ctx, packID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return nil, err
}
providerRows, err := store.Providers().ListByPackID(ctx, packRow.ID)
if err != nil {
return nil, err
}
owners := make([]providerModelOwner, 0, len(providerRows))
for _, row := range providerRows {
owners = append(owners, providerModelOwner{
ProviderID: strings.TrimSpace(row.ProviderID),
DisplayName: strings.TrimSpace(row.DisplayName),
Models: normalizeProviderModels(decodeStringList(row.DefaultModelsJSON), row.SmokeTestModel),
Source: "store",
})
}
return owners, nil
}
func normalizeProviderDraftPayload(req CreateProviderDraftRequest) (string, any, []string, error) {
supportedModels := normalizeStringList(req.SupportedModels)
if len(req.Manifest) > 0 {
var manifestValue any
if err := json.Unmarshal(req.Manifest, &manifestValue); err != nil {
return "", nil, nil, fmt.Errorf("decode manifest: %w", err)
}
manifestJSON := strings.TrimSpace(string(req.Manifest))
if manifestJSON == "" {
manifestJSON = "{}"
}
return manifestJSON, manifestValue, supportedModels, nil
}
manifestValue := map[string]any{
"provider_id": strings.TrimSpace(req.ProviderID),
"display_name": strings.TrimSpace(req.DisplayName),
"platform": strings.TrimSpace(req.Platform),
"base_url": strings.TrimSpace(req.BaseURL),
"smoke_test_model": strings.TrimSpace(req.SmokeTestModel),
"supported_models": supportedModels,
}
manifestJSONBytes, err := json.Marshal(manifestValue)
if err != nil {
return "", nil, nil, fmt.Errorf("marshal manifest: %w", err)
}
return string(manifestJSONBytes), manifestValue, supportedModels, nil
}
func providerDraftRecordToInfo(row sqlite.ProviderDraft, manifestValue any, supportedModels []string) (ProviderDraftInfo, error) {
if manifestValue == nil {
manifestValue = map[string]any{}
}
return ProviderDraftInfo{
DraftID: row.DraftID,
PackID: row.PackID,
ProviderID: row.ProviderID,
DisplayName: row.DisplayName,
Platform: row.Platform,
BaseURL: row.BaseURL,
SmokeTestModel: row.SmokeTestModel,
SupportedModels: append([]string(nil), supportedModels...),
Manifest: manifestValue,
SourceHostID: row.SourceHostID,
Notes: row.Notes,
CreatedAt: row.CreatedAt,
UpdatedAt: row.UpdatedAt,
}, nil
}
func providerDraftRecordToInfoFromStored(row sqlite.ProviderDraft) (ProviderDraftInfo, error) {
var manifestValue any
if strings.TrimSpace(row.ManifestJSON) != "" {
if err := json.Unmarshal([]byte(row.ManifestJSON), &manifestValue); err != nil {
return ProviderDraftInfo{}, fmt.Errorf("decode stored provider draft manifest: %w", err)
}
}
supportedModels := []string{}
if strings.TrimSpace(row.SupportedModelsJSON) != "" {
if err := json.Unmarshal([]byte(row.SupportedModelsJSON), &supportedModels); err != nil {
return ProviderDraftInfo{}, fmt.Errorf("decode stored provider draft supported_models: %w", err)
}
}
return providerDraftRecordToInfo(row, manifestValue, supportedModels)
}
func buildPublishedProviderManifest(row sqlite.ProviderDraft) (pack.ProviderManifest, error) {
manifest := pack.ProviderManifest{
ProviderID: strings.TrimSpace(row.ProviderID),
DisplayName: strings.TrimSpace(row.DisplayName),
BaseURL: strings.TrimSpace(row.BaseURL),
Platform: strings.TrimSpace(row.Platform),
AccountType: "apikey",
SmokeTestModel: strings.TrimSpace(row.SmokeTestModel),
DefaultModels: normalizeStringList(decodeStringList(row.SupportedModelsJSON)),
GroupTemplate: pack.GroupTemplate{
Name: strings.TrimSpace(row.DisplayName) + " 默认分组",
RateMultiplier: 1.0,
},
ChannelTemplate: pack.ChannelTemplate{
Name: strings.TrimSpace(row.DisplayName) + " 默认渠道",
ModelMapping: map[string]string{},
},
PlanTemplate: pack.PlanTemplate{
Name: strings.TrimSpace(row.DisplayName) + " 默认套餐",
Price: 19.9,
ValidityDays: 30,
ValidityUnit: "day",
},
Import: pack.ImportOptions{
SupportsMultiKey: true,
SupportsStrict: true,
SupportsPartial: true,
},
}
if strings.TrimSpace(manifest.SmokeTestModel) == "" && len(manifest.DefaultModels) > 0 {
manifest.SmokeTestModel = manifest.DefaultModels[0]
}
if len(manifest.DefaultModels) == 0 && strings.TrimSpace(manifest.SmokeTestModel) != "" {
manifest.DefaultModels = []string{manifest.SmokeTestModel}
}
for _, model := range manifest.DefaultModels {
manifest.ChannelTemplate.ModelMapping[model] = model
}
if strings.TrimSpace(row.ManifestJSON) != "" {
if err := json.Unmarshal([]byte(row.ManifestJSON), &manifest); err != nil {
return pack.ProviderManifest{}, fmt.Errorf("decode publishable provider manifest: %w", err)
}
}
manifest.ProviderID = strings.TrimSpace(manifest.ProviderID)
manifest.DisplayName = strings.TrimSpace(manifest.DisplayName)
manifest.BaseURL = strings.TrimSpace(manifest.BaseURL)
manifest.Platform = strings.TrimSpace(manifest.Platform)
manifest.AccountType = strings.TrimSpace(manifest.AccountType)
manifest.SmokeTestModel = strings.TrimSpace(manifest.SmokeTestModel)
if manifest.AccountType == "" {
manifest.AccountType = "apikey"
}
manifest.DefaultModels = normalizeStringList(manifest.DefaultModels)
if len(manifest.DefaultModels) == 0 && manifest.SmokeTestModel != "" {
manifest.DefaultModels = []string{manifest.SmokeTestModel}
}
if manifest.SmokeTestModel == "" && len(manifest.DefaultModels) > 0 {
manifest.SmokeTestModel = manifest.DefaultModels[0]
}
if manifest.ChannelTemplate.ModelMapping == nil {
manifest.ChannelTemplate.ModelMapping = map[string]string{}
}
for _, model := range manifest.DefaultModels {
model = strings.TrimSpace(model)
if model == "" {
continue
}
if _, ok := manifest.ChannelTemplate.ModelMapping[model]; !ok {
manifest.ChannelTemplate.ModelMapping[model] = model
}
}
if manifest.SmokeTestModel != "" {
if _, ok := manifest.ChannelTemplate.ModelMapping[manifest.SmokeTestModel]; !ok {
manifest.ChannelTemplate.ModelMapping[manifest.SmokeTestModel] = manifest.SmokeTestModel
}
}
return manifest, nil
}
func encodeStringList(values []string) string {
encoded, err := json.Marshal(normalizeStringList(values))
if err != nil {
return "[]"
}
return string(encoded)
}
func normalizeProviderModels(values []string, smokeTestModel string) []string {
normalized := normalizeStringList(values)
if len(normalized) == 0 {
if smoke := strings.TrimSpace(smokeTestModel); smoke != "" {
normalized = []string{smoke}
}
}
if len(normalized) == 0 {
return nil
}
seen := make(map[string]struct{}, len(normalized))
result := make([]string, 0, len(normalized))
for _, value := range normalized {
if _, ok := seen[value]; ok {
continue
}
seen[value] = struct{}{}
result = append(result, value)
}
return result
}
func intersectNormalizedModels(left, right []string) []string {
if len(left) == 0 || len(right) == 0 {
return nil
}
rightSet := make(map[string]struct{}, len(right))
for _, value := range right {
rightSet[value] = struct{}{}
}
matched := make([]string, 0)
for _, value := range left {
if _, ok := rightSet[value]; ok {
matched = append(matched, value)
}
}
return matched
}
func normalizeStringList(values []string) []string {
normalized := make([]string, 0, len(values))
for _, value := range values {
value = strings.TrimSpace(value)
if value == "" {
continue
}
normalized = append(normalized, value)
}
return normalized
}
func decodeStringList(raw string) []string {
raw = strings.TrimSpace(raw)
if raw == "" {
return []string{}
}
values := []string{}
if err := json.Unmarshal([]byte(raw), &values); err != nil {
return []string{}
}
return values
}
func deriveAccessStatus(gw sub2api.GatewayAccessResult) string {
if provision.GatewayAccessReady(gw) {
return provision.AccessStatusSubscriptionReady
}
return provision.AccessStatusBroken
}
func subscriptionTargets(userIDs []string, durationDays int) []provision.SubscriptionTarget {
targets := make([]provision.SubscriptionTarget, 0, len(userIDs))
for _, userID := range userIDs {
targets = append(targets, provision.SubscriptionTarget{
UserID: strings.TrimSpace(userID),
DurationDays: durationDays,
})
}
return targets
}