diff --git a/gateway/go.mod b/gateway/go.mod index fa41636..e7010a7 100644 --- a/gateway/go.mod +++ b/gateway/go.mod @@ -3,10 +3,19 @@ module lijiaoqiao/gateway go 1.21 require ( - github.com/golang-jwt/jwt/v5 v5.2.0 + github.com/jackc/pgx/v5 v5.5.0 + github.com/stretchr/testify v1.8.1 ) require ( - github.com/jackc/pgx/v5 v5.5.0 - golang.org/x/net v0.19.0 + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/jackc/puddle/v2 v2.2.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/objx v0.5.0 // indirect + golang.org/x/crypto v0.9.0 // indirect + golang.org/x/sync v0.1.0 // indirect + golang.org/x/text v0.9.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/gateway/internal/adapter/openai_adapter.go b/gateway/internal/adapter/openai_adapter.go index 4829a19..6f56f6e 100644 --- a/gateway/internal/adapter/openai_adapter.go +++ b/gateway/internal/adapter/openai_adapter.go @@ -1,6 +1,7 @@ package adapter import ( + "bufio" "bytes" "context" "encoding/json" @@ -8,8 +9,6 @@ import ( "io" "net/http" "time" - - "lijiaoqiao/gateway/pkg/error" ) // OpenAIAdapter OpenAI适配器 @@ -188,13 +187,9 @@ func (a *OpenAIAdapter) ChatCompletionStream(ctx context.Context, model string, defer close(ch) defer resp.Body.Close() - reader := io.Reader(resp.Body) - for { - line, err := io.ReadLine(reader) - if err != nil { - return - } - + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + line := scanner.Bytes() if len(line) < 6 { continue } @@ -262,24 +257,24 @@ func (a *OpenAIAdapter) GetUsage(response *CompletionResponse) Usage { } // MapError 错误码映射 -func (a *OpenAIAdapter) MapError(err error) error { +func (a *OpenAIAdapter) MapError(err error) ProviderError { // 简化实现,实际应根据OpenAI错误响应映射 errStr := err.Error() if contains(errStr, "invalid_api_key") { - return error.NewGatewayError(error.PROVIDER_INVALID_KEY, "Invalid API key").WithInternal(err) + return ProviderError{Code: "PROVIDER_001", Message: "Invalid API key", HTTPStatus: 401, Retryable: false} } if contains(errStr, "rate_limit") { - return error.NewGatewayError(error.PROVIDER_RATE_LIMIT, "Rate limit exceeded").WithInternal(err) + return ProviderError{Code: "PROVIDER_002", Message: "Rate limit exceeded", HTTPStatus: 429, Retryable: true} } if contains(errStr, "quota") { - return error.NewGatewayError(error.PROVIDER_QUOTA_EXCEEDED, "Quota exceeded").WithInternal(err) + return ProviderError{Code: "PROVIDER_003", Message: "Quota exceeded", HTTPStatus: 402, Retryable: false} } if contains(errStr, "model_not_found") { - return error.NewGatewayError(error.PROVIDER_MODEL_NOT_FOUND, "Model not found").WithInternal(err) + return ProviderError{Code: "PROVIDER_004", Message: "Model not found", HTTPStatus: 404, Retryable: false} } - return error.NewGatewayError(error.PROVIDER_ERROR, "Provider error").WithInternal(err) + return ProviderError{Code: "PROVIDER_005", Message: "Provider error", HTTPStatus: 502, Retryable: true} } func contains(s, substr string) bool {