From e97768c6a29e1e9d2d1dcc8af86e219e82b808fd Mon Sep 17 00:00:00 2001 From: Dax Fohl Date: Wed, 8 Dec 2021 08:53:07 -0800 Subject: [PATCH] Classical control (#4631) Sits on top of #4627 Creates `ConditionalOperation` class and executes operations conditionally upon the classical bits. Most of this is done in commit https://github.com/quantumlib/Cirq/pull/4631/commits/06883d5d3d767eb08fe77bc0e4a0cf6ae7e0d64c. Reimplements quantum teleportation example based off this class. Parts 8, 9, 10 of https://tinyurl.com/cirq-feedforward. --- cirq-core/cirq/__init__.py | 2 + cirq-core/cirq/circuits/circuit.py | 4 +- cirq-core/cirq/circuits/circuit_test.py | 205 +-------- cirq-core/cirq/json_resolver_cache.py | 1 + cirq-core/cirq/ops/__init__.py | 4 + .../ops/classically_controlled_operation.py | 186 ++++++++ .../classically_controlled_operation_test.py | 418 ++++++++++++++++++ cirq-core/cirq/ops/moment.py | 6 +- cirq-core/cirq/ops/raw_types.py | 51 +++ cirq-core/cirq/protocols/__init__.py | 1 + .../circuit_diagram_info_protocol.py | 3 +- .../cirq/protocols/control_key_protocol.py | 16 + .../ClassicallyControlledOperation.json | 27 ++ .../ClassicallyControlledOperation.repr | 1 + cirq-core/cirq/sim/act_on_args.py | 4 +- .../cirq/sim/clifford/stabilizer_sampler.py | 4 +- .../cirq/sim/density_matrix_simulator.py | 2 +- examples/quantum_teleportation.py | 27 +- 18 files changed, 741 insertions(+), 221 deletions(-) create mode 100644 cirq-core/cirq/ops/classically_controlled_operation.py create mode 100644 cirq-core/cirq/ops/classically_controlled_operation_test.py create mode 100644 cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.json create mode 100644 cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.repr diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index 35b84ca21da..0b26d585b33 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -190,6 +190,7 @@ CCZPowGate, CCNOT, CCNotPowGate, + ClassicallyControlledOperation, CNOT, CNotPowGate, ControlledGate, @@ -538,6 +539,7 @@ measurement_key_obj, measurement_key_names, measurement_key_objs, + measurement_keys_touched, mixture, mul, num_qubits, diff --git a/cirq-core/cirq/circuits/circuit.py b/cirq-core/cirq/circuits/circuit.py index 499199dc3cb..baa558dc175 100644 --- a/cirq-core/cirq/circuits/circuit.py +++ b/cirq-core/cirq/circuits/circuit.py @@ -2422,9 +2422,7 @@ def _draw_moment_in_diagram( max_x = x0 for op in non_global_ops: qubits = tuple(op.qubits) - cbits = tuple( - (protocols.measurement_key_objs(op) | protocols.control_keys(op)) & label_map.keys() - ) + cbits = tuple(protocols.measurement_keys_touched(op) & label_map.keys()) labels = qubits + cbits indices = [label_map[label] for label in labels] y1 = min(indices) diff --git a/cirq-core/cirq/circuits/circuit_test.py b/cirq-core/cirq/circuits/circuit_test.py index 63cda64ee94..58018f0d955 100644 --- a/cirq-core/cirq/circuits/circuit_test.py +++ b/cirq-core/cirq/circuits/circuit_test.py @@ -86,28 +86,6 @@ def validate_moment(self, moment): moment_and_op_type_validating_device = _MomentAndOpTypeValidatingDeviceType() -class ControlOp(cirq.Operation): - def __init__(self, keys, qubits=None): - self._keys = [cirq.MeasurementKey(k) if isinstance(k, str) else k for k in keys] - self._qubits = qubits or [] - - def with_qubits(self, *new_qids): - pass # coverage: ignore - - @property - def qubits(self): - return self._qubits - - def _control_keys_(self): - return self._keys - - def _circuit_diagram_info_( - self, args: 'cirq.CircuitDiagramInfoArgs' - ) -> 'cirq.CircuitDiagramInfo': - symbols = ['X'] * len(self._qubits) + ['^'] * len(self._keys) - return cirq.CircuitDiagramInfo(symbols) - - def test_alignment(): assert repr(cirq.Alignment.LEFT) == 'cirq.Alignment.LEFT' assert repr(cirq.Alignment.RIGHT) == 'cirq.Alignment.RIGHT' @@ -250,190 +228,19 @@ def test_append_single(): def test_append_control_key(): - q = cirq.LineQubit(0) + q0, q1, q2 = cirq.LineQubit.range(3) c = cirq.Circuit() - c.append(cirq.measure(q, key='a')) - c.append(ControlOp(['a'])) + c.append(cirq.measure(q0, key='a')) + c.append(cirq.X(q1).with_classical_controls('a')) assert len(c) == 2 c = cirq.Circuit() - c.append(cirq.measure(q, key='a')) - c.append(ControlOp(['b'])) - c.append(ControlOp(['b'])) + c.append(cirq.measure(q0, key='a')) + c.append(cirq.X(q1).with_classical_controls('b')) + c.append(cirq.X(q2).with_classical_controls('b')) assert len(c) == 1 -def test_control_key_diagram(): - q0, q1 = cirq.LineQubit.range(2) - c = cirq.Circuit(cirq.measure(q0, key='a'), ControlOp(qubits=[q1], keys=['a'])) - - cirq.testing.assert_has_diagram( - c, - """ -0: ───M─────── - ║ -1: ───╫───X─── - ║ ║ -a: ═══@═══^═══ -""", - use_unicode_characters=True, - ) - - -def test_control_key_diagram_pauli(): - q0, q1 = cirq.LineQubit.range(2) - c = cirq.Circuit( - cirq.measure_single_paulistring(cirq.X(q0), key='a'), ControlOp(qubits=[q1], keys=['a']) - ) - - cirq.testing.assert_has_diagram( - c, - """ -0: ───M(X)─────── - ║ -1: ───╫──────X─── - ║ ║ -a: ═══@══════^═══ -""", - use_unicode_characters=True, - ) - - -def test_control_key_diagram_extra_measurements(): - q0, q1 = cirq.LineQubit.range(2) - c = cirq.Circuit( - cirq.measure(q0, key='a'), cirq.measure(q0, key='b'), ControlOp(qubits=[q1], keys=['a']) - ) - - cirq.testing.assert_has_diagram( - c, - """ -0: ───M───M('b')─── - ║ -1: ───╫───X──────── - ║ ║ -a: ═══@═══^════════ -""", - use_unicode_characters=True, - ) - - -def test_control_key_diagram_extra_controlled_bits(): - q0, q1 = cirq.LineQubit.range(2) - c = cirq.Circuit(cirq.measure(q0, key='a'), ControlOp(qubits=[q0, q1], keys=['a'])) - - cirq.testing.assert_has_diagram( - c, - """ -0: ───M───X─── - ║ ║ -1: ───╫───X─── - ║ ║ -a: ═══@═══^═══ -""", - use_unicode_characters=True, - ) - - -def test_control_key_diagram_extra_control_bits(): - q0, q1 = cirq.LineQubit.range(2) - c = cirq.Circuit( - cirq.measure(q0, key='a'), - cirq.measure(q0, key='b'), - ControlOp(qubits=[q1], keys=['a', 'b']), - ) - - cirq.testing.assert_has_diagram( - c, - """ -0: ───M───M─────── - ║ ║ -1: ───╫───╫───X─── - ║ ║ ║ -a: ═══@═══╬═══^═══ - ║ ║ -b: ═══════@═══^═══ -""", - use_unicode_characters=True, - ) - - -def test_control_key_diagram_multiple_ops_single_moment(): - q0, q1 = cirq.LineQubit.range(2) - c = cirq.Circuit( - cirq.measure(q0, key='a'), - cirq.measure(q1, key='b'), - ControlOp(qubits=[q0], keys=['a']), - ControlOp(qubits=[q1], keys=['b']), - ) - - cirq.testing.assert_has_diagram( - c, - """ - ┌──┐ ┌──┐ -0: ────M──────X───── - ║ ║ -1: ────╫M─────╫X──── - ║║ ║║ -a: ════@╬═════^╬════ - ║ ║ -b: ═════@══════^════ - └──┘ └──┘ -""", - use_unicode_characters=True, - ) - - -def test_control_key_diagram_subcircuit(): - q0, q1 = cirq.LineQubit.range(2) - c = cirq.Circuit( - cirq.CircuitOperation( - cirq.FrozenCircuit(cirq.measure(q0, key='a'), ControlOp(qubits=[q1], keys=['a'])) - ) - ) - - cirq.testing.assert_has_diagram( - c, - """ - [ 0: ───M─────── ] - [ ║ ] -0: ───[ 1: ───╫───X─── ]─── - [ ║ ║ ] - [ a: ═══@═══^═══ ] - │ -1: ───#2─────────────────── -""", - use_unicode_characters=True, - ) - - -def test_control_key_diagram_subcircuit_layered(): - q0, q1 = cirq.LineQubit.range(2) - c = cirq.Circuit( - cirq.measure(q0, key='a'), - cirq.CircuitOperation( - cirq.FrozenCircuit(cirq.measure(q0, key='a'), ControlOp(qubits=[q1], keys=['a'])), - ), - ControlOp(qubits=[q1], keys=['a']), - ) - - cirq.testing.assert_has_diagram( - c, - """ - [ 0: ───M─────── ] - [ ║ ] -0: ───M───[ 1: ───╫───X─── ]─────── - ║ [ ║ ║ ] - ║ [ a: ═══@═══^═══ ] - ║ ║ -1: ───╫───#2───────────────────X─── - ║ ║ ║ -a: ═══@═══╩════════════════════^═══ -""", - use_unicode_characters=True, - ) - - def test_append_multiple(): a = cirq.NamedQubit('a') b = cirq.NamedQubit('b') diff --git a/cirq-core/cirq/json_resolver_cache.py b/cirq-core/cirq/json_resolver_cache.py index f18fbc90c91..9c7613de0c3 100644 --- a/cirq-core/cirq/json_resolver_cache.py +++ b/cirq-core/cirq/json_resolver_cache.py @@ -63,6 +63,7 @@ def _parallel_gate_op(gate, qubits): 'CCXPowGate': cirq.CCXPowGate, 'CCZPowGate': cirq.CCZPowGate, 'CNotPowGate': cirq.CNotPowGate, + 'ClassicallyControlledOperation': cirq.ClassicallyControlledOperation, 'ControlledGate': cirq.ControlledGate, 'ControlledOperation': cirq.ControlledOperation, 'CSwapGate': cirq.CSwapGate, diff --git a/cirq-core/cirq/ops/__init__.py b/cirq-core/cirq/ops/__init__.py index 0b4041e5641..30c892b6982 100644 --- a/cirq-core/cirq/ops/__init__.py +++ b/cirq-core/cirq/ops/__init__.py @@ -82,6 +82,10 @@ ParallelGateFamily, ) +from cirq.ops.classically_controlled_operation import ( + ClassicallyControlledOperation, +) + from cirq.ops.controlled_gate import ( ControlledGate, ) diff --git a/cirq-core/cirq/ops/classically_controlled_operation.py b/cirq-core/cirq/ops/classically_controlled_operation.py new file mode 100644 index 00000000000..94114e742f1 --- /dev/null +++ b/cirq-core/cirq/ops/classically_controlled_operation.py @@ -0,0 +1,186 @@ +# Copyright 2021 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import ( + AbstractSet, + Any, + Dict, + FrozenSet, + Optional, + Sequence, + TYPE_CHECKING, + Tuple, + Union, +) + +from cirq import protocols, value +from cirq.ops import raw_types + +if TYPE_CHECKING: + import cirq + + +@value.value_equality +class ClassicallyControlledOperation(raw_types.Operation): + """Augments existing operations to be conditionally executed. + + An operation that is classically controlled is executed iff all conditions + evaluate to True. Currently the only condition type is a measurement key. + A measurement key evaluates to True iff any qubit in the corresponding + measurement operation evaluated to a non-zero value. + + This object is typically created via + `operation.with_classical_controls(*conditions)`. + """ + + def __init__( + self, + sub_operation: 'cirq.Operation', + conditions: Sequence[Union[str, 'cirq.MeasurementKey']], + ): + """Initializes a `ClassicallyControlledOperation`. + + Multiple consecutive `ClassicallyControlledOperation` layers are + squashed when possible, so one should not depend on a specific number + of layers. + + Args: + sub_operation: The operation to gate with a classical control + condition. + conditions: A sequence of measurement keys, or strings that can be + parsed into measurement keys. + + Raises: + ValueError: If an unsupported gate is being classically + controlled. + """ + if protocols.measurement_key_objs(sub_operation): + raise ValueError( + f'Cannot conditionally run operations with measurements: {sub_operation}' + ) + keys = tuple(value.MeasurementKey(c) if isinstance(c, str) else c for c in conditions) + if isinstance(sub_operation, ClassicallyControlledOperation): + keys += sub_operation._control_keys + sub_operation = sub_operation._sub_operation + self._control_keys: Tuple['cirq.MeasurementKey', ...] = keys + self._sub_operation: 'cirq.Operation' = sub_operation + + def without_classical_controls(self) -> 'cirq.Operation': + return self._sub_operation.without_classical_controls() + + @property + def qubits(self): + return self._sub_operation.qubits + + def with_qubits(self, *new_qubits): + return self._sub_operation.with_qubits(*new_qubits).with_classical_controls( + *self._control_keys + ) + + def _decompose_(self): + result = protocols.decompose_once(self._sub_operation, NotImplemented) + if result is NotImplemented: + return NotImplemented + + return [ClassicallyControlledOperation(op, self._control_keys) for op in result] + + def _value_equality_values_(self): + return (frozenset(self._control_keys), self._sub_operation) + + def __str__(self) -> str: + keys = ', '.join(map(str, self._control_keys)) + return f'{self._sub_operation}.with_classical_controls({keys})' + + def __repr__(self): + return ( + f'cirq.ClassicallyControlledOperation(' + f'{self._sub_operation!r}, {list(self._control_keys)!r})' + ) + + def _is_parameterized_(self) -> bool: + return protocols.is_parameterized(self._sub_operation) + + def _parameter_names_(self) -> AbstractSet[str]: + return protocols.parameter_names(self._sub_operation) + + def _resolve_parameters_( + self, resolver: 'cirq.ParamResolver', recursive: bool + ) -> 'ClassicallyControlledOperation': + new_sub_op = protocols.resolve_parameters(self._sub_operation, resolver, recursive) + return new_sub_op.with_classical_controls(*self._control_keys) + + def _circuit_diagram_info_( + self, args: 'cirq.CircuitDiagramInfoArgs' + ) -> Optional['protocols.CircuitDiagramInfo']: + sub_args = protocols.CircuitDiagramInfoArgs( + known_qubit_count=args.known_qubit_count, + known_qubits=args.known_qubits, + use_unicode_characters=args.use_unicode_characters, + precision=args.precision, + label_map=args.label_map, + ) + sub_info = protocols.circuit_diagram_info(self._sub_operation, sub_args, None) + if sub_info is None: + return NotImplemented # coverage: ignore + + wire_symbols = sub_info.wire_symbols + ('^',) * len(self._control_keys) + exponent_qubit_index = None + if sub_info.exponent_qubit_index is not None: + exponent_qubit_index = sub_info.exponent_qubit_index + len(self._control_keys) + elif sub_info.exponent is not None: + exponent_qubit_index = len(self._control_keys) + return protocols.CircuitDiagramInfo( + wire_symbols=wire_symbols, + exponent=sub_info.exponent, + exponent_qubit_index=exponent_qubit_index, + ) + + def _json_dict_(self) -> Dict[str, Any]: + return { + 'cirq_type': self.__class__.__name__, + 'conditions': self._control_keys, + 'sub_operation': self._sub_operation, + } + + def _act_on_(self, args: 'cirq.ActOnArgs') -> bool: + def not_zero(measurement): + return any(i != 0 for i in measurement) + + measurements = [ + args.log_of_measurement_results.get(str(key), str(key)) for key in self._control_keys + ] + missing = [m for m in measurements if isinstance(m, str)] + if missing: + raise ValueError(f'Measurement keys {missing} missing when performing {self}') + if all(not_zero(measurement) for measurement in measurements): + protocols.act_on(self._sub_operation, args) + return True + + def _with_measurement_key_mapping_( + self, key_map: Dict[str, str] + ) -> 'ClassicallyControlledOperation': + keys = [protocols.with_measurement_key_mapping(k, key_map) for k in self._control_keys] + return self._sub_operation.with_classical_controls(*keys) + + def _with_key_path_prefix_(self, path: Tuple[str, ...]) -> 'ClassicallyControlledOperation': + keys = [protocols.with_key_path_prefix(k, path) for k in self._control_keys] + return self._sub_operation.with_classical_controls(*keys) + + def _control_keys_(self) -> FrozenSet[value.MeasurementKey]: + return frozenset(self._control_keys).union(protocols.control_keys(self._sub_operation)) + + def _qasm_(self, args: 'cirq.QasmArgs') -> Optional[str]: + args.validate_version('2.0') + keys = [f'm_{key}!=0' for key in self._control_keys] + all_keys = " && ".join(keys) + return args.format('if ({0}) {1}', all_keys, protocols.qasm(self._sub_operation, args=args)) diff --git a/cirq-core/cirq/ops/classically_controlled_operation_test.py b/cirq-core/cirq/ops/classically_controlled_operation_test.py new file mode 100644 index 00000000000..ff46dccb5fb --- /dev/null +++ b/cirq-core/cirq/ops/classically_controlled_operation_test.py @@ -0,0 +1,418 @@ +# Copyright 2021 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re +import pytest +import sympy + +import cirq + +ALL_SIMULATORS = ( + cirq.Simulator(), + cirq.DensityMatrixSimulator(), + cirq.CliffordSimulator(), +) + + +def test_diagram(): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit(cirq.measure(q0, key='a'), cirq.X(q1).with_classical_controls('a')) + + cirq.testing.assert_has_diagram( + circuit, + """ +0: ───M─────── + ║ +1: ───╫───X─── + ║ ║ +a: ═══@═══^═══ +""", + use_unicode_characters=True, + ) + + +def test_diagram_pauli(): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.measure_single_paulistring(cirq.X(q0), key='a'), + cirq.X(q1).with_classical_controls('a'), + ) + + cirq.testing.assert_has_diagram( + circuit, + """ +0: ───M(X)─────── + ║ +1: ───╫──────X─── + ║ ║ +a: ═══@══════^═══ +""", + use_unicode_characters=True, + ) + + +def test_diagram_extra_measurements(): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.measure(q0, key='a'), + cirq.measure(q0, key='b'), + cirq.X(q1).with_classical_controls('a'), + ) + + cirq.testing.assert_has_diagram( + circuit, + """ +0: ───M───M('b')─── + ║ +1: ───╫───X──────── + ║ ║ +a: ═══@═══^════════ +""", + use_unicode_characters=True, + ) + + +def test_diagram_extra_controlled_bits(): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.measure(q0, key='a'), + cirq.CX(q0, q1).with_classical_controls('a'), + ) + + cirq.testing.assert_has_diagram( + circuit, + """ +0: ───M───@─── + ║ ║ +1: ───╫───X─── + ║ ║ +a: ═══@═══^═══ +""", + use_unicode_characters=True, + ) + + +def test_diagram_extra_control_bits(): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.measure(q0, key='a'), + cirq.measure(q0, key='b'), + cirq.X(q1).with_classical_controls('a', 'b'), + ) + + cirq.testing.assert_has_diagram( + circuit, + """ +0: ───M───M─────── + ║ ║ +1: ───╫───╫───X─── + ║ ║ ║ +a: ═══@═══╬═══^═══ + ║ ║ +b: ═══════@═══^═══ +""", + use_unicode_characters=True, + ) + + +def test_diagram_multiple_ops_single_moment(): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.measure(q0, key='a'), + cirq.measure(q1, key='b'), + cirq.X(q0).with_classical_controls('a'), + cirq.X(q1).with_classical_controls('b'), + ) + + cirq.testing.assert_has_diagram( + circuit, + """ + ┌──┐ ┌──┐ +0: ────M──────X───── + ║ ║ +1: ────╫M─────╫X──── + ║║ ║║ +a: ════@╬═════^╬════ + ║ ║ +b: ═════@══════^════ + └──┘ └──┘ +""", + use_unicode_characters=True, + ) + + +def test_diagram_subcircuit(): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.CircuitOperation( + cirq.FrozenCircuit( + cirq.measure(q0, key='a'), + cirq.X(q1).with_classical_controls('a'), + ) + ) + ) + + cirq.testing.assert_has_diagram( + circuit, + """ + [ 0: ───M─────── ] + [ ║ ] +0: ───[ 1: ───╫───X─── ]─── + [ ║ ║ ] + [ a: ═══@═══^═══ ] + │ +1: ───#2─────────────────── +""", + use_unicode_characters=True, + ) + + +def test_diagram_subcircuit_layered(): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.measure(q0, key='a'), + cirq.CircuitOperation( + cirq.FrozenCircuit( + cirq.measure(q0, key='a'), + cirq.X(q1).with_classical_controls('a'), + ), + ), + cirq.X(q1).with_classical_controls('a'), + ) + + cirq.testing.assert_has_diagram( + circuit, + """ + [ 0: ───M─────── ] + [ ║ ] +0: ───M───[ 1: ───╫───X─── ]─────── + ║ [ ║ ║ ] + ║ [ a: ═══@═══^═══ ] + ║ ║ +1: ───╫───#2───────────────────X─── + ║ ║ ║ +a: ═══@═══╩════════════════════^═══ +""", + use_unicode_characters=True, + ) + + +def test_qasm(): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit(cirq.measure(q0, key='a'), cirq.X(q1).with_classical_controls('a')) + qasm = cirq.qasm(circuit) + assert ( + qasm + == """// Generated from Cirq v0.14.0.dev + +OPENQASM 2.0; +include "qelib1.inc"; + + +// Qubits: [0, 1] +qreg q[2]; +creg m_a[1]; + + +measure q[0] -> m_a[0]; +if (m_a!=0) x q[1]; +""" + ) + + +@pytest.mark.parametrize('sim', ALL_SIMULATORS) +def test_key_unset(sim): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.measure(q0, key='a'), + cirq.X(q1).with_classical_controls('a'), + cirq.measure(q1, key='b'), + ) + result = sim.run(circuit) + assert result.measurements['a'] == 0 + assert result.measurements['b'] == 0 + + +@pytest.mark.parametrize('sim', ALL_SIMULATORS) +def test_key_set(sim): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.X(q0), + cirq.measure(q0, key='a'), + cirq.X(q1).with_classical_controls('a'), + cirq.measure(q1, key='b'), + ) + result = sim.run(circuit) + assert result.measurements['a'] == 1 + assert result.measurements['b'] == 1 + + +@pytest.mark.parametrize('sim', ALL_SIMULATORS) +def test_subcircuit_key_unset(sim): + q0, q1 = cirq.LineQubit.range(2) + inner = cirq.Circuit( + cirq.measure(q0, key='c'), + cirq.X(q1).with_classical_controls('c'), + cirq.measure(q1, key='b'), + ) + circuit = cirq.Circuit( + cirq.CircuitOperation(inner.freeze(), repetitions=2, measurement_key_map={'c': 'a'}) + ) + result = sim.run(circuit) + assert result.measurements['0:a'] == 0 + assert result.measurements['0:b'] == 0 + assert result.measurements['1:a'] == 0 + assert result.measurements['1:b'] == 0 + + +@pytest.mark.parametrize('sim', ALL_SIMULATORS) +def test_subcircuit_key_set(sim): + q0, q1 = cirq.LineQubit.range(2) + inner = cirq.Circuit( + cirq.X(q0), + cirq.measure(q0, key='c'), + cirq.X(q1).with_classical_controls('c'), + cirq.measure(q1, key='b'), + ) + circuit = cirq.Circuit( + cirq.CircuitOperation(inner.freeze(), repetitions=4, measurement_key_map={'c': 'a'}) + ) + result = sim.run(circuit) + assert result.measurements['0:a'] == 1 + assert result.measurements['0:b'] == 1 + assert result.measurements['1:a'] == 0 + assert result.measurements['1:b'] == 1 + assert result.measurements['2:a'] == 1 + assert result.measurements['2:b'] == 0 + assert result.measurements['3:a'] == 0 + assert result.measurements['3:b'] == 0 + + +def test_key_unset_in_subcircuit_outer_scope(): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.measure(q0, key='a'), + ) + # TODO (daxfohl): This will not need an InsertStrategy after scope PR. + circuit.append( + cirq.CircuitOperation(cirq.FrozenCircuit(cirq.X(q1).with_classical_controls('a'))), + strategy=cirq.InsertStrategy.NEW, + ) + circuit.append(cirq.measure(q1, key='b')) + result = cirq.Simulator().run(circuit) + assert result.measurements['a'] == 0 + assert result.measurements['b'] == 0 + + +def test_key_set_in_subcircuit_outer_scope(): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.X(q0), + cirq.measure(q0, key='a'), + ) + # TODO (daxfohl): This will not need an InsertStrategy after scope PR. + circuit.append( + cirq.CircuitOperation(cirq.FrozenCircuit(cirq.X(q1).with_classical_controls('a'))), + strategy=cirq.InsertStrategy.NEW, + ) + circuit.append(cirq.measure(q1, key='b')) + result = cirq.Simulator().run(circuit) + assert result.measurements['a'] == 1 + assert result.measurements['b'] == 1 + + +def test_condition_flattening(): + q0 = cirq.LineQubit(0) + op = cirq.X(q0).with_classical_controls('a').with_classical_controls('b') + assert set(map(str, op._control_keys)) == {'a', 'b'} + assert isinstance(op._sub_operation, cirq.GateOperation) + + +def test_condition_stacking(): + q0 = cirq.LineQubit(0) + op = cirq.X(q0).with_classical_controls('a').with_tags('t').with_classical_controls('b') + assert set(map(str, cirq.control_keys(op))) == {'a', 'b'} + assert not op.tags + + +def test_condition_removal(): + q0 = cirq.LineQubit(0) + op = ( + cirq.X(q0) + .with_tags('t1') + .with_classical_controls('a') + .with_tags('t2') + .with_classical_controls('b') + ) + op = op.without_classical_controls() + assert not cirq.control_keys(op) + assert set(map(str, op.tags)) == {'t1'} + + +def test_qubit_mapping(): + q0, q1 = cirq.LineQubit.range(2) + op = cirq.X(q0).with_classical_controls('a') + assert op.with_qubits(q1).qubits == (q1,) + + +def test_parameterizable(): + s = sympy.Symbol('s') + q0 = cirq.LineQubit(0) + op = cirq.X(q0).with_classical_controls('a') + opa = cirq.XPowGate(exponent=s).on(q0).with_classical_controls('a') + assert cirq.is_parameterized(opa) + assert not cirq.is_parameterized(op) + assert cirq.resolve_parameters(opa, cirq.ParamResolver({'s': 1})) == op + + +def test_decompose(): + q0 = cirq.LineQubit(0) + op = cirq.H(q0).with_classical_controls('a') + assert cirq.decompose(op) == [ + (cirq.Y(q0) ** 0.5).with_classical_controls('a'), + cirq.XPowGate(exponent=1.0, global_shift=-0.25).on(q0).with_classical_controls('a'), + ] + + +def test_str(): + q0 = cirq.LineQubit(0) + op = cirq.X(q0).with_classical_controls('a') + assert str(op) == 'X(0).with_classical_controls(a)' + + +def test_repr(): + q0 = cirq.LineQubit(0) + op = cirq.X(q0).with_classical_controls('a') + assert repr(op) == ( + "cirq.ClassicallyControlledOperation(" + "cirq.X(cirq.LineQubit(0)), [cirq.MeasurementKey(name='a')]" + ")" + ) + + +def test_no_measurement_gates(): + q0 = cirq.LineQubit(0) + with pytest.raises(ValueError, match='with measurements'): + _ = cirq.measure(q0).with_classical_controls('a') + + +def test_unmeasured_condition(): + q0 = cirq.LineQubit(0) + bad_circuit = cirq.Circuit(cirq.X(q0).with_classical_controls('a')) + with pytest.raises( + ValueError, + match=re.escape( + "Measurement keys ['a'] missing when performing X(0).with_classical_controls(a)" + ), + ): + _ = cirq.Simulator().simulate(bad_circuit) diff --git a/cirq-core/cirq/ops/moment.py b/cirq-core/cirq/ops/moment.py index 7c1f132ca2a..69f1e22a5fb 100644 --- a/cirq-core/cirq/ops/moment.py +++ b/cirq-core/cirq/ops/moment.py @@ -225,7 +225,7 @@ def without_operations_touching(self, qubits: Iterable['cirq.Qid']) -> 'cirq.Mom def _with_measurement_key_mapping_(self, key_map: Dict[str, str]): return Moment( protocols.with_measurement_key_mapping(op, key_map) - if protocols.is_measurement(op) + if protocols.measurement_keys_touched(op) else op for op in self.operations ) @@ -248,7 +248,9 @@ def _with_key_path_(self, path: Tuple[str, ...]): def _with_key_path_prefix_(self, prefix: Tuple[str, ...]): return Moment( - protocols.with_key_path_prefix(op, prefix) if protocols.is_measurement(op) else op + protocols.with_key_path_prefix(op, prefix) + if protocols.measurement_keys_touched(op) + else op for op in self.operations ) diff --git a/cirq-core/cirq/ops/raw_types.py b/cirq-core/cirq/ops/raw_types.py index 717dfcfbe1b..8212ba23f3e 100644 --- a/cirq-core/cirq/ops/raw_types.py +++ b/cirq-core/cirq/ops/raw_types.py @@ -590,6 +590,50 @@ def _commutes_( return np.allclose(m12, m21, atol=atol) + def with_classical_controls( + self, *conditions: Union[str, 'cirq.MeasurementKey'] + ) -> 'cirq.ClassicallyControlledOperation': + """Returns a classically controlled version of this operation. + + An operation that is classically controlled is executed iff all + conditions evaluate to True. Currently the only condition type is a + measurement key. A measurement key evaluates to True iff any qubit in + the corresponding measurement operation evaluated to a non-zero value. + + The classical control will hide any tags on the existing operation, + since tags are considered a local attribute. + + Args: + conditions: A list of measurement keys, or strings that can be + parsed into measurement keys. + + Returns: + A `ClassicallyControlledOperation` wrapping the operation. + """ + from cirq.ops.classically_controlled_operation import ClassicallyControlledOperation + + return ClassicallyControlledOperation(self, conditions) + + def without_classical_controls(self) -> 'cirq.Operation': + """Removes all classical controls from the operation. + + This function removes all classical controls gating the operation. It + acts recursively, so that all classical control wrappers are always + removed from the current operation. + + If there are no classical controls on the operation, it will return + `self`. + + Since tags are considered local, this will also remove any tags from + the operation (unless there are no classical controls on it). If a + `TaggedOperation` is under all the classical control layers, that + `TaggedOperation` will be returned from this function. + + Returns: + The operation with all classical controls removed. + """ + return self + @value.value_equality class TaggedOperation(Operation): @@ -777,6 +821,13 @@ def _equal_up_to_global_phase_( ) -> Union[NotImplementedType, bool]: return protocols.equal_up_to_global_phase(self.sub_operation, other, atol=atol) + def without_classical_controls(self) -> 'cirq.Operation': + new_sub_operation = self.sub_operation.without_classical_controls() + return self if new_sub_operation is self.sub_operation else new_sub_operation + + def _control_keys_(self) -> AbstractSet[value.MeasurementKey]: + return protocols.control_keys(self.sub_operation) + @value.value_equality class _InverseCompositeGate(Gate): diff --git a/cirq-core/cirq/protocols/__init__.py b/cirq-core/cirq/protocols/__init__.py index a56f0f7965a..3595ca5107b 100644 --- a/cirq-core/cirq/protocols/__init__.py +++ b/cirq-core/cirq/protocols/__init__.py @@ -51,6 +51,7 @@ ) from cirq.protocols.control_key_protocol import ( control_keys, + measurement_keys_touched, SupportsControlKey, ) from cirq.protocols.circuit_diagram_info_protocol import ( diff --git a/cirq-core/cirq/protocols/circuit_diagram_info_protocol.py b/cirq-core/cirq/protocols/circuit_diagram_info_protocol.py index 21437ca959d..37bff8a299d 100644 --- a/cirq-core/cirq/protocols/circuit_diagram_info_protocol.py +++ b/cirq-core/cirq/protocols/circuit_diagram_info_protocol.py @@ -331,8 +331,7 @@ def _op_info_with_fallback( info = protocols.circuit_diagram_info(op, args, None) rows: List[LabelEntity] = list(op.qubits) if args.label_map is not None: - rows += protocols.measurement_key_objs(op) & args.label_map.keys() - rows += protocols.control_keys(op) & args.label_map.keys() + rows += protocols.measurement_keys_touched(op) & args.label_map.keys() if info is not None: if max(1, len(rows)) != len(info.wire_symbols): raise ValueError(f'Wanted diagram info from {op!r} for {rows!r}) but got {info!r}') diff --git a/cirq-core/cirq/protocols/control_key_protocol.py b/cirq-core/cirq/protocols/control_key_protocol.py index 9faacc859db..ef39362eed7 100644 --- a/cirq-core/cirq/protocols/control_key_protocol.py +++ b/cirq-core/cirq/protocols/control_key_protocol.py @@ -18,6 +18,7 @@ from typing_extensions import Protocol from cirq._doc import doc_private +from cirq.protocols import measurement_key_protocol if TYPE_CHECKING: import cirq @@ -58,3 +59,18 @@ def control_keys(val: Any) -> AbstractSet['cirq.MeasurementKey']: return set(result) return set() + + +def measurement_keys_touched(val: Any) -> AbstractSet['cirq.MeasurementKey']: + """Returns all the measurement keys used by the value. + + This would be the case if the value is or contains a measurement gate, or + if the value is or contains a conditional operation. + + Args: + val: The object that may interact with measurements. + + Returns: + The measurement keys used by the value.. + """ + return measurement_key_protocol.measurement_key_objs(val) | control_keys(val) diff --git a/cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.json b/cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.json new file mode 100644 index 00000000000..a22c2720095 --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.json @@ -0,0 +1,27 @@ +{ + "cirq_type": "ClassicallyControlledOperation", + "conditions": [ + { + "cirq_type": "MeasurementKey", + "name": "a", + "path": [] + }, + { + "cirq_type": "MeasurementKey", + "name": "b", + "path": [] + } + ], + "sub_operation": { + "cirq_type": "SingleQubitPauliStringGateOperation", + "pauli": { + "cirq_type": "_PauliY", + "exponent": 1, + "global_shift": 0.0 + }, + "qubit": { + "cirq_type": "NamedQubit", + "name": "target" + } + } +} \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.repr b/cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.repr new file mode 100644 index 00000000000..bfc25256cff --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.repr @@ -0,0 +1 @@ +cirq.ClassicallyControlledOperation(cirq.Y.on(cirq.NamedQubit('target')), [cirq.MeasurementKey('a'), cirq.MeasurementKey('b')]) \ No newline at end of file diff --git a/cirq-core/cirq/sim/act_on_args.py b/cirq-core/cirq/sim/act_on_args.py index 870eaf3b9c5..d28cda196f8 100644 --- a/cirq-core/cirq/sim/act_on_args.py +++ b/cirq-core/cirq/sim/act_on_args.py @@ -46,7 +46,7 @@ def __init__( self, prng: np.random.RandomState = None, qubits: Sequence['cirq.Qid'] = None, - log_of_measurement_results: Dict[str, Any] = None, + log_of_measurement_results: Dict[str, List[int]] = None, ): """Inits ActOnArgs. @@ -181,7 +181,7 @@ def _on_transpose_to_qubit_order(self: TSelf, qubits: Sequence['cirq.Qid'], targ functionality, if supported.""" @property - def log_of_measurement_results(self) -> Dict[str, Any]: + def log_of_measurement_results(self) -> Dict[str, List[int]]: return self._log_of_measurement_results @property diff --git a/cirq-core/cirq/sim/clifford/stabilizer_sampler.py b/cirq-core/cirq/sim/clifford/stabilizer_sampler.py index bb0e32da8f5..e2ef5ce514e 100644 --- a/cirq-core/cirq/sim/clifford/stabilizer_sampler.py +++ b/cirq-core/cirq/sim/clifford/stabilizer_sampler.py @@ -53,7 +53,7 @@ def run_sweep( def _run(self, circuit: circuits.AbstractCircuit, repetitions: int) -> Dict[str, np.ndarray]: - measurements: Dict[str, List[int]] = { + measurements: Dict[str, List[np.ndarray]] = { key: [] for key in protocols.measurement_key_names(circuit) } qubits = circuit.all_qubits() @@ -69,6 +69,6 @@ def _run(self, circuit: circuits.AbstractCircuit, repetitions: int) -> Dict[str, protocols.act_on(op, state) for k, v in state.log_of_measurement_results.items(): - measurements[k].append(v) + measurements[k].append(np.array(v, dtype=np.uint8)) return {k: np.array(v) for k, v in measurements.items()} diff --git a/cirq-core/cirq/sim/density_matrix_simulator.py b/cirq-core/cirq/sim/density_matrix_simulator.py index ffcd69af42e..bad0553f0df 100644 --- a/cirq-core/cirq/sim/density_matrix_simulator.py +++ b/cirq-core/cirq/sim/density_matrix_simulator.py @@ -211,7 +211,7 @@ def _create_partial_act_on_args( ) def _can_be_in_run_prefix(self, val: Any): - return not protocols.is_measurement(val) + return not protocols.measurement_keys_touched(val) def _create_step_result( self, diff --git a/examples/quantum_teleportation.py b/examples/quantum_teleportation.py index 5c737d09b21..30c8d2e3fb2 100644 --- a/examples/quantum_teleportation.py +++ b/examples/quantum_teleportation.py @@ -20,18 +20,23 @@ === EXAMPLE OUTPUT === Circuit: -0: -----------X^0.25---Y^0.125---@---H---M-------@--- - | | | -1: ---H---@----------------------X-------M---@---|--- - | | | -2: -------X----------------------------------X---@--- + ┌──┐ +0: ───────X^0.559───Y^0.647───@───H────M───────── + │ ║ +1: ───────H─────────@─────────X───M────╫───────── + │ ║ ║ +2: ─────────────────X─────────────╫────╫X────Z─── + ║ ║║ ║ +alice: ═══════════════════════════@════╬^════╬═══ + ║ ║ +msg: ══════════════════════════════════@═════^═══ + └──┘ Bloch Sphere of Message After Random X and Y Gates: -x: 0.2706 y: -0.7071 z: 0.6533 +x: -0.1647 y: -0.9829 z: 0.082 Bloch Sphere of Qubit 2 at Final State: -x: 0.2706 y: -0.7071 z: 0.6533 - +x: -0.1647 y: -0.9829 z: 0.082 """ import random @@ -49,10 +54,12 @@ def make_quantum_teleportation_circuit(ranX, ranY): circuit.append([cirq.X(msg) ** ranX, cirq.Y(msg) ** ranY]) # Bell measurement of the Message and Alice's entangled qubit. circuit.append([cirq.CNOT(msg, alice), cirq.H(msg)]) - circuit.append(cirq.measure(msg, alice)) + circuit.append(cirq.measure(msg, key='msg')) + circuit.append(cirq.measure(alice, key='alice')) # Uses the two classical bits from the Bell measurement to recover the # original quantum Message on Bob's entangled qubit. - circuit.append([cirq.CNOT(alice, bob), cirq.CZ(msg, bob)]) + circuit.append(cirq.X(bob).with_classical_controls('alice')) + circuit.append(cirq.Z(bob).with_classical_controls('msg')) return circuit