diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index f45004b9347..ac57b1b07dd 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -483,15 +483,18 @@ canonicalize_half_turns, chosen_angle_to_canonical_half_turns, chosen_angle_to_half_turns, + Condition, Duration, DURATION_LIKE, GenericMetaImplementAnyOneOf, + KeyCondition, LinearDict, MEASUREMENT_KEY_SEPARATOR, MeasurementKey, PeriodicValue, RANDOM_STATE_OR_SEED_LIKE, state_vector_to_probabilities, + SympyCondition, Timestamp, TParamKey, TParamVal, diff --git a/cirq-core/cirq/_compat.py b/cirq-core/cirq/_compat.py index 6dc089e0f6b..facfdd8e45a 100644 --- a/cirq-core/cirq/_compat.py +++ b/cirq-core/cirq/_compat.py @@ -27,6 +27,7 @@ import numpy as np import pandas as pd import sympy +import sympy.printing.repr def proper_repr(value: Any) -> str: diff --git a/cirq-core/cirq/json_resolver_cache.py b/cirq-core/cirq/json_resolver_cache.py index 8213f007868..6a017320dad 100644 --- a/cirq-core/cirq/json_resolver_cache.py +++ b/cirq-core/cirq/json_resolver_cache.py @@ -94,6 +94,7 @@ def _parallel_gate_op(gate, qubits): 'ISwapPowGate': cirq.ISwapPowGate, 'IdentityGate': cirq.IdentityGate, 'InitObsSetting': cirq.work.InitObsSetting, + 'KeyCondition': cirq.KeyCondition, 'KrausChannel': cirq.KrausChannel, 'LinearDict': cirq.LinearDict, 'LineQubit': cirq.LineQubit, @@ -150,6 +151,7 @@ def _parallel_gate_op(gate, qubits): 'StatePreparationChannel': cirq.StatePreparationChannel, 'SwapPowGate': cirq.SwapPowGate, 'SymmetricalQidPair': cirq.SymmetricalQidPair, + 'SympyCondition': cirq.SympyCondition, 'TaggedOperation': cirq.TaggedOperation, 'TiltedSquareLattice': cirq.TiltedSquareLattice, 'TrialResult': cirq.Result, # keep support for Cirq < 0.11. diff --git a/cirq-core/cirq/ops/classically_controlled_operation.py b/cirq-core/cirq/ops/classically_controlled_operation.py index fe3093a3d4e..74a4c3dbb53 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation.py +++ b/cirq-core/cirq/ops/classically_controlled_operation.py @@ -16,6 +16,7 @@ Any, Dict, FrozenSet, + List, Optional, Sequence, TYPE_CHECKING, @@ -23,6 +24,8 @@ Union, ) +import sympy + from cirq import protocols, value from cirq.ops import raw_types @@ -46,7 +49,7 @@ class ClassicallyControlledOperation(raw_types.Operation): def __init__( self, sub_operation: 'cirq.Operation', - conditions: Sequence[Union[str, 'cirq.MeasurementKey']], + conditions: Sequence[Union[str, 'cirq.MeasurementKey', 'cirq.Condition', sympy.Basic]], ): """Initializes a `ClassicallyControlledOperation`. @@ -68,13 +71,26 @@ def __init__( raise ValueError( f'Cannot conditionally run operations with measurements: {sub_operation}' ) - keys = tuple(value.MeasurementKey(c) if isinstance(c, str) else c for c in conditions) + conditions = tuple(conditions) if isinstance(sub_operation, ClassicallyControlledOperation): - keys += sub_operation._control_keys + conditions += sub_operation._conditions sub_operation = sub_operation._sub_operation - self._control_keys: Tuple['cirq.MeasurementKey', ...] = keys + conds: List['cirq.Condition'] = [] + for c in conditions: + if isinstance(c, str): + c = value.MeasurementKey.parse_serialized(c) + if isinstance(c, value.MeasurementKey): + c = value.KeyCondition(c) + if isinstance(c, sympy.Basic): + c = value.SympyCondition(c) + conds.append(c) + self._conditions: Tuple['cirq.Condition', ...] = tuple(conds) self._sub_operation: 'cirq.Operation' = sub_operation + @property + def classical_controls(self) -> FrozenSet['cirq.Condition']: + return frozenset(self._conditions).union(self._sub_operation.classical_controls) + def without_classical_controls(self) -> 'cirq.Operation': return self._sub_operation.without_classical_controls() @@ -84,7 +100,7 @@ def qubits(self): def with_qubits(self, *new_qubits): return self._sub_operation.with_qubits(*new_qubits).with_classical_controls( - *self._control_keys + *self._conditions ) def _decompose_(self): @@ -92,19 +108,19 @@ def _decompose_(self): if result is NotImplemented: return NotImplemented - return [ClassicallyControlledOperation(op, self._control_keys) for op in result] + return [ClassicallyControlledOperation(op, self._conditions) for op in result] def _value_equality_values_(self): - return (frozenset(self._control_keys), self._sub_operation) + return (frozenset(self._conditions), self._sub_operation) def __str__(self) -> str: - keys = ', '.join(map(str, self._control_keys)) + keys = ', '.join(map(str, self._conditions)) return f'{self._sub_operation}.with_classical_controls({keys})' def __repr__(self): return ( f'cirq.ClassicallyControlledOperation(' - f'{self._sub_operation!r}, {list(self._control_keys)!r})' + f'{self._sub_operation!r}, {list(self._conditions)!r})' ) def _is_parameterized_(self) -> bool: @@ -117,7 +133,7 @@ def _resolve_parameters_( self, resolver: 'cirq.ParamResolver', recursive: bool ) -> 'ClassicallyControlledOperation': new_sub_op = protocols.resolve_parameters(self._sub_operation, resolver, recursive) - return new_sub_op.with_classical_controls(*self._control_keys) + return new_sub_op.with_classical_controls(*self._conditions) def _circuit_diagram_info_( self, args: 'cirq.CircuitDiagramInfoArgs' @@ -133,12 +149,20 @@ def _circuit_diagram_info_( if sub_info is None: return NotImplemented # coverage: ignore - wire_symbols = sub_info.wire_symbols + ('^',) * len(self._control_keys) + control_count = len({k for c in self._conditions for k in c.keys}) + wire_symbols = sub_info.wire_symbols + ('^',) * control_count + if any(not isinstance(c, value.KeyCondition) for c in self._conditions): + wire_symbols = ( + wire_symbols[0] + + '(conditions=[' + + ', '.join(str(c) for c in self._conditions) + + '])', + ) + wire_symbols[1:] exponent_qubit_index = None if sub_info.exponent_qubit_index is not None: - exponent_qubit_index = sub_info.exponent_qubit_index + len(self._control_keys) + exponent_qubit_index = sub_info.exponent_qubit_index + control_count elif sub_info.exponent is not None: - exponent_qubit_index = len(self._control_keys) + exponent_qubit_index = control_count return protocols.CircuitDiagramInfo( wire_symbols=wire_symbols, exponent=sub_info.exponent, @@ -148,58 +172,45 @@ def _circuit_diagram_info_( def _json_dict_(self) -> Dict[str, Any]: return { 'cirq_type': self.__class__.__name__, - 'conditions': self._control_keys, + 'conditions': self._conditions, 'sub_operation': self._sub_operation, } def _act_on_(self, args: 'cirq.ActOnArgs') -> bool: - def not_zero(measurement): - return any(i != 0 for i in measurement) - - measurements = [ - args.log_of_measurement_results.get(str(key), str(key)) for key in self._control_keys - ] - missing = [m for m in measurements if isinstance(m, str)] - if missing: - raise ValueError(f'Measurement keys {missing} missing when performing {self}') - if all(not_zero(measurement) for measurement in measurements): + if all(c.resolve(args.log_of_measurement_results) for c in self._conditions): protocols.act_on(self._sub_operation, args) return True def _with_measurement_key_mapping_( self, key_map: Dict[str, str] ) -> 'ClassicallyControlledOperation': + conditions = [protocols.with_measurement_key_mapping(c, key_map) for c in self._conditions] sub_operation = protocols.with_measurement_key_mapping(self._sub_operation, key_map) sub_operation = self._sub_operation if sub_operation is NotImplemented else sub_operation - return sub_operation.with_classical_controls( - *[protocols.with_measurement_key_mapping(k, key_map) for k in self._control_keys] - ) + return sub_operation.with_classical_controls(*conditions) - def _with_key_path_prefix_(self, path: Tuple[str, ...]) -> 'ClassicallyControlledOperation': - keys = [protocols.with_key_path_prefix(k, path) for k in self._control_keys] - return self._sub_operation.with_classical_controls(*keys) + def _with_key_path_prefix_(self, prefix: Tuple[str, ...]) -> 'ClassicallyControlledOperation': + conditions = [protocols.with_key_path_prefix(c, prefix) for c in self._conditions] + sub_operation = protocols.with_key_path_prefix(self._sub_operation, prefix) + sub_operation = self._sub_operation if sub_operation is NotImplemented else sub_operation + return sub_operation.with_classical_controls(*conditions) def _with_rescoped_keys_( self, path: Tuple[str, ...], bindable_keys: FrozenSet['cirq.MeasurementKey'], ) -> 'ClassicallyControlledOperation': - def map_key(key: 'cirq.MeasurementKey') -> 'cirq.MeasurementKey': - for i in range(len(path) + 1): - back_path = path[: len(path) - i] - new_key = key.with_key_path_prefix(*back_path) - if new_key in bindable_keys: - return new_key - return key - + conds = [protocols.with_rescoped_keys(c, path, bindable_keys) for c in self._conditions] sub_operation = protocols.with_rescoped_keys(self._sub_operation, path, bindable_keys) - return sub_operation.with_classical_controls(*[map_key(k) for k in self._control_keys]) + return sub_operation.with_classical_controls(*conds) def _control_keys_(self) -> FrozenSet['cirq.MeasurementKey']: - return frozenset(self._control_keys).union(protocols.control_keys(self._sub_operation)) + local_keys: FrozenSet['cirq.MeasurementKey'] = frozenset( + k for condition in self._conditions for k in condition.keys + ) + return local_keys.union(protocols.control_keys(self._sub_operation)) def _qasm_(self, args: 'cirq.QasmArgs') -> Optional[str]: args.validate_version('2.0') - keys = [f'm_{key}!=0' for key in self._control_keys] - all_keys = " && ".join(keys) + all_keys = " && ".join(c.qasm for c in self._conditions) return args.format('if ({0}) {1}', all_keys, protocols.qasm(self._sub_operation, args=args)) diff --git a/cirq-core/cirq/ops/classically_controlled_operation_test.py b/cirq-core/cirq/ops/classically_controlled_operation_test.py index a9896dbed0d..1baf7612f37 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation_test.py +++ b/cirq-core/cirq/ops/classically_controlled_operation_test.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import re import pytest import sympy +from sympy.parsing import sympy_parser import cirq @@ -331,10 +331,17 @@ def test_key_set_in_subcircuit_outer_scope(): assert result.measurements['b'] == 1 +def test_condition_types(): + q0 = cirq.LineQubit(0) + sympy_cond = sympy_parser.parse_expr('a >= 2') + op = cirq.X(q0).with_classical_controls(cirq.MeasurementKey('a'), 'b', 'a > b', sympy_cond) + assert set(map(str, op.classical_controls)) == {'a', 'b', 'a > b', 'a >= 2'} + + def test_condition_flattening(): q0 = cirq.LineQubit(0) op = cirq.X(q0).with_classical_controls('a').with_classical_controls('b') - assert set(map(str, op._control_keys)) == {'a', 'b'} + assert set(map(str, op.classical_controls)) == {'a', 'b'} assert isinstance(op._sub_operation, cirq.GateOperation) @@ -342,6 +349,7 @@ def test_condition_stacking(): q0 = cirq.LineQubit(0) op = cirq.X(q0).with_classical_controls('a').with_tags('t').with_classical_controls('b') assert set(map(str, cirq.control_keys(op))) == {'a', 'b'} + assert set(map(str, op.classical_controls)) == {'a', 'b'} assert not op.tags @@ -356,6 +364,7 @@ def test_condition_removal(): ) op = op.without_classical_controls() assert not cirq.control_keys(op) + assert not op.classical_controls assert set(map(str, op.tags)) == {'t1'} @@ -604,7 +613,7 @@ def test_repr(): op = cirq.X(q0).with_classical_controls('a') assert repr(op) == ( "cirq.ClassicallyControlledOperation(" - "cirq.X(cirq.LineQubit(0)), [cirq.MeasurementKey(name='a')]" + "cirq.X(cirq.LineQubit(0)), [cirq.KeyCondition(cirq.MeasurementKey(name='a'))]" ")" ) @@ -619,10 +628,7 @@ def test_unmeasured_condition(): q0 = cirq.LineQubit(0) bad_circuit = cirq.Circuit(cirq.X(q0).with_classical_controls('a')) with pytest.raises( - ValueError, - match=re.escape( - "Measurement keys ['a'] missing when performing X(0).with_classical_controls(a)" - ), + ValueError, match='Measurement key a missing when testing classical control' ): _ = cirq.Simulator().simulate(bad_circuit) @@ -669,3 +675,141 @@ def test_layered_circuit_operations_with_controls_in_between(): """, use_unicode_characters=True, ) + + +def test_sympy(): + q0, q1, q2, q3, q_result = cirq.LineQubit.range(5) + for i in range(4): + for j in range(4): + # Put first two qubits into a state representing bitstring(i), next two qubits into a + # state representing bitstring(j) and measure those into m_i and m_j respectively. Then + # add a conditional X(q_result) based on m_i > m_j and measure that. + bitstring_i = cirq.big_endian_int_to_bits(i, bit_count=2) + bitstring_j = cirq.big_endian_int_to_bits(j, bit_count=2) + circuit = cirq.Circuit( + cirq.X(q0) ** bitstring_i[0], + cirq.X(q1) ** bitstring_i[1], + cirq.X(q2) ** bitstring_j[0], + cirq.X(q3) ** bitstring_j[1], + cirq.measure(q0, q1, key='m_i'), + cirq.measure(q2, q3, key='m_j'), + cirq.X(q_result).with_classical_controls(sympy_parser.parse_expr('m_j > m_i')), + cirq.measure(q_result, key='m_result'), + ) + + # m_result should now be set iff j > i. + result = cirq.Simulator().run(circuit) + assert result.measurements['m_result'][0][0] == (j > i) + + +def test_sympy_path_prefix(): + q = cirq.LineQubit(0) + op = cirq.X(q).with_classical_controls(sympy.Symbol('b')) + prefixed = cirq.with_key_path_prefix(op, ('0',)) + assert cirq.control_keys(prefixed) == {'0:b'} + + +def test_sympy_scope(): + q = cirq.LineQubit(0) + a, b, c, d = sympy.symbols('a b c d') + inner = cirq.Circuit( + cirq.measure(q, key='a'), + cirq.X(q).with_classical_controls(a & b).with_classical_controls(c | d), + ) + middle = cirq.Circuit( + cirq.measure(q, key='b'), + cirq.measure(q, key=cirq.MeasurementKey('c', ('0',))), + cirq.CircuitOperation(inner.freeze(), repetitions=2), + ) + outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2) + circuit = outer_subcircuit.mapped_circuit(deep=True) + internal_controls = [str(k) for op in circuit.all_operations() for k in cirq.control_keys(op)] + assert set(internal_controls) == {'0:0:a', '0:1:a', '1:0:a', '1:1:a', '0:b', '1:b', 'c', 'd'} + assert cirq.control_keys(outer_subcircuit) == {'c', 'd'} + assert cirq.control_keys(circuit) == {'c', 'd'} + assert circuit == cirq.Circuit(cirq.decompose(outer_subcircuit)) + cirq.testing.assert_has_diagram( + cirq.Circuit(outer_subcircuit), + """ + [ [ 0: ───M───X(conditions=[c | d, a & b])─── ] ] + [ [ ║ ║ ] ] + [ [ a: ═══@═══^══════════════════════════════ ] ] + [ [ ║ ] ] + [ 0: ───M───M('0:c')───[ b: ═══════^══════════════════════════════ ]──────────── ] + [ ║ [ ║ ] ] + [ ║ [ c: ═══════^══════════════════════════════ ] ] +0: ───[ ║ [ ║ ] ]──────────── + [ ║ [ d: ═══════^══════════════════════════════ ](loops=2) ] + [ ║ ║ ] + [ b: ═══@══════════════╬════════════════════════════════════════════════════════ ] + [ ║ ] + [ c: ══════════════════╬════════════════════════════════════════════════════════ ] + [ ║ ] + [ d: ══════════════════╩════════════════════════════════════════════════════════ ](loops=2) + ║ +c: ═══╬═════════════════════════════════════════════════════════════════════════════════════════════ + ║ +d: ═══╩═════════════════════════════════════════════════════════════════════════════════════════════ +""", + use_unicode_characters=True, + ) + + # pylint: disable=line-too-long + cirq.testing.assert_has_diagram( + circuit, + """ +0: ───────M───M('0:0:c')───M───X(conditions=[c | d, 0:0:a & 0:b])───M───X(conditions=[c | d, 0:1:a & 0:b])───M───M('1:0:c')───M───X(conditions=[c | d, 1:0:a & 1:b])───M───X(conditions=[c | d, 1:1:a & 1:b])─── + ║ ║ ║ ║ ║ ║ ║ ║ ║ ║ +0:0:a: ═══╬════════════════@═══^════════════════════════════════════╬═══╬════════════════════════════════════╬════════════════╬═══╬════════════════════════════════════╬═══╬════════════════════════════════════ + ║ ║ ║ ║ ║ ║ ║ ║ ║ +0:1:a: ═══╬════════════════════╬════════════════════════════════════@═══^════════════════════════════════════╬════════════════╬═══╬════════════════════════════════════╬═══╬════════════════════════════════════ + ║ ║ ║ ║ ║ ║ ║ ║ +0:b: ═════@════════════════════^════════════════════════════════════════^════════════════════════════════════╬════════════════╬═══╬════════════════════════════════════╬═══╬════════════════════════════════════ + ║ ║ ║ ║ ║ ║ ║ +1:0:a: ════════════════════════╬════════════════════════════════════════╬════════════════════════════════════╬════════════════@═══^════════════════════════════════════╬═══╬════════════════════════════════════ + ║ ║ ║ ║ ║ ║ +1:1:a: ════════════════════════╬════════════════════════════════════════╬════════════════════════════════════╬════════════════════╬════════════════════════════════════@═══^════════════════════════════════════ + ║ ║ ║ ║ ║ +1:b: ══════════════════════════╬════════════════════════════════════════╬════════════════════════════════════@════════════════════^════════════════════════════════════════^════════════════════════════════════ + ║ ║ ║ ║ +c: ════════════════════════════^════════════════════════════════════════^═════════════════════════════════════════════════════════^════════════════════════════════════════^════════════════════════════════════ + ║ ║ ║ ║ +d: ════════════════════════════^════════════════════════════════════════^═════════════════════════════════════════════════════════^════════════════════════════════════════^════════════════════════════════════ +""", + use_unicode_characters=True, + ) + # pylint: enable=line-too-long + + +def test_sympy_scope_simulation(): + q0, q1, q2, q3, q_ignored, q_result = cirq.LineQubit.range(6) + condition = sympy_parser.parse_expr('a & b | c & d') + # We set up condition (a & b | c & d) plus an ignored measurement key, and run through the + # combinations of possible values of those (by doing X(q_i)**bits[i] on each), then verify + # that the final measurement into m_result is True iff that condition was met. + for i in range(32): + bits = cirq.big_endian_int_to_bits(i, bit_count=5) + inner = cirq.Circuit( + cirq.X(q0) ** bits[0], + cirq.measure(q0, key='a'), + cirq.X(q_result).with_classical_controls(condition), + cirq.measure(q_result, key='m_result'), + ) + middle = cirq.Circuit( + cirq.X(q1) ** bits[1], + cirq.measure(q1, key='b'), + cirq.X(q_ignored) ** bits[4], + cirq.measure(q_ignored, key=cirq.MeasurementKey('c', ('0',))), + cirq.CircuitOperation(inner.freeze(), repetition_ids=['0']), + ) + circuit = cirq.Circuit( + cirq.X(q2) ** bits[2], + cirq.measure(q2, key='c'), + cirq.X(q3) ** bits[3], + cirq.measure(q3, key='d'), + cirq.CircuitOperation(middle.freeze(), repetition_ids=['0']), + ) + result = cirq.CliffordSimulator().run(circuit) + assert result.measurements['0:0:m_result'][0][0] == ( + bits[0] and bits[1] or bits[2] and bits[3] # bits[4] irrelevant + ) diff --git a/cirq-core/cirq/ops/raw_types.py b/cirq-core/cirq/ops/raw_types.py index c88ba65eeb8..ca665008655 100644 --- a/cirq-core/cirq/ops/raw_types.py +++ b/cirq-core/cirq/ops/raw_types.py @@ -22,6 +22,7 @@ Callable, Collection, Dict, + FrozenSet, Hashable, Iterable, List, @@ -34,6 +35,7 @@ ) import numpy as np +import sympy from cirq import protocols, value from cirq._import import LazyLoader @@ -590,8 +592,13 @@ def _commutes_( return np.allclose(m12, m21, atol=atol) + @property + def classical_controls(self) -> FrozenSet['cirq.Condition']: + """The classical controls gating this operation.""" + return frozenset() + def with_classical_controls( - self, *conditions: Union[str, 'cirq.MeasurementKey'] + self, *conditions: Union[str, 'cirq.MeasurementKey', 'cirq.Condition', sympy.Expr] ) -> 'cirq.ClassicallyControlledOperation': """Returns a classically controlled version of this operation. @@ -604,8 +611,9 @@ def with_classical_controls( since tags are considered a local attribute. Args: - conditions: A list of measurement keys, or strings that can be - parsed into measurement keys. + conditions: A list of measurement keys, strings that can be parsed + into measurement keys, or sympy expressions where the free + symbols are measurement key strings. Returns: A `ClassicallyControlledOperation` wrapping the operation. @@ -821,6 +829,10 @@ def _equal_up_to_global_phase_( ) -> Union[NotImplementedType, bool]: return protocols.equal_up_to_global_phase(self.sub_operation, other, atol=atol) + @property + def classical_controls(self) -> FrozenSet['cirq.Condition']: + return self.sub_operation.classical_controls + def without_classical_controls(self) -> 'cirq.Operation': new_sub_operation = self.sub_operation.without_classical_controls() return self if new_sub_operation is self.sub_operation else new_sub_operation diff --git a/cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.json b/cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.json index a22c2720095..8fbae9b27c7 100644 --- a/cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.json +++ b/cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.json @@ -2,14 +2,20 @@ "cirq_type": "ClassicallyControlledOperation", "conditions": [ { - "cirq_type": "MeasurementKey", - "name": "a", - "path": [] + "cirq_type": "KeyCondition", + "key": { + "cirq_type": "MeasurementKey", + "name": "a", + "path": [] + } }, { - "cirq_type": "MeasurementKey", - "name": "b", - "path": [] + "cirq_type": "KeyCondition", + "key": { + "cirq_type": "MeasurementKey", + "name": "b", + "path": [] + } } ], "sub_operation": { diff --git a/cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.repr b/cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.repr index bbc3a1dc22b..423551a8501 100644 --- a/cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.repr +++ b/cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.repr @@ -1 +1 @@ -cirq.ClassicallyControlledOperation(cirq.Y.on(cirq.NamedQubit('target')), [cirq.MeasurementKey('a'), cirq.MeasurementKey('b')]) +cirq.ClassicallyControlledOperation(cirq.Y.on(cirq.NamedQubit('target')), [cirq.KeyCondition(key=cirq.MeasurementKey('a')), cirq.KeyCondition(key=cirq.MeasurementKey('b'))]) \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/KeyCondition.json b/cirq-core/cirq/protocols/json_test_data/KeyCondition.json new file mode 100644 index 00000000000..f5b81ba63dc --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/KeyCondition.json @@ -0,0 +1,8 @@ +{ + "cirq_type": "KeyCondition", + "key": { + "cirq_type": "MeasurementKey", + "name": "a", + "path": [] + } +} \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/KeyCondition.repr b/cirq-core/cirq/protocols/json_test_data/KeyCondition.repr new file mode 100644 index 00000000000..fb9fa3232ec --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/KeyCondition.repr @@ -0,0 +1 @@ +cirq.KeyCondition(key=cirq.MeasurementKey('a')) \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/SympyCondition.json b/cirq-core/cirq/protocols/json_test_data/SympyCondition.json new file mode 100644 index 00000000000..1dc17ec7710 --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/SympyCondition.json @@ -0,0 +1,17 @@ +{ + "cirq_type": "SympyCondition", + "expr": + { + "cirq_type": "sympy.GreaterThan", + "args": [ + { + "cirq_type": "sympy.Symbol", + "name": "a" + }, + { + "cirq_type": "sympy.Symbol", + "name": "b" + } + ] + } +} \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/SympyCondition.repr b/cirq-core/cirq/protocols/json_test_data/SympyCondition.repr new file mode 100644 index 00000000000..6c961a2a1f6 --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/SympyCondition.repr @@ -0,0 +1 @@ +cirq.SympyCondition(sympy.GreaterThan(sympy.Symbol('a'), sympy.Symbol('b'))) \ No newline at end of file diff --git a/cirq-core/cirq/protocols/measurement_key_protocol.py b/cirq-core/cirq/protocols/measurement_key_protocol.py index 639df1aa180..f94ca105148 100644 --- a/cirq-core/cirq/protocols/measurement_key_protocol.py +++ b/cirq-core/cirq/protocols/measurement_key_protocol.py @@ -20,6 +20,9 @@ from cirq import value from cirq._doc import doc_private +if TYPE_CHECKING: + import cirq + if TYPE_CHECKING: import cirq diff --git a/cirq-core/cirq/value/__init__.py b/cirq-core/cirq/value/__init__.py index 390db1e4a11..da6cfc2b058 100644 --- a/cirq-core/cirq/value/__init__.py +++ b/cirq-core/cirq/value/__init__.py @@ -25,6 +25,12 @@ chosen_angle_to_half_turns, ) +from cirq.value.condition import ( + Condition, + KeyCondition, + SympyCondition, +) + from cirq.value.digits import ( big_endian_bits_to_int, big_endian_digits_to_int, diff --git a/cirq-core/cirq/value/condition.py b/cirq-core/cirq/value/condition.py new file mode 100644 index 00000000000..ef432b7506f --- /dev/null +++ b/cirq-core/cirq/value/condition.py @@ -0,0 +1,166 @@ +# Copyright 2021 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import dataclasses +from typing import Dict, Mapping, Sequence, Tuple, TYPE_CHECKING, FrozenSet + +import sympy + +from cirq._compat import proper_repr +from cirq.protocols import json_serialization, measurement_key_protocol as mkp +from cirq.value import digits, measurement_key + +if TYPE_CHECKING: + import cirq + + +class Condition(abc.ABC): + """A classical control condition that can gate an operation.""" + + @property + @abc.abstractmethod + def keys(self) -> Tuple['cirq.MeasurementKey', ...]: + """Gets the control keys.""" + + @abc.abstractmethod + def replace_key(self, current: 'cirq.MeasurementKey', replacement: 'cirq.MeasurementKey'): + """Replaces the control keys.""" + + @abc.abstractmethod + def resolve(self, measurements: Mapping[str, Sequence[int]]) -> bool: + """Resolves the condition based on the measurements.""" + + @property + @abc.abstractmethod + def qasm(self): + """Returns the qasm of this condition.""" + + def _with_measurement_key_mapping_(self, key_map: Dict[str, str]) -> 'cirq.Condition': + condition = self + for k in self.keys: + condition = condition.replace_key(k, mkp.with_measurement_key_mapping(k, key_map)) + return condition + + def _with_key_path_prefix_(self, path: Tuple[str, ...]) -> 'cirq.Condition': + condition = self + for k in self.keys: + condition = condition.replace_key(k, mkp.with_key_path_prefix(k, path)) + return condition + + def _with_rescoped_keys_( + self, + path: Tuple[str, ...], + bindable_keys: FrozenSet['cirq.MeasurementKey'], + ) -> 'cirq.Condition': + condition = self + for key in self.keys: + for i in range(len(path) + 1): + back_path = path[: len(path) - i] + new_key = key.with_key_path_prefix(*back_path) + if new_key in bindable_keys: + condition = condition.replace_key(key, new_key) + break + return condition + + +@dataclasses.dataclass(frozen=True) +class KeyCondition(Condition): + """A classical control condition based on a single measurement key. + + This condition resolves to True iff the measurement key is non-zero at the + time of resolution. + """ + + key: 'cirq.MeasurementKey' + + @property + def keys(self): + return (self.key,) + + def replace_key(self, current: 'cirq.MeasurementKey', replacement: 'cirq.MeasurementKey'): + return KeyCondition(replacement) if self.key == current else self + + def __str__(self): + return str(self.key) + + def __repr__(self): + return f'cirq.KeyCondition({self.key!r})' + + def resolve(self, measurements: Mapping[str, Sequence[int]]) -> bool: + key = str(self.key) + if key not in measurements: + raise ValueError(f'Measurement key {key} missing when testing classical control') + return any(measurements[key]) + + def _json_dict_(self): + return json_serialization.dataclass_json_dict(self) + + @classmethod + def _from_json_dict_(cls, key, **kwargs): + return cls(key=key) + + @property + def qasm(self): + return f'm_{self.key}!=0' + + +@dataclasses.dataclass(frozen=True) +class SympyCondition(Condition): + """A classical control condition based on a sympy expression. + + This condition resolves to True iff the sympy expression resolves to a + truthy value (i.e. `bool(x) == True`) when the measurement keys are + substituted in as the free variables. + """ + + expr: sympy.Basic + + @property + def keys(self): + return tuple( + measurement_key.MeasurementKey.parse_serialized(symbol.name) + for symbol in self.expr.free_symbols + ) + + def replace_key(self, current: 'cirq.MeasurementKey', replacement: 'cirq.MeasurementKey'): + return SympyCondition(self.expr.subs({str(current): sympy.Symbol(str(replacement))})) + + def __str__(self): + return str(self.expr) + + def __repr__(self): + return f'cirq.SympyCondition({proper_repr(self.expr)})' + + def resolve(self, measurements: Mapping[str, Sequence[int]]) -> bool: + missing = [str(k) for k in self.keys if str(k) not in measurements] + if missing: + raise ValueError(f'Measurement keys {missing} missing when testing classical control') + + def value(k): + return digits.big_endian_bits_to_int(measurements[str(k)]) + + replacements = {str(k): value(k) for k in self.keys} + return bool(self.expr.subs(replacements)) + + def _json_dict_(self): + return json_serialization.dataclass_json_dict(self) + + @classmethod + def _from_json_dict_(cls, expr, **kwargs): + return cls(expr=expr) + + @property + def qasm(self): + raise NotImplementedError() diff --git a/cirq-core/cirq/value/condition_test.py b/cirq-core/cirq/value/condition_test.py new file mode 100644 index 00000000000..fd80033a29a --- /dev/null +++ b/cirq-core/cirq/value/condition_test.py @@ -0,0 +1,105 @@ +# Copyright 2021 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +import pytest +import sympy + +import cirq + +key_a = cirq.MeasurementKey.parse_serialized('0:a') +key_b = cirq.MeasurementKey.parse_serialized('0:b') +key_c = cirq.MeasurementKey.parse_serialized('0:c') +init_key_condition = cirq.KeyCondition(key_a) +init_sympy_condition = cirq.SympyCondition(sympy.Symbol('0:a') >= 1) + + +def test_key_condition_with_keys(): + c = init_key_condition.replace_key(key_a, key_b) + assert c.key is key_b + c = init_key_condition.replace_key(key_b, key_c) + assert c.key is key_a + + +def test_key_condition_str(): + assert str(init_key_condition) == '0:a' + + +def test_key_condition_repr(): + cirq.testing.assert_equivalent_repr(init_key_condition) + + +def test_key_condition_resolve(): + assert init_key_condition.resolve({'0:a': [1]}) + assert init_key_condition.resolve({'0:a': [2]}) + assert init_key_condition.resolve({'0:a': [0, 1]}) + assert init_key_condition.resolve({'0:a': [1, 0]}) + assert not init_key_condition.resolve({'0:a': [0]}) + assert not init_key_condition.resolve({'0:a': [0, 0]}) + assert not init_key_condition.resolve({'0:a': []}) + assert not init_key_condition.resolve({'0:a': [0], 'b': [1]}) + with pytest.raises( + ValueError, match='Measurement key 0:a missing when testing classical control' + ): + _ = init_key_condition.resolve({}) + with pytest.raises( + ValueError, match='Measurement key 0:a missing when testing classical control' + ): + _ = init_key_condition.resolve({'0:b': [1]}) + + +def test_key_condition_qasm(): + assert cirq.KeyCondition(cirq.MeasurementKey('a')).qasm == 'm_a!=0' + + +def test_sympy_condition_with_keys(): + c = init_sympy_condition.replace_key(key_a, key_b) + assert c.keys == (key_b,) + c = init_sympy_condition.replace_key(key_b, key_c) + assert c.keys == (key_a,) + + +def test_sympy_condition_str(): + assert str(init_sympy_condition) == '0:a >= 1' + + +def test_sympy_condition_repr(): + cirq.testing.assert_equivalent_repr(init_sympy_condition) + + +def test_sympy_condition_resolve(): + assert init_sympy_condition.resolve({'0:a': [1]}) + assert init_sympy_condition.resolve({'0:a': [2]}) + assert init_sympy_condition.resolve({'0:a': [0, 1]}) + assert init_sympy_condition.resolve({'0:a': [1, 0]}) + assert not init_sympy_condition.resolve({'0:a': [0]}) + assert not init_sympy_condition.resolve({'0:a': [0, 0]}) + assert not init_sympy_condition.resolve({'0:a': []}) + assert not init_sympy_condition.resolve({'0:a': [0], 'b': [1]}) + with pytest.raises( + ValueError, + match=re.escape("Measurement keys ['0:a'] missing when testing classical control"), + ): + _ = init_sympy_condition.resolve({}) + with pytest.raises( + ValueError, + match=re.escape("Measurement keys ['0:a'] missing when testing classical control"), + ): + _ = init_sympy_condition.resolve({'0:b': [1]}) + + +def test_sympy_condition_qasm(): + with pytest.raises(NotImplementedError): + _ = init_sympy_condition.qasm