From f7f2c2600a732385caf3874108c90c3f0e6fe7fc Mon Sep 17 00:00:00 2001 From: Kartik Singhal Date: Tue, 9 Apr 2024 13:19:22 -0500 Subject: [PATCH] feat(phirgen): add support for irregular multibit ops --- .pre-commit-config.yaml | 2 +- pytket/phir/phirgen.py | 64 +++++++++++++++++++++++++++++++---------- tests/test_phirgen.py | 35 ++++++++++++++++++++++ 3 files changed, 85 insertions(+), 16 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 60c8cd9..8a0cfbf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,7 +13,7 @@ repos: - id: debug-statements - repo: https://github.com/crate-ci/typos - rev: v1.20.4 + rev: v1.20.5 hooks: - id: typos diff --git a/pytket/phir/phirgen.py b/pytket/phir/phirgen.py index d3edc98..a389883 100644 --- a/pytket/phir/phirgen.py +++ b/pytket/phir/phirgen.py @@ -10,6 +10,7 @@ import json import logging +from copy import deepcopy from importlib.metadata import version from typing import TYPE_CHECKING, Any, TypeAlias @@ -278,22 +279,55 @@ def convert_classicalevalop(op: tk.ClassicalEvalOp, cmd: tk.Command) -> JsonDict "true_branch": [assign_cop([arg_to_bit(cmd.bits[0])], [1])], } case tk.MultiBitOp(): + cop = cop_from_op_name(op.basic_op.get_name()) + is_explicit = op.basic_op.type == tk.OpType.ExplicitPredicate + # determine number of register operands involved in the operation - operand_count = ( - len(args) // len(cmd.bits) - 1 - if op.basic_op.type == tk.OpType.ExplicitPredicate - else len(args) // len(cmd.bits) - ) - out = assign_cop( - # Converting to regwise operations that pecos can handle - [cmd.bits[0].reg_name], - [ - { - "cop": cop_from_op_name(op.basic_op.get_name()), - "args": [arg.reg_name for arg in args[:operand_count]], - } - ], - ) + operand_count = len(args) // len(cmd.bits) - (1 if is_explicit else 0) + + iters = [iter(args)] * (operand_count + (1 if is_explicit else 0)) + iter2 = deepcopy(iters) + + # Columns of expressions, e.g., + # AND (*2) a[0], b[0], c[0] + # , a[1], b[1], c[1] + # would be [(a[0], a[1]), (b[0], b[1]), (c[0], c[1])] + # and AND (*2) a[0], a[1], b[0] + # , b[1], c[0], c[1] + # would be [(a[0], b[1]), (a[1], c[0]), (b[0], c[1])] + cols = zip(*zip(*iters, strict=True), strict=True) + + if all( + all(col[0].reg_name == bit.reg_name for bit in col) for col in cols + ): # expression can be applied register-wise + out = assign_cop( + [cmd.bits[0].reg_name], + [ + { + "cop": cop, + "args": [arg.reg_name for arg in args[:operand_count]], + } + ], + ) + else: # apply a sequence of bit-wise ops + exps = zip(*iter2, strict=True) + out = { + "block": "sequence", + "ops": [ + assign_cop( + [arg_to_bit(bit)], + [ + { + "cop": cop, + "args": [ + arg_to_bit(arg) for arg in exp[:operand_count] + ], + } + ], + ) + for bit, exp in zip(cmd.bits, exps, strict=True) + ], + } case _: raise NotImplementedError(op) diff --git a/tests/test_phirgen.py b/tests/test_phirgen.py index bb56158..090c798 100644 --- a/tests/test_phirgen.py +++ b/tests/test_phirgen.py @@ -396,3 +396,38 @@ def test_multi_bit_ops() -> None: {"cop": "=", "returns": ["c1"], "args": [{"cop": "~", "args": ["c1"]}]} ], } + + +def test_irregular_multibit_ops() -> None: + """From https://github.com/CQCL/pytket-phir/pull/162#discussion_r1555807863 .""" + c = Circuit() + areg = c.add_c_register("a", 2) + breg = c.add_c_register("b", 2) + creg = c.add_c_register("c", 2) + c.add_c_and_to_registers(areg, breg, creg) + mbop = c.get_commands()[0].op + c.add_gate(mbop, [areg[0], areg[1], breg[0], breg[1], creg[0], creg[1]]) + + phir = json.loads(pytket_to_phir(c)) + assert phir["ops"][3] == {"//": "AND (*2) a[0], b[0], c[0], a[1], b[1], c[1];"} + assert phir["ops"][4] == { + "cop": "=", + "returns": ["c"], + "args": [{"cop": "&", "args": ["a", "b"]}], + } + assert phir["ops"][5] == {"//": "AND (*2) a[0], a[1], b[0], b[1], c[0], c[1];"} + assert phir["ops"][6] == { + "block": "sequence", + "ops": [ + { + "cop": "=", + "returns": [["b", 0]], + "args": [{"cop": "&", "args": [["a", 0], ["a", 1]]}], + }, + { + "cop": "=", + "returns": [["c", 1]], + "args": [{"cop": "&", "args": [["b", 1], ["c", 0]]}], + }, + ], + }