diff --git a/scripts/official_pricing_import_common.go b/scripts/official_pricing_import_common.go index 8241bac..4f9827d 100644 --- a/scripts/official_pricing_import_common.go +++ b/scripts/official_pricing_import_common.go @@ -8,6 +8,7 @@ import ( "html" "io" "net/http" + "net/url" "os" "regexp" "strings" @@ -343,7 +344,18 @@ func fetchRawPricingPageOnce(url string, client *http.Client, opts officialPrici resp, err := client.Do(req) if err != nil { - return "", isRetriablePricingFetchError(err), fmt.Errorf("fetch %s: %w", url, err) + if fallbackClient, ok := pricingFetchDirectFallbackClient(client, url); ok { + fallbackResp, fallbackErr := fallbackClient.Do(req.Clone(req.Context())) + if fallbackErr == nil { + resp = fallbackResp + err = nil + } else { + err = fallbackErr + } + } + if err != nil { + return "", isRetriablePricingFetchError(err), fmt.Errorf("fetch %s: %w", url, err) + } } defer resp.Body.Close() @@ -385,6 +397,34 @@ func isRetriablePricingFetchError(err error) bool { return false } +func pricingFetchDirectFallbackClient(baseClient *http.Client, rawURL string) (*http.Client, bool) { + if baseClient == nil || !pricingFetchProxyConfigured() { + return nil, false + } + parsed, err := url.Parse(rawURL) + if err != nil { + return nil, false + } + if parsed.Scheme != "http" && parsed.Scheme != "https" { + return nil, false + } + transport := http.DefaultTransport.(*http.Transport).Clone() + if custom, ok := baseClient.Transport.(*http.Transport); ok && custom != nil { + transport = custom.Clone() + } + transport.Proxy = nil + return &http.Client{Timeout: baseClient.Timeout, Transport: transport}, true +} + +func pricingFetchProxyConfigured() bool { + for _, key := range []string{"HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"} { + if strings.TrimSpace(os.Getenv(key)) != "" { + return true + } + } + return false +} + func cleanHTMLText(raw string) string { tagPattern := regexp.MustCompile(`(?is)<[^>]+>`) spacePattern := regexp.MustCompile(`[ \t]+`) diff --git a/scripts/official_pricing_import_common_test.go b/scripts/official_pricing_import_common_test.go index fa8693b..cbf39f5 100644 --- a/scripts/official_pricing_import_common_test.go +++ b/scripts/official_pricing_import_common_test.go @@ -5,6 +5,7 @@ package main import ( "net/http" "net/http/httptest" + "os" "sync/atomic" "testing" "time" @@ -44,6 +45,47 @@ func TestIsRetriablePricingFetchErrorRecognizesEOF(t *testing.T) { } } +func TestFetchRawPricingPageFallsBackWithoutProxyOnRetriableProxyFailure(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("ok")) + })) + defer server.Close() + + oldHTTPProxy, hadHTTPProxy := os.LookupEnv("HTTP_PROXY") + oldHTTPSProxy, hadHTTPSProxy := os.LookupEnv("HTTPS_PROXY") + oldNoProxy, hadNoProxy := os.LookupEnv("NO_PROXY") + defer func() { + if hadHTTPProxy { + _ = os.Setenv("HTTP_PROXY", oldHTTPProxy) + } else { + _ = os.Unsetenv("HTTP_PROXY") + } + if hadHTTPSProxy { + _ = os.Setenv("HTTPS_PROXY", oldHTTPSProxy) + } else { + _ = os.Unsetenv("HTTPS_PROXY") + } + if hadNoProxy { + _ = os.Setenv("NO_PROXY", oldNoProxy) + } else { + _ = os.Unsetenv("NO_PROXY") + } + }() + + _ = os.Setenv("HTTP_PROXY", "http://127.0.0.1:1") + _ = os.Unsetenv("HTTPS_PROXY") + _ = os.Unsetenv("NO_PROXY") + + client := &http.Client{Timeout: 2 * time.Second} + body, err := fetchRawPricingPage(server.URL, "", client) + if err != nil { + t.Fatalf("fetchRawPricingPage returned error with proxy fallback enabled: %v", err) + } + if body != "ok" { + t.Fatalf("body = %q, want ok", body) + } +} + func TestFallbackModalityCanonicalizesAliases(t *testing.T) { if got := fallbackModality("image"); got != "vision" { t.Fatalf("fallbackModality(image) = %q, want vision", got)