From bda568ab6b6514cef090e4b5caaf975b67094932 Mon Sep 17 00:00:00 2001 From: Antoniu Pop Date: Mon, 5 Feb 2024 13:23:24 +0000 Subject: [PATCH] feat(compiler): distributed execution - on-demand key transfer to remote nodes. --- .../include/concretelang/Runtime/context.h | 34 +++- .../concretelang/Runtime/key_manager.hpp | 43 +++- .../Runtime/workfunction_registry.hpp | 3 +- .../compiler/lib/Runtime/CMakeLists.txt | 6 +- .../compiler/lib/Runtime/DFRuntime.cpp | 19 +- .../compiler/lib/Runtime/context.cpp | 185 +++++++++++++----- .../compiler/lib/Runtime/key_manager.cpp | 39 ++++ 7 files changed, 258 insertions(+), 71 deletions(-) create mode 100644 compilers/concrete-compiler/compiler/lib/Runtime/key_manager.cpp diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Runtime/context.h b/compilers/concrete-compiler/compiler/include/concretelang/Runtime/context.h index 8613f7558f..28cc3bb3f8 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Runtime/context.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Runtime/context.h @@ -31,6 +31,7 @@ typedef struct FFT { FFT() = delete; FFT(size_t polynomial_size); FFT(FFT &other) = delete; + FFT(const FFT &other) = delete; FFT(FFT &&other); ~FFT(); @@ -42,7 +43,7 @@ typedef struct RuntimeContext { RuntimeContext() = delete; RuntimeContext(ServerKeyset serverKeyset); - ~RuntimeContext() { + virtual ~RuntimeContext() { #ifdef CONCRETELANG_CUDA_SUPPORT for (int i = 0; i < num_devices; ++i) { if (bsk_gpu[i] != nullptr) @@ -53,27 +54,30 @@ typedef struct RuntimeContext { #endif }; - const uint64_t *keyswitch_key_buffer(size_t keyId) { + virtual const uint64_t *keyswitch_key_buffer(size_t keyId) { return serverKeyset.lweKeyswitchKeys[keyId].getBuffer().data(); } - const std::complex *fourier_bootstrap_key_buffer(size_t keyId) { + virtual const std::complex * + fourier_bootstrap_key_buffer(size_t keyId) { return fourier_bootstrap_keys[keyId]->data(); } - const uint64_t *fp_keyswitch_key_buffer(size_t keyId) { + virtual const uint64_t *fp_keyswitch_key_buffer(size_t keyId) { return serverKeyset.packingKeyswitchKeys[keyId].getRawPtr(); } - const struct Fft *fft(size_t keyId) { return ffts[keyId].fft; } + virtual const struct Fft *fft(size_t keyId) { return ffts[keyId].fft; } const ServerKeyset getKeys() const { return serverKeyset; } -private: +protected: ServerKeyset serverKeyset; std::vector>>> fourier_bootstrap_keys; std::vector ffts; + std::pair>>> + convert_to_fourier_domain(LweBootstrapKey &bsk); #ifdef CONCRETELANG_CUDA_SUPPORT public: @@ -144,6 +148,24 @@ typedef struct RuntimeContext { #endif } RuntimeContext; +struct DistributedRuntimeContext : public RuntimeContext { + + using RuntimeContext::RuntimeContext; + const uint64_t *keyswitch_key_buffer(size_t keyId) override; + const std::complex * + fourier_bootstrap_key_buffer(size_t keyId) override; + const uint64_t *fp_keyswitch_key_buffer(size_t keyId) override; + const struct Fft *fft(size_t keyId) override; + +private: + void getBSKonNode(size_t keyId); + std::mutex cm_guard; + std::map ksks; + std::map>>> fbks; + std::map dffts; + std::map pksks; +}; + } // namespace concretelang } // namespace mlir diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Runtime/key_manager.hpp b/compilers/concrete-compiler/compiler/include/concretelang/Runtime/key_manager.hpp index 9b2f8004d5..7b68837e59 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Runtime/key_manager.hpp +++ b/compilers/concrete-compiler/compiler/include/concretelang/Runtime/key_manager.hpp @@ -33,10 +33,7 @@ namespace concretelang { namespace dfr { struct RuntimeContextManager; -namespace { -static void *dl_handle; -static RuntimeContextManager *_dfr_node_level_runtime_context_manager; -} // namespace +extern RuntimeContextManager *_dfr_node_level_runtime_context_manager; template struct KeyWrapper { std::vector keys; @@ -109,8 +106,10 @@ struct RuntimeContextManager { // TODO: this is only ok so long as we don't change keys. Once we // use multiple keys, should have a map. RuntimeContext *context; + bool allocated = false; + bool lazy_key_transfer = false; - RuntimeContextManager() { + RuntimeContextManager(bool lazy = false) : lazy_key_transfer(lazy) { context = nullptr; _dfr_node_level_runtime_context_manager = this; } @@ -118,12 +117,29 @@ struct RuntimeContextManager { void setContext(void *ctx) { assert(context == nullptr && "Only one RuntimeContext can be used at a time."); + context = (RuntimeContext *)ctx; + + if (lazy_key_transfer) { + if (!_dfr_is_root_node()) { + context = + new mlir::concretelang::DistributedRuntimeContext(ServerKeyset()); + allocated = true; + } + return; + } + + // When the root node does not require a context, we still need to + // broadcast an empty keyset to remote nodes as they cannot know + // ahead of time and avoid waiting for the broadcast. Instantiate + // an empty context for this. + if (_dfr_is_root_node() && ctx == nullptr) { + context = new mlir::concretelang::RuntimeContext(ServerKeyset()); + allocated = true; + } // Root node broadcasts the evaluation keys and each remote // instantiates a local RuntimeContext. if (_dfr_is_root_node()) { - RuntimeContext *context = (RuntimeContext *)ctx; - KeyWrapper kskw(context->getKeys().lweKeyswitchKeys); KeyWrapper bskw(context->getKeys().lweBootstrapKeys); KeyWrapper pkskw( @@ -153,12 +169,23 @@ struct RuntimeContextManager { void clearContext() { if (context != nullptr) - delete context; + // On root node deallocate only if allocated independently here + if (!_dfr_is_root_node() || allocated) + delete context; context = nullptr; } }; +KeyWrapper getKsk(size_t keyId); +KeyWrapper getBsk(size_t keyId); +KeyWrapper getPKsk(size_t keyId); + +HPX_DEFINE_PLAIN_ACTION(getKsk, _get_ksk_action); +HPX_DEFINE_PLAIN_ACTION(getBsk, _get_bsk_action); +HPX_DEFINE_PLAIN_ACTION(getPKsk, _get_pksk_action); + } // namespace dfr } // namespace concretelang } // namespace mlir + #endif diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Runtime/workfunction_registry.hpp b/compilers/concrete-compiler/compiler/include/concretelang/Runtime/workfunction_registry.hpp index 41d14d5f19..fc65778708 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Runtime/workfunction_registry.hpp +++ b/compilers/concrete-compiler/compiler/include/concretelang/Runtime/workfunction_registry.hpp @@ -22,8 +22,9 @@ namespace dfr { struct WorkFunctionRegistry; namespace { +static void *dl_handle; static WorkFunctionRegistry *_dfr_node_level_work_function_registry; -} +} // namespace struct WorkFunctionRegistry { WorkFunctionRegistry() { _dfr_node_level_work_function_registry = this; } diff --git a/compilers/concrete-compiler/compiler/lib/Runtime/CMakeLists.txt b/compilers/concrete-compiler/compiler/lib/Runtime/CMakeLists.txt index e19b974d3d..13d4775002 100644 --- a/compilers/concrete-compiler/compiler/lib/Runtime/CMakeLists.txt +++ b/compilers/concrete-compiler/compiler/lib/Runtime/CMakeLists.txt @@ -1,10 +1,12 @@ add_compile_options(-fsized-deallocation) if(CONCRETELANG_CUDA_SUPPORT) - add_library(ConcretelangRuntime SHARED context.cpp simulation.cpp wrappers.cpp DFRuntime.cpp GPUDFG.cpp) + add_library(ConcretelangRuntime SHARED context.cpp simulation.cpp wrappers.cpp DFRuntime.cpp key_manager.cpp + GPUDFG.cpp) target_link_libraries(ConcretelangRuntime PRIVATE hwloc) else() - add_library(ConcretelangRuntime SHARED context.cpp simulation.cpp wrappers.cpp DFRuntime.cpp StreamEmulator.cpp) + add_library(ConcretelangRuntime SHARED context.cpp simulation.cpp wrappers.cpp DFRuntime.cpp key_manager.cpp + StreamEmulator.cpp) endif() add_dependencies(ConcretelangRuntime concrete_cpu concrete_cpu_noise_model concrete-protocol) diff --git a/compilers/concrete-compiler/compiler/lib/Runtime/DFRuntime.cpp b/compilers/concrete-compiler/compiler/lib/Runtime/DFRuntime.cpp index 2bf5bb0cd5..92eb0a4e79 100644 --- a/compilers/concrete-compiler/compiler/lib/Runtime/DFRuntime.cpp +++ b/compilers/concrete-compiler/compiler/lib/Runtime/DFRuntime.cpp @@ -316,7 +316,16 @@ static inline void _dfr_start_impl(int argc, char *argv[]) { num_nodes = hpx::get_num_localities().get(); new WorkFunctionRegistry(); - new RuntimeContextManager(); + + char *env = getenv("DFR_LAZY_KEY_TRANSFER"); + bool lazy = false; + if (env != nullptr) + if (!strncmp(env, "True", 4) || !strncmp(env, "true", 4) || + !strncmp(env, "On", 2) || !strncmp(env, "on", 2) || + !strncmp(env, "1", 1)) + lazy = true; + new RuntimeContextManager(lazy); + _dfr_jit_phase_barrier = new hpx::distributed::barrier( "phase_barrier", num_nodes, hpx::get_locality_id()); _dfr_startup_barrier = new hpx::distributed::barrier( @@ -351,14 +360,12 @@ void _dfr_start(int64_t use_dfr_p, void *ctx) { assert(init_guard == active && "DFR runtime failed to initialise"); - // If DFR is used and a runtime context is needed, and execution is - // distributed, then broadcast from root to all compute nodes. - if (num_nodes > 1 && (ctx || !_dfr_is_root_node())) { + // If execution is distributed, then broadcast (possibly an empty) + // context from root to all compute nodes. + if (num_nodes > 1) { BEGIN_TIME(&broadcast_timer); _dfr_node_level_runtime_context_manager->setContext(ctx); } - // If this is not JIT, then the remote nodes never reach _dfr_stop, - // so root should not instantiate this barrier. if (_dfr_is_root_node()) _dfr_startup_barrier->wait(); diff --git a/compilers/concrete-compiler/compiler/lib/Runtime/context.cpp b/compilers/concrete-compiler/compiler/lib/Runtime/context.cpp index 879fb5f415..11cdcbedc8 100644 --- a/compilers/concrete-compiler/compiler/lib/Runtime/context.cpp +++ b/compilers/concrete-compiler/compiler/lib/Runtime/context.cpp @@ -31,59 +31,148 @@ FFT::~FFT() { RuntimeContext::RuntimeContext(ServerKeyset serverKeyset) : serverKeyset(serverKeyset) { - { - - // Initialize for each bootstrap key the fourier one - for (size_t i = 0; i < serverKeyset.lweBootstrapKeys.size(); i++) { - - auto bsk = serverKeyset.lweBootstrapKeys[i]; - auto info = bsk.getInfo().asReader(); - - size_t decomposition_level_count = info.getParams().getLevelCount(); - size_t decomposition_base_log = info.getParams().getBaseLog(); - size_t glwe_dimension = info.getParams().getGlweDimension(); - size_t polynomial_size = info.getParams().getPolynomialSize(); - size_t input_lwe_dimension = info.getParams().getInputLweDimension(); - - // Create the FFT - FFT fft(polynomial_size); - - // Allocate scratch for key conversion - size_t scratch_size; - size_t scratch_align; - concrete_cpu_bootstrap_key_convert_u64_to_fourier_scratch( - &scratch_size, &scratch_align, fft.fft); - auto scratch = (uint8_t *)aligned_alloc(scratch_align, scratch_size); - - // Allocate the fourier_bootstrap_key - auto &bsk_buffer = bsk.getBuffer(); - auto fourier_data = std::make_shared>>(); - fourier_data->resize(bsk_buffer.size() / 2); - auto bsk_data = bsk_buffer.data(); - - // Convert bootstrap_key to the fourier domain - concrete_cpu_bootstrap_key_convert_u64_to_fourier( - bsk_data, fourier_data->data(), decomposition_level_count, - decomposition_base_log, glwe_dimension, polynomial_size, - input_lwe_dimension, fft.fft, scratch, scratch_size); - - // Store the fourier_bootstrap_key in the context - fourier_bootstrap_keys.push_back(fourier_data); - ffts.push_back(std::move(fft)); - free(scratch); - } + + // Initialize for each bootstrap key the fourier one + for (size_t i = 0; i < serverKeyset.lweBootstrapKeys.size(); i++) { + auto fdbsk = convert_to_fourier_domain(serverKeyset.lweBootstrapKeys[i]); + // Store the fourier_bootstrap_key in the context + fourier_bootstrap_keys.push_back(fdbsk.second); + ffts.push_back(std::move(fdbsk.first)); + } #ifdef CONCRETELANG_CUDA_SUPPORT - assert(cudaGetDeviceCount(&num_devices) == cudaSuccess); - bsk_gpu.resize(num_devices, nullptr); - ksk_gpu.resize(num_devices, nullptr); - for (int i = 0; i < num_devices; ++i) { - bsk_gpu_mutex.push_back(std::make_unique()); - ksk_gpu_mutex.push_back(std::make_unique()); - } + assert(cudaGetDeviceCount(&num_devices) == cudaSuccess); + bsk_gpu.resize(num_devices, nullptr); + ksk_gpu.resize(num_devices, nullptr); + for (int i = 0; i < num_devices; ++i) { + bsk_gpu_mutex.push_back(std::make_unique()); + ksk_gpu_mutex.push_back(std::make_unique()); + } #endif +} + +std::pair>>> +RuntimeContext::convert_to_fourier_domain(LweBootstrapKey &bsk) { + auto info = bsk.getInfo().asReader(); + + size_t decomposition_level_count = info.getParams().getLevelCount(); + size_t decomposition_base_log = info.getParams().getBaseLog(); + size_t glwe_dimension = info.getParams().getGlweDimension(); + size_t polynomial_size = info.getParams().getPolynomialSize(); + size_t input_lwe_dimension = info.getParams().getInputLweDimension(); + + // Create the FFT + FFT fft(polynomial_size); + + // Allocate scratch for key conversion + size_t scratch_size; + size_t scratch_align; + concrete_cpu_bootstrap_key_convert_u64_to_fourier_scratch( + &scratch_size, &scratch_align, fft.fft); + auto scratch = (uint8_t *)aligned_alloc(scratch_align, scratch_size); + + // Allocate the fourier_bootstrap_key + auto &bsk_buffer = bsk.getBuffer(); + auto fourier_data = std::make_shared>>(); + fourier_data->resize(bsk_buffer.size() / 2); + auto bsk_data = bsk_buffer.data(); + + // Convert bootstrap_key to the fourier domain + concrete_cpu_bootstrap_key_convert_u64_to_fourier( + bsk_data, fourier_data->data(), decomposition_level_count, + decomposition_base_log, glwe_dimension, polynomial_size, + input_lwe_dimension, fft.fft, scratch, scratch_size); + free(scratch); + + return std::pair>>>( + std::move(fft), fourier_data); +} +} // namespace concretelang +} // namespace mlir + +#ifdef CONCRETELANG_DATAFLOW_EXECUTION_ENABLED +#include "concretelang/Runtime/key_manager.hpp" + +// Register the HPX actions for retrieving the evaluation keys from +// the master node (must be in global namespace) +HPX_PLAIN_ACTION(mlir::concretelang::dfr::getKsk, _dfr_get_ksk_action) +HPX_PLAIN_ACTION(mlir::concretelang::dfr::getBsk, _dfr_get_bsk_action) +HPX_PLAIN_ACTION(mlir::concretelang::dfr::getPKsk, _dfr_get_pksk_action) + +namespace mlir { +namespace concretelang { +const uint64_t *DistributedRuntimeContext::keyswitch_key_buffer(size_t keyId) { + if (dfr::_dfr_is_root_node()) + return RuntimeContext::keyswitch_key_buffer(keyId); + + std::lock_guard guard(cm_guard); + if (ksks.find(keyId) == ksks.end()) { + _dfr_get_ksk_action getKskAction; + dfr::KeyWrapper kskw = + getKskAction(hpx::find_root_locality(), keyId); + ksks.insert(std::pair(keyId, kskw.keys[0])); + } + auto it = ksks.find(keyId); + assert(it != ksks.end()); + return it->second.getBuffer().data(); +} + +void DistributedRuntimeContext::getBSKonNode(size_t keyId) { + assert(fbks.find(keyId) == fbks.end()); + assert(dffts.find(keyId) == dffts.end()); + _dfr_get_bsk_action getBskAction; + dfr::KeyWrapper bskw = + getBskAction(hpx::find_root_locality(), keyId); + + auto fdbsk = convert_to_fourier_domain(bskw.keys[0]); + fbks.insert( + std::pair>>>( + keyId, fdbsk.second)); + dffts.insert(std::pair(keyId, std::move(fdbsk.first))); +} + +const std::complex * +DistributedRuntimeContext::fourier_bootstrap_key_buffer(size_t keyId) { + if (dfr::_dfr_is_root_node()) + return RuntimeContext::fourier_bootstrap_key_buffer(keyId); + + std::lock_guard guard(cm_guard); + if (fbks.find(keyId) == fbks.end()) + getBSKonNode(keyId); + auto it = fbks.find(keyId); + assert(it != fbks.end()); + return it->second->data(); +} + +const uint64_t * +DistributedRuntimeContext::fp_keyswitch_key_buffer(size_t keyId) { + if (dfr::_dfr_is_root_node()) + return RuntimeContext::fp_keyswitch_key_buffer(keyId); + + std::lock_guard guard(cm_guard); + if (ksks.find(keyId) == ksks.end()) { + _dfr_get_pksk_action getPKskAction; + dfr::KeyWrapper pkskw = + getPKskAction(hpx::find_root_locality(), keyId); + pksks.insert(std::pair(keyId, pkskw.keys[0])); } + auto it = pksks.find(keyId); + assert(it != pksks.end()); + return it->second.getRawPtr(); +} + +const struct Fft *DistributedRuntimeContext::fft(size_t keyId) { + if (dfr::_dfr_is_root_node()) + return RuntimeContext::fft(keyId); + + std::lock_guard guard(cm_guard); + if (dffts.find(keyId) == dffts.end()) + getBSKonNode(keyId); + auto it = dffts.find(keyId); + assert(it != dffts.end()); + return it->second.fft; } } // namespace concretelang } // namespace mlir +#endif diff --git a/compilers/concrete-compiler/compiler/lib/Runtime/key_manager.cpp b/compilers/concrete-compiler/compiler/lib/Runtime/key_manager.cpp new file mode 100644 index 0000000000..11af106ccf --- /dev/null +++ b/compilers/concrete-compiler/compiler/lib/Runtime/key_manager.cpp @@ -0,0 +1,39 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#ifdef CONCRETELANG_DATAFLOW_EXECUTION_ENABLED + +#include "concretelang/Runtime/key_manager.hpp" +#include "concretelang/Common/Keysets.h" +#include "concretelang/Runtime/context.h" + +namespace mlir { +namespace concretelang { +namespace dfr { + +RuntimeContextManager *_dfr_node_level_runtime_context_manager; + +KeyWrapper getKsk(size_t keyId) { + return KeyWrapper(std::vector{ + _dfr_node_level_runtime_context_manager->context->getKeys() + .lweKeyswitchKeys[keyId]}); +} + +KeyWrapper getBsk(size_t keyId) { + return KeyWrapper(std::vector{ + _dfr_node_level_runtime_context_manager->context->getKeys() + .lweBootstrapKeys[keyId]}); +} + +KeyWrapper getPKsk(size_t keyId) { + return KeyWrapper(std::vector{ + _dfr_node_level_runtime_context_manager->context->getKeys() + .packingKeyswitchKeys[keyId]}); +} + +} // namespace dfr +} // namespace concretelang +} // namespace mlir +#endif