From 5d40559f9b62768692dd563d565eecc6c35e5ca5 Mon Sep 17 00:00:00 2001 From: Mark Bicknell Date: Thu, 19 Sep 2024 09:33:43 +0000 Subject: [PATCH] Merge branch 'INSTX-6538-priorityqueue-fix' into 'master' [INSTX-6538] Fix crash in PriorityTaskQueue when popping a work item with an explicit priority Closes INSTX-6538 See merge request machine-learning/dorado!1201 --- .../utils/concurrency/async_task_executor.cpp | 4 +- .../utils/concurrency/async_task_executor.h | 6 +- .../detail/priority_task_queue.cpp | 35 ++++------ .../concurrency/detail/priority_task_queue.h | 48 +++++++------- .../concurrency/multi_queue_thread_pool.cpp | 39 +++++------ .../concurrency/multi_queue_thread_pool.h | 31 ++++++--- tests/multi_queue_thread_pool_test.cpp | 32 +++++----- tests/priority_task_queue_test.cpp | 64 +++++++++++++++++++ 8 files changed, 160 insertions(+), 99 deletions(-) diff --git a/dorado/utils/concurrency/async_task_executor.cpp b/dorado/utils/concurrency/async_task_executor.cpp index 8f693e29c..2d9432b86 100644 --- a/dorado/utils/concurrency/async_task_executor.cpp +++ b/dorado/utils/concurrency/async_task_executor.cpp @@ -13,7 +13,7 @@ AsyncTaskExecutor::~AsyncTaskExecutor() { flush(); } void AsyncTaskExecutor::send_impl(TaskType task) { increment_tasks_in_flight(); - m_thread_pool_queue->push([task = std::move(task), this] { + m_thread_pool_queue.push([task = std::move(task), this] { task(); decrement_tasks_in_flight(); }); @@ -23,7 +23,7 @@ std::unique_ptr AsyncTaskExecutor::send_async(TaskType task) { increment_tasks_in_flight(); auto sending_thread = std::make_unique([this, task = std::move(task)]() mutable { - m_thread_pool_queue->push([task = std::move(task), this] { + m_thread_pool_queue.push([task = std::move(task), this] { task(); decrement_tasks_in_flight(); }); diff --git a/dorado/utils/concurrency/async_task_executor.h b/dorado/utils/concurrency/async_task_executor.h index 4dd2bfe69..394bde47c 100644 --- a/dorado/utils/concurrency/async_task_executor.h +++ b/dorado/utils/concurrency/async_task_executor.h @@ -25,9 +25,9 @@ namespace dorado::utils::concurrency { // to flush it's tasks without waiting for other pipleines to // also flush. class AsyncTaskExecutor { - std::unique_ptr m_thread_pool_queue; - std::mutex m_mutex{}; - std::condition_variable m_tasks_in_flight_cv{}; + MultiQueueThreadPool::ThreadPoolQueue& m_thread_pool_queue; + std::mutex m_mutex; + std::condition_variable m_tasks_in_flight_cv; std::size_t m_num_tasks_in_flight{}; std::unique_ptr m_flushing_counter; std::size_t m_max_tasks_in_flight; diff --git a/dorado/utils/concurrency/detail/priority_task_queue.cpp b/dorado/utils/concurrency/detail/priority_task_queue.cpp index 6e7bb7570..8895f90db 100644 --- a/dorado/utils/concurrency/detail/priority_task_queue.cpp +++ b/dorado/utils/concurrency/detail/priority_task_queue.cpp @@ -5,7 +5,7 @@ namespace dorado::utils::concurrency::detail { -void PriorityTaskQueue::queue_producer_task(ProducerQueue* producer_queue) { +void PriorityTaskQueue::queue_producer_task(TaskQueue* producer_queue) { m_producer_queue_list.push_back(producer_queue); auto task_itr = std::prev(m_producer_queue_list.end()); if (producer_queue->priority() == TaskPriority::high) { @@ -22,32 +22,26 @@ std::size_t PriorityTaskQueue::size(TaskPriority priority) const { } WaitingTask PriorityTaskQueue::pop() { - auto producer_queue_itr = m_producer_queue_list.begin(); - TaskPriority popped_priority{TaskPriority::normal}; - if (!m_low_producer_queue.empty() && producer_queue_itr == m_low_producer_queue.front()) { - m_low_producer_queue.pop(); - } else { - popped_priority = TaskPriority::high; - m_high_producer_queue.pop(); - } - - WaitingTask result{(*producer_queue_itr)->pop(), popped_priority}; - m_producer_queue_list.pop_front(); - return result; + assert(!m_producer_queue_list.empty()); + const auto next_priority = m_producer_queue_list.front()->priority(); + return pop(next_priority); } WaitingTask PriorityTaskQueue::pop(TaskPriority priority) { - ProducerQueueList::iterator producer_queue_itr; + TaskQueueList::iterator producer_queue_itr; if (priority == TaskPriority::high) { + assert(!m_high_producer_queue.empty()); producer_queue_itr = m_high_producer_queue.front(); m_high_producer_queue.pop(); } else { + assert(!m_low_producer_queue.empty()); producer_queue_itr = m_low_producer_queue.front(); m_low_producer_queue.pop(); } + assert(priority == (*producer_queue_itr)->priority()); WaitingTask result{(*producer_queue_itr)->pop(), priority}; - m_producer_queue_list.pop_front(); + m_producer_queue_list.erase(producer_queue_itr); return result; } @@ -55,10 +49,10 @@ bool PriorityTaskQueue::empty() const { return size() == 0; } bool PriorityTaskQueue::empty(TaskPriority priority) const { return size(priority) == 0; } -PriorityTaskQueue::ProducerQueue::ProducerQueue(PriorityTaskQueue* parent, TaskPriority priority) +PriorityTaskQueue::TaskQueue::TaskQueue(PriorityTaskQueue* parent, TaskPriority priority) : m_parent(parent), m_priority(priority) {} -void PriorityTaskQueue::ProducerQueue::push(TaskType task) { +void PriorityTaskQueue::TaskQueue::push(TaskType task) { m_producer_queue.push(std::move(task)); if (m_priority == TaskPriority::normal) { ++m_parent->m_num_normal_prio; @@ -70,7 +64,7 @@ void PriorityTaskQueue::ProducerQueue::push(TaskType task) { } } -TaskType PriorityTaskQueue::ProducerQueue::pop() { +TaskType PriorityTaskQueue::TaskQueue::pop() { assert(!m_producer_queue.empty() && "Cannot pop an empty producer queue."); auto result = std::move(m_producer_queue.front()); m_producer_queue.pop(); @@ -86,8 +80,7 @@ TaskType PriorityTaskQueue::ProducerQueue::pop() { } PriorityTaskQueue::TaskQueue& PriorityTaskQueue::create_task_queue(TaskPriority priority) { - m_queue_repository.emplace_back(std::make_unique(this, priority)); - return *m_queue_repository.back(); + return *m_queue_repository.emplace_back(new TaskQueue(this, priority)); } -} // namespace dorado::utils::concurrency::detail \ No newline at end of file +} // namespace dorado::utils::concurrency::detail diff --git a/dorado/utils/concurrency/detail/priority_task_queue.h b/dorado/utils/concurrency/detail/priority_task_queue.h index 4624dc030..bee87bfc4 100644 --- a/dorado/utils/concurrency/detail/priority_task_queue.h +++ b/dorado/utils/concurrency/detail/priority_task_queue.h @@ -28,11 +28,7 @@ struct WaitingTask { // The interface supports popping by priority. class PriorityTaskQueue { public: - class TaskQueue { - public: - virtual ~TaskQueue() = default; - virtual void push(TaskType task) = 0; - }; + class TaskQueue; TaskQueue& create_task_queue(TaskPriority priority); WaitingTask pop(); @@ -44,35 +40,39 @@ class PriorityTaskQueue { bool empty() const; bool empty(TaskPriority priority) const; -private: - class ProducerQueue : public TaskQueue { + class TaskQueue { + friend class PriorityTaskQueue; + PriorityTaskQueue* m_parent; TaskPriority m_priority; - std::queue m_producer_queue{}; + std::queue m_producer_queue; - public: - ProducerQueue(PriorityTaskQueue* parent, TaskPriority priority); + TaskQueue(PriorityTaskQueue* parent, TaskPriority priority); + TaskType pop(); - TaskPriority priority() const { return m_priority; }; + TaskQueue(const TaskQueue&) = delete; + TaskQueue& operator=(const TaskQueue&) = delete; - void push(TaskType task) override; - TaskType pop(); + public: + TaskPriority priority() const { return m_priority; } + void push(TaskType task); }; - std::vector> - m_queue_repository{}; // ownership of producer queues - using ProducerQueueList = std::list; - ProducerQueueList m_producer_queue_list{}; - std::queue m_low_producer_queue{}; - std::queue m_high_producer_queue{}; + +private: + std::vector> m_queue_repository; // ownership of producer queues + using TaskQueueList = std::list; + TaskQueueList m_producer_queue_list; + std::queue m_low_producer_queue; + std::queue m_high_producer_queue; std::size_t m_num_normal_prio{}; std::size_t m_num_high_prio{}; using WaitingTaskList = std::list>; - WaitingTaskList m_task_list{}; - std::queue m_low_queue{}; - std::queue m_high_queue{}; + WaitingTaskList m_task_list; + std::queue m_low_queue; + std::queue m_high_queue; - void queue_producer_task(ProducerQueue* producer_queue); + void queue_producer_task(TaskQueue* producer_queue); }; -} // namespace dorado::utils::concurrency::detail \ No newline at end of file +} // namespace dorado::utils::concurrency::detail diff --git a/dorado/utils/concurrency/multi_queue_thread_pool.cpp b/dorado/utils/concurrency/multi_queue_thread_pool.cpp index f216f019f..e0f11e398 100644 --- a/dorado/utils/concurrency/multi_queue_thread_pool.cpp +++ b/dorado/utils/concurrency/multi_queue_thread_pool.cpp @@ -34,11 +34,15 @@ MultiQueueThreadPool::~MultiQueueThreadPool() { join(); } void MultiQueueThreadPool::join() { // post as many done messages as there are threads to make sure all waiting threads will receive a wakeup - auto& terminate_task_queue = m_priority_task_queue.create_task_queue(TaskPriority::normal); + detail::PriorityTaskQueue::TaskQueue* terminate_task_queue; + { + std::lock_guard lock(m_mutex); + terminate_task_queue = &m_priority_task_queue.create_task_queue(TaskPriority::normal); + } for (uint32_t thread_index{0}; thread_index < m_num_threads * 2; ++thread_index) { { std::lock_guard lock(m_mutex); - terminate_task_queue.push([this] { m_done.store(true, std::memory_order_relaxed); }); + terminate_task_queue->push([this] { m_done.store(true, std::memory_order_relaxed); }); } m_message_received.notify_one(); } @@ -135,31 +139,20 @@ void MultiQueueThreadPool::process_task_queue() { } } -namespace { - -class ThreadPoolQueueImpl : public MultiQueueThreadPool::ThreadPoolQueue { - MultiQueueThreadPool* m_parent; - detail::PriorityTaskQueue::TaskQueue& m_task_queue; - -public: - ThreadPoolQueueImpl(MultiQueueThreadPool* parent, - detail::PriorityTaskQueue::TaskQueue& task_queue); - void push(TaskType task) override; -}; - -ThreadPoolQueueImpl::ThreadPoolQueueImpl(MultiQueueThreadPool* parent, - detail::PriorityTaskQueue::TaskQueue& task_queue) +MultiQueueThreadPool::ThreadPoolQueue::ThreadPoolQueue( + MultiQueueThreadPool* parent, + detail::PriorityTaskQueue::TaskQueue& task_queue) : m_parent(parent), m_task_queue(task_queue) {} -void ThreadPoolQueueImpl::push(TaskType task) { m_parent->send(std::move(task), m_task_queue); } - -} // namespace +void MultiQueueThreadPool::ThreadPoolQueue::push(TaskType task) { + m_parent->send(std::move(task), m_task_queue); +} -std::unique_ptr MultiQueueThreadPool::create_task_queue( +MultiQueueThreadPool::ThreadPoolQueue& MultiQueueThreadPool::create_task_queue( TaskPriority priority) { std::lock_guard lock(m_mutex); - return std::make_unique(this, - m_priority_task_queue.create_task_queue(priority)); + auto& task_queue = m_priority_task_queue.create_task_queue(priority); + return *m_queues.emplace_back(new ThreadPoolQueue(this, task_queue)); } -} // namespace dorado::utils::concurrency \ No newline at end of file +} // namespace dorado::utils::concurrency diff --git a/dorado/utils/concurrency/multi_queue_thread_pool.h b/dorado/utils/concurrency/multi_queue_thread_pool.h index a5762558d..77dc86b44 100644 --- a/dorado/utils/concurrency/multi_queue_thread_pool.h +++ b/dorado/utils/concurrency/multi_queue_thread_pool.h @@ -41,27 +41,40 @@ class MultiQueueThreadPool { MultiQueueThreadPool(std::size_t num_threads, std::string name); ~MultiQueueThreadPool(); - void send(TaskType task, detail::PriorityTaskQueue::TaskQueue& task_queue); - void join(); + class ThreadPoolQueue; + ThreadPoolQueue& create_task_queue(TaskPriority priority); + class ThreadPoolQueue { + friend class MultiQueueThreadPool; + + MultiQueueThreadPool* m_parent; + detail::PriorityTaskQueue::TaskQueue& m_task_queue; + + ThreadPoolQueue(MultiQueueThreadPool* parent, + detail::PriorityTaskQueue::TaskQueue& task_queue); + + ThreadPoolQueue(const ThreadPoolQueue&) = delete; + ThreadPoolQueue& operator=(const ThreadPoolQueue&) = delete; + public: - virtual ~ThreadPoolQueue() = default; - virtual void push(TaskType task) = 0; + void push(TaskType task); }; - std::unique_ptr create_task_queue(TaskPriority priority); private: + void send(TaskType task, detail::PriorityTaskQueue::TaskQueue& task_queue); + std::string m_name{"async_task_exec"}; const std::size_t m_num_threads; const std::size_t m_num_expansion_low_prio_threads; - std::vector m_threads{}; + std::vector m_threads; std::atomic_bool m_done{false}; // Note that this flag is only accessed by the managed threads. - std::mutex m_mutex{}; - detail::PriorityTaskQueue m_priority_task_queue{}; - std::condition_variable m_message_received{}; + std::mutex m_mutex; + std::vector> m_queues; + detail::PriorityTaskQueue m_priority_task_queue; + std::condition_variable m_message_received; std::size_t m_normal_prio_tasks_in_flight{}; std::size_t m_high_prio_tasks_in_flight{}; diff --git a/tests/multi_queue_thread_pool_test.cpp b/tests/multi_queue_thread_pool_test.cpp index 9c6e7a8b7..5d49b8935 100644 --- a/tests/multi_queue_thread_pool_test.cpp +++ b/tests/multi_queue_thread_pool_test.cpp @@ -83,40 +83,38 @@ class NoQueueThreadPoolTestFixture { } // namespace -DEFINE_TEST("create_task_queue returns non-null object") { +DEFINE_TEST("create_task_queue doesn't throw") { MultiQueueThreadPool cut{1, "test_executor"}; - auto task_queue = cut.create_task_queue(TaskPriority::normal); - - REQUIRE(task_queue != nullptr); + CHECK_NOTHROW(cut.create_task_queue(TaskPriority::normal)); } DEFINE_TEST("ThreadPoolQueue::push() with valid task_queue does not throw") { MultiQueueThreadPool cut{2, "test_executor"}; - auto task_queue = cut.create_task_queue(TaskPriority::normal); + auto& task_queue = cut.create_task_queue(TaskPriority::normal); - REQUIRE_NOTHROW(task_queue->push([] {})); + REQUIRE_NOTHROW(task_queue.push([] {})); } DEFINE_TEST("ThreadPoolQueue::push() with valid task_queue invokes the task") { MultiQueueThreadPool cut{2, "test_executor"}; - auto task_queue = cut.create_task_queue(TaskPriority::normal); + auto& task_queue = cut.create_task_queue(TaskPriority::normal); Flag invoked{}; - task_queue->push([&invoked] { invoked.signal(); }); + task_queue.push([&invoked] { invoked.signal(); }); REQUIRE(invoked.wait_for(TIMEOUT)); } DEFINE_TEST("ThreadPoolQueue::push() invokes task on separate thread") { MultiQueueThreadPool cut{1, "test_executor"}; - auto task_queue = cut.create_task_queue(TaskPriority::normal); + auto& task_queue = cut.create_task_queue(TaskPriority::normal); Flag thread_id_assigned{}; auto invocation_thread{std::this_thread::get_id()}; - task_queue->push([&thread_id_assigned, &invocation_thread] { + task_queue.push([&thread_id_assigned, &invocation_thread] { invocation_thread = std::this_thread::get_id(); thread_id_assigned.signal(); }); @@ -128,10 +126,10 @@ DEFINE_TEST("ThreadPoolQueue::push() invokes task on separate thread") { DEFINE_TEST("MultiQueueThreadPool::join() with 2 active threads completes") { constexpr std::size_t num_threads{2}; MultiQueueThreadPool cut{num_threads, "test_executor"}; - auto task_queue = cut.create_task_queue(TaskPriority::normal); + auto& task_queue = cut.create_task_queue(TaskPriority::normal); Flag release_busy_tasks{}; Latch all_busy_tasks_started{num_threads}; - auto producer_threads = create_producer_threads(*task_queue, num_threads, + auto producer_threads = create_producer_threads(task_queue, num_threads, [&release_busy_tasks, &all_busy_tasks_started] { all_busy_tasks_started.count_down(); release_busy_tasks.wait(); @@ -165,12 +163,12 @@ DEFINE_TEST_FIXTURE_METHOD( "ThreadPoolQueue::push() high priority with pool size 2 and 2 busy normal tasks then high " "priority is " "invoked") { - auto normal_task_queue = cut->create_task_queue(TaskPriority::normal); - normal_task_queue->push(create_task(0)); - normal_task_queue->push(create_task(1)); + auto& normal_task_queue = cut->create_task_queue(TaskPriority::normal); + normal_task_queue.push(create_task(0)); + normal_task_queue.push(create_task(1)); - auto high_task_queue = cut->create_task_queue(TaskPriority::high); - high_task_queue->push(create_task(2)); + auto& high_task_queue = cut->create_task_queue(TaskPriority::high); + high_task_queue.push(create_task(2)); REQUIRE(task_started_flags[2]->wait_for(TIMEOUT)); } diff --git a/tests/priority_task_queue_test.cpp b/tests/priority_task_queue_test.cpp index 8d750e978..b85b68a6c 100644 --- a/tests/priority_task_queue_test.cpp +++ b/tests/priority_task_queue_test.cpp @@ -110,4 +110,68 @@ DEFINE_SCENARIO("prioritised pushing and popping with 2 high queues and one norm } } +DEFINE_SCENARIO("Popping tasks") { + GIVEN("A high and a normal priority queue") { + PriorityTaskQueue queue; + auto& normal_prio_queue = queue.create_task_queue(TaskPriority::normal); + auto& high_prio_queue = queue.create_task_queue(TaskPriority::high); + + WHEN("A task is pushed to each queue") { + // Check both orderings + const auto normal_first = GENERATE(true, false); + CAPTURE(normal_first); + const auto push_order = normal_first ? std::pair(&normal_prio_queue, &high_prio_queue) + : std::pair(&high_prio_queue, &normal_prio_queue); + push_order.first->push([] {}); + push_order.second->push([] {}); + + THEN("Queue sizes match") { + CHECK(queue.size() == 2); + CHECK(queue.size(TaskPriority::high) == 1); + CHECK(queue.size(TaskPriority::normal) == 1); + } + + THEN("Popping explicit priorities match their priority") { + // Check both orderings + const auto pop_order = + GENERATE(std::pair(TaskPriority::normal, TaskPriority::high), + std::pair(TaskPriority::high, TaskPriority::normal)); + CAPTURE(pop_order.first, pop_order.second); + + CHECK(queue.pop(pop_order.first).priority == pop_order.first); + CHECK(queue.size() == 1); + CHECK(queue.size(pop_order.first) == 0); + CHECK(queue.size(pop_order.second) == 1); + + CHECK(queue.pop(pop_order.second).priority == pop_order.second); + CHECK(queue.size() == 0); + CHECK(queue.size(pop_order.first) == 0); + CHECK(queue.size(pop_order.second) == 0); + } + + THEN("Popping 1 explicit priority and 1 arbitrary matches priorities") { + // Check both orderings + const auto priority = GENERATE(TaskPriority::normal, TaskPriority::high); + CAPTURE(priority); + + CHECK(queue.pop(priority).priority == priority); + CHECK(queue.size() == 1); + CHECK(queue.size(priority) == 0); + + CHECK(queue.pop().priority != priority); + CHECK(queue.size() == 0); + } + + THEN("Popping arbitrary tasks don't match each other") { + const auto first_priority = queue.pop().priority; + CHECK(queue.size() == 1); + CHECK(queue.size(first_priority) == 0); + + CHECK(queue.pop().priority != first_priority); + CHECK(queue.size() == 0); + } + } + } +} + } // namespace dorado::utils::concurrency::detail::priority_task_queue_test \ No newline at end of file