Files
sub2api-cn-relay-manager/internal/provision/model_pool.go
phamnazage-jpg 492f33a129
Some checks failed
CI / Build & Test (push) Has been cancelled
CI / Lint (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / Docker Build (push) Has been cancelled
CI / Release (push) Has been cancelled
feat(vnext): complete vNext.1 release gate — default chain admission, idempotent init, user key skeleton
- DEFAULT_CHAIN_ADMISSION.md: reviewed and approved, real artifact refs added
- DEFAULT_DATA_IDEMPOTENT_RELEASE_GATE.md: reviewed and approved
- scripts/setup_default_data.sh: idempotent init with --dry-run/--apply/artifact
- scripts/test/test_default_data.sh: 4 test cases all pass
- scripts/acceptance/verify_user_key_self_service.sh: Phase 0 skeleton
- .gitignore: add generated artifact directories
2026-06-05 11:07:50 +08:00

219 lines
6.3 KiB
Go

package provision
import (
"fmt"
"sort"
"strings"
"sub2api-cn-relay-manager/internal/host/sub2api"
"sub2api-cn-relay-manager/internal/pack"
)
type ModelPoolBuildRequest struct {
PublicModel string
AllowPluginAdapterCandidates bool
Candidates []ModelPoolCandidate
}
type ModelPoolCandidate struct {
RouteID string
Provider pack.ProviderManifest
Priority int
Schedulable *bool
AdvertisedModel string
CallableModel string
Inventory sub2api.CapabilityInventory
CooldownUntil string
DisableReason string
}
type ModelPool struct {
PublicModel string
CanonicalModelFamily string
Routes []PoolRoute
}
type PoolRoute struct {
RouteID string
ProviderID string
DisplayName string
BaseURL string
PublicModel string
AdvertisedModel string
CallableModel string
CanonicalModelFamily string
Priority int
Schedulable bool
SupportLevel string
SupportedModels []string
SupportsChat bool
SupportsResponses bool
CooldownUntil string
DisableReason string
KnownAdvisories []string
}
func BuildModelPool(req ModelPoolBuildRequest) (ModelPool, error) {
publicModel := strings.TrimSpace(req.PublicModel)
if publicModel == "" {
return ModelPool{}, fmt.Errorf("public_model is required")
}
routes := make([]PoolRoute, 0, len(req.Candidates))
canonicalFamily := ""
for _, candidate := range req.Candidates {
route, ok, err := buildPoolRoute(publicModel, req.AllowPluginAdapterCandidates, candidate)
if err != nil {
return ModelPool{}, err
}
if !ok {
continue
}
if canonicalFamily == "" {
canonicalFamily = route.CanonicalModelFamily
}
routes = append(routes, route)
}
if len(routes) == 0 {
return ModelPool{}, fmt.Errorf("no eligible routes for public_model %q", publicModel)
}
sort.SliceStable(routes, func(i, j int) bool {
if routes[i].Priority != routes[j].Priority {
return routes[i].Priority < routes[j].Priority
}
return routes[i].RouteID < routes[j].RouteID
})
if canonicalFamily == "" {
canonicalFamily = publicModel
}
return ModelPool{
PublicModel: publicModel,
CanonicalModelFamily: canonicalFamily,
Routes: routes,
}, nil
}
func buildPoolRoute(publicModel string, allowPluginAdapter bool, candidate ModelPoolCandidate) (PoolRoute, bool, error) {
routeID := strings.TrimSpace(candidate.RouteID)
if routeID == "" {
return PoolRoute{}, false, fmt.Errorf("route_id is required")
}
if !candidate.Inventory.HostReady {
return PoolRoute{}, false, nil
}
if candidate.Schedulable != nil && !*candidate.Schedulable {
return PoolRoute{}, false, nil
}
modelSummary, found := findModelSummary(candidate.Inventory, publicModel)
if !found {
return PoolRoute{}, false, nil
}
if !isEligibleSupportLevel(modelSummary.SupportLevel, allowPluginAdapter) {
return PoolRoute{}, false, nil
}
callableModel := strings.TrimSpace(candidate.CallableModel)
if callableModel == "" {
callableModel = resolveCallableModel(publicModel, candidate.Provider)
}
advertisedModel := strings.TrimSpace(candidate.AdvertisedModel)
if advertisedModel == "" {
advertisedModel = publicModel
}
schedulable := true
if candidate.Schedulable != nil {
schedulable = *candidate.Schedulable
}
supportedModels := collectSupportedModels(candidate.Inventory)
supportsResponses := !contains(modelSummary.KnownAdvisories, "responses_unsupported_but_chat_ok")
return PoolRoute{
RouteID: routeID,
ProviderID: strings.TrimSpace(candidate.Provider.ProviderID),
DisplayName: strings.TrimSpace(candidate.Provider.DisplayName),
BaseURL: strings.TrimSpace(candidate.Provider.BaseURL),
PublicModel: publicModel,
AdvertisedModel: advertisedModel,
CallableModel: callableModel,
CanonicalModelFamily: strings.TrimSpace(modelSummary.CanonicalModelFamily),
Priority: candidate.Priority,
Schedulable: schedulable,
SupportLevel: strings.TrimSpace(modelSummary.SupportLevel),
SupportedModels: supportedModels,
SupportsChat: modelSummary.SmokeChatOK,
SupportsResponses: supportsResponses,
CooldownUntil: strings.TrimSpace(candidate.CooldownUntil),
DisableReason: strings.TrimSpace(candidate.DisableReason),
KnownAdvisories: append([]string(nil), modelSummary.KnownAdvisories...),
}, true, nil
}
func findModelSummary(inventory sub2api.CapabilityInventory, publicModel string) (sub2api.ModelCapabilitySummary, bool) {
trimmed := strings.TrimSpace(publicModel)
for _, model := range inventory.Models {
if strings.EqualFold(strings.TrimSpace(model.RawModelID), trimmed) || strings.EqualFold(strings.TrimSpace(model.CanonicalModelFamily), trimmed) {
return model, true
}
}
return sub2api.ModelCapabilitySummary{}, false
}
func isEligibleSupportLevel(level string, allowPluginAdapter bool) bool {
switch strings.TrimSpace(level) {
case sub2api.SupportLevelDirect:
return true
case sub2api.SupportLevelWithPluginAdapter:
return allowPluginAdapter
default:
return false
}
}
func resolveCallableModel(publicModel string, provider pack.ProviderManifest) string {
trimmed := strings.TrimSpace(publicModel)
if mapped, ok := provider.ChannelTemplate.ModelMapping[trimmed]; ok && strings.TrimSpace(mapped) != "" {
return strings.TrimSpace(mapped)
}
if smoke := strings.TrimSpace(provider.SmokeTestModel); smoke != "" {
return smoke
}
return trimmed
}
func collectSupportedModels(inventory sub2api.CapabilityInventory) []string {
models := make([]string, 0, len(inventory.Models))
seen := make(map[string]struct{}, len(inventory.Models))
for _, model := range inventory.Models {
candidate := strings.TrimSpace(model.RawModelID)
if candidate == "" {
candidate = strings.TrimSpace(model.CanonicalModelFamily)
}
if candidate == "" {
continue
}
if _, ok := seen[candidate]; ok {
continue
}
seen[candidate] = struct{}{}
models = append(models, candidate)
}
return models
}
func contains(values []string, target string) bool {
target = strings.TrimSpace(target)
if target == "" {
return false
}
for _, value := range values {
if strings.TrimSpace(value) == target {
return true
}
}
return false
}