package httpclient import ( "errors" "io" "net/http" "strings" "sync/atomic" "testing" "time" "github.com/stretchr/testify/require" ) type roundTripFunc func(*http.Request) (*http.Response, error) func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { return f(req) } func TestValidatedTransport_CacheHostValidation(t *testing.T) { originalValidate := validateResolvedIP defer func() { validateResolvedIP = originalValidate }() var validateCalls int32 validateResolvedIP = func(host string) error { atomic.AddInt32(&validateCalls, 1) require.Equal(t, "api.openai.com", host) return nil } var baseCalls int32 base := roundTripFunc(func(_ *http.Request) (*http.Response, error) { atomic.AddInt32(&baseCalls, 1) return &http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader(`{}`)), Header: make(http.Header), }, nil }) now := time.Unix(1730000000, 0) transport := newValidatedTransport(base) transport.now = func() time.Time { return now } req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/responses", nil) require.NoError(t, err) _, err = transport.RoundTrip(req) require.NoError(t, err) _, err = transport.RoundTrip(req) require.NoError(t, err) require.Equal(t, int32(1), atomic.LoadInt32(&validateCalls)) require.Equal(t, int32(2), atomic.LoadInt32(&baseCalls)) } func TestValidatedTransport_ExpiredCacheTriggersRevalidation(t *testing.T) { originalValidate := validateResolvedIP defer func() { validateResolvedIP = originalValidate }() var validateCalls int32 validateResolvedIP = func(_ string) error { atomic.AddInt32(&validateCalls, 1) return nil } base := roundTripFunc(func(_ *http.Request) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader(`{}`)), Header: make(http.Header), }, nil }) now := time.Unix(1730001000, 0) transport := newValidatedTransport(base) transport.now = func() time.Time { return now } req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/responses", nil) require.NoError(t, err) _, err = transport.RoundTrip(req) require.NoError(t, err) now = now.Add(validatedHostTTL + time.Second) _, err = transport.RoundTrip(req) require.NoError(t, err) require.Equal(t, int32(2), atomic.LoadInt32(&validateCalls)) } func TestValidatedTransport_ValidationErrorStopsRoundTrip(t *testing.T) { originalValidate := validateResolvedIP defer func() { validateResolvedIP = originalValidate }() expectedErr := errors.New("dns rebinding rejected") validateResolvedIP = func(_ string) error { return expectedErr } var baseCalls int32 base := roundTripFunc(func(_ *http.Request) (*http.Response, error) { atomic.AddInt32(&baseCalls, 1) return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader(`{}`))}, nil }) transport := newValidatedTransport(base) req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/responses", nil) require.NoError(t, err) _, err = transport.RoundTrip(req) require.ErrorIs(t, err, expectedErr) require.Equal(t, int32(0), atomic.LoadInt32(&baseCalls)) } func TestBuildClientKey(t *testing.T) { opts1 := Options{ ProxyURL: "http://proxy:8080", Timeout: 30 * time.Second, ResponseHeaderTimeout: 10 * time.Second, InsecureSkipVerify: false, ValidateResolvedIP: true, AllowPrivateHosts: false, MaxIdleConns: 100, MaxIdleConnsPerHost: 10, MaxConnsPerHost: 0, } key1 := buildClientKey(opts1) require.NotEmpty(t, key1) // Same options should produce same key key2 := buildClientKey(opts1) require.Equal(t, key1, key2) // Different options should produce different key opts2 := opts1 opts2.Timeout = 60 * time.Second key3 := buildClientKey(opts2) require.NotEqual(t, key1, key3) } func TestBuildClientKeyTrimsSpaces(t *testing.T) { opts1 := Options{ProxyURL: "http://proxy:8080"} opts2 := Options{ProxyURL: " http://proxy:8080 "} key1 := buildClientKey(opts1) key2 := buildClientKey(opts2) require.Equal(t, key1, key2) } func TestIsValidatedHost(t *testing.T) { base := roundTripFunc(func(_ *http.Request) (*http.Response, error) { return &http.Response{StatusCode: http.StatusOK}, nil }) transport := newValidatedTransport(base) now := time.Unix(1730000000, 0) transport.now = func() time.Time { return now } host := "example.com" transport.validatedHosts.Store(host, now.Add(validatedHostTTL)) require.True(t, transport.isValidatedHost(host, now)) require.False(t, transport.isValidatedHost(host, now.Add(validatedHostTTL+1))) require.False(t, transport.isValidatedHost("other.com", now)) } func TestIsValidatedHostNilTransport(t *testing.T) { var transport *validatedTransport now := time.Now() require.False(t, transport.isValidatedHost("example.com", now)) } func TestNewValidatedTransport(t *testing.T) { base := roundTripFunc(func(_ *http.Request) (*http.Response, error) { return &http.Response{StatusCode: http.StatusOK}, nil }) transport := newValidatedTransport(base) require.NotNil(t, transport) require.NotNil(t, transport.base) require.NotNil(t, transport.now) } func TestBuildClient(t *testing.T) { t.Run("valid options", func(t *testing.T) { opts := Options{ Timeout: 30 * time.Second, MaxIdleConns: 100, MaxIdleConnsPerHost: 10, } client, err := buildClient(opts) require.NoError(t, err) require.NotNil(t, client) }) t.Run("insecure skip verify not allowed", func(t *testing.T) { opts := Options{ InsecureSkipVerify: true, } _, err := buildClient(opts) require.Error(t, err) require.Contains(t, err.Error(), "insecure_skip_verify is not allowed") }) } func TestBuildTransport(t *testing.T) { t.Run("default values", func(t *testing.T) { opts := Options{} transport, err := buildTransport(opts) require.NoError(t, err) require.NotNil(t, transport) require.Equal(t, defaultMaxIdleConns, transport.MaxIdleConns) require.Equal(t, defaultMaxIdleConnsPerHost, transport.MaxIdleConnsPerHost) }) t.Run("custom values", func(t *testing.T) { opts := Options{ MaxIdleConns: 50, MaxIdleConnsPerHost: 5, } transport, err := buildTransport(opts) require.NoError(t, err) require.NotNil(t, transport) require.Equal(t, 50, transport.MaxIdleConns) require.Equal(t, 5, transport.MaxIdleConnsPerHost) }) }