Skip to content

Commit

Permalink
Fix instance pool (#1881)
Browse files Browse the repository at this point in the history
* Fix instance pool multithreading

* Fix locks on compiling_modules

* Don't create redundant promises
  • Loading branch information
Harrm authored Nov 28, 2023
1 parent 3130861 commit cbbee3b
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 43 deletions.
43 changes: 18 additions & 25 deletions core/runtime/common/runtime_instances_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,58 +79,51 @@ namespace kagome::runtime {

if (!pool_opt) {
lock.unlock();
if (auto future = getFutureCompiledModule(code_hash)) {
lock.lock();
pool_opt = pools_.get(code_hash);
} else {
OUTCOME_TRY(module, tryCompileModule(code_hash, code_zstd));
BOOST_ASSERT(module != nullptr);
lock.lock();
OUTCOME_TRY(module, tryCompileModule(code_hash, code_zstd));
lock.lock();
pool_opt = pools_.get(code_hash);
if (!pool_opt) {
pool_opt = std::ref(pools_.put(code_hash, InstancePool{module, {}}));
}
}
BOOST_ASSERT(pool_opt);
auto instance = pool_opt->get().instantiate(lock);
return std::make_shared<BorrowedInstance>(
weak_from_this(), code_hash, std::move(instance));
}

std::optional<std::shared_future<RuntimeInstancesPool::CompilationResult>>
RuntimeInstancesPool::getFutureCompiledModule(
const CodeHash &code_hash) const {
std::unique_lock l{compiling_modules_mtx_};
auto iter = compiling_modules_.find(code_hash);
if (iter == compiling_modules_.end()) {
return std::nullopt;
}
auto future = iter->second;
l.unlock();
return future;
}

RuntimeInstancesPool::CompilationResult
RuntimeInstancesPool::tryCompileModule(const CodeHash &code_hash,
common::BufferView code_zstd) {
std::unique_lock l{compiling_modules_mtx_};
if (auto iter = compiling_modules_.find(code_hash);
iter != compiling_modules_.end()) {
std::shared_future<CompilationResult> future = iter->second;
l.unlock();
return future.get();
}
std::promise<CompilationResult> promise;
auto [iter, inserted] =
auto [iter, is_inserted] =
compiling_modules_.insert({code_hash, promise.get_future()});
BOOST_ASSERT(inserted);
BOOST_ASSERT(is_inserted);
BOOST_ASSERT(iter != compiling_modules_.end());
l.unlock();

common::Buffer code;
CompilationResult res{nullptr};
std::optional<CompilationResult> res;
if (!uncompressCodeIfNeeded(code_zstd, code)) {
res = CompilationError{"Failed to uncompress code"};
} else {
res = common::map_result(module_factory_->make(code), [](auto &&module) {
return std::shared_ptr<const Module>(module);
});
}
promise.set_value(res);
BOOST_ASSERT(res);

l.lock();
compiling_modules_.erase(iter);
return res;
promise.set_value(*res);
return *res;
}

outcome::result<std::shared_ptr<ModuleInstance>>
Expand Down
3 changes: 0 additions & 3 deletions core/runtime/common/runtime_instances_pool.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,6 @@ namespace kagome::runtime {
CompilationResult tryCompileModule(const CodeHash &code_hash,
common::BufferView code_zstd);

std::optional<std::shared_future<CompilationResult>> getFutureCompiledModule(
const CodeHash &code_hash) const;

std::shared_ptr<ModuleFactory> module_factory_;

std::mutex pools_mtx_;
Expand Down
29 changes: 14 additions & 15 deletions test/core/runtime/instance_pool_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

#include <gtest/gtest.h>

#if defined(BACKWARD_HAS_BACKTRACE)
#include <backward.hpp>
#endif
#include <algorithm>
#include <random>
#include <ranges>

#include "testutil/literals.hpp"
#include "testutil/outcome.hpp"
Expand All @@ -33,10 +33,6 @@ RuntimeInstancesPool::CodeHash make_code_hash(int i) {
TEST(InstancePoolTest, HeavilyMultithreadedCompilation) {
using namespace std::chrono_literals;

#if defined(BACKWARD_HAS_BACKTRACE)
backward::SignalHandling sh;
#endif

auto module_instance_mock = std::make_shared<ModuleInstanceMock>();

auto module_mock = std::make_shared<ModuleMock>();
Expand All @@ -54,13 +50,16 @@ TEST(InstancePoolTest, HeavilyMultithreadedCompilation) {
return module_mock;
}));

RuntimeInstancesPool pool{module_factory, 5};
static constexpr int THREAD_NUM = 100;
static constexpr int POOL_SIZE = 10;

RuntimeInstancesPool pool{module_factory, POOL_SIZE};

std::vector<std::thread> threads;
for (int i = 0; i < 10; i++) {
threads.emplace_back(std::thread([&pool, i, &code]() {
for (int i = 0; i < THREAD_NUM; i++) {
threads.emplace_back(std::thread([&pool, &code, i]() {
ASSERT_OUTCOME_SUCCESS_TRY(
pool.instantiateFromCode(make_code_hash(i % 5), code.view()));
pool.instantiateFromCode(make_code_hash(i % POOL_SIZE), code));
}));
}

Expand All @@ -69,12 +68,12 @@ TEST(InstancePoolTest, HeavilyMultithreadedCompilation) {
}

// check that 'make' was only called 5 times
ASSERT_EQ(times_make_called.load(), 5);
ASSERT_EQ(times_make_called.load(), POOL_SIZE);

// check that all 10 instances are in cache
for (int i = 0; i < 5; i++) {
// check that all POOL_SIZE instances are in cache
for (int i = 0; i < POOL_SIZE; i++) {
ASSERT_OUTCOME_SUCCESS_TRY(
pool.instantiateFromCode(make_code_hash(i), code.view()));
}
ASSERT_EQ(times_make_called.load(), 5);
ASSERT_EQ(times_make_called.load(), POOL_SIZE);
}

0 comments on commit cbbee3b

Please sign in to comment.