From 256c319ce231eb0a158ec3506e2236d48ca4d6a5 Mon Sep 17 00:00:00 2001 From: Hasanul Islam Date: Thu, 8 Jul 2021 12:46:43 +0600 Subject: [PATCH] Fix mypy errors at kruskal_2 (#4528) --- graphs/minimum_spanning_tree_kruskal2.py | 104 +++++++++++++---------- 1 file changed, 58 insertions(+), 46 deletions(-) diff --git a/graphs/minimum_spanning_tree_kruskal2.py b/graphs/minimum_spanning_tree_kruskal2.py index dfb87efeb..0ddb43ce8 100644 --- a/graphs/minimum_spanning_tree_kruskal2.py +++ b/graphs/minimum_spanning_tree_kruskal2.py @@ -1,78 +1,93 @@ from __future__ import annotations +from typing import Generic, TypeVar -class DisjointSetTreeNode: +T = TypeVar("T") + + +class DisjointSetTreeNode(Generic[T]): # Disjoint Set Node to store the parent and rank - def __init__(self, key: int) -> None: - self.key = key + def __init__(self, data: T) -> None: + self.data = data self.parent = self self.rank = 0 -class DisjointSetTree: +class DisjointSetTree(Generic[T]): # Disjoint Set DataStructure - def __init__(self): + def __init__(self) -> None: # map from node name to the node object - self.map = {} + self.map: dict[T, DisjointSetTreeNode[T]] = {} - def make_set(self, x: int) -> None: + def make_set(self, data: T) -> None: # create a new set with x as its member - self.map[x] = DisjointSetTreeNode(x) + self.map[data] = DisjointSetTreeNode(data) - def find_set(self, x: int) -> DisjointSetTreeNode: + def find_set(self, data: T) -> DisjointSetTreeNode[T]: # find the set x belongs to (with path-compression) - elem_ref = self.map[x] + elem_ref = self.map[data] if elem_ref != elem_ref.parent: - elem_ref.parent = self.find_set(elem_ref.parent.key) + elem_ref.parent = self.find_set(elem_ref.parent.data) return elem_ref.parent - def link(self, x: int, y: int) -> None: + def link( + self, node1: DisjointSetTreeNode[T], node2: DisjointSetTreeNode[T] + ) -> None: # helper function for union operation - if x.rank > y.rank: - y.parent = x + if node1.rank > node2.rank: + node2.parent = node1 else: - x.parent = y - if x.rank == y.rank: - y.rank += 1 + node1.parent = node2 + if node1.rank == node2.rank: + node2.rank += 1 - def union(self, x: int, y: int) -> None: + def union(self, data1: T, data2: T) -> None: # merge 2 disjoint sets - self.link(self.find_set(x), self.find_set(y)) + self.link(self.find_set(data1), self.find_set(data2)) -class GraphUndirectedWeighted: - def __init__(self): +class GraphUndirectedWeighted(Generic[T]): + def __init__(self) -> None: # connections: map from the node to the neighbouring nodes (with weights) - self.connections = {} + self.connections: dict[T, dict[T, int]] = {} - def add_node(self, node: int) -> None: + def add_node(self, node: T) -> None: # add a node ONLY if its not present in the graph if node not in self.connections: self.connections[node] = {} - def add_edge(self, node1: int, node2: int, weight: int) -> None: + def add_edge(self, node1: T, node2: T, weight: int) -> None: # add an edge with the given weight self.add_node(node1) self.add_node(node2) self.connections[node1][node2] = weight self.connections[node2][node1] = weight - def kruskal(self) -> GraphUndirectedWeighted: + def kruskal(self) -> GraphUndirectedWeighted[T]: # Kruskal's Algorithm to generate a Minimum Spanning Tree (MST) of a graph """ Details: https://en.wikipedia.org/wiki/Kruskal%27s_algorithm Example: - - >>> graph = GraphUndirectedWeighted() - >>> graph.add_edge(1, 2, 1) - >>> graph.add_edge(2, 3, 2) - >>> graph.add_edge(3, 4, 1) - >>> graph.add_edge(3, 5, 100) # Removed in MST - >>> graph.add_edge(4, 5, 5) - >>> assert 5 in graph.connections[3] - >>> mst = graph.kruskal() + >>> g1 = GraphUndirectedWeighted[int]() + >>> g1.add_edge(1, 2, 1) + >>> g1.add_edge(2, 3, 2) + >>> g1.add_edge(3, 4, 1) + >>> g1.add_edge(3, 5, 100) # Removed in MST + >>> g1.add_edge(4, 5, 5) + >>> assert 5 in g1.connections[3] + >>> mst = g1.kruskal() >>> assert 5 not in mst.connections[3] + + >>> g2 = GraphUndirectedWeighted[str]() + >>> g2.add_edge('A', 'B', 1) + >>> g2.add_edge('B', 'C', 2) + >>> g2.add_edge('C', 'D', 1) + >>> g2.add_edge('C', 'E', 100) # Removed in MST + >>> g2.add_edge('D', 'E', 5) + >>> assert 'E' in g2.connections["C"] + >>> mst = g2.kruskal() + >>> assert 'E' not in mst.connections['C'] """ # getting the edges in ascending order of weights @@ -84,26 +99,23 @@ class GraphUndirectedWeighted: seen.add((end, start)) edges.append((start, end, self.connections[start][end])) edges.sort(key=lambda x: x[2]) + # creating the disjoint set - disjoint_set = DisjointSetTree() - [disjoint_set.make_set(node) for node in self.connections] + disjoint_set = DisjointSetTree[T]() + for node in self.connections: + disjoint_set.make_set(node) + # MST generation num_edges = 0 index = 0 - graph = GraphUndirectedWeighted() + graph = GraphUndirectedWeighted[T]() while num_edges < len(self.connections) - 1: u, v, w = edges[index] index += 1 - parentu = disjoint_set.find_set(u) - parentv = disjoint_set.find_set(v) - if parentu != parentv: + parent_u = disjoint_set.find_set(u) + parent_v = disjoint_set.find_set(v) + if parent_u != parent_v: num_edges += 1 graph.add_edge(u, v, w) disjoint_set.union(u, v) return graph - - -if __name__ == "__main__": - import doctest - - doctest.testmod()