Skip to content

Commit

Permalink
Merge branch 'pipeline' into 'master'
Browse files Browse the repository at this point in the history
Pipeline management reworking

See merge request machine-learning/dorado!419
  • Loading branch information
tijyojwad committed Jul 6, 2023
2 parents 90c805a + a9610a0 commit ba40e53
Show file tree
Hide file tree
Showing 53 changed files with 1,060 additions and 767 deletions.
52 changes: 30 additions & 22 deletions dorado/cli/aligner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,43 +101,51 @@ int aligner(int argc, char* argv[]) {

spdlog::info("> loading index {}", index);

std::vector<dorado::stats::StatsCallable> stats_callables;
ProgressTracker tracker(0, false);
stats_callables.push_back(
[&tracker](const stats::NamedStats& stats) { tracker.update_progress_bar(stats); });

HtsWriter writer("-", HtsWriter::OutputMode::BAM, writer_threads, 0);
Aligner aligner(writer, index, kmer_size, window_size, index_batch_size, aligner_threads);
HtsReader reader(reads[0]);

spdlog::debug("> input fmt: {} aligned: {}", reader.format, reader.is_aligned);
auto header = sam_hdr_dup(reader.header);
add_pg_hdr(header);
utils::add_sq_hdr(header, aligner.get_sequence_records_for_header());
writer.write_header(header);

// Setup stats counting.
std::unique_ptr<dorado::stats::StatsSampler> stats_sampler;
PipelineDescriptor pipeline_desc;
auto aligner = pipeline_desc.add_node<Aligner>({}, index, kmer_size, window_size,
index_batch_size, aligner_threads);
auto hts_writer = pipeline_desc.add_node<HtsWriter>({}, "-", HtsWriter::OutputMode::BAM,
writer_threads, 0);
pipeline_desc.add_node_sink(aligner, hts_writer);

// Create the Pipeline from our description.
std::vector<dorado::stats::StatsReporter> stats_reporters;
using dorado::stats::make_stats_reporter;
stats_reporters.push_back(make_stats_reporter(writer));
stats_reporters.push_back(make_stats_reporter(aligner));
auto pipeline = Pipeline::create(std::move(pipeline_desc), &stats_reporters);
if (pipeline == nullptr) {
spdlog::error("Failed to create pipeline");
std::exit(EXIT_FAILURE);
}

// At present, header output file header writing relies on direct node method calls
// rather than the pipeline framework.
const auto& aligner_ref = dynamic_cast<Aligner&>(pipeline->get_node_ref(aligner));
utils::add_sq_hdr(header, aligner_ref.get_sequence_records_for_header());
auto& hts_writer_ref = dynamic_cast<HtsWriter&>(pipeline->get_node_ref(hts_writer));
hts_writer_ref.set_and_write_header(header);

// Set up stats counting
std::vector<dorado::stats::StatsCallable> stats_callables;
ProgressTracker tracker(0, false);
stats_callables.push_back(
[&tracker](const stats::NamedStats& stats) { tracker.update_progress_bar(stats); });
constexpr auto kStatsPeriod = 100ms;
stats_sampler = std::make_unique<dorado::stats::StatsSampler>(kStatsPeriod, stats_reporters,
stats_callables);
// End stats counting setup.
auto stats_sampler = std::make_unique<dorado::stats::StatsSampler>(
kStatsPeriod, stats_reporters, stats_callables);

spdlog::info("> starting alignment");
reader.read(aligner, max_reads);
writer.join();
reader.read(*pipeline, max_reads);

stats_sampler->terminate();
auto final_stats = pipeline->terminate();
tracker.update_progress_bar(final_stats);
tracker.summarize();

spdlog::info("> finished alignment");
spdlog::info("> total/primary/unmapped {}/{}/{}", writer.total, writer.primary,
writer.unmapped);

return 0;
}
Expand Down
136 changes: 71 additions & 65 deletions dorado/cli/basecaller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,26 +122,67 @@ void setup(std::vector<std::string> args,
std::unique_ptr<sam_hdr_t, void (*)(sam_hdr_t*)> hdr(sam_hdr_init(), sam_hdr_destroy);
utils::add_pg_hdr(hdr.get(), args);
utils::add_rg_hdr(hdr.get(), read_groups);
std::shared_ptr<HtsWriter> bam_writer;
std::shared_ptr<Aligner> aligner;
MessageSink* converted_reads_sink = nullptr;

PipelineDescriptor pipeline_desc;
auto hts_writer = PipelineDescriptor::InvalidNodeHandle;
auto aligner = PipelineDescriptor::InvalidNodeHandle;
auto converted_reads_sink = PipelineDescriptor::InvalidNodeHandle;
std::unordered_set<std::string> reads_already_processed;
// TODO -- refactor to avoid repeated code here.
if (ref.empty()) {
bam_writer = std::make_shared<HtsWriter>("-", output_mode,
thread_allocations.writer_threads, num_reads);
bam_writer->write_header(hdr.get());
converted_reads_sink = bam_writer.get();
hts_writer = pipeline_desc.add_node<HtsWriter>(
{}, "-", output_mode, thread_allocations.writer_threads, num_reads);
converted_reads_sink = hts_writer;
} else {
bam_writer = std::make_shared<HtsWriter>("-", output_mode,
thread_allocations.writer_threads, num_reads);
aligner =
std::make_shared<Aligner>(*bam_writer, ref, kmer_size, window_size,
mm2_index_batch_size, thread_allocations.aligner_threads);
utils::add_sq_hdr(hdr.get(), aligner->get_sequence_records_for_header());
bam_writer->write_header(hdr.get());
converted_reads_sink = aligner.get();
hts_writer = pipeline_desc.add_node<HtsWriter>(
{}, "-", output_mode, thread_allocations.writer_threads, num_reads);
aligner = pipeline_desc.add_node<Aligner>({hts_writer}, ref, kmer_size, window_size,
mm2_index_batch_size,
thread_allocations.aligner_threads);
converted_reads_sink = aligner;
}
auto read_converter = pipeline_desc.add_node<ReadToBamType>(
{converted_reads_sink}, emit_moves, rna, thread_allocations.read_converter_threads,
methylation_threshold_pct);
auto read_filter_node = pipeline_desc.add_node<ReadFilterNode>(
{read_converter}, min_qscore, default_parameters.min_sequence_length,
std::unordered_set<std::string>{}, thread_allocations.read_filter_threads);

auto mod_base_caller_node = PipelineDescriptor::InvalidNodeHandle;
auto basecaller_node_sink = read_filter_node;

if (!remora_runners.empty()) {
mod_base_caller_node = pipeline_desc.add_node<ModBaseCallerNode>(
{read_filter_node}, std::move(remora_runners),
thread_allocations.remora_threads * num_devices, model_stride, remora_batch_size);
basecaller_node_sink = mod_base_caller_node;
}
const int kBatchTimeoutMS = 100;
auto basecaller_node = pipeline_desc.add_node<BasecallerNode>(
{basecaller_node_sink}, std::move(runners), overlap, kBatchTimeoutMS, model_name, 1000,
"BasecallerNode", false, get_model_mean_qscore_start_pos(model_config));

auto scaler_node =
pipeline_desc.add_node<ScalerNode>({basecaller_node}, model_config.signal_norm_params,
thread_allocations.scaler_node_threads);

// Create the Pipeline from our description.
std::vector<dorado::stats::StatsReporter> stats_reporters;
auto pipeline = Pipeline::create(std::move(pipeline_desc), &stats_reporters);
if (pipeline == nullptr) {
spdlog::error("Failed to create pipeline");
std::exit(EXIT_FAILURE);
}

// At present, header output file header writing relies on direct node method calls
// rather than the pipeline framework.
auto& hts_writer_ref = dynamic_cast<HtsWriter&>(pipeline->get_node_ref(hts_writer));
if (!ref.empty()) {
const auto& aligner_ref = dynamic_cast<Aligner&>(pipeline->get_node_ref(aligner));
utils::add_sq_hdr(hdr.get(), aligner_ref.get_sequence_records_for_header());
}
hts_writer_ref.set_and_write_header(hdr.get());

std::unordered_set<std::string> reads_already_processed;
if (!resume_from_file.empty()) {
spdlog::info("> Inspecting resume file...");
// Turn off warning logging as header info is fetched.
Expand All @@ -157,69 +198,34 @@ void setup(std::vector<std::string> args,
"Resume only works if the same model is used. Resume model was " +
resume_model_name + " and current model is " + model_name);
}
ResumeLoaderNode resume_loader(*bam_writer, resume_from_file);
// Resume functionality injects reads directly into the writer node.
ResumeLoaderNode resume_loader(hts_writer_ref, resume_from_file);
resume_loader.copy_completed_reads();
reads_already_processed = resume_loader.get_processed_read_ids();
}

ReadToBamType read_converter(*converted_reads_sink, emit_moves, rna,
thread_allocations.read_converter_threads,
methylation_threshold_pct);
ReadFilterNode read_filter_node(read_converter, min_qscore,
default_parameters.min_seqeuence_length, {},
thread_allocations.read_filter_threads);

std::unique_ptr<ModBaseCallerNode> mod_base_caller_node;
MessageSink* basecaller_node_sink = static_cast<MessageSink*>(&read_filter_node);
if (!remora_runners.empty()) {
mod_base_caller_node = std::make_unique<ModBaseCallerNode>(
read_filter_node, std::move(remora_runners),
thread_allocations.remora_threads * num_devices, model_stride, remora_batch_size);
basecaller_node_sink = static_cast<MessageSink*>(mod_base_caller_node.get());
}
const int kBatchTimeoutMS = 100;
BasecallerNode basecaller_node(*basecaller_node_sink, std::move(runners), overlap,
kBatchTimeoutMS, model_name, 1000, "BasecallerNode", false,
get_model_mean_qscore_start_pos(model_config));
ScalerNode scaler_node(basecaller_node, model_config.signal_norm_params,
thread_allocations.scaler_node_threads);

DataLoader loader(scaler_node, "cpu", thread_allocations.loader_threads, max_reads, read_list,
reads_already_processed);

// Setup stats counting
std::unique_ptr<dorado::stats::StatsSampler> stats_sampler;
std::vector<dorado::stats::StatsReporter> stats_reporters;
using dorado::stats::make_stats_reporter;
stats_reporters.push_back(make_stats_reporter(basecaller_node));
if (mod_base_caller_node) {
stats_reporters.push_back(make_stats_reporter(*mod_base_caller_node));
}
if (aligner) {
stats_reporters.push_back(make_stats_reporter(*aligner));
}
stats_reporters.push_back(make_stats_reporter(*bam_writer));
stats_reporters.push_back(make_stats_reporter(loader));
stats_reporters.push_back(make_stats_reporter(scaler_node));
stats_reporters.push_back(make_stats_reporter(read_filter_node));

std::vector<dorado::stats::StatsCallable> stats_callables;
ProgressTracker tracker(num_reads, duplex);
stats_callables.push_back(
[&tracker](const stats::NamedStats& stats) { tracker.update_progress_bar(stats); });

constexpr auto kStatsPeriod = 100ms;
stats_sampler = std::make_unique<dorado::stats::StatsSampler>(kStatsPeriod, stats_reporters,
stats_callables);
// End stats counting setup.
auto stats_sampler = std::make_unique<dorado::stats::StatsSampler>(
kStatsPeriod, stats_reporters, stats_callables);

DataLoader loader(*pipeline, "cpu", thread_allocations.loader_threads, max_reads, read_list,
reads_already_processed);

// Run pipeline.
loader.load_reads(data_path, recursive_file_loading);

bam_writer->join();
// End pipeline

// Stop the stats sampler thread before tearing down any pipeline objects.
stats_sampler->terminate();

// Stop the pipeline, as we do so collecting final processing stats.
// Then update progress tracking one more time from this thread, to
// allow accurate summarisation.
auto final_stats = pipeline->terminate();
tracker.update_progress_bar(final_stats);
tracker.summarize();
if (!dump_stats_file.empty()) {
std::ofstream stats_file(dump_stats_file);
Expand Down
Loading

0 comments on commit ba40e53

Please sign in to comment.