This commit is contained in:
129
internal/collectors/collector.go
Normal file
129
internal/collectors/collector.go
Normal file
@@ -0,0 +1,129 @@
|
||||
// internal/collectors/collector.go
|
||||
// Collector 接口定义:所有数据源采集器的统一抽象
|
||||
package collectors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Result 采集结果
|
||||
type Result struct {
|
||||
Models []ModelInfo
|
||||
Meta CollectionMeta
|
||||
}
|
||||
|
||||
// CollectionMeta 采集元信息
|
||||
type CollectionMeta struct {
|
||||
Source string
|
||||
Count int
|
||||
Duration time.Duration
|
||||
Timestamp time.Time
|
||||
BatchID string
|
||||
CollectorVersion string
|
||||
}
|
||||
|
||||
// ModelInfo 标准模型信息(与 fetch_openrouter.go 兼容)
|
||||
type ModelInfo struct {
|
||||
ID string
|
||||
Name string
|
||||
Provider string
|
||||
ProviderID string
|
||||
Version string
|
||||
Modality string
|
||||
ContextLength int
|
||||
Capabilities []string
|
||||
Pricing ModelPricing
|
||||
Description string
|
||||
IsFree bool
|
||||
SourceURL string
|
||||
}
|
||||
|
||||
// ModelPricing 标准定价信息
|
||||
type ModelPricing struct {
|
||||
Input float64
|
||||
Output float64
|
||||
}
|
||||
|
||||
// Collector 采集器接口
|
||||
type Collector interface {
|
||||
// Name 返回采集器名称
|
||||
Name() string
|
||||
|
||||
// Collect 执行采集,返回标准模型列表
|
||||
Collect(ctx context.Context) (Result, error)
|
||||
|
||||
// Schedule 返回推荐调度周期(如 "0 8 * * *")
|
||||
Schedule() string
|
||||
|
||||
// Timeout 返回单次采集超时时间
|
||||
Timeout() time.Duration
|
||||
|
||||
// RetryCount 返回最大重试次数
|
||||
RetryCount() int
|
||||
}
|
||||
|
||||
// BaseCollector 提供默认实现的嵌入类型
|
||||
type BaseCollector struct {
|
||||
name string
|
||||
schedule string
|
||||
timeout time.Duration
|
||||
retryCount int
|
||||
version string
|
||||
}
|
||||
|
||||
func (b *BaseCollector) Name() string { return b.name }
|
||||
func (b *BaseCollector) Schedule() string { return b.schedule }
|
||||
func (b *BaseCollector) Timeout() time.Duration { return b.timeout }
|
||||
func (b *BaseCollector) RetryCount() int { return b.retryCount }
|
||||
func (b *BaseCollector) Version() string { return b.version }
|
||||
|
||||
// NewBaseCollector 创建基础采集器配置
|
||||
func NewBaseCollector(name, schedule string, timeout time.Duration, retry int, version string) BaseCollector {
|
||||
return BaseCollector{
|
||||
name: name,
|
||||
schedule: schedule,
|
||||
timeout: timeout,
|
||||
retryCount: retry,
|
||||
version: version,
|
||||
}
|
||||
}
|
||||
|
||||
// CollectorRegistry 采集器注册表
|
||||
type CollectorRegistry struct {
|
||||
collectors map[string]Collector
|
||||
}
|
||||
|
||||
// NewRegistry 创建采集器注册表
|
||||
func NewRegistry() *CollectorRegistry {
|
||||
return &CollectorRegistry{collectors: make(map[string]Collector)}
|
||||
}
|
||||
|
||||
// Register 注册采集器
|
||||
func (r *CollectorRegistry) Register(c Collector) {
|
||||
r.collectors[c.Name()] = c
|
||||
}
|
||||
|
||||
// Get 获取采集器
|
||||
func (r *CollectorRegistry) Get(name string) (Collector, bool) {
|
||||
c, ok := r.collectors[name]
|
||||
return c, ok
|
||||
}
|
||||
|
||||
// All 返回所有已注册采集器
|
||||
func (r *CollectorRegistry) All() []Collector {
|
||||
cs := make([]Collector, 0, len(r.collectors))
|
||||
for _, c := range r.collectors {
|
||||
cs = append(cs, c)
|
||||
}
|
||||
return cs
|
||||
}
|
||||
|
||||
// Names 返回所有已注册采集器名称
|
||||
func (r *CollectorRegistry) Names() []string {
|
||||
names := make([]string, 0, len(r.collectors))
|
||||
for n := range r.collectors {
|
||||
names = append(names, n)
|
||||
}
|
||||
return names
|
||||
}
|
||||
127
internal/collectors/collector_test.go
Normal file
127
internal/collectors/collector_test.go
Normal file
@@ -0,0 +1,127 @@
|
||||
// internal/collectors/collector_test.go
|
||||
package collectors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// mockCollector 用于测试的模拟采集器
|
||||
type mockCollector struct {
|
||||
BaseCollector
|
||||
collectFunc func(ctx context.Context) (Result, error)
|
||||
}
|
||||
|
||||
func (m *mockCollector) Collect(ctx context.Context) (Result, error) {
|
||||
return m.collectFunc(ctx)
|
||||
}
|
||||
|
||||
func TestCollectorInterface(t *testing.T) {
|
||||
c := &mockCollector{
|
||||
BaseCollector: NewBaseCollector("test", "0 8 * * *", 30*time.Second, 3, "v1.0"),
|
||||
collectFunc: func(ctx context.Context) (Result, error) {
|
||||
return Result{
|
||||
Models: []ModelInfo{{ID: "test/model-1", Name: "Test Model"}},
|
||||
Meta: CollectionMeta{Source: "test", Count: 1},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
// 测试接口方法
|
||||
if c.Name() != "test" {
|
||||
t.Errorf("Name() = %q, want %q", c.Name(), "test")
|
||||
}
|
||||
if c.Schedule() != "0 8 * * *" {
|
||||
t.Errorf("Schedule() = %q, want %q", c.Schedule(), "0 8 * * *")
|
||||
}
|
||||
if c.Timeout() != 30*time.Second {
|
||||
t.Errorf("Timeout() = %v, want %v", c.Timeout(), 30*time.Second)
|
||||
}
|
||||
if c.RetryCount() != 3 {
|
||||
t.Errorf("RetryCount() = %d, want %d", c.RetryCount(), 3)
|
||||
}
|
||||
|
||||
// 测试 Collect
|
||||
ctx := context.Background()
|
||||
result, err := c.Collect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Collect() error = %v", err)
|
||||
}
|
||||
if len(result.Models) != 1 {
|
||||
t.Errorf("len(Models) = %d, want 1", len(result.Models))
|
||||
}
|
||||
if result.Meta.Count != 1 {
|
||||
t.Errorf("Meta.Count = %d, want 1", result.Meta.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectorRegistry(t *testing.T) {
|
||||
reg := NewRegistry()
|
||||
|
||||
c1 := &mockCollector{
|
||||
BaseCollector: NewBaseCollector("openrouter", "0 8 * * *", 30*time.Second, 3, "v1.0"),
|
||||
collectFunc: func(ctx context.Context) (Result, error) { return Result{}, nil },
|
||||
}
|
||||
c2 := &mockCollector{
|
||||
BaseCollector: NewBaseCollector("siliconflow", "0 9 * * *", 30*time.Second, 3, "v1.0"),
|
||||
collectFunc: func(ctx context.Context) (Result, error) { return Result{}, nil },
|
||||
}
|
||||
|
||||
reg.Register(c1)
|
||||
reg.Register(c2)
|
||||
|
||||
// 测试 Get
|
||||
got, ok := reg.Get("openrouter")
|
||||
if !ok {
|
||||
t.Fatal("Get(openrouter) not found")
|
||||
}
|
||||
if got.Name() != "openrouter" {
|
||||
t.Errorf("Get() Name = %q, want %q", got.Name(), "openrouter")
|
||||
}
|
||||
|
||||
// 测试 Names
|
||||
names := reg.Names()
|
||||
if len(names) != 2 {
|
||||
t.Errorf("Names() len = %d, want 2", len(names))
|
||||
}
|
||||
|
||||
// 测试 All
|
||||
all := reg.All()
|
||||
if len(all) != 2 {
|
||||
t.Errorf("All() len = %d, want 2", len(all))
|
||||
}
|
||||
|
||||
// 测试不存在的采集器
|
||||
_, ok = reg.Get("nonexistent")
|
||||
if ok {
|
||||
t.Error("Get(nonexistent) should return false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectorTimeout(t *testing.T) {
|
||||
c := &mockCollector{
|
||||
BaseCollector: NewBaseCollector("slow", "0 8 * * *", 100*time.Millisecond, 0, "v1.0"),
|
||||
collectFunc: func(ctx context.Context) (Result, error) {
|
||||
// 模拟耗时操作
|
||||
select {
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
return Result{}, nil
|
||||
case <-ctx.Done():
|
||||
return Result{}, ctx.Err()
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, err := c.Collect(ctx)
|
||||
if err == nil {
|
||||
t.Error("Expected timeout error, got nil")
|
||||
}
|
||||
if !errors.Is(err, context.DeadlineExceeded) {
|
||||
t.Errorf("Expected DeadlineExceeded, got %v", err)
|
||||
}
|
||||
}
|
||||
115
internal/collectors/provider_mapper.go
Normal file
115
internal/collectors/provider_mapper.go
Normal file
@@ -0,0 +1,115 @@
|
||||
// internal/collectors/provider_mapper.go
|
||||
// ProviderMapper: 将 OpenRouter 模型 ID 映射为标准厂商/模型名称
|
||||
package collectors
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ProviderInfo 标准厂商信息
|
||||
type ProviderInfo struct {
|
||||
ID string // 标准ID: "openai", "anthropic", "deepseek"...
|
||||
Name string // 英文名
|
||||
NameCN string // 中文名
|
||||
Country string // "US" / "CN" / "EU"
|
||||
}
|
||||
|
||||
// ModelMapping 模型映射结果
|
||||
type ModelMapping struct {
|
||||
Provider ProviderInfo
|
||||
ModelName string // 纯模型名,不含厂商前缀
|
||||
RawID string // 原始 OpenRouter ID
|
||||
IsFree bool // 是否免费版(:free 后缀)
|
||||
}
|
||||
|
||||
// providerNameMap 标准厂商名称映射表
|
||||
// key 为标准ID(也兼容 OpenRouter 原始格式作为别名)
|
||||
var providerNameMap = map[string]ProviderInfo{
|
||||
"openai": {ID: "openai", Name: "OpenAI", NameCN: "OpenAI", Country: "US"},
|
||||
"anthropic": {ID: "anthropic", Name: "Anthropic", NameCN: "Anthropic", Country: "US"},
|
||||
"google": {ID: "google", Name: "Google", NameCN: "谷歌", Country: "US"},
|
||||
"meta": {ID: "meta", Name: "Meta", NameCN: "Meta", Country: "US"},
|
||||
"xai": {ID: "xai", Name: "xAI", NameCN: "xAI", Country: "US"},
|
||||
"x-ai": {ID: "xai", Name: "xAI", NameCN: "xAI", Country: "US"}, // OpenRouter别名
|
||||
"deepseek": {ID: "deepseek", Name: "DeepSeek", NameCN: "深度求索", Country: "CN"},
|
||||
"qwen": {ID: "qwen", Name: "Qwen", NameCN: "通义千问", Country: "CN"},
|
||||
"alibaba": {ID: "alibaba", Name: "Alibaba", NameCN: "阿里巴巴", Country: "CN"},
|
||||
"moonshot": {ID: "moonshot", Name: "Moonshot AI", NameCN: "月之暗面", Country: "CN"},
|
||||
"moonshotai": {ID: "moonshot", Name: "Moonshot AI", NameCN: "月之暗面", Country: "CN"}, // OpenRouter别名
|
||||
"zhipu": {ID: "zhipu", Name: "Zhipu AI", NameCN: "智谱AI", Country: "CN"},
|
||||
"zhipuai": {ID: "zhipu", Name: "Zhipu AI", NameCN: "智谱AI", Country: "CN"}, // OpenRouter别名
|
||||
"bytedance": {ID: "bytedance", Name: "ByteDance", NameCN: "字节跳动", Country: "CN"},
|
||||
"baidu": {ID: "baidu", Name: "Baidu", NameCN: "百度", Country: "CN"},
|
||||
"tencent": {ID: "tencent", Name: "Tencent", NameCN: "腾讯", Country: "CN"},
|
||||
"mistral": {ID: "mistral", Name: "Mistral AI", NameCN: "Mistral", Country: "EU"},
|
||||
"cohere": {ID: "cohere", Name: "Cohere", NameCN: "Cohere", Country: "US"},
|
||||
"ai21": {ID: "ai21", Name: "AI21 Labs", NameCN: "AI21", Country: "US"},
|
||||
"perplexity": {ID: "perplexity", Name: "Perplexity", NameCN: "Perplexity", Country: "US"},
|
||||
"nvidia": {ID: "nvidia", Name: "NVIDIA", NameCN: "英伟达", Country: "US"},
|
||||
"microsoft": {ID: "microsoft", Name: "Microsoft", NameCN: "微软", Country: "US"},
|
||||
"openrouter": {ID: "openrouter", Name: "OpenRouter", NameCN: "OpenRouter", Country: "US"},
|
||||
}
|
||||
|
||||
// MapOpenRouterID 将 OpenRouter 模型 ID 映射为标准信息
|
||||
// OpenRouter ID 格式: "provider/model-name" 或 "provider/model-name:free"
|
||||
func MapOpenRouterID(rawID string) (ModelMapping, error) {
|
||||
if rawID == "" {
|
||||
return ModelMapping{}, fmt.Errorf("empty model ID")
|
||||
}
|
||||
|
||||
// 检测 :free 后缀
|
||||
isFree := false
|
||||
modelPart := rawID
|
||||
if strings.HasSuffix(rawID, ":free") {
|
||||
isFree = true
|
||||
modelPart = rawID[:len(rawID)-5]
|
||||
}
|
||||
|
||||
// 分割 provider / model
|
||||
parts := strings.SplitN(modelPart, "/", 2)
|
||||
if len(parts) < 2 {
|
||||
return ModelMapping{}, fmt.Errorf("invalid model ID format: %s", rawID)
|
||||
}
|
||||
|
||||
providerKey := strings.ToLower(parts[0])
|
||||
modelName := parts[1]
|
||||
|
||||
// 查找厂商信息
|
||||
provider, ok := providerNameMap[providerKey]
|
||||
if !ok {
|
||||
// 未识别厂商,返回通用信息
|
||||
provider = ProviderInfo{
|
||||
ID: providerKey,
|
||||
Name: providerKey,
|
||||
NameCN: providerKey,
|
||||
Country: "unknown",
|
||||
}
|
||||
}
|
||||
|
||||
return ModelMapping{
|
||||
Provider: provider,
|
||||
ModelName: modelName,
|
||||
RawID: rawID,
|
||||
IsFree: isFree,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetAllProviderNames 返回所有已注册的厂商ID列表(用于测试覆盖度检查)
|
||||
func GetAllProviderNames() []string {
|
||||
names := make([]string, 0, len(providerNameMap))
|
||||
for k := range providerNameMap {
|
||||
names = append(names, k)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// RegisterProvider 动态注册新厂商(用于扩展)
|
||||
func RegisterProvider(key string, info ProviderInfo) {
|
||||
providerNameMap[strings.ToLower(key)] = info
|
||||
}
|
||||
|
||||
// ProviderCount 返回已注册厂商数量
|
||||
func ProviderCount() int {
|
||||
return len(providerNameMap)
|
||||
}
|
||||
167
internal/collectors/provider_mapper_test.go
Normal file
167
internal/collectors/provider_mapper_test.go
Normal file
@@ -0,0 +1,167 @@
|
||||
// internal/collectors/provider_mapper_test.go
|
||||
package collectors
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMapOpenRouterID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rawID string
|
||||
wantErr bool
|
||||
wantProvID string
|
||||
wantProvCN string
|
||||
wantModel string
|
||||
wantFree bool
|
||||
wantCountry string
|
||||
}{
|
||||
{
|
||||
name: "OpenAI GPT-4o",
|
||||
rawID: "openai/gpt-4o",
|
||||
wantProvID: "openai",
|
||||
wantProvCN: "OpenAI",
|
||||
wantModel: "gpt-4o",
|
||||
wantFree: false,
|
||||
wantCountry: "US",
|
||||
},
|
||||
{
|
||||
name: "Anthropic Claude free",
|
||||
rawID: "anthropic/claude-3.5-sonnet:free",
|
||||
wantProvID: "anthropic",
|
||||
wantProvCN: "Anthropic",
|
||||
wantModel: "claude-3.5-sonnet",
|
||||
wantFree: true,
|
||||
wantCountry: "US",
|
||||
},
|
||||
{
|
||||
name: "DeepSeek V3",
|
||||
rawID: "deepseek/deepseek-v3",
|
||||
wantProvID: "deepseek",
|
||||
wantProvCN: "深度求索",
|
||||
wantModel: "deepseek-v3",
|
||||
wantFree: false,
|
||||
wantCountry: "CN",
|
||||
},
|
||||
{
|
||||
name: "Moonshot Kimi",
|
||||
rawID: "moonshotai/kimi-k2",
|
||||
wantProvID: "moonshot",
|
||||
wantProvCN: "月之暗面",
|
||||
wantModel: "kimi-k2",
|
||||
wantFree: false,
|
||||
wantCountry: "CN",
|
||||
},
|
||||
{
|
||||
name: "Unknown provider fallback",
|
||||
rawID: "some-new-ai/model-x",
|
||||
wantProvID: "some-new-ai",
|
||||
wantProvCN: "some-new-ai",
|
||||
wantModel: "model-x",
|
||||
wantFree: false,
|
||||
wantCountry: "unknown",
|
||||
},
|
||||
{
|
||||
name: "Empty ID",
|
||||
rawID: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid format no slash",
|
||||
rawID: "invalid-id",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := MapOpenRouterID(tt.rawID)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("MapOpenRouterID(%q) error = %v, wantErr %v", tt.rawID, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if tt.wantErr {
|
||||
return
|
||||
}
|
||||
if got.Provider.ID != tt.wantProvID {
|
||||
t.Errorf("Provider.ID = %q, want %q", got.Provider.ID, tt.wantProvID)
|
||||
}
|
||||
if got.Provider.NameCN != tt.wantProvCN {
|
||||
t.Errorf("Provider.NameCN = %q, want %q", got.Provider.NameCN, tt.wantProvCN)
|
||||
}
|
||||
if got.ModelName != tt.wantModel {
|
||||
t.Errorf("ModelName = %q, want %q", got.ModelName, tt.wantModel)
|
||||
}
|
||||
if got.IsFree != tt.wantFree {
|
||||
t.Errorf("IsFree = %v, want %v", got.IsFree, tt.wantFree)
|
||||
}
|
||||
if got.Provider.Country != tt.wantCountry {
|
||||
t.Errorf("Country = %q, want %q", got.Provider.Country, tt.wantCountry)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderMapCompleteness(t *testing.T) {
|
||||
// 验证所有预定义的厂商映射
|
||||
requiredProviders := []string{
|
||||
"openai", "anthropic", "google", "meta", "xai",
|
||||
"deepseek", "qwen", "moonshot", "zhipu", "bytedance",
|
||||
"baidu", "tencent", "alibaba", "mistral", "cohere",
|
||||
"ai21", "perplexity", "nvidia", "microsoft", "openrouter",
|
||||
}
|
||||
|
||||
for _, id := range requiredProviders {
|
||||
_, ok := providerNameMap[id]
|
||||
if !ok {
|
||||
t.Errorf("Required provider %q not found in providerNameMap", id)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证总数 >= 20
|
||||
if ProviderCount() < 20 {
|
||||
t.Errorf("ProviderCount() = %d, want >= 20", ProviderCount())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterProvider(t *testing.T) {
|
||||
// 注册新厂商
|
||||
RegisterProvider("test-corp", ProviderInfo{
|
||||
ID: "test-corp",
|
||||
Name: "Test Corp",
|
||||
NameCN: "测试公司",
|
||||
Country: "CN",
|
||||
})
|
||||
|
||||
got, err := MapOpenRouterID("test-corp/model-1")
|
||||
if err != nil {
|
||||
t.Fatalf("MapOpenRouterID after RegisterProvider failed: %v", err)
|
||||
}
|
||||
if got.Provider.NameCN != "测试公司" {
|
||||
t.Errorf("After RegisterProvider, NameCN = %q, want %q", got.Provider.NameCN, "测试公司")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAllProviderNames(t *testing.T) {
|
||||
names := GetAllProviderNames()
|
||||
if len(names) == 0 {
|
||||
t.Error("GetAllProviderNames() returned empty slice")
|
||||
}
|
||||
// 验证包含 openai
|
||||
found := false
|
||||
for _, n := range names {
|
||||
if n == "openai" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("GetAllProviderNames() missing 'openai'")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMapOpenRouterID(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = MapOpenRouterID("openai/gpt-4o")
|
||||
}
|
||||
}
|
||||
170
internal/retry/retry.go
Normal file
170
internal/retry/retry.go
Normal file
@@ -0,0 +1,170 @@
|
||||
// internal/retry/retry.go
|
||||
// 指数退避重试机制
|
||||
package retry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Strategy 重试策略
|
||||
type Strategy struct {
|
||||
MaxRetries int // 最大重试次数(0=不重试)
|
||||
BaseDelay time.Duration // 基础延迟
|
||||
MaxDelay time.Duration // 最大延迟上限
|
||||
Multiplier float64 // 乘数(默认2.0)
|
||||
Jitter bool // 是否添加随机抖动
|
||||
Retryable func(error) bool // 判断错误是否可重试
|
||||
}
|
||||
|
||||
// DefaultStrategy 返回默认重试策略
|
||||
func DefaultStrategy() Strategy {
|
||||
return Strategy{
|
||||
MaxRetries: 3,
|
||||
BaseDelay: 1 * time.Second,
|
||||
MaxDelay: 30 * time.Second,
|
||||
Multiplier: 2.0,
|
||||
Jitter: true,
|
||||
Retryable: IsRetryable,
|
||||
}
|
||||
}
|
||||
|
||||
// IsRetryable 默认重试判定:网络错误、超时、5xx状态码等可重试
|
||||
func IsRetryable(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
// 这里可以扩展更多错误类型判定
|
||||
return true
|
||||
}
|
||||
|
||||
// Do 执行带重试的操作
|
||||
func Do(ctx context.Context, strategy Strategy, fn func() error) error {
|
||||
var lastErr error
|
||||
|
||||
for attempt := 0; attempt <= strategy.MaxRetries; attempt++ {
|
||||
if err := fn(); err != nil {
|
||||
lastErr = err
|
||||
|
||||
// 不判断最后一次是否需要重试
|
||||
if attempt == strategy.MaxRetries {
|
||||
break
|
||||
}
|
||||
|
||||
// 检查是否可重试
|
||||
if strategy.Retryable != nil && !strategy.Retryable(err) {
|
||||
return fmt.Errorf("non-retryable error on attempt %d: %w", attempt+1, err)
|
||||
}
|
||||
|
||||
// 计算退避延迟
|
||||
delay := calculateDelay(strategy, attempt)
|
||||
|
||||
// 检查上下文是否已取消
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("context cancelled after attempt %d: %w", attempt+1, ctx.Err())
|
||||
case <-time.After(delay):
|
||||
// 继续重试
|
||||
}
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("all %d attempts failed, last error: %w", strategy.MaxRetries+1, lastErr)
|
||||
}
|
||||
|
||||
// calculateDelay 计算指数退避延迟
|
||||
func calculateDelay(s Strategy, attempt int) time.Duration {
|
||||
// 指数退避: base * multiplier^attempt
|
||||
delay := float64(s.BaseDelay) * math.Pow(s.Multiplier, float64(attempt))
|
||||
|
||||
// 添加上限
|
||||
if max := float64(s.MaxDelay); delay > max {
|
||||
delay = max
|
||||
}
|
||||
|
||||
// 添加抖动(±25%)
|
||||
if s.Jitter {
|
||||
jitter := delay * 0.25
|
||||
delay = delay - jitter + (jitter * 2 * float64(time.Now().Nanosecond()%1000) / 1000)
|
||||
}
|
||||
|
||||
return time.Duration(delay)
|
||||
}
|
||||
|
||||
// DoWithResult 执行带重试的操作并返回结果
|
||||
func DoWithResult[T any](ctx context.Context, strategy Strategy, fn func() (T, error)) (T, error) {
|
||||
var zero T
|
||||
var lastErr error
|
||||
|
||||
for attempt := 0; attempt <= strategy.MaxRetries; attempt++ {
|
||||
result, err := fn()
|
||||
if err == nil {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
if attempt == strategy.MaxRetries {
|
||||
break
|
||||
}
|
||||
|
||||
if strategy.Retryable != nil && !strategy.Retryable(err) {
|
||||
return zero, fmt.Errorf("non-retryable error on attempt %d: %w", attempt+1, err)
|
||||
}
|
||||
|
||||
delay := calculateDelay(strategy, attempt)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return zero, fmt.Errorf("context cancelled after attempt %d: %w", attempt+1, ctx.Err())
|
||||
case <-time.After(delay):
|
||||
}
|
||||
}
|
||||
|
||||
return zero, fmt.Errorf("all %d attempts failed, last error: %w", strategy.MaxRetries+1, lastErr)
|
||||
}
|
||||
|
||||
// Metrics 重试统计
|
||||
type Metrics struct {
|
||||
Attempts int
|
||||
Success bool
|
||||
TotalDelay time.Duration
|
||||
}
|
||||
|
||||
// DoWithMetrics 执行带重试并返回统计信息
|
||||
func DoWithMetrics(ctx context.Context, strategy Strategy, fn func() error) (Metrics, error) {
|
||||
m := Metrics{}
|
||||
var lastErr error
|
||||
start := time.Now()
|
||||
|
||||
for attempt := 0; attempt <= strategy.MaxRetries; attempt++ {
|
||||
m.Attempts = attempt + 1
|
||||
if err := fn(); err != nil {
|
||||
lastErr = err
|
||||
if attempt == strategy.MaxRetries {
|
||||
break
|
||||
}
|
||||
if strategy.Retryable != nil && !strategy.Retryable(err) {
|
||||
m.TotalDelay = time.Since(start)
|
||||
return m, fmt.Errorf("non-retryable error on attempt %d: %w", attempt+1, err)
|
||||
}
|
||||
delay := calculateDelay(strategy, attempt)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
m.TotalDelay = time.Since(start)
|
||||
return m, fmt.Errorf("context cancelled after attempt %d: %w", attempt+1, ctx.Err())
|
||||
case <-time.After(delay):
|
||||
}
|
||||
} else {
|
||||
m.Success = true
|
||||
m.TotalDelay = time.Since(start)
|
||||
return m, nil
|
||||
}
|
||||
}
|
||||
|
||||
m.TotalDelay = time.Since(start)
|
||||
return m, fmt.Errorf("all %d attempts failed, last error: %w", strategy.MaxRetries+1, lastErr)
|
||||
}
|
||||
245
internal/retry/retry_test.go
Normal file
245
internal/retry/retry_test.go
Normal file
@@ -0,0 +1,245 @@
|
||||
// internal/retry/retry_test.go
|
||||
package retry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestDo_Success(t *testing.T) {
|
||||
strategy := DefaultStrategy()
|
||||
callCount := 0
|
||||
|
||||
err := Do(context.Background(), strategy, func() error {
|
||||
callCount++
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if callCount != 1 {
|
||||
t.Errorf("expected 1 call, got %d", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDo_RetryThenSuccess(t *testing.T) {
|
||||
strategy := Strategy{
|
||||
MaxRetries: 3,
|
||||
BaseDelay: 10 * time.Millisecond,
|
||||
MaxDelay: 100 * time.Millisecond,
|
||||
Multiplier: 2.0,
|
||||
Jitter: false,
|
||||
Retryable: IsRetryable,
|
||||
}
|
||||
callCount := 0
|
||||
|
||||
err := Do(context.Background(), strategy, func() error {
|
||||
callCount++
|
||||
if callCount < 3 {
|
||||
return errors.New("temporary error")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if callCount != 3 {
|
||||
t.Errorf("expected 3 calls, got %d", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDo_MaxRetriesExceeded(t *testing.T) {
|
||||
strategy := Strategy{
|
||||
MaxRetries: 2,
|
||||
BaseDelay: 5 * time.Millisecond,
|
||||
MaxDelay: 50 * time.Millisecond,
|
||||
Multiplier: 2.0,
|
||||
Jitter: false,
|
||||
Retryable: IsRetryable,
|
||||
}
|
||||
callCount := 0
|
||||
expectedErr := errors.New("persistent error")
|
||||
|
||||
err := Do(context.Background(), strategy, func() error {
|
||||
callCount++
|
||||
return expectedErr
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if callCount != 3 { // initial + 2 retries
|
||||
t.Errorf("expected 3 calls, got %d", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDo_NonRetryableError(t *testing.T) {
|
||||
strategy := Strategy{
|
||||
MaxRetries: 3,
|
||||
BaseDelay: 10 * time.Millisecond,
|
||||
MaxDelay: 100 * time.Millisecond,
|
||||
Multiplier: 2.0,
|
||||
Jitter: false,
|
||||
Retryable: func(err error) bool { return false }, // 任何错误都不重试
|
||||
}
|
||||
callCount := 0
|
||||
|
||||
err := Do(context.Background(), strategy, func() error {
|
||||
callCount++
|
||||
return errors.New("non-retryable")
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if callCount != 1 {
|
||||
t.Errorf("expected 1 call (no retry), got %d", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDo_ContextCancellation(t *testing.T) {
|
||||
strategy := Strategy{
|
||||
MaxRetries: 3,
|
||||
BaseDelay: 1 * time.Second, // 长延迟确保上下文取消优先
|
||||
MaxDelay: 5 * time.Second,
|
||||
Multiplier: 2.0,
|
||||
Jitter: false,
|
||||
Retryable: IsRetryable,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
callCount := 0
|
||||
err := Do(ctx, strategy, func() error {
|
||||
callCount++
|
||||
return errors.New("error")
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if callCount < 1 {
|
||||
t.Error("expected at least 1 call")
|
||||
}
|
||||
if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) {
|
||||
t.Errorf("expected context error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDoWithResult(t *testing.T) {
|
||||
strategy := Strategy{
|
||||
MaxRetries: 2,
|
||||
BaseDelay: 5 * time.Millisecond,
|
||||
MaxDelay: 50 * time.Millisecond,
|
||||
Multiplier: 2.0,
|
||||
Jitter: false,
|
||||
Retryable: IsRetryable,
|
||||
}
|
||||
callCount := 0
|
||||
|
||||
result, err := DoWithResult(context.Background(), strategy, func() (string, error) {
|
||||
callCount++
|
||||
if callCount < 2 {
|
||||
return "", errors.New("temp error")
|
||||
}
|
||||
return "success", nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if result != "success" {
|
||||
t.Errorf("expected 'success', got %q", result)
|
||||
}
|
||||
if callCount != 2 {
|
||||
t.Errorf("expected 2 calls, got %d", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDoWithMetrics(t *testing.T) {
|
||||
strategy := Strategy{
|
||||
MaxRetries: 2,
|
||||
BaseDelay: 10 * time.Millisecond,
|
||||
MaxDelay: 100 * time.Millisecond,
|
||||
Multiplier: 2.0,
|
||||
Jitter: false,
|
||||
Retryable: IsRetryable,
|
||||
}
|
||||
|
||||
// 成功场景
|
||||
m, err := DoWithMetrics(context.Background(), strategy, func() error {
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if !m.Success {
|
||||
t.Error("expected Success=true")
|
||||
}
|
||||
if m.Attempts != 1 {
|
||||
t.Errorf("expected 1 attempt, got %d", m.Attempts)
|
||||
}
|
||||
|
||||
// 失败场景
|
||||
m2, err := DoWithMetrics(context.Background(), strategy, func() error {
|
||||
return errors.New("always fails")
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if m2.Success {
|
||||
t.Error("expected Success=false")
|
||||
}
|
||||
if m2.Attempts != 3 {
|
||||
t.Errorf("expected 3 attempts, got %d", m2.Attempts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateDelay(t *testing.T) {
|
||||
strategy := Strategy{
|
||||
BaseDelay: 1 * time.Second,
|
||||
MaxDelay: 10 * time.Second,
|
||||
Multiplier: 2.0,
|
||||
Jitter: false,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
attempt int
|
||||
min time.Duration
|
||||
max time.Duration
|
||||
}{
|
||||
{0, 1 * time.Second, 1 * time.Second},
|
||||
{1, 2 * time.Second, 2 * time.Second},
|
||||
{2, 4 * time.Second, 4 * time.Second},
|
||||
{3, 8 * time.Second, 8 * time.Second},
|
||||
{4, 10 * time.Second, 10 * time.Second}, // 达到上限
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
delay := calculateDelay(strategy, tt.attempt)
|
||||
if delay < tt.min || delay > tt.max {
|
||||
t.Errorf("attempt %d: delay=%v, want [%v, %v]", tt.attempt, delay, tt.min, tt.max)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDo(b *testing.B) {
|
||||
strategy := Strategy{
|
||||
MaxRetries: 0,
|
||||
BaseDelay: 0,
|
||||
MaxDelay: 0,
|
||||
Multiplier: 0,
|
||||
Jitter: false,
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = Do(context.Background(), strategy, func() error {
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user