Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into refactor/nodetype_op
Browse files Browse the repository at this point in the history
  • Loading branch information
acl-cqc committed Oct 27, 2023
2 parents 46efff0 + 0beb165 commit 4896158
Show file tree
Hide file tree
Showing 20 changed files with 508 additions and 391 deletions.
3 changes: 2 additions & 1 deletion src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ use thiserror::Error;
#[cfg(feature = "pyo3")]
use pyo3::{create_exception, exceptions::PyException, PyErr};

use crate::hugr::{HugrError, Node, ValidationError, Wire};
use crate::hugr::{HugrError, ValidationError};
use crate::ops::handle::{BasicBlockID, CfgID, ConditionalID, DfgID, FuncID, TailLoopID};
use crate::types::ConstTypeError;
use crate::types::Type;
use crate::{Node, Wire};

pub mod handle;
pub use handle::BuildHandle;
Expand Down
3 changes: 2 additions & 1 deletion src/builder/build_traits.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use crate::hugr::hugrmut::InsertionResult;
use crate::hugr::validate::InterGraphEdgeError;
use crate::hugr::views::HugrView;
use crate::hugr::{IncomingPort, Node, NodeMetadata, OutgoingPort, Port, ValidationError};
use crate::hugr::{NodeMetadata, ValidationError};
use crate::ops::{self, LeafOp, OpTrait, OpType};
use crate::{IncomingPort, Node, OutgoingPort, Port};

use std::iter;

Expand Down
4 changes: 1 addition & 3 deletions src/builder/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@ use std::collections::HashMap;

use thiserror::Error;

use crate::hugr::CircuitUnit;

use crate::ops::OpType;

use super::{BuildError, Dataflow};
use crate::Wire;
use crate::{CircuitUnit, Wire};

/// Builder to build regions of dataflow graphs that look like Circuits,
/// where some inputs of operations directly correspond to some outputs.
Expand Down
333 changes: 333 additions & 0 deletions src/core.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,333 @@
//! Definitions for the core types used in the Hugr.
//!
//! These types are re-exported in the root of the crate.
use derive_more::From;

#[cfg(feature = "pyo3")]
use pyo3::pyclass;

use crate::hugr::HugrError;

/// A handle to a node in the HUGR.
#[derive(
Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, From, serde::Serialize, serde::Deserialize,
)]
#[serde(transparent)]
#[cfg_attr(feature = "pyo3", pyclass)]
pub struct Node {
index: portgraph::NodeIndex,
}

/// A handle to a port for a node in the HUGR.
#[derive(
Clone,
Copy,
PartialEq,
PartialOrd,
Eq,
Ord,
Hash,
Default,
From,
serde::Serialize,
serde::Deserialize,
)]
#[serde(transparent)]
#[cfg_attr(feature = "pyo3", pyclass)]
pub struct Port {
offset: portgraph::PortOffset,
}

/// A trait for getting the undirected index of a port.
pub trait PortIndex {
/// Returns the offset of the port.
fn index(self) -> usize;
}

/// A trait for getting the index of a node.
pub trait NodeIndex {
/// Returns the index of the node.
fn index(self) -> usize;
}

/// A port in the incoming direction.
#[derive(Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash, Default)]
pub struct IncomingPort {
index: u16,
}

/// A port in the outgoing direction.
#[derive(Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash, Default)]
pub struct OutgoingPort {
index: u16,
}

/// The direction of a port.
pub type Direction = portgraph::Direction;

#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
/// A DataFlow wire, defined by a Value-kind output port of a node
// Stores node and offset to output port
pub struct Wire(Node, usize);

impl Node {
/// Returns the node as a portgraph `NodeIndex`.
#[inline]
pub(crate) fn pg_index(self) -> portgraph::NodeIndex {
self.index
}
}

impl Port {
/// Creates a new port.
#[inline]
pub fn new(direction: Direction, port: usize) -> Self {
Self {
offset: portgraph::PortOffset::new(direction, port),
}
}

/// Creates a new incoming port.
#[inline]
pub fn new_incoming(port: impl Into<IncomingPort>) -> Self {
Self::try_new_incoming(port).unwrap()
}

/// Creates a new outgoing port.
#[inline]
pub fn new_outgoing(port: impl Into<OutgoingPort>) -> Self {
Self::try_new_outgoing(port).unwrap()
}

/// Creates a new incoming port.
#[inline]
pub fn try_new_incoming(port: impl TryInto<IncomingPort>) -> Result<Self, HugrError> {
let Ok(port) = port.try_into() else {
return Err(HugrError::InvalidPortDirection(Direction::Outgoing));
};
Ok(Self {
offset: portgraph::PortOffset::new_incoming(port.index()),
})
}

/// Creates a new outgoing port.
#[inline]
pub fn try_new_outgoing(port: impl TryInto<OutgoingPort>) -> Result<Self, HugrError> {
let Ok(port) = port.try_into() else {
return Err(HugrError::InvalidPortDirection(Direction::Incoming));
};
Ok(Self {
offset: portgraph::PortOffset::new_outgoing(port.index()),
})
}

/// Returns the direction of the port.
#[inline]
pub fn direction(self) -> Direction {
self.offset.direction()
}

/// Returns the port as a portgraph `PortOffset`.
#[inline]
pub(crate) fn pg_offset(self) -> portgraph::PortOffset {
self.offset
}
}

impl PortIndex for Port {
#[inline(always)]
fn index(self) -> usize {
self.offset.index()
}
}

impl PortIndex for usize {
#[inline(always)]
fn index(self) -> usize {
self
}
}

impl PortIndex for IncomingPort {
#[inline(always)]
fn index(self) -> usize {
self.index as usize
}
}

impl PortIndex for OutgoingPort {
#[inline(always)]
fn index(self) -> usize {
self.index as usize
}
}

impl From<usize> for IncomingPort {
#[inline(always)]
fn from(index: usize) -> Self {
Self {
index: index as u16,
}
}
}

impl From<usize> for OutgoingPort {
#[inline(always)]
fn from(index: usize) -> Self {
Self {
index: index as u16,
}
}
}

impl TryFrom<Port> for IncomingPort {
type Error = HugrError;
#[inline(always)]
fn try_from(port: Port) -> Result<Self, Self::Error> {
match port.direction() {
Direction::Incoming => Ok(Self {
index: port.index() as u16,
}),
dir @ Direction::Outgoing => Err(HugrError::InvalidPortDirection(dir)),
}
}
}

impl TryFrom<Port> for OutgoingPort {
type Error = HugrError;
#[inline(always)]
fn try_from(port: Port) -> Result<Self, Self::Error> {
match port.direction() {
Direction::Outgoing => Ok(Self {
index: port.index() as u16,
}),
dir @ Direction::Incoming => Err(HugrError::InvalidPortDirection(dir)),
}
}
}

impl NodeIndex for Node {
fn index(self) -> usize {
self.index.into()
}
}

impl Wire {
/// Create a new wire from a node and a port.
#[inline]
pub fn new(node: Node, port: impl TryInto<OutgoingPort>) -> Self {
Self(node, Port::try_new_outgoing(port).unwrap().index())
}

/// The node that this wire is connected to.
#[inline]
pub fn node(&self) -> Node {
self.0
}

/// The output port that this wire is connected to.
#[inline]
pub fn source(&self) -> Port {
Port::new_outgoing(self.1)
}
}

/// Enum for uniquely identifying the origin of linear wires in a circuit-like
/// dataflow region.
///
/// Falls back to [`Wire`] if the wire is not linear or if it's not possible to
/// track the origin.
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum CircuitUnit {
/// Arbitrary input wire.
Wire(Wire),
/// Index to region input.
Linear(usize),
}

impl CircuitUnit {
/// Check if this is a wire.
pub fn is_wire(&self) -> bool {
matches!(self, CircuitUnit::Wire(_))
}

/// Check if this is a linear unit.
pub fn is_linear(&self) -> bool {
matches!(self, CircuitUnit::Linear(_))
}
}

impl From<usize> for CircuitUnit {
fn from(value: usize) -> Self {
CircuitUnit::Linear(value)
}
}

impl From<Wire> for CircuitUnit {
fn from(value: Wire) -> Self {
CircuitUnit::Wire(value)
}
}

impl std::fmt::Debug for Node {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("Node").field(&self.index()).finish()
}
}

impl std::fmt::Debug for Port {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("Port")
.field(&self.offset.direction())
.field(&self.index())
.finish()
}
}

impl std::fmt::Debug for IncomingPort {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("IncomingPort").field(&self.index).finish()
}
}

impl std::fmt::Debug for OutgoingPort {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("OutgoingPort").field(&self.index).finish()
}
}

impl std::fmt::Debug for Wire {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Wire")
.field("node", &self.0.index())
.field("port", &self.1)
.finish()
}
}

impl std::fmt::Debug for CircuitUnit {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Wire(w) => f
.debug_struct("WireUnit")
.field("node", &w.0.index())
.field("port", &w.1)
.finish(),
Self::Linear(id) => f.debug_tuple("LinearUnit").field(id).finish(),
}
}
}

macro_rules! impl_display_from_debug {
($($t:ty),*) => {
$(
impl std::fmt::Display for $t {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
<Self as std::fmt::Debug>::fmt(self, f)
}
}
)*
};
}
impl_display_from_debug!(Node, Port, IncomingPort, OutgoingPort, Wire, CircuitUnit);
3 changes: 1 addition & 2 deletions src/extension/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@
use super::{ExtensionId, ExtensionSet};
use crate::{
hugr::views::HugrView,
hugr::Node,
ops::{OpTag, OpTrait},
types::EdgeKind,
Direction,
Direction, Node,
};

use super::validate::ExtensionError;
Expand Down
5 changes: 5 additions & 0 deletions src/extension/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,11 @@ impl ConstUsize {
pub fn new(value: u64) -> Self {
Self(value)
}

/// Returns the value of the constant.
pub fn value(&self) -> u64 {
self.0
}
}

#[typetag::serde]
Expand Down
Loading

0 comments on commit 4896158

Please sign in to comment.