fix(provision): preserve channel model mapping on import

This commit is contained in:
phamnazage-jpg
2026-05-19 22:24:32 +08:00
parent 18e1b085eb
commit 83ee216a4d
5 changed files with 64 additions and 13 deletions

View File

@@ -54,8 +54,11 @@ type GroupRef struct {
}
type CreateChannelRequest struct {
Name string `json:"name"`
GroupIDs []string `json:"group_ids"`
Name string `json:"name"`
GroupIDs []string `json:"group_ids"`
ModelMapping map[string]string `json:"model_mapping,omitempty"`
RestrictModels bool `json:"restrict_models,omitempty"`
BillingModelSource string `json:"billing_model_source,omitempty"`
}
type ChannelRef struct {

View File

@@ -255,7 +255,14 @@ func ensureGroup(ctx context.Context, host hostAdapter, existing []sub2api.Named
func ensureChannel(ctx context.Context, host hostAdapter, existing []sub2api.NamedResource, provider pack.ProviderManifest, groupID string) (sub2api.ChannelRef, bool, error) {
switch len(existing) {
case 0:
channel, err := host.CreateChannel(ctx, sub2api.CreateChannelRequest{Name: provider.ChannelTemplate.Name, GroupIDs: []string{groupID}})
channelReq := sub2api.CreateChannelRequest{
Name: provider.ChannelTemplate.Name,
GroupIDs: []string{groupID},
ModelMapping: provider.ChannelTemplate.ModelMapping,
RestrictModels: true,
BillingModelSource: "channel_mapped",
}
channel, err := host.CreateChannel(ctx, channelReq)
return channel, true, err
case 1:
return sub2api.ChannelRef{ID: existing[0].ID, Name: existing[0].Name}, false, nil

View File

@@ -141,11 +141,11 @@ func TestImportServiceStrictModeRollsBackCreatedResources(t *testing.T) {
}
}
func TestImportServiceReusesExistingManagedResources(t *testing.T) {
func TestImportReusesExistingGroup(t *testing.T) {
host := &fakeHostAdapter{
batchAccounts: []sub2api.AccountRef{{ID: "account_1"}},
batchAccounts: []sub2api.AccountRef{{ID: "account_1", Name: "deepseek-01"}},
testResults: map[string]sub2api.ProbeResult{
"account_1": {OK: true, Status: "passed"},
"account_1": {OK: true, Status: "ready"},
},
models: map[string][]sub2api.AccountModel{
"account_1": {{ID: "deepseek-chat"}},
@@ -177,6 +177,44 @@ func TestImportServiceReusesExistingManagedResources(t *testing.T) {
}
}
func TestImportCreatesChannelWithManifestModelMapping(t *testing.T) {
host := &fakeHostAdapter{
batchAccounts: []sub2api.AccountRef{{ID: "account_1", Name: "deepseek-01"}},
testResults: map[string]sub2api.ProbeResult{
"account_1": {OK: true, Status: "ready"},
},
models: map[string][]sub2api.AccountModel{
"account_1": {{ID: "deepseek-chat"}},
},
gatewayResult: sub2api.GatewayAccessResult{OK: true, StatusCode: 200, HasExpectedModel: true, Models: []string{"deepseek-chat"}},
}
_, err := NewImportService(host).Import(context.Background(), ImportRequest{
Provider: sampleProviderManifest(),
Mode: ImportModePartial,
Access: AccessRequest{Mode: AccessModeSelfService, ProbeAPIKey: "user-key"},
Keys: []string{"key-1"},
})
if err != nil {
t.Fatalf("Import() error = %v", err)
}
if host.createChannelReq.Name != "DeepSeek 默认渠道" {
t.Fatalf("CreateChannel().Name = %q, want DeepSeek 默认渠道", host.createChannelReq.Name)
}
if len(host.createChannelReq.GroupIDs) != 1 || host.createChannelReq.GroupIDs[0] != "group_1" {
t.Fatalf("CreateChannel().GroupIDs = %v, want [group_1]", host.createChannelReq.GroupIDs)
}
if got := host.createChannelReq.ModelMapping["deepseek-chat"]; got != "deepseek-chat" {
t.Fatalf("CreateChannel().ModelMapping = %+v, want deepseek-chat passthrough", host.createChannelReq.ModelMapping)
}
if !host.createChannelReq.RestrictModels {
t.Fatal("CreateChannel().RestrictModels = false, want true")
}
if host.createChannelReq.BillingModelSource != "channel_mapped" {
t.Fatalf("CreateChannel().BillingModelSource = %q, want channel_mapped", host.createChannelReq.BillingModelSource)
}
}
func sampleProviderManifest() pack.ProviderManifest {
return pack.ProviderManifest{
ProviderID: "deepseek",
@@ -210,6 +248,7 @@ type fakeHostAdapter struct {
createChannelCalls int
createPlanCalls int
createGroupReq sub2api.CreateGroupRequest
createChannelReq sub2api.CreateChannelRequest
}
func (f *fakeHostAdapter) GetHostVersion(context.Context) (string, error) {
@@ -230,8 +269,9 @@ func (f *fakeHostAdapter) DeleteGroup(_ context.Context, groupID string) error {
f.deletedResources = append(f.deletedResources, "group:"+groupID)
return nil
}
func (f *fakeHostAdapter) CreateChannel(context.Context, sub2api.CreateChannelRequest) (sub2api.ChannelRef, error) {
func (f *fakeHostAdapter) CreateChannel(_ context.Context, req sub2api.CreateChannelRequest) (sub2api.ChannelRef, error) {
f.createChannelCalls++
f.createChannelReq = req
return sub2api.ChannelRef{ID: "channel_1", Name: "c"}, nil
}
func (f *fakeHostAdapter) DeleteChannel(_ context.Context, channelID string) error {