Files
lijiaoqiao/gateway/internal/ratelimit/ratelimit_test.go

334 lines
8.7 KiB
Go
Raw Normal View History

package ratelimit
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestTokenBucketLimiter(t *testing.T) {
t.Run("allows requests within limit", func(t *testing.T) {
limiter := NewTokenBucketLimiter(60, 60000, 1.5) // 60 RPM
ctx := context.Background()
// Should allow multiple requests
for i := 0; i < 5; i++ {
allowed, err := limiter.Allow(ctx, "test-key")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !allowed {
t.Errorf("request %d should be allowed", i+1)
}
}
})
t.Run("blocks requests over limit", func(t *testing.T) {
// Use very low limits for testing
limiter := &TokenBucketLimiter{
buckets: make(map[string]*tokenBucket),
defaultRPM: 2,
defaultTPM: 100,
burstMultiplier: 1.0,
cleanInterval: 10 * time.Minute,
}
// Pre-fill the bucket to capacity
key := "test-key"
bucket := limiter.newBucket(2, 100)
limiter.buckets[key] = bucket
ctx := context.Background()
// First two should be allowed
allowed, _ := limiter.Allow(ctx, key)
if !allowed {
t.Error("first request should be allowed")
}
allowed, _ = limiter.Allow(ctx, key)
if !allowed {
t.Error("second request should be allowed")
}
// Third should be blocked
allowed, _ = limiter.Allow(ctx, key)
if allowed {
t.Error("third request should be blocked")
}
})
t.Run("refills tokens over time", func(t *testing.T) {
limiter := &TokenBucketLimiter{
buckets: make(map[string]*tokenBucket),
defaultRPM: 60,
defaultTPM: 60000,
burstMultiplier: 1.0,
cleanInterval: 10 * time.Minute,
}
key := "test-key"
// Consume all tokens
for i := 0; i < 60; i++ {
limiter.Allow(context.Background(), key)
}
// Should be blocked now
allowed, _ := limiter.Allow(context.Background(), key)
if allowed {
t.Error("should be blocked after consuming all tokens")
}
// Manually backdate the refill time to simulate time passing
limiter.buckets[key].lastRefill = time.Now().Add(-2 * time.Minute)
// Should allow again after time-based refill
allowed, _ = limiter.Allow(context.Background(), key)
if !allowed {
t.Error("should allow after token refill")
}
})
t.Run("separate buckets for different keys", func(t *testing.T) {
limiter := NewTokenBucketLimiter(2, 100, 1.0)
ctx := context.Background()
// Exhaust key1
limiter.Allow(ctx, "key1")
limiter.Allow(ctx, "key1")
// key1 should be blocked
allowed, _ := limiter.Allow(ctx, "key1")
if allowed {
t.Error("key1 should be rate limited")
}
// key2 should still work
allowed, _ = limiter.Allow(ctx, "key2")
if !allowed {
t.Error("key2 should be allowed")
}
})
t.Run("get limit returns correct values", func(t *testing.T) {
limiter := NewTokenBucketLimiter(60, 60000, 1.5)
limiter.Allow(context.Background(), "test-key")
limit := limiter.GetLimit("test-key")
if limit.RPM != 60 {
t.Errorf("expected RPM 60, got %d", limit.RPM)
}
if limit.TPM != 60000 {
t.Errorf("expected TPM 60000, got %d", limit.TPM)
}
if limit.Burst != 90 { // 60 * 1.5
t.Errorf("expected Burst 90, got %d", limit.Burst)
}
})
}
func TestSlidingWindowLimiter(t *testing.T) {
t.Run("allows requests within window", func(t *testing.T) {
limiter := NewSlidingWindowLimiter(time.Minute, 5)
ctx := context.Background()
for i := 0; i < 5; i++ {
allowed, err := limiter.Allow(ctx, "test-key")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !allowed {
t.Errorf("request %d should be allowed", i+1)
}
}
})
t.Run("blocks requests over window limit", func(t *testing.T) {
limiter := NewSlidingWindowLimiter(time.Minute, 2)
ctx := context.Background()
limiter.Allow(ctx, "test-key")
limiter.Allow(ctx, "test-key")
allowed, _ := limiter.Allow(ctx, "test-key")
if allowed {
t.Error("third request should be blocked")
}
})
t.Run("sliding window respects time", func(t *testing.T) {
limiter := &SlidingWindowLimiter{
windows: make(map[string]*slidingWindow),
windowSize: time.Minute,
maxRequests: 2,
cleanInterval: 10 * time.Minute,
}
ctx := context.Background()
key := "test-key"
// Make requests
limiter.Allow(ctx, key)
limiter.Allow(ctx, key)
// Should be blocked
allowed, _ := limiter.Allow(ctx, key)
if allowed {
t.Error("should be blocked after reaching limit")
}
// Simulate time passing - move window forward
limiter.windows[key].requests[0] = time.Now().Add(-2 * time.Minute)
limiter.windows[key].requests[1] = time.Now().Add(-2 * time.Minute)
// Should allow now
allowed, _ = limiter.Allow(ctx, key)
if !allowed {
t.Error("should allow after old requests expire from window")
}
})
t.Run("separate windows for different keys", func(t *testing.T) {
limiter := NewSlidingWindowLimiter(time.Minute, 1)
ctx := context.Background()
limiter.Allow(ctx, "key1")
allowed, _ := limiter.Allow(ctx, "key1")
if allowed {
t.Error("key1 should be rate limited")
}
allowed, _ = limiter.Allow(ctx, "key2")
if !allowed {
t.Error("key2 should be allowed")
}
})
t.Run("get limit returns correct remaining", func(t *testing.T) {
limiter := NewSlidingWindowLimiter(time.Minute, 10)
ctx := context.Background()
limiter.Allow(ctx, "test-key")
limiter.Allow(ctx, "test-key")
limiter.Allow(ctx, "test-key")
limit := limiter.GetLimit("test-key")
if limit.RPM != 10 {
t.Errorf("expected RPM 10, got %d", limit.RPM)
}
if limit.Remaining != 7 {
t.Errorf("expected Remaining 7, got %d", limit.Remaining)
}
})
}
func TestMiddleware(t *testing.T) {
t.Run("allows request when under limit", func(t *testing.T) {
limiter := NewTokenBucketLimiter(60, 60000, 1.5)
middleware := NewMiddleware(limiter)
handler := middleware.Limit(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("Authorization", "Bearer test-token")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", rr.Code)
}
})
t.Run("sets rate limit headers when blocked", func(t *testing.T) {
// Use very low limit so request is blocked
limiter := &TokenBucketLimiter{
buckets: make(map[string]*tokenBucket),
defaultRPM: 1,
defaultTPM: 100,
burstMultiplier: 1.0,
cleanInterval: 10 * time.Minute,
}
// Exhaust the bucket - key is the extracted token, not the full Authorization header
key := "test-token"
bucket := limiter.newBucket(1, 100)
bucket.tokens = 0
limiter.buckets[key] = bucket
middleware := NewMiddleware(limiter)
handler := middleware.Limit(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Error("handler should not be called")
}))
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("Authorization", "Bearer "+key)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
// Headers should be set when rate limited
if rr.Header().Get("X-RateLimit-Limit") == "" {
t.Error("expected X-RateLimit-Limit header to be set")
}
if rr.Header().Get("X-RateLimit-Remaining") == "" {
t.Error("expected X-RateLimit-Remaining header to be set")
}
if rr.Header().Get("X-RateLimit-Reset") == "" {
t.Error("expected X-RateLimit-Reset header to be set")
}
})
t.Run("blocks request when over limit", func(t *testing.T) {
// Use very low limit
limiter := &TokenBucketLimiter{
buckets: make(map[string]*tokenBucket),
defaultRPM: 1,
defaultTPM: 100,
burstMultiplier: 1.0,
cleanInterval: 10 * time.Minute,
}
// Exhaust the bucket - key is the extracted token, not the full Authorization header
key := "test-token"
bucket := limiter.newBucket(1, 100)
bucket.tokens = 0 // Exhaust
limiter.buckets[key] = bucket
middleware := NewMiddleware(limiter)
handler := middleware.Limit(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Error("handler should not be called")
}))
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("Authorization", "Bearer "+key)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusTooManyRequests {
t.Errorf("expected status 429, got %d", rr.Code)
}
})
t.Run("uses remote addr when no auth header", func(t *testing.T) {
limiter := NewTokenBucketLimiter(60, 60000, 1.5)
middleware := NewMiddleware(limiter)
handler := middleware.Limit(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/test", nil)
// No Authorization header
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", rr.Code)
}
})
}