Skip to content

Commit

Permalink
Fix classical conditions getting optimized ahead of their measurements (
Browse files Browse the repository at this point in the history
#6872)

* Fix classical conditions getting optimized ahead of their measurements

* format

* Improve test

* format
  • Loading branch information
daxfohl authored Jan 22, 2025
1 parent 870860f commit 4206cb1
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
11 changes: 6 additions & 5 deletions cirq-core/cirq/sim/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,22 +973,23 @@ def split_into_matching_protocol_then_general(
qubit A will cause later operations on A to be part of the non-matching
suffix, but later operations on other qubits will continue to be put into
the matching part (as long as those qubits have had no non-matching operation
up to that point).
up to that point). Measurement keys are handled equivalently.
"""
blocked_qubits: Set[cirq.Qid] = set()
blocked_keys: Set[cirq.MeasurementKey] = set()
matching_prefix = circuits.Circuit()
general_suffix = circuits.Circuit()
for moment in circuit:
matching_part = []
general_part = []
for op in moment:
qs = set(op.qubits)
if not predicate(op) or not qs.isdisjoint(blocked_qubits):
blocked_qubits |= qs

if qs.isdisjoint(blocked_qubits):
keys = protocols.measurement_keys_touched(op)
if predicate(op) and qs.isdisjoint(blocked_qubits) and keys.isdisjoint(blocked_keys):
matching_part.append(op)
else:
blocked_qubits |= qs
blocked_keys |= keys
general_part.append(op)
if matching_part:
matching_prefix.append(circuits.Moment(matching_part))
Expand Down
11 changes: 11 additions & 0 deletions cirq-core/cirq/sim/simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,17 @@ def test_sample_repeated_measurement_keys():
assert len(result.records['b'][0]) == 2


def test_classical_controls_go_to_suffix_if_corresponding_measurement_does():
subcircuit = cirq.CircuitOperation(cirq.FrozenCircuit()).with_classical_controls('a')
m = cirq.measure(cirq.LineQubit(0), key='a')
circuit = cirq.Circuit(m, subcircuit)
prefix, suffix = cirq.sim.simulator.split_into_matching_protocol_then_general(
circuit, lambda op: op != m # any op but m goes into prefix
)
assert not prefix
assert suffix == circuit


def test_simulate_with_invert_mask():
q0, q1, q2, q3, q4 = cirq.LineQid.for_qid_shape((2, 3, 3, 3, 4))
c = cirq.Circuit(
Expand Down

0 comments on commit 4206cb1

Please sign in to comment.