Skip to content

Commit

Permalink
feat(compiler): distributed execution - on-demand key transfer to rem…
Browse files Browse the repository at this point in the history
…ote nodes.
  • Loading branch information
antoniupop committed Feb 23, 2024
1 parent e67aa9b commit bda568a
Show file tree
Hide file tree
Showing 7 changed files with 258 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ typedef struct FFT {
FFT() = delete;
FFT(size_t polynomial_size);
FFT(FFT &other) = delete;
FFT(const FFT &other) = delete;
FFT(FFT &&other);
~FFT();

Expand All @@ -42,7 +43,7 @@ typedef struct RuntimeContext {

RuntimeContext() = delete;
RuntimeContext(ServerKeyset serverKeyset);
~RuntimeContext() {
virtual ~RuntimeContext() {
#ifdef CONCRETELANG_CUDA_SUPPORT
for (int i = 0; i < num_devices; ++i) {
if (bsk_gpu[i] != nullptr)
Expand All @@ -53,27 +54,30 @@ typedef struct RuntimeContext {
#endif
};

const uint64_t *keyswitch_key_buffer(size_t keyId) {
virtual const uint64_t *keyswitch_key_buffer(size_t keyId) {
return serverKeyset.lweKeyswitchKeys[keyId].getBuffer().data();
}

const std::complex<double> *fourier_bootstrap_key_buffer(size_t keyId) {
virtual const std::complex<double> *
fourier_bootstrap_key_buffer(size_t keyId) {
return fourier_bootstrap_keys[keyId]->data();
}

const uint64_t *fp_keyswitch_key_buffer(size_t keyId) {
virtual const uint64_t *fp_keyswitch_key_buffer(size_t keyId) {
return serverKeyset.packingKeyswitchKeys[keyId].getRawPtr();
}

const struct Fft *fft(size_t keyId) { return ffts[keyId].fft; }
virtual const struct Fft *fft(size_t keyId) { return ffts[keyId].fft; }

const ServerKeyset getKeys() const { return serverKeyset; }

private:
protected:
ServerKeyset serverKeyset;
std::vector<std::shared_ptr<std::vector<std::complex<double>>>>
fourier_bootstrap_keys;
std::vector<FFT> ffts;
std::pair<FFT, std::shared_ptr<std::vector<std::complex<double>>>>
convert_to_fourier_domain(LweBootstrapKey &bsk);

#ifdef CONCRETELANG_CUDA_SUPPORT
public:
Expand Down Expand Up @@ -144,6 +148,24 @@ typedef struct RuntimeContext {
#endif
} RuntimeContext;

struct DistributedRuntimeContext : public RuntimeContext {

using RuntimeContext::RuntimeContext;
const uint64_t *keyswitch_key_buffer(size_t keyId) override;
const std::complex<double> *
fourier_bootstrap_key_buffer(size_t keyId) override;
const uint64_t *fp_keyswitch_key_buffer(size_t keyId) override;
const struct Fft *fft(size_t keyId) override;

private:
void getBSKonNode(size_t keyId);
std::mutex cm_guard;
std::map<size_t, LweKeyswitchKey> ksks;
std::map<size_t, std::shared_ptr<std::vector<std::complex<double>>>> fbks;
std::map<size_t, FFT> dffts;
std::map<size_t, PackingKeyswitchKey> pksks;
};

} // namespace concretelang
} // namespace mlir

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@ namespace concretelang {
namespace dfr {

struct RuntimeContextManager;
namespace {
static void *dl_handle;
static RuntimeContextManager *_dfr_node_level_runtime_context_manager;
} // namespace
extern RuntimeContextManager *_dfr_node_level_runtime_context_manager;

template <typename LweKeyType> struct KeyWrapper {
std::vector<LweKeyType> keys;
Expand Down Expand Up @@ -109,21 +106,40 @@ struct RuntimeContextManager {
// TODO: this is only ok so long as we don't change keys. Once we
// use multiple keys, should have a map.
RuntimeContext *context;
bool allocated = false;
bool lazy_key_transfer = false;

RuntimeContextManager() {
RuntimeContextManager(bool lazy = false) : lazy_key_transfer(lazy) {
context = nullptr;
_dfr_node_level_runtime_context_manager = this;
}

void setContext(void *ctx) {
assert(context == nullptr &&
"Only one RuntimeContext can be used at a time.");
context = (RuntimeContext *)ctx;

if (lazy_key_transfer) {
if (!_dfr_is_root_node()) {
context =
new mlir::concretelang::DistributedRuntimeContext(ServerKeyset());
allocated = true;
}
return;
}

// When the root node does not require a context, we still need to
// broadcast an empty keyset to remote nodes as they cannot know
// ahead of time and avoid waiting for the broadcast. Instantiate
// an empty context for this.
if (_dfr_is_root_node() && ctx == nullptr) {
context = new mlir::concretelang::RuntimeContext(ServerKeyset());
allocated = true;
}

// Root node broadcasts the evaluation keys and each remote
// instantiates a local RuntimeContext.
if (_dfr_is_root_node()) {
RuntimeContext *context = (RuntimeContext *)ctx;

KeyWrapper<LweKeyswitchKey> kskw(context->getKeys().lweKeyswitchKeys);
KeyWrapper<LweBootstrapKey> bskw(context->getKeys().lweBootstrapKeys);
KeyWrapper<PackingKeyswitchKey> pkskw(
Expand Down Expand Up @@ -153,12 +169,23 @@ struct RuntimeContextManager {

void clearContext() {
if (context != nullptr)
delete context;
// On root node deallocate only if allocated independently here
if (!_dfr_is_root_node() || allocated)
delete context;
context = nullptr;
}
};

KeyWrapper<LweKeyswitchKey> getKsk(size_t keyId);
KeyWrapper<LweBootstrapKey> getBsk(size_t keyId);
KeyWrapper<PackingKeyswitchKey> getPKsk(size_t keyId);

HPX_DEFINE_PLAIN_ACTION(getKsk, _get_ksk_action);
HPX_DEFINE_PLAIN_ACTION(getBsk, _get_bsk_action);
HPX_DEFINE_PLAIN_ACTION(getPKsk, _get_pksk_action);

} // namespace dfr
} // namespace concretelang
} // namespace mlir

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ namespace dfr {

struct WorkFunctionRegistry;
namespace {
static void *dl_handle;
static WorkFunctionRegistry *_dfr_node_level_work_function_registry;
}
} // namespace

struct WorkFunctionRegistry {
WorkFunctionRegistry() { _dfr_node_level_work_function_registry = this; }
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
add_compile_options(-fsized-deallocation)

if(CONCRETELANG_CUDA_SUPPORT)
add_library(ConcretelangRuntime SHARED context.cpp simulation.cpp wrappers.cpp DFRuntime.cpp GPUDFG.cpp)
add_library(ConcretelangRuntime SHARED context.cpp simulation.cpp wrappers.cpp DFRuntime.cpp key_manager.cpp
GPUDFG.cpp)
target_link_libraries(ConcretelangRuntime PRIVATE hwloc)
else()
add_library(ConcretelangRuntime SHARED context.cpp simulation.cpp wrappers.cpp DFRuntime.cpp StreamEmulator.cpp)
add_library(ConcretelangRuntime SHARED context.cpp simulation.cpp wrappers.cpp DFRuntime.cpp key_manager.cpp
StreamEmulator.cpp)
endif()

add_dependencies(ConcretelangRuntime concrete_cpu concrete_cpu_noise_model concrete-protocol)
Expand Down
19 changes: 13 additions & 6 deletions compilers/concrete-compiler/compiler/lib/Runtime/DFRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,16 @@ static inline void _dfr_start_impl(int argc, char *argv[]) {
num_nodes = hpx::get_num_localities().get();

new WorkFunctionRegistry();
new RuntimeContextManager();

char *env = getenv("DFR_LAZY_KEY_TRANSFER");
bool lazy = false;
if (env != nullptr)
if (!strncmp(env, "True", 4) || !strncmp(env, "true", 4) ||
!strncmp(env, "On", 2) || !strncmp(env, "on", 2) ||
!strncmp(env, "1", 1))
lazy = true;
new RuntimeContextManager(lazy);

_dfr_jit_phase_barrier = new hpx::distributed::barrier(
"phase_barrier", num_nodes, hpx::get_locality_id());
_dfr_startup_barrier = new hpx::distributed::barrier(
Expand Down Expand Up @@ -351,14 +360,12 @@ void _dfr_start(int64_t use_dfr_p, void *ctx) {

assert(init_guard == active && "DFR runtime failed to initialise");

// If DFR is used and a runtime context is needed, and execution is
// distributed, then broadcast from root to all compute nodes.
if (num_nodes > 1 && (ctx || !_dfr_is_root_node())) {
// If execution is distributed, then broadcast (possibly an empty)
// context from root to all compute nodes.
if (num_nodes > 1) {
BEGIN_TIME(&broadcast_timer);
_dfr_node_level_runtime_context_manager->setContext(ctx);
}
// If this is not JIT, then the remote nodes never reach _dfr_stop,
// so root should not instantiate this barrier.
if (_dfr_is_root_node())
_dfr_startup_barrier->wait();

Expand Down
Loading

0 comments on commit bda568a

Please sign in to comment.