diff --git a/tachyon/math/matrix/prime_field_num_traits.h b/tachyon/math/matrix/prime_field_num_traits.h index edfb0cabc0..6aa46085e2 100644 --- a/tachyon/math/matrix/prime_field_num_traits.h +++ b/tachyon/math/matrix/prime_field_num_traits.h @@ -1,6 +1,8 @@ #ifndef TACHYON_MATH_MATRIX_PRIME_FIELD_NUM_TRAITS_H_ #define TACHYON_MATH_MATRIX_PRIME_FIELD_NUM_TRAITS_H_ +#include + #include "third_party/eigen3/Eigen/Core" #include "tachyon/math/finite_fields/finite_field_forwards.h" @@ -8,23 +10,44 @@ namespace Eigen { template -struct NumTraits> - : GenericNumTraits> { +struct CostCalculator; + +template +struct CostCalculator> { constexpr static size_t kLimbNums = tachyon::math::PrimeField::kLimbNums; + using NumTraitsType = + std::conditional_t, + NumTraits>; + + constexpr static int ComputeReadCost() { + return static_cast(kLimbNums * NumTraitsType::ReadCost); + } + constexpr static int ComputeAddCost() { + // In general, c = (a + b) % M = (a + b) > M ? (a + b) - M : (a + b) + return static_cast(kLimbNums * NumTraitsType::AddCost * 3 / 2); + } + constexpr static int ComputeMulCost() { + // In general, c = (a * b) % M = (a * b) - [(a * b) / M] * M + return static_cast( + kLimbNums * (4 * NumTraitsType::MulCost + NumTraitsType::AddCost)); + } +}; +template +struct NumTraits> + : GenericNumTraits> { enum { IsInteger = 1, IsSigned = 0, IsComplex = 0, RequireInitialization = 1, - ReadCost = static_cast(kLimbNums * NumTraits::ReadCost), - // In general, c = (a + b) % M = (a + b) > M ? (a + b) - M : (a + b) + ReadCost = + CostCalculator>::ComputeReadCost(), AddCost = - static_cast(kLimbNums * NumTraits::AddCost * 3 / 2), - // In general, c = (a * b) % M = (a * b) - [(a * b) / M] * M - MulCost = static_cast(kLimbNums * (4 * NumTraits::MulCost + - NumTraits::AddCost)), + CostCalculator>::ComputeAddCost(), + MulCost = + CostCalculator>::ComputeMulCost(), }; };