Skip to content

Commit

Permalink
Add initial implementation for Karatsuba multiplication
Browse files Browse the repository at this point in the history
  • Loading branch information
OTheDev committed Mar 2, 2024
1 parent ffef98a commit b80e010
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 2 deletions.
3 changes: 3 additions & 0 deletions src/constants.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ constexpr auto bi_cmp_dbl_size_upper = find_upper();
constexpr auto max_size = dvector::max_size();
constexpr bi_bitcount_t max_bits = max_size * bi_dwidth;

// If both operands of * have size() >= karatsuba_threshold, then use karatsuba
constexpr auto karatsuba_threshold = 60;

} // namespace bi

#endif // BI_SRC_CONSTANTS_HPP_
113 changes: 111 additions & 2 deletions src/h_.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ SPDX-License-Identifier: Apache-2.0
#include <cassert>
#include <cmath>
#include <limits>
#include <random>
#include <string>
#include <utility>

Expand Down Expand Up @@ -41,6 +42,7 @@ struct h_ {
static void mul_algo_square(bi_t& result, const bi_t& a, const size_t m);
static void mul_algo_knuth(bi_t& result, const bi_t& a, const bi_t& b,
const size_t m, const size_t n);
static void mul_karatsuba(bi_t& result, const bi_t& a, const bi_t& b);
static void mul(bi_t& result, const bi_t& a, const bi_t& b);
static digit div_algo_digit(bi_t& q, const bi_t& u, digit v) noexcept;
static void div_algo_single(bi_t& q, bi_t& r, const bi_t& n,
Expand Down Expand Up @@ -81,6 +83,7 @@ struct h_ {
// misc.
static dvector to_twos_complement(const dvector& vec);
static void to_twos_complement_in_place(dvector& vec) noexcept;
static void bisect(const bi_t&, bi_t&, bi_t&, size_t m);

// double
static void assign_from_double(bi_t&, double);
Expand All @@ -90,8 +93,14 @@ struct h_ {
// exponentiation
template <std::unsigned_integral T>
static bi_t expo_left_to_right(const bi_t& base, T exp);

// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
static thread_local std::mt19937 rng_;
static bi_t random_(bi_bitcount_t z);
};

thread_local std::mt19937 h_::rng_{std::random_device{}()}; // NOLINT

void h_::increment_abs(bi_t& x) {
if (x.size() == 0 || x[x.size() - 1] == std::numeric_limits<digit>::max()) {
x.reserve_(x.size() + 1);
Expand Down Expand Up @@ -452,6 +461,73 @@ void h_::mul_algo_knuth(bi_t& w, const bi_t& a, const bi_t& b, const size_t m,
}
}

/**
* @internal
* @page mul_karatsuba Multiplication - Karatsuba
* @ingroup algorithms
* When both operands of multiplication require many base \f$ b \f$ digits,
* we can make use of the following observation to implement a recursive
* algorithm that is more efficient than the standard quadratic
* pencil-and-paper method.
*
* The following is adapted from Knuth Vol. 2, pp. 294-295.
*
* Consider two integers with \f$ 2n \f$ base-b digits and their product:
* \f{align}{
* u &= (u_{2n-1} \cdots u_{0})_{b} = b^{n}U_{1} + U_{0} \\
* v &= (v_{2n-1} \cdots v_{0})_{b} = b^{n}V_{1} + V_{0} \\
* U_{1} &\equiv (u_{2n-1} \cdots u_{n})_{b},
* \; U_{0} \equiv (u_{n-1} \cdots u_{0})_{b} \\
* V_{1} &\equiv (v_{2n-1} \cdots v_{n})_{b},
* \; V_{0} \equiv (v_{n-1} \cdots v_{0})_{b} \\
* uv &= (b^{n}U_{1} + U_{0})(b^{n}V_{1} + V_{0})
* \f}
* It is simple to verify that the product is equivalent to
* \f[
* uv = b^{2n}U_{1}V_{1} + b^{n}\left[(U_{1}+U_{0})(V_{1}+V_{0}) -
* (U_{0}V_{0} + U_{1}V_{1})\right] + U_{0}V_{0}
* \f]
* The result of multiplying two \f$ 2n \f$ digit integers can be found
* via three multiplications of \f$ n \f$ digit integers (plus cheaper bit
* shifts and additions).
* @endinternal
*/
void h_::bisect(const bi::bi_t& x, bi::bi_t& lower, bi::bi_t& upper, size_t n) {
const size_t bisect_point = std::min(n, x.size());

lower.vec_ = dvector(x.vec_.begin(), x.vec_.begin() + bisect_point);
upper.vec_ = dvector(x.vec_.begin() + bisect_point, x.vec_.end());

lower.trim_trailing_zeros();
upper.trim_trailing_zeros();
}

void h_::mul_karatsuba(bi_t& w, const bi_t& u, const bi_t& v) {
if (u.size() < karatsuba_threshold || v.size() < karatsuba_threshold) {
h_::mul(w, u, v);
return;
}

const size_t n = std::min(u.size(), v.size()) >> 1;

bi_t u0, u1, v0, v1;
bisect(u, u0, u1, n);
bisect(v, v0, v1, n);

bi_t a, b, c;
mul_karatsuba(a, u1, v1); // a = u1 * v1
mul_karatsuba(b, u0, v0); // b = u0 * v0

u1 += u0;
v1 += v0;
mul_karatsuba(c, u1, v1);
c -= a + b; // c = (u1 + u0)(v1 + v0) - (u0v0 + u1v1)

a <<= (2 * n * bi_dbits);
c <<= (n * bi_dbits);
w = a + c + b; // w = (bi_base ** 2n) * a + (bi_base ** n) * c + b
}

/**
* @brief Performs `result = a * b`.
* @note mult_helpers.hpp proves that multiplying any two digits followed by
Expand Down Expand Up @@ -1747,8 +1823,41 @@ bi_t h_::expo_left_to_right(const bi_t& base, T exp) {
return ret;
}

} // namespace bi

///@}

bi_t h_::random_(bi_bitcount_t z) {
bi_t result{};

if (z == 0) {
return result;
}

size_t num_digits = z / bi_dbits;
size_t remainder = z % bi_dbits;

if (remainder != 0) {
num_digits += 1;
}

result.resize_(num_digits);

std::uniform_int_distribution<digit> dist(0, bi_dmax);

for (size_t i = 0; i < num_digits - 1; ++i) {
result[i] = dist(rng_);
}

if (remainder != 0) {
const digit mask = (static_cast<digit>(1) << remainder) - 1;
result[num_digits - 1] = dist(rng_) & mask;
} else {
result[num_digits - 1] = dist(rng_);
}

result.trim_trailing_zeros();
return result;
}

} // namespace bi

#endif // BI_SRC_H__HPP_
56 changes: 56 additions & 0 deletions test/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ SPDX-License-Identifier: Apache-2.0
#include <cmath>
#include <functional>
#include <limits>
#include <numeric>
#include <random>
#include <string>

Expand All @@ -19,6 +20,14 @@ SPDX-License-Identifier: Apache-2.0
#include "int128.hpp"
#include "uints.hpp"

namespace bi {
struct h_ {
static bi_t random_(bi_bitcount_t z);
static void mul_karatsuba(bi_t&, const bi_t&, const bi_t&);
static void mul(bi_t& result, const bi_t& a, const bi_t& b);
};
} // namespace bi

namespace {

using bi::bi_dmax;
Expand Down Expand Up @@ -2245,6 +2254,53 @@ TEST_F(BITest, Exponentiation) {
EXPECT_THROW(bi_t::pow(two, bi_max_bits + 1), bi::overflow_error);
}

TEST_F(BITest, Karatsuba) {
std::random_device rdev;
std::mt19937_64 rng(rdev());
std::uniform_int_distribution<int> dist(bi::karatsuba_threshold,
bi::karatsuba_threshold * 4);

const int num_iterations = 1000;
std::vector<double> karatsuba_times, normal_times;

for (int i = 0; i < num_iterations; ++i) {
std::clock_t start{};
std::clock_t end{};
bi_t x_k, x_n;

bi_t r_1 = bi::h_::random_(bi_dwidth * dist(rng));
bi_t r_2 = bi::h_::random_(bi_dwidth * dist(rng));

start = std::clock();

bi::h_::mul_karatsuba(x_k, r_1, r_2);

end = std::clock();
karatsuba_times.push_back(1000.0 * static_cast<double>(end - start) /
CLOCKS_PER_SEC);

start = std::clock();

bi::h_::mul(x_n, r_1, r_2);

end = std::clock();
normal_times.push_back(1000.0 * static_cast<double>(end - start) /
CLOCKS_PER_SEC);

ASSERT_EQ(x_k, x_n);
}

double avg_karatsuba_time =
std::accumulate(karatsuba_times.begin(), karatsuba_times.end(), 0.0) /
num_iterations;
double avg_normal_time =
std::accumulate(normal_times.begin(), normal_times.end(), 0.0) /
num_iterations;

std::cout << "Average Mult-Karatsuba time: " << avg_karatsuba_time << " ms\n";
std::cout << "Average Mult-Pencil time: " << avg_normal_time << " ms\n";
}

// NOLINTEND(cppcoreguidelines-avoid-magic-numbers)

} // namespace

0 comments on commit b80e010

Please sign in to comment.