diff --git a/pydatastructs/trees/__init__.py b/pydatastructs/trees/__init__.py index 42f919c35..fa08fe1d5 100644 --- a/pydatastructs/trees/__init__.py +++ b/pydatastructs/trees/__init__.py @@ -12,7 +12,8 @@ BinarySearchTree, BinaryTreeTraversal, AVLTree, - BinaryIndexedTree + BinaryIndexedTree, + SplayTree ) __all__.extend(binary_trees.__all__) diff --git a/pydatastructs/trees/binary_trees.py b/pydatastructs/trees/binary_trees.py index 6e329cc71..8059fe005 100644 --- a/pydatastructs/trees/binary_trees.py +++ b/pydatastructs/trees/binary_trees.py @@ -10,7 +10,8 @@ 'BinaryTree', 'BinarySearchTree', 'BinaryTreeTraversal', - 'BinaryIndexedTree' + 'BinaryIndexedTree', + 'SplayTree' ] class BinaryTree(object): @@ -1059,3 +1060,123 @@ def get_sum(self, left_index, right_index): self.get_prefix_sum(left_index - 1) else: return self.get_prefix_sum(right_index) + +class SplayTree(SelfBalancingBinaryTree): + """ + Represents Splay Trees. + + References + ========== + .. [1] https://en.wikipedia.org/wiki/Splay_tree + + """ + def _zig(self, x, p): + if self.tree[p].left == x: + super(SplayTree, self)._right_rotate(p, x) + else: + super(SplayTree, self)._left_rotate(p, x) + + def _zig_zig(self, x, p): + super(SplayTree, self)._right_rotate(self.tree[p].parent, p) + super(SplayTree, self)._right_rotate(p, x) + + def _zig_zag(self, x, p): + super(SplayTree, self)._left_right_rotate(self.tree[p].parent, p) + + def _zag_zag(self, x, p): + super(SplayTree, self)._left_rotate(self.tree[p].parent, p) + super(SplayTree, self)._left_rotate(p, x) + + def _zag_zig(self, x, p): + super(SplayTree, self)._right_left_rotate(self.tree[p].parent, p) + + def splay(self, x, p): + while self.tree[x].parent is not None: + if self.tree[p].parent is None: + self._zig(x, p) + elif self.tree[p].left == x and self.tree[self.tree[p].parent].left == p: + self._zig_zig(x, p) + elif self.tree[p].right == x and self.tree[self.tree[p].parent].right == p: + self._zag_zag(x, p) + elif self.tree[p].left == x and self.tree[self.tree[p].parent].right == p: + self._zag_zig(x, p) + else: + self._zig_zag(x, p) + p = self.tree[x].parent + + def insert(self, key, x): + super(SelfBalancingBinaryTree, self).insert(key, x) + e, p = super(SelfBalancingBinaryTree, self).search(key, parent=True) + self.tree[self.size-1].parent = p; + self.splay(e, p) + + def delete(self, x): + e, p = super(SelfBalancingBinaryTree, self).search(x, parent=True) + if e is None: + return + self.splay(e, p) + b = super(SelfBalancingBinaryTree, self).delete(x, balancing_info=True) + return True + + def join(self, other): + """ + Joins two trees current and other such that all elements of + the current splay tree are smaller than the elements of the other tree. + + Parameters + ========== + + other: SplayTree + SplayTree which needs to be joined with the self tree. + + """ + maxm = self.root_idx + while self.tree[maxm].right is not None: + maxm = self.tree[maxm].right + self.splay(maxm, self.tree[maxm].parent) + traverse = BinaryTreeTraversal(other) + elements = traverse.depth_first_search(order='pre_order', node=other.root_idx) + for i in range(len(elements)): + super(SelfBalancingBinaryTree, self).insert(elements[i].key, elements[i].data) + j = len(elements)-1 + while j>=0: + e, p = super(SelfBalancingBinaryTree, other).search(elements[j].key, parent=True) + other.tree[e] = None + j-=1 + + def split(self, x): + """ + Splits current splay tree into two trees such that one tree contains nodes + with key less than or equal to x and the other tree containing + nodes with key greater than x. + + Parameters + ========== + + x: key + Key of the element on the basis of which split is performed. + + Returns + ======= + + other: SplayTree + SplayTree containing elements with key greater than x. + + """ + e, p = super(SelfBalancingBinaryTree, self).search(x, parent=True) + if e is None: + return + self.splay(e, p) + other = SplayTree(None, None) + if self.tree[self.root_idx].right is not None: + traverse = BinaryTreeTraversal(self) + elements = traverse.depth_first_search(order='pre_order', node=self.tree[self.root_idx].right) + for i in range(len(elements)): + super(SelfBalancingBinaryTree, other).insert(elements[i].key, elements[i].data) + j = len(elements)-1 + while j>=0: + e, p = super(SelfBalancingBinaryTree, self).search(elements[j].key, parent=True) + self.tree[e] = None + j-=1 + self.tree[self.root_idx].right = None + return other diff --git a/pydatastructs/trees/tests/test_binary_trees.py b/pydatastructs/trees/tests/test_binary_trees.py index b516895e4..04a7f3347 100644 --- a/pydatastructs/trees/tests/test_binary_trees.py +++ b/pydatastructs/trees/tests/test_binary_trees.py @@ -1,6 +1,6 @@ from pydatastructs.trees.binary_trees import ( BinarySearchTree, BinaryTreeTraversal, AVLTree, - ArrayForTrees, BinaryIndexedTree, SelfBalancingBinaryTree) + ArrayForTrees, BinaryIndexedTree, SelfBalancingBinaryTree, SplayTree) from pydatastructs.utils.raises_util import raises from pydatastructs.utils.misc_util import TreeNode from copy import deepcopy @@ -348,3 +348,29 @@ def test_issue_234(): tree.insert(4.56, 4.56) tree._left_rotate(5, 8) assert tree.tree[tree.tree[8].parent].left == 8 + +def test_SplayTree(): + t = SplayTree(100, 100) + t.insert(50, 50) + t.insert(200, 200) + t.insert(40, 40) + t.insert(30, 30) + t.insert(20, 20) + t.insert(55, 55) + + assert str(t) == ("[(None, 100, 100, None), (None, 50, 50, None), (0, 200, 200, None), (None, 40, 40, 1), (5, 30, 30, 3), (None, 20, 20, None), (4, 55, 55, 2)]") + t.delete(40) + assert str(t) == ("[(None, 100, 100, None), '', (0, 200, 200, None), (4, 50, 50, 6), (5, 30, 30, None), (None, 20, 20, None), (None, 55, 55, 2)]") + t.delete(150) + assert str(t) == ("[(None, 100, 100, None), '', (0, 200, 200, None), (4, 50, 50, 6), (5, 30, 30, None), (None, 20, 20, None), (None, 55, 55, 2)]") + t1 = SplayTree(1000, 1000) + t1.insert(2000, 2000) + + assert str(t1) == ("[(None, 1000, 1000, None), (0, 2000, 2000, None)]") + + t.join(t1) + assert str(t) == ("[(None, 100, 100, None), '', (6, 200, 200, 7), (4, 50, 50, None), (5, 30, 30, None), (None, 20, 20, None), (3, 55, 55, 0), (8, 2000, 2000, None), (None, 1000, 1000, None)]") + s = t.split(200) + + assert str(s) == ("[(1, 2000, 2000, None), (None, 1000, 1000, None)]") + assert str(t) == ("[(None, 100, 100, None), '', (6, 200, 200, None), (4, 50, 50, None), (5, 30, 30, None), (None, 20, 20, None), (3, 55, 55, 0), '', '']")