Files
ai-ops/internal/service/audit_service.go
2026-05-12 17:48:22 +08:00

182 lines
5.6 KiB
Go

package service
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"time"
"github.com/company/ai-ops/internal/database"
)
// AuditService 是审计服务
type AuditService struct{}
func NewAuditService() *AuditService {
return &AuditService{}
}
// AuditLog 是审计日志记录
type AuditLog struct {
ID string `json:"id"`
TenantID string `json:"tenant_id"`
ObjectType string `json:"object_type"`
ObjectID string `json:"object_id"`
Action string `json:"action"`
BeforeState map[string]any `json:"before_state,omitempty"`
AfterState map[string]any `json:"after_state,omitempty"`
RequestID string `json:"request_id"`
ResultCode string `json:"result_code"`
SourceIP string `json:"source_ip"`
ActorID string `json:"actor_id"`
RiskLevel string `json:"risk_level"`
ParentAuditID *string `json:"parent_audit_id,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
// Record 记录审计日志
func (s *AuditService) Record(ctx context.Context, log *AuditLog) error {
var parentID any
if log.ParentAuditID != nil {
parentID = *log.ParentAuditID
}
_, err := database.Pool.Exec(ctx, `
INSERT INTO ai_ops_audits (id, tenant_id, object_type, object_id, action,
before_state, after_state, request_id, result_code, source_ip, actor_id,
risk_level, parent_audit_id, created_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, NOW())
`, log.ID, log.TenantID, log.ObjectType, log.ObjectID, log.Action,
log.BeforeState, log.AfterState, log.RequestID, log.ResultCode,
log.SourceIP, log.ActorID, log.RiskLevel, parentID)
if err != nil {
return fmt.Errorf("insert audit: %w", err)
}
return nil
}
// List 查询审计日志
func (s *AuditService) List(ctx context.Context, objectType, objectID string, page, pageSize int) ([]AuditLog, int, error) {
if page < 1 {
page = 1
}
if pageSize < 1 || pageSize > 100 {
pageSize = 20
}
where := ""
args := []any{}
argIdx := 1
if objectType != "" {
where = fmt.Sprintf("WHERE object_type = $%d", argIdx)
args = append(args, objectType)
argIdx++
}
if objectID != "" {
if where != "" {
where += fmt.Sprintf(" AND object_id = $%d", argIdx)
} else {
where = fmt.Sprintf("WHERE object_id = $%d", argIdx)
}
args = append(args, objectID)
argIdx++
}
var total int
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM ai_ops_audits %s", where)
if err := database.Pool.QueryRow(ctx, countQuery, args...).Scan(&total); err != nil {
return nil, 0, fmt.Errorf("count audits: %w", err)
}
dataQuery := fmt.Sprintf(`
SELECT id, tenant_id, object_type, object_id, action,
before_state, after_state, request_id, result_code, source_ip, actor_id,
risk_level, parent_audit_id, created_at
FROM ai_ops_audits %s
ORDER BY created_at DESC
LIMIT $%d OFFSET $%d
`, where, argIdx, argIdx+1)
queryArgs := append(args, pageSize, (page-1)*pageSize)
rows, err := database.Pool.Query(ctx, dataQuery, queryArgs...)
if err != nil {
return nil, 0, fmt.Errorf("query audits: %w", err)
}
defer rows.Close()
var logs []AuditLog
for rows.Next() {
var l AuditLog
var parentID *string
if err := rows.Scan(
&l.ID, &l.TenantID, &l.ObjectType, &l.ObjectID, &l.Action,
&l.BeforeState, &l.AfterState, &l.RequestID, &l.ResultCode,
&l.SourceIP, &l.ActorID, &l.RiskLevel, &parentID, &l.CreatedAt,
); err != nil {
return nil, 0, fmt.Errorf("scan audit: %w", err)
}
l.ParentAuditID = parentID
logs = append(logs, l)
}
return logs, total, rows.Err()
}
// Rollback 回滚配置
func (s *AuditService) Rollback(ctx context.Context, auditID string) (*AuditLog, error) {
// 查找原始审计记录
var original AuditLog
var parentID *string
err := database.Pool.QueryRow(ctx, `
SELECT id, tenant_id, object_type, object_id, action,
before_state, after_state, request_id, result_code, source_ip, actor_id,
risk_level, parent_audit_id, created_at
FROM ai_ops_audits WHERE id = $1
`, auditID).Scan(
&original.ID, &original.TenantID, &original.ObjectType, &original.ObjectID, &original.Action,
&original.BeforeState, &original.AfterState, &original.RequestID, &original.ResultCode,
&original.SourceIP, &original.ActorID, &original.RiskLevel, &parentID, &original.CreatedAt,
)
if err != nil {
return nil, fmt.Errorf("audit record not found")
}
// 检查目标资源是否存在(简化处理:假设总是存在)
if original.BeforeState == nil {
return nil, fmt.Errorf("no before_state available for rollback")
}
// 创建回滚审计记录
rollbackLog := &AuditLog{
ID: generateAuditID(),
TenantID: original.TenantID,
ObjectType: original.ObjectType,
ObjectID: original.ObjectID,
Action: "rollback",
BeforeState: original.AfterState,
AfterState: original.BeforeState,
RequestID: original.RequestID,
ResultCode: "SUCCESS",
SourceIP: original.SourceIP,
ActorID: original.ActorID,
RiskLevel: "high",
ParentAuditID: &original.ID,
}
if err := s.Record(ctx, rollbackLog); err != nil {
return nil, fmt.Errorf("record rollback audit: %w", err)
}
return rollbackLog, nil
}
func generateAuditID() string {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return fmt.Sprintf("00000000-0000-4000-8000-%012d", time.Now().UnixNano()%1_000_000_000_000)
}
b[6] = (b[6] & 0x0f) | 0x40
b[8] = (b[8] & 0x3f) | 0x80
return fmt.Sprintf("%s-%s-%s-%s-%s", hex.EncodeToString(b[0:4]), hex.EncodeToString(b[4:6]), hex.EncodeToString(b[6:8]), hex.EncodeToString(b[8:10]), hex.EncodeToString(b[10:16]))
}