Files
llm-intelligence/internal/retry/retry_test.go
2026-05-13 14:42:45 +08:00

246 lines
5.1 KiB
Go

// internal/retry/retry_test.go
package retry
import (
"context"
"errors"
"testing"
"time"
)
func TestDo_Success(t *testing.T) {
strategy := DefaultStrategy()
callCount := 0
err := Do(context.Background(), strategy, func() error {
callCount++
return nil
})
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if callCount != 1 {
t.Errorf("expected 1 call, got %d", callCount)
}
}
func TestDo_RetryThenSuccess(t *testing.T) {
strategy := Strategy{
MaxRetries: 3,
BaseDelay: 10 * time.Millisecond,
MaxDelay: 100 * time.Millisecond,
Multiplier: 2.0,
Jitter: false,
Retryable: IsRetryable,
}
callCount := 0
err := Do(context.Background(), strategy, func() error {
callCount++
if callCount < 3 {
return errors.New("temporary error")
}
return nil
})
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if callCount != 3 {
t.Errorf("expected 3 calls, got %d", callCount)
}
}
func TestDo_MaxRetriesExceeded(t *testing.T) {
strategy := Strategy{
MaxRetries: 2,
BaseDelay: 5 * time.Millisecond,
MaxDelay: 50 * time.Millisecond,
Multiplier: 2.0,
Jitter: false,
Retryable: IsRetryable,
}
callCount := 0
expectedErr := errors.New("persistent error")
err := Do(context.Background(), strategy, func() error {
callCount++
return expectedErr
})
if err == nil {
t.Fatal("expected error, got nil")
}
if callCount != 3 { // initial + 2 retries
t.Errorf("expected 3 calls, got %d", callCount)
}
}
func TestDo_NonRetryableError(t *testing.T) {
strategy := Strategy{
MaxRetries: 3,
BaseDelay: 10 * time.Millisecond,
MaxDelay: 100 * time.Millisecond,
Multiplier: 2.0,
Jitter: false,
Retryable: func(err error) bool { return false }, // 任何错误都不重试
}
callCount := 0
err := Do(context.Background(), strategy, func() error {
callCount++
return errors.New("non-retryable")
})
if err == nil {
t.Fatal("expected error, got nil")
}
if callCount != 1 {
t.Errorf("expected 1 call (no retry), got %d", callCount)
}
}
func TestDo_ContextCancellation(t *testing.T) {
strategy := Strategy{
MaxRetries: 3,
BaseDelay: 1 * time.Second, // 长延迟确保上下文取消优先
MaxDelay: 5 * time.Second,
Multiplier: 2.0,
Jitter: false,
Retryable: IsRetryable,
}
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
callCount := 0
err := Do(ctx, strategy, func() error {
callCount++
return errors.New("error")
})
if err == nil {
t.Fatal("expected error, got nil")
}
if callCount < 1 {
t.Error("expected at least 1 call")
}
if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) {
t.Errorf("expected context error, got %v", err)
}
}
func TestDoWithResult(t *testing.T) {
strategy := Strategy{
MaxRetries: 2,
BaseDelay: 5 * time.Millisecond,
MaxDelay: 50 * time.Millisecond,
Multiplier: 2.0,
Jitter: false,
Retryable: IsRetryable,
}
callCount := 0
result, err := DoWithResult(context.Background(), strategy, func() (string, error) {
callCount++
if callCount < 2 {
return "", errors.New("temp error")
}
return "success", nil
})
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if result != "success" {
t.Errorf("expected 'success', got %q", result)
}
if callCount != 2 {
t.Errorf("expected 2 calls, got %d", callCount)
}
}
func TestDoWithMetrics(t *testing.T) {
strategy := Strategy{
MaxRetries: 2,
BaseDelay: 10 * time.Millisecond,
MaxDelay: 100 * time.Millisecond,
Multiplier: 2.0,
Jitter: false,
Retryable: IsRetryable,
}
// 成功场景
m, err := DoWithMetrics(context.Background(), strategy, func() error {
return nil
})
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if !m.Success {
t.Error("expected Success=true")
}
if m.Attempts != 1 {
t.Errorf("expected 1 attempt, got %d", m.Attempts)
}
// 失败场景
m2, err := DoWithMetrics(context.Background(), strategy, func() error {
return errors.New("always fails")
})
if err == nil {
t.Fatal("expected error, got nil")
}
if m2.Success {
t.Error("expected Success=false")
}
if m2.Attempts != 3 {
t.Errorf("expected 3 attempts, got %d", m2.Attempts)
}
}
func TestCalculateDelay(t *testing.T) {
strategy := Strategy{
BaseDelay: 1 * time.Second,
MaxDelay: 10 * time.Second,
Multiplier: 2.0,
Jitter: false,
}
tests := []struct {
attempt int
min time.Duration
max time.Duration
}{
{0, 1 * time.Second, 1 * time.Second},
{1, 2 * time.Second, 2 * time.Second},
{2, 4 * time.Second, 4 * time.Second},
{3, 8 * time.Second, 8 * time.Second},
{4, 10 * time.Second, 10 * time.Second}, // 达到上限
}
for _, tt := range tests {
delay := calculateDelay(strategy, tt.attempt)
if delay < tt.min || delay > tt.max {
t.Errorf("attempt %d: delay=%v, want [%v, %v]", tt.attempt, delay, tt.min, tt.max)
}
}
}
func BenchmarkDo(b *testing.B) {
strategy := Strategy{
MaxRetries: 0,
BaseDelay: 0,
MaxDelay: 0,
Multiplier: 0,
Jitter: false,
}
for i := 0; i < b.N; i++ {
_ = Do(context.Background(), strategy, func() error {
return nil
})
}
}