test: 补齐 handler/repository/domain 层单元测试
This commit is contained in:
95
internal/repository/pagination_test.go
Normal file
95
internal/repository/pagination_test.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/user-management-system/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
func TestPaginationResultFromTotal(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
total int64
|
||||
params pagination.PaginationParams
|
||||
wantPages int
|
||||
wantTotal int64
|
||||
wantPage int
|
||||
wantPageSize int
|
||||
}{
|
||||
{
|
||||
name: "exact division",
|
||||
total: 100,
|
||||
params: pagination.PaginationParams{Page: 1, PageSize: 20},
|
||||
wantPages: 5,
|
||||
wantTotal: 100,
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "with remainder",
|
||||
total: 105,
|
||||
params: pagination.PaginationParams{Page: 1, PageSize: 20},
|
||||
wantPages: 6,
|
||||
wantTotal: 105,
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "zero total",
|
||||
total: 0,
|
||||
params: pagination.PaginationParams{Page: 1, PageSize: 20},
|
||||
wantPages: 0,
|
||||
wantTotal: 0,
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "single page",
|
||||
total: 5,
|
||||
params: pagination.PaginationParams{Page: 1, PageSize: 20},
|
||||
wantPages: 1,
|
||||
wantTotal: 5,
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "page 2",
|
||||
total: 50,
|
||||
params: pagination.PaginationParams{Page: 2, PageSize: 20},
|
||||
wantPages: 3,
|
||||
wantTotal: 50,
|
||||
wantPage: 2,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "small page size",
|
||||
total: 10,
|
||||
params: pagination.PaginationParams{Page: 1, PageSize: 3},
|
||||
wantPages: 4,
|
||||
wantTotal: 10,
|
||||
wantPage: 1,
|
||||
wantPageSize: 3,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := paginationResultFromTotal(tc.total, tc.params)
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
if result.Total != tc.wantTotal {
|
||||
t.Errorf("expected total %d, got %d", tc.wantTotal, result.Total)
|
||||
}
|
||||
if result.Page != tc.wantPage {
|
||||
t.Errorf("expected page %d, got %d", tc.wantPage, result.Page)
|
||||
}
|
||||
if result.PageSize != tc.wantPageSize {
|
||||
t.Errorf("expected page_size %d, got %d", tc.wantPageSize, result.PageSize)
|
||||
}
|
||||
if result.Pages != tc.wantPages {
|
||||
t.Errorf("expected pages %d, got %d", tc.wantPages, result.Pages)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
224
internal/repository/password_history_test.go
Normal file
224
internal/repository/password_history_test.go
Normal file
@@ -0,0 +1,224 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
)
|
||||
|
||||
func TestPasswordHistoryRepository_Create(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
if err := db.AutoMigrate(&domain.PasswordHistory{}); err != nil {
|
||||
t.Fatalf("migrate password_history failed: %v", err)
|
||||
}
|
||||
|
||||
repo := NewPasswordHistoryRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
history := &domain.PasswordHistory{
|
||||
UserID: 1,
|
||||
PasswordHash: "hash1",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
if err := repo.Create(ctx, history); err != nil {
|
||||
t.Fatalf("create failed: %v", err)
|
||||
}
|
||||
if history.ID == 0 {
|
||||
t.Error("expected ID to be set after create")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordHistoryRepository_GetByUserID(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
if err := db.AutoMigrate(&domain.PasswordHistory{}); err != nil {
|
||||
t.Fatalf("migrate password_history failed: %v", err)
|
||||
}
|
||||
|
||||
repo := NewPasswordHistoryRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create multiple records for user 1
|
||||
for i := 0; i < 5; i++ {
|
||||
h := &domain.PasswordHistory{
|
||||
UserID: 1,
|
||||
PasswordHash: "hash",
|
||||
CreatedAt: time.Now().Add(time.Duration(i) * time.Second),
|
||||
}
|
||||
if err := repo.Create(ctx, h); err != nil {
|
||||
t.Fatalf("create failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create record for user 2
|
||||
if err := repo.Create(ctx, &domain.PasswordHistory{UserID: 2, PasswordHash: "hash", CreatedAt: time.Now()}); err != nil {
|
||||
t.Fatalf("create failed: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int64
|
||||
limit int
|
||||
wantLen int
|
||||
wantUser int64
|
||||
}{
|
||||
{"get all for user 1", 1, 10, 5, 1},
|
||||
{"limit 3 for user 1", 1, 3, 3, 1},
|
||||
{"get for user 2", 2, 10, 1, 2},
|
||||
{"get for nonexistent user", 999, 10, 0, 999},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
histories, err := repo.GetByUserID(ctx, tc.userID, tc.limit)
|
||||
if err != nil {
|
||||
t.Fatalf("get failed: %v", err)
|
||||
}
|
||||
if len(histories) != tc.wantLen {
|
||||
t.Errorf("expected %d histories, got %d", tc.wantLen, len(histories))
|
||||
}
|
||||
for _, h := range histories {
|
||||
if h.UserID != tc.wantUser {
|
||||
t.Errorf("expected user_id %d, got %d", tc.wantUser, h.UserID)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordHistoryRepository_GetByUserID_Order(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
if err := db.AutoMigrate(&domain.PasswordHistory{}); err != nil {
|
||||
t.Fatalf("migrate password_history failed: %v", err)
|
||||
}
|
||||
|
||||
repo := NewPasswordHistoryRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create records with different timestamps
|
||||
now := time.Now()
|
||||
for i := 0; i < 3; i++ {
|
||||
h := &domain.PasswordHistory{
|
||||
UserID: 1,
|
||||
PasswordHash: "hash",
|
||||
CreatedAt: now.Add(time.Duration(i) * time.Hour),
|
||||
}
|
||||
if err := repo.Create(ctx, h); err != nil {
|
||||
t.Fatalf("create failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
histories, err := repo.GetByUserID(ctx, 1, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("get failed: %v", err)
|
||||
}
|
||||
if len(histories) != 3 {
|
||||
t.Fatalf("expected 3 histories, got %d", len(histories))
|
||||
}
|
||||
|
||||
// Should be ordered by created_at DESC (newest first)
|
||||
for i := 0; i < len(histories)-1; i++ {
|
||||
if !histories[i].CreatedAt.After(histories[i+1].CreatedAt) && !histories[i].CreatedAt.Equal(histories[i+1].CreatedAt) {
|
||||
t.Errorf("expected descending order, got %v before %v", histories[i].CreatedAt, histories[i+1].CreatedAt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordHistoryRepository_DeleteOldRecords(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
if err := db.AutoMigrate(&domain.PasswordHistory{}); err != nil {
|
||||
t.Fatalf("migrate password_history failed: %v", err)
|
||||
}
|
||||
|
||||
repo := NewPasswordHistoryRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create 5 records for user 1
|
||||
now := time.Now()
|
||||
for i := 0; i < 5; i++ {
|
||||
h := &domain.PasswordHistory{
|
||||
UserID: 1,
|
||||
PasswordHash: "hash",
|
||||
CreatedAt: now.Add(time.Duration(i) * time.Hour),
|
||||
}
|
||||
if err := repo.Create(ctx, h); err != nil {
|
||||
t.Fatalf("create failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Delete old records, keep only 3
|
||||
if err := repo.DeleteOldRecords(ctx, 1, 3); err != nil {
|
||||
t.Fatalf("delete old records failed: %v", err)
|
||||
}
|
||||
|
||||
histories, err := repo.GetByUserID(ctx, 1, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("get failed: %v", err)
|
||||
}
|
||||
if len(histories) != 3 {
|
||||
t.Errorf("expected 3 histories after cleanup, got %d", len(histories))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordHistoryRepository_DeleteOldRecords_NoRecords(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
if err := db.AutoMigrate(&domain.PasswordHistory{}); err != nil {
|
||||
t.Fatalf("migrate password_history failed: %v", err)
|
||||
}
|
||||
|
||||
repo := NewPasswordHistoryRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Should not error when no records exist
|
||||
if err := repo.DeleteOldRecords(ctx, 999, 3); err != nil {
|
||||
t.Fatalf("delete old records on empty table should not error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordHistoryRepository_KeepsNewestRecords(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
if err := db.AutoMigrate(&domain.PasswordHistory{}); err != nil {
|
||||
t.Fatalf("migrate password_history failed: %v", err)
|
||||
}
|
||||
|
||||
repo := NewPasswordHistoryRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create 5 records with different timestamps
|
||||
now := time.Now()
|
||||
var createdIDs []int64
|
||||
for i := 0; i < 5; i++ {
|
||||
h := &domain.PasswordHistory{
|
||||
UserID: 1,
|
||||
PasswordHash: "hash",
|
||||
CreatedAt: now.Add(time.Duration(i) * time.Hour),
|
||||
}
|
||||
if err := repo.Create(ctx, h); err != nil {
|
||||
t.Fatalf("create failed: %v", err)
|
||||
}
|
||||
createdIDs = append(createdIDs, h.ID)
|
||||
}
|
||||
|
||||
// Delete old records, keep only 2
|
||||
if err := repo.DeleteOldRecords(ctx, 1, 2); err != nil {
|
||||
t.Fatalf("delete old records failed: %v", err)
|
||||
}
|
||||
|
||||
histories, err := repo.GetByUserID(ctx, 1, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("get failed: %v", err)
|
||||
}
|
||||
if len(histories) != 2 {
|
||||
t.Fatalf("expected 2 histories after cleanup, got %d", len(histories))
|
||||
}
|
||||
|
||||
// The remaining records should be the newest (last 2 created)
|
||||
expectedIDs := map[int64]bool{createdIDs[3]: true, createdIDs[4]: true}
|
||||
for _, h := range histories {
|
||||
if !expectedIDs[h.ID] {
|
||||
t.Errorf("expected remaining IDs to be %v, got %d", expectedIDs, h.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
117
internal/repository/sql_scan_test.go
Normal file
117
internal/repository/sql_scan_test.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// mockQueryer implements sqlQueryer for testing
|
||||
type mockQueryer struct {
|
||||
rows *sql.Rows
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockQueryer) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
|
||||
return m.rows, m.err
|
||||
}
|
||||
|
||||
func TestScanSingleRow_QueryError(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mockErr := errors.New("query failed")
|
||||
q := &mockQueryer{err: mockErr}
|
||||
|
||||
var dest int
|
||||
err := scanSingleRow(ctx, q, "SELECT 1", nil, &dest)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if !errors.Is(err, mockErr) {
|
||||
t.Errorf("expected query error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanSingleRow_NoRows(t *testing.T) {
|
||||
// This test requires a real database connection to create sql.Rows.
|
||||
// scanSingleRow is designed to work with any sqlQueryer, but creating
|
||||
// a mock sql.Rows without a real driver is complex.
|
||||
// We test the behavior through integration with the test database.
|
||||
db := openTestDB(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Use the raw sql.DB from gorm
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
t.Fatalf("get sql.DB failed: %v", err)
|
||||
}
|
||||
|
||||
var dest int
|
||||
err = scanSingleRow(ctx, sqlDB, "SELECT 1 WHERE 1=0", nil, &dest)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for no rows, got nil")
|
||||
}
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
t.Errorf("expected sql.ErrNoRows, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanSingleRow_Success(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
ctx := context.Background()
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
t.Fatalf("get sql.DB failed: %v", err)
|
||||
}
|
||||
|
||||
var dest int
|
||||
err = scanSingleRow(ctx, sqlDB, "SELECT 42", nil, &dest)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if dest != 42 {
|
||||
t.Errorf("expected 42, got %d", dest)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanSingleRow_MultipleColumns(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
ctx := context.Background()
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
t.Fatalf("get sql.DB failed: %v", err)
|
||||
}
|
||||
|
||||
var a, b int
|
||||
err = scanSingleRow(ctx, sqlDB, "SELECT 1, 2", nil, &a, &b)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if a != 1 {
|
||||
t.Errorf("expected a=1, got %d", a)
|
||||
}
|
||||
if b != 2 {
|
||||
t.Errorf("expected b=2, got %d", b)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanSingleRow_StringResult(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
ctx := context.Background()
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
t.Fatalf("get sql.DB failed: %v", err)
|
||||
}
|
||||
|
||||
var dest string
|
||||
err = scanSingleRow(ctx, sqlDB, "SELECT 'hello'", nil, &dest)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if dest != "hello" {
|
||||
t.Errorf("expected 'hello', got %q", dest)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user