test: 补齐 handler/repository/domain 层单元测试

This commit is contained in:
2026-05-10 12:54:13 +08:00
parent b8e9af001f
commit 28012140cb
21 changed files with 5837 additions and 1 deletions

View File

@@ -0,0 +1,95 @@
package repository
import (
"testing"
"github.com/user-management-system/internal/pkg/pagination"
)
func TestPaginationResultFromTotal(t *testing.T) {
tests := []struct {
name string
total int64
params pagination.PaginationParams
wantPages int
wantTotal int64
wantPage int
wantPageSize int
}{
{
name: "exact division",
total: 100,
params: pagination.PaginationParams{Page: 1, PageSize: 20},
wantPages: 5,
wantTotal: 100,
wantPage: 1,
wantPageSize: 20,
},
{
name: "with remainder",
total: 105,
params: pagination.PaginationParams{Page: 1, PageSize: 20},
wantPages: 6,
wantTotal: 105,
wantPage: 1,
wantPageSize: 20,
},
{
name: "zero total",
total: 0,
params: pagination.PaginationParams{Page: 1, PageSize: 20},
wantPages: 0,
wantTotal: 0,
wantPage: 1,
wantPageSize: 20,
},
{
name: "single page",
total: 5,
params: pagination.PaginationParams{Page: 1, PageSize: 20},
wantPages: 1,
wantTotal: 5,
wantPage: 1,
wantPageSize: 20,
},
{
name: "page 2",
total: 50,
params: pagination.PaginationParams{Page: 2, PageSize: 20},
wantPages: 3,
wantTotal: 50,
wantPage: 2,
wantPageSize: 20,
},
{
name: "small page size",
total: 10,
params: pagination.PaginationParams{Page: 1, PageSize: 3},
wantPages: 4,
wantTotal: 10,
wantPage: 1,
wantPageSize: 3,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := paginationResultFromTotal(tc.total, tc.params)
if result == nil {
t.Fatal("expected non-nil result")
}
if result.Total != tc.wantTotal {
t.Errorf("expected total %d, got %d", tc.wantTotal, result.Total)
}
if result.Page != tc.wantPage {
t.Errorf("expected page %d, got %d", tc.wantPage, result.Page)
}
if result.PageSize != tc.wantPageSize {
t.Errorf("expected page_size %d, got %d", tc.wantPageSize, result.PageSize)
}
if result.Pages != tc.wantPages {
t.Errorf("expected pages %d, got %d", tc.wantPages, result.Pages)
}
})
}
}

View File

@@ -0,0 +1,224 @@
package repository
import (
"context"
"testing"
"time"
"github.com/user-management-system/internal/domain"
)
func TestPasswordHistoryRepository_Create(t *testing.T) {
db := openTestDB(t)
if err := db.AutoMigrate(&domain.PasswordHistory{}); err != nil {
t.Fatalf("migrate password_history failed: %v", err)
}
repo := NewPasswordHistoryRepository(db)
ctx := context.Background()
history := &domain.PasswordHistory{
UserID: 1,
PasswordHash: "hash1",
CreatedAt: time.Now(),
}
if err := repo.Create(ctx, history); err != nil {
t.Fatalf("create failed: %v", err)
}
if history.ID == 0 {
t.Error("expected ID to be set after create")
}
}
func TestPasswordHistoryRepository_GetByUserID(t *testing.T) {
db := openTestDB(t)
if err := db.AutoMigrate(&domain.PasswordHistory{}); err != nil {
t.Fatalf("migrate password_history failed: %v", err)
}
repo := NewPasswordHistoryRepository(db)
ctx := context.Background()
// Create multiple records for user 1
for i := 0; i < 5; i++ {
h := &domain.PasswordHistory{
UserID: 1,
PasswordHash: "hash",
CreatedAt: time.Now().Add(time.Duration(i) * time.Second),
}
if err := repo.Create(ctx, h); err != nil {
t.Fatalf("create failed: %v", err)
}
}
// Create record for user 2
if err := repo.Create(ctx, &domain.PasswordHistory{UserID: 2, PasswordHash: "hash", CreatedAt: time.Now()}); err != nil {
t.Fatalf("create failed: %v", err)
}
tests := []struct {
name string
userID int64
limit int
wantLen int
wantUser int64
}{
{"get all for user 1", 1, 10, 5, 1},
{"limit 3 for user 1", 1, 3, 3, 1},
{"get for user 2", 2, 10, 1, 2},
{"get for nonexistent user", 999, 10, 0, 999},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
histories, err := repo.GetByUserID(ctx, tc.userID, tc.limit)
if err != nil {
t.Fatalf("get failed: %v", err)
}
if len(histories) != tc.wantLen {
t.Errorf("expected %d histories, got %d", tc.wantLen, len(histories))
}
for _, h := range histories {
if h.UserID != tc.wantUser {
t.Errorf("expected user_id %d, got %d", tc.wantUser, h.UserID)
}
}
})
}
}
func TestPasswordHistoryRepository_GetByUserID_Order(t *testing.T) {
db := openTestDB(t)
if err := db.AutoMigrate(&domain.PasswordHistory{}); err != nil {
t.Fatalf("migrate password_history failed: %v", err)
}
repo := NewPasswordHistoryRepository(db)
ctx := context.Background()
// Create records with different timestamps
now := time.Now()
for i := 0; i < 3; i++ {
h := &domain.PasswordHistory{
UserID: 1,
PasswordHash: "hash",
CreatedAt: now.Add(time.Duration(i) * time.Hour),
}
if err := repo.Create(ctx, h); err != nil {
t.Fatalf("create failed: %v", err)
}
}
histories, err := repo.GetByUserID(ctx, 1, 10)
if err != nil {
t.Fatalf("get failed: %v", err)
}
if len(histories) != 3 {
t.Fatalf("expected 3 histories, got %d", len(histories))
}
// Should be ordered by created_at DESC (newest first)
for i := 0; i < len(histories)-1; i++ {
if !histories[i].CreatedAt.After(histories[i+1].CreatedAt) && !histories[i].CreatedAt.Equal(histories[i+1].CreatedAt) {
t.Errorf("expected descending order, got %v before %v", histories[i].CreatedAt, histories[i+1].CreatedAt)
}
}
}
func TestPasswordHistoryRepository_DeleteOldRecords(t *testing.T) {
db := openTestDB(t)
if err := db.AutoMigrate(&domain.PasswordHistory{}); err != nil {
t.Fatalf("migrate password_history failed: %v", err)
}
repo := NewPasswordHistoryRepository(db)
ctx := context.Background()
// Create 5 records for user 1
now := time.Now()
for i := 0; i < 5; i++ {
h := &domain.PasswordHistory{
UserID: 1,
PasswordHash: "hash",
CreatedAt: now.Add(time.Duration(i) * time.Hour),
}
if err := repo.Create(ctx, h); err != nil {
t.Fatalf("create failed: %v", err)
}
}
// Delete old records, keep only 3
if err := repo.DeleteOldRecords(ctx, 1, 3); err != nil {
t.Fatalf("delete old records failed: %v", err)
}
histories, err := repo.GetByUserID(ctx, 1, 10)
if err != nil {
t.Fatalf("get failed: %v", err)
}
if len(histories) != 3 {
t.Errorf("expected 3 histories after cleanup, got %d", len(histories))
}
}
func TestPasswordHistoryRepository_DeleteOldRecords_NoRecords(t *testing.T) {
db := openTestDB(t)
if err := db.AutoMigrate(&domain.PasswordHistory{}); err != nil {
t.Fatalf("migrate password_history failed: %v", err)
}
repo := NewPasswordHistoryRepository(db)
ctx := context.Background()
// Should not error when no records exist
if err := repo.DeleteOldRecords(ctx, 999, 3); err != nil {
t.Fatalf("delete old records on empty table should not error: %v", err)
}
}
func TestPasswordHistoryRepository_KeepsNewestRecords(t *testing.T) {
db := openTestDB(t)
if err := db.AutoMigrate(&domain.PasswordHistory{}); err != nil {
t.Fatalf("migrate password_history failed: %v", err)
}
repo := NewPasswordHistoryRepository(db)
ctx := context.Background()
// Create 5 records with different timestamps
now := time.Now()
var createdIDs []int64
for i := 0; i < 5; i++ {
h := &domain.PasswordHistory{
UserID: 1,
PasswordHash: "hash",
CreatedAt: now.Add(time.Duration(i) * time.Hour),
}
if err := repo.Create(ctx, h); err != nil {
t.Fatalf("create failed: %v", err)
}
createdIDs = append(createdIDs, h.ID)
}
// Delete old records, keep only 2
if err := repo.DeleteOldRecords(ctx, 1, 2); err != nil {
t.Fatalf("delete old records failed: %v", err)
}
histories, err := repo.GetByUserID(ctx, 1, 10)
if err != nil {
t.Fatalf("get failed: %v", err)
}
if len(histories) != 2 {
t.Fatalf("expected 2 histories after cleanup, got %d", len(histories))
}
// The remaining records should be the newest (last 2 created)
expectedIDs := map[int64]bool{createdIDs[3]: true, createdIDs[4]: true}
for _, h := range histories {
if !expectedIDs[h.ID] {
t.Errorf("expected remaining IDs to be %v, got %d", expectedIDs, h.ID)
}
}
}

View File

@@ -0,0 +1,117 @@
package repository
import (
"context"
"database/sql"
"errors"
"testing"
)
// mockQueryer implements sqlQueryer for testing
type mockQueryer struct {
rows *sql.Rows
err error
}
func (m *mockQueryer) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
return m.rows, m.err
}
func TestScanSingleRow_QueryError(t *testing.T) {
ctx := context.Background()
mockErr := errors.New("query failed")
q := &mockQueryer{err: mockErr}
var dest int
err := scanSingleRow(ctx, q, "SELECT 1", nil, &dest)
if err == nil {
t.Fatal("expected error, got nil")
}
if !errors.Is(err, mockErr) {
t.Errorf("expected query error, got %v", err)
}
}
func TestScanSingleRow_NoRows(t *testing.T) {
// This test requires a real database connection to create sql.Rows.
// scanSingleRow is designed to work with any sqlQueryer, but creating
// a mock sql.Rows without a real driver is complex.
// We test the behavior through integration with the test database.
db := openTestDB(t)
ctx := context.Background()
// Use the raw sql.DB from gorm
sqlDB, err := db.DB()
if err != nil {
t.Fatalf("get sql.DB failed: %v", err)
}
var dest int
err = scanSingleRow(ctx, sqlDB, "SELECT 1 WHERE 1=0", nil, &dest)
if err == nil {
t.Fatal("expected error for no rows, got nil")
}
if !errors.Is(err, sql.ErrNoRows) {
t.Errorf("expected sql.ErrNoRows, got %v", err)
}
}
func TestScanSingleRow_Success(t *testing.T) {
db := openTestDB(t)
ctx := context.Background()
sqlDB, err := db.DB()
if err != nil {
t.Fatalf("get sql.DB failed: %v", err)
}
var dest int
err = scanSingleRow(ctx, sqlDB, "SELECT 42", nil, &dest)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if dest != 42 {
t.Errorf("expected 42, got %d", dest)
}
}
func TestScanSingleRow_MultipleColumns(t *testing.T) {
db := openTestDB(t)
ctx := context.Background()
sqlDB, err := db.DB()
if err != nil {
t.Fatalf("get sql.DB failed: %v", err)
}
var a, b int
err = scanSingleRow(ctx, sqlDB, "SELECT 1, 2", nil, &a, &b)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if a != 1 {
t.Errorf("expected a=1, got %d", a)
}
if b != 2 {
t.Errorf("expected b=2, got %d", b)
}
}
func TestScanSingleRow_StringResult(t *testing.T) {
db := openTestDB(t)
ctx := context.Background()
sqlDB, err := db.DB()
if err != nil {
t.Fatalf("get sql.DB failed: %v", err)
}
var dest string
err = scanSingleRow(ctx, sqlDB, "SELECT 'hello'", nil, &dest)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if dest != "hello" {
t.Errorf("expected 'hello', got %q", dest)
}
}