diff --git a/internal/app/server_timeout_test.go b/internal/app/server_timeout_test.go new file mode 100644 index 00000000..76d1ac3f --- /dev/null +++ b/internal/app/server_timeout_test.go @@ -0,0 +1,69 @@ +package app + +import ( + "context" + "net/http" + "testing" + "time" +) + +func TestServer_TimeoutConfiguration_Real(t *testing.T) { + // 创建一个慢速处理器 + slowHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(100 * time.Millisecond) + w.Write([]byte("OK")) + }) + + server := NewServer("127.0.0.1:0", slowHandler, nil) + + // 验证超时配置已设置 + if server.server.ReadTimeout != 30*time.Second { + t.Errorf("ReadTimeout = %v, want 30s", server.server.ReadTimeout) + } + if server.server.ReadHeaderTimeout != 10*time.Second { + t.Errorf("ReadHeaderTimeout = %v, want 10s", server.server.ReadHeaderTimeout) + } + if server.server.WriteTimeout != 30*time.Second { + t.Errorf("WriteTimeout = %v, want 30s", server.server.WriteTimeout) + } + if server.server.IdleTimeout != 120*time.Second { + t.Errorf("IdleTimeout = %v, want 120s", server.server.IdleTimeout) + } + if server.server.MaxHeaderBytes != 1<<20 { + t.Errorf("MaxHeaderBytes = %d, want %d", server.server.MaxHeaderBytes, 1<<20) + } +} + +func TestServer_GracefulShutdown(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("OK")) + }) + + server := NewServer("127.0.0.1:0", handler, nil) + + ctx, cancel := context.WithCancel(context.Background()) + + // 在后台启动服务 + go func() { + server.Run(ctx) + }() + + // 给服务启动时间 + time.Sleep(50 * time.Millisecond) + + // 发送取消信号触发关闭 + cancel() + + // 等待关闭完成 + time.Sleep(100 * time.Millisecond) + + // 验证服务可以正常关闭 + t.Log("Server shutdown gracefully") +} + +func TestServer_Addr(t *testing.T) { + server := NewServer(":8080", nil, nil) + if server.Addr() != ":8080" { + t.Errorf("Addr() = %q, want :8080", server.Addr()) + } +} diff --git a/internal/log/log_integration_test.go b/internal/log/log_integration_test.go new file mode 100644 index 00000000..669a0ab3 --- /dev/null +++ b/internal/log/log_integration_test.go @@ -0,0 +1,136 @@ +package log + +import ( + "log/slog" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestInitWithConfig_FileOutput_Real(t *testing.T) { + logDir := t.TempDir() + logFile := filepath.Join(logDir, "test.log") + + cfg := Config{ + Level: "info", + Output: logFile, + MaxSize: 1, // 1MB + } + + // 初始化日志 + InitWithConfig(cfg) + + // 写入日志 + Info("test message", "key", "value") + Error("error message", "err", "test error") + + // 验证文件被创建 + if _, err := os.Stat(logFile); os.IsNotExist(err) { + t.Fatal("Log file was not created") + } + + // 读取文件内容 + content, err := os.ReadFile(logFile) + if err != nil { + t.Fatalf("Failed to read log file: %v", err) + } + + // 验证内容 + contentStr := string(content) + if !strings.Contains(contentStr, "test message") { + t.Error("Log file missing 'test message'") + } + if !strings.Contains(contentStr, "key") { + t.Error("Log file missing 'key' field") + } +} + +func TestSanitizeAttrs_Real(t *testing.T) { + tests := []struct { + name string + key string + value string + expected string + }{ + { + name: "token should be redacted", + key: "api_token", + value: "secret123", + expected: "[REDACTED]", + }, + { + name: "password should be redacted", + key: "user_password", + value: "mypassword", + expected: "[REDACTED]", + }, + { + name: "normal key should not be redacted", + key: "user_name", + value: "john", + expected: "john", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + attr := slog.String(tt.key, tt.value) + sanitized := sanitizeAttrs(nil, attr) + + if sanitized.Value.String() != tt.expected { + t.Errorf("sanitizeAttrs() = %q, want %q", sanitized.Value.String(), tt.expected) + } + }) + } +} + +func TestIsSensitive_Real(t *testing.T) { + sensitiveKeys := []string{ + "token", + "password", + "secret", + "api_key", + "private_key", + } + + nonSensitiveKeys := []string{ + "name", + "id", + "timestamp", + "message", + } + + for _, key := range sensitiveKeys { + if !IsSensitive(key) { + t.Errorf("IsSensitive(%q) = false, want true", key) + } + } + + for _, key := range nonSensitiveKeys { + if IsSensitive(key) { + t.Errorf("IsSensitive(%q) = true, want false", key) + } + } +} + +func TestParseLevel_Real(t *testing.T) { + tests := []struct { + input string + expected slog.Level + }{ + {"debug", slog.LevelDebug}, + {"info", slog.LevelInfo}, + {"warn", slog.LevelWarn}, + {"error", slog.LevelError}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + level := parseLevel(tt.input) + if level != tt.expected { + t.Errorf("parseLevel(%q) = %d, want %d", tt.input, level, tt.expected) + } + }) + } +} diff --git a/internal/metrics/metrics_integration_test.go b/internal/metrics/metrics_integration_test.go new file mode 100644 index 00000000..02315733 --- /dev/null +++ b/internal/metrics/metrics_integration_test.go @@ -0,0 +1,115 @@ +package metrics + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func TestMetricsEndpoint_RealHTTP(t *testing.T) { + // 设置一些指标值 + SetActiveHosts(42) + SetActiveProviders(5) + RecordLogFlushError() + RecordLogDroppedEvent() + + // 创建真实的 HTTP 服务器 + server := httptest.NewServer(Handler()) + defer server.Close() + + // 使用真实 HTTP 客户端访问 + resp, err := http.Get(server.URL + "/metrics") + if err != nil { + t.Fatalf("Failed to GET /metrics: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200, got %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Failed to read body: %v", err) + } + + bodyStr := string(body) + + // 验证 Prometheus 格式 + requiredElements := []string{ + "# HELP", + "# TYPE", + "active_hosts 42", + "active_providers 5", + "log_flush_errors_total", + "log_dropped_events_total", + } + + for _, elem := range requiredElements { + if !strings.Contains(bodyStr, elem) { + t.Errorf("Response missing: %q", elem) + } + } + + // 验证内容类型 + contentType := resp.Header.Get("Content-Type") + if !strings.Contains(contentType, "text/plain") { + t.Errorf("Content-Type = %q, want text/plain", contentType) + } +} + +func TestMiddleware_RealRequest(t *testing.T) { + // 创建一个使用 middleware 的处理器 + handler := Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusTeapot) + w.Write([]byte("I'm a teapot")) + })) + + server := httptest.NewServer(handler) + defer server.Close() + + // 发送真实请求 + resp, err := http.Get(server.URL + "/test-path") + if err != nil { + t.Fatalf("Failed to GET: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusTeapot { + t.Errorf("Expected status %d, got %d", http.StatusTeapot, resp.StatusCode) + } + + body, _ := io.ReadAll(resp.Body) + if string(body) != "I'm a teapot" { + t.Errorf("Body = %q, want I'm a teapot", string(body)) + } +} + +func TestRecordHTTPRequest_RealMetrics(t *testing.T) { + // 记录请求 + RecordHTTPRequest("GET", "/api/test", 200, 100*time.Millisecond) + + // 启动服务器 + server := httptest.NewServer(Handler()) + defer server.Close() + + resp, err := http.Get(server.URL + "/metrics") + if err != nil { + t.Fatalf("Failed to GET: %v", err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + bodyStr := string(body) + + // 验证请求被记录 + if !strings.Contains(bodyStr, "http_requests_total") { + t.Error("Missing http_requests_total metric") + } + if !strings.Contains(bodyStr, "http_request_duration_seconds") { + t.Error("Missing http_request_duration_seconds metric") + } +} diff --git a/internal/routing/logwriter_metrics_test.go b/internal/routing/logwriter_metrics_test.go new file mode 100644 index 00000000..49f5a00e --- /dev/null +++ b/internal/routing/logwriter_metrics_test.go @@ -0,0 +1,250 @@ +package routing + +import ( + "context" + "errors" + "testing" + "time" +) + +// mockLogSink 用于测试错误指标 +type mockLogSink struct { + appendDecisionError error + appendFailoverError error + appendStickyError error + closeCalled bool +} + +func (m *mockLogSink) AppendDecision(ctx context.Context, event RouteDecisionEvent) error { + return m.appendDecisionError +} + +func (m *mockLogSink) AppendFailover(ctx context.Context, event RouteFailoverEvent) error { + return m.appendFailoverError +} + +func (m *mockLogSink) AppendStickyAudit(ctx context.Context, event RouteStickyAuditEvent) error { + return m.appendStickyError +} + +func (m *mockLogSink) Close() error { + m.closeCalled = true + return nil +} + +func TestErrorMetrics_RecordAndGet(t *testing.T) { + var em ErrorMetrics + + // 初始值应为 0 + if em.GetFlushErrors() != 0 { + t.Errorf("GetFlushErrors() = %d, want 0", em.GetFlushErrors()) + } + if em.GetWriteErrors() != 0 { + t.Errorf("GetWriteErrors() = %d, want 0", em.GetWriteErrors()) + } + if em.GetDroppedEvents() != 0 { + t.Errorf("GetDroppedEvents() = %d, want 0", em.GetDroppedEvents()) + } + + // 记录错误 + em.RecordFlushError() + em.RecordFlushError() + em.RecordWriteError() + em.RecordDroppedEvent() + em.RecordDroppedEvent() + em.RecordDroppedEvent() + + // 验证计数 + if em.GetFlushErrors() != 2 { + t.Errorf("GetFlushErrors() = %d, want 2", em.GetFlushErrors()) + } + if em.GetWriteErrors() != 1 { + t.Errorf("GetWriteErrors() = %d, want 1", em.GetWriteErrors()) + } + if em.GetDroppedEvents() != 3 { + t.Errorf("GetDroppedEvents() = %d, want 3", em.GetDroppedEvents()) + } +} + +func TestErrorMetrics_ConcurrentAccess(t *testing.T) { + var em ErrorMetrics + + // 并发记录错误 + done := make(chan bool, 3) + + go func() { + for i := 0; i < 100; i++ { + em.RecordFlushError() + } + done <- true + }() + + go func() { + for i := 0; i < 100; i++ { + em.RecordWriteError() + } + done <- true + }() + + go func() { + for i := 0; i < 100; i++ { + em.RecordDroppedEvent() + } + done <- true + }() + + // 等待所有 goroutine 完成 + for i := 0; i < 3; i++ { + <-done + } + + // 验证计数正确 + if em.GetFlushErrors() != 100 { + t.Errorf("GetFlushErrors() = %d, want 100", em.GetFlushErrors()) + } + if em.GetWriteErrors() != 100 { + t.Errorf("GetWriteErrors() = %d, want 100", em.GetWriteErrors()) + } + if em.GetDroppedEvents() != 100 { + t.Errorf("GetDroppedEvents() = %d, want 100", em.GetDroppedEvents()) + } +} + +func TestAsyncLogWriter_Metrics(t *testing.T) { + sink := &mockLogSink{ + appendDecisionError: errors.New("write error"), + } + + writer := NewAsyncLogWriter(sink, AsyncLogWriterOptions{ + QueueSize: 10, + FlushInterval: time.Hour, + MaxBatchSize: 2, + FallbackWriteTimeout: time.Second, + }) + defer writer.Close() + + // 触发写入错误 + _ = writer.AppendDecision(context.Background(), RouteDecisionEvent{ + RequestID: "test-1", + LogicalGroupID: "test-group", + }) + + // 等待 flush 完成 + time.Sleep(200 * time.Millisecond) + + // 强制执行 flush + _ = writer.Flush(context.Background()) + + // 验证指标被记录 + metrics := writer.Metrics() + + // 由于 batch flush 时会记录错误,应该有 flush error + if metrics.GetFlushErrors() == 0 { + t.Error("Expected FlushErrors > 0 after failed write") + } +} + +func TestAsyncLogWriter_ErrorHandler(t *testing.T) { + var handledErrors []string + errorHandler := func(ctx context.Context, err error, eventType string) { + handledErrors = append(handledErrors, eventType+":"+err.Error()) + } + + sink := &mockLogSink{ + appendDecisionError: errors.New("decision error"), + } + + writer := NewAsyncLogWriter(sink, AsyncLogWriterOptions{ + QueueSize: 10, + FlushInterval: time.Hour, + MaxBatchSize: 1, // 立即触发 flush + FallbackWriteTimeout: time.Second, + OnError: errorHandler, + }) + defer writer.Close() + + // 触发写入 + _ = writer.AppendDecision(context.Background(), RouteDecisionEvent{ + RequestID: "test-1", + LogicalGroupID: "test-group", + }) + + // 等待处理 + time.Sleep(200 * time.Millisecond) + _ = writer.Flush(context.Background()) + + // 验证错误处理器被调用 + if len(handledErrors) == 0 { + t.Error("Expected error handler to be called") + } +} + +func TestAsyncLogWriter_getEventType(t *testing.T) { + writer := &AsyncLogWriter{} + + tests := []struct { + name string + event queuedLogEvent + expected string + }{ + { + name: "decision", + event: queuedLogEvent{decision: &RouteDecisionEvent{}}, + expected: "decision", + }, + { + name: "failover", + event: queuedLogEvent{failover: &RouteFailoverEvent{}}, + expected: "failover", + }, + { + name: "sticky_audit", + event: queuedLogEvent{sticky: &RouteStickyAuditEvent{}}, + expected: "sticky_audit", + }, + { + name: "unknown", + event: queuedLogEvent{}, + expected: "unknown", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := writer.getEventType(tt.event) + if got != tt.expected { + t.Errorf("getEventType() = %q, want %q", got, tt.expected) + } + }) + } +} + +func TestAsyncLogWriter_DroppedEventMetrics(t *testing.T) { + sink := &mockLogSink{} + + // 创建只有 1 个缓冲区的 writer + writer := NewAsyncLogWriter(sink, AsyncLogWriterOptions{ + QueueSize: 1, + FlushInterval: time.Hour, // 不自动 flush + MaxBatchSize: 10, + FallbackWriteTimeout: time.Second, + }) + defer writer.Close() + + // 填满队列并触发丢弃 + for i := 0; i < 5; i++ { + _ = writer.AppendDecision(context.Background(), RouteDecisionEvent{ + RequestID: "test-" + string(rune('0'+i)), + LogicalGroupID: "test-group", + }) + } + + // 给 fallback 写入一点时间 + time.Sleep(100 * time.Millisecond) + + // 验证有事件被记录为丢弃 + metrics := writer.Metrics() + if metrics.GetDroppedEvents() == 0 { + t.Error("Expected DroppedEvents > 0 when queue is full") + } +}