339 lines
11 KiB
Go
339 lines
11 KiB
Go
package app
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"sub2api-cn-relay-manager/internal/routing"
|
|
"sub2api-cn-relay-manager/internal/store/sqlite"
|
|
)
|
|
|
|
const routeChatCompletionsPath = "/v1/chat/completions"
|
|
|
|
type ProxyRouteChatCompletionsRequest struct {
|
|
RequestID string `json:"request_id,omitempty"`
|
|
LogicalGroupID string `json:"logical_group_id"`
|
|
PublicModel string `json:"public_model"`
|
|
Scope string `json:"scope"`
|
|
SubjectID string `json:"subject_id"`
|
|
UserKey string `json:"user_key,omitempty"`
|
|
ConversationKey string `json:"conversation_key,omitempty"`
|
|
GatewayAPIKey string `json:"gateway_api_key"`
|
|
Messages []ChatCompletionMessage `json:"messages,omitempty"`
|
|
MaxTokens int `json:"max_tokens,omitempty"`
|
|
Temperature *float64 `json:"temperature,omitempty"`
|
|
Sync bool `json:"sync,omitempty"`
|
|
}
|
|
|
|
type ChatCompletionMessage struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
}
|
|
|
|
type ProxyRouteChatCompletionsResult struct {
|
|
Resolve ResolveRouteInfo `json:"resolve"`
|
|
Forward RouteChatCompletionsForwardInfo `json:"forward"`
|
|
}
|
|
|
|
type RouteChatCompletionsForwardInfo struct {
|
|
OK bool `json:"ok"`
|
|
HostID string `json:"host_id"`
|
|
HostBaseURL string `json:"host_base_url"`
|
|
ShadowGroupID string `json:"shadow_group_id"`
|
|
ShadowModel string `json:"shadow_model"`
|
|
UpstreamPath string `json:"upstream_path"`
|
|
UpstreamStatus int `json:"upstream_status"`
|
|
LatencyMS int64 `json:"latency_ms"`
|
|
ContentType string `json:"content_type,omitempty"`
|
|
ErrorClass string `json:"error_class,omitempty"`
|
|
ErrorMessage string `json:"error_message,omitempty"`
|
|
Response any `json:"response,omitempty"`
|
|
}
|
|
|
|
func handleProxyRouteChatCompletions(w http.ResponseWriter, r *http.Request, fn func(context.Context, ProxyRouteChatCompletionsRequest) (ProxyRouteChatCompletionsResult, error)) {
|
|
if fn == nil {
|
|
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "proxy-route-chat-completions action is not configured"})
|
|
return
|
|
}
|
|
var req ProxyRouteChatCompletionsRequest
|
|
if err := decodeJSON(r, &req); err != nil {
|
|
writeHTTPError(w, err)
|
|
return
|
|
}
|
|
result, err := fn(r.Context(), req)
|
|
if err != nil {
|
|
writeHTTPError(w, classifyError(err))
|
|
return
|
|
}
|
|
writeJSON(w, http.StatusOK, result)
|
|
}
|
|
|
|
func buildProxyRouteChatCompletionsAction(
|
|
sqliteDSN string,
|
|
resolveRoute func(context.Context, ResolveRouteRequest) (ResolveRouteInfo, error),
|
|
writerSource *lazyRouteLogWriter,
|
|
) func(context.Context, ProxyRouteChatCompletionsRequest) (ProxyRouteChatCompletionsResult, error) {
|
|
return func(ctx context.Context, req ProxyRouteChatCompletionsRequest) (ProxyRouteChatCompletionsResult, error) {
|
|
req.GatewayAPIKey = strings.TrimSpace(req.GatewayAPIKey)
|
|
if req.GatewayAPIKey == "" {
|
|
return ProxyRouteChatCompletionsResult{}, fmt.Errorf("gateway_api_key is required")
|
|
}
|
|
|
|
resolveInfo, err := resolveRoute(ctx, ResolveRouteRequest{
|
|
RequestID: req.RequestID,
|
|
LogicalGroupID: req.LogicalGroupID,
|
|
PublicModel: req.PublicModel,
|
|
Scope: req.Scope,
|
|
SubjectID: req.SubjectID,
|
|
UserKey: req.UserKey,
|
|
ConversationKey: req.ConversationKey,
|
|
Sync: req.Sync,
|
|
})
|
|
if err != nil {
|
|
return ProxyRouteChatCompletionsResult{}, err
|
|
}
|
|
|
|
store, err := sqlite.Open(ctx, sqliteDSN)
|
|
if err != nil {
|
|
return ProxyRouteChatCompletionsResult{}, err
|
|
}
|
|
defer store.Close()
|
|
|
|
hostRow, err := store.Hosts().GetByHostID(ctx, strings.TrimSpace(resolveInfo.ShadowHostID))
|
|
if err != nil {
|
|
return ProxyRouteChatCompletionsResult{}, fmt.Errorf("get shadow host %q: %w", resolveInfo.ShadowHostID, err)
|
|
}
|
|
|
|
shadowModel := strings.TrimSpace(resolveInfo.ShadowModel)
|
|
if shadowModel == "" {
|
|
shadowModel = strings.TrimSpace(resolveInfo.PublicModel)
|
|
}
|
|
|
|
forward := proxyChatCompletionToShadowHost(ctx, hostRow.BaseURL, req.GatewayAPIKey, shadowModel, req.Messages, req.MaxTokens, req.Temperature)
|
|
forward.HostID = strings.TrimSpace(hostRow.HostID)
|
|
forward.HostBaseURL = strings.TrimSpace(hostRow.BaseURL)
|
|
forward.ShadowGroupID = strings.TrimSpace(resolveInfo.ShadowGroupID)
|
|
forward.ShadowModel = shadowModel
|
|
|
|
if err := appendProxyRouteDecisionLog(ctx, writerSource, req, resolveInfo, forward); err != nil {
|
|
return ProxyRouteChatCompletionsResult{}, err
|
|
}
|
|
if req.Sync {
|
|
writer, err := writerSource.get(ctx)
|
|
if err != nil {
|
|
return ProxyRouteChatCompletionsResult{}, err
|
|
}
|
|
if err := writer.Flush(ctx); err != nil {
|
|
return ProxyRouteChatCompletionsResult{}, err
|
|
}
|
|
}
|
|
|
|
return ProxyRouteChatCompletionsResult{
|
|
Resolve: resolveInfo,
|
|
Forward: forward,
|
|
}, nil
|
|
}
|
|
}
|
|
|
|
func appendProxyRouteDecisionLog(
|
|
ctx context.Context,
|
|
writerSource *lazyRouteLogWriter,
|
|
req ProxyRouteChatCompletionsRequest,
|
|
resolveInfo ResolveRouteInfo,
|
|
forward RouteChatCompletionsForwardInfo,
|
|
) error {
|
|
writer, err := writerSource.get(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return writer.AppendDecision(ctx, routing.RouteDecisionEvent{
|
|
RequestID: strings.TrimSpace(resolveInfo.RequestID),
|
|
LogicalGroupID: strings.TrimSpace(resolveInfo.LogicalGroupID),
|
|
PublicModel: strings.TrimSpace(resolveInfo.PublicModel),
|
|
UserKey: resolveProxyUserKey(req),
|
|
ConversationKey: resolveProxyConversationKey(req),
|
|
StickyKey: strings.TrimSpace(resolveInfo.StickyKey),
|
|
StickyKeyType: strings.TrimSpace(resolveInfo.Scope),
|
|
StickyHit: resolveInfo.StickyHit,
|
|
SelectedRouteID: strings.TrimSpace(resolveInfo.RouteID),
|
|
SelectedShadowGroupID: strings.TrimSpace(resolveInfo.ShadowGroupID),
|
|
ErrorClass: strings.TrimSpace(forward.ErrorClass),
|
|
UpstreamStatus: forward.UpstreamStatus,
|
|
LatencyMS: int(forward.LatencyMS),
|
|
})
|
|
}
|
|
|
|
func proxyChatCompletionToShadowHost(
|
|
ctx context.Context,
|
|
baseURL, gatewayAPIKey, shadowModel string,
|
|
messages []ChatCompletionMessage,
|
|
maxTokens int,
|
|
temperature *float64,
|
|
) RouteChatCompletionsForwardInfo {
|
|
info := RouteChatCompletionsForwardInfo{
|
|
UpstreamPath: routeChatCompletionsPath,
|
|
}
|
|
|
|
requestURL, err := joinRouteProxyPath(baseURL, routeChatCompletionsPath)
|
|
if err != nil {
|
|
info.ErrorClass = "invalid_host_base_url"
|
|
info.ErrorMessage = err.Error()
|
|
return info
|
|
}
|
|
|
|
payload := map[string]any{
|
|
"model": strings.TrimSpace(shadowModel),
|
|
"messages": normalizeProxyChatMessages(messages),
|
|
"max_tokens": normalizeProxyMaxTokens(maxTokens),
|
|
"temperature": normalizeProxyTemperature(temperature),
|
|
}
|
|
|
|
var body bytes.Buffer
|
|
if err := json.NewEncoder(&body).Encode(payload); err != nil {
|
|
info.ErrorClass = "encode_request_failed"
|
|
info.ErrorMessage = err.Error()
|
|
return info
|
|
}
|
|
|
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, &body)
|
|
if err != nil {
|
|
info.ErrorClass = "build_request_failed"
|
|
info.ErrorMessage = err.Error()
|
|
return info
|
|
}
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
httpReq.Header.Set("Accept", "application/json, text/event-stream")
|
|
httpReq.Header.Set("Authorization", "Bearer "+strings.TrimSpace(gatewayAPIKey))
|
|
|
|
startedAt := time.Now()
|
|
resp, err := (&http.Client{Timeout: 20 * time.Second}).Do(httpReq)
|
|
if err != nil {
|
|
info.LatencyMS = time.Since(startedAt).Milliseconds()
|
|
info.ErrorClass = "transport_error"
|
|
info.ErrorMessage = err.Error()
|
|
return info
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
info.LatencyMS = time.Since(startedAt).Milliseconds()
|
|
info.UpstreamStatus = resp.StatusCode
|
|
info.ContentType = strings.TrimSpace(resp.Header.Get("Content-Type"))
|
|
|
|
responseBody, readErr := io.ReadAll(resp.Body)
|
|
if readErr != nil {
|
|
info.ErrorClass = "read_response_failed"
|
|
info.ErrorMessage = readErr.Error()
|
|
return info
|
|
}
|
|
|
|
info.OK = resp.StatusCode >= http.StatusOK && resp.StatusCode < http.StatusMultipleChoices
|
|
info.Response = decodeProxyResponseBody(responseBody)
|
|
if !info.OK {
|
|
info.ErrorClass = classifyProxyUpstreamStatus(resp.StatusCode)
|
|
}
|
|
return info
|
|
}
|
|
|
|
func normalizeProxyChatMessages(messages []ChatCompletionMessage) []map[string]string {
|
|
if len(messages) == 0 {
|
|
return []map[string]string{{"role": "user", "content": "ping"}}
|
|
}
|
|
|
|
normalized := make([]map[string]string, 0, len(messages))
|
|
for _, message := range messages {
|
|
role := strings.TrimSpace(message.Role)
|
|
if role == "" {
|
|
role = "user"
|
|
}
|
|
normalized = append(normalized, map[string]string{
|
|
"role": role,
|
|
"content": strings.TrimSpace(message.Content),
|
|
})
|
|
}
|
|
return normalized
|
|
}
|
|
|
|
func normalizeProxyMaxTokens(maxTokens int) int {
|
|
if maxTokens <= 0 {
|
|
return 8
|
|
}
|
|
return maxTokens
|
|
}
|
|
|
|
func normalizeProxyTemperature(temperature *float64) float64 {
|
|
if temperature == nil {
|
|
return 0
|
|
}
|
|
return *temperature
|
|
}
|
|
|
|
func joinRouteProxyPath(baseURL, path string) (string, error) {
|
|
parsedURL, err := url.Parse(strings.TrimSpace(baseURL))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if parsedURL.Scheme == "" || parsedURL.Host == "" {
|
|
return "", fmt.Errorf("base url must include scheme and host")
|
|
}
|
|
|
|
resolvedPath := strings.TrimSpace(path)
|
|
if !strings.HasPrefix(resolvedPath, "/") {
|
|
resolvedPath = "/" + resolvedPath
|
|
}
|
|
return parsedURL.ResolveReference(&url.URL{Path: resolvedPath}).String(), nil
|
|
}
|
|
|
|
func decodeProxyResponseBody(body []byte) any {
|
|
trimmed := bytes.TrimSpace(body)
|
|
if len(trimmed) == 0 {
|
|
return nil
|
|
}
|
|
var payload any
|
|
if err := json.Unmarshal(trimmed, &payload); err == nil {
|
|
return payload
|
|
}
|
|
return string(trimmed)
|
|
}
|
|
|
|
func classifyProxyUpstreamStatus(statusCode int) string {
|
|
switch {
|
|
case statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden:
|
|
return "gateway_auth_error"
|
|
case statusCode == http.StatusTooManyRequests:
|
|
return "gateway_rate_limited"
|
|
case statusCode >= http.StatusBadGateway:
|
|
return "gateway_5xx"
|
|
case statusCode >= http.StatusBadRequest:
|
|
return fmt.Sprintf("gateway_%d", statusCode)
|
|
default:
|
|
return ""
|
|
}
|
|
}
|
|
|
|
func resolveProxyUserKey(req ProxyRouteChatCompletionsRequest) string {
|
|
if key := strings.TrimSpace(req.UserKey); key != "" {
|
|
return key
|
|
}
|
|
if strings.EqualFold(strings.TrimSpace(req.Scope), "user") {
|
|
return strings.TrimSpace(req.SubjectID)
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func resolveProxyConversationKey(req ProxyRouteChatCompletionsRequest) string {
|
|
if key := strings.TrimSpace(req.ConversationKey); key != "" {
|
|
return key
|
|
}
|
|
if strings.EqualFold(strings.TrimSpace(req.Scope), "conversation") {
|
|
return strings.TrimSpace(req.SubjectID)
|
|
}
|
|
return ""
|
|
}
|