feat(probe): add model discovery and canonical family normalization
This commit is contained in:
97
internal/probe/models.go
Normal file
97
internal/probe/models.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package probe
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var ErrAuthFailed = errors.New("upstream auth failed")
|
||||
|
||||
type ModelsResult struct {
|
||||
RawModels []string
|
||||
HTTPStatus int
|
||||
LatencyMs int64
|
||||
Error string
|
||||
}
|
||||
|
||||
func ProviderModels(ctx context.Context, baseURL, apiKey string) (*ModelsResult, error) {
|
||||
client := &http.Client{Timeout: 15 * time.Second}
|
||||
|
||||
requestURL, err := joinGatewayPath(baseURL, "/v1/models")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolve models endpoint: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build models request: %w", err)
|
||||
}
|
||||
if token := strings.TrimSpace(apiKey); token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
|
||||
startedAt := time.Now()
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request models: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
result := &ModelsResult{
|
||||
RawModels: []string{},
|
||||
HTTPStatus: resp.StatusCode,
|
||||
LatencyMs: time.Since(startedAt).Milliseconds(),
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
Error any `json:"error"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
|
||||
return nil, fmt.Errorf("decode models response: %w", err)
|
||||
}
|
||||
|
||||
switch resp.StatusCode {
|
||||
case http.StatusUnauthorized, http.StatusForbidden:
|
||||
result.Error = "auth_failed"
|
||||
return result, ErrAuthFailed
|
||||
}
|
||||
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
result.Error = fmt.Sprintf("unexpected_status_%d", resp.StatusCode)
|
||||
return result, fmt.Errorf("models endpoint returned status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
for _, item := range payload.Data {
|
||||
if modelID := strings.TrimSpace(item.ID); modelID != "" {
|
||||
result.RawModels = append(result.RawModels, modelID)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func joinGatewayPath(baseURL, path string) (string, error) {
|
||||
parsedURL, err := url.Parse(strings.TrimSpace(baseURL))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if parsedURL.Scheme == "" || parsedURL.Host == "" {
|
||||
return "", fmt.Errorf("base url must include scheme and host")
|
||||
}
|
||||
|
||||
resolvedPath := strings.TrimSpace(path)
|
||||
if !strings.HasPrefix(resolvedPath, "/") {
|
||||
resolvedPath = "/" + resolvedPath
|
||||
}
|
||||
|
||||
return parsedURL.ResolveReference(&url.URL{Path: resolvedPath}).String(), nil
|
||||
}
|
||||
Reference in New Issue
Block a user