forked from niuniu/llm-intelligence
chore: prepare repository for publishing
This commit is contained in:
170
internal/retry/retry.go
Normal file
170
internal/retry/retry.go
Normal file
@@ -0,0 +1,170 @@
|
||||
// internal/retry/retry.go
|
||||
// 指数退避重试机制
|
||||
package retry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Strategy 重试策略
|
||||
type Strategy struct {
|
||||
MaxRetries int // 最大重试次数(0=不重试)
|
||||
BaseDelay time.Duration // 基础延迟
|
||||
MaxDelay time.Duration // 最大延迟上限
|
||||
Multiplier float64 // 乘数(默认2.0)
|
||||
Jitter bool // 是否添加随机抖动
|
||||
Retryable func(error) bool // 判断错误是否可重试
|
||||
}
|
||||
|
||||
// DefaultStrategy 返回默认重试策略
|
||||
func DefaultStrategy() Strategy {
|
||||
return Strategy{
|
||||
MaxRetries: 3,
|
||||
BaseDelay: 1 * time.Second,
|
||||
MaxDelay: 30 * time.Second,
|
||||
Multiplier: 2.0,
|
||||
Jitter: true,
|
||||
Retryable: IsRetryable,
|
||||
}
|
||||
}
|
||||
|
||||
// IsRetryable 默认重试判定:网络错误、超时、5xx状态码等可重试
|
||||
func IsRetryable(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
// 这里可以扩展更多错误类型判定
|
||||
return true
|
||||
}
|
||||
|
||||
// Do 执行带重试的操作
|
||||
func Do(ctx context.Context, strategy Strategy, fn func() error) error {
|
||||
var lastErr error
|
||||
|
||||
for attempt := 0; attempt <= strategy.MaxRetries; attempt++ {
|
||||
if err := fn(); err != nil {
|
||||
lastErr = err
|
||||
|
||||
// 不判断最后一次是否需要重试
|
||||
if attempt == strategy.MaxRetries {
|
||||
break
|
||||
}
|
||||
|
||||
// 检查是否可重试
|
||||
if strategy.Retryable != nil && !strategy.Retryable(err) {
|
||||
return fmt.Errorf("non-retryable error on attempt %d: %w", attempt+1, err)
|
||||
}
|
||||
|
||||
// 计算退避延迟
|
||||
delay := calculateDelay(strategy, attempt)
|
||||
|
||||
// 检查上下文是否已取消
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("context cancelled after attempt %d: %w", attempt+1, ctx.Err())
|
||||
case <-time.After(delay):
|
||||
// 继续重试
|
||||
}
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("all %d attempts failed, last error: %w", strategy.MaxRetries+1, lastErr)
|
||||
}
|
||||
|
||||
// calculateDelay 计算指数退避延迟
|
||||
func calculateDelay(s Strategy, attempt int) time.Duration {
|
||||
// 指数退避: base * multiplier^attempt
|
||||
delay := float64(s.BaseDelay) * math.Pow(s.Multiplier, float64(attempt))
|
||||
|
||||
// 添加上限
|
||||
if max := float64(s.MaxDelay); delay > max {
|
||||
delay = max
|
||||
}
|
||||
|
||||
// 添加抖动(±25%)
|
||||
if s.Jitter {
|
||||
jitter := delay * 0.25
|
||||
delay = delay - jitter + (jitter * 2 * float64(time.Now().Nanosecond()%1000) / 1000)
|
||||
}
|
||||
|
||||
return time.Duration(delay)
|
||||
}
|
||||
|
||||
// DoWithResult 执行带重试的操作并返回结果
|
||||
func DoWithResult[T any](ctx context.Context, strategy Strategy, fn func() (T, error)) (T, error) {
|
||||
var zero T
|
||||
var lastErr error
|
||||
|
||||
for attempt := 0; attempt <= strategy.MaxRetries; attempt++ {
|
||||
result, err := fn()
|
||||
if err == nil {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
if attempt == strategy.MaxRetries {
|
||||
break
|
||||
}
|
||||
|
||||
if strategy.Retryable != nil && !strategy.Retryable(err) {
|
||||
return zero, fmt.Errorf("non-retryable error on attempt %d: %w", attempt+1, err)
|
||||
}
|
||||
|
||||
delay := calculateDelay(strategy, attempt)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return zero, fmt.Errorf("context cancelled after attempt %d: %w", attempt+1, ctx.Err())
|
||||
case <-time.After(delay):
|
||||
}
|
||||
}
|
||||
|
||||
return zero, fmt.Errorf("all %d attempts failed, last error: %w", strategy.MaxRetries+1, lastErr)
|
||||
}
|
||||
|
||||
// Metrics 重试统计
|
||||
type Metrics struct {
|
||||
Attempts int
|
||||
Success bool
|
||||
TotalDelay time.Duration
|
||||
}
|
||||
|
||||
// DoWithMetrics 执行带重试并返回统计信息
|
||||
func DoWithMetrics(ctx context.Context, strategy Strategy, fn func() error) (Metrics, error) {
|
||||
m := Metrics{}
|
||||
var lastErr error
|
||||
start := time.Now()
|
||||
|
||||
for attempt := 0; attempt <= strategy.MaxRetries; attempt++ {
|
||||
m.Attempts = attempt + 1
|
||||
if err := fn(); err != nil {
|
||||
lastErr = err
|
||||
if attempt == strategy.MaxRetries {
|
||||
break
|
||||
}
|
||||
if strategy.Retryable != nil && !strategy.Retryable(err) {
|
||||
m.TotalDelay = time.Since(start)
|
||||
return m, fmt.Errorf("non-retryable error on attempt %d: %w", attempt+1, err)
|
||||
}
|
||||
delay := calculateDelay(strategy, attempt)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
m.TotalDelay = time.Since(start)
|
||||
return m, fmt.Errorf("context cancelled after attempt %d: %w", attempt+1, ctx.Err())
|
||||
case <-time.After(delay):
|
||||
}
|
||||
} else {
|
||||
m.Success = true
|
||||
m.TotalDelay = time.Since(start)
|
||||
return m, nil
|
||||
}
|
||||
}
|
||||
|
||||
m.TotalDelay = time.Since(start)
|
||||
return m, fmt.Errorf("all %d attempts failed, last error: %w", strategy.MaxRetries+1, lastErr)
|
||||
}
|
||||
245
internal/retry/retry_test.go
Normal file
245
internal/retry/retry_test.go
Normal file
@@ -0,0 +1,245 @@
|
||||
// internal/retry/retry_test.go
|
||||
package retry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestDo_Success(t *testing.T) {
|
||||
strategy := DefaultStrategy()
|
||||
callCount := 0
|
||||
|
||||
err := Do(context.Background(), strategy, func() error {
|
||||
callCount++
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if callCount != 1 {
|
||||
t.Errorf("expected 1 call, got %d", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDo_RetryThenSuccess(t *testing.T) {
|
||||
strategy := Strategy{
|
||||
MaxRetries: 3,
|
||||
BaseDelay: 10 * time.Millisecond,
|
||||
MaxDelay: 100 * time.Millisecond,
|
||||
Multiplier: 2.0,
|
||||
Jitter: false,
|
||||
Retryable: IsRetryable,
|
||||
}
|
||||
callCount := 0
|
||||
|
||||
err := Do(context.Background(), strategy, func() error {
|
||||
callCount++
|
||||
if callCount < 3 {
|
||||
return errors.New("temporary error")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if callCount != 3 {
|
||||
t.Errorf("expected 3 calls, got %d", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDo_MaxRetriesExceeded(t *testing.T) {
|
||||
strategy := Strategy{
|
||||
MaxRetries: 2,
|
||||
BaseDelay: 5 * time.Millisecond,
|
||||
MaxDelay: 50 * time.Millisecond,
|
||||
Multiplier: 2.0,
|
||||
Jitter: false,
|
||||
Retryable: IsRetryable,
|
||||
}
|
||||
callCount := 0
|
||||
expectedErr := errors.New("persistent error")
|
||||
|
||||
err := Do(context.Background(), strategy, func() error {
|
||||
callCount++
|
||||
return expectedErr
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if callCount != 3 { // initial + 2 retries
|
||||
t.Errorf("expected 3 calls, got %d", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDo_NonRetryableError(t *testing.T) {
|
||||
strategy := Strategy{
|
||||
MaxRetries: 3,
|
||||
BaseDelay: 10 * time.Millisecond,
|
||||
MaxDelay: 100 * time.Millisecond,
|
||||
Multiplier: 2.0,
|
||||
Jitter: false,
|
||||
Retryable: func(err error) bool { return false }, // 任何错误都不重试
|
||||
}
|
||||
callCount := 0
|
||||
|
||||
err := Do(context.Background(), strategy, func() error {
|
||||
callCount++
|
||||
return errors.New("non-retryable")
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if callCount != 1 {
|
||||
t.Errorf("expected 1 call (no retry), got %d", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDo_ContextCancellation(t *testing.T) {
|
||||
strategy := Strategy{
|
||||
MaxRetries: 3,
|
||||
BaseDelay: 1 * time.Second, // 长延迟确保上下文取消优先
|
||||
MaxDelay: 5 * time.Second,
|
||||
Multiplier: 2.0,
|
||||
Jitter: false,
|
||||
Retryable: IsRetryable,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
callCount := 0
|
||||
err := Do(ctx, strategy, func() error {
|
||||
callCount++
|
||||
return errors.New("error")
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if callCount < 1 {
|
||||
t.Error("expected at least 1 call")
|
||||
}
|
||||
if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) {
|
||||
t.Errorf("expected context error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDoWithResult(t *testing.T) {
|
||||
strategy := Strategy{
|
||||
MaxRetries: 2,
|
||||
BaseDelay: 5 * time.Millisecond,
|
||||
MaxDelay: 50 * time.Millisecond,
|
||||
Multiplier: 2.0,
|
||||
Jitter: false,
|
||||
Retryable: IsRetryable,
|
||||
}
|
||||
callCount := 0
|
||||
|
||||
result, err := DoWithResult(context.Background(), strategy, func() (string, error) {
|
||||
callCount++
|
||||
if callCount < 2 {
|
||||
return "", errors.New("temp error")
|
||||
}
|
||||
return "success", nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if result != "success" {
|
||||
t.Errorf("expected 'success', got %q", result)
|
||||
}
|
||||
if callCount != 2 {
|
||||
t.Errorf("expected 2 calls, got %d", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDoWithMetrics(t *testing.T) {
|
||||
strategy := Strategy{
|
||||
MaxRetries: 2,
|
||||
BaseDelay: 10 * time.Millisecond,
|
||||
MaxDelay: 100 * time.Millisecond,
|
||||
Multiplier: 2.0,
|
||||
Jitter: false,
|
||||
Retryable: IsRetryable,
|
||||
}
|
||||
|
||||
// 成功场景
|
||||
m, err := DoWithMetrics(context.Background(), strategy, func() error {
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if !m.Success {
|
||||
t.Error("expected Success=true")
|
||||
}
|
||||
if m.Attempts != 1 {
|
||||
t.Errorf("expected 1 attempt, got %d", m.Attempts)
|
||||
}
|
||||
|
||||
// 失败场景
|
||||
m2, err := DoWithMetrics(context.Background(), strategy, func() error {
|
||||
return errors.New("always fails")
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if m2.Success {
|
||||
t.Error("expected Success=false")
|
||||
}
|
||||
if m2.Attempts != 3 {
|
||||
t.Errorf("expected 3 attempts, got %d", m2.Attempts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateDelay(t *testing.T) {
|
||||
strategy := Strategy{
|
||||
BaseDelay: 1 * time.Second,
|
||||
MaxDelay: 10 * time.Second,
|
||||
Multiplier: 2.0,
|
||||
Jitter: false,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
attempt int
|
||||
min time.Duration
|
||||
max time.Duration
|
||||
}{
|
||||
{0, 1 * time.Second, 1 * time.Second},
|
||||
{1, 2 * time.Second, 2 * time.Second},
|
||||
{2, 4 * time.Second, 4 * time.Second},
|
||||
{3, 8 * time.Second, 8 * time.Second},
|
||||
{4, 10 * time.Second, 10 * time.Second}, // 达到上限
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
delay := calculateDelay(strategy, tt.attempt)
|
||||
if delay < tt.min || delay > tt.max {
|
||||
t.Errorf("attempt %d: delay=%v, want [%v, %v]", tt.attempt, delay, tt.min, tt.max)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDo(b *testing.B) {
|
||||
strategy := Strategy{
|
||||
MaxRetries: 0,
|
||||
BaseDelay: 0,
|
||||
MaxDelay: 0,
|
||||
Multiplier: 0,
|
||||
Jitter: false,
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = Do(context.Background(), strategy, func() error {
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user