From 2641a8a194bf93dd867376e85b5fabc08f3e7dbd Mon Sep 17 00:00:00 2001 From: Asa Kosto Date: Fri, 19 Jan 2024 10:22:19 -0700 Subject: [PATCH] sharder now checks for WAW and WAR instead of just one, fixed typo in cli --- pytket/phir/cli.py | 2 +- pytket/phir/sharding/sharder.py | 2 +- tests/data/qasm/classical_ordering.qasm | 11 +++ tests/test_sharder.py | 95 ++++++++++++++++++++++++- tests/test_utils.py | 1 + 5 files changed, 107 insertions(+), 4 deletions(-) create mode 100644 tests/data/qasm/classical_ordering.qasm diff --git a/pytket/phir/cli.py b/pytket/phir/cli.py index b7704d2..1ef0a04 100644 --- a/pytket/phir/cli.py +++ b/pytket/phir/cli.py @@ -59,7 +59,7 @@ def main() -> None: for file in args.qasm_files: print(f"Processing {file}") # noqa: T201 c = circuit_from_qasm(file) - tket_opt_level = int(args.tk) + tket_opt_level = int(args.tket_opt_level) rc = rebase_to_qtm_machine(c, args.machine, tket_opt_level) qasm = circuit_to_qasm_str(rc, header="hqslib1") circ = circuit_from_qasm_str(qasm) diff --git a/pytket/phir/sharding/sharder.py b/pytket/phir/sharding/sharder.py index 5271542..2a0fbde 100644 --- a/pytket/phir/sharding/sharder.py +++ b/pytket/phir/sharding/sharder.py @@ -194,7 +194,7 @@ def _resolve_shard_dependencies( "...adding shard dep %s -> WAW", self._bit_written_by[bit_written] ) depends_upon.add(self._bit_written_by[bit_written]) - elif bit_written in self._bit_read_by: + if bit_written in self._bit_read_by: logger.debug( "...adding shard dep %s -> WAR", self._bit_read_by[bit_written] ) diff --git a/tests/data/qasm/classical_ordering.qasm b/tests/data/qasm/classical_ordering.qasm new file mode 100644 index 0000000..73e4d93 --- /dev/null +++ b/tests/data/qasm/classical_ordering.qasm @@ -0,0 +1,11 @@ +OPENQASM 2.0; +include "hqslib1.inc"; + +qreg q[1]; +creg a[4]; +creg b[4]; +creg c[4]; +a = 3; +b = a; +c = (a - b); +a = (a << 1); diff --git a/tests/test_sharder.py b/tests/test_sharder.py index e5fdabb..af9d130 100644 --- a/tests/test_sharder.py +++ b/tests/test_sharder.py @@ -143,7 +143,7 @@ def test_simple_conditional(self) -> None: assert shards[3].qubits_used == {circuit.qubits[0]} assert shards[3].bits_written == {circuit.bits[0]} assert shards[3].bits_read == {circuit.bits[0]} - assert shards[3].depends_upon == {shards[0].ID, shards[1].ID} + assert shards[3].depends_upon == {shards[0].ID, shards[1].ID, shards[2].ID} assert len(shards[3].sub_commands.items()) == 1 s2_qubit, s2_sub_cmds = next(iter(shards[3].sub_commands.items())) assert s2_qubit == circuit.qubits[0] @@ -274,7 +274,7 @@ def test_classical_hazards(self) -> None: assert shards[3].qubits_used == set() assert shards[3].bits_written == {circuit.bits[0]} assert shards[3].bits_read == {circuit.bits[0]} - assert shards[3].depends_upon == {shards[0].ID} + assert shards[3].depends_upon == {shards[0].ID, shards[2].ID} # shard 4: [] if(c[2]==1) c[0]=1; assert shards[4].primary_command.op.type == OpType.Conditional @@ -307,3 +307,94 @@ def test_with_big_gate(self) -> None: assert len(shards[1].sub_commands) == 0 assert shards[1].qubits_used == {circuit.qubits[3]} assert shards[1].bits_written == {circuit.bits[0]} + + def test_classical_ordering_breaking_circuit(self) -> None: + circuit = get_qasm_as_circuit(QasmFile.classical_ordering) + sharder = Sharder(circuit) + shards = sharder.shard() + + assert len(shards) == 4 + + # shard 0 + assert shards[0].primary_command.op.type == OpType.SetBits + assert len(shards[0].sub_commands.items()) == 0 + assert shards[0].qubits_used == set() + assert shards[0].bits_written == { + circuit.bits[0], + circuit.bits[1], + circuit.bits[2], + circuit.bits[3], + } + assert shards[0].bits_read == { + circuit.bits[0], + circuit.bits[1], + circuit.bits[2], + circuit.bits[3], + } + assert shards[0].depends_upon == set() + + # shard 1 + assert shards[1].primary_command.op.type == OpType.CopyBits + assert len(shards[1].sub_commands.items()) == 0 + assert shards[1].qubits_used == set() + assert shards[1].bits_written == { + circuit.bits[4], + circuit.bits[5], + circuit.bits[6], + circuit.bits[7], + } + assert shards[1].bits_read == { + circuit.bits[0], + circuit.bits[1], + circuit.bits[2], + circuit.bits[3], + circuit.bits[4], + circuit.bits[5], + circuit.bits[6], + circuit.bits[7], + } + assert shards[1].depends_upon == {shards[0].ID} + + # shard 2 + assert shards[2].primary_command.op.type == OpType.ClassicalExpBox + assert len(shards[2].sub_commands.items()) == 0 + assert shards[2].qubits_used == set() + assert shards[2].bits_written == { + circuit.bits[8], + circuit.bits[9], + circuit.bits[10], + circuit.bits[11], + } + assert shards[2].bits_read == { + circuit.bits[0], + circuit.bits[1], + circuit.bits[2], + circuit.bits[3], + circuit.bits[4], + circuit.bits[5], + circuit.bits[6], + circuit.bits[7], + circuit.bits[8], + circuit.bits[9], + circuit.bits[10], + circuit.bits[11], + } + assert shards[2].depends_upon == {shards[0].ID, shards[1].ID} + + # shard 2 + assert shards[3].primary_command.op.type == OpType.ClassicalExpBox + assert len(shards[3].sub_commands.items()) == 0 + assert shards[3].qubits_used == set() + assert shards[3].bits_written == { + circuit.bits[0], + circuit.bits[1], + circuit.bits[2], + circuit.bits[3], + } + assert shards[3].bits_read == { + circuit.bits[0], + circuit.bits[1], + circuit.bits[2], + circuit.bits[3], + } + assert shards[3].depends_upon == {shards[0].ID, shards[2].ID} diff --git a/tests/test_utils.py b/tests/test_utils.py index f9c7b53..55fd646 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -42,6 +42,7 @@ class QasmFile(Enum): tk2_same_angle = auto() tk2_diff_angles = auto() rxrz = auto() + classical_ordering = auto() def get_qasm_as_circuit(qasm_file: QasmFile) -> "Circuit":