mirror of
https://hub.njuu.cf/TheAlgorithms/C-Plus-Plus.git
synced 2023-10-11 13:05:55 +08:00
[feat/fix]: A Star Search Improvement (#1566)
* A Star Search Improvement * A Star Search Improvement - 2 * A Star Search Improvement - 3 Co-authored-by: David Leal <halfpacho@gmail.com>
This commit is contained in:
parent
c3b07aed22
commit
88394665b4
@ -25,6 +25,7 @@
|
|||||||
#include <functional> /// for `std::function` STL
|
#include <functional> /// for `std::function` STL
|
||||||
#include <iostream> /// for IO operations
|
#include <iostream> /// for IO operations
|
||||||
#include <map> /// for `std::map` STL
|
#include <map> /// for `std::map` STL
|
||||||
|
#include <memory> /// for `std::shared_ptr`
|
||||||
#include <set> /// for `std::set` STL
|
#include <set> /// for `std::set` STL
|
||||||
#include <vector> /// for `std::vector` STL
|
#include <vector> /// for `std::vector` STL
|
||||||
/**
|
/**
|
||||||
@ -60,7 +61,7 @@ class EightPuzzle {
|
|||||||
std::array<std::array<uint32_t, N>, N>
|
std::array<std::array<uint32_t, N>, N>
|
||||||
board; /// N x N array to store the current state of the Puzzle.
|
board; /// N x N array to store the current state of the Puzzle.
|
||||||
|
|
||||||
std::vector<std::pair<int, int>> moves = {
|
std::vector<std::pair<int8_t, int8_t>> moves = {
|
||||||
{0, 1},
|
{0, 1},
|
||||||
{1, 0},
|
{1, 0},
|
||||||
{0, -1},
|
{0, -1},
|
||||||
@ -86,9 +87,7 @@ class EightPuzzle {
|
|||||||
* @param value index for the current board
|
* @param value index for the current board
|
||||||
* @returns `true` if index is within the board, else `false`
|
* @returns `true` if index is within the board, else `false`
|
||||||
*/
|
*/
|
||||||
inline bool in_range(const uint32_t value) const {
|
inline bool in_range(const uint32_t value) const { return value < N; }
|
||||||
return value >= 0 && value < N;
|
|
||||||
}
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/**
|
/**
|
||||||
@ -292,9 +291,9 @@ class AyStarSearch {
|
|||||||
* state.
|
* state.
|
||||||
*/
|
*/
|
||||||
typedef struct Info {
|
typedef struct Info {
|
||||||
Puzzle state; /// Holds the current state.
|
std::shared_ptr<Puzzle> state; /// Holds the current state.
|
||||||
size_t heuristic_value = 0; /// stores h score
|
size_t heuristic_value = 0; /// stores h score
|
||||||
size_t depth = 0; /// stores g score
|
size_t depth = 0; /// stores g score
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Default constructor
|
* @brief Default constructor
|
||||||
@ -305,7 +304,7 @@ class AyStarSearch {
|
|||||||
* @brief constructor having Puzzle as parameter
|
* @brief constructor having Puzzle as parameter
|
||||||
* @param A a puzzle object
|
* @param A a puzzle object
|
||||||
*/
|
*/
|
||||||
explicit Info(const Puzzle &A) : state(std::move(A)) {}
|
explicit Info(const Puzzle &A) : state(std::make_shared<Puzzle>(A)) {}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief constructor having three parameters
|
* @brief constructor having three parameters
|
||||||
@ -314,14 +313,16 @@ class AyStarSearch {
|
|||||||
* @param depth the depth at which this node was found during traversal
|
* @param depth the depth at which this node was found during traversal
|
||||||
*/
|
*/
|
||||||
Info(const Puzzle &A, size_t h_value, size_t d)
|
Info(const Puzzle &A, size_t h_value, size_t d)
|
||||||
: state(std::move(A)), heuristic_value(h_value), depth(d) {}
|
: state(std::make_shared<Puzzle>(A)),
|
||||||
|
heuristic_value(h_value),
|
||||||
|
depth(d) {}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Copy constructor
|
* @brief Copy constructor
|
||||||
* @param A Info object reference
|
* @param A Info object reference
|
||||||
*/
|
*/
|
||||||
Info(const Info &A)
|
Info(const Info &A)
|
||||||
: state(A.state),
|
: state(std::make_shared<Puzzle>(A.state)),
|
||||||
heuristic_value(A.heuristic_value),
|
heuristic_value(A.heuristic_value),
|
||||||
depth(A.depth) {}
|
depth(A.depth) {}
|
||||||
|
|
||||||
@ -330,7 +331,7 @@ class AyStarSearch {
|
|||||||
* @param A Info object reference
|
* @param A Info object reference
|
||||||
*/
|
*/
|
||||||
Info(const Info &&A) noexcept
|
Info(const Info &&A) noexcept
|
||||||
: state(std::move(A.state)),
|
: state(std::make_shared<Puzzle>(std::move(A.state))),
|
||||||
heuristic_value(std::move(A.heuristic_value)),
|
heuristic_value(std::move(A.heuristic_value)),
|
||||||
depth(std::move(A.depth)) {}
|
depth(std::move(A.depth)) {}
|
||||||
|
|
||||||
@ -361,26 +362,36 @@ class AyStarSearch {
|
|||||||
~Info() = default;
|
~Info() = default;
|
||||||
} Info;
|
} Info;
|
||||||
|
|
||||||
Info Initial; // Initial state of the AyStarSearch
|
std::shared_ptr<Info> Initial; // Initial state of the AyStarSearch
|
||||||
Info Final; // Final state of the AyStarSearch
|
std::shared_ptr<Info> Final; // Final state of the AyStarSearch
|
||||||
/**
|
/**
|
||||||
* @brief Custom comparator for open_list
|
* @brief Custom comparator for open_list
|
||||||
*/
|
*/
|
||||||
struct comparison_operator {
|
struct comparison_operator {
|
||||||
bool operator()(const Info &a, const Info &b) const {
|
bool operator()(const std::shared_ptr<Info> &a,
|
||||||
return a.state < b.state;
|
const std::shared_ptr<Info> &b) const {
|
||||||
|
return *(a->state) < *(b->state);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
using MapOfPuzzleInfoWithPuzzleInfo =
|
||||||
|
std::map<std::shared_ptr<Info>, std::shared_ptr<Info>,
|
||||||
|
comparison_operator>;
|
||||||
|
|
||||||
|
using MapOfPuzzleInfoWithInteger =
|
||||||
|
std::map<std::shared_ptr<Info>, uint32_t, comparison_operator>;
|
||||||
|
|
||||||
|
using SetOfPuzzleInfo =
|
||||||
|
std::set<std::shared_ptr<Info>, comparison_operator>;
|
||||||
/**
|
/**
|
||||||
* @brief Parameterized constructor for AyStarSearch
|
* @brief Parameterized constructor for AyStarSearch
|
||||||
* @param initial denoting initial state of the puzzle
|
* @param initial denoting initial state of the puzzle
|
||||||
* @param final denoting final state of the puzzle
|
* @param final denoting final state of the puzzle
|
||||||
*/
|
*/
|
||||||
AyStarSearch(const Puzzle &initial, const Puzzle &final) {
|
AyStarSearch(const Puzzle &initial, const Puzzle &final) {
|
||||||
Initial = Info(initial);
|
Initial = std::make_shared<Info>(initial);
|
||||||
Final = Info(final);
|
Final = std::make_shared<Info>(final);
|
||||||
}
|
}
|
||||||
/**
|
/**
|
||||||
* @brief A helper solution: launches when a solution for AyStarSearch
|
* @brief A helper solution: launches when a solution for AyStarSearch
|
||||||
@ -392,18 +403,18 @@ class AyStarSearch {
|
|||||||
* state (in reverse)
|
* state (in reverse)
|
||||||
*/
|
*/
|
||||||
std::vector<Puzzle> Solution(
|
std::vector<Puzzle> Solution(
|
||||||
Info *FinalState,
|
std::shared_ptr<Info> FinalState,
|
||||||
const std::map<Info, Info *, comparison_operator> &parent_of) {
|
const MapOfPuzzleInfoWithPuzzleInfo &parent_of) {
|
||||||
// Useful for traversing from final state to current state.
|
// Useful for traversing from final state to current state.
|
||||||
Info *current_state = FinalState;
|
auto current_state = FinalState;
|
||||||
/*
|
/*
|
||||||
* For storing the solution tree starting from initial state to
|
* For storing the solution tree starting from initial state to
|
||||||
* final state
|
* final state
|
||||||
*/
|
*/
|
||||||
std::vector<Puzzle> answer;
|
std::vector<Puzzle> answer;
|
||||||
while (current_state != nullptr) {
|
while (current_state != nullptr) {
|
||||||
answer.emplace_back(current_state->state);
|
answer.emplace_back(*current_state->state);
|
||||||
current_state = parent_of.find(*current_state)->second;
|
current_state = parent_of.find(current_state)->second;
|
||||||
}
|
}
|
||||||
return answer;
|
return answer;
|
||||||
}
|
}
|
||||||
@ -418,14 +429,11 @@ class AyStarSearch {
|
|||||||
std::vector<Puzzle> a_star_search(
|
std::vector<Puzzle> a_star_search(
|
||||||
const std::function<uint32_t(const Puzzle &, const Puzzle &)> &dist,
|
const std::function<uint32_t(const Puzzle &, const Puzzle &)> &dist,
|
||||||
const uint32_t permissible_depth = 30) {
|
const uint32_t permissible_depth = 30) {
|
||||||
std::map<Info, Info *, comparison_operator>
|
MapOfPuzzleInfoWithPuzzleInfo
|
||||||
parent_of; /// Stores the parent of the states
|
parent_of; /// Stores the parent of the states
|
||||||
std::map<Info, uint32_t, comparison_operator>
|
MapOfPuzzleInfoWithInteger g_score; /// Stores the g_score
|
||||||
g_score; /// Stores the g_score
|
SetOfPuzzleInfo open_list; /// Stores the list to explore
|
||||||
std::set<Info, comparison_operator>
|
SetOfPuzzleInfo closed_list; /// Stores the list that are explored
|
||||||
open_list; /// Stores the list to explore
|
|
||||||
std::set<Info, comparison_operator>
|
|
||||||
closed_list; /// Stores the list that are explored
|
|
||||||
|
|
||||||
// Before starting the AyStartSearch, initialize the set and maps
|
// Before starting the AyStartSearch, initialize the set and maps
|
||||||
open_list.emplace(Initial);
|
open_list.emplace(Initial);
|
||||||
@ -434,14 +442,13 @@ class AyStarSearch {
|
|||||||
|
|
||||||
while (!open_list.empty()) {
|
while (!open_list.empty()) {
|
||||||
// Iterator for state having having lowest f_score.
|
// Iterator for state having having lowest f_score.
|
||||||
typename std::set<Info, comparison_operator>::iterator
|
typename SetOfPuzzleInfo::iterator it_low_f_score;
|
||||||
it_low_f_score;
|
|
||||||
uint32_t min_f_score = 1e9;
|
uint32_t min_f_score = 1e9;
|
||||||
for (auto iter = open_list.begin(); iter != open_list.end();
|
for (auto iter = open_list.begin(); iter != open_list.end();
|
||||||
++iter) {
|
++iter) {
|
||||||
// f score here is evaluated by g score (depth) and h score
|
// f score here is evaluated by g score (depth) and h score
|
||||||
// (distance between current state and final state)
|
// (distance between current state and final state)
|
||||||
uint32_t f_score = iter->heuristic_value + iter->depth;
|
uint32_t f_score = (*iter)->heuristic_value + (*iter)->depth;
|
||||||
if (f_score < min_f_score) {
|
if (f_score < min_f_score) {
|
||||||
min_f_score = f_score;
|
min_f_score = f_score;
|
||||||
it_low_f_score = iter;
|
it_low_f_score = iter;
|
||||||
@ -449,10 +456,10 @@ class AyStarSearch {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// current_state, stores lowest f score so far for this state.
|
// current_state, stores lowest f score so far for this state.
|
||||||
Info *current_state = new Info(*it_low_f_score);
|
std::shared_ptr<Info> current_state = *it_low_f_score;
|
||||||
|
|
||||||
// if this current state is equal to final, return
|
// if this current state is equal to final, return
|
||||||
if (current_state->state == Final.state) {
|
if (*(current_state->state) == *(Final->state)) {
|
||||||
return Solution(current_state, parent_of);
|
return Solution(current_state, parent_of);
|
||||||
}
|
}
|
||||||
// else remove from open list as visited.
|
// else remove from open list as visited.
|
||||||
@ -465,14 +472,15 @@ class AyStarSearch {
|
|||||||
// Generate all possible moves (neighbors) given the current
|
// Generate all possible moves (neighbors) given the current
|
||||||
// state
|
// state
|
||||||
std::vector<Puzzle> total_possible_moves =
|
std::vector<Puzzle> total_possible_moves =
|
||||||
current_state->state.generate_possible_moves();
|
current_state->state->generate_possible_moves();
|
||||||
|
|
||||||
for (Puzzle &neighbor : total_possible_moves) {
|
for (Puzzle &neighbor : total_possible_moves) {
|
||||||
// calculate score of neighbors with respect to
|
// calculate score of neighbors with respect to
|
||||||
// current_state
|
// current_state
|
||||||
Info Neighbor = {neighbor, dist(neighbor, Final.state),
|
std::shared_ptr<Info> Neighbor = std::make_shared<Info>(
|
||||||
current_state->depth + 1};
|
neighbor, dist(neighbor, *(Final->state)),
|
||||||
uint32_t temp_g_score = Neighbor.depth;
|
current_state->depth + 1U);
|
||||||
|
uint32_t temp_g_score = Neighbor->depth;
|
||||||
|
|
||||||
// Check whether this state is explored.
|
// Check whether this state is explored.
|
||||||
// If this state is discovered at greater depth, then discard,
|
// If this state is discovered at greater depth, then discard,
|
||||||
@ -482,7 +490,7 @@ class AyStarSearch {
|
|||||||
// 1. If state in closed list has higher depth, then remove
|
// 1. If state in closed list has higher depth, then remove
|
||||||
// from list since we have found better option,
|
// from list since we have found better option,
|
||||||
// 2. Else don't explore this state.
|
// 2. Else don't explore this state.
|
||||||
if (Neighbor.depth < closed_list_iter->depth) {
|
if (Neighbor->depth < (*closed_list_iter)->depth) {
|
||||||
closed_list.erase(closed_list_iter);
|
closed_list.erase(closed_list_iter);
|
||||||
} else {
|
} else {
|
||||||
continue;
|
continue;
|
||||||
@ -506,12 +514,11 @@ class AyStarSearch {
|
|||||||
auto iter = open_list.find(Neighbor);
|
auto iter = open_list.find(Neighbor);
|
||||||
if (iter == open_list.end()) {
|
if (iter == open_list.end()) {
|
||||||
open_list.emplace(Neighbor);
|
open_list.emplace(Neighbor);
|
||||||
} else if (iter->depth > Neighbor.depth) {
|
} else if ((*iter)->depth > Neighbor->depth) {
|
||||||
open_list.erase(iter);
|
(*iter)->depth = Neighbor->depth;
|
||||||
open_list.emplace(Neighbor);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
closed_list.emplace(*current_state);
|
closed_list.emplace(current_state);
|
||||||
}
|
}
|
||||||
// Cannot find the solution, return empty vector
|
// Cannot find the solution, return empty vector
|
||||||
return std::vector<Puzzle>(0);
|
return std::vector<Puzzle>(0);
|
||||||
@ -548,23 +555,24 @@ static void test() {
|
|||||||
[](const machine_learning::aystar_search::EightPuzzle<> &first,
|
[](const machine_learning::aystar_search::EightPuzzle<> &first,
|
||||||
const machine_learning::aystar_search::EightPuzzle<> &second) {
|
const machine_learning::aystar_search::EightPuzzle<> &second) {
|
||||||
uint32_t ret = 0;
|
uint32_t ret = 0;
|
||||||
for (int i = 0; i < first.get_size(); ++i) {
|
for (size_t i = 0; i < first.get_size(); ++i) {
|
||||||
for (int j = 0; j < first.get_size(); ++j) {
|
for (size_t j = 0; j < first.get_size(); ++j) {
|
||||||
uint32_t find = first.get(i, j);
|
uint32_t find = first.get(i, j);
|
||||||
int m = -1, n = -1;
|
size_t m = first.get_size(), n = first.get_size();
|
||||||
for (int k = 0; k < second.get_size(); ++k) {
|
for (size_t k = 0; k < second.get_size(); ++k) {
|
||||||
for (int l = 0; l < second.get_size(); ++l) {
|
for (size_t l = 0; l < second.get_size(); ++l) {
|
||||||
if (find == second.get(k, l)) {
|
if (find == second.get(k, l)) {
|
||||||
std::tie(m, n) = std::make_pair(k, l);
|
std::tie(m, n) = std::make_pair(k, l);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (m != -1) {
|
if (m != first.get_size()) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (m != -1) {
|
if (m != first.get_size()) {
|
||||||
ret += abs(m - i) + abs(n - j);
|
ret += (std::max(m, i) - std::min(m, i)) +
|
||||||
|
(std::max(n, j) - std::min(n, j));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -631,10 +639,10 @@ static void test() {
|
|||||||
// 3rd test: A* search for 15-Puzzle
|
// 3rd test: A* search for 15-Puzzle
|
||||||
// Initial State of the puzzle
|
// Initial State of the puzzle
|
||||||
matrix4 puzzle2;
|
matrix4 puzzle2;
|
||||||
puzzle2[0] = row4({5, 1, 2, 3});
|
puzzle2[0] = row4({10, 1, 6, 2});
|
||||||
puzzle2[1] = row4({9, 6, 8, 4});
|
puzzle2[1] = row4({5, 8, 4, 3});
|
||||||
puzzle2[2] = row4({13, 10, 7, 11});
|
puzzle2[2] = row4({13, 0, 7, 11});
|
||||||
puzzle2[3] = row4({14, 15, 12, 0});
|
puzzle2[3] = row4({14, 9, 15, 12});
|
||||||
// Final state of the puzzle
|
// Final state of the puzzle
|
||||||
matrix4 ideal2;
|
matrix4 ideal2;
|
||||||
ideal2[0] = row4({1, 2, 3, 4});
|
ideal2[0] = row4({1, 2, 3, 4});
|
||||||
@ -656,23 +664,24 @@ static void test() {
|
|||||||
[](const machine_learning::aystar_search::EightPuzzle<4> &first,
|
[](const machine_learning::aystar_search::EightPuzzle<4> &first,
|
||||||
const machine_learning::aystar_search::EightPuzzle<4> &second) {
|
const machine_learning::aystar_search::EightPuzzle<4> &second) {
|
||||||
uint32_t ret = 0;
|
uint32_t ret = 0;
|
||||||
for (int i = 0; i < first.get_size(); ++i) {
|
for (size_t i = 0; i < first.get_size(); ++i) {
|
||||||
for (int j = 0; j < first.get_size(); ++j) {
|
for (size_t j = 0; j < first.get_size(); ++j) {
|
||||||
uint32_t find = first.get(i, j);
|
uint32_t find = first.get(i, j);
|
||||||
int m = -1, n = -1;
|
size_t m = first.get_size(), n = first.get_size();
|
||||||
for (int k = 0; k < second.get_size(); ++k) {
|
for (size_t k = 0; k < second.get_size(); ++k) {
|
||||||
for (int l = 0; l < second.get_size(); ++l) {
|
for (size_t l = 0; l < second.get_size(); ++l) {
|
||||||
if (find == second.get(k, l)) {
|
if (find == second.get(k, l)) {
|
||||||
std::tie(m, n) = std::make_pair(k, l);
|
std::tie(m, n) = std::make_pair(k, l);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (m != -1) {
|
if (m != first.get_size()) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (m != -1) {
|
if (m != first.get_size()) {
|
||||||
ret += abs(m - i) + abs(n - j);
|
ret += (std::max(m, i) - std::min(m, i)) +
|
||||||
|
(std::max(n, j) - std::min(n, j));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -683,7 +692,7 @@ static void test() {
|
|||||||
std::cout << sol2.size() << std::endl;
|
std::cout << sol2.size() << std::endl;
|
||||||
|
|
||||||
// Static assertion due to large solution
|
// Static assertion due to large solution
|
||||||
assert(15 == sol2.size());
|
assert(24 == sol2.size());
|
||||||
// Check whether the final state is equal to expected one
|
// Check whether the final state is equal to expected one
|
||||||
assert(sol2[0].get_state() == ideal2);
|
assert(sol2[0].get_state() == ideal2);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user