From 3ab7b8ba571f58734e28979baa97418aea9e2bf5 Mon Sep 17 00:00:00 2001 From: Daniel Voogsgerd Date: Thu, 14 Dec 2023 15:00:03 +0100 Subject: [PATCH] fix(SW): Sequential and sequential_straight algs --- rust/src/algorithm/mod.rs | 32 +++++------- rust/src/lib.rs | 107 ++++++++++++++++++-------------------- rust/src/main.rs | 81 +++++++++++------------------ 3 files changed, 93 insertions(+), 127 deletions(-) diff --git a/rust/src/algorithm/mod.rs b/rust/src/algorithm/mod.rs index d4fd093..0f9fe00 100644 --- a/rust/src/algorithm/mod.rs +++ b/rust/src/algorithm/mod.rs @@ -37,13 +37,10 @@ pub fn string_scores_sequential( if x >= y { continue; } - if y > target.len() + x - 1 { + if y > target.len() + x { continue; } - let lhs = query[x - 1]; - let rhs = target[y - x]; - - let sub_score = if lhs == rhs { + let sub_score = if query[x - 1] == target[y - x - 1] { scores.r#match } else { scores.miss @@ -233,10 +230,7 @@ where .collect(); // Padding the target to the next whole number of LANES - let target_u16: Vec<_> = target - .into_iter() - .map(|x| *x as u16 + 1) - .collect(); + let target_u16: Vec<_> = target.into_iter().map(|x| *x as u16 + 1).collect(); let width = query_u16.len() + 1; assert!(query_u16.len() >= query.len()); @@ -461,6 +455,8 @@ pub fn string_scores_parallel( let width = query.len() + 1; let height = query.len() + target.len() + 1; + let threads = min(threads, query.len()); + let mut data = Vec::with_capacity(width * height); let data_ptr = SendPtr(data.as_mut_ptr()); @@ -479,19 +475,16 @@ pub fn string_scores_parallel( let data_ref = unsafe { std::slice::from_raw_parts_mut(data_ptr.0, width * height) }; - for y in 1..height { - let max_x = min(right, y); - for x in left..max_x { - // HOT LOOP - // PERF: Probably faster to kickstart the top and bottom - // On a single thread and save ourselves the branching - // inside the hot loop - - if target.len() + x - 1 < y { + for y in 2..height { + for x in left..right { + if x >= y { + continue; + } + if y > target.len() + x { continue; } - let sub_score = if query[x - 1] == target[y - x] { + let sub_score = if query[x - 1] == target[y - x - 1] { scores.r#match } else { scores.miss @@ -516,7 +509,6 @@ pub fn string_scores_parallel( child }); - // TODO: Probably not necessary anymore due to the thread scope for handle in handles.collect::>() { handle.join().unwrap(); } diff --git a/rust/src/lib.rs b/rust/src/lib.rs index fc1d549..9d715ce 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -25,20 +25,19 @@ pub fn find_alignment_sequential( let width = query.len() + 1; // TODO: Find max index - if let Some((index, _value)) = data.iter().enumerate().max_by_key(|(_i, x)| *x) { - let (x, y) = coord(index, width); - traceback_straight( - &data, - query, - target, - x, - y, - width, - &mut query_result, - &mut target_result, - scores, - ); - } + let max_index = data.argmax(); + let (x, y) = coord(max_index, width); + traceback( + &data, + query, + target, + x, + y, + width, + &mut query_result, + &mut target_result, + scores, + ); (query_result, target_result, 0) } @@ -51,24 +50,22 @@ pub fn find_alignment_sequential_straight( let data = string_scores_straight(query, target, scores); let mut query_result = Vec::with_capacity(query.len()); let mut target_result = Vec::with_capacity(target.len()); - visualize_straight(&data, query, target); let width = query.len() + 1; - // TODO: Find max index - if let Some((index, _value)) = data.iter().enumerate().max_by_key(|(_i, x)| *x) { - let (x, y) = coord(index, width); - traceback_straight( - &data, - query, - target, - x, - y, - width, - &mut query_result, - &mut target_result, - scores, - ); - } + let (_, max_index) = data.argminmax(); + + let (x, y) = coord(max_index, width); + traceback_straight( + &data, + query, + target, + x, + y, + width, + &mut query_result, + &mut target_result, + scores, + ); (query_result, target_result, 0) } @@ -84,21 +81,22 @@ pub fn find_alignment_parallel( let mut target_result = Vec::with_capacity(target.len()); let width = query.len() + 1; - // TODO: Find max index - if let Some((index, _value)) = data.iter().enumerate().max_by_key(|(_i, x)| *x) { - let (x, y) = coord(index, width); - traceback( - &data, - query, - target, - x, - y, - width, - &mut query_result, - &mut target_result, - scores, - ); + if data.is_empty() { + return (query_result, target_result, 0); } + let argmax = data.argmax(); + let (x, y) = coord(argmax, width); + traceback( + &data, + query, + target, + x, + y, + width, + &mut query_result, + &mut target_result, + scores, + ); (query_result, target_result, 0) } @@ -418,16 +416,15 @@ mod tests { } // TODO: Fix algorihm - // #[test] - // fn test_all_sequential() { - // test_all(find_alignment_sequential); - // } + #[test] + fn test_all_sequential() { + test_all(find_alignment_sequential); + } - // TODO: Fix algorihm - // #[test] - // fn test_all_sequential_straight() { - // test_all(find_alignment_sequential_straight); - // } + #[test] + fn test_all_sequential_straight() { + test_all(find_alignment_sequential_straight); + } // fn find_alignment_parallel_wrapper( // query: &[char], @@ -436,8 +433,8 @@ mod tests { // ) -> AlignResult { // return find_alignment_parallel(query, target, THREADS, scores); // } - - // TODO: Fix algorihm + // + // // TODO: Fix algorihm // #[test] // fn test_all_parallel() { // test_all(find_alignment_parallel_wrapper::<2>); diff --git a/rust/src/main.rs b/rust/src/main.rs index a91bf25..27f898f 100644 --- a/rust/src/main.rs +++ b/rust/src/main.rs @@ -1,18 +1,16 @@ -use std::iter::repeat; +use rand::Rng; +use sw::{ + algorithm::{find_alignment_simd_lowmem, AlignmentScores}, + find_alignment_simd, +}; + #[allow(unused_imports)] use sw::algorithm::{ string_scores_parallel, string_scores_simd, string_scores_straight, traceback, }; -use argminmax::ArgMinMax; - const LANES: usize = 64; -use sw::{ - algorithm::AlignmentScores, - utils::{coord, roundup}, -}; - fn main() { let scores = AlignmentScores { gap: -2, @@ -20,55 +18,34 @@ fn main() { miss: -3, }; - let mut query = repeat('A').take(128).collect::>(); - let mut target = repeat('T').take(20000000).collect::>(); - - target[67] = 'Z'; - target[68] = 'Z'; - query[13] = 'Z'; - query[14] = 'Z'; - - let start = std::time::Instant::now(); - - let data = string_scores_simd::(&query, &target, scores); + let mut rng = rand::thread_rng(); - // PERF: We might be able to eek out a bit more perf here by ignoring the min in a custom - // implementation - let (_min_index, max_index) = data.argminmax(); - let max = data[max_index]; - // let (max_index, max) = data.iter().enumerate().max_by_key(|(_, x)| *x).unwrap(); + let charset = ['A', 'T', 'C', 'G']; - let end = std::time::Instant::now(); + for _i in 0..10000 { + let query: Vec<_> = (0..) + .map(|_| { + let i: u8 = rng.gen(); + let char = charset[i as usize % charset.len()]; - let duration = end - start; - println!("Query * Target: {}", query.len() * target.len()); - println!("Duration: {:?}", duration.as_micros()); + char + }) + .take(1000) + .collect(); - println!( - "MCUPS: {}", - (query.len() * target.len()) / (duration.as_micros() as usize) - ); + let target: Vec<_> = (0..) + .map(|_| { + let i: u8 = rng.gen(); + let char = charset[i as usize % charset.len()]; - println!("Max: {}", max); - println!("Max index: {}", max_index); - let data_width = roundup(query.len(), LANES) + 1; - let (x, y) = coord(max_index, data_width); - println!("x: {}; y: {}", x - 1, y - x); - let mut query_result = Vec::new(); - let mut target_result = Vec::new(); + char + }) + .take(1000) + .collect(); - traceback( - &data, - &query, - &target, - x, - y, - data_width, - &mut query_result, - &mut target_result, - scores, - ); + let lm = find_alignment_simd_lowmem::(&query, &target, scores); + let sd = find_alignment_simd::(&query, &target, scores); - println!("Query: {:?}", query_result); - println!("Target: {:?}", target_result); + assert_eq!(lm, sd); + } }