diff --git a/motif-scanner/Cargo.toml b/motif-scanner/Cargo.toml index cd8cb12..ed52592 100644 --- a/motif-scanner/Cargo.toml +++ b/motif-scanner/Cargo.toml @@ -8,6 +8,7 @@ license = "MIT" [dependencies] tf-binding-rs = { path = "../tf-binding-rs" } -clap = { version = "4.5.21", features = ["derive", "color"] } -polars = { version = "0.44.2", features = ["lazy", "csv", "strings", "regex"] } -thiserror = "2.0.3" +clap = { version = "4.5.23", features = ["derive", "color"] } +polars = { version = "0.45.1", features = ["lazy", "csv", "strings", "regex", "parquet"] } +thiserror = "2.0.8" +rayon = "1.5.1" diff --git a/motif-scanner/src/main.rs b/motif-scanner/src/main.rs index 33397cd..b70769a 100644 --- a/motif-scanner/src/main.rs +++ b/motif-scanner/src/main.rs @@ -1,5 +1,6 @@ use clap::Parser; use polars::prelude::*; +use rayon::prelude::*; use std::fs; use std::path::Path; use tf_binding_rs::occupancy::{read_pwm_to_ewm, total_landscape}; @@ -66,6 +67,35 @@ struct Args { mu: i32, } +trait UnzipN { + fn unzip_n_vec(self) -> (Vec, Vec, Vec, Vec, Vec, Vec); +} + +impl UnzipN for I +where + I: Iterator, +{ + fn unzip_n_vec(self) -> (Vec, Vec, Vec, Vec, Vec, Vec) { + let mut a_vec = Vec::new(); + let mut b_vec = Vec::new(); + let mut c_vec = Vec::new(); + let mut d_vec = Vec::new(); + let mut e_vec = Vec::new(); + let mut f_vec = Vec::new(); + + for (a, b, c, d, e, f) in self { + a_vec.push(a); + b_vec.push(b); + c_vec.push(c); + d_vec.push(d); + e_vec.push(e); + f_vec.push(f); + } + + (a_vec, b_vec, c_vec, d_vec, e_vec, f_vec) + } +} + fn process_sequences( df: &DataFrame, ewm: &EWMCollection, @@ -76,58 +106,67 @@ fn process_sequences( .column("sequence") .map_err(|_| ScannerError::MissingSequenceColumn)?; - let mut labels: Vec = Vec::new(); - let mut positions: Vec = Vec::new(); - let mut motifs: Vec = Vec::new(); - let mut strands: Vec = Vec::new(); - let mut lengths: Vec = Vec::new(); - let mut occupancies: Vec = Vec::new(); - let total_seqs = sequences.len(); println!("{} sequences to scan", total_seqs); - for (idx, seq) in sequences.str()?.into_iter().enumerate() { - if let Some(sequence) = seq { - let landscape = total_landscape(sequence, ewm, mu).map_err(|_| { - ScannerError::PwmError(format!( - "Error calculating occupancy landscape for sequence {}", - idx - )) - })?; - - // Get the height (number of positions) of the landscape DataFrame - let n_positions = landscape.height(); - - // Iterate through each motif in the EWM collection - for (motif_id, motif_df) in ewm.iter() { - // Check both forward and reverse strands - for strand in ["F", "R"] { - let col_name = format!("{}_{}", motif_id, strand); - - // Get the column for this motif+strand from the landscape - if let Ok(motif_col) = landscape.column(&col_name) { - // Iterate through positions - for pos in 0..n_positions { - if let Ok(occ) = motif_col.get(pos).unwrap().try_extract::() { - if occ > cutoff { - labels.push(idx as i32); - positions.push(pos as i32); - motifs.push(motif_id.split('_').next().unwrap().to_string()); - strands.push(strand.to_string()); - lengths.push(motif_df.height() as i32); - occupancies.push(occ); + // convert ChunkedArray to Vec for parallel processing + let sequences_vec: Vec<_> = sequences.str()?.into_iter().collect(); + + // Parallel processing of sequences + let results: Vec<_> = sequences_vec + .into_par_iter() + .enumerate() + .filter_map(|(idx, seq)| { + seq.map(|sequence| { + let landscape = match total_landscape(sequence, ewm, mu) { + Ok(l) => l, + Err(_) => return Vec::new(), + }; + + let n_positions = landscape.height(); + let mut local_results = Vec::new(); + + // Iterate through each motif in the EWM collection + for (motif_id, motif_df) in ewm.iter() { + // Check both forward and reverse strands + for strand in ["F", "R"] { + let col_name = format!("{}_{}", motif_id, strand); + + // Get the column for this motif+strand from the landscape + if let Ok(motif_col) = landscape.column(&col_name) { + // Iterate through positions + for pos in 0..n_positions { + if let Ok(occ) = motif_col.get(pos).unwrap().try_extract::() { + if occ > cutoff { + local_results.push(( + idx as i32, + pos as i32, + motif_id.split('_').next().unwrap().to_string(), + strand.to_string(), + motif_df.height() as i32, + occ, + )); + } } } } } } - } - } - - if (idx + 1) % 5000 == 0 { - println!("\t{} / {} sequences scanned", idx + 1, total_seqs); - } - } + local_results + }) + }) + .flatten() + .collect(); + + // Unzip results into separate vectors + let (labels, positions, motifs, strands, lengths, occupancies): ( + Vec, + Vec, + Vec, + Vec, + Vec, + Vec, + ) = results.into_iter().unzip_n_vec(); let df = DataFrame::new(vec![ Column::new("label".into(), labels), @@ -141,6 +180,26 @@ fn process_sequences( Ok(df) } +fn save_results(df: &mut DataFrame, output_file: &str) -> Result<(), ScannerError> { + match Path::new(output_file) + .extension() + .and_then(|ext| ext.to_str()) + { + Some("parquet") => { + let mut file = std::fs::File::create(output_file)?; + ParquetWriter::new(&mut file) + .with_compression(ParquetCompression::Snappy) + .finish(df)?; + } + _ => { + let mut file = std::fs::File::create(output_file)?; + CsvWriter::new(&mut file).include_header(true).finish(df)?; + } + } + + Ok(()) +} + fn main() -> Result<(), ScannerError> { let start_time = std::time::Instant::now(); @@ -161,7 +220,7 @@ fn main() -> Result<(), ScannerError> { // read pwm file and convert to ewm let ewm = read_pwm_to_ewm(&args.pwm_file).map_err(|e| ScannerError::PwmError(e.to_string()))?; - let results_df = process_sequences(&df, &ewm, args.mu as f64, args.cutoff)?; + let mut results_df = process_sequences(&df, &ewm, args.mu as f64, args.cutoff)?; let elapsed = start_time.elapsed(); println!( @@ -169,5 +228,8 @@ fn main() -> Result<(), ScannerError> { elapsed.as_secs_f64() / 60.0 ); + // save results + save_results(&mut results_df, &args.output_file)?; + Ok(()) } diff --git a/tf-binding-rs/src/occupancy.rs b/tf-binding-rs/src/occupancy.rs index 762cc4a..a449fda 100644 --- a/tf-binding-rs/src/occupancy.rs +++ b/tf-binding-rs/src/occupancy.rs @@ -127,7 +127,7 @@ pub fn read_pwm_files(filename: &str) -> Result { let file = File::open(filename)?; let reader = BufReader::new(file); let mut lines = reader.lines().peekable(); - let mut pwms = HashMap::new(); + let mut pwms = PWMCollection::new(); // Skip header until first MOTIF skip_until_motif(&mut lines);