From 92a31a8620770ec4bef49aa9f06aff49a23eb719 Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Fri, 17 Nov 2023 16:28:57 +0000 Subject: [PATCH] feat: EccRewriter bindings --- tket2-py/src/circuit/convert.rs | 2 +- tket2-py/src/lib.rs | 2 + tket2-py/src/pattern.rs | 5 +-- tket2-py/src/pattern/rewrite.rs | 19 --------- tket2-py/src/rewrite.rs | 75 +++++++++++++++++++++++++++++++++ tket2-py/tket2/__init__.py | 4 +- tket2-py/tket2/rewrite.py | 18 ++++++++ 7 files changed, 99 insertions(+), 26 deletions(-) delete mode 100644 tket2-py/src/pattern/rewrite.rs create mode 100644 tket2-py/src/rewrite.rs create mode 100644 tket2-py/tket2/rewrite.py diff --git a/tket2-py/src/circuit/convert.rs b/tket2-py/src/circuit/convert.rs index 47993ccd..e63167d7 100644 --- a/tket2-py/src/circuit/convert.rs +++ b/tket2-py/src/circuit/convert.rs @@ -11,7 +11,7 @@ use tket2::json::TKETDecode; use tket2::passes::CircuitChunks; use tket_json_rs::circuit_json::SerialCircuit; -use crate::pattern::rewrite::PyCircuitRewrite; +use crate::rewrite::PyCircuitRewrite; /// A manager for tket 2 operations on a tket 1 Circuit. #[pyclass] diff --git a/tket2-py/src/lib.rs b/tket2-py/src/lib.rs index f6008fa0..bb9ec8ef 100644 --- a/tket2-py/src/lib.rs +++ b/tket2-py/src/lib.rs @@ -3,6 +3,7 @@ pub mod circuit; pub mod optimiser; pub mod passes; pub mod pattern; +pub mod rewrite; use pyo3::prelude::*; @@ -14,6 +15,7 @@ fn tket2_py(py: Python, m: &PyModule) -> PyResult<()> { add_submodule(py, m, optimiser::module(py)?)?; add_submodule(py, m, passes::module(py)?)?; add_submodule(py, m, pattern::module(py)?)?; + add_submodule(py, m, rewrite::module(py)?)?; Ok(()) } diff --git a/tket2-py/src/pattern.rs b/tket2-py/src/pattern.rs index 322b5220..3aa8f14e 100644 --- a/tket2-py/src/pattern.rs +++ b/tket2-py/src/pattern.rs @@ -1,20 +1,17 @@ //! Pattern matching on circuits. pub mod portmatching; -pub mod rewrite; use crate::circuit::Tk2Circuit; +use crate::rewrite::PyCircuitRewrite; use hugr::Hugr; use pyo3::prelude::*; use tket2::portmatching::{CircuitPattern, PatternMatcher}; -use self::rewrite::PyCircuitRewrite; - /// The module definition pub fn module(py: Python) -> PyResult<&PyModule> { let m = PyModule::new(py, "_pattern")?; - m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/tket2-py/src/pattern/rewrite.rs b/tket2-py/src/pattern/rewrite.rs deleted file mode 100644 index 4bafd459..00000000 --- a/tket2-py/src/pattern/rewrite.rs +++ /dev/null @@ -1,19 +0,0 @@ -//! Bindings for circuit rewrites. - -use derive_more::From; -use pyo3::prelude::*; -use tket2::rewrite::CircuitRewrite; - -/// A rewrite rule for circuits. -/// -/// Python equivalent of [`CircuitRewrite`]. -/// -/// [`CircuitRewrite`]: tket2::rewrite::CircuitRewrite -#[pyclass] -#[pyo3(name = "CircuitRewrite")] -#[derive(Debug, Clone, From)] -#[repr(transparent)] -pub struct PyCircuitRewrite { - /// Rust representation of the circuit chunks. - pub rewrite: CircuitRewrite, -} diff --git a/tket2-py/src/rewrite.rs b/tket2-py/src/rewrite.rs new file mode 100644 index 00000000..ff833cc1 --- /dev/null +++ b/tket2-py/src/rewrite.rs @@ -0,0 +1,75 @@ +//! PyO3 wrapper for rewriters. + +use derive_more::From; +use itertools::Itertools; +use pyo3::prelude::*; +use std::path::PathBuf; +use tket2::rewrite::{CircuitRewrite, ECCRewriter, Rewriter}; + +use crate::circuit::Tk2Circuit; + +/// The module definition +pub fn module(py: Python) -> PyResult<&PyModule> { + let m = PyModule::new(py, "_rewrite")?; + m.add_class::()?; + m.add_class::()?; + Ok(m) +} + +/// A rewrite rule for circuits. +/// +/// Python equivalent of [`CircuitRewrite`]. +/// +/// [`CircuitRewrite`]: tket2::rewrite::CircuitRewrite +#[pyclass] +#[pyo3(name = "CircuitRewrite")] +#[derive(Debug, Clone, From)] +#[repr(transparent)] +pub struct PyCircuitRewrite { + /// Rust representation of the circuit chunks. + pub rewrite: CircuitRewrite, +} + +#[pymethods] +impl PyCircuitRewrite { + /// Number of nodes added or removed by the rewrite. + /// + /// The difference between the new number of nodes minus the old. A positive + /// number is an increase in node count, a negative number is a decrease. + pub fn node_count_delta(&self) -> isize { + self.rewrite.node_count_delta() + } + + /// The replacement subcircuit. + pub fn replacement(&self) -> Tk2Circuit { + self.rewrite.replacement().clone().into() + } +} + +/// A rewriter based on circuit equivalence classes. +/// +/// In every equivalence class, one circuit is chosen as the representative. +/// Valid rewrites turn a non-representative circuit into its representative, +/// or a representative circuit into any of the equivalent non-representative +#[pyclass(name = "ECCRewriter")] +pub struct PyECCRewriter(ECCRewriter); + +#[pymethods] +impl PyECCRewriter { + /// Load a precompiled ecc rewriter from a file. + #[staticmethod] + pub fn load_precompiled(path: PathBuf) -> PyResult { + Ok(Self(ECCRewriter::load_binary(path).map_err(|e| { + PyErr::new::(e.to_string()) + })?)) + } + + /// Returns a list of circuit rewrites that can be applied to the given Tk2Circuit. + pub fn get_rewrites(&self, circ: &Tk2Circuit) -> Vec { + self.0 + .get_rewrites(&circ.hugr) + .into_iter() + .map_into() + .collect() + } +} diff --git a/tket2-py/tket2/__init__.py b/tket2-py/tket2/__init__.py index 918ca898..74ae59fd 100644 --- a/tket2-py/tket2/__init__.py +++ b/tket2-py/tket2/__init__.py @@ -1,3 +1,3 @@ -from . import passes, circuit, optimiser, pattern +from . import passes, circuit, optimiser, pattern, rewrite -__all__ = [circuit, optimiser, passes, pattern] +__all__ = [circuit, optimiser, passes, pattern, rewrite] diff --git a/tket2-py/tket2/rewrite.py b/tket2-py/tket2/rewrite.py new file mode 100644 index 00000000..9566c6ff --- /dev/null +++ b/tket2-py/tket2/rewrite.py @@ -0,0 +1,18 @@ +# Re-export native bindings +from .tket2._rewrite import * # noqa: F403 +from .tket2 import _rewrite + +from pathlib import Path +import importlib + +__all__ = [ + "default_ecc_rewriter", + *_rewrite.__all__, +] + + +def default_ecc_rewriter() -> _rewrite.ECCRewriter: + """Load the default ecc rewriter.""" + # TODO: Cite, explain what this is + rewriter = Path(importlib.resources.files("tket2").joinpath("data/nam_6_3.rwr")) + return _rewrite.ECCRewriter.load_precompiled(rewriter)