package repository import ( "context" "testing" "time" "gorm.io/gorm" "github.com/user-management-system/internal/domain" "github.com/user-management-system/internal/pagination" ) func setupTestDB(t *testing.T) *gorm.DB { return openTestDB(t) } // TestUserRepository_Create 测试创建用户 func TestUserRepository_Create(t *testing.T) { db := setupTestDB(t) repo := NewUserRepository(db) ctx := context.Background() user := &domain.User{ Username: "testuser", Email: domain.StrPtr("test@example.com"), Phone: domain.StrPtr("13800138000"), Password: "hashedpassword", Status: domain.UserStatusActive, } if err := repo.Create(ctx, user); err != nil { t.Fatalf("Create() error = %v", err) } if user.ID == 0 { t.Error("创建后用户ID不应为0") } } // TestUserRepository_GetByUsername 测试根据用户名查询 func TestUserRepository_GetByUsername(t *testing.T) { db := setupTestDB(t) repo := NewUserRepository(db) ctx := context.Background() user := &domain.User{ Username: "findme", Email: domain.StrPtr("findme@example.com"), Password: "hash", Status: domain.UserStatusActive, } repo.Create(ctx, user) found, err := repo.GetByUsername(ctx, "findme") if err != nil { t.Fatalf("GetByUsername() error = %v", err) } if found.Username != "findme" { t.Errorf("Username = %v, want findme", found.Username) } _, err = repo.GetByUsername(ctx, "notexist") if err == nil { t.Error("查找不存在的用户应返回错误") } } // TestUserRepository_GetByEmail 测试根据邮箱查询 func TestUserRepository_GetByEmail(t *testing.T) { db := setupTestDB(t) repo := NewUserRepository(db) ctx := context.Background() user := &domain.User{ Username: "emailuser", Email: domain.StrPtr("email@example.com"), Password: "hash", Status: domain.UserStatusActive, } repo.Create(ctx, user) found, err := repo.GetByEmail(ctx, "email@example.com") if err != nil { t.Fatalf("GetByEmail() error = %v", err) } if domain.DerefStr(found.Email) != "email@example.com" { t.Errorf("Email = %v, want email@example.com", domain.DerefStr(found.Email)) } } // TestUserRepository_Update 测试更新用户 func TestUserRepository_Update(t *testing.T) { db := setupTestDB(t) repo := NewUserRepository(db) ctx := context.Background() user := &domain.User{ Username: "updateme", Email: domain.StrPtr("update@example.com"), Password: "hash", Status: domain.UserStatusActive, } repo.Create(ctx, user) user.Nickname = "已更新" if err := repo.Update(ctx, user); err != nil { t.Fatalf("Update() error = %v", err) } found, _ := repo.GetByID(ctx, user.ID) if found.Nickname != "已更新" { t.Errorf("Nickname = %v, want 已更新", found.Nickname) } } // TestUserRepository_Delete 测试删除用户 func TestUserRepository_Delete(t *testing.T) { db := setupTestDB(t) repo := NewUserRepository(db) ctx := context.Background() user := &domain.User{ Username: "deleteme", Email: domain.StrPtr("delete@example.com"), Password: "hash", Status: domain.UserStatusActive, } repo.Create(ctx, user) if err := repo.Delete(ctx, user.ID); err != nil { t.Fatalf("Delete() error = %v", err) } _, err := repo.GetByID(ctx, user.ID) if err == nil { t.Error("删除后查询应返回错误") } } // TestUserRepository_ExistsBy 测试存在性检查 func TestUserRepository_ExistsBy(t *testing.T) { db := setupTestDB(t) repo := NewUserRepository(db) ctx := context.Background() user := &domain.User{ Username: "existsuser", Email: domain.StrPtr("exists@example.com"), Phone: domain.StrPtr("13900139000"), Password: "hash", Status: domain.UserStatusActive, } repo.Create(ctx, user) ok, _ := repo.ExistsByUsername(ctx, "existsuser") if !ok { t.Error("ExistsByUsername 应返回 true") } ok, _ = repo.ExistsByEmail(ctx, "exists@example.com") if !ok { t.Error("ExistsByEmail 应返回 true") } ok, _ = repo.ExistsByPhone(ctx, "13900139000") if !ok { t.Error("ExistsByPhone 应返回 true") } ok, _ = repo.ExistsByUsername(ctx, "notexist") if ok { t.Error("不存在的用户 ExistsByUsername 应返回 false") } } // TestUserRepository_List 测试列表查询 func TestUserRepository_List(t *testing.T) { db := setupTestDB(t) repo := NewUserRepository(db) ctx := context.Background() for i := 0; i < 5; i++ { repo.Create(ctx, &domain.User{ Username: "listuser" + string(rune('0'+i)), Password: "hash", Status: domain.UserStatusActive, }) } users, total, err := repo.List(ctx, 0, 10) if err != nil { t.Fatalf("List() error = %v", err) } if len(users) != 5 { t.Errorf("len(users) = %d, want 5", len(users)) } if total != 5 { t.Errorf("total = %d, want 5", total) } } // TestUserRepository_GetByPhone tests phone lookup func TestUserRepository_GetByPhone(t *testing.T) { db := setupTestDB(t) repo := NewUserRepository(db) ctx := context.Background() user := &domain.User{ Username: "phoneuser", Email: domain.StrPtr("phone@example.com"), Phone: domain.StrPtr("13700137000"), Password: "hash", Status: domain.UserStatusActive, } repo.Create(ctx, user) found, err := repo.GetByPhone(ctx, "13700137000") if err != nil { t.Fatalf("GetByPhone() error = %v", err) } if found.Username != "phoneuser" { t.Errorf("Username = %v, want phoneuser", found.Username) } } // TestUserRepository_ListByStatus tests status filtering func TestUserRepository_ListByStatus(t *testing.T) { db := setupTestDB(t) repo := NewUserRepository(db) ctx := context.Background() repo.Create(ctx, &domain.User{ Username: "active1", Password: "hash", Status: domain.UserStatusActive, }) repo.Create(ctx, &domain.User{ Username: "active2", Password: "hash", Status: domain.UserStatusActive, }) repo.Create(ctx, &domain.User{ Username: "inactive1", Password: "hash", Status: domain.UserStatusInactive, }) users, total, err := repo.ListByStatus(ctx, domain.UserStatusActive, 0, 10) if err != nil { t.Fatalf("ListByStatus() error = %v", err) } if len(users) != 2 { t.Errorf("len(users) = %d, want 2", len(users)) } if total != 2 { t.Errorf("total = %d, want 2", total) } } // TestUserRepository_UpdateStatus tests status update func TestUserRepository_UpdateStatus(t *testing.T) { db := setupTestDB(t) repo := NewUserRepository(db) ctx := context.Background() user := &domain.User{ Username: "statususer", Email: domain.StrPtr("status@example.com"), Password: "hash", Status: domain.UserStatusActive, } repo.Create(ctx, user) err := repo.UpdateStatus(ctx, user.ID, domain.UserStatusInactive) if err != nil { t.Fatalf("UpdateStatus() error = %v", err) } found, _ := repo.GetByID(ctx, user.ID) if found.Status != domain.UserStatusInactive { t.Errorf("Status = %v, want Inactive", found.Status) } } // TestUserRepository_BatchUpdateStatus tests batch status update func TestUserRepository_BatchUpdateStatus(t *testing.T) { db := setupTestDB(t) repo := NewUserRepository(db) ctx := context.Background() user1 := &domain.User{ Username: "batch1", Email: domain.StrPtr("batch1@example.com"), Password: "hash", Status: domain.UserStatusActive, } user2 := &domain.User{ Username: "batch2", Email: domain.StrPtr("batch2@example.com"), Password: "hash", Status: domain.UserStatusActive, } repo.Create(ctx, user1) repo.Create(ctx, user2) err := repo.BatchUpdateStatus(ctx, []int64{user1.ID, user2.ID}, domain.UserStatusInactive) if err != nil { t.Fatalf("BatchUpdateStatus() error = %v", err) } found1, _ := repo.GetByID(ctx, user1.ID) found2, _ := repo.GetByID(ctx, user2.ID) if found1.Status != domain.UserStatusInactive || found2.Status != domain.UserStatusInactive { t.Error("BatchUpdateStatus failed") } } // TestUserRepository_BatchDelete tests batch delete func TestUserRepository_BatchDelete(t *testing.T) { db := setupTestDB(t) repo := NewUserRepository(db) ctx := context.Background() user1 := &domain.User{ Username: "del1", Email: domain.StrPtr("del1@example.com"), Password: "hash", Status: domain.UserStatusActive, } user2 := &domain.User{ Username: "del2", Email: domain.StrPtr("del2@example.com"), Password: "hash", Status: domain.UserStatusActive, } repo.Create(ctx, user1) repo.Create(ctx, user2) err := repo.BatchDelete(ctx, []int64{user1.ID, user2.ID}) if err != nil { t.Fatalf("BatchDelete() error = %v", err) } _, err1 := repo.GetByID(ctx, user1.ID) _, err2 := repo.GetByID(ctx, user2.ID) if err1 == nil || err2 == nil { t.Error("BatchDelete should have deleted users") } } // TestUserRepository_Search tests user search func TestUserRepository_Search(t *testing.T) { db := setupTestDB(t) repo := NewUserRepository(db) ctx := context.Background() repo.Create(ctx, &domain.User{ Username: "searchuser1", Nickname: "张三", Email: domain.StrPtr("zhangsan@example.com"), Password: "hash", Status: domain.UserStatusActive, }) repo.Create(ctx, &domain.User{ Username: "searchuser2", Nickname: "李四", Email: domain.StrPtr("lisi@example.com"), Password: "hash", Status: domain.UserStatusActive, }) users, total, err := repo.Search(ctx, "zhang", 0, 10) if err != nil { t.Fatalf("Search() error = %v", err) } if len(users) != 1 { t.Errorf("len(users) = %d, want 1", len(users)) } if total != 1 { t.Errorf("total = %d, want 1", total) } } // TestUserRepository_Search_LikePattern tests search with LIKE special chars func TestUserRepository_Search_LikePattern(t *testing.T) { db := setupTestDB(t) repo := NewUserRepository(db) ctx := context.Background() repo.Create(ctx, &domain.User{ Username: "user%with%percent", Nickname: "测试用户", Email: domain.StrPtr("percent@example.com"), Password: "hash", Status: domain.UserStatusActive, }) // Search should handle LIKE special chars safely users, _, err := repo.Search(ctx, "%", 0, 10) if err != nil { t.Fatalf("Search() error = %v", err) } // Should not error and should escape properly _ = users } // TestUserRepository_GetByIDs 测试批量获取用户 func TestUserRepository_GetByIDs(t *testing.T) { db := setupTestDB(t) repo := NewUserRepository(db) ctx := context.Background() u1 := &domain.User{Username: "batchuser1", Password: "hash", Status: domain.UserStatusActive} u2 := &domain.User{Username: "batchuser2", Password: "hash", Status: domain.UserStatusActive} u3 := &domain.User{Username: "batchuser3", Password: "hash", Status: domain.UserStatusActive} repo.Create(ctx, u1) repo.Create(ctx, u2) repo.Create(ctx, u3) users, err := repo.GetByIDs(ctx, []int64{u1.ID, u3.ID}) if err != nil { t.Fatalf("GetByIDs() error = %v", err) } if len(users) != 2 { t.Errorf("len(users) = %d, want 2", len(users)) } } // TestUserRepository_GetByIDs_Empty 测试空ID列表 func TestUserRepository_GetByIDs_Empty(t *testing.T) { db := setupTestDB(t) repo := NewUserRepository(db) ctx := context.Background() users, err := repo.GetByIDs(ctx, []int64{}) if err != nil { t.Fatalf("GetByIDs() error = %v", err) } if len(users) != 0 { t.Errorf("len(users) = %d, want 0", len(users)) } } // TestUserRepository_UpdatePassword 测试更新密码 func TestUserRepository_UpdatePassword(t *testing.T) { db := setupTestDB(t) repo := NewUserRepository(db) ctx := context.Background() user := &domain.User{ Username: "pwduser", Password: "oldpassword", Status: domain.UserStatusActive, } repo.Create(ctx, user) err := repo.UpdatePassword(ctx, user.ID, "newpasswordhash") if err != nil { t.Fatalf("UpdatePassword() error = %v", err) } found, _ := repo.GetByID(ctx, user.ID) if found.Password != "newpasswordhash" { t.Errorf("Password = %v, want newpasswordhash", found.Password) } } // TestUserRepository_UpdateTOTP 测试更新TOTP func TestUserRepository_UpdateTOTP(t *testing.T) { db := setupTestDB(t) repo := NewUserRepository(db) ctx := context.Background() user := &domain.User{ Username: "totpuser", Password: "hash", Status: domain.UserStatusActive, } repo.Create(ctx, user) user.TOTPEnabled = true user.TOTPSecret = "JBSWY3DPEHPK3PXP" err := repo.UpdateTOTP(ctx, user) if err != nil { t.Fatalf("UpdateTOTP() error = %v", err) } found, _ := repo.GetByID(ctx, user.ID) if !found.TOTPEnabled { t.Error("TOTPEnabled should be true") } if found.TOTPSecret != "JBSWY3DPEHPK3PXP" { t.Errorf("TOTPSecret = %v, want JBSWY3DPEHPK3PXP", found.TOTPSecret) } } // TestUserRepository_ListCreatedAfter 测试查询创建时间之后的用户 func TestUserRepository_ListCreatedAfter(t *testing.T) { db := setupTestDB(t) repo := NewUserRepository(db) ctx := context.Background() user := &domain.User{ Username: "afteruser", Password: "hash", Status: domain.UserStatusActive, } repo.Create(ctx, user) since := user.CreatedAt.Add(-1 * time.Hour) users, total, err := repo.ListCreatedAfter(ctx, since, 0, 10) if err != nil { t.Fatalf("ListCreatedAfter() error = %v", err) } if total < 1 { t.Errorf("total = %d, want at least 1", total) } _ = users } // TestUserRepository_ListCreatedAfter_Limited 测试带limit的查询 func TestUserRepository_ListCreatedAfter_Limited(t *testing.T) { db := setupTestDB(t) repo := NewUserRepository(db) ctx := context.Background() for i := 0; i < 5; i++ { repo.Create(ctx, &domain.User{ Username: "limituser" + string(rune('0'+i)), Password: "hash", Status: domain.UserStatusActive, }) } since := time.Now().Add(-1 * time.Hour) users, total, err := repo.ListCreatedAfter(ctx, since, 0, 3) if err != nil { t.Fatalf("ListCreatedAfter() error = %v", err) } if len(users) != 3 { t.Errorf("len(users) = %d, want 3", len(users)) } if total < 5 { t.Errorf("total = %d, want at least 5", total) } } // TestUserRepository_AdvancedSearch 测试高级搜索 func TestUserRepository_AdvancedSearch(t *testing.T) { db := setupTestDB(t) repo := NewUserRepository(db) ctx := context.Background() repo.Create(ctx, &domain.User{ Username: "searchuser1", Nickname: "张三", Email: domain.StrPtr("zhangsan@example.com"), Password: "hash", Status: domain.UserStatusActive, }) repo.Create(ctx, &domain.User{ Username: "searchuser2", Nickname: "李四", Email: domain.StrPtr("lisi@example.com"), Password: "hash", Status: domain.UserStatusActive, }) repo.Create(ctx, &domain.User{ Username: "searchuser3", Nickname: "王五", Email: domain.StrPtr("wangwu@example.com"), Password: "hash", Status: domain.UserStatusInactive, }) // 按关键字搜索(Status=-1 表示全部状态) filter := &AdvancedFilter{Keyword: "searchuser1", Status: -1, Offset: 0, Limit: 10} users, total, err := repo.AdvancedSearch(ctx, filter) if err != nil { t.Fatalf("AdvancedSearch() error = %v", err) } if len(users) != 1 { t.Errorf("len(users) = %d, want 1", len(users)) } if total != 1 { t.Errorf("total = %d, want 1", total) } // 按状态筛选 filter2 := &AdvancedFilter{Status: int(domain.UserStatusActive), Offset: 0, Limit: 10} users2, total2, err := repo.AdvancedSearch(ctx, filter2) if err != nil { t.Fatalf("AdvancedSearch() error = %v", err) } if len(users2) != 2 { t.Errorf("len(users2) = %d, want 2", len(users2)) } if total2 != 2 { t.Errorf("total2 = %d, want 2", total2) } // 按状态筛选 - 禁用用户 filter3 := &AdvancedFilter{Status: int(domain.UserStatusInactive), Offset: 0, Limit: 10} users3, total3, err := repo.AdvancedSearch(ctx, filter3) if err != nil { t.Fatalf("AdvancedSearch() error = %v", err) } if len(users3) != 1 { t.Errorf("len(users3) = %d, want 1", len(users3)) } if total3 != 1 { t.Errorf("total3 = %d, want 1", total3) } } // TestUserRepository_AdvancedSearch_AllStatus 测试状态为-1返回全部 func TestUserRepository_AdvancedSearch_AllStatus(t *testing.T) { db := setupTestDB(t) repo := NewUserRepository(db) ctx := context.Background() repo.Create(ctx, &domain.User{Username: "active", Password: "hash", Status: domain.UserStatusActive}) repo.Create(ctx, &domain.User{Username: "inactive", Password: "hash", Status: domain.UserStatusInactive}) filter := &AdvancedFilter{Status: -1, Offset: 0, Limit: 10} users, total, err := repo.AdvancedSearch(ctx, filter) if err != nil { t.Fatalf("AdvancedSearch() error = %v", err) } if len(users) != 2 { t.Errorf("len(users) = %d, want 2", len(users)) } if total != 2 { t.Errorf("total = %d, want 2", total) } } // TestUserRepository_AdvancedSearch_LikeSpecialChars 测试搜索LIKE特殊字符转义 func TestUserRepository_AdvancedSearch_LikeSpecialChars(t *testing.T) { db := setupTestDB(t) repo := NewUserRepository(db) ctx := context.Background() repo.Create(ctx, &domain.User{ Username: "user%with%percent", Nickname: "测试用户", Password: "hash", Status: domain.UserStatusActive, }) // 搜索%应该不匹配任何记录(被转义) filter := &AdvancedFilter{Keyword: "%", Offset: 0, Limit: 10} users, _, err := repo.AdvancedSearch(ctx, filter) if err != nil { t.Fatalf("AdvancedSearch() error = %v", err) } if len(users) != 0 { t.Errorf("len(users) = %d, want 0 for escaped percent", len(users)) } } // TestUserRepository_ListCursor 测试用户游标分页查询 func TestUserRepository_ListCursor(t *testing.T) { db := setupTestDB(t) repo := NewUserRepository(db) ctx := context.Background() // 创建多个用户 for i := 0; i < 5; i++ { repo.Create(ctx, &domain.User{ Username: "cursoruser" + string(rune('a'+i)), Password: "hash", Status: domain.UserStatusActive, }) } // 第一次查询 filter := &AdvancedFilter{Status: int(domain.UserStatusActive), Offset: 0, Limit: 3} users, hasMore, err := repo.ListCursor(ctx, filter, 3, nil) if err != nil { t.Fatalf("ListCursor() error = %v", err) } if len(users) != 3 { t.Errorf("len(users) = %d, want 3", len(users)) } if !hasMore { t.Error("hasMore should be true when more users exist") } // 使用游标继续查询 lastUser := users[len(users)-1] cursor := &pagination.Cursor{ LastID: lastUser.ID, LastValue: lastUser.CreatedAt, } users2, hasMore2, err := repo.ListCursor(ctx, filter, 3, cursor) if err != nil { t.Fatalf("ListCursor() error = %v", err) } if len(users2) != 2 { t.Errorf("len(users2) = %d, want 2", len(users2)) } if hasMore2 { t.Error("hasMore2 should be false") } } // TestUserRepository_ListCursor_WithKeyword 测试带关键字过滤的游标分页 func TestUserRepository_ListCursor_WithKeyword(t *testing.T) { db := setupTestDB(t) repo := NewUserRepository(db) ctx := context.Background() repo.Create(ctx, &domain.User{ Username: "keyworduser1", Nickname: "张三", Password: "hash", Status: domain.UserStatusActive, }) repo.Create(ctx, &domain.User{ Username: "keyworduser2", Nickname: "李四", Password: "hash", Status: domain.UserStatusActive, }) repo.Create(ctx, &domain.User{ Username: "otheruser", Nickname: "王五", Password: "hash", Status: domain.UserStatusActive, }) filter := &AdvancedFilter{Keyword: "keyword", Status: -1, Offset: 0, Limit: 10} users, _, err := repo.ListCursor(ctx, filter, 10, nil) if err != nil { t.Fatalf("ListCursor() error = %v", err) } if len(users) != 2 { t.Errorf("len(users) = %d, want 2", len(users)) } } // TestUserRepository_ListCursor_WithCreatedRange 测试带创建时间范围的游标分页 func TestUserRepository_ListCursor_WithCreatedRange(t *testing.T) { db := setupTestDB(t) repo := NewUserRepository(db) ctx := context.Background() repo.Create(ctx, &domain.User{ Username: "timeuser1", Password: "hash", Status: domain.UserStatusActive, }) repo.Create(ctx, &domain.User{ Username: "timeuser2", Password: "hash", Status: domain.UserStatusActive, }) now := time.Now() filter := &AdvancedFilter{ Status: -1, CreatedFrom: func() *time.Time { t := now.Add(-time.Hour); return &t }(), CreatedTo: func() *time.Time { return &now }(), Offset: 0, Limit: 10, } users, _, err := repo.ListCursor(ctx, filter, 10, nil) if err != nil { t.Fatalf("ListCursor() error = %v", err) } if len(users) != 2 { t.Errorf("len(users) = %d, want 2", len(users)) } } // TestUserRepository_ListCursor_WithRoleIDs 测试带角色过滤的游标分页 func TestUserRepository_ListCursor_WithRoleIDs(t *testing.T) { db := setupTestDB(t) repo := NewUserRepository(db) ctx := context.Background() // 创建用户 user1 := &domain.User{Username: "roleuser1", Password: "hash", Status: domain.UserStatusActive} user2 := &domain.User{Username: "roleuser2", Password: "hash", Status: domain.UserStatusActive} repo.Create(ctx, user1) repo.Create(ctx, user2) // 创建角色 role := &domain.Role{Code: "testrole", Name: "测试角色", Status: domain.RoleStatusEnabled} db.WithContext(ctx).Create(role) // 分配角色给user1 urRepo := NewUserRoleRepository(db) urRepo.Create(ctx, &domain.UserRole{UserID: user1.ID, RoleID: role.ID}) filter := &AdvancedFilter{RoleIDs: []int64{role.ID}, Status: -1, Offset: 0, Limit: 10} users, _, err := repo.ListCursor(ctx, filter, 10, nil) if err != nil { t.Fatalf("ListCursor() error = %v", err) } if len(users) != 1 { t.Errorf("len(users) = %d, want 1", len(users)) } if users[0].Username != "roleuser1" { t.Errorf("users[0].Username = %s, want roleuser1", users[0].Username) } }