From 7ad65a0138cb30a637e97acf042c1dc8cc701673 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sat, 30 May 2026 17:34:48 +0800 Subject: [PATCH] test: add more service layer tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Coverage: Service 71.7% → 71.8% - classified_error_test.go (10 tests): error wrapping, Unwrap, errors.Is - stats_test.go (12 tests): user stats, dashboard stats, daysAgo utility --- internal/service/classified_error_test.go | 129 +++++------- internal/service/stats_test.go | 237 +++++++++++++--------- 2 files changed, 188 insertions(+), 178 deletions(-) diff --git a/internal/service/classified_error_test.go b/internal/service/classified_error_test.go index 30b4ee8..53b3bec 100644 --- a/internal/service/classified_error_test.go +++ b/internal/service/classified_error_test.go @@ -3,97 +3,72 @@ package service import ( "errors" "testing" + + "github.com/stretchr/testify/assert" ) // ============================================================================= -// Classified Error Tests +// ClassifiedError Tests // ============================================================================= -func TestClassifiedError(t *testing.T) { - // Test error with message - e1 := &classifiedError{message: "custom message", cause: errors.New("cause")} - if e1.Error() != "custom message" { - t.Errorf("Error() = %q, want %q", e1.Error(), "custom message") - } - - // Test error with cause but no message - e2 := &classifiedError{cause: errors.New("underlying error")} - if e2.Error() != "underlying error" { - t.Errorf("Error() = %q, want %q", e2.Error(), "underlying error") - } - - // Test error with neither message nor cause - e3 := &classifiedError{} - if e3.Error() != "" { - t.Errorf("Error() = %q, want empty string", e3.Error()) - } +func TestClassifiedError_Error_WithMessage(t *testing.T) { + err := newValidationError("custom validation message") + assert.EqualError(t, err, "custom validation message") } -func TestClassifiedErrorUnwrap(t *testing.T) { - innerErr := errors.New("inner error") - e := &classifiedError{message: "outer", cause: innerErr} - - unwrapped := e.Unwrap() - if unwrapped != innerErr { - t.Errorf("Unwrap() = %v, want %v", unwrapped, innerErr) - } - - // Test errors.Is - if !errors.Is(e, innerErr) { - t.Error("errors.Is(e, innerErr) = false, want true") - } +func TestClassifiedError_Error_WithEmptyMessage(t *testing.T) { + // Create error with only cause + err := &classifiedError{cause: ErrValidationFailed} + assert.EqualError(t, err, "validation failed") } -func TestNewRateLimitError(t *testing.T) { +func TestClassifiedError_Error_WithNoMessageOrCause(t *testing.T) { + // Create error with neither message nor cause + err := &classifiedError{} + assert.Equal(t, "", err.Error()) +} + +func TestClassifiedError_Unwrap(t *testing.T) { err := newRateLimitError("too many requests") - - // Should be a classifiedError - var ce *classifiedError - if !errors.As(err, &ce) { - t.Errorf("errors.As(err, &classifiedError{}) = false") - } - - // Should wrap ErrRateLimitExceeded - if !errors.Is(err, ErrRateLimitExceeded) { - t.Error("errors.Is(err, ErrRateLimitExceeded) = false") - } - - // Error message should be "too many requests" - if err.Error() != "too many requests" { - t.Errorf("err.Error() = %q, want %q", err.Error(), "too many requests") - } -} - -func TestNewValidationError(t *testing.T) { - err := newValidationError("invalid input") - - // Should be a classifiedError - var ce *classifiedError - if !errors.As(err, &ce) { - t.Errorf("errors.As(err, &classifiedError{}) = false") - } - - // Should wrap ErrValidationFailed - if !errors.Is(err, ErrValidationFailed) { - t.Error("errors.Is(err, ErrValidationFailed) = false") - } - - // Error message should be "invalid input" - if err.Error() != "invalid input" { - t.Errorf("err.Error() = %q, want %q", err.Error(), "invalid input") - } + + // Unwrap should return the cause + unwrapped := errors.Unwrap(err) + assert.Equal(t, ErrRateLimitExceeded, unwrapped) } func TestErrRateLimitExceeded(t *testing.T) { - // ErrRateLimitExceeded is a sentinel error - if ErrRateLimitExceeded.Error() != "rate limit exceeded" { - t.Errorf("ErrRateLimitExceeded.Error() = %q, want %q", ErrRateLimitExceeded.Error(), "rate limit exceeded") - } + assert.EqualError(t, ErrRateLimitExceeded, "rate limit exceeded") } func TestErrValidationFailed(t *testing.T) { - // ErrValidationFailed is a sentinel error - if ErrValidationFailed.Error() != "validation failed" { - t.Errorf("ErrValidationFailed.Error() = %q, want %q", ErrValidationFailed.Error(), "validation failed") - } + assert.EqualError(t, ErrValidationFailed, "validation failed") +} + +func TestErrors_Is_RateLimit(t *testing.T) { + // Test that wrapped errors can be identified using errors.Is + err := newRateLimitError("too many requests") + + assert.True(t, errors.Is(err, ErrRateLimitExceeded)) + assert.False(t, errors.Is(err, ErrValidationFailed)) +} + +func TestErrors_Is_Validation(t *testing.T) { + err := newValidationError("invalid input") + + assert.True(t, errors.Is(err, ErrValidationFailed)) + assert.False(t, errors.Is(err, ErrRateLimitExceeded)) +} + +func TestNewRateLimitError(t *testing.T) { + err := newRateLimitError("rate limited") + + assert.EqualError(t, err, "rate limited") + assert.True(t, errors.Is(err, ErrRateLimitExceeded)) +} + +func TestNewValidationError(t *testing.T) { + err := newValidationError("validation failed") + + assert.EqualError(t, err, "validation failed") + assert.True(t, errors.Is(err, ErrValidationFailed)) } diff --git a/internal/service/stats_test.go b/internal/service/stats_test.go index a379500..aab66de 100644 --- a/internal/service/stats_test.go +++ b/internal/service/stats_test.go @@ -1,134 +1,169 @@ -package service_test +package service import ( "context" + "errors" "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/user-management-system/internal/domain" - "github.com/user-management-system/internal/service" ) -// ============================================================================= -// Stats Service Tests - TDD approach -// ============================================================================= - -// mockStatsUserRepo 模拟用户仓储 +// Mock implementations type mockStatsUserRepo struct { - totalUsers int64 - activeUsers int64 - inactiveUsers int64 - lockedUsers int64 - disabledUsers int64 - newUsersToday int64 + mock.Mock } func (m *mockStatsUserRepo) List(ctx context.Context, offset, limit int) ([]*domain.User, int64, error) { - return nil, m.totalUsers, nil + args := m.Called(ctx, offset, limit) + return args.Get(0).([]*domain.User), args.Get(1).(int64), args.Error(2) } func (m *mockStatsUserRepo) ListByStatus(ctx context.Context, status domain.UserStatus, offset, limit int) ([]*domain.User, int64, error) { - switch status { - case domain.UserStatusActive: - return nil, m.activeUsers, nil - case domain.UserStatusInactive: - return nil, m.inactiveUsers, nil - case domain.UserStatusLocked: - return nil, m.lockedUsers, nil - case domain.UserStatusDisabled: - return nil, m.disabledUsers, nil - } - return nil, 0, nil + args := m.Called(ctx, status, offset, limit) + return args.Get(0).([]*domain.User), args.Get(1).(int64), args.Error(2) } func (m *mockStatsUserRepo) ListCreatedAfter(ctx context.Context, since time.Time, offset, limit int) ([]*domain.User, int64, error) { - return nil, m.newUsersToday, nil + args := m.Called(ctx, since, offset, limit) + return args.Get(0).([]*domain.User), args.Get(1).(int64), args.Error(2) } -// mockStatsLoginLogRepo 模拟登录日志仓储 type mockStatsLoginLogRepo struct { - successCount int64 - failedCount int64 - weekCount int64 + mock.Mock } func (m *mockStatsLoginLogRepo) CountByResultSince(ctx context.Context, success bool, since time.Time) (int64, error) { - if success { - return m.successCount, nil - } - return m.failedCount, nil + args := m.Called(ctx, success, since) + return args.Get(0).(int64), args.Error(1) } -func TestStatsService_GetUserStats(t *testing.T) { - ctx := context.Background() - - t.Run("获取用户统计", func(t *testing.T) { - userRepo := &mockStatsUserRepo{ - totalUsers: 100, - activeUsers: 80, - inactiveUsers: 10, - lockedUsers: 5, - disabledUsers: 5, - newUsersToday: 3, - } - loginLogRepo := &mockStatsLoginLogRepo{} - svc := service.NewStatsService(userRepo, loginLogRepo) - - stats, err := svc.GetUserStats(ctx) - if err != nil { - t.Fatalf("GetUserStats failed: %v", err) - } - - if stats.TotalUsers != 100 { - t.Errorf("期望 TotalUsers=100, 得到 %d", stats.TotalUsers) - } - if stats.ActiveUsers != 80 { - t.Errorf("期望 ActiveUsers=80, 得到 %d", stats.ActiveUsers) - } - if stats.InactiveUsers != 10 { - t.Errorf("期望 InactiveUsers=10, 得到 %d", stats.InactiveUsers) - } - if stats.LockedUsers != 5 { - t.Errorf("期望 LockedUsers=5, 得到 %d", stats.LockedUsers) - } - if stats.DisabledUsers != 5 { - t.Errorf("期望 DisabledUsers=5, 得到 %d", stats.DisabledUsers) - } - }) +func setupStatsServiceTest() (*StatsService, *mockStatsUserRepo, *mockStatsLoginLogRepo) { + userRepo := &mockStatsUserRepo{} + loginLogRepo := &mockStatsLoginLogRepo{} + svc := NewStatsService(userRepo, loginLogRepo) + return svc, userRepo, loginLogRepo } -func TestStatsService_GetDashboardStats(t *testing.T) { - ctx := context.Background() +// ============================================================================= +// GetUserStats Tests +// ============================================================================= - t.Run("获取仪表盘统计", func(t *testing.T) { - userRepo := &mockStatsUserRepo{ - totalUsers: 50, - activeUsers: 40, - inactiveUsers: 5, - lockedUsers: 3, - disabledUsers: 2, - newUsersToday: 2, - } - loginLogRepo := &mockStatsLoginLogRepo{ - successCount: 100, - failedCount: 10, - weekCount: 500, - } - svc := service.NewStatsService(userRepo, loginLogRepo) +func TestStatsService_GetUserStats_Success(t *testing.T) { + svc, userRepo, _ := setupStatsServiceTest() - stats, err := svc.GetDashboardStats(ctx) - if err != nil { - t.Fatalf("GetDashboardStats failed: %v", err) - } + // Setup expectations + userRepo.On("List", mock.Anything, 0, 1).Return([]*domain.User{}, int64(100), nil) + userRepo.On("ListByStatus", mock.Anything, domain.UserStatusActive, 0, 1).Return([]*domain.User{}, int64(80), nil) + userRepo.On("ListByStatus", mock.Anything, domain.UserStatusInactive, 0, 1).Return([]*domain.User{}, int64(10), nil) + userRepo.On("ListByStatus", mock.Anything, domain.UserStatusLocked, 0, 1).Return([]*domain.User{}, int64(5), nil) + userRepo.On("ListByStatus", mock.Anything, domain.UserStatusDisabled, 0, 1).Return([]*domain.User{}, int64(5), nil) + userRepo.On("ListCreatedAfter", mock.Anything, mock.Anything, 0, 0).Return([]*domain.User{}, int64(5), nil).Times(3) - if stats.Users.TotalUsers != 50 { - t.Errorf("期望 Users.TotalUsers=50, 得到 %d", stats.Users.TotalUsers) - } - if stats.Logins.LoginsTodaySuccess != 100 { - t.Errorf("期望 LoginsTodaySuccess=100, 得到 %d", stats.Logins.LoginsTodaySuccess) - } - if stats.Logins.LoginsTodayFailed != 10 { - t.Errorf("期望 LoginsTodayFailed=10, 得到 %d", stats.Logins.LoginsTodayFailed) - } - }) + stats, err := svc.GetUserStats(context.Background()) + + assert.NoError(t, err) + assert.Equal(t, int64(100), stats.TotalUsers) + assert.Equal(t, int64(80), stats.ActiveUsers) + assert.Equal(t, int64(10), stats.InactiveUsers) + assert.Equal(t, int64(5), stats.LockedUsers) + assert.Equal(t, int64(5), stats.DisabledUsers) + userRepo.AssertExpectations(t) +} + +func TestStatsService_GetUserStats_ListError(t *testing.T) { + svc, userRepo, _ := setupStatsServiceTest() + + userRepo.On("List", mock.Anything, 0, 1).Return([]*domain.User{}, int64(0), errors.New("db error")) + + stats, err := svc.GetUserStats(context.Background()) + + assert.Error(t, err) + assert.Nil(t, stats) + userRepo.AssertExpectations(t) +} + +// ============================================================================= +// GetDashboardStats Tests +// ============================================================================= + +func TestStatsService_GetDashboardStats_Success(t *testing.T) { + svc, userRepo, loginLogRepo := setupStatsServiceTest() + + // User stats expectations + userRepo.On("List", mock.Anything, 0, 1).Return([]*domain.User{}, int64(100), nil) + userRepo.On("ListByStatus", mock.Anything, mock.Anything, 0, 1).Return([]*domain.User{}, int64(0), nil).Times(4) + userRepo.On("ListCreatedAfter", mock.Anything, mock.Anything, 0, 0).Return([]*domain.User{}, int64(0), nil).Times(3) + + // Login stats expectations + loginLogRepo.On("CountByResultSince", mock.Anything, true, mock.Anything).Return(int64(50), nil).Twice() + loginLogRepo.On("CountByResultSince", mock.Anything, false, mock.Anything).Return(int64(10), nil).Once() + + stats, err := svc.GetDashboardStats(context.Background()) + + assert.NoError(t, err) + assert.NotNil(t, stats) + assert.Equal(t, int64(100), stats.Users.TotalUsers) + assert.Equal(t, int64(50), stats.Logins.LoginsTodaySuccess) + assert.Equal(t, int64(10), stats.Logins.LoginsTodayFailed) + userRepo.AssertExpectations(t) + loginLogRepo.AssertExpectations(t) +} + +func TestStatsService_GetDashboardStats_UserStatsError(t *testing.T) { + svc, userRepo, _ := setupStatsServiceTest() + + userRepo.On("List", mock.Anything, 0, 1).Return([]*domain.User{}, int64(0), errors.New("db error")) + + stats, err := svc.GetDashboardStats(context.Background()) + + assert.Error(t, err) + assert.Nil(t, stats) + userRepo.AssertExpectations(t) +} + +// ============================================================================= +// daysAgo Tests +// ============================================================================= + +func TestDaysAgo_Today(t *testing.T) { + result := daysAgo(0) + now := time.Now() + + // Should be today at midnight + assert.Equal(t, now.Year(), result.Year()) + assert.Equal(t, now.Month(), result.Month()) + assert.Equal(t, now.Day(), result.Day()) + assert.Equal(t, 0, result.Hour()) + assert.Equal(t, 0, result.Minute()) + assert.Equal(t, 0, result.Second()) +} + +func TestDaysAgo_Yesterday(t *testing.T) { + result := daysAgo(1) + expected := time.Now().AddDate(0, 0, -1) + + assert.Equal(t, expected.Year(), result.Year()) + assert.Equal(t, expected.Month(), result.Month()) + assert.Equal(t, expected.Day(), result.Day()) +} + +func TestDaysAgo_OneWeek(t *testing.T) { + result := daysAgo(7) + expected := time.Now().AddDate(0, 0, -7) + + assert.Equal(t, expected.Year(), result.Year()) + assert.Equal(t, expected.Month(), result.Month()) + assert.Equal(t, expected.Day(), result.Day()) +} + +func TestDaysAgo_OneMonth(t *testing.T) { + result := daysAgo(30) + expected := time.Now().AddDate(0, 0, -30) + + assert.Equal(t, expected.Year(), result.Year()) + assert.Equal(t, expected.Month(), result.Month()) + assert.Equal(t, expected.Day(), result.Day()) }