Skip to content

Commit

Permalink
Merge branch 'smalton/DOR-750-per-barcode-polya-config' into 'master'
Browse files Browse the repository at this point in the history
[DOR-750] Per-barcode polyA config overrides

Closes DOR-750

See merge request machine-learning/dorado!1165
  • Loading branch information
malton-ont committed Sep 13, 2024
2 parents df861db + ecb2682 commit 6646701
Show file tree
Hide file tree
Showing 12 changed files with 296 additions and 51 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,8 @@ add_library(dorado_lib
dorado/poly_tail/plasmid_poly_tail_calculator.h
dorado/poly_tail/poly_tail_calculator.cpp
dorado/poly_tail/poly_tail_calculator.h
dorado/poly_tail/poly_tail_calculator_selector.cpp
dorado/poly_tail/poly_tail_calculator_selector.h
dorado/poly_tail/poly_tail_config.cpp
dorado/poly_tail/poly_tail_config.h
dorado/poly_tail/rna_poly_tail_calculator.cpp
Expand Down
29 changes: 29 additions & 0 deletions documentation/PolyTailConfig.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,35 @@ flank_threshold = 0.6
tail_interrupt_length = 10
```

### Overrides
Configuration options can be overridden for individual barcodes. We generate a default configuration as normal, and then
add overrides of specific values for each barcode by adding an `[[overrides]]` section labelled by the barcode name.
```
[anchors]
front_primer = "ATCG"
rear_primer = "CGTA"
[tail]
tail_interrupt_length = 5
[[overrides]]
barcode_id = "Custom-Kit_barcode01"
[overrides.threshold]
flank_threshold = 0.5
[[overrides]]
barcode_id = "Custom-Kit_barcode02"
[overrides.anchors]
front_primer = "AACC"
rear_primer = "GGTT"
[overrides.tail]
tail_interrupt_length = 10
```

This creates three configurations:
* a default configuration with custom front and rear primers and an interrupt length of 5
* a configuration to use for `barcode01` from kit `Custom-Kit` identical to the main custom settings (i.e. with the custom front and rear primers and the interrupt length), with an additional change to the `flank_threshold`.
* a configuration to use for `barcode02` from kit `Custom-Kit` with different primers and an interrupt length of 10, but with no change to the flank threshold.

### Configuration Options

| Option | Description |
Expand Down
21 changes: 12 additions & 9 deletions dorado/cli/basecaller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include "models/kits.h"
#include "models/model_complex.h"
#include "models/models.h"
#include "poly_tail/poly_tail_calculator.h"
#include "poly_tail/poly_tail_calculator_selector.h"
#include "read_pipeline/AdapterDetectorNode.h"
#include "read_pipeline/AlignerNode.h"
#include "read_pipeline/BarcodeClassifierNode.h"
Expand Down Expand Up @@ -490,6 +490,16 @@ void setup(const std::vector<std::string>& args,

auto client_info = std::make_shared<DefaultClientInfo>();
client_info->contexts().register_context<const demux::AdapterInfo>(std::move(adapter_info));

if (estimate_poly_a) {
auto poly_tail_calc_selector =
std::make_shared<const poly_tail::PolyTailCalculatorSelector>(
polya_config, is_rna_model(model_config), is_rna_adapter);
client_info->contexts().register_context<const poly_tail::PolyTailCalculatorSelector>(
std::move(poly_tail_calc_selector));
current_sink_node = pipeline_desc.add_node<PolyACalculatorNode>(
{current_sink_node}, std::thread::hardware_concurrency(), 1000);
}
if (barcoding_info) {
client_info->contexts().register_context<const demux::BarcodingInfo>(
std::move(barcoding_info));
Expand All @@ -500,14 +510,7 @@ void setup(const std::vector<std::string>& args,
current_sink_node = pipeline_desc.add_node<AdapterDetectorNode>(
{current_sink_node}, thread_allocations.adapter_threads);
}
if (estimate_poly_a) {
auto poly_tail_calculator = poly_tail::PolyTailCalculatorFactory::create(
is_rna_model(model_config), is_rna_adapter, polya_config);
client_info->contexts().register_context<const poly_tail::PolyTailCalculator>(
std::move(poly_tail_calculator));
current_sink_node = pipeline_desc.add_node<PolyACalculatorNode>(
{current_sink_node}, std::thread::hardware_concurrency(), 1000);
}

current_sink_node = pipeline_desc.add_node<ReadFilterNode>(
{current_sink_node}, min_qscore, default_parameters.min_sequence_length,
std::unordered_set<std::string>{}, thread_allocations.read_filter_threads);
Expand Down
13 changes: 5 additions & 8 deletions dorado/poly_tail/poly_tail_calculator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,18 +232,15 @@ int PolyTailCalculator::calculate_num_bases(const SimplexRead& read,
return num_bases;
}

std::shared_ptr<const PolyTailCalculator> PolyTailCalculatorFactory::create(
bool is_rna,
bool is_rna_adapter,
const std::string& config_file) {
auto config = prepare_config(config_file);
std::shared_ptr<const PolyTailCalculator>
PolyTailCalculatorFactory::create(const PolyTailConfig& config, bool is_rna, bool is_rna_adapter) {
if (is_rna) {
return std::make_unique<RNAPolyTailCalculator>(std::move(config), is_rna_adapter);
return std::make_unique<RNAPolyTailCalculator>(config, is_rna_adapter);
}
if (config.is_plasmid) {
return std::make_unique<PlasmidPolyTailCalculator>(std::move(config));
return std::make_unique<PlasmidPolyTailCalculator>(config);
}
return std::make_unique<DNAPolyTailCalculator>(std::move(config));
return std::make_unique<DNAPolyTailCalculator>(config);
}

} // namespace dorado::poly_tail
6 changes: 3 additions & 3 deletions dorado/poly_tail/poly_tail_calculator.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ class PolyTailCalculator {

class PolyTailCalculatorFactory {
public:
static std::shared_ptr<const PolyTailCalculator> create(bool is_rna,
bool is_rna_adapter,
const std::string& config_file);
static std::shared_ptr<const PolyTailCalculator> create(const PolyTailConfig& config,
bool is_rna,
bool is_rna_adapter);
};

} // namespace dorado::poly_tail
66 changes: 66 additions & 0 deletions dorado/poly_tail/poly_tail_calculator_selector.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#include "poly_tail_calculator_selector.h"

#include "poly_tail_calculator.h"
#include "poly_tail_config.h"

#include <fstream>
#include <sstream>

namespace dorado::poly_tail {

PolyTailCalculatorSelector::PolyTailCalculatorSelector(const std::filesystem::path& config,
bool is_rna,
bool is_rna_adapter) {
if (config.empty()) {
std::stringstream buffer("");
init(buffer, is_rna, is_rna_adapter);
return;
}

if (!std::filesystem::exists(config) || !std::filesystem::is_regular_file(config)) {
throw std::runtime_error("PolyA config file doesn't exist at " + config.string());
}

std::ifstream file(config);
if (!file.is_open()) {
throw std::runtime_error("Failed to open file " + config.string());
}

std::stringstream buffer;
buffer << file.rdbuf();
init(buffer, is_rna, is_rna_adapter);
}

PolyTailCalculatorSelector::PolyTailCalculatorSelector(std::istream& config_stream,
bool is_rna,
bool is_rna_adapter) {
init(config_stream, is_rna, is_rna_adapter);
}

void PolyTailCalculatorSelector::init(std::istream& config_stream,
bool is_rna,
bool is_rna_adapter) {
auto configs = prepare_configs(config_stream);
m_default = PolyTailCalculatorFactory::create(configs.back(), is_rna, is_rna_adapter);
configs.pop_back();

std::lock_guard<std::mutex> lock(m_lut_mutex);
for (const auto& config : configs) {
m_lut[config.barcode_id] =
PolyTailCalculatorFactory::create(config, is_rna, is_rna_adapter);
}
}

// Return the barcode-specific configuration if one has been provided, otherwise the default.
// If any barcode-specific configurations are present, do not attempt to estimate
// for unclassified reads - better to give no result than a wrong result in this case.
std::shared_ptr<const PolyTailCalculator> PolyTailCalculatorSelector::get_calculator(
const std::string& name) const {
std::lock_guard<std::mutex> lock(m_lut_mutex);
auto it = m_lut.find(name);
return (it == std::end(m_lut))
? (name == "unclassified" && !m_lut.empty() ? nullptr : m_default)
: it->second;
}

} // namespace dorado::poly_tail
31 changes: 31 additions & 0 deletions dorado/poly_tail/poly_tail_calculator_selector.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#pragma once

#include <filesystem>
#include <iosfwd>
#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>

namespace dorado::poly_tail {

class PolyTailCalculator;

class PolyTailCalculatorSelector {
public:
PolyTailCalculatorSelector(const std::filesystem::path& config,
bool is_rna,
bool is_rna_adapter);
PolyTailCalculatorSelector(std::istream& config_stream, bool is_rna, bool is_rna_adapter);

std::shared_ptr<const PolyTailCalculator> get_calculator(const std::string& name) const;

private:
void init(std::istream& config_stream, bool is_rna, bool is_rna_adapter);

mutable std::mutex m_lut_mutex;
std::unordered_map<std::string, std::shared_ptr<const PolyTailCalculator>> m_lut;
std::shared_ptr<const PolyTailCalculator> m_default;
};

} // namespace dorado::poly_tail
52 changes: 45 additions & 7 deletions dorado/poly_tail/poly_tail_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
#include <istream>
#include <sstream>
#include <string>
#include <unordered_set>

namespace dorado::poly_tail {
namespace {

PolyTailConfig prepare_config(std::istream& is) {
PolyTailConfig config;

const toml::value config_toml = toml::parse(is);
PolyTailConfig update_config(const toml::value& config_toml, PolyTailConfig config) {
if (config_toml.contains("barcode_id")) {
config.barcode_id = toml::find<std::string>(config_toml, "barcode_id");
}

if (config_toml.contains("anchors")) {
const auto& anchors = toml::find(config_toml, "anchors");
Expand Down Expand Up @@ -76,7 +78,43 @@ PolyTailConfig prepare_config(std::istream& is) {
return config;
}

PolyTailConfig prepare_config(const std::string& config_file) {
void add_configs(const toml::value& config_toml, std::vector<PolyTailConfig>& configs) {
// add the default config
auto default_config = update_config(config_toml, PolyTailConfig{});
if (!default_config.barcode_id.empty()) {
throw std::runtime_error("Default poly tail config must not specify barcode_id.");
}

// get override configs
if (config_toml.contains("overrides")) {
const std::vector<toml::value> overrides = toml::find(config_toml, "overrides").as_array();
std::unordered_set<std::string> ids;
for (auto& override_toml : overrides) {
auto override = update_config(override_toml, default_config);
ids.insert(override.barcode_id);
configs.push_back(std::move(override));
}
if (ids.count("") != 0) {
throw std::runtime_error("Missing barcode_id in override poly tail configuration.");
}
if (ids.size() != overrides.size()) {
throw std::runtime_error("Duplicate barcode_id found in poly tail config file.");
}
}

configs.push_back(std::move(default_config));
}

} // namespace

std::vector<PolyTailConfig> prepare_configs(std::istream& is) {
const toml::value config_toml = toml::parse(is);
std::vector<PolyTailConfig> configs;
add_configs(config_toml, configs);
return configs;
}

std::vector<PolyTailConfig> prepare_configs(const std::string& config_file) {
if (!config_file.empty()) {
if (!std::filesystem::exists(config_file) ||
!std::filesystem::is_regular_file(config_file)) {
Expand All @@ -90,10 +128,10 @@ PolyTailConfig prepare_config(const std::string& config_file) {
// Read the file contents into a string
std::stringstream buffer;
buffer << file.rdbuf();
return prepare_config(buffer);
return prepare_configs(buffer);
} else {
std::stringstream buffer("");
return prepare_config(buffer);
return prepare_configs(buffer);
}
}

Expand Down
12 changes: 7 additions & 5 deletions dorado/poly_tail/poly_tail_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <istream>
#include <string>
#include <vector>

namespace dorado::poly_tail {

Expand All @@ -21,15 +22,16 @@ struct PolyTailConfig {
bool is_plasmid = false;
int tail_interrupt_length = 0;
int min_base_count = 10;
std::string barcode_id;
};

// Prepare the PolyA configuration struct. If a configuration
// file is available, parse it to extract parameters. Otherwise
// prepare the default configuration.
PolyTailConfig prepare_config(const std::string& config_file);
// Prepare the PolyA configurations. If a configuration file is available, parse it to extract parameters.
// If barcode-specific overrides are present, the non-specific configuration will be at the back.
// Otherwise prepares a single default configuration.
std::vector<PolyTailConfig> prepare_configs(const std::string& config_file);

// Overloaded function that parses the configuration passed
// in as an input stream.
PolyTailConfig prepare_config(std::istream& is);
std::vector<PolyTailConfig> prepare_configs(std::istream& is);

} // namespace dorado::poly_tail
12 changes: 10 additions & 2 deletions dorado/read_pipeline/PolyACalculatorNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "ClientInfo.h"
#include "poly_tail/poly_tail_calculator.h"
#include "poly_tail/poly_tail_calculator_selector.h"
#include "utils/math_utils.h"
#include "utils/sequence_utils.h"

Expand All @@ -25,9 +26,16 @@ void PolyACalculatorNode::input_thread_fn() {
// If this message isn't a read, we'll get a bad_variant_access exception.
auto read = std::get<SimplexReadPtr>(std::move(message));

auto calculator = read->read_common.client_info->contexts()
.get_ptr<const poly_tail::PolyTailCalculator>();
auto selector = read->read_common.client_info->contexts()
.get_ptr<const poly_tail::PolyTailCalculatorSelector>();

if (!selector) {
send_message_to_sink(std::move(read));
num_not_called++;
continue;
}

auto calculator = selector->get_calculator(read->read_common.barcode);
if (!calculator) {
send_message_to_sink(std::move(read));
num_not_called++;
Expand Down
7 changes: 4 additions & 3 deletions tests/NodeSmokeTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include "model_downloader/model_downloader.h"
#include "models/kits.h"
#include "models/models.h"
#include "poly_tail/poly_tail_calculator.h"
#include "poly_tail/poly_tail_calculator_selector.h"
#include "read_pipeline/AdapterDetectorNode.h"
#include "read_pipeline/BarcodeClassifierNode.h"
#include "read_pipeline/BasecallerNode.h"
Expand Down Expand Up @@ -429,8 +429,9 @@ DEFINE_TEST(NodeSmokeTestRead, "PolyACalculatorNode") {

set_pipeline_restart(pipeline_restart);

client_info->contexts().register_context<const dorado::poly_tail::PolyTailCalculator>(
dorado::poly_tail::PolyTailCalculatorFactory::create(is_rna, is_rna_adapter, ""));
client_info->contexts().register_context<const dorado::poly_tail::PolyTailCalculatorSelector>(
std::make_shared<dorado::poly_tail::PolyTailCalculatorSelector>("", is_rna,
is_rna_adapter));

set_read_mutator([](dorado::SimplexReadPtr& read) {
read->read_common.model_stride = 2;
Expand Down
Loading

0 comments on commit 6646701

Please sign in to comment.