test(project): achieve ≥70% package coverage across all internal packages
- store/sqlite: 75.4% (repos + db coverage) - host/sub2api: 80.8% (httptest mock server, pure function tests) - app: 74.2% (handler error paths, NewActionSet closures) - pack: 72.4% - provision: 75.2% - access: 77.3% - config: 94.7% (lookup mock tests) All tests pass: build, vet, race, coverage gates.
This commit is contained in:
266
internal/pack/extra_test.go
Normal file
266
internal/pack/extra_test.go
Normal file
@@ -0,0 +1,266 @@
|
||||
package pack
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestValidateManifestRequiredFields(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
manifest Manifest
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "empty pack id",
|
||||
manifest: Manifest{
|
||||
Version: "1.0.0",
|
||||
TargetHost: "sub2api",
|
||||
ProvidersDir: "providers",
|
||||
ChecksumFile: "checksums.txt",
|
||||
},
|
||||
wantErr: "pack_id is required",
|
||||
},
|
||||
{
|
||||
name: "empty version",
|
||||
manifest: Manifest{
|
||||
PackID: "openai-cn-pack",
|
||||
TargetHost: "sub2api",
|
||||
ProvidersDir: "providers",
|
||||
ChecksumFile: "checksums.txt",
|
||||
},
|
||||
wantErr: "version is required",
|
||||
},
|
||||
{
|
||||
name: "empty target host",
|
||||
manifest: Manifest{
|
||||
PackID: "openai-cn-pack",
|
||||
Version: "1.0.0",
|
||||
ProvidersDir: "providers",
|
||||
ChecksumFile: "checksums.txt",
|
||||
},
|
||||
wantErr: "target_host is required",
|
||||
},
|
||||
{
|
||||
name: "empty providers dir",
|
||||
manifest: Manifest{
|
||||
PackID: "openai-cn-pack",
|
||||
Version: "1.0.0",
|
||||
TargetHost: "sub2api",
|
||||
ChecksumFile: "checksums.txt",
|
||||
},
|
||||
wantErr: "providers_dir is required",
|
||||
},
|
||||
{
|
||||
name: "empty checksum file",
|
||||
manifest: Manifest{
|
||||
PackID: "openai-cn-pack",
|
||||
Version: "1.0.0",
|
||||
TargetHost: "sub2api",
|
||||
ProvidersDir: "providers",
|
||||
},
|
||||
wantErr: "checksum_file is required",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateManifest(tt.manifest)
|
||||
if err == nil || !strings.Contains(err.Error(), tt.wantErr) {
|
||||
t.Fatalf("validateManifest() error = %v, want substring %q", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateProvidersRejectsInvalidProviderFields(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
providers []ProviderManifest
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "empty provider id",
|
||||
providers: []ProviderManifest{{
|
||||
DisplayName: "DeepSeek",
|
||||
BaseURL: "https://api.deepseek.com",
|
||||
Platform: "openai",
|
||||
AccountType: "api",
|
||||
DefaultModels: []string{"deepseek-chat"},
|
||||
SmokeTestModel: "deepseek-chat",
|
||||
GroupTemplate: GroupTemplate{Name: "g"},
|
||||
ChannelTemplate: ChannelTemplate{
|
||||
Name: "c",
|
||||
ModelMapping: map[string]string{"deepseek-chat": "deepseek-chat"},
|
||||
},
|
||||
PlanTemplate: PlanTemplate{Name: "p", ValidityDays: 30},
|
||||
}},
|
||||
wantErr: "provider_id is required",
|
||||
},
|
||||
{
|
||||
name: "empty base url",
|
||||
providers: []ProviderManifest{{
|
||||
ProviderID: "deepseek",
|
||||
DisplayName: "DeepSeek",
|
||||
BaseURL: "",
|
||||
Platform: "openai",
|
||||
AccountType: "api",
|
||||
DefaultModels: []string{"deepseek-chat"},
|
||||
SmokeTestModel: "deepseek-chat",
|
||||
GroupTemplate: GroupTemplate{Name: "g"},
|
||||
ChannelTemplate: ChannelTemplate{
|
||||
Name: "c",
|
||||
ModelMapping: map[string]string{"deepseek-chat": "deepseek-chat"},
|
||||
},
|
||||
PlanTemplate: PlanTemplate{Name: "p", ValidityDays: 30},
|
||||
}},
|
||||
wantErr: "base_url must use https",
|
||||
},
|
||||
{
|
||||
name: "missing display name",
|
||||
providers: []ProviderManifest{{
|
||||
ProviderID: "deepseek",
|
||||
DisplayName: "",
|
||||
BaseURL: "https://api.deepseek.com",
|
||||
Platform: "openai",
|
||||
AccountType: "api",
|
||||
DefaultModels: []string{"deepseek-chat"},
|
||||
SmokeTestModel: "deepseek-chat",
|
||||
GroupTemplate: GroupTemplate{Name: "g"},
|
||||
ChannelTemplate: ChannelTemplate{
|
||||
Name: "c",
|
||||
ModelMapping: map[string]string{"deepseek-chat": "deepseek-chat"},
|
||||
},
|
||||
PlanTemplate: PlanTemplate{Name: "p", ValidityDays: 30},
|
||||
}},
|
||||
wantErr: "display_name is required",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateProviders(tt.providers)
|
||||
if err == nil || !strings.Contains(err.Error(), tt.wantErr) {
|
||||
t.Fatalf("validateProviders() error = %v, want substring %q", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractZipToTempRejectsEmptyArchive(t *testing.T) {
|
||||
archivePath := filepath.Join(t.TempDir(), "empty.zip")
|
||||
file, err := os.Create(archivePath)
|
||||
if err != nil {
|
||||
t.Fatalf("os.Create() error = %v", err)
|
||||
}
|
||||
writer := zip.NewWriter(file)
|
||||
if err := writer.Close(); err != nil {
|
||||
t.Fatalf("writer.Close() error = %v", err)
|
||||
}
|
||||
if err := file.Close(); err != nil {
|
||||
t.Fatalf("file.Close() error = %v", err)
|
||||
}
|
||||
|
||||
_, cleanup, err := extractZipToTemp(archivePath)
|
||||
if cleanup != nil {
|
||||
cleanup()
|
||||
}
|
||||
if err == nil || !strings.Contains(err.Error(), "pack archive is empty") {
|
||||
t.Fatalf("extractZipToTemp() error = %v, want empty archive error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractZipToTempRejectsPathTraversal(t *testing.T) {
|
||||
archivePath := filepath.Join(t.TempDir(), "traversal.zip")
|
||||
file, err := os.Create(archivePath)
|
||||
if err != nil {
|
||||
t.Fatalf("os.Create() error = %v", err)
|
||||
}
|
||||
writer := zip.NewWriter(file)
|
||||
entry, err := writer.Create("../../../evil.txt")
|
||||
if err != nil {
|
||||
t.Fatalf("writer.Create() error = %v", err)
|
||||
}
|
||||
if _, err := entry.Write([]byte("evil")); err != nil {
|
||||
t.Fatalf("entry.Write() error = %v", err)
|
||||
}
|
||||
if err := writer.Close(); err != nil {
|
||||
t.Fatalf("writer.Close() error = %v", err)
|
||||
}
|
||||
if err := file.Close(); err != nil {
|
||||
t.Fatalf("file.Close() error = %v", err)
|
||||
}
|
||||
|
||||
_, cleanup, err := extractZipToTemp(archivePath)
|
||||
if cleanup != nil {
|
||||
cleanup()
|
||||
}
|
||||
if err == nil || !strings.Contains(err.Error(), "invalid path") {
|
||||
t.Fatalf("extractZipToTemp() error = %v, want invalid path error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchesMaxConstraintCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
hostVersion string
|
||||
maxVersion string
|
||||
want bool
|
||||
}{
|
||||
{name: "exact version match", hostVersion: "1.2.3", maxVersion: "1.2.3", want: true},
|
||||
{name: "wildcard x accepts same minor", hostVersion: "0.2.9", maxVersion: "0.2.x", want: true},
|
||||
{name: "non matching version", hostVersion: "1.2.4", maxVersion: "1.2.3", want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := matchesMaxConstraint(tt.hostVersion, tt.maxVersion)
|
||||
if err != nil {
|
||||
t.Fatalf("matchesMaxConstraint() error = %v", err)
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Fatalf("matchesMaxConstraint() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchesMaxConstraintRejectsWildcardStar(t *testing.T) {
|
||||
_, err := matchesMaxConstraint("1.2.3", "1.2.*")
|
||||
if err == nil || !strings.Contains(err.Error(), `parse version segment "*"`) {
|
||||
t.Fatalf("matchesMaxConstraint() error = %v, want parse failure for wildcard star", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadPathRejectsEmptyAndMissingPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
wantErr string
|
||||
}{
|
||||
{name: "empty path", path: " ", wantErr: "pack path is required"},
|
||||
{name: "missing path", path: filepath.Join(t.TempDir(), "missing-pack"), wantErr: "stat pack path"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := LoadPath(tt.path)
|
||||
if err == nil || !strings.Contains(err.Error(), tt.wantErr) {
|
||||
t.Fatalf("LoadPath() error = %v, want substring %q", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadArchiveRejectsNonZipFile(t *testing.T) {
|
||||
filePath := filepath.Join(t.TempDir(), "not-a-zip.zip")
|
||||
mustWrite(t, filePath, "plain text, not a zip archive")
|
||||
|
||||
_, err := LoadArchive(filePath)
|
||||
if err == nil || !strings.Contains(err.Error(), "open pack archive") {
|
||||
t.Fatalf("LoadArchive() error = %v, want open archive error", err)
|
||||
}
|
||||
}
|
||||
249
internal/pack/loader.go
Normal file
249
internal/pack/loader.go
Normal file
@@ -0,0 +1,249 @@
|
||||
package pack
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Manifest struct {
|
||||
PackID string `json:"pack_id"`
|
||||
Version string `json:"version"`
|
||||
Vendor string `json:"vendor"`
|
||||
TargetHost string `json:"target_host"`
|
||||
MinHostVersion string `json:"min_host_version"`
|
||||
MaxHostVersion string `json:"max_host_version"`
|
||||
ProvidersDir string `json:"providers_dir"`
|
||||
ChecksumFile string `json:"checksum_file"`
|
||||
}
|
||||
|
||||
type ProviderManifest struct {
|
||||
ProviderID string `json:"provider_id"`
|
||||
DisplayName string `json:"display_name"`
|
||||
BaseURL string `json:"base_url"`
|
||||
Platform string `json:"platform"`
|
||||
AccountType string `json:"account_type"`
|
||||
DefaultModels []string `json:"default_models"`
|
||||
SmokeTestModel string `json:"smoke_test_model"`
|
||||
GroupTemplate GroupTemplate `json:"group_template"`
|
||||
ChannelTemplate ChannelTemplate `json:"channel_template"`
|
||||
PlanTemplate PlanTemplate `json:"plan_template"`
|
||||
Import ImportOptions `json:"import"`
|
||||
}
|
||||
|
||||
type GroupTemplate struct {
|
||||
Name string `json:"name"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
}
|
||||
|
||||
type ChannelTemplate struct {
|
||||
Name string `json:"name"`
|
||||
ModelMapping map[string]string `json:"model_mapping"`
|
||||
}
|
||||
|
||||
type PlanTemplate struct {
|
||||
Name string `json:"name"`
|
||||
Price float64 `json:"price"`
|
||||
ValidityDays int `json:"validity_days"`
|
||||
ValidityUnit string `json:"validity_unit"`
|
||||
}
|
||||
|
||||
type ImportOptions struct {
|
||||
SupportsMultiKey bool `json:"supports_multi_key"`
|
||||
SupportsStrict bool `json:"supports_strict"`
|
||||
SupportsPartial bool `json:"supports_partial"`
|
||||
}
|
||||
|
||||
type LoadedPack struct {
|
||||
Dir string
|
||||
Manifest Manifest
|
||||
Providers []ProviderManifest
|
||||
Checksum string
|
||||
}
|
||||
|
||||
func LoadDir(dir string) (LoadedPack, error) {
|
||||
root := strings.TrimSpace(dir)
|
||||
if root == "" {
|
||||
return LoadedPack{}, fmt.Errorf("pack dir is required")
|
||||
}
|
||||
|
||||
manifestPath := filepath.Join(root, "pack.json")
|
||||
manifestBytes, err := os.ReadFile(manifestPath)
|
||||
if err != nil {
|
||||
return LoadedPack{}, fmt.Errorf("read pack.json: %w", err)
|
||||
}
|
||||
|
||||
var manifest Manifest
|
||||
if err := json.Unmarshal(manifestBytes, &manifest); err != nil {
|
||||
return LoadedPack{}, fmt.Errorf("decode pack.json: %w", err)
|
||||
}
|
||||
if err := validateManifest(manifest); err != nil {
|
||||
return LoadedPack{}, err
|
||||
}
|
||||
|
||||
if err := validateChecksums(root, manifest.ChecksumFile); err != nil {
|
||||
return LoadedPack{}, err
|
||||
}
|
||||
|
||||
providers, err := loadProviders(root, manifest.ProvidersDir)
|
||||
if err != nil {
|
||||
return LoadedPack{}, err
|
||||
}
|
||||
if len(providers) == 0 {
|
||||
return LoadedPack{}, fmt.Errorf("providers dir %q does not contain provider manifests", manifest.ProvidersDir)
|
||||
}
|
||||
if err := validateProviders(providers); err != nil {
|
||||
return LoadedPack{}, err
|
||||
}
|
||||
|
||||
checksum, err := computeAggregateChecksum(root, manifest.ChecksumFile)
|
||||
if err != nil {
|
||||
return LoadedPack{}, err
|
||||
}
|
||||
|
||||
return LoadedPack{Dir: root, Manifest: manifest, Providers: providers, Checksum: checksum}, nil
|
||||
}
|
||||
|
||||
func validateManifest(manifest Manifest) error {
|
||||
switch {
|
||||
case strings.TrimSpace(manifest.PackID) == "":
|
||||
return fmt.Errorf("pack.json: pack_id is required")
|
||||
case strings.TrimSpace(manifest.Version) == "":
|
||||
return fmt.Errorf("pack.json: version is required")
|
||||
case strings.TrimSpace(manifest.TargetHost) == "":
|
||||
return fmt.Errorf("pack.json: target_host is required")
|
||||
case strings.TrimSpace(manifest.ProvidersDir) == "":
|
||||
return fmt.Errorf("pack.json: providers_dir is required")
|
||||
case strings.TrimSpace(manifest.ChecksumFile) == "":
|
||||
return fmt.Errorf("pack.json: checksum_file is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func loadProviders(root string, providersDir string) ([]ProviderManifest, error) {
|
||||
dir := filepath.Join(root, providersDir)
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read providers dir %q: %w", providersDir, err)
|
||||
}
|
||||
|
||||
providers := make([]ProviderManifest, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".json") {
|
||||
continue
|
||||
}
|
||||
path := filepath.Join(dir, entry.Name())
|
||||
body, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read provider %q: %w", entry.Name(), err)
|
||||
}
|
||||
var provider ProviderManifest
|
||||
if err := json.Unmarshal(body, &provider); err != nil {
|
||||
return nil, fmt.Errorf("decode provider %q: %w", entry.Name(), err)
|
||||
}
|
||||
providers = append(providers, provider)
|
||||
}
|
||||
sort.Slice(providers, func(i, j int) bool { return providers[i].ProviderID < providers[j].ProviderID })
|
||||
return providers, nil
|
||||
}
|
||||
|
||||
func validateProviders(providers []ProviderManifest) error {
|
||||
seen := make(map[string]struct{}, len(providers))
|
||||
for _, provider := range providers {
|
||||
providerID := strings.TrimSpace(provider.ProviderID)
|
||||
switch {
|
||||
case providerID == "":
|
||||
return fmt.Errorf("provider manifest: provider_id is required")
|
||||
case strings.TrimSpace(provider.DisplayName) == "":
|
||||
return fmt.Errorf("provider %q: display_name is required", providerID)
|
||||
case !strings.HasPrefix(strings.TrimSpace(provider.BaseURL), "https://"):
|
||||
return fmt.Errorf("provider %q: base_url must use https", providerID)
|
||||
case strings.TrimSpace(provider.Platform) == "":
|
||||
return fmt.Errorf("provider %q: platform is required", providerID)
|
||||
case strings.TrimSpace(provider.AccountType) == "":
|
||||
return fmt.Errorf("provider %q: account_type is required", providerID)
|
||||
case len(provider.DefaultModels) == 0:
|
||||
return fmt.Errorf("provider %q: default_models must not be empty", providerID)
|
||||
case strings.TrimSpace(provider.SmokeTestModel) == "":
|
||||
return fmt.Errorf("provider %q: smoke_test_model is required", providerID)
|
||||
case !contains(provider.DefaultModels, provider.SmokeTestModel):
|
||||
return fmt.Errorf("provider %q: smoke_test_model must be present in default_models", providerID)
|
||||
case strings.TrimSpace(provider.GroupTemplate.Name) == "":
|
||||
return fmt.Errorf("provider %q: group_template.name is required", providerID)
|
||||
case strings.TrimSpace(provider.ChannelTemplate.Name) == "":
|
||||
return fmt.Errorf("provider %q: channel_template.name is required", providerID)
|
||||
case len(provider.ChannelTemplate.ModelMapping) == 0:
|
||||
return fmt.Errorf("provider %q: channel_template.model_mapping must not be empty", providerID)
|
||||
case strings.TrimSpace(provider.PlanTemplate.Name) == "":
|
||||
return fmt.Errorf("provider %q: plan_template.name is required", providerID)
|
||||
case provider.PlanTemplate.ValidityDays <= 0:
|
||||
return fmt.Errorf("provider %q: plan_template.validity_days must be positive", providerID)
|
||||
}
|
||||
if _, ok := seen[providerID]; ok {
|
||||
return fmt.Errorf("duplicate provider_id %q", providerID)
|
||||
}
|
||||
seen[providerID] = struct{}{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateChecksums(root string, checksumFile string) error {
|
||||
path := filepath.Join(root, checksumFile)
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read checksum file %q: %w", checksumFile, err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
scanner := bufio.NewScanner(file)
|
||||
lineNumber := 0
|
||||
for scanner.Scan() {
|
||||
lineNumber++
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
parts := strings.Fields(line)
|
||||
if len(parts) != 2 {
|
||||
return fmt.Errorf("checksum file %q line %d: invalid format", checksumFile, lineNumber)
|
||||
}
|
||||
relativePath := parts[1]
|
||||
body, err := os.ReadFile(filepath.Join(root, relativePath))
|
||||
if err != nil {
|
||||
return fmt.Errorf("checksum file %q line %d: read %q: %w", checksumFile, lineNumber, relativePath, err)
|
||||
}
|
||||
sum := sha256.Sum256(body)
|
||||
actual := hex.EncodeToString(sum[:])
|
||||
if !strings.EqualFold(parts[0], actual) {
|
||||
return fmt.Errorf("checksum mismatch for %s", relativePath)
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return fmt.Errorf("scan checksum file %q: %w", checksumFile, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func computeAggregateChecksum(root string, checksumFile string) (string, error) {
|
||||
body, err := os.ReadFile(filepath.Join(root, checksumFile))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read checksum file %q: %w", checksumFile, err)
|
||||
}
|
||||
sum := sha256.Sum256(body)
|
||||
return hex.EncodeToString(sum[:]), nil
|
||||
}
|
||||
|
||||
func contains(items []string, target string) bool {
|
||||
for _, item := range items {
|
||||
if strings.TrimSpace(item) == strings.TrimSpace(target) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
108
internal/pack/loader_test.go
Normal file
108
internal/pack/loader_test.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package pack
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadDirParsesAndValidatesPack(t *testing.T) {
|
||||
packDir := createPackFixture(t, map[string]string{
|
||||
"pack.json": `{
|
||||
"pack_id": "openai-cn-pack",
|
||||
"version": "1.0.0",
|
||||
"vendor": "YourTeam",
|
||||
"target_host": "sub2api",
|
||||
"min_host_version": "0.1.126",
|
||||
"max_host_version": "0.2.x",
|
||||
"providers_dir": "providers",
|
||||
"checksum_file": "checksums.txt"
|
||||
}`,
|
||||
"providers/deepseek.json": `{
|
||||
"provider_id": "deepseek",
|
||||
"display_name": "DeepSeek OpenAI Compatible",
|
||||
"base_url": "https://api.deepseek.com",
|
||||
"platform": "openai",
|
||||
"account_type": "api",
|
||||
"default_models": ["deepseek-chat", "deepseek-reasoner"],
|
||||
"smoke_test_model": "deepseek-chat",
|
||||
"group_template": {"name": "DeepSeek 默认分组", "rate_multiplier": 1.0},
|
||||
"channel_template": {"name": "DeepSeek 默认渠道", "model_mapping": {"deepseek-chat": "deepseek-chat"}},
|
||||
"plan_template": {"name": "DeepSeek 默认套餐", "price": 19.9, "validity_days": 30, "validity_unit": "day"},
|
||||
"import": {"supports_multi_key": true, "supports_strict": true, "supports_partial": true}
|
||||
}`,
|
||||
})
|
||||
|
||||
loaded, err := LoadDir(packDir)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadDir() error = %v", err)
|
||||
}
|
||||
|
||||
if loaded.Manifest.PackID != "openai-cn-pack" {
|
||||
t.Fatalf("PackID = %q, want %q", loaded.Manifest.PackID, "openai-cn-pack")
|
||||
}
|
||||
if len(loaded.Providers) != 1 {
|
||||
t.Fatalf("len(Providers) = %d, want 1", len(loaded.Providers))
|
||||
}
|
||||
if loaded.Providers[0].ProviderID != "deepseek" {
|
||||
t.Fatalf("ProviderID = %q, want %q", loaded.Providers[0].ProviderID, "deepseek")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadDirRejectsChecksumMismatch(t *testing.T) {
|
||||
packDir := t.TempDir()
|
||||
mustWrite(t, filepath.Join(packDir, "pack.json"), `{"pack_id":"openai-cn-pack","version":"1.0.0","vendor":"x","target_host":"sub2api","min_host_version":"0.1.126","max_host_version":"0.2.x","providers_dir":"providers","checksum_file":"checksums.txt"}`)
|
||||
mustWrite(t, filepath.Join(packDir, "providers", "deepseek.json"), `{"provider_id":"deepseek","display_name":"DeepSeek","base_url":"https://api.deepseek.com","platform":"openai","account_type":"api","default_models":["deepseek-chat"],"smoke_test_model":"deepseek-chat","group_template":{"name":"g","rate_multiplier":1},"channel_template":{"name":"c","model_mapping":{"deepseek-chat":"deepseek-chat"}},"plan_template":{"name":"p","price":1,"validity_days":30,"validity_unit":"day"},"import":{"supports_multi_key":true,"supports_strict":true,"supports_partial":true}}`)
|
||||
mustWrite(t, filepath.Join(packDir, "checksums.txt"), "deadbeef pack.json\ndeadbeef providers/deepseek.json\n")
|
||||
|
||||
_, err := LoadDir(packDir)
|
||||
if err == nil {
|
||||
t.Fatal("LoadDir() error = nil, want checksum mismatch")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "checksum mismatch") {
|
||||
t.Fatalf("LoadDir() error = %v, want checksum mismatch", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadDirRejectsInvalidProviderSchema(t *testing.T) {
|
||||
packDir := createPackFixture(t, map[string]string{
|
||||
"pack.json": `{"pack_id":"openai-cn-pack","version":"1.0.0","vendor":"x","target_host":"sub2api","min_host_version":"0.1.126","max_host_version":"0.2.x","providers_dir":"providers","checksum_file":"checksums.txt"}`,
|
||||
"providers/deepseek.json": `{"provider_id":"deepseek","display_name":"DeepSeek","base_url":"http://insecure.example.com","platform":"openai","account_type":"api","default_models":["deepseek-chat"],"smoke_test_model":"missing-model","group_template":{"name":"g","rate_multiplier":1},"channel_template":{"name":"c","model_mapping":{"deepseek-chat":"deepseek-chat"}},"plan_template":{"name":"p","price":1,"validity_days":30,"validity_unit":"day"},"import":{"supports_multi_key":true,"supports_strict":true,"supports_partial":true}}`,
|
||||
})
|
||||
|
||||
_, err := LoadDir(packDir)
|
||||
if err == nil {
|
||||
t.Fatal("LoadDir() error = nil, want schema validation failure")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "https") && !strings.Contains(err.Error(), "smoke_test_model") {
|
||||
t.Fatalf("LoadDir() error = %v, want schema validation detail", err)
|
||||
}
|
||||
}
|
||||
|
||||
func createPackFixture(t *testing.T, files map[string]string) string {
|
||||
t.Helper()
|
||||
|
||||
packDir := t.TempDir()
|
||||
var lines []string
|
||||
for relativePath, content := range files {
|
||||
absolutePath := filepath.Join(packDir, relativePath)
|
||||
mustWrite(t, absolutePath, content)
|
||||
sum := sha256.Sum256([]byte(content))
|
||||
lines = append(lines, hex.EncodeToString(sum[:])+" "+relativePath)
|
||||
}
|
||||
mustWrite(t, filepath.Join(packDir, "checksums.txt"), strings.Join(lines, "\n")+"\n")
|
||||
return packDir
|
||||
}
|
||||
|
||||
func mustWrite(t *testing.T, path string, content string) {
|
||||
t.Helper()
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
t.Fatalf("MkdirAll(%q) error = %v", path, err)
|
||||
}
|
||||
if err := os.WriteFile(path, []byte(content), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile(%q) error = %v", path, err)
|
||||
}
|
||||
}
|
||||
171
internal/pack/source_loader.go
Normal file
171
internal/pack/source_loader.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package pack
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
maxArchiveEntries = 256
|
||||
maxArchiveFileSize = 5 << 20
|
||||
maxArchiveTotalSize = 20 << 20
|
||||
)
|
||||
|
||||
func LoadPath(path string) (LoadedPack, error) {
|
||||
trimmed := strings.TrimSpace(path)
|
||||
if trimmed == "" {
|
||||
return LoadedPack{}, fmt.Errorf("pack path is required")
|
||||
}
|
||||
|
||||
info, err := os.Stat(trimmed)
|
||||
if err != nil {
|
||||
return LoadedPack{}, fmt.Errorf("stat pack path: %w", err)
|
||||
}
|
||||
if info.IsDir() {
|
||||
return LoadDir(trimmed)
|
||||
}
|
||||
if strings.EqualFold(filepath.Ext(info.Name()), ".zip") {
|
||||
return LoadArchive(trimmed)
|
||||
}
|
||||
return LoadedPack{}, fmt.Errorf("pack path %q must be a directory or .zip archive", trimmed)
|
||||
}
|
||||
|
||||
func LoadArchive(path string) (LoadedPack, error) {
|
||||
root, cleanup, err := extractZipToTemp(path)
|
||||
if err != nil {
|
||||
return LoadedPack{}, err
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
loaded, err := LoadDir(root)
|
||||
if err != nil {
|
||||
return LoadedPack{}, err
|
||||
}
|
||||
loaded.Dir = strings.TrimSpace(path)
|
||||
return loaded, nil
|
||||
}
|
||||
|
||||
func extractZipToTemp(path string) (string, func(), error) {
|
||||
reader, err := zip.OpenReader(strings.TrimSpace(path))
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("open pack archive: %w", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
if len(reader.File) == 0 {
|
||||
return "", nil, fmt.Errorf("pack archive is empty")
|
||||
}
|
||||
if len(reader.File) > maxArchiveEntries {
|
||||
return "", nil, fmt.Errorf("pack archive has too many entries: %d", len(reader.File))
|
||||
}
|
||||
|
||||
tempDir, err := os.MkdirTemp("", "relay-pack-*")
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("create temp dir for pack archive: %w", err)
|
||||
}
|
||||
cleanup := func() { _ = os.RemoveAll(tempDir) }
|
||||
|
||||
var totalSize uint64
|
||||
for _, file := range reader.File {
|
||||
cleanName := filepath.Clean(file.Name)
|
||||
if cleanName == "." || cleanName == "" {
|
||||
continue
|
||||
}
|
||||
if filepath.IsAbs(cleanName) || cleanName == ".." || strings.HasPrefix(cleanName, ".."+string(filepath.Separator)) {
|
||||
cleanup()
|
||||
return "", nil, fmt.Errorf("pack archive contains invalid path %q", file.Name)
|
||||
}
|
||||
if file.FileInfo().Mode()&os.ModeSymlink != 0 {
|
||||
cleanup()
|
||||
return "", nil, fmt.Errorf("pack archive contains unsupported symlink entry %q", file.Name)
|
||||
}
|
||||
if file.UncompressedSize64 > maxArchiveFileSize {
|
||||
cleanup()
|
||||
return "", nil, fmt.Errorf("pack archive entry %q exceeds size limit", file.Name)
|
||||
}
|
||||
totalSize += file.UncompressedSize64
|
||||
if totalSize > maxArchiveTotalSize {
|
||||
cleanup()
|
||||
return "", nil, fmt.Errorf("pack archive exceeds total size limit")
|
||||
}
|
||||
|
||||
targetPath := filepath.Join(tempDir, cleanName)
|
||||
relativeTarget, err := filepath.Rel(tempDir, targetPath)
|
||||
if err != nil || relativeTarget == ".." || strings.HasPrefix(relativeTarget, ".."+string(filepath.Separator)) {
|
||||
cleanup()
|
||||
return "", nil, fmt.Errorf("pack archive entry %q escapes extraction root", file.Name)
|
||||
}
|
||||
|
||||
if file.FileInfo().IsDir() {
|
||||
if err := os.MkdirAll(targetPath, 0o755); err != nil {
|
||||
cleanup()
|
||||
return "", nil, fmt.Errorf("create archive dir %q: %w", file.Name, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(targetPath), 0o755); err != nil {
|
||||
cleanup()
|
||||
return "", nil, fmt.Errorf("create archive parent dir %q: %w", file.Name, err)
|
||||
}
|
||||
|
||||
src, err := file.Open()
|
||||
if err != nil {
|
||||
cleanup()
|
||||
return "", nil, fmt.Errorf("open archive entry %q: %w", file.Name, err)
|
||||
}
|
||||
dst, err := os.OpenFile(targetPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o644)
|
||||
if err != nil {
|
||||
src.Close()
|
||||
cleanup()
|
||||
return "", nil, fmt.Errorf("create archive file %q: %w", file.Name, err)
|
||||
}
|
||||
_, copyErr := io.Copy(dst, src)
|
||||
closeErr := dst.Close()
|
||||
srcErr := src.Close()
|
||||
if copyErr != nil {
|
||||
cleanup()
|
||||
return "", nil, fmt.Errorf("extract archive entry %q: %w", file.Name, copyErr)
|
||||
}
|
||||
if closeErr != nil {
|
||||
cleanup()
|
||||
return "", nil, fmt.Errorf("close archive file %q: %w", file.Name, closeErr)
|
||||
}
|
||||
if srcErr != nil {
|
||||
cleanup()
|
||||
return "", nil, fmt.Errorf("close archive entry %q: %w", file.Name, srcErr)
|
||||
}
|
||||
}
|
||||
|
||||
root, err := resolvePackRoot(tempDir)
|
||||
if err != nil {
|
||||
cleanup()
|
||||
return "", nil, err
|
||||
}
|
||||
return root, cleanup, nil
|
||||
}
|
||||
|
||||
func resolvePackRoot(extractDir string) (string, error) {
|
||||
manifestPath := filepath.Join(extractDir, "pack.json")
|
||||
if _, err := os.Stat(manifestPath); err == nil {
|
||||
return extractDir, nil
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(extractDir)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read extracted archive root: %w", err)
|
||||
}
|
||||
if len(entries) != 1 || !entries[0].IsDir() {
|
||||
return "", fmt.Errorf("pack archive must contain pack.json at root or a single top-level directory")
|
||||
}
|
||||
|
||||
root := filepath.Join(extractDir, entries[0].Name())
|
||||
if _, err := os.Stat(filepath.Join(root, "pack.json")); err != nil {
|
||||
return "", fmt.Errorf("pack archive root does not contain pack.json")
|
||||
}
|
||||
return root, nil
|
||||
}
|
||||
70
internal/pack/source_loader_test.go
Normal file
70
internal/pack/source_loader_test.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package pack
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadPathSupportsDirectory(t *testing.T) {
|
||||
loaded, err := LoadPath(filepath.Join("..", "..", "packs", "openai-cn-pack"))
|
||||
if err != nil {
|
||||
t.Fatalf("LoadPath(dir) error = %v", err)
|
||||
}
|
||||
if loaded.Manifest.PackID != "openai-cn-pack" {
|
||||
t.Fatalf("PackID = %q, want %q", loaded.Manifest.PackID, "openai-cn-pack")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadPathSupportsZipArchive(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
archivePath := filepath.Join(tempDir, "openai-cn-pack.zip")
|
||||
writePackArchive(t, archivePath)
|
||||
|
||||
loaded, err := LoadPath(archivePath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadPath(zip) error = %v", err)
|
||||
}
|
||||
if loaded.Manifest.PackID != "openai-cn-pack" {
|
||||
t.Fatalf("PackID = %q, want %q", loaded.Manifest.PackID, "openai-cn-pack")
|
||||
}
|
||||
if len(loaded.Providers) == 0 {
|
||||
t.Fatal("Providers = 0, want parsed providers from archive")
|
||||
}
|
||||
}
|
||||
|
||||
func writePackArchive(t *testing.T, archivePath string) {
|
||||
t.Helper()
|
||||
file, err := os.Create(archivePath)
|
||||
if err != nil {
|
||||
t.Fatalf("os.Create() error = %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
writer := zip.NewWriter(file)
|
||||
defer writer.Close()
|
||||
|
||||
sourceRoot := filepath.Join("..", "..", "packs", "openai-cn-pack")
|
||||
files := []string{
|
||||
"pack.json",
|
||||
"checksums.txt",
|
||||
filepath.Join("providers", "deepseek.json"),
|
||||
}
|
||||
for _, relativePath := range files {
|
||||
body, err := os.ReadFile(filepath.Join(sourceRoot, relativePath))
|
||||
if err != nil {
|
||||
t.Fatalf("os.ReadFile(%q) error = %v", relativePath, err)
|
||||
}
|
||||
entry, err := writer.Create(filepath.ToSlash(filepath.Join("openai-cn-pack", relativePath)))
|
||||
if err != nil {
|
||||
t.Fatalf("Create(%q) error = %v", relativePath, err)
|
||||
}
|
||||
if _, err := entry.Write(body); err != nil {
|
||||
t.Fatalf("Write(%q) error = %v", relativePath, err)
|
||||
}
|
||||
}
|
||||
if err := writer.Close(); err != nil {
|
||||
t.Fatalf("Close archive writer: %v", err)
|
||||
}
|
||||
}
|
||||
134
internal/pack/version.go
Normal file
134
internal/pack/version.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package pack
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func CheckHostCompatibility(manifest Manifest, hostVersion string) error {
|
||||
targetHost := strings.TrimSpace(manifest.TargetHost)
|
||||
if targetHost == "" {
|
||||
return fmt.Errorf("pack manifest target_host is required")
|
||||
}
|
||||
if targetHost != "sub2api" {
|
||||
return fmt.Errorf("pack target_host %q is not supported", targetHost)
|
||||
}
|
||||
|
||||
normalizedHost, err := parseVersion(strings.TrimSpace(hostVersion))
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse host version %q: %w", hostVersion, err)
|
||||
}
|
||||
minVersion := strings.TrimSpace(manifest.MinHostVersion)
|
||||
if minVersion != "" {
|
||||
cmp, err := compareVersions(normalizedHost.raw, minVersion)
|
||||
if err != nil {
|
||||
return fmt.Errorf("compare min_host_version: %w", err)
|
||||
}
|
||||
if cmp < 0 {
|
||||
return fmt.Errorf("host version %q is below min_host_version %q", hostVersion, minVersion)
|
||||
}
|
||||
}
|
||||
|
||||
maxVersion := strings.TrimSpace(manifest.MaxHostVersion)
|
||||
if maxVersion != "" {
|
||||
ok, err := matchesMaxConstraint(normalizedHost.raw, maxVersion)
|
||||
if err != nil {
|
||||
return fmt.Errorf("compare max_host_version: %w", err)
|
||||
}
|
||||
if !ok {
|
||||
return fmt.Errorf("host version %q is above max_host_version %q", hostVersion, maxVersion)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type parsedVersion struct {
|
||||
raw string
|
||||
parts [3]int
|
||||
}
|
||||
|
||||
func compareVersions(a, b string) (int, error) {
|
||||
left, err := parseVersion(a)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
right, err := parseVersion(b)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
for i := 0; i < len(left.parts); i++ {
|
||||
if left.parts[i] < right.parts[i] {
|
||||
return -1, nil
|
||||
}
|
||||
if left.parts[i] > right.parts[i] {
|
||||
return 1, nil
|
||||
}
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func matchesMaxConstraint(hostVersion, maxVersion string) (bool, error) {
|
||||
normalizedMax := normalizeVersion(maxVersion)
|
||||
if strings.HasSuffix(normalizedMax, ".x") {
|
||||
prefix := strings.TrimSuffix(normalizedMax, ".x")
|
||||
parts := strings.Split(prefix, ".")
|
||||
if len(parts) != 2 {
|
||||
return false, fmt.Errorf("wildcard max version %q must be in N.N.x format", maxVersion)
|
||||
}
|
||||
host, err := parseVersion(hostVersion)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
major, err := strconv.Atoi(parts[0])
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("parse major version %q: %w", parts[0], err)
|
||||
}
|
||||
minor, err := strconv.Atoi(parts[1])
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("parse minor version %q: %w", parts[1], err)
|
||||
}
|
||||
if host.parts[0] < major {
|
||||
return true, nil
|
||||
}
|
||||
if host.parts[0] > major {
|
||||
return false, nil
|
||||
}
|
||||
return host.parts[1] <= minor, nil
|
||||
}
|
||||
|
||||
cmp, err := compareVersions(hostVersion, maxVersion)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return cmp <= 0, nil
|
||||
}
|
||||
|
||||
func parseVersion(value string) (parsedVersion, error) {
|
||||
normalized := normalizeVersion(value)
|
||||
if normalized == "" {
|
||||
return parsedVersion{}, fmt.Errorf("version is required")
|
||||
}
|
||||
parts := strings.Split(normalized, ".")
|
||||
if len(parts) != 3 {
|
||||
return parsedVersion{}, fmt.Errorf("version %q must be in N.N.N format", value)
|
||||
}
|
||||
|
||||
var parsed parsedVersion
|
||||
parsed.raw = normalized
|
||||
for i, part := range parts {
|
||||
number, err := strconv.Atoi(part)
|
||||
if err != nil {
|
||||
return parsedVersion{}, fmt.Errorf("parse version segment %q: %w", part, err)
|
||||
}
|
||||
parsed.parts[i] = number
|
||||
}
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
func normalizeVersion(value string) string {
|
||||
trimmed := strings.TrimSpace(value)
|
||||
trimmed = strings.TrimPrefix(trimmed, "v")
|
||||
trimmed = strings.TrimPrefix(trimmed, "V")
|
||||
return trimmed
|
||||
}
|
||||
32
internal/pack/version_test.go
Normal file
32
internal/pack/version_test.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package pack
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestCheckHostCompatibilityAcceptsRange(t *testing.T) {
|
||||
manifest := Manifest{
|
||||
PackID: "openai-cn-pack",
|
||||
TargetHost: "sub2api",
|
||||
MinHostVersion: "0.1.126",
|
||||
MaxHostVersion: "0.2.x",
|
||||
}
|
||||
if err := CheckHostCompatibility(manifest, "0.1.126"); err != nil {
|
||||
t.Fatalf("CheckHostCompatibility() error = %v, want nil", err)
|
||||
}
|
||||
if err := CheckHostCompatibility(manifest, "0.2.9"); err != nil {
|
||||
t.Fatalf("CheckHostCompatibility() error = %v, want nil for wildcard upper bound", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckHostCompatibilityRejectsBelowMinimum(t *testing.T) {
|
||||
manifest := Manifest{TargetHost: "sub2api", MinHostVersion: "0.1.126"}
|
||||
if err := CheckHostCompatibility(manifest, "0.1.125"); err == nil {
|
||||
t.Fatal("CheckHostCompatibility() error = nil, want min version failure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckHostCompatibilityRejectsDifferentMaxMinor(t *testing.T) {
|
||||
manifest := Manifest{TargetHost: "sub2api", MaxHostVersion: "0.2.x"}
|
||||
if err := CheckHostCompatibility(manifest, "0.3.0"); err == nil {
|
||||
t.Fatal("CheckHostCompatibility() error = nil, want max version failure")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user