fix(provision): reconcile channel pricing and hosted access
This commit is contained in:
@@ -6,6 +6,7 @@ import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -591,8 +592,23 @@ func TestCreateGroupWithMock(t *testing.T) {
|
||||
func TestCreateChannelWithMock(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
Name string `json:"name"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
Name string `json:"name"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
||||
ModelPricing []struct {
|
||||
Platform string `json:"platform"`
|
||||
Models []string `json:"models"`
|
||||
BillingMode string `json:"billing_mode"`
|
||||
InputPrice *float64 `json:"input_price"`
|
||||
OutputPrice *float64 `json:"output_price"`
|
||||
CacheWritePrice *float64 `json:"cache_write_price"`
|
||||
CacheReadPrice *float64 `json:"cache_read_price"`
|
||||
ImageOutputPrice *float64 `json:"image_output_price"`
|
||||
PerRequestPrice *float64 `json:"per_request_price"`
|
||||
Intervals []any `json:"intervals"`
|
||||
} `json:"model_pricing"`
|
||||
RestrictModels bool `json:"restrict_models"`
|
||||
BillingModelSource string `json:"billing_model_source"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
t.Fatalf("decode request: %v", err)
|
||||
@@ -603,11 +619,36 @@ func TestCreateChannelWithMock(t *testing.T) {
|
||||
if len(req.GroupIDs) != 1 || req.GroupIDs[0] != 101 {
|
||||
t.Fatalf("group_ids = %v, want [101]", req.GroupIDs)
|
||||
}
|
||||
if req.ModelMapping["openai"]["deepseek-v4-pro"] != "deepseek-v4-pro" {
|
||||
t.Fatalf("model_mapping = %+v, want openai/deepseek-v4-pro passthrough", req.ModelMapping)
|
||||
}
|
||||
if len(req.ModelPricing) != 1 {
|
||||
t.Fatalf("model_pricing len = %d, want 1", len(req.ModelPricing))
|
||||
}
|
||||
if req.ModelPricing[0].Platform != "openai" || req.ModelPricing[0].BillingMode != "token" {
|
||||
t.Fatalf("model_pricing[0] = %+v, want openai/token entry", req.ModelPricing[0])
|
||||
}
|
||||
if len(req.ModelPricing[0].Models) != 1 || req.ModelPricing[0].Models[0] != "deepseek-v4-pro" {
|
||||
t.Fatalf("model_pricing[0].models = %v, want [deepseek-v4-pro]", req.ModelPricing[0].Models)
|
||||
}
|
||||
if !req.RestrictModels {
|
||||
t.Fatal("restrict_models = false, want true")
|
||||
}
|
||||
if req.BillingModelSource != "channel_mapped" {
|
||||
t.Fatalf("billing_model_source = %q, want channel_mapped", req.BillingModelSource)
|
||||
}
|
||||
w.Write([]byte(`{"data":{"id":201,"name":"ch"}}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
client, _ := NewClient(srv.URL, WithAPIKey("k"))
|
||||
ref, err := client.CreateChannel(context.Background(), CreateChannelRequest{Name: "ch", GroupIDs: []string{"101"}})
|
||||
ref, err := client.CreateChannel(context.Background(), CreateChannelRequest{
|
||||
Name: "ch",
|
||||
GroupIDs: []string{"101"},
|
||||
ModelMapping: map[string]string{"deepseek-v4-pro": "deepseek-v4-pro"},
|
||||
ModelPricing: []ChannelModelPricing{{Platform: "openai", Models: []string{"deepseek-v4-pro"}, BillingMode: "token"}},
|
||||
RestrictModels: true,
|
||||
BillingModelSource: "channel_mapped",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -699,6 +740,66 @@ func TestAssignSubscriptionWithMock(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureSubscriptionAccessWithMock(t *testing.T) {
|
||||
var calls []string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
calls = append(calls, r.Method+" "+r.URL.Path)
|
||||
switch {
|
||||
case r.Method == http.MethodGet && strings.HasPrefix(r.URL.RequestURI(), "/api/v1/admin/users?"):
|
||||
w.Write([]byte(`{"data":{"items":[]}}`))
|
||||
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/admin/users":
|
||||
w.Write([]byte(`{"data":{"id":84,"email":"relay-sub-user-1@sub2api.local"}}`))
|
||||
case r.Method == http.MethodPut && r.URL.Path == "/api/v1/admin/users/84":
|
||||
w.Write([]byte(`{"data":{"id":84}}`))
|
||||
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/admin/users/84/balance":
|
||||
w.Write([]byte(`{"data":{"id":84}}`))
|
||||
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/admin/subscriptions/assign":
|
||||
var req struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
GroupID int64 `json:"group_id"`
|
||||
DurationDays int `json:"validity_days"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
t.Fatalf("decode assign subscription request: %v", err)
|
||||
}
|
||||
if req.UserID != 84 || req.GroupID != 101 || req.DurationDays != 30 {
|
||||
t.Fatalf("unexpected assign subscription request: %+v", req)
|
||||
}
|
||||
w.Write([]byte(`{"data":{"id":401}}`))
|
||||
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/auth/login":
|
||||
w.Write([]byte(`{"data":{"access_token":"user-jwt"}}`))
|
||||
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/keys":
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
t.Fatalf("decode managed key request: %v", err)
|
||||
}
|
||||
if _, ok := req["group_id"]; ok {
|
||||
t.Fatalf("managed key request unexpectedly carried group_id: %+v", req)
|
||||
}
|
||||
w.Write([]byte(`{"data":{"id":501,"key":"sk-relay-key","name":"managed-key"}}`))
|
||||
case r.Method == http.MethodPut && r.URL.Path == "/api/v1/admin/api-keys/501":
|
||||
w.Write([]byte(`{"data":{"api_key":{"id":501}}}`))
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
client, _ := NewClient(srv.URL, WithBearerToken("admin-token"))
|
||||
ref, err := client.EnsureSubscriptionAccess(context.Background(), EnsureSubscriptionAccessRequest{UserSelector: "crm-user-1", GroupID: "101"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if ref.UserID != "84" {
|
||||
t.Fatalf("user id = %q, want 84", ref.UserID)
|
||||
}
|
||||
if !strings.HasPrefix(ref.APIKey, "sk-relay-") {
|
||||
t.Fatalf("api key = %q, want managed sk-relay-* key", ref.APIKey)
|
||||
}
|
||||
if len(calls) < 7 {
|
||||
t.Fatalf("calls = %v, want managed subscription setup sequence", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckGatewayAccessWithMock(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(`{"data":[{"id":"gpt-4"},{"id":"claude-3"}]}`))
|
||||
@@ -741,12 +842,19 @@ func TestBatchCreateAccountsWithMock(t *testing.T) {
|
||||
if len(acct.GroupIDs) != 1 || acct.GroupIDs[0] != 101 {
|
||||
t.Fatalf("group_ids = %v, want [101]", acct.GroupIDs)
|
||||
}
|
||||
rawMapping, ok := acct.Credentials["model_mapping"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("credentials = %+v, want model_mapping map", acct.Credentials)
|
||||
}
|
||||
if got, _ := rawMapping["deepseek-v4-pro"].(string); got != "deepseek-v4-pro" {
|
||||
t.Fatalf("model_mapping = %+v, want deepseek-v4-pro passthrough", rawMapping)
|
||||
}
|
||||
w.Write([]byte(`{"data":[{"id":601,"name":"acct1"}]}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
client, _ := NewClient(srv.URL, WithAPIKey("k"))
|
||||
refs, err := client.BatchCreateAccounts(context.Background(), BatchCreateAccountsRequest{
|
||||
Accounts: []CreateAccountRequest{{Name: "acct1", Platform: "openai", Type: "apikey", GroupIDs: []string{"101"}, Credentials: map[string]any{"api_key": "sk-test", "base_url": "https://api.example.com"}}},
|
||||
Accounts: []CreateAccountRequest{{Name: "acct1", Platform: "openai", Type: "apikey", GroupIDs: []string{"101"}, Credentials: map[string]any{"api_key": "sk-test", "base_url": "https://api.example.com", "model_mapping": map[string]string{"deepseek-v4-pro": "deepseek-v4-pro"}}}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
Reference in New Issue
Block a user