diff --git a/dorado/read_pipeline/DuplexSplitNode.cpp b/dorado/read_pipeline/DuplexSplitNode.cpp index 053f9e208..40e3ddb9e 100644 --- a/dorado/read_pipeline/DuplexSplitNode.cpp +++ b/dorado/read_pipeline/DuplexSplitNode.cpp @@ -32,7 +32,7 @@ auto filter_ranges(const PosRanges& ranges, FilterF filter_f) { //merges overlapping ranges and ranges separated by merge_dist or less //ranges supposed to be sorted by start coordinate -PosRanges merge_ranges(const PosRanges& ranges, size_t merge_dist) { +PosRanges merge_ranges(const PosRanges& ranges, uint64_t merge_dist) { PosRanges merged; for (auto& r : ranges) { assert(merged.empty() || r.first >= merged.back().first); @@ -45,16 +45,16 @@ PosRanges merge_ranges(const PosRanges& ranges, size_t merge_dist) { return merged; } -std::vector> detect_pore_signal(const torch::Tensor& signal, - float threshold, - size_t cluster_dist, - size_t ignore_prefix) { - std::vector> ans; +std::vector> detect_pore_signal(const torch::Tensor& signal, + float threshold, + uint64_t cluster_dist, + uint64_t ignore_prefix) { + std::vector> ans; auto pore_a = signal.accessor(); int64_t cl_start = -1; int64_t cl_end = -1; - for (size_t i = ignore_prefix; i < pore_a.size(0); i++) { + for (auto i = ignore_prefix; i < pore_a.size(0); i++) { if (pore_a[i] > threshold) { //check if we need to start new cluster if (cl_end == -1 || i > cl_end + cluster_dist) { @@ -121,7 +121,11 @@ std::vector find_adapter_matches(const std::string& adapter, } //semi-global alignment of "template region" to "complement region" -bool check_rc_match(const std::string& seq, PosRange templ_r, PosRange compl_r, int dist_thr) { +//returns range in the compl_r +std::optional check_rc_match(const std::string& seq, + PosRange templ_r, + PosRange compl_r, + int dist_thr) { assert(templ_r.second > templ_r.first); assert(compl_r.second > compl_r.first); assert(dist_thr >= 0); @@ -129,17 +133,24 @@ bool check_rc_match(const std::string& seq, PosRange templ_r, PosRange compl_r, auto rc_compl = dorado::utils::reverse_complement( seq.substr(compl_r.first, compl_r.second - compl_r.first)); - auto edlib_cfg = edlibNewAlignConfig(dist_thr, EDLIB_MODE_HW, EDLIB_TASK_DISTANCE, NULL, 0); + auto edlib_cfg = edlibNewAlignConfig(dist_thr, EDLIB_MODE_HW, EDLIB_TASK_LOC, NULL, 0); auto edlib_result = edlibAlign(seq.c_str() + templ_r.first, templ_r.second - templ_r.first, rc_compl.c_str(), rc_compl.size(), edlib_cfg); assert(edlib_result.status == EDLIB_STATUS_OK); bool match = (edlib_result.status == EDLIB_STATUS_OK) && (edlib_result.editDistance != -1); - assert(!match || edlib_result.editDistance <= dist_thr); + std::optional res = std::nullopt; + if (match) { + assert(edlib_result.editDistance <= dist_thr); + assert(edlib_result.numLocations > 0 && edlib_result.endLocations[0] < compl_r.second && + edlib_result.startLocations[0] < compl_r.second); + res = PosRange(compl_r.second - edlib_result.endLocations[0], + compl_r.second - edlib_result.startLocations[0]); + } edlibFreeAlignResult(edlib_result); - return match; + return res; } //TODO end_reason access? @@ -170,16 +181,17 @@ std::shared_ptr subread(const Read& read, PosRange seq_range, PosRange sig {torch::indexing::Slice(signal_range.first, signal_range.second)}); subread->attributes.read_number = -1; + //we adjust for it in new start time + subread->attributes.num_samples = signal_range.second - signal_range.first; + subread->num_trimmed_samples = 0; subread->start_sample = read.start_sample + read.num_trimmed_samples + signal_range.first; - subread->end_sample = read.start_sample + read.num_trimmed_samples + signal_range.second; + subread->end_sample = subread->start_sample + subread->attributes.num_samples; + auto start_time_ms = read.run_acquisition_start_time_ms + uint64_t(std::round(subread->start_sample * 1000. / subread->sample_rate)); subread->attributes.start_time = utils::get_string_timestamp_from_unix_time(start_time_ms); subread->start_time_ms = start_time_ms; - //we adjust for it in new start time above - subread->num_trimmed_samples = 0; - subread->seq = subread->seq.substr(seq_range.first, seq_range.second - seq_range.first); subread->qstring = subread->qstring.substr(seq_range.first, seq_range.second - seq_range.first); subread->moves = std::vector(subread->moves.begin() + signal_range.first / stride, @@ -199,28 +211,25 @@ std::shared_ptr subread(const Read& read, PosRange seq_range, PosRange sig namespace dorado { -DuplexSplitNode::ExtRead::ExtRead(std::shared_ptr r) - : read(std::move(r)), - data_as_float32(read->raw_data.to(torch::kFloat)), - move_sums(utils::move_cum_sums(read->moves)) { - assert(!move_sums.empty()); - assert(move_sums.back() == read->seq.length()); +DuplexSplitNode::ExtRead DuplexSplitNode::create_ext_read(std::shared_ptr r) const { + ExtRead ext_read; + ext_read.read = r; + ext_read.move_sums = utils::move_cum_sums(r->moves); + assert(!ext_read.move_sums.empty()); + assert(ext_read.move_sums.back() == r->seq.length()); + ext_read.data_as_float32 = r->raw_data.to(torch::kFloat); + ext_read.possible_pore_regions = possible_pore_regions(ext_read); + return ext_read; } -PosRanges DuplexSplitNode::possible_pore_regions(const DuplexSplitNode::ExtRead& read, - float pore_thr) const { - PosRanges pore_regions; - - //pA formula before scaling: - //pA = read->scaling * (raw + read->offset); - //pA formula after scaling: - //pA = read->scale * raw + read->shift +PosRanges DuplexSplitNode::possible_pore_regions(const DuplexSplitNode::ExtRead& read) const { spdlog::trace("Analyzing signal in read {}", read.read->read_id); - auto pore_sample_ranges = detect_pore_signal( - read.data_as_float32, (pore_thr - read.read->shift) / read.read->scale, - m_settings.pore_cl_dist, m_settings.expect_pore_prefix); + auto pore_sample_ranges = + detect_pore_signal(read.data_as_float32, m_settings.pore_thr, m_settings.pore_cl_dist, + m_settings.expect_pore_prefix); + PosRanges pore_regions; for (auto pore_sample_range : pore_sample_ranges) { auto move_start = pore_sample_range.first / read.read->model_stride; auto move_end = pore_sample_range.second / read.read->model_stride; @@ -235,7 +244,9 @@ PosRanges DuplexSplitNode::possible_pore_regions(const DuplexSplitNode::ExtRead& //NB. adding adapter length auto end_pos = read.move_sums[move_end]; assert(end_pos > start_pos); - pore_regions.push_back({start_pos, end_pos}); + if (end_pos <= start_pos + m_settings.max_pore_region) { + pore_regions.push_back({start_pos, end_pos}); + } } return pore_regions; @@ -249,21 +260,51 @@ bool DuplexSplitNode::check_nearby_adapter(const Read& read, PosRange r, int ada .has_value(); } -//r is potential spacer region -bool DuplexSplitNode::check_flank_match(const Read& read, PosRange r, int dist_thr) const { - return r.first >= m_settings.end_flank && - r.second + m_settings.start_flank <= read.seq.length() && - check_rc_match(read.seq, {r.first - m_settings.end_flank, r.first - m_settings.end_trim}, - //including spacer region in search - {r.first, r.second + m_settings.start_flank}, dist_thr); +//'spacer' is region potentially containing templ/compl strand boundary +//returns optional pair of matching ranges (first strictly to the left of spacer region) +std::optional> +DuplexSplitNode::check_flank_match(const Read& read, PosRange spacer, float err_thr) const { + const uint64_t rlen = read.seq.length(); + assert(spacer.first <= spacer.second && spacer.second <= rlen); + if (spacer.first <= m_settings.strand_end_trim || spacer.second == rlen) { + return std::nullopt; + } + + const uint64_t left_start = (spacer.first > m_settings.strand_end_flank) + ? spacer.first - m_settings.strand_end_flank + : 0; + const uint64_t left_end = spacer.first - m_settings.strand_end_trim; + assert(left_start < left_end); + const uint64_t left_span = left_end - left_start; + + //including spacer region in search + const uint64_t right_start = spacer.first; + //(r.second - r.first) adjusts for potentially incorrectly detected split region + //, shifting into correct sequence + const uint64_t right_end = std::min( + spacer.second + m_settings.strand_start_flank + (spacer.second - spacer.first), rlen); + assert(right_start < right_end); + const uint64_t right_span = right_end - right_start; + + const int dist_thr = std::round(err_thr * left_span); + if (left_span >= m_settings.min_flank && right_span >= left_span) { + if (auto match = check_rc_match(read.seq, {left_start, left_end}, + //including spacer region in search + {right_start, right_end}, dist_thr)) { + return std::pair{PosRange{left_start, left_end}, *match}; + } + } + return std::nullopt; } -std::optional DuplexSplitNode::identify_extra_middle_split( +std::optional DuplexSplitNode::identify_middle_adapter_split( const Read& read) const { - const auto r_l = read.seq.size(); - const auto search_span = std::max(m_settings.middle_adapter_search_span, - int(std::round(m_settings.middle_adapter_search_frac * r_l))); - if (r_l < m_settings.end_flank + m_settings.start_flank || r_l < search_span) { + assert(m_settings.strand_end_flank > m_settings.strand_end_trim + m_settings.min_flank); + const uint64_t r_l = read.seq.size(); + const uint64_t search_span = + std::max(m_settings.middle_adapter_search_span, + uint64_t(std::round(m_settings.middle_adapter_search_frac * r_l))); + if (r_l < search_span) { return std::nullopt; } @@ -271,13 +312,73 @@ std::optional DuplexSplitNode::identify_extra_middle_ if (auto adapter_match = find_best_adapter_match( m_settings.adapter, read.seq, m_settings.relaxed_adapter_edist, {r_l / 2 - search_span / 2, r_l / 2 + search_span / 2})) { - auto adapter_start = adapter_match->first; + const uint64_t adapter_start = adapter_match->first; + const uint64_t adapter_end = adapter_match->second; spdlog::trace("Checking middle match & start/end match"); - if (check_flank_match(read, {adapter_start, adapter_start}, - m_settings.relaxed_flank_edist) && - check_rc_match(read.seq, {r_l - m_settings.end_flank, r_l - m_settings.end_trim}, - {0, m_settings.start_flank}, m_settings.relaxed_flank_edist)) { - return PosRange{adapter_start - 1, adapter_start}; + //Checking match around adapter + if (check_flank_match(read, {adapter_start, adapter_start}, m_settings.flank_err)) { + //Checking start/end match + //some initializations might 'overflow' and not make sense, but not if check_rc_match below actually ends up checked! + const uint64_t query_start = r_l - m_settings.strand_end_flank; + const uint64_t query_end = r_l - m_settings.strand_end_trim; + const uint64_t query_span = query_end - query_start; + const int dist_thr = std::round(m_settings.flank_err * query_span); + + const uint64_t template_start = 0; + const uint64_t template_end = std::min(m_settings.strand_start_flank, adapter_start); + const uint64_t template_span = template_end - template_start; + + if (adapter_end + m_settings.strand_end_flank > r_l || template_span < query_span || + check_rc_match( + read.seq, + {r_l - m_settings.strand_end_flank, r_l - m_settings.strand_end_trim}, + {0, std::min(m_settings.strand_start_flank, r_l)}, dist_thr)) { + return PosRange{adapter_start - 1, adapter_start}; + } + } + } + return std::nullopt; +} + +std::optional DuplexSplitNode::identify_extra_middle_split( + const Read& read) const { + const uint64_t r_l = read.seq.size(); + //TODO parameterize + const float ext_start_frac = 0.1; + //extend to tolerate some extra length difference + const uint64_t ext_start_flank = + std::max(uint64_t(ext_start_frac * r_l), m_settings.strand_start_flank); + //further consider only reasonably long reads + if (ext_start_flank + m_settings.strand_end_flank > r_l) { + return std::nullopt; + } + + int flank_edist = std::round(m_settings.flank_err * + (m_settings.strand_end_flank - m_settings.strand_end_trim)); + + spdlog::trace("Checking start/end match"); + if (auto templ_start_match = check_rc_match( + read.seq, {r_l - m_settings.strand_end_flank, r_l - m_settings.strand_end_trim}, + {0, std::min(r_l, ext_start_flank)}, flank_edist)) { + //check if matched region and query overlap + if (templ_start_match->second + m_settings.strand_end_flank > r_l) { + return std::nullopt; + } + uint64_t est_middle = (templ_start_match->second + (r_l - m_settings.strand_end_flank)) / 2; + spdlog::trace("Middle estimate {}", est_middle); + //TODO parameterize + const int min_split_margin = 100; + const float split_margin_frac = 0.05; + const auto split_margin = std::max(min_split_margin, int(split_margin_frac * r_l)); + + spdlog::trace("Checking approx middle match"); + if (auto middle_match_ranges = + check_flank_match(read, {est_middle - split_margin, est_middle + split_margin}, + m_settings.flank_err)) { + est_middle = + (middle_match_ranges->first.second + middle_match_ranges->second.first) / 2; + spdlog::trace("Middle re-estimate {}", est_middle); + return PosRange{est_middle - 1, est_middle}; } } return std::nullopt; @@ -320,39 +421,38 @@ std::vector> DuplexSplitNode::subreads( std::vector> DuplexSplitNode::build_split_finders() const { std::vector> split_finders; - split_finders.push_back( - {"PORE_ADAPTER", [&](const ExtRead& read) { - return filter_ranges( - possible_pore_regions(read, m_settings.pore_thr), [&](PosRange r) { - return check_nearby_adapter(*read.read, r, m_settings.adapter_edist); - }); - }}); + split_finders.push_back({"PORE_ADAPTER", [&](const ExtRead& read) { + return filter_ranges(read.possible_pore_regions, [&](PosRange r) { + return check_nearby_adapter(*read.read, r, + m_settings.adapter_edist); + }); + }}); if (!m_settings.simplex_mode) { split_finders.push_back( {"PORE_FLANK", [&](const ExtRead& read) { return merge_ranges( - filter_ranges(possible_pore_regions(read, m_settings.pore_thr), + filter_ranges(read.possible_pore_regions, [&](PosRange r) { return check_flank_match(*read.read, r, - m_settings.flank_edist); + m_settings.flank_err); }), - m_settings.end_flank + m_settings.start_flank); + m_settings.strand_end_flank + m_settings.strand_start_flank); }}); split_finders.push_back( {"PORE_ALL", [&](const ExtRead& read) { return merge_ranges( - filter_ranges(possible_pore_regions(read, m_settings.relaxed_pore_thr), + filter_ranges(read.possible_pore_regions, [&](PosRange r) { return check_nearby_adapter( *read.read, r, m_settings.relaxed_adapter_edist) && check_flank_match( *read.read, r, - m_settings.relaxed_flank_edist); + m_settings.relaxed_flank_err); }), - m_settings.end_flank + m_settings.start_flank); + m_settings.strand_end_flank + m_settings.strand_start_flank); }}); split_finders.push_back( @@ -363,11 +463,19 @@ DuplexSplitNode::build_split_finders() const { [&](PosRange r) { return check_flank_match(*read.read, {r.first, r.first}, - m_settings.flank_edist); + m_settings.flank_err); }); }}); split_finders.push_back({"ADAPTER_MIDDLE", [&](const ExtRead& read) { + if (auto split = identify_middle_adapter_split(*read.read)) { + return PosRanges{*split}; + } else { + return PosRanges(); + } + }}); + + split_finders.push_back({"SPLIT_MIDDLE", [&](const ExtRead& read) { if (auto split = identify_extra_middle_split(*read.read)) { return PosRanges{*split}; } else { @@ -392,7 +500,7 @@ std::vector> DuplexSplitNode::split(std::shared_ptr return std::vector>{std::move(init_read)}; } - std::vector to_split{ExtRead(init_read)}; + std::vector to_split{create_ext_read(init_read)}; for (const auto& [description, split_f] : m_split_finders) { spdlog::trace("Running {}", description); std::vector split_round_result; @@ -405,7 +513,7 @@ std::vector> DuplexSplitNode::split(std::shared_ptr split_round_result.push_back(std::move(r)); } else { for (auto sr : subreads(r.read, spacers)) { - split_round_result.emplace_back(sr); + split_round_result.push_back(create_ext_read(sr)); } } } diff --git a/dorado/read_pipeline/DuplexSplitNode.h b/dorado/read_pipeline/DuplexSplitNode.h index 342457361..179886275 100644 --- a/dorado/read_pipeline/DuplexSplitNode.h +++ b/dorado/read_pipeline/DuplexSplitNode.h @@ -17,25 +17,28 @@ namespace dorado { struct DuplexSplitSettings { bool enabled = true; bool simplex_mode = false; - float pore_thr = 160.; - size_t pore_cl_dist = 4000; // TODO maybe use frequency * 1sec here? - float relaxed_pore_thr = 150.; + float pore_thr = 2.2; + uint64_t pore_cl_dist = 4000; // TODO maybe use frequency * 1sec here? + //maximal 'open pore' region to consider (bp) + uint64_t max_pore_region = 500; //usually template read region to the left of potential spacer region - size_t end_flank = 1200; + uint64_t strand_end_flank = 1200; //trim potentially erroneous (and/or PCR adapter) bases at end of query - size_t end_trim = 200; + uint64_t strand_end_trim = 200; //adjusted to adapter presense and potential loss of bases on query, leading to 'shift' - size_t start_flank = 1700; - int flank_edist = 150; - int relaxed_flank_edist = 250; + uint64_t strand_start_flank = 1700; + //minimal query size to consider in "short read" case + uint64_t min_flank = 300; + float flank_err = 0.15; + float relaxed_flank_err = 0.275; int adapter_edist = 4; - int relaxed_adapter_edist = 6; + int relaxed_adapter_edist = 8; uint64_t pore_adapter_range = 100; //bp //in bases uint64_t expect_adapter_prefix = 200; //in samples uint64_t expect_pore_prefix = 5000; - int middle_adapter_search_span = 1000; + uint64_t middle_adapter_search_span = 1000; float middle_adapter_search_frac = 0.2; //TODO put in config @@ -68,15 +71,18 @@ class DuplexSplitNode : public MessageSink { std::shared_ptr read; torch::Tensor data_as_float32; std::vector move_sums; - - explicit ExtRead(std::shared_ptr r); + PosRanges possible_pore_regions; }; typedef std::function SplitFinderF; - std::vector possible_pore_regions(const ExtRead& read, float pore_thr) const; + ExtRead create_ext_read(std::shared_ptr r) const; + std::vector possible_pore_regions(const ExtRead& read) const; bool check_nearby_adapter(const Read& read, PosRange r, int adapter_edist) const; - bool check_flank_match(const Read& read, PosRange r, int dist_thr) const; + std::optional> check_flank_match(const Read& read, + PosRange r, + float err_thr) const; + std::optional identify_middle_adapter_split(const Read& read) const; std::optional identify_extra_middle_split(const Read& read) const; std::vector> subreads(std::shared_ptr read, diff --git a/tests/DuplexSplitTest.cpp b/tests/DuplexSplitTest.cpp index 42964248d..c275597c2 100644 --- a/tests/DuplexSplitTest.cpp +++ b/tests/DuplexSplitTest.cpp @@ -75,6 +75,19 @@ TEST_CASE("4 subread splitting test", TEST_GROUP) { "2023-02-21T12:46:39.607+00:00", "2023-02-21T12:46:53.105+00:00"}); + std::vector start_time_mss; + for (auto &r : split_res) { + start_time_mss.push_back(r->start_time_ms); + } + REQUIRE(start_time_mss == + std::vector{1676983561529, 1676983585837, 1676983599607, 1676983613105}); + + std::vector num_sampless; + for (auto &r : split_res) { + num_sampless.push_back(r->attributes.num_samples); + } + REQUIRE(num_sampless == std::vector{97125, 55055, 53940, 50475}); + std::set names; for (auto &r : split_res) { names.insert(r->read_id);