Files
2026-05-28 07:30:02 +08:00

317 lines
10 KiB
Go

package pack
import (
"bufio"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
"sort"
"strconv"
"strings"
)
type PublishProviderManifestRequest struct {
RepoRoot string
PackID string
Manifest ProviderManifest
CommitMessage string
}
type PublishProviderManifestResult struct {
RepoRoot string `json:"repo_root"`
PackID string `json:"pack_id"`
ProviderID string `json:"provider_id"`
ProviderPath string `json:"provider_path"`
PackVersionBefore string `json:"pack_version_before"`
PackVersionAfter string `json:"pack_version_after"`
PublishMode string `json:"publish_mode"`
CommitMessage string `json:"commit_message"`
CommitSHA string `json:"commit_sha"`
}
func PublishProviderManifest(ctx context.Context, req PublishProviderManifestRequest) (PublishProviderManifestResult, error) {
repoRoot, err := resolveRepoRoot(req.RepoRoot)
if err != nil {
return PublishProviderManifestResult{}, err
}
manifest := normalizeProviderManifest(req.Manifest)
packDir := filepath.Join(repoRoot, "packs", strings.TrimSpace(req.PackID))
loadedPack, err := LoadDir(packDir)
if err != nil {
return PublishProviderManifestResult{}, fmt.Errorf("load pack dir %q: %w", packDir, err)
}
if err := validateProviders(packDir, []ProviderManifest{manifest}); err != nil {
return PublishProviderManifestResult{}, err
}
providersDir := filepath.Join(packDir, loadedPack.Manifest.ProvidersDir)
if err := os.MkdirAll(providersDir, 0o755); err != nil {
return PublishProviderManifestResult{}, fmt.Errorf("ensure providers dir %q: %w", providersDir, err)
}
providerFileName := strings.TrimSpace(manifest.ProviderID) + ".json"
providerPath := filepath.Join(providersDir, providerFileName)
publishMode := "created"
if _, err := os.Stat(providerPath); err == nil {
publishMode = "updated"
} else if !os.IsNotExist(err) {
return PublishProviderManifestResult{}, fmt.Errorf("stat provider file %q: %w", providerPath, err)
}
providerBody, err := json.MarshalIndent(manifest, "", " ")
if err != nil {
return PublishProviderManifestResult{}, fmt.Errorf("marshal provider manifest: %w", err)
}
providerBody = append(providerBody, '\n')
if err := os.WriteFile(providerPath, providerBody, 0o644); err != nil {
return PublishProviderManifestResult{}, fmt.Errorf("write provider manifest %q: %w", providerPath, err)
}
packManifest := loadedPack.Manifest
previousVersion := packManifest.Version
nextVersion, err := bumpPatchVersion(previousVersion)
if err != nil {
return PublishProviderManifestResult{}, err
}
packManifest.Version = nextVersion
packBody, err := json.MarshalIndent(packManifest, "", " ")
if err != nil {
return PublishProviderManifestResult{}, fmt.Errorf("marshal pack manifest: %w", err)
}
packBody = append(packBody, '\n')
packManifestPath := filepath.Join(packDir, "pack.json")
if err := os.WriteFile(packManifestPath, packBody, 0o644); err != nil {
return PublishProviderManifestResult{}, fmt.Errorf("write pack manifest %q: %w", packManifestPath, err)
}
if err := updateChecksumFile(packDir, packManifest.ChecksumFile, []string{
"pack.json",
filepath.ToSlash(filepath.Join(packManifest.ProvidersDir, providerFileName)),
}); err != nil {
return PublishProviderManifestResult{}, err
}
if _, err := LoadDir(packDir); err != nil {
return PublishProviderManifestResult{}, fmt.Errorf("re-validate published pack %q: %w", packDir, err)
}
commitMessage := strings.TrimSpace(req.CommitMessage)
if commitMessage == "" {
commitMessage = fmt.Sprintf("feat(pack): publish provider draft %s", manifest.ProviderID)
}
commitSHA, err := commitPackPublish(ctx, repoRoot, commitMessage, []string{
filepath.Join("packs", packManifest.PackID, "pack.json"),
filepath.Join("packs", packManifest.PackID, packManifest.ChecksumFile),
filepath.Join("packs", packManifest.PackID, packManifest.ProvidersDir, providerFileName),
})
if err != nil {
return PublishProviderManifestResult{}, err
}
return PublishProviderManifestResult{
RepoRoot: repoRoot,
PackID: packManifest.PackID,
ProviderID: manifest.ProviderID,
ProviderPath: filepath.ToSlash(filepath.Join("packs", packManifest.PackID, packManifest.ProvidersDir, providerFileName)),
PackVersionBefore: previousVersion,
PackVersionAfter: nextVersion,
PublishMode: publishMode,
CommitMessage: commitMessage,
CommitSHA: commitSHA,
}, nil
}
func normalizeProviderManifest(manifest ProviderManifest) ProviderManifest {
normalized := manifest
normalized.ProviderID = strings.TrimSpace(normalized.ProviderID)
normalized.DisplayName = strings.TrimSpace(normalized.DisplayName)
normalized.BaseURL = strings.TrimSpace(normalized.BaseURL)
normalized.Platform = strings.TrimSpace(normalized.Platform)
normalized.AccountType = strings.TrimSpace(normalized.AccountType)
normalized.SmokeTestModel = strings.TrimSpace(normalized.SmokeTestModel)
if normalized.AccountType == "" {
normalized.AccountType = "apikey"
}
normalized.DefaultModels = normalizeModels(normalized.DefaultModels, normalized.SmokeTestModel)
if normalized.GroupTemplate.Name == "" {
normalized.GroupTemplate.Name = normalized.DisplayName + " 默认分组"
}
if normalized.GroupTemplate.RateMultiplier == 0 {
normalized.GroupTemplate.RateMultiplier = 1.0
}
if normalized.ChannelTemplate.Name == "" {
normalized.ChannelTemplate.Name = normalized.DisplayName + " 默认渠道"
}
if normalized.ChannelTemplate.ModelMapping == nil {
normalized.ChannelTemplate.ModelMapping = make(map[string]string, len(normalized.DefaultModels))
}
for _, model := range normalized.DefaultModels {
if _, ok := normalized.ChannelTemplate.ModelMapping[model]; !ok {
normalized.ChannelTemplate.ModelMapping[model] = model
}
}
if normalized.PlanTemplate.Name == "" {
normalized.PlanTemplate.Name = normalized.DisplayName + " 默认套餐"
}
if normalized.PlanTemplate.ValidityDays <= 0 {
normalized.PlanTemplate.ValidityDays = 30
}
if normalized.PlanTemplate.ValidityUnit == "" {
normalized.PlanTemplate.ValidityUnit = "day"
}
if normalized.Import == (ImportOptions{}) {
normalized.Import = ImportOptions{
SupportsMultiKey: true,
SupportsStrict: true,
SupportsPartial: true,
}
}
return normalized
}
func normalizeModels(models []string, smokeTestModel string) []string {
normalized := make([]string, 0, len(models)+1)
seen := make(map[string]struct{}, len(models)+1)
for _, model := range models {
model = strings.TrimSpace(model)
if model == "" {
continue
}
if _, ok := seen[model]; ok {
continue
}
normalized = append(normalized, model)
seen[model] = struct{}{}
}
smokeTestModel = strings.TrimSpace(smokeTestModel)
if smokeTestModel != "" {
if _, ok := seen[smokeTestModel]; !ok {
normalized = append(normalized, smokeTestModel)
}
}
return normalized
}
func resolveRepoRoot(repoRoot string) (string, error) {
repoRoot = strings.TrimSpace(repoRoot)
if repoRoot == "" {
return "", fmt.Errorf("pack repo root is not configured")
}
absoluteRepoRoot, err := filepath.Abs(repoRoot)
if err != nil {
return "", fmt.Errorf("resolve repo root %q: %w", repoRoot, err)
}
info, err := os.Stat(absoluteRepoRoot)
if err != nil {
return "", fmt.Errorf("stat repo root %q: %w", absoluteRepoRoot, err)
}
if !info.IsDir() {
return "", fmt.Errorf("repo root %q is not a directory", absoluteRepoRoot)
}
return absoluteRepoRoot, nil
}
func bumpPatchVersion(version string) (string, error) {
parts := strings.Split(strings.TrimSpace(version), ".")
if len(parts) != 3 {
return "", fmt.Errorf("pack version %q must use x.y.z format", version)
}
patch, err := strconv.Atoi(parts[2])
if err != nil {
return "", fmt.Errorf("parse pack version %q patch: %w", version, err)
}
parts[2] = strconv.Itoa(patch + 1)
return strings.Join(parts, "."), nil
}
func updateChecksumFile(packDir string, checksumFile string, relativePaths []string) error {
path := filepath.Join(packDir, checksumFile)
entries, err := readChecksumEntries(path)
if err != nil {
return err
}
for _, relativePath := range relativePaths {
normalizedPath := filepath.ToSlash(filepath.Clean(strings.TrimSpace(relativePath)))
if normalizedPath == "." || normalizedPath == "" {
continue
}
body, err := os.ReadFile(filepath.Join(packDir, normalizedPath))
if err != nil {
return fmt.Errorf("read checksum target %q: %w", normalizedPath, err)
}
sum := sha256.Sum256(body)
entries[normalizedPath] = hex.EncodeToString(sum[:])
}
paths := make([]string, 0, len(entries))
for relativePath := range entries {
paths = append(paths, relativePath)
}
sort.Strings(paths)
lines := make([]string, 0, len(paths))
for _, relativePath := range paths {
lines = append(lines, fmt.Sprintf("%s %s", entries[relativePath], relativePath))
}
if err := os.WriteFile(path, []byte(strings.Join(lines, "\n")+"\n"), 0o644); err != nil {
return fmt.Errorf("write checksum file %q: %w", path, err)
}
return nil
}
func readChecksumEntries(path string) (map[string]string, error) {
file, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("read checksum file %q: %w", path, err)
}
defer file.Close()
entries := map[string]string{}
scanner := bufio.NewScanner(file)
lineNumber := 0
for scanner.Scan() {
lineNumber++
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue
}
parts := strings.Fields(line)
if len(parts) != 2 {
return nil, fmt.Errorf("checksum file %q line %d: invalid format", path, lineNumber)
}
entries[filepath.ToSlash(parts[1])] = parts[0]
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("scan checksum file %q: %w", path, err)
}
return entries, nil
}
func commitPackPublish(ctx context.Context, repoRoot string, message string, relativePaths []string) (string, error) {
addArgs := append([]string{"add"}, relativePaths...)
if _, err := runGit(ctx, repoRoot, addArgs...); err != nil {
return "", err
}
if _, err := runGit(ctx, repoRoot, "commit", "-m", message); err != nil {
return "", err
}
sha, err := runGit(ctx, repoRoot, "rev-parse", "--short", "HEAD")
if err != nil {
return "", err
}
return strings.TrimSpace(sha), nil
}
func runGit(ctx context.Context, repoRoot string, args ...string) (string, error) {
cmd := exec.CommandContext(ctx, "git", append([]string{"-C", repoRoot}, args...)...)
output, err := cmd.CombinedOutput()
if err != nil {
return "", fmt.Errorf("run git %q: %w: %s", strings.Join(args, " "), err, strings.TrimSpace(string(output)))
}
return strings.TrimSpace(string(output)), nil
}