Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Use tket1 and tket2 circuits interchangeably everywhere #243

Merged
merged 5 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ authors = [] # TODO
maintainers = [] # TODO
include = ["pyproject.toml"]
license = "Apache-2.0"
license_file = "LICENCE"
readme = "README.md"

packages = [{ include = "tket2-py" }]
Expand Down
19 changes: 6 additions & 13 deletions tket2-py/src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,18 @@ use tket2::json::TKETDecode;
use tket2::rewrite::CircuitRewrite;
use tket_json_rs::circuit_json::SerialCircuit;

pub use self::convert::{try_update_hugr, try_with_hugr, update_hugr, with_hugr, T2Circuit};
pub use self::convert::{try_update_hugr, try_with_hugr, update_hugr, with_hugr, Tk2Circuit};

/// The module definition
pub fn module(py: Python) -> PyResult<&PyModule> {
let m = PyModule::new(py, "_circuit")?;
m.add_class::<T2Circuit>()?;
m.add_class::<Tk2Circuit>()?;
m.add_class::<PyNode>()?;
m.add_class::<tket2::T2Op>()?;
m.add_class::<tket2::Pauli>()?;

m.add_function(wrap_pyfunction!(validate_hugr, m)?)?;
m.add_function(wrap_pyfunction!(to_hugr_dot, m)?)?;
m.add_function(wrap_pyfunction!(to_hugr, m)?)?;

m.add("HugrError", py.get_type::<hugr::hugr::PyHugrError>())?;
m.add("BuildError", py.get_type::<hugr::builder::PyBuildError>())?;
Expand All @@ -47,20 +46,14 @@ pub fn module(py: Python) -> PyResult<&PyModule> {

/// Run the validation checks on a circuit.
#[pyfunction]
pub fn validate_hugr(c: Py<PyAny>) -> PyResult<()> {
try_with_hugr(c, |hugr| hugr.validate(&REGISTRY))
pub fn validate_hugr(c: &PyAny) -> PyResult<()> {
try_with_hugr(c, |hugr, _| hugr.validate(&REGISTRY))
}

/// Return a Graphviz DOT string representation of the circuit.
#[pyfunction]
pub fn to_hugr_dot(c: Py<PyAny>) -> PyResult<String> {
with_hugr(c, |hugr| hugr.dot_string())
}

/// Downcast a python object to a [`Hugr`].
#[pyfunction]
pub fn to_hugr(c: Py<PyAny>) -> PyResult<T2Circuit> {
with_hugr(c, |hugr| hugr.into())
pub fn to_hugr_dot(c: &PyAny) -> PyResult<String> {
with_hugr(c, |hugr, _| hugr.dot_string())
}

/// A [`hugr::Node`] wrapper for Python.
Expand Down
151 changes: 112 additions & 39 deletions tket2-py/src/circuit/convert.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
//! Utilities for calling Hugr functions on generic python objects.

use pyo3::exceptions::PyAttributeError;
use pyo3::{prelude::*, PyTypeInfo};

use derive_more::From;
use hugr::{Hugr, HugrView};
use serde::Serialize;
use tket2::extension::REGISTRY;
use tket2::json::TKETDecode;
use tket2::passes::CircuitChunks;
Expand All @@ -14,78 +16,149 @@ use crate::pattern::rewrite::PyCircuitRewrite;
/// A manager for tket 2 operations on a tket 1 Circuit.
#[pyclass]
#[derive(Clone, Debug, PartialEq, From)]
pub struct T2Circuit {
pub struct Tk2Circuit {
/// Rust representation of the circuit.
pub hugr: Hugr,
}

#[pymethods]
impl T2Circuit {
impl Tk2Circuit {
/// Cast a tket1 circuit to a [`Tk2Circuit`].
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still a couple mentions of "cast"?

#[new]
fn from_circuit(circ: PyObject) -> PyResult<Self> {
pub fn from_tket1(circ: &PyAny) -> PyResult<Self> {
Ok(Self {
hugr: with_hugr(circ, |hugr| hugr)?,
hugr: with_hugr(circ, |hugr, _| hugr)?,
})
}

fn finish(&self) -> PyResult<PyObject> {
SerialCircuit::encode(&self.hugr)?.to_tket1_with_gil()
/// Cast the [`Tk2Circuit`] to a tket1 circuit.
pub fn to_tket1<'py>(&self, py: Python<'py>) -> PyResult<&'py PyAny> {
SerialCircuit::encode(&self.hugr)?.to_tket1(py)
}

fn apply_match(&mut self, rw: PyCircuitRewrite) {
/// Apply a rewrite on the circuit.
pub fn apply_match(&mut self, rw: PyCircuitRewrite) {
rw.rewrite.apply(&mut self.hugr).expect("Apply error.");
}

/// Encode the circuit as a HUGR json string.
//
// TODO: Bind a messagepack encoder/decoder too.
pub fn to_hugr_json(&self) -> PyResult<String> {
Ok(serde_json::to_string(&self.hugr).unwrap())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a note that previously we have used messagepack (and the behaviour between json and messagepack is not always identical), but I'm ok avoiding the extra python dependency here for now - at least until the serialized format gets a bit more stable (hopefully soon)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a TODO

}

/// Decode a HUGR json string to a circuit.
#[staticmethod]
pub fn from_hugr_json(json: &str) -> PyResult<Self> {
let hugr = serde_json::from_str(json)
.map_err(|e| PyErr::new::<PyAttributeError, _>(format!("Invalid encoded HUGR: {e}")))?;
Ok(Tk2Circuit { hugr })
}

/// Encode the circuit as a tket1 json string.
///
/// FIXME: Currently the encoded circuit cannot be loaded back due to
/// [https://github.com/CQCL/hugr/issues/683]
pub fn to_tket1_json(&self) -> PyResult<String> {
Ok(serde_json::to_string(&SerialCircuit::encode(&self.hugr)?).unwrap())
}

/// Decode a tket1 json string to a circuit.
#[staticmethod]
pub fn from_tket1_json(json: &str) -> PyResult<Self> {
let tk1: SerialCircuit = serde_json::from_str(json)
.map_err(|e| PyErr::new::<PyAttributeError, _>(format!("Invalid encoded HUGR: {e}")))?;
Ok(Tk2Circuit {
hugr: tk1.decode()?,
})
}
}
impl T2Circuit {
/// Tries to extract a T2Circuit from a python object.
impl Tk2Circuit {
/// Tries to extract a Tk2Circuit from a python object.
///
/// Returns an error if the py object is not a T2Circuit.
pub fn try_extract(circ: Py<PyAny>) -> PyResult<Self> {
Python::with_gil(|py| circ.as_ref(py).extract::<T2Circuit>())
/// Returns an error if the py object is not a Tk2Circuit.
pub fn try_extract(circ: &PyAny) -> PyResult<Self> {
circ.extract::<Tk2Circuit>()
}
}

/// A flag to indicate the encoding of a circuit.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum CircuitType {
/// A `pytket` `Circuit`.
Tket1,
/// A tket2 `Tk2Circuit`, represented as a HUGR.
Tket2,
}

impl CircuitType {
/// Converts a `Hugr` into the format indicated by the flag.
pub fn convert(self, py: Python, hugr: Hugr) -> PyResult<&PyAny> {
match self {
CircuitType::Tket1 => SerialCircuit::encode(&hugr)?.to_tket1(py),
CircuitType::Tket2 => Ok(Py::new(py, Tk2Circuit { hugr })?.into_ref(py)),
}
}
}

/// Apply a fallible function expecting a hugr on a pytket circuit.
pub fn try_with_hugr<T, E, F>(circ: Py<PyAny>, f: F) -> PyResult<T>
/// Apply a fallible function expecting a hugr on a python circuit.
///
/// This method supports both `pytket.Circuit` and `Tk2Circuit` python objects.
pub fn try_with_hugr<T, E, F>(circ: &PyAny, f: F) -> PyResult<T>
where
E: Into<PyErr>,
F: FnOnce(Hugr) -> Result<T, E>,
F: FnOnce(Hugr, CircuitType) -> Result<T, E>,
{
let hugr = Python::with_gil(|py| -> PyResult<Hugr> {
let circ = circ.as_ref(py);
match T2Circuit::extract(circ) {
// hugr circuit
Ok(t2circ) => Ok(t2circ.hugr),
// tket1 circuit
Err(_) => Ok(SerialCircuit::from_tket1(circ)?.decode()?),
}
})?;
(f)(hugr).map_err(|e| e.into())
let (hugr, typ) = match Tk2Circuit::extract(circ) {
// hugr circuit
Ok(t2circ) => (t2circ.hugr, CircuitType::Tket2),
// tket1 circuit
Err(_) => (
SerialCircuit::from_tket1(circ)?.decode()?,
CircuitType::Tket1,
),
};
(f)(hugr, typ).map_err(|e| e.into())
}

/// Apply a function expecting a hugr on a pytket circuit.
pub fn with_hugr<T, F>(circ: Py<PyAny>, f: F) -> PyResult<T>
/// Apply a function expecting a hugr on a python circuit.
///
/// This method supports both `pytket.Circuit` and `Tk2Circuit` python objects.
pub fn with_hugr<T, F>(circ: &PyAny, f: F) -> PyResult<T>
where
F: FnOnce(Hugr) -> T,
F: FnOnce(Hugr, CircuitType) -> T,
{
try_with_hugr(circ, |hugr| Ok::<T, PyErr>((f)(hugr)))
try_with_hugr(circ, |hugr, typ| Ok::<T, PyErr>((f)(hugr, typ)))
}

/// Apply a hugr-to-hugr function on a pytket circuit, and return the modified circuit.
pub fn try_update_hugr<E, F>(circ: Py<PyAny>, f: F) -> PyResult<Py<PyAny>>
/// Apply a fallible hugr-to-hugr function on a python circuit, and return the modified circuit.
///
/// This method supports both `pytket.Circuit` and `Tk2Circuit` python objects.
/// The returned Hugr is converted to the matching python object.
pub fn try_update_hugr<E, F>(circ: &PyAny, f: F) -> PyResult<&PyAny>
where
E: Into<PyErr>,
F: FnOnce(Hugr) -> Result<Hugr, E>,
F: FnOnce(Hugr, CircuitType) -> Result<Hugr, E>,
{
let hugr = try_with_hugr(circ, f)?;
SerialCircuit::encode(&hugr)?.to_tket1_with_gil()
let py = circ.py();
try_with_hugr(circ, |hugr, typ| {
let hugr = f(hugr, typ).map_err(|e| e.into())?;
typ.convert(py, hugr)
})
}

/// Apply a hugr-to-hugr function on a pytket circuit, and return the modified circuit.
pub fn update_hugr<F>(circ: Py<PyAny>, f: F) -> PyResult<Py<PyAny>>
/// Apply a hugr-to-hugr function on a python circuit, and return the modified circuit.
///
/// This method supports both `pytket.Circuit` and `Tk2Circuit` python objects.
/// The returned Hugr is converted to the matching python object.
pub fn update_hugr<F>(circ: &PyAny, f: F) -> PyResult<&PyAny>
where
F: FnOnce(Hugr) -> Hugr,
F: FnOnce(Hugr, CircuitType) -> Hugr,
{
let hugr = with_hugr(circ, f)?;
SerialCircuit::encode(&hugr)?.to_tket1_with_gil()
let py = circ.py();
try_with_hugr(circ, |hugr, typ| {
let hugr = f(hugr, typ);
typ.convert(py, hugr)
})
}
8 changes: 4 additions & 4 deletions tket2-py/src/optimiser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,16 @@ impl PyBadgerOptimiser {
/// * `log_progress`: The path to a CSV file to log progress to.
///
#[pyo3(name = "optimise")]
pub fn py_optimise(
pub fn py_optimise<'py>(
&self,
circ: PyObject,
circ: &'py PyAny,
timeout: Option<u64>,
n_threads: Option<NonZeroUsize>,
split_circ: Option<bool>,
log_progress: Option<PathBuf>,
queue_size: Option<usize>,
) -> PyResult<PyObject> {
update_hugr(circ, |circ| {
) -> PyResult<&'py PyAny> {
update_hugr(circ, |circ, _| {
self.optimise(
circ,
timeout,
Expand Down
45 changes: 21 additions & 24 deletions tket2-py/src/passes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ pub mod chunks;
use std::{cmp::min, convert::TryInto, fs, num::NonZeroUsize, path::PathBuf};

use pyo3::{prelude::*, types::IntoPyDict};
use tket2::{json::TKETDecode, op_matches, passes::apply_greedy_commutation, Circuit, T2Op};
use tket_json_rs::circuit_json::SerialCircuit;
use tket2::{op_matches, passes::apply_greedy_commutation, Circuit, T2Op};

use crate::{
circuit::{try_update_hugr, try_with_hugr},
Expand All @@ -30,35 +29,33 @@ pub fn module(py: Python) -> PyResult<&PyModule> {
}

#[pyfunction]
fn greedy_depth_reduce(py_c: PyObject) -> PyResult<(PyObject, u32)> {
try_with_hugr(py_c, |mut h| {
fn greedy_depth_reduce(circ: &PyAny) -> PyResult<(&PyAny, u32)> {
let py = circ.py();
try_with_hugr(circ, |mut h, typ| {
let n_moves = apply_greedy_commutation(&mut h)?;
let py_c = SerialCircuit::encode(&h)?.to_tket1_with_gil()?;
PyResult::Ok((py_c, n_moves))
let circ = typ.convert(py, h)?;
PyResult::Ok((circ, n_moves))
})
}

/// Rebase a circuit to the Nam gate set (CX, Rz, H) using TKET1.
///
/// Acquires the python GIL to call TKET's `auto_rebase_pass`.
///
/// Equivalent to running the following code:
/// ```python
/// from pytket.passes.auto_rebase import auto_rebase_pass
/// from pytket import OpType
/// auto_rebase_pass({OpType.CX, OpType.Rz, OpType.H}).apply(circ)"
// ```
fn rebase_nam(circ: &PyObject) -> PyResult<()> {
Python::with_gil(|py| {
let auto_rebase = py
.import("pytket.passes.auto_rebase")?
.getattr("auto_rebase_pass")?;
let optype = py.import("pytket")?.getattr("OpType")?;
let locals = [("OpType", &optype)].into_py_dict(py);
let op_set = py.eval("{OpType.CX, OpType.Rz, OpType.H}", None, Some(locals))?;
let rebase_pass = auto_rebase.call1((op_set,))?.getattr("apply")?;
rebase_pass.call1((circ,)).map(|_| ())
})
fn rebase_nam(circ: &PyAny) -> PyResult<()> {
let py = circ.py();
let auto_rebase = py
.import("pytket.passes.auto_rebase")?
.getattr("auto_rebase_pass")?;
let optype = py.import("pytket")?.getattr("OpType")?;
let locals = [("OpType", &optype)].into_py_dict(py);
let op_set = py.eval("{OpType.CX, OpType.Rz, OpType.H}", None, Some(locals))?;
let rebase_pass = auto_rebase.call1((op_set,))?.getattr("apply")?;
rebase_pass.call1((circ,)).map(|_| ())
}

/// Badger optimisation pass.
Expand All @@ -76,14 +73,14 @@ fn rebase_nam(circ: &PyObject) -> PyResult<()> {
///
/// Log files will be written to the directory `log_dir` if specified.
#[pyfunction]
fn badger_optimise(
circ: PyObject,
fn badger_optimise<'py>(
circ: &'py PyAny,
optimiser: &PyBadgerOptimiser,
max_threads: Option<NonZeroUsize>,
timeout: Option<u64>,
log_dir: Option<PathBuf>,
rebase: Option<bool>,
) -> PyResult<PyObject> {
) -> PyResult<&'py PyAny> {
// Default parameter values
let rebase = rebase.unwrap_or(true);
let max_threads = max_threads.unwrap_or(num_cpus::get().try_into().unwrap());
Expand All @@ -94,7 +91,7 @@ fn badger_optimise(
}
// Rebase circuit
if rebase {
rebase_nam(&circ)?;
rebase_nam(circ)?;
}
// Logic to choose how to split the circuit
let badger_splits = |n_threads: NonZeroUsize| match n_threads.get() {
Expand All @@ -111,7 +108,7 @@ fn badger_optimise(
_ => unreachable!(),
};
// Optimise
try_update_hugr(circ, |mut circ| {
try_update_hugr(circ, |mut circ, _| {
let n_cx = circ
.commands()
.filter(|c| op_matches(c.optype(), T2Op::CX))
Expand Down
Loading