Skip to content

Commit

Permalink
fix(phirgen): make conditionals more robust based on Alec's feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
qartik committed Apr 8, 2024
1 parent 46a9ab1 commit 09970f6
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 51 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v4.6.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
Expand Down
49 changes: 29 additions & 20 deletions pytket/phir/phirgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,15 +237,15 @@ def cop_from_op_name(op_name: str) -> str:
def convert_classicalevalop(op: tk.ClassicalEvalOp, cmd: tk.Command) -> JsonDict | None:
"""Return PHIR dict for a pytket ClassicalEvalOp."""
# Exclude conditional bit from args
args = cmd.args[1:] if isinstance(cmd.op, tk.Conditional) else cmd.args
args = cmd.args[cmd.op.width :] if isinstance(cmd.op, tk.Conditional) else cmd.args
out: JsonDict | None = None
match op:
case tk.CopyBitsOp():
if len(cmd.bits) != len(cmd.args) // 2:
if len(cmd.bits) != len(args) // 2:
logger.warning("LHS and RHS lengths mismatch for CopyBits")
out = assign_cop(
[arg_to_bit(bit) for bit in cmd.bits],
[arg_to_bit(args[i]) for i in range(len(cmd.args) // 2)],
[arg_to_bit(args[i]) for i in range(len(args) // 2)],
)
case tk.SetBitsOp():
if len(cmd.bits) != len(op.values):
Expand All @@ -260,17 +260,17 @@ def convert_classicalevalop(op: tk.ClassicalEvalOp, cmd: tk.Command) -> JsonDict
case l, u if l == u:
cond = {
"cop": "==",
"args": [cmd.args[0].reg_name, u],
"args": [args[0].reg_name, u],
}
case l, u if u == UINTMAX:
cond = {
"cop": ">=",
"args": [cmd.args[0].reg_name, l],
"args": [args[0].reg_name, l],
}
case 0, u:
cond = {
"cop": "<=",
"args": [cmd.args[0].reg_name, u],
"args": [args[0].reg_name, u],
}
out = {
"block": "if",
Expand All @@ -280,9 +280,9 @@ def convert_classicalevalop(op: tk.ClassicalEvalOp, cmd: tk.Command) -> JsonDict
case tk.MultiBitOp():
# determine number of register operands involved in the operation
operand_count = (
len(cmd.args) // len(cmd.bits) - 1
len(args) // len(cmd.bits) - 1
if op.basic_op.type == tk.OpType.ExplicitPredicate
else len(cmd.args) // len(cmd.bits)
else len(args) // len(cmd.bits)
)
out = assign_cop(
# Converting to regwise operations that pecos can handle
Expand All @@ -300,25 +300,30 @@ def convert_classicalevalop(op: tk.ClassicalEvalOp, cmd: tk.Command) -> JsonDict
return out


def multi_bit_condition(args: "list[UnitID]", value: int) -> JsonDict:
"""Construct bitwise condition."""
return {
"cop": "&",
"args": [
{"cop": "==", "args": [arg_to_bit(arg), bval]}
for (arg, bval) in zip(args, map(int, f"{value:b}"), strict=False)
],
}


def convert_subcmd(op: tk.Op, cmd: tk.Command) -> JsonDict | None:
"""Return PHIR dict given a tket op and its arguments."""
if op.is_gate():
return convert_gate(op, cmd)

out: JsonDict | None = None
match op: # non-quantum op
case tk.Conditional(): # where the condition is equality check
case tk.Conditional():
out = {
"block": "if",
"condition": {
"cop": "==",
"args": [
arg_to_bit(cmd.args[0])
if op.width == 1
else cmd.args[0].reg_name,
op.value,
],
},
"condition": {"cop": "==", "args": [arg_to_bit(cmd.args[0]), op.value]}
if op.width == 1
else multi_bit_condition(cmd.args[: op.width], op.value),
"true_branch": [convert_subcmd(op.op, cmd)],
}

Expand Down Expand Up @@ -360,8 +365,12 @@ def convert_subcmd(op: tk.Op, cmd: tk.Command) -> JsonDict | None:
return create_wasm_op(cmd, op)

case _:
# Exclude conditional bit from args
args = cmd.args[1:] if isinstance(cmd.op, tk.Conditional) else cmd.args
# Exclude conditional bits from args
args = (
cmd.args[cmd.op.width :]
if isinstance(cmd.op, tk.Conditional)
else cmd.args
)
match op.type:
case tk.OpType.ExplicitPredicate | tk.OpType.ExplicitModifier:
# exclude output bit when not modifying in place
Expand Down
82 changes: 52 additions & 30 deletions tests/test_phirgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,35 @@
from .test_utils import QasmFile, get_qasm_as_circuit


def test_multiple_sleep() -> None:
"""Ensure multiple sleep ops get converted correctly."""
qasm = """
OPENQASM 2.0;
include "hqslib1_dev.inc";
qreg q[2];
sleep(1) q[0];
sleep(2) q[1];
"""
circ = circuit_from_qasm_str(qasm)
phir = json.loads(pytket_to_phir(circ))
assert phir["ops"][2] == {"mop": "Idle", "args": [["q", 0]], "duration": [1.0, "s"]}
assert phir["ops"][4] == {"mop": "Idle", "args": [["q", 1]], "duration": [2.0, "s"]}


def test_simple_cond_classical() -> None:
"""Ensure conditional classical operation are correctly generated."""
circ = get_qasm_as_circuit(QasmFile.simple_cond)
phir = json.loads(pytket_to_phir(circ))
assert phir["ops"][-6] == {"//": "IF ([c[0]] == 1) THEN SetBits(1) z[0];"}
assert phir["ops"][-5] == {
"block": "if",
"condition": {"cop": "==", "args": [["c", 0], 1]},
"true_branch": [{"cop": "=", "returns": [["z", 0]], "args": [1]}],
}


def test_pytket_classical_only() -> None:
"""From https://github.com/CQCL/pytket-phir/issues/61 ."""
c = Circuit(1)
Expand All @@ -29,6 +58,12 @@ def test_pytket_classical_only() -> None:
c.add_c_copybits(
[Bit("b", 2), Bit("a", 1)], [Bit("a", 0), Bit("b", 0)], condition=Bit("b", 1)
)
c.add_c_copybits(
[Bit("a", 0), Bit("a", 1)], # type: ignore[list-item] # overloaded function
[Bit("b", 0), Bit("b", 1)], # type: ignore[list-item] # overloaded function
condition_bits=[Bit("b", 1), Bit("b", 2)],
condition_value=2,
)

phir = json.loads(pytket_to_phir(c))

Expand All @@ -49,6 +84,22 @@ def test_pytket_classical_only() -> None:
{"cop": "=", "returns": [["a", 0], ["b", 0]], "args": [["b", 2], ["a", 1]]}
],
}
assert phir["ops"][8] == {
"//": "IF ([b[1], b[2]] == 2) THEN CopyBits a[0], a[1], b[0], b[1];"
}
assert phir["ops"][9] == {
"block": "if",
"condition": {
"cop": "&",
"args": [
{"cop": "==", "args": [["b", 1], 1]},
{"cop": "==", "args": [["b", 2], 0]},
],
},
"true_branch": [
{"cop": "=", "returns": [["b", 0], ["b", 1]], "args": [["a", 0], ["a", 1]]}
],
}


def test_classicalexpbox() -> None:
Expand Down Expand Up @@ -121,23 +172,11 @@ def test_conditional_barrier() -> None:
assert phir["ops"][4] == {"//": "IF ([m[0], m[1]] == 0) THEN Barrier q[0], q[1];"}
assert phir["ops"][5] == {
"block": "if",
"condition": {"cop": "==", "args": ["m", 0]},
"condition": {"cop": "&", "args": [{"cop": "==", "args": [["m", 0], 0]}]},
"true_branch": [{"meta": "barrier", "args": [["q", 0], ["q", 1]]}],
}


def test_simple_cond_classical() -> None:
"""Ensure conditional classical operation are correctly generated."""
circ = get_qasm_as_circuit(QasmFile.simple_cond)
phir = json.loads(pytket_to_phir(circ))
assert phir["ops"][-6] == {"//": "IF ([c[0]] == 1) THEN SetBits(1) z[0];"}
assert phir["ops"][-5] == {
"block": "if",
"condition": {"cop": "==", "args": [["c", 0], 1]},
"true_branch": [{"cop": "=", "returns": [["z", 0]], "args": [1]}],
}


def test_nested_bitwise_op() -> None:
"""From https://github.com/CQCL/pytket-phir/issues/133 ."""
circ = Circuit(4)
Expand Down Expand Up @@ -171,23 +210,6 @@ def test_sleep_idle() -> None:
assert phir["ops"][7] == {"mop": "Idle", "args": [["q", 0]], "duration": [1.0, "s"]}


def test_multiple_sleep() -> None:
"""Ensure multiple sleep ops get converted correctly."""
qasm = """
OPENQASM 2.0;
include "hqslib1_dev.inc";
qreg q[2];
sleep(1) q[0];
sleep(2) q[1];
"""
circ = circuit_from_qasm_str(qasm)
phir = json.loads(pytket_to_phir(circ))
assert phir["ops"][2] == {"mop": "Idle", "args": [["q", 0]], "duration": [1.0, "s"]}
assert phir["ops"][4] == {"mop": "Idle", "args": [["q", 1]], "duration": [2.0, "s"]}


def test_reordering_classical_conditional() -> None:
"""From https://github.com/CQCL/pytket-phir/issues/150 ."""
circuit = Circuit(1)
Expand Down

0 comments on commit 09970f6

Please sign in to comment.