Skip to content

Commit

Permalink
feat: guppy → pytket conversion (#407)
Browse files Browse the repository at this point in the history
Adds support for converting **flat** **pure** functions defined in guppy
into pytket circuits. See the example in `test_guppy.py`.

This PR just adds a `lower_to_pytket` pass that currently only runs the
tuple erasure from #406.
  • Loading branch information
aborgna-q authored Jun 18, 2024
1 parent 425d6b4 commit 8c5a487
Show file tree
Hide file tree
Showing 10 changed files with 165 additions and 32 deletions.
10 changes: 6 additions & 4 deletions tket2-py/src/circuit/tk2circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use hugr::{Hugr, HugrView, Wire};
use serde::Serialize;
use tket2::circuit::CircuitHash;
use tket2::extension::REGISTRY;
use tket2::passes::pytket::lower_to_pytket;
use tket2::passes::CircuitChunks;
use tket2::serialize::TKETDecode;
use tket2::{Circuit, Tk2Op};
Expand Down Expand Up @@ -73,9 +74,8 @@ impl Tk2Circuit {

/// Convert the [`Tk2Circuit`] to a tket1 circuit.
pub fn to_tket1<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
SerialCircuit::encode(&self.circ)
.convert_pyerrs()?
.to_tket1(py)
let circ = lower_to_pytket(&self.circ).convert_pyerrs()?;
SerialCircuit::encode(&circ).convert_pyerrs()?.to_tket1(py)
}

/// Apply a rewrite on the circuit.
Expand Down Expand Up @@ -109,7 +109,9 @@ impl Tk2Circuit {

/// Encode the circuit as a tket1 json string.
pub fn to_tket1_json(&self) -> PyResult<String> {
Ok(serde_json::to_string(&SerialCircuit::encode(&self.circ).convert_pyerrs()?).unwrap())
// Try to simplify tuple pack-unpack pairs, and other operations not supported by pytket.
let circ = lower_to_pytket(&self.circ).convert_pyerrs()?;
Ok(serde_json::to_string(&SerialCircuit::encode(&circ).convert_pyerrs()?).unwrap())
}

/// Decode a tket1 json string to a circuit.
Expand Down
6 changes: 6 additions & 0 deletions tket2-py/src/passes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ create_py_exception!(
"Error from a `PullForward` operation"
);

create_py_exception!(
tket2::passes::pytket::PytketLoweringError,
PyPytketLoweringError,
"Errors that can occur while removing high-level operations from HUGR intended to be encoded as a pytket circuit."
);

#[pyfunction]
fn greedy_depth_reduce<'py>(circ: &Bound<'py, PyAny>) -> PyResult<(Bound<'py, PyAny>, u32)> {
let py = circ.py();
Expand Down
Empty file added tket2-py/test/__init__.py
Empty file.
46 changes: 37 additions & 9 deletions tket2-py/test/test_guppy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import no_type_check
from tket2.circuit import Tk2Circuit
import pytket.circuit
from test.util import guppy_to_circuit

import math

Expand All @@ -9,8 +10,40 @@
from guppylang.prelude.builtins import py
from guppylang.prelude.quantum import measure, phased_x, qubit, rz, zz_max

import pytket

def test_load_compiled_module():

def test_load_pure_circuit():
module = GuppyModule("test")
module.load(quantum)

@guppy(module)
@no_type_check
def my_func(
q0: qubit,
q1: qubit,
) -> tuple[qubit, qubit]: # pragma: no cover
q0 = phased_x(q0, py(math.pi / 2), py(-math.pi / 2))
q0 = rz(q0, py(math.pi))
q1 = phased_x(q1, py(math.pi / 2), py(-math.pi / 2))
q1 = rz(q1, py(math.pi))
q0, q1 = zz_max(q0, q1)
q0 = rz(q0, py(math.pi))
q1 = rz(q1, py(math.pi))
return (q0, q1)

circ = guppy_to_circuit(my_func)
assert circ.num_operations() == 7

tk1 = circ.to_tket1()
assert tk1.n_gates == 7
assert tk1.n_qubits == 2

gates = list(tk1)
assert gates[4].op.type == pytket.circuit.OpType.ZZMax


def test_load_hybrid_circuit():
module = GuppyModule("test")
module.load(quantum)

Expand All @@ -19,7 +52,7 @@ def test_load_compiled_module():
def my_func(
q0: qubit,
q1: qubit,
) -> tuple[bool,]:
) -> tuple[bool,]: # pragma: no cover
q0 = phased_x(q0, py(math.pi / 2), py(-math.pi / 2))
q0 = rz(q0, py(math.pi))
q1 = phased_x(q1, py(math.pi / 2), py(-math.pi / 2))
Expand All @@ -28,12 +61,7 @@ def my_func(
_ = measure(q0)
return (measure(q1),)

# Compile the module, and convert it to a JSON string
hugr = module.compile()
json = hugr.to_raw().to_json()

# Load the module from the JSON string
circ = Tk2Circuit.from_guppy_json(json, "my_func")
circ = guppy_to_circuit(my_func)

# The 7 operations in the function, plus two implicit QFree
assert circ.num_operations() == 9
15 changes: 15 additions & 0 deletions tket2-py/test/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from guppylang.definition.function import RawFunctionDef

from tket2.circuit import Tk2Circuit


def guppy_to_circuit(func_def: RawFunctionDef) -> Tk2Circuit:
"""Convert a Guppy function definition to a `Tk2Circuit`."""
module = func_def.id.module
assert module is not None, "Function definition must belong to a module"

hugr = module.compile()
assert hugr is not None, "Module must be compilable"

json = hugr.to_raw().to_json()
return Tk2Circuit.from_guppy_json(json, func_def.name)
3 changes: 3 additions & 0 deletions tket2/src/passes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,8 @@ pub use commutation::{apply_greedy_commutation, PullForwardError};
pub mod chunks;
pub use chunks::CircuitChunks;

pub mod pytket;
pub use pytket::lower_to_pytket;

pub mod tuple_unpack;
pub use tuple_unpack::find_tuple_unpack_rewrites;
39 changes: 39 additions & 0 deletions tket2/src/passes/pytket.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
//! This module contains routines needed for normalizing a circuit
//! into a form that can be encoded as a pytket legacy circuit.
//!
//! This is a best-effort attempt, and may not always succeed.

use itertools::Itertools;

use crate::serialize::pytket::OpConvertError;
use crate::Circuit;

use super::find_tuple_unpack_rewrites;

/// Try to lower a circuit to a form that can be encoded as a pytket legacy circuit.
pub fn lower_to_pytket(circ: &Circuit) -> Result<Circuit, PytketLoweringError> {
let mut circ = circ
.extract_dfg()
.map_err(|_| PytketLoweringError::NonLocalOperations)?;

// Remove sequences of tuple pack-unpack operations,
// typically generated by guppy.
let rewrites = find_tuple_unpack_rewrites(&circ).collect_vec();
for rewrite in rewrites {
rewrite.apply(&mut circ).unwrap();
}

Ok(circ)
}

/// Errors that can occur during the lowering process.
#[derive(Clone, PartialEq, Debug, thiserror::Error)]
pub enum PytketLoweringError {
/// An error occurred during the conversion of an operation.
#[error("operation conversion error: {0}")]
OpConversionError(#[from] OpConvertError),
/// The circuit is not fully-contained in a region.
/// Function calls are not supported.
#[error("Non-local operations found. Function calls are not supported.")]
NonLocalOperations,
}
61 changes: 51 additions & 10 deletions tket2/src/serialize/pytket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ use crate::circuit::Circuit;
use self::decoder::JsonDecoder;
use self::encoder::JsonEncoder;

pub use crate::passes::pytket::lower_to_pytket;

/// Prefix used for storing metadata in the hugr nodes.
pub const METADATA_PREFIX: &str = "TKET1_JSON";
/// The global phase specified as metadata.
Expand Down Expand Up @@ -92,7 +94,7 @@ impl TKETDecode for SerialCircuit {
}

/// Error type for conversion between `Op` and `OpType`.
#[derive(Debug, Error)]
#[derive(Clone, PartialEq, Debug, Error)]
pub enum OpConvertError {
/// The serialized operation is not supported.
#[error("Unsupported serialized pytket operation: {0:?}")]
Expand Down Expand Up @@ -123,20 +125,41 @@ pub fn load_tk1_json_str(json: &str) -> Result<Circuit, TK1ConvertError> {
}

/// Save a circuit to file in TK1 JSON format.
///
/// You may need to normalize the circuit using [`lower_to_pytket`] before saving.
///
/// # Errors
///
/// Returns an error if the circuit is not flat or if it contains operations not
/// supported by pytket.
pub fn save_tk1_json_file(circ: &Circuit, path: impl AsRef<Path>) -> Result<(), TK1ConvertError> {
let file = fs::File::create(path)?;
let writer = io::BufWriter::new(file);
save_tk1_json_writer(circ, writer)
}

/// Save a circuit in TK1 JSON format to a writer.
///
/// You may need to normalize the circuit using [`lower_to_pytket`] before saving.
///
/// # Errors
///
/// Returns an error if the circuit is not flat or if it contains operations not
/// supported by pytket.
pub fn save_tk1_json_writer(circ: &Circuit, w: impl io::Write) -> Result<(), TK1ConvertError> {
let serial_circ = SerialCircuit::encode(circ)?;
serde_json::to_writer(w, &serial_circ)?;
Ok(())
}

/// Save a circuit in TK1 JSON format to a String.
///
/// You may need to normalize the circuit using [`lower_to_pytket`] before saving.
///
/// # Errors
///
/// Returns an error if the circuit is not flat or if it contains operations not
/// supported by pytket.
pub fn save_tk1_json_str(circ: &Circuit) -> Result<String, TK1ConvertError> {
let mut buf = io::BufWriter::new(Vec::new());
save_tk1_json_writer(circ, &mut buf)?;
Expand Down Expand Up @@ -167,22 +190,40 @@ pub enum TK1ConvertError {
FileLoadError(#[from] io::Error),
}

#[inline]
fn parse_val(n: &str) -> Option<f64> {
n.parse::<f64>().ok()
}
/// Try to interpret a TKET1 parameter as a constant value.
///
/// Angle parameters in TKET1 are encoded as a number of half-turns,
/// whereas HUGR uses radians.
#[inline]
fn try_param_to_constant(param: &str) -> Option<Value> {
if let Some(f) = parse_val(param) {
Some(ConstF64::new(f).into())
fn parse_val(n: &str) -> Option<f64> {
n.parse::<f64>().ok()
}

let half_turns = if let Some(f) = parse_val(param) {
f
} else if param.split('/').count() == 2 {
// TODO: Use the rational types from `Hugr::extensions::rotation`
let (n, d) = param.split_once('/').unwrap();
let n = parse_val(n)?;
let d = parse_val(d)?;
Some(ConstF64::new(n / d).into())
n / d
} else {
None
}
return None;
};

let radians = half_turns * std::f64::consts::PI;
Some(ConstF64::new(radians).into())
}

/// Convert a HUGR angle constant to a TKET1 parameter.
///
/// Angle parameters in TKET1 are encoded as a number of half-turns,
/// whereas HUGR uses radians.
#[inline]
fn try_constant_to_param(val: &Value) -> Option<String> {
let const_float = val.get_custom_value::<ConstF64>()?;
let radians: f64 = **const_float;
let half_turns = radians / std::f64::consts::PI;
Some(half_turns.to_string())
}
13 changes: 6 additions & 7 deletions tket2/src/serialize/pytket/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use std::collections::HashMap;

use hugr::extension::prelude::QB_T;
use hugr::ops::{NamedOp, OpType};
use hugr::std_extensions::arithmetic::float_types::ConstF64;
use hugr::{HugrView, Wire};
use itertools::{Either, Itertools};
use tket_json_rs::circuit_json::{self, Permutation, Register, SerialCircuit};
Expand All @@ -18,8 +17,8 @@ use crate::Tk2Op;

use super::op::JsonOp;
use super::{
OpConvertError, METADATA_B_REGISTERS, METADATA_IMPLICIT_PERM, METADATA_PHASE,
METADATA_Q_REGISTERS,
try_constant_to_param, OpConvertError, METADATA_B_REGISTERS, METADATA_IMPLICIT_PERM,
METADATA_PHASE, METADATA_Q_REGISTERS,
};

/// The state of an in-progress [`SerialCircuit`] being built from a [`Circuit`].
Expand Down Expand Up @@ -198,10 +197,10 @@ impl JsonEncoder {
let param = match optype {
OpType::Const(const_op) => {
// New constant, register it if it can be interpreted as a parameter.
let Some(const_float) = const_op.value().get_custom_value::<ConstF64>() else {
return false;
};
const_float.to_string()
match try_constant_to_param(const_op.value()) {
Some(param) => param,
None => return false,
}
}
OpType::LoadConstant(_op_type) => {
// Re-use the parameter from the input.
Expand Down
4 changes: 2 additions & 2 deletions tket2/src/serialize/pytket/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ fn circ_add_angles_constants() -> Circuit {

let qb = h.input_wires().next().unwrap();

let point2 = h.add_load_value(ConstF64::new(0.2));
let point3 = h.add_load_value(ConstF64::new(0.3));
let point2 = h.add_load_value(ConstF64::new(0.2 * std::f64::consts::PI));
let point3 = h.add_load_value(ConstF64::new(0.3 * std::f64::consts::PI));
let point5 = h
.add_dataflow_op(Tk2Op::AngleAdd, [point2, point3])
.unwrap()
Expand Down

0 comments on commit 8c5a487

Please sign in to comment.