feat(batch): implement v2 run setup and provision stages
This commit is contained in:
34
internal/batch/capability_profile.go
Normal file
34
internal/batch/capability_profile.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package batch
|
||||
|
||||
import "sub2api-cn-relay-manager/internal/probe"
|
||||
|
||||
type ImportRoutingStrategy struct {
|
||||
UseRawChatCompletions bool
|
||||
SkipResponsesChecks bool
|
||||
RetryInitial503 bool
|
||||
TreatProbe403Advisory bool
|
||||
}
|
||||
|
||||
func BuildImportRoutingStrategy(profile *probe.CapabilityProfile) ImportRoutingStrategy {
|
||||
strategy := ImportRoutingStrategy{
|
||||
RetryInitial503: true,
|
||||
}
|
||||
if profile == nil {
|
||||
return strategy
|
||||
}
|
||||
|
||||
if profile.TransportProfile.SupportsOpenAIChatCompletions && !profile.TransportProfile.SupportsOpenAIResponses {
|
||||
strategy.UseRawChatCompletions = true
|
||||
strategy.SkipResponsesChecks = true
|
||||
}
|
||||
for _, advisory := range profile.TransportProfile.KnownAdvisories {
|
||||
switch advisory {
|
||||
case "responses_unsupported_but_chat_ok":
|
||||
strategy.UseRawChatCompletions = true
|
||||
strategy.SkipResponsesChecks = true
|
||||
case "initial_probe_race_expected":
|
||||
strategy.TreatProbe403Advisory = true
|
||||
}
|
||||
}
|
||||
return strategy
|
||||
}
|
||||
41
internal/batch/channel_evolution.go
Normal file
41
internal/batch/channel_evolution.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package batch
|
||||
|
||||
import "sub2api-cn-relay-manager/internal/probe"
|
||||
|
||||
type ChannelPatchContract struct {
|
||||
ModelMapping map[string]string
|
||||
ModelPricing map[string]any
|
||||
RestrictModels bool
|
||||
BillingModelSource string
|
||||
}
|
||||
|
||||
func ModelMappingDelta(existing map[string]string, discoveredAliases map[string]probe.AliasResult) ChannelPatchContract {
|
||||
modelMapping := make(map[string]string, len(existing)+len(discoveredAliases)*2)
|
||||
modelPricing := make(map[string]any, len(existing)+len(discoveredAliases))
|
||||
|
||||
for raw, canonical := range existing {
|
||||
modelMapping[raw] = canonical
|
||||
modelPricing[canonical] = map[string]any{"billing_mode": "token"}
|
||||
}
|
||||
|
||||
for _, alias := range discoveredAliases {
|
||||
if alias.Canonical == "" {
|
||||
continue
|
||||
}
|
||||
if alias.Raw != "" {
|
||||
modelMapping[alias.Raw] = alias.Canonical
|
||||
}
|
||||
if alias.Normalized != "" {
|
||||
modelMapping[alias.Normalized] = alias.Canonical
|
||||
}
|
||||
modelMapping[alias.Canonical] = alias.Canonical
|
||||
modelPricing[alias.Canonical] = map[string]any{"billing_mode": "token"}
|
||||
}
|
||||
|
||||
return ChannelPatchContract{
|
||||
ModelMapping: modelMapping,
|
||||
ModelPricing: modelPricing,
|
||||
RestrictModels: true,
|
||||
BillingModelSource: "channel_mapped",
|
||||
}
|
||||
}
|
||||
56
internal/batch/channel_evolution_test.go
Normal file
56
internal/batch/channel_evolution_test.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package batch
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/probe"
|
||||
)
|
||||
|
||||
func TestModelMappingDelta(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("preserves existing entries", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := ModelMappingDelta(
|
||||
map[string]string{"deepseek-v4-pro": "deepseek-v4-pro"},
|
||||
map[string]probe.AliasResult{
|
||||
"deepseek-v4-pro": {Raw: "deepseek-ai/DeepSeek-V4-Pro", Canonical: "deepseek-v4-pro"},
|
||||
},
|
||||
)
|
||||
if got.ModelMapping["deepseek-v4-pro"] != "deepseek-v4-pro" {
|
||||
t.Fatalf("existing model mapping lost: %#v", got.ModelMapping)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("adds raw aliases mapped to canonical ids", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := ModelMappingDelta(
|
||||
map[string]string{"kimi-k2.6": "kimi-2.6"},
|
||||
map[string]probe.AliasResult{
|
||||
"kimi-k2.6": {Raw: "Kimi-K2.6", Canonical: "kimi-2.6"},
|
||||
},
|
||||
)
|
||||
if got.ModelMapping["Kimi-K2.6"] != "kimi-2.6" {
|
||||
t.Fatalf("ModelMapping[Kimi-K2.6] = %q, want kimi-2.6", got.ModelMapping["Kimi-K2.6"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("always sets canonical patch flags", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := ModelMappingDelta(nil, map[string]probe.AliasResult{
|
||||
"deepseek-v4-pro": {Raw: "deepseek-ai/DeepSeek-V4-Pro", Canonical: "deepseek-v4-pro"},
|
||||
})
|
||||
if !got.RestrictModels {
|
||||
t.Fatal("RestrictModels = false, want true")
|
||||
}
|
||||
if got.BillingModelSource != "channel_mapped" {
|
||||
t.Fatalf("BillingModelSource = %q, want channel_mapped", got.BillingModelSource)
|
||||
}
|
||||
if got.ModelPricing["deepseek-v4-pro"] == nil {
|
||||
t.Fatalf("ModelPricing = %#v, want canonical entry", got.ModelPricing)
|
||||
}
|
||||
})
|
||||
}
|
||||
326
internal/batch/service.go
Normal file
326
internal/batch/service.go
Normal file
@@ -0,0 +1,326 @@
|
||||
package batch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/probe"
|
||||
"sub2api-cn-relay-manager/internal/store/sqlite"
|
||||
)
|
||||
|
||||
type BatchImportEntry struct {
|
||||
BaseURL string
|
||||
APIKey string
|
||||
RequestedModels []string
|
||||
}
|
||||
|
||||
type BatchImportRunRequest struct {
|
||||
RunID string
|
||||
Mode string
|
||||
AccessMode string
|
||||
HostID string
|
||||
HostBaseURL string
|
||||
Entries []BatchImportEntry
|
||||
}
|
||||
|
||||
type BatchImportRunResult struct {
|
||||
RunID string
|
||||
ItemIDs []string
|
||||
}
|
||||
|
||||
type RunStateStore interface {
|
||||
Create(ctx context.Context, run sqlite.ImportRun) error
|
||||
}
|
||||
|
||||
type ItemStateStore interface {
|
||||
Upsert(ctx context.Context, item sqlite.ImportRunItem) error
|
||||
}
|
||||
|
||||
type ReuseLookupInput struct {
|
||||
HostID string
|
||||
ProviderID string
|
||||
BaseURL string
|
||||
APIKeyFingerprint string
|
||||
CanonicalModelFamilies []string
|
||||
}
|
||||
|
||||
type ReuseLookupResult struct {
|
||||
ExistingProviderID string
|
||||
ExistingAccessStatus AccessStatus
|
||||
ExistingCanonicalFamilys []string
|
||||
MatchedAccountID int64
|
||||
MatchedAccountState MatchedAccountState
|
||||
ExistingModelMapping map[string]string
|
||||
}
|
||||
|
||||
type ProvisionRequest struct {
|
||||
RunID string
|
||||
ItemID string
|
||||
Entry BatchImportEntry
|
||||
ProviderID string
|
||||
ResolvedModel string
|
||||
RoutingStrategy ImportRoutingStrategy
|
||||
CapabilityProfile *probe.CapabilityProfile
|
||||
}
|
||||
|
||||
type ProvisionResult struct {
|
||||
LegacyBatchID *int64
|
||||
LegacyProviderID string
|
||||
}
|
||||
|
||||
type PatchProvisionRequest struct {
|
||||
ProviderID string
|
||||
Contract ChannelPatchContract
|
||||
}
|
||||
|
||||
type BatchProvisioner interface {
|
||||
Provision(ctx context.Context, req ProvisionRequest) (ProvisionResult, error)
|
||||
Patch(ctx context.Context, req PatchProvisionRequest) error
|
||||
}
|
||||
|
||||
type BatchImportService struct {
|
||||
RunStore RunStateStore
|
||||
ItemStore ItemStateStore
|
||||
ProbeModels func(ctx context.Context, baseURL, apiKey string) (*probe.ModelsResult, error)
|
||||
ProbeCapabilities func(ctx context.Context, baseURL, apiKey string, rawModels []string) (*probe.CapabilityProfile, error)
|
||||
InspectReuse func(ctx context.Context, input ReuseLookupInput) (ReuseLookupResult, error)
|
||||
Provisioner BatchProvisioner
|
||||
}
|
||||
|
||||
func (s BatchImportService) StartRun(ctx context.Context, req BatchImportRunRequest) (BatchImportRunResult, error) {
|
||||
if s.RunStore == nil {
|
||||
return BatchImportRunResult{}, fmt.Errorf("run store is required")
|
||||
}
|
||||
if s.ItemStore == nil {
|
||||
return BatchImportRunResult{}, fmt.Errorf("item store is required")
|
||||
}
|
||||
if s.ProbeModels == nil {
|
||||
return BatchImportRunResult{}, fmt.Errorf("model probe is required")
|
||||
}
|
||||
if s.ProbeCapabilities == nil {
|
||||
return BatchImportRunResult{}, fmt.Errorf("capability probe is required")
|
||||
}
|
||||
|
||||
runID := strings.TrimSpace(req.RunID)
|
||||
if runID == "" {
|
||||
runID = fmt.Sprintf("run-%d", time.Now().UnixNano())
|
||||
}
|
||||
if len(req.Entries) == 0 {
|
||||
return BatchImportRunResult{}, fmt.Errorf("entries are required")
|
||||
}
|
||||
|
||||
if err := s.RunStore.Create(ctx, sqlite.ImportRun{
|
||||
RunID: runID,
|
||||
Mode: strings.TrimSpace(req.Mode),
|
||||
AccessMode: strings.TrimSpace(req.AccessMode),
|
||||
State: string(RunStateRunning),
|
||||
TotalItems: len(req.Entries),
|
||||
}); err != nil {
|
||||
return BatchImportRunResult{}, err
|
||||
}
|
||||
|
||||
result := BatchImportRunResult{
|
||||
RunID: runID,
|
||||
ItemIDs: make([]string, 0, len(req.Entries)),
|
||||
}
|
||||
|
||||
for idx, entry := range req.Entries {
|
||||
itemID := fmt.Sprintf("%s-item-%d", runID, idx+1)
|
||||
result.ItemIDs = append(result.ItemIDs, itemID)
|
||||
|
||||
providerID := NormalizeProviderID(entry.BaseURL)
|
||||
fingerprint := fingerprintAPIKey(entry.APIKey)
|
||||
initialItem := sqlite.ImportRunItem{
|
||||
ItemID: itemID,
|
||||
RunID: runID,
|
||||
BaseURL: strings.TrimSpace(entry.BaseURL),
|
||||
ProviderID: providerID,
|
||||
APIKeyFingerprint: fingerprint,
|
||||
CurrentStage: string(ItemStageProbe),
|
||||
ConfirmationStatus: string(ConfirmationPending),
|
||||
AccessStatus: string(AccessStatusUnknown),
|
||||
MatchedAccountState: string(MatchedAccountStateNone),
|
||||
AccountResolution: string(AccountResolutionCreated),
|
||||
}
|
||||
if err := s.ItemStore.Upsert(ctx, initialItem); err != nil {
|
||||
return BatchImportRunResult{}, err
|
||||
}
|
||||
|
||||
modelsResult, err := s.ProbeModels(ctx, entry.BaseURL, entry.APIKey)
|
||||
if err != nil {
|
||||
return BatchImportRunResult{}, err
|
||||
}
|
||||
rawModels := append([]string(nil), modelsResult.RawModels...)
|
||||
capabilityProfile, err := s.ProbeCapabilities(ctx, entry.BaseURL, entry.APIKey, rawModels)
|
||||
if err != nil {
|
||||
return BatchImportRunResult{}, err
|
||||
}
|
||||
routingStrategy := BuildImportRoutingStrategy(capabilityProfile)
|
||||
resolvedSmokeModel, recommendedModels, err := probe.ResolveSmokeModel(entry.RequestedModels, rawModels, capabilityProfile)
|
||||
if err != nil {
|
||||
return BatchImportRunResult{}, err
|
||||
}
|
||||
|
||||
canonicalFamilies := uniqueCanonicalFamilies(rawModels)
|
||||
reuseLookup := ReuseLookupResult{}
|
||||
if s.InspectReuse != nil {
|
||||
reuseLookup, err = s.InspectReuse(ctx, ReuseLookupInput{
|
||||
HostID: strings.TrimSpace(req.HostID),
|
||||
ProviderID: providerID,
|
||||
BaseURL: entry.BaseURL,
|
||||
APIKeyFingerprint: fingerprint,
|
||||
CanonicalModelFamilies: canonicalFamilies,
|
||||
})
|
||||
if err != nil {
|
||||
return BatchImportRunResult{}, err
|
||||
}
|
||||
}
|
||||
|
||||
reuseDecision := DecideReuse(ReuseInput{
|
||||
ProviderID: providerID,
|
||||
CanonicalModelFamilies: canonicalFamilies,
|
||||
MatchedAccountID: reuseLookup.MatchedAccountID,
|
||||
MatchedAccountState: reuseLookup.MatchedAccountState,
|
||||
ExistingProviderID: reuseLookup.ExistingProviderID,
|
||||
ExistingAccessStatus: reuseLookup.ExistingAccessStatus,
|
||||
ExistingCanonicalFamilys: reuseLookup.ExistingCanonicalFamilys,
|
||||
})
|
||||
|
||||
finalItem := sqlite.ImportRunItem{
|
||||
ItemID: itemID,
|
||||
RunID: runID,
|
||||
BaseURL: strings.TrimSpace(entry.BaseURL),
|
||||
ProviderID: providerID,
|
||||
APIKeyFingerprint: fingerprint,
|
||||
RequestedModelsJSON: mustMarshalJSON(entry.RequestedModels, "[]"),
|
||||
RawModelsJSON: mustMarshalJSON(rawModels, "[]"),
|
||||
NormalizedModelsJSON: mustMarshalJSON(uniqueNormalizedModels(rawModels), "[]"),
|
||||
CanonicalFamiliesJSON: mustMarshalJSON(canonicalFamilies, "[]"),
|
||||
RecommendedModelsJSON: mustMarshalJSON(recommendedModels, "[]"),
|
||||
ResolvedSmokeModel: resolvedSmokeModel,
|
||||
CapabilityProfileJSON: mustMarshalJSON(capabilityProfile, "{}"),
|
||||
CurrentStage: string(ItemStageConfirm),
|
||||
ConfirmationStatus: string(ConfirmationPending),
|
||||
AccessStatus: string(AccessStatusUnknown),
|
||||
MatchedAccountState: string(reuseDecision.MatchedAccountState),
|
||||
AccountResolution: string(reuseDecision.AccountResolution),
|
||||
ProvisionReused: reuseDecision.ProvisionReused,
|
||||
ReusedFromProviderID: reuseDecision.ReusedFromProviderID,
|
||||
ReusedFromAccountID: int64PtrIfSet(reuseDecision.ReusedFromAccountID),
|
||||
}
|
||||
|
||||
if reuseDecision.ProvisionReused {
|
||||
patchContract := ModelMappingDelta(reuseLookup.ExistingModelMapping, probe.BuildAliasTable(rawModels))
|
||||
if shouldPatchAliases(reuseLookup.ExistingModelMapping, patchContract.ModelMapping) {
|
||||
if s.Provisioner == nil {
|
||||
return BatchImportRunResult{}, fmt.Errorf("provisioner is required for patch-only flow")
|
||||
}
|
||||
if err := s.Provisioner.Patch(ctx, PatchProvisionRequest{
|
||||
ProviderID: reuseDecision.ReusedFromProviderID,
|
||||
Contract: patchContract,
|
||||
}); err != nil {
|
||||
return BatchImportRunResult{}, err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if s.Provisioner == nil {
|
||||
return BatchImportRunResult{}, fmt.Errorf("provisioner is required")
|
||||
}
|
||||
provisionResult, err := s.Provisioner.Provision(ctx, ProvisionRequest{
|
||||
RunID: runID,
|
||||
ItemID: itemID,
|
||||
Entry: entry,
|
||||
ProviderID: providerID,
|
||||
ResolvedModel: resolvedSmokeModel,
|
||||
RoutingStrategy: routingStrategy,
|
||||
CapabilityProfile: capabilityProfile,
|
||||
})
|
||||
if err != nil {
|
||||
return BatchImportRunResult{}, err
|
||||
}
|
||||
finalItem.LegacyBatchID = provisionResult.LegacyBatchID
|
||||
finalItem.LegacyProviderID = strings.TrimSpace(provisionResult.LegacyProviderID)
|
||||
}
|
||||
|
||||
if err := s.ItemStore.Upsert(ctx, finalItem); err != nil {
|
||||
return BatchImportRunResult{}, err
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func uniqueCanonicalFamilies(rawModels []string) []string {
|
||||
seen := make(map[string]struct{}, len(rawModels))
|
||||
families := make([]string, 0, len(rawModels))
|
||||
for _, rawModel := range rawModels {
|
||||
family := probe.CanonicalModelFamily(rawModel)
|
||||
if family == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[family]; ok {
|
||||
continue
|
||||
}
|
||||
seen[family] = struct{}{}
|
||||
families = append(families, family)
|
||||
}
|
||||
return families
|
||||
}
|
||||
|
||||
func uniqueNormalizedModels(rawModels []string) []string {
|
||||
seen := make(map[string]struct{}, len(rawModels))
|
||||
models := make([]string, 0, len(rawModels))
|
||||
for _, rawModel := range rawModels {
|
||||
normalized := probe.NormalizeModelID(rawModel)
|
||||
if normalized == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[normalized]; ok {
|
||||
continue
|
||||
}
|
||||
seen[normalized] = struct{}{}
|
||||
models = append(models, normalized)
|
||||
}
|
||||
return models
|
||||
}
|
||||
|
||||
func mustMarshalJSON(value any, fallback string) string {
|
||||
payload, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return fallback
|
||||
}
|
||||
return string(payload)
|
||||
}
|
||||
|
||||
func fingerprintAPIKey(apiKey string) string {
|
||||
trimmed := strings.TrimSpace(apiKey)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256([]byte(trimmed))
|
||||
return fmt.Sprintf("sha256:%x", sum[:8])
|
||||
}
|
||||
|
||||
func int64PtrIfSet(value int64) *int64 {
|
||||
if value == 0 {
|
||||
return nil
|
||||
}
|
||||
result := value
|
||||
return &result
|
||||
}
|
||||
|
||||
func shouldPatchAliases(existing map[string]string, next map[string]string) bool {
|
||||
if len(existing) == 0 {
|
||||
return false
|
||||
}
|
||||
for key, value := range next {
|
||||
if existingValue, ok := existing[key]; !ok || existingValue != value {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
296
internal/batch/service_test.go
Normal file
296
internal/batch/service_test.go
Normal file
@@ -0,0 +1,296 @@
|
||||
package batch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/probe"
|
||||
"sub2api-cn-relay-manager/internal/store/sqlite"
|
||||
)
|
||||
|
||||
func TestBatchImport_StartRun(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("creates run items and backfills legacy linkage after provision", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
runStore := &fakeRunStore{}
|
||||
itemStore := &fakeItemStore{}
|
||||
provisioner := &fakeProvisioner{
|
||||
provisionResult: ProvisionResult{
|
||||
LegacyBatchID: int64Ptr(81),
|
||||
LegacyProviderID: "legacy-provider",
|
||||
},
|
||||
}
|
||||
|
||||
service := BatchImportService{
|
||||
RunStore: runStore,
|
||||
ItemStore: itemStore,
|
||||
ProbeModels: func(context.Context, string, string) (*probe.ModelsResult, error) {
|
||||
return &probe.ModelsResult{RawModels: []string{"deepseek-ai/DeepSeek-V4-Pro"}}, nil
|
||||
},
|
||||
ProbeCapabilities: func(context.Context, string, string, []string) (*probe.CapabilityProfile, error) {
|
||||
return &probe.CapabilityProfile{
|
||||
TransportProfile: probe.TransportProfile{SupportsOpenAIChatCompletions: true},
|
||||
ModelProfiles: []probe.ModelCapabilityProfile{
|
||||
{RawModelID: "deepseek-ai/DeepSeek-V4-Pro", CanonicalModelFamily: "deepseek-v4-pro", SmokeChatOK: true},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
InspectReuse: func(context.Context, ReuseLookupInput) (ReuseLookupResult, error) {
|
||||
return ReuseLookupResult{}, nil
|
||||
},
|
||||
Provisioner: provisioner,
|
||||
}
|
||||
|
||||
result, err := service.StartRun(context.Background(), BatchImportRunRequest{
|
||||
RunID: "run-1",
|
||||
Mode: "strict",
|
||||
AccessMode: "subscription",
|
||||
HostID: "host-1",
|
||||
HostBaseURL: "https://relay.example.com",
|
||||
Entries: []BatchImportEntry{
|
||||
{BaseURL: "https://api.deepseek.com/v1", APIKey: "sk-live", RequestedModels: []string{"DeepSeek V4 Pro"}},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("StartRun() error = %v", err)
|
||||
}
|
||||
if result.RunID != "run-1" {
|
||||
t.Fatalf("RunID = %q, want run-1", result.RunID)
|
||||
}
|
||||
if len(runStore.created) != 1 || runStore.created[0].RunID != "run-1" {
|
||||
t.Fatalf("created runs = %#v, want run-1 persisted", runStore.created)
|
||||
}
|
||||
if provisioner.provisionCalls != 1 {
|
||||
t.Fatalf("provision calls = %d, want 1", provisioner.provisionCalls)
|
||||
}
|
||||
if len(itemStore.upserts) != 2 {
|
||||
t.Fatalf("item upserts = %d, want initial + final", len(itemStore.upserts))
|
||||
}
|
||||
|
||||
finalItem := itemStore.upserts[len(itemStore.upserts)-1]
|
||||
if finalItem.LegacyBatchID == nil || *finalItem.LegacyBatchID != 81 {
|
||||
t.Fatalf("LegacyBatchID = %#v, want 81", finalItem.LegacyBatchID)
|
||||
}
|
||||
if finalItem.LegacyProviderID != "legacy-provider" {
|
||||
t.Fatalf("LegacyProviderID = %q, want legacy-provider", finalItem.LegacyProviderID)
|
||||
}
|
||||
if finalItem.CurrentStage != string(ItemStageConfirm) {
|
||||
t.Fatalf("CurrentStage = %q, want confirm", finalItem.CurrentStage)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("active duplicate account is reused without provision", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
providerID := NormalizeProviderID("https://api.kimi.com/v1")
|
||||
itemStore := &fakeItemStore{}
|
||||
provisioner := &fakeProvisioner{}
|
||||
service := BatchImportService{
|
||||
RunStore: &fakeRunStore{},
|
||||
ItemStore: itemStore,
|
||||
Provisioner: provisioner,
|
||||
ProbeModels: func(context.Context, string, string) (*probe.ModelsResult, error) {
|
||||
return &probe.ModelsResult{RawModels: []string{"kimi-k2.6"}}, nil
|
||||
},
|
||||
ProbeCapabilities: func(context.Context, string, string, []string) (*probe.CapabilityProfile, error) {
|
||||
return &probe.CapabilityProfile{
|
||||
ModelProfiles: []probe.ModelCapabilityProfile{{RawModelID: "kimi-k2.6", CanonicalModelFamily: "kimi-2.6", SmokeChatOK: true}},
|
||||
}, nil
|
||||
},
|
||||
InspectReuse: func(context.Context, ReuseLookupInput) (ReuseLookupResult, error) {
|
||||
return ReuseLookupResult{
|
||||
ExistingProviderID: providerID,
|
||||
ExistingAccessStatus: AccessStatusActive,
|
||||
ExistingCanonicalFamilys: []string{"kimi 2.6"},
|
||||
MatchedAccountID: 201,
|
||||
MatchedAccountState: MatchedAccountStateActive,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
_, err := service.StartRun(context.Background(), BatchImportRunRequest{
|
||||
RunID: "run-2",
|
||||
Mode: "strict",
|
||||
AccessMode: "subscription",
|
||||
Entries: []BatchImportEntry{
|
||||
{BaseURL: "https://api.kimi.com/v1", APIKey: "sk-live", RequestedModels: []string{"kimi 2.6"}},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("StartRun() error = %v", err)
|
||||
}
|
||||
if provisioner.provisionCalls != 0 {
|
||||
t.Fatalf("provision calls = %d, want 0", provisioner.provisionCalls)
|
||||
}
|
||||
|
||||
finalItem := itemStore.upserts[len(itemStore.upserts)-1]
|
||||
if !finalItem.ProvisionReused {
|
||||
t.Fatal("ProvisionReused = false, want true")
|
||||
}
|
||||
if finalItem.MatchedAccountState != string(MatchedAccountStateActive) {
|
||||
t.Fatalf("MatchedAccountState = %q, want active", finalItem.MatchedAccountState)
|
||||
}
|
||||
if finalItem.AccountResolution != string(AccountResolutionReused) {
|
||||
t.Fatalf("AccountResolution = %q, want reused", finalItem.AccountResolution)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("deprecated duplicate account becomes reactivated", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
providerID := NormalizeProviderID("https://api.kimi.com/v1")
|
||||
itemStore := &fakeItemStore{}
|
||||
service := BatchImportService{
|
||||
RunStore: &fakeRunStore{},
|
||||
ItemStore: itemStore,
|
||||
Provisioner: &fakeProvisioner{},
|
||||
ProbeModels: func(context.Context, string, string) (*probe.ModelsResult, error) {
|
||||
return &probe.ModelsResult{RawModels: []string{"kimi-k2.6"}}, nil
|
||||
},
|
||||
ProbeCapabilities: func(context.Context, string, string, []string) (*probe.CapabilityProfile, error) {
|
||||
return &probe.CapabilityProfile{
|
||||
ModelProfiles: []probe.ModelCapabilityProfile{{RawModelID: "kimi-k2.6", CanonicalModelFamily: "kimi-2.6", SmokeChatOK: true}},
|
||||
}, nil
|
||||
},
|
||||
InspectReuse: func(context.Context, ReuseLookupInput) (ReuseLookupResult, error) {
|
||||
return ReuseLookupResult{
|
||||
ExistingProviderID: providerID,
|
||||
ExistingAccessStatus: AccessStatusActive,
|
||||
ExistingCanonicalFamilys: []string{"kimi-2.6"},
|
||||
MatchedAccountID: 301,
|
||||
MatchedAccountState: MatchedAccountStateDeprecated,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
_, err := service.StartRun(context.Background(), BatchImportRunRequest{
|
||||
RunID: "run-3",
|
||||
Mode: "strict",
|
||||
AccessMode: "subscription",
|
||||
Entries: []BatchImportEntry{
|
||||
{BaseURL: "https://api.kimi.com/v1", APIKey: "sk-live", RequestedModels: []string{"kimi 2.6"}},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("StartRun() error = %v", err)
|
||||
}
|
||||
|
||||
finalItem := itemStore.upserts[len(itemStore.upserts)-1]
|
||||
if finalItem.AccountResolution != string(AccountResolutionReactivated) {
|
||||
t.Fatalf("AccountResolution = %q, want reactivated", finalItem.AccountResolution)
|
||||
}
|
||||
if !finalItem.ProvisionReused {
|
||||
t.Fatal("ProvisionReused = false, want true")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("same family new alias only patches mapping", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
providerID := NormalizeProviderID("https://api.kimi.com/v1")
|
||||
itemStore := &fakeItemStore{}
|
||||
provisioner := &fakeProvisioner{}
|
||||
service := BatchImportService{
|
||||
RunStore: &fakeRunStore{},
|
||||
ItemStore: itemStore,
|
||||
Provisioner: provisioner,
|
||||
ProbeModels: func(context.Context, string, string) (*probe.ModelsResult, error) {
|
||||
return &probe.ModelsResult{RawModels: []string{"Kimi-K2.6"}}, nil
|
||||
},
|
||||
ProbeCapabilities: func(context.Context, string, string, []string) (*probe.CapabilityProfile, error) {
|
||||
return &probe.CapabilityProfile{
|
||||
ModelProfiles: []probe.ModelCapabilityProfile{{RawModelID: "Kimi-K2.6", CanonicalModelFamily: "kimi-2.6", SmokeChatOK: true}},
|
||||
}, nil
|
||||
},
|
||||
InspectReuse: func(context.Context, ReuseLookupInput) (ReuseLookupResult, error) {
|
||||
return ReuseLookupResult{
|
||||
ExistingProviderID: providerID,
|
||||
ExistingAccessStatus: AccessStatusActive,
|
||||
ExistingCanonicalFamilys: []string{"kimi 2.6"},
|
||||
MatchedAccountID: 401,
|
||||
MatchedAccountState: MatchedAccountStateActive,
|
||||
ExistingModelMapping: map[string]string{"kimi-k2.6": "kimi-2.6"},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
_, err := service.StartRun(context.Background(), BatchImportRunRequest{
|
||||
RunID: "run-4",
|
||||
Mode: "strict",
|
||||
AccessMode: "subscription",
|
||||
Entries: []BatchImportEntry{
|
||||
{BaseURL: "https://api.kimi.com/v1", APIKey: "sk-live", RequestedModels: []string{"kimi 2.6"}},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("StartRun() error = %v", err)
|
||||
}
|
||||
if provisioner.provisionCalls != 0 {
|
||||
t.Fatalf("provision calls = %d, want 0", provisioner.provisionCalls)
|
||||
}
|
||||
if provisioner.patchCalls != 1 {
|
||||
t.Fatalf("patch calls = %d, want 1", provisioner.patchCalls)
|
||||
}
|
||||
if provisioner.lastPatch.Contract.ModelMapping["Kimi-K2.6"] != "kimi-2.6" {
|
||||
t.Fatalf("patch mapping = %#v, want raw alias mapped to canonical family", provisioner.lastPatch.Contract.ModelMapping)
|
||||
}
|
||||
|
||||
finalItem := itemStore.upserts[len(itemStore.upserts)-1]
|
||||
if !finalItem.ProvisionReused {
|
||||
t.Fatal("ProvisionReused = false, want true for patch-only flow")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
type fakeRunStore struct {
|
||||
created []sqlite.ImportRun
|
||||
}
|
||||
|
||||
func (f *fakeRunStore) Create(ctx context.Context, run sqlite.ImportRun) error {
|
||||
f.created = append(f.created, run)
|
||||
return nil
|
||||
}
|
||||
|
||||
type fakeItemStore struct {
|
||||
upserts []sqlite.ImportRunItem
|
||||
}
|
||||
|
||||
func (f *fakeItemStore) Upsert(ctx context.Context, item sqlite.ImportRunItem) error {
|
||||
f.upserts = append(f.upserts, item)
|
||||
return nil
|
||||
}
|
||||
|
||||
type fakeProvisioner struct {
|
||||
provisionCalls int
|
||||
patchCalls int
|
||||
provisionResult ProvisionResult
|
||||
lastPatch PatchProvisionRequest
|
||||
}
|
||||
|
||||
func (f *fakeProvisioner) Provision(ctx context.Context, req ProvisionRequest) (ProvisionResult, error) {
|
||||
f.provisionCalls++
|
||||
return f.provisionResult, nil
|
||||
}
|
||||
|
||||
func (f *fakeProvisioner) Patch(ctx context.Context, req PatchProvisionRequest) error {
|
||||
f.patchCalls++
|
||||
f.lastPatch = req
|
||||
return nil
|
||||
}
|
||||
|
||||
func int64Ptr(value int64) *int64 {
|
||||
return &value
|
||||
}
|
||||
|
||||
func mustJSON(t *testing.T, value any) string {
|
||||
t.Helper()
|
||||
|
||||
payload, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal() error = %v", err)
|
||||
}
|
||||
return string(payload)
|
||||
}
|
||||
Reference in New Issue
Block a user