diff --git a/CMakeLists.txt b/CMakeLists.txt index 202f64c9d..0a1a5bb3d 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -281,6 +281,8 @@ add_library(dorado_lib dorado/poly_tail/plasmid_poly_tail_calculator.h dorado/poly_tail/poly_tail_calculator.cpp dorado/poly_tail/poly_tail_calculator.h + dorado/poly_tail/poly_tail_calculator_selector.cpp + dorado/poly_tail/poly_tail_calculator_selector.h dorado/poly_tail/poly_tail_config.cpp dorado/poly_tail/poly_tail_config.h dorado/poly_tail/rna_poly_tail_calculator.cpp diff --git a/documentation/PolyTailConfig.md b/documentation/PolyTailConfig.md index 03f229aa8..2459d81e2 100644 --- a/documentation/PolyTailConfig.md +++ b/documentation/PolyTailConfig.md @@ -53,6 +53,35 @@ flank_threshold = 0.6 tail_interrupt_length = 10 ``` +### Overrides +Configuration options can be overridden for individual barcodes. We generate a default configuration as normal, and then +add overrides of specific values for each barcode by adding an `[[overrides]]` section labelled by the barcode name. +``` +[anchors] +front_primer = "ATCG" +rear_primer = "CGTA" +[tail] +tail_interrupt_length = 5 + +[[overrides]] +barcode_id = "Custom-Kit_barcode01" +[overrides.threshold] +flank_threshold = 0.5 + +[[overrides]] +barcode_id = "Custom-Kit_barcode02" +[overrides.anchors] +front_primer = "AACC" +rear_primer = "GGTT" +[overrides.tail] +tail_interrupt_length = 10 +``` + +This creates three configurations: +* a default configuration with custom front and rear primers and an interrupt length of 5 +* a configuration to use for `barcode01` from kit `Custom-Kit` identical to the main custom settings (i.e. with the custom front and rear primers and the interrupt length), with an additional change to the `flank_threshold`. +* a configuration to use for `barcode02` from kit `Custom-Kit` with different primers and an interrupt length of 10, but with no change to the flank threshold. + ### Configuration Options | Option | Description | diff --git a/dorado/cli/basecaller.cpp b/dorado/cli/basecaller.cpp index cbddbb6eb..1d6df37d9 100644 --- a/dorado/cli/basecaller.cpp +++ b/dorado/cli/basecaller.cpp @@ -15,7 +15,7 @@ #include "models/kits.h" #include "models/model_complex.h" #include "models/models.h" -#include "poly_tail/poly_tail_calculator.h" +#include "poly_tail/poly_tail_calculator_selector.h" #include "read_pipeline/AdapterDetectorNode.h" #include "read_pipeline/AlignerNode.h" #include "read_pipeline/BarcodeClassifierNode.h" @@ -490,6 +490,16 @@ void setup(const std::vector& args, auto client_info = std::make_shared(); client_info->contexts().register_context(std::move(adapter_info)); + + if (estimate_poly_a) { + auto poly_tail_calc_selector = + std::make_shared( + polya_config, is_rna_model(model_config), is_rna_adapter); + client_info->contexts().register_context( + std::move(poly_tail_calc_selector)); + current_sink_node = pipeline_desc.add_node( + {current_sink_node}, std::thread::hardware_concurrency(), 1000); + } if (barcoding_info) { client_info->contexts().register_context( std::move(barcoding_info)); @@ -500,14 +510,7 @@ void setup(const std::vector& args, current_sink_node = pipeline_desc.add_node( {current_sink_node}, thread_allocations.adapter_threads); } - if (estimate_poly_a) { - auto poly_tail_calculator = poly_tail::PolyTailCalculatorFactory::create( - is_rna_model(model_config), is_rna_adapter, polya_config); - client_info->contexts().register_context( - std::move(poly_tail_calculator)); - current_sink_node = pipeline_desc.add_node( - {current_sink_node}, std::thread::hardware_concurrency(), 1000); - } + current_sink_node = pipeline_desc.add_node( {current_sink_node}, min_qscore, default_parameters.min_sequence_length, std::unordered_set{}, thread_allocations.read_filter_threads); diff --git a/dorado/poly_tail/poly_tail_calculator.cpp b/dorado/poly_tail/poly_tail_calculator.cpp index 5b0cd39c2..4d419e33e 100644 --- a/dorado/poly_tail/poly_tail_calculator.cpp +++ b/dorado/poly_tail/poly_tail_calculator.cpp @@ -232,18 +232,15 @@ int PolyTailCalculator::calculate_num_bases(const SimplexRead& read, return num_bases; } -std::shared_ptr PolyTailCalculatorFactory::create( - bool is_rna, - bool is_rna_adapter, - const std::string& config_file) { - auto config = prepare_config(config_file); +std::shared_ptr +PolyTailCalculatorFactory::create(const PolyTailConfig& config, bool is_rna, bool is_rna_adapter) { if (is_rna) { - return std::make_unique(std::move(config), is_rna_adapter); + return std::make_unique(config, is_rna_adapter); } if (config.is_plasmid) { - return std::make_unique(std::move(config)); + return std::make_unique(config); } - return std::make_unique(std::move(config)); + return std::make_unique(config); } } // namespace dorado::poly_tail diff --git a/dorado/poly_tail/poly_tail_calculator.h b/dorado/poly_tail/poly_tail_calculator.h index 3974158d9..65c114cac 100644 --- a/dorado/poly_tail/poly_tail_calculator.h +++ b/dorado/poly_tail/poly_tail_calculator.h @@ -71,9 +71,9 @@ class PolyTailCalculator { class PolyTailCalculatorFactory { public: - static std::shared_ptr create(bool is_rna, - bool is_rna_adapter, - const std::string& config_file); + static std::shared_ptr create(const PolyTailConfig& config, + bool is_rna, + bool is_rna_adapter); }; } // namespace dorado::poly_tail diff --git a/dorado/poly_tail/poly_tail_calculator_selector.cpp b/dorado/poly_tail/poly_tail_calculator_selector.cpp new file mode 100644 index 000000000..5661113b6 --- /dev/null +++ b/dorado/poly_tail/poly_tail_calculator_selector.cpp @@ -0,0 +1,66 @@ +#include "poly_tail_calculator_selector.h" + +#include "poly_tail_calculator.h" +#include "poly_tail_config.h" + +#include +#include + +namespace dorado::poly_tail { + +PolyTailCalculatorSelector::PolyTailCalculatorSelector(const std::filesystem::path& config, + bool is_rna, + bool is_rna_adapter) { + if (config.empty()) { + std::stringstream buffer(""); + init(buffer, is_rna, is_rna_adapter); + return; + } + + if (!std::filesystem::exists(config) || !std::filesystem::is_regular_file(config)) { + throw std::runtime_error("PolyA config file doesn't exist at " + config.string()); + } + + std::ifstream file(config); + if (!file.is_open()) { + throw std::runtime_error("Failed to open file " + config.string()); + } + + std::stringstream buffer; + buffer << file.rdbuf(); + init(buffer, is_rna, is_rna_adapter); +} + +PolyTailCalculatorSelector::PolyTailCalculatorSelector(std::istream& config_stream, + bool is_rna, + bool is_rna_adapter) { + init(config_stream, is_rna, is_rna_adapter); +} + +void PolyTailCalculatorSelector::init(std::istream& config_stream, + bool is_rna, + bool is_rna_adapter) { + auto configs = prepare_configs(config_stream); + m_default = PolyTailCalculatorFactory::create(configs.back(), is_rna, is_rna_adapter); + configs.pop_back(); + + std::lock_guard lock(m_lut_mutex); + for (const auto& config : configs) { + m_lut[config.barcode_id] = + PolyTailCalculatorFactory::create(config, is_rna, is_rna_adapter); + } +} + +// Return the barcode-specific configuration if one has been provided, otherwise the default. +// If any barcode-specific configurations are present, do not attempt to estimate +// for unclassified reads - better to give no result than a wrong result in this case. +std::shared_ptr PolyTailCalculatorSelector::get_calculator( + const std::string& name) const { + std::lock_guard lock(m_lut_mutex); + auto it = m_lut.find(name); + return (it == std::end(m_lut)) + ? (name == "unclassified" && !m_lut.empty() ? nullptr : m_default) + : it->second; +} + +} // namespace dorado::poly_tail diff --git a/dorado/poly_tail/poly_tail_calculator_selector.h b/dorado/poly_tail/poly_tail_calculator_selector.h new file mode 100644 index 000000000..9d8d76e11 --- /dev/null +++ b/dorado/poly_tail/poly_tail_calculator_selector.h @@ -0,0 +1,31 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace dorado::poly_tail { + +class PolyTailCalculator; + +class PolyTailCalculatorSelector { +public: + PolyTailCalculatorSelector(const std::filesystem::path& config, + bool is_rna, + bool is_rna_adapter); + PolyTailCalculatorSelector(std::istream& config_stream, bool is_rna, bool is_rna_adapter); + + std::shared_ptr get_calculator(const std::string& name) const; + +private: + void init(std::istream& config_stream, bool is_rna, bool is_rna_adapter); + + mutable std::mutex m_lut_mutex; + std::unordered_map> m_lut; + std::shared_ptr m_default; +}; + +} // namespace dorado::poly_tail diff --git a/dorado/poly_tail/poly_tail_config.cpp b/dorado/poly_tail/poly_tail_config.cpp index 320361e43..5a512a596 100644 --- a/dorado/poly_tail/poly_tail_config.cpp +++ b/dorado/poly_tail/poly_tail_config.cpp @@ -10,13 +10,15 @@ #include #include #include +#include namespace dorado::poly_tail { +namespace { -PolyTailConfig prepare_config(std::istream& is) { - PolyTailConfig config; - - const toml::value config_toml = toml::parse(is); +PolyTailConfig update_config(const toml::value& config_toml, PolyTailConfig config) { + if (config_toml.contains("barcode_id")) { + config.barcode_id = toml::find(config_toml, "barcode_id"); + } if (config_toml.contains("anchors")) { const auto& anchors = toml::find(config_toml, "anchors"); @@ -76,7 +78,43 @@ PolyTailConfig prepare_config(std::istream& is) { return config; } -PolyTailConfig prepare_config(const std::string& config_file) { +void add_configs(const toml::value& config_toml, std::vector& configs) { + // add the default config + auto default_config = update_config(config_toml, PolyTailConfig{}); + if (!default_config.barcode_id.empty()) { + throw std::runtime_error("Default poly tail config must not specify barcode_id."); + } + + // get override configs + if (config_toml.contains("overrides")) { + const std::vector overrides = toml::find(config_toml, "overrides").as_array(); + std::unordered_set ids; + for (auto& override_toml : overrides) { + auto override = update_config(override_toml, default_config); + ids.insert(override.barcode_id); + configs.push_back(std::move(override)); + } + if (ids.count("") != 0) { + throw std::runtime_error("Missing barcode_id in override poly tail configuration."); + } + if (ids.size() != overrides.size()) { + throw std::runtime_error("Duplicate barcode_id found in poly tail config file."); + } + } + + configs.push_back(std::move(default_config)); +} + +} // namespace + +std::vector prepare_configs(std::istream& is) { + const toml::value config_toml = toml::parse(is); + std::vector configs; + add_configs(config_toml, configs); + return configs; +} + +std::vector prepare_configs(const std::string& config_file) { if (!config_file.empty()) { if (!std::filesystem::exists(config_file) || !std::filesystem::is_regular_file(config_file)) { @@ -90,10 +128,10 @@ PolyTailConfig prepare_config(const std::string& config_file) { // Read the file contents into a string std::stringstream buffer; buffer << file.rdbuf(); - return prepare_config(buffer); + return prepare_configs(buffer); } else { std::stringstream buffer(""); - return prepare_config(buffer); + return prepare_configs(buffer); } } diff --git a/dorado/poly_tail/poly_tail_config.h b/dorado/poly_tail/poly_tail_config.h index 97914c108..58fbf8807 100644 --- a/dorado/poly_tail/poly_tail_config.h +++ b/dorado/poly_tail/poly_tail_config.h @@ -2,6 +2,7 @@ #include #include +#include namespace dorado::poly_tail { @@ -21,15 +22,16 @@ struct PolyTailConfig { bool is_plasmid = false; int tail_interrupt_length = 0; int min_base_count = 10; + std::string barcode_id; }; -// Prepare the PolyA configuration struct. If a configuration -// file is available, parse it to extract parameters. Otherwise -// prepare the default configuration. -PolyTailConfig prepare_config(const std::string& config_file); +// Prepare the PolyA configurations. If a configuration file is available, parse it to extract parameters. +// If barcode-specific overrides are present, the non-specific configuration will be at the back. +// Otherwise prepares a single default configuration. +std::vector prepare_configs(const std::string& config_file); // Overloaded function that parses the configuration passed // in as an input stream. -PolyTailConfig prepare_config(std::istream& is); +std::vector prepare_configs(std::istream& is); } // namespace dorado::poly_tail diff --git a/dorado/read_pipeline/PolyACalculatorNode.cpp b/dorado/read_pipeline/PolyACalculatorNode.cpp index d8938606f..0867657f7 100644 --- a/dorado/read_pipeline/PolyACalculatorNode.cpp +++ b/dorado/read_pipeline/PolyACalculatorNode.cpp @@ -2,6 +2,7 @@ #include "ClientInfo.h" #include "poly_tail/poly_tail_calculator.h" +#include "poly_tail/poly_tail_calculator_selector.h" #include "utils/math_utils.h" #include "utils/sequence_utils.h" @@ -25,9 +26,16 @@ void PolyACalculatorNode::input_thread_fn() { // If this message isn't a read, we'll get a bad_variant_access exception. auto read = std::get(std::move(message)); - auto calculator = read->read_common.client_info->contexts() - .get_ptr(); + auto selector = read->read_common.client_info->contexts() + .get_ptr(); + if (!selector) { + send_message_to_sink(std::move(read)); + num_not_called++; + continue; + } + + auto calculator = selector->get_calculator(read->read_common.barcode); if (!calculator) { send_message_to_sink(std::move(read)); num_not_called++; diff --git a/tests/NodeSmokeTest.cpp b/tests/NodeSmokeTest.cpp index ee5815151..e240088c0 100644 --- a/tests/NodeSmokeTest.cpp +++ b/tests/NodeSmokeTest.cpp @@ -7,7 +7,7 @@ #include "model_downloader/model_downloader.h" #include "models/kits.h" #include "models/models.h" -#include "poly_tail/poly_tail_calculator.h" +#include "poly_tail/poly_tail_calculator_selector.h" #include "read_pipeline/AdapterDetectorNode.h" #include "read_pipeline/BarcodeClassifierNode.h" #include "read_pipeline/BasecallerNode.h" @@ -429,8 +429,9 @@ DEFINE_TEST(NodeSmokeTestRead, "PolyACalculatorNode") { set_pipeline_restart(pipeline_restart); - client_info->contexts().register_context( - dorado::poly_tail::PolyTailCalculatorFactory::create(is_rna, is_rna_adapter, "")); + client_info->contexts().register_context( + std::make_shared("", is_rna, + is_rna_adapter)); set_read_mutator([](dorado::SimplexReadPtr& read) { read->read_common.model_stride = 2; diff --git a/tests/PolyACalculatorTest.cpp b/tests/PolyACalculatorTest.cpp index 27519718a..53f6da666 100644 --- a/tests/PolyACalculatorTest.cpp +++ b/tests/PolyACalculatorTest.cpp @@ -1,6 +1,6 @@ #include "MessageSinkUtils.h" #include "TestUtils.h" -#include "poly_tail/poly_tail_calculator.h" +#include "poly_tail/poly_tail_calculator_selector.h" #include "poly_tail/poly_tail_config.h" #include "read_pipeline/DefaultClientInfo.h" #include "read_pipeline/PolyACalculatorNode.h" @@ -36,6 +36,7 @@ TEST_CASE("PolyACalculator: Test polyT tail estimation", TEST_GROUP) { TestCase{143, "poly_a/r9_rev_cdna", false}, TestCase{35, "poly_a/r10_fwd_cdna", false}, TestCase{37, "poly_a/rna002", true}, TestCase{73, "poly_a/rna004", true}); + CAPTURE(data); dorado::PipelineDescriptor pipeline_desc; std::vector messages; auto sink = pipeline_desc.add_node({}, 100, messages); @@ -56,8 +57,9 @@ TEST_CASE("PolyACalculator: Test polyT tail estimation", TEST_GROUP) { read->read_common.read_id = "read_id"; read->read_common.client_info = std::make_shared(); read->read_common.client_info->contexts() - .register_context( - dorado::poly_tail::PolyTailCalculatorFactory::create(is_rna, false, "")); + .register_context( + std::make_shared("", is_rna, + false)); // Push a Read type. pipeline->push_message(std::move(read)); @@ -93,8 +95,9 @@ TEST_CASE("PolyACalculator: Test polyT tail estimation with custom config", TEST read->read_common.read_id = "read_id"; read->read_common.client_info = std::make_shared(); read->read_common.client_info->contexts() - .register_context( - dorado::poly_tail::PolyTailCalculatorFactory::create(false, false, config)); + .register_context( + std::make_shared(config, false, + false)); // Push a Read type. pipeline->push_message(std::move(read)); @@ -108,38 +111,33 @@ TEST_CASE("PolyACalculator: Test polyT tail estimation with custom config", TEST } TEST_CASE("PolyTailConfig: Test parsing file", TEST_GROUP) { - auto tmp_dir = make_temp_dir("polya_test"); - SECTION("Check failure with non-existent file.") { const std::string missing_file = "foo_bar_baz"; - CHECK_THROWS_WITH(dorado::poly_tail::prepare_config(missing_file), + CHECK_THROWS_WITH(dorado::poly_tail::prepare_configs(missing_file), "PolyA config file doesn't exist at foo_bar_baz"); } SECTION("Only one primer is provided") { - auto path = (tmp_dir.m_path / "only_one_primer.toml").string(); const toml::value data{{"anchors", toml::table{{"front_primer", "ACTG"}}}}; const std::string fmt = toml::format(data); std::stringstream buffer(fmt); - CHECK_THROWS_WITH(dorado::poly_tail::prepare_config(buffer), + CHECK_THROWS_WITH(dorado::poly_tail::prepare_configs(buffer), "Both front_primer and rear_primer must be provided in the PolyA " "configuration file."); } SECTION("Only one plasmid flank is provided") { - auto path = (tmp_dir.m_path / "only_one_flank.toml").string(); const toml::value data{{"anchors", toml::table{{"plasmid_rear_flank", "ACTG"}}}}; const std::string fmt = toml::format(data); std::stringstream buffer(fmt); - CHECK_THROWS_WITH(dorado::poly_tail::prepare_config(buffer), + CHECK_THROWS_WITH(dorado::poly_tail::prepare_configs(buffer), "Both plasmid_front_flank and plasmid_rear_flank must be provided in the " "PolyA configuration file."); } SECTION("Parse all supported configs") { - auto path = (tmp_dir.m_path / "only_one_flank.toml").string(); const toml::value data{{"anchors", toml::table{{"plasmid_front_flank", "CGTA"}, {"plasmid_rear_flank", "ACTG"}, {"front_primer", "AAAAAA"}, @@ -148,7 +146,9 @@ TEST_CASE("PolyTailConfig: Test parsing file", TEST_GROUP) { const std::string fmt = toml::format(data); std::stringstream buffer(fmt); - auto config = dorado::poly_tail::prepare_config(buffer); + auto configs = dorado::poly_tail::prepare_configs(buffer); + REQUIRE(configs.size() == 1); + const auto& config = configs.front(); CHECK(config.front_primer == "AAAAAA"); CHECK(config.rc_front_primer == "TTTTTT"); CHECK(config.rear_primer == "GGGGGG"); @@ -160,4 +160,72 @@ TEST_CASE("PolyTailConfig: Test parsing file", TEST_GROUP) { CHECK(config.is_plasmid); // Since the plasmid flanks were specified CHECK(config.tail_interrupt_length == 10); } + + SECTION("Override config missing id") { + const int NUM_CONFIGS = 3; + toml::array config_toml; + for (int i = 0; i < NUM_CONFIGS; ++i) { + toml::table data{}; + config_toml.push_back(data); + } + const toml::value data{{"overrides", config_toml}}; + const std::string fmt = toml::format(data); + std::stringstream buffer(fmt); + + CHECK_THROWS_WITH(dorado::poly_tail::prepare_configs(buffer), + "Missing barcode_id in override poly tail configuration."); + } + + SECTION("Override config duplicate id") { + const int NUM_CONFIGS = 3; + toml::array config_toml; + for (int i = 0; i < NUM_CONFIGS; ++i) { + toml::table data{{"barcode_id", "duplicate"}}; + config_toml.push_back(data); + } + const toml::value data{{"overrides", config_toml}}; + const std::string fmt = toml::format(data); + std::stringstream buffer(fmt); + + CHECK_THROWS_WITH(dorado::poly_tail::prepare_configs(buffer), + "Duplicate barcode_id found in poly tail config file."); + } + + SECTION("Default config contains barcode id") { + const int NUM_CONFIGS = 3; + toml::array config_toml; + for (int i = 0; i < NUM_CONFIGS; ++i) { + toml::table data{{"barcode_id", "barcode" + std::to_string(i)}}; + config_toml.push_back(data); + } + const toml::value data{{"barcode_id", "error"}, {"overrides", config_toml}}; + const std::string fmt = toml::format(data); + std::stringstream buffer(fmt); + + CHECK_THROWS_WITH(dorado::poly_tail::prepare_configs(buffer), + "Default poly tail config must not specify barcode_id."); + } + + SECTION("Parse override configs") { + const int NUM_CONFIGS = 3; + toml::array config_toml; + for (int i = 0; i < NUM_CONFIGS; ++i) { + toml::table data{{"barcode_id", "barcode" + std::to_string(i)}}; + config_toml.push_back(data); + } + + const toml::value data{{"tail", toml::table{{"tail_interrupt_length", 10}}}, + {"overrides", config_toml}}; + const std::string fmt = toml::format(data); + std::stringstream buffer(fmt); + + auto configs = dorado::poly_tail::prepare_configs(buffer); + REQUIRE(configs.size() == NUM_CONFIGS + 1); // specified configs + default + for (int i = 0; i < NUM_CONFIGS; ++i) { + CHECK(configs[i].barcode_id == + "barcode" + std::to_string(i)); // overridden value per config + CHECK(configs[i].tail_interrupt_length == 10); // default inherited from main config + } + CHECK(configs.back().barcode_id.empty()); + } }