From 09970f634914ea45ef7f814aaf221dc8b3d17a41 Mon Sep 17 00:00:00 2001 From: Kartik Singhal Date: Mon, 8 Apr 2024 08:39:15 -0500 Subject: [PATCH] fix(phirgen): make conditionals more robust based on Alec's feedback --- .pre-commit-config.yaml | 2 +- pytket/phir/phirgen.py | 49 ++++++++++++++---------- tests/test_phirgen.py | 82 ++++++++++++++++++++++++++--------------- 3 files changed, 82 insertions(+), 51 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5d1bcbd..60c8cd9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 + rev: v4.6.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer diff --git a/pytket/phir/phirgen.py b/pytket/phir/phirgen.py index 672679b..81ce731 100644 --- a/pytket/phir/phirgen.py +++ b/pytket/phir/phirgen.py @@ -237,15 +237,15 @@ def cop_from_op_name(op_name: str) -> str: def convert_classicalevalop(op: tk.ClassicalEvalOp, cmd: tk.Command) -> JsonDict | None: """Return PHIR dict for a pytket ClassicalEvalOp.""" # Exclude conditional bit from args - args = cmd.args[1:] if isinstance(cmd.op, tk.Conditional) else cmd.args + args = cmd.args[cmd.op.width :] if isinstance(cmd.op, tk.Conditional) else cmd.args out: JsonDict | None = None match op: case tk.CopyBitsOp(): - if len(cmd.bits) != len(cmd.args) // 2: + if len(cmd.bits) != len(args) // 2: logger.warning("LHS and RHS lengths mismatch for CopyBits") out = assign_cop( [arg_to_bit(bit) for bit in cmd.bits], - [arg_to_bit(args[i]) for i in range(len(cmd.args) // 2)], + [arg_to_bit(args[i]) for i in range(len(args) // 2)], ) case tk.SetBitsOp(): if len(cmd.bits) != len(op.values): @@ -260,17 +260,17 @@ def convert_classicalevalop(op: tk.ClassicalEvalOp, cmd: tk.Command) -> JsonDict case l, u if l == u: cond = { "cop": "==", - "args": [cmd.args[0].reg_name, u], + "args": [args[0].reg_name, u], } case l, u if u == UINTMAX: cond = { "cop": ">=", - "args": [cmd.args[0].reg_name, l], + "args": [args[0].reg_name, l], } case 0, u: cond = { "cop": "<=", - "args": [cmd.args[0].reg_name, u], + "args": [args[0].reg_name, u], } out = { "block": "if", @@ -280,9 +280,9 @@ def convert_classicalevalop(op: tk.ClassicalEvalOp, cmd: tk.Command) -> JsonDict case tk.MultiBitOp(): # determine number of register operands involved in the operation operand_count = ( - len(cmd.args) // len(cmd.bits) - 1 + len(args) // len(cmd.bits) - 1 if op.basic_op.type == tk.OpType.ExplicitPredicate - else len(cmd.args) // len(cmd.bits) + else len(args) // len(cmd.bits) ) out = assign_cop( # Converting to regwise operations that pecos can handle @@ -300,6 +300,17 @@ def convert_classicalevalop(op: tk.ClassicalEvalOp, cmd: tk.Command) -> JsonDict return out +def multi_bit_condition(args: "list[UnitID]", value: int) -> JsonDict: + """Construct bitwise condition.""" + return { + "cop": "&", + "args": [ + {"cop": "==", "args": [arg_to_bit(arg), bval]} + for (arg, bval) in zip(args, map(int, f"{value:b}"), strict=False) + ], + } + + def convert_subcmd(op: tk.Op, cmd: tk.Command) -> JsonDict | None: """Return PHIR dict given a tket op and its arguments.""" if op.is_gate(): @@ -307,18 +318,12 @@ def convert_subcmd(op: tk.Op, cmd: tk.Command) -> JsonDict | None: out: JsonDict | None = None match op: # non-quantum op - case tk.Conditional(): # where the condition is equality check + case tk.Conditional(): out = { "block": "if", - "condition": { - "cop": "==", - "args": [ - arg_to_bit(cmd.args[0]) - if op.width == 1 - else cmd.args[0].reg_name, - op.value, - ], - }, + "condition": {"cop": "==", "args": [arg_to_bit(cmd.args[0]), op.value]} + if op.width == 1 + else multi_bit_condition(cmd.args[: op.width], op.value), "true_branch": [convert_subcmd(op.op, cmd)], } @@ -360,8 +365,12 @@ def convert_subcmd(op: tk.Op, cmd: tk.Command) -> JsonDict | None: return create_wasm_op(cmd, op) case _: - # Exclude conditional bit from args - args = cmd.args[1:] if isinstance(cmd.op, tk.Conditional) else cmd.args + # Exclude conditional bits from args + args = ( + cmd.args[cmd.op.width :] + if isinstance(cmd.op, tk.Conditional) + else cmd.args + ) match op.type: case tk.OpType.ExplicitPredicate | tk.OpType.ExplicitModifier: # exclude output bit when not modifying in place diff --git a/tests/test_phirgen.py b/tests/test_phirgen.py index 61c43b6..7de6d57 100644 --- a/tests/test_phirgen.py +++ b/tests/test_phirgen.py @@ -18,6 +18,35 @@ from .test_utils import QasmFile, get_qasm_as_circuit +def test_multiple_sleep() -> None: + """Ensure multiple sleep ops get converted correctly.""" + qasm = """ + OPENQASM 2.0; + include "hqslib1_dev.inc"; + + qreg q[2]; + + sleep(1) q[0]; + sleep(2) q[1]; + """ + circ = circuit_from_qasm_str(qasm) + phir = json.loads(pytket_to_phir(circ)) + assert phir["ops"][2] == {"mop": "Idle", "args": [["q", 0]], "duration": [1.0, "s"]} + assert phir["ops"][4] == {"mop": "Idle", "args": [["q", 1]], "duration": [2.0, "s"]} + + +def test_simple_cond_classical() -> None: + """Ensure conditional classical operation are correctly generated.""" + circ = get_qasm_as_circuit(QasmFile.simple_cond) + phir = json.loads(pytket_to_phir(circ)) + assert phir["ops"][-6] == {"//": "IF ([c[0]] == 1) THEN SetBits(1) z[0];"} + assert phir["ops"][-5] == { + "block": "if", + "condition": {"cop": "==", "args": [["c", 0], 1]}, + "true_branch": [{"cop": "=", "returns": [["z", 0]], "args": [1]}], + } + + def test_pytket_classical_only() -> None: """From https://github.com/CQCL/pytket-phir/issues/61 .""" c = Circuit(1) @@ -29,6 +58,12 @@ def test_pytket_classical_only() -> None: c.add_c_copybits( [Bit("b", 2), Bit("a", 1)], [Bit("a", 0), Bit("b", 0)], condition=Bit("b", 1) ) + c.add_c_copybits( + [Bit("a", 0), Bit("a", 1)], # type: ignore[list-item] # overloaded function + [Bit("b", 0), Bit("b", 1)], # type: ignore[list-item] # overloaded function + condition_bits=[Bit("b", 1), Bit("b", 2)], + condition_value=2, + ) phir = json.loads(pytket_to_phir(c)) @@ -49,6 +84,22 @@ def test_pytket_classical_only() -> None: {"cop": "=", "returns": [["a", 0], ["b", 0]], "args": [["b", 2], ["a", 1]]} ], } + assert phir["ops"][8] == { + "//": "IF ([b[1], b[2]] == 2) THEN CopyBits a[0], a[1], b[0], b[1];" + } + assert phir["ops"][9] == { + "block": "if", + "condition": { + "cop": "&", + "args": [ + {"cop": "==", "args": [["b", 1], 1]}, + {"cop": "==", "args": [["b", 2], 0]}, + ], + }, + "true_branch": [ + {"cop": "=", "returns": [["b", 0], ["b", 1]], "args": [["a", 0], ["a", 1]]} + ], + } def test_classicalexpbox() -> None: @@ -121,23 +172,11 @@ def test_conditional_barrier() -> None: assert phir["ops"][4] == {"//": "IF ([m[0], m[1]] == 0) THEN Barrier q[0], q[1];"} assert phir["ops"][5] == { "block": "if", - "condition": {"cop": "==", "args": ["m", 0]}, + "condition": {"cop": "&", "args": [{"cop": "==", "args": [["m", 0], 0]}]}, "true_branch": [{"meta": "barrier", "args": [["q", 0], ["q", 1]]}], } -def test_simple_cond_classical() -> None: - """Ensure conditional classical operation are correctly generated.""" - circ = get_qasm_as_circuit(QasmFile.simple_cond) - phir = json.loads(pytket_to_phir(circ)) - assert phir["ops"][-6] == {"//": "IF ([c[0]] == 1) THEN SetBits(1) z[0];"} - assert phir["ops"][-5] == { - "block": "if", - "condition": {"cop": "==", "args": [["c", 0], 1]}, - "true_branch": [{"cop": "=", "returns": [["z", 0]], "args": [1]}], - } - - def test_nested_bitwise_op() -> None: """From https://github.com/CQCL/pytket-phir/issues/133 .""" circ = Circuit(4) @@ -171,23 +210,6 @@ def test_sleep_idle() -> None: assert phir["ops"][7] == {"mop": "Idle", "args": [["q", 0]], "duration": [1.0, "s"]} -def test_multiple_sleep() -> None: - """Ensure multiple sleep ops get converted correctly.""" - qasm = """ - OPENQASM 2.0; - include "hqslib1_dev.inc"; - - qreg q[2]; - - sleep(1) q[0]; - sleep(2) q[1]; - """ - circ = circuit_from_qasm_str(qasm) - phir = json.loads(pytket_to_phir(circ)) - assert phir["ops"][2] == {"mop": "Idle", "args": [["q", 0]], "duration": [1.0, "s"]} - assert phir["ops"][4] == {"mop": "Idle", "args": [["q", 1]], "duration": [2.0, "s"]} - - def test_reordering_classical_conditional() -> None: """From https://github.com/CQCL/pytket-phir/issues/150 .""" circuit = Circuit(1)