Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for ExplicitPredicate, ExplicitModifier, MultiBitOp #162

Merged
merged 12 commits into from
Apr 10, 2024
Merged
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ repos:
- id: debug-statements

- repo: https://github.com/crate-ci/typos
rev: v1.20.4
rev: v1.20.7
hooks:
- id: typos

Expand Down
223 changes: 166 additions & 57 deletions pytket/phir/phirgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import json
import logging
from copy import deepcopy
from importlib.metadata import version
from typing import TYPE_CHECKING, Any, TypeAlias

Expand Down Expand Up @@ -218,13 +219,155 @@ def convert_gate(op: tk.Op, cmd: tk.Command) -> JsonDict | None:
return qop


def cop_from_op_name(op_name: str) -> str:
"""Get PHIR classical op name from pytket op name."""
match op_name:
case "AND":
cop = "&"
case "OR":
cop = "|"
case "XOR":
cop = "^"
case "NOT":
cop = "~"
case name:
raise NotImplementedError(name)
return cop


def convert_classicalevalop(op: tk.ClassicalEvalOp, cmd: tk.Command) -> JsonDict | None:
"""Return PHIR dict for a pytket ClassicalEvalOp."""
# Exclude conditional bits from args
args = cmd.args[cmd.op.width :] if isinstance(cmd.op, tk.Conditional) else cmd.args
out: JsonDict | None = None
match op:
case tk.CopyBitsOp():
if len(cmd.bits) != len(args) // 2:
msg = "LHS and RHS lengths mismatch for CopyBits"
raise TypeError(msg)
out = assign_cop(
[arg_to_bit(bit) for bit in cmd.bits],
[arg_to_bit(args[i]) for i in range(len(args) // 2)],
)
case tk.SetBitsOp():
if len(cmd.bits) != len(op.values):
logger.error("LHS and RHS lengths mismatch for classical assignment")
raise ValueError
out = assign_cop(
[arg_to_bit(bit) for bit in cmd.bits], list(map(int, op.values))
)
case tk.RangePredicateOp(): # where the condition is a range
cond: JsonDict
match op.lower, op.upper:
case l, u if l == u:
cond = {
"cop": "==",
"args": [args[0].reg_name, u],
}
case l, u if u == UINTMAX:
qartik marked this conversation as resolved.
Show resolved Hide resolved
cond = {
"cop": ">=",
"args": [args[0].reg_name, l],
}
case 0, u:
cond = {
"cop": "<=",
"args": [args[0].reg_name, u],
}
out = {
"block": "if",
"condition": cond,
"true_branch": [assign_cop([arg_to_bit(cmd.bits[0])], [1])],
}
case tk.MultiBitOp():
if len(args) % len(cmd.bits) != 0:
msg = "Input bit- and output bit lengths mismatch."
raise TypeError(msg)

cop = cop_from_op_name(op.basic_op.get_name())
is_explicit = op.basic_op.type == tk.OpType.ExplicitPredicate

# determine number of register operands involved in the operation
operand_count = len(args) // len(cmd.bits) - is_explicit

iters = [iter(args)] * (operand_count + is_explicit)
iter2 = deepcopy(iters)

# Columns of expressions, e.g.,
# AND (*2) a[0], b[0], c[0]
# , a[1], b[1], c[1]
# would be [(a[0], a[1]), (b[0], b[1]), (c[0], c[1])]
# and AND (*2) a[0], a[1], b[0]
# , b[1], c[0], c[1]
# would be [(a[0], b[1]), (a[1], c[0]), (b[0], c[1])]
cols = zip(*zip(*iters, strict=True), strict=True)

if all(
all(col[0].reg_name == bit.reg_name for bit in col) for col in cols
): # expression can be applied register-wise
out = assign_cop(
[cmd.bits[0].reg_name],
[
{
"cop": cop,
"args": [arg.reg_name for arg in args[:operand_count]],
}
],
)
else: # apply a sequence of bit-wise ops
exps = zip(*iter2, strict=True)
out = {
"block": "sequence",
"ops": [
assign_cop(
[arg_to_bit(bit)],
[
{
"cop": cop,
"args": [
arg_to_bit(arg) for arg in exp[:operand_count]
],
}
],
)
for bit, exp in zip(cmd.bits, exps, strict=True)
],
}
case _:
raise NotImplementedError(op)

return out


def multi_bit_condition(args: "list[UnitID]", value: int) -> JsonDict:
"""Construct bitwise condition."""
return {
"cop": "&",
"args": [
{"cop": "==", "args": [arg_to_bit(arg), bval]}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to be careful of endianness here. The conditional value in pytket is little-endian, e.g. a value of 2 on the bits [a[0], a[1]] means a[0] == 0 && a[1] == 1. Does PHIR use the same convention?

Copy link
Member

@qartik qartik Apr 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qciaran would you be able to confirm. I forgot to mention you at #162 (comment), I also had my doubts.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree though that the default choice should be little-endian here. While Ciaran confirms, I am pushing the change to update the behavior.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cqc-alec although I wonder if the bits are explicitly specified like:

        condition_bits=[Bit("b", 1), Bit("b", 2)],
        condition_value=2,

ie, [b[1], b[2]] == 2 (pytket translation of above in the phir comment), is the intention from pytket still that they will be reordered and considered little-endian? I'd imagine the user to write condition_bits=[Bit("b", 2), Bit("b", 1)] if they wanted b[2]=1 and b[1]=0.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that is the intention: the order in which the bits are given corresponds to the little-endian binary expansion of the value. In your example, the condition is b[1] == 0 && b[2] == 1.

Copy link
Collaborator

@qciaran qciaran Apr 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, in PHIR a[0]==0 && a[1]==1 is the integer value 2. Or: {"cop": "=", "args": [2], "returns": ["b"]} results in b[0]==0 & b[1]==1
Sounds like that should be mentioned in the spec, womp womp.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filed CQCL/phir#77, this PR seems ready to be merged.

for (arg, bval) in zip(
args[::-1], map(int, f"{value:0{len(args)}b}"), strict=True
)
],
}


def convert_subcmd(op: tk.Op, cmd: tk.Command) -> JsonDict | None:
"""Return PHIR dict given a tket op and its arguments."""
if op.is_gate():
return convert_gate(op, cmd)

out: JsonDict | None = None
match op: # non-quantum op
case tk.Conditional():
out = {
"block": "if",
"condition": {"cop": "==", "args": [arg_to_bit(cmd.args[0]), op.value]}
if op.width == 1
else multi_bit_condition(cmd.args[: op.width], op.value),
"true_branch": [convert_subcmd(op.op, cmd)],
}

case tk.BarrierOp():
if op.data:
# See https://github.com/CQCL/tket/blob/0ec603986821d994caa3a0fb9c4640e5bc6c0a24/pytket/pytket/qasm/qasm.py#L419-L459
Expand All @@ -246,45 +389,6 @@ def convert_subcmd(op: tk.Op, cmd: tk.Command) -> JsonDict | None:
"args": [arg_to_bit(qbit) for qbit in cmd.qubits],
}

case tk.Conditional(): # where the condition is equality check
out = {
"block": "if",
"condition": {
"cop": "==",
"args": [
arg_to_bit(cmd.args[0])
if op.width == 1
else cmd.args[0].reg_name,
op.value,
],
},
"true_branch": [convert_subcmd(op.op, cmd)],
}

case tk.RangePredicateOp(): # where the condition is a range
cond: JsonDict
match op.lower, op.upper:
case l, u if l == u:
cond = {
"cop": "==",
"args": [cmd.args[0].reg_name, u],
}
case l, u if u == UINTMAX:
cond = {
"cop": ">=",
"args": [cmd.args[0].reg_name, l],
}
case 0, u:
cond = {
"cop": "<=",
"args": [cmd.args[0].reg_name, u],
}
out = {
"block": "if",
"condition": cond,
"true_branch": [assign_cop([arg_to_bit(cmd.bits[0])], [1])],
}

case tk.ClassicalExpBox():
exp = op.get_exp()
match exp:
Expand All @@ -295,29 +399,34 @@ def convert_subcmd(op: tk.Op, cmd: tk.Command) -> JsonDict | None:
rhs = [classical_op(exp)]
out = assign_cop([cmd.bits[0].reg_name], rhs)

case tk.SetBitsOp():
if len(cmd.bits) != len(op.values):
logger.error("LHS and RHS lengths mismatch for classical assignment")
raise ValueError
out = assign_cop(
[arg_to_bit(bit) for bit in cmd.bits], list(map(int, op.values))
)

case tk.CopyBitsOp():
if len(cmd.bits) != len(cmd.args) // 2:
logger.warning("LHS and RHS lengths mismatch for CopyBits")
out = assign_cop(
[arg_to_bit(bit) for bit in cmd.bits],
[arg_to_bit(cmd.args[i]) for i in range(len(cmd.args) // 2)],
)
case tk.ClassicalEvalOp():
return convert_classicalevalop(op, cmd)

case tk.WASMOp():
return create_wasm_op(cmd, op)

case _:
# TODO(kartik): NYI
# https://github.com/CQCL/pytket-phir/issues/25
raise NotImplementedError
# Exclude conditional bits from args
args = (
cmd.args[cmd.op.width :]
if isinstance(cmd.op, tk.Conditional)
else cmd.args
)
match op.type:
case tk.OpType.ExplicitPredicate | tk.OpType.ExplicitModifier:
# exclude output bit when not modifying in place
args = args[:-1] if op.type == tk.OpType.ExplicitPredicate else args
out = assign_cop(
[arg_to_bit(cmd.bits[0])],
[
{
"cop": cop_from_op_name(op.get_name()),
"args": [arg_to_bit(arg) for arg in args],
}
],
)
case _:
raise NotImplementedError(op.type)

return out

Expand Down
2 changes: 2 additions & 0 deletions pytket/phir/sharding/sharder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
OpType.ClassicalExpBox, # some classical operations are rolled up into a box
OpType.RangePredicate,
OpType.ExplicitPredicate,
OpType.ExplicitModifier,
OpType.MultiBit,
OpType.CopyBits,
OpType.WASM,
]
Expand Down
23 changes: 0 additions & 23 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,10 @@

# mypy: disable-error-code="misc"

import json
import logging

import pytest

from pytket.circuit import Bit, Circuit
from pytket.phir.api import pytket_to_phir, qasm_to_phir
from pytket.phir.qtm_machine import QtmMachine

Expand Down Expand Up @@ -50,27 +48,6 @@ def test_pytket_to_phir_h1_1_all(self, test_file: QasmFile) -> None:

assert pytket_to_phir(circuit, QtmMachine.H1)

def test_pytket_classical_only(self) -> None:
c = Circuit(1)
a = c.add_c_register("a", 2)
b = c.add_c_register("b", 3)

c.add_c_copyreg(a, b)
c.add_c_copybits([Bit("b", 2), Bit("a", 1)], [Bit("a", 0), Bit("b", 0)])

phir = json.loads(pytket_to_phir(c))

assert phir["ops"][3] == {
"cop": "=",
"returns": [["b", 0], ["b", 1]],
"args": [["a", 0], ["a", 1]],
}
assert phir["ops"][5] == {
"cop": "=",
"returns": [["a", 0], ["b", 0]],
"args": [["b", 2], ["a", 1]],
}
qartik marked this conversation as resolved.
Show resolved Hide resolved

def test_qasm_to_phir(self) -> None:
"""Test the qasm string entrypoint works."""
qasm = """
Expand Down
Loading