226 lines
7.7 KiB
Go
226 lines
7.7 KiB
Go
package service
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/company/ai-ops/internal/domain/model"
|
|
)
|
|
|
|
type fakeRuleAlertRepo struct {
|
|
rules []model.AlertRule
|
|
gotRuleID string
|
|
createdRule *model.AlertRule
|
|
updatedRule *model.AlertRule
|
|
deletedID string
|
|
err error
|
|
}
|
|
|
|
func (r *fakeRuleAlertRepo) GetOpenCount(context.Context) (*model.AlertCount, error) {
|
|
return &model.AlertCount{}, nil
|
|
}
|
|
func (r *fakeRuleAlertRepo) ListRules(context.Context) ([]model.AlertRule, error) {
|
|
return r.rules, r.err
|
|
}
|
|
func (r *fakeRuleAlertRepo) GetRuleByID(_ context.Context, id string) (*model.AlertRule, error) {
|
|
r.gotRuleID = id
|
|
if r.err != nil {
|
|
return nil, r.err
|
|
}
|
|
return &model.AlertRule{ID: id, Name: "rule"}, nil
|
|
}
|
|
func (r *fakeRuleAlertRepo) CreateRule(_ context.Context, rule *model.AlertRule) error {
|
|
r.createdRule = rule
|
|
return r.err
|
|
}
|
|
func (r *fakeRuleAlertRepo) UpdateRule(_ context.Context, rule *model.AlertRule) error {
|
|
r.updatedRule = rule
|
|
return r.err
|
|
}
|
|
func (r *fakeRuleAlertRepo) DeleteRule(_ context.Context, id string) error {
|
|
r.deletedID = id
|
|
return r.err
|
|
}
|
|
func (r *fakeRuleAlertRepo) ListEvents(context.Context, string, int, int) ([]model.AlertEvent, int, error) {
|
|
return nil, 0, nil
|
|
}
|
|
func (r *fakeRuleAlertRepo) CreateEvent(context.Context, *model.AlertEvent) error { return nil }
|
|
func (r *fakeRuleAlertRepo) CreateEventWithAggregation(_ context.Context, e *model.AlertEvent, _ time.Duration, _ int) (*model.AlertEvent, error) {
|
|
return e, nil
|
|
}
|
|
func (r *fakeRuleAlertRepo) UpdateEventStatus(context.Context, string, string) error { return nil }
|
|
func (r *fakeRuleAlertRepo) EscalateEvent(context.Context, string, string) error { return nil }
|
|
|
|
type fakeChannelRepository struct {
|
|
channels []model.NotificationChannel
|
|
gotID string
|
|
created *model.NotificationChannel
|
|
updated *model.NotificationChannel
|
|
deleted string
|
|
err error
|
|
}
|
|
|
|
func (r *fakeChannelRepository) List(context.Context) ([]model.NotificationChannel, error) {
|
|
return r.channels, r.err
|
|
}
|
|
func (r *fakeChannelRepository) GetByID(_ context.Context, id string) (*model.NotificationChannel, error) {
|
|
r.gotID = id
|
|
if r.err != nil {
|
|
return nil, r.err
|
|
}
|
|
return &model.NotificationChannel{ID: id, Name: "webhook"}, nil
|
|
}
|
|
func (r *fakeChannelRepository) Create(_ context.Context, ch *model.NotificationChannel) error {
|
|
r.created = ch
|
|
return r.err
|
|
}
|
|
func (r *fakeChannelRepository) Update(_ context.Context, ch *model.NotificationChannel) error {
|
|
r.updated = ch
|
|
return r.err
|
|
}
|
|
func (r *fakeChannelRepository) Delete(_ context.Context, id string) error {
|
|
r.deleted = id
|
|
return r.err
|
|
}
|
|
|
|
type fakeLogRepository struct {
|
|
logs []model.RequestLog
|
|
total int
|
|
lastFilter model.LogQueryFilter
|
|
err error
|
|
}
|
|
|
|
func (r *fakeLogRepository) Query(_ context.Context, filter model.LogQueryFilter) ([]model.RequestLog, int, error) {
|
|
r.lastFilter = filter
|
|
return r.logs, r.total, r.err
|
|
}
|
|
|
|
func TestAuthServiceIssuesAndParsesToken(t *testing.T) {
|
|
svc := NewAuthService("secret")
|
|
token, err := svc.IssueToken("u1", "admin")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
claims, err := svc.ParseToken(token)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if claims.UserID != "u1" || claims.Role != "admin" {
|
|
t.Fatalf("unexpected claims: %+v", claims)
|
|
}
|
|
if _, err := NewAuthService("other").ParseToken(token); err == nil {
|
|
t.Fatal("expected invalid signature error")
|
|
}
|
|
if _, err := svc.ParseToken("not-a-jwt"); err == nil {
|
|
t.Fatal("expected malformed token error")
|
|
}
|
|
}
|
|
|
|
func TestRuleServiceValidationAndRepositoryCalls(t *testing.T) {
|
|
repo := &fakeRuleAlertRepo{rules: []model.AlertRule{{ID: "r1"}}}
|
|
svc := NewRuleService(repo)
|
|
if rules, err := svc.ListRules(context.Background()); err != nil || len(rules) != 1 {
|
|
t.Fatalf("list = %v %v", rules, err)
|
|
}
|
|
if rule, err := svc.GetRule(context.Background(), "r1"); err != nil || rule.ID != "r1" {
|
|
t.Fatalf("get = %+v %v", rule, err)
|
|
}
|
|
if err := svc.CreateRule(context.Background(), &model.AlertRule{}); err == nil {
|
|
t.Fatal("expected missing id error")
|
|
}
|
|
if err := svc.CreateRule(context.Background(), &model.AlertRule{ID: "r2"}); err == nil {
|
|
t.Fatal("expected missing name/metric error")
|
|
}
|
|
rule := &model.AlertRule{ID: "r2", Name: "latency", MetricName: "p99"}
|
|
if err := svc.CreateRule(context.Background(), rule); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if !rule.Enabled || rule.Version != 1 || repo.createdRule != rule {
|
|
t.Fatalf("create did not normalize rule: %+v", rule)
|
|
}
|
|
if err := svc.UpdateRule(context.Background(), &model.AlertRule{}); err == nil {
|
|
t.Fatal("expected missing update id error")
|
|
}
|
|
updating := &model.AlertRule{ID: "r2", Version: 2}
|
|
if err := svc.UpdateRule(context.Background(), updating); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if updating.Version != 3 || repo.updatedRule != updating {
|
|
t.Fatalf("version not incremented: %+v", updating)
|
|
}
|
|
if err := svc.DeleteRule(context.Background(), "r2"); err != nil || repo.deletedID != "r2" {
|
|
t.Fatalf("delete failed: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestChannelServiceValidationAndRepositoryCalls(t *testing.T) {
|
|
repo := &fakeChannelRepository{channels: []model.NotificationChannel{{ID: "c1"}}}
|
|
svc := NewChannelService(repo)
|
|
if channels, err := svc.List(context.Background()); err != nil || len(channels) != 1 {
|
|
t.Fatalf("list = %v %v", channels, err)
|
|
}
|
|
if ch, err := svc.Get(context.Background(), "c1"); err != nil || ch.ID != "c1" {
|
|
t.Fatalf("get = %+v %v", ch, err)
|
|
}
|
|
if err := svc.Create(context.Background(), &model.NotificationChannel{}); err == nil {
|
|
t.Fatal("expected validation error")
|
|
}
|
|
ch := &model.NotificationChannel{Name: "hook", ChannelType: "webhook"}
|
|
if err := svc.Create(context.Background(), ch); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if !ch.Enabled || repo.created != ch {
|
|
t.Fatalf("create did not enable channel: %+v", ch)
|
|
}
|
|
if err := svc.Update(context.Background(), &model.NotificationChannel{}); err == nil {
|
|
t.Fatal("expected missing id error")
|
|
}
|
|
if err := svc.Update(context.Background(), &model.NotificationChannel{ID: "c1"}); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := svc.Delete(context.Background(), "c1"); err != nil || repo.deleted != "c1" {
|
|
t.Fatalf("delete failed: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestLogServiceQueryAndExportCSV(t *testing.T) {
|
|
repo := &fakeLogRepository{
|
|
logs: []model.RequestLog{{Timestamp: time.Date(2026, 5, 12, 1, 2, 3, 0, time.UTC), Service: "api", Path: "/v1", Method: "GET", StatusCode: 200, LatencyMs: 12.34, UserID: "u", SupplierID: "s"}},
|
|
total: 1,
|
|
}
|
|
svc := NewLogService(repo)
|
|
logs, total, err := svc.QueryLogs(context.Background(), model.LogQueryFilter{Service: "api", Page: 2, PageSize: 5})
|
|
if err != nil || total != 1 || len(logs) != 1 {
|
|
t.Fatalf("query = %v %d %v", logs, total, err)
|
|
}
|
|
if repo.lastFilter.Service != "api" || repo.lastFilter.Page != 2 {
|
|
t.Fatalf("filter not passed: %+v", repo.lastFilter)
|
|
}
|
|
|
|
var buf bytes.Buffer
|
|
if err := svc.ExportLogsCSV(context.Background(), model.LogQueryFilter{Page: 9, PageSize: 1}, &buf); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
out := buf.String()
|
|
if !strings.Contains(out, "时间,服务名,路径,方法,状态码") || !strings.Contains(out, "api,/v1,GET,200,12.34") {
|
|
t.Fatalf("unexpected csv: %s", out)
|
|
}
|
|
if repo.lastFilter.Page != 1 || repo.lastFilter.PageSize != 10000 {
|
|
t.Fatalf("export did not enforce bounds: %+v", repo.lastFilter)
|
|
}
|
|
}
|
|
|
|
func TestLogServicePropagatesRepositoryErrors(t *testing.T) {
|
|
svc := NewLogService(&fakeLogRepository{err: errors.New("db down")})
|
|
if _, _, err := svc.QueryLogs(context.Background(), model.LogQueryFilter{}); err == nil || !strings.Contains(err.Error(), "query logs") {
|
|
t.Fatalf("unexpected query err: %v", err)
|
|
}
|
|
if err := svc.ExportLogsCSV(context.Background(), model.LogQueryFilter{}, &bytes.Buffer{}); err == nil || !strings.Contains(err.Error(), "query logs for export") {
|
|
t.Fatalf("unexpected export err: %v", err)
|
|
}
|
|
}
|