Skip to content

Commit

Permalink
feat(phirgen): tket to phir-specific gate conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
qartik committed Nov 3, 2023
1 parent 47a2ba9 commit aa3a01d
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 31 deletions.
212 changes: 190 additions & 22 deletions pytket/phir/phirgen.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,203 @@
import json
import logging
from collections.abc import Sequence
from typing import Any

import pytket.circuit as tk
from phir.model import PHIRModel
from pytket.circuit import Command
from pytket.phir.sharding.shard import Cost, Layer, Ordering
from pytket.circuit.logic_exp import RegWiseOp
from pytket.unit_id import UnitID

from .sharding.shard import Cost, Layer, Ordering

def write_cmd(cmd: Command, ops: list[dict[str, Any]]) -> None:
"""Write a pytket command to PHIR qop.
logger = logging.getLogger(__name__)

UINTMAX = 2**32 - 1

tket_gate_to_phir = {
tk.OpType.noop: "I",

tk.OpType.CX: "CX",
tk.OpType.CY: "CY",
tk.OpType.CZ: "CZ",
tk.OpType.H: "H",
tk.OpType.PhasedX: "R1XY",
tk.OpType.Reset: "Reset", # TODO(kartik): confirm with Ciaran
tk.OpType.Rx: "RX",
tk.OpType.Ry: "RY",
tk.OpType.Rz: "RZ",
tk.OpType.S: "SZ",
tk.OpType.Sdg: "SZdg",
tk.OpType.SWAP: "SWAP",
tk.OpType.SX: "SX",
tk.OpType.SXdg: "SXdg",
tk.OpType.T: "T",
tk.OpType.Tdg: "Tdg",
tk.OpType.TK2: "R2XXYYZZ",
tk.OpType.U1: "RZ",
tk.OpType.V: "SX",
tk.OpType.Vdg: "SXdg",
tk.OpType.X: "X",
tk.OpType.XXPhase: "RXX",
tk.OpType.Y: "Y",
tk.OpType.YYPhase: "RYY",
tk.OpType.Z: "Z",
tk.OpType.ZZMax: "SZZ",
tk.OpType.ZZPhase: "RZZ",

tk.OpType.Measure: "Measure",
} # fmt: skip


def arg_to_bit(arg: UnitID) -> list[str | int]:
"""Convert tket arg to Bit."""
return [arg.reg_name, arg.index[0]]


def assign_cop(into: str | list[str | int], what: Sequence[int]) -> dict[str, Any]:
"""PHIR for assign classical operation."""
return {
"cop": "=",
"returns": [into],
"args": what,
}


def convert_subcmd(op: tk.Op, cmd: tk.Command) -> dict[str, Any]:
"""Return PHIR dict give op and its arguments."""
if op.is_gate():
try:
gate = tket_gate_to_phir[op.type]
except KeyError:
logging.exception(f"Gate {op.get_name()} unsupported by PHIR")
raise
angles = (op.params, "pi") if op.params else None
qop: dict[str, Any]
match op.type:
case tk.OpType.Measure:
qop = {
"cop": "Measure",
"returns": [arg_to_bit(cmd.bits[0])],
"args": [arg_to_bit(cmd.args[0])],
}

case _: # a regular quantum gate
qop = {
"angles": angles,
"qop": gate,
"args": [arg_to_bit(qbit) for qbit in cmd.qubits],
}
return qop

match op: # non-quantum op
case tk.SetBitsOp():
return assign_cop(arg_to_bit(cmd.bits[0]), op.values)

case _:
# TODO(kartik): NYI
raise NotImplementedError


def append_cmd(cmd: tk.Command, ops: list[dict[str, Any]]) -> None:
"""Convert a pytket command to a PHIR command and append to `ops`.
Args:
cmd: pytket command obtained from pytket-phir
ops: the list of ops to append to
"""
gate = cmd.op.get_name().split("(", 1)[0]
angles = (cmd.op.params, "pi") if cmd.op.is_gate() and cmd.op.params else None
ops.append({"//": str(cmd)})
if cmd.op.is_gate():
ops.append(convert_subcmd(cmd.op, cmd))
else:
op: dict[str, Any] | None = None
match cmd.op:
case tk.SetBitsOp():
op = convert_subcmd(cmd.op, cmd)

qop: dict[str, Any] = {
"angles": angles,
"qop": gate,
"args": [],
}
for qbit in cmd.args:
qop["args"].append([qbit.reg_name, qbit.index[0]])
if gate == "Measure":
break
if cmd.bits:
qop["returns"] = []
for cbit in cmd.bits:
qop["returns"].append([cbit.reg_name, cbit.index[0]])
ops.extend(({"//": str(cmd)}, qop))
case tk.BarrierOp():
# TODO(kartik): confirm with Ciaran
logger.debug("Skipping Barrier instruction")

case tk.Conditional(): # where the condition is equality check
op = {
"block": "if",
"condition": {
"cop": "==",
"args": [
arg_to_bit(cmd.args[0])
if cmd.op.width == 1
else cmd.args[0].reg_name,
cmd.op.value,
],
},
"true_branch": [convert_subcmd(cmd.op.op, cmd)],
}

case tk.RangePredicateOp(): # where the condition is a range
cond: dict[str, Any]
match cmd.op.lower, cmd.op.upper:
case l, u if l == u:
cond = {
"cop": "==",
"args": [cmd.args[0].reg_name, u],
}
case l, u if u == UINTMAX:
cond = {
"cop": ">=",
"args": [cmd.args[0].reg_name, l],
}
case 0, u:
cond = {
"cop": "<=",
"args": [cmd.args[0].reg_name, u],
}
op = {
"block": "if",
"condition": cond,
"true_branch": [assign_cop(arg_to_bit(cmd.bits[0]), [1])],
}
case tk.ClassicalExpBox():
exp = cmd.op.get_exp()
match exp.op:
case RegWiseOp.XOR:
cop = "^"
case RegWiseOp.ADD:
cop = "+"
case RegWiseOp.SUB:
cop = "-"
case RegWiseOp.MUL:
cop = "*"
case RegWiseOp.DIV:
cop = "/"
case RegWiseOp.LSH:
cop = "<<"
case RegWiseOp.RSH:
cop = ">>"
case RegWiseOp.EQ:
cop = "=="
case RegWiseOp.NEQ:
cop = "!="
case RegWiseOp.LT:
cop = "<"
case RegWiseOp.GT:
cop = ">"
case RegWiseOp.LEQ:
cop = "<="
case RegWiseOp.GEQ:
cop = ">="
case RegWiseOp.NOT:
cop = "~"
case other:
logging.exception(f"Unsupported classical operator {other}")
raise ValueError
op = {
"cop": cop,
"args": [arg["name"] for arg in exp.to_dict()["args"]],
}
case m:
raise NotImplementedError(m)
if op:
ops.append(op)


def genphir(inp: list[tuple[Ordering, Layer, Cost]]) -> str:
Expand All @@ -53,8 +221,8 @@ def genphir(inp: list[tuple[Ordering, Layer, Cost]]) -> str:
cbits |= shard.bits_read | shard.bits_written
for sub_commands in shard.sub_commands.values():
for sc in sub_commands:
write_cmd(sc, ops)
write_cmd(shard.primary_command, ops)
append_cmd(sc, ops)
append_cmd(shard.primary_command, ops)
ops.append(
{
"mop": "Transport",
Expand Down
18 changes: 9 additions & 9 deletions tests/sample_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,20 @@


class QasmFile(Enum):
simple = auto()
cond_1 = auto()
bv_n10 = auto()
baby = auto()
simple = auto()
eztest = auto()
baby_with_rollup = auto()
simple_cond = auto()
cond_classical = auto()
barrier_complex = auto()
classical_hazards = auto()
big_gate = auto()
simple_cond = auto()
n10_test = auto()
qv20_0 = auto()
classical_hazards = auto()
cond_1 = auto()
barrier_complex = auto()
cond_classical = auto()
bv_n10 = auto()
oned_brickwork_circuit_n20 = auto()
eztest = auto()
qv20_0 = auto()


def get_qasm_as_circuit(qasm_file: QasmFile) -> Circuit:
Expand Down

0 comments on commit aa3a01d

Please sign in to comment.