Skip to content

Commit

Permalink
Merge branch 'r941-v3.3-5mCG-5hmCG-models' into 'master'
Browse files Browse the repository at this point in the history
R941 v3.3 5m cg 5hm cg models

See merge request machine-learning/dorado!457
  • Loading branch information
MarkBicknellONT committed Jul 6, 2023
2 parents ba40e53 + 93bf7b8 commit a57987f
Show file tree
Hide file tree
Showing 9 changed files with 38 additions and 5 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,9 @@ The following simplex models are also available:
* dna_r9.4.1_e8_fast@v3.4_5mCG@v0
* dna_r9.4.1_e8_hac@v3.3_5mCG@v0
* dna_r9.4.1_e8_sup@v3.3_5mCG@v0
* dna_r9.4.1_e8_fast@v3.4_5mCG_5hmCG@v0
* dna_r9.4.1_e8_hac@v3.3_5mCG_5hmCG@v0
* dna_r9.4.1_e8_sup@v3.3_5mCG_5hmCG@v0
* dna_r10.4.1_e8.2_260bps_fast@v3.5.2_5mCG@v2
* dna_r10.4.1_e8.2_260bps_hac@v3.5.2_5mCG@v2
* dna_r10.4.1_e8.2_260bps_sup@v3.5.2_5mCG@v2
Expand Down
6 changes: 6 additions & 0 deletions dorado/nn/CRFModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,12 @@ CRFModelConfig load_crf_model_config(const std::filesystem::path &path) {
config.signal_norm_params.scale_multiplier = toml::find<float>(norm, "scale_multiplier");
}

// Set quantile scaling method based on the model filename
std::string model_name = std::filesystem::canonical(config.model_path).filename().string();
if (model_name.rfind("dna_r9.4.1", 0) == 0) {
config.signal_norm_params.quantile_scaling = false;
}

return config;
}

Expand Down
1 change: 1 addition & 0 deletions dorado/nn/CRFModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ struct SignalNormalisationParams {
float quantile_b = 0.9f;
float shift_multiplier = 0.51f;
float scale_multiplier = 0.53f;
bool quantile_scaling = true;
};

// Values extracted from config.toml used in construction of the model module.
Expand Down
2 changes: 1 addition & 1 deletion dorado/read_pipeline/ReadPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ void Read::generate_read_tags(bam1_t *aln, bool emit_moves) const {
float sd = scale;
bam_aux_append(aln, "sd", 'f', sizeof(sd), (uint8_t *)&sd);

bam_aux_append(aln, "sv", 'Z', 9, (uint8_t *)"quantile");
bam_aux_append(aln, "sv", 'Z', scaling_method.size() + 1, (uint8_t *)scaling_method.c_str());

uint32_t duplex = 0;
bam_aux_append(aln, "dx", 'i', sizeof(duplex), (uint8_t *)&duplex);
Expand Down
5 changes: 3 additions & 2 deletions dorado/read_pipeline/ReadPipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@ class Read {
uint64_t start_time_ms;
uint64_t get_end_time_ms();

float shift; // To be set by scaler
float scale; // To be set by scaler
float shift; // To be set by scaler
float scale; // To be set by scaler
std::string scaling_method; // To be set by scaler

float scaling; // Scale factor applied to convert raw integers from sequencer into pore current values

Expand Down
19 changes: 17 additions & 2 deletions dorado/read_pipeline/ScalerNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include <chrono>
#include <utility>

#define EPS 1e-9f

using namespace std::chrono_literals;
using Slice = torch::indexing::Slice;

Expand All @@ -25,14 +27,27 @@ std::pair<float, float> ScalerNode::normalisation(const torch::Tensor& x) {
return {shift, scale};
}

std::pair<float, float> ScalerNode::med_mad(const torch::Tensor& x) {
// See https://en.wikipedia.org/wiki/Median_absolute_deviation
// (specifically the "Relation to standard deviation" section)
constexpr float factor = 1.4826;
//Calculate signal median and median absolute deviation
auto med = x.median();
auto mad = torch::median(torch::abs(x - med)) * factor + EPS;
return {med.item<float>(), mad.item<float>()};
}

void ScalerNode::worker_thread() {
Message message;
while (m_work_queue.try_pop(message)) {
// If this message isn't a read, we'll get a bad_variant_access exception.
auto read = std::get<std::shared_ptr<Read>>(message);

assert(read->raw_data.dtype() == torch::kInt16);
const auto [shift, scale] = normalisation(read->raw_data);
const auto [shift, scale] = m_scaling_params.quantile_scaling
? normalisation(read->raw_data)
: med_mad(read->raw_data);
read->scaling_method = m_scaling_params.quantile_scaling ? "quantile" : "med_mad";

// raw_data comes from DataLoader with dtype int16. We send it on as float16 after
// shifting/scaling in float32 form.
read->raw_data = ((read->raw_data.to(torch::kFloat) - shift) / scale).to(torch::kFloat16);
Expand Down
1 change: 1 addition & 0 deletions dorado/read_pipeline/ScalerNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class ScalerNode : public MessageSink {

SignalNormalisationParams m_scaling_params;

std::pair<float, float> med_mad(const torch::Tensor& x);
std::pair<float, float> normalisation(const torch::Tensor& x);
};

Expand Down
5 changes: 5 additions & 0 deletions dorado/utils/models.h
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ static const std::vector<std::string> models = {
// RNA002
"rna002_70bps_fast@v3",
"rna002_70bps_hac@v3",

// RNA003
"rna003_120bps_sup@v3",

Expand Down Expand Up @@ -89,6 +90,10 @@ static const std::vector<std::string> models = {
"dna_r9.4.1_e8_hac@v3.3_5mCG@v0",
"dna_r9.4.1_e8_sup@v3.3_5mCG@v0",

"dna_r9.4.1_e8_fast@v3.4_5mCG_5hmCG@v0",
"dna_r9.4.1_e8_hac@v3.3_5mCG_5hmCG@v0",
"dna_r9.4.1_e8_sup@v3.3_5mCG_5hmCG@v0",

// v3.5.2
"dna_r10.4.1_e8.2_260bps_fast@v3.5.2_5mCG@v2",
"dna_r10.4.1_e8.2_260bps_hac@v3.5.2_5mCG@v2",
Expand Down
1 change: 1 addition & 0 deletions tests/ReadTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ TEST_CASE(TEST_GROUP ": Test tag generation", TEST_GROUP) {
test_read.sample_rate = 4000.0;
test_read.shift = 128.3842f;
test_read.scale = 8.258f;
test_read.scaling_method = "quantile";
test_read.num_trimmed_samples = 132;
test_read.attributes.mux = 2;
test_read.attributes.read_number = 18501;
Expand Down

0 comments on commit a57987f

Please sign in to comment.