Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multi-qubit measurements in deferred measurement transformer #5787

Merged
merged 8 commits into from
Jul 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 12 additions & 14 deletions cirq-core/cirq/transformers/measurement_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union

from cirq import ops, protocols, value
Expand Down Expand Up @@ -81,9 +82,7 @@ def defer_measurements(
A circuit with equivalent logic, but all measurements at the end of the
circuit.
Raises:
ValueError: If sympy-based classical conditions are used, or if
conditions based on multi-qubit measurements exist. (The latter of
these is planned to be implemented soon).
ValueError: If sympy-based classical conditions are used.
NotImplementedError: When attempting to defer a measurement with a
confusion map. (https://github.com/quantumlib/Cirq/issues/5482)
"""
Expand Down Expand Up @@ -111,23 +110,22 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
elif protocols.is_measurement(op):
return [defer(op, None) for op in protocols.decompose_once(op)]
elif op.classical_controls:
controls = []
new_op = op.without_classical_controls()
for c in op.classical_controls:
if isinstance(c, value.KeyCondition):
if c.key not in measurement_qubits:
raise ValueError(f'Deferred measurement for key={c.key} not found.')
qubits = measurement_qubits[c.key]
if len(qubits) != 1:
# TODO: Multi-qubit conditions require
# https://github.com/quantumlib/Cirq/issues/4512
# Remember to update docstring above once this works.
raise ValueError('Only single qubit conditions are allowed.')
controls.extend(qubits)
qs = measurement_qubits[c.key]
if len(qs) == 1:
control_values: Any = range(1, qs[0].dimension)
else:
all_values = itertools.product(*[range(q.dimension) for q in qs])
anything_but_all_zeros = tuple(itertools.islice(all_values, 1, None))
control_values = ops.SumOfProducts(anything_but_all_zeros)
tanujkhattar marked this conversation as resolved.
Show resolved Hide resolved
new_op = new_op.controlled_by(*qs, control_values=control_values)
else:
raise ValueError('Only KeyConditions are allowed.')
return op.without_classical_controls().controlled_by(
*controls, control_values=[tuple(range(1, q.dimension)) for q in controls]
)
return new_op
return op

circuit = transformer_primitives.map_operations_and_unroll(
Expand Down
32 changes: 25 additions & 7 deletions cirq-core/cirq/transformers/measurement_transformers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,31 @@ def test_multi_qubit_measurements():
)


def test_multi_qubit_control():
q0, q1, q2 = cirq.LineQubit.range(3)
circuit = cirq.Circuit(
cirq.measure(q0, q1, key='a'),
cirq.X(q2).with_classical_controls('a'),
cirq.measure(q2, key='b'),
)
assert_equivalent_to_deferred(circuit)
deferred = cirq.defer_measurements(circuit)
q_ma0 = _MeasurementQid('a', q0)
q_ma1 = _MeasurementQid('a', q1)
cirq.testing.assert_same_circuits(
deferred,
cirq.Circuit(
cirq.CX(q0, q_ma0),
cirq.CX(q1, q_ma1),
cirq.X(q2).controlled_by(
q_ma0, q_ma1, control_values=cirq.SumOfProducts(((0, 1), (1, 0), (1, 1)))
),
cirq.measure(q_ma0, q_ma1, key='a'),
cirq.measure(q2, key='b'),
),
)


def test_diagram():
q0, q1, q2, q3 = cirq.LineQubit.range(4)
circuit = cirq.Circuit(
Expand Down Expand Up @@ -270,13 +295,6 @@ def test_repr(qid: _MeasurementQid):
test_repr(_MeasurementQid('0:1:a', cirq.LineQid(9, 4)))


def test_multi_qubit_control():
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(cirq.measure(q0, q1, key='a'), cirq.X(q1).with_classical_controls('a'))
with pytest.raises(ValueError, match='Only single qubit conditions are allowed'):
_ = cirq.defer_measurements(circuit)


def test_sympy_control():
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(
Expand Down