Skip to content

Commit

Permalink
Adding WASM support (#77)
Browse files Browse the repository at this point in the history
Adding wasm support to the API and CLI.

---------

Co-authored-by: Neal Erickson <neal.erickson@quantinuum.com>
Co-authored-by: Kartik Singhal <kartik.singhal@quantinuum.com>
  • Loading branch information
3 people authored Jan 23, 2024
1 parent df5957a commit a199f3c
Show file tree
Hide file tree
Showing 13 changed files with 331 additions and 33 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,5 @@ repos:
pytket-quantinuum,
pytket,
types-setuptools,
wasmtime,
]
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
.PHONY: install dev tests lint docs clean build
.PHONY: install dev dev-all tests lint docs clean build

install:
pip install .

dev:
pip install -e .

dev-all:
pip install -e .[phirc]

tests:
pytest -s -x -vv tests/test*.py

Expand Down
11 changes: 8 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Install additional dependencies needed for the CLI using `pip install pytket-phi

```sh
❯ phirc -h
usage: phirc [-h] [-m {H1-1,H1-2}] [-v] qasm_files [qasm_files ...]
usage: phirc [-h] [-w WASM_FILE] [-m {H1-1,H1-2}] [-o {0,1,2}] [-v] [--version] qasm_files [qasm_files ...]

Emulates QASM program execution via PECOS

Expand All @@ -35,9 +35,14 @@ positional arguments:

options:
-h, --help show this help message and exit
-w WASM_FILE, --wasm-file WASM_FILE
Optional WASM file for use by the QASM programs
-m {H1-1,H1-2}, --machine {H1-1,H1-2}
machine name, H1-1 by default
-v, --version show program's version number and exit
Machine name, H1-1 by default
-o {0,1,2}, --tket-opt-level {0,1,2}
TKET optimization level, 0 by default
-v, --verbose
--version show program's version number and exit
```
## Development
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ dependencies = [
"phir>=0.2.1",
"pytket-quantinuum>=0.25.0",
"pytket>=1.21.0",
"wasmtime>=15.0.0",
]

[project.optional-dependencies]
Expand Down Expand Up @@ -55,8 +56,9 @@ pythonpath = [
]
log_cli = true
log_cli_level = "INFO"
log_level = "DEBUG"
filterwarnings = ["ignore:::lark.s*"]
log_format = "%(asctime)s.%(msecs)03d %(levelname)s %(message)s"
log_format = "%(asctime)s.%(msecs)03d %(levelname)s %(name)s:%(lineno)s %(message)s"
log_date_format = "%Y-%m-%d %H:%M:%S"

[tool.setuptools_scm]
Expand Down
29 changes: 26 additions & 3 deletions pytket/phir/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@
#
##############################################################################

# mypy: disable-error-code="misc"

import logging
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING

from rich import print

from phir.model import PHIRModel
from pytket.qasm.qasm import circuit_from_qasm_str
from pytket.qasm.qasm import circuit_from_qasm_str, circuit_from_qasm_wasm

from .phirgen import genphir
from .phirgen_parallel import genphir_parallel
Expand Down Expand Up @@ -74,14 +78,15 @@ def pytket_to_phir(
else:
phir_json = genphir(placed, machine_ops=bool(machine))
if logger.getEffectiveLevel() <= logging.INFO:
print(PHIRModel.model_validate_json(phir_json)) # type: ignore[misc]
print(PHIRModel.model_validate_json(phir_json))
return phir_json


def qasm_to_phir(
qasm: str,
qtm_machine: QtmMachine | None = None,
tket_optimization_level: int = DEFAULT_TKET_OPT_LEVEL,
wasm_bytes: bytes | None = None,
) -> str:
"""Converts a QASM circuit string into its PHIR representation.
Expand All @@ -91,6 +96,24 @@ def qasm_to_phir(
:param circuit: Circuit object to be converted
:param qtm_machine: (Optional) Quantinuum machine architecture to rebase against
:param tket_optimization_level: (Default=0) TKET circuit optimization level
:param wasm_bytes (Optional) WASM as bytes to include as part of circuit
"""
circuit = circuit_from_qasm_str(qasm)
circuit: Circuit
if wasm_bytes:
try:
qasm_file = NamedTemporaryFile(suffix=".qasm", delete=False)
wasm_file = NamedTemporaryFile(suffix=".wasm", delete=False)
qasm_file.write(qasm.encode())
qasm_file.flush()
qasm_file.close()
wasm_file.write(wasm_bytes)
wasm_file.flush()
wasm_file.close()

circuit = circuit_from_qasm_wasm(qasm_file.name, wasm_file.name)
finally:
Path.unlink(Path(qasm_file.name))
Path.unlink(Path(wasm_file.name))
else:
circuit = circuit_from_qasm_str(qasm)
return pytket_to_phir(circuit, qtm_machine, tket_optimization_level)
47 changes: 32 additions & 15 deletions pytket/phir/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,21 @@
##############################################################################

# mypy: disable-error-code="misc"
# ruff: noqa: T201

from argparse import ArgumentParser
from importlib.metadata import version

from pecos.engines.hybrid_engine import HybridEngine # type:ignore [import-not-found]
from pecos.foreign_objects.wasmtime import WasmtimeObj # type:ignore [import-not-found]

from phir.model import PHIRModel
from pytket.qasm.qasm import (
circuit_from_qasm,
circuit_from_qasm_str,
circuit_to_qasm_str,
circuit_from_qasm_wasm,
)

from .api import pytket_to_phir
from .qtm_machine import QtmMachine
from .rebasing.rebaser import rebase_to_qtm_machine


def main() -> None:
Expand All @@ -34,12 +33,18 @@ def main() -> None:
parser.add_argument(
"qasm_files", nargs="+", default=None, help="One or more QASM files to emulate"
)
parser.add_argument(
"-w",
"--wasm-file",
default=None,
help="Optional WASM file for use by the QASM programs",
)
parser.add_argument(
"-m",
"--machine",
choices=["H1-1", "H1-2"],
default="H1-1",
help="machine name, H1-1 by default",
help="Machine name, H1-1 by default",
)
parser.add_argument(
"-o",
Expand All @@ -48,28 +53,40 @@ def main() -> None:
default="0",
help="TKET optimization level, 0 by default",
)
parser.add_argument("-v", "--verbose", action="store_true")
parser.add_argument(
"-v",
"--version",
action="version",
version=f'{version("pytket-phir")}',
)
args = parser.parse_args()

for file in args.qasm_files:
print(f"Processing {file}") # noqa: T201
c = circuit_from_qasm(file)
tket_opt_level = int(args.tket_opt_level)
rc = rebase_to_qtm_machine(c, args.machine, tket_opt_level)
qasm = circuit_to_qasm_str(rc, header="hqslib1")
circ = circuit_from_qasm_str(qasm)
print(f"Processing {file}")
circuit = None
if args.wasm_file:
print(f"Including WASM from file {args.wasm_file}")
circuit = circuit_from_qasm_wasm(file, args.wasm_file)
wasm_pecos_obj = WasmtimeObj(args.wasm_file)
else:
circuit = circuit_from_qasm(file)

match args.machine:
case "H1-1":
machine = QtmMachine.H1_1
case "H1-2":
machine = QtmMachine.H1_2
phir = pytket_to_phir(circ, machine)
PHIRModel.model_validate_json(phir)

HybridEngine(qsim="state-vector").run(program=phir, shots=10)
phir = pytket_to_phir(circuit, machine, int(args.tket_opt_level))
if args.verbose:
print("\nPHIR to be simulated:")
print(phir)

print("\nPECOS results:")
print(
HybridEngine(qsim="state-vector").run(
program=phir,
shots=10,
foreign_object=wasm_pecos_obj if args.wasm_file else None,
)
)
57 changes: 56 additions & 1 deletion pytket/phir/phirgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,9 @@ def convert_subcmd(op: tk.Op, cmd: tk.Command) -> JsonDict | None:
[arg_to_bit(cmd.args[i]) for i in range(len(cmd.args) // 2)],
)

case tk.WASMOp():
return create_wasm_op(cmd, op)

case _:
# TODO(kartik): NYI
# https://github.com/CQCL/pytket-phir/issues/25
Expand All @@ -296,12 +299,63 @@ def append_cmd(cmd: tk.Command, ops: list[JsonDict]) -> None:
cmd: pytket command obtained from pytket-phir
ops: the list of ops to append to
"""
ops.append({"//": str(cmd)})
ops.append({"//": make_comment_text(cmd, cmd.op)})
op: JsonDict | None = convert_subcmd(cmd.op, cmd)
if op:
ops.append(op)


def create_wasm_op(cmd: tk.Command, wasm_op: tk.WASMOp) -> JsonDict:
"""Creates a PHIR operation for a WASM command."""
args, returns = extract_wasm_args_and_returns(cmd, wasm_op)
op = {
"cop": "ffcall",
"function": wasm_op.func_name,
"args": args,
"metadata": {
"ff_object": f"WASM module uid: {wasm_op.wasm_uid}",
},
}
if cmd.bits:
op["returns"] = returns

return op


def extract_wasm_args_and_returns(
command: tk.Command, op: tk.WASMOp
) -> tuple[list[str], list[str]]:
"""Extract the wasm args and return values as whole register names."""
# This slice removes the extra `_w` cregs (wires) that are not part of the
# circuit, and the output args which are appended after the input args
slice_index = op.num_w + sum(op.output_widths)
only_args = command.args[:-slice_index]
return (
dedupe_bits_to_registers(only_args),
dedupe_bits_to_registers(command.bits),
)


def dedupe_bits_to_registers(bits: "Sequence[UnitID]") -> list[str]:
"""Dedupes a list of bits to their registers, keeping order intact."""
return list(dict.fromkeys([bit.reg_name for bit in bits]))


def make_comment_text(command: tk.Command, op: tk.Op) -> str:
"""Converts a command + op to the PHIR comment spec."""
match op:
case tk.Conditional():
conditional_text = str(command)
cleaned = conditional_text[: conditional_text.find("THEN") + 4]
return f"{cleaned} {make_comment_text(command, op.op)}"

case tk.WASMOp():
args, returns = extract_wasm_args_and_returns(command, op)
return f"WASM function={op.func_name} args={args} returns={returns}"
case _:
return str(command)


def get_decls(qbits: set["Qubit"], cbits: set[tkBit]) -> list[dict[str, str | int]]:
"""Format the qvar and cvar define PHIR elements."""
# TODO(kartik): this may not always be accurate
Expand Down Expand Up @@ -334,6 +388,7 @@ def get_decls(qbits: set["Qubit"], cbits: set[tkBit]) -> list[dict[str, str | in
"size": dim,
}
for cvar, dim in cvar_dim.items()
if cvar != "_w"
]

return decls
Expand Down
17 changes: 13 additions & 4 deletions pytket/phir/sharding/sharder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from .shard import Shard

NOT_IMPLEMENTED_OP_TYPES = [OpType.CircBox, OpType.WASM]
NOT_IMPLEMENTED_OP_TYPES = [OpType.CircBox]

SHARD_TRIGGER_OP_TYPES = [
OpType.Measure,
Expand All @@ -25,6 +25,7 @@
OpType.RangePredicate,
OpType.ExplicitPredicate,
OpType.CopyBits,
OpType.WASM,
]

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -98,7 +99,6 @@ def _process_command(self, command: Command) -> None:
return

if self.should_op_create_shard(command.op):
logger.debug("Building shard for command: %s", command)
self._build_shard(command)
else:
self._add_pending_sub_command(command)
Expand All @@ -112,6 +112,7 @@ def _build_shard(self, command: Command) -> None:
Args:
command: tket command (operation, bits, etc)
"""
logger.debug("Building shard for command: %s", command)
# Rollup any sub commands (SQ gates) that interact with the same qubits
sub_commands: dict[UnitID, list[Command]] = {}
for key in (
Expand All @@ -123,6 +124,7 @@ def _build_shard(self, command: Command) -> None:
for sub_command_list in sub_commands.values():
all_commands.extend(sub_command_list)

logger.debug("All shard commands: %s", all_commands)
qubits_used = set(command.qubits)
bits_written = set(command.bits)
bits_read: set[Bit] = set()
Expand Down Expand Up @@ -185,7 +187,9 @@ def _resolve_shard_dependencies(

for bit_read in bits_read:
if bit_read in self._bit_written_by:
logger.debug("...adding shard dep %s -> RAW")
logger.debug(
"...adding shard dep %s -> RAW", self._bit_written_by[bit_read]
)
depends_upon.add(self._bit_written_by[bit_read])

for bit_written in bits_written:
Expand Down Expand Up @@ -220,6 +224,7 @@ def _mark_dependencies(
self._bit_written_by[bit] = shard.ID
for bit in shard.bits_read:
self._bit_read_by[bit] = shard.ID
logger.debug("... dependencies marked")

def _cleanup_remaining_commands(self) -> None:
"""Cleans up any remaining subcommands.
Expand All @@ -228,7 +233,11 @@ def _cleanup_remaining_commands(self) -> None:
to roll up lingering subcommands.
"""
remaining_qubits = [k for k, v in self._pending_commands.items() if v]
logger.debug(
"Cleaning up remaining subcommands for qubits %s", remaining_qubits
)
for qubit in remaining_qubits:
logger.debug("Adding barrier for subcommands for qubit %s", qubit)
self._circuit.add_barrier([qubit])
# Easiest way to get to a command, since there's no constructor. Could
# create an entire orphan circuit with the matching qubits and the barrier
Expand All @@ -249,7 +258,7 @@ def _add_pending_sub_command(self, command: Command) -> None:
if qubit_key not in self._pending_commands:
self._pending_commands[qubit_key] = []
self._pending_commands[qubit_key].append(command)
logger.debug("Adding pending command %s", command)
logger.debug("Added pending sub-command %s", command)

@staticmethod
def should_op_create_shard(op: Op) -> bool:
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ pytket==1.24.0
ruff==0.1.14
setuptools_scm==8.0.4
sphinx==7.2.6
wasmtime==15.0.0
wheel==0.42.0
Loading

0 comments on commit a199f3c

Please sign in to comment.