package httpx import ( "net" "net/http" "sync" "time" ) // WithBodyLimit wraps the next handler, enforcing a maximum request body size. func WithBodyLimit(next http.Handler, limit int64) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, limit) next.ServeHTTP(w, r) }) } // RateLimiter implements a per-key (IP or channel) sliding-window rate limiter. // It does NOT block the main flow — on exceed it writes 429 and returns, // but does not propagate an error. type RateLimiter struct { mu sync.RWMutex counters map[string]*slidingWindow window time.Duration limit int } type slidingWindow struct { mu sync.Mutex tokens []time.Time } // NewRateLimiter creates a rate limiter that allows max `limit` requests // per `window` duration per key. func NewRateLimiter(window time.Duration, limit int) *RateLimiter { if limit <= 0 { limit = 10 } if window <= 0 { window = time.Second } return &RateLimiter{ counters: make(map[string]*slidingWindow), window: window, limit: limit, } } // Allow returns true if the request for the given key is within the rate limit, // false if it should be rejected with 429. func (rl *RateLimiter) Allow(key string) bool { now := time.Now() cutoff := now.Add(-rl.window) // P0-1 fix: use write lock for GetOrCreate to avoid data race on map write rl.mu.Lock() sw, exists := rl.counters[key] if !exists { rl.counters[key] = &slidingWindow{tokens: make([]time.Time, 0, rl.limit)} sw = rl.counters[key] } rl.mu.Unlock() sw.mu.Lock() defer sw.mu.Unlock() // Remove expired tokens using in-place filtering to avoid GC pressure. n := 0 for _, t := range sw.tokens { if t.After(cutoff) { sw.tokens[n] = t n++ } } sw.tokens = sw.tokens[:n] if len(sw.tokens) >= rl.limit { return false } sw.tokens = append(sw.tokens, now) return true } // WithRateLimit wraps the next handler with per-key rate limiting. // The key is extracted from X-Forwarded-For or r.RemoteAddr. // Exceeding the limit returns HTTP 429 without propagating an error. func (rl *RateLimiter) WithRateLimit(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { key := rateLimitKey(r) if !rl.Allow(key) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusTooManyRequests) _, _ = w.Write([]byte(`{"error":{"code":"CS_SES_4002","message":"message rate limit exceeded"}}`)) return } next.ServeHTTP(w, r) }) } // rateLimitKey extracts a stable key for rate limiting. // It prefers X-Forwarded-For (first IP) over RemoteAddr. func rateLimitKey(r *http.Request) string { if fwd := r.Header.Get("X-Forwarded-For"); fwd != "" { for i := 0; i < len(fwd); i++ { if fwd[i] == ',' { return fwd[:i] } } return fwd } // Strip port from RemoteAddr using net.SplitHostPort for correct IPv6 handling. addr := r.RemoteAddr if host, _, err := net.SplitHostPort(addr); err == nil { return host } return addr }