From 4206cb156d228dc6b0f4be9571183af4fcc9b554 Mon Sep 17 00:00:00 2001 From: Dax Fohl Date: Wed, 22 Jan 2025 11:17:07 -0800 Subject: [PATCH] Fix classical conditions getting optimized ahead of their measurements (#6872) * Fix classical conditions getting optimized ahead of their measurements * format * Improve test * format --- cirq-core/cirq/sim/simulator.py | 11 ++++++----- cirq-core/cirq/sim/simulator_test.py | 11 +++++++++++ 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/cirq-core/cirq/sim/simulator.py b/cirq-core/cirq/sim/simulator.py index 66104a43f5e..404658230f7 100644 --- a/cirq-core/cirq/sim/simulator.py +++ b/cirq-core/cirq/sim/simulator.py @@ -973,9 +973,10 @@ def split_into_matching_protocol_then_general( qubit A will cause later operations on A to be part of the non-matching suffix, but later operations on other qubits will continue to be put into the matching part (as long as those qubits have had no non-matching operation - up to that point). + up to that point). Measurement keys are handled equivalently. """ blocked_qubits: Set[cirq.Qid] = set() + blocked_keys: Set[cirq.MeasurementKey] = set() matching_prefix = circuits.Circuit() general_suffix = circuits.Circuit() for moment in circuit: @@ -983,12 +984,12 @@ def split_into_matching_protocol_then_general( general_part = [] for op in moment: qs = set(op.qubits) - if not predicate(op) or not qs.isdisjoint(blocked_qubits): - blocked_qubits |= qs - - if qs.isdisjoint(blocked_qubits): + keys = protocols.measurement_keys_touched(op) + if predicate(op) and qs.isdisjoint(blocked_qubits) and keys.isdisjoint(blocked_keys): matching_part.append(op) else: + blocked_qubits |= qs + blocked_keys |= keys general_part.append(op) if matching_part: matching_prefix.append(circuits.Moment(matching_part)) diff --git a/cirq-core/cirq/sim/simulator_test.py b/cirq-core/cirq/sim/simulator_test.py index f6b4fdb1b61..c3a6d7e2b53 100644 --- a/cirq-core/cirq/sim/simulator_test.py +++ b/cirq-core/cirq/sim/simulator_test.py @@ -399,6 +399,17 @@ def test_sample_repeated_measurement_keys(): assert len(result.records['b'][0]) == 2 +def test_classical_controls_go_to_suffix_if_corresponding_measurement_does(): + subcircuit = cirq.CircuitOperation(cirq.FrozenCircuit()).with_classical_controls('a') + m = cirq.measure(cirq.LineQubit(0), key='a') + circuit = cirq.Circuit(m, subcircuit) + prefix, suffix = cirq.sim.simulator.split_into_matching_protocol_then_general( + circuit, lambda op: op != m # any op but m goes into prefix + ) + assert not prefix + assert suffix == circuit + + def test_simulate_with_invert_mask(): q0, q1, q2, q3, q4 = cirq.LineQid.for_qid_shape((2, 3, 3, 3, 4)) c = cirq.Circuit(