feat: backend core - auth, user, role, permission, device, webhook, monitoring, cache, repository, service, middleware, API handlers
This commit is contained in:
484
internal/service/webhook.go
Normal file
484
internal/service/webhook.go
Normal file
@@ -0,0 +1,484 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
cryptorand "crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/repository"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// WebhookService Webhook 服务
|
||||
type WebhookService struct {
|
||||
db *gorm.DB
|
||||
repo *repository.WebhookRepository
|
||||
queue chan *deliveryTask
|
||||
workers int
|
||||
config WebhookServiceConfig
|
||||
wg sync.WaitGroup
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
type WebhookServiceConfig struct {
|
||||
Enabled bool
|
||||
SecretHeader string
|
||||
TimeoutSec int
|
||||
MaxRetries int
|
||||
RetryBackoff string
|
||||
WorkerCount int
|
||||
QueueSize int
|
||||
}
|
||||
|
||||
// deliveryTask 投递任务
|
||||
type deliveryTask struct {
|
||||
webhook *domain.Webhook
|
||||
eventType domain.WebhookEventType
|
||||
payload []byte
|
||||
attempt int
|
||||
}
|
||||
|
||||
// WebhookEvent 发布的事件结构
|
||||
type WebhookEvent struct {
|
||||
EventID string `json:"event_id"`
|
||||
EventType domain.WebhookEventType `json:"event_type"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Data interface{} `json:"data"`
|
||||
}
|
||||
|
||||
// NewWebhookService 创建 Webhook 服务
|
||||
func NewWebhookService(db *gorm.DB, cfgs ...WebhookServiceConfig) *WebhookService {
|
||||
cfg := defaultWebhookServiceConfig()
|
||||
if len(cfgs) > 0 {
|
||||
cfg = cfgs[0]
|
||||
}
|
||||
if cfg.WorkerCount <= 0 {
|
||||
cfg.WorkerCount = defaultWebhookServiceConfig().WorkerCount
|
||||
}
|
||||
if cfg.QueueSize <= 0 {
|
||||
cfg.QueueSize = defaultWebhookServiceConfig().QueueSize
|
||||
}
|
||||
if cfg.SecretHeader == "" {
|
||||
cfg.SecretHeader = defaultWebhookServiceConfig().SecretHeader
|
||||
}
|
||||
if cfg.TimeoutSec <= 0 {
|
||||
cfg.TimeoutSec = defaultWebhookServiceConfig().TimeoutSec
|
||||
}
|
||||
if cfg.MaxRetries <= 0 {
|
||||
cfg.MaxRetries = defaultWebhookServiceConfig().MaxRetries
|
||||
}
|
||||
if cfg.RetryBackoff == "" {
|
||||
cfg.RetryBackoff = defaultWebhookServiceConfig().RetryBackoff
|
||||
}
|
||||
|
||||
svc := &WebhookService{
|
||||
db: db,
|
||||
repo: repository.NewWebhookRepository(db),
|
||||
queue: make(chan *deliveryTask, cfg.QueueSize),
|
||||
workers: cfg.WorkerCount,
|
||||
config: cfg,
|
||||
}
|
||||
svc.startWorkers()
|
||||
return svc
|
||||
}
|
||||
|
||||
func defaultWebhookServiceConfig() WebhookServiceConfig {
|
||||
return WebhookServiceConfig{
|
||||
Enabled: true,
|
||||
SecretHeader: "X-Webhook-Signature",
|
||||
TimeoutSec: 10,
|
||||
MaxRetries: 3,
|
||||
RetryBackoff: "exponential",
|
||||
WorkerCount: 4,
|
||||
QueueSize: 1000,
|
||||
}
|
||||
}
|
||||
|
||||
// startWorkers 启动后台投递 worker
|
||||
func (s *WebhookService) startWorkers() {
|
||||
s.once.Do(func() {
|
||||
for i := 0; i < s.workers; i++ {
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
for task := range s.queue {
|
||||
s.deliver(task)
|
||||
}
|
||||
}()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Publish 发布事件:找到订阅该事件的所有 Webhook,异步投递
|
||||
func (s *WebhookService) Publish(ctx context.Context, eventType domain.WebhookEventType, data interface{}) {
|
||||
if !s.config.Enabled {
|
||||
return
|
||||
}
|
||||
// 查询所有活跃 Webhook
|
||||
webhooks, err := s.repo.ListActive(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 构建事件载荷
|
||||
eventID, err := generateEventID()
|
||||
if err != nil {
|
||||
slog.Error("generate event ID failed", "error", err)
|
||||
return
|
||||
}
|
||||
event := &WebhookEvent{
|
||||
EventID: eventID,
|
||||
EventType: eventType,
|
||||
Timestamp: time.Now().UTC(),
|
||||
Data: data,
|
||||
}
|
||||
payloadBytes, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for i := range webhooks {
|
||||
wh := webhooks[i]
|
||||
// 检查是否订阅了该事件类型
|
||||
if !webhookSubscribesTo(wh, eventType) {
|
||||
continue
|
||||
}
|
||||
|
||||
task := &deliveryTask{
|
||||
webhook: wh,
|
||||
eventType: eventType,
|
||||
payload: payloadBytes,
|
||||
attempt: 1,
|
||||
}
|
||||
|
||||
// 非阻塞投递到队列
|
||||
select {
|
||||
case s.queue <- task:
|
||||
default:
|
||||
// 队列满时记录但不阻塞
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// deliver 执行单次 HTTP 投递
|
||||
func (s *WebhookService) deliver(task *deliveryTask) {
|
||||
wh := task.webhook
|
||||
|
||||
// NEW-SEC-01 修复:检查 URL 安全性
|
||||
if !isSafeURL(wh.URL) {
|
||||
s.recordDelivery(task, 0, "", "webhook URL 不安全: 可能存在 SSRF 风险", false)
|
||||
return
|
||||
}
|
||||
|
||||
timeout := time.Duration(wh.TimeoutSec) * time.Second
|
||||
if timeout <= 0 {
|
||||
timeout = time.Duration(s.config.TimeoutSec) * time.Second
|
||||
}
|
||||
if timeout <= 0 {
|
||||
timeout = 10 * time.Second
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: timeout}
|
||||
|
||||
req, err := http.NewRequest("POST", wh.URL, bytes.NewReader(task.payload))
|
||||
if err != nil {
|
||||
s.recordDelivery(task, 0, "", err.Error(), false)
|
||||
return
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", "UserManagementSystem-Webhook/1.0")
|
||||
req.Header.Set("X-Webhook-Event", string(task.eventType))
|
||||
req.Header.Set("X-Webhook-Attempt", fmt.Sprintf("%d", task.attempt))
|
||||
|
||||
// HMAC 签名
|
||||
if wh.Secret != "" {
|
||||
sig := computeHMAC(task.payload, wh.Secret)
|
||||
req.Header.Set(s.config.SecretHeader, "sha256="+sig)
|
||||
}
|
||||
|
||||
// 使用带超时的 context 避免请求无限等待
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
resp, err := client.Do(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
s.handleFailure(task, 0, "", err.Error())
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var respBuf bytes.Buffer
|
||||
respBuf.ReadFrom(resp.Body)
|
||||
success := resp.StatusCode >= 200 && resp.StatusCode < 300
|
||||
|
||||
if !success {
|
||||
s.handleFailure(task, resp.StatusCode, respBuf.String(), "非 2xx 响应")
|
||||
return
|
||||
}
|
||||
|
||||
s.recordDelivery(task, resp.StatusCode, respBuf.String(), "", true)
|
||||
}
|
||||
|
||||
// handleFailure 处理投递失败(重试逻辑)
|
||||
func (s *WebhookService) handleFailure(task *deliveryTask, statusCode int, body, errMsg string) {
|
||||
s.recordDelivery(task, statusCode, body, errMsg, false)
|
||||
|
||||
// 指数退避重试
|
||||
if task.attempt < task.webhook.MaxRetries {
|
||||
backoff := time.Second
|
||||
if s.config.RetryBackoff == "fixed" {
|
||||
backoff = 2 * time.Second
|
||||
} else {
|
||||
backoff = time.Duration(1<<uint(task.attempt)) * time.Second
|
||||
}
|
||||
time.AfterFunc(backoff, func() {
|
||||
task.attempt++
|
||||
select {
|
||||
case s.queue <- task:
|
||||
default:
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// recordDelivery 记录投递日志
|
||||
func (s *WebhookService) recordDelivery(task *deliveryTask, statusCode int, body, errMsg string, success bool) {
|
||||
now := time.Now()
|
||||
delivery := &domain.WebhookDelivery{
|
||||
WebhookID: task.webhook.ID,
|
||||
EventType: task.eventType,
|
||||
Payload: string(task.payload),
|
||||
StatusCode: statusCode,
|
||||
ResponseBody: body,
|
||||
Attempt: task.attempt,
|
||||
Success: success,
|
||||
Error: errMsg,
|
||||
}
|
||||
if success {
|
||||
delivery.DeliveredAt = &now
|
||||
}
|
||||
_ = s.repo.CreateDelivery(context.Background(), delivery)
|
||||
}
|
||||
|
||||
// CreateWebhook 创建 Webhook
|
||||
func (s *WebhookService) CreateWebhook(ctx context.Context, req *CreateWebhookRequest, createdBy int64) (*domain.Webhook, error) {
|
||||
eventsJSON, err := json.Marshal(req.Events)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化事件列表失败")
|
||||
}
|
||||
|
||||
secret := req.Secret
|
||||
if secret == "" {
|
||||
generatedSecret, err := generateWebhookSecret()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate webhook secret failed: %w", err)
|
||||
}
|
||||
secret = generatedSecret
|
||||
}
|
||||
|
||||
wh := &domain.Webhook{
|
||||
Name: req.Name,
|
||||
URL: req.URL,
|
||||
Secret: secret,
|
||||
Events: string(eventsJSON),
|
||||
Status: domain.WebhookStatusActive,
|
||||
MaxRetries: s.config.MaxRetries,
|
||||
TimeoutSec: s.config.TimeoutSec,
|
||||
CreatedBy: createdBy,
|
||||
}
|
||||
if err := s.repo.Create(ctx, wh); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return wh, nil
|
||||
}
|
||||
|
||||
// UpdateWebhook 更新 Webhook
|
||||
func (s *WebhookService) UpdateWebhook(ctx context.Context, id int64, req *UpdateWebhookRequest) error {
|
||||
updates := map[string]interface{}{}
|
||||
if req.Name != "" {
|
||||
updates["name"] = req.Name
|
||||
}
|
||||
if req.URL != "" {
|
||||
updates["url"] = req.URL
|
||||
}
|
||||
if len(req.Events) > 0 {
|
||||
b, _ := json.Marshal(req.Events)
|
||||
updates["events"] = string(b)
|
||||
}
|
||||
if req.Status != nil {
|
||||
updates["status"] = *req.Status
|
||||
}
|
||||
return s.repo.Update(ctx, id, updates)
|
||||
}
|
||||
|
||||
// DeleteWebhook 删除 Webhook
|
||||
func (s *WebhookService) DeleteWebhook(ctx context.Context, id int64) error {
|
||||
return s.repo.Delete(ctx, id)
|
||||
}
|
||||
|
||||
func (s *WebhookService) GetWebhook(ctx context.Context, id int64) (*domain.Webhook, error) {
|
||||
return s.repo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
// ListWebhooks 获取 Webhook 列表(不分页)
|
||||
func (s *WebhookService) ListWebhooks(ctx context.Context, createdBy int64) ([]*domain.Webhook, error) {
|
||||
return s.repo.ListByCreator(ctx, createdBy)
|
||||
}
|
||||
|
||||
// ListWebhooksPaginated 获取 Webhook 列表(分页)
|
||||
func (s *WebhookService) ListWebhooksPaginated(ctx context.Context, createdBy int64, offset, limit int) ([]*domain.Webhook, int64, error) {
|
||||
return s.repo.ListByCreatorPaginated(ctx, createdBy, offset, limit)
|
||||
}
|
||||
|
||||
// GetWebhookDeliveries 获取投递记录
|
||||
func (s *WebhookService) GetWebhookDeliveries(ctx context.Context, webhookID int64, limit int) ([]*domain.WebhookDelivery, error) {
|
||||
return s.repo.ListDeliveries(ctx, webhookID, limit)
|
||||
}
|
||||
|
||||
// ---- Request/Response 结构 ----
|
||||
|
||||
// CreateWebhookRequest 创建 Webhook 请求
|
||||
type CreateWebhookRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
URL string `json:"url" binding:"required,url"`
|
||||
Secret string `json:"secret"`
|
||||
Events []domain.WebhookEventType `json:"events" binding:"required,min=1"`
|
||||
}
|
||||
|
||||
// UpdateWebhookRequest 更新 Webhook 请求
|
||||
type UpdateWebhookRequest struct {
|
||||
Name string `json:"name"`
|
||||
URL string `json:"url"`
|
||||
Events []domain.WebhookEventType `json:"events"`
|
||||
Status *domain.WebhookStatus `json:"status"`
|
||||
}
|
||||
|
||||
// ---- 辅助函数 ----
|
||||
|
||||
// webhookSubscribesTo 检查 Webhook 是否订阅了指定事件类型
|
||||
func webhookSubscribesTo(w *domain.Webhook, eventType domain.WebhookEventType) bool {
|
||||
var events []domain.WebhookEventType
|
||||
if err := json.Unmarshal([]byte(w.Events), &events); err != nil {
|
||||
return false
|
||||
}
|
||||
for _, e := range events {
|
||||
if e == eventType || e == "*" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// SubscribesTo 检查 Webhook 是否订阅了指定事件类型(为 domain.Webhook 添加方法,通过包装实现)
|
||||
// 注意:此函数在 domain 包外部无法直接扩展,使用独立函数代替
|
||||
|
||||
// isSafeURL 检查 URL 是否安全(防止 SSRF 攻击)
|
||||
// NEW-SEC-01 修复:添加完整的 URL 安全检查
|
||||
func isSafeURL(rawURL string) bool {
|
||||
u, err := url.Parse(rawURL)
|
||||
if err != nil || u.Scheme == "" {
|
||||
return false
|
||||
}
|
||||
// 只允许 http/https
|
||||
if u.Scheme != "http" && u.Scheme != "https" {
|
||||
return false
|
||||
}
|
||||
|
||||
host := u.Hostname()
|
||||
|
||||
// 禁止 localhost
|
||||
if host == "localhost" || host == "127.0.0.1" || host == "::1" {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查内网 IP
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if isPrivateIP(ip) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// 检查内网域名
|
||||
if strings.HasSuffix(host, ".internal") ||
|
||||
strings.HasSuffix(host, ".local") ||
|
||||
strings.HasSuffix(host, ".corp") ||
|
||||
strings.HasSuffix(host, ".lan") ||
|
||||
strings.HasSuffix(host, ".intranet") {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查知名内网服务地址
|
||||
blockedHosts := []string{
|
||||
"metadata.google.internal", // GCP 元数据服务
|
||||
"169.254.169.254", // AWS/Azure/GCP 元数据服务
|
||||
"metadata.azure.internal", // Azure 元数据服务
|
||||
"100.100.100.200", // 阿里云元数据服务
|
||||
}
|
||||
for _, blocked := range blockedHosts {
|
||||
if host == blocked {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// isPrivateIP 检查是否为内网 IP
|
||||
func isPrivateIP(ip net.IP) bool {
|
||||
privateRanges := []string{
|
||||
"10.0.0.0/8",
|
||||
"172.16.0.0/12",
|
||||
"192.168.0.0/16",
|
||||
"127.0.0.0/8",
|
||||
"::1/128",
|
||||
"fc00::/7",
|
||||
}
|
||||
for _, cidr := range privateRanges {
|
||||
_, network, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if network.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// computeHMAC 计算 HMAC-SHA256 签名
|
||||
func computeHMAC(payload []byte, secret string) string {
|
||||
mac := hmac.New(sha256.New, []byte(secret))
|
||||
mac.Write(payload)
|
||||
return hex.EncodeToString(mac.Sum(nil))
|
||||
}
|
||||
|
||||
// generateEventID 生成随机事件 ID
|
||||
func generateEventID() (string, error) {
|
||||
b := make([]byte, 8)
|
||||
if _, err := cryptorand.Read(b); err != nil {
|
||||
return "", fmt.Errorf("generate event ID failed: %w", err)
|
||||
}
|
||||
return "evt_" + hex.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// generateWebhookSecret 生成随机 Webhook 签名密钥
|
||||
func generateWebhookSecret() (string, error) {
|
||||
b := make([]byte, 24)
|
||||
if _, err := cryptorand.Read(b); err != nil {
|
||||
return "", fmt.Errorf("generate webhook secret failed: %w", err)
|
||||
}
|
||||
return strings.ToLower(hex.EncodeToString(b)), nil
|
||||
}
|
||||
Reference in New Issue
Block a user