From 1b6083b9e93ec7d750258da1201ff91cff4b2626 Mon Sep 17 00:00:00 2001 From: Chris Seymour Date: Wed, 28 Jun 2023 15:01:40 +0100 Subject: [PATCH 1/7] adding r941 v3.x 5hmCG+5mCG models --- README.md | 3 +++ dorado/utils/models.h | 5 +++++ 2 files changed, 8 insertions(+) mode change 100755 => 100644 dorado/utils/models.h diff --git a/README.md b/README.md index 48d72ead4..2a4006eee 100644 --- a/README.md +++ b/README.md @@ -172,6 +172,9 @@ The following simplex models are also available: * dna_r9.4.1_e8_fast@v3.4_5mCG@v0 * dna_r9.4.1_e8_hac@v3.3_5mCG@v0 * dna_r9.4.1_e8_sup@v3.3_5mCG@v0 +* dna_r9.4.1_e8_fast@v3.4_5mCG_5hmCG@v0 +* dna_r9.4.1_e8_hac@v3.3_5mCG_5hmCG@v0 +* dna_r9.4.1_e8_sup@v3.3_5mCG_5hmCG@v0 * dna_r10.4.1_e8.2_260bps_fast@v3.5.2_5mCG@v2 * dna_r10.4.1_e8.2_260bps_hac@v3.5.2_5mCG@v2 * dna_r10.4.1_e8.2_260bps_sup@v3.5.2_5mCG@v2 diff --git a/dorado/utils/models.h b/dorado/utils/models.h old mode 100755 new mode 100644 index 59a2a4137..ba214a8a2 --- a/dorado/utils/models.h +++ b/dorado/utils/models.h @@ -57,6 +57,7 @@ static const std::vector models = { // RNA002 "rna002_70bps_fast@v3", "rna002_70bps_hac@v3", + // RNA003 "rna003_120bps_sup@v3", @@ -89,6 +90,10 @@ static const std::vector models = { "dna_r9.4.1_e8_hac@v3.3_5mCG@v0", "dna_r9.4.1_e8_sup@v3.3_5mCG@v0", + "dna_r9.4.1_e8_hac@v3.4_5mCG_5hmCG@v0", + "dna_r9.4.1_e8_hac@v3.3_5mCG_5hmCG@v0", + "dna_r9.4.1_e8_sup@v3.3_5mCG_5hmCG@v0", + // v3.5.2 "dna_r10.4.1_e8.2_260bps_fast@v3.5.2_5mCG@v2", "dna_r10.4.1_e8.2_260bps_hac@v3.5.2_5mCG@v2", From 3a021f4d76018abe2ddcc41cdb3e9bff6bd3ca3b Mon Sep 17 00:00:00 2001 From: Chris Seymour Date: Wed, 28 Jun 2023 15:02:49 +0100 Subject: [PATCH 2/7] hac -> fast --- dorado/utils/models.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dorado/utils/models.h b/dorado/utils/models.h index ba214a8a2..fdbcf1fa4 100644 --- a/dorado/utils/models.h +++ b/dorado/utils/models.h @@ -90,7 +90,7 @@ static const std::vector models = { "dna_r9.4.1_e8_hac@v3.3_5mCG@v0", "dna_r9.4.1_e8_sup@v3.3_5mCG@v0", - "dna_r9.4.1_e8_hac@v3.4_5mCG_5hmCG@v0", + "dna_r9.4.1_e8_fast@v3.4_5mCG_5hmCG@v0", "dna_r9.4.1_e8_hac@v3.3_5mCG_5hmCG@v0", "dna_r9.4.1_e8_sup@v3.3_5mCG_5hmCG@v0", From 2a1f5595c4dcf6219d133dae128bd7f38d6546fd Mon Sep 17 00:00:00 2001 From: Chris Seymour Date: Mon, 3 Jul 2023 18:10:27 +0100 Subject: [PATCH 3/7] switch back to medmad scaling --- dorado/read_pipeline/ScalerNode.cpp | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/dorado/read_pipeline/ScalerNode.cpp b/dorado/read_pipeline/ScalerNode.cpp index 3d578fe4d..bdecd08b9 100644 --- a/dorado/read_pipeline/ScalerNode.cpp +++ b/dorado/read_pipeline/ScalerNode.cpp @@ -23,13 +23,23 @@ std::pair ScalerNode::normalisation(torch::Tensor& x) { return {shift, scale}; } +#define EPS 1e-9f + +std::pair calculate_med_mad(torch::Tensor& x, float factor = 1.4826) { + //Calculate signal median and median absolute deviation + auto med = x.median(); + auto mad = torch::median(torch::abs(x - med)) * factor + EPS; + + return {med.item(), mad.item()}; +} + void ScalerNode::worker_thread() { Message message; while (m_work_queue.try_pop(message)) { // If this message isn't a read, we'll get a bad_variant_access exception. auto read = std::get>(message); - const auto [shift, scale] = normalisation(read->raw_data); + const auto [shift, scale] = calculate_med_mad(read->raw_data); // raw_data comes from DataLoader with dtype int16. We send it on as float16 after // shifting/scaling in float32 form. read->raw_data = ((read->raw_data.to(torch::kFloat) - shift) / scale).to(torch::kFloat16); From 69f35c465632ee7d812ac51f8e8040a3770aeac4 Mon Sep 17 00:00:00 2001 From: Chris Seymour Date: Wed, 5 Jul 2023 11:00:57 +0100 Subject: [PATCH 4/7] choose scaling method from model name --- dorado/cli/basecaller.cpp | 6 ++++++ dorado/nn/CRFModel.h | 1 + dorado/read_pipeline/ScalerNode.cpp | 11 ++++++----- dorado/read_pipeline/ScalerNode.h | 1 + 4 files changed, 14 insertions(+), 5 deletions(-) diff --git a/dorado/cli/basecaller.cpp b/dorado/cli/basecaller.cpp index b92a0dbb0..170d5adf6 100644 --- a/dorado/cli/basecaller.cpp +++ b/dorado/cli/basecaller.cpp @@ -268,6 +268,12 @@ void setup(std::vector args, const int kBatchTimeoutMS = 100; BasecallerNode basecaller_node(*basecaller_node_sink, std::move(runners), overlap, kBatchTimeoutMS, model_name, 1000); + + if (model_name.rfind("dna_r9.4.1", 0) == 0) { + spdlog::debug("- using medmad scaling"); + model_config.signal_norm_params.quantile_scaling = false; + } + ScalerNode scaler_node(basecaller_node, model_config.signal_norm_params, thread_allocations.scaler_node_threads); diff --git a/dorado/nn/CRFModel.h b/dorado/nn/CRFModel.h index 77149cc25..88f9e2094 100644 --- a/dorado/nn/CRFModel.h +++ b/dorado/nn/CRFModel.h @@ -15,6 +15,7 @@ struct SignalNormalisationParams { float quantile_b = 0.9f; float shift_multiplier = 0.51f; float scale_multiplier = 0.53f; + bool quantile_scaling = true; }; // Values extracted from config.toml used in construction of the model module. diff --git a/dorado/read_pipeline/ScalerNode.cpp b/dorado/read_pipeline/ScalerNode.cpp index bdecd08b9..a39ef3c92 100644 --- a/dorado/read_pipeline/ScalerNode.cpp +++ b/dorado/read_pipeline/ScalerNode.cpp @@ -7,6 +7,8 @@ #include #include +#define EPS 1e-9f + using namespace std::chrono_literals; using Slice = torch::indexing::Slice; @@ -23,13 +25,10 @@ std::pair ScalerNode::normalisation(torch::Tensor& x) { return {shift, scale}; } -#define EPS 1e-9f - -std::pair calculate_med_mad(torch::Tensor& x, float factor = 1.4826) { +std::pair ScalerNode::med_mad(torch::Tensor& x, float factor = 1.4826) { //Calculate signal median and median absolute deviation auto med = x.median(); auto mad = torch::median(torch::abs(x - med)) * factor + EPS; - return {med.item(), mad.item()}; } @@ -39,7 +38,9 @@ void ScalerNode::worker_thread() { // If this message isn't a read, we'll get a bad_variant_access exception. auto read = std::get>(message); - const auto [shift, scale] = calculate_med_mad(read->raw_data); + const auto [shift, scale] = m_scaling_params.quantile_scaling + ? normalisation(read->raw_data) + : med_mad(read->raw_data); // raw_data comes from DataLoader with dtype int16. We send it on as float16 after // shifting/scaling in float32 form. read->raw_data = ((read->raw_data.to(torch::kFloat) - shift) / scale).to(torch::kFloat16); diff --git a/dorado/read_pipeline/ScalerNode.h b/dorado/read_pipeline/ScalerNode.h index 4dd6be623..999c9f195 100644 --- a/dorado/read_pipeline/ScalerNode.h +++ b/dorado/read_pipeline/ScalerNode.h @@ -29,6 +29,7 @@ class ScalerNode : public MessageSink { SignalNormalisationParams m_scaling_params; + std::pair med_mad(torch::Tensor& x, float factor); std::pair normalisation(torch::Tensor& x); }; From ea76b3b698c67f9e00e495abfbbb1e476d6d1309 Mon Sep 17 00:00:00 2001 From: Mark Bicknell Date: Wed, 5 Jul 2023 18:06:04 +0100 Subject: [PATCH 5/7] Review changes: - Made med_mad factor argument into a constexpr - Added scaling_method to Read class to track med_mad vs quantile --- dorado/read_pipeline/ReadPipeline.cpp | 2 +- dorado/read_pipeline/ReadPipeline.h | 5 +++-- dorado/read_pipeline/ScalerNode.cpp | 6 +++++- dorado/read_pipeline/ScalerNode.h | 2 +- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/dorado/read_pipeline/ReadPipeline.cpp b/dorado/read_pipeline/ReadPipeline.cpp index 6467eacd0..ffe5b7b25 100644 --- a/dorado/read_pipeline/ReadPipeline.cpp +++ b/dorado/read_pipeline/ReadPipeline.cpp @@ -88,7 +88,7 @@ void Read::generate_read_tags(bam1_t *aln, bool emit_moves) const { float sd = scale; bam_aux_append(aln, "sd", 'f', sizeof(sd), (uint8_t *)&sd); - bam_aux_append(aln, "sv", 'Z', 9, (uint8_t *)"quantile"); + bam_aux_append(aln, "sv", 'Z', scaling_method.size() + 1, (uint8_t *)scaling_method.c_str()); uint32_t duplex = 0; bam_aux_append(aln, "dx", 'i', sizeof(duplex), (uint8_t *)&duplex); diff --git a/dorado/read_pipeline/ReadPipeline.h b/dorado/read_pipeline/ReadPipeline.h index 010951f49..991c2eb62 100644 --- a/dorado/read_pipeline/ReadPipeline.h +++ b/dorado/read_pipeline/ReadPipeline.h @@ -67,8 +67,9 @@ class Read { uint64_t start_time_ms; uint64_t get_end_time_ms(); - float shift; // To be set by scaler - float scale; // To be set by scaler + float shift; // To be set by scaler + float scale; // To be set by scaler + std::string scaling_method; // To be set by scaler float scaling; // Scale factor applied to convert raw integers from sequencer into pore current values diff --git a/dorado/read_pipeline/ScalerNode.cpp b/dorado/read_pipeline/ScalerNode.cpp index a39ef3c92..25d4cc0d2 100644 --- a/dorado/read_pipeline/ScalerNode.cpp +++ b/dorado/read_pipeline/ScalerNode.cpp @@ -25,7 +25,10 @@ std::pair ScalerNode::normalisation(torch::Tensor& x) { return {shift, scale}; } -std::pair ScalerNode::med_mad(torch::Tensor& x, float factor = 1.4826) { +std::pair ScalerNode::med_mad(torch::Tensor& x) { + // See https://en.wikipedia.org/wiki/Median_absolute_deviation + // (specifically the "Relation to standard deviation" section) + constexpr float factor = 1.4826; //Calculate signal median and median absolute deviation auto med = x.median(); auto mad = torch::median(torch::abs(x - med)) * factor + EPS; @@ -41,6 +44,7 @@ void ScalerNode::worker_thread() { const auto [shift, scale] = m_scaling_params.quantile_scaling ? normalisation(read->raw_data) : med_mad(read->raw_data); + read->scaling_method = m_scaling_params.quantile_scaling ? "quantile" : "med_mad"; // raw_data comes from DataLoader with dtype int16. We send it on as float16 after // shifting/scaling in float32 form. read->raw_data = ((read->raw_data.to(torch::kFloat) - shift) / scale).to(torch::kFloat16); diff --git a/dorado/read_pipeline/ScalerNode.h b/dorado/read_pipeline/ScalerNode.h index 999c9f195..5b6b8a7e0 100644 --- a/dorado/read_pipeline/ScalerNode.h +++ b/dorado/read_pipeline/ScalerNode.h @@ -29,7 +29,7 @@ class ScalerNode : public MessageSink { SignalNormalisationParams m_scaling_params; - std::pair med_mad(torch::Tensor& x, float factor); + std::pair med_mad(torch::Tensor& x); std::pair normalisation(torch::Tensor& x); }; From 84f8256d8d36cd2d743b908a85cacd9da63d3e5e Mon Sep 17 00:00:00 2001 From: Mark Bicknell Date: Wed, 5 Jul 2023 18:19:26 +0100 Subject: [PATCH 6/7] Moved setting of signal_norm_params.quantile_scaling into load_crf_model_config --- dorado/cli/basecaller.cpp | 1 - dorado/nn/CRFModel.cpp | 6 ++++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/dorado/cli/basecaller.cpp b/dorado/cli/basecaller.cpp index 0191988cd..e484d155e 100644 --- a/dorado/cli/basecaller.cpp +++ b/dorado/cli/basecaller.cpp @@ -184,7 +184,6 @@ void setup(std::vector args, if (model_name.rfind("dna_r9.4.1", 0) == 0) { spdlog::debug("- using medmad scaling"); - model_config.signal_norm_params.quantile_scaling = false; } ScalerNode scaler_node(basecaller_node, model_config.signal_norm_params, diff --git a/dorado/nn/CRFModel.cpp b/dorado/nn/CRFModel.cpp index cd2c9a854..71a1764f6 100644 --- a/dorado/nn/CRFModel.cpp +++ b/dorado/nn/CRFModel.cpp @@ -872,6 +872,12 @@ CRFModelConfig load_crf_model_config(const std::filesystem::path &path) { config.signal_norm_params.scale_multiplier = toml::find(norm, "scale_multiplier"); } + // Set quantile scaling method based on the model filename + std::string model_name = std::filesystem::canonical(config.model_path).filename().string(); + if (model_name.rfind("dna_r9.4.1", 0) == 0) { + config.signal_norm_params.quantile_scaling = false; + } + return config; } From 8c62929f8b23e560b57032462d172efe04e4fbe5 Mon Sep 17 00:00:00 2001 From: Mark Bicknell Date: Wed, 5 Jul 2023 18:22:29 +0100 Subject: [PATCH 7/7] Unit test fix. --- tests/ReadTest.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/ReadTest.cpp b/tests/ReadTest.cpp index ebc197eac..f0da94a0c 100644 --- a/tests/ReadTest.cpp +++ b/tests/ReadTest.cpp @@ -17,6 +17,7 @@ TEST_CASE(TEST_GROUP ": Test tag generation", TEST_GROUP) { test_read.sample_rate = 4000.0; test_read.shift = 128.3842f; test_read.scale = 8.258f; + test_read.scaling_method = "quantile"; test_read.num_trimmed_samples = 132; test_read.attributes.mux = 2; test_read.attributes.read_number = 18501;