Files
lijiaoqiao/gateway/internal/handler/handler_test.go
Your Name d90cc382a4 fix: 验证并修复comprehensive_review_v4问题
已验证的问题状态:
1. P0-07补偿处理器 - 已集成到main.go 
2. P0-09外键校验器 - 已集成到main.go并调用 
3. 幂等协议Idempotency-Key - 已在idempotency.go实现 
4. 幂等唯一索引 - 已在SQL中定义 

Gateway修复:
- 修复cors.go语法错误(重复函数定义)
- 修复middleware_test.go参数不匹配问题
- 修复go.mod降级到go 1.21解决依赖问题
2026-04-08 20:17:07 +08:00

488 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package handler
import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
"time"
"lijiaoqiao/gateway/internal/adapter"
"lijiaoqiao/gateway/internal/router"
gwerror "lijiaoqiao/gateway/pkg/error"
"lijiaoqiao/gateway/pkg/model"
)
// mockRouter 用于测试的Router
type mockRouter struct {
providers map[string]adapter.ProviderAdapter
health map[string]*router.ProviderHealth
}
func (m *mockRouter) SelectProvider(ctx context.Context, model string) (adapter.ProviderAdapter, error) {
for name := range m.providers {
return m.providers[name], nil
}
return nil, gwerror.NewGatewayError(gwerror.ROUTER_NO_PROVIDER_AVAILABLE, "no provider")
}
func (m *mockRouter) RecordResult(ctx context.Context, providerName string, success bool, latencyMs int64) {}
func (m *mockRouter) GetHealthStatus() map[string]*router.ProviderHealth {
return m.health
}
func (m *mockRouter) GetFallbackProviders(ctx context.Context, model string) ([]adapter.ProviderAdapter, error) {
return nil, nil
}
// mockProvider 用于测试的Provider
type mockProvider struct {
name string
models []string
healthy bool
}
func (m *mockProvider) ChatCompletion(ctx context.Context, model string, messages []adapter.Message, options adapter.CompletionOptions) (*adapter.CompletionResponse, error) {
return &adapter.CompletionResponse{
ID: "test-id",
Object: "chat.completion",
Created: time.Now().Unix(),
Model: model,
Choices: []adapter.Choice{
{
Index: 0,
Message: &adapter.Message{
Role: "assistant",
Content: "Hello, world!",
},
FinishReason: "stop",
},
},
Usage: adapter.Usage{
PromptTokens: 10,
CompletionTokens: 5,
TotalTokens: 15,
},
}, nil
}
func (m *mockProvider) ChatCompletionStream(ctx context.Context, model string, messages []adapter.Message, options adapter.CompletionOptions) (<-chan *adapter.StreamChunk, error) {
ch := make(chan *adapter.StreamChunk, 1)
ch <- &adapter.StreamChunk{
ID: "test-id",
Object: "chat.completion.chunk",
Created: time.Now().Unix(),
Model: model,
Choices: []adapter.StreamChoice{
{
Index: 0,
Delta: &adapter.Delta{
Role: "assistant",
Content: "Hello",
},
},
},
}
close(ch)
return ch, nil
}
func (m *mockProvider) GetUsage(response *adapter.CompletionResponse) adapter.Usage {
return response.Usage
}
func (m *mockProvider) MapError(err error) adapter.ProviderError {
return adapter.ProviderError{}
}
func (m *mockProvider) HealthCheck(ctx context.Context) bool {
return m.healthy
}
func (m *mockProvider) ProviderName() string {
return m.name
}
func (m *mockProvider) SupportedModels() []string {
return m.models
}
func TestNewHandler(t *testing.T) {
r := router.NewRouter(router.StrategyLatency)
h := NewHandler(r)
if h == nil {
t.Fatal("expected non-nil handler")
}
if h.version != "v1" {
t.Errorf("expected version v1, got %s", h.version)
}
}
func TestChatCompletionsHandle_InvalidRequest(t *testing.T) {
r := router.NewRouter(router.StrategyLatency)
h := NewHandler(r)
tests := []struct {
name string
body string
wantStatus int
}{
{
name: "invalid JSON",
body: "{invalid}",
wantStatus: 400,
},
{
name: "empty messages",
body: `{"model": "gpt-4", "messages": []}`,
wantStatus: 400,
},
{
name: "missing model - passes validation but no provider for empty model",
body: `{"messages": [{"role": "user", "content": "hello"}]}`,
wantStatus: 503,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(tt.body))
req.Header.Set("Content-Type", "application/json")
rr := httptest.NewRecorder()
h.ChatCompletionsHandle(rr, req)
if rr.Code != tt.wantStatus {
t.Errorf("expected status %d, got %d", tt.wantStatus, rr.Code)
}
})
}
}
func TestChatCompletionsHandle_Success(t *testing.T) {
r := router.NewRouter(router.StrategyLatency)
prov := &mockProvider{name: "test", models: []string{"gpt-4"}, healthy: true}
r.RegisterProvider("test", prov)
h := NewHandler(r)
body := `{"model": "gpt-4", "messages": [{"role": "user", "content": "hello"}]}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
rr := httptest.NewRecorder()
h.ChatCompletionsHandle(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", rr.Code)
}
var resp model.ChatCompletionResponse
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp.ID == "" {
t.Error("expected non-empty ID")
}
if resp.Object != "chat.completion" {
t.Errorf("expected object chat.completion, got %s", resp.Object)
}
if len(resp.Choices) != 1 {
t.Errorf("expected 1 choice, got %d", len(resp.Choices))
}
if resp.Choices[0].Message.Content != "Hello, world!" {
t.Errorf("unexpected content: %s", resp.Choices[0].Message.Content)
}
}
func TestChatCompletionsHandle_WithRequestID(t *testing.T) {
r := router.NewRouter(router.StrategyLatency)
prov := &mockProvider{name: "test", models: []string{"gpt-4"}, healthy: true}
r.RegisterProvider("test", prov)
h := NewHandler(r)
body := `{"model": "gpt-4", "messages": [{"role": "user", "content": "hello"}]}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Request-ID", "custom-req-id")
rr := httptest.NewRecorder()
h.ChatCompletionsHandle(rr, req)
if rr.Header().Get("X-Request-ID") != "custom-req-id" {
t.Errorf("expected X-Request-ID custom-req-id, got %s", rr.Header().Get("X-Request-ID"))
}
}
func TestChatCompletionsHandle_ProviderError(t *testing.T) {
r := router.NewRouter(router.StrategyLatency)
// 不注册任何provider会触发ROUTER_NO_PROVIDER_AVAILABLE
h := NewHandler(r)
body := `{"model": "gpt-4", "messages": [{"role": "user", "content": "hello"}]}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
rr := httptest.NewRecorder()
h.ChatCompletionsHandle(rr, req)
if rr.Code != 503 {
t.Errorf("expected status 503, got %d", rr.Code)
}
}
func TestCompletionsHandle_Success(t *testing.T) {
r := router.NewRouter(router.StrategyLatency)
prov := &mockProvider{name: "test", models: []string{"gpt-4"}, healthy: true}
r.RegisterProvider("test", prov)
h := NewHandler(r)
body := `{"model": "gpt-4", "prompt": "Say hello"}`
req := httptest.NewRequest("POST", "/v1/completions", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
rr := httptest.NewRecorder()
h.CompletionsHandle(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", rr.Code)
}
var resp model.CompletionResponse
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp.Object != "text_completion" {
t.Errorf("expected object text_completion, got %s", resp.Object)
}
}
func TestCompletionsHandle_InvalidRequest(t *testing.T) {
r := router.NewRouter(router.StrategyLatency)
h := NewHandler(r)
body := `{invalid}`
req := httptest.NewRequest("POST", "/v1/completions", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
rr := httptest.NewRecorder()
h.CompletionsHandle(rr, req)
if rr.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", rr.Code)
}
}
func TestModelsHandle(t *testing.T) {
r := router.NewRouter(router.StrategyLatency)
h := NewHandler(r)
req := httptest.NewRequest("GET", "/v1/models", nil)
rr := httptest.NewRecorder()
h.ModelsHandle(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", rr.Code)
}
var resp map[string]interface{}
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp["object"] != "list" {
t.Errorf("expected object list, got %v", resp["object"])
}
data, ok := resp["data"].([]interface{})
if !ok {
t.Fatal("expected data to be array")
}
if len(data) != 4 {
t.Errorf("expected 4 models, got %d", len(data))
}
}
func TestHealthHandle_AllHealthy(t *testing.T) {
r := router.NewRouter(router.StrategyLatency)
prov := &mockProvider{name: "test", models: []string{"gpt-4"}, healthy: true}
r.RegisterProvider("test", prov)
h := NewHandler(r)
req := httptest.NewRequest("GET", "/health", nil)
rr := httptest.NewRecorder()
h.HealthHandle(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", rr.Code)
}
var resp model.HealthStatus
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp.Status != "healthy" {
t.Errorf("expected status healthy, got %s", resp.Status)
}
}
func TestHealthHandle_Degraded(t *testing.T) {
r := router.NewRouter(router.StrategyLatency)
prov := &mockProvider{name: "unhealthy", models: []string{}, healthy: false}
r.RegisterProvider("unhealthy", prov)
// 标记为不可用
r.UpdateHealth("unhealthy", false)
h := NewHandler(r)
req := httptest.NewRequest("GET", "/health", nil)
rr := httptest.NewRecorder()
h.HealthHandle(rr, req)
if rr.Code != http.StatusServiceUnavailable {
t.Errorf("expected status 503, got %d", rr.Code)
}
var resp model.HealthStatus
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp.Status != "degraded" {
t.Errorf("expected status degraded, got %s", resp.Status)
}
}
func TestWriteJSON(t *testing.T) {
r := router.NewRouter(router.StrategyLatency)
h := NewHandler(r)
w := httptest.NewRecorder()
data := map[string]string{"key": "value"}
h.writeJSON(w, http.StatusOK, data, "test-req-id")
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if w.Header().Get("Content-Type") != "application/json" {
t.Errorf("expected Content-Type application/json, got %s", w.Header().Get("Content-Type"))
}
if w.Header().Get("X-Request-ID") != "test-req-id" {
t.Errorf("expected X-Request-ID test-req-id, got %s", w.Header().Get("X-Request-ID"))
}
}
func TestWriteError(t *testing.T) {
r := router.NewRouter(router.StrategyLatency)
h := NewHandler(r)
w := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/test", nil)
gwErr := gwerror.NewGatewayError(gwerror.COMMON_INVALID_REQUEST, "test error").WithRequestID("req-123")
h.writeError(w, req, gwErr)
if w.Code != 400 {
t.Errorf("expected status 400, got %d", w.Code)
}
var resp model.ErrorResponse
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp.Error.Message != "test error" {
t.Errorf("unexpected error message: %s", resp.Error.Message)
}
if resp.Error.Type != "gateway_error" {
t.Errorf("unexpected error type: %s", resp.Error.Type)
}
if resp.Error.Code != "COMMON_001" {
t.Errorf("unexpected error code: %s", resp.Error.Code)
}
}
func TestGenerateRequestID(t *testing.T) {
id1 := generateRequestID()
id2 := generateRequestID()
if id1 == "" {
t.Error("expected non-empty request ID")
}
if id1 == id2 {
t.Error("expected different request IDs")
}
if len(id1) < 10 {
t.Error("request ID seems too short")
}
}
func TestMarshalJSON(t *testing.T) {
data := map[string]string{"key": "value"}
result := marshalJSON(data)
if result != `{"key":"value"}` {
t.Errorf("unexpected JSON: %s", result)
}
}
func TestMarshalJSON_NilValues(t *testing.T) {
type testStruct struct {
Name *string
}
name := "test"
obj := testStruct{Name: &name}
result := marshalJSON(obj)
if result == "" {
t.Error("expected non-empty JSON")
}
}
// mockFailingProvider 用于测试流式处理失败的Provider
type mockFailingProvider struct {
mockProvider
}
func (m *mockFailingProvider) ChatCompletionStream(ctx context.Context, model string, messages []adapter.Message, options adapter.CompletionOptions) (<-chan *adapter.StreamChunk, error) {
return nil, errors.New("stream error")
}
func TestHandleStream_ProviderError(t *testing.T) {
r := router.NewRouter(router.StrategyLatency)
prov := &mockFailingProvider{}
r.RegisterProvider("failing", prov)
h := NewHandler(r)
body := `{"model": "gpt-4", "messages": [{"role": "user", "content": "hello"}], "stream": true}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
rr := httptest.NewRecorder()
h.ChatCompletionsHandle(rr, req)
// 流式请求失败时会写入错误
if rr.Code == 0 {
t.Log("stream error handled (code 0 means write error)")
}
}