Skip to content

Commit

Permalink
fix: Support non-DFG circuits (#391)
Browse files Browse the repository at this point in the history
Followup to #390.
Fixes #389.
Closes #385.
~~Blocked by CQCL/hugr#1175

Fixes support for non-`Dfg` circuits and circuits with a non-root
parent:

Adds a `Circuit::extract_dfg(&self)` function that extracts the circuit
into a new hugr with a DFG operation at the root.
In some cases, like in a CFG DataflowBlock node, this requires some
changes to the definition to eliminate the output sum type.
Here I only implemented it for the kind of blocks produced by guppy. We
could replace the manual implementation once
CQCL/hugr#818 gets implemented.

With this we can now fix #389 by extracting the circuit before using it
as a replacement in `SimpleReplacement::create_simple_replacement`.

Replaces `DfgBuilder` with `FunctionBuilder` where possible, so we can
use named circuits in the tests.
(This failed before due to the bug in CircuitRewrite).
  • Loading branch information
aborgna-q authored Jun 17, 2024
1 parent f18aa4d commit 39057d2
Show file tree
Hide file tree
Showing 9 changed files with 191 additions and 31 deletions.
3 changes: 2 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 8 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@ lto = "thin"

[workspace]
resolver = "2"
members = ["tket2", "tket2-py", "compile-rewriter", "badger-optimiser", "tket2-hseries"]
members = [
"tket2",
"tket2-py",
"compile-rewriter",
"badger-optimiser",
"tket2-hseries",
]
default-members = ["tket2", "tket2-hseries"]

[workspace.package]
Expand All @@ -21,6 +27,7 @@ missing_docs = "warn"
tket2 = { path = "./tket2" }
hugr = "0.5.1"
hugr-cli = "0.1.1"
hugr-core = "0.2.0"
portgraph = "0.12"
pyo3 = "0.21.2"
itertools = "0.13.0"
Expand Down
1 change: 1 addition & 0 deletions tket2/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ serde_yaml = { workspace = true }
portmatching = { workspace = true, optional = true, features = ["serde"] }
derive_more = { workspace = true }
hugr = { workspace = true }
hugr-core = { workspace = true }
portgraph = { workspace = true, features = ["serde"] }
strum_macros = { workspace = true }
strum = { workspace = true }
Expand Down
30 changes: 25 additions & 5 deletions tket2/src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

pub mod command;
pub mod cost;
mod extract_dfg;
mod hash;
pub mod units;

use std::iter::Sum;

pub use command::{Command, CommandIterator};
pub use hash::CircuitHash;
use hugr::hugr::views::{DescendantsGraph, ExtractHugr, HierarchyView};
use itertools::Either::{Left, Right};

use hugr::hugr::hugrmut::HugrMut;
Expand Down Expand Up @@ -299,6 +301,27 @@ impl<T: HugrView> Circuit<T> {
// TODO: See comment in `dot_string`.
self.hugr.mermaid_string()
}

/// Extracts the circuit into a new owned HUGR containing the circuit at the root.
/// Replaces the circuit container operation with an [`OpType::DFG`].
///
/// Regions that are not descendants of the parent node are not included in the new HUGR.
/// This may invalidate calls to functions defined elsewhere. Make sure to inline any
/// external functions before calling this method.
pub fn extract_dfg(&self) -> Result<Circuit<Hugr>, CircuitMutError>
where
T: ExtractHugr,
{
let mut circ = if self.parent == self.hugr.root() {
self.to_owned()
} else {
let view: DescendantsGraph = DescendantsGraph::try_new(&self.hugr, self.parent)
.expect("Circuit parent was not a dataflow container.");
view.extract_hugr().into()
};
extract_dfg::rewrite_into_dfg(&mut circ)?;
Ok(circ)
}
}

impl<T: HugrView> From<T> for Circuit<T> {
Expand Down Expand Up @@ -648,12 +671,9 @@ mod tests {
#[case] circ: Circuit,
#[case] qubits: usize,
#[case] bits: usize,
#[case] _name: Option<&str>,
#[case] name: Option<&str>,
) {
// TODO: The decoder discards the circuit name.
// This requires decoding circuits into `FuncDefn` nodes instead of `Dfg`,
// but currently that causes errors with the replacement methods.
//assert_eq!(circ.name(), name);
assert_eq!(circ.name(), name);
assert_eq!(circ.circuit_signature().input_count(), qubits + bits);
assert_eq!(circ.circuit_signature().output_count(), qubits + bits);
assert_eq!(circ.qubit_count(), qubits);
Expand Down
111 changes: 111 additions & 0 deletions tket2/src/circuit/extract_dfg.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
//! Internal implementation of `Circuit::extract_dfg`.

use hugr::hugr::hugrmut::HugrMut;
use hugr::hugr::NodeType;
use hugr::ops::{OpTrait, OpType, Output, DFG};
use hugr::types::{FunctionType, SumType, TypeEnum};
use hugr::HugrView;
use hugr_core::hugr::internal::HugrMutInternals;
use itertools::Itertools;

use crate::{Circuit, CircuitMutError};

/// Internal method used by [`extract_dfg`] to replace the parent node with a DFG node.
pub(super) fn rewrite_into_dfg(circ: &mut Circuit) -> Result<(), CircuitMutError> {
// Replace the parent node with a DFG node, if necessary.
let old_optype = circ.hugr.get_optype(circ.parent());
if matches!(old_optype, OpType::DFG(_)) {
return Ok(());
}

// If the region was a cfg with a single successor, unpack the output sum type.
let signature = circ.circuit_signature();
let signature = match old_optype {
OpType::DataflowBlock(_) => remove_cfg_empty_output_tuple(circ, signature)?,
_ => signature,
};

let dfg = DFG { signature };
let nodetype = circ.hugr.get_nodetype(circ.parent());
let input_extensions = nodetype.input_extensions().cloned();
let nodetype = NodeType::new(OpType::DFG(dfg), input_extensions);
circ.hugr.replace_op(circ.parent(), nodetype)?;

Ok(())
}

/// Remove an empty sum from a cfg's DataflowBlock output node, if possible.
///
/// Bails out if it cannot match the exact pattern, without modifying the
/// circuit.
///
/// TODO: This function is specialized towards the specific functions generated
/// by guppy. We should generalize this to work with non-empty sum types
/// when possible.
fn remove_cfg_empty_output_tuple(
circ: &mut Circuit,
signature: FunctionType,
) -> Result<FunctionType, CircuitMutError> {
let sig = signature;
let parent = circ.parent();

let output_node = circ.output_node();
let output_nodetype = circ.hugr.get_nodetype(output_node).clone();
let output_op = output_nodetype.op();

let output_sig = output_op
.dataflow_signature()
.expect("Exit node with no dataflow signature.");

// Only remove the port if it's an empty sum type.
if !matches!(
output_sig.input[0].as_type_enum(),
TypeEnum::Sum(SumType::Unit { size: 1 })
) {
return Ok(sig);
}

// There must be a zero-sized `Tag` operation.
let Some((tag_node, _)) = circ.hugr.single_linked_output(output_node, 0) else {
return Ok(sig);
};

let tag_op = circ.hugr.get_optype(tag_node);
if !matches!(tag_op, OpType::Tag(_)) {
return Ok(sig);
}

// Hacky replacement for the nodes.

// Drop the old nodes
let hugr = circ.hugr_mut();
let input_neighs = hugr.all_linked_outputs(output_node).skip(1).collect_vec();

hugr.remove_node(output_node);
hugr.remove_node(tag_node);

// Add a new output node.
let new_types = output_sig.input[1..].to_vec();
let new_op = Output {
types: new_types.clone().into(),
};
let new_node = hugr.add_node_with_parent(
parent,
NodeType::new(
new_op,
output_nodetype
.input_extensions()
.cloned()
.unwrap_or_default(),
),
);

// Reconnect the outputs.
for (i, (neigh, port)) in input_neighs.into_iter().enumerate() {
hugr.connect(neigh, port, new_node, i);
}

// Return the updated circuit signature.
let sig = FunctionType::new(sig.input, new_types);
Ok(sig)
}
2 changes: 1 addition & 1 deletion tket2/src/portmatching/matcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ impl From<InvalidSubgraph> for InvalidPatternMatch {
InvalidSubgraphBoundary::DisconnectedBoundaryPort(_, _),
) => InvalidPatternMatch::NotConvex,
InvalidSubgraph::EmptySubgraph => InvalidPatternMatch::EmptyMatch,
InvalidSubgraph::NoSharedParent | InvalidSubgraph::InvalidBoundary(_) => {
InvalidSubgraph::NoSharedParent { .. } | InvalidSubgraph::InvalidBoundary(_) => {
InvalidPatternMatch::InvalidSubcircuit
}
other => InvalidPatternMatch::Other(other),
Expand Down
40 changes: 27 additions & 13 deletions tket2/src/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub use ecc_rewriter::ECCRewriter;
use derive_more::{From, Into};
use hugr::hugr::hugrmut::HugrMut;
use hugr::hugr::views::sibling_subgraph::{InvalidReplacement, InvalidSubgraph};
use hugr::hugr::views::ExtractHugr;
use hugr::{
hugr::{views::SiblingSubgraph, Rewrite, SimpleReplacementError},
SimpleReplacement,
Expand All @@ -33,7 +34,7 @@ impl Subcircuit {
/// Create a new subcircuit induced from a set of nodes.
pub fn try_from_nodes(
nodes: impl Into<Vec<Node>>,
circ: &Circuit,
circ: &Circuit<impl HugrView>,
) -> Result<Self, InvalidSubgraph> {
let subgraph = SiblingSubgraph::try_from_nodes(nodes, circ.hugr())?;
Ok(Self { subgraph })
Expand All @@ -49,16 +50,25 @@ impl Subcircuit {
self.subgraph.node_count()
}

/// Create a rewrite rule to replace the subcircuit.
/// Create a rewrite rule to replace the subcircuit with a new circuit.
///
/// # Parameters
/// * `circuit` - The base circuit that contains the subcircuit.
/// * `replacement` - The new circuit to replace the subcircuit with.
pub fn create_rewrite(
&self,
source: &Circuit,
target: Circuit,
circuit: &Circuit<impl HugrView>,
replacement: Circuit<impl ExtractHugr>,
) -> Result<CircuitRewrite, InvalidReplacement> {
Ok(CircuitRewrite(self.subgraph.create_simple_replacement(
source.hugr(),
target.into_hugr(),
)?))
// The replacement must be a Dfg rooted hugr.
let replacement = replacement
.extract_dfg()
.unwrap_or_else(|e| panic!("{}", e))
.into_hugr();
Ok(CircuitRewrite(
self.subgraph
.create_simple_replacement(circuit.hugr(), replacement)?,
))
}
}

Expand All @@ -69,13 +79,17 @@ pub struct CircuitRewrite(SimpleReplacement);
impl CircuitRewrite {
/// Create a new rewrite rule.
pub fn try_new(
source_position: &Subcircuit,
source: &Circuit<impl HugrView>,
target: Circuit,
circuit_position: &Subcircuit,
circuit: &Circuit<impl HugrView>,
replacement: Circuit<impl ExtractHugr>,
) -> Result<Self, InvalidReplacement> {
source_position
let replacement = replacement
.extract_dfg()
.unwrap_or_else(|e| panic!("{}", e))
.into_hugr();
circuit_position
.subgraph
.create_simple_replacement(source.hugr(), target.into_hugr())
.create_simple_replacement(circuit.hugr(), replacement)
.map(Self)
}

Expand Down
12 changes: 6 additions & 6 deletions tket2/src/serialize/pytket/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::mem;

use hugr::builder::{CircuitBuilder, Container, DFGBuilder, Dataflow, DataflowHugr};
use hugr::builder::{CircuitBuilder, Container, Dataflow, DataflowHugr, FunctionBuilder};
use hugr::extension::prelude::QB_T;

use hugr::types::FunctionType;
Expand All @@ -22,13 +22,13 @@ use super::{METADATA_B_REGISTERS, METADATA_Q_REGISTERS};
use crate::extension::{LINEAR_BIT, REGISTRY};
use crate::symbolic_constant_op;

/// The state of an in-progress [`DFGBuilder`] being built from a [`SerialCircuit`].
/// The state of an in-progress [`FunctionBuilder`] being built from a [`SerialCircuit`].
///
/// Mostly used to define helper internal methods.
#[derive(Debug, PartialEq)]
pub(super) struct JsonDecoder {
/// The Hugr being built.
pub hugr: DFGBuilder<Hugr>,
pub hugr: FunctionBuilder<Hugr>,
/// The dangling wires of the builder.
/// Used to generate [`CircuitBuilder`]s.
dangling_wires: Vec<Wire>,
Expand Down Expand Up @@ -66,8 +66,8 @@ impl JsonDecoder {
);
// .with_extension_delta(&ExtensionSet::singleton(&TKET1_EXTENSION_ID));

// TODO: Use a FunctionBuilder and store the circuit name there.
let mut dfg = DFGBuilder::new(sig).unwrap();
let name = serialcirc.name.clone().unwrap_or_default();
let mut dfg = FunctionBuilder::new(name, sig.into()).unwrap();

// Metadata. The circuit requires "name", and we store other things that
// should pass through the serialization roundtrip.
Expand Down Expand Up @@ -128,7 +128,7 @@ impl JsonDecoder {
}

/// Apply a function to the internal hugr builder viewed as a [`CircuitBuilder`].
fn with_circ_builder(&mut self, f: impl FnOnce(&mut CircuitBuilder<DFGBuilder<Hugr>>)) {
fn with_circ_builder(&mut self, f: impl FnOnce(&mut CircuitBuilder<FunctionBuilder<Hugr>>)) {
let mut circ = self.hugr.as_circuit(mem::take(&mut self.dangling_wires));
f(&mut circ);
self.dangling_wires = circ.finish();
Expand Down
14 changes: 10 additions & 4 deletions tket2/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
use hugr::builder::{Container, DataflowSubContainer, FunctionBuilder, HugrBuilder, ModuleBuilder};
use hugr::extension::PRELUDE_REGISTRY;
use hugr::ops::handle::NodeHandle;
use hugr::std_extensions::arithmetic::float_ops::FLOAT_OPS_REGISTRY;
use hugr::std_extensions::arithmetic::float_types;
use hugr::types::{Type, TypeBound};
use hugr::Hugr;
use hugr::{
builder::{BuildError, CircuitBuilder, DFGBuilder, Dataflow, DataflowHugr},
builder::{BuildError, CircuitBuilder, Dataflow, DataflowHugr},
extension::prelude::QB_T,
types::FunctionType,
};
Expand All @@ -21,10 +23,12 @@ pub(crate) fn type_is_linear(typ: &Type) -> bool {
#[allow(unused)]
pub(crate) fn build_simple_circuit<F>(num_qubits: usize, f: F) -> Result<Circuit, BuildError>
where
F: FnOnce(&mut CircuitBuilder<DFGBuilder<Hugr>>) -> Result<(), BuildError>,
F: FnOnce(&mut CircuitBuilder<FunctionBuilder<Hugr>>) -> Result<(), BuildError>,
{
let qb_row = vec![QB_T; num_qubits];
let mut h = DFGBuilder::new(FunctionType::new(qb_row.clone(), qb_row))?;
let signature =
FunctionType::new(qb_row.clone(), qb_row).with_extension_delta(float_types::EXTENSION_ID);
let mut h = FunctionBuilder::new("main", signature.into())?;

let qbs = h.input_wires();

Expand All @@ -33,7 +37,9 @@ where
f(&mut circ)?;

let qbs = circ.finish();
let hugr = h.finish_hugr_with_outputs(qbs, &PRELUDE_REGISTRY)?;

// The float ops registry is required to define constant float values.
let hugr = h.finish_hugr_with_outputs(qbs, &FLOAT_OPS_REGISTRY)?;
Ok(hugr.into())
}

Expand Down

0 comments on commit 39057d2

Please sign in to comment.