Fix mypy at prims_algo_2 (#4527)

This commit is contained in:
Hasanul Islam 2021-07-05 12:23:18 +06:00 committed by GitHub
parent 86baec0bc9
commit 95862303a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -8,7 +8,9 @@ connection from the tree to another vertex.
""" """
from sys import maxsize from sys import maxsize
from typing import Dict, Optional, Tuple, Union from typing import Generic, Optional, TypeVar
T = TypeVar("T")
def get_parent_position(position: int) -> int: def get_parent_position(position: int) -> int:
@ -43,7 +45,7 @@ def get_child_right_position(position: int) -> int:
return (2 * position) + 2 return (2 * position) + 2
class MinPriorityQueue: class MinPriorityQueue(Generic[T]):
""" """
Minimum Priority Queue Class Minimum Priority Queue Class
@ -80,9 +82,9 @@ class MinPriorityQueue:
""" """
def __init__(self) -> None: def __init__(self) -> None:
self.heap = [] self.heap: list[tuple[T, int]] = []
self.position_map = {} self.position_map: dict[T, int] = {}
self.elements = 0 self.elements: int = 0
def __len__(self) -> int: def __len__(self) -> int:
return self.elements return self.elements
@ -94,14 +96,14 @@ class MinPriorityQueue:
# Check if the priority queue is empty # Check if the priority queue is empty
return self.elements == 0 return self.elements == 0
def push(self, elem: Union[int, str], weight: int) -> None: def push(self, elem: T, weight: int) -> None:
# Add an element with given priority to the queue # Add an element with given priority to the queue
self.heap.append((elem, weight)) self.heap.append((elem, weight))
self.position_map[elem] = self.elements self.position_map[elem] = self.elements
self.elements += 1 self.elements += 1
self._bubble_up(elem) self._bubble_up(elem)
def extract_min(self) -> Union[int, str]: def extract_min(self) -> T:
# Remove and return the element with lowest weight (highest priority) # Remove and return the element with lowest weight (highest priority)
if self.elements > 1: if self.elements > 1:
self._swap_nodes(0, self.elements - 1) self._swap_nodes(0, self.elements - 1)
@ -113,7 +115,7 @@ class MinPriorityQueue:
self._bubble_down(bubble_down_elem) self._bubble_down(bubble_down_elem)
return elem return elem
def update_key(self, elem: Union[int, str], weight: int) -> None: def update_key(self, elem: T, weight: int) -> None:
# Update the weight of the given key # Update the weight of the given key
position = self.position_map[elem] position = self.position_map[elem]
self.heap[position] = (elem, weight) self.heap[position] = (elem, weight)
@ -127,7 +129,7 @@ class MinPriorityQueue:
else: else:
self._bubble_down(elem) self._bubble_down(elem)
def _bubble_up(self, elem: Union[int, str]) -> None: def _bubble_up(self, elem: T) -> None:
# Place a node at the proper position (upward movement) [to be used internally # Place a node at the proper position (upward movement) [to be used internally
# only] # only]
curr_pos = self.position_map[elem] curr_pos = self.position_map[elem]
@ -141,7 +143,7 @@ class MinPriorityQueue:
return self._bubble_up(elem) return self._bubble_up(elem)
return return
def _bubble_down(self, elem: Union[int, str]) -> None: def _bubble_down(self, elem: T) -> None:
# Place a node at the proper position (downward movement) [to be used # Place a node at the proper position (downward movement) [to be used
# internally only] # internally only]
curr_pos = self.position_map[elem] curr_pos = self.position_map[elem]
@ -182,7 +184,7 @@ class MinPriorityQueue:
self.position_map[node2_elem] = node1_pos self.position_map[node2_elem] = node1_pos
class GraphUndirectedWeighted: class GraphUndirectedWeighted(Generic[T]):
""" """
Graph Undirected Weighted Class Graph Undirected Weighted Class
@ -192,8 +194,8 @@ class GraphUndirectedWeighted:
""" """
def __init__(self) -> None: def __init__(self) -> None:
self.connections = {} self.connections: dict[T, dict[T, int]] = {}
self.nodes = 0 self.nodes: int = 0
def __repr__(self) -> str: def __repr__(self) -> str:
return str(self.connections) return str(self.connections)
@ -201,15 +203,13 @@ class GraphUndirectedWeighted:
def __len__(self) -> int: def __len__(self) -> int:
return self.nodes return self.nodes
def add_node(self, node: Union[int, str]) -> None: def add_node(self, node: T) -> None:
# Add a node in the graph if it is not in the graph # Add a node in the graph if it is not in the graph
if node not in self.connections: if node not in self.connections:
self.connections[node] = {} self.connections[node] = {}
self.nodes += 1 self.nodes += 1
def add_edge( def add_edge(self, node1: T, node2: T, weight: int) -> None:
self, node1: Union[int, str], node2: Union[int, str], weight: int
) -> None:
# Add an edge between 2 nodes in the graph # Add an edge between 2 nodes in the graph
self.add_node(node1) self.add_node(node1)
self.add_node(node2) self.add_node(node2)
@ -218,8 +218,8 @@ class GraphUndirectedWeighted:
def prims_algo( def prims_algo(
graph: GraphUndirectedWeighted, graph: GraphUndirectedWeighted[T],
) -> Tuple[Dict[str, int], Dict[str, Optional[str]]]: ) -> tuple[dict[T, int], dict[T, Optional[T]]]:
""" """
>>> graph = GraphUndirectedWeighted() >>> graph = GraphUndirectedWeighted()
@ -239,10 +239,13 @@ def prims_algo(
13 13
""" """
# prim's algorithm for minimum spanning tree # prim's algorithm for minimum spanning tree
dist = {node: maxsize for node in graph.connections} dist: dict[T, int] = {node: maxsize for node in graph.connections}
parent = {node: None for node in graph.connections} parent: dict[T, Optional[T]] = {node: None for node in graph.connections}
priority_queue = MinPriorityQueue()
[priority_queue.push(node, weight) for node, weight in dist.items()] priority_queue: MinPriorityQueue[T] = MinPriorityQueue()
for node, weight in dist.items():
priority_queue.push(node, weight)
if priority_queue.is_empty(): if priority_queue.is_empty():
return dist, parent return dist, parent
@ -254,6 +257,7 @@ def prims_algo(
dist[neighbour] = dist[node] + graph.connections[node][neighbour] dist[neighbour] = dist[node] + graph.connections[node][neighbour]
priority_queue.update_key(neighbour, dist[neighbour]) priority_queue.update_key(neighbour, dist[neighbour])
parent[neighbour] = node parent[neighbour] = node
# running prim's algorithm # running prim's algorithm
while not priority_queue.is_empty(): while not priority_queue.is_empty():
node = priority_queue.extract_min() node = priority_queue.extract_min()
@ -263,9 +267,3 @@ def prims_algo(
priority_queue.update_key(neighbour, dist[neighbour]) priority_queue.update_key(neighbour, dist[neighbour])
parent[neighbour] = node parent[neighbour] = node
return dist, parent return dist, parent
if __name__ == "__main__":
from doctest import testmod
testmod()