Skip to content

Commit

Permalink
feat: dataflow builder methods for angle ops (#596)
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 authored Sep 6, 2024
1 parent dd9592f commit dcc562d
Showing 1 changed file with 107 additions and 2 deletions.
109 changes: 107 additions & 2 deletions tket2/src/extension/angle.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use hugr::builder::{BuildError, Dataflow};
use hugr::extension::prelude::{sum_with_error, BOOL_T, USIZE_T};
use hugr::extension::simple_op::{MakeOpDef, MakeRegisteredOp};
use hugr::extension::{ExtensionId, ExtensionSet, Version};
use hugr::ops::constant::{downcast_equal_consts, CustomConst};
use hugr::std_extensions::arithmetic::float_types::FLOAT64_TYPE;
use hugr::type_row;
use hugr::{type_row, Wire};
use hugr::{
types::{ConstTypeError, CustomType, Signature, Type, TypeBound},
Extension,
Expand Down Expand Up @@ -261,11 +262,94 @@ pub(super) fn add_to_extension(extension: &mut Extension) {
AngleOp::load_all_ops(extension).expect("add fail");
}

/// An extension trait for [Dataflow] providing methods to add
/// "tket2.angle" operations.
pub trait AngleOpBuilder: Dataflow {
/// Add a "tket2.angle.atrunc" op.
fn add_atrunc(&mut self, angle: Wire, log_denom: Wire) -> Result<Wire, BuildError> {
Ok(self
.add_dataflow_op(AngleOp::atrunc, [angle, log_denom])?
.out_wire(0))
}
/// Add a "tket2.angle.aadd" op.
fn add_aadd(&mut self, angle1: Wire, angle2: Wire) -> Result<Wire, BuildError> {
Ok(self
.add_dataflow_op(AngleOp::aadd, [angle1, angle2])?
.out_wire(0))
}

/// Add a "tket2.angle.asub" op.
fn add_asub(&mut self, angle1: Wire, angle2: Wire) -> Result<Wire, BuildError> {
Ok(self
.add_dataflow_op(AngleOp::asub, [angle1, angle2])?
.out_wire(0))
}

/// Add a "tket2.angle.aneg" op.
fn add_aneg(&mut self, angle: Wire) -> Result<Wire, BuildError> {
Ok(self.add_dataflow_op(AngleOp::aneg, [angle])?.out_wire(0))
}

/// Add a "tket2.angle.anew" op.
fn add_anew(&mut self, numerator: Wire, log_denominator: Wire) -> Result<Wire, BuildError> {
Ok(self
.add_dataflow_op(AngleOp::anew, [numerator, log_denominator])?
.out_wire(0))
}

/// Add a "tket2.angle.aparts" op.
fn add_aparts(&mut self, angle: Wire) -> Result<[Wire; 2], BuildError> {
Ok(self
.add_dataflow_op(AngleOp::aparts, [angle])?
.outputs_arr())
}

/// Add a "tket2.angle.afromrad" op.
fn add_afromrad(&mut self, log_denominator: Wire, radians: Wire) -> Result<Wire, BuildError> {
Ok(self
.add_dataflow_op(AngleOp::afromrad, [log_denominator, radians])?
.out_wire(0))
}

/// Add a "tket2.angle.atorad" op.
fn add_atorad(&mut self, angle: Wire) -> Result<Wire, BuildError> {
Ok(self.add_dataflow_op(AngleOp::atorad, [angle])?.out_wire(0))
}

/// Add a "tket2.angle.aeq" op.
fn add_aeq(&mut self, angle1: Wire, angle2: Wire) -> Result<Wire, BuildError> {
Ok(self
.add_dataflow_op(AngleOp::aeq, [angle1, angle2])?
.out_wire(0))
}

/// Add a "tket2.angle.amul" op.
fn add_amul(&mut self, angle: Wire, scalar: Wire) -> Result<Wire, BuildError> {
Ok(self
.add_dataflow_op(AngleOp::amul, [angle, scalar])?
.out_wire(0))
}

/// Add a "tket2.angle.adiv" op.
fn add_adiv(&mut self, angle: Wire, scalar: Wire) -> Result<Wire, BuildError> {
Ok(self
.add_dataflow_op(AngleOp::adiv, [angle, scalar])?
.out_wire(0))
}
}

impl<D: Dataflow> AngleOpBuilder for D {}

#[cfg(test)]
mod test {
use hugr::ops::OpType;
use hugr::{
builder::{DFGBuilder, DataflowHugr},
ops::OpType,
};
use strum::IntoEnumIterator;

use crate::extension::REGISTRY;

use super::*;

#[test]
Expand Down Expand Up @@ -306,4 +390,25 @@ mod test {
assert_eq!(optype.cast(), Some(op));
}
}

#[test]
fn test_builder() {
let mut builder =
DFGBuilder::new(Signature::new(vec![ANGLE_TYPE, USIZE_T], vec![BOOL_T])).unwrap();

let [angle, scalar] = builder.input_wires_arr();
let radians = builder.add_atorad(angle).unwrap();
let angle = builder.add_afromrad(scalar, radians).unwrap();
let angle = builder.add_amul(angle, scalar).unwrap();
let angle = builder.add_adiv(angle, scalar).unwrap();
let angle = builder.add_aadd(angle, angle).unwrap();
let angle = builder.add_asub(angle, angle).unwrap();
let [num, log_denom] = builder.add_aparts(angle).unwrap();
let _angle_sum = builder.add_anew(num, log_denom).unwrap();
let angle = builder.add_aneg(angle).unwrap();
let angle = builder.add_atrunc(angle, log_denom).unwrap();
let bool = builder.add_aeq(angle, angle).unwrap();

let _hugr = builder.finish_hugr_with_outputs([bool], &REGISTRY).unwrap();
}
}

0 comments on commit dcc562d

Please sign in to comment.