// internal/retry/retry_test.go package retry import ( "context" "errors" "testing" "time" ) func TestDo_Success(t *testing.T) { strategy := DefaultStrategy() callCount := 0 err := Do(context.Background(), strategy, func() error { callCount++ return nil }) if err != nil { t.Fatalf("expected no error, got %v", err) } if callCount != 1 { t.Errorf("expected 1 call, got %d", callCount) } } func TestDo_RetryThenSuccess(t *testing.T) { strategy := Strategy{ MaxRetries: 3, BaseDelay: 10 * time.Millisecond, MaxDelay: 100 * time.Millisecond, Multiplier: 2.0, Jitter: false, Retryable: IsRetryable, } callCount := 0 err := Do(context.Background(), strategy, func() error { callCount++ if callCount < 3 { return errors.New("temporary error") } return nil }) if err != nil { t.Fatalf("expected no error, got %v", err) } if callCount != 3 { t.Errorf("expected 3 calls, got %d", callCount) } } func TestDo_MaxRetriesExceeded(t *testing.T) { strategy := Strategy{ MaxRetries: 2, BaseDelay: 5 * time.Millisecond, MaxDelay: 50 * time.Millisecond, Multiplier: 2.0, Jitter: false, Retryable: IsRetryable, } callCount := 0 expectedErr := errors.New("persistent error") err := Do(context.Background(), strategy, func() error { callCount++ return expectedErr }) if err == nil { t.Fatal("expected error, got nil") } if callCount != 3 { // initial + 2 retries t.Errorf("expected 3 calls, got %d", callCount) } } func TestDo_NonRetryableError(t *testing.T) { strategy := Strategy{ MaxRetries: 3, BaseDelay: 10 * time.Millisecond, MaxDelay: 100 * time.Millisecond, Multiplier: 2.0, Jitter: false, Retryable: func(err error) bool { return false }, // 任何错误都不重试 } callCount := 0 err := Do(context.Background(), strategy, func() error { callCount++ return errors.New("non-retryable") }) if err == nil { t.Fatal("expected error, got nil") } if callCount != 1 { t.Errorf("expected 1 call (no retry), got %d", callCount) } } func TestDo_ContextCancellation(t *testing.T) { strategy := Strategy{ MaxRetries: 3, BaseDelay: 1 * time.Second, // 长延迟确保上下文取消优先 MaxDelay: 5 * time.Second, Multiplier: 2.0, Jitter: false, Retryable: IsRetryable, } ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) defer cancel() callCount := 0 err := Do(ctx, strategy, func() error { callCount++ return errors.New("error") }) if err == nil { t.Fatal("expected error, got nil") } if callCount < 1 { t.Error("expected at least 1 call") } if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { t.Errorf("expected context error, got %v", err) } } func TestDoWithResult(t *testing.T) { strategy := Strategy{ MaxRetries: 2, BaseDelay: 5 * time.Millisecond, MaxDelay: 50 * time.Millisecond, Multiplier: 2.0, Jitter: false, Retryable: IsRetryable, } callCount := 0 result, err := DoWithResult(context.Background(), strategy, func() (string, error) { callCount++ if callCount < 2 { return "", errors.New("temp error") } return "success", nil }) if err != nil { t.Fatalf("expected no error, got %v", err) } if result != "success" { t.Errorf("expected 'success', got %q", result) } if callCount != 2 { t.Errorf("expected 2 calls, got %d", callCount) } } func TestDoWithMetrics(t *testing.T) { strategy := Strategy{ MaxRetries: 2, BaseDelay: 10 * time.Millisecond, MaxDelay: 100 * time.Millisecond, Multiplier: 2.0, Jitter: false, Retryable: IsRetryable, } // 成功场景 m, err := DoWithMetrics(context.Background(), strategy, func() error { return nil }) if err != nil { t.Fatalf("expected no error, got %v", err) } if !m.Success { t.Error("expected Success=true") } if m.Attempts != 1 { t.Errorf("expected 1 attempt, got %d", m.Attempts) } // 失败场景 m2, err := DoWithMetrics(context.Background(), strategy, func() error { return errors.New("always fails") }) if err == nil { t.Fatal("expected error, got nil") } if m2.Success { t.Error("expected Success=false") } if m2.Attempts != 3 { t.Errorf("expected 3 attempts, got %d", m2.Attempts) } } func TestCalculateDelay(t *testing.T) { strategy := Strategy{ BaseDelay: 1 * time.Second, MaxDelay: 10 * time.Second, Multiplier: 2.0, Jitter: false, } tests := []struct { attempt int min time.Duration max time.Duration }{ {0, 1 * time.Second, 1 * time.Second}, {1, 2 * time.Second, 2 * time.Second}, {2, 4 * time.Second, 4 * time.Second}, {3, 8 * time.Second, 8 * time.Second}, {4, 10 * time.Second, 10 * time.Second}, // 达到上限 } for _, tt := range tests { delay := calculateDelay(strategy, tt.attempt) if delay < tt.min || delay > tt.max { t.Errorf("attempt %d: delay=%v, want [%v, %v]", tt.attempt, delay, tt.min, tt.max) } } } func BenchmarkDo(b *testing.B) { strategy := Strategy{ MaxRetries: 0, BaseDelay: 0, MaxDelay: 0, Multiplier: 0, Jitter: false, } for i := 0; i < b.N; i++ { _ = Do(context.Background(), strategy, func() error { return nil }) } }