Fix mypy at prims_algo_2 (#4527)

This commit is contained in:
Hasanul Islam 2021-07-05 12:23:18 +06:00 committed by GitHub
parent 86baec0bc9
commit 95862303a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -8,7 +8,9 @@ connection from the tree to another vertex.
"""
from sys import maxsize
from typing import Dict, Optional, Tuple, Union
from typing import Generic, Optional, TypeVar
T = TypeVar("T")
def get_parent_position(position: int) -> int:
@ -43,7 +45,7 @@ def get_child_right_position(position: int) -> int:
return (2 * position) + 2
class MinPriorityQueue:
class MinPriorityQueue(Generic[T]):
"""
Minimum Priority Queue Class
@ -80,9 +82,9 @@ class MinPriorityQueue:
"""
def __init__(self) -> None:
self.heap = []
self.position_map = {}
self.elements = 0
self.heap: list[tuple[T, int]] = []
self.position_map: dict[T, int] = {}
self.elements: int = 0
def __len__(self) -> int:
return self.elements
@ -94,14 +96,14 @@ class MinPriorityQueue:
# Check if the priority queue is empty
return self.elements == 0
def push(self, elem: Union[int, str], weight: int) -> None:
def push(self, elem: T, weight: int) -> None:
# Add an element with given priority to the queue
self.heap.append((elem, weight))
self.position_map[elem] = self.elements
self.elements += 1
self._bubble_up(elem)
def extract_min(self) -> Union[int, str]:
def extract_min(self) -> T:
# Remove and return the element with lowest weight (highest priority)
if self.elements > 1:
self._swap_nodes(0, self.elements - 1)
@ -113,7 +115,7 @@ class MinPriorityQueue:
self._bubble_down(bubble_down_elem)
return elem
def update_key(self, elem: Union[int, str], weight: int) -> None:
def update_key(self, elem: T, weight: int) -> None:
# Update the weight of the given key
position = self.position_map[elem]
self.heap[position] = (elem, weight)
@ -127,7 +129,7 @@ class MinPriorityQueue:
else:
self._bubble_down(elem)
def _bubble_up(self, elem: Union[int, str]) -> None:
def _bubble_up(self, elem: T) -> None:
# Place a node at the proper position (upward movement) [to be used internally
# only]
curr_pos = self.position_map[elem]
@ -141,7 +143,7 @@ class MinPriorityQueue:
return self._bubble_up(elem)
return
def _bubble_down(self, elem: Union[int, str]) -> None:
def _bubble_down(self, elem: T) -> None:
# Place a node at the proper position (downward movement) [to be used
# internally only]
curr_pos = self.position_map[elem]
@ -182,7 +184,7 @@ class MinPriorityQueue:
self.position_map[node2_elem] = node1_pos
class GraphUndirectedWeighted:
class GraphUndirectedWeighted(Generic[T]):
"""
Graph Undirected Weighted Class
@ -192,8 +194,8 @@ class GraphUndirectedWeighted:
"""
def __init__(self) -> None:
self.connections = {}
self.nodes = 0
self.connections: dict[T, dict[T, int]] = {}
self.nodes: int = 0
def __repr__(self) -> str:
return str(self.connections)
@ -201,15 +203,13 @@ class GraphUndirectedWeighted:
def __len__(self) -> int:
return self.nodes
def add_node(self, node: Union[int, str]) -> None:
def add_node(self, node: T) -> None:
# Add a node in the graph if it is not in the graph
if node not in self.connections:
self.connections[node] = {}
self.nodes += 1
def add_edge(
self, node1: Union[int, str], node2: Union[int, str], weight: int
) -> None:
def add_edge(self, node1: T, node2: T, weight: int) -> None:
# Add an edge between 2 nodes in the graph
self.add_node(node1)
self.add_node(node2)
@ -218,8 +218,8 @@ class GraphUndirectedWeighted:
def prims_algo(
graph: GraphUndirectedWeighted,
) -> Tuple[Dict[str, int], Dict[str, Optional[str]]]:
graph: GraphUndirectedWeighted[T],
) -> tuple[dict[T, int], dict[T, Optional[T]]]:
"""
>>> graph = GraphUndirectedWeighted()
@ -239,10 +239,13 @@ def prims_algo(
13
"""
# prim's algorithm for minimum spanning tree
dist = {node: maxsize for node in graph.connections}
parent = {node: None for node in graph.connections}
priority_queue = MinPriorityQueue()
[priority_queue.push(node, weight) for node, weight in dist.items()]
dist: dict[T, int] = {node: maxsize for node in graph.connections}
parent: dict[T, Optional[T]] = {node: None for node in graph.connections}
priority_queue: MinPriorityQueue[T] = MinPriorityQueue()
for node, weight in dist.items():
priority_queue.push(node, weight)
if priority_queue.is_empty():
return dist, parent
@ -254,6 +257,7 @@ def prims_algo(
dist[neighbour] = dist[node] + graph.connections[node][neighbour]
priority_queue.update_key(neighbour, dist[neighbour])
parent[neighbour] = node
# running prim's algorithm
while not priority_queue.is_empty():
node = priority_queue.extract_min()
@ -263,9 +267,3 @@ def prims_algo(
priority_queue.update_key(neighbour, dist[neighbour])
parent[neighbour] = node
return dist, parent
if __name__ == "__main__":
from doctest import testmod
testmod()