Skip to content

Commit

Permalink
Drop tket1/2 converter functions, add k to T2Circuit
Browse files Browse the repository at this point in the history
  • Loading branch information
aborgna-q committed Nov 15, 2023
1 parent 0a80938 commit 19cf309
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 46 deletions.
18 changes: 2 additions & 16 deletions tket2-py/src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,18 @@ use tket2::json::TKETDecode;
use tket2::rewrite::CircuitRewrite;
use tket_json_rs::circuit_json::SerialCircuit;

pub use self::convert::{try_update_hugr, try_with_hugr, update_hugr, with_hugr, T2Circuit};
pub use self::convert::{try_update_hugr, try_with_hugr, update_hugr, with_hugr, Tk2Circuit};

/// The module definition
pub fn module(py: Python) -> PyResult<&PyModule> {
let m = PyModule::new(py, "_circuit")?;
m.add_class::<T2Circuit>()?;
m.add_class::<Tk2Circuit>()?;
m.add_class::<PyNode>()?;
m.add_class::<tket2::T2Op>()?;
m.add_class::<tket2::Pauli>()?;

m.add_function(wrap_pyfunction!(validate_hugr, m)?)?;
m.add_function(wrap_pyfunction!(to_hugr_dot, m)?)?;
m.add_function(wrap_pyfunction!(tket1_to_tket2, m)?)?;
m.add_function(wrap_pyfunction!(tket2_to_tket1, m)?)?;

m.add("HugrError", py.get_type::<hugr::hugr::PyHugrError>())?;
m.add("BuildError", py.get_type::<hugr::builder::PyBuildError>())?;
Expand Down Expand Up @@ -58,18 +56,6 @@ pub fn to_hugr_dot(c: &PyAny) -> PyResult<String> {
with_hugr(c, |hugr, _| hugr.dot_string())
}

/// Cast a python tket1 circuit to a [`T2Circuit`].
#[pyfunction]
pub fn tket1_to_tket2(c: &PyAny) -> PyResult<T2Circuit> {
T2Circuit::from_circuit(c)
}

/// Cast a [`T2Circuit`] to a python tket1 circuit.
#[pyfunction]
pub fn tket2_to_tket1(py: Python, c: T2Circuit) -> PyResult<&PyAny> {
c.finish(py)
}

/// A [`hugr::Node`] wrapper for Python.
#[pyclass]
#[pyo3(name = "Node")]
Expand Down
30 changes: 15 additions & 15 deletions tket2-py/src/circuit/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,23 @@ use crate::pattern::rewrite::PyCircuitRewrite;
/// A manager for tket 2 operations on a tket 1 Circuit.
#[pyclass]
#[derive(Clone, Debug, PartialEq, From)]
pub struct T2Circuit {
pub struct Tk2Circuit {
/// Rust representation of the circuit.
pub hugr: Hugr,
}

#[pymethods]
impl T2Circuit {
/// Cast a tket1 circuit to a [`T2Circuit`].
impl Tk2Circuit {
/// Cast a tket1 circuit to a [`Tk2Circuit`].
#[new]
pub fn from_circuit(circ: &PyAny) -> PyResult<Self> {
pub fn from_tket1(circ: &PyAny) -> PyResult<Self> {
Ok(Self {
hugr: with_hugr(circ, |hugr, _| hugr)?,
})
}

/// Cast the [`T2Circuit`] to a tket1 circuit.
pub fn finish<'py>(&self, py: Python<'py>) -> PyResult<&'py PyAny> {
/// Cast the [`Tk2Circuit`] to a tket1 circuit.
pub fn to_tket1<'py>(&self, py: Python<'py>) -> PyResult<&'py PyAny> {
SerialCircuit::encode(&self.hugr)?.to_tket1(py)
}

Expand All @@ -51,7 +51,7 @@ impl T2Circuit {
pub fn from_hugr_json(json: &str) -> PyResult<Self> {
let hugr = serde_json::from_str(json)
.map_err(|e| PyErr::new::<PyAttributeError, _>(format!("Invalid encoded HUGR: {e}")))?;
Ok(T2Circuit { hugr })
Ok(Tk2Circuit { hugr })
}

/// Encode the circuit as a tket1 json string.
Expand All @@ -67,17 +67,17 @@ impl T2Circuit {
pub fn from_tket1_json(json: &str) -> PyResult<Self> {
let tk1: SerialCircuit = serde_json::from_str(json)
.map_err(|e| PyErr::new::<PyAttributeError, _>(format!("Invalid encoded HUGR: {e}")))?;
Ok(T2Circuit {
Ok(Tk2Circuit {
hugr: tk1.decode()?,
})
}
}
impl T2Circuit {
/// Tries to extract a T2Circuit from a python object.
impl Tk2Circuit {
/// Tries to extract a Tk2Circuit from a python object.
///
/// Returns an error if the py object is not a T2Circuit.
/// Returns an error if the py object is not a Tk2Circuit.
pub fn try_extract(circ: &PyAny) -> PyResult<Self> {
circ.extract::<T2Circuit>()
circ.extract::<Tk2Circuit>()
}
}

Expand All @@ -86,7 +86,7 @@ impl T2Circuit {
pub enum CircuitType {
/// A `pytket` `Circuit`.
Tket1,
/// A tket2 `T2Circuit`, represented as a HUGR.
/// A tket2 `Tk2Circuit`, represented as a HUGR.
Tket2,
}

Expand All @@ -95,7 +95,7 @@ impl CircuitType {
pub fn convert(self, py: Python, hugr: Hugr) -> PyResult<&PyAny> {
match self {
CircuitType::Tket1 => SerialCircuit::encode(&hugr)?.to_tket1(py),
CircuitType::Tket2 => Ok(Py::new(py, T2Circuit { hugr })?.into_ref(py)),
CircuitType::Tket2 => Ok(Py::new(py, Tk2Circuit { hugr })?.into_ref(py)),
}
}
}
Expand All @@ -106,7 +106,7 @@ where
E: Into<PyErr>,
F: FnOnce(Hugr, CircuitType) -> Result<T, E>,
{
let (hugr, typ) = match T2Circuit::extract(circ) {
let (hugr, typ) = match Tk2Circuit::extract(circ) {
// hugr circuit
Ok(t2circ) => (t2circ.hugr, CircuitType::Tket2),
// tket1 circuit
Expand Down
2 changes: 1 addition & 1 deletion tket2-py/src/passes/chunks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::circuit::{try_with_hugr, with_hugr};
#[pyfunction]
pub fn chunks(c: &PyAny, max_chunk_size: usize) -> PyResult<PyCircuitChunks> {
with_hugr(c, |hugr, typ| {
// TODO: Detect if the circuit is in tket1 format or T2Circuit.
// TODO: Detect if the circuit is in tket1 format or Tk2Circuit.
let chunks = CircuitChunks::split(&hugr, max_chunk_size);
(chunks, typ).into()
})
Expand Down
8 changes: 4 additions & 4 deletions tket2-py/src/pattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
pub mod portmatching;
pub mod rewrite;

use crate::circuit::{tket1_to_tket2, T2Circuit};
use crate::circuit::Tk2Circuit;

use hugr::Hugr;
use pyo3::prelude::*;
Expand Down Expand Up @@ -46,8 +46,8 @@ pub struct Rule(pub [Hugr; 2]);
impl Rule {
#[new]
fn new_rule(l: &PyAny, r: &PyAny) -> PyResult<Rule> {
let l = tket1_to_tket2(l)?;
let r = tket1_to_tket2(r)?;
let l = Tk2Circuit::from_tket1(l)?;
let r = Tk2Circuit::from_tket1(r)?;

Ok(Rule([l.hugr, r.hugr]))
}
Expand All @@ -71,7 +71,7 @@ impl RuleMatcher {
Ok(Self { matcher, rights })
}

pub fn find_match(&self, target: &T2Circuit) -> PyResult<Option<PyCircuitRewrite>> {
pub fn find_match(&self, target: &Tk2Circuit) -> PyResult<Option<PyCircuitRewrite>> {
let h = &target.hugr;
if let Some(p_match) = self.matcher.find_matches_iter(h).next() {
let r = self.rights.get(p_match.pattern_id().0).unwrap().clone();
Expand Down
20 changes: 10 additions & 10 deletions tket2-py/test/test_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,21 @@

from tket2 import passes
from tket2.passes import greedy_depth_reduce
from tket2.circuit import T2Circuit, to_hugr_dot, tket1_to_tket2, tket2_to_tket1
from tket2.circuit import Tk2Circuit, to_hugr_dot
from tket2.pattern import Rule, RuleMatcher


def test_conversion():
tk1 = Circuit(4).CX(0, 2).CX(1, 2).CX(1, 3)
tk1_dot = to_hugr_dot(tk1)

tk2 = tket1_to_tket2(tk1)
tk2 = Tk2Circuit(tk1)
tk2_dot = to_hugr_dot(tk2)

assert type(tk2) == T2Circuit
assert type(tk2) == Tk2Circuit
assert tk1_dot == tk2_dot

tk1_back = tket2_to_tket1(tk2)
tk1_back = tk2.to_tket1()

assert tk1_back == tk1
assert type(tk1_back) == Circuit
Expand Down Expand Up @@ -54,14 +54,14 @@ def test_chunks():
assert type(c2) == Circuit

# Split and reassemble, with a tket2 circuit
tk2_chunks = passes.chunks(T2Circuit(c2), 2)
tk2_chunks = passes.chunks(Tk2Circuit(c2), 2)
tk2 = tk2_chunks.reassemble()

assert type(tk2) == T2Circuit
assert type(tk2) == Tk2Circuit


def test_cx_rule():
c = T2Circuit(Circuit(4).CX(0, 2).CX(1, 2).CX(1, 2))
c = Tk2Circuit(Circuit(4).CX(0, 2).CX(1, 2).CX(1, 2))

rule = Rule(Circuit(2).CX(0, 1).CX(0, 1), Circuit(2))
matcher = RuleMatcher([rule])
Expand All @@ -70,13 +70,13 @@ def test_cx_rule():

c.apply_match(mtch)

out = c.finish()
out = c.to_tket1()

assert out == Circuit(4).CX(0, 2)


def test_multiple_rules():
circ = T2Circuit(Circuit(3).CX(0, 1).H(0).H(1).H(2).Z(0).H(0).H(1).H(2))
circ = Tk2Circuit(Circuit(3).CX(0, 1).H(0).H(1).H(2).Z(0).H(0).H(1).H(2))

rule1 = Rule(Circuit(1).H(0).Z(0).H(0), Circuit(1).X(0))
rule2 = Rule(Circuit(1).H(0).H(0), Circuit(1))
Expand All @@ -89,5 +89,5 @@ def test_multiple_rules():

assert match_count == 3

out = circ.finish()
out = circ.to_tket1()
assert out == Circuit(3).CX(0, 1).X(0)

0 comments on commit 19cf309

Please sign in to comment.