Skip to content

Commit

Permalink
Merge branch 'smalton/DOR-222-pair-cache-strategy' into 'master'
Browse files Browse the repository at this point in the history
Add time ordered reads pair cache strategy

Closes DOR-222

See merge request machine-learning/dorado!459
  • Loading branch information
malton-ont committed Jul 4, 2023
2 parents d2700dd + eb375aa commit d953f33
Show file tree
Hide file tree
Showing 9 changed files with 152 additions and 77 deletions.
16 changes: 9 additions & 7 deletions dorado/cli/duplex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "utils/log_utils.h"
#include "utils/models.h"
#include "utils/parameters.h"
#include "utils/types.h"

#include <argparse.hpp>
#include <htslib/sam.h>
Expand Down Expand Up @@ -303,18 +304,19 @@ int duplex(int argc, char* argv[]) {
StereoDuplexEncoderNode stereo_node =
StereoDuplexEncoderNode(*stereo_basecaller_node, simplex_model_stride);

PairingNode pairing_node(stereo_node,
template_complement_map.empty()
? std::optional<std::map<std::string, std::string>>{}
: template_complement_map);
std::unique_ptr<PairingNode> pairing_node =
template_complement_map.empty()
? std::make_unique<PairingNode>(stereo_node, ReadOrder::BY_CHANNEL)
: std::make_unique<PairingNode>(stereo_node,
std::move(template_complement_map));

// Initialize duplex split settings and create a duplex split node
// with the given settings and number of devices. If
// splitter_settings.enabled is set to false, the splitter node will
// act as a passthrough, meaning it won't perform any splitting
// operations and will just pass data through.
DuplexSplitSettings splitter_settings;
DuplexSplitNode splitter_node(pairing_node, splitter_settings, num_devices);
DuplexSplitNode splitter_node(*pairing_node, splitter_settings, num_devices);

auto adjusted_simplex_overlap = (overlap / simplex_model_stride) * simplex_model_stride;

Expand All @@ -333,7 +335,7 @@ int duplex(int argc, char* argv[]) {
using dorado::stats::make_stats_reporter;
stats_reporters.push_back(make_stats_reporter(*stereo_basecaller_node));
stats_reporters.push_back(make_stats_reporter(stereo_node));
stats_reporters.push_back(make_stats_reporter(pairing_node));
stats_reporters.push_back(make_stats_reporter(*pairing_node));
stats_reporters.push_back(make_stats_reporter(splitter_node));
stats_reporters.push_back(make_stats_reporter(*basecaller_node));
stats_reporters.push_back(make_stats_reporter(loader));
Expand All @@ -344,7 +346,7 @@ int duplex(int argc, char* argv[]) {
kStatsPeriod, stats_reporters, stats_callables);
// End stats counting setup.

loader.load_reads(reads, parser.get<bool>("--recursive"), DataLoader::BY_CHANNEL);
loader.load_reads(reads, parser.get<bool>("--recursive"), ReadOrder::BY_CHANNEL);
bam_writer->join(); // Explicitly wait for all output rows to be written.
stats_sampler->terminate();
}
Expand Down
8 changes: 4 additions & 4 deletions dorado/data_loader/DataLoader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ void DataLoader::load_reads(const std::string& path,

auto iterate_directory = [&](const auto& iterator_fn) {
switch (traversal_order) {
case BY_CHANNEL:
case ReadOrder::BY_CHANNEL:
// If traversal in channel order is required, the following algorithm
// is used -
// 1. iterate through all the read metadata to collect channel information
Expand Down Expand Up @@ -199,7 +199,7 @@ void DataLoader::load_reads(const std::string& path,
}
}
break;
case UNRESTRICTED:
case ReadOrder::UNRESTRICTED:
for (const auto& entry : iterator_fn(path)) {
if (m_loaded_read_count == m_max_reads) {
break;
Expand All @@ -215,8 +215,8 @@ void DataLoader::load_reads(const std::string& path,
}
break;
default:
throw std::runtime_error("Unsupported traversal order detected " +
std::to_string(traversal_order));
throw std::runtime_error("Unsupported traversal order detected: " +
dorado::to_string(traversal_order));
}
};

Expand Down
8 changes: 2 additions & 6 deletions dorado/data_loader/DataLoader.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once
#include "utils/stats.h"
#include "utils/types.h"

#include <array>
#include <map>
Expand Down Expand Up @@ -28,11 +29,6 @@ using Pod5Ptr = std::unique_ptr<Pod5FileReader, Pod5Destructor>;

class DataLoader {
public:
enum ReadOrder {
UNRESTRICTED,
BY_CHANNEL,
};

DataLoader(MessageSink& read_sink,
const std::string& device,
size_t num_worker_threads,
Expand All @@ -42,7 +38,7 @@ class DataLoader {
~DataLoader() = default;
void load_reads(const std::string& path,
bool recursive_file_loading = false,
ReadOrder traversal_order = UNRESTRICTED);
ReadOrder traversal_order = ReadOrder::UNRESTRICTED);

static std::unordered_map<std::string, ReadGroup> load_read_groups(
std::string data_path,
Expand Down
118 changes: 70 additions & 48 deletions dorado/read_pipeline/PairingNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ void PairingNode::pair_list_worker_thread() {
partner_found = true;
} else {
{
tc_lock.unlock();
std::lock_guard<std::mutex> ct_lock(m_ct_map_mutex);
auto it = m_complement_template_map.find(read->read_id);
if (it != m_complement_template_map.end()) {
Expand All @@ -47,14 +48,14 @@ void PairingNode::pair_list_worker_thread() {

if (partner_found) {
std::unique_lock<std::mutex> read_cache_lock(m_read_cache_mutex);
if (read_cache.find(partner_id) == read_cache.end()) {
auto partner_read_itr = m_read_cache.find(partner_id);
if (partner_read_itr == m_read_cache.end()) {
// Partner is not in the read cache
read_cache[read->read_id] = read;
m_read_cache.insert({read->read_id, read});
read_cache_lock.unlock();
} else {
auto partner_read_itr = read_cache.find(partner_id);
auto partner_read = partner_read_itr->second;
read_cache.erase(partner_read_itr);
m_read_cache.erase(partner_read_itr);
read_cache_lock.unlock();

std::shared_ptr<Read> template_read;
Expand Down Expand Up @@ -84,6 +85,11 @@ void PairingNode::pair_list_worker_thread() {
}

void PairingNode::pair_generating_worker_thread() {
auto compare_reads_by_time = [](const std::shared_ptr<Read>& read1,
const std::shared_ptr<Read>& read2) {
return read1->start_time_ms < read2->start_time_ms;
};

Message message;
while (m_work_queue.try_pop(message)) {
// If this message isn't a read, we'll get a bad_variant_access exception.
Expand All @@ -95,43 +101,36 @@ void PairingNode::pair_generating_worker_thread() {
std::string flowcell_id = read->flowcell_id;
int32_t client_id = read->client_id;

int max_num_keys = 10;
std::unique_lock<std::mutex> lock(m_pairing_mtx);
UniquePoreIdentifierKey key = std::make_tuple(channel, mux, run_id, flowcell_id, client_id);
auto found = channel_mux_read_map.find(key);
auto read_list_iter = m_channel_mux_read_map.find(key);
// Check if the key is already in the list
if (found == channel_mux_read_map.end()) {
if (read_list_iter == m_channel_mux_read_map.end()) {
// Key is not in the dequeue
// Add the new key to the end of the list
m_working_channel_mux_keys.push_back(key);
m_channel_mux_read_map.insert({key, {read}});

if (m_working_channel_mux_keys.size() >= max_num_keys) {
if (m_working_channel_mux_keys.size() > m_max_num_keys) {
// Remove the oldest key (front of the list)
auto oldest_key = m_working_channel_mux_keys.front();
m_working_channel_mux_keys.pop_front();

auto oldest_key_it = channel_mux_read_map.find(oldest_key);
auto oldest_key_it = m_channel_mux_read_map.find(oldest_key);

// Remove the oldest key from the map
for (auto read_ptr : oldest_key_it->second) {
m_sink.push_message(read_ptr);
}
channel_mux_read_map.erase(oldest_key);
assert(channel_mux_read_map.size() == m_working_channel_mux_keys.size());
m_channel_mux_read_map.erase(oldest_key);
assert(m_channel_mux_read_map.size() == m_working_channel_mux_keys.size());
}
// Add the new key to the end of the list
m_working_channel_mux_keys.push_back(key);
}

auto compare_reads_by_time = [](const std::shared_ptr<Read>& read1,
const std::shared_ptr<Read>& read2) {
return read1->attributes.start_time < read2->attributes.start_time;
};

if (channel_mux_read_map.count(key)) {
auto later_read =
std::lower_bound(channel_mux_read_map[key].begin(),
channel_mux_read_map[key].end(), read, compare_reads_by_time);
} else {
auto& cached_read_list = read_list_iter->second;
auto later_read = std::lower_bound(cached_read_list.begin(), cached_read_list.end(),
read, compare_reads_by_time);

if (later_read != channel_mux_read_map[key].begin()) {
if (later_read != cached_read_list.begin()) {
auto earlier_read = std::prev(later_read);

if (is_within_time_and_length_criteria(*earlier_read, read)) {
Expand All @@ -141,25 +140,26 @@ void PairingNode::pair_generating_worker_thread() {
}
}

if (later_read != channel_mux_read_map[key].end()) {
if (later_read != cached_read_list.end()) {
if (is_within_time_and_length_criteria(read, *later_read)) {
ReadPair pair = {read, *later_read};
++read->num_duplex_candidate_pairs;
m_sink.push_message(std::make_shared<ReadPair>(pair));
}
}

channel_mux_read_map[key].insert(later_read, read);

} else {
channel_mux_read_map[key].push_back(read);
cached_read_list.insert(later_read, read);
while (cached_read_list.size() > m_max_num_reads) {
cached_read_list.pop_front();
}
}
}

if (--m_num_worker_threads == 0) {
std::unique_lock<std::mutex> lock(m_pairing_mtx);
// There are still reads in channel_mux_read_map. Push them to the sink.
// Last thread alive is responsible for cleaning up the cache.
for (const auto& kv : channel_mux_read_map) {
for (const auto& kv : m_channel_mux_read_map) {
// kv is a std::pair<UniquePoreIdentifierKey, std::list<std::shared_ptr<Read>>>
const auto& reads_list = kv.second;

Expand All @@ -174,26 +174,48 @@ void PairingNode::pair_generating_worker_thread() {
}

PairingNode::PairingNode(MessageSink& sink,
std::optional<std::map<std::string, std::string>> template_complement_map,
std::map<std::string, std::string> template_complement_map,
int num_worker_threads,
size_t max_reads)
: MessageSink(max_reads), m_sink(sink), m_num_worker_threads(num_worker_threads) {
if (template_complement_map.has_value()) {
m_template_complement_map = template_complement_map.value();
// Set up the complement-template_map
for (auto& key : m_template_complement_map) {
m_complement_template_map[key.second] = key.first;
}
: MessageSink(max_reads),
m_sink(sink),
m_num_worker_threads(num_worker_threads),
m_template_complement_map(std::move(template_complement_map)) {
// Set up the complement-template_map
for (auto& key : m_template_complement_map) {
m_complement_template_map[key.second] = key.first;
}

for (size_t i = 0; i < m_num_worker_threads; i++) {
m_workers.push_back(std::make_unique<std::thread>(
std::thread(&PairingNode::pair_list_worker_thread, this)));
}
} else {
for (size_t i = 0; i < m_num_worker_threads; i++) {
m_workers.push_back(std::make_unique<std::thread>(
std::thread(&PairingNode::pair_generating_worker_thread, this)));
}
for (size_t i = 0; i < m_num_worker_threads; i++) {
m_workers.push_back(std::make_unique<std::thread>(
std::thread(&PairingNode::pair_list_worker_thread, this)));
}
}

PairingNode::PairingNode(MessageSink& sink,
ReadOrder read_order,
int num_worker_threads,
size_t max_reads)
: MessageSink(max_reads),
m_sink(sink),
m_num_worker_threads(num_worker_threads),
m_max_num_keys(std::numeric_limits<size_t>::max()),
m_max_num_reads(std::numeric_limits<size_t>::max()) {
switch (read_order) {
case ReadOrder::BY_CHANNEL:
m_max_num_keys = 10;
break;
case ReadOrder::BY_TIME:
m_max_num_reads = 10;
break;
default:
throw std::runtime_error("Unsupported read order detected: " +
dorado::to_string(read_order));
}

for (size_t i = 0; i < m_num_worker_threads; i++) {
m_workers.push_back(std::make_unique<std::thread>(
std::thread(&PairingNode::pair_generating_worker_thread, this)));
}
}

Expand Down
56 changes: 48 additions & 8 deletions dorado/read_pipeline/PairingNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "ReadPipeline.h"
#include "utils/stats.h"
#include "utils/types.h"

#include <atomic>
#include <deque>
Expand All @@ -16,40 +17,79 @@ namespace dorado {

class PairingNode : public MessageSink {
public:
// Template-complement map: uses the pair_list pairing method
PairingNode(MessageSink& sink,
std::optional<std::map<std::string, std::string>> = std::nullopt,
std::map<std::string, std::string> template_complement_map,
int num_worker_threads = 2,
size_t max_reads = 1000);

// No template-complement map: uses the pair_generation pairing method
PairingNode(MessageSink& sink,
ReadOrder read_order,
int num_worker_threads = 2,
size_t max_reads = 1000);
~PairingNode();
std::string get_name() const override { return "PairingNode"; }
stats::NamedStats sample_stats() const override;

private:
/**
* This is a worker thread function for pairing reads based on a specified list of template-complement pairs.
*/
void pair_list_worker_thread();

/**
* This is a worker thread function for generating pairs of reads that fall within pairing criteria.
*
* The function goes through the incoming messages, which are expected to be reads. For each read, it finds its pore
* in the list of active pores. If the pore isn't in the list yet, it is added. If the list of active pores has reached
* its maximum size (m_max_num_keys), the oldest pore is removed from the list, and its associated reads are discarded.
* The function then inserts the new read into the sorted list of reads for its pore, and checks if it can be paired
* with the reads immediately before and after it in the list. If the list of reads for a pore has reached its maximum
* size (m_max_num_reads), the oldest read is removed from the list.
*/
void pair_generating_worker_thread();

// A key for a unique Pore, Duplex reads must have the same UniquePoreIdentifierKey
// The values are channel, mux, run_id, flowcell_id, client_id
using UniquePoreIdentifierKey = std::tuple<int, int, std::string, std::string, int32_t>;

std::vector<std::unique_ptr<std::thread>> m_workers;
MessageSink& m_sink;
std::map<std::string, std::string> m_template_complement_map;
std::map<std::string, std::string> m_complement_template_map;
std::vector<std::unique_ptr<std::thread>> m_workers;
std::atomic<int> m_num_worker_threads;

// Members for pair_list method

std::mutex m_tc_map_mutex;
std::mutex m_ct_map_mutex;
std::mutex m_read_cache_mutex;

std::atomic<int> m_num_worker_threads;
std::map<std::string, std::string> m_template_complement_map;
std::map<std::string, std::string> m_complement_template_map;
std::map<std::string, std::shared_ptr<Read>> m_read_cache;

std::map<std::string, std::shared_ptr<Read>> read_cache;
// Members for pair_generating method

std::map<UniquePoreIdentifierKey, std::list<std::shared_ptr<Read>>> channel_mux_read_map;
std::mutex m_pairing_mtx;

std::map<UniquePoreIdentifierKey, std::list<std::shared_ptr<Read>>> m_channel_mux_read_map;
std::deque<UniquePoreIdentifierKey> m_working_channel_mux_keys;

std::mutex m_pairing_mtx;
/**
* The maximum number of different channels (pores) to keep in memory concurrently.
* This parameter is crucial when reads are expected to be delivered in channel/pore order. In this order,
* once a read from a specific pore is processed, it is guaranteed that no other reads from that pore will appear.
* Thus, the function can limit memory usage by only keeping reads from a fixed number of pores (channels) in memory.
*/
size_t m_max_num_keys;

/**
* The maximum number of reads from a specific pore to keep in memory. This parameter is
* crucial when reads are expected to be delivered in time order. In this order, reads from the same pore could
* appear at any point in the stream. Thus, the function keeps a limited history of reads for each pore in memory.
* It ensures that the memory usage is controlled, while the reads needed for pairing are available.
*/
size_t m_max_num_reads;
};

} // namespace dorado
Loading

0 comments on commit d953f33

Please sign in to comment.