Files
sub2api-cn-relay-manager/internal/host/sub2api/sub2api_test.go

1202 lines
39 KiB
Go

package sub2api
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestHTTPErrorErrorMessage(t *testing.T) {
e := newHTTPError("POST", "/api/v1/admin/groups", http.StatusTeapot, []byte("short and stout"))
want := "sub2api POST /api/v1/admin/groups returned 418: short and stout"
if got := e.Error(); got != want {
t.Fatalf("HTTPError.Error() = %q, want %q", got, want)
}
}
func TestWithHTTPClientAndOptions(t *testing.T) {
customHTTP := &http.Client{Timeout: 123}
client, err := NewClient("http://localhost:8080",
WithHTTPClient(customHTTP),
WithAPIKey(" sk-abc "),
WithBearerToken(" tok-xyz "),
)
if err != nil {
t.Fatal(err)
}
if client.httpClient != customHTTP {
t.Fatal("WithHTTPClient not applied")
}
if client.apiKey != "sk-abc" {
t.Fatalf("apiKey = %q, want %q", client.apiKey, "sk-abc")
}
if client.bearerToken != "tok-xyz" {
t.Fatalf("bearerToken = %q, want %q", client.bearerToken, "tok-xyz")
}
}
func TestNewClient_RejectsInvalidURLs(t *testing.T) {
tests := []struct {
name string
url string
}{
{"empty", ""},
{"no scheme", "localhost:8080"},
{"no host", "http://"},
{"garbage", "://foo"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := NewClient(tt.url)
if err == nil {
t.Fatalf("NewClient(%q) error = nil, want error", tt.url)
}
})
}
}
func TestResolvePath(t *testing.T) {
client, err := NewClient("http://host:9090")
if err != nil {
t.Fatal(err)
}
tests := []struct {
path string
want string
}{
{"/v1/models", "http://host:9090/v1/models"},
{"v1/models", "http://host:9090/v1/models"},
{"/v1/models?key=val", "http://host:9090/v1/models?key=val"},
}
for _, tt := range tests {
t.Run(tt.path, func(t *testing.T) {
if got := client.resolvePath(tt.path); got != tt.want {
t.Fatalf("resolvePath(%q) = %q, want %q", tt.path, got, tt.want)
}
})
}
}
func TestApplyAuth(t *testing.T) {
t.Run("api key preferred", func(t *testing.T) {
c, _ := NewClient("http://h:8080", WithAPIKey("key1"), WithBearerToken("btok"))
req, _ := http.NewRequest("GET", "http://h:8080/path", nil)
c.applyAuth(req)
if h := req.Header.Get("x-api-key"); h != "key1" {
t.Fatalf("x-api-key = %q, want %q", h, "key1")
}
if h := req.Header.Get("Authorization"); h != "" {
t.Fatalf("Authorization should be empty, got %q", h)
}
})
t.Run("bearer token fallback", func(t *testing.T) {
c, _ := NewClient("http://h:8080", WithBearerToken("btok"))
req, _ := http.NewRequest("GET", "http://h:8080/path", nil)
c.applyAuth(req)
if h := req.Header.Get("Authorization"); h != "Bearer btok" {
t.Fatalf("Authorization = %q, want %q", h, "Bearer btok")
}
})
t.Run("no auth", func(t *testing.T) {
c, _ := NewClient("http://h:8080")
req, _ := http.NewRequest("GET", "http://h:8080/path", nil)
c.applyAuth(req)
if h := req.Header.Get("x-api-key"); h != "" {
t.Fatalf("x-api-key should be empty, got %q", h)
}
if h := req.Header.Get("Authorization"); h != "" {
t.Fatalf("Authorization should be empty, got %q", h)
}
})
}
func TestDecodeEnvelopeObject(t *testing.T) {
t.Run("standard envelope", func(t *testing.T) {
body := []byte(`{"data":{"id":"g1","name":"test"}}`)
var ref GroupRef
if err := decodeEnvelopeObject(body, &ref); err != nil {
t.Fatal(err)
}
if ref.ID != "g1" || ref.Name != "test" {
t.Fatalf("got %+v, want {ID:g1 Name:test}", ref)
}
})
t.Run("flat response (no data wrapper)", func(t *testing.T) {
body := []byte(`{"id":"g2","name":"flat"}`)
var ref GroupRef
if err := decodeEnvelopeObject(body, &ref); err != nil {
t.Fatal(err)
}
if ref.ID != "g2" || ref.Name != "flat" {
t.Fatalf("got %+v, want {ID:g2 Name:flat}", ref)
}
})
t.Run("data:null returns flat", func(t *testing.T) {
body := []byte(`{"data":null,"id":"g3"}`)
var ref GroupRef
if err := decodeEnvelopeObject(body, &ref); err != nil {
t.Fatal(err)
}
if ref.ID != "g3" {
t.Fatalf("id = %q, want %q", ref.ID, "g3")
}
})
t.Run("invalid json returns error", func(t *testing.T) {
var ref GroupRef
if err := decodeEnvelopeObject([]byte(`not json`), &ref); err == nil {
t.Fatal("expected error")
}
})
}
func TestDecodeGatewayModelIDs(t *testing.T) {
t.Run("standard list", func(t *testing.T) {
ids := decodeGatewayModelIDs([]byte(`{"data":[{"id":"gpt-4"},{"id":" claude-3 "}]}`))
if len(ids) != 2 || ids[0] != "gpt-4" || ids[1] != "claude-3" {
t.Fatalf("got %v, want [gpt-4 claude-3]", ids)
}
})
t.Run("empty data", func(t *testing.T) {
if ids := decodeGatewayModelIDs([]byte(`{}`)); ids != nil {
t.Fatalf("expected nil, got %v", ids)
}
})
t.Run("invalid json", func(t *testing.T) {
if ids := decodeGatewayModelIDs([]byte(`not json`)); ids != nil {
t.Fatalf("expected nil, got %v", ids)
}
})
t.Run("empty array", func(t *testing.T) {
if ids := decodeGatewayModelIDs([]byte(`{"data":[]}`)); ids != nil {
t.Fatalf("expected nil, got %v", ids)
}
})
}
func TestFilterNamedResourcesByName(t *testing.T) {
resources := []NamedResource{
{Name: "group-a", ID: "g1"},
{Name: "group-b", ID: "g2"},
{Name: " group-a ", ID: "g3"},
}
t.Run("match", func(t *testing.T) {
got := filterNamedResourcesByName(resources, "group-a")
if len(got) != 2 || got[0].ID != "g1" || got[1].ID != "g3" {
t.Fatalf("got %+v, want 2 matches", got)
}
})
t.Run("no match", func(t *testing.T) {
if got := filterNamedResourcesByName(resources, "nonexistent"); len(got) != 0 {
t.Fatalf("expected 0, got %d", len(got))
}
})
t.Run("empty name returns all", func(t *testing.T) {
if got := filterNamedResourcesByName(resources, ""); len(got) != 3 {
t.Fatalf("expected 3, got %d", len(got))
}
})
}
func TestFilterNamedResourcesByPrefix(t *testing.T) {
resources := []NamedResource{
{Name: "deepseek-proxy", ID: "r1"},
{Name: "deepseek-us", ID: "r2"},
{Name: "claude-eu", ID: "r3"},
}
t.Run("prefix matches", func(t *testing.T) {
got := filterNamedResourcesByPrefix(resources, "deepseek")
if len(got) != 2 {
t.Fatalf("expected 2, got %d", len(got))
}
})
t.Run("no prefix match", func(t *testing.T) {
if got := filterNamedResourcesByPrefix(resources, "nope"); len(got) != 0 {
t.Fatalf("expected 0, got %d", len(got))
}
})
t.Run("empty prefix returns all", func(t *testing.T) {
if got := filterNamedResourcesByPrefix(resources, ""); len(got) != 3 {
t.Fatalf("expected 3, got %d", len(got))
}
})
}
func TestDecodeNamedResources(t *testing.T) {
t.Run("envelope", func(t *testing.T) {
resources, pages, err := decodeNamedResources([]byte(`{"data":[{"id":"r1","name":"n1"}]}`))
if err != nil {
t.Fatal(err)
}
if pages != 1 {
t.Fatalf("pages = %d, want 1", pages)
}
if len(resources) != 1 || resources[0].ID != "r1" {
t.Fatalf("got %+v", resources)
}
})
t.Run("numeric id", func(t *testing.T) {
resources, pages, err := decodeNamedResources([]byte(`{"data":{"items":[{"id":1,"name":"default"}],"pages":2}}`))
if err != nil {
t.Fatal(err)
}
if pages != 2 {
t.Fatalf("pages = %d, want 2", pages)
}
if len(resources) != 1 || resources[0].ID != "1" {
t.Fatalf("got %+v", resources)
}
})
t.Run("wrapper with items", func(t *testing.T) {
resources, pages, err := decodeNamedResources([]byte(`{"data":{"items":[{"id":"r2","name":"n2"}]}}`))
if err != nil {
t.Fatal(err)
}
if pages != 1 {
t.Fatalf("pages = %d, want 1", pages)
}
if len(resources) != 1 || resources[0].ID != "r2" {
t.Fatalf("got %+v", resources)
}
})
t.Run("invalid json", func(t *testing.T) {
_, _, err := decodeNamedResources([]byte(`not json`))
if err == nil {
t.Fatal("expected error")
}
})
}
func TestDecodeAccountRefs(t *testing.T) {
t.Run("envelope", func(t *testing.T) {
refs, err := decodeAccountRefs([]byte(`{"data":[{"id":"a1"}]}`))
if err != nil {
t.Fatal(err)
}
if len(refs) != 1 || refs[0].ID != "a1" {
t.Fatalf("got %+v", refs)
}
})
t.Run("numeric id", func(t *testing.T) {
refs, err := decodeAccountRefs([]byte(`{"data":{"items":[{"id":42}]}}`))
if err != nil {
t.Fatal(err)
}
if len(refs) != 1 || refs[0].ID != "42" {
t.Fatalf("got %+v", refs)
}
})
t.Run("wrapper with items", func(t *testing.T) {
refs, err := decodeAccountRefs([]byte(`{"data":{"items":[{"id":"a2"}]}}`))
if err != nil {
t.Fatal(err)
}
if len(refs) != 1 || refs[0].ID != "a2" {
t.Fatalf("got %+v", refs)
}
})
t.Run("batch results", func(t *testing.T) {
refs, err := decodeAccountRefs([]byte(`{"success":1,"failed":0,"results":[{"name":"k1","id":123,"success":true}]}`))
if err != nil {
t.Fatal(err)
}
if len(refs) != 1 || refs[0].ID != "123" || refs[0].Name != "k1" {
t.Fatalf("got %+v", refs)
}
})
t.Run("batch results ignores failed items", func(t *testing.T) {
refs, err := decodeAccountRefs([]byte(`{"success":1,"failed":1,"results":[{"name":"k1","id":123,"success":true},{"name":"k2","id":456,"success":false}]}`))
if err != nil {
t.Fatal(err)
}
if len(refs) != 1 || refs[0].ID != "123" {
t.Fatalf("got %+v", refs)
}
})
t.Run("data wrapped batch results", func(t *testing.T) {
refs, err := decodeAccountRefs([]byte(`{"code":0,"message":"success","data":{"failed":0,"results":[{"id":5,"name":"deepseek-01","success":true}],"success":1}}`))
if err != nil {
t.Fatal(err)
}
if len(refs) != 1 || refs[0].ID != "5" || refs[0].Name != "deepseek-01" {
t.Fatalf("got %+v", refs)
}
})
t.Run("invalid json", func(t *testing.T) {
_, err := decodeAccountRefs([]byte(`not json`))
if err == nil {
t.Fatal("expected error")
}
})
}
func TestDecodeAccountModels(t *testing.T) {
t.Run("envelope", func(t *testing.T) {
models, err := decodeAccountModels([]byte(`{"data":[{"id":"gpt4","display_name":"GPT-4","type":"chat"}]}`))
if err != nil {
t.Fatal(err)
}
if len(models) != 1 || models[0].ID != "gpt4" {
t.Fatalf("got %+v", models)
}
})
t.Run("wrapper with items", func(t *testing.T) {
models, err := decodeAccountModels([]byte(`{"data":{"items":[{"id":"cl3","display_name":"Claude 3","type":"chat"}]}}`))
if err != nil {
t.Fatal(err)
}
if len(models) != 1 || models[0].ID != "cl3" {
t.Fatalf("got %+v", models)
}
})
t.Run("invalid json", func(t *testing.T) {
_, err := decodeAccountModels([]byte(`not json`))
if err == nil {
t.Fatal("expected error")
}
})
}
func TestParseProbeResult(t *testing.T) {
t.Run("SSE with ok=true", func(t *testing.T) {
result, err := parseProbeResult([]byte("data: {\"status\":\"passed\",\"ok\":true}\n"))
if err != nil {
t.Fatal(err)
}
if !result.OK || result.Status != "passed" {
t.Fatalf("got %+v, want OK=true Status=passed", result)
}
})
t.Run("SSE with success=true", func(t *testing.T) {
result, err := parseProbeResult([]byte("data: {\"status\":\"succeeded\",\"success\":true}\n"))
if err != nil {
t.Fatal(err)
}
if !result.OK || result.Status != "passed" {
t.Fatalf("got %+v", result)
}
})
t.Run("SSE with ok=false", func(t *testing.T) {
result, err := parseProbeResult([]byte("data: {\"status\":\"failed\",\"ok\":false}\n"))
if err != nil {
t.Fatal(err)
}
if result.OK || result.Status != "failed" {
t.Fatalf("got %+v", result)
}
})
t.Run("SSE with status-based ok", func(t *testing.T) {
result, err := parseProbeResult([]byte("data: {\"status\":\"pass\",\"message\":\"all good\"}\n"))
if err != nil {
t.Fatal(err)
}
if !result.OK || result.Message != "all good" {
t.Fatalf("got %+v", result)
}
})
t.Run("multiple SSE events picks last", func(t *testing.T) {
result, err := parseProbeResult([]byte("data: {\"status\":\"running\"}\ndata: {\"status\":\"passed\",\"ok\":true}\n"))
if err != nil {
t.Fatal(err)
}
if !result.OK {
t.Fatalf("expected OK=true from last event, got %+v", result)
}
})
t.Run("no data events", func(t *testing.T) {
_, err := parseProbeResult([]byte("not data\n"))
if err == nil {
t.Fatal("expected error")
}
})
}
func TestNormalizeProbeStatus(t *testing.T) {
tests := []struct {
status string
ok bool
want string
}{
{"pass", true, "passed"},
{"PASSED", true, "passed"},
{"Ok", true, "passed"},
{"success", true, "passed"},
{"succeeded", true, "passed"},
{"fail", false, "failed"},
{"FAILED", false, "failed"},
{"error", false, "failed"},
{"custom_ok", true, "passed"},
{"custom_fail", false, "failed"},
}
for _, tt := range tests {
t.Run(tt.status, func(t *testing.T) {
if got := normalizeProbeStatus(tt.status, tt.ok); got != tt.want {
t.Fatalf("normalizeProbeStatus(%q, %v) = %q, want %q", tt.status, tt.ok, got, tt.want)
}
})
}
}
func TestLooksLikeExistingEndpoint(t *testing.T) {
t.Run("json content type", func(t *testing.T) {
h := http.Header{"Content-Type": []string{"application/json"}}
if !looksLikeExistingEndpoint(h, nil) {
t.Fatal("expected true with json content type")
}
})
t.Run("sse content type", func(t *testing.T) {
h := http.Header{"Content-Type": []string{"text/event-stream"}}
if !looksLikeExistingEndpoint(h, nil) {
t.Fatal("expected true with sse content type")
}
})
t.Run("empty body and no content type", func(t *testing.T) {
if looksLikeExistingEndpoint(http.Header{}, nil) {
t.Fatal("expected false")
}
})
t.Run("json-like body", func(t *testing.T) {
if !looksLikeExistingEndpoint(http.Header{}, []byte(`{"error":"not found"}`)) {
t.Fatal("expected true for json body")
}
})
t.Run("array body", func(t *testing.T) {
if !looksLikeExistingEndpoint(http.Header{}, []byte(`[]`)) {
t.Fatal("expected true for array body")
}
})
t.Run("html body", func(t *testing.T) {
if looksLikeExistingEndpoint(http.Header{}, []byte(`<html>`)) {
t.Fatal("expected false for html body")
}
})
}
// Tests for NamedResource type used by the filter functions.
// Defined locally since it's in the same package.
func TestNewClientWithNilOption(t *testing.T) {
client, err := NewClient("http://localhost:8080", nil)
if err != nil {
t.Fatal(err)
}
if client == nil {
t.Fatal("client is nil")
}
}
func TestNewHTTPError(t *testing.T) {
e := newHTTPError("GET", "/v1/models", 200, []byte(`{"ok":true}`))
if e.Method != "GET" || e.Path != "/v1/models" || e.StatusCode != 200 || e.Body != `{"ok":true}` {
t.Fatalf("unexpected http error: %+v", e)
}
}
func TestPerformWithMockServer(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/v1/admin/system/version":
w.Write([]byte(`{"data":{"version":"v1.2.3"}}`))
case "/api/v1/admin/groups":
w.Write([]byte(`{"data":{"id":"g1","name":"test-group"}}`))
case "/api/v1/admin/channels":
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{"error":"panic"}`))
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
client, err := NewClient(srv.URL, WithAPIKey("test-key"))
if err != nil {
t.Fatal(err)
}
t.Run("GetHostVersion", func(t *testing.T) {
ver, err := client.GetHostVersion(context.Background())
if err != nil {
t.Fatal(err)
}
if ver != "v1.2.3" {
t.Fatalf("version = %q, want %q", ver, "v1.2.3")
}
})
t.Run("postJSON success", func(t *testing.T) {
var ref GroupRef
if err := client.postJSON(context.Background(), "/api/v1/admin/groups", CreateGroupRequest{Name: "test"}, &ref); err != nil {
t.Fatal(err)
}
if ref.ID != "g1" || ref.Name != "test-group" {
t.Fatalf("got %+v, want {ID:g1 Name:test-group}", ref)
}
})
t.Run("postJSON error status", func(t *testing.T) {
var ref GroupRef
err := client.postJSON(context.Background(), "/api/v1/admin/channels", nil, &ref)
if err == nil {
t.Fatal("expected error")
}
var httpErr *HTTPError
if !errors.As(err, &httpErr) {
t.Fatalf("expected HTTPError, got %T: %v", err, err)
}
if httpErr.StatusCode != 500 {
t.Fatalf("status code = %d, want 500", httpErr.StatusCode)
}
})
t.Run("getJSON success", func(t *testing.T) {
var ref GroupRef
if err := client.getJSON(context.Background(), "/api/v1/admin/groups", &ref); err != nil {
t.Fatal(err)
}
})
t.Run("getJSON error status", func(t *testing.T) {
var ref GroupRef
err := client.getJSON(context.Background(), "/bad/path", &ref)
if err == nil {
t.Fatal("expected error")
}
})
}
func TestCreateGroupWithMock(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req struct {
Name string `json:"name"`
Platform string `json:"platform"`
RateMultiplier float64 `json:"rate_multiplier"`
SubscriptionType string `json:"subscription_type"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("decode request: %v", err)
}
if req.Name != "demo" || req.Platform != "openai" || req.RateMultiplier != 1.0 {
t.Fatalf("unexpected request: %+v", req)
}
w.Write([]byte(`{"data":{"id":"g1","name":"demo"}}`))
}))
defer srv.Close()
client, err := NewClient(srv.URL, WithAPIKey("k"))
if err != nil {
t.Fatal(err)
}
ref, err := client.CreateGroup(context.Background(), CreateGroupRequest{Name: "demo", Platform: "openai", RateMultiplier: 1.0})
if err != nil {
t.Fatal(err)
}
if ref.ID != "g1" || ref.Name != "demo" {
t.Fatalf("got %+v", ref)
}
}
func TestCreateChannelWithMock(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req struct {
Name string `json:"name"`
GroupIDs []int64 `json:"group_ids"`
ModelMapping map[string]map[string]string `json:"model_mapping"`
ModelPricing []struct {
Platform string `json:"platform"`
Models []string `json:"models"`
BillingMode string `json:"billing_mode"`
InputPrice *float64 `json:"input_price"`
OutputPrice *float64 `json:"output_price"`
CacheWritePrice *float64 `json:"cache_write_price"`
CacheReadPrice *float64 `json:"cache_read_price"`
ImageOutputPrice *float64 `json:"image_output_price"`
PerRequestPrice *float64 `json:"per_request_price"`
Intervals []any `json:"intervals"`
} `json:"model_pricing"`
RestrictModels bool `json:"restrict_models"`
BillingModelSource string `json:"billing_model_source"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("decode request: %v", err)
}
if req.Name != "ch" {
t.Fatalf("name = %q, want ch", req.Name)
}
if len(req.GroupIDs) != 1 || req.GroupIDs[0] != 101 {
t.Fatalf("group_ids = %v, want [101]", req.GroupIDs)
}
if req.ModelMapping["openai"]["deepseek-v4-pro"] != "deepseek-v4-pro" {
t.Fatalf("model_mapping = %+v, want openai/deepseek-v4-pro passthrough", req.ModelMapping)
}
if len(req.ModelPricing) != 1 {
t.Fatalf("model_pricing len = %d, want 1", len(req.ModelPricing))
}
if req.ModelPricing[0].Platform != "openai" || req.ModelPricing[0].BillingMode != "token" {
t.Fatalf("model_pricing[0] = %+v, want openai/token entry", req.ModelPricing[0])
}
if len(req.ModelPricing[0].Models) != 1 || req.ModelPricing[0].Models[0] != "deepseek-v4-pro" {
t.Fatalf("model_pricing[0].models = %v, want [deepseek-v4-pro]", req.ModelPricing[0].Models)
}
if !req.RestrictModels {
t.Fatal("restrict_models = false, want true")
}
if req.BillingModelSource != "channel_mapped" {
t.Fatalf("billing_model_source = %q, want channel_mapped", req.BillingModelSource)
}
w.Write([]byte(`{"data":{"id":201,"name":"ch"}}`))
}))
defer srv.Close()
client, _ := NewClient(srv.URL, WithAPIKey("k"))
ref, err := client.CreateChannel(context.Background(), CreateChannelRequest{
Name: "ch",
GroupIDs: []string{"101"},
ModelMapping: map[string]string{"deepseek-v4-pro": "deepseek-v4-pro"},
ModelPricing: []ChannelModelPricing{{Platform: "openai", Models: []string{"deepseek-v4-pro"}, BillingMode: "token"}},
RestrictModels: true,
BillingModelSource: "channel_mapped",
})
if err != nil {
t.Fatal(err)
}
if ref.ID != "201" {
t.Fatalf("id = %q, want 201", ref.ID)
}
}
func TestUpdateChannelWithMock(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPut {
t.Fatalf("method = %s, want PUT", r.Method)
}
if r.URL.Path != "/api/v1/admin/channels/201" {
t.Fatalf("path = %s, want /api/v1/admin/channels/201", r.URL.Path)
}
var req struct {
ModelMapping map[string]map[string]string `json:"model_mapping"`
ModelPricing []struct {
Platform string `json:"platform"`
Models []string `json:"models"`
BillingMode string `json:"billing_mode"`
} `json:"model_pricing"`
RestrictModels bool `json:"restrict_models"`
BillingModelSource string `json:"billing_model_source"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("decode request: %v", err)
}
if req.ModelMapping["openai"]["deepseek-v4-pro"] != "deepseek-v4-pro" {
t.Fatalf("model_mapping = %+v, want openai/deepseek-v4-pro passthrough", req.ModelMapping)
}
if len(req.ModelPricing) != 1 || req.ModelPricing[0].Platform != "openai" || req.ModelPricing[0].BillingMode != "token" {
t.Fatalf("model_pricing = %+v, want openai/token entry", req.ModelPricing)
}
if !req.RestrictModels {
t.Fatal("restrict_models = false, want true")
}
if req.BillingModelSource != "channel_mapped" {
t.Fatalf("billing_model_source = %q, want channel_mapped", req.BillingModelSource)
}
w.Write([]byte(`{"data":{"id":201,"name":"ch"}}`))
}))
defer srv.Close()
client, _ := NewClient(srv.URL, WithAPIKey("k"))
if err := client.UpdateChannel(context.Background(), "201", CreateChannelRequest{
Name: "ch",
GroupIDs: []string{"101"},
ModelMapping: map[string]string{"deepseek-v4-pro": "deepseek-v4-pro"},
ModelPricing: []ChannelModelPricing{{Platform: "openai", Models: []string{"deepseek-v4-pro"}, BillingMode: "token"}},
RestrictModels: true,
BillingModelSource: "channel_mapped",
}); err != nil {
t.Fatal(err)
}
}
func TestCreateChannelRequestMarshalJSONDefaultsPricingPlatform(t *testing.T) {
t.Run("request platform", func(t *testing.T) {
payload, err := json.Marshal(CreateChannelRequest{
Name: "ch",
GroupIDs: []string{"101"},
Platform: "openai",
ModelMapping: map[string]string{"deepseek-v4-pro": "deepseek-v4-pro"},
ModelPricing: []ChannelModelPricing{{Models: []string{"deepseek-v4-pro"}, BillingMode: "token"}},
})
if err != nil {
t.Fatalf("Marshal() error = %v", err)
}
var got struct {
ModelMapping map[string]map[string]string `json:"model_mapping"`
ModelPricing []struct {
Platform string `json:"platform"`
} `json:"model_pricing"`
}
if err := json.Unmarshal(payload, &got); err != nil {
t.Fatalf("Unmarshal() error = %v", err)
}
if got.ModelMapping["openai"]["deepseek-v4-pro"] != "deepseek-v4-pro" {
t.Fatalf("model_mapping = %+v, want openai/deepseek-v4-pro passthrough", got.ModelMapping)
}
if len(got.ModelPricing) != 1 || got.ModelPricing[0].Platform != "openai" {
t.Fatalf("model_pricing = %+v, want platform openai", got.ModelPricing)
}
})
t.Run("openai fallback", func(t *testing.T) {
payload, err := json.Marshal(CreateChannelRequest{
Name: "ch",
GroupIDs: []string{"101"},
ModelMapping: map[string]string{"deepseek-v4-pro": "deepseek-v4-pro"},
ModelPricing: []ChannelModelPricing{{Models: []string{"deepseek-v4-pro"}, BillingMode: "token"}},
})
if err != nil {
t.Fatalf("Marshal() error = %v", err)
}
var got struct {
ModelPricing []struct {
Platform string `json:"platform"`
} `json:"model_pricing"`
}
if err := json.Unmarshal(payload, &got); err != nil {
t.Fatalf("Unmarshal() error = %v", err)
}
if len(got.ModelPricing) != 1 || got.ModelPricing[0].Platform != "openai" {
t.Fatalf("model_pricing = %+v, want platform openai fallback", got.ModelPricing)
}
})
}
func TestCreatePlanWithMock(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req struct {
GroupID int64 `json:"group_id"`
Name string `json:"name"`
Price float64 `json:"price"`
ValidityDays int `json:"validity_days"`
ValidityUnit string `json:"validity_unit"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("decode request: %v", err)
}
if req.GroupID != 101 || req.Name != "plan" || req.Price != 19.9 || req.ValidityDays != 30 || req.ValidityUnit != "day" {
t.Fatalf("unexpected request: %+v", req)
}
w.Write([]byte(`{"data":{"id":301,"name":"plan"}}`))
}))
defer srv.Close()
client, _ := NewClient(srv.URL, WithAPIKey("k"))
ref, err := client.CreatePlan(context.Background(), CreatePlanRequest{GroupID: "101", Name: "plan", Price: 19.9, ValidityDays: 30, ValidityUnit: "day"})
if err != nil {
t.Fatal(err)
}
if ref.ID != "301" {
t.Fatalf("id = %q, want 301", ref.ID)
}
}
func TestDeleteWithMock(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}))
defer srv.Close()
client, _ := NewClient(srv.URL, WithAPIKey("k"))
t.Run("DeleteGroup", func(t *testing.T) {
if err := client.DeleteGroup(context.Background(), "g1"); err != nil {
t.Fatal(err)
}
})
t.Run("DeleteChannel", func(t *testing.T) {
if err := client.DeleteChannel(context.Background(), "c1"); err != nil {
t.Fatal(err)
}
})
t.Run("DeletePlan", func(t *testing.T) {
if err := client.DeletePlan(context.Background(), "p1"); err != nil {
t.Fatal(err)
}
})
t.Run("DeleteAccount", func(t *testing.T) {
if err := client.DeleteAccount(context.Background(), "a1"); err != nil {
t.Fatal(err)
}
})
}
func TestAssignSubscriptionWithMock(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req struct {
UserID int64 `json:"user_id"`
GroupID int64 `json:"group_id"`
DurationDays int `json:"validity_days"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("decode request: %v", err)
}
if req.UserID != 501 || req.GroupID != 101 || req.DurationDays != 30 {
t.Fatalf("unexpected request: %+v", req)
}
w.Write([]byte(`{"data":{"id":401}}`))
}))
defer srv.Close()
client, _ := NewClient(srv.URL, WithAPIKey("k"))
ref, err := client.AssignSubscription(context.Background(), AssignSubscriptionRequest{UserID: "501", GroupID: "101", DurationDays: 30})
if err != nil {
t.Fatal(err)
}
if ref.ID != "401" {
t.Fatalf("id = %q", ref.ID)
}
}
func TestEnsureSubscriptionAccessWithMock(t *testing.T) {
var calls []string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
calls = append(calls, r.Method+" "+r.URL.Path)
switch {
case r.Method == http.MethodGet && strings.HasPrefix(r.URL.RequestURI(), "/api/v1/admin/users?"):
w.Write([]byte(`{"data":{"items":[]}}`))
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/admin/users":
w.Write([]byte(`{"data":{"id":84,"email":"relay-sub-user-1@sub2api.local"}}`))
case r.Method == http.MethodPut && r.URL.Path == "/api/v1/admin/users/84":
w.Write([]byte(`{"data":{"id":84}}`))
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/admin/users/84/balance":
w.Write([]byte(`{"data":{"id":84}}`))
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/admin/subscriptions/assign":
var req struct {
UserID int64 `json:"user_id"`
GroupID int64 `json:"group_id"`
DurationDays int `json:"validity_days"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("decode assign subscription request: %v", err)
}
if req.UserID != 84 || req.GroupID != 101 || req.DurationDays != 30 {
t.Fatalf("unexpected assign subscription request: %+v", req)
}
w.Write([]byte(`{"data":{"id":401}}`))
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/auth/login":
w.Write([]byte(`{"data":{"access_token":"user-jwt"}}`))
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/keys":
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("decode managed key request: %v", err)
}
if _, ok := req["group_id"]; ok {
t.Fatalf("managed key request unexpectedly carried group_id: %+v", req)
}
w.Write([]byte(`{"data":{"id":501,"key":"sk-relay-key","name":"managed-key"}}`))
case r.Method == http.MethodPut && r.URL.Path == "/api/v1/admin/api-keys/501":
w.Write([]byte(`{"data":{"api_key":{"id":501}}}`))
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
client, _ := NewClient(srv.URL, WithBearerToken("admin-token"))
ref, err := client.EnsureSubscriptionAccess(context.Background(), EnsureSubscriptionAccessRequest{UserSelector: "crm-user-1", GroupID: "101"})
if err != nil {
t.Fatal(err)
}
if ref.UserID != "84" {
t.Fatalf("user id = %q, want 84", ref.UserID)
}
if !strings.HasPrefix(ref.APIKey, "sk-relay-") {
t.Fatalf("api key = %q, want managed sk-relay-* key", ref.APIKey)
}
if len(calls) < 7 {
t.Fatalf("calls = %v, want managed subscription setup sequence", calls)
}
}
func TestCheckGatewayAccessWithMock(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got := r.Header.Get("Authorization"); got != "Bearer gk" {
t.Fatalf("Authorization = %q, want %q", got, "Bearer gk")
}
if got := r.Header.Get("x-api-key"); got != "" {
t.Fatalf("x-api-key = %q, want empty", got)
}
w.Write([]byte(`{"data":[{"id":"gpt-4"},{"id":"claude-3"}]}`))
}))
defer srv.Close()
client, _ := NewClient(srv.URL, WithAPIKey("k"))
result, err := client.CheckGatewayAccess(context.Background(), GatewayAccessCheckRequest{APIKey: "gk", ExpectedModel: "gpt-4"})
if err != nil {
t.Fatal(err)
}
if !result.OK {
t.Fatal("expected OK=true")
}
if !result.HasExpectedModel {
t.Fatal("expected HasExpectedModel=true")
}
}
func TestCheckGatewayCompletionWithMock(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/chat/completions" {
t.Fatalf("path = %q, want /v1/chat/completions", r.URL.Path)
}
if got := r.Header.Get("Authorization"); got != "Bearer gk" {
t.Fatalf("Authorization = %q, want %q", got, "Bearer gk")
}
if got := r.Header.Get("x-api-key"); got != "" {
t.Fatalf("x-api-key = %q, want empty", got)
}
var payload struct {
Model string `json:"model"`
MaxTokens int `json:"max_tokens"`
}
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
t.Fatalf("decode request: %v", err)
}
if payload.Model != "gpt-4" {
t.Fatalf("model = %q, want gpt-4", payload.Model)
}
if payload.MaxTokens != 8 {
t.Fatalf("max_tokens = %d, want 8", payload.MaxTokens)
}
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"choices":[{"message":{"content":"pong"}}]}`))
}))
defer srv.Close()
client, _ := NewClient(srv.URL, WithAPIKey("k"))
result, err := client.CheckGatewayCompletion(context.Background(), GatewayCompletionCheckRequest{APIKey: "gk", Model: "gpt-4"})
if err != nil {
t.Fatal(err)
}
if !result.OK {
t.Fatal("expected completion OK=true")
}
if result.StatusCode != 200 {
t.Fatalf("status = %d, want 200", result.StatusCode)
}
if result.ContentType != "application/json" {
t.Fatalf("content type = %q, want application/json", result.ContentType)
}
}
func TestBatchCreateAccountsWithMock(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req struct {
Accounts []struct {
Name string `json:"name"`
Platform string `json:"platform"`
Type string `json:"type"`
Credentials map[string]any `json:"credentials"`
GroupIDs []int64 `json:"group_ids"`
} `json:"accounts"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("decode request: %v", err)
}
if len(req.Accounts) != 1 {
t.Fatalf("accounts len = %d, want 1", len(req.Accounts))
}
acct := req.Accounts[0]
if acct.Name != "acct1" || acct.Platform != "openai" || acct.Type != "apikey" {
t.Fatalf("unexpected account metadata: %+v", acct)
}
if len(acct.GroupIDs) != 1 || acct.GroupIDs[0] != 101 {
t.Fatalf("group_ids = %v, want [101]", acct.GroupIDs)
}
rawMapping, ok := acct.Credentials["model_mapping"].(map[string]any)
if !ok {
t.Fatalf("credentials = %+v, want model_mapping map", acct.Credentials)
}
if got, _ := rawMapping["deepseek-v4-pro"].(string); got != "deepseek-v4-pro" {
t.Fatalf("model_mapping = %+v, want deepseek-v4-pro passthrough", rawMapping)
}
w.Write([]byte(`{"data":[{"id":601,"name":"acct1"}]}`))
}))
defer srv.Close()
client, _ := NewClient(srv.URL, WithAPIKey("k"))
refs, err := client.BatchCreateAccounts(context.Background(), BatchCreateAccountsRequest{
Accounts: []CreateAccountRequest{{Name: "acct1", Platform: "openai", Type: "apikey", GroupIDs: []string{"101"}, Credentials: map[string]any{"api_key": "sk-test", "base_url": "https://api.example.com", "model_mapping": map[string]string{"deepseek-v4-pro": "deepseek-v4-pro"}}}},
})
if err != nil {
t.Fatal(err)
}
if len(refs) != 1 || refs[0].ID != "601" {
t.Fatalf("got %+v", refs)
}
}
func TestProbeCapabilitiesWithMock(t *testing.T) {
callCount := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"data":[]}`))
}))
defer srv.Close()
client, _ := NewClient(srv.URL, WithAPIKey("k"))
caps, err := client.ProbeCapabilities(context.Background())
if err != nil {
t.Fatal(err)
}
if !caps.Groups || !caps.Channels || !caps.Plans || !caps.Accounts || !caps.AccountTest || !caps.AccountModels || !caps.Subscriptions {
t.Fatalf("all capabilities should be true, got %+v", caps)
}
if callCount != 7 {
t.Fatalf("callCount = %d, want 7", callCount)
}
}
func TestProbeCapabilitiesDoesNotTreat404AsSupportForAccountOrSubscriptionRoutes(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch r.URL.Path {
case "/api/v1/admin/groups", "/api/v1/admin/channels", "/api/v1/admin/payment/plans", "/api/v1/admin/accounts":
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"data":[]}`))
case "/api/v1/admin/accounts/__probe__/test", "/api/v1/admin/accounts/__probe__/models", "/api/v1/admin/subscriptions/assign":
w.WriteHeader(http.StatusNotFound)
_, _ = w.Write([]byte(`{"error":"not found"}`))
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
client, _ := NewClient(srv.URL, WithAPIKey("k"))
caps, err := client.ProbeCapabilities(context.Background())
if err != nil {
t.Fatalf("ProbeCapabilities() error = %v", err)
}
if caps.AccountTest {
t.Fatal("AccountTest = true, want false on 404 probe route")
}
if caps.AccountModels {
t.Fatal("AccountModels = true, want false on 404 probe route")
}
if caps.Subscriptions {
t.Fatal("Subscriptions = true, want false on 404 probe route")
}
}
func TestListManagedResourcesWithMock(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{"data":{"items":[
{"id":"r1","name":"resource-1"}
]}}`))
}))
defer srv.Close()
client, _ := NewClient(srv.URL, WithAPIKey("k"))
snapshot, err := client.ListManagedResources(context.Background(), ListManagedResourcesRequest{})
if err != nil {
t.Fatal(err)
}
if len(snapshot.Groups) != 1 {
t.Fatalf("expected 1 group, got %d", len(snapshot.Groups))
}
}
func TestListManagedResourcesLoadsAllAccountPages(t *testing.T) {
accountPages := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/v1/admin/groups", "/api/v1/admin/channels":
_, _ = w.Write([]byte(`{"data":{"items":[{"id":"r1","name":"resource-1"}],"total":1,"page":1,"page_size":20,"pages":1}}`))
case "/api/v1/admin/payment/plans":
_, _ = w.Write([]byte(`{"data":[{"id":"plan_1","name":"plan-1"}]}`))
case "/api/v1/admin/accounts":
accountPages++
page := r.URL.Query().Get("page")
if page == "" {
page = "1"
}
if got := r.URL.Query().Get("page_size"); got != "100" {
t.Fatalf("page_size = %q, want 100", got)
}
switch page {
case "1":
_, _ = w.Write([]byte(`{"data":{"items":[{"id":"account_1","name":"deepseek-01"}],"total":2,"page":1,"page_size":100,"pages":2}}`))
case "2":
_, _ = w.Write([]byte(`{"data":{"items":[{"id":"account_2","name":"deepseek-02"}],"total":2,"page":2,"page_size":100,"pages":2}}`))
default:
t.Fatalf("unexpected accounts page %q", page)
}
default:
http.NotFound(w, r)
}
}))
defer srv.Close()
client, _ := NewClient(srv.URL, WithAPIKey("k"))
snapshot, err := client.ListManagedResources(context.Background(), ListManagedResourcesRequest{AccountNamePrefix: "deepseek-"})
if err != nil {
t.Fatal(err)
}
if accountPages != 2 {
t.Fatalf("account pages fetched = %d, want 2", accountPages)
}
if len(snapshot.Accounts) != 2 || snapshot.Accounts[0].ID != "account_1" || snapshot.Accounts[1].ID != "account_2" {
t.Fatalf("Accounts = %+v, want both paged accounts", snapshot.Accounts)
}
}
func TestTestAccountWithMock(t *testing.T) {
var requestBody map[string]any
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
t.Fatalf("decode request body: %v", err)
}
w.Write([]byte("data: {\"status\":\"passed\",\"ok\":true}\n"))
}))
defer srv.Close()
client, _ := NewClient(srv.URL, WithAPIKey("k"))
result, err := client.TestAccount(context.Background(), "a1", "MiniMax-M2.7-highspeed")
if err != nil {
t.Fatal(err)
}
if !result.OK {
t.Fatal("expected OK=true")
}
if got := requestBody["model_id"]; got != "MiniMax-M2.7-highspeed" {
t.Fatalf("model_id = %#v, want MiniMax-M2.7-highspeed", got)
}
}
func TestTestAccountWithMockSSEError(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.Write([]byte("data: {\"type\":\"test_start\",\"model\":\"MiniMax-M2.7-highspeed\"}\n\n"))
w.Write([]byte("data: {\"type\":\"error\",\"error\":\"账号本身可正常使用,但当前测试接口仅支持 Responses API 路径。请直接通过实际 API 调用验证。\"}\n\n"))
}))
defer srv.Close()
client, _ := NewClient(srv.URL, WithAPIKey("k"))
result, err := client.TestAccount(context.Background(), "a1", "")
if err != nil {
t.Fatal(err)
}
if result.OK {
t.Fatal("expected OK=false for SSE error event")
}
if result.Status != "failed" {
t.Fatalf("Status = %q, want failed", result.Status)
}
if !strings.Contains(result.Message, "测试接口仅支持 Responses API 路径") {
t.Fatalf("Message = %q, want propagated SSE error message", result.Message)
}
}
func TestGetAccountModelsWithMock(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{"data":[{"id":"m1","display_name":"M1","type":"chat"}]}`))
}))
defer srv.Close()
client, _ := NewClient(srv.URL, WithAPIKey("k"))
models, err := client.GetAccountModels(context.Background(), "a1")
if err != nil {
t.Fatal(err)
}
if len(models) != 1 || models[0].ID != "m1" {
t.Fatalf("got %+v", models)
}
}