182 lines
5.6 KiB
Go
182 lines
5.6 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/company/ai-ops/internal/database"
|
|
)
|
|
|
|
// AuditService 是审计服务
|
|
type AuditService struct{}
|
|
|
|
func NewAuditService() *AuditService {
|
|
return &AuditService{}
|
|
}
|
|
|
|
// AuditLog 是审计日志记录
|
|
type AuditLog struct {
|
|
ID string `json:"id"`
|
|
TenantID string `json:"tenant_id"`
|
|
ObjectType string `json:"object_type"`
|
|
ObjectID string `json:"object_id"`
|
|
Action string `json:"action"`
|
|
BeforeState map[string]any `json:"before_state,omitempty"`
|
|
AfterState map[string]any `json:"after_state,omitempty"`
|
|
RequestID string `json:"request_id"`
|
|
ResultCode string `json:"result_code"`
|
|
SourceIP string `json:"source_ip"`
|
|
ActorID string `json:"actor_id"`
|
|
RiskLevel string `json:"risk_level"`
|
|
ParentAuditID *string `json:"parent_audit_id,omitempty"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
}
|
|
|
|
// Record 记录审计日志
|
|
func (s *AuditService) Record(ctx context.Context, log *AuditLog) error {
|
|
var parentID any
|
|
if log.ParentAuditID != nil {
|
|
parentID = *log.ParentAuditID
|
|
}
|
|
_, err := database.Pool.Exec(ctx, `
|
|
INSERT INTO ai_ops_audits (id, tenant_id, object_type, object_id, action,
|
|
before_state, after_state, request_id, result_code, source_ip, actor_id,
|
|
risk_level, parent_audit_id, created_at)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, NOW())
|
|
`, log.ID, log.TenantID, log.ObjectType, log.ObjectID, log.Action,
|
|
log.BeforeState, log.AfterState, log.RequestID, log.ResultCode,
|
|
log.SourceIP, log.ActorID, log.RiskLevel, parentID)
|
|
if err != nil {
|
|
return fmt.Errorf("insert audit: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// List 查询审计日志
|
|
func (s *AuditService) List(ctx context.Context, objectType, objectID string, page, pageSize int) ([]AuditLog, int, error) {
|
|
if page < 1 {
|
|
page = 1
|
|
}
|
|
if pageSize < 1 || pageSize > 100 {
|
|
pageSize = 20
|
|
}
|
|
|
|
where := ""
|
|
args := []any{}
|
|
argIdx := 1
|
|
|
|
if objectType != "" {
|
|
where = fmt.Sprintf("WHERE object_type = $%d", argIdx)
|
|
args = append(args, objectType)
|
|
argIdx++
|
|
}
|
|
if objectID != "" {
|
|
if where != "" {
|
|
where += fmt.Sprintf(" AND object_id = $%d", argIdx)
|
|
} else {
|
|
where = fmt.Sprintf("WHERE object_id = $%d", argIdx)
|
|
}
|
|
args = append(args, objectID)
|
|
argIdx++
|
|
}
|
|
|
|
var total int
|
|
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM ai_ops_audits %s", where)
|
|
if err := database.Pool.QueryRow(ctx, countQuery, args...).Scan(&total); err != nil {
|
|
return nil, 0, fmt.Errorf("count audits: %w", err)
|
|
}
|
|
|
|
dataQuery := fmt.Sprintf(`
|
|
SELECT id, tenant_id, object_type, object_id, action,
|
|
before_state, after_state, request_id, result_code, source_ip, actor_id,
|
|
risk_level, parent_audit_id, created_at
|
|
FROM ai_ops_audits %s
|
|
ORDER BY created_at DESC
|
|
LIMIT $%d OFFSET $%d
|
|
`, where, argIdx, argIdx+1)
|
|
queryArgs := append(args, pageSize, (page-1)*pageSize)
|
|
|
|
rows, err := database.Pool.Query(ctx, dataQuery, queryArgs...)
|
|
if err != nil {
|
|
return nil, 0, fmt.Errorf("query audits: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var logs []AuditLog
|
|
for rows.Next() {
|
|
var l AuditLog
|
|
var parentID *string
|
|
if err := rows.Scan(
|
|
&l.ID, &l.TenantID, &l.ObjectType, &l.ObjectID, &l.Action,
|
|
&l.BeforeState, &l.AfterState, &l.RequestID, &l.ResultCode,
|
|
&l.SourceIP, &l.ActorID, &l.RiskLevel, &parentID, &l.CreatedAt,
|
|
); err != nil {
|
|
return nil, 0, fmt.Errorf("scan audit: %w", err)
|
|
}
|
|
l.ParentAuditID = parentID
|
|
logs = append(logs, l)
|
|
}
|
|
return logs, total, rows.Err()
|
|
}
|
|
|
|
// Rollback 回滚配置
|
|
func (s *AuditService) Rollback(ctx context.Context, auditID string) (*AuditLog, error) {
|
|
// 查找原始审计记录
|
|
var original AuditLog
|
|
var parentID *string
|
|
err := database.Pool.QueryRow(ctx, `
|
|
SELECT id, tenant_id, object_type, object_id, action,
|
|
before_state, after_state, request_id, result_code, source_ip, actor_id,
|
|
risk_level, parent_audit_id, created_at
|
|
FROM ai_ops_audits WHERE id = $1
|
|
`, auditID).Scan(
|
|
&original.ID, &original.TenantID, &original.ObjectType, &original.ObjectID, &original.Action,
|
|
&original.BeforeState, &original.AfterState, &original.RequestID, &original.ResultCode,
|
|
&original.SourceIP, &original.ActorID, &original.RiskLevel, &parentID, &original.CreatedAt,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("audit record not found")
|
|
}
|
|
|
|
// 检查目标资源是否存在(简化处理:假设总是存在)
|
|
if original.BeforeState == nil {
|
|
return nil, fmt.Errorf("no before_state available for rollback")
|
|
}
|
|
|
|
// 创建回滚审计记录
|
|
rollbackLog := &AuditLog{
|
|
ID: generateAuditID(),
|
|
TenantID: original.TenantID,
|
|
ObjectType: original.ObjectType,
|
|
ObjectID: original.ObjectID,
|
|
Action: "rollback",
|
|
BeforeState: original.AfterState,
|
|
AfterState: original.BeforeState,
|
|
RequestID: original.RequestID,
|
|
ResultCode: "SUCCESS",
|
|
SourceIP: original.SourceIP,
|
|
ActorID: original.ActorID,
|
|
RiskLevel: "high",
|
|
ParentAuditID: &original.ID,
|
|
}
|
|
|
|
if err := s.Record(ctx, rollbackLog); err != nil {
|
|
return nil, fmt.Errorf("record rollback audit: %w", err)
|
|
}
|
|
|
|
return rollbackLog, nil
|
|
}
|
|
|
|
func generateAuditID() string {
|
|
b := make([]byte, 16)
|
|
if _, err := rand.Read(b); err != nil {
|
|
return fmt.Sprintf("00000000-0000-4000-8000-%012d", time.Now().UnixNano()%1_000_000_000_000)
|
|
}
|
|
b[6] = (b[6] & 0x0f) | 0x40
|
|
b[8] = (b[8] & 0x3f) | 0x80
|
|
return fmt.Sprintf("%s-%s-%s-%s-%s", hex.EncodeToString(b[0:4]), hex.EncodeToString(b[4:6]), hex.EncodeToString(b[6:8]), hex.EncodeToString(b[8:10]), hex.EncodeToString(b[10:16]))
|
|
}
|