Skip to content

Commit

Permalink
feat: Exhaustive greedy rewrite strategy (#151)
Browse files Browse the repository at this point in the history
This is a straight combination of the greedy and exhaustive strategies.
Instead of applying each strategy separately, apply each one once but
also include all the following strategies that do not modify the same
nodes.

This is a draft, since for some reason combining multiple strategies at
once is invalidating the hugr.
They may not be correctly reporting their modified nodes ?
  • Loading branch information
aborgna-q authored Oct 11, 2023
1 parent fb58b44 commit cf6ccf9
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 80 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ members = ["pyrs", "compile-rewriter", "taso-optimiser"]

[workspace.dependencies]

quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "09494f1" }
quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "9254ac7" }
portgraph = { version = "0.9", features = ["serde"] }
pyo3 = { version = "0.19" }
itertools = { version = "0.11.0" }
Expand Down
19 changes: 14 additions & 5 deletions src/circuit/cost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use hugr::ops::OpType;
use std::fmt::{Debug, Display};
use std::iter::Sum;
use std::num::NonZeroUsize;
use std::ops::Add;
use std::ops::{Add, AddAssign};

use crate::ops::op_matches;
use crate::T2Op;
Expand All @@ -29,7 +29,9 @@ pub trait CircuitCost: Add<Output = Self> + Sum<Self> + Debug + Default + Clone
}

/// The cost for a group of operations in a circuit, each with cost `OpCost`.
pub trait CostDelta: Sum<Self> + Debug + Default + Clone + Ord {
pub trait CostDelta:
AddAssign + Add<Output = Self> + Sum<Self> + Debug + Default + Clone + Ord
{
/// Return the delta as a `isize`. This may discard some of the cost delta information.
fn as_isize(&self) -> isize;
}
Expand Down Expand Up @@ -62,14 +64,21 @@ impl<T: Display> Debug for MajorMinorCost<T> {
}
}

impl Add for MajorMinorCost {
type Output = MajorMinorCost;
impl<T: Add<Output = T>> Add for MajorMinorCost<T> {
type Output = MajorMinorCost<T>;

fn add(self, rhs: MajorMinorCost) -> Self::Output {
fn add(self, rhs: MajorMinorCost<T>) -> Self::Output {
(self.major + rhs.major, self.minor + rhs.minor).into()
}
}

impl<T: AddAssign> AddAssign for MajorMinorCost<T> {
fn add_assign(&mut self, rhs: Self) {
self.major += rhs.major;
self.minor += rhs.minor;
}
}

impl<T: Add<Output = T> + Default> Sum for MajorMinorCost<T> {
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.reduce(|a, b| (a.major + b.major, a.minor + b.minor).into())
Expand Down
15 changes: 7 additions & 8 deletions src/optimiser/taso.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ where
// Ignore this circuit: we've already seen it
continue;
}
circ_cnt += 1;
logger.log_progress(circ_cnt, Some(pq.len()), seen_hashes.len());
let new_circ_cost = cost.add_delta(&cost_delta);
pq.push_unchecked(new_circ, new_circ_hash, new_circ_cost);
Expand Down Expand Up @@ -390,22 +389,22 @@ mod taso_default {
use hugr::ops::OpType;

use crate::rewrite::ecc_rewriter::RewriterSerialisationError;
use crate::rewrite::strategy::NonIncreasingGateCountStrategy;
use crate::rewrite::strategy::{ExhaustiveGreedyStrategy, NonIncreasingGateCountCost};
use crate::rewrite::ECCRewriter;

use super::*;

pub type StrategyCost = NonIncreasingGateCountCost<fn(&OpType) -> usize, fn(&OpType) -> usize>;

/// The default TASO optimiser using ECC sets.
pub type DefaultTasoOptimiser = TasoOptimiser<
ECCRewriter,
NonIncreasingGateCountStrategy<fn(&OpType) -> usize, fn(&OpType) -> usize>,
>;
pub type DefaultTasoOptimiser =
TasoOptimiser<ECCRewriter, ExhaustiveGreedyStrategy<StrategyCost>>;

impl DefaultTasoOptimiser {
/// A sane default optimiser using the given ECC sets.
pub fn default_with_eccs_json_file(eccs_path: impl AsRef<Path>) -> io::Result<Self> {
let rewriter = ECCRewriter::try_from_eccs_json_file(eccs_path)?;
let strategy = NonIncreasingGateCountStrategy::default_cx();
let strategy = NonIncreasingGateCountCost::default_cx();
Ok(TasoOptimiser::new(rewriter, strategy))
}

Expand All @@ -414,7 +413,7 @@ mod taso_default {
rewriter_path: impl AsRef<Path>,
) -> Result<Self, RewriterSerialisationError> {
let rewriter = ECCRewriter::load_binary(rewriter_path)?;
let strategy = NonIncreasingGateCountStrategy::default_cx();
let strategy = NonIncreasingGateCountCost::default_cx();
Ok(TasoOptimiser::new(rewriter, strategy))
}
}
Expand Down
10 changes: 10 additions & 0 deletions src/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,16 @@ impl CircuitRewrite {
self.0.replacement()
}

/// Returns a set of nodes referenced by the rewrite. Modifying any these
/// nodes will invalidate it.
///
/// Two `CircuitRewrite`s can be composed if their invalidation sets are
/// disjoint.
#[inline]
pub fn invalidation_set(&self) -> impl Iterator<Item = Node> + '_ {
self.0.invalidation_set()
}

delegate! {
to self.0 {
/// Apply the rewrite rule to a circuit.
Expand Down
Loading

0 comments on commit cf6ccf9

Please sign in to comment.