fix: karatsuba's algorithm is not compiling (#2115)

* fix: karatsuba's algorithm is not compiling
doc: improved comments

* fix: continuous integration issues

Co-authored-by: David Leal <halfpacho@gmail.com>
This commit is contained in:
Ameer Carlo Lubang 2022-10-06 01:18:56 +08:00 committed by GitHub
parent 84ff18e0ac
commit 97c7d91878
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4,11 +4,12 @@
* multiplication](https://en.wikipedia.org/wiki/Karatsuba_algorithm) * multiplication](https://en.wikipedia.org/wiki/Karatsuba_algorithm)
* @details * @details
* Given two strings in binary notation we want to multiply them and return the * Given two strings in binary notation we want to multiply them and return the
* value Simple approach is to multiply bits one by one which will give the time * value. Simple approach is to multiply bits one by one which will give the time
* complexity of around O(n^2). To make it more efficient we will be using * complexity of around O(n^2). To make it more efficient we will be using
* Karatsuba' algorithm to find the product which will solve the problem * Karatsuba algorithm to find the product which will solve the problem
* O(nlogn) of time. * O(nlogn) of time.
* @author [Swastika Gupta](https://github.com/Swastyy) * @author [Swastika Gupta](https://github.com/Swastyy)
* @author [Ameer Carlo Lubang](https://github.com/poypoyan)
*/ */
#include <cassert> /// for assert #include <cassert> /// for assert
@ -24,101 +25,117 @@ namespace divide_and_conquer {
/** /**
* @namespace karatsuba_algorithm * @namespace karatsuba_algorithm
* @brief Functions for the [Karatsuba algorithm for fast * @brief Functions for the [Karatsuba algorithm for fast
* multiplication](https://en.wikipedia.org/wiki/Karatsuba_algorithm) * multiplication](https://en.wikipedia.org/wiki/Karatsuba_algorithm) implementation
*/ */
namespace karatsuba_algorithm { namespace karatsuba_algorithm {
/** /**
* @brief Helper function for the main function, that implements Karatsuba's * @brief Binary addition
* algorithm for fast multiplication * @param first, the input string 1
* @param first the input string 1 * @param second, the input string 2
* @param second the input string 2 * @returns the sum binary string
* @returns the concatenated string
*/ */
std::string addStrings(std::string first, std::string second) { std::string add_strings(std::string first, std::string second) {
std::string result; // To store the resulting sum bits std::string result; // to store the resulting sum bits
// make the string lengths equal
int64_t len1 = first.size(); int64_t len1 = first.size();
int64_t len2 = second.size(); int64_t len2 = second.size();
int64_t length = std::max(len1, len2);
std::string zero = "0"; std::string zero = "0";
if (len1 < len2) // make the string lengths equal if (len1 < len2) {
{
for (int64_t i = 0; i < len2 - len1; i++) { for (int64_t i = 0; i < len2 - len1; i++) {
zero += first; zero += first;
first = zero; first = zero;
zero = "0"; // Prevents CI from failing
} }
} else if (len1 > len2) { } else if (len1 > len2) {
zero = "0";
for (int64_t i = 0; i < len1 - len2; i++) { for (int64_t i = 0; i < len1 - len2; i++) {
zero += second; zero += second;
second = zero; second = zero;
zero = "0"; // Prevents CI from failing
} }
} }
int64_t length = std::max(len1, len2);
int64_t carry = 0; int64_t carry = 0;
for (int64_t i = length - 1; i >= 0; i--) { for (int64_t i = length - 1; i >= 0; i--) {
int64_t firstBit = first.at(i) - '0'; int64_t firstBit = first.at(i) - '0';
int64_t secondBit = second.at(i) - '0'; int64_t secondBit = second.at(i) - '0';
int64_t sum = (firstBit ^ secondBit ^ carry) + '0'; // sum of 3 bits int64_t sum = (char(firstBit ^ secondBit ^ carry)) + '0'; // sum of 3 bits
std::string temp; result.insert(result.begin(), sum);
temp = std::to_string(sum);
temp += result;
result = temp;
carry = (firstBit & secondBit) | (secondBit & carry) | carry = char((firstBit & secondBit) | (secondBit & carry) |
(firstBit & carry); // sum of 3 bits (firstBit & carry)); // sum of 3 bits
} }
if (carry) { if (carry) {
result = '1' + result; // adding 1 incase of overflow result.insert(result.begin(), '1'); // adding 1 incase of overflow
} }
return result; return result;
} }
/**
* @brief Wrapper function for substr that considers leading zeros.
* @param str, the binary input string.
* @param x1, the substr parameter integer 1
* @param x2, the substr parameter integer 2
* @param n, is the length of the "whole" string: leading zeros + str
* @returns the "safe" substring for the algorithm *without* leading zeros
* @returns "0" if substring spans to leading zeros only
*/
std::string safe_substr(const std::string &str, int64_t x1, int64_t x2, int64_t n) {
int64_t len = str.size();
if (len >= n) {
return str.substr(x1, x2);
}
int64_t y1 = x1 - (n - len); // index in str of first char of substring of "whole" string
int64_t y2 = (x1 + x2 - 1) - (n - len); // index in str of last char of substring of "whole" string
if (y2 < 0) {
return "0";
} else if (y1 < 0) {
return str.substr(0, y2 + 1);
} else {
return str.substr(y1, x2);
}
}
/** /**
* @brief The main function implements Karatsuba's algorithm for fast * @brief The main function implements Karatsuba's algorithm for fast
* multiplication * multiplication
* @param str1 the input string 1 * @param str1 the input string 1
* @param str2 the input string 2 * @param str2 the input string 2
* @returns the multiplicative number value * @returns the product number value
*/ */
int64_t karatsuba_algorithm(std::string str1, std::string str2) { int64_t karatsuba_algorithm(std::string str1, std::string str2) {
int64_t len1 = str1.size(); int64_t len1 = str1.size();
int64_t len2 = str2.size(); int64_t len2 = str2.size();
int64_t n = std::max(len1, len2); int64_t n = std::max(len1, len2);
std::string zero = "0";
if (len1 < len2) {
for (int64_t i = 0; i < len2 - len1; i++) {
zero += str1;
str1 = zero;
}
} else if (len1 > len2) {
zero = "0";
for (int64_t i = 0; i < len1 - len2; i++) {
zero += str2;
str2 = zero;
}
}
if (n == 0) { if (n == 0) {
return 0; return 0;
} }
if (n == 1) { if (n == 1) {
return (str1[0] - '0') * (str2[0] - '0'); return (str1[0] - '0') * (str2[0] - '0');
} }
int64_t fh = n / 2; // first half of string int64_t fh = n / 2; // first half of string
int64_t sh = (n - fh); // second half of string int64_t sh = n - fh; // second half of string
std::string Xl = str1.substr(0, fh); // first half of first string std::string Xl = divide_and_conquer::karatsuba_algorithm::safe_substr(str1, 0, fh, n); // first half of first string
std::string Xr = str1.substr(fh, sh); // second half of first string std::string Xr = divide_and_conquer::karatsuba_algorithm::safe_substr(str1, fh, sh, n); // second half of first string
std::string Yl = str2.substr(0, fh); // first half of second string std::string Yl = divide_and_conquer::karatsuba_algorithm::safe_substr(str2, 0, fh, n); // first half of second string
std::string Yr = str2.substr(fh, sh); // second half of second string std::string Yr = divide_and_conquer::karatsuba_algorithm::safe_substr(str2, fh, sh, n); // second half of second string
// Calculating the three products of inputs of size n/2 recursively // calculating the three products of inputs of size n/2 recursively
int64_t product1 = karatsuba_algorithm(Xl, Yl); int64_t product1 = karatsuba_algorithm(Xl, Yl);
int64_t product2 = karatsuba_algorithm(Xr, Yr); int64_t product2 = karatsuba_algorithm(Xr, Yr);
int64_t product3 = karatsuba_algorithm( int64_t product3 = karatsuba_algorithm(
divide_and_conquer::karatsuba_algorithm::addStrings(Xl, Xr), divide_and_conquer::karatsuba_algorithm::add_strings(Xl, Xr),
divide_and_conquer::karatsuba_algorithm::addStrings(Yl, Yr)); divide_and_conquer::karatsuba_algorithm::add_strings(Yl, Yr));
return product1 * (1 << (2 * sh)) + return product1 * (1 << (2 * sh)) +
(product3 - product1 - product2) * (1 << sh) + (product3 - product1 - product2) * (1 << sh) +
@ -133,27 +150,27 @@ int64_t karatsuba_algorithm(std::string str1, std::string str2) {
*/ */
static void test() { static void test() {
// 1st test // 1st test
std::string s11 = "1"; std::string s11 = "1"; // 1
std::string s12 = "1010"; std::string s12 = "1010"; // 10
std::cout << "1st test... "; std::cout << "1st test... ";
assert(divide_and_conquer::karatsuba_algorithm::karatsuba_algorithm( assert(divide_and_conquer::karatsuba_algorithm::karatsuba_algorithm(
s11, s12) == 10); // here the multiplication is 10 s11, s12) == 10);
std::cout << "passed" << std::endl; std::cout << "passed" << std::endl;
// 2nd test // 2nd test
std::string s21 = "11"; std::string s21 = "11"; // 3
std::string s22 = "1010"; std::string s22 = "1010"; // 10
std::cout << "2nd test... "; std::cout << "2nd test... ";
assert(divide_and_conquer::karatsuba_algorithm::karatsuba_algorithm( assert(divide_and_conquer::karatsuba_algorithm::karatsuba_algorithm(
s21, s22) == 30); // here the multiplication is 30 s21, s22) == 30);
std::cout << "passed" << std::endl; std::cout << "passed" << std::endl;
// 3rd test // 3rd test
std::string s31 = "110"; std::string s31 = "110"; // 6
std::string s32 = "1010"; std::string s32 = "1010"; // 10
std::cout << "3rd test... "; std::cout << "3rd test... ";
assert(divide_and_conquer::karatsuba_algorithm::karatsuba_algorithm( assert(divide_and_conquer::karatsuba_algorithm::karatsuba_algorithm(
s31, s32) == 60); // here the multiplication is 60 s31, s32) == 60);
std::cout << "passed" << std::endl; std::cout << "passed" << std::endl;
} }