mirror of
https://hub.njuu.cf/TheAlgorithms/Python.git
synced 2023-10-11 13:06:12 +08:00
Fix mypy errors at bidirectional_a_star (#4556)
This commit is contained in:
parent
72aa4cc315
commit
4a2216b69a
@ -8,6 +8,8 @@ import time
|
||||
from math import sqrt
|
||||
|
||||
# 1 for manhattan, 0 for euclidean
|
||||
from typing import Optional
|
||||
|
||||
HEURISTIC = 0
|
||||
|
||||
grid = [
|
||||
@ -22,6 +24,8 @@ grid = [
|
||||
|
||||
delta = [[-1, 0], [0, -1], [1, 0], [0, 1]] # up, left, down, right
|
||||
|
||||
TPosition = tuple[int, int]
|
||||
|
||||
|
||||
class Node:
|
||||
"""
|
||||
@ -39,7 +43,15 @@ class Node:
|
||||
True
|
||||
"""
|
||||
|
||||
def __init__(self, pos_x, pos_y, goal_x, goal_y, g_cost, parent):
|
||||
def __init__(
|
||||
self,
|
||||
pos_x: int,
|
||||
pos_y: int,
|
||||
goal_x: int,
|
||||
goal_y: int,
|
||||
g_cost: int,
|
||||
parent: Optional[Node],
|
||||
) -> None:
|
||||
self.pos_x = pos_x
|
||||
self.pos_y = pos_y
|
||||
self.pos = (pos_y, pos_x)
|
||||
@ -61,7 +73,7 @@ class Node:
|
||||
else:
|
||||
return sqrt(dy ** 2 + dx ** 2)
|
||||
|
||||
def __lt__(self, other) -> bool:
|
||||
def __lt__(self, other: Node) -> bool:
|
||||
return self.f_cost < other.f_cost
|
||||
|
||||
|
||||
@ -81,23 +93,22 @@ class AStar:
|
||||
(4, 3), (4, 4), (5, 4), (5, 5), (6, 5), (6, 6)]
|
||||
"""
|
||||
|
||||
def __init__(self, start, goal):
|
||||
def __init__(self, start: TPosition, goal: TPosition):
|
||||
self.start = Node(start[1], start[0], goal[1], goal[0], 0, None)
|
||||
self.target = Node(goal[1], goal[0], goal[1], goal[0], 99999, None)
|
||||
|
||||
self.open_nodes = [self.start]
|
||||
self.closed_nodes = []
|
||||
self.closed_nodes: list[Node] = []
|
||||
|
||||
self.reached = False
|
||||
|
||||
def search(self) -> list[tuple[int]]:
|
||||
def search(self) -> list[TPosition]:
|
||||
while self.open_nodes:
|
||||
# Open Nodes are sorted using __lt__
|
||||
self.open_nodes.sort()
|
||||
current_node = self.open_nodes.pop(0)
|
||||
|
||||
if current_node.pos == self.target.pos:
|
||||
self.reached = True
|
||||
return self.retrace_path(current_node)
|
||||
|
||||
self.closed_nodes.append(current_node)
|
||||
@ -118,8 +129,7 @@ class AStar:
|
||||
else:
|
||||
self.open_nodes.append(better_node)
|
||||
|
||||
if not (self.reached):
|
||||
return [(self.start.pos)]
|
||||
return [self.start.pos]
|
||||
|
||||
def get_successors(self, parent: Node) -> list[Node]:
|
||||
"""
|
||||
@ -147,7 +157,7 @@ class AStar:
|
||||
)
|
||||
return successors
|
||||
|
||||
def retrace_path(self, node: Node) -> list[tuple[int]]:
|
||||
def retrace_path(self, node: Optional[Node]) -> list[TPosition]:
|
||||
"""
|
||||
Retrace the path from parents to parents until start node
|
||||
"""
|
||||
@ -173,12 +183,12 @@ class BidirectionalAStar:
|
||||
(2, 5), (3, 5), (4, 5), (5, 5), (5, 6), (6, 6)]
|
||||
"""
|
||||
|
||||
def __init__(self, start, goal):
|
||||
def __init__(self, start: TPosition, goal: TPosition) -> None:
|
||||
self.fwd_astar = AStar(start, goal)
|
||||
self.bwd_astar = AStar(goal, start)
|
||||
self.reached = False
|
||||
|
||||
def search(self) -> list[tuple[int]]:
|
||||
def search(self) -> list[TPosition]:
|
||||
while self.fwd_astar.open_nodes or self.bwd_astar.open_nodes:
|
||||
self.fwd_astar.open_nodes.sort()
|
||||
self.bwd_astar.open_nodes.sort()
|
||||
@ -186,7 +196,6 @@ class BidirectionalAStar:
|
||||
current_bwd_node = self.bwd_astar.open_nodes.pop(0)
|
||||
|
||||
if current_bwd_node.pos == current_fwd_node.pos:
|
||||
self.reached = True
|
||||
return self.retrace_bidirectional_path(
|
||||
current_fwd_node, current_bwd_node
|
||||
)
|
||||
@ -220,12 +229,11 @@ class BidirectionalAStar:
|
||||
else:
|
||||
astar.open_nodes.append(better_node)
|
||||
|
||||
if not self.reached:
|
||||
return [self.fwd_astar.start.pos]
|
||||
|
||||
def retrace_bidirectional_path(
|
||||
self, fwd_node: Node, bwd_node: Node
|
||||
) -> list[tuple[int]]:
|
||||
) -> list[TPosition]:
|
||||
fwd_path = self.fwd_astar.retrace_path(fwd_node)
|
||||
bwd_path = self.bwd_astar.retrace_path(bwd_node)
|
||||
bwd_path.pop()
|
||||
@ -236,9 +244,6 @@ class BidirectionalAStar:
|
||||
|
||||
if __name__ == "__main__":
|
||||
# all coordinates are given in format [y,x]
|
||||
import doctest
|
||||
|
||||
doctest.testmod()
|
||||
init = (0, 0)
|
||||
goal = (len(grid) - 1, len(grid[0]) - 1)
|
||||
for elem in grid:
|
||||
@ -252,6 +257,5 @@ if __name__ == "__main__":
|
||||
|
||||
bd_start_time = time.time()
|
||||
bidir_astar = BidirectionalAStar(init, goal)
|
||||
path = bidir_astar.search()
|
||||
bd_end_time = time.time() - bd_start_time
|
||||
print(f"BidirectionalAStar execution time = {bd_end_time:f} seconds")
|
||||
|
Loading…
Reference in New Issue
Block a user