diff --git a/tket2-py/src/optimiser.rs b/tket2-py/src/optimiser.rs index ac04c667..fdba25a8 100644 --- a/tket2-py/src/optimiser.rs +++ b/tket2-py/src/optimiser.rs @@ -3,6 +3,7 @@ use std::io::BufWriter; use std::{fs, num::NonZeroUsize, path::PathBuf}; +use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use tket2::optimiser::badger::BadgerOptions; use tket2::optimiser::{BadgerLogger, DefaultBadgerOptimiser}; @@ -24,20 +25,60 @@ pub fn module(py: Python<'_>) -> PyResult> { #[pyclass(name = "BadgerOptimiser")] pub struct PyBadgerOptimiser(DefaultBadgerOptimiser); +/// The cost function to use for the Badger optimiser. +#[derive(Debug, Clone, Copy, Default)] +pub enum BadgerCostFunction { + /// Minimise CX count. + #[default] + CXCount, + /// Minimise Rz count. + RzCount, +} + +impl<'py> FromPyObject<'py> for BadgerCostFunction { + fn extract(ob: &'py PyAny) -> PyResult { + let str = ob.extract::<&str>()?; + match str { + "cx" => Ok(BadgerCostFunction::CXCount), + "rz" => Ok(BadgerCostFunction::RzCount), + _ => Err(PyErr::new::(format!( + "Invalid cost function: {}. Expected 'cx' or 'rz'.", + str + ))), + } + } +} + #[pymethods] impl PyBadgerOptimiser { /// Create a new [`PyDefaultBadgerOptimiser`] from a precompiled rewriter. #[staticmethod] - pub fn load_precompiled(path: PathBuf) -> Self { - Self(DefaultBadgerOptimiser::default_with_rewriter_binary(path).unwrap()) + pub fn load_precompiled(path: PathBuf, cost_fn: Option) -> Self { + let opt = match cost_fn.unwrap_or_default() { + BadgerCostFunction::CXCount => { + DefaultBadgerOptimiser::default_with_rewriter_binary(path).unwrap() + } + BadgerCostFunction::RzCount => { + DefaultBadgerOptimiser::rz_opt_with_rewriter_binary(path).unwrap() + } + }; + Self(opt) } /// Create a new [`PyDefaultBadgerOptimiser`] from ECC sets. /// /// This will compile the rewriter from the provided ECC JSON file. #[staticmethod] - pub fn compile_eccs(path: &str) -> Self { - Self(DefaultBadgerOptimiser::default_with_eccs_json_file(path).unwrap()) + pub fn compile_eccs(path: &str, cost_fn: Option) -> Self { + let opt = match cost_fn.unwrap_or_default() { + BadgerCostFunction::CXCount => { + DefaultBadgerOptimiser::default_with_eccs_json_file(path).unwrap() + } + BadgerCostFunction::RzCount => { + DefaultBadgerOptimiser::rz_opt_with_eccs_json_file(path).unwrap() + } + }; + Self(opt) } /// Run the optimiser on a circuit. diff --git a/tket2-py/tket2/_tket2/optimiser.pyi b/tket2-py/tket2/_tket2/optimiser.pyi index aef94ad0..0675210e 100644 --- a/tket2-py/tket2/_tket2/optimiser.pyi +++ b/tket2-py/tket2/_tket2/optimiser.pyi @@ -1,4 +1,4 @@ -from typing import TypeVar +from typing import TypeVar, Literal from .circuit import Tk2Circuit from pytket._tket.circuit import Circuit @@ -8,12 +8,26 @@ CircuitClass = TypeVar("CircuitClass", Circuit, Tk2Circuit) class BadgerOptimiser: @staticmethod - def load_precompiled(filename: Path) -> BadgerOptimiser: - """Load a precompiled rewriter from a file.""" + def load_precompiled( + filename: Path, cost_fn: Literal["cx", "rz"] | None = None + ) -> BadgerOptimiser: + """ + Load a precompiled rewriter from a file. + + :param filename: The path to the file containing the precompiled rewriter. + :param cost_fn: The cost function to use. + """ @staticmethod - def compile_eccs(filename: Path) -> BadgerOptimiser: - """Compile a set of ECCs and create a new rewriter .""" + def compile_eccs( + filename: Path, cost_fn: Literal["cx", "rz"] | None = None + ) -> BadgerOptimiser: + """ + Compile a set of ECCs and create a new rewriter. + + :param filename: The path to the file containing the ECCs. + :param cost_fn: The cost function to use. + """ def optimise( self, diff --git a/tket2-py/tket2/passes.py b/tket2-py/tket2/passes.py index ae593329..032c565a 100644 --- a/tket2-py/tket2/passes.py +++ b/tket2-py/tket2/passes.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Optional +from typing import Optional, Literal from pytket import Circuit from pytket.passes import CustomPass, BasePass @@ -37,6 +37,7 @@ def badger_pass( max_circuit_count: Optional[int] = None, log_dir: Optional[Path] = None, rebase: bool = False, + cost_fn: Literal["cx", "rz"] | None = None, ) -> BasePass: """Construct a Badger pass. @@ -44,6 +45,9 @@ def badger_pass( `compile-rewriter `_ utility. If `rewriter` is not specified, a default one will be used. + The cost function to minimise can be specified by passing `cost_fn` as `'cx'` + or `'rz'`. If not specified, the default is `'cx'`. + The arguments `max_threads`, `timeout`, `progress_timeout`, `max_circuit_count`, `log_dir` and `rebase` are optional and will be passed on to the Badger optimiser if provided.""" @@ -56,7 +60,7 @@ def badger_pass( ) rewriter = tket2_eccs.nam_6_3() - opt = optimiser.BadgerOptimiser.load_precompiled(rewriter) + opt = optimiser.BadgerOptimiser.load_precompiled(rewriter, cost_fn=cost_fn) def apply(circuit: Circuit) -> Circuit: """Apply Badger optimisation to the circuit.""" diff --git a/tket2/src/optimiser/badger.rs b/tket2/src/optimiser/badger.rs index 54d62f7a..a692abc3 100644 --- a/tket2/src/optimiser/badger.rs +++ b/tket2/src/optimiser/badger.rs @@ -518,7 +518,7 @@ mod badger_default { /// A sane default optimiser using the given ECC sets. pub fn default_with_eccs_json_file(eccs_path: impl AsRef) -> io::Result { let rewriter = ECCRewriter::try_from_eccs_json_file(eccs_path)?; - let strategy = LexicographicCostFunction::default_cx(); + let strategy = LexicographicCostFunction::default_cx_strategy(); Ok(BadgerOptimiser::new(rewriter, strategy)) } @@ -528,7 +528,24 @@ mod badger_default { rewriter_path: impl AsRef, ) -> Result { let rewriter = ECCRewriter::load_binary(rewriter_path)?; - let strategy = LexicographicCostFunction::default_cx(); + let strategy = LexicographicCostFunction::default_cx_strategy(); + Ok(BadgerOptimiser::new(rewriter, strategy)) + } + + /// An optimiser minimising Rz gate count using the given ECC sets. + pub fn rz_opt_with_eccs_json_file(eccs_path: impl AsRef) -> io::Result { + let rewriter = ECCRewriter::try_from_eccs_json_file(eccs_path)?; + let strategy = LexicographicCostFunction::rz_count().into_greedy_strategy(); + Ok(BadgerOptimiser::new(rewriter, strategy)) + } + + /// An optimiser minimising Rz gate count using a precompiled binary rewriter. + #[cfg(feature = "binary-eccs")] + pub fn rz_opt_with_rewriter_binary( + rewriter_path: impl AsRef, + ) -> Result { + let rewriter = ECCRewriter::load_binary(rewriter_path)?; + let strategy = LexicographicCostFunction::rz_count().into_greedy_strategy(); Ok(BadgerOptimiser::new(rewriter, strategy)) } } diff --git a/tket2/src/rewrite/strategy.rs b/tket2/src/rewrite/strategy.rs index 98020cff..a55159ad 100644 --- a/tket2/src/rewrite/strategy.rs +++ b/tket2/src/rewrite/strategy.rs @@ -16,7 +16,7 @@ //! not increase some coarse cost function (e.g. CX count), whilst //! ordering them according to a lexicographic ordering of finer cost //! functions (e.g. total gate count). See -//! [`LexicographicCostFunction::default_cx`]) for a default implementation. +//! [`LexicographicCostFunction::default_cx_strategy`]) for a default implementation. //! - [`GammaStrategyCost`] ignores rewrites that increase the cost //! function beyond a percentage given by a f64 parameter gamma. @@ -29,7 +29,7 @@ use hugr::HugrView; use itertools::Itertools; use crate::circuit::cost::{is_cx, is_quantum, CircuitCost, CostDelta, LexicographicCost}; -use crate::Circuit; +use crate::{op_matches, Circuit, Tk2Op}; use super::trace::RewriteTrace; use super::CircuitRewrite; @@ -345,12 +345,66 @@ impl LexicographicCostFunction usize, 2> { /// is used to rank circuits with equal CX count. /// /// This is probably a good default for NISQ-y circuit optimisation. - #[inline] + pub fn default_cx_strategy() -> ExhaustiveGreedyStrategy { + Self::cx_count().into_greedy_strategy() + } + + /// Non-increasing rewrite strategy based on CX count. + /// + /// A fine-grained cost function given by the total number of quantum gates + /// is used to rank circuits with equal CX count. + /// + /// This is probably a good default for NISQ-y circuit optimisation. + /// + /// Deprecated: Use `default_cx_strategy` instead. + // TODO: Remove this method in the next breaking release. + #[deprecated(since = "0.5.1", note = "Use `default_cx_strategy` instead.")] pub fn default_cx() -> ExhaustiveGreedyStrategy { + Self::default_cx_strategy() + } + + /// Non-increasing rewrite cost function based on CX gate count. + /// + /// A fine-grained cost function given by the total number of quantum gates + /// is used to rank circuits with equal Rz gate count. + #[inline] + pub fn cx_count() -> Self { Self { cost_fns: [|op| is_cx(op) as usize, |op| is_quantum(op) as usize], } - .into() + } + + // TODO: Ideally, do not count Clifford rotations in the cost function. + /// Non-increasing rewrite cost function based on Rz gate count. + /// + /// A fine-grained cost function given by the total number of quantum gates + /// is used to rank circuits with equal Rz gate count. + #[inline] + pub fn rz_count() -> Self { + Self { + cost_fns: [ + |op| op_matches(op, Tk2Op::Rz) as usize, + |op| is_quantum(op) as usize, + ], + } + } + + /// Consume the cost function and create a greedy rewrite strategy out of + /// it. + pub fn into_greedy_strategy(self) -> ExhaustiveGreedyStrategy { + ExhaustiveGreedyStrategy { strat_cost: self } + } + + /// Consume the cost function and create a threshold rewrite strategy out + /// of it. + pub fn into_threshold_strategy(self) -> ExhaustiveThresholdStrategy { + ExhaustiveThresholdStrategy { strat_cost: self } + } +} + +impl Default for LexicographicCostFunction usize, 2> { + fn default() -> Self { + LexicographicCostFunction::cx_count() } } @@ -440,7 +494,6 @@ mod tests { circuit::Circuit, rewrite::{CircuitRewrite, Subcircuit}, utils::build_simple_circuit, - Tk2Op, }; fn n_cx(n_gates: usize) -> Circuit { @@ -512,7 +565,7 @@ mod tests { rw_to_empty(&circ, cx_gates[9..10].to_vec()), ]; - let strategy = LexicographicCostFunction::default_cx(); + let strategy = LexicographicCostFunction::cx_count().into_greedy_strategy(); let rewritten = strategy.apply_rewrites(rws, &circ).collect_vec(); let exp_circ_lens = HashSet::from_iter([3, 7, 9]); let circ_lens: HashSet<_> = rewritten.iter().map(|r| r.circ.num_operations()).collect(); @@ -557,7 +610,7 @@ mod tests { #[test] fn test_exhaustive_default_cx_cost() { - let strat = LexicographicCostFunction::default_cx(); + let strat = LexicographicCostFunction::cx_count().into_greedy_strategy(); let circ = n_cx(3); assert_eq!(strat.circuit_cost(&circ), (3, 3).into()); let circ = build_simple_circuit(2, |circ| { @@ -572,7 +625,7 @@ mod tests { #[test] fn test_exhaustive_default_cx_threshold() { - let strat = LexicographicCostFunction::default_cx().strat_cost; + let strat = LexicographicCostFunction::cx_count(); assert!(strat.under_threshold(&(3, 0).into(), &(3, 0).into())); assert!(strat.under_threshold(&(3, 0).into(), &(3, 5).into())); assert!(!strat.under_threshold(&(3, 10).into(), &(4, 0).into()));