From ccc35913f18fc4f9384434941c5a01bb1888ac52 Mon Sep 17 00:00:00 2001 From: Douglas Wilson <141026920+doug-q@users.noreply.github.com> Date: Mon, 9 Sep 2024 15:17:34 +0100 Subject: [PATCH] feat!: `HSeriesPass` lowers `Tk2Op`s into `HSeriesOp`s (#602) BREAKING CHANGE: The output of `HSeriesPass` has changed. All `Tk2Op`s are now converted to `HSeriesOp`s. --------- Co-authored-by: Seyon Sivarajah --- tket2-hseries/src/extension/hseries.rs | 2 +- tket2-hseries/src/extension/hseries/lower.rs | 40 ++++++++++++++++++++ tket2-hseries/src/lib.rs | 31 ++++++++++----- 3 files changed, 63 insertions(+), 10 deletions(-) diff --git a/tket2-hseries/src/extension/hseries.rs b/tket2-hseries/src/extension/hseries.rs index f1d7a898..c6e846ab 100644 --- a/tket2-hseries/src/extension/hseries.rs +++ b/tket2-hseries/src/extension/hseries.rs @@ -30,7 +30,7 @@ use super::futures::future_type; mod lower; use lower::pi_mul_f64; -pub use lower::{check_lowered, lower_tk2_op}; +pub use lower::{check_lowered, lower_tk2_op, LowerTk2Error, LowerTket2ToHSeriesPass}; /// The "tket2.hseries" extension id. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("tket2.hseries"); diff --git a/tket2-hseries/src/extension/hseries/lower.rs b/tket2-hseries/src/extension/hseries/lower.rs index a7f755cb..02427c90 100644 --- a/tket2-hseries/src/extension/hseries/lower.rs +++ b/tket2-hseries/src/extension/hseries/lower.rs @@ -1,7 +1,9 @@ use std::collections::HashMap; use hugr::{ + algorithms::validation::{ValidatePassError, ValidationLevel}, builder::{BuildError, DFGBuilder, Dataflow, DataflowHugr}, + extension::ExtensionRegistry, hugr::{hugrmut::HugrMut, HugrError}, ops::{self, DataflowOpTrait}, std_extensions::arithmetic::float_types::ConstF64, @@ -26,6 +28,7 @@ fn const_f64(builder: &mut T, value: f64) -> Wire { /// Errors produced by lowering [Tk2Op]s. #[derive(Debug, Error)] +#[allow(missing_docs)] pub enum LowerTk2Error { #[error("Error when building the circuit: {0}")] BuildError(#[from] BuildError), @@ -38,6 +41,12 @@ pub enum LowerTk2Error { #[error("Error when lowering ops: {0}")] CircuitReplacement(#[from] hugr::algorithms::lower::LowerError), + + #[error("Tk2Ops were not lowered: {0:?}")] + Unlowered(Vec), + + #[error(transparent)] + ValidationError(#[from] ValidatePassError), } fn op_to_hugr(op: Tk2Op) -> Result { @@ -136,6 +145,37 @@ pub fn check_lowered(hugr: &impl HugrView) -> Result<(), Vec> { } } +/// A `Hugr -> Hugr` pass that replaces [tket2::Tk2Op] nodes to +/// equivalent graphs made of [HSeriesOp]s. +/// +/// Invokes [lower_tk2_op]. If validation is enabled the resulting HUGR is +/// checked with [check_lowered]. +#[derive(Default, Debug, Clone)] +pub struct LowerTket2ToHSeriesPass(ValidationLevel); + +impl LowerTket2ToHSeriesPass { + /// Run `LowerTket2ToHSeriesPass` on the given [HugrMut]. `registry` is used + /// for validation, if enabled. + pub fn run( + &self, + hugr: &mut impl HugrMut, + registry: &ExtensionRegistry, + ) -> Result<(), LowerTk2Error> { + self.0.run_validated_pass(hugr, registry, |hugr, level| { + lower_tk2_op(hugr)?; + if *level != ValidationLevel::None { + check_lowered(hugr).map_err(LowerTk2Error::Unlowered)?; + } + Ok(()) + }) + } + + /// Returns a new `LowerTket2ToHSeriesPass` with the given [ValidationLevel]. + pub fn with_validation_level(&self, level: ValidationLevel) -> Self { + Self(level) + } +} + #[cfg(test)] mod test { use hugr::{builder::FunctionBuilder, type_row, HugrView}; diff --git a/tket2-hseries/src/lib.rs b/tket2-hseries/src/lib.rs index 777b7c28..3c778919 100644 --- a/tket2-hseries/src/lib.rs +++ b/tket2-hseries/src/lib.rs @@ -12,7 +12,10 @@ use tket2::Tk2Op; use thiserror::Error; -use extension::{futures::FutureOpDef, hseries::HSeriesOp}; +use extension::{ + futures::FutureOpDef, + hseries::{HSeriesOp, LowerTk2Error, LowerTket2ToHSeriesPass}, +}; use lazify_measure::{LazifyMeasurePass, LazifyMeasurePassError}; #[cfg(feature = "cli")] @@ -32,7 +35,6 @@ pub struct HSeriesPass { #[derive(Error, Debug)] /// An error reported from [HSeriesPass]. - pub enum HSeriesPassError { /// The [hugr::Hugr] was invalid either before or after a pass ran. #[error(transparent)] @@ -43,6 +45,9 @@ pub enum HSeriesPassError { /// An error from the component [force_order()] pass. #[error(transparent)] ForceOrderError(#[from] HugrError), + /// An error from the component [LowerTket2ToHSeriesPass] pass. + #[error(transparent)] + LowerTk2Error(#[from] LowerTk2Error), } impl HSeriesPass { @@ -53,6 +58,7 @@ impl HSeriesPass { hugr: &mut impl HugrMut, registry: &ExtensionRegistry, ) -> Result<(), HSeriesPassError> { + self.lower_tk2().run(hugr, registry)?; self.lazify_measure().run(hugr, registry)?; self.validation_level .run_validated_pass(hugr, registry, |hugr, _| { @@ -73,6 +79,10 @@ impl HSeriesPass { Ok(()) } + fn lower_tk2(&self) -> LowerTket2ToHSeriesPass { + LowerTket2ToHSeriesPass::default().with_validation_level(self.validation_level) + } + fn lazify_measure(&self) -> LazifyMeasurePass { LazifyMeasurePass::default().with_validation_level(self.validation_level) } @@ -90,15 +100,18 @@ mod test { builder::{Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer}, extension::prelude::{BOOL_T, QB_T}, ops::handle::NodeHandle, + std_extensions::arithmetic::float_types::ConstF64, type_row, types::Signature, HugrView as _, }; use itertools::Itertools as _; use petgraph::visit::{Topo, Walker as _}; - use tket2::{extension::angle::ConstAngle, Tk2Op}; - use crate::{extension::futures::FutureOpDef, HSeriesPass}; + use crate::{ + extension::{futures::FutureOpDef, hseries::HSeriesOp}, + HSeriesPass, + }; #[test] fn hseries_pass() { @@ -121,19 +134,19 @@ mod test { .node(); // this LoadConstant should be pushed below the quantum ops where possible - let angle = builder.add_load_value(ConstAngle::PI); + let angle = builder.add_load_value(ConstF64::new(0.0)); let f_node = angle.node(); - // with no dependencies, this H should be lifted to the start + // with no dependencies, this Reset should be lifted to the start let [qb] = builder - .add_dataflow_op(Tk2Op::H, [qb]) + .add_dataflow_op(HSeriesOp::Reset, [qb]) .unwrap() .outputs_arr(); let h_node = qb.node(); // depending on the angle means this op can't be lifted above the angle ops let [qb] = builder - .add_dataflow_op(Tk2Op::Rx, [qb, angle]) + .add_dataflow_op(HSeriesOp::Rz, [qb, angle]) .unwrap() .outputs_arr(); let rx_node = qb.node(); @@ -142,7 +155,7 @@ mod test { // Reads will be added. The Lazy Measure will be lifted and the // reads will be sunk. let [qb, measure_result] = builder - .add_dataflow_op(Tk2Op::Measure, [qb]) + .add_dataflow_op(HSeriesOp::Measure, [qb]) .unwrap() .outputs_arr();