Skip to content

Commit

Permalink
refactor!: Replace Circuit trait with a struct
Browse files Browse the repository at this point in the history
  • Loading branch information
aborgna-q committed May 31, 2024
1 parent bc17af7 commit 7edcf8c
Show file tree
Hide file tree
Showing 38 changed files with 829 additions and 550 deletions.
54 changes: 42 additions & 12 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ missing_docs = "warn"
[workspace.dependencies]

tket2 = { path = "./tket2" }
hugr = "0.4.0"
hugr = "0.5.0"
portgraph = "0.12"
pyo3 = "0.21.2"
itertools = "0.13.0"
Expand Down Expand Up @@ -60,3 +60,4 @@ tracing-subscriber = "0.3.17"
typetag = "0.2.8"
urlencoding = "2.1.2"
webbrowser = "1.0.0"
cool_asserts = "2.0.3"
1 change: 0 additions & 1 deletion badger-optimiser/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ use tket2::json::{load_tk1_json_file, save_tk1_json_file};
use tket2::optimiser::badger::log::BadgerLogger;
use tket2::optimiser::badger::BadgerOptions;
use tket2::optimiser::{BadgerOptimiser, DefaultBadgerOptimiser};
use tket2::rewrite::trace::RewriteTracer;

#[cfg(feature = "peak_alloc")]
#[global_allocator]
Expand Down
12 changes: 7 additions & 5 deletions tket2-py/src/circuit/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ impl CircuitType {
/// Converts a `Hugr` into the format indicated by the flag.
pub fn convert(self, py: Python, hugr: Hugr) -> PyResult<Bound<PyAny>> {
match self {
CircuitType::Tket1 => SerialCircuit::encode(&hugr).convert_pyerrs()?.to_tket1(py),
CircuitType::Tket2 => Ok(Bound::new(py, Tk2Circuit { hugr })?.into_any()),
CircuitType::Tket1 => SerialCircuit::encode(&hugr.into())
.convert_pyerrs()?
.to_tket1(py),
CircuitType::Tket2 => Ok(Bound::new(py, Tk2Circuit { circ: hugr.into() })?.into_any()),
}
}
}
Expand All @@ -58,16 +60,16 @@ where
E: ConvertPyErr<Output = PyErr>,
F: FnOnce(Hugr, CircuitType) -> Result<T, E>,
{
let (hugr, typ) = match Tk2Circuit::extract_bound(circ) {
let (circ, typ) = match Tk2Circuit::extract_bound(circ) {
// hugr circuit
Ok(t2circ) => (t2circ.hugr, CircuitType::Tket2),
Ok(t2circ) => (t2circ.circ, CircuitType::Tket2),
// tket1 circuit
Err(_) => (
SerialCircuit::from_tket1(circ)?.decode().convert_pyerrs()?,
CircuitType::Tket1,
),
};
(f)(hugr, typ).map_err(|e| e.convert_pyerrs())
(f)(circ.into_hugr(), typ).map_err(|e| e.convert_pyerrs())
}

/// Apply a function expecting a hugr on a python circuit.
Expand Down
40 changes: 22 additions & 18 deletions tket2-py/src/circuit/tk2circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ use super::{cost, with_hugr, PyCircuitCost, PyCustom, PyHugrType, PyNode, PyWire
#[derive(Clone, Debug, PartialEq, From)]
pub struct Tk2Circuit {
/// Rust representation of the circuit.
pub hugr: Hugr,
pub circ: Circuit,
}

#[pymethods]
Expand All @@ -67,40 +67,40 @@ impl Tk2Circuit {
#[new]
pub fn new(circ: &Bound<PyAny>) -> PyResult<Self> {
Ok(Self {
hugr: with_hugr(circ, |hugr, _| hugr)?,
circ: with_hugr(circ, |hugr, _| hugr)?.into(),
})
}

/// Convert the [`Tk2Circuit`] to a tket1 circuit.
pub fn to_tket1<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
SerialCircuit::encode(&self.hugr)
SerialCircuit::encode(&self.circ)
.convert_pyerrs()?
.to_tket1(py)
}

/// Apply a rewrite on the circuit.
pub fn apply_rewrite(&mut self, rw: PyCircuitRewrite) {
rw.rewrite.apply(&mut self.hugr).expect("Apply error.");
rw.rewrite.apply(&mut self.circ).expect("Apply error.");
}

/// Encode the circuit as a HUGR json string.
//
// TODO: Bind a messagepack encoder/decoder too.
pub fn to_hugr_json(&self) -> PyResult<String> {
Ok(serde_json::to_string(&self.hugr).unwrap())
Ok(serde_json::to_string(self.circ.hugr()).unwrap())
}

/// Decode a HUGR json string to a circuit.
#[staticmethod]
pub fn from_hugr_json(json: &str) -> PyResult<Self> {
let hugr = serde_json::from_str(json)
let hugr: Hugr = serde_json::from_str(json)
.map_err(|e| PyErr::new::<PyAttributeError, _>(format!("Invalid encoded HUGR: {e}")))?;
Ok(Tk2Circuit { hugr })
Ok(Tk2Circuit { circ: hugr.into() })
}

/// Encode the circuit as a tket1 json string.
pub fn to_tket1_json(&self) -> PyResult<String> {
Ok(serde_json::to_string(&SerialCircuit::encode(&self.hugr).convert_pyerrs()?).unwrap())
Ok(serde_json::to_string(&SerialCircuit::encode(&self.circ).convert_pyerrs()?).unwrap())
}

/// Decode a tket1 json string to a circuit.
Expand All @@ -109,7 +109,7 @@ impl Tk2Circuit {
let tk1: SerialCircuit = serde_json::from_str(json)
.map_err(|e| PyErr::new::<PyAttributeError, _>(format!("Invalid encoded HUGR: {e}")))?;
Ok(Tk2Circuit {
hugr: tk1.decode().convert_pyerrs()?,
circ: tk1.decode().convert_pyerrs()?,
})
}

Expand All @@ -134,13 +134,13 @@ impl Tk2Circuit {
cost: cost.to_object(py),
})
};
let circ_cost = self.hugr.circuit_cost(cost_fn)?;
let circ_cost = self.circ.circuit_cost(cost_fn)?;
Ok(circ_cost.cost.into_bound(py))
}

/// Returns a hash of the circuit.
pub fn hash(&self) -> u64 {
self.hugr.circuit_hash().unwrap()
self.circ.circuit_hash().unwrap()
}

/// Hash the circuit
Expand All @@ -160,7 +160,8 @@ impl Tk2Circuit {

fn node_op(&self, node: PyNode) -> PyResult<PyCustom> {
let custom: CustomOp = self
.hugr
.circ
.hugr()
.get_optype(node.node)
.clone()
.try_into()
Expand All @@ -174,25 +175,27 @@ impl Tk2Circuit {
}

fn node_inputs(&self, node: PyNode) -> Vec<PyWire> {
self.hugr
self.circ
.hugr()
.all_linked_outputs(node.node)
.map(|(n, p)| Wire::new(n, p).into())
.collect()
}

fn node_outputs(&self, node: PyNode) -> Vec<PyWire> {
self.hugr
self.circ
.hugr()
.node_outputs(node.node)
.map(|p| Wire::new(node.node, p).into())
.collect()
}

fn input_node(&self) -> PyNode {
self.hugr.input().into()
self.circ.input_node().into()
}

fn output_node(&self) -> PyNode {
self.hugr.output().into()
self.circ.output_node().into()
}
}
impl Tk2Circuit {
Expand Down Expand Up @@ -236,11 +239,12 @@ impl Dfg {

fn finish(&mut self, outputs: Vec<PyWire>) -> PyResult<Tk2Circuit> {
Ok(Tk2Circuit {
hugr: self
circ: self
.builder
.clone()
.finish_hugr_with_outputs(outputs.into_iter().map_into(), &REGISTRY)
.convert_pyerrs()?,
.convert_pyerrs()?
.into(),
})
}
}
Expand Down
11 changes: 7 additions & 4 deletions tket2-py/src/optimiser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
use std::io::BufWriter;
use std::{fs, num::NonZeroUsize, path::PathBuf};

use hugr::Hugr;
use pyo3::prelude::*;
use tket2::optimiser::badger::BadgerOptions;
use tket2::optimiser::{BadgerLogger, DefaultBadgerOptimiser};
use tket2::Circuit;

use crate::circuit::update_hugr;

Expand Down Expand Up @@ -96,18 +96,21 @@ impl PyBadgerOptimiser {
split_circuit: split_circ.unwrap_or(false),
queue_size: queue_size.unwrap_or(100),
};
update_hugr(circ, |circ, _| self.optimise(circ, log_progress, options))
update_hugr(circ, |circ, _| {
self.optimise(circ.into(), log_progress, options)
.into_hugr()
})
}
}

impl PyBadgerOptimiser {
/// The Python optimise method, but on Hugrs.
pub(super) fn optimise(
&self,
circ: Hugr,
circ: Circuit,
log_progress: Option<PathBuf>,
options: BadgerOptions,
) -> Hugr {
) -> Circuit {
let badger_logger = log_progress
.map(|file_name| {
let log_file = fs::File::create(file_name).unwrap();
Expand Down
12 changes: 7 additions & 5 deletions tket2-py/src/passes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,10 @@ create_py_exception!(
#[pyfunction]
fn greedy_depth_reduce<'py>(circ: &Bound<'py, PyAny>) -> PyResult<(Bound<'py, PyAny>, u32)> {
let py = circ.py();
try_with_hugr(circ, |mut h, typ| {
let n_moves = apply_greedy_commutation(&mut h).convert_pyerrs()?;
let circ = typ.convert(py, h)?;
try_with_hugr(circ, |h, typ| {
let mut circ: Circuit = h.into();
let n_moves = apply_greedy_commutation(&mut circ).convert_pyerrs()?;
let circ = typ.convert(py, circ.into_hugr())?;
PyResult::Ok((circ, n_moves))
})
}
Expand Down Expand Up @@ -117,7 +118,8 @@ fn badger_optimise<'py>(
_ => unreachable!(),
};
// Optimise
try_update_hugr(circ, |mut circ, _| {
try_update_hugr(circ, |hugr, _| {
let mut circ: Circuit = hugr.into();
let n_cx = circ
.commands()
.filter(|c| op_matches(c.optype(), Tk2Op::CX))
Expand All @@ -142,6 +144,6 @@ fn badger_optimise<'py>(
};
circ = optimiser.optimise(circ, log_file, options);
}
PyResult::Ok(circ)
PyResult::Ok(circ.into_hugr())
})
}
Loading

0 comments on commit 7edcf8c

Please sign in to comment.