mirror of
https://hub.njuu.cf/TheAlgorithms/Python.git
synced 2023-10-11 13:06:12 +08:00
Fix mypy errors at kruskal_2 (#4528)
This commit is contained in:
parent
4412eafaac
commit
256c319ce2
@ -1,78 +1,93 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
class DisjointSetTreeNode:
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class DisjointSetTreeNode(Generic[T]):
|
||||||
# Disjoint Set Node to store the parent and rank
|
# Disjoint Set Node to store the parent and rank
|
||||||
def __init__(self, key: int) -> None:
|
def __init__(self, data: T) -> None:
|
||||||
self.key = key
|
self.data = data
|
||||||
self.parent = self
|
self.parent = self
|
||||||
self.rank = 0
|
self.rank = 0
|
||||||
|
|
||||||
|
|
||||||
class DisjointSetTree:
|
class DisjointSetTree(Generic[T]):
|
||||||
# Disjoint Set DataStructure
|
# Disjoint Set DataStructure
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
# map from node name to the node object
|
# map from node name to the node object
|
||||||
self.map = {}
|
self.map: dict[T, DisjointSetTreeNode[T]] = {}
|
||||||
|
|
||||||
def make_set(self, x: int) -> None:
|
def make_set(self, data: T) -> None:
|
||||||
# create a new set with x as its member
|
# create a new set with x as its member
|
||||||
self.map[x] = DisjointSetTreeNode(x)
|
self.map[data] = DisjointSetTreeNode(data)
|
||||||
|
|
||||||
def find_set(self, x: int) -> DisjointSetTreeNode:
|
def find_set(self, data: T) -> DisjointSetTreeNode[T]:
|
||||||
# find the set x belongs to (with path-compression)
|
# find the set x belongs to (with path-compression)
|
||||||
elem_ref = self.map[x]
|
elem_ref = self.map[data]
|
||||||
if elem_ref != elem_ref.parent:
|
if elem_ref != elem_ref.parent:
|
||||||
elem_ref.parent = self.find_set(elem_ref.parent.key)
|
elem_ref.parent = self.find_set(elem_ref.parent.data)
|
||||||
return elem_ref.parent
|
return elem_ref.parent
|
||||||
|
|
||||||
def link(self, x: int, y: int) -> None:
|
def link(
|
||||||
|
self, node1: DisjointSetTreeNode[T], node2: DisjointSetTreeNode[T]
|
||||||
|
) -> None:
|
||||||
# helper function for union operation
|
# helper function for union operation
|
||||||
if x.rank > y.rank:
|
if node1.rank > node2.rank:
|
||||||
y.parent = x
|
node2.parent = node1
|
||||||
else:
|
else:
|
||||||
x.parent = y
|
node1.parent = node2
|
||||||
if x.rank == y.rank:
|
if node1.rank == node2.rank:
|
||||||
y.rank += 1
|
node2.rank += 1
|
||||||
|
|
||||||
def union(self, x: int, y: int) -> None:
|
def union(self, data1: T, data2: T) -> None:
|
||||||
# merge 2 disjoint sets
|
# merge 2 disjoint sets
|
||||||
self.link(self.find_set(x), self.find_set(y))
|
self.link(self.find_set(data1), self.find_set(data2))
|
||||||
|
|
||||||
|
|
||||||
class GraphUndirectedWeighted:
|
class GraphUndirectedWeighted(Generic[T]):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
# connections: map from the node to the neighbouring nodes (with weights)
|
# connections: map from the node to the neighbouring nodes (with weights)
|
||||||
self.connections = {}
|
self.connections: dict[T, dict[T, int]] = {}
|
||||||
|
|
||||||
def add_node(self, node: int) -> None:
|
def add_node(self, node: T) -> None:
|
||||||
# add a node ONLY if its not present in the graph
|
# add a node ONLY if its not present in the graph
|
||||||
if node not in self.connections:
|
if node not in self.connections:
|
||||||
self.connections[node] = {}
|
self.connections[node] = {}
|
||||||
|
|
||||||
def add_edge(self, node1: int, node2: int, weight: int) -> None:
|
def add_edge(self, node1: T, node2: T, weight: int) -> None:
|
||||||
# add an edge with the given weight
|
# add an edge with the given weight
|
||||||
self.add_node(node1)
|
self.add_node(node1)
|
||||||
self.add_node(node2)
|
self.add_node(node2)
|
||||||
self.connections[node1][node2] = weight
|
self.connections[node1][node2] = weight
|
||||||
self.connections[node2][node1] = weight
|
self.connections[node2][node1] = weight
|
||||||
|
|
||||||
def kruskal(self) -> GraphUndirectedWeighted:
|
def kruskal(self) -> GraphUndirectedWeighted[T]:
|
||||||
# Kruskal's Algorithm to generate a Minimum Spanning Tree (MST) of a graph
|
# Kruskal's Algorithm to generate a Minimum Spanning Tree (MST) of a graph
|
||||||
"""
|
"""
|
||||||
Details: https://en.wikipedia.org/wiki/Kruskal%27s_algorithm
|
Details: https://en.wikipedia.org/wiki/Kruskal%27s_algorithm
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
>>> g1 = GraphUndirectedWeighted[int]()
|
||||||
>>> graph = GraphUndirectedWeighted()
|
>>> g1.add_edge(1, 2, 1)
|
||||||
>>> graph.add_edge(1, 2, 1)
|
>>> g1.add_edge(2, 3, 2)
|
||||||
>>> graph.add_edge(2, 3, 2)
|
>>> g1.add_edge(3, 4, 1)
|
||||||
>>> graph.add_edge(3, 4, 1)
|
>>> g1.add_edge(3, 5, 100) # Removed in MST
|
||||||
>>> graph.add_edge(3, 5, 100) # Removed in MST
|
>>> g1.add_edge(4, 5, 5)
|
||||||
>>> graph.add_edge(4, 5, 5)
|
>>> assert 5 in g1.connections[3]
|
||||||
>>> assert 5 in graph.connections[3]
|
>>> mst = g1.kruskal()
|
||||||
>>> mst = graph.kruskal()
|
|
||||||
>>> assert 5 not in mst.connections[3]
|
>>> assert 5 not in mst.connections[3]
|
||||||
|
|
||||||
|
>>> g2 = GraphUndirectedWeighted[str]()
|
||||||
|
>>> g2.add_edge('A', 'B', 1)
|
||||||
|
>>> g2.add_edge('B', 'C', 2)
|
||||||
|
>>> g2.add_edge('C', 'D', 1)
|
||||||
|
>>> g2.add_edge('C', 'E', 100) # Removed in MST
|
||||||
|
>>> g2.add_edge('D', 'E', 5)
|
||||||
|
>>> assert 'E' in g2.connections["C"]
|
||||||
|
>>> mst = g2.kruskal()
|
||||||
|
>>> assert 'E' not in mst.connections['C']
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# getting the edges in ascending order of weights
|
# getting the edges in ascending order of weights
|
||||||
@ -84,26 +99,23 @@ class GraphUndirectedWeighted:
|
|||||||
seen.add((end, start))
|
seen.add((end, start))
|
||||||
edges.append((start, end, self.connections[start][end]))
|
edges.append((start, end, self.connections[start][end]))
|
||||||
edges.sort(key=lambda x: x[2])
|
edges.sort(key=lambda x: x[2])
|
||||||
|
|
||||||
# creating the disjoint set
|
# creating the disjoint set
|
||||||
disjoint_set = DisjointSetTree()
|
disjoint_set = DisjointSetTree[T]()
|
||||||
[disjoint_set.make_set(node) for node in self.connections]
|
for node in self.connections:
|
||||||
|
disjoint_set.make_set(node)
|
||||||
|
|
||||||
# MST generation
|
# MST generation
|
||||||
num_edges = 0
|
num_edges = 0
|
||||||
index = 0
|
index = 0
|
||||||
graph = GraphUndirectedWeighted()
|
graph = GraphUndirectedWeighted[T]()
|
||||||
while num_edges < len(self.connections) - 1:
|
while num_edges < len(self.connections) - 1:
|
||||||
u, v, w = edges[index]
|
u, v, w = edges[index]
|
||||||
index += 1
|
index += 1
|
||||||
parentu = disjoint_set.find_set(u)
|
parent_u = disjoint_set.find_set(u)
|
||||||
parentv = disjoint_set.find_set(v)
|
parent_v = disjoint_set.find_set(v)
|
||||||
if parentu != parentv:
|
if parent_u != parent_v:
|
||||||
num_edges += 1
|
num_edges += 1
|
||||||
graph.add_edge(u, v, w)
|
graph.add_edge(u, v, w)
|
||||||
disjoint_set.union(u, v)
|
disjoint_set.union(u, v)
|
||||||
return graph
|
return graph
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import doctest
|
|
||||||
|
|
||||||
doctest.testmod()
|
|
||||||
|
Loading…
Reference in New Issue
Block a user