Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Utilities for loading compiled guppy circuits #393

Merged
merged 9 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ serde = "1.0"
serde_json = "1.0"
serde_yaml = "0.9.22"
smol_str = "0.2.0"
stringreader = "0.1.1"
strum = "0.26.1"
strum_macros = "0.26.4"
thiserror = "1.0.28"
Expand Down
298 changes: 224 additions & 74 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ readme = "README.md"
packages = [{ include = "tket2-py" }]

[tool.poetry.dependencies]
python = ">=3.10"
python = "^3.10"
pytket = "1.28.0"

[tool.poetry.group.dev.dependencies]
Expand All @@ -36,6 +36,7 @@ mypy = "^1.9.0"
hypothesis = "^6.103.1"
graphviz = "^0.20"
pre-commit = "^3.7.1"
guppylang = "^0.5.0"

[build-system]
requires = ["maturin~=1.5.1"]
Expand Down
27 changes: 22 additions & 5 deletions tket2-py/src/circuit/tk2circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,15 @@ impl Tk2Circuit {
Ok(Tk2Circuit { circ: hugr.into() })
}

/// Load a function from a compiled guppy module, encoded as a json string.
#[staticmethod]
pub fn from_guppy_json(json: &str, function: &str) -> PyResult<Self> {
let circ = tket2::serialize::load_guppy_json_str(json, function).map_err(|e| {
PyErr::new::<PyAttributeError, _>(format!("Invalid encoded circuit: {e}"))
})?;
Ok(Tk2Circuit { circ })
}

/// Encode the circuit as a tket1 json string.
pub fn to_tket1_json(&self) -> PyResult<String> {
Ok(serde_json::to_string(&SerialCircuit::encode(&self.circ).convert_pyerrs()?).unwrap())
Expand All @@ -106,11 +115,10 @@ impl Tk2Circuit {
/// Decode a tket1 json string to a circuit.
#[staticmethod]
pub fn from_tket1_json(json: &str) -> PyResult<Self> {
let tk1: SerialCircuit = serde_json::from_str(json)
.map_err(|e| PyErr::new::<PyAttributeError, _>(format!("Invalid encoded HUGR: {e}")))?;
Ok(Tk2Circuit {
circ: tk1.decode().convert_pyerrs()?,
})
let circ = tket2::serialize::load_tk1_json_str(json).map_err(|e| {
PyErr::new::<PyAttributeError, _>(format!("Could not load pytket circuit: {e}"))
})?;
Ok(Tk2Circuit { circ })
}

/// Compute the cost of the circuit based on a per-operation cost function.
Expand Down Expand Up @@ -138,6 +146,15 @@ impl Tk2Circuit {
Ok(circ_cost.cost.into_bound(py))
}

/// Returns the number of operations in the circuit.
///
/// This includes [`Tk2Op`]s, pytket ops, and any other custom operations.
///
/// Nested circuits are traversed to count their operations.
pub fn num_operations(&self) -> usize {
self.circ.num_operations()
}

/// Returns a hash of the circuit.
pub fn hash(&self) -> u64 {
self.circ.circuit_hash().unwrap()
Expand Down
39 changes: 39 additions & 0 deletions tket2-py/test/test_guppy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import no_type_check
from tket2.circuit import Tk2Circuit

import math

from guppylang import guppy
from guppylang.module import GuppyModule
from guppylang.prelude import quantum
from guppylang.prelude.builtins import py
from guppylang.prelude.quantum import measure, phased_x, qubit, rz, zz_max


def test_load_compiled_module():
module = GuppyModule("test")
module.load(quantum)

@guppy(module)
@no_type_check
def my_func(
q0: qubit,
q1: qubit,
) -> tuple[bool,]:
q0 = phased_x(q0, py(math.pi / 2), py(-math.pi / 2))
q0 = rz(q0, py(math.pi))
q1 = phased_x(q1, py(math.pi / 2), py(-math.pi / 2))
q1 = rz(q1, py(math.pi))
q0, q1 = zz_max(q0, q1)
_ = measure(q0)
return (measure(q1),)

Check warning on line 29 in tket2-py/test/test_guppy.py

View check run for this annotation

Codecov / codecov/patch

tket2-py/test/test_guppy.py#L23-L29

Added lines #L23 - L29 were not covered by tests

# Compile the module, and convert it to a JSON string
hugr = module.compile()
json = hugr.to_raw().to_json()

# Load the module from the JSON string
circ = Tk2Circuit.from_guppy_json(json, "my_func")

# The 7 operations in the function, plus two implicit QFree
assert circ.num_operations() == 9
12 changes: 12 additions & 0 deletions tket2-py/tket2/_tket2/circuit.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ class Tk2Circuit:
def circuit_cost(self, cost_fn: Callable[[Tk2Op], Any]) -> int:
"""Compute the cost of the circuit. Return value must implement __add__."""

def num_operations(self) -> int:
"""The number of operations in the circuit.

This includes [`Tk2Op`]s, pytket ops, and any other custom operations.

Nested circuits are traversed to count their operations.
"""

def node_op(self, node: Node) -> CustomOp:
"""If the node corresponds to a custom op, return it. Otherwise, raise an error."""

Expand Down Expand Up @@ -55,6 +63,10 @@ class Tk2Circuit:
def to_tket1_json(self) -> str:
"""Encode the circuit as a pytket json string."""

@staticmethod
def from_guppy_json(json: str, function: str) -> Tk2Circuit:
"""Load a function from a compiled guppy module, encoded as a json string."""

@staticmethod
def from_tket1_json(json: str) -> Tk2Circuit:
"""Decode a pytket json string to a Tk2Circuit."""
Expand Down
1 change: 0 additions & 1 deletion tket2/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ delegate = { workspace = true }
csv = { workspace = true }
chrono = { workspace = true }
bytemuck = { workspace = true }
stringreader = { workspace = true }
crossbeam-channel = { workspace = true }
tracing = { workspace = true }

Expand Down
2 changes: 1 addition & 1 deletion tket2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,6 @@ pub mod portmatching;

mod utils;

pub use circuit::Circuit;
pub use circuit::{Circuit, CircuitError, CircuitMutError};
pub use hugr::Hugr;
pub use ops::{op_matches, symbolic_constant_op, Pauli, Tk2Op};
4 changes: 4 additions & 0 deletions tket2/src/serialize.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
//! Utilities for serializing circuits.
//!
//! See [`crate::serialize::pytket`] for serialization to and from the legacy pytket format.
pub mod guppy;
pub mod pytket;

pub use guppy::{
load_guppy_json_file, load_guppy_json_reader, load_guppy_json_str, CircuitLoadError,
};
pub use pytket::{
load_tk1_json_file, load_tk1_json_reader, load_tk1_json_str, save_tk1_json_file,
save_tk1_json_str, save_tk1_json_writer, TKETDecode,
Expand Down
142 changes: 142 additions & 0 deletions tket2/src/serialize/guppy.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
//! Load pre-compiled guppy functions.

use std::path::Path;
use std::{fs, io};

use hugr::ops::{NamedOp, OpTag, OpTrait, OpType};
use hugr::{Hugr, HugrView};
use itertools::Itertools;
use thiserror::Error;

use crate::{Circuit, CircuitError};

/// Loads a pre-compiled guppy file.
pub fn load_guppy_json_file(
path: impl AsRef<Path>,
function: &str,
) -> Result<Circuit, CircuitLoadError> {
let file = fs::File::open(path)?;
let reader = io::BufReader::new(file);
load_guppy_json_reader(reader, function)
}

/// Loads a pre-compiled guppy file from a json string.
pub fn load_guppy_json_str(json: &str, function: &str) -> Result<Circuit, CircuitLoadError> {
let reader = json.as_bytes();
load_guppy_json_reader(reader, function)
}

/// Loads a pre-compiled guppy file from a reader.
pub fn load_guppy_json_reader(
reader: impl io::Read,
function: &str,
) -> Result<Circuit, CircuitLoadError> {
let hugr: Hugr = serde_json::from_reader(reader)?;
find_function(hugr, function)
}

/// Looks for the required function in a HUGR compiled from a guppy module.
///
/// Guppy functions are compiled into a root module, with each function as a `FuncDecl` child.
/// Each `FuncDecl` contains a `CFG` operation that defines the function.
///
/// Currently we only support functions where the CFG operation has a single `DataflowBlock` child,
/// which we use as the root of the circuit. We (currently) do not support control flow primitives.
///
/// # Errors
///
/// - If the root of the HUGR is not a module operation.
/// - If the function is not found in the module.
/// - If the function has control flow primitives.
fn find_function(hugr: Hugr, function_name: &str) -> Result<Circuit, CircuitLoadError> {
// Find the root module.
let module = hugr.root();
if !OpTag::ModuleRoot.is_superset(hugr.get_optype(module).tag()) {
return Err(CircuitLoadError::NonModuleRoot {
root_op: hugr.get_optype(module).clone(),
});
}

// Find the function declaration.
fn func_name(op: &OpType) -> &str {
match op {
OpType::FuncDefn(decl) => &decl.name,
_ => "",
}
}

let Some(function) = hugr
.children(module)
.find(|&n| func_name(hugr.get_optype(n)) == function_name)
else {
let available_functions = hugr
.children(module)
.map(|n| func_name(hugr.get_optype(n)).to_string())
.collect();
return Err(CircuitLoadError::FunctionNotFound {
function: function_name.to_string(),
available_functions,
});
};

// Find the CFG operation.
let invalid_cfg = CircuitLoadError::InvalidControlFlow {
function: function_name.to_string(),
};
let Ok(cfg) = hugr.children(function).skip(2).exactly_one() else {
return Err(invalid_cfg);
};

// Find the single dataflow block to use as the root of the circuit.
// The cfg node should only have the dataflow block and an exit node as children.
let mut cfg_children = hugr.children(cfg);
let Some(dataflow) = cfg_children.next() else {
return Err(invalid_cfg);
};
if cfg_children.nth(1).is_some() {
return Err(invalid_cfg);
}

let circ = Circuit::try_new(hugr, dataflow)?;
Ok(circ)
}

/// Error type for conversion between `Op` and `OpType`.
#[derive(Debug, Error)]
pub enum CircuitLoadError {
/// Cannot load the circuit file.
#[error("Cannot load the circuit file: {0}")]
InvalidFile(#[from] io::Error),
/// Invalid JSON
#[error("Invalid JSON. {0}")]
InvalidJson(#[from] serde_json::Error),
/// The root node is not a module operation.
#[error(
"Expected a HUGR with a module at the root, but found a {} instead.",
root_op.name()
)]
NonModuleRoot {
/// The root operation.
root_op: OpType,
},
/// The function is not found in the module.
#[error(
"Function '{function}' not found in the loaded module. Available functions: [{}]",
available_functions.join(", ")
)]
FunctionNotFound {
/// The function name.
function: String,
/// The available functions.
available_functions: Vec<String>,
},
/// The function has an invalid control flow structure.
#[error("Function '{function}' has an invalid control flow structure. Currently only flat functions with no control flow primitives are supported.")]
InvalidControlFlow {
/// The function name.
function: String,
},
/// Error loading the circuit.
#[error("Error loading the circuit: {0}")]
CircuitLoadError(#[from] CircuitError),
}
3 changes: 1 addition & 2 deletions tket2/src/serialize/pytket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ use std::{fs, io};
use hugr::ops::{OpType, Value};
use hugr::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE};

use stringreader::StringReader;
use thiserror::Error;
use tket_json_rs::circuit_json::SerialCircuit;
use tket_json_rs::optype::OpType as JsonOpType;
Expand Down Expand Up @@ -120,7 +119,7 @@ pub fn load_tk1_json_reader(json: impl io::Read) -> Result<Circuit, TK1ConvertEr

/// Load a TKET1 circuit from a JSON string.
pub fn load_tk1_json_str(json: &str) -> Result<Circuit, TK1ConvertError> {
let reader = StringReader::new(json);
let reader = json.as_bytes();
load_tk1_json_reader(reader)
}

Expand Down