Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Portmatching for StaticSizeCircuit #560

Open
wants to merge 5 commits into
base: feat/badgerv2
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
580 changes: 59 additions & 521 deletions Cargo.lock

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ lto = "thin"
resolver = "2"
members = [
"tket2",
"tket2-py",
"compile-rewriter",
"badger-optimiser",
# "tket2-py",
# "compile-rewriter",
# "badger-optimiser",
"tket2-hseries",
]
default-members = ["tket2", "tket2-hseries"]
Expand All @@ -33,7 +33,7 @@ pyo3 = "0.21.2"
itertools = "0.13.0"
tket-json-rs = "0.5.1"
tracing = "0.1.37"
portmatching = "0.3.1"
portmatching = "0.4.0-rc.1"
bytemuck = "1.17.0"
cgmath = "0.18.0"
chrono = "0.4.30"
Expand Down
6 changes: 3 additions & 3 deletions tket2-py/src/pattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::utils::{create_py_exception, ConvertPyErr};

use hugr::HugrView;
use pyo3::prelude::*;
use tket2::portmatching::{CircuitPattern, PatternMatch, PatternMatcher};
use tket2::portmatching::{CircuitPattern, PatternMatch, CircuitMatcher};
use tket2::Circuit;

/// The module definition
Expand Down Expand Up @@ -80,7 +80,7 @@ impl Rule {
}
#[pyclass]
struct RuleMatcher {
matcher: PatternMatcher,
matcher: CircuitMatcher,
rights: Vec<Circuit>,
}

Expand All @@ -92,7 +92,7 @@ impl RuleMatcher {
rules.into_iter().map(|Rule([l, r])| (l, r)).unzip();
let patterns: Result<Vec<CircuitPattern>, _> =
lefts.iter().map(CircuitPattern::try_from_circuit).collect();
let matcher = PatternMatcher::from_patterns(patterns.convert_pyerrs()?);
let matcher = CircuitMatcher::from_patterns(patterns.convert_pyerrs()?);

Ok(Self { matcher, rights })
}
Expand Down
6 changes: 3 additions & 3 deletions tket2-py/src/pattern/portmatching.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use itertools::Itertools;
use portmatching::PatternID;
use pyo3::{prelude::*, types::PyIterator};

use tket2::portmatching::{CircuitPattern, PatternMatch, PatternMatcher};
use tket2::portmatching::{CircuitPattern, PatternMatch, CircuitMatcher};

use crate::circuit::{try_with_circ, with_circ, PyNode};

Expand Down Expand Up @@ -54,15 +54,15 @@ impl PyCircuitPattern {
#[derive(Debug, Clone, From)]
pub struct PyPatternMatcher {
/// Rust representation of the matcher
pub matcher: PatternMatcher,
pub matcher: CircuitMatcher,
}

#[pymethods]
impl PyPatternMatcher {
/// Construct a matcher from a list of patterns.
#[new]
pub fn py_from_patterns(patterns: &Bound<PyIterator>) -> PyResult<Self> {
Ok(PatternMatcher::from_patterns(
Ok(CircuitMatcher::from_patterns(
patterns
.iter()?
.map(|p| {
Expand Down
1 change: 1 addition & 0 deletions tket2/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ criterion = { workspace = true, features = ["html_reports"] }
webbrowser = { workspace = true }
urlencoding = { workspace = true }
cool_asserts = { workspace = true }
insta = "1.39.0"

[[bench]]
name = "bench_main"
Expand Down
214 changes: 130 additions & 84 deletions tket2/src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pub mod units;
use std::iter::Sum;

pub use command::{Command, CommandIterator};
pub use hash::CircuitHash;
pub use hash::{CircuitHash, HashError};
use hugr::hugr::views::{DescendantsGraph, ExtractHugr, HierarchyView};
use itertools::Either::{Left, Right};

Expand Down Expand Up @@ -253,31 +253,6 @@ impl<T: HugrView> Circuit<T> {
self.commands().filter(|cmd| cmd.optype().is_custom_op())
}

/// Compute the cost of the circuit based on a per-operation cost function.
#[inline]
pub fn circuit_cost<F, C>(&self, op_cost: F) -> C
where
Self: Sized,
C: Sum,
F: Fn(&OpType) -> C,
{
self.commands().map(|cmd| op_cost(cmd.optype())).sum()
}

/// Compute the cost of a group of nodes in a circuit based on a
/// per-operation cost function.
#[inline]
pub fn nodes_cost<F, C>(&self, nodes: impl IntoIterator<Item = Node>, op_cost: F) -> C
where
C: Sum,
F: Fn(&OpType) -> C,
{
nodes
.into_iter()
.map(|n| op_cost(self.hugr.get_optype(n)))
.sum()
}

/// Return the graphviz representation of the underlying graph and hierarchy side by side.
///
/// For a simpler representation, use the [`Circuit::mermaid_string`] format instead.
Expand Down Expand Up @@ -321,6 +296,48 @@ impl<T: HugrView> Circuit<T> {
}
}

pub trait CircuitCostTrait {
/// Compute the cost of the circuit based on a per-operation cost function.
fn circuit_cost<F, C>(&self, op_cost: F) -> C
where
Self: Sized,
C: Sum,
F: Fn(&OpType) -> C;

/// Compute the cost of a group of nodes in a circuit based on a
/// per-operation cost function.
fn nodes_cost<F, C>(&self, nodes: impl IntoIterator<Item = Node>, op_cost: F) -> C
where
C: Sum,
F: Fn(&OpType) -> C;
}

impl<H: HugrView> CircuitCostTrait for Circuit<H> {
#[inline]
fn circuit_cost<F, C>(&self, op_cost: F) -> C
where
Self: Sized,
C: Sum,
F: Fn(&OpType) -> C,
{
self.commands().map(|cmd| op_cost(cmd.optype())).sum()
}

/// Compute the cost of a group of nodes in a circuit based on a
/// per-operation cost function.
#[inline]
fn nodes_cost<F, C>(&self, nodes: impl IntoIterator<Item = Node>, op_cost: F) -> C
where
C: Sum,
F: Fn(&OpType) -> C,
{
nodes
.into_iter()
.map(|n| op_cost(self.hugr.get_optype(n)))
.sum()
}
}

impl<T: HugrView> From<T> for Circuit<T> {
fn from(hugr: T) -> Self {
let parent = hugr.root();
Expand Down Expand Up @@ -360,61 +377,90 @@ fn check_hugr(hugr: &impl HugrView, parent: Node) -> Result<(), CircuitError> {
}
}

/// Remove an empty wire in a dataflow HUGR.
///
/// The wire to be removed is identified by the index of the outgoing port
/// at the circuit input node.
///
/// This will change the circuit signature and will shift all ports after
/// the removed wire by -1. If the wire is connected to the output node,
/// this will also change the signature output and shift the ports after
/// the removed wire by -1.
///
/// This will return an error if the wire is not empty or if a HugrError
/// occurs.
#[allow(dead_code)]
pub(crate) fn remove_empty_wire(
circ: &mut Circuit<impl HugrMut>,
input_port: usize,
) -> Result<(), CircuitMutError> {
let parent = circ.parent();
let hugr = circ.hugr_mut();

let [inp, out] = hugr.get_io(parent).expect("no IO nodes found at parent");
if input_port >= hugr.num_outputs(inp) {
return Err(CircuitMutError::InvalidPortOffset(input_port));
}
let input_port = OutgoingPort::from(input_port);
let link = hugr
.linked_inputs(inp, input_port)
.at_most_one()
.map_err(|_| CircuitMutError::DeleteNonEmptyWire(input_port.index()))?;
if link.is_some() && link.unwrap().0 != out {
return Err(CircuitMutError::DeleteNonEmptyWire(input_port.index()));
}
if link.is_some() {
hugr.disconnect(inp, input_port);
}

// Shift ports at input
shift_ports(hugr, inp, input_port, hugr.num_outputs(inp))?;
// Shift ports at output
if let Some((out, output_port)) = link {
shift_ports(hugr, out, output_port, hugr.num_inputs(out))?;
}
// Update input node, output node (if necessary) and parent signatures.
update_signature(
hugr,
parent,
input_port.index(),
link.map(|(_, p)| p.index()),
)?;
// Resize ports at input/output node
hugr.set_num_ports(inp, 0, hugr.num_outputs(inp) - 1);
if let Some((out, _)) = link {
hugr.set_num_ports(out, hugr.num_inputs(out) - 1, 0);
pub(crate) trait RemoveEmptyWire {
/// Remove an empty wire in a dataflow HUGR.
///
/// The wire to be removed is identified by the index of the outgoing port
/// at the circuit input node.
///
/// This will change the circuit signature and will shift all ports after
/// the removed wire by -1. If the wire is connected to the output node,
/// this will also change the signature output and shift the ports after
/// the removed wire by -1.
///
/// This will return an error if the wire is not empty or if a HugrError
/// occurs.
fn remove_empty_wire(&mut self, input_port: usize) -> Result<(), CircuitMutError>;

/// The port offsets of wires that are empty.
fn empty_wires(&self) -> Vec<usize>;
}

impl<H: HugrMut> RemoveEmptyWire for Circuit<H> {
#[allow(dead_code)]
fn remove_empty_wire(&mut self, input_port: usize) -> Result<(), CircuitMutError> {
let parent = self.parent();
let hugr = self.hugr_mut();

let [inp, out] = hugr.get_io(parent).expect("no IO nodes found at parent");
if input_port >= hugr.num_outputs(inp) {
return Err(CircuitMutError::InvalidPortOffset(input_port));
}
let input_port = OutgoingPort::from(input_port);
let link = hugr
.linked_inputs(inp, input_port)
.at_most_one()
.map_err(|_| CircuitMutError::DeleteNonEmptyWire(input_port.index()))?;
if link.is_some() && link.unwrap().0 != out {
return Err(CircuitMutError::DeleteNonEmptyWire(input_port.index()));
}
if link.is_some() {
hugr.disconnect(inp, input_port);
}

// Shift ports at input
shift_ports(hugr, inp, input_port, hugr.num_outputs(inp))?;
// Shift ports at output
if let Some((out, output_port)) = link {
shift_ports(hugr, out, output_port, hugr.num_inputs(out))?;
}
// Update input node, output node (if necessary) and parent signatures.
update_signature(
hugr,
parent,
input_port.index(),
link.map(|(_, p)| p.index()),
)?;
// Resize ports at input/output node
hugr.set_num_ports(inp, 0, hugr.num_outputs(inp) - 1);
if let Some((out, _)) = link {
hugr.set_num_ports(out, hugr.num_inputs(out) - 1, 0);
}
Ok(())
}

/// The port offsets of wires that are empty.
fn empty_wires(&self) -> Vec<usize> {
let hugr = self.hugr();
let input = self.input_node();
let input_sig = hugr.signature(input).unwrap();
hugr.node_outputs(input)
// Only consider dataflow edges
.filter(|&p| input_sig.out_port_type(p).is_some())
// Only consider ports linked to at most one other port
.filter_map(|p| Some((p, hugr.linked_ports(input, p).at_most_one().ok()?)))
// Ports are either connected to output or nothing
.filter_map(|(from, to)| {
if let Some((n, _)) = to {
// Wires connected to output
(n == self.output_node()).then_some(from.index())
} else {
// Wires connected to nothing
Some(from.index())
}
})
.collect()
}
Ok(())
}

/// Errors that can occur when mutating a circuit.
Expand Down Expand Up @@ -690,10 +736,10 @@ mod tests {
.unwrap();

assert_eq!(circ.qubit_count(), 2);
assert!(remove_empty_wire(&mut circ, 1).is_ok());
assert!(circ.remove_empty_wire(1).is_ok());
assert_eq!(circ.qubit_count(), 1);
assert_eq!(
remove_empty_wire(&mut circ, 0).unwrap_err(),
circ.remove_empty_wire(0).unwrap_err(),
CircuitMutError::DeleteNonEmptyWire(0)
);
}
Expand All @@ -717,10 +763,10 @@ mod tests {
.into();

assert_eq!(circ.units().count(), 1);
assert!(remove_empty_wire(&mut circ, 0).is_ok());
assert!(circ.remove_empty_wire(0).is_ok());
assert_eq!(circ.units().count(), 0);
assert_eq!(
remove_empty_wire(&mut circ, 2).unwrap_err(),
circ.remove_empty_wire(2).unwrap_err(),
CircuitMutError::InvalidPortOffset(2)
);
}
Expand Down
2 changes: 2 additions & 0 deletions tket2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ pub mod optimiser;
pub mod passes;
pub mod rewrite;
pub mod serialize;
#[cfg(feature = "portmatching")]
pub mod static_circ;

#[cfg(feature = "portmatching")]
pub mod portmatching;
Expand Down
Loading