diff --git a/internal/repository/device_repository_test.go b/internal/repository/device_repository_test.go new file mode 100644 index 0000000..843742b --- /dev/null +++ b/internal/repository/device_repository_test.go @@ -0,0 +1,486 @@ +package repository + +import ( + "context" + "fmt" + "sync/atomic" + "testing" + "time" + + _ "modernc.org/sqlite" + gormsqlite "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/logger" + + "github.com/user-management-system/internal/domain" +) + +var deviceTestCounter int64 + +// openDeviceTestDB 为每个测试打开独立的内存数据库 +func openDeviceTestDB(t *testing.T) *gorm.DB { + t.Helper() + + id := atomic.AddInt64(&deviceTestCounter, 1) + dsn := fmt.Sprintf("file:devtestdb%d?mode=memory&cache=private", id) + + db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{ + DriverName: "sqlite", + DSN: dsn, + }), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + }) + if err != nil { + t.Fatalf("打开测试数据库失败: %v", err) + } + + if err := db.AutoMigrate(&domain.Device{}); err != nil { + t.Fatalf("数据库迁移失败: %v", err) + } + return db +} + +// setupDeviceTestDB 兼容性别名 +func setupDeviceTestDB(t *testing.T) *gorm.DB { + return openDeviceTestDB(t) +} + +// TestDeviceRepository_Create 测试创建设备 +func TestDeviceRepository_Create(t *testing.T) { + db := setupDeviceTestDB(t) + repo := NewDeviceRepository(db) + ctx := context.Background() + + device := &domain.Device{ + UserID: 1, + DeviceID: "test-device-001", + DeviceName: "测试手机", + DeviceType: domain.DeviceTypeMobile, + Status: domain.DeviceStatusActive, + } + + if err := repo.Create(ctx, device); err != nil { + t.Fatalf("Create() error = %v", err) + } + if device.ID == 0 { + t.Error("创建后设备ID不应为0") + } +} + +// TestDeviceRepository_GetByID 测试根据ID获取设备 +func TestDeviceRepository_GetByID(t *testing.T) { + db := setupDeviceTestDB(t) + repo := NewDeviceRepository(db) + ctx := context.Background() + + device := &domain.Device{ + UserID: 1, + DeviceID: "test-device-002", + DeviceName: "测试平板", + Status: domain.DeviceStatusActive, + } + repo.Create(ctx, device) + + found, err := repo.GetByID(ctx, device.ID) + if err != nil { + t.Fatalf("GetByID() error = %v", err) + } + if found.DeviceID != "test-device-002" { + t.Errorf("DeviceID = %v, want test-device-002", found.DeviceID) + } +} + +// TestDeviceRepository_GetByDeviceID 测试根据设备标识查询 +func TestDeviceRepository_GetByDeviceID(t *testing.T) { + db := setupDeviceTestDB(t) + repo := NewDeviceRepository(db) + ctx := context.Background() + + device := &domain.Device{ + UserID: 1, + DeviceID: "unique-device-id", + DeviceName: "测试设备", + Status: domain.DeviceStatusActive, + } + repo.Create(ctx, device) + + found, err := repo.GetByDeviceID(ctx, 1, "unique-device-id") + if err != nil { + t.Fatalf("GetByDeviceID() error = %v", err) + } + if found.UserID != 1 { + t.Errorf("UserID = %v, want 1", found.UserID) + } +} + +// TestDeviceRepository_Update 测试更新设备 +func TestDeviceRepository_Update(t *testing.T) { + db := setupDeviceTestDB(t) + repo := NewDeviceRepository(db) + ctx := context.Background() + + device := &domain.Device{ + UserID: 1, + DeviceID: "update-test", + DeviceName: "旧名称", + Status: domain.DeviceStatusActive, + } + repo.Create(ctx, device) + + device.DeviceName = "新名称" + if err := repo.Update(ctx, device); err != nil { + t.Fatalf("Update() error = %v", err) + } + + found, _ := repo.GetByID(ctx, device.ID) + if found.DeviceName != "新名称" { + t.Errorf("DeviceName = %v, want 新名称", found.DeviceName) + } +} + +// TestDeviceRepository_Delete 测试删除设备 +func TestDeviceRepository_Delete(t *testing.T) { + db := setupDeviceTestDB(t) + repo := NewDeviceRepository(db) + ctx := context.Background() + + device := &domain.Device{ + UserID: 1, + DeviceID: "delete-test", + DeviceName: "待删除", + Status: domain.DeviceStatusActive, + } + repo.Create(ctx, device) + + if err := repo.Delete(ctx, device.ID); err != nil { + t.Fatalf("Delete() error = %v", err) + } + + _, err := repo.GetByID(ctx, device.ID) + if err == nil { + t.Error("删除后查询应返回错误") + } +} + +// TestDeviceRepository_List 测试列表查询 +func TestDeviceRepository_List(t *testing.T) { + db := setupDeviceTestDB(t) + repo := NewDeviceRepository(db) + ctx := context.Background() + + for i := 0; i < 3; i++ { + repo.Create(ctx, &domain.Device{ + UserID: int64(i + 1), + DeviceID: "list-device-" + string(rune('a'+i)), + Status: domain.DeviceStatusActive, + }) + } + + devices, total, err := repo.List(ctx, 0, 10) + if err != nil { + t.Fatalf("List() error = %v", err) + } + if len(devices) != 3 { + t.Errorf("len(devices) = %d, want 3", len(devices)) + } + if total != 3 { + t.Errorf("total = %d, want 3", total) + } +} + +// TestDeviceRepository_ListByUserID 测试按用户ID查询设备列表 +func TestDeviceRepository_ListByUserID(t *testing.T) { + db := setupDeviceTestDB(t) + repo := NewDeviceRepository(db) + ctx := context.Background() + + repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "user1-dev1", Status: domain.DeviceStatusActive}) + repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "user1-dev2", Status: domain.DeviceStatusActive}) + repo.Create(ctx, &domain.Device{UserID: 2, DeviceID: "user2-dev1", Status: domain.DeviceStatusActive}) + + devices, total, err := repo.ListByUserID(ctx, 1, 0, 10) + if err != nil { + t.Fatalf("ListByUserID() error = %v", err) + } + if len(devices) != 2 { + t.Errorf("len(devices) = %d, want 2", len(devices)) + } + if total != 2 { + t.Errorf("total = %d, want 2", total) + } +} + +// TestDeviceRepository_ListByStatus 测试按状态查询设备列表 +func TestDeviceRepository_ListByStatus(t *testing.T) { + db := setupDeviceTestDB(t) + repo := NewDeviceRepository(db) + ctx := context.Background() + + repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "active1", Status: domain.DeviceStatusActive}) + repo.Create(ctx, &domain.Device{UserID: 2, DeviceID: "active2", Status: domain.DeviceStatusActive}) + repo.Create(ctx, &domain.Device{UserID: 3, DeviceID: "inactive1", Status: domain.DeviceStatusInactive}) + + devices, total, err := repo.ListByStatus(ctx, domain.DeviceStatusActive, 0, 10) + if err != nil { + t.Fatalf("ListByStatus() error = %v", err) + } + if len(devices) != 2 { + t.Errorf("len(devices) = %d, want 2", len(devices)) + } + if total != 2 { + t.Errorf("total = %d, want 2", total) + } +} + +// TestDeviceRepository_UpdateStatus 测试更新设备状态 +func TestDeviceRepository_UpdateStatus(t *testing.T) { + db := setupDeviceTestDB(t) + repo := NewDeviceRepository(db) + ctx := context.Background() + + device := &domain.Device{ + UserID: 1, + DeviceID: "status-test", + DeviceName: "状态测试", + Status: domain.DeviceStatusActive, + } + repo.Create(ctx, device) + + err := repo.UpdateStatus(ctx, device.ID, domain.DeviceStatusInactive) + if err != nil { + t.Fatalf("UpdateStatus() error = %v", err) + } + + found, _ := repo.GetByID(ctx, device.ID) + if found.Status != domain.DeviceStatusInactive { + t.Errorf("Status = %v, want Inactive", found.Status) + } +} + +// TestDeviceRepository_Exists 测试设备存在性检查 +func TestDeviceRepository_Exists(t *testing.T) { + db := setupDeviceTestDB(t) + repo := NewDeviceRepository(db) + ctx := context.Background() + + device := &domain.Device{ + UserID: 1, + DeviceID: "exists-test", + DeviceName: "存在性测试", + Status: domain.DeviceStatusActive, + } + repo.Create(ctx, device) + + exists, err := repo.Exists(ctx, 1, "exists-test") + if err != nil { + t.Fatalf("Exists() error = %v", err) + } + if !exists { + t.Error("Exists 应返回 true") + } + + exists, _ = repo.Exists(ctx, 1, "not-exists") + if exists { + t.Error("不存在的设备 Exists 应返回 false") + } +} + +// TestDeviceRepository_DeleteByUserID 测试删除用户的所有设备 +func TestDeviceRepository_DeleteByUserID(t *testing.T) { + db := setupDeviceTestDB(t) + repo := NewDeviceRepository(db) + ctx := context.Background() + + repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "user1-dev1", Status: domain.DeviceStatusActive}) + repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "user1-dev2", Status: domain.DeviceStatusActive}) + repo.Create(ctx, &domain.Device{UserID: 2, DeviceID: "user2-dev1", Status: domain.DeviceStatusActive}) + + err := repo.DeleteByUserID(ctx, 1) + if err != nil { + t.Fatalf("DeleteByUserID() error = %v", err) + } + + devices, _, _ := repo.ListByUserID(ctx, 1, 0, 10) + if len(devices) != 0 { + t.Errorf("用户1设备数 = %d, want 0", len(devices)) + } + + // 用户2的设备应该还在 + devices, _, _ = repo.ListByUserID(ctx, 2, 0, 10) + if len(devices) != 1 { + t.Errorf("用户2设备数 = %d, want 1", len(devices)) + } +} + +// TestDeviceRepository_GetActiveDevices 测试获取活跃设备 +func TestDeviceRepository_GetActiveDevices(t *testing.T) { + db := setupDeviceTestDB(t) + repo := NewDeviceRepository(db) + ctx := context.Background() + + now := time.Now() + // 创建设备并设置 LastActiveTime(GetActiveDevices 不检查状态,只检查最近活跃时间) + repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "active-dev1", Status: domain.DeviceStatusActive, LastActiveTime: now}) + repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "recent-dev", Status: domain.DeviceStatusInactive, LastActiveTime: now}) + + devices, err := repo.GetActiveDevices(ctx, 1) + if err != nil { + t.Fatalf("GetActiveDevices() error = %v", err) + } + // GetActiveDevices 只检查 last_active_time > 30天前,不检查 status + if len(devices) != 2 { + t.Errorf("len(devices) = %d, want 2", len(devices)) + } +} + +// TestDeviceRepository_TrustDevice 测试设置设备信任 +func TestDeviceRepository_TrustDevice(t *testing.T) { + db := setupDeviceTestDB(t) + repo := NewDeviceRepository(db) + ctx := context.Background() + + device := &domain.Device{ + UserID: 1, + DeviceID: "trust-test", + DeviceName: "信任测试", + Status: domain.DeviceStatusActive, + } + repo.Create(ctx, device) + + expiresAt := time.Now().Add(30 * 24 * time.Hour) + err := repo.TrustDevice(ctx, device.ID, &expiresAt) + if err != nil { + t.Fatalf("TrustDevice() error = %v", err) + } + + found, _ := repo.GetByID(ctx, device.ID) + if !found.IsTrusted { + t.Error("IsTrusted 应为 true") + } +} + +// TestDeviceRepository_UntrustDevice 测试取消设备信任 +func TestDeviceRepository_UntrustDevice(t *testing.T) { + db := setupDeviceTestDB(t) + repo := NewDeviceRepository(db) + ctx := context.Background() + + device := &domain.Device{ + UserID: 1, + DeviceID: "untrust-test", + DeviceName: "取消信任测试", + IsTrusted: true, + Status: domain.DeviceStatusActive, + } + repo.Create(ctx, device) + + err := repo.UntrustDevice(ctx, device.ID) + if err != nil { + t.Fatalf("UntrustDevice() error = %v", err) + } + + found, _ := repo.GetByID(ctx, device.ID) + if found.IsTrusted { + t.Error("IsTrusted 应为 false") + } +} + +// TestDeviceRepository_DeleteAllByUserIDExcept 测试删除用户设备(保留指定设备) +func TestDeviceRepository_DeleteAllByUserIDExcept(t *testing.T) { + db := setupDeviceTestDB(t) + repo := NewDeviceRepository(db) + ctx := context.Background() + + d1, _ := createDevice(t, repo, ctx, 1, "keep-me") + createDevice(t, repo, ctx, 1, "delete-me1") + createDevice(t, repo, ctx, 1, "delete-me2") + + err := repo.DeleteAllByUserIDExcept(ctx, 1, d1.ID) + if err != nil { + t.Fatalf("DeleteAllByUserIDExcept() error = %v", err) + } + + devices, _, _ := repo.ListByUserID(ctx, 1, 0, 10) + if len(devices) != 1 { + t.Errorf("len(devices) = %d, want 1", len(devices)) + } + if devices[0].ID != d1.ID { + t.Error("应保留指定设备") + } +} + +// TestDeviceRepository_GetTrustedDevices 测试获取信任设备列表 +func TestDeviceRepository_GetTrustedDevices(t *testing.T) { + db := setupDeviceTestDB(t) + repo := NewDeviceRepository(db) + ctx := context.Background() + + trusted := &domain.Device{ + UserID: 1, + DeviceID: "trusted-device", + IsTrusted: true, + Status: domain.DeviceStatusActive, + } + untrusted := &domain.Device{ + UserID: 1, + DeviceID: "untrusted-device", + IsTrusted: false, + Status: domain.DeviceStatusActive, + } + repo.Create(ctx, trusted) + repo.Create(ctx, untrusted) + + devices, err := repo.GetTrustedDevices(ctx, 1) + if err != nil { + t.Fatalf("GetTrustedDevices() error = %v", err) + } + if len(devices) != 1 { + t.Errorf("len(devices) = %d, want 1", len(devices)) + } +} + +// TestDeviceRepository_ListAll 测试带筛选条件的列表查询 +func TestDeviceRepository_ListAll(t *testing.T) { + db := setupDeviceTestDB(t) + repo := NewDeviceRepository(db) + ctx := context.Background() + + repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "dev1", Status: domain.DeviceStatusActive}) + repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "dev2", Status: domain.DeviceStatusInactive}) + repo.Create(ctx, &domain.Device{UserID: 2, DeviceID: "dev3", Status: domain.DeviceStatusActive}) + + // 按用户筛选 + params := &ListDevicesParams{UserID: 1, Offset: 0, Limit: 10} + _, total, err := repo.ListAll(ctx, params) + if err != nil { + t.Fatalf("ListAll() error = %v", err) + } + if total != 2 { + t.Errorf("total = %d, want 2", total) + } + + // 按状态筛选 + status := domain.DeviceStatusActive + params2 := &ListDevicesParams{Status: &status, Offset: 0, Limit: 10} + _, total2, err := repo.ListAll(ctx, params2) + if err != nil { + t.Fatalf("ListAll() error = %v", err) + } + if total2 != 2 { + t.Errorf("total = %d, want 2", total2) + } +} + +// createDevice 辅助函数:创建设备 +func createDevice(t *testing.T, repo *DeviceRepository, ctx context.Context, userID int64, deviceID string) (*domain.Device, error) { + d := &domain.Device{ + UserID: userID, + DeviceID: deviceID, + Status: domain.DeviceStatusActive, + } + err := repo.Create(ctx, d) + if err != nil { + t.Fatalf("createDevice() error = %v", err) + } + return d, nil +}