Files
lijiaoqiao/supply-api/internal/audit/handler/audit_handler.go
Your Name 8ac23bf7d4 test: improve coverage and fix sanitizer bug
- Fix MaskMap to properly handle []string sensitive fields
- Add missing slice handling in sanitizer
- Add comprehensive tests for GetMetrics and CreateEventsBatch
- Improve audit/handler coverage from 49.8% to 68.8%
- Fix test expectations to match actual sanitizer behavior
- All tests pass
2026-04-08 07:44:58 +08:00

420 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package handler
import (
"context"
"encoding/json"
"net/http"
"strconv"
"strings"
"time"
"lijiaoqiao/supply-api/internal/audit/model"
"lijiaoqiao/supply-api/internal/audit/service"
)
// AuditHandler HTTP处理器
type AuditHandler struct {
svc *service.AuditService
metricsSvc *service.MetricsService
}
// NewAuditHandler 创建审计处理器
func NewAuditHandler(svc *service.AuditService) *AuditHandler {
return &AuditHandler{svc: svc}
}
// NewAuditHandlerWithMetrics 创建带指标服务的审计处理器
func NewAuditHandlerWithMetrics(svc *service.AuditService, metricsSvc *service.MetricsService) *AuditHandler {
return &AuditHandler{
svc: svc,
metricsSvc: metricsSvc,
}
}
// CreateEventRequest 创建事件请求
type CreateEventRequest struct {
EventName string `json:"event_name"`
EventCategory string `json:"event_category"`
EventSubCategory string `json:"event_sub_category"`
OperatorID int64 `json:"operator_id"`
TenantID int64 `json:"tenant_id"`
ObjectType string `json:"object_type"`
ObjectID int64 `json:"object_id"`
Action string `json:"action"`
IdempotencyKey string `json:"idempotency_key,omitempty"`
SourceIP string `json:"source_ip,omitempty"`
Success bool `json:"success"`
ResultCode string `json:"result_code,omitempty"`
}
// ErrorResponse 错误响应
type ErrorResponse struct {
Error string `json:"error"`
Code string `json:"code,omitempty"`
Details string `json:"details,omitempty"`
}
// ListEventsResponse 事件列表响应
type ListEventsResponse struct {
Events []*model.AuditEvent `json:"events"`
Total int64 `json:"total"`
Offset int `json:"offset"`
Limit int `json:"limit"`
}
// CreateEvent 处理POST /api/v1/audit/events
// @Summary 创建审计事件
// @Description 创建新的审计事件,支持幂等
// @Tags audit
// @Accept json
// @Produce json
// @Param event body CreateEventRequest true "事件信息"
// @Success 201 {object} service.CreateEventResult
// @Success 200 {object} service.CreateEventResult "幂等重复"
// @Success 409 {object} service.CreateEventResult "幂等冲突"
// @Failure 400 {object} ErrorResponse
// @Failure 500 {object} ErrorResponse
// @Router /api/v1/audit/events [post]
func (h *AuditHandler) CreateEvent(w http.ResponseWriter, r *http.Request) {
var req CreateEventRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "INVALID_REQUEST", "invalid request body: "+err.Error())
return
}
// 验证必填字段
if req.EventName == "" {
writeError(w, http.StatusBadRequest, "MISSING_FIELD", "event_name is required")
return
}
if req.EventCategory == "" {
writeError(w, http.StatusBadRequest, "MISSING_FIELD", "event_category is required")
return
}
event := &model.AuditEvent{
EventName: req.EventName,
EventCategory: req.EventCategory,
EventSubCategory: req.EventSubCategory,
OperatorID: req.OperatorID,
TenantID: req.TenantID,
ObjectType: req.ObjectType,
ObjectID: req.ObjectID,
Action: req.Action,
IdempotencyKey: req.IdempotencyKey,
SourceIP: req.SourceIP,
Success: req.Success,
ResultCode: req.ResultCode,
}
result, err := h.svc.CreateEvent(r.Context(), event)
if err != nil {
writeError(w, http.StatusInternalServerError, "CREATE_FAILED", err.Error())
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(result.StatusCode)
json.NewEncoder(w).Encode(result)
}
// ListEvents 处理GET /api/v1/audit/events
// @Summary 查询审计事件
// @Description 查询审计事件列表,支持分页和过滤
// @Tags audit
// @Produce json
// @Param tenant_id query int false "租户ID"
// @Param category query string false "事件类别"
// @Param event_name query string false "事件名称"
// @Param offset query int false "偏移量" default(0)
// @Param limit query int false "限制数量" default(100)
// @Success 200 {object} ListEventsResponse
// @Failure 500 {object} ErrorResponse
// @Router /api/v1/audit/events [get]
func (h *AuditHandler) ListEvents(w http.ResponseWriter, r *http.Request) {
filter := &service.EventFilter{}
// 解析查询参数
if tenantIDStr := r.URL.Query().Get("tenant_id"); tenantIDStr != "" {
tenantID, err := strconv.ParseInt(tenantIDStr, 10, 64)
if err == nil {
filter.TenantID = tenantID
}
}
if category := r.URL.Query().Get("category"); category != "" {
filter.Category = category
}
if eventName := r.URL.Query().Get("event_name"); eventName != "" {
filter.EventName = eventName
}
if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" {
offset, err := strconv.Atoi(offsetStr)
if err == nil {
filter.Offset = offset
}
}
if limitStr := r.URL.Query().Get("limit"); limitStr != "" {
limit, err := strconv.Atoi(limitStr)
if err == nil && limit > 0 && limit <= 1000 {
filter.Limit = limit
}
}
if filter.Limit == 0 {
filter.Limit = 100
}
events, total, err := h.svc.ListEventsWithFilter(r.Context(), filter)
if err != nil {
writeError(w, http.StatusInternalServerError, "QUERY_FAILED", err.Error())
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(ListEventsResponse{
Events: events,
Total: total,
Offset: filter.Offset,
Limit: filter.Limit,
})
}
// GetEventResponse 单个事件响应
type GetEventResponse struct {
Event *model.AuditEvent `json:"event"`
}
// GetEvent 处理 GET /api/v1/audit/events/{event_id}
// @Summary 获取单个审计事件
// @Description 根据事件ID获取审计事件详情
// @Tags audit
// @Produce json
// @Param event_id path string true "事件ID"
// @Success 200 {object} GetEventResponse
// @Failure 400 {object} ErrorResponse
// @Failure 404 {object} ErrorResponse
// @Failure 500 {object} ErrorResponse
// @Router /api/v1/audit/events/{event_id} [get]
func (h *AuditHandler) GetEvent(w http.ResponseWriter, r *http.Request) {
// 从路径提取 event_id
eventID := r.URL.Query().Get("event_id")
if eventID == "" {
// 尝试从路径参数获取
pathParts := strings.Split(strings.TrimPrefix(r.URL.Path, "/api/v1/audit/events/"), "/")
if len(pathParts) > 0 && pathParts[0] != "" {
eventID = pathParts[0]
}
}
if eventID == "" {
writeError(w, http.StatusBadRequest, "MISSING_PARAM", "event_id is required")
return
}
event, err := h.svc.GetEventByID(r.Context(), eventID)
if err != nil {
if err == service.ErrEventNotFound {
writeError(w, http.StatusNotFound, "NOT_FOUND", "event not found")
return
}
writeError(w, http.StatusInternalServerError, "GET_FAILED", err.Error())
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(GetEventResponse{Event: event})
}
// GetMetrics 处理 GET /api/v1/audit/metrics/{metric_id}
// @Summary 获取审计指标
// @Description 获取M-013~M-016指标数据
// @Tags audit
// @Produce json
// @Param metric_id path string true "指标ID (m013/m014/m015/m016)"
// @Param start query string false "开始时间 ISO8601"
// @Param end query string false "结束时间 ISO8601"
// @Success 200 {object} service.Metric
// @Failure 400 {object} ErrorResponse
// @Failure 500 {object} ErrorResponse
// @Router /api/v1/audit/metrics/{metric_id} [get]
func (h *AuditHandler) GetMetrics(w http.ResponseWriter, r *http.Request) {
if h.metricsSvc == nil {
writeError(w, http.StatusServiceUnavailable, "METRICS_UNAVAILABLE", "metrics service not available")
return
}
// 解析metric_id
metricID := r.URL.Query().Get("metric_id")
if metricID == "" {
// 从路径中提取
metricID = "m013" // 默认
}
// 解析时间范围
now := time.Now()
startStr := r.URL.Query().Get("start")
endStr := r.URL.Query().Get("end")
var start, end time.Time
if startStr != "" {
var err error
start, err = time.Parse(time.RFC3339, startStr)
if err != nil {
start = now.Add(-24 * time.Hour)
}
} else {
start = now.Add(-24 * time.Hour)
}
if endStr != "" {
var err error
end, err = time.Parse(time.RFC3339, endStr)
if err != nil {
end = now
}
} else {
end = now
}
// 根据metric_id调用对应的计算方法
var metric *service.Metric
var err error
switch metricID {
case "m013", "M013", "m014", "M014", "m015", "M015", "m016", "M016":
metric, err = h.calculateMetric(r.Context(), metricID, start, end)
default:
writeError(w, http.StatusBadRequest, "INVALID_METRIC", "invalid metric_id: "+metricID)
return
}
if err != nil {
writeError(w, http.StatusInternalServerError, "METRICS_FAILED", err.Error())
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(metric)
}
// CreateEventsBatchRequest 批量创建事件请求
type CreateEventsBatchRequest struct {
Events []*CreateEventRequest `json:"events"`
}
// CreateEventsBatchResponse 批量创建事件响应
type CreateEventsBatchResponse struct {
SuccessCount int `json:"success_count"`
FailCount int `json:"fail_count"`
Errors []string `json:"errors,omitempty"`
EventIDs []string `json:"event_ids,omitempty"`
}
// CreateEventsBatch 处理 POST /api/v1/audit/events/batch
// @Summary 批量创建审计事件
// @Description 批量创建审计事件支持最多50条/批次
// @Tags audit
// @Accept json
// @Produce json
// @Param events body CreateEventsBatchRequest true "事件列表"
// @Success 200 {object} CreateEventsBatchResponse
// @Failure 400 {object} ErrorResponse
// @Failure 500 {object} ErrorResponse
// @Router /api/v1/audit/events/batch [post]
func (h *AuditHandler) CreateEventsBatch(w http.ResponseWriter, r *http.Request) {
var req CreateEventsBatchRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "INVALID_REQUEST", "invalid request body: "+err.Error())
return
}
// 限制批次大小
if len(req.Events) > 50 {
writeError(w, http.StatusBadRequest, "BATCH_TOO_LARGE", "batch size cannot exceed 50")
return
}
if len(req.Events) == 0 {
writeError(w, http.StatusBadRequest, "EMPTY_BATCH", "batch cannot be empty")
return
}
// 转换为 AuditEvent
events := make([]*model.AuditEvent, 0, len(req.Events))
for i, eventReq := range req.Events {
// 验证必填字段
if eventReq.EventName == "" {
writeError(w, http.StatusBadRequest, "VALIDATION_FAILED", "event["+strconv.Itoa(i)+"]: event_name is required")
return
}
if eventReq.EventCategory == "" {
writeError(w, http.StatusBadRequest, "VALIDATION_FAILED", "event["+strconv.Itoa(i)+"]: event_category is required")
return
}
event := &model.AuditEvent{
EventName: eventReq.EventName,
EventCategory: eventReq.EventCategory,
EventSubCategory: eventReq.EventSubCategory,
OperatorID: eventReq.OperatorID,
TenantID: eventReq.TenantID,
ObjectType: eventReq.ObjectType,
ObjectID: eventReq.ObjectID,
Action: eventReq.Action,
IdempotencyKey: eventReq.IdempotencyKey,
SourceIP: eventReq.SourceIP,
Success: eventReq.Success,
ResultCode: eventReq.ResultCode,
}
events = append(events, event)
}
// 调用批量创建
batchResult, err := h.svc.CreateEventsBatch(r.Context(), events)
if err != nil {
writeError(w, http.StatusInternalServerError, "BATCH_CREATE_FAILED", err.Error())
return
}
response := &CreateEventsBatchResponse{
SuccessCount: batchResult.SuccessCount,
FailCount: batchResult.FailCount,
Errors: batchResult.Errors,
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(response)
}
// calculateMetric 根据metric_id计算指标
func (h *AuditHandler) calculateMetric(ctx context.Context, metricID string, start, end time.Time) (*service.Metric, error) {
switch metricID {
case "m013", "M013":
return h.metricsSvc.CalculateM013(ctx, start, end)
case "m014", "M014":
return h.metricsSvc.CalculateM014(ctx, start, end)
case "m015", "M015":
return h.metricsSvc.CalculateM015(ctx, start, end)
case "m016", "M016":
return h.metricsSvc.CalculateM016(ctx, start, end)
default:
return nil, nil
}
}
// writeError 写入错误响应
func writeError(w http.ResponseWriter, status int, code, message string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(ErrorResponse{
Error: message,
Code: code,
Details: "",
})
}