已验证的问题状态: 1. P0-07补偿处理器 - 已集成到main.go ✅ 2. P0-09外键校验器 - 已集成到main.go并调用 ✅ 3. 幂等协议Idempotency-Key - 已在idempotency.go实现 ✅ 4. 幂等唯一索引 - 已在SQL中定义 ✅ Gateway修复: - 修复cors.go语法错误(重复函数定义) - 修复middleware_test.go参数不匹配问题 - 修复go.mod降级到go 1.21解决依赖问题
488 lines
13 KiB
Go
488 lines
13 KiB
Go
package handler
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"encoding/json"
|
||
"errors"
|
||
"net/http"
|
||
"net/http/httptest"
|
||
"testing"
|
||
"time"
|
||
|
||
"lijiaoqiao/gateway/internal/adapter"
|
||
"lijiaoqiao/gateway/internal/router"
|
||
gwerror "lijiaoqiao/gateway/pkg/error"
|
||
"lijiaoqiao/gateway/pkg/model"
|
||
)
|
||
|
||
// mockRouter 用于测试的Router
|
||
type mockRouter struct {
|
||
providers map[string]adapter.ProviderAdapter
|
||
health map[string]*router.ProviderHealth
|
||
}
|
||
|
||
func (m *mockRouter) SelectProvider(ctx context.Context, model string) (adapter.ProviderAdapter, error) {
|
||
for name := range m.providers {
|
||
return m.providers[name], nil
|
||
}
|
||
return nil, gwerror.NewGatewayError(gwerror.ROUTER_NO_PROVIDER_AVAILABLE, "no provider")
|
||
}
|
||
|
||
func (m *mockRouter) RecordResult(ctx context.Context, providerName string, success bool, latencyMs int64) {}
|
||
|
||
func (m *mockRouter) GetHealthStatus() map[string]*router.ProviderHealth {
|
||
return m.health
|
||
}
|
||
|
||
func (m *mockRouter) GetFallbackProviders(ctx context.Context, model string) ([]adapter.ProviderAdapter, error) {
|
||
return nil, nil
|
||
}
|
||
|
||
// mockProvider 用于测试的Provider
|
||
type mockProvider struct {
|
||
name string
|
||
models []string
|
||
healthy bool
|
||
}
|
||
|
||
func (m *mockProvider) ChatCompletion(ctx context.Context, model string, messages []adapter.Message, options adapter.CompletionOptions) (*adapter.CompletionResponse, error) {
|
||
return &adapter.CompletionResponse{
|
||
ID: "test-id",
|
||
Object: "chat.completion",
|
||
Created: time.Now().Unix(),
|
||
Model: model,
|
||
Choices: []adapter.Choice{
|
||
{
|
||
Index: 0,
|
||
Message: &adapter.Message{
|
||
Role: "assistant",
|
||
Content: "Hello, world!",
|
||
},
|
||
FinishReason: "stop",
|
||
},
|
||
},
|
||
Usage: adapter.Usage{
|
||
PromptTokens: 10,
|
||
CompletionTokens: 5,
|
||
TotalTokens: 15,
|
||
},
|
||
}, nil
|
||
}
|
||
|
||
func (m *mockProvider) ChatCompletionStream(ctx context.Context, model string, messages []adapter.Message, options adapter.CompletionOptions) (<-chan *adapter.StreamChunk, error) {
|
||
ch := make(chan *adapter.StreamChunk, 1)
|
||
ch <- &adapter.StreamChunk{
|
||
ID: "test-id",
|
||
Object: "chat.completion.chunk",
|
||
Created: time.Now().Unix(),
|
||
Model: model,
|
||
Choices: []adapter.StreamChoice{
|
||
{
|
||
Index: 0,
|
||
Delta: &adapter.Delta{
|
||
Role: "assistant",
|
||
Content: "Hello",
|
||
},
|
||
},
|
||
},
|
||
}
|
||
close(ch)
|
||
return ch, nil
|
||
}
|
||
|
||
func (m *mockProvider) GetUsage(response *adapter.CompletionResponse) adapter.Usage {
|
||
return response.Usage
|
||
}
|
||
|
||
func (m *mockProvider) MapError(err error) adapter.ProviderError {
|
||
return adapter.ProviderError{}
|
||
}
|
||
|
||
func (m *mockProvider) HealthCheck(ctx context.Context) bool {
|
||
return m.healthy
|
||
}
|
||
|
||
func (m *mockProvider) ProviderName() string {
|
||
return m.name
|
||
}
|
||
|
||
func (m *mockProvider) SupportedModels() []string {
|
||
return m.models
|
||
}
|
||
|
||
func TestNewHandler(t *testing.T) {
|
||
r := router.NewRouter(router.StrategyLatency)
|
||
h := NewHandler(r)
|
||
|
||
if h == nil {
|
||
t.Fatal("expected non-nil handler")
|
||
}
|
||
if h.version != "v1" {
|
||
t.Errorf("expected version v1, got %s", h.version)
|
||
}
|
||
}
|
||
|
||
func TestChatCompletionsHandle_InvalidRequest(t *testing.T) {
|
||
r := router.NewRouter(router.StrategyLatency)
|
||
h := NewHandler(r)
|
||
|
||
tests := []struct {
|
||
name string
|
||
body string
|
||
wantStatus int
|
||
}{
|
||
{
|
||
name: "invalid JSON",
|
||
body: "{invalid}",
|
||
wantStatus: 400,
|
||
},
|
||
{
|
||
name: "empty messages",
|
||
body: `{"model": "gpt-4", "messages": []}`,
|
||
wantStatus: 400,
|
||
},
|
||
{
|
||
name: "missing model - passes validation but no provider for empty model",
|
||
body: `{"messages": [{"role": "user", "content": "hello"}]}`,
|
||
wantStatus: 503,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(tt.body))
|
||
req.Header.Set("Content-Type", "application/json")
|
||
rr := httptest.NewRecorder()
|
||
|
||
h.ChatCompletionsHandle(rr, req)
|
||
|
||
if rr.Code != tt.wantStatus {
|
||
t.Errorf("expected status %d, got %d", tt.wantStatus, rr.Code)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestChatCompletionsHandle_Success(t *testing.T) {
|
||
r := router.NewRouter(router.StrategyLatency)
|
||
prov := &mockProvider{name: "test", models: []string{"gpt-4"}, healthy: true}
|
||
r.RegisterProvider("test", prov)
|
||
|
||
h := NewHandler(r)
|
||
|
||
body := `{"model": "gpt-4", "messages": [{"role": "user", "content": "hello"}]}`
|
||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(body))
|
||
req.Header.Set("Content-Type", "application/json")
|
||
rr := httptest.NewRecorder()
|
||
|
||
h.ChatCompletionsHandle(rr, req)
|
||
|
||
if rr.Code != http.StatusOK {
|
||
t.Errorf("expected status 200, got %d", rr.Code)
|
||
}
|
||
|
||
var resp model.ChatCompletionResponse
|
||
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
|
||
t.Fatalf("failed to unmarshal response: %v", err)
|
||
}
|
||
|
||
if resp.ID == "" {
|
||
t.Error("expected non-empty ID")
|
||
}
|
||
if resp.Object != "chat.completion" {
|
||
t.Errorf("expected object chat.completion, got %s", resp.Object)
|
||
}
|
||
if len(resp.Choices) != 1 {
|
||
t.Errorf("expected 1 choice, got %d", len(resp.Choices))
|
||
}
|
||
if resp.Choices[0].Message.Content != "Hello, world!" {
|
||
t.Errorf("unexpected content: %s", resp.Choices[0].Message.Content)
|
||
}
|
||
}
|
||
|
||
func TestChatCompletionsHandle_WithRequestID(t *testing.T) {
|
||
r := router.NewRouter(router.StrategyLatency)
|
||
prov := &mockProvider{name: "test", models: []string{"gpt-4"}, healthy: true}
|
||
r.RegisterProvider("test", prov)
|
||
|
||
h := NewHandler(r)
|
||
|
||
body := `{"model": "gpt-4", "messages": [{"role": "user", "content": "hello"}]}`
|
||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(body))
|
||
req.Header.Set("Content-Type", "application/json")
|
||
req.Header.Set("X-Request-ID", "custom-req-id")
|
||
rr := httptest.NewRecorder()
|
||
|
||
h.ChatCompletionsHandle(rr, req)
|
||
|
||
if rr.Header().Get("X-Request-ID") != "custom-req-id" {
|
||
t.Errorf("expected X-Request-ID custom-req-id, got %s", rr.Header().Get("X-Request-ID"))
|
||
}
|
||
}
|
||
|
||
func TestChatCompletionsHandle_ProviderError(t *testing.T) {
|
||
r := router.NewRouter(router.StrategyLatency)
|
||
// 不注册任何provider,会触发ROUTER_NO_PROVIDER_AVAILABLE
|
||
|
||
h := NewHandler(r)
|
||
|
||
body := `{"model": "gpt-4", "messages": [{"role": "user", "content": "hello"}]}`
|
||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(body))
|
||
req.Header.Set("Content-Type", "application/json")
|
||
rr := httptest.NewRecorder()
|
||
|
||
h.ChatCompletionsHandle(rr, req)
|
||
|
||
if rr.Code != 503 {
|
||
t.Errorf("expected status 503, got %d", rr.Code)
|
||
}
|
||
}
|
||
|
||
func TestCompletionsHandle_Success(t *testing.T) {
|
||
r := router.NewRouter(router.StrategyLatency)
|
||
prov := &mockProvider{name: "test", models: []string{"gpt-4"}, healthy: true}
|
||
r.RegisterProvider("test", prov)
|
||
|
||
h := NewHandler(r)
|
||
|
||
body := `{"model": "gpt-4", "prompt": "Say hello"}`
|
||
req := httptest.NewRequest("POST", "/v1/completions", bytes.NewBufferString(body))
|
||
req.Header.Set("Content-Type", "application/json")
|
||
rr := httptest.NewRecorder()
|
||
|
||
h.CompletionsHandle(rr, req)
|
||
|
||
if rr.Code != http.StatusOK {
|
||
t.Errorf("expected status 200, got %d", rr.Code)
|
||
}
|
||
|
||
var resp model.CompletionResponse
|
||
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
|
||
t.Fatalf("failed to unmarshal response: %v", err)
|
||
}
|
||
|
||
if resp.Object != "text_completion" {
|
||
t.Errorf("expected object text_completion, got %s", resp.Object)
|
||
}
|
||
}
|
||
|
||
func TestCompletionsHandle_InvalidRequest(t *testing.T) {
|
||
r := router.NewRouter(router.StrategyLatency)
|
||
h := NewHandler(r)
|
||
|
||
body := `{invalid}`
|
||
req := httptest.NewRequest("POST", "/v1/completions", bytes.NewBufferString(body))
|
||
req.Header.Set("Content-Type", "application/json")
|
||
rr := httptest.NewRecorder()
|
||
|
||
h.CompletionsHandle(rr, req)
|
||
|
||
if rr.Code != http.StatusBadRequest {
|
||
t.Errorf("expected status 400, got %d", rr.Code)
|
||
}
|
||
}
|
||
|
||
func TestModelsHandle(t *testing.T) {
|
||
r := router.NewRouter(router.StrategyLatency)
|
||
h := NewHandler(r)
|
||
|
||
req := httptest.NewRequest("GET", "/v1/models", nil)
|
||
rr := httptest.NewRecorder()
|
||
|
||
h.ModelsHandle(rr, req)
|
||
|
||
if rr.Code != http.StatusOK {
|
||
t.Errorf("expected status 200, got %d", rr.Code)
|
||
}
|
||
|
||
var resp map[string]interface{}
|
||
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
|
||
t.Fatalf("failed to unmarshal response: %v", err)
|
||
}
|
||
|
||
if resp["object"] != "list" {
|
||
t.Errorf("expected object list, got %v", resp["object"])
|
||
}
|
||
|
||
data, ok := resp["data"].([]interface{})
|
||
if !ok {
|
||
t.Fatal("expected data to be array")
|
||
}
|
||
if len(data) != 4 {
|
||
t.Errorf("expected 4 models, got %d", len(data))
|
||
}
|
||
}
|
||
|
||
func TestHealthHandle_AllHealthy(t *testing.T) {
|
||
r := router.NewRouter(router.StrategyLatency)
|
||
prov := &mockProvider{name: "test", models: []string{"gpt-4"}, healthy: true}
|
||
r.RegisterProvider("test", prov)
|
||
|
||
h := NewHandler(r)
|
||
|
||
req := httptest.NewRequest("GET", "/health", nil)
|
||
rr := httptest.NewRecorder()
|
||
|
||
h.HealthHandle(rr, req)
|
||
|
||
if rr.Code != http.StatusOK {
|
||
t.Errorf("expected status 200, got %d", rr.Code)
|
||
}
|
||
|
||
var resp model.HealthStatus
|
||
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
|
||
t.Fatalf("failed to unmarshal response: %v", err)
|
||
}
|
||
|
||
if resp.Status != "healthy" {
|
||
t.Errorf("expected status healthy, got %s", resp.Status)
|
||
}
|
||
}
|
||
|
||
func TestHealthHandle_Degraded(t *testing.T) {
|
||
r := router.NewRouter(router.StrategyLatency)
|
||
prov := &mockProvider{name: "unhealthy", models: []string{}, healthy: false}
|
||
r.RegisterProvider("unhealthy", prov)
|
||
// 标记为不可用
|
||
r.UpdateHealth("unhealthy", false)
|
||
|
||
h := NewHandler(r)
|
||
|
||
req := httptest.NewRequest("GET", "/health", nil)
|
||
rr := httptest.NewRecorder()
|
||
|
||
h.HealthHandle(rr, req)
|
||
|
||
if rr.Code != http.StatusServiceUnavailable {
|
||
t.Errorf("expected status 503, got %d", rr.Code)
|
||
}
|
||
|
||
var resp model.HealthStatus
|
||
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
|
||
t.Fatalf("failed to unmarshal response: %v", err)
|
||
}
|
||
|
||
if resp.Status != "degraded" {
|
||
t.Errorf("expected status degraded, got %s", resp.Status)
|
||
}
|
||
}
|
||
|
||
func TestWriteJSON(t *testing.T) {
|
||
r := router.NewRouter(router.StrategyLatency)
|
||
h := NewHandler(r)
|
||
|
||
w := httptest.NewRecorder()
|
||
data := map[string]string{"key": "value"}
|
||
|
||
h.writeJSON(w, http.StatusOK, data, "test-req-id")
|
||
|
||
if w.Code != http.StatusOK {
|
||
t.Errorf("expected status 200, got %d", w.Code)
|
||
}
|
||
if w.Header().Get("Content-Type") != "application/json" {
|
||
t.Errorf("expected Content-Type application/json, got %s", w.Header().Get("Content-Type"))
|
||
}
|
||
if w.Header().Get("X-Request-ID") != "test-req-id" {
|
||
t.Errorf("expected X-Request-ID test-req-id, got %s", w.Header().Get("X-Request-ID"))
|
||
}
|
||
}
|
||
|
||
func TestWriteError(t *testing.T) {
|
||
r := router.NewRouter(router.StrategyLatency)
|
||
h := NewHandler(r)
|
||
|
||
w := httptest.NewRecorder()
|
||
req := httptest.NewRequest("GET", "/test", nil)
|
||
|
||
gwErr := gwerror.NewGatewayError(gwerror.COMMON_INVALID_REQUEST, "test error").WithRequestID("req-123")
|
||
|
||
h.writeError(w, req, gwErr)
|
||
|
||
if w.Code != 400 {
|
||
t.Errorf("expected status 400, got %d", w.Code)
|
||
}
|
||
|
||
var resp model.ErrorResponse
|
||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||
t.Fatalf("failed to unmarshal response: %v", err)
|
||
}
|
||
|
||
if resp.Error.Message != "test error" {
|
||
t.Errorf("unexpected error message: %s", resp.Error.Message)
|
||
}
|
||
if resp.Error.Type != "gateway_error" {
|
||
t.Errorf("unexpected error type: %s", resp.Error.Type)
|
||
}
|
||
if resp.Error.Code != "COMMON_001" {
|
||
t.Errorf("unexpected error code: %s", resp.Error.Code)
|
||
}
|
||
}
|
||
|
||
func TestGenerateRequestID(t *testing.T) {
|
||
id1 := generateRequestID()
|
||
id2 := generateRequestID()
|
||
|
||
if id1 == "" {
|
||
t.Error("expected non-empty request ID")
|
||
}
|
||
if id1 == id2 {
|
||
t.Error("expected different request IDs")
|
||
}
|
||
|
||
if len(id1) < 10 {
|
||
t.Error("request ID seems too short")
|
||
}
|
||
}
|
||
|
||
func TestMarshalJSON(t *testing.T) {
|
||
data := map[string]string{"key": "value"}
|
||
result := marshalJSON(data)
|
||
|
||
if result != `{"key":"value"}` {
|
||
t.Errorf("unexpected JSON: %s", result)
|
||
}
|
||
}
|
||
|
||
func TestMarshalJSON_NilValues(t *testing.T) {
|
||
type testStruct struct {
|
||
Name *string
|
||
}
|
||
name := "test"
|
||
obj := testStruct{Name: &name}
|
||
result := marshalJSON(obj)
|
||
|
||
if result == "" {
|
||
t.Error("expected non-empty JSON")
|
||
}
|
||
}
|
||
|
||
// mockFailingProvider 用于测试流式处理失败的Provider
|
||
type mockFailingProvider struct {
|
||
mockProvider
|
||
}
|
||
|
||
func (m *mockFailingProvider) ChatCompletionStream(ctx context.Context, model string, messages []adapter.Message, options adapter.CompletionOptions) (<-chan *adapter.StreamChunk, error) {
|
||
return nil, errors.New("stream error")
|
||
}
|
||
|
||
func TestHandleStream_ProviderError(t *testing.T) {
|
||
r := router.NewRouter(router.StrategyLatency)
|
||
prov := &mockFailingProvider{}
|
||
r.RegisterProvider("failing", prov)
|
||
|
||
h := NewHandler(r)
|
||
|
||
body := `{"model": "gpt-4", "messages": [{"role": "user", "content": "hello"}], "stream": true}`
|
||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(body))
|
||
req.Header.Set("Content-Type", "application/json")
|
||
rr := httptest.NewRecorder()
|
||
|
||
h.ChatCompletionsHandle(rr, req)
|
||
|
||
// 流式请求失败时会写入错误
|
||
if rr.Code == 0 {
|
||
t.Log("stream error handled (code 0 means write error)")
|
||
}
|
||
}
|