diff --git a/internal/probe/capability.go b/internal/probe/capability.go new file mode 100644 index 00000000..7fc9aad5 --- /dev/null +++ b/internal/probe/capability.go @@ -0,0 +1,157 @@ +package probe + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + "time" +) + +type TransportProfile struct { + SupportsOpenAIModels bool `json:"supports_openai_models"` + SupportsOpenAIChatCompletions bool `json:"supports_openai_chat_completions"` + SupportsOpenAIResponses bool `json:"supports_openai_responses"` + SupportsAnthropicMessages bool `json:"supports_anthropic_messages"` + AuthStyle string `json:"auth_style"` + ModelIDStyle string `json:"model_id_style"` + KnownAdvisories []string `json:"known_advisories"` +} + +type ModelCapabilityProfile struct { + RawModelID string `json:"raw_model_id"` + NormalizedModelID string `json:"normalized_model_id"` + CanonicalModelFamily string `json:"canonical_model_family"` + SupportsStream string `json:"supports_stream"` + SupportsTools string `json:"supports_tools"` + SupportsReasoningFields string `json:"supports_reasoning_fields"` + SmokeChatOK bool `json:"smoke_chat_ok"` +} + +type CapabilityProfile struct { + TransportProfile TransportProfile `json:"transport_profile"` + ModelProfiles []ModelCapabilityProfile `json:"model_profiles"` +} + +func ProbeCapabilities(ctx context.Context, baseURL, apiKey string, rawModels []string) (*CapabilityProfile, error) { + profile := &CapabilityProfile{ + TransportProfile: TransportProfile{ + SupportsOpenAIModels: len(rawModels) > 0, + AuthStyle: "bearer", + ModelIDStyle: detectModelIDStyle(rawModels), + KnownAdvisories: []string{}, + }, + ModelProfiles: make([]ModelCapabilityProfile, 0, len(rawModels)), + } + + responsesStatus, err := probeJSONEndpoint(ctx, baseURL, apiKey, "/v1/responses", map[string]any{ + "model": firstNonEmptyModel(rawModels), + "input": "ping", + }) + if err != nil { + return nil, err + } + profile.TransportProfile.SupportsOpenAIResponses = responsesStatus >= http.StatusOK && responsesStatus < http.StatusMultipleChoices + + for _, rawModel := range rawModels { + modelProfile := ModelCapabilityProfile{ + RawModelID: strings.TrimSpace(rawModel), + NormalizedModelID: NormalizeModelID(rawModel), + CanonicalModelFamily: CanonicalModelFamily(rawModel), + SupportsStream: "unknown", + SupportsTools: "unknown", + SupportsReasoningFields: "unknown", + } + + chatStatus, err := probeJSONEndpoint(ctx, baseURL, apiKey, "/v1/chat/completions", map[string]any{ + "model": modelProfile.RawModelID, + "messages": []map[string]string{ + {"role": "user", "content": "ping"}, + }, + "max_tokens": 8, + "temperature": 0, + }) + if err != nil { + return nil, err + } + + modelProfile.SmokeChatOK = chatStatus >= http.StatusOK && chatStatus < http.StatusMultipleChoices + if modelProfile.SmokeChatOK { + profile.TransportProfile.SupportsOpenAIChatCompletions = true + } + if chatStatus == http.StatusForbidden { + appendAdvisory(&profile.TransportProfile.KnownAdvisories, "initial_probe_race_expected") + } + + profile.ModelProfiles = append(profile.ModelProfiles, modelProfile) + } + + if !profile.TransportProfile.SupportsOpenAIResponses && profile.TransportProfile.SupportsOpenAIChatCompletions { + appendAdvisory(&profile.TransportProfile.KnownAdvisories, "responses_unsupported_but_chat_ok") + } + + return profile, nil +} + +func probeJSONEndpoint(ctx context.Context, baseURL, apiKey, path string, payload any) (int, error) { + requestURL, err := joinGatewayPath(baseURL, path) + if err != nil { + return 0, fmt.Errorf("resolve %s endpoint: %w", path, err) + } + + var body bytes.Buffer + if payload != nil { + if err := json.NewEncoder(&body).Encode(payload); err != nil { + return 0, fmt.Errorf("encode %s probe payload: %w", path, err) + } + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, &body) + if err != nil { + return 0, fmt.Errorf("build %s request: %w", path, err) + } + req.Header.Set("Content-Type", "application/json") + if token := strings.TrimSpace(apiKey); token != "" { + req.Header.Set("Authorization", "Bearer "+token) + } + + resp, err := (&http.Client{Timeout: 15 * time.Second}).Do(req) + if err != nil { + return 0, fmt.Errorf("request %s: %w", path, err) + } + defer resp.Body.Close() + + return resp.StatusCode, nil +} + +func detectModelIDStyle(rawModels []string) string { + for _, rawModel := range rawModels { + if strings.Contains(strings.TrimSpace(rawModel), "/") { + return "vendor_prefixed" + } + } + return "canonical" +} + +func appendAdvisory(values *[]string, advisory string) { + if advisory == "" { + return + } + for _, existing := range *values { + if existing == advisory { + return + } + } + *values = append(*values, advisory) +} + +func firstNonEmptyModel(rawModels []string) string { + for _, rawModel := range rawModels { + if trimmed := strings.TrimSpace(rawModel); trimmed != "" { + return trimmed + } + } + return "ping" +} diff --git a/internal/probe/capability_test.go b/internal/probe/capability_test.go new file mode 100644 index 00000000..8d33f08b --- /dev/null +++ b/internal/probe/capability_test.go @@ -0,0 +1,119 @@ +package probe + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestProbeCapabilities(t *testing.T) { + t.Parallel() + + t.Run("responses unsupported but chat works", func(t *testing.T) { + t.Parallel() + + var responseCalls int + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/v1/responses": + responseCalls++ + http.Error(w, `{"error":"unsupported"}`, http.StatusForbidden) + case "/v1/chat/completions": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"chatcmpl-1","choices":[{"message":{"content":"pong"}}]}`)) + default: + t.Fatalf("unexpected path %q", r.URL.Path) + } + })) + defer server.Close() + + profile, err := ProbeCapabilities(context.Background(), server.URL, "sk-test", []string{"kimi-k2.6"}) + if err != nil { + t.Fatalf("ProbeCapabilities() error = %v", err) + } + if !profile.TransportProfile.SupportsOpenAIChatCompletions { + t.Fatal("SupportsOpenAIChatCompletions = false, want true") + } + if profile.TransportProfile.SupportsOpenAIResponses { + t.Fatal("SupportsOpenAIResponses = true, want false") + } + if responseCalls == 0 { + t.Fatal("responses endpoint was not probed") + } + if !containsString(profile.TransportProfile.KnownAdvisories, "responses_unsupported_but_chat_ok") { + t.Fatalf("KnownAdvisories = %#v, want responses advisory", profile.TransportProfile.KnownAdvisories) + } + }) + + t.Run("records per model capability profile", func(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/v1/responses": + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"id":"resp_1"}`)) + case "/v1/chat/completions": + body := make([]byte, r.ContentLength) + _, _ = r.Body.Read(body) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"chatcmpl-1","choices":[{"message":{"content":"ok"}}]}`)) + default: + t.Fatalf("unexpected path %q", r.URL.Path) + } + })) + defer server.Close() + + profile, err := ProbeCapabilities(context.Background(), server.URL, "sk-test", []string{"deepseek-ai/DeepSeek-V4-Pro", "kimi-k2.6"}) + if err != nil { + t.Fatalf("ProbeCapabilities() error = %v", err) + } + if len(profile.ModelProfiles) != 2 { + t.Fatalf("len(ModelProfiles) = %d, want 2", len(profile.ModelProfiles)) + } + if profile.ModelProfiles[0].NormalizedModelID != "deepseek-v4-pro" { + t.Fatalf("NormalizedModelID = %q, want %q", profile.ModelProfiles[0].NormalizedModelID, "deepseek-v4-pro") + } + if profile.ModelProfiles[0].CanonicalModelFamily != "deepseek-v4-pro" { + t.Fatalf("CanonicalModelFamily = %q, want %q", profile.ModelProfiles[0].CanonicalModelFamily, "deepseek-v4-pro") + } + if !profile.ModelProfiles[0].SmokeChatOK { + t.Fatal("SmokeChatOK = false, want true") + } + }) + + t.Run("records initial probe advisory on transient auth race", func(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/v1/responses": + http.Error(w, `{"error":"forbidden"}`, http.StatusForbidden) + case "/v1/chat/completions": + http.Error(w, `{"error":"warmup"}`, http.StatusForbidden) + default: + t.Fatalf("unexpected path %q", r.URL.Path) + } + })) + defer server.Close() + + profile, err := ProbeCapabilities(context.Background(), server.URL, "sk-test", []string{"kimi-k2.6"}) + if err != nil { + t.Fatalf("ProbeCapabilities() error = %v", err) + } + if !containsString(profile.TransportProfile.KnownAdvisories, "initial_probe_race_expected") { + t.Fatalf("KnownAdvisories = %#v, want initial probe advisory", profile.TransportProfile.KnownAdvisories) + } + }) +} + +func containsString(values []string, want string) bool { + for _, value := range values { + if strings.TrimSpace(value) == want { + return true + } + } + return false +} diff --git a/internal/probe/completion.go b/internal/probe/completion.go new file mode 100644 index 00000000..a78c034e --- /dev/null +++ b/internal/probe/completion.go @@ -0,0 +1,123 @@ +package probe + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + "time" +) + +type CompletionResult struct { + Model string + HTTPStatus int + LatencyMs int64 + Classification string + Error string +} + +func ResolveSmokeModel(requested []string, rawModels []string, profile *CapabilityProfile) (string, []string, error) { + recommended := RecommendModels(requested, rawModels) + for _, candidate := range recommended { + if profileAllowsSmoke(profile, candidate) { + return candidate, recommended, nil + } + } + + for _, rawModel := range rawModels { + if strings.TrimSpace(rawModel) == "" { + continue + } + if profileAllowsSmoke(profile, rawModel) { + return rawModel, recommended, nil + } + } + + if len(rawModels) > 0 && strings.TrimSpace(rawModels[0]) != "" { + return rawModels[0], recommended, nil + } + + return "", recommended, fmt.Errorf("no smoke model available") +} + +func SmokeCompletion(ctx context.Context, baseURL, apiKey, model string, profile *CapabilityProfile) (*CompletionResult, error) { + model = strings.TrimSpace(model) + if model == "" { + return nil, fmt.Errorf("model is required") + } + + path := "/v1/chat/completions" + classification := "chat_completions" + payload := map[string]any{ + "model": model, + "messages": []map[string]string{ + {"role": "user", "content": "ping"}, + }, + "max_tokens": 8, + "temperature": 0, + } + + if profile != nil && profile.TransportProfile.SupportsOpenAIResponses { + path = "/v1/responses" + classification = "responses" + payload = map[string]any{ + "model": model, + "input": "ping", + } + } + + requestURL, err := joinGatewayPath(baseURL, path) + if err != nil { + return nil, fmt.Errorf("resolve smoke endpoint: %w", err) + } + + var body bytes.Buffer + if err := json.NewEncoder(&body).Encode(payload); err != nil { + return nil, fmt.Errorf("encode smoke payload: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, &body) + if err != nil { + return nil, fmt.Errorf("build smoke request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + if token := strings.TrimSpace(apiKey); token != "" { + req.Header.Set("Authorization", "Bearer "+token) + } + + startedAt := time.Now() + resp, err := (&http.Client{Timeout: 15 * time.Second}).Do(req) + if err != nil { + return nil, fmt.Errorf("request smoke completion: %w", err) + } + defer resp.Body.Close() + + result := &CompletionResult{ + Model: model, + HTTPStatus: resp.StatusCode, + LatencyMs: time.Since(startedAt).Milliseconds(), + Classification: classification, + } + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + result.Error = fmt.Sprintf("unexpected_status_%d", resp.StatusCode) + } + return result, nil +} + +func profileAllowsSmoke(profile *CapabilityProfile, rawModel string) bool { + if profile == nil || len(profile.ModelProfiles) == 0 { + return true + } + + targetRaw := strings.TrimSpace(rawModel) + targetCanonical := CanonicalModelFamily(rawModel) + for _, modelProfile := range profile.ModelProfiles { + if strings.TrimSpace(modelProfile.RawModelID) == targetRaw || modelProfile.CanonicalModelFamily == targetCanonical { + return modelProfile.SmokeChatOK + } + } + + return false +} diff --git a/internal/probe/completion_test.go b/internal/probe/completion_test.go new file mode 100644 index 00000000..691aea0f --- /dev/null +++ b/internal/probe/completion_test.go @@ -0,0 +1,94 @@ +package probe + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestResolveSmokeModel(t *testing.T) { + t.Parallel() + + t.Run("uses requested alias when matched", func(t *testing.T) { + t.Parallel() + + profile := &CapabilityProfile{ + ModelProfiles: []ModelCapabilityProfile{ + {RawModelID: "kimi-k2.6", CanonicalModelFamily: "kimi-2.6", SmokeChatOK: true}, + }, + } + + model, recommended, err := ResolveSmokeModel([]string{"kimi 2.6"}, []string{"kimi-k2.6"}, profile) + if err != nil { + t.Fatalf("ResolveSmokeModel() error = %v", err) + } + if model != "kimi-k2.6" { + t.Fatalf("ResolveSmokeModel() model = %q, want %q", model, "kimi-k2.6") + } + if len(recommended) != 1 || recommended[0] != "kimi-k2.6" { + t.Fatalf("recommended = %#v, want discovered alias", recommended) + } + }) + + t.Run("falls back to discovered model with smoke support", func(t *testing.T) { + t.Parallel() + + profile := &CapabilityProfile{ + ModelProfiles: []ModelCapabilityProfile{ + {RawModelID: "deepseek-ai/DeepSeek-V4-Pro", CanonicalModelFamily: "deepseek-v4-pro", SmokeChatOK: true}, + }, + } + + model, recommended, err := ResolveSmokeModel([]string{"unknown"}, []string{"deepseek-ai/DeepSeek-V4-Pro"}, profile) + if err != nil { + t.Fatalf("ResolveSmokeModel() error = %v", err) + } + if model != "deepseek-ai/DeepSeek-V4-Pro" { + t.Fatalf("ResolveSmokeModel() model = %q, want discovered model", model) + } + if len(recommended) != 0 { + t.Fatalf("recommended = %#v, want empty for unknown request", recommended) + } + }) +} + +func TestSmokeCompletion(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/chat/completions" { + t.Fatalf("path = %q, want chat completions fallback", r.URL.Path) + } + var payload map[string]any + if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { + t.Fatalf("decode request body: %v", err) + } + if payload["model"] != "kimi-k2.6" { + t.Fatalf("payload model = %v, want kimi-k2.6", payload["model"]) + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"chatcmpl-1","choices":[{"message":{"content":"pong"}}]}`)) + })) + defer server.Close() + + profile := &CapabilityProfile{ + TransportProfile: TransportProfile{ + SupportsOpenAIChatCompletions: true, + SupportsOpenAIResponses: false, + KnownAdvisories: []string{"responses_unsupported_but_chat_ok"}, + }, + } + + result, err := SmokeCompletion(context.Background(), server.URL, "sk-test", "kimi-k2.6", profile) + if err != nil { + t.Fatalf("SmokeCompletion() error = %v", err) + } + if result.HTTPStatus != http.StatusOK { + t.Fatalf("HTTPStatus = %d, want %d", result.HTTPStatus, http.StatusOK) + } + if result.Classification != "chat_completions" { + t.Fatalf("Classification = %q, want %q", result.Classification, "chat_completions") + } +}