package service import ( "bytes" "context" "crypto/hmac" cryptorand "crypto/rand" "crypto/sha256" "encoding/hex" "encoding/json" "fmt" "log/slog" "net" "net/http" "net/url" "strings" "sync" "time" "github.com/user-management-system/internal/domain" "github.com/user-management-system/internal/repository" "gorm.io/gorm" ) // WebhookService Webhook 服务 type WebhookService struct { db *gorm.DB repo *repository.WebhookRepository queue chan *deliveryTask workers int config WebhookServiceConfig wg sync.WaitGroup once sync.Once } type WebhookServiceConfig struct { Enabled bool SecretHeader string TimeoutSec int MaxRetries int RetryBackoff string WorkerCount int QueueSize int } // deliveryTask 投递任务 type deliveryTask struct { webhook *domain.Webhook eventType domain.WebhookEventType payload []byte attempt int } // WebhookEvent 发布的事件结构 type WebhookEvent struct { EventID string `json:"event_id"` EventType domain.WebhookEventType `json:"event_type"` Timestamp time.Time `json:"timestamp"` Data interface{} `json:"data"` } // NewWebhookService 创建 Webhook 服务 func NewWebhookService(db *gorm.DB, cfgs ...WebhookServiceConfig) *WebhookService { cfg := defaultWebhookServiceConfig() if len(cfgs) > 0 { cfg = cfgs[0] } if cfg.WorkerCount <= 0 { cfg.WorkerCount = defaultWebhookServiceConfig().WorkerCount } if cfg.QueueSize <= 0 { cfg.QueueSize = defaultWebhookServiceConfig().QueueSize } if cfg.SecretHeader == "" { cfg.SecretHeader = defaultWebhookServiceConfig().SecretHeader } if cfg.TimeoutSec <= 0 { cfg.TimeoutSec = defaultWebhookServiceConfig().TimeoutSec } if cfg.MaxRetries <= 0 { cfg.MaxRetries = defaultWebhookServiceConfig().MaxRetries } if cfg.RetryBackoff == "" { cfg.RetryBackoff = defaultWebhookServiceConfig().RetryBackoff } svc := &WebhookService{ db: db, repo: repository.NewWebhookRepository(db), queue: make(chan *deliveryTask, cfg.QueueSize), workers: cfg.WorkerCount, config: cfg, } svc.startWorkers() return svc } func defaultWebhookServiceConfig() WebhookServiceConfig { return WebhookServiceConfig{ Enabled: true, SecretHeader: "X-Webhook-Signature", TimeoutSec: 10, MaxRetries: 3, RetryBackoff: "exponential", WorkerCount: 4, QueueSize: 1000, } } // startWorkers 启动后台投递 worker func (s *WebhookService) startWorkers() { s.once.Do(func() { for i := 0; i < s.workers; i++ { s.wg.Add(1) go func() { defer s.wg.Done() for task := range s.queue { s.deliver(task) } }() } }) } // Publish 发布事件:找到订阅该事件的所有 Webhook,异步投递 func (s *WebhookService) Publish(ctx context.Context, eventType domain.WebhookEventType, data interface{}) { if !s.config.Enabled { return } // 查询所有活跃 Webhook webhooks, err := s.repo.ListActive(ctx) if err != nil { return } // 构建事件载荷 eventID, err := generateEventID() if err != nil { slog.Error("generate event ID failed", "error", err) return } event := &WebhookEvent{ EventID: eventID, EventType: eventType, Timestamp: time.Now().UTC(), Data: data, } payloadBytes, err := json.Marshal(event) if err != nil { return } for i := range webhooks { wh := webhooks[i] // 检查是否订阅了该事件类型 if !webhookSubscribesTo(wh, eventType) { continue } task := &deliveryTask{ webhook: wh, eventType: eventType, payload: payloadBytes, attempt: 1, } // 非阻塞投递到队列 select { case s.queue <- task: default: // 队列满时记录但不阻塞 } } } // deliver 执行单次 HTTP 投递 func (s *WebhookService) deliver(task *deliveryTask) { wh := task.webhook // NEW-SEC-01 修复:检查 URL 安全性 if !isSafeURL(wh.URL) { s.recordDelivery(task, 0, "", "webhook URL 不安全: 可能存在 SSRF 风险", false) return } timeout := time.Duration(wh.TimeoutSec) * time.Second if timeout <= 0 { timeout = time.Duration(s.config.TimeoutSec) * time.Second } if timeout <= 0 { timeout = 10 * time.Second } client := &http.Client{Timeout: timeout} req, err := http.NewRequest("POST", wh.URL, bytes.NewReader(task.payload)) if err != nil { s.recordDelivery(task, 0, "", err.Error(), false) return } req.Header.Set("Content-Type", "application/json") req.Header.Set("User-Agent", "UserManagementSystem-Webhook/1.0") req.Header.Set("X-Webhook-Event", string(task.eventType)) req.Header.Set("X-Webhook-Attempt", fmt.Sprintf("%d", task.attempt)) // HMAC 签名 if wh.Secret != "" { sig := computeHMAC(task.payload, wh.Secret) req.Header.Set(s.config.SecretHeader, "sha256="+sig) } // 使用带超时的 context 避免请求无限等待 ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() resp, err := client.Do(req.WithContext(ctx)) if err != nil { s.handleFailure(task, 0, "", err.Error()) return } defer resp.Body.Close() var respBuf bytes.Buffer respBuf.ReadFrom(resp.Body) success := resp.StatusCode >= 200 && resp.StatusCode < 300 if !success { s.handleFailure(task, resp.StatusCode, respBuf.String(), "非 2xx 响应") return } s.recordDelivery(task, resp.StatusCode, respBuf.String(), "", true) } // handleFailure 处理投递失败(重试逻辑) func (s *WebhookService) handleFailure(task *deliveryTask, statusCode int, body, errMsg string) { s.recordDelivery(task, statusCode, body, errMsg, false) // 指数退避重试 if task.attempt < task.webhook.MaxRetries { backoff := time.Second if s.config.RetryBackoff == "fixed" { backoff = 2 * time.Second } else { backoff = time.Duration(1< 0 { b, _ := json.Marshal(req.Events) updates["events"] = string(b) } if req.Status != nil { updates["status"] = *req.Status } return s.repo.Update(ctx, id, updates) } // DeleteWebhook 删除 Webhook func (s *WebhookService) DeleteWebhook(ctx context.Context, id int64) error { return s.repo.Delete(ctx, id) } func (s *WebhookService) GetWebhook(ctx context.Context, id int64) (*domain.Webhook, error) { return s.repo.GetByID(ctx, id) } // ListWebhooks 获取 Webhook 列表(不分页) func (s *WebhookService) ListWebhooks(ctx context.Context, createdBy int64) ([]*domain.Webhook, error) { return s.repo.ListByCreator(ctx, createdBy) } // ListWebhooksPaginated 获取 Webhook 列表(分页) func (s *WebhookService) ListWebhooksPaginated(ctx context.Context, createdBy int64, offset, limit int) ([]*domain.Webhook, int64, error) { return s.repo.ListByCreatorPaginated(ctx, createdBy, offset, limit) } // GetWebhookDeliveries 获取投递记录 func (s *WebhookService) GetWebhookDeliveries(ctx context.Context, webhookID int64, limit int) ([]*domain.WebhookDelivery, error) { return s.repo.ListDeliveries(ctx, webhookID, limit) } // ---- Request/Response 结构 ---- // CreateWebhookRequest 创建 Webhook 请求 type CreateWebhookRequest struct { Name string `json:"name" binding:"required"` URL string `json:"url" binding:"required,url"` Secret string `json:"secret"` Events []domain.WebhookEventType `json:"events" binding:"required,min=1"` } // UpdateWebhookRequest 更新 Webhook 请求 type UpdateWebhookRequest struct { Name string `json:"name"` URL string `json:"url"` Events []domain.WebhookEventType `json:"events"` Status *domain.WebhookStatus `json:"status"` } // ---- 辅助函数 ---- // webhookSubscribesTo 检查 Webhook 是否订阅了指定事件类型 func webhookSubscribesTo(w *domain.Webhook, eventType domain.WebhookEventType) bool { var events []domain.WebhookEventType if err := json.Unmarshal([]byte(w.Events), &events); err != nil { return false } for _, e := range events { if e == eventType || e == "*" { return true } } return false } // SubscribesTo 检查 Webhook 是否订阅了指定事件类型(为 domain.Webhook 添加方法,通过包装实现) // 注意:此函数在 domain 包外部无法直接扩展,使用独立函数代替 // isSafeURL 检查 URL 是否安全(防止 SSRF 攻击) // NEW-SEC-01 修复:添加完整的 URL 安全检查 func isSafeURL(rawURL string) bool { u, err := url.Parse(rawURL) if err != nil || u.Scheme == "" { return false } // 只允许 http/https if u.Scheme != "http" && u.Scheme != "https" { return false } host := u.Hostname() // 禁止 localhost if host == "localhost" || host == "127.0.0.1" || host == "::1" { return false } // 检查内网 IP if ip := net.ParseIP(host); ip != nil { if isPrivateIP(ip) { return false } } // 检查内网域名 if strings.HasSuffix(host, ".internal") || strings.HasSuffix(host, ".local") || strings.HasSuffix(host, ".corp") || strings.HasSuffix(host, ".lan") || strings.HasSuffix(host, ".intranet") { return false } // 检查知名内网服务地址 blockedHosts := []string{ "metadata.google.internal", // GCP 元数据服务 "169.254.169.254", // AWS/Azure/GCP 元数据服务 "metadata.azure.internal", // Azure 元数据服务 "100.100.100.200", // 阿里云元数据服务 } for _, blocked := range blockedHosts { if host == blocked { return false } } return true } // isPrivateIP 检查是否为内网 IP func isPrivateIP(ip net.IP) bool { privateRanges := []string{ "10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "127.0.0.0/8", "::1/128", "fc00::/7", } for _, cidr := range privateRanges { _, network, err := net.ParseCIDR(cidr) if err != nil { continue } if network.Contains(ip) { return true } } return false } // computeHMAC 计算 HMAC-SHA256 签名 func computeHMAC(payload []byte, secret string) string { mac := hmac.New(sha256.New, []byte(secret)) mac.Write(payload) return hex.EncodeToString(mac.Sum(nil)) } // generateEventID 生成随机事件 ID func generateEventID() (string, error) { b := make([]byte, 8) if _, err := cryptorand.Read(b); err != nil { return "", fmt.Errorf("generate event ID failed: %w", err) } return "evt_" + hex.EncodeToString(b), nil } // generateWebhookSecret 生成随机 Webhook 签名密钥 func generateWebhookSecret() (string, error) { b := make([]byte, 24) if _, err := cryptorand.Read(b); err != nil { return "", fmt.Errorf("generate webhook secret failed: %w", err) } return strings.ToLower(hex.EncodeToString(b)), nil }