From 2bc7554cf8fc4e93f2e7e30124bcc013f6088bea Mon Sep 17 00:00:00 2001 From: phamnazage-jpg Date: Fri, 22 May 2026 14:29:51 +0800 Subject: [PATCH] feat(probe): add model discovery and canonical family normalization --- internal/probe/aliases.go | 126 +++++++++++++++++++++++++++++++++ internal/probe/aliases_test.go | 61 ++++++++++++++++ internal/probe/models.go | 97 +++++++++++++++++++++++++ internal/probe/models_test.go | 78 ++++++++++++++++++++ 4 files changed, 362 insertions(+) create mode 100644 internal/probe/aliases.go create mode 100644 internal/probe/aliases_test.go create mode 100644 internal/probe/models.go create mode 100644 internal/probe/models_test.go diff --git a/internal/probe/aliases.go b/internal/probe/aliases.go new file mode 100644 index 00000000..3c9fe0ef --- /dev/null +++ b/internal/probe/aliases.go @@ -0,0 +1,126 @@ +package probe + +import ( + "strings" + "unicode" +) + +type AliasResult struct { + Raw string + Normalized string + Canonical string +} + +func NormalizeModelID(raw string) string { + trimmed := strings.TrimSpace(strings.ToLower(raw)) + if trimmed == "" { + return "" + } + + if idx := strings.LastIndex(trimmed, "/"); idx >= 0 { + trimmed = trimmed[idx+1:] + } + + replacer := strings.NewReplacer("_", "-", " ", "-", ".", ".", "--", "-") + normalized := replacer.Replace(trimmed) + for strings.Contains(normalized, "--") { + normalized = strings.ReplaceAll(normalized, "--", "-") + } + return strings.Trim(normalized, "-") +} + +func CanonicalModelID(raw string) string { + return NormalizeModelID(raw) +} + +func CanonicalModelFamily(raw string) string { + normalized := NormalizeModelID(raw) + switch { + case strings.HasPrefix(normalized, "kimi-k2."): + return strings.Replace(normalized, "kimi-k2.", "kimi-2.", 1) + case strings.HasPrefix(normalized, "kimi-k2-"): + return strings.Replace(normalized, "kimi-k2-", "kimi-2-", 1) + default: + return normalized + } +} + +func BuildAliasTable(rawModels []string) map[string]AliasResult { + table := make(map[string]AliasResult, len(rawModels)*4) + for _, rawModel := range rawModels { + rawModel = strings.TrimSpace(rawModel) + if rawModel == "" { + continue + } + + result := AliasResult{ + Raw: rawModel, + Normalized: NormalizeModelID(rawModel), + Canonical: CanonicalModelFamily(rawModel), + } + + keys := []string{ + rawModel, + result.Normalized, + result.Canonical, + CanonicalModelID(rawModel), + lookupKey(rawModel), + lookupKey(result.Normalized), + lookupKey(result.Canonical), + lookupKey(CanonicalModelID(rawModel)), + } + for _, key := range keys { + if key == "" { + continue + } + if _, exists := table[key]; !exists { + table[key] = result + } + } + } + return table +} + +func ResolveRequestedModel(requested string, rawModels []string) (resolved string, ok bool) { + result, ok := BuildAliasTable(rawModels)[lookupKey(requested)] + if !ok { + return "", false + } + return result.Raw, true +} + +func RecommendModels(requested []string, rawModels []string) []string { + table := BuildAliasTable(rawModels) + recommended := make([]string, 0, len(requested)) + seen := make(map[string]struct{}, len(requested)) + + for _, requestedModel := range requested { + result, ok := table[lookupKey(requestedModel)] + if !ok { + continue + } + if _, exists := seen[result.Raw]; exists { + continue + } + seen[result.Raw] = struct{}{} + recommended = append(recommended, result.Raw) + } + + return recommended +} + +func lookupKey(raw string) string { + canonical := CanonicalModelFamily(raw) + if canonical == "" { + return "" + } + + var builder strings.Builder + builder.Grow(len(canonical)) + for _, r := range canonical { + if unicode.IsLetter(r) || unicode.IsDigit(r) { + builder.WriteRune(unicode.ToLower(r)) + } + } + return builder.String() +} diff --git a/internal/probe/aliases_test.go b/internal/probe/aliases_test.go new file mode 100644 index 00000000..bd0f27b0 --- /dev/null +++ b/internal/probe/aliases_test.go @@ -0,0 +1,61 @@ +package probe + +import ( + "reflect" + "testing" +) + +func TestCanonicalModelFamily(t *testing.T) { + t.Parallel() + + t.Run("kimi aliases collapse into one family", func(t *testing.T) { + t.Parallel() + + variants := []string{"kimi 2.6", "kimi-2.6", "kimi-k2.6", "Kimi-K2.6"} + for _, variant := range variants { + if got := CanonicalModelFamily(variant); got != "kimi-2.6" { + t.Fatalf("CanonicalModelFamily(%q) = %q, want %q", variant, got, "kimi-2.6") + } + } + }) + + t.Run("deepseek vendor prefix normalizes away", func(t *testing.T) { + t.Parallel() + + if got := NormalizeModelID("deepseek-ai/DeepSeek-V4-Pro"); got != "deepseek-v4-pro" { + t.Fatalf("NormalizeModelID() = %q, want %q", got, "deepseek-v4-pro") + } + if got := CanonicalModelID("deepseek-ai/DeepSeek-V4-Pro"); got != "deepseek-v4-pro" { + t.Fatalf("CanonicalModelID() = %q, want %q", got, "deepseek-v4-pro") + } + }) + + t.Run("alias table and requested model resolution prefer discovered ids", func(t *testing.T) { + t.Parallel() + + rawModels := []string{"deepseek-ai/DeepSeek-V4-Pro", "kimi-k2.6"} + table := BuildAliasTable(rawModels) + if got := table["deepseek-v4-pro"].Canonical; got != "deepseek-v4-pro" { + t.Fatalf("alias canonical = %q, want %q", got, "deepseek-v4-pro") + } + + resolved, ok := ResolveRequestedModel("DeepSeek V4 Pro", rawModels) + if !ok { + t.Fatal("ResolveRequestedModel() ok = false, want true") + } + if resolved != "deepseek-ai/DeepSeek-V4-Pro" { + t.Fatalf("ResolveRequestedModel() = %q, want discovered raw id", resolved) + } + }) + + t.Run("recommend models returns canonical discovered candidates", func(t *testing.T) { + t.Parallel() + + rawModels := []string{"kimi-k2.6", "deepseek-ai/DeepSeek-V4-Pro"} + got := RecommendModels([]string{"kimi 2.6", "DeepSeek V4 Pro", "unknown"}, rawModels) + want := []string{"kimi-k2.6", "deepseek-ai/DeepSeek-V4-Pro"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("RecommendModels() = %#v, want %#v", got, want) + } + }) +} diff --git a/internal/probe/models.go b/internal/probe/models.go new file mode 100644 index 00000000..88c293e7 --- /dev/null +++ b/internal/probe/models.go @@ -0,0 +1,97 @@ +package probe + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "strings" + "time" +) + +var ErrAuthFailed = errors.New("upstream auth failed") + +type ModelsResult struct { + RawModels []string + HTTPStatus int + LatencyMs int64 + Error string +} + +func ProviderModels(ctx context.Context, baseURL, apiKey string) (*ModelsResult, error) { + client := &http.Client{Timeout: 15 * time.Second} + + requestURL, err := joinGatewayPath(baseURL, "/v1/models") + if err != nil { + return nil, fmt.Errorf("resolve models endpoint: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil) + if err != nil { + return nil, fmt.Errorf("build models request: %w", err) + } + if token := strings.TrimSpace(apiKey); token != "" { + req.Header.Set("Authorization", "Bearer "+token) + } + + startedAt := time.Now() + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("request models: %w", err) + } + defer resp.Body.Close() + + result := &ModelsResult{ + RawModels: []string{}, + HTTPStatus: resp.StatusCode, + LatencyMs: time.Since(startedAt).Milliseconds(), + } + + var payload struct { + Data []struct { + ID string `json:"id"` + } `json:"data"` + Error any `json:"error"` + } + if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { + return nil, fmt.Errorf("decode models response: %w", err) + } + + switch resp.StatusCode { + case http.StatusUnauthorized, http.StatusForbidden: + result.Error = "auth_failed" + return result, ErrAuthFailed + } + + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + result.Error = fmt.Sprintf("unexpected_status_%d", resp.StatusCode) + return result, fmt.Errorf("models endpoint returned status %d", resp.StatusCode) + } + + for _, item := range payload.Data { + if modelID := strings.TrimSpace(item.ID); modelID != "" { + result.RawModels = append(result.RawModels, modelID) + } + } + + return result, nil +} + +func joinGatewayPath(baseURL, path string) (string, error) { + parsedURL, err := url.Parse(strings.TrimSpace(baseURL)) + if err != nil { + return "", err + } + if parsedURL.Scheme == "" || parsedURL.Host == "" { + return "", fmt.Errorf("base url must include scheme and host") + } + + resolvedPath := strings.TrimSpace(path) + if !strings.HasPrefix(resolvedPath, "/") { + resolvedPath = "/" + resolvedPath + } + + return parsedURL.ResolveReference(&url.URL{Path: resolvedPath}).String(), nil +} diff --git a/internal/probe/models_test.go b/internal/probe/models_test.go new file mode 100644 index 00000000..c2a559cd --- /dev/null +++ b/internal/probe/models_test.go @@ -0,0 +1,78 @@ +package probe + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "testing" +) + +func TestProviderModels(t *testing.T) { + t.Parallel() + + t.Run("parses openai models response", func(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/models" { + t.Fatalf("path = %q, want /v1/models", r.URL.Path) + } + if got := r.Header.Get("Authorization"); got != "Bearer sk-test" { + t.Fatalf("authorization = %q, want bearer auth", got) + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"data":[{"id":" kimi 2.6 "},{"id":"deepseek-ai/DeepSeek-V4-Pro"}]}`)) + })) + defer server.Close() + + result, err := ProviderModels(context.Background(), server.URL, "sk-test") + if err != nil { + t.Fatalf("ProviderModels() error = %v", err) + } + if result.HTTPStatus != http.StatusOK { + t.Fatalf("HTTPStatus = %d, want %d", result.HTTPStatus, http.StatusOK) + } + if len(result.RawModels) != 2 { + t.Fatalf("len(RawModels) = %d, want 2", len(result.RawModels)) + } + if result.RawModels[0] != "kimi 2.6" || result.RawModels[1] != "deepseek-ai/DeepSeek-V4-Pro" { + t.Fatalf("RawModels = %#v, want normalized trim order", result.RawModels) + } + }) + + t.Run("returns empty slice when upstream has no models", func(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"data":[]}`)) + })) + defer server.Close() + + result, err := ProviderModels(context.Background(), server.URL, "sk-empty") + if err != nil { + t.Fatalf("ProviderModels() error = %v", err) + } + if result.HTTPStatus != http.StatusOK { + t.Fatalf("HTTPStatus = %d, want %d", result.HTTPStatus, http.StatusOK) + } + if len(result.RawModels) != 0 { + t.Fatalf("len(RawModels) = %d, want 0", len(result.RawModels)) + } + }) + + t.Run("classifies auth failure", func(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, `{"error":"forbidden"}`, http.StatusForbidden) + })) + defer server.Close() + + _, err := ProviderModels(context.Background(), server.URL, "sk-nope") + if !errors.Is(err, ErrAuthFailed) { + t.Fatalf("ProviderModels() error = %v, want ErrAuthFailed", err) + } + }) +}