Skip to content

Commit

Permalink
chore: Move parallel evaluation code to CircuitChunks (#528)
Browse files Browse the repository at this point in the history
resolves #266

---------

Co-authored-by: Agustín Borgna <121866228+aborgna-q@users.noreply.github.com>
  • Loading branch information
potatoboiler and aborgna-q committed Aug 5, 2024
1 parent 7c14494 commit 209b342
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 5 deletions.
5 changes: 3 additions & 2 deletions tket2/src/optimiser/badger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use fxhash::FxHashSet;
use hugr::hugr::HugrError;
use hugr::HugrView;
pub use log::BadgerLogger;
use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};

use std::num::NonZeroUsize;
use std::time::{Duration, Instant};
Expand Down Expand Up @@ -131,7 +132,7 @@ impl<R, S> BadgerOptimiser<R, S> {

impl<R, S> BadgerOptimiser<R, S>
where
R: Rewriter + Send + Clone + 'static,
R: Rewriter + Send + Clone + Sync + 'static,
S: RewriteStrategy + Send + Sync + Clone + 'static,
S::Cost: serde::Serialize + Send + Sync,
{
Expand Down Expand Up @@ -440,7 +441,7 @@ where
logger.log_best(circ_cost.clone(), num_rewrites);

let (joins, rx_work): (Vec<_>, Vec<_>) = chunks
.iter_mut()
.par_iter_mut()
.enumerate()
.map(|(i, chunk)| {
let (tx, rx) = crossbeam_channel::unbounded();
Expand Down
7 changes: 4 additions & 3 deletions tket2/src/optimiser/badger/log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::{fmt::Debug, io};

/// Logging configuration for the Badger optimiser.
pub struct BadgerLogger<'w> {
circ_candidates_csv: Option<csv::Writer<Box<dyn io::Write + 'w>>>,
circ_candidates_csv: Option<csv::Writer<Box<dyn io::Write + Send + Sync + 'w>>>,
last_circ_processed: usize,
last_progress_time: Instant,
branching_factor: UsizeAverage,
Expand Down Expand Up @@ -41,8 +41,9 @@ impl<'w> BadgerLogger<'w> {
/// or [`PROGRESS_TARGET`].
///
/// [`log`]: <https://docs.rs/log/latest/log/>
pub fn new(best_progress_csv_writer: impl io::Write + 'w) -> Self {
let boxed_candidates_writer: Box<dyn io::Write + 'w> = Box::new(best_progress_csv_writer);
pub fn new(best_progress_csv_writer: impl io::Write + Send + Sync + 'w) -> Self {
let boxed_candidates_writer: Box<dyn io::Write + Send + Sync + 'w> =
Box::new(best_progress_csv_writer);
Self {
circ_candidates_csv: Some(csv::Writer::from_writer(boxed_candidates_writer)),
..Default::default()
Expand Down
27 changes: 27 additions & 0 deletions tket2/src/passes/chunks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ use hugr::types::Signature;
use hugr::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex, Wire};
use itertools::Itertools;
use portgraph::algorithms::ConvexChecker;
use rayon::iter::{IntoParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
use rayon::slice::ParallelSliceMut;

use crate::Circuit;

Expand Down Expand Up @@ -442,6 +444,19 @@ impl CircuitChunks {
pub fn is_empty(&self) -> bool {
self.chunks.is_empty()
}

/// Supports implementation of rayon::iter::IntoParallelRefMutIterator
fn par_iter_mut(
&mut self,
) -> rayon::iter::Map<
rayon::slice::IterMut<'_, Chunk>,
for<'a> fn(&'a mut Chunk) -> &'a mut Circuit,
> {
self.chunks
.as_parallel_slice_mut()
.into_par_iter()
.map(|chunk| &mut chunk.circ)
}
}

impl Index<usize> for CircuitChunks {
Expand All @@ -458,6 +473,18 @@ impl IndexMut<usize> for CircuitChunks {
}
}

impl<'data> IntoParallelRefMutIterator<'data> for CircuitChunks {
type Item = &'data mut Circuit;
type Iter = rayon::iter::Map<
rayon::slice::IterMut<'data, Chunk>,
for<'a> fn(&'a mut Chunk) -> &'a mut Circuit,
>;

fn par_iter_mut(&'data mut self) -> Self::Iter {
self.par_iter_mut()
}
}

#[cfg(test)]
mod test {
use crate::circuit::CircuitHash;
Expand Down

0 comments on commit 209b342

Please sign in to comment.