Files

214 lines
5.8 KiB
Go
Raw Permalink Normal View History

package repository
import (
"context"
"errors"
"gorm.io/gorm"
"github.com/user-management-system/internal/domain"
)
// RoleRepository 角色数据访问层
type RoleRepository struct {
db *gorm.DB
}
// NewRoleRepository 创建角色数据访问层
func NewRoleRepository(db *gorm.DB) *RoleRepository {
return &RoleRepository{db: db}
}
// Create 创建角色
func (r *RoleRepository) Create(ctx context.Context, role *domain.Role) error {
// GORM omits zero values on insert for fields with DB defaults. Explicitly
// backfill disabled status so callers can persist status=0 roles.
requestedStatus := role.Status
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := tx.Create(role).Error; err != nil {
return err
}
if requestedStatus == domain.RoleStatusDisabled {
if err := tx.Model(&domain.Role{}).Where("id = ?", role.ID).Update("status", requestedStatus).Error; err != nil {
return err
}
role.Status = requestedStatus
}
return nil
})
}
// Update 更新角色
func (r *RoleRepository) Update(ctx context.Context, role *domain.Role) error {
return r.db.WithContext(ctx).Save(role).Error
}
// Delete 删除角色
func (r *RoleRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&domain.Role{}, id).Error
}
// GetByID 根据ID获取角色
func (r *RoleRepository) GetByID(ctx context.Context, id int64) (*domain.Role, error) {
var role domain.Role
err := r.db.WithContext(ctx).First(&role, id).Error
if err != nil {
return nil, err
}
return &role, nil
}
// GetByCode 根据代码获取角色
func (r *RoleRepository) GetByCode(ctx context.Context, code string) (*domain.Role, error) {
var role domain.Role
err := r.db.WithContext(ctx).Where("code = ?", code).First(&role).Error
if err != nil {
return nil, err
}
return &role, nil
}
// List 获取角色列表
func (r *RoleRepository) List(ctx context.Context, offset, limit int) ([]*domain.Role, int64, error) {
var roles []*domain.Role
var total int64
query := r.db.WithContext(ctx).Model(&domain.Role{})
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取列表
if err := query.Offset(offset).Limit(limit).Find(&roles).Error; err != nil {
return nil, 0, err
}
return roles, total, nil
}
// ListByStatus 根据状态获取角色列表
func (r *RoleRepository) ListByStatus(ctx context.Context, status domain.RoleStatus, offset, limit int) ([]*domain.Role, int64, error) {
var roles []*domain.Role
var total int64
query := r.db.WithContext(ctx).Model(&domain.Role{}).Where("status = ?", status)
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取列表
if err := query.Offset(offset).Limit(limit).Find(&roles).Error; err != nil {
return nil, 0, err
}
return roles, total, nil
}
// GetDefaultRoles 获取默认角色
func (r *RoleRepository) GetDefaultRoles(ctx context.Context) ([]*domain.Role, error) {
var roles []*domain.Role
err := r.db.WithContext(ctx).Where("is_default = ?", true).Find(&roles).Error
if err != nil {
return nil, err
}
return roles, nil
}
// ExistsByCode 检查角色代码是否存在
func (r *RoleRepository) ExistsByCode(ctx context.Context, code string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&domain.Role{}).Where("code = ?", code).Count(&count).Error
return count > 0, err
}
// UpdateStatus 更新角色状态
func (r *RoleRepository) UpdateStatus(ctx context.Context, id int64, status domain.RoleStatus) error {
return r.db.WithContext(ctx).Model(&domain.Role{}).Where("id = ?", id).Update("status", status).Error
}
// Search 搜索角色
func (r *RoleRepository) Search(ctx context.Context, keyword string, offset, limit int) ([]*domain.Role, int64, error) {
var roles []*domain.Role
var total int64
query := r.db.WithContext(ctx).Model(&domain.Role{}).
Where("name LIKE ? OR code LIKE ? OR description LIKE ?", "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取列表
if err := query.Offset(offset).Limit(limit).Find(&roles).Error; err != nil {
return nil, 0, err
}
return roles, total, nil
}
// ListByParentID 根据父ID获取角色列表
func (r *RoleRepository) ListByParentID(ctx context.Context, parentID int64) ([]*domain.Role, error) {
var roles []*domain.Role
err := r.db.WithContext(ctx).Where("parent_id = ?", parentID).Find(&roles).Error
if err != nil {
return nil, err
}
return roles, nil
}
// GetByIDs 根据ID列表批量获取角色
func (r *RoleRepository) GetByIDs(ctx context.Context, ids []int64) ([]*domain.Role, error) {
if len(ids) == 0 {
return []*domain.Role{}, nil
}
var roles []*domain.Role
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&roles).Error
if err != nil {
return nil, err
}
return roles, nil
}
// GetAncestorIDs 获取角色的所有祖先角色ID用于权限继承
func (r *RoleRepository) GetAncestorIDs(ctx context.Context, roleID int64) ([]int64, error) {
var ancestorIDs []int64
currentID := roleID
// 循环向上查找父角色,直到没有父角色为止
for {
var role domain.Role
err := r.db.WithContext(ctx).Select("id", "parent_id").First(&role, currentID).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
break
}
return nil, err
}
if role.ParentID == nil {
break
}
ancestorIDs = append(ancestorIDs, *role.ParentID)
currentID = *role.ParentID
}
return ancestorIDs, nil
}
// GetAncestors 获取角色的完整继承链(从父到子)
func (r *RoleRepository) GetAncestors(ctx context.Context, roleID int64) ([]*domain.Role, error) {
ancestorIDs, err := r.GetAncestorIDs(ctx, roleID)
if err != nil {
return nil, err
}
if len(ancestorIDs) == 0 {
return []*domain.Role{}, nil
}
return r.GetByIDs(ctx, ancestorIDs)
}