From 2626722f23dc53c24016cb5364cf9d48cbb8e18b Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Fri, 6 Sep 2024 13:46:27 +0100 Subject: [PATCH] build hugr table ahead of time --- tket2-hseries/src/extension/hseries/lower.rs | 21 ++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/tket2-hseries/src/extension/hseries/lower.rs b/tket2-hseries/src/extension/hseries/lower.rs index 402cb0cb..e45956d6 100644 --- a/tket2-hseries/src/extension/hseries/lower.rs +++ b/tket2-hseries/src/extension/hseries/lower.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use hugr::{ builder::{BuildError, DFGBuilder, Dataflow, DataflowHugr}, hugr::{hugrmut::HugrMut, HugrError}, @@ -6,6 +8,7 @@ use hugr::{ types::Signature, Hugr, HugrView, Node, Wire, }; +use strum::IntoEnumIterator; use thiserror::Error; use tket2::{extension::angle::AngleOpBuilder, Tk2Op}; @@ -73,9 +76,6 @@ fn op_to_hugr(op: Tk2Op) -> Result { b.build_crz(*c, *t, float)?.into() } (Tk2Op::Toffoli, [a, b_, c]) => b.build_toffoli(*a, *b_, *c)?.into(), - (Tk2Op::QAlloc | Tk2Op::QFree | Tk2Op::Reset | Tk2Op::Measure, _) => { - unreachable!("should be covered by lower_direct") - } _ => return Err(LowerTk2Error::UnknownOp(op, inputs.len())), // non-exhaustive }; Ok(b.finish_hugr_with_outputs(outputs, ®ISTRY)?) @@ -84,7 +84,20 @@ fn op_to_hugr(op: Tk2Op) -> Result { /// Lower `Tk2Op` operations to `HSeriesOp` operations. pub fn lower_tk2_op(mut hugr: impl HugrMut) -> Result, LowerTk2Error> { let replaced_nodes = lower_direct(&mut hugr)?; - let lowered_nodes = hugr::algorithms::lower_ops(&mut hugr, |op| op_to_hugr(op.cast()?).ok())?; + let mut hugr_map: HashMap = HashMap::new(); + for op in Tk2Op::iter() { + match op_to_hugr(op) { + Ok(h) => hugr_map.insert(op, h), + // filter out unknown ops, includes those covered by direct lowering + Err(LowerTk2Error::UnknownOp(_, _)) => continue, + Err(e) => return Err(e), + }; + } + + let lowered_nodes = hugr::algorithms::lower_ops(&mut hugr, |op| { + let op: Tk2Op = op.cast()?; + hugr_map.get(&op).cloned() + })?; Ok([replaced_nodes, lowered_nodes].concat()) }