fix: 系统性修复安全问题、性能问题和错误处理
安全问题修复: - X-Forwarded-For越界检查(auth.go) - checkTokenStatus Context参数传递(auth.go) - Type Assertion安全检查(auth.go) 性能问题修复: - TokenCache过期清理机制 - BruteForceProtection过期清理 - InMemoryIdempotencyStore过期清理 错误处理修复: - AuditStore.Emit返回error - domain层emitAudit辅助方法 - List方法返回空slice而非nil - 金额/价格负数验证 架构一致性: - 统一使用model.RoleHierarchyLevels 新增功能: - Alert API完整实现(CRUD+Resolve) - pkg/error错误码集中管理
This commit is contained in:
274
supply-api/internal/audit/service/alert_service.go
Normal file
274
supply-api/internal/audit/service/alert_service.go
Normal file
@@ -0,0 +1,274 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"lijiaoqiao/supply-api/internal/audit/model"
|
||||
)
|
||||
|
||||
// 错误定义
|
||||
var (
|
||||
ErrAlertNotFound = errors.New("alert not found")
|
||||
ErrInvalidAlertInput = errors.New("invalid alert input")
|
||||
ErrAlertConflict = errors.New("alert conflict")
|
||||
)
|
||||
|
||||
// AlertStoreInterface 告警存储接口
|
||||
type AlertStoreInterface interface {
|
||||
Create(ctx context.Context, alert *model.Alert) error
|
||||
GetByID(ctx context.Context, alertID string) (*model.Alert, error)
|
||||
Update(ctx context.Context, alert *model.Alert) error
|
||||
Delete(ctx context.Context, alertID string) error
|
||||
List(ctx context.Context, filter *model.AlertFilter) ([]*model.Alert, int64, error)
|
||||
}
|
||||
|
||||
// InMemoryAlertStore 内存告警存储
|
||||
type InMemoryAlertStore struct {
|
||||
mu sync.RWMutex
|
||||
alerts map[string]*model.Alert
|
||||
}
|
||||
|
||||
// NewInMemoryAlertStore 创建内存告警存储
|
||||
func NewInMemoryAlertStore() *InMemoryAlertStore {
|
||||
return &InMemoryAlertStore{
|
||||
alerts: make(map[string]*model.Alert),
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建告警
|
||||
func (s *InMemoryAlertStore) Create(ctx context.Context, alert *model.Alert) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if alert.AlertID == "" {
|
||||
alert.AlertID = "ALT-" + uuid.New().String()[:8]
|
||||
}
|
||||
alert.CreatedAt = time.Now()
|
||||
alert.UpdatedAt = time.Now()
|
||||
|
||||
s.alerts[alert.AlertID] = alert
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取告警
|
||||
func (s *InMemoryAlertStore) GetByID(ctx context.Context, alertID string) (*model.Alert, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if alert, ok := s.alerts[alertID]; ok {
|
||||
return alert, nil
|
||||
}
|
||||
return nil, ErrAlertNotFound
|
||||
}
|
||||
|
||||
// Update 更新告警
|
||||
func (s *InMemoryAlertStore) Update(ctx context.Context, alert *model.Alert) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if _, ok := s.alerts[alert.AlertID]; !ok {
|
||||
return ErrAlertNotFound
|
||||
}
|
||||
|
||||
alert.UpdatedAt = time.Now()
|
||||
s.alerts[alert.AlertID] = alert
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除告警
|
||||
func (s *InMemoryAlertStore) Delete(ctx context.Context, alertID string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if _, ok := s.alerts[alertID]; !ok {
|
||||
return ErrAlertNotFound
|
||||
}
|
||||
|
||||
delete(s.alerts, alertID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// List 查询告警列表
|
||||
func (s *InMemoryAlertStore) List(ctx context.Context, filter *model.AlertFilter) ([]*model.Alert, int64, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
var result []*model.Alert
|
||||
for _, alert := range s.alerts {
|
||||
// 按租户过滤
|
||||
if filter.TenantID > 0 && alert.TenantID != filter.TenantID {
|
||||
continue
|
||||
}
|
||||
// 按供应商过滤
|
||||
if filter.SupplierID > 0 && alert.SupplierID != filter.SupplierID {
|
||||
continue
|
||||
}
|
||||
// 按类型过滤
|
||||
if filter.AlertType != "" && alert.AlertType != filter.AlertType {
|
||||
continue
|
||||
}
|
||||
// 按级别过滤
|
||||
if filter.AlertLevel != "" && alert.AlertLevel != filter.AlertLevel {
|
||||
continue
|
||||
}
|
||||
// 按状态过滤
|
||||
if filter.Status != "" && alert.Status != filter.Status {
|
||||
continue
|
||||
}
|
||||
// 按时间范围过滤
|
||||
if !filter.StartTime.IsZero() && alert.CreatedAt.Before(filter.StartTime) {
|
||||
continue
|
||||
}
|
||||
if !filter.EndTime.IsZero() && alert.CreatedAt.After(filter.EndTime) {
|
||||
continue
|
||||
}
|
||||
// 关键字搜索
|
||||
if filter.Keywords != "" {
|
||||
kw := filter.Keywords
|
||||
if !strings.Contains(alert.Title, kw) && !strings.Contains(alert.Message, kw) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
result = append(result, alert)
|
||||
}
|
||||
|
||||
total := int64(len(result))
|
||||
|
||||
// 分页
|
||||
if filter.Offset > 0 {
|
||||
if filter.Offset >= len(result) {
|
||||
return []*model.Alert{}, total, nil
|
||||
}
|
||||
result = result[filter.Offset:]
|
||||
}
|
||||
if filter.Limit > 0 && filter.Limit < len(result) {
|
||||
result = result[:filter.Limit]
|
||||
}
|
||||
|
||||
return result, total, nil
|
||||
}
|
||||
|
||||
// AlertService 告警服务
|
||||
type AlertService struct {
|
||||
store AlertStoreInterface
|
||||
}
|
||||
|
||||
// NewAlertService 创建告警服务
|
||||
func NewAlertService(store AlertStoreInterface) *AlertService {
|
||||
return &AlertService{store: store}
|
||||
}
|
||||
|
||||
// CreateAlert 创建告警
|
||||
func (s *AlertService) CreateAlert(ctx context.Context, alert *model.Alert) (*model.Alert, error) {
|
||||
if alert == nil {
|
||||
return nil, ErrInvalidAlertInput
|
||||
}
|
||||
if alert.Title == "" {
|
||||
return nil, errors.New("alert title is required")
|
||||
}
|
||||
|
||||
// 设置默认值
|
||||
if alert.AlertID == "" {
|
||||
alert.AlertID = model.NewAlert("", "", "", "", "", "").AlertID
|
||||
}
|
||||
if alert.Status == "" {
|
||||
alert.Status = model.AlertStatusActive
|
||||
}
|
||||
now := time.Now()
|
||||
if alert.CreatedAt.IsZero() {
|
||||
alert.CreatedAt = now
|
||||
}
|
||||
if alert.UpdatedAt.IsZero() {
|
||||
alert.UpdatedAt = now
|
||||
}
|
||||
if alert.FirstSeenAt.IsZero() {
|
||||
alert.FirstSeenAt = now
|
||||
}
|
||||
if alert.LastSeenAt.IsZero() {
|
||||
alert.LastSeenAt = now
|
||||
}
|
||||
|
||||
err := s.store.Create(ctx, alert)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return alert, nil
|
||||
}
|
||||
|
||||
// GetAlert 获取告警
|
||||
func (s *AlertService) GetAlert(ctx context.Context, alertID string) (*model.Alert, error) {
|
||||
if alertID == "" {
|
||||
return nil, ErrInvalidAlertInput
|
||||
}
|
||||
return s.store.GetByID(ctx, alertID)
|
||||
}
|
||||
|
||||
// UpdateAlert 更新告警
|
||||
func (s *AlertService) UpdateAlert(ctx context.Context, alert *model.Alert) (*model.Alert, error) {
|
||||
if alert == nil || alert.AlertID == "" {
|
||||
return nil, ErrInvalidAlertInput
|
||||
}
|
||||
|
||||
alert.UpdatedAt = time.Now()
|
||||
err := s.store.Update(ctx, alert)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return alert, nil
|
||||
}
|
||||
|
||||
// DeleteAlert 删除告警
|
||||
func (s *AlertService) DeleteAlert(ctx context.Context, alertID string) error {
|
||||
if alertID == "" {
|
||||
return ErrInvalidAlertInput
|
||||
}
|
||||
return s.store.Delete(ctx, alertID)
|
||||
}
|
||||
|
||||
// ListAlerts 列出告警
|
||||
func (s *AlertService) ListAlerts(ctx context.Context, filter *model.AlertFilter) ([]*model.Alert, int64, error) {
|
||||
if filter == nil {
|
||||
filter = &model.AlertFilter{}
|
||||
}
|
||||
if filter.Limit == 0 {
|
||||
filter.Limit = 100
|
||||
}
|
||||
return s.store.List(ctx, filter)
|
||||
}
|
||||
|
||||
// ResolveAlert 解决告警
|
||||
func (s *AlertService) ResolveAlert(ctx context.Context, alertID, resolvedBy, note string) (*model.Alert, error) {
|
||||
alert, err := s.store.GetByID(ctx, alertID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
alert.Resolve(resolvedBy, note)
|
||||
err = s.store.Update(ctx, alert)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return alert, nil
|
||||
}
|
||||
|
||||
// AcknowledgeAlert 确认告警
|
||||
func (s *AlertService) AcknowledgeAlert(ctx context.Context, alertID string) (*model.Alert, error) {
|
||||
alert, err := s.store.GetByID(ctx, alertID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
alert.Acknowledge()
|
||||
err = s.store.Update(ctx, alert)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return alert, nil
|
||||
}
|
||||
203
supply-api/internal/audit/service/batch_buffer.go
Normal file
203
supply-api/internal/audit/service/batch_buffer.go
Normal file
@@ -0,0 +1,203 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"lijiaoqiao/supply-api/internal/audit/model"
|
||||
)
|
||||
|
||||
// BatchBufferConfig 批量缓冲区配置
|
||||
type BatchBufferConfig struct {
|
||||
BatchSize int // 批量大小(默认50)
|
||||
FlushInterval time.Duration // 刷新间隔(默认5ms)
|
||||
BufferSize int // 通道缓冲大小(默认1000)
|
||||
}
|
||||
|
||||
// DefaultBatchBufferConfig 默认配置
|
||||
var DefaultBatchBufferConfig = BatchBufferConfig{
|
||||
BatchSize: 50,
|
||||
FlushInterval: 5 * time.Millisecond,
|
||||
BufferSize: 1000,
|
||||
}
|
||||
|
||||
// BatchBuffer 批量写入缓冲区
|
||||
// 设计目标:50条/批或5ms刷新间隔,支持5K-8K TPS
|
||||
type BatchBuffer struct {
|
||||
config BatchBufferConfig
|
||||
eventCh chan *model.AuditEvent
|
||||
buffer []*model.AuditEvent
|
||||
mu sync.Mutex
|
||||
closed bool
|
||||
|
||||
flushTick *time.Ticker
|
||||
stopCh chan struct{}
|
||||
doneCh chan struct{}
|
||||
|
||||
// FlushHandler 处理批量刷新回调
|
||||
FlushHandler func(events []*model.AuditEvent) error
|
||||
}
|
||||
|
||||
// NewBatchBuffer 创建批量缓冲区
|
||||
func NewBatchBuffer(batchSize int, flushInterval time.Duration) *BatchBuffer {
|
||||
config := DefaultBatchBufferConfig
|
||||
if batchSize > 0 {
|
||||
config.BatchSize = batchSize
|
||||
}
|
||||
if flushInterval > 0 {
|
||||
config.FlushInterval = flushInterval
|
||||
}
|
||||
|
||||
return &BatchBuffer{
|
||||
config: config,
|
||||
eventCh: make(chan *model.AuditEvent, config.BufferSize),
|
||||
buffer: make([]*model.AuditEvent, 0, batchSize),
|
||||
flushTick: time.NewTicker(config.FlushInterval),
|
||||
stopCh: make(chan struct{}),
|
||||
doneCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动批量缓冲处理
|
||||
func (b *BatchBuffer) Start(ctx context.Context) error {
|
||||
go b.run()
|
||||
return nil
|
||||
}
|
||||
|
||||
// run 后台处理循环
|
||||
func (b *BatchBuffer) run() {
|
||||
defer close(b.doneCh)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-b.stopCh:
|
||||
// 停止信号:处理剩余缓冲
|
||||
b.flush()
|
||||
return
|
||||
case event := <-b.eventCh:
|
||||
b.addEvent(event)
|
||||
case <-b.flushTick.C:
|
||||
b.flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// addEvent 添加事件到缓冲
|
||||
func (b *BatchBuffer) addEvent(event *model.AuditEvent) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
b.buffer = append(b.buffer, event)
|
||||
|
||||
// 达到批量大小立即刷新
|
||||
if len(b.buffer) >= b.config.BatchSize {
|
||||
b.doFlushLocked()
|
||||
}
|
||||
}
|
||||
|
||||
// flush 刷新缓冲(带锁)- 也会处理eventCh中的待处理事件
|
||||
func (b *BatchBuffer) flush() {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
// 处理eventCh中已有的事件
|
||||
for {
|
||||
select {
|
||||
case event := <-b.eventCh:
|
||||
b.buffer = append(b.buffer, event)
|
||||
default:
|
||||
goto done
|
||||
}
|
||||
}
|
||||
done:
|
||||
b.doFlushLocked()
|
||||
}
|
||||
|
||||
// doFlushLocked 执行刷新( caller 必须持锁)
|
||||
func (b *BatchBuffer) doFlushLocked() {
|
||||
if len(b.buffer) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// 复制缓冲数据
|
||||
events := make([]*model.AuditEvent, len(b.buffer))
|
||||
copy(events, b.buffer)
|
||||
|
||||
// 清空缓冲
|
||||
b.buffer = b.buffer[:0]
|
||||
|
||||
// 调用处理函数(如果已设置)
|
||||
if b.FlushHandler != nil {
|
||||
if err := b.FlushHandler(events); err != nil {
|
||||
// TODO: 错误处理 - 记录日志、重试等
|
||||
// 当前简化处理:仅记录
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add 添加审计事件
|
||||
func (b *BatchBuffer) Add(event *model.AuditEvent) error {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if b.closed {
|
||||
return ErrBufferClosed
|
||||
}
|
||||
|
||||
select {
|
||||
case b.eventCh <- event:
|
||||
return nil
|
||||
default:
|
||||
// 通道满,添加到缓冲
|
||||
b.buffer = append(b.buffer, event)
|
||||
if len(b.buffer) >= b.config.BatchSize {
|
||||
b.doFlushLocked()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// FlushNow 立即刷新
|
||||
func (b *BatchBuffer) FlushNow() error {
|
||||
b.flush()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close 关闭缓冲区
|
||||
func (b *BatchBuffer) Close() error {
|
||||
b.mu.Lock()
|
||||
if b.closed {
|
||||
b.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
b.closed = true
|
||||
b.mu.Unlock()
|
||||
|
||||
close(b.stopCh)
|
||||
<-b.doneCh
|
||||
b.flushTick.Stop()
|
||||
close(b.eventCh)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetFlushHandler 设置刷新处理器
|
||||
func (b *BatchBuffer) SetFlushHandler(handler func(events []*model.AuditEvent) error) {
|
||||
b.FlushHandler = handler
|
||||
}
|
||||
|
||||
// 错误定义
|
||||
var (
|
||||
ErrBufferClosed = &BatchBufferError{"buffer is closed"}
|
||||
ErrMissingFlushHandler = &BatchBufferError{"flush handler not set"}
|
||||
)
|
||||
|
||||
// BatchBufferError 批量缓冲错误
|
||||
type BatchBufferError struct {
|
||||
msg string
|
||||
}
|
||||
|
||||
func (e *BatchBufferError) Error() string {
|
||||
return e.msg
|
||||
}
|
||||
249
supply-api/internal/audit/service/batch_buffer_test.go
Normal file
249
supply-api/internal/audit/service/batch_buffer_test.go
Normal file
@@ -0,0 +1,249 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"lijiaoqiao/supply-api/internal/audit/model"
|
||||
)
|
||||
|
||||
// TestBatchBuffer_BatchSize 测试50条/批刷新
|
||||
func TestBatchBuffer_BatchSize(t *testing.T) {
|
||||
const batchSize = 50
|
||||
|
||||
buffer := NewBatchBuffer(batchSize, 100*time.Millisecond) // 100ms超时防止测试卡住
|
||||
ctx := context.Background()
|
||||
|
||||
err := buffer.Start(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Start failed: %v", err)
|
||||
}
|
||||
defer buffer.Close()
|
||||
|
||||
// 收集器:接收批量事件
|
||||
var receivedBatches [][]*model.AuditEvent
|
||||
var mu sync.Mutex
|
||||
|
||||
buffer.SetFlushHandler(func(events []*model.AuditEvent) error {
|
||||
mu.Lock()
|
||||
receivedBatches = append(receivedBatches, events)
|
||||
mu.Unlock()
|
||||
return nil
|
||||
})
|
||||
|
||||
// 添加50条事件,应该触发一次批量刷新
|
||||
for i := 0; i < batchSize; i++ {
|
||||
event := &model.AuditEvent{
|
||||
EventID: "batch-test-001",
|
||||
EventName: "TEST-EVENT",
|
||||
}
|
||||
if err := buffer.Add(event); err != nil {
|
||||
t.Errorf("Add failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 等待刷新完成
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// 验证:应该收到恰好一个批次
|
||||
mu.Lock()
|
||||
if len(receivedBatches) != 1 {
|
||||
t.Errorf("expected 1 batch, got %d", len(receivedBatches))
|
||||
}
|
||||
if len(receivedBatches) > 0 && len(receivedBatches[0]) != batchSize {
|
||||
t.Errorf("expected batch size %d, got %d", batchSize, len(receivedBatches[0]))
|
||||
}
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
// TestBatchBuffer_TimeoutFlush 测试5ms超时刷新
|
||||
func TestBatchBuffer_TimeoutFlush(t *testing.T) {
|
||||
const batchSize = 100 // 大于我们添加的数量
|
||||
const flushInterval = 5 * time.Millisecond
|
||||
|
||||
buffer := NewBatchBuffer(batchSize, flushInterval)
|
||||
ctx := context.Background()
|
||||
|
||||
err := buffer.Start(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Start failed: %v", err)
|
||||
}
|
||||
defer buffer.Close()
|
||||
|
||||
// 收集器
|
||||
var receivedBatches [][]*model.AuditEvent
|
||||
var mu sync.Mutex
|
||||
|
||||
buffer.SetFlushHandler(func(events []*model.AuditEvent) error {
|
||||
mu.Lock()
|
||||
receivedBatches = append(receivedBatches, events)
|
||||
mu.Unlock()
|
||||
return nil
|
||||
})
|
||||
|
||||
// 只添加3条事件,不满50条
|
||||
for i := 0; i < 3; i++ {
|
||||
event := &model.AuditEvent{
|
||||
EventID: "batch-test-002",
|
||||
EventName: "TEST-TIMEOUT",
|
||||
}
|
||||
if err := buffer.Add(event); err != nil {
|
||||
t.Errorf("Add failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 等待5ms超时刷新
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
|
||||
// 验证:应该收到一个批次,包含3条事件
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if len(receivedBatches) != 1 {
|
||||
t.Errorf("expected 1 batch (timeout flush), got %d", len(receivedBatches))
|
||||
}
|
||||
if len(receivedBatches) > 0 && len(receivedBatches[0]) != 3 {
|
||||
t.Errorf("expected 3 events in batch, got %d", len(receivedBatches[0]))
|
||||
}
|
||||
}
|
||||
|
||||
// TestBatchBuffer_ConcurrentAccess 测试并发安全性
|
||||
func TestBatchBuffer_ConcurrentAccess(t *testing.T) {
|
||||
const batchSize = 50
|
||||
const numGoroutines = 10
|
||||
const eventsPerGoroutine = 100
|
||||
|
||||
buffer := NewBatchBuffer(batchSize, 10*time.Millisecond)
|
||||
ctx := context.Background()
|
||||
|
||||
err := buffer.Start(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Start failed: %v", err)
|
||||
}
|
||||
defer buffer.Close()
|
||||
|
||||
var totalReceived int
|
||||
var mu sync.Mutex
|
||||
|
||||
buffer.SetFlushHandler(func(events []*model.AuditEvent) error {
|
||||
mu.Lock()
|
||||
totalReceived += len(events)
|
||||
mu.Unlock()
|
||||
return nil
|
||||
})
|
||||
|
||||
// 并发添加事件
|
||||
var wg sync.WaitGroup
|
||||
for g := 0; g < numGoroutines; g++ {
|
||||
wg.Add(1)
|
||||
go func(goroutineID int) {
|
||||
defer wg.Done()
|
||||
for i := 0; i < eventsPerGoroutine; i++ {
|
||||
event := &model.AuditEvent{
|
||||
EventID: "batch-test-concurrent",
|
||||
EventName: "TEST-CONCURRENT",
|
||||
}
|
||||
if err := buffer.Add(event); err != nil {
|
||||
t.Errorf("Add failed: %v", err)
|
||||
}
|
||||
}
|
||||
}(g)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
time.Sleep(50 * time.Millisecond) // 等待所有刷新完成
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
expectedTotal := numGoroutines * eventsPerGoroutine
|
||||
if totalReceived != expectedTotal {
|
||||
t.Errorf("expected %d total events, got %d", expectedTotal, totalReceived)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBatchBuffer_Close 测试关闭
|
||||
func TestBatchBuffer_Close(t *testing.T) {
|
||||
buffer := NewBatchBuffer(50, 10*time.Millisecond)
|
||||
ctx := context.Background()
|
||||
|
||||
err := buffer.Start(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Start failed: %v", err)
|
||||
}
|
||||
|
||||
// 添加一些事件
|
||||
for i := 0; i < 5; i++ {
|
||||
event := &model.AuditEvent{
|
||||
EventID: "batch-test-close",
|
||||
EventName: "TEST-CLOSE",
|
||||
}
|
||||
if err := buffer.Add(event); err != nil {
|
||||
t.Errorf("Add failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 关闭缓冲区
|
||||
err = buffer.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Close failed: %v", err)
|
||||
}
|
||||
|
||||
// 关闭后添加应该失败
|
||||
event := &model.AuditEvent{
|
||||
EventID: "batch-test-after-close",
|
||||
EventName: "TEST-AFTER-CLOSE",
|
||||
}
|
||||
if err := buffer.Add(event); err == nil {
|
||||
t.Errorf("Add after Close should fail")
|
||||
}
|
||||
}
|
||||
|
||||
// TestBatchBuffer_FlushNow 测试手动刷新
|
||||
func TestBatchBuffer_FlushNow(t *testing.T) {
|
||||
const batchSize = 100 // 足够大,不会自动触发
|
||||
|
||||
buffer := NewBatchBuffer(batchSize, 100*time.Millisecond) // 100ms才自动刷新
|
||||
ctx := context.Background()
|
||||
|
||||
err := buffer.Start(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Start failed: %v", err)
|
||||
}
|
||||
defer buffer.Close()
|
||||
|
||||
var receivedBatches [][]*model.AuditEvent
|
||||
var mu sync.Mutex
|
||||
|
||||
buffer.SetFlushHandler(func(events []*model.AuditEvent) error {
|
||||
mu.Lock()
|
||||
receivedBatches = append(receivedBatches, events)
|
||||
mu.Unlock()
|
||||
return nil
|
||||
})
|
||||
|
||||
// 添加少量事件
|
||||
for i := 0; i < 3; i++ {
|
||||
event := &model.AuditEvent{
|
||||
EventID: "batch-test-manual",
|
||||
EventName: "TEST-MANUAL",
|
||||
}
|
||||
if err := buffer.Add(event); err != nil {
|
||||
t.Errorf("Add failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 立即手动刷新
|
||||
err = buffer.FlushNow()
|
||||
if err != nil {
|
||||
t.Errorf("FlushNow failed: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if len(receivedBatches) != 1 {
|
||||
t.Errorf("expected 1 batch after FlushNow, got %d", len(receivedBatches))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user