fix(provision): reconcile channel pricing and hosted access

This commit is contained in:
phamnazage-jpg
2026-05-20 22:09:40 +08:00
parent 83ee216a4d
commit ca1d448cc0
27 changed files with 1344 additions and 154 deletions

View File

@@ -27,6 +27,7 @@ type ClosureRequest struct {
}
type Host interface {
EnsureSubscriptionAccess(ctx context.Context, req sub2api.EnsureSubscriptionAccessRequest) (sub2api.SubscriptionAccessRef, error)
AssignSubscription(ctx context.Context, req sub2api.AssignSubscriptionRequest) (sub2api.SubscriptionRef, error)
CheckGatewayAccess(ctx context.Context, req sub2api.GatewayAccessCheckRequest) (sub2api.GatewayAccessResult, error)
}
@@ -52,7 +53,7 @@ func Validate(req ClosureRequest) error {
default:
return fmt.Errorf("unsupported access mode %q", req.Mode)
}
if strings.TrimSpace(req.ProbeAPIKey) == "" {
if strings.TrimSpace(req.Mode) != ModeSubscription && strings.TrimSpace(req.ProbeAPIKey) == "" {
return fmt.Errorf("access probe api key is required to verify gateway closure")
}
return nil
@@ -65,14 +66,29 @@ func (s *Service) Close(ctx context.Context, req ClosureRequest) (sub2api.Gatewa
if err := Validate(req); err != nil {
return sub2api.GatewayAccessResult{}, err
}
probeAPIKey := strings.TrimSpace(req.ProbeAPIKey)
if strings.TrimSpace(req.Mode) == ModeSubscription {
for _, target := range req.Subscriptions {
if _, err := s.host.AssignSubscription(ctx, sub2api.AssignSubscriptionRequest{UserID: target.UserID, GroupID: req.GroupID, DurationDays: target.DurationDays}); err != nil {
resolvedTarget := target.UserID
accessRef, err := s.host.EnsureSubscriptionAccess(ctx, sub2api.EnsureSubscriptionAccessRequest{UserSelector: target.UserID, GroupID: req.GroupID})
if err != nil {
return sub2api.GatewayAccessResult{}, fmt.Errorf("ensure subscription access for %s: %w", target.UserID, err)
}
if strings.TrimSpace(accessRef.UserID) != "" {
resolvedTarget = accessRef.UserID
}
if strings.TrimSpace(accessRef.APIKey) != "" {
probeAPIKey = strings.TrimSpace(accessRef.APIKey)
}
if _, err := s.host.AssignSubscription(ctx, sub2api.AssignSubscriptionRequest{UserID: resolvedTarget, GroupID: req.GroupID, DurationDays: target.DurationDays}); err != nil {
return sub2api.GatewayAccessResult{}, fmt.Errorf("assign subscription for %s: %w", target.UserID, err)
}
}
}
result, err := s.host.CheckGatewayAccess(ctx, sub2api.GatewayAccessCheckRequest{APIKey: req.ProbeAPIKey, ExpectedModel: req.ExpectedModel})
if probeAPIKey == "" {
return sub2api.GatewayAccessResult{}, fmt.Errorf("access probe api key is required to verify gateway closure")
}
result, err := s.host.CheckGatewayAccess(ctx, sub2api.GatewayAccessCheckRequest{APIKey: probeAPIKey, ExpectedModel: req.ExpectedModel})
if err != nil {
return sub2api.GatewayAccessResult{}, fmt.Errorf("check gateway access: %w", err)
}

View File

@@ -22,14 +22,28 @@ func TestValidateRejectsMissingSubscriptionsForSubscriptionMode(t *testing.T) {
}
}
func TestValidateAllowsManagedSubscriptionProbeWithoutExplicitAPIKey(t *testing.T) {
err := Validate(ClosureRequest{
Mode: "subscription",
GroupID: "group-1",
ExpectedModel: "deepseek-chat",
Subscriptions: []SubscriptionTarget{{UserID: "crm-user-42", DurationDays: 30}},
})
if err != nil {
t.Fatalf("Validate() error = %v, want nil for managed subscription probe", err)
}
}
func TestServiceCloseAssignsSubscriptionsAndProbesGateway(t *testing.T) {
host := &fakeClosureHost{
gatewayResult: sub2api.GatewayAccessResult{OK: true, StatusCode: 200, HasExpectedModel: true, Models: []string{"deepseek-chat"}},
managedAccess: map[string]sub2api.SubscriptionAccessRef{
"user-1": {UserID: "host-user-1", APIKey: "managed-user-key"},
},
}
service := NewService(host)
result, err := service.Close(context.Background(), ClosureRequest{
Mode: "subscription",
ProbeAPIKey: "user-key",
GroupID: "group-1",
ExpectedModel: "deepseek-chat",
Subscriptions: []SubscriptionTarget{{UserID: "user-1", DurationDays: 30}},
@@ -40,7 +54,10 @@ func TestServiceCloseAssignsSubscriptionsAndProbesGateway(t *testing.T) {
if len(host.assigned) != 1 {
t.Fatalf("assigned subscriptions = %d, want 1", len(host.assigned))
}
if host.gatewayProbe.APIKey != "user-key" || host.gatewayProbe.ExpectedModel != "deepseek-chat" {
if host.assigned[0].UserID != "host-user-1" {
t.Fatalf("assigned subscription user = %q, want host-user-1", host.assigned[0].UserID)
}
if host.gatewayProbe.APIKey != "managed-user-key" || host.gatewayProbe.ExpectedModel != "deepseek-chat" {
t.Fatalf("gateway probe = %+v, want api key + expected model", host.gatewayProbe)
}
if !result.OK || !result.HasExpectedModel {
@@ -68,12 +85,20 @@ func TestServiceCloseReturnsSubscriptionErrorBeforeGatewayProbe(t *testing.T) {
type fakeClosureHost struct {
assigned []sub2api.AssignSubscriptionRequest
managedAccess map[string]sub2api.SubscriptionAccessRef
assignErr error
gatewayProbe sub2api.GatewayAccessCheckRequest
gatewayResult sub2api.GatewayAccessResult
gatewayErr error
}
func (f *fakeClosureHost) EnsureSubscriptionAccess(_ context.Context, req sub2api.EnsureSubscriptionAccessRequest) (sub2api.SubscriptionAccessRef, error) {
if ref, ok := f.managedAccess[req.UserSelector]; ok {
return ref, nil
}
return sub2api.SubscriptionAccessRef{}, errors.New("missing managed access")
}
func (f *fakeClosureHost) AssignSubscription(_ context.Context, req sub2api.AssignSubscriptionRequest) (sub2api.SubscriptionRef, error) {
if f.assignErr != nil {
return sub2api.SubscriptionRef{}, f.assignErr

View File

@@ -5,10 +5,12 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"path/filepath"
"strings"
"testing"
"time"
@@ -874,6 +876,47 @@ func TestHandlerErrorPaths(t *testing.T) {
}
}
func TestResolveLatestAccessStatusAggregatesAcrossModeBatches(t *testing.T) {
store := openAppTestStore(t)
defer closeAppTestStore(t, store)
ctx := context.Background()
hostID, err := store.Hosts().Create(ctx, sqlite.Host{HostID: "host-1", BaseURL: "https://sub2api.example.com", HostVersion: "0.1.126", AuthType: "apikey", AuthToken: "token"})
if err != nil {
t.Fatalf("Hosts().Create() error = %v", err)
}
packID, err := store.Packs().Create(ctx, sqlite.Pack{PackID: "openai-cn-pack", Version: "1.0.0", TargetHost: "sub2api", Checksum: "checksum-1"})
if err != nil {
t.Fatalf("Packs().Create() error = %v", err)
}
providerID, err := store.Providers().Create(ctx, sqlite.Provider{PackID: packID, ProviderID: "deepseek", DisplayName: "DeepSeek", BaseURL: "https://api.deepseek.com", Platform: "openai"})
if err != nil {
t.Fatalf("Providers().Create() error = %v", err)
}
batchSubscription, err := store.ImportBatches().Create(ctx, sqlite.ImportBatch{HostID: hostID, PackID: packID, ProviderID: providerID, Mode: provision.ImportModePartial, BatchStatus: provision.BatchStatusSucceeded, AccessStatus: provision.AccessStatusSubscriptionReady})
if err != nil {
t.Fatalf("ImportBatches().Create(subscription) error = %v", err)
}
if _, err := store.AccessClosures().Create(ctx, sqlite.AccessClosureRecord{BatchID: batchSubscription, ClosureType: provision.AccessModeSubscription, Status: provision.AccessStatusSubscriptionReady, DetailsJSON: "{}"}); err != nil {
t.Fatalf("AccessClosures().Create(subscription) error = %v", err)
}
batchSelf, err := store.ImportBatches().Create(ctx, sqlite.ImportBatch{HostID: hostID, PackID: packID, ProviderID: providerID, Mode: provision.ImportModePartial, BatchStatus: provision.BatchStatusSucceeded, AccessStatus: provision.AccessStatusSelfServiceReady})
if err != nil {
t.Fatalf("ImportBatches().Create(self_service) error = %v", err)
}
if _, err := store.AccessClosures().Create(ctx, sqlite.AccessClosureRecord{BatchID: batchSelf, ClosureType: provision.AccessModeSelfService, Status: provision.AccessStatusSelfServiceReady, DetailsJSON: "{}"}); err != nil {
t.Fatalf("AccessClosures().Create(self_service) error = %v", err)
}
got, err := resolveLatestAccessStatus(ctx, store, sqlite.Provider{ID: providerID, ProviderID: "deepseek"}, "host-1")
if err != nil {
t.Fatalf("resolveLatestAccessStatus() error = %v", err)
}
if got != provision.AccessStatusFullyReady {
t.Fatalf("resolveLatestAccessStatus() = %q, want %q", got, provision.AccessStatusFullyReady)
}
}
func TestProviderAccessStatusMultipleClosures(t *testing.T) {
handler := NewAPIHandler("t", ActionSet{
GetProviderAccessStatus: func(context.Context, ProviderQueryRequest) (provision.ProviderSnapshot, error) {
@@ -926,6 +969,24 @@ func TestHostSupportStatusRequiresPlansCapability(t *testing.T) {
}
}
func openAppTestStore(t *testing.T) *sqlite.DB {
t.Helper()
dbPath := filepath.Join(t.TempDir(), "state.db")
dsn := fmt.Sprintf("file:%s?_busy_timeout=5000&_pragma=foreign_keys(0)", filepath.ToSlash(dbPath))
store, err := sqlite.Open(context.Background(), dsn)
if err != nil {
t.Fatalf("sqlite.Open() error = %v", err)
}
return store
}
func closeAppTestStore(t *testing.T, store *sqlite.DB) {
t.Helper()
if err := store.Close(); err != nil {
t.Fatalf("store.Close() error = %v", err)
}
}
func assertJSONContains(t *testing.T, payload []byte, key string, want any) {
t.Helper()
var decoded map[string]any

View File

@@ -1367,37 +1367,10 @@ func NewActionSet(sqliteDSN string) ActionSet {
return AccessPreviewResult{}, fmt.Errorf("provider %q exists in multiple packs; pack_id is required", req.ProviderID)
}
providerRow := providers[0]
if strings.TrimSpace(req.HostID) != "" {
hostRow, err := store.Hosts().GetByHostID(ctx, req.HostID)
if err != nil {
return AccessPreviewResult{}, err
}
batch, err := store.ImportBatches().GetLatestByProviderIDAndHostID(ctx, providerRow.ID, hostRow.ID)
if err != nil {
return AccessPreviewResult{}, fmt.Errorf("find batch for provider: %w", err)
}
latestStatus := batch.AccessStatus
closures, err := store.AccessClosures().GetByBatchID(ctx, batch.ID)
if err == nil && len(closures) > 0 {
latestStatus = closures[len(closures)-1].Status
}
available := accessStatusSupportsMode(latestStatus, req.Mode)
message := fmt.Sprintf("latest access status: %s", latestStatus)
if !available {
message = fmt.Sprintf("access status %s does not satisfy mode %s", latestStatus, req.Mode)
}
return AccessPreviewResult{ProviderID: req.ProviderID, Mode: req.Mode, Available: available, Message: message}, nil
}
batch, err := store.ImportBatches().GetLatestByProviderID(ctx, providerRow.ID)
latestStatus, err := resolveLatestAccessStatus(ctx, store, providerRow, req.HostID)
if err != nil {
return AccessPreviewResult{}, fmt.Errorf("find batch for provider: %w", err)
}
latestStatus := batch.AccessStatus
closures, err := store.AccessClosures().GetByBatchID(ctx, batch.ID)
if err == nil && len(closures) > 0 {
latestStatus = closures[len(closures)-1].Status
}
available := accessStatusSupportsMode(latestStatus, req.Mode)
message := fmt.Sprintf("latest access status: %s", latestStatus)
if !available {
@@ -1440,6 +1413,45 @@ func resolveProvidersForQuery(ctx context.Context, store *sqlite.DB, req Provide
return store.Providers().ListByProviderID(ctx, providerID)
}
func resolveLatestAccessStatus(ctx context.Context, store *sqlite.DB, providerRow sqlite.Provider, hostID string) (string, error) {
if store == nil {
return "", fmt.Errorf("store is required")
}
if strings.TrimSpace(hostID) != "" {
hostRow, err := store.Hosts().GetByHostID(ctx, hostID)
if err != nil {
return "", err
}
batches, err := store.ImportBatches().ListByProviderIDAndHostID(ctx, providerRow.ID, hostRow.ID)
if err != nil {
return "", err
}
modeStatuses, err := provision.LatestModeAccessStatuses(ctx, store, batches)
if err != nil {
return "", err
}
return provision.AggregateAccessStatus(modeStatuses), nil
}
batches, err := store.ImportBatches().ListByProviderID(ctx, providerRow.ID)
if err != nil {
return "", err
}
if len(batches) == 0 {
return "", fmt.Errorf("latest import batch not found for provider")
}
hostIDValue := batches[0].HostID
for _, batch := range batches[1:] {
if batch.HostID != hostIDValue {
return "", fmt.Errorf("provider exists on multiple hosts; host_id is required")
}
}
modeStatuses, err := provision.LatestModeAccessStatuses(ctx, store, batches)
if err != nil {
return "", err
}
return provision.AggregateAccessStatus(modeStatuses), nil
}
func resolveManagedHost(ctx context.Context, store *sqlite.DB, hostID, baseURL string, auth CreateHostAuth) (sqlite.Host, *sub2api.Client, error) {
if store == nil {
return sqlite.Host{}, nil, fmt.Errorf("store is required")

View File

@@ -1,6 +1,10 @@
package sub2api
import "context"
import (
"context"
"fmt"
"net/http"
)
func (c *Client) CreateChannel(ctx context.Context, req CreateChannelRequest) (ChannelRef, error) {
var ref ChannelRef
@@ -9,3 +13,15 @@ func (c *Client) CreateChannel(ctx context.Context, req CreateChannelRequest) (C
}
return ref, nil
}
func (c *Client) UpdateChannel(ctx context.Context, channelID string, req CreateChannelRequest) error {
path := fmt.Sprintf("/api/v1/admin/channels/%s", channelID)
statusCode, _, body, err := c.perform(ctx, http.MethodPut, path, req)
if err != nil {
return err
}
if statusCode < http.StatusOK || statusCode >= http.StatusMultipleChoices {
return newHTTPError(http.MethodPut, path, statusCode, body)
}
return nil
}

View File

@@ -18,6 +18,7 @@ type HostAdapter interface {
CreateGroup(ctx context.Context, req CreateGroupRequest) (GroupRef, error)
DeleteGroup(ctx context.Context, groupID string) error
CreateChannel(ctx context.Context, req CreateChannelRequest) (ChannelRef, error)
UpdateChannel(ctx context.Context, channelID string, req CreateChannelRequest) error
DeleteChannel(ctx context.Context, channelID string) error
CreatePlan(ctx context.Context, req CreatePlanRequest) (PlanRef, error)
DeletePlan(ctx context.Context, planID string) error
@@ -26,6 +27,7 @@ type HostAdapter interface {
DeleteAccount(ctx context.Context, accountID string) error
TestAccount(ctx context.Context, accountID string) (ProbeResult, error)
GetAccountModels(ctx context.Context, accountID string) ([]AccountModel, error)
EnsureSubscriptionAccess(ctx context.Context, req EnsureSubscriptionAccessRequest) (SubscriptionAccessRef, error)
AssignSubscription(ctx context.Context, req AssignSubscriptionRequest) (SubscriptionRef, error)
CheckGatewayAccess(ctx context.Context, req GatewayAccessCheckRequest) (GatewayAccessResult, error)
ListManagedResources(ctx context.Context, req ListManagedResourcesRequest) (ManagedResourceSnapshot, error)
@@ -54,11 +56,38 @@ type GroupRef struct {
}
type CreateChannelRequest struct {
Name string `json:"name"`
GroupIDs []string `json:"group_ids"`
ModelMapping map[string]string `json:"model_mapping,omitempty"`
RestrictModels bool `json:"restrict_models,omitempty"`
BillingModelSource string `json:"billing_model_source,omitempty"`
Name string `json:"name"`
GroupIDs []string `json:"group_ids"`
ModelMapping map[string]string `json:"model_mapping,omitempty"`
ModelPricing []ChannelModelPricing `json:"model_pricing,omitempty"`
Platform string `json:"-"`
RestrictModels bool `json:"restrict_models,omitempty"`
BillingModelSource string `json:"billing_model_source,omitempty"`
}
type ChannelModelPricing struct {
Platform string `json:"platform,omitempty"`
Models []string `json:"models,omitempty"`
BillingMode string `json:"billing_mode,omitempty"`
InputPrice *float64 `json:"input_price,omitempty"`
OutputPrice *float64 `json:"output_price,omitempty"`
CacheWritePrice *float64 `json:"cache_write_price,omitempty"`
CacheReadPrice *float64 `json:"cache_read_price,omitempty"`
ImageOutputPrice *float64 `json:"image_output_price,omitempty"`
PerRequestPrice *float64 `json:"per_request_price,omitempty"`
Intervals []ChannelPricingTier `json:"intervals,omitempty"`
}
type ChannelPricingTier struct {
MinTokens int `json:"min_tokens,omitempty"`
MaxTokens *int `json:"max_tokens,omitempty"`
TierLabel string `json:"tier_label,omitempty"`
InputPrice *float64 `json:"input_price,omitempty"`
OutputPrice *float64 `json:"output_price,omitempty"`
CacheWritePrice *float64 `json:"cache_write_price,omitempty"`
CacheReadPrice *float64 `json:"cache_read_price,omitempty"`
PerRequestPrice *float64 `json:"per_request_price,omitempty"`
SortOrder int `json:"sort_order,omitempty"`
}
type ChannelRef struct {
@@ -116,6 +145,16 @@ type AssignSubscriptionRequest struct {
DurationDays int `json:"validity_days,omitempty"`
}
type EnsureSubscriptionAccessRequest struct {
UserSelector string
GroupID string
}
type SubscriptionAccessRef struct {
UserID string
APIKey string
}
type SubscriptionRef struct {
ID string `json:"id"`
}

View File

@@ -48,12 +48,41 @@ func flexibleIDSliceValues(raw []string) []any {
}
func (r CreateChannelRequest) MarshalJSON() ([]byte, error) {
modelMapping := map[string]map[string]string{}
platform := strings.TrimSpace(r.Platform)
if platform == "" {
platform = "openai"
}
if len(r.ModelMapping) > 0 {
inner := make(map[string]string, len(r.ModelMapping))
for key, value := range r.ModelMapping {
inner[key] = value
}
modelMapping[platform] = inner
}
modelPricing := make([]ChannelModelPricing, 0, len(r.ModelPricing))
for _, entry := range r.ModelPricing {
pricing := entry
if strings.TrimSpace(pricing.Platform) == "" {
pricing.Platform = platform
}
modelPricing = append(modelPricing, pricing)
}
return json.Marshal(struct {
Name string `json:"name"`
GroupIDs []any `json:"group_ids"`
Name string `json:"name"`
GroupIDs []any `json:"group_ids"`
ModelMapping map[string]map[string]string `json:"model_mapping,omitempty"`
ModelPricing []ChannelModelPricing `json:"model_pricing,omitempty"`
RestrictModels bool `json:"restrict_models,omitempty"`
BillingModelSource string `json:"billing_model_source,omitempty"`
}{
Name: r.Name,
GroupIDs: flexibleIDSliceValues(r.GroupIDs),
Name: r.Name,
GroupIDs: flexibleIDSliceValues(r.GroupIDs),
ModelMapping: modelMapping,
ModelPricing: modelPricing,
RestrictModels: r.RestrictModels,
BillingModelSource: r.BillingModelSource,
})
}

View File

@@ -6,6 +6,7 @@ import (
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
@@ -591,8 +592,23 @@ func TestCreateGroupWithMock(t *testing.T) {
func TestCreateChannelWithMock(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req struct {
Name string `json:"name"`
GroupIDs []int64 `json:"group_ids"`
Name string `json:"name"`
GroupIDs []int64 `json:"group_ids"`
ModelMapping map[string]map[string]string `json:"model_mapping"`
ModelPricing []struct {
Platform string `json:"platform"`
Models []string `json:"models"`
BillingMode string `json:"billing_mode"`
InputPrice *float64 `json:"input_price"`
OutputPrice *float64 `json:"output_price"`
CacheWritePrice *float64 `json:"cache_write_price"`
CacheReadPrice *float64 `json:"cache_read_price"`
ImageOutputPrice *float64 `json:"image_output_price"`
PerRequestPrice *float64 `json:"per_request_price"`
Intervals []any `json:"intervals"`
} `json:"model_pricing"`
RestrictModels bool `json:"restrict_models"`
BillingModelSource string `json:"billing_model_source"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("decode request: %v", err)
@@ -603,11 +619,36 @@ func TestCreateChannelWithMock(t *testing.T) {
if len(req.GroupIDs) != 1 || req.GroupIDs[0] != 101 {
t.Fatalf("group_ids = %v, want [101]", req.GroupIDs)
}
if req.ModelMapping["openai"]["deepseek-v4-pro"] != "deepseek-v4-pro" {
t.Fatalf("model_mapping = %+v, want openai/deepseek-v4-pro passthrough", req.ModelMapping)
}
if len(req.ModelPricing) != 1 {
t.Fatalf("model_pricing len = %d, want 1", len(req.ModelPricing))
}
if req.ModelPricing[0].Platform != "openai" || req.ModelPricing[0].BillingMode != "token" {
t.Fatalf("model_pricing[0] = %+v, want openai/token entry", req.ModelPricing[0])
}
if len(req.ModelPricing[0].Models) != 1 || req.ModelPricing[0].Models[0] != "deepseek-v4-pro" {
t.Fatalf("model_pricing[0].models = %v, want [deepseek-v4-pro]", req.ModelPricing[0].Models)
}
if !req.RestrictModels {
t.Fatal("restrict_models = false, want true")
}
if req.BillingModelSource != "channel_mapped" {
t.Fatalf("billing_model_source = %q, want channel_mapped", req.BillingModelSource)
}
w.Write([]byte(`{"data":{"id":201,"name":"ch"}}`))
}))
defer srv.Close()
client, _ := NewClient(srv.URL, WithAPIKey("k"))
ref, err := client.CreateChannel(context.Background(), CreateChannelRequest{Name: "ch", GroupIDs: []string{"101"}})
ref, err := client.CreateChannel(context.Background(), CreateChannelRequest{
Name: "ch",
GroupIDs: []string{"101"},
ModelMapping: map[string]string{"deepseek-v4-pro": "deepseek-v4-pro"},
ModelPricing: []ChannelModelPricing{{Platform: "openai", Models: []string{"deepseek-v4-pro"}, BillingMode: "token"}},
RestrictModels: true,
BillingModelSource: "channel_mapped",
})
if err != nil {
t.Fatal(err)
}
@@ -699,6 +740,66 @@ func TestAssignSubscriptionWithMock(t *testing.T) {
}
}
func TestEnsureSubscriptionAccessWithMock(t *testing.T) {
var calls []string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
calls = append(calls, r.Method+" "+r.URL.Path)
switch {
case r.Method == http.MethodGet && strings.HasPrefix(r.URL.RequestURI(), "/api/v1/admin/users?"):
w.Write([]byte(`{"data":{"items":[]}}`))
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/admin/users":
w.Write([]byte(`{"data":{"id":84,"email":"relay-sub-user-1@sub2api.local"}}`))
case r.Method == http.MethodPut && r.URL.Path == "/api/v1/admin/users/84":
w.Write([]byte(`{"data":{"id":84}}`))
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/admin/users/84/balance":
w.Write([]byte(`{"data":{"id":84}}`))
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/admin/subscriptions/assign":
var req struct {
UserID int64 `json:"user_id"`
GroupID int64 `json:"group_id"`
DurationDays int `json:"validity_days"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("decode assign subscription request: %v", err)
}
if req.UserID != 84 || req.GroupID != 101 || req.DurationDays != 30 {
t.Fatalf("unexpected assign subscription request: %+v", req)
}
w.Write([]byte(`{"data":{"id":401}}`))
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/auth/login":
w.Write([]byte(`{"data":{"access_token":"user-jwt"}}`))
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/keys":
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("decode managed key request: %v", err)
}
if _, ok := req["group_id"]; ok {
t.Fatalf("managed key request unexpectedly carried group_id: %+v", req)
}
w.Write([]byte(`{"data":{"id":501,"key":"sk-relay-key","name":"managed-key"}}`))
case r.Method == http.MethodPut && r.URL.Path == "/api/v1/admin/api-keys/501":
w.Write([]byte(`{"data":{"api_key":{"id":501}}}`))
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
client, _ := NewClient(srv.URL, WithBearerToken("admin-token"))
ref, err := client.EnsureSubscriptionAccess(context.Background(), EnsureSubscriptionAccessRequest{UserSelector: "crm-user-1", GroupID: "101"})
if err != nil {
t.Fatal(err)
}
if ref.UserID != "84" {
t.Fatalf("user id = %q, want 84", ref.UserID)
}
if !strings.HasPrefix(ref.APIKey, "sk-relay-") {
t.Fatalf("api key = %q, want managed sk-relay-* key", ref.APIKey)
}
if len(calls) < 7 {
t.Fatalf("calls = %v, want managed subscription setup sequence", calls)
}
}
func TestCheckGatewayAccessWithMock(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{"data":[{"id":"gpt-4"},{"id":"claude-3"}]}`))
@@ -741,12 +842,19 @@ func TestBatchCreateAccountsWithMock(t *testing.T) {
if len(acct.GroupIDs) != 1 || acct.GroupIDs[0] != 101 {
t.Fatalf("group_ids = %v, want [101]", acct.GroupIDs)
}
rawMapping, ok := acct.Credentials["model_mapping"].(map[string]any)
if !ok {
t.Fatalf("credentials = %+v, want model_mapping map", acct.Credentials)
}
if got, _ := rawMapping["deepseek-v4-pro"].(string); got != "deepseek-v4-pro" {
t.Fatalf("model_mapping = %+v, want deepseek-v4-pro passthrough", rawMapping)
}
w.Write([]byte(`{"data":[{"id":601,"name":"acct1"}]}`))
}))
defer srv.Close()
client, _ := NewClient(srv.URL, WithAPIKey("k"))
refs, err := client.BatchCreateAccounts(context.Background(), BatchCreateAccountsRequest{
Accounts: []CreateAccountRequest{{Name: "acct1", Platform: "openai", Type: "apikey", GroupIDs: []string{"101"}, Credentials: map[string]any{"api_key": "sk-test", "base_url": "https://api.example.com"}}},
Accounts: []CreateAccountRequest{{Name: "acct1", Platform: "openai", Type: "apikey", GroupIDs: []string{"101"}, Credentials: map[string]any{"api_key": "sk-test", "base_url": "https://api.example.com", "model_mapping": map[string]string{"deepseek-v4-pro": "deepseek-v4-pro"}}}},
})
if err != nil {
t.Fatal(err)

View File

@@ -0,0 +1,320 @@
package sub2api
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
)
const (
managedSubscriptionBalance = 10.0
managedSubscriptionValidityDays = 30
)
type adminUserRecord struct {
ID int64 `json:"id"`
Email string `json:"email"`
}
type adminAPIKeyRecord struct {
ID int64 `json:"id"`
Key string `json:"key"`
Name string `json:"name"`
Group *struct {
ID int64 `json:"id"`
} `json:"group,omitempty"`
GroupID *int64 `json:"group_id,omitempty"`
}
type authTokenPair struct {
AccessToken string `json:"access_token"`
}
func (c *Client) EnsureSubscriptionAccess(ctx context.Context, req EnsureSubscriptionAccessRequest) (SubscriptionAccessRef, error) {
if c == nil {
return SubscriptionAccessRef{}, fmt.Errorf("client is required")
}
selector := strings.TrimSpace(req.UserSelector)
groupID := strings.TrimSpace(req.GroupID)
if selector == "" {
return SubscriptionAccessRef{}, fmt.Errorf("user selector is required")
}
if groupID == "" {
return SubscriptionAccessRef{}, fmt.Errorf("group id is required")
}
groupInt, err := strconv.ParseInt(groupID, 10, 64)
if err != nil {
return SubscriptionAccessRef{}, fmt.Errorf("parse group id %q: %w", groupID, err)
}
identity := buildManagedSubscriptionIdentity(selector, groupID)
user, err := c.findManagedSubscriptionUser(ctx, identity.Email)
if err != nil {
return SubscriptionAccessRef{}, err
}
if user == nil {
user, err = c.createManagedSubscriptionUser(ctx, identity, groupInt)
if err != nil {
return SubscriptionAccessRef{}, err
}
}
if err := c.updateManagedSubscriptionUser(ctx, user.ID, groupInt); err != nil {
return SubscriptionAccessRef{}, err
}
if err := c.setManagedSubscriptionBalance(ctx, user.ID); err != nil {
return SubscriptionAccessRef{}, err
}
if err := c.ensureManagedSubscriptionAssignment(ctx, user.ID, groupID); err != nil {
return SubscriptionAccessRef{}, err
}
userClient, err := c.loginAsManagedSubscriptionUser(ctx, identity.Email, identity.Password)
if err != nil {
return SubscriptionAccessRef{}, err
}
keyRecord, err := c.ensureManagedSubscriptionAPIKey(ctx, userClient, user.ID, identity)
if err != nil {
return SubscriptionAccessRef{}, err
}
if err := c.bindManagedSubscriptionAPIKey(ctx, keyRecord.ID, groupInt); err != nil {
return SubscriptionAccessRef{}, err
}
return SubscriptionAccessRef{UserID: strconv.FormatInt(user.ID, 10), APIKey: identity.CustomKey}, nil
}
type managedSubscriptionIdentity struct {
Email string
Username string
Password string
CustomKey string
KeyName string
}
func buildManagedSubscriptionIdentity(selector, groupID string) managedSubscriptionIdentity {
normalizedSelector := strings.TrimSpace(selector)
seedMaterial := strings.ToLower(normalizedSelector) + "|" + strings.TrimSpace(groupID)
sum := sha256.Sum256([]byte(seedMaterial))
hash := hex.EncodeToString(sum[:])
prefix := sanitizeManagedSubscriptionPrefix(normalizedSelector)
if prefix == "" {
prefix = "relay-sub"
}
prefix = truncateManagedSubscriptionToken(prefix, 24)
shortHash := hash[:16]
keyHash := hash[:32]
username := truncateManagedSubscriptionToken(prefix+"-"+shortHash[:8], 32)
return managedSubscriptionIdentity{
Email: fmt.Sprintf("%s-%s@sub2api.local", prefix, shortHash),
Username: username,
Password: "RelayPwd!" + hash[:12],
CustomKey: "sk-relay-" + keyHash,
KeyName: truncateManagedSubscriptionToken(username+"-key", 48),
}
}
func sanitizeManagedSubscriptionPrefix(value string) string {
value = strings.ToLower(strings.TrimSpace(value))
var b strings.Builder
lastDash := false
for _, r := range value {
switch {
case r >= 'a' && r <= 'z', r >= '0' && r <= '9':
b.WriteRune(r)
lastDash = false
case !lastDash:
b.WriteByte('-')
lastDash = true
}
}
return strings.Trim(b.String(), "-")
}
func truncateManagedSubscriptionToken(value string, max int) string {
if len(value) <= max {
return value
}
return strings.Trim(value[:max], "-")
}
func (c *Client) findManagedSubscriptionUser(ctx context.Context, email string) (*adminUserRecord, error) {
statusCode, _, body, err := c.perform(ctx, http.MethodGet, "/api/v1/admin/users?search="+url.QueryEscape(email)+"&page=1&page_size=20&sort_by=created_at&sort_order=desc", nil)
if err != nil {
return nil, fmt.Errorf("list admin users: %w", err)
}
if statusCode < http.StatusOK || statusCode >= http.StatusMultipleChoices {
return nil, newHTTPError(http.MethodGet, "/api/v1/admin/users", statusCode, body)
}
var envelope struct {
Data struct {
Items []adminUserRecord `json:"items"`
} `json:"data"`
}
if err := json.Unmarshal(body, &envelope); err != nil {
return nil, fmt.Errorf("decode admin users response: %w", err)
}
for _, item := range envelope.Data.Items {
if strings.EqualFold(strings.TrimSpace(item.Email), email) {
user := item
return &user, nil
}
}
return nil, nil
}
func (c *Client) createManagedSubscriptionUser(ctx context.Context, identity managedSubscriptionIdentity, groupID int64) (*adminUserRecord, error) {
payload := map[string]any{
"email": identity.Email,
"password": identity.Password,
"username": identity.Username,
"notes": "managed by sub2api-cn-relay-manager",
"balance": managedSubscriptionBalance,
"concurrency": 5,
"allowed_groups": []int64{groupID},
}
statusCode, _, body, err := c.perform(ctx, http.MethodPost, "/api/v1/admin/users", payload)
if err != nil {
return nil, fmt.Errorf("create admin user: %w", err)
}
if statusCode == http.StatusConflict {
return c.findManagedSubscriptionUser(ctx, identity.Email)
}
if statusCode < http.StatusOK || statusCode >= http.StatusMultipleChoices {
return nil, newHTTPError(http.MethodPost, "/api/v1/admin/users", statusCode, body)
}
var user adminUserRecord
if err := decodeEnvelopeObject(body, &user); err != nil {
return nil, fmt.Errorf("decode created admin user: %w", err)
}
return &user, nil
}
func (c *Client) updateManagedSubscriptionUser(ctx context.Context, userID, groupID int64) error {
payload := map[string]any{"allowed_groups": []int64{groupID}}
statusCode, _, body, err := c.perform(ctx, http.MethodPut, fmt.Sprintf("/api/v1/admin/users/%d", userID), payload)
if err != nil {
return fmt.Errorf("update admin user groups: %w", err)
}
if statusCode < http.StatusOK || statusCode >= http.StatusMultipleChoices {
return newHTTPError(http.MethodPut, fmt.Sprintf("/api/v1/admin/users/%d", userID), statusCode, body)
}
return nil
}
func (c *Client) setManagedSubscriptionBalance(ctx context.Context, userID int64) error {
payload := map[string]any{"balance": managedSubscriptionBalance, "operation": "set", "notes": "managed by sub2api-cn-relay-manager"}
statusCode, _, body, err := c.perform(ctx, http.MethodPost, fmt.Sprintf("/api/v1/admin/users/%d/balance", userID), payload)
if err != nil {
return fmt.Errorf("set admin user balance: %w", err)
}
if statusCode < http.StatusOK || statusCode >= http.StatusMultipleChoices {
return newHTTPError(http.MethodPost, fmt.Sprintf("/api/v1/admin/users/%d/balance", userID), statusCode, body)
}
return nil
}
func (c *Client) ensureManagedSubscriptionAssignment(ctx context.Context, userID int64, groupID string) error {
_, err := c.AssignSubscription(ctx, AssignSubscriptionRequest{
UserID: strconv.FormatInt(userID, 10),
GroupID: groupID,
DurationDays: managedSubscriptionValidityDays,
})
if err != nil {
return fmt.Errorf("assign managed subscription: %w", err)
}
return nil
}
func (c *Client) loginAsManagedSubscriptionUser(ctx context.Context, email, password string) (*Client, error) {
anon := c.cloneWithAuth("", "")
payload := map[string]any{"email": email, "password": password, "turnstile_token": ""}
statusCode, _, body, err := anon.perform(ctx, http.MethodPost, "/api/v1/auth/login", payload)
if err != nil {
return nil, fmt.Errorf("login managed subscription user: %w", err)
}
if statusCode < http.StatusOK || statusCode >= http.StatusMultipleChoices {
return nil, newHTTPError(http.MethodPost, "/api/v1/auth/login", statusCode, body)
}
var tokenPair authTokenPair
if err := decodeEnvelopeObject(body, &tokenPair); err != nil {
return nil, fmt.Errorf("decode managed user login response: %w", err)
}
if strings.TrimSpace(tokenPair.AccessToken) == "" {
return nil, fmt.Errorf("managed user login returned empty access token")
}
return c.cloneWithAuth("", tokenPair.AccessToken), nil
}
func (c *Client) ensureManagedSubscriptionAPIKey(ctx context.Context, userClient *Client, userID int64, identity managedSubscriptionIdentity) (*adminAPIKeyRecord, error) {
payload := map[string]any{
"name": identity.KeyName,
"custom_key": identity.CustomKey,
}
statusCode, _, body, err := userClient.perform(ctx, http.MethodPost, "/api/v1/keys", payload)
if err != nil {
return nil, fmt.Errorf("create managed api key: %w", err)
}
if statusCode >= http.StatusOK && statusCode < http.StatusMultipleChoices {
var key adminAPIKeyRecord
if err := decodeEnvelopeObject(body, &key); err != nil {
return nil, fmt.Errorf("decode created api key: %w", err)
}
return &key, nil
}
if statusCode != http.StatusConflict && statusCode != http.StatusBadRequest {
return nil, newHTTPError(http.MethodPost, "/api/v1/keys", statusCode, body)
}
return c.findManagedSubscriptionAPIKey(ctx, userID, identity)
}
func (c *Client) findManagedSubscriptionAPIKey(ctx context.Context, userID int64, identity managedSubscriptionIdentity) (*adminAPIKeyRecord, error) {
statusCode, _, body, err := c.perform(ctx, http.MethodGet, fmt.Sprintf("/api/v1/admin/users/%d/api-keys?page=1&page_size=100&sort_by=created_at&sort_order=desc", userID), nil)
if err != nil {
return nil, fmt.Errorf("list managed api keys: %w", err)
}
if statusCode < http.StatusOK || statusCode >= http.StatusMultipleChoices {
return nil, newHTTPError(http.MethodGet, fmt.Sprintf("/api/v1/admin/users/%d/api-keys", userID), statusCode, body)
}
var envelope struct {
Data struct {
Items []adminAPIKeyRecord `json:"items"`
} `json:"data"`
}
if err := json.Unmarshal(body, &envelope); err != nil {
return nil, fmt.Errorf("decode admin api keys response: %w", err)
}
for _, item := range envelope.Data.Items {
if strings.TrimSpace(item.Key) == identity.CustomKey || strings.TrimSpace(item.Name) == identity.KeyName {
key := item
return &key, nil
}
}
return nil, fmt.Errorf("managed api key %q not found for user %d", identity.KeyName, userID)
}
func (c *Client) bindManagedSubscriptionAPIKey(ctx context.Context, keyID, groupID int64) error {
payload := map[string]any{"group_id": groupID}
statusCode, _, body, err := c.perform(ctx, http.MethodPut, fmt.Sprintf("/api/v1/admin/api-keys/%d", keyID), payload)
if err != nil {
return fmt.Errorf("bind managed api key group: %w", err)
}
if statusCode < http.StatusOK || statusCode >= http.StatusMultipleChoices {
return newHTTPError(http.MethodPut, fmt.Sprintf("/api/v1/admin/api-keys/%d", keyID), statusCode, body)
}
return nil
}
func (c *Client) cloneWithAuth(apiKey, bearerToken string) *Client {
if c == nil {
return nil
}
clone := *c
clone.apiKey = strings.TrimSpace(apiKey)
clone.bearerToken = strings.TrimSpace(bearerToken)
return &clone
}

View File

@@ -157,6 +157,7 @@ func validateProviders(providers []ProviderManifest) error {
seen := make(map[string]struct{}, len(providers))
for _, provider := range providers {
providerID := strings.TrimSpace(provider.ProviderID)
missingDefaultModel := firstMissingDefaultModel(provider.DefaultModels, provider.ChannelTemplate.ModelMapping)
switch {
case providerID == "":
return fmt.Errorf("provider manifest: provider_id is required")
@@ -180,6 +181,10 @@ func validateProviders(providers []ProviderManifest) error {
return fmt.Errorf("provider %q: channel_template.name is required", providerID)
case len(provider.ChannelTemplate.ModelMapping) == 0:
return fmt.Errorf("provider %q: channel_template.model_mapping must not be empty", providerID)
case !containsProviderModel(provider.ChannelTemplate.ModelMapping, provider.SmokeTestModel):
return fmt.Errorf("provider %q: channel_template.model_mapping must include smoke_test_model %q", providerID, provider.SmokeTestModel)
case missingDefaultModel != "":
return fmt.Errorf("provider %q: channel_template.model_mapping must cover default_models, missing %q", providerID, missingDefaultModel)
case strings.TrimSpace(provider.PlanTemplate.Name) == "":
return fmt.Errorf("provider %q: plan_template.name is required", providerID)
case provider.PlanTemplate.ValidityDays <= 0:
@@ -247,3 +252,29 @@ func contains(items []string, target string) bool {
}
return false
}
func containsProviderModel(modelMapping map[string]string, target string) bool {
trimmedTarget := strings.TrimSpace(target)
if trimmedTarget == "" {
return false
}
for sourceModel, mappedModel := range modelMapping {
if strings.TrimSpace(sourceModel) == trimmedTarget || strings.TrimSpace(mappedModel) == trimmedTarget {
return true
}
}
return false
}
func firstMissingDefaultModel(defaultModels []string, modelMapping map[string]string) string {
for _, model := range defaultModels {
trimmedModel := strings.TrimSpace(model)
if trimmedModel == "" {
continue
}
if !containsProviderModel(modelMapping, trimmedModel) {
return trimmedModel
}
}
return ""
}

View File

@@ -30,7 +30,7 @@ func TestLoadDirParsesAndValidatesPack(t *testing.T) {
"default_models": ["deepseek-chat", "deepseek-reasoner"],
"smoke_test_model": "deepseek-chat",
"group_template": {"name": "DeepSeek 默认分组", "rate_multiplier": 1.0},
"channel_template": {"name": "DeepSeek 默认渠道", "model_mapping": {"deepseek-chat": "deepseek-chat"}},
"channel_template": {"name": "DeepSeek 默认渠道", "model_mapping": {"deepseek-chat": "deepseek-chat", "deepseek-reasoner": "deepseek-reasoner"}},
"plan_template": {"name": "DeepSeek 默认套餐", "price": 19.9, "validity_days": 30, "validity_unit": "day"},
"import": {"supports_multi_key": true, "supports_strict": true, "supports_partial": true}
}`,
@@ -82,6 +82,36 @@ func TestLoadDirRejectsInvalidProviderSchema(t *testing.T) {
}
}
func TestLoadDirRejectsSmokeTestModelMissingFromChannelMapping(t *testing.T) {
packDir := createPackFixture(t, map[string]string{
"pack.json": `{"pack_id":"openai-cn-pack","version":"1.0.0","vendor":"x","target_host":"sub2api","min_host_version":"0.1.126","max_host_version":"0.2.x","providers_dir":"providers","checksum_file":"checksums.txt"}`,
"providers/deepseek.json": `{"provider_id":"deepseek","display_name":"DeepSeek","base_url":"https://api.deepseek.com","platform":"openai","account_type":"apikey","default_models":["deepseek-v4-pro","deepseek-v4-flash"],"smoke_test_model":"deepseek-v4-pro","group_template":{"name":"g","rate_multiplier":1},"channel_template":{"name":"c","model_mapping":{"deepseek-chat":"deepseek-chat","deepseek-reasoner":"deepseek-reasoner"}},"plan_template":{"name":"p","price":1,"validity_days":30,"validity_unit":"day"},"import":{"supports_multi_key":true,"supports_strict":true,"supports_partial":true}}`,
})
_, err := LoadDir(packDir)
if err == nil {
t.Fatal("LoadDir() error = nil, want smoke_test_model channel mapping validation failure")
}
if !strings.Contains(err.Error(), "channel_template.model_mapping") || !strings.Contains(err.Error(), "smoke_test_model") {
t.Fatalf("LoadDir() error = %v, want smoke_test_model channel mapping detail", err)
}
}
func TestLoadDirRejectsDefaultModelsMissingFromChannelMapping(t *testing.T) {
packDir := createPackFixture(t, map[string]string{
"pack.json": `{"pack_id":"openai-cn-pack","version":"1.0.0","vendor":"x","target_host":"sub2api","min_host_version":"0.1.126","max_host_version":"0.2.x","providers_dir":"providers","checksum_file":"checksums.txt"}`,
"providers/minimax.json": `{"provider_id":"minimax","display_name":"MiniMax","base_url":"https://api.minimax.example.com","platform":"openai","account_type":"apikey","default_models":["MiniMax-M2.5-highspeed","MiniMax-M2.7-highspeed"],"smoke_test_model":"MiniMax-M2.7-highspeed","group_template":{"name":"g","rate_multiplier":1},"channel_template":{"name":"c","model_mapping":{"MiniMax-M2.7-highspeed":"MiniMax-M2.7-highspeed"}},"plan_template":{"name":"p","price":1,"validity_days":30,"validity_unit":"day"},"import":{"supports_multi_key":true,"supports_strict":true,"supports_partial":true}}`,
})
_, err := LoadDir(packDir)
if err == nil {
t.Fatal("LoadDir() error = nil, want default_models channel mapping validation failure")
}
if !strings.Contains(err.Error(), "default_models") || !strings.Contains(err.Error(), "channel_template.model_mapping") {
t.Fatalf("LoadDir() error = %v, want default_models channel mapping detail", err)
}
}
func createPackFixture(t *testing.T, files map[string]string) string {
t.Helper()

View File

@@ -0,0 +1,128 @@
package provision
import (
"context"
"strings"
"sub2api-cn-relay-manager/internal/pack"
"sub2api-cn-relay-manager/internal/store/sqlite"
)
type ModeAccessStatuses struct {
Subscription string
SelfService string
}
func SuggestResourceNamesForMode(provider pack.ProviderManifest, accessMode string) ResourceNames {
base := SuggestResourceNames(provider)
suffix := accessModeResourceSuffix(accessMode)
if suffix == "" {
return base
}
return ResourceNames{
Group: appendResourceNameSuffix(base.Group, suffix),
Channel: appendResourceNameSuffix(base.Channel, suffix),
Plan: appendResourceNameSuffix(base.Plan, suffix),
}
}
func accessModeResourceSuffix(accessMode string) string {
switch strings.TrimSpace(accessMode) {
case AccessModeSubscription:
return "subscription"
case AccessModeSelfService:
return "self-service"
default:
return ""
}
}
func appendResourceNameSuffix(name, suffix string) string {
name = strings.TrimSpace(name)
suffix = strings.TrimSpace(suffix)
if name == "" || suffix == "" {
return name
}
if strings.HasSuffix(name, "-"+suffix) {
return name
}
return name + "-" + suffix
}
func LatestModeAccessStatuses(ctx context.Context, store *sqlite.DB, batches []sqlite.ImportBatch) (ModeAccessStatuses, error) {
var statuses ModeAccessStatuses
for _, batch := range batches {
if statuses.Subscription != "" && statuses.SelfService != "" {
break
}
closures, err := store.AccessClosures().GetByBatchID(ctx, batch.ID)
if err != nil {
return ModeAccessStatuses{}, err
}
batchStatuses := modeAccessStatusesForBatch(batch, closures)
if statuses.Subscription == "" && strings.TrimSpace(batchStatuses.Subscription) != "" {
statuses.Subscription = strings.TrimSpace(batchStatuses.Subscription)
}
if statuses.SelfService == "" && strings.TrimSpace(batchStatuses.SelfService) != "" {
statuses.SelfService = strings.TrimSpace(batchStatuses.SelfService)
}
}
return statuses, nil
}
func modeAccessStatusesForBatch(batch sqlite.ImportBatch, closures []sqlite.AccessClosureRecord) ModeAccessStatuses {
statuses := ModeAccessStatuses{}
for _, closure := range closures {
status := strings.TrimSpace(closure.Status)
switch strings.TrimSpace(closure.ClosureType) {
case AccessModeSubscription:
statuses.Subscription = status
case AccessModeSelfService:
statuses.SelfService = status
}
}
if statuses.Subscription == "" && statuses.SelfService == "" {
return seedModeAccessStatuses(batch.AccessStatus)
}
return statuses
}
func seedModeAccessStatuses(accessStatus string) ModeAccessStatuses {
switch strings.TrimSpace(accessStatus) {
case AccessStatusFullyReady:
return ModeAccessStatuses{Subscription: AccessStatusSubscriptionReady, SelfService: AccessStatusSelfServiceReady}
case AccessStatusSubscriptionReady:
return ModeAccessStatuses{Subscription: AccessStatusSubscriptionReady}
case AccessStatusSelfServiceReady:
return ModeAccessStatuses{SelfService: AccessStatusSelfServiceReady}
default:
return ModeAccessStatuses{}
}
}
func AggregateAccessStatus(statuses ModeAccessStatuses) string {
subscriptionReady := isReadyAccessStatus(statuses.Subscription, AccessModeSubscription)
selfServiceReady := isReadyAccessStatus(statuses.SelfService, AccessModeSelfService)
switch {
case subscriptionReady && selfServiceReady:
return AccessStatusFullyReady
case subscriptionReady:
return AccessStatusSubscriptionReady
case selfServiceReady:
return AccessStatusSelfServiceReady
default:
return AccessStatusBroken
}
}
func isReadyAccessStatus(status, mode string) bool {
status = strings.TrimSpace(status)
switch mode {
case AccessModeSubscription:
return status == AccessStatusSubscriptionReady || status == AccessStatusFullyReady
case AccessModeSelfService:
return status == AccessStatusSelfServiceReady || status == AccessStatusFullyReady
default:
return status != "" && status != AccessStatusBroken
}
}

View File

@@ -278,7 +278,7 @@ func accessClosureType(accessClosures []sqlite.AccessClosureRecord) string {
}
func buildManagedResourceListRequest(provider pack.ProviderManifest, accessMode string) sub2api.ListManagedResourcesRequest {
names := SuggestResourceNames(provider)
names := SuggestResourceNamesForMode(provider, accessMode)
req := sub2api.ListManagedResourcesRequest{
GroupName: names.Group,
ChannelName: names.Channel,

View File

@@ -215,10 +215,11 @@ func TestDeriveProviderStatus(t *testing.T) {
tests := []struct {
name string
batchStatus string
accessStatus string
reconcileStatus string
want string
}{
{name: "reconcile wins", batchStatus: BatchStatusSucceeded, reconcileStatus: "degraded", want: "degraded"},
{name: "recovered success beats stale reconcile", batchStatus: BatchStatusSucceeded, accessStatus: AccessStatusSelfServiceReady, reconcileStatus: "degraded", want: ProviderStatusActive},
{name: "succeeded batch", batchStatus: BatchStatusSucceeded, reconcileStatus: "not_run", want: ProviderStatusActive},
{name: "failed batch", batchStatus: BatchStatusFailed, want: ProviderStatusFailed},
{name: "running batch", batchStatus: "running", want: "running"},
@@ -226,13 +227,60 @@ func TestDeriveProviderStatus(t *testing.T) {
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := deriveProviderStatus(tc.batchStatus, tc.reconcileStatus); got != tc.want {
t.Fatalf("deriveProviderStatus(%q, %q) = %q, want %q", tc.batchStatus, tc.reconcileStatus, got, tc.want)
if got := deriveProviderStatus(tc.batchStatus, tc.accessStatus, tc.reconcileStatus); got != tc.want {
t.Fatalf("deriveProviderStatus(%q, %q, %q) = %q, want %q", tc.batchStatus, tc.accessStatus, tc.reconcileStatus, got, tc.want)
}
})
}
}
func TestProviderStatusServiceAggregatesLatestAccessModesAcrossBatches(t *testing.T) {
store := openProvisionTestStore(t)
defer closeProvisionTestStore(t, store)
ctx := context.Background()
hostID := seedProvisionHost(t, store, "host-1", "https://sub2api.example.com")
packID, err := store.Packs().Create(ctx, sqlite.Pack{PackID: "openai-cn-pack", Version: "1.0.0", TargetHost: "sub2api", Checksum: "checksum-1"})
if err != nil {
t.Fatalf("Packs().Create() error = %v", err)
}
providerID, err := store.Providers().Create(ctx, sqlite.Provider{PackID: packID, ProviderID: "deepseek", DisplayName: "DeepSeek", BaseURL: "https://api.deepseek.com", Platform: "openai"})
if err != nil {
t.Fatalf("Providers().Create() error = %v", err)
}
batchSubscription, err := store.ImportBatches().Create(ctx, sqlite.ImportBatch{HostID: hostID, PackID: packID, ProviderID: providerID, Mode: ImportModePartial, BatchStatus: BatchStatusSucceeded, AccessStatus: AccessStatusSubscriptionReady})
if err != nil {
t.Fatalf("ImportBatches().Create(subscription) error = %v", err)
}
if _, err := store.AccessClosures().Create(ctx, sqlite.AccessClosureRecord{BatchID: batchSubscription, ClosureType: AccessModeSubscription, Status: AccessStatusSubscriptionReady, DetailsJSON: "{}"}); err != nil {
t.Fatalf("AccessClosures().Create(subscription) error = %v", err)
}
batchSelfService, err := store.ImportBatches().Create(ctx, sqlite.ImportBatch{HostID: hostID, PackID: packID, ProviderID: providerID, Mode: ImportModePartial, BatchStatus: BatchStatusSucceeded, AccessStatus: AccessStatusSelfServiceReady})
if err != nil {
t.Fatalf("ImportBatches().Create(self_service) error = %v", err)
}
if _, err := store.AccessClosures().Create(ctx, sqlite.AccessClosureRecord{BatchID: batchSelfService, ClosureType: AccessModeSelfService, Status: AccessStatusSelfServiceReady, DetailsJSON: "{}"}); err != nil {
t.Fatalf("AccessClosures().Create(self_service) error = %v", err)
}
if _, err := store.ReconcileRuns().Create(ctx, sqlite.ReconcileRun{BatchID: batchSelfService, HostID: hostID, ProviderID: providerID, Status: "drifted", SummaryJSON: `{"missing_count":1}`}); err != nil {
t.Fatalf("ReconcileRuns().Create() error = %v", err)
}
snapshot, err := NewProviderStatusService(store).GetStatus(ctx, ProviderQuery{ProviderID: "deepseek", PackID: "openai-cn-pack", HostID: "host-1"})
if err != nil {
t.Fatalf("GetStatus() error = %v", err)
}
if snapshot.LatestAccessStatus != AccessStatusFullyReady {
t.Fatalf("LatestAccessStatus = %q, want %q", snapshot.LatestAccessStatus, AccessStatusFullyReady)
}
if snapshot.ProviderStatus != ProviderStatusActive {
t.Fatalf("ProviderStatus = %q, want %q", snapshot.ProviderStatus, ProviderStatusActive)
}
if snapshot.LatestReconcileStatus != "drifted" {
t.Fatalf("LatestReconcileStatus = %q, want drifted", snapshot.LatestReconcileStatus)
}
}
func TestBuildPackAndProviderRecord(t *testing.T) {
packRow, err := buildPackRecord(sampleLoadedPack())
if err != nil {

View File

@@ -199,7 +199,7 @@ func (s *ImportService) Import(ctx context.Context, req ImportRequest) (report I
}
func (s *ImportService) ensureManagedResources(ctx context.Context, provider pack.ProviderManifest, accessMode string) (resolvedManagedResources, error) {
names := SuggestResourceNames(provider)
names := SuggestResourceNamesForMode(provider, accessMode)
snapshot, err := s.host.ListManagedResources(ctx, sub2api.ListManagedResourcesRequest{
GroupName: names.Group,
ChannelName: names.Channel,
@@ -210,14 +210,14 @@ func (s *ImportService) ensureManagedResources(ctx context.Context, provider pac
}
result := resolvedManagedResources{}
group, created, err := ensureGroup(ctx, s.host, snapshot.Groups, provider, accessMode)
group, created, err := ensureGroup(ctx, s.host, snapshot.Groups, provider, accessMode, names.Group)
if err != nil {
return resolvedManagedResources{}, fmt.Errorf("ensure group: %w", err)
}
result.Group = group
result.CreatedGroup = created
channel, created, err := ensureChannel(ctx, s.host, snapshot.Channels, provider, group.ID)
channel, created, err := ensureChannel(ctx, s.host, snapshot.Channels, provider, group.ID, names.Channel)
if err != nil {
return resolvedManagedResources{}, fmt.Errorf("ensure channel: %w", err)
}
@@ -225,7 +225,7 @@ func (s *ImportService) ensureManagedResources(ctx context.Context, provider pac
result.CreatedChannel = created
if accessMode == AccessModeSubscription {
plan, created, err := ensurePlan(ctx, s.host, snapshot.Plans, provider, group.ID)
plan, created, err := ensurePlan(ctx, s.host, snapshot.Plans, provider, group.ID, names.Plan)
if err != nil {
return resolvedManagedResources{}, fmt.Errorf("ensure plan: %w", err)
}
@@ -236,10 +236,10 @@ func (s *ImportService) ensureManagedResources(ctx context.Context, provider pac
return result, nil
}
func ensureGroup(ctx context.Context, host hostAdapter, existing []sub2api.NamedResource, provider pack.ProviderManifest, accessMode string) (sub2api.GroupRef, bool, error) {
func ensureGroup(ctx context.Context, host hostAdapter, existing []sub2api.NamedResource, provider pack.ProviderManifest, accessMode, groupName string) (sub2api.GroupRef, bool, error) {
switch len(existing) {
case 0:
groupReq := sub2api.CreateGroupRequest{Name: provider.GroupTemplate.Name, Platform: provider.Platform, RateMultiplier: provider.GroupTemplate.RateMultiplier}
groupReq := sub2api.CreateGroupRequest{Name: groupName, Platform: provider.Platform, RateMultiplier: provider.GroupTemplate.RateMultiplier}
if accessMode == AccessModeSubscription {
groupReq.SubscriptionType = "subscription"
}
@@ -248,38 +248,52 @@ func ensureGroup(ctx context.Context, host hostAdapter, existing []sub2api.Named
case 1:
return sub2api.GroupRef{ID: existing[0].ID, Name: existing[0].Name}, false, nil
default:
return sub2api.GroupRef{}, false, fmt.Errorf("multiple groups already exist for %q", provider.GroupTemplate.Name)
return sub2api.GroupRef{}, false, fmt.Errorf("multiple groups already exist for %q", groupName)
}
}
func ensureChannel(ctx context.Context, host hostAdapter, existing []sub2api.NamedResource, provider pack.ProviderManifest, groupID string) (sub2api.ChannelRef, bool, error) {
func ensureChannel(ctx context.Context, host hostAdapter, existing []sub2api.NamedResource, provider pack.ProviderManifest, groupID, channelName string) (sub2api.ChannelRef, bool, error) {
channelReq := buildChannelRequest(provider, groupID, channelName)
switch len(existing) {
case 0:
channelReq := sub2api.CreateChannelRequest{
Name: provider.ChannelTemplate.Name,
GroupIDs: []string{groupID},
ModelMapping: provider.ChannelTemplate.ModelMapping,
RestrictModels: true,
BillingModelSource: "channel_mapped",
}
channel, err := host.CreateChannel(ctx, channelReq)
return channel, true, err
case 1:
if err := host.UpdateChannel(ctx, existing[0].ID, channelReq); err != nil {
return sub2api.ChannelRef{}, false, err
}
return sub2api.ChannelRef{ID: existing[0].ID, Name: existing[0].Name}, false, nil
default:
return sub2api.ChannelRef{}, false, fmt.Errorf("multiple channels already exist for %q", provider.ChannelTemplate.Name)
return sub2api.ChannelRef{}, false, fmt.Errorf("multiple channels already exist for %q", channelName)
}
}
func ensurePlan(ctx context.Context, host hostAdapter, existing []sub2api.NamedResource, provider pack.ProviderManifest, groupID string) (sub2api.PlanRef, bool, error) {
func buildChannelRequest(provider pack.ProviderManifest, groupID, channelName string) sub2api.CreateChannelRequest {
return sub2api.CreateChannelRequest{
Name: channelName,
GroupIDs: []string{groupID},
ModelMapping: provider.ChannelTemplate.ModelMapping,
ModelPricing: []sub2api.ChannelModelPricing{{
Platform: provider.Platform,
Models: append([]string(nil), provider.DefaultModels...),
BillingMode: "token",
Intervals: []sub2api.ChannelPricingTier{},
}},
Platform: provider.Platform,
RestrictModels: true,
BillingModelSource: "channel_mapped",
}
}
func ensurePlan(ctx context.Context, host hostAdapter, existing []sub2api.NamedResource, provider pack.ProviderManifest, groupID, planName string) (sub2api.PlanRef, bool, error) {
switch len(existing) {
case 0:
plan, err := host.CreatePlan(ctx, sub2api.CreatePlanRequest{GroupID: groupID, Name: provider.PlanTemplate.Name, Price: provider.PlanTemplate.Price, ValidityDays: provider.PlanTemplate.ValidityDays, ValidityUnit: provider.PlanTemplate.ValidityUnit})
plan, err := host.CreatePlan(ctx, sub2api.CreatePlanRequest{GroupID: groupID, Name: planName, Price: provider.PlanTemplate.Price, ValidityDays: provider.PlanTemplate.ValidityDays, ValidityUnit: provider.PlanTemplate.ValidityUnit})
return plan, true, err
case 1:
return sub2api.PlanRef{ID: existing[0].ID, Name: existing[0].Name}, false, nil
default:
return sub2api.PlanRef{}, false, fmt.Errorf("multiple plans already exist for %q", provider.PlanTemplate.Name)
return sub2api.PlanRef{}, false, fmt.Errorf("multiple plans already exist for %q", planName)
}
}
@@ -329,8 +343,9 @@ func buildBatchAccountsRequest(provider pack.ProviderManifest, groupID string, k
Type: provider.AccountType,
GroupIDs: []string{groupID},
Credentials: map[string]any{
"base_url": provider.BaseURL,
"api_key": key,
"base_url": provider.BaseURL,
"api_key": key,
"model_mapping": provider.ChannelTemplate.ModelMapping,
},
})
}

View File

@@ -152,7 +152,7 @@ func TestImportReusesExistingGroup(t *testing.T) {
},
gatewayResult: sub2api.GatewayAccessResult{OK: true, StatusCode: 200, HasExpectedModel: true, Models: []string{"deepseek-chat"}},
managedSnapshot: sub2api.ManagedResourceSnapshot{
Groups: []sub2api.NamedResource{{ID: "group_existing", Name: "DeepSeek 默认分组"}},
Groups: []sub2api.NamedResource{{ID: "group_existing", Name: "DeepSeek 默认分组-self-service"}},
},
}
@@ -198,8 +198,8 @@ func TestImportCreatesChannelWithManifestModelMapping(t *testing.T) {
if err != nil {
t.Fatalf("Import() error = %v", err)
}
if host.createChannelReq.Name != "DeepSeek 默认渠道" {
t.Fatalf("CreateChannel().Name = %q, want DeepSeek 默认渠道", host.createChannelReq.Name)
if host.createChannelReq.Name != "DeepSeek 默认渠道-self-service" {
t.Fatalf("CreateChannel().Name = %q, want DeepSeek 默认渠道-self-service", host.createChannelReq.Name)
}
if len(host.createChannelReq.GroupIDs) != 1 || host.createChannelReq.GroupIDs[0] != "group_1" {
t.Fatalf("CreateChannel().GroupIDs = %v, want [group_1]", host.createChannelReq.GroupIDs)
@@ -213,6 +213,31 @@ func TestImportCreatesChannelWithManifestModelMapping(t *testing.T) {
if host.createChannelReq.BillingModelSource != "channel_mapped" {
t.Fatalf("CreateChannel().BillingModelSource = %q, want channel_mapped", host.createChannelReq.BillingModelSource)
}
if len(host.createChannelReq.ModelPricing) != 1 {
t.Fatalf("CreateChannel().ModelPricing len = %d, want 1", len(host.createChannelReq.ModelPricing))
}
if len(host.createChannelReq.ModelPricing[0].Models) != 2 {
t.Fatalf("CreateChannel().ModelPricing[0].Models = %v, want default model coverage", host.createChannelReq.ModelPricing[0].Models)
}
if host.createChannelReq.ModelPricing[0].BillingMode != "token" {
t.Fatalf("CreateChannel().ModelPricing[0].BillingMode = %q, want token", host.createChannelReq.ModelPricing[0].BillingMode)
}
if len(host.batchCreateReq.Accounts) != 1 {
t.Fatalf("BatchCreateAccounts().Accounts len = %d, want 1", len(host.batchCreateReq.Accounts))
}
credentials := host.batchCreateReq.Accounts[0].Credentials
switch rawMapping := credentials["model_mapping"].(type) {
case map[string]string:
if got := rawMapping["deepseek-chat"]; got != "deepseek-chat" {
t.Fatalf("BatchCreateAccounts().Credentials.model_mapping = %+v, want deepseek-chat passthrough", rawMapping)
}
case map[string]any:
if got, _ := rawMapping["deepseek-chat"].(string); got != "deepseek-chat" {
t.Fatalf("BatchCreateAccounts().Credentials.model_mapping = %+v, want deepseek-chat passthrough", rawMapping)
}
default:
t.Fatalf("BatchCreateAccounts().Credentials = %+v, want model_mapping map", credentials)
}
}
func sampleProviderManifest() pack.ProviderManifest {
@@ -230,8 +255,48 @@ func sampleProviderManifest() pack.ProviderManifest {
}
}
func TestImportReconcilesExistingChannelConfiguration(t *testing.T) {
host := &fakeHostAdapter{
batchAccounts: []sub2api.AccountRef{{ID: "account_1", Name: "deepseek-01"}},
testResults: map[string]sub2api.ProbeResult{
"account_1": {OK: true, Status: "ready"},
},
models: map[string][]sub2api.AccountModel{
"account_1": {{ID: "deepseek-chat"}},
},
gatewayResult: sub2api.GatewayAccessResult{OK: true, StatusCode: 200, HasExpectedModel: true, Models: []string{"deepseek-chat"}},
managedSnapshot: sub2api.ManagedResourceSnapshot{
Groups: []sub2api.NamedResource{{ID: "group_existing", Name: "DeepSeek 默认分组-self-service"}},
Channels: []sub2api.NamedResource{{ID: "channel_existing", Name: "DeepSeek 默认渠道-self-service"}},
},
}
_, err := NewImportService(host).Import(context.Background(), ImportRequest{
Provider: sampleProviderManifest(),
Mode: ImportModePartial,
Access: AccessRequest{Mode: AccessModeSelfService, ProbeAPIKey: "user-key"},
Keys: []string{"key-1"},
})
if err != nil {
t.Fatalf("Import() error = %v", err)
}
if host.createChannelCalls != 0 {
t.Fatalf("CreateChannel() calls = %d, want 0 when channel already exists", host.createChannelCalls)
}
if host.updateChannelCalls != 1 {
t.Fatalf("UpdateChannel() calls = %d, want 1", host.updateChannelCalls)
}
if host.updateChannelID != "channel_existing" {
t.Fatalf("UpdateChannel() id = %q, want channel_existing", host.updateChannelID)
}
if len(host.updateChannelReq.ModelPricing) != 1 {
t.Fatalf("UpdateChannel().ModelPricing len = %d, want 1", len(host.updateChannelReq.ModelPricing))
}
}
type fakeHostAdapter struct {
batchAccounts []sub2api.AccountRef
batchCreateReq sub2api.BatchCreateAccountsRequest
testResults map[string]sub2api.ProbeResult
models map[string][]sub2api.AccountModel
gatewayResult sub2api.GatewayAccessResult
@@ -246,9 +311,12 @@ type fakeHostAdapter struct {
listManagedReq sub2api.ListManagedResourcesRequest
createGroupCalls int
createChannelCalls int
updateChannelCalls int
createPlanCalls int
createGroupReq sub2api.CreateGroupRequest
createChannelReq sub2api.CreateChannelRequest
updateChannelID string
updateChannelReq sub2api.CreateChannelRequest
}
func (f *fakeHostAdapter) GetHostVersion(context.Context) (string, error) {
@@ -274,6 +342,12 @@ func (f *fakeHostAdapter) CreateChannel(_ context.Context, req sub2api.CreateCha
f.createChannelReq = req
return sub2api.ChannelRef{ID: "channel_1", Name: "c"}, nil
}
func (f *fakeHostAdapter) UpdateChannel(_ context.Context, channelID string, req sub2api.CreateChannelRequest) error {
f.updateChannelCalls++
f.updateChannelID = channelID
f.updateChannelReq = req
return nil
}
func (f *fakeHostAdapter) DeleteChannel(_ context.Context, channelID string) error {
f.deletedResources = append(f.deletedResources, "channel:"+channelID)
return nil
@@ -289,7 +363,8 @@ func (f *fakeHostAdapter) DeletePlan(_ context.Context, planID string) error {
func (f *fakeHostAdapter) CreateAccount(context.Context, sub2api.CreateAccountRequest) (sub2api.AccountRef, error) {
return sub2api.AccountRef{}, errors.New("unused")
}
func (f *fakeHostAdapter) BatchCreateAccounts(_ context.Context, _ sub2api.BatchCreateAccountsRequest) ([]sub2api.AccountRef, error) {
func (f *fakeHostAdapter) BatchCreateAccounts(_ context.Context, req sub2api.BatchCreateAccountsRequest) ([]sub2api.AccountRef, error) {
f.batchCreateReq = req
if f.batchCreateErr != nil {
return nil, f.batchCreateErr
}
@@ -313,6 +388,9 @@ func (f *fakeHostAdapter) GetAccountModels(_ context.Context, accountID string)
}
return models, nil
}
func (f *fakeHostAdapter) EnsureSubscriptionAccess(_ context.Context, req sub2api.EnsureSubscriptionAccessRequest) (sub2api.SubscriptionAccessRef, error) {
return sub2api.SubscriptionAccessRef{UserID: req.UserSelector, APIKey: "managed-subscription-key"}, nil
}
func (f *fakeHostAdapter) AssignSubscription(_ context.Context, req sub2api.AssignSubscriptionRequest) (sub2api.SubscriptionRef, error) {
if f.assignErr != nil {
return sub2api.SubscriptionRef{}, f.assignErr

View File

@@ -57,7 +57,7 @@ func (s *PreviewService) PreviewImport(ctx context.Context, req PreviewRequest)
return PreviewReport{}, fmt.Errorf("preview host is required")
}
names := SuggestResourceNames(req.Provider)
names := SuggestResourceNamesForMode(req.Provider, req.Mode)
snapshot, err := s.host.ListManagedResources(ctx, sub2api.ListManagedResourcesRequest{
GroupName: names.Group,
ChannelName: names.Channel,

View File

@@ -23,6 +23,23 @@ func TestSuggestResourceNames(t *testing.T) {
}
}
func TestSuggestResourceNamesIncludesAccessModeSuffix(t *testing.T) {
provider := sampleProviderManifest()
provider.GroupTemplate.Name = ""
provider.ChannelTemplate.Name = ""
provider.PlanTemplate.Name = ""
names := SuggestResourceNamesForMode(provider, AccessModeSubscription)
want := ResourceNames{
Group: "crm-deepseek-group-subscription",
Channel: "crm-deepseek-channel-subscription",
Plan: "crm-deepseek-plan-subscription",
}
if !reflect.DeepEqual(names, want) {
t.Fatalf("SuggestResourceNamesForMode() = %#v, want %#v", names, want)
}
}
func TestPreviewServiceReportsCreateActionsWhenHostHasNoResources(t *testing.T) {
host := &fakePreviewHost{}
svc := NewPreviewService(host)

View File

@@ -69,13 +69,18 @@ func (s *ProviderStatusService) snapshot(ctx context.Context, query ProviderQuer
if err != nil {
return ProviderSnapshot{}, err
}
reconcileRuns, err := s.store.ReconcileRuns().GetByBatchID(ctx, batchRow.ID)
batches, err := s.store.ImportBatches().ListByProviderIDAndHostID(ctx, provider.ID, hostRow.ID)
if err != nil {
return ProviderSnapshot{}, err
}
latestAccessStatus := batchRow.AccessStatus
if len(accessClosures) > 0 {
latestAccessStatus = firstNonEmpty(accessClosures[len(accessClosures)-1].Status, latestAccessStatus)
modeStatuses, err := LatestModeAccessStatuses(ctx, s.store, batches)
if err != nil {
return ProviderSnapshot{}, err
}
latestAccessStatus := AggregateAccessStatus(modeStatuses)
reconcileRuns, err := s.store.ReconcileRuns().GetByBatchID(ctx, batchRow.ID)
if err != nil {
return ProviderSnapshot{}, err
}
latestReconcileStatus := "not_run"
latestReconcileSummary := map[string]any{}
@@ -87,7 +92,7 @@ func (s *ProviderStatusService) snapshot(ctx context.Context, query ProviderQuer
}
}
}
providerStatus := deriveProviderStatus(batchRow.BatchStatus, latestReconcileStatus)
providerStatus := deriveProviderStatus(batchRow.BatchStatus, latestAccessStatus, latestReconcileStatus)
return ProviderSnapshot{
Host: hostRow,
Pack: packRow,
@@ -162,8 +167,12 @@ func (s *ProviderStatusService) resolveHostAndBatch(ctx context.Context, provide
return hostRow, batches[0], nil
}
func deriveProviderStatus(batchStatus, reconcileStatus string) string {
func deriveProviderStatus(batchStatus, accessStatus, reconcileStatus string) string {
reconcileStatus = strings.TrimSpace(reconcileStatus)
accessStatus = strings.TrimSpace(accessStatus)
if strings.TrimSpace(batchStatus) == BatchStatusSucceeded && accessStatus != "" && accessStatus != AccessStatusBroken {
return ProviderStatusActive
}
if reconcileStatus != "" && reconcileStatus != "not_run" {
return reconcileStatus
}

View File

@@ -54,8 +54,8 @@ func TestProviderStatusServiceReturnsLatestSnapshot(t *testing.T) {
if snapshot.Provider.ProviderID != "deepseek" {
t.Fatalf("Provider.ProviderID = %q, want deepseek", snapshot.Provider.ProviderID)
}
if snapshot.ProviderStatus != "drifted" {
t.Fatalf("ProviderStatus = %q, want drifted", snapshot.ProviderStatus)
if snapshot.ProviderStatus != ProviderStatusActive {
t.Fatalf("ProviderStatus = %q, want %q", snapshot.ProviderStatus, ProviderStatusActive)
}
if snapshot.LatestAccessStatus != AccessStatusSelfServiceReady {
t.Fatalf("LatestAccessStatus = %q, want %q", snapshot.LatestAccessStatus, AccessStatusSelfServiceReady)

View File

@@ -28,8 +28,8 @@ func TestReconcileServiceReturnsActiveAfterProbeRerun(t *testing.T) {
batchID := seedRuntimeImportForReconcile(t, store, host)
host.managedSnapshot = sub2api.ManagedResourceSnapshot{
Groups: []sub2api.NamedResource{{ID: "group_1", Name: "DeepSeek 默认分组"}},
Channels: []sub2api.NamedResource{{ID: "channel_1", Name: "DeepSeek 默认渠道"}},
Groups: []sub2api.NamedResource{{ID: "group_1", Name: "DeepSeek 默认分组-self-service"}},
Channels: []sub2api.NamedResource{{ID: "channel_1", Name: "DeepSeek 默认渠道-self-service"}},
Accounts: []sub2api.NamedResource{{ID: "account_1", Name: "deepseek-01"}, {ID: "account_2", Name: "deepseek-02"}},
}
@@ -82,8 +82,8 @@ func TestReconcileServiceReturnsDegradedWhenProbeRerunFails(t *testing.T) {
seedRuntimeImportForReconcile(t, store, host)
host.managedSnapshot = sub2api.ManagedResourceSnapshot{
Groups: []sub2api.NamedResource{{ID: "group_1", Name: "DeepSeek 默认分组"}},
Channels: []sub2api.NamedResource{{ID: "channel_1", Name: "DeepSeek 默认渠道"}},
Groups: []sub2api.NamedResource{{ID: "group_1", Name: "DeepSeek 默认分组-self-service"}},
Channels: []sub2api.NamedResource{{ID: "channel_1", Name: "DeepSeek 默认渠道-self-service"}},
Accounts: []sub2api.NamedResource{{ID: "account_1", Name: "deepseek-01"}, {ID: "account_2", Name: "deepseek-02"}},
}
@@ -124,8 +124,8 @@ func TestReconcileServiceReturnsDriftedWhenManagedResourceMissing(t *testing.T)
seedRuntimeImportForReconcile(t, store, host)
host.managedSnapshot = sub2api.ManagedResourceSnapshot{
Groups: []sub2api.NamedResource{{ID: "group_1", Name: "DeepSeek 默认分组"}},
Channels: []sub2api.NamedResource{{ID: "channel_1", Name: "DeepSeek 默认渠道"}},
Groups: []sub2api.NamedResource{{ID: "group_1", Name: "DeepSeek 默认分组-self-service"}},
Channels: []sub2api.NamedResource{{ID: "channel_1", Name: "DeepSeek 默认渠道-self-service"}},
Accounts: []sub2api.NamedResource{{ID: "account_1", Name: "deepseek-01"}},
}
@@ -166,8 +166,8 @@ func TestReconcileServiceIgnoresSubscriptionPlanForSelfServiceBatch(t *testing.T
seedRuntimeImportForReconcile(t, store, host)
host.managedSnapshot = sub2api.ManagedResourceSnapshot{
Groups: []sub2api.NamedResource{{ID: "group_1", Name: "DeepSeek 默认分组"}},
Channels: []sub2api.NamedResource{{ID: "channel_1", Name: "DeepSeek 默认渠道"}},
Groups: []sub2api.NamedResource{{ID: "group_1", Name: "DeepSeek 默认分组-self-service"}},
Channels: []sub2api.NamedResource{{ID: "channel_1", Name: "DeepSeek 默认渠道-self-service"}},
Plans: []sub2api.NamedResource{{ID: "plan_1", Name: "DeepSeek 默认套餐"}},
Accounts: []sub2api.NamedResource{{ID: "account_1", Name: "deepseek-01"}, {ID: "account_2", Name: "deepseek-02"}},
}
@@ -212,8 +212,8 @@ func TestReconcileServicePassesAccountNamePrefixToManagedResourceSnapshot(t *tes
seedRuntimeImportForReconcile(t, store, host)
host.managedSnapshot = sub2api.ManagedResourceSnapshot{
Groups: []sub2api.NamedResource{{ID: "group_1", Name: "DeepSeek 默认分组"}},
Channels: []sub2api.NamedResource{{ID: "channel_1", Name: "DeepSeek 默认渠道"}},
Groups: []sub2api.NamedResource{{ID: "group_1", Name: "DeepSeek 默认分组-self-service"}},
Channels: []sub2api.NamedResource{{ID: "channel_1", Name: "DeepSeek 默认渠道-self-service"}},
Accounts: []sub2api.NamedResource{{ID: "account_1", Name: "deepseek-01"}, {ID: "account_2", Name: "deepseek-02"}},
}