feat(routing): auto-supply managed proxy keys

This commit is contained in:
phamnazage-jpg
2026-05-29 10:49:27 +08:00
parent b4d1b8c377
commit cffe3332ac
2 changed files with 466 additions and 27 deletions

View File

@@ -3,14 +3,21 @@ package app
import (
"bytes"
"context"
"crypto/sha256"
"database/sql"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"sub2api-cn-relay-manager/internal/access"
"sub2api-cn-relay-manager/internal/host/sub2api"
"sub2api-cn-relay-manager/internal/routing"
"sub2api-cn-relay-manager/internal/store/sqlite"
)
@@ -26,6 +33,7 @@ type ProxyRouteChatCompletionsRequest struct {
UserKey string `json:"user_key,omitempty"`
ConversationKey string `json:"conversation_key,omitempty"`
GatewayAPIKey string `json:"gateway_api_key"`
SubscriptionUserID string `json:"subscription_user_id,omitempty"`
Messages []ChatCompletionMessage `json:"messages,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
@@ -48,6 +56,9 @@ type RouteChatCompletionsForwardInfo struct {
HostBaseURL string `json:"host_base_url"`
ShadowGroupID string `json:"shadow_group_id"`
ShadowModel string `json:"shadow_model"`
EffectiveGatewayKeySource string `json:"effective_gateway_key_source,omitempty"`
EffectiveGatewayKeyFingerprint string `json:"effective_gateway_key_fingerprint,omitempty"`
ManagedUserID string `json:"managed_user_id,omitempty"`
UpstreamPath string `json:"upstream_path"`
UpstreamStatus int `json:"upstream_status"`
LatencyMS int64 `json:"latency_ms"`
@@ -82,8 +93,9 @@ func buildProxyRouteChatCompletionsAction(
) func(context.Context, ProxyRouteChatCompletionsRequest) (ProxyRouteChatCompletionsResult, error) {
return func(ctx context.Context, req ProxyRouteChatCompletionsRequest) (ProxyRouteChatCompletionsResult, error) {
req.GatewayAPIKey = strings.TrimSpace(req.GatewayAPIKey)
if req.GatewayAPIKey == "" {
return ProxyRouteChatCompletionsResult{}, fmt.Errorf("gateway_api_key is required")
req.SubscriptionUserID = strings.TrimSpace(req.SubscriptionUserID)
if req.GatewayAPIKey == "" && req.SubscriptionUserID == "" {
return ProxyRouteChatCompletionsResult{}, fmt.Errorf("gateway_api_key or subscription_user_id is required")
}
resolveInfo, err := resolveRoute(ctx, ResolveRouteRequest{
@@ -110,17 +122,29 @@ func buildProxyRouteChatCompletionsAction(
if err != nil {
return ProxyRouteChatCompletionsResult{}, fmt.Errorf("get shadow host %q: %w", resolveInfo.ShadowHostID, err)
}
hostClient, err := newSub2APIClient(hostRow.BaseURL, authFromStoredHost(hostRow))
if err != nil {
return ProxyRouteChatCompletionsResult{}, err
}
shadowModel := strings.TrimSpace(resolveInfo.ShadowModel)
if shadowModel == "" {
shadowModel = strings.TrimSpace(resolveInfo.PublicModel)
}
forward := proxyChatCompletionToShadowHost(ctx, hostRow.BaseURL, req.GatewayAPIKey, shadowModel, req.Messages, req.MaxTokens, req.Temperature)
gatewayAPIKey, gatewayKeySource, managedUserID, err := resolveProxyGatewayAPIKey(ctx, store, hostRow, hostClient, resolveInfo, req)
if err != nil {
return ProxyRouteChatCompletionsResult{}, err
}
forward := proxyChatCompletionToShadowHost(ctx, hostRow.BaseURL, gatewayAPIKey, shadowModel, req.Messages, req.MaxTokens, req.Temperature)
forward.HostID = strings.TrimSpace(hostRow.HostID)
forward.HostBaseURL = strings.TrimSpace(hostRow.BaseURL)
forward.ShadowGroupID = strings.TrimSpace(resolveInfo.ShadowGroupID)
forward.ShadowModel = shadowModel
forward.EffectiveGatewayKeySource = gatewayKeySource
forward.EffectiveGatewayKeyFingerprint = fingerprintRouteProxySecret(gatewayAPIKey)
forward.ManagedUserID = managedUserID
if err := appendProxyRouteDecisionLog(ctx, writerSource, req, resolveInfo, forward); err != nil {
return ProxyRouteChatCompletionsResult{}, err
@@ -317,6 +341,78 @@ func classifyProxyUpstreamStatus(statusCode int) string {
}
}
func resolveProxyGatewayAPIKey(
ctx context.Context,
store *sqlite.DB,
hostRow sqlite.Host,
hostClient *sub2api.Client,
resolveInfo ResolveRouteInfo,
req ProxyRouteChatCompletionsRequest,
) (string, string, string, error) {
gatewayAPIKey := strings.TrimSpace(req.GatewayAPIKey)
if gatewayAPIKey != "" {
return gatewayAPIKey, access.ProbeKeySourceRequestedProbeAPIKey, "", nil
}
subscriptionUserID := strings.TrimSpace(req.SubscriptionUserID)
if subscriptionUserID == "" {
return "", "", "", fmt.Errorf("gateway_api_key or subscription_user_id is required")
}
shadowGroupHostResourceID, err := resolveShadowGroupHostResourceID(ctx, store, hostRow, hostClient, strings.TrimSpace(resolveInfo.ShadowGroupID))
if err != nil {
return "", "", "", err
}
accessRef, err := hostClient.EnsureSubscriptionAccess(ctx, sub2api.EnsureSubscriptionAccessRequest{
UserSelector: subscriptionUserID,
GroupID: shadowGroupHostResourceID,
})
if err != nil {
return "", "", "", fmt.Errorf("ensure subscription access for route %q: %w", resolveInfo.RouteID, err)
}
gatewayAPIKey = strings.TrimSpace(accessRef.APIKey)
if gatewayAPIKey == "" {
return "", "", "", fmt.Errorf("managed subscription access api key is empty")
}
return gatewayAPIKey, access.ProbeKeySourceManagedSubscription, strings.TrimSpace(accessRef.UserID), nil
}
func resolveShadowGroupHostResourceID(
ctx context.Context,
store *sqlite.DB,
hostRow sqlite.Host,
hostClient *sub2api.Client,
shadowGroupID string,
) (string, error) {
shadowGroupID = strings.TrimSpace(shadowGroupID)
if shadowGroupID == "" {
return "", fmt.Errorf("shadow_group_id is required")
}
if _, err := strconv.ParseInt(shadowGroupID, 10, 64); err == nil {
return shadowGroupID, nil
}
resource, err := store.ManagedResources().GetByResourceIdentity(ctx, hostRow.ID, "group", shadowGroupID)
if err == nil {
return strings.TrimSpace(resource.HostResourceID), nil
}
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return "", fmt.Errorf("lookup shadow group %q in managed resources: %w", shadowGroupID, err)
}
snapshot, err := hostClient.ListManagedResources(ctx, sub2api.ListManagedResourcesRequest{GroupName: shadowGroupID})
if err != nil {
return "", fmt.Errorf("list host groups for %q: %w", shadowGroupID, err)
}
if len(snapshot.Groups) == 1 {
return strings.TrimSpace(snapshot.Groups[0].ID), nil
}
if len(snapshot.Groups) > 1 {
return "", fmt.Errorf("multiple host groups matched shadow_group_id %q", shadowGroupID)
}
return "", fmt.Errorf("shadow group %q not found on host %q", shadowGroupID, hostRow.HostID)
}
func resolveProxyUserKey(req ProxyRouteChatCompletionsRequest) string {
if key := strings.TrimSpace(req.UserKey); key != "" {
return key
@@ -327,6 +423,15 @@ func resolveProxyUserKey(req ProxyRouteChatCompletionsRequest) string {
return ""
}
func fingerprintRouteProxySecret(value string) string {
trimmed := strings.TrimSpace(value)
if trimmed == "" {
return ""
}
sum := sha256.Sum256([]byte(trimmed))
return "sha256:" + hex.EncodeToString(sum[:])
}
func resolveProxyConversationKey(req ProxyRouteChatCompletionsRequest) string {
if key := strings.TrimSpace(req.ConversationKey); key != "" {
return key

View File

@@ -6,6 +6,7 @@ import (
"net/http"
"net/http/httptest"
"path/filepath"
"strings"
"testing"
"sub2api-cn-relay-manager/internal/store/sqlite"
@@ -224,6 +225,188 @@ func TestNewActionSetProxyRouteChatCompletionsFlow(t *testing.T) {
}
}
func TestNewActionSetProxyRouteChatCompletionsManagedSubscriptionFlow(t *testing.T) {
t.Parallel()
var (
gotAuthHeader string
gotModel string
)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodGet && strings.HasPrefix(r.URL.RequestURI(), "/api/v1/admin/users?"):
_, _ = w.Write([]byte(`{"data":{"items":[]}}`))
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/admin/users":
_, _ = w.Write([]byte(`{"data":{"id":84,"email":"relay-sub-managed-user@sub2api.local"}}`))
case r.Method == http.MethodPut && r.URL.Path == "/api/v1/admin/users/84":
_, _ = w.Write([]byte(`{"data":{"id":84}}`))
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/admin/users/84/balance":
_, _ = w.Write([]byte(`{"data":{"id":84}}`))
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/admin/subscriptions/assign":
_, _ = w.Write([]byte(`{"data":{"id":401}}`))
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/auth/login":
_, _ = w.Write([]byte(`{"data":{"access_token":"user-jwt"}}`))
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/keys":
_, _ = w.Write([]byte(`{"data":{"id":501,"key":"sk-relay-key","name":"managed-key"}}`))
case r.Method == http.MethodPut && r.URL.Path == "/api/v1/admin/api-keys/501":
_, _ = w.Write([]byte(`{"data":{"api_key":{"id":501}}}`))
case r.Method == http.MethodPost && r.URL.Path == "/v1/chat/completions":
gotAuthHeader = r.Header.Get("Authorization")
var payload struct {
Model string `json:"model"`
}
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
t.Fatalf("json.Decode() error = %v", err)
}
gotModel = payload.Model
writeJSON(w, http.StatusOK, map[string]any{
"id": "chatcmpl_proxy_managed",
"choices": []map[string]any{
{
"message": map[string]any{
"role": "assistant",
"content": "pong-managed",
},
},
},
})
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer server.Close()
dsn := "file:" + filepath.ToSlash(filepath.Join(t.TempDir(), "route-proxy-managed.db")) + "?_busy_timeout=5000"
actions := NewActionSet(dsn)
ctx := context.Background()
store, err := sqlite.Open(ctx, dsn)
if err != nil {
t.Fatalf("sqlite.Open() error = %v", err)
}
defer store.Close()
hostID, err := store.Hosts().Create(ctx, sqlite.Host{
HostID: "remote43-managed",
BaseURL: server.URL,
HostVersion: "0.1.126",
AuthType: "bearer",
AuthToken: "host-admin-token",
})
if err != nil {
t.Fatalf("Hosts().Create() error = %v", err)
}
packID, err := store.Packs().Create(ctx, sqlite.Pack{
PackID: "managed-pack",
Version: "1.0.0",
Checksum: "sha256-managed-pack",
Vendor: "tksea",
ManifestJSON: "{}",
})
if err != nil {
t.Fatalf("Packs().Create() error = %v", err)
}
providerID, err := store.Providers().Create(ctx, sqlite.Provider{
PackID: packID,
ProviderID: "managed-provider",
DisplayName: "Managed Provider",
BaseURL: "https://api.asxs.top/v1",
Platform: "openai",
})
if err != nil {
t.Fatalf("Providers().Create() error = %v", err)
}
batchID, err := store.ImportBatches().Create(ctx, sqlite.ImportBatch{
HostID: hostID,
PackID: packID,
ProviderID: providerID,
Mode: "strict",
BatchStatus: "succeeded",
AccessStatus: "subscription_ready",
})
if err != nil {
t.Fatalf("ImportBatches().Create() error = %v", err)
}
if _, err := actions.CreateLogicalGroup(ctx, CreateLogicalGroupRequest{
LogicalGroupID: "gpt-shared-managed",
DisplayName: "GPT Shared Managed",
Status: "active",
RoutePolicy: "priority",
StickyMode: "conversation_preferred",
ConversationTTLSeconds: 1200,
UserModelTTLSeconds: 600,
FailoverThreshold: 2,
CooldownSeconds: 300,
}); err != nil {
t.Fatalf("CreateLogicalGroup() error = %v", err)
}
if _, err := actions.CreateLogicalGroupModel(ctx, CreateLogicalGroupModelRequest{
LogicalGroupID: "gpt-shared-managed",
PublicModel: "gpt-5.4",
Status: "active",
}); err != nil {
t.Fatalf("CreateLogicalGroupModel() error = %v", err)
}
if _, err := actions.CreateLogicalGroupRoute(ctx, CreateLogicalGroupRouteRequest{
LogicalGroupID: "gpt-shared-managed",
RouteID: "asxs-managed",
Name: "ASXS Managed",
Status: "active",
Priority: 10,
ShadowGroupID: "101",
ShadowHostID: "remote43-managed",
UpstreamBaseURLHint: "https://api.asxs.top/v1",
}); err != nil {
t.Fatalf("CreateLogicalGroupRoute() error = %v", err)
}
if _, err := actions.CreateLogicalGroupRouteModel(ctx, CreateLogicalGroupRouteModelRequest{
LogicalGroupID: "gpt-shared-managed",
RouteID: "asxs-managed",
PublicModel: "gpt-5.4",
ShadowModel: "gpt-5.4-asxs",
Status: "active",
}); err != nil {
t.Fatalf("CreateLogicalGroupRouteModel() error = %v", err)
}
if _, err := store.ManagedResources().Create(ctx, sqlite.ManagedResource{
BatchID: batchID,
HostID: hostID,
ResourceType: "group",
HostResourceID: "101",
ResourceName: "shadow-group-asxs",
}); err != nil {
t.Fatalf("ManagedResources().Create() error = %v", err)
}
result, err := actions.ProxyRouteChatCompletions(ctx, ProxyRouteChatCompletionsRequest{
RequestID: "req-proxy-managed-1",
LogicalGroupID: "gpt-shared-managed",
PublicModel: "gpt-5.4",
Scope: "conversation",
SubjectID: "conv-managed-1",
SubscriptionUserID: "crm-user-1",
Sync: true,
})
if err != nil {
t.Fatalf("ProxyRouteChatCompletions() error = %v", err)
}
if !strings.HasPrefix(gotAuthHeader, "Bearer sk-relay-") {
t.Fatalf("Authorization header = %q, want Bearer sk-relay-*", gotAuthHeader)
}
if gotModel != "gpt-5.4-asxs" {
t.Fatalf("forwarded model = %q, want gpt-5.4-asxs", gotModel)
}
if result.Forward.EffectiveGatewayKeySource != "managed_subscription" {
t.Fatalf("EffectiveGatewayKeySource = %q, want managed_subscription", result.Forward.EffectiveGatewayKeySource)
}
if result.Forward.EffectiveGatewayKeyFingerprint == "" {
t.Fatal("EffectiveGatewayKeyFingerprint = empty, want hashed managed key fingerprint")
}
if result.Forward.ManagedUserID != "84" {
t.Fatalf("ManagedUserID = %q, want 84", result.Forward.ManagedUserID)
}
}
func TestProxyChatCompletionToShadowHostReportsNon2xx(t *testing.T) {
t.Parallel()
@@ -274,3 +457,154 @@ func TestRouteProxyHelpers(t *testing.T) {
t.Fatalf("resolveProxyConversationKey(conversation) = %q, want conv-1", got)
}
}
func TestResolveShadowGroupHostResourceID(t *testing.T) {
t.Parallel()
dsn := "file:" + filepath.ToSlash(filepath.Join(t.TempDir(), "route-proxy-helper.db")) + "?_busy_timeout=5000"
ctx := context.Background()
store, err := sqlite.Open(ctx, dsn)
if err != nil {
t.Fatalf("sqlite.Open() error = %v", err)
}
defer store.Close()
hostID, err := store.Hosts().Create(ctx, sqlite.Host{
HostID: "helper-host",
BaseURL: "https://helper.example.com",
HostVersion: "0.1.126",
AuthType: "bearer",
AuthToken: "host-token",
})
if err != nil {
t.Fatalf("Hosts().Create() error = %v", err)
}
hostRow, err := store.Hosts().GetByID(ctx, hostID)
if err != nil {
t.Fatalf("Hosts().GetByID() error = %v", err)
}
if got, err := resolveShadowGroupHostResourceID(ctx, store, hostRow, nil, "101"); err != nil || got != "101" {
t.Fatalf("resolveShadowGroupHostResourceID(numeric) = (%q, %v), want 101", got, err)
}
packID, err := store.Packs().Create(ctx, sqlite.Pack{
PackID: "helper-pack",
Version: "1.0.0",
Checksum: "sha256-helper",
Vendor: "tksea",
ManifestJSON: "{}",
})
if err != nil {
t.Fatalf("Packs().Create() error = %v", err)
}
providerID, err := store.Providers().Create(ctx, sqlite.Provider{
PackID: packID,
ProviderID: "helper-provider",
DisplayName: "Helper Provider",
BaseURL: "https://helper.example.com/v1",
Platform: "openai",
})
if err != nil {
t.Fatalf("Providers().Create() error = %v", err)
}
batchID, err := store.ImportBatches().Create(ctx, sqlite.ImportBatch{
HostID: hostID,
PackID: packID,
ProviderID: providerID,
Mode: "strict",
BatchStatus: "succeeded",
AccessStatus: "subscription_ready",
})
if err != nil {
t.Fatalf("ImportBatches().Create() error = %v", err)
}
if _, err := store.ManagedResources().Create(ctx, sqlite.ManagedResource{
BatchID: batchID,
HostID: hostID,
ResourceType: "group",
HostResourceID: "202",
ResourceName: "shadow-group-name",
}); err != nil {
t.Fatalf("ManagedResources().Create() error = %v", err)
}
if got, err := resolveShadowGroupHostResourceID(ctx, store, hostRow, nil, "202"); err != nil || got != "202" {
t.Fatalf("resolveShadowGroupHostResourceID(store identity) = (%q, %v), want 202", got, err)
}
}
func TestResolveShadowGroupHostResourceIDFallsBackToHostList(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.HasPrefix(r.URL.RequestURI(), "/api/v1/admin/groups"):
_, _ = w.Write([]byte(`{"data":[{"id":"303","name":"shadow-group-remote"}]}`))
case strings.HasPrefix(r.URL.RequestURI(), "/api/v1/admin/channels"):
_, _ = w.Write([]byte(`{"data":[]}`))
case strings.HasPrefix(r.URL.RequestURI(), "/api/v1/admin/payment/plans"):
_, _ = w.Write([]byte(`{"data":[]}`))
case strings.HasPrefix(r.URL.RequestURI(), "/api/v1/admin/accounts"):
_, _ = w.Write([]byte(`{"data":{"items":[],"pages":1}}`))
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer server.Close()
dsn := "file:" + filepath.ToSlash(filepath.Join(t.TempDir(), "route-proxy-fallback.db")) + "?_busy_timeout=5000"
ctx := context.Background()
store, err := sqlite.Open(ctx, dsn)
if err != nil {
t.Fatalf("sqlite.Open() error = %v", err)
}
defer store.Close()
hostID, err := store.Hosts().Create(ctx, sqlite.Host{
HostID: "fallback-host",
BaseURL: server.URL,
HostVersion: "0.1.126",
AuthType: "bearer",
AuthToken: "host-token",
})
if err != nil {
t.Fatalf("Hosts().Create() error = %v", err)
}
hostRow, err := store.Hosts().GetByID(ctx, hostID)
if err != nil {
t.Fatalf("Hosts().GetByID() error = %v", err)
}
hostClient, err := newSub2APIClient(server.URL, CreateHostAuth{Type: "bearer", Token: "host-token"})
if err != nil {
t.Fatalf("newSub2APIClient() error = %v", err)
}
got, err := resolveShadowGroupHostResourceID(ctx, store, hostRow, hostClient, "shadow-group-remote")
if err != nil {
t.Fatalf("resolveShadowGroupHostResourceID(host fallback) error = %v", err)
}
if got != "303" {
t.Fatalf("resolveShadowGroupHostResourceID(host fallback) = %q, want 303", got)
}
}
func TestAPIProxyRouteChatCompletionsRejectsMissingGatewayAndSubscriptionUser(t *testing.T) {
t.Parallel()
handler := NewAPIHandler("secret-token", ActionSet{
ProxyRouteChatCompletions: buildProxyRouteChatCompletionsAction("file::memory:?cache=shared", func(context.Context, ResolveRouteRequest) (ResolveRouteInfo, error) {
t.Fatal("ResolveRoute should not be called when auth inputs are missing")
return ResolveRouteInfo{}, nil
}, newLazyRouteLogWriter("file::memory:?cache=shared")),
})
request := httptestRequest(t, http.MethodPost, "/api/routing/proxy/chat/completions", map[string]any{
"logical_group_id": "gpt-shared",
"public_model": "gpt-5.4",
"scope": "conversation",
"subject_id": "conv-1",
}, "secret-token")
response := httptestRecorder(handler, request)
assertStatusCode(t, response, http.StatusBadRequest)
assertJSONContains(t, response.Body().Bytes(), "error.message", "gateway_api_key or subscription_user_id is required")
}