Skip to content

Commit

Permalink
Merge branch 'INSTX-6538-priorityqueue-fix-backport-0.8' into 'releas…
Browse files Browse the repository at this point in the history
…e-v0.8'

Backport INSTX-6538-priorityqueue-fix to 0.8 release

See merge request machine-learning/dorado!1202
  • Loading branch information
blawrence-ont committed Sep 19, 2024
2 parents acec121 + 5d40559 commit 922ab14
Show file tree
Hide file tree
Showing 8 changed files with 160 additions and 99 deletions.
4 changes: 2 additions & 2 deletions dorado/utils/concurrency/async_task_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ AsyncTaskExecutor::~AsyncTaskExecutor() { flush(); }
void AsyncTaskExecutor::send_impl(TaskType task) {
increment_tasks_in_flight();

m_thread_pool_queue->push([task = std::move(task), this] {
m_thread_pool_queue.push([task = std::move(task), this] {
task();
decrement_tasks_in_flight();
});
Expand All @@ -23,7 +23,7 @@ std::unique_ptr<std::thread> AsyncTaskExecutor::send_async(TaskType task) {
increment_tasks_in_flight();

auto sending_thread = std::make_unique<std::thread>([this, task = std::move(task)]() mutable {
m_thread_pool_queue->push([task = std::move(task), this] {
m_thread_pool_queue.push([task = std::move(task), this] {
task();
decrement_tasks_in_flight();
});
Expand Down
6 changes: 3 additions & 3 deletions dorado/utils/concurrency/async_task_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ namespace dorado::utils::concurrency {
// to flush it's tasks without waiting for other pipleines to
// also flush.
class AsyncTaskExecutor {
std::unique_ptr<MultiQueueThreadPool::ThreadPoolQueue> m_thread_pool_queue;
std::mutex m_mutex{};
std::condition_variable m_tasks_in_flight_cv{};
MultiQueueThreadPool::ThreadPoolQueue& m_thread_pool_queue;
std::mutex m_mutex;
std::condition_variable m_tasks_in_flight_cv;
std::size_t m_num_tasks_in_flight{};
std::unique_ptr<Latch> m_flushing_counter;
std::size_t m_max_tasks_in_flight;
Expand Down
35 changes: 14 additions & 21 deletions dorado/utils/concurrency/detail/priority_task_queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

namespace dorado::utils::concurrency::detail {

void PriorityTaskQueue::queue_producer_task(ProducerQueue* producer_queue) {
void PriorityTaskQueue::queue_producer_task(TaskQueue* producer_queue) {
m_producer_queue_list.push_back(producer_queue);
auto task_itr = std::prev(m_producer_queue_list.end());
if (producer_queue->priority() == TaskPriority::high) {
Expand All @@ -22,43 +22,37 @@ std::size_t PriorityTaskQueue::size(TaskPriority priority) const {
}

WaitingTask PriorityTaskQueue::pop() {
auto producer_queue_itr = m_producer_queue_list.begin();
TaskPriority popped_priority{TaskPriority::normal};
if (!m_low_producer_queue.empty() && producer_queue_itr == m_low_producer_queue.front()) {
m_low_producer_queue.pop();
} else {
popped_priority = TaskPriority::high;
m_high_producer_queue.pop();
}

WaitingTask result{(*producer_queue_itr)->pop(), popped_priority};
m_producer_queue_list.pop_front();
return result;
assert(!m_producer_queue_list.empty());
const auto next_priority = m_producer_queue_list.front()->priority();
return pop(next_priority);
}

WaitingTask PriorityTaskQueue::pop(TaskPriority priority) {
ProducerQueueList::iterator producer_queue_itr;
TaskQueueList::iterator producer_queue_itr;
if (priority == TaskPriority::high) {
assert(!m_high_producer_queue.empty());
producer_queue_itr = m_high_producer_queue.front();
m_high_producer_queue.pop();
} else {
assert(!m_low_producer_queue.empty());
producer_queue_itr = m_low_producer_queue.front();
m_low_producer_queue.pop();
}
assert(priority == (*producer_queue_itr)->priority());

WaitingTask result{(*producer_queue_itr)->pop(), priority};
m_producer_queue_list.pop_front();
m_producer_queue_list.erase(producer_queue_itr);
return result;
}

bool PriorityTaskQueue::empty() const { return size() == 0; }

bool PriorityTaskQueue::empty(TaskPriority priority) const { return size(priority) == 0; }

PriorityTaskQueue::ProducerQueue::ProducerQueue(PriorityTaskQueue* parent, TaskPriority priority)
PriorityTaskQueue::TaskQueue::TaskQueue(PriorityTaskQueue* parent, TaskPriority priority)
: m_parent(parent), m_priority(priority) {}

void PriorityTaskQueue::ProducerQueue::push(TaskType task) {
void PriorityTaskQueue::TaskQueue::push(TaskType task) {
m_producer_queue.push(std::move(task));
if (m_priority == TaskPriority::normal) {
++m_parent->m_num_normal_prio;
Expand All @@ -70,7 +64,7 @@ void PriorityTaskQueue::ProducerQueue::push(TaskType task) {
}
}

TaskType PriorityTaskQueue::ProducerQueue::pop() {
TaskType PriorityTaskQueue::TaskQueue::pop() {
assert(!m_producer_queue.empty() && "Cannot pop an empty producer queue.");
auto result = std::move(m_producer_queue.front());
m_producer_queue.pop();
Expand All @@ -86,8 +80,7 @@ TaskType PriorityTaskQueue::ProducerQueue::pop() {
}

PriorityTaskQueue::TaskQueue& PriorityTaskQueue::create_task_queue(TaskPriority priority) {
m_queue_repository.emplace_back(std::make_unique<ProducerQueue>(this, priority));
return *m_queue_repository.back();
return *m_queue_repository.emplace_back(new TaskQueue(this, priority));
}

} // namespace dorado::utils::concurrency::detail
} // namespace dorado::utils::concurrency::detail
48 changes: 24 additions & 24 deletions dorado/utils/concurrency/detail/priority_task_queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,7 @@ struct WaitingTask {
// The interface supports popping by priority.
class PriorityTaskQueue {
public:
class TaskQueue {
public:
virtual ~TaskQueue() = default;
virtual void push(TaskType task) = 0;
};
class TaskQueue;
TaskQueue& create_task_queue(TaskPriority priority);

WaitingTask pop();
Expand All @@ -44,35 +40,39 @@ class PriorityTaskQueue {
bool empty() const;
bool empty(TaskPriority priority) const;

private:
class ProducerQueue : public TaskQueue {
class TaskQueue {
friend class PriorityTaskQueue;

PriorityTaskQueue* m_parent;
TaskPriority m_priority;
std::queue<TaskType> m_producer_queue{};
std::queue<TaskType> m_producer_queue;

public:
ProducerQueue(PriorityTaskQueue* parent, TaskPriority priority);
TaskQueue(PriorityTaskQueue* parent, TaskPriority priority);
TaskType pop();

TaskPriority priority() const { return m_priority; };
TaskQueue(const TaskQueue&) = delete;
TaskQueue& operator=(const TaskQueue&) = delete;

void push(TaskType task) override;
TaskType pop();
public:
TaskPriority priority() const { return m_priority; }
void push(TaskType task);
};
std::vector<std::unique_ptr<ProducerQueue>>
m_queue_repository{}; // ownership of producer queues
using ProducerQueueList = std::list<ProducerQueue*>;
ProducerQueueList m_producer_queue_list{};
std::queue<ProducerQueueList::iterator> m_low_producer_queue{};
std::queue<ProducerQueueList::iterator> m_high_producer_queue{};

private:
std::vector<std::unique_ptr<TaskQueue>> m_queue_repository; // ownership of producer queues
using TaskQueueList = std::list<TaskQueue*>;
TaskQueueList m_producer_queue_list;
std::queue<TaskQueueList::iterator> m_low_producer_queue;
std::queue<TaskQueueList::iterator> m_high_producer_queue;
std::size_t m_num_normal_prio{};
std::size_t m_num_high_prio{};

using WaitingTaskList = std::list<std::shared_ptr<detail::WaitingTask>>;
WaitingTaskList m_task_list{};
std::queue<WaitingTaskList::iterator> m_low_queue{};
std::queue<WaitingTaskList::iterator> m_high_queue{};
WaitingTaskList m_task_list;
std::queue<WaitingTaskList::iterator> m_low_queue;
std::queue<WaitingTaskList::iterator> m_high_queue;

void queue_producer_task(ProducerQueue* producer_queue);
void queue_producer_task(TaskQueue* producer_queue);
};

} // namespace dorado::utils::concurrency::detail
} // namespace dorado::utils::concurrency::detail
39 changes: 16 additions & 23 deletions dorado/utils/concurrency/multi_queue_thread_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,15 @@ MultiQueueThreadPool::~MultiQueueThreadPool() { join(); }

void MultiQueueThreadPool::join() {
// post as many done messages as there are threads to make sure all waiting threads will receive a wakeup
auto& terminate_task_queue = m_priority_task_queue.create_task_queue(TaskPriority::normal);
detail::PriorityTaskQueue::TaskQueue* terminate_task_queue;
{
std::lock_guard lock(m_mutex);
terminate_task_queue = &m_priority_task_queue.create_task_queue(TaskPriority::normal);
}
for (uint32_t thread_index{0}; thread_index < m_num_threads * 2; ++thread_index) {
{
std::lock_guard lock(m_mutex);
terminate_task_queue.push([this] { m_done.store(true, std::memory_order_relaxed); });
terminate_task_queue->push([this] { m_done.store(true, std::memory_order_relaxed); });
}
m_message_received.notify_one();
}
Expand Down Expand Up @@ -135,31 +139,20 @@ void MultiQueueThreadPool::process_task_queue() {
}
}

namespace {

class ThreadPoolQueueImpl : public MultiQueueThreadPool::ThreadPoolQueue {
MultiQueueThreadPool* m_parent;
detail::PriorityTaskQueue::TaskQueue& m_task_queue;

public:
ThreadPoolQueueImpl(MultiQueueThreadPool* parent,
detail::PriorityTaskQueue::TaskQueue& task_queue);
void push(TaskType task) override;
};

ThreadPoolQueueImpl::ThreadPoolQueueImpl(MultiQueueThreadPool* parent,
detail::PriorityTaskQueue::TaskQueue& task_queue)
MultiQueueThreadPool::ThreadPoolQueue::ThreadPoolQueue(
MultiQueueThreadPool* parent,
detail::PriorityTaskQueue::TaskQueue& task_queue)
: m_parent(parent), m_task_queue(task_queue) {}

void ThreadPoolQueueImpl::push(TaskType task) { m_parent->send(std::move(task), m_task_queue); }

} // namespace
void MultiQueueThreadPool::ThreadPoolQueue::push(TaskType task) {
m_parent->send(std::move(task), m_task_queue);
}

std::unique_ptr<MultiQueueThreadPool::ThreadPoolQueue> MultiQueueThreadPool::create_task_queue(
MultiQueueThreadPool::ThreadPoolQueue& MultiQueueThreadPool::create_task_queue(
TaskPriority priority) {
std::lock_guard lock(m_mutex);
return std::make_unique<ThreadPoolQueueImpl>(this,
m_priority_task_queue.create_task_queue(priority));
auto& task_queue = m_priority_task_queue.create_task_queue(priority);
return *m_queues.emplace_back(new ThreadPoolQueue(this, task_queue));
}

} // namespace dorado::utils::concurrency
} // namespace dorado::utils::concurrency
31 changes: 22 additions & 9 deletions dorado/utils/concurrency/multi_queue_thread_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,27 +41,40 @@ class MultiQueueThreadPool {
MultiQueueThreadPool(std::size_t num_threads, std::string name);
~MultiQueueThreadPool();

void send(TaskType task, detail::PriorityTaskQueue::TaskQueue& task_queue);

void join();

class ThreadPoolQueue;
ThreadPoolQueue& create_task_queue(TaskPriority priority);

class ThreadPoolQueue {
friend class MultiQueueThreadPool;

MultiQueueThreadPool* m_parent;
detail::PriorityTaskQueue::TaskQueue& m_task_queue;

ThreadPoolQueue(MultiQueueThreadPool* parent,
detail::PriorityTaskQueue::TaskQueue& task_queue);

ThreadPoolQueue(const ThreadPoolQueue&) = delete;
ThreadPoolQueue& operator=(const ThreadPoolQueue&) = delete;

public:
virtual ~ThreadPoolQueue() = default;
virtual void push(TaskType task) = 0;
void push(TaskType task);
};
std::unique_ptr<ThreadPoolQueue> create_task_queue(TaskPriority priority);

private:
void send(TaskType task, detail::PriorityTaskQueue::TaskQueue& task_queue);

std::string m_name{"async_task_exec"};
const std::size_t m_num_threads;
const std::size_t m_num_expansion_low_prio_threads;
std::vector<std::thread> m_threads{};
std::vector<std::thread> m_threads;
std::atomic_bool m_done{false}; // Note that this flag is only accessed by the managed threads.

std::mutex m_mutex{};
detail::PriorityTaskQueue m_priority_task_queue{};
std::condition_variable m_message_received{};
std::mutex m_mutex;
std::vector<std::unique_ptr<ThreadPoolQueue>> m_queues;
detail::PriorityTaskQueue m_priority_task_queue;
std::condition_variable m_message_received;
std::size_t m_normal_prio_tasks_in_flight{};
std::size_t m_high_prio_tasks_in_flight{};

Expand Down
32 changes: 15 additions & 17 deletions tests/multi_queue_thread_pool_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,40 +83,38 @@ class NoQueueThreadPoolTestFixture {

} // namespace

DEFINE_TEST("create_task_queue returns non-null object") {
DEFINE_TEST("create_task_queue doesn't throw") {
MultiQueueThreadPool cut{1, "test_executor"};

auto task_queue = cut.create_task_queue(TaskPriority::normal);

REQUIRE(task_queue != nullptr);
CHECK_NOTHROW(cut.create_task_queue(TaskPriority::normal));
}

DEFINE_TEST("ThreadPoolQueue::push() with valid task_queue does not throw") {
MultiQueueThreadPool cut{2, "test_executor"};
auto task_queue = cut.create_task_queue(TaskPriority::normal);
auto& task_queue = cut.create_task_queue(TaskPriority::normal);

REQUIRE_NOTHROW(task_queue->push([] {}));
REQUIRE_NOTHROW(task_queue.push([] {}));
}

DEFINE_TEST("ThreadPoolQueue::push() with valid task_queue invokes the task") {
MultiQueueThreadPool cut{2, "test_executor"};
auto task_queue = cut.create_task_queue(TaskPriority::normal);
auto& task_queue = cut.create_task_queue(TaskPriority::normal);

Flag invoked{};
task_queue->push([&invoked] { invoked.signal(); });
task_queue.push([&invoked] { invoked.signal(); });

REQUIRE(invoked.wait_for(TIMEOUT));
}

DEFINE_TEST("ThreadPoolQueue::push() invokes task on separate thread") {
MultiQueueThreadPool cut{1, "test_executor"};
auto task_queue = cut.create_task_queue(TaskPriority::normal);
auto& task_queue = cut.create_task_queue(TaskPriority::normal);

Flag thread_id_assigned{};

auto invocation_thread{std::this_thread::get_id()};

task_queue->push([&thread_id_assigned, &invocation_thread] {
task_queue.push([&thread_id_assigned, &invocation_thread] {
invocation_thread = std::this_thread::get_id();
thread_id_assigned.signal();
});
Expand All @@ -128,10 +126,10 @@ DEFINE_TEST("ThreadPoolQueue::push() invokes task on separate thread") {
DEFINE_TEST("MultiQueueThreadPool::join() with 2 active threads completes") {
constexpr std::size_t num_threads{2};
MultiQueueThreadPool cut{num_threads, "test_executor"};
auto task_queue = cut.create_task_queue(TaskPriority::normal);
auto& task_queue = cut.create_task_queue(TaskPriority::normal);
Flag release_busy_tasks{};
Latch all_busy_tasks_started{num_threads};
auto producer_threads = create_producer_threads(*task_queue, num_threads,
auto producer_threads = create_producer_threads(task_queue, num_threads,
[&release_busy_tasks, &all_busy_tasks_started] {
all_busy_tasks_started.count_down();
release_busy_tasks.wait();
Expand Down Expand Up @@ -165,12 +163,12 @@ DEFINE_TEST_FIXTURE_METHOD(
"ThreadPoolQueue::push() high priority with pool size 2 and 2 busy normal tasks then high "
"priority is "
"invoked") {
auto normal_task_queue = cut->create_task_queue(TaskPriority::normal);
normal_task_queue->push(create_task(0));
normal_task_queue->push(create_task(1));
auto& normal_task_queue = cut->create_task_queue(TaskPriority::normal);
normal_task_queue.push(create_task(0));
normal_task_queue.push(create_task(1));

auto high_task_queue = cut->create_task_queue(TaskPriority::high);
high_task_queue->push(create_task(2));
auto& high_task_queue = cut->create_task_queue(TaskPriority::high);
high_task_queue.push(create_task(2));

REQUIRE(task_started_flags[2]->wait_for(TIMEOUT));
}
Expand Down
Loading

0 comments on commit 922ab14

Please sign in to comment.