Skip to content

Commit

Permalink
Fboemer/ntt fix (#58)
Browse files Browse the repository at this point in the history
* Fix NTT AVX512 implementation
  • Loading branch information
fboemer committed Sep 2, 2021
1 parent 0c26425 commit 4d9806f
Show file tree
Hide file tree
Showing 21 changed files with 229 additions and 122 deletions.
4 changes: 3 additions & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

## Version 1.2.1
- Fixes a bug in AVX512 floating-point implementation of element-wise vector-vector modular multiplication (https://github.com/microsoft/SEAL/issues/385)
- Fixes a bug in the NTT default allocator (https://gitlab.com/palisade/palisade-development/-/issues/323#note_662270512)
- Fixes a bug in the NTT default constructor (https://gitlab.com/palisade/palisade-development/-/issues/329)
- Fixes a bug in the AVX512 NTT (https://github.com/intel/hexl/pull/58)
- Improves performance of EltwiseFMAModAVX512 on ICX (https://github.com/intel/hexl/pull/42)
- Improves performance of the native NTT
- Adds reference implementations for the radix-4 NTT
- Enables support for pre-built easylogging (https://github.com/intel/hexl/pull/57)

## Version 1.2.0
- Large performance improvement in large (N >= 16384) AVX512 NTTs via recursive implementations
Expand Down
6 changes: 3 additions & 3 deletions benchmark/bench-eltwise-fma-mod.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ static void BM_EltwiseFMAModAddNative(benchmark::State& state) { // NOLINT

BENCHMARK(BM_EltwiseFMAModAddNative)
->Unit(benchmark::kMicrosecond)
->ArgsProduct({{1024, 8192, 16384}, {false, true}});
->ArgsProduct({{1024, 4096, 16384}, {false, true}});

//=================================================================

Expand All @@ -59,7 +59,7 @@ static void BM_EltwiseFMAModAVX512DQ(benchmark::State& state) { // NOLINT

BENCHMARK(BM_EltwiseFMAModAVX512DQ)
->Unit(benchmark::kMicrosecond)
->ArgsProduct({{1024, 8192, 16384}, {false, true}});
->ArgsProduct({{1024, 4096, 16384}, {false, true}});
#endif

//=================================================================
Expand All @@ -84,7 +84,7 @@ static void BM_EltwiseFMAModAVX512IFMA(benchmark::State& state) { // NOLINT

BENCHMARK(BM_EltwiseFMAModAVX512IFMA)
->Unit(benchmark::kMicrosecond)
->ArgsProduct({{1024, 8192, 16384}, {false, true}});
->ArgsProduct({{1024, 4096, 16384}, {false, true}});

#endif

Expand Down
2 changes: 1 addition & 1 deletion benchmark/bench-eltwise-mult-mod.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ static void BM_EltwiseMultMod(benchmark::State& state) { // NOLINT

BENCHMARK(BM_EltwiseMultMod)
->Unit(benchmark::kMicrosecond)
->ArgsProduct({{1024, 8192, 16384}, {48, 60}, {1, 2, 4}});
->ArgsProduct({{1024, 4096, 16384}, {48, 60}, {1, 2, 4}});

//=================================================================

Expand Down
30 changes: 15 additions & 15 deletions benchmark/bench-ntt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace hexl {

static void BM_FwdNTTNativeRadix2(benchmark::State& state) { // NOLINT
size_t ntt_size = state.range(0);
size_t modulus = GeneratePrimes(1, 45, ntt_size)[0];
size_t modulus = GeneratePrimes(1, 45, true, ntt_size)[0];

AlignedVector64<uint64_t> input(ntt_size, 1);
NTT ntt(ntt_size, modulus);
Expand All @@ -43,7 +43,7 @@ BENCHMARK(BM_FwdNTTNativeRadix2)

static void BM_FwdNTTNativeRadix4(benchmark::State& state) { // NOLINT
size_t ntt_size = state.range(0);
size_t modulus = GeneratePrimes(1, 45, ntt_size)[0];
size_t modulus = GeneratePrimes(1, 45, true, ntt_size)[0];

AlignedVector64<uint64_t> input(ntt_size, 1);
NTT ntt(ntt_size, modulus);
Expand All @@ -67,7 +67,7 @@ BENCHMARK(BM_FwdNTTNativeRadix4)
static void BM_FwdNTT_AVX512IFMA(benchmark::State& state) { // NOLINT
size_t ntt_size = state.range(0);
size_t modulus_bits = 49;
size_t modulus = GeneratePrimes(1, modulus_bits, ntt_size)[0];
size_t modulus = GeneratePrimes(1, modulus_bits, true, ntt_size)[0];

AlignedVector64<uint64_t> input(ntt_size, 1);
NTT ntt(ntt_size, modulus);
Expand Down Expand Up @@ -96,7 +96,7 @@ BENCHMARK(BM_FwdNTT_AVX512IFMA)
static void BM_FwdNTT_AVX512IFMALazy(benchmark::State& state) { // NOLINT
size_t ntt_size = state.range(0);
size_t modulus_bits = 49;
size_t modulus = GeneratePrimes(1, modulus_bits, ntt_size)[0];
size_t modulus = GeneratePrimes(1, modulus_bits, true, ntt_size)[0];

AlignedVector64<uint64_t> input(ntt_size, 1);
NTT ntt(ntt_size, modulus);
Expand Down Expand Up @@ -132,7 +132,7 @@ static void BM_FwdNTT_AVX512DQ_32(benchmark::State& state) { // NOLINT
size_t ntt_size = state.range(0);
uint64_t output_mod_factor = state.range(1);
size_t modulus_bits = 29;
size_t modulus = GeneratePrimes(1, modulus_bits, ntt_size)[0];
size_t modulus = GeneratePrimes(1, modulus_bits, true, ntt_size)[0];

AlignedVector64<uint64_t> input(ntt_size, 1);
NTT ntt(ntt_size, modulus);
Expand Down Expand Up @@ -163,7 +163,7 @@ static void BM_FwdNTT_AVX512DQ_64(benchmark::State& state) { // NOLINT
size_t ntt_size = state.range(0);
uint64_t output_mod_factor = state.range(1);
size_t modulus_bits = 55;
size_t modulus = GeneratePrimes(1, modulus_bits, ntt_size)[0];
size_t modulus = GeneratePrimes(1, modulus_bits, true, ntt_size)[0];

AlignedVector64<uint64_t> input(ntt_size, 1);
NTT ntt(ntt_size, modulus);
Expand Down Expand Up @@ -195,7 +195,7 @@ BENCHMARK(BM_FwdNTT_AVX512DQ_64)
// state[0] is the degree
static void BM_FwdNTTInPlace(benchmark::State& state) { // NOLINT
size_t ntt_size = state.range(0);
size_t modulus = GeneratePrimes(1, 61, ntt_size)[0];
size_t modulus = GeneratePrimes(1, 61, true, ntt_size)[0];

AlignedVector64<uint64_t> input(ntt_size, 1);
NTT ntt(ntt_size, modulus);
Expand All @@ -216,7 +216,7 @@ BENCHMARK(BM_FwdNTTInPlace)
// state[0] is the degree
static void BM_FwdNTTCopy(benchmark::State& state) { // NOLINT
size_t ntt_size = state.range(0);
size_t modulus = GeneratePrimes(1, 45, ntt_size)[0];
size_t modulus = GeneratePrimes(1, 45, true, ntt_size)[0];

AlignedVector64<uint64_t> input(ntt_size, 1);
AlignedVector64<uint64_t> output(ntt_size, 1);
Expand All @@ -236,7 +236,7 @@ BENCHMARK(BM_FwdNTTCopy)
// state[0] is the degree
static void BM_InvNTTCopy(benchmark::State& state) { // NOLINT
size_t ntt_size = state.range(0);
size_t modulus = GeneratePrimes(1, 45, ntt_size)[0];
size_t modulus = GeneratePrimes(1, 45, true, ntt_size)[0];

AlignedVector64<uint64_t> input(ntt_size, 1);
AlignedVector64<uint64_t> output(ntt_size, 1);
Expand All @@ -259,7 +259,7 @@ BENCHMARK(BM_InvNTTCopy)

static void BM_InvNTTNativeRadix2(benchmark::State& state) { // NOLINT
size_t ntt_size = state.range(0);
size_t modulus = GeneratePrimes(1, 45, ntt_size)[0];
size_t modulus = GeneratePrimes(1, 45, true, ntt_size)[0];

AlignedVector64<uint64_t> input(ntt_size, 1);
NTT ntt(ntt_size, modulus);
Expand All @@ -284,7 +284,7 @@ BENCHMARK(BM_InvNTTNativeRadix2)

static void BM_InvNTTNativeRadix4(benchmark::State& state) { // NOLINT
size_t ntt_size = state.range(0);
size_t modulus = GeneratePrimes(1, 45, ntt_size)[0];
size_t modulus = GeneratePrimes(1, 45, true, ntt_size)[0];

AlignedVector64<uint64_t> input(ntt_size, 1);
NTT ntt(ntt_size, modulus);
Expand All @@ -311,7 +311,7 @@ BENCHMARK(BM_InvNTTNativeRadix4)
// state[0] is the degree
static void BM_InvNTT_AVX512IFMA(benchmark::State& state) { // NOLINT
size_t ntt_size = state.range(0);
size_t modulus = GeneratePrimes(1, 49, ntt_size)[0];
size_t modulus = GeneratePrimes(1, 49, true, ntt_size)[0];

AlignedVector64<uint64_t> input(ntt_size, 1);
NTT ntt(ntt_size, modulus);
Expand All @@ -337,7 +337,7 @@ BENCHMARK(BM_InvNTT_AVX512IFMA)
// state[0] is the degree
static void BM_InvNTT_AVX512IFMALazy(benchmark::State& state) { // NOLINT
size_t ntt_size = state.range(0);
size_t modulus = GeneratePrimes(1, 49, ntt_size)[0];
size_t modulus = GeneratePrimes(1, 49, true, ntt_size)[0];

AlignedVector64<uint64_t> input(ntt_size, 1);
NTT ntt(ntt_size, modulus);
Expand Down Expand Up @@ -367,7 +367,7 @@ BENCHMARK(BM_InvNTT_AVX512IFMALazy)
static void BM_InvNTT_AVX512DQ_32(benchmark::State& state) { // NOLINT
size_t ntt_size = state.range(0);
uint64_t output_mod_factor = state.range(1);
size_t modulus = GeneratePrimes(1, 30, ntt_size)[0];
size_t modulus = GeneratePrimes(1, 30, true, ntt_size)[0];

AlignedVector64<uint64_t> input(ntt_size, 1);
NTT ntt(ntt_size, modulus);
Expand Down Expand Up @@ -395,7 +395,7 @@ BENCHMARK(BM_InvNTT_AVX512DQ_32)
static void BM_InvNTT_AVX512DQ_64(benchmark::State& state) { // NOLINT
size_t ntt_size = state.range(0);
uint64_t output_mod_factor = state.range(1);
size_t modulus = GeneratePrimes(1, 61, ntt_size)[0];
size_t modulus = GeneratePrimes(1, 61, true, ntt_size)[0];

AlignedVector64<uint64_t> input(ntt_size, 1);
NTT ntt(ntt_size, modulus);
Expand Down
2 changes: 2 additions & 0 deletions hexl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ endif()
if (CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang")
target_compile_options(hexl PRIVATE -Wall -Wconversion -Wshadow -pedantic -Wextra
-Wno-unknown-pragmas -march=native -O3 -fomit-frame-pointer
-Wno-sign-conversion
-Wno-implicit-int-conversion
)
# Avoid 3rd-party dependency warnings when including HEXL as a dependency
target_compile_options(hexl PUBLIC
Expand Down
7 changes: 5 additions & 2 deletions hexl/include/hexl/number-theory/number-theory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,14 +176,17 @@ inline unsigned char AddUInt64(uint64_t operand1, uint64_t operand2,
/// @brief Returns whether or not the input is prime
bool IsPrime(uint64_t n);

/// @brief Generates a list of num_primes primes in the range [2^(bit_size,
/// @brief Generates a list of num_primes primes in the range [2^(bit_size),
// 2^(bit_size+1)]. Ensures each prime q satisfies
// q % (2*ntt_size+1)) == 1
/// @param[in] num_primes Number of primes to generate
/// @param[in] bit_size Bit size of each prime
/// @param[in] prefer_small_primes When true, returns primes starting from
/// 2^(bit_size); when false, returns primes starting from 2^(bit_size+1)
/// @param[in] ntt_size N such that each prime q satisfies q % (2N) == 1. N must
/// be a power of two
/// be a power of two less than 2^bit_size.
std::vector<uint64_t> GeneratePrimes(size_t num_primes, size_t bit_size,
bool prefer_small_primes,
size_t ntt_size = 1);

/// @brief Returns input mod modulus, computed via 64-bit Barrett reduction
Expand Down
2 changes: 1 addition & 1 deletion hexl/ntt/fwd-ntt-avx512.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ void ForwardTransformToBitReverseAVX512(
const uint64_t* W = &root_of_unity_powers[W_idx];
const uint64_t* W_precon = &precon_root_of_unity_powers[W_idx];

if (input_mod_factor <= 2) {
if ((input_mod_factor <= 2) && (recursion_depth == 0)) {
FwdT8<BitShift, true>(operand, v_neg_modulus, v_twice_mod, t, m, W,
W_precon);
} else {
Expand Down
2 changes: 1 addition & 1 deletion hexl/ntt/inv-ntt-avx512.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ void InverseTransformFromBitReverseAVX512(
// t = 1
const uint64_t* W = &inv_root_of_unity_powers[W_idx];
const uint64_t* W_precon = &precon_inv_root_of_unity_powers[W_idx];
if (input_mod_factor == 1) {
if ((input_mod_factor == 1) && (recursion_depth == 0)) {
InvT1<BitShift, true>(operand, v_neg_modulus, v_twice_mod, m, W,
W_precon);
} else {
Expand Down
32 changes: 27 additions & 5 deletions hexl/number-theory/number-theory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ bool IsPrime(uint64_t n) {
}

std::vector<uint64_t> GeneratePrimes(size_t num_primes, size_t bit_size,
bool prefer_small_primes,
size_t ntt_size) {
HEXL_CHECK(num_primes > 0, "num_primes == 0");
HEXL_CHECK(IsPowerOfTwo(ntt_size),
Expand All @@ -231,18 +232,39 @@ std::vector<uint64_t> GeneratePrimes(size_t num_primes, size_t bit_size,
"log2(ntt_size) " << Log2(ntt_size)
<< " should be less than bit_size " << bit_size);

uint64_t value = (1ULL << bit_size) + 1;
int64_t prime_lower_bound = (1LL << bit_size) + 1LL;
int64_t prime_upper_bound = (1LL << (bit_size + 1LL)) - 1LL;

// Keep signed to enable negative step
int64_t prime_candidate =
prefer_small_primes
? prime_lower_bound
: prime_upper_bound - (prime_upper_bound % (2 * ntt_size)) + 1;
HEXL_CHECK(prime_candidate % (2 * ntt_size) == 1, "bad prime candidate");

// Ensure prime % 2 * ntt_size == 1
int64_t prime_candidate_step =
(prefer_small_primes ? 1 : -1) * 2 * static_cast<int64_t>(ntt_size);

auto continue_condition = [&](int64_t local_candidate_prime) {
if (prefer_small_primes) {
return local_candidate_prime < prime_upper_bound;
} else {
return local_candidate_prime > prime_lower_bound;
}
};

std::vector<uint64_t> ret;

while (value < (1ULL << (bit_size + 1))) {
if (IsPrime(value)) {
ret.emplace_back(value);
while (continue_condition(prime_candidate)) {
if (IsPrime(prime_candidate)) {
HEXL_CHECK(prime_candidate % (2 * ntt_size) == 1, "bad prime candidate");
ret.emplace_back(static_cast<uint64_t>(prime_candidate));
if (ret.size() == num_primes) {
return ret;
}
}
value += 2 * ntt_size;
prime_candidate += prime_candidate_step;
}

HEXL_CHECK(false, "Failed to find enough primes");
Expand Down
4 changes: 2 additions & 2 deletions test/test-eltwise-add-mod-avx512.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ TEST(EltwiseAddMod, vector_vector_avx512_big) {
GTEST_SKIP();
}

uint64_t modulus = GeneratePrimes(1, 60, 1024)[0];
uint64_t modulus = GeneratePrimes(1, 60, true, 1024)[0];

std::vector<uint64_t> op1{modulus - 1, modulus - 1, modulus - 2, modulus - 2,
modulus - 3, modulus - 3, modulus - 4, modulus - 4};
Expand All @@ -72,7 +72,7 @@ TEST(EltwiseAddMod, vector_scalar_avx512_big) {
GTEST_SKIP();
}

uint64_t modulus = GeneratePrimes(1, 60, 1024)[0];
uint64_t modulus = GeneratePrimes(1, 60, true, 1024)[0];

std::vector<uint64_t> op1{modulus - 1, modulus - 1, modulus - 2, modulus - 2,
modulus - 3, modulus - 3, modulus - 4, modulus - 4};
Expand Down
4 changes: 2 additions & 2 deletions test/test-eltwise-add-mod.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ TEST(EltwiseAddMod, vector_scalar_native_small) {
}

TEST(EltwiseAddMod, vector_vector_native_big) {
uint64_t modulus = GeneratePrimes(1, 60, 1024)[0];
uint64_t modulus = GeneratePrimes(1, 60, true, 1024)[0];

std::vector<uint64_t> op1{modulus - 1, modulus - 1, modulus - 2, modulus - 2,
modulus - 3, modulus - 3, modulus - 4, modulus - 4};
Expand All @@ -97,7 +97,7 @@ TEST(EltwiseAddMod, vector_vector_native_big) {
}

TEST(EltwiseAddMod, vector_scalar_native_big) {
uint64_t modulus = GeneratePrimes(1, 60, 1024)[0];
uint64_t modulus = GeneratePrimes(1, 60, true, 1024)[0];

std::vector<uint64_t> op1{modulus - 1, modulus - 1, modulus - 2, modulus - 2,
modulus - 3, modulus - 3, modulus - 4, modulus - 4};
Expand Down
2 changes: 1 addition & 1 deletion test/test-eltwise-cmp-sub-mod-avx512.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ TEST(EltwiseCmpSubMod, AVX512) {

for (size_t cmp = 0; cmp < 8; ++cmp) {
for (size_t bits = 48; bits <= 51; ++bits) {
uint64_t modulus = GeneratePrimes(1, bits, 1024)[0];
uint64_t modulus = GeneratePrimes(1, bits, true, 1024)[0];
std::uniform_int_distribution<uint64_t> distrib(0, modulus - 1);

for (size_t trial = 0; trial < 200; ++trial) {
Expand Down
2 changes: 1 addition & 1 deletion test/test-eltwise-fma-mod-avx512.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ TEST(EltwiseFMAMod, AVX512IFMA) {
constexpr uint64_t input_mod_factor = 8;

for (size_t bits = 48; bits <= 51; ++bits) {
uint64_t modulus = GeneratePrimes(1, bits, length)[0];
uint64_t modulus = GeneratePrimes(1, bits, true, length)[0];
std::uniform_int_distribution<uint64_t> distrib(
0, input_mod_factor * modulus - 1);

Expand Down
2 changes: 1 addition & 1 deletion test/test-eltwise-mult-mod-avx512.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ TEST(EltwiseMultMod, avx512_int2) {
if (!has_avx512dq) {
GTEST_SKIP();
}
uint64_t modulus = GeneratePrimes(1, 60, 1024)[0];
uint64_t modulus = GeneratePrimes(1, 60, true, 1024)[0];

std::vector<uint64_t> op1{modulus - 3, 1, 1, 1, 1, 1, 1, 1};
std::vector<uint64_t> op2{modulus - 4, 1, 1, 1, 1, 1, 1, 1};
Expand Down
8 changes: 4 additions & 4 deletions test/test-eltwise-mult-mod.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ TEST(EltwiseMultModInPlace, 8_bounds) {
#endif

TEST(EltwiseMultModInPlace, 9) {
uint64_t modulus = GeneratePrimes(1, 51, 1024)[0];
uint64_t modulus = GeneratePrimes(1, 51, true, 1024)[0];

std::vector<uint64_t> op1{modulus - 3, 1, 2, 3, 4, 5, 6, 7, 8};
std::vector<uint64_t> op2{modulus - 4, 8, 7, 6, 5, 4, 3, 2, 1};
Expand Down Expand Up @@ -105,7 +105,7 @@ TEST(EltwiseMultMod, native_mult2) {
}

TEST(EltwiseMultMod, native2_big) {
uint64_t modulus = GeneratePrimes(1, 60, 1024)[0];
uint64_t modulus = GeneratePrimes(1, 60, true, 1024)[0];

std::vector<uint64_t> op1{modulus - 3, 1, 1, 1, 1, 1, 1, 1};
std::vector<uint64_t> op2{modulus - 4, 1, 1, 1, 1, 1, 1, 1};
Expand All @@ -119,7 +119,7 @@ TEST(EltwiseMultMod, native2_big) {
}

TEST(EltwiseMultMod, 8big) {
uint64_t modulus = GeneratePrimes(1, 48, 1024)[0];
uint64_t modulus = GeneratePrimes(1, 48, true, 1024)[0];

std::vector<uint64_t> op1{modulus - 1, 1, 1, 1, 1, 1, 1, 1};
std::vector<uint64_t> op2{modulus - 1, 1, 1, 1, 1, 1, 1, 1};
Expand Down Expand Up @@ -198,7 +198,7 @@ TEST(EltwiseMultMod, 8_bounds) {
#endif

TEST(EltwiseMultMod, 9) {
uint64_t modulus = GeneratePrimes(1, 51, 1024)[0];
uint64_t modulus = GeneratePrimes(1, 51, true, 1024)[0];

std::vector<uint64_t> op1{modulus - 3, 1, 2, 3, 4, 5, 6, 7, 8};
std::vector<uint64_t> op2{modulus - 4, 8, 7, 6, 5, 4, 3, 2, 1};
Expand Down
Loading

0 comments on commit 4d9806f

Please sign in to comment.