forked from niuniu/llm-intelligence
109 lines
2.7 KiB
Go
109 lines
2.7 KiB
Go
|
|
//go:build llm_script
|
||
|
|
|
||
|
|
package main
|
||
|
|
|
||
|
|
import (
|
||
|
|
"bytes"
|
||
|
|
"testing"
|
||
|
|
)
|
||
|
|
|
||
|
|
type fakeSource struct {
|
||
|
|
name string
|
||
|
|
prices []ModelPricing
|
||
|
|
err error
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s fakeSource) Name() string { return s.name }
|
||
|
|
|
||
|
|
func (s fakeSource) FetchPricing() ([]ModelPricing, error) { return s.prices, s.err }
|
||
|
|
|
||
|
|
func (s fakeSource) SourceType() string { return "official" }
|
||
|
|
|
||
|
|
func TestBuildSourcesFiltersRequestedNames(t *testing.T) {
|
||
|
|
sources, err := buildSources("", []string{"moonshot", "openai"})
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("buildSources returned error: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
if len(sources) != 2 {
|
||
|
|
t.Fatalf("expected 2 sources, got %d", len(sources))
|
||
|
|
}
|
||
|
|
|
||
|
|
if sources[0].Name() != "Moonshot" || sources[1].Name() != "OpenAI" {
|
||
|
|
t.Fatalf("unexpected source order: %s, %s", sources[0].Name(), sources[1].Name())
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestBuildSourcesRejectsUnknownNames(t *testing.T) {
|
||
|
|
_, err := buildSources("", []string{"moonshot", "unknown"})
|
||
|
|
if err == nil {
|
||
|
|
t.Fatal("expected error for unknown source")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestRunCollectorDryRunSkipsDatabaseWrite(t *testing.T) {
|
||
|
|
cfg := runConfig{DryRun: true}
|
||
|
|
var out bytes.Buffer
|
||
|
|
writeCalled := false
|
||
|
|
|
||
|
|
err := runCollector(
|
||
|
|
cfg,
|
||
|
|
[]DataSource{
|
||
|
|
fakeSource{
|
||
|
|
name: "Moonshot",
|
||
|
|
prices: []ModelPricing{
|
||
|
|
{ModelID: "kimi-k2.6", ProviderCountry: "CN", Currency: "CNY"},
|
||
|
|
{ModelID: "kimi-k2-0905-preview", ProviderCountry: "CN", Currency: "CNY"},
|
||
|
|
},
|
||
|
|
},
|
||
|
|
fakeSource{
|
||
|
|
name: "OpenAI",
|
||
|
|
prices: []ModelPricing{
|
||
|
|
{ModelID: "gpt-5.5", ProviderCountry: "US", Currency: "USD"},
|
||
|
|
},
|
||
|
|
},
|
||
|
|
},
|
||
|
|
func([]ModelPricing) error {
|
||
|
|
writeCalled = true
|
||
|
|
return nil
|
||
|
|
},
|
||
|
|
&out,
|
||
|
|
)
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("runCollector returned error: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
if writeCalled {
|
||
|
|
t.Fatal("expected dry-run to skip database write")
|
||
|
|
}
|
||
|
|
|
||
|
|
output := out.String()
|
||
|
|
if output == "" {
|
||
|
|
t.Fatal("expected dry-run summary output")
|
||
|
|
}
|
||
|
|
if !bytes.Contains(out.Bytes(), []byte("sources=2")) {
|
||
|
|
t.Fatalf("expected sources summary, got %q", output)
|
||
|
|
}
|
||
|
|
if !bytes.Contains(out.Bytes(), []byte("models=3")) {
|
||
|
|
t.Fatalf("expected model summary, got %q", output)
|
||
|
|
}
|
||
|
|
if !bytes.Contains(out.Bytes(), []byte("currencies=CNY:2,USD:1")) {
|
||
|
|
t.Fatalf("expected currency summary, got %q", output)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestPricingMetadataClassifiesSourceType(t *testing.T) {
|
||
|
|
freeTier := pricingMetadata(ModelPricing{OperatorType: "official", IsFree: true})
|
||
|
|
if freeTier.SourceType != "free_tier" {
|
||
|
|
t.Fatalf("expected free_tier, got %q", freeTier.SourceType)
|
||
|
|
}
|
||
|
|
if freeTier.FreeQuota == "" {
|
||
|
|
t.Fatal("expected free tier quota description")
|
||
|
|
}
|
||
|
|
|
||
|
|
reseller := pricingMetadata(ModelPricing{OperatorType: "reseller"})
|
||
|
|
if reseller.SourceType != "reseller" {
|
||
|
|
t.Fatalf("expected reseller, got %q", reseller.SourceType)
|
||
|
|
}
|
||
|
|
}
|