diff --git a/dorado/read_pipeline/BasecallerNode.cpp b/dorado/read_pipeline/BasecallerNode.cpp index d5cfa02a3..4566e6c29 100644 --- a/dorado/read_pipeline/BasecallerNode.cpp +++ b/dorado/read_pipeline/BasecallerNode.cpp @@ -74,7 +74,7 @@ void BasecallerNode::input_worker_thread() { } for (auto &chunk : read_chunks) { - m_chunks_in->try_push(std::move(chunk)); + m_chunks_in.try_push(std::move(chunk)); } break; // Go back to watching the input reads @@ -82,7 +82,7 @@ void BasecallerNode::input_worker_thread() { } // Notify the basecaller threads that it is safe to gracefully terminate the basecaller - m_chunks_in->terminate(); + m_chunks_in.terminate(); } void BasecallerNode::basecall_current_batch(int worker_id) { @@ -99,7 +99,7 @@ void BasecallerNode::basecall_current_batch(int worker_id) { } for (auto &complete_chunk : m_batched_chunks[worker_id]) { - m_processed_chunks->try_push(std::move(complete_chunk)); + m_processed_chunks.try_push(std::move(complete_chunk)); } m_batched_chunks[worker_id].clear(); @@ -108,7 +108,7 @@ void BasecallerNode::basecall_current_batch(int worker_id) { void BasecallerNode::working_reads_manager() { std::shared_ptr chunk; - while (m_processed_chunks->try_pop(chunk)) { + while (m_processed_chunks.try_pop(chunk)) { nvtx3::scoped_range loop{"working_reads_manager"}; auto source_read = chunk->source_read.lock(); @@ -148,7 +148,7 @@ void BasecallerNode::basecall_worker_thread(int worker_id) { auto last_chunk_reserve_time = std::chrono::system_clock::now(); int batch_size = m_model_runners[worker_id]->batch_size(); std::shared_ptr chunk; - while (m_chunks_in->try_pop_until( + while (m_chunks_in.try_pop_until( chunk, last_chunk_reserve_time + std::chrono::milliseconds(m_batch_timeout_ms))) { // If chunk is empty, then try_pop timed out without getting a new chunk. if (!chunk) { @@ -221,10 +221,25 @@ void BasecallerNode::basecall_worker_thread(int worker_id) { for (auto &runner : m_model_runners) { runner->terminate(); } - m_processed_chunks->terminate(); + m_processed_chunks.terminate(); } } +namespace { + +// Calculates the input queue size. +size_t CalcMaxChunksIn(const std::vector &model_runners) { + // Allow 5 batches per model runner on the chunks_in queue + size_t max_chunks_in = 0; + // Allows optimal batch size to be used for every GPU + for (auto &runner : model_runners) { + max_chunks_in += runner->batch_size() * 5; + } + return max_chunks_in; +} + +} // namespace + BasecallerNode::BasecallerNode(std::vector model_runners, size_t overlap, int batch_timeout_ms, @@ -243,22 +258,15 @@ BasecallerNode::BasecallerNode(std::vector model_runners, m_max_reads(max_reads), m_in_duplex_pipeline(in_duplex_pipeline), m_mean_qscore_start_pos(read_mean_qscore_start_pos), + m_chunks_in(CalcMaxChunksIn(m_model_runners)), + m_processed_chunks(CalcMaxChunksIn(m_model_runners)), m_node_name(node_name) { // Setup worker state - size_t const num_workers = m_model_runners.size(); + const size_t num_workers = m_model_runners.size(); m_batched_chunks.resize(num_workers); m_basecall_workers.resize(num_workers); m_num_active_model_runners = num_workers; - // Allow 5 batches per model runner on the chunks_in queue - size_t max_chunks_in = 0; - // Allows optimal batch size to be used for every GPU - for (auto &runner : m_model_runners) { - max_chunks_in += runner->batch_size() * 5; - } - m_chunks_in = std::make_unique>>(max_chunks_in); - m_processed_chunks = std::make_unique>>(max_chunks_in); - initialization_time = std::chrono::system_clock::now(); // Spin up any workers last so that we're not mutating |this| underneath them diff --git a/dorado/read_pipeline/BasecallerNode.h b/dorado/read_pipeline/BasecallerNode.h index 94840f2c3..d7982607b 100644 --- a/dorado/read_pipeline/BasecallerNode.h +++ b/dorado/read_pipeline/BasecallerNode.h @@ -65,7 +65,7 @@ class BasecallerNode : public MessageSink { // Time when Basecaller Node terminates. Used for benchmarking and debugging std::chrono::time_point termination_time; // Async queue to keep track of basecalling chunks. - std::unique_ptr>> m_chunks_in; + AsyncQueue> m_chunks_in; std::mutex m_working_reads_mutex; // Reads removed from input queue and being basecalled. @@ -74,7 +74,7 @@ class BasecallerNode : public MessageSink { // If we go multi-threaded, there will be one of these batches per thread std::vector>> m_batched_chunks; - std::unique_ptr>> m_processed_chunks; + AsyncQueue> m_processed_chunks; // Class members are initialised in declaration order regardless of initialiser list order. // Class data members whose construction launches threads must therefore have their diff --git a/dorado/read_pipeline/ModBaseCallerNode.cpp b/dorado/read_pipeline/ModBaseCallerNode.cpp index e4f01fb5c..9a8d4d17d 100644 --- a/dorado/read_pipeline/ModBaseCallerNode.cpp +++ b/dorado/read_pipeline/ModBaseCallerNode.cpp @@ -28,7 +28,9 @@ ModBaseCallerNode::ModBaseCallerNode(std::vector> : MessageSink(max_reads), m_batch_size(batch_size), m_block_stride(block_stride), - m_runners(std::move(model_runners)) { + m_runners(std::move(model_runners)), + // TODO -- more principled calculation of output queue size + m_processed_chunks(10 * max_reads) { init_modbase_info(); m_output_worker = std::make_unique(&ModBaseCallerNode::output_worker_thread, this); @@ -289,8 +291,7 @@ void ModBaseCallerNode::modbasecall_worker_thread(size_t worker_id, size_t calle for (auto& runner : m_runners) { runner->terminate(); } - m_terminate_output.store(true); - m_processed_chunks_cv.notify_one(); + m_processed_chunks.terminate(); } return; } @@ -347,7 +348,6 @@ void ModBaseCallerNode::call_current_batch( assert(results_f32.is_contiguous()); const auto* const results_f32_ptr = results_f32.data_ptr(); - std::unique_lock processed_chunks_lock(m_processed_chunks_mutex); auto row_size = results.size(1); // Put results into chunk @@ -355,42 +355,35 @@ void ModBaseCallerNode::call_current_batch( auto& chunk = batched_chunks[i]; chunk->scores.resize(row_size); std::memcpy(chunk->scores.data(), &results_f32_ptr[i * row_size], row_size * sizeof(float)); - m_processed_chunks.push_back(chunk); + m_processed_chunks.try_push(std::move(chunk)); } - processed_chunks_lock.unlock(); - m_processed_chunks_cv.notify_one(); - batched_chunks.clear(); ++m_num_batches_called; } void ModBaseCallerNode::output_worker_thread() { - while (true) { + // The m_processed_chunks lock is sufficiently contended that it's worth taking all + // chunks available once we obtain it. + std::vector> processed_chunks; + auto grab_chunk = [&processed_chunks](std::shared_ptr& chunk) { + processed_chunks.push_back(std::move(chunk)); + }; + while (m_processed_chunks.process_and_pop_all(grab_chunk)) { nvtx3::scoped_range range{"modbase_output_worker_thread"}; - // Wait until we are provided with a read - std::unique_lock processed_chunks_lock(m_processed_chunks_mutex); - m_processed_chunks_cv.wait(processed_chunks_lock, [this] { - return !m_processed_chunks.empty() || m_terminate_output.load(); - }); - if (m_terminate_output.load() && m_processed_chunks.empty()) { - return; - } - for (const auto& chunk : m_processed_chunks) { + for (const auto& chunk : processed_chunks) { auto source_read = chunk->source_read.lock(); int64_t result_pos = chunk->context_hit; int64_t offset = m_base_prob_offsets[RemoraUtils::BASE_IDS[source_read->seq[result_pos]]]; for (size_t i = 0; i < chunk->scores.size(); ++i) { source_read->base_mod_probs[m_num_states * result_pos + offset + i] = - uint8_t(std::min(std::floor(chunk->scores[i] * 256), 255.0f)); + static_cast(std::min(std::floor(chunk->scores[i] * 256), 255.0f)); } - source_read->num_modbase_chunks_called += 1; + ++source_read->num_modbase_chunks_called; } - - m_processed_chunks.clear(); - processed_chunks_lock.unlock(); + processed_chunks.clear(); // Now move any completed reads to the output queue std::vector> completed_reads; diff --git a/dorado/read_pipeline/ModBaseCallerNode.h b/dorado/read_pipeline/ModBaseCallerNode.h index 3e48ab375..454de52db 100644 --- a/dorado/read_pipeline/ModBaseCallerNode.h +++ b/dorado/read_pipeline/ModBaseCallerNode.h @@ -1,11 +1,18 @@ #pragma once + #include "ReadPipeline.h" +#include "utils/AsyncQueue.h" #include "utils/stats.h" +#include #include +#include +#include #include #include #include +#include +#include #include namespace dorado { @@ -72,7 +79,7 @@ class ModBaseCallerNode : public MessageSink { std::vector> m_runner_workers; std::vector> m_input_worker; - std::deque> m_processed_chunks; + AsyncQueue> m_processed_chunks; std::vector>> m_chunk_queues; std::mutex m_working_reads_mutex; @@ -83,9 +90,6 @@ class ModBaseCallerNode : public MessageSink { std::condition_variable m_chunk_queues_cv; std::condition_variable m_chunks_added_cv; - std::mutex m_processed_chunks_mutex; - std::condition_variable m_processed_chunks_cv; - std::atomic m_num_active_runner_workers{0}; std::atomic m_num_active_input_worker{0}; diff --git a/dorado/utils/AsyncQueue.h b/dorado/utils/AsyncQueue.h index a95eb8a81..a2e3ee5ca 100644 --- a/dorado/utils/AsyncQueue.h +++ b/dorado/utils/AsyncQueue.h @@ -105,7 +105,7 @@ class AsyncQueue { // Otherwise we block if the queue is empty. bool try_pop(Item& item) { std::unique_lock lock(m_mutex); - // Wait until either an item is added, or we're asked to terminate. + // Wait until the queue is non-empty, or we're asked to terminate. m_not_empty_cv.wait(lock, [this] { return !m_items.empty() || m_terminate; }); // Termination takes effect once all items have been popped from the queue. @@ -124,6 +124,34 @@ class AsyncQueue { return true; } + // Obtains all items in the queue once the lock is obtained. + // Return value is false if we are terminating. + // If the lock is contended this could be more efficient than repeated + // calls to try_pop. + template + bool process_and_pop_all(ProcessFn process_fn) { + std::unique_lock lock(m_mutex); + // Wait until the queue is non-empty, or we're asked to terminate. + m_not_empty_cv.wait(lock, [this] { return !m_items.empty() || m_terminate; }); + + // Termination takes effect once all items have been popped from the queue. + if (m_terminate && m_items.empty()) { + return false; + } + + while (!m_items.empty()) { + process_fn(m_items.front()); + m_items.pop(); + ++m_num_pops; + } + + // Inform a waiting thread that the queue is not full. + lock.unlock(); + m_not_full_cv.notify_one(); + + return true; + } + // Tells the queue to terminate any CV waits. void terminate() { { diff --git a/tests/AsyncQueueTest.cpp b/tests/AsyncQueueTest.cpp index 222908add..9caa9b7b6 100644 --- a/tests/AsyncQueueTest.cpp +++ b/tests/AsyncQueueTest.cpp @@ -4,8 +4,10 @@ #define TEST_GROUP "AsyncQueue " +#include #include #include +#include #include TEST_CASE(TEST_GROUP ": InputsMatchOutputs") { @@ -19,7 +21,8 @@ TEST_CASE(TEST_GROUP ": InputsMatchOutputs") { for (int i = 0; i < n; ++i) { int val = -1; const bool success = queue.try_pop(val); - REQUIRE(val == i); + REQUIRE(success); + CHECK(val == i); } } @@ -27,7 +30,7 @@ TEST_CASE(TEST_GROUP ": PushFailsIfTerminating") { AsyncQueue queue(1); queue.terminate(); const bool success = queue.try_push(42); - REQUIRE(!success); + CHECK(!success); } TEST_CASE(TEST_GROUP ": PopFailsIfTerminating") { @@ -35,7 +38,7 @@ TEST_CASE(TEST_GROUP ": PopFailsIfTerminating") { queue.terminate(); int val; const bool success = queue.try_pop(val); - REQUIRE(!success); + CHECK(!success); } // Spawned thread sits waiting for an item. @@ -62,7 +65,7 @@ TEST_CASE(TEST_GROUP ": PopFromOtherThread") { REQUIRE(success); popping_thread.join(); - REQUIRE(try_pop_result); + CHECK(try_pop_result); } // Spawned thread sits waiting for an item. @@ -89,5 +92,23 @@ TEST_CASE(TEST_GROUP ": TerminateFromOtherThread") { popping_thread.join(); // This will fail, since the wait is terminated. - REQUIRE(!try_pop_result); + CHECK(!try_pop_result); +} + +TEST_CASE(TEST_GROUP ": process_and_pop_all") { + const int n = 10; + AsyncQueue queue(n); + for (int i = 0; i < n; ++i) { + const bool success = queue.try_push(std::move(i)); + REQUIRE(success); + } + + std::vector popped_items; + const bool success = queue.process_and_pop_all( + [&popped_items](int popped) { popped_items.push_back(popped); }); + REQUIRE(success); + std::vector expected(n); + std::iota(expected.begin(), expected.end(), 0); + CHECK(popped_items == expected); + CHECK(queue.size() == 0); } \ No newline at end of file