Skip to content

Commit

Permalink
fix: workaround for bshoshany/thread-pool/issues/100
Browse files Browse the repository at this point in the history
  • Loading branch information
Antares0982 committed Mar 16, 2023
1 parent 6a47ec1 commit 11f3d18
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 18 deletions.
4 changes: 0 additions & 4 deletions src/Allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
// Created by antares on 3/15/23.
//

#pragma clang diagnostic push
#pragma ide diagnostic ignored "HidingNonVirtualFunction"
#ifndef LOCKFREE_THREADPOOL_ALLOCATOR_H
#define LOCKFREE_THREADPOOL_ALLOCATOR_H

Expand Down Expand Up @@ -43,5 +41,3 @@ namespace Antares {
};
}
#endif //LOCKFREE_THREADPOOL_ALLOCATOR_H

#pragma clang diagnostic pop
53 changes: 41 additions & 12 deletions src/ThreadPool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,22 +60,47 @@ inline void platform_get_thread_name(decltype(platform_thread_self()) id, char *
#endif

namespace Antares {
// void ThreadPoolBase::PoolWorker::worker() {
// platform_set_thread_name(platform_thread_self(), "Worker");
// std::unique_lock<std::mutex> tasks_lock(cv_mtx);
// while (pool.running.load(std::memory_order_relaxed)) {
// std::function<void()> task;
// while (pool.tasks_total == 0 && pool.running) {
// pool.task_available_cv.wait_for(tasks_lock, std::chrono::milliseconds(100));
// }
//
// while (!pool.paused && pool.tasks_total > 0) {
// auto popResult = pool.tasks.pop(task);
// if (!popResult) continue;
// task();
// --pool.tasks_total;
// if (pool.waiting) {
// std::lock_guard _lk(pool.task_done_mtx);
// pool.task_done_cv.notify_one();
// }
// }
// }
// }

void ThreadPoolBase::worker() {
platform_set_thread_name(platform_thread_self(), "Worker");
std::mutex t_mtx;
while (running) {
std::mutex cv_mtx;
std::unique_lock<std::mutex> tasks_lock(cv_mtx);
while (running.load(std::memory_order_relaxed)) {
std::function<void()> task;
std::unique_lock<std::mutex> tasks_lock(t_mtx);
task_available_cv.wait(tasks_lock, [this] { return !tasks.empty() || !running; });
tasks_lock.unlock();
if (!paused) {
auto popresult = tasks.pop(task);
if (!popresult) continue;
while (tasks_total == 0 && running) {
task_available_cv.wait_for(tasks_lock, std::chrono::milliseconds(100));
}

while (!paused && tasks_total > 0) {
auto popResult = tasks.pop(task);
if (!popResult) continue;
task();
--tasks_total;
if (waiting)
if (waiting) {
std::lock_guard _lk(task_done_mtx);
task_done_cv.notify_one();
}
}
}
}
Expand All @@ -94,10 +119,14 @@ namespace Antares {
}

void ThreadPoolBase::wait_for_tasks() {
std::mutex unused_mtx;
waiting = true;
static std::mutex unused_mtx;
std::unique_lock<std::mutex> tasks_lock(unused_mtx);
task_done_cv.wait(tasks_lock, [this] { return (tasks_total == (paused ? tasks.size() : 0)); });
{
// the notify_one() will only be called before this lock or after wait() happens
std::unique_lock<std::mutex> tasks_lock(unused_mtx);
task_done_cv.wait(tasks_lock,
[this] { return (tasks_total == (paused ? tasks.size() : 0)); });
}
waiting = false;
}

Expand Down
37 changes: 36 additions & 1 deletion src/ThreadPool.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,24 @@ namespace Antares {

class ThreadPoolBase {
protected:
// struct PoolWorker {
// ThreadPoolBase &pool;
// std::mutex cv_mtx;
// // std::condition_variable task_cv;
//
// PoolWorker(ThreadPoolBase &inPool) : pool(inPool) {}
//
// PoolWorker(const PoolWorker &) = delete;
//
// PoolWorker(PoolWorker &&) = delete;
//
// PoolWorker &operator=(const PoolWorker &) = delete;
//
// PoolWorker &operator=(PoolWorker &&) = delete;
//
// void worker();
// };

template<typename T1, typename T2, typename T = std::common_type_t<T1, T2>>
class [[nodiscard]] blocks {
public:
Expand Down Expand Up @@ -94,6 +112,7 @@ namespace Antares {
std::condition_variable task_available_cv = {};
std::atomic<bool> running = false;
std::atomic<bool> waiting = false;
std::mutex task_done_mtx{};
std::condition_variable task_done_cv = {};
std::atomic<bool> paused = false;
LockfreeQueue tasks; // this class implements its own traits
Expand Down Expand Up @@ -138,6 +157,7 @@ namespace Antares {

private:
std::vector<std::thread, Allocator<std::thread, Traits>> threads;
// std::vector<std::unique_ptr<PoolWorker>, Allocator<std::unique_ptr<PoolWorker>, Traits>> workers;

public:
ThreadPool(concurrency_t thread_count_ = 0)
Expand Down Expand Up @@ -169,8 +189,15 @@ namespace Antares {
std::function<void()> task_function = std::bind(std::forward<F>(task), std::forward<A>(args)...);

tasks.push(std::move(task_function));

++tasks_total;
// for (size_t i = 0; i < threads.size(); ++i) {
// auto w = workers[i].get();
// if (w->cv_mtx.try_lock()) {
// w->task_cv.notify_one();
// w->cv_mtx.unlock();
// return;
// }
// }
task_available_cv.notify_one();
}

Expand Down Expand Up @@ -246,24 +273,32 @@ namespace Antares {
destroy_threads();
auto thread_count = determine_thread_count(thread_count_);
threads.resize(thread_count);
// workers.reserve(thread_count);
// int remainSize = std::max(0, int(thread_count) - int(workers.size()));
// for (int i = 0; i < remainSize; ++i) workers.emplace_back(std::make_unique<PoolWorker>(*this));
paused = was_paused;
create_threads();
}

private:
void create_threads() {
// workers.reserve(threads.size());
// while (workers.size() < threads.size()) { workers.emplace_back(std::make_unique<PoolWorker>(*this)); }
running = true;
for (concurrency_t i = 0; i < threads.size(); ++i) {
threads[i] = std::thread(&ThreadPool::worker, this);
}
}

void destroy_threads() {
//
// for (auto &w: workers) { w->cv_mtx.lock(); }
running = false;
task_available_cv.notify_all();
for (concurrency_t i = 0; i < threads.size(); ++i) {
threads[i].join();
}
// for (auto &w: workers) { w->cv_mtx.unlock(); }
}
};
}
Expand Down
20 changes: 19 additions & 1 deletion src/test_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,21 @@ void check_reset() {
dual_println(
"Checking that after a second reset() the manually counted number of unique thread IDs is equal to the reported number of threads...");
check(pool.get_thread_count(), count_unique_threads());

// https://github.com/bshoshany/thread-pool/issues/100
dual_println(
"Checking that after many reset() calls the pool does not dead lock...");
for (int i = 0; i < 100; ++i) {
pool.pause();
for (int j = 0; j < 32; ++j) {
pool.push_task([]() {return 0; });
}
pool.unpause();
for (int j = 0; j < 32; ++j) {
pool.reset();
}
pool.wait_for_tasks();
}
}

// =======================================
Expand Down Expand Up @@ -1035,7 +1050,7 @@ void check_performance() {
thread_count * 4};

// How many times to repeat each run of the test in order to collect reliable statistics.
constexpr size_t repeat = 20;
constexpr size_t repeat = 50;
dual_println("Each test will be repeated ", repeat, " times to collect reliable statistics.");

// The target execution time, in milliseconds, of the multi-threaded test with the number of blocks equal to the number of threads. The total time spent on that test will be approximately equal to repeat * target_ms.
Expand All @@ -1058,6 +1073,7 @@ void check_performance() {
num_vectors *= 2;
vector_size *= 2;
vectors = std::vector<std::vector<double>>(num_vectors, std::vector<double>(vector_size));
std::this_thread::sleep_for(std::chrono::milliseconds(100));
tmr.start();
pool.push_loop(num_vectors, loop);
pool.wait_for_tasks();
Expand All @@ -1078,6 +1094,8 @@ void check_performance() {
dual_println("Generating ", num_vectors, " vectors with ", vector_size, " elements each:");
for (Antares::concurrency_t n: try_tasks) {
for (size_t r = 0; r < repeat; ++r) {
// let the pool rest for a while before starting the next test
std::this_thread::sleep_for(std::chrono::milliseconds(500));
tmr.start();
if (n > 1) {
pool.push_loop(num_vectors, loop, n);
Expand Down

0 comments on commit 11f3d18

Please sign in to comment.