package repository import ( "context" "gorm.io/gorm" "github.com/user-management-system/internal/domain" ) // UserRoleRepository 用户角色关联数据访问层 type UserRoleRepository struct { db *gorm.DB } // NewUserRoleRepository 创建用户角色关联数据访问层 func NewUserRoleRepository(db *gorm.DB) *UserRoleRepository { return &UserRoleRepository{db: db} } // DB returns the underlying GORM DB for transaction support func (r *UserRoleRepository) DB() *gorm.DB { return r.db } // WithTx returns a new repository instance that uses the given transaction func (r *UserRoleRepository) WithTx(tx *gorm.DB) *UserRoleRepository { return &UserRoleRepository{db: tx} } // Create 创建用户角色关联 func (r *UserRoleRepository) Create(ctx context.Context, userRole *domain.UserRole) error { return r.db.WithContext(ctx).Create(userRole).Error } // Delete 删除用户角色关联 func (r *UserRoleRepository) Delete(ctx context.Context, id int64) error { return r.db.WithContext(ctx).Delete(&domain.UserRole{}, id).Error } // DeleteByUserID 删除用户的所有角色 func (r *UserRoleRepository) DeleteByUserID(ctx context.Context, userID int64) error { return r.db.WithContext(ctx).Where("user_id = ?", userID).Delete(&domain.UserRole{}).Error } // DeleteByUserAndRole 删除指定用户和角色的关联 func (r *UserRoleRepository) DeleteByUserAndRole(ctx context.Context, userID, roleID int64) error { return r.db.WithContext(ctx).Where("user_id = ? AND role_id = ?", userID, roleID).Delete(&domain.UserRole{}).Error } // DeleteByRoleID 删除角色的所有用户 func (r *UserRoleRepository) DeleteByRoleID(ctx context.Context, roleID int64) error { return r.db.WithContext(ctx).Where("role_id = ?", roleID).Delete(&domain.UserRole{}).Error } // GetByUserID 根据用户ID获取角色列表 func (r *UserRoleRepository) GetByUserID(ctx context.Context, userID int64) ([]*domain.UserRole, error) { var userRoles []*domain.UserRole err := r.db.WithContext(ctx).Where("user_id = ?", userID).Find(&userRoles).Error if err != nil { return nil, err } return userRoles, nil } // GetByRoleID 根据角色ID获取用户列表 func (r *UserRoleRepository) GetByRoleID(ctx context.Context, roleID int64) ([]*domain.UserRole, error) { var userRoles []*domain.UserRole err := r.db.WithContext(ctx).Where("role_id = ?", roleID).Find(&userRoles).Error if err != nil { return nil, err } return userRoles, nil } // GetRoleIDsByUserID 根据用户ID获取角色ID列表 func (r *UserRoleRepository) GetRoleIDsByUserID(ctx context.Context, userID int64) ([]int64, error) { var roleIDs []int64 err := r.db.WithContext(ctx).Model(&domain.UserRole{}).Where("user_id = ?", userID).Pluck("role_id", &roleIDs).Error if err != nil { return nil, err } return roleIDs, nil } // getRoleAncestorIDs 递归获取角色的所有祖先角色ID(含自身) // 包含循环检测(最大深度 5 层) func (r *UserRoleRepository) getRoleAncestorIDs(ctx context.Context, roleID int64) ([]int64, error) { var ancestors []int64 visited := make(map[int64]bool) current := roleID depth := 0 maxDepth := 5 for current > 0 && depth < maxDepth { if visited[current] { break // 循环检测 } visited[current] = true ancestors = append(ancestors, current) var role domain.Role err := r.db.WithContext(ctx).Select("parent_id").First(&role, current).Error if err != nil || role.ParentID == nil { break } current = *role.ParentID depth++ } return ancestors, nil } // GetUserRolesAndPermissions 获取用户角色和权限(包含继承的父角色和权限) func (r *UserRoleRepository) GetUserRolesAndPermissions(ctx context.Context, userID int64) ([]*domain.Role, []*domain.Permission, error) { // 获取用户直接分配的角色ID var directRoleIDs []int64 err := r.db.WithContext(ctx).Model(&domain.UserRole{}).Where("user_id = ?", userID).Pluck("role_id", &directRoleIDs).Error if err != nil { return nil, nil, err } // 递归获取所有祖先角色ID(含自身),包含循环检测 allRoleIDMap := make(map[int64]bool) for _, roleID := range directRoleIDs { ancestors, err := r.getRoleAncestorIDs(ctx, roleID) if err != nil { return nil, nil, err } for _, id := range ancestors { allRoleIDMap[id] = true } } // 转换为 slice allRoleIDs := make([]int64, 0, len(allRoleIDMap)) for id := range allRoleIDMap { allRoleIDs = append(allRoleIDs, id) } if len(allRoleIDs) == 0 { return []*domain.Role{}, []*domain.Permission{}, nil } // 查询所有角色信息 var roles []*domain.Role err = r.db.WithContext(ctx).Where("id IN ? AND status = ?", allRoleIDs, domain.RoleStatusEnabled).Find(&roles).Error if err != nil { return nil, nil, err } // 查询所有权限ID var permissionIDs []int64 err = r.db.WithContext(ctx).Model(&domain.RolePermission{}).Where("role_id IN ?", allRoleIDs).Pluck("permission_id", &permissionIDs).Error if err != nil { return nil, nil, err } // 查询权限详情 var permissions []*domain.Permission if len(permissionIDs) > 0 { err = r.db.WithContext(ctx).Where("id IN ?", permissionIDs).Find(&permissions).Error if err != nil { return nil, nil, err } } return roles, permissions, nil } // GetUserIDByRoleID 根据角色ID获取用户ID列表 func (r *UserRoleRepository) GetUserIDByRoleID(ctx context.Context, roleID int64) ([]int64, error) { var userIDs []int64 err := r.db.WithContext(ctx).Model(&domain.UserRole{}).Where("role_id = ?", roleID).Pluck("user_id", &userIDs).Error if err != nil { return nil, err } return userIDs, nil } // Exists 检查用户角色关联是否存在 func (r *UserRoleRepository) Exists(ctx context.Context, userID, roleID int64) (bool, error) { var count int64 err := r.db.WithContext(ctx).Model(&domain.UserRole{}). Where("user_id = ? AND role_id = ?", userID, roleID). Count(&count).Error return count > 0, err } // BatchCreate 批量创建用户角色关联 func (r *UserRoleRepository) BatchCreate(ctx context.Context, userRoles []*domain.UserRole) error { if len(userRoles) == 0 { return nil } return r.db.WithContext(ctx).Create(&userRoles).Error } // BatchDelete 批量删除用户角色关联 func (r *UserRoleRepository) BatchDelete(ctx context.Context, userRoles []*domain.UserRole) error { if len(userRoles) == 0 { return nil } var ids []int64 for _, ur := range userRoles { ids = append(ids, ur.ID) } return r.db.WithContext(ctx).Delete(&domain.UserRole{}, ids).Error } // ReplaceUserRoles replaces all roles for a user in a single transaction // This encapsulates the delete-then-create pattern to ensure atomicity func (r *UserRoleRepository) ReplaceUserRoles(ctx context.Context, userID int64, roleIDs []int64) error { return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { // Delete all existing roles for the user if err := tx.Where("user_id = ?", userID).Delete(&domain.UserRole{}).Error; err != nil { return err } // Create new role associations if any if len(roleIDs) > 0 { userRoles := make([]*domain.UserRole, len(roleIDs)) for i, roleID := range roleIDs { userRoles[i] = &domain.UserRole{ UserID: userID, RoleID: roleID, } } if err := tx.Create(&userRoles).Error; err != nil { return err } } return nil }) }