From 4d9806fe00b0ff5bbd75794b3ea0ab651a2ae437 Mon Sep 17 00:00:00 2001 From: Fabian Boemer Date: Wed, 1 Sep 2021 11:46:36 -0700 Subject: [PATCH] Fboemer/ntt fix (#58) * Fix NTT AVX512 implementation --- CHANGES.md | 4 +- benchmark/bench-eltwise-fma-mod.cpp | 6 +- benchmark/bench-eltwise-mult-mod.cpp | 2 +- benchmark/bench-ntt.cpp | 30 +-- hexl/CMakeLists.txt | 2 + .../hexl/number-theory/number-theory.hpp | 7 +- hexl/ntt/fwd-ntt-avx512.cpp | 2 +- hexl/ntt/inv-ntt-avx512.cpp | 2 +- hexl/number-theory/number-theory.cpp | 32 ++- test/test-eltwise-add-mod-avx512.cpp | 4 +- test/test-eltwise-add-mod.cpp | 4 +- test/test-eltwise-cmp-sub-mod-avx512.cpp | 2 +- test/test-eltwise-fma-mod-avx512.cpp | 2 +- test/test-eltwise-mult-mod-avx512.cpp | 2 +- test/test-eltwise-mult-mod.cpp | 8 +- test/test-eltwise-reduce-mod-avx512.cpp | 8 +- test/test-eltwise-sub-mod-avx512.cpp | 4 +- test/test-eltwise-sub-mod.cpp | 4 +- test/test-ntt-avx512.cpp | 204 ++++++++++++------ test/test-ntt.cpp | 10 +- test/test-number-theory.cpp | 12 +- 21 files changed, 229 insertions(+), 122 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index f0678207..290055c2 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -2,10 +2,12 @@ ## Version 1.2.1 - Fixes a bug in AVX512 floating-point implementation of element-wise vector-vector modular multiplication (https://github.com/microsoft/SEAL/issues/385) -- Fixes a bug in the NTT default allocator (https://gitlab.com/palisade/palisade-development/-/issues/323#note_662270512) +- Fixes a bug in the NTT default constructor (https://gitlab.com/palisade/palisade-development/-/issues/329) +- Fixes a bug in the AVX512 NTT (https://github.com/intel/hexl/pull/58) - Improves performance of EltwiseFMAModAVX512 on ICX (https://github.com/intel/hexl/pull/42) - Improves performance of the native NTT - Adds reference implementations for the radix-4 NTT +- Enables support for pre-built easylogging (https://github.com/intel/hexl/pull/57) ## Version 1.2.0 - Large performance improvement in large (N >= 16384) AVX512 NTTs via recursive implementations diff --git a/benchmark/bench-eltwise-fma-mod.cpp b/benchmark/bench-eltwise-fma-mod.cpp index 7c697f68..212ac594 100644 --- a/benchmark/bench-eltwise-fma-mod.cpp +++ b/benchmark/bench-eltwise-fma-mod.cpp @@ -35,7 +35,7 @@ static void BM_EltwiseFMAModAddNative(benchmark::State& state) { // NOLINT BENCHMARK(BM_EltwiseFMAModAddNative) ->Unit(benchmark::kMicrosecond) - ->ArgsProduct({{1024, 8192, 16384}, {false, true}}); + ->ArgsProduct({{1024, 4096, 16384}, {false, true}}); //================================================================= @@ -59,7 +59,7 @@ static void BM_EltwiseFMAModAVX512DQ(benchmark::State& state) { // NOLINT BENCHMARK(BM_EltwiseFMAModAVX512DQ) ->Unit(benchmark::kMicrosecond) - ->ArgsProduct({{1024, 8192, 16384}, {false, true}}); + ->ArgsProduct({{1024, 4096, 16384}, {false, true}}); #endif //================================================================= @@ -84,7 +84,7 @@ static void BM_EltwiseFMAModAVX512IFMA(benchmark::State& state) { // NOLINT BENCHMARK(BM_EltwiseFMAModAVX512IFMA) ->Unit(benchmark::kMicrosecond) - ->ArgsProduct({{1024, 8192, 16384}, {false, true}}); + ->ArgsProduct({{1024, 4096, 16384}, {false, true}}); #endif diff --git a/benchmark/bench-eltwise-mult-mod.cpp b/benchmark/bench-eltwise-mult-mod.cpp index 720978f2..dfe6bf3f 100644 --- a/benchmark/bench-eltwise-mult-mod.cpp +++ b/benchmark/bench-eltwise-mult-mod.cpp @@ -36,7 +36,7 @@ static void BM_EltwiseMultMod(benchmark::State& state) { // NOLINT BENCHMARK(BM_EltwiseMultMod) ->Unit(benchmark::kMicrosecond) - ->ArgsProduct({{1024, 8192, 16384}, {48, 60}, {1, 2, 4}}); + ->ArgsProduct({{1024, 4096, 16384}, {48, 60}, {1, 2, 4}}); //================================================================= diff --git a/benchmark/bench-ntt.cpp b/benchmark/bench-ntt.cpp index 0187f024..58846ca5 100644 --- a/benchmark/bench-ntt.cpp +++ b/benchmark/bench-ntt.cpp @@ -22,7 +22,7 @@ namespace hexl { static void BM_FwdNTTNativeRadix2(benchmark::State& state) { // NOLINT size_t ntt_size = state.range(0); - size_t modulus = GeneratePrimes(1, 45, ntt_size)[0]; + size_t modulus = GeneratePrimes(1, 45, true, ntt_size)[0]; AlignedVector64 input(ntt_size, 1); NTT ntt(ntt_size, modulus); @@ -43,7 +43,7 @@ BENCHMARK(BM_FwdNTTNativeRadix2) static void BM_FwdNTTNativeRadix4(benchmark::State& state) { // NOLINT size_t ntt_size = state.range(0); - size_t modulus = GeneratePrimes(1, 45, ntt_size)[0]; + size_t modulus = GeneratePrimes(1, 45, true, ntt_size)[0]; AlignedVector64 input(ntt_size, 1); NTT ntt(ntt_size, modulus); @@ -67,7 +67,7 @@ BENCHMARK(BM_FwdNTTNativeRadix4) static void BM_FwdNTT_AVX512IFMA(benchmark::State& state) { // NOLINT size_t ntt_size = state.range(0); size_t modulus_bits = 49; - size_t modulus = GeneratePrimes(1, modulus_bits, ntt_size)[0]; + size_t modulus = GeneratePrimes(1, modulus_bits, true, ntt_size)[0]; AlignedVector64 input(ntt_size, 1); NTT ntt(ntt_size, modulus); @@ -96,7 +96,7 @@ BENCHMARK(BM_FwdNTT_AVX512IFMA) static void BM_FwdNTT_AVX512IFMALazy(benchmark::State& state) { // NOLINT size_t ntt_size = state.range(0); size_t modulus_bits = 49; - size_t modulus = GeneratePrimes(1, modulus_bits, ntt_size)[0]; + size_t modulus = GeneratePrimes(1, modulus_bits, true, ntt_size)[0]; AlignedVector64 input(ntt_size, 1); NTT ntt(ntt_size, modulus); @@ -132,7 +132,7 @@ static void BM_FwdNTT_AVX512DQ_32(benchmark::State& state) { // NOLINT size_t ntt_size = state.range(0); uint64_t output_mod_factor = state.range(1); size_t modulus_bits = 29; - size_t modulus = GeneratePrimes(1, modulus_bits, ntt_size)[0]; + size_t modulus = GeneratePrimes(1, modulus_bits, true, ntt_size)[0]; AlignedVector64 input(ntt_size, 1); NTT ntt(ntt_size, modulus); @@ -163,7 +163,7 @@ static void BM_FwdNTT_AVX512DQ_64(benchmark::State& state) { // NOLINT size_t ntt_size = state.range(0); uint64_t output_mod_factor = state.range(1); size_t modulus_bits = 55; - size_t modulus = GeneratePrimes(1, modulus_bits, ntt_size)[0]; + size_t modulus = GeneratePrimes(1, modulus_bits, true, ntt_size)[0]; AlignedVector64 input(ntt_size, 1); NTT ntt(ntt_size, modulus); @@ -195,7 +195,7 @@ BENCHMARK(BM_FwdNTT_AVX512DQ_64) // state[0] is the degree static void BM_FwdNTTInPlace(benchmark::State& state) { // NOLINT size_t ntt_size = state.range(0); - size_t modulus = GeneratePrimes(1, 61, ntt_size)[0]; + size_t modulus = GeneratePrimes(1, 61, true, ntt_size)[0]; AlignedVector64 input(ntt_size, 1); NTT ntt(ntt_size, modulus); @@ -216,7 +216,7 @@ BENCHMARK(BM_FwdNTTInPlace) // state[0] is the degree static void BM_FwdNTTCopy(benchmark::State& state) { // NOLINT size_t ntt_size = state.range(0); - size_t modulus = GeneratePrimes(1, 45, ntt_size)[0]; + size_t modulus = GeneratePrimes(1, 45, true, ntt_size)[0]; AlignedVector64 input(ntt_size, 1); AlignedVector64 output(ntt_size, 1); @@ -236,7 +236,7 @@ BENCHMARK(BM_FwdNTTCopy) // state[0] is the degree static void BM_InvNTTCopy(benchmark::State& state) { // NOLINT size_t ntt_size = state.range(0); - size_t modulus = GeneratePrimes(1, 45, ntt_size)[0]; + size_t modulus = GeneratePrimes(1, 45, true, ntt_size)[0]; AlignedVector64 input(ntt_size, 1); AlignedVector64 output(ntt_size, 1); @@ -259,7 +259,7 @@ BENCHMARK(BM_InvNTTCopy) static void BM_InvNTTNativeRadix2(benchmark::State& state) { // NOLINT size_t ntt_size = state.range(0); - size_t modulus = GeneratePrimes(1, 45, ntt_size)[0]; + size_t modulus = GeneratePrimes(1, 45, true, ntt_size)[0]; AlignedVector64 input(ntt_size, 1); NTT ntt(ntt_size, modulus); @@ -284,7 +284,7 @@ BENCHMARK(BM_InvNTTNativeRadix2) static void BM_InvNTTNativeRadix4(benchmark::State& state) { // NOLINT size_t ntt_size = state.range(0); - size_t modulus = GeneratePrimes(1, 45, ntt_size)[0]; + size_t modulus = GeneratePrimes(1, 45, true, ntt_size)[0]; AlignedVector64 input(ntt_size, 1); NTT ntt(ntt_size, modulus); @@ -311,7 +311,7 @@ BENCHMARK(BM_InvNTTNativeRadix4) // state[0] is the degree static void BM_InvNTT_AVX512IFMA(benchmark::State& state) { // NOLINT size_t ntt_size = state.range(0); - size_t modulus = GeneratePrimes(1, 49, ntt_size)[0]; + size_t modulus = GeneratePrimes(1, 49, true, ntt_size)[0]; AlignedVector64 input(ntt_size, 1); NTT ntt(ntt_size, modulus); @@ -337,7 +337,7 @@ BENCHMARK(BM_InvNTT_AVX512IFMA) // state[0] is the degree static void BM_InvNTT_AVX512IFMALazy(benchmark::State& state) { // NOLINT size_t ntt_size = state.range(0); - size_t modulus = GeneratePrimes(1, 49, ntt_size)[0]; + size_t modulus = GeneratePrimes(1, 49, true, ntt_size)[0]; AlignedVector64 input(ntt_size, 1); NTT ntt(ntt_size, modulus); @@ -367,7 +367,7 @@ BENCHMARK(BM_InvNTT_AVX512IFMALazy) static void BM_InvNTT_AVX512DQ_32(benchmark::State& state) { // NOLINT size_t ntt_size = state.range(0); uint64_t output_mod_factor = state.range(1); - size_t modulus = GeneratePrimes(1, 30, ntt_size)[0]; + size_t modulus = GeneratePrimes(1, 30, true, ntt_size)[0]; AlignedVector64 input(ntt_size, 1); NTT ntt(ntt_size, modulus); @@ -395,7 +395,7 @@ BENCHMARK(BM_InvNTT_AVX512DQ_32) static void BM_InvNTT_AVX512DQ_64(benchmark::State& state) { // NOLINT size_t ntt_size = state.range(0); uint64_t output_mod_factor = state.range(1); - size_t modulus = GeneratePrimes(1, 61, ntt_size)[0]; + size_t modulus = GeneratePrimes(1, 61, true, ntt_size)[0]; AlignedVector64 input(ntt_size, 1); NTT ntt(ntt_size, modulus); diff --git a/hexl/CMakeLists.txt b/hexl/CMakeLists.txt index 7b505373..764229f4 100644 --- a/hexl/CMakeLists.txt +++ b/hexl/CMakeLists.txt @@ -76,6 +76,8 @@ endif() if (CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang") target_compile_options(hexl PRIVATE -Wall -Wconversion -Wshadow -pedantic -Wextra -Wno-unknown-pragmas -march=native -O3 -fomit-frame-pointer + -Wno-sign-conversion + -Wno-implicit-int-conversion ) # Avoid 3rd-party dependency warnings when including HEXL as a dependency target_compile_options(hexl PUBLIC diff --git a/hexl/include/hexl/number-theory/number-theory.hpp b/hexl/include/hexl/number-theory/number-theory.hpp index 16e70fad..f003b918 100644 --- a/hexl/include/hexl/number-theory/number-theory.hpp +++ b/hexl/include/hexl/number-theory/number-theory.hpp @@ -176,14 +176,17 @@ inline unsigned char AddUInt64(uint64_t operand1, uint64_t operand2, /// @brief Returns whether or not the input is prime bool IsPrime(uint64_t n); -/// @brief Generates a list of num_primes primes in the range [2^(bit_size, +/// @brief Generates a list of num_primes primes in the range [2^(bit_size), // 2^(bit_size+1)]. Ensures each prime q satisfies // q % (2*ntt_size+1)) == 1 /// @param[in] num_primes Number of primes to generate /// @param[in] bit_size Bit size of each prime +/// @param[in] prefer_small_primes When true, returns primes starting from +/// 2^(bit_size); when false, returns primes starting from 2^(bit_size+1) /// @param[in] ntt_size N such that each prime q satisfies q % (2N) == 1. N must -/// be a power of two +/// be a power of two less than 2^bit_size. std::vector GeneratePrimes(size_t num_primes, size_t bit_size, + bool prefer_small_primes, size_t ntt_size = 1); /// @brief Returns input mod modulus, computed via 64-bit Barrett reduction diff --git a/hexl/ntt/fwd-ntt-avx512.cpp b/hexl/ntt/fwd-ntt-avx512.cpp index d099ccc1..02be569f 100644 --- a/hexl/ntt/fwd-ntt-avx512.cpp +++ b/hexl/ntt/fwd-ntt-avx512.cpp @@ -266,7 +266,7 @@ void ForwardTransformToBitReverseAVX512( const uint64_t* W = &root_of_unity_powers[W_idx]; const uint64_t* W_precon = &precon_root_of_unity_powers[W_idx]; - if (input_mod_factor <= 2) { + if ((input_mod_factor <= 2) && (recursion_depth == 0)) { FwdT8(operand, v_neg_modulus, v_twice_mod, t, m, W, W_precon); } else { diff --git a/hexl/ntt/inv-ntt-avx512.cpp b/hexl/ntt/inv-ntt-avx512.cpp index 646b2b14..a548a770 100644 --- a/hexl/ntt/inv-ntt-avx512.cpp +++ b/hexl/ntt/inv-ntt-avx512.cpp @@ -260,7 +260,7 @@ void InverseTransformFromBitReverseAVX512( // t = 1 const uint64_t* W = &inv_root_of_unity_powers[W_idx]; const uint64_t* W_precon = &precon_inv_root_of_unity_powers[W_idx]; - if (input_mod_factor == 1) { + if ((input_mod_factor == 1) && (recursion_depth == 0)) { InvT1(operand, v_neg_modulus, v_twice_mod, m, W, W_precon); } else { diff --git a/hexl/number-theory/number-theory.cpp b/hexl/number-theory/number-theory.cpp index 3df8bf8e..967c64e6 100644 --- a/hexl/number-theory/number-theory.cpp +++ b/hexl/number-theory/number-theory.cpp @@ -223,6 +223,7 @@ bool IsPrime(uint64_t n) { } std::vector GeneratePrimes(size_t num_primes, size_t bit_size, + bool prefer_small_primes, size_t ntt_size) { HEXL_CHECK(num_primes > 0, "num_primes == 0"); HEXL_CHECK(IsPowerOfTwo(ntt_size), @@ -231,18 +232,39 @@ std::vector GeneratePrimes(size_t num_primes, size_t bit_size, "log2(ntt_size) " << Log2(ntt_size) << " should be less than bit_size " << bit_size); - uint64_t value = (1ULL << bit_size) + 1; + int64_t prime_lower_bound = (1LL << bit_size) + 1LL; + int64_t prime_upper_bound = (1LL << (bit_size + 1LL)) - 1LL; + + // Keep signed to enable negative step + int64_t prime_candidate = + prefer_small_primes + ? prime_lower_bound + : prime_upper_bound - (prime_upper_bound % (2 * ntt_size)) + 1; + HEXL_CHECK(prime_candidate % (2 * ntt_size) == 1, "bad prime candidate"); + + // Ensure prime % 2 * ntt_size == 1 + int64_t prime_candidate_step = + (prefer_small_primes ? 1 : -1) * 2 * static_cast(ntt_size); + + auto continue_condition = [&](int64_t local_candidate_prime) { + if (prefer_small_primes) { + return local_candidate_prime < prime_upper_bound; + } else { + return local_candidate_prime > prime_lower_bound; + } + }; std::vector ret; - while (value < (1ULL << (bit_size + 1))) { - if (IsPrime(value)) { - ret.emplace_back(value); + while (continue_condition(prime_candidate)) { + if (IsPrime(prime_candidate)) { + HEXL_CHECK(prime_candidate % (2 * ntt_size) == 1, "bad prime candidate"); + ret.emplace_back(static_cast(prime_candidate)); if (ret.size() == num_primes) { return ret; } } - value += 2 * ntt_size; + prime_candidate += prime_candidate_step; } HEXL_CHECK(false, "Failed to find enough primes"); diff --git a/test/test-eltwise-add-mod-avx512.cpp b/test/test-eltwise-add-mod-avx512.cpp index cc14ef9a..de5c114b 100644 --- a/test/test-eltwise-add-mod-avx512.cpp +++ b/test/test-eltwise-add-mod-avx512.cpp @@ -52,7 +52,7 @@ TEST(EltwiseAddMod, vector_vector_avx512_big) { GTEST_SKIP(); } - uint64_t modulus = GeneratePrimes(1, 60, 1024)[0]; + uint64_t modulus = GeneratePrimes(1, 60, true, 1024)[0]; std::vector op1{modulus - 1, modulus - 1, modulus - 2, modulus - 2, modulus - 3, modulus - 3, modulus - 4, modulus - 4}; @@ -72,7 +72,7 @@ TEST(EltwiseAddMod, vector_scalar_avx512_big) { GTEST_SKIP(); } - uint64_t modulus = GeneratePrimes(1, 60, 1024)[0]; + uint64_t modulus = GeneratePrimes(1, 60, true, 1024)[0]; std::vector op1{modulus - 1, modulus - 1, modulus - 2, modulus - 2, modulus - 3, modulus - 3, modulus - 4, modulus - 4}; diff --git a/test/test-eltwise-add-mod.cpp b/test/test-eltwise-add-mod.cpp index a3366e33..72e90693 100644 --- a/test/test-eltwise-add-mod.cpp +++ b/test/test-eltwise-add-mod.cpp @@ -81,7 +81,7 @@ TEST(EltwiseAddMod, vector_scalar_native_small) { } TEST(EltwiseAddMod, vector_vector_native_big) { - uint64_t modulus = GeneratePrimes(1, 60, 1024)[0]; + uint64_t modulus = GeneratePrimes(1, 60, true, 1024)[0]; std::vector op1{modulus - 1, modulus - 1, modulus - 2, modulus - 2, modulus - 3, modulus - 3, modulus - 4, modulus - 4}; @@ -97,7 +97,7 @@ TEST(EltwiseAddMod, vector_vector_native_big) { } TEST(EltwiseAddMod, vector_scalar_native_big) { - uint64_t modulus = GeneratePrimes(1, 60, 1024)[0]; + uint64_t modulus = GeneratePrimes(1, 60, true, 1024)[0]; std::vector op1{modulus - 1, modulus - 1, modulus - 2, modulus - 2, modulus - 3, modulus - 3, modulus - 4, modulus - 4}; diff --git a/test/test-eltwise-cmp-sub-mod-avx512.cpp b/test/test-eltwise-cmp-sub-mod-avx512.cpp index 394856c2..17afcf28 100644 --- a/test/test-eltwise-cmp-sub-mod-avx512.cpp +++ b/test/test-eltwise-cmp-sub-mod-avx512.cpp @@ -31,7 +31,7 @@ TEST(EltwiseCmpSubMod, AVX512) { for (size_t cmp = 0; cmp < 8; ++cmp) { for (size_t bits = 48; bits <= 51; ++bits) { - uint64_t modulus = GeneratePrimes(1, bits, 1024)[0]; + uint64_t modulus = GeneratePrimes(1, bits, true, 1024)[0]; std::uniform_int_distribution distrib(0, modulus - 1); for (size_t trial = 0; trial < 200; ++trial) { diff --git a/test/test-eltwise-fma-mod-avx512.cpp b/test/test-eltwise-fma-mod-avx512.cpp index 46263ae2..986aa543 100644 --- a/test/test-eltwise-fma-mod-avx512.cpp +++ b/test/test-eltwise-fma-mod-avx512.cpp @@ -229,7 +229,7 @@ TEST(EltwiseFMAMod, AVX512IFMA) { constexpr uint64_t input_mod_factor = 8; for (size_t bits = 48; bits <= 51; ++bits) { - uint64_t modulus = GeneratePrimes(1, bits, length)[0]; + uint64_t modulus = GeneratePrimes(1, bits, true, length)[0]; std::uniform_int_distribution distrib( 0, input_mod_factor * modulus - 1); diff --git a/test/test-eltwise-mult-mod-avx512.cpp b/test/test-eltwise-mult-mod-avx512.cpp index 3285ba73..17fd3ae2 100644 --- a/test/test-eltwise-mult-mod-avx512.cpp +++ b/test/test-eltwise-mult-mod-avx512.cpp @@ -39,7 +39,7 @@ TEST(EltwiseMultMod, avx512_int2) { if (!has_avx512dq) { GTEST_SKIP(); } - uint64_t modulus = GeneratePrimes(1, 60, 1024)[0]; + uint64_t modulus = GeneratePrimes(1, 60, true, 1024)[0]; std::vector op1{modulus - 3, 1, 1, 1, 1, 1, 1, 1}; std::vector op2{modulus - 4, 1, 1, 1, 1, 1, 1, 1}; diff --git a/test/test-eltwise-mult-mod.cpp b/test/test-eltwise-mult-mod.cpp index f119434d..6a5ffcb7 100644 --- a/test/test-eltwise-mult-mod.cpp +++ b/test/test-eltwise-mult-mod.cpp @@ -77,7 +77,7 @@ TEST(EltwiseMultModInPlace, 8_bounds) { #endif TEST(EltwiseMultModInPlace, 9) { - uint64_t modulus = GeneratePrimes(1, 51, 1024)[0]; + uint64_t modulus = GeneratePrimes(1, 51, true, 1024)[0]; std::vector op1{modulus - 3, 1, 2, 3, 4, 5, 6, 7, 8}; std::vector op2{modulus - 4, 8, 7, 6, 5, 4, 3, 2, 1}; @@ -105,7 +105,7 @@ TEST(EltwiseMultMod, native_mult2) { } TEST(EltwiseMultMod, native2_big) { - uint64_t modulus = GeneratePrimes(1, 60, 1024)[0]; + uint64_t modulus = GeneratePrimes(1, 60, true, 1024)[0]; std::vector op1{modulus - 3, 1, 1, 1, 1, 1, 1, 1}; std::vector op2{modulus - 4, 1, 1, 1, 1, 1, 1, 1}; @@ -119,7 +119,7 @@ TEST(EltwiseMultMod, native2_big) { } TEST(EltwiseMultMod, 8big) { - uint64_t modulus = GeneratePrimes(1, 48, 1024)[0]; + uint64_t modulus = GeneratePrimes(1, 48, true, 1024)[0]; std::vector op1{modulus - 1, 1, 1, 1, 1, 1, 1, 1}; std::vector op2{modulus - 1, 1, 1, 1, 1, 1, 1, 1}; @@ -198,7 +198,7 @@ TEST(EltwiseMultMod, 8_bounds) { #endif TEST(EltwiseMultMod, 9) { - uint64_t modulus = GeneratePrimes(1, 51, 1024)[0]; + uint64_t modulus = GeneratePrimes(1, 51, true, 1024)[0]; std::vector op1{modulus - 3, 1, 2, 3, 4, 5, 6, 7, 8}; std::vector op2{modulus - 4, 8, 7, 6, 5, 4, 3, 2, 1}; diff --git a/test/test-eltwise-reduce-mod-avx512.cpp b/test/test-eltwise-reduce-mod-avx512.cpp index 044de5b9..c2ac38e7 100644 --- a/test/test-eltwise-reduce-mod-avx512.cpp +++ b/test/test-eltwise-reduce-mod-avx512.cpp @@ -100,7 +100,7 @@ TEST(EltwiseReduceMod, AVX512Big_0_1) { size_t length = 1024; for (size_t bits = 50; bits <= 62; ++bits) { - uint64_t modulus = GeneratePrimes(1, bits, length)[0]; + uint64_t modulus = GeneratePrimes(1, bits, true, length)[0]; std::uniform_int_distribution distrib(0, modulus - 1); #ifdef HEXL_DEBUG @@ -139,7 +139,7 @@ TEST(EltwiseReduceMod, AVX512Big_4_1) { size_t length = 1024; for (size_t bits = 50; bits <= 62; ++bits) { - uint64_t modulus = GeneratePrimes(1, bits, length)[0]; + uint64_t modulus = GeneratePrimes(1, bits, true, length)[0]; std::uniform_int_distribution distrib(0, (4 * modulus) - 1); #ifdef HEXL_DEBUG @@ -178,7 +178,7 @@ TEST(EltwiseReduceMod, AVX512Big_4_2) { size_t length = 1024; for (size_t bits = 50; bits <= 62; ++bits) { - uint64_t modulus = GeneratePrimes(1, bits, length)[0]; + uint64_t modulus = GeneratePrimes(1, bits, true, length)[0]; std::uniform_int_distribution distrib(0, (4 * modulus) - 1); #ifdef HEXL_DEBUG @@ -217,7 +217,7 @@ TEST(EltwiseReduceMod, AVX512Big_2_1) { size_t length = 1024; for (size_t bits = 50; bits <= 62; ++bits) { - uint64_t modulus = GeneratePrimes(1, bits, length)[0]; + uint64_t modulus = GeneratePrimes(1, bits, true, length)[0]; std::uniform_int_distribution distrib(0, (2 * modulus) - 1); #ifdef HEXL_DEBUG diff --git a/test/test-eltwise-sub-mod-avx512.cpp b/test/test-eltwise-sub-mod-avx512.cpp index c436a70b..abcfb00e 100644 --- a/test/test-eltwise-sub-mod-avx512.cpp +++ b/test/test-eltwise-sub-mod-avx512.cpp @@ -52,7 +52,7 @@ TEST(EltwiseSubMod, vector_vector_avx512_big) { GTEST_SKIP(); } - uint64_t modulus = GeneratePrimes(1, 60, 1024)[0]; + uint64_t modulus = GeneratePrimes(1, 60, true, 1024)[0]; std::vector op1{0, 1, 2, 3, modulus - 1, modulus - 2, modulus - 3, modulus - 4}; @@ -71,7 +71,7 @@ TEST(EltwiseSubMod, vector_scalar_avx512_big) { GTEST_SKIP(); } - uint64_t modulus = GeneratePrimes(1, 60, 1024)[0]; + uint64_t modulus = GeneratePrimes(1, 60, true, 1024)[0]; std::vector op1{0, 1, 2, 3, modulus - 1, modulus - 2, modulus - 3, modulus - 4}; diff --git a/test/test-eltwise-sub-mod.cpp b/test/test-eltwise-sub-mod.cpp index 6be1226a..922f9ee6 100644 --- a/test/test-eltwise-sub-mod.cpp +++ b/test/test-eltwise-sub-mod.cpp @@ -81,7 +81,7 @@ TEST(EltwiseSubMod, vector_scalar_native_small) { } TEST(EltwiseSubMod, vector_vector_native_big) { - uint64_t modulus = GeneratePrimes(1, 60, 1024)[0]; + uint64_t modulus = GeneratePrimes(1, 60, true, 1024)[0]; std::vector op1{0, 1, 2, 3, modulus - 1, modulus - 2, modulus - 3, modulus - 4}; @@ -96,7 +96,7 @@ TEST(EltwiseSubMod, vector_vector_native_big) { } TEST(EltwiseSubMod, vector_scalar_native_big) { - uint64_t modulus = GeneratePrimes(1, 60, 1024)[0]; + uint64_t modulus = GeneratePrimes(1, 60, true, 1024)[0]; std::vector op1{0, 1, 2, 3, modulus - 1, modulus - 2, modulus - 3, modulus - 4}; diff --git a/test/test-ntt-avx512.cpp b/test/test-ntt-avx512.cpp index 12a0b446..602c8d0f 100644 --- a/test/test-ntt-avx512.cpp +++ b/test/test-ntt-avx512.cpp @@ -21,70 +21,6 @@ namespace intel { namespace hexl { -#ifdef HEXL_HAS_AVX512IFMA -class ModulusTest - : public ::testing::TestWithParam> { - protected: - void SetUp() {} - - void TearDown() {} - - public: -}; - -// Test modulus around 50 bits to check IFMA behavior -// Parameters = (degree, modulus_bits) -TEST_P(ModulusTest, IFMAModuli) { - if (!has_avx512ifma) { - GTEST_SKIP(); - } - uint64_t N = std::get<0>(GetParam()); - uint64_t modulus_bits = std::get<1>(GetParam()); - uint64_t modulus = GeneratePrimes(1, modulus_bits, N)[0]; - - std::vector input64(N, 0); - for (size_t i = 0; i < N; ++i) { - input64[i] = i % modulus; - } - std::vector input_ifma = input64; - std::vector input_ifma_lazy = input64; - - std::vector exp_output(N, 0); - - // Compute reference - NTT ntt64(N, modulus); - ReferenceForwardTransformToBitReverse(input64.data(), N, modulus, - ntt64.GetRootOfUnityPowers().data()); - - // Compute with 52-bit bit shift - NTT ntt_ifma(N, modulus); - - ForwardTransformToBitReverseAVX512<52>( - input_ifma.data(), N, ntt_ifma.GetModulus(), - ntt_ifma.GetAVX512RootOfUnityPowers().data(), - ntt_ifma.GetAVX512Precon52RootOfUnityPowers().data(), 2, 1); - - // Compute lazy - ForwardTransformToBitReverseAVX512<52>( - input_ifma_lazy.data(), N, ntt_ifma.GetModulus(), - ntt_ifma.GetAVX512RootOfUnityPowers().data(), - ntt_ifma.GetAVX512Precon52RootOfUnityPowers().data(), 2, 4); - for (auto& elem : input_ifma_lazy) { - elem = elem % modulus; - } - - AssertEqual(input64, input_ifma); - AssertEqual(input64, input_ifma_lazy); -} - -INSTANTIATE_TEST_SUITE_P(NTT, ModulusTest, - ::testing::Values(std::make_tuple(1 << 4, 48), - std::make_tuple(1 << 5, 49), - std::make_tuple(1 << 6, 49), - std::make_tuple(1 << 7, 49), - std::make_tuple(1 << 8, 49))); -#endif - #ifdef HEXL_HAS_AVX512DQ TEST(NTT, LoadFwdInterleavedT1) { if (!has_avx512dq) { @@ -232,6 +168,138 @@ TEST(NTT, WriteInvInterleavedT4) { AssertEqual(exp, out); } +#ifdef HEXL_HAS_AVX512IFMA +// First parameter is the NTT degree +// Second parameter is the number of bits in the NTT modulus +// Third parameter is whether or not to prefer small primes +class DegreeModulusBoolTest + : public ::testing::TestWithParam> { + protected: + void SetUp() {} + + void TearDown() {} + + public: +}; + +TEST_P(DegreeModulusBoolTest, FwdNTTAVX512IFMA) { + if (!has_avx512ifma) { + GTEST_SKIP(); + } + uint64_t N = std::get<0>(GetParam()); + uint64_t modulus_bits = std::get<1>(GetParam()); + bool prefer_small_primes = std::get<2>(GetParam()); + uint64_t modulus = GeneratePrimes(1, modulus_bits, prefer_small_primes, N)[0]; + +#ifdef HEXL_DEBUG + size_t num_trials = 1; +#else + size_t num_trials = 20; +#endif + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution distrib(1, modulus - 1); + + for (size_t trial = 0; trial < num_trials; ++trial) { + std::vector input64(N, 0); + for (size_t i = 0; i < N; ++i) { + input64[i] = distrib(gen); + } + + std::vector input_ifma = input64; + std::vector input_ifma_lazy = input64; + std::vector exp_output(N, 0); + + // Compute reference + NTT ntt64(N, modulus); + ReferenceForwardTransformToBitReverse(input64.data(), N, modulus, + ntt64.GetRootOfUnityPowers().data()); + + ForwardTransformToBitReverseAVX512<52>( + input_ifma.data(), N, ntt64.GetModulus(), + ntt64.GetAVX512RootOfUnityPowers().data(), + ntt64.GetAVX512Precon52RootOfUnityPowers().data(), 1, 1); + + // Compute lazy + ForwardTransformToBitReverseAVX512<52>( + input_ifma_lazy.data(), N, ntt64.GetModulus(), + ntt64.GetAVX512RootOfUnityPowers().data(), + ntt64.GetAVX512Precon52RootOfUnityPowers().data(), 2, 4); + for (auto& elem : input_ifma_lazy) { + elem = elem % modulus; + } + + AssertEqual(input64, input_ifma); + AssertEqual(input64, input_ifma_lazy); + } +} + +TEST_P(DegreeModulusBoolTest, InvNTTAVX512IFMA) { + if (!has_avx512ifma) { + GTEST_SKIP(); + } + uint64_t N = std::get<0>(GetParam()); + uint64_t modulus_bits = std::get<1>(GetParam()); + bool prefer_small_primes = std::get<2>(GetParam()); + uint64_t modulus = GeneratePrimes(1, modulus_bits, prefer_small_primes, N)[0]; + +#ifdef HEXL_DEBUG + size_t num_trials = 1; +#else + size_t num_trials = 20; +#endif + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution distrib(1, modulus - 1); + + for (size_t trial = 0; trial < num_trials; ++trial) { + std::vector input64(N, 0); + for (size_t i = 0; i < N; ++i) { + input64[i] = distrib(gen); + } + + std::vector input_ifma = input64; + std::vector input_ifma_lazy = input64; + + std::vector exp_output(N, 0); + + // Compute reference + NTT ntt(N, modulus); + InverseTransformFromBitReverseRadix2( + input64.data(), N, modulus, ntt.GetInvRootOfUnityPowers().data(), + ntt.GetPrecon64InvRootOfUnityPowers().data(), 1, 1); + + InverseTransformFromBitReverseAVX512<52>( + input_ifma.data(), N, ntt.GetModulus(), + ntt.GetInvRootOfUnityPowers().data(), + ntt.GetPrecon52InvRootOfUnityPowers().data(), 1, 1); + + // Compute lazy + InverseTransformFromBitReverseAVX512<52>( + input_ifma_lazy.data(), N, ntt.GetModulus(), + ntt.GetInvRootOfUnityPowers().data(), + ntt.GetPrecon52InvRootOfUnityPowers().data(), 1, 2); + for (auto& elem : input_ifma_lazy) { + elem = elem % modulus; + } + + AssertEqual(input64, input_ifma); + AssertEqual(input64, input_ifma_lazy); + } +} + +// Test modulus around 50 bits to check IFMA behavior +INSTANTIATE_TEST_SUITE_P( + NTT, DegreeModulusBoolTest, + ::testing::Combine(::testing::ValuesIn(std::vector{ + 1 << 11, 1 << 12, 1 << 13}), + ::testing::ValuesIn(std::vector{48, 49}), + ::testing::ValuesIn(std::vector{false, true}))); + +#endif + // Checks AVX512 and native forward NTT implementations match TEST(NTT, FwdNTT_AVX512_32) { if (!has_avx512dq) { @@ -247,7 +315,7 @@ TEST(NTT, FwdNTT_AVX512_32) { #endif for (size_t N = 512; N <= 65536; N *= 2) { - uint64_t modulus = GeneratePrimes(1, 27, N)[0]; + uint64_t modulus = GeneratePrimes(1, 27, true, N)[0]; std::uniform_int_distribution distrib(0, modulus - 1); NTT ntt(N, modulus); @@ -299,7 +367,7 @@ TEST(NTT, FwdNTT_AVX512_64) { #endif for (size_t N = 512; N <= 65536; N *= 2) { - uint64_t modulus = GeneratePrimes(1, 55, N)[0]; + uint64_t modulus = GeneratePrimes(1, 55, true, N)[0]; std::uniform_int_distribution distrib(0, modulus - 1); NTT ntt(N, modulus); @@ -351,7 +419,7 @@ TEST(NTT, InvNTT_AVX512_32) { #endif for (size_t N = 512; N <= 65536; N *= 2) { - uint64_t modulus = GeneratePrimes(1, 27, N)[0]; + uint64_t modulus = GeneratePrimes(1, 27, true, N)[0]; std::uniform_int_distribution distrib(0, modulus - 1); NTT ntt(N, modulus); @@ -403,7 +471,7 @@ TEST(NTT, InvNTT_AVX512_64) { #endif for (size_t N = 512; N <= 65536; N *= 2) { - uint64_t modulus = GeneratePrimes(1, 55, N)[0]; + uint64_t modulus = GeneratePrimes(1, 55, true, N)[0]; std::uniform_int_distribution distrib(0, modulus - 1); NTT ntt(N, modulus); diff --git a/test/test-ntt.cpp b/test/test-ntt.cpp index 7b7875e3..763ec3fa 100644 --- a/test/test-ntt.cpp +++ b/test/test-ntt.cpp @@ -368,7 +368,7 @@ class DegreeModulusTest TEST_P(DegreeModulusTest, ForwardZeros) { uint64_t N = std::get<0>(GetParam()); uint64_t modulus_bits = std::get<1>(GetParam()); - uint64_t modulus = GeneratePrimes(1, modulus_bits, N)[0]; + uint64_t modulus = GeneratePrimes(1, modulus_bits, true, N)[0]; std::vector input(N, 0); std::vector exp_output(N, 0); @@ -382,7 +382,7 @@ TEST_P(DegreeModulusTest, ForwardZeros) { TEST_P(DegreeModulusTest, InverseZeros) { uint64_t N = std::get<0>(GetParam()); uint64_t modulus_bits = std::get<1>(GetParam()); - uint64_t modulus = GeneratePrimes(1, modulus_bits, N)[0]; + uint64_t modulus = GeneratePrimes(1, modulus_bits, true, N)[0]; std::vector input(N, 0); std::vector exp_output(N, 0); @@ -396,7 +396,7 @@ TEST_P(DegreeModulusTest, InverseZeros) { TEST_P(DegreeModulusTest, ForwardRadix4Random) { uint64_t N = std::get<0>(GetParam()); uint64_t modulus_bits = std::get<1>(GetParam()); - uint64_t modulus = GeneratePrimes(1, modulus_bits, N)[0]; + uint64_t modulus = GeneratePrimes(1, modulus_bits, true, N)[0]; std::random_device rd; std::mt19937 gen(rd()); @@ -423,10 +423,10 @@ TEST_P(DegreeModulusTest, ForwardRadix4Random) { TEST_P(DegreeModulusTest, InverseRadix4Random) { uint64_t N = std::get<0>(GetParam()); uint64_t modulus_bits = std::get<1>(GetParam()); - uint64_t modulus = GeneratePrimes(1, modulus_bits, N)[0]; + uint64_t modulus = GeneratePrimes(1, modulus_bits, true, N)[0]; std::random_device rd; - std::mt19937 gen(42); // rd()); + std::mt19937 gen(rd()); std::uniform_int_distribution distrib(1, modulus - 1); std::vector input(N); diff --git a/test/test-number-theory.cpp b/test/test-number-theory.cpp index d51e0684..4c317633 100644 --- a/test/test-number-theory.cpp +++ b/test/test-number-theory.cpp @@ -336,7 +336,17 @@ TEST(NumberTheory, IsPrime) { TEST(NumberTheory, GeneratePrimes) { for (int bit_size = 40; bit_size < 62; ++bit_size) { - std::vector primes = GeneratePrimes(10, bit_size, 4096); + std::vector primes = GeneratePrimes(10, bit_size, true, 4096); + ASSERT_EQ(primes.size(), 10); + for (const auto& prime : primes) { + ASSERT_EQ(prime % 8192, 1); + ASSERT_TRUE(IsPrime(prime)); + ASSERT_TRUE(prime <= (1ULL << (bit_size + 1))); + ASSERT_TRUE(prime >= (1ULL << bit_size)); + } + } + for (int bit_size = 40; bit_size < 62; ++bit_size) { + std::vector primes = GeneratePrimes(10, bit_size, false, 4096); ASSERT_EQ(primes.size(), 10); for (const auto& prime : primes) { ASSERT_EQ(prime % 8192, 1);