Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(badger): cx and rz const functions and strategies for LexicographicCostFunction #625

Merged
merged 6 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 45 additions & 4 deletions tket2-py/src/optimiser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use std::io::BufWriter;
use std::{fs, num::NonZeroUsize, path::PathBuf};

use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use tket2::optimiser::badger::BadgerOptions;
use tket2::optimiser::{BadgerLogger, DefaultBadgerOptimiser};
Expand All @@ -24,20 +25,60 @@ pub fn module(py: Python<'_>) -> PyResult<Bound<'_, PyModule>> {
#[pyclass(name = "BadgerOptimiser")]
pub struct PyBadgerOptimiser(DefaultBadgerOptimiser);

/// The cost function to use for the Badger optimiser.
#[derive(Debug, Clone, Copy, Default)]
pub enum BadgerCostFunction {
/// Minimise CX count.
#[default]
CXCount,
/// Minimise Rz count.
RzCount,
}

impl<'py> FromPyObject<'py> for BadgerCostFunction {
fn extract(ob: &'py PyAny) -> PyResult<Self> {
let str = ob.extract::<&str>()?;
match str {
"cx" => Ok(BadgerCostFunction::CXCount),
"rz" => Ok(BadgerCostFunction::RzCount),
_ => Err(PyErr::new::<PyValueError, _>(format!(
"Invalid cost function: {}. Expected 'cx' or 'rz'.",
str
))),
}
}
}

#[pymethods]
impl PyBadgerOptimiser {
/// Create a new [`PyDefaultBadgerOptimiser`] from a precompiled rewriter.
#[staticmethod]
pub fn load_precompiled(path: PathBuf) -> Self {
Self(DefaultBadgerOptimiser::default_with_rewriter_binary(path).unwrap())
pub fn load_precompiled(path: PathBuf, cost_fn: Option<BadgerCostFunction>) -> Self {
let opt = match cost_fn.unwrap_or_default() {
BadgerCostFunction::CXCount => {
DefaultBadgerOptimiser::default_with_rewriter_binary(path).unwrap()
}
BadgerCostFunction::RzCount => {
DefaultBadgerOptimiser::rz_opt_with_rewriter_binary(path).unwrap()
}
};
Self(opt)
}

/// Create a new [`PyDefaultBadgerOptimiser`] from ECC sets.
///
/// This will compile the rewriter from the provided ECC JSON file.
#[staticmethod]
pub fn compile_eccs(path: &str) -> Self {
Self(DefaultBadgerOptimiser::default_with_eccs_json_file(path).unwrap())
pub fn compile_eccs(path: &str, cost_fn: Option<BadgerCostFunction>) -> Self {
let opt = match cost_fn.unwrap_or_default() {
BadgerCostFunction::CXCount => {
DefaultBadgerOptimiser::default_with_eccs_json_file(path).unwrap()
}
BadgerCostFunction::RzCount => {
DefaultBadgerOptimiser::rz_opt_with_eccs_json_file(path).unwrap()
}
};
Self(opt)
}

/// Run the optimiser on a circuit.
Expand Down
24 changes: 19 additions & 5 deletions tket2-py/tket2/_tket2/optimiser.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TypeVar
from typing import TypeVar, Literal
from .circuit import Tk2Circuit
from pytket._tket.circuit import Circuit

Expand All @@ -8,12 +8,26 @@ CircuitClass = TypeVar("CircuitClass", Circuit, Tk2Circuit)

class BadgerOptimiser:
@staticmethod
def load_precompiled(filename: Path) -> BadgerOptimiser:
"""Load a precompiled rewriter from a file."""
def load_precompiled(
filename: Path, cost_fn: Literal["cx", "rz"] | None = None
) -> BadgerOptimiser:
"""
Load a precompiled rewriter from a file.

:param filename: The path to the file containing the precompiled rewriter.
:param cost_fn: The cost function to use.
"""

@staticmethod
def compile_eccs(filename: Path) -> BadgerOptimiser:
"""Compile a set of ECCs and create a new rewriter ."""
def compile_eccs(
filename: Path, cost_fn: Literal["cx", "rz"] | None = None
) -> BadgerOptimiser:
"""
Compile a set of ECCs and create a new rewriter.

:param filename: The path to the file containing the ECCs.
:param cost_fn: The cost function to use.
"""

def optimise(
self,
Expand Down
8 changes: 6 additions & 2 deletions tket2-py/tket2/passes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Optional
from typing import Optional, Literal

from pytket import Circuit
from pytket.passes import CustomPass, BasePass
Expand Down Expand Up @@ -37,13 +37,17 @@ def badger_pass(
max_circuit_count: Optional[int] = None,
log_dir: Optional[Path] = None,
rebase: bool = False,
cost_fn: Literal["cx", "rz"] | None = None,
) -> BasePass:
"""Construct a Badger pass.

The Badger optimiser requires a pre-compiled rewriter produced by the
`compile-rewriter <https://github.com/CQCL/tket2/tree/main/badger-optimiser>`_
utility. If `rewriter` is not specified, a default one will be used.

The cost function to minimise can be specified by passing `cost_fn` as `'cx'`
or `'rz'`. If not specified, the default is `'cx'`.

The arguments `max_threads`, `timeout`, `progress_timeout`, `max_circuit_count`,
`log_dir` and `rebase` are optional and will be passed on to the Badger
optimiser if provided."""
Expand All @@ -56,7 +60,7 @@ def badger_pass(
)

rewriter = tket2_eccs.nam_6_3()
opt = optimiser.BadgerOptimiser.load_precompiled(rewriter)
opt = optimiser.BadgerOptimiser.load_precompiled(rewriter, cost_fn=cost_fn)

def apply(circuit: Circuit) -> Circuit:
"""Apply Badger optimisation to the circuit."""
Expand Down
21 changes: 19 additions & 2 deletions tket2/src/optimiser/badger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ mod badger_default {
/// A sane default optimiser using the given ECC sets.
pub fn default_with_eccs_json_file(eccs_path: impl AsRef<Path>) -> io::Result<Self> {
let rewriter = ECCRewriter::try_from_eccs_json_file(eccs_path)?;
let strategy = LexicographicCostFunction::default_cx();
let strategy = LexicographicCostFunction::default_cx_strategy();
Ok(BadgerOptimiser::new(rewriter, strategy))
}

Expand All @@ -528,7 +528,24 @@ mod badger_default {
rewriter_path: impl AsRef<Path>,
) -> Result<Self, RewriterSerialisationError> {
let rewriter = ECCRewriter::load_binary(rewriter_path)?;
let strategy = LexicographicCostFunction::default_cx();
let strategy = LexicographicCostFunction::default_cx_strategy();
Ok(BadgerOptimiser::new(rewriter, strategy))
}

/// An optimiser minimising Rz gate count using the given ECC sets.
pub fn rz_opt_with_eccs_json_file(eccs_path: impl AsRef<Path>) -> io::Result<Self> {
let rewriter = ECCRewriter::try_from_eccs_json_file(eccs_path)?;
let strategy = LexicographicCostFunction::rz_count().into_greedy_strategy();
Ok(BadgerOptimiser::new(rewriter, strategy))
}

/// An optimiser minimising Rz gate count using a precompiled binary rewriter.
#[cfg(feature = "binary-eccs")]
pub fn rz_opt_with_rewriter_binary(
rewriter_path: impl AsRef<Path>,
) -> Result<Self, RewriterSerialisationError> {
let rewriter = ECCRewriter::load_binary(rewriter_path)?;
let strategy = LexicographicCostFunction::rz_count().into_greedy_strategy();
Ok(BadgerOptimiser::new(rewriter, strategy))
}
}
Expand Down
69 changes: 61 additions & 8 deletions tket2/src/rewrite/strategy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
//! not increase some coarse cost function (e.g. CX count), whilst
//! ordering them according to a lexicographic ordering of finer cost
//! functions (e.g. total gate count). See
//! [`LexicographicCostFunction::default_cx`]) for a default implementation.
//! [`LexicographicCostFunction::default_cx_strategy`]) for a default implementation.
//! - [`GammaStrategyCost`] ignores rewrites that increase the cost
//! function beyond a percentage given by a f64 parameter gamma.

Expand All @@ -29,7 +29,7 @@ use hugr::HugrView;
use itertools::Itertools;

use crate::circuit::cost::{is_cx, is_quantum, CircuitCost, CostDelta, LexicographicCost};
use crate::Circuit;
use crate::{op_matches, Circuit, Tk2Op};

use super::trace::RewriteTrace;
use super::CircuitRewrite;
Expand Down Expand Up @@ -345,12 +345,66 @@ impl LexicographicCostFunction<fn(&OpType) -> usize, 2> {
/// is used to rank circuits with equal CX count.
///
/// This is probably a good default for NISQ-y circuit optimisation.
#[inline]
pub fn default_cx_strategy() -> ExhaustiveGreedyStrategy<Self> {
Self::cx_count().into_greedy_strategy()
}

/// Non-increasing rewrite strategy based on CX count.
///
/// A fine-grained cost function given by the total number of quantum gates
/// is used to rank circuits with equal CX count.
///
/// This is probably a good default for NISQ-y circuit optimisation.
///
/// Deprecated: Use `default_cx_strategy` instead.
// TODO: Remove this method in the next breaking release.
#[deprecated(since = "0.5.1", note = "Use `default_cx_strategy` instead.")]
pub fn default_cx() -> ExhaustiveGreedyStrategy<Self> {
Self::default_cx_strategy()
}

/// Non-increasing rewrite cost function based on CX gate count.
///
/// A fine-grained cost function given by the total number of quantum gates
/// is used to rank circuits with equal Rz gate count.
#[inline]
pub fn cx_count() -> Self {
Self {
cost_fns: [|op| is_cx(op) as usize, |op| is_quantum(op) as usize],
}
.into()
}

// TODO: Ideally, do not count Clifford rotations in the cost function.
/// Non-increasing rewrite cost function based on Rz gate count.
///
/// A fine-grained cost function given by the total number of quantum gates
/// is used to rank circuits with equal Rz gate count.
#[inline]
pub fn rz_count() -> Self {
Self {
cost_fns: [
|op| op_matches(op, Tk2Op::Rz) as usize,
|op| is_quantum(op) as usize,
],
}
}

/// Consume the cost function and create a greedy rewrite strategy out of
/// it.
pub fn into_greedy_strategy(self) -> ExhaustiveGreedyStrategy<Self> {
ExhaustiveGreedyStrategy { strat_cost: self }
}

/// Consume the cost function and create a threshold rewrite strategy out
/// of it.
pub fn into_threshold_strategy(self) -> ExhaustiveThresholdStrategy<Self> {
ExhaustiveThresholdStrategy { strat_cost: self }
}
}

impl Default for LexicographicCostFunction<fn(&OpType) -> usize, 2> {
fn default() -> Self {
LexicographicCostFunction::cx_count()
}
}

Expand Down Expand Up @@ -440,7 +494,6 @@ mod tests {
circuit::Circuit,
rewrite::{CircuitRewrite, Subcircuit},
utils::build_simple_circuit,
Tk2Op,
};

fn n_cx(n_gates: usize) -> Circuit {
Expand Down Expand Up @@ -512,7 +565,7 @@ mod tests {
rw_to_empty(&circ, cx_gates[9..10].to_vec()),
];

let strategy = LexicographicCostFunction::default_cx();
let strategy = LexicographicCostFunction::cx_count().into_greedy_strategy();
let rewritten = strategy.apply_rewrites(rws, &circ).collect_vec();
let exp_circ_lens = HashSet::from_iter([3, 7, 9]);
let circ_lens: HashSet<_> = rewritten.iter().map(|r| r.circ.num_operations()).collect();
Expand Down Expand Up @@ -557,7 +610,7 @@ mod tests {

#[test]
fn test_exhaustive_default_cx_cost() {
let strat = LexicographicCostFunction::default_cx();
let strat = LexicographicCostFunction::cx_count().into_greedy_strategy();
let circ = n_cx(3);
assert_eq!(strat.circuit_cost(&circ), (3, 3).into());
let circ = build_simple_circuit(2, |circ| {
Expand All @@ -572,7 +625,7 @@ mod tests {

#[test]
fn test_exhaustive_default_cx_threshold() {
let strat = LexicographicCostFunction::default_cx().strat_cost;
let strat = LexicographicCostFunction::cx_count();
assert!(strat.under_threshold(&(3, 0).into(), &(3, 0).into()));
assert!(strat.under_threshold(&(3, 0).into(), &(3, 5).into()));
assert!(!strat.under_threshold(&(3, 10).into(), &(4, 0).into()));
Expand Down
Loading