fix(supply-api): add missing runtime support sources
Check in the healthcheck, structured logging, outbox broker, partition manager, and token status repository files that the committed supply-api runtime already imports. Verified with fresh go test runs for cmd/supply-api, internal/httpapi, internal/pkg/logging, internal/repository, and internal/outbox.
This commit is contained in:
293
supply-api/internal/httpapi/healthcheck.go
Normal file
293
supply-api/internal/httpapi/healthcheck.go
Normal file
@@ -0,0 +1,293 @@
|
||||
package httpapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ==================== P1-007 健康检查和就绪探针 ====================
|
||||
|
||||
// HealthChecker 健康检查接口
|
||||
type HealthChecker interface {
|
||||
// Check 执行健康检查
|
||||
Check(ctx context.Context) error
|
||||
// Name 返回检查器名称
|
||||
Name() string
|
||||
}
|
||||
|
||||
// HealthResponse 健康检查响应
|
||||
type HealthResponse struct {
|
||||
Status string `json:"status"` // "healthy" | "unhealthy" | "degraded"
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
DurationMs int64 `json:"duration_ms"`
|
||||
Checks []HealthCheckResult `json:"checks,omitempty"`
|
||||
}
|
||||
|
||||
// HealthCheckResult 单个检查结果
|
||||
type HealthCheckResult struct {
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"` // "ok" | "error" | "warn"
|
||||
DurationMs int64 `json:"duration_ms"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
// ReadinessResponse 就绪检查响应
|
||||
type ReadinessResponse struct {
|
||||
Status string `json:"status"` // "ready" | "not_ready"
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Checks []ReadinessCheckResult `json:"checks,omitempty"`
|
||||
}
|
||||
|
||||
// ReadinessCheckResult 就绪检查结果
|
||||
type ReadinessCheckResult struct {
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"` // "ready" | "not_ready"
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// LivenessResponse 存活检查响应
|
||||
type LivenessResponse struct {
|
||||
Status string `json:"status"` // "alive"
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
// DefaultHealthChecker 默认健康检查器
|
||||
type DefaultHealthChecker struct {
|
||||
checks []HealthChecker
|
||||
}
|
||||
|
||||
// NewDefaultHealthChecker 创建默认健康检查器
|
||||
func NewDefaultHealthChecker() *DefaultHealthChecker {
|
||||
return &DefaultHealthChecker{
|
||||
checks: make([]HealthChecker, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// AddCheck 添加健康检查
|
||||
func (h *DefaultHealthChecker) AddCheck(checker HealthChecker) {
|
||||
h.checks = append(h.checks, checker)
|
||||
}
|
||||
|
||||
// Check 执行所有健康检查
|
||||
func (h *DefaultHealthChecker) Check(ctx context.Context) error {
|
||||
for _, checker := range h.checks {
|
||||
if err := checker.Check(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Name 返回名称
|
||||
func (h *DefaultHealthChecker) Name() string {
|
||||
return "default"
|
||||
}
|
||||
|
||||
// DBHealthChecker 数据库健康检查
|
||||
type DBHealthChecker struct {
|
||||
checkFn func(ctx context.Context) error
|
||||
}
|
||||
|
||||
// NewDBHealthChecker 创建数据库健康检查
|
||||
func NewDBHealthChecker(checkFn func(ctx context.Context) error) *DBHealthChecker {
|
||||
return &DBHealthChecker{checkFn: checkFn}
|
||||
}
|
||||
|
||||
func (c *DBHealthChecker) Check(ctx context.Context) error {
|
||||
return c.checkFn(ctx)
|
||||
}
|
||||
|
||||
func (c *DBHealthChecker) Name() string {
|
||||
return "database"
|
||||
}
|
||||
|
||||
// CacheHealthChecker 缓存健康检查
|
||||
type CacheHealthChecker struct {
|
||||
checkFn func(ctx context.Context) error
|
||||
}
|
||||
|
||||
// NewCacheHealthChecker 创建缓存健康检查
|
||||
func NewCacheHealthChecker(checkFn func(ctx context.Context) error) *CacheHealthChecker {
|
||||
return &CacheHealthChecker{checkFn: checkFn}
|
||||
}
|
||||
|
||||
func (c *CacheHealthChecker) Check(ctx context.Context) error {
|
||||
return c.checkFn(ctx)
|
||||
}
|
||||
|
||||
func (c *CacheHealthChecker) Name() string {
|
||||
return "cache"
|
||||
}
|
||||
|
||||
// HealthHandler 健康检查处理器
|
||||
type HealthHandler struct {
|
||||
healthChecker *DefaultHealthChecker
|
||||
readinessChecks []HealthChecker
|
||||
livenessChecks []HealthChecker
|
||||
}
|
||||
|
||||
// NewHealthHandler 创建健康检查处理器
|
||||
func NewHealthHandler() *HealthHandler {
|
||||
return &HealthHandler{
|
||||
healthChecker: NewDefaultHealthChecker(),
|
||||
readinessChecks: make([]HealthChecker, 0),
|
||||
livenessChecks: make([]HealthChecker, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// NewHealthHandlerWithDefaults 创建带默认检查器的健康检查处理器
|
||||
// P1-007修复: 统一健康检查实现,消除重复的inline handlers
|
||||
func NewHealthHandlerWithDefaults(
|
||||
dbHealthCheck func(ctx context.Context) error,
|
||||
redisHealthCheck func(ctx context.Context) error,
|
||||
) *HealthHandler {
|
||||
h := NewHealthHandler()
|
||||
|
||||
if dbHealthCheck != nil {
|
||||
h.AddHealthCheck(NewDBHealthChecker(dbHealthCheck))
|
||||
h.AddReadinessCheck(NewDBHealthChecker(dbHealthCheck))
|
||||
}
|
||||
|
||||
if redisHealthCheck != nil {
|
||||
h.AddHealthCheck(NewCacheHealthChecker(redisHealthCheck))
|
||||
h.AddReadinessCheck(NewCacheHealthChecker(redisHealthCheck))
|
||||
}
|
||||
|
||||
// 存活检查总是返回OK(不需要外部依赖)
|
||||
h.AddLivenessCheck(&staticHealthChecker{name: "liveness", err: nil})
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// staticHealthChecker 静态健康检查器
|
||||
type staticHealthChecker struct {
|
||||
name string
|
||||
err error
|
||||
}
|
||||
|
||||
func (c *staticHealthChecker) Check(ctx context.Context) error {
|
||||
return c.err
|
||||
}
|
||||
|
||||
func (c *staticHealthChecker) Name() string {
|
||||
return c.name
|
||||
}
|
||||
|
||||
// AddHealthCheck 添加健康检查
|
||||
func (h *HealthHandler) AddHealthCheck(checker HealthChecker) {
|
||||
h.healthChecker.AddCheck(checker)
|
||||
}
|
||||
|
||||
// AddReadinessCheck 添加就绪检查
|
||||
func (h *HealthHandler) AddReadinessCheck(checker HealthChecker) {
|
||||
h.readinessChecks = append(h.readinessChecks, checker)
|
||||
}
|
||||
|
||||
// AddLivenessCheck 添加存活检查
|
||||
func (h *HealthHandler) AddLivenessCheck(checker HealthChecker) {
|
||||
h.livenessChecks = append(h.livenessChecks, checker)
|
||||
}
|
||||
|
||||
// ServeHealth 处理健康检查请求
|
||||
func (h *HealthHandler) ServeHealth(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
start := time.Now()
|
||||
|
||||
response := HealthResponse{
|
||||
Timestamp: start,
|
||||
DurationMs: 0,
|
||||
Checks: make([]HealthCheckResult, 0),
|
||||
}
|
||||
|
||||
overallStatus := "healthy"
|
||||
|
||||
for _, checker := range h.healthChecker.checks {
|
||||
checkStart := time.Now()
|
||||
err := checker.Check(ctx)
|
||||
|
||||
result := HealthCheckResult{
|
||||
Name: checker.Name(),
|
||||
DurationMs: time.Since(checkStart).Milliseconds(),
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
result.Status = "error"
|
||||
result.Error = err.Error()
|
||||
overallStatus = "unhealthy"
|
||||
} else {
|
||||
result.Status = "ok"
|
||||
}
|
||||
|
||||
response.Checks = append(response.Checks, result)
|
||||
}
|
||||
|
||||
response.Status = overallStatus
|
||||
response.DurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
// ServeReadiness 处理就绪检查请求
|
||||
func (h *HealthHandler) ServeReadiness(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
response := ReadinessResponse{
|
||||
Timestamp: time.Now(),
|
||||
Checks: make([]ReadinessCheckResult, 0),
|
||||
}
|
||||
|
||||
overallStatus := "ready"
|
||||
|
||||
for _, checker := range h.readinessChecks {
|
||||
err := checker.Check(ctx)
|
||||
|
||||
result := ReadinessCheckResult{
|
||||
Name: checker.Name(),
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
result.Status = "not_ready"
|
||||
result.Error = err.Error()
|
||||
overallStatus = "not_ready"
|
||||
} else {
|
||||
result.Status = "ready"
|
||||
}
|
||||
|
||||
response.Checks = append(response.Checks, result)
|
||||
}
|
||||
|
||||
response.Status = overallStatus
|
||||
|
||||
statusCode := http.StatusOK
|
||||
if overallStatus == "not_ready" {
|
||||
statusCode = http.StatusServiceUnavailable
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(statusCode)
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
// ServeLiveness 处理存活检查请求
|
||||
func (h *HealthHandler) ServeLiveness(w http.ResponseWriter, r *http.Request) {
|
||||
response := LivenessResponse{
|
||||
Status: "alive",
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
// RegisterHealthRoutes 注册健康检查路由
|
||||
func (h *HealthHandler) RegisterRoutes(mux *http.ServeMux) {
|
||||
mux.HandleFunc("/actuator/health", h.ServeHealth)
|
||||
mux.HandleFunc("/actuator/health/ready", h.ServeReadiness)
|
||||
mux.HandleFunc("/actuator/health/live", h.ServeLiveness)
|
||||
}
|
||||
187
supply-api/internal/httpapi/healthcheck_test.go
Normal file
187
supply-api/internal/httpapi/healthcheck_test.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package httpapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// mockHealthChecker Mock健康检查器
|
||||
type mockHealthChecker struct {
|
||||
name string
|
||||
healthy bool
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockHealthChecker) Check(ctx context.Context) error {
|
||||
return m.err
|
||||
}
|
||||
|
||||
func (m *mockHealthChecker) Name() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
// TestP107_HealthEndpoint 健康端点
|
||||
func TestP107_HealthEndpoint(t *testing.T) {
|
||||
handler := NewHealthHandler()
|
||||
|
||||
// 添加模拟检查
|
||||
handler.AddHealthCheck(&mockHealthChecker{name: "db", healthy: true, err: nil})
|
||||
handler.AddHealthCheck(&mockHealthChecker{name: "cache", healthy: true, err: nil})
|
||||
|
||||
req := httptest.NewRequest("GET", "/health", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHealth(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var response HealthResponse
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if response.Status != "healthy" {
|
||||
t.Errorf("expected status healthy, got %s", response.Status)
|
||||
}
|
||||
|
||||
if len(response.Checks) != 2 {
|
||||
t.Errorf("expected 2 checks, got %d", len(response.Checks))
|
||||
}
|
||||
|
||||
t.Log("P1-07: 健康端点验证通过")
|
||||
}
|
||||
|
||||
// TestP107_ReadinessEndpoint 就绪端点
|
||||
func TestP107_ReadinessEndpoint(t *testing.T) {
|
||||
handler := NewHealthHandler()
|
||||
|
||||
// 添加模拟检查
|
||||
handler.AddReadinessCheck(&mockHealthChecker{name: "db", healthy: true, err: nil})
|
||||
|
||||
req := httptest.NewRequest("GET", "/ready", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeReadiness(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var response ReadinessResponse
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if response.Status != "ready" {
|
||||
t.Errorf("expected status ready, got %s", response.Status)
|
||||
}
|
||||
|
||||
t.Log("P1-07: 就绪端点验证通过")
|
||||
}
|
||||
|
||||
// TestP107_ReadinessNotReady 就绪检查失败
|
||||
func TestP107_ReadinessNotReady(t *testing.T) {
|
||||
handler := NewHealthHandler()
|
||||
|
||||
// 添加失败的检查
|
||||
handler.AddReadinessCheck(&mockHealthChecker{name: "db", healthy: false, err: context.DeadlineExceeded})
|
||||
|
||||
req := httptest.NewRequest("GET", "/ready", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeReadiness(w, req)
|
||||
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("expected status 503, got %d", w.Code)
|
||||
}
|
||||
|
||||
var response ReadinessResponse
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if response.Status != "not_ready" {
|
||||
t.Errorf("expected status not_ready, got %s", response.Status)
|
||||
}
|
||||
|
||||
t.Log("P1-07: 就绪检查失败验证通过")
|
||||
}
|
||||
|
||||
// TestP107_LivenessEndpoint 存活端点
|
||||
func TestP107_LivenessEndpoint(t *testing.T) {
|
||||
handler := NewHealthHandler()
|
||||
|
||||
req := httptest.NewRequest("GET", "/live", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeLiveness(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var response LivenessResponse
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if response.Status != "alive" {
|
||||
t.Errorf("expected status alive, got %s", response.Status)
|
||||
}
|
||||
|
||||
t.Log("P1-07: 存活端点验证通过")
|
||||
}
|
||||
|
||||
// TestP107_HealthCheckerInterface 健康检查器接口
|
||||
func TestP107_HealthCheckerInterface(t *testing.T) {
|
||||
// 验证实现了HealthChecker接口
|
||||
var _ HealthChecker = &DBHealthChecker{}
|
||||
var _ HealthChecker = &CacheHealthChecker{}
|
||||
|
||||
t.Log("P1-07: 健康检查器接口验证通过")
|
||||
}
|
||||
|
||||
// TestP107_ResponseTimestamps 响应时间戳
|
||||
func TestP107_ResponseTimestamps(t *testing.T) {
|
||||
handler := NewHealthHandler()
|
||||
handler.AddHealthCheck(&mockHealthChecker{name: "db", healthy: true, err: nil})
|
||||
|
||||
req := httptest.NewRequest("GET", "/health", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHealth(w, req)
|
||||
|
||||
var response HealthResponse
|
||||
json.Unmarshal(w.Body.Bytes(), &response)
|
||||
|
||||
if response.Timestamp.IsZero() {
|
||||
t.Error("timestamp should not be zero")
|
||||
}
|
||||
|
||||
if response.DurationMs < 0 {
|
||||
t.Error("duration should be non-negative")
|
||||
}
|
||||
|
||||
t.Logf("P1-07: 响应时间戳验证通过 (duration=%dms)", response.DurationMs)
|
||||
}
|
||||
|
||||
// TestP107_Summary 测试总结
|
||||
func TestP107_Summary(t *testing.T) {
|
||||
t.Log("=== P1-007 健康检查和就绪探针测试总结 ===")
|
||||
t.Log("问题: 中间件文档提到排除健康检查,但未定义具体端点和逻辑")
|
||||
t.Log("")
|
||||
t.Log("修复方案:")
|
||||
t.Log(" - /health: 综合健康检查 (db, cache, external services)")
|
||||
t.Log(" - /ready: 就绪探针 (依赖项是否就绪)")
|
||||
t.Log(" - /live: 存活探针 (服务是否存活)")
|
||||
t.Log("")
|
||||
t.Log("响应示例:")
|
||||
t.Log(" /health: {status: healthy, checks: [...]}")
|
||||
t.Log(" /ready: {status: ready, checks: [...]}")
|
||||
t.Log(" /live: {status: alive}")
|
||||
}
|
||||
90
supply-api/internal/messaging/outbox_broker.go
Normal file
90
supply-api/internal/messaging/outbox_broker.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package messaging
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"lijiaoqiao/supply-api/internal/repository"
|
||||
)
|
||||
|
||||
// OutboxMessageBroker Outbox消息代理(使用Redis Streams)
|
||||
type OutboxMessageBroker struct {
|
||||
redis *redis.Client
|
||||
streamName string
|
||||
consumerGroup string
|
||||
}
|
||||
|
||||
// NewOutboxMessageBroker 创建Outbox消息代理
|
||||
func NewOutboxMessageBroker(redisClient *redis.Client, streamName string, consumerGroup string) *OutboxMessageBroker {
|
||||
return &OutboxMessageBroker{
|
||||
redis: redisClient,
|
||||
streamName: streamName,
|
||||
consumerGroup: consumerGroup,
|
||||
}
|
||||
}
|
||||
|
||||
// Publish 发布消息到Redis Stream
|
||||
func (b *OutboxMessageBroker) Publish(ctx context.Context, event *repository.OutboxEvent) error {
|
||||
// 构造消息
|
||||
msg := map[string]interface{}{
|
||||
"event_id": event.EventID,
|
||||
"aggregate_type": event.AggregateType,
|
||||
"aggregate_id": event.AggregateID,
|
||||
"event_type": event.EventType,
|
||||
"payload": string(event.Payload),
|
||||
"published_at": time.Now().UTC().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal message: %w", err)
|
||||
}
|
||||
|
||||
// 发布到Stream
|
||||
_, err = b.redis.XAdd(ctx, &redis.XAddArgs{
|
||||
Stream: b.streamName,
|
||||
ID: "*", // 自动生成ID
|
||||
Values: map[string]interface{}{
|
||||
"data": string(data),
|
||||
},
|
||||
}).Result()
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to publish to stream: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// EnsureConsumerGroup 确保消费者组存在
|
||||
func (b *OutboxMessageBroker) EnsureConsumerGroup(ctx context.Context) error {
|
||||
err := b.redis.XGroupCreateMkStream(ctx, b.streamName, b.consumerGroup, "0").Err()
|
||||
if err != nil && err.Error() != "BUSYGROUP Consumer Group name already exists" {
|
||||
return fmt.Errorf("failed to create consumer group: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MessageBroker 消息代理接口
|
||||
type MessageBroker interface {
|
||||
Publish(ctx context.Context, event *repository.OutboxEvent) error
|
||||
}
|
||||
|
||||
// OutboxStats Outbox统计接口
|
||||
type OutboxStats interface {
|
||||
RecordOutboxSuccess(eventType string)
|
||||
RecordOutboxFailure(reason string)
|
||||
RecordOutboxRetry(eventType string)
|
||||
RecordOutboxDLQ(eventType string)
|
||||
}
|
||||
|
||||
// NoOpOutboxStats 无操作统计(用于默认实现)
|
||||
type NoOpOutboxStats struct{}
|
||||
|
||||
func (s *NoOpOutboxStats) RecordOutboxSuccess(eventType string) {}
|
||||
func (s *NoOpOutboxStats) RecordOutboxFailure(reason string) {}
|
||||
func (s *NoOpOutboxStats) RecordOutboxRetry(eventType string) {}
|
||||
func (s *NoOpOutboxStats) RecordOutboxDLQ(eventType string) {}
|
||||
233
supply-api/internal/pkg/logging/logger.go
Normal file
233
supply-api/internal/pkg/logging/logger.go
Normal file
@@ -0,0 +1,233 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ==================== P1-010 日志规范 ====================
|
||||
|
||||
// LogLevel 日志级别
|
||||
type LogLevel string
|
||||
|
||||
const (
|
||||
LogLevelDebug LogLevel = "DEBUG"
|
||||
LogLevelInfo LogLevel = "INFO"
|
||||
LogLevelWarn LogLevel = "WARN"
|
||||
LogLevelError LogLevel = "ERROR"
|
||||
LogLevelFatal LogLevel = "FATAL"
|
||||
)
|
||||
|
||||
// LogEntry 标准日志条目(JSON格式)
|
||||
type LogEntry struct {
|
||||
Timestamp string `json:"timestamp"` // ISO8601格式
|
||||
Level string `json:"level"` // DEBUG|INFO|WARN|ERROR|FATAL
|
||||
Service string `json:"service"` // 服务名称
|
||||
TraceID string `json:"trace_id,omitempty"` // 追踪ID
|
||||
SpanID string `json:"span_id,omitempty"` // Span ID
|
||||
RequestID string `json:"request_id,omitempty"` // 请求ID
|
||||
Message string `json:"message"` // 日志消息
|
||||
Fields map[string]interface{} `json:"fields,omitempty"` // 额外字段
|
||||
}
|
||||
|
||||
// Logger 日志接口
|
||||
type Logger interface {
|
||||
Debug(msg string, fields ...map[string]interface{})
|
||||
Info(msg string, fields ...map[string]interface{})
|
||||
Warn(msg string, fields ...map[string]interface{})
|
||||
Error(msg string, fields ...map[string]interface{})
|
||||
Fatal(msg string, fields ...map[string]interface{})
|
||||
}
|
||||
|
||||
// jsonLogger JSON格式日志实现
|
||||
type jsonLogger struct {
|
||||
service string
|
||||
minLevel LogLevel
|
||||
output *os.File
|
||||
}
|
||||
|
||||
// NewLogger 创建日志实例
|
||||
func NewLogger(service string, minLevel LogLevel) *jsonLogger {
|
||||
return &jsonLogger{
|
||||
service: service,
|
||||
minLevel: minLevel,
|
||||
output: os.Stdout,
|
||||
}
|
||||
}
|
||||
|
||||
// shouldLog 检查是否应该记录此级别
|
||||
func (l *jsonLogger) shouldLog(level LogLevel) bool {
|
||||
levels := map[LogLevel]int{
|
||||
LogLevelDebug: 0,
|
||||
LogLevelInfo: 1,
|
||||
LogLevelWarn: 2,
|
||||
LogLevelError: 3,
|
||||
LogLevelFatal: 4,
|
||||
}
|
||||
|
||||
return levels[level] >= levels[l.minLevel]
|
||||
}
|
||||
|
||||
// formatEntry 格式化日志条目
|
||||
func (l *jsonLogger) formatEntry(level LogLevel, msg string, fields map[string]interface{}) *LogEntry {
|
||||
entry := &LogEntry{
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339Nano),
|
||||
Level: string(level),
|
||||
Service: l.service,
|
||||
Message: msg,
|
||||
}
|
||||
|
||||
// 添加fields
|
||||
if fields != nil {
|
||||
entry.Fields = sanitizeFields(fields)
|
||||
}
|
||||
|
||||
return entry
|
||||
}
|
||||
|
||||
// log 输出日志
|
||||
func (l *jsonLogger) log(level LogLevel, msg string, fields map[string]interface{}) {
|
||||
if !l.shouldLog(level) {
|
||||
return
|
||||
}
|
||||
|
||||
entry := l.formatEntry(level, msg, fields)
|
||||
|
||||
// 序列化为JSON
|
||||
data, err := json.Marshal(entry)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 添加换行符
|
||||
l.output.Write(append(data, '\n'))
|
||||
}
|
||||
|
||||
func (l *jsonLogger) Debug(msg string, fields ...map[string]interface{}) {
|
||||
var f map[string]interface{}
|
||||
if len(fields) > 0 {
|
||||
f = fields[0]
|
||||
}
|
||||
l.log(LogLevelDebug, msg, f)
|
||||
}
|
||||
|
||||
func (l *jsonLogger) Info(msg string, fields ...map[string]interface{}) {
|
||||
var f map[string]interface{}
|
||||
if len(fields) > 0 {
|
||||
f = fields[0]
|
||||
}
|
||||
l.log(LogLevelInfo, msg, f)
|
||||
}
|
||||
|
||||
func (l *jsonLogger) Warn(msg string, fields ...map[string]interface{}) {
|
||||
var f map[string]interface{}
|
||||
if len(fields) > 0 {
|
||||
f = fields[0]
|
||||
}
|
||||
l.log(LogLevelWarn, msg, f)
|
||||
}
|
||||
|
||||
func (l *jsonLogger) Error(msg string, fields ...map[string]interface{}) {
|
||||
var f map[string]interface{}
|
||||
if len(fields) > 0 {
|
||||
f = fields[0]
|
||||
}
|
||||
l.log(LogLevelError, msg, f)
|
||||
}
|
||||
|
||||
func (l *jsonLogger) Fatal(msg string, fields ...map[string]interface{}) {
|
||||
var f map[string]interface{}
|
||||
if len(fields) > 0 {
|
||||
f = fields[0]
|
||||
}
|
||||
l.log(LogLevelFatal, msg, f)
|
||||
}
|
||||
|
||||
// sanitizeFields 敏感字段脱敏
|
||||
func sanitizeFields(fields map[string]interface{}) map[string]interface{} {
|
||||
sanitized := make(map[string]interface{})
|
||||
|
||||
sensitiveKeys := []string{
|
||||
"password", "secret", "token", "api_key", "apikey",
|
||||
"credential", "authorization", "private_key",
|
||||
"credit_card", "ssn", "passport",
|
||||
}
|
||||
|
||||
for k, v := range fields {
|
||||
lowerK := toLower(k)
|
||||
for _, sensitive := range sensitiveKeys {
|
||||
if contains(lowerK, sensitive) {
|
||||
sanitized[k] = "[REDACTED]"
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := sanitized[k]; !ok {
|
||||
// 检查嵌套map
|
||||
if nestedMap, ok := v.(map[string]interface{}); ok {
|
||||
sanitized[k] = sanitizeFields(nestedMap)
|
||||
} else {
|
||||
sanitized[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return sanitized
|
||||
}
|
||||
|
||||
// String helpers
|
||||
func toLower(s string) string {
|
||||
result := make([]byte, len(s))
|
||||
for i := 0; i < len(s); i++ {
|
||||
c := s[i]
|
||||
if c >= 'A' && c <= 'Z' {
|
||||
c += 'a' - 'A'
|
||||
}
|
||||
result[i] = c
|
||||
}
|
||||
return string(result)
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// LogFieldKeys 日志字段名常量(防止拼写错误)
|
||||
const (
|
||||
FieldKeyTenantID = "tenant_id"
|
||||
FieldKeyUserID = "user_id"
|
||||
FieldKeyRequestID = "request_id"
|
||||
FieldKeyTraceID = "trace_id"
|
||||
FieldKeySpanID = "span_id"
|
||||
FieldKeyOperation = "operation"
|
||||
FieldKeyDuration = "duration_ms"
|
||||
FieldKeyStatusCode = "status_code"
|
||||
FieldKeyError = "error"
|
||||
FieldKeyErrorCode = "error_code"
|
||||
FieldKeyClientIP = "client_ip"
|
||||
FieldKeyUserAgent = "user_agent"
|
||||
FieldKeyMethod = "method"
|
||||
FieldKeyPath = "path"
|
||||
FieldKeyQuery = "query"
|
||||
FieldKeyRoute = "route"
|
||||
)
|
||||
|
||||
// 标准字段常量
|
||||
var SensitiveFields = []string{
|
||||
"password",
|
||||
"secret",
|
||||
"token",
|
||||
"api_key",
|
||||
"apikey",
|
||||
"credential",
|
||||
"authorization",
|
||||
"private_key",
|
||||
"credit_card",
|
||||
"ssn",
|
||||
}
|
||||
283
supply-api/internal/pkg/logging/logger_test.go
Normal file
283
supply-api/internal/pkg/logging/logger_test.go
Normal file
@@ -0,0 +1,283 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// captureLogger 捕获日志输出的测试Logger
|
||||
type captureLogger struct {
|
||||
*jsonLogger
|
||||
outputBuffer *strings.Builder
|
||||
}
|
||||
|
||||
func newCaptureLogger() *captureLogger {
|
||||
buf := &strings.Builder{}
|
||||
return &captureLogger{
|
||||
jsonLogger: &jsonLogger{
|
||||
service: "test-service",
|
||||
minLevel: LogLevelDebug,
|
||||
output: os.Stdout, // 实际输出到stdout但我们可以捕获
|
||||
},
|
||||
outputBuffer: buf,
|
||||
}
|
||||
}
|
||||
|
||||
// 重写Info方法以捕获输出
|
||||
func (l *captureLogger) Info(msg string, fields ...map[string]interface{}) {
|
||||
var f map[string]interface{}
|
||||
if len(fields) > 0 {
|
||||
f = fields[0]
|
||||
}
|
||||
l.log(LogLevelInfo, msg, f)
|
||||
}
|
||||
|
||||
// 重写Debug方法以捕获输出
|
||||
func (l *captureLogger) Debug(msg string, fields ...map[string]interface{}) {
|
||||
var f map[string]interface{}
|
||||
if len(fields) > 0 {
|
||||
f = fields[0]
|
||||
}
|
||||
l.log(LogLevelDebug, msg, f)
|
||||
}
|
||||
|
||||
// 重写Warn方法以捕获输出
|
||||
func (l *captureLogger) Warn(msg string, fields ...map[string]interface{}) {
|
||||
var f map[string]interface{}
|
||||
if len(fields) > 0 {
|
||||
f = fields[0]
|
||||
}
|
||||
l.log(LogLevelWarn, msg, f)
|
||||
}
|
||||
|
||||
// 重写Error方法以捕获输出
|
||||
func (l *captureLogger) Error(msg string, fields ...map[string]interface{}) {
|
||||
var f map[string]interface{}
|
||||
if len(fields) > 0 {
|
||||
f = fields[0]
|
||||
}
|
||||
l.log(LogLevelError, msg, f)
|
||||
}
|
||||
|
||||
// 重写Fatal方法以捕获输出
|
||||
func (l *captureLogger) Fatal(msg string, fields ...map[string]interface{}) {
|
||||
var f map[string]interface{}
|
||||
if len(fields) > 0 {
|
||||
f = fields[0]
|
||||
}
|
||||
l.log(LogLevelFatal, msg, f)
|
||||
}
|
||||
|
||||
// log 方法实际写入 outputBuffer
|
||||
func (l *captureLogger) log(level LogLevel, msg string, fields map[string]interface{}) {
|
||||
if !l.shouldLog(level) {
|
||||
return
|
||||
}
|
||||
|
||||
entry := l.formatEntry(level, msg, fields)
|
||||
|
||||
data, err := json.Marshal(entry)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
l.outputBuffer.Write(data)
|
||||
l.outputBuffer.WriteString("\n")
|
||||
}
|
||||
|
||||
// TestP110_LogLevels 日志级别
|
||||
func TestP110_LogLevels(t *testing.T) {
|
||||
logger := &jsonLogger{
|
||||
service: "test",
|
||||
minLevel: LogLevelInfo,
|
||||
output: os.Stdout,
|
||||
}
|
||||
|
||||
// Info及以上应该记录
|
||||
if !logger.shouldLog(LogLevelInfo) {
|
||||
t.Error("Info should be logged")
|
||||
}
|
||||
if !logger.shouldLog(LogLevelWarn) {
|
||||
t.Error("Warn should be logged")
|
||||
}
|
||||
if !logger.shouldLog(LogLevelError) {
|
||||
t.Error("Error should be logged")
|
||||
}
|
||||
|
||||
// Debug不应该记录
|
||||
if logger.shouldLog(LogLevelDebug) {
|
||||
t.Error("Debug should not be logged when minLevel is Info")
|
||||
}
|
||||
|
||||
t.Log("P1-10: 日志级别验证通过")
|
||||
}
|
||||
|
||||
// TestP110_JSONFormat JSON格式验证
|
||||
func TestP110_JSONFormat(t *testing.T) {
|
||||
logger := newCaptureLogger()
|
||||
logger.Info("test message", map[string]interface{}{
|
||||
FieldKeyTenantID: 123,
|
||||
FieldKeyRequestID: "req-123",
|
||||
})
|
||||
|
||||
output := logger.outputBuffer.String()
|
||||
|
||||
// 验证是有效JSON
|
||||
var entry LogEntry
|
||||
if err := json.Unmarshal([]byte(output), &entry); err != nil {
|
||||
t.Fatalf("output is not valid JSON: %v\noutput: %s", err, output)
|
||||
}
|
||||
|
||||
// 验证字段
|
||||
if entry.Level != "INFO" {
|
||||
t.Errorf("expected level INFO, got %s", entry.Level)
|
||||
}
|
||||
if entry.Service != "test-service" {
|
||||
t.Errorf("expected service test-service, got %s", entry.Service)
|
||||
}
|
||||
if entry.Message != "test message" {
|
||||
t.Errorf("expected message 'test message', got %s", entry.Message)
|
||||
}
|
||||
|
||||
t.Log("P1-10: JSON格式验证通过")
|
||||
}
|
||||
|
||||
// TestP110_TimestampFormat 时间戳格式
|
||||
func TestP110_TimestampFormat(t *testing.T) {
|
||||
logger := newCaptureLogger()
|
||||
logger.Info("test")
|
||||
|
||||
var entry LogEntry
|
||||
json.Unmarshal([]byte(logger.outputBuffer.String()), &entry)
|
||||
|
||||
// 验证是RFC3339格式
|
||||
if !strings.Contains(entry.Timestamp, "T") {
|
||||
t.Error("timestamp should be in RFC3339 format")
|
||||
}
|
||||
|
||||
t.Log("P1-10: 时间戳格式验证通过")
|
||||
}
|
||||
|
||||
// TestP110_SensitiveFieldRedaction 敏感字段脱敏
|
||||
func TestP110_SensitiveFieldRedaction(t *testing.T) {
|
||||
logger := newCaptureLogger()
|
||||
logger.Info("test", map[string]interface{}{
|
||||
"password": "secret123",
|
||||
"api_key": "sk-abc123",
|
||||
"user_name": "john", // 非敏感字段
|
||||
"access_token": "tok-xyz",
|
||||
})
|
||||
|
||||
var entry LogEntry
|
||||
json.Unmarshal([]byte(logger.outputBuffer.String()), &entry)
|
||||
|
||||
fields := entry.Fields
|
||||
|
||||
// 验证敏感字段被脱敏
|
||||
if fields["password"] != "[REDACTED]" {
|
||||
t.Errorf("password should be redacted, got %v", fields["password"])
|
||||
}
|
||||
if fields["api_key"] != "[REDACTED]" {
|
||||
t.Errorf("api_key should be redacted, got %v", fields["api_key"])
|
||||
}
|
||||
if fields["access_token"] != "[REDACTED]" {
|
||||
t.Errorf("access_token should be redacted, got %v", fields["access_token"])
|
||||
}
|
||||
|
||||
// 非敏感字段不应被脱敏
|
||||
if fields["user_name"] != "john" {
|
||||
t.Errorf("user_name should not be redacted, got %v", fields["user_name"])
|
||||
}
|
||||
|
||||
t.Log("P1-10: 敏感字段脱敏验证通过")
|
||||
}
|
||||
|
||||
// TestP110_NestedSensitiveFields 嵌套敏感字段
|
||||
func TestP110_NestedSensitiveFields(t *testing.T) {
|
||||
logger := newCaptureLogger()
|
||||
logger.Info("test", map[string]interface{}{
|
||||
"user": map[string]interface{}{
|
||||
"name": "john",
|
||||
"password": "secret",
|
||||
},
|
||||
})
|
||||
|
||||
var entry LogEntry
|
||||
json.Unmarshal([]byte(logger.outputBuffer.String()), &entry)
|
||||
|
||||
fields := entry.Fields
|
||||
user := fields["user"].(map[string]interface{})
|
||||
|
||||
if user["password"] != "[REDACTED]" {
|
||||
t.Errorf("nested password should be redacted, got %v", user["password"])
|
||||
}
|
||||
if user["name"] != "john" {
|
||||
t.Errorf("nested name should not be redacted, got %v", user["name"])
|
||||
}
|
||||
|
||||
t.Log("P1-10: 嵌套敏感字段验证通过")
|
||||
}
|
||||
|
||||
// TestP110_LogFieldsConstants 日志字段常量
|
||||
func TestP110_LogFieldsConstants(t *testing.T) {
|
||||
// 验证字段常量定义正确
|
||||
if FieldKeyTenantID != "tenant_id" {
|
||||
t.Errorf("FieldKeyTenantID should be tenant_id")
|
||||
}
|
||||
if FieldKeyUserID != "user_id" {
|
||||
t.Errorf("FieldKeyUserID should be user_id")
|
||||
}
|
||||
if FieldKeyRequestID != "request_id" {
|
||||
t.Errorf("FieldKeyRequestID should be request_id")
|
||||
}
|
||||
if FieldKeyTraceID != "trace_id" {
|
||||
t.Errorf("FieldKeyTraceID should be trace_id")
|
||||
}
|
||||
if FieldKeyDuration != "duration_ms" {
|
||||
t.Errorf("FieldKeyDuration should be duration_ms")
|
||||
}
|
||||
|
||||
t.Log("P1-10: 日志字段常量验证通过")
|
||||
}
|
||||
|
||||
// TestP110_SensitiveFieldsList 敏感字段列表
|
||||
func TestP110_SensitiveFieldsList(t *testing.T) {
|
||||
expected := []string{
|
||||
"password", "secret", "token", "api_key", "apikey",
|
||||
"credential", "authorization", "private_key",
|
||||
"credit_card", "ssn",
|
||||
}
|
||||
|
||||
for _, exp := range expected {
|
||||
found := false
|
||||
for _, sens := range SensitiveFields {
|
||||
if sens == exp {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("expected sensitive field %s not found", exp)
|
||||
}
|
||||
}
|
||||
|
||||
t.Log("P1-10: 敏感字段列表验证通过")
|
||||
}
|
||||
|
||||
// TestP110_Summary 测试总结
|
||||
func TestP110_Summary(t *testing.T) {
|
||||
t.Log("=== P1-010 日志规范测试总结 ===")
|
||||
t.Log("问题: 所有文档均未定义日志级别、格式、结构化日志规范")
|
||||
t.Log("")
|
||||
t.Log("修复方案:")
|
||||
t.Log(" - JSON结构化日志")
|
||||
t.Log(" - 字段: timestamp, level, service, trace_id, request_id, message, fields")
|
||||
t.Log(" - 级别: DEBUG, INFO, WARN, ERROR, FATAL")
|
||||
t.Log(" - 敏感字段自动脱敏")
|
||||
t.Log(" - 时间戳: RFC3339Nano格式")
|
||||
t.Log("")
|
||||
t.Log("JSON示例:")
|
||||
t.Log(`{"timestamp":"2026-04-07T10:30:00.123Z","level":"INFO","service":"supply-api","request_id":"req-123","message":"request completed","fields":{"duration_ms":50,"status_code":200}}`)
|
||||
}
|
||||
190
supply-api/internal/repository/partition_manager.go
Normal file
190
supply-api/internal/repository/partition_manager.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
// PartitionConfig 分区配置
|
||||
type PartitionConfig struct {
|
||||
TableName string
|
||||
PartitionType string // RANGE, LIST
|
||||
PartitionKey string
|
||||
RetentionMonths int // 0 = 永久保留
|
||||
PreCreateMonths int
|
||||
}
|
||||
|
||||
// PartitionManager 分区管理器
|
||||
type PartitionManager struct {
|
||||
pool *pgxpool.Pool
|
||||
config map[string]*PartitionConfig
|
||||
}
|
||||
|
||||
// NewPartitionManager 创建分区管理器
|
||||
func NewPartitionManager(pool *pgxpool.Pool) *PartitionManager {
|
||||
return &PartitionManager{
|
||||
pool: pool,
|
||||
config: map[string]*PartitionConfig{
|
||||
"audit_events": {
|
||||
TableName: "audit_events",
|
||||
PartitionType: "RANGE",
|
||||
PartitionKey: "timestamp",
|
||||
RetentionMonths: 12,
|
||||
PreCreateMonths: 3,
|
||||
},
|
||||
"supply_usage_records": {
|
||||
TableName: "supply_usage_records",
|
||||
PartitionType: "RANGE",
|
||||
PartitionKey: "started_at",
|
||||
RetentionMonths: 3,
|
||||
PreCreateMonths: 3,
|
||||
},
|
||||
"supply_idempotency_records": {
|
||||
TableName: "supply_idempotency_records",
|
||||
PartitionType: "RANGE",
|
||||
PartitionKey: "expires_at",
|
||||
RetentionMonths: 1, // 保留1个月
|
||||
PreCreateMonths: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// EnsureFuturePartitions 确保未来分区已创建
|
||||
func (m *PartitionManager) EnsureFuturePartitions(ctx context.Context) error {
|
||||
for tableName, cfg := range m.config {
|
||||
for i := 0; i <= cfg.PreCreateMonths; i++ {
|
||||
futureDate := time.Now().AddDate(0, i, 0)
|
||||
if err := m.createPartition(ctx, tableName, futureDate); err != nil {
|
||||
return fmt.Errorf("failed to create partition for %s at %s: %w", tableName, futureDate.Format("2006-01"), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// createPartition 创建单个分区
|
||||
func (m *PartitionManager) createPartition(ctx context.Context, tableName string, partitionDate time.Time) error {
|
||||
startDate := time.Date(partitionDate.Year(), partitionDate.Month(), 1, 0, 0, 0, 0, time.UTC)
|
||||
endDate := startDate.AddDate(0, 1, 0)
|
||||
partitionName := fmt.Sprintf("%s_%s", tableName, startDate.Format("2006_01"))
|
||||
|
||||
// 检查分区是否已存在
|
||||
var exists bool
|
||||
checkQuery := `SELECT EXISTS(SELECT 1 FROM pg_class WHERE relname = $1)`
|
||||
if err := m.pool.QueryRow(ctx, checkQuery, partitionName).Scan(&exists); err != nil {
|
||||
return fmt.Errorf("failed to check partition existence: %w", err)
|
||||
}
|
||||
if exists {
|
||||
return nil // 分区已存在
|
||||
}
|
||||
|
||||
// 创建分区
|
||||
createQuery := fmt.Sprintf(
|
||||
"CREATE TABLE %s PARTITION OF %s FOR VALUES FROM ('%s') TO ('%s')",
|
||||
partitionName, tableName, startDate.Format("2006-01-02"), endDate.Format("2006-01-02"),
|
||||
)
|
||||
|
||||
if _, err := m.pool.Exec(ctx, createQuery); err != nil {
|
||||
return fmt.Errorf("failed to create partition: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DropOldPartitions 删除过期分区
|
||||
func (m *PartitionManager) DropOldPartitions(ctx context.Context, tableName string) (int, error) {
|
||||
cfg, ok := m.config[tableName]
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("unknown table: %s", tableName)
|
||||
}
|
||||
if cfg.RetentionMonths == 0 {
|
||||
return 0, nil // 永久保留,不删除
|
||||
}
|
||||
|
||||
cutoffDate := time.Now().AddDate(0, -cfg.RetentionMonths, 0)
|
||||
cutoffPrefix := fmt.Sprintf("%s_%s", tableName, cutoffDate.Format("2006_01"))
|
||||
|
||||
var droppedCount int
|
||||
|
||||
// 查询所有该表的分区
|
||||
query := `
|
||||
SELECT relname
|
||||
FROM pg_class c
|
||||
JOIN pg_namespace n ON n.oid = c.relnamespace
|
||||
WHERE n.nspname = 'public'
|
||||
AND relname ~ $1
|
||||
AND relname < $2
|
||||
`
|
||||
|
||||
rows, err := m.pool.Query(ctx, query, fmt.Sprintf("^%s_[0-9]{4}_[0-9]{2}$", tableName), cutoffPrefix)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to query partitions: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var partitionName string
|
||||
if err := rows.Scan(&partitionName); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 删除分区
|
||||
dropQuery := fmt.Sprintf("DROP TABLE IF EXISTS %s", partitionName)
|
||||
if _, err := m.pool.Exec(ctx, dropQuery); err != nil {
|
||||
return droppedCount, fmt.Errorf("failed to drop partition %s: %w", partitionName, err)
|
||||
}
|
||||
droppedCount++
|
||||
}
|
||||
|
||||
return droppedCount, nil
|
||||
}
|
||||
|
||||
// ListPartitions 列出表的所有分区
|
||||
func (m *PartitionManager) ListPartitions(ctx context.Context, tableName string) ([]string, error) {
|
||||
query := `
|
||||
SELECT relname
|
||||
FROM pg_class c
|
||||
JOIN pg_namespace n ON n.oid = c.relnamespace
|
||||
WHERE n.nspname = 'public'
|
||||
AND relname ~ $1
|
||||
ORDER BY relname
|
||||
`
|
||||
|
||||
rows, err := m.pool.Query(ctx, query, fmt.Sprintf("^%s_[0-9]{4}_[0-9]{2}$", tableName))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query partitions: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var partitions []string
|
||||
for rows.Next() {
|
||||
var partitionName string
|
||||
if err := rows.Scan(&partitionName); err != nil {
|
||||
continue
|
||||
}
|
||||
partitions = append(partitions, partitionName)
|
||||
}
|
||||
|
||||
return partitions, nil
|
||||
}
|
||||
|
||||
// IsPartitioned 检查表是否已分区
|
||||
func (m *PartitionManager) IsPartitioned(ctx context.Context, tableName string) (bool, error) {
|
||||
query := `
|
||||
SELECT relkind
|
||||
FROM pg_class
|
||||
WHERE relname = $1
|
||||
`
|
||||
|
||||
var relkind string
|
||||
if err := m.pool.QueryRow(ctx, query, tableName).Scan(&relkind); err != nil {
|
||||
return false, fmt.Errorf("failed to check table: %w", err)
|
||||
}
|
||||
|
||||
// 'p' = partitioned table
|
||||
return relkind == "p", nil
|
||||
}
|
||||
235
supply-api/internal/repository/token_status.go
Normal file
235
supply-api/internal/repository/token_status.go
Normal file
@@ -0,0 +1,235 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
// TokenStatus Token状态
|
||||
type TokenStatus string
|
||||
|
||||
const (
|
||||
TokenStatusActive TokenStatus = "active"
|
||||
TokenStatusRevoked TokenStatus = "revoked"
|
||||
TokenStatusExpired TokenStatus = "expired"
|
||||
)
|
||||
|
||||
// TokenStatusRecord Token状态记录
|
||||
type TokenStatusRecord struct {
|
||||
ID int64 `json:"id"`
|
||||
TokenID string `json:"token_id"`
|
||||
SubjectID int64 `json:"subject_id"`
|
||||
TenantID int64 `json:"tenant_id"`
|
||||
Role string `json:"role"`
|
||||
Status TokenStatus `json:"status"`
|
||||
IssuedAt time.Time `json:"issued_at"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
RevokedAt *time.Time `json:"revoked_at,omitempty"`
|
||||
RevokedReason *string `json:"revoked_reason,omitempty"`
|
||||
RevokedBy *int64 `json:"revoked_by,omitempty"`
|
||||
LastVerifiedAt *time.Time `json:"last_verified_at,omitempty"`
|
||||
VerificationCount int64 `json:"verification_count"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// TokenStatusRepository Token状态仓储
|
||||
type TokenStatusRepository struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
// NewTokenStatusRepository 创建Token状态仓储
|
||||
func NewTokenStatusRepository(pool *pgxpool.Pool) *TokenStatusRepository {
|
||||
return &TokenStatusRepository{pool: pool}
|
||||
}
|
||||
|
||||
// Create 创建Token状态记录
|
||||
func (r *TokenStatusRepository) Create(ctx context.Context, record *TokenStatusRecord) error {
|
||||
query := `
|
||||
INSERT INTO token_status_registry (
|
||||
token_id, subject_id, tenant_id, role, status,
|
||||
issued_at, expires_at
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
RETURNING id, created_at, updated_at
|
||||
`
|
||||
|
||||
err := r.pool.QueryRow(ctx, query,
|
||||
record.TokenID, record.SubjectID, record.TenantID, record.Role,
|
||||
record.Status, record.IssuedAt, record.ExpiresAt,
|
||||
).Scan(&record.ID, &record.CreatedAt, &record.UpdatedAt)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create token status record: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByTokenID 根据TokenID获取状态
|
||||
func (r *TokenStatusRepository) GetByTokenID(ctx context.Context, tokenID string) (*TokenStatusRecord, error) {
|
||||
query := `
|
||||
SELECT id, token_id, subject_id, tenant_id, role, status,
|
||||
issued_at, expires_at, revoked_at, revoked_reason, revoked_by,
|
||||
last_verified_at, verification_count, created_at, updated_at
|
||||
FROM token_status_registry
|
||||
WHERE token_id = $1
|
||||
`
|
||||
|
||||
record := &TokenStatusRecord{}
|
||||
err := r.pool.QueryRow(ctx, query, tokenID).Scan(
|
||||
&record.ID, &record.TokenID, &record.SubjectID, &record.TenantID, &record.Role,
|
||||
&record.Status, &record.IssuedAt, &record.ExpiresAt, &record.RevokedAt,
|
||||
&record.RevokedReason, &record.RevokedBy, &record.LastVerifiedAt,
|
||||
&record.VerificationCount, &record.CreatedAt, &record.UpdatedAt,
|
||||
)
|
||||
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get token status record: %w", err)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// GetStatus 获取Token状态字符串(用于TokenStatusBackend接口)
|
||||
func (r *TokenStatusRepository) GetStatus(ctx context.Context, tokenID string) (string, error) {
|
||||
record, err := r.GetByTokenID(ctx, tokenID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if record == nil {
|
||||
return "active", nil // 不存在的token默认为active(未发行)
|
||||
}
|
||||
|
||||
// 检查是否过期
|
||||
if record.Status == TokenStatusActive && record.ExpiresAt.Before(time.Now()) {
|
||||
return string(TokenStatusExpired), nil
|
||||
}
|
||||
|
||||
return string(record.Status), nil
|
||||
}
|
||||
|
||||
// Revoke 吊销Token(用于TokenRevocationBackend接口)
|
||||
func (r *TokenStatusRepository) Revoke(ctx context.Context, tokenID string, reason string) error {
|
||||
query := `
|
||||
UPDATE token_status_registry SET
|
||||
status = 'revoked',
|
||||
revoked_at = CURRENT_TIMESTAMP,
|
||||
revoked_reason = $2,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE token_id = $1 AND status = 'active'
|
||||
`
|
||||
|
||||
result, err := r.pool.Exec(ctx, query, tokenID, reason)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to revoke token: %w", err)
|
||||
}
|
||||
|
||||
if result.RowsAffected() == 0 {
|
||||
return fmt.Errorf("token not found or already revoked")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeBySubjectID 根据SubjectID吊销所有Token
|
||||
func (r *TokenStatusRepository) RevokeBySubjectID(ctx context.Context, subjectID int64, reason string) (int64, error) {
|
||||
query := `
|
||||
UPDATE token_status_registry SET
|
||||
status = 'revoked',
|
||||
revoked_at = CURRENT_TIMESTAMP,
|
||||
revoked_reason = $2,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE subject_id = $1 AND status = 'active'
|
||||
`
|
||||
|
||||
result, err := r.pool.Exec(ctx, query, subjectID, reason)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to revoke tokens by subject_id: %w", err)
|
||||
}
|
||||
|
||||
return result.RowsAffected(), nil
|
||||
}
|
||||
|
||||
// UpdateVerificationCount 更新验证计数
|
||||
func (r *TokenStatusRepository) UpdateVerificationCount(ctx context.Context, tokenID string) error {
|
||||
query := `
|
||||
UPDATE token_status_registry SET
|
||||
last_verified_at = CURRENT_TIMESTAMP,
|
||||
verification_count = verification_count + 1,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE token_id = $1
|
||||
`
|
||||
|
||||
_, err := r.pool.Exec(ctx, query, tokenID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update verification count: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteExpired 删除过期记录(定时清理)
|
||||
func (r *TokenStatusRepository) DeleteExpired(ctx context.Context, before time.Time) (int64, error) {
|
||||
query := `DELETE FROM token_status_registry WHERE status = 'expired' AND expires_at < $1`
|
||||
|
||||
cmdTag, err := r.pool.Exec(ctx, query, before)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to delete expired token records: %w", err)
|
||||
}
|
||||
|
||||
return cmdTag.RowsAffected(), nil
|
||||
}
|
||||
|
||||
// DeleteBySubjectID 删除用户的所有Token记录(登出时清理)
|
||||
func (r *TokenStatusRepository) DeleteBySubjectID(ctx context.Context, subjectID int64) (int64, error) {
|
||||
query := `DELETE FROM token_status_registry WHERE subject_id = $1`
|
||||
|
||||
cmdTag, err := r.pool.Exec(ctx, query, subjectID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to delete token records by subject_id: %w", err)
|
||||
}
|
||||
|
||||
return cmdTag.RowsAffected(), nil
|
||||
}
|
||||
|
||||
// ListActiveBySubjectID 列出用户的所有活跃Token
|
||||
func (r *TokenStatusRepository) ListActiveBySubjectID(ctx context.Context, subjectID int64) ([]*TokenStatusRecord, error) {
|
||||
query := `
|
||||
SELECT id, token_id, subject_id, tenant_id, role, status,
|
||||
issued_at, expires_at, revoked_at, revoked_reason, revoked_by,
|
||||
last_verified_at, verification_count, created_at, updated_at
|
||||
FROM token_status_registry
|
||||
WHERE subject_id = $1 AND status = 'active' AND expires_at > CURRENT_TIMESTAMP
|
||||
ORDER BY issued_at DESC
|
||||
`
|
||||
|
||||
rows, err := r.pool.Query(ctx, query, subjectID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list active tokens: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var records []*TokenStatusRecord
|
||||
for rows.Next() {
|
||||
record := &TokenStatusRecord{}
|
||||
err := rows.Scan(
|
||||
&record.ID, &record.TokenID, &record.SubjectID, &record.TenantID, &record.Role,
|
||||
&record.Status, &record.IssuedAt, &record.ExpiresAt, &record.RevokedAt,
|
||||
&record.RevokedReason, &record.RevokedBy, &record.LastVerifiedAt,
|
||||
&record.VerificationCount, &record.CreatedAt, &record.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan token record: %w", err)
|
||||
}
|
||||
records = append(records, record)
|
||||
}
|
||||
|
||||
return records, nil
|
||||
}
|
||||
Reference in New Issue
Block a user