Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: Track circuit extensions and read/write packages #680

Merged
merged 10 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ license = "Apache-2.0"
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(ci_run)'] }
missing_docs = "warn"

[patch.crates-io]
hugr-core = { git = "https://github.com/CQCL/hugr.git", branch = "ab/package-from-hugr" }
hugr = { git = "https://github.com/CQCL/hugr.git", branch = "ab/package-from-hugr" }
hugr-cli = { git = "https://github.com/CQCL/hugr.git", branch = "ab/package-from-hugr" }

[workspace.dependencies]

# Make sure to run `just recompile-eccs` if the hugr serialisation format changes.
Expand Down
2 changes: 1 addition & 1 deletion tket2-py/examples/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@ def guppy_to_circuit(func_def: RawFunctionDef) -> Tk2Circuit:
pkg = module.compile()

json = pkg.to_json()
circ = Tk2Circuit.from_guppy_json(json, func_def.name)
circ = Tk2Circuit.from_package_json(json, func_def.name)

return lower_to_pytket(circ)
59 changes: 41 additions & 18 deletions tket2-py/src/circuit/tk2circuit.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! Rust-backed representation of circuits

use std::borrow::{Borrow, Cow};
use std::fmt::Display;
use std::mem;

use hugr::builder::{CircuitBuilder, DFGBuilder, Dataflow, DataflowHugr};
Expand Down Expand Up @@ -91,32 +92,54 @@ impl Tk2Circuit {
//
// TODO: Bind a messagepack encoder/decoder too.
pub fn to_hugr_json(&self) -> PyResult<String> {
Ok(serde_json::to_string(self.circ.hugr()).unwrap())
fn err(e: impl Display) -> PyErr {
PyErr::new::<PyAttributeError, _>(format!("Could not encode circuit: {e}"))
};
let mut buf = Vec::new();
self.circ.to_hugr_writer(&mut buf).map_err(err)?;
let res = std::str::from_utf8(&buf).map_err(err)?;
Ok(res.to_string())
}

/// Decode a HUGR json string to a circuit.
/// Encode the circuit as a Hugr Package json string.
//
// TODO: Bind a messagepack encoder/decoder too.
pub fn to_package_json(&self) -> PyResult<String> {
fn err(e: impl Display) -> PyErr {
PyErr::new::<PyAttributeError, _>(format!("Could not encode circuit: {e}"))
};
let mut buf = Vec::new();
self.circ.to_package_writer(&mut buf).map_err(err)?;
let res = std::str::from_utf8(&buf).map_err(err)?;
Ok(res.to_string())
}

/// Decode a HUGR json to a circuit.
#[staticmethod]
pub fn from_hugr_json(json: &str) -> PyResult<Self> {
let mut pkg: Package = serde_json::from_str(json)
.map_err(|e| PyErr::new::<PyAttributeError, _>(format!("Invalid encoded HUGR: {e}")))?;
let mut reg = REGISTRY.clone();
pkg.update_validate(&mut reg).map_err(|e| {
PyErr::new::<PyAttributeError, _>(format!("Invalid encoded circuit: {e}"))
})?;
let Ok(hugr) = pkg.modules.into_iter().exactly_one() else {
return Err(PyValueError::new_err(
"Invalid HUGR json: Package must contain exactly one hugr.",
));
fn err(e: impl Display) -> PyErr {
PyErr::new::<PyAttributeError, _>(format!("Could not read hugr: {e}"))
};
Ok(Tk2Circuit { circ: hugr.into() })
let circ = Circuit::load_hugr_reader(json.as_bytes()).map_err(err)?;
Ok(Tk2Circuit { circ })
}

/// Load a function from a compiled guppy module, encoded as a json string.
/// Decode a HUGR Package json to a circuit.
///
/// Traverses the package's modules in order until it finds one containing a
/// function named `function_name`, and loads it as a circuit.
///
/// If the json is a hugr json, it will be decoded as a `main` function in an empty module.
///
/// When `function_name` is not given, it defaults to `main`.
#[staticmethod]
pub fn from_guppy_json(json: &str, function: &str) -> PyResult<Self> {
let circ = tket2::serialize::load_guppy_json_str(json, function).map_err(|e| {
PyErr::new::<PyAttributeError, _>(format!("Invalid encoded circuit: {e}"))
})?;
#[pyo3(signature = (json, function_name = None))]
pub fn from_package_json(json: &str, function_name: Option<String>) -> PyResult<Self> {
fn err(e: impl Display) -> PyErr {
PyErr::new::<PyAttributeError, _>(format!("Could not read package: {e}"))
};
let name = function_name.unwrap_or_else(|| "main".to_string());
let circ = Circuit::load_function_reader(json.as_bytes(), &name).map_err(err)?;
Ok(Tk2Circuit { circ })
}

Expand Down
25 changes: 19 additions & 6 deletions tket2-py/tket2/_tket2/circuit.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,31 @@ class Tk2Circuit:
"""The output node of the circuit."""

def to_hugr_json(self) -> str:
"""Encode the circuit as a HUGR json string."""
"""Encode the circuit as a HUGR json."""

def to_package_json(self) -> str:
"""Encode the circuit as a HUGR Package json."""

@staticmethod
def from_hugr_json(json: str) -> Tk2Circuit:
"""Decode a HUGR json string to a Tk2Circuit."""

def to_tket1_json(self) -> str:
"""Encode the circuit as a pytket json string."""

@staticmethod
def from_guppy_json(json: str, function: str) -> Tk2Circuit:
"""Load a function from a compiled guppy module, encoded as a json string."""
def from_package_json(json: str, function_name: str | None = None) -> Tk2Circuit:
"""Decode a HUGR Package json to a circuit.

Traverses the package's modules in order until it finds one containing a
function named `function_name`, and loads it as a circuit.

If the json is a hugr json, it will be decoded as a `main` function in an empty module.

When `function_name` is not given, it defaults to `main`.
"""

def to_tket1_json(
self,
) -> str:
"""Encode the circuit as a pytket json string."""

@staticmethod
def from_tket1_json(json: str) -> Tk2Circuit:
Expand Down
42 changes: 34 additions & 8 deletions tket2-py/tket2/circuit/build.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations
from typing import Iterable

from hugr import tys, ops
from hugr import Hugr, tys, ops
from hugr.package import Package
from hugr.ext import Extension
from hugr.ops import ComWire, Command
from hugr.std.float import FLOAT_T
from hugr.build.function import Module
from hugr.build.tracked_dfg import TrackedDfg
from tket2.circuit import Tk2Circuit

Expand All @@ -20,17 +21,33 @@ class CircBuild(TrackedDfg):
def with_nqb(cls, n_qb: int) -> CircBuild:
return cls(*[tys.Qubit] * n_qb, track_inputs=True)

def finish_hugr(self) -> Hugr:
"""Finish building the package by setting all the qubits as the output
and wrap it in a hugr package with the required extensions.
aborgna-q marked this conversation as resolved.
Show resolved Hide resolved

Returns:
The finished Hugr.
"""
return self.hugr

def finish_package(
self, other_extensions: Iterable[Extension] | None = None
self,
*,
other_extensions: Iterable[Extension] | None = None,
function_name="main",
) -> Package:
"""Finish building the package by setting all the qubits as the output
and wrap it in a hugr package with the required extensions.

Args:
other_extensions: Other extensions to include in the package.
function_name: The name of the function containing the circuit in
the package's module. Defaults to "main".
Returns:
The finished package.
"""
# TODO: Replace with `finish_hugr` once extensions are included in the hugr itself.
# See https://github.com/CQCL/hugr/pull/1621
import tket2.extensions as ext

extensions = [
Expand All @@ -42,13 +59,26 @@ def finish_package(
*(other_extensions or []),
]

return Package(modules=[self.hugr], extensions=extensions)
# Convert the DFG into a Function definition
dfg_op = self.hugr[self.hugr.root].op
assert type(dfg_op) is ops.DFG, "CircBuild must have a Dfg root"
self.hugr[self.hugr.root].op = ops.FuncDefn(
function_name, inputs=dfg_op.inputs, _outputs=dfg_op.outputs
)
Comment on lines +64 to +66
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hugr rewriting is so much easier in python 🤔


# Insert it into a module, as required by the package.
module = Module()
module.hugr.insert_hugr(self.hugr)

return Package(modules=[module.hugr], extensions=extensions)

def finish(self, other_extensions: list[Extension] | None = None) -> Tk2Circuit:
"""Finish building the circuit by setting all the qubits as the output
and validate."""

return load_hugr_pkg(self.finish_package(other_extensions))
return Tk2Circuit.from_package_json(
self.finish_package(other_extensions=other_extensions).to_json()
)


def from_coms(*args: Command) -> Tk2Circuit:
Expand All @@ -68,10 +98,6 @@ def from_coms(*args: Command) -> Tk2Circuit:
return build.finish()


def load_hugr_pkg(package: Package) -> Tk2Circuit:
return Tk2Circuit.from_hugr_json(package.to_json())


def load_custom(serialized: bytes) -> ops.Custom:
import hugr._serialization.ops as sops
import json
Expand Down
Loading
Loading