From f680806894d39265f810e7257d50aa0beaf2152e Mon Sep 17 00:00:00 2001 From: Hao LI <8520588+Leo-LiHao@users.noreply.github.com> Date: Mon, 22 Feb 2021 07:58:17 +0800 Subject: [PATCH] add type hints for avl_tree (#4214) Co-authored-by: LiHao --- data_structures/binary_tree/avl_tree.py | 155 +++++++++++++----------- 1 file changed, 87 insertions(+), 68 deletions(-) diff --git a/data_structures/binary_tree/avl_tree.py b/data_structures/binary_tree/avl_tree.py index 3362610b9..e0d3e4d43 100644 --- a/data_structures/binary_tree/avl_tree.py +++ b/data_structures/binary_tree/avl_tree.py @@ -8,84 +8,85 @@ python avl_tree.py import math import random +from typing import Any, List, Optional class my_queue: - def __init__(self): - self.data = [] - self.head = 0 - self.tail = 0 + def __init__(self) -> None: + self.data: List[Any] = [] + self.head: int = 0 + self.tail: int = 0 - def is_empty(self): + def is_empty(self) -> bool: return self.head == self.tail - def push(self, data): + def push(self, data: Any) -> None: self.data.append(data) self.tail = self.tail + 1 - def pop(self): + def pop(self) -> Any: ret = self.data[self.head] self.head = self.head + 1 return ret - def count(self): + def count(self) -> int: return self.tail - self.head - def print(self): + def print(self) -> None: print(self.data) print("**************") print(self.data[self.head : self.tail]) class my_node: - def __init__(self, data): + def __init__(self, data: Any) -> None: self.data = data - self.left = None - self.right = None - self.height = 1 + self.left: Optional[my_node] = None + self.right: Optional[my_node] = None + self.height: int = 1 - def get_data(self): + def get_data(self) -> Any: return self.data - def get_left(self): + def get_left(self) -> Optional["my_node"]: return self.left - def get_right(self): + def get_right(self) -> Optional["my_node"]: return self.right - def get_height(self): + def get_height(self) -> int: return self.height - def set_data(self, data): + def set_data(self, data: Any) -> None: self.data = data return - def set_left(self, node): + def set_left(self, node: Optional["my_node"]) -> None: self.left = node return - def set_right(self, node): + def set_right(self, node: Optional["my_node"]) -> None: self.right = node return - def set_height(self, height): + def set_height(self, height: int) -> None: self.height = height return -def get_height(node): +def get_height(node: Optional["my_node"]) -> int: if node is None: return 0 return node.get_height() -def my_max(a, b): +def my_max(a: int, b: int) -> int: if a > b: return a return b -def right_rotation(node): +def right_rotation(node: my_node) -> my_node: r""" A B / \ / \ @@ -98,6 +99,7 @@ def right_rotation(node): """ print("left rotation node:", node.get_data()) ret = node.get_left() + assert ret is not None node.set_left(ret.get_right()) ret.set_right(node) h1 = my_max(get_height(node.get_right()), get_height(node.get_left())) + 1 @@ -107,12 +109,13 @@ def right_rotation(node): return ret -def left_rotation(node): +def left_rotation(node: my_node) -> my_node: """ a mirror symmetry rotation of the left_rotation """ print("right rotation node:", node.get_data()) ret = node.get_right() + assert ret is not None node.set_right(ret.get_left()) ret.set_left(node) h1 = my_max(get_height(node.get_right()), get_height(node.get_left())) + 1 @@ -122,7 +125,7 @@ def left_rotation(node): return ret -def lr_rotation(node): +def lr_rotation(node: my_node) -> my_node: r""" A A Br / \ / \ / \ @@ -133,16 +136,20 @@ def lr_rotation(node): UB Bl RR = right_rotation LR = left_rotation """ - node.set_left(left_rotation(node.get_left())) + left_child = node.get_left() + assert left_child is not None + node.set_left(left_rotation(left_child)) return right_rotation(node) -def rl_rotation(node): - node.set_right(right_rotation(node.get_right())) +def rl_rotation(node: my_node) -> my_node: + right_child = node.get_right() + assert right_child is not None + node.set_right(right_rotation(right_child)) return left_rotation(node) -def insert_node(node, data): +def insert_node(node: Optional["my_node"], data: Any) -> Optional["my_node"]: if node is None: return my_node(data) if data < node.get_data(): @@ -150,8 +157,10 @@ def insert_node(node, data): if ( get_height(node.get_left()) - get_height(node.get_right()) == 2 ): # an unbalance detected + left_child = node.get_left() + assert left_child is not None if ( - data < node.get_left().get_data() + data < left_child.get_data() ): # new node is the left child of the left child node = right_rotation(node) else: @@ -159,7 +168,9 @@ def insert_node(node, data): else: node.set_right(insert_node(node.get_right(), data)) if get_height(node.get_right()) - get_height(node.get_left()) == 2: - if data < node.get_right().get_data(): + right_child = node.get_right() + assert right_child is not None + if data < right_child.get_data(): node = rl_rotation(node) else: node = left_rotation(node) @@ -168,52 +179,59 @@ def insert_node(node, data): return node -def get_rightMost(root): - while root.get_right() is not None: - root = root.get_right() +def get_rightMost(root: my_node) -> Any: + while True: + right_child = root.get_right() + if right_child is None: + break + root = right_child return root.get_data() -def get_leftMost(root): - while root.get_left() is not None: - root = root.get_left() +def get_leftMost(root: my_node) -> Any: + while True: + left_child = root.get_left() + if left_child is None: + break + root = left_child return root.get_data() -def del_node(root, data): +def del_node(root: my_node, data: Any) -> Optional["my_node"]: + left_child = root.get_left() + right_child = root.get_right() if root.get_data() == data: - if root.get_left() is not None and root.get_right() is not None: - temp_data = get_leftMost(root.get_right()) + if left_child is not None and right_child is not None: + temp_data = get_leftMost(right_child) root.set_data(temp_data) - root.set_right(del_node(root.get_right(), temp_data)) - elif root.get_left() is not None: - root = root.get_left() + root.set_right(del_node(right_child, temp_data)) + elif left_child is not None: + root = left_child + elif right_child is not None: + root = right_child else: - root = root.get_right() + return None elif root.get_data() > data: - if root.get_left() is None: + if left_child is None: print("No such data") return root else: - root.set_left(del_node(root.get_left(), data)) - elif root.get_data() < data: - if root.get_right() is None: + root.set_left(del_node(left_child, data)) + else: # root.get_data() < data + if right_child is None: return root else: - root.set_right(del_node(root.get_right(), data)) - if root is None: - return root - if get_height(root.get_right()) - get_height(root.get_left()) == 2: - if get_height(root.get_right().get_right()) > get_height( - root.get_right().get_left() - ): + root.set_right(del_node(right_child, data)) + + if get_height(right_child) - get_height(left_child) == 2: + assert right_child is not None + if get_height(right_child.get_right()) > get_height(right_child.get_left()): root = left_rotation(root) else: root = rl_rotation(root) - elif get_height(root.get_right()) - get_height(root.get_left()) == -2: - if get_height(root.get_left().get_left()) > get_height( - root.get_left().get_right() - ): + elif get_height(right_child) - get_height(left_child) == -2: + assert left_child is not None + if get_height(left_child.get_left()) > get_height(left_child.get_right()): root = right_rotation(root) else: root = lr_rotation(root) @@ -256,25 +274,26 @@ class AVLtree: ************************************* """ - def __init__(self): - self.root = None + def __init__(self) -> None: + self.root: Optional[my_node] = None - def get_height(self): - # print("yyy") + def get_height(self) -> int: return get_height(self.root) - def insert(self, data): + def insert(self, data: Any) -> None: print("insert:" + str(data)) self.root = insert_node(self.root, data) - def del_node(self, data): + def del_node(self, data: Any) -> None: print("delete:" + str(data)) if self.root is None: print("Tree is empty!") return self.root = del_node(self.root, data) - def __str__(self): # a level traversale, gives a more intuitive look on the tree + def __str__( + self, + ) -> str: # a level traversale, gives a more intuitive look on the tree output = "" q = my_queue() q.push(self.root) @@ -308,7 +327,7 @@ class AVLtree: return output -def _test(): +def _test() -> None: import doctest doctest.testmod()