Files
ai-ops/internal/service/alert_engine_test.go
2026-05-12 17:48:22 +08:00

192 lines
6.8 KiB
Go

package service
import (
"context"
"testing"
"time"
"github.com/company/ai-ops/internal/domain/model"
"github.com/stretchr/testify/mock"
)
type fakeAggregationAlertRepo struct {
rules []model.AlertRule
events []model.AlertEvent
createdEvents []*model.AlertEvent
escalated []string
}
func (r *fakeAggregationAlertRepo) GetOpenCount(ctx context.Context) (*model.AlertCount, error) {
return &model.AlertCount{}, nil
}
func (r *fakeAggregationAlertRepo) ListRules(ctx context.Context) ([]model.AlertRule, error) {
return r.rules, nil
}
func (r *fakeAggregationAlertRepo) GetRuleByID(ctx context.Context, id string) (*model.AlertRule, error) {
for i := range r.rules {
if r.rules[i].ID == id {
return &r.rules[i], nil
}
}
return nil, nil
}
func (r *fakeAggregationAlertRepo) CreateRule(ctx context.Context, rule *model.AlertRule) error {
return nil
}
func (r *fakeAggregationAlertRepo) UpdateRule(ctx context.Context, rule *model.AlertRule) error {
return nil
}
func (r *fakeAggregationAlertRepo) DeleteRule(ctx context.Context, id string) error { return nil }
func (r *fakeAggregationAlertRepo) ListEvents(ctx context.Context, status string, page, pageSize int) ([]model.AlertEvent, int, error) {
return r.events, len(r.events), nil
}
func (r *fakeAggregationAlertRepo) CreateEvent(ctx context.Context, event *model.AlertEvent) error {
r.createdEvents = append(r.createdEvents, event)
return nil
}
func (r *fakeAggregationAlertRepo) CreateEventWithAggregation(ctx context.Context, event *model.AlertEvent, window time.Duration, threshold int) (*model.AlertEvent, error) {
r.createdEvents = append(r.createdEvents, event)
if len(r.createdEvents) > threshold {
return &model.AlertEvent{
ID: "agg-1",
RuleID: event.RuleID,
Level: event.Level,
ResourceType: event.ResourceType,
ResourceID: event.ResourceID,
CurrentValue: event.CurrentValue,
ThresholdValue: event.ThresholdValue,
Status: "triggered",
IsAggregated: true,
AggregatedCount: len(r.createdEvents),
}, nil
}
return event, nil
}
func (r *fakeAggregationAlertRepo) UpdateEventStatus(ctx context.Context, id, status string) error {
return nil
}
func (r *fakeAggregationAlertRepo) EscalateEvent(ctx context.Context, id, newLevel string) error {
r.escalated = append(r.escalated, id+":"+newLevel)
return nil
}
type fakeMetricRepo struct {
point *model.MetricPoint
}
func (r *fakeMetricRepo) GetRealtime(ctx context.Context) (*model.RealtimeMetrics, error) {
return &model.RealtimeMetrics{}, nil
}
func (r *fakeMetricRepo) Query(ctx context.Context, req model.MetricQueryRequest) ([]model.MetricPoint, error) {
return nil, nil
}
func (r *fakeMetricRepo) GetLatest(ctx context.Context, source, name string) (*model.MetricPoint, error) {
return r.point, nil
}
func TestAlertEngineAggregatesWhenSameResourceExceedsTwentyEventsWithinWindow(t *testing.T) {
alertRepo := &fakeAggregationAlertRepo{rules: []model.AlertRule{{
ID: "rule-1",
MetricSource: "service",
MetricName: "api-error-rate",
ThresholdType: ">",
ThresholdValue: "0.1",
DurationMin: 0,
Level: "P2",
}}}
metricRepo := &fakeMetricRepo{point: &model.MetricPoint{Value: 0.5}}
engine := NewAlertEngine(alertRepo, metricRepo, nil)
engine.suppressWindow = 0
var last *model.AlertEvent
for i := 0; i < 21; i++ {
if err := engine.evaluateRule(context.Background(), &alertRepo.rules[0]); err != nil {
t.Fatalf("evaluate rule: %v", err)
}
last = alertRepo.createdEvents[len(alertRepo.createdEvents)-1]
}
if got := len(alertRepo.createdEvents); got != 21 {
t.Fatalf("created events = %d, want 21", got)
}
if last.IsAggregated {
t.Fatalf("raw child event must not be marked aggregated")
}
}
func TestAlertEngineEvaluateAndEscalateBranches(t *testing.T) {
alertRepo := &fakeAggregationAlertRepo{rules: []model.AlertRule{{
ID: "rule-eval",
MetricSource: "service",
MetricName: "latency",
ThresholdType: ">=",
ThresholdValue: "10",
DurationMin: 0,
Level: "P2",
}}}
metricRepo := &fakeMetricRepo{point: &model.MetricPoint{Value: 10}}
engine := NewAlertEngine(alertRepo, metricRepo, nil)
engine.suppressWindow = time.Hour
engine.evaluate(context.Background())
if len(alertRepo.createdEvents) != 1 {
t.Fatalf("created events = %d", len(alertRepo.createdEvents))
}
// suppressed second event
engine.evaluate(context.Background())
if len(alertRepo.createdEvents) != 1 {
t.Fatalf("suppression failed, events = %d", len(alertRepo.createdEvents))
}
if !engine.compare(1, 1, "=") || !engine.compare(1, 2, "<") || !engine.compare(2, 1, ">") || !engine.compare(2, 2, ">=") || !engine.compare(1, 2, "<=") || engine.compare(1, 2, "regex") {
t.Fatal("compare operators not covered as expected")
}
if generateID() == "" {
t.Fatal("empty alert id")
}
}
func TestMetricServiceSupplierAndQuery(t *testing.T) {
mockMetric := new(MockMetricRepository)
mockAlert := new(MockAlertRepository)
svc := NewMetricService(mockMetric, mockAlert)
query := model.MetricQueryRequest{Name: "qps"}
points := []model.MetricPoint{{Name: "qps", Value: 1}}
mockMetric.On("Query", mock.Anything, query).Return(points, nil).Once()
mockMetric.On("Query", mock.Anything, model.MetricQueryRequest{Name: "supplier_health"}).Return([]model.MetricPoint{{Value: 1}, {Value: 0}}, nil).Once()
got, err := svc.QueryMetrics(context.Background(), query)
if err != nil || len(got) != 1 {
t.Fatalf("query metrics = %+v %v", got, err)
}
count, err := svc.GetSupplierCount(context.Background())
if err != nil || count.Healthy != 1 || count.Unhealthy != 1 || count.Total != 2 {
t.Fatalf("supplier count = %+v %v", count, err)
}
}
func TestAlertEngineStartStopCoversLoop(t *testing.T) {
engine := NewAlertEngine(&fakeAggregationAlertRepo{}, &fakeMetricRepo{point: &model.MetricPoint{Value: 0}}, nil)
engine.interval = time.Hour
engine.Start()
time.Sleep(5 * time.Millisecond)
engine.Stop()
}
func TestAlertEngineEscalatesOldP2EventsOnly(t *testing.T) {
oldEvent := model.AlertEvent{ID: "old", RuleID: "rule-old", Level: "P2", ResourceType: "svc", ResourceID: "api", CurrentValue: "9", ThresholdValue: "1", Status: "triggered", StartedAt: time.Now().Add(-3 * time.Hour)}
freshEvent := model.AlertEvent{ID: "fresh", RuleID: "rule-fresh", Level: "P2", StartedAt: time.Now()}
p1Event := model.AlertEvent{ID: "p1", RuleID: "rule-p1", Level: "P1", StartedAt: time.Now().Add(-3 * time.Hour)}
repo := &fakeAggregationAlertRepo{
events: []model.AlertEvent{oldEvent, freshEvent, p1Event},
rules: []model.AlertRule{{ID: "rule-old", ChannelIDs: []string{"ch-1"}}},
}
engine := NewAlertEngine(repo, &fakeMetricRepo{point: &model.MetricPoint{Value: 0}}, nil)
engine.escalate(context.Background())
if len(repo.escalated) != 1 || repo.escalated[0] != "old:P1" {
t.Fatalf("escalated = %+v", repo.escalated)
}
}