Files
user-system/internal/service/device.go

422 lines
13 KiB
Go
Raw Normal View History

package service
import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/pagination"
apierrors "github.com/user-management-system/internal/pkg/errors"
"github.com/user-management-system/internal/repository"
"gorm.io/gorm"
)
type deviceRepository interface {
Create(ctx context.Context, device *domain.Device) error
Update(ctx context.Context, device *domain.Device) error
Delete(ctx context.Context, id int64) error
GetByID(ctx context.Context, id int64) (*domain.Device, error)
GetByDeviceID(ctx context.Context, userID int64, deviceID string) (*domain.Device, error)
Exists(ctx context.Context, userID int64, deviceID string) (bool, error)
ListByUserID(ctx context.Context, userID int64, offset, limit int) ([]*domain.Device, int64, error)
ListByStatus(ctx context.Context, status domain.DeviceStatus, offset, limit int) ([]*domain.Device, int64, error)
UpdateStatus(ctx context.Context, id int64, status domain.DeviceStatus) error
UpdateLastActiveTime(ctx context.Context, id int64) error
TrustDevice(ctx context.Context, id int64, expiresAt *time.Time) error
UntrustDevice(ctx context.Context, id int64) error
DeleteAllByUserIDExcept(ctx context.Context, userID int64, exceptDeviceID int64) error
GetTrustedDevices(ctx context.Context, userID int64) ([]*domain.Device, error)
CountTrustedDevices(ctx context.Context, userID int64) (int64, error)
ListAll(ctx context.Context, params *repository.ListDevicesParams) ([]*domain.Device, int64, error)
ListAllCursor(ctx context.Context, params *repository.ListDevicesParams, limit int, cursor *pagination.Cursor) ([]*domain.Device, bool, error)
}
type deviceUserRepository interface {
GetByID(ctx context.Context, id int64) (*domain.User, error)
}
type DeviceService struct {
deviceRepo deviceRepository
userRepo deviceUserRepository
}
func NewDeviceService(deviceRepo deviceRepository, userRepo deviceUserRepository) *DeviceService {
return &DeviceService{
deviceRepo: deviceRepo,
userRepo: userRepo,
}
}
type CreateDeviceRequest struct {
DeviceID string `json:"device_id" binding:"required"`
DeviceName string `json:"device_name"`
DeviceType int `json:"device_type"`
DeviceOS string `json:"device_os"`
DeviceBrowser string `json:"device_browser"`
IP string `json:"ip"`
Location string `json:"location"`
}
type UpdateDeviceRequest struct {
DeviceName string `json:"device_name"`
DeviceType int `json:"device_type"`
DeviceOS string `json:"device_os"`
DeviceBrowser string `json:"device_browser"`
IP string `json:"ip"`
Location string `json:"location"`
Status int `json:"status"`
}
func (s *DeviceService) CreateDevice(ctx context.Context, userID int64, req *CreateDeviceRequest) (*domain.Device, error) {
if _, err := s.userRepo.GetByID(ctx, userID); err != nil {
return nil, errors.New("user not found")
}
exists, err := s.deviceRepo.Exists(ctx, userID, req.DeviceID)
if err != nil {
return nil, err
}
if exists {
device, err := s.deviceRepo.GetByDeviceID(ctx, userID, req.DeviceID)
if err != nil {
return nil, err
}
device.LastActiveTime = time.Now()
return device, s.deviceRepo.Update(ctx, device)
}
device := &domain.Device{
UserID: userID,
DeviceID: req.DeviceID,
DeviceName: req.DeviceName,
DeviceType: domain.DeviceType(req.DeviceType),
DeviceOS: req.DeviceOS,
DeviceBrowser: req.DeviceBrowser,
IP: req.IP,
Location: req.Location,
Status: domain.DeviceStatusActive,
}
if err := s.deviceRepo.Create(ctx, device); err != nil {
return nil, err
}
return device, nil
}
func isDeviceNotFoundError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, gorm.ErrRecordNotFound) {
return true
}
lowerErr := strings.ToLower(strings.TrimSpace(err.Error()))
return strings.Contains(lowerErr, "record not found") ||
strings.Contains(lowerErr, "device not found") ||
strings.Contains(lowerErr, "not found")
}
func (s *DeviceService) getDeviceByID(ctx context.Context, deviceID int64) (*domain.Device, error) {
device, err := s.deviceRepo.GetByID(ctx, deviceID)
if err != nil {
if isDeviceNotFoundError(err) {
return nil, apierrors.NotFound("device_not_found", "device not found").WithCause(err)
}
return nil, err
}
return device, nil
}
func (s *DeviceService) getAuthorizedDevice(ctx context.Context, actorUserID, deviceID int64, isAdmin bool) (*domain.Device, error) {
device, err := s.getDeviceByID(ctx, deviceID)
if err != nil {
return nil, err
}
if !isAdmin && device.UserID != actorUserID {
return nil, apierrors.Forbidden("device_forbidden", "permission denied")
}
return device, nil
}
func (s *DeviceService) persistDeviceUpdate(ctx context.Context, device *domain.Device, req *UpdateDeviceRequest) (*domain.Device, error) {
if req == nil {
return device, nil
}
if req.DeviceName != "" {
device.DeviceName = req.DeviceName
}
if req.DeviceType >= 0 {
device.DeviceType = domain.DeviceType(req.DeviceType)
}
if req.DeviceOS != "" {
device.DeviceOS = req.DeviceOS
}
if req.DeviceBrowser != "" {
device.DeviceBrowser = req.DeviceBrowser
}
if req.IP != "" {
device.IP = req.IP
}
if req.Location != "" {
device.Location = req.Location
}
if req.Status >= 0 {
device.Status = domain.DeviceStatus(req.Status)
}
if err := s.deviceRepo.Update(ctx, device); err != nil {
return nil, err
}
return device, nil
}
// maxTrustedDevicesPerUser 每个用户最大信任设备数量P2 安全增强)
const maxTrustedDevicesPerUser = 10
func (s *DeviceService) trustDeviceRecord(ctx context.Context, device *domain.Device, trustDuration time.Duration) error {
// P2 安全增强:检查信任设备数量上限
trustedCount, err := s.deviceRepo.CountTrustedDevices(ctx, device.UserID)
if err != nil {
return fmt.Errorf("count trusted devices failed: %w", err)
}
if trustedCount >= maxTrustedDevicesPerUser {
return fmt.Errorf("trusted device limit reached (max %d), please untrust an existing device first", maxTrustedDevicesPerUser)
}
var trustExpiresAt *time.Time
if trustDuration > 0 {
expiresAt := time.Now().Add(trustDuration)
trustExpiresAt = &expiresAt
}
return s.deviceRepo.TrustDevice(ctx, device.ID, trustExpiresAt)
}
func (s *DeviceService) UpdateDevice(ctx context.Context, deviceID int64, req *UpdateDeviceRequest) (*domain.Device, error) {
device, err := s.getDeviceByID(ctx, deviceID)
if err != nil {
return nil, err
}
return s.persistDeviceUpdate(ctx, device, req)
}
func (s *DeviceService) DeleteDevice(ctx context.Context, deviceID int64) error {
device, err := s.getDeviceByID(ctx, deviceID)
if err != nil {
return err
}
return s.deviceRepo.Delete(ctx, device.ID)
}
func (s *DeviceService) GetDevice(ctx context.Context, deviceID int64) (*domain.Device, error) {
return s.getDeviceByID(ctx, deviceID)
}
func (s *DeviceService) GetUserDevices(ctx context.Context, userID int64, page, pageSize int) ([]*domain.Device, int64, error) {
if page <= 0 {
page = 1
}
if pageSize <= 0 {
pageSize = 20
}
offset := (page - 1) * pageSize
return s.deviceRepo.ListByUserID(ctx, userID, offset, pageSize)
}
func (s *DeviceService) GetDeviceForActor(ctx context.Context, actorUserID, deviceID int64, isAdmin bool) (*domain.Device, error) {
return s.getAuthorizedDevice(ctx, actorUserID, deviceID, isAdmin)
}
func (s *DeviceService) UpdateDeviceForActor(ctx context.Context, actorUserID, deviceID int64, isAdmin bool, req *UpdateDeviceRequest) (*domain.Device, error) {
device, err := s.getAuthorizedDevice(ctx, actorUserID, deviceID, isAdmin)
if err != nil {
return nil, err
}
return s.persistDeviceUpdate(ctx, device, req)
}
func (s *DeviceService) DeleteDeviceForActor(ctx context.Context, actorUserID, deviceID int64, isAdmin bool) error {
device, err := s.getAuthorizedDevice(ctx, actorUserID, deviceID, isAdmin)
if err != nil {
return err
}
return s.deviceRepo.Delete(ctx, device.ID)
}
func (s *DeviceService) UpdateDeviceStatus(ctx context.Context, deviceID int64, status domain.DeviceStatus) error {
device, err := s.getDeviceByID(ctx, deviceID)
if err != nil {
return err
}
return s.deviceRepo.UpdateStatus(ctx, device.ID, status)
}
func (s *DeviceService) UpdateDeviceStatusForActor(ctx context.Context, actorUserID, deviceID int64, isAdmin bool, status domain.DeviceStatus) error {
device, err := s.getAuthorizedDevice(ctx, actorUserID, deviceID, isAdmin)
if err != nil {
return err
}
return s.deviceRepo.UpdateStatus(ctx, device.ID, status)
}
func (s *DeviceService) UpdateLastActiveTime(ctx context.Context, deviceID int64) error {
return s.deviceRepo.UpdateLastActiveTime(ctx, deviceID)
}
func (s *DeviceService) GetActiveDevices(ctx context.Context, page, pageSize int) ([]*domain.Device, int64, error) {
if page <= 0 {
page = 1
}
if pageSize <= 0 {
pageSize = 20
}
offset := (page - 1) * pageSize
return s.deviceRepo.ListByStatus(ctx, domain.DeviceStatusActive, offset, pageSize)
}
func (s *DeviceService) TrustDevice(ctx context.Context, deviceID int64, trustDuration time.Duration) error {
device, err := s.getDeviceByID(ctx, deviceID)
if err != nil {
return err
}
return s.trustDeviceRecord(ctx, device, trustDuration)
}
func (s *DeviceService) TrustDeviceForActor(ctx context.Context, actorUserID, deviceID int64, isAdmin bool, trustDuration time.Duration) error {
device, err := s.getAuthorizedDevice(ctx, actorUserID, deviceID, isAdmin)
if err != nil {
return err
}
return s.trustDeviceRecord(ctx, device, trustDuration)
}
func (s *DeviceService) TrustDeviceByDeviceID(ctx context.Context, userID int64, deviceID string, trustDuration time.Duration) error {
device, err := s.deviceRepo.GetByDeviceID(ctx, userID, deviceID)
if err != nil {
if isDeviceNotFoundError(err) {
return apierrors.NotFound("device_not_found", "device not found").WithCause(err)
}
return err
}
return s.trustDeviceRecord(ctx, device, trustDuration)
}
func (s *DeviceService) UntrustDevice(ctx context.Context, deviceID int64) error {
device, err := s.getDeviceByID(ctx, deviceID)
if err != nil {
return err
}
return s.deviceRepo.UntrustDevice(ctx, device.ID)
}
func (s *DeviceService) UntrustDeviceForActor(ctx context.Context, actorUserID, deviceID int64, isAdmin bool) error {
device, err := s.getAuthorizedDevice(ctx, actorUserID, deviceID, isAdmin)
if err != nil {
return err
}
return s.deviceRepo.UntrustDevice(ctx, device.ID)
}
func (s *DeviceService) LogoutAllOtherDevices(ctx context.Context, userID int64, currentDeviceID int64) error {
return s.deviceRepo.DeleteAllByUserIDExcept(ctx, userID, currentDeviceID)
}
func (s *DeviceService) GetTrustedDevices(ctx context.Context, userID int64) ([]*domain.Device, error) {
return s.deviceRepo.GetTrustedDevices(ctx, userID)
}
type GetAllDevicesRequest struct {
Page int `form:"page"`
PageSize int `form:"page_size"`
UserID int64 `form:"user_id"`
Status *int `form:"status"`
IsTrusted *bool `form:"is_trusted"`
Keyword string `form:"keyword"`
Cursor string `form:"cursor"`
Size int `form:"size"`
}
func (s *DeviceService) GetAllDevices(ctx context.Context, req *GetAllDevicesRequest) ([]*domain.Device, int64, error) {
if req.Page <= 0 {
req.Page = 1
}
if req.PageSize <= 0 {
req.PageSize = 20
}
if req.PageSize > 100 {
req.PageSize = 100
}
offset := (req.Page - 1) * req.PageSize
params := &repository.ListDevicesParams{
UserID: req.UserID,
Keyword: req.Keyword,
Offset: offset,
Limit: req.PageSize,
}
if req.Status != nil && (*req.Status == 0 || *req.Status == 1) {
status := domain.DeviceStatus(*req.Status)
params.Status = &status
}
if req.IsTrusted != nil {
params.IsTrusted = req.IsTrusted
}
return s.deviceRepo.ListAll(ctx, params)
}
func (s *DeviceService) GetAllDevicesCursor(ctx context.Context, req *GetAllDevicesRequest) (*CursorResult, error) {
size := pagination.ClampPageSize(req.Size)
if req.PageSize > 0 && req.Cursor == "" {
size = pagination.ClampPageSize(req.PageSize)
}
cursor, err := pagination.Decode(req.Cursor)
if err != nil {
return nil, fmt.Errorf("invalid cursor: %w", err)
}
params := &repository.ListDevicesParams{
UserID: req.UserID,
Keyword: req.Keyword,
}
if req.Status != nil && (*req.Status == 0 || *req.Status == 1) {
status := domain.DeviceStatus(*req.Status)
params.Status = &status
}
if req.IsTrusted != nil {
params.IsTrusted = req.IsTrusted
}
devices, hasMore, err := s.deviceRepo.ListAllCursor(ctx, params, size, cursor)
if err != nil {
return nil, err
}
nextCursor := ""
if len(devices) > 0 {
last := devices[len(devices)-1]
nextCursor = pagination.BuildNextCursor(last.ID, last.LastActiveTime)
}
return &CursorResult{
Items: devices,
NextCursor: nextCursor,
HasMore: hasMore,
PageSize: size,
}, nil
}
func (s *DeviceService) GetDeviceByDeviceID(ctx context.Context, userID int64, deviceID string) (*domain.Device, error) {
return s.deviceRepo.GetByDeviceID(ctx, userID, deviceID)
}