diff --git a/pydatastructs/miscellaneous_data_structures/disjoint_set.py b/pydatastructs/miscellaneous_data_structures/disjoint_set.py index 2c2533a27..cac4f80e4 100644 --- a/pydatastructs/miscellaneous_data_structures/disjoint_set.py +++ b/pydatastructs/miscellaneous_data_structures/disjoint_set.py @@ -16,6 +16,9 @@ class DisjointSetForest(object): >>> dst.union(1, 2) >>> dst.find_root(2).key 1 + >>> dst.make_root(2) + >>> dst.find_root(2).key + 2 References ========== @@ -74,3 +77,29 @@ def union(self, key1, key2): y_root.parent = x_root x_root.size += y_root.size + + def make_root(self, key): + """ + Finds the set to which the key belongs + and makes it as the root of the set. + """ + if self.tree.get(key, None) is None: + raise KeyError("Invalid key, %s"%(key)) + + key_set = self.tree[key] + if key_set.parent is not key_set: + current_parent = key_set.parent + # Remove this key subtree size from all its ancestors + while current_parent.parent is not current_parent: + current_parent.size -= key_set.size + current_parent = current_parent.parent + + all_set_size = current_parent.size # This is the root node + current_parent.size -= key_set.size + + # Make parent of current root as key + current_parent.parent = key_set + # size of new root will be same as previous root's size + key_set.size = all_set_size + # Make parent of key as itself + key_set.parent = key_set diff --git a/pydatastructs/miscellaneous_data_structures/tests/test_disjoint_set.py b/pydatastructs/miscellaneous_data_structures/tests/test_disjoint_set.py index bc056b076..307a69af0 100644 --- a/pydatastructs/miscellaneous_data_structures/tests/test_disjoint_set.py +++ b/pydatastructs/miscellaneous_data_structures/tests/test_disjoint_set.py @@ -21,3 +21,44 @@ def test_DisjointSetForest(): assert raises(KeyError, lambda: dst.find_root(9)) dst.union(3, 1) assert dst.find_root(3).key == 1 + assert dst.find_root(5).key == 1 + dst.make_root(6) + assert dst.find_root(3).key == 6 + assert dst.find_root(5).key == 6 + dst.make_root(5) + assert dst.find_root(1).key == 5 + assert dst.find_root(5).key == 5 + assert raises(KeyError, lambda: dst.make_root(9)) + + dst = DisjointSetForest() + for i in range(6): + dst.make_set(i) + assert dst.tree[2].size == 1 + dst.union(2, 3) + assert dst.tree[2].size == 2 + assert dst.tree[3].size == 1 + dst.union(1, 4) + dst.union(2, 4) + # current tree + ############### + # 2 + # / \ + # 1 3 + # / + # 4 + ############### + assert dst.tree[2].size == 4 + assert dst.tree[1].size == 2 + assert dst.tree[3].size == dst.tree[4].size == 1 + dst.make_root(4) + # New tree + ############### + # 4 + # | + # 2 + # / \ + # 1 3 + ############### + assert dst.tree[4].size == 4 + assert dst.tree[2].size == 3 + assert dst.tree[1].size == dst.tree[3].size == 1