Skip to content

Commit

Permalink
feat!: HSeriesPass lowers Tk2Ops into HSeriesOps (#602)
Browse files Browse the repository at this point in the history
BREAKING CHANGE: The output of `HSeriesPass` has changed. All `Tk2Op`s
are now converted to `HSeriesOp`s.

---------

Co-authored-by: Seyon Sivarajah <seyon.sivarajah@cambridgequantum.com>
  • Loading branch information
doug-q and ss2165 committed Sep 9, 2024
1 parent d90b815 commit ccc3591
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 10 deletions.
2 changes: 1 addition & 1 deletion tket2-hseries/src/extension/hseries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use super::futures::future_type;

mod lower;
use lower::pi_mul_f64;
pub use lower::{check_lowered, lower_tk2_op};
pub use lower::{check_lowered, lower_tk2_op, LowerTk2Error, LowerTket2ToHSeriesPass};

/// The "tket2.hseries" extension id.
pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("tket2.hseries");
Expand Down
40 changes: 40 additions & 0 deletions tket2-hseries/src/extension/hseries/lower.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use std::collections::HashMap;

use hugr::{
algorithms::validation::{ValidatePassError, ValidationLevel},
builder::{BuildError, DFGBuilder, Dataflow, DataflowHugr},
extension::ExtensionRegistry,
hugr::{hugrmut::HugrMut, HugrError},
ops::{self, DataflowOpTrait},
std_extensions::arithmetic::float_types::ConstF64,
Expand All @@ -26,6 +28,7 @@ fn const_f64<T: Dataflow + ?Sized>(builder: &mut T, value: f64) -> Wire {

/// Errors produced by lowering [Tk2Op]s.
#[derive(Debug, Error)]
#[allow(missing_docs)]
pub enum LowerTk2Error {
#[error("Error when building the circuit: {0}")]
BuildError(#[from] BuildError),
Expand All @@ -38,6 +41,12 @@ pub enum LowerTk2Error {

#[error("Error when lowering ops: {0}")]
CircuitReplacement(#[from] hugr::algorithms::lower::LowerError),

#[error("Tk2Ops were not lowered: {0:?}")]
Unlowered(Vec<Node>),

#[error(transparent)]
ValidationError(#[from] ValidatePassError),
}

fn op_to_hugr(op: Tk2Op) -> Result<Hugr, LowerTk2Error> {
Expand Down Expand Up @@ -136,6 +145,37 @@ pub fn check_lowered(hugr: &impl HugrView) -> Result<(), Vec<Node>> {
}
}

/// A `Hugr -> Hugr` pass that replaces [tket2::Tk2Op] nodes to
/// equivalent graphs made of [HSeriesOp]s.
///
/// Invokes [lower_tk2_op]. If validation is enabled the resulting HUGR is
/// checked with [check_lowered].
#[derive(Default, Debug, Clone)]
pub struct LowerTket2ToHSeriesPass(ValidationLevel);

impl LowerTket2ToHSeriesPass {
/// Run `LowerTket2ToHSeriesPass` on the given [HugrMut]. `registry` is used
/// for validation, if enabled.
pub fn run(
&self,
hugr: &mut impl HugrMut,
registry: &ExtensionRegistry,
) -> Result<(), LowerTk2Error> {
self.0.run_validated_pass(hugr, registry, |hugr, level| {
lower_tk2_op(hugr)?;
if *level != ValidationLevel::None {
check_lowered(hugr).map_err(LowerTk2Error::Unlowered)?;
}
Ok(())
})
}

/// Returns a new `LowerTket2ToHSeriesPass` with the given [ValidationLevel].
pub fn with_validation_level(&self, level: ValidationLevel) -> Self {
Self(level)
}
}

#[cfg(test)]
mod test {
use hugr::{builder::FunctionBuilder, type_row, HugrView};
Expand Down
31 changes: 22 additions & 9 deletions tket2-hseries/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ use tket2::Tk2Op;

use thiserror::Error;

use extension::{futures::FutureOpDef, hseries::HSeriesOp};
use extension::{
futures::FutureOpDef,
hseries::{HSeriesOp, LowerTk2Error, LowerTket2ToHSeriesPass},
};
use lazify_measure::{LazifyMeasurePass, LazifyMeasurePassError};

#[cfg(feature = "cli")]
Expand All @@ -32,7 +35,6 @@ pub struct HSeriesPass {

#[derive(Error, Debug)]
/// An error reported from [HSeriesPass].

pub enum HSeriesPassError {
/// The [hugr::Hugr] was invalid either before or after a pass ran.
#[error(transparent)]
Expand All @@ -43,6 +45,9 @@ pub enum HSeriesPassError {
/// An error from the component [force_order()] pass.
#[error(transparent)]
ForceOrderError(#[from] HugrError),
/// An error from the component [LowerTket2ToHSeriesPass] pass.
#[error(transparent)]
LowerTk2Error(#[from] LowerTk2Error),
}

impl HSeriesPass {
Expand All @@ -53,6 +58,7 @@ impl HSeriesPass {
hugr: &mut impl HugrMut,
registry: &ExtensionRegistry,
) -> Result<(), HSeriesPassError> {
self.lower_tk2().run(hugr, registry)?;
self.lazify_measure().run(hugr, registry)?;
self.validation_level
.run_validated_pass(hugr, registry, |hugr, _| {
Expand All @@ -73,6 +79,10 @@ impl HSeriesPass {
Ok(())
}

fn lower_tk2(&self) -> LowerTket2ToHSeriesPass {
LowerTket2ToHSeriesPass::default().with_validation_level(self.validation_level)
}

fn lazify_measure(&self) -> LazifyMeasurePass {
LazifyMeasurePass::default().with_validation_level(self.validation_level)
}
Expand All @@ -90,15 +100,18 @@ mod test {
builder::{Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer},
extension::prelude::{BOOL_T, QB_T},
ops::handle::NodeHandle,
std_extensions::arithmetic::float_types::ConstF64,
type_row,
types::Signature,
HugrView as _,
};
use itertools::Itertools as _;
use petgraph::visit::{Topo, Walker as _};
use tket2::{extension::angle::ConstAngle, Tk2Op};

use crate::{extension::futures::FutureOpDef, HSeriesPass};
use crate::{
extension::{futures::FutureOpDef, hseries::HSeriesOp},
HSeriesPass,
};

#[test]
fn hseries_pass() {
Expand All @@ -121,19 +134,19 @@ mod test {
.node();

// this LoadConstant should be pushed below the quantum ops where possible
let angle = builder.add_load_value(ConstAngle::PI);
let angle = builder.add_load_value(ConstF64::new(0.0));
let f_node = angle.node();

// with no dependencies, this H should be lifted to the start
// with no dependencies, this Reset should be lifted to the start
let [qb] = builder
.add_dataflow_op(Tk2Op::H, [qb])
.add_dataflow_op(HSeriesOp::Reset, [qb])
.unwrap()
.outputs_arr();
let h_node = qb.node();

// depending on the angle means this op can't be lifted above the angle ops
let [qb] = builder
.add_dataflow_op(Tk2Op::Rx, [qb, angle])
.add_dataflow_op(HSeriesOp::Rz, [qb, angle])
.unwrap()
.outputs_arr();
let rx_node = qb.node();
Expand All @@ -142,7 +155,7 @@ mod test {
// Reads will be added. The Lazy Measure will be lifted and the
// reads will be sunk.
let [qb, measure_result] = builder
.add_dataflow_op(Tk2Op::Measure, [qb])
.add_dataflow_op(HSeriesOp::Measure, [qb])
.unwrap()
.outputs_arr();

Expand Down

0 comments on commit ccc3591

Please sign in to comment.