diff --git a/tachyon/math/base/BUILD.bazel b/tachyon/math/base/BUILD.bazel index aca3104a0..2e149f944 100644 --- a/tachyon/math/base/BUILD.bazel +++ b/tachyon/math/base/BUILD.bazel @@ -141,6 +141,7 @@ tachyon_cc_unittest( ":sign", "//tachyon/base/buffer:vector_buffer", "//tachyon/base/containers:container_util", + "//tachyon/base/strings:string_number_conversions", "//tachyon/math/elliptic_curves/msm/test:variable_base_msm_test_set", "//tachyon/math/elliptic_curves/short_weierstrass/test:sw_curve_config", "//tachyon/math/finite_fields/test:finite_field_test", diff --git a/tachyon/math/base/field.h b/tachyon/math/base/field.h index 3f3ab4290..99026c4cb 100644 --- a/tachyon/math/base/field.h +++ b/tachyon/math/base/field.h @@ -1,6 +1,7 @@ #ifndef TACHYON_MATH_BASE_FIELD_H_ #define TACHYON_MATH_BASE_FIELD_H_ +#include #include #include "tachyon/math/base/ring.h" @@ -32,6 +33,12 @@ class Field : public AdditiveGroup, public MultiplicativeGroup { } }; +template +std::ostream& operator<<(std::ostream& os, const Field& f) { + const F& derived = static_cast(f); + return os << derived.ToString(); +} + } // namespace tachyon::math #endif // TACHYON_MATH_BASE_FIELD_H_ diff --git a/tachyon/math/base/groups.h b/tachyon/math/base/groups.h index 763ff174a..23178fc7a 100644 --- a/tachyon/math/base/groups.h +++ b/tachyon/math/base/groups.h @@ -2,6 +2,7 @@ #define TACHYON_MATH_BASE_GROUPS_H_ #include +#include #include #include #include @@ -158,6 +159,12 @@ class MultiplicativeGroup : public MultiplicativeSemigroup { } }; +template +std::ostream& operator<<(std::ostream& os, const MultiplicativeGroup& g) { + const G& derived = static_cast(g); + return os << derived.ToString(); +} + // AdditiveGroup is a group with the group operation '+'. // AdditiveGroup supports subtraction and negation, inheriting the // properties of AdditiveSemigroup. @@ -203,6 +210,12 @@ class AdditiveGroup : public AdditiveSemigroup { } }; +template +std::ostream& operator<<(std::ostream& os, const AdditiveGroup& g) { + const G& derived = static_cast(g); + return os << derived.ToString(); +} + } // namespace tachyon::math #endif // TACHYON_MATH_BASE_GROUPS_H_ diff --git a/tachyon/math/base/groups_unittest.cc b/tachyon/math/base/groups_unittest.cc index 6bd854fde..89a0d89bc 100644 --- a/tachyon/math/base/groups_unittest.cc +++ b/tachyon/math/base/groups_unittest.cc @@ -1,9 +1,12 @@ #include "tachyon/math/base/groups.h" +#include + #include "gmock/gmock.h" #include "gtest/gtest.h" #include "tachyon/base/containers/container_util.h" +#include "tachyon/base/strings/string_number_conversions.h" #include "tachyon/math/finite_fields/test/gf7.h" namespace tachyon::math { @@ -20,6 +23,8 @@ TEST(GroupsTest, Div) { bool operator==(const Int& other) const { return value_ == other.value_; } + std::string ToString() const { return base::NumberToString(value_); } + private: int value_ = 0; }; @@ -44,6 +49,8 @@ TEST(GroupsTest, InverseOverride) { return denominator_ == other.denominator_; } + std::string ToString() const { return base::NumberToString(denominator_); } + private: int denominator_ = 0; }; @@ -104,6 +111,8 @@ TEST(GroupsTest, Sub) { bool operator==(const Int& other) const { return value_ == other.value_; } + std::string ToString() const { return base::NumberToString(value_); } + private: int value_ = 0; }; @@ -130,6 +139,8 @@ TEST(GroupsTest, SubOverAdd) { bool operator==(const Int& other) const { return value_ == other.value_; } + std::string ToString() const { return base::NumberToString(value_); } + private: int value_ = 0; }; diff --git a/tachyon/math/matrix/BUILD.bazel b/tachyon/math/matrix/BUILD.bazel index 4bc74b4c1..45eaafd77 100644 --- a/tachyon/math/matrix/BUILD.bazel +++ b/tachyon/math/matrix/BUILD.bazel @@ -34,6 +34,7 @@ tachyon_cc_unittest( srcs = ["matrix_types_unittest.cc"], deps = [ ":matrix_types", + ":prime_field_num_traits", "//tachyon/base/buffer:vector_buffer", "//tachyon/math/finite_fields/test:finite_field_test", "//tachyon/math/finite_fields/test:gf7", diff --git a/tachyon/math/matrix/matrix_types.h b/tachyon/math/matrix/matrix_types.h index 22f7cfbcc..1d812fd42 100644 --- a/tachyon/math/matrix/matrix_types.h +++ b/tachyon/math/matrix/matrix_types.h @@ -10,26 +10,30 @@ namespace tachyon { namespace math { -template -using Matrix = Eigen::Matrix; +template +using Matrix = Eigen::Matrix; -template -using Vector = Eigen::Matrix; +template +using DiagonalMatrix = Eigen::DiagonalMatrix; -template -using RowVector = Eigen::Matrix; +template +using Vector = Eigen::Matrix; + +template +using RowVector = Eigen::Matrix; } // namespace math namespace base { -template -class Copyable< - Eigen::Matrix> { +class Copyable> { public: - using Matrix = - Eigen::Matrix; + using Matrix = Eigen::Matrix; static bool WriteTo(const Matrix& matrix, Buffer* buffer) { if (!buffer->WriteMany(matrix.rows(), matrix.cols())) return false; @@ -62,7 +66,43 @@ class Copyable< } static size_t EstimateSize(const Matrix& matrix) { - return matrix.size() * sizeof(PrimeField) + sizeof(Eigen::Index) * 2; + return matrix.size() * sizeof(Field) + sizeof(Eigen::Index) * 2; + } +}; + +template +class Copyable> { + public: + using DiagonalMatrix = Eigen::DiagonalMatrix; + using DiagonalVector = typename DiagonalMatrix::DiagonalVectorType; + + static bool WriteTo(const DiagonalMatrix& matrix, Buffer* buffer) { + if (!buffer->WriteMany(matrix.rows())) return false; + const DiagonalVector& diagonal = matrix.diagonal(); + for (Eigen::Index i = 0; i < diagonal.size(); ++i) { + if (!buffer->Write(diagonal.data()[i])) return false; + } + return true; + } + + static bool ReadFrom(const ReadOnlyBuffer& buffer, DiagonalMatrix* matrix) { + Eigen::Index size; + DiagonalVector vector_tmp; + if (!buffer.ReadMany(&size)) return false; + if (Size != Eigen::Dynamic) { + if (size != Size) return false; + } else { + vector_tmp.resize(size); + } + for (Eigen::Index i = 0; i < vector_tmp.size(); ++i) { + if (!buffer.Read(&vector_tmp.data()[i])) return false; + } + *matrix = DiagonalMatrix(std::move(vector_tmp)); + return true; + } + + static size_t EstimateSize(const DiagonalMatrix& matrix) { + return matrix.rows() * sizeof(Field) + sizeof(Eigen::Index); } }; diff --git a/tachyon/math/matrix/matrix_types_unittest.cc b/tachyon/math/matrix/matrix_types_unittest.cc index 71595ac8a..2b71b29fc 100644 --- a/tachyon/math/matrix/matrix_types_unittest.cc +++ b/tachyon/math/matrix/matrix_types_unittest.cc @@ -3,6 +3,7 @@ #include "tachyon/base/buffer/vector_buffer.h" #include "tachyon/math/finite_fields/test/finite_field_test.h" #include "tachyon/math/finite_fields/test/gf7.h" +#include "tachyon/math/matrix/prime_field_num_traits.h" namespace tachyon::math { @@ -22,30 +23,30 @@ TEST_F(MatrixTypesTest, CopyableDynamicMatrix) { { write_buf.set_buffer_offset(0); - Eigen::Matrix value; + Matrix value; ASSERT_FALSE(write_buf.Read(&value)); } { write_buf.set_buffer_offset(0); - Eigen::Matrix value; + Matrix value; ASSERT_FALSE(write_buf.Read(&value)); } { write_buf.set_buffer_offset(0); - Eigen::Matrix value; + Matrix value; ASSERT_TRUE(write_buf.Read(&value)); - EXPECT_TRUE(value == expected); + EXPECT_EQ(value, expected); } { write_buf.set_buffer_offset(0); Matrix value; ASSERT_TRUE(write_buf.Read(&value)); - EXPECT_TRUE(value == expected); + EXPECT_EQ(value, expected); } } TEST_F(MatrixTypesTest, Copyable3x3Matrix) { - Eigen::Matrix expected{ + Matrix expected{ {GF7(0), GF7(1), GF7(2)}, {GF7(3), GF7(4), GF7(5)}, {GF7(6), GF7(0), GF7(1)}, @@ -60,25 +61,79 @@ TEST_F(MatrixTypesTest, Copyable3x3Matrix) { { write_buf.set_buffer_offset(0); - Eigen::Matrix value; + Matrix value; ASSERT_FALSE(write_buf.Read(&value)); } { write_buf.set_buffer_offset(0); - Eigen::Matrix value; + Matrix value; ASSERT_FALSE(write_buf.Read(&value)); } { write_buf.set_buffer_offset(0); - Eigen::Matrix value; + Matrix value; ASSERT_TRUE(write_buf.Read(&value)); - EXPECT_TRUE(value == expected); + EXPECT_EQ(value, expected); } { write_buf.set_buffer_offset(0); Matrix value; ASSERT_TRUE(write_buf.Read(&value)); - EXPECT_TRUE(value == expected); + EXPECT_EQ(value, expected); + } +} + +TEST_F(MatrixTypesTest, CopyableDynamicDiagonalMatrix) { + DiagonalMatrix expected{{GF7(1), GF7(2), GF7(3)}}; + + base::Uint8VectorBuffer write_buf; + ASSERT_TRUE(write_buf.Grow(base::EstimateSize(expected))); + ASSERT_TRUE(write_buf.Write(expected)); + ASSERT_TRUE(write_buf.Done()); + + { + write_buf.set_buffer_offset(0); + DiagonalMatrix value; + ASSERT_FALSE(write_buf.Read(&value)); + } + { + write_buf.set_buffer_offset(0); + DiagonalMatrix value; + ASSERT_TRUE(write_buf.Read(&value)); + EXPECT_EQ(value.diagonal(), expected.diagonal()); + } + { + write_buf.set_buffer_offset(0); + DiagonalMatrix value; + ASSERT_TRUE(write_buf.Read(&value)); + EXPECT_EQ(value.diagonal(), expected.diagonal()); + } +} + +TEST_F(MatrixTypesTest, Copyable3x3DiagonalMatrix) { + DiagonalMatrix expected{GF7(1), GF7(2), GF7(3)}; + + base::Uint8VectorBuffer write_buf; + ASSERT_TRUE(write_buf.Grow(base::EstimateSize(expected))); + ASSERT_TRUE(write_buf.Write(expected)); + ASSERT_TRUE(write_buf.Done()); + + { + write_buf.set_buffer_offset(0); + DiagonalMatrix value; + ASSERT_FALSE(write_buf.Read(&value)); + } + { + write_buf.set_buffer_offset(0); + DiagonalMatrix value; + ASSERT_TRUE(write_buf.Read(&value)); + EXPECT_EQ(value.diagonal(), expected.diagonal()); + } + { + write_buf.set_buffer_offset(0); + DiagonalMatrix value; + ASSERT_TRUE(write_buf.Read(&value)); + EXPECT_EQ(value.diagonal(), expected.diagonal()); } } diff --git a/tachyon/math/matrix/prime_field_num_traits.h b/tachyon/math/matrix/prime_field_num_traits.h index edfb0cabc..6aa46085e 100644 --- a/tachyon/math/matrix/prime_field_num_traits.h +++ b/tachyon/math/matrix/prime_field_num_traits.h @@ -1,6 +1,8 @@ #ifndef TACHYON_MATH_MATRIX_PRIME_FIELD_NUM_TRAITS_H_ #define TACHYON_MATH_MATRIX_PRIME_FIELD_NUM_TRAITS_H_ +#include + #include "third_party/eigen3/Eigen/Core" #include "tachyon/math/finite_fields/finite_field_forwards.h" @@ -8,23 +10,44 @@ namespace Eigen { template -struct NumTraits> - : GenericNumTraits> { +struct CostCalculator; + +template +struct CostCalculator> { constexpr static size_t kLimbNums = tachyon::math::PrimeField::kLimbNums; + using NumTraitsType = + std::conditional_t, + NumTraits>; + + constexpr static int ComputeReadCost() { + return static_cast(kLimbNums * NumTraitsType::ReadCost); + } + constexpr static int ComputeAddCost() { + // In general, c = (a + b) % M = (a + b) > M ? (a + b) - M : (a + b) + return static_cast(kLimbNums * NumTraitsType::AddCost * 3 / 2); + } + constexpr static int ComputeMulCost() { + // In general, c = (a * b) % M = (a * b) - [(a * b) / M] * M + return static_cast( + kLimbNums * (4 * NumTraitsType::MulCost + NumTraitsType::AddCost)); + } +}; +template +struct NumTraits> + : GenericNumTraits> { enum { IsInteger = 1, IsSigned = 0, IsComplex = 0, RequireInitialization = 1, - ReadCost = static_cast(kLimbNums * NumTraits::ReadCost), - // In general, c = (a + b) % M = (a + b) > M ? (a + b) - M : (a + b) + ReadCost = + CostCalculator>::ComputeReadCost(), AddCost = - static_cast(kLimbNums * NumTraits::AddCost * 3 / 2), - // In general, c = (a * b) % M = (a * b) - [(a * b) / M] * M - MulCost = static_cast(kLimbNums * (4 * NumTraits::MulCost + - NumTraits::AddCost)), + CostCalculator>::ComputeAddCost(), + MulCost = + CostCalculator>::ComputeMulCost(), }; };