Files
sub2api-cn-relay-manager/internal/app/route_proxy_api.go
2026-05-29 13:17:56 +08:00

558 lines
20 KiB
Go

package app
import (
"bytes"
"context"
"crypto/sha256"
"database/sql"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"sub2api-cn-relay-manager/internal/access"
"sub2api-cn-relay-manager/internal/host/sub2api"
"sub2api-cn-relay-manager/internal/routing"
"sub2api-cn-relay-manager/internal/store/sqlite"
)
const routeChatCompletionsPath = "/v1/chat/completions"
type ProxyRouteChatCompletionsRequest struct {
RequestID string `json:"request_id,omitempty"`
LogicalGroupID string `json:"logical_group_id"`
PublicModel string `json:"public_model"`
Scope string `json:"scope"`
SubjectID string `json:"subject_id"`
UserKey string `json:"user_key,omitempty"`
ConversationKey string `json:"conversation_key,omitempty"`
GatewayAPIKey string `json:"gateway_api_key"`
SubscriptionUserID string `json:"subscription_user_id,omitempty"`
Messages []ChatCompletionMessage `json:"messages,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
Sync bool `json:"sync,omitempty"`
}
type RouteChatCompletionsRequest struct {
RequestID string `json:"request_id,omitempty"`
LogicalGroupID string `json:"logical_group_id"`
Model string `json:"model"`
Scope string `json:"scope"`
SubjectID string `json:"subject_id"`
UserKey string `json:"user_key,omitempty"`
ConversationKey string `json:"conversation_key,omitempty"`
GatewayAPIKey string `json:"gateway_api_key,omitempty"`
SubscriptionUserID string `json:"subscription_user_id,omitempty"`
Messages []ChatCompletionMessage `json:"messages,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
Sync bool `json:"sync,omitempty"`
}
type ChatCompletionMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type ProxyRouteChatCompletionsResult struct {
Resolve ResolveRouteInfo `json:"resolve"`
Forward RouteChatCompletionsForwardInfo `json:"forward"`
}
type RouteChatCompletionsForwardInfo struct {
OK bool `json:"ok"`
HostID string `json:"host_id"`
HostBaseURL string `json:"host_base_url"`
ShadowGroupID string `json:"shadow_group_id"`
ShadowModel string `json:"shadow_model"`
EffectiveGatewayKeySource string `json:"effective_gateway_key_source,omitempty"`
EffectiveGatewayKeyFingerprint string `json:"effective_gateway_key_fingerprint,omitempty"`
ManagedUserID string `json:"managed_user_id,omitempty"`
UpstreamPath string `json:"upstream_path"`
UpstreamStatus int `json:"upstream_status"`
LatencyMS int64 `json:"latency_ms"`
ContentType string `json:"content_type,omitempty"`
ErrorClass string `json:"error_class,omitempty"`
ErrorMessage string `json:"error_message,omitempty"`
Response any `json:"response,omitempty"`
}
type RouteChatCompletionsResult struct {
RequestID string `json:"request_id"`
Backend string `json:"backend"`
LogicalGroupID string `json:"logical_group_id"`
Model string `json:"model"`
Scope string `json:"scope"`
SubjectID string `json:"subject_id"`
StickyKey string `json:"sticky_key"`
StickyHit bool `json:"sticky_hit"`
StickyAction string `json:"sticky_action"`
FallbackUsed bool `json:"fallback_used,omitempty"`
SelectedRoute RouteChatCompletionsRouteInfo `json:"selected_route"`
Forward RouteChatCompletionsForwardInfo `json:"forward"`
}
type RouteChatCompletionsRouteInfo struct {
RouteID string `json:"route_id"`
RouteName string `json:"route_name,omitempty"`
ShadowHostID string `json:"shadow_host_id"`
ShadowGroupID string `json:"shadow_group_id"`
ShadowModel string `json:"shadow_model,omitempty"`
Priority int `json:"priority"`
Weight int `json:"weight"`
BoundAt string `json:"bound_at,omitempty"`
ExpiresAt string `json:"expires_at,omitempty"`
}
func handleProxyRouteChatCompletions(w http.ResponseWriter, r *http.Request, fn func(context.Context, ProxyRouteChatCompletionsRequest) (ProxyRouteChatCompletionsResult, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "proxy-route-chat-completions action is not configured"})
return
}
var req ProxyRouteChatCompletionsRequest
if err := decodeJSON(r, &req); err != nil {
writeHTTPError(w, err)
return
}
result, err := fn(r.Context(), req)
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
writeJSON(w, http.StatusOK, result)
}
func handleRouteChatCompletions(w http.ResponseWriter, r *http.Request, fn func(context.Context, RouteChatCompletionsRequest) (RouteChatCompletionsResult, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "route-chat-completions action is not configured"})
return
}
var req RouteChatCompletionsRequest
if err := decodeJSON(r, &req); err != nil {
writeHTTPError(w, err)
return
}
result, err := fn(r.Context(), req)
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
writeJSON(w, http.StatusOK, result)
}
func buildProxyRouteChatCompletionsAction(
sqliteDSN string,
resolveRoute func(context.Context, ResolveRouteRequest) (ResolveRouteInfo, error),
writerSource *lazyRouteLogWriter,
) func(context.Context, ProxyRouteChatCompletionsRequest) (ProxyRouteChatCompletionsResult, error) {
return func(ctx context.Context, req ProxyRouteChatCompletionsRequest) (ProxyRouteChatCompletionsResult, error) {
req.GatewayAPIKey = strings.TrimSpace(req.GatewayAPIKey)
req.SubscriptionUserID = strings.TrimSpace(req.SubscriptionUserID)
if req.GatewayAPIKey == "" && req.SubscriptionUserID == "" {
return ProxyRouteChatCompletionsResult{}, fmt.Errorf("gateway_api_key or subscription_user_id is required")
}
resolveInfo, err := resolveRoute(ctx, ResolveRouteRequest{
RequestID: req.RequestID,
LogicalGroupID: req.LogicalGroupID,
PublicModel: req.PublicModel,
Scope: req.Scope,
SubjectID: req.SubjectID,
UserKey: req.UserKey,
ConversationKey: req.ConversationKey,
Sync: req.Sync,
})
if err != nil {
return ProxyRouteChatCompletionsResult{}, err
}
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {
return ProxyRouteChatCompletionsResult{}, err
}
defer store.Close()
hostRow, err := store.Hosts().GetByHostID(ctx, strings.TrimSpace(resolveInfo.ShadowHostID))
if err != nil {
return ProxyRouteChatCompletionsResult{}, fmt.Errorf("get shadow host %q: %w", resolveInfo.ShadowHostID, err)
}
hostClient, err := newSub2APIClient(hostRow.BaseURL, authFromStoredHost(hostRow))
if err != nil {
return ProxyRouteChatCompletionsResult{}, err
}
shadowModel := strings.TrimSpace(resolveInfo.ShadowModel)
if shadowModel == "" {
shadowModel = strings.TrimSpace(resolveInfo.PublicModel)
}
gatewayAPIKey, gatewayKeySource, managedUserID, err := resolveProxyGatewayAPIKey(ctx, store, hostRow, hostClient, resolveInfo, req)
if err != nil {
return ProxyRouteChatCompletionsResult{}, err
}
forward := proxyChatCompletionToShadowHost(ctx, hostRow.BaseURL, gatewayAPIKey, shadowModel, req.Messages, req.MaxTokens, req.Temperature)
forward.HostID = strings.TrimSpace(hostRow.HostID)
forward.HostBaseURL = strings.TrimSpace(hostRow.BaseURL)
forward.ShadowGroupID = strings.TrimSpace(resolveInfo.ShadowGroupID)
forward.ShadowModel = shadowModel
forward.EffectiveGatewayKeySource = gatewayKeySource
forward.EffectiveGatewayKeyFingerprint = fingerprintRouteProxySecret(gatewayAPIKey)
forward.ManagedUserID = managedUserID
if err := appendProxyRouteDecisionLog(ctx, writerSource, req, resolveInfo, forward); err != nil {
return ProxyRouteChatCompletionsResult{}, err
}
if req.Sync {
writer, err := writerSource.get(ctx)
if err != nil {
return ProxyRouteChatCompletionsResult{}, err
}
if err := writer.Flush(ctx); err != nil {
return ProxyRouteChatCompletionsResult{}, err
}
}
return ProxyRouteChatCompletionsResult{
Resolve: resolveInfo,
Forward: forward,
}, nil
}
}
func buildRouteChatCompletionsAction(
proxyRouteChatCompletions func(context.Context, ProxyRouteChatCompletionsRequest) (ProxyRouteChatCompletionsResult, error),
) func(context.Context, RouteChatCompletionsRequest) (RouteChatCompletionsResult, error) {
return func(ctx context.Context, req RouteChatCompletionsRequest) (RouteChatCompletionsResult, error) {
result, err := proxyRouteChatCompletions(ctx, ProxyRouteChatCompletionsRequest{
RequestID: req.RequestID,
LogicalGroupID: req.LogicalGroupID,
PublicModel: req.Model,
Scope: req.Scope,
SubjectID: req.SubjectID,
UserKey: req.UserKey,
ConversationKey: req.ConversationKey,
GatewayAPIKey: req.GatewayAPIKey,
SubscriptionUserID: req.SubscriptionUserID,
Messages: req.Messages,
MaxTokens: req.MaxTokens,
Temperature: req.Temperature,
Sync: req.Sync,
})
if err != nil {
return RouteChatCompletionsResult{}, err
}
return routeChatCompletionsResultFromProxy(result), nil
}
}
func routeChatCompletionsResultFromProxy(result ProxyRouteChatCompletionsResult) RouteChatCompletionsResult {
return RouteChatCompletionsResult{
RequestID: result.Resolve.RequestID,
Backend: result.Resolve.Backend,
LogicalGroupID: result.Resolve.LogicalGroupID,
Model: result.Resolve.PublicModel,
Scope: result.Resolve.Scope,
SubjectID: result.Resolve.SubjectID,
StickyKey: result.Resolve.StickyKey,
StickyHit: result.Resolve.StickyHit,
StickyAction: result.Resolve.StickyAction,
FallbackUsed: result.Resolve.FallbackUsed,
SelectedRoute: RouteChatCompletionsRouteInfo{
RouteID: result.Resolve.RouteID,
RouteName: result.Resolve.RouteName,
ShadowHostID: result.Resolve.ShadowHostID,
ShadowGroupID: result.Resolve.ShadowGroupID,
ShadowModel: result.Resolve.ShadowModel,
Priority: result.Resolve.Priority,
Weight: result.Resolve.Weight,
BoundAt: result.Resolve.BoundAt,
ExpiresAt: result.Resolve.ExpiresAt,
},
Forward: result.Forward,
}
}
func appendProxyRouteDecisionLog(
ctx context.Context,
writerSource *lazyRouteLogWriter,
req ProxyRouteChatCompletionsRequest,
resolveInfo ResolveRouteInfo,
forward RouteChatCompletionsForwardInfo,
) error {
writer, err := writerSource.get(ctx)
if err != nil {
return err
}
return writer.AppendDecision(ctx, routing.RouteDecisionEvent{
RequestID: strings.TrimSpace(resolveInfo.RequestID),
LogicalGroupID: strings.TrimSpace(resolveInfo.LogicalGroupID),
PublicModel: strings.TrimSpace(resolveInfo.PublicModel),
UserKey: resolveProxyUserKey(req),
ConversationKey: resolveProxyConversationKey(req),
StickyKey: strings.TrimSpace(resolveInfo.StickyKey),
StickyKeyType: strings.TrimSpace(resolveInfo.Scope),
StickyHit: resolveInfo.StickyHit,
SelectedRouteID: strings.TrimSpace(resolveInfo.RouteID),
SelectedShadowGroupID: strings.TrimSpace(resolveInfo.ShadowGroupID),
ErrorClass: strings.TrimSpace(forward.ErrorClass),
UpstreamStatus: forward.UpstreamStatus,
LatencyMS: int(forward.LatencyMS),
})
}
func proxyChatCompletionToShadowHost(
ctx context.Context,
baseURL, gatewayAPIKey, shadowModel string,
messages []ChatCompletionMessage,
maxTokens int,
temperature *float64,
) RouteChatCompletionsForwardInfo {
info := RouteChatCompletionsForwardInfo{
UpstreamPath: routeChatCompletionsPath,
}
requestURL, err := joinRouteProxyPath(baseURL, routeChatCompletionsPath)
if err != nil {
info.ErrorClass = "invalid_host_base_url"
info.ErrorMessage = err.Error()
return info
}
payload := map[string]any{
"model": strings.TrimSpace(shadowModel),
"messages": normalizeProxyChatMessages(messages),
"max_tokens": normalizeProxyMaxTokens(maxTokens),
"temperature": normalizeProxyTemperature(temperature),
}
var body bytes.Buffer
if err := json.NewEncoder(&body).Encode(payload); err != nil {
info.ErrorClass = "encode_request_failed"
info.ErrorMessage = err.Error()
return info
}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, &body)
if err != nil {
info.ErrorClass = "build_request_failed"
info.ErrorMessage = err.Error()
return info
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Accept", "application/json, text/event-stream")
httpReq.Header.Set("Authorization", "Bearer "+strings.TrimSpace(gatewayAPIKey))
startedAt := time.Now()
resp, err := (&http.Client{Timeout: 20 * time.Second}).Do(httpReq)
if err != nil {
info.LatencyMS = time.Since(startedAt).Milliseconds()
info.ErrorClass = "transport_error"
info.ErrorMessage = err.Error()
return info
}
defer resp.Body.Close()
info.LatencyMS = time.Since(startedAt).Milliseconds()
info.UpstreamStatus = resp.StatusCode
info.ContentType = strings.TrimSpace(resp.Header.Get("Content-Type"))
responseBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
info.ErrorClass = "read_response_failed"
info.ErrorMessage = readErr.Error()
return info
}
info.OK = resp.StatusCode >= http.StatusOK && resp.StatusCode < http.StatusMultipleChoices
info.Response = decodeProxyResponseBody(responseBody)
if !info.OK {
info.ErrorClass = classifyProxyUpstreamStatus(resp.StatusCode)
}
return info
}
func normalizeProxyChatMessages(messages []ChatCompletionMessage) []map[string]string {
if len(messages) == 0 {
return []map[string]string{{"role": "user", "content": "ping"}}
}
normalized := make([]map[string]string, 0, len(messages))
for _, message := range messages {
role := strings.TrimSpace(message.Role)
if role == "" {
role = "user"
}
normalized = append(normalized, map[string]string{
"role": role,
"content": strings.TrimSpace(message.Content),
})
}
return normalized
}
func normalizeProxyMaxTokens(maxTokens int) int {
if maxTokens <= 0 {
return 8
}
return maxTokens
}
func normalizeProxyTemperature(temperature *float64) float64 {
if temperature == nil {
return 0
}
return *temperature
}
func joinRouteProxyPath(baseURL, path string) (string, error) {
parsedURL, err := url.Parse(strings.TrimSpace(baseURL))
if err != nil {
return "", err
}
if parsedURL.Scheme == "" || parsedURL.Host == "" {
return "", fmt.Errorf("base url must include scheme and host")
}
resolvedPath := strings.TrimSpace(path)
if !strings.HasPrefix(resolvedPath, "/") {
resolvedPath = "/" + resolvedPath
}
return parsedURL.ResolveReference(&url.URL{Path: resolvedPath}).String(), nil
}
func decodeProxyResponseBody(body []byte) any {
trimmed := bytes.TrimSpace(body)
if len(trimmed) == 0 {
return nil
}
var payload any
if err := json.Unmarshal(trimmed, &payload); err == nil {
return payload
}
return string(trimmed)
}
func classifyProxyUpstreamStatus(statusCode int) string {
switch {
case statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden:
return "gateway_auth_error"
case statusCode == http.StatusTooManyRequests:
return "gateway_rate_limited"
case statusCode >= http.StatusBadGateway:
return "gateway_5xx"
case statusCode >= http.StatusBadRequest:
return fmt.Sprintf("gateway_%d", statusCode)
default:
return ""
}
}
func resolveProxyGatewayAPIKey(
ctx context.Context,
store *sqlite.DB,
hostRow sqlite.Host,
hostClient *sub2api.Client,
resolveInfo ResolveRouteInfo,
req ProxyRouteChatCompletionsRequest,
) (string, string, string, error) {
gatewayAPIKey := strings.TrimSpace(req.GatewayAPIKey)
if gatewayAPIKey != "" {
return gatewayAPIKey, access.ProbeKeySourceRequestedProbeAPIKey, "", nil
}
subscriptionUserID := strings.TrimSpace(req.SubscriptionUserID)
if subscriptionUserID == "" {
return "", "", "", fmt.Errorf("gateway_api_key or subscription_user_id is required")
}
shadowGroupHostResourceID, err := resolveShadowGroupHostResourceID(ctx, store, hostRow, hostClient, strings.TrimSpace(resolveInfo.ShadowGroupID))
if err != nil {
return "", "", "", err
}
accessRef, err := hostClient.EnsureSubscriptionAccess(ctx, sub2api.EnsureSubscriptionAccessRequest{
UserSelector: subscriptionUserID,
GroupID: shadowGroupHostResourceID,
})
if err != nil {
return "", "", "", fmt.Errorf("ensure subscription access for route %q: %w", resolveInfo.RouteID, err)
}
gatewayAPIKey = strings.TrimSpace(accessRef.APIKey)
if gatewayAPIKey == "" {
return "", "", "", fmt.Errorf("managed subscription access api key is empty")
}
return gatewayAPIKey, access.ProbeKeySourceManagedSubscription, strings.TrimSpace(accessRef.UserID), nil
}
func resolveShadowGroupHostResourceID(
ctx context.Context,
store *sqlite.DB,
hostRow sqlite.Host,
hostClient *sub2api.Client,
shadowGroupID string,
) (string, error) {
shadowGroupID = strings.TrimSpace(shadowGroupID)
if shadowGroupID == "" {
return "", fmt.Errorf("shadow_group_id is required")
}
if _, err := strconv.ParseInt(shadowGroupID, 10, 64); err == nil {
return shadowGroupID, nil
}
resource, err := store.ManagedResources().GetByResourceIdentity(ctx, hostRow.ID, "group", shadowGroupID)
if err == nil {
return strings.TrimSpace(resource.HostResourceID), nil
}
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return "", fmt.Errorf("lookup shadow group %q in managed resources: %w", shadowGroupID, err)
}
snapshot, err := hostClient.ListManagedResources(ctx, sub2api.ListManagedResourcesRequest{GroupName: shadowGroupID})
if err != nil {
return "", fmt.Errorf("list host groups for %q: %w", shadowGroupID, err)
}
if len(snapshot.Groups) == 1 {
return strings.TrimSpace(snapshot.Groups[0].ID), nil
}
if len(snapshot.Groups) > 1 {
return "", fmt.Errorf("multiple host groups matched shadow_group_id %q", shadowGroupID)
}
return "", fmt.Errorf("shadow group %q not found on host %q", shadowGroupID, hostRow.HostID)
}
func resolveProxyUserKey(req ProxyRouteChatCompletionsRequest) string {
if key := strings.TrimSpace(req.UserKey); key != "" {
return key
}
if strings.EqualFold(strings.TrimSpace(req.Scope), "user") {
return strings.TrimSpace(req.SubjectID)
}
return ""
}
func fingerprintRouteProxySecret(value string) string {
trimmed := strings.TrimSpace(value)
if trimmed == "" {
return ""
}
sum := sha256.Sum256([]byte(trimmed))
return "sha256:" + hex.EncodeToString(sum[:])
}
func resolveProxyConversationKey(req ProxyRouteChatCompletionsRequest) string {
if key := strings.TrimSpace(req.ConversationKey); key != "" {
return key
}
if strings.EqualFold(strings.TrimSpace(req.Scope), "conversation") {
return strings.TrimSpace(req.SubjectID)
}
return ""
}