Files
ai-ops/internal/service/core_services_test.go
2026-05-12 17:48:22 +08:00

226 lines
7.7 KiB
Go

package service
import (
"bytes"
"context"
"errors"
"strings"
"testing"
"time"
"github.com/company/ai-ops/internal/domain/model"
)
type fakeRuleAlertRepo struct {
rules []model.AlertRule
gotRuleID string
createdRule *model.AlertRule
updatedRule *model.AlertRule
deletedID string
err error
}
func (r *fakeRuleAlertRepo) GetOpenCount(context.Context) (*model.AlertCount, error) {
return &model.AlertCount{}, nil
}
func (r *fakeRuleAlertRepo) ListRules(context.Context) ([]model.AlertRule, error) {
return r.rules, r.err
}
func (r *fakeRuleAlertRepo) GetRuleByID(_ context.Context, id string) (*model.AlertRule, error) {
r.gotRuleID = id
if r.err != nil {
return nil, r.err
}
return &model.AlertRule{ID: id, Name: "rule"}, nil
}
func (r *fakeRuleAlertRepo) CreateRule(_ context.Context, rule *model.AlertRule) error {
r.createdRule = rule
return r.err
}
func (r *fakeRuleAlertRepo) UpdateRule(_ context.Context, rule *model.AlertRule) error {
r.updatedRule = rule
return r.err
}
func (r *fakeRuleAlertRepo) DeleteRule(_ context.Context, id string) error {
r.deletedID = id
return r.err
}
func (r *fakeRuleAlertRepo) ListEvents(context.Context, string, int, int) ([]model.AlertEvent, int, error) {
return nil, 0, nil
}
func (r *fakeRuleAlertRepo) CreateEvent(context.Context, *model.AlertEvent) error { return nil }
func (r *fakeRuleAlertRepo) CreateEventWithAggregation(_ context.Context, e *model.AlertEvent, _ time.Duration, _ int) (*model.AlertEvent, error) {
return e, nil
}
func (r *fakeRuleAlertRepo) UpdateEventStatus(context.Context, string, string) error { return nil }
func (r *fakeRuleAlertRepo) EscalateEvent(context.Context, string, string) error { return nil }
type fakeChannelRepository struct {
channels []model.NotificationChannel
gotID string
created *model.NotificationChannel
updated *model.NotificationChannel
deleted string
err error
}
func (r *fakeChannelRepository) List(context.Context) ([]model.NotificationChannel, error) {
return r.channels, r.err
}
func (r *fakeChannelRepository) GetByID(_ context.Context, id string) (*model.NotificationChannel, error) {
r.gotID = id
if r.err != nil {
return nil, r.err
}
return &model.NotificationChannel{ID: id, Name: "webhook"}, nil
}
func (r *fakeChannelRepository) Create(_ context.Context, ch *model.NotificationChannel) error {
r.created = ch
return r.err
}
func (r *fakeChannelRepository) Update(_ context.Context, ch *model.NotificationChannel) error {
r.updated = ch
return r.err
}
func (r *fakeChannelRepository) Delete(_ context.Context, id string) error {
r.deleted = id
return r.err
}
type fakeLogRepository struct {
logs []model.RequestLog
total int
lastFilter model.LogQueryFilter
err error
}
func (r *fakeLogRepository) Query(_ context.Context, filter model.LogQueryFilter) ([]model.RequestLog, int, error) {
r.lastFilter = filter
return r.logs, r.total, r.err
}
func TestAuthServiceIssuesAndParsesToken(t *testing.T) {
svc := NewAuthService("secret")
token, err := svc.IssueToken("u1", "admin")
if err != nil {
t.Fatal(err)
}
claims, err := svc.ParseToken(token)
if err != nil {
t.Fatal(err)
}
if claims.UserID != "u1" || claims.Role != "admin" {
t.Fatalf("unexpected claims: %+v", claims)
}
if _, err := NewAuthService("other").ParseToken(token); err == nil {
t.Fatal("expected invalid signature error")
}
if _, err := svc.ParseToken("not-a-jwt"); err == nil {
t.Fatal("expected malformed token error")
}
}
func TestRuleServiceValidationAndRepositoryCalls(t *testing.T) {
repo := &fakeRuleAlertRepo{rules: []model.AlertRule{{ID: "r1"}}}
svc := NewRuleService(repo)
if rules, err := svc.ListRules(context.Background()); err != nil || len(rules) != 1 {
t.Fatalf("list = %v %v", rules, err)
}
if rule, err := svc.GetRule(context.Background(), "r1"); err != nil || rule.ID != "r1" {
t.Fatalf("get = %+v %v", rule, err)
}
if err := svc.CreateRule(context.Background(), &model.AlertRule{}); err == nil {
t.Fatal("expected missing id error")
}
if err := svc.CreateRule(context.Background(), &model.AlertRule{ID: "r2"}); err == nil {
t.Fatal("expected missing name/metric error")
}
rule := &model.AlertRule{ID: "r2", Name: "latency", MetricName: "p99"}
if err := svc.CreateRule(context.Background(), rule); err != nil {
t.Fatal(err)
}
if !rule.Enabled || rule.Version != 1 || repo.createdRule != rule {
t.Fatalf("create did not normalize rule: %+v", rule)
}
if err := svc.UpdateRule(context.Background(), &model.AlertRule{}); err == nil {
t.Fatal("expected missing update id error")
}
updating := &model.AlertRule{ID: "r2", Version: 2}
if err := svc.UpdateRule(context.Background(), updating); err != nil {
t.Fatal(err)
}
if updating.Version != 3 || repo.updatedRule != updating {
t.Fatalf("version not incremented: %+v", updating)
}
if err := svc.DeleteRule(context.Background(), "r2"); err != nil || repo.deletedID != "r2" {
t.Fatalf("delete failed: %v", err)
}
}
func TestChannelServiceValidationAndRepositoryCalls(t *testing.T) {
repo := &fakeChannelRepository{channels: []model.NotificationChannel{{ID: "c1"}}}
svc := NewChannelService(repo)
if channels, err := svc.List(context.Background()); err != nil || len(channels) != 1 {
t.Fatalf("list = %v %v", channels, err)
}
if ch, err := svc.Get(context.Background(), "c1"); err != nil || ch.ID != "c1" {
t.Fatalf("get = %+v %v", ch, err)
}
if err := svc.Create(context.Background(), &model.NotificationChannel{}); err == nil {
t.Fatal("expected validation error")
}
ch := &model.NotificationChannel{Name: "hook", ChannelType: "webhook"}
if err := svc.Create(context.Background(), ch); err != nil {
t.Fatal(err)
}
if !ch.Enabled || repo.created != ch {
t.Fatalf("create did not enable channel: %+v", ch)
}
if err := svc.Update(context.Background(), &model.NotificationChannel{}); err == nil {
t.Fatal("expected missing id error")
}
if err := svc.Update(context.Background(), &model.NotificationChannel{ID: "c1"}); err != nil {
t.Fatal(err)
}
if err := svc.Delete(context.Background(), "c1"); err != nil || repo.deleted != "c1" {
t.Fatalf("delete failed: %v", err)
}
}
func TestLogServiceQueryAndExportCSV(t *testing.T) {
repo := &fakeLogRepository{
logs: []model.RequestLog{{Timestamp: time.Date(2026, 5, 12, 1, 2, 3, 0, time.UTC), Service: "api", Path: "/v1", Method: "GET", StatusCode: 200, LatencyMs: 12.34, UserID: "u", SupplierID: "s"}},
total: 1,
}
svc := NewLogService(repo)
logs, total, err := svc.QueryLogs(context.Background(), model.LogQueryFilter{Service: "api", Page: 2, PageSize: 5})
if err != nil || total != 1 || len(logs) != 1 {
t.Fatalf("query = %v %d %v", logs, total, err)
}
if repo.lastFilter.Service != "api" || repo.lastFilter.Page != 2 {
t.Fatalf("filter not passed: %+v", repo.lastFilter)
}
var buf bytes.Buffer
if err := svc.ExportLogsCSV(context.Background(), model.LogQueryFilter{Page: 9, PageSize: 1}, &buf); err != nil {
t.Fatal(err)
}
out := buf.String()
if !strings.Contains(out, "时间,服务名,路径,方法,状态码") || !strings.Contains(out, "api,/v1,GET,200,12.34") {
t.Fatalf("unexpected csv: %s", out)
}
if repo.lastFilter.Page != 1 || repo.lastFilter.PageSize != 10000 {
t.Fatalf("export did not enforce bounds: %+v", repo.lastFilter)
}
}
func TestLogServicePropagatesRepositoryErrors(t *testing.T) {
svc := NewLogService(&fakeLogRepository{err: errors.New("db down")})
if _, _, err := svc.QueryLogs(context.Background(), model.LogQueryFilter{}); err == nil || !strings.Contains(err.Error(), "query logs") {
t.Fatalf("unexpected query err: %v", err)
}
if err := svc.ExportLogsCSV(context.Background(), model.LogQueryFilter{}, &bytes.Buffer{}); err == nil || !strings.Contains(err.Error(), "query logs for export") {
t.Fatalf("unexpected export err: %v", err)
}
}