Skip to content

Commit

Permalink
feat(math): compute the costs differentely for the small prime field
Browse files Browse the repository at this point in the history
  • Loading branch information
chokobole committed May 8, 2024
1 parent 362a6e5 commit 36b475b
Showing 1 changed file with 31 additions and 8 deletions.
39 changes: 31 additions & 8 deletions tachyon/math/matrix/prime_field_num_traits.h
Original file line number Diff line number Diff line change
@@ -1,30 +1,53 @@
#ifndef TACHYON_MATH_MATRIX_PRIME_FIELD_NUM_TRAITS_H_
#define TACHYON_MATH_MATRIX_PRIME_FIELD_NUM_TRAITS_H_

#include <type_traits>

#include "third_party/eigen3/Eigen/Core"

#include "tachyon/math/finite_fields/finite_field_forwards.h"

namespace Eigen {

template <typename Config>
struct NumTraits<tachyon::math::PrimeField<Config>>
: GenericNumTraits<tachyon::math::PrimeField<Config>> {
struct CostCalculator;

template <typename Config>
struct CostCalculator<tachyon::math::PrimeField<Config>> {
constexpr static size_t kLimbNums =
tachyon::math::PrimeField<Config>::kLimbNums;
using NumTraitsType =
std::conditional_t<Config::kModulusBits <= 32, NumTraits<uint32_t>,
NumTraits<uint64_t>>;

constexpr static int ComputeReadCost() {
return static_cast<int>(kLimbNums * NumTraitsType::ReadCost);
}
constexpr static int ComputeAddCost() {
// In general, c = (a + b) % M = (a + b) > M ? (a + b) - M : (a + b)
return static_cast<int>(kLimbNums * NumTraitsType::AddCost * 3 / 2);
}
constexpr static int ComputeMulCost() {
// In general, c = (a * b) % M = (a * b) - [(a * b) / M] * M
return static_cast<int>(
kLimbNums * (4 * NumTraitsType::MulCost + NumTraitsType::AddCost));
}
};

template <typename Config>
struct NumTraits<tachyon::math::PrimeField<Config>>
: GenericNumTraits<tachyon::math::PrimeField<Config>> {
enum {
IsInteger = 1,
IsSigned = 0,
IsComplex = 0,
RequireInitialization = 1,
ReadCost = static_cast<int>(kLimbNums * NumTraits<uint64_t>::ReadCost),
// In general, c = (a + b) % M = (a + b) > M ? (a + b) - M : (a + b)
ReadCost =
CostCalculator<tachyon::math::PrimeField<Config>>::ComputeReadCost(),
AddCost =
static_cast<int>(kLimbNums * NumTraits<uint64_t>::AddCost * 3 / 2),
// In general, c = (a * b) % M = (a * b) - [(a * b) / M] * M
MulCost = static_cast<int>(kLimbNums * (4 * NumTraits<uint64_t>::MulCost +
NumTraits<uint64_t>::AddCost)),
CostCalculator<tachyon::math::PrimeField<Config>>::ComputeAddCost(),
MulCost =
CostCalculator<tachyon::math::PrimeField<Config>>::ComputeMulCost(),
};
};

Expand Down

0 comments on commit 36b475b

Please sign in to comment.