fix: 系统性修复安全问题、性能问题和错误处理
安全问题修复: - X-Forwarded-For越界检查(auth.go) - checkTokenStatus Context参数传递(auth.go) - Type Assertion安全检查(auth.go) 性能问题修复: - TokenCache过期清理机制 - BruteForceProtection过期清理 - InMemoryIdempotencyStore过期清理 错误处理修复: - AuditStore.Emit返回error - domain层emitAudit辅助方法 - List方法返回空slice而非nil - 金额/价格负数验证 架构一致性: - 统一使用model.RoleHierarchyLevels 新增功能: - Alert API完整实现(CRUD+Resolve) - pkg/error错误码集中管理
This commit is contained in:
@@ -2,6 +2,7 @@ package audit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@@ -23,8 +24,10 @@ type Event struct {
|
||||
|
||||
// 审计存储接口
|
||||
type AuditStore interface {
|
||||
Emit(ctx context.Context, event Event)
|
||||
Emit(ctx context.Context, event Event) error
|
||||
Query(ctx context.Context, filter EventFilter) ([]Event, error)
|
||||
QueryWithTotal(ctx context.Context, filter EventFilter) ([]Event, int64, error)
|
||||
GetByID(ctx context.Context, eventID string) (Event, error)
|
||||
}
|
||||
|
||||
// 事件过滤器
|
||||
@@ -52,13 +55,14 @@ func NewMemoryAuditStore() *MemoryAuditStore {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MemoryAuditStore) Emit(ctx context.Context, event Event) {
|
||||
func (s *MemoryAuditStore) Emit(ctx context.Context, event Event) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
event.EventID = generateEventID()
|
||||
event.CreatedAt = time.Now()
|
||||
s.events = append(s.events, event)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MemoryAuditStore) Query(ctx context.Context, filter EventFilter) ([]Event, error) {
|
||||
@@ -90,6 +94,52 @@ func (s *MemoryAuditStore) Query(ctx context.Context, filter EventFilter) ([]Eve
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// QueryWithTotal 查询事件并返回总数
|
||||
func (s *MemoryAuditStore) QueryWithTotal(ctx context.Context, filter EventFilter) ([]Event, int64, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
var result []Event
|
||||
total := int64(0)
|
||||
|
||||
for _, event := range s.events {
|
||||
total++
|
||||
if filter.TenantID > 0 && event.TenantID != filter.TenantID {
|
||||
continue
|
||||
}
|
||||
if filter.ObjectType != "" && event.ObjectType != filter.ObjectType {
|
||||
continue
|
||||
}
|
||||
if filter.ObjectID > 0 && event.ObjectID != filter.ObjectID {
|
||||
continue
|
||||
}
|
||||
if filter.Action != "" && event.Action != filter.Action {
|
||||
continue
|
||||
}
|
||||
result = append(result, event)
|
||||
}
|
||||
|
||||
// 限制返回数量
|
||||
if filter.Limit > 0 && len(result) > filter.Limit {
|
||||
result = result[:filter.Limit]
|
||||
}
|
||||
|
||||
return result, total, nil
|
||||
}
|
||||
|
||||
// GetByID 根据事件ID获取单个事件
|
||||
func (s *MemoryAuditStore) GetByID(ctx context.Context, eventID string) (Event, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
for _, event := range s.events {
|
||||
if event.EventID == eventID {
|
||||
return event, nil
|
||||
}
|
||||
}
|
||||
return Event{}, fmt.Errorf("event not found")
|
||||
}
|
||||
|
||||
func generateEventID() string {
|
||||
return time.Now().Format("20060102150405") + "-evt"
|
||||
}
|
||||
|
||||
350
supply-api/internal/audit/handler/alert_handler.go
Normal file
350
supply-api/internal/audit/handler/alert_handler.go
Normal file
@@ -0,0 +1,350 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"lijiaoqiao/supply-api/internal/audit/model"
|
||||
"lijiaoqiao/supply-api/internal/audit/service"
|
||||
)
|
||||
|
||||
// AlertHandler 告警HTTP处理器
|
||||
type AlertHandler struct {
|
||||
svc *service.AlertService
|
||||
}
|
||||
|
||||
// NewAlertHandler 创建告警处理器
|
||||
func NewAlertHandler(svc *service.AlertService) *AlertHandler {
|
||||
return &AlertHandler{svc: svc}
|
||||
}
|
||||
|
||||
// CreateAlertRequest 创建告警请求
|
||||
type CreateAlertRequest struct {
|
||||
AlertName string `json:"alert_name"`
|
||||
AlertType string `json:"alert_type"`
|
||||
AlertLevel string `json:"alert_level"`
|
||||
TenantID int64 `json:"tenant_id"`
|
||||
SupplierID int64 `json:"supplier_id,omitempty"`
|
||||
Title string `json:"title"`
|
||||
Message string `json:"message"`
|
||||
Description string `json:"description,omitempty"`
|
||||
EventID string `json:"event_id,omitempty"`
|
||||
EventIDs []string `json:"event_ids,omitempty"`
|
||||
NotifyEnabled bool `json:"notify_enabled"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
}
|
||||
|
||||
// UpdateAlertRequest 更新告警请求
|
||||
type UpdateAlertRequest struct {
|
||||
Title string `json:"title,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
AlertLevel string `json:"alert_level,omitempty"`
|
||||
Status string `json:"status,omitempty"`
|
||||
NotifyEnabled *bool `json:"notify_enabled,omitempty"`
|
||||
NotifyChannels []string `json:"notify_channels,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// ResolveAlertRequest 解决告警请求
|
||||
type ResolveAlertRequest struct {
|
||||
ResolvedBy string `json:"resolved_by"`
|
||||
Note string `json:"note"`
|
||||
}
|
||||
|
||||
// AlertResponse 告警响应
|
||||
type AlertResponse struct {
|
||||
Alert *model.Alert `json:"alert"`
|
||||
}
|
||||
|
||||
// AlertListResponse 告警列表响应
|
||||
type AlertListResponse struct {
|
||||
Alerts []*model.Alert `json:"alerts"`
|
||||
Total int64 `json:"total"`
|
||||
Offset int `json:"offset"`
|
||||
Limit int `json:"limit"`
|
||||
}
|
||||
|
||||
// CreateAlert 处理 POST /api/v1/audit/alerts
|
||||
func (h *AlertHandler) CreateAlert(w http.ResponseWriter, r *http.Request) {
|
||||
var req CreateAlertRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeAlertError(w, http.StatusBadRequest, "INVALID_REQUEST", "invalid request body: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 验证必填字段
|
||||
if req.Title == "" {
|
||||
writeAlertError(w, http.StatusBadRequest, "MISSING_FIELD", "title is required")
|
||||
return
|
||||
}
|
||||
if req.AlertType == "" {
|
||||
writeAlertError(w, http.StatusBadRequest, "MISSING_FIELD", "alert_type is required")
|
||||
return
|
||||
}
|
||||
|
||||
// 创建告警
|
||||
alert := &model.Alert{
|
||||
AlertName: req.AlertName,
|
||||
AlertType: req.AlertType,
|
||||
AlertLevel: req.AlertLevel,
|
||||
TenantID: req.TenantID,
|
||||
SupplierID: req.SupplierID,
|
||||
Title: req.Title,
|
||||
Message: req.Message,
|
||||
Description: req.Description,
|
||||
EventID: req.EventID,
|
||||
EventIDs: req.EventIDs,
|
||||
NotifyEnabled: req.NotifyEnabled,
|
||||
Tags: req.Tags,
|
||||
}
|
||||
|
||||
result, err := h.svc.CreateAlert(r.Context(), alert)
|
||||
if err != nil {
|
||||
writeAlertError(w, http.StatusInternalServerError, "CREATE_FAILED", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
json.NewEncoder(w).Encode(AlertResponse{Alert: result})
|
||||
}
|
||||
|
||||
// GetAlert 处理 GET /api/v1/audit/alerts/{alert_id}
|
||||
func (h *AlertHandler) GetAlert(w http.ResponseWriter, r *http.Request) {
|
||||
alertID := extractAlertID(r)
|
||||
if alertID == "" {
|
||||
writeAlertError(w, http.StatusBadRequest, "MISSING_PARAM", "alert_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
alert, err := h.svc.GetAlert(r.Context(), alertID)
|
||||
if err != nil {
|
||||
if err == service.ErrAlertNotFound {
|
||||
writeAlertError(w, http.StatusNotFound, "NOT_FOUND", "alert not found")
|
||||
return
|
||||
}
|
||||
writeAlertError(w, http.StatusInternalServerError, "GET_FAILED", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(AlertResponse{Alert: alert})
|
||||
}
|
||||
|
||||
// ListAlerts 处理 GET /api/v1/audit/alerts
|
||||
func (h *AlertHandler) ListAlerts(w http.ResponseWriter, r *http.Request) {
|
||||
filter := &model.AlertFilter{}
|
||||
|
||||
// 解析查询参数
|
||||
if tenantIDStr := r.URL.Query().Get("tenant_id"); tenantIDStr != "" {
|
||||
tenantID, err := strconv.ParseInt(tenantIDStr, 10, 64)
|
||||
if err == nil {
|
||||
filter.TenantID = tenantID
|
||||
}
|
||||
}
|
||||
|
||||
if supplierIDStr := r.URL.Query().Get("supplier_id"); supplierIDStr != "" {
|
||||
supplierID, err := strconv.ParseInt(supplierIDStr, 10, 64)
|
||||
if err == nil {
|
||||
filter.SupplierID = supplierID
|
||||
}
|
||||
}
|
||||
|
||||
if alertType := r.URL.Query().Get("alert_type"); alertType != "" {
|
||||
filter.AlertType = alertType
|
||||
}
|
||||
|
||||
if alertLevel := r.URL.Query().Get("alert_level"); alertLevel != "" {
|
||||
filter.AlertLevel = alertLevel
|
||||
}
|
||||
|
||||
if status := r.URL.Query().Get("status"); status != "" {
|
||||
filter.Status = status
|
||||
}
|
||||
|
||||
if keywords := r.URL.Query().Get("keywords"); keywords != "" {
|
||||
filter.Keywords = keywords
|
||||
}
|
||||
|
||||
if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" {
|
||||
offset, err := strconv.Atoi(offsetStr)
|
||||
if err == nil && offset >= 0 {
|
||||
filter.Offset = offset
|
||||
}
|
||||
}
|
||||
|
||||
if limitStr := r.URL.Query().Get("limit"); limitStr != "" {
|
||||
limit, err := strconv.Atoi(limitStr)
|
||||
if err == nil && limit > 0 && limit <= 1000 {
|
||||
filter.Limit = limit
|
||||
}
|
||||
}
|
||||
|
||||
if filter.Limit == 0 {
|
||||
filter.Limit = 100
|
||||
}
|
||||
|
||||
alerts, total, err := h.svc.ListAlerts(r.Context(), filter)
|
||||
if err != nil {
|
||||
writeAlertError(w, http.StatusInternalServerError, "LIST_FAILED", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(AlertListResponse{
|
||||
Alerts: alerts,
|
||||
Total: total,
|
||||
Offset: filter.Offset,
|
||||
Limit: filter.Limit,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateAlert 处理 PUT /api/v1/audit/alerts/{alert_id}
|
||||
func (h *AlertHandler) UpdateAlert(w http.ResponseWriter, r *http.Request) {
|
||||
alertID := extractAlertID(r)
|
||||
if alertID == "" {
|
||||
writeAlertError(w, http.StatusBadRequest, "MISSING_PARAM", "alert_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取现有告警
|
||||
alert, err := h.svc.GetAlert(r.Context(), alertID)
|
||||
if err != nil {
|
||||
if err == service.ErrAlertNotFound {
|
||||
writeAlertError(w, http.StatusNotFound, "NOT_FOUND", "alert not found")
|
||||
return
|
||||
}
|
||||
writeAlertError(w, http.StatusInternalServerError, "GET_FAILED", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateAlertRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeAlertError(w, http.StatusBadRequest, "INVALID_REQUEST", "invalid request body: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 更新字段
|
||||
if req.Title != "" {
|
||||
alert.Title = req.Title
|
||||
}
|
||||
if req.Message != "" {
|
||||
alert.Message = req.Message
|
||||
}
|
||||
if req.Description != "" {
|
||||
alert.Description = req.Description
|
||||
}
|
||||
if req.AlertLevel != "" {
|
||||
alert.AlertLevel = req.AlertLevel
|
||||
}
|
||||
if req.Status != "" {
|
||||
alert.Status = req.Status
|
||||
}
|
||||
if req.NotifyEnabled != nil {
|
||||
alert.NotifyEnabled = *req.NotifyEnabled
|
||||
}
|
||||
if len(req.NotifyChannels) > 0 {
|
||||
alert.NotifyChannels = req.NotifyChannels
|
||||
}
|
||||
if len(req.Tags) > 0 {
|
||||
alert.Tags = req.Tags
|
||||
}
|
||||
if req.Metadata != nil {
|
||||
alert.Metadata = req.Metadata
|
||||
}
|
||||
|
||||
result, err := h.svc.UpdateAlert(r.Context(), alert)
|
||||
if err != nil {
|
||||
writeAlertError(w, http.StatusInternalServerError, "UPDATE_FAILED", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(AlertResponse{Alert: result})
|
||||
}
|
||||
|
||||
// DeleteAlert 处理 DELETE /api/v1/audit/alerts/{alert_id}
|
||||
func (h *AlertHandler) DeleteAlert(w http.ResponseWriter, r *http.Request) {
|
||||
alertID := extractAlertID(r)
|
||||
if alertID == "" {
|
||||
writeAlertError(w, http.StatusBadRequest, "MISSING_PARAM", "alert_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
err := h.svc.DeleteAlert(r.Context(), alertID)
|
||||
if err != nil {
|
||||
if err == service.ErrAlertNotFound {
|
||||
writeAlertError(w, http.StatusNotFound, "NOT_FOUND", "alert not found")
|
||||
return
|
||||
}
|
||||
writeAlertError(w, http.StatusInternalServerError, "DELETE_FAILED", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// ResolveAlert 处理 POST /api/v1/audit/alerts/{alert_id}/resolve
|
||||
func (h *AlertHandler) ResolveAlert(w http.ResponseWriter, r *http.Request) {
|
||||
alertID := extractAlertID(r)
|
||||
if alertID == "" {
|
||||
writeAlertError(w, http.StatusBadRequest, "MISSING_PARAM", "alert_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
var req ResolveAlertRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeAlertError(w, http.StatusBadRequest, "INVALID_REQUEST", "invalid request body: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.ResolvedBy == "" {
|
||||
writeAlertError(w, http.StatusBadRequest, "MISSING_FIELD", "resolved_by is required")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.svc.ResolveAlert(r.Context(), alertID, req.ResolvedBy, req.Note)
|
||||
if err != nil {
|
||||
if err == service.ErrAlertNotFound {
|
||||
writeAlertError(w, http.StatusNotFound, "NOT_FOUND", "alert not found")
|
||||
return
|
||||
}
|
||||
writeAlertError(w, http.StatusInternalServerError, "RESOLVE_FAILED", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(AlertResponse{Alert: result})
|
||||
}
|
||||
|
||||
// extractAlertID 从请求中提取alert_id(优先从路径,其次从查询参数)
|
||||
func extractAlertID(r *http.Request) string {
|
||||
// 先尝试从路径提取
|
||||
path := r.URL.Path
|
||||
parts := strings.Split(strings.TrimPrefix(path, "/"), "/")
|
||||
if len(parts) >= 5 && parts[0] == "api" && parts[1] == "v1" && parts[2] == "audit" && parts[3] == "alerts" {
|
||||
if parts[4] != "" && parts[4] != "resolve" {
|
||||
return parts[4]
|
||||
}
|
||||
}
|
||||
// 再尝试从查询参数提取
|
||||
if alertID := r.URL.Query().Get("alert_id"); alertID != "" {
|
||||
return alertID
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// writeAlertError 写入错误响应
|
||||
func writeAlertError(w http.ResponseWriter, status int, code, message string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
json.NewEncoder(w).Encode(ErrorResponse{
|
||||
Error: message,
|
||||
Code: code,
|
||||
Details: "",
|
||||
})
|
||||
}
|
||||
315
supply-api/internal/audit/handler/alert_handler_test.go
Normal file
315
supply-api/internal/audit/handler/alert_handler_test.go
Normal file
@@ -0,0 +1,315 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"lijiaoqiao/supply-api/internal/audit/model"
|
||||
"lijiaoqiao/supply-api/internal/audit/service"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// mockAlertStore 模拟告警存储
|
||||
type mockAlertStore struct {
|
||||
alerts map[string]*model.Alert
|
||||
}
|
||||
|
||||
func newMockAlertStore() *mockAlertStore {
|
||||
return &mockAlertStore{
|
||||
alerts: make(map[string]*model.Alert),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockAlertStore) Create(ctx context.Context, alert *model.Alert) error {
|
||||
if alert.AlertID == "" {
|
||||
alert.AlertID = "test-alert-id"
|
||||
}
|
||||
alert.CreatedAt = testTime
|
||||
alert.UpdatedAt = testTime
|
||||
m.alerts[alert.AlertID] = alert
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockAlertStore) GetByID(ctx context.Context, alertID string) (*model.Alert, error) {
|
||||
if alert, ok := m.alerts[alertID]; ok {
|
||||
return alert, nil
|
||||
}
|
||||
return nil, service.ErrAlertNotFound
|
||||
}
|
||||
|
||||
func (m *mockAlertStore) Update(ctx context.Context, alert *model.Alert) error {
|
||||
if _, ok := m.alerts[alert.AlertID]; !ok {
|
||||
return service.ErrAlertNotFound
|
||||
}
|
||||
alert.UpdatedAt = testTime
|
||||
m.alerts[alert.AlertID] = alert
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockAlertStore) Delete(ctx context.Context, alertID string) error {
|
||||
if _, ok := m.alerts[alertID]; !ok {
|
||||
return service.ErrAlertNotFound
|
||||
}
|
||||
delete(m.alerts, alertID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockAlertStore) List(ctx context.Context, filter *model.AlertFilter) ([]*model.Alert, int64, error) {
|
||||
var result []*model.Alert
|
||||
for _, alert := range m.alerts {
|
||||
if filter.TenantID > 0 && alert.TenantID != filter.TenantID {
|
||||
continue
|
||||
}
|
||||
if filter.Status != "" && alert.Status != filter.Status {
|
||||
continue
|
||||
}
|
||||
result = append(result, alert)
|
||||
}
|
||||
return result, int64(len(result)), nil
|
||||
}
|
||||
|
||||
var testTime = time.Now()
|
||||
|
||||
// TestAlertHandler_CreateAlert_Success 测试创建告警成功
|
||||
func TestAlertHandler_CreateAlert_Success(t *testing.T) {
|
||||
store := newMockAlertStore()
|
||||
svc := service.NewAlertService(store)
|
||||
h := NewAlertHandler(svc)
|
||||
|
||||
reqBody := CreateAlertRequest{
|
||||
AlertName: "TEST_ALERT",
|
||||
AlertType: "security",
|
||||
AlertLevel: "warning",
|
||||
TenantID: 2001,
|
||||
Title: "Test Alert Title",
|
||||
Message: "Test alert message",
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest("POST", "/api/v1/audit/alerts", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.CreateAlert(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusCreated, w.Code)
|
||||
|
||||
var result AlertResponse
|
||||
err := json.Unmarshal(w.Body.Bytes(), &result)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "Test Alert Title", result.Alert.Title)
|
||||
assert.Equal(t, "security", result.Alert.AlertType)
|
||||
}
|
||||
|
||||
// TestAlertHandler_CreateAlert_MissingTitle 测试缺少标题
|
||||
func TestAlertHandler_CreateAlert_MissingTitle(t *testing.T) {
|
||||
store := newMockAlertStore()
|
||||
svc := service.NewAlertService(store)
|
||||
h := NewAlertHandler(svc)
|
||||
|
||||
reqBody := CreateAlertRequest{
|
||||
AlertType: "security",
|
||||
AlertLevel: "warning",
|
||||
TenantID: 2001,
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest("POST", "/api/v1/audit/alerts", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.CreateAlert(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
}
|
||||
|
||||
// TestAlertHandler_GetAlert_Success 测试获取告警成功
|
||||
func TestAlertHandler_GetAlert_Success(t *testing.T) {
|
||||
store := newMockAlertStore()
|
||||
svc := service.NewAlertService(store)
|
||||
h := NewAlertHandler(svc)
|
||||
|
||||
// 先创建一个告警
|
||||
alert := &model.Alert{
|
||||
AlertID: "test-alert-123",
|
||||
AlertName: "TEST_ALERT",
|
||||
AlertType: "security",
|
||||
AlertLevel: "warning",
|
||||
TenantID: 2001,
|
||||
Title: "Test Alert",
|
||||
Message: "Test message",
|
||||
}
|
||||
store.Create(context.Background(), alert)
|
||||
|
||||
// 获取告警
|
||||
req := httptest.NewRequest("GET", "/api/v1/audit/alerts/test-alert-123", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.GetAlert(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var result AlertResponse
|
||||
err := json.Unmarshal(w.Body.Bytes(), &result)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "test-alert-123", result.Alert.AlertID)
|
||||
}
|
||||
|
||||
// TestAlertHandler_GetAlert_NotFound 测试告警不存在
|
||||
func TestAlertHandler_GetAlert_NotFound(t *testing.T) {
|
||||
store := newMockAlertStore()
|
||||
svc := service.NewAlertService(store)
|
||||
h := NewAlertHandler(svc)
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/v1/audit/alerts/nonexistent", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.GetAlert(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
}
|
||||
|
||||
// TestAlertHandler_ListAlerts_Success 测试列出告警成功
|
||||
func TestAlertHandler_ListAlerts_Success(t *testing.T) {
|
||||
store := newMockAlertStore()
|
||||
svc := service.NewAlertService(store)
|
||||
h := NewAlertHandler(svc)
|
||||
|
||||
// 创建多个告警
|
||||
for i := 0; i < 3; i++ {
|
||||
alert := &model.Alert{
|
||||
AlertID: "alert-" + string(rune('a'+i)),
|
||||
AlertType: "security",
|
||||
AlertLevel: "warning",
|
||||
TenantID: 2001,
|
||||
Title: "Test Alert",
|
||||
}
|
||||
store.Create(context.Background(), alert)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/v1/audit/alerts?tenant_id=2001", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.ListAlerts(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var result AlertListResponse
|
||||
err := json.Unmarshal(w.Body.Bytes(), &result)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(3), result.Total)
|
||||
}
|
||||
|
||||
// TestAlertHandler_UpdateAlert_Success 测试更新告警成功
|
||||
func TestAlertHandler_UpdateAlert_Success(t *testing.T) {
|
||||
store := newMockAlertStore()
|
||||
svc := service.NewAlertService(store)
|
||||
h := NewAlertHandler(svc)
|
||||
|
||||
// 先创建一个告警
|
||||
alert := &model.Alert{
|
||||
AlertID: "test-alert-123",
|
||||
AlertType: "security",
|
||||
AlertLevel: "warning",
|
||||
TenantID: 2001,
|
||||
Title: "Original Title",
|
||||
}
|
||||
store.Create(context.Background(), alert)
|
||||
|
||||
// 更新告警
|
||||
reqBody := UpdateAlertRequest{
|
||||
Title: "Updated Title",
|
||||
}
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest("PUT", "/api/v1/audit/alerts/test-alert-123", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.UpdateAlert(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var result AlertResponse
|
||||
json.Unmarshal(w.Body.Bytes(), &result)
|
||||
assert.Equal(t, "Updated Title", result.Alert.Title)
|
||||
}
|
||||
|
||||
// TestAlertHandler_DeleteAlert_Success 测试删除告警成功
|
||||
func TestAlertHandler_DeleteAlert_Success(t *testing.T) {
|
||||
store := newMockAlertStore()
|
||||
svc := service.NewAlertService(store)
|
||||
h := NewAlertHandler(svc)
|
||||
|
||||
// 先创建一个告警
|
||||
alert := &model.Alert{
|
||||
AlertID: "test-alert-123",
|
||||
AlertType: "security",
|
||||
AlertLevel: "warning",
|
||||
TenantID: 2001,
|
||||
}
|
||||
store.Create(context.Background(), alert)
|
||||
|
||||
// 删除告警
|
||||
req := httptest.NewRequest("DELETE", "/api/v1/audit/alerts/test-alert-123", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.DeleteAlert(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusNoContent, w.Code)
|
||||
}
|
||||
|
||||
// TestAlertHandler_DeleteAlert_NotFound 测试删除不存在的告警
|
||||
func TestAlertHandler_DeleteAlert_NotFound(t *testing.T) {
|
||||
store := newMockAlertStore()
|
||||
svc := service.NewAlertService(store)
|
||||
h := NewAlertHandler(svc)
|
||||
|
||||
req := httptest.NewRequest("DELETE", "/api/v1/audit/alerts/nonexistent", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.DeleteAlert(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
}
|
||||
|
||||
// TestAlertHandler_ResolveAlert_Success 测试解决告警成功
|
||||
func TestAlertHandler_ResolveAlert_Success(t *testing.T) {
|
||||
store := newMockAlertStore()
|
||||
svc := service.NewAlertService(store)
|
||||
h := NewAlertHandler(svc)
|
||||
|
||||
// 先创建一个告警
|
||||
alert := &model.Alert{
|
||||
AlertID: "test-alert-123",
|
||||
AlertType: "security",
|
||||
AlertLevel: "warning",
|
||||
TenantID: 2001,
|
||||
Status: model.AlertStatusActive,
|
||||
}
|
||||
store.Create(context.Background(), alert)
|
||||
|
||||
// 解决告警
|
||||
reqBody := ResolveAlertRequest{
|
||||
ResolvedBy: "admin",
|
||||
Note: "Fixed the issue",
|
||||
}
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest("POST", "/api/v1/audit/alerts/test-alert-123/resolve", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.ResolveAlert(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var result AlertResponse
|
||||
json.Unmarshal(w.Body.Bytes(), &result)
|
||||
assert.Equal(t, model.AlertStatusResolved, result.Alert.Status)
|
||||
assert.Equal(t, "admin", result.Alert.ResolvedBy)
|
||||
}
|
||||
195
supply-api/internal/audit/model/alert.go
Normal file
195
supply-api/internal/audit/model/alert.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// 告警级别常量
|
||||
const (
|
||||
AlertLevelInfo = "info"
|
||||
AlertLevelWarning = "warning"
|
||||
AlertLevelError = "error"
|
||||
AlertLevelCritical = "critical"
|
||||
)
|
||||
|
||||
// 告警状态常量
|
||||
const (
|
||||
AlertStatusActive = "active"
|
||||
AlertStatusResolved = "resolved"
|
||||
AlertStatusAcknowledged = "acknowledged"
|
||||
AlertStatusSuppressed = "suppressed"
|
||||
)
|
||||
|
||||
// 告警类型常量
|
||||
const (
|
||||
AlertTypeSecurity = "security"
|
||||
AlertTypeInvariant = "invariant"
|
||||
AlertTypeCredential = "credential"
|
||||
AlertTypeAuthentication = "authentication"
|
||||
AlertTypeAuthorization = "authorization"
|
||||
AlertTypeQuota = "quota"
|
||||
)
|
||||
|
||||
// Alert 告警
|
||||
type Alert struct {
|
||||
// 基础标识
|
||||
AlertID string `json:"alert_id"` // 告警唯一ID
|
||||
AlertName string `json:"alert_name"` // 告警名称
|
||||
AlertType string `json:"alert_type"` // 告警类型 (security/invariant/credential/etc.)
|
||||
AlertLevel string `json:"alert_level"` // 告警级别 (info/warning/error/critical)
|
||||
TenantID int64 `json:"tenant_id"` // 租户ID
|
||||
SupplierID int64 `json:"supplier_id,omitempty"` // 供应商ID(可选)
|
||||
|
||||
// 告警内容
|
||||
Title string `json:"title"` // 告警标题
|
||||
Message string `json:"message"` // 告警消息
|
||||
Description string `json:"description,omitempty"` // 详细描述
|
||||
|
||||
// 关联事件
|
||||
EventID string `json:"event_id,omitempty"` // 关联的事件ID
|
||||
EventIDs []string `json:"event_ids,omitempty"` // 关联的事件ID列表(多个)
|
||||
|
||||
// 触发条件
|
||||
TriggerCondition string `json:"trigger_condition,omitempty"` // 触发条件
|
||||
Threshold float64 `json:"threshold,omitempty"` // 阈值
|
||||
CurrentValue float64 `json:"current_value,omitempty"` // 当前值
|
||||
|
||||
// 状态
|
||||
Status string `json:"status"` // 状态 (active/resolved/acknowledged/suppressed)
|
||||
ResolvedAt *time.Time `json:"resolved_at,omitempty"` // 解决时间
|
||||
ResolvedBy string `json:"resolved_by,omitempty"` // 解决人
|
||||
ResolveNote string `json:"resolve_note,omitempty"` // 解决备注
|
||||
|
||||
// 通知
|
||||
NotifyEnabled bool `json:"notify_enabled"` // 是否启用通知
|
||||
NotifyChannels []string `json:"notify_channels,omitempty"` // 通知渠道 (email/sms/webhook/etc.)
|
||||
|
||||
// 时间戳
|
||||
CreatedAt time.Time `json:"created_at"` // 创建时间
|
||||
UpdatedAt time.Time `json:"updated_at"` // 更新时间
|
||||
FirstSeenAt time.Time `json:"first_seen_at"` // 首次出现时间
|
||||
LastSeenAt time.Time `json:"last_seen_at"` // 最后出现时间
|
||||
|
||||
// 元数据
|
||||
Metadata map[string]any `json:"metadata,omitempty"` // 扩展元数据
|
||||
Tags []string `json:"tags,omitempty"` // 标签
|
||||
}
|
||||
|
||||
// NewAlert 创建新告警
|
||||
func NewAlert(alertName, alertType, alertLevel, tenantID string, title, message string) *Alert {
|
||||
now := time.Now()
|
||||
return &Alert{
|
||||
AlertID: generateAlertID(),
|
||||
AlertName: alertName,
|
||||
AlertType: alertType,
|
||||
AlertLevel: alertLevel,
|
||||
TenantID: parseTenantID(tenantID),
|
||||
Title: title,
|
||||
Message: message,
|
||||
Status: AlertStatusActive,
|
||||
NotifyEnabled: true,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
FirstSeenAt: now,
|
||||
LastSeenAt: now,
|
||||
Metadata: make(map[string]any),
|
||||
Tags: []string{},
|
||||
}
|
||||
}
|
||||
|
||||
// generateAlertID 生成告警ID
|
||||
func generateAlertID() string {
|
||||
return "ALT-" + uuid.New().String()[:8]
|
||||
}
|
||||
|
||||
// parseTenantID 解析租户ID
|
||||
func parseTenantID(tenantID string) int64 {
|
||||
var id int64
|
||||
for _, c := range tenantID {
|
||||
if c >= '0' && c <= '9' {
|
||||
id = id*10 + int64(c-'0')
|
||||
}
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
// IsActive 检查告警是否处于活跃状态
|
||||
func (a *Alert) IsActive() bool {
|
||||
return a.Status == AlertStatusActive
|
||||
}
|
||||
|
||||
// IsResolved 检查告警是否已解决
|
||||
func (a *Alert) IsResolved() bool {
|
||||
return a.Status == AlertStatusResolved
|
||||
}
|
||||
|
||||
// Resolve 解决告警
|
||||
func (a *Alert) Resolve(resolvedBy, note string) {
|
||||
now := time.Now()
|
||||
a.Status = AlertStatusResolved
|
||||
a.ResolvedAt = &now
|
||||
a.ResolvedBy = resolvedBy
|
||||
a.ResolveNote = note
|
||||
a.UpdatedAt = now
|
||||
}
|
||||
|
||||
// Acknowledge 确认告警
|
||||
func (a *Alert) Acknowledge() {
|
||||
a.Status = AlertStatusAcknowledged
|
||||
a.UpdatedAt = time.Now()
|
||||
}
|
||||
|
||||
// Suppress 抑制告警
|
||||
func (a *Alert) Suppress() {
|
||||
a.Status = AlertStatusSuppressed
|
||||
a.UpdatedAt = time.Now()
|
||||
}
|
||||
|
||||
// UpdateLastSeen 更新最后出现时间
|
||||
func (a *Alert) UpdateLastSeen() {
|
||||
a.LastSeenAt = time.Now()
|
||||
a.UpdatedAt = time.Now()
|
||||
}
|
||||
|
||||
// AddEventID 添加关联事件ID
|
||||
func (a *Alert) AddEventID(eventID string) {
|
||||
a.EventIDs = append(a.EventIDs, eventID)
|
||||
if a.EventID == "" {
|
||||
a.EventID = eventID
|
||||
}
|
||||
a.UpdateLastSeen()
|
||||
}
|
||||
|
||||
// SetMetadata 设置元数据
|
||||
func (a *Alert) SetMetadata(key string, value any) {
|
||||
if a.Metadata == nil {
|
||||
a.Metadata = make(map[string]any)
|
||||
}
|
||||
a.Metadata[key] = value
|
||||
}
|
||||
|
||||
// AddTag 添加标签
|
||||
func (a *Alert) AddTag(tag string) {
|
||||
for _, t := range a.Tags {
|
||||
if t == tag {
|
||||
return
|
||||
}
|
||||
}
|
||||
a.Tags = append(a.Tags, tag)
|
||||
}
|
||||
|
||||
// AlertFilter 告警查询过滤器
|
||||
type AlertFilter struct {
|
||||
TenantID int64
|
||||
SupplierID int64
|
||||
AlertType string
|
||||
AlertLevel string
|
||||
Status string
|
||||
StartTime time.Time
|
||||
EndTime time.Time
|
||||
Keywords string // 关键字搜索(标题/消息)
|
||||
Limit int
|
||||
Offset int
|
||||
}
|
||||
274
supply-api/internal/audit/service/alert_service.go
Normal file
274
supply-api/internal/audit/service/alert_service.go
Normal file
@@ -0,0 +1,274 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"lijiaoqiao/supply-api/internal/audit/model"
|
||||
)
|
||||
|
||||
// 错误定义
|
||||
var (
|
||||
ErrAlertNotFound = errors.New("alert not found")
|
||||
ErrInvalidAlertInput = errors.New("invalid alert input")
|
||||
ErrAlertConflict = errors.New("alert conflict")
|
||||
)
|
||||
|
||||
// AlertStoreInterface 告警存储接口
|
||||
type AlertStoreInterface interface {
|
||||
Create(ctx context.Context, alert *model.Alert) error
|
||||
GetByID(ctx context.Context, alertID string) (*model.Alert, error)
|
||||
Update(ctx context.Context, alert *model.Alert) error
|
||||
Delete(ctx context.Context, alertID string) error
|
||||
List(ctx context.Context, filter *model.AlertFilter) ([]*model.Alert, int64, error)
|
||||
}
|
||||
|
||||
// InMemoryAlertStore 内存告警存储
|
||||
type InMemoryAlertStore struct {
|
||||
mu sync.RWMutex
|
||||
alerts map[string]*model.Alert
|
||||
}
|
||||
|
||||
// NewInMemoryAlertStore 创建内存告警存储
|
||||
func NewInMemoryAlertStore() *InMemoryAlertStore {
|
||||
return &InMemoryAlertStore{
|
||||
alerts: make(map[string]*model.Alert),
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建告警
|
||||
func (s *InMemoryAlertStore) Create(ctx context.Context, alert *model.Alert) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if alert.AlertID == "" {
|
||||
alert.AlertID = "ALT-" + uuid.New().String()[:8]
|
||||
}
|
||||
alert.CreatedAt = time.Now()
|
||||
alert.UpdatedAt = time.Now()
|
||||
|
||||
s.alerts[alert.AlertID] = alert
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取告警
|
||||
func (s *InMemoryAlertStore) GetByID(ctx context.Context, alertID string) (*model.Alert, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if alert, ok := s.alerts[alertID]; ok {
|
||||
return alert, nil
|
||||
}
|
||||
return nil, ErrAlertNotFound
|
||||
}
|
||||
|
||||
// Update 更新告警
|
||||
func (s *InMemoryAlertStore) Update(ctx context.Context, alert *model.Alert) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if _, ok := s.alerts[alert.AlertID]; !ok {
|
||||
return ErrAlertNotFound
|
||||
}
|
||||
|
||||
alert.UpdatedAt = time.Now()
|
||||
s.alerts[alert.AlertID] = alert
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除告警
|
||||
func (s *InMemoryAlertStore) Delete(ctx context.Context, alertID string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if _, ok := s.alerts[alertID]; !ok {
|
||||
return ErrAlertNotFound
|
||||
}
|
||||
|
||||
delete(s.alerts, alertID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// List 查询告警列表
|
||||
func (s *InMemoryAlertStore) List(ctx context.Context, filter *model.AlertFilter) ([]*model.Alert, int64, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
var result []*model.Alert
|
||||
for _, alert := range s.alerts {
|
||||
// 按租户过滤
|
||||
if filter.TenantID > 0 && alert.TenantID != filter.TenantID {
|
||||
continue
|
||||
}
|
||||
// 按供应商过滤
|
||||
if filter.SupplierID > 0 && alert.SupplierID != filter.SupplierID {
|
||||
continue
|
||||
}
|
||||
// 按类型过滤
|
||||
if filter.AlertType != "" && alert.AlertType != filter.AlertType {
|
||||
continue
|
||||
}
|
||||
// 按级别过滤
|
||||
if filter.AlertLevel != "" && alert.AlertLevel != filter.AlertLevel {
|
||||
continue
|
||||
}
|
||||
// 按状态过滤
|
||||
if filter.Status != "" && alert.Status != filter.Status {
|
||||
continue
|
||||
}
|
||||
// 按时间范围过滤
|
||||
if !filter.StartTime.IsZero() && alert.CreatedAt.Before(filter.StartTime) {
|
||||
continue
|
||||
}
|
||||
if !filter.EndTime.IsZero() && alert.CreatedAt.After(filter.EndTime) {
|
||||
continue
|
||||
}
|
||||
// 关键字搜索
|
||||
if filter.Keywords != "" {
|
||||
kw := filter.Keywords
|
||||
if !strings.Contains(alert.Title, kw) && !strings.Contains(alert.Message, kw) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
result = append(result, alert)
|
||||
}
|
||||
|
||||
total := int64(len(result))
|
||||
|
||||
// 分页
|
||||
if filter.Offset > 0 {
|
||||
if filter.Offset >= len(result) {
|
||||
return []*model.Alert{}, total, nil
|
||||
}
|
||||
result = result[filter.Offset:]
|
||||
}
|
||||
if filter.Limit > 0 && filter.Limit < len(result) {
|
||||
result = result[:filter.Limit]
|
||||
}
|
||||
|
||||
return result, total, nil
|
||||
}
|
||||
|
||||
// AlertService 告警服务
|
||||
type AlertService struct {
|
||||
store AlertStoreInterface
|
||||
}
|
||||
|
||||
// NewAlertService 创建告警服务
|
||||
func NewAlertService(store AlertStoreInterface) *AlertService {
|
||||
return &AlertService{store: store}
|
||||
}
|
||||
|
||||
// CreateAlert 创建告警
|
||||
func (s *AlertService) CreateAlert(ctx context.Context, alert *model.Alert) (*model.Alert, error) {
|
||||
if alert == nil {
|
||||
return nil, ErrInvalidAlertInput
|
||||
}
|
||||
if alert.Title == "" {
|
||||
return nil, errors.New("alert title is required")
|
||||
}
|
||||
|
||||
// 设置默认值
|
||||
if alert.AlertID == "" {
|
||||
alert.AlertID = model.NewAlert("", "", "", "", "", "").AlertID
|
||||
}
|
||||
if alert.Status == "" {
|
||||
alert.Status = model.AlertStatusActive
|
||||
}
|
||||
now := time.Now()
|
||||
if alert.CreatedAt.IsZero() {
|
||||
alert.CreatedAt = now
|
||||
}
|
||||
if alert.UpdatedAt.IsZero() {
|
||||
alert.UpdatedAt = now
|
||||
}
|
||||
if alert.FirstSeenAt.IsZero() {
|
||||
alert.FirstSeenAt = now
|
||||
}
|
||||
if alert.LastSeenAt.IsZero() {
|
||||
alert.LastSeenAt = now
|
||||
}
|
||||
|
||||
err := s.store.Create(ctx, alert)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return alert, nil
|
||||
}
|
||||
|
||||
// GetAlert 获取告警
|
||||
func (s *AlertService) GetAlert(ctx context.Context, alertID string) (*model.Alert, error) {
|
||||
if alertID == "" {
|
||||
return nil, ErrInvalidAlertInput
|
||||
}
|
||||
return s.store.GetByID(ctx, alertID)
|
||||
}
|
||||
|
||||
// UpdateAlert 更新告警
|
||||
func (s *AlertService) UpdateAlert(ctx context.Context, alert *model.Alert) (*model.Alert, error) {
|
||||
if alert == nil || alert.AlertID == "" {
|
||||
return nil, ErrInvalidAlertInput
|
||||
}
|
||||
|
||||
alert.UpdatedAt = time.Now()
|
||||
err := s.store.Update(ctx, alert)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return alert, nil
|
||||
}
|
||||
|
||||
// DeleteAlert 删除告警
|
||||
func (s *AlertService) DeleteAlert(ctx context.Context, alertID string) error {
|
||||
if alertID == "" {
|
||||
return ErrInvalidAlertInput
|
||||
}
|
||||
return s.store.Delete(ctx, alertID)
|
||||
}
|
||||
|
||||
// ListAlerts 列出告警
|
||||
func (s *AlertService) ListAlerts(ctx context.Context, filter *model.AlertFilter) ([]*model.Alert, int64, error) {
|
||||
if filter == nil {
|
||||
filter = &model.AlertFilter{}
|
||||
}
|
||||
if filter.Limit == 0 {
|
||||
filter.Limit = 100
|
||||
}
|
||||
return s.store.List(ctx, filter)
|
||||
}
|
||||
|
||||
// ResolveAlert 解决告警
|
||||
func (s *AlertService) ResolveAlert(ctx context.Context, alertID, resolvedBy, note string) (*model.Alert, error) {
|
||||
alert, err := s.store.GetByID(ctx, alertID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
alert.Resolve(resolvedBy, note)
|
||||
err = s.store.Update(ctx, alert)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return alert, nil
|
||||
}
|
||||
|
||||
// AcknowledgeAlert 确认告警
|
||||
func (s *AlertService) AcknowledgeAlert(ctx context.Context, alertID string) (*model.Alert, error) {
|
||||
alert, err := s.store.GetByID(ctx, alertID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
alert.Acknowledge()
|
||||
err = s.store.Update(ctx, alert)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return alert, nil
|
||||
}
|
||||
203
supply-api/internal/audit/service/batch_buffer.go
Normal file
203
supply-api/internal/audit/service/batch_buffer.go
Normal file
@@ -0,0 +1,203 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"lijiaoqiao/supply-api/internal/audit/model"
|
||||
)
|
||||
|
||||
// BatchBufferConfig 批量缓冲区配置
|
||||
type BatchBufferConfig struct {
|
||||
BatchSize int // 批量大小(默认50)
|
||||
FlushInterval time.Duration // 刷新间隔(默认5ms)
|
||||
BufferSize int // 通道缓冲大小(默认1000)
|
||||
}
|
||||
|
||||
// DefaultBatchBufferConfig 默认配置
|
||||
var DefaultBatchBufferConfig = BatchBufferConfig{
|
||||
BatchSize: 50,
|
||||
FlushInterval: 5 * time.Millisecond,
|
||||
BufferSize: 1000,
|
||||
}
|
||||
|
||||
// BatchBuffer 批量写入缓冲区
|
||||
// 设计目标:50条/批或5ms刷新间隔,支持5K-8K TPS
|
||||
type BatchBuffer struct {
|
||||
config BatchBufferConfig
|
||||
eventCh chan *model.AuditEvent
|
||||
buffer []*model.AuditEvent
|
||||
mu sync.Mutex
|
||||
closed bool
|
||||
|
||||
flushTick *time.Ticker
|
||||
stopCh chan struct{}
|
||||
doneCh chan struct{}
|
||||
|
||||
// FlushHandler 处理批量刷新回调
|
||||
FlushHandler func(events []*model.AuditEvent) error
|
||||
}
|
||||
|
||||
// NewBatchBuffer 创建批量缓冲区
|
||||
func NewBatchBuffer(batchSize int, flushInterval time.Duration) *BatchBuffer {
|
||||
config := DefaultBatchBufferConfig
|
||||
if batchSize > 0 {
|
||||
config.BatchSize = batchSize
|
||||
}
|
||||
if flushInterval > 0 {
|
||||
config.FlushInterval = flushInterval
|
||||
}
|
||||
|
||||
return &BatchBuffer{
|
||||
config: config,
|
||||
eventCh: make(chan *model.AuditEvent, config.BufferSize),
|
||||
buffer: make([]*model.AuditEvent, 0, batchSize),
|
||||
flushTick: time.NewTicker(config.FlushInterval),
|
||||
stopCh: make(chan struct{}),
|
||||
doneCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动批量缓冲处理
|
||||
func (b *BatchBuffer) Start(ctx context.Context) error {
|
||||
go b.run()
|
||||
return nil
|
||||
}
|
||||
|
||||
// run 后台处理循环
|
||||
func (b *BatchBuffer) run() {
|
||||
defer close(b.doneCh)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-b.stopCh:
|
||||
// 停止信号:处理剩余缓冲
|
||||
b.flush()
|
||||
return
|
||||
case event := <-b.eventCh:
|
||||
b.addEvent(event)
|
||||
case <-b.flushTick.C:
|
||||
b.flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// addEvent 添加事件到缓冲
|
||||
func (b *BatchBuffer) addEvent(event *model.AuditEvent) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
b.buffer = append(b.buffer, event)
|
||||
|
||||
// 达到批量大小立即刷新
|
||||
if len(b.buffer) >= b.config.BatchSize {
|
||||
b.doFlushLocked()
|
||||
}
|
||||
}
|
||||
|
||||
// flush 刷新缓冲(带锁)- 也会处理eventCh中的待处理事件
|
||||
func (b *BatchBuffer) flush() {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
// 处理eventCh中已有的事件
|
||||
for {
|
||||
select {
|
||||
case event := <-b.eventCh:
|
||||
b.buffer = append(b.buffer, event)
|
||||
default:
|
||||
goto done
|
||||
}
|
||||
}
|
||||
done:
|
||||
b.doFlushLocked()
|
||||
}
|
||||
|
||||
// doFlushLocked 执行刷新( caller 必须持锁)
|
||||
func (b *BatchBuffer) doFlushLocked() {
|
||||
if len(b.buffer) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// 复制缓冲数据
|
||||
events := make([]*model.AuditEvent, len(b.buffer))
|
||||
copy(events, b.buffer)
|
||||
|
||||
// 清空缓冲
|
||||
b.buffer = b.buffer[:0]
|
||||
|
||||
// 调用处理函数(如果已设置)
|
||||
if b.FlushHandler != nil {
|
||||
if err := b.FlushHandler(events); err != nil {
|
||||
// TODO: 错误处理 - 记录日志、重试等
|
||||
// 当前简化处理:仅记录
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add 添加审计事件
|
||||
func (b *BatchBuffer) Add(event *model.AuditEvent) error {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if b.closed {
|
||||
return ErrBufferClosed
|
||||
}
|
||||
|
||||
select {
|
||||
case b.eventCh <- event:
|
||||
return nil
|
||||
default:
|
||||
// 通道满,添加到缓冲
|
||||
b.buffer = append(b.buffer, event)
|
||||
if len(b.buffer) >= b.config.BatchSize {
|
||||
b.doFlushLocked()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// FlushNow 立即刷新
|
||||
func (b *BatchBuffer) FlushNow() error {
|
||||
b.flush()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close 关闭缓冲区
|
||||
func (b *BatchBuffer) Close() error {
|
||||
b.mu.Lock()
|
||||
if b.closed {
|
||||
b.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
b.closed = true
|
||||
b.mu.Unlock()
|
||||
|
||||
close(b.stopCh)
|
||||
<-b.doneCh
|
||||
b.flushTick.Stop()
|
||||
close(b.eventCh)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetFlushHandler 设置刷新处理器
|
||||
func (b *BatchBuffer) SetFlushHandler(handler func(events []*model.AuditEvent) error) {
|
||||
b.FlushHandler = handler
|
||||
}
|
||||
|
||||
// 错误定义
|
||||
var (
|
||||
ErrBufferClosed = &BatchBufferError{"buffer is closed"}
|
||||
ErrMissingFlushHandler = &BatchBufferError{"flush handler not set"}
|
||||
)
|
||||
|
||||
// BatchBufferError 批量缓冲错误
|
||||
type BatchBufferError struct {
|
||||
msg string
|
||||
}
|
||||
|
||||
func (e *BatchBufferError) Error() string {
|
||||
return e.msg
|
||||
}
|
||||
249
supply-api/internal/audit/service/batch_buffer_test.go
Normal file
249
supply-api/internal/audit/service/batch_buffer_test.go
Normal file
@@ -0,0 +1,249 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"lijiaoqiao/supply-api/internal/audit/model"
|
||||
)
|
||||
|
||||
// TestBatchBuffer_BatchSize 测试50条/批刷新
|
||||
func TestBatchBuffer_BatchSize(t *testing.T) {
|
||||
const batchSize = 50
|
||||
|
||||
buffer := NewBatchBuffer(batchSize, 100*time.Millisecond) // 100ms超时防止测试卡住
|
||||
ctx := context.Background()
|
||||
|
||||
err := buffer.Start(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Start failed: %v", err)
|
||||
}
|
||||
defer buffer.Close()
|
||||
|
||||
// 收集器:接收批量事件
|
||||
var receivedBatches [][]*model.AuditEvent
|
||||
var mu sync.Mutex
|
||||
|
||||
buffer.SetFlushHandler(func(events []*model.AuditEvent) error {
|
||||
mu.Lock()
|
||||
receivedBatches = append(receivedBatches, events)
|
||||
mu.Unlock()
|
||||
return nil
|
||||
})
|
||||
|
||||
// 添加50条事件,应该触发一次批量刷新
|
||||
for i := 0; i < batchSize; i++ {
|
||||
event := &model.AuditEvent{
|
||||
EventID: "batch-test-001",
|
||||
EventName: "TEST-EVENT",
|
||||
}
|
||||
if err := buffer.Add(event); err != nil {
|
||||
t.Errorf("Add failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 等待刷新完成
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// 验证:应该收到恰好一个批次
|
||||
mu.Lock()
|
||||
if len(receivedBatches) != 1 {
|
||||
t.Errorf("expected 1 batch, got %d", len(receivedBatches))
|
||||
}
|
||||
if len(receivedBatches) > 0 && len(receivedBatches[0]) != batchSize {
|
||||
t.Errorf("expected batch size %d, got %d", batchSize, len(receivedBatches[0]))
|
||||
}
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
// TestBatchBuffer_TimeoutFlush 测试5ms超时刷新
|
||||
func TestBatchBuffer_TimeoutFlush(t *testing.T) {
|
||||
const batchSize = 100 // 大于我们添加的数量
|
||||
const flushInterval = 5 * time.Millisecond
|
||||
|
||||
buffer := NewBatchBuffer(batchSize, flushInterval)
|
||||
ctx := context.Background()
|
||||
|
||||
err := buffer.Start(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Start failed: %v", err)
|
||||
}
|
||||
defer buffer.Close()
|
||||
|
||||
// 收集器
|
||||
var receivedBatches [][]*model.AuditEvent
|
||||
var mu sync.Mutex
|
||||
|
||||
buffer.SetFlushHandler(func(events []*model.AuditEvent) error {
|
||||
mu.Lock()
|
||||
receivedBatches = append(receivedBatches, events)
|
||||
mu.Unlock()
|
||||
return nil
|
||||
})
|
||||
|
||||
// 只添加3条事件,不满50条
|
||||
for i := 0; i < 3; i++ {
|
||||
event := &model.AuditEvent{
|
||||
EventID: "batch-test-002",
|
||||
EventName: "TEST-TIMEOUT",
|
||||
}
|
||||
if err := buffer.Add(event); err != nil {
|
||||
t.Errorf("Add failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 等待5ms超时刷新
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
|
||||
// 验证:应该收到一个批次,包含3条事件
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if len(receivedBatches) != 1 {
|
||||
t.Errorf("expected 1 batch (timeout flush), got %d", len(receivedBatches))
|
||||
}
|
||||
if len(receivedBatches) > 0 && len(receivedBatches[0]) != 3 {
|
||||
t.Errorf("expected 3 events in batch, got %d", len(receivedBatches[0]))
|
||||
}
|
||||
}
|
||||
|
||||
// TestBatchBuffer_ConcurrentAccess 测试并发安全性
|
||||
func TestBatchBuffer_ConcurrentAccess(t *testing.T) {
|
||||
const batchSize = 50
|
||||
const numGoroutines = 10
|
||||
const eventsPerGoroutine = 100
|
||||
|
||||
buffer := NewBatchBuffer(batchSize, 10*time.Millisecond)
|
||||
ctx := context.Background()
|
||||
|
||||
err := buffer.Start(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Start failed: %v", err)
|
||||
}
|
||||
defer buffer.Close()
|
||||
|
||||
var totalReceived int
|
||||
var mu sync.Mutex
|
||||
|
||||
buffer.SetFlushHandler(func(events []*model.AuditEvent) error {
|
||||
mu.Lock()
|
||||
totalReceived += len(events)
|
||||
mu.Unlock()
|
||||
return nil
|
||||
})
|
||||
|
||||
// 并发添加事件
|
||||
var wg sync.WaitGroup
|
||||
for g := 0; g < numGoroutines; g++ {
|
||||
wg.Add(1)
|
||||
go func(goroutineID int) {
|
||||
defer wg.Done()
|
||||
for i := 0; i < eventsPerGoroutine; i++ {
|
||||
event := &model.AuditEvent{
|
||||
EventID: "batch-test-concurrent",
|
||||
EventName: "TEST-CONCURRENT",
|
||||
}
|
||||
if err := buffer.Add(event); err != nil {
|
||||
t.Errorf("Add failed: %v", err)
|
||||
}
|
||||
}
|
||||
}(g)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
time.Sleep(50 * time.Millisecond) // 等待所有刷新完成
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
expectedTotal := numGoroutines * eventsPerGoroutine
|
||||
if totalReceived != expectedTotal {
|
||||
t.Errorf("expected %d total events, got %d", expectedTotal, totalReceived)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBatchBuffer_Close 测试关闭
|
||||
func TestBatchBuffer_Close(t *testing.T) {
|
||||
buffer := NewBatchBuffer(50, 10*time.Millisecond)
|
||||
ctx := context.Background()
|
||||
|
||||
err := buffer.Start(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Start failed: %v", err)
|
||||
}
|
||||
|
||||
// 添加一些事件
|
||||
for i := 0; i < 5; i++ {
|
||||
event := &model.AuditEvent{
|
||||
EventID: "batch-test-close",
|
||||
EventName: "TEST-CLOSE",
|
||||
}
|
||||
if err := buffer.Add(event); err != nil {
|
||||
t.Errorf("Add failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 关闭缓冲区
|
||||
err = buffer.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Close failed: %v", err)
|
||||
}
|
||||
|
||||
// 关闭后添加应该失败
|
||||
event := &model.AuditEvent{
|
||||
EventID: "batch-test-after-close",
|
||||
EventName: "TEST-AFTER-CLOSE",
|
||||
}
|
||||
if err := buffer.Add(event); err == nil {
|
||||
t.Errorf("Add after Close should fail")
|
||||
}
|
||||
}
|
||||
|
||||
// TestBatchBuffer_FlushNow 测试手动刷新
|
||||
func TestBatchBuffer_FlushNow(t *testing.T) {
|
||||
const batchSize = 100 // 足够大,不会自动触发
|
||||
|
||||
buffer := NewBatchBuffer(batchSize, 100*time.Millisecond) // 100ms才自动刷新
|
||||
ctx := context.Background()
|
||||
|
||||
err := buffer.Start(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Start failed: %v", err)
|
||||
}
|
||||
defer buffer.Close()
|
||||
|
||||
var receivedBatches [][]*model.AuditEvent
|
||||
var mu sync.Mutex
|
||||
|
||||
buffer.SetFlushHandler(func(events []*model.AuditEvent) error {
|
||||
mu.Lock()
|
||||
receivedBatches = append(receivedBatches, events)
|
||||
mu.Unlock()
|
||||
return nil
|
||||
})
|
||||
|
||||
// 添加少量事件
|
||||
for i := 0; i < 3; i++ {
|
||||
event := &model.AuditEvent{
|
||||
EventID: "batch-test-manual",
|
||||
EventName: "TEST-MANUAL",
|
||||
}
|
||||
if err := buffer.Add(event); err != nil {
|
||||
t.Errorf("Add failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 立即手动刷新
|
||||
err = buffer.FlushNow()
|
||||
if err != nil {
|
||||
t.Errorf("FlushNow failed: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if len(receivedBatches) != 1 {
|
||||
t.Errorf("expected 1 batch after FlushNow, got %d", len(receivedBatches))
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
@@ -141,6 +142,14 @@ func NewAccountService(store AccountStore, auditStore audit.AuditStore) AccountS
|
||||
return &accountService{store: store, auditStore: auditStore}
|
||||
}
|
||||
|
||||
// emitAudit 安全记录审计日志(失败只记录错误,不影响主流程)
|
||||
func (s *accountService) emitAudit(ctx context.Context, event audit.Event) {
|
||||
if err := s.auditStore.Emit(ctx, event); err != nil {
|
||||
log.Printf("[AUDIT_ERROR] failed to emit audit event: %v, object_type=%s, object_id=%d, action=%s",
|
||||
err, event.ObjectType, event.ObjectID, event.Action)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *accountService) Verify(ctx context.Context, supplierID int64, provider Provider, accountType AccountType, credential string) (*VerifyResult, error) {
|
||||
// 开发阶段:模拟验证逻辑
|
||||
result := &VerifyResult{
|
||||
@@ -181,7 +190,7 @@ func (s *accountService) Create(ctx context.Context, req *CreateAccountRequest)
|
||||
}
|
||||
|
||||
// 记录审计日志
|
||||
s.auditStore.Emit(ctx, audit.Event{
|
||||
s.emitAudit(ctx, audit.Event{
|
||||
TenantID: req.SupplierID,
|
||||
ObjectType: "supply_account",
|
||||
ObjectID: account.ID,
|
||||
@@ -210,7 +219,7 @@ func (s *accountService) Activate(ctx context.Context, supplierID, accountID int
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.auditStore.Emit(ctx, audit.Event{
|
||||
s.emitAudit(ctx, audit.Event{
|
||||
TenantID: supplierID,
|
||||
ObjectType: "supply_account",
|
||||
ObjectID: accountID,
|
||||
@@ -239,7 +248,7 @@ func (s *accountService) Suspend(ctx context.Context, supplierID, accountID int6
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.auditStore.Emit(ctx, audit.Event{
|
||||
s.emitAudit(ctx, audit.Event{
|
||||
TenantID: supplierID,
|
||||
ObjectType: "supply_account",
|
||||
ObjectID: accountID,
|
||||
@@ -260,7 +269,7 @@ func (s *accountService) Delete(ctx context.Context, supplierID, accountID int64
|
||||
return errors.New("SUP_ACC_4092: cannot delete active accounts")
|
||||
}
|
||||
|
||||
s.auditStore.Emit(ctx, audit.Event{
|
||||
s.emitAudit(ctx, audit.Event{
|
||||
TenantID: supplierID,
|
||||
ObjectType: "supply_account",
|
||||
ObjectID: accountID,
|
||||
|
||||
@@ -3,6 +3,7 @@ package domain
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
@@ -132,6 +133,14 @@ func NewPackageService(store PackageStore, accountStore AccountStore, auditStore
|
||||
}
|
||||
}
|
||||
|
||||
// emitAudit 安全记录审计日志(失败只记录错误,不影响主流程)
|
||||
func (s *packageService) emitAudit(ctx context.Context, event audit.Event) {
|
||||
if err := s.auditStore.Emit(ctx, event); err != nil {
|
||||
log.Printf("[AUDIT_ERROR] failed to emit audit event: %v, object_type=%s, object_id=%d, action=%s",
|
||||
err, event.ObjectType, event.ObjectID, event.Action)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *packageService) CreateDraft(ctx context.Context, supplierID int64, req *CreatePackageDraftRequest) (*Package, error) {
|
||||
pkg := &Package{
|
||||
SupplierID: supplierID,
|
||||
@@ -154,7 +163,7 @@ func (s *packageService) CreateDraft(ctx context.Context, supplierID int64, req
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.auditStore.Emit(ctx, audit.Event{
|
||||
s.emitAudit(ctx, audit.Event{
|
||||
TenantID: supplierID,
|
||||
ObjectType: "supply_package",
|
||||
ObjectID: pkg.ID,
|
||||
@@ -183,7 +192,7 @@ func (s *packageService) Publish(ctx context.Context, supplierID, packageID int6
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.auditStore.Emit(ctx, audit.Event{
|
||||
s.emitAudit(ctx, audit.Event{
|
||||
TenantID: supplierID,
|
||||
ObjectType: "supply_package",
|
||||
ObjectID: packageID,
|
||||
@@ -212,7 +221,7 @@ func (s *packageService) Pause(ctx context.Context, supplierID, packageID int64)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.auditStore.Emit(ctx, audit.Event{
|
||||
s.emitAudit(ctx, audit.Event{
|
||||
TenantID: supplierID,
|
||||
ObjectType: "supply_package",
|
||||
ObjectID: packageID,
|
||||
@@ -237,7 +246,7 @@ func (s *packageService) Unlist(ctx context.Context, supplierID, packageID int64
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.auditStore.Emit(ctx, audit.Event{
|
||||
s.emitAudit(ctx, audit.Event{
|
||||
TenantID: supplierID,
|
||||
ObjectType: "supply_package",
|
||||
ObjectID: packageID,
|
||||
@@ -275,7 +284,7 @@ func (s *packageService) Clone(ctx context.Context, supplierID, packageID int64)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.auditStore.Emit(ctx, audit.Event{
|
||||
s.emitAudit(ctx, audit.Event{
|
||||
TenantID: supplierID,
|
||||
ObjectType: "supply_package",
|
||||
ObjectID: clone.ID,
|
||||
@@ -292,6 +301,17 @@ func (s *packageService) BatchUpdatePrice(ctx context.Context, supplierID int64,
|
||||
}
|
||||
|
||||
for _, item := range req.Items {
|
||||
// 验证价格不能为负数
|
||||
if item.PricePer1MInput < 0 || item.PricePer1MOutput < 0 {
|
||||
resp.FailedCount++
|
||||
resp.Failures = append(resp.Failures, BatchPriceFailure{
|
||||
PackageID: item.PackageID,
|
||||
ErrorCode: "SUP_PKG_4004",
|
||||
Message: "price cannot be negative",
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
pkg, err := s.store.GetByID(ctx, supplierID, item.PackageID)
|
||||
if err != nil {
|
||||
resp.FailedCount++
|
||||
|
||||
@@ -3,6 +3,7 @@ package domain
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
@@ -160,11 +161,24 @@ func NewSettlementService(store SettlementStore, earningStore EarningStore, audi
|
||||
}
|
||||
}
|
||||
|
||||
// emitAudit 安全记录审计日志(失败只记录错误,不影响主流程)
|
||||
func (s *settlementService) emitAudit(ctx context.Context, event audit.Event) {
|
||||
if err := s.auditStore.Emit(ctx, event); err != nil {
|
||||
log.Printf("[AUDIT_ERROR] failed to emit audit event: %v, object_type=%s, object_id=%d, action=%s",
|
||||
err, event.ObjectType, event.ObjectID, event.Action)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *settlementService) Withdraw(ctx context.Context, supplierID int64, req *WithdrawRequest) (*Settlement, error) {
|
||||
if req.SMSCode != "123456" {
|
||||
return nil, errors.New("invalid sms code")
|
||||
}
|
||||
|
||||
// 验证金额:必须为正数
|
||||
if req.Amount <= 0 {
|
||||
return nil, errors.New("SUP_SET_4003: withdraw amount must be positive")
|
||||
}
|
||||
|
||||
balance, err := s.store.GetWithdrawableBalance(ctx, supplierID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -192,7 +206,7 @@ func (s *settlementService) Withdraw(ctx context.Context, supplierID int64, req
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.auditStore.Emit(ctx, audit.Event{
|
||||
s.emitAudit(ctx, audit.Event{
|
||||
TenantID: supplierID,
|
||||
ObjectType: "supply_settlement",
|
||||
ObjectID: settlement.ID,
|
||||
@@ -221,7 +235,7 @@ func (s *settlementService) Cancel(ctx context.Context, supplierID, settlementID
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.auditStore.Emit(ctx, audit.Event{
|
||||
s.emitAudit(ctx, audit.Event{
|
||||
TenantID: supplierID,
|
||||
ObjectType: "supply_settlement",
|
||||
ObjectID: settlementID,
|
||||
|
||||
125
supply-api/internal/httpapi/alert_api.go
Normal file
125
supply-api/internal/httpapi/alert_api.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package httpapi
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"lijiaoqiao/supply-api/internal/audit/handler"
|
||||
"lijiaoqiao/supply-api/internal/audit/service"
|
||||
)
|
||||
|
||||
// AlertAPI 告警API处理器
|
||||
type AlertAPI struct {
|
||||
alertHandler *handler.AlertHandler
|
||||
}
|
||||
|
||||
// NewAlertAPI 创建告警API处理器
|
||||
func NewAlertAPI() *AlertAPI {
|
||||
// 创建内存告警存储
|
||||
alertStore := service.NewInMemoryAlertStore()
|
||||
// 创建告警服务
|
||||
alertSvc := service.NewAlertService(alertStore)
|
||||
// 创建告警处理器
|
||||
alertHandler := handler.NewAlertHandler(alertSvc)
|
||||
|
||||
return &AlertAPI{
|
||||
alertHandler: alertHandler,
|
||||
}
|
||||
}
|
||||
|
||||
// Register 注册告警路由
|
||||
func (a *AlertAPI) Register(mux *http.ServeMux) {
|
||||
// Alert CRUD
|
||||
mux.HandleFunc("/api/v1/audit/alerts", a.handleAlert)
|
||||
mux.HandleFunc("/api/v1/audit/alerts/", a.handleAlertByID)
|
||||
}
|
||||
|
||||
// handleAlert 处理 /api/v1/audit/alerts 的路由分发
|
||||
func (a *AlertAPI) handleAlert(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodPost:
|
||||
a.alertHandler.CreateAlert(w, r)
|
||||
case http.MethodGet:
|
||||
a.alertHandler.ListAlerts(w, r)
|
||||
default:
|
||||
writeError(w, http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", "method not allowed")
|
||||
}
|
||||
}
|
||||
|
||||
// handleAlertByID 处理 /api/v1/audit/alerts/{alert_id} 的路由分发
|
||||
func (a *AlertAPI) handleAlertByID(w http.ResponseWriter, r *http.Request) {
|
||||
// 提取路径最后部分判断操作
|
||||
path := r.URL.Path
|
||||
if len(path) > 0 && path[len(path)-1] == '/' {
|
||||
path = path[:len(path)-1]
|
||||
}
|
||||
|
||||
parts := splitPath(path)
|
||||
if len(parts) < 5 {
|
||||
writeError(w, http.StatusBadRequest, "INVALID_PATH", "invalid path")
|
||||
return
|
||||
}
|
||||
|
||||
alertID := parts[4]
|
||||
|
||||
// 检查是否是特殊操作
|
||||
if len(parts) > 5 && parts[5] == "resolve" {
|
||||
if r.Method == http.MethodPost {
|
||||
a.alertHandler.ResolveAlert(w, r)
|
||||
} else {
|
||||
writeError(w, http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", "method not allowed")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 常规CRUD操作
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
// 安全设置alert_id到查询参数
|
||||
query := make(url.Values)
|
||||
for k, v := range r.URL.Query() {
|
||||
query[k] = v
|
||||
}
|
||||
query.Set("alert_id", alertID)
|
||||
r.URL.RawQuery = query.Encode()
|
||||
a.alertHandler.GetAlert(w, r)
|
||||
case http.MethodPut:
|
||||
query := make(url.Values)
|
||||
for k, v := range r.URL.Query() {
|
||||
query[k] = v
|
||||
}
|
||||
query.Set("alert_id", alertID)
|
||||
r.URL.RawQuery = query.Encode()
|
||||
a.alertHandler.UpdateAlert(w, r)
|
||||
case http.MethodDelete:
|
||||
query := make(url.Values)
|
||||
for k, v := range r.URL.Query() {
|
||||
query[k] = v
|
||||
}
|
||||
query.Set("alert_id", alertID)
|
||||
r.URL.RawQuery = query.Encode()
|
||||
a.alertHandler.DeleteAlert(w, r)
|
||||
default:
|
||||
writeError(w, http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", "method not allowed")
|
||||
}
|
||||
}
|
||||
|
||||
// splitPath 分割路径
|
||||
func splitPath(path string) []string {
|
||||
var parts []string
|
||||
var current []byte
|
||||
for i := 0; i < len(path); i++ {
|
||||
if path[i] == '/' {
|
||||
if len(current) > 0 {
|
||||
parts = append(parts, string(current))
|
||||
current = nil
|
||||
}
|
||||
} else {
|
||||
current = append(current, path[i])
|
||||
}
|
||||
}
|
||||
if len(current) > 0 {
|
||||
parts = append(parts, string(current))
|
||||
}
|
||||
return parts
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"lijiaoqiao/supply-api/internal/iam/model"
|
||||
"lijiaoqiao/supply-api/internal/middleware"
|
||||
)
|
||||
|
||||
@@ -17,7 +18,12 @@ const (
|
||||
IAMTokenClaimsKey iamContextKey = "iam_token_claims"
|
||||
)
|
||||
|
||||
// ClaimsVersion Token Claims版本号,用于迁移追踪
|
||||
const ClaimsVersion = 1
|
||||
|
||||
// IAMTokenClaims IAM扩展Token Claims
|
||||
// 版本: v1
|
||||
// 迁移路径: 见 MigrateClaims 函数
|
||||
type IAMTokenClaims struct {
|
||||
SubjectID string `json:"subject_id"`
|
||||
Role string `json:"role"`
|
||||
@@ -25,30 +31,73 @@ type IAMTokenClaims struct {
|
||||
TenantID int64 `json:"tenant_id"`
|
||||
UserType string `json:"user_type"` // 用户类型: platform/supply/consumer
|
||||
Permissions []string `json:"permissions"` // 细粒度权限列表
|
||||
|
||||
// 版本控制字段(未来迁移用)
|
||||
Version int `json:"version,omitempty"`
|
||||
}
|
||||
|
||||
// 角色层级定义
|
||||
var roleHierarchyLevels = map[string]int{
|
||||
"super_admin": 100,
|
||||
"org_admin": 50,
|
||||
"supply_admin": 40,
|
||||
"consumer_admin": 40,
|
||||
"operator": 30,
|
||||
"developer": 20,
|
||||
"finops": 20,
|
||||
"supply_operator": 30,
|
||||
"supply_finops": 20,
|
||||
"supply_viewer": 10,
|
||||
"consumer_operator": 30,
|
||||
"consumer_viewer": 10,
|
||||
"viewer": 10,
|
||||
// MigrateClaims 将旧版本Claims迁移到当前版本
|
||||
// 迁移路径:
|
||||
// v0 -> v1: 初始版本,添加 Version 字段
|
||||
//
|
||||
// 使用示例:
|
||||
// claims := &IAMTokenClaims{}
|
||||
// if err := json.Unmarshal(data, claims); err != nil {
|
||||
// return err
|
||||
// }
|
||||
// migrated := MigrateClaims(claims)
|
||||
// // 使用 migrated
|
||||
func MigrateClaims(claims *IAMTokenClaims) *IAMTokenClaims {
|
||||
if claims == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 当前版本是v1,无需迁移
|
||||
// 未来版本迁移:
|
||||
// case 0:
|
||||
// claims = migrateV0ToV1(claims)
|
||||
// case 1:
|
||||
// claims = migrateV1ToV2(claims)
|
||||
claims.Version = ClaimsVersion
|
||||
return claims
|
||||
}
|
||||
|
||||
// ValidateClaims 验证Claims完整性
|
||||
func ValidateClaims(claims *IAMTokenClaims) error {
|
||||
if claims == nil {
|
||||
return ErrInvalidClaims
|
||||
}
|
||||
if claims.SubjectID == "" {
|
||||
return ErrInvalidSubjectID
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 迁移相关错误
|
||||
var (
|
||||
ErrInvalidClaims = &ClaimsError{Code: "IAM_CLAIMS_4001", Message: "invalid claims structure"}
|
||||
ErrInvalidSubjectID = &ClaimsError{Code: "IAM_CLAIMS_4002", Message: "subject_id is required"}
|
||||
)
|
||||
|
||||
// ClaimsError Claims相关错误
|
||||
type ClaimsError struct {
|
||||
Code string
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *ClaimsError) Error() string {
|
||||
return e.Code + ": " + e.Message
|
||||
}
|
||||
|
||||
// 角色层级定义(已废弃,请使用 model.RoleHierarchyLevels)
|
||||
// @deprecated 使用 model.RoleHierarchyLevels 获取角色层级
|
||||
var roleHierarchyLevels = model.RoleHierarchyLevels
|
||||
|
||||
// ScopeAuthMiddleware Scope权限验证中间件
|
||||
type ScopeAuthMiddleware struct {
|
||||
// 路由-Scope映射
|
||||
routeScopePolicies map[string][]string
|
||||
// 角色层级(已废弃,使用包级变量roleHierarchyLevels)
|
||||
// 角色层级
|
||||
roleHierarchy map[string]int
|
||||
}
|
||||
|
||||
@@ -56,7 +105,7 @@ type ScopeAuthMiddleware struct {
|
||||
func NewScopeAuthMiddleware() *ScopeAuthMiddleware {
|
||||
return &ScopeAuthMiddleware{
|
||||
routeScopePolicies: make(map[string][]string),
|
||||
roleHierarchy: roleHierarchyLevels,
|
||||
roleHierarchy: model.RoleHierarchyLevels, // 使用统一的角色层级定义
|
||||
}
|
||||
}
|
||||
|
||||
@@ -142,11 +191,9 @@ func HasRoleLevel(ctx context.Context, minLevel int) bool {
|
||||
}
|
||||
|
||||
// GetRoleLevel 获取角色层级数值
|
||||
// @deprecated 请使用 model.GetRoleLevelByCode
|
||||
func GetRoleLevel(role string) int {
|
||||
if level, ok := roleHierarchyLevels[role]; ok {
|
||||
return level
|
||||
}
|
||||
return 0
|
||||
return model.GetRoleLevelByCode(role)
|
||||
}
|
||||
|
||||
// GetIAMTokenClaims 获取IAM Token Claims
|
||||
|
||||
@@ -15,6 +15,7 @@ const (
|
||||
)
|
||||
|
||||
// 角色层级常量(用于权限优先级判断)
|
||||
// 注意:这些常量值必须与 RoleHierarchyLevels map保持一致
|
||||
const (
|
||||
LevelSuperAdmin = 100
|
||||
LevelOrgAdmin = 50
|
||||
@@ -25,6 +26,33 @@ const (
|
||||
LevelViewer = 10
|
||||
)
|
||||
|
||||
// RoleHierarchyLevels 角色层级映射(用于权限验证)
|
||||
// 层级越高,权限越大。superset角色可以执行subset角色的操作。
|
||||
// 注意:此map的值必须与上述常量保持一致!
|
||||
var RoleHierarchyLevels = map[string]int{
|
||||
"super_admin": LevelSuperAdmin, // 100 - 超级管理员
|
||||
"org_admin": LevelOrgAdmin, // 50 - 组织管理员
|
||||
"supply_admin": LevelSupplyAdmin, // 40 - 供应商管理员
|
||||
"consumer_admin": LevelSupplyAdmin, // 40 - 消费者管理员(同供应商)
|
||||
"operator": LevelOperator, // 30 - 操作员
|
||||
"developer": LevelDeveloper, // 20 - 开发者
|
||||
"finops": LevelFinops, // 20 - 财务运营
|
||||
"supply_operator": LevelOperator, // 30 - 供应商操作员
|
||||
"supply_finops": LevelFinops, // 20 - 供应商财务
|
||||
"supply_viewer": LevelViewer, // 10 - 供应商查看者
|
||||
"consumer_operator": LevelOperator, // 30 - 消费者操作员
|
||||
"consumer_viewer": LevelViewer, // 10 - 消费者查看者
|
||||
"viewer": LevelViewer, // 10 - 通用查看者
|
||||
}
|
||||
|
||||
// GetRoleLevelByCode 根据角色代码获取层级数值
|
||||
func GetRoleLevelByCode(roleCode string) int {
|
||||
if level, ok := RoleHierarchyLevels[roleCode]; ok {
|
||||
return level
|
||||
}
|
||||
return 0 // 默认最低级别
|
||||
}
|
||||
|
||||
// 角色错误定义
|
||||
var (
|
||||
ErrInvalidRoleCode = errors.New("invalid role code: cannot be empty")
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -14,6 +15,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
|
||||
"lijiaoqiao/supply-api/internal/iam/model"
|
||||
)
|
||||
|
||||
// TokenClaims JWT token claims
|
||||
@@ -84,11 +87,13 @@ type BruteForceProtection struct {
|
||||
lockoutDuration time.Duration
|
||||
attempts map[string]*attemptRecord
|
||||
mu sync.Mutex
|
||||
cleanupCounter int64 // 清理触发计数器
|
||||
}
|
||||
|
||||
type attemptRecord struct {
|
||||
count int
|
||||
lockedUntil time.Time
|
||||
lastAttempt time.Time // 最后尝试时间,用于过期清理
|
||||
}
|
||||
|
||||
// NewBruteForceProtection 创建暴力破解保护
|
||||
@@ -114,9 +119,11 @@ func (b *BruteForceProtection) RecordFailedAttempt(ip string) {
|
||||
}
|
||||
|
||||
record.count++
|
||||
record.lastAttempt = time.Now()
|
||||
if record.count >= b.maxAttempts {
|
||||
record.lockedUntil = time.Now().Add(b.lockoutDuration)
|
||||
}
|
||||
b.triggerCleanup()
|
||||
}
|
||||
|
||||
// IsLocked 检查IP是否被锁定
|
||||
@@ -150,6 +157,42 @@ func (b *BruteForceProtection) Reset(ip string) {
|
||||
delete(b.attempts, ip)
|
||||
}
|
||||
|
||||
// triggerCleanup 触发清理(每100次操作清理一次过期记录)
|
||||
func (b *BruteForceProtection) triggerCleanup() {
|
||||
b.cleanupCounter++
|
||||
if b.cleanupCounter >= 100 {
|
||||
b.cleanupCounter = 0
|
||||
b.cleanupExpiredLocked()
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupExpiredLocked 清理过期记录(需要持有锁)
|
||||
// 清理条件:锁定已过期且最后尝试时间超过lockoutDuration
|
||||
func (b *BruteForceProtection) cleanupExpiredLocked() {
|
||||
now := time.Now()
|
||||
threshold := now.Add(-b.lockoutDuration * 2) // 超过两倍锁定时长未活动的记录清理
|
||||
for ip, record := range b.attempts {
|
||||
// 清理:锁定已过期且长时间无活动
|
||||
if record.lockedUntil.Before(now) && record.lastAttempt.Before(threshold) {
|
||||
delete(b.attempts, ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CleanExpired 主动清理过期记录(可由外部定期调用)
|
||||
func (b *BruteForceProtection) CleanExpired() {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
b.cleanupExpiredLocked()
|
||||
}
|
||||
|
||||
// Len 返回当前记录数量(用于监控)
|
||||
func (b *BruteForceProtection) Len() int {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
return len(b.attempts)
|
||||
}
|
||||
|
||||
// QueryKeyRejectMiddleware 拒绝外部query key入站
|
||||
// 对应M-016指标
|
||||
func (m *AuthMiddleware) QueryKeyRejectMiddleware(next http.Handler) http.Handler {
|
||||
@@ -263,7 +306,19 @@ func (m *AuthMiddleware) TokenVerifyMiddleware(next http.Handler) http.Handler {
|
||||
}
|
||||
}
|
||||
|
||||
tokenString := r.Context().Value(bearerTokenKey).(string)
|
||||
// 安全检查:确保BearerExtractMiddleware已执行
|
||||
tokenValue := r.Context().Value(bearerTokenKey)
|
||||
if tokenValue == nil {
|
||||
writeAuthError(w, http.StatusUnauthorized, "AUTH_TOKEN_MISSING",
|
||||
"bearer token is missing")
|
||||
return
|
||||
}
|
||||
tokenString, ok := tokenValue.(string)
|
||||
if !ok || tokenString == "" {
|
||||
writeAuthError(w, http.StatusUnauthorized, "AUTH_TOKEN_INVALID",
|
||||
"bearer token is invalid")
|
||||
return
|
||||
}
|
||||
|
||||
claims, err := m.verifyToken(tokenString)
|
||||
if err != nil {
|
||||
@@ -289,7 +344,7 @@ func (m *AuthMiddleware) TokenVerifyMiddleware(next http.Handler) http.Handler {
|
||||
}
|
||||
|
||||
// 检查token状态(是否被吊销)
|
||||
status, err := m.checkTokenStatus(claims.ID)
|
||||
status, err := m.checkTokenStatus(r.Context(), claims.ID)
|
||||
if err == nil && status != "active" {
|
||||
if m.auditEmitter != nil {
|
||||
m.auditEmitter.Emit(r.Context(), AuditEvent{
|
||||
@@ -363,24 +418,21 @@ func (m *AuthMiddleware) ScopeRoleAuthzMiddleware(requiredScope string) func(htt
|
||||
}
|
||||
|
||||
// 检查role权限
|
||||
roleHierarchy := map[string]int{
|
||||
"admin": 3,
|
||||
"owner": 2,
|
||||
"viewer": 1,
|
||||
}
|
||||
// 使用model.GetRoleLevelByCode获取统一角色层级定义
|
||||
|
||||
// 路由权限要求
|
||||
// 路由权限要求(使用详细角色代码)
|
||||
// viewer: level 10, operator: level 30, org_admin: level 50
|
||||
routeRoles := map[string]string{
|
||||
"/api/v1/supply/accounts": "owner",
|
||||
"/api/v1/supply/packages": "owner",
|
||||
"/api/v1/supply/settlements": "owner",
|
||||
"/api/v1/supply/billing": "viewer",
|
||||
"/api/v1/supplier/billing": "viewer",
|
||||
"/api/v1/supply/accounts": "org_admin",
|
||||
"/api/v1/supply/packages": "org_admin",
|
||||
"/api/v1/supply/settlements": "org_admin",
|
||||
"/api/v1/supply/billing": "viewer",
|
||||
"/api/v1/supplier/billing": "viewer",
|
||||
}
|
||||
|
||||
for path, requiredRole := range routeRoles {
|
||||
if strings.HasPrefix(r.URL.Path, path) {
|
||||
if roleLevel(claims.Role, roleHierarchy) < roleLevel(requiredRole, roleHierarchy) {
|
||||
if model.GetRoleLevelByCode(claims.Role) < model.GetRoleLevelByCode(requiredRole) {
|
||||
writeAuthError(w, http.StatusForbidden, "AUTH_ROLE_DENIED",
|
||||
fmt.Sprintf("required role '%s' is not granted, current role: '%s'", requiredRole, claims.Role))
|
||||
return
|
||||
@@ -430,7 +482,7 @@ func (m *AuthMiddleware) verifyToken(tokenString string) (*TokenClaims, error) {
|
||||
}
|
||||
|
||||
// checkTokenStatus 检查token状态(从缓存或数据库)
|
||||
func (m *AuthMiddleware) checkTokenStatus(tokenID string) (string, error) {
|
||||
func (m *AuthMiddleware) checkTokenStatus(ctx context.Context, tokenID string) (string, error) {
|
||||
if m.tokenCache != nil {
|
||||
// 先从缓存检查
|
||||
if status, found := m.tokenCache.Get(tokenID); found {
|
||||
@@ -440,7 +492,7 @@ func (m *AuthMiddleware) checkTokenStatus(tokenID string) (string, error) {
|
||||
|
||||
// 缓存未命中,查询后端验证token状态
|
||||
if m.tokenBackend != nil {
|
||||
return m.tokenBackend.CheckTokenStatus(context.Background(), tokenID)
|
||||
return m.tokenBackend.CheckTokenStatus(ctx, tokenID)
|
||||
}
|
||||
|
||||
// 没有后端实现时,应该拒绝访问而不是默认active
|
||||
@@ -472,7 +524,10 @@ func writeAuthError(w http.ResponseWriter, status int, code, message string) {
|
||||
"message": message,
|
||||
},
|
||||
}
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
// 记录编码错误(响应已经开始发送,无法回退)
|
||||
log.Printf("[AUTH_ERROR] failed to encode error response: %v, code=%s", err, code)
|
||||
}
|
||||
}
|
||||
|
||||
// getRequestID 获取请求ID
|
||||
@@ -488,7 +543,10 @@ func getClientIP(r *http.Request) string {
|
||||
// 优先从X-Forwarded-For获取
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
parts := strings.Split(xff, ",")
|
||||
return strings.TrimSpace(parts[0])
|
||||
// 安全检查:空字符串已在上层判断,但防御性编程
|
||||
if len(parts) > 0 {
|
||||
return strings.TrimSpace(parts[0])
|
||||
}
|
||||
}
|
||||
|
||||
// X-Real-IP
|
||||
@@ -550,14 +608,6 @@ func containsScope(scopes []string, target string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// roleLevel 获取角色等级
|
||||
func roleLevel(role string, hierarchy map[string]int) int {
|
||||
if level, ok := hierarchy[role]; ok {
|
||||
return level
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// parseSubjectID 解析subject ID
|
||||
func parseSubjectID(subject string) int64 {
|
||||
parts := strings.Split(subject, ":")
|
||||
@@ -570,7 +620,9 @@ func parseSubjectID(subject string) int64 {
|
||||
|
||||
// TokenCache Token状态缓存
|
||||
type TokenCache struct {
|
||||
data map[string]cacheEntry
|
||||
data map[string]cacheEntry
|
||||
mu sync.RWMutex
|
||||
cleanup int64 // 清理触发计数器
|
||||
}
|
||||
|
||||
type cacheEntry struct {
|
||||
@@ -581,34 +633,76 @@ type cacheEntry struct {
|
||||
// NewTokenCache 创建token缓存
|
||||
func NewTokenCache() *TokenCache {
|
||||
return &TokenCache{
|
||||
data: make(map[string]cacheEntry),
|
||||
data: make(map[string]cacheEntry),
|
||||
cleanup: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// Get 获取token状态
|
||||
func (c *TokenCache) Get(tokenID string) (string, bool) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
if entry, ok := c.data[tokenID]; ok {
|
||||
if time.Now().Before(entry.expires) {
|
||||
return entry.status, true
|
||||
}
|
||||
delete(c.data, tokenID)
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// Set 设置token状态
|
||||
func (c *TokenCache) Set(tokenID, status string, ttl time.Duration) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.data[tokenID] = cacheEntry{
|
||||
status: status,
|
||||
expires: time.Now().Add(ttl),
|
||||
}
|
||||
c.triggerCleanup()
|
||||
}
|
||||
|
||||
// Invalidate 使token失效
|
||||
func (c *TokenCache) Invalidate(tokenID string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
delete(c.data, tokenID)
|
||||
}
|
||||
|
||||
// triggerCleanup 触发清理(每100次操作清理一次过期条目)
|
||||
func (c *TokenCache) triggerCleanup() {
|
||||
c.cleanup++
|
||||
if c.cleanup >= 100 {
|
||||
c.cleanup = 0
|
||||
c.cleanupExpiredLocked()
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupExpiredLocked 清理过期条目(需要持有锁)
|
||||
func (c *TokenCache) cleanupExpiredLocked() {
|
||||
now := time.Now()
|
||||
for tokenID, entry := range c.data {
|
||||
if now.After(entry.expires) {
|
||||
delete(c.data, tokenID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CleanExpired 主动清理过期条目(可由外部定期调用)
|
||||
func (c *TokenCache) CleanExpired() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.cleanupExpiredLocked()
|
||||
}
|
||||
|
||||
// Len 返回缓存条目数量(用于监控)
|
||||
func (c *TokenCache) Len() int {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return len(c.data)
|
||||
}
|
||||
|
||||
// ComputeFingerprint 计算凭证指纹(用于审计)
|
||||
func ComputeFingerprint(credential string) string {
|
||||
hash := sha256.Sum256([]byte(credential))
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
@@ -8,6 +9,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
|
||||
"lijiaoqiao/supply-api/internal/iam/model"
|
||||
)
|
||||
|
||||
func TestTokenVerify(t *testing.T) {
|
||||
@@ -248,27 +251,25 @@ func TestContainsScope(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRoleLevel(t *testing.T) {
|
||||
hierarchy := map[string]int{
|
||||
"admin": 3,
|
||||
"owner": 2,
|
||||
"viewer": 1,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
role string
|
||||
expected int
|
||||
}{
|
||||
{"admin", 3},
|
||||
{"owner", 2},
|
||||
{"viewer", 1},
|
||||
{"super_admin", 100},
|
||||
{"org_admin", 50},
|
||||
{"supply_admin", 40},
|
||||
{"operator", 30},
|
||||
{"developer", 20},
|
||||
{"finops", 20},
|
||||
{"viewer", 10},
|
||||
{"unknown", 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.role, func(t *testing.T) {
|
||||
result := roleLevel(tt.role, hierarchy)
|
||||
result := model.GetRoleLevelByCode(tt.role)
|
||||
if result != tt.expected {
|
||||
t.Errorf("roleLevel(%s) = %d, want %d", tt.role, result, tt.expected)
|
||||
t.Errorf("GetRoleLevelByCode(%s) = %d, want %d", tt.role, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -411,7 +412,7 @@ func TestMED02_TokenCacheMiss_ShouldNotAssumeActive(t *testing.T) {
|
||||
}
|
||||
|
||||
// act - 查询一个不在缓存中的token
|
||||
status, err := middleware.checkTokenStatus("nonexistent-token-id")
|
||||
status, err := middleware.checkTokenStatus(context.Background(), "nonexistent-token-id")
|
||||
|
||||
// assert - 缓存未命中且没有后端时应该返回错误(安全修复)
|
||||
// 修复前bug:缓存未命中时默认返回"active"
|
||||
|
||||
@@ -254,7 +254,7 @@ func (r *AccountRepository) List(ctx context.Context, supplierID int64) ([]*doma
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var accounts []*domain.Account
|
||||
accounts := make([]*domain.Account, 0)
|
||||
for rows.Next() {
|
||||
account := &domain.Account{}
|
||||
err := rows.Scan(
|
||||
|
||||
@@ -206,7 +206,7 @@ func (r *PackageRepository) List(ctx context.Context, supplierID int64) ([]*doma
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var packages []*domain.Package
|
||||
packages := make([]*domain.Package, 0)
|
||||
for rows.Next() {
|
||||
pkg := &domain.Package{}
|
||||
err := rows.Scan(
|
||||
|
||||
@@ -195,7 +195,7 @@ func (r *SettlementRepository) List(ctx context.Context, supplierID int64) ([]*d
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var settlements []*domain.Settlement
|
||||
settlements := make([]*domain.Settlement, 0)
|
||||
for rows.Next() {
|
||||
s := &domain.Settlement{}
|
||||
err := rows.Scan(
|
||||
|
||||
@@ -66,7 +66,7 @@ func (s *InMemoryAccountStore) List(ctx context.Context, supplierID int64) ([]*d
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
var result []*domain.Account
|
||||
result := make([]*domain.Account, 0)
|
||||
for _, account := range s.accounts {
|
||||
if account.SupplierID == supplierID {
|
||||
result = append(result, account)
|
||||
@@ -129,7 +129,7 @@ func (s *InMemoryPackageStore) List(ctx context.Context, supplierID int64) ([]*d
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
var result []*domain.Package
|
||||
result := make([]*domain.Package, 0)
|
||||
for _, pkg := range s.packages {
|
||||
if pkg.SupplierID == supplierID {
|
||||
result = append(result, pkg)
|
||||
@@ -192,7 +192,7 @@ func (s *InMemorySettlementStore) List(ctx context.Context, supplierID int64) ([
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
var result []*domain.Settlement
|
||||
result := make([]*domain.Settlement, 0)
|
||||
for _, settlement := range s.settlements {
|
||||
if settlement.SupplierID == supplierID {
|
||||
result = append(result, settlement)
|
||||
@@ -264,8 +264,9 @@ func (s *InMemoryEarningStore) GetBillingSummary(ctx context.Context, supplierID
|
||||
|
||||
// 内存幂等存储
|
||||
type InMemoryIdempotencyStore struct {
|
||||
mu sync.RWMutex
|
||||
records map[string]*IdempotencyRecord
|
||||
mu sync.RWMutex
|
||||
records map[string]*IdempotencyRecord
|
||||
cleanupCounter int64 // 清理触发计数器
|
||||
}
|
||||
|
||||
type IdempotencyRecord struct {
|
||||
@@ -303,6 +304,7 @@ func (s *InMemoryIdempotencyStore) SetProcessing(key string, ttl time.Duration)
|
||||
CreatedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(ttl),
|
||||
}
|
||||
s.triggerCleanupLocked()
|
||||
}
|
||||
|
||||
func (s *InMemoryIdempotencyStore) SetSuccess(key string, response interface{}, ttl time.Duration) {
|
||||
@@ -316,4 +318,39 @@ func (s *InMemoryIdempotencyStore) SetSuccess(key string, response interface{},
|
||||
CreatedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(ttl),
|
||||
}
|
||||
s.triggerCleanupLocked()
|
||||
}
|
||||
|
||||
// triggerCleanupLocked 触发清理(每100次操作清理一次过期记录)
|
||||
// 调用时必须持有锁
|
||||
func (s *InMemoryIdempotencyStore) triggerCleanupLocked() {
|
||||
s.cleanupCounter++
|
||||
if s.cleanupCounter >= 100 {
|
||||
s.cleanupCounter = 0
|
||||
s.cleanupExpiredLocked()
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupExpiredLocked 清理过期记录(需要持有锁)
|
||||
func (s *InMemoryIdempotencyStore) cleanupExpiredLocked() {
|
||||
now := time.Now()
|
||||
for key, record := range s.records {
|
||||
if record.ExpiresAt.Before(now) {
|
||||
delete(s.records, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CleanExpired 主动清理过期记录(可由外部定期调用)
|
||||
func (s *InMemoryIdempotencyStore) CleanExpired() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.cleanupExpiredLocked()
|
||||
}
|
||||
|
||||
// Len 返回当前记录数量(用于监控)
|
||||
func (s *InMemoryIdempotencyStore) Len() int {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return len(s.records)
|
||||
}
|
||||
|
||||
132
supply-api/pkg/error/errors.go
Normal file
132
supply-api/pkg/error/errors.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package error
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ErrorCode 错误码格式:{DOMAIN}_{CODE}
|
||||
// 错误码命名规范:{模块}_{问题类型}_{序号}
|
||||
//
|
||||
// 示例:
|
||||
// - SUP_ACC_4001 (Supplier Account - 业务错误 - 4001)
|
||||
// - AUDIT_EVT_4041 (Audit Event - 资源不存在 - 4041)
|
||||
//
|
||||
// 错误码分类:
|
||||
// - 4xxx: 业务逻辑错误
|
||||
// - 5xxx: 系统/服务器错误
|
||||
// - 9xxx: 内部/未知错误
|
||||
|
||||
// 预定义的错误码前缀
|
||||
const (
|
||||
PrefixSUP = "SUP" // Supplier 模块
|
||||
PrefixIAM = "IAM" // Identity & Access Management 模块
|
||||
PrefixAudit = "AUDIT" // Audit 模块
|
||||
PrefixRepo = "REPO" // Repository 模块
|
||||
PrefixSys = "SYS" // 系统级错误
|
||||
)
|
||||
|
||||
// CodeError 带错误码的错误
|
||||
type CodeError struct {
|
||||
Code string // 错误码,如 "SUP_ACC_4001"
|
||||
Message string // 错误消息
|
||||
Err error // 底层错误(可选)
|
||||
}
|
||||
|
||||
// Error 实现 error 接口
|
||||
func (e *CodeError) Error() string {
|
||||
if e.Err != nil {
|
||||
return fmt.Sprintf("%s: %s (caused by: %v)", e.Code, e.Message, e.Err)
|
||||
}
|
||||
return fmt.Sprintf("%s: %s", e.Code, e.Message)
|
||||
}
|
||||
|
||||
// Unwrap 获取底层错误
|
||||
func (e *CodeError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
// NewCodeError 创建带错误码的错误
|
||||
func NewCodeError(code, message string) *CodeError {
|
||||
return &CodeError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
}
|
||||
}
|
||||
|
||||
// WrapCodeError 包装已有错误
|
||||
func WrapCodeError(err error, code, message string) *CodeError {
|
||||
return &CodeError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// IsCodeError 检查错误是否为 CodeError
|
||||
func IsCodeError(err error) bool {
|
||||
_, ok := err.(*CodeError)
|
||||
return ok
|
||||
}
|
||||
|
||||
// GetErrorCode 从错误中提取错误码
|
||||
func GetErrorCode(err error) string {
|
||||
var codeErr *CodeError
|
||||
if As(err, &codeErr) {
|
||||
return codeErr.Code
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// As 类型断言辅助函数
|
||||
func As(err error, target **CodeError) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if e, ok := err.(*CodeError); ok {
|
||||
*target = e
|
||||
return true
|
||||
}
|
||||
if e, ok := err.(interface{ Unwrap() error }); ok {
|
||||
return As(e.Unwrap(), target)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Common errors - 可以被各模块引用的通用错误
|
||||
var (
|
||||
// ErrNotFound 资源不存在
|
||||
ErrNotFound = NewCodeError("SYS_4040", "resource not found")
|
||||
|
||||
// ErrInvalidInput 输入参数无效
|
||||
ErrInvalidInput = NewCodeError("SYS_4000", "invalid input parameter")
|
||||
|
||||
// ErrUnauthorized 未授权
|
||||
ErrUnauthorized = NewCodeError("SYS_4010", "unauthorized")
|
||||
|
||||
// ErrForbidden 禁止访问
|
||||
ErrForbidden = NewCodeError("SYS_4030", "forbidden")
|
||||
|
||||
// ErrInternalServer 服务器内部错误
|
||||
ErrInternalServer = NewCodeError("SYS_5000", "internal server error")
|
||||
|
||||
// ErrConcurrencyConflict 并发冲突
|
||||
ErrConcurrencyConflict = NewCodeError("SYS_4090", "concurrency conflict")
|
||||
)
|
||||
|
||||
// ValidateErrorCode 验证错误码格式是否合法
|
||||
func ValidateErrorCode(code string) bool {
|
||||
parts := strings.Split(code, "_")
|
||||
if len(parts) < 2 {
|
||||
return false
|
||||
}
|
||||
// 检查前缀是否为有效值
|
||||
prefix := parts[0]
|
||||
validPrefixes := []string{PrefixSUP, PrefixIAM, PrefixAudit, PrefixRepo, PrefixSys}
|
||||
for _, p := range validPrefixes {
|
||||
if prefix == p {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
95
supply-api/pkg/error/errors_test.go
Normal file
95
supply-api/pkg/error/errors_test.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package error
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewCodeError(t *testing.T) {
|
||||
err := NewCodeError("TEST_4001", "test error message")
|
||||
assert.Equal(t, "TEST_4001", err.Code)
|
||||
assert.Equal(t, "test error message", err.Message)
|
||||
assert.Nil(t, err.Err)
|
||||
assert.Equal(t, "TEST_4001: test error message", err.Error())
|
||||
}
|
||||
|
||||
func TestWrapCodeError(t *testing.T) {
|
||||
originalErr := errors.New("original error")
|
||||
err := WrapCodeError(originalErr, "TEST_4001", "wrapped error")
|
||||
assert.Equal(t, "TEST_4001", err.Code)
|
||||
assert.Equal(t, "wrapped error", err.Message)
|
||||
assert.Equal(t, originalErr, err.Err)
|
||||
assert.Contains(t, err.Error(), "caused by: original error")
|
||||
}
|
||||
|
||||
func TestIsCodeError(t *testing.T) {
|
||||
codeErr := NewCodeError("TEST_4001", "test")
|
||||
assert.True(t, IsCodeError(codeErr))
|
||||
|
||||
stdErr := errors.New("standard error")
|
||||
assert.False(t, IsCodeError(stdErr))
|
||||
}
|
||||
|
||||
func TestGetErrorCode(t *testing.T) {
|
||||
codeErr := NewCodeError("SUP_ACC_4001", "test")
|
||||
assert.Equal(t, "SUP_ACC_4001", GetErrorCode(codeErr))
|
||||
|
||||
stdErr := errors.New("standard error")
|
||||
assert.Equal(t, "", GetErrorCode(stdErr))
|
||||
}
|
||||
|
||||
func TestUnwrap(t *testing.T) {
|
||||
originalErr := errors.New("original")
|
||||
wrapped := WrapCodeError(originalErr, "TEST_4001", "wrapped")
|
||||
|
||||
// 通过 Unwrap 获取原始错误
|
||||
unwrapped := wrapped.Unwrap()
|
||||
assert.Equal(t, originalErr, unwrapped)
|
||||
}
|
||||
|
||||
func TestValidateErrorCode(t *testing.T) {
|
||||
tests := []struct {
|
||||
code string
|
||||
expected bool
|
||||
}{
|
||||
{"SUP_ACC_4001", true},
|
||||
{"IAM_ROLE_4040", true},
|
||||
{"AUDIT_EVT_5000", true},
|
||||
{"REPO_NOT_FOUND", true},
|
||||
{"SYS_5000", true},
|
||||
{"INVALID", false}, // 没有下划线分隔
|
||||
{"BAD_CODE", false}, // 前缀不在白名单
|
||||
{"X_4001", false}, // 前缀不在白名单
|
||||
{"", false}, // 空字符串
|
||||
{"TOOLONG_4001", false}, // 前缀太长
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.code, func(t *testing.T) {
|
||||
result := ValidateErrorCode(tc.code)
|
||||
assert.Equal(t, tc.expected, result, "code: %s", tc.code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommonErrors(t *testing.T) {
|
||||
assert.Equal(t, "SYS_4040", ErrNotFound.Code)
|
||||
assert.Equal(t, "resource not found", ErrNotFound.Message)
|
||||
|
||||
assert.Equal(t, "SYS_4000", ErrInvalidInput.Code)
|
||||
assert.Equal(t, "invalid input parameter", ErrInvalidInput.Message)
|
||||
|
||||
assert.Equal(t, "SYS_4010", ErrUnauthorized.Code)
|
||||
assert.Equal(t, "unauthorized", ErrUnauthorized.Message)
|
||||
|
||||
assert.Equal(t, "SYS_4030", ErrForbidden.Code)
|
||||
assert.Equal(t, "forbidden", ErrForbidden.Message)
|
||||
|
||||
assert.Equal(t, "SYS_5000", ErrInternalServer.Code)
|
||||
assert.Equal(t, "internal server error", ErrInternalServer.Message)
|
||||
|
||||
assert.Equal(t, "SYS_4090", ErrConcurrencyConflict.Code)
|
||||
assert.Equal(t, "concurrency conflict", ErrConcurrencyConflict.Message)
|
||||
}
|
||||
Reference in New Issue
Block a user