Files
lijiaoqiao/gateway/internal/handler/handler_test.go

488 lines
13 KiB
Go
Raw Normal View History

package handler
import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
"time"
"lijiaoqiao/gateway/internal/adapter"
"lijiaoqiao/gateway/internal/router"
gwerror "lijiaoqiao/gateway/pkg/error"
"lijiaoqiao/gateway/pkg/model"
)
// mockRouter 用于测试的Router
type mockRouter struct {
providers map[string]adapter.ProviderAdapter
health map[string]*router.ProviderHealth
}
func (m *mockRouter) SelectProvider(ctx context.Context, model string) (adapter.ProviderAdapter, error) {
for name := range m.providers {
return m.providers[name], nil
}
return nil, gwerror.NewGatewayError(gwerror.ROUTER_NO_PROVIDER_AVAILABLE, "no provider")
}
func (m *mockRouter) RecordResult(ctx context.Context, providerName string, success bool, latencyMs int64) {}
func (m *mockRouter) GetHealthStatus() map[string]*router.ProviderHealth {
return m.health
}
func (m *mockRouter) GetFallbackProviders(ctx context.Context, model string) ([]adapter.ProviderAdapter, error) {
return nil, nil
}
// mockProvider 用于测试的Provider
type mockProvider struct {
name string
models []string
healthy bool
}
func (m *mockProvider) ChatCompletion(ctx context.Context, model string, messages []adapter.Message, options adapter.CompletionOptions) (*adapter.CompletionResponse, error) {
return &adapter.CompletionResponse{
ID: "test-id",
Object: "chat.completion",
Created: time.Now().Unix(),
Model: model,
Choices: []adapter.Choice{
{
Index: 0,
Message: &adapter.Message{
Role: "assistant",
Content: "Hello, world!",
},
FinishReason: "stop",
},
},
Usage: adapter.Usage{
PromptTokens: 10,
CompletionTokens: 5,
TotalTokens: 15,
},
}, nil
}
func (m *mockProvider) ChatCompletionStream(ctx context.Context, model string, messages []adapter.Message, options adapter.CompletionOptions) (<-chan *adapter.StreamChunk, error) {
ch := make(chan *adapter.StreamChunk, 1)
ch <- &adapter.StreamChunk{
ID: "test-id",
Object: "chat.completion.chunk",
Created: time.Now().Unix(),
Model: model,
Choices: []adapter.StreamChoice{
{
Index: 0,
Delta: &adapter.Delta{
Role: "assistant",
Content: "Hello",
},
},
},
}
close(ch)
return ch, nil
}
func (m *mockProvider) GetUsage(response *adapter.CompletionResponse) adapter.Usage {
return response.Usage
}
func (m *mockProvider) MapError(err error) adapter.ProviderError {
return adapter.ProviderError{}
}
func (m *mockProvider) HealthCheck(ctx context.Context) bool {
return m.healthy
}
func (m *mockProvider) ProviderName() string {
return m.name
}
func (m *mockProvider) SupportedModels() []string {
return m.models
}
func TestNewHandler(t *testing.T) {
r := router.NewRouter(router.StrategyLatency)
h := NewHandler(r)
if h == nil {
t.Fatal("expected non-nil handler")
}
if h.version != "v1" {
t.Errorf("expected version v1, got %s", h.version)
}
}
func TestChatCompletionsHandle_InvalidRequest(t *testing.T) {
r := router.NewRouter(router.StrategyLatency)
h := NewHandler(r)
tests := []struct {
name string
body string
wantStatus int
}{
{
name: "invalid JSON",
body: "{invalid}",
wantStatus: 400,
},
{
name: "empty messages",
body: `{"model": "gpt-4", "messages": []}`,
wantStatus: 400,
},
{
name: "missing model - passes validation but no provider for empty model",
body: `{"messages": [{"role": "user", "content": "hello"}]}`,
wantStatus: 503,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(tt.body))
req.Header.Set("Content-Type", "application/json")
rr := httptest.NewRecorder()
h.ChatCompletionsHandle(rr, req)
if rr.Code != tt.wantStatus {
t.Errorf("expected status %d, got %d", tt.wantStatus, rr.Code)
}
})
}
}
func TestChatCompletionsHandle_Success(t *testing.T) {
r := router.NewRouter(router.StrategyLatency)
prov := &mockProvider{name: "test", models: []string{"gpt-4"}, healthy: true}
r.RegisterProvider("test", prov)
h := NewHandler(r)
body := `{"model": "gpt-4", "messages": [{"role": "user", "content": "hello"}]}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
rr := httptest.NewRecorder()
h.ChatCompletionsHandle(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", rr.Code)
}
var resp model.ChatCompletionResponse
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp.ID == "" {
t.Error("expected non-empty ID")
}
if resp.Object != "chat.completion" {
t.Errorf("expected object chat.completion, got %s", resp.Object)
}
if len(resp.Choices) != 1 {
t.Errorf("expected 1 choice, got %d", len(resp.Choices))
}
if resp.Choices[0].Message.Content != "Hello, world!" {
t.Errorf("unexpected content: %s", resp.Choices[0].Message.Content)
}
}
func TestChatCompletionsHandle_WithRequestID(t *testing.T) {
r := router.NewRouter(router.StrategyLatency)
prov := &mockProvider{name: "test", models: []string{"gpt-4"}, healthy: true}
r.RegisterProvider("test", prov)
h := NewHandler(r)
body := `{"model": "gpt-4", "messages": [{"role": "user", "content": "hello"}]}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Request-ID", "custom-req-id")
rr := httptest.NewRecorder()
h.ChatCompletionsHandle(rr, req)
if rr.Header().Get("X-Request-ID") != "custom-req-id" {
t.Errorf("expected X-Request-ID custom-req-id, got %s", rr.Header().Get("X-Request-ID"))
}
}
func TestChatCompletionsHandle_ProviderError(t *testing.T) {
r := router.NewRouter(router.StrategyLatency)
// 不注册任何provider会触发ROUTER_NO_PROVIDER_AVAILABLE
h := NewHandler(r)
body := `{"model": "gpt-4", "messages": [{"role": "user", "content": "hello"}]}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
rr := httptest.NewRecorder()
h.ChatCompletionsHandle(rr, req)
if rr.Code != 503 {
t.Errorf("expected status 503, got %d", rr.Code)
}
}
func TestCompletionsHandle_Success(t *testing.T) {
r := router.NewRouter(router.StrategyLatency)
prov := &mockProvider{name: "test", models: []string{"gpt-4"}, healthy: true}
r.RegisterProvider("test", prov)
h := NewHandler(r)
body := `{"model": "gpt-4", "prompt": "Say hello"}`
req := httptest.NewRequest("POST", "/v1/completions", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
rr := httptest.NewRecorder()
h.CompletionsHandle(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", rr.Code)
}
var resp model.CompletionResponse
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp.Object != "text_completion" {
t.Errorf("expected object text_completion, got %s", resp.Object)
}
}
func TestCompletionsHandle_InvalidRequest(t *testing.T) {
r := router.NewRouter(router.StrategyLatency)
h := NewHandler(r)
body := `{invalid}`
req := httptest.NewRequest("POST", "/v1/completions", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
rr := httptest.NewRecorder()
h.CompletionsHandle(rr, req)
if rr.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", rr.Code)
}
}
func TestModelsHandle(t *testing.T) {
r := router.NewRouter(router.StrategyLatency)
h := NewHandler(r)
req := httptest.NewRequest("GET", "/v1/models", nil)
rr := httptest.NewRecorder()
h.ModelsHandle(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", rr.Code)
}
var resp map[string]interface{}
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp["object"] != "list" {
t.Errorf("expected object list, got %v", resp["object"])
}
data, ok := resp["data"].([]interface{})
if !ok {
t.Fatal("expected data to be array")
}
if len(data) != 4 {
t.Errorf("expected 4 models, got %d", len(data))
}
}
func TestHealthHandle_AllHealthy(t *testing.T) {
r := router.NewRouter(router.StrategyLatency)
prov := &mockProvider{name: "test", models: []string{"gpt-4"}, healthy: true}
r.RegisterProvider("test", prov)
h := NewHandler(r)
req := httptest.NewRequest("GET", "/health", nil)
rr := httptest.NewRecorder()
h.HealthHandle(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", rr.Code)
}
var resp model.HealthStatus
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp.Status != "healthy" {
t.Errorf("expected status healthy, got %s", resp.Status)
}
}
func TestHealthHandle_Degraded(t *testing.T) {
r := router.NewRouter(router.StrategyLatency)
prov := &mockProvider{name: "unhealthy", models: []string{}, healthy: false}
r.RegisterProvider("unhealthy", prov)
// 标记为不可用
r.UpdateHealth("unhealthy", false)
h := NewHandler(r)
req := httptest.NewRequest("GET", "/health", nil)
rr := httptest.NewRecorder()
h.HealthHandle(rr, req)
if rr.Code != http.StatusServiceUnavailable {
t.Errorf("expected status 503, got %d", rr.Code)
}
var resp model.HealthStatus
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp.Status != "degraded" {
t.Errorf("expected status degraded, got %s", resp.Status)
}
}
func TestWriteJSON(t *testing.T) {
r := router.NewRouter(router.StrategyLatency)
h := NewHandler(r)
w := httptest.NewRecorder()
data := map[string]string{"key": "value"}
h.writeJSON(w, http.StatusOK, data, "test-req-id")
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if w.Header().Get("Content-Type") != "application/json" {
t.Errorf("expected Content-Type application/json, got %s", w.Header().Get("Content-Type"))
}
if w.Header().Get("X-Request-ID") != "test-req-id" {
t.Errorf("expected X-Request-ID test-req-id, got %s", w.Header().Get("X-Request-ID"))
}
}
func TestWriteError(t *testing.T) {
r := router.NewRouter(router.StrategyLatency)
h := NewHandler(r)
w := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/test", nil)
gwErr := gwerror.NewGatewayError(gwerror.COMMON_INVALID_REQUEST, "test error").WithRequestID("req-123")
h.writeError(w, req, gwErr)
if w.Code != 400 {
t.Errorf("expected status 400, got %d", w.Code)
}
var resp model.ErrorResponse
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp.Error.Message != "test error" {
t.Errorf("unexpected error message: %s", resp.Error.Message)
}
if resp.Error.Type != "gateway_error" {
t.Errorf("unexpected error type: %s", resp.Error.Type)
}
if resp.Error.Code != "COMMON_001" {
t.Errorf("unexpected error code: %s", resp.Error.Code)
}
}
func TestGenerateRequestID(t *testing.T) {
id1 := generateRequestID()
id2 := generateRequestID()
if id1 == "" {
t.Error("expected non-empty request ID")
}
if id1 == id2 {
t.Error("expected different request IDs")
}
if len(id1) < 10 {
t.Error("request ID seems too short")
}
}
func TestMarshalJSON(t *testing.T) {
data := map[string]string{"key": "value"}
result := marshalJSON(data)
if result != `{"key":"value"}` {
t.Errorf("unexpected JSON: %s", result)
}
}
func TestMarshalJSON_NilValues(t *testing.T) {
type testStruct struct {
Name *string
}
name := "test"
obj := testStruct{Name: &name}
result := marshalJSON(obj)
if result == "" {
t.Error("expected non-empty JSON")
}
}
// mockFailingProvider 用于测试流式处理失败的Provider
type mockFailingProvider struct {
mockProvider
}
func (m *mockFailingProvider) ChatCompletionStream(ctx context.Context, model string, messages []adapter.Message, options adapter.CompletionOptions) (<-chan *adapter.StreamChunk, error) {
return nil, errors.New("stream error")
}
func TestHandleStream_ProviderError(t *testing.T) {
r := router.NewRouter(router.StrategyLatency)
prov := &mockFailingProvider{}
r.RegisterProvider("failing", prov)
h := NewHandler(r)
body := `{"model": "gpt-4", "messages": [{"role": "user", "content": "hello"}], "stream": true}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
rr := httptest.NewRecorder()
h.ChatCompletionsHandle(rr, req)
// 流式请求失败时会写入错误
if rr.Code == 0 {
t.Log("stream error handled (code 0 means write error)")
}
}