Skip to content

Commit

Permalink
Use same enum to set traversal order for dataloader and pairing node
Browse files Browse the repository at this point in the history
  • Loading branch information
malton-ont committed Jun 30, 2023
1 parent a31f874 commit eb375aa
Show file tree
Hide file tree
Showing 9 changed files with 37 additions and 20 deletions.
5 changes: 3 additions & 2 deletions dorado/cli/duplex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

#include "utils/models.h"
#include "utils/parameters.h"
#include "utils/types.h"

#include <argparse.hpp>
#include <htslib/sam.h>
Expand Down Expand Up @@ -371,7 +372,7 @@ int duplex(int argc, char* argv[]) {

std::unique_ptr<PairingNode> pairing_node =
template_complement_map.empty()
? std::make_unique<PairingNode>(stereo_node)
? std::make_unique<PairingNode>(stereo_node, ReadOrder::BY_CHANNEL)
: std::make_unique<PairingNode>(stereo_node,
std::move(template_complement_map));

Expand Down Expand Up @@ -410,7 +411,7 @@ int duplex(int argc, char* argv[]) {
kStatsPeriod, stats_reporters, stats_callables);
// End stats counting setup.

loader.load_reads(reads, parser.get<bool>("--recursive"), DataLoader::BY_CHANNEL);
loader.load_reads(reads, parser.get<bool>("--recursive"), ReadOrder::BY_CHANNEL);
bam_writer->join(); // Explicitly wait for all output rows to be written.
stats_sampler->terminate();
}
Expand Down
8 changes: 4 additions & 4 deletions dorado/data_loader/DataLoader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ void DataLoader::load_reads(const std::string& path,

auto iterate_directory = [&](const auto& iterator_fn) {
switch (traversal_order) {
case BY_CHANNEL:
case ReadOrder::BY_CHANNEL:
// If traversal in channel order is required, the following algorithm
// is used -
// 1. iterate through all the read metadata to collect channel information
Expand Down Expand Up @@ -199,7 +199,7 @@ void DataLoader::load_reads(const std::string& path,
}
}
break;
case UNRESTRICTED:
case ReadOrder::UNRESTRICTED:
for (const auto& entry : iterator_fn(path)) {
if (m_loaded_read_count == m_max_reads) {
break;
Expand All @@ -215,8 +215,8 @@ void DataLoader::load_reads(const std::string& path,
}
break;
default:
throw std::runtime_error("Unsupported traversal order detected " +
std::to_string(traversal_order));
throw std::runtime_error("Unsupported traversal order detected: " +
dorado::to_string(traversal_order));
}
};

Expand Down
8 changes: 2 additions & 6 deletions dorado/data_loader/DataLoader.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once
#include "utils/stats.h"
#include "utils/types.h"

#include <array>
#include <map>
Expand Down Expand Up @@ -28,11 +29,6 @@ using Pod5Ptr = std::unique_ptr<Pod5FileReader, Pod5Destructor>;

class DataLoader {
public:
enum ReadOrder {
UNRESTRICTED,
BY_CHANNEL,
};

DataLoader(MessageSink& read_sink,
const std::string& device,
size_t num_worker_threads,
Expand All @@ -42,7 +38,7 @@ class DataLoader {
~DataLoader() = default;
void load_reads(const std::string& path,
bool recursive_file_loading = false,
ReadOrder traversal_order = UNRESTRICTED);
ReadOrder traversal_order = ReadOrder::UNRESTRICTED);

static std::unordered_map<std::string, ReadGroup> load_read_groups(
std::string data_path,
Expand Down
10 changes: 8 additions & 2 deletions dorado/read_pipeline/PairingNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,16 @@ PairingNode::PairingNode(MessageSink& sink,
m_num_worker_threads(num_worker_threads),
m_max_num_keys(std::numeric_limits<size_t>::max()),
m_max_num_reads(std::numeric_limits<size_t>::max()) {
if (read_order == ReadOrder::pore_order) {
switch (read_order) {
case ReadOrder::BY_CHANNEL:
m_max_num_keys = 10;
} else if (read_order == ReadOrder::time_order) {
break;
case ReadOrder::BY_TIME:
m_max_num_reads = 10;
break;
default:
throw std::runtime_error("Unsupported read order detected: " +
dorado::to_string(read_order));
}

for (size_t i = 0; i < m_num_worker_threads; i++) {
Expand Down
5 changes: 2 additions & 3 deletions dorado/read_pipeline/PairingNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "ReadPipeline.h"
#include "utils/stats.h"
#include "utils/types.h"

#include <atomic>
#include <deque>
Expand All @@ -16,8 +17,6 @@ namespace dorado {

class PairingNode : public MessageSink {
public:
enum class ReadOrder { pore_order = 0, time_order };

// Template-complement map: uses the pair_list pairing method
PairingNode(MessageSink& sink,
std::map<std::string, std::string> template_complement_map,
Expand All @@ -26,7 +25,7 @@ class PairingNode : public MessageSink {

// No template-complement map: uses the pair_generation pairing method
PairingNode(MessageSink& sink,
ReadOrder read_order = ReadOrder::pore_order,
ReadOrder read_order,
int num_worker_threads = 2,
size_t max_reads = 1000);
~PairingNode();
Expand Down
15 changes: 15 additions & 0 deletions dorado/utils/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,19 @@ struct BamDestructor {
};
using BamPtr = std::unique_ptr<bam1_t, BamDestructor>;

enum class ReadOrder { UNRESTRICTED, BY_CHANNEL, BY_TIME };

inline std::string to_string(ReadOrder read_order) {
switch (read_order) {
case ReadOrder::UNRESTRICTED:
return "UNRESTRICTED";
case ReadOrder::BY_CHANNEL:
return "BY_CHANNEL";
case ReadOrder::BY_TIME:
return "BY_TIME";
default:
return "Unknown";
}
}

} // namespace dorado
2 changes: 1 addition & 1 deletion tests/DuplexSplitTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ TEST_CASE("4 subread split tagging", TEST_GROUP) {
MessageSinkToVector<std::shared_ptr<dorado::Read>> sink(3);
dorado::SubreadTaggerNode tag_node(sink);
dorado::StereoDuplexEncoderNode stereo_node(tag_node, read->model_stride);
dorado::PairingNode pairing_node(stereo_node);
dorado::PairingNode pairing_node(stereo_node, dorado::ReadOrder::BY_CHANNEL);

dorado::DuplexSplitSettings splitter_settings;
dorado::DuplexSplitNode splitter_node(pairing_node, splitter_settings, 1);
Expand Down
2 changes: 1 addition & 1 deletion tests/PairingNodeTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ TEST_CASE("Split read pairing", TEST_GROUP) {

MessageSinkToVector<dorado::Message> sink(5);
// one thread, one read - force reads through in order
dorado::PairingNode pairing_node(sink, dorado::PairingNode::ReadOrder::pore_order, 1, 1);
dorado::PairingNode pairing_node(sink, dorado::ReadOrder::BY_CHANNEL, 1, 1);
for (auto& read : reads) {
pairing_node.push_message(std::move(read));
}
Expand Down
2 changes: 1 addition & 1 deletion tests/Pod5DataLoaderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ TEST_CASE(TEST_GROUP "Load data sorted by channel id.") {

MessageSinkToVector<std::shared_ptr<dorado::Read>> sink(100);
dorado::DataLoader loader(sink, "cpu", 1, 0);
loader.load_reads(data_path, true, dorado::DataLoader::ReadOrder::BY_CHANNEL);
loader.load_reads(data_path, true, dorado::ReadOrder::BY_CHANNEL);

auto reads = sink.get_messages();
int start_channel_id = -1;
Expand Down

0 comments on commit eb375aa

Please sign in to comment.