158 lines
4.8 KiB
Go
158 lines
4.8 KiB
Go
package probe
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type TransportProfile struct {
|
|
SupportsOpenAIModels bool `json:"supports_openai_models"`
|
|
SupportsOpenAIChatCompletions bool `json:"supports_openai_chat_completions"`
|
|
SupportsOpenAIResponses bool `json:"supports_openai_responses"`
|
|
SupportsAnthropicMessages bool `json:"supports_anthropic_messages"`
|
|
AuthStyle string `json:"auth_style"`
|
|
ModelIDStyle string `json:"model_id_style"`
|
|
KnownAdvisories []string `json:"known_advisories"`
|
|
}
|
|
|
|
type ModelCapabilityProfile struct {
|
|
RawModelID string `json:"raw_model_id"`
|
|
NormalizedModelID string `json:"normalized_model_id"`
|
|
CanonicalModelFamily string `json:"canonical_model_family"`
|
|
SupportsStream string `json:"supports_stream"`
|
|
SupportsTools string `json:"supports_tools"`
|
|
SupportsReasoningFields string `json:"supports_reasoning_fields"`
|
|
SmokeChatOK bool `json:"smoke_chat_ok"`
|
|
}
|
|
|
|
type CapabilityProfile struct {
|
|
TransportProfile TransportProfile `json:"transport_profile"`
|
|
ModelProfiles []ModelCapabilityProfile `json:"model_profiles"`
|
|
}
|
|
|
|
func ProbeCapabilities(ctx context.Context, baseURL, apiKey string, rawModels []string) (*CapabilityProfile, error) {
|
|
profile := &CapabilityProfile{
|
|
TransportProfile: TransportProfile{
|
|
SupportsOpenAIModels: len(rawModels) > 0,
|
|
AuthStyle: "bearer",
|
|
ModelIDStyle: detectModelIDStyle(rawModels),
|
|
KnownAdvisories: []string{},
|
|
},
|
|
ModelProfiles: make([]ModelCapabilityProfile, 0, len(rawModels)),
|
|
}
|
|
|
|
responsesStatus, err := probeJSONEndpoint(ctx, baseURL, apiKey, "/v1/responses", map[string]any{
|
|
"model": firstNonEmptyModel(rawModels),
|
|
"input": "ping",
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
profile.TransportProfile.SupportsOpenAIResponses = responsesStatus >= http.StatusOK && responsesStatus < http.StatusMultipleChoices
|
|
|
|
for _, rawModel := range rawModels {
|
|
modelProfile := ModelCapabilityProfile{
|
|
RawModelID: strings.TrimSpace(rawModel),
|
|
NormalizedModelID: NormalizeModelID(rawModel),
|
|
CanonicalModelFamily: CanonicalModelFamily(rawModel),
|
|
SupportsStream: "unknown",
|
|
SupportsTools: "unknown",
|
|
SupportsReasoningFields: "unknown",
|
|
}
|
|
|
|
chatStatus, err := probeJSONEndpoint(ctx, baseURL, apiKey, "/v1/chat/completions", map[string]any{
|
|
"model": modelProfile.RawModelID,
|
|
"messages": []map[string]string{
|
|
{"role": "user", "content": "ping"},
|
|
},
|
|
"max_tokens": 8,
|
|
"temperature": 0,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
modelProfile.SmokeChatOK = chatStatus >= http.StatusOK && chatStatus < http.StatusMultipleChoices
|
|
if modelProfile.SmokeChatOK {
|
|
profile.TransportProfile.SupportsOpenAIChatCompletions = true
|
|
}
|
|
if chatStatus == http.StatusForbidden {
|
|
appendAdvisory(&profile.TransportProfile.KnownAdvisories, "initial_probe_race_expected")
|
|
}
|
|
|
|
profile.ModelProfiles = append(profile.ModelProfiles, modelProfile)
|
|
}
|
|
|
|
if !profile.TransportProfile.SupportsOpenAIResponses && profile.TransportProfile.SupportsOpenAIChatCompletions {
|
|
appendAdvisory(&profile.TransportProfile.KnownAdvisories, "responses_unsupported_but_chat_ok")
|
|
}
|
|
|
|
return profile, nil
|
|
}
|
|
|
|
func probeJSONEndpoint(ctx context.Context, baseURL, apiKey, path string, payload any) (int, error) {
|
|
requestURL, err := joinGatewayPath(baseURL, path)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("resolve %s endpoint: %w", path, err)
|
|
}
|
|
|
|
var body bytes.Buffer
|
|
if payload != nil {
|
|
if err := json.NewEncoder(&body).Encode(payload); err != nil {
|
|
return 0, fmt.Errorf("encode %s probe payload: %w", path, err)
|
|
}
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, &body)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("build %s request: %w", path, err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
if token := strings.TrimSpace(apiKey); token != "" {
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
}
|
|
|
|
resp, err := (&http.Client{Timeout: 15 * time.Second}).Do(req)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("request %s: %w", path, err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
return resp.StatusCode, nil
|
|
}
|
|
|
|
func detectModelIDStyle(rawModels []string) string {
|
|
for _, rawModel := range rawModels {
|
|
if strings.Contains(strings.TrimSpace(rawModel), "/") {
|
|
return "vendor_prefixed"
|
|
}
|
|
}
|
|
return "canonical"
|
|
}
|
|
|
|
func appendAdvisory(values *[]string, advisory string) {
|
|
if advisory == "" {
|
|
return
|
|
}
|
|
for _, existing := range *values {
|
|
if existing == advisory {
|
|
return
|
|
}
|
|
}
|
|
*values = append(*values, advisory)
|
|
}
|
|
|
|
func firstNonEmptyModel(rawModels []string) string {
|
|
for _, rawModel := range rawModels {
|
|
if trimmed := strings.TrimSpace(rawModel); trimmed != "" {
|
|
return trimmed
|
|
}
|
|
}
|
|
return "ping"
|
|
}
|