Skip to content

Commit

Permalink
add parallel processing support for motif scanning
Browse files Browse the repository at this point in the history
  • Loading branch information
peter6866 committed Dec 18, 2024
1 parent 03df91f commit 508ea68
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 49 deletions.
7 changes: 4 additions & 3 deletions motif-scanner/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ license = "MIT"

[dependencies]
tf-binding-rs = { path = "../tf-binding-rs" }
clap = { version = "4.5.21", features = ["derive", "color"] }
polars = { version = "0.44.2", features = ["lazy", "csv", "strings", "regex"] }
thiserror = "2.0.3"
clap = { version = "4.5.23", features = ["derive", "color"] }
polars = { version = "0.45.1", features = ["lazy", "csv", "strings", "regex", "parquet"] }
thiserror = "2.0.8"
rayon = "1.5.1"
152 changes: 107 additions & 45 deletions motif-scanner/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use clap::Parser;
use polars::prelude::*;
use rayon::prelude::*;
use std::fs;
use std::path::Path;
use tf_binding_rs::occupancy::{read_pwm_to_ewm, total_landscape};
Expand Down Expand Up @@ -66,6 +67,35 @@ struct Args {
mu: i32,
}

trait UnzipN<A, B, C, D, E, F> {
fn unzip_n_vec(self) -> (Vec<A>, Vec<B>, Vec<C>, Vec<D>, Vec<E>, Vec<F>);
}

impl<I, A, B, C, D, E, F> UnzipN<A, B, C, D, E, F> for I
where
I: Iterator<Item = (A, B, C, D, E, F)>,
{
fn unzip_n_vec(self) -> (Vec<A>, Vec<B>, Vec<C>, Vec<D>, Vec<E>, Vec<F>) {
let mut a_vec = Vec::new();
let mut b_vec = Vec::new();
let mut c_vec = Vec::new();
let mut d_vec = Vec::new();
let mut e_vec = Vec::new();
let mut f_vec = Vec::new();

for (a, b, c, d, e, f) in self {
a_vec.push(a);
b_vec.push(b);
c_vec.push(c);
d_vec.push(d);
e_vec.push(e);
f_vec.push(f);
}

(a_vec, b_vec, c_vec, d_vec, e_vec, f_vec)
}
}

fn process_sequences(
df: &DataFrame,
ewm: &EWMCollection,
Expand All @@ -76,58 +106,67 @@ fn process_sequences(
.column("sequence")
.map_err(|_| ScannerError::MissingSequenceColumn)?;

let mut labels: Vec<i32> = Vec::new();
let mut positions: Vec<i32> = Vec::new();
let mut motifs: Vec<String> = Vec::new();
let mut strands: Vec<String> = Vec::new();
let mut lengths: Vec<i32> = Vec::new();
let mut occupancies: Vec<f64> = Vec::new();

let total_seqs = sequences.len();
println!("{} sequences to scan", total_seqs);

for (idx, seq) in sequences.str()?.into_iter().enumerate() {
if let Some(sequence) = seq {
let landscape = total_landscape(sequence, ewm, mu).map_err(|_| {
ScannerError::PwmError(format!(
"Error calculating occupancy landscape for sequence {}",
idx
))
})?;

// Get the height (number of positions) of the landscape DataFrame
let n_positions = landscape.height();

// Iterate through each motif in the EWM collection
for (motif_id, motif_df) in ewm.iter() {
// Check both forward and reverse strands
for strand in ["F", "R"] {
let col_name = format!("{}_{}", motif_id, strand);

// Get the column for this motif+strand from the landscape
if let Ok(motif_col) = landscape.column(&col_name) {
// Iterate through positions
for pos in 0..n_positions {
if let Ok(occ) = motif_col.get(pos).unwrap().try_extract::<f64>() {
if occ > cutoff {
labels.push(idx as i32);
positions.push(pos as i32);
motifs.push(motif_id.split('_').next().unwrap().to_string());
strands.push(strand.to_string());
lengths.push(motif_df.height() as i32);
occupancies.push(occ);
// convert ChunkedArray<String> to Vec<String> for parallel processing
let sequences_vec: Vec<_> = sequences.str()?.into_iter().collect();

// Parallel processing of sequences
let results: Vec<_> = sequences_vec
.into_par_iter()
.enumerate()
.filter_map(|(idx, seq)| {
seq.map(|sequence| {
let landscape = match total_landscape(sequence, ewm, mu) {
Ok(l) => l,
Err(_) => return Vec::new(),
};

let n_positions = landscape.height();
let mut local_results = Vec::new();

// Iterate through each motif in the EWM collection
for (motif_id, motif_df) in ewm.iter() {
// Check both forward and reverse strands
for strand in ["F", "R"] {
let col_name = format!("{}_{}", motif_id, strand);

// Get the column for this motif+strand from the landscape
if let Ok(motif_col) = landscape.column(&col_name) {
// Iterate through positions
for pos in 0..n_positions {
if let Ok(occ) = motif_col.get(pos).unwrap().try_extract::<f64>() {
if occ > cutoff {
local_results.push((
idx as i32,
pos as i32,
motif_id.split('_').next().unwrap().to_string(),
strand.to_string(),
motif_df.height() as i32,
occ,
));
}
}
}
}
}
}
}
}

if (idx + 1) % 5000 == 0 {
println!("\t{} / {} sequences scanned", idx + 1, total_seqs);
}
}
local_results
})
})
.flatten()
.collect();

// Unzip results into separate vectors
let (labels, positions, motifs, strands, lengths, occupancies): (
Vec<i32>,
Vec<i32>,
Vec<String>,
Vec<String>,
Vec<i32>,
Vec<f64>,
) = results.into_iter().unzip_n_vec();

let df = DataFrame::new(vec![
Column::new("label".into(), labels),
Expand All @@ -141,6 +180,26 @@ fn process_sequences(
Ok(df)
}

fn save_results(df: &mut DataFrame, output_file: &str) -> Result<(), ScannerError> {
match Path::new(output_file)
.extension()
.and_then(|ext| ext.to_str())
{
Some("parquet") => {
let mut file = std::fs::File::create(output_file)?;
ParquetWriter::new(&mut file)
.with_compression(ParquetCompression::Snappy)
.finish(df)?;
}
_ => {
let mut file = std::fs::File::create(output_file)?;
CsvWriter::new(&mut file).include_header(true).finish(df)?;
}
}

Ok(())
}

fn main() -> Result<(), ScannerError> {
let start_time = std::time::Instant::now();

Expand All @@ -161,13 +220,16 @@ fn main() -> Result<(), ScannerError> {
// read pwm file and convert to ewm
let ewm = read_pwm_to_ewm(&args.pwm_file).map_err(|e| ScannerError::PwmError(e.to_string()))?;

let results_df = process_sequences(&df, &ewm, args.mu as f64, args.cutoff)?;
let mut results_df = process_sequences(&df, &ewm, args.mu as f64, args.cutoff)?;

let elapsed = start_time.elapsed();
println!(
"Total execution time: {:.4} minutes",
elapsed.as_secs_f64() / 60.0
);

// save results
save_results(&mut results_df, &args.output_file)?;

Ok(())
}
2 changes: 1 addition & 1 deletion tf-binding-rs/src/occupancy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ pub fn read_pwm_files(filename: &str) -> Result<PWMCollection, MotifError> {
let file = File::open(filename)?;
let reader = BufReader::new(file);
let mut lines = reader.lines().peekable();
let mut pwms = HashMap::new();
let mut pwms = PWMCollection::new();

// Skip header until first MOTIF
skip_until_motif(&mut lines);
Expand Down

0 comments on commit 508ea68

Please sign in to comment.