Skip to content

Commit

Permalink
feat: Use tket1 and tket2 circuits interchangeably everywhere (#243)
Browse files Browse the repository at this point in the history
Adds a `tket1`/`tket2` to the `with_hugr` helpers, so we always know
what format to output afterwards.

The more noisy part of this commit changing all the GIL-independent
types to capturing references, so we don't have to manually lock
multiple times per call.

Closes #178.
  • Loading branch information
aborgna-q committed Nov 15, 2023
1 parent b76a21a commit eac7acf
Show file tree
Hide file tree
Showing 9 changed files with 203 additions and 120 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ authors = [] # TODO
maintainers = [] # TODO
include = ["pyproject.toml"]
license = "Apache-2.0"
license_file = "LICENCE"
readme = "README.md"

packages = [{ include = "tket2-py" }]
Expand Down
19 changes: 6 additions & 13 deletions tket2-py/src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,18 @@ use tket2::json::TKETDecode;
use tket2::rewrite::CircuitRewrite;
use tket_json_rs::circuit_json::SerialCircuit;

pub use self::convert::{try_update_hugr, try_with_hugr, update_hugr, with_hugr, T2Circuit};
pub use self::convert::{try_update_hugr, try_with_hugr, update_hugr, with_hugr, Tk2Circuit};

/// The module definition
pub fn module(py: Python) -> PyResult<&PyModule> {
let m = PyModule::new(py, "_circuit")?;
m.add_class::<T2Circuit>()?;
m.add_class::<Tk2Circuit>()?;
m.add_class::<PyNode>()?;
m.add_class::<tket2::T2Op>()?;
m.add_class::<tket2::Pauli>()?;

m.add_function(wrap_pyfunction!(validate_hugr, m)?)?;
m.add_function(wrap_pyfunction!(to_hugr_dot, m)?)?;
m.add_function(wrap_pyfunction!(to_hugr, m)?)?;

m.add("HugrError", py.get_type::<hugr::hugr::PyHugrError>())?;
m.add("BuildError", py.get_type::<hugr::builder::PyBuildError>())?;
Expand All @@ -47,20 +46,14 @@ pub fn module(py: Python) -> PyResult<&PyModule> {

/// Run the validation checks on a circuit.
#[pyfunction]
pub fn validate_hugr(c: Py<PyAny>) -> PyResult<()> {
try_with_hugr(c, |hugr| hugr.validate(&REGISTRY))
pub fn validate_hugr(c: &PyAny) -> PyResult<()> {
try_with_hugr(c, |hugr, _| hugr.validate(&REGISTRY))
}

/// Return a Graphviz DOT string representation of the circuit.
#[pyfunction]
pub fn to_hugr_dot(c: Py<PyAny>) -> PyResult<String> {
with_hugr(c, |hugr| hugr.dot_string())
}

/// Downcast a python object to a [`Hugr`].
#[pyfunction]
pub fn to_hugr(c: Py<PyAny>) -> PyResult<T2Circuit> {
with_hugr(c, |hugr| hugr.into())
pub fn to_hugr_dot(c: &PyAny) -> PyResult<String> {
with_hugr(c, |hugr, _| hugr.dot_string())
}

/// A [`hugr::Node`] wrapper for Python.
Expand Down
151 changes: 112 additions & 39 deletions tket2-py/src/circuit/convert.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
//! Utilities for calling Hugr functions on generic python objects.

use pyo3::exceptions::PyAttributeError;
use pyo3::{prelude::*, PyTypeInfo};

use derive_more::From;
use hugr::{Hugr, HugrView};
use serde::Serialize;
use tket2::extension::REGISTRY;
use tket2::json::TKETDecode;
use tket2::passes::CircuitChunks;
Expand All @@ -14,78 +16,149 @@ use crate::pattern::rewrite::PyCircuitRewrite;
/// A manager for tket 2 operations on a tket 1 Circuit.
#[pyclass]
#[derive(Clone, Debug, PartialEq, From)]
pub struct T2Circuit {
pub struct Tk2Circuit {
/// Rust representation of the circuit.
pub hugr: Hugr,
}

#[pymethods]
impl T2Circuit {
impl Tk2Circuit {
/// Convert a tket1 circuit to a [`Tk2Circuit`].
#[new]
fn from_circuit(circ: PyObject) -> PyResult<Self> {
pub fn from_tket1(circ: &PyAny) -> PyResult<Self> {
Ok(Self {
hugr: with_hugr(circ, |hugr| hugr)?,
hugr: with_hugr(circ, |hugr, _| hugr)?,
})
}

fn finish(&self) -> PyResult<PyObject> {
SerialCircuit::encode(&self.hugr)?.to_tket1_with_gil()
/// Convert the [`Tk2Circuit`] to a tket1 circuit.
pub fn to_tket1<'py>(&self, py: Python<'py>) -> PyResult<&'py PyAny> {
SerialCircuit::encode(&self.hugr)?.to_tket1(py)
}

fn apply_match(&mut self, rw: PyCircuitRewrite) {
/// Apply a rewrite on the circuit.
pub fn apply_match(&mut self, rw: PyCircuitRewrite) {
rw.rewrite.apply(&mut self.hugr).expect("Apply error.");
}

/// Encode the circuit as a HUGR json string.
//
// TODO: Bind a messagepack encoder/decoder too.
pub fn to_hugr_json(&self) -> PyResult<String> {
Ok(serde_json::to_string(&self.hugr).unwrap())
}

/// Decode a HUGR json string to a circuit.
#[staticmethod]
pub fn from_hugr_json(json: &str) -> PyResult<Self> {
let hugr = serde_json::from_str(json)
.map_err(|e| PyErr::new::<PyAttributeError, _>(format!("Invalid encoded HUGR: {e}")))?;
Ok(Tk2Circuit { hugr })
}

/// Encode the circuit as a tket1 json string.
///
/// FIXME: Currently the encoded circuit cannot be loaded back due to
/// [https://github.com/CQCL/hugr/issues/683]
pub fn to_tket1_json(&self) -> PyResult<String> {
Ok(serde_json::to_string(&SerialCircuit::encode(&self.hugr)?).unwrap())
}

/// Decode a tket1 json string to a circuit.
#[staticmethod]
pub fn from_tket1_json(json: &str) -> PyResult<Self> {
let tk1: SerialCircuit = serde_json::from_str(json)
.map_err(|e| PyErr::new::<PyAttributeError, _>(format!("Invalid encoded HUGR: {e}")))?;
Ok(Tk2Circuit {
hugr: tk1.decode()?,
})
}
}
impl T2Circuit {
/// Tries to extract a T2Circuit from a python object.
impl Tk2Circuit {
/// Tries to extract a Tk2Circuit from a python object.
///
/// Returns an error if the py object is not a T2Circuit.
pub fn try_extract(circ: Py<PyAny>) -> PyResult<Self> {
Python::with_gil(|py| circ.as_ref(py).extract::<T2Circuit>())
/// Returns an error if the py object is not a Tk2Circuit.
pub fn try_extract(circ: &PyAny) -> PyResult<Self> {
circ.extract::<Tk2Circuit>()
}
}

/// A flag to indicate the encoding of a circuit.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum CircuitType {
/// A `pytket` `Circuit`.
Tket1,
/// A tket2 `Tk2Circuit`, represented as a HUGR.
Tket2,
}

impl CircuitType {
/// Converts a `Hugr` into the format indicated by the flag.
pub fn convert(self, py: Python, hugr: Hugr) -> PyResult<&PyAny> {
match self {
CircuitType::Tket1 => SerialCircuit::encode(&hugr)?.to_tket1(py),
CircuitType::Tket2 => Ok(Py::new(py, Tk2Circuit { hugr })?.into_ref(py)),
}
}
}

/// Apply a fallible function expecting a hugr on a pytket circuit.
pub fn try_with_hugr<T, E, F>(circ: Py<PyAny>, f: F) -> PyResult<T>
/// Apply a fallible function expecting a hugr on a python circuit.
///
/// This method supports both `pytket.Circuit` and `Tk2Circuit` python objects.
pub fn try_with_hugr<T, E, F>(circ: &PyAny, f: F) -> PyResult<T>
where
E: Into<PyErr>,
F: FnOnce(Hugr) -> Result<T, E>,
F: FnOnce(Hugr, CircuitType) -> Result<T, E>,
{
let hugr = Python::with_gil(|py| -> PyResult<Hugr> {
let circ = circ.as_ref(py);
match T2Circuit::extract(circ) {
// hugr circuit
Ok(t2circ) => Ok(t2circ.hugr),
// tket1 circuit
Err(_) => Ok(SerialCircuit::from_tket1(circ)?.decode()?),
}
})?;
(f)(hugr).map_err(|e| e.into())
let (hugr, typ) = match Tk2Circuit::extract(circ) {
// hugr circuit
Ok(t2circ) => (t2circ.hugr, CircuitType::Tket2),
// tket1 circuit
Err(_) => (
SerialCircuit::from_tket1(circ)?.decode()?,
CircuitType::Tket1,
),
};
(f)(hugr, typ).map_err(|e| e.into())
}

/// Apply a function expecting a hugr on a pytket circuit.
pub fn with_hugr<T, F>(circ: Py<PyAny>, f: F) -> PyResult<T>
/// Apply a function expecting a hugr on a python circuit.
///
/// This method supports both `pytket.Circuit` and `Tk2Circuit` python objects.
pub fn with_hugr<T, F>(circ: &PyAny, f: F) -> PyResult<T>
where
F: FnOnce(Hugr) -> T,
F: FnOnce(Hugr, CircuitType) -> T,
{
try_with_hugr(circ, |hugr| Ok::<T, PyErr>((f)(hugr)))
try_with_hugr(circ, |hugr, typ| Ok::<T, PyErr>((f)(hugr, typ)))
}

/// Apply a hugr-to-hugr function on a pytket circuit, and return the modified circuit.
pub fn try_update_hugr<E, F>(circ: Py<PyAny>, f: F) -> PyResult<Py<PyAny>>
/// Apply a fallible hugr-to-hugr function on a python circuit, and return the modified circuit.
///
/// This method supports both `pytket.Circuit` and `Tk2Circuit` python objects.
/// The returned Hugr is converted to the matching python object.
pub fn try_update_hugr<E, F>(circ: &PyAny, f: F) -> PyResult<&PyAny>
where
E: Into<PyErr>,
F: FnOnce(Hugr) -> Result<Hugr, E>,
F: FnOnce(Hugr, CircuitType) -> Result<Hugr, E>,
{
let hugr = try_with_hugr(circ, f)?;
SerialCircuit::encode(&hugr)?.to_tket1_with_gil()
let py = circ.py();
try_with_hugr(circ, |hugr, typ| {
let hugr = f(hugr, typ).map_err(|e| e.into())?;
typ.convert(py, hugr)
})
}

/// Apply a hugr-to-hugr function on a pytket circuit, and return the modified circuit.
pub fn update_hugr<F>(circ: Py<PyAny>, f: F) -> PyResult<Py<PyAny>>
/// Apply a hugr-to-hugr function on a python circuit, and return the modified circuit.
///
/// This method supports both `pytket.Circuit` and `Tk2Circuit` python objects.
/// The returned Hugr is converted to the matching python object.
pub fn update_hugr<F>(circ: &PyAny, f: F) -> PyResult<&PyAny>
where
F: FnOnce(Hugr) -> Hugr,
F: FnOnce(Hugr, CircuitType) -> Hugr,
{
let hugr = with_hugr(circ, f)?;
SerialCircuit::encode(&hugr)?.to_tket1_with_gil()
let py = circ.py();
try_with_hugr(circ, |hugr, typ| {
let hugr = f(hugr, typ);
typ.convert(py, hugr)
})
}
8 changes: 4 additions & 4 deletions tket2-py/src/optimiser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,16 @@ impl PyBadgerOptimiser {
/// * `log_progress`: The path to a CSV file to log progress to.
///
#[pyo3(name = "optimise")]
pub fn py_optimise(
pub fn py_optimise<'py>(
&self,
circ: PyObject,
circ: &'py PyAny,
timeout: Option<u64>,
n_threads: Option<NonZeroUsize>,
split_circ: Option<bool>,
log_progress: Option<PathBuf>,
queue_size: Option<usize>,
) -> PyResult<PyObject> {
update_hugr(circ, |circ| {
) -> PyResult<&'py PyAny> {
update_hugr(circ, |circ, _| {
self.optimise(
circ,
timeout,
Expand Down
45 changes: 21 additions & 24 deletions tket2-py/src/passes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ pub mod chunks;
use std::{cmp::min, convert::TryInto, fs, num::NonZeroUsize, path::PathBuf};

use pyo3::{prelude::*, types::IntoPyDict};
use tket2::{json::TKETDecode, op_matches, passes::apply_greedy_commutation, Circuit, T2Op};
use tket_json_rs::circuit_json::SerialCircuit;
use tket2::{op_matches, passes::apply_greedy_commutation, Circuit, T2Op};

use crate::{
circuit::{try_update_hugr, try_with_hugr},
Expand All @@ -30,35 +29,33 @@ pub fn module(py: Python) -> PyResult<&PyModule> {
}

#[pyfunction]
fn greedy_depth_reduce(py_c: PyObject) -> PyResult<(PyObject, u32)> {
try_with_hugr(py_c, |mut h| {
fn greedy_depth_reduce(circ: &PyAny) -> PyResult<(&PyAny, u32)> {
let py = circ.py();
try_with_hugr(circ, |mut h, typ| {
let n_moves = apply_greedy_commutation(&mut h)?;
let py_c = SerialCircuit::encode(&h)?.to_tket1_with_gil()?;
PyResult::Ok((py_c, n_moves))
let circ = typ.convert(py, h)?;
PyResult::Ok((circ, n_moves))
})
}

/// Rebase a circuit to the Nam gate set (CX, Rz, H) using TKET1.
///
/// Acquires the python GIL to call TKET's `auto_rebase_pass`.
///
/// Equivalent to running the following code:
/// ```python
/// from pytket.passes.auto_rebase import auto_rebase_pass
/// from pytket import OpType
/// auto_rebase_pass({OpType.CX, OpType.Rz, OpType.H}).apply(circ)"
// ```
fn rebase_nam(circ: &PyObject) -> PyResult<()> {
Python::with_gil(|py| {
let auto_rebase = py
.import("pytket.passes.auto_rebase")?
.getattr("auto_rebase_pass")?;
let optype = py.import("pytket")?.getattr("OpType")?;
let locals = [("OpType", &optype)].into_py_dict(py);
let op_set = py.eval("{OpType.CX, OpType.Rz, OpType.H}", None, Some(locals))?;
let rebase_pass = auto_rebase.call1((op_set,))?.getattr("apply")?;
rebase_pass.call1((circ,)).map(|_| ())
})
fn rebase_nam(circ: &PyAny) -> PyResult<()> {
let py = circ.py();
let auto_rebase = py
.import("pytket.passes.auto_rebase")?
.getattr("auto_rebase_pass")?;
let optype = py.import("pytket")?.getattr("OpType")?;
let locals = [("OpType", &optype)].into_py_dict(py);
let op_set = py.eval("{OpType.CX, OpType.Rz, OpType.H}", None, Some(locals))?;
let rebase_pass = auto_rebase.call1((op_set,))?.getattr("apply")?;
rebase_pass.call1((circ,)).map(|_| ())
}

/// Badger optimisation pass.
Expand All @@ -76,14 +73,14 @@ fn rebase_nam(circ: &PyObject) -> PyResult<()> {
///
/// Log files will be written to the directory `log_dir` if specified.
#[pyfunction]
fn badger_optimise(
circ: PyObject,
fn badger_optimise<'py>(
circ: &'py PyAny,
optimiser: &PyBadgerOptimiser,
max_threads: Option<NonZeroUsize>,
timeout: Option<u64>,
log_dir: Option<PathBuf>,
rebase: Option<bool>,
) -> PyResult<PyObject> {
) -> PyResult<&'py PyAny> {
// Default parameter values
let rebase = rebase.unwrap_or(true);
let max_threads = max_threads.unwrap_or(num_cpus::get().try_into().unwrap());
Expand All @@ -94,7 +91,7 @@ fn badger_optimise(
}
// Rebase circuit
if rebase {
rebase_nam(&circ)?;
rebase_nam(circ)?;
}
// Logic to choose how to split the circuit
let badger_splits = |n_threads: NonZeroUsize| match n_threads.get() {
Expand All @@ -111,7 +108,7 @@ fn badger_optimise(
_ => unreachable!(),
};
// Optimise
try_update_hugr(circ, |mut circ| {
try_update_hugr(circ, |mut circ, _| {
let n_cx = circ
.commands()
.filter(|c| op_matches(c.optype(), T2Op::CX))
Expand Down
Loading

0 comments on commit eac7acf

Please sign in to comment.