Skip to content

Commit

Permalink
Merge pull request #441 from kroma-network/feat/impl-radix2ditparallel
Browse files Browse the repository at this point in the history
feat: impl `Radix2DitParallel`
  • Loading branch information
chokobole authored Jun 27, 2024
2 parents 5de7501 + fa4c33f commit bbbabd0
Show file tree
Hide file tree
Showing 21 changed files with 605 additions and 54 deletions.
1 change: 1 addition & 0 deletions tachyon/math/finite_fields/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ tachyon_cc_library(
deps = [
":finite_field",
":legendre_symbol",
":packed_prime_field_traits_forward",
":prime_field_util",
"//tachyon/base:bits",
"//tachyon/base/json",
Expand Down
11 changes: 9 additions & 2 deletions tachyon/math/finite_fields/baby_bear/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
load("@bazel_skylib//rules:common_settings.bzl", "string_flag")
load("//bazel:tachyon.bzl", "if_aarch64", "if_has_avx512", "if_x86_64")
load("//bazel:tachyon_cc.bzl", "tachyon_avx512_defines", "tachyon_cc_library")
load("//tachyon/math/finite_fields/generator/ext_prime_field_generator:build_defs.bzl", "generate_fp4s")
load("//tachyon/math/finite_fields/generator/prime_field_generator:build_defs.bzl", "generate_prime_fields")
load("//tachyon/math/finite_fields/generator/prime_field_generator:build_defs.bzl", "SUBGROUP_GENERATOR", "generate_fft_prime_fields")

package(default_visibility = ["//visibility:public"])

generate_prime_fields(
string_flag(
name = SUBGROUP_GENERATOR,
build_setting_default = "31",
)

generate_fft_prime_fields(
name = "baby_bear",
class_name = "BabyBear",
# 2³¹ - 2²⁷ + 1
# Hex: 0x78000001
modulus = "2013265921",
namespace = "tachyon::math",
subgroup_generator = ":" + SUBGROUP_GENERATOR,
use_montgomery = True,
)

Expand Down
5 changes: 5 additions & 0 deletions tachyon/math/finite_fields/baby_bear/packed_baby_bear.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ struct FiniteFieldTraits<PackedBabyBear> {
using Config = BabyBear::Config;
};

template <>
struct PackedPrimeFieldTraits<BabyBear> {
using PackedPrimeField = PackedBabyBear;
};

} // namespace tachyon::math

namespace Eigen {
Expand Down
11 changes: 9 additions & 2 deletions tachyon/math/finite_fields/koala_bear/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
load("@bazel_skylib//rules:common_settings.bzl", "string_flag")
load("//bazel:tachyon.bzl", "if_aarch64", "if_has_avx512", "if_x86_64")
load("//bazel:tachyon_cc.bzl", "tachyon_avx512_defines", "tachyon_cc_library")
load("//tachyon/math/finite_fields/generator/ext_prime_field_generator:build_defs.bzl", "generate_fp2s", "generate_fp4s")
load("//tachyon/math/finite_fields/generator/prime_field_generator:build_defs.bzl", "generate_prime_fields")
load("//tachyon/math/finite_fields/generator/prime_field_generator:build_defs.bzl", "SUBGROUP_GENERATOR", "generate_fft_prime_fields")

package(default_visibility = ["//visibility:public"])

generate_prime_fields(
string_flag(
name = SUBGROUP_GENERATOR,
build_setting_default = "3",
)

generate_fft_prime_fields(
name = "koala_bear",
class_name = "KoalaBear",
# 2³¹ - 2²⁴ + 1
# Hex: 0x7f000001
modulus = "2130706433",
namespace = "tachyon::math",
subgroup_generator = ":" + SUBGROUP_GENERATOR,
use_montgomery = True,
)

Expand Down
5 changes: 5 additions & 0 deletions tachyon/math/finite_fields/koala_bear/packed_koala_bear.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ struct FiniteFieldTraits<PackedKoalaBear> {
using Config = KoalaBear::Config;
};

template <>
struct PackedPrimeFieldTraits<KoalaBear> {
using PackedPrimeField = PackedKoalaBear;
};

} // namespace tachyon::math

namespace Eigen {
Expand Down
5 changes: 5 additions & 0 deletions tachyon/math/finite_fields/mersenne31/packed_mersenne31.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ struct FiniteFieldTraits<PackedMersenne31> {
using Config = Mersenne31::Config;
};

template <>
struct PackedPrimeFieldTraits<Mersenne31> {
using PackedPrimeField = PackedMersenne31;
};

} // namespace tachyon::math

namespace Eigen {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

namespace tachyon::math {

template <typename T>
template <typename T, typename SFINAE = void>
struct PackedPrimeFieldTraits;

} // namespace tachyon::math
Expand Down
7 changes: 7 additions & 0 deletions tachyon/math/finite_fields/prime_field_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "tachyon/math/base/gmp/gmp_util.h"
#include "tachyon/math/finite_fields/finite_field.h"
#include "tachyon/math/finite_fields/legendre_symbol.h"
#include "tachyon/math/finite_fields/packed_prime_field_traits_forward.h"
#include "tachyon/math/finite_fields/prime_field_util.h"

namespace tachyon {
Expand Down Expand Up @@ -160,6 +161,12 @@ H AbslHashValue(H h, const F& prime_field) {
return h;
}

template <typename T>
struct PackedPrimeFieldTraits<
T, std::enable_if_t<std::is_base_of_v<math::PrimeFieldBase<T>, T>>> {
using PackedPrimeField = T;
};

} // namespace math

namespace base {
Expand Down
2 changes: 2 additions & 0 deletions tachyon/math/matrix/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ tachyon_cc_library(
name = "matrix_utils",
hdrs = ["matrix_utils.h"],
deps = [
"//tachyon/base:bits",
"//tachyon/base:openmp_util",
"//tachyon/base/containers:container_util",
"//tachyon/math/finite_fields:packed_prime_field_traits_forward",
"@eigen_archive//:eigen3",
Expand Down
74 changes: 62 additions & 12 deletions tachyon/math/matrix/matrix_utils.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
#ifndef TACHYON_MATH_MATRIX_MATRIX_UTILS_H_
#define TACHYON_MATH_MATRIX_MATRIX_UTILS_H_

#include <utility>
#include <vector>

#include "third_party/eigen3/Eigen/Core"

#include "tachyon/base/bits.h"
#include "tachyon/base/containers/container_util.h"
#include "tachyon/base/openmp_util.h"
#include "tachyon/math/finite_fields/packed_prime_field_traits_forward.h"

namespace tachyon::math {
Expand Down Expand Up @@ -45,22 +48,25 @@ MakeCirculant(const Eigen::MatrixBase<ArgType>& arg) {
CirculantFunctor<ArgType>(arg.derived()));
}

// NOTE(ashjeong): Important! |matrix| should carry the same amount of rows as
// the parent matrix it is a block from. |PackRowHorizontally| currently only
// supports row-major matrices.
template <typename PackedPrimeField, typename Derived, typename PrimeField>
std::vector<PackedPrimeField> PackRowHorizontally(
const Eigen::MatrixBase<Derived>& matrix, size_t row,
std::vector<PrimeField>& remaining_values) {
std::vector<PackedPrimeField*> PackRowHorizontally(
Eigen::Block<Derived>& matrix, size_t row,
std::vector<PrimeField*>& remaining_values) {
static_assert(Derived::Options & Eigen::RowMajorBit);
size_t num_packed = matrix.cols() / PackedPrimeField::N;
size_t remaining_start_idx = num_packed * PackedPrimeField::N;
remaining_values =
base::CreateVector(matrix.cols() - remaining_start_idx,
[row, remaining_start_idx, &matrix](size_t col) {
return matrix(row, remaining_start_idx + col);
});

remaining_values = base::CreateVector(
matrix.cols() - remaining_start_idx,
[row, remaining_start_idx, &matrix](size_t col) {
return reinterpret_cast<PrimeField*>(
matrix.data() + row * matrix.cols() + remaining_start_idx + col);
});
return base::CreateVector(num_packed, [row, &matrix](size_t col) {
return PackedPrimeField::From([row, col, &matrix](size_t i) {
return matrix(row, PackedPrimeField::N * col + i);
});
return reinterpret_cast<PackedPrimeField*>(
matrix.data() + row * matrix.cols() + PackedPrimeField::N * col);
});
}

Expand All @@ -74,6 +80,50 @@ std::vector<PackedPrimeField> PackRowVertically(
});
}

// Expands a |Eigen::MatrixBase|'s rows from |rows| to |rows|^(|added_bits|),
// moving values from row |i| to row |i|^(|added_bits|). All new entries are set
// to |F::Zero()|.
template <typename Derived>
void ExpandInPlaceWithZeroPad(Eigen::MatrixBase<Derived>& mat,
size_t added_bits) {
if (added_bits == 0) {
return;
}

Eigen::Index original_rows = mat.rows();
Eigen::Index new_rows = mat.rows() << added_bits;
Eigen::Index cols = mat.cols();

Derived padded = Derived::Zero(new_rows, cols);

OPENMP_PARALLEL_FOR(Eigen::Index row = 0; row < original_rows; ++row) {
Eigen::Index padded_row_index = row << added_bits;
// TODO(ashjeong): Check if moved properly
padded.row(padded_row_index) = std::move(mat.row(row));
}
mat = std::move(padded);
}

// Swaps rows of a |Eigen::MatrixBase| such that each row is changed to the row
// accessed with the reversed bits of the current index. Crashes if the number
// of rows is not a power of two.
template <typename Derived>
void ReverseMatrixIndexBits(Eigen::MatrixBase<Derived>& mat) {
size_t rows = static_cast<size_t>(mat.rows());
if (rows == 0) {
return;
}
CHECK(base::bits::IsPowerOfTwo(rows));
size_t log_n = base::bits::Log2Ceiling(rows);

OPENMP_PARALLEL_FOR(size_t row = 1; row < rows; ++row) {
size_t ridx = base::bits::BitRev(row) >> (sizeof(size_t) * 8 - log_n);
if (row < ridx) {
mat.row(row).swap(mat.row(ridx));
}
}
}

} // namespace tachyon::math

#endif // TACHYON_MATH_MATRIX_MATRIX_UTILS_H_
54 changes: 31 additions & 23 deletions tachyon/math/matrix/matrix_utils_unittest.cc
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include "tachyon/math/matrix/matrix_utils.h"

#include "tachyon/base/strings/string_util.h"
#include "tachyon/build/build_config.h"
#include "tachyon/math/finite_fields/baby_bear/packed_baby_bear.h"
#include "tachyon/math/finite_fields/test/finite_field_test.h"
#include "tachyon/math/finite_fields/test/gf7.h"
Expand All @@ -27,30 +26,39 @@ TEST_F(MatrixPackingTest, PackRowHorizontally) {
constexpr size_t N = PackedBabyBear::N;
constexpr size_t R = 3;

Matrix<BabyBear> matrix = Matrix<BabyBear>::Random(2 * N, 2 * N);
std::vector<BabyBear> remaining_values;
std::vector<PackedBabyBear> packed_values =
PackRowHorizontally<PackedBabyBear>(matrix, R, remaining_values);
ASSERT_TRUE(remaining_values.empty());
ASSERT_EQ(packed_values.size(), 2);
for (size_t i = 0; i < packed_values.size(); ++i) {
for (size_t j = 0; j < N; ++j) {
EXPECT_EQ(packed_values[i][j], matrix(R, i * N + j));
{
RowMajorMatrix<BabyBear> matrix =
RowMajorMatrix<BabyBear>::Random(2 * N, 2 * N);
Eigen::Block<RowMajorMatrix<BabyBear>> mat =
matrix.block(0, 0, matrix.rows(), matrix.cols());
std::vector<BabyBear*> remaining_values;
std::vector<PackedBabyBear*> packed_values =
PackRowHorizontally<PackedBabyBear>(mat, R, remaining_values);
ASSERT_TRUE(remaining_values.empty());
ASSERT_EQ(packed_values.size(), 2);
for (size_t i = 0; i < packed_values.size(); ++i) {
for (size_t j = 0; j < N; ++j) {
EXPECT_EQ((*packed_values[i])[j], matrix(R, i * N + j));
}
}
}

matrix = Matrix<BabyBear>::Random(2 * N - 1, 2 * N - 1);
remaining_values.clear();
packed_values =
PackRowHorizontally<PackedBabyBear>(matrix, R, remaining_values);
ASSERT_EQ(remaining_values.size(), N - 1);
ASSERT_EQ(packed_values.size(), 1);
for (size_t i = 0; i < remaining_values.size(); ++i) {
EXPECT_EQ(remaining_values[i], matrix(R, packed_values.size() * N + i));
}
for (size_t i = 0; i < packed_values.size(); ++i) {
for (size_t j = 0; j < N; ++j) {
EXPECT_EQ(packed_values[i][j], matrix(R, i * N + j));
{
RowMajorMatrix<BabyBear> matrix =
RowMajorMatrix<BabyBear>::Random(2 * N - 1, 2 * N - 1);
Eigen::Block<RowMajorMatrix<BabyBear>> mat =
matrix.block(0, 0, matrix.rows(), matrix.cols());
std::vector<BabyBear*> remaining_values;
std::vector<PackedBabyBear*> packed_values =
PackRowHorizontally<PackedBabyBear>(mat, R, remaining_values);
ASSERT_EQ(remaining_values.size(), N - 1);
ASSERT_EQ(packed_values.size(), 1);
for (size_t i = 0; i < remaining_values.size(); ++i) {
EXPECT_EQ(*remaining_values[i], matrix(R, packed_values.size() * N + i));
}
for (size_t i = 0; i < packed_values.size(); ++i) {
for (size_t j = 0; j < N; ++j) {
EXPECT_EQ((*packed_values[i])[j], matrix(R, i * N + j));
}
}
}
}
Expand Down
29 changes: 29 additions & 0 deletions tachyon/math/polynomials/univariate/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,41 @@ tachyon_cc_library(
],
)

tachyon_cc_library(
name = "naive_batch_fft",
hdrs = ["naive_batch_fft.h"],
deps = [
":two_adic_subgroup",
"//tachyon/base:bits",
],
)

tachyon_cc_library(
name = "radix2_evaluation_domain",
hdrs = ["radix2_evaluation_domain.h"],
deps = [
":two_adic_subgroup",
":univariate_evaluation_domain",
"//tachyon/base:bits",
"//tachyon/base:openmp_util",
"//tachyon/base:parallelize",
"//tachyon/base/containers:container_util",
"//tachyon/math/finite_fields:packed_prime_field_base",
"//tachyon/math/matrix:matrix_types",
"//tachyon/math/matrix:matrix_utils",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest_prod",
"@eigen_archive//:eigen3",
],
)

tachyon_cc_library(
name = "two_adic_subgroup",
hdrs = ["two_adic_subgroup.h"],
deps = [
"//tachyon/base:optional",
"//tachyon/math/matrix:matrix_types",
],
)

Expand Down Expand Up @@ -118,6 +143,7 @@ tachyon_cc_unittest(
name = "univariate_unittests",
srcs = [
"lagrange_interpolation_unittest.cc",
"radix2_evaluation_domain_unittest.cc",
"univariate_dense_polynomial_unittest.cc",
"univariate_evaluation_domain_unittest.cc",
"univariate_evaluations_unittest.cc",
Expand All @@ -126,6 +152,7 @@ tachyon_cc_unittest(
deps = [
":lagrange_interpolation",
":mixed_radix_evaluation_domain",
":naive_batch_fft",
":radix2_evaluation_domain",
":univariate_polynomial",
"//tachyon/base:optional",
Expand All @@ -136,6 +163,8 @@ tachyon_cc_unittest(
"//tachyon/math/elliptic_curves/bls12/bls12_381:fr",
"//tachyon/math/elliptic_curves/bn/bn254:fr",
"//tachyon/math/elliptic_curves/bn/bn384_small_two_adicity:fq",
"//tachyon/math/finite_fields/baby_bear:packed_baby_bear",
"//tachyon/math/finite_fields/koala_bear:packed_koala_bear",
"//tachyon/math/finite_fields/test:finite_field_test",
"//tachyon/math/finite_fields/test:gf7",
"@com_google_absl//absl/hash:hash_testing",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ class MixedRadixEvaluationDomain
for (size_t k = 0; k < n; k += 2 * m) {
F w = F::One();
for (size_t j = 0; j < m; ++j) {
UnivariateEvaluationDomain<F, MaxDegree>::ButterflyFnOutIn(
UnivariateEvaluationDomain<F, MaxDegree>::template ButterflyFnOutIn(
a.at(k + j), a.at((k + m) + j), w);
w *= w_m;
}
Expand Down
Loading

0 comments on commit bbbabd0

Please sign in to comment.