forked from niuniu/llm-intelligence
246 lines
5.1 KiB
Go
246 lines
5.1 KiB
Go
// internal/retry/retry_test.go
|
|
package retry
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func TestDo_Success(t *testing.T) {
|
|
strategy := DefaultStrategy()
|
|
callCount := 0
|
|
|
|
err := Do(context.Background(), strategy, func() error {
|
|
callCount++
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
t.Fatalf("expected no error, got %v", err)
|
|
}
|
|
if callCount != 1 {
|
|
t.Errorf("expected 1 call, got %d", callCount)
|
|
}
|
|
}
|
|
|
|
func TestDo_RetryThenSuccess(t *testing.T) {
|
|
strategy := Strategy{
|
|
MaxRetries: 3,
|
|
BaseDelay: 10 * time.Millisecond,
|
|
MaxDelay: 100 * time.Millisecond,
|
|
Multiplier: 2.0,
|
|
Jitter: false,
|
|
Retryable: IsRetryable,
|
|
}
|
|
callCount := 0
|
|
|
|
err := Do(context.Background(), strategy, func() error {
|
|
callCount++
|
|
if callCount < 3 {
|
|
return errors.New("temporary error")
|
|
}
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
t.Fatalf("expected no error, got %v", err)
|
|
}
|
|
if callCount != 3 {
|
|
t.Errorf("expected 3 calls, got %d", callCount)
|
|
}
|
|
}
|
|
|
|
func TestDo_MaxRetriesExceeded(t *testing.T) {
|
|
strategy := Strategy{
|
|
MaxRetries: 2,
|
|
BaseDelay: 5 * time.Millisecond,
|
|
MaxDelay: 50 * time.Millisecond,
|
|
Multiplier: 2.0,
|
|
Jitter: false,
|
|
Retryable: IsRetryable,
|
|
}
|
|
callCount := 0
|
|
expectedErr := errors.New("persistent error")
|
|
|
|
err := Do(context.Background(), strategy, func() error {
|
|
callCount++
|
|
return expectedErr
|
|
})
|
|
|
|
if err == nil {
|
|
t.Fatal("expected error, got nil")
|
|
}
|
|
if callCount != 3 { // initial + 2 retries
|
|
t.Errorf("expected 3 calls, got %d", callCount)
|
|
}
|
|
}
|
|
|
|
func TestDo_NonRetryableError(t *testing.T) {
|
|
strategy := Strategy{
|
|
MaxRetries: 3,
|
|
BaseDelay: 10 * time.Millisecond,
|
|
MaxDelay: 100 * time.Millisecond,
|
|
Multiplier: 2.0,
|
|
Jitter: false,
|
|
Retryable: func(err error) bool { return false }, // 任何错误都不重试
|
|
}
|
|
callCount := 0
|
|
|
|
err := Do(context.Background(), strategy, func() error {
|
|
callCount++
|
|
return errors.New("non-retryable")
|
|
})
|
|
|
|
if err == nil {
|
|
t.Fatal("expected error, got nil")
|
|
}
|
|
if callCount != 1 {
|
|
t.Errorf("expected 1 call (no retry), got %d", callCount)
|
|
}
|
|
}
|
|
|
|
func TestDo_ContextCancellation(t *testing.T) {
|
|
strategy := Strategy{
|
|
MaxRetries: 3,
|
|
BaseDelay: 1 * time.Second, // 长延迟确保上下文取消优先
|
|
MaxDelay: 5 * time.Second,
|
|
Multiplier: 2.0,
|
|
Jitter: false,
|
|
Retryable: IsRetryable,
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
|
defer cancel()
|
|
|
|
callCount := 0
|
|
err := Do(ctx, strategy, func() error {
|
|
callCount++
|
|
return errors.New("error")
|
|
})
|
|
|
|
if err == nil {
|
|
t.Fatal("expected error, got nil")
|
|
}
|
|
if callCount < 1 {
|
|
t.Error("expected at least 1 call")
|
|
}
|
|
if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) {
|
|
t.Errorf("expected context error, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestDoWithResult(t *testing.T) {
|
|
strategy := Strategy{
|
|
MaxRetries: 2,
|
|
BaseDelay: 5 * time.Millisecond,
|
|
MaxDelay: 50 * time.Millisecond,
|
|
Multiplier: 2.0,
|
|
Jitter: false,
|
|
Retryable: IsRetryable,
|
|
}
|
|
callCount := 0
|
|
|
|
result, err := DoWithResult(context.Background(), strategy, func() (string, error) {
|
|
callCount++
|
|
if callCount < 2 {
|
|
return "", errors.New("temp error")
|
|
}
|
|
return "success", nil
|
|
})
|
|
|
|
if err != nil {
|
|
t.Fatalf("expected no error, got %v", err)
|
|
}
|
|
if result != "success" {
|
|
t.Errorf("expected 'success', got %q", result)
|
|
}
|
|
if callCount != 2 {
|
|
t.Errorf("expected 2 calls, got %d", callCount)
|
|
}
|
|
}
|
|
|
|
func TestDoWithMetrics(t *testing.T) {
|
|
strategy := Strategy{
|
|
MaxRetries: 2,
|
|
BaseDelay: 10 * time.Millisecond,
|
|
MaxDelay: 100 * time.Millisecond,
|
|
Multiplier: 2.0,
|
|
Jitter: false,
|
|
Retryable: IsRetryable,
|
|
}
|
|
|
|
// 成功场景
|
|
m, err := DoWithMetrics(context.Background(), strategy, func() error {
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("expected no error, got %v", err)
|
|
}
|
|
if !m.Success {
|
|
t.Error("expected Success=true")
|
|
}
|
|
if m.Attempts != 1 {
|
|
t.Errorf("expected 1 attempt, got %d", m.Attempts)
|
|
}
|
|
|
|
// 失败场景
|
|
m2, err := DoWithMetrics(context.Background(), strategy, func() error {
|
|
return errors.New("always fails")
|
|
})
|
|
if err == nil {
|
|
t.Fatal("expected error, got nil")
|
|
}
|
|
if m2.Success {
|
|
t.Error("expected Success=false")
|
|
}
|
|
if m2.Attempts != 3 {
|
|
t.Errorf("expected 3 attempts, got %d", m2.Attempts)
|
|
}
|
|
}
|
|
|
|
func TestCalculateDelay(t *testing.T) {
|
|
strategy := Strategy{
|
|
BaseDelay: 1 * time.Second,
|
|
MaxDelay: 10 * time.Second,
|
|
Multiplier: 2.0,
|
|
Jitter: false,
|
|
}
|
|
|
|
tests := []struct {
|
|
attempt int
|
|
min time.Duration
|
|
max time.Duration
|
|
}{
|
|
{0, 1 * time.Second, 1 * time.Second},
|
|
{1, 2 * time.Second, 2 * time.Second},
|
|
{2, 4 * time.Second, 4 * time.Second},
|
|
{3, 8 * time.Second, 8 * time.Second},
|
|
{4, 10 * time.Second, 10 * time.Second}, // 达到上限
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
delay := calculateDelay(strategy, tt.attempt)
|
|
if delay < tt.min || delay > tt.max {
|
|
t.Errorf("attempt %d: delay=%v, want [%v, %v]", tt.attempt, delay, tt.min, tt.max)
|
|
}
|
|
}
|
|
}
|
|
|
|
func BenchmarkDo(b *testing.B) {
|
|
strategy := Strategy{
|
|
MaxRetries: 0,
|
|
BaseDelay: 0,
|
|
MaxDelay: 0,
|
|
Multiplier: 0,
|
|
Jitter: false,
|
|
}
|
|
|
|
for i := 0; i < b.N; i++ {
|
|
_ = Do(context.Background(), strategy, func() error {
|
|
return nil
|
|
})
|
|
}
|
|
}
|