fix(provision): preserve channel model mapping on import
This commit is contained in:
@@ -54,8 +54,11 @@ type GroupRef struct {
|
||||
}
|
||||
|
||||
type CreateChannelRequest struct {
|
||||
Name string `json:"name"`
|
||||
GroupIDs []string `json:"group_ids"`
|
||||
Name string `json:"name"`
|
||||
GroupIDs []string `json:"group_ids"`
|
||||
ModelMapping map[string]string `json:"model_mapping,omitempty"`
|
||||
RestrictModels bool `json:"restrict_models,omitempty"`
|
||||
BillingModelSource string `json:"billing_model_source,omitempty"`
|
||||
}
|
||||
|
||||
type ChannelRef struct {
|
||||
|
||||
@@ -255,7 +255,14 @@ func ensureGroup(ctx context.Context, host hostAdapter, existing []sub2api.Named
|
||||
func ensureChannel(ctx context.Context, host hostAdapter, existing []sub2api.NamedResource, provider pack.ProviderManifest, groupID string) (sub2api.ChannelRef, bool, error) {
|
||||
switch len(existing) {
|
||||
case 0:
|
||||
channel, err := host.CreateChannel(ctx, sub2api.CreateChannelRequest{Name: provider.ChannelTemplate.Name, GroupIDs: []string{groupID}})
|
||||
channelReq := sub2api.CreateChannelRequest{
|
||||
Name: provider.ChannelTemplate.Name,
|
||||
GroupIDs: []string{groupID},
|
||||
ModelMapping: provider.ChannelTemplate.ModelMapping,
|
||||
RestrictModels: true,
|
||||
BillingModelSource: "channel_mapped",
|
||||
}
|
||||
channel, err := host.CreateChannel(ctx, channelReq)
|
||||
return channel, true, err
|
||||
case 1:
|
||||
return sub2api.ChannelRef{ID: existing[0].ID, Name: existing[0].Name}, false, nil
|
||||
|
||||
@@ -141,11 +141,11 @@ func TestImportServiceStrictModeRollsBackCreatedResources(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportServiceReusesExistingManagedResources(t *testing.T) {
|
||||
func TestImportReusesExistingGroup(t *testing.T) {
|
||||
host := &fakeHostAdapter{
|
||||
batchAccounts: []sub2api.AccountRef{{ID: "account_1"}},
|
||||
batchAccounts: []sub2api.AccountRef{{ID: "account_1", Name: "deepseek-01"}},
|
||||
testResults: map[string]sub2api.ProbeResult{
|
||||
"account_1": {OK: true, Status: "passed"},
|
||||
"account_1": {OK: true, Status: "ready"},
|
||||
},
|
||||
models: map[string][]sub2api.AccountModel{
|
||||
"account_1": {{ID: "deepseek-chat"}},
|
||||
@@ -177,6 +177,44 @@ func TestImportServiceReusesExistingManagedResources(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportCreatesChannelWithManifestModelMapping(t *testing.T) {
|
||||
host := &fakeHostAdapter{
|
||||
batchAccounts: []sub2api.AccountRef{{ID: "account_1", Name: "deepseek-01"}},
|
||||
testResults: map[string]sub2api.ProbeResult{
|
||||
"account_1": {OK: true, Status: "ready"},
|
||||
},
|
||||
models: map[string][]sub2api.AccountModel{
|
||||
"account_1": {{ID: "deepseek-chat"}},
|
||||
},
|
||||
gatewayResult: sub2api.GatewayAccessResult{OK: true, StatusCode: 200, HasExpectedModel: true, Models: []string{"deepseek-chat"}},
|
||||
}
|
||||
|
||||
_, err := NewImportService(host).Import(context.Background(), ImportRequest{
|
||||
Provider: sampleProviderManifest(),
|
||||
Mode: ImportModePartial,
|
||||
Access: AccessRequest{Mode: AccessModeSelfService, ProbeAPIKey: "user-key"},
|
||||
Keys: []string{"key-1"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Import() error = %v", err)
|
||||
}
|
||||
if host.createChannelReq.Name != "DeepSeek 默认渠道" {
|
||||
t.Fatalf("CreateChannel().Name = %q, want DeepSeek 默认渠道", host.createChannelReq.Name)
|
||||
}
|
||||
if len(host.createChannelReq.GroupIDs) != 1 || host.createChannelReq.GroupIDs[0] != "group_1" {
|
||||
t.Fatalf("CreateChannel().GroupIDs = %v, want [group_1]", host.createChannelReq.GroupIDs)
|
||||
}
|
||||
if got := host.createChannelReq.ModelMapping["deepseek-chat"]; got != "deepseek-chat" {
|
||||
t.Fatalf("CreateChannel().ModelMapping = %+v, want deepseek-chat passthrough", host.createChannelReq.ModelMapping)
|
||||
}
|
||||
if !host.createChannelReq.RestrictModels {
|
||||
t.Fatal("CreateChannel().RestrictModels = false, want true")
|
||||
}
|
||||
if host.createChannelReq.BillingModelSource != "channel_mapped" {
|
||||
t.Fatalf("CreateChannel().BillingModelSource = %q, want channel_mapped", host.createChannelReq.BillingModelSource)
|
||||
}
|
||||
}
|
||||
|
||||
func sampleProviderManifest() pack.ProviderManifest {
|
||||
return pack.ProviderManifest{
|
||||
ProviderID: "deepseek",
|
||||
@@ -210,6 +248,7 @@ type fakeHostAdapter struct {
|
||||
createChannelCalls int
|
||||
createPlanCalls int
|
||||
createGroupReq sub2api.CreateGroupRequest
|
||||
createChannelReq sub2api.CreateChannelRequest
|
||||
}
|
||||
|
||||
func (f *fakeHostAdapter) GetHostVersion(context.Context) (string, error) {
|
||||
@@ -230,8 +269,9 @@ func (f *fakeHostAdapter) DeleteGroup(_ context.Context, groupID string) error {
|
||||
f.deletedResources = append(f.deletedResources, "group:"+groupID)
|
||||
return nil
|
||||
}
|
||||
func (f *fakeHostAdapter) CreateChannel(context.Context, sub2api.CreateChannelRequest) (sub2api.ChannelRef, error) {
|
||||
func (f *fakeHostAdapter) CreateChannel(_ context.Context, req sub2api.CreateChannelRequest) (sub2api.ChannelRef, error) {
|
||||
f.createChannelCalls++
|
||||
f.createChannelReq = req
|
||||
return sub2api.ChannelRef{ID: "channel_1", Name: "c"}, nil
|
||||
}
|
||||
func (f *fakeHostAdapter) DeleteChannel(_ context.Context, channelID string) error {
|
||||
|
||||
Reference in New Issue
Block a user