Files
user-system/internal/service/webhook.go

485 lines
12 KiB
Go
Raw Normal View History

package service
import (
"bytes"
"context"
"crypto/hmac"
cryptorand "crypto/rand"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"log/slog"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/repository"
"gorm.io/gorm"
)
// WebhookService Webhook 服务
type WebhookService struct {
db *gorm.DB
repo *repository.WebhookRepository
queue chan *deliveryTask
workers int
config WebhookServiceConfig
wg sync.WaitGroup
once sync.Once
}
type WebhookServiceConfig struct {
Enabled bool
SecretHeader string
TimeoutSec int
MaxRetries int
RetryBackoff string
WorkerCount int
QueueSize int
}
// deliveryTask 投递任务
type deliveryTask struct {
webhook *domain.Webhook
eventType domain.WebhookEventType
payload []byte
attempt int
}
// WebhookEvent 发布的事件结构
type WebhookEvent struct {
EventID string `json:"event_id"`
EventType domain.WebhookEventType `json:"event_type"`
Timestamp time.Time `json:"timestamp"`
Data interface{} `json:"data"`
}
// NewWebhookService 创建 Webhook 服务
func NewWebhookService(db *gorm.DB, cfgs ...WebhookServiceConfig) *WebhookService {
cfg := defaultWebhookServiceConfig()
if len(cfgs) > 0 {
cfg = cfgs[0]
}
if cfg.WorkerCount <= 0 {
cfg.WorkerCount = defaultWebhookServiceConfig().WorkerCount
}
if cfg.QueueSize <= 0 {
cfg.QueueSize = defaultWebhookServiceConfig().QueueSize
}
if cfg.SecretHeader == "" {
cfg.SecretHeader = defaultWebhookServiceConfig().SecretHeader
}
if cfg.TimeoutSec <= 0 {
cfg.TimeoutSec = defaultWebhookServiceConfig().TimeoutSec
}
if cfg.MaxRetries <= 0 {
cfg.MaxRetries = defaultWebhookServiceConfig().MaxRetries
}
if cfg.RetryBackoff == "" {
cfg.RetryBackoff = defaultWebhookServiceConfig().RetryBackoff
}
svc := &WebhookService{
db: db,
repo: repository.NewWebhookRepository(db),
queue: make(chan *deliveryTask, cfg.QueueSize),
workers: cfg.WorkerCount,
config: cfg,
}
svc.startWorkers()
return svc
}
func defaultWebhookServiceConfig() WebhookServiceConfig {
return WebhookServiceConfig{
Enabled: true,
SecretHeader: "X-Webhook-Signature",
TimeoutSec: 10,
MaxRetries: 3,
RetryBackoff: "exponential",
WorkerCount: 4,
QueueSize: 1000,
}
}
// startWorkers 启动后台投递 worker
func (s *WebhookService) startWorkers() {
s.once.Do(func() {
for i := 0; i < s.workers; i++ {
s.wg.Add(1)
go func() {
defer s.wg.Done()
for task := range s.queue {
s.deliver(task)
}
}()
}
})
}
// Publish 发布事件:找到订阅该事件的所有 Webhook异步投递
func (s *WebhookService) Publish(ctx context.Context, eventType domain.WebhookEventType, data interface{}) {
if !s.config.Enabled {
return
}
// 查询所有活跃 Webhook
webhooks, err := s.repo.ListActive(ctx)
if err != nil {
return
}
// 构建事件载荷
eventID, err := generateEventID()
if err != nil {
slog.Error("generate event ID failed", "error", err)
return
}
event := &WebhookEvent{
EventID: eventID,
EventType: eventType,
Timestamp: time.Now().UTC(),
Data: data,
}
payloadBytes, err := json.Marshal(event)
if err != nil {
return
}
for i := range webhooks {
wh := webhooks[i]
// 检查是否订阅了该事件类型
if !webhookSubscribesTo(wh, eventType) {
continue
}
task := &deliveryTask{
webhook: wh,
eventType: eventType,
payload: payloadBytes,
attempt: 1,
}
// 非阻塞投递到队列
select {
case s.queue <- task:
default:
// 队列满时记录但不阻塞
}
}
}
// deliver 执行单次 HTTP 投递
func (s *WebhookService) deliver(task *deliveryTask) {
wh := task.webhook
// NEW-SEC-01 修复:检查 URL 安全性
if !isSafeURL(wh.URL) {
s.recordDelivery(task, 0, "", "webhook URL 不安全: 可能存在 SSRF 风险", false)
return
}
timeout := time.Duration(wh.TimeoutSec) * time.Second
if timeout <= 0 {
timeout = time.Duration(s.config.TimeoutSec) * time.Second
}
if timeout <= 0 {
timeout = 10 * time.Second
}
client := &http.Client{Timeout: timeout}
req, err := http.NewRequest("POST", wh.URL, bytes.NewReader(task.payload))
if err != nil {
s.recordDelivery(task, 0, "", err.Error(), false)
return
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", "UserManagementSystem-Webhook/1.0")
req.Header.Set("X-Webhook-Event", string(task.eventType))
req.Header.Set("X-Webhook-Attempt", fmt.Sprintf("%d", task.attempt))
// HMAC 签名
if wh.Secret != "" {
sig := computeHMAC(task.payload, wh.Secret)
req.Header.Set(s.config.SecretHeader, "sha256="+sig)
}
// 使用带超时的 context 避免请求无限等待
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
resp, err := client.Do(req.WithContext(ctx))
if err != nil {
s.handleFailure(task, 0, "", err.Error())
return
}
defer resp.Body.Close()
var respBuf bytes.Buffer
respBuf.ReadFrom(resp.Body)
success := resp.StatusCode >= 200 && resp.StatusCode < 300
if !success {
s.handleFailure(task, resp.StatusCode, respBuf.String(), "非 2xx 响应")
return
}
s.recordDelivery(task, resp.StatusCode, respBuf.String(), "", true)
}
// handleFailure 处理投递失败(重试逻辑)
func (s *WebhookService) handleFailure(task *deliveryTask, statusCode int, body, errMsg string) {
s.recordDelivery(task, statusCode, body, errMsg, false)
// 指数退避重试
if task.attempt < task.webhook.MaxRetries {
backoff := time.Second
if s.config.RetryBackoff == "fixed" {
backoff = 2 * time.Second
} else {
backoff = time.Duration(1<<uint(task.attempt)) * time.Second
}
time.AfterFunc(backoff, func() {
task.attempt++
select {
case s.queue <- task:
default:
}
})
}
}
// recordDelivery 记录投递日志
func (s *WebhookService) recordDelivery(task *deliveryTask, statusCode int, body, errMsg string, success bool) {
now := time.Now()
delivery := &domain.WebhookDelivery{
WebhookID: task.webhook.ID,
EventType: task.eventType,
Payload: string(task.payload),
StatusCode: statusCode,
ResponseBody: body,
Attempt: task.attempt,
Success: success,
Error: errMsg,
}
if success {
delivery.DeliveredAt = &now
}
_ = s.repo.CreateDelivery(context.Background(), delivery)
}
// CreateWebhook 创建 Webhook
func (s *WebhookService) CreateWebhook(ctx context.Context, req *CreateWebhookRequest, createdBy int64) (*domain.Webhook, error) {
eventsJSON, err := json.Marshal(req.Events)
if err != nil {
return nil, fmt.Errorf("序列化事件列表失败")
}
secret := req.Secret
if secret == "" {
generatedSecret, err := generateWebhookSecret()
if err != nil {
return nil, fmt.Errorf("generate webhook secret failed: %w", err)
}
secret = generatedSecret
}
wh := &domain.Webhook{
Name: req.Name,
URL: req.URL,
Secret: secret,
Events: string(eventsJSON),
Status: domain.WebhookStatusActive,
MaxRetries: s.config.MaxRetries,
TimeoutSec: s.config.TimeoutSec,
CreatedBy: createdBy,
}
if err := s.repo.Create(ctx, wh); err != nil {
return nil, err
}
return wh, nil
}
// UpdateWebhook 更新 Webhook
func (s *WebhookService) UpdateWebhook(ctx context.Context, id int64, req *UpdateWebhookRequest) error {
updates := map[string]interface{}{}
if req.Name != "" {
updates["name"] = req.Name
}
if req.URL != "" {
updates["url"] = req.URL
}
if len(req.Events) > 0 {
b, _ := json.Marshal(req.Events)
updates["events"] = string(b)
}
if req.Status != nil {
updates["status"] = *req.Status
}
return s.repo.Update(ctx, id, updates)
}
// DeleteWebhook 删除 Webhook
func (s *WebhookService) DeleteWebhook(ctx context.Context, id int64) error {
return s.repo.Delete(ctx, id)
}
func (s *WebhookService) GetWebhook(ctx context.Context, id int64) (*domain.Webhook, error) {
return s.repo.GetByID(ctx, id)
}
// ListWebhooks 获取 Webhook 列表(不分页)
func (s *WebhookService) ListWebhooks(ctx context.Context, createdBy int64) ([]*domain.Webhook, error) {
return s.repo.ListByCreator(ctx, createdBy)
}
// ListWebhooksPaginated 获取 Webhook 列表(分页)
func (s *WebhookService) ListWebhooksPaginated(ctx context.Context, createdBy int64, offset, limit int) ([]*domain.Webhook, int64, error) {
return s.repo.ListByCreatorPaginated(ctx, createdBy, offset, limit)
}
// GetWebhookDeliveries 获取投递记录
func (s *WebhookService) GetWebhookDeliveries(ctx context.Context, webhookID int64, limit int) ([]*domain.WebhookDelivery, error) {
return s.repo.ListDeliveries(ctx, webhookID, limit)
}
// ---- Request/Response 结构 ----
// CreateWebhookRequest 创建 Webhook 请求
type CreateWebhookRequest struct {
Name string `json:"name" binding:"required"`
URL string `json:"url" binding:"required,url"`
Secret string `json:"secret"`
Events []domain.WebhookEventType `json:"events" binding:"required,min=1"`
}
// UpdateWebhookRequest 更新 Webhook 请求
type UpdateWebhookRequest struct {
Name string `json:"name"`
URL string `json:"url"`
Events []domain.WebhookEventType `json:"events"`
Status *domain.WebhookStatus `json:"status"`
}
// ---- 辅助函数 ----
// webhookSubscribesTo 检查 Webhook 是否订阅了指定事件类型
func webhookSubscribesTo(w *domain.Webhook, eventType domain.WebhookEventType) bool {
var events []domain.WebhookEventType
if err := json.Unmarshal([]byte(w.Events), &events); err != nil {
return false
}
for _, e := range events {
if e == eventType || e == "*" {
return true
}
}
return false
}
// SubscribesTo 检查 Webhook 是否订阅了指定事件类型(为 domain.Webhook 添加方法,通过包装实现)
// 注意:此函数在 domain 包外部无法直接扩展,使用独立函数代替
// isSafeURL 检查 URL 是否安全(防止 SSRF 攻击)
// NEW-SEC-01 修复:添加完整的 URL 安全检查
func isSafeURL(rawURL string) bool {
u, err := url.Parse(rawURL)
if err != nil || u.Scheme == "" {
return false
}
// 只允许 http/https
if u.Scheme != "http" && u.Scheme != "https" {
return false
}
host := u.Hostname()
// 禁止 localhost
if host == "localhost" || host == "127.0.0.1" || host == "::1" {
return false
}
// 检查内网 IP
if ip := net.ParseIP(host); ip != nil {
if isPrivateIP(ip) {
return false
}
}
// 检查内网域名
if strings.HasSuffix(host, ".internal") ||
strings.HasSuffix(host, ".local") ||
strings.HasSuffix(host, ".corp") ||
strings.HasSuffix(host, ".lan") ||
strings.HasSuffix(host, ".intranet") {
return false
}
// 检查知名内网服务地址
blockedHosts := []string{
"metadata.google.internal", // GCP 元数据服务
"169.254.169.254", // AWS/Azure/GCP 元数据服务
"metadata.azure.internal", // Azure 元数据服务
"100.100.100.200", // 阿里云元数据服务
}
for _, blocked := range blockedHosts {
if host == blocked {
return false
}
}
return true
}
// isPrivateIP 检查是否为内网 IP
func isPrivateIP(ip net.IP) bool {
privateRanges := []string{
"10.0.0.0/8",
"172.16.0.0/12",
"192.168.0.0/16",
"127.0.0.0/8",
"::1/128",
"fc00::/7",
}
for _, cidr := range privateRanges {
_, network, err := net.ParseCIDR(cidr)
if err != nil {
continue
}
if network.Contains(ip) {
return true
}
}
return false
}
// computeHMAC 计算 HMAC-SHA256 签名
func computeHMAC(payload []byte, secret string) string {
mac := hmac.New(sha256.New, []byte(secret))
mac.Write(payload)
return hex.EncodeToString(mac.Sum(nil))
}
// generateEventID 生成随机事件 ID
func generateEventID() (string, error) {
b := make([]byte, 8)
if _, err := cryptorand.Read(b); err != nil {
return "", fmt.Errorf("generate event ID failed: %w", err)
}
return "evt_" + hex.EncodeToString(b), nil
}
// generateWebhookSecret 生成随机 Webhook 签名密钥
func generateWebhookSecret() (string, error) {
b := make([]byte, 24)
if _, err := cryptorand.Read(b); err != nil {
return "", fmt.Errorf("generate webhook secret failed: %w", err)
}
return strings.ToLower(hex.EncodeToString(b)), nil
}