From 88394665b4896227cd9e7b894e98e84b2942ba97 Mon Sep 17 00:00:00 2001 From: Ashish Bhanu Daulatabad Date: Fri, 27 Aug 2021 23:20:40 +0530 Subject: [PATCH] [feat/fix]: A Star Search Improvement (#1566) * A Star Search Improvement * A Star Search Improvement - 2 * A Star Search Improvement - 3 Co-authored-by: David Leal --- machine_learning/a_star_search.cpp | 139 +++++++++++++++-------------- 1 file changed, 74 insertions(+), 65 deletions(-) diff --git a/machine_learning/a_star_search.cpp b/machine_learning/a_star_search.cpp index b9769a308..9f713883a 100644 --- a/machine_learning/a_star_search.cpp +++ b/machine_learning/a_star_search.cpp @@ -25,6 +25,7 @@ #include /// for `std::function` STL #include /// for IO operations #include /// for `std::map` STL +#include /// for `std::shared_ptr` #include /// for `std::set` STL #include /// for `std::vector` STL /** @@ -60,7 +61,7 @@ class EightPuzzle { std::array, N> board; /// N x N array to store the current state of the Puzzle. - std::vector> moves = { + std::vector> moves = { {0, 1}, {1, 0}, {0, -1}, @@ -86,9 +87,7 @@ class EightPuzzle { * @param value index for the current board * @returns `true` if index is within the board, else `false` */ - inline bool in_range(const uint32_t value) const { - return value >= 0 && value < N; - } + inline bool in_range(const uint32_t value) const { return value < N; } public: /** @@ -292,9 +291,9 @@ class AyStarSearch { * state. */ typedef struct Info { - Puzzle state; /// Holds the current state. - size_t heuristic_value = 0; /// stores h score - size_t depth = 0; /// stores g score + std::shared_ptr state; /// Holds the current state. + size_t heuristic_value = 0; /// stores h score + size_t depth = 0; /// stores g score /** * @brief Default constructor @@ -305,7 +304,7 @@ class AyStarSearch { * @brief constructor having Puzzle as parameter * @param A a puzzle object */ - explicit Info(const Puzzle &A) : state(std::move(A)) {} + explicit Info(const Puzzle &A) : state(std::make_shared(A)) {} /** * @brief constructor having three parameters @@ -314,14 +313,16 @@ class AyStarSearch { * @param depth the depth at which this node was found during traversal */ Info(const Puzzle &A, size_t h_value, size_t d) - : state(std::move(A)), heuristic_value(h_value), depth(d) {} + : state(std::make_shared(A)), + heuristic_value(h_value), + depth(d) {} /** * @brief Copy constructor * @param A Info object reference */ Info(const Info &A) - : state(A.state), + : state(std::make_shared(A.state)), heuristic_value(A.heuristic_value), depth(A.depth) {} @@ -330,7 +331,7 @@ class AyStarSearch { * @param A Info object reference */ Info(const Info &&A) noexcept - : state(std::move(A.state)), + : state(std::make_shared(std::move(A.state))), heuristic_value(std::move(A.heuristic_value)), depth(std::move(A.depth)) {} @@ -361,26 +362,36 @@ class AyStarSearch { ~Info() = default; } Info; - Info Initial; // Initial state of the AyStarSearch - Info Final; // Final state of the AyStarSearch + std::shared_ptr Initial; // Initial state of the AyStarSearch + std::shared_ptr Final; // Final state of the AyStarSearch /** * @brief Custom comparator for open_list */ struct comparison_operator { - bool operator()(const Info &a, const Info &b) const { - return a.state < b.state; + bool operator()(const std::shared_ptr &a, + const std::shared_ptr &b) const { + return *(a->state) < *(b->state); } }; public: + using MapOfPuzzleInfoWithPuzzleInfo = + std::map, std::shared_ptr, + comparison_operator>; + + using MapOfPuzzleInfoWithInteger = + std::map, uint32_t, comparison_operator>; + + using SetOfPuzzleInfo = + std::set, comparison_operator>; /** * @brief Parameterized constructor for AyStarSearch * @param initial denoting initial state of the puzzle * @param final denoting final state of the puzzle */ AyStarSearch(const Puzzle &initial, const Puzzle &final) { - Initial = Info(initial); - Final = Info(final); + Initial = std::make_shared(initial); + Final = std::make_shared(final); } /** * @brief A helper solution: launches when a solution for AyStarSearch @@ -392,18 +403,18 @@ class AyStarSearch { * state (in reverse) */ std::vector Solution( - Info *FinalState, - const std::map &parent_of) { + std::shared_ptr FinalState, + const MapOfPuzzleInfoWithPuzzleInfo &parent_of) { // Useful for traversing from final state to current state. - Info *current_state = FinalState; + auto current_state = FinalState; /* * For storing the solution tree starting from initial state to * final state */ std::vector answer; while (current_state != nullptr) { - answer.emplace_back(current_state->state); - current_state = parent_of.find(*current_state)->second; + answer.emplace_back(*current_state->state); + current_state = parent_of.find(current_state)->second; } return answer; } @@ -418,14 +429,11 @@ class AyStarSearch { std::vector a_star_search( const std::function &dist, const uint32_t permissible_depth = 30) { - std::map - parent_of; /// Stores the parent of the states - std::map - g_score; /// Stores the g_score - std::set - open_list; /// Stores the list to explore - std::set - closed_list; /// Stores the list that are explored + MapOfPuzzleInfoWithPuzzleInfo + parent_of; /// Stores the parent of the states + MapOfPuzzleInfoWithInteger g_score; /// Stores the g_score + SetOfPuzzleInfo open_list; /// Stores the list to explore + SetOfPuzzleInfo closed_list; /// Stores the list that are explored // Before starting the AyStartSearch, initialize the set and maps open_list.emplace(Initial); @@ -434,14 +442,13 @@ class AyStarSearch { while (!open_list.empty()) { // Iterator for state having having lowest f_score. - typename std::set::iterator - it_low_f_score; + typename SetOfPuzzleInfo::iterator it_low_f_score; uint32_t min_f_score = 1e9; for (auto iter = open_list.begin(); iter != open_list.end(); ++iter) { // f score here is evaluated by g score (depth) and h score // (distance between current state and final state) - uint32_t f_score = iter->heuristic_value + iter->depth; + uint32_t f_score = (*iter)->heuristic_value + (*iter)->depth; if (f_score < min_f_score) { min_f_score = f_score; it_low_f_score = iter; @@ -449,10 +456,10 @@ class AyStarSearch { } // current_state, stores lowest f score so far for this state. - Info *current_state = new Info(*it_low_f_score); + std::shared_ptr current_state = *it_low_f_score; // if this current state is equal to final, return - if (current_state->state == Final.state) { + if (*(current_state->state) == *(Final->state)) { return Solution(current_state, parent_of); } // else remove from open list as visited. @@ -465,14 +472,15 @@ class AyStarSearch { // Generate all possible moves (neighbors) given the current // state std::vector total_possible_moves = - current_state->state.generate_possible_moves(); + current_state->state->generate_possible_moves(); for (Puzzle &neighbor : total_possible_moves) { // calculate score of neighbors with respect to // current_state - Info Neighbor = {neighbor, dist(neighbor, Final.state), - current_state->depth + 1}; - uint32_t temp_g_score = Neighbor.depth; + std::shared_ptr Neighbor = std::make_shared( + neighbor, dist(neighbor, *(Final->state)), + current_state->depth + 1U); + uint32_t temp_g_score = Neighbor->depth; // Check whether this state is explored. // If this state is discovered at greater depth, then discard, @@ -482,7 +490,7 @@ class AyStarSearch { // 1. If state in closed list has higher depth, then remove // from list since we have found better option, // 2. Else don't explore this state. - if (Neighbor.depth < closed_list_iter->depth) { + if (Neighbor->depth < (*closed_list_iter)->depth) { closed_list.erase(closed_list_iter); } else { continue; @@ -506,12 +514,11 @@ class AyStarSearch { auto iter = open_list.find(Neighbor); if (iter == open_list.end()) { open_list.emplace(Neighbor); - } else if (iter->depth > Neighbor.depth) { - open_list.erase(iter); - open_list.emplace(Neighbor); + } else if ((*iter)->depth > Neighbor->depth) { + (*iter)->depth = Neighbor->depth; } } - closed_list.emplace(*current_state); + closed_list.emplace(current_state); } // Cannot find the solution, return empty vector return std::vector(0); @@ -548,23 +555,24 @@ static void test() { [](const machine_learning::aystar_search::EightPuzzle<> &first, const machine_learning::aystar_search::EightPuzzle<> &second) { uint32_t ret = 0; - for (int i = 0; i < first.get_size(); ++i) { - for (int j = 0; j < first.get_size(); ++j) { + for (size_t i = 0; i < first.get_size(); ++i) { + for (size_t j = 0; j < first.get_size(); ++j) { uint32_t find = first.get(i, j); - int m = -1, n = -1; - for (int k = 0; k < second.get_size(); ++k) { - for (int l = 0; l < second.get_size(); ++l) { + size_t m = first.get_size(), n = first.get_size(); + for (size_t k = 0; k < second.get_size(); ++k) { + for (size_t l = 0; l < second.get_size(); ++l) { if (find == second.get(k, l)) { std::tie(m, n) = std::make_pair(k, l); break; } } - if (m != -1) { + if (m != first.get_size()) { break; } } - if (m != -1) { - ret += abs(m - i) + abs(n - j); + if (m != first.get_size()) { + ret += (std::max(m, i) - std::min(m, i)) + + (std::max(n, j) - std::min(n, j)); } } } @@ -631,10 +639,10 @@ static void test() { // 3rd test: A* search for 15-Puzzle // Initial State of the puzzle matrix4 puzzle2; - puzzle2[0] = row4({5, 1, 2, 3}); - puzzle2[1] = row4({9, 6, 8, 4}); - puzzle2[2] = row4({13, 10, 7, 11}); - puzzle2[3] = row4({14, 15, 12, 0}); + puzzle2[0] = row4({10, 1, 6, 2}); + puzzle2[1] = row4({5, 8, 4, 3}); + puzzle2[2] = row4({13, 0, 7, 11}); + puzzle2[3] = row4({14, 9, 15, 12}); // Final state of the puzzle matrix4 ideal2; ideal2[0] = row4({1, 2, 3, 4}); @@ -656,23 +664,24 @@ static void test() { [](const machine_learning::aystar_search::EightPuzzle<4> &first, const machine_learning::aystar_search::EightPuzzle<4> &second) { uint32_t ret = 0; - for (int i = 0; i < first.get_size(); ++i) { - for (int j = 0; j < first.get_size(); ++j) { + for (size_t i = 0; i < first.get_size(); ++i) { + for (size_t j = 0; j < first.get_size(); ++j) { uint32_t find = first.get(i, j); - int m = -1, n = -1; - for (int k = 0; k < second.get_size(); ++k) { - for (int l = 0; l < second.get_size(); ++l) { + size_t m = first.get_size(), n = first.get_size(); + for (size_t k = 0; k < second.get_size(); ++k) { + for (size_t l = 0; l < second.get_size(); ++l) { if (find == second.get(k, l)) { std::tie(m, n) = std::make_pair(k, l); break; } } - if (m != -1) { + if (m != first.get_size()) { break; } } - if (m != -1) { - ret += abs(m - i) + abs(n - j); + if (m != first.get_size()) { + ret += (std::max(m, i) - std::min(m, i)) + + (std::max(n, j) - std::min(n, j)); } } } @@ -683,7 +692,7 @@ static void test() { std::cout << sol2.size() << std::endl; // Static assertion due to large solution - assert(15 == sol2.size()); + assert(24 == sol2.size()); // Check whether the final state is equal to expected one assert(sol2[0].get_state() == ideal2);