Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: move rewrite inside hugr, Rewrite -> Replace implementing new 'Rewrite' trait #119

Merged
merged 28 commits into from
Jun 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
80703e1
Remove Pattern
acl-cqc Jun 5, 2023
9ef30ad
Remove Pattern.rs, too skeletal to be useful
acl-cqc Jun 5, 2023
75b9f1f
Move rewrite to hugr/replace
acl-cqc Jun 5, 2023
5c5662f
Add RewriteOp enum, move Hugr code into RewriteOp::apply (a big match)
acl-cqc Jun 5, 2023
de70383
Make into a trait. So it has to be public...so what?
acl-cqc Jun 5, 2023
3a534e6
Parametrize by error type
acl-cqc Jun 5, 2023
55f83ba
fmt
acl-cqc Jun 5, 2023
76d13fe
Rename hugr/replace/{rewrite.rs -> replace.rs}
acl-cqc Jun 6, 2023
6e8b152
Rename src/hugr/{replace=>rewrite}(,.rs)
acl-cqc Jun 6, 2023
99b31ce
Rename Rewrite(Error) to Replace(Error)
acl-cqc Jun 6, 2023
070afdd
Rename RewriteOp to Rewrite
acl-cqc Jun 6, 2023
7348e11
Hugr::apply -> apply_rewrite
acl-cqc Jun 6, 2023
9349aa4
Merge remote-tracking branch 'origin/main' into refactor/replace_trait
acl-cqc Jun 6, 2023
ec8354e
Add may_fail_destructively check, default true, and Transactional wra…
acl-cqc Jun 6, 2023
147c937
is_err
acl-cqc Jun 6, 2023
9f1d382
unchanged_on_failure as trait associated constant
acl-cqc Jun 7, 2023
f0b5b82
Rephrase assert/debug_assert
acl-cqc Jun 7, 2023
45c5fc2
Merge remote-tracking branch 'origin/main' into refactor/replace_trait
acl-cqc Jun 7, 2023
8f8fcae
Merge remote-tracking branch 'origin/main' into refactor/replace_trait
acl-cqc Jun 9, 2023
4604c70
Move SimpleReplacement inside rewrite, move Hugr::apply_simple_replac…
acl-cqc Jun 9, 2023
a83a93e
unused variable
acl-cqc Jun 9, 2023
0291bba
Drive-by: simple_replace.rs: change ".ok();"s to unwrap
acl-cqc Jun 9, 2023
3572933
Merge remote-tracking branch 'origin/main' into refactor/replace_trait
acl-cqc Jun 9, 2023
04c8777
WIP
acl-cqc Jun 19, 2023
6070e20
Fix merge
acl-cqc Jun 19, 2023
4202f5e
Review comments
acl-cqc Jun 19, 2023
4553395
todo -> unimplemented, the plan is not necessarily to do these
acl-cqc Jun 19, 2023
0201233
fmt
acl-cqc Jun 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 7 additions & 187 deletions src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,25 @@

mod hugrmut;

pub mod rewrite;
pub mod serialize;
pub mod typecheck;
pub mod validate;
pub mod view;

use std::collections::HashMap;

pub(crate) use self::hugrmut::HugrMut;
pub use self::validate::ValidationError;

use derive_more::From;
pub use rewrite::{Replace, ReplaceError, Rewrite, SimpleReplacement, SimpleReplacementError};

use portgraph::dot::{DotFormat, EdgeStyle, NodeStyle, PortStyle};
use portgraph::multiportgraph::MultiPortGraph;
use portgraph::{Hierarchy, LinkView, NodeIndex, PortView, UnmanagedDenseMap};
use portgraph::{Hierarchy, LinkView, PortView, UnmanagedDenseMap};
use thiserror::Error;

pub use self::view::HugrView;
use crate::ops::tag::OpTag;
use crate::ops::{OpName, OpTrait, OpType};
use crate::replacement::{SimpleReplacement, SimpleReplacementError};
use crate::rewrite::{Rewrite, RewriteError};
use crate::ops::{OpName, OpType};
use crate::types::EdgeKind;

/// The Hugr data structure.
Expand Down Expand Up @@ -81,187 +79,9 @@ pub struct Wire(Node, usize);

/// Public API for HUGRs.
impl Hugr {
/// Apply a simple replacement operation to the HUGR.
pub fn apply_simple_replacement(
&mut self,
r: SimpleReplacement,
) -> Result<(), SimpleReplacementError> {
// 1. Check the parent node exists and is a DFG node.
if self.get_optype(r.parent).tag() != OpTag::Dfg {
return Err(SimpleReplacementError::InvalidParentNode());
}
// 2. Check that all the to-be-removed nodes are children of it and are leaves.
for node in &r.removal {
if self.hierarchy.parent(node.index) != Some(r.parent.index)
|| self.hierarchy.has_children(node.index)
{
return Err(SimpleReplacementError::InvalidRemovedNode());
}
}
// 3. Do the replacement.
// 3.1. Add copies of all replacement nodes and edges to self. Exclude Input/Output nodes.
// Create map from old NodeIndex (in r.replacement) to new NodeIndex (in self).
let mut index_map: HashMap<NodeIndex, NodeIndex> = HashMap::new();
let replacement_nodes = r
.replacement
.children(r.replacement.root())
.collect::<Vec<Node>>();
// slice of nodes omitting Input and Output:
let replacement_inner_nodes = &replacement_nodes[2..];
for &node in replacement_inner_nodes {
// Check there are no const inputs.
if !r
.replacement
.get_optype(node)
.signature()
.const_input
.is_empty()
{
return Err(SimpleReplacementError::InvalidReplacementNode());
}
}
let self_output_node_index = self.children(r.parent).nth(1).unwrap();
let replacement_output_node = *replacement_nodes.get(1).unwrap();
for &node in replacement_inner_nodes {
// Add the nodes.
let op: &OpType = r.replacement.get_optype(node);
let new_node_index = self
.add_op_after(self_output_node_index, op.clone())
.unwrap();
index_map.insert(node.index, new_node_index.index);
}
// Add edges between all newly added nodes matching those in replacement.
// TODO This will probably change when implicit copies are implemented.
for &node in replacement_inner_nodes {
let new_node_index = index_map.get(&node.index).unwrap();
for node_successor in r.replacement.output_neighbours(node) {
if r.replacement.get_optype(node_successor).tag() != OpTag::Output {
let new_node_successor_index = index_map.get(&node_successor.index).unwrap();
for connection in r
.replacement
.graph
.get_connections(node.index, node_successor.index)
{
let src_offset = r
.replacement
.graph
.port_offset(connection.0)
.unwrap()
.index();
let tgt_offset = r
.replacement
.graph
.port_offset(connection.1)
.unwrap()
.index();
self.graph
.link_nodes(
*new_node_index,
src_offset,
*new_node_successor_index,
tgt_offset,
)
.ok();
}
}
}
}
// 3.2. For each p = r.nu_inp[q] such that q is not an Output port, add an edge from the
// predecessor of p to (the new copy of) q.
for ((rep_inp_node, rep_inp_port), (rem_inp_node, rem_inp_port)) in &r.nu_inp {
if r.replacement.get_optype(*rep_inp_node).tag() != OpTag::Output {
let new_inp_node_index = index_map.get(&rep_inp_node.index).unwrap();
// add edge from predecessor of (s_inp_node, s_inp_port) to (new_inp_node, n_inp_port)
let rem_inp_port_index = self
.graph
.port_index(rem_inp_node.index, rem_inp_port.offset)
.unwrap();
let rem_inp_predecessor_port_index =
self.graph.port_link(rem_inp_port_index).unwrap().port();
let new_inp_port_index = self
.graph
.port_index(*new_inp_node_index, rep_inp_port.offset)
.unwrap();
self.graph.unlink_port(rem_inp_predecessor_port_index);
self.graph
.link_ports(rem_inp_predecessor_port_index, new_inp_port_index)
.ok();
}
}
// 3.3. For each q = r.nu_out[p] such that the predecessor of q is not an Input port, add an
// edge from (the new copy of) the predecessor of q to p.
for ((rem_out_node, rem_out_port), rep_out_port) in &r.nu_out {
let rem_out_port_index = self
.graph
.port_index(rem_out_node.index, rem_out_port.offset)
.unwrap();
let rep_out_port_index = r
.replacement
.graph
.port_index(replacement_output_node.index, rep_out_port.offset)
.unwrap();
let rep_out_predecessor_port_index =
r.replacement.graph.port_link(rep_out_port_index).unwrap();
let rep_out_predecessor_node_index = r
.replacement
.graph
.port_node(rep_out_predecessor_port_index)
.unwrap();
if r.replacement
.get_optype(rep_out_predecessor_node_index.into())
.tag()
!= OpTag::Input
{
let rep_out_predecessor_port_offset = r
.replacement
.graph
.port_offset(rep_out_predecessor_port_index)
.unwrap();
let new_out_node_index = index_map.get(&rep_out_predecessor_node_index).unwrap();
let new_out_port_index = self
.graph
.port_index(*new_out_node_index, rep_out_predecessor_port_offset)
.unwrap();
self.graph.unlink_port(rem_out_port_index);
self.graph
.link_ports(new_out_port_index, rem_out_port_index)
.ok();
}
}
// 3.4. For each q = r.nu_out[p1], p0 = r.nu_inp[q], add an edge from the predecessor of p0
// to p1.
for ((rem_out_node, rem_out_port), &rep_out_port) in &r.nu_out {
let rem_inp_nodeport = r.nu_inp.get(&(replacement_output_node, rep_out_port));
if let Some((rem_inp_node, rem_inp_port)) = rem_inp_nodeport {
// add edge from predecessor of (rem_inp_node, rem_inp_port) to (rem_out_node, rem_out_port):
let rem_inp_port_index = self
.graph
.port_index(rem_inp_node.index, rem_inp_port.offset)
.unwrap();
let rem_inp_predecessor_port_index =
self.graph.port_link(rem_inp_port_index).unwrap().port();
let rem_out_port_index = self
.graph
.port_index(rem_out_node.index, rem_out_port.offset)
.unwrap();
self.graph.unlink_port(rem_inp_port_index);
self.graph.unlink_port(rem_out_port_index);
self.graph
.link_ports(rem_inp_predecessor_port_index, rem_out_port_index)
.ok();
}
}
// 3.5. Remove all nodes in r.removal and edges between them.
for node in &r.removal {
self.graph.remove_node(node.index);
self.hierarchy.remove(node.index);
}
Ok(())
}

/// Applies a rewrite to the graph.
pub fn apply_rewrite(self, _rewrite: Rewrite) -> Result<(), RewriteError> {
unimplemented!()
pub fn apply_rewrite<E>(&mut self, rw: impl Rewrite<Error = E>) -> Result<(), E> {
rw.apply(self)
}

/// Return dot string showing underlying graph and hierarchy side by side.
Expand Down
62 changes: 62 additions & 0 deletions src/hugr/rewrite.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
//! Rewrite operations on the HUGR - replacement, outlining, etc.

pub mod replace;
pub mod simple_replace;
use std::mem;

use crate::Hugr;
pub use replace::{OpenHugr, Replace, ReplaceError};
pub use simple_replace::{SimpleReplacement, SimpleReplacementError};

/// An operation that can be applied to mutate a Hugr
pub trait Rewrite {
/// The type of Error with which this Rewrite may fail
type Error: std::error::Error;

/// If `true`, [self.apply]'s of this rewrite guarantee that they do not mutate the Hugr when they return an Err.
/// If `false`, there is no guarantee; the Hugr should be assumed invalid when Err is returned.
const UNCHANGED_ON_FAILURE: bool;

/// Checks whether the rewrite would succeed on the specified Hugr.
/// If this call succeeds, [self.apply] should also succeed on the same `h`
/// If this calls fails, [self.apply] would fail with the same error.
Copy link
Contributor Author

@acl-cqc acl-cqc Jun 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An alternative is to return Result<(), Option<E>> where returning Err(None) means, no guarantee can be given about apply. If a rewrite really has to get partway through in order to complete its own validity checks then said alternative would support that. But I think we should leave as Result<(), E> until/unless we find a Rewrite where that's difficult.

(Compound Rewrites is one such case. For a sequence of rewrites, verify might have to clone() the Hugr, and then step through the sequence, applying each rewrite before verifying the next.)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd say custom enums with {Recoverable, NonRecoverable} variants rather than Option<E>.
But yeah, let's keep it simple for now.

fn verify(&self, h: &Hugr) -> Result<(), Self::Error>;

/// Mutate the specified Hugr, or fail with an error.
/// If [self.unchanged_on_failure] is true, then `h` must be unchanged if Err is returned.
/// See also [self.verify]
/// # Panics
/// May panic if-and-only-if `h` would have failed [Hugr::validate]; that is,
/// implementations may begin with `assert!(h.validate())`, with `debug_assert!(h.validate())`
/// being preferred.
fn apply(self, h: &mut Hugr) -> Result<(), Self::Error>;
}

/// Wraps any rewrite into a transaction (i.e. that has no effect upon failure)
pub struct Transactional<R> {
underlying: R,
}

// Note we might like to constrain R to Rewrite<unchanged_on_failure=false> but this
// is not yet supported, https://github.com/rust-lang/rust/issues/92827
impl<R: Rewrite> Rewrite for Transactional<R> {
type Error = R::Error;
const UNCHANGED_ON_FAILURE: bool = true;

fn verify(&self, h: &Hugr) -> Result<(), Self::Error> {
self.underlying.verify(h)
}

fn apply(self, h: &mut Hugr) -> Result<(), Self::Error> {
if R::UNCHANGED_ON_FAILURE {
return self.underlying.apply(h);
}
let backup = h.clone();
let r = self.underlying.apply(h);
if r.is_err() {
// drop the old h, it was undefined
let _ = mem::replace(h, backup);
}
r
}
}
42 changes: 28 additions & 14 deletions src/rewrite/rewrite.rs → src/hugr/rewrite/replace.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
#![allow(missing_docs)]
//! Rewrite operations on Hugr graphs.
//! Replace operations on Hugr graphs. This is a nonfunctional
//! dummy implementation just to demonstrate design principles.

use std::collections::HashMap;

use portgraph::substitute::OpenGraph;
use portgraph::{NodeIndex, PortIndex};
use thiserror::Error;

use super::Rewrite;
use crate::Hugr;

/// A subset of the nodes in a graph, and the ports that it is connected to.
Expand Down Expand Up @@ -77,7 +79,7 @@ pub type ParentsMap = HashMap<NodeIndex, NodeIndex>;
/// Includes the new weights for the nodes in the replacement graph.
#[derive(Debug, Clone)]
#[allow(unused)]
pub struct Rewrite {
pub struct Replace {
/// The subgraph to be replaced.
subgraph: BoundedSubgraph,
/// The replacement graph.
Expand All @@ -86,7 +88,7 @@ pub struct Rewrite {
parents: ParentsMap,
}

impl Rewrite {
impl Replace {
/// Creates a new rewrite operation.
pub fn new(
subgraph: BoundedSubgraph,
Expand Down Expand Up @@ -114,30 +116,42 @@ impl Rewrite {
)
}

pub fn verify_convexity(&self) -> Result<(), ReplaceError> {
unimplemented!()
}

pub fn verify_boundaries(&self) -> Result<(), ReplaceError> {
unimplemented!()
}
}

impl Rewrite for Replace {
type Error = ReplaceError;
const UNCHANGED_ON_FAILURE: bool = false;

/// Checks that the rewrite is valid.
///
/// This includes having a convex subgraph (TODO: include definition), and
/// having matching numbers of ports on the boundaries.
pub fn verify(&self) -> Result<(), RewriteError> {
/// TODO not clear this implementation really provides much guarantee about [self.apply]
/// but this class is not really working anyway.
fn verify(&self, _h: &Hugr) -> Result<(), ReplaceError> {
self.verify_convexity()?;
self.verify_boundaries()?;
Ok(())
}

pub fn verify_convexity(&self) -> Result<(), RewriteError> {
todo!()
}

pub fn verify_boundaries(&self) -> Result<(), RewriteError> {
todo!()
/// Performs a Replace operation on the graph.
fn apply(self, _h: &mut Hugr) -> Result<(), ReplaceError> {
unimplemented!()
}
}

/// Error generated when a rewrite fails.
#[derive(Debug, Clone, Error, PartialEq, Eq)]
pub enum RewriteError {
/// The rewrite failed because the boundary defined by the
/// [`Rewrite`] could not be matched to the dangling ports of the
pub enum ReplaceError {
/// The replacement failed because the boundary defined by the
/// [`Replace`] could not be matched to the dangling ports of the
/// [`OpenHugr`].
#[error("The boundary defined by the rewrite could not be matched to the dangling ports of the OpenHugr")]
BoundarySize(#[source] portgraph::substitute::RewriteError),
Expand All @@ -152,7 +166,7 @@ pub enum RewriteError {
NotConvex(),
}

impl From<portgraph::substitute::RewriteError> for RewriteError {
impl From<portgraph::substitute::RewriteError> for ReplaceError {
fn from(e: portgraph::substitute::RewriteError) -> Self {
match e {
portgraph::substitute::RewriteError::BoundarySize => Self::BoundarySize(e),
Expand Down
Loading