Files
ai-ops/internal/service/alert_engine.go
2026-05-12 17:48:22 +08:00

278 lines
7.0 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"log/slog"
"strconv"
"sync"
"time"
"github.com/company/ai-ops/internal/domain/model"
"github.com/company/ai-ops/internal/domain/repository"
)
// TriggerState 记录规则触发状态
type TriggerState struct {
FirstTriggeredAt time.Time // 首次超阈值时间
LastTriggeredAt time.Time // 最近一次触发时间
}
// AlertEngine 是告警规则评估引擎
type AlertEngine struct {
alertRepo repository.AlertRepository
metricRepo repository.MetricRepository
notifySvc *NotificationService
interval time.Duration
stopCh chan struct{}
// 规则触发状态(持续时间判定)
triggerStates map[string]*TriggerState
statesMu sync.RWMutex
// 抑制期:同一规则 5 分钟内不重复触发
suppressWindow time.Duration
// 告警升级P2 持续 2 小时未确认 → 升级 P1
escalationInterval time.Duration
// 告警聚合:同一资源 1 分钟内超过 20 条时生成聚合告警
aggregationWindow time.Duration
aggregationThreshold int
}
// NewAlertEngine 创建规则评估引擎
func NewAlertEngine(ar repository.AlertRepository, mr repository.MetricRepository, ns *NotificationService) *AlertEngine {
return &AlertEngine{
alertRepo: ar,
metricRepo: mr,
notifySvc: ns,
interval: 30 * time.Second,
stopCh: make(chan struct{}),
triggerStates: make(map[string]*TriggerState),
suppressWindow: 5 * time.Minute,
escalationInterval: 2 * time.Hour,
aggregationWindow: 1 * time.Minute,
aggregationThreshold: 20,
}
}
// Start 启动定时评估
func (e *AlertEngine) Start() {
slog.Info("alert_engine_started", "interval", e.interval, "suppress_window", e.suppressWindow)
go e.loop()
}
// Stop 停止引擎
func (e *AlertEngine) Stop() {
close(e.stopCh)
slog.Info("alert_engine_stopped")
}
func (e *AlertEngine) loop() {
ticker := time.NewTicker(e.interval)
defer ticker.Stop()
escalationTicker := time.NewTicker(5 * time.Minute)
defer escalationTicker.Stop()
e.evaluate(context.Background())
for {
select {
case <-ticker.C:
e.evaluate(context.Background())
case <-escalationTicker.C:
e.escalate(context.Background())
case <-e.stopCh:
return
}
}
}
func (e *AlertEngine) evaluate(ctx context.Context) {
rules, err := e.alertRepo.ListRules(ctx)
if err != nil {
slog.Error("list_rules_failed", "error", err)
return
}
for _, rule := range rules {
if err := e.evaluateRule(ctx, &rule); err != nil {
slog.Error("evaluate_rule_failed", "rule_id", rule.ID, "error", err)
}
}
}
func (e *AlertEngine) evaluateRule(ctx context.Context, rule *model.AlertRule) error {
point, err := e.metricRepo.GetLatest(ctx, rule.MetricSource, rule.MetricName)
if err != nil {
return fmt.Errorf("get metric: %w", err)
}
threshold, err := strconv.ParseFloat(rule.ThresholdValue, 64)
if err != nil {
return fmt.Errorf("parse threshold: %w", err)
}
triggered := e.compare(point.Value, threshold, rule.ThresholdType)
now := time.Now()
e.statesMu.Lock()
state, exists := e.triggerStates[rule.ID]
if !triggered {
// 指标恢复正常,清除触发状态
if exists {
delete(e.triggerStates, rule.ID)
slog.Info("alert_cleared", "rule_id", rule.ID, "metric", rule.MetricName)
}
e.statesMu.Unlock()
return nil
}
// 指标超阈值
if !exists {
state = &TriggerState{FirstTriggeredAt: now, LastTriggeredAt: time.Time{}}
e.triggerStates[rule.ID] = state
}
e.statesMu.Unlock()
// 持续时间判定:必须持续 N 分钟才触发
duration := time.Since(state.FirstTriggeredAt)
requiredDuration := time.Duration(rule.DurationMin) * time.Minute
if duration < requiredDuration {
slog.Debug("alert_breaching_not_yet_triggered",
"rule_id", rule.ID,
"duration", duration,
"required", requiredDuration,
)
return nil
}
// 抑制期检查5 分钟内不重复触发
if !state.LastTriggeredAt.IsZero() && now.Sub(state.LastTriggeredAt) < e.suppressWindow {
return nil
}
// 更新最近触发时间
e.statesMu.Lock()
state.LastTriggeredAt = now
e.statesMu.Unlock()
// 创建告警事件
event := &model.AlertEvent{
ID: generateID(),
RuleID: rule.ID,
Level: rule.Level,
ResourceType: rule.MetricSource,
ResourceID: rule.MetricName,
CurrentValue: fmt.Sprintf("%.4f", point.Value),
ThresholdValue: rule.ThresholdValue,
Status: "triggered",
IsAggregated: false,
AggregatedCount: 1,
}
notifyEvent, err := e.alertRepo.CreateEventWithAggregation(ctx, event, e.aggregationWindow, e.aggregationThreshold)
if err != nil {
return fmt.Errorf("create event: %w", err)
}
if notifyEvent == nil {
notifyEvent = event
}
// 异步发送通知
if e.notifySvc != nil && len(rule.ChannelIDs) > 0 {
e.notifySvc.Enqueue(notifyEvent, rule.ChannelIDs)
}
slog.Info("alert_triggered",
"rule_id", rule.ID,
"level", rule.Level,
"metric", rule.MetricName,
"value", point.Value,
"threshold", threshold,
"duration_min", rule.DurationMin,
)
return nil
}
func (e *AlertEngine) escalate(ctx context.Context) {
// 查询 open 状态的 P2 告警
events, _, err := e.alertRepo.ListEvents(ctx, "triggered", 1, 100)
if err != nil {
slog.Error("list_open_events_failed", "error", err)
return
}
now := time.Now()
for _, event := range events {
if event.Level != "P2" {
continue
}
if now.Sub(event.StartedAt) < e.escalationInterval {
continue
}
// 升级为 P1
if err := e.alertRepo.EscalateEvent(ctx, event.ID, "P1"); err != nil {
slog.Error("escalate_event_failed", "event_id", event.ID, "error", err)
continue
}
// 发送升级通知
upgraded := &model.AlertEvent{
ID: event.ID,
RuleID: event.RuleID,
Level: "P1",
ResourceType: event.ResourceType,
ResourceID: event.ResourceID,
CurrentValue: event.CurrentValue,
ThresholdValue: event.ThresholdValue,
Status: "triggered",
}
rule, err := e.alertRepo.GetRuleByID(ctx, event.RuleID)
if err == nil && e.notifySvc != nil && len(rule.ChannelIDs) > 0 {
e.notifySvc.Enqueue(upgraded, rule.ChannelIDs)
}
slog.Info("alert_escalated",
"event_id", event.ID,
"rule_id", event.RuleID,
"from_level", "P2",
"to_level", "P1",
"duration", now.Sub(event.StartedAt),
)
}
}
func (e *AlertEngine) compare(value, threshold float64, op string) bool {
switch op {
case ">":
return value > threshold
case "<":
return value < threshold
case "=":
return value == threshold
case ">=":
return value >= threshold
case "<=":
return value <= threshold
default:
return false
}
}
func generateID() string {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return fmt.Sprintf("00000000-0000-4000-8000-%012d", time.Now().UnixNano()%1_000_000_000_000)
}
b[6] = (b[6] & 0x0f) | 0x40
b[8] = (b[8] & 0x3f) | 0x80
return fmt.Sprintf("%s-%s-%s-%s-%s", hex.EncodeToString(b[0:4]), hex.EncodeToString(b[4:6]), hex.EncodeToString(b[6:8]), hex.EncodeToString(b[8:10]), hex.EncodeToString(b[10:16]))
}