package app import ( "bytes" "context" "crypto/sha256" "database/sql" "encoding/hex" "encoding/json" "errors" "fmt" "io" "net/http" "net/url" "strconv" "strings" "time" "sub2api-cn-relay-manager/internal/access" "sub2api-cn-relay-manager/internal/host/sub2api" "sub2api-cn-relay-manager/internal/routing" "sub2api-cn-relay-manager/internal/store/sqlite" ) const routeChatCompletionsPath = "/v1/chat/completions" type ProxyRouteChatCompletionsRequest struct { RequestID string `json:"request_id,omitempty"` LogicalGroupID string `json:"logical_group_id"` PublicModel string `json:"public_model"` Scope string `json:"scope"` SubjectID string `json:"subject_id"` UserKey string `json:"user_key,omitempty"` ConversationKey string `json:"conversation_key,omitempty"` GatewayAPIKey string `json:"gateway_api_key"` SubscriptionUserID string `json:"subscription_user_id,omitempty"` Messages []ChatCompletionMessage `json:"messages,omitempty"` MaxTokens int `json:"max_tokens,omitempty"` Temperature *float64 `json:"temperature,omitempty"` Sync bool `json:"sync,omitempty"` } type ChatCompletionMessage struct { Role string `json:"role"` Content string `json:"content"` } type ProxyRouteChatCompletionsResult struct { Resolve ResolveRouteInfo `json:"resolve"` Forward RouteChatCompletionsForwardInfo `json:"forward"` } type RouteChatCompletionsForwardInfo struct { OK bool `json:"ok"` HostID string `json:"host_id"` HostBaseURL string `json:"host_base_url"` ShadowGroupID string `json:"shadow_group_id"` ShadowModel string `json:"shadow_model"` EffectiveGatewayKeySource string `json:"effective_gateway_key_source,omitempty"` EffectiveGatewayKeyFingerprint string `json:"effective_gateway_key_fingerprint,omitempty"` ManagedUserID string `json:"managed_user_id,omitempty"` UpstreamPath string `json:"upstream_path"` UpstreamStatus int `json:"upstream_status"` LatencyMS int64 `json:"latency_ms"` ContentType string `json:"content_type,omitempty"` ErrorClass string `json:"error_class,omitempty"` ErrorMessage string `json:"error_message,omitempty"` Response any `json:"response,omitempty"` } func handleProxyRouteChatCompletions(w http.ResponseWriter, r *http.Request, fn func(context.Context, ProxyRouteChatCompletionsRequest) (ProxyRouteChatCompletionsResult, error)) { if fn == nil { writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "proxy-route-chat-completions action is not configured"}) return } var req ProxyRouteChatCompletionsRequest if err := decodeJSON(r, &req); err != nil { writeHTTPError(w, err) return } result, err := fn(r.Context(), req) if err != nil { writeHTTPError(w, classifyError(err)) return } writeJSON(w, http.StatusOK, result) } func buildProxyRouteChatCompletionsAction( sqliteDSN string, resolveRoute func(context.Context, ResolveRouteRequest) (ResolveRouteInfo, error), writerSource *lazyRouteLogWriter, ) func(context.Context, ProxyRouteChatCompletionsRequest) (ProxyRouteChatCompletionsResult, error) { return func(ctx context.Context, req ProxyRouteChatCompletionsRequest) (ProxyRouteChatCompletionsResult, error) { req.GatewayAPIKey = strings.TrimSpace(req.GatewayAPIKey) req.SubscriptionUserID = strings.TrimSpace(req.SubscriptionUserID) if req.GatewayAPIKey == "" && req.SubscriptionUserID == "" { return ProxyRouteChatCompletionsResult{}, fmt.Errorf("gateway_api_key or subscription_user_id is required") } resolveInfo, err := resolveRoute(ctx, ResolveRouteRequest{ RequestID: req.RequestID, LogicalGroupID: req.LogicalGroupID, PublicModel: req.PublicModel, Scope: req.Scope, SubjectID: req.SubjectID, UserKey: req.UserKey, ConversationKey: req.ConversationKey, Sync: req.Sync, }) if err != nil { return ProxyRouteChatCompletionsResult{}, err } store, err := sqlite.Open(ctx, sqliteDSN) if err != nil { return ProxyRouteChatCompletionsResult{}, err } defer store.Close() hostRow, err := store.Hosts().GetByHostID(ctx, strings.TrimSpace(resolveInfo.ShadowHostID)) if err != nil { return ProxyRouteChatCompletionsResult{}, fmt.Errorf("get shadow host %q: %w", resolveInfo.ShadowHostID, err) } hostClient, err := newSub2APIClient(hostRow.BaseURL, authFromStoredHost(hostRow)) if err != nil { return ProxyRouteChatCompletionsResult{}, err } shadowModel := strings.TrimSpace(resolveInfo.ShadowModel) if shadowModel == "" { shadowModel = strings.TrimSpace(resolveInfo.PublicModel) } gatewayAPIKey, gatewayKeySource, managedUserID, err := resolveProxyGatewayAPIKey(ctx, store, hostRow, hostClient, resolveInfo, req) if err != nil { return ProxyRouteChatCompletionsResult{}, err } forward := proxyChatCompletionToShadowHost(ctx, hostRow.BaseURL, gatewayAPIKey, shadowModel, req.Messages, req.MaxTokens, req.Temperature) forward.HostID = strings.TrimSpace(hostRow.HostID) forward.HostBaseURL = strings.TrimSpace(hostRow.BaseURL) forward.ShadowGroupID = strings.TrimSpace(resolveInfo.ShadowGroupID) forward.ShadowModel = shadowModel forward.EffectiveGatewayKeySource = gatewayKeySource forward.EffectiveGatewayKeyFingerprint = fingerprintRouteProxySecret(gatewayAPIKey) forward.ManagedUserID = managedUserID if err := appendProxyRouteDecisionLog(ctx, writerSource, req, resolveInfo, forward); err != nil { return ProxyRouteChatCompletionsResult{}, err } if req.Sync { writer, err := writerSource.get(ctx) if err != nil { return ProxyRouteChatCompletionsResult{}, err } if err := writer.Flush(ctx); err != nil { return ProxyRouteChatCompletionsResult{}, err } } return ProxyRouteChatCompletionsResult{ Resolve: resolveInfo, Forward: forward, }, nil } } func appendProxyRouteDecisionLog( ctx context.Context, writerSource *lazyRouteLogWriter, req ProxyRouteChatCompletionsRequest, resolveInfo ResolveRouteInfo, forward RouteChatCompletionsForwardInfo, ) error { writer, err := writerSource.get(ctx) if err != nil { return err } return writer.AppendDecision(ctx, routing.RouteDecisionEvent{ RequestID: strings.TrimSpace(resolveInfo.RequestID), LogicalGroupID: strings.TrimSpace(resolveInfo.LogicalGroupID), PublicModel: strings.TrimSpace(resolveInfo.PublicModel), UserKey: resolveProxyUserKey(req), ConversationKey: resolveProxyConversationKey(req), StickyKey: strings.TrimSpace(resolveInfo.StickyKey), StickyKeyType: strings.TrimSpace(resolveInfo.Scope), StickyHit: resolveInfo.StickyHit, SelectedRouteID: strings.TrimSpace(resolveInfo.RouteID), SelectedShadowGroupID: strings.TrimSpace(resolveInfo.ShadowGroupID), ErrorClass: strings.TrimSpace(forward.ErrorClass), UpstreamStatus: forward.UpstreamStatus, LatencyMS: int(forward.LatencyMS), }) } func proxyChatCompletionToShadowHost( ctx context.Context, baseURL, gatewayAPIKey, shadowModel string, messages []ChatCompletionMessage, maxTokens int, temperature *float64, ) RouteChatCompletionsForwardInfo { info := RouteChatCompletionsForwardInfo{ UpstreamPath: routeChatCompletionsPath, } requestURL, err := joinRouteProxyPath(baseURL, routeChatCompletionsPath) if err != nil { info.ErrorClass = "invalid_host_base_url" info.ErrorMessage = err.Error() return info } payload := map[string]any{ "model": strings.TrimSpace(shadowModel), "messages": normalizeProxyChatMessages(messages), "max_tokens": normalizeProxyMaxTokens(maxTokens), "temperature": normalizeProxyTemperature(temperature), } var body bytes.Buffer if err := json.NewEncoder(&body).Encode(payload); err != nil { info.ErrorClass = "encode_request_failed" info.ErrorMessage = err.Error() return info } httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, &body) if err != nil { info.ErrorClass = "build_request_failed" info.ErrorMessage = err.Error() return info } httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Accept", "application/json, text/event-stream") httpReq.Header.Set("Authorization", "Bearer "+strings.TrimSpace(gatewayAPIKey)) startedAt := time.Now() resp, err := (&http.Client{Timeout: 20 * time.Second}).Do(httpReq) if err != nil { info.LatencyMS = time.Since(startedAt).Milliseconds() info.ErrorClass = "transport_error" info.ErrorMessage = err.Error() return info } defer resp.Body.Close() info.LatencyMS = time.Since(startedAt).Milliseconds() info.UpstreamStatus = resp.StatusCode info.ContentType = strings.TrimSpace(resp.Header.Get("Content-Type")) responseBody, readErr := io.ReadAll(resp.Body) if readErr != nil { info.ErrorClass = "read_response_failed" info.ErrorMessage = readErr.Error() return info } info.OK = resp.StatusCode >= http.StatusOK && resp.StatusCode < http.StatusMultipleChoices info.Response = decodeProxyResponseBody(responseBody) if !info.OK { info.ErrorClass = classifyProxyUpstreamStatus(resp.StatusCode) } return info } func normalizeProxyChatMessages(messages []ChatCompletionMessage) []map[string]string { if len(messages) == 0 { return []map[string]string{{"role": "user", "content": "ping"}} } normalized := make([]map[string]string, 0, len(messages)) for _, message := range messages { role := strings.TrimSpace(message.Role) if role == "" { role = "user" } normalized = append(normalized, map[string]string{ "role": role, "content": strings.TrimSpace(message.Content), }) } return normalized } func normalizeProxyMaxTokens(maxTokens int) int { if maxTokens <= 0 { return 8 } return maxTokens } func normalizeProxyTemperature(temperature *float64) float64 { if temperature == nil { return 0 } return *temperature } func joinRouteProxyPath(baseURL, path string) (string, error) { parsedURL, err := url.Parse(strings.TrimSpace(baseURL)) if err != nil { return "", err } if parsedURL.Scheme == "" || parsedURL.Host == "" { return "", fmt.Errorf("base url must include scheme and host") } resolvedPath := strings.TrimSpace(path) if !strings.HasPrefix(resolvedPath, "/") { resolvedPath = "/" + resolvedPath } return parsedURL.ResolveReference(&url.URL{Path: resolvedPath}).String(), nil } func decodeProxyResponseBody(body []byte) any { trimmed := bytes.TrimSpace(body) if len(trimmed) == 0 { return nil } var payload any if err := json.Unmarshal(trimmed, &payload); err == nil { return payload } return string(trimmed) } func classifyProxyUpstreamStatus(statusCode int) string { switch { case statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden: return "gateway_auth_error" case statusCode == http.StatusTooManyRequests: return "gateway_rate_limited" case statusCode >= http.StatusBadGateway: return "gateway_5xx" case statusCode >= http.StatusBadRequest: return fmt.Sprintf("gateway_%d", statusCode) default: return "" } } func resolveProxyGatewayAPIKey( ctx context.Context, store *sqlite.DB, hostRow sqlite.Host, hostClient *sub2api.Client, resolveInfo ResolveRouteInfo, req ProxyRouteChatCompletionsRequest, ) (string, string, string, error) { gatewayAPIKey := strings.TrimSpace(req.GatewayAPIKey) if gatewayAPIKey != "" { return gatewayAPIKey, access.ProbeKeySourceRequestedProbeAPIKey, "", nil } subscriptionUserID := strings.TrimSpace(req.SubscriptionUserID) if subscriptionUserID == "" { return "", "", "", fmt.Errorf("gateway_api_key or subscription_user_id is required") } shadowGroupHostResourceID, err := resolveShadowGroupHostResourceID(ctx, store, hostRow, hostClient, strings.TrimSpace(resolveInfo.ShadowGroupID)) if err != nil { return "", "", "", err } accessRef, err := hostClient.EnsureSubscriptionAccess(ctx, sub2api.EnsureSubscriptionAccessRequest{ UserSelector: subscriptionUserID, GroupID: shadowGroupHostResourceID, }) if err != nil { return "", "", "", fmt.Errorf("ensure subscription access for route %q: %w", resolveInfo.RouteID, err) } gatewayAPIKey = strings.TrimSpace(accessRef.APIKey) if gatewayAPIKey == "" { return "", "", "", fmt.Errorf("managed subscription access api key is empty") } return gatewayAPIKey, access.ProbeKeySourceManagedSubscription, strings.TrimSpace(accessRef.UserID), nil } func resolveShadowGroupHostResourceID( ctx context.Context, store *sqlite.DB, hostRow sqlite.Host, hostClient *sub2api.Client, shadowGroupID string, ) (string, error) { shadowGroupID = strings.TrimSpace(shadowGroupID) if shadowGroupID == "" { return "", fmt.Errorf("shadow_group_id is required") } if _, err := strconv.ParseInt(shadowGroupID, 10, 64); err == nil { return shadowGroupID, nil } resource, err := store.ManagedResources().GetByResourceIdentity(ctx, hostRow.ID, "group", shadowGroupID) if err == nil { return strings.TrimSpace(resource.HostResourceID), nil } if err != nil && !errors.Is(err, sql.ErrNoRows) { return "", fmt.Errorf("lookup shadow group %q in managed resources: %w", shadowGroupID, err) } snapshot, err := hostClient.ListManagedResources(ctx, sub2api.ListManagedResourcesRequest{GroupName: shadowGroupID}) if err != nil { return "", fmt.Errorf("list host groups for %q: %w", shadowGroupID, err) } if len(snapshot.Groups) == 1 { return strings.TrimSpace(snapshot.Groups[0].ID), nil } if len(snapshot.Groups) > 1 { return "", fmt.Errorf("multiple host groups matched shadow_group_id %q", shadowGroupID) } return "", fmt.Errorf("shadow group %q not found on host %q", shadowGroupID, hostRow.HostID) } func resolveProxyUserKey(req ProxyRouteChatCompletionsRequest) string { if key := strings.TrimSpace(req.UserKey); key != "" { return key } if strings.EqualFold(strings.TrimSpace(req.Scope), "user") { return strings.TrimSpace(req.SubjectID) } return "" } func fingerprintRouteProxySecret(value string) string { trimmed := strings.TrimSpace(value) if trimmed == "" { return "" } sum := sha256.Sum256([]byte(trimmed)) return "sha256:" + hex.EncodeToString(sum[:]) } func resolveProxyConversationKey(req ProxyRouteChatCompletionsRequest) string { if key := strings.TrimSpace(req.ConversationKey); key != "" { return key } if strings.EqualFold(strings.TrimSpace(req.Scope), "conversation") { return strings.TrimSpace(req.SubjectID) } return "" }