From 2818892255d22c1eb20e3856e4b9088dc0f16c9e Mon Sep 17 00:00:00 2001 From: phamnazage-jpg Date: Thu, 28 May 2026 15:57:34 +0800 Subject: [PATCH] feat(routing): add logical group admin api --- internal/app/http_api.go | 144 ++- internal/app/logical_groups_api.go | 840 ++++++++++++++++++ internal/app/logical_groups_api_test.go | 271 ++++++ internal/store/sqlite/db.go | 76 +- .../store/sqlite/logical_group_models_repo.go | 136 +++ .../sqlite/logical_group_route_models_repo.go | 113 +++ .../store/sqlite/logical_group_routes_repo.go | 238 +++++ internal/store/sqlite/logical_groups_repo.go | 249 ++++++ .../store/sqlite/logical_groups_repo_test.go | 308 +++++++ 9 files changed, 2312 insertions(+), 63 deletions(-) create mode 100644 internal/app/logical_groups_api.go create mode 100644 internal/app/logical_groups_api_test.go create mode 100644 internal/store/sqlite/logical_group_models_repo.go create mode 100644 internal/store/sqlite/logical_group_route_models_repo.go create mode 100644 internal/store/sqlite/logical_group_routes_repo.go create mode 100644 internal/store/sqlite/logical_groups_repo.go create mode 100644 internal/store/sqlite/logical_groups_repo_test.go diff --git a/internal/app/http_api.go b/internal/app/http_api.go index 43ffb67c..b83533b0 100644 --- a/internal/app/http_api.go +++ b/internal/app/http_api.go @@ -26,38 +26,52 @@ import ( ) type ActionSet struct { - CreateBatchImportRun func(context.Context, CreateBatchImportRunRequest) (BatchImportRunCreateResponse, error) - ListBatchImportRuns func(context.Context, ListBatchImportRunsRequest) (ListBatchImportRunsResponse, error) - GetBatchImportRun func(context.Context, string) (batch.RunSummaryProjection, error) - ListBatchImportRunItems func(context.Context, ListBatchImportRunItemsRequest) (ListBatchImportRunItemsResponse, error) - GetBatchImportRunItem func(context.Context, GetBatchImportRunItemRequest) (batch.ItemDetailProjection, error) - CreateProviderDraft func(context.Context, CreateProviderDraftRequest) (ProviderDraftInfo, error) - ListProviderDrafts func(context.Context, ListProviderDraftsRequest) ([]ProviderDraftInfo, error) - GetProviderDraft func(context.Context, string) (ProviderDraftInfo, error) - UpdateProviderDraft func(context.Context, UpdateProviderDraftRequest) (ProviderDraftInfo, error) - DeleteProviderDraft func(context.Context, string) error - PublishProviderDraft func(context.Context, PublishProviderDraftRequest) (PublishProviderDraftResult, error) - InstallPack func(context.Context, InstallPackRequest) (provision.PackInstallResult, error) - BatchDetail func(context.Context, BatchDetailRequest) (provision.BatchDetailResult, error) - GetProviderStatus func(context.Context, ProviderQueryRequest) (provision.ProviderSnapshot, error) - GetProviderResources func(context.Context, ProviderQueryRequest) (provision.ProviderSnapshot, error) - GetProviderAccessStatus func(context.Context, ProviderQueryRequest) (provision.ProviderSnapshot, error) - ListProviderImportBatches func(context.Context, ProviderQueryRequest) ([]ImportBatchInfo, error) - PreviewProvider func(context.Context, PreviewProviderRequest) (provision.PreviewReport, error) - ImportProvider func(context.Context, ImportProviderRequest) (provision.RuntimeImportResult, error) - RollbackProvider func(context.Context, RollbackProviderRequest) (provision.RollbackReport, error) - RollbackBatch func(context.Context, RollbackBatchRequest) (provision.RollbackReport, error) - ReconcileProvider func(context.Context, ReconcileProviderRequest) (reconcile.Result, error) - CreateHost func(context.Context, CreateHostRequest) (HostInfo, error) - ProbeHost func(context.Context, ProbeHostRequest) (HostInfo, error) - ListHosts func(context.Context) ([]HostInfo, error) - GetHost func(context.Context, string) (HostInfo, error) - DeleteHost func(context.Context, string) error - ListPacks func(context.Context) ([]PackInfo, error) - GetPack func(context.Context, string) (PackInfo, error) - ListPackProviders func(context.Context, string) ([]PackProviderInfo, error) - AssignAccessSubscriptions func(context.Context, AssignAccessSubscriptionsRequest) (AssignAccessSubscriptionsResult, error) - AccessPreview func(context.Context, AccessPreviewRequest) (AccessPreviewResult, error) + CreateBatchImportRun func(context.Context, CreateBatchImportRunRequest) (BatchImportRunCreateResponse, error) + ListBatchImportRuns func(context.Context, ListBatchImportRunsRequest) (ListBatchImportRunsResponse, error) + GetBatchImportRun func(context.Context, string) (batch.RunSummaryProjection, error) + ListBatchImportRunItems func(context.Context, ListBatchImportRunItemsRequest) (ListBatchImportRunItemsResponse, error) + GetBatchImportRunItem func(context.Context, GetBatchImportRunItemRequest) (batch.ItemDetailProjection, error) + CreateLogicalGroup func(context.Context, CreateLogicalGroupRequest) (LogicalGroupInfo, error) + ListLogicalGroups func(context.Context) ([]LogicalGroupInfo, error) + GetLogicalGroup func(context.Context, string) (LogicalGroupInfo, error) + UpdateLogicalGroup func(context.Context, UpdateLogicalGroupRequest) (LogicalGroupInfo, error) + DeleteLogicalGroup func(context.Context, string) error + CreateLogicalGroupModel func(context.Context, CreateLogicalGroupModelRequest) (LogicalGroupModelInfo, error) + ListLogicalGroupModels func(context.Context, string) ([]LogicalGroupModelInfo, error) + DeleteLogicalGroupModel func(context.Context, DeleteLogicalGroupModelRequest) error + CreateLogicalGroupRoute func(context.Context, CreateLogicalGroupRouteRequest) (LogicalGroupRouteInfo, error) + ListLogicalGroupRoutes func(context.Context, string) ([]LogicalGroupRouteInfo, error) + UpdateLogicalGroupRoute func(context.Context, UpdateLogicalGroupRouteRequest) (LogicalGroupRouteInfo, error) + DeleteLogicalGroupRoute func(context.Context, DeleteLogicalGroupRouteRequest) error + CreateLogicalGroupRouteModel func(context.Context, CreateLogicalGroupRouteModelRequest) (LogicalGroupRouteModelInfo, error) + ListLogicalGroupRouteModels func(context.Context, ListLogicalGroupRouteModelsRequest) ([]LogicalGroupRouteModelInfo, error) + CreateProviderDraft func(context.Context, CreateProviderDraftRequest) (ProviderDraftInfo, error) + ListProviderDrafts func(context.Context, ListProviderDraftsRequest) ([]ProviderDraftInfo, error) + GetProviderDraft func(context.Context, string) (ProviderDraftInfo, error) + UpdateProviderDraft func(context.Context, UpdateProviderDraftRequest) (ProviderDraftInfo, error) + DeleteProviderDraft func(context.Context, string) error + PublishProviderDraft func(context.Context, PublishProviderDraftRequest) (PublishProviderDraftResult, error) + InstallPack func(context.Context, InstallPackRequest) (provision.PackInstallResult, error) + BatchDetail func(context.Context, BatchDetailRequest) (provision.BatchDetailResult, error) + GetProviderStatus func(context.Context, ProviderQueryRequest) (provision.ProviderSnapshot, error) + GetProviderResources func(context.Context, ProviderQueryRequest) (provision.ProviderSnapshot, error) + GetProviderAccessStatus func(context.Context, ProviderQueryRequest) (provision.ProviderSnapshot, error) + ListProviderImportBatches func(context.Context, ProviderQueryRequest) ([]ImportBatchInfo, error) + PreviewProvider func(context.Context, PreviewProviderRequest) (provision.PreviewReport, error) + ImportProvider func(context.Context, ImportProviderRequest) (provision.RuntimeImportResult, error) + RollbackProvider func(context.Context, RollbackProviderRequest) (provision.RollbackReport, error) + RollbackBatch func(context.Context, RollbackBatchRequest) (provision.RollbackReport, error) + ReconcileProvider func(context.Context, ReconcileProviderRequest) (reconcile.Result, error) + CreateHost func(context.Context, CreateHostRequest) (HostInfo, error) + ProbeHost func(context.Context, ProbeHostRequest) (HostInfo, error) + ListHosts func(context.Context) ([]HostInfo, error) + GetHost func(context.Context, string) (HostInfo, error) + DeleteHost func(context.Context, string) error + ListPacks func(context.Context) ([]PackInfo, error) + GetPack func(context.Context, string) (PackInfo, error) + ListPackProviders func(context.Context, string) ([]PackProviderInfo, error) + AssignAccessSubscriptions func(context.Context, AssignAccessSubscriptionsRequest) (AssignAccessSubscriptionsResult, error) + AccessPreview func(context.Context, AccessPreviewRequest) (AccessPreviewResult, error) } const maxJSONBodyBytes int64 = 1 << 20 @@ -312,6 +326,48 @@ func NewAPIHandlerWithAuth(adminAuth AdminAuthConfig, actions ActionSet) http.Ha mux.Handle("GET /api/batch-import/runs/{run_id}/items/{item_id}", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handleGetBatchImportRunItem(w, r, actions.GetBatchImportRunItem) }))) + mux.Handle("POST /api/logical-groups", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handleCreateLogicalGroup(w, r, actions.CreateLogicalGroup) + }))) + mux.Handle("GET /api/logical-groups", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handleListLogicalGroups(w, r, actions.ListLogicalGroups) + }))) + mux.Handle("GET /api/logical-groups/{groupID}", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handleGetLogicalGroup(w, r, actions.GetLogicalGroup) + }))) + mux.Handle("PUT /api/logical-groups/{groupID}", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handleUpdateLogicalGroup(w, r, actions.UpdateLogicalGroup) + }))) + mux.Handle("DELETE /api/logical-groups/{groupID}", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handleDeleteLogicalGroup(w, r, actions.DeleteLogicalGroup) + }))) + mux.Handle("POST /api/logical-groups/{groupID}/models", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handleCreateLogicalGroupModel(w, r, actions.CreateLogicalGroupModel) + }))) + mux.Handle("GET /api/logical-groups/{groupID}/models", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handleListLogicalGroupModels(w, r, actions.ListLogicalGroupModels) + }))) + mux.Handle("DELETE /api/logical-groups/{groupID}/models/{model}", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handleDeleteLogicalGroupModel(w, r, actions.DeleteLogicalGroupModel) + }))) + mux.Handle("POST /api/logical-groups/{groupID}/routes", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handleCreateLogicalGroupRoute(w, r, actions.CreateLogicalGroupRoute) + }))) + mux.Handle("GET /api/logical-groups/{groupID}/routes", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handleListLogicalGroupRoutes(w, r, actions.ListLogicalGroupRoutes) + }))) + mux.Handle("PUT /api/logical-groups/{groupID}/routes/{routeID}", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handleUpdateLogicalGroupRoute(w, r, actions.UpdateLogicalGroupRoute) + }))) + mux.Handle("DELETE /api/logical-groups/{groupID}/routes/{routeID}", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handleDeleteLogicalGroupRoute(w, r, actions.DeleteLogicalGroupRoute) + }))) + mux.Handle("POST /api/logical-groups/{groupID}/routes/{routeID}/models", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handleCreateLogicalGroupRouteModel(w, r, actions.CreateLogicalGroupRouteModel) + }))) + mux.Handle("GET /api/logical-groups/{groupID}/routes/{routeID}/models", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handleListLogicalGroupRouteModels(w, r, actions.ListLogicalGroupRouteModels) + }))) mux.Handle("POST /api/provider-drafts", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handleCreateProviderDraft(w, r, actions.CreateProviderDraft) }))) @@ -1117,11 +1173,25 @@ func classifyError(err error) *httpError { func NewActionSet(sqliteDSN string) ActionSet { return ActionSet{ - CreateBatchImportRun: buildCreateBatchImportRunAction(sqliteDSN), - ListBatchImportRuns: buildListBatchImportRunsAction(sqliteDSN), - GetBatchImportRun: buildGetBatchImportRunAction(sqliteDSN), - ListBatchImportRunItems: buildListBatchImportRunItemsAction(sqliteDSN), - GetBatchImportRunItem: buildGetBatchImportRunItemAction(sqliteDSN), + CreateBatchImportRun: buildCreateBatchImportRunAction(sqliteDSN), + ListBatchImportRuns: buildListBatchImportRunsAction(sqliteDSN), + GetBatchImportRun: buildGetBatchImportRunAction(sqliteDSN), + ListBatchImportRunItems: buildListBatchImportRunItemsAction(sqliteDSN), + GetBatchImportRunItem: buildGetBatchImportRunItemAction(sqliteDSN), + CreateLogicalGroup: buildCreateLogicalGroupAction(sqliteDSN), + ListLogicalGroups: buildListLogicalGroupsAction(sqliteDSN), + GetLogicalGroup: buildGetLogicalGroupAction(sqliteDSN), + UpdateLogicalGroup: buildUpdateLogicalGroupAction(sqliteDSN), + DeleteLogicalGroup: buildDeleteLogicalGroupAction(sqliteDSN), + CreateLogicalGroupModel: buildCreateLogicalGroupModelAction(sqliteDSN), + ListLogicalGroupModels: buildListLogicalGroupModelsAction(sqliteDSN), + DeleteLogicalGroupModel: buildDeleteLogicalGroupModelAction(sqliteDSN), + CreateLogicalGroupRoute: buildCreateLogicalGroupRouteAction(sqliteDSN), + ListLogicalGroupRoutes: buildListLogicalGroupRoutesAction(sqliteDSN), + UpdateLogicalGroupRoute: buildUpdateLogicalGroupRouteAction(sqliteDSN), + DeleteLogicalGroupRoute: buildDeleteLogicalGroupRouteAction(sqliteDSN), + CreateLogicalGroupRouteModel: buildCreateLogicalGroupRouteModelAction(sqliteDSN), + ListLogicalGroupRouteModels: buildListLogicalGroupRouteModelsAction(sqliteDSN), CreateProviderDraft: func(ctx context.Context, req CreateProviderDraftRequest) (ProviderDraftInfo, error) { store, err := sqlite.Open(ctx, sqliteDSN) if err != nil { diff --git a/internal/app/logical_groups_api.go b/internal/app/logical_groups_api.go new file mode 100644 index 00000000..457f777c --- /dev/null +++ b/internal/app/logical_groups_api.go @@ -0,0 +1,840 @@ +package app + +import ( + "context" + "database/sql" + "errors" + "fmt" + "net/http" + "strings" + + "sub2api-cn-relay-manager/internal/store/sqlite" +) + +type CreateLogicalGroupRequest struct { + LogicalGroupID string `json:"logical_group_id"` + DisplayName string `json:"display_name"` + Status string `json:"status"` + Description string `json:"description,omitempty"` + RoutePolicy string `json:"route_policy,omitempty"` + StickyMode string `json:"sticky_mode,omitempty"` + ConversationTTLSeconds int `json:"conversation_ttl_seconds,omitempty"` + UserModelTTLSeconds int `json:"user_model_ttl_seconds,omitempty"` + FailoverThreshold int `json:"failover_threshold,omitempty"` + CooldownSeconds int `json:"cooldown_seconds,omitempty"` +} + +type UpdateLogicalGroupRequest struct { + LogicalGroupID string `json:"-"` + DisplayName string `json:"display_name"` + Status string `json:"status"` + Description string `json:"description,omitempty"` + RoutePolicy string `json:"route_policy,omitempty"` + StickyMode string `json:"sticky_mode,omitempty"` + ConversationTTLSeconds int `json:"conversation_ttl_seconds,omitempty"` + UserModelTTLSeconds int `json:"user_model_ttl_seconds,omitempty"` + FailoverThreshold int `json:"failover_threshold,omitempty"` + CooldownSeconds int `json:"cooldown_seconds,omitempty"` +} + +type LogicalGroupInfo struct { + LogicalGroupID string `json:"logical_group_id"` + DisplayName string `json:"display_name"` + Status string `json:"status"` + Description string `json:"description,omitempty"` + RoutePolicy string `json:"route_policy,omitempty"` + StickyMode string `json:"sticky_mode,omitempty"` + ConversationTTLSeconds int `json:"conversation_ttl_seconds,omitempty"` + UserModelTTLSeconds int `json:"user_model_ttl_seconds,omitempty"` + FailoverThreshold int `json:"failover_threshold,omitempty"` + CooldownSeconds int `json:"cooldown_seconds,omitempty"` + Models []LogicalGroupModelInfo `json:"models,omitempty"` + Routes []LogicalGroupRouteInfo `json:"routes,omitempty"` + CreatedAt string `json:"created_at,omitempty"` + UpdatedAt string `json:"updated_at,omitempty"` +} + +type CreateLogicalGroupModelRequest struct { + LogicalGroupID string `json:"-"` + PublicModel string `json:"public_model"` + Status string `json:"status,omitempty"` +} + +type DeleteLogicalGroupModelRequest struct { + LogicalGroupID string + PublicModel string +} + +type LogicalGroupModelInfo struct { + PublicModel string `json:"public_model"` + Status string `json:"status,omitempty"` + CreatedAt string `json:"created_at,omitempty"` + UpdatedAt string `json:"updated_at,omitempty"` +} + +type CreateLogicalGroupRouteRequest struct { + LogicalGroupID string `json:"-"` + RouteID string `json:"route_id"` + Name string `json:"name"` + Status string `json:"status"` + Priority int `json:"priority"` + Weight int `json:"weight,omitempty"` + ShadowGroupID string `json:"shadow_group_id"` + ShadowHostID string `json:"shadow_host_id"` + UpstreamBaseURLHint string `json:"upstream_base_url_hint,omitempty"` + CooldownUntil string `json:"cooldown_until,omitempty"` +} + +type UpdateLogicalGroupRouteRequest struct { + LogicalGroupID string `json:"-"` + RouteID string `json:"-"` + Name string `json:"name"` + Status string `json:"status"` + Priority int `json:"priority"` + Weight int `json:"weight,omitempty"` + ShadowGroupID string `json:"shadow_group_id"` + ShadowHostID string `json:"shadow_host_id"` + UpstreamBaseURLHint string `json:"upstream_base_url_hint,omitempty"` + CooldownUntil string `json:"cooldown_until,omitempty"` +} + +type DeleteLogicalGroupRouteRequest struct { + LogicalGroupID string + RouteID string +} + +type LogicalGroupRouteInfo struct { + RouteID string `json:"route_id"` + LogicalGroupID string `json:"logical_group_id"` + Name string `json:"name"` + Status string `json:"status"` + Priority int `json:"priority"` + Weight int `json:"weight,omitempty"` + ShadowGroupID string `json:"shadow_group_id"` + ShadowHostID string `json:"shadow_host_id"` + UpstreamBaseURLHint string `json:"upstream_base_url_hint,omitempty"` + CooldownUntil string `json:"cooldown_until,omitempty"` + Models []LogicalGroupRouteModelInfo `json:"models,omitempty"` + CreatedAt string `json:"created_at,omitempty"` + UpdatedAt string `json:"updated_at,omitempty"` +} + +type CreateLogicalGroupRouteModelRequest struct { + LogicalGroupID string `json:"-"` + RouteID string `json:"-"` + PublicModel string `json:"public_model"` + ShadowModel string `json:"shadow_model,omitempty"` + Status string `json:"status,omitempty"` +} + +type ListLogicalGroupRouteModelsRequest struct { + LogicalGroupID string + RouteID string +} + +type LogicalGroupRouteModelInfo struct { + PublicModel string `json:"public_model"` + ShadowModel string `json:"shadow_model,omitempty"` + Status string `json:"status,omitempty"` + CreatedAt string `json:"created_at,omitempty"` + UpdatedAt string `json:"updated_at,omitempty"` +} + +func handleCreateLogicalGroup(w http.ResponseWriter, r *http.Request, fn func(context.Context, CreateLogicalGroupRequest) (LogicalGroupInfo, error)) { + if fn == nil { + writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "create-logical-group action is not configured"}) + return + } + var req CreateLogicalGroupRequest + if err := decodeJSON(r, &req); err != nil { + writeHTTPError(w, err) + return + } + group, err := fn(r.Context(), req) + if err != nil { + writeHTTPError(w, classifyError(err)) + return + } + writeJSON(w, http.StatusCreated, map[string]any{"logical_group": group}) +} + +func handleListLogicalGroups(w http.ResponseWriter, r *http.Request, fn func(context.Context) ([]LogicalGroupInfo, error)) { + if fn == nil { + writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "list-logical-groups action is not configured"}) + return + } + groups, err := fn(r.Context()) + if err != nil { + writeHTTPError(w, classifyError(err)) + return + } + writeJSON(w, http.StatusOK, map[string]any{"logical_groups": groups}) +} + +func handleGetLogicalGroup(w http.ResponseWriter, r *http.Request, fn func(context.Context, string) (LogicalGroupInfo, error)) { + if fn == nil { + writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "get-logical-group action is not configured"}) + return + } + groupID := strings.TrimSpace(r.PathValue("groupID")) + group, err := fn(r.Context(), groupID) + if err != nil { + writeHTTPError(w, classifyError(err)) + return + } + writeJSON(w, http.StatusOK, map[string]any{"logical_group": group}) +} + +func handleUpdateLogicalGroup(w http.ResponseWriter, r *http.Request, fn func(context.Context, UpdateLogicalGroupRequest) (LogicalGroupInfo, error)) { + if fn == nil { + writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "update-logical-group action is not configured"}) + return + } + var req UpdateLogicalGroupRequest + req.LogicalGroupID = strings.TrimSpace(r.PathValue("groupID")) + if err := decodeJSON(r, &req); err != nil { + writeHTTPError(w, err) + return + } + group, err := fn(r.Context(), req) + if err != nil { + writeHTTPError(w, classifyError(err)) + return + } + writeJSON(w, http.StatusOK, map[string]any{"logical_group": group}) +} + +func handleDeleteLogicalGroup(w http.ResponseWriter, r *http.Request, fn func(context.Context, string) error) { + if fn == nil { + writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "delete-logical-group action is not configured"}) + return + } + if err := fn(r.Context(), strings.TrimSpace(r.PathValue("groupID"))); err != nil { + writeHTTPError(w, classifyError(err)) + return + } + w.WriteHeader(http.StatusNoContent) +} + +func handleCreateLogicalGroupModel(w http.ResponseWriter, r *http.Request, fn func(context.Context, CreateLogicalGroupModelRequest) (LogicalGroupModelInfo, error)) { + if fn == nil { + writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "create-logical-group-model action is not configured"}) + return + } + var req CreateLogicalGroupModelRequest + req.LogicalGroupID = strings.TrimSpace(r.PathValue("groupID")) + if err := decodeJSON(r, &req); err != nil { + writeHTTPError(w, err) + return + } + model, err := fn(r.Context(), req) + if err != nil { + writeHTTPError(w, classifyError(err)) + return + } + writeJSON(w, http.StatusCreated, map[string]any{"logical_group_model": model}) +} + +func handleListLogicalGroupModels(w http.ResponseWriter, r *http.Request, fn func(context.Context, string) ([]LogicalGroupModelInfo, error)) { + if fn == nil { + writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "list-logical-group-models action is not configured"}) + return + } + models, err := fn(r.Context(), strings.TrimSpace(r.PathValue("groupID"))) + if err != nil { + writeHTTPError(w, classifyError(err)) + return + } + writeJSON(w, http.StatusOK, map[string]any{"models": models}) +} + +func handleDeleteLogicalGroupModel(w http.ResponseWriter, r *http.Request, fn func(context.Context, DeleteLogicalGroupModelRequest) error) { + if fn == nil { + writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "delete-logical-group-model action is not configured"}) + return + } + if err := fn(r.Context(), DeleteLogicalGroupModelRequest{ + LogicalGroupID: strings.TrimSpace(r.PathValue("groupID")), + PublicModel: strings.TrimSpace(r.PathValue("model")), + }); err != nil { + writeHTTPError(w, classifyError(err)) + return + } + w.WriteHeader(http.StatusNoContent) +} + +func handleCreateLogicalGroupRoute(w http.ResponseWriter, r *http.Request, fn func(context.Context, CreateLogicalGroupRouteRequest) (LogicalGroupRouteInfo, error)) { + if fn == nil { + writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "create-logical-group-route action is not configured"}) + return + } + var req CreateLogicalGroupRouteRequest + req.LogicalGroupID = strings.TrimSpace(r.PathValue("groupID")) + if err := decodeJSON(r, &req); err != nil { + writeHTTPError(w, err) + return + } + route, err := fn(r.Context(), req) + if err != nil { + writeHTTPError(w, classifyError(err)) + return + } + writeJSON(w, http.StatusCreated, map[string]any{"route": route}) +} + +func handleListLogicalGroupRoutes(w http.ResponseWriter, r *http.Request, fn func(context.Context, string) ([]LogicalGroupRouteInfo, error)) { + if fn == nil { + writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "list-logical-group-routes action is not configured"}) + return + } + routes, err := fn(r.Context(), strings.TrimSpace(r.PathValue("groupID"))) + if err != nil { + writeHTTPError(w, classifyError(err)) + return + } + writeJSON(w, http.StatusOK, map[string]any{"routes": routes}) +} + +func handleUpdateLogicalGroupRoute(w http.ResponseWriter, r *http.Request, fn func(context.Context, UpdateLogicalGroupRouteRequest) (LogicalGroupRouteInfo, error)) { + if fn == nil { + writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "update-logical-group-route action is not configured"}) + return + } + var req UpdateLogicalGroupRouteRequest + req.LogicalGroupID = strings.TrimSpace(r.PathValue("groupID")) + req.RouteID = strings.TrimSpace(r.PathValue("routeID")) + if err := decodeJSON(r, &req); err != nil { + writeHTTPError(w, err) + return + } + route, err := fn(r.Context(), req) + if err != nil { + writeHTTPError(w, classifyError(err)) + return + } + writeJSON(w, http.StatusOK, map[string]any{"route": route}) +} + +func handleDeleteLogicalGroupRoute(w http.ResponseWriter, r *http.Request, fn func(context.Context, DeleteLogicalGroupRouteRequest) error) { + if fn == nil { + writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "delete-logical-group-route action is not configured"}) + return + } + if err := fn(r.Context(), DeleteLogicalGroupRouteRequest{ + LogicalGroupID: strings.TrimSpace(r.PathValue("groupID")), + RouteID: strings.TrimSpace(r.PathValue("routeID")), + }); err != nil { + writeHTTPError(w, classifyError(err)) + return + } + w.WriteHeader(http.StatusNoContent) +} + +func handleCreateLogicalGroupRouteModel(w http.ResponseWriter, r *http.Request, fn func(context.Context, CreateLogicalGroupRouteModelRequest) (LogicalGroupRouteModelInfo, error)) { + if fn == nil { + writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "create-logical-group-route-model action is not configured"}) + return + } + var req CreateLogicalGroupRouteModelRequest + req.LogicalGroupID = strings.TrimSpace(r.PathValue("groupID")) + req.RouteID = strings.TrimSpace(r.PathValue("routeID")) + if err := decodeJSON(r, &req); err != nil { + writeHTTPError(w, err) + return + } + model, err := fn(r.Context(), req) + if err != nil { + writeHTTPError(w, classifyError(err)) + return + } + writeJSON(w, http.StatusCreated, map[string]any{"route_model": model}) +} + +func handleListLogicalGroupRouteModels(w http.ResponseWriter, r *http.Request, fn func(context.Context, ListLogicalGroupRouteModelsRequest) ([]LogicalGroupRouteModelInfo, error)) { + if fn == nil { + writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "list-logical-group-route-models action is not configured"}) + return + } + models, err := fn(r.Context(), ListLogicalGroupRouteModelsRequest{ + LogicalGroupID: strings.TrimSpace(r.PathValue("groupID")), + RouteID: strings.TrimSpace(r.PathValue("routeID")), + }) + if err != nil { + writeHTTPError(w, classifyError(err)) + return + } + writeJSON(w, http.StatusOK, map[string]any{"route_models": models}) +} + +func buildCreateLogicalGroupAction(sqliteDSN string) func(context.Context, CreateLogicalGroupRequest) (LogicalGroupInfo, error) { + return func(ctx context.Context, req CreateLogicalGroupRequest) (LogicalGroupInfo, error) { + store, err := sqlite.Open(ctx, sqliteDSN) + if err != nil { + return LogicalGroupInfo{}, err + } + defer store.Close() + + row := logicalGroupRequestToRow(req) + if _, err := store.LogicalGroups().Create(ctx, row); err != nil { + return LogicalGroupInfo{}, err + } + return loadLogicalGroupInfo(ctx, store, row.LogicalGroupID) + } +} + +func buildListLogicalGroupsAction(sqliteDSN string) func(context.Context) ([]LogicalGroupInfo, error) { + return func(ctx context.Context) ([]LogicalGroupInfo, error) { + store, err := sqlite.Open(ctx, sqliteDSN) + if err != nil { + return nil, err + } + defer store.Close() + + rows, err := store.LogicalGroups().List(ctx) + if err != nil { + return nil, err + } + result := make([]LogicalGroupInfo, 0, len(rows)) + for _, row := range rows { + info, err := loadLogicalGroupInfo(ctx, store, row.LogicalGroupID) + if err != nil { + return nil, err + } + result = append(result, info) + } + return result, nil + } +} + +func buildGetLogicalGroupAction(sqliteDSN string) func(context.Context, string) (LogicalGroupInfo, error) { + return func(ctx context.Context, logicalGroupID string) (LogicalGroupInfo, error) { + store, err := sqlite.Open(ctx, sqliteDSN) + if err != nil { + return LogicalGroupInfo{}, err + } + defer store.Close() + return loadLogicalGroupInfo(ctx, store, logicalGroupID) + } +} + +func buildUpdateLogicalGroupAction(sqliteDSN string) func(context.Context, UpdateLogicalGroupRequest) (LogicalGroupInfo, error) { + return func(ctx context.Context, req UpdateLogicalGroupRequest) (LogicalGroupInfo, error) { + store, err := sqlite.Open(ctx, sqliteDSN) + if err != nil { + return LogicalGroupInfo{}, err + } + defer store.Close() + + if _, err := getLogicalGroupRow(ctx, store, req.LogicalGroupID); err != nil { + return LogicalGroupInfo{}, err + } + if err := store.LogicalGroups().UpdateByLogicalGroupID(ctx, logicalGroupRequestToRow(CreateLogicalGroupRequest{ + LogicalGroupID: req.LogicalGroupID, + DisplayName: req.DisplayName, + Status: req.Status, + Description: req.Description, + RoutePolicy: req.RoutePolicy, + StickyMode: req.StickyMode, + ConversationTTLSeconds: req.ConversationTTLSeconds, + UserModelTTLSeconds: req.UserModelTTLSeconds, + FailoverThreshold: req.FailoverThreshold, + CooldownSeconds: req.CooldownSeconds, + })); err != nil { + return LogicalGroupInfo{}, err + } + return loadLogicalGroupInfo(ctx, store, req.LogicalGroupID) + } +} + +func buildDeleteLogicalGroupAction(sqliteDSN string) func(context.Context, string) error { + return func(ctx context.Context, logicalGroupID string) error { + store, err := sqlite.Open(ctx, sqliteDSN) + if err != nil { + return err + } + defer store.Close() + return store.LogicalGroups().DeleteByLogicalGroupID(ctx, strings.TrimSpace(logicalGroupID)) + } +} + +func buildCreateLogicalGroupModelAction(sqliteDSN string) func(context.Context, CreateLogicalGroupModelRequest) (LogicalGroupModelInfo, error) { + return func(ctx context.Context, req CreateLogicalGroupModelRequest) (LogicalGroupModelInfo, error) { + store, err := sqlite.Open(ctx, sqliteDSN) + if err != nil { + return LogicalGroupModelInfo{}, err + } + defer store.Close() + + if _, err := getLogicalGroupRow(ctx, store, req.LogicalGroupID); err != nil { + return LogicalGroupModelInfo{}, err + } + row := sqlite.LogicalGroupModel{ + LogicalGroupID: strings.TrimSpace(req.LogicalGroupID), + PublicModel: strings.TrimSpace(req.PublicModel), + Status: strings.TrimSpace(req.Status), + } + if _, err := store.LogicalGroupModels().Create(ctx, row); err != nil { + return LogicalGroupModelInfo{}, err + } + models, err := store.LogicalGroupModels().ListByLogicalGroupID(ctx, row.LogicalGroupID) + if err != nil { + return LogicalGroupModelInfo{}, err + } + for _, model := range models { + if model.PublicModel == row.PublicModel { + return logicalGroupModelRowToInfo(model), nil + } + } + return LogicalGroupModelInfo{}, fmt.Errorf("logical group model %q/%q not found", row.LogicalGroupID, row.PublicModel) + } +} + +func buildListLogicalGroupModelsAction(sqliteDSN string) func(context.Context, string) ([]LogicalGroupModelInfo, error) { + return func(ctx context.Context, logicalGroupID string) ([]LogicalGroupModelInfo, error) { + store, err := sqlite.Open(ctx, sqliteDSN) + if err != nil { + return nil, err + } + defer store.Close() + + if _, err := getLogicalGroupRow(ctx, store, logicalGroupID); err != nil { + return nil, err + } + rows, err := store.LogicalGroupModels().ListByLogicalGroupID(ctx, logicalGroupID) + if err != nil { + return nil, err + } + return logicalGroupModelRowsToInfo(rows), nil + } +} + +func buildDeleteLogicalGroupModelAction(sqliteDSN string) func(context.Context, DeleteLogicalGroupModelRequest) error { + return func(ctx context.Context, req DeleteLogicalGroupModelRequest) error { + store, err := sqlite.Open(ctx, sqliteDSN) + if err != nil { + return err + } + defer store.Close() + + if _, err := getLogicalGroupRow(ctx, store, req.LogicalGroupID); err != nil { + return err + } + return store.LogicalGroupModels().DeleteByLogicalGroupIDAndModel(ctx, strings.TrimSpace(req.LogicalGroupID), strings.TrimSpace(req.PublicModel)) + } +} + +func buildCreateLogicalGroupRouteAction(sqliteDSN string) func(context.Context, CreateLogicalGroupRouteRequest) (LogicalGroupRouteInfo, error) { + return func(ctx context.Context, req CreateLogicalGroupRouteRequest) (LogicalGroupRouteInfo, error) { + store, err := sqlite.Open(ctx, sqliteDSN) + if err != nil { + return LogicalGroupRouteInfo{}, err + } + defer store.Close() + + if _, err := getLogicalGroupRow(ctx, store, req.LogicalGroupID); err != nil { + return LogicalGroupRouteInfo{}, err + } + row := logicalGroupRouteRequestToRow(req) + if _, err := store.LogicalGroupRoutes().Create(ctx, row); err != nil { + return LogicalGroupRouteInfo{}, err + } + return loadLogicalGroupRouteInfo(ctx, store, row.RouteID) + } +} + +func buildListLogicalGroupRoutesAction(sqliteDSN string) func(context.Context, string) ([]LogicalGroupRouteInfo, error) { + return func(ctx context.Context, logicalGroupID string) ([]LogicalGroupRouteInfo, error) { + store, err := sqlite.Open(ctx, sqliteDSN) + if err != nil { + return nil, err + } + defer store.Close() + + if _, err := getLogicalGroupRow(ctx, store, logicalGroupID); err != nil { + return nil, err + } + rows, err := store.LogicalGroupRoutes().ListByLogicalGroupID(ctx, logicalGroupID) + if err != nil { + return nil, err + } + return logicalGroupRouteRowsToInfo(ctx, store, rows) + } +} + +func buildUpdateLogicalGroupRouteAction(sqliteDSN string) func(context.Context, UpdateLogicalGroupRouteRequest) (LogicalGroupRouteInfo, error) { + return func(ctx context.Context, req UpdateLogicalGroupRouteRequest) (LogicalGroupRouteInfo, error) { + store, err := sqlite.Open(ctx, sqliteDSN) + if err != nil { + return LogicalGroupRouteInfo{}, err + } + defer store.Close() + + if _, err := getLogicalGroupRow(ctx, store, req.LogicalGroupID); err != nil { + return LogicalGroupRouteInfo{}, err + } + existing, err := getLogicalGroupRouteRow(ctx, store, req.RouteID) + if err != nil { + return LogicalGroupRouteInfo{}, err + } + if existing.LogicalGroupID != strings.TrimSpace(req.LogicalGroupID) { + return LogicalGroupRouteInfo{}, fmt.Errorf("logical group route %q not found under logical group %q", req.RouteID, req.LogicalGroupID) + } + row := logicalGroupRouteRequestToRow(CreateLogicalGroupRouteRequest{ + LogicalGroupID: req.LogicalGroupID, + RouteID: req.RouteID, + Name: req.Name, + Status: req.Status, + Priority: req.Priority, + Weight: req.Weight, + ShadowGroupID: req.ShadowGroupID, + ShadowHostID: req.ShadowHostID, + UpstreamBaseURLHint: req.UpstreamBaseURLHint, + CooldownUntil: req.CooldownUntil, + }) + if err := store.LogicalGroupRoutes().UpdateByRouteID(ctx, row); err != nil { + return LogicalGroupRouteInfo{}, err + } + return loadLogicalGroupRouteInfo(ctx, store, req.RouteID) + } +} + +func buildDeleteLogicalGroupRouteAction(sqliteDSN string) func(context.Context, DeleteLogicalGroupRouteRequest) error { + return func(ctx context.Context, req DeleteLogicalGroupRouteRequest) error { + store, err := sqlite.Open(ctx, sqliteDSN) + if err != nil { + return err + } + defer store.Close() + + existing, err := getLogicalGroupRouteRow(ctx, store, req.RouteID) + if err != nil { + return err + } + if existing.LogicalGroupID != strings.TrimSpace(req.LogicalGroupID) { + return fmt.Errorf("logical group route %q not found under logical group %q", req.RouteID, req.LogicalGroupID) + } + return store.LogicalGroupRoutes().DeleteByRouteID(ctx, strings.TrimSpace(req.RouteID)) + } +} + +func buildCreateLogicalGroupRouteModelAction(sqliteDSN string) func(context.Context, CreateLogicalGroupRouteModelRequest) (LogicalGroupRouteModelInfo, error) { + return func(ctx context.Context, req CreateLogicalGroupRouteModelRequest) (LogicalGroupRouteModelInfo, error) { + store, err := sqlite.Open(ctx, sqliteDSN) + if err != nil { + return LogicalGroupRouteModelInfo{}, err + } + defer store.Close() + + route, err := getLogicalGroupRouteRow(ctx, store, req.RouteID) + if err != nil { + return LogicalGroupRouteModelInfo{}, err + } + if route.LogicalGroupID != strings.TrimSpace(req.LogicalGroupID) { + return LogicalGroupRouteModelInfo{}, fmt.Errorf("logical group route %q not found under logical group %q", req.RouteID, req.LogicalGroupID) + } + row := sqlite.LogicalGroupRouteModel{ + RouteID: strings.TrimSpace(req.RouteID), + PublicModel: strings.TrimSpace(req.PublicModel), + ShadowModel: strings.TrimSpace(req.ShadowModel), + Status: strings.TrimSpace(req.Status), + } + if _, err := store.LogicalGroupRouteModels().Create(ctx, row); err != nil { + return LogicalGroupRouteModelInfo{}, err + } + models, err := store.LogicalGroupRouteModels().ListByRouteID(ctx, row.RouteID) + if err != nil { + return LogicalGroupRouteModelInfo{}, err + } + for _, model := range models { + if model.PublicModel == row.PublicModel { + return logicalGroupRouteModelRowToInfo(model), nil + } + } + return LogicalGroupRouteModelInfo{}, fmt.Errorf("logical group route model %q/%q not found", row.RouteID, row.PublicModel) + } +} + +func buildListLogicalGroupRouteModelsAction(sqliteDSN string) func(context.Context, ListLogicalGroupRouteModelsRequest) ([]LogicalGroupRouteModelInfo, error) { + return func(ctx context.Context, req ListLogicalGroupRouteModelsRequest) ([]LogicalGroupRouteModelInfo, error) { + store, err := sqlite.Open(ctx, sqliteDSN) + if err != nil { + return nil, err + } + defer store.Close() + + route, err := getLogicalGroupRouteRow(ctx, store, req.RouteID) + if err != nil { + return nil, err + } + if route.LogicalGroupID != strings.TrimSpace(req.LogicalGroupID) { + return nil, fmt.Errorf("logical group route %q not found under logical group %q", req.RouteID, req.LogicalGroupID) + } + rows, err := store.LogicalGroupRouteModels().ListByRouteID(ctx, req.RouteID) + if err != nil { + return nil, err + } + return logicalGroupRouteModelRowsToInfo(rows), nil + } +} + +func logicalGroupRequestToRow(req CreateLogicalGroupRequest) sqlite.LogicalGroup { + return sqlite.LogicalGroup{ + LogicalGroupID: strings.TrimSpace(req.LogicalGroupID), + DisplayName: strings.TrimSpace(req.DisplayName), + Status: strings.TrimSpace(req.Status), + Description: strings.TrimSpace(req.Description), + RoutePolicy: strings.TrimSpace(req.RoutePolicy), + StickyMode: strings.TrimSpace(req.StickyMode), + ConversationTTLSeconds: req.ConversationTTLSeconds, + UserModelTTLSeconds: req.UserModelTTLSeconds, + FailoverThreshold: req.FailoverThreshold, + CooldownSeconds: req.CooldownSeconds, + } +} + +func logicalGroupRouteRequestToRow(req CreateLogicalGroupRouteRequest) sqlite.LogicalGroupRoute { + return sqlite.LogicalGroupRoute{ + RouteID: strings.TrimSpace(req.RouteID), + LogicalGroupID: strings.TrimSpace(req.LogicalGroupID), + Name: strings.TrimSpace(req.Name), + Status: strings.TrimSpace(req.Status), + Priority: req.Priority, + Weight: req.Weight, + ShadowGroupID: strings.TrimSpace(req.ShadowGroupID), + ShadowHostID: strings.TrimSpace(req.ShadowHostID), + UpstreamBaseURLHint: strings.TrimSpace(req.UpstreamBaseURLHint), + CooldownUntil: strings.TrimSpace(req.CooldownUntil), + } +} + +func loadLogicalGroupInfo(ctx context.Context, store *sqlite.DB, logicalGroupID string) (LogicalGroupInfo, error) { + group, err := getLogicalGroupRow(ctx, store, logicalGroupID) + if err != nil { + return LogicalGroupInfo{}, err + } + models, err := store.LogicalGroupModels().ListByLogicalGroupID(ctx, group.LogicalGroupID) + if err != nil { + return LogicalGroupInfo{}, err + } + routes, err := store.LogicalGroupRoutes().ListByLogicalGroupID(ctx, group.LogicalGroupID) + if err != nil { + return LogicalGroupInfo{}, err + } + routeInfos, err := logicalGroupRouteRowsToInfo(ctx, store, routes) + if err != nil { + return LogicalGroupInfo{}, err + } + return logicalGroupRowToInfo(group, logicalGroupModelRowsToInfo(models), routeInfos), nil +} + +func loadLogicalGroupRouteInfo(ctx context.Context, store *sqlite.DB, routeID string) (LogicalGroupRouteInfo, error) { + route, err := getLogicalGroupRouteRow(ctx, store, routeID) + if err != nil { + return LogicalGroupRouteInfo{}, err + } + models, err := store.LogicalGroupRouteModels().ListByRouteID(ctx, route.RouteID) + if err != nil { + return LogicalGroupRouteInfo{}, err + } + return logicalGroupRouteRowToInfo(route, logicalGroupRouteModelRowsToInfo(models)), nil +} + +func logicalGroupRowToInfo(group sqlite.LogicalGroup, models []LogicalGroupModelInfo, routes []LogicalGroupRouteInfo) LogicalGroupInfo { + return LogicalGroupInfo{ + LogicalGroupID: group.LogicalGroupID, + DisplayName: group.DisplayName, + Status: group.Status, + Description: group.Description, + RoutePolicy: group.RoutePolicy, + StickyMode: group.StickyMode, + ConversationTTLSeconds: group.ConversationTTLSeconds, + UserModelTTLSeconds: group.UserModelTTLSeconds, + FailoverThreshold: group.FailoverThreshold, + CooldownSeconds: group.CooldownSeconds, + Models: models, + Routes: routes, + CreatedAt: group.CreatedAt, + UpdatedAt: group.UpdatedAt, + } +} + +func logicalGroupModelRowsToInfo(rows []sqlite.LogicalGroupModel) []LogicalGroupModelInfo { + result := make([]LogicalGroupModelInfo, 0, len(rows)) + for _, row := range rows { + result = append(result, logicalGroupModelRowToInfo(row)) + } + return result +} + +func logicalGroupModelRowToInfo(row sqlite.LogicalGroupModel) LogicalGroupModelInfo { + return LogicalGroupModelInfo{ + PublicModel: row.PublicModel, + Status: row.Status, + CreatedAt: row.CreatedAt, + UpdatedAt: row.UpdatedAt, + } +} + +func logicalGroupRouteRowsToInfo(ctx context.Context, store *sqlite.DB, rows []sqlite.LogicalGroupRoute) ([]LogicalGroupRouteInfo, error) { + result := make([]LogicalGroupRouteInfo, 0, len(rows)) + for _, row := range rows { + models, err := store.LogicalGroupRouteModels().ListByRouteID(ctx, row.RouteID) + if err != nil { + return nil, err + } + result = append(result, logicalGroupRouteRowToInfo(row, logicalGroupRouteModelRowsToInfo(models))) + } + return result, nil +} + +func logicalGroupRouteRowToInfo(row sqlite.LogicalGroupRoute, models []LogicalGroupRouteModelInfo) LogicalGroupRouteInfo { + return LogicalGroupRouteInfo{ + RouteID: row.RouteID, + LogicalGroupID: row.LogicalGroupID, + Name: row.Name, + Status: row.Status, + Priority: row.Priority, + Weight: row.Weight, + ShadowGroupID: row.ShadowGroupID, + ShadowHostID: row.ShadowHostID, + UpstreamBaseURLHint: row.UpstreamBaseURLHint, + CooldownUntil: row.CooldownUntil, + Models: models, + CreatedAt: row.CreatedAt, + UpdatedAt: row.UpdatedAt, + } +} + +func logicalGroupRouteModelRowsToInfo(rows []sqlite.LogicalGroupRouteModel) []LogicalGroupRouteModelInfo { + result := make([]LogicalGroupRouteModelInfo, 0, len(rows)) + for _, row := range rows { + result = append(result, logicalGroupRouteModelRowToInfo(row)) + } + return result +} + +func logicalGroupRouteModelRowToInfo(row sqlite.LogicalGroupRouteModel) LogicalGroupRouteModelInfo { + return LogicalGroupRouteModelInfo{ + PublicModel: row.PublicModel, + ShadowModel: row.ShadowModel, + Status: row.Status, + CreatedAt: row.CreatedAt, + UpdatedAt: row.UpdatedAt, + } +} + +func getLogicalGroupRow(ctx context.Context, store *sqlite.DB, logicalGroupID string) (sqlite.LogicalGroup, error) { + row, err := store.LogicalGroups().GetByLogicalGroupID(ctx, strings.TrimSpace(logicalGroupID)) + if errors.Is(err, sql.ErrNoRows) { + return sqlite.LogicalGroup{}, fmt.Errorf("logical group %q not found", strings.TrimSpace(logicalGroupID)) + } + return row, err +} + +func getLogicalGroupRouteRow(ctx context.Context, store *sqlite.DB, routeID string) (sqlite.LogicalGroupRoute, error) { + row, err := store.LogicalGroupRoutes().GetByRouteID(ctx, strings.TrimSpace(routeID)) + if errors.Is(err, sql.ErrNoRows) { + return sqlite.LogicalGroupRoute{}, fmt.Errorf("logical group route %q not found", strings.TrimSpace(routeID)) + } + return row, err +} diff --git a/internal/app/logical_groups_api_test.go b/internal/app/logical_groups_api_test.go new file mode 100644 index 00000000..a50ca63b --- /dev/null +++ b/internal/app/logical_groups_api_test.go @@ -0,0 +1,271 @@ +package app + +import ( + "context" + "encoding/json" + "net/http" + "path/filepath" + "testing" + + "sub2api-cn-relay-manager/internal/store/sqlite" +) + +func TestAPICreateLogicalGroupReturnsCreated(t *testing.T) { + handler := NewAPIHandler("secret-token", ActionSet{ + CreateLogicalGroup: func(_ context.Context, req CreateLogicalGroupRequest) (LogicalGroupInfo, error) { + if req.LogicalGroupID != "gpt-shared" { + t.Fatalf("LogicalGroupID = %q, want gpt-shared", req.LogicalGroupID) + } + return LogicalGroupInfo{ + LogicalGroupID: req.LogicalGroupID, + DisplayName: req.DisplayName, + Status: req.Status, + }, nil + }, + }) + + request := httptestRequest(t, http.MethodPost, "/api/logical-groups", map[string]any{ + "logical_group_id": "gpt-shared", + "display_name": "GPT Shared", + "status": "active", + }, "secret-token") + response := httptestRecorder(handler, request) + assertStatusCode(t, response, http.StatusCreated) + assertJSONContains(t, response.Body().Bytes(), "logical_group.logical_group_id", "gpt-shared") +} + +func TestAPIGetLogicalGroupReturnsAggregatedItem(t *testing.T) { + handler := NewAPIHandler("secret-token", ActionSet{ + GetLogicalGroup: func(_ context.Context, groupID string) (LogicalGroupInfo, error) { + if groupID != "gpt-shared" { + t.Fatalf("groupID = %q, want gpt-shared", groupID) + } + return LogicalGroupInfo{ + LogicalGroupID: groupID, + DisplayName: "GPT Shared", + Status: "active", + Models: []LogicalGroupModelInfo{{PublicModel: "gpt-5.4", Status: "active"}}, + Routes: []LogicalGroupRouteInfo{{ + RouteID: "asxs", + LogicalGroupID: groupID, + Name: "ASXS", + Status: "active", + Models: []LogicalGroupRouteModelInfo{{PublicModel: "gpt-5.4", ShadowModel: "gpt-5.4"}}, + }}, + }, nil + }, + }) + + request := httptestRequest(t, http.MethodGet, "/api/logical-groups/gpt-shared", nil, "secret-token") + response := httptestRecorder(handler, request) + assertStatusCode(t, response, http.StatusOK) + var payload map[string]any + if err := json.Unmarshal(response.Body().Bytes(), &payload); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + group, ok := payload["logical_group"].(map[string]any) + if !ok { + t.Fatalf("logical_group = %#v, want object", payload["logical_group"]) + } + models, ok := group["models"].([]any) + if !ok || len(models) != 1 { + t.Fatalf("models = %#v, want one item", group["models"]) + } + firstModel, ok := models[0].(map[string]any) + if !ok || firstModel["public_model"] != "gpt-5.4" { + t.Fatalf("first model = %#v, want public_model gpt-5.4", models[0]) + } + routes, ok := group["routes"].([]any) + if !ok || len(routes) != 1 { + t.Fatalf("routes = %#v, want one item", group["routes"]) + } + firstRoute, ok := routes[0].(map[string]any) + if !ok || firstRoute["route_id"] != "asxs" { + t.Fatalf("first route = %#v, want route_id asxs", routes[0]) + } +} + +func TestAPICreateLogicalGroupRouteUsesPathGroupID(t *testing.T) { + handler := NewAPIHandler("secret-token", ActionSet{ + CreateLogicalGroupRoute: func(_ context.Context, req CreateLogicalGroupRouteRequest) (LogicalGroupRouteInfo, error) { + if req.LogicalGroupID != "gpt-shared" { + t.Fatalf("LogicalGroupID = %q, want gpt-shared", req.LogicalGroupID) + } + return LogicalGroupRouteInfo{ + RouteID: req.RouteID, + LogicalGroupID: req.LogicalGroupID, + Name: req.Name, + Status: req.Status, + }, nil + }, + }) + + request := httptestRequest(t, http.MethodPost, "/api/logical-groups/gpt-shared/routes", map[string]any{ + "route_id": "asxs", + "name": "ASXS", + "status": "active", + "priority": 10, + "shadow_group_id": "gpt-shared__asxs", + "shadow_host_id": "remote43", + }, "secret-token") + response := httptestRecorder(handler, request) + assertStatusCode(t, response, http.StatusCreated) + assertJSONContains(t, response.Body().Bytes(), "route.logical_group_id", "gpt-shared") + assertJSONContains(t, response.Body().Bytes(), "route.route_id", "asxs") +} + +func TestAPICreateLogicalGroupRouteModelUsesPathValues(t *testing.T) { + handler := NewAPIHandler("secret-token", ActionSet{ + CreateLogicalGroupRouteModel: func(_ context.Context, req CreateLogicalGroupRouteModelRequest) (LogicalGroupRouteModelInfo, error) { + if req.LogicalGroupID != "gpt-shared" { + t.Fatalf("LogicalGroupID = %q, want gpt-shared", req.LogicalGroupID) + } + if req.RouteID != "asxs" { + t.Fatalf("RouteID = %q, want asxs", req.RouteID) + } + return LogicalGroupRouteModelInfo{ + PublicModel: req.PublicModel, + ShadowModel: req.ShadowModel, + Status: req.Status, + }, nil + }, + }) + + request := httptestRequest(t, http.MethodPost, "/api/logical-groups/gpt-shared/routes/asxs/models", map[string]any{ + "public_model": "gpt-5.4", + "shadow_model": "gpt-5.4", + "status": "active", + }, "secret-token") + response := httptestRecorder(handler, request) + assertStatusCode(t, response, http.StatusCreated) + assertJSONContains(t, response.Body().Bytes(), "route_model.public_model", "gpt-5.4") +} + +func TestNewActionSetLogicalGroupCRUDFlow(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "logical-groups.db") + dsn := "file:" + filepath.ToSlash(dbPath) + "?_busy_timeout=5000" + actions := NewActionSet(dsn) + ctx := context.Background() + + createdGroup, err := actions.CreateLogicalGroup(ctx, CreateLogicalGroupRequest{ + LogicalGroupID: "gpt-shared", + DisplayName: "GPT Shared", + Status: "active", + }) + if err != nil { + t.Fatalf("CreateLogicalGroup() error = %v", err) + } + if createdGroup.LogicalGroupID != "gpt-shared" { + t.Fatalf("CreateLogicalGroup() = %+v, want logical_group_id gpt-shared", createdGroup) + } + + if _, err := actions.CreateLogicalGroupModel(ctx, CreateLogicalGroupModelRequest{ + LogicalGroupID: "gpt-shared", + PublicModel: "gpt-5.4", + }); err != nil { + t.Fatalf("CreateLogicalGroupModel() error = %v", err) + } + + if _, err := actions.CreateLogicalGroupRoute(ctx, CreateLogicalGroupRouteRequest{ + LogicalGroupID: "gpt-shared", + RouteID: "asxs", + Name: "ASXS", + Status: "active", + Priority: 10, + ShadowGroupID: "gpt-shared__asxs", + ShadowHostID: "remote43", + }); err != nil { + t.Fatalf("CreateLogicalGroupRoute() error = %v", err) + } + + if _, err := actions.CreateLogicalGroupRouteModel(ctx, CreateLogicalGroupRouteModelRequest{ + LogicalGroupID: "gpt-shared", + RouteID: "asxs", + PublicModel: "gpt-5.4", + }); err != nil { + t.Fatalf("CreateLogicalGroupRouteModel() error = %v", err) + } + + group, err := actions.GetLogicalGroup(ctx, "gpt-shared") + if err != nil { + t.Fatalf("GetLogicalGroup() error = %v", err) + } + if len(group.Models) != 1 || group.Models[0].PublicModel != "gpt-5.4" { + t.Fatalf("GetLogicalGroup().Models = %+v, want gpt-5.4", group.Models) + } + if len(group.Routes) != 1 || group.Routes[0].RouteID != "asxs" { + t.Fatalf("GetLogicalGroup().Routes = %+v, want route asxs", group.Routes) + } + if len(group.Routes[0].Models) != 1 || group.Routes[0].Models[0].ShadowModel != "gpt-5.4" { + t.Fatalf("GetLogicalGroup().Routes[0].Models = %+v, want shadow gpt-5.4", group.Routes[0].Models) + } + + if _, err := actions.UpdateLogicalGroup(ctx, UpdateLogicalGroupRequest{ + LogicalGroupID: "gpt-shared", + DisplayName: "GPT Shared Updated", + Status: "paused", + }); err != nil { + t.Fatalf("UpdateLogicalGroup() error = %v", err) + } + if _, err := actions.UpdateLogicalGroupRoute(ctx, UpdateLogicalGroupRouteRequest{ + LogicalGroupID: "gpt-shared", + RouteID: "asxs", + Name: "ASXS Updated", + Status: "degraded", + Priority: 20, + Weight: 80, + ShadowGroupID: "gpt-shared__asxs", + ShadowHostID: "remote43", + CooldownUntil: "2026-05-28T16:00:00Z", + }); err != nil { + t.Fatalf("UpdateLogicalGroupRoute() error = %v", err) + } + + groups, err := actions.ListLogicalGroups(ctx) + if err != nil { + t.Fatalf("ListLogicalGroups() error = %v", err) + } + if len(groups) != 1 || groups[0].DisplayName != "GPT Shared Updated" { + t.Fatalf("ListLogicalGroups() = %+v, want updated group", groups) + } + + routeModels, err := actions.ListLogicalGroupRouteModels(ctx, ListLogicalGroupRouteModelsRequest{ + LogicalGroupID: "gpt-shared", + RouteID: "asxs", + }) + if err != nil { + t.Fatalf("ListLogicalGroupRouteModels() error = %v", err) + } + if len(routeModels) != 1 || routeModels[0].PublicModel != "gpt-5.4" { + t.Fatalf("ListLogicalGroupRouteModels() = %+v, want gpt-5.4", routeModels) + } + + if err := actions.DeleteLogicalGroupRoute(ctx, DeleteLogicalGroupRouteRequest{ + LogicalGroupID: "gpt-shared", + RouteID: "asxs", + }); err != nil { + t.Fatalf("DeleteLogicalGroupRoute() error = %v", err) + } + if err := actions.DeleteLogicalGroupModel(ctx, DeleteLogicalGroupModelRequest{ + LogicalGroupID: "gpt-shared", + PublicModel: "gpt-5.4", + }); err != nil { + t.Fatalf("DeleteLogicalGroupModel() error = %v", err) + } + if err := actions.DeleteLogicalGroup(ctx, "gpt-shared"); err != nil { + t.Fatalf("DeleteLogicalGroup() error = %v", err) + } + + store, err := sqlite.Open(ctx, dsn) + if err != nil { + t.Fatalf("sqlite.Open() error = %v", err) + } + defer store.Close() + remaining, err := store.LogicalGroups().List(ctx) + if err != nil { + t.Fatalf("LogicalGroups().List() error = %v", err) + } + if len(remaining) != 0 { + t.Fatalf("remaining logical groups = %+v, want empty", remaining) + } +} diff --git a/internal/store/sqlite/db.go b/internal/store/sqlite/db.go index 772a6413..619a461c 100644 --- a/internal/store/sqlite/db.go +++ b/internal/store/sqlite/db.go @@ -20,19 +20,23 @@ type execQuerier interface { } type Queries struct { - Hosts *HostsRepo - Packs *PacksRepo - Providers *ProvidersRepo - ProviderDrafts *ProviderDraftsRepo - ImportBatches *ImportBatchesRepo - ImportBatchItems *ImportBatchItemsRepo - ImportRuns *ImportRunsRepo - ImportRunItems *ImportRunItemsRepo - ImportRunEvents *ImportRunItemEventsRepo - ManagedResources *ManagedResourcesRepo - ProbeResults *ProbeResultsRepo - AccessClosures *AccessClosureRecordsRepo - ReconcileRuns *ReconcileRunsRepo + Hosts *HostsRepo + Packs *PacksRepo + Providers *ProvidersRepo + LogicalGroups *LogicalGroupsRepo + LogicalGroupModels *LogicalGroupModelsRepo + LogicalGroupRoutes *LogicalGroupRoutesRepo + LogicalGroupRouteModels *LogicalGroupRouteModelsRepo + ProviderDrafts *ProviderDraftsRepo + ImportBatches *ImportBatchesRepo + ImportBatchItems *ImportBatchItemsRepo + ImportRuns *ImportRunsRepo + ImportRunItems *ImportRunItemsRepo + ImportRunEvents *ImportRunItemEventsRepo + ManagedResources *ManagedResourcesRepo + ProbeResults *ProbeResultsRepo + AccessClosures *AccessClosureRecordsRepo + ReconcileRuns *ReconcileRunsRepo } type DB struct { @@ -92,6 +96,22 @@ func (db *DB) Providers() *ProvidersRepo { return db.queries.Providers } +func (db *DB) LogicalGroups() *LogicalGroupsRepo { + return db.queries.LogicalGroups +} + +func (db *DB) LogicalGroupModels() *LogicalGroupModelsRepo { + return db.queries.LogicalGroupModels +} + +func (db *DB) LogicalGroupRoutes() *LogicalGroupRoutesRepo { + return db.queries.LogicalGroupRoutes +} + +func (db *DB) LogicalGroupRouteModels() *LogicalGroupRouteModelsRepo { + return db.queries.LogicalGroupRouteModels +} + func (db *DB) ProviderDrafts() *ProviderDraftsRepo { return db.queries.ProviderDrafts } @@ -161,19 +181,23 @@ func (db *DB) WithTx(ctx context.Context, fn func(*Queries) error) error { func newQueries(db execQuerier) *Queries { return &Queries{ - Hosts: newHostsRepo(db), - Packs: newPacksRepo(db), - Providers: newProvidersRepo(db), - ProviderDrafts: newProviderDraftsRepo(db), - ImportBatches: newImportBatchesRepo(db), - ImportBatchItems: newImportBatchItemsRepo(db), - ImportRuns: newImportRunsRepo(db), - ImportRunItems: newImportRunItemsRepo(db), - ImportRunEvents: newImportRunItemEventsRepo(db), - ManagedResources: newManagedResourcesRepo(db), - ProbeResults: newProbeResultsRepo(db), - AccessClosures: newAccessClosureRecordsRepo(db), - ReconcileRuns: newReconcileRunsRepo(db), + Hosts: newHostsRepo(db), + Packs: newPacksRepo(db), + Providers: newProvidersRepo(db), + LogicalGroups: newLogicalGroupsRepo(db), + LogicalGroupModels: newLogicalGroupModelsRepo(db), + LogicalGroupRoutes: newLogicalGroupRoutesRepo(db), + LogicalGroupRouteModels: newLogicalGroupRouteModelsRepo(db), + ProviderDrafts: newProviderDraftsRepo(db), + ImportBatches: newImportBatchesRepo(db), + ImportBatchItems: newImportBatchItemsRepo(db), + ImportRuns: newImportRunsRepo(db), + ImportRunItems: newImportRunItemsRepo(db), + ImportRunEvents: newImportRunItemEventsRepo(db), + ManagedResources: newManagedResourcesRepo(db), + ProbeResults: newProbeResultsRepo(db), + AccessClosures: newAccessClosureRecordsRepo(db), + ReconcileRuns: newReconcileRunsRepo(db), } } diff --git a/internal/store/sqlite/logical_group_models_repo.go b/internal/store/sqlite/logical_group_models_repo.go new file mode 100644 index 00000000..947cff61 --- /dev/null +++ b/internal/store/sqlite/logical_group_models_repo.go @@ -0,0 +1,136 @@ +package sqlite + +import ( + "context" + "fmt" + "strings" +) + +type LogicalGroupModel struct { + ID int64 + LogicalGroupID string + PublicModel string + Status string + CreatedAt string + UpdatedAt string +} + +type LogicalGroupModelsRepo struct { + db execQuerier +} + +func newLogicalGroupModelsRepo(db execQuerier) *LogicalGroupModelsRepo { + return &LogicalGroupModelsRepo{db: db} +} + +func (r *LogicalGroupModelsRepo) Create(ctx context.Context, model LogicalGroupModel) (int64, error) { + model, err := normalizeLogicalGroupModel(model) + if err != nil { + return 0, err + } + + result, err := r.db.ExecContext( + ctx, + `INSERT INTO logical_group_models (logical_group_id, public_model, status) + VALUES (?, ?, ?)`, + model.LogicalGroupID, + model.PublicModel, + model.Status, + ) + if err != nil { + return 0, fmt.Errorf("insert logical group model %q/%q: %w", model.LogicalGroupID, model.PublicModel, err) + } + + id, err := result.LastInsertId() + if err != nil { + return 0, fmt.Errorf("read inserted logical group model id for %q/%q: %w", model.LogicalGroupID, model.PublicModel, err) + } + return id, nil +} + +func (r *LogicalGroupModelsRepo) ListByLogicalGroupID(ctx context.Context, logicalGroupID string) ([]LogicalGroupModel, error) { + logicalGroupID = strings.TrimSpace(logicalGroupID) + if logicalGroupID == "" { + return nil, fmt.Errorf("logical_group_id is required") + } + + rows, err := r.db.QueryContext( + ctx, + `SELECT id, logical_group_id, public_model, status, created_at, updated_at + FROM logical_group_models + WHERE logical_group_id = ? + ORDER BY id ASC`, + logicalGroupID, + ) + if err != nil { + return nil, fmt.Errorf("list logical group models for %q: %w", logicalGroupID, err) + } + defer rows.Close() + + models := make([]LogicalGroupModel, 0) + for rows.Next() { + var model LogicalGroupModel + if err := rows.Scan( + &model.ID, + &model.LogicalGroupID, + &model.PublicModel, + &model.Status, + &model.CreatedAt, + &model.UpdatedAt, + ); err != nil { + return nil, fmt.Errorf("scan logical group model: %w", err) + } + models = append(models, model) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate logical group models for %q: %w", logicalGroupID, err) + } + return models, nil +} + +func (r *LogicalGroupModelsRepo) DeleteByLogicalGroupIDAndModel(ctx context.Context, logicalGroupID string, publicModel string) error { + logicalGroupID = strings.TrimSpace(logicalGroupID) + publicModel = strings.TrimSpace(publicModel) + if logicalGroupID == "" { + return fmt.Errorf("logical_group_id is required") + } + if publicModel == "" { + return fmt.Errorf("public_model is required") + } + + result, err := r.db.ExecContext( + ctx, + `DELETE FROM logical_group_models + WHERE logical_group_id = ? AND public_model = ?`, + logicalGroupID, + publicModel, + ) + if err != nil { + return fmt.Errorf("delete logical group model %q/%q: %w", logicalGroupID, publicModel, err) + } + affected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("read deleted logical group model rows for %q/%q: %w", logicalGroupID, publicModel, err) + } + if affected == 0 { + return fmt.Errorf("logical group model %q/%q not found", logicalGroupID, publicModel) + } + return nil +} + +func normalizeLogicalGroupModel(model LogicalGroupModel) (LogicalGroupModel, error) { + model.LogicalGroupID = strings.TrimSpace(model.LogicalGroupID) + model.PublicModel = strings.TrimSpace(model.PublicModel) + model.Status = strings.TrimSpace(model.Status) + + switch { + case model.LogicalGroupID == "": + return LogicalGroupModel{}, fmt.Errorf("logical_group_id is required") + case model.PublicModel == "": + return LogicalGroupModel{}, fmt.Errorf("public_model is required") + } + if model.Status == "" { + model.Status = "active" + } + return model, nil +} diff --git a/internal/store/sqlite/logical_group_route_models_repo.go b/internal/store/sqlite/logical_group_route_models_repo.go new file mode 100644 index 00000000..adaf640d --- /dev/null +++ b/internal/store/sqlite/logical_group_route_models_repo.go @@ -0,0 +1,113 @@ +package sqlite + +import ( + "context" + "fmt" + "strings" +) + +type LogicalGroupRouteModel struct { + ID int64 + RouteID string + PublicModel string + ShadowModel string + Status string + CreatedAt string + UpdatedAt string +} + +type LogicalGroupRouteModelsRepo struct { + db execQuerier +} + +func newLogicalGroupRouteModelsRepo(db execQuerier) *LogicalGroupRouteModelsRepo { + return &LogicalGroupRouteModelsRepo{db: db} +} + +func (r *LogicalGroupRouteModelsRepo) Create(ctx context.Context, model LogicalGroupRouteModel) (int64, error) { + model, err := normalizeLogicalGroupRouteModel(model) + if err != nil { + return 0, err + } + + result, err := r.db.ExecContext( + ctx, + `INSERT INTO logical_group_route_models (route_id, public_model, shadow_model, status) + VALUES (?, ?, ?, ?)`, + model.RouteID, + model.PublicModel, + model.ShadowModel, + model.Status, + ) + if err != nil { + return 0, fmt.Errorf("insert logical group route model %q/%q: %w", model.RouteID, model.PublicModel, err) + } + + id, err := result.LastInsertId() + if err != nil { + return 0, fmt.Errorf("read inserted logical group route model id for %q/%q: %w", model.RouteID, model.PublicModel, err) + } + return id, nil +} + +func (r *LogicalGroupRouteModelsRepo) ListByRouteID(ctx context.Context, routeID string) ([]LogicalGroupRouteModel, error) { + routeID = strings.TrimSpace(routeID) + if routeID == "" { + return nil, fmt.Errorf("route_id is required") + } + + rows, err := r.db.QueryContext( + ctx, + `SELECT id, route_id, public_model, shadow_model, status, created_at, updated_at + FROM logical_group_route_models + WHERE route_id = ? + ORDER BY id ASC`, + routeID, + ) + if err != nil { + return nil, fmt.Errorf("list logical group route models for %q: %w", routeID, err) + } + defer rows.Close() + + models := make([]LogicalGroupRouteModel, 0) + for rows.Next() { + var model LogicalGroupRouteModel + if err := rows.Scan( + &model.ID, + &model.RouteID, + &model.PublicModel, + &model.ShadowModel, + &model.Status, + &model.CreatedAt, + &model.UpdatedAt, + ); err != nil { + return nil, fmt.Errorf("scan logical group route model: %w", err) + } + models = append(models, model) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate logical group route models for %q: %w", routeID, err) + } + return models, nil +} + +func normalizeLogicalGroupRouteModel(model LogicalGroupRouteModel) (LogicalGroupRouteModel, error) { + model.RouteID = strings.TrimSpace(model.RouteID) + model.PublicModel = strings.TrimSpace(model.PublicModel) + model.ShadowModel = strings.TrimSpace(model.ShadowModel) + model.Status = strings.TrimSpace(model.Status) + + switch { + case model.RouteID == "": + return LogicalGroupRouteModel{}, fmt.Errorf("route_id is required") + case model.PublicModel == "": + return LogicalGroupRouteModel{}, fmt.Errorf("public_model is required") + } + if model.ShadowModel == "" { + model.ShadowModel = model.PublicModel + } + if model.Status == "" { + model.Status = "active" + } + return model, nil +} diff --git a/internal/store/sqlite/logical_group_routes_repo.go b/internal/store/sqlite/logical_group_routes_repo.go new file mode 100644 index 00000000..c95b7f90 --- /dev/null +++ b/internal/store/sqlite/logical_group_routes_repo.go @@ -0,0 +1,238 @@ +package sqlite + +import ( + "context" + "fmt" + "strings" +) + +type LogicalGroupRoute struct { + ID int64 + RouteID string + LogicalGroupID string + Name string + Status string + Priority int + Weight int + ShadowGroupID string + ShadowHostID string + UpstreamBaseURLHint string + CooldownUntil string + CreatedAt string + UpdatedAt string +} + +type LogicalGroupRoutesRepo struct { + db execQuerier +} + +func newLogicalGroupRoutesRepo(db execQuerier) *LogicalGroupRoutesRepo { + return &LogicalGroupRoutesRepo{db: db} +} + +func (r *LogicalGroupRoutesRepo) Create(ctx context.Context, route LogicalGroupRoute) (int64, error) { + route, err := normalizeLogicalGroupRoute(route) + if err != nil { + return 0, err + } + + result, err := r.db.ExecContext( + ctx, + `INSERT INTO logical_group_routes ( + route_id, + logical_group_id, + name, + status, + priority, + weight, + shadow_group_id, + shadow_host_id, + upstream_base_url_hint, + cooldown_until + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + route.RouteID, + route.LogicalGroupID, + route.Name, + route.Status, + route.Priority, + route.Weight, + route.ShadowGroupID, + route.ShadowHostID, + route.UpstreamBaseURLHint, + route.CooldownUntil, + ) + if err != nil { + return 0, fmt.Errorf("insert logical group route %q: %w", route.RouteID, err) + } + + id, err := result.LastInsertId() + if err != nil { + return 0, fmt.Errorf("read inserted logical group route id for %q: %w", route.RouteID, err) + } + return id, nil +} + +func (r *LogicalGroupRoutesRepo) GetByRouteID(ctx context.Context, routeID string) (LogicalGroupRoute, error) { + routeID = strings.TrimSpace(routeID) + if routeID == "" { + return LogicalGroupRoute{}, fmt.Errorf("route_id is required") + } + + var route LogicalGroupRoute + if err := r.db.QueryRowContext( + ctx, + `SELECT id, route_id, logical_group_id, name, status, priority, weight, shadow_group_id, shadow_host_id, upstream_base_url_hint, cooldown_until, created_at, updated_at + FROM logical_group_routes + WHERE route_id = ?`, + routeID, + ).Scan( + &route.ID, + &route.RouteID, + &route.LogicalGroupID, + &route.Name, + &route.Status, + &route.Priority, + &route.Weight, + &route.ShadowGroupID, + &route.ShadowHostID, + &route.UpstreamBaseURLHint, + &route.CooldownUntil, + &route.CreatedAt, + &route.UpdatedAt, + ); err != nil { + return LogicalGroupRoute{}, err + } + return route, nil +} + +func (r *LogicalGroupRoutesRepo) ListByLogicalGroupID(ctx context.Context, logicalGroupID string) ([]LogicalGroupRoute, error) { + logicalGroupID = strings.TrimSpace(logicalGroupID) + if logicalGroupID == "" { + return nil, fmt.Errorf("logical_group_id is required") + } + + rows, err := r.db.QueryContext( + ctx, + `SELECT id, route_id, logical_group_id, name, status, priority, weight, shadow_group_id, shadow_host_id, upstream_base_url_hint, cooldown_until, created_at, updated_at + FROM logical_group_routes + WHERE logical_group_id = ? + ORDER BY priority ASC, id ASC`, + logicalGroupID, + ) + if err != nil { + return nil, fmt.Errorf("list logical group routes for %q: %w", logicalGroupID, err) + } + defer rows.Close() + + routes := make([]LogicalGroupRoute, 0) + for rows.Next() { + var route LogicalGroupRoute + if err := rows.Scan( + &route.ID, + &route.RouteID, + &route.LogicalGroupID, + &route.Name, + &route.Status, + &route.Priority, + &route.Weight, + &route.ShadowGroupID, + &route.ShadowHostID, + &route.UpstreamBaseURLHint, + &route.CooldownUntil, + &route.CreatedAt, + &route.UpdatedAt, + ); err != nil { + return nil, fmt.Errorf("scan logical group route: %w", err) + } + routes = append(routes, route) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate logical group routes for %q: %w", logicalGroupID, err) + } + return routes, nil +} + +func (r *LogicalGroupRoutesRepo) UpdateByRouteID(ctx context.Context, route LogicalGroupRoute) error { + route, err := normalizeLogicalGroupRoute(route) + if err != nil { + return err + } + + result, err := r.db.ExecContext( + ctx, + `UPDATE logical_group_routes + SET logical_group_id = ?, name = ?, status = ?, priority = ?, weight = ?, shadow_group_id = ?, shadow_host_id = ?, upstream_base_url_hint = ?, cooldown_until = ?, updated_at = CURRENT_TIMESTAMP + WHERE route_id = ?`, + route.LogicalGroupID, + route.Name, + route.Status, + route.Priority, + route.Weight, + route.ShadowGroupID, + route.ShadowHostID, + route.UpstreamBaseURLHint, + route.CooldownUntil, + route.RouteID, + ) + if err != nil { + return fmt.Errorf("update logical group route %q: %w", route.RouteID, err) + } + affected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("read updated logical group route rows for %q: %w", route.RouteID, err) + } + if affected == 0 { + return fmt.Errorf("logical group route %q not found", route.RouteID) + } + return nil +} + +func (r *LogicalGroupRoutesRepo) DeleteByRouteID(ctx context.Context, routeID string) error { + routeID = strings.TrimSpace(routeID) + if routeID == "" { + return fmt.Errorf("route_id is required") + } + + result, err := r.db.ExecContext(ctx, `DELETE FROM logical_group_routes WHERE route_id = ?`, routeID) + if err != nil { + return fmt.Errorf("delete logical group route %q: %w", routeID, err) + } + affected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("read deleted logical group route rows for %q: %w", routeID, err) + } + if affected == 0 { + return fmt.Errorf("logical group route %q not found", routeID) + } + return nil +} + +func normalizeLogicalGroupRoute(route LogicalGroupRoute) (LogicalGroupRoute, error) { + route.RouteID = strings.TrimSpace(route.RouteID) + route.LogicalGroupID = strings.TrimSpace(route.LogicalGroupID) + route.Name = strings.TrimSpace(route.Name) + route.Status = strings.TrimSpace(route.Status) + route.ShadowGroupID = strings.TrimSpace(route.ShadowGroupID) + route.ShadowHostID = strings.TrimSpace(route.ShadowHostID) + route.UpstreamBaseURLHint = strings.TrimSpace(route.UpstreamBaseURLHint) + route.CooldownUntil = strings.TrimSpace(route.CooldownUntil) + + switch { + case route.RouteID == "": + return LogicalGroupRoute{}, fmt.Errorf("route_id is required") + case route.LogicalGroupID == "": + return LogicalGroupRoute{}, fmt.Errorf("logical_group_id is required") + case route.Name == "": + return LogicalGroupRoute{}, fmt.Errorf("name is required") + case route.Status == "": + return LogicalGroupRoute{}, fmt.Errorf("status is required") + case route.ShadowGroupID == "": + return LogicalGroupRoute{}, fmt.Errorf("shadow_group_id is required") + case route.ShadowHostID == "": + return LogicalGroupRoute{}, fmt.Errorf("shadow_host_id is required") + } + if route.Weight <= 0 { + route.Weight = 100 + } + return route, nil +} diff --git a/internal/store/sqlite/logical_groups_repo.go b/internal/store/sqlite/logical_groups_repo.go new file mode 100644 index 00000000..15c85cb8 --- /dev/null +++ b/internal/store/sqlite/logical_groups_repo.go @@ -0,0 +1,249 @@ +package sqlite + +import ( + "context" + "fmt" + "strings" +) + +const ( + defaultLogicalGroupRoutePolicy = "priority" + defaultLogicalGroupStickyMode = "conversation_preferred" + defaultConversationTTLSeconds = 7200 + defaultUserModelTTLSeconds = 1800 + defaultFailoverThreshold = 2 + defaultCooldownSeconds = 600 +) + +type LogicalGroup struct { + ID int64 + LogicalGroupID string + DisplayName string + Status string + Description string + RoutePolicy string + StickyMode string + ConversationTTLSeconds int + UserModelTTLSeconds int + FailoverThreshold int + CooldownSeconds int + CreatedAt string + UpdatedAt string +} + +type LogicalGroupsRepo struct { + db execQuerier +} + +func newLogicalGroupsRepo(db execQuerier) *LogicalGroupsRepo { + return &LogicalGroupsRepo{db: db} +} + +func (r *LogicalGroupsRepo) Create(ctx context.Context, group LogicalGroup) (int64, error) { + group, err := normalizeLogicalGroup(group) + if err != nil { + return 0, err + } + + result, err := r.db.ExecContext( + ctx, + `INSERT INTO logical_groups ( + logical_group_id, + display_name, + status, + description, + route_policy, + sticky_mode, + conversation_ttl_seconds, + user_model_ttl_seconds, + failover_threshold, + cooldown_seconds + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + group.LogicalGroupID, + group.DisplayName, + group.Status, + group.Description, + group.RoutePolicy, + group.StickyMode, + group.ConversationTTLSeconds, + group.UserModelTTLSeconds, + group.FailoverThreshold, + group.CooldownSeconds, + ) + if err != nil { + return 0, fmt.Errorf("insert logical group %q: %w", group.LogicalGroupID, err) + } + + id, err := result.LastInsertId() + if err != nil { + return 0, fmt.Errorf("read inserted logical group id for %q: %w", group.LogicalGroupID, err) + } + return id, nil +} + +func (r *LogicalGroupsRepo) GetByLogicalGroupID(ctx context.Context, logicalGroupID string) (LogicalGroup, error) { + logicalGroupID = strings.TrimSpace(logicalGroupID) + if logicalGroupID == "" { + return LogicalGroup{}, fmt.Errorf("logical_group_id is required") + } + + var group LogicalGroup + if err := r.db.QueryRowContext( + ctx, + `SELECT id, logical_group_id, display_name, status, description, route_policy, sticky_mode, conversation_ttl_seconds, user_model_ttl_seconds, failover_threshold, cooldown_seconds, created_at, updated_at + FROM logical_groups + WHERE logical_group_id = ?`, + logicalGroupID, + ).Scan( + &group.ID, + &group.LogicalGroupID, + &group.DisplayName, + &group.Status, + &group.Description, + &group.RoutePolicy, + &group.StickyMode, + &group.ConversationTTLSeconds, + &group.UserModelTTLSeconds, + &group.FailoverThreshold, + &group.CooldownSeconds, + &group.CreatedAt, + &group.UpdatedAt, + ); err != nil { + return LogicalGroup{}, err + } + return group, nil +} + +func (r *LogicalGroupsRepo) List(ctx context.Context) ([]LogicalGroup, error) { + rows, err := r.db.QueryContext( + ctx, + `SELECT id, logical_group_id, display_name, status, description, route_policy, sticky_mode, conversation_ttl_seconds, user_model_ttl_seconds, failover_threshold, cooldown_seconds, created_at, updated_at + FROM logical_groups + ORDER BY id ASC`, + ) + if err != nil { + return nil, fmt.Errorf("list logical groups: %w", err) + } + defer rows.Close() + + groups := make([]LogicalGroup, 0) + for rows.Next() { + var group LogicalGroup + if err := rows.Scan( + &group.ID, + &group.LogicalGroupID, + &group.DisplayName, + &group.Status, + &group.Description, + &group.RoutePolicy, + &group.StickyMode, + &group.ConversationTTLSeconds, + &group.UserModelTTLSeconds, + &group.FailoverThreshold, + &group.CooldownSeconds, + &group.CreatedAt, + &group.UpdatedAt, + ); err != nil { + return nil, fmt.Errorf("scan logical group: %w", err) + } + groups = append(groups, group) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate logical groups: %w", err) + } + return groups, nil +} + +func (r *LogicalGroupsRepo) UpdateByLogicalGroupID(ctx context.Context, group LogicalGroup) error { + group, err := normalizeLogicalGroup(group) + if err != nil { + return err + } + + result, err := r.db.ExecContext( + ctx, + `UPDATE logical_groups + SET display_name = ?, status = ?, description = ?, route_policy = ?, sticky_mode = ?, conversation_ttl_seconds = ?, user_model_ttl_seconds = ?, failover_threshold = ?, cooldown_seconds = ?, updated_at = CURRENT_TIMESTAMP + WHERE logical_group_id = ?`, + group.DisplayName, + group.Status, + group.Description, + group.RoutePolicy, + group.StickyMode, + group.ConversationTTLSeconds, + group.UserModelTTLSeconds, + group.FailoverThreshold, + group.CooldownSeconds, + group.LogicalGroupID, + ) + if err != nil { + return fmt.Errorf("update logical group %q: %w", group.LogicalGroupID, err) + } + affected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("read updated logical group rows for %q: %w", group.LogicalGroupID, err) + } + if affected == 0 { + return fmt.Errorf("logical group %q not found", group.LogicalGroupID) + } + return nil +} + +func (r *LogicalGroupsRepo) DeleteByLogicalGroupID(ctx context.Context, logicalGroupID string) error { + logicalGroupID = strings.TrimSpace(logicalGroupID) + if logicalGroupID == "" { + return fmt.Errorf("logical_group_id is required") + } + + result, err := r.db.ExecContext(ctx, `DELETE FROM logical_groups WHERE logical_group_id = ?`, logicalGroupID) + if err != nil { + return fmt.Errorf("delete logical group %q: %w", logicalGroupID, err) + } + affected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("read deleted logical group rows for %q: %w", logicalGroupID, err) + } + if affected == 0 { + return fmt.Errorf("logical group %q not found", logicalGroupID) + } + return nil +} + +func normalizeLogicalGroup(group LogicalGroup) (LogicalGroup, error) { + group.LogicalGroupID = strings.TrimSpace(group.LogicalGroupID) + group.DisplayName = strings.TrimSpace(group.DisplayName) + group.Status = strings.TrimSpace(group.Status) + group.Description = strings.TrimSpace(group.Description) + group.RoutePolicy = strings.TrimSpace(group.RoutePolicy) + group.StickyMode = strings.TrimSpace(group.StickyMode) + + switch { + case group.LogicalGroupID == "": + return LogicalGroup{}, fmt.Errorf("logical_group_id is required") + case group.DisplayName == "": + return LogicalGroup{}, fmt.Errorf("display_name is required") + case group.Status == "": + return LogicalGroup{}, fmt.Errorf("status is required") + } + + if group.RoutePolicy == "" { + group.RoutePolicy = defaultLogicalGroupRoutePolicy + } + if group.StickyMode == "" { + group.StickyMode = defaultLogicalGroupStickyMode + } + if group.ConversationTTLSeconds <= 0 { + group.ConversationTTLSeconds = defaultConversationTTLSeconds + } + if group.UserModelTTLSeconds <= 0 { + group.UserModelTTLSeconds = defaultUserModelTTLSeconds + } + if group.FailoverThreshold <= 0 { + group.FailoverThreshold = defaultFailoverThreshold + } + if group.CooldownSeconds <= 0 { + group.CooldownSeconds = defaultCooldownSeconds + } + + return group, nil +} diff --git a/internal/store/sqlite/logical_groups_repo_test.go b/internal/store/sqlite/logical_groups_repo_test.go new file mode 100644 index 00000000..be4cd5bb --- /dev/null +++ b/internal/store/sqlite/logical_groups_repo_test.go @@ -0,0 +1,308 @@ +package sqlite + +import ( + "context" + "database/sql" + "errors" + "testing" +) + +func TestLogicalGroupsRepoCreateGetUpdateDelete(t *testing.T) { + store := openTestDB(t) + ctx := context.Background() + + id, err := store.LogicalGroups().Create(ctx, LogicalGroup{ + LogicalGroupID: "gpt-shared", + DisplayName: "GPT Shared", + Status: "active", + Description: "shared group", + }) + if err != nil { + t.Fatalf("Create() error = %v", err) + } + if id <= 0 { + t.Fatalf("Create() id = %d, want positive", id) + } + + group, err := store.LogicalGroups().GetByLogicalGroupID(ctx, "gpt-shared") + if err != nil { + t.Fatalf("GetByLogicalGroupID() error = %v", err) + } + if group.RoutePolicy != defaultLogicalGroupRoutePolicy { + t.Fatalf("RoutePolicy = %q, want %q", group.RoutePolicy, defaultLogicalGroupRoutePolicy) + } + if group.StickyMode != defaultLogicalGroupStickyMode { + t.Fatalf("StickyMode = %q, want %q", group.StickyMode, defaultLogicalGroupStickyMode) + } + + if err := store.LogicalGroups().UpdateByLogicalGroupID(ctx, LogicalGroup{ + LogicalGroupID: "gpt-shared", + DisplayName: "GPT Shared Updated", + Status: "paused", + Description: "updated", + RoutePolicy: "priority", + StickyMode: "user_preferred", + ConversationTTLSeconds: 3600, + UserModelTTLSeconds: 900, + FailoverThreshold: 3, + CooldownSeconds: 120, + }); err != nil { + t.Fatalf("UpdateByLogicalGroupID() error = %v", err) + } + + updated, err := store.LogicalGroups().GetByLogicalGroupID(ctx, "gpt-shared") + if err != nil { + t.Fatalf("GetByLogicalGroupID(updated) error = %v", err) + } + if updated.DisplayName != "GPT Shared Updated" || updated.Status != "paused" { + t.Fatalf("updated group = %+v, want updated fields", updated) + } + + if err := store.LogicalGroups().DeleteByLogicalGroupID(ctx, "gpt-shared"); err != nil { + t.Fatalf("DeleteByLogicalGroupID() error = %v", err) + } + _, err = store.LogicalGroups().GetByLogicalGroupID(ctx, "gpt-shared") + if !errors.Is(err, sql.ErrNoRows) { + t.Fatalf("GetByLogicalGroupID() after delete error = %v, want sql.ErrNoRows", err) + } +} + +func TestLogicalGroupsRepoList(t *testing.T) { + store := openTestDB(t) + ctx := context.Background() + + for _, group := range []LogicalGroup{ + {LogicalGroupID: "group-a", DisplayName: "Group A", Status: "active"}, + {LogicalGroupID: "group-b", DisplayName: "Group B", Status: "active"}, + } { + if _, err := store.LogicalGroups().Create(ctx, group); err != nil { + t.Fatalf("Create(%q) error = %v", group.LogicalGroupID, err) + } + } + + groups, err := store.LogicalGroups().List(ctx) + if err != nil { + t.Fatalf("List() error = %v", err) + } + if len(groups) != 2 { + t.Fatalf("List() len = %d, want 2", len(groups)) + } + if groups[0].LogicalGroupID != "group-a" || groups[1].LogicalGroupID != "group-b" { + t.Fatalf("List() = %+v, want insertion order", groups) + } +} + +func TestLogicalGroupModelsRepoCreateListDelete(t *testing.T) { + store := openTestDBWithFK(t) + ctx := context.Background() + + if _, err := store.LogicalGroups().Create(ctx, LogicalGroup{ + LogicalGroupID: "gpt-shared", + DisplayName: "GPT Shared", + Status: "active", + }); err != nil { + t.Fatalf("LogicalGroups().Create() error = %v", err) + } + + if _, err := store.LogicalGroupModels().Create(ctx, LogicalGroupModel{ + LogicalGroupID: "gpt-shared", + PublicModel: "gpt-5.4", + }); err != nil { + t.Fatalf("LogicalGroupModels().Create() error = %v", err) + } + + models, err := store.LogicalGroupModels().ListByLogicalGroupID(ctx, "gpt-shared") + if err != nil { + t.Fatalf("ListByLogicalGroupID() error = %v", err) + } + if len(models) != 1 || models[0].PublicModel != "gpt-5.4" { + t.Fatalf("ListByLogicalGroupID() = %+v, want gpt-5.4", models) + } + + if err := store.LogicalGroupModels().DeleteByLogicalGroupIDAndModel(ctx, "gpt-shared", "gpt-5.4"); err != nil { + t.Fatalf("DeleteByLogicalGroupIDAndModel() error = %v", err) + } + models, err = store.LogicalGroupModels().ListByLogicalGroupID(ctx, "gpt-shared") + if err != nil { + t.Fatalf("ListByLogicalGroupID() after delete error = %v", err) + } + if len(models) != 0 { + t.Fatalf("ListByLogicalGroupID() after delete len = %d, want 0", len(models)) + } +} + +func TestLogicalGroupRoutesRepoCreateGetListUpdateDelete(t *testing.T) { + store := openTestDBWithFK(t) + ctx := context.Background() + + if _, err := store.LogicalGroups().Create(ctx, LogicalGroup{ + LogicalGroupID: "gpt-shared", + DisplayName: "GPT Shared", + Status: "active", + }); err != nil { + t.Fatalf("LogicalGroups().Create() error = %v", err) + } + + if _, err := store.LogicalGroupRoutes().Create(ctx, LogicalGroupRoute{ + RouteID: "asxs", + LogicalGroupID: "gpt-shared", + Name: "ASXS", + Status: "active", + Priority: 10, + ShadowGroupID: "gpt-shared__asxs", + ShadowHostID: "remote43", + }); err != nil { + t.Fatalf("LogicalGroupRoutes().Create() error = %v", err) + } + + route, err := store.LogicalGroupRoutes().GetByRouteID(ctx, "asxs") + if err != nil { + t.Fatalf("GetByRouteID() error = %v", err) + } + if route.Weight != 100 { + t.Fatalf("Weight = %d, want 100", route.Weight) + } + + routes, err := store.LogicalGroupRoutes().ListByLogicalGroupID(ctx, "gpt-shared") + if err != nil { + t.Fatalf("ListByLogicalGroupID() error = %v", err) + } + if len(routes) != 1 || routes[0].RouteID != "asxs" { + t.Fatalf("ListByLogicalGroupID() = %+v, want route asxs", routes) + } + + if err := store.LogicalGroupRoutes().UpdateByRouteID(ctx, LogicalGroupRoute{ + RouteID: "asxs", + LogicalGroupID: "gpt-shared", + Name: "ASXS Updated", + Status: "degraded", + Priority: 20, + Weight: 80, + ShadowGroupID: "gpt-shared__asxs", + ShadowHostID: "remote43", + UpstreamBaseURLHint: "https://api.asxs.top/v1", + CooldownUntil: "2026-05-28T16:00:00Z", + }); err != nil { + t.Fatalf("UpdateByRouteID() error = %v", err) + } + + updated, err := store.LogicalGroupRoutes().GetByRouteID(ctx, "asxs") + if err != nil { + t.Fatalf("GetByRouteID(updated) error = %v", err) + } + if updated.Name != "ASXS Updated" || updated.Status != "degraded" || updated.Weight != 80 { + t.Fatalf("updated route = %+v, want updated fields", updated) + } + + if err := store.LogicalGroupRoutes().DeleteByRouteID(ctx, "asxs"); err != nil { + t.Fatalf("DeleteByRouteID() error = %v", err) + } + _, err = store.LogicalGroupRoutes().GetByRouteID(ctx, "asxs") + if !errors.Is(err, sql.ErrNoRows) { + t.Fatalf("GetByRouteID() after delete error = %v, want sql.ErrNoRows", err) + } +} + +func TestLogicalGroupRouteModelsRepoCreateList(t *testing.T) { + store := openTestDBWithFK(t) + ctx := context.Background() + + if _, err := store.LogicalGroups().Create(ctx, LogicalGroup{ + LogicalGroupID: "gpt-shared", + DisplayName: "GPT Shared", + Status: "active", + }); err != nil { + t.Fatalf("LogicalGroups().Create() error = %v", err) + } + if _, err := store.LogicalGroupRoutes().Create(ctx, LogicalGroupRoute{ + RouteID: "codex2api", + LogicalGroupID: "gpt-shared", + Name: "Codex2API", + Status: "active", + Priority: 20, + ShadowGroupID: "gpt-shared__codex2api", + ShadowHostID: "remote43", + }); err != nil { + t.Fatalf("LogicalGroupRoutes().Create() error = %v", err) + } + + if _, err := store.LogicalGroupRouteModels().Create(ctx, LogicalGroupRouteModel{ + RouteID: "codex2api", + PublicModel: "gpt-5.4", + }); err != nil { + t.Fatalf("LogicalGroupRouteModels().Create() error = %v", err) + } + + models, err := store.LogicalGroupRouteModels().ListByRouteID(ctx, "codex2api") + if err != nil { + t.Fatalf("ListByRouteID() error = %v", err) + } + if len(models) != 1 { + t.Fatalf("ListByRouteID() len = %d, want 1", len(models)) + } + if models[0].ShadowModel != "gpt-5.4" { + t.Fatalf("ShadowModel = %q, want default public model", models[0].ShadowModel) + } +} + +func TestLogicalGroupReposEnforceForeignKeysAndCascadeDelete(t *testing.T) { + store := openTestDBWithFK(t) + ctx := context.Background() + + if _, err := store.LogicalGroupModels().Create(ctx, LogicalGroupModel{ + LogicalGroupID: "missing-group", + PublicModel: "gpt-5.4", + }); err == nil { + t.Fatal("LogicalGroupModels().Create() error = nil, want foreign key failure") + } + + if _, err := store.LogicalGroups().Create(ctx, LogicalGroup{ + LogicalGroupID: "gpt-shared", + DisplayName: "GPT Shared", + Status: "active", + }); err != nil { + t.Fatalf("LogicalGroups().Create() error = %v", err) + } + if _, err := store.LogicalGroupModels().Create(ctx, LogicalGroupModel{ + LogicalGroupID: "gpt-shared", + PublicModel: "gpt-5.4", + }); err != nil { + t.Fatalf("LogicalGroupModels().Create() error = %v", err) + } + if _, err := store.LogicalGroupRoutes().Create(ctx, LogicalGroupRoute{ + RouteID: "asxs", + LogicalGroupID: "gpt-shared", + Name: "ASXS", + Status: "active", + Priority: 10, + ShadowGroupID: "gpt-shared__asxs", + ShadowHostID: "remote43", + }); err != nil { + t.Fatalf("LogicalGroupRoutes().Create() error = %v", err) + } + if _, err := store.LogicalGroupRouteModels().Create(ctx, LogicalGroupRouteModel{ + RouteID: "asxs", + PublicModel: "gpt-5.4", + }); err != nil { + t.Fatalf("LogicalGroupRouteModels().Create() error = %v", err) + } + + if err := store.LogicalGroups().DeleteByLogicalGroupID(ctx, "gpt-shared"); err != nil { + t.Fatalf("DeleteByLogicalGroupID() error = %v", err) + } + + models, err := store.LogicalGroupModels().ListByLogicalGroupID(ctx, "gpt-shared") + if err != nil { + t.Fatalf("ListByLogicalGroupID() after cascade error = %v", err) + } + if len(models) != 0 { + t.Fatalf("models after cascade len = %d, want 0", len(models)) + } + routes, err := store.LogicalGroupRoutes().ListByLogicalGroupID(ctx, "gpt-shared") + if err != nil { + t.Fatalf("ListByLogicalGroupID(routes) after cascade error = %v", err) + } + if len(routes) != 0 { + t.Fatalf("routes after cascade len = %d, want 0", len(routes)) + } +}