Files
lijiaoqiao/supply-api/internal/audit/repository/audit_repository.go

420 lines
12 KiB
Go
Raw Normal View History

package repository
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"lijiaoqiao/supply-api/internal/audit/model"
)
// EventFilter 事件查询过滤器(仓储层定义,避免循环依赖)
type EventFilter struct {
TenantID int64
OperatorID int64
Category string
EventName string
StartTime *time.Time
EndTime *time.Time
Limit int
Offset int
}
// AuditRepository 审计事件仓储接口
type AuditRepository interface {
// Emit 发送审计事件
Emit(ctx context.Context, event *model.AuditEvent) error
// Query 查询审计事件
Query(ctx context.Context, filter *EventFilter) ([]*model.AuditEvent, int64, error)
// GetByIdempotencyKey 根据幂等键获取事件
GetByIdempotencyKey(ctx context.Context, key string) (*model.AuditEvent, error)
}
// PostgresAuditRepository PostgreSQL实现的审计仓储
type PostgresAuditRepository struct {
pool *pgxpool.Pool
}
// NewPostgresAuditRepository 创建PostgreSQL审计仓储
func NewPostgresAuditRepository(pool *pgxpool.Pool) *PostgresAuditRepository {
return &PostgresAuditRepository{pool: pool}
}
// Ensure interface
var _ AuditRepository = (*PostgresAuditRepository)(nil)
// Emit 发送审计事件
func (r *PostgresAuditRepository) Emit(ctx context.Context, event *model.AuditEvent) error {
// 生成事件ID
if event.EventID == "" {
event.EventID = uuid.New().String()
}
// 设置时间戳
if event.Timestamp.IsZero() {
event.Timestamp = time.Now()
}
event.TimestampMs = event.Timestamp.UnixMilli()
// 序列化扩展字段
var extensionsJSON []byte
if event.Extensions != nil {
var err error
extensionsJSON, err = json.Marshal(event.Extensions)
if err != nil {
return fmt.Errorf("failed to marshal extensions: %w", err)
}
}
// 序列化安全标记
securityFlagsJSON, err := json.Marshal(event.SecurityFlags)
if err != nil {
return fmt.Errorf("failed to marshal security flags: %w", err)
}
// 序列化状态变更
var beforeStateJSON, afterStateJSON []byte
if event.BeforeState != nil {
beforeStateJSON, err = json.Marshal(event.BeforeState)
if err != nil {
return fmt.Errorf("failed to marshal before state: %w", err)
}
}
if event.AfterState != nil {
afterStateJSON, err = json.Marshal(event.AfterState)
if err != nil {
return fmt.Errorf("failed to marshal after state: %w", err)
}
}
query := `
INSERT INTO audit_events (
event_id, event_name, event_category, event_sub_category,
timestamp, timestamp_ms,
request_id, trace_id, span_id,
idempotency_key,
operator_id, operator_type, operator_role,
tenant_id, tenant_type,
object_type, object_id,
action, action_detail,
credential_type, credential_id, credential_fingerprint,
source_type, source_ip, source_region, user_agent,
target_type, target_endpoint, target_direct,
result_code, result_message, success,
before_data, after_data,
security_flags, risk_score,
compliance_tags, invariant_rule,
extensions,
version, created_at
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10,
$11, $12, $13, $14, $15, $16, $17, $18, $19, $20,
$21, $22, $23, $24, $25, $26, $27, $28, $29, $30,
$31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41
)
`
_, err = r.pool.Exec(ctx, query,
event.EventID, event.EventName, event.EventCategory, event.EventSubCategory,
event.Timestamp, event.TimestampMs,
event.RequestID, event.TraceID, event.SpanID,
event.IdempotencyKey,
event.OperatorID, event.OperatorType, event.OperatorRole,
event.TenantID, event.TenantType,
event.ObjectType, event.ObjectID,
event.Action, event.ActionDetail,
event.CredentialType, event.CredentialID, event.CredentialFingerprint,
event.SourceType, event.SourceIP, event.SourceRegion, event.UserAgent,
event.TargetType, event.TargetEndpoint, event.TargetDirect,
event.ResultCode, event.ResultMessage, event.Success,
beforeStateJSON, afterStateJSON,
securityFlagsJSON, event.RiskScore,
event.ComplianceTags, event.InvariantRule,
extensionsJSON,
1, time.Now(),
)
if err != nil {
// 检查幂等键重复
if strings.Contains(err.Error(), "idempotency_key") && strings.Contains(err.Error(), "unique") {
return ErrDuplicateIdempotencyKey
}
return fmt.Errorf("failed to emit audit event: %w", err)
}
return nil
}
// Query 查询审计事件
func (r *PostgresAuditRepository) Query(ctx context.Context, filter *EventFilter) ([]*model.AuditEvent, int64, error) {
// 构建查询条件
conditions := []string{}
args := []interface{}{}
argIndex := 1
if filter.TenantID != 0 {
conditions = append(conditions, fmt.Sprintf("tenant_id = $%d", argIndex))
args = append(args, filter.TenantID)
argIndex++
}
if filter.Category != "" {
conditions = append(conditions, fmt.Sprintf("event_category = $%d", argIndex))
args = append(args, filter.Category)
argIndex++
}
if filter.EventName != "" {
conditions = append(conditions, fmt.Sprintf("event_name = $%d", argIndex))
args = append(args, filter.EventName)
argIndex++
}
if filter.OperatorID != 0 {
conditions = append(conditions, fmt.Sprintf("operator_id = $%d", argIndex))
args = append(args, filter.OperatorID)
argIndex++
}
if filter.StartTime != nil {
conditions = append(conditions, fmt.Sprintf("timestamp >= $%d", argIndex))
args = append(args, *filter.StartTime)
argIndex++
}
if filter.EndTime != nil {
conditions = append(conditions, fmt.Sprintf("timestamp <= $%d", argIndex))
args = append(args, *filter.EndTime)
argIndex++
}
whereClause := ""
if len(conditions) > 0 {
whereClause = "WHERE " + strings.Join(conditions, " AND ")
}
// 查询总数
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM audit_events %s", whereClause)
var total int64
err := r.pool.QueryRow(ctx, countQuery, args...).Scan(&total)
if err != nil {
return nil, 0, fmt.Errorf("failed to count audit events: %w", err)
}
// 查询事件列表
limit := filter.Limit
if limit <= 0 {
limit = 100
}
if limit > 1000 {
limit = 1000
}
offset := filter.Offset
if offset < 0 {
offset = 0
}
query := fmt.Sprintf(`
SELECT
event_id, event_name, event_category, event_sub_category,
timestamp, timestamp_ms,
request_id, trace_id, span_id,
idempotency_key,
operator_id, operator_type, operator_role,
tenant_id, tenant_type,
object_type, object_id,
action, action_detail,
credential_type, credential_id, credential_fingerprint,
source_type, source_ip, source_region, user_agent,
target_type, target_endpoint, target_direct,
result_code, result_message, success,
before_data, after_data,
security_flags, risk_score,
compliance_tags, invariant_rule,
extensions,
version, created_at
FROM audit_events
%s
ORDER BY timestamp DESC
LIMIT $%d OFFSET $%d
`, whereClause, argIndex, argIndex+1)
args = append(args, limit, offset)
rows, err := r.pool.Query(ctx, query, args...)
if err != nil {
return nil, 0, fmt.Errorf("failed to query audit events: %w", err)
}
defer rows.Close()
var events []*model.AuditEvent
for rows.Next() {
event, err := r.scanAuditEvent(rows)
if err != nil {
return nil, 0, fmt.Errorf("failed to scan audit event: %w", err)
}
events = append(events, event)
}
return events, total, nil
}
// GetByIdempotencyKey 根据幂等键获取事件
func (r *PostgresAuditRepository) GetByIdempotencyKey(ctx context.Context, key string) (*model.AuditEvent, error) {
query := `
SELECT
event_id, event_name, event_category, event_sub_category,
timestamp, timestamp_ms,
request_id, trace_id, span_id,
idempotency_key,
operator_id, operator_type, operator_role,
tenant_id, tenant_type,
object_type, object_id,
action, action_detail,
credential_type, credential_id, credential_fingerprint,
source_type, source_ip, source_region, user_agent,
target_type, target_endpoint, target_direct,
result_code, result_message, success,
before_data, after_data,
security_flags, risk_score,
compliance_tags, invariant_rule,
extensions,
version, created_at
FROM audit_events
WHERE idempotency_key = $1
`
row := r.pool.QueryRow(ctx, query, key)
event, err := r.scanAuditEventRow(row)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
return nil, fmt.Errorf("failed to get event by idempotency key: %w", err)
}
return event, nil
}
// scanAuditEvent 扫描审计事件行
func (r *PostgresAuditRepository) scanAuditEvent(rows pgx.Rows) (*model.AuditEvent, error) {
var event model.AuditEvent
var eventSubCategory, traceID, spanID, idempotencyKey, operatorRole string
var beforeData, afterData, extensions []byte
var securityFlagsJSON []byte
var complianceTags []string
err := rows.Scan(
&event.EventID, &event.EventName, &event.EventCategory, &eventSubCategory,
&event.Timestamp, &event.TimestampMs,
&event.RequestID, &traceID, &spanID,
&idempotencyKey,
&event.OperatorID, &event.OperatorType, &operatorRole,
&event.TenantID, &event.TenantType,
&event.ObjectType, &event.ObjectID,
&event.Action, &event.ActionDetail,
&event.CredentialType, &event.CredentialID, &event.CredentialFingerprint,
&event.SourceType, &event.SourceIP, &event.SourceRegion, &event.UserAgent,
&event.TargetType, &event.TargetEndpoint, &event.TargetDirect,
&event.ResultCode, &event.ResultMessage, &event.Success,
&beforeData, &afterData,
&securityFlagsJSON, &event.RiskScore,
&complianceTags, &event.InvariantRule,
&extensions,
&event.Version, &event.CreatedAt,
)
if err != nil {
return nil, err
}
event.EventSubCategory = eventSubCategory
event.TraceID = traceID
event.SpanID = spanID
event.IdempotencyKey = idempotencyKey
event.OperatorRole = operatorRole
event.ComplianceTags = complianceTags
// 反序列化JSON字段
if beforeData != nil {
json.Unmarshal(beforeData, &event.BeforeState)
}
if afterData != nil {
json.Unmarshal(afterData, &event.AfterState)
}
if securityFlagsJSON != nil {
json.Unmarshal(securityFlagsJSON, &event.SecurityFlags)
}
if extensions != nil {
json.Unmarshal(extensions, &event.Extensions)
}
return &event, nil
}
// scanAuditEventRow 扫描单行审计事件
func (r *PostgresAuditRepository) scanAuditEventRow(row pgx.Row) (*model.AuditEvent, error) {
var event model.AuditEvent
var eventSubCategory, traceID, spanID, idempotencyKey, operatorRole string
var beforeData, afterData, extensions []byte
var securityFlagsJSON []byte
var complianceTags []string
err := row.Scan(
&event.EventID, &event.EventName, &event.EventCategory, &eventSubCategory,
&event.Timestamp, &event.TimestampMs,
&event.RequestID, &traceID, &spanID,
&idempotencyKey,
&event.OperatorID, &event.OperatorType, &operatorRole,
&event.TenantID, &event.TenantType,
&event.ObjectType, &event.ObjectID,
&event.Action, &event.ActionDetail,
&event.CredentialType, &event.CredentialID, &event.CredentialFingerprint,
&event.SourceType, &event.SourceIP, &event.SourceRegion, &event.UserAgent,
&event.TargetType, &event.TargetEndpoint, &event.TargetDirect,
&event.ResultCode, &event.ResultMessage, &event.Success,
&beforeData, &afterData,
&securityFlagsJSON, &event.RiskScore,
&complianceTags, &event.InvariantRule,
&extensions,
&event.Version, &event.CreatedAt,
)
if err != nil {
return nil, err
}
event.EventSubCategory = eventSubCategory
event.TraceID = traceID
event.SpanID = spanID
event.IdempotencyKey = idempotencyKey
event.OperatorRole = operatorRole
event.ComplianceTags = complianceTags
// 反序列化JSON字段
if beforeData != nil {
json.Unmarshal(beforeData, &event.BeforeState)
}
if afterData != nil {
json.Unmarshal(afterData, &event.AfterState)
}
if securityFlagsJSON != nil {
json.Unmarshal(securityFlagsJSON, &event.SecurityFlags)
}
if extensions != nil {
json.Unmarshal(extensions, &event.Extensions)
}
return &event, nil
}
// errors
var (
ErrDuplicateIdempotencyKey = errors.New("duplicate idempotency key")
)