From d5b5a8ece06506e805cd895b3e37c30f2576090d Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 7 Apr 2026 07:41:25 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E7=B3=BB=E7=BB=9F=E6=80=A7=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=E5=AE=89=E5=85=A8=E9=97=AE=E9=A2=98=E3=80=81=E6=80=A7?= =?UTF-8?q?=E8=83=BD=E9=97=AE=E9=A2=98=E5=92=8C=E9=94=99=E8=AF=AF=E5=A4=84?= =?UTF-8?q?=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 安全问题修复: - X-Forwarded-For越界检查(auth.go) - checkTokenStatus Context参数传递(auth.go) - Type Assertion安全检查(auth.go) 性能问题修复: - TokenCache过期清理机制 - BruteForceProtection过期清理 - InMemoryIdempotencyStore过期清理 错误处理修复: - AuditStore.Emit返回error - domain层emitAudit辅助方法 - List方法返回空slice而非nil - 金额/价格负数验证 架构一致性: - 统一使用model.RoleHierarchyLevels 新增功能: - Alert API完整实现(CRUD+Resolve) - pkg/error错误码集中管理 --- supply-api/internal/audit/audit.go | 54 ++- .../internal/audit/handler/alert_handler.go | 350 ++++++++++++++++++ .../audit/handler/alert_handler_test.go | 315 ++++++++++++++++ supply-api/internal/audit/model/alert.go | 195 ++++++++++ .../internal/audit/service/alert_service.go | 274 ++++++++++++++ .../internal/audit/service/batch_buffer.go | 203 ++++++++++ .../audit/service/batch_buffer_test.go | 249 +++++++++++++ supply-api/internal/domain/account.go | 17 +- supply-api/internal/domain/package.go | 30 +- supply-api/internal/domain/settlement.go | 18 +- supply-api/internal/httpapi/alert_api.go | 125 +++++++ .../internal/iam/middleware/scope_auth.go | 89 +++-- supply-api/internal/iam/model/role.go | 28 ++ supply-api/internal/middleware/auth.go | 152 ++++++-- supply-api/internal/middleware/auth_test.go | 25 +- supply-api/internal/repository/account.go | 2 +- supply-api/internal/repository/package.go | 2 +- supply-api/internal/repository/settlement.go | 2 +- supply-api/internal/storage/store.go | 47 ++- supply-api/pkg/error/errors.go | 132 +++++++ supply-api/pkg/error/errors_test.go | 95 +++++ 21 files changed, 2321 insertions(+), 83 deletions(-) create mode 100644 supply-api/internal/audit/handler/alert_handler.go create mode 100644 supply-api/internal/audit/handler/alert_handler_test.go create mode 100644 supply-api/internal/audit/model/alert.go create mode 100644 supply-api/internal/audit/service/alert_service.go create mode 100644 supply-api/internal/audit/service/batch_buffer.go create mode 100644 supply-api/internal/audit/service/batch_buffer_test.go create mode 100644 supply-api/internal/httpapi/alert_api.go create mode 100644 supply-api/pkg/error/errors.go create mode 100644 supply-api/pkg/error/errors_test.go diff --git a/supply-api/internal/audit/audit.go b/supply-api/internal/audit/audit.go index 55c1546..994b8fd 100644 --- a/supply-api/internal/audit/audit.go +++ b/supply-api/internal/audit/audit.go @@ -2,6 +2,7 @@ package audit import ( "context" + "fmt" "sync" "time" ) @@ -23,8 +24,10 @@ type Event struct { // 审计存储接口 type AuditStore interface { - Emit(ctx context.Context, event Event) + Emit(ctx context.Context, event Event) error Query(ctx context.Context, filter EventFilter) ([]Event, error) + QueryWithTotal(ctx context.Context, filter EventFilter) ([]Event, int64, error) + GetByID(ctx context.Context, eventID string) (Event, error) } // 事件过滤器 @@ -52,13 +55,14 @@ func NewMemoryAuditStore() *MemoryAuditStore { } } -func (s *MemoryAuditStore) Emit(ctx context.Context, event Event) { +func (s *MemoryAuditStore) Emit(ctx context.Context, event Event) error { s.mu.Lock() defer s.mu.Unlock() event.EventID = generateEventID() event.CreatedAt = time.Now() s.events = append(s.events, event) + return nil } func (s *MemoryAuditStore) Query(ctx context.Context, filter EventFilter) ([]Event, error) { @@ -90,6 +94,52 @@ func (s *MemoryAuditStore) Query(ctx context.Context, filter EventFilter) ([]Eve return result, nil } +// QueryWithTotal 查询事件并返回总数 +func (s *MemoryAuditStore) QueryWithTotal(ctx context.Context, filter EventFilter) ([]Event, int64, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + var result []Event + total := int64(0) + + for _, event := range s.events { + total++ + if filter.TenantID > 0 && event.TenantID != filter.TenantID { + continue + } + if filter.ObjectType != "" && event.ObjectType != filter.ObjectType { + continue + } + if filter.ObjectID > 0 && event.ObjectID != filter.ObjectID { + continue + } + if filter.Action != "" && event.Action != filter.Action { + continue + } + result = append(result, event) + } + + // 限制返回数量 + if filter.Limit > 0 && len(result) > filter.Limit { + result = result[:filter.Limit] + } + + return result, total, nil +} + +// GetByID 根据事件ID获取单个事件 +func (s *MemoryAuditStore) GetByID(ctx context.Context, eventID string) (Event, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + for _, event := range s.events { + if event.EventID == eventID { + return event, nil + } + } + return Event{}, fmt.Errorf("event not found") +} + func generateEventID() string { return time.Now().Format("20060102150405") + "-evt" } diff --git a/supply-api/internal/audit/handler/alert_handler.go b/supply-api/internal/audit/handler/alert_handler.go new file mode 100644 index 0000000..9bf915b --- /dev/null +++ b/supply-api/internal/audit/handler/alert_handler.go @@ -0,0 +1,350 @@ +package handler + +import ( + "encoding/json" + "net/http" + "strconv" + "strings" + + "lijiaoqiao/supply-api/internal/audit/model" + "lijiaoqiao/supply-api/internal/audit/service" +) + +// AlertHandler 告警HTTP处理器 +type AlertHandler struct { + svc *service.AlertService +} + +// NewAlertHandler 创建告警处理器 +func NewAlertHandler(svc *service.AlertService) *AlertHandler { + return &AlertHandler{svc: svc} +} + +// CreateAlertRequest 创建告警请求 +type CreateAlertRequest struct { + AlertName string `json:"alert_name"` + AlertType string `json:"alert_type"` + AlertLevel string `json:"alert_level"` + TenantID int64 `json:"tenant_id"` + SupplierID int64 `json:"supplier_id,omitempty"` + Title string `json:"title"` + Message string `json:"message"` + Description string `json:"description,omitempty"` + EventID string `json:"event_id,omitempty"` + EventIDs []string `json:"event_ids,omitempty"` + NotifyEnabled bool `json:"notify_enabled"` + Tags []string `json:"tags,omitempty"` +} + +// UpdateAlertRequest 更新告警请求 +type UpdateAlertRequest struct { + Title string `json:"title,omitempty"` + Message string `json:"message,omitempty"` + Description string `json:"description,omitempty"` + AlertLevel string `json:"alert_level,omitempty"` + Status string `json:"status,omitempty"` + NotifyEnabled *bool `json:"notify_enabled,omitempty"` + NotifyChannels []string `json:"notify_channels,omitempty"` + Tags []string `json:"tags,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +// ResolveAlertRequest 解决告警请求 +type ResolveAlertRequest struct { + ResolvedBy string `json:"resolved_by"` + Note string `json:"note"` +} + +// AlertResponse 告警响应 +type AlertResponse struct { + Alert *model.Alert `json:"alert"` +} + +// AlertListResponse 告警列表响应 +type AlertListResponse struct { + Alerts []*model.Alert `json:"alerts"` + Total int64 `json:"total"` + Offset int `json:"offset"` + Limit int `json:"limit"` +} + +// CreateAlert 处理 POST /api/v1/audit/alerts +func (h *AlertHandler) CreateAlert(w http.ResponseWriter, r *http.Request) { + var req CreateAlertRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeAlertError(w, http.StatusBadRequest, "INVALID_REQUEST", "invalid request body: "+err.Error()) + return + } + + // 验证必填字段 + if req.Title == "" { + writeAlertError(w, http.StatusBadRequest, "MISSING_FIELD", "title is required") + return + } + if req.AlertType == "" { + writeAlertError(w, http.StatusBadRequest, "MISSING_FIELD", "alert_type is required") + return + } + + // 创建告警 + alert := &model.Alert{ + AlertName: req.AlertName, + AlertType: req.AlertType, + AlertLevel: req.AlertLevel, + TenantID: req.TenantID, + SupplierID: req.SupplierID, + Title: req.Title, + Message: req.Message, + Description: req.Description, + EventID: req.EventID, + EventIDs: req.EventIDs, + NotifyEnabled: req.NotifyEnabled, + Tags: req.Tags, + } + + result, err := h.svc.CreateAlert(r.Context(), alert) + if err != nil { + writeAlertError(w, http.StatusInternalServerError, "CREATE_FAILED", err.Error()) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + json.NewEncoder(w).Encode(AlertResponse{Alert: result}) +} + +// GetAlert 处理 GET /api/v1/audit/alerts/{alert_id} +func (h *AlertHandler) GetAlert(w http.ResponseWriter, r *http.Request) { + alertID := extractAlertID(r) + if alertID == "" { + writeAlertError(w, http.StatusBadRequest, "MISSING_PARAM", "alert_id is required") + return + } + + alert, err := h.svc.GetAlert(r.Context(), alertID) + if err != nil { + if err == service.ErrAlertNotFound { + writeAlertError(w, http.StatusNotFound, "NOT_FOUND", "alert not found") + return + } + writeAlertError(w, http.StatusInternalServerError, "GET_FAILED", err.Error()) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(AlertResponse{Alert: alert}) +} + +// ListAlerts 处理 GET /api/v1/audit/alerts +func (h *AlertHandler) ListAlerts(w http.ResponseWriter, r *http.Request) { + filter := &model.AlertFilter{} + + // 解析查询参数 + if tenantIDStr := r.URL.Query().Get("tenant_id"); tenantIDStr != "" { + tenantID, err := strconv.ParseInt(tenantIDStr, 10, 64) + if err == nil { + filter.TenantID = tenantID + } + } + + if supplierIDStr := r.URL.Query().Get("supplier_id"); supplierIDStr != "" { + supplierID, err := strconv.ParseInt(supplierIDStr, 10, 64) + if err == nil { + filter.SupplierID = supplierID + } + } + + if alertType := r.URL.Query().Get("alert_type"); alertType != "" { + filter.AlertType = alertType + } + + if alertLevel := r.URL.Query().Get("alert_level"); alertLevel != "" { + filter.AlertLevel = alertLevel + } + + if status := r.URL.Query().Get("status"); status != "" { + filter.Status = status + } + + if keywords := r.URL.Query().Get("keywords"); keywords != "" { + filter.Keywords = keywords + } + + if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" { + offset, err := strconv.Atoi(offsetStr) + if err == nil && offset >= 0 { + filter.Offset = offset + } + } + + if limitStr := r.URL.Query().Get("limit"); limitStr != "" { + limit, err := strconv.Atoi(limitStr) + if err == nil && limit > 0 && limit <= 1000 { + filter.Limit = limit + } + } + + if filter.Limit == 0 { + filter.Limit = 100 + } + + alerts, total, err := h.svc.ListAlerts(r.Context(), filter) + if err != nil { + writeAlertError(w, http.StatusInternalServerError, "LIST_FAILED", err.Error()) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(AlertListResponse{ + Alerts: alerts, + Total: total, + Offset: filter.Offset, + Limit: filter.Limit, + }) +} + +// UpdateAlert 处理 PUT /api/v1/audit/alerts/{alert_id} +func (h *AlertHandler) UpdateAlert(w http.ResponseWriter, r *http.Request) { + alertID := extractAlertID(r) + if alertID == "" { + writeAlertError(w, http.StatusBadRequest, "MISSING_PARAM", "alert_id is required") + return + } + + // 获取现有告警 + alert, err := h.svc.GetAlert(r.Context(), alertID) + if err != nil { + if err == service.ErrAlertNotFound { + writeAlertError(w, http.StatusNotFound, "NOT_FOUND", "alert not found") + return + } + writeAlertError(w, http.StatusInternalServerError, "GET_FAILED", err.Error()) + return + } + + var req UpdateAlertRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeAlertError(w, http.StatusBadRequest, "INVALID_REQUEST", "invalid request body: "+err.Error()) + return + } + + // 更新字段 + if req.Title != "" { + alert.Title = req.Title + } + if req.Message != "" { + alert.Message = req.Message + } + if req.Description != "" { + alert.Description = req.Description + } + if req.AlertLevel != "" { + alert.AlertLevel = req.AlertLevel + } + if req.Status != "" { + alert.Status = req.Status + } + if req.NotifyEnabled != nil { + alert.NotifyEnabled = *req.NotifyEnabled + } + if len(req.NotifyChannels) > 0 { + alert.NotifyChannels = req.NotifyChannels + } + if len(req.Tags) > 0 { + alert.Tags = req.Tags + } + if req.Metadata != nil { + alert.Metadata = req.Metadata + } + + result, err := h.svc.UpdateAlert(r.Context(), alert) + if err != nil { + writeAlertError(w, http.StatusInternalServerError, "UPDATE_FAILED", err.Error()) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(AlertResponse{Alert: result}) +} + +// DeleteAlert 处理 DELETE /api/v1/audit/alerts/{alert_id} +func (h *AlertHandler) DeleteAlert(w http.ResponseWriter, r *http.Request) { + alertID := extractAlertID(r) + if alertID == "" { + writeAlertError(w, http.StatusBadRequest, "MISSING_PARAM", "alert_id is required") + return + } + + err := h.svc.DeleteAlert(r.Context(), alertID) + if err != nil { + if err == service.ErrAlertNotFound { + writeAlertError(w, http.StatusNotFound, "NOT_FOUND", "alert not found") + return + } + writeAlertError(w, http.StatusInternalServerError, "DELETE_FAILED", err.Error()) + return + } + + w.WriteHeader(http.StatusNoContent) +} + +// ResolveAlert 处理 POST /api/v1/audit/alerts/{alert_id}/resolve +func (h *AlertHandler) ResolveAlert(w http.ResponseWriter, r *http.Request) { + alertID := extractAlertID(r) + if alertID == "" { + writeAlertError(w, http.StatusBadRequest, "MISSING_PARAM", "alert_id is required") + return + } + + var req ResolveAlertRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeAlertError(w, http.StatusBadRequest, "INVALID_REQUEST", "invalid request body: "+err.Error()) + return + } + + if req.ResolvedBy == "" { + writeAlertError(w, http.StatusBadRequest, "MISSING_FIELD", "resolved_by is required") + return + } + + result, err := h.svc.ResolveAlert(r.Context(), alertID, req.ResolvedBy, req.Note) + if err != nil { + if err == service.ErrAlertNotFound { + writeAlertError(w, http.StatusNotFound, "NOT_FOUND", "alert not found") + return + } + writeAlertError(w, http.StatusInternalServerError, "RESOLVE_FAILED", err.Error()) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(AlertResponse{Alert: result}) +} + +// extractAlertID 从请求中提取alert_id(优先从路径,其次从查询参数) +func extractAlertID(r *http.Request) string { + // 先尝试从路径提取 + path := r.URL.Path + parts := strings.Split(strings.TrimPrefix(path, "/"), "/") + if len(parts) >= 5 && parts[0] == "api" && parts[1] == "v1" && parts[2] == "audit" && parts[3] == "alerts" { + if parts[4] != "" && parts[4] != "resolve" { + return parts[4] + } + } + // 再尝试从查询参数提取 + if alertID := r.URL.Query().Get("alert_id"); alertID != "" { + return alertID + } + return "" +} + +// writeAlertError 写入错误响应 +func writeAlertError(w http.ResponseWriter, status int, code, message string) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + json.NewEncoder(w).Encode(ErrorResponse{ + Error: message, + Code: code, + Details: "", + }) +} diff --git a/supply-api/internal/audit/handler/alert_handler_test.go b/supply-api/internal/audit/handler/alert_handler_test.go new file mode 100644 index 0000000..286561f --- /dev/null +++ b/supply-api/internal/audit/handler/alert_handler_test.go @@ -0,0 +1,315 @@ +package handler + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "lijiaoqiao/supply-api/internal/audit/model" + "lijiaoqiao/supply-api/internal/audit/service" + + "github.com/stretchr/testify/assert" +) + +// mockAlertStore 模拟告警存储 +type mockAlertStore struct { + alerts map[string]*model.Alert +} + +func newMockAlertStore() *mockAlertStore { + return &mockAlertStore{ + alerts: make(map[string]*model.Alert), + } +} + +func (m *mockAlertStore) Create(ctx context.Context, alert *model.Alert) error { + if alert.AlertID == "" { + alert.AlertID = "test-alert-id" + } + alert.CreatedAt = testTime + alert.UpdatedAt = testTime + m.alerts[alert.AlertID] = alert + return nil +} + +func (m *mockAlertStore) GetByID(ctx context.Context, alertID string) (*model.Alert, error) { + if alert, ok := m.alerts[alertID]; ok { + return alert, nil + } + return nil, service.ErrAlertNotFound +} + +func (m *mockAlertStore) Update(ctx context.Context, alert *model.Alert) error { + if _, ok := m.alerts[alert.AlertID]; !ok { + return service.ErrAlertNotFound + } + alert.UpdatedAt = testTime + m.alerts[alert.AlertID] = alert + return nil +} + +func (m *mockAlertStore) Delete(ctx context.Context, alertID string) error { + if _, ok := m.alerts[alertID]; !ok { + return service.ErrAlertNotFound + } + delete(m.alerts, alertID) + return nil +} + +func (m *mockAlertStore) List(ctx context.Context, filter *model.AlertFilter) ([]*model.Alert, int64, error) { + var result []*model.Alert + for _, alert := range m.alerts { + if filter.TenantID > 0 && alert.TenantID != filter.TenantID { + continue + } + if filter.Status != "" && alert.Status != filter.Status { + continue + } + result = append(result, alert) + } + return result, int64(len(result)), nil +} + +var testTime = time.Now() + +// TestAlertHandler_CreateAlert_Success 测试创建告警成功 +func TestAlertHandler_CreateAlert_Success(t *testing.T) { + store := newMockAlertStore() + svc := service.NewAlertService(store) + h := NewAlertHandler(svc) + + reqBody := CreateAlertRequest{ + AlertName: "TEST_ALERT", + AlertType: "security", + AlertLevel: "warning", + TenantID: 2001, + Title: "Test Alert Title", + Message: "Test alert message", + } + + body, _ := json.Marshal(reqBody) + req := httptest.NewRequest("POST", "/api/v1/audit/alerts", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + h.CreateAlert(w, req) + + assert.Equal(t, http.StatusCreated, w.Code) + + var result AlertResponse + err := json.Unmarshal(w.Body.Bytes(), &result) + assert.NoError(t, err) + assert.Equal(t, "Test Alert Title", result.Alert.Title) + assert.Equal(t, "security", result.Alert.AlertType) +} + +// TestAlertHandler_CreateAlert_MissingTitle 测试缺少标题 +func TestAlertHandler_CreateAlert_MissingTitle(t *testing.T) { + store := newMockAlertStore() + svc := service.NewAlertService(store) + h := NewAlertHandler(svc) + + reqBody := CreateAlertRequest{ + AlertType: "security", + AlertLevel: "warning", + TenantID: 2001, + } + + body, _ := json.Marshal(reqBody) + req := httptest.NewRequest("POST", "/api/v1/audit/alerts", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + h.CreateAlert(w, req) + + assert.Equal(t, http.StatusBadRequest, w.Code) +} + +// TestAlertHandler_GetAlert_Success 测试获取告警成功 +func TestAlertHandler_GetAlert_Success(t *testing.T) { + store := newMockAlertStore() + svc := service.NewAlertService(store) + h := NewAlertHandler(svc) + + // 先创建一个告警 + alert := &model.Alert{ + AlertID: "test-alert-123", + AlertName: "TEST_ALERT", + AlertType: "security", + AlertLevel: "warning", + TenantID: 2001, + Title: "Test Alert", + Message: "Test message", + } + store.Create(context.Background(), alert) + + // 获取告警 + req := httptest.NewRequest("GET", "/api/v1/audit/alerts/test-alert-123", nil) + w := httptest.NewRecorder() + + h.GetAlert(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + + var result AlertResponse + err := json.Unmarshal(w.Body.Bytes(), &result) + assert.NoError(t, err) + assert.Equal(t, "test-alert-123", result.Alert.AlertID) +} + +// TestAlertHandler_GetAlert_NotFound 测试告警不存在 +func TestAlertHandler_GetAlert_NotFound(t *testing.T) { + store := newMockAlertStore() + svc := service.NewAlertService(store) + h := NewAlertHandler(svc) + + req := httptest.NewRequest("GET", "/api/v1/audit/alerts/nonexistent", nil) + w := httptest.NewRecorder() + + h.GetAlert(w, req) + + assert.Equal(t, http.StatusNotFound, w.Code) +} + +// TestAlertHandler_ListAlerts_Success 测试列出告警成功 +func TestAlertHandler_ListAlerts_Success(t *testing.T) { + store := newMockAlertStore() + svc := service.NewAlertService(store) + h := NewAlertHandler(svc) + + // 创建多个告警 + for i := 0; i < 3; i++ { + alert := &model.Alert{ + AlertID: "alert-" + string(rune('a'+i)), + AlertType: "security", + AlertLevel: "warning", + TenantID: 2001, + Title: "Test Alert", + } + store.Create(context.Background(), alert) + } + + req := httptest.NewRequest("GET", "/api/v1/audit/alerts?tenant_id=2001", nil) + w := httptest.NewRecorder() + + h.ListAlerts(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + + var result AlertListResponse + err := json.Unmarshal(w.Body.Bytes(), &result) + assert.NoError(t, err) + assert.Equal(t, int64(3), result.Total) +} + +// TestAlertHandler_UpdateAlert_Success 测试更新告警成功 +func TestAlertHandler_UpdateAlert_Success(t *testing.T) { + store := newMockAlertStore() + svc := service.NewAlertService(store) + h := NewAlertHandler(svc) + + // 先创建一个告警 + alert := &model.Alert{ + AlertID: "test-alert-123", + AlertType: "security", + AlertLevel: "warning", + TenantID: 2001, + Title: "Original Title", + } + store.Create(context.Background(), alert) + + // 更新告警 + reqBody := UpdateAlertRequest{ + Title: "Updated Title", + } + body, _ := json.Marshal(reqBody) + req := httptest.NewRequest("PUT", "/api/v1/audit/alerts/test-alert-123", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + h.UpdateAlert(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + + var result AlertResponse + json.Unmarshal(w.Body.Bytes(), &result) + assert.Equal(t, "Updated Title", result.Alert.Title) +} + +// TestAlertHandler_DeleteAlert_Success 测试删除告警成功 +func TestAlertHandler_DeleteAlert_Success(t *testing.T) { + store := newMockAlertStore() + svc := service.NewAlertService(store) + h := NewAlertHandler(svc) + + // 先创建一个告警 + alert := &model.Alert{ + AlertID: "test-alert-123", + AlertType: "security", + AlertLevel: "warning", + TenantID: 2001, + } + store.Create(context.Background(), alert) + + // 删除告警 + req := httptest.NewRequest("DELETE", "/api/v1/audit/alerts/test-alert-123", nil) + w := httptest.NewRecorder() + + h.DeleteAlert(w, req) + + assert.Equal(t, http.StatusNoContent, w.Code) +} + +// TestAlertHandler_DeleteAlert_NotFound 测试删除不存在的告警 +func TestAlertHandler_DeleteAlert_NotFound(t *testing.T) { + store := newMockAlertStore() + svc := service.NewAlertService(store) + h := NewAlertHandler(svc) + + req := httptest.NewRequest("DELETE", "/api/v1/audit/alerts/nonexistent", nil) + w := httptest.NewRecorder() + + h.DeleteAlert(w, req) + + assert.Equal(t, http.StatusNotFound, w.Code) +} + +// TestAlertHandler_ResolveAlert_Success 测试解决告警成功 +func TestAlertHandler_ResolveAlert_Success(t *testing.T) { + store := newMockAlertStore() + svc := service.NewAlertService(store) + h := NewAlertHandler(svc) + + // 先创建一个告警 + alert := &model.Alert{ + AlertID: "test-alert-123", + AlertType: "security", + AlertLevel: "warning", + TenantID: 2001, + Status: model.AlertStatusActive, + } + store.Create(context.Background(), alert) + + // 解决告警 + reqBody := ResolveAlertRequest{ + ResolvedBy: "admin", + Note: "Fixed the issue", + } + body, _ := json.Marshal(reqBody) + req := httptest.NewRequest("POST", "/api/v1/audit/alerts/test-alert-123/resolve", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + h.ResolveAlert(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + + var result AlertResponse + json.Unmarshal(w.Body.Bytes(), &result) + assert.Equal(t, model.AlertStatusResolved, result.Alert.Status) + assert.Equal(t, "admin", result.Alert.ResolvedBy) +} diff --git a/supply-api/internal/audit/model/alert.go b/supply-api/internal/audit/model/alert.go new file mode 100644 index 0000000..6f181c1 --- /dev/null +++ b/supply-api/internal/audit/model/alert.go @@ -0,0 +1,195 @@ +package model + +import ( + "time" + + "github.com/google/uuid" +) + +// 告警级别常量 +const ( + AlertLevelInfo = "info" + AlertLevelWarning = "warning" + AlertLevelError = "error" + AlertLevelCritical = "critical" +) + +// 告警状态常量 +const ( + AlertStatusActive = "active" + AlertStatusResolved = "resolved" + AlertStatusAcknowledged = "acknowledged" + AlertStatusSuppressed = "suppressed" +) + +// 告警类型常量 +const ( + AlertTypeSecurity = "security" + AlertTypeInvariant = "invariant" + AlertTypeCredential = "credential" + AlertTypeAuthentication = "authentication" + AlertTypeAuthorization = "authorization" + AlertTypeQuota = "quota" +) + +// Alert 告警 +type Alert struct { + // 基础标识 + AlertID string `json:"alert_id"` // 告警唯一ID + AlertName string `json:"alert_name"` // 告警名称 + AlertType string `json:"alert_type"` // 告警类型 (security/invariant/credential/etc.) + AlertLevel string `json:"alert_level"` // 告警级别 (info/warning/error/critical) + TenantID int64 `json:"tenant_id"` // 租户ID + SupplierID int64 `json:"supplier_id,omitempty"` // 供应商ID(可选) + + // 告警内容 + Title string `json:"title"` // 告警标题 + Message string `json:"message"` // 告警消息 + Description string `json:"description,omitempty"` // 详细描述 + + // 关联事件 + EventID string `json:"event_id,omitempty"` // 关联的事件ID + EventIDs []string `json:"event_ids,omitempty"` // 关联的事件ID列表(多个) + + // 触发条件 + TriggerCondition string `json:"trigger_condition,omitempty"` // 触发条件 + Threshold float64 `json:"threshold,omitempty"` // 阈值 + CurrentValue float64 `json:"current_value,omitempty"` // 当前值 + + // 状态 + Status string `json:"status"` // 状态 (active/resolved/acknowledged/suppressed) + ResolvedAt *time.Time `json:"resolved_at,omitempty"` // 解决时间 + ResolvedBy string `json:"resolved_by,omitempty"` // 解决人 + ResolveNote string `json:"resolve_note,omitempty"` // 解决备注 + + // 通知 + NotifyEnabled bool `json:"notify_enabled"` // 是否启用通知 + NotifyChannels []string `json:"notify_channels,omitempty"` // 通知渠道 (email/sms/webhook/etc.) + + // 时间戳 + CreatedAt time.Time `json:"created_at"` // 创建时间 + UpdatedAt time.Time `json:"updated_at"` // 更新时间 + FirstSeenAt time.Time `json:"first_seen_at"` // 首次出现时间 + LastSeenAt time.Time `json:"last_seen_at"` // 最后出现时间 + + // 元数据 + Metadata map[string]any `json:"metadata,omitempty"` // 扩展元数据 + Tags []string `json:"tags,omitempty"` // 标签 +} + +// NewAlert 创建新告警 +func NewAlert(alertName, alertType, alertLevel, tenantID string, title, message string) *Alert { + now := time.Now() + return &Alert{ + AlertID: generateAlertID(), + AlertName: alertName, + AlertType: alertType, + AlertLevel: alertLevel, + TenantID: parseTenantID(tenantID), + Title: title, + Message: message, + Status: AlertStatusActive, + NotifyEnabled: true, + CreatedAt: now, + UpdatedAt: now, + FirstSeenAt: now, + LastSeenAt: now, + Metadata: make(map[string]any), + Tags: []string{}, + } +} + +// generateAlertID 生成告警ID +func generateAlertID() string { + return "ALT-" + uuid.New().String()[:8] +} + +// parseTenantID 解析租户ID +func parseTenantID(tenantID string) int64 { + var id int64 + for _, c := range tenantID { + if c >= '0' && c <= '9' { + id = id*10 + int64(c-'0') + } + } + return id +} + +// IsActive 检查告警是否处于活跃状态 +func (a *Alert) IsActive() bool { + return a.Status == AlertStatusActive +} + +// IsResolved 检查告警是否已解决 +func (a *Alert) IsResolved() bool { + return a.Status == AlertStatusResolved +} + +// Resolve 解决告警 +func (a *Alert) Resolve(resolvedBy, note string) { + now := time.Now() + a.Status = AlertStatusResolved + a.ResolvedAt = &now + a.ResolvedBy = resolvedBy + a.ResolveNote = note + a.UpdatedAt = now +} + +// Acknowledge 确认告警 +func (a *Alert) Acknowledge() { + a.Status = AlertStatusAcknowledged + a.UpdatedAt = time.Now() +} + +// Suppress 抑制告警 +func (a *Alert) Suppress() { + a.Status = AlertStatusSuppressed + a.UpdatedAt = time.Now() +} + +// UpdateLastSeen 更新最后出现时间 +func (a *Alert) UpdateLastSeen() { + a.LastSeenAt = time.Now() + a.UpdatedAt = time.Now() +} + +// AddEventID 添加关联事件ID +func (a *Alert) AddEventID(eventID string) { + a.EventIDs = append(a.EventIDs, eventID) + if a.EventID == "" { + a.EventID = eventID + } + a.UpdateLastSeen() +} + +// SetMetadata 设置元数据 +func (a *Alert) SetMetadata(key string, value any) { + if a.Metadata == nil { + a.Metadata = make(map[string]any) + } + a.Metadata[key] = value +} + +// AddTag 添加标签 +func (a *Alert) AddTag(tag string) { + for _, t := range a.Tags { + if t == tag { + return + } + } + a.Tags = append(a.Tags, tag) +} + +// AlertFilter 告警查询过滤器 +type AlertFilter struct { + TenantID int64 + SupplierID int64 + AlertType string + AlertLevel string + Status string + StartTime time.Time + EndTime time.Time + Keywords string // 关键字搜索(标题/消息) + Limit int + Offset int +} diff --git a/supply-api/internal/audit/service/alert_service.go b/supply-api/internal/audit/service/alert_service.go new file mode 100644 index 0000000..46b586e --- /dev/null +++ b/supply-api/internal/audit/service/alert_service.go @@ -0,0 +1,274 @@ +package service + +import ( + "context" + "errors" + "strings" + "sync" + "time" + + "github.com/google/uuid" + + "lijiaoqiao/supply-api/internal/audit/model" +) + +// 错误定义 +var ( + ErrAlertNotFound = errors.New("alert not found") + ErrInvalidAlertInput = errors.New("invalid alert input") + ErrAlertConflict = errors.New("alert conflict") +) + +// AlertStoreInterface 告警存储接口 +type AlertStoreInterface interface { + Create(ctx context.Context, alert *model.Alert) error + GetByID(ctx context.Context, alertID string) (*model.Alert, error) + Update(ctx context.Context, alert *model.Alert) error + Delete(ctx context.Context, alertID string) error + List(ctx context.Context, filter *model.AlertFilter) ([]*model.Alert, int64, error) +} + +// InMemoryAlertStore 内存告警存储 +type InMemoryAlertStore struct { + mu sync.RWMutex + alerts map[string]*model.Alert +} + +// NewInMemoryAlertStore 创建内存告警存储 +func NewInMemoryAlertStore() *InMemoryAlertStore { + return &InMemoryAlertStore{ + alerts: make(map[string]*model.Alert), + } +} + +// Create 创建告警 +func (s *InMemoryAlertStore) Create(ctx context.Context, alert *model.Alert) error { + s.mu.Lock() + defer s.mu.Unlock() + + if alert.AlertID == "" { + alert.AlertID = "ALT-" + uuid.New().String()[:8] + } + alert.CreatedAt = time.Now() + alert.UpdatedAt = time.Now() + + s.alerts[alert.AlertID] = alert + return nil +} + +// GetByID 根据ID获取告警 +func (s *InMemoryAlertStore) GetByID(ctx context.Context, alertID string) (*model.Alert, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + if alert, ok := s.alerts[alertID]; ok { + return alert, nil + } + return nil, ErrAlertNotFound +} + +// Update 更新告警 +func (s *InMemoryAlertStore) Update(ctx context.Context, alert *model.Alert) error { + s.mu.Lock() + defer s.mu.Unlock() + + if _, ok := s.alerts[alert.AlertID]; !ok { + return ErrAlertNotFound + } + + alert.UpdatedAt = time.Now() + s.alerts[alert.AlertID] = alert + return nil +} + +// Delete 删除告警 +func (s *InMemoryAlertStore) Delete(ctx context.Context, alertID string) error { + s.mu.Lock() + defer s.mu.Unlock() + + if _, ok := s.alerts[alertID]; !ok { + return ErrAlertNotFound + } + + delete(s.alerts, alertID) + return nil +} + +// List 查询告警列表 +func (s *InMemoryAlertStore) List(ctx context.Context, filter *model.AlertFilter) ([]*model.Alert, int64, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + var result []*model.Alert + for _, alert := range s.alerts { + // 按租户过滤 + if filter.TenantID > 0 && alert.TenantID != filter.TenantID { + continue + } + // 按供应商过滤 + if filter.SupplierID > 0 && alert.SupplierID != filter.SupplierID { + continue + } + // 按类型过滤 + if filter.AlertType != "" && alert.AlertType != filter.AlertType { + continue + } + // 按级别过滤 + if filter.AlertLevel != "" && alert.AlertLevel != filter.AlertLevel { + continue + } + // 按状态过滤 + if filter.Status != "" && alert.Status != filter.Status { + continue + } + // 按时间范围过滤 + if !filter.StartTime.IsZero() && alert.CreatedAt.Before(filter.StartTime) { + continue + } + if !filter.EndTime.IsZero() && alert.CreatedAt.After(filter.EndTime) { + continue + } + // 关键字搜索 + if filter.Keywords != "" { + kw := filter.Keywords + if !strings.Contains(alert.Title, kw) && !strings.Contains(alert.Message, kw) { + continue + } + } + + result = append(result, alert) + } + + total := int64(len(result)) + + // 分页 + if filter.Offset > 0 { + if filter.Offset >= len(result) { + return []*model.Alert{}, total, nil + } + result = result[filter.Offset:] + } + if filter.Limit > 0 && filter.Limit < len(result) { + result = result[:filter.Limit] + } + + return result, total, nil +} + +// AlertService 告警服务 +type AlertService struct { + store AlertStoreInterface +} + +// NewAlertService 创建告警服务 +func NewAlertService(store AlertStoreInterface) *AlertService { + return &AlertService{store: store} +} + +// CreateAlert 创建告警 +func (s *AlertService) CreateAlert(ctx context.Context, alert *model.Alert) (*model.Alert, error) { + if alert == nil { + return nil, ErrInvalidAlertInput + } + if alert.Title == "" { + return nil, errors.New("alert title is required") + } + + // 设置默认值 + if alert.AlertID == "" { + alert.AlertID = model.NewAlert("", "", "", "", "", "").AlertID + } + if alert.Status == "" { + alert.Status = model.AlertStatusActive + } + now := time.Now() + if alert.CreatedAt.IsZero() { + alert.CreatedAt = now + } + if alert.UpdatedAt.IsZero() { + alert.UpdatedAt = now + } + if alert.FirstSeenAt.IsZero() { + alert.FirstSeenAt = now + } + if alert.LastSeenAt.IsZero() { + alert.LastSeenAt = now + } + + err := s.store.Create(ctx, alert) + if err != nil { + return nil, err + } + return alert, nil +} + +// GetAlert 获取告警 +func (s *AlertService) GetAlert(ctx context.Context, alertID string) (*model.Alert, error) { + if alertID == "" { + return nil, ErrInvalidAlertInput + } + return s.store.GetByID(ctx, alertID) +} + +// UpdateAlert 更新告警 +func (s *AlertService) UpdateAlert(ctx context.Context, alert *model.Alert) (*model.Alert, error) { + if alert == nil || alert.AlertID == "" { + return nil, ErrInvalidAlertInput + } + + alert.UpdatedAt = time.Now() + err := s.store.Update(ctx, alert) + if err != nil { + return nil, err + } + return alert, nil +} + +// DeleteAlert 删除告警 +func (s *AlertService) DeleteAlert(ctx context.Context, alertID string) error { + if alertID == "" { + return ErrInvalidAlertInput + } + return s.store.Delete(ctx, alertID) +} + +// ListAlerts 列出告警 +func (s *AlertService) ListAlerts(ctx context.Context, filter *model.AlertFilter) ([]*model.Alert, int64, error) { + if filter == nil { + filter = &model.AlertFilter{} + } + if filter.Limit == 0 { + filter.Limit = 100 + } + return s.store.List(ctx, filter) +} + +// ResolveAlert 解决告警 +func (s *AlertService) ResolveAlert(ctx context.Context, alertID, resolvedBy, note string) (*model.Alert, error) { + alert, err := s.store.GetByID(ctx, alertID) + if err != nil { + return nil, err + } + + alert.Resolve(resolvedBy, note) + err = s.store.Update(ctx, alert) + if err != nil { + return nil, err + } + return alert, nil +} + +// AcknowledgeAlert 确认告警 +func (s *AlertService) AcknowledgeAlert(ctx context.Context, alertID string) (*model.Alert, error) { + alert, err := s.store.GetByID(ctx, alertID) + if err != nil { + return nil, err + } + + alert.Acknowledge() + err = s.store.Update(ctx, alert) + if err != nil { + return nil, err + } + return alert, nil +} diff --git a/supply-api/internal/audit/service/batch_buffer.go b/supply-api/internal/audit/service/batch_buffer.go new file mode 100644 index 0000000..e2c50a8 --- /dev/null +++ b/supply-api/internal/audit/service/batch_buffer.go @@ -0,0 +1,203 @@ +package service + +import ( + "context" + "sync" + "time" + + "lijiaoqiao/supply-api/internal/audit/model" +) + +// BatchBufferConfig 批量缓冲区配置 +type BatchBufferConfig struct { + BatchSize int // 批量大小(默认50) + FlushInterval time.Duration // 刷新间隔(默认5ms) + BufferSize int // 通道缓冲大小(默认1000) +} + +// DefaultBatchBufferConfig 默认配置 +var DefaultBatchBufferConfig = BatchBufferConfig{ + BatchSize: 50, + FlushInterval: 5 * time.Millisecond, + BufferSize: 1000, +} + +// BatchBuffer 批量写入缓冲区 +// 设计目标:50条/批或5ms刷新间隔,支持5K-8K TPS +type BatchBuffer struct { + config BatchBufferConfig + eventCh chan *model.AuditEvent + buffer []*model.AuditEvent + mu sync.Mutex + closed bool + + flushTick *time.Ticker + stopCh chan struct{} + doneCh chan struct{} + + // FlushHandler 处理批量刷新回调 + FlushHandler func(events []*model.AuditEvent) error +} + +// NewBatchBuffer 创建批量缓冲区 +func NewBatchBuffer(batchSize int, flushInterval time.Duration) *BatchBuffer { + config := DefaultBatchBufferConfig + if batchSize > 0 { + config.BatchSize = batchSize + } + if flushInterval > 0 { + config.FlushInterval = flushInterval + } + + return &BatchBuffer{ + config: config, + eventCh: make(chan *model.AuditEvent, config.BufferSize), + buffer: make([]*model.AuditEvent, 0, batchSize), + flushTick: time.NewTicker(config.FlushInterval), + stopCh: make(chan struct{}), + doneCh: make(chan struct{}), + } +} + +// Start 启动批量缓冲处理 +func (b *BatchBuffer) Start(ctx context.Context) error { + go b.run() + return nil +} + +// run 后台处理循环 +func (b *BatchBuffer) run() { + defer close(b.doneCh) + + for { + select { + case <-b.stopCh: + // 停止信号:处理剩余缓冲 + b.flush() + return + case event := <-b.eventCh: + b.addEvent(event) + case <-b.flushTick.C: + b.flush() + } + } +} + +// addEvent 添加事件到缓冲 +func (b *BatchBuffer) addEvent(event *model.AuditEvent) { + b.mu.Lock() + defer b.mu.Unlock() + + b.buffer = append(b.buffer, event) + + // 达到批量大小立即刷新 + if len(b.buffer) >= b.config.BatchSize { + b.doFlushLocked() + } +} + +// flush 刷新缓冲(带锁)- 也会处理eventCh中的待处理事件 +func (b *BatchBuffer) flush() { + b.mu.Lock() + defer b.mu.Unlock() + + // 处理eventCh中已有的事件 + for { + select { + case event := <-b.eventCh: + b.buffer = append(b.buffer, event) + default: + goto done + } + } +done: + b.doFlushLocked() +} + +// doFlushLocked 执行刷新( caller 必须持锁) +func (b *BatchBuffer) doFlushLocked() { + if len(b.buffer) == 0 { + return + } + + // 复制缓冲数据 + events := make([]*model.AuditEvent, len(b.buffer)) + copy(events, b.buffer) + + // 清空缓冲 + b.buffer = b.buffer[:0] + + // 调用处理函数(如果已设置) + if b.FlushHandler != nil { + if err := b.FlushHandler(events); err != nil { + // TODO: 错误处理 - 记录日志、重试等 + // 当前简化处理:仅记录 + } + } +} + +// Add 添加审计事件 +func (b *BatchBuffer) Add(event *model.AuditEvent) error { + b.mu.Lock() + defer b.mu.Unlock() + + if b.closed { + return ErrBufferClosed + } + + select { + case b.eventCh <- event: + return nil + default: + // 通道满,添加到缓冲 + b.buffer = append(b.buffer, event) + if len(b.buffer) >= b.config.BatchSize { + b.doFlushLocked() + } + return nil + } +} + +// FlushNow 立即刷新 +func (b *BatchBuffer) FlushNow() error { + b.flush() + return nil +} + +// Close 关闭缓冲区 +func (b *BatchBuffer) Close() error { + b.mu.Lock() + if b.closed { + b.mu.Unlock() + return nil + } + b.closed = true + b.mu.Unlock() + + close(b.stopCh) + <-b.doneCh + b.flushTick.Stop() + close(b.eventCh) + + return nil +} + +// SetFlushHandler 设置刷新处理器 +func (b *BatchBuffer) SetFlushHandler(handler func(events []*model.AuditEvent) error) { + b.FlushHandler = handler +} + +// 错误定义 +var ( + ErrBufferClosed = &BatchBufferError{"buffer is closed"} + ErrMissingFlushHandler = &BatchBufferError{"flush handler not set"} +) + +// BatchBufferError 批量缓冲错误 +type BatchBufferError struct { + msg string +} + +func (e *BatchBufferError) Error() string { + return e.msg +} diff --git a/supply-api/internal/audit/service/batch_buffer_test.go b/supply-api/internal/audit/service/batch_buffer_test.go new file mode 100644 index 0000000..2738018 --- /dev/null +++ b/supply-api/internal/audit/service/batch_buffer_test.go @@ -0,0 +1,249 @@ +package service + +import ( + "context" + "sync" + "testing" + "time" + + "lijiaoqiao/supply-api/internal/audit/model" +) + +// TestBatchBuffer_BatchSize 测试50条/批刷新 +func TestBatchBuffer_BatchSize(t *testing.T) { + const batchSize = 50 + + buffer := NewBatchBuffer(batchSize, 100*time.Millisecond) // 100ms超时防止测试卡住 + ctx := context.Background() + + err := buffer.Start(ctx) + if err != nil { + t.Fatalf("Start failed: %v", err) + } + defer buffer.Close() + + // 收集器:接收批量事件 + var receivedBatches [][]*model.AuditEvent + var mu sync.Mutex + + buffer.SetFlushHandler(func(events []*model.AuditEvent) error { + mu.Lock() + receivedBatches = append(receivedBatches, events) + mu.Unlock() + return nil + }) + + // 添加50条事件,应该触发一次批量刷新 + for i := 0; i < batchSize; i++ { + event := &model.AuditEvent{ + EventID: "batch-test-001", + EventName: "TEST-EVENT", + } + if err := buffer.Add(event); err != nil { + t.Errorf("Add failed: %v", err) + } + } + + // 等待刷新完成 + time.Sleep(50 * time.Millisecond) + + // 验证:应该收到恰好一个批次 + mu.Lock() + if len(receivedBatches) != 1 { + t.Errorf("expected 1 batch, got %d", len(receivedBatches)) + } + if len(receivedBatches) > 0 && len(receivedBatches[0]) != batchSize { + t.Errorf("expected batch size %d, got %d", batchSize, len(receivedBatches[0])) + } + mu.Unlock() +} + +// TestBatchBuffer_TimeoutFlush 测试5ms超时刷新 +func TestBatchBuffer_TimeoutFlush(t *testing.T) { + const batchSize = 100 // 大于我们添加的数量 + const flushInterval = 5 * time.Millisecond + + buffer := NewBatchBuffer(batchSize, flushInterval) + ctx := context.Background() + + err := buffer.Start(ctx) + if err != nil { + t.Fatalf("Start failed: %v", err) + } + defer buffer.Close() + + // 收集器 + var receivedBatches [][]*model.AuditEvent + var mu sync.Mutex + + buffer.SetFlushHandler(func(events []*model.AuditEvent) error { + mu.Lock() + receivedBatches = append(receivedBatches, events) + mu.Unlock() + return nil + }) + + // 只添加3条事件,不满50条 + for i := 0; i < 3; i++ { + event := &model.AuditEvent{ + EventID: "batch-test-002", + EventName: "TEST-TIMEOUT", + } + if err := buffer.Add(event); err != nil { + t.Errorf("Add failed: %v", err) + } + } + + // 等待5ms超时刷新 + time.Sleep(20 * time.Millisecond) + + // 验证:应该收到一个批次,包含3条事件 + mu.Lock() + defer mu.Unlock() + if len(receivedBatches) != 1 { + t.Errorf("expected 1 batch (timeout flush), got %d", len(receivedBatches)) + } + if len(receivedBatches) > 0 && len(receivedBatches[0]) != 3 { + t.Errorf("expected 3 events in batch, got %d", len(receivedBatches[0])) + } +} + +// TestBatchBuffer_ConcurrentAccess 测试并发安全性 +func TestBatchBuffer_ConcurrentAccess(t *testing.T) { + const batchSize = 50 + const numGoroutines = 10 + const eventsPerGoroutine = 100 + + buffer := NewBatchBuffer(batchSize, 10*time.Millisecond) + ctx := context.Background() + + err := buffer.Start(ctx) + if err != nil { + t.Fatalf("Start failed: %v", err) + } + defer buffer.Close() + + var totalReceived int + var mu sync.Mutex + + buffer.SetFlushHandler(func(events []*model.AuditEvent) error { + mu.Lock() + totalReceived += len(events) + mu.Unlock() + return nil + }) + + // 并发添加事件 + var wg sync.WaitGroup + for g := 0; g < numGoroutines; g++ { + wg.Add(1) + go func(goroutineID int) { + defer wg.Done() + for i := 0; i < eventsPerGoroutine; i++ { + event := &model.AuditEvent{ + EventID: "batch-test-concurrent", + EventName: "TEST-CONCURRENT", + } + if err := buffer.Add(event); err != nil { + t.Errorf("Add failed: %v", err) + } + } + }(g) + } + + wg.Wait() + time.Sleep(50 * time.Millisecond) // 等待所有刷新完成 + + mu.Lock() + defer mu.Unlock() + expectedTotal := numGoroutines * eventsPerGoroutine + if totalReceived != expectedTotal { + t.Errorf("expected %d total events, got %d", expectedTotal, totalReceived) + } +} + +// TestBatchBuffer_Close 测试关闭 +func TestBatchBuffer_Close(t *testing.T) { + buffer := NewBatchBuffer(50, 10*time.Millisecond) + ctx := context.Background() + + err := buffer.Start(ctx) + if err != nil { + t.Fatalf("Start failed: %v", err) + } + + // 添加一些事件 + for i := 0; i < 5; i++ { + event := &model.AuditEvent{ + EventID: "batch-test-close", + EventName: "TEST-CLOSE", + } + if err := buffer.Add(event); err != nil { + t.Errorf("Add failed: %v", err) + } + } + + // 关闭缓冲区 + err = buffer.Close() + if err != nil { + t.Errorf("Close failed: %v", err) + } + + // 关闭后添加应该失败 + event := &model.AuditEvent{ + EventID: "batch-test-after-close", + EventName: "TEST-AFTER-CLOSE", + } + if err := buffer.Add(event); err == nil { + t.Errorf("Add after Close should fail") + } +} + +// TestBatchBuffer_FlushNow 测试手动刷新 +func TestBatchBuffer_FlushNow(t *testing.T) { + const batchSize = 100 // 足够大,不会自动触发 + + buffer := NewBatchBuffer(batchSize, 100*time.Millisecond) // 100ms才自动刷新 + ctx := context.Background() + + err := buffer.Start(ctx) + if err != nil { + t.Fatalf("Start failed: %v", err) + } + defer buffer.Close() + + var receivedBatches [][]*model.AuditEvent + var mu sync.Mutex + + buffer.SetFlushHandler(func(events []*model.AuditEvent) error { + mu.Lock() + receivedBatches = append(receivedBatches, events) + mu.Unlock() + return nil + }) + + // 添加少量事件 + for i := 0; i < 3; i++ { + event := &model.AuditEvent{ + EventID: "batch-test-manual", + EventName: "TEST-MANUAL", + } + if err := buffer.Add(event); err != nil { + t.Errorf("Add failed: %v", err) + } + } + + // 立即手动刷新 + err = buffer.FlushNow() + if err != nil { + t.Errorf("FlushNow failed: %v", err) + } + + time.Sleep(10 * time.Millisecond) + + mu.Lock() + defer mu.Unlock() + if len(receivedBatches) != 1 { + t.Errorf("expected 1 batch after FlushNow, got %d", len(receivedBatches)) + } +} diff --git a/supply-api/internal/domain/account.go b/supply-api/internal/domain/account.go index fbadddf..8d346db 100644 --- a/supply-api/internal/domain/account.go +++ b/supply-api/internal/domain/account.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "log" "net/netip" "time" @@ -141,6 +142,14 @@ func NewAccountService(store AccountStore, auditStore audit.AuditStore) AccountS return &accountService{store: store, auditStore: auditStore} } +// emitAudit 安全记录审计日志(失败只记录错误,不影响主流程) +func (s *accountService) emitAudit(ctx context.Context, event audit.Event) { + if err := s.auditStore.Emit(ctx, event); err != nil { + log.Printf("[AUDIT_ERROR] failed to emit audit event: %v, object_type=%s, object_id=%d, action=%s", + err, event.ObjectType, event.ObjectID, event.Action) + } +} + func (s *accountService) Verify(ctx context.Context, supplierID int64, provider Provider, accountType AccountType, credential string) (*VerifyResult, error) { // 开发阶段:模拟验证逻辑 result := &VerifyResult{ @@ -181,7 +190,7 @@ func (s *accountService) Create(ctx context.Context, req *CreateAccountRequest) } // 记录审计日志 - s.auditStore.Emit(ctx, audit.Event{ + s.emitAudit(ctx, audit.Event{ TenantID: req.SupplierID, ObjectType: "supply_account", ObjectID: account.ID, @@ -210,7 +219,7 @@ func (s *accountService) Activate(ctx context.Context, supplierID, accountID int return nil, err } - s.auditStore.Emit(ctx, audit.Event{ + s.emitAudit(ctx, audit.Event{ TenantID: supplierID, ObjectType: "supply_account", ObjectID: accountID, @@ -239,7 +248,7 @@ func (s *accountService) Suspend(ctx context.Context, supplierID, accountID int6 return nil, err } - s.auditStore.Emit(ctx, audit.Event{ + s.emitAudit(ctx, audit.Event{ TenantID: supplierID, ObjectType: "supply_account", ObjectID: accountID, @@ -260,7 +269,7 @@ func (s *accountService) Delete(ctx context.Context, supplierID, accountID int64 return errors.New("SUP_ACC_4092: cannot delete active accounts") } - s.auditStore.Emit(ctx, audit.Event{ + s.emitAudit(ctx, audit.Event{ TenantID: supplierID, ObjectType: "supply_account", ObjectID: accountID, diff --git a/supply-api/internal/domain/package.go b/supply-api/internal/domain/package.go index 52b3858..02ea420 100644 --- a/supply-api/internal/domain/package.go +++ b/supply-api/internal/domain/package.go @@ -3,6 +3,7 @@ package domain import ( "context" "errors" + "log" "net/netip" "time" @@ -132,6 +133,14 @@ func NewPackageService(store PackageStore, accountStore AccountStore, auditStore } } +// emitAudit 安全记录审计日志(失败只记录错误,不影响主流程) +func (s *packageService) emitAudit(ctx context.Context, event audit.Event) { + if err := s.auditStore.Emit(ctx, event); err != nil { + log.Printf("[AUDIT_ERROR] failed to emit audit event: %v, object_type=%s, object_id=%d, action=%s", + err, event.ObjectType, event.ObjectID, event.Action) + } +} + func (s *packageService) CreateDraft(ctx context.Context, supplierID int64, req *CreatePackageDraftRequest) (*Package, error) { pkg := &Package{ SupplierID: supplierID, @@ -154,7 +163,7 @@ func (s *packageService) CreateDraft(ctx context.Context, supplierID int64, req return nil, err } - s.auditStore.Emit(ctx, audit.Event{ + s.emitAudit(ctx, audit.Event{ TenantID: supplierID, ObjectType: "supply_package", ObjectID: pkg.ID, @@ -183,7 +192,7 @@ func (s *packageService) Publish(ctx context.Context, supplierID, packageID int6 return nil, err } - s.auditStore.Emit(ctx, audit.Event{ + s.emitAudit(ctx, audit.Event{ TenantID: supplierID, ObjectType: "supply_package", ObjectID: packageID, @@ -212,7 +221,7 @@ func (s *packageService) Pause(ctx context.Context, supplierID, packageID int64) return nil, err } - s.auditStore.Emit(ctx, audit.Event{ + s.emitAudit(ctx, audit.Event{ TenantID: supplierID, ObjectType: "supply_package", ObjectID: packageID, @@ -237,7 +246,7 @@ func (s *packageService) Unlist(ctx context.Context, supplierID, packageID int64 return nil, err } - s.auditStore.Emit(ctx, audit.Event{ + s.emitAudit(ctx, audit.Event{ TenantID: supplierID, ObjectType: "supply_package", ObjectID: packageID, @@ -275,7 +284,7 @@ func (s *packageService) Clone(ctx context.Context, supplierID, packageID int64) return nil, err } - s.auditStore.Emit(ctx, audit.Event{ + s.emitAudit(ctx, audit.Event{ TenantID: supplierID, ObjectType: "supply_package", ObjectID: clone.ID, @@ -292,6 +301,17 @@ func (s *packageService) BatchUpdatePrice(ctx context.Context, supplierID int64, } for _, item := range req.Items { + // 验证价格不能为负数 + if item.PricePer1MInput < 0 || item.PricePer1MOutput < 0 { + resp.FailedCount++ + resp.Failures = append(resp.Failures, BatchPriceFailure{ + PackageID: item.PackageID, + ErrorCode: "SUP_PKG_4004", + Message: "price cannot be negative", + }) + continue + } + pkg, err := s.store.GetByID(ctx, supplierID, item.PackageID) if err != nil { resp.FailedCount++ diff --git a/supply-api/internal/domain/settlement.go b/supply-api/internal/domain/settlement.go index 080a397..1a87ac5 100644 --- a/supply-api/internal/domain/settlement.go +++ b/supply-api/internal/domain/settlement.go @@ -3,6 +3,7 @@ package domain import ( "context" "errors" + "log" "net/netip" "time" @@ -160,11 +161,24 @@ func NewSettlementService(store SettlementStore, earningStore EarningStore, audi } } +// emitAudit 安全记录审计日志(失败只记录错误,不影响主流程) +func (s *settlementService) emitAudit(ctx context.Context, event audit.Event) { + if err := s.auditStore.Emit(ctx, event); err != nil { + log.Printf("[AUDIT_ERROR] failed to emit audit event: %v, object_type=%s, object_id=%d, action=%s", + err, event.ObjectType, event.ObjectID, event.Action) + } +} + func (s *settlementService) Withdraw(ctx context.Context, supplierID int64, req *WithdrawRequest) (*Settlement, error) { if req.SMSCode != "123456" { return nil, errors.New("invalid sms code") } + // 验证金额:必须为正数 + if req.Amount <= 0 { + return nil, errors.New("SUP_SET_4003: withdraw amount must be positive") + } + balance, err := s.store.GetWithdrawableBalance(ctx, supplierID) if err != nil { return nil, err @@ -192,7 +206,7 @@ func (s *settlementService) Withdraw(ctx context.Context, supplierID int64, req return nil, err } - s.auditStore.Emit(ctx, audit.Event{ + s.emitAudit(ctx, audit.Event{ TenantID: supplierID, ObjectType: "supply_settlement", ObjectID: settlement.ID, @@ -221,7 +235,7 @@ func (s *settlementService) Cancel(ctx context.Context, supplierID, settlementID return nil, err } - s.auditStore.Emit(ctx, audit.Event{ + s.emitAudit(ctx, audit.Event{ TenantID: supplierID, ObjectType: "supply_settlement", ObjectID: settlementID, diff --git a/supply-api/internal/httpapi/alert_api.go b/supply-api/internal/httpapi/alert_api.go new file mode 100644 index 0000000..2836c68 --- /dev/null +++ b/supply-api/internal/httpapi/alert_api.go @@ -0,0 +1,125 @@ +package httpapi + +import ( + "net/http" + "net/url" + + "lijiaoqiao/supply-api/internal/audit/handler" + "lijiaoqiao/supply-api/internal/audit/service" +) + +// AlertAPI 告警API处理器 +type AlertAPI struct { + alertHandler *handler.AlertHandler +} + +// NewAlertAPI 创建告警API处理器 +func NewAlertAPI() *AlertAPI { + // 创建内存告警存储 + alertStore := service.NewInMemoryAlertStore() + // 创建告警服务 + alertSvc := service.NewAlertService(alertStore) + // 创建告警处理器 + alertHandler := handler.NewAlertHandler(alertSvc) + + return &AlertAPI{ + alertHandler: alertHandler, + } +} + +// Register 注册告警路由 +func (a *AlertAPI) Register(mux *http.ServeMux) { + // Alert CRUD + mux.HandleFunc("/api/v1/audit/alerts", a.handleAlert) + mux.HandleFunc("/api/v1/audit/alerts/", a.handleAlertByID) +} + +// handleAlert 处理 /api/v1/audit/alerts 的路由分发 +func (a *AlertAPI) handleAlert(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodPost: + a.alertHandler.CreateAlert(w, r) + case http.MethodGet: + a.alertHandler.ListAlerts(w, r) + default: + writeError(w, http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", "method not allowed") + } +} + +// handleAlertByID 处理 /api/v1/audit/alerts/{alert_id} 的路由分发 +func (a *AlertAPI) handleAlertByID(w http.ResponseWriter, r *http.Request) { + // 提取路径最后部分判断操作 + path := r.URL.Path + if len(path) > 0 && path[len(path)-1] == '/' { + path = path[:len(path)-1] + } + + parts := splitPath(path) + if len(parts) < 5 { + writeError(w, http.StatusBadRequest, "INVALID_PATH", "invalid path") + return + } + + alertID := parts[4] + + // 检查是否是特殊操作 + if len(parts) > 5 && parts[5] == "resolve" { + if r.Method == http.MethodPost { + a.alertHandler.ResolveAlert(w, r) + } else { + writeError(w, http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", "method not allowed") + } + return + } + + // 常规CRUD操作 + switch r.Method { + case http.MethodGet: + // 安全设置alert_id到查询参数 + query := make(url.Values) + for k, v := range r.URL.Query() { + query[k] = v + } + query.Set("alert_id", alertID) + r.URL.RawQuery = query.Encode() + a.alertHandler.GetAlert(w, r) + case http.MethodPut: + query := make(url.Values) + for k, v := range r.URL.Query() { + query[k] = v + } + query.Set("alert_id", alertID) + r.URL.RawQuery = query.Encode() + a.alertHandler.UpdateAlert(w, r) + case http.MethodDelete: + query := make(url.Values) + for k, v := range r.URL.Query() { + query[k] = v + } + query.Set("alert_id", alertID) + r.URL.RawQuery = query.Encode() + a.alertHandler.DeleteAlert(w, r) + default: + writeError(w, http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", "method not allowed") + } +} + +// splitPath 分割路径 +func splitPath(path string) []string { + var parts []string + var current []byte + for i := 0; i < len(path); i++ { + if path[i] == '/' { + if len(current) > 0 { + parts = append(parts, string(current)) + current = nil + } + } else { + current = append(current, path[i]) + } + } + if len(current) > 0 { + parts = append(parts, string(current)) + } + return parts +} diff --git a/supply-api/internal/iam/middleware/scope_auth.go b/supply-api/internal/iam/middleware/scope_auth.go index 283d94a..50ba655 100644 --- a/supply-api/internal/iam/middleware/scope_auth.go +++ b/supply-api/internal/iam/middleware/scope_auth.go @@ -6,6 +6,7 @@ import ( "log" "net/http" + "lijiaoqiao/supply-api/internal/iam/model" "lijiaoqiao/supply-api/internal/middleware" ) @@ -17,7 +18,12 @@ const ( IAMTokenClaimsKey iamContextKey = "iam_token_claims" ) +// ClaimsVersion Token Claims版本号,用于迁移追踪 +const ClaimsVersion = 1 + // IAMTokenClaims IAM扩展Token Claims +// 版本: v1 +// 迁移路径: 见 MigrateClaims 函数 type IAMTokenClaims struct { SubjectID string `json:"subject_id"` Role string `json:"role"` @@ -25,30 +31,73 @@ type IAMTokenClaims struct { TenantID int64 `json:"tenant_id"` UserType string `json:"user_type"` // 用户类型: platform/supply/consumer Permissions []string `json:"permissions"` // 细粒度权限列表 + + // 版本控制字段(未来迁移用) + Version int `json:"version,omitempty"` } -// 角色层级定义 -var roleHierarchyLevels = map[string]int{ - "super_admin": 100, - "org_admin": 50, - "supply_admin": 40, - "consumer_admin": 40, - "operator": 30, - "developer": 20, - "finops": 20, - "supply_operator": 30, - "supply_finops": 20, - "supply_viewer": 10, - "consumer_operator": 30, - "consumer_viewer": 10, - "viewer": 10, +// MigrateClaims 将旧版本Claims迁移到当前版本 +// 迁移路径: +// v0 -> v1: 初始版本,添加 Version 字段 +// +// 使用示例: +// claims := &IAMTokenClaims{} +// if err := json.Unmarshal(data, claims); err != nil { +// return err +// } +// migrated := MigrateClaims(claims) +// // 使用 migrated +func MigrateClaims(claims *IAMTokenClaims) *IAMTokenClaims { + if claims == nil { + return nil + } + + // 当前版本是v1,无需迁移 + // 未来版本迁移: + // case 0: + // claims = migrateV0ToV1(claims) + // case 1: + // claims = migrateV1ToV2(claims) + claims.Version = ClaimsVersion + return claims } +// ValidateClaims 验证Claims完整性 +func ValidateClaims(claims *IAMTokenClaims) error { + if claims == nil { + return ErrInvalidClaims + } + if claims.SubjectID == "" { + return ErrInvalidSubjectID + } + return nil +} + +// 迁移相关错误 +var ( + ErrInvalidClaims = &ClaimsError{Code: "IAM_CLAIMS_4001", Message: "invalid claims structure"} + ErrInvalidSubjectID = &ClaimsError{Code: "IAM_CLAIMS_4002", Message: "subject_id is required"} +) + +// ClaimsError Claims相关错误 +type ClaimsError struct { + Code string + Message string +} + +func (e *ClaimsError) Error() string { + return e.Code + ": " + e.Message +} + +// 角色层级定义(已废弃,请使用 model.RoleHierarchyLevels) +// @deprecated 使用 model.RoleHierarchyLevels 获取角色层级 +var roleHierarchyLevels = model.RoleHierarchyLevels + // ScopeAuthMiddleware Scope权限验证中间件 type ScopeAuthMiddleware struct { // 路由-Scope映射 routeScopePolicies map[string][]string - // 角色层级(已废弃,使用包级变量roleHierarchyLevels) + // 角色层级 roleHierarchy map[string]int } @@ -56,7 +105,7 @@ type ScopeAuthMiddleware struct { func NewScopeAuthMiddleware() *ScopeAuthMiddleware { return &ScopeAuthMiddleware{ routeScopePolicies: make(map[string][]string), - roleHierarchy: roleHierarchyLevels, + roleHierarchy: model.RoleHierarchyLevels, // 使用统一的角色层级定义 } } @@ -142,11 +191,9 @@ func HasRoleLevel(ctx context.Context, minLevel int) bool { } // GetRoleLevel 获取角色层级数值 +// @deprecated 请使用 model.GetRoleLevelByCode func GetRoleLevel(role string) int { - if level, ok := roleHierarchyLevels[role]; ok { - return level - } - return 0 + return model.GetRoleLevelByCode(role) } // GetIAMTokenClaims 获取IAM Token Claims diff --git a/supply-api/internal/iam/model/role.go b/supply-api/internal/iam/model/role.go index 59ed4ba..e9c95ba 100644 --- a/supply-api/internal/iam/model/role.go +++ b/supply-api/internal/iam/model/role.go @@ -15,6 +15,7 @@ const ( ) // 角色层级常量(用于权限优先级判断) +// 注意:这些常量值必须与 RoleHierarchyLevels map保持一致 const ( LevelSuperAdmin = 100 LevelOrgAdmin = 50 @@ -25,6 +26,33 @@ const ( LevelViewer = 10 ) +// RoleHierarchyLevels 角色层级映射(用于权限验证) +// 层级越高,权限越大。superset角色可以执行subset角色的操作。 +// 注意:此map的值必须与上述常量保持一致! +var RoleHierarchyLevels = map[string]int{ + "super_admin": LevelSuperAdmin, // 100 - 超级管理员 + "org_admin": LevelOrgAdmin, // 50 - 组织管理员 + "supply_admin": LevelSupplyAdmin, // 40 - 供应商管理员 + "consumer_admin": LevelSupplyAdmin, // 40 - 消费者管理员(同供应商) + "operator": LevelOperator, // 30 - 操作员 + "developer": LevelDeveloper, // 20 - 开发者 + "finops": LevelFinops, // 20 - 财务运营 + "supply_operator": LevelOperator, // 30 - 供应商操作员 + "supply_finops": LevelFinops, // 20 - 供应商财务 + "supply_viewer": LevelViewer, // 10 - 供应商查看者 + "consumer_operator": LevelOperator, // 30 - 消费者操作员 + "consumer_viewer": LevelViewer, // 10 - 消费者查看者 + "viewer": LevelViewer, // 10 - 通用查看者 +} + +// GetRoleLevelByCode 根据角色代码获取层级数值 +func GetRoleLevelByCode(roleCode string) int { + if level, ok := RoleHierarchyLevels[roleCode]; ok { + return level + } + return 0 // 默认最低级别 +} + // 角色错误定义 var ( ErrInvalidRoleCode = errors.New("invalid role code: cannot be empty") diff --git a/supply-api/internal/middleware/auth.go b/supply-api/internal/middleware/auth.go index ab2378c..d0e02bd 100644 --- a/supply-api/internal/middleware/auth.go +++ b/supply-api/internal/middleware/auth.go @@ -7,6 +7,7 @@ import ( "encoding/json" "errors" "fmt" + "log" "net/http" "strconv" "strings" @@ -14,6 +15,8 @@ import ( "time" "github.com/golang-jwt/jwt/v5" + + "lijiaoqiao/supply-api/internal/iam/model" ) // TokenClaims JWT token claims @@ -84,11 +87,13 @@ type BruteForceProtection struct { lockoutDuration time.Duration attempts map[string]*attemptRecord mu sync.Mutex + cleanupCounter int64 // 清理触发计数器 } type attemptRecord struct { count int lockedUntil time.Time + lastAttempt time.Time // 最后尝试时间,用于过期清理 } // NewBruteForceProtection 创建暴力破解保护 @@ -114,9 +119,11 @@ func (b *BruteForceProtection) RecordFailedAttempt(ip string) { } record.count++ + record.lastAttempt = time.Now() if record.count >= b.maxAttempts { record.lockedUntil = time.Now().Add(b.lockoutDuration) } + b.triggerCleanup() } // IsLocked 检查IP是否被锁定 @@ -150,6 +157,42 @@ func (b *BruteForceProtection) Reset(ip string) { delete(b.attempts, ip) } +// triggerCleanup 触发清理(每100次操作清理一次过期记录) +func (b *BruteForceProtection) triggerCleanup() { + b.cleanupCounter++ + if b.cleanupCounter >= 100 { + b.cleanupCounter = 0 + b.cleanupExpiredLocked() + } +} + +// cleanupExpiredLocked 清理过期记录(需要持有锁) +// 清理条件:锁定已过期且最后尝试时间超过lockoutDuration +func (b *BruteForceProtection) cleanupExpiredLocked() { + now := time.Now() + threshold := now.Add(-b.lockoutDuration * 2) // 超过两倍锁定时长未活动的记录清理 + for ip, record := range b.attempts { + // 清理:锁定已过期且长时间无活动 + if record.lockedUntil.Before(now) && record.lastAttempt.Before(threshold) { + delete(b.attempts, ip) + } + } +} + +// CleanExpired 主动清理过期记录(可由外部定期调用) +func (b *BruteForceProtection) CleanExpired() { + b.mu.Lock() + defer b.mu.Unlock() + b.cleanupExpiredLocked() +} + +// Len 返回当前记录数量(用于监控) +func (b *BruteForceProtection) Len() int { + b.mu.Lock() + defer b.mu.Unlock() + return len(b.attempts) +} + // QueryKeyRejectMiddleware 拒绝外部query key入站 // 对应M-016指标 func (m *AuthMiddleware) QueryKeyRejectMiddleware(next http.Handler) http.Handler { @@ -263,7 +306,19 @@ func (m *AuthMiddleware) TokenVerifyMiddleware(next http.Handler) http.Handler { } } - tokenString := r.Context().Value(bearerTokenKey).(string) + // 安全检查:确保BearerExtractMiddleware已执行 + tokenValue := r.Context().Value(bearerTokenKey) + if tokenValue == nil { + writeAuthError(w, http.StatusUnauthorized, "AUTH_TOKEN_MISSING", + "bearer token is missing") + return + } + tokenString, ok := tokenValue.(string) + if !ok || tokenString == "" { + writeAuthError(w, http.StatusUnauthorized, "AUTH_TOKEN_INVALID", + "bearer token is invalid") + return + } claims, err := m.verifyToken(tokenString) if err != nil { @@ -289,7 +344,7 @@ func (m *AuthMiddleware) TokenVerifyMiddleware(next http.Handler) http.Handler { } // 检查token状态(是否被吊销) - status, err := m.checkTokenStatus(claims.ID) + status, err := m.checkTokenStatus(r.Context(), claims.ID) if err == nil && status != "active" { if m.auditEmitter != nil { m.auditEmitter.Emit(r.Context(), AuditEvent{ @@ -363,24 +418,21 @@ func (m *AuthMiddleware) ScopeRoleAuthzMiddleware(requiredScope string) func(htt } // 检查role权限 - roleHierarchy := map[string]int{ - "admin": 3, - "owner": 2, - "viewer": 1, - } + // 使用model.GetRoleLevelByCode获取统一角色层级定义 - // 路由权限要求 + // 路由权限要求(使用详细角色代码) + // viewer: level 10, operator: level 30, org_admin: level 50 routeRoles := map[string]string{ - "/api/v1/supply/accounts": "owner", - "/api/v1/supply/packages": "owner", - "/api/v1/supply/settlements": "owner", - "/api/v1/supply/billing": "viewer", - "/api/v1/supplier/billing": "viewer", + "/api/v1/supply/accounts": "org_admin", + "/api/v1/supply/packages": "org_admin", + "/api/v1/supply/settlements": "org_admin", + "/api/v1/supply/billing": "viewer", + "/api/v1/supplier/billing": "viewer", } for path, requiredRole := range routeRoles { if strings.HasPrefix(r.URL.Path, path) { - if roleLevel(claims.Role, roleHierarchy) < roleLevel(requiredRole, roleHierarchy) { + if model.GetRoleLevelByCode(claims.Role) < model.GetRoleLevelByCode(requiredRole) { writeAuthError(w, http.StatusForbidden, "AUTH_ROLE_DENIED", fmt.Sprintf("required role '%s' is not granted, current role: '%s'", requiredRole, claims.Role)) return @@ -430,7 +482,7 @@ func (m *AuthMiddleware) verifyToken(tokenString string) (*TokenClaims, error) { } // checkTokenStatus 检查token状态(从缓存或数据库) -func (m *AuthMiddleware) checkTokenStatus(tokenID string) (string, error) { +func (m *AuthMiddleware) checkTokenStatus(ctx context.Context, tokenID string) (string, error) { if m.tokenCache != nil { // 先从缓存检查 if status, found := m.tokenCache.Get(tokenID); found { @@ -440,7 +492,7 @@ func (m *AuthMiddleware) checkTokenStatus(tokenID string) (string, error) { // 缓存未命中,查询后端验证token状态 if m.tokenBackend != nil { - return m.tokenBackend.CheckTokenStatus(context.Background(), tokenID) + return m.tokenBackend.CheckTokenStatus(ctx, tokenID) } // 没有后端实现时,应该拒绝访问而不是默认active @@ -472,7 +524,10 @@ func writeAuthError(w http.ResponseWriter, status int, code, message string) { "message": message, }, } - json.NewEncoder(w).Encode(resp) + if err := json.NewEncoder(w).Encode(resp); err != nil { + // 记录编码错误(响应已经开始发送,无法回退) + log.Printf("[AUTH_ERROR] failed to encode error response: %v, code=%s", err, code) + } } // getRequestID 获取请求ID @@ -488,7 +543,10 @@ func getClientIP(r *http.Request) string { // 优先从X-Forwarded-For获取 if xff := r.Header.Get("X-Forwarded-For"); xff != "" { parts := strings.Split(xff, ",") - return strings.TrimSpace(parts[0]) + // 安全检查:空字符串已在上层判断,但防御性编程 + if len(parts) > 0 { + return strings.TrimSpace(parts[0]) + } } // X-Real-IP @@ -550,14 +608,6 @@ func containsScope(scopes []string, target string) bool { return false } -// roleLevel 获取角色等级 -func roleLevel(role string, hierarchy map[string]int) int { - if level, ok := hierarchy[role]; ok { - return level - } - return 0 -} - // parseSubjectID 解析subject ID func parseSubjectID(subject string) int64 { parts := strings.Split(subject, ":") @@ -570,7 +620,9 @@ func parseSubjectID(subject string) int64 { // TokenCache Token状态缓存 type TokenCache struct { - data map[string]cacheEntry + data map[string]cacheEntry + mu sync.RWMutex + cleanup int64 // 清理触发计数器 } type cacheEntry struct { @@ -581,34 +633,76 @@ type cacheEntry struct { // NewTokenCache 创建token缓存 func NewTokenCache() *TokenCache { return &TokenCache{ - data: make(map[string]cacheEntry), + data: make(map[string]cacheEntry), + cleanup: 0, } } // Get 获取token状态 func (c *TokenCache) Get(tokenID string) (string, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + if entry, ok := c.data[tokenID]; ok { if time.Now().Before(entry.expires) { return entry.status, true } - delete(c.data, tokenID) } return "", false } // Set 设置token状态 func (c *TokenCache) Set(tokenID, status string, ttl time.Duration) { + c.mu.Lock() + defer c.mu.Unlock() + c.data[tokenID] = cacheEntry{ status: status, expires: time.Now().Add(ttl), } + c.triggerCleanup() } // Invalidate 使token失效 func (c *TokenCache) Invalidate(tokenID string) { + c.mu.Lock() + defer c.mu.Unlock() delete(c.data, tokenID) } +// triggerCleanup 触发清理(每100次操作清理一次过期条目) +func (c *TokenCache) triggerCleanup() { + c.cleanup++ + if c.cleanup >= 100 { + c.cleanup = 0 + c.cleanupExpiredLocked() + } +} + +// cleanupExpiredLocked 清理过期条目(需要持有锁) +func (c *TokenCache) cleanupExpiredLocked() { + now := time.Now() + for tokenID, entry := range c.data { + if now.After(entry.expires) { + delete(c.data, tokenID) + } + } +} + +// CleanExpired 主动清理过期条目(可由外部定期调用) +func (c *TokenCache) CleanExpired() { + c.mu.Lock() + defer c.mu.Unlock() + c.cleanupExpiredLocked() +} + +// Len 返回缓存条目数量(用于监控) +func (c *TokenCache) Len() int { + c.mu.RLock() + defer c.mu.RUnlock() + return len(c.data) +} + // ComputeFingerprint 计算凭证指纹(用于审计) func ComputeFingerprint(credential string) string { hash := sha256.Sum256([]byte(credential)) diff --git a/supply-api/internal/middleware/auth_test.go b/supply-api/internal/middleware/auth_test.go index 0330b68..c69e2ae 100644 --- a/supply-api/internal/middleware/auth_test.go +++ b/supply-api/internal/middleware/auth_test.go @@ -1,6 +1,7 @@ package middleware import ( + "context" "net/http" "net/http/httptest" "strings" @@ -8,6 +9,8 @@ import ( "time" "github.com/golang-jwt/jwt/v5" + + "lijiaoqiao/supply-api/internal/iam/model" ) func TestTokenVerify(t *testing.T) { @@ -248,27 +251,25 @@ func TestContainsScope(t *testing.T) { } func TestRoleLevel(t *testing.T) { - hierarchy := map[string]int{ - "admin": 3, - "owner": 2, - "viewer": 1, - } - tests := []struct { role string expected int }{ - {"admin", 3}, - {"owner", 2}, - {"viewer", 1}, + {"super_admin", 100}, + {"org_admin", 50}, + {"supply_admin", 40}, + {"operator", 30}, + {"developer", 20}, + {"finops", 20}, + {"viewer", 10}, {"unknown", 0}, } for _, tt := range tests { t.Run(tt.role, func(t *testing.T) { - result := roleLevel(tt.role, hierarchy) + result := model.GetRoleLevelByCode(tt.role) if result != tt.expected { - t.Errorf("roleLevel(%s) = %d, want %d", tt.role, result, tt.expected) + t.Errorf("GetRoleLevelByCode(%s) = %d, want %d", tt.role, result, tt.expected) } }) } @@ -411,7 +412,7 @@ func TestMED02_TokenCacheMiss_ShouldNotAssumeActive(t *testing.T) { } // act - 查询一个不在缓存中的token - status, err := middleware.checkTokenStatus("nonexistent-token-id") + status, err := middleware.checkTokenStatus(context.Background(), "nonexistent-token-id") // assert - 缓存未命中且没有后端时应该返回错误(安全修复) // 修复前bug:缓存未命中时默认返回"active" diff --git a/supply-api/internal/repository/account.go b/supply-api/internal/repository/account.go index a1db940..f15286b 100644 --- a/supply-api/internal/repository/account.go +++ b/supply-api/internal/repository/account.go @@ -254,7 +254,7 @@ func (r *AccountRepository) List(ctx context.Context, supplierID int64) ([]*doma } defer rows.Close() - var accounts []*domain.Account + accounts := make([]*domain.Account, 0) for rows.Next() { account := &domain.Account{} err := rows.Scan( diff --git a/supply-api/internal/repository/package.go b/supply-api/internal/repository/package.go index 67cbc8c..81f4823 100644 --- a/supply-api/internal/repository/package.go +++ b/supply-api/internal/repository/package.go @@ -206,7 +206,7 @@ func (r *PackageRepository) List(ctx context.Context, supplierID int64) ([]*doma } defer rows.Close() - var packages []*domain.Package + packages := make([]*domain.Package, 0) for rows.Next() { pkg := &domain.Package{} err := rows.Scan( diff --git a/supply-api/internal/repository/settlement.go b/supply-api/internal/repository/settlement.go index cb25ace..b6832a7 100644 --- a/supply-api/internal/repository/settlement.go +++ b/supply-api/internal/repository/settlement.go @@ -195,7 +195,7 @@ func (r *SettlementRepository) List(ctx context.Context, supplierID int64) ([]*d } defer rows.Close() - var settlements []*domain.Settlement + settlements := make([]*domain.Settlement, 0) for rows.Next() { s := &domain.Settlement{} err := rows.Scan( diff --git a/supply-api/internal/storage/store.go b/supply-api/internal/storage/store.go index c281e7a..af707dc 100644 --- a/supply-api/internal/storage/store.go +++ b/supply-api/internal/storage/store.go @@ -66,7 +66,7 @@ func (s *InMemoryAccountStore) List(ctx context.Context, supplierID int64) ([]*d s.mu.RLock() defer s.mu.RUnlock() - var result []*domain.Account + result := make([]*domain.Account, 0) for _, account := range s.accounts { if account.SupplierID == supplierID { result = append(result, account) @@ -129,7 +129,7 @@ func (s *InMemoryPackageStore) List(ctx context.Context, supplierID int64) ([]*d s.mu.RLock() defer s.mu.RUnlock() - var result []*domain.Package + result := make([]*domain.Package, 0) for _, pkg := range s.packages { if pkg.SupplierID == supplierID { result = append(result, pkg) @@ -192,7 +192,7 @@ func (s *InMemorySettlementStore) List(ctx context.Context, supplierID int64) ([ s.mu.RLock() defer s.mu.RUnlock() - var result []*domain.Settlement + result := make([]*domain.Settlement, 0) for _, settlement := range s.settlements { if settlement.SupplierID == supplierID { result = append(result, settlement) @@ -264,8 +264,9 @@ func (s *InMemoryEarningStore) GetBillingSummary(ctx context.Context, supplierID // 内存幂等存储 type InMemoryIdempotencyStore struct { - mu sync.RWMutex - records map[string]*IdempotencyRecord + mu sync.RWMutex + records map[string]*IdempotencyRecord + cleanupCounter int64 // 清理触发计数器 } type IdempotencyRecord struct { @@ -303,6 +304,7 @@ func (s *InMemoryIdempotencyStore) SetProcessing(key string, ttl time.Duration) CreatedAt: time.Now(), ExpiresAt: time.Now().Add(ttl), } + s.triggerCleanupLocked() } func (s *InMemoryIdempotencyStore) SetSuccess(key string, response interface{}, ttl time.Duration) { @@ -316,4 +318,39 @@ func (s *InMemoryIdempotencyStore) SetSuccess(key string, response interface{}, CreatedAt: time.Now(), ExpiresAt: time.Now().Add(ttl), } + s.triggerCleanupLocked() +} + +// triggerCleanupLocked 触发清理(每100次操作清理一次过期记录) +// 调用时必须持有锁 +func (s *InMemoryIdempotencyStore) triggerCleanupLocked() { + s.cleanupCounter++ + if s.cleanupCounter >= 100 { + s.cleanupCounter = 0 + s.cleanupExpiredLocked() + } +} + +// cleanupExpiredLocked 清理过期记录(需要持有锁) +func (s *InMemoryIdempotencyStore) cleanupExpiredLocked() { + now := time.Now() + for key, record := range s.records { + if record.ExpiresAt.Before(now) { + delete(s.records, key) + } + } +} + +// CleanExpired 主动清理过期记录(可由外部定期调用) +func (s *InMemoryIdempotencyStore) CleanExpired() { + s.mu.Lock() + defer s.mu.Unlock() + s.cleanupExpiredLocked() +} + +// Len 返回当前记录数量(用于监控) +func (s *InMemoryIdempotencyStore) Len() int { + s.mu.RLock() + defer s.mu.RUnlock() + return len(s.records) } diff --git a/supply-api/pkg/error/errors.go b/supply-api/pkg/error/errors.go new file mode 100644 index 0000000..6ed9029 --- /dev/null +++ b/supply-api/pkg/error/errors.go @@ -0,0 +1,132 @@ +package error + +import ( + "fmt" + "strings" +) + +// ErrorCode 错误码格式:{DOMAIN}_{CODE} +// 错误码命名规范:{模块}_{问题类型}_{序号} +// +// 示例: +// - SUP_ACC_4001 (Supplier Account - 业务错误 - 4001) +// - AUDIT_EVT_4041 (Audit Event - 资源不存在 - 4041) +// +// 错误码分类: +// - 4xxx: 业务逻辑错误 +// - 5xxx: 系统/服务器错误 +// - 9xxx: 内部/未知错误 + +// 预定义的错误码前缀 +const ( + PrefixSUP = "SUP" // Supplier 模块 + PrefixIAM = "IAM" // Identity & Access Management 模块 + PrefixAudit = "AUDIT" // Audit 模块 + PrefixRepo = "REPO" // Repository 模块 + PrefixSys = "SYS" // 系统级错误 +) + +// CodeError 带错误码的错误 +type CodeError struct { + Code string // 错误码,如 "SUP_ACC_4001" + Message string // 错误消息 + Err error // 底层错误(可选) +} + +// Error 实现 error 接口 +func (e *CodeError) Error() string { + if e.Err != nil { + return fmt.Sprintf("%s: %s (caused by: %v)", e.Code, e.Message, e.Err) + } + return fmt.Sprintf("%s: %s", e.Code, e.Message) +} + +// Unwrap 获取底层错误 +func (e *CodeError) Unwrap() error { + return e.Err +} + +// NewCodeError 创建带错误码的错误 +func NewCodeError(code, message string) *CodeError { + return &CodeError{ + Code: code, + Message: message, + } +} + +// WrapCodeError 包装已有错误 +func WrapCodeError(err error, code, message string) *CodeError { + return &CodeError{ + Code: code, + Message: message, + Err: err, + } +} + +// IsCodeError 检查错误是否为 CodeError +func IsCodeError(err error) bool { + _, ok := err.(*CodeError) + return ok +} + +// GetErrorCode 从错误中提取错误码 +func GetErrorCode(err error) string { + var codeErr *CodeError + if As(err, &codeErr) { + return codeErr.Code + } + return "" +} + +// As 类型断言辅助函数 +func As(err error, target **CodeError) bool { + if err == nil { + return false + } + if e, ok := err.(*CodeError); ok { + *target = e + return true + } + if e, ok := err.(interface{ Unwrap() error }); ok { + return As(e.Unwrap(), target) + } + return false +} + +// Common errors - 可以被各模块引用的通用错误 +var ( + // ErrNotFound 资源不存在 + ErrNotFound = NewCodeError("SYS_4040", "resource not found") + + // ErrInvalidInput 输入参数无效 + ErrInvalidInput = NewCodeError("SYS_4000", "invalid input parameter") + + // ErrUnauthorized 未授权 + ErrUnauthorized = NewCodeError("SYS_4010", "unauthorized") + + // ErrForbidden 禁止访问 + ErrForbidden = NewCodeError("SYS_4030", "forbidden") + + // ErrInternalServer 服务器内部错误 + ErrInternalServer = NewCodeError("SYS_5000", "internal server error") + + // ErrConcurrencyConflict 并发冲突 + ErrConcurrencyConflict = NewCodeError("SYS_4090", "concurrency conflict") +) + +// ValidateErrorCode 验证错误码格式是否合法 +func ValidateErrorCode(code string) bool { + parts := strings.Split(code, "_") + if len(parts) < 2 { + return false + } + // 检查前缀是否为有效值 + prefix := parts[0] + validPrefixes := []string{PrefixSUP, PrefixIAM, PrefixAudit, PrefixRepo, PrefixSys} + for _, p := range validPrefixes { + if prefix == p { + return true + } + } + return false +} diff --git a/supply-api/pkg/error/errors_test.go b/supply-api/pkg/error/errors_test.go new file mode 100644 index 0000000..189ba00 --- /dev/null +++ b/supply-api/pkg/error/errors_test.go @@ -0,0 +1,95 @@ +package error + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewCodeError(t *testing.T) { + err := NewCodeError("TEST_4001", "test error message") + assert.Equal(t, "TEST_4001", err.Code) + assert.Equal(t, "test error message", err.Message) + assert.Nil(t, err.Err) + assert.Equal(t, "TEST_4001: test error message", err.Error()) +} + +func TestWrapCodeError(t *testing.T) { + originalErr := errors.New("original error") + err := WrapCodeError(originalErr, "TEST_4001", "wrapped error") + assert.Equal(t, "TEST_4001", err.Code) + assert.Equal(t, "wrapped error", err.Message) + assert.Equal(t, originalErr, err.Err) + assert.Contains(t, err.Error(), "caused by: original error") +} + +func TestIsCodeError(t *testing.T) { + codeErr := NewCodeError("TEST_4001", "test") + assert.True(t, IsCodeError(codeErr)) + + stdErr := errors.New("standard error") + assert.False(t, IsCodeError(stdErr)) +} + +func TestGetErrorCode(t *testing.T) { + codeErr := NewCodeError("SUP_ACC_4001", "test") + assert.Equal(t, "SUP_ACC_4001", GetErrorCode(codeErr)) + + stdErr := errors.New("standard error") + assert.Equal(t, "", GetErrorCode(stdErr)) +} + +func TestUnwrap(t *testing.T) { + originalErr := errors.New("original") + wrapped := WrapCodeError(originalErr, "TEST_4001", "wrapped") + + // 通过 Unwrap 获取原始错误 + unwrapped := wrapped.Unwrap() + assert.Equal(t, originalErr, unwrapped) +} + +func TestValidateErrorCode(t *testing.T) { + tests := []struct { + code string + expected bool + }{ + {"SUP_ACC_4001", true}, + {"IAM_ROLE_4040", true}, + {"AUDIT_EVT_5000", true}, + {"REPO_NOT_FOUND", true}, + {"SYS_5000", true}, + {"INVALID", false}, // 没有下划线分隔 + {"BAD_CODE", false}, // 前缀不在白名单 + {"X_4001", false}, // 前缀不在白名单 + {"", false}, // 空字符串 + {"TOOLONG_4001", false}, // 前缀太长 + } + + for _, tc := range tests { + t.Run(tc.code, func(t *testing.T) { + result := ValidateErrorCode(tc.code) + assert.Equal(t, tc.expected, result, "code: %s", tc.code) + }) + } +} + +func TestCommonErrors(t *testing.T) { + assert.Equal(t, "SYS_4040", ErrNotFound.Code) + assert.Equal(t, "resource not found", ErrNotFound.Message) + + assert.Equal(t, "SYS_4000", ErrInvalidInput.Code) + assert.Equal(t, "invalid input parameter", ErrInvalidInput.Message) + + assert.Equal(t, "SYS_4010", ErrUnauthorized.Code) + assert.Equal(t, "unauthorized", ErrUnauthorized.Message) + + assert.Equal(t, "SYS_4030", ErrForbidden.Code) + assert.Equal(t, "forbidden", ErrForbidden.Message) + + assert.Equal(t, "SYS_5000", ErrInternalServer.Code) + assert.Equal(t, "internal server error", ErrInternalServer.Message) + + assert.Equal(t, "SYS_4090", ErrConcurrencyConflict.Code) + assert.Equal(t, "concurrency conflict", ErrConcurrencyConflict.Message) +}