diff --git a/motif-scanner/Cargo.toml b/motif-scanner/Cargo.toml
index cd8cb12..ed52592 100644
--- a/motif-scanner/Cargo.toml
+++ b/motif-scanner/Cargo.toml
@@ -8,6 +8,7 @@ license = "MIT"
[dependencies]
tf-binding-rs = { path = "../tf-binding-rs" }
-clap = { version = "4.5.21", features = ["derive", "color"] }
-polars = { version = "0.44.2", features = ["lazy", "csv", "strings", "regex"] }
-thiserror = "2.0.3"
+clap = { version = "4.5.23", features = ["derive", "color"] }
+polars = { version = "0.45.1", features = ["lazy", "csv", "strings", "regex", "parquet"] }
+thiserror = "2.0.8"
+rayon = "1.5.1"
diff --git a/motif-scanner/src/main.rs b/motif-scanner/src/main.rs
index 33397cd..b70769a 100644
--- a/motif-scanner/src/main.rs
+++ b/motif-scanner/src/main.rs
@@ -1,5 +1,6 @@
use clap::Parser;
use polars::prelude::*;
+use rayon::prelude::*;
use std::fs;
use std::path::Path;
use tf_binding_rs::occupancy::{read_pwm_to_ewm, total_landscape};
@@ -66,6 +67,35 @@ struct Args {
mu: i32,
}
+trait UnzipN {
+ fn unzip_n_vec(self) -> (Vec, Vec, Vec, Vec, Vec, Vec);
+}
+
+impl UnzipN for I
+where
+ I: Iterator- ,
+{
+ fn unzip_n_vec(self) -> (Vec, Vec, Vec, Vec, Vec, Vec) {
+ let mut a_vec = Vec::new();
+ let mut b_vec = Vec::new();
+ let mut c_vec = Vec::new();
+ let mut d_vec = Vec::new();
+ let mut e_vec = Vec::new();
+ let mut f_vec = Vec::new();
+
+ for (a, b, c, d, e, f) in self {
+ a_vec.push(a);
+ b_vec.push(b);
+ c_vec.push(c);
+ d_vec.push(d);
+ e_vec.push(e);
+ f_vec.push(f);
+ }
+
+ (a_vec, b_vec, c_vec, d_vec, e_vec, f_vec)
+ }
+}
+
fn process_sequences(
df: &DataFrame,
ewm: &EWMCollection,
@@ -76,58 +106,67 @@ fn process_sequences(
.column("sequence")
.map_err(|_| ScannerError::MissingSequenceColumn)?;
- let mut labels: Vec = Vec::new();
- let mut positions: Vec = Vec::new();
- let mut motifs: Vec = Vec::new();
- let mut strands: Vec = Vec::new();
- let mut lengths: Vec = Vec::new();
- let mut occupancies: Vec = Vec::new();
-
let total_seqs = sequences.len();
println!("{} sequences to scan", total_seqs);
- for (idx, seq) in sequences.str()?.into_iter().enumerate() {
- if let Some(sequence) = seq {
- let landscape = total_landscape(sequence, ewm, mu).map_err(|_| {
- ScannerError::PwmError(format!(
- "Error calculating occupancy landscape for sequence {}",
- idx
- ))
- })?;
-
- // Get the height (number of positions) of the landscape DataFrame
- let n_positions = landscape.height();
-
- // Iterate through each motif in the EWM collection
- for (motif_id, motif_df) in ewm.iter() {
- // Check both forward and reverse strands
- for strand in ["F", "R"] {
- let col_name = format!("{}_{}", motif_id, strand);
-
- // Get the column for this motif+strand from the landscape
- if let Ok(motif_col) = landscape.column(&col_name) {
- // Iterate through positions
- for pos in 0..n_positions {
- if let Ok(occ) = motif_col.get(pos).unwrap().try_extract::() {
- if occ > cutoff {
- labels.push(idx as i32);
- positions.push(pos as i32);
- motifs.push(motif_id.split('_').next().unwrap().to_string());
- strands.push(strand.to_string());
- lengths.push(motif_df.height() as i32);
- occupancies.push(occ);
+ // convert ChunkedArray to Vec for parallel processing
+ let sequences_vec: Vec<_> = sequences.str()?.into_iter().collect();
+
+ // Parallel processing of sequences
+ let results: Vec<_> = sequences_vec
+ .into_par_iter()
+ .enumerate()
+ .filter_map(|(idx, seq)| {
+ seq.map(|sequence| {
+ let landscape = match total_landscape(sequence, ewm, mu) {
+ Ok(l) => l,
+ Err(_) => return Vec::new(),
+ };
+
+ let n_positions = landscape.height();
+ let mut local_results = Vec::new();
+
+ // Iterate through each motif in the EWM collection
+ for (motif_id, motif_df) in ewm.iter() {
+ // Check both forward and reverse strands
+ for strand in ["F", "R"] {
+ let col_name = format!("{}_{}", motif_id, strand);
+
+ // Get the column for this motif+strand from the landscape
+ if let Ok(motif_col) = landscape.column(&col_name) {
+ // Iterate through positions
+ for pos in 0..n_positions {
+ if let Ok(occ) = motif_col.get(pos).unwrap().try_extract::() {
+ if occ > cutoff {
+ local_results.push((
+ idx as i32,
+ pos as i32,
+ motif_id.split('_').next().unwrap().to_string(),
+ strand.to_string(),
+ motif_df.height() as i32,
+ occ,
+ ));
+ }
}
}
}
}
}
- }
- }
-
- if (idx + 1) % 5000 == 0 {
- println!("\t{} / {} sequences scanned", idx + 1, total_seqs);
- }
- }
+ local_results
+ })
+ })
+ .flatten()
+ .collect();
+
+ // Unzip results into separate vectors
+ let (labels, positions, motifs, strands, lengths, occupancies): (
+ Vec,
+ Vec,
+ Vec,
+ Vec,
+ Vec,
+ Vec,
+ ) = results.into_iter().unzip_n_vec();
let df = DataFrame::new(vec![
Column::new("label".into(), labels),
@@ -141,6 +180,26 @@ fn process_sequences(
Ok(df)
}
+fn save_results(df: &mut DataFrame, output_file: &str) -> Result<(), ScannerError> {
+ match Path::new(output_file)
+ .extension()
+ .and_then(|ext| ext.to_str())
+ {
+ Some("parquet") => {
+ let mut file = std::fs::File::create(output_file)?;
+ ParquetWriter::new(&mut file)
+ .with_compression(ParquetCompression::Snappy)
+ .finish(df)?;
+ }
+ _ => {
+ let mut file = std::fs::File::create(output_file)?;
+ CsvWriter::new(&mut file).include_header(true).finish(df)?;
+ }
+ }
+
+ Ok(())
+}
+
fn main() -> Result<(), ScannerError> {
let start_time = std::time::Instant::now();
@@ -161,7 +220,7 @@ fn main() -> Result<(), ScannerError> {
// read pwm file and convert to ewm
let ewm = read_pwm_to_ewm(&args.pwm_file).map_err(|e| ScannerError::PwmError(e.to_string()))?;
- let results_df = process_sequences(&df, &ewm, args.mu as f64, args.cutoff)?;
+ let mut results_df = process_sequences(&df, &ewm, args.mu as f64, args.cutoff)?;
let elapsed = start_time.elapsed();
println!(
@@ -169,5 +228,8 @@ fn main() -> Result<(), ScannerError> {
elapsed.as_secs_f64() / 60.0
);
+ // save results
+ save_results(&mut results_df, &args.output_file)?;
+
Ok(())
}
diff --git a/tf-binding-rs/src/occupancy.rs b/tf-binding-rs/src/occupancy.rs
index 762cc4a..a449fda 100644
--- a/tf-binding-rs/src/occupancy.rs
+++ b/tf-binding-rs/src/occupancy.rs
@@ -127,7 +127,7 @@ pub fn read_pwm_files(filename: &str) -> Result {
let file = File::open(filename)?;
let reader = BufReader::new(file);
let mut lines = reader.lines().peekable();
- let mut pwms = HashMap::new();
+ let mut pwms = PWMCollection::new();
// Skip header until first MOTIF
skip_until_motif(&mut lines);