475 lines
13 KiB
Go
475 lines
13 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func TestModelsHandlerReturnsFlatPricingFields(t *testing.T) {
|
|
mux := newMux(
|
|
&sql.DB{},
|
|
func(context.Context, *sql.DB) ([]modelResponse, error) {
|
|
return []modelResponse{{
|
|
ID: "mobile-cloud-huabei-huhehaote-cosyvoice",
|
|
Name: "CosyVoice",
|
|
Provider: "Alibaba",
|
|
ProviderCN: "阿里云",
|
|
Modality: "audio",
|
|
PricingMode: "flat",
|
|
PriceUnit: "10k_characters",
|
|
FlatPrice: 2,
|
|
Currency: "CNY",
|
|
IsFree: false,
|
|
DataConfidence: "official",
|
|
}}, nil
|
|
},
|
|
func(context.Context, *sql.DB) ([]subscriptionPlanResponse, error) {
|
|
return nil, nil
|
|
},
|
|
func(context.Context, *sql.DB) (*latestReportResponse, error) {
|
|
return nil, sql.ErrNoRows
|
|
},
|
|
)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/models", nil)
|
|
rec := httptest.NewRecorder()
|
|
mux.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("expected status 200, got %d", rec.Code)
|
|
}
|
|
|
|
var payload struct {
|
|
Data []modelResponse `json:"data"`
|
|
}
|
|
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
|
|
t.Fatalf("unmarshal response: %v", err)
|
|
}
|
|
if len(payload.Data) != 1 {
|
|
t.Fatalf("expected 1 model, got %d", len(payload.Data))
|
|
}
|
|
got := payload.Data[0]
|
|
if got.PricingMode != "flat" || got.PriceUnit != "10k_characters" || got.FlatPrice != 2 {
|
|
t.Fatalf("unexpected flat pricing payload: %+v", got)
|
|
}
|
|
}
|
|
|
|
func TestModelsHandlerReturnsJSONErrorEnvelope(t *testing.T) {
|
|
mux := newMux(
|
|
nil,
|
|
func(context.Context, *sql.DB) ([]modelResponse, error) {
|
|
return nil, nil
|
|
},
|
|
func(context.Context, *sql.DB) ([]subscriptionPlanResponse, error) {
|
|
return nil, nil
|
|
},
|
|
func(context.Context, *sql.DB) (*latestReportResponse, error) {
|
|
return nil, sql.ErrNoRows
|
|
},
|
|
)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/models", nil)
|
|
rec := httptest.NewRecorder()
|
|
mux.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusServiceUnavailable {
|
|
t.Fatalf("expected status 503, got %d", rec.Code)
|
|
}
|
|
|
|
var payload struct {
|
|
Error struct {
|
|
Code string `json:"code"`
|
|
Message string `json:"message"`
|
|
} `json:"error"`
|
|
}
|
|
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
|
|
t.Fatalf("unmarshal error response: %v", err)
|
|
}
|
|
if payload.Error.Code != "database_not_configured" {
|
|
t.Fatalf("unexpected error code: %q", payload.Error.Code)
|
|
}
|
|
}
|
|
|
|
func TestHealthHandlerReturnsJSONErrorEnvelope(t *testing.T) {
|
|
mux := newMux(
|
|
nil,
|
|
func(context.Context, *sql.DB) ([]modelResponse, error) {
|
|
return nil, nil
|
|
},
|
|
func(context.Context, *sql.DB) ([]subscriptionPlanResponse, error) {
|
|
return nil, nil
|
|
},
|
|
func(context.Context, *sql.DB) (*latestReportResponse, error) {
|
|
return nil, sql.ErrNoRows
|
|
},
|
|
)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/health", nil)
|
|
rec := httptest.NewRecorder()
|
|
mux.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusServiceUnavailable {
|
|
t.Fatalf("expected status 503, got %d", rec.Code)
|
|
}
|
|
|
|
var payload struct {
|
|
Error struct {
|
|
Code string `json:"code"`
|
|
Message string `json:"message"`
|
|
} `json:"error"`
|
|
}
|
|
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
|
|
t.Fatalf("unmarshal health error response: %v", err)
|
|
}
|
|
if payload.Error.Code != "database_not_configured" {
|
|
t.Fatalf("unexpected error code: %q", payload.Error.Code)
|
|
}
|
|
}
|
|
|
|
func TestLatestReportHTMLHandlerReturnsJSONErrorEnvelope(t *testing.T) {
|
|
mux := newMux(
|
|
&sql.DB{},
|
|
func(context.Context, *sql.DB) ([]modelResponse, error) {
|
|
return nil, nil
|
|
},
|
|
func(context.Context, *sql.DB) ([]subscriptionPlanResponse, error) {
|
|
return nil, nil
|
|
},
|
|
func(context.Context, *sql.DB) (*latestReportResponse, error) {
|
|
return nil, sql.ErrNoRows
|
|
},
|
|
)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/reports/latest/html", nil)
|
|
rec := httptest.NewRecorder()
|
|
mux.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusNotFound {
|
|
t.Fatalf("expected status 404, got %d", rec.Code)
|
|
}
|
|
|
|
var payload struct {
|
|
Error struct {
|
|
Code string `json:"code"`
|
|
Message string `json:"message"`
|
|
} `json:"error"`
|
|
}
|
|
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
|
|
t.Fatalf("unmarshal latest html error response: %v", err)
|
|
}
|
|
if payload.Error.Code != "latest_report_not_found" {
|
|
t.Fatalf("unexpected error code: %q", payload.Error.Code)
|
|
}
|
|
}
|
|
|
|
func TestFetchModelsQueryEncodesPrimaryPricePriority(t *testing.T) {
|
|
fragments := []string{
|
|
"CASE WHEN lower(rp.region) = 'global' THEN 0 ELSE 1 END",
|
|
"WHEN 'official' THEN 0",
|
|
"WHEN 'reseller' THEN 1",
|
|
"WHEN 'free_tier' THEN 2",
|
|
"rp.effective_date DESC NULLS LAST",
|
|
"rp.id DESC",
|
|
}
|
|
|
|
for _, fragment := range fragments {
|
|
if !strings.Contains(fetchModelsQuery, fragment) {
|
|
t.Fatalf("fetchModelsQuery missing fragment %q", fragment)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestSubscriptionPlansHandlerReturnsEnvelope(t *testing.T) {
|
|
mux := newMux(
|
|
&sql.DB{},
|
|
func(context.Context, *sql.DB) ([]modelResponse, error) {
|
|
return nil, nil
|
|
},
|
|
func(context.Context, *sql.DB) ([]subscriptionPlanResponse, error) {
|
|
return []subscriptionPlanResponse{
|
|
{
|
|
PlanFamily: "token_plan",
|
|
PlanCode: "token-plan-lite",
|
|
PlanName: "通用 Token Plan Lite",
|
|
Tier: "Lite",
|
|
Provider: "Tencent",
|
|
ProviderCN: "腾讯",
|
|
Operator: "Tencent Cloud",
|
|
OperatorCN: "腾讯云",
|
|
Currency: "CNY",
|
|
ListPrice: 39,
|
|
PriceUnit: "CNY/month",
|
|
QuotaValue: 35000000,
|
|
QuotaUnit: "tokens/month",
|
|
ContextWindow: 0,
|
|
ModelScope: []string{"tc-code-latest", "glm-5", "glm-5.1"},
|
|
SourceURL: "https://cloud.tencent.com/document/product/1823/130060",
|
|
EffectiveDate: "2026-04-27",
|
|
},
|
|
}, nil
|
|
},
|
|
func(context.Context, *sql.DB) (*latestReportResponse, error) {
|
|
return nil, sql.ErrNoRows
|
|
},
|
|
)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/subscription-plans", nil)
|
|
rec := httptest.NewRecorder()
|
|
mux.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("expected status 200, got %d", rec.Code)
|
|
}
|
|
|
|
var payload struct {
|
|
Data []subscriptionPlanResponse `json:"data"`
|
|
}
|
|
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
|
|
t.Fatalf("unmarshal response: %v", err)
|
|
}
|
|
|
|
if len(payload.Data) != 1 {
|
|
t.Fatalf("expected 1 plan, got %d", len(payload.Data))
|
|
}
|
|
|
|
got := payload.Data[0]
|
|
if got.PlanCode != "token-plan-lite" {
|
|
t.Fatalf("unexpected plan code: %q", got.PlanCode)
|
|
}
|
|
if got.ProviderCN != "腾讯" {
|
|
t.Fatalf("unexpected providerCN: %q", got.ProviderCN)
|
|
}
|
|
if got.OperatorCN != "腾讯云" {
|
|
t.Fatalf("unexpected operatorCN: %q", got.OperatorCN)
|
|
}
|
|
if got.ListPrice != 39 {
|
|
t.Fatalf("unexpected list price: %v", got.ListPrice)
|
|
}
|
|
if len(got.ModelScope) != 3 {
|
|
t.Fatalf("unexpected model scope length: %d", len(got.ModelScope))
|
|
}
|
|
}
|
|
|
|
func TestLatestReportHandlerReturnsEnvelope(t *testing.T) {
|
|
mux := newMux(
|
|
&sql.DB{},
|
|
func(context.Context, *sql.DB) ([]modelResponse, error) {
|
|
return nil, nil
|
|
},
|
|
func(context.Context, *sql.DB) ([]subscriptionPlanResponse, error) {
|
|
return nil, nil
|
|
},
|
|
func(context.Context, *sql.DB) (*latestReportResponse, error) {
|
|
return &latestReportResponse{
|
|
ReportDate: "2026-05-13",
|
|
Status: "generated",
|
|
ModelCount: 504,
|
|
MarkdownPath: "reports/daily/daily_report_2026-05-13.md",
|
|
HTMLPath: "reports/daily/html/daily_report_2026-05-13.html",
|
|
MarkdownURL: "/api/v1/reports/latest/markdown",
|
|
HTMLURL: "/api/v1/reports/latest/html",
|
|
}, nil
|
|
},
|
|
)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/reports/latest", nil)
|
|
rec := httptest.NewRecorder()
|
|
mux.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("expected status 200, got %d", rec.Code)
|
|
}
|
|
|
|
var payload struct {
|
|
Data latestReportResponse `json:"data"`
|
|
}
|
|
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
|
|
t.Fatalf("unmarshal response: %v", err)
|
|
}
|
|
|
|
if payload.Data.ReportDate != "2026-05-13" {
|
|
t.Fatalf("unexpected report date: %q", payload.Data.ReportDate)
|
|
}
|
|
if payload.Data.HTMLURL != "/api/v1/reports/latest/html" {
|
|
t.Fatalf("unexpected html url: %q", payload.Data.HTMLURL)
|
|
}
|
|
}
|
|
|
|
func TestLatestReportHTMLHandlerServesArtifact(t *testing.T) {
|
|
tempDir := t.TempDir()
|
|
htmlPath := tempDir + "/daily_report_2026-05-13.html"
|
|
if err := os.WriteFile(htmlPath, []byte("<html><body>ok</body></html>"), 0644); err != nil {
|
|
t.Fatalf("write temp html: %v", err)
|
|
}
|
|
|
|
mux := newMux(
|
|
&sql.DB{},
|
|
func(context.Context, *sql.DB) ([]modelResponse, error) {
|
|
return nil, nil
|
|
},
|
|
func(context.Context, *sql.DB) ([]subscriptionPlanResponse, error) {
|
|
return nil, nil
|
|
},
|
|
func(context.Context, *sql.DB) (*latestReportResponse, error) {
|
|
return &latestReportResponse{
|
|
ReportDate: "2026-05-13",
|
|
Status: "generated",
|
|
MarkdownPath: tempDir + "/daily_report_2026-05-13.md",
|
|
HTMLPath: htmlPath,
|
|
}, nil
|
|
},
|
|
)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/reports/latest/html", nil)
|
|
rec := httptest.NewRecorder()
|
|
mux.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("expected status 200, got %d", rec.Code)
|
|
}
|
|
if body := rec.Body.String(); body != "<html><body>ok</body></html>" {
|
|
t.Fatalf("unexpected body: %q", body)
|
|
}
|
|
}
|
|
|
|
func TestModelsHandlerRejectsUnauthenticatedExternalRequests(t *testing.T) {
|
|
mux := newMuxWithConfig(
|
|
&sql.DB{},
|
|
func(context.Context, *sql.DB) ([]modelResponse, error) {
|
|
return nil, nil
|
|
},
|
|
func(context.Context, *sql.DB) ([]subscriptionPlanResponse, error) {
|
|
return nil, nil
|
|
},
|
|
func(context.Context, *sql.DB) (*latestReportResponse, error) {
|
|
return nil, sql.ErrNoRows
|
|
},
|
|
serverConfig{BasicAuthUser: "review", BasicAuthPass: "secret", RateLimitPerWindow: 10, RateLimitWindow: time.Minute},
|
|
)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/models", nil)
|
|
req.RemoteAddr = "198.51.100.8:1234"
|
|
rec := httptest.NewRecorder()
|
|
mux.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusUnauthorized {
|
|
t.Fatalf("expected status 401, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestModelsHandlerAllowsBasicAuthForExternalRequests(t *testing.T) {
|
|
mux := newMuxWithConfig(
|
|
&sql.DB{},
|
|
func(context.Context, *sql.DB) ([]modelResponse, error) {
|
|
return []modelResponse{{ID: "openai/gpt-4o", Name: "GPT-4o"}}, nil
|
|
},
|
|
func(context.Context, *sql.DB) ([]subscriptionPlanResponse, error) {
|
|
return nil, nil
|
|
},
|
|
func(context.Context, *sql.DB) (*latestReportResponse, error) {
|
|
return nil, sql.ErrNoRows
|
|
},
|
|
serverConfig{BasicAuthUser: "review", BasicAuthPass: "secret", RateLimitPerWindow: 10, RateLimitWindow: time.Minute},
|
|
)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/models", nil)
|
|
req.RemoteAddr = "198.51.100.8:1234"
|
|
req.SetBasicAuth("review", "secret")
|
|
rec := httptest.NewRecorder()
|
|
mux.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("expected status 200, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestModelsHandlerAllowsBearerTokenForExternalRequests(t *testing.T) {
|
|
mux := newMuxWithConfig(
|
|
&sql.DB{},
|
|
func(context.Context, *sql.DB) ([]modelResponse, error) {
|
|
return []modelResponse{{ID: "openai/gpt-4o", Name: "GPT-4o"}}, nil
|
|
},
|
|
func(context.Context, *sql.DB) ([]subscriptionPlanResponse, error) {
|
|
return nil, nil
|
|
},
|
|
func(context.Context, *sql.DB) (*latestReportResponse, error) {
|
|
return nil, sql.ErrNoRows
|
|
},
|
|
serverConfig{ServiceToken: "token-123", RateLimitPerWindow: 10, RateLimitWindow: time.Minute},
|
|
)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/models", nil)
|
|
req.RemoteAddr = "198.51.100.8:1234"
|
|
req.Header.Set("Authorization", "Bearer token-123")
|
|
rec := httptest.NewRecorder()
|
|
mux.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("expected status 200, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestHealthHandlerRejectsExternalRequests(t *testing.T) {
|
|
mux := newMuxWithConfig(
|
|
&sql.DB{},
|
|
func(context.Context, *sql.DB) ([]modelResponse, error) {
|
|
return nil, nil
|
|
},
|
|
func(context.Context, *sql.DB) ([]subscriptionPlanResponse, error) {
|
|
return nil, nil
|
|
},
|
|
func(context.Context, *sql.DB) (*latestReportResponse, error) {
|
|
return nil, sql.ErrNoRows
|
|
},
|
|
serverConfig{RateLimitPerWindow: 10, RateLimitWindow: time.Minute},
|
|
)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/health", nil)
|
|
req.RemoteAddr = "198.51.100.8:1234"
|
|
rec := httptest.NewRecorder()
|
|
mux.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusForbidden {
|
|
t.Fatalf("expected status 403, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestModelsHandlerAppliesRateLimit(t *testing.T) {
|
|
mux := newMuxWithConfig(
|
|
&sql.DB{},
|
|
func(context.Context, *sql.DB) ([]modelResponse, error) {
|
|
return []modelResponse{{ID: "openai/gpt-4o", Name: "GPT-4o"}}, nil
|
|
},
|
|
func(context.Context, *sql.DB) ([]subscriptionPlanResponse, error) {
|
|
return nil, nil
|
|
},
|
|
func(context.Context, *sql.DB) (*latestReportResponse, error) {
|
|
return nil, sql.ErrNoRows
|
|
},
|
|
serverConfig{RateLimitPerWindow: 1, RateLimitWindow: time.Minute},
|
|
)
|
|
|
|
first := httptest.NewRequest(http.MethodGet, "/api/v1/models", nil)
|
|
first.RemoteAddr = "127.0.0.1:1234"
|
|
firstRec := httptest.NewRecorder()
|
|
mux.ServeHTTP(firstRec, first)
|
|
if firstRec.Code != http.StatusOK {
|
|
t.Fatalf("expected first request status 200, got %d", firstRec.Code)
|
|
}
|
|
|
|
second := httptest.NewRequest(http.MethodGet, "/api/v1/models", nil)
|
|
second.RemoteAddr = "127.0.0.1:1234"
|
|
secondRec := httptest.NewRecorder()
|
|
mux.ServeHTTP(secondRec, second)
|
|
if secondRec.Code != http.StatusTooManyRequests {
|
|
t.Fatalf("expected second request status 429, got %d", secondRec.Code)
|
|
}
|
|
}
|