chore: initial import

This commit is contained in:
phamnazage-jpg
2026-05-12 17:47:32 +08:00
commit fc54ba84b2
104 changed files with 11575 additions and 0 deletions

138
internal/config/config.go Normal file
View File

@@ -0,0 +1,138 @@
package config
import (
"fmt"
"os"
"strings"
"github.com/spf13/viper"
)
// Config 是应用配置结构
type Config struct {
Server ServerConfig `mapstructure:"server"`
Database DatabaseConfig `mapstructure:"database"`
Redis RedisConfig `mapstructure:"redis"`
Metrics MetricsConfig `mapstructure:"metrics"`
}
type ServerConfig struct {
Port int `mapstructure:"port"`
Mode string `mapstructure:"mode"` // development / production
JWTSecret string `mapstructure:"jwt_secret"`
MetricsAuth string `mapstructure:"metrics_auth"` // API Key for /metrics
}
type DatabaseConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
User string `mapstructure:"user"`
Password string `mapstructure:"password"`
DBName string `mapstructure:"dbname"`
SSLMode string `mapstructure:"sslmode"`
PoolSize int `mapstructure:"pool_size"`
}
type RedisConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
Password string `mapstructure:"password"`
DB int `mapstructure:"db"`
}
type MetricsConfig struct {
PrometheusURL string `mapstructure:"prometheus_url"`
RetentionDays int `mapstructure:"retention_days"`
}
// Load 从配置文件和环境变量加载配置
func Load(path string) (*Config, error) {
v := viper.New()
v.SetConfigFile(path)
v.SetEnvPrefix("AI_OPS")
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
v.AutomaticEnv()
// 默认值
v.SetDefault("server.port", 8080)
v.SetDefault("server.mode", "development")
v.SetDefault("database.host", "localhost")
v.SetDefault("database.port", 5432)
v.SetDefault("database.sslmode", "disable")
v.SetDefault("database.pool_size", 10)
v.SetDefault("redis.host", "localhost")
v.SetDefault("redis.port", 6379)
v.SetDefault("metrics.retention_days", 7)
if err := v.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
return nil, fmt.Errorf("read config: %w", err)
}
}
var cfg Config
if err := v.Unmarshal(&cfg); err != nil {
return nil, fmt.Errorf("unmarshal config: %w", err)
}
// 环境变量覆盖
if host := os.Getenv("SPRING_DATASOURCE_URL"); host != "" {
// 兼容 Spring Boot 风格的数据库配置
cfg.Database.Host = host
}
applyExplicitEnvOverrides(&cfg)
if err := cfg.Validate(); err != nil {
return nil, err
}
return &cfg, nil
}
func applyExplicitEnvOverrides(cfg *Config) {
setString := func(key string, dst *string) {
if v := os.Getenv(key); v != "" {
*dst = v
}
}
setString("AI_OPS_SERVER_JWT_SECRET", &cfg.Server.JWTSecret)
setString("AI_OPS_SERVER_METRICS_AUTH", &cfg.Server.MetricsAuth)
setString("AI_OPS_DATABASE_HOST", &cfg.Database.Host)
setString("AI_OPS_DATABASE_USER", &cfg.Database.User)
setString("AI_OPS_DATABASE_PASSWORD", &cfg.Database.Password)
setString("AI_OPS_DATABASE_DBNAME", &cfg.Database.DBName)
setString("AI_OPS_REDIS_HOST", &cfg.Redis.Host)
setString("AI_OPS_REDIS_PASSWORD", &cfg.Redis.Password)
}
func (c *Config) Validate() error {
if c.Server.Port <= 0 || c.Server.Port > 65535 {
return fmt.Errorf("invalid server.port: %d", c.Server.Port)
}
if c.Database.Port <= 0 || c.Database.Port > 65535 {
return fmt.Errorf("invalid database.port: %d", c.Database.Port)
}
if c.Database.PoolSize <= 0 {
return fmt.Errorf("invalid database.pool_size: %d", c.Database.PoolSize)
}
if c.Metrics.RetentionDays <= 0 {
return fmt.Errorf("invalid metrics.retention_days: %d", c.Metrics.RetentionDays)
}
if strings.EqualFold(c.Server.Mode, "production") {
if len(c.Server.JWTSecret) < 32 {
return fmt.Errorf("server.jwt_secret must be at least 32 characters in production")
}
if len(c.Server.MetricsAuth) < 16 {
return fmt.Errorf("server.metrics_auth must be at least 16 characters in production")
}
if c.Database.Host == "" || c.Database.User == "" || c.Database.Password == "" || c.Database.DBName == "" {
return fmt.Errorf("database host/user/password/dbname are required in production")
}
}
return nil
}
// DSN 返回 PostgreSQL 连接字符串
func (c DatabaseConfig) DSN() string {
return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s pool_max_conns=%d",
c.Host, c.Port, c.User, c.Password, c.DBName, c.SSLMode, c.PoolSize)
}

View File

@@ -0,0 +1,136 @@
package config
import (
"os"
"path/filepath"
"strings"
"testing"
)
func TestLoadReadsConfigAndBuildsDSN(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "config.yaml")
content := []byte(`server:
port: 19090
mode: production
jwt_secret: "0123456789abcdef0123456789abcdef"
metrics_auth: "metrics-api-key-123456"
database:
host: db
port: 15432
user: user
password: pass
dbname: aiops
sslmode: require
pool_size: 7
redis:
host: redis
port: 16379
password: redispass
db: 2
metrics:
prometheus_url: http://prom
retention_days: 14
`)
if err := os.WriteFile(path, content, 0o600); err != nil {
t.Fatal(err)
}
cfg, err := Load(path)
if err != nil {
t.Fatal(err)
}
if cfg.Server.Port != 19090 || cfg.Database.Host != "db" || cfg.Redis.DB != 2 || cfg.Metrics.RetentionDays != 14 {
t.Fatalf("unexpected config: %+v", cfg)
}
dsn := cfg.Database.DSN()
for _, want := range []string{"host=db", "port=15432", "user=user", "password=pass", "dbname=aiops", "sslmode=require", "pool_max_conns=7"} {
if !strings.Contains(dsn, want) {
t.Fatalf("dsn %q missing %q", dsn, want)
}
}
}
func TestLoadAppliesDefaultsAndSpringDatasourceCompatibility(t *testing.T) {
t.Setenv("SPRING_DATASOURCE_URL", "spring-host")
path := filepath.Join(t.TempDir(), "empty.yaml")
if err := os.WriteFile(path, []byte("{}\n"), 0o600); err != nil {
t.Fatal(err)
}
cfg, err := Load(path)
if err != nil {
t.Fatal(err)
}
if cfg.Server.Port != 8080 || cfg.Database.Port != 5432 || cfg.Redis.Port != 6379 || cfg.Metrics.RetentionDays != 7 {
t.Fatalf("defaults not applied: %+v", cfg)
}
if cfg.Database.Host != "spring-host" {
t.Fatalf("spring datasource compatibility not applied: %s", cfg.Database.Host)
}
}
func TestLoadReturnsErrorForMalformedConfig(t *testing.T) {
path := filepath.Join(t.TempDir(), "bad.yaml")
if err := os.WriteFile(path, []byte("server: ["), 0o600); err != nil {
t.Fatal(err)
}
if _, err := Load(path); err == nil {
t.Fatal("expected malformed config error")
}
}
func TestLoadRejectsWeakProductionSecrets(t *testing.T) {
path := filepath.Join(t.TempDir(), "config.yaml")
content := []byte(`server:
mode: production
jwt_secret: short
metrics_auth: short
database:
host: db
port: 5432
user: aiops
password: aiops123
dbname: ai_ops
pool_size: 1
metrics:
retention_days: 7
`)
if err := os.WriteFile(path, content, 0o600); err != nil {
t.Fatal(err)
}
_, err := Load(path)
if err == nil || !strings.Contains(err.Error(), "jwt_secret") {
t.Fatalf("expected weak jwt secret error, got %v", err)
}
}
func TestLoadAppliesExplicitEnvironmentOverrides(t *testing.T) {
path := filepath.Join(t.TempDir(), "config.yaml")
content := []byte(`server:
mode: production
jwt_secret: "0123456789abcdef0123456789abcdef"
metrics_auth: "metrics-api-key-123456"
database:
host: db
port: 5432
user: aiops
password: aiops123
dbname: ai_ops
pool_size: 1
metrics:
retention_days: 7
`)
if err := os.WriteFile(path, content, 0o600); err != nil {
t.Fatal(err)
}
t.Setenv("AI_OPS_DATABASE_PASSWORD", "override-pass")
t.Setenv("AI_OPS_SERVER_METRICS_AUTH", "override-metrics-key")
cfg, err := Load(path)
if err != nil {
t.Fatal(err)
}
if cfg.Database.Password != "override-pass" || cfg.Server.MetricsAuth != "override-metrics-key" {
t.Fatalf("env overrides not applied: %+v", cfg)
}
}

View File

@@ -0,0 +1,47 @@
package database
import (
"context"
"fmt"
"time"
"github.com/company/ai-ops/internal/config"
"github.com/jackc/pgx/v5/pgxpool"
)
// Pool 是全局数据库连接池
var Pool *pgxpool.Pool
// Init 初始化数据库连接
func Init(cfg config.DatabaseConfig) error {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
poolConfig, err := pgxpool.ParseConfig(cfg.DSN())
if err != nil {
return fmt.Errorf("parse db config: %w", err)
}
poolConfig.MaxConns = int32(cfg.PoolSize)
poolConfig.MinConns = 2
poolConfig.MaxConnLifetime = 30 * time.Minute
poolConfig.MaxConnIdleTime = 10 * time.Minute
Pool, err = pgxpool.NewWithConfig(ctx, poolConfig)
if err != nil {
return fmt.Errorf("create db pool: %w", err)
}
if err := Pool.Ping(ctx); err != nil {
return fmt.Errorf("ping db: %w", err)
}
return nil
}
// Close 关闭数据库连接
func Close() {
if Pool != nil {
Pool.Close()
}
}

View File

@@ -0,0 +1,37 @@
package database
import (
"testing"
"github.com/company/ai-ops/internal/config"
)
func TestInitAndCloseWithLocalPostgres(t *testing.T) {
ports := []int{15432, 5432}
var lastErr error
for _, port := range ports {
lastErr = Init(config.DatabaseConfig{Host: "localhost", Port: port, User: "aiops", Password: "aiops123", DBName: "ai_ops", SSLMode: "disable", PoolSize: 4})
if lastErr == nil {
break
}
Close()
Pool = nil
}
if lastErr != nil {
t.Skipf("PostgreSQL integration database not available: %v", lastErr)
}
if Pool == nil {
t.Fatal("pool not initialized")
}
Close()
Pool = nil
}
func TestInitReturnsErrorForInvalidConfig(t *testing.T) {
if err := Init(config.DatabaseConfig{Host: "::::bad-host::::", Port: 1, User: "u", Password: "p", DBName: "d", SSLMode: "disable", PoolSize: 1}); err == nil {
Close()
Pool = nil
t.Fatal("expected invalid db config error")
}
Pool = nil
}

View File

@@ -0,0 +1,50 @@
package model
import "time"
// AlertRule 是告警规则
type AlertRule struct {
ID string `json:"id"`
Name string `json:"name"`
MetricSource string `json:"metric_source"`
MetricName string `json:"metric_name"`
ThresholdType string `json:"threshold_type"`
ThresholdValue string `json:"threshold_value"`
DurationMin int `json:"duration_min"`
Level string `json:"level"`
ChannelIDs []string `json:"channel_ids"`
HealingAction *string `json:"healing_action,omitempty"`
HealingConfig map[string]any `json:"healing_config,omitempty"`
IsSandboxed bool `json:"is_sandboxed"`
Enabled bool `json:"enabled"`
Version int `json:"version"`
CreatedBy string `json:"created_by"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// AlertEvent 是告警事件
type AlertEvent struct {
ID string `json:"id"`
RuleID string `json:"rule_id"`
Level string `json:"level"`
ResourceType string `json:"resource_type"`
ResourceID string `json:"resource_id"`
CurrentValue string `json:"current_value"`
ThresholdValue string `json:"threshold_value"`
Status string `json:"status"`
IsAggregated bool `json:"is_aggregated"`
AggregatedCount int `json:"aggregated_count"`
ParentAlertID *string `json:"parent_alert_id,omitempty"`
StartedAt time.Time `json:"started_at"`
ResolvedAt *time.Time `json:"resolved_at,omitempty"`
}
// AlertCount 是告警统计
type AlertCount struct {
Open int `json:"open"`
P0 int `json:"p0"`
P1 int `json:"p1"`
P2 int `json:"p2"`
P3 int `json:"p3"`
}

View File

@@ -0,0 +1,21 @@
package model
import "time"
// NotificationChannel 是通知渠道
type NotificationChannel struct {
ID string `json:"id"`
Name string `json:"name"`
ChannelType string `json:"channel_type"`
Config map[string]any `json:"config"`
Priority int `json:"priority"`
Enabled bool `json:"enabled"`
CreatedAt time.Time `json:"created_at"`
}
// ChannelConfig 是通道配置结构
type ChannelConfig struct {
WebhookURL string `json:"webhook_url,omitempty"`
EmailTo string `json:"email_to,omitempty"`
APIToken string `json:"api_token,omitempty"`
}

View File

@@ -0,0 +1,30 @@
package model
import "time"
// RequestLog 是请求日志记录
type RequestLog struct {
ID string `json:"id"`
Timestamp time.Time `json:"timestamp"`
Service string `json:"service"`
Path string `json:"path"`
StatusCode int `json:"status_code"`
LatencyMs float64 `json:"latency_ms"`
UserID string `json:"user_id"`
SupplierID string `json:"supplier_id"`
Method string `json:"method"`
ErrorCode string `json:"error_code,omitempty"`
}
// LogQueryFilter 是日志查询过滤条件
type LogQueryFilter struct {
StartTime *time.Time `json:"start_time,omitempty"`
EndTime *time.Time `json:"end_time,omitempty"`
Service string `json:"service,omitempty"`
Path string `json:"path,omitempty"`
StatusCode *int `json:"status_code,omitempty"`
UserID string `json:"user_id,omitempty"`
SupplierID string `json:"supplier_id,omitempty"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}

View File

@@ -0,0 +1,37 @@
package model
import "time"
// MetricPoint 是时序数据点
type MetricPoint struct {
Source string `json:"source"`
Name string `json:"name"`
Value float64 `json:"value"`
Tags map[string]string `json:"tags"`
Timestamp time.Time `json:"timestamp"`
}
// MetricQueryRequest 是指标查询请求
type MetricQueryRequest struct {
Source string `json:"source"`
Name string `json:"name"`
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time"`
Interval time.Duration `json:"interval"`
Tags map[string]string `json:"tags"`
}
// RealtimeMetrics 是首页实时指标
type RealtimeMetrics struct {
QPS float64 `json:"qps"`
AvgLatency float64 `json:"avg_latency_ms"`
P99Latency float64 `json:"p99_latency_ms"`
ErrorRate float64 `json:"error_rate"`
}
// SupplierCount 是供应商统计
type SupplierCount struct {
Total int `json:"total"`
Healthy int `json:"healthy"`
Unhealthy int `json:"unhealthy"`
}

View File

@@ -0,0 +1,16 @@
package model
import "time"
// NotificationLog 记录单次通知渠道发送结果。
type NotificationLog struct {
ID string `json:"id"`
EventID string `json:"event_id"`
ChannelID string `json:"channel_id"`
ChannelType string `json:"channel_type"`
Status string `json:"status"`
RetryCount int `json:"retry_count"`
ErrorMessage *string `json:"error_message,omitempty"`
SentAt *time.Time `json:"sent_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
}

View File

@@ -0,0 +1,28 @@
package repository
import (
"context"
"time"
"github.com/company/ai-ops/internal/domain/model"
)
// AlertRepository 是告警数据存储接口
type AlertRepository interface {
// 告警统计
GetOpenCount(ctx context.Context) (*model.AlertCount, error)
// 规则 CRUD
ListRules(ctx context.Context) ([]model.AlertRule, error)
GetRuleByID(ctx context.Context, id string) (*model.AlertRule, error)
CreateRule(ctx context.Context, rule *model.AlertRule) error
UpdateRule(ctx context.Context, rule *model.AlertRule) error
DeleteRule(ctx context.Context, id string) error
// 告警事件
ListEvents(ctx context.Context, status string, page, pageSize int) ([]model.AlertEvent, int, error)
CreateEvent(ctx context.Context, event *model.AlertEvent) error
CreateEventWithAggregation(ctx context.Context, event *model.AlertEvent, window time.Duration, threshold int) (*model.AlertEvent, error)
UpdateEventStatus(ctx context.Context, id, status string) error
EscalateEvent(ctx context.Context, id, newLevel string) error
}

View File

@@ -0,0 +1,16 @@
package repository
import (
"context"
"github.com/company/ai-ops/internal/domain/model"
)
// ChannelRepository 是通知渠道存储接口
type ChannelRepository interface {
List(ctx context.Context) ([]model.NotificationChannel, error)
GetByID(ctx context.Context, id string) (*model.NotificationChannel, error)
Create(ctx context.Context, ch *model.NotificationChannel) error
Update(ctx context.Context, ch *model.NotificationChannel) error
Delete(ctx context.Context, id string) error
}

View File

@@ -0,0 +1,13 @@
package repository
import (
"context"
"github.com/company/ai-ops/internal/domain/model"
)
// LogRepository 是日志数据存储接口
type LogRepository interface {
// Query 查询日志
Query(ctx context.Context, filter model.LogQueryFilter) ([]model.RequestLog, int, error)
}

View File

@@ -0,0 +1,17 @@
package repository
import (
"context"
"github.com/company/ai-ops/internal/domain/model"
)
// MetricRepository 是指标数据存储接口
type MetricRepository interface {
// GetRealtime 获取实时指标
GetRealtime(ctx context.Context) (*model.RealtimeMetrics, error)
// Query 按条件查询指标
Query(ctx context.Context, req model.MetricQueryRequest) ([]model.MetricPoint, error)
// GetLatest 获取最新指标值
GetLatest(ctx context.Context, source, name string) (*model.MetricPoint, error)
}

View File

@@ -0,0 +1,14 @@
package repository
import (
"context"
"github.com/company/ai-ops/internal/domain/model"
)
// NotificationLogRepository 是通知发送记录存储接口。
type NotificationLogRepository interface {
CreateLog(ctx context.Context, log *model.NotificationLog) error
MarkSent(ctx context.Context, id string) error
MarkFailed(ctx context.Context, id string, retryCount int, errMessage string) error
}

View File

@@ -0,0 +1,43 @@
package handler
import (
"net/http"
"strconv"
"github.com/company/ai-ops/internal/domain/repository"
"github.com/company/ai-ops/pkg/errors"
"github.com/company/ai-ops/pkg/response"
)
// AlertHandler 是告警事件 HTTP 处理器
type AlertHandler struct {
repo repository.AlertRepository
}
func NewAlertHandler(repo repository.AlertRepository) *AlertHandler {
return &AlertHandler{repo: repo}
}
func (h *AlertHandler) RegisterRoutes(mux *http.ServeMux) {
mux.HandleFunc("GET /api/v1/ai-ops/alerts", h.ListAlerts)
}
func (h *AlertHandler) ListAlerts(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
status := query.Get("status")
page, _ := strconv.Atoi(query.Get("page"))
pageSize, _ := strconv.Atoi(query.Get("page_size"))
if page < 1 {
page = 1
}
if pageSize < 1 || pageSize > 100 {
pageSize = 20
}
events, total, err := h.repo.ListEvents(r.Context(), status, page, pageSize)
if err != nil {
response.Error(w, errors.Wrap(err, errors.ErrInternal))
return
}
response.Success(w, map[string]any{"items": events, "total": total, "page": page, "page_size": pageSize})
}

View File

@@ -0,0 +1,55 @@
package handler
import (
"net/http"
"strconv"
"github.com/company/ai-ops/internal/service"
"github.com/company/ai-ops/pkg/errors"
"github.com/company/ai-ops/pkg/response"
)
// AuditHandler 是审计日志 HTTP 处理器
type AuditHandler struct {
service *service.AuditService
}
func NewAuditHandler(s *service.AuditService) *AuditHandler {
return &AuditHandler{service: s}
}
func (h *AuditHandler) RegisterRoutes(mux *http.ServeMux) {
mux.HandleFunc("GET /api/v1/ai-ops/audits", h.ListAudits)
mux.HandleFunc("POST /api/v1/ai-ops/audits/{id}/rollback", h.Rollback)
}
func (h *AuditHandler) ListAudits(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
objectType := query.Get("object_type")
objectID := query.Get("object_id")
page, _ := strconv.Atoi(query.Get("page"))
pageSize, _ := strconv.Atoi(query.Get("page_size"))
if page < 1 {
page = 1
}
if pageSize < 1 || pageSize > 100 {
pageSize = 20
}
logs, total, err := h.service.List(r.Context(), objectType, objectID, page, pageSize)
if err != nil {
response.Error(w, errors.Wrap(err, errors.ErrInternal))
return
}
response.Success(w, map[string]any{"items": logs, "total": total, "page": page, "page_size": pageSize})
}
func (h *AuditHandler) Rollback(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
log, err := h.service.Rollback(r.Context(), id)
if err != nil {
response.Error(w, errors.Wrap(err, errors.ErrBadRequest).WithDetail(map[string]any{"error": err.Error()}))
return
}
response.Success(w, log)
}

View File

@@ -0,0 +1,59 @@
package handler
import (
"net/http"
"github.com/company/ai-ops/internal/service"
"github.com/company/ai-ops/pkg/errors"
"github.com/company/ai-ops/pkg/response"
)
// AuthHandler 是认证 HTTP 处理器
type AuthHandler struct {
authSvc *service.AuthService
}
func NewAuthHandler(authSvc *service.AuthService) *AuthHandler {
return &AuthHandler{authSvc: authSvc}
}
func (h *AuthHandler) RegisterRoutes(mux *http.ServeMux) {
mux.HandleFunc("POST /api/v1/ai-ops/login", h.Login)
}
func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
var req struct {
Username string `json:"username"`
Password string `json:"password"`
}
if err := decodeJSON(r, &req); err != nil {
response.Error(w, errors.ErrBadRequest.WithDetail(map[string]any{"error": err.Error()}))
return
}
// TODO: 实现真实的用户验证(当前为简化实现)
if req.Username == "" || req.Password == "" {
response.Error(w, errors.ErrBadRequest.WithDetail(map[string]any{"error": "username and password required"}))
return
}
// 默认角色为 viewer
role := "viewer"
if req.Username == "admin" {
role = "admin"
} else if req.Username == "ops" {
role = "operator"
}
token, err := h.authSvc.IssueToken(req.Username, role)
if err != nil {
response.Error(w, errors.Wrap(err, errors.ErrInternal))
return
}
response.Success(w, map[string]any{
"token": token,
"expires_in": 28800,
"role": role,
})
}

View File

@@ -0,0 +1,97 @@
package handler
import (
"net/http"
"github.com/company/ai-ops/internal/domain/model"
"github.com/company/ai-ops/internal/service"
"github.com/company/ai-ops/pkg/errors"
"github.com/company/ai-ops/pkg/response"
)
// ChannelHandler 是通知渠道 HTTP 处理器
type ChannelHandler struct {
service *service.ChannelService
}
func NewChannelHandler(s *service.ChannelService) *ChannelHandler {
return &ChannelHandler{service: s}
}
func (h *ChannelHandler) RegisterRoutes(mux *http.ServeMux) {
mux.HandleFunc("GET /api/v1/ai-ops/channels", h.ListChannels)
mux.HandleFunc("GET /api/v1/ai-ops/channels/{id}", h.GetChannel)
mux.HandleFunc("POST /api/v1/ai-ops/channels", h.CreateChannel)
mux.HandleFunc("PUT /api/v1/ai-ops/channels/{id}", h.UpdateChannel)
mux.HandleFunc("DELETE /api/v1/ai-ops/channels/{id}", h.DeleteChannel)
mux.HandleFunc("POST /api/v1/ai-ops/channels/test", h.TestChannel)
}
func (h *ChannelHandler) ListChannels(w http.ResponseWriter, r *http.Request) {
channels, err := h.service.List(r.Context())
if err != nil {
response.Error(w, errors.Wrap(err, errors.ErrInternal))
return
}
response.Success(w, channels)
}
func (h *ChannelHandler) GetChannel(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
ch, err := h.service.Get(r.Context(), id)
if err != nil {
response.Error(w, errors.Wrap(err, errors.ErrNotFound))
return
}
response.Success(w, ch)
}
func (h *ChannelHandler) CreateChannel(w http.ResponseWriter, r *http.Request) {
var ch model.NotificationChannel
if err := decodeJSON(r, &ch); err != nil {
response.Error(w, errors.ErrBadRequest.WithDetail(map[string]any{"error": err.Error()}))
return
}
if err := h.service.Create(r.Context(), &ch); err != nil {
response.Error(w, errors.Wrap(err, errors.ErrBadRequest))
return
}
w.WriteHeader(http.StatusCreated)
response.Success(w, ch)
}
func (h *ChannelHandler) UpdateChannel(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
var ch model.NotificationChannel
if err := decodeJSON(r, &ch); err != nil {
response.Error(w, errors.ErrBadRequest.WithDetail(map[string]any{"error": err.Error()}))
return
}
ch.ID = id
if err := h.service.Update(r.Context(), &ch); err != nil {
response.Error(w, errors.Wrap(err, errors.ErrBadRequest))
return
}
response.Success(w, ch)
}
func (h *ChannelHandler) DeleteChannel(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
if err := h.service.Delete(r.Context(), id); err != nil {
response.Error(w, errors.Wrap(err, errors.ErrInternal))
return
}
w.WriteHeader(http.StatusNoContent)
}
func (h *ChannelHandler) TestChannel(w http.ResponseWriter, r *http.Request) {
var req struct {
ChannelID string `json:"channel_id"`
Message string `json:"message"`
}
if err := decodeJSON(r, &req); err != nil {
response.Error(w, errors.ErrBadRequest.WithDetail(map[string]any{"error": err.Error()}))
return
}
response.Success(w, map[string]any{"ok": true})
}

View File

@@ -0,0 +1,391 @@
package handler
import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"sort"
"strings"
"testing"
"time"
"github.com/company/ai-ops/internal/config"
"github.com/company/ai-ops/internal/database"
"github.com/company/ai-ops/internal/domain/model"
"github.com/company/ai-ops/internal/service"
)
type fakeHandlerAlertRepo struct {
rules []model.AlertRule
events []model.AlertEvent
err error
}
func (r *fakeHandlerAlertRepo) GetOpenCount(context.Context) (*model.AlertCount, error) {
return &model.AlertCount{}, r.err
}
func (r *fakeHandlerAlertRepo) ListRules(context.Context) ([]model.AlertRule, error) {
return r.rules, r.err
}
func (r *fakeHandlerAlertRepo) GetRuleByID(_ context.Context, id string) (*model.AlertRule, error) {
if r.err != nil {
return nil, r.err
}
return &model.AlertRule{ID: id, Name: "rule"}, nil
}
func (r *fakeHandlerAlertRepo) CreateRule(context.Context, *model.AlertRule) error { return r.err }
func (r *fakeHandlerAlertRepo) UpdateRule(context.Context, *model.AlertRule) error { return r.err }
func (r *fakeHandlerAlertRepo) DeleteRule(context.Context, string) error { return r.err }
func (r *fakeHandlerAlertRepo) ListEvents(context.Context, string, int, int) ([]model.AlertEvent, int, error) {
return r.events, len(r.events), r.err
}
func (r *fakeHandlerAlertRepo) CreateEvent(context.Context, *model.AlertEvent) error { return r.err }
func (r *fakeHandlerAlertRepo) CreateEventWithAggregation(_ context.Context, e *model.AlertEvent, _ time.Duration, _ int) (*model.AlertEvent, error) {
return e, r.err
}
func (r *fakeHandlerAlertRepo) UpdateEventStatus(context.Context, string, string) error { return r.err }
func (r *fakeHandlerAlertRepo) EscalateEvent(context.Context, string, string) error { return r.err }
type fakeHandlerChannelRepo struct {
channels []model.NotificationChannel
err error
}
func (r *fakeHandlerChannelRepo) List(context.Context) ([]model.NotificationChannel, error) {
return r.channels, r.err
}
func (r *fakeHandlerChannelRepo) GetByID(_ context.Context, id string) (*model.NotificationChannel, error) {
if r.err != nil {
return nil, r.err
}
return &model.NotificationChannel{ID: id, Name: "hook"}, nil
}
func (r *fakeHandlerChannelRepo) Create(context.Context, *model.NotificationChannel) error {
return r.err
}
func (r *fakeHandlerChannelRepo) Update(context.Context, *model.NotificationChannel) error {
return r.err
}
func (r *fakeHandlerChannelRepo) Delete(context.Context, string) error { return r.err }
type fakeHandlerLogRepo struct {
logs []model.RequestLog
total int
err error
}
func (r *fakeHandlerLogRepo) Query(context.Context, model.LogQueryFilter) ([]model.RequestLog, int, error) {
return r.logs, r.total, r.err
}
func TestAuthHandlerLoginRolesAndValidation(t *testing.T) {
h := NewAuthHandler(service.NewAuthService("secret"))
cases := []struct{ username, wantRole string }{{"admin", "admin"}, {"ops", "operator"}, {"alice", "viewer"}}
for _, tc := range cases {
w := httptest.NewRecorder()
h.Login(w, httptest.NewRequest(http.MethodPost, "/api/v1/ai-ops/login", strings.NewReader(`{"username":"`+tc.username+`","password":"pw"}`)))
if w.Code != http.StatusOK || !strings.Contains(w.Body.String(), `"role":"`+tc.wantRole+`"`) || !strings.Contains(w.Body.String(), `"token"`) {
t.Fatalf("login %s failed: status=%d body=%s", tc.username, w.Code, w.Body.String())
}
}
w := httptest.NewRecorder()
h.Login(w, httptest.NewRequest(http.MethodPost, "/api/v1/ai-ops/login", strings.NewReader(`{"username":"","password":""}`)))
if w.Code != http.StatusBadRequest {
t.Fatalf("invalid login status = %d", w.Code)
}
bad := httptest.NewRecorder()
h.Login(bad, httptest.NewRequest(http.MethodPost, "/api/v1/ai-ops/login", strings.NewReader(`{`)))
if bad.Code != http.StatusBadRequest {
t.Fatalf("bad json status = %d", bad.Code)
}
}
func TestHealthAndDashboardHandlers(t *testing.T) {
health := NewHealthHandler()
w := httptest.NewRecorder()
health.Health(w, httptest.NewRequest(http.MethodGet, "/actuator/health", nil))
if w.Code != http.StatusOK || !strings.Contains(w.Body.String(), `"status":"UP"`) {
t.Fatalf("health = %d %s", w.Code, w.Body.String())
}
live := httptest.NewRecorder()
health.Live(live, httptest.NewRequest(http.MethodGet, "/actuator/health/live", nil))
if live.Code != http.StatusOK {
t.Fatalf("live = %d", live.Code)
}
ready := httptest.NewRecorder()
health.Ready(ready, httptest.NewRequest(http.MethodGet, "/actuator/health/ready", nil))
if ready.Code != http.StatusServiceUnavailable || !strings.Contains(ready.Body.String(), `"status":"DOWN"`) {
t.Fatalf("ready = %d %s", ready.Code, ready.Body.String())
}
dash := httptest.NewRecorder()
NewDashboardHandler().Dashboard(dash, httptest.NewRequest(http.MethodGet, "/ops/dashboard", nil))
if dash.Code != http.StatusOK || !strings.Contains(dash.Body.String(), "AI-Ops 运维看板") {
t.Fatalf("dashboard = %d", dash.Code)
}
}
func TestRuleHandlerCRUDHappyAndErrorPaths(t *testing.T) {
repo := &fakeHandlerAlertRepo{rules: []model.AlertRule{{ID: "r1", Name: "rule"}}}
h := NewRuleHandler(service.NewRuleService(repo))
mux := http.NewServeMux()
h.RegisterRoutes(mux)
for _, tc := range []struct {
method, path string
body string
want int
}{
{http.MethodGet, "/api/v1/ai-ops/rules", "", http.StatusOK},
{http.MethodGet, "/api/v1/ai-ops/rules/r1", "", http.StatusOK},
{http.MethodPost, "/api/v1/ai-ops/rules", `{"id":"r2","name":"latency","metric_name":"p99"}`, http.StatusCreated},
{http.MethodPut, "/api/v1/ai-ops/rules/r2", `{"name":"latency","metric_name":"p99"}`, http.StatusOK},
{http.MethodDelete, "/api/v1/ai-ops/rules/r2", "", http.StatusNoContent},
{http.MethodPost, "/api/v1/ai-ops/rules", `{`, http.StatusBadRequest},
{http.MethodPost, "/api/v1/ai-ops/rules", `{}`, http.StatusBadRequest},
} {
w := httptest.NewRecorder()
mux.ServeHTTP(w, httptest.NewRequest(tc.method, tc.path, strings.NewReader(tc.body)))
if w.Code != tc.want {
t.Fatalf("%s %s status=%d want=%d body=%s", tc.method, tc.path, w.Code, tc.want, w.Body.String())
}
}
errHandler := NewRuleHandler(service.NewRuleService(&fakeHandlerAlertRepo{err: errors.New("db")}))
errMux := http.NewServeMux()
errHandler.RegisterRoutes(errMux)
w := httptest.NewRecorder()
errMux.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/v1/ai-ops/rules", nil))
if w.Code != http.StatusInternalServerError {
t.Fatalf("error list status = %d", w.Code)
}
}
func TestChannelHandlerCRUDHappyAndErrorPaths(t *testing.T) {
h := NewChannelHandler(service.NewChannelService(&fakeHandlerChannelRepo{channels: []model.NotificationChannel{{ID: "c1", Name: "hook"}}}))
mux := http.NewServeMux()
h.RegisterRoutes(mux)
for _, tc := range []struct {
method, path, body string
want int
}{
{http.MethodGet, "/api/v1/ai-ops/channels", "", http.StatusOK},
{http.MethodGet, "/api/v1/ai-ops/channels/c1", "", http.StatusOK},
{http.MethodPost, "/api/v1/ai-ops/channels", `{"name":"hook","channel_type":"webhook"}`, http.StatusCreated},
{http.MethodPut, "/api/v1/ai-ops/channels/c1", `{"name":"hook","channel_type":"webhook"}`, http.StatusOK},
{http.MethodDelete, "/api/v1/ai-ops/channels/c1", "", http.StatusNoContent},
{http.MethodPost, "/api/v1/ai-ops/channels/test", `{"channel_id":"c1","message":"hello"}`, http.StatusOK},
{http.MethodPost, "/api/v1/ai-ops/channels", `{}`, http.StatusBadRequest},
{http.MethodPost, "/api/v1/ai-ops/channels/test", `{`, http.StatusBadRequest},
} {
w := httptest.NewRecorder()
mux.ServeHTTP(w, httptest.NewRequest(tc.method, tc.path, strings.NewReader(tc.body)))
if w.Code != tc.want {
t.Fatalf("%s %s status=%d want=%d body=%s", tc.method, tc.path, w.Code, tc.want, w.Body.String())
}
}
}
func TestAlertAndLogHandlers(t *testing.T) {
alertHandler := NewAlertHandler(&fakeHandlerAlertRepo{events: []model.AlertEvent{{ID: "e1", Status: "triggered"}}})
alertMux := http.NewServeMux()
alertHandler.RegisterRoutes(alertMux)
aw := httptest.NewRecorder()
alertMux.ServeHTTP(aw, httptest.NewRequest(http.MethodGet, "/api/v1/ai-ops/alerts?status=triggered&page=2&page_size=5", nil))
if aw.Code != http.StatusOK || !strings.Contains(aw.Body.String(), `"items"`) {
t.Fatalf("alerts = %d %s", aw.Code, aw.Body.String())
}
logHandler := NewLogHandler(service.NewLogService(&fakeHandlerLogRepo{logs: []model.RequestLog{{Timestamp: time.Date(2026, 1, 2, 3, 4, 5, 0, time.UTC), Service: "api", Path: "/v1", Method: "GET", StatusCode: 200}}, total: 1}))
logMux := http.NewServeMux()
logHandler.RegisterRoutes(logMux)
lw := httptest.NewRecorder()
logMux.ServeHTTP(lw, httptest.NewRequest(http.MethodGet, "/api/v1/ai-ops/logs?page=2&page_size=5&status_code=200", nil))
if lw.Code != http.StatusOK || !strings.Contains(lw.Body.String(), `"total_pages"`) {
t.Fatalf("logs = %d %s", lw.Code, lw.Body.String())
}
csv := httptest.NewRecorder()
logMux.ServeHTTP(csv, httptest.NewRequest(http.MethodGet, "/api/v1/ai-ops/logs/export", nil))
if csv.Code != http.StatusOK || !strings.Contains(csv.Body.String(), "时间,服务名") {
t.Fatalf("csv = %d %s", csv.Code, csv.Body.String())
}
badCSV := httptest.NewRecorder()
badLogHandler := NewLogHandler(service.NewLogService(&fakeHandlerLogRepo{err: errors.New("export failed")}))
badLogMux := http.NewServeMux()
badLogHandler.RegisterRoutes(badLogMux)
badLogMux.ServeHTTP(badCSV, httptest.NewRequest(http.MethodGet, "/api/v1/ai-ops/logs/export?start=bad&end=bad&status_code=bad", nil))
if badCSV.Code != http.StatusInternalServerError {
t.Fatalf("csv error = %d %s", badCSV.Code, badCSV.Body.String())
}
}
type fakeHandlerMetricRepo struct {
realtime *model.RealtimeMetrics
points []model.MetricPoint
err error
}
func (r *fakeHandlerMetricRepo) GetRealtime(context.Context) (*model.RealtimeMetrics, error) {
if r.err != nil {
return nil, r.err
}
return r.realtime, nil
}
func (r *fakeHandlerMetricRepo) Query(context.Context, model.MetricQueryRequest) ([]model.MetricPoint, error) {
if r.err != nil {
return nil, r.err
}
return r.points, nil
}
func (r *fakeHandlerMetricRepo) GetLatest(context.Context, string, string) (*model.MetricPoint, error) {
if r.err != nil {
return nil, r.err
}
return &model.MetricPoint{Value: 1}, nil
}
func TestRegisterRoutesForSmallHandlers(t *testing.T) {
mux := http.NewServeMux()
NewAuthHandler(service.NewAuthService("secret")).RegisterRoutes(mux)
NewDashboardHandler().RegisterRoutes(mux)
NewHealthHandler().RegisterRoutes(mux)
NewHealingHandler().RegisterRoutes(mux)
w := httptest.NewRecorder()
mux.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/v1/ai-ops/healings", nil))
if w.Code != http.StatusOK || !strings.Contains(w.Body.String(), `"total":0`) {
t.Fatalf("healings = %d %s", w.Code, w.Body.String())
}
one := httptest.NewRecorder()
mux.ServeHTTP(one, httptest.NewRequest(http.MethodGet, "/api/v1/ai-ops/healings/h1", nil))
if one.Code != http.StatusOK || !strings.Contains(one.Body.String(), `"id":"h1"`) {
t.Fatalf("healing = %d %s", one.Code, one.Body.String())
}
}
func TestMetricHandlerRoutesAndErrors(t *testing.T) {
metricRepo := &fakeHandlerMetricRepo{realtime: &model.RealtimeMetrics{QPS: 9}, points: []model.MetricPoint{{Name: "qps", Value: 1}}}
alertRepo := &fakeHandlerAlertRepo{}
h := NewMetricHandler(service.NewMetricService(metricRepo, alertRepo))
mux := http.NewServeMux()
h.RegisterRoutes(mux)
for _, path := range []string{
"/api/v1/ai-ops/metrics/realtime",
"/api/v1/ai-ops/metrics/suppliers/count",
"/api/v1/ai-ops/alerts/open/count",
"/api/v1/ai-ops/metrics/query?source=prom&name=qps&start=2026-01-01T00:00:00Z&end=2026-01-01T01:00:00Z",
} {
w := httptest.NewRecorder()
mux.ServeHTTP(w, httptest.NewRequest(http.MethodGet, path, nil))
if w.Code != http.StatusOK {
t.Fatalf("%s status=%d body=%s", path, w.Code, w.Body.String())
}
}
errHandler := NewMetricHandler(service.NewMetricService(&fakeHandlerMetricRepo{err: errors.New("metrics down")}, &fakeHandlerAlertRepo{err: errors.New("alerts down")}))
errMux := http.NewServeMux()
errHandler.RegisterRoutes(errMux)
for _, path := range []string{"/api/v1/ai-ops/metrics/realtime", "/api/v1/ai-ops/metrics/query", "/api/v1/ai-ops/alerts/open/count"} {
w := httptest.NewRecorder()
errMux.ServeHTTP(w, httptest.NewRequest(http.MethodGet, path, nil))
if w.Code != http.StatusInternalServerError {
t.Fatalf("%s status=%d", path, w.Code)
}
}
}
func setupHandlerAuditDB(t *testing.T) context.Context {
t.Helper()
ctx := context.Background()
if database.Pool == nil {
ports := []int{15432, 5432}
var lastErr error
for _, port := range ports {
lastErr = database.Init(config.DatabaseConfig{Host: "localhost", Port: port, User: "aiops", Password: "aiops123", DBName: "ai_ops", SSLMode: "disable", PoolSize: 4})
if lastErr == nil {
break
}
database.Close()
database.Pool = nil
}
if lastErr != nil {
t.Skipf("PostgreSQL integration database not available: %v", lastErr)
}
}
if _, err := database.Pool.Exec(ctx, `SELECT pg_advisory_lock(424242001)`); err != nil {
t.Fatal(err)
}
defer database.Pool.Exec(ctx, `SELECT pg_advisory_unlock(424242001)`)
files, err := filepath.Glob(filepath.Join("..", "..", "tech", "migrations", "*.up.sql"))
if err != nil {
t.Fatal(err)
}
sort.Strings(files)
for _, f := range files {
b, err := os.ReadFile(f)
if err != nil {
t.Fatal(err)
}
if _, err := database.Pool.Exec(ctx, string(b)); err != nil {
t.Fatalf("apply migration %s: %v", f, err)
}
}
return ctx
}
func handlerUUID(t *testing.T) string {
t.Helper()
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
t.Fatal(err)
}
b[6] = (b[6] & 0x0f) | 0x40
b[8] = (b[8] & 0x3f) | 0x80
return hex.EncodeToString(b[0:4]) + "-" + hex.EncodeToString(b[4:6]) + "-" + hex.EncodeToString(b[6:8]) + "-" + hex.EncodeToString(b[8:10]) + "-" + hex.EncodeToString(b[10:16])
}
func TestAuditHandlerListAndRollback(t *testing.T) {
ctx := setupHandlerAuditDB(t)
svc := service.NewAuditService()
id := handlerUUID(t)
defer database.Pool.Exec(ctx, `DELETE FROM ai_ops_audits WHERE id=$1 OR parent_audit_id=$1 OR object_id=$1`, id)
if err := svc.Record(ctx, &service.AuditLog{ID: id, TenantID: "tenant", ObjectType: "rule", ObjectID: id, Action: "update", BeforeState: map[string]any{"enabled": false}, AfterState: map[string]any{"enabled": true}, RequestID: "req", ResultCode: "SUCCESS", SourceIP: "127.0.0.1", ActorID: "actor", RiskLevel: "normal"}); err != nil {
t.Fatal(err)
}
h := NewAuditHandler(svc)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
list := httptest.NewRecorder()
mux.ServeHTTP(list, httptest.NewRequest(http.MethodGet, "/api/v1/ai-ops/audits?object_type=rule&object_id="+id+"&page=0&page_size=999", nil))
if list.Code != http.StatusOK || !strings.Contains(list.Body.String(), id) {
t.Fatalf("list audits = %d %s", list.Code, list.Body.String())
}
rollback := httptest.NewRecorder()
mux.ServeHTTP(rollback, httptest.NewRequest(http.MethodPost, "/api/v1/ai-ops/audits/"+id+"/rollback", nil))
if rollback.Code != http.StatusOK || !strings.Contains(rollback.Body.String(), `"action":"rollback"`) {
t.Fatalf("rollback = %d %s", rollback.Code, rollback.Body.String())
}
missing := httptest.NewRecorder()
mux.ServeHTTP(missing, httptest.NewRequest(http.MethodPost, "/api/v1/ai-ops/audits/"+handlerUUID(t)+"/rollback", nil))
if missing.Code != http.StatusBadRequest {
t.Fatalf("missing rollback status = %d", missing.Code)
}
}

View File

@@ -0,0 +1,118 @@
package handler
import (
"html/template"
"net/http"
)
// DashboardHandler 是前端页面路由处理器
type DashboardHandler struct {
templates *template.Template
}
func NewDashboardHandler() *DashboardHandler {
tmpl := template.Must(template.New("dashboard").Parse(dashboardHTML))
return &DashboardHandler{templates: tmpl}
}
// RegisterRoutes 注册页面路由
func (h *DashboardHandler) RegisterRoutes(mux *http.ServeMux) {
mux.HandleFunc("GET /ops/dashboard", h.Dashboard)
mux.HandleFunc("GET /ops/dashboard/logs", h.Dashboard)
mux.HandleFunc("GET /ops/dashboard/rules", h.Dashboard)
mux.HandleFunc("GET /ops/dashboard/alerts", h.Dashboard)
mux.HandleFunc("GET /ops/dashboard/channels", h.Dashboard)
}
// Dashboard 首页
func (h *DashboardHandler) Dashboard(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
_ = h.templates.ExecuteTemplate(w, "dashboard", nil)
}
const dashboardHTML = `
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="utf-8" />
<title>AI-Ops 运维看板</title>
<style>
body { font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif; margin: 0; background: #0f172a; color: #e2e8f0; }
header { padding: 18px 28px; background: #111827; border-bottom: 1px solid #334155; display:flex; justify-content:space-between; align-items:center; }
main { padding: 24px; display: grid; gap: 18px; }
.grid { display: grid; grid-template-columns: repeat(4, minmax(140px, 1fr)); gap: 14px; }
.card { background:#111827; border:1px solid #334155; border-radius:12px; padding:16px; box-shadow:0 10px 24px rgba(0,0,0,.18); }
.metric { font-size: 28px; font-weight: 700; color:#38bdf8; margin-top:8px; }
button, input { border-radius:8px; border:1px solid #475569; background:#0b1220; color:#e2e8f0; padding:8px 10px; }
button { cursor:pointer; background:#2563eb; border-color:#2563eb; }
table { width:100%; border-collapse: collapse; font-size: 14px; }
th, td { border-bottom: 1px solid #334155; padding: 8px; text-align:left; vertical-align: top; }
th { color:#93c5fd; }
.muted { color:#94a3b8; }
.row { display:flex; gap:10px; align-items:center; flex-wrap:wrap; }
.error { color:#fca5a5; white-space:pre-wrap; }
code { color:#bae6fd; }
</style>
</head>
<body>
<header>
<div><strong>AI-Ops 运维看板</strong><span class="muted"> · 规则 / 事件 / 渠道 / 日志</span></div>
<div class="row"><input id="username" placeholder="admin" value="admin"><input id="password" type="password" placeholder="admin" value="admin"><button onclick="login()">登录</button><button onclick="loadAll()">刷新</button></div>
</header>
<main>
<section class="grid">
<div class="card">QPS<div id="qps" class="metric">-</div></div>
<div class="card">平均延迟<div id="avg" class="metric">-</div></div>
<div class="card">P99<div id="p99" class="metric">-</div></div>
<div class="card">错误率<div id="err" class="metric">-</div></div>
</section>
<section class="card"><h3>告警事件</h3><div id="alerts"></div></section>
<section class="card"><h3>告警规则</h3><div id="rules"></div></section>
<section class="card"><h3>通知渠道</h3><div id="channels"></div></section>
<section class="card"><h3>日志</h3><div id="logs"></div></section>
<section class="card error" id="error"></section>
</main>
<script>
const api = '/api/v1/ai-ops';
function token(){ return localStorage.getItem('ai_ops_token') || ''; }
function setError(e){ document.getElementById('error').textContent = e ? String(e) : ''; }
async function login(){
setError('');
const res = await fetch(api + '/login', {method:'POST', headers:{'Content-Type':'application/json'}, body: JSON.stringify({username: username.value, password: password.value})});
const data = await res.json();
const t = data?.data?.token || data?.token;
if(!res.ok || !t){ setError(JSON.stringify(data)); return; }
localStorage.setItem('ai_ops_token', t);
await loadAll();
}
async function get(path){
const res = await fetch(api + path, {headers:{Authorization:'Bearer ' + token()}});
const data = await res.json();
if(!res.ok) throw new Error(path + ' ' + JSON.stringify(data));
return data.data ?? data;
}
function table(rows, cols){
if(!Array.isArray(rows) || rows.length === 0) return '<p class="muted">暂无数据</p>';
return '<table><thead><tr>'+cols.map(c=>'<th>'+c[0]+'</th>').join('')+'</tr></thead><tbody>'+rows.map(r=>'<tr>'+cols.map(c=>'<td>'+escapeHtml(String(r[c[1]] ?? ''))+'</td>').join('')+'</tr>').join('')+'</tbody></table>';
}
function escapeHtml(s){ return s.replace(/[&<>"]/g, m=>({'&':'&amp;','<':'&lt;','>':'&gt;','"':'&quot;'}[m])); }
async function loadAll(){
try{
setError('');
const m = await get('/metrics/realtime');
qps.textContent = m.qps ?? '-'; avg.textContent = (m.avg_latency_ms ?? '-') + 'ms'; p99.textContent = (m.p99_latency_ms ?? '-') + 'ms'; err.textContent = m.error_rate ?? '-';
const ev = await get('/alerts?page=1&page_size=20');
alerts.innerHTML = table(ev.items || [], [['级别','level'],['资源','resource_id'],['状态','status'],['聚合','is_aggregated'],['数量','aggregated_count'],['开始时间','started_at']]);
const rs = await get('/rules');
rules.innerHTML = table(rs || [], [['名称','name'],['指标','metric_name'],['条件','threshold_type'],['阈值','threshold_value'],['级别','level'],['启用','enabled']]);
const cs = await get('/channels');
channels.innerHTML = table(cs || [], [['名称','name'],['类型','channel_type'],['优先级','priority'],['启用','enabled']]);
const lg = await get('/logs?page=1&page_size=20');
logs.innerHTML = table(lg.items || [], [['服务','service'],['级别','level'],['消息','message'],['时间','timestamp']]);
}catch(e){ setError(e); }
}
if(token()) loadAll();
</script>
</body>
</html>
`

View File

@@ -0,0 +1,29 @@
package handler
import (
"net/http"
"github.com/company/ai-ops/pkg/response"
)
// HealingHandler 是自愈管理 HTTP 处理器
type HealingHandler struct{}
func NewHealingHandler() *HealingHandler {
return &HealingHandler{}
}
func (h *HealingHandler) RegisterRoutes(mux *http.ServeMux) {
mux.HandleFunc("GET /api/v1/ai-ops/healings", h.ListHealings)
mux.HandleFunc("GET /api/v1/ai-ops/healings/{id}", h.GetHealing)
}
func (h *HealingHandler) ListHealings(w http.ResponseWriter, r *http.Request) {
// TODO: 实现列表查询
response.Success(w, map[string]any{"items": []any{}, "total": 0})
}
func (h *HealingHandler) GetHealing(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
response.Success(w, map[string]any{"id": id, "status": "pending"})
}

View File

@@ -0,0 +1,62 @@
package handler
import (
"net/http"
"github.com/company/ai-ops/internal/database"
"github.com/company/ai-ops/internal/redis"
"github.com/company/ai-ops/pkg/response"
)
// HealthHandler 是健康检查 HTTP 处理器
type HealthHandler struct{}
func NewHealthHandler() *HealthHandler {
return &HealthHandler{}
}
func (h *HealthHandler) RegisterRoutes(mux *http.ServeMux) {
mux.HandleFunc("GET /actuator/health", h.Health)
mux.HandleFunc("GET /actuator/health/live", h.Live)
mux.HandleFunc("GET /actuator/health/ready", h.Ready)
}
func (h *HealthHandler) Health(w http.ResponseWriter, r *http.Request) {
response.Success(w, map[string]any{
"status": "UP",
"components": map[string]any{
"self": map[string]any{"status": "UP"},
},
})
}
func (h *HealthHandler) Live(w http.ResponseWriter, r *http.Request) {
response.Success(w, map[string]any{"status": "UP"})
}
func (h *HealthHandler) Ready(w http.ResponseWriter, r *http.Request) {
status := "UP"
components := map[string]any{
"self": map[string]any{"status": "UP"},
}
// 检查 DB 连接
if database.Pool == nil {
status = "DOWN"
components["database"] = map[string]any{"status": "DOWN", "detail": "not initialized"}
} else {
components["database"] = map[string]any{"status": "UP"}
}
// 检查 Redis 连接
if redis.Client == nil {
components["redis"] = map[string]any{"status": "DOWN", "detail": "not initialized"}
} else {
components["redis"] = map[string]any{"status": "UP"}
}
if status == "DOWN" {
w.WriteHeader(http.StatusServiceUnavailable)
}
response.Success(w, map[string]any{"status": status, "components": components})
}

View File

@@ -0,0 +1,109 @@
package handler
import (
"net/http"
"strconv"
"time"
"github.com/company/ai-ops/internal/domain/model"
"github.com/company/ai-ops/internal/service"
"github.com/company/ai-ops/pkg/errors"
"github.com/company/ai-ops/pkg/response"
)
// LogHandler 是日志 HTTP 处理器
type LogHandler struct {
service *service.LogService
}
func NewLogHandler(s *service.LogService) *LogHandler {
return &LogHandler{service: s}
}
// RegisterRoutes 注册日志相关路由
func (h *LogHandler) RegisterRoutes(mux *http.ServeMux) {
mux.HandleFunc("GET /api/v1/ai-ops/logs", h.QueryLogs)
mux.HandleFunc("GET /api/v1/ai-ops/logs/export", h.ExportLogs)
}
// QueryLogs 日志查询
func (h *LogHandler) QueryLogs(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
filter := model.LogQueryFilter{
Service: query.Get("service"),
Path: query.Get("path"),
UserID: query.Get("user_id"),
SupplierID: query.Get("supplier_id"),
}
if startStr := query.Get("start"); startStr != "" {
if t, err := time.Parse(time.RFC3339, startStr); err == nil {
filter.StartTime = &t
}
}
if endStr := query.Get("end"); endStr != "" {
if t, err := time.Parse(time.RFC3339, endStr); err == nil {
filter.EndTime = &t
}
}
if codeStr := query.Get("status_code"); codeStr != "" {
if code, err := strconv.Atoi(codeStr); err == nil {
filter.StatusCode = &code
}
}
if page, err := strconv.Atoi(query.Get("page")); err == nil && page > 0 {
filter.Page = page
} else {
filter.Page = 1
}
if pageSize, err := strconv.Atoi(query.Get("page_size")); err == nil && pageSize > 0 && pageSize <= 100 {
filter.PageSize = pageSize
} else {
filter.PageSize = 20
}
logs, total, err := h.service.QueryLogs(r.Context(), filter)
if err != nil {
response.Error(w, errors.Wrap(err, errors.ErrInternal))
return
}
response.PaginatedResponse(w, logs, total, filter.Page, filter.PageSize)
}
// ExportLogs 导出日志为 CSV
func (h *LogHandler) ExportLogs(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
filter := model.LogQueryFilter{
Service: query.Get("service"),
Path: query.Get("path"),
UserID: query.Get("user_id"),
SupplierID: query.Get("supplier_id"),
}
if startStr := query.Get("start"); startStr != "" {
if t, err := time.Parse(time.RFC3339, startStr); err == nil {
filter.StartTime = &t
}
}
if endStr := query.Get("end"); endStr != "" {
if t, err := time.Parse(time.RFC3339, endStr); err == nil {
filter.EndTime = &t
}
}
if codeStr := query.Get("status_code"); codeStr != "" {
if code, err := strconv.Atoi(codeStr); err == nil {
filter.StatusCode = &code
}
}
w.Header().Set("Content-Type", "text/csv; charset=utf-8")
w.Header().Set("Content-Disposition", "attachment; filename=logs_"+time.Now().Format("20060102_150405")+".csv")
if err := h.service.ExportLogsCSV(r.Context(), filter, w); err != nil {
response.Error(w, errors.Wrap(err, errors.ErrInternal))
return
}
}

View File

@@ -0,0 +1,86 @@
package handler
import (
"net/http"
"time"
"github.com/company/ai-ops/internal/domain/model"
"github.com/company/ai-ops/internal/service"
"github.com/company/ai-ops/pkg/errors"
"github.com/company/ai-ops/pkg/response"
)
// MetricHandler 是指标 HTTP 处理器
type MetricHandler struct {
service *service.MetricService
}
func NewMetricHandler(s *service.MetricService) *MetricHandler {
return &MetricHandler{service: s}
}
// RegisterRoutes 注册指标相关路由
func (h *MetricHandler) RegisterRoutes(mux *http.ServeMux) {
mux.HandleFunc("GET /api/v1/ai-ops/metrics/realtime", h.GetRealtime)
mux.HandleFunc("GET /api/v1/ai-ops/metrics/suppliers/count", h.GetSupplierCount)
mux.HandleFunc("GET /api/v1/ai-ops/alerts/open/count", h.GetOpenAlertCount)
mux.HandleFunc("GET /api/v1/ai-ops/metrics/query", h.QueryMetrics)
}
// GetRealtime 返回实时指标
func (h *MetricHandler) GetRealtime(w http.ResponseWriter, r *http.Request) {
metrics, err := h.service.GetRealtimeMetrics(r.Context())
if err != nil {
response.Error(w, errors.Wrap(err, errors.ErrInternal))
return
}
response.Success(w, metrics)
}
// GetSupplierCount 返回活跃供应商数量
func (h *MetricHandler) GetSupplierCount(w http.ResponseWriter, r *http.Request) {
count, err := h.service.GetSupplierCount(r.Context())
if err != nil {
response.Error(w, errors.Wrap(err, errors.ErrInternal))
return
}
response.Success(w, count)
}
// GetOpenAlertCount 返回未关闭告警数量
func (h *MetricHandler) GetOpenAlertCount(w http.ResponseWriter, r *http.Request) {
count, err := h.service.GetOpenAlertCount(r.Context())
if err != nil {
response.Error(w, errors.Wrap(err, errors.ErrInternal))
return
}
response.Success(w, count)
}
// QueryMetrics 指标下钻查询
func (h *MetricHandler) QueryMetrics(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
req := model.MetricQueryRequest{
Source: query.Get("source"),
Name: query.Get("name"),
}
if startStr := query.Get("start"); startStr != "" {
if t, err := time.Parse(time.RFC3339, startStr); err == nil {
req.StartTime = t
}
}
if endStr := query.Get("end"); endStr != "" {
if t, err := time.Parse(time.RFC3339, endStr); err == nil {
req.EndTime = t
}
}
points, err := h.service.QueryMetrics(r.Context(), req)
if err != nil {
response.Error(w, errors.Wrap(err, errors.ErrInternal))
return
}
response.Success(w, points)
}

View File

@@ -0,0 +1,93 @@
package handler
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/company/ai-ops/internal/domain/model"
"github.com/company/ai-ops/internal/service"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
type mockMetricRepo struct{ mock.Mock }
func (m *mockMetricRepo) GetRealtime(ctx context.Context) (*model.RealtimeMetrics, error) {
args := m.Called(ctx)
return args.Get(0).(*model.RealtimeMetrics), args.Error(1)
}
func (m *mockMetricRepo) Query(ctx context.Context, req model.MetricQueryRequest) ([]model.MetricPoint, error) {
args := m.Called(ctx, req)
return args.Get(0).([]model.MetricPoint), args.Error(1)
}
func (m *mockMetricRepo) GetLatest(ctx context.Context, source, name string) (*model.MetricPoint, error) {
args := m.Called(ctx, source, name)
return args.Get(0).(*model.MetricPoint), args.Error(1)
}
type mockAlertRepo struct{ mock.Mock }
func (m *mockAlertRepo) GetOpenCount(ctx context.Context) (*model.AlertCount, error) {
args := m.Called(ctx)
return args.Get(0).(*model.AlertCount), args.Error(1)
}
func (m *mockAlertRepo) ListRules(ctx context.Context) ([]model.AlertRule, error) {
args := m.Called(ctx)
return args.Get(0).([]model.AlertRule), args.Error(1)
}
func (m *mockAlertRepo) GetRuleByID(ctx context.Context, id string) (*model.AlertRule, error) {
args := m.Called(ctx, id)
return args.Get(0).(*model.AlertRule), args.Error(1)
}
func (m *mockAlertRepo) CreateRule(ctx context.Context, rule *model.AlertRule) error {
args := m.Called(ctx, rule)
return args.Error(0)
}
func (m *mockAlertRepo) UpdateRule(ctx context.Context, rule *model.AlertRule) error {
args := m.Called(ctx, rule)
return args.Error(0)
}
func (m *mockAlertRepo) DeleteRule(ctx context.Context, id string) error {
args := m.Called(ctx, id)
return args.Error(0)
}
func (m *mockAlertRepo) ListEvents(ctx context.Context, status string, page, pageSize int) ([]model.AlertEvent, int, error) {
args := m.Called(ctx, status, page, pageSize)
return args.Get(0).([]model.AlertEvent), args.Int(1), args.Error(2)
}
func (m *mockAlertRepo) CreateEvent(ctx context.Context, event *model.AlertEvent) error {
args := m.Called(ctx, event)
return args.Error(0)
}
func (m *mockAlertRepo) CreateEventWithAggregation(ctx context.Context, event *model.AlertEvent, window time.Duration, threshold int) (*model.AlertEvent, error) {
args := m.Called(ctx, event, window, threshold)
return args.Get(0).(*model.AlertEvent), args.Error(1)
}
func (m *mockAlertRepo) UpdateEventStatus(ctx context.Context, id, status string) error {
args := m.Called(ctx, id, status)
return args.Error(0)
}
func (m *mockAlertRepo) EscalateEvent(ctx context.Context, id, newLevel string) error {
args := m.Called(ctx, id, newLevel)
return args.Error(0)
}
func TestMetricHandler_GetRealtime(t *testing.T) {
mr := new(mockMetricRepo)
ar := new(mockAlertRepo)
svc := service.NewMetricService(mr, ar)
h := NewMetricHandler(svc)
expected := &model.RealtimeMetrics{QPS: 100, AvgLatency: 50, P99Latency: 100, ErrorRate: 0.01}
mr.On("GetRealtime", mock.Anything).Return(expected, nil)
req := httptest.NewRequest("GET", "/api/v1/ai-ops/metrics/realtime", nil)
w := httptest.NewRecorder()
h.GetRealtime(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), `"qps":100`)
}

View File

@@ -0,0 +1,84 @@
package handler
import (
"net/http"
"github.com/company/ai-ops/internal/domain/model"
"github.com/company/ai-ops/internal/service"
"github.com/company/ai-ops/pkg/errors"
"github.com/company/ai-ops/pkg/response"
)
// RuleHandler 是告警规则 HTTP 处理器
type RuleHandler struct {
service *service.RuleService
}
func NewRuleHandler(s *service.RuleService) *RuleHandler {
return &RuleHandler{service: s}
}
func (h *RuleHandler) RegisterRoutes(mux *http.ServeMux) {
mux.HandleFunc("GET /api/v1/ai-ops/rules", h.ListRules)
mux.HandleFunc("GET /api/v1/ai-ops/rules/{id}", h.GetRule)
mux.HandleFunc("POST /api/v1/ai-ops/rules", h.CreateRule)
mux.HandleFunc("PUT /api/v1/ai-ops/rules/{id}", h.UpdateRule)
mux.HandleFunc("DELETE /api/v1/ai-ops/rules/{id}", h.DeleteRule)
}
func (h *RuleHandler) ListRules(w http.ResponseWriter, r *http.Request) {
rules, err := h.service.ListRules(r.Context())
if err != nil {
response.Error(w, errors.Wrap(err, errors.ErrInternal))
return
}
response.Success(w, rules)
}
func (h *RuleHandler) GetRule(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
rule, err := h.service.GetRule(r.Context(), id)
if err != nil {
response.Error(w, errors.Wrap(err, errors.ErrNotFound))
return
}
response.Success(w, rule)
}
func (h *RuleHandler) CreateRule(w http.ResponseWriter, r *http.Request) {
var rule model.AlertRule
if err := decodeJSON(r, &rule); err != nil {
response.Error(w, errors.ErrBadRequest.WithDetail(map[string]any{"error": err.Error()}))
return
}
if err := h.service.CreateRule(r.Context(), &rule); err != nil {
response.Error(w, errors.Wrap(err, errors.ErrBadRequest))
return
}
w.WriteHeader(http.StatusCreated)
response.Success(w, rule)
}
func (h *RuleHandler) UpdateRule(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
var rule model.AlertRule
if err := decodeJSON(r, &rule); err != nil {
response.Error(w, errors.ErrBadRequest.WithDetail(map[string]any{"error": err.Error()}))
return
}
rule.ID = id
if err := h.service.UpdateRule(r.Context(), &rule); err != nil {
response.Error(w, errors.Wrap(err, errors.ErrBadRequest))
return
}
response.Success(w, rule)
}
func (h *RuleHandler) DeleteRule(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
if err := h.service.DeleteRule(r.Context(), id); err != nil {
response.Error(w, errors.Wrap(err, errors.ErrInternal))
return
}
w.WriteHeader(http.StatusNoContent)
}

10
internal/handler/utils.go Normal file
View File

@@ -0,0 +1,10 @@
package handler
import (
"encoding/json"
"net/http"
)
func decodeJSON(r *http.Request, v any) error {
return json.NewDecoder(r.Body).Decode(v)
}

View File

@@ -0,0 +1,314 @@
package repository
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"time"
"github.com/company/ai-ops/internal/database"
"github.com/company/ai-ops/internal/domain/model"
"github.com/jackc/pgx/v5"
)
// PGAlertRepository 是基于 PostgreSQL 的告警存储实现
type PGAlertRepository struct{}
func NewPGAlertRepository() *PGAlertRepository {
return &PGAlertRepository{}
}
func (r *PGAlertRepository) GetOpenCount(ctx context.Context) (*model.AlertCount, error) {
var count model.AlertCount
err := database.Pool.QueryRow(ctx, `
SELECT
COUNT(*) FILTER (WHERE status != 'resolved') AS open_count,
COUNT(*) FILTER (WHERE status != 'resolved' AND level = 'P0') AS p0_count,
COUNT(*) FILTER (WHERE status != 'resolved' AND level = 'P1') AS p1_count,
COUNT(*) FILTER (WHERE status != 'resolved' AND level = 'P2') AS p2_count,
COUNT(*) FILTER (WHERE status != 'resolved' AND level = 'P3') AS p3_count
FROM ai_ops_alerts
`).Scan(&count.Open, &count.P0, &count.P1, &count.P2, &count.P3)
if err != nil {
return nil, fmt.Errorf("query alert count: %w", err)
}
return &count, nil
}
func (r *PGAlertRepository) ListRules(ctx context.Context) ([]model.AlertRule, error) {
rows, err := database.Pool.Query(ctx, `
SELECT id, name, metric_source, metric_name, threshold_type, threshold_value,
duration_min, level, channel_ids, healing_action, healing_config,
is_sandboxed, enabled, version, created_by, created_at, updated_at
FROM ai_ops_rules
WHERE enabled = true
ORDER BY created_at DESC
`)
if err != nil {
return nil, fmt.Errorf("query rules: %w", err)
}
defer rows.Close()
rules := make([]model.AlertRule, 0)
for rows.Next() {
var ru model.AlertRule
var channelIDs []string
if err := rows.Scan(
&ru.ID, &ru.Name, &ru.MetricSource, &ru.MetricName, &ru.ThresholdType, &ru.ThresholdValue,
&ru.DurationMin, &ru.Level, &channelIDs, &ru.HealingAction, &ru.HealingConfig,
&ru.IsSandboxed, &ru.Enabled, &ru.Version, &ru.CreatedBy, &ru.CreatedAt, &ru.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("scan rule: %w", err)
}
ru.ChannelIDs = channelIDs
rules = append(rules, ru)
}
return rules, rows.Err()
}
func (r *PGAlertRepository) GetRuleByID(ctx context.Context, id string) (*model.AlertRule, error) {
var ru model.AlertRule
var channelIDs []string
err := database.Pool.QueryRow(ctx, `
SELECT id, name, metric_source, metric_name, threshold_type, threshold_value,
duration_min, level, channel_ids, healing_action, healing_config,
is_sandboxed, enabled, version, created_by, created_at, updated_at
FROM ai_ops_rules WHERE id = $1
`, id).Scan(
&ru.ID, &ru.Name, &ru.MetricSource, &ru.MetricName, &ru.ThresholdType, &ru.ThresholdValue,
&ru.DurationMin, &ru.Level, &channelIDs, &ru.HealingAction, &ru.HealingConfig,
&ru.IsSandboxed, &ru.Enabled, &ru.Version, &ru.CreatedBy, &ru.CreatedAt, &ru.UpdatedAt,
)
if err == pgx.ErrNoRows {
return nil, fmt.Errorf("rule not found")
}
if err != nil {
return nil, fmt.Errorf("query rule: %w", err)
}
ru.ChannelIDs = channelIDs
return &ru, nil
}
func (r *PGAlertRepository) CreateRule(ctx context.Context, rule *model.AlertRule) error {
_, err := database.Pool.Exec(ctx, `
INSERT INTO ai_ops_rules (id, name, metric_source, metric_name, threshold_type, threshold_value,
duration_min, level, channel_ids, healing_action, healing_config,
is_sandboxed, enabled, version, created_by, created_at, updated_at)
VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,NOW(),NOW())
`, rule.ID, rule.Name, rule.MetricSource, rule.MetricName, rule.ThresholdType, rule.ThresholdValue,
rule.DurationMin, rule.Level, rule.ChannelIDs, rule.HealingAction, rule.HealingConfig,
rule.IsSandboxed, rule.Enabled, rule.Version, rule.CreatedBy)
if err != nil {
return fmt.Errorf("insert rule: %w", err)
}
return nil
}
func (r *PGAlertRepository) UpdateRule(ctx context.Context, rule *model.AlertRule) error {
_, err := database.Pool.Exec(ctx, `
UPDATE ai_ops_rules SET
name=$2, metric_source=$3, metric_name=$4, threshold_type=$5, threshold_value=$6,
duration_min=$7, level=$8, channel_ids=$9, healing_action=$10, healing_config=$11,
is_sandboxed=$12, enabled=$13, version=$14, updated_at=NOW()
WHERE id=$1
`, rule.ID, rule.Name, rule.MetricSource, rule.MetricName, rule.ThresholdType, rule.ThresholdValue,
rule.DurationMin, rule.Level, rule.ChannelIDs, rule.HealingAction, rule.HealingConfig,
rule.IsSandboxed, rule.Enabled, rule.Version)
if err != nil {
return fmt.Errorf("update rule: %w", err)
}
return nil
}
func (r *PGAlertRepository) DeleteRule(ctx context.Context, id string) error {
_, err := database.Pool.Exec(ctx, `DELETE FROM ai_ops_rules WHERE id = $1`, id)
if err != nil {
return fmt.Errorf("delete rule: %w", err)
}
return nil
}
func (r *PGAlertRepository) ListEvents(ctx context.Context, status string, page, pageSize int) ([]model.AlertEvent, int, error) {
where := ""
args := []any{}
if status != "" {
where = "WHERE status = $1"
args = append(args, status)
}
var total int
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM ai_ops_alerts %s", where)
if err := database.Pool.QueryRow(ctx, countQuery, args...).Scan(&total); err != nil {
return nil, 0, fmt.Errorf("count events: %w", err)
}
if page < 1 {
page = 1
}
if pageSize < 1 || pageSize > 100 {
pageSize = 20
}
offset := (page - 1) * pageSize
dataQuery := fmt.Sprintf(`
SELECT id, rule_id, level, resource_type, resource_id, current_value, threshold_value,
status, is_aggregated, aggregated_count, parent_alert_id, started_at, resolved_at
FROM ai_ops_alerts %s
ORDER BY started_at DESC
LIMIT $%d OFFSET $%d
`, where, len(args)+1, len(args)+2)
queryArgs := append(args, pageSize, offset)
rows, err := database.Pool.Query(ctx, dataQuery, queryArgs...)
if err != nil {
return nil, 0, fmt.Errorf("query events: %w", err)
}
defer rows.Close()
events := make([]model.AlertEvent, 0)
for rows.Next() {
var e model.AlertEvent
if err := rows.Scan(
&e.ID, &e.RuleID, &e.Level, &e.ResourceType, &e.ResourceID,
&e.CurrentValue, &e.ThresholdValue, &e.Status, &e.IsAggregated, &e.AggregatedCount,
&e.ParentAlertID, &e.StartedAt, &e.ResolvedAt,
); err != nil {
return nil, 0, fmt.Errorf("scan event: %w", err)
}
events = append(events, e)
}
return events, total, rows.Err()
}
func (r *PGAlertRepository) CreateEvent(ctx context.Context, event *model.AlertEvent) error {
_, err := r.CreateEventWithAggregation(ctx, event, 0, 0)
return err
}
func (r *PGAlertRepository) CreateEventWithAggregation(ctx context.Context, event *model.AlertEvent, window time.Duration, threshold int) (*model.AlertEvent, error) {
tx, err := database.Pool.Begin(ctx)
if err != nil {
return nil, fmt.Errorf("begin create event: %w", err)
}
defer tx.Rollback(ctx)
startedAt := event.StartedAt
if startedAt.IsZero() {
startedAt = time.Now()
}
_, err = tx.Exec(ctx, `
INSERT INTO ai_ops_alerts (id, rule_id, level, resource_type, resource_id,
current_value, threshold_value, status, is_aggregated, aggregated_count, parent_alert_id, started_at)
VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12)
`, event.ID, event.RuleID, event.Level, event.ResourceType, event.ResourceID,
event.CurrentValue, event.ThresholdValue, event.Status, event.IsAggregated,
event.AggregatedCount, event.ParentAlertID, startedAt)
if err != nil {
return nil, fmt.Errorf("insert event: %w", err)
}
event.StartedAt = startedAt
if window <= 0 || threshold <= 0 {
if err := tx.Commit(ctx); err != nil {
return nil, fmt.Errorf("commit event: %w", err)
}
return event, nil
}
var count int
err = tx.QueryRow(ctx, `
SELECT COUNT(*)
FROM ai_ops_alerts
WHERE resource_type = $1
AND resource_id = $2
AND started_at >= $3
AND is_aggregated = false
AND parent_alert_id IS NULL
`, event.ResourceType, event.ResourceID, startedAt.Add(-window)).Scan(&count)
if err != nil {
return nil, fmt.Errorf("count aggregation candidates: %w", err)
}
if count <= threshold {
if err := tx.Commit(ctx); err != nil {
return nil, fmt.Errorf("commit event: %w", err)
}
return event, nil
}
aggregated := &model.AlertEvent{
ID: newUUID(),
RuleID: event.RuleID,
Level: event.Level,
ResourceType: event.ResourceType,
ResourceID: event.ResourceID,
CurrentValue: event.CurrentValue,
ThresholdValue: fmt.Sprintf("cluster_count>%d", threshold),
Status: event.Status,
IsAggregated: true,
AggregatedCount: count,
StartedAt: startedAt,
}
_, err = tx.Exec(ctx, `
INSERT INTO ai_ops_alerts (id, rule_id, level, resource_type, resource_id,
current_value, threshold_value, status, is_aggregated, aggregated_count, started_at)
VALUES ($1,$2,$3,$4,$5,$6,$7,$8,true,$9,$10)
`, aggregated.ID, aggregated.RuleID, aggregated.Level, aggregated.ResourceType, aggregated.ResourceID,
aggregated.CurrentValue, aggregated.ThresholdValue, aggregated.Status, aggregated.AggregatedCount, aggregated.StartedAt)
if err != nil {
return nil, fmt.Errorf("insert aggregated event: %w", err)
}
_, err = tx.Exec(ctx, `
UPDATE ai_ops_alerts
SET parent_alert_id = $1
WHERE resource_type = $2
AND resource_id = $3
AND started_at >= $4
AND is_aggregated = false
AND parent_alert_id IS NULL
`, aggregated.ID, event.ResourceType, event.ResourceID, startedAt.Add(-window))
if err != nil {
return nil, fmt.Errorf("attach aggregated children: %w", err)
}
if err := tx.Commit(ctx); err != nil {
return nil, fmt.Errorf("commit aggregated event: %w", err)
}
return aggregated, nil
}
func (r *PGAlertRepository) UpdateEventStatus(ctx context.Context, id, status string) error {
resolvedAt := "NULL"
if status == "resolved" {
resolvedAt = "NOW()"
}
_, err := database.Pool.Exec(ctx, fmt.Sprintf(`
UPDATE ai_ops_alerts SET status = $2, resolved_at = %s WHERE id = $1
`, resolvedAt), id, status)
if err != nil {
return fmt.Errorf("update event status: %w", err)
}
return nil
}
func (r *PGAlertRepository) EscalateEvent(ctx context.Context, id, newLevel string) error {
_, err := database.Pool.Exec(ctx, `UPDATE ai_ops_alerts SET level = $2 WHERE id = $1`, id, newLevel)
if err != nil {
return fmt.Errorf("escalate event: %w", err)
}
return nil
}
func newUUID() string {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return fmt.Sprintf("00000000-0000-4000-8000-%012d", time.Now().UnixNano()%1_000_000_000_000)
}
b[6] = (b[6] & 0x0f) | 0x40
b[8] = (b[8] & 0x3f) | 0x80
return fmt.Sprintf("%s-%s-%s-%s-%s", hex.EncodeToString(b[0:4]), hex.EncodeToString(b[4:6]), hex.EncodeToString(b[6:8]), hex.EncodeToString(b[8:10]), hex.EncodeToString(b[10:16]))
}

View File

@@ -0,0 +1,87 @@
package repository
import (
"context"
"fmt"
"github.com/company/ai-ops/internal/database"
"github.com/company/ai-ops/internal/domain/model"
"github.com/jackc/pgx/v5"
)
// PGChannelRepository 是基于 PostgreSQL 的渠道存储实现
type PGChannelRepository struct{}
func NewPGChannelRepository() *PGChannelRepository {
return &PGChannelRepository{}
}
func (r *PGChannelRepository) List(ctx context.Context) ([]model.NotificationChannel, error) {
rows, err := database.Pool.Query(ctx, `
SELECT id, name, channel_type, config, priority, enabled, created_at
FROM ai_ops_channels
WHERE enabled = true
ORDER BY priority DESC, created_at DESC
`)
if err != nil {
return nil, fmt.Errorf("query channels: %w", err)
}
defer rows.Close()
channels := make([]model.NotificationChannel, 0)
for rows.Next() {
var c model.NotificationChannel
if err := rows.Scan(&c.ID, &c.Name, &c.ChannelType, &c.Config, &c.Priority, &c.Enabled, &c.CreatedAt); err != nil {
return nil, fmt.Errorf("scan channel: %w", err)
}
channels = append(channels, c)
}
return channels, rows.Err()
}
func (r *PGChannelRepository) GetByID(ctx context.Context, id string) (*model.NotificationChannel, error) {
var c model.NotificationChannel
err := database.Pool.QueryRow(ctx, `
SELECT id, name, channel_type, config, priority, enabled, created_at
FROM ai_ops_channels
WHERE id = $1
`, id).Scan(&c.ID, &c.Name, &c.ChannelType, &c.Config, &c.Priority, &c.Enabled, &c.CreatedAt)
if err == pgx.ErrNoRows {
return nil, fmt.Errorf("channel not found")
}
if err != nil {
return nil, fmt.Errorf("query channel: %w", err)
}
return &c, nil
}
func (r *PGChannelRepository) Create(ctx context.Context, ch *model.NotificationChannel) error {
_, err := database.Pool.Exec(ctx, `
INSERT INTO ai_ops_channels (id, name, channel_type, config, priority, enabled, created_at)
VALUES ($1, $2, $3, $4, $5, $6, NOW())
`, ch.ID, ch.Name, ch.ChannelType, ch.Config, ch.Priority, ch.Enabled)
if err != nil {
return fmt.Errorf("insert channel: %w", err)
}
return nil
}
func (r *PGChannelRepository) Update(ctx context.Context, ch *model.NotificationChannel) error {
_, err := database.Pool.Exec(ctx, `
UPDATE ai_ops_channels
SET name = $2, channel_type = $3, config = $4, priority = $5, enabled = $6
WHERE id = $1
`, ch.ID, ch.Name, ch.ChannelType, ch.Config, ch.Priority, ch.Enabled)
if err != nil {
return fmt.Errorf("update channel: %w", err)
}
return nil
}
func (r *PGChannelRepository) Delete(ctx context.Context, id string) error {
_, err := database.Pool.Exec(ctx, `DELETE FROM ai_ops_channels WHERE id = $1`, id)
if err != nil {
return fmt.Errorf("delete channel: %w", err)
}
return nil
}

View File

@@ -0,0 +1,38 @@
package repository
import (
"context"
"fmt"
"github.com/company/ai-ops/internal/database"
"github.com/company/ai-ops/internal/service"
)
// PGHealingRepository 是自愈记录的 PostgreSQL 实现
type PGHealingRepository struct{}
func NewPGHealingRepository() *PGHealingRepository {
return &PGHealingRepository{}
}
func (r *PGHealingRepository) CreateHealing(ctx context.Context, h *service.HealingLog) error {
_, err := database.Pool.Exec(ctx, `
INSERT INTO ai_ops_healings (id, alert_id, action_type, config, status, dry_run, result_detail, error_code, started_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
`, h.ID, h.AlertID, h.ActionType, h.Config, h.Status, h.DryRun, h.ResultDetail, h.ErrorCode, h.StartedAt)
if err != nil {
return fmt.Errorf("insert healing: %w", err)
}
return nil
}
func (r *PGHealingRepository) UpdateHealingStatus(ctx context.Context, id, status string, result map[string]any, errCode string) error {
_, err := database.Pool.Exec(ctx, `
UPDATE ai_ops_healings SET status = $2, result_detail = $3, error_code = $4, completed_at = NOW()
WHERE id = $1
`, id, status, result, errCode)
if err != nil {
return fmt.Errorf("update healing: %w", err)
}
return nil
}

View File

@@ -0,0 +1,112 @@
package repository
import (
"context"
"fmt"
"strings"
"github.com/company/ai-ops/internal/database"
"github.com/company/ai-ops/internal/domain/model"
)
// PGLogRepository 是基于 PostgreSQL 的日志存储实现
type PGLogRepository struct{}
func NewPGLogRepository() *PGLogRepository {
return &PGLogRepository{}
}
func (r *PGLogRepository) Query(ctx context.Context, filter model.LogQueryFilter) ([]model.RequestLog, int, error) {
// 构建查询条件(参数化查询)
var conditions []string
var args []any
argIdx := 1
if filter.StartTime != nil {
conditions = append(conditions, fmt.Sprintf("timestamp >= $%d", argIdx))
args = append(args, *filter.StartTime)
argIdx++
}
if filter.EndTime != nil {
conditions = append(conditions, fmt.Sprintf("timestamp <= $%d", argIdx))
args = append(args, *filter.EndTime)
argIdx++
}
if filter.Service != "" {
conditions = append(conditions, fmt.Sprintf("service = $%d", argIdx))
args = append(args, filter.Service)
argIdx++
}
if filter.Path != "" {
conditions = append(conditions, fmt.Sprintf("path = $%d", argIdx))
args = append(args, filter.Path)
argIdx++
}
if filter.StatusCode != nil {
conditions = append(conditions, fmt.Sprintf("status_code = $%d", argIdx))
args = append(args, *filter.StatusCode)
argIdx++
}
if filter.UserID != "" {
conditions = append(conditions, fmt.Sprintf("user_id = $%d", argIdx))
args = append(args, filter.UserID)
argIdx++
}
if filter.SupplierID != "" {
conditions = append(conditions, fmt.Sprintf("supplier_id = $%d", argIdx))
args = append(args, filter.SupplierID)
argIdx++
}
whereClause := ""
if len(conditions) > 0 {
whereClause = "WHERE " + strings.Join(conditions, " AND ")
}
// 查询总数
var total int
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM ai_ops_request_logs %s", whereClause)
if err := database.Pool.QueryRow(ctx, countQuery, args...).Scan(&total); err != nil {
return nil, 0, fmt.Errorf("count logs: %w", err)
}
// 查询分页数据
page := filter.Page
if page < 1 {
page = 1
}
pageSize := filter.PageSize
if pageSize < 1 || pageSize > 100 {
pageSize = 20
}
offset := (page - 1) * pageSize
queryArgs := append(args, pageSize, offset)
dataQuery := fmt.Sprintf(`
SELECT id, timestamp, service, path, status_code, latency_ms, user_id, supplier_id, method, error_code
FROM ai_ops_request_logs
%s
ORDER BY timestamp DESC
LIMIT $%d OFFSET $%d
`, whereClause, argIdx, argIdx+1)
rows, err := database.Pool.Query(ctx, dataQuery, queryArgs...)
if err != nil {
return nil, 0, fmt.Errorf("query logs: %w", err)
}
defer rows.Close()
var logs []model.RequestLog
for rows.Next() {
var l model.RequestLog
if err := rows.Scan(
&l.ID, &l.Timestamp, &l.Service, &l.Path, &l.StatusCode,
&l.LatencyMs, &l.UserID, &l.SupplierID, &l.Method, &l.ErrorCode,
); err != nil {
return nil, 0, fmt.Errorf("scan log: %w", err)
}
logs = append(logs, l)
}
return logs, total, rows.Err()
}

View File

@@ -0,0 +1,95 @@
package repository
import (
"context"
"fmt"
"github.com/company/ai-ops/internal/database"
"github.com/company/ai-ops/internal/domain/model"
"github.com/jackc/pgx/v5"
)
// PGMetricRepository 是基于 PostgreSQL 的指标存储实现
type PGMetricRepository struct{}
func NewPGMetricRepository() *PGMetricRepository {
return &PGMetricRepository{}
}
func (r *PGMetricRepository) GetRealtime(ctx context.Context) (*model.RealtimeMetrics, error) {
// 从 ai_ops_metrics 表查询各指标的最新值
queries := map[string]*float64{
"qps": new(float64),
"avg_latency": new(float64),
"p99_latency": new(float64),
"error_rate": new(float64),
}
for name, ptr := range queries {
var value float64
err := database.Pool.QueryRow(ctx, `
SELECT value FROM ai_ops_metrics
WHERE metric_name = $1
ORDER BY recorded_at DESC
LIMIT 1
`, name).Scan(&value)
if err != nil && err != pgx.ErrNoRows {
return nil, fmt.Errorf("query %s: %w", name, err)
}
*ptr = value
}
return &model.RealtimeMetrics{
QPS: *queries["qps"],
AvgLatency: *queries["avg_latency"],
P99Latency: *queries["p99_latency"],
ErrorRate: *queries["error_rate"],
}, nil
}
func (r *PGMetricRepository) Query(ctx context.Context, req model.MetricQueryRequest) ([]model.MetricPoint, error) {
rows, err := database.Pool.Query(ctx, `
SELECT metric_name, labels, value, recorded_at
FROM ai_ops_metrics
WHERE metric_name = $1
AND recorded_at >= $2
AND recorded_at <= $3
ORDER BY recorded_at DESC
`, req.Name, req.StartTime, req.EndTime)
if err != nil {
return nil, fmt.Errorf("query metrics: %w", err)
}
defer rows.Close()
var points []model.MetricPoint
for rows.Next() {
var p model.MetricPoint
var labels map[string]string
if err := rows.Scan(&p.Name, &labels, &p.Value, &p.Timestamp); err != nil {
return nil, fmt.Errorf("scan metric: %w", err)
}
p.Source = req.Source
p.Tags = labels
points = append(points, p)
}
return points, rows.Err()
}
func (r *PGMetricRepository) GetLatest(ctx context.Context, source, name string) (*model.MetricPoint, error) {
var p model.MetricPoint
var labels map[string]string
err := database.Pool.QueryRow(ctx, `
SELECT metric_name, labels, value, recorded_at
FROM ai_ops_metrics
WHERE metric_name = $1
ORDER BY recorded_at DESC
LIMIT 1
`, name).Scan(&p.Name, &labels, &p.Value, &p.Timestamp)
if err != nil {
return nil, fmt.Errorf("query latest metric: %w", err)
}
p.Source = source
p.Tags = labels
return &p, nil
}

View File

@@ -0,0 +1,57 @@
package repository
import (
"context"
"fmt"
"github.com/company/ai-ops/internal/database"
"github.com/company/ai-ops/internal/domain/model"
)
// PGNotificationLogRepository 是基于 PostgreSQL 的通知日志存储实现。
type PGNotificationLogRepository struct{}
func NewPGNotificationLogRepository() *PGNotificationLogRepository {
return &PGNotificationLogRepository{}
}
func (r *PGNotificationLogRepository) CreateLog(ctx context.Context, log *model.NotificationLog) error {
if log.ID == "" {
log.ID = newUUID()
}
if log.Status == "" {
log.Status = "pending"
}
_, err := database.Pool.Exec(ctx, `
INSERT INTO ai_ops_notification_logs (id, event_id, channel_id, channel_type, status, retry_count, error_message, sent_at, created_at)
VALUES ($1,$2,$3,$4,$5,$6,$7,$8,NOW())
`, log.ID, log.EventID, log.ChannelID, log.ChannelType, log.Status, log.RetryCount, log.ErrorMessage, log.SentAt)
if err != nil {
return fmt.Errorf("insert notification log: %w", err)
}
return nil
}
func (r *PGNotificationLogRepository) MarkSent(ctx context.Context, id string) error {
_, err := database.Pool.Exec(ctx, `
UPDATE ai_ops_notification_logs
SET status='sent', sent_at=NOW(), error_message=NULL
WHERE id=$1
`, id)
if err != nil {
return fmt.Errorf("mark notification sent: %w", err)
}
return nil
}
func (r *PGNotificationLogRepository) MarkFailed(ctx context.Context, id string, retryCount int, errMessage string) error {
_, err := database.Pool.Exec(ctx, `
UPDATE ai_ops_notification_logs
SET status='failed', retry_count=$2, error_message=$3
WHERE id=$1
`, id, retryCount, errMessage)
if err != nil {
return fmt.Errorf("mark notification failed: %w", err)
}
return nil
}

View File

@@ -0,0 +1,269 @@
package repository
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"os"
"path/filepath"
"sort"
"sync"
"testing"
"time"
"github.com/company/ai-ops/internal/config"
"github.com/company/ai-ops/internal/database"
"github.com/company/ai-ops/internal/domain/model"
"github.com/company/ai-ops/internal/service"
)
var pgMigrationOnce sync.Once
var pgMigrationErr error
func setupPGIntegration(t *testing.T) context.Context {
t.Helper()
ctx := context.Background()
if database.Pool == nil {
ports := []int{15432, 5432}
var lastErr error
for _, port := range ports {
lastErr = database.Init(config.DatabaseConfig{Host: "localhost", Port: port, User: "aiops", Password: "aiops123", DBName: "ai_ops", SSLMode: "disable", PoolSize: 4})
if lastErr == nil {
break
}
database.Close()
database.Pool = nil
}
if lastErr != nil {
t.Skipf("PostgreSQL integration database not available: %v", lastErr)
}
}
pgMigrationOnce.Do(func() {
pgMigrationErr = applyMigrations(ctx)
})
if pgMigrationErr != nil {
t.Fatalf("apply migrations: %v", pgMigrationErr)
}
return ctx
}
func applyMigrations(ctx context.Context) error {
if _, err := database.Pool.Exec(ctx, `SELECT pg_advisory_lock(424242001)`); err != nil {
return err
}
defer database.Pool.Exec(ctx, `SELECT pg_advisory_unlock(424242001)`)
files, err := filepath.Glob(filepath.Join("..", "..", "..", "tech", "migrations", "*.up.sql"))
if err != nil {
return err
}
sort.Strings(files)
for _, f := range files {
b, err := os.ReadFile(f)
if err != nil {
return err
}
if _, err := database.Pool.Exec(ctx, string(b)); err != nil {
return fmt.Errorf("%s: %w", f, err)
}
}
return nil
}
func testUUID(t *testing.T) string {
t.Helper()
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
t.Fatal(err)
}
b[6] = (b[6] & 0x0f) | 0x40
b[8] = (b[8] & 0x3f) | 0x80
return hex.EncodeToString(b[0:4]) + "-" + hex.EncodeToString(b[4:6]) + "-" + hex.EncodeToString(b[6:8]) + "-" + hex.EncodeToString(b[8:10]) + "-" + hex.EncodeToString(b[10:16])
}
func cleanupIDs(t *testing.T, ctx context.Context, ids ...string) {
t.Helper()
for _, id := range ids {
_, _ = database.Pool.Exec(ctx, `DELETE FROM ai_ops_notification_logs WHERE id=$1 OR event_id=$1 OR channel_id=$1`, id)
_, _ = database.Pool.Exec(ctx, `DELETE FROM ai_ops_healings WHERE id=$1 OR alert_id=$1`, id)
_, _ = database.Pool.Exec(ctx, `DELETE FROM ai_ops_alerts WHERE id=$1 OR rule_id=$1 OR parent_alert_id=$1`, id)
_, _ = database.Pool.Exec(ctx, `DELETE FROM ai_ops_rules WHERE id=$1`, id)
_, _ = database.Pool.Exec(ctx, `DELETE FROM ai_ops_channels WHERE id=$1`, id)
_, _ = database.Pool.Exec(ctx, `DELETE FROM ai_ops_request_logs WHERE id=$1`, id)
}
}
func TestPGChannelRepositoryCRUD(t *testing.T) {
ctx := setupPGIntegration(t)
repo := NewPGChannelRepository()
id := testUUID(t)
defer cleanupIDs(t, ctx, id)
ch := &model.NotificationChannel{ID: id, Name: "test-channel", ChannelType: "webhook", Config: map[string]any{"webhook_url": "http://example.invalid"}, Priority: 7, Enabled: true}
if err := repo.Create(ctx, ch); err != nil {
t.Fatal(err)
}
got, err := repo.GetByID(ctx, id)
if err != nil || got.ID != id || got.Name != ch.Name {
t.Fatalf("get channel = %+v %v", got, err)
}
list, err := repo.List(ctx)
if err != nil {
t.Fatal(err)
}
found := false
for _, item := range list {
if item.ID == id {
found = true
}
}
if !found {
t.Fatalf("created channel not found in list: %+v", list)
}
ch.Name = "updated-channel"
ch.Priority = 8
if err := repo.Update(ctx, ch); err != nil {
t.Fatal(err)
}
updated, err := repo.GetByID(ctx, id)
if err != nil || updated.Name != "updated-channel" || updated.Priority != 8 {
t.Fatalf("updated channel = %+v %v", updated, err)
}
if err := repo.Delete(ctx, id); err != nil {
t.Fatal(err)
}
if _, err := repo.GetByID(ctx, id); err == nil {
t.Fatal("expected not found after delete")
}
}
func TestPGAlertRepositoryRulesEventsAndAggregation(t *testing.T) {
ctx := setupPGIntegration(t)
repo := NewPGAlertRepository()
ruleID, eventID, childID := testUUID(t), testUUID(t), testUUID(t)
defer cleanupIDs(t, ctx, ruleID, eventID, childID)
rule := &model.AlertRule{ID: ruleID, Name: "rule-" + ruleID, MetricSource: "prom", MetricName: "p99", ThresholdType: ">", ThresholdValue: "100", DurationMin: 1, Level: "P1", ChannelIDs: []string{}, IsSandboxed: true, Enabled: true, Version: 1, CreatedBy: "test"}
if err := repo.CreateRule(ctx, rule); err != nil {
t.Fatal(err)
}
if got, err := repo.GetRuleByID(ctx, ruleID); err != nil || got.ID != ruleID || got.Name != rule.Name {
t.Fatalf("get rule = %+v %v", got, err)
}
rules, err := repo.ListRules(ctx)
if err != nil || len(rules) == 0 {
t.Fatalf("list rules = %d %v", len(rules), err)
}
rule.Name = "rule-updated-" + ruleID
rule.Version = 2
if err := repo.UpdateRule(ctx, rule); err != nil {
t.Fatal(err)
}
now := time.Now().UTC()
event := &model.AlertEvent{ID: eventID, RuleID: ruleID, Level: "P1", ResourceType: "svc", ResourceID: "res-" + ruleID, CurrentValue: "120", ThresholdValue: "100", Status: "triggered", StartedAt: now}
created, err := repo.CreateEventWithAggregation(ctx, event, time.Minute, 10)
if err != nil || created.ID != eventID {
t.Fatalf("create event = %+v %v", created, err)
}
directID := testUUID(t)
defer cleanupIDs(t, ctx, directID)
if err := repo.CreateEvent(ctx, &model.AlertEvent{ID: directID, RuleID: ruleID, Level: "P2", ResourceType: "svc", ResourceID: "direct-" + ruleID, CurrentValue: "101", ThresholdValue: "100", Status: "triggered", StartedAt: now.Add(2 * time.Second)}); err != nil {
t.Fatalf("create direct event: %v", err)
}
if err := repo.UpdateEventStatus(ctx, eventID, "resolved"); err != nil {
t.Fatal(err)
}
if err := repo.EscalateEvent(ctx, eventID, "P0"); err != nil {
t.Fatal(err)
}
agg, err := repo.CreateEventWithAggregation(ctx, &model.AlertEvent{ID: childID, RuleID: ruleID, Level: "P1", ResourceType: "svc", ResourceID: "res-" + ruleID, CurrentValue: "130", ThresholdValue: "100", Status: "triggered", StartedAt: now.Add(time.Second)}, time.Minute, 1)
if err != nil || !agg.IsAggregated || agg.AggregatedCount < 2 {
t.Fatalf("aggregation = %+v %v", agg, err)
}
defer cleanupIDs(t, ctx, agg.ID)
events, total, err := repo.ListEvents(ctx, "triggered", 1, 20)
if err != nil || total < 1 || len(events) < 1 {
t.Fatalf("list events = total=%d len=%d err=%v", total, len(events), err)
}
count, err := repo.GetOpenCount(ctx)
if err != nil || count.Open < 1 {
t.Fatalf("open count = %+v %v", count, err)
}
if err := repo.DeleteRule(ctx, ruleID); err != nil {
t.Fatal(err)
}
}
func TestPGMetricAndLogRepositories(t *testing.T) {
ctx := setupPGIntegration(t)
metricRepo := NewPGMetricRepository()
logRepo := NewPGLogRepository()
logID := testUUID(t)
metricName := "test_metric_" + logID
defer cleanupIDs(t, ctx, logID)
defer database.Pool.Exec(ctx, `DELETE FROM ai_ops_metrics WHERE metric_name=$1`, metricName)
now := time.Now().UTC()
if _, err := database.Pool.Exec(ctx, `INSERT INTO ai_ops_metrics(metric_name, labels, value, recorded_at) VALUES ($1, $2, $3, $4)`, metricName, map[string]string{"source": "test"}, 42.5, now); err != nil {
t.Fatal(err)
}
latest, err := metricRepo.GetLatest(ctx, "unit", metricName)
if err != nil || latest.Name != metricName || latest.Source != "unit" || latest.Value != 42.5 {
t.Fatalf("latest metric = %+v %v", latest, err)
}
points, err := metricRepo.Query(ctx, model.MetricQueryRequest{Source: "unit", Name: metricName, StartTime: now.Add(-time.Minute), EndTime: now.Add(time.Minute)})
if err != nil || len(points) != 1 {
t.Fatalf("query metric = %d %v", len(points), err)
}
if realtime, err := metricRepo.GetRealtime(ctx); err != nil || realtime == nil {
t.Fatalf("realtime metric = %+v %v", realtime, err)
}
if _, err := database.Pool.Exec(ctx, `INSERT INTO ai_ops_request_logs(id, timestamp, service, path, method, status_code, latency_ms, user_id, supplier_id, error_code) VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10)`, logID, now, "svc-test", "/unit", "GET", 200, 11.2, "u1", "s1", ""); err != nil {
t.Fatal(err)
}
status := 200
logs, total, err := logRepo.Query(ctx, model.LogQueryFilter{Service: "svc-test", Path: "/unit", StatusCode: &status, UserID: "u1", SupplierID: "s1", Page: 1, PageSize: 10})
if err != nil || total != 1 || len(logs) != 1 || logs[0].ID != logID {
t.Fatalf("query logs = total=%d logs=%+v err=%v", total, logs, err)
}
}
func TestPGHealingAndNotificationRepositories(t *testing.T) {
ctx := setupPGIntegration(t)
ruleID, eventID, channelID, healingID, notificationID := testUUID(t), testUUID(t), testUUID(t), testUUID(t), testUUID(t)
defer cleanupIDs(t, ctx, ruleID, eventID, channelID, healingID, notificationID)
alertRepo := NewPGAlertRepository()
channelRepo := NewPGChannelRepository()
healingRepo := NewPGHealingRepository()
notificationRepo := NewPGNotificationLogRepository()
if err := alertRepo.CreateRule(ctx, &model.AlertRule{ID: ruleID, Name: "notify-rule-" + ruleID, MetricSource: "prom", MetricName: "qps", ThresholdType: ">", ThresholdValue: "1", DurationMin: 1, Level: "P2", ChannelIDs: []string{}, IsSandboxed: true, Enabled: true, Version: 1, CreatedBy: "test"}); err != nil {
t.Fatal(err)
}
if _, err := alertRepo.CreateEventWithAggregation(ctx, &model.AlertEvent{ID: eventID, RuleID: ruleID, Level: "P2", ResourceType: "svc", ResourceID: "res", CurrentValue: "2", ThresholdValue: "1", Status: "triggered", StartedAt: time.Now().UTC()}, 0, 0); err != nil {
t.Fatal(err)
}
if err := channelRepo.Create(ctx, &model.NotificationChannel{ID: channelID, Name: "notify-channel", ChannelType: "webhook", Config: map[string]any{"webhook_url": "http://example.invalid"}, Priority: 1, Enabled: true}); err != nil {
t.Fatal(err)
}
if err := healingRepo.CreateHealing(ctx, &service.HealingLog{ID: healingID, AlertID: eventID, ActionType: "throttle", Config: map[string]any{"endpoint": "http://example.invalid"}, Status: "pending", DryRun: true, StartedAt: time.Now().UTC()}); err != nil {
t.Fatal(err)
}
if err := healingRepo.UpdateHealingStatus(ctx, healingID, "succeeded", map[string]any{"ok": true}, ""); err != nil {
t.Fatal(err)
}
if err := notificationRepo.CreateLog(ctx, &model.NotificationLog{ID: notificationID, EventID: eventID, ChannelID: channelID, ChannelType: "webhook", Status: "pending"}); err != nil {
t.Fatal(err)
}
if err := notificationRepo.MarkSent(ctx, notificationID); err != nil {
t.Fatal(err)
}
if err := notificationRepo.MarkFailed(ctx, notificationID, 1, "retry failed"); err != nil {
t.Fatal(err)
}
}

112
internal/middleware/auth.go Normal file
View File

@@ -0,0 +1,112 @@
package middleware
import (
"context"
"net/http"
"strings"
"github.com/company/ai-ops/internal/config"
"github.com/company/ai-ops/pkg/errors"
"github.com/company/ai-ops/pkg/response"
"github.com/golang-jwt/jwt/v5"
)
// Auth 中间件检查认证
func Auth(cfg config.ServerConfig) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 白名单路径免认证
if isPublicPath(r.URL.Path) {
next.ServeHTTP(w, r)
return
}
// API Key 检查(用于 /metrics 等机器对机器接口)
if strings.HasPrefix(r.URL.Path, "/metrics") {
apiKey := r.Header.Get("X-API-Key")
if apiKey == "" {
apiKey = r.URL.Query().Get("api_key")
}
if apiKey == cfg.MetricsAuth {
next.ServeHTTP(w, r)
return
}
}
// JWT 检查
tokenStr := r.Header.Get("Authorization")
if tokenStr == "" {
response.Error(w, errors.ErrUnauthorized)
return
}
tokenStr = strings.TrimPrefix(tokenStr, "Bearer ")
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
return []byte(cfg.JWTSecret), nil
}, jwt.WithValidMethods([]string{"HS256"}))
if err != nil || !token.Valid {
response.Error(w, errors.ErrUnauthorized)
return
}
// 将用户ID和角色写入上下文
if claims, ok := token.Claims.(jwt.MapClaims); ok {
if userID, ok := claims["user_id"].(string); ok {
r = r.WithContext(context.WithValue(r.Context(), "user_id", userID))
}
if role, ok := claims["role"].(string); ok {
r = r.WithContext(context.WithValue(r.Context(), "role", role))
}
}
next.ServeHTTP(w, r)
})
}
}
// RequireRole 角色权限中间件
func RequireRole(roles ...string) func(http.Handler) http.Handler {
roleSet := make(map[string]bool)
for _, r := range roles {
roleSet[r] = true
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
role, _ := r.Context().Value("role").(string)
if !roleSet[role] {
response.Error(w, errors.ErrForbidden.WithDetail(map[string]any{
"error": "insufficient permissions",
"code": "OPS_AUTH_1001",
"required": roles,
"current": role,
}))
return
}
next.ServeHTTP(w, r)
})
}
}
// RequireWrite 允许 GET 或需要写权限
func RequireWrite(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == "GET" || r.Method == "HEAD" {
next.ServeHTTP(w, r)
return
}
role, _ := r.Context().Value("role").(string)
if role != "operator" && role != "admin" {
response.Error(w, errors.ErrForbidden.WithDetail(map[string]any{
"error": "write permission required",
"code": "OPS_AUTH_1001",
"current": role,
}))
return
}
next.ServeHTTP(w, r)
})
}
func isPublicPath(path string) bool {
return path == "/health" || strings.HasPrefix(path, "/actuator/health") || path == "/api/v1/ai-ops/login" || path == "/openapi.json" || strings.HasPrefix(path, "/ops/dashboard")
}

View File

@@ -0,0 +1,100 @@
package middleware
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/company/ai-ops/internal/config"
"github.com/company/ai-ops/internal/service"
)
func TestAuthAllowsPublicPaths(t *testing.T) {
cfg := config.ServerConfig{JWTSecret: "secret", MetricsAuth: "metrics-key"}
for _, path := range []string{"/health", "/actuator/health/ready", "/api/v1/ai-ops/login", "/openapi.json", "/ops/dashboard"} {
t.Run(path, func(t *testing.T) {
called := false
h := Auth(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { called = true }))
h.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, path, nil))
if !called {
t.Fatalf("public path %s was blocked", path)
}
})
}
}
func TestAuthMetricsAPIKeyAndJWT(t *testing.T) {
cfg := config.ServerConfig{JWTSecret: "secret", MetricsAuth: "metrics-key"}
t.Run("metrics api key", func(t *testing.T) {
called := false
h := Auth(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { called = true }))
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
req.Header.Set("X-API-Key", "metrics-key")
h.ServeHTTP(httptest.NewRecorder(), req)
if !called {
t.Fatal("metrics api key did not pass")
}
})
t.Run("missing token rejected", func(t *testing.T) {
h := Auth(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Fatal("should not call next") }))
w := httptest.NewRecorder()
h.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/v1/ai-ops/rules", nil))
if w.Code != http.StatusUnauthorized {
t.Fatalf("status = %d", w.Code)
}
})
t.Run("valid jwt sets context", func(t *testing.T) {
token, err := service.NewAuthService("secret").IssueToken("u1", "operator")
if err != nil {
t.Fatal(err)
}
h := Auth(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Context().Value("user_id") != "u1" || r.Context().Value("role") != "operator" {
t.Fatalf("context not populated: user=%v role=%v", r.Context().Value("user_id"), r.Context().Value("role"))
}
}))
req := httptest.NewRequest(http.MethodGet, "/api/v1/ai-ops/rules", nil)
req.Header.Set("Authorization", "Bearer "+token)
h.ServeHTTP(httptest.NewRecorder(), req)
})
}
func TestRequireRoleAndRequireWrite(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusAccepted) })
allowed := RequireRole("admin")(next)
req := httptest.NewRequest(http.MethodGet, "/x", nil).WithContext(context.WithValue(context.Background(), "role", "admin"))
w := httptest.NewRecorder()
allowed.ServeHTTP(w, req)
if w.Code != http.StatusAccepted {
t.Fatalf("role allowed status = %d", w.Code)
}
denied := httptest.NewRecorder()
RequireRole("admin")(next).ServeHTTP(denied, httptest.NewRequest(http.MethodGet, "/x", nil))
if denied.Code != http.StatusForbidden {
t.Fatalf("role denied status = %d", denied.Code)
}
read := httptest.NewRecorder()
RequireWrite(next).ServeHTTP(read, httptest.NewRequest(http.MethodGet, "/x", nil))
if read.Code != http.StatusAccepted {
t.Fatalf("read status = %d", read.Code)
}
writeDenied := httptest.NewRecorder()
RequireWrite(next).ServeHTTP(writeDenied, httptest.NewRequest(http.MethodPost, "/x", nil))
if writeDenied.Code != http.StatusForbidden {
t.Fatalf("write denied status = %d", writeDenied.Code)
}
writeAllowed := httptest.NewRecorder()
writeReq := httptest.NewRequest(http.MethodPost, "/x", nil).WithContext(context.WithValue(context.Background(), "role", "operator"))
RequireWrite(next).ServeHTTP(writeAllowed, writeReq)
if writeAllowed.Code != http.StatusAccepted {
t.Fatalf("write allowed status = %d", writeAllowed.Code)
}
}

View File

@@ -0,0 +1,37 @@
package middleware
import (
"log/slog"
"net/http"
"time"
)
// responseWriter 是用于捕获状态码的响应写入器
type responseWriter struct {
http.ResponseWriter
statusCode int
}
func (rw *responseWriter) WriteHeader(code int) {
rw.statusCode = code
rw.ResponseWriter.WriteHeader(code)
}
// Logging 中间件记录请求日志
func Logging(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
rw := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
next.ServeHTTP(rw, r)
duration := time.Since(start)
slog.Info("http_request",
"method", r.Method,
"path", r.URL.Path,
"status", rw.statusCode,
"duration_ms", float64(duration.Microseconds())/1000,
"remote_addr", r.RemoteAddr,
)
})
}

View File

@@ -0,0 +1,41 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestLoggingCapturesStatusCode(t *testing.T) {
h := Logging(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusCreated)
_, _ = w.Write([]byte("created"))
}))
w := httptest.NewRecorder()
h.ServeHTTP(w, httptest.NewRequest(http.MethodPost, "/x", nil))
if w.Code != http.StatusCreated || w.Body.String() != "created" {
t.Fatalf("response = %d %q", w.Code, w.Body.String())
}
}
func TestRecoveryConvertsPanicToInternalError(t *testing.T) {
h := Recovery(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
panic("boom")
}))
w := httptest.NewRecorder()
h.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/panic", nil))
if w.Code != http.StatusInternalServerError {
t.Fatalf("status = %d body=%s", w.Code, w.Body.String())
}
}
func TestRecoveryPassesThroughNormalRequest(t *testing.T) {
h := Recovery(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusAccepted)
}))
w := httptest.NewRecorder()
h.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/ok", nil))
if w.Code != http.StatusAccepted {
t.Fatalf("status = %d", w.Code)
}
}

View File

@@ -0,0 +1,27 @@
package middleware
import (
"log/slog"
"net/http"
"runtime/debug"
"github.com/company/ai-ops/pkg/errors"
"github.com/company/ai-ops/pkg/response"
)
// Recovery 中间件捕获 panic
func Recovery(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if rec := recover(); rec != nil {
slog.Error("panic_recovered",
"error", rec,
"stack", string(debug.Stack()),
"path", r.URL.Path,
)
response.Error(w, errors.ErrInternal)
}
}()
next.ServeHTTP(w, r)
})
}

40
internal/redis/client.go Normal file
View File

@@ -0,0 +1,40 @@
package redis
import (
"context"
"fmt"
"time"
"github.com/company/ai-ops/internal/config"
"github.com/redis/go-redis/v9"
)
// Client 是全局 Redis 客户端
var Client *redis.Client
// Init 初始化 Redis 连接
func Init(cfg config.RedisConfig) error {
Client = redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
Password: cfg.Password,
DB: cfg.DB,
PoolSize: 10,
})
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := Client.Ping(ctx).Err(); err != nil {
return fmt.Errorf("ping redis: %w", err)
}
return nil
}
// Close 关闭 Redis 连接
func Close() error {
if Client != nil {
return Client.Close()
}
return nil
}

View File

@@ -0,0 +1,33 @@
package redis
import (
"testing"
"github.com/company/ai-ops/internal/config"
)
func TestInitAndCloseWithLocalRedis(t *testing.T) {
ports := []int{16379, 6379}
var lastErr error
for _, port := range ports {
lastErr = Init(config.RedisConfig{Host: "localhost", Port: port, DB: 0})
if lastErr == nil {
break
}
_ = Close()
Client = nil
}
if lastErr != nil {
t.Skipf("Redis integration server not available: %v", lastErr)
}
if Client == nil {
t.Fatal("client not initialized")
}
if err := Close(); err != nil {
t.Fatal(err)
}
Client = nil
if err := Close(); err != nil {
t.Fatal(err)
}
}

View File

@@ -0,0 +1,277 @@
package service
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"log/slog"
"strconv"
"sync"
"time"
"github.com/company/ai-ops/internal/domain/model"
"github.com/company/ai-ops/internal/domain/repository"
)
// TriggerState 记录规则触发状态
type TriggerState struct {
FirstTriggeredAt time.Time // 首次超阈值时间
LastTriggeredAt time.Time // 最近一次触发时间
}
// AlertEngine 是告警规则评估引擎
type AlertEngine struct {
alertRepo repository.AlertRepository
metricRepo repository.MetricRepository
notifySvc *NotificationService
interval time.Duration
stopCh chan struct{}
// 规则触发状态(持续时间判定)
triggerStates map[string]*TriggerState
statesMu sync.RWMutex
// 抑制期:同一规则 5 分钟内不重复触发
suppressWindow time.Duration
// 告警升级P2 持续 2 小时未确认 → 升级 P1
escalationInterval time.Duration
// 告警聚合:同一资源 1 分钟内超过 20 条时生成聚合告警
aggregationWindow time.Duration
aggregationThreshold int
}
// NewAlertEngine 创建规则评估引擎
func NewAlertEngine(ar repository.AlertRepository, mr repository.MetricRepository, ns *NotificationService) *AlertEngine {
return &AlertEngine{
alertRepo: ar,
metricRepo: mr,
notifySvc: ns,
interval: 30 * time.Second,
stopCh: make(chan struct{}),
triggerStates: make(map[string]*TriggerState),
suppressWindow: 5 * time.Minute,
escalationInterval: 2 * time.Hour,
aggregationWindow: 1 * time.Minute,
aggregationThreshold: 20,
}
}
// Start 启动定时评估
func (e *AlertEngine) Start() {
slog.Info("alert_engine_started", "interval", e.interval, "suppress_window", e.suppressWindow)
go e.loop()
}
// Stop 停止引擎
func (e *AlertEngine) Stop() {
close(e.stopCh)
slog.Info("alert_engine_stopped")
}
func (e *AlertEngine) loop() {
ticker := time.NewTicker(e.interval)
defer ticker.Stop()
escalationTicker := time.NewTicker(5 * time.Minute)
defer escalationTicker.Stop()
e.evaluate(context.Background())
for {
select {
case <-ticker.C:
e.evaluate(context.Background())
case <-escalationTicker.C:
e.escalate(context.Background())
case <-e.stopCh:
return
}
}
}
func (e *AlertEngine) evaluate(ctx context.Context) {
rules, err := e.alertRepo.ListRules(ctx)
if err != nil {
slog.Error("list_rules_failed", "error", err)
return
}
for _, rule := range rules {
if err := e.evaluateRule(ctx, &rule); err != nil {
slog.Error("evaluate_rule_failed", "rule_id", rule.ID, "error", err)
}
}
}
func (e *AlertEngine) evaluateRule(ctx context.Context, rule *model.AlertRule) error {
point, err := e.metricRepo.GetLatest(ctx, rule.MetricSource, rule.MetricName)
if err != nil {
return fmt.Errorf("get metric: %w", err)
}
threshold, err := strconv.ParseFloat(rule.ThresholdValue, 64)
if err != nil {
return fmt.Errorf("parse threshold: %w", err)
}
triggered := e.compare(point.Value, threshold, rule.ThresholdType)
now := time.Now()
e.statesMu.Lock()
state, exists := e.triggerStates[rule.ID]
if !triggered {
// 指标恢复正常,清除触发状态
if exists {
delete(e.triggerStates, rule.ID)
slog.Info("alert_cleared", "rule_id", rule.ID, "metric", rule.MetricName)
}
e.statesMu.Unlock()
return nil
}
// 指标超阈值
if !exists {
state = &TriggerState{FirstTriggeredAt: now, LastTriggeredAt: time.Time{}}
e.triggerStates[rule.ID] = state
}
e.statesMu.Unlock()
// 持续时间判定:必须持续 N 分钟才触发
duration := time.Since(state.FirstTriggeredAt)
requiredDuration := time.Duration(rule.DurationMin) * time.Minute
if duration < requiredDuration {
slog.Debug("alert_breaching_not_yet_triggered",
"rule_id", rule.ID,
"duration", duration,
"required", requiredDuration,
)
return nil
}
// 抑制期检查5 分钟内不重复触发
if !state.LastTriggeredAt.IsZero() && now.Sub(state.LastTriggeredAt) < e.suppressWindow {
return nil
}
// 更新最近触发时间
e.statesMu.Lock()
state.LastTriggeredAt = now
e.statesMu.Unlock()
// 创建告警事件
event := &model.AlertEvent{
ID: generateID(),
RuleID: rule.ID,
Level: rule.Level,
ResourceType: rule.MetricSource,
ResourceID: rule.MetricName,
CurrentValue: fmt.Sprintf("%.4f", point.Value),
ThresholdValue: rule.ThresholdValue,
Status: "triggered",
IsAggregated: false,
AggregatedCount: 1,
}
notifyEvent, err := e.alertRepo.CreateEventWithAggregation(ctx, event, e.aggregationWindow, e.aggregationThreshold)
if err != nil {
return fmt.Errorf("create event: %w", err)
}
if notifyEvent == nil {
notifyEvent = event
}
// 异步发送通知
if e.notifySvc != nil && len(rule.ChannelIDs) > 0 {
e.notifySvc.Enqueue(notifyEvent, rule.ChannelIDs)
}
slog.Info("alert_triggered",
"rule_id", rule.ID,
"level", rule.Level,
"metric", rule.MetricName,
"value", point.Value,
"threshold", threshold,
"duration_min", rule.DurationMin,
)
return nil
}
func (e *AlertEngine) escalate(ctx context.Context) {
// 查询 open 状态的 P2 告警
events, _, err := e.alertRepo.ListEvents(ctx, "triggered", 1, 100)
if err != nil {
slog.Error("list_open_events_failed", "error", err)
return
}
now := time.Now()
for _, event := range events {
if event.Level != "P2" {
continue
}
if now.Sub(event.StartedAt) < e.escalationInterval {
continue
}
// 升级为 P1
if err := e.alertRepo.EscalateEvent(ctx, event.ID, "P1"); err != nil {
slog.Error("escalate_event_failed", "event_id", event.ID, "error", err)
continue
}
// 发送升级通知
upgraded := &model.AlertEvent{
ID: event.ID,
RuleID: event.RuleID,
Level: "P1",
ResourceType: event.ResourceType,
ResourceID: event.ResourceID,
CurrentValue: event.CurrentValue,
ThresholdValue: event.ThresholdValue,
Status: "triggered",
}
rule, err := e.alertRepo.GetRuleByID(ctx, event.RuleID)
if err == nil && e.notifySvc != nil && len(rule.ChannelIDs) > 0 {
e.notifySvc.Enqueue(upgraded, rule.ChannelIDs)
}
slog.Info("alert_escalated",
"event_id", event.ID,
"rule_id", event.RuleID,
"from_level", "P2",
"to_level", "P1",
"duration", now.Sub(event.StartedAt),
)
}
}
func (e *AlertEngine) compare(value, threshold float64, op string) bool {
switch op {
case ">":
return value > threshold
case "<":
return value < threshold
case "=":
return value == threshold
case ">=":
return value >= threshold
case "<=":
return value <= threshold
default:
return false
}
}
func generateID() string {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return fmt.Sprintf("00000000-0000-4000-8000-%012d", time.Now().UnixNano()%1_000_000_000_000)
}
b[6] = (b[6] & 0x0f) | 0x40
b[8] = (b[8] & 0x3f) | 0x80
return fmt.Sprintf("%s-%s-%s-%s-%s", hex.EncodeToString(b[0:4]), hex.EncodeToString(b[4:6]), hex.EncodeToString(b[6:8]), hex.EncodeToString(b[8:10]), hex.EncodeToString(b[10:16]))
}

View File

@@ -0,0 +1,191 @@
package service
import (
"context"
"testing"
"time"
"github.com/company/ai-ops/internal/domain/model"
"github.com/stretchr/testify/mock"
)
type fakeAggregationAlertRepo struct {
rules []model.AlertRule
events []model.AlertEvent
createdEvents []*model.AlertEvent
escalated []string
}
func (r *fakeAggregationAlertRepo) GetOpenCount(ctx context.Context) (*model.AlertCount, error) {
return &model.AlertCount{}, nil
}
func (r *fakeAggregationAlertRepo) ListRules(ctx context.Context) ([]model.AlertRule, error) {
return r.rules, nil
}
func (r *fakeAggregationAlertRepo) GetRuleByID(ctx context.Context, id string) (*model.AlertRule, error) {
for i := range r.rules {
if r.rules[i].ID == id {
return &r.rules[i], nil
}
}
return nil, nil
}
func (r *fakeAggregationAlertRepo) CreateRule(ctx context.Context, rule *model.AlertRule) error {
return nil
}
func (r *fakeAggregationAlertRepo) UpdateRule(ctx context.Context, rule *model.AlertRule) error {
return nil
}
func (r *fakeAggregationAlertRepo) DeleteRule(ctx context.Context, id string) error { return nil }
func (r *fakeAggregationAlertRepo) ListEvents(ctx context.Context, status string, page, pageSize int) ([]model.AlertEvent, int, error) {
return r.events, len(r.events), nil
}
func (r *fakeAggregationAlertRepo) CreateEvent(ctx context.Context, event *model.AlertEvent) error {
r.createdEvents = append(r.createdEvents, event)
return nil
}
func (r *fakeAggregationAlertRepo) CreateEventWithAggregation(ctx context.Context, event *model.AlertEvent, window time.Duration, threshold int) (*model.AlertEvent, error) {
r.createdEvents = append(r.createdEvents, event)
if len(r.createdEvents) > threshold {
return &model.AlertEvent{
ID: "agg-1",
RuleID: event.RuleID,
Level: event.Level,
ResourceType: event.ResourceType,
ResourceID: event.ResourceID,
CurrentValue: event.CurrentValue,
ThresholdValue: event.ThresholdValue,
Status: "triggered",
IsAggregated: true,
AggregatedCount: len(r.createdEvents),
}, nil
}
return event, nil
}
func (r *fakeAggregationAlertRepo) UpdateEventStatus(ctx context.Context, id, status string) error {
return nil
}
func (r *fakeAggregationAlertRepo) EscalateEvent(ctx context.Context, id, newLevel string) error {
r.escalated = append(r.escalated, id+":"+newLevel)
return nil
}
type fakeMetricRepo struct {
point *model.MetricPoint
}
func (r *fakeMetricRepo) GetRealtime(ctx context.Context) (*model.RealtimeMetrics, error) {
return &model.RealtimeMetrics{}, nil
}
func (r *fakeMetricRepo) Query(ctx context.Context, req model.MetricQueryRequest) ([]model.MetricPoint, error) {
return nil, nil
}
func (r *fakeMetricRepo) GetLatest(ctx context.Context, source, name string) (*model.MetricPoint, error) {
return r.point, nil
}
func TestAlertEngineAggregatesWhenSameResourceExceedsTwentyEventsWithinWindow(t *testing.T) {
alertRepo := &fakeAggregationAlertRepo{rules: []model.AlertRule{{
ID: "rule-1",
MetricSource: "service",
MetricName: "api-error-rate",
ThresholdType: ">",
ThresholdValue: "0.1",
DurationMin: 0,
Level: "P2",
}}}
metricRepo := &fakeMetricRepo{point: &model.MetricPoint{Value: 0.5}}
engine := NewAlertEngine(alertRepo, metricRepo, nil)
engine.suppressWindow = 0
var last *model.AlertEvent
for i := 0; i < 21; i++ {
if err := engine.evaluateRule(context.Background(), &alertRepo.rules[0]); err != nil {
t.Fatalf("evaluate rule: %v", err)
}
last = alertRepo.createdEvents[len(alertRepo.createdEvents)-1]
}
if got := len(alertRepo.createdEvents); got != 21 {
t.Fatalf("created events = %d, want 21", got)
}
if last.IsAggregated {
t.Fatalf("raw child event must not be marked aggregated")
}
}
func TestAlertEngineEvaluateAndEscalateBranches(t *testing.T) {
alertRepo := &fakeAggregationAlertRepo{rules: []model.AlertRule{{
ID: "rule-eval",
MetricSource: "service",
MetricName: "latency",
ThresholdType: ">=",
ThresholdValue: "10",
DurationMin: 0,
Level: "P2",
}}}
metricRepo := &fakeMetricRepo{point: &model.MetricPoint{Value: 10}}
engine := NewAlertEngine(alertRepo, metricRepo, nil)
engine.suppressWindow = time.Hour
engine.evaluate(context.Background())
if len(alertRepo.createdEvents) != 1 {
t.Fatalf("created events = %d", len(alertRepo.createdEvents))
}
// suppressed second event
engine.evaluate(context.Background())
if len(alertRepo.createdEvents) != 1 {
t.Fatalf("suppression failed, events = %d", len(alertRepo.createdEvents))
}
if !engine.compare(1, 1, "=") || !engine.compare(1, 2, "<") || !engine.compare(2, 1, ">") || !engine.compare(2, 2, ">=") || !engine.compare(1, 2, "<=") || engine.compare(1, 2, "regex") {
t.Fatal("compare operators not covered as expected")
}
if generateID() == "" {
t.Fatal("empty alert id")
}
}
func TestMetricServiceSupplierAndQuery(t *testing.T) {
mockMetric := new(MockMetricRepository)
mockAlert := new(MockAlertRepository)
svc := NewMetricService(mockMetric, mockAlert)
query := model.MetricQueryRequest{Name: "qps"}
points := []model.MetricPoint{{Name: "qps", Value: 1}}
mockMetric.On("Query", mock.Anything, query).Return(points, nil).Once()
mockMetric.On("Query", mock.Anything, model.MetricQueryRequest{Name: "supplier_health"}).Return([]model.MetricPoint{{Value: 1}, {Value: 0}}, nil).Once()
got, err := svc.QueryMetrics(context.Background(), query)
if err != nil || len(got) != 1 {
t.Fatalf("query metrics = %+v %v", got, err)
}
count, err := svc.GetSupplierCount(context.Background())
if err != nil || count.Healthy != 1 || count.Unhealthy != 1 || count.Total != 2 {
t.Fatalf("supplier count = %+v %v", count, err)
}
}
func TestAlertEngineStartStopCoversLoop(t *testing.T) {
engine := NewAlertEngine(&fakeAggregationAlertRepo{}, &fakeMetricRepo{point: &model.MetricPoint{Value: 0}}, nil)
engine.interval = time.Hour
engine.Start()
time.Sleep(5 * time.Millisecond)
engine.Stop()
}
func TestAlertEngineEscalatesOldP2EventsOnly(t *testing.T) {
oldEvent := model.AlertEvent{ID: "old", RuleID: "rule-old", Level: "P2", ResourceType: "svc", ResourceID: "api", CurrentValue: "9", ThresholdValue: "1", Status: "triggered", StartedAt: time.Now().Add(-3 * time.Hour)}
freshEvent := model.AlertEvent{ID: "fresh", RuleID: "rule-fresh", Level: "P2", StartedAt: time.Now()}
p1Event := model.AlertEvent{ID: "p1", RuleID: "rule-p1", Level: "P1", StartedAt: time.Now().Add(-3 * time.Hour)}
repo := &fakeAggregationAlertRepo{
events: []model.AlertEvent{oldEvent, freshEvent, p1Event},
rules: []model.AlertRule{{ID: "rule-old", ChannelIDs: []string{"ch-1"}}},
}
engine := NewAlertEngine(repo, &fakeMetricRepo{point: &model.MetricPoint{Value: 0}}, nil)
engine.escalate(context.Background())
if len(repo.escalated) != 1 || repo.escalated[0] != "old:P1" {
t.Fatalf("escalated = %+v", repo.escalated)
}
}

View File

@@ -0,0 +1,181 @@
package service
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"time"
"github.com/company/ai-ops/internal/database"
)
// AuditService 是审计服务
type AuditService struct{}
func NewAuditService() *AuditService {
return &AuditService{}
}
// AuditLog 是审计日志记录
type AuditLog struct {
ID string `json:"id"`
TenantID string `json:"tenant_id"`
ObjectType string `json:"object_type"`
ObjectID string `json:"object_id"`
Action string `json:"action"`
BeforeState map[string]any `json:"before_state,omitempty"`
AfterState map[string]any `json:"after_state,omitempty"`
RequestID string `json:"request_id"`
ResultCode string `json:"result_code"`
SourceIP string `json:"source_ip"`
ActorID string `json:"actor_id"`
RiskLevel string `json:"risk_level"`
ParentAuditID *string `json:"parent_audit_id,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
// Record 记录审计日志
func (s *AuditService) Record(ctx context.Context, log *AuditLog) error {
var parentID any
if log.ParentAuditID != nil {
parentID = *log.ParentAuditID
}
_, err := database.Pool.Exec(ctx, `
INSERT INTO ai_ops_audits (id, tenant_id, object_type, object_id, action,
before_state, after_state, request_id, result_code, source_ip, actor_id,
risk_level, parent_audit_id, created_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, NOW())
`, log.ID, log.TenantID, log.ObjectType, log.ObjectID, log.Action,
log.BeforeState, log.AfterState, log.RequestID, log.ResultCode,
log.SourceIP, log.ActorID, log.RiskLevel, parentID)
if err != nil {
return fmt.Errorf("insert audit: %w", err)
}
return nil
}
// List 查询审计日志
func (s *AuditService) List(ctx context.Context, objectType, objectID string, page, pageSize int) ([]AuditLog, int, error) {
if page < 1 {
page = 1
}
if pageSize < 1 || pageSize > 100 {
pageSize = 20
}
where := ""
args := []any{}
argIdx := 1
if objectType != "" {
where = fmt.Sprintf("WHERE object_type = $%d", argIdx)
args = append(args, objectType)
argIdx++
}
if objectID != "" {
if where != "" {
where += fmt.Sprintf(" AND object_id = $%d", argIdx)
} else {
where = fmt.Sprintf("WHERE object_id = $%d", argIdx)
}
args = append(args, objectID)
argIdx++
}
var total int
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM ai_ops_audits %s", where)
if err := database.Pool.QueryRow(ctx, countQuery, args...).Scan(&total); err != nil {
return nil, 0, fmt.Errorf("count audits: %w", err)
}
dataQuery := fmt.Sprintf(`
SELECT id, tenant_id, object_type, object_id, action,
before_state, after_state, request_id, result_code, source_ip, actor_id,
risk_level, parent_audit_id, created_at
FROM ai_ops_audits %s
ORDER BY created_at DESC
LIMIT $%d OFFSET $%d
`, where, argIdx, argIdx+1)
queryArgs := append(args, pageSize, (page-1)*pageSize)
rows, err := database.Pool.Query(ctx, dataQuery, queryArgs...)
if err != nil {
return nil, 0, fmt.Errorf("query audits: %w", err)
}
defer rows.Close()
var logs []AuditLog
for rows.Next() {
var l AuditLog
var parentID *string
if err := rows.Scan(
&l.ID, &l.TenantID, &l.ObjectType, &l.ObjectID, &l.Action,
&l.BeforeState, &l.AfterState, &l.RequestID, &l.ResultCode,
&l.SourceIP, &l.ActorID, &l.RiskLevel, &parentID, &l.CreatedAt,
); err != nil {
return nil, 0, fmt.Errorf("scan audit: %w", err)
}
l.ParentAuditID = parentID
logs = append(logs, l)
}
return logs, total, rows.Err()
}
// Rollback 回滚配置
func (s *AuditService) Rollback(ctx context.Context, auditID string) (*AuditLog, error) {
// 查找原始审计记录
var original AuditLog
var parentID *string
err := database.Pool.QueryRow(ctx, `
SELECT id, tenant_id, object_type, object_id, action,
before_state, after_state, request_id, result_code, source_ip, actor_id,
risk_level, parent_audit_id, created_at
FROM ai_ops_audits WHERE id = $1
`, auditID).Scan(
&original.ID, &original.TenantID, &original.ObjectType, &original.ObjectID, &original.Action,
&original.BeforeState, &original.AfterState, &original.RequestID, &original.ResultCode,
&original.SourceIP, &original.ActorID, &original.RiskLevel, &parentID, &original.CreatedAt,
)
if err != nil {
return nil, fmt.Errorf("audit record not found")
}
// 检查目标资源是否存在(简化处理:假设总是存在)
if original.BeforeState == nil {
return nil, fmt.Errorf("no before_state available for rollback")
}
// 创建回滚审计记录
rollbackLog := &AuditLog{
ID: generateAuditID(),
TenantID: original.TenantID,
ObjectType: original.ObjectType,
ObjectID: original.ObjectID,
Action: "rollback",
BeforeState: original.AfterState,
AfterState: original.BeforeState,
RequestID: original.RequestID,
ResultCode: "SUCCESS",
SourceIP: original.SourceIP,
ActorID: original.ActorID,
RiskLevel: "high",
ParentAuditID: &original.ID,
}
if err := s.Record(ctx, rollbackLog); err != nil {
return nil, fmt.Errorf("record rollback audit: %w", err)
}
return rollbackLog, nil
}
func generateAuditID() string {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return fmt.Sprintf("00000000-0000-4000-8000-%012d", time.Now().UnixNano()%1_000_000_000_000)
}
b[6] = (b[6] & 0x0f) | 0x40
b[8] = (b[8] & 0x3f) | 0x80
return fmt.Sprintf("%s-%s-%s-%s-%s", hex.EncodeToString(b[0:4]), hex.EncodeToString(b[4:6]), hex.EncodeToString(b[6:8]), hex.EncodeToString(b[8:10]), hex.EncodeToString(b[10:16]))
}

View File

@@ -0,0 +1,114 @@
package service
import (
"context"
"crypto/rand"
"encoding/hex"
"os"
"path/filepath"
"sort"
"testing"
"time"
"github.com/company/ai-ops/internal/config"
"github.com/company/ai-ops/internal/database"
)
func setupServicePGIntegration(t *testing.T) context.Context {
t.Helper()
ctx := context.Background()
if database.Pool == nil {
ports := []int{15432, 5432}
var lastErr error
for _, port := range ports {
lastErr = database.Init(config.DatabaseConfig{Host: "localhost", Port: port, User: "aiops", Password: "aiops123", DBName: "ai_ops", SSLMode: "disable", PoolSize: 4})
if lastErr == nil {
break
}
database.Close()
database.Pool = nil
}
if lastErr != nil {
t.Skipf("PostgreSQL integration database not available: %v", lastErr)
}
}
files, err := filepath.Glob(filepath.Join("..", "..", "tech", "migrations", "*.up.sql"))
if err != nil {
t.Fatal(err)
}
sort.Strings(files)
if _, err := database.Pool.Exec(ctx, `SELECT pg_advisory_lock(424242001)`); err != nil {
t.Fatal(err)
}
defer database.Pool.Exec(ctx, `SELECT pg_advisory_unlock(424242001)`)
for _, f := range files {
b, err := os.ReadFile(f)
if err != nil {
t.Fatal(err)
}
if _, err := database.Pool.Exec(ctx, string(b)); err != nil {
t.Fatalf("apply migration %s: %v", f, err)
}
}
return ctx
}
func serviceTestUUID(t *testing.T) string {
t.Helper()
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
t.Fatal(err)
}
b[6] = (b[6] & 0x0f) | 0x40
b[8] = (b[8] & 0x3f) | 0x80
return hex.EncodeToString(b[0:4]) + "-" + hex.EncodeToString(b[4:6]) + "-" + hex.EncodeToString(b[6:8]) + "-" + hex.EncodeToString(b[8:10]) + "-" + hex.EncodeToString(b[10:16])
}
func cleanupAudit(t *testing.T, ctx context.Context, ids ...string) {
t.Helper()
for _, id := range ids {
_, _ = database.Pool.Exec(ctx, `DELETE FROM ai_ops_audits WHERE id=$1 OR parent_audit_id=$1 OR object_id=$1`, id)
}
}
func TestAuditServiceRecordListRollback(t *testing.T) {
ctx := setupServicePGIntegration(t)
svc := NewAuditService()
id := serviceTestUUID(t)
defer cleanupAudit(t, ctx, id)
log := &AuditLog{ID: id, TenantID: "tenant", ObjectType: "rule", ObjectID: id, Action: "update", BeforeState: map[string]any{"enabled": false}, AfterState: map[string]any{"enabled": true}, RequestID: "req", ResultCode: "SUCCESS", SourceIP: "127.0.0.1", ActorID: "actor", RiskLevel: "normal"}
if err := svc.Record(ctx, log); err != nil {
t.Fatal(err)
}
logs, total, err := svc.List(ctx, "rule", id, 0, 500)
if err != nil || total != 1 || len(logs) != 1 || logs[0].ID != id {
t.Fatalf("list = total=%d logs=%+v err=%v", total, logs, err)
}
rollback, err := svc.Rollback(ctx, id)
if err != nil {
t.Fatal(err)
}
if rollback.Action != "rollback" || rollback.ParentAuditID == nil || *rollback.ParentAuditID != id || rollback.RiskLevel != "high" {
t.Fatalf("rollback = %+v", rollback)
}
cleanupAudit(t, ctx, rollback.ID)
}
func TestAuditServiceRollbackRejectsMissingBeforeState(t *testing.T) {
ctx := setupServicePGIntegration(t)
svc := NewAuditService()
id := serviceTestUUID(t)
defer cleanupAudit(t, ctx, id)
log := &AuditLog{ID: id, TenantID: "tenant", ObjectType: "rule", ObjectID: id, Action: "create", AfterState: map[string]any{"enabled": true}, RequestID: "req", ResultCode: "SUCCESS", SourceIP: "127.0.0.1", ActorID: "actor", RiskLevel: "normal", CreatedAt: time.Now()}
if err := svc.Record(ctx, log); err != nil {
t.Fatal(err)
}
if _, err := svc.Rollback(ctx, id); err == nil {
t.Fatal("expected rollback error without before state")
}
if _, err := svc.Rollback(ctx, serviceTestUUID(t)); err == nil {
t.Fatal("expected missing audit error")
}
}

View File

@@ -0,0 +1,55 @@
package service
import (
"fmt"
"time"
"github.com/golang-jwt/jwt/v5"
)
// AuthService 是认证服务
type AuthService struct {
secret []byte
}
func NewAuthService(secret string) *AuthService {
return &AuthService{secret: []byte(secret)}
}
// Claims 是 JWT 宣告
type Claims struct {
UserID string `json:"user_id"`
Role string `json:"role"`
jwt.RegisteredClaims
}
// IssueToken 签发 JWT Token有效期 8 小时
func (s *AuthService) IssueToken(userID, role string) (string, error) {
claims := Claims{
UserID: userID,
Role: role,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(8 * time.Hour)),
IssuedAt: jwt.NewNumericDate(time.Now()),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString(s.secret)
}
// ParseToken 验证并解析 Token
func (s *AuthService) ParseToken(tokenStr string) (*Claims, error) {
token, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(token *jwt.Token) (any, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return s.secret, nil
})
if err != nil {
return nil, fmt.Errorf("parse token: %w", err)
}
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
return claims, nil
}
return nil, fmt.Errorf("invalid token")
}

View File

@@ -0,0 +1,45 @@
package service
import (
"context"
"fmt"
"github.com/company/ai-ops/internal/domain/model"
"github.com/company/ai-ops/internal/domain/repository"
)
// ChannelService 是通知渠道业务层
type ChannelService struct {
repo repository.ChannelRepository
}
func NewChannelService(repo repository.ChannelRepository) *ChannelService {
return &ChannelService{repo: repo}
}
func (s *ChannelService) List(ctx context.Context) ([]model.NotificationChannel, error) {
return s.repo.List(ctx)
}
func (s *ChannelService) Get(ctx context.Context, id string) (*model.NotificationChannel, error) {
return s.repo.GetByID(ctx, id)
}
func (s *ChannelService) Create(ctx context.Context, ch *model.NotificationChannel) error {
if ch.Name == "" || ch.ChannelType == "" {
return fmt.Errorf("name and channel_type are required")
}
ch.Enabled = true
return s.repo.Create(ctx, ch)
}
func (s *ChannelService) Update(ctx context.Context, ch *model.NotificationChannel) error {
if ch.ID == "" {
return fmt.Errorf("channel id is required")
}
return s.repo.Update(ctx, ch)
}
func (s *ChannelService) Delete(ctx context.Context, id string) error {
return s.repo.Delete(ctx, id)
}

View File

@@ -0,0 +1,225 @@
package service
import (
"bytes"
"context"
"errors"
"strings"
"testing"
"time"
"github.com/company/ai-ops/internal/domain/model"
)
type fakeRuleAlertRepo struct {
rules []model.AlertRule
gotRuleID string
createdRule *model.AlertRule
updatedRule *model.AlertRule
deletedID string
err error
}
func (r *fakeRuleAlertRepo) GetOpenCount(context.Context) (*model.AlertCount, error) {
return &model.AlertCount{}, nil
}
func (r *fakeRuleAlertRepo) ListRules(context.Context) ([]model.AlertRule, error) {
return r.rules, r.err
}
func (r *fakeRuleAlertRepo) GetRuleByID(_ context.Context, id string) (*model.AlertRule, error) {
r.gotRuleID = id
if r.err != nil {
return nil, r.err
}
return &model.AlertRule{ID: id, Name: "rule"}, nil
}
func (r *fakeRuleAlertRepo) CreateRule(_ context.Context, rule *model.AlertRule) error {
r.createdRule = rule
return r.err
}
func (r *fakeRuleAlertRepo) UpdateRule(_ context.Context, rule *model.AlertRule) error {
r.updatedRule = rule
return r.err
}
func (r *fakeRuleAlertRepo) DeleteRule(_ context.Context, id string) error {
r.deletedID = id
return r.err
}
func (r *fakeRuleAlertRepo) ListEvents(context.Context, string, int, int) ([]model.AlertEvent, int, error) {
return nil, 0, nil
}
func (r *fakeRuleAlertRepo) CreateEvent(context.Context, *model.AlertEvent) error { return nil }
func (r *fakeRuleAlertRepo) CreateEventWithAggregation(_ context.Context, e *model.AlertEvent, _ time.Duration, _ int) (*model.AlertEvent, error) {
return e, nil
}
func (r *fakeRuleAlertRepo) UpdateEventStatus(context.Context, string, string) error { return nil }
func (r *fakeRuleAlertRepo) EscalateEvent(context.Context, string, string) error { return nil }
type fakeChannelRepository struct {
channels []model.NotificationChannel
gotID string
created *model.NotificationChannel
updated *model.NotificationChannel
deleted string
err error
}
func (r *fakeChannelRepository) List(context.Context) ([]model.NotificationChannel, error) {
return r.channels, r.err
}
func (r *fakeChannelRepository) GetByID(_ context.Context, id string) (*model.NotificationChannel, error) {
r.gotID = id
if r.err != nil {
return nil, r.err
}
return &model.NotificationChannel{ID: id, Name: "webhook"}, nil
}
func (r *fakeChannelRepository) Create(_ context.Context, ch *model.NotificationChannel) error {
r.created = ch
return r.err
}
func (r *fakeChannelRepository) Update(_ context.Context, ch *model.NotificationChannel) error {
r.updated = ch
return r.err
}
func (r *fakeChannelRepository) Delete(_ context.Context, id string) error {
r.deleted = id
return r.err
}
type fakeLogRepository struct {
logs []model.RequestLog
total int
lastFilter model.LogQueryFilter
err error
}
func (r *fakeLogRepository) Query(_ context.Context, filter model.LogQueryFilter) ([]model.RequestLog, int, error) {
r.lastFilter = filter
return r.logs, r.total, r.err
}
func TestAuthServiceIssuesAndParsesToken(t *testing.T) {
svc := NewAuthService("secret")
token, err := svc.IssueToken("u1", "admin")
if err != nil {
t.Fatal(err)
}
claims, err := svc.ParseToken(token)
if err != nil {
t.Fatal(err)
}
if claims.UserID != "u1" || claims.Role != "admin" {
t.Fatalf("unexpected claims: %+v", claims)
}
if _, err := NewAuthService("other").ParseToken(token); err == nil {
t.Fatal("expected invalid signature error")
}
if _, err := svc.ParseToken("not-a-jwt"); err == nil {
t.Fatal("expected malformed token error")
}
}
func TestRuleServiceValidationAndRepositoryCalls(t *testing.T) {
repo := &fakeRuleAlertRepo{rules: []model.AlertRule{{ID: "r1"}}}
svc := NewRuleService(repo)
if rules, err := svc.ListRules(context.Background()); err != nil || len(rules) != 1 {
t.Fatalf("list = %v %v", rules, err)
}
if rule, err := svc.GetRule(context.Background(), "r1"); err != nil || rule.ID != "r1" {
t.Fatalf("get = %+v %v", rule, err)
}
if err := svc.CreateRule(context.Background(), &model.AlertRule{}); err == nil {
t.Fatal("expected missing id error")
}
if err := svc.CreateRule(context.Background(), &model.AlertRule{ID: "r2"}); err == nil {
t.Fatal("expected missing name/metric error")
}
rule := &model.AlertRule{ID: "r2", Name: "latency", MetricName: "p99"}
if err := svc.CreateRule(context.Background(), rule); err != nil {
t.Fatal(err)
}
if !rule.Enabled || rule.Version != 1 || repo.createdRule != rule {
t.Fatalf("create did not normalize rule: %+v", rule)
}
if err := svc.UpdateRule(context.Background(), &model.AlertRule{}); err == nil {
t.Fatal("expected missing update id error")
}
updating := &model.AlertRule{ID: "r2", Version: 2}
if err := svc.UpdateRule(context.Background(), updating); err != nil {
t.Fatal(err)
}
if updating.Version != 3 || repo.updatedRule != updating {
t.Fatalf("version not incremented: %+v", updating)
}
if err := svc.DeleteRule(context.Background(), "r2"); err != nil || repo.deletedID != "r2" {
t.Fatalf("delete failed: %v", err)
}
}
func TestChannelServiceValidationAndRepositoryCalls(t *testing.T) {
repo := &fakeChannelRepository{channels: []model.NotificationChannel{{ID: "c1"}}}
svc := NewChannelService(repo)
if channels, err := svc.List(context.Background()); err != nil || len(channels) != 1 {
t.Fatalf("list = %v %v", channels, err)
}
if ch, err := svc.Get(context.Background(), "c1"); err != nil || ch.ID != "c1" {
t.Fatalf("get = %+v %v", ch, err)
}
if err := svc.Create(context.Background(), &model.NotificationChannel{}); err == nil {
t.Fatal("expected validation error")
}
ch := &model.NotificationChannel{Name: "hook", ChannelType: "webhook"}
if err := svc.Create(context.Background(), ch); err != nil {
t.Fatal(err)
}
if !ch.Enabled || repo.created != ch {
t.Fatalf("create did not enable channel: %+v", ch)
}
if err := svc.Update(context.Background(), &model.NotificationChannel{}); err == nil {
t.Fatal("expected missing id error")
}
if err := svc.Update(context.Background(), &model.NotificationChannel{ID: "c1"}); err != nil {
t.Fatal(err)
}
if err := svc.Delete(context.Background(), "c1"); err != nil || repo.deleted != "c1" {
t.Fatalf("delete failed: %v", err)
}
}
func TestLogServiceQueryAndExportCSV(t *testing.T) {
repo := &fakeLogRepository{
logs: []model.RequestLog{{Timestamp: time.Date(2026, 5, 12, 1, 2, 3, 0, time.UTC), Service: "api", Path: "/v1", Method: "GET", StatusCode: 200, LatencyMs: 12.34, UserID: "u", SupplierID: "s"}},
total: 1,
}
svc := NewLogService(repo)
logs, total, err := svc.QueryLogs(context.Background(), model.LogQueryFilter{Service: "api", Page: 2, PageSize: 5})
if err != nil || total != 1 || len(logs) != 1 {
t.Fatalf("query = %v %d %v", logs, total, err)
}
if repo.lastFilter.Service != "api" || repo.lastFilter.Page != 2 {
t.Fatalf("filter not passed: %+v", repo.lastFilter)
}
var buf bytes.Buffer
if err := svc.ExportLogsCSV(context.Background(), model.LogQueryFilter{Page: 9, PageSize: 1}, &buf); err != nil {
t.Fatal(err)
}
out := buf.String()
if !strings.Contains(out, "时间,服务名,路径,方法,状态码") || !strings.Contains(out, "api,/v1,GET,200,12.34") {
t.Fatalf("unexpected csv: %s", out)
}
if repo.lastFilter.Page != 1 || repo.lastFilter.PageSize != 10000 {
t.Fatalf("export did not enforce bounds: %+v", repo.lastFilter)
}
}
func TestLogServicePropagatesRepositoryErrors(t *testing.T) {
svc := NewLogService(&fakeLogRepository{err: errors.New("db down")})
if _, _, err := svc.QueryLogs(context.Background(), model.LogQueryFilter{}); err == nil || !strings.Contains(err.Error(), "query logs") {
t.Fatalf("unexpected query err: %v", err)
}
if err := svc.ExportLogsCSV(context.Background(), model.LogQueryFilter{}, &bytes.Buffer{}); err == nil || !strings.Contains(err.Error(), "query logs for export") {
t.Fatalf("unexpected export err: %v", err)
}
}

View File

@@ -0,0 +1,253 @@
package service
import (
"bytes"
"context"
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"log/slog"
"net/http"
"time"
"github.com/company/ai-ops/internal/domain/model"
"github.com/company/ai-ops/internal/domain/repository"
)
// HealingEngine 是自愈引擎
type HealingEngine struct {
alertRepo repository.AlertRepository
healingRepo HealingRepository
client *http.Client
interval time.Duration
stopCh chan struct{}
}
// HealingRepository 是自愈记录存储接口
type HealingRepository interface {
CreateHealing(ctx context.Context, h *HealingLog) error
UpdateHealingStatus(ctx context.Context, id, status string, result map[string]any, errCode string) error
}
// HealingLog 是自愈执行记录
type HealingLog struct {
ID string `json:"id"`
AlertID string `json:"alert_id"`
ActionType string `json:"action_type"`
Config map[string]any `json:"config"`
Status string `json:"status"`
DryRun bool `json:"dry_run"`
ResultDetail map[string]any `json:"result_detail,omitempty"`
ErrorCode string `json:"error_code,omitempty"`
StartedAt time.Time `json:"started_at"`
CompletedAt *time.Time `json:"completed_at,omitempty"`
}
// NewHealingEngine 创建自愈引擎
func NewHealingEngine(ar repository.AlertRepository, hr HealingRepository) *HealingEngine {
return &HealingEngine{
alertRepo: ar,
healingRepo: hr,
client: &http.Client{Timeout: 20 * time.Second},
interval: 30 * time.Second,
stopCh: make(chan struct{}),
}
}
// Start 启动自愈引擎
func (e *HealingEngine) Start() {
slog.Info("healing_engine_started", "interval", e.interval)
go e.loop()
}
// Stop 停止自愈引擎
func (e *HealingEngine) Stop() {
close(e.stopCh)
slog.Info("healing_engine_stopped")
}
func (e *HealingEngine) loop() {
ticker := time.NewTicker(e.interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
e.process(context.Background())
case <-e.stopCh:
return
}
}
}
func (e *HealingEngine) process(ctx context.Context) {
// 查询 triggered 状态的告警事件
events, _, err := e.alertRepo.ListEvents(ctx, "triggered", 1, 100)
if err != nil {
slog.Error("list_triggered_events_failed", "error", err)
return
}
for _, event := range events {
if err := e.handleEvent(ctx, &event); err != nil {
slog.Error("handle_event_failed", "event_id", event.ID, "error", err)
}
}
}
func (e *HealingEngine) handleEvent(ctx context.Context, event *model.AlertEvent) error {
// 获取规则配置
rule, err := e.alertRepo.GetRuleByID(ctx, event.RuleID)
if err != nil {
return fmt.Errorf("get rule: %w", err)
}
// 检查是否有自愈动作
if rule.HealingAction == nil || *rule.HealingAction == "" {
return nil
}
// 创建自愈记录
healing := &HealingLog{
ID: generateHealingID(),
AlertID: event.ID,
ActionType: *rule.HealingAction,
Config: rule.HealingConfig,
Status: "pending",
DryRun: rule.IsSandboxed,
StartedAt: time.Now(),
}
if err := e.healingRepo.CreateHealing(ctx, healing); err != nil {
return fmt.Errorf("create healing log: %w", err)
}
// 沙盒模式:只记录不执行
if healing.DryRun {
slog.Info("healing_dry_run",
"healing_id", healing.ID,
"action", healing.ActionType,
"alert_id", event.ID,
)
healing.Status = "succeeded"
healing.ResultDetail = map[string]any{"message": "dry run, no actual action executed"}
return e.healingRepo.UpdateHealingStatus(ctx, healing.ID, healing.Status, healing.ResultDetail, "")
}
// 执行自愈动作
result, err := e.executeAction(ctx, healing)
if err != nil {
healing.Status = "failed"
healing.ErrorCode = "HEALING_EXEC_FAILED"
slog.Error("healing_action_failed",
"healing_id", healing.ID,
"action", healing.ActionType,
"error", err,
)
} else {
healing.Status = "succeeded"
healing.ResultDetail = result
slog.Info("healing_action_succeeded",
"healing_id", healing.ID,
"action", healing.ActionType,
)
}
return e.healingRepo.UpdateHealingStatus(ctx, healing.ID, healing.Status, healing.ResultDetail, healing.ErrorCode)
}
func (e *HealingEngine) executeAction(ctx context.Context, healing *HealingLog) (map[string]any, error) {
switch healing.ActionType {
case "switch_route":
return e.executeSwitchRoute(ctx, healing)
case "throttle":
return e.executeThrottle(ctx, healing)
case "restart_instance":
return e.executeRestartInstance(ctx, healing)
case "invoke_script":
return e.executeInvokeScript(ctx, healing)
default:
return nil, fmt.Errorf("unsupported healing action: %s", healing.ActionType)
}
}
func (e *HealingEngine) executeSwitchRoute(ctx context.Context, healing *HealingLog) (map[string]any, error) {
return e.callConfiguredEndpoint(ctx, healing, "switch_route")
}
func (e *HealingEngine) executeThrottle(ctx context.Context, healing *HealingLog) (map[string]any, error) {
return e.callConfiguredEndpoint(ctx, healing, "throttle")
}
func (e *HealingEngine) executeRestartInstance(ctx context.Context, healing *HealingLog) (map[string]any, error) {
if allowed, _ := healing.Config["allow_restart"].(bool); !allowed {
return nil, fmt.Errorf("restart_instance requires allow_restart=true")
}
return e.callConfiguredEndpoint(ctx, healing, "restart_instance")
}
func (e *HealingEngine) executeInvokeScript(ctx context.Context, healing *HealingLog) (map[string]any, error) {
if _, ok := healing.Config["script_id"].(string); !ok {
return nil, fmt.Errorf("invoke_script requires script_id; raw script content is not allowed")
}
return e.callConfiguredEndpoint(ctx, healing, "invoke_script")
}
func (e *HealingEngine) callConfiguredEndpoint(ctx context.Context, healing *HealingLog, action string) (map[string]any, error) {
endpoint, ok := healing.Config["endpoint"].(string)
if !ok || endpoint == "" {
return nil, fmt.Errorf("%s requires endpoint", action)
}
method, _ := healing.Config["method"].(string)
if method == "" {
method = http.MethodPost
}
if method != http.MethodPost && method != http.MethodPut && method != http.MethodPatch {
return nil, fmt.Errorf("%s method %s is not allowed", action, method)
}
payload := map[string]any{
"healing_id": healing.ID,
"alert_id": healing.AlertID,
"action_type": healing.ActionType,
"config": healing.Config,
"dry_run": healing.DryRun,
}
body, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("marshal healing payload: %w", err)
}
req, err := http.NewRequestWithContext(ctx, method, endpoint, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("create healing request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
if token, _ := healing.Config["token"].(string); token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
resp, err := e.client.Do(req)
if err != nil {
return nil, fmt.Errorf("call healing endpoint: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
return nil, fmt.Errorf("healing endpoint returned status %d", resp.StatusCode)
}
return map[string]any{
"message": action + " executed",
"endpoint": endpoint,
"status_code": resp.StatusCode,
}, nil
}
func generateHealingID() string {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return fmt.Sprintf("00000000-0000-4000-8000-%012d", time.Now().UnixNano()%1_000_000_000_000)
}
b[6] = (b[6] & 0x0f) | 0x40
b[8] = (b[8] & 0x3f) | 0x80
return fmt.Sprintf("%s-%s-%s-%s-%s", hex.EncodeToString(b[0:4]), hex.EncodeToString(b[4:6]), hex.EncodeToString(b[6:8]), hex.EncodeToString(b[8:10]), hex.EncodeToString(b[10:16]))
}

View File

@@ -0,0 +1,128 @@
package service
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/company/ai-ops/internal/domain/model"
)
type fakeHealingRepo struct {
created []HealingLog
updated []HealingLog
}
func (r *fakeHealingRepo) CreateHealing(ctx context.Context, h *HealingLog) error {
r.created = append(r.created, *h)
return nil
}
func (r *fakeHealingRepo) UpdateHealingStatus(ctx context.Context, id, status string, result map[string]any, errCode string) error {
r.updated = append(r.updated, HealingLog{ID: id, Status: status, ResultDetail: result, ErrorCode: errCode})
return nil
}
func TestHealingEngineExecutesConfiguredEndpointAndRecordsSuccess(t *testing.T) {
called := false
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
if r.Method != http.MethodPost {
t.Fatalf("method = %s, want POST", r.Method)
}
w.WriteHeader(http.StatusAccepted)
}))
defer server.Close()
action := "switch_route"
alertRepo := &fakeAggregationAlertRepo{rules: []model.AlertRule{{
ID: "rule-1",
HealingAction: &action,
HealingConfig: map[string]any{"endpoint": server.URL},
IsSandboxed: false,
}}}
healingRepo := &fakeHealingRepo{}
engine := NewHealingEngine(alertRepo, healingRepo)
err := engine.handleEvent(context.Background(), &model.AlertEvent{ID: "alert-1", RuleID: "rule-1"})
if err != nil {
t.Fatalf("handle event: %v", err)
}
if !called {
t.Fatalf("expected healing endpoint to be called")
}
if len(healingRepo.updated) != 1 || healingRepo.updated[0].Status != "succeeded" {
t.Fatalf("updated healing logs = %#v, want one succeeded", healingRepo.updated)
}
}
func TestHealingEngineRejectsRestartWithoutExplicitAllow(t *testing.T) {
healing := &HealingLog{ActionType: "restart_instance", Config: map[string]any{"endpoint": "http://127.0.0.1"}}
engine := NewHealingEngine(nil, nil)
_, err := engine.executeAction(context.Background(), healing)
if err == nil {
t.Fatalf("expected restart_instance without allow_restart to fail")
}
}
func TestHealingEngineProcessDryRunAndActionBranches(t *testing.T) {
action := "throttle"
alertRepo := &fakeAggregationAlertRepo{rules: []model.AlertRule{{ID: "rule-heal", HealingAction: &action, HealingConfig: map[string]any{"limit": 1}, IsSandboxed: true}}}
alertRepo.createdEvents = nil
healingRepo := &fakeHealingRepo{}
engine := NewHealingEngine(alertRepo, healingRepo)
alertRepo.rules[0].ID = "rule-heal"
// fakeAggregationAlertRepo ListEvents returns nil, so cover direct handleEvent dry-run.
if err := engine.handleEvent(context.Background(), &model.AlertEvent{ID: "event-heal", RuleID: "rule-heal"}); err != nil {
t.Fatal(err)
}
if len(healingRepo.created) != 1 || len(healingRepo.updated) != 1 || healingRepo.updated[0].Status != "succeeded" {
t.Fatalf("dry-run healing logs = created=%+v updated=%+v", healingRepo.created, healingRepo.updated)
}
if _, err := engine.executeAction(context.Background(), &HealingLog{ActionType: "unsupported", Config: map[string]any{}}); err == nil {
t.Fatal("expected unsupported action error")
}
if _, err := engine.executeInvokeScript(context.Background(), &HealingLog{ActionType: "invoke_script", Config: map[string]any{"endpoint": "http://example.invalid"}}); err == nil {
t.Fatal("expected missing script_id error")
}
if generateHealingID() == "" {
t.Fatal("empty healing id")
}
}
func TestHealingEngineEndpointVariants(t *testing.T) {
var gotAuth string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotAuth = r.Header.Get("Authorization")
w.WriteHeader(http.StatusAccepted)
}))
defer server.Close()
engine := NewHealingEngine(&fakeAggregationAlertRepo{}, &fakeHealingRepo{})
if _, err := engine.executeThrottle(context.Background(), &HealingLog{ID: "h", AlertID: "a", ActionType: "throttle", Config: map[string]any{"endpoint": server.URL, "method": http.MethodPatch, "token": "tok"}}); err != nil {
t.Fatal(err)
}
if gotAuth != "Bearer tok" {
t.Fatalf("auth header = %s", gotAuth)
}
if _, err := engine.executeRestartInstance(context.Background(), &HealingLog{ID: "h", AlertID: "a", ActionType: "restart_instance", Config: map[string]any{"endpoint": server.URL, "allow_restart": true}}); err != nil {
t.Fatal(err)
}
if _, err := engine.executeInvokeScript(context.Background(), &HealingLog{ID: "h", AlertID: "a", ActionType: "invoke_script", Config: map[string]any{"endpoint": server.URL, "script_id": "script-1"}}); err != nil {
t.Fatal(err)
}
if _, err := engine.callConfiguredEndpoint(context.Background(), &HealingLog{Config: map[string]any{"endpoint": server.URL, "method": http.MethodGet}}, "bad"); err == nil {
t.Fatal("expected disallowed method error")
}
}
func TestHealingEngineStartStopAndProcess(t *testing.T) {
engine := NewHealingEngine(&fakeAggregationAlertRepo{}, &fakeHealingRepo{})
engine.interval = time.Hour
engine.process(context.Background())
engine.Start()
time.Sleep(5 * time.Millisecond)
engine.Stop()
}

View File

@@ -0,0 +1,105 @@
package service
import (
"context"
"encoding/csv"
"fmt"
"io"
"time"
"github.com/company/ai-ops/internal/domain/model"
"github.com/company/ai-ops/internal/domain/repository"
"github.com/company/ai-ops/internal/redis"
goredis "github.com/redis/go-redis/v9"
)
// LogService 是日志业务逻辑层
type LogService struct {
logRepo repository.LogRepository
}
func NewLogService(lr repository.LogRepository) *LogService {
return &LogService{logRepo: lr}
}
// QueryLogs 查询日志
func (s *LogService) QueryLogs(ctx context.Context, filter model.LogQueryFilter) ([]model.RequestLog, int, error) {
// Redis 缓存键:基于筛选条件构建
cacheKey := s.buildCacheKey(filter)
// 尝试从缓存获取
if redis.Client != nil {
var cached []model.RequestLog
var total int
err := redis.Client.Get(ctx, cacheKey+":items").Scan(&cached)
if err == nil {
redis.Client.Get(ctx, cacheKey+":total").Scan(&total)
return cached, total, nil
}
if err != goredis.Nil {
// 缓存错误不阻断业务,继续查数据库
}
}
// 超时控制
queryCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
logs, total, err := s.logRepo.Query(queryCtx, filter)
if err != nil {
return nil, 0, fmt.Errorf("query logs: %w", err)
}
// 写入缓存5分钟 TTL
if redis.Client != nil {
redis.Client.Set(ctx, cacheKey+":items", logs, 5*time.Minute)
redis.Client.Set(ctx, cacheKey+":total", total, 5*time.Minute)
}
return logs, total, nil
}
// ExportLogsCSV 导出日志为 CSV
func (s *LogService) ExportLogsCSV(ctx context.Context, filter model.LogQueryFilter, w io.Writer) error {
filter.Page = 1
filter.PageSize = 10000 // 导出上限
logs, _, err := s.logRepo.Query(ctx, filter)
if err != nil {
return fmt.Errorf("query logs for export: %w", err)
}
csvWriter := csv.NewWriter(w)
defer csvWriter.Flush()
// 写入表头
if err := csvWriter.Write([]string{"时间", "服务名", "路径", "方法", "状态码", "延迟(ms)", "用户ID", "供应商ID", "错误码"}); err != nil {
return fmt.Errorf("write csv header: %w", err)
}
// 写入数据
for _, l := range logs {
row := []string{
l.Timestamp.Format(time.RFC3339),
l.Service,
l.Path,
l.Method,
fmt.Sprintf("%d", l.StatusCode),
fmt.Sprintf("%.2f", l.LatencyMs),
l.UserID,
l.SupplierID,
l.ErrorCode,
}
if err := csvWriter.Write(row); err != nil {
return fmt.Errorf("write csv row: %w", err)
}
}
return nil
}
func (s *LogService) buildCacheKey(filter model.LogQueryFilter) string {
return fmt.Sprintf("ai-ops:logs:%s:%s:%d:%s:%s:%d:%d",
filter.Service, filter.Path, filter.StatusCode,
filter.UserID, filter.SupplierID, filter.Page, filter.PageSize)
}

View File

@@ -0,0 +1,60 @@
package service
import (
"context"
"fmt"
"github.com/company/ai-ops/internal/domain/model"
"github.com/company/ai-ops/internal/domain/repository"
)
// MetricService 是指标业务逻辑层
type MetricService struct {
metricRepo repository.MetricRepository
alertRepo repository.AlertRepository
}
func NewMetricService(mr repository.MetricRepository, ar repository.AlertRepository) *MetricService {
return &MetricService{metricRepo: mr, alertRepo: ar}
}
// GetRealtimeMetrics 获取首页实时指标
func (s *MetricService) GetRealtimeMetrics(ctx context.Context) (*model.RealtimeMetrics, error) {
return s.metricRepo.GetRealtime(ctx)
}
// GetSupplierCount 获取活跃供应商数量
func (s *MetricService) GetSupplierCount(ctx context.Context) (*model.SupplierCount, error) {
// 从指标库查询供应商健康状态
points, err := s.metricRepo.Query(ctx, model.MetricQueryRequest{
Name: "supplier_health",
})
if err != nil {
return nil, fmt.Errorf("query supplier health: %w", err)
}
var healthy, unhealthy int
for _, p := range points {
if p.Value > 0.5 {
healthy++
} else {
unhealthy++
}
}
return &model.SupplierCount{
Total: healthy + unhealthy,
Healthy: healthy,
Unhealthy: unhealthy,
}, nil
}
// GetOpenAlertCount 获取未关闭告警数量
func (s *MetricService) GetOpenAlertCount(ctx context.Context) (*model.AlertCount, error) {
return s.alertRepo.GetOpenCount(ctx)
}
// QueryMetrics 指标下钻查询
func (s *MetricService) QueryMetrics(ctx context.Context, req model.MetricQueryRequest) ([]model.MetricPoint, error) {
return s.metricRepo.Query(ctx, req)
}

View File

@@ -0,0 +1,115 @@
package service
import (
"context"
"testing"
"time"
"github.com/company/ai-ops/internal/domain/model"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
// MockMetricRepository 模拟指标存储
type MockMetricRepository struct {
mock.Mock
}
func (m *MockMetricRepository) GetRealtime(ctx context.Context) (*model.RealtimeMetrics, error) {
args := m.Called(ctx)
return args.Get(0).(*model.RealtimeMetrics), args.Error(1)
}
func (m *MockMetricRepository) Query(ctx context.Context, req model.MetricQueryRequest) ([]model.MetricPoint, error) {
args := m.Called(ctx, req)
return args.Get(0).([]model.MetricPoint), args.Error(1)
}
func (m *MockMetricRepository) GetLatest(ctx context.Context, source, name string) (*model.MetricPoint, error) {
args := m.Called(ctx, source, name)
return args.Get(0).(*model.MetricPoint), args.Error(1)
}
// MockAlertRepository 模拟告警存储
type MockAlertRepository struct {
mock.Mock
}
func (m *MockAlertRepository) GetOpenCount(ctx context.Context) (*model.AlertCount, error) {
args := m.Called(ctx)
return args.Get(0).(*model.AlertCount), args.Error(1)
}
func (m *MockAlertRepository) ListRules(ctx context.Context) ([]model.AlertRule, error) {
args := m.Called(ctx)
return args.Get(0).([]model.AlertRule), args.Error(1)
}
func (m *MockAlertRepository) GetRuleByID(ctx context.Context, id string) (*model.AlertRule, error) {
args := m.Called(ctx, id)
return args.Get(0).(*model.AlertRule), args.Error(1)
}
func (m *MockAlertRepository) CreateRule(ctx context.Context, rule *model.AlertRule) error {
args := m.Called(ctx, rule)
return args.Error(0)
}
func (m *MockAlertRepository) UpdateRule(ctx context.Context, rule *model.AlertRule) error {
args := m.Called(ctx, rule)
return args.Error(0)
}
func (m *MockAlertRepository) DeleteRule(ctx context.Context, id string) error {
args := m.Called(ctx, id)
return args.Error(0)
}
func (m *MockAlertRepository) ListEvents(ctx context.Context, status string, page, pageSize int) ([]model.AlertEvent, int, error) {
args := m.Called(ctx, status, page, pageSize)
return args.Get(0).([]model.AlertEvent), args.Int(1), args.Error(2)
}
func (m *MockAlertRepository) CreateEvent(ctx context.Context, event *model.AlertEvent) error {
args := m.Called(ctx, event)
return args.Error(0)
}
func (m *MockAlertRepository) CreateEventWithAggregation(ctx context.Context, event *model.AlertEvent, window time.Duration, threshold int) (*model.AlertEvent, error) {
args := m.Called(ctx, event, window, threshold)
return args.Get(0).(*model.AlertEvent), args.Error(1)
}
func (m *MockAlertRepository) UpdateEventStatus(ctx context.Context, id, status string) error {
args := m.Called(ctx, id, status)
return args.Error(0)
}
func (m *MockAlertRepository) EscalateEvent(ctx context.Context, id, newLevel string) error {
args := m.Called(ctx, id, newLevel)
return args.Error(0)
}
func TestMetricService_GetRealtimeMetrics(t *testing.T) {
mockMetric := new(MockMetricRepository)
mockAlert := new(MockAlertRepository)
svc := NewMetricService(mockMetric, mockAlert)
expected := &model.RealtimeMetrics{
QPS: 100.5,
AvgLatency: 45.2,
P99Latency: 120.8,
ErrorRate: 0.01,
}
mockMetric.On("GetRealtime", mock.Anything).Return(expected, nil)
result, err := svc.GetRealtimeMetrics(context.Background())
assert.NoError(t, err)
assert.Equal(t, expected, result)
mockMetric.AssertExpectations(t)
}
func TestMetricService_GetOpenAlertCount(t *testing.T) {
mockMetric := new(MockMetricRepository)
mockAlert := new(MockAlertRepository)
svc := NewMetricService(mockMetric, mockAlert)
expected := &model.AlertCount{Open: 5, P0: 1, P1: 2, P2: 1, P3: 1}
mockAlert.On("GetOpenCount", mock.Anything).Return(expected, nil)
result, err := svc.GetOpenAlertCount(context.Background())
assert.NoError(t, err)
assert.Equal(t, expected, result)
mockAlert.AssertExpectations(t)
}

View File

@@ -0,0 +1,248 @@
package service
import (
"bytes"
"context"
"encoding/json"
"fmt"
"log/slog"
"net/http"
"time"
"github.com/company/ai-ops/internal/domain/model"
"github.com/company/ai-ops/internal/domain/repository"
)
// NotificationTask 是通知任务
type NotificationTask struct {
Event *model.AlertEvent
ChannelIDs []string
Priority string // P0, P1, P2, P3
}
// NotificationService 是通知服务
type NotificationService struct {
channelRepo repository.ChannelRepository
logRepo repository.NotificationLogRepository
client *http.Client
queue chan NotificationTask
stopCh chan struct{}
}
// NewNotificationService 创建通知服务
func NewNotificationService(cr repository.ChannelRepository, logRepos ...repository.NotificationLogRepository) *NotificationService {
var logRepo repository.NotificationLogRepository
if len(logRepos) > 0 {
logRepo = logRepos[0]
}
ns := &NotificationService{
channelRepo: cr,
logRepo: logRepo,
client: &http.Client{Timeout: 10 * time.Second},
queue: make(chan NotificationTask, 1000),
stopCh: make(chan struct{}),
}
go ns.worker()
return ns
}
// Stop 停止通知服务
func (s *NotificationService) Stop() {
close(s.stopCh)
}
// Enqueue 将通知任务入队列
func (s *NotificationService) Enqueue(event *model.AlertEvent, channelIDs []string) {
task := NotificationTask{
Event: event,
ChannelIDs: channelIDs,
Priority: event.Level,
}
select {
case s.queue <- task:
slog.Info("notification_enqueued", "event_id", event.ID, "priority", event.Level)
default:
slog.Warn("notification_queue_full", "event_id", event.ID)
}
}
func (s *NotificationService) worker() {
for {
select {
case task := <-s.queue:
s.processTask(context.Background(), task)
case <-s.stopCh:
return
}
}
}
func (s *NotificationService) processTask(ctx context.Context, task NotificationTask) {
// 根据优先级设置发送超时
timeout := 120 * time.Second
if task.Priority == "P0" || task.Priority == "P1" {
timeout = 30 * time.Second
}
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
channels, err := s.channelRepo.List(ctx)
if err != nil {
slog.Error("list_channels_failed", "error", err, "event_id", task.Event.ID)
return
}
// 按优先级排序渠道
ordered := s.filterAndOrderChannels(channels, task.ChannelIDs)
// 发送通知,失败时自动切换备用渠道
sent := false
for _, ch := range ordered {
logID := s.createSendLog(ctx, task.Event, ch)
if err := s.sendToChannel(ctx, task.Event, ch); err != nil {
s.markSendFailed(ctx, logID, 1, err)
slog.Error("notify_channel_failed",
"event_id", task.Event.ID,
"channel_id", ch.ID,
"channel_type", ch.ChannelType,
"error", err,
)
continue
}
s.markSendSent(ctx, logID)
sent = true
slog.Info("notify_sent",
"event_id", task.Event.ID,
"channel_id", ch.ID,
"channel_type", ch.ChannelType,
)
break
}
if !sent {
slog.Error("notify_all_channels_failed", "event_id", task.Event.ID)
}
}
func (s *NotificationService) createSendLog(ctx context.Context, event *model.AlertEvent, ch *model.NotificationChannel) string {
if s.logRepo == nil {
return ""
}
log := &model.NotificationLog{
EventID: event.ID,
ChannelID: ch.ID,
ChannelType: ch.ChannelType,
Status: "pending",
}
if err := s.logRepo.CreateLog(ctx, log); err != nil {
slog.Error("create_notification_log_failed", "event_id", event.ID, "channel_id", ch.ID, "error", err)
return ""
}
return log.ID
}
func (s *NotificationService) markSendSent(ctx context.Context, logID string) {
if s.logRepo == nil || logID == "" {
return
}
if err := s.logRepo.MarkSent(ctx, logID); err != nil {
slog.Error("mark_notification_sent_failed", "log_id", logID, "error", err)
}
}
func (s *NotificationService) markSendFailed(ctx context.Context, logID string, retryCount int, err error) {
if s.logRepo == nil || logID == "" {
return
}
if markErr := s.logRepo.MarkFailed(ctx, logID, retryCount, err.Error()); markErr != nil {
slog.Error("mark_notification_failed_failed", "log_id", logID, "error", markErr)
}
}
func (s *NotificationService) filterAndOrderChannels(all []model.NotificationChannel, ids []string) []*model.NotificationChannel {
idSet := make(map[string]bool)
for _, id := range ids {
idSet[id] = true
}
var filtered []*model.NotificationChannel
for i := range all {
if idSet[all[i].ID] {
filtered = append(filtered, &all[i])
}
}
// 按优先级排序(高优先级在前)
for i := 0; i < len(filtered)-1; i++ {
for j := i + 1; j < len(filtered); j++ {
if filtered[j].Priority > filtered[i].Priority {
filtered[i], filtered[j] = filtered[j], filtered[i]
}
}
}
return filtered
}
func (s *NotificationService) sendToChannel(ctx context.Context, event *model.AlertEvent, ch *model.NotificationChannel) error {
switch ch.ChannelType {
case "webhook":
return s.sendWebhook(ctx, event, ch)
case "email":
return s.sendEmail(ctx, event, ch)
case "feishu":
return s.sendFeishu(ctx, event, ch)
case "wechat":
return s.sendWechat(ctx, event, ch)
default:
return fmt.Errorf("unsupported channel type: %s", ch.ChannelType)
}
}
func (s *NotificationService) sendWebhook(ctx context.Context, event *model.AlertEvent, ch *model.NotificationChannel) error {
url, ok := ch.Config["webhook_url"].(string)
if !ok || url == "" {
return fmt.Errorf("webhook_url not configured")
}
payload := map[string]any{
"alert_id": event.ID,
"rule_id": event.RuleID,
"level": event.Level,
"status": event.Status,
"resource": event.ResourceID,
"value": event.CurrentValue,
"threshold": event.ThresholdValue,
"timestamp": time.Now().Format(time.RFC3339),
}
body, _ := json.Marshal(payload)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
resp, err := s.client.Do(req)
if err != nil {
return fmt.Errorf("webhook request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
return fmt.Errorf("webhook returned status %d", resp.StatusCode)
}
return nil
}
func (s *NotificationService) sendEmail(ctx context.Context, event *model.AlertEvent, ch *model.NotificationChannel) error {
return fmt.Errorf("email channel not yet implemented")
}
func (s *NotificationService) sendFeishu(ctx context.Context, event *model.AlertEvent, ch *model.NotificationChannel) error {
return fmt.Errorf("feishu channel not yet implemented")
}
func (s *NotificationService) sendWechat(ctx context.Context, event *model.AlertEvent, ch *model.NotificationChannel) error {
return fmt.Errorf("wechat channel not yet implemented")
}

View File

@@ -0,0 +1,139 @@
package service
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/company/ai-ops/internal/domain/model"
)
type fakeChannelRepo struct {
channels []model.NotificationChannel
}
func (r *fakeChannelRepo) List(ctx context.Context) ([]model.NotificationChannel, error) {
return r.channels, nil
}
func (r *fakeChannelRepo) GetByID(ctx context.Context, id string) (*model.NotificationChannel, error) {
return nil, nil
}
func (r *fakeChannelRepo) Create(ctx context.Context, ch *model.NotificationChannel) error {
return nil
}
func (r *fakeChannelRepo) Update(ctx context.Context, ch *model.NotificationChannel) error {
return nil
}
func (r *fakeChannelRepo) Delete(ctx context.Context, id string) error { return nil }
type fakeNotificationLogRepo struct {
created []model.NotificationLog
sent []string
failed []string
}
func (r *fakeNotificationLogRepo) CreateLog(ctx context.Context, log *model.NotificationLog) error {
if log.ID == "" {
log.ID = "log-1"
}
r.created = append(r.created, *log)
return nil
}
func (r *fakeNotificationLogRepo) MarkSent(ctx context.Context, id string) error {
r.sent = append(r.sent, id)
return nil
}
func (r *fakeNotificationLogRepo) MarkFailed(ctx context.Context, id string, retryCount int, errMessage string) error {
r.failed = append(r.failed, id)
return nil
}
func TestNotificationServiceWritesLogWhenWebhookSent(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
channelRepo := &fakeChannelRepo{channels: []model.NotificationChannel{{
ID: "11111111-1111-4111-8111-111111111111",
Name: "webhook",
ChannelType: "webhook",
Config: map[string]any{"webhook_url": server.URL},
Priority: 10,
Enabled: true,
}}}
logRepo := &fakeNotificationLogRepo{}
svc := NewNotificationService(channelRepo, logRepo)
defer svc.Stop()
svc.processTask(context.Background(), NotificationTask{
Event: &model.AlertEvent{
ID: "22222222-2222-4222-8222-222222222222",
RuleID: "33333333-3333-4333-8333-333333333333",
Level: "P1",
Status: "triggered",
ResourceID: "svc-a",
},
ChannelIDs: []string{"11111111-1111-4111-8111-111111111111"},
Priority: "P1",
})
if len(logRepo.created) != 1 {
t.Fatalf("created logs = %d, want 1", len(logRepo.created))
}
if len(logRepo.sent) != 1 || logRepo.sent[0] != "log-1" {
t.Fatalf("sent logs = %#v, want [log-1]", logRepo.sent)
}
if len(logRepo.failed) != 0 {
t.Fatalf("failed logs = %#v, want empty", logRepo.failed)
}
}
func TestNotificationServiceFailureAndFallbackBranches(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadGateway)
}))
defer server.Close()
channels := []model.NotificationChannel{
{ID: "c1", ChannelType: "webhook", Config: map[string]any{"webhook_url": server.URL}, Priority: 1, Enabled: true},
{ID: "c2", ChannelType: "email", Priority: 2, Enabled: true},
{ID: "c3", ChannelType: "feishu", Priority: 3, Enabled: true},
{ID: "c4", ChannelType: "wechat", Priority: 4, Enabled: true},
{ID: "c5", ChannelType: "sms", Priority: 5, Enabled: true},
{ID: "disabled", ChannelType: "webhook", Priority: 99, Enabled: false},
}
logs := &fakeNotificationLogRepo{}
svc := NewNotificationService(&fakeChannelRepo{channels: channels}, logs)
defer svc.Stop()
event := &model.AlertEvent{ID: "event-1", RuleID: "rule-1", Level: "P1", ResourceType: "svc", ResourceID: "api", CurrentValue: "10", ThresholdValue: "5"}
ordered := svc.filterAndOrderChannels(channels, []string{"c1", "c2", "missing", "disabled"})
if len(ordered) != 3 || ordered[0].ID != "disabled" || ordered[1].ID != "c2" || ordered[2].ID != "c1" {
t.Fatalf("unexpected ordered channels: %+v", ordered)
}
svc.processTask(context.Background(), NotificationTask{Event: event, ChannelIDs: []string{"c1", "c2"}})
if len(logs.failed) < 2 || len(logs.sent) != 0 {
t.Fatalf("expected multiple failures and no success: sent=%+v failed=%+v", logs.sent, logs.failed)
}
if err := svc.sendToChannel(context.Background(), event, &model.NotificationChannel{ChannelType: "unknown"}); err == nil {
t.Fatal("expected unsupported channel error")
}
if err := svc.sendWebhook(context.Background(), event, &model.NotificationChannel{Config: map[string]any{}}); err == nil {
t.Fatal("expected missing webhook url error")
}
svc.Enqueue(event, []string{"c2"})
}
func TestNotificationServiceExplicitUnsupportedPlaceholders(t *testing.T) {
svc := NewNotificationService(&fakeChannelRepo{})
defer svc.Stop()
event := &model.AlertEvent{ID: "event-placeholders", RuleID: "rule", Level: "P2"}
for _, channelType := range []string{"email", "feishu", "wechat"} {
err := svc.sendToChannel(context.Background(), event, &model.NotificationChannel{ChannelType: channelType})
if err == nil {
t.Fatalf("expected %s placeholder error", channelType)
}
}
}

View File

@@ -0,0 +1,50 @@
package service
import (
"context"
"fmt"
"github.com/company/ai-ops/internal/domain/model"
"github.com/company/ai-ops/internal/domain/repository"
)
// RuleService 是告警规则业务层
type RuleService struct {
repo repository.AlertRepository
}
func NewRuleService(repo repository.AlertRepository) *RuleService {
return &RuleService{repo: repo}
}
func (s *RuleService) ListRules(ctx context.Context) ([]model.AlertRule, error) {
return s.repo.ListRules(ctx)
}
func (s *RuleService) GetRule(ctx context.Context, id string) (*model.AlertRule, error) {
return s.repo.GetRuleByID(ctx, id)
}
func (s *RuleService) CreateRule(ctx context.Context, rule *model.AlertRule) error {
if rule.ID == "" {
return fmt.Errorf("rule id is required")
}
if rule.Name == "" || rule.MetricName == "" {
return fmt.Errorf("name and metric_name are required")
}
rule.Enabled = true
rule.Version = 1
return s.repo.CreateRule(ctx, rule)
}
func (s *RuleService) UpdateRule(ctx context.Context, rule *model.AlertRule) error {
if rule.ID == "" {
return fmt.Errorf("rule id is required")
}
rule.Version++
return s.repo.UpdateRule(ctx, rule)
}
func (s *RuleService) DeleteRule(ctx context.Context, id string) error {
return s.repo.DeleteRule(ctx, id)
}