diff --git a/pytket/phir/phirgen.py b/pytket/phir/phirgen.py index f82394a..fc8bc18 100644 --- a/pytket/phir/phirgen.py +++ b/pytket/phir/phirgen.py @@ -1,35 +1,203 @@ import json +import logging +from collections.abc import Sequence from typing import Any +import pytket.circuit as tk from phir.model import PHIRModel -from pytket.circuit import Command -from pytket.phir.sharding.shard import Cost, Layer, Ordering +from pytket.circuit.logic_exp import RegWiseOp +from pytket.unit_id import UnitID +from .sharding.shard import Cost, Layer, Ordering -def write_cmd(cmd: Command, ops: list[dict[str, Any]]) -> None: - """Write a pytket command to PHIR qop. +logger = logging.getLogger(__name__) + +UINTMAX = 2**32 - 1 + +tket_gate_to_phir = { + tk.OpType.noop: "I", + + tk.OpType.CX: "CX", + tk.OpType.CY: "CY", + tk.OpType.CZ: "CZ", + tk.OpType.H: "H", + tk.OpType.PhasedX: "R1XY", + tk.OpType.Reset: "Reset", # TODO(kartik): confirm with Ciaran + tk.OpType.Rx: "RX", + tk.OpType.Ry: "RY", + tk.OpType.Rz: "RZ", + tk.OpType.S: "SZ", + tk.OpType.Sdg: "SZdg", + tk.OpType.SWAP: "SWAP", + tk.OpType.SX: "SX", + tk.OpType.SXdg: "SXdg", + tk.OpType.T: "T", + tk.OpType.Tdg: "Tdg", + tk.OpType.TK2: "R2XXYYZZ", + tk.OpType.U1: "RZ", + tk.OpType.V: "SX", + tk.OpType.Vdg: "SXdg", + tk.OpType.X: "X", + tk.OpType.XXPhase: "RXX", + tk.OpType.Y: "Y", + tk.OpType.YYPhase: "RYY", + tk.OpType.Z: "Z", + tk.OpType.ZZMax: "SZZ", + tk.OpType.ZZPhase: "RZZ", + + tk.OpType.Measure: "Measure", +} # fmt: skip + + +def arg_to_bit(arg: UnitID) -> list[str | int]: + """Convert tket arg to Bit.""" + return [arg.reg_name, arg.index[0]] + + +def assign_cop(into: str | list[str | int], what: Sequence[int]) -> dict[str, Any]: + """PHIR for assign classical operation.""" + return { + "cop": "=", + "returns": [into], + "args": what, + } + + +def convert_subcmd(op: tk.Op, cmd: tk.Command) -> dict[str, Any]: + """Return PHIR dict give op and its arguments.""" + if op.is_gate(): + try: + gate = tket_gate_to_phir[op.type] + except KeyError: + logging.exception(f"Gate {op.get_name()} unsupported by PHIR") + raise + angles = (op.params, "pi") if op.params else None + qop: dict[str, Any] + match op.type: + case tk.OpType.Measure: + qop = { + "cop": "Measure", + "returns": [arg_to_bit(cmd.bits[0])], + "args": [arg_to_bit(cmd.args[0])], + } + + case _: # a regular quantum gate + qop = { + "angles": angles, + "qop": gate, + "args": [arg_to_bit(qbit) for qbit in cmd.qubits], + } + return qop + + match op: # non-quantum op + case tk.SetBitsOp(): + return assign_cop(arg_to_bit(cmd.bits[0]), op.values) + + case _: + # TODO(kartik): NYI + raise NotImplementedError + + +def append_cmd(cmd: tk.Command, ops: list[dict[str, Any]]) -> None: + """Convert a pytket command to a PHIR command and append to `ops`. Args: cmd: pytket command obtained from pytket-phir ops: the list of ops to append to """ - gate = cmd.op.get_name().split("(", 1)[0] - angles = (cmd.op.params, "pi") if cmd.op.is_gate() and cmd.op.params else None + ops.append({"//": str(cmd)}) + if cmd.op.is_gate(): + ops.append(convert_subcmd(cmd.op, cmd)) + else: + op: dict[str, Any] | None = None + match cmd.op: + case tk.SetBitsOp(): + op = convert_subcmd(cmd.op, cmd) - qop: dict[str, Any] = { - "angles": angles, - "qop": gate, - "args": [], - } - for qbit in cmd.args: - qop["args"].append([qbit.reg_name, qbit.index[0]]) - if gate == "Measure": - break - if cmd.bits: - qop["returns"] = [] - for cbit in cmd.bits: - qop["returns"].append([cbit.reg_name, cbit.index[0]]) - ops.extend(({"//": str(cmd)}, qop)) + case tk.BarrierOp(): + # TODO(kartik): confirm with Ciaran + logger.debug("Skipping Barrier instruction") + + case tk.Conditional(): # where the condition is equality check + op = { + "block": "if", + "condition": { + "cop": "==", + "args": [ + arg_to_bit(cmd.args[0]) + if cmd.op.width == 1 + else cmd.args[0].reg_name, + cmd.op.value, + ], + }, + "true_branch": [convert_subcmd(cmd.op.op, cmd)], + } + + case tk.RangePredicateOp(): # where the condition is a range + cond: dict[str, Any] + match cmd.op.lower, cmd.op.upper: + case l, u if l == u: + cond = { + "cop": "==", + "args": [cmd.args[0].reg_name, u], + } + case l, u if u == UINTMAX: + cond = { + "cop": ">=", + "args": [cmd.args[0].reg_name, l], + } + case 0, u: + cond = { + "cop": "<=", + "args": [cmd.args[0].reg_name, u], + } + op = { + "block": "if", + "condition": cond, + "true_branch": [assign_cop(arg_to_bit(cmd.bits[0]), [1])], + } + case tk.ClassicalExpBox(): + exp = cmd.op.get_exp() + match exp.op: + case RegWiseOp.XOR: + cop = "^" + case RegWiseOp.ADD: + cop = "+" + case RegWiseOp.SUB: + cop = "-" + case RegWiseOp.MUL: + cop = "*" + case RegWiseOp.DIV: + cop = "/" + case RegWiseOp.LSH: + cop = "<<" + case RegWiseOp.RSH: + cop = ">>" + case RegWiseOp.EQ: + cop = "==" + case RegWiseOp.NEQ: + cop = "!=" + case RegWiseOp.LT: + cop = "<" + case RegWiseOp.GT: + cop = ">" + case RegWiseOp.LEQ: + cop = "<=" + case RegWiseOp.GEQ: + cop = ">=" + case RegWiseOp.NOT: + cop = "~" + case other: + logging.exception(f"Unsupported classical operator {other}") + raise ValueError + op = { + "cop": cop, + "args": [arg["name"] for arg in exp.to_dict()["args"]], + } + case m: + raise NotImplementedError(m) + if op: + ops.append(op) def genphir(inp: list[tuple[Ordering, Layer, Cost]]) -> str: @@ -53,8 +221,8 @@ def genphir(inp: list[tuple[Ordering, Layer, Cost]]) -> str: cbits |= shard.bits_read | shard.bits_written for sub_commands in shard.sub_commands.values(): for sc in sub_commands: - write_cmd(sc, ops) - write_cmd(shard.primary_command, ops) + append_cmd(sc, ops) + append_cmd(shard.primary_command, ops) ops.append( { "mop": "Transport", diff --git a/tests/sample_data.py b/tests/sample_data.py index b740b89..c11fc88 100644 --- a/tests/sample_data.py +++ b/tests/sample_data.py @@ -6,20 +6,20 @@ class QasmFile(Enum): - simple = auto() - cond_1 = auto() - bv_n10 = auto() baby = auto() + simple = auto() + eztest = auto() baby_with_rollup = auto() - simple_cond = auto() - cond_classical = auto() - barrier_complex = auto() - classical_hazards = auto() big_gate = auto() + simple_cond = auto() n10_test = auto() - qv20_0 = auto() + classical_hazards = auto() + cond_1 = auto() + barrier_complex = auto() + cond_classical = auto() + bv_n10 = auto() oned_brickwork_circuit_n20 = auto() - eztest = auto() + qv20_0 = auto() def get_qasm_as_circuit(qasm_file: QasmFile) -> Circuit: