Files
sub2api-cn-relay-manager/internal/batch/service.go
2026-05-22 14:41:12 +08:00

327 lines
9.7 KiB
Go

package batch
import (
"context"
"crypto/sha256"
"encoding/json"
"fmt"
"strings"
"time"
"sub2api-cn-relay-manager/internal/probe"
"sub2api-cn-relay-manager/internal/store/sqlite"
)
type BatchImportEntry struct {
BaseURL string
APIKey string
RequestedModels []string
}
type BatchImportRunRequest struct {
RunID string
Mode string
AccessMode string
HostID string
HostBaseURL string
Entries []BatchImportEntry
}
type BatchImportRunResult struct {
RunID string
ItemIDs []string
}
type RunStateStore interface {
Create(ctx context.Context, run sqlite.ImportRun) error
}
type ItemStateStore interface {
Upsert(ctx context.Context, item sqlite.ImportRunItem) error
}
type ReuseLookupInput struct {
HostID string
ProviderID string
BaseURL string
APIKeyFingerprint string
CanonicalModelFamilies []string
}
type ReuseLookupResult struct {
ExistingProviderID string
ExistingAccessStatus AccessStatus
ExistingCanonicalFamilys []string
MatchedAccountID int64
MatchedAccountState MatchedAccountState
ExistingModelMapping map[string]string
}
type ProvisionRequest struct {
RunID string
ItemID string
Entry BatchImportEntry
ProviderID string
ResolvedModel string
RoutingStrategy ImportRoutingStrategy
CapabilityProfile *probe.CapabilityProfile
}
type ProvisionResult struct {
LegacyBatchID *int64
LegacyProviderID string
}
type PatchProvisionRequest struct {
ProviderID string
Contract ChannelPatchContract
}
type BatchProvisioner interface {
Provision(ctx context.Context, req ProvisionRequest) (ProvisionResult, error)
Patch(ctx context.Context, req PatchProvisionRequest) error
}
type BatchImportService struct {
RunStore RunStateStore
ItemStore ItemStateStore
ProbeModels func(ctx context.Context, baseURL, apiKey string) (*probe.ModelsResult, error)
ProbeCapabilities func(ctx context.Context, baseURL, apiKey string, rawModels []string) (*probe.CapabilityProfile, error)
InspectReuse func(ctx context.Context, input ReuseLookupInput) (ReuseLookupResult, error)
Provisioner BatchProvisioner
}
func (s BatchImportService) StartRun(ctx context.Context, req BatchImportRunRequest) (BatchImportRunResult, error) {
if s.RunStore == nil {
return BatchImportRunResult{}, fmt.Errorf("run store is required")
}
if s.ItemStore == nil {
return BatchImportRunResult{}, fmt.Errorf("item store is required")
}
if s.ProbeModels == nil {
return BatchImportRunResult{}, fmt.Errorf("model probe is required")
}
if s.ProbeCapabilities == nil {
return BatchImportRunResult{}, fmt.Errorf("capability probe is required")
}
runID := strings.TrimSpace(req.RunID)
if runID == "" {
runID = fmt.Sprintf("run-%d", time.Now().UnixNano())
}
if len(req.Entries) == 0 {
return BatchImportRunResult{}, fmt.Errorf("entries are required")
}
if err := s.RunStore.Create(ctx, sqlite.ImportRun{
RunID: runID,
Mode: strings.TrimSpace(req.Mode),
AccessMode: strings.TrimSpace(req.AccessMode),
State: string(RunStateRunning),
TotalItems: len(req.Entries),
}); err != nil {
return BatchImportRunResult{}, err
}
result := BatchImportRunResult{
RunID: runID,
ItemIDs: make([]string, 0, len(req.Entries)),
}
for idx, entry := range req.Entries {
itemID := fmt.Sprintf("%s-item-%d", runID, idx+1)
result.ItemIDs = append(result.ItemIDs, itemID)
providerID := NormalizeProviderID(entry.BaseURL)
fingerprint := fingerprintAPIKey(entry.APIKey)
initialItem := sqlite.ImportRunItem{
ItemID: itemID,
RunID: runID,
BaseURL: strings.TrimSpace(entry.BaseURL),
ProviderID: providerID,
APIKeyFingerprint: fingerprint,
CurrentStage: string(ItemStageProbe),
ConfirmationStatus: string(ConfirmationPending),
AccessStatus: string(AccessStatusUnknown),
MatchedAccountState: string(MatchedAccountStateNone),
AccountResolution: string(AccountResolutionCreated),
}
if err := s.ItemStore.Upsert(ctx, initialItem); err != nil {
return BatchImportRunResult{}, err
}
modelsResult, err := s.ProbeModels(ctx, entry.BaseURL, entry.APIKey)
if err != nil {
return BatchImportRunResult{}, err
}
rawModels := append([]string(nil), modelsResult.RawModels...)
capabilityProfile, err := s.ProbeCapabilities(ctx, entry.BaseURL, entry.APIKey, rawModels)
if err != nil {
return BatchImportRunResult{}, err
}
routingStrategy := BuildImportRoutingStrategy(capabilityProfile)
resolvedSmokeModel, recommendedModels, err := probe.ResolveSmokeModel(entry.RequestedModels, rawModels, capabilityProfile)
if err != nil {
return BatchImportRunResult{}, err
}
canonicalFamilies := uniqueCanonicalFamilies(rawModels)
reuseLookup := ReuseLookupResult{}
if s.InspectReuse != nil {
reuseLookup, err = s.InspectReuse(ctx, ReuseLookupInput{
HostID: strings.TrimSpace(req.HostID),
ProviderID: providerID,
BaseURL: entry.BaseURL,
APIKeyFingerprint: fingerprint,
CanonicalModelFamilies: canonicalFamilies,
})
if err != nil {
return BatchImportRunResult{}, err
}
}
reuseDecision := DecideReuse(ReuseInput{
ProviderID: providerID,
CanonicalModelFamilies: canonicalFamilies,
MatchedAccountID: reuseLookup.MatchedAccountID,
MatchedAccountState: reuseLookup.MatchedAccountState,
ExistingProviderID: reuseLookup.ExistingProviderID,
ExistingAccessStatus: reuseLookup.ExistingAccessStatus,
ExistingCanonicalFamilys: reuseLookup.ExistingCanonicalFamilys,
})
finalItem := sqlite.ImportRunItem{
ItemID: itemID,
RunID: runID,
BaseURL: strings.TrimSpace(entry.BaseURL),
ProviderID: providerID,
APIKeyFingerprint: fingerprint,
RequestedModelsJSON: mustMarshalJSON(entry.RequestedModels, "[]"),
RawModelsJSON: mustMarshalJSON(rawModels, "[]"),
NormalizedModelsJSON: mustMarshalJSON(uniqueNormalizedModels(rawModels), "[]"),
CanonicalFamiliesJSON: mustMarshalJSON(canonicalFamilies, "[]"),
RecommendedModelsJSON: mustMarshalJSON(recommendedModels, "[]"),
ResolvedSmokeModel: resolvedSmokeModel,
CapabilityProfileJSON: mustMarshalJSON(capabilityProfile, "{}"),
CurrentStage: string(ItemStageConfirm),
ConfirmationStatus: string(ConfirmationPending),
AccessStatus: string(AccessStatusUnknown),
MatchedAccountState: string(reuseDecision.MatchedAccountState),
AccountResolution: string(reuseDecision.AccountResolution),
ProvisionReused: reuseDecision.ProvisionReused,
ReusedFromProviderID: reuseDecision.ReusedFromProviderID,
ReusedFromAccountID: int64PtrIfSet(reuseDecision.ReusedFromAccountID),
}
if reuseDecision.ProvisionReused {
patchContract := ModelMappingDelta(reuseLookup.ExistingModelMapping, probe.BuildAliasTable(rawModels))
if shouldPatchAliases(reuseLookup.ExistingModelMapping, patchContract.ModelMapping) {
if s.Provisioner == nil {
return BatchImportRunResult{}, fmt.Errorf("provisioner is required for patch-only flow")
}
if err := s.Provisioner.Patch(ctx, PatchProvisionRequest{
ProviderID: reuseDecision.ReusedFromProviderID,
Contract: patchContract,
}); err != nil {
return BatchImportRunResult{}, err
}
}
} else {
if s.Provisioner == nil {
return BatchImportRunResult{}, fmt.Errorf("provisioner is required")
}
provisionResult, err := s.Provisioner.Provision(ctx, ProvisionRequest{
RunID: runID,
ItemID: itemID,
Entry: entry,
ProviderID: providerID,
ResolvedModel: resolvedSmokeModel,
RoutingStrategy: routingStrategy,
CapabilityProfile: capabilityProfile,
})
if err != nil {
return BatchImportRunResult{}, err
}
finalItem.LegacyBatchID = provisionResult.LegacyBatchID
finalItem.LegacyProviderID = strings.TrimSpace(provisionResult.LegacyProviderID)
}
if err := s.ItemStore.Upsert(ctx, finalItem); err != nil {
return BatchImportRunResult{}, err
}
}
return result, nil
}
func uniqueCanonicalFamilies(rawModels []string) []string {
seen := make(map[string]struct{}, len(rawModels))
families := make([]string, 0, len(rawModels))
for _, rawModel := range rawModels {
family := probe.CanonicalModelFamily(rawModel)
if family == "" {
continue
}
if _, ok := seen[family]; ok {
continue
}
seen[family] = struct{}{}
families = append(families, family)
}
return families
}
func uniqueNormalizedModels(rawModels []string) []string {
seen := make(map[string]struct{}, len(rawModels))
models := make([]string, 0, len(rawModels))
for _, rawModel := range rawModels {
normalized := probe.NormalizeModelID(rawModel)
if normalized == "" {
continue
}
if _, ok := seen[normalized]; ok {
continue
}
seen[normalized] = struct{}{}
models = append(models, normalized)
}
return models
}
func mustMarshalJSON(value any, fallback string) string {
payload, err := json.Marshal(value)
if err != nil {
return fallback
}
return string(payload)
}
func fingerprintAPIKey(apiKey string) string {
trimmed := strings.TrimSpace(apiKey)
if trimmed == "" {
return ""
}
sum := sha256.Sum256([]byte(trimmed))
return fmt.Sprintf("sha256:%x", sum[:8])
}
func int64PtrIfSet(value int64) *int64 {
if value == 0 {
return nil
}
result := value
return &result
}
func shouldPatchAliases(existing map[string]string, next map[string]string) bool {
if len(existing) == 0 {
return false
}
for key, value := range next {
if existingValue, ok := existing[key]; !ok || existingValue != value {
return true
}
}
return false
}