171 lines
4.3 KiB
Go
171 lines
4.3 KiB
Go
|
|
// internal/retry/retry.go
|
|||
|
|
// 指数退避重试机制
|
|||
|
|
package retry
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"context"
|
|||
|
|
"fmt"
|
|||
|
|
"math"
|
|||
|
|
"time"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// Strategy 重试策略
|
|||
|
|
type Strategy struct {
|
|||
|
|
MaxRetries int // 最大重试次数(0=不重试)
|
|||
|
|
BaseDelay time.Duration // 基础延迟
|
|||
|
|
MaxDelay time.Duration // 最大延迟上限
|
|||
|
|
Multiplier float64 // 乘数(默认2.0)
|
|||
|
|
Jitter bool // 是否添加随机抖动
|
|||
|
|
Retryable func(error) bool // 判断错误是否可重试
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// DefaultStrategy 返回默认重试策略
|
|||
|
|
func DefaultStrategy() Strategy {
|
|||
|
|
return Strategy{
|
|||
|
|
MaxRetries: 3,
|
|||
|
|
BaseDelay: 1 * time.Second,
|
|||
|
|
MaxDelay: 30 * time.Second,
|
|||
|
|
Multiplier: 2.0,
|
|||
|
|
Jitter: true,
|
|||
|
|
Retryable: IsRetryable,
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// IsRetryable 默认重试判定:网络错误、超时、5xx状态码等可重试
|
|||
|
|
func IsRetryable(err error) bool {
|
|||
|
|
if err == nil {
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
// 这里可以扩展更多错误类型判定
|
|||
|
|
return true
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Do 执行带重试的操作
|
|||
|
|
func Do(ctx context.Context, strategy Strategy, fn func() error) error {
|
|||
|
|
var lastErr error
|
|||
|
|
|
|||
|
|
for attempt := 0; attempt <= strategy.MaxRetries; attempt++ {
|
|||
|
|
if err := fn(); err != nil {
|
|||
|
|
lastErr = err
|
|||
|
|
|
|||
|
|
// 不判断最后一次是否需要重试
|
|||
|
|
if attempt == strategy.MaxRetries {
|
|||
|
|
break
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 检查是否可重试
|
|||
|
|
if strategy.Retryable != nil && !strategy.Retryable(err) {
|
|||
|
|
return fmt.Errorf("non-retryable error on attempt %d: %w", attempt+1, err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 计算退避延迟
|
|||
|
|
delay := calculateDelay(strategy, attempt)
|
|||
|
|
|
|||
|
|
// 检查上下文是否已取消
|
|||
|
|
select {
|
|||
|
|
case <-ctx.Done():
|
|||
|
|
return fmt.Errorf("context cancelled after attempt %d: %w", attempt+1, ctx.Err())
|
|||
|
|
case <-time.After(delay):
|
|||
|
|
// 继续重试
|
|||
|
|
}
|
|||
|
|
} else {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return fmt.Errorf("all %d attempts failed, last error: %w", strategy.MaxRetries+1, lastErr)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// calculateDelay 计算指数退避延迟
|
|||
|
|
func calculateDelay(s Strategy, attempt int) time.Duration {
|
|||
|
|
// 指数退避: base * multiplier^attempt
|
|||
|
|
delay := float64(s.BaseDelay) * math.Pow(s.Multiplier, float64(attempt))
|
|||
|
|
|
|||
|
|
// 添加上限
|
|||
|
|
if max := float64(s.MaxDelay); delay > max {
|
|||
|
|
delay = max
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 添加抖动(±25%)
|
|||
|
|
if s.Jitter {
|
|||
|
|
jitter := delay * 0.25
|
|||
|
|
delay = delay - jitter + (jitter * 2 * float64(time.Now().Nanosecond()%1000) / 1000)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return time.Duration(delay)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// DoWithResult 执行带重试的操作并返回结果
|
|||
|
|
func DoWithResult[T any](ctx context.Context, strategy Strategy, fn func() (T, error)) (T, error) {
|
|||
|
|
var zero T
|
|||
|
|
var lastErr error
|
|||
|
|
|
|||
|
|
for attempt := 0; attempt <= strategy.MaxRetries; attempt++ {
|
|||
|
|
result, err := fn()
|
|||
|
|
if err == nil {
|
|||
|
|
return result, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
lastErr = err
|
|||
|
|
if attempt == strategy.MaxRetries {
|
|||
|
|
break
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if strategy.Retryable != nil && !strategy.Retryable(err) {
|
|||
|
|
return zero, fmt.Errorf("non-retryable error on attempt %d: %w", attempt+1, err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
delay := calculateDelay(strategy, attempt)
|
|||
|
|
|
|||
|
|
select {
|
|||
|
|
case <-ctx.Done():
|
|||
|
|
return zero, fmt.Errorf("context cancelled after attempt %d: %w", attempt+1, ctx.Err())
|
|||
|
|
case <-time.After(delay):
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return zero, fmt.Errorf("all %d attempts failed, last error: %w", strategy.MaxRetries+1, lastErr)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Metrics 重试统计
|
|||
|
|
type Metrics struct {
|
|||
|
|
Attempts int
|
|||
|
|
Success bool
|
|||
|
|
TotalDelay time.Duration
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// DoWithMetrics 执行带重试并返回统计信息
|
|||
|
|
func DoWithMetrics(ctx context.Context, strategy Strategy, fn func() error) (Metrics, error) {
|
|||
|
|
m := Metrics{}
|
|||
|
|
var lastErr error
|
|||
|
|
start := time.Now()
|
|||
|
|
|
|||
|
|
for attempt := 0; attempt <= strategy.MaxRetries; attempt++ {
|
|||
|
|
m.Attempts = attempt + 1
|
|||
|
|
if err := fn(); err != nil {
|
|||
|
|
lastErr = err
|
|||
|
|
if attempt == strategy.MaxRetries {
|
|||
|
|
break
|
|||
|
|
}
|
|||
|
|
if strategy.Retryable != nil && !strategy.Retryable(err) {
|
|||
|
|
m.TotalDelay = time.Since(start)
|
|||
|
|
return m, fmt.Errorf("non-retryable error on attempt %d: %w", attempt+1, err)
|
|||
|
|
}
|
|||
|
|
delay := calculateDelay(strategy, attempt)
|
|||
|
|
select {
|
|||
|
|
case <-ctx.Done():
|
|||
|
|
m.TotalDelay = time.Since(start)
|
|||
|
|
return m, fmt.Errorf("context cancelled after attempt %d: %w", attempt+1, ctx.Err())
|
|||
|
|
case <-time.After(delay):
|
|||
|
|
}
|
|||
|
|
} else {
|
|||
|
|
m.Success = true
|
|||
|
|
m.TotalDelay = time.Since(start)
|
|||
|
|
return m, nil
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
m.TotalDelay = time.Since(start)
|
|||
|
|
return m, fmt.Errorf("all %d attempts failed, last error: %w", strategy.MaxRetries+1, lastErr)
|
|||
|
|
}
|