package repository import ( "context" "testing" "github.com/user-management-system/internal/domain" ) func TestRoleRepository_GetAncestorIDs(t *testing.T) { db := setupTestDB(t) repo := NewRoleRepository(db) ctx := context.Background() // 创建角色层级: grandchild -> child -> parent parentID := int64(0) parent := &domain.Role{Name: "parent", Code: "parent", ParentID: nil} if err := repo.Create(ctx, parent); err != nil { t.Fatalf("Create parent failed: %v", err) } parentID = parent.ID child := &domain.Role{Name: "child", Code: "child", ParentID: &parentID} if err := repo.Create(ctx, child); err != nil { t.Fatalf("Create child failed: %v", err) } childID := child.ID grandchild := &domain.Role{Name: "grandchild", Code: "grandchild", ParentID: &childID} if err := repo.Create(ctx, grandchild); err != nil { t.Fatalf("Create grandchild failed: %v", err) } // 获取grandchild的祖先ID列表 ancestorIDs, err := repo.GetAncestorIDs(ctx, grandchild.ID) if err != nil { t.Fatalf("GetAncestorIDs failed: %v", err) } if len(ancestorIDs) != 2 { t.Errorf("len(ancestorIDs) = %d, want 2", len(ancestorIDs)) } if ancestorIDs[0] != childID { t.Errorf("ancestorIDs[0] = %d, want %d", ancestorIDs[0], childID) } if ancestorIDs[1] != parentID { t.Errorf("ancestorIDs[1] = %d, want %d", ancestorIDs[1], parentID) } } func TestRoleRepository_GetAncestors(t *testing.T) { db := setupTestDB(t) repo := NewRoleRepository(db) ctx := context.Background() // 创建角色层级 parentID := int64(0) parent := &domain.Role{Name: "parent-role", Code: "parent-role", Status: domain.RoleStatusEnabled} if err := repo.Create(ctx, parent); err != nil { t.Fatalf("Create parent failed: %v", err) } parentID = parent.ID child := &domain.Role{Name: "child-role", Code: "child-role", ParentID: &parentID, Status: domain.RoleStatusEnabled} if err := repo.Create(ctx, child); err != nil { t.Fatalf("Create child failed: %v", err) } childID := child.ID grandchild := &domain.Role{Name: "grandchild-role", Code: "grandchild-role", ParentID: &childID, Status: domain.RoleStatusEnabled} if err := repo.Create(ctx, grandchild); err != nil { t.Fatalf("Create grandchild failed: %v", err) } // 获取grandchild的完整继承链 ancestors, err := repo.GetAncestors(ctx, grandchild.ID) if err != nil { t.Fatalf("GetAncestors failed: %v", err) } if len(ancestors) != 2 { t.Errorf("len(ancestors) = %d, want 2", len(ancestors)) } // 第一个应该是parent if ancestors[0].Code != "parent-role" { t.Errorf("ancestors[0].Code = %s, want parent-role", ancestors[0].Code) } // 第二个应该是child if ancestors[1].Code != "child-role" { t.Errorf("ancestors[1].Code = %s, want child-role", ancestors[1].Code) } }