Files
sub2api-cn-relay-manager/internal/provision/pack_install_service.go

176 lines
5.3 KiB
Go

package provision
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"strings"
"sub2api-cn-relay-manager/internal/host/sub2api"
packdef "sub2api-cn-relay-manager/internal/pack"
"sub2api-cn-relay-manager/internal/store/sqlite"
)
type PackInstallRequest struct {
Pack packdef.LoadedPack
}
type PackInstallResult struct {
Pack sqlite.Pack
Providers []sqlite.Provider
HostVersion string
AlreadyInstalled bool
}
type PackInstallService struct {
store *sqlite.DB
host sub2api.HostAdapter
}
func NewPackInstallService(store *sqlite.DB, host sub2api.HostAdapter) *PackInstallService {
return &PackInstallService{store: store, host: host}
}
func (s *PackInstallService) Install(ctx context.Context, req PackInstallRequest) (PackInstallResult, error) {
if s == nil || s.store == nil {
return PackInstallResult{}, fmt.Errorf("store is required")
}
if s.host == nil {
return PackInstallResult{}, fmt.Errorf("host adapter is required")
}
if strings.TrimSpace(req.Pack.Manifest.PackID) == "" {
return PackInstallResult{}, fmt.Errorf("pack manifest is required")
}
hostVersion, err := s.host.GetHostVersion(ctx)
if err != nil {
return PackInstallResult{}, fmt.Errorf("get host version: %w", err)
}
if err := packdef.CheckHostCompatibility(req.Pack.Manifest, hostVersion); err != nil {
return PackInstallResult{}, err
}
result := PackInstallResult{HostVersion: hostVersion}
if err := s.store.WithTx(ctx, func(queries *sqlite.Queries) error {
existing, err := queries.Packs.GetByPackID(ctx, req.Pack.Manifest.PackID)
if err == nil {
if err := validateExistingPack(existing, req.Pack); err != nil {
return err
}
result.AlreadyInstalled = true
} else if !errors.Is(err, sql.ErrNoRows) {
return err
}
packRow, err := buildPackRecord(req.Pack)
if err != nil {
return err
}
if _, err := queries.Packs.Upsert(ctx, packRow); err != nil {
return err
}
persistedPack, err := queries.Packs.GetByPackID(ctx, req.Pack.Manifest.PackID)
if err != nil {
return err
}
result.Pack = persistedPack
providers := make([]sqlite.Provider, 0, len(req.Pack.Providers))
for _, providerManifest := range req.Pack.Providers {
providerRow, err := buildProviderRecord(persistedPack.ID, providerManifest)
if err != nil {
return err
}
if _, err := queries.Providers.Upsert(ctx, providerRow); err != nil {
return err
}
persistedProvider, err := queries.Providers.GetByPackIDAndProviderID(ctx, persistedPack.ID, providerManifest.ProviderID)
if err != nil {
return err
}
providers = append(providers, persistedProvider)
}
result.Providers = providers
return nil
}); err != nil {
return PackInstallResult{}, err
}
return result, nil
}
func validateExistingPack(existing sqlite.Pack, loaded packdef.LoadedPack) error {
if strings.TrimSpace(existing.PackID) != strings.TrimSpace(loaded.Manifest.PackID) {
return fmt.Errorf("existing pack %q does not match loaded pack %q", existing.PackID, loaded.Manifest.PackID)
}
return nil
}
func buildPackRecord(loaded packdef.LoadedPack) (sqlite.Pack, error) {
manifestJSON, err := json.Marshal(loaded.Manifest)
if err != nil {
return sqlite.Pack{}, fmt.Errorf("marshal pack manifest: %w", err)
}
return sqlite.Pack{
PackID: loaded.Manifest.PackID,
Version: loaded.Manifest.Version,
Checksum: loaded.Checksum,
Vendor: loaded.Manifest.Vendor,
TargetHost: loaded.Manifest.TargetHost,
MinHostVersion: loaded.Manifest.MinHostVersion,
MaxHostVersion: loaded.Manifest.MaxHostVersion,
ManifestJSON: string(manifestJSON),
}, nil
}
func buildProviderRecord(packID int64, provider packdef.ProviderManifest) (sqlite.Provider, error) {
defaultModelsJSON, err := marshalJSONString(provider.DefaultModels)
if err != nil {
return sqlite.Provider{}, fmt.Errorf("marshal provider default models: %w", err)
}
groupTemplateJSON, err := marshalJSONString(provider.GroupTemplate)
if err != nil {
return sqlite.Provider{}, fmt.Errorf("marshal group template: %w", err)
}
channelTemplateJSON, err := marshalJSONString(provider.ChannelTemplate)
if err != nil {
return sqlite.Provider{}, fmt.Errorf("marshal channel template: %w", err)
}
planTemplateJSON, err := marshalJSONString(provider.PlanTemplate)
if err != nil {
return sqlite.Provider{}, fmt.Errorf("marshal plan template: %w", err)
}
importOptionsJSON, err := marshalJSONString(provider.Import)
if err != nil {
return sqlite.Provider{}, fmt.Errorf("marshal import options: %w", err)
}
manifestJSON, err := marshalJSONString(provider)
if err != nil {
return sqlite.Provider{}, fmt.Errorf("marshal provider manifest: %w", err)
}
return sqlite.Provider{
PackID: packID,
ProviderID: provider.ProviderID,
DisplayName: provider.DisplayName,
BaseURL: provider.BaseURL,
Platform: provider.Platform,
AccountType: provider.AccountType,
DefaultModelsJSON: defaultModelsJSON,
SmokeTestModel: provider.SmokeTestModel,
GroupTemplateJSON: groupTemplateJSON,
ChannelTemplateJSON: channelTemplateJSON,
PlanTemplateJSON: planTemplateJSON,
ImportOptionsJSON: importOptionsJSON,
ManifestJSON: manifestJSON,
}, nil
}
func marshalJSONString(value any) (string, error) {
body, err := json.Marshal(value)
if err != nil {
return "", err
}
return string(body), nil
}