Skip to content

Commit

Permalink
Merge branch 'aligner' into 'master'
Browse files Browse the repository at this point in the history
Fix aligner regressions from pipeline change

See merge request machine-learning/dorado!470
  • Loading branch information
iiSeymour committed Jul 6, 2023
2 parents a57987f + 4d3a265 commit ddb6f71
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 9 deletions.
8 changes: 4 additions & 4 deletions dorado/cli/aligner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,10 @@ int aligner(int argc, char* argv[]) {
add_pg_hdr(header);

PipelineDescriptor pipeline_desc;
auto aligner = pipeline_desc.add_node<Aligner>({}, index, kmer_size, window_size,
index_batch_size, aligner_threads);
auto hts_writer = pipeline_desc.add_node<HtsWriter>({}, "-", HtsWriter::OutputMode::BAM,
writer_threads, 0);
pipeline_desc.add_node_sink(aligner, hts_writer);
auto aligner = pipeline_desc.add_node<Aligner>({hts_writer}, index, kmer_size, window_size,
index_batch_size, aligner_threads);

// Create the Pipeline from our description.
std::vector<dorado::stats::StatsReporter> stats_reporters;
Expand Down Expand Up @@ -146,7 +145,8 @@ int aligner(int argc, char* argv[]) {
tracker.summarize();

spdlog::info("> finished alignment");

spdlog::info("> total/primary/unmapped {}/{}/{}", hts_writer_ref.get_total(),
hts_writer_ref.get_primary(), hts_writer_ref.get_unmapped());
return 0;
}

Expand Down
3 changes: 1 addition & 2 deletions dorado/cli/basecaller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,10 @@ void setup(std::vector<std::string> args,
{read_converter}, min_qscore, default_parameters.min_sequence_length,
std::unordered_set<std::string>{}, thread_allocations.read_filter_threads);

auto mod_base_caller_node = PipelineDescriptor::InvalidNodeHandle;
auto basecaller_node_sink = read_filter_node;

if (!remora_runners.empty()) {
mod_base_caller_node = pipeline_desc.add_node<ModBaseCallerNode>(
auto mod_base_caller_node = pipeline_desc.add_node<ModBaseCallerNode>(
{read_filter_node}, std::move(remora_runners),
thread_allocations.remora_threads * num_devices, model_stride, remora_batch_size);
basecaller_node_sink = mod_base_caller_node;
Expand Down
1 change: 0 additions & 1 deletion dorado/read_pipeline/HtsWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ HtsWriter::~HtsWriter() {
terminate_impl();
sam_hdr_destroy(m_header);
hts_close(m_file);
spdlog::info("> total/primary/unmapped {}/{}/{}", m_total, m_primary, m_unmapped);
}

HtsWriter::OutputMode HtsWriter::get_output_mode(const std::string& mode) {
Expand Down
3 changes: 3 additions & 0 deletions dorado/read_pipeline/HtsWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ class HtsWriter : public MessageSink {

int set_and_write_header(const sam_hdr_t* header);
static OutputMode get_output_mode(const std::string& mode);
size_t get_total() const { return m_total; }
size_t get_primary() const { return m_primary; }
size_t get_unmapped() const { return m_unmapped; }

private:
void terminate_impl();
Expand Down
3 changes: 2 additions & 1 deletion dorado/read_pipeline/ReadPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,8 @@ void MessageSink::send_message_to_sink(int sink_index, Message &&message) {

void Pipeline::push_message(Message &&message) {
assert(!m_nodes.empty());
dynamic_cast<MessageSink &>(*m_nodes.back()).push_message(std::move(message));
const auto source_node_index = m_source_to_sink_order.front();
dynamic_cast<MessageSink &>(*m_nodes.at(source_node_index)).push_message(std::move(message));
}

stats::NamedStats Pipeline::terminate() {
Expand Down
35 changes: 34 additions & 1 deletion tests/PipelineTest.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "MessageSinkUtils.h"
#include "read_pipeline/NullNode.h"
#include "read_pipeline/ReadPipeline.h"

Expand Down Expand Up @@ -89,7 +90,7 @@ TEST_CASE("Creation", TEST_GROUP) {
}

// Tests destruction order of a random linear pipeline.
TEST_CASE("LinearDestructionOrder") {
TEST_CASE("LinearDestructionOrder", TEST_GROUP) {
// Node that records destruction order.
class OrderTestNode : public MessageSink {
public:
Expand Down Expand Up @@ -136,3 +137,35 @@ TEST_CASE("LinearDestructionOrder") {
// Verify that nodes were destroyed in source-to-sink order.
CHECK(std::equal(destruction_order.cbegin(), destruction_order.cend(), indices.cbegin()));
}

// Test inputs flow in the expected way from the source node.
TEST_CASE("PipelineFlow", TEST_GROUP) {
// NullNode passes nothing on, so the sink should get no messages
// if they are sent to that node first.
{
// Natural construction order: sink to source.
PipelineDescriptor pipeline_desc;
std::vector<dorado::Message> messages;
auto sink = pipeline_desc.add_node<MessageSinkToVector>({}, 100, messages);
pipeline_desc.add_node<NullNode>({sink});
auto pipeline = dorado::Pipeline::create(std::move(pipeline_desc));
REQUIRE(pipeline != nullptr);
pipeline->push_message(std::shared_ptr<dorado::Read>());
pipeline.reset();
CHECK(messages.size() == 0);
}

{
// Peverse construction order: source to sink.
PipelineDescriptor pipeline_desc;
std::vector<dorado::Message> messages;
auto null_node = pipeline_desc.add_node<NullNode>({});
auto sink = pipeline_desc.add_node<MessageSinkToVector>({}, 100, messages);
pipeline_desc.add_node_sink(null_node, sink);
auto pipeline = dorado::Pipeline::create(std::move(pipeline_desc));
REQUIRE(pipeline != nullptr);
pipeline->push_message(std::shared_ptr<dorado::Read>());
pipeline.reset();
CHECK(messages.size() == 0);
}
}

0 comments on commit ddb6f71

Please sign in to comment.