From 8c2d004d71c9c21fb7bfbe283ba44bc100a67793 Mon Sep 17 00:00:00 2001 From: Joyjit Daw Date: Thu, 11 Jan 2024 17:15:05 +0000 Subject: [PATCH] Merge branch 'jdaw/fix-modbase-trim-reverse' into 'master' Correctly trim modbase tags for reverse strand alignments Closes DOR-523 See merge request machine-learning/dorado!797 (cherry picked from commit 59a445b75c05000c0e5b371f83e8b60034859ce5) c562c9ab Correctly trim modbase tags for reverse strand alignments --- dorado/demux/Trimmer.cpp | 17 +++++++++-- dorado/read_pipeline/HtsWriter.cpp | 7 +++++ tests/TrimTest.cpp | 30 +++++++++++++++++++ tests/data/trimmer/reverse_strand_record.bam | Bin 0 -> 1531 bytes 4 files changed, 52 insertions(+), 2 deletions(-) create mode 100644 tests/data/trimmer/reverse_strand_record.bam diff --git a/dorado/demux/Trimmer.cpp b/dorado/demux/Trimmer.cpp index 1c3879ab1..3c3276086 100644 --- a/dorado/demux/Trimmer.cpp +++ b/dorado/demux/Trimmer.cpp @@ -1,6 +1,7 @@ #include "Trimmer.h" #include "utils/bam_utils.h" +#include "utils/sequence_utils.h" #include "utils/trim.h" #include @@ -30,6 +31,13 @@ void trim_torch_tensor(at::Tensor& raw_data, std::pair sample raw_data = raw_data.index({Slice(sample_trim_interval.first, sample_trim_interval.second)}); } +// For alignments that are reverse complemented, the trim interval derived from adapters/barcodes +// will need to be reverse complemented when applied to the trimming of modbase tags because +// modbase tags are all relative to the original sequence that was basecalled. +std::pair reverse_complement_interval(const std::pair& interval, int seqlen) { + return {seqlen - interval.second, seqlen - interval.first}; +} + } // namespace namespace dorado { @@ -122,6 +130,8 @@ std::pair Trimmer::determine_trim_interval(const AdapterScoreResult& r BamPtr Trimmer::trim_sequence(BamPtr input, std::pair trim_interval) { bam1_t* input_record = input.get(); + bool is_seq_reversed = input_record->core.flag & BAM_FREVERSE; + // Fetch components that need to be trimmed. std::string seq = utils::extract_sequence(input_record); std::vector qual = utils::extract_quality(input_record); @@ -141,8 +151,10 @@ BamPtr Trimmer::trim_sequence(BamPtr input, std::pair trim_interval) { // |---------------------- ns ------------------| // |----ts----|--------moves signal-------------| ns = int(trimmed_moves.size() * stride) + ts; - auto [trimmed_modbase_str, trimmed_modbase_probs] = - utils::trim_modbase_info(seq, modbase_str, modbase_probs, trim_interval); + auto [trimmed_modbase_str, trimmed_modbase_probs] = utils::trim_modbase_info( + is_seq_reversed ? utils::reverse_complement(seq) : seq, modbase_str, modbase_probs, + is_seq_reversed ? reverse_complement_interval(trim_interval, int(seq.length())) + : trim_interval); auto n_cigar = input_record->core.n_cigar; std::vector ops; uint32_t ref_pos_consumed = 0; @@ -180,6 +192,7 @@ BamPtr Trimmer::trim_sequence(BamPtr input, std::pair trim_interval) { bam_aux_del(out_record, bam_aux_get(out_record, "ML")); bam_aux_update_array(out_record, "ML", 'C', int(trimmed_modbase_probs.size()), (uint8_t*)trimmed_modbase_probs.data()); + bam_aux_update_int(out_record, "MN", trimmed_seq.length()); } bam_aux_update_int(out_record, "ts", ts); diff --git a/dorado/read_pipeline/HtsWriter.cpp b/dorado/read_pipeline/HtsWriter.cpp index 0055c1937..960d8a88d 100644 --- a/dorado/read_pipeline/HtsWriter.cpp +++ b/dorado/read_pipeline/HtsWriter.cpp @@ -136,6 +136,13 @@ int HtsWriter::write(bam1_t* const record) { } m_primary = m_total - m_secondary - m_supplementary - m_unmapped; + // Verify that the MN tag, if it exists, and the sequence length are in sync. + if (auto tag = bam_aux_get(record, "MN"); tag != nullptr) { + if (bam_aux2i(tag) != record->core.l_qseq) { + throw std::runtime_error("MN tag and sequence length are not in sync."); + }; + } + // FIXME -- HtsWriter is constructed in a state where attempting to write // will segfault, since set_and_write_header has to have been called // in order to set m_header. diff --git a/tests/TrimTest.cpp b/tests/TrimTest.cpp index c70651656..f5761205d 100644 --- a/tests/TrimTest.cpp +++ b/tests/TrimTest.cpp @@ -1,15 +1,24 @@ #include "utils/trim.h" +#include "TestUtils.h" +#include "demux/Trimmer.h" +#include "read_pipeline/HtsReader.h" + #include #include +#include +#include #include +using Catch::Matchers::Equals; using Slice = at::indexing::Slice; using namespace dorado; #define TEST_GROUP "[utils][trim]" +namespace fs = std::filesystem; + TEST_CASE("Test trim signal", TEST_GROUP) { constexpr int signal_len = 2000; @@ -161,3 +170,24 @@ TEST_CASE("Test trim mod base info", TEST_GROUP) { CHECK(probs.size() == 0); } } + +// This test case is useful because trimming of reverse strand requires +// the modbase tags to be treated differently since they are written +// relative to the original sequence that was basecalled. +TEST_CASE("Test trim of reverse strand record in BAM", TEST_GROUP) { + const auto data_dir = fs::path(get_data_dir("trimmer")); + const auto bam_file = data_dir / "reverse_strand_record.bam"; + HtsReader reader(bam_file.string(), std::nullopt); + reader.read(); + auto &record = reader.record; + + Trimmer trimmer; + const std::pair trim_interval = {72, 647}; + auto trimmed_record = trimmer.trim_sequence(std::move(record), trim_interval); + auto seqlen = trimmed_record->core.l_qseq; + + CHECK(seqlen == (trim_interval.second - trim_interval.first)); + CHECK(bam_aux2i(bam_aux_get(trimmed_record.get(), "MN")) == seqlen); + CHECK_THAT(bam_aux2Z(bam_aux_get(trimmed_record.get(), "MM")), + Equals("C+h?,28,24;C+m?,28,24;")); +} diff --git a/tests/data/trimmer/reverse_strand_record.bam b/tests/data/trimmer/reverse_strand_record.bam new file mode 100644 index 0000000000000000000000000000000000000000..b2d688981b253c16e0b2695561e425a13843babd GIT binary patch literal 1531 zcmV=}kX570}(DITJ1!Hc?X z%I#hqwdtoJdJ1@v!?0j(?iCEE2c@f+qD(nCI6&#Zd(~Py;4oS4m24=%Hl|X$7^F=| zerRQTI{0TB113nUxEQZ&Hncq?Vjiw1#0Hn3V99~NpiiMs*pdpT+DAtWNc zNR}wbc&=JqsH&2?y42O^|EE^BcoBUI+8Y#-?WUn*o38uk0j5@S+HcFztwx||!sj6& zAk#F?BoonWEz_BxD2x9dn11$*vV<3#hN39pV7p%D5rFB7YX+^_9qYPkkw|{=H8kp= zY4pQc=dFIT`njcWN10u^#$8vJyT<+d$HNz;)05@&0RR9WiwFb&00000{{{d;LjnMb z1VvQKZrer_Wn;%y61|3^mC=KbnIUC5u>_K$EGa3FP$EzQt&Oy5H`)lYL?JC=#j%_; zTN1OFqT4R|1KsBjx@fq+&DM}*XWD_9}3Q~m#%teL-MFAI-3d(s(I2RBSQ6wm#30Js+ z2q$(VLI{nhh=L>%Bd~HD0^djw$Vq8JP$=U-PBHKc9NHrc^(YQ`l5kFCEK-3|64~I7 zct8XXfsvxf7NZ16vBK>UN)Sm*&Y_7^B*{o291Dz6TZ9qDGF1{B27JWB6yXrmBvG0K zP+LLRq_WIqL1n}vj%-CFBmx15L=i$!l-gL>&;?DQW!a9nNbSIeNq`E9(?sS`6a+lN znRZYbBq6s0ISC0F1(D4ok;rzkvzsfIOU2@y4EaJKUnmv|2%(Bn$mjFLa)wg*R<2ko zl?vc4mdfZZ#<*6)wR*i)t;ygnmt|Q-@GB#9P(3K$zO}`je7>L@RI7LEnyTum+Eg`z z5k~77bi*LT&^0Z$4*)`)lC6Uf%WNIqGwU?~+TYvT%hjp}l?qk~rAD)yhfu@ZyZ?`gW)l1x|U^eY7|FOSi!?OIKQa$nEfW&J_kkmm(QIZK_JPKNqF^%bK>{JV0qvM3G-w!G zZO{}my;j?Um}Xr&f@!cU#DZ`N!yrb+kkzIzEmdVL2Be0b83?2r1iFK8pagN2)yk5g z5T*DopxJ=&nCysImR>!83*Fz(<#KgsQ?npf)l^MmHe`TS4YX<*w+7RLOJShVWW>yf z4Mq>)zSJT>gBKWzacEomHU%_UFgUBVRb;|gO#*+}xS_Y~if|ovqExOoRl~@7H?tZ# zJ2siFt64b%xSQ4O#;J{Z1(v5=Ds9&gHl|ptRFooYVIi}@R#hq$88$L|QREk!vAle_ z^1l7!(Rwra*u36M-hXuQdc3=MI{AHPb3OGOuiJLLw$uC4^&a&62mVQ?-yPg{oCl60 zU9Bg-sjKz$N${25JMo9J!AZP`2lK^jaqgWgdV}+s>vsp8%je6Uw7FRQ_5H7-&2pvu z{4UtMn!b2BUtfRyES8q9$I=&{O#jOnE_$9j^p=BpZ|?Ty{d2GHxo&s3isQKN^?S4B zdOo{yosQpeXYrup&3wl>zue3=FE1az@;hF~kxtJhZ*H8PPRIA3K059mcYDX4|8RW& z;{Stm`q^l_2j9dUT-vqXot=(%E;f^Ic7w^}&)xNUEG^^J)1Q8Pczqc>SzL_k8+ZJ6 zMWx3D h001A02m}BC000301^_}s0stET0{{R300000004~a);$0K literal 0 HcmV?d00001