fix: 系统性修复安全问题、性能问题和错误处理

安全问题修复:
- X-Forwarded-For越界检查(auth.go)
- checkTokenStatus Context参数传递(auth.go)
- Type Assertion安全检查(auth.go)

性能问题修复:
- TokenCache过期清理机制
- BruteForceProtection过期清理
- InMemoryIdempotencyStore过期清理

错误处理修复:
- AuditStore.Emit返回error
- domain层emitAudit辅助方法
- List方法返回空slice而非nil
- 金额/价格负数验证

架构一致性:
- 统一使用model.RoleHierarchyLevels

新增功能:
- Alert API完整实现(CRUD+Resolve)
- pkg/error错误码集中管理
This commit is contained in:
Your Name
2026-04-07 07:41:25 +08:00
parent 12ce4913cd
commit d5b5a8ece0
21 changed files with 2321 additions and 83 deletions

View File

@@ -2,6 +2,7 @@ package audit
import (
"context"
"fmt"
"sync"
"time"
)
@@ -23,8 +24,10 @@ type Event struct {
// 审计存储接口
type AuditStore interface {
Emit(ctx context.Context, event Event)
Emit(ctx context.Context, event Event) error
Query(ctx context.Context, filter EventFilter) ([]Event, error)
QueryWithTotal(ctx context.Context, filter EventFilter) ([]Event, int64, error)
GetByID(ctx context.Context, eventID string) (Event, error)
}
// 事件过滤器
@@ -52,13 +55,14 @@ func NewMemoryAuditStore() *MemoryAuditStore {
}
}
func (s *MemoryAuditStore) Emit(ctx context.Context, event Event) {
func (s *MemoryAuditStore) Emit(ctx context.Context, event Event) error {
s.mu.Lock()
defer s.mu.Unlock()
event.EventID = generateEventID()
event.CreatedAt = time.Now()
s.events = append(s.events, event)
return nil
}
func (s *MemoryAuditStore) Query(ctx context.Context, filter EventFilter) ([]Event, error) {
@@ -90,6 +94,52 @@ func (s *MemoryAuditStore) Query(ctx context.Context, filter EventFilter) ([]Eve
return result, nil
}
// QueryWithTotal 查询事件并返回总数
func (s *MemoryAuditStore) QueryWithTotal(ctx context.Context, filter EventFilter) ([]Event, int64, error) {
s.mu.RLock()
defer s.mu.RUnlock()
var result []Event
total := int64(0)
for _, event := range s.events {
total++
if filter.TenantID > 0 && event.TenantID != filter.TenantID {
continue
}
if filter.ObjectType != "" && event.ObjectType != filter.ObjectType {
continue
}
if filter.ObjectID > 0 && event.ObjectID != filter.ObjectID {
continue
}
if filter.Action != "" && event.Action != filter.Action {
continue
}
result = append(result, event)
}
// 限制返回数量
if filter.Limit > 0 && len(result) > filter.Limit {
result = result[:filter.Limit]
}
return result, total, nil
}
// GetByID 根据事件ID获取单个事件
func (s *MemoryAuditStore) GetByID(ctx context.Context, eventID string) (Event, error) {
s.mu.RLock()
defer s.mu.RUnlock()
for _, event := range s.events {
if event.EventID == eventID {
return event, nil
}
}
return Event{}, fmt.Errorf("event not found")
}
func generateEventID() string {
return time.Now().Format("20060102150405") + "-evt"
}

View File

@@ -0,0 +1,350 @@
package handler
import (
"encoding/json"
"net/http"
"strconv"
"strings"
"lijiaoqiao/supply-api/internal/audit/model"
"lijiaoqiao/supply-api/internal/audit/service"
)
// AlertHandler 告警HTTP处理器
type AlertHandler struct {
svc *service.AlertService
}
// NewAlertHandler 创建告警处理器
func NewAlertHandler(svc *service.AlertService) *AlertHandler {
return &AlertHandler{svc: svc}
}
// CreateAlertRequest 创建告警请求
type CreateAlertRequest struct {
AlertName string `json:"alert_name"`
AlertType string `json:"alert_type"`
AlertLevel string `json:"alert_level"`
TenantID int64 `json:"tenant_id"`
SupplierID int64 `json:"supplier_id,omitempty"`
Title string `json:"title"`
Message string `json:"message"`
Description string `json:"description,omitempty"`
EventID string `json:"event_id,omitempty"`
EventIDs []string `json:"event_ids,omitempty"`
NotifyEnabled bool `json:"notify_enabled"`
Tags []string `json:"tags,omitempty"`
}
// UpdateAlertRequest 更新告警请求
type UpdateAlertRequest struct {
Title string `json:"title,omitempty"`
Message string `json:"message,omitempty"`
Description string `json:"description,omitempty"`
AlertLevel string `json:"alert_level,omitempty"`
Status string `json:"status,omitempty"`
NotifyEnabled *bool `json:"notify_enabled,omitempty"`
NotifyChannels []string `json:"notify_channels,omitempty"`
Tags []string `json:"tags,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
}
// ResolveAlertRequest 解决告警请求
type ResolveAlertRequest struct {
ResolvedBy string `json:"resolved_by"`
Note string `json:"note"`
}
// AlertResponse 告警响应
type AlertResponse struct {
Alert *model.Alert `json:"alert"`
}
// AlertListResponse 告警列表响应
type AlertListResponse struct {
Alerts []*model.Alert `json:"alerts"`
Total int64 `json:"total"`
Offset int `json:"offset"`
Limit int `json:"limit"`
}
// CreateAlert 处理 POST /api/v1/audit/alerts
func (h *AlertHandler) CreateAlert(w http.ResponseWriter, r *http.Request) {
var req CreateAlertRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeAlertError(w, http.StatusBadRequest, "INVALID_REQUEST", "invalid request body: "+err.Error())
return
}
// 验证必填字段
if req.Title == "" {
writeAlertError(w, http.StatusBadRequest, "MISSING_FIELD", "title is required")
return
}
if req.AlertType == "" {
writeAlertError(w, http.StatusBadRequest, "MISSING_FIELD", "alert_type is required")
return
}
// 创建告警
alert := &model.Alert{
AlertName: req.AlertName,
AlertType: req.AlertType,
AlertLevel: req.AlertLevel,
TenantID: req.TenantID,
SupplierID: req.SupplierID,
Title: req.Title,
Message: req.Message,
Description: req.Description,
EventID: req.EventID,
EventIDs: req.EventIDs,
NotifyEnabled: req.NotifyEnabled,
Tags: req.Tags,
}
result, err := h.svc.CreateAlert(r.Context(), alert)
if err != nil {
writeAlertError(w, http.StatusInternalServerError, "CREATE_FAILED", err.Error())
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(AlertResponse{Alert: result})
}
// GetAlert 处理 GET /api/v1/audit/alerts/{alert_id}
func (h *AlertHandler) GetAlert(w http.ResponseWriter, r *http.Request) {
alertID := extractAlertID(r)
if alertID == "" {
writeAlertError(w, http.StatusBadRequest, "MISSING_PARAM", "alert_id is required")
return
}
alert, err := h.svc.GetAlert(r.Context(), alertID)
if err != nil {
if err == service.ErrAlertNotFound {
writeAlertError(w, http.StatusNotFound, "NOT_FOUND", "alert not found")
return
}
writeAlertError(w, http.StatusInternalServerError, "GET_FAILED", err.Error())
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(AlertResponse{Alert: alert})
}
// ListAlerts 处理 GET /api/v1/audit/alerts
func (h *AlertHandler) ListAlerts(w http.ResponseWriter, r *http.Request) {
filter := &model.AlertFilter{}
// 解析查询参数
if tenantIDStr := r.URL.Query().Get("tenant_id"); tenantIDStr != "" {
tenantID, err := strconv.ParseInt(tenantIDStr, 10, 64)
if err == nil {
filter.TenantID = tenantID
}
}
if supplierIDStr := r.URL.Query().Get("supplier_id"); supplierIDStr != "" {
supplierID, err := strconv.ParseInt(supplierIDStr, 10, 64)
if err == nil {
filter.SupplierID = supplierID
}
}
if alertType := r.URL.Query().Get("alert_type"); alertType != "" {
filter.AlertType = alertType
}
if alertLevel := r.URL.Query().Get("alert_level"); alertLevel != "" {
filter.AlertLevel = alertLevel
}
if status := r.URL.Query().Get("status"); status != "" {
filter.Status = status
}
if keywords := r.URL.Query().Get("keywords"); keywords != "" {
filter.Keywords = keywords
}
if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" {
offset, err := strconv.Atoi(offsetStr)
if err == nil && offset >= 0 {
filter.Offset = offset
}
}
if limitStr := r.URL.Query().Get("limit"); limitStr != "" {
limit, err := strconv.Atoi(limitStr)
if err == nil && limit > 0 && limit <= 1000 {
filter.Limit = limit
}
}
if filter.Limit == 0 {
filter.Limit = 100
}
alerts, total, err := h.svc.ListAlerts(r.Context(), filter)
if err != nil {
writeAlertError(w, http.StatusInternalServerError, "LIST_FAILED", err.Error())
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(AlertListResponse{
Alerts: alerts,
Total: total,
Offset: filter.Offset,
Limit: filter.Limit,
})
}
// UpdateAlert 处理 PUT /api/v1/audit/alerts/{alert_id}
func (h *AlertHandler) UpdateAlert(w http.ResponseWriter, r *http.Request) {
alertID := extractAlertID(r)
if alertID == "" {
writeAlertError(w, http.StatusBadRequest, "MISSING_PARAM", "alert_id is required")
return
}
// 获取现有告警
alert, err := h.svc.GetAlert(r.Context(), alertID)
if err != nil {
if err == service.ErrAlertNotFound {
writeAlertError(w, http.StatusNotFound, "NOT_FOUND", "alert not found")
return
}
writeAlertError(w, http.StatusInternalServerError, "GET_FAILED", err.Error())
return
}
var req UpdateAlertRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeAlertError(w, http.StatusBadRequest, "INVALID_REQUEST", "invalid request body: "+err.Error())
return
}
// 更新字段
if req.Title != "" {
alert.Title = req.Title
}
if req.Message != "" {
alert.Message = req.Message
}
if req.Description != "" {
alert.Description = req.Description
}
if req.AlertLevel != "" {
alert.AlertLevel = req.AlertLevel
}
if req.Status != "" {
alert.Status = req.Status
}
if req.NotifyEnabled != nil {
alert.NotifyEnabled = *req.NotifyEnabled
}
if len(req.NotifyChannels) > 0 {
alert.NotifyChannels = req.NotifyChannels
}
if len(req.Tags) > 0 {
alert.Tags = req.Tags
}
if req.Metadata != nil {
alert.Metadata = req.Metadata
}
result, err := h.svc.UpdateAlert(r.Context(), alert)
if err != nil {
writeAlertError(w, http.StatusInternalServerError, "UPDATE_FAILED", err.Error())
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(AlertResponse{Alert: result})
}
// DeleteAlert 处理 DELETE /api/v1/audit/alerts/{alert_id}
func (h *AlertHandler) DeleteAlert(w http.ResponseWriter, r *http.Request) {
alertID := extractAlertID(r)
if alertID == "" {
writeAlertError(w, http.StatusBadRequest, "MISSING_PARAM", "alert_id is required")
return
}
err := h.svc.DeleteAlert(r.Context(), alertID)
if err != nil {
if err == service.ErrAlertNotFound {
writeAlertError(w, http.StatusNotFound, "NOT_FOUND", "alert not found")
return
}
writeAlertError(w, http.StatusInternalServerError, "DELETE_FAILED", err.Error())
return
}
w.WriteHeader(http.StatusNoContent)
}
// ResolveAlert 处理 POST /api/v1/audit/alerts/{alert_id}/resolve
func (h *AlertHandler) ResolveAlert(w http.ResponseWriter, r *http.Request) {
alertID := extractAlertID(r)
if alertID == "" {
writeAlertError(w, http.StatusBadRequest, "MISSING_PARAM", "alert_id is required")
return
}
var req ResolveAlertRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeAlertError(w, http.StatusBadRequest, "INVALID_REQUEST", "invalid request body: "+err.Error())
return
}
if req.ResolvedBy == "" {
writeAlertError(w, http.StatusBadRequest, "MISSING_FIELD", "resolved_by is required")
return
}
result, err := h.svc.ResolveAlert(r.Context(), alertID, req.ResolvedBy, req.Note)
if err != nil {
if err == service.ErrAlertNotFound {
writeAlertError(w, http.StatusNotFound, "NOT_FOUND", "alert not found")
return
}
writeAlertError(w, http.StatusInternalServerError, "RESOLVE_FAILED", err.Error())
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(AlertResponse{Alert: result})
}
// extractAlertID 从请求中提取alert_id优先从路径其次从查询参数
func extractAlertID(r *http.Request) string {
// 先尝试从路径提取
path := r.URL.Path
parts := strings.Split(strings.TrimPrefix(path, "/"), "/")
if len(parts) >= 5 && parts[0] == "api" && parts[1] == "v1" && parts[2] == "audit" && parts[3] == "alerts" {
if parts[4] != "" && parts[4] != "resolve" {
return parts[4]
}
}
// 再尝试从查询参数提取
if alertID := r.URL.Query().Get("alert_id"); alertID != "" {
return alertID
}
return ""
}
// writeAlertError 写入错误响应
func writeAlertError(w http.ResponseWriter, status int, code, message string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(ErrorResponse{
Error: message,
Code: code,
Details: "",
})
}

View File

@@ -0,0 +1,315 @@
package handler
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"lijiaoqiao/supply-api/internal/audit/model"
"lijiaoqiao/supply-api/internal/audit/service"
"github.com/stretchr/testify/assert"
)
// mockAlertStore 模拟告警存储
type mockAlertStore struct {
alerts map[string]*model.Alert
}
func newMockAlertStore() *mockAlertStore {
return &mockAlertStore{
alerts: make(map[string]*model.Alert),
}
}
func (m *mockAlertStore) Create(ctx context.Context, alert *model.Alert) error {
if alert.AlertID == "" {
alert.AlertID = "test-alert-id"
}
alert.CreatedAt = testTime
alert.UpdatedAt = testTime
m.alerts[alert.AlertID] = alert
return nil
}
func (m *mockAlertStore) GetByID(ctx context.Context, alertID string) (*model.Alert, error) {
if alert, ok := m.alerts[alertID]; ok {
return alert, nil
}
return nil, service.ErrAlertNotFound
}
func (m *mockAlertStore) Update(ctx context.Context, alert *model.Alert) error {
if _, ok := m.alerts[alert.AlertID]; !ok {
return service.ErrAlertNotFound
}
alert.UpdatedAt = testTime
m.alerts[alert.AlertID] = alert
return nil
}
func (m *mockAlertStore) Delete(ctx context.Context, alertID string) error {
if _, ok := m.alerts[alertID]; !ok {
return service.ErrAlertNotFound
}
delete(m.alerts, alertID)
return nil
}
func (m *mockAlertStore) List(ctx context.Context, filter *model.AlertFilter) ([]*model.Alert, int64, error) {
var result []*model.Alert
for _, alert := range m.alerts {
if filter.TenantID > 0 && alert.TenantID != filter.TenantID {
continue
}
if filter.Status != "" && alert.Status != filter.Status {
continue
}
result = append(result, alert)
}
return result, int64(len(result)), nil
}
var testTime = time.Now()
// TestAlertHandler_CreateAlert_Success 测试创建告警成功
func TestAlertHandler_CreateAlert_Success(t *testing.T) {
store := newMockAlertStore()
svc := service.NewAlertService(store)
h := NewAlertHandler(svc)
reqBody := CreateAlertRequest{
AlertName: "TEST_ALERT",
AlertType: "security",
AlertLevel: "warning",
TenantID: 2001,
Title: "Test Alert Title",
Message: "Test alert message",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/v1/audit/alerts", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
h.CreateAlert(w, req)
assert.Equal(t, http.StatusCreated, w.Code)
var result AlertResponse
err := json.Unmarshal(w.Body.Bytes(), &result)
assert.NoError(t, err)
assert.Equal(t, "Test Alert Title", result.Alert.Title)
assert.Equal(t, "security", result.Alert.AlertType)
}
// TestAlertHandler_CreateAlert_MissingTitle 测试缺少标题
func TestAlertHandler_CreateAlert_MissingTitle(t *testing.T) {
store := newMockAlertStore()
svc := service.NewAlertService(store)
h := NewAlertHandler(svc)
reqBody := CreateAlertRequest{
AlertType: "security",
AlertLevel: "warning",
TenantID: 2001,
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/v1/audit/alerts", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
h.CreateAlert(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
}
// TestAlertHandler_GetAlert_Success 测试获取告警成功
func TestAlertHandler_GetAlert_Success(t *testing.T) {
store := newMockAlertStore()
svc := service.NewAlertService(store)
h := NewAlertHandler(svc)
// 先创建一个告警
alert := &model.Alert{
AlertID: "test-alert-123",
AlertName: "TEST_ALERT",
AlertType: "security",
AlertLevel: "warning",
TenantID: 2001,
Title: "Test Alert",
Message: "Test message",
}
store.Create(context.Background(), alert)
// 获取告警
req := httptest.NewRequest("GET", "/api/v1/audit/alerts/test-alert-123", nil)
w := httptest.NewRecorder()
h.GetAlert(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var result AlertResponse
err := json.Unmarshal(w.Body.Bytes(), &result)
assert.NoError(t, err)
assert.Equal(t, "test-alert-123", result.Alert.AlertID)
}
// TestAlertHandler_GetAlert_NotFound 测试告警不存在
func TestAlertHandler_GetAlert_NotFound(t *testing.T) {
store := newMockAlertStore()
svc := service.NewAlertService(store)
h := NewAlertHandler(svc)
req := httptest.NewRequest("GET", "/api/v1/audit/alerts/nonexistent", nil)
w := httptest.NewRecorder()
h.GetAlert(w, req)
assert.Equal(t, http.StatusNotFound, w.Code)
}
// TestAlertHandler_ListAlerts_Success 测试列出告警成功
func TestAlertHandler_ListAlerts_Success(t *testing.T) {
store := newMockAlertStore()
svc := service.NewAlertService(store)
h := NewAlertHandler(svc)
// 创建多个告警
for i := 0; i < 3; i++ {
alert := &model.Alert{
AlertID: "alert-" + string(rune('a'+i)),
AlertType: "security",
AlertLevel: "warning",
TenantID: 2001,
Title: "Test Alert",
}
store.Create(context.Background(), alert)
}
req := httptest.NewRequest("GET", "/api/v1/audit/alerts?tenant_id=2001", nil)
w := httptest.NewRecorder()
h.ListAlerts(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var result AlertListResponse
err := json.Unmarshal(w.Body.Bytes(), &result)
assert.NoError(t, err)
assert.Equal(t, int64(3), result.Total)
}
// TestAlertHandler_UpdateAlert_Success 测试更新告警成功
func TestAlertHandler_UpdateAlert_Success(t *testing.T) {
store := newMockAlertStore()
svc := service.NewAlertService(store)
h := NewAlertHandler(svc)
// 先创建一个告警
alert := &model.Alert{
AlertID: "test-alert-123",
AlertType: "security",
AlertLevel: "warning",
TenantID: 2001,
Title: "Original Title",
}
store.Create(context.Background(), alert)
// 更新告警
reqBody := UpdateAlertRequest{
Title: "Updated Title",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", "/api/v1/audit/alerts/test-alert-123", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
h.UpdateAlert(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var result AlertResponse
json.Unmarshal(w.Body.Bytes(), &result)
assert.Equal(t, "Updated Title", result.Alert.Title)
}
// TestAlertHandler_DeleteAlert_Success 测试删除告警成功
func TestAlertHandler_DeleteAlert_Success(t *testing.T) {
store := newMockAlertStore()
svc := service.NewAlertService(store)
h := NewAlertHandler(svc)
// 先创建一个告警
alert := &model.Alert{
AlertID: "test-alert-123",
AlertType: "security",
AlertLevel: "warning",
TenantID: 2001,
}
store.Create(context.Background(), alert)
// 删除告警
req := httptest.NewRequest("DELETE", "/api/v1/audit/alerts/test-alert-123", nil)
w := httptest.NewRecorder()
h.DeleteAlert(w, req)
assert.Equal(t, http.StatusNoContent, w.Code)
}
// TestAlertHandler_DeleteAlert_NotFound 测试删除不存在的告警
func TestAlertHandler_DeleteAlert_NotFound(t *testing.T) {
store := newMockAlertStore()
svc := service.NewAlertService(store)
h := NewAlertHandler(svc)
req := httptest.NewRequest("DELETE", "/api/v1/audit/alerts/nonexistent", nil)
w := httptest.NewRecorder()
h.DeleteAlert(w, req)
assert.Equal(t, http.StatusNotFound, w.Code)
}
// TestAlertHandler_ResolveAlert_Success 测试解决告警成功
func TestAlertHandler_ResolveAlert_Success(t *testing.T) {
store := newMockAlertStore()
svc := service.NewAlertService(store)
h := NewAlertHandler(svc)
// 先创建一个告警
alert := &model.Alert{
AlertID: "test-alert-123",
AlertType: "security",
AlertLevel: "warning",
TenantID: 2001,
Status: model.AlertStatusActive,
}
store.Create(context.Background(), alert)
// 解决告警
reqBody := ResolveAlertRequest{
ResolvedBy: "admin",
Note: "Fixed the issue",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/v1/audit/alerts/test-alert-123/resolve", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
h.ResolveAlert(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var result AlertResponse
json.Unmarshal(w.Body.Bytes(), &result)
assert.Equal(t, model.AlertStatusResolved, result.Alert.Status)
assert.Equal(t, "admin", result.Alert.ResolvedBy)
}

View File

@@ -0,0 +1,195 @@
package model
import (
"time"
"github.com/google/uuid"
)
// 告警级别常量
const (
AlertLevelInfo = "info"
AlertLevelWarning = "warning"
AlertLevelError = "error"
AlertLevelCritical = "critical"
)
// 告警状态常量
const (
AlertStatusActive = "active"
AlertStatusResolved = "resolved"
AlertStatusAcknowledged = "acknowledged"
AlertStatusSuppressed = "suppressed"
)
// 告警类型常量
const (
AlertTypeSecurity = "security"
AlertTypeInvariant = "invariant"
AlertTypeCredential = "credential"
AlertTypeAuthentication = "authentication"
AlertTypeAuthorization = "authorization"
AlertTypeQuota = "quota"
)
// Alert 告警
type Alert struct {
// 基础标识
AlertID string `json:"alert_id"` // 告警唯一ID
AlertName string `json:"alert_name"` // 告警名称
AlertType string `json:"alert_type"` // 告警类型 (security/invariant/credential/etc.)
AlertLevel string `json:"alert_level"` // 告警级别 (info/warning/error/critical)
TenantID int64 `json:"tenant_id"` // 租户ID
SupplierID int64 `json:"supplier_id,omitempty"` // 供应商ID可选
// 告警内容
Title string `json:"title"` // 告警标题
Message string `json:"message"` // 告警消息
Description string `json:"description,omitempty"` // 详细描述
// 关联事件
EventID string `json:"event_id,omitempty"` // 关联的事件ID
EventIDs []string `json:"event_ids,omitempty"` // 关联的事件ID列表多个
// 触发条件
TriggerCondition string `json:"trigger_condition,omitempty"` // 触发条件
Threshold float64 `json:"threshold,omitempty"` // 阈值
CurrentValue float64 `json:"current_value,omitempty"` // 当前值
// 状态
Status string `json:"status"` // 状态 (active/resolved/acknowledged/suppressed)
ResolvedAt *time.Time `json:"resolved_at,omitempty"` // 解决时间
ResolvedBy string `json:"resolved_by,omitempty"` // 解决人
ResolveNote string `json:"resolve_note,omitempty"` // 解决备注
// 通知
NotifyEnabled bool `json:"notify_enabled"` // 是否启用通知
NotifyChannels []string `json:"notify_channels,omitempty"` // 通知渠道 (email/sms/webhook/etc.)
// 时间戳
CreatedAt time.Time `json:"created_at"` // 创建时间
UpdatedAt time.Time `json:"updated_at"` // 更新时间
FirstSeenAt time.Time `json:"first_seen_at"` // 首次出现时间
LastSeenAt time.Time `json:"last_seen_at"` // 最后出现时间
// 元数据
Metadata map[string]any `json:"metadata,omitempty"` // 扩展元数据
Tags []string `json:"tags,omitempty"` // 标签
}
// NewAlert 创建新告警
func NewAlert(alertName, alertType, alertLevel, tenantID string, title, message string) *Alert {
now := time.Now()
return &Alert{
AlertID: generateAlertID(),
AlertName: alertName,
AlertType: alertType,
AlertLevel: alertLevel,
TenantID: parseTenantID(tenantID),
Title: title,
Message: message,
Status: AlertStatusActive,
NotifyEnabled: true,
CreatedAt: now,
UpdatedAt: now,
FirstSeenAt: now,
LastSeenAt: now,
Metadata: make(map[string]any),
Tags: []string{},
}
}
// generateAlertID 生成告警ID
func generateAlertID() string {
return "ALT-" + uuid.New().String()[:8]
}
// parseTenantID 解析租户ID
func parseTenantID(tenantID string) int64 {
var id int64
for _, c := range tenantID {
if c >= '0' && c <= '9' {
id = id*10 + int64(c-'0')
}
}
return id
}
// IsActive 检查告警是否处于活跃状态
func (a *Alert) IsActive() bool {
return a.Status == AlertStatusActive
}
// IsResolved 检查告警是否已解决
func (a *Alert) IsResolved() bool {
return a.Status == AlertStatusResolved
}
// Resolve 解决告警
func (a *Alert) Resolve(resolvedBy, note string) {
now := time.Now()
a.Status = AlertStatusResolved
a.ResolvedAt = &now
a.ResolvedBy = resolvedBy
a.ResolveNote = note
a.UpdatedAt = now
}
// Acknowledge 确认告警
func (a *Alert) Acknowledge() {
a.Status = AlertStatusAcknowledged
a.UpdatedAt = time.Now()
}
// Suppress 抑制告警
func (a *Alert) Suppress() {
a.Status = AlertStatusSuppressed
a.UpdatedAt = time.Now()
}
// UpdateLastSeen 更新最后出现时间
func (a *Alert) UpdateLastSeen() {
a.LastSeenAt = time.Now()
a.UpdatedAt = time.Now()
}
// AddEventID 添加关联事件ID
func (a *Alert) AddEventID(eventID string) {
a.EventIDs = append(a.EventIDs, eventID)
if a.EventID == "" {
a.EventID = eventID
}
a.UpdateLastSeen()
}
// SetMetadata 设置元数据
func (a *Alert) SetMetadata(key string, value any) {
if a.Metadata == nil {
a.Metadata = make(map[string]any)
}
a.Metadata[key] = value
}
// AddTag 添加标签
func (a *Alert) AddTag(tag string) {
for _, t := range a.Tags {
if t == tag {
return
}
}
a.Tags = append(a.Tags, tag)
}
// AlertFilter 告警查询过滤器
type AlertFilter struct {
TenantID int64
SupplierID int64
AlertType string
AlertLevel string
Status string
StartTime time.Time
EndTime time.Time
Keywords string // 关键字搜索(标题/消息)
Limit int
Offset int
}

View File

@@ -0,0 +1,274 @@
package service
import (
"context"
"errors"
"strings"
"sync"
"time"
"github.com/google/uuid"
"lijiaoqiao/supply-api/internal/audit/model"
)
// 错误定义
var (
ErrAlertNotFound = errors.New("alert not found")
ErrInvalidAlertInput = errors.New("invalid alert input")
ErrAlertConflict = errors.New("alert conflict")
)
// AlertStoreInterface 告警存储接口
type AlertStoreInterface interface {
Create(ctx context.Context, alert *model.Alert) error
GetByID(ctx context.Context, alertID string) (*model.Alert, error)
Update(ctx context.Context, alert *model.Alert) error
Delete(ctx context.Context, alertID string) error
List(ctx context.Context, filter *model.AlertFilter) ([]*model.Alert, int64, error)
}
// InMemoryAlertStore 内存告警存储
type InMemoryAlertStore struct {
mu sync.RWMutex
alerts map[string]*model.Alert
}
// NewInMemoryAlertStore 创建内存告警存储
func NewInMemoryAlertStore() *InMemoryAlertStore {
return &InMemoryAlertStore{
alerts: make(map[string]*model.Alert),
}
}
// Create 创建告警
func (s *InMemoryAlertStore) Create(ctx context.Context, alert *model.Alert) error {
s.mu.Lock()
defer s.mu.Unlock()
if alert.AlertID == "" {
alert.AlertID = "ALT-" + uuid.New().String()[:8]
}
alert.CreatedAt = time.Now()
alert.UpdatedAt = time.Now()
s.alerts[alert.AlertID] = alert
return nil
}
// GetByID 根据ID获取告警
func (s *InMemoryAlertStore) GetByID(ctx context.Context, alertID string) (*model.Alert, error) {
s.mu.RLock()
defer s.mu.RUnlock()
if alert, ok := s.alerts[alertID]; ok {
return alert, nil
}
return nil, ErrAlertNotFound
}
// Update 更新告警
func (s *InMemoryAlertStore) Update(ctx context.Context, alert *model.Alert) error {
s.mu.Lock()
defer s.mu.Unlock()
if _, ok := s.alerts[alert.AlertID]; !ok {
return ErrAlertNotFound
}
alert.UpdatedAt = time.Now()
s.alerts[alert.AlertID] = alert
return nil
}
// Delete 删除告警
func (s *InMemoryAlertStore) Delete(ctx context.Context, alertID string) error {
s.mu.Lock()
defer s.mu.Unlock()
if _, ok := s.alerts[alertID]; !ok {
return ErrAlertNotFound
}
delete(s.alerts, alertID)
return nil
}
// List 查询告警列表
func (s *InMemoryAlertStore) List(ctx context.Context, filter *model.AlertFilter) ([]*model.Alert, int64, error) {
s.mu.RLock()
defer s.mu.RUnlock()
var result []*model.Alert
for _, alert := range s.alerts {
// 按租户过滤
if filter.TenantID > 0 && alert.TenantID != filter.TenantID {
continue
}
// 按供应商过滤
if filter.SupplierID > 0 && alert.SupplierID != filter.SupplierID {
continue
}
// 按类型过滤
if filter.AlertType != "" && alert.AlertType != filter.AlertType {
continue
}
// 按级别过滤
if filter.AlertLevel != "" && alert.AlertLevel != filter.AlertLevel {
continue
}
// 按状态过滤
if filter.Status != "" && alert.Status != filter.Status {
continue
}
// 按时间范围过滤
if !filter.StartTime.IsZero() && alert.CreatedAt.Before(filter.StartTime) {
continue
}
if !filter.EndTime.IsZero() && alert.CreatedAt.After(filter.EndTime) {
continue
}
// 关键字搜索
if filter.Keywords != "" {
kw := filter.Keywords
if !strings.Contains(alert.Title, kw) && !strings.Contains(alert.Message, kw) {
continue
}
}
result = append(result, alert)
}
total := int64(len(result))
// 分页
if filter.Offset > 0 {
if filter.Offset >= len(result) {
return []*model.Alert{}, total, nil
}
result = result[filter.Offset:]
}
if filter.Limit > 0 && filter.Limit < len(result) {
result = result[:filter.Limit]
}
return result, total, nil
}
// AlertService 告警服务
type AlertService struct {
store AlertStoreInterface
}
// NewAlertService 创建告警服务
func NewAlertService(store AlertStoreInterface) *AlertService {
return &AlertService{store: store}
}
// CreateAlert 创建告警
func (s *AlertService) CreateAlert(ctx context.Context, alert *model.Alert) (*model.Alert, error) {
if alert == nil {
return nil, ErrInvalidAlertInput
}
if alert.Title == "" {
return nil, errors.New("alert title is required")
}
// 设置默认值
if alert.AlertID == "" {
alert.AlertID = model.NewAlert("", "", "", "", "", "").AlertID
}
if alert.Status == "" {
alert.Status = model.AlertStatusActive
}
now := time.Now()
if alert.CreatedAt.IsZero() {
alert.CreatedAt = now
}
if alert.UpdatedAt.IsZero() {
alert.UpdatedAt = now
}
if alert.FirstSeenAt.IsZero() {
alert.FirstSeenAt = now
}
if alert.LastSeenAt.IsZero() {
alert.LastSeenAt = now
}
err := s.store.Create(ctx, alert)
if err != nil {
return nil, err
}
return alert, nil
}
// GetAlert 获取告警
func (s *AlertService) GetAlert(ctx context.Context, alertID string) (*model.Alert, error) {
if alertID == "" {
return nil, ErrInvalidAlertInput
}
return s.store.GetByID(ctx, alertID)
}
// UpdateAlert 更新告警
func (s *AlertService) UpdateAlert(ctx context.Context, alert *model.Alert) (*model.Alert, error) {
if alert == nil || alert.AlertID == "" {
return nil, ErrInvalidAlertInput
}
alert.UpdatedAt = time.Now()
err := s.store.Update(ctx, alert)
if err != nil {
return nil, err
}
return alert, nil
}
// DeleteAlert 删除告警
func (s *AlertService) DeleteAlert(ctx context.Context, alertID string) error {
if alertID == "" {
return ErrInvalidAlertInput
}
return s.store.Delete(ctx, alertID)
}
// ListAlerts 列出告警
func (s *AlertService) ListAlerts(ctx context.Context, filter *model.AlertFilter) ([]*model.Alert, int64, error) {
if filter == nil {
filter = &model.AlertFilter{}
}
if filter.Limit == 0 {
filter.Limit = 100
}
return s.store.List(ctx, filter)
}
// ResolveAlert 解决告警
func (s *AlertService) ResolveAlert(ctx context.Context, alertID, resolvedBy, note string) (*model.Alert, error) {
alert, err := s.store.GetByID(ctx, alertID)
if err != nil {
return nil, err
}
alert.Resolve(resolvedBy, note)
err = s.store.Update(ctx, alert)
if err != nil {
return nil, err
}
return alert, nil
}
// AcknowledgeAlert 确认告警
func (s *AlertService) AcknowledgeAlert(ctx context.Context, alertID string) (*model.Alert, error) {
alert, err := s.store.GetByID(ctx, alertID)
if err != nil {
return nil, err
}
alert.Acknowledge()
err = s.store.Update(ctx, alert)
if err != nil {
return nil, err
}
return alert, nil
}

View File

@@ -0,0 +1,203 @@
package service
import (
"context"
"sync"
"time"
"lijiaoqiao/supply-api/internal/audit/model"
)
// BatchBufferConfig 批量缓冲区配置
type BatchBufferConfig struct {
BatchSize int // 批量大小默认50
FlushInterval time.Duration // 刷新间隔默认5ms
BufferSize int // 通道缓冲大小默认1000
}
// DefaultBatchBufferConfig 默认配置
var DefaultBatchBufferConfig = BatchBufferConfig{
BatchSize: 50,
FlushInterval: 5 * time.Millisecond,
BufferSize: 1000,
}
// BatchBuffer 批量写入缓冲区
// 设计目标50条/批或5ms刷新间隔支持5K-8K TPS
type BatchBuffer struct {
config BatchBufferConfig
eventCh chan *model.AuditEvent
buffer []*model.AuditEvent
mu sync.Mutex
closed bool
flushTick *time.Ticker
stopCh chan struct{}
doneCh chan struct{}
// FlushHandler 处理批量刷新回调
FlushHandler func(events []*model.AuditEvent) error
}
// NewBatchBuffer 创建批量缓冲区
func NewBatchBuffer(batchSize int, flushInterval time.Duration) *BatchBuffer {
config := DefaultBatchBufferConfig
if batchSize > 0 {
config.BatchSize = batchSize
}
if flushInterval > 0 {
config.FlushInterval = flushInterval
}
return &BatchBuffer{
config: config,
eventCh: make(chan *model.AuditEvent, config.BufferSize),
buffer: make([]*model.AuditEvent, 0, batchSize),
flushTick: time.NewTicker(config.FlushInterval),
stopCh: make(chan struct{}),
doneCh: make(chan struct{}),
}
}
// Start 启动批量缓冲处理
func (b *BatchBuffer) Start(ctx context.Context) error {
go b.run()
return nil
}
// run 后台处理循环
func (b *BatchBuffer) run() {
defer close(b.doneCh)
for {
select {
case <-b.stopCh:
// 停止信号:处理剩余缓冲
b.flush()
return
case event := <-b.eventCh:
b.addEvent(event)
case <-b.flushTick.C:
b.flush()
}
}
}
// addEvent 添加事件到缓冲
func (b *BatchBuffer) addEvent(event *model.AuditEvent) {
b.mu.Lock()
defer b.mu.Unlock()
b.buffer = append(b.buffer, event)
// 达到批量大小立即刷新
if len(b.buffer) >= b.config.BatchSize {
b.doFlushLocked()
}
}
// flush 刷新缓冲(带锁)- 也会处理eventCh中的待处理事件
func (b *BatchBuffer) flush() {
b.mu.Lock()
defer b.mu.Unlock()
// 处理eventCh中已有的事件
for {
select {
case event := <-b.eventCh:
b.buffer = append(b.buffer, event)
default:
goto done
}
}
done:
b.doFlushLocked()
}
// doFlushLocked 执行刷新( caller 必须持锁)
func (b *BatchBuffer) doFlushLocked() {
if len(b.buffer) == 0 {
return
}
// 复制缓冲数据
events := make([]*model.AuditEvent, len(b.buffer))
copy(events, b.buffer)
// 清空缓冲
b.buffer = b.buffer[:0]
// 调用处理函数(如果已设置)
if b.FlushHandler != nil {
if err := b.FlushHandler(events); err != nil {
// TODO: 错误处理 - 记录日志、重试等
// 当前简化处理:仅记录
}
}
}
// Add 添加审计事件
func (b *BatchBuffer) Add(event *model.AuditEvent) error {
b.mu.Lock()
defer b.mu.Unlock()
if b.closed {
return ErrBufferClosed
}
select {
case b.eventCh <- event:
return nil
default:
// 通道满,添加到缓冲
b.buffer = append(b.buffer, event)
if len(b.buffer) >= b.config.BatchSize {
b.doFlushLocked()
}
return nil
}
}
// FlushNow 立即刷新
func (b *BatchBuffer) FlushNow() error {
b.flush()
return nil
}
// Close 关闭缓冲区
func (b *BatchBuffer) Close() error {
b.mu.Lock()
if b.closed {
b.mu.Unlock()
return nil
}
b.closed = true
b.mu.Unlock()
close(b.stopCh)
<-b.doneCh
b.flushTick.Stop()
close(b.eventCh)
return nil
}
// SetFlushHandler 设置刷新处理器
func (b *BatchBuffer) SetFlushHandler(handler func(events []*model.AuditEvent) error) {
b.FlushHandler = handler
}
// 错误定义
var (
ErrBufferClosed = &BatchBufferError{"buffer is closed"}
ErrMissingFlushHandler = &BatchBufferError{"flush handler not set"}
)
// BatchBufferError 批量缓冲错误
type BatchBufferError struct {
msg string
}
func (e *BatchBufferError) Error() string {
return e.msg
}

View File

@@ -0,0 +1,249 @@
package service
import (
"context"
"sync"
"testing"
"time"
"lijiaoqiao/supply-api/internal/audit/model"
)
// TestBatchBuffer_BatchSize 测试50条/批刷新
func TestBatchBuffer_BatchSize(t *testing.T) {
const batchSize = 50
buffer := NewBatchBuffer(batchSize, 100*time.Millisecond) // 100ms超时防止测试卡住
ctx := context.Background()
err := buffer.Start(ctx)
if err != nil {
t.Fatalf("Start failed: %v", err)
}
defer buffer.Close()
// 收集器:接收批量事件
var receivedBatches [][]*model.AuditEvent
var mu sync.Mutex
buffer.SetFlushHandler(func(events []*model.AuditEvent) error {
mu.Lock()
receivedBatches = append(receivedBatches, events)
mu.Unlock()
return nil
})
// 添加50条事件应该触发一次批量刷新
for i := 0; i < batchSize; i++ {
event := &model.AuditEvent{
EventID: "batch-test-001",
EventName: "TEST-EVENT",
}
if err := buffer.Add(event); err != nil {
t.Errorf("Add failed: %v", err)
}
}
// 等待刷新完成
time.Sleep(50 * time.Millisecond)
// 验证:应该收到恰好一个批次
mu.Lock()
if len(receivedBatches) != 1 {
t.Errorf("expected 1 batch, got %d", len(receivedBatches))
}
if len(receivedBatches) > 0 && len(receivedBatches[0]) != batchSize {
t.Errorf("expected batch size %d, got %d", batchSize, len(receivedBatches[0]))
}
mu.Unlock()
}
// TestBatchBuffer_TimeoutFlush 测试5ms超时刷新
func TestBatchBuffer_TimeoutFlush(t *testing.T) {
const batchSize = 100 // 大于我们添加的数量
const flushInterval = 5 * time.Millisecond
buffer := NewBatchBuffer(batchSize, flushInterval)
ctx := context.Background()
err := buffer.Start(ctx)
if err != nil {
t.Fatalf("Start failed: %v", err)
}
defer buffer.Close()
// 收集器
var receivedBatches [][]*model.AuditEvent
var mu sync.Mutex
buffer.SetFlushHandler(func(events []*model.AuditEvent) error {
mu.Lock()
receivedBatches = append(receivedBatches, events)
mu.Unlock()
return nil
})
// 只添加3条事件不满50条
for i := 0; i < 3; i++ {
event := &model.AuditEvent{
EventID: "batch-test-002",
EventName: "TEST-TIMEOUT",
}
if err := buffer.Add(event); err != nil {
t.Errorf("Add failed: %v", err)
}
}
// 等待5ms超时刷新
time.Sleep(20 * time.Millisecond)
// 验证应该收到一个批次包含3条事件
mu.Lock()
defer mu.Unlock()
if len(receivedBatches) != 1 {
t.Errorf("expected 1 batch (timeout flush), got %d", len(receivedBatches))
}
if len(receivedBatches) > 0 && len(receivedBatches[0]) != 3 {
t.Errorf("expected 3 events in batch, got %d", len(receivedBatches[0]))
}
}
// TestBatchBuffer_ConcurrentAccess 测试并发安全性
func TestBatchBuffer_ConcurrentAccess(t *testing.T) {
const batchSize = 50
const numGoroutines = 10
const eventsPerGoroutine = 100
buffer := NewBatchBuffer(batchSize, 10*time.Millisecond)
ctx := context.Background()
err := buffer.Start(ctx)
if err != nil {
t.Fatalf("Start failed: %v", err)
}
defer buffer.Close()
var totalReceived int
var mu sync.Mutex
buffer.SetFlushHandler(func(events []*model.AuditEvent) error {
mu.Lock()
totalReceived += len(events)
mu.Unlock()
return nil
})
// 并发添加事件
var wg sync.WaitGroup
for g := 0; g < numGoroutines; g++ {
wg.Add(1)
go func(goroutineID int) {
defer wg.Done()
for i := 0; i < eventsPerGoroutine; i++ {
event := &model.AuditEvent{
EventID: "batch-test-concurrent",
EventName: "TEST-CONCURRENT",
}
if err := buffer.Add(event); err != nil {
t.Errorf("Add failed: %v", err)
}
}
}(g)
}
wg.Wait()
time.Sleep(50 * time.Millisecond) // 等待所有刷新完成
mu.Lock()
defer mu.Unlock()
expectedTotal := numGoroutines * eventsPerGoroutine
if totalReceived != expectedTotal {
t.Errorf("expected %d total events, got %d", expectedTotal, totalReceived)
}
}
// TestBatchBuffer_Close 测试关闭
func TestBatchBuffer_Close(t *testing.T) {
buffer := NewBatchBuffer(50, 10*time.Millisecond)
ctx := context.Background()
err := buffer.Start(ctx)
if err != nil {
t.Fatalf("Start failed: %v", err)
}
// 添加一些事件
for i := 0; i < 5; i++ {
event := &model.AuditEvent{
EventID: "batch-test-close",
EventName: "TEST-CLOSE",
}
if err := buffer.Add(event); err != nil {
t.Errorf("Add failed: %v", err)
}
}
// 关闭缓冲区
err = buffer.Close()
if err != nil {
t.Errorf("Close failed: %v", err)
}
// 关闭后添加应该失败
event := &model.AuditEvent{
EventID: "batch-test-after-close",
EventName: "TEST-AFTER-CLOSE",
}
if err := buffer.Add(event); err == nil {
t.Errorf("Add after Close should fail")
}
}
// TestBatchBuffer_FlushNow 测试手动刷新
func TestBatchBuffer_FlushNow(t *testing.T) {
const batchSize = 100 // 足够大,不会自动触发
buffer := NewBatchBuffer(batchSize, 100*time.Millisecond) // 100ms才自动刷新
ctx := context.Background()
err := buffer.Start(ctx)
if err != nil {
t.Fatalf("Start failed: %v", err)
}
defer buffer.Close()
var receivedBatches [][]*model.AuditEvent
var mu sync.Mutex
buffer.SetFlushHandler(func(events []*model.AuditEvent) error {
mu.Lock()
receivedBatches = append(receivedBatches, events)
mu.Unlock()
return nil
})
// 添加少量事件
for i := 0; i < 3; i++ {
event := &model.AuditEvent{
EventID: "batch-test-manual",
EventName: "TEST-MANUAL",
}
if err := buffer.Add(event); err != nil {
t.Errorf("Add failed: %v", err)
}
}
// 立即手动刷新
err = buffer.FlushNow()
if err != nil {
t.Errorf("FlushNow failed: %v", err)
}
time.Sleep(10 * time.Millisecond)
mu.Lock()
defer mu.Unlock()
if len(receivedBatches) != 1 {
t.Errorf("expected 1 batch after FlushNow, got %d", len(receivedBatches))
}
}