Skip to content

Commit

Permalink
feat: lowering tk2ops -> hseriesops
Browse files Browse the repository at this point in the history
add test for hadamard lowering
  • Loading branch information
ss2165 committed Sep 3, 2024
1 parent b2313d9 commit 487cb62
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 3 deletions.
111 changes: 108 additions & 3 deletions tket2-hseries/src/extension/hseries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,22 @@
//! laziness is represented by returning `tket2.futures.Future` classical
//! values. Qubits are never lazy.
use hugr::{
builder::{BuildError, Dataflow},
builder::{BuildError, DFGBuilder, Dataflow, DataflowHugr},
extension::{
prelude::{BOOL_T, QB_T},
simple_op::{try_from_name, MakeOpDef, MakeRegisteredOp, OpLoadError},
ExtensionId, ExtensionRegistry, OpDef, SignatureFunc, Version, PRELUDE,
},
ops::{NamedOp as _, OpType},
std_extensions::arithmetic::float_types::{EXTENSION as FLOAT_TYPES, FLOAT64_TYPE},
ops::{self, NamedOp as _, OpTrait, OpType},
std_extensions::arithmetic::float_types::{ConstF64, EXTENSION as FLOAT_TYPES, FLOAT64_TYPE},
type_row,
types::Signature,
Extension, Wire,
};

use lazy_static::lazy_static;
use strum_macros::{EnumIter, EnumString, IntoStaticStr};
use tket2::Tk2Op;

use crate::extension::futures;

Expand Down Expand Up @@ -202,6 +203,51 @@ pub trait HSeriesOpBuilder: Dataflow {

impl<D: Dataflow> HSeriesOpBuilder for D {}

/// Lower `Tk2Op` operations to `HSeriesOp` operations.
pub fn lower_tk2_op(
mut hugr: impl hugr::hugr::hugrmut::HugrMut,
) -> Result<(), Box<dyn std::error::Error>> {
tket2::passes::replace_ops(&mut hugr, |op| {
let op: Tk2Op = op.cast()?;
Some(match op {
Tk2Op::QAlloc => HSeriesOp::QAlloc,
Tk2Op::QFree => HSeriesOp::QFree,
Tk2Op::Reset => HSeriesOp::Reset,
Tk2Op::Measure => HSeriesOp::Measure,
Tk2Op::Rz => HSeriesOp::Rz,
_ => return None,
})
})?;
fn pi_mul(builder: &mut impl Dataflow, multiplier: f64) -> Wire {
builder.add_load_const(ops::Const::new(
ConstF64::new(multiplier * std::f64::consts::PI).into(),
))
}
tket2::passes::lower_ops(&mut hugr, |op| {
let sig = op.dataflow_signature()?;
let sig = Signature::new(sig.input, sig.output); // ignore extension delta
let op = op.cast()?;
let mut b = DFGBuilder::new(sig).ok()?;
Some(match op {
Tk2Op::H => {
let pi = pi_mul(&mut b, 1.0);
let pi_2 = pi_mul(&mut b, 0.5);
let pi_minus_2 = pi_mul(&mut b, -0.5);

let [q] = b.input_wires_arr();

let q = b.add_phased_x(q, pi_2, pi_minus_2).ok()?;
let q = b.add_rz(q, pi).ok()?;

b.finish_hugr_with_outputs([q], &REGISTRY).ok()?
}
_ => return None,
})
})?;

Ok(())
}

#[cfg(test)]
mod test {
use std::sync::Arc;
Expand All @@ -211,8 +257,10 @@ mod test {
use hugr::{
builder::{DataflowHugr, FunctionBuilder},
ops::NamedOp,
HugrView,
};
use strum::IntoEnumIterator as _;
use tket2::Circuit;

use super::*;

Expand Down Expand Up @@ -267,4 +315,61 @@ mod test {
};
assert_matches!(hugr.validate(&REGISTRY), Ok(_));
}

#[test]
fn test_lower_direct() {
let mut b = FunctionBuilder::new(
"circuit",
Signature::new(type_row![FLOAT64_TYPE], type_row![]),
)
.unwrap();
let [angle] = b.input_wires_arr();
let [q] = b.add_dataflow_op(Tk2Op::QAlloc, []).unwrap().outputs_arr();
let [q] = b.add_dataflow_op(Tk2Op::Reset, [q]).unwrap().outputs_arr();
let [q] = b
.add_dataflow_op(Tk2Op::Rz, [q, angle])
.unwrap()
.outputs_arr();
let [q, _] = b
.add_dataflow_op(Tk2Op::Measure, [q])
.unwrap()
.outputs_arr();
b.add_dataflow_op(Tk2Op::QFree, [q]).unwrap();
// TODO remaining ops
let mut h = b.finish_hugr_with_outputs([], &REGISTRY).unwrap();
lower_tk2_op(&mut h).unwrap();
let circ = Circuit::new(&h, h.root());
let ops: Vec<HSeriesOp> = circ
.commands()
.map(|com| com.optype().cast().unwrap())
.collect();
assert_eq!(
ops,
vec![
HSeriesOp::QAlloc,
HSeriesOp::Reset,
HSeriesOp::Rz,
HSeriesOp::Measure,
HSeriesOp::QFree
]
);
}

#[test]
fn test_lower_circuit() {
let mut b = DFGBuilder::new(Signature::new_endo(QB_T)).unwrap();
let [q] = b
.add_dataflow_op(Tk2Op::H, [b.input_wires().next().unwrap()])
.unwrap()
.outputs_arr();
let mut h = b.finish_hugr_with_outputs([q], &REGISTRY).unwrap();

lower_tk2_op(&mut h).unwrap();
let circ = Circuit::new(&h, h.root());
let ops: Vec<HSeriesOp> = circ
.commands()
.filter_map(|com| com.optype().cast())
.collect();
assert_eq!(ops, vec![HSeriesOp::PhasedX, HSeriesOp::Rz]);
}
}
51 changes: 51 additions & 0 deletions tket2/src/passes.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,64 @@
//! Optimisation passes and related utilities for circuits.
mod commutation;
use std::error::Error;

pub use commutation::{apply_greedy_commutation, PullForwardError};

pub mod chunks;
pub use chunks::CircuitChunks;

pub mod pytket;
use hugr::{
hugr::{hugrmut::HugrMut, views::SiblingSubgraph, HugrError},
ops::OpType,
Hugr,
};
pub use pytket::lower_to_pytket;

pub mod tuple_unpack;
pub use tuple_unpack::find_tuple_unpack_rewrites;

// TODO use HUGR versions once they are available

/// Replace all operations in a HUGR according to a mapping.
pub fn replace_ops<S: Into<OpType>>(
hugr: &mut impl HugrMut,
mapping: impl Fn(&OpType) -> Option<S>,
) -> Result<(), HugrError> {
let replacements = hugr
.nodes()
.filter_map(|node| {
let new_op = mapping(hugr.get_optype(node))?;
Some((node, new_op))
})
.collect::<Vec<_>>();

for (node, new_op) in replacements {
hugr.replace_op(node, new_op)?;
}

Ok(())
}

/// Lower operations in a circuit according to a mapping to a new HUGR.
pub fn lower_ops(
hugr: &mut impl HugrMut,
lowering: impl Fn(&OpType) -> Option<Hugr>,
) -> Result<(), Box<dyn Error>> {
let replacements = hugr
.nodes()
.filter_map(|node| {
let hugr = lowering(hugr.get_optype(node))?;
Some((node, hugr))
})
.collect::<Vec<_>>();

for (node, replacement) in replacements {
let subcirc = SiblingSubgraph::try_from_nodes([node], hugr)?;
let rw = subcirc.create_simple_replacement(hugr, replacement)?;
hugr.apply_rewrite(rw)?;
}

Ok(())
}

0 comments on commit 487cb62

Please sign in to comment.