Fix mypy errors at bidirectional_a_star (#4556)

This commit is contained in:
Hasanul Islam 2021-07-20 13:36:14 +06:00 committed by GitHub
parent 72aa4cc315
commit 4a2216b69a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -8,6 +8,8 @@ import time
from math import sqrt
# 1 for manhattan, 0 for euclidean
from typing import Optional
HEURISTIC = 0
grid = [
@ -22,6 +24,8 @@ grid = [
delta = [[-1, 0], [0, -1], [1, 0], [0, 1]] # up, left, down, right
TPosition = tuple[int, int]
class Node:
"""
@ -39,7 +43,15 @@ class Node:
True
"""
def __init__(self, pos_x, pos_y, goal_x, goal_y, g_cost, parent):
def __init__(
self,
pos_x: int,
pos_y: int,
goal_x: int,
goal_y: int,
g_cost: int,
parent: Optional[Node],
) -> None:
self.pos_x = pos_x
self.pos_y = pos_y
self.pos = (pos_y, pos_x)
@ -61,7 +73,7 @@ class Node:
else:
return sqrt(dy ** 2 + dx ** 2)
def __lt__(self, other) -> bool:
def __lt__(self, other: Node) -> bool:
return self.f_cost < other.f_cost
@ -81,23 +93,22 @@ class AStar:
(4, 3), (4, 4), (5, 4), (5, 5), (6, 5), (6, 6)]
"""
def __init__(self, start, goal):
def __init__(self, start: TPosition, goal: TPosition):
self.start = Node(start[1], start[0], goal[1], goal[0], 0, None)
self.target = Node(goal[1], goal[0], goal[1], goal[0], 99999, None)
self.open_nodes = [self.start]
self.closed_nodes = []
self.closed_nodes: list[Node] = []
self.reached = False
def search(self) -> list[tuple[int]]:
def search(self) -> list[TPosition]:
while self.open_nodes:
# Open Nodes are sorted using __lt__
self.open_nodes.sort()
current_node = self.open_nodes.pop(0)
if current_node.pos == self.target.pos:
self.reached = True
return self.retrace_path(current_node)
self.closed_nodes.append(current_node)
@ -118,8 +129,7 @@ class AStar:
else:
self.open_nodes.append(better_node)
if not (self.reached):
return [(self.start.pos)]
return [self.start.pos]
def get_successors(self, parent: Node) -> list[Node]:
"""
@ -147,7 +157,7 @@ class AStar:
)
return successors
def retrace_path(self, node: Node) -> list[tuple[int]]:
def retrace_path(self, node: Optional[Node]) -> list[TPosition]:
"""
Retrace the path from parents to parents until start node
"""
@ -173,12 +183,12 @@ class BidirectionalAStar:
(2, 5), (3, 5), (4, 5), (5, 5), (5, 6), (6, 6)]
"""
def __init__(self, start, goal):
def __init__(self, start: TPosition, goal: TPosition) -> None:
self.fwd_astar = AStar(start, goal)
self.bwd_astar = AStar(goal, start)
self.reached = False
def search(self) -> list[tuple[int]]:
def search(self) -> list[TPosition]:
while self.fwd_astar.open_nodes or self.bwd_astar.open_nodes:
self.fwd_astar.open_nodes.sort()
self.bwd_astar.open_nodes.sort()
@ -186,7 +196,6 @@ class BidirectionalAStar:
current_bwd_node = self.bwd_astar.open_nodes.pop(0)
if current_bwd_node.pos == current_fwd_node.pos:
self.reached = True
return self.retrace_bidirectional_path(
current_fwd_node, current_bwd_node
)
@ -220,12 +229,11 @@ class BidirectionalAStar:
else:
astar.open_nodes.append(better_node)
if not self.reached:
return [self.fwd_astar.start.pos]
def retrace_bidirectional_path(
self, fwd_node: Node, bwd_node: Node
) -> list[tuple[int]]:
) -> list[TPosition]:
fwd_path = self.fwd_astar.retrace_path(fwd_node)
bwd_path = self.bwd_astar.retrace_path(bwd_node)
bwd_path.pop()
@ -236,9 +244,6 @@ class BidirectionalAStar:
if __name__ == "__main__":
# all coordinates are given in format [y,x]
import doctest
doctest.testmod()
init = (0, 0)
goal = (len(grid) - 1, len(grid[0]) - 1)
for elem in grid:
@ -252,6 +257,5 @@ if __name__ == "__main__":
bd_start_time = time.time()
bidir_astar = BidirectionalAStar(init, goal)
path = bidir_astar.search()
bd_end_time = time.time() - bd_start_time
print(f"BidirectionalAStar execution time = {bd_end_time:f} seconds")