Skip to content

Commit

Permalink
Merge branch 'smalton/DOR-857-poly-improvements' into 'master'
Browse files Browse the repository at this point in the history
[DOR-857] PolyA estimation updates

Closes DOR-857

See merge request machine-learning/dorado!1185
  • Loading branch information
malton-ont committed Sep 16, 2024
2 parents 8e3a870 + 6666df5 commit 0b79407
Show file tree
Hide file tree
Showing 9 changed files with 134 additions and 56 deletions.
2 changes: 1 addition & 1 deletion dorado/demux/AdapterDetector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ struct Adapter {
const std::vector<Adapter> adapters = {
{"LSK109", "AATGTACTTCGTTCAGTTACGTATTGCT", "AGCAATACGTAACTGAACGAAGT"},
{"LSK110", "CCTGTACTTCGTTCAGTTACGTATTGC", "AGCAATACGTAACTGAAC"},
{"RNA004", "", "GGTTGTTTCTGTTGGTGCTGATATTGC"}};
{"RNA004", "", "GGTTGTTTCTGTTGGTGCTG"}};

// For primers, we look for each primer sequence, and its reverse complement, at both the front and rear of the read.
struct Primer {
Expand Down
9 changes: 9 additions & 0 deletions dorado/poly_tail/dna_poly_tail_calculator.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@ class DNAPolyTailCalculator : public PolyTailCalculator {
float average_samples_per_base(const std::vector<float>& sizes) const override;
int signal_length_adjustment(int) const override { return 0; };
float min_avg_val() const override { return -3.0f; }
std::pair<int, int> buffer_range(const std::pair<int, int>& interval,
[[maybe_unused]] float samples_per_base) const override {
// The buffer is currently the length of the interval
// itself. This heuristic generally works because a longer interval
// detected is likely to be the correct one so we relax the
// how close it needs to be to the anchor to account for errors
// in anchor determination.
return {interval.second - interval.first, interval.second - interval.first};
}
std::pair<int, int> signal_range(int signal_anchor,
int signal_len,
float samples_per_base) const override;
Expand Down
141 changes: 93 additions & 48 deletions dorado/poly_tail/poly_tail_calculator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,12 @@ std::pair<int, int> PolyTailCalculator::determine_signal_bounds(int signal_ancho
return {avg, std::sqrt(var)};
};

std::pair<float, float> last_interval_stats;

// Maximum variance between consecutive values to be
// considered part of the same interval.
const float kVar = 0.35f;
// How close the mean values should be for consecutive intervals
// to be merged.
const float kMeanValueProximity = 0.2f;
const float kMeanValueProximity = 0.25f;
// Maximum gap between intervals that can be combined.
const int kMaxSampleGap = int(std::round(num_samples_per_base * 5));
// Minimum size of intervals considered for merge.
Expand All @@ -74,48 +72,86 @@ std::pair<int, int> PolyTailCalculator::determine_signal_bounds(int signal_ancho
spdlog::trace("Bounds left {}, right {}", left_end, right_end);

std::vector<std::pair<int, int>> intervals;
std::pair<float, float> last_interval_stats;
const int kStride = 3;

auto try_merge_latest_intervals = [&, right = right_end]() { // keep clang-tidy happy
if (intervals.size() > 1) {
auto last_interval = intervals.rbegin();
auto prev_last_interval = std::next(last_interval);

// Attempt to merge the last interval and the previous one if
// 1. the gap between the intervals is small and
// 2. the averages of the two intervals are close and
// 3. the first interval is longer than some threshold and
// 4. the second interval is longer than some threshold or reaches the end of the range
spdlog::trace("Evaluate for merge {}-{} with {}-{}", prev_last_interval->first,
prev_last_interval->second, last_interval->first, last_interval->second);

const auto [avg_1, stdev_1] =
calc_stats(prev_last_interval->first, prev_last_interval->second);
const auto [avg_2, stdev_2] = calc_stats(last_interval->first, last_interval->second);
if ((last_interval->first - prev_last_interval->second < kMaxSampleGap) &&
(std::abs(avg_2 - avg_1) < kMeanValueProximity) &&
(prev_last_interval->second - prev_last_interval->first >
kMinIntervalSizeForMerge &&
(last_interval->second - last_interval->first > kMinIntervalSizeForMerge ||
last_interval->second >= (right - kStride)))) {
spdlog::trace("Merge interval {}-{} with {}-{}", prev_last_interval->first,
prev_last_interval->second, last_interval->first,
last_interval->second);
prev_last_interval->second = last_interval->second;

intervals.pop_back();
return true;
}
}
return false;
};

bool in_range = false;
for (int s = left_end; s < (right_end - kMaxSampleGap); s += kStride) {
const int e = s + kMaxSampleGap;
auto [avg, stdev] = calc_stats(s, e);
if (stdev < kVar) {
// If a new interval overlaps with the previous interval, just extend
// the previous interval.
if (intervals.size() > 1 && intervals.back().second >= s &&
std::abs(avg - last_interval_stats.first) < kMeanValueProximity &&
(avg > kMinAvgVal)) {
auto& last = intervals.back();
spdlog::trace("extend interval {}-{} to {}-{} avg {} stdev {}", last.first,
last.second, s, e, avg, stdev);
last.second = e;
} else {
// Attempt to merge the most recent interval and the one before
// that if the gap between the intervals is small and both of the
// intervals are longer than some threshold.
if (intervals.size() >= 2) {
auto& last = intervals.back();
auto& second_last = intervals[intervals.size() - 2];
spdlog::trace("Evaluate for merge {}-{} with {}-{}", second_last.first,
second_last.second, last.first, last.second);
if ((last.first - second_last.second < kMaxSampleGap) &&
(last.second - last.first > kMinIntervalSizeForMerge) &&
(second_last.second - second_last.first > kMinIntervalSizeForMerge)) {
spdlog::trace("Merge interval {}-{} with {}-{}", second_last.first,
second_last.second, second_last.first, last.second);
second_last.second = last.second;
intervals.pop_back();
} else if (second_last.second - second_last.first <
std::round(num_samples_per_base * m_config.min_base_count)) {
intervals.erase(intervals.end() - 2);
}
}
if (avg > kMinAvgVal && stdev < kVar) {
if (intervals.empty()) {
spdlog::trace("Add new interval {}-{} avg {} stdev {}", s, e, avg, stdev);
intervals.push_back({s, e});
} else {
// If new interval overlaps with the previous interval and
// intervals have a similar mean, just extend the previous interval.
auto last_interval = intervals.rbegin();
if (last_interval->second >= s &&
std::abs(avg - last_interval_stats.first) < kMeanValueProximity) {
// recalc stats for new interval
std::tie(avg, stdev) = calc_stats(last_interval->first, e);
spdlog::trace("extend interval {}-{} to {}-{} avg {} stdev {}",
last_interval->first, last_interval->second, last_interval->first,
e, avg, stdev);
last_interval->second = e;
} else {
try_merge_latest_intervals();
spdlog::trace("Add new interval {}-{} avg {} stdev {}", s, e, avg, stdev);
intervals.push_back({s, e});
}
}
last_interval_stats = {avg, stdev};
in_range = true;
} else if (in_range) {
if (try_merge_latest_intervals()) {
// recalc stats for new interval
auto last_interval = intervals.rbegin();
last_interval_stats = calc_stats(last_interval->first, last_interval->second);
}
in_range = false;
}
}

if (in_range) {
// We won't have attempted to merge the final two ranges if we were still in a range at the end
try_merge_latest_intervals();
}

std::string int_str = "";
for (const auto& in : intervals) {
int_str += std::to_string(in.first) + "-" + std::to_string(in.second) + ", ";
Expand All @@ -124,18 +160,26 @@ std::pair<int, int> PolyTailCalculator::determine_signal_bounds(int signal_ancho

// Cluster intervals if there are interrupted poly tails that should
// be combined. Interruption length is specified through a config file.
// In the example below, tail estimation show include both stretches
// In the example below, tail estimation should include both stretches
// of As along with the small gap in the middle.
// e.g. -----AAAAAAA--AAAAAA-----
const int kMaxInterruption =
static_cast<int>(std::round(num_samples_per_base * m_config.tail_interrupt_length));
const size_t num_bases = read.read_common.seq.length();
const auto num_samples = read.read_common.get_raw_data_samples();
const auto stride = read.read_common.model_stride;
const auto seq_to_sig_map =
dorado::utils::moves_to_map(read.read_common.moves, stride, num_samples, num_bases + 1);

std::vector<std::pair<int, int>> clustered_intervals;
for (const auto& i : intervals) {
if (clustered_intervals.empty()) {
clustered_intervals.push_back(i);
} else {
auto& last = clustered_intervals.back();
if (std::abs(i.first - last.second) < kMaxInterruption) {
auto start = std::lower_bound(std::begin(seq_to_sig_map), std::end(seq_to_sig_map),
last.second);
auto end = std::upper_bound(start, std::end(seq_to_sig_map), i.first);
auto bases = static_cast<int>(std::distance(start, end));
if (bases <= m_config.tail_interrupt_length) {
last.second = i.second;
} else {
clustered_intervals.push_back(i);
Expand All @@ -154,18 +198,19 @@ std::pair<int, int> PolyTailCalculator::determine_signal_bounds(int signal_ancho
std::vector<std::pair<int, int>> filtered_intervals;
std::copy_if(clustered_intervals.begin(), clustered_intervals.end(),
std::back_inserter(filtered_intervals), [&](auto& i) {
int buffer = i.second - i.first;
auto buffer = buffer_range(i, num_samples_per_base);
// Only keep intervals that are close-ish to the signal anchor.
// i.e. the anchor needs to be within the buffer region of
// the interval. The buffer is currently the length of the interval
// itself. This heuristic generally works because a longer interval
// detected is likely to be the correct one so we relax the
// how close it needs to be to the anchor to account for errors
// in anchor determination.
// <----buffer---|--- interval ---|---- buffer---->
bool within_anchor_dist = (signal_anchor >= std::max(0, i.first - buffer)) &&
(signal_anchor <= (i.second + buffer));
return within_anchor_dist;
// the interval
// <----buffer.first---|--- interval ---|---- buffer.second---->
bool within_anchor_dist =
(signal_anchor >= std::max(0, i.first - buffer.first)) &&
(signal_anchor <= (i.second + buffer.second));
bool meets_min_base_count =
(i.second - i.first) >=
std::round(num_samples_per_base * m_config.min_base_count);

return within_anchor_dist && meets_min_base_count;
});

int_str = "";
Expand Down
4 changes: 4 additions & 0 deletions dorado/poly_tail/poly_tail_calculator.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ class PolyTailCalculator {
// Floor for average signal value of poly tail.
virtual float min_avg_val() const = 0;

// Returns the acceptable distance between the supplied interval and the anchor
virtual std::pair<int, int> buffer_range(const std::pair<int, int>& interval,
float samples_per_base) const = 0;

// Determine the outer boundary of the signal space to consider based on the anchor.
virtual std::pair<int, int> signal_range(int signal_anchor,
int signal_len,
Expand Down
2 changes: 1 addition & 1 deletion dorado/poly_tail/poly_tail_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ PolyTailConfig update_config(const toml::value& config_toml, PolyTailConfig conf

if (config_toml.contains("threshold")) {
const auto& threshold = toml::find(config_toml, "threshold");
if (threshold.contains("flank_threshold ")) {
if (threshold.contains("flank_threshold")) {
config.flank_threshold = toml::find<float>(threshold, "flank_threshold");
}
}
Expand Down
4 changes: 2 additions & 2 deletions dorado/poly_tail/poly_tail_config.h
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
#pragma once

#include <istream>
#include <iosfwd>
#include <string>
#include <vector>

namespace dorado::poly_tail {

struct PolyTailConfig {
std::string rna_adapter = "GGTTGTTTCTGTTGGTGCTGATATTGC"; // RNA
std::string rna_adapter = "GGTTGTTTCTGTTGGTGCTG"; // RNA
std::string front_primer = "TTTCTGTTGGTGCTGATATTGCTTT"; // SSP
std::string rear_primer = "ACTTGCCTGTCGCTCTATCTTCAGAGGAGAGTCCGCCGCCCGCAAGTTTT"; // VNP
std::string rc_front_primer;
Expand Down
22 changes: 20 additions & 2 deletions dorado/poly_tail/rna_poly_tail_calculator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <spdlog/spdlog.h>

#include <algorithm>
#include <cmath>

namespace {
EdlibAlignConfig init_edlib_config_for_adapter() {
Expand All @@ -24,6 +25,11 @@ RNAPolyTailCalculator::RNAPolyTailCalculator(PolyTailConfig config, bool is_rna_
: PolyTailCalculator(std::move(config)), m_rna_adapter(is_rna_adapter) {}

float RNAPolyTailCalculator::average_samples_per_base(const std::vector<float>& sizes) const {
const auto log_sum =
std::accumulate(std::cbegin(sizes), std::cend(sizes), 0.f,
[](float total, float size) { return total + std::log(size); });
const auto samples_per_base1 = sizes.empty() ? 0 : std::exp(log_sum / sizes.size());

auto quantiles = dorado::utils::quantiles(sizes, {0.1f, 0.9f});
float sum = 0.f;
int count = 0;
Expand All @@ -33,7 +39,8 @@ float RNAPolyTailCalculator::average_samples_per_base(const std::vector<float>&
count++;
}
}
return (count > 0 ? (sum / count) : 0.f);
const auto samples_per_base2 = (count > 0 ? (sum / count) : 0.f);
return (samples_per_base1 + samples_per_base2) / 2;
}

SignalAnchorInfo RNAPolyTailCalculator::determine_signal_anchor_and_strand(
Expand Down Expand Up @@ -69,7 +76,7 @@ SignalAnchorInfo RNAPolyTailCalculator::determine_signal_anchor_and_strand(
read.read_common.moves, stride, read.read_common.get_raw_data_samples(),
read.read_common.seq.size() + 1);

const int base_anchor = bottom_start + align_result.startLocations[0] - m_config.rna_offset;
const int base_anchor = bottom_start + align_result.startLocations[0];
// RNA sequence is reversed wrt the signal and move table
const int signal_anchor =
int(seq_to_sig_map[static_cast<int>(seq_view.length()) - base_anchor]);
Expand Down Expand Up @@ -99,4 +106,15 @@ std::pair<int, int> RNAPolyTailCalculator::signal_range(int signal_anchor,
return {std::max(0, signal_anchor - 50), std::min(signal_len, signal_anchor + kSpread)};
}

std::pair<int, int> RNAPolyTailCalculator::buffer_range(const std::pair<int, int>& interval,
float samples_per_base) const {
if (m_rna_adapter) {
// Extend the buffer towards the front of the read as there may be something between the adapter and the polytail
return {interval.second - interval.first +
static_cast<int>(std::round(m_config.rna_offset * samples_per_base)),
interval.second - interval.first};
}
return {interval.second - interval.first, interval.second - interval.first};
}

} // namespace dorado::poly_tail
2 changes: 2 additions & 0 deletions dorado/poly_tail/rna_poly_tail_calculator.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ class RNAPolyTailCalculator : public PolyTailCalculator {
float average_samples_per_base(const std::vector<float>& sizes) const override;
int signal_length_adjustment(int signal_len) const override;
float min_avg_val() const override { return -0.5f; }
std::pair<int, int> buffer_range(const std::pair<int, int>& interval,
float samples_per_base) const override;
std::pair<int, int> signal_range(int signal_anchor,
int signal_len,
float samples_per_base) const override;
Expand Down
4 changes: 2 additions & 2 deletions tests/PolyACalculatorTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ struct TestCase {

TEST_CASE("PolyACalculator: Test polyT tail estimation", TEST_GROUP) {
auto [gt, data, is_rna] = GENERATE(
TestCase{143, "poly_a/r9_rev_cdna", false}, TestCase{35, "poly_a/r10_fwd_cdna", false},
TestCase{37, "poly_a/rna002", true}, TestCase{73, "poly_a/rna004", true});
TestCase{149, "poly_a/r9_rev_cdna", false}, TestCase{35, "poly_a/r10_fwd_cdna", false},
TestCase{39, "poly_a/rna002", true}, TestCase{76, "poly_a/rna004", true});

CAPTURE(data);
dorado::PipelineDescriptor pipeline_desc;
Expand Down

0 comments on commit 0b79407

Please sign in to comment.