Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: move the core types to their own module #627

Merged
merged 1 commit into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ use thiserror::Error;
#[cfg(feature = "pyo3")]
use pyo3::{create_exception, exceptions::PyException, PyErr};

use crate::hugr::{HugrError, Node, ValidationError, Wire};
use crate::hugr::{HugrError, ValidationError};
use crate::ops::handle::{BasicBlockID, CfgID, ConditionalID, DfgID, FuncID, TailLoopID};
use crate::types::ConstTypeError;
use crate::types::Type;
use crate::{Node, Wire};

pub mod handle;
pub use handle::BuildHandle;
Expand Down
3 changes: 2 additions & 1 deletion src/builder/build_traits.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use crate::hugr::hugrmut::InsertionResult;
use crate::hugr::validate::InterGraphEdgeError;
use crate::hugr::views::HugrView;
use crate::hugr::{IncomingPort, Node, NodeMetadata, OutgoingPort, Port, ValidationError};
use crate::hugr::{NodeMetadata, ValidationError};
use crate::ops::{self, LeafOp, OpTrait, OpType};
use crate::{IncomingPort, Node, OutgoingPort, Port};

use std::iter;

Expand Down
4 changes: 1 addition & 3 deletions src/builder/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@ use std::collections::HashMap;

use thiserror::Error;

use crate::hugr::CircuitUnit;

use crate::ops::OpType;

use super::{BuildError, Dataflow};
use crate::Wire;
use crate::{CircuitUnit, Wire};

/// Builder to build regions of dataflow graphs that look like Circuits,
/// where some inputs of operations directly correspond to some outputs.
Expand Down
282 changes: 282 additions & 0 deletions src/core.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
//! Definitions for the core types used in the Hugr.
//!
//! These types are re-exported in the root of the crate.

use derive_more::From;

#[cfg(feature = "pyo3")]
use pyo3::pyclass;

use crate::hugr::HugrError;

/// A handle to a node in the HUGR.
#[derive(
Clone,
Copy,
PartialEq,
Eq,
PartialOrd,
Ord,
Hash,
Debug,
From,
serde::Serialize,
serde::Deserialize,
)]
#[serde(transparent)]
#[cfg_attr(feature = "pyo3", pyclass)]
pub struct Node {
index: portgraph::NodeIndex,
}

/// A handle to a port for a node in the HUGR.
#[derive(
Clone,
Copy,
PartialEq,
PartialOrd,
Eq,
Ord,
Hash,
Default,
Debug,
From,
serde::Serialize,
serde::Deserialize,
)]
#[serde(transparent)]
#[cfg_attr(feature = "pyo3", pyclass)]
pub struct Port {
offset: portgraph::PortOffset,
}

/// A trait for getting the undirected index of a port.
pub trait PortIndex {
/// Returns the offset of the port.
fn index(self) -> usize;
}

/// A trait for getting the index of a node.
pub trait NodeIndex {
/// Returns the index of the node.
fn index(self) -> usize;
}

/// A port in the incoming direction.
#[derive(Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash, Default, Debug)]
pub struct IncomingPort {
index: u16,
}

/// A port in the outgoing direction.
#[derive(Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash, Default, Debug)]
pub struct OutgoingPort {
index: u16,
}

/// The direction of a port.
pub type Direction = portgraph::Direction;

#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
/// A DataFlow wire, defined by a Value-kind output port of a node
// Stores node and offset to output port
pub struct Wire(Node, usize);

impl Node {
/// Returns the node as a portgraph `NodeIndex`.
#[inline]
pub(crate) fn pg_index(self) -> portgraph::NodeIndex {
self.index
}
}

impl Port {
/// Creates a new port.
#[inline]
pub fn new(direction: Direction, port: usize) -> Self {
Self {
offset: portgraph::PortOffset::new(direction, port),
}
}

/// Creates a new incoming port.
#[inline]
pub fn new_incoming(port: impl Into<IncomingPort>) -> Self {
Self::try_new_incoming(port).unwrap()
}

/// Creates a new outgoing port.
#[inline]
pub fn new_outgoing(port: impl Into<OutgoingPort>) -> Self {
Self::try_new_outgoing(port).unwrap()
}

/// Creates a new incoming port.
#[inline]
pub fn try_new_incoming(port: impl TryInto<IncomingPort>) -> Result<Self, HugrError> {
let Ok(port) = port.try_into() else {
return Err(HugrError::InvalidPortDirection(Direction::Outgoing));
};
Ok(Self {
offset: portgraph::PortOffset::new_incoming(port.index()),
})
}

/// Creates a new outgoing port.
#[inline]
pub fn try_new_outgoing(port: impl TryInto<OutgoingPort>) -> Result<Self, HugrError> {
let Ok(port) = port.try_into() else {
return Err(HugrError::InvalidPortDirection(Direction::Incoming));
};
Ok(Self {
offset: portgraph::PortOffset::new_outgoing(port.index()),
})
}

/// Returns the direction of the port.
#[inline]
pub fn direction(self) -> Direction {
self.offset.direction()
}

/// Returns the port as a portgraph `PortOffset`.
#[inline]
pub(crate) fn pg_offset(self) -> portgraph::PortOffset {
self.offset
}
}

impl PortIndex for Port {
#[inline(always)]
fn index(self) -> usize {
self.offset.index()
}
}

impl PortIndex for usize {
#[inline(always)]
fn index(self) -> usize {
self
}
}

impl PortIndex for IncomingPort {
#[inline(always)]
fn index(self) -> usize {
self.index as usize
}
}

impl PortIndex for OutgoingPort {
#[inline(always)]
fn index(self) -> usize {
self.index as usize
}
}

impl From<usize> for IncomingPort {
#[inline(always)]
fn from(index: usize) -> Self {
Self {
index: index as u16,
}
}
}

impl From<usize> for OutgoingPort {
#[inline(always)]
fn from(index: usize) -> Self {
Self {
index: index as u16,
}
}
}

impl TryFrom<Port> for IncomingPort {
type Error = HugrError;
#[inline(always)]
fn try_from(port: Port) -> Result<Self, Self::Error> {
match port.direction() {
Direction::Incoming => Ok(Self {
index: port.index() as u16,
}),
dir @ Direction::Outgoing => Err(HugrError::InvalidPortDirection(dir)),
}
}
}

impl TryFrom<Port> for OutgoingPort {
type Error = HugrError;
#[inline(always)]
fn try_from(port: Port) -> Result<Self, Self::Error> {
match port.direction() {
Direction::Outgoing => Ok(Self {
index: port.index() as u16,
}),
dir @ Direction::Incoming => Err(HugrError::InvalidPortDirection(dir)),
}
}
}

impl NodeIndex for Node {
fn index(self) -> usize {
self.index.into()
}
}

impl Wire {
/// Create a new wire from a node and a port.
#[inline]
pub fn new(node: Node, port: impl TryInto<OutgoingPort>) -> Self {
Self(node, Port::try_new_outgoing(port).unwrap().index())
}

/// The node that this wire is connected to.
#[inline]
pub fn node(&self) -> Node {
self.0
}

/// The output port that this wire is connected to.
#[inline]
pub fn source(&self) -> Port {
Port::new_outgoing(self.1)
}
}

/// Enum for uniquely identifying the origin of linear wires in a circuit-like
/// dataflow region.
///
/// Falls back to [`Wire`] if the wire is not linear or if it's not possible to
/// track the origin.
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum CircuitUnit {
/// Arbitrary input wire.
Wire(Wire),
/// Index to region input.
Linear(usize),
}

impl CircuitUnit {
/// Check if this is a wire.
pub fn is_wire(&self) -> bool {
matches!(self, CircuitUnit::Wire(_))
}

/// Check if this is a linear unit.
pub fn is_linear(&self) -> bool {
matches!(self, CircuitUnit::Linear(_))
}
}

impl From<usize> for CircuitUnit {
fn from(value: usize) -> Self {
CircuitUnit::Linear(value)
}
}

impl From<Wire> for CircuitUnit {
fn from(value: Wire) -> Self {
CircuitUnit::Wire(value)
}
}
3 changes: 1 addition & 2 deletions src/extension/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@
use super::{ExtensionId, ExtensionSet};
use crate::{
hugr::views::HugrView,
hugr::Node,
ops::{OpTag, OpTrait, OpType},
types::EdgeKind,
Direction,
Direction, Node,
};

use super::validate::ExtensionError;
Expand Down
Loading