diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 65ccc40..c62e3c2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -43,4 +43,5 @@ repos: pytket-quantinuum, pytket, types-setuptools, + wasmtime, ] diff --git a/Makefile b/Makefile index 19129b6..38a8c7f 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: install dev tests lint docs clean build +.PHONY: install dev dev-all tests lint docs clean build install: pip install . @@ -6,6 +6,9 @@ install: dev: pip install -e . +dev-all: + pip install -e .[phirc] + tests: pytest -s -x -vv tests/test*.py diff --git a/README.md b/README.md index efeb91d..2e158ef 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ Install additional dependencies needed for the CLI using `pip install pytket-phi ```sh ❯ phirc -h -usage: phirc [-h] [-m {H1-1,H1-2}] [-v] qasm_files [qasm_files ...] +usage: phirc [-h] [-w WASM_FILE] [-m {H1-1,H1-2}] [-o {0,1,2}] [-v] [--version] qasm_files [qasm_files ...] Emulates QASM program execution via PECOS @@ -35,9 +35,14 @@ positional arguments: options: -h, --help show this help message and exit + -w WASM_FILE, --wasm-file WASM_FILE + Optional WASM file for use by the QASM programs -m {H1-1,H1-2}, --machine {H1-1,H1-2} - machine name, H1-1 by default - -v, --version show program's version number and exit + Machine name, H1-1 by default + -o {0,1,2}, --tket-opt-level {0,1,2} + TKET optimization level, 0 by default + -v, --verbose + --version show program's version number and exit ``` ## Development diff --git a/pyproject.toml b/pyproject.toml index d0c5c7d..4825e13 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "phir>=0.2.1", "pytket-quantinuum>=0.25.0", "pytket>=1.21.0", + "wasmtime>=15.0.0", ] [project.optional-dependencies] @@ -55,8 +56,9 @@ pythonpath = [ ] log_cli = true log_cli_level = "INFO" +log_level = "DEBUG" filterwarnings = ["ignore:::lark.s*"] -log_format = "%(asctime)s.%(msecs)03d %(levelname)s %(message)s" +log_format = "%(asctime)s.%(msecs)03d %(levelname)s %(name)s:%(lineno)s %(message)s" log_date_format = "%Y-%m-%d %H:%M:%S" [tool.setuptools_scm] diff --git a/pytket/phir/api.py b/pytket/phir/api.py index 54cc329..958f732 100644 --- a/pytket/phir/api.py +++ b/pytket/phir/api.py @@ -6,13 +6,17 @@ # ############################################################################## +# mypy: disable-error-code="misc" + import logging +from pathlib import Path +from tempfile import NamedTemporaryFile from typing import TYPE_CHECKING from rich import print from phir.model import PHIRModel -from pytket.qasm.qasm import circuit_from_qasm_str +from pytket.qasm.qasm import circuit_from_qasm_str, circuit_from_qasm_wasm from .phirgen import genphir from .phirgen_parallel import genphir_parallel @@ -74,7 +78,7 @@ def pytket_to_phir( else: phir_json = genphir(placed, machine_ops=bool(machine)) if logger.getEffectiveLevel() <= logging.INFO: - print(PHIRModel.model_validate_json(phir_json)) # type: ignore[misc] + print(PHIRModel.model_validate_json(phir_json)) return phir_json @@ -82,6 +86,7 @@ def qasm_to_phir( qasm: str, qtm_machine: QtmMachine | None = None, tket_optimization_level: int = DEFAULT_TKET_OPT_LEVEL, + wasm_bytes: bytes | None = None, ) -> str: """Converts a QASM circuit string into its PHIR representation. @@ -91,6 +96,24 @@ def qasm_to_phir( :param circuit: Circuit object to be converted :param qtm_machine: (Optional) Quantinuum machine architecture to rebase against :param tket_optimization_level: (Default=0) TKET circuit optimization level + :param wasm_bytes (Optional) WASM as bytes to include as part of circuit """ - circuit = circuit_from_qasm_str(qasm) + circuit: Circuit + if wasm_bytes: + try: + qasm_file = NamedTemporaryFile(suffix=".qasm", delete=False) + wasm_file = NamedTemporaryFile(suffix=".wasm", delete=False) + qasm_file.write(qasm.encode()) + qasm_file.flush() + qasm_file.close() + wasm_file.write(wasm_bytes) + wasm_file.flush() + wasm_file.close() + + circuit = circuit_from_qasm_wasm(qasm_file.name, wasm_file.name) + finally: + Path.unlink(Path(qasm_file.name)) + Path.unlink(Path(wasm_file.name)) + else: + circuit = circuit_from_qasm_str(qasm) return pytket_to_phir(circuit, qtm_machine, tket_optimization_level) diff --git a/pytket/phir/cli.py b/pytket/phir/cli.py index 1ef0a04..f501bdd 100644 --- a/pytket/phir/cli.py +++ b/pytket/phir/cli.py @@ -7,22 +7,21 @@ ############################################################################## # mypy: disable-error-code="misc" +# ruff: noqa: T201 from argparse import ArgumentParser from importlib.metadata import version from pecos.engines.hybrid_engine import HybridEngine # type:ignore [import-not-found] +from pecos.foreign_objects.wasmtime import WasmtimeObj # type:ignore [import-not-found] -from phir.model import PHIRModel from pytket.qasm.qasm import ( circuit_from_qasm, - circuit_from_qasm_str, - circuit_to_qasm_str, + circuit_from_qasm_wasm, ) from .api import pytket_to_phir from .qtm_machine import QtmMachine -from .rebasing.rebaser import rebase_to_qtm_machine def main() -> None: @@ -34,12 +33,18 @@ def main() -> None: parser.add_argument( "qasm_files", nargs="+", default=None, help="One or more QASM files to emulate" ) + parser.add_argument( + "-w", + "--wasm-file", + default=None, + help="Optional WASM file for use by the QASM programs", + ) parser.add_argument( "-m", "--machine", choices=["H1-1", "H1-2"], default="H1-1", - help="machine name, H1-1 by default", + help="Machine name, H1-1 by default", ) parser.add_argument( "-o", @@ -48,8 +53,8 @@ def main() -> None: default="0", help="TKET optimization level, 0 by default", ) + parser.add_argument("-v", "--verbose", action="store_true") parser.add_argument( - "-v", "--version", action="version", version=f'{version("pytket-phir")}', @@ -57,19 +62,31 @@ def main() -> None: args = parser.parse_args() for file in args.qasm_files: - print(f"Processing {file}") # noqa: T201 - c = circuit_from_qasm(file) - tket_opt_level = int(args.tket_opt_level) - rc = rebase_to_qtm_machine(c, args.machine, tket_opt_level) - qasm = circuit_to_qasm_str(rc, header="hqslib1") - circ = circuit_from_qasm_str(qasm) + print(f"Processing {file}") + circuit = None + if args.wasm_file: + print(f"Including WASM from file {args.wasm_file}") + circuit = circuit_from_qasm_wasm(file, args.wasm_file) + wasm_pecos_obj = WasmtimeObj(args.wasm_file) + else: + circuit = circuit_from_qasm(file) match args.machine: case "H1-1": machine = QtmMachine.H1_1 case "H1-2": machine = QtmMachine.H1_2 - phir = pytket_to_phir(circ, machine) - PHIRModel.model_validate_json(phir) - HybridEngine(qsim="state-vector").run(program=phir, shots=10) + phir = pytket_to_phir(circuit, machine, int(args.tket_opt_level)) + if args.verbose: + print("\nPHIR to be simulated:") + print(phir) + + print("\nPECOS results:") + print( + HybridEngine(qsim="state-vector").run( + program=phir, + shots=10, + foreign_object=wasm_pecos_obj if args.wasm_file else None, + ) + ) diff --git a/pytket/phir/phirgen.py b/pytket/phir/phirgen.py index eb4a364..0dab676 100644 --- a/pytket/phir/phirgen.py +++ b/pytket/phir/phirgen.py @@ -281,6 +281,9 @@ def convert_subcmd(op: tk.Op, cmd: tk.Command) -> JsonDict | None: [arg_to_bit(cmd.args[i]) for i in range(len(cmd.args) // 2)], ) + case tk.WASMOp(): + return create_wasm_op(cmd, op) + case _: # TODO(kartik): NYI # https://github.com/CQCL/pytket-phir/issues/25 @@ -296,12 +299,63 @@ def append_cmd(cmd: tk.Command, ops: list[JsonDict]) -> None: cmd: pytket command obtained from pytket-phir ops: the list of ops to append to """ - ops.append({"//": str(cmd)}) + ops.append({"//": make_comment_text(cmd, cmd.op)}) op: JsonDict | None = convert_subcmd(cmd.op, cmd) if op: ops.append(op) +def create_wasm_op(cmd: tk.Command, wasm_op: tk.WASMOp) -> JsonDict: + """Creates a PHIR operation for a WASM command.""" + args, returns = extract_wasm_args_and_returns(cmd, wasm_op) + op = { + "cop": "ffcall", + "function": wasm_op.func_name, + "args": args, + "metadata": { + "ff_object": f"WASM module uid: {wasm_op.wasm_uid}", + }, + } + if cmd.bits: + op["returns"] = returns + + return op + + +def extract_wasm_args_and_returns( + command: tk.Command, op: tk.WASMOp +) -> tuple[list[str], list[str]]: + """Extract the wasm args and return values as whole register names.""" + # This slice removes the extra `_w` cregs (wires) that are not part of the + # circuit, and the output args which are appended after the input args + slice_index = op.num_w + sum(op.output_widths) + only_args = command.args[:-slice_index] + return ( + dedupe_bits_to_registers(only_args), + dedupe_bits_to_registers(command.bits), + ) + + +def dedupe_bits_to_registers(bits: "Sequence[UnitID]") -> list[str]: + """Dedupes a list of bits to their registers, keeping order intact.""" + return list(dict.fromkeys([bit.reg_name for bit in bits])) + + +def make_comment_text(command: tk.Command, op: tk.Op) -> str: + """Converts a command + op to the PHIR comment spec.""" + match op: + case tk.Conditional(): + conditional_text = str(command) + cleaned = conditional_text[: conditional_text.find("THEN") + 4] + return f"{cleaned} {make_comment_text(command, op.op)}" + + case tk.WASMOp(): + args, returns = extract_wasm_args_and_returns(command, op) + return f"WASM function={op.func_name} args={args} returns={returns}" + case _: + return str(command) + + def get_decls(qbits: set["Qubit"], cbits: set[tkBit]) -> list[dict[str, str | int]]: """Format the qvar and cvar define PHIR elements.""" # TODO(kartik): this may not always be accurate @@ -334,6 +388,7 @@ def get_decls(qbits: set["Qubit"], cbits: set[tkBit]) -> list[dict[str, str | in "size": dim, } for cvar, dim in cvar_dim.items() + if cvar != "_w" ] return decls diff --git a/pytket/phir/sharding/sharder.py b/pytket/phir/sharding/sharder.py index 2a0fbde..ffa7be9 100644 --- a/pytket/phir/sharding/sharder.py +++ b/pytket/phir/sharding/sharder.py @@ -14,7 +14,7 @@ from .shard import Shard -NOT_IMPLEMENTED_OP_TYPES = [OpType.CircBox, OpType.WASM] +NOT_IMPLEMENTED_OP_TYPES = [OpType.CircBox] SHARD_TRIGGER_OP_TYPES = [ OpType.Measure, @@ -25,6 +25,7 @@ OpType.RangePredicate, OpType.ExplicitPredicate, OpType.CopyBits, + OpType.WASM, ] logger = logging.getLogger(__name__) @@ -98,7 +99,6 @@ def _process_command(self, command: Command) -> None: return if self.should_op_create_shard(command.op): - logger.debug("Building shard for command: %s", command) self._build_shard(command) else: self._add_pending_sub_command(command) @@ -112,6 +112,7 @@ def _build_shard(self, command: Command) -> None: Args: command: tket command (operation, bits, etc) """ + logger.debug("Building shard for command: %s", command) # Rollup any sub commands (SQ gates) that interact with the same qubits sub_commands: dict[UnitID, list[Command]] = {} for key in ( @@ -123,6 +124,7 @@ def _build_shard(self, command: Command) -> None: for sub_command_list in sub_commands.values(): all_commands.extend(sub_command_list) + logger.debug("All shard commands: %s", all_commands) qubits_used = set(command.qubits) bits_written = set(command.bits) bits_read: set[Bit] = set() @@ -185,7 +187,9 @@ def _resolve_shard_dependencies( for bit_read in bits_read: if bit_read in self._bit_written_by: - logger.debug("...adding shard dep %s -> RAW") + logger.debug( + "...adding shard dep %s -> RAW", self._bit_written_by[bit_read] + ) depends_upon.add(self._bit_written_by[bit_read]) for bit_written in bits_written: @@ -220,6 +224,7 @@ def _mark_dependencies( self._bit_written_by[bit] = shard.ID for bit in shard.bits_read: self._bit_read_by[bit] = shard.ID + logger.debug("... dependencies marked") def _cleanup_remaining_commands(self) -> None: """Cleans up any remaining subcommands. @@ -228,7 +233,11 @@ def _cleanup_remaining_commands(self) -> None: to roll up lingering subcommands. """ remaining_qubits = [k for k, v in self._pending_commands.items() if v] + logger.debug( + "Cleaning up remaining subcommands for qubits %s", remaining_qubits + ) for qubit in remaining_qubits: + logger.debug("Adding barrier for subcommands for qubit %s", qubit) self._circuit.add_barrier([qubit]) # Easiest way to get to a command, since there's no constructor. Could # create an entire orphan circuit with the matching qubits and the barrier @@ -249,7 +258,7 @@ def _add_pending_sub_command(self, command: Command) -> None: if qubit_key not in self._pending_commands: self._pending_commands[qubit_key] = [] self._pending_commands[qubit_key].append(command) - logger.debug("Adding pending command %s", command) + logger.debug("Added pending sub-command %s", command) @staticmethod def should_op_create_shard(op: Op) -> bool: diff --git a/requirements.txt b/requirements.txt index 16c92e3..005e465 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,5 @@ pytket==1.24.0 ruff==0.1.14 setuptools_scm==8.0.4 sphinx==7.2.6 +wasmtime==15.0.0 wheel==0.42.0 diff --git a/tests/data/wasm/add.wat b/tests/data/wasm/add.wat new file mode 100644 index 0000000..d3d1e02 --- /dev/null +++ b/tests/data/wasm/add.wat @@ -0,0 +1,17 @@ +(module + (type (;0;) (func)) + (type (;1;) (func (param i32 i32) (result i32))) + (func $init (type 0)) + (func $add (type 1) (param i32 i32) (result i32) + local.get 1 + local.get 0 + i32.add) + (memory (;0;) 16) + (global $__stack_pointer (mut i32) (i32.const 1048576)) + (global (;1;) i32 (i32.const 1048576)) + (global (;2;) i32 (i32.const 1048576)) + (export "memory" (memory 0)) + (export "init" (func $init)) + (export "add" (func $add)) + (export "__data_end" (global 1)) + (export "__heap_base" (global 2))) diff --git a/tests/data/wasm/testfile.wat b/tests/data/wasm/testfile.wat new file mode 100644 index 0000000..d762875 --- /dev/null +++ b/tests/data/wasm/testfile.wat @@ -0,0 +1,38 @@ +(module + (type $t0 (func)) + (type $t1 (func (param i32) (result i32))) + (type $t2 (func (param i32 i32) (result i32))) + (type $t3 (func (param i64) (result i64))) + (type $t4 (func (param i32))) + (type $t5 (func (result i32))) + (func $init (export "init") (type $t0)) + (func $add_one (export "add_one") (type $t1) (param $p0 i32) (result i32) + (i32.add + (local.get $p0) + (i32.const 1))) + (func $multi (export "multi") (type $t2) (param $p0 i32) (param $p1 i32) (result i32) + (i32.mul + (local.get $p1) + (local.get $p0))) + (func $add_two (export "add_two") (type $t1) (param $p0 i32) (result i32) + (i32.add + (local.get $p0) + (i32.const 2))) + (func $add_something (export "add_something") (type $t3) (param $p0 i64) (result i64) + (i64.add + (local.get $p0) + (i64.const 11))) + (func $add_eleven (export "add_eleven") (type $t1) (param $p0 i32) (result i32) + (i32.add + (local.get $p0) + (i32.const 11))) + (func $no_return (export "no_return") (type $t4) (param $p0 i32)) + (func $no_parameters (export "no_parameters") (type $t5) (result i32) + (i32.const 11)) + (func $new_function (export "new_function") (type $t5) (result i32) + (i32.const 13)) + (table $T0 1 1 funcref) + (memory $memory (export "memory") 16) + (global $__stack_pointer (mut i32) (i32.const 1048576)) + (global $__data_end (export "__data_end") i32 (i32.const 1048576)) + (global $__heap_base (export "__heap_base") i32 (i32.const 1048576))) diff --git a/tests/test_api.py b/tests/test_api.py index bf2dab2..766c8c9 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -6,16 +6,23 @@ # ############################################################################## +# mypy: disable-error-code="misc" + +import base64 +import hashlib import json import logging +from pathlib import Path +from tempfile import NamedTemporaryFile import pytest from pytket.circuit import Bit, Circuit from pytket.phir.api import pytket_to_phir, qasm_to_phir from pytket.phir.qtm_machine import QtmMachine +from pytket.wasm.wasm import WasmFileHandler -from .test_utils import QasmFile, get_qasm_as_circuit +from .test_utils import QasmFile, WatFile, get_qasm_as_circuit, get_wat_as_wasm_bytes logger = logging.getLogger(__name__) @@ -24,7 +31,6 @@ class TestApi: def test_pytket_to_phir_no_machine(self) -> None: """Test case when no machine is present.""" circuit = get_qasm_as_circuit(QasmFile.baby) - assert pytket_to_phir(circuit) @pytest.mark.parametrize("test_file", list(QasmFile)) @@ -57,14 +63,14 @@ def test_pytket_classical_only(self) -> None: c.add_c_copyreg(a, b) c.add_c_copybits([Bit("b", 2), Bit("a", 1)], [Bit("a", 0), Bit("b", 0)]) - phir = json.loads(pytket_to_phir(c)) # type: ignore[misc] + phir = json.loads(pytket_to_phir(c)) - assert phir["ops"][3] == { # type: ignore[misc] + assert phir["ops"][3] == { "cop": "=", "returns": [["b", 0], ["b", 1]], "args": [["a", 0], ["a", 1]], } - assert phir["ops"][5] == { # type: ignore[misc] + assert phir["ops"][5] == { "cop": "=", "returns": [["a", 0], ["b", 0]], "args": [["b", 2], ["a", 1]], @@ -85,3 +91,111 @@ def test_qasm_to_phir(self) -> None: """ assert qasm_to_phir(qasm, QtmMachine.H1_1) + + def test_qasm_to_phir_with_wasm(self) -> None: + """Test the qasm string entrypoint works with WASM.""" + qasm = """ + OPENQASM 2.0; + include "qelib1.inc"; + + qreg q[2]; + h q; + ZZ q[1],q[0]; + creg cr[3]; + creg cs[3]; + creg co[3]; + measure q[0]->cr[0]; + measure q[1]->cr[1]; + + cs = cr; + co = add(cr, cs); + """ + + wasm_bytes = get_wat_as_wasm_bytes(WatFile.add) + + wasm_uid = hashlib.sha256(base64.b64encode(wasm_bytes)).hexdigest() + + phir_str = qasm_to_phir(qasm, QtmMachine.H1_1, wasm_bytes=wasm_bytes) + phir = json.loads(phir_str) + + expected_metadata = {"ff_object": (f"WASM module uid: {wasm_uid}")} + + assert phir["ops"][21] == { + "metadata": expected_metadata, + "cop": "ffcall", + "function": "add", + "args": ["cr", "cs"], + "returns": ["co"], + } + + def test_pytket_with_wasm(self) -> None: + wasm_bytes = get_wat_as_wasm_bytes(WatFile.testfile) + phir_str: str + try: + wasm_file = NamedTemporaryFile(suffix=".wasm", delete=False) + wasm_file.write(wasm_bytes) + wasm_file.flush() + wasm_file.close() + + w = WasmFileHandler(wasm_file.name) + + c = Circuit(6, 6) + c0 = c.add_c_register("c0", 3) + c1 = c.add_c_register("c1", 4) + c2 = c.add_c_register("c2", 5) + + c.add_wasm_to_reg("multi", w, [c0, c1], [c2]) + c.add_wasm_to_reg("add_one", w, [c2], [c2]) + c.add_wasm_to_reg("no_return", w, [c2], []) + c.add_wasm_to_reg("no_parameters", w, [], [c2]) + + c.add_wasm_to_reg("add_one", w, [c0], [c0], condition=c1[0]) + + phir_str = pytket_to_phir(c, QtmMachine.H1_1) + finally: + Path.unlink(Path(wasm_file.name)) + + phir = json.loads(phir_str) + + expected_metadata = {"ff_object": (f"WASM module uid: {w!s}")} + + assert phir["ops"][4] == { + "metadata": expected_metadata, + "cop": "ffcall", + "function": "multi", + "args": ["c0", "c1"], + "returns": ["c2"], + } + assert phir["ops"][7] == { + "metadata": expected_metadata, + "cop": "ffcall", + "function": "add_one", + "args": ["c2"], + "returns": ["c2"], + } + assert phir["ops"][9] == { + "block": "if", + "condition": {"cop": "==", "args": [["c1", 0], 1]}, + "true_branch": [ + { + "metadata": expected_metadata, + "cop": "ffcall", + "returns": ["c0"], + "function": "add_one", + "args": ["c1", "c0"], + } + ], + } + assert phir["ops"][12] == { + "metadata": expected_metadata, + "cop": "ffcall", + "function": "no_return", + "args": ["c2"], + } + assert phir["ops"][15] == { + "metadata": expected_metadata, + "cop": "ffcall", + "function": "no_parameters", + "args": [], + "returns": ["c2"], + } diff --git a/tests/test_utils.py b/tests/test_utils.py index 55fd646..abeb8e0 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -11,6 +11,8 @@ from pathlib import Path from typing import TYPE_CHECKING +from wasmtime import wat2wasm + from pytket.phir.phirgen_parallel import genphir_parallel from pytket.phir.place_and_route import place_and_route from pytket.phir.qtm_machine import QTM_MACHINES_MAP, QtmMachine @@ -45,6 +47,11 @@ class QasmFile(Enum): classical_ordering = auto() +class WatFile(Enum): + add = auto() + testfile = auto() + + def get_qasm_as_circuit(qasm_file: QasmFile) -> "Circuit": """Utility function to convert a QASM file to Circuit. @@ -70,3 +77,9 @@ def get_phir_json(qasmfile: QasmFile, *, rebase: bool) -> "JsonDict": shards = sharder.shard() placed = place_and_route(shards, machine) return json.loads(genphir_parallel(placed, machine)) # type: ignore[misc, no-any-return] + + +def get_wat_as_wasm_bytes(wat_file: WatFile) -> bytes: + """Gets a given wat file, converted to WASM bytes by wasmtime.""" + this_dir = Path(Path(__file__).resolve()).parent + return wat2wasm(Path(f"{this_dir}/data/wasm/{wat_file.name}.wat").read_text())