Files
user-system/internal/repository/device_repository_test.go

576 lines
16 KiB
Go
Raw Normal View History

package repository
import (
"context"
"fmt"
"sync/atomic"
"testing"
"time"
gormsqlite "gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
_ "modernc.org/sqlite"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/pagination"
)
var deviceTestCounter int64
// openDeviceTestDB 为每个测试打开独立的内存数据库
func openDeviceTestDB(t *testing.T) *gorm.DB {
t.Helper()
id := atomic.AddInt64(&deviceTestCounter, 1)
dsn := fmt.Sprintf("file:devtestdb%d?mode=memory&cache=private", id)
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
DriverName: "sqlite",
DSN: dsn,
}), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("打开测试数据库失败: %v", err)
}
if err := db.AutoMigrate(&domain.Device{}); err != nil {
t.Fatalf("数据库迁移失败: %v", err)
}
return db
}
// setupDeviceTestDB 兼容性别名
func setupDeviceTestDB(t *testing.T) *gorm.DB {
return openDeviceTestDB(t)
}
// TestDeviceRepository_Create 测试创建设备
func TestDeviceRepository_Create(t *testing.T) {
db := setupDeviceTestDB(t)
repo := NewDeviceRepository(db)
ctx := context.Background()
device := &domain.Device{
UserID: 1,
DeviceID: "test-device-001",
DeviceName: "测试手机",
DeviceType: domain.DeviceTypeMobile,
Status: domain.DeviceStatusActive,
}
if err := repo.Create(ctx, device); err != nil {
t.Fatalf("Create() error = %v", err)
}
if device.ID == 0 {
t.Error("创建后设备ID不应为0")
}
}
// TestDeviceRepository_GetByID 测试根据ID获取设备
func TestDeviceRepository_GetByID(t *testing.T) {
db := setupDeviceTestDB(t)
repo := NewDeviceRepository(db)
ctx := context.Background()
device := &domain.Device{
UserID: 1,
DeviceID: "test-device-002",
DeviceName: "测试平板",
Status: domain.DeviceStatusActive,
}
repo.Create(ctx, device)
found, err := repo.GetByID(ctx, device.ID)
if err != nil {
t.Fatalf("GetByID() error = %v", err)
}
if found.DeviceID != "test-device-002" {
t.Errorf("DeviceID = %v, want test-device-002", found.DeviceID)
}
}
// TestDeviceRepository_GetByDeviceID 测试根据设备标识查询
func TestDeviceRepository_GetByDeviceID(t *testing.T) {
db := setupDeviceTestDB(t)
repo := NewDeviceRepository(db)
ctx := context.Background()
device := &domain.Device{
UserID: 1,
DeviceID: "unique-device-id",
DeviceName: "测试设备",
Status: domain.DeviceStatusActive,
}
repo.Create(ctx, device)
found, err := repo.GetByDeviceID(ctx, 1, "unique-device-id")
if err != nil {
t.Fatalf("GetByDeviceID() error = %v", err)
}
if found.UserID != 1 {
t.Errorf("UserID = %v, want 1", found.UserID)
}
}
// TestDeviceRepository_Update 测试更新设备
func TestDeviceRepository_Update(t *testing.T) {
db := setupDeviceTestDB(t)
repo := NewDeviceRepository(db)
ctx := context.Background()
device := &domain.Device{
UserID: 1,
DeviceID: "update-test",
DeviceName: "旧名称",
Status: domain.DeviceStatusActive,
}
repo.Create(ctx, device)
device.DeviceName = "新名称"
if err := repo.Update(ctx, device); err != nil {
t.Fatalf("Update() error = %v", err)
}
found, _ := repo.GetByID(ctx, device.ID)
if found.DeviceName != "新名称" {
t.Errorf("DeviceName = %v, want 新名称", found.DeviceName)
}
}
// TestDeviceRepository_Delete 测试删除设备
func TestDeviceRepository_Delete(t *testing.T) {
db := setupDeviceTestDB(t)
repo := NewDeviceRepository(db)
ctx := context.Background()
device := &domain.Device{
UserID: 1,
DeviceID: "delete-test",
DeviceName: "待删除",
Status: domain.DeviceStatusActive,
}
repo.Create(ctx, device)
if err := repo.Delete(ctx, device.ID); err != nil {
t.Fatalf("Delete() error = %v", err)
}
_, err := repo.GetByID(ctx, device.ID)
if err == nil {
t.Error("删除后查询应返回错误")
}
}
// TestDeviceRepository_List 测试列表查询
func TestDeviceRepository_List(t *testing.T) {
db := setupDeviceTestDB(t)
repo := NewDeviceRepository(db)
ctx := context.Background()
for i := 0; i < 3; i++ {
repo.Create(ctx, &domain.Device{
UserID: int64(i + 1),
DeviceID: "list-device-" + string(rune('a'+i)),
Status: domain.DeviceStatusActive,
})
}
devices, total, err := repo.List(ctx, 0, 10)
if err != nil {
t.Fatalf("List() error = %v", err)
}
if len(devices) != 3 {
t.Errorf("len(devices) = %d, want 3", len(devices))
}
if total != 3 {
t.Errorf("total = %d, want 3", total)
}
}
// TestDeviceRepository_ListByUserID 测试按用户ID查询设备列表
func TestDeviceRepository_ListByUserID(t *testing.T) {
db := setupDeviceTestDB(t)
repo := NewDeviceRepository(db)
ctx := context.Background()
repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "user1-dev1", Status: domain.DeviceStatusActive})
repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "user1-dev2", Status: domain.DeviceStatusActive})
repo.Create(ctx, &domain.Device{UserID: 2, DeviceID: "user2-dev1", Status: domain.DeviceStatusActive})
devices, total, err := repo.ListByUserID(ctx, 1, 0, 10)
if err != nil {
t.Fatalf("ListByUserID() error = %v", err)
}
if len(devices) != 2 {
t.Errorf("len(devices) = %d, want 2", len(devices))
}
if total != 2 {
t.Errorf("total = %d, want 2", total)
}
}
// TestDeviceRepository_ListByStatus 测试按状态查询设备列表
func TestDeviceRepository_ListByStatus(t *testing.T) {
db := setupDeviceTestDB(t)
repo := NewDeviceRepository(db)
ctx := context.Background()
repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "active1", Status: domain.DeviceStatusActive})
repo.Create(ctx, &domain.Device{UserID: 2, DeviceID: "active2", Status: domain.DeviceStatusActive})
repo.Create(ctx, &domain.Device{UserID: 3, DeviceID: "inactive1", Status: domain.DeviceStatusInactive})
devices, total, err := repo.ListByStatus(ctx, domain.DeviceStatusActive, 0, 10)
if err != nil {
t.Fatalf("ListByStatus() error = %v", err)
}
if len(devices) != 2 {
t.Errorf("len(devices) = %d, want 2", len(devices))
}
if total != 2 {
t.Errorf("total = %d, want 2", total)
}
}
// TestDeviceRepository_UpdateStatus 测试更新设备状态
func TestDeviceRepository_UpdateStatus(t *testing.T) {
db := setupDeviceTestDB(t)
repo := NewDeviceRepository(db)
ctx := context.Background()
device := &domain.Device{
UserID: 1,
DeviceID: "status-test",
DeviceName: "状态测试",
Status: domain.DeviceStatusActive,
}
repo.Create(ctx, device)
err := repo.UpdateStatus(ctx, device.ID, domain.DeviceStatusInactive)
if err != nil {
t.Fatalf("UpdateStatus() error = %v", err)
}
found, _ := repo.GetByID(ctx, device.ID)
if found.Status != domain.DeviceStatusInactive {
t.Errorf("Status = %v, want Inactive", found.Status)
}
}
// TestDeviceRepository_Exists 测试设备存在性检查
func TestDeviceRepository_Exists(t *testing.T) {
db := setupDeviceTestDB(t)
repo := NewDeviceRepository(db)
ctx := context.Background()
device := &domain.Device{
UserID: 1,
DeviceID: "exists-test",
DeviceName: "存在性测试",
Status: domain.DeviceStatusActive,
}
repo.Create(ctx, device)
exists, err := repo.Exists(ctx, 1, "exists-test")
if err != nil {
t.Fatalf("Exists() error = %v", err)
}
if !exists {
t.Error("Exists 应返回 true")
}
exists, _ = repo.Exists(ctx, 1, "not-exists")
if exists {
t.Error("不存在的设备 Exists 应返回 false")
}
}
// TestDeviceRepository_DeleteByUserID 测试删除用户的所有设备
func TestDeviceRepository_DeleteByUserID(t *testing.T) {
db := setupDeviceTestDB(t)
repo := NewDeviceRepository(db)
ctx := context.Background()
repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "user1-dev1", Status: domain.DeviceStatusActive})
repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "user1-dev2", Status: domain.DeviceStatusActive})
repo.Create(ctx, &domain.Device{UserID: 2, DeviceID: "user2-dev1", Status: domain.DeviceStatusActive})
err := repo.DeleteByUserID(ctx, 1)
if err != nil {
t.Fatalf("DeleteByUserID() error = %v", err)
}
devices, _, _ := repo.ListByUserID(ctx, 1, 0, 10)
if len(devices) != 0 {
t.Errorf("用户1设备数 = %d, want 0", len(devices))
}
// 用户2的设备应该还在
devices, _, _ = repo.ListByUserID(ctx, 2, 0, 10)
if len(devices) != 1 {
t.Errorf("用户2设备数 = %d, want 1", len(devices))
}
}
// TestDeviceRepository_GetActiveDevices 测试获取活跃设备
func TestDeviceRepository_GetActiveDevices(t *testing.T) {
db := setupDeviceTestDB(t)
repo := NewDeviceRepository(db)
ctx := context.Background()
now := time.Now()
// 创建设备并设置 LastActiveTimeGetActiveDevices 不检查状态,只检查最近活跃时间)
repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "active-dev1", Status: domain.DeviceStatusActive, LastActiveTime: now})
repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "recent-dev", Status: domain.DeviceStatusInactive, LastActiveTime: now})
devices, err := repo.GetActiveDevices(ctx, 1)
if err != nil {
t.Fatalf("GetActiveDevices() error = %v", err)
}
// GetActiveDevices 只检查 last_active_time > 30天前不检查 status
if len(devices) != 2 {
t.Errorf("len(devices) = %d, want 2", len(devices))
}
}
// TestDeviceRepository_TrustDevice 测试设置设备信任
func TestDeviceRepository_TrustDevice(t *testing.T) {
db := setupDeviceTestDB(t)
repo := NewDeviceRepository(db)
ctx := context.Background()
device := &domain.Device{
UserID: 1,
DeviceID: "trust-test",
DeviceName: "信任测试",
Status: domain.DeviceStatusActive,
}
repo.Create(ctx, device)
expiresAt := time.Now().Add(30 * 24 * time.Hour)
err := repo.TrustDevice(ctx, device.ID, &expiresAt)
if err != nil {
t.Fatalf("TrustDevice() error = %v", err)
}
found, _ := repo.GetByID(ctx, device.ID)
if !found.IsTrusted {
t.Error("IsTrusted 应为 true")
}
}
// TestDeviceRepository_UntrustDevice 测试取消设备信任
func TestDeviceRepository_UntrustDevice(t *testing.T) {
db := setupDeviceTestDB(t)
repo := NewDeviceRepository(db)
ctx := context.Background()
device := &domain.Device{
UserID: 1,
DeviceID: "untrust-test",
DeviceName: "取消信任测试",
IsTrusted: true,
Status: domain.DeviceStatusActive,
}
repo.Create(ctx, device)
err := repo.UntrustDevice(ctx, device.ID)
if err != nil {
t.Fatalf("UntrustDevice() error = %v", err)
}
found, _ := repo.GetByID(ctx, device.ID)
if found.IsTrusted {
t.Error("IsTrusted 应为 false")
}
}
// TestDeviceRepository_DeleteAllByUserIDExcept 测试删除用户设备(保留指定设备)
func TestDeviceRepository_DeleteAllByUserIDExcept(t *testing.T) {
db := setupDeviceTestDB(t)
repo := NewDeviceRepository(db)
ctx := context.Background()
d1, _ := createDevice(t, repo, ctx, 1, "keep-me")
createDevice(t, repo, ctx, 1, "delete-me1")
createDevice(t, repo, ctx, 1, "delete-me2")
err := repo.DeleteAllByUserIDExcept(ctx, 1, d1.ID)
if err != nil {
t.Fatalf("DeleteAllByUserIDExcept() error = %v", err)
}
devices, _, _ := repo.ListByUserID(ctx, 1, 0, 10)
if len(devices) != 1 {
t.Errorf("len(devices) = %d, want 1", len(devices))
}
if devices[0].ID != d1.ID {
t.Error("应保留指定设备")
}
}
// TestDeviceRepository_GetTrustedDevices 测试获取信任设备列表
func TestDeviceRepository_GetTrustedDevices(t *testing.T) {
db := setupDeviceTestDB(t)
repo := NewDeviceRepository(db)
ctx := context.Background()
trusted := &domain.Device{
UserID: 1,
DeviceID: "trusted-device",
IsTrusted: true,
Status: domain.DeviceStatusActive,
}
untrusted := &domain.Device{
UserID: 1,
DeviceID: "untrusted-device",
IsTrusted: false,
Status: domain.DeviceStatusActive,
}
repo.Create(ctx, trusted)
repo.Create(ctx, untrusted)
devices, err := repo.GetTrustedDevices(ctx, 1)
if err != nil {
t.Fatalf("GetTrustedDevices() error = %v", err)
}
if len(devices) != 1 {
t.Errorf("len(devices) = %d, want 1", len(devices))
}
}
// TestDeviceRepository_ListAll 测试带筛选条件的列表查询
func TestDeviceRepository_ListAll(t *testing.T) {
db := setupDeviceTestDB(t)
repo := NewDeviceRepository(db)
ctx := context.Background()
repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "dev1", Status: domain.DeviceStatusActive})
repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "dev2", Status: domain.DeviceStatusInactive})
repo.Create(ctx, &domain.Device{UserID: 2, DeviceID: "dev3", Status: domain.DeviceStatusActive})
// 按用户筛选
params := &ListDevicesParams{UserID: 1, Offset: 0, Limit: 10}
_, total, err := repo.ListAll(ctx, params)
if err != nil {
t.Fatalf("ListAll() error = %v", err)
}
if total != 2 {
t.Errorf("total = %d, want 2", total)
}
// 按状态筛选
status := domain.DeviceStatusActive
params2 := &ListDevicesParams{Status: &status, Offset: 0, Limit: 10}
_, total2, err := repo.ListAll(ctx, params2)
if err != nil {
t.Fatalf("ListAll() error = %v", err)
}
if total2 != 2 {
t.Errorf("total = %d, want 2", total2)
}
}
// createDevice 辅助函数:创建设备
func createDevice(t *testing.T, repo *DeviceRepository, ctx context.Context, userID int64, deviceID string) (*domain.Device, error) {
d := &domain.Device{
UserID: userID,
DeviceID: deviceID,
Status: domain.DeviceStatusActive,
}
err := repo.Create(ctx, d)
if err != nil {
t.Fatalf("createDevice() error = %v", err)
}
return d, nil
}
// TestDeviceRepository_ListAllCursor 测试设备游标分页查询
func TestDeviceRepository_ListAllCursor(t *testing.T) {
db := setupDeviceTestDB(t)
repo := NewDeviceRepository(db)
ctx := context.Background()
// 创建设备需要设置LastActiveTime以支持游标分页
now := time.Now()
for i := 0; i < 5; i++ {
repo.Create(ctx, &domain.Device{
UserID: int64(i + 1),
DeviceID: "cursor-device-" + string(rune('a'+i)),
DeviceName: "设备" + string(rune('0'+i)),
Status: domain.DeviceStatusActive,
LastActiveTime: now.Add(-time.Duration(i) * time.Minute),
})
}
// 第一次查询获取前3个
devices, hasMore, err := repo.ListAllCursor(ctx, &ListDevicesParams{Offset: 0, Limit: 10}, 3, nil)
if err != nil {
t.Fatalf("ListAllCursor() error = %v", err)
}
if len(devices) != 3 {
t.Errorf("len(devices) = %d, want 3", len(devices))
}
if !hasMore {
t.Error("hasMore should be true when more devices exist")
}
// 使用游标继续查询
lastDevice := devices[len(devices)-1]
cursor := &pagination.Cursor{
LastID: lastDevice.ID,
LastValue: lastDevice.LastActiveTime,
}
devices2, hasMore2, err := repo.ListAllCursor(ctx, &ListDevicesParams{Offset: 0, Limit: 10}, 3, cursor)
if err != nil {
t.Fatalf("ListAllCursor() error = %v", err)
}
if len(devices2) != 2 {
t.Errorf("len(devices2) = %d, want 2", len(devices2))
}
if hasMore2 {
t.Error("hasMore2 should be false")
}
}
// TestDeviceRepository_ListAllCursor_WithFilters 测试带筛选条件的设备游标分页
func TestDeviceRepository_ListAllCursor_WithFilters(t *testing.T) {
db := setupDeviceTestDB(t)
repo := NewDeviceRepository(db)
ctx := context.Background()
now := time.Now()
repo.Create(ctx, &domain.Device{
UserID: 1,
DeviceID: "filter-dev1",
DeviceName: "用户1设备",
Status: domain.DeviceStatusActive,
LastActiveTime: now,
})
repo.Create(ctx, &domain.Device{
UserID: 2,
DeviceID: "filter-dev2",
DeviceName: "用户2设备",
Status: domain.DeviceStatusActive,
LastActiveTime: now,
})
repo.Create(ctx, &domain.Device{
UserID: 1,
DeviceID: "filter-dev3",
DeviceName: "用户1禁用设备",
Status: domain.DeviceStatusInactive,
LastActiveTime: now,
})
// 按用户ID筛选
status := domain.DeviceStatusActive
devices, _, err := repo.ListAllCursor(ctx, &ListDevicesParams{UserID: 1, Status: &status, Offset: 0, Limit: 10}, 10, nil)
if err != nil {
t.Fatalf("ListAllCursor() error = %v", err)
}
if len(devices) != 1 {
t.Errorf("len(devices) = %d, want 1", len(devices))
}
}