420 lines
12 KiB
Go
420 lines
12 KiB
Go
|
|
package repository
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"encoding/json"
|
||
|
|
"errors"
|
||
|
|
"fmt"
|
||
|
|
"strings"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"github.com/google/uuid"
|
||
|
|
"github.com/jackc/pgx/v5"
|
||
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
||
|
|
"lijiaoqiao/supply-api/internal/audit/model"
|
||
|
|
)
|
||
|
|
|
||
|
|
// EventFilter 事件查询过滤器(仓储层定义,避免循环依赖)
|
||
|
|
type EventFilter struct {
|
||
|
|
TenantID int64
|
||
|
|
OperatorID int64
|
||
|
|
Category string
|
||
|
|
EventName string
|
||
|
|
StartTime *time.Time
|
||
|
|
EndTime *time.Time
|
||
|
|
Limit int
|
||
|
|
Offset int
|
||
|
|
}
|
||
|
|
|
||
|
|
// AuditRepository 审计事件仓储接口
|
||
|
|
type AuditRepository interface {
|
||
|
|
// Emit 发送审计事件
|
||
|
|
Emit(ctx context.Context, event *model.AuditEvent) error
|
||
|
|
// Query 查询审计事件
|
||
|
|
Query(ctx context.Context, filter *EventFilter) ([]*model.AuditEvent, int64, error)
|
||
|
|
// GetByIdempotencyKey 根据幂等键获取事件
|
||
|
|
GetByIdempotencyKey(ctx context.Context, key string) (*model.AuditEvent, error)
|
||
|
|
}
|
||
|
|
|
||
|
|
// PostgresAuditRepository PostgreSQL实现的审计仓储
|
||
|
|
type PostgresAuditRepository struct {
|
||
|
|
pool *pgxpool.Pool
|
||
|
|
}
|
||
|
|
|
||
|
|
// NewPostgresAuditRepository 创建PostgreSQL审计仓储
|
||
|
|
func NewPostgresAuditRepository(pool *pgxpool.Pool) *PostgresAuditRepository {
|
||
|
|
return &PostgresAuditRepository{pool: pool}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Ensure interface
|
||
|
|
var _ AuditRepository = (*PostgresAuditRepository)(nil)
|
||
|
|
|
||
|
|
// Emit 发送审计事件
|
||
|
|
func (r *PostgresAuditRepository) Emit(ctx context.Context, event *model.AuditEvent) error {
|
||
|
|
// 生成事件ID
|
||
|
|
if event.EventID == "" {
|
||
|
|
event.EventID = uuid.New().String()
|
||
|
|
}
|
||
|
|
|
||
|
|
// 设置时间戳
|
||
|
|
if event.Timestamp.IsZero() {
|
||
|
|
event.Timestamp = time.Now()
|
||
|
|
}
|
||
|
|
event.TimestampMs = event.Timestamp.UnixMilli()
|
||
|
|
|
||
|
|
// 序列化扩展字段
|
||
|
|
var extensionsJSON []byte
|
||
|
|
if event.Extensions != nil {
|
||
|
|
var err error
|
||
|
|
extensionsJSON, err = json.Marshal(event.Extensions)
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("failed to marshal extensions: %w", err)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// 序列化安全标记
|
||
|
|
securityFlagsJSON, err := json.Marshal(event.SecurityFlags)
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("failed to marshal security flags: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
// 序列化状态变更
|
||
|
|
var beforeStateJSON, afterStateJSON []byte
|
||
|
|
if event.BeforeState != nil {
|
||
|
|
beforeStateJSON, err = json.Marshal(event.BeforeState)
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("failed to marshal before state: %w", err)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if event.AfterState != nil {
|
||
|
|
afterStateJSON, err = json.Marshal(event.AfterState)
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("failed to marshal after state: %w", err)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
query := `
|
||
|
|
INSERT INTO audit_events (
|
||
|
|
event_id, event_name, event_category, event_sub_category,
|
||
|
|
timestamp, timestamp_ms,
|
||
|
|
request_id, trace_id, span_id,
|
||
|
|
idempotency_key,
|
||
|
|
operator_id, operator_type, operator_role,
|
||
|
|
tenant_id, tenant_type,
|
||
|
|
object_type, object_id,
|
||
|
|
action, action_detail,
|
||
|
|
credential_type, credential_id, credential_fingerprint,
|
||
|
|
source_type, source_ip, source_region, user_agent,
|
||
|
|
target_type, target_endpoint, target_direct,
|
||
|
|
result_code, result_message, success,
|
||
|
|
before_data, after_data,
|
||
|
|
security_flags, risk_score,
|
||
|
|
compliance_tags, invariant_rule,
|
||
|
|
extensions,
|
||
|
|
version, created_at
|
||
|
|
) VALUES (
|
||
|
|
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10,
|
||
|
|
$11, $12, $13, $14, $15, $16, $17, $18, $19, $20,
|
||
|
|
$21, $22, $23, $24, $25, $26, $27, $28, $29, $30,
|
||
|
|
$31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41
|
||
|
|
)
|
||
|
|
`
|
||
|
|
|
||
|
|
_, err = r.pool.Exec(ctx, query,
|
||
|
|
event.EventID, event.EventName, event.EventCategory, event.EventSubCategory,
|
||
|
|
event.Timestamp, event.TimestampMs,
|
||
|
|
event.RequestID, event.TraceID, event.SpanID,
|
||
|
|
event.IdempotencyKey,
|
||
|
|
event.OperatorID, event.OperatorType, event.OperatorRole,
|
||
|
|
event.TenantID, event.TenantType,
|
||
|
|
event.ObjectType, event.ObjectID,
|
||
|
|
event.Action, event.ActionDetail,
|
||
|
|
event.CredentialType, event.CredentialID, event.CredentialFingerprint,
|
||
|
|
event.SourceType, event.SourceIP, event.SourceRegion, event.UserAgent,
|
||
|
|
event.TargetType, event.TargetEndpoint, event.TargetDirect,
|
||
|
|
event.ResultCode, event.ResultMessage, event.Success,
|
||
|
|
beforeStateJSON, afterStateJSON,
|
||
|
|
securityFlagsJSON, event.RiskScore,
|
||
|
|
event.ComplianceTags, event.InvariantRule,
|
||
|
|
extensionsJSON,
|
||
|
|
1, time.Now(),
|
||
|
|
)
|
||
|
|
|
||
|
|
if err != nil {
|
||
|
|
// 检查幂等键重复
|
||
|
|
if strings.Contains(err.Error(), "idempotency_key") && strings.Contains(err.Error(), "unique") {
|
||
|
|
return ErrDuplicateIdempotencyKey
|
||
|
|
}
|
||
|
|
return fmt.Errorf("failed to emit audit event: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// Query 查询审计事件
|
||
|
|
func (r *PostgresAuditRepository) Query(ctx context.Context, filter *EventFilter) ([]*model.AuditEvent, int64, error) {
|
||
|
|
// 构建查询条件
|
||
|
|
conditions := []string{}
|
||
|
|
args := []interface{}{}
|
||
|
|
argIndex := 1
|
||
|
|
|
||
|
|
if filter.TenantID != 0 {
|
||
|
|
conditions = append(conditions, fmt.Sprintf("tenant_id = $%d", argIndex))
|
||
|
|
args = append(args, filter.TenantID)
|
||
|
|
argIndex++
|
||
|
|
}
|
||
|
|
|
||
|
|
if filter.Category != "" {
|
||
|
|
conditions = append(conditions, fmt.Sprintf("event_category = $%d", argIndex))
|
||
|
|
args = append(args, filter.Category)
|
||
|
|
argIndex++
|
||
|
|
}
|
||
|
|
|
||
|
|
if filter.EventName != "" {
|
||
|
|
conditions = append(conditions, fmt.Sprintf("event_name = $%d", argIndex))
|
||
|
|
args = append(args, filter.EventName)
|
||
|
|
argIndex++
|
||
|
|
}
|
||
|
|
|
||
|
|
if filter.OperatorID != 0 {
|
||
|
|
conditions = append(conditions, fmt.Sprintf("operator_id = $%d", argIndex))
|
||
|
|
args = append(args, filter.OperatorID)
|
||
|
|
argIndex++
|
||
|
|
}
|
||
|
|
|
||
|
|
if filter.StartTime != nil {
|
||
|
|
conditions = append(conditions, fmt.Sprintf("timestamp >= $%d", argIndex))
|
||
|
|
args = append(args, *filter.StartTime)
|
||
|
|
argIndex++
|
||
|
|
}
|
||
|
|
|
||
|
|
if filter.EndTime != nil {
|
||
|
|
conditions = append(conditions, fmt.Sprintf("timestamp <= $%d", argIndex))
|
||
|
|
args = append(args, *filter.EndTime)
|
||
|
|
argIndex++
|
||
|
|
}
|
||
|
|
|
||
|
|
whereClause := ""
|
||
|
|
if len(conditions) > 0 {
|
||
|
|
whereClause = "WHERE " + strings.Join(conditions, " AND ")
|
||
|
|
}
|
||
|
|
|
||
|
|
// 查询总数
|
||
|
|
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM audit_events %s", whereClause)
|
||
|
|
var total int64
|
||
|
|
err := r.pool.QueryRow(ctx, countQuery, args...).Scan(&total)
|
||
|
|
if err != nil {
|
||
|
|
return nil, 0, fmt.Errorf("failed to count audit events: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
// 查询事件列表
|
||
|
|
limit := filter.Limit
|
||
|
|
if limit <= 0 {
|
||
|
|
limit = 100
|
||
|
|
}
|
||
|
|
if limit > 1000 {
|
||
|
|
limit = 1000
|
||
|
|
}
|
||
|
|
|
||
|
|
offset := filter.Offset
|
||
|
|
if offset < 0 {
|
||
|
|
offset = 0
|
||
|
|
}
|
||
|
|
|
||
|
|
query := fmt.Sprintf(`
|
||
|
|
SELECT
|
||
|
|
event_id, event_name, event_category, event_sub_category,
|
||
|
|
timestamp, timestamp_ms,
|
||
|
|
request_id, trace_id, span_id,
|
||
|
|
idempotency_key,
|
||
|
|
operator_id, operator_type, operator_role,
|
||
|
|
tenant_id, tenant_type,
|
||
|
|
object_type, object_id,
|
||
|
|
action, action_detail,
|
||
|
|
credential_type, credential_id, credential_fingerprint,
|
||
|
|
source_type, source_ip, source_region, user_agent,
|
||
|
|
target_type, target_endpoint, target_direct,
|
||
|
|
result_code, result_message, success,
|
||
|
|
before_data, after_data,
|
||
|
|
security_flags, risk_score,
|
||
|
|
compliance_tags, invariant_rule,
|
||
|
|
extensions,
|
||
|
|
version, created_at
|
||
|
|
FROM audit_events
|
||
|
|
%s
|
||
|
|
ORDER BY timestamp DESC
|
||
|
|
LIMIT $%d OFFSET $%d
|
||
|
|
`, whereClause, argIndex, argIndex+1)
|
||
|
|
|
||
|
|
args = append(args, limit, offset)
|
||
|
|
|
||
|
|
rows, err := r.pool.Query(ctx, query, args...)
|
||
|
|
if err != nil {
|
||
|
|
return nil, 0, fmt.Errorf("failed to query audit events: %w", err)
|
||
|
|
}
|
||
|
|
defer rows.Close()
|
||
|
|
|
||
|
|
var events []*model.AuditEvent
|
||
|
|
for rows.Next() {
|
||
|
|
event, err := r.scanAuditEvent(rows)
|
||
|
|
if err != nil {
|
||
|
|
return nil, 0, fmt.Errorf("failed to scan audit event: %w", err)
|
||
|
|
}
|
||
|
|
events = append(events, event)
|
||
|
|
}
|
||
|
|
|
||
|
|
return events, total, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// GetByIdempotencyKey 根据幂等键获取事件
|
||
|
|
func (r *PostgresAuditRepository) GetByIdempotencyKey(ctx context.Context, key string) (*model.AuditEvent, error) {
|
||
|
|
query := `
|
||
|
|
SELECT
|
||
|
|
event_id, event_name, event_category, event_sub_category,
|
||
|
|
timestamp, timestamp_ms,
|
||
|
|
request_id, trace_id, span_id,
|
||
|
|
idempotency_key,
|
||
|
|
operator_id, operator_type, operator_role,
|
||
|
|
tenant_id, tenant_type,
|
||
|
|
object_type, object_id,
|
||
|
|
action, action_detail,
|
||
|
|
credential_type, credential_id, credential_fingerprint,
|
||
|
|
source_type, source_ip, source_region, user_agent,
|
||
|
|
target_type, target_endpoint, target_direct,
|
||
|
|
result_code, result_message, success,
|
||
|
|
before_data, after_data,
|
||
|
|
security_flags, risk_score,
|
||
|
|
compliance_tags, invariant_rule,
|
||
|
|
extensions,
|
||
|
|
version, created_at
|
||
|
|
FROM audit_events
|
||
|
|
WHERE idempotency_key = $1
|
||
|
|
`
|
||
|
|
|
||
|
|
row := r.pool.QueryRow(ctx, query, key)
|
||
|
|
event, err := r.scanAuditEventRow(row)
|
||
|
|
if err != nil {
|
||
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
||
|
|
return nil, nil
|
||
|
|
}
|
||
|
|
return nil, fmt.Errorf("failed to get event by idempotency key: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
return event, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// scanAuditEvent 扫描审计事件行
|
||
|
|
func (r *PostgresAuditRepository) scanAuditEvent(rows pgx.Rows) (*model.AuditEvent, error) {
|
||
|
|
var event model.AuditEvent
|
||
|
|
var eventSubCategory, traceID, spanID, idempotencyKey, operatorRole string
|
||
|
|
var beforeData, afterData, extensions []byte
|
||
|
|
var securityFlagsJSON []byte
|
||
|
|
var complianceTags []string
|
||
|
|
|
||
|
|
err := rows.Scan(
|
||
|
|
&event.EventID, &event.EventName, &event.EventCategory, &eventSubCategory,
|
||
|
|
&event.Timestamp, &event.TimestampMs,
|
||
|
|
&event.RequestID, &traceID, &spanID,
|
||
|
|
&idempotencyKey,
|
||
|
|
&event.OperatorID, &event.OperatorType, &operatorRole,
|
||
|
|
&event.TenantID, &event.TenantType,
|
||
|
|
&event.ObjectType, &event.ObjectID,
|
||
|
|
&event.Action, &event.ActionDetail,
|
||
|
|
&event.CredentialType, &event.CredentialID, &event.CredentialFingerprint,
|
||
|
|
&event.SourceType, &event.SourceIP, &event.SourceRegion, &event.UserAgent,
|
||
|
|
&event.TargetType, &event.TargetEndpoint, &event.TargetDirect,
|
||
|
|
&event.ResultCode, &event.ResultMessage, &event.Success,
|
||
|
|
&beforeData, &afterData,
|
||
|
|
&securityFlagsJSON, &event.RiskScore,
|
||
|
|
&complianceTags, &event.InvariantRule,
|
||
|
|
&extensions,
|
||
|
|
&event.Version, &event.CreatedAt,
|
||
|
|
)
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
event.EventSubCategory = eventSubCategory
|
||
|
|
event.TraceID = traceID
|
||
|
|
event.SpanID = spanID
|
||
|
|
event.IdempotencyKey = idempotencyKey
|
||
|
|
event.OperatorRole = operatorRole
|
||
|
|
event.ComplianceTags = complianceTags
|
||
|
|
|
||
|
|
// 反序列化JSON字段
|
||
|
|
if beforeData != nil {
|
||
|
|
json.Unmarshal(beforeData, &event.BeforeState)
|
||
|
|
}
|
||
|
|
if afterData != nil {
|
||
|
|
json.Unmarshal(afterData, &event.AfterState)
|
||
|
|
}
|
||
|
|
if securityFlagsJSON != nil {
|
||
|
|
json.Unmarshal(securityFlagsJSON, &event.SecurityFlags)
|
||
|
|
}
|
||
|
|
if extensions != nil {
|
||
|
|
json.Unmarshal(extensions, &event.Extensions)
|
||
|
|
}
|
||
|
|
|
||
|
|
return &event, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// scanAuditEventRow 扫描单行审计事件
|
||
|
|
func (r *PostgresAuditRepository) scanAuditEventRow(row pgx.Row) (*model.AuditEvent, error) {
|
||
|
|
var event model.AuditEvent
|
||
|
|
var eventSubCategory, traceID, spanID, idempotencyKey, operatorRole string
|
||
|
|
var beforeData, afterData, extensions []byte
|
||
|
|
var securityFlagsJSON []byte
|
||
|
|
var complianceTags []string
|
||
|
|
|
||
|
|
err := row.Scan(
|
||
|
|
&event.EventID, &event.EventName, &event.EventCategory, &eventSubCategory,
|
||
|
|
&event.Timestamp, &event.TimestampMs,
|
||
|
|
&event.RequestID, &traceID, &spanID,
|
||
|
|
&idempotencyKey,
|
||
|
|
&event.OperatorID, &event.OperatorType, &operatorRole,
|
||
|
|
&event.TenantID, &event.TenantType,
|
||
|
|
&event.ObjectType, &event.ObjectID,
|
||
|
|
&event.Action, &event.ActionDetail,
|
||
|
|
&event.CredentialType, &event.CredentialID, &event.CredentialFingerprint,
|
||
|
|
&event.SourceType, &event.SourceIP, &event.SourceRegion, &event.UserAgent,
|
||
|
|
&event.TargetType, &event.TargetEndpoint, &event.TargetDirect,
|
||
|
|
&event.ResultCode, &event.ResultMessage, &event.Success,
|
||
|
|
&beforeData, &afterData,
|
||
|
|
&securityFlagsJSON, &event.RiskScore,
|
||
|
|
&complianceTags, &event.InvariantRule,
|
||
|
|
&extensions,
|
||
|
|
&event.Version, &event.CreatedAt,
|
||
|
|
)
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
event.EventSubCategory = eventSubCategory
|
||
|
|
event.TraceID = traceID
|
||
|
|
event.SpanID = spanID
|
||
|
|
event.IdempotencyKey = idempotencyKey
|
||
|
|
event.OperatorRole = operatorRole
|
||
|
|
event.ComplianceTags = complianceTags
|
||
|
|
|
||
|
|
// 反序列化JSON字段
|
||
|
|
if beforeData != nil {
|
||
|
|
json.Unmarshal(beforeData, &event.BeforeState)
|
||
|
|
}
|
||
|
|
if afterData != nil {
|
||
|
|
json.Unmarshal(afterData, &event.AfterState)
|
||
|
|
}
|
||
|
|
if securityFlagsJSON != nil {
|
||
|
|
json.Unmarshal(securityFlagsJSON, &event.SecurityFlags)
|
||
|
|
}
|
||
|
|
if extensions != nil {
|
||
|
|
json.Unmarshal(extensions, &event.Extensions)
|
||
|
|
}
|
||
|
|
|
||
|
|
return &event, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// errors
|
||
|
|
var (
|
||
|
|
ErrDuplicateIdempotencyKey = errors.New("duplicate idempotency key")
|
||
|
|
)
|