diff --git a/icicle_v3/backend/cpu/include/ntt_template.h b/icicle_v3/backend/cpu/include/ntt_template.h index 98130650d..8d9c9b0eb 100644 --- a/icicle_v3/backend/cpu/include/ntt_template.h +++ b/icicle_v3/backend/cpu/include/ntt_template.h @@ -16,32 +16,33 @@ using namespace field_config; using namespace icicle; namespace ntt_template { template - class CpuNttDomain { - //TODO - coset, mixed-radix NTT - int max_size = 0; - int max_log_size = 0; - S* twiddles = nullptr; - - public: - template - friend eIcicleError cpu_ntt_init_domain(const Device& device, const U& primitive_root, const NTTInitDomainConfig& config); - - template - friend eIcicleError generate_twiddles(const U& primitive_root, U* twiddles, int logn); + class CpuNttDomain + { + // TODO - coset, mixed-radix NTT + int max_size = 0; + int max_log_size = 0; + S* twiddles = nullptr; - template - friend eIcicleError cpu_ntt_release_domain(const Device& device); + public: + template + friend eIcicleError + cpu_ntt_init_domain(const Device& device, const U& primitive_root, const NTTInitDomainConfig& config); - template - friend eIcicleError cpu_ntt_ref(const Device& device, const E* input, int size, NTTDir dir, NTTConfig& config, E* output); - - template - friend eIcicleError cpu_ntt(const Device& device, const E* input, int size, NTTDir dir, NTTConfig& config, E* output); + template + friend eIcicleError generate_twiddles(const U& primitive_root, U* twiddles, int logn); - S* get_twiddles() { - return twiddles; - } + template + friend eIcicleError cpu_ntt_release_domain(const Device& device); + + template + friend eIcicleError + cpu_ntt_ref(const Device& device, const E* input, int size, NTTDir dir, NTTConfig& config, E* output); + template + friend eIcicleError + cpu_ntt(const Device& device, const E* input, int size, NTTDir dir, NTTConfig& config, E* output); + + S* get_twiddles() { return twiddles; } }; template @@ -50,218 +51,210 @@ namespace ntt_template { template eIcicleError generate_twiddles(const S& primitive_root, S* twiddles, int n) { - S omega = primitive_root; - twiddles[0] = S::one(); - for (int i = 1; i <= n; i++) { - twiddles[i] = twiddles[i - 1] * omega; - } - return eIcicleError::SUCCESS; + S omega = primitive_root; + twiddles[0] = S::one(); + for (int i = 1; i <= n; i++) { + twiddles[i] = twiddles[i - 1] * omega; + } + return eIcicleError::SUCCESS; } template eIcicleError cpu_ntt_init_domain(const Device& device, const S& primitive_root, const NTTInitDomainConfig& config) { - // (1) check if need to refresh domain. TODO - reusing domain for different ntt sizes if possible - if (s_ntt_domain != nullptr) { - return eIcicleError::SUCCESS; - } - // (2) build the domain - delete s_ntt_domain; - s_ntt_domain = new CpuNttDomain(); - - bool found_logn = false; - S omega = primitive_root; - unsigned omegas_count = S::get_omegas_count(); - for (int i = 0; i < omegas_count; i++) { - omega = S::sqr(omega); - if (!found_logn) { - ++s_ntt_domain->max_log_size; - found_logn = omega == S::one(); - if (found_logn) break; - } - } + // (1) check if need to refresh domain. TODO - reusing domain for different ntt sizes if possible + if (s_ntt_domain != nullptr) { return eIcicleError::SUCCESS; } + // (2) build the domain + delete s_ntt_domain; + s_ntt_domain = new CpuNttDomain(); - s_ntt_domain->max_size = (int)pow(2, s_ntt_domain->max_log_size); - if (omega != S::one()) { - ICICLE_LOG_ERROR << "Primitive root provided to the InitDomain function is not in the subgroup"; - return eIcicleError::INVALID_ARGUMENT; + bool found_logn = false; + S omega = primitive_root; + unsigned omegas_count = S::get_omegas_count(); + for (int i = 0; i < omegas_count; i++) { + omega = S::sqr(omega); + if (!found_logn) { + ++s_ntt_domain->max_log_size; + found_logn = omega == S::one(); + if (found_logn) break; } + } + + s_ntt_domain->max_size = (int)pow(2, s_ntt_domain->max_log_size); + if (omega != S::one()) { + ICICLE_LOG_ERROR << "Primitive root provided to the InitDomain function is not in the subgroup"; + return eIcicleError::INVALID_ARGUMENT; + } - // calculate twiddles - // Note: radix-2 INTT needs ONE in last element (in addition to first element), therefore have n+1 elements + // calculate twiddles + // Note: radix-2 INTT needs ONE in last element (in addition to first element), therefore have n+1 elements - s_ntt_domain->twiddles = (S*)malloc((s_ntt_domain->max_size + 1) * sizeof(S)); - generate_twiddles(primitive_root, s_ntt_domain->twiddles, s_ntt_domain->max_size); - return eIcicleError::SUCCESS; + s_ntt_domain->twiddles = (S*)malloc((s_ntt_domain->max_size + 1) * sizeof(S)); + generate_twiddles(primitive_root, s_ntt_domain->twiddles, s_ntt_domain->max_size); + return eIcicleError::SUCCESS; } template eIcicleError cpu_ntt_release_domain(const Device& device) { - // release the memory - free(s_ntt_domain->twiddles); - delete s_ntt_domain; - s_ntt_domain = nullptr; - return eIcicleError::SUCCESS; + // release the memory + free(s_ntt_domain->twiddles); + delete s_ntt_domain; + s_ntt_domain = nullptr; + return eIcicleError::SUCCESS; } template eIcicleError bit_reverse(int size, int logn, E* output, int batch_size) { - ICICLE_LOG_DEBUG << "BIT REVERSE"; - int total_size = size * batch_size; - for (int batch = 0; batch < batch_size; ++batch) { - E* current_output = output + batch * size; - for (int i = 0; i < size; ++i) { - int rev = 0; - for (int j = 0; j < logn; ++j) { - if (i & (1 << j)) { - rev |= 1 << (logn - 1 - j); - } - } - if (i < rev) { - std::swap(current_output[i], current_output[rev]); - } - } + ICICLE_LOG_DEBUG << "BIT REVERSE"; + int total_size = size * batch_size; + for (int batch = 0; batch < batch_size; ++batch) { + E* current_output = output + batch * size; + for (int i = 0; i < size; ++i) { + int rev = 0; + for (int j = 0; j < logn; ++j) { + if (i & (1 << j)) { rev |= 1 << (logn - 1 - j); } + } + if (i < rev) { std::swap(current_output[i], current_output[rev]); } } - return eIcicleError::SUCCESS; + } + return eIcicleError::SUCCESS; } template eIcicleError cpu_ntt_ref(const Device& device, const E* input, int size, NTTDir dir, NTTConfig& config, E* output) { - if (size & (size - 1)) { - return eIcicleError::INVALID_ARGUMENT; - } - - int total_size = size * config.batch_size; - E* temp_elements = new E[total_size]; + if (size & (size - 1)) { return eIcicleError::INVALID_ARGUMENT; } - // ICICLE_LOG_DEBUG << "INITIAL INPUT"; - // for (int i = 0; i < total_size; ++i) { - // ICICLE_LOG_DEBUG << "input[" << i << "]: " << input[i]; - // } + int total_size = size * config.batch_size; + E* temp_elements = new E[total_size]; - if (config.columns_batch) { - // Distribute input into columns if columns_batch is set - for (int col = 0; col < config.batch_size; ++col) { - for (int row = 0; row < size; ++row) { - temp_elements[col * size + row] = input[row * config.batch_size + col]; - } - } - } else { - std::copy(input, input + total_size, temp_elements); - } - // ICICLE_LOG_DEBUG << "INITIAL ELEMENTS ARRAY"; - // for (int i = 0; i < total_size; ++i) { - // ICICLE_LOG_DEBUG << "temp_elements[" << i << "]: " << temp_elements[i]; - // } - const int logn = int(log2(size)); - - bool dit = true; - switch(config.ordering) { //kNN, kNR, kRN, kRR, kNM, kMN - case Ordering::kNN: - bit_reverse(size, logn, temp_elements, config.batch_size); - break; - case Ordering::kNR: - case Ordering::kNM: - dit = false; //dif - break; - case Ordering::kRR: - bit_reverse(size, logn, temp_elements, config.batch_size); - dit = false; //dif - break; - case Ordering::kRN: - case Ordering::kMN: - break; - default: - return eIcicleError::INVALID_ARGUMENT; + // ICICLE_LOG_DEBUG << "INITIAL INPUT"; + // for (int i = 0; i < total_size; ++i) { + // ICICLE_LOG_DEBUG << "input[" << i << "]: " << input[i]; + // } + + if (config.columns_batch) { + // Distribute input into columns if columns_batch is set + for (int col = 0; col < config.batch_size; ++col) { + for (int row = 0; row < size; ++row) { + temp_elements[col * size + row] = input[row * config.batch_size + col]; + } } - - // ICICLE_LOG_DEBUG << "AFTER BIT REVERSE"; - // for (int i = 0; i < total_size; ++i) { - // ICICLE_LOG_DEBUG << "temp_elements[" << i << "]: " << temp_elements[i]; - // } + } else { + std::copy(input, input + total_size, temp_elements); + } + // ICICLE_LOG_DEBUG << "INITIAL ELEMENTS ARRAY"; + // for (int i = 0; i < total_size; ++i) { + // ICICLE_LOG_DEBUG << "temp_elements[" << i << "]: " << temp_elements[i]; + // } + const int logn = int(log2(size)); + + bool dit = true; + switch (config.ordering) { // kNN, kNR, kRN, kRR, kNM, kMN + case Ordering::kNN: + bit_reverse(size, logn, temp_elements, config.batch_size); + break; + case Ordering::kNR: + case Ordering::kNM: + dit = false; // dif + break; + case Ordering::kRR: + bit_reverse(size, logn, temp_elements, config.batch_size); + dit = false; // dif + break; + case Ordering::kRN: + case Ordering::kMN: + break; + default: + return eIcicleError::INVALID_ARGUMENT; + } + + // ICICLE_LOG_DEBUG << "AFTER BIT REVERSE"; + // for (int i = 0; i < total_size; ++i) { + // ICICLE_LOG_DEBUG << "temp_elements[" << i << "]: " << temp_elements[i]; + // } - S* twiddles = s_ntt_domain->get_twiddles(); + S* twiddles = s_ntt_domain->get_twiddles(); - // NTT/INTT - int ntt_step = 0; - if (dit) { - ICICLE_LOG_DEBUG << "DIT"; - for (int batch = 0; batch < config.batch_size; ++batch) { - E* current_temp_elements = temp_elements + batch * size; - for (int len = 2; len <= size; len <<= 1) { - // ICICLE_LOG_DEBUG << "ntt_step: " << ntt_step++; - int half_len = len / 2; - int step = size / len; - int tw_idx = 0; - for (int i = 0; i < size; i += len) { - for (int j = 0; j < half_len; ++j) { - tw_idx = (dir == NTTDir::kForward)? j * step : size - j * step; - S u = current_temp_elements[i + j]; - S v = current_temp_elements[i + j + half_len] * twiddles[tw_idx]; - // ICICLE_LOG_DEBUG << "tw_idx=" << tw_idx; - // ICICLE_LOG_DEBUG << "current_temp_elements[" << i + j << "] <-- " << current_temp_elements[i + j] << " + " << current_temp_elements[i + j + half_len] << "*" << twiddles[tw_idx]; - // ICICLE_LOG_DEBUG << "current_temp_elements[" << i + j + half_len << "] <-- " << current_temp_elements[i + j] << " - " << current_temp_elements[i + j + half_len] << "*" << twiddles[tw_idx]; - current_temp_elements[i + j] = u + v; - current_temp_elements[i + j + half_len] = u - v; - // ICICLE_LOG_DEBUG << i + j << " <-- " << i + j << " + " << i + j + half_len << "*" << tw_idx; - // ICICLE_LOG_DEBUG << i + j + half_len << " <-- " << i + j << " - " << i + j + half_len << "*" << tw_idx; - - } - } - } + // NTT/INTT + int ntt_step = 0; + if (dit) { + ICICLE_LOG_DEBUG << "DIT"; + for (int batch = 0; batch < config.batch_size; ++batch) { + E* current_temp_elements = temp_elements + batch * size; + for (int len = 2; len <= size; len <<= 1) { + // ICICLE_LOG_DEBUG << "ntt_step: " << ntt_step++; + int half_len = len / 2; + int step = size / len; + int tw_idx = 0; + for (int i = 0; i < size; i += len) { + for (int j = 0; j < half_len; ++j) { + tw_idx = (dir == NTTDir::kForward) ? j * step : size - j * step; + E u = current_temp_elements[i + j]; + E v = current_temp_elements[i + j + half_len] * twiddles[tw_idx]; + // ICICLE_LOG_DEBUG << "tw_idx=" << tw_idx; + // ICICLE_LOG_DEBUG << "current_temp_elements[" << i + j << "] <-- " << current_temp_elements[i + j] << " + // + " << current_temp_elements[i + j + half_len] << "*" << twiddles[tw_idx]; ICICLE_LOG_DEBUG << + // "current_temp_elements[" << i + j + half_len << "] <-- " << current_temp_elements[i + j] << " - " << + // current_temp_elements[i + j + half_len] << "*" << twiddles[tw_idx]; + current_temp_elements[i + j] = u + v; + current_temp_elements[i + j + half_len] = u - v; + // ICICLE_LOG_DEBUG << i + j << " <-- " << i + j << " + " << i + j + half_len << "*" << tw_idx; + // ICICLE_LOG_DEBUG << i + j + half_len << " <-- " << i + j << " - " << i + j + half_len << "*" << tw_idx; + } } - } else { //dif - ICICLE_LOG_DEBUG << "DIF"; - for (int batch = 0; batch < config.batch_size; ++batch) { - E* current_temp_elements = temp_elements + batch * size; - for (int len = size; len >= 2; len >>= 1) { - // ICICLE_LOG_DEBUG << "ntt_step: " << ntt_step++; - int half_len = len / 2; - int step = size / len; - int tw_idx = 0; - for (int i = 0; i < size; i += len) { - for (int j = 0; j < half_len; ++j) { - tw_idx = (dir == NTTDir::kForward)? j * step : size - j * step; - S u = current_temp_elements[i + j]; - S v = current_temp_elements[i + j + half_len]; - current_temp_elements[i + j] = u + v; - current_temp_elements[i + j + half_len] = (u - v)*twiddles[tw_idx]; - } - } - } + } + } + } else { // dif + ICICLE_LOG_DEBUG << "DIF"; + for (int batch = 0; batch < config.batch_size; ++batch) { + E* current_temp_elements = temp_elements + batch * size; + for (int len = size; len >= 2; len >>= 1) { + // ICICLE_LOG_DEBUG << "ntt_step: " << ntt_step++; + int half_len = len / 2; + int step = size / len; + int tw_idx = 0; + for (int i = 0; i < size; i += len) { + for (int j = 0; j < half_len; ++j) { + tw_idx = (dir == NTTDir::kForward) ? j * step : size - j * step; + E u = current_temp_elements[i + j]; + E v = current_temp_elements[i + j + half_len]; + current_temp_elements[i + j] = u + v; + current_temp_elements[i + j + half_len] = (u - v) * twiddles[tw_idx]; + } } + } } + } - if (dir == NTTDir::kInverse) { - // Normalize results - S inv_size = S::inv_log_size(logn); - for (int i = 0; i < total_size; ++i) { - temp_elements[i] = temp_elements[i] * inv_size; - } + if (dir == NTTDir::kInverse) { + // Normalize results + S inv_size = S::inv_log_size(logn); + for (int i = 0; i < total_size; ++i) { + temp_elements[i] = temp_elements[i] * inv_size; } + } - if (config.columns_batch) { - // Distribute output into columns if columns_batch is set - for (int col = 0; col < config.batch_size; ++col) { - for (int row = 0; row < size; ++row) { - output[row * config.batch_size + col] = temp_elements[col * size + row]; - } - } - } else { - std::copy(temp_elements, temp_elements + total_size, output); + if (config.columns_batch) { + // Distribute output into columns if columns_batch is set + for (int col = 0; col < config.batch_size; ++col) { + for (int row = 0; row < size; ++row) { + output[row * config.batch_size + col] = temp_elements[col * size + row]; + } } - // ICICLE_LOG_DEBUG << "FINAL OUTPUT"; - // for (int i = 0; i < total_size; ++i) { - // ICICLE_LOG_DEBUG << "output[" << i << "]: " << output[i]; - // } - - delete[] temp_elements; - return eIcicleError::SUCCESS; - } + } else { + std::copy(temp_elements, temp_elements + total_size, output); + } + // ICICLE_LOG_DEBUG << "FINAL OUTPUT"; + // for (int i = 0; i < total_size; ++i) { + // ICICLE_LOG_DEBUG << "output[" << i << "]: " << output[i]; + // } + delete[] temp_elements; + return eIcicleError::SUCCESS; + } template eIcicleError cpu_ntt(const Device& device, const E* input, int size, NTTDir dir, NTTConfig& config, E* output) diff --git a/icicle_v3/backend/cpu/src/cpu_device_api.cpp b/icicle_v3/backend/cpu/src/cpu_device_api.cpp index 370922749..f144e6e1e 100644 --- a/icicle_v3/backend/cpu/src/cpu_device_api.cpp +++ b/icicle_v3/backend/cpu/src/cpu_device_api.cpp @@ -93,6 +93,8 @@ class CpuDeviceAPI : public DeviceAPI REGISTER_DEVICE_API("CPU", CpuDeviceAPI); -class CpuRefDevice: public CpuDeviceAPI {}; +class CpuRefDevice : public CpuDeviceAPI +{ +}; REGISTER_DEVICE_API("CPU_REF", CpuRefDevice); diff --git a/icicle_v3/backend/cpu/src/curve/cpu_ecntt.cpp b/icicle_v3/backend/cpu/src/curve/cpu_ecntt.cpp index 1d1971388..ed0b21d53 100644 --- a/icicle_v3/backend/cpu/src/curve/cpu_ecntt.cpp +++ b/icicle_v3/backend/cpu/src/curve/cpu_ecntt.cpp @@ -6,7 +6,13 @@ #include "icicle/curves/curve_config.h" -using namespace field_config; +using namespace curve_config; using namespace icicle; +template +eIcicleError cpu_ntt(const Device& device, const E* input, int size, NTTDir dir, NTTConfig& config, E* output) +{ + return eIcicleError::API_NOT_IMPLEMENTED; +} + REGISTER_ECNTT_BACKEND("CPU", (cpu_ntt)); diff --git a/icicle_v3/backend/cpu/src/field/cpu_ntt.cpp b/icicle_v3/backend/cpu/src/field/cpu_ntt.cpp index 983dc8779..9e3b70dab 100644 --- a/icicle_v3/backend/cpu/src/field/cpu_ntt.cpp +++ b/icicle_v3/backend/cpu/src/field/cpu_ntt.cpp @@ -3,29 +3,31 @@ using namespace field_config; using namespace icicle; -eIcicleError cpu_ntt_init_domain(const Device& device, const scalar_t& primitive_root, const NTTInitDomainConfig& config) +eIcicleError +cpu_ntt_init_domain(const Device& device, const scalar_t& primitive_root, const NTTInitDomainConfig& config) { - auto err = ntt_template::cpu_ntt_init_domain(device, primitive_root, config); - return err; + auto err = ntt_template::cpu_ntt_init_domain(device, primitive_root, config); + return err; } eIcicleError cpu_ntt_release_domain(const Device& device) { - auto err = ntt_template::cpu_ntt_release_domain(device); - return err; + auto err = ntt_template::cpu_ntt_release_domain(device); + return err; } template eIcicleError cpu_ntt(const Device& device, const E* input, int size, NTTDir dir, NTTConfig& config, E* output) { - auto err = ntt_template::cpu_ntt(device, input, size, dir, config, output); - return err; + auto err = ntt_template::cpu_ntt(device, input, size, dir, config, output); + return err; } -eIcicleError cpu_ntt_ref(const Device& device, const scalar_t* input, int size, NTTDir dir, NTTConfig& config, scalar_t* output) +eIcicleError cpu_ntt_ref( + const Device& device, const scalar_t* input, int size, NTTDir dir, NTTConfig& config, scalar_t* output) { - auto err = ntt_template::cpu_ntt_ref(device, input, size, dir, config, output); - return err; + auto err = ntt_template::cpu_ntt_ref(device, input, size, dir, config, output); + return err; } REGISTER_NTT_INIT_DOMAIN_BACKEND("CPU", (cpu_ntt_init_domain)); @@ -34,7 +36,6 @@ REGISTER_NTT_INIT_DOMAIN_BACKEND("CPU_REF", (cpu_ntt_init_domain)); REGISTER_NTT_RELEASE_DOMAIN_BACKEND("CPU", cpu_ntt_release_domain); REGISTER_NTT_RELEASE_DOMAIN_BACKEND("CPU_REF", cpu_ntt_release_domain); - REGISTER_NTT_BACKEND("CPU", (cpu_ntt)); REGISTER_NTT_BACKEND("CPU_REF", (cpu_ntt_ref)); diff --git a/icicle_v3/tests/test_field_api.cpp b/icicle_v3/tests/test_field_api.cpp index 0448e48bb..ac04357e5 100644 --- a/icicle_v3/tests/test_field_api.cpp +++ b/icicle_v3/tests/test_field_api.cpp @@ -12,7 +12,6 @@ #include "icicle/fields/field_config.h" #include "icicle/utils/log.h" - using namespace field_config; using namespace icicle; @@ -211,29 +210,35 @@ TYPED_TEST(FieldApiTest, ntt) const int logn = rand() % 13 + 3; const int N = 1 << logn; - // Randomize config - const int batch_size = rand() % 15 + 1; - const Ordering ordering = static_cast(rand() % 6); - bool columns_batch; - if (ordering == Ordering::kMN || ordering == Ordering::kNM || logn == 7 || logn < 4) { - columns_batch = false; //FIXME: currently not supported (icicle_v3/backend/cuda/src/ntt/ntt.cuh line 578) - } - else { - columns_batch = rand() % 2; - } - const NTTDir dir = static_cast(rand() % 2); // 0: forward, 1: inverse - - // // Set fixed config values for debugging - // const int batch_size = 2; - // const bool columns_batch = false; - // const Ordering ordering = Ordering::kNM; - // const NTTDir dir = NTTDir::kForward; - - // Print config + // // Randomize config + // TODO - iterate over different configs + // const int batch_size = rand() % 15 + 1; + // const Ordering ordering = static_cast(rand() % 6); + // bool columns_batch; + // if (ordering == Ordering::kMN || ordering == Ordering::kNM || logn == 7 || logn < 4) { + // columns_batch = false; // FIXME: currently not supported (icicle_v3/backend/cuda/src/ntt/ntt.cuh line 578) + // } else { + // columns_batch = rand() % 2; + // } + // const NTTDir dir = static_cast(rand() % 2); // 0: forward, 1: inverse + + // Set fixed config values for debugging + const int batch_size = 2; + const bool columns_batch = false; + const Ordering ordering = Ordering::kNM; + const NTTDir dir = NTTDir::kForward; + + // Print config ICICLE_LOG_DEBUG << "logn: " << logn; ICICLE_LOG_DEBUG << "batch_size: " << batch_size; ICICLE_LOG_DEBUG << "columns_batch: " << columns_batch; - ICICLE_LOG_DEBUG << "ordering: " << (ordering == Ordering::kNN? "kNN" : ordering == Ordering::kNR? "kNR" : ordering == Ordering::kRN? "kRN" : ordering == Ordering::kRR? "kRR" : ordering == Ordering::kNM? "kNM" : "kMN"); + ICICLE_LOG_DEBUG << "ordering: " + << (ordering == Ordering::kNN ? "kNN" + : ordering == Ordering::kNR ? "kNR" + : ordering == Ordering::kRN ? "kRN" + : ordering == Ordering::kRR ? "kRR" + : ordering == Ordering::kNM ? "kNM" + : "kMN"); ICICLE_LOG_DEBUG << "dir: " << (dir == NTTDir::kForward ? "Forward" : "Inverse"); const int total_size = N * batch_size; @@ -255,19 +260,17 @@ TYPED_TEST(FieldApiTest, ntt) init_domain_config.is_async = false; init_domain_config.ext.set(CUDA_NTT_FAST_TWIDDLES_MODE, true); - auto config = default_ntt_config(); // config.stream = stream; // config.coset_gen = coset_gen; // TODO - Implement. default: S::one() - config.batch_size = batch_size; // default: 1 + config.batch_size = batch_size; // default: 1 config.columns_batch = columns_batch; // default: false - config.ordering = ordering; // default: kNN - config.are_inputs_on_device = true; // TODO, ask yuval why set to true? - config.are_outputs_on_device = true;// TODO, ask yuval why set to true? + config.ordering = ordering; // default: kNN + config.are_inputs_on_device = true; // TODO, ask yuval why set to true? + config.are_outputs_on_device = true; // TODO, ask yuval why set to true? // config.is_async = false; // TODO - Implement. default: false ICICLE_CHECK(ntt_init_domain(scalar_t::omega(logn), init_domain_config)); - - + TypeParam *d_in, *d_out; icicle_malloc_async((void**)&d_in, total_size * sizeof(TypeParam), config.stream); icicle_malloc_async((void**)&d_out, total_size * sizeof(TypeParam), config.stream);