Files
llm-intelligence/internal/collectors/provider_mapper_test.go
2026-05-13 14:42:45 +08:00

168 lines
4.0 KiB
Go

// internal/collectors/provider_mapper_test.go
package collectors
import (
"testing"
)
func TestMapOpenRouterID(t *testing.T) {
tests := []struct {
name string
rawID string
wantErr bool
wantProvID string
wantProvCN string
wantModel string
wantFree bool
wantCountry string
}{
{
name: "OpenAI GPT-4o",
rawID: "openai/gpt-4o",
wantProvID: "openai",
wantProvCN: "OpenAI",
wantModel: "gpt-4o",
wantFree: false,
wantCountry: "US",
},
{
name: "Anthropic Claude free",
rawID: "anthropic/claude-3.5-sonnet:free",
wantProvID: "anthropic",
wantProvCN: "Anthropic",
wantModel: "claude-3.5-sonnet",
wantFree: true,
wantCountry: "US",
},
{
name: "DeepSeek V3",
rawID: "deepseek/deepseek-v3",
wantProvID: "deepseek",
wantProvCN: "深度求索",
wantModel: "deepseek-v3",
wantFree: false,
wantCountry: "CN",
},
{
name: "Moonshot Kimi",
rawID: "moonshotai/kimi-k2",
wantProvID: "moonshot",
wantProvCN: "月之暗面",
wantModel: "kimi-k2",
wantFree: false,
wantCountry: "CN",
},
{
name: "Unknown provider fallback",
rawID: "some-new-ai/model-x",
wantProvID: "some-new-ai",
wantProvCN: "some-new-ai",
wantModel: "model-x",
wantFree: false,
wantCountry: "unknown",
},
{
name: "Empty ID",
rawID: "",
wantErr: true,
},
{
name: "Invalid format no slash",
rawID: "invalid-id",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := MapOpenRouterID(tt.rawID)
if (err != nil) != tt.wantErr {
t.Errorf("MapOpenRouterID(%q) error = %v, wantErr %v", tt.rawID, err, tt.wantErr)
return
}
if tt.wantErr {
return
}
if got.Provider.ID != tt.wantProvID {
t.Errorf("Provider.ID = %q, want %q", got.Provider.ID, tt.wantProvID)
}
if got.Provider.NameCN != tt.wantProvCN {
t.Errorf("Provider.NameCN = %q, want %q", got.Provider.NameCN, tt.wantProvCN)
}
if got.ModelName != tt.wantModel {
t.Errorf("ModelName = %q, want %q", got.ModelName, tt.wantModel)
}
if got.IsFree != tt.wantFree {
t.Errorf("IsFree = %v, want %v", got.IsFree, tt.wantFree)
}
if got.Provider.Country != tt.wantCountry {
t.Errorf("Country = %q, want %q", got.Provider.Country, tt.wantCountry)
}
})
}
}
func TestProviderMapCompleteness(t *testing.T) {
// 验证所有预定义的厂商映射
requiredProviders := []string{
"openai", "anthropic", "google", "meta", "xai",
"deepseek", "qwen", "moonshot", "zhipu", "bytedance",
"baidu", "tencent", "alibaba", "mistral", "cohere",
"ai21", "perplexity", "nvidia", "microsoft", "openrouter",
}
for _, id := range requiredProviders {
_, ok := providerNameMap[id]
if !ok {
t.Errorf("Required provider %q not found in providerNameMap", id)
}
}
// 验证总数 >= 20
if ProviderCount() < 20 {
t.Errorf("ProviderCount() = %d, want >= 20", ProviderCount())
}
}
func TestRegisterProvider(t *testing.T) {
// 注册新厂商
RegisterProvider("test-corp", ProviderInfo{
ID: "test-corp",
Name: "Test Corp",
NameCN: "测试公司",
Country: "CN",
})
got, err := MapOpenRouterID("test-corp/model-1")
if err != nil {
t.Fatalf("MapOpenRouterID after RegisterProvider failed: %v", err)
}
if got.Provider.NameCN != "测试公司" {
t.Errorf("After RegisterProvider, NameCN = %q, want %q", got.Provider.NameCN, "测试公司")
}
}
func TestGetAllProviderNames(t *testing.T) {
names := GetAllProviderNames()
if len(names) == 0 {
t.Error("GetAllProviderNames() returned empty slice")
}
// 验证包含 openai
found := false
for _, n := range names {
if n == "openai" {
found = true
break
}
}
if !found {
t.Error("GetAllProviderNames() missing 'openai'")
}
}
func BenchmarkMapOpenRouterID(b *testing.B) {
for i := 0; i < b.N; i++ {
_, _ = MapOpenRouterID("openai/gpt-4o")
}
}