From b76950ed84c1768606d27d5dfad2867998b7a2b5 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Wed, 8 Dec 2021 16:47:49 -0800 Subject: [PATCH 01/89] Allow sympy expressions as classical controls --- .../ops/classically_controlled_operation.py | 117 +++++++++++++----- .../classically_controlled_operation_test.py | 2 +- cirq-core/cirq/ops/raw_types.py | 41 +++++- 3 files changed, 126 insertions(+), 34 deletions(-) diff --git a/cirq-core/cirq/ops/classically_controlled_operation.py b/cirq-core/cirq/ops/classically_controlled_operation.py index 94114e742f1..6a2c4b88e5c 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation.py +++ b/cirq-core/cirq/ops/classically_controlled_operation.py @@ -16,13 +16,15 @@ Any, Dict, FrozenSet, + List, Optional, - Sequence, TYPE_CHECKING, Tuple, Union, ) +import sympy + from cirq import protocols, value from cirq.ops import raw_types @@ -46,7 +48,7 @@ class ClassicallyControlledOperation(raw_types.Operation): def __init__( self, sub_operation: 'cirq.Operation', - conditions: Sequence[Union[str, 'cirq.MeasurementKey']], + conditions: Tuple[Union[str, 'cirq.MeasurementKey', raw_types.Condition], ...], ): """Initializes a `ClassicallyControlledOperation`. @@ -68,13 +70,26 @@ def __init__( raise ValueError( f'Cannot conditionally run operations with measurements: {sub_operation}' ) - keys = tuple(value.MeasurementKey(c) if isinstance(c, str) else c for c in conditions) if isinstance(sub_operation, ClassicallyControlledOperation): - keys += sub_operation._control_keys + conditions += sub_operation._conditions sub_operation = sub_operation._sub_operation - self._control_keys: Tuple['cirq.MeasurementKey', ...] = keys + conds: List[raw_types.Condition] = [] + for c in conditions: + if isinstance(c, str): + c1 = parse_condition(c) or value.MeasurementKey.parse_serialized(c) + if c1 is None: + raise ValueError(f"'{c}' is not a valid condition") + c = c1 + if isinstance(c, value.MeasurementKey): + c = raw_types.Condition(sympy.sympify('x0'), (c,)) + conds.append(c) + self._conditions = tuple(conds) self._sub_operation: 'cirq.Operation' = sub_operation + @property + def classical_controls(self) -> FrozenSet[raw_types.Condition]: + return frozenset(self._conditions).union(self._sub_operation.classical_controls) + def without_classical_controls(self) -> 'cirq.Operation': return self._sub_operation.without_classical_controls() @@ -84,7 +99,7 @@ def qubits(self): def with_qubits(self, *new_qubits): return self._sub_operation.with_qubits(*new_qubits).with_classical_controls( - *self._control_keys + *self._conditions ) def _decompose_(self): @@ -92,19 +107,19 @@ def _decompose_(self): if result is NotImplemented: return NotImplemented - return [ClassicallyControlledOperation(op, self._control_keys) for op in result] + return [ClassicallyControlledOperation(op, self._conditions) for op in result] def _value_equality_values_(self): - return (frozenset(self._control_keys), self._sub_operation) + return (frozenset(self._conditions), self._sub_operation) def __str__(self) -> str: - keys = ', '.join(map(str, self._control_keys)) + keys = ', '.join(map(str, self._conditions)) return f'{self._sub_operation}.with_classical_controls({keys})' def __repr__(self): return ( f'cirq.ClassicallyControlledOperation(' - f'{self._sub_operation!r}, {list(self._control_keys)!r})' + f'{self._sub_operation!r}, {list(self._conditions)!r})' ) def _is_parameterized_(self) -> bool: @@ -117,7 +132,7 @@ def _resolve_parameters_( self, resolver: 'cirq.ParamResolver', recursive: bool ) -> 'ClassicallyControlledOperation': new_sub_op = protocols.resolve_parameters(self._sub_operation, resolver, recursive) - return new_sub_op.with_classical_controls(*self._control_keys) + return new_sub_op.with_classical_controls(*self._conditions) def _circuit_diagram_info_( self, args: 'cirq.CircuitDiagramInfoArgs' @@ -133,12 +148,12 @@ def _circuit_diagram_info_( if sub_info is None: return NotImplemented # coverage: ignore - wire_symbols = sub_info.wire_symbols + ('^',) * len(self._control_keys) + wire_symbols = sub_info.wire_symbols + ('^',) * len(self._conditions) exponent_qubit_index = None if sub_info.exponent_qubit_index is not None: - exponent_qubit_index = sub_info.exponent_qubit_index + len(self._control_keys) + exponent_qubit_index = sub_info.exponent_qubit_index + len(self._conditions) elif sub_info.exponent is not None: - exponent_qubit_index = len(self._control_keys) + exponent_qubit_index = len(self._conditions) return protocols.CircuitDiagramInfo( wire_symbols=wire_symbols, exponent=sub_info.exponent, @@ -148,39 +163,77 @@ def _circuit_diagram_info_( def _json_dict_(self) -> Dict[str, Any]: return { 'cirq_type': self.__class__.__name__, - 'conditions': self._control_keys, + 'conditions': self._conditions, 'sub_operation': self._sub_operation, } def _act_on_(self, args: 'cirq.ActOnArgs') -> bool: - def not_zero(measurement): - return any(i != 0 for i in measurement) - - measurements = [ - args.log_of_measurement_results.get(str(key), str(key)) for key in self._control_keys - ] - missing = [m for m in measurements if isinstance(m, str)] - if missing: - raise ValueError(f'Measurement keys {missing} missing when performing {self}') - if all(not_zero(measurement) for measurement in measurements): - protocols.act_on(self._sub_operation, args) + for condition in self._conditions: + keys, expr = condition.keys, condition.expr + missing = [str(k) for k in keys if str(k) not in args.log_of_measurement_results] + if missing: + raise ValueError(f'Measurement keys {missing} missing when performing {self}') + replacements = { + f'x{i}': args.log_of_measurement_results[str(k)][0] for i, k in enumerate(keys) + } + result = expr.subs(replacements) + if not result: + return True + protocols.act_on(self._sub_operation, args) return True def _with_measurement_key_mapping_( self, key_map: Dict[str, str] ) -> 'ClassicallyControlledOperation': - keys = [protocols.with_measurement_key_mapping(k, key_map) for k in self._control_keys] - return self._sub_operation.with_classical_controls(*keys) + def map_condition(condition: raw_types.Condition) -> raw_types.Condition: + keys = [protocols.with_measurement_key_mapping(k, key_map) for k in condition.keys] + return condition.with_keys(tuple(keys)) + conditions = [map_condition(c) for c in self._conditions] + return self._sub_operation.with_classical_controls(*conditions) def _with_key_path_prefix_(self, path: Tuple[str, ...]) -> 'ClassicallyControlledOperation': - keys = [protocols.with_key_path_prefix(k, path) for k in self._control_keys] - return self._sub_operation.with_classical_controls(*keys) + def map_condition(condition: raw_types.Condition) -> raw_types.Condition: + keys = tuple(protocols.with_key_path_prefix(k, path) for k in condition.keys) + return condition.with_keys(keys) + conditions = [map_condition(c) for c in self._conditions] + return self._sub_operation.with_classical_controls(*conditions) def _control_keys_(self) -> FrozenSet[value.MeasurementKey]: - return frozenset(self._control_keys).union(protocols.control_keys(self._sub_operation)) + local_keys = frozenset() + for condition in self._conditions: + local_keys = local_keys.union(condition.keys) + return local_keys.union(protocols.control_keys(self._sub_operation)) def _qasm_(self, args: 'cirq.QasmArgs') -> Optional[str]: args.validate_version('2.0') - keys = [f'm_{key}!=0' for key in self._control_keys] + keys = [f'm_{key}!=0' for key in self._conditions] all_keys = " && ".join(keys) return args.format('if ({0}) {1}', all_keys, protocols.qasm(self._sub_operation, args=args)) + + +def parse_condition(s: str) -> Optional[raw_types.Condition]: + in_key = False + key_count = 0 + s_out = '' + key_name = '' + keys = [] + for c in s: + if not in_key: + if c == '{': + in_key = True + else: + s_out += c + else: + if c == '}': + symbol_name = f'x{key_count}' + s_out += symbol_name + keys.append(value.MeasurementKey.parse_serialized(key_name)) + key_name = '' + key_count += 1 + in_key = False + else: + key_name += c + expr = sympy.sympify(s_out) + if len(expr.free_symbols) != len(keys): + return None + return raw_types.Condition(expr, tuple(keys)) diff --git a/cirq-core/cirq/ops/classically_controlled_operation_test.py b/cirq-core/cirq/ops/classically_controlled_operation_test.py index ff46dccb5fb..70a318eb8ab 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation_test.py +++ b/cirq-core/cirq/ops/classically_controlled_operation_test.py @@ -334,7 +334,7 @@ def test_key_set_in_subcircuit_outer_scope(): def test_condition_flattening(): q0 = cirq.LineQubit(0) op = cirq.X(q0).with_classical_controls('a').with_classical_controls('b') - assert set(map(str, op._control_keys)) == {'a', 'b'} + assert set(map(str, op.classical_controls)) == {'a', 'b'} assert isinstance(op._sub_operation, cirq.GateOperation) diff --git a/cirq-core/cirq/ops/raw_types.py b/cirq-core/cirq/ops/raw_types.py index 8212ba23f3e..153843ebd55 100644 --- a/cirq-core/cirq/ops/raw_types.py +++ b/cirq-core/cirq/ops/raw_types.py @@ -22,6 +22,7 @@ Callable, Collection, Dict, + FrozenSet, Hashable, Iterable, List, @@ -34,6 +35,7 @@ ) import numpy as np +import sympy from cirq import protocols, value from cirq._import import LazyLoader @@ -421,6 +423,35 @@ def _json_dict_(self) -> Dict[str, Any]: return protocols.obj_to_dict_helper(self, attribute_names=[]) +class Condition: + def __init__(self, expr: sympy.Expr, keys: Tuple[value.MeasurementKey, ...]): + self._expr = expr + self._keys = keys + + @property + def keys(self): + return self._keys + + @property + def expr(self): + return self._expr + + def with_keys(self, keys: Tuple[value.MeasurementKey, ...]): + assert len(keys) == len(self._keys) + return Condition(self._expr, keys) + + def __eq__(self, other): + return isinstance(other, Condition) and self._keys == other._keys and self._expr == other._expr + + def __hash__(self): + return hash(self._keys) ^ hash(self._expr) + + def __str__(self): + if self._expr == sympy.symbols('x0') and len(self._keys) == 1: + return str(self._keys[0]) + return f'({self._expr}, {self._keys})' + + TSelf = TypeVar('TSelf', bound='Operation') @@ -590,8 +621,12 @@ def _commutes_( return np.allclose(m12, m21, atol=atol) + @property + def classical_controls(self) -> FrozenSet[Condition]: + return frozenset() + def with_classical_controls( - self, *conditions: Union[str, 'cirq.MeasurementKey'] + self, *conditions: Union[str, 'cirq.MeasurementKey', Condition] ) -> 'cirq.ClassicallyControlledOperation': """Returns a classically controlled version of this operation. @@ -821,6 +856,10 @@ def _equal_up_to_global_phase_( ) -> Union[NotImplementedType, bool]: return protocols.equal_up_to_global_phase(self.sub_operation, other, atol=atol) + @property + def classical_controls(self) -> FrozenSet[Condition]: + return self.sub_operation.classical_controls + def without_classical_controls(self) -> 'cirq.Operation': new_sub_operation = self.sub_operation.without_classical_controls() return self if new_sub_operation is self.sub_operation else new_sub_operation From 5fdff5026d704e43a1076fe339a485768db5969b Mon Sep 17 00:00:00 2001 From: daxfohl Date: Wed, 8 Dec 2021 17:26:23 -0800 Subject: [PATCH 02/89] Format --- .../cirq/ops/classically_controlled_operation.py | 14 +++++++++----- cirq-core/cirq/ops/raw_types.py | 4 ++-- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/cirq-core/cirq/ops/classically_controlled_operation.py b/cirq-core/cirq/ops/classically_controlled_operation.py index 6a2c4b88e5c..0efd99fed38 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation.py +++ b/cirq-core/cirq/ops/classically_controlled_operation.py @@ -18,6 +18,7 @@ FrozenSet, List, Optional, + Sequence, TYPE_CHECKING, Tuple, Union, @@ -48,7 +49,7 @@ class ClassicallyControlledOperation(raw_types.Operation): def __init__( self, sub_operation: 'cirq.Operation', - conditions: Tuple[Union[str, 'cirq.MeasurementKey', raw_types.Condition], ...], + conditions: Sequence[Union[str, 'cirq.MeasurementKey', raw_types.Condition]], ): """Initializes a `ClassicallyControlledOperation`. @@ -70,6 +71,7 @@ def __init__( raise ValueError( f'Cannot conditionally run operations with measurements: {sub_operation}' ) + conditions = tuple(conditions) if isinstance(sub_operation, ClassicallyControlledOperation): conditions += sub_operation._conditions sub_operation = sub_operation._sub_operation @@ -83,7 +85,7 @@ def __init__( if isinstance(c, value.MeasurementKey): c = raw_types.Condition(sympy.sympify('x0'), (c,)) conds.append(c) - self._conditions = tuple(conds) + self._conditions: Tuple[raw_types.Condition, ...] = tuple(conds) self._sub_operation: 'cirq.Operation' = sub_operation @property @@ -188,6 +190,7 @@ def _with_measurement_key_mapping_( def map_condition(condition: raw_types.Condition) -> raw_types.Condition: keys = [protocols.with_measurement_key_mapping(k, key_map) for k in condition.keys] return condition.with_keys(tuple(keys)) + conditions = [map_condition(c) for c in self._conditions] return self._sub_operation.with_classical_controls(*conditions) @@ -195,13 +198,14 @@ def _with_key_path_prefix_(self, path: Tuple[str, ...]) -> 'ClassicallyControlle def map_condition(condition: raw_types.Condition) -> raw_types.Condition: keys = tuple(protocols.with_key_path_prefix(k, path) for k in condition.keys) return condition.with_keys(keys) + conditions = [map_condition(c) for c in self._conditions] return self._sub_operation.with_classical_controls(*conditions) def _control_keys_(self) -> FrozenSet[value.MeasurementKey]: - local_keys = frozenset() - for condition in self._conditions: - local_keys = local_keys.union(condition.keys) + local_keys: FrozenSet[value.MeasurementKey] = frozenset( + k for condition in self._conditions for k in condition.keys + ) return local_keys.union(protocols.control_keys(self._sub_operation)) def _qasm_(self, args: 'cirq.QasmArgs') -> Optional[str]: diff --git a/cirq-core/cirq/ops/raw_types.py b/cirq-core/cirq/ops/raw_types.py index 153843ebd55..6e0e81ff891 100644 --- a/cirq-core/cirq/ops/raw_types.py +++ b/cirq-core/cirq/ops/raw_types.py @@ -440,8 +440,8 @@ def with_keys(self, keys: Tuple[value.MeasurementKey, ...]): assert len(keys) == len(self._keys) return Condition(self._expr, keys) - def __eq__(self, other): - return isinstance(other, Condition) and self._keys == other._keys and self._expr == other._expr + def __eq__(self, x): + return isinstance(x, Condition) and self._keys == x._keys and self._expr == x._expr def __hash__(self): return hash(self._keys) ^ hash(self._expr) From ef081d7f9e054d418dbffe17f5cbb0351169538d Mon Sep 17 00:00:00 2001 From: daxfohl Date: Wed, 8 Dec 2021 18:29:58 -0800 Subject: [PATCH 03/89] move Condition to value --- cirq-core/cirq/__init__.py | 1 + .../ops/classically_controlled_operation.py | 18 +++---- cirq-core/cirq/ops/raw_types.py | 35 ++----------- cirq-core/cirq/value/__init__.py | 4 ++ cirq-core/cirq/value/condition.py | 52 +++++++++++++++++++ 5 files changed, 69 insertions(+), 41 deletions(-) create mode 100644 cirq-core/cirq/value/condition.py diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index 0b26d585b33..deba34804be 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -473,6 +473,7 @@ canonicalize_half_turns, chosen_angle_to_canonical_half_turns, chosen_angle_to_half_turns, + Condition, Duration, DURATION_LIKE, GenericMetaImplementAnyOneOf, diff --git a/cirq-core/cirq/ops/classically_controlled_operation.py b/cirq-core/cirq/ops/classically_controlled_operation.py index 0efd99fed38..172f7dd9293 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation.py +++ b/cirq-core/cirq/ops/classically_controlled_operation.py @@ -49,7 +49,7 @@ class ClassicallyControlledOperation(raw_types.Operation): def __init__( self, sub_operation: 'cirq.Operation', - conditions: Sequence[Union[str, 'cirq.MeasurementKey', raw_types.Condition]], + conditions: Sequence[Union[str, 'cirq.MeasurementKey', 'cirq.Condition']], ): """Initializes a `ClassicallyControlledOperation`. @@ -75,7 +75,7 @@ def __init__( if isinstance(sub_operation, ClassicallyControlledOperation): conditions += sub_operation._conditions sub_operation = sub_operation._sub_operation - conds: List[raw_types.Condition] = [] + conds: List['cirq.Condition'] = [] for c in conditions: if isinstance(c, str): c1 = parse_condition(c) or value.MeasurementKey.parse_serialized(c) @@ -83,13 +83,13 @@ def __init__( raise ValueError(f"'{c}' is not a valid condition") c = c1 if isinstance(c, value.MeasurementKey): - c = raw_types.Condition(sympy.sympify('x0'), (c,)) + c = value.Condition(sympy.sympify('x0'), (c,)) conds.append(c) - self._conditions: Tuple[raw_types.Condition, ...] = tuple(conds) + self._conditions: Tuple['cirq.Condition', ...] = tuple(conds) self._sub_operation: 'cirq.Operation' = sub_operation @property - def classical_controls(self) -> FrozenSet[raw_types.Condition]: + def classical_controls(self) -> FrozenSet['cirq.Condition']: return frozenset(self._conditions).union(self._sub_operation.classical_controls) def without_classical_controls(self) -> 'cirq.Operation': @@ -187,7 +187,7 @@ def _act_on_(self, args: 'cirq.ActOnArgs') -> bool: def _with_measurement_key_mapping_( self, key_map: Dict[str, str] ) -> 'ClassicallyControlledOperation': - def map_condition(condition: raw_types.Condition) -> raw_types.Condition: + def map_condition(condition: 'cirq.Condition') -> 'cirq.Condition': keys = [protocols.with_measurement_key_mapping(k, key_map) for k in condition.keys] return condition.with_keys(tuple(keys)) @@ -195,7 +195,7 @@ def map_condition(condition: raw_types.Condition) -> raw_types.Condition: return self._sub_operation.with_classical_controls(*conditions) def _with_key_path_prefix_(self, path: Tuple[str, ...]) -> 'ClassicallyControlledOperation': - def map_condition(condition: raw_types.Condition) -> raw_types.Condition: + def map_condition(condition: 'cirq.Condition') -> 'cirq.Condition': keys = tuple(protocols.with_key_path_prefix(k, path) for k in condition.keys) return condition.with_keys(keys) @@ -215,7 +215,7 @@ def _qasm_(self, args: 'cirq.QasmArgs') -> Optional[str]: return args.format('if ({0}) {1}', all_keys, protocols.qasm(self._sub_operation, args=args)) -def parse_condition(s: str) -> Optional[raw_types.Condition]: +def parse_condition(s: str) -> Optional['cirq.Condition']: in_key = False key_count = 0 s_out = '' @@ -240,4 +240,4 @@ def parse_condition(s: str) -> Optional[raw_types.Condition]: expr = sympy.sympify(s_out) if len(expr.free_symbols) != len(keys): return None - return raw_types.Condition(expr, tuple(keys)) + return value.Condition(expr, tuple(keys)) diff --git a/cirq-core/cirq/ops/raw_types.py b/cirq-core/cirq/ops/raw_types.py index 6e0e81ff891..52b82a2ec3d 100644 --- a/cirq-core/cirq/ops/raw_types.py +++ b/cirq-core/cirq/ops/raw_types.py @@ -423,35 +423,6 @@ def _json_dict_(self) -> Dict[str, Any]: return protocols.obj_to_dict_helper(self, attribute_names=[]) -class Condition: - def __init__(self, expr: sympy.Expr, keys: Tuple[value.MeasurementKey, ...]): - self._expr = expr - self._keys = keys - - @property - def keys(self): - return self._keys - - @property - def expr(self): - return self._expr - - def with_keys(self, keys: Tuple[value.MeasurementKey, ...]): - assert len(keys) == len(self._keys) - return Condition(self._expr, keys) - - def __eq__(self, x): - return isinstance(x, Condition) and self._keys == x._keys and self._expr == x._expr - - def __hash__(self): - return hash(self._keys) ^ hash(self._expr) - - def __str__(self): - if self._expr == sympy.symbols('x0') and len(self._keys) == 1: - return str(self._keys[0]) - return f'({self._expr}, {self._keys})' - - TSelf = TypeVar('TSelf', bound='Operation') @@ -622,11 +593,11 @@ def _commutes_( return np.allclose(m12, m21, atol=atol) @property - def classical_controls(self) -> FrozenSet[Condition]: + def classical_controls(self) -> FrozenSet['cirq.Condition']: return frozenset() def with_classical_controls( - self, *conditions: Union[str, 'cirq.MeasurementKey', Condition] + self, *conditions: Union[str, 'cirq.MeasurementKey', 'cirq.Condition'] ) -> 'cirq.ClassicallyControlledOperation': """Returns a classically controlled version of this operation. @@ -857,7 +828,7 @@ def _equal_up_to_global_phase_( return protocols.equal_up_to_global_phase(self.sub_operation, other, atol=atol) @property - def classical_controls(self) -> FrozenSet[Condition]: + def classical_controls(self) -> FrozenSet['cirq.Condition']: return self.sub_operation.classical_controls def without_classical_controls(self) -> 'cirq.Operation': diff --git a/cirq-core/cirq/value/__init__.py b/cirq-core/cirq/value/__init__.py index 390db1e4a11..a3ffea9507e 100644 --- a/cirq-core/cirq/value/__init__.py +++ b/cirq-core/cirq/value/__init__.py @@ -25,6 +25,10 @@ chosen_angle_to_half_turns, ) +from cirq.value.condition import ( + Condition, +) + from cirq.value.digits import ( big_endian_bits_to_int, big_endian_digits_to_int, diff --git a/cirq-core/cirq/value/condition.py b/cirq-core/cirq/value/condition.py new file mode 100644 index 00000000000..bfc52211a28 --- /dev/null +++ b/cirq-core/cirq/value/condition.py @@ -0,0 +1,52 @@ +# Copyright 2021 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import ( + Tuple, + TYPE_CHECKING, +) + +import sympy + +if TYPE_CHECKING: + import cirq + + +class Condition: + def __init__(self, expr: sympy.Expr, keys: Tuple['cirq.MeasurementKey', ...]): + self._expr = expr + self._keys = keys + + @property + def keys(self): + return self._keys + + @property + def expr(self): + return self._expr + + def with_keys(self, keys: Tuple['cirq.MeasurementKey', ...]): + assert len(keys) == len(self._keys) + return Condition(self._expr, keys) + + def __eq__(self, x): + return isinstance(x, Condition) and self._keys == x._keys and self._expr == x._expr + + def __hash__(self): + return hash(self._keys) ^ hash(self._expr) + + def __str__(self): + if self._expr == sympy.symbols('x0') and len(self._keys) == 1: + return str(self._keys[0]) + return f'({self._expr}, {self._keys})' From 0dd430e7aa33f29863dfeefc62308d7c84f69afe Mon Sep 17 00:00:00 2001 From: daxfohl Date: Wed, 8 Dec 2021 20:18:37 -0800 Subject: [PATCH 04/89] Condition subclasses --- cirq-core/cirq/__init__.py | 4 + .../ops/classically_controlled_operation.py | 49 +------- .../classically_controlled_operation_test.py | 5 +- cirq-core/cirq/value/__init__.py | 4 + cirq-core/cirq/value/condition.py | 108 +++++++++++++++--- 5 files changed, 104 insertions(+), 66 deletions(-) diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index deba34804be..42cd728f381 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -477,12 +477,16 @@ Duration, DURATION_LIKE, GenericMetaImplementAnyOneOf, + KeyCondition, LinearDict, MEASUREMENT_KEY_SEPARATOR, MeasurementKey, + parse_condition, + parse_sympy_condition, PeriodicValue, RANDOM_STATE_OR_SEED_LIKE, state_vector_to_probabilities, + SympyCondition, Timestamp, TParamKey, TParamVal, diff --git a/cirq-core/cirq/ops/classically_controlled_operation.py b/cirq-core/cirq/ops/classically_controlled_operation.py index 172f7dd9293..ad80b8aecbc 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation.py +++ b/cirq-core/cirq/ops/classically_controlled_operation.py @@ -78,12 +78,9 @@ def __init__( conds: List['cirq.Condition'] = [] for c in conditions: if isinstance(c, str): - c1 = parse_condition(c) or value.MeasurementKey.parse_serialized(c) - if c1 is None: - raise ValueError(f"'{c}' is not a valid condition") - c = c1 + c = value.parse_condition(c) if isinstance(c, value.MeasurementKey): - c = value.Condition(sympy.sympify('x0'), (c,)) + c = value.KeyCondition(c) conds.append(c) self._conditions: Tuple['cirq.Condition', ...] = tuple(conds) self._sub_operation: 'cirq.Operation' = sub_operation @@ -170,18 +167,8 @@ def _json_dict_(self) -> Dict[str, Any]: } def _act_on_(self, args: 'cirq.ActOnArgs') -> bool: - for condition in self._conditions: - keys, expr = condition.keys, condition.expr - missing = [str(k) for k in keys if str(k) not in args.log_of_measurement_results] - if missing: - raise ValueError(f'Measurement keys {missing} missing when performing {self}') - replacements = { - f'x{i}': args.log_of_measurement_results[str(k)][0] for i, k in enumerate(keys) - } - result = expr.subs(replacements) - if not result: - return True - protocols.act_on(self._sub_operation, args) + if all(c.resolve(args.log_of_measurement_results) for c in self._conditions): + protocols.act_on(self._sub_operation, args) return True def _with_measurement_key_mapping_( @@ -213,31 +200,3 @@ def _qasm_(self, args: 'cirq.QasmArgs') -> Optional[str]: keys = [f'm_{key}!=0' for key in self._conditions] all_keys = " && ".join(keys) return args.format('if ({0}) {1}', all_keys, protocols.qasm(self._sub_operation, args=args)) - - -def parse_condition(s: str) -> Optional['cirq.Condition']: - in_key = False - key_count = 0 - s_out = '' - key_name = '' - keys = [] - for c in s: - if not in_key: - if c == '{': - in_key = True - else: - s_out += c - else: - if c == '}': - symbol_name = f'x{key_count}' - s_out += symbol_name - keys.append(value.MeasurementKey.parse_serialized(key_name)) - key_name = '' - key_count += 1 - in_key = False - else: - key_name += c - expr = sympy.sympify(s_out) - if len(expr.free_symbols) != len(keys): - return None - return value.Condition(expr, tuple(keys)) diff --git a/cirq-core/cirq/ops/classically_controlled_operation_test.py b/cirq-core/cirq/ops/classically_controlled_operation_test.py index 70a318eb8ab..2e48bc73fa3 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation_test.py +++ b/cirq-core/cirq/ops/classically_controlled_operation_test.py @@ -410,9 +410,6 @@ def test_unmeasured_condition(): q0 = cirq.LineQubit(0) bad_circuit = cirq.Circuit(cirq.X(q0).with_classical_controls('a')) with pytest.raises( - ValueError, - match=re.escape( - "Measurement keys ['a'] missing when performing X(0).with_classical_controls(a)" - ), + ValueError, match='Measurement key a missing when testing classical control' ): _ = cirq.Simulator().simulate(bad_circuit) diff --git a/cirq-core/cirq/value/__init__.py b/cirq-core/cirq/value/__init__.py index a3ffea9507e..05e92d0777a 100644 --- a/cirq-core/cirq/value/__init__.py +++ b/cirq-core/cirq/value/__init__.py @@ -27,6 +27,10 @@ from cirq.value.condition import ( Condition, + KeyCondition, + parse_condition, + parse_sympy_condition, + SympyCondition, ) from cirq.value.digits import ( diff --git a/cirq-core/cirq/value/condition.py b/cirq-core/cirq/value/condition.py index bfc52211a28..506038481e5 100644 --- a/cirq-core/cirq/value/condition.py +++ b/cirq-core/cirq/value/condition.py @@ -13,40 +13,114 @@ # limitations under the License. from typing import ( + Dict, + List, + Optional, Tuple, TYPE_CHECKING, ) +import abc +import dataclasses import sympy +from cirq.value import measurement_key + if TYPE_CHECKING: import cirq -class Condition: - def __init__(self, expr: sympy.Expr, keys: Tuple['cirq.MeasurementKey', ...]): - self._expr = expr - self._keys = keys +class Condition(abc.ABC): + @property + @abc.abstractmethod + def keys(self) -> Tuple['cirq.MeasurementKey', ...]: + """Gets the control keys.""" + + @abc.abstractmethod + def with_keys(self, keys: Tuple['cirq.MeasurementKey', ...]): + """Replaces the control keys.""" + + @abc.abstractmethod + def resolve(self, measurements: Dict[str, List[int]]) -> bool: + """Resolves the condition based on the measurements.""" + + +@dataclasses.dataclass(frozen=True) +class KeyCondition(Condition): + key: 'cirq.MeasurementKey' @property def keys(self): - return self._keys + return (self.key,) + + def with_keys(self, keys: Tuple['cirq.MeasurementKey', ...]): + assert len(keys) == 1 + return KeyCondition(keys[0]) + + def __str__(self): + return str(self.key) + + def resolve(self, measurements: Dict[str, List[int]]) -> bool: + key = str(self.key) + if key not in measurements: + raise ValueError(f'Measurement key {key} missing when testing classical control') + return any(measurements[key]) + + +@dataclasses.dataclass(frozen=True) +class SympyCondition(Condition): + expr: sympy.Expr + control_keys: Tuple['cirq.MeasurementKey', ...] @property - def expr(self): - return self._expr + def keys(self): + return self.control_keys def with_keys(self, keys: Tuple['cirq.MeasurementKey', ...]): - assert len(keys) == len(self._keys) - return Condition(self._expr, keys) + assert len(keys) == len(self.control_keys) + return dataclasses.replace(self, control_keys=keys) - def __eq__(self, x): - return isinstance(x, Condition) and self._keys == x._keys and self._expr == x._expr + def __str__(self): + return f'({self.expr}, {self.control_keys})' - def __hash__(self): - return hash(self._keys) ^ hash(self._expr) + def resolve(self, measurements: Dict[str, List[int]]) -> bool: + missing = [str(k) for k in self.keys if str(k) not in measurements] + if missing: + raise ValueError(f'Measurement keys {missing} missing when testing classical control') + replacements = {f'x{i}': measurements[str(k)][0] for i, k in enumerate(self.keys)} + return bool(self.expr.subs(replacements)) - def __str__(self): - if self._expr == sympy.symbols('x0') and len(self._keys) == 1: - return str(self._keys[0]) - return f'({self._expr}, {self._keys})' + +def parse_sympy_condition(s: str) -> Optional['cirq.SympyCondition']: + in_key = False + key_count = 0 + s_out = '' + key_name = '' + keys = [] + for c in s: + if not in_key: + if c == '{': + in_key = True + else: + s_out += c + else: + if c == '}': + symbol_name = f'x{key_count}' + s_out += symbol_name + keys.append(measurement_key.MeasurementKey.parse_serialized(key_name)) + key_name = '' + key_count += 1 + in_key = False + else: + key_name += c + expr = sympy.sympify(s_out) + if len(expr.free_symbols) != len(keys): + return None + return SympyCondition(expr, tuple(keys)) + + +def parse_condition(s: str) -> 'cirq.Condition': + c = parse_sympy_condition(s) or measurement_key.MeasurementKey.parse_serialized(s) + if c is None: + raise ValueError(f"'{s}' is not a valid condition") + return c From 3fb7f75423eecdc422593baeabc77b8f854a63e8 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Wed, 8 Dec 2021 20:35:10 -0800 Subject: [PATCH 05/89] Fix sympy resolver --- cirq-core/cirq/value/condition.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/cirq-core/cirq/value/condition.py b/cirq-core/cirq/value/condition.py index 506038481e5..ef5f69a2621 100644 --- a/cirq-core/cirq/value/condition.py +++ b/cirq-core/cirq/value/condition.py @@ -13,9 +13,9 @@ # limitations under the License. from typing import ( - Dict, - List, + Mapping, Optional, + Sequence, Tuple, TYPE_CHECKING, ) @@ -41,7 +41,7 @@ def with_keys(self, keys: Tuple['cirq.MeasurementKey', ...]): """Replaces the control keys.""" @abc.abstractmethod - def resolve(self, measurements: Dict[str, List[int]]) -> bool: + def resolve(self, measurements: Mapping[str, Sequence[int]]) -> bool: """Resolves the condition based on the measurements.""" @@ -60,7 +60,7 @@ def with_keys(self, keys: Tuple['cirq.MeasurementKey', ...]): def __str__(self): return str(self.key) - def resolve(self, measurements: Dict[str, List[int]]) -> bool: + def resolve(self, measurements: Mapping[str, Sequence[int]]) -> bool: key = str(self.key) if key not in measurements: raise ValueError(f'Measurement key {key} missing when testing classical control') @@ -83,11 +83,15 @@ def with_keys(self, keys: Tuple['cirq.MeasurementKey', ...]): def __str__(self): return f'({self.expr}, {self.control_keys})' - def resolve(self, measurements: Dict[str, List[int]]) -> bool: + def resolve(self, measurements: Mapping[str, Sequence[int]]) -> bool: missing = [str(k) for k in self.keys if str(k) not in measurements] if missing: raise ValueError(f'Measurement keys {missing} missing when testing classical control') - replacements = {f'x{i}': measurements[str(k)][0] for i, k in enumerate(self.keys)} + + def value(k): + return sum(v * 2 ** i for i, v in enumerate(measurements[str(k)])) + + replacements = {f'x{i}': value(k) for i, k in enumerate(self.keys)} return bool(self.expr.subs(replacements)) From 9568ae0630f308d6f8a0f9529c074c5957597d97 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Wed, 8 Dec 2021 22:16:03 -0800 Subject: [PATCH 06/89] lint --- cirq-core/cirq/ops/classically_controlled_operation.py | 2 -- cirq-core/cirq/ops/classically_controlled_operation_test.py | 3 +-- cirq-core/cirq/ops/raw_types.py | 1 - 3 files changed, 1 insertion(+), 5 deletions(-) diff --git a/cirq-core/cirq/ops/classically_controlled_operation.py b/cirq-core/cirq/ops/classically_controlled_operation.py index ad80b8aecbc..c5df8f38a1f 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation.py +++ b/cirq-core/cirq/ops/classically_controlled_operation.py @@ -24,8 +24,6 @@ Union, ) -import sympy - from cirq import protocols, value from cirq.ops import raw_types diff --git a/cirq-core/cirq/ops/classically_controlled_operation_test.py b/cirq-core/cirq/ops/classically_controlled_operation_test.py index 2e48bc73fa3..a7d3f637ffb 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation_test.py +++ b/cirq-core/cirq/ops/classically_controlled_operation_test.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import re import pytest import sympy @@ -395,7 +394,7 @@ def test_repr(): op = cirq.X(q0).with_classical_controls('a') assert repr(op) == ( "cirq.ClassicallyControlledOperation(" - "cirq.X(cirq.LineQubit(0)), [cirq.MeasurementKey(name='a')]" + "cirq.X(cirq.LineQubit(0)), [KeyCondition(key=cirq.MeasurementKey(name='a'))]" ")" ) diff --git a/cirq-core/cirq/ops/raw_types.py b/cirq-core/cirq/ops/raw_types.py index 52b82a2ec3d..d278674c9d9 100644 --- a/cirq-core/cirq/ops/raw_types.py +++ b/cirq-core/cirq/ops/raw_types.py @@ -35,7 +35,6 @@ ) import numpy as np -import sympy from cirq import protocols, value from cirq._import import LazyLoader From 04fcff79e61977452940c8068e3ffc7aee602c2a Mon Sep 17 00:00:00 2001 From: daxfohl Date: Wed, 8 Dec 2021 22:24:44 -0800 Subject: [PATCH 07/89] fix CCO serialization --- cirq-core/cirq/json_resolver_cache.py | 2 ++ .../ClassicallyControlledOperation.json | 18 ++++++++++++------ .../ClassicallyControlledOperation.repr | 2 +- cirq-core/cirq/value/condition.py | 7 +++++++ 4 files changed, 22 insertions(+), 7 deletions(-) diff --git a/cirq-core/cirq/json_resolver_cache.py b/cirq-core/cirq/json_resolver_cache.py index 9c7613de0c3..48b7d49d6d8 100644 --- a/cirq-core/cirq/json_resolver_cache.py +++ b/cirq-core/cirq/json_resolver_cache.py @@ -97,6 +97,7 @@ def _parallel_gate_op(gate, qubits): 'ISwapPowGate': cirq.ISwapPowGate, 'IdentityGate': cirq.IdentityGate, 'InitObsSetting': cirq.work.InitObsSetting, + 'KeyCondition': cirq.KeyCondition, 'KrausChannel': cirq.KrausChannel, 'LinearDict': cirq.LinearDict, 'LineQubit': cirq.LineQubit, @@ -143,6 +144,7 @@ def _parallel_gate_op(gate, qubits): 'StabilizerStateChForm': cirq.StabilizerStateChForm, 'SwapPowGate': cirq.SwapPowGate, 'SymmetricalQidPair': cirq.SymmetricalQidPair, + 'SympyCondition': cirq.SympyCondition, 'TaggedOperation': cirq.TaggedOperation, 'TiltedSquareLattice': cirq.TiltedSquareLattice, 'TrialResult': cirq.Result, # keep support for Cirq < 0.11. diff --git a/cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.json b/cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.json index a22c2720095..8fbae9b27c7 100644 --- a/cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.json +++ b/cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.json @@ -2,14 +2,20 @@ "cirq_type": "ClassicallyControlledOperation", "conditions": [ { - "cirq_type": "MeasurementKey", - "name": "a", - "path": [] + "cirq_type": "KeyCondition", + "key": { + "cirq_type": "MeasurementKey", + "name": "a", + "path": [] + } }, { - "cirq_type": "MeasurementKey", - "name": "b", - "path": [] + "cirq_type": "KeyCondition", + "key": { + "cirq_type": "MeasurementKey", + "name": "b", + "path": [] + } } ], "sub_operation": { diff --git a/cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.repr b/cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.repr index bfc25256cff..423551a8501 100644 --- a/cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.repr +++ b/cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.repr @@ -1 +1 @@ -cirq.ClassicallyControlledOperation(cirq.Y.on(cirq.NamedQubit('target')), [cirq.MeasurementKey('a'), cirq.MeasurementKey('b')]) \ No newline at end of file +cirq.ClassicallyControlledOperation(cirq.Y.on(cirq.NamedQubit('target')), [cirq.KeyCondition(key=cirq.MeasurementKey('a')), cirq.KeyCondition(key=cirq.MeasurementKey('b'))]) \ No newline at end of file diff --git a/cirq-core/cirq/value/condition.py b/cirq-core/cirq/value/condition.py index ef5f69a2621..0f2a4fb3568 100644 --- a/cirq-core/cirq/value/condition.py +++ b/cirq-core/cirq/value/condition.py @@ -24,6 +24,7 @@ import dataclasses import sympy +from cirq.protocols import json_serialization from cirq.value import measurement_key if TYPE_CHECKING: @@ -66,6 +67,9 @@ def resolve(self, measurements: Mapping[str, Sequence[int]]) -> bool: raise ValueError(f'Measurement key {key} missing when testing classical control') return any(measurements[key]) + def _json_dict_(self): + return json_serialization.dataclass_json_dict(self) + @dataclasses.dataclass(frozen=True) class SympyCondition(Condition): @@ -94,6 +98,9 @@ def value(k): replacements = {f'x{i}': value(k) for i, k in enumerate(self.keys)} return bool(self.expr.subs(replacements)) + def _json_dict_(self): + return json_serialization.dataclass_json_dict(self) + def parse_sympy_condition(s: str) -> Optional['cirq.SympyCondition']: in_key = False From b3c344e4521ef379f8b2a26bfb6f0d302119f72d Mon Sep 17 00:00:00 2001 From: daxfohl Date: Wed, 8 Dec 2021 22:30:53 -0800 Subject: [PATCH 08/89] fix CCO serialization --- cirq-core/cirq/ops/classically_controlled_operation.py | 3 +-- cirq-core/cirq/value/condition.py | 9 +++++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/cirq-core/cirq/ops/classically_controlled_operation.py b/cirq-core/cirq/ops/classically_controlled_operation.py index c5df8f38a1f..7984b43a5fe 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation.py +++ b/cirq-core/cirq/ops/classically_controlled_operation.py @@ -195,6 +195,5 @@ def _control_keys_(self) -> FrozenSet[value.MeasurementKey]: def _qasm_(self, args: 'cirq.QasmArgs') -> Optional[str]: args.validate_version('2.0') - keys = [f'm_{key}!=0' for key in self._conditions] - all_keys = " && ".join(keys) + all_keys = " && ".join(c.qasm for c in self._conditions) return args.format('if ({0}) {1}', all_keys, protocols.qasm(self._sub_operation, args=args)) diff --git a/cirq-core/cirq/value/condition.py b/cirq-core/cirq/value/condition.py index 0f2a4fb3568..e64fd8a304f 100644 --- a/cirq-core/cirq/value/condition.py +++ b/cirq-core/cirq/value/condition.py @@ -45,6 +45,11 @@ def with_keys(self, keys: Tuple['cirq.MeasurementKey', ...]): def resolve(self, measurements: Mapping[str, Sequence[int]]) -> bool: """Resolves the condition based on the measurements.""" + @property + def qasm(self): + """Returns the qasm of this condition.""" + raise NotImplementedError() + @dataclasses.dataclass(frozen=True) class KeyCondition(Condition): @@ -70,6 +75,10 @@ def resolve(self, measurements: Mapping[str, Sequence[int]]) -> bool: def _json_dict_(self): return json_serialization.dataclass_json_dict(self) + @property + def qasm(self): + return f'm_{self.key}!=0' + @dataclasses.dataclass(frozen=True) class SympyCondition(Condition): From b8ff20a437fd65d314776fc1e47c1260502d531b Mon Sep 17 00:00:00 2001 From: daxfohl Date: Wed, 8 Dec 2021 22:38:41 -0800 Subject: [PATCH 09/89] add json reprs for conditions --- .../cirq/protocols/json_test_data/KeyCondition.json | 8 ++++++++ .../cirq/protocols/json_test_data/KeyCondition.repr | 1 + .../cirq/protocols/json_test_data/SympyCondition.json | 11 +++++++++++ .../cirq/protocols/json_test_data/SympyCondition.repr | 1 + cirq-core/cirq/value/condition.py | 6 +++++- 5 files changed, 26 insertions(+), 1 deletion(-) create mode 100644 cirq-core/cirq/protocols/json_test_data/KeyCondition.json create mode 100644 cirq-core/cirq/protocols/json_test_data/KeyCondition.repr create mode 100644 cirq-core/cirq/protocols/json_test_data/SympyCondition.json create mode 100644 cirq-core/cirq/protocols/json_test_data/SympyCondition.repr diff --git a/cirq-core/cirq/protocols/json_test_data/KeyCondition.json b/cirq-core/cirq/protocols/json_test_data/KeyCondition.json new file mode 100644 index 00000000000..f5b81ba63dc --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/KeyCondition.json @@ -0,0 +1,8 @@ +{ + "cirq_type": "KeyCondition", + "key": { + "cirq_type": "MeasurementKey", + "name": "a", + "path": [] + } +} \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/KeyCondition.repr b/cirq-core/cirq/protocols/json_test_data/KeyCondition.repr new file mode 100644 index 00000000000..fb9fa3232ec --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/KeyCondition.repr @@ -0,0 +1 @@ +cirq.KeyCondition(key=cirq.MeasurementKey('a')) \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/SympyCondition.json b/cirq-core/cirq/protocols/json_test_data/SympyCondition.json new file mode 100644 index 00000000000..0a8c2522a69 --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/SympyCondition.json @@ -0,0 +1,11 @@ +{ + "cirq_type": "SympyCondition", + "expr": "x > 5", + "control_keys": [ + { + "cirq_type": "MeasurementKey", + "name": "a", + "path": [] + } + ] +} \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/SympyCondition.repr b/cirq-core/cirq/protocols/json_test_data/SympyCondition.repr new file mode 100644 index 00000000000..9b465aaa679 --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/SympyCondition.repr @@ -0,0 +1 @@ +cirq.SympyCondition(control_keys=[cirq.MeasurementKey('a')], expr='x > 5') \ No newline at end of file diff --git a/cirq-core/cirq/value/condition.py b/cirq-core/cirq/value/condition.py index e64fd8a304f..a3b1483e244 100644 --- a/cirq-core/cirq/value/condition.py +++ b/cirq-core/cirq/value/condition.py @@ -46,9 +46,9 @@ def resolve(self, measurements: Mapping[str, Sequence[int]]) -> bool: """Resolves the condition based on the measurements.""" @property + @abc.abstractmethod def qasm(self): """Returns the qasm of this condition.""" - raise NotImplementedError() @dataclasses.dataclass(frozen=True) @@ -110,6 +110,10 @@ def value(k): def _json_dict_(self): return json_serialization.dataclass_json_dict(self) + @property + def qasm(self): + raise NotImplementedError() + def parse_sympy_condition(s: str) -> Optional['cirq.SympyCondition']: in_key = False From aa998051eb70f7a1a0c62b568815e1d624751ec4 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Wed, 8 Dec 2021 23:26:04 -0800 Subject: [PATCH 10/89] add support for qudits in conditions --- cirq-core/cirq/contrib/quimb/mps_simulator.py | 7 +++++-- .../ops/classically_controlled_operation.py | 3 ++- cirq-core/cirq/sim/act_on_args.py | 10 +++++++++ cirq-core/cirq/sim/act_on_args_container.py | 19 ++++++++++++----- .../cirq/sim/act_on_density_matrix_args.py | 3 ++- .../cirq/sim/act_on_state_vector_args.py | 3 ++- .../clifford/act_on_clifford_tableau_args.py | 5 +++-- .../act_on_stabilizer_ch_form_args.py | 5 +++-- .../cirq/sim/clifford/clifford_simulator.py | 4 +++- .../cirq/sim/density_matrix_simulator.py | 2 ++ cirq-core/cirq/sim/simulator_base.py | 11 ++++++++-- cirq-core/cirq/sim/simulator_base_test.py | 2 ++ cirq-core/cirq/sim/sparse_simulator.py | 3 +++ cirq-core/cirq/value/condition.py | 21 +++++++++++++++---- .../calibration/engine_simulator.py | 1 + 15 files changed, 78 insertions(+), 21 deletions(-) diff --git a/cirq-core/cirq/contrib/quimb/mps_simulator.py b/cirq-core/cirq/contrib/quimb/mps_simulator.py index 4247bcf6b97..9a11c0664cc 100644 --- a/cirq-core/cirq/contrib/quimb/mps_simulator.py +++ b/cirq-core/cirq/contrib/quimb/mps_simulator.py @@ -19,7 +19,7 @@ import dataclasses import math -from typing import Any, Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Union +from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, TYPE_CHECKING, Union import numpy as np import quimb.tensor as qtn @@ -92,6 +92,7 @@ def _create_partial_act_on_args( initial_state: Union[int, 'MPSState'], qubits: Sequence['cirq.Qid'], logs: Dict[str, Any], + measured_qubits: Dict[str, Tuple['cirq.Qid', ...]], ) -> 'MPSState': """Creates MPSState args for simulating the Circuit. @@ -116,6 +117,7 @@ def _create_partial_act_on_args( grouping=self.grouping, initial_state=initial_state, log_of_measurement_results=logs, + measured_qubits=measured_qubits, ) def _create_step_result( @@ -229,6 +231,7 @@ def __init__( grouping: Optional[Dict['cirq.Qid', int]] = None, initial_state: int = 0, log_of_measurement_results: Dict[str, Any] = None, + measured_qubits: Dict[str, Tuple['cirq.Qid', ...]] = None, ): """Creates and MPSState @@ -246,7 +249,7 @@ def __init__( Raises: ValueError: If the grouping does not cover the qubits. """ - super().__init__(prng, qubits, log_of_measurement_results) + super().__init__(prng, qubits, log_of_measurement_results, measured_qubits) qubit_map = self.qubit_map self.grouping = qubit_map if grouping is None else grouping if self.grouping.keys() != self.qubit_map.keys(): diff --git a/cirq-core/cirq/ops/classically_controlled_operation.py b/cirq-core/cirq/ops/classically_controlled_operation.py index 7984b43a5fe..b2784298d09 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation.py +++ b/cirq-core/cirq/ops/classically_controlled_operation.py @@ -165,7 +165,8 @@ def _json_dict_(self) -> Dict[str, Any]: } def _act_on_(self, args: 'cirq.ActOnArgs') -> bool: - if all(c.resolve(args.log_of_measurement_results) for c in self._conditions): + measurements, qubits = args.log_of_measurement_results, args.measured_qubits + if all(c.resolve(measurements, qubits) for c in self._conditions): protocols.act_on(self._sub_operation, args) return True diff --git a/cirq-core/cirq/sim/act_on_args.py b/cirq-core/cirq/sim/act_on_args.py index d28cda196f8..c9479375f5c 100644 --- a/cirq-core/cirq/sim/act_on_args.py +++ b/cirq-core/cirq/sim/act_on_args.py @@ -47,6 +47,7 @@ def __init__( prng: np.random.RandomState = None, qubits: Sequence['cirq.Qid'] = None, log_of_measurement_results: Dict[str, List[int]] = None, + measured_qubits: Dict[str, Tuple['cirq.Qid', ...]] = None, ): """Inits ActOnArgs. @@ -65,9 +66,12 @@ def __init__( qubits = () if log_of_measurement_results is None: log_of_measurement_results = {} + if measured_qubits is None: + measured_qubits = {} self._set_qubits(qubits) self.prng = prng self._log_of_measurement_results = log_of_measurement_results + self._measured_qubits = measured_qubits def _set_qubits(self, qubits: Sequence['cirq.Qid']): self._qubits = tuple(qubits) @@ -91,6 +95,7 @@ def measure(self, qubits: Sequence['cirq.Qid'], key: str, invert_mask: Sequence[ if key in self._log_of_measurement_results: raise ValueError(f"Measurement already logged to key {key!r}") self._log_of_measurement_results[key] = corrected + self._measured_qubits[key] = tuple(qubits) def get_axes(self, qubits: Sequence['cirq.Qid']) -> List[int]: return [self.qubit_map[q] for q in qubits] @@ -105,6 +110,7 @@ def copy(self: TSelf) -> TSelf: args = copy.copy(self) self._on_copy(args) args._log_of_measurement_results = self.log_of_measurement_results.copy() + args._measured_qubits = self.measured_qubits.copy() return args def _on_copy(self: TSelf, args: TSelf): @@ -184,6 +190,10 @@ def _on_transpose_to_qubit_order(self: TSelf, qubits: Sequence['cirq.Qid'], targ def log_of_measurement_results(self) -> Dict[str, List[int]]: return self._log_of_measurement_results + @property + def measured_qubits(self) -> Dict[str, Tuple['cirq.Qid', ...]]: + return self._measured_qubits + @property def qubits(self) -> Tuple['cirq.Qid', ...]: return self._qubits diff --git a/cirq-core/cirq/sim/act_on_args_container.py b/cirq-core/cirq/sim/act_on_args_container.py index 4433c0ac2bd..55c20a9c1ef 100644 --- a/cirq-core/cirq/sim/act_on_args_container.py +++ b/cirq-core/cirq/sim/act_on_args_container.py @@ -14,15 +14,16 @@ from collections import abc from typing import ( + Any, Dict, - TYPE_CHECKING, Generic, - Sequence, - Optional, Iterator, - Any, - Tuple, List, + Mapping, + Optional, + Sequence, + Tuple, + TYPE_CHECKING, Union, ) @@ -51,6 +52,7 @@ def __init__( qubits: Sequence['cirq.Qid'], split_untangled_states: bool, log_of_measurement_results: Dict[str, Any], + measured_qubits: Dict[str, Tuple['cirq.Qid', ...]] = None, ): """Initializes the class. @@ -68,6 +70,7 @@ def __init__( self._qubits = tuple(qubits) self.split_untangled_states = split_untangled_states self._log_of_measurement_results = log_of_measurement_results + self._measured_qubits = measured_qubits or {} def create_merged_state(self) -> TActOnArgs: if not self.split_untangled_states: @@ -132,9 +135,11 @@ def _act_on_fallback_( def copy(self) -> 'ActOnArgsContainer[TActOnArgs]': logs = self.log_of_measurement_results.copy() + measured_qubits = self._measured_qubits.copy() copies = {a: a.copy() for a in set(self.args.values())} for copy in copies.values(): copy._log_of_measurement_results = logs + copy._measured_qubits = measured_qubits args = {q: copies[a] for q, a in self.args.items()} return ActOnArgsContainer(args, self.qubits, self.split_untangled_states, logs) @@ -146,6 +151,10 @@ def qubits(self) -> Tuple['cirq.Qid', ...]: def log_of_measurement_results(self) -> Dict[str, Any]: return self._log_of_measurement_results + @property + def measured_qubits(self) -> Mapping[str, Tuple['cirq.Qid', ...]]: + return self._measured_qubits + def sample( self, qubits: List[ops.Qid], diff --git a/cirq-core/cirq/sim/act_on_density_matrix_args.py b/cirq-core/cirq/sim/act_on_density_matrix_args.py index d1b529d87a9..38feb74309c 100644 --- a/cirq-core/cirq/sim/act_on_density_matrix_args.py +++ b/cirq-core/cirq/sim/act_on_density_matrix_args.py @@ -40,6 +40,7 @@ def __init__( prng: np.random.RandomState, log_of_measurement_results: Dict[str, Any], qubits: Sequence['cirq.Qid'] = None, + measured_qubits: Dict[str, Tuple['cirq.Qid', ...]] = None, ): """Inits ActOnDensityMatrixArgs. @@ -60,7 +61,7 @@ def __init__( log_of_measurement_results: A mutable object that measurements are being recorded into. """ - super().__init__(prng, qubits, log_of_measurement_results) + super().__init__(prng, qubits, log_of_measurement_results, measured_qubits) self.target_tensor = target_tensor self.available_buffer = available_buffer self.qid_shape = qid_shape diff --git a/cirq-core/cirq/sim/act_on_state_vector_args.py b/cirq-core/cirq/sim/act_on_state_vector_args.py index 11211f491da..815b193ad1b 100644 --- a/cirq-core/cirq/sim/act_on_state_vector_args.py +++ b/cirq-core/cirq/sim/act_on_state_vector_args.py @@ -43,6 +43,7 @@ def __init__( prng: np.random.RandomState, log_of_measurement_results: Dict[str, Any], qubits: Sequence['cirq.Qid'] = None, + measured_qubits: Dict[str, Tuple['cirq.Qid', ...]] = None, ): """Inits ActOnStateVectorArgs. @@ -63,7 +64,7 @@ def __init__( log_of_measurement_results: A mutable object that measurements are being recorded into. """ - super().__init__(prng, qubits, log_of_measurement_results) + super().__init__(prng, qubits, log_of_measurement_results, measured_qubits) self.target_tensor = target_tensor self.available_buffer = available_buffer diff --git a/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py b/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py index 4f5bcce2619..9849ae40b2a 100644 --- a/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py +++ b/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py @@ -14,7 +14,7 @@ """A protocol for implementing high performance clifford tableau evolutions for Clifford Simulator.""" -from typing import Any, Dict, TYPE_CHECKING, List, Sequence, Union +from typing import Any, Dict, TYPE_CHECKING, List, Sequence, Tuple, Union import numpy as np @@ -43,6 +43,7 @@ def __init__( prng: np.random.RandomState, log_of_measurement_results: Dict[str, Any], qubits: Sequence['cirq.Qid'] = None, + measured_qubits: Dict[str, Tuple['cirq.Qid', ...]] = None, ): """Inits ActOnCliffordTableauArgs. @@ -57,7 +58,7 @@ def __init__( log_of_measurement_results: A mutable object that measurements are being recorded into. """ - super().__init__(prng, qubits, log_of_measurement_results) + super().__init__(prng, qubits, log_of_measurement_results, measured_qubits) self.tableau = tableau def _act_on_fallback_( diff --git a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py index 62154568052..a7011a8f1eb 100644 --- a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py +++ b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, TYPE_CHECKING, List, Sequence, Union +from typing import Any, Dict, TYPE_CHECKING, List, Sequence, Tuple, Union import numpy as np @@ -42,6 +42,7 @@ def __init__( prng: np.random.RandomState, log_of_measurement_results: Dict[str, Any], qubits: Sequence['cirq.Qid'] = None, + measured_qubits: Dict[str, Tuple['cirq.Qid', ...]] = None, ): """Initializes with the given state and the axes for the operation. Args: @@ -55,7 +56,7 @@ def __init__( log_of_measurement_results: A mutable object that measurements are being recorded into. """ - super().__init__(prng, qubits, log_of_measurement_results) + super().__init__(prng, qubits, log_of_measurement_results, measured_qubits) self.state = state def _act_on_fallback_( diff --git a/cirq-core/cirq/sim/clifford/clifford_simulator.py b/cirq-core/cirq/sim/clifford/clifford_simulator.py index f017d07bcf4..7e6175bbbc8 100644 --- a/cirq-core/cirq/sim/clifford/clifford_simulator.py +++ b/cirq-core/cirq/sim/clifford/clifford_simulator.py @@ -29,7 +29,7 @@ to state vector amplitudes. """ -from typing import Any, Dict, List, Sequence, Union +from typing import Any, Dict, List, Sequence, Tuple, Union import numpy as np @@ -69,6 +69,7 @@ def _create_partial_act_on_args( initial_state: Union[int, clifford.ActOnStabilizerCHFormArgs], qubits: Sequence['cirq.Qid'], logs: Dict[str, Any], + measured_qubits: Dict[str, Tuple['cirq.Qid', ...]], ) -> clifford.ActOnStabilizerCHFormArgs: """Creates the ActOnStabilizerChFormArgs for a circuit. @@ -94,6 +95,7 @@ def _create_partial_act_on_args( prng=self._prng, log_of_measurement_results=logs, qubits=qubits, + measured_qubits=measured_qubits, ) def _create_step_result( diff --git a/cirq-core/cirq/sim/density_matrix_simulator.py b/cirq-core/cirq/sim/density_matrix_simulator.py index bad0553f0df..72ad6290d54 100644 --- a/cirq-core/cirq/sim/density_matrix_simulator.py +++ b/cirq-core/cirq/sim/density_matrix_simulator.py @@ -176,6 +176,7 @@ def _create_partial_act_on_args( initial_state: Union[np.ndarray, 'cirq.STATE_VECTOR_LIKE', 'cirq.ActOnDensityMatrixArgs'], qubits: Sequence['cirq.Qid'], logs: Dict[str, Any], + measured_qubits: Dict[str, Tuple['cirq.Qid', ...]], ) -> 'cirq.ActOnDensityMatrixArgs': """Creates the ActOnDensityMatrixArgs for a circuit. @@ -208,6 +209,7 @@ def _create_partial_act_on_args( qid_shape=qid_shape, prng=self._prng, log_of_measurement_results=logs, + measured_qubits=measured_qubits, ) def _can_be_in_run_prefix(self, val: Any): diff --git a/cirq-core/cirq/sim/simulator_base.py b/cirq-core/cirq/sim/simulator_base.py index 7d017b20a79..7bbaad1231c 100644 --- a/cirq-core/cirq/sim/simulator_base.py +++ b/cirq-core/cirq/sim/simulator_base.py @@ -124,6 +124,7 @@ def _create_partial_act_on_args( initial_state: Any, qubits: Sequence['cirq.Qid'], logs: Dict[str, Any], + measured_qubits: Dict[str, Tuple['cirq.Qid', ...]], ) -> TActOnArgs: """Creates an instance of the TActOnArgs class for the simulator. @@ -340,6 +341,7 @@ def _create_act_on_args( return initial_state log: Dict[str, Any] = {} + measured_qubits: Dict[str, Tuple['cirq.Qid', ...]] = {} if self._split_untangled_states: args_map: Dict[Optional['cirq.Qid'], TActOnArgs] = {} if isinstance(initial_state, int): @@ -348,6 +350,7 @@ def _create_act_on_args( initial_state=initial_state % q.dimension, qubits=[q], logs=log, + measured_qubits=measured_qubits, ) initial_state = int(initial_state / q.dimension) else: @@ -355,16 +358,20 @@ def _create_act_on_args( initial_state=initial_state, qubits=qubits, logs=log, + measured_qubits=measured_qubits, ) for q in qubits: args_map[q] = args - args_map[None] = self._create_partial_act_on_args(0, (), log) - return ActOnArgsContainer(args_map, qubits, self._split_untangled_states, log) + args_map[None] = self._create_partial_act_on_args(0, (), log, measured_qubits) + return ActOnArgsContainer( + args_map, qubits, self._split_untangled_states, log, measured_qubits + ) else: return self._create_partial_act_on_args( initial_state=initial_state, qubits=qubits, logs=log, + measured_qubits=measured_qubits, ) diff --git a/cirq-core/cirq/sim/simulator_base_test.py b/cirq-core/cirq/sim/simulator_base_test.py index d98e31d1cd0..ee89d2c1cf8 100644 --- a/cirq-core/cirq/sim/simulator_base_test.py +++ b/cirq-core/cirq/sim/simulator_base_test.py @@ -112,6 +112,7 @@ def _create_partial_act_on_args( initial_state: Any, qubits: Sequence['cirq.Qid'], logs: Dict[str, Any], + measured_qubits, ) -> CountingActOnArgs: return CountingActOnArgs(qubits=qubits, state=initial_state, logs=logs) @@ -142,6 +143,7 @@ def _create_partial_act_on_args( initial_state: Any, qubits: Sequence['cirq.Qid'], logs: Dict[str, Any], + measured_qubits, ) -> CountingActOnArgs: return SplittableCountingActOnArgs(qubits=qubits, state=initial_state, logs=logs) diff --git a/cirq-core/cirq/sim/sparse_simulator.py b/cirq-core/cirq/sim/sparse_simulator.py index 7a72f6e2299..b4447602f49 100644 --- a/cirq-core/cirq/sim/sparse_simulator.py +++ b/cirq-core/cirq/sim/sparse_simulator.py @@ -23,6 +23,7 @@ TYPE_CHECKING, Union, Sequence, + Tuple, Optional, ) @@ -175,6 +176,7 @@ def _create_partial_act_on_args( initial_state: Union['cirq.STATE_VECTOR_LIKE', 'cirq.ActOnStateVectorArgs'], qubits: Sequence['cirq.Qid'], logs: Dict[str, Any], + measured_qubits: Dict[str, Tuple['cirq.Qid', ...]], ): """Creates the ActOnStateVectorArgs for a circuit. @@ -203,6 +205,7 @@ def _create_partial_act_on_args( qubits=qubits, prng=self._prng, log_of_measurement_results=logs, + measured_qubits=measured_qubits, ) def _create_step_result( diff --git a/cirq-core/cirq/value/condition.py b/cirq-core/cirq/value/condition.py index a3b1483e244..4a22a2a4569 100644 --- a/cirq-core/cirq/value/condition.py +++ b/cirq-core/cirq/value/condition.py @@ -42,7 +42,11 @@ def with_keys(self, keys: Tuple['cirq.MeasurementKey', ...]): """Replaces the control keys.""" @abc.abstractmethod - def resolve(self, measurements: Mapping[str, Sequence[int]]) -> bool: + def resolve( + self, + measurements: Mapping[str, Sequence[int]], + measured_qubits: Mapping[str, Sequence['cirq.Qid']], + ) -> bool: """Resolves the condition based on the measurements.""" @property @@ -66,7 +70,11 @@ def with_keys(self, keys: Tuple['cirq.MeasurementKey', ...]): def __str__(self): return str(self.key) - def resolve(self, measurements: Mapping[str, Sequence[int]]) -> bool: + def resolve( + self, + measurements: Mapping[str, Sequence[int]], + measured_qubits: Mapping[str, Sequence['cirq.Qid']], + ) -> bool: key = str(self.key) if key not in measurements: raise ValueError(f'Measurement key {key} missing when testing classical control') @@ -96,13 +104,18 @@ def with_keys(self, keys: Tuple['cirq.MeasurementKey', ...]): def __str__(self): return f'({self.expr}, {self.control_keys})' - def resolve(self, measurements: Mapping[str, Sequence[int]]) -> bool: + def resolve( + self, + measurements: Mapping[str, Sequence[int]], + measured_qubits: Mapping[str, Sequence['cirq.Qid']], + ) -> bool: missing = [str(k) for k in self.keys if str(k) not in measurements] if missing: raise ValueError(f'Measurement keys {missing} missing when testing classical control') def value(k): - return sum(v * 2 ** i for i, v in enumerate(measurements[str(k)])) + zipped = zip(measurements[str(k)], measured_qubits[str(k)]) + return sum(v * q.dimension ** i for i, (v, q) in enumerate(zipped)) replacements = {f'x{i}': value(k) for i, k in enumerate(self.keys)} return bool(self.expr.subs(replacements)) diff --git a/cirq-google/cirq_google/calibration/engine_simulator.py b/cirq-google/cirq_google/calibration/engine_simulator.py index d2c383f059b..573fa7b6f0f 100644 --- a/cirq-google/cirq_google/calibration/engine_simulator.py +++ b/cirq-google/cirq_google/calibration/engine_simulator.py @@ -475,6 +475,7 @@ def _create_partial_act_on_args( initial_state: Union[int, cirq.ActOnStateVectorArgs], qubits: Sequence[cirq.Qid], logs: Dict[str, Any], + measured_qubits: Dict[str, Tuple['cirq.Qid', ...]], ) -> cirq.ActOnStateVectorArgs: # Needs an implementation since it's abstract but will never actually be called. raise NotImplementedError() From cbb029b1be64dc911ace49020c18f51d6ca1931f Mon Sep 17 00:00:00 2001 From: daxfohl Date: Thu, 9 Dec 2021 10:53:27 -0800 Subject: [PATCH 11/89] add test --- .../ops/classically_controlled_operation_test.py | 15 +++++++++++++++ cirq-core/cirq/value/condition.py | 4 ++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/cirq-core/cirq/ops/classically_controlled_operation_test.py b/cirq-core/cirq/ops/classically_controlled_operation_test.py index a7d3f637ffb..4b799ae73b5 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation_test.py +++ b/cirq-core/cirq/ops/classically_controlled_operation_test.py @@ -412,3 +412,18 @@ def test_unmeasured_condition(): ValueError, match='Measurement key a missing when testing classical control' ): _ = cirq.Simulator().simulate(bad_circuit) + + +def test_sympy(): + for i in range(9): + for j in range(8): + bitstring = cirq.big_endian_int_to_bits(j, bit_count=3) + circuit = cirq.Circuit() + for k in range(3): + circuit.append(cirq.X(cirq.LineQubit(k)) ** bitstring[k]) + circuit.append(cirq.measure(*cirq.LineQubit.range(3), key='m')) + circuit.append(cirq.X(cirq.LineQubit(3)).with_classical_controls(f'{{m}} > {i}')) + circuit.append(cirq.measure(cirq.LineQubit(3), key='a')) + result = cirq.Simulator().run(circuit) + expected = 1 if j > i else 0 + assert result.measurements['a'][0][0] == expected diff --git a/cirq-core/cirq/value/condition.py b/cirq-core/cirq/value/condition.py index a3b1483e244..56703e2a06a 100644 --- a/cirq-core/cirq/value/condition.py +++ b/cirq-core/cirq/value/condition.py @@ -25,7 +25,7 @@ import sympy from cirq.protocols import json_serialization -from cirq.value import measurement_key +from cirq.value import digits, measurement_key if TYPE_CHECKING: import cirq @@ -102,7 +102,7 @@ def resolve(self, measurements: Mapping[str, Sequence[int]]) -> bool: raise ValueError(f'Measurement keys {missing} missing when testing classical control') def value(k): - return sum(v * 2 ** i for i, v in enumerate(measurements[str(k)])) + return digits.big_endian_bits_to_int(measurements[str(k)]) replacements = {f'x{i}': value(k) for i, k in enumerate(self.keys)} return bool(self.expr.subs(replacements)) From efce2f914afd847e0f757179d1c8e50f314bd8d5 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Thu, 9 Dec 2021 11:05:54 -0800 Subject: [PATCH 12/89] tests --- cirq-core/cirq/ops/classically_controlled_operation_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cirq-core/cirq/ops/classically_controlled_operation_test.py b/cirq-core/cirq/ops/classically_controlled_operation_test.py index 4b799ae73b5..e34ae1bfda4 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation_test.py +++ b/cirq-core/cirq/ops/classically_controlled_operation_test.py @@ -333,7 +333,7 @@ def test_key_set_in_subcircuit_outer_scope(): def test_condition_flattening(): q0 = cirq.LineQubit(0) op = cirq.X(q0).with_classical_controls('a').with_classical_controls('b') - assert set(map(str, op.classical_controls)) == {'a', 'b'} + assert set(map(str, op._conditions)) == {'a', 'b'} assert isinstance(op._sub_operation, cirq.GateOperation) @@ -341,6 +341,7 @@ def test_condition_stacking(): q0 = cirq.LineQubit(0) op = cirq.X(q0).with_classical_controls('a').with_tags('t').with_classical_controls('b') assert set(map(str, cirq.control_keys(op))) == {'a', 'b'} + assert set(map(str, op.classical_controls)) == {'a', 'b'} assert not op.tags @@ -355,6 +356,7 @@ def test_condition_removal(): ) op = op.without_classical_controls() assert not cirq.control_keys(op) + assert not op.classical_controls assert set(map(str, op.tags)) == {'t1'} From b34994bdb80acfc9d26166b6577d76f9242b1aa9 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Thu, 9 Dec 2021 11:33:38 -0800 Subject: [PATCH 13/89] tests --- .../ops/classically_controlled_operation.py | 6 ++--- cirq-core/cirq/value/condition.py | 22 ++++++++++--------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/cirq-core/cirq/ops/classically_controlled_operation.py b/cirq-core/cirq/ops/classically_controlled_operation.py index 7984b43a5fe..780bbc869bc 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation.py +++ b/cirq-core/cirq/ops/classically_controlled_operation.py @@ -174,15 +174,15 @@ def _with_measurement_key_mapping_( ) -> 'ClassicallyControlledOperation': def map_condition(condition: 'cirq.Condition') -> 'cirq.Condition': keys = [protocols.with_measurement_key_mapping(k, key_map) for k in condition.keys] - return condition.with_keys(tuple(keys)) + return condition.with_keys(*keys) conditions = [map_condition(c) for c in self._conditions] return self._sub_operation.with_classical_controls(*conditions) def _with_key_path_prefix_(self, path: Tuple[str, ...]) -> 'ClassicallyControlledOperation': def map_condition(condition: 'cirq.Condition') -> 'cirq.Condition': - keys = tuple(protocols.with_key_path_prefix(k, path) for k in condition.keys) - return condition.with_keys(keys) + keys = [protocols.with_key_path_prefix(k, path) for k in condition.keys] + return condition.with_keys(*keys) conditions = [map_condition(c) for c in self._conditions] return self._sub_operation.with_classical_controls(*conditions) diff --git a/cirq-core/cirq/value/condition.py b/cirq-core/cirq/value/condition.py index 56703e2a06a..7a2f0d6db1d 100644 --- a/cirq-core/cirq/value/condition.py +++ b/cirq-core/cirq/value/condition.py @@ -38,7 +38,7 @@ def keys(self) -> Tuple['cirq.MeasurementKey', ...]: """Gets the control keys.""" @abc.abstractmethod - def with_keys(self, keys: Tuple['cirq.MeasurementKey', ...]): + def with_keys(self, *keys: 'cirq.MeasurementKey'): """Replaces the control keys.""" @abc.abstractmethod @@ -59,7 +59,7 @@ class KeyCondition(Condition): def keys(self): return (self.key,) - def with_keys(self, keys: Tuple['cirq.MeasurementKey', ...]): + def with_keys(self, *keys: 'cirq.MeasurementKey'): assert len(keys) == 1 return KeyCondition(keys[0]) @@ -89,12 +89,13 @@ class SympyCondition(Condition): def keys(self): return self.control_keys - def with_keys(self, keys: Tuple['cirq.MeasurementKey', ...]): + def with_keys(self, *keys: 'cirq.MeasurementKey'): assert len(keys) == len(self.control_keys) return dataclasses.replace(self, control_keys=keys) def __str__(self): - return f'({self.expr}, {self.control_keys})' + replacements = {f'x{i}': str(key) for i, key in enumerate(self.control_keys)} + return f"{self.expr.subs(replacements)}" def resolve(self, measurements: Mapping[str, Sequence[int]]) -> bool: missing = [str(k) for k in self.keys if str(k) not in measurements] @@ -115,7 +116,7 @@ def qasm(self): raise NotImplementedError() -def parse_sympy_condition(s: str) -> Optional['cirq.SympyCondition']: +def parse_sympy_condition(s: str) -> 'cirq.SympyCondition': in_key = False key_count = 0 s_out = '' @@ -139,12 +140,13 @@ def parse_sympy_condition(s: str) -> Optional['cirq.SympyCondition']: key_name += c expr = sympy.sympify(s_out) if len(expr.free_symbols) != len(keys): - return None + raise ValueError(f"'{s}' is not a valid sympy condition") return SympyCondition(expr, tuple(keys)) def parse_condition(s: str) -> 'cirq.Condition': - c = parse_sympy_condition(s) or measurement_key.MeasurementKey.parse_serialized(s) - if c is None: - raise ValueError(f"'{s}' is not a valid condition") - return c + try: + return parse_sympy_condition(s) + except ValueError: + pass + return measurement_key.MeasurementKey.parse_serialized(s) From 5537397bba2f46d1834bdcf4ba70deca4c1c6445 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Thu, 9 Dec 2021 12:24:36 -0800 Subject: [PATCH 14/89] test --- cirq-core/cirq/value/condition_test.py | 81 ++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 cirq-core/cirq/value/condition_test.py diff --git a/cirq-core/cirq/value/condition_test.py b/cirq-core/cirq/value/condition_test.py new file mode 100644 index 00000000000..2bd15cb0480 --- /dev/null +++ b/cirq-core/cirq/value/condition_test.py @@ -0,0 +1,81 @@ +# Copyright 2021 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +import pytest + +import cirq + +key_a = cirq.MeasurementKey('a') +key_b = cirq.MeasurementKey('b') +init_key_condition = cirq.KeyCondition(key_a) +init_sympy_condition = cirq.parse_sympy_condition('{a} >= 1') + + +def test_key_condition_with_keys(): + c = init_key_condition.with_keys(key_b) + assert c.key is key_b + + +def test_key_condition_str(): + assert str(init_key_condition) == 'a' + + +def test_key_condition_resolve(): + assert init_key_condition.resolve({'a': [1]}) + assert init_key_condition.resolve({'a': [2]}) + assert init_key_condition.resolve({'a': [0, 1]}) + assert init_key_condition.resolve({'a': [1, 0]}) + assert not init_key_condition.resolve({'a': [0]}) + assert not init_key_condition.resolve({'a': [0, 0]}) + assert not init_key_condition.resolve({'a': []}) + assert not init_key_condition.resolve({'a': [0], 'b': [1]}) + with pytest.raises(ValueError, match='Measurement key a missing when testing classical control'): + _ = init_key_condition.resolve({}) + with pytest.raises(ValueError, match='Measurement key a missing when testing classical control'): + _ = init_key_condition.resolve({'b': [1]}) + + +def test_key_condition_qasm(): + assert init_key_condition.qasm == 'm_a!=0' + + +def test_sympy_condition_with_keys(): + c = init_sympy_condition.with_keys(key_b) + assert c.keys == (key_b,) + + +def test_sympy_condition_str(): + assert str(init_sympy_condition) == "a >= 1" + + +def test_sympy_condition_resolve(): + assert init_sympy_condition.resolve({'a': [1]}) + assert init_sympy_condition.resolve({'a': [2]}) + assert init_sympy_condition.resolve({'a': [0, 1]}) + assert init_sympy_condition.resolve({'a': [1, 0]}) + assert not init_sympy_condition.resolve({'a': [0]}) + assert not init_sympy_condition.resolve({'a': [0, 0]}) + assert not init_sympy_condition.resolve({'a': []}) + assert not init_sympy_condition.resolve({'a': [0], 'b': [1]}) + with pytest.raises(ValueError, match=re.escape("Measurement keys ['a'] missing when testing classical control")): + _ = init_sympy_condition.resolve({}) + with pytest.raises(ValueError, match=re.escape("Measurement keys ['a'] missing when testing classical control")): + _ = init_sympy_condition.resolve({'b': [1]}) + + +def test_sympy_condition_qasm(): + with pytest.raises(NotImplementedError): + _ = init_sympy_condition.qasm From 73ed74b60e98043c8f257bf89b11a139889288f8 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Thu, 9 Dec 2021 12:44:42 -0800 Subject: [PATCH 15/89] format --- cirq-core/cirq/value/condition_test.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/cirq-core/cirq/value/condition_test.py b/cirq-core/cirq/value/condition_test.py index 2bd15cb0480..94e6a8a4db6 100644 --- a/cirq-core/cirq/value/condition_test.py +++ b/cirq-core/cirq/value/condition_test.py @@ -42,9 +42,13 @@ def test_key_condition_resolve(): assert not init_key_condition.resolve({'a': [0, 0]}) assert not init_key_condition.resolve({'a': []}) assert not init_key_condition.resolve({'a': [0], 'b': [1]}) - with pytest.raises(ValueError, match='Measurement key a missing when testing classical control'): + with pytest.raises( + ValueError, match='Measurement key a missing when testing classical control' + ): _ = init_key_condition.resolve({}) - with pytest.raises(ValueError, match='Measurement key a missing when testing classical control'): + with pytest.raises( + ValueError, match='Measurement key a missing when testing classical control' + ): _ = init_key_condition.resolve({'b': [1]}) @@ -70,9 +74,13 @@ def test_sympy_condition_resolve(): assert not init_sympy_condition.resolve({'a': [0, 0]}) assert not init_sympy_condition.resolve({'a': []}) assert not init_sympy_condition.resolve({'a': [0], 'b': [1]}) - with pytest.raises(ValueError, match=re.escape("Measurement keys ['a'] missing when testing classical control")): + with pytest.raises( + ValueError, match=re.escape("Measurement keys ['a'] missing when testing classical control") + ): _ = init_sympy_condition.resolve({}) - with pytest.raises(ValueError, match=re.escape("Measurement keys ['a'] missing when testing classical control")): + with pytest.raises( + ValueError, match=re.escape("Measurement keys ['a'] missing when testing classical control") + ): _ = init_sympy_condition.resolve({'b': [1]}) From a415235c5e0165ec5bdc4311d47205d90c9f6b43 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Thu, 9 Dec 2021 14:02:34 -0800 Subject: [PATCH 16/89] docstrings --- .../classically_controlled_operation_test.py | 14 +++++--- cirq-core/cirq/ops/raw_types.py | 1 + cirq-core/cirq/value/condition.py | 32 +++++++++++++++++++ 3 files changed, 42 insertions(+), 5 deletions(-) diff --git a/cirq-core/cirq/ops/classically_controlled_operation_test.py b/cirq-core/cirq/ops/classically_controlled_operation_test.py index e34ae1bfda4..bd8172baeea 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation_test.py +++ b/cirq-core/cirq/ops/classically_controlled_operation_test.py @@ -419,13 +419,17 @@ def test_unmeasured_condition(): def test_sympy(): for i in range(9): for j in range(8): + # Add X gates to put the circuit into a state representing bitstring(j), and measure bitstring = cirq.big_endian_int_to_bits(j, bit_count=3) circuit = cirq.Circuit() for k in range(3): circuit.append(cirq.X(cirq.LineQubit(k)) ** bitstring[k]) - circuit.append(cirq.measure(*cirq.LineQubit.range(3), key='m')) - circuit.append(cirq.X(cirq.LineQubit(3)).with_classical_controls(f'{{m}} > {i}')) - circuit.append(cirq.measure(cirq.LineQubit(3), key='a')) + circuit.append(cirq.measure(*cirq.LineQubit.range(3), key='m_j')) + + # Add a X(q3) conditional on the above measurement (which should be == `j`) being > `i` + circuit.append(cirq.X(cirq.LineQubit(3)).with_classical_controls(f'{{m_j}} > {i}')) + circuit.append(cirq.measure(cirq.LineQubit(3), key='q3')) result = cirq.Simulator().run(circuit) - expected = 1 if j > i else 0 - assert result.measurements['a'][0][0] == expected + + # q3 should now be set iff j > i. + assert result.measurements['q3'][0][0] == (j > i) diff --git a/cirq-core/cirq/ops/raw_types.py b/cirq-core/cirq/ops/raw_types.py index d278674c9d9..57eeff13bd7 100644 --- a/cirq-core/cirq/ops/raw_types.py +++ b/cirq-core/cirq/ops/raw_types.py @@ -593,6 +593,7 @@ def _commutes_( @property def classical_controls(self) -> FrozenSet['cirq.Condition']: + """The classical controls gating this operation.""" return frozenset() def with_classical_controls( diff --git a/cirq-core/cirq/value/condition.py b/cirq-core/cirq/value/condition.py index 7a2f0d6db1d..7df95c76fcf 100644 --- a/cirq-core/cirq/value/condition.py +++ b/cirq-core/cirq/value/condition.py @@ -32,6 +32,8 @@ class Condition(abc.ABC): + """A classical control condition that can gate an operation.""" + @property @abc.abstractmethod def keys(self) -> Tuple['cirq.MeasurementKey', ...]: @@ -53,6 +55,12 @@ def qasm(self): @dataclasses.dataclass(frozen=True) class KeyCondition(Condition): + """A classical control condition based on a single measurement key. + + This condition resolves to True iff the measurement key is non-zero at the + time of resolution. + """ + key: 'cirq.MeasurementKey' @property @@ -82,6 +90,23 @@ def qasm(self): @dataclasses.dataclass(frozen=True) class SympyCondition(Condition): + """A classical control condition based on a sympy expression. + + This condition resolves to True iff the sympy expression resolves to a + Truthy value when the measurement keys are substituted in as the free + variables. + + To account for the fact that measurement key strings can contain characters + not allowed in sympy variables, we use x0..xN for the free variables and + substitute the measurement keys in by order. For instance a condition with + `expr='x0 > x1', keys=['0:A', '0:B']` would resolve to True iff key `0:A` + is greater than key `0:B`. + + The `cirq.parse_sympy_condition` function automates setting this up + correctly. To create the above expression, one would call + `cirq.parse_sympy_condition('{0:A}, {0:B}')`. + """ + expr: sympy.Expr control_keys: Tuple['cirq.MeasurementKey', ...] @@ -117,6 +142,12 @@ def qasm(self): def parse_sympy_condition(s: str) -> 'cirq.SympyCondition': + """Parses a string into a `cirq.SympyCondition`. + + The measurement keys in a sympy condition string must be wrapped in curly + braces to denote them. For example, to create an expression that checks if + measurement A was greater than measurement B, the proper syntax is + `cirq.parse_sympy_condition('{A} > {B}')`.""" in_key = False key_count = 0 s_out = '' @@ -145,6 +176,7 @@ def parse_sympy_condition(s: str) -> 'cirq.SympyCondition': def parse_condition(s: str) -> 'cirq.Condition': + """Parses a string into a `Condition`.""" try: return parse_sympy_condition(s) except ValueError: From 3a0a56d24e2812b87c291ec1d8a44021bd37fb18 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Thu, 9 Dec 2021 14:05:58 -0800 Subject: [PATCH 17/89] subop --- cirq-core/cirq/ops/classically_controlled_operation.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/cirq-core/cirq/ops/classically_controlled_operation.py b/cirq-core/cirq/ops/classically_controlled_operation.py index 780bbc869bc..3d793a1ebc5 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation.py +++ b/cirq-core/cirq/ops/classically_controlled_operation.py @@ -177,7 +177,9 @@ def map_condition(condition: 'cirq.Condition') -> 'cirq.Condition': return condition.with_keys(*keys) conditions = [map_condition(c) for c in self._conditions] - return self._sub_operation.with_classical_controls(*conditions) + sub_operation = protocols.with_measurement_key_mapping(self._sub_operation, key_map) + sub_operation = self._sub_operation if sub_operation is NotImplemented else sub_operation + return sub_operation.with_classical_controls(*conditions) def _with_key_path_prefix_(self, path: Tuple[str, ...]) -> 'ClassicallyControlledOperation': def map_condition(condition: 'cirq.Condition') -> 'cirq.Condition': @@ -185,7 +187,9 @@ def map_condition(condition: 'cirq.Condition') -> 'cirq.Condition': return condition.with_keys(*keys) conditions = [map_condition(c) for c in self._conditions] - return self._sub_operation.with_classical_controls(*conditions) + sub_operation = protocols.with_key_path_prefix(self._sub_operation, path) + sub_operation = self._sub_operation if sub_operation is NotImplemented else sub_operation + return sub_operation.with_classical_controls(*conditions) def _control_keys_(self) -> FrozenSet[value.MeasurementKey]: local_keys: FrozenSet[value.MeasurementKey] = frozenset( From f930f6a18266ae7570ce936bf694bc6d4cd5b35d Mon Sep 17 00:00:00 2001 From: daxfohl Date: Thu, 9 Dec 2021 22:16:23 -0800 Subject: [PATCH 18/89] regex --- cirq-core/cirq/value/condition.py | 55 ++++++++++--------------------- 1 file changed, 17 insertions(+), 38 deletions(-) diff --git a/cirq-core/cirq/value/condition.py b/cirq-core/cirq/value/condition.py index 7df95c76fcf..0b7735dc8a4 100644 --- a/cirq-core/cirq/value/condition.py +++ b/cirq-core/cirq/value/condition.py @@ -11,17 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from typing import ( - Mapping, - Optional, - Sequence, - Tuple, - TYPE_CHECKING, -) - import abc import dataclasses +import re +from typing import Dict, Mapping, Sequence, Tuple, TYPE_CHECKING + import sympy from cirq.protocols import json_serialization @@ -148,37 +142,22 @@ def parse_sympy_condition(s: str) -> 'cirq.SympyCondition': braces to denote them. For example, to create an expression that checks if measurement A was greater than measurement B, the proper syntax is `cirq.parse_sympy_condition('{A} > {B}')`.""" - in_key = False - key_count = 0 - s_out = '' - key_name = '' - keys = [] - for c in s: - if not in_key: - if c == '{': - in_key = True - else: - s_out += c - else: - if c == '}': - symbol_name = f'x{key_count}' - s_out += symbol_name - keys.append(measurement_key.MeasurementKey.parse_serialized(key_name)) - key_name = '' - key_count += 1 - in_key = False - else: - key_name += c - expr = sympy.sympify(s_out) - if len(expr.free_symbols) != len(keys): - raise ValueError(f"'{s}' is not a valid sympy condition") + keys_to_indexes: Dict[str, str] = {} + + def replace(m): + g1 = m.group(1) + if g1 not in keys_to_indexes: + keys_to_indexes[g1] = f'x{len(keys_to_indexes)}' + return keys_to_indexes[g1] + + result = re.sub(r'{([^}]+)}', replace, s) + expr = sympy.sympify(result) + keys = [measurement_key.MeasurementKey.parse_serialized(key) for key in keys_to_indexes.keys()] return SympyCondition(expr, tuple(keys)) def parse_condition(s: str) -> 'cirq.Condition': """Parses a string into a `Condition`.""" - try: - return parse_sympy_condition(s) - except ValueError: - pass - return measurement_key.MeasurementKey.parse_serialized(s) + return ( + parse_sympy_condition(s) if '{' in s else measurement_key.MeasurementKey.parse_serialized(s) + ) From 39a7a951cbabf1ac7b5fa2ca7a5727db09d151ed Mon Sep 17 00:00:00 2001 From: daxfohl Date: Thu, 9 Dec 2021 22:19:05 -0800 Subject: [PATCH 19/89] docs --- cirq-core/cirq/value/condition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cirq-core/cirq/value/condition.py b/cirq-core/cirq/value/condition.py index 0b7735dc8a4..ff19e74cff6 100644 --- a/cirq-core/cirq/value/condition.py +++ b/cirq-core/cirq/value/condition.py @@ -98,7 +98,7 @@ class SympyCondition(Condition): The `cirq.parse_sympy_condition` function automates setting this up correctly. To create the above expression, one would call - `cirq.parse_sympy_condition('{0:A}, {0:B}')`. + `cirq.parse_sympy_condition('{0:A} > {0:B}')`. """ expr: sympy.Expr From 42bac3ec8927d01a77cd220f72aa798179359ee1 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Fri, 10 Dec 2021 10:44:55 -0800 Subject: [PATCH 20/89] Make test_sympy more intuitive. --- .../classically_controlled_operation_test.py | 36 +++++++++++-------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/cirq-core/cirq/ops/classically_controlled_operation_test.py b/cirq-core/cirq/ops/classically_controlled_operation_test.py index bd8172baeea..cda002da60c 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation_test.py +++ b/cirq-core/cirq/ops/classically_controlled_operation_test.py @@ -417,19 +417,27 @@ def test_unmeasured_condition(): def test_sympy(): - for i in range(9): + q0, q1, q2, q3, q4, q5, q6 = cirq.LineQubit.range(7) + for i in range(8): for j in range(8): - # Add X gates to put the circuit into a state representing bitstring(j), and measure - bitstring = cirq.big_endian_int_to_bits(j, bit_count=3) - circuit = cirq.Circuit() - for k in range(3): - circuit.append(cirq.X(cirq.LineQubit(k)) ** bitstring[k]) - circuit.append(cirq.measure(*cirq.LineQubit.range(3), key='m_j')) - - # Add a X(q3) conditional on the above measurement (which should be == `j`) being > `i` - circuit.append(cirq.X(cirq.LineQubit(3)).with_classical_controls(f'{{m_j}} > {i}')) - circuit.append(cirq.measure(cirq.LineQubit(3), key='q3')) - result = cirq.Simulator().run(circuit) + # Put first three qubits into a state representing bitstring(i), next three qubits + # into a state representing bitstring(j) and measure those into m_i and m_j + # respectively. Then add a conditional X(q6) based on m_i > m_j and measure that. + bitstring_i = cirq.big_endian_int_to_bits(i, bit_count=3) + bitstring_j = cirq.big_endian_int_to_bits(j, bit_count=3) + circuit = cirq.Circuit( + cirq.X(q0) ** bitstring_i[0], + cirq.X(q1) ** bitstring_i[1], + cirq.X(q2) ** bitstring_i[2], + cirq.X(q3) ** bitstring_j[0], + cirq.X(q4) ** bitstring_j[1], + cirq.X(q5) ** bitstring_j[2], + cirq.measure(q0, q1, q2, key='m_i'), + cirq.measure(q3, q4, q5, key='m_j'), + cirq.X(q6).with_classical_controls('{m_j} > {m_i}'), + cirq.measure(q6, key='m_q6'), + ) - # q3 should now be set iff j > i. - assert result.measurements['q3'][0][0] == (j > i) + # m_q6 should now be set iff j > i. + result = cirq.Simulator().run(circuit) + assert result.measurements['m_q6'][0][0] == (j > i) From f4ea9d84b60bb9a542b6c450d690f60d114ce783 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Mon, 13 Dec 2021 17:05:08 -0800 Subject: [PATCH 21/89] Sympy str roundtrip --- cirq-core/cirq/value/condition.py | 9 +++++++-- cirq-core/cirq/value/condition_test.py | 2 +- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/cirq-core/cirq/value/condition.py b/cirq-core/cirq/value/condition.py index ff19e74cff6..eb96b150a3e 100644 --- a/cirq-core/cirq/value/condition.py +++ b/cirq-core/cirq/value/condition.py @@ -113,8 +113,13 @@ def with_keys(self, *keys: 'cirq.MeasurementKey'): return dataclasses.replace(self, control_keys=keys) def __str__(self): - replacements = {f'x{i}': str(key) for i, key in enumerate(self.control_keys)} - return f"{self.expr.subs(replacements)}" + replacements = {f'x{i}': f'{{{str(key)}}}' for i, key in enumerate(self.control_keys)} + + class CustomCodePrinter(sympy.printing.StrPrinter): + def _print_Symbol(self, expr): + return replacements[expr.name] + + return CustomCodePrinter().doprint(self.expr) def resolve(self, measurements: Mapping[str, Sequence[int]]) -> bool: missing = [str(k) for k in self.keys if str(k) not in measurements] diff --git a/cirq-core/cirq/value/condition_test.py b/cirq-core/cirq/value/condition_test.py index 94e6a8a4db6..cb906550511 100644 --- a/cirq-core/cirq/value/condition_test.py +++ b/cirq-core/cirq/value/condition_test.py @@ -62,7 +62,7 @@ def test_sympy_condition_with_keys(): def test_sympy_condition_str(): - assert str(init_sympy_condition) == "a >= 1" + assert str(init_sympy_condition) == '{a} >= 1' def test_sympy_condition_resolve(): From 71f61f52bebefc26d7d074365537c2f27bf1c10a Mon Sep 17 00:00:00 2001 From: daxfohl Date: Wed, 15 Dec 2021 16:07:24 -0800 Subject: [PATCH 22/89] Resolve some code review comments --- .../ops/classically_controlled_operation.py | 16 +++------ .../classically_controlled_operation_test.py | 34 +++++++++---------- .../protocols/measurement_key_protocol.py | 13 ++++--- cirq-core/cirq/value/condition.py | 24 +++++++++---- 4 files changed, 46 insertions(+), 41 deletions(-) diff --git a/cirq-core/cirq/ops/classically_controlled_operation.py b/cirq-core/cirq/ops/classically_controlled_operation.py index 3d793a1ebc5..b6565ebb205 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation.py +++ b/cirq-core/cirq/ops/classically_controlled_operation.py @@ -172,22 +172,14 @@ def _act_on_(self, args: 'cirq.ActOnArgs') -> bool: def _with_measurement_key_mapping_( self, key_map: Dict[str, str] ) -> 'ClassicallyControlledOperation': - def map_condition(condition: 'cirq.Condition') -> 'cirq.Condition': - keys = [protocols.with_measurement_key_mapping(k, key_map) for k in condition.keys] - return condition.with_keys(*keys) - - conditions = [map_condition(c) for c in self._conditions] + conditions = [protocols.with_measurement_key_mapping(c, key_map) for c in self._conditions] sub_operation = protocols.with_measurement_key_mapping(self._sub_operation, key_map) sub_operation = self._sub_operation if sub_operation is NotImplemented else sub_operation return sub_operation.with_classical_controls(*conditions) - def _with_key_path_prefix_(self, path: Tuple[str, ...]) -> 'ClassicallyControlledOperation': - def map_condition(condition: 'cirq.Condition') -> 'cirq.Condition': - keys = [protocols.with_key_path_prefix(k, path) for k in condition.keys] - return condition.with_keys(*keys) - - conditions = [map_condition(c) for c in self._conditions] - sub_operation = protocols.with_key_path_prefix(self._sub_operation, path) + def _with_key_path_prefix_(self, prefix: Tuple[str, ...]) -> 'ClassicallyControlledOperation': + conditions = [protocols.with_key_path_prefix(c, prefix) for c in self._conditions] + sub_operation = protocols.with_key_path_prefix(self._sub_operation, prefix) sub_operation = self._sub_operation if sub_operation is NotImplemented else sub_operation return sub_operation.with_classical_controls(*conditions) diff --git a/cirq-core/cirq/ops/classically_controlled_operation_test.py b/cirq-core/cirq/ops/classically_controlled_operation_test.py index cda002da60c..a36385b9794 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation_test.py +++ b/cirq-core/cirq/ops/classically_controlled_operation_test.py @@ -417,27 +417,25 @@ def test_unmeasured_condition(): def test_sympy(): - q0, q1, q2, q3, q4, q5, q6 = cirq.LineQubit.range(7) - for i in range(8): - for j in range(8): - # Put first three qubits into a state representing bitstring(i), next three qubits - # into a state representing bitstring(j) and measure those into m_i and m_j - # respectively. Then add a conditional X(q6) based on m_i > m_j and measure that. - bitstring_i = cirq.big_endian_int_to_bits(i, bit_count=3) - bitstring_j = cirq.big_endian_int_to_bits(j, bit_count=3) + q0, q1, q2, q3, q_result = cirq.LineQubit.range(5) + for i in range(4): + for j in range(4): + # Put first two qubits into a state representing bitstring(i), next two qubits into a + # state representing bitstring(j) and measure those into m_i and m_j respectively. Then + # add a conditional X(q_result) based on m_i > m_j and measure that. + bitstring_i = cirq.big_endian_int_to_bits(i, bit_count=2) + bitstring_j = cirq.big_endian_int_to_bits(j, bit_count=2) circuit = cirq.Circuit( cirq.X(q0) ** bitstring_i[0], cirq.X(q1) ** bitstring_i[1], - cirq.X(q2) ** bitstring_i[2], - cirq.X(q3) ** bitstring_j[0], - cirq.X(q4) ** bitstring_j[1], - cirq.X(q5) ** bitstring_j[2], - cirq.measure(q0, q1, q2, key='m_i'), - cirq.measure(q3, q4, q5, key='m_j'), - cirq.X(q6).with_classical_controls('{m_j} > {m_i}'), - cirq.measure(q6, key='m_q6'), + cirq.X(q2) ** bitstring_j[0], + cirq.X(q3) ** bitstring_j[1], + cirq.measure(q0, q1, key='m_i'), + cirq.measure(q2, q3, key='m_j'), + cirq.X(q_result).with_classical_controls('{m_j} > {m_i}'), + cirq.measure(q_result, key='result'), ) - # m_q6 should now be set iff j > i. + # m_q4 should now be set iff j > i. result = cirq.Simulator().run(circuit) - assert result.measurements['m_q6'][0][0] == (j > i) + assert result.measurements['result'][0][0] == (j > i) diff --git a/cirq-core/cirq/protocols/measurement_key_protocol.py b/cirq-core/cirq/protocols/measurement_key_protocol.py index 09d8b60bdf3..31a60368e5d 100644 --- a/cirq-core/cirq/protocols/measurement_key_protocol.py +++ b/cirq-core/cirq/protocols/measurement_key_protocol.py @@ -13,13 +13,16 @@ # limitations under the License. """Protocol for object that have measurement keys.""" -from typing import AbstractSet, Any, Dict, Optional, Tuple +from typing import AbstractSet, Any, Dict, Optional, Tuple, TYPE_CHECKING from typing_extensions import Protocol from cirq._doc import doc_private from cirq import value +if TYPE_CHECKING: + import cirq + # This is a special indicator value used by the inverse method to determine # whether or not the caller provided a 'default' argument. RaiseTypeErrorIfNotProvided = ([],) # type: Any @@ -56,7 +59,7 @@ def _is_measurement_(self) -> bool: """Return if this object is (or contains) a measurement.""" @doc_private - def _measurement_key_obj_(self) -> value.MeasurementKey: + def _measurement_key_obj_(self) -> 'cirq.MeasurementKey': """Return the key object that will be used to identify this measurement. When a measurement occurs, either on hardware, or in a simulation, @@ -65,7 +68,7 @@ def _measurement_key_obj_(self) -> value.MeasurementKey: """ @doc_private - def _measurement_key_objs_(self) -> AbstractSet[value.MeasurementKey]: + def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']: """Return the key objects for measurements performed by the receiving object. When a measurement occurs, either on hardware, or in a simulation, @@ -169,7 +172,7 @@ def measurement_key_name(val: Any, default: Any = RaiseTypeErrorIfNotProvided): def _measurement_key_objs_from_magic_methods( val: Any, -) -> Optional[AbstractSet[value.MeasurementKey]]: +) -> Optional[AbstractSet['cirq.MeasurementKey']]: """Uses the measurement key related magic methods to get the `MeasurementKey`s for this object.""" @@ -201,7 +204,7 @@ def _measurement_key_names_from_magic_methods(val: Any) -> Optional[AbstractSet[ return result -def measurement_key_objs(val: Any) -> AbstractSet[value.MeasurementKey]: +def measurement_key_objs(val: Any) -> AbstractSet['cirq.MeasurementKey']: """Gets the measurement key objects of measurements within the given value. Args: diff --git a/cirq-core/cirq/value/condition.py b/cirq-core/cirq/value/condition.py index eb96b150a3e..c208fe977ec 100644 --- a/cirq-core/cirq/value/condition.py +++ b/cirq-core/cirq/value/condition.py @@ -18,7 +18,7 @@ import sympy -from cirq.protocols import json_serialization +from cirq.protocols import json_serialization, measurement_key_protocol as mkp from cirq.value import digits, measurement_key if TYPE_CHECKING: @@ -46,6 +46,14 @@ def resolve(self, measurements: Mapping[str, Sequence[int]]) -> bool: def qasm(self): """Returns the qasm of this condition.""" + def _with_measurement_key_mapping_(self, key_map: Dict[str, str]) -> 'Condition': + keys = [mkp.with_measurement_key_mapping(k, key_map) for k in self.keys] + return self.with_keys(*keys) + + def _with_key_path_prefix_(self, path: Tuple[str, ...]) -> 'Condition': + keys = [mkp.with_key_path_prefix(k, path) for k in self.keys] + return self.with_keys(*keys) + @dataclasses.dataclass(frozen=True) class KeyCondition(Condition): @@ -62,7 +70,8 @@ def keys(self): return (self.key,) def with_keys(self, *keys: 'cirq.MeasurementKey'): - assert len(keys) == 1 + if len(keys) != 1: + raise ValueError(f'Cannot apply multiple keys to a KeyCondition') return KeyCondition(keys[0]) def __str__(self): @@ -87,8 +96,8 @@ class SympyCondition(Condition): """A classical control condition based on a sympy expression. This condition resolves to True iff the sympy expression resolves to a - Truthy value when the measurement keys are substituted in as the free - variables. + truthy value (i.e. `bool(x) == True`) when the measurement keys are + substituted in as the free variables. To account for the fact that measurement key strings can contain characters not allowed in sympy variables, we use x0..xN for the free variables and @@ -109,7 +118,8 @@ def keys(self): return self.control_keys def with_keys(self, *keys: 'cirq.MeasurementKey'): - assert len(keys) == len(self.control_keys) + if len(keys) != len(self.control_keys): + raise ValueError(f'Wrong number of keys applied to this condition.') return dataclasses.replace(self, control_keys=keys) def __str__(self): @@ -164,5 +174,7 @@ def replace(m): def parse_condition(s: str) -> 'cirq.Condition': """Parses a string into a `Condition`.""" return ( - parse_sympy_condition(s) if '{' in s else measurement_key.MeasurementKey.parse_serialized(s) + parse_sympy_condition(s) + if '{' in s + else KeyCondition(measurement_key.MeasurementKey.parse_serialized(s)) ) From 22613554b6ef72859e7b295000efa2f5ee8fc0e5 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Wed, 15 Dec 2021 17:14:34 -0800 Subject: [PATCH 23/89] Add escape key to parse_sympy_condition --- cirq-core/cirq/value/condition.py | 51 ++++++++++++++++++++------ cirq-core/cirq/value/condition_test.py | 39 ++++++++++++++++++++ 2 files changed, 78 insertions(+), 12 deletions(-) diff --git a/cirq-core/cirq/value/condition.py b/cirq-core/cirq/value/condition.py index c208fe977ec..09f2147cca8 100644 --- a/cirq-core/cirq/value/condition.py +++ b/cirq-core/cirq/value/condition.py @@ -156,18 +156,45 @@ def parse_sympy_condition(s: str) -> 'cirq.SympyCondition': The measurement keys in a sympy condition string must be wrapped in curly braces to denote them. For example, to create an expression that checks if measurement A was greater than measurement B, the proper syntax is - `cirq.parse_sympy_condition('{A} > {B}')`.""" - keys_to_indexes: Dict[str, str] = {} - - def replace(m): - g1 = m.group(1) - if g1 not in keys_to_indexes: - keys_to_indexes[g1] = f'x{len(keys_to_indexes)}' - return keys_to_indexes[g1] - - result = re.sub(r'{([^}]+)}', replace, s) - expr = sympy.sympify(result) - keys = [measurement_key.MeasurementKey.parse_serialized(key) for key in keys_to_indexes.keys()] + `cirq.parse_sympy_condition('{A} > {B}')`. + + A backslash can be used to treat the subsequent character as a literal + within the key name, in case braces or backslashes appear in the key name. + """ + in_key = False + key_count = 0 + s_out = '' + key_name = '' + keys = [] + i = 0 + while i < len(s): + c = s[i] + if not in_key: + if c == '{': + in_key = True + else: + s_out += c + else: + if c == '}': + symbol_name = f'x{key_count}' + s_out += symbol_name + keys.append(measurement_key.MeasurementKey.parse_serialized(key_name)) + key_name = '' + key_count += 1 + in_key = False + else: + if c == '\\': + i += 1 + if i == len(c): + raise ValueError(f"'{s}' is not a valid sympy condition") + c = s[i] + key_name += c + i += 1 + if in_key: + raise ValueError(f"'{s}' is not a valid sympy condition") + expr = sympy.sympify(s_out) + if len(expr.free_symbols) != len(keys): + raise ValueError(f"'{s}' is not a valid sympy condition") return SympyCondition(expr, tuple(keys)) diff --git a/cirq-core/cirq/value/condition_test.py b/cirq-core/cirq/value/condition_test.py index cb906550511..b414a97f97c 100644 --- a/cirq-core/cirq/value/condition_test.py +++ b/cirq-core/cirq/value/condition_test.py @@ -87,3 +87,42 @@ def test_sympy_condition_resolve(): def test_sympy_condition_qasm(): with pytest.raises(NotImplementedError): _ = init_sympy_condition.qasm + + +def test_parse_sympy_condition(): + c = cirq.parse_sympy_condition('{a} > {b}') + assert len(c.keys) == 2 + assert c.keys[0].name == 'a' + assert c.keys[1].name == 'b' + + +def test_parse_sympy_condition_escaping(): + c = cirq.parse_sympy_condition('{a\\{\\}\\\\} + 2') + assert len(c.keys) == 1 + assert c.keys[0].name == 'a{}\\' + + +def test_parse_sympy_condition_errors(): + with pytest.raises(ValueError): + _ = cirq.parse_sympy_condition('{a} > {b') + with pytest.raises(ValueError): + _ = cirq.parse_sympy_condition('{a} > {b}}') + with pytest.raises(ValueError): + _ = cirq.parse_sympy_condition('[]]23[42][][@#{$}') + + +def test_parse_condition(): + c = cirq.parse_condition('{a} > {b}') + assert isinstance(c, cirq.SympyCondition) + assert len(c.keys) == 2 + assert c.keys[0].name == 'a' + assert c.keys[1].name == 'b' + c = cirq.parse_condition('a') + assert isinstance(c, cirq.KeyCondition) + assert len(c.keys) == 1 + assert c.keys[0].name == 'a' + c = cirq.parse_condition('0:a') + assert isinstance(c, cirq.KeyCondition) + assert len(c.keys) == 1 + assert c.keys[0].name == 'a' + assert c.keys[0].path == ('0',) From 6b36357a909dd2b01ddc68738742589957ef3c60 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Wed, 15 Dec 2021 17:38:06 -0800 Subject: [PATCH 24/89] repr --- .../cirq/ops/classically_controlled_operation_test.py | 2 +- cirq-core/cirq/value/condition.py | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/cirq-core/cirq/ops/classically_controlled_operation_test.py b/cirq-core/cirq/ops/classically_controlled_operation_test.py index a36385b9794..79d193f6fc4 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation_test.py +++ b/cirq-core/cirq/ops/classically_controlled_operation_test.py @@ -396,7 +396,7 @@ def test_repr(): op = cirq.X(q0).with_classical_controls('a') assert repr(op) == ( "cirq.ClassicallyControlledOperation(" - "cirq.X(cirq.LineQubit(0)), [KeyCondition(key=cirq.MeasurementKey(name='a'))]" + "cirq.X(cirq.LineQubit(0)), [cirq.KeyCondition(cirq.MeasurementKey(name='a'))]" ")" ) diff --git a/cirq-core/cirq/value/condition.py b/cirq-core/cirq/value/condition.py index 09f2147cca8..41f2f7477b2 100644 --- a/cirq-core/cirq/value/condition.py +++ b/cirq-core/cirq/value/condition.py @@ -13,11 +13,11 @@ # limitations under the License. import abc import dataclasses -import re from typing import Dict, Mapping, Sequence, Tuple, TYPE_CHECKING import sympy +from cirq._compat import proper_repr from cirq.protocols import json_serialization, measurement_key_protocol as mkp from cirq.value import digits, measurement_key @@ -77,6 +77,9 @@ def with_keys(self, *keys: 'cirq.MeasurementKey'): def __str__(self): return str(self.key) + def __repr__(self): + return f'cirq.KeyCondition({self.key!r})' + def resolve(self, measurements: Mapping[str, Sequence[int]]) -> bool: key = str(self.key) if key not in measurements: @@ -131,6 +134,9 @@ def _print_Symbol(self, expr): return CustomCodePrinter().doprint(self.expr) + def __repr__(self): + return f'cirq.SympyCondition({proper_repr(self.expr)}, {self.keys!r})' + def resolve(self, measurements: Mapping[str, Sequence[int]]) -> bool: missing = [str(k) for k in self.keys if str(k) not in measurements] if missing: From afbf3c9b2032d08e1c1e20ee6de8ff9d7b11b819 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Wed, 15 Dec 2021 17:56:35 -0800 Subject: [PATCH 25/89] coverage --- .../cirq/ops/classically_controlled_operation_test.py | 9 ++++++++- cirq-core/cirq/value/condition.py | 2 +- cirq-core/cirq/value/condition_test.py | 8 +++++++- 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/cirq-core/cirq/ops/classically_controlled_operation_test.py b/cirq-core/cirq/ops/classically_controlled_operation_test.py index 79d193f6fc4..7f5f8dda254 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation_test.py +++ b/cirq-core/cirq/ops/classically_controlled_operation_test.py @@ -330,10 +330,17 @@ def test_key_set_in_subcircuit_outer_scope(): assert result.measurements['b'] == 1 +def test_condition_types(): + q0 = cirq.LineQubit(0) + sympy_cond = cirq.parse_sympy_condition('{a} >= 2') + op = cirq.X(q0).with_classical_controls(cirq.MeasurementKey('a'), 'b', '{a} > {b}', sympy_cond) + assert set(map(str, op.classical_controls)) == {'a', 'b', '{a} > {b}', '{a} >= 2'} + + def test_condition_flattening(): q0 = cirq.LineQubit(0) op = cirq.X(q0).with_classical_controls('a').with_classical_controls('b') - assert set(map(str, op._conditions)) == {'a', 'b'} + assert set(map(str, op.classical_controls)) == {'a', 'b'} assert isinstance(op._sub_operation, cirq.GateOperation) diff --git a/cirq-core/cirq/value/condition.py b/cirq-core/cirq/value/condition.py index 41f2f7477b2..dcaa3385878 100644 --- a/cirq-core/cirq/value/condition.py +++ b/cirq-core/cirq/value/condition.py @@ -191,7 +191,7 @@ def parse_sympy_condition(s: str) -> 'cirq.SympyCondition': else: if c == '\\': i += 1 - if i == len(c): + if i == len(s): raise ValueError(f"'{s}' is not a valid sympy condition") c = s[i] key_name += c diff --git a/cirq-core/cirq/value/condition_test.py b/cirq-core/cirq/value/condition_test.py index b414a97f97c..25e4af4dd35 100644 --- a/cirq-core/cirq/value/condition_test.py +++ b/cirq-core/cirq/value/condition_test.py @@ -27,6 +27,8 @@ def test_key_condition_with_keys(): c = init_key_condition.with_keys(key_b) assert c.key is key_b + with pytest.raises(ValueError, match='Cannot apply multiple keys to a KeyCondition'): + _ = c.with_keys(key_b, key_a) def test_key_condition_str(): @@ -59,6 +61,8 @@ def test_key_condition_qasm(): def test_sympy_condition_with_keys(): c = init_sympy_condition.with_keys(key_b) assert c.keys == (key_b,) + with pytest.raises(ValueError, match='Wrong number of keys applied to this condition'): + _ = c.with_keys(key_b, key_a) def test_sympy_condition_str(): @@ -108,7 +112,9 @@ def test_parse_sympy_condition_errors(): with pytest.raises(ValueError): _ = cirq.parse_sympy_condition('{a} > {b}}') with pytest.raises(ValueError): - _ = cirq.parse_sympy_condition('[]]23[42][][@#{$}') + _ = cirq.parse_sympy_condition('a + {b}') + with pytest.raises(ValueError): + _ = cirq.parse_sympy_condition('{a\\') def test_parse_condition(): From 58fb2dc647b1275b4385eb116f020da4f9f87a79 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Wed, 15 Dec 2021 18:15:50 -0800 Subject: [PATCH 26/89] coverage --- cirq-core/cirq/value/condition_test.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/cirq-core/cirq/value/condition_test.py b/cirq-core/cirq/value/condition_test.py index 25e4af4dd35..3b339d0d229 100644 --- a/cirq-core/cirq/value/condition_test.py +++ b/cirq-core/cirq/value/condition_test.py @@ -35,6 +35,10 @@ def test_key_condition_str(): assert str(init_key_condition) == 'a' +def test_key_condition_repr(): + assert repr(init_key_condition) == "cirq.KeyCondition(cirq.MeasurementKey(name='a'))" + + def test_key_condition_resolve(): assert init_key_condition.resolve({'a': [1]}) assert init_key_condition.resolve({'a': [2]}) @@ -69,6 +73,14 @@ def test_sympy_condition_str(): assert str(init_sympy_condition) == '{a} >= 1' +def test_sympy_condition_repr(): + assert ( + repr(cirq.parse_sympy_condition('{a} * {b}')) == "cirq.SympyCondition(" + "sympy.Mul(sympy.Symbol('x0'), sympy.Symbol('x1')), " + "(cirq.MeasurementKey(name='a'), cirq.MeasurementKey(name='b')))" + ) + + def test_sympy_condition_resolve(): assert init_sympy_condition.resolve({'a': [1]}) assert init_sympy_condition.resolve({'a': [2]}) From bd80c0b91cf03c97ac8c3c717ea4afee8ead3296 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Fri, 17 Dec 2021 11:25:21 -0800 Subject: [PATCH 27/89] parser --- cirq-core/cirq/__init__.py | 2 - .../ops/classically_controlled_operation.py | 8 +- .../classically_controlled_operation_test.py | 9 +- cirq-core/cirq/ops/raw_types.py | 8 +- cirq-core/cirq/value/__init__.py | 2 - cirq-core/cirq/value/condition.py | 111 ++++-------------- cirq-core/cirq/value/condition_test.py | 64 ++-------- 7 files changed, 48 insertions(+), 156 deletions(-) diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index 42cd728f381..b6f1549496e 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -481,8 +481,6 @@ LinearDict, MEASUREMENT_KEY_SEPARATOR, MeasurementKey, - parse_condition, - parse_sympy_condition, PeriodicValue, RANDOM_STATE_OR_SEED_LIKE, state_vector_to_probabilities, diff --git a/cirq-core/cirq/ops/classically_controlled_operation.py b/cirq-core/cirq/ops/classically_controlled_operation.py index b6565ebb205..f6188f29785 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation.py +++ b/cirq-core/cirq/ops/classically_controlled_operation.py @@ -24,6 +24,8 @@ Union, ) +import sympy + from cirq import protocols, value from cirq.ops import raw_types @@ -47,7 +49,7 @@ class ClassicallyControlledOperation(raw_types.Operation): def __init__( self, sub_operation: 'cirq.Operation', - conditions: Sequence[Union[str, 'cirq.MeasurementKey', 'cirq.Condition']], + conditions: Sequence[Union[str, 'cirq.MeasurementKey', 'cirq.Condition', sympy.Expr]], ): """Initializes a `ClassicallyControlledOperation`. @@ -76,9 +78,11 @@ def __init__( conds: List['cirq.Condition'] = [] for c in conditions: if isinstance(c, str): - c = value.parse_condition(c) + c = value.MeasurementKey.parse_serialized(c) if isinstance(c, value.MeasurementKey): c = value.KeyCondition(c) + if isinstance(c, sympy.Expr): + c = value.SympyCondition(c) conds.append(c) self._conditions: Tuple['cirq.Condition', ...] = tuple(conds) self._sub_operation: 'cirq.Operation' = sub_operation diff --git a/cirq-core/cirq/ops/classically_controlled_operation_test.py b/cirq-core/cirq/ops/classically_controlled_operation_test.py index 7f5f8dda254..0992c9619fc 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation_test.py +++ b/cirq-core/cirq/ops/classically_controlled_operation_test.py @@ -13,6 +13,7 @@ # limitations under the License. import pytest import sympy +from sympy.parsing import sympy_parser import cirq @@ -332,9 +333,9 @@ def test_key_set_in_subcircuit_outer_scope(): def test_condition_types(): q0 = cirq.LineQubit(0) - sympy_cond = cirq.parse_sympy_condition('{a} >= 2') - op = cirq.X(q0).with_classical_controls(cirq.MeasurementKey('a'), 'b', '{a} > {b}', sympy_cond) - assert set(map(str, op.classical_controls)) == {'a', 'b', '{a} > {b}', '{a} >= 2'} + sympy_cond = sympy_parser.parse_expr('a >= 2') + op = cirq.X(q0).with_classical_controls(cirq.MeasurementKey('a'), 'b', 'a > b', sympy_cond) + assert set(map(str, op.classical_controls)) == {'a', 'b', 'a > b', 'a >= 2'} def test_condition_flattening(): @@ -439,7 +440,7 @@ def test_sympy(): cirq.X(q3) ** bitstring_j[1], cirq.measure(q0, q1, key='m_i'), cirq.measure(q2, q3, key='m_j'), - cirq.X(q_result).with_classical_controls('{m_j} > {m_i}'), + cirq.X(q_result).with_classical_controls(sympy_parser.parse_expr('m_j > m_i')), cirq.measure(q_result, key='result'), ) diff --git a/cirq-core/cirq/ops/raw_types.py b/cirq-core/cirq/ops/raw_types.py index 57eeff13bd7..298b213be6b 100644 --- a/cirq-core/cirq/ops/raw_types.py +++ b/cirq-core/cirq/ops/raw_types.py @@ -35,6 +35,7 @@ ) import numpy as np +import sympy from cirq import protocols, value from cirq._import import LazyLoader @@ -597,7 +598,7 @@ def classical_controls(self) -> FrozenSet['cirq.Condition']: return frozenset() def with_classical_controls( - self, *conditions: Union[str, 'cirq.MeasurementKey', 'cirq.Condition'] + self, *conditions: Union[str, 'cirq.MeasurementKey', 'cirq.Condition', sympy.Expr] ) -> 'cirq.ClassicallyControlledOperation': """Returns a classically controlled version of this operation. @@ -610,8 +611,9 @@ def with_classical_controls( since tags are considered a local attribute. Args: - conditions: A list of measurement keys, or strings that can be - parsed into measurement keys. + conditions: A list of measurement keys, strings that can be parsed + into measurement keys, or sympy expressions where the free + symbols are measurement key strings. Returns: A `ClassicallyControlledOperation` wrapping the operation. diff --git a/cirq-core/cirq/value/__init__.py b/cirq-core/cirq/value/__init__.py index 05e92d0777a..da6cfc2b058 100644 --- a/cirq-core/cirq/value/__init__.py +++ b/cirq-core/cirq/value/__init__.py @@ -28,8 +28,6 @@ from cirq.value.condition import ( Condition, KeyCondition, - parse_condition, - parse_sympy_condition, SympyCondition, ) diff --git a/cirq-core/cirq/value/condition.py b/cirq-core/cirq/value/condition.py index dcaa3385878..0d08984b231 100644 --- a/cirq-core/cirq/value/condition.py +++ b/cirq-core/cirq/value/condition.py @@ -34,7 +34,7 @@ def keys(self) -> Tuple['cirq.MeasurementKey', ...]: """Gets the control keys.""" @abc.abstractmethod - def with_keys(self, *keys: 'cirq.MeasurementKey'): + def replace_key(self, current: 'cirq.MeasurementKey', replacement: 'cirq.MeasurementKey'): """Replaces the control keys.""" @abc.abstractmethod @@ -47,12 +47,16 @@ def qasm(self): """Returns the qasm of this condition.""" def _with_measurement_key_mapping_(self, key_map: Dict[str, str]) -> 'Condition': - keys = [mkp.with_measurement_key_mapping(k, key_map) for k in self.keys] - return self.with_keys(*keys) + condition = self + for k in self.keys: + condition = condition.replace_key(k, mkp.with_measurement_key_mapping(k, key_map)) + return condition def _with_key_path_prefix_(self, path: Tuple[str, ...]) -> 'Condition': - keys = [mkp.with_key_path_prefix(k, path) for k in self.keys] - return self.with_keys(*keys) + condition = self + for k in self.keys: + condition = condition.replace_key(k, mkp.with_key_path_prefix(k, path)) + return condition @dataclasses.dataclass(frozen=True) @@ -69,10 +73,8 @@ class KeyCondition(Condition): def keys(self): return (self.key,) - def with_keys(self, *keys: 'cirq.MeasurementKey'): - if len(keys) != 1: - raise ValueError(f'Cannot apply multiple keys to a KeyCondition') - return KeyCondition(keys[0]) + def replace_key(self, current: 'cirq.MeasurementKey', replacement: 'cirq.MeasurementKey'): + return KeyCondition(replacement) if self.key == current else self def __str__(self): return str(self.key) @@ -101,41 +103,25 @@ class SympyCondition(Condition): This condition resolves to True iff the sympy expression resolves to a truthy value (i.e. `bool(x) == True`) when the measurement keys are substituted in as the free variables. - - To account for the fact that measurement key strings can contain characters - not allowed in sympy variables, we use x0..xN for the free variables and - substitute the measurement keys in by order. For instance a condition with - `expr='x0 > x1', keys=['0:A', '0:B']` would resolve to True iff key `0:A` - is greater than key `0:B`. - - The `cirq.parse_sympy_condition` function automates setting this up - correctly. To create the above expression, one would call - `cirq.parse_sympy_condition('{0:A} > {0:B}')`. """ expr: sympy.Expr - control_keys: Tuple['cirq.MeasurementKey', ...] @property def keys(self): - return self.control_keys + return tuple( + measurement_key.MeasurementKey.parse_serialized(symbol.name) + for symbol in self.expr.free_symbols + ) - def with_keys(self, *keys: 'cirq.MeasurementKey'): - if len(keys) != len(self.control_keys): - raise ValueError(f'Wrong number of keys applied to this condition.') - return dataclasses.replace(self, control_keys=keys) + def replace_key(self, current: 'cirq.MeasurementKey', replacement: 'cirq.MeasurementKey'): + return dataclasses.replace(self, expr=self.expr.subs({str(current): str(replacement)})) def __str__(self): - replacements = {f'x{i}': f'{{{str(key)}}}' for i, key in enumerate(self.control_keys)} - - class CustomCodePrinter(sympy.printing.StrPrinter): - def _print_Symbol(self, expr): - return replacements[expr.name] - - return CustomCodePrinter().doprint(self.expr) + return str(self.expr) def __repr__(self): - return f'cirq.SympyCondition({proper_repr(self.expr)}, {self.keys!r})' + return f'cirq.SympyCondition({proper_repr(self.expr)})' def resolve(self, measurements: Mapping[str, Sequence[int]]) -> bool: missing = [str(k) for k in self.keys if str(k) not in measurements] @@ -145,7 +131,7 @@ def resolve(self, measurements: Mapping[str, Sequence[int]]) -> bool: def value(k): return digits.big_endian_bits_to_int(measurements[str(k)]) - replacements = {f'x{i}': value(k) for i, k in enumerate(self.keys)} + replacements = {str(k): value(k) for k in self.keys} return bool(self.expr.subs(replacements)) def _json_dict_(self): @@ -154,60 +140,3 @@ def _json_dict_(self): @property def qasm(self): raise NotImplementedError() - - -def parse_sympy_condition(s: str) -> 'cirq.SympyCondition': - """Parses a string into a `cirq.SympyCondition`. - - The measurement keys in a sympy condition string must be wrapped in curly - braces to denote them. For example, to create an expression that checks if - measurement A was greater than measurement B, the proper syntax is - `cirq.parse_sympy_condition('{A} > {B}')`. - - A backslash can be used to treat the subsequent character as a literal - within the key name, in case braces or backslashes appear in the key name. - """ - in_key = False - key_count = 0 - s_out = '' - key_name = '' - keys = [] - i = 0 - while i < len(s): - c = s[i] - if not in_key: - if c == '{': - in_key = True - else: - s_out += c - else: - if c == '}': - symbol_name = f'x{key_count}' - s_out += symbol_name - keys.append(measurement_key.MeasurementKey.parse_serialized(key_name)) - key_name = '' - key_count += 1 - in_key = False - else: - if c == '\\': - i += 1 - if i == len(s): - raise ValueError(f"'{s}' is not a valid sympy condition") - c = s[i] - key_name += c - i += 1 - if in_key: - raise ValueError(f"'{s}' is not a valid sympy condition") - expr = sympy.sympify(s_out) - if len(expr.free_symbols) != len(keys): - raise ValueError(f"'{s}' is not a valid sympy condition") - return SympyCondition(expr, tuple(keys)) - - -def parse_condition(s: str) -> 'cirq.Condition': - """Parses a string into a `Condition`.""" - return ( - parse_sympy_condition(s) - if '{' in s - else KeyCondition(measurement_key.MeasurementKey.parse_serialized(s)) - ) diff --git a/cirq-core/cirq/value/condition_test.py b/cirq-core/cirq/value/condition_test.py index 3b339d0d229..47ac4788fb5 100644 --- a/cirq-core/cirq/value/condition_test.py +++ b/cirq-core/cirq/value/condition_test.py @@ -15,20 +15,22 @@ import re import pytest +import sympy.parsing.sympy_parser import cirq key_a = cirq.MeasurementKey('a') key_b = cirq.MeasurementKey('b') +key_c = cirq.MeasurementKey('c') init_key_condition = cirq.KeyCondition(key_a) -init_sympy_condition = cirq.parse_sympy_condition('{a} >= 1') +init_sympy_condition = cirq.SympyCondition(sympy.parsing.sympy_parser.parse_expr('a >= 1')) def test_key_condition_with_keys(): - c = init_key_condition.with_keys(key_b) + c = init_key_condition.replace_key(key_a, key_b) assert c.key is key_b - with pytest.raises(ValueError, match='Cannot apply multiple keys to a KeyCondition'): - _ = c.with_keys(key_b, key_a) + c = init_key_condition.replace_key(key_b, key_c) + assert c.key is key_a def test_key_condition_str(): @@ -63,21 +65,20 @@ def test_key_condition_qasm(): def test_sympy_condition_with_keys(): - c = init_sympy_condition.with_keys(key_b) + c = init_sympy_condition.replace_key(key_a, key_b) assert c.keys == (key_b,) - with pytest.raises(ValueError, match='Wrong number of keys applied to this condition'): - _ = c.with_keys(key_b, key_a) + c = init_sympy_condition.replace_key(key_b, key_c) + assert c.keys == (key_a,) def test_sympy_condition_str(): - assert str(init_sympy_condition) == '{a} >= 1' + assert str(init_sympy_condition) == 'a >= 1' def test_sympy_condition_repr(): assert ( - repr(cirq.parse_sympy_condition('{a} * {b}')) == "cirq.SympyCondition(" - "sympy.Mul(sympy.Symbol('x0'), sympy.Symbol('x1')), " - "(cirq.MeasurementKey(name='a'), cirq.MeasurementKey(name='b')))" + repr(init_sympy_condition) + == "cirq.SympyCondition(GreaterThan(sympy.Symbol('a'), sympy.Integer(1)))" ) @@ -103,44 +104,3 @@ def test_sympy_condition_resolve(): def test_sympy_condition_qasm(): with pytest.raises(NotImplementedError): _ = init_sympy_condition.qasm - - -def test_parse_sympy_condition(): - c = cirq.parse_sympy_condition('{a} > {b}') - assert len(c.keys) == 2 - assert c.keys[0].name == 'a' - assert c.keys[1].name == 'b' - - -def test_parse_sympy_condition_escaping(): - c = cirq.parse_sympy_condition('{a\\{\\}\\\\} + 2') - assert len(c.keys) == 1 - assert c.keys[0].name == 'a{}\\' - - -def test_parse_sympy_condition_errors(): - with pytest.raises(ValueError): - _ = cirq.parse_sympy_condition('{a} > {b') - with pytest.raises(ValueError): - _ = cirq.parse_sympy_condition('{a} > {b}}') - with pytest.raises(ValueError): - _ = cirq.parse_sympy_condition('a + {b}') - with pytest.raises(ValueError): - _ = cirq.parse_sympy_condition('{a\\') - - -def test_parse_condition(): - c = cirq.parse_condition('{a} > {b}') - assert isinstance(c, cirq.SympyCondition) - assert len(c.keys) == 2 - assert c.keys[0].name == 'a' - assert c.keys[1].name == 'b' - c = cirq.parse_condition('a') - assert isinstance(c, cirq.KeyCondition) - assert len(c.keys) == 1 - assert c.keys[0].name == 'a' - c = cirq.parse_condition('0:a') - assert isinstance(c, cirq.KeyCondition) - assert len(c.keys) == 1 - assert c.keys[0].name == 'a' - assert c.keys[0].path == ('0',) From c39a5727ad377a662a7eefbc1b07e862f89002dd Mon Sep 17 00:00:00 2001 From: daxfohl Date: Fri, 17 Dec 2021 12:24:47 -0800 Subject: [PATCH 28/89] Improve sympy repr --- cirq-core/cirq/_compat.py | 17 +++-- .../classically_controlled_operation_test.py | 6 +- .../json_test_data/SympyCondition.json | 22 +++--- .../json_test_data/SympyCondition.repr | 2 +- cirq-core/cirq/value/condition.py | 2 +- cirq-core/cirq/value/condition_test.py | 69 ++++++++++--------- 6 files changed, 66 insertions(+), 52 deletions(-) diff --git a/cirq-core/cirq/_compat.py b/cirq-core/cirq/_compat.py index 8327e1d4cf4..c56ff4b853b 100644 --- a/cirq-core/cirq/_compat.py +++ b/cirq-core/cirq/_compat.py @@ -27,21 +27,24 @@ import numpy as np import pandas as pd import sympy +import sympy.printing.repr def proper_repr(value: Any) -> str: """Overrides sympy and numpy returning repr strings that don't parse.""" if isinstance(value, sympy.Basic): - result = sympy.srepr(value) - # HACK: work around https://github.com/sympy/sympy/issues/16074 - # (only handles a few cases) - fixed_tokens = ['Symbol', 'pi', 'Mul', 'Pow', 'Add', 'Mod', 'Integer', 'Float', 'Rational'] - for token in fixed_tokens: - result = result.replace(token, 'sympy.' + token) + fixed_tokens = dir(sympy) + + class Printer(sympy.printing.repr.ReprPrinter): + def _print(self, expr, **kwargs): + s = super()._print(expr, **kwargs) + if any(s.startswith(t) for t in fixed_tokens): + return 'sympy.' + s + return s - return result + return Printer().doprint(value) if isinstance(value, np.ndarray): if np.issubdtype(value.dtype, np.datetime64): diff --git a/cirq-core/cirq/ops/classically_controlled_operation_test.py b/cirq-core/cirq/ops/classically_controlled_operation_test.py index 0992c9619fc..7fb783013f9 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation_test.py +++ b/cirq-core/cirq/ops/classically_controlled_operation_test.py @@ -441,9 +441,9 @@ def test_sympy(): cirq.measure(q0, q1, key='m_i'), cirq.measure(q2, q3, key='m_j'), cirq.X(q_result).with_classical_controls(sympy_parser.parse_expr('m_j > m_i')), - cirq.measure(q_result, key='result'), + cirq.measure(q_result, key='m_result'), ) - # m_q4 should now be set iff j > i. + # m_result should now be set iff j > i. result = cirq.Simulator().run(circuit) - assert result.measurements['result'][0][0] == (j > i) + assert result.measurements['m_result'][0][0] == (j > i) diff --git a/cirq-core/cirq/protocols/json_test_data/SympyCondition.json b/cirq-core/cirq/protocols/json_test_data/SympyCondition.json index 0a8c2522a69..1e42bf5b8ad 100644 --- a/cirq-core/cirq/protocols/json_test_data/SympyCondition.json +++ b/cirq-core/cirq/protocols/json_test_data/SympyCondition.json @@ -1,11 +1,17 @@ { "cirq_type": "SympyCondition", - "expr": "x > 5", - "control_keys": [ - { - "cirq_type": "MeasurementKey", - "name": "a", - "path": [] - } - ] + "expr": + { + "cirq_type": "sympy.Mul", + "args": [ + { + "cirq_type": "sympy.Symbol", + "name": "a" + }, + { + "cirq_type": "sympy.Symbol", + "name": "b" + } + ] + } } \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/SympyCondition.repr b/cirq-core/cirq/protocols/json_test_data/SympyCondition.repr index 9b465aaa679..66ad4c5ced0 100644 --- a/cirq-core/cirq/protocols/json_test_data/SympyCondition.repr +++ b/cirq-core/cirq/protocols/json_test_data/SympyCondition.repr @@ -1 +1 @@ -cirq.SympyCondition(control_keys=[cirq.MeasurementKey('a')], expr='x > 5') \ No newline at end of file +cirq.SympyCondition(sympy.Mul(sympy.Symbol('a'), sympy.Symbol('b'))) \ No newline at end of file diff --git a/cirq-core/cirq/value/condition.py b/cirq-core/cirq/value/condition.py index 0d08984b231..903abfab6fb 100644 --- a/cirq-core/cirq/value/condition.py +++ b/cirq-core/cirq/value/condition.py @@ -115,7 +115,7 @@ def keys(self): ) def replace_key(self, current: 'cirq.MeasurementKey', replacement: 'cirq.MeasurementKey'): - return dataclasses.replace(self, expr=self.expr.subs({str(current): str(replacement)})) + return SympyCondition(self.expr.subs({str(current): sympy.Symbol(str(replacement))})) def __str__(self): return str(self.expr) diff --git a/cirq-core/cirq/value/condition_test.py b/cirq-core/cirq/value/condition_test.py index 47ac4788fb5..78dccaa683d 100644 --- a/cirq-core/cirq/value/condition_test.py +++ b/cirq-core/cirq/value/condition_test.py @@ -15,15 +15,16 @@ import re import pytest -import sympy.parsing.sympy_parser +import sympy import cirq -key_a = cirq.MeasurementKey('a') -key_b = cirq.MeasurementKey('b') -key_c = cirq.MeasurementKey('c') +key_a = cirq.MeasurementKey.parse_serialized('0:a') +key_b = cirq.MeasurementKey.parse_serialized('0:b') +key_c = cirq.MeasurementKey.parse_serialized('0:c') +key_x = cirq.MeasurementKey.parse_serialized('0:c') init_key_condition = cirq.KeyCondition(key_a) -init_sympy_condition = cirq.SympyCondition(sympy.parsing.sympy_parser.parse_expr('a >= 1')) +init_sympy_condition = cirq.SympyCondition(sympy.Symbol('0:a') >= 1) def test_key_condition_with_keys(): @@ -34,34 +35,36 @@ def test_key_condition_with_keys(): def test_key_condition_str(): - assert str(init_key_condition) == 'a' + assert str(init_key_condition) == '0:a' def test_key_condition_repr(): - assert repr(init_key_condition) == "cirq.KeyCondition(cirq.MeasurementKey(name='a'))" + assert ( + repr(init_key_condition) == "cirq.KeyCondition(cirq.MeasurementKey(path=('0',), name='a'))" + ) def test_key_condition_resolve(): - assert init_key_condition.resolve({'a': [1]}) - assert init_key_condition.resolve({'a': [2]}) - assert init_key_condition.resolve({'a': [0, 1]}) - assert init_key_condition.resolve({'a': [1, 0]}) - assert not init_key_condition.resolve({'a': [0]}) - assert not init_key_condition.resolve({'a': [0, 0]}) - assert not init_key_condition.resolve({'a': []}) - assert not init_key_condition.resolve({'a': [0], 'b': [1]}) + assert init_key_condition.resolve({'0:a': [1]}) + assert init_key_condition.resolve({'0:a': [2]}) + assert init_key_condition.resolve({'0:a': [0, 1]}) + assert init_key_condition.resolve({'0:a': [1, 0]}) + assert not init_key_condition.resolve({'0:a': [0]}) + assert not init_key_condition.resolve({'0:a': [0, 0]}) + assert not init_key_condition.resolve({'0:a': []}) + assert not init_key_condition.resolve({'0:a': [0], 'b': [1]}) with pytest.raises( - ValueError, match='Measurement key a missing when testing classical control' + ValueError, match='Measurement key 0:a missing when testing classical control' ): _ = init_key_condition.resolve({}) with pytest.raises( - ValueError, match='Measurement key a missing when testing classical control' + ValueError, match='Measurement key 0:a missing when testing classical control' ): - _ = init_key_condition.resolve({'b': [1]}) + _ = init_key_condition.resolve({'0:b': [1]}) def test_key_condition_qasm(): - assert init_key_condition.qasm == 'm_a!=0' + assert cirq.KeyCondition(cirq.MeasurementKey('a')).qasm == 'm_a!=0' def test_sympy_condition_with_keys(): @@ -72,33 +75,35 @@ def test_sympy_condition_with_keys(): def test_sympy_condition_str(): - assert str(init_sympy_condition) == 'a >= 1' + assert str(init_sympy_condition) == '0:a >= 1' def test_sympy_condition_repr(): assert ( repr(init_sympy_condition) - == "cirq.SympyCondition(GreaterThan(sympy.Symbol('a'), sympy.Integer(1)))" + == "cirq.SympyCondition(sympy.GreaterThan(sympy.Symbol('0:a'), sympy.Integer(1)))" ) def test_sympy_condition_resolve(): - assert init_sympy_condition.resolve({'a': [1]}) - assert init_sympy_condition.resolve({'a': [2]}) - assert init_sympy_condition.resolve({'a': [0, 1]}) - assert init_sympy_condition.resolve({'a': [1, 0]}) - assert not init_sympy_condition.resolve({'a': [0]}) - assert not init_sympy_condition.resolve({'a': [0, 0]}) - assert not init_sympy_condition.resolve({'a': []}) - assert not init_sympy_condition.resolve({'a': [0], 'b': [1]}) + assert init_sympy_condition.resolve({'0:a': [1]}) + assert init_sympy_condition.resolve({'0:a': [2]}) + assert init_sympy_condition.resolve({'0:a': [0, 1]}) + assert init_sympy_condition.resolve({'0:a': [1, 0]}) + assert not init_sympy_condition.resolve({'0:a': [0]}) + assert not init_sympy_condition.resolve({'0:a': [0, 0]}) + assert not init_sympy_condition.resolve({'0:a': []}) + assert not init_sympy_condition.resolve({'0:a': [0], 'b': [1]}) with pytest.raises( - ValueError, match=re.escape("Measurement keys ['a'] missing when testing classical control") + ValueError, + match=re.escape("Measurement keys ['0:a'] missing when testing classical control"), ): _ = init_sympy_condition.resolve({}) with pytest.raises( - ValueError, match=re.escape("Measurement keys ['a'] missing when testing classical control") + ValueError, + match=re.escape("Measurement keys ['0:a'] missing when testing classical control"), ): - _ = init_sympy_condition.resolve({'b': [1]}) + _ = init_sympy_condition.resolve({'0:b': [1]}) def test_sympy_condition_qasm(): From 12d38ca41583594d9dcf7344ac39afe9e9daa6f8 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Sun, 19 Dec 2021 19:58:05 -0800 Subject: [PATCH 29/89] lint --- cirq-core/cirq/value/condition_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cirq-core/cirq/value/condition_test.py b/cirq-core/cirq/value/condition_test.py index 78dccaa683d..73df224a35f 100644 --- a/cirq-core/cirq/value/condition_test.py +++ b/cirq-core/cirq/value/condition_test.py @@ -22,7 +22,6 @@ key_a = cirq.MeasurementKey.parse_serialized('0:a') key_b = cirq.MeasurementKey.parse_serialized('0:b') key_c = cirq.MeasurementKey.parse_serialized('0:c') -key_x = cirq.MeasurementKey.parse_serialized('0:c') init_key_condition = cirq.KeyCondition(key_a) init_sympy_condition = cirq.SympyCondition(sympy.Symbol('0:a') >= 1) From 724febb12c9a1ca9d6af79f72ab97ec9ce6630c2 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Sun, 19 Dec 2021 20:09:15 -0800 Subject: [PATCH 30/89] sympy.basic --- cirq-core/cirq/ops/classically_controlled_operation.py | 4 ++-- cirq-core/cirq/value/condition.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cirq-core/cirq/ops/classically_controlled_operation.py b/cirq-core/cirq/ops/classically_controlled_operation.py index f6188f29785..20ac131b565 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation.py +++ b/cirq-core/cirq/ops/classically_controlled_operation.py @@ -49,7 +49,7 @@ class ClassicallyControlledOperation(raw_types.Operation): def __init__( self, sub_operation: 'cirq.Operation', - conditions: Sequence[Union[str, 'cirq.MeasurementKey', 'cirq.Condition', sympy.Expr]], + conditions: Sequence[Union[str, 'cirq.MeasurementKey', 'cirq.Condition', sympy.Basic]], ): """Initializes a `ClassicallyControlledOperation`. @@ -81,7 +81,7 @@ def __init__( c = value.MeasurementKey.parse_serialized(c) if isinstance(c, value.MeasurementKey): c = value.KeyCondition(c) - if isinstance(c, sympy.Expr): + if isinstance(c, sympy.Basic): c = value.SympyCondition(c) conds.append(c) self._conditions: Tuple['cirq.Condition', ...] = tuple(conds) diff --git a/cirq-core/cirq/value/condition.py b/cirq-core/cirq/value/condition.py index 903abfab6fb..41981ff0c23 100644 --- a/cirq-core/cirq/value/condition.py +++ b/cirq-core/cirq/value/condition.py @@ -105,7 +105,7 @@ class SympyCondition(Condition): substituted in as the free variables. """ - expr: sympy.Expr + expr: sympy.Basic @property def keys(self): From b59869703ad36e46d41b1be141e38e39b72ef3c6 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Mon, 20 Dec 2021 10:13:30 -0800 Subject: [PATCH 31/89] Add sympy json resolvers for comparators --- cirq-core/cirq/json_resolver_cache.py | 6 ++++++ cirq-core/cirq/protocols/json_serialization.py | 15 ++++++++++++++- .../protocols/json_test_data/SympyCondition.json | 2 +- .../protocols/json_test_data/SympyCondition.repr | 2 +- .../protocols/json_test_data/sympy.Equality.json | 13 +++++++++++++ .../protocols/json_test_data/sympy.Equality.repr | 1 + .../json_test_data/sympy.GreaterThan.json | 13 +++++++++++++ .../json_test_data/sympy.GreaterThan.repr | 1 + .../protocols/json_test_data/sympy.LessThan.json | 13 +++++++++++++ .../protocols/json_test_data/sympy.LessThan.repr | 1 + .../json_test_data/sympy.StrictGreaterThan.json | 13 +++++++++++++ .../json_test_data/sympy.StrictGreaterThan.repr | 1 + .../json_test_data/sympy.StrictLessThan.json | 13 +++++++++++++ .../json_test_data/sympy.StrictLessThan.repr | 1 + .../json_test_data/sympy.Unequality.json | 13 +++++++++++++ .../json_test_data/sympy.Unequality.repr | 1 + cirq-core/cirq/value/condition_test.py | 9 ++------- 17 files changed, 108 insertions(+), 10 deletions(-) create mode 100644 cirq-core/cirq/protocols/json_test_data/sympy.Equality.json create mode 100644 cirq-core/cirq/protocols/json_test_data/sympy.Equality.repr create mode 100644 cirq-core/cirq/protocols/json_test_data/sympy.GreaterThan.json create mode 100644 cirq-core/cirq/protocols/json_test_data/sympy.GreaterThan.repr create mode 100644 cirq-core/cirq/protocols/json_test_data/sympy.LessThan.json create mode 100644 cirq-core/cirq/protocols/json_test_data/sympy.LessThan.repr create mode 100644 cirq-core/cirq/protocols/json_test_data/sympy.StrictGreaterThan.json create mode 100644 cirq-core/cirq/protocols/json_test_data/sympy.StrictGreaterThan.repr create mode 100644 cirq-core/cirq/protocols/json_test_data/sympy.StrictLessThan.json create mode 100644 cirq-core/cirq/protocols/json_test_data/sympy.StrictLessThan.repr create mode 100644 cirq-core/cirq/protocols/json_test_data/sympy.Unequality.json create mode 100644 cirq-core/cirq/protocols/json_test_data/sympy.Unequality.repr diff --git a/cirq-core/cirq/json_resolver_cache.py b/cirq-core/cirq/json_resolver_cache.py index 48b7d49d6d8..7632105c0cc 100644 --- a/cirq-core/cirq/json_resolver_cache.py +++ b/cirq-core/cirq/json_resolver_cache.py @@ -180,6 +180,12 @@ def _parallel_gate_op(gate, qubits): 'sympy.Add': lambda args: sympy.Add(*args), 'sympy.Mul': lambda args: sympy.Mul(*args), 'sympy.Pow': lambda args: sympy.Pow(*args), + 'sympy.GreaterThan': lambda args: sympy.GreaterThan(*args), + 'sympy.StrictGreaterThan': lambda args: sympy.StrictGreaterThan(*args), + 'sympy.LessThan': lambda args: sympy.LessThan(*args), + 'sympy.StrictLessThan': lambda args: sympy.StrictLessThan(*args), + 'sympy.Equality': lambda args: sympy.Equality(*args), + 'sympy.Unequality': lambda args: sympy.Unequality(*args), 'sympy.Float': lambda approx: sympy.Float(approx), 'sympy.Integer': sympy.Integer, 'sympy.Rational': sympy.Rational, diff --git a/cirq-core/cirq/protocols/json_serialization.py b/cirq-core/cirq/protocols/json_serialization.py index eb425c3f0c0..1c81f70ebbd 100644 --- a/cirq-core/cirq/protocols/json_serialization.py +++ b/cirq-core/cirq/protocols/json_serialization.py @@ -319,7 +319,20 @@ def default(self, o): if isinstance(o, sympy.Symbol): return {'cirq_type': 'sympy.Symbol', 'name': o.name} - if isinstance(o, (sympy.Add, sympy.Mul, sympy.Pow)): + if isinstance( + o, + ( + sympy.Add, + sympy.Mul, + sympy.Pow, + sympy.GreaterThan, + sympy.StrictGreaterThan, + sympy.LessThan, + sympy.StrictLessThan, + sympy.Equality, + sympy.Unequality, + ), + ): return {'cirq_type': f'sympy.{o.__class__.__name__}', 'args': o.args} if isinstance(o, sympy.Integer): diff --git a/cirq-core/cirq/protocols/json_test_data/SympyCondition.json b/cirq-core/cirq/protocols/json_test_data/SympyCondition.json index 1e42bf5b8ad..1dc17ec7710 100644 --- a/cirq-core/cirq/protocols/json_test_data/SympyCondition.json +++ b/cirq-core/cirq/protocols/json_test_data/SympyCondition.json @@ -2,7 +2,7 @@ "cirq_type": "SympyCondition", "expr": { - "cirq_type": "sympy.Mul", + "cirq_type": "sympy.GreaterThan", "args": [ { "cirq_type": "sympy.Symbol", diff --git a/cirq-core/cirq/protocols/json_test_data/SympyCondition.repr b/cirq-core/cirq/protocols/json_test_data/SympyCondition.repr index 66ad4c5ced0..6c961a2a1f6 100644 --- a/cirq-core/cirq/protocols/json_test_data/SympyCondition.repr +++ b/cirq-core/cirq/protocols/json_test_data/SympyCondition.repr @@ -1 +1 @@ -cirq.SympyCondition(sympy.Mul(sympy.Symbol('a'), sympy.Symbol('b'))) \ No newline at end of file +cirq.SympyCondition(sympy.GreaterThan(sympy.Symbol('a'), sympy.Symbol('b'))) \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/sympy.Equality.json b/cirq-core/cirq/protocols/json_test_data/sympy.Equality.json new file mode 100644 index 00000000000..0d32f22cca8 --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/sympy.Equality.json @@ -0,0 +1,13 @@ +{ + "cirq_type": "sympy.Equality", + "args": [ + { + "cirq_type": "sympy.Symbol", + "name": "s" + }, + { + "cirq_type": "sympy.Symbol", + "name": "t" + } + ] +} \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/sympy.Equality.repr b/cirq-core/cirq/protocols/json_test_data/sympy.Equality.repr new file mode 100644 index 00000000000..2c8528e76ea --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/sympy.Equality.repr @@ -0,0 +1 @@ +sympy.Equality(sympy.Symbol('s'), sympy.Symbol('t')) \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/sympy.GreaterThan.json b/cirq-core/cirq/protocols/json_test_data/sympy.GreaterThan.json new file mode 100644 index 00000000000..1aad226f05b --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/sympy.GreaterThan.json @@ -0,0 +1,13 @@ +{ + "cirq_type": "sympy.GreaterThan", + "args": [ + { + "cirq_type": "sympy.Symbol", + "name": "s" + }, + { + "cirq_type": "sympy.Symbol", + "name": "t" + } + ] +} \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/sympy.GreaterThan.repr b/cirq-core/cirq/protocols/json_test_data/sympy.GreaterThan.repr new file mode 100644 index 00000000000..7675d9e3021 --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/sympy.GreaterThan.repr @@ -0,0 +1 @@ +sympy.GreaterThan(sympy.Symbol('s'), sympy.Symbol('t')) \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/sympy.LessThan.json b/cirq-core/cirq/protocols/json_test_data/sympy.LessThan.json new file mode 100644 index 00000000000..4bd72168849 --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/sympy.LessThan.json @@ -0,0 +1,13 @@ +{ + "cirq_type": "sympy.LessThan", + "args": [ + { + "cirq_type": "sympy.Symbol", + "name": "s" + }, + { + "cirq_type": "sympy.Symbol", + "name": "t" + } + ] +} \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/sympy.LessThan.repr b/cirq-core/cirq/protocols/json_test_data/sympy.LessThan.repr new file mode 100644 index 00000000000..03c09c8dd57 --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/sympy.LessThan.repr @@ -0,0 +1 @@ +sympy.LessThan(sympy.Symbol('s'), sympy.Symbol('t')) \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/sympy.StrictGreaterThan.json b/cirq-core/cirq/protocols/json_test_data/sympy.StrictGreaterThan.json new file mode 100644 index 00000000000..adfcfc3c110 --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/sympy.StrictGreaterThan.json @@ -0,0 +1,13 @@ +{ + "cirq_type": "sympy.StrictGreaterThan", + "args": [ + { + "cirq_type": "sympy.Symbol", + "name": "s" + }, + { + "cirq_type": "sympy.Symbol", + "name": "t" + } + ] +} \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/sympy.StrictGreaterThan.repr b/cirq-core/cirq/protocols/json_test_data/sympy.StrictGreaterThan.repr new file mode 100644 index 00000000000..35f94e3bbd1 --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/sympy.StrictGreaterThan.repr @@ -0,0 +1 @@ +sympy.StrictGreaterThan(sympy.Symbol('s'), sympy.Symbol('t')) \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/sympy.StrictLessThan.json b/cirq-core/cirq/protocols/json_test_data/sympy.StrictLessThan.json new file mode 100644 index 00000000000..68b4605b90a --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/sympy.StrictLessThan.json @@ -0,0 +1,13 @@ +{ + "cirq_type": "sympy.StrictLessThan", + "args": [ + { + "cirq_type": "sympy.Symbol", + "name": "s" + }, + { + "cirq_type": "sympy.Symbol", + "name": "t" + } + ] +} \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/sympy.StrictLessThan.repr b/cirq-core/cirq/protocols/json_test_data/sympy.StrictLessThan.repr new file mode 100644 index 00000000000..dedcbc7e674 --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/sympy.StrictLessThan.repr @@ -0,0 +1 @@ +sympy.StrictLessThan(sympy.Symbol('s'), sympy.Symbol('t')) \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/sympy.Unequality.json b/cirq-core/cirq/protocols/json_test_data/sympy.Unequality.json new file mode 100644 index 00000000000..b58edc9fcff --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/sympy.Unequality.json @@ -0,0 +1,13 @@ +{ + "cirq_type": "sympy.Unequality", + "args": [ + { + "cirq_type": "sympy.Symbol", + "name": "s" + }, + { + "cirq_type": "sympy.Symbol", + "name": "t" + } + ] +} \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/sympy.Unequality.repr b/cirq-core/cirq/protocols/json_test_data/sympy.Unequality.repr new file mode 100644 index 00000000000..ea4bce67718 --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/sympy.Unequality.repr @@ -0,0 +1 @@ +sympy.Unequality(sympy.Symbol('s'), sympy.Symbol('t')) \ No newline at end of file diff --git a/cirq-core/cirq/value/condition_test.py b/cirq-core/cirq/value/condition_test.py index 73df224a35f..fd80033a29a 100644 --- a/cirq-core/cirq/value/condition_test.py +++ b/cirq-core/cirq/value/condition_test.py @@ -38,9 +38,7 @@ def test_key_condition_str(): def test_key_condition_repr(): - assert ( - repr(init_key_condition) == "cirq.KeyCondition(cirq.MeasurementKey(path=('0',), name='a'))" - ) + cirq.testing.assert_equivalent_repr(init_key_condition) def test_key_condition_resolve(): @@ -78,10 +76,7 @@ def test_sympy_condition_str(): def test_sympy_condition_repr(): - assert ( - repr(init_sympy_condition) - == "cirq.SympyCondition(sympy.GreaterThan(sympy.Symbol('0:a'), sympy.Integer(1)))" - ) + cirq.testing.assert_equivalent_repr(init_sympy_condition) def test_sympy_condition_resolve(): From d167de7894183c5362c4ee05f1565d924aba6030 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Mon, 20 Dec 2021 10:35:15 -0800 Subject: [PATCH 32/89] _from_json_dict_ --- cirq-core/cirq/value/condition.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/cirq-core/cirq/value/condition.py b/cirq-core/cirq/value/condition.py index 41981ff0c23..d98783d43ff 100644 --- a/cirq-core/cirq/value/condition.py +++ b/cirq-core/cirq/value/condition.py @@ -11,11 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import abc import dataclasses from typing import Dict, Mapping, Sequence, Tuple, TYPE_CHECKING import sympy +import sympy.parsing.sympy_parser as sympy_parser from cirq._compat import proper_repr from cirq.protocols import json_serialization, measurement_key_protocol as mkp @@ -91,6 +93,10 @@ def resolve(self, measurements: Mapping[str, Sequence[int]]) -> bool: def _json_dict_(self): return json_serialization.dataclass_json_dict(self) + @classmethod + def _from_json_dict_(cls, key, **kwargs): + return cls(key=key) + @property def qasm(self): return f'm_{self.key}!=0' @@ -137,6 +143,10 @@ def value(k): def _json_dict_(self): return json_serialization.dataclass_json_dict(self) + @classmethod + def _from_json_dict_(cls, expr, **kwargs): + return cls(expr=expr) + @property def qasm(self): raise NotImplementedError() From 72d82eb62a8fbbe958a8d02894b9b36ec59d086b Mon Sep 17 00:00:00 2001 From: daxfohl Date: Mon, 20 Dec 2021 10:53:47 -0800 Subject: [PATCH 33/89] lint --- cirq-core/cirq/value/condition.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cirq-core/cirq/value/condition.py b/cirq-core/cirq/value/condition.py index d98783d43ff..920579b5bca 100644 --- a/cirq-core/cirq/value/condition.py +++ b/cirq-core/cirq/value/condition.py @@ -17,7 +17,6 @@ from typing import Dict, Mapping, Sequence, Tuple, TYPE_CHECKING import sympy -import sympy.parsing.sympy_parser as sympy_parser from cirq._compat import proper_repr from cirq.protocols import json_serialization, measurement_key_protocol as mkp From b55188e9c8bee5cacd1c03c8ed4feb53eb80ba87 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Mon, 20 Dec 2021 10:56:12 -0800 Subject: [PATCH 34/89] reduce fixed_tokens --- cirq-core/cirq/_compat.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/cirq-core/cirq/_compat.py b/cirq-core/cirq/_compat.py index c56ff4b853b..4cb92a2dff3 100644 --- a/cirq-core/cirq/_compat.py +++ b/cirq-core/cirq/_compat.py @@ -35,7 +35,23 @@ def proper_repr(value: Any) -> str: if isinstance(value, sympy.Basic): # HACK: work around https://github.com/sympy/sympy/issues/16074 - fixed_tokens = dir(sympy) + fixed_tokens = [ + 'Symbol', + 'pi', + 'Mul', + 'Pow', + 'Add', + 'Mod', + 'Integer', + 'Float', + 'Rational', + 'GreaterThan', + 'StrictGreaterThan', + 'LessThan', + 'StrictLessThan', + 'Equality', + 'Unequality', + ] class Printer(sympy.printing.repr.ReprPrinter): def _print(self, expr, **kwargs): From ca56bd8d98e8255aebeaa7443cb2274a61468edd Mon Sep 17 00:00:00 2001 From: daxfohl Date: Mon, 20 Dec 2021 14:21:41 -0800 Subject: [PATCH 35/89] more tests --- .../classically_controlled_operation_test.py | 88 +++++++++++++++++++ 1 file changed, 88 insertions(+) diff --git a/cirq-core/cirq/ops/classically_controlled_operation_test.py b/cirq-core/cirq/ops/classically_controlled_operation_test.py index a797eea2357..955769585d2 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation_test.py +++ b/cirq-core/cirq/ops/classically_controlled_operation_test.py @@ -700,3 +700,91 @@ def test_sympy(): # m_result should now be set iff j > i. result = cirq.Simulator().run(circuit) assert result.measurements['m_result'][0][0] == (j > i) + + +def test_scope_local_sympy(): + q = cirq.LineQubit(0) + inner = cirq.Circuit( + cirq.measure(q, key='a'), + cirq.X(q).with_classical_controls(sympy.Symbol('a')), + ) + middle = cirq.Circuit(cirq.CircuitOperation(inner.freeze(), repetitions=2)) + outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2) + circuit = outer_subcircuit.mapped_circuit(deep=True) + internal_control_keys = [ + str(condition) for op in circuit.all_operations() for condition in cirq.control_keys(op) + ] + assert internal_control_keys == ['0:0:a', '0:1:a', '1:0:a', '1:1:a'] + assert not cirq.control_keys(outer_subcircuit) + assert not cirq.control_keys(circuit) + assert circuit == cirq.Circuit(cirq.decompose(outer_subcircuit)) + + +def test_extern_scope_sympy(): + q = cirq.LineQubit(0) + inner = cirq.Circuit( + cirq.measure(q, key='a'), + cirq.X(q).with_classical_controls(sympy.Symbol('b')), + ) + middle = cirq.Circuit( + cirq.measure(q, key=cirq.MeasurementKey('b')), + cirq.CircuitOperation(inner.freeze(), repetitions=2), + ) + outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2) + circuit = outer_subcircuit.mapped_circuit(deep=True) + internal_control_keys = [ + str(condition) for op in circuit.all_operations() for condition in cirq.control_keys(op) + ] + assert internal_control_keys == ['0:b', '0:b', '1:b', '1:b'] + assert not cirq.control_keys(outer_subcircuit) + assert not cirq.control_keys(circuit) + assert circuit == cirq.Circuit(cirq.decompose(outer_subcircuit)) + + +def test_scope_extern_mismatch_sympy(): + q = cirq.LineQubit(0) + inner = cirq.Circuit( + cirq.measure(q, key='a'), + cirq.X(q).with_classical_controls(sympy.Symbol('b')), + ) + middle = cirq.Circuit( + cirq.measure(q, key=cirq.MeasurementKey('b', ('0',))), + cirq.CircuitOperation(inner.freeze(), repetitions=2), + ) + outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2) + circuit = outer_subcircuit.mapped_circuit(deep=True) + internal_control_keys = [ + str(condition) for op in circuit.all_operations() for condition in cirq.control_keys(op) + ] + assert internal_control_keys == ['b', 'b', 'b', 'b'] + assert cirq.control_keys(outer_subcircuit) == {cirq.MeasurementKey('b')} + assert cirq.control_keys(circuit) == {cirq.MeasurementKey('b')} + assert circuit == cirq.Circuit(cirq.decompose(outer_subcircuit)) + + +def test_scope_root_sympy(): + q = cirq.LineQubit(0) + inner = cirq.Circuit( + cirq.measure(q, key='a'), + cirq.X(q).with_classical_controls(sympy.Symbol('b')), + ) + middle = cirq.Circuit( + cirq.measure(q, key=cirq.MeasurementKey('c')), + cirq.CircuitOperation(inner.freeze(), repetitions=2), + ) + outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2) + circuit = outer_subcircuit.mapped_circuit(deep=True) + internal_control_keys = [ + str(condition) for op in circuit.all_operations() for condition in cirq.control_keys(op) + ] + assert internal_control_keys == ['b', 'b', 'b', 'b'] + assert cirq.control_keys(outer_subcircuit) == {cirq.MeasurementKey('b')} + assert cirq.control_keys(circuit) == {cirq.MeasurementKey('b')} + assert circuit == cirq.Circuit(cirq.decompose(outer_subcircuit)) + + +def test_sympy_with_path_prefix(): + q = cirq.LineQubit(0) + op = cirq.X(q).with_classical_controls(sympy.Symbol('b')) + prefixed = cirq.with_key_path_prefix(op, ('0',)) + assert cirq.control_keys(prefixed) == {'0:b'} From 6f8e3442e75dbad34dbff9ff6681ea2f41e881f7 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Mon, 20 Dec 2021 14:36:34 -0800 Subject: [PATCH 36/89] format --- cirq-core/cirq/ops/classically_controlled_operation.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/cirq-core/cirq/ops/classically_controlled_operation.py b/cirq-core/cirq/ops/classically_controlled_operation.py index 4bad181ef97..57bbc8098ac 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation.py +++ b/cirq-core/cirq/ops/classically_controlled_operation.py @@ -192,12 +192,9 @@ def _with_rescoped_keys_( path: Tuple[str, ...], bindable_keys: FrozenSet['cirq.MeasurementKey'], ) -> 'ClassicallyControlledOperation': - conditions = [ - protocols.with_rescoped_keys(c, path, bindable_keys) for c in self._conditions - ] + conds = [protocols.with_rescoped_keys(c, path, bindable_keys) for c in self._conditions] sub_operation = protocols.with_rescoped_keys(self._sub_operation, path, bindable_keys) - sub_operation = self._sub_operation if sub_operation is NotImplemented else sub_operation - return sub_operation.with_classical_controls(*conditions) + return sub_operation.with_classical_controls(*conds) def _control_keys_(self) -> FrozenSet[value.MeasurementKey]: local_keys: FrozenSet[value.MeasurementKey] = frozenset( From 689719feea4110da9741bb07ca5dd5f24a2d0833 Mon Sep 17 00:00:00 2001 From: Dax Fohl Date: Mon, 20 Dec 2021 16:16:34 -0800 Subject: [PATCH 37/89] Key --- cirq-core/cirq/ops/classically_controlled_operation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cirq-core/cirq/ops/classically_controlled_operation.py b/cirq-core/cirq/ops/classically_controlled_operation.py index 57bbc8098ac..1eda5ac784e 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation.py +++ b/cirq-core/cirq/ops/classically_controlled_operation.py @@ -197,7 +197,7 @@ def _with_rescoped_keys_( return sub_operation.with_classical_controls(*conds) def _control_keys_(self) -> FrozenSet[value.MeasurementKey]: - local_keys: FrozenSet[value.MeasurementKey] = frozenset( + local_keys: FrozenSet['cirq.MeasurementKey'] = frozenset( k for condition in self._conditions for k in condition.keys ) return local_keys.union(protocols.control_keys(self._sub_operation)) From 96ba4e97e8d98c68ca864ca3fa6796bd17a61576 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Mon, 20 Dec 2021 18:47:21 -0800 Subject: [PATCH 38/89] combined test --- .../classically_controlled_operation_test.py | 86 +++---------------- 1 file changed, 13 insertions(+), 73 deletions(-) diff --git a/cirq-core/cirq/ops/classically_controlled_operation_test.py b/cirq-core/cirq/ops/classically_controlled_operation_test.py index 955769585d2..b7d2d1e0955 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation_test.py +++ b/cirq-core/cirq/ops/classically_controlled_operation_test.py @@ -702,89 +702,29 @@ def test_sympy(): assert result.measurements['m_result'][0][0] == (j > i) -def test_scope_local_sympy(): +def test_sympy_path_prefix(): q = cirq.LineQubit(0) - inner = cirq.Circuit( - cirq.measure(q, key='a'), - cirq.X(q).with_classical_controls(sympy.Symbol('a')), - ) - middle = cirq.Circuit(cirq.CircuitOperation(inner.freeze(), repetitions=2)) - outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2) - circuit = outer_subcircuit.mapped_circuit(deep=True) - internal_control_keys = [ - str(condition) for op in circuit.all_operations() for condition in cirq.control_keys(op) - ] - assert internal_control_keys == ['0:0:a', '0:1:a', '1:0:a', '1:1:a'] - assert not cirq.control_keys(outer_subcircuit) - assert not cirq.control_keys(circuit) - assert circuit == cirq.Circuit(cirq.decompose(outer_subcircuit)) - - -def test_extern_scope_sympy(): - q = cirq.LineQubit(0) - inner = cirq.Circuit( - cirq.measure(q, key='a'), - cirq.X(q).with_classical_controls(sympy.Symbol('b')), - ) - middle = cirq.Circuit( - cirq.measure(q, key=cirq.MeasurementKey('b')), - cirq.CircuitOperation(inner.freeze(), repetitions=2), - ) - outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2) - circuit = outer_subcircuit.mapped_circuit(deep=True) - internal_control_keys = [ - str(condition) for op in circuit.all_operations() for condition in cirq.control_keys(op) - ] - assert internal_control_keys == ['0:b', '0:b', '1:b', '1:b'] - assert not cirq.control_keys(outer_subcircuit) - assert not cirq.control_keys(circuit) - assert circuit == cirq.Circuit(cirq.decompose(outer_subcircuit)) - - -def test_scope_extern_mismatch_sympy(): - q = cirq.LineQubit(0) - inner = cirq.Circuit( - cirq.measure(q, key='a'), - cirq.X(q).with_classical_controls(sympy.Symbol('b')), - ) - middle = cirq.Circuit( - cirq.measure(q, key=cirq.MeasurementKey('b', ('0',))), - cirq.CircuitOperation(inner.freeze(), repetitions=2), - ) - outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2) - circuit = outer_subcircuit.mapped_circuit(deep=True) - internal_control_keys = [ - str(condition) for op in circuit.all_operations() for condition in cirq.control_keys(op) - ] - assert internal_control_keys == ['b', 'b', 'b', 'b'] - assert cirq.control_keys(outer_subcircuit) == {cirq.MeasurementKey('b')} - assert cirq.control_keys(circuit) == {cirq.MeasurementKey('b')} - assert circuit == cirq.Circuit(cirq.decompose(outer_subcircuit)) + op = cirq.X(q).with_classical_controls(sympy.Symbol('b')) + prefixed = cirq.with_key_path_prefix(op, ('0',)) + assert cirq.control_keys(prefixed) == {'0:b'} -def test_scope_root_sympy(): +def test_sympy_scope(): q = cirq.LineQubit(0) + a, b, c, d = sympy.symbols('a b c d') inner = cirq.Circuit( cirq.measure(q, key='a'), - cirq.X(q).with_classical_controls(sympy.Symbol('b')), + cirq.X(q).with_classical_controls(a + b + c + d), ) middle = cirq.Circuit( - cirq.measure(q, key=cirq.MeasurementKey('c')), + cirq.measure(q, key='b'), + cirq.measure(q, key=cirq.MeasurementKey('c', ('0',))), cirq.CircuitOperation(inner.freeze(), repetitions=2), ) outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2) circuit = outer_subcircuit.mapped_circuit(deep=True) - internal_control_keys = [ - str(condition) for op in circuit.all_operations() for condition in cirq.control_keys(op) - ] - assert internal_control_keys == ['b', 'b', 'b', 'b'] - assert cirq.control_keys(outer_subcircuit) == {cirq.MeasurementKey('b')} - assert cirq.control_keys(circuit) == {cirq.MeasurementKey('b')} + internal_controls = [str(k) for op in circuit.all_operations() for k in cirq.control_keys(op)] + assert set(internal_controls) == {'0:0:a', '0:1:a', '1:0:a', '1:1:a', '0:b', '1:b', 'c', 'd'} + assert cirq.control_keys(outer_subcircuit) == {'c', 'd'} + assert cirq.control_keys(circuit) == {'c', 'd'} assert circuit == cirq.Circuit(cirq.decompose(outer_subcircuit)) - - -def test_sympy_with_path_prefix(): - q = cirq.LineQubit(0) - op = cirq.X(q).with_classical_controls(sympy.Symbol('b')) - prefixed = cirq.with_key_path_prefix(op, ('0',)) - assert cirq.control_keys(prefixed) == {'0:b'} From c56d6bb557c13ab3c9a8125c887c5b66dd91ca67 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Thu, 23 Dec 2021 14:04:04 -0800 Subject: [PATCH 39/89] lint --- cirq-core/cirq/protocols/measurement_key_protocol.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/cirq-core/cirq/protocols/measurement_key_protocol.py b/cirq-core/cirq/protocols/measurement_key_protocol.py index 504d8b0c250..639df1aa180 100644 --- a/cirq-core/cirq/protocols/measurement_key_protocol.py +++ b/cirq-core/cirq/protocols/measurement_key_protocol.py @@ -20,12 +20,6 @@ from cirq import value from cirq._doc import doc_private -if TYPE_CHECKING: - import cirq - -if TYPE_CHECKING: - import cirq - if TYPE_CHECKING: import cirq From 681a00865412236b26206b5b864c43e28b26d127 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Thu, 23 Dec 2021 14:26:10 -0800 Subject: [PATCH 40/89] Docstrings --- cirq-core/cirq/contrib/quimb/mps_simulator.py | 4 ++++ cirq-core/cirq/sim/act_on_args.py | 2 ++ cirq-core/cirq/sim/act_on_args_container.py | 2 ++ cirq-core/cirq/sim/act_on_density_matrix_args.py | 2 ++ cirq-core/cirq/sim/act_on_state_vector_args.py | 2 ++ cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py | 2 ++ cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py | 2 ++ cirq-core/cirq/sim/clifford/clifford_simulator.py | 2 ++ cirq-core/cirq/sim/density_matrix_simulator.py | 2 ++ cirq-core/cirq/sim/simulator_base.py | 2 ++ cirq-core/cirq/sim/sparse_simulator.py | 2 ++ 11 files changed, 24 insertions(+) diff --git a/cirq-core/cirq/contrib/quimb/mps_simulator.py b/cirq-core/cirq/contrib/quimb/mps_simulator.py index 555a829b8f0..170435fd373 100644 --- a/cirq-core/cirq/contrib/quimb/mps_simulator.py +++ b/cirq-core/cirq/contrib/quimb/mps_simulator.py @@ -103,6 +103,8 @@ def _create_partial_act_on_args( is often used in specifying the initial state, i.e. the ordering of the computational basis states. logs: A mutable object that measurements are recorded into. + measured_qubits: A dictionary that contains the qubits that were + measured in each measurement. Returns: MPSState args for simulating the Circuit. @@ -245,6 +247,8 @@ def __init__( initial_state: An integer representing the initial state. log_of_measurement_results: A mutable object that measurements are being recorded into. + measured_qubits: A dictionary that contains the qubits that were + measured in each measurement. Raises: ValueError: If the grouping does not cover the qubits. diff --git a/cirq-core/cirq/sim/act_on_args.py b/cirq-core/cirq/sim/act_on_args.py index 9e6ed26d8a4..48eef5b5721 100644 --- a/cirq-core/cirq/sim/act_on_args.py +++ b/cirq-core/cirq/sim/act_on_args.py @@ -59,6 +59,8 @@ def __init__( ordering of the computational basis states. log_of_measurement_results: A mutable object that measurements are being recorded into. + measured_qubits: A dictionary that contains the qubits that were + measured in each measurement. """ if prng is None: prng = cast(np.random.RandomState, np.random) diff --git a/cirq-core/cirq/sim/act_on_args_container.py b/cirq-core/cirq/sim/act_on_args_container.py index 3a6c455e57b..107caaa2805 100644 --- a/cirq-core/cirq/sim/act_on_args_container.py +++ b/cirq-core/cirq/sim/act_on_args_container.py @@ -65,6 +65,8 @@ def __init__( at the end. log_of_measurement_results: A mutable object that measurements are being recorded into. + measured_qubits: A dictionary that contains the qubits that were + measured in each measurement. """ self.args = args self._qubits = tuple(qubits) diff --git a/cirq-core/cirq/sim/act_on_density_matrix_args.py b/cirq-core/cirq/sim/act_on_density_matrix_args.py index 976ea529066..18d039a17be 100644 --- a/cirq-core/cirq/sim/act_on_density_matrix_args.py +++ b/cirq-core/cirq/sim/act_on_density_matrix_args.py @@ -61,6 +61,8 @@ def __init__( effects. log_of_measurement_results: A mutable object that measurements are being recorded into. + measured_qubits: A dictionary that contains the qubits that were + measured in each measurement. """ super().__init__(prng, qubits, log_of_measurement_results, measured_qubits) self.target_tensor = target_tensor diff --git a/cirq-core/cirq/sim/act_on_state_vector_args.py b/cirq-core/cirq/sim/act_on_state_vector_args.py index b3de22831c0..3bc24978ebd 100644 --- a/cirq-core/cirq/sim/act_on_state_vector_args.py +++ b/cirq-core/cirq/sim/act_on_state_vector_args.py @@ -64,6 +64,8 @@ def __init__( effects. log_of_measurement_results: A mutable object that measurements are being recorded into. + measured_qubits: A dictionary that contains the qubits that were + measured in each measurement. """ super().__init__(prng, qubits, log_of_measurement_results, measured_qubits) self.target_tensor = target_tensor diff --git a/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py b/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py index 34ee040fa40..95c6951ba31 100644 --- a/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py +++ b/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py @@ -56,6 +56,8 @@ def __init__( effects. log_of_measurement_results: A mutable object that measurements are being recorded into. + measured_qubits: A dictionary that contains the qubits that were + measured in each measurement. """ super().__init__(prng, qubits, log_of_measurement_results, measured_qubits) self.tableau = tableau diff --git a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py index 835c5145563..aa1ee5590f1 100644 --- a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py +++ b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py @@ -54,6 +54,8 @@ def __init__( effects. log_of_measurement_results: A mutable object that measurements are being recorded into. + measured_qubits: A dictionary that contains the qubits that were + measured in each measurement. """ super().__init__(prng, qubits, log_of_measurement_results, measured_qubits) self.state = state diff --git a/cirq-core/cirq/sim/clifford/clifford_simulator.py b/cirq-core/cirq/sim/clifford/clifford_simulator.py index f46fc07c74a..0e1c15ad864 100644 --- a/cirq-core/cirq/sim/clifford/clifford_simulator.py +++ b/cirq-core/cirq/sim/clifford/clifford_simulator.py @@ -80,6 +80,8 @@ def _create_partial_act_on_args( is often used in specifying the initial state, i.e. the ordering of the computational basis states. logs: A log of the results of measurement that is added to. + measured_qubits: A dictionary that contains the qubits that were + measured in each measurement. Returns: ActOnStabilizerChFormArgs for the circuit. diff --git a/cirq-core/cirq/sim/density_matrix_simulator.py b/cirq-core/cirq/sim/density_matrix_simulator.py index 38877097423..a3ce424f247 100644 --- a/cirq-core/cirq/sim/density_matrix_simulator.py +++ b/cirq-core/cirq/sim/density_matrix_simulator.py @@ -188,6 +188,8 @@ def _create_partial_act_on_args( is often used in specifying the initial state, i.e. the ordering of the computational basis states. logs: The log of measurement results that is added into. + measured_qubits: A dictionary that contains the qubits that were + measured in each measurement. Returns: ActOnDensityMatrixArgs for the circuit. diff --git a/cirq-core/cirq/sim/simulator_base.py b/cirq-core/cirq/sim/simulator_base.py index 40e19e5dbce..a2eb1eddf29 100644 --- a/cirq-core/cirq/sim/simulator_base.py +++ b/cirq-core/cirq/sim/simulator_base.py @@ -138,6 +138,8 @@ def _create_partial_act_on_args( qubits: The sequence of qubits to represent. logs: The structure to hold measurement logs. A single instance should be shared among all ActOnArgs within the simulation. + measured_qubits: A dictionary that contains the qubits that were + measured in each measurement. """ @abc.abstractmethod diff --git a/cirq-core/cirq/sim/sparse_simulator.py b/cirq-core/cirq/sim/sparse_simulator.py index 548237c1032..e1e4b971bca 100644 --- a/cirq-core/cirq/sim/sparse_simulator.py +++ b/cirq-core/cirq/sim/sparse_simulator.py @@ -187,6 +187,8 @@ def _create_partial_act_on_args( is often used in specifying the initial state, i.e. the ordering of the computational basis states. logs: Log of the measurement results. + measured_qubits: A dictionary that contains the qubits that were + measured in each measurement. Returns: ActOnStateVectorArgs for the circuit. From 47d8288b7573a1f5d221d05071dc8af9e3a0d7fa Mon Sep 17 00:00:00 2001 From: daxfohl Date: Thu, 23 Dec 2021 20:52:10 -0800 Subject: [PATCH 41/89] ClassicalData class --- cirq-core/cirq/__init__.py | 1 + cirq-core/cirq/contrib/quimb/mps_simulator.py | 21 +++--- .../ops/classically_controlled_operation.py | 3 +- cirq-core/cirq/sim/act_on_args.py | 29 +++----- cirq-core/cirq/sim/act_on_args_container.py | 23 +++--- .../cirq/sim/act_on_args_container_test.py | 2 +- .../cirq/sim/act_on_density_matrix_args.py | 8 +-- .../cirq/sim/act_on_state_vector_args.py | 8 +-- .../clifford/act_on_clifford_tableau_args.py | 8 +-- .../act_on_stabilizer_ch_form_args.py | 10 +-- .../cirq/sim/clifford/clifford_simulator.py | 12 ++-- .../cirq/sim/density_matrix_simulator.py | 13 ++-- cirq-core/cirq/sim/operation_target.py | 4 +- cirq-core/cirq/sim/simulator_base.py | 53 +++++++++----- cirq-core/cirq/sim/simulator_base_test.py | 17 +++-- cirq-core/cirq/sim/sparse_simulator.py | 13 ++-- cirq-core/cirq/value/__init__.py | 4 ++ cirq-core/cirq/value/classical_data.py | 72 +++++++++++++++++++ cirq-core/cirq/value/condition.py | 26 +++---- .../calibration/engine_simulator.py | 10 --- 20 files changed, 197 insertions(+), 140 deletions(-) create mode 100644 cirq-core/cirq/value/classical_data.py diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index 3d8c90da603..33440246bd1 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -483,6 +483,7 @@ canonicalize_half_turns, chosen_angle_to_canonical_half_turns, chosen_angle_to_half_turns, + ClassicalData, Condition, Duration, DURATION_LIKE, diff --git a/cirq-core/cirq/contrib/quimb/mps_simulator.py b/cirq-core/cirq/contrib/quimb/mps_simulator.py index 170435fd373..c9a07293df4 100644 --- a/cirq-core/cirq/contrib/quimb/mps_simulator.py +++ b/cirq-core/cirq/contrib/quimb/mps_simulator.py @@ -87,12 +87,11 @@ def __init__( seed=seed, ) - def _create_partial_act_on_args( + def _create_partial_act_on_args_ex( self, initial_state: Union[int, 'MPSState'], qubits: Sequence['cirq.Qid'], - logs: Dict[str, Any], - measured_qubits: Dict[str, Tuple['cirq.Qid', ...]], + classical_data: 'cirq.ClassicalData', ) -> 'MPSState': """Creates MPSState args for simulating the Circuit. @@ -102,9 +101,8 @@ def _create_partial_act_on_args( qubits: Determines the canonical ordering of the qubits. This is often used in specifying the initial state, i.e. the ordering of the computational basis states. - logs: A mutable object that measurements are recorded into. - measured_qubits: A dictionary that contains the qubits that were - measured in each measurement. + classical_data: The shared classical data container for this + simulation. Returns: MPSState args for simulating the Circuit. @@ -118,8 +116,7 @@ def _create_partial_act_on_args( simulation_options=self.simulation_options, grouping=self.grouping, initial_state=initial_state, - log_of_measurement_results=logs, - measured_qubits=measured_qubits, + classical_data=classical_data, ) def _create_step_result( @@ -233,7 +230,7 @@ def __init__( grouping: Optional[Dict['cirq.Qid', int]] = None, initial_state: int = 0, log_of_measurement_results: Dict[str, Any] = None, - measured_qubits: Dict[str, Tuple['cirq.Qid', ...]] = None, + classical_data: 'cirq.ClassicalData' = None, ): """Creates and MPSState @@ -247,13 +244,13 @@ def __init__( initial_state: An integer representing the initial state. log_of_measurement_results: A mutable object that measurements are being recorded into. - measured_qubits: A dictionary that contains the qubits that were - measured in each measurement. + classical_data: The shared classical data container for this + simulation. Raises: ValueError: If the grouping does not cover the qubits. """ - super().__init__(prng, qubits, log_of_measurement_results, measured_qubits) + super().__init__(prng, qubits, log_of_measurement_results, classical_data) qubit_map = self.qubit_map self.grouping = qubit_map if grouping is None else grouping if self.grouping.keys() != self.qubit_map.keys(): diff --git a/cirq-core/cirq/ops/classically_controlled_operation.py b/cirq-core/cirq/ops/classically_controlled_operation.py index 35a8ac50c0e..dfc7ec13c5a 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation.py +++ b/cirq-core/cirq/ops/classically_controlled_operation.py @@ -176,8 +176,7 @@ def _json_dict_(self) -> Dict[str, Any]: } def _act_on_(self, args: 'cirq.OperationTarget') -> bool: - measurements, qubits = args.log_of_measurement_results, args.measured_qubits - if all(c.resolve(measurements, qubits) for c in self._conditions): + if all(c.resolve(args.classical_data) for c in self._conditions): protocols.act_on(self._sub_operation, args) return True diff --git a/cirq-core/cirq/sim/act_on_args.py b/cirq-core/cirq/sim/act_on_args.py index 7f2227f4276..8675bdf3b5a 100644 --- a/cirq-core/cirq/sim/act_on_args.py +++ b/cirq-core/cirq/sim/act_on_args.py @@ -29,7 +29,7 @@ import numpy as np -from cirq import protocols, ops +from cirq import protocols, ops, value from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits from cirq.sim.operation_target import OperationTarget @@ -47,7 +47,7 @@ def __init__( prng: np.random.RandomState = None, qubits: Sequence['cirq.Qid'] = None, log_of_measurement_results: Dict[str, List[int]] = None, - measured_qubits: Dict[str, Tuple['cirq.Qid', ...]] = None, + classical_data: 'cirq.ClassicalData' = None, ignore_measurement_results: bool = False, ): """Inits ActOnArgs. @@ -60,8 +60,8 @@ def __init__( ordering of the computational basis states. log_of_measurement_results: A mutable object that measurements are being recorded into. - measured_qubits: A dictionary that contains the qubits that were - measured in each measurement. + classical_data: The shared classical data container for this + simulation. ignore_measurement_results: If True, then the simulation will treat measurement as dephasing instead of collapsing process, and not log the result. This is only applicable to @@ -71,14 +71,9 @@ def __init__( prng = cast(np.random.RandomState, np.random) if qubits is None: qubits = () - if log_of_measurement_results is None: - log_of_measurement_results = {} - if measured_qubits is None: - measured_qubits = {} self._set_qubits(qubits) self.prng = prng - self._log_of_measurement_results = log_of_measurement_results - self._measured_qubits = measured_qubits + self._classical_data = classical_data or value.ClassicalData(log_of_measurement_results) self._ignore_measurement_results = ignore_measurement_results def _set_qubits(self, qubits: Sequence['cirq.Qid']): @@ -107,10 +102,7 @@ def measure(self, qubits: Sequence['cirq.Qid'], key: str, invert_mask: Sequence[ return bits = self._perform_measurement(qubits) corrected = [bit ^ (bit < 2 and mask) for bit, mask in zip(bits, invert_mask)] - if key in self._log_of_measurement_results: - raise ValueError(f"Measurement already logged to key {key!r}") - self._log_of_measurement_results[key] = corrected - self._measured_qubits[key] = tuple(qubits) + self._classical_data.record_measurement(value.MeasurementKey.parse_serialized(key), corrected, qubits) def get_axes(self, qubits: Sequence['cirq.Qid']) -> List[int]: return [self.qubit_map[q] for q in qubits] @@ -124,8 +116,7 @@ def copy(self: TSelf) -> TSelf: """Creates a copy of the object.""" args = copy.copy(self) self._on_copy(args) - args._log_of_measurement_results = self.log_of_measurement_results.copy() - args._measured_qubits = self.measured_qubits.copy() + args._classical_data = self._classical_data.copy() return args def _on_copy(self: TSelf, args: TSelf): @@ -203,11 +194,11 @@ def _on_transpose_to_qubit_order(self: TSelf, qubits: Sequence['cirq.Qid'], targ @property def log_of_measurement_results(self) -> Dict[str, List[int]]: - return self._log_of_measurement_results + return {k: list(v) for k, v in self._classical_data.measurements.items()} @property - def measured_qubits(self) -> Dict[str, Tuple['cirq.Qid', ...]]: - return self._measured_qubits + def classical_data(self) -> 'cirq.ClassicalData': + return self._classical_data @property def ignore_measurement_results(self) -> bool: diff --git a/cirq-core/cirq/sim/act_on_args_container.py b/cirq-core/cirq/sim/act_on_args_container.py index 107caaa2805..e5009ae2f1c 100644 --- a/cirq-core/cirq/sim/act_on_args_container.py +++ b/cirq-core/cirq/sim/act_on_args_container.py @@ -29,7 +29,7 @@ import numpy as np -from cirq import ops, protocols +from cirq import ops, protocols, value from cirq.sim.operation_target import OperationTarget from cirq.sim.simulator import ( TActOnArgs, @@ -51,8 +51,8 @@ def __init__( args: Dict[Optional['cirq.Qid'], TActOnArgs], qubits: Sequence['cirq.Qid'], split_untangled_states: bool, - log_of_measurement_results: Dict[str, Any], - measured_qubits: Dict[str, Tuple['cirq.Qid', ...]] = None, + log_of_measurement_results: Dict[str, Any] = None, + classical_data: 'cirq.ClassicalData' = None, ): """Initializes the class. @@ -65,14 +65,13 @@ def __init__( at the end. log_of_measurement_results: A mutable object that measurements are being recorded into. - measured_qubits: A dictionary that contains the qubits that were - measured in each measurement. + classical_data: The shared classical data container for this + simulation. """ self.args = args self._qubits = tuple(qubits) self.split_untangled_states = split_untangled_states - self._log_of_measurement_results = log_of_measurement_results - self._measured_qubits = measured_qubits if measured_qubits is not None else {} + self._classical_data = classical_data or value.ClassicalData(log_of_measurement_results) # type: ignore def create_merged_state(self) -> TActOnArgs: if not self.split_untangled_states: @@ -137,11 +136,11 @@ def _act_on_fallback_( def copy(self) -> 'cirq.ActOnArgsContainer[TActOnArgs]': logs = self.log_of_measurement_results.copy() - measured_qubits = self._measured_qubits.copy() + classical_data = self._classical_data.copy() copies = {a: a.copy() for a in set(self.args.values())} for copy in copies.values(): copy._log_of_measurement_results = logs - copy._measured_qubits = measured_qubits + copy._classical_data = classical_data args = {q: copies[a] for q, a in self.args.items()} return ActOnArgsContainer(args, self.qubits, self.split_untangled_states, logs) @@ -151,11 +150,11 @@ def qubits(self) -> Tuple['cirq.Qid', ...]: @property def log_of_measurement_results(self) -> Dict[str, Any]: - return self._log_of_measurement_results + return {k: list(v) for k, v in self._classical_data.measurements.items()} @property - def measured_qubits(self) -> Mapping[str, Tuple['cirq.Qid', ...]]: - return self._measured_qubits + def classical_data(self) -> 'cirq.ClassicalData': + return self._classical_data def sample( self, diff --git a/cirq-core/cirq/sim/act_on_args_container_test.py b/cirq-core/cirq/sim/act_on_args_container_test.py index 16380bc592c..e72887aea60 100644 --- a/cirq-core/cirq/sim/act_on_args_container_test.py +++ b/cirq-core/cirq/sim/act_on_args_container_test.py @@ -24,7 +24,7 @@ def __init__(self, qubits, logs): ) def _perform_measurement(self, qubits: Sequence[cirq.Qid]) -> List[int]: - return [] + return [0] * len(qubits) def copy(self) -> 'EmptyActOnArgs': return EmptyActOnArgs( diff --git a/cirq-core/cirq/sim/act_on_density_matrix_args.py b/cirq-core/cirq/sim/act_on_density_matrix_args.py index d3cf1e802bc..3a0f214cd4c 100644 --- a/cirq-core/cirq/sim/act_on_density_matrix_args.py +++ b/cirq-core/cirq/sim/act_on_density_matrix_args.py @@ -41,7 +41,7 @@ def __init__( prng: np.random.RandomState = None, log_of_measurement_results: Dict[str, Any] = None, qubits: Sequence['cirq.Qid'] = None, - measured_qubits: Dict[str, Tuple['cirq.Qid', ...]] = None, + classical_data: 'cirq.ClassicalData' = None, ignore_measurement_results: bool = False, ): """Inits ActOnDensityMatrixArgs. @@ -62,8 +62,8 @@ def __init__( effects. log_of_measurement_results: A mutable object that measurements are being recorded into. - measured_qubits: A dictionary that contains the qubits that were - measured in each measurement. + classical_data: The shared classical data container for this + simulation. ignore_measurement_results: If True, then the simulation will treat measurement as dephasing instead of collapsing process. This is only applicable to simulators that can @@ -73,7 +73,7 @@ def __init__( prng=prng, qubits=qubits, log_of_measurement_results=log_of_measurement_results, - measured_qubits=measured_qubits, + classical_data=classical_data, ignore_measurement_results=ignore_measurement_results, ) self.target_tensor = target_tensor diff --git a/cirq-core/cirq/sim/act_on_state_vector_args.py b/cirq-core/cirq/sim/act_on_state_vector_args.py index 3bc24978ebd..39361b1c845 100644 --- a/cirq-core/cirq/sim/act_on_state_vector_args.py +++ b/cirq-core/cirq/sim/act_on_state_vector_args.py @@ -44,7 +44,7 @@ def __init__( prng: np.random.RandomState = None, log_of_measurement_results: Dict[str, Any] = None, qubits: Sequence['cirq.Qid'] = None, - measured_qubits: Dict[str, Tuple['cirq.Qid', ...]] = None, + classical_data: 'cirq.ClassicalData' = None, ): """Inits ActOnStateVectorArgs. @@ -64,10 +64,10 @@ def __init__( effects. log_of_measurement_results: A mutable object that measurements are being recorded into. - measured_qubits: A dictionary that contains the qubits that were - measured in each measurement. + classical_data: The shared classical data container for this + simulation. """ - super().__init__(prng, qubits, log_of_measurement_results, measured_qubits) + super().__init__(prng, qubits, log_of_measurement_results, classical_data) self.target_tensor = target_tensor self.available_buffer = available_buffer diff --git a/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py b/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py index 95c6951ba31..50ada171dea 100644 --- a/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py +++ b/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py @@ -42,7 +42,7 @@ def __init__( prng: np.random.RandomState, log_of_measurement_results: Dict[str, Any], qubits: Sequence['cirq.Qid'] = None, - measured_qubits: Dict[str, Tuple['cirq.Qid', ...]] = None, + classical_data: 'cirq.ClassicalData' = None, ): """Inits ActOnCliffordTableauArgs. @@ -56,10 +56,10 @@ def __init__( effects. log_of_measurement_results: A mutable object that measurements are being recorded into. - measured_qubits: A dictionary that contains the qubits that were - measured in each measurement. + classical_data: The shared classical data container for this + simulation. """ - super().__init__(prng, qubits, log_of_measurement_results, measured_qubits) + super().__init__(prng, qubits, log_of_measurement_results, classical_data) self.tableau = tableau def _act_on_fallback_( diff --git a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py index aa1ee5590f1..d07c5c1b353 100644 --- a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py +++ b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py @@ -39,9 +39,9 @@ def __init__( self, state: 'cirq.StabilizerStateChForm', prng: np.random.RandomState, - log_of_measurement_results: Dict[str, Any], + log_of_measurement_results: Dict[str, Any] = None, qubits: Sequence['cirq.Qid'] = None, - measured_qubits: Dict[str, Tuple['cirq.Qid', ...]] = None, + classical_data: 'cirq.ClassicalData' = None, ): """Initializes with the given state and the axes for the operation. Args: @@ -54,10 +54,10 @@ def __init__( effects. log_of_measurement_results: A mutable object that measurements are being recorded into. - measured_qubits: A dictionary that contains the qubits that were - measured in each measurement. + classical_data: The shared classical data container for this + simulation. """ - super().__init__(prng, qubits, log_of_measurement_results, measured_qubits) + super().__init__(prng, qubits, log_of_measurement_results, classical_data) self.state = state def _act_on_fallback_( diff --git a/cirq-core/cirq/sim/clifford/clifford_simulator.py b/cirq-core/cirq/sim/clifford/clifford_simulator.py index 0e1c15ad864..308a8a081c5 100644 --- a/cirq-core/cirq/sim/clifford/clifford_simulator.py +++ b/cirq-core/cirq/sim/clifford/clifford_simulator.py @@ -64,12 +64,11 @@ def is_supported_operation(op: 'cirq.Operation') -> bool: # TODO: support more general Pauli measurements return protocols.has_stabilizer_effect(op) - def _create_partial_act_on_args( + def _create_partial_act_on_args_ex( self, initial_state: Union[int, 'cirq.ActOnStabilizerCHFormArgs'], qubits: Sequence['cirq.Qid'], - logs: Dict[str, Any], - measured_qubits: Dict[str, Tuple['cirq.Qid', ...]], + classical_data: 'cirq.ClassicalData' = None, ) -> 'cirq.ActOnStabilizerCHFormArgs': """Creates the ActOnStabilizerChFormArgs for a circuit. @@ -80,8 +79,8 @@ def _create_partial_act_on_args( is often used in specifying the initial state, i.e. the ordering of the computational basis states. logs: A log of the results of measurement that is added to. - measured_qubits: A dictionary that contains the qubits that were - measured in each measurement. + classical_data: The shared classical data container for this + simulation. Returns: ActOnStabilizerChFormArgs for the circuit. @@ -95,9 +94,8 @@ def _create_partial_act_on_args( return clifford.ActOnStabilizerCHFormArgs( state=state.ch_form, prng=self._prng, - log_of_measurement_results=logs, qubits=qubits, - measured_qubits=measured_qubits, + classical_data=classical_data, ) def _create_step_result( diff --git a/cirq-core/cirq/sim/density_matrix_simulator.py b/cirq-core/cirq/sim/density_matrix_simulator.py index a5cc062eb92..e7e315b353a 100644 --- a/cirq-core/cirq/sim/density_matrix_simulator.py +++ b/cirq-core/cirq/sim/density_matrix_simulator.py @@ -172,12 +172,11 @@ def __init__( if dtype not in {np.complex64, np.complex128}: raise ValueError(f'dtype must be complex64 or complex128, was {dtype}') - def _create_partial_act_on_args( + def _create_partial_act_on_args_ex( self, initial_state: Union[np.ndarray, 'cirq.STATE_VECTOR_LIKE', 'cirq.ActOnDensityMatrixArgs'], qubits: Sequence['cirq.Qid'], - logs: Dict[str, Any], - measured_qubits: Dict[str, Tuple['cirq.Qid', ...]], + classical_data: 'cirq.ClassicalData' = None, ) -> 'cirq.ActOnDensityMatrixArgs': """Creates the ActOnDensityMatrixArgs for a circuit. @@ -187,9 +186,8 @@ def _create_partial_act_on_args( qubits: Determines the canonical ordering of the qubits. This is often used in specifying the initial state, i.e. the ordering of the computational basis states. - logs: The log of measurement results that is added into. - measured_qubits: A dictionary that contains the qubits that were - measured in each measurement. + classical_data: The shared classical data container for this + simulation. Returns: ActOnDensityMatrixArgs for the circuit. @@ -211,8 +209,7 @@ def _create_partial_act_on_args( qubits=qubits, qid_shape=qid_shape, prng=self._prng, - log_of_measurement_results=logs, - measured_qubits=measured_qubits, + classical_data=classical_data, ignore_measurement_results=self._ignore_measurement_results, ) diff --git a/cirq-core/cirq/sim/operation_target.py b/cirq-core/cirq/sim/operation_target.py index 56b854bd8ff..e3d95ff79e9 100644 --- a/cirq-core/cirq/sim/operation_target.py +++ b/cirq-core/cirq/sim/operation_target.py @@ -84,8 +84,8 @@ def log_of_measurement_results(self) -> Dict[str, Any]: @property @abc.abstractmethod - def measured_qubits(self) -> Mapping[str, Tuple['cirq.Qid', ...]]: - """Gets the qubits that were in each measurement.""" + def classical_data(self) -> 'cirq.ClassicalData': + """The shared classical data container for this simulation..""" @abc.abstractmethod def sample( diff --git a/cirq-core/cirq/sim/simulator_base.py b/cirq-core/cirq/sim/simulator_base.py index 96063d00ce0..7fa61fa9302 100644 --- a/cirq-core/cirq/sim/simulator_base.py +++ b/cirq-core/cirq/sim/simulator_base.py @@ -119,13 +119,11 @@ def __init__( self._ignore_measurement_results = ignore_measurement_results self._split_untangled_states = split_untangled_states - @abc.abstractmethod def _create_partial_act_on_args( self, initial_state: Any, qubits: Sequence['cirq.Qid'], logs: Dict[str, Any], - measured_qubits: Dict[str, Tuple['cirq.Qid', ...]], ) -> TActOnArgs: """Creates an instance of the TActOnArgs class for the simulator. @@ -138,9 +136,34 @@ def _create_partial_act_on_args( qubits: The sequence of qubits to represent. logs: The structure to hold measurement logs. A single instance should be shared among all ActOnArgs within the simulation. - measured_qubits: A dictionary that contains the qubits that were - measured in each measurement. """ + raise NotImplementedError() + + def _create_partial_act_on_args_ex( + self, + initial_state: Any, + qubits: Sequence['cirq.Qid'], + classical_data: 'cirq.ClassicalData', + ) -> TActOnArgs: + """Creates an instance of the TActOnArgs class for the simulator. + + It represents the supplied qubits initialized to the provided state. + + Args: + initial_state: The initial state to represent. An integer state is + understood to be a pure state. Other state representations are + simulator-dependent. + qubits: The sequence of qubits to represent. + classical_data: The shared classical data container for this + simulation. + """ + # Child classes should override this behavior. We call the old one here by default for + # backwards compatibility, until deprecation cycle is complete. + return self._create_partial_act_on_args( + initial_state, + qubits, + classical_data.measurements() # type: ignore + ) @abc.abstractmethod def _create_step_result( @@ -338,38 +361,34 @@ def _create_act_on_args( if isinstance(initial_state, OperationTarget): return initial_state - log: Dict[str, Any] = {} - measured_qubits: Dict[str, Tuple['cirq.Qid', ...]] = {} + classical_data = value.ClassicalData() if self._split_untangled_states: args_map: Dict[Optional['cirq.Qid'], TActOnArgs] = {} if isinstance(initial_state, int): for q in reversed(qubits): - args_map[q] = self._create_partial_act_on_args( + args_map[q] = self._create_partial_act_on_args_ex( initial_state=initial_state % q.dimension, qubits=[q], - logs=log, - measured_qubits=measured_qubits, + classical_data=classical_data, ) initial_state = int(initial_state / q.dimension) else: - args = self._create_partial_act_on_args( + args = self._create_partial_act_on_args_ex( initial_state=initial_state, qubits=qubits, - logs=log, - measured_qubits=measured_qubits, + classical_data=classical_data, ) for q in qubits: args_map[q] = args - args_map[None] = self._create_partial_act_on_args(0, (), log, measured_qubits) + args_map[None] = self._create_partial_act_on_args_ex(0, (), classical_data) return ActOnArgsContainer( - args_map, qubits, self._split_untangled_states, log, measured_qubits + args_map, qubits, self._split_untangled_states, classical_data=classical_data ) else: - return self._create_partial_act_on_args( + return self._create_partial_act_on_args_ex( initial_state=initial_state, qubits=qubits, - logs=log, - measured_qubits=measured_qubits, + classical_data=classical_data, ) diff --git a/cirq-core/cirq/sim/simulator_base_test.py b/cirq-core/cirq/sim/simulator_base_test.py index eb327dd5791..757ee8b694d 100644 --- a/cirq-core/cirq/sim/simulator_base_test.py +++ b/cirq-core/cirq/sim/simulator_base_test.py @@ -25,10 +25,10 @@ class CountingActOnArgs(cirq.ActOnArgs): gate_count = 0 measurement_count = 0 - def __init__(self, state, qubits, logs): + def __init__(self, state, qubits, classical_data): super().__init__( qubits=qubits, - log_of_measurement_results=logs, + classical_data=classical_data, ) self.state = state @@ -39,7 +39,7 @@ def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]: def copy(self) -> 'CountingActOnArgs': args = CountingActOnArgs( qubits=self.qubits, - logs=self.log_of_measurement_results.copy(), + classical_data=self.classical_data.copy(), state=self.state, ) args.gate_count = self.gate_count @@ -107,14 +107,13 @@ def __init__(self, noise=None, split_untangled_states=False): split_untangled_states=split_untangled_states, ) - def _create_partial_act_on_args( + def _create_partial_act_on_args_ex( self, initial_state: Any, qubits: Sequence['cirq.Qid'], - logs: Dict[str, Any], - measured_qubits, + classical_data: cirq.ClassicalData, ) -> CountingActOnArgs: - return CountingActOnArgs(qubits=qubits, state=initial_state, logs=logs) + return CountingActOnArgs(qubits=qubits, state=initial_state, classical_data=classical_data) def _create_simulator_trial_result( self, @@ -143,9 +142,9 @@ def _create_partial_act_on_args( initial_state: Any, qubits: Sequence['cirq.Qid'], logs: Dict[str, Any], - measured_qubits, + classical_data, ) -> CountingActOnArgs: - return SplittableCountingActOnArgs(qubits=qubits, state=initial_state, logs=logs) + return SplittableCountingActOnArgs(qubits=qubits, state=initial_state, classical_data=classical_data) q0, q1 = cirq.LineQubit.range(2) diff --git a/cirq-core/cirq/sim/sparse_simulator.py b/cirq-core/cirq/sim/sparse_simulator.py index e1e4b971bca..ed084bbf3d4 100644 --- a/cirq-core/cirq/sim/sparse_simulator.py +++ b/cirq-core/cirq/sim/sparse_simulator.py @@ -171,12 +171,11 @@ def __init__( split_untangled_states=split_untangled_states, ) - def _create_partial_act_on_args( + def _create_partial_act_on_args_ex( self, initial_state: Union['cirq.STATE_VECTOR_LIKE', 'cirq.ActOnStateVectorArgs'], qubits: Sequence['cirq.Qid'], - logs: Dict[str, Any], - measured_qubits: Dict[str, Tuple['cirq.Qid', ...]], + classical_data: 'cirq.ClassicalData', ): """Creates the ActOnStateVectorArgs for a circuit. @@ -186,9 +185,8 @@ def _create_partial_act_on_args( qubits: Determines the canonical ordering of the qubits. This is often used in specifying the initial state, i.e. the ordering of the computational basis states. - logs: Log of the measurement results. - measured_qubits: A dictionary that contains the qubits that were - measured in each measurement. + classical_data: The shared classical data container for this + simulation. Returns: ActOnStateVectorArgs for the circuit. @@ -206,8 +204,7 @@ def _create_partial_act_on_args( available_buffer=np.empty(qid_shape, dtype=self._dtype), qubits=qubits, prng=self._prng, - log_of_measurement_results=logs, - measured_qubits=measured_qubits, + classical_data=classical_data, ) def _create_step_result( diff --git a/cirq-core/cirq/value/__init__.py b/cirq-core/cirq/value/__init__.py index da6cfc2b058..168dc935b0b 100644 --- a/cirq-core/cirq/value/__init__.py +++ b/cirq-core/cirq/value/__init__.py @@ -25,6 +25,10 @@ chosen_angle_to_half_turns, ) +from cirq.value.classical_data import ( + ClassicalData, +) + from cirq.value.condition import ( Condition, KeyCondition, diff --git a/cirq-core/cirq/value/classical_data.py b/cirq-core/cirq/value/classical_data.py new file mode 100644 index 00000000000..6b1a991274f --- /dev/null +++ b/cirq-core/cirq/value/classical_data.py @@ -0,0 +1,72 @@ +# Copyright 2021 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import dataclasses +from typing import Dict, Mapping, Sequence, Tuple, TYPE_CHECKING, FrozenSet, List + +import sympy + +from cirq._compat import proper_repr +from cirq.protocols import json_serialization, measurement_key_protocol as mkp +from cirq.value import digits, measurement_key + +if TYPE_CHECKING: + import cirq + + +class ClassicalData: + """Classical data representing measurements and metadata.""" + + def __init__( + self, + measurements: Dict['cirq.MeasurementKey', Tuple[int, ...]] = None, + measured_qubits: Dict['cirq.MeasurementKey', Tuple['cirq.Qid', ...]] = None, + ): + # if (measurements is None) != (measured_qubits is None): + # raise ValueError( + # 'measurements and measured_qubits must both either be provided or left default.' + # ) + if measurements is None: + measurements: Dict['cirq.MeasurementKey', Tuple[int, ...]] = {} + if measured_qubits is None: + measured_qubits: Dict['cirq.MeasurementKey', Tuple['cirq.Qid', ...]] = {} + # if set(measurements.keys()) != set(measured_qubits.keys()): + # raise ValueError('measurements and measured_qubits must contain same keys.') + self._measurements = measurements + self._measured_qubits = measured_qubits + + def keys(self) -> Tuple['cirq.MeasurementKey', ...]: + return tuple(self._measurements.keys()) + + @property + def measurements(self) -> Mapping['cirq.MeasurementKey', Tuple[int, ...]]: + return self._measurements + + @property + def measured_qubits(self) -> Mapping['cirq.MeasurementKey', Tuple['cirq.Qid', ...]]: + return self._measured_qubits + + def record_measurement( + self, key: 'cirq.MeasurementKey', measurement: Sequence[int], qubits: Sequence['cirq.Qid'] + ): + if len(measurement) != len(qubits): + raise ValueError(f'{len(measurement)} measurements but {len(qubits)} qubits.') + if key in self._measurements: + raise ValueError(f"Measurement already logged to key {key!r}") + self._measurements[key] = tuple(measurement) + self._measured_qubits[key] = tuple(qubits) + + def copy(self): + return ClassicalData(self._measurements.copy(), self._measured_qubits.copy()) diff --git a/cirq-core/cirq/value/condition.py b/cirq-core/cirq/value/condition.py index 7e02173df48..cb12a25deb1 100644 --- a/cirq-core/cirq/value/condition.py +++ b/cirq-core/cirq/value/condition.py @@ -41,8 +41,7 @@ def replace_key(self, current: 'cirq.MeasurementKey', replacement: 'cirq.Measure @abc.abstractmethod def resolve( self, - measurements: Mapping[str, Sequence[int]], - measured_qubits: Mapping[str, Sequence['cirq.Qid']] = None, + classical_data: 'cirq.ClassicalData', ) -> bool: """Resolves the condition based on the measurements.""" @@ -104,13 +103,11 @@ def __repr__(self): def resolve( self, - measurements: Mapping[str, Sequence[int]], - measured_qubits: Mapping[str, Sequence['cirq.Qid']] = None, + classical_data: 'cirq.ClassicalData', ) -> bool: - key = str(self.key) - if key not in measurements: - raise ValueError(f'Measurement key {key} missing when testing classical control') - return any(measurements[key]) + if self.key not in classical_data.keys(): + raise ValueError(f'Measurement key {self.key} missing when testing classical control') + return any(classical_data.measurements[self.key]) def _json_dict_(self): return json_serialization.dataclass_json_dict(self) @@ -153,23 +150,20 @@ def __repr__(self): def resolve( self, - measurements: Mapping[str, Sequence[int]], - measured_qubits: Mapping[str, Sequence['cirq.Qid']] = None, + classical_data: 'cirq.ClassicalData', ) -> bool: - missing = [str(k) for k in self.keys if str(k) not in measurements] + missing = [str(k) for k in self.keys if k not in classical_data.keys()] if missing: raise ValueError(f'Measurement keys {missing} missing when testing classical control') def value(k): return ( - digits.big_endian_bits_to_int(measurements[k]) - if measured_qubits is None - else digits.big_endian_digits_to_int( - measurements[k], base=[q.dimension for q in measured_qubits[k]] + digits.big_endian_digits_to_int( + classical_data.measurements[k], base=[q.dimension for q in classical_data.measured_qubits[k]] ) ) - replacements = {str(k): value(str(k)) for k in self.keys} + replacements = {str(k): value(k) for k in self.keys} return bool(self.expr.subs(replacements)) def _json_dict_(self): diff --git a/cirq-google/cirq_google/calibration/engine_simulator.py b/cirq-google/cirq_google/calibration/engine_simulator.py index 573fa7b6f0f..db4eeb5c619 100644 --- a/cirq-google/cirq_google/calibration/engine_simulator.py +++ b/cirq-google/cirq_google/calibration/engine_simulator.py @@ -470,16 +470,6 @@ def simulate( converted = _convert_to_circuit_with_drift(self, program) return self._simulator.simulate(converted, param_resolver, qubit_order, initial_state) - def _create_partial_act_on_args( - self, - initial_state: Union[int, cirq.ActOnStateVectorArgs], - qubits: Sequence[cirq.Qid], - logs: Dict[str, Any], - measured_qubits: Dict[str, Tuple['cirq.Qid', ...]], - ) -> cirq.ActOnStateVectorArgs: - # Needs an implementation since it's abstract but will never actually be called. - raise NotImplementedError() - def _create_step_result( self, sim_state: cirq.OperationTarget, From f411cc3e9b147a07aba3f246625e2f37745ed387 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Thu, 23 Dec 2021 23:02:08 -0800 Subject: [PATCH 42/89] add get_int, fix bugs --- .../cirq/contrib/quimb/mps_simulator_test.py | 4 +- cirq-core/cirq/sim/act_on_args.py | 4 -- cirq-core/cirq/sim/act_on_args_container.py | 4 -- .../cirq/sim/act_on_state_vector_args.py | 4 +- .../sim/clifford/clifford_simulator_test.py | 2 +- cirq-core/cirq/sim/operation_target.py | 2 +- cirq-core/cirq/sim/simulator_base_test.py | 5 +- cirq-core/cirq/value/classical_data.py | 23 ++++++---- cirq-core/cirq/value/condition.py | 11 +---- cirq-core/cirq/value/condition_test.py | 46 +++++++++++-------- 10 files changed, 50 insertions(+), 55 deletions(-) diff --git a/cirq-core/cirq/contrib/quimb/mps_simulator_test.py b/cirq-core/cirq/contrib/quimb/mps_simulator_test.py index 5d85735eac0..14e7119a3c7 100644 --- a/cirq-core/cirq/contrib/quimb/mps_simulator_test.py +++ b/cirq-core/cirq/contrib/quimb/mps_simulator_test.py @@ -539,10 +539,10 @@ def test_state_act_on_args_initializer(): s = ccq.mps_simulator.MPSState( qubits=(cirq.LineQubit(0),), prng=np.random.RandomState(0), - log_of_measurement_results={'test': 4}, + log_of_measurement_results={'test': [4]}, ) assert s.qubits == (cirq.LineQubit(0),) - assert s.log_of_measurement_results == {'test': 4} + assert s.log_of_measurement_results == {'test': [4]} def test_act_on_gate(): diff --git a/cirq-core/cirq/sim/act_on_args.py b/cirq-core/cirq/sim/act_on_args.py index 8675bdf3b5a..4a1e798d23a 100644 --- a/cirq-core/cirq/sim/act_on_args.py +++ b/cirq-core/cirq/sim/act_on_args.py @@ -192,10 +192,6 @@ def _on_transpose_to_qubit_order(self: TSelf, qubits: Sequence['cirq.Qid'], targ """Subclasses should implement this with any additional state transpose functionality, if supported.""" - @property - def log_of_measurement_results(self) -> Dict[str, List[int]]: - return {k: list(v) for k, v in self._classical_data.measurements.items()} - @property def classical_data(self) -> 'cirq.ClassicalData': return self._classical_data diff --git a/cirq-core/cirq/sim/act_on_args_container.py b/cirq-core/cirq/sim/act_on_args_container.py index e5009ae2f1c..404564afc39 100644 --- a/cirq-core/cirq/sim/act_on_args_container.py +++ b/cirq-core/cirq/sim/act_on_args_container.py @@ -148,10 +148,6 @@ def copy(self) -> 'cirq.ActOnArgsContainer[TActOnArgs]': def qubits(self) -> Tuple['cirq.Qid', ...]: return self._qubits - @property - def log_of_measurement_results(self) -> Dict[str, Any]: - return {k: list(v) for k, v in self._classical_data.measurements.items()} - @property def classical_data(self) -> 'cirq.ClassicalData': return self._classical_data diff --git a/cirq-core/cirq/sim/act_on_state_vector_args.py b/cirq-core/cirq/sim/act_on_state_vector_args.py index 39361b1c845..7b6b471cad7 100644 --- a/cirq-core/cirq/sim/act_on_state_vector_args.py +++ b/cirq-core/cirq/sim/act_on_state_vector_args.py @@ -278,7 +278,7 @@ def _strat_act_on_state_vector_from_mixture( args.swap_target_tensor_for(args.available_buffer) if protocols.is_measurement(action): key = protocols.measurement_key_name(action) - args.log_of_measurement_results[key] = [index] + args.classical_data.record_measurement(key, [index], qubits) return True @@ -327,5 +327,5 @@ def prepare_into_buffer(k: int): args.swap_target_tensor_for(args.available_buffer) if protocols.is_measurement(action): key = protocols.measurement_key_name(action) - args.log_of_measurement_results[key] = [index] + args.classical_data.record_measurement(key, [index], qubits) return True diff --git a/cirq-core/cirq/sim/clifford/clifford_simulator_test.py b/cirq-core/cirq/sim/clifford/clifford_simulator_test.py index 59156a3b88f..cdd8c5fcdff 100644 --- a/cirq-core/cirq/sim/clifford/clifford_simulator_test.py +++ b/cirq-core/cirq/sim/clifford/clifford_simulator_test.py @@ -546,7 +546,7 @@ def test_valid_apply_measurement(): state = cirq.CliffordState(qubit_map={q0: 0}, initial_state=1) measurements = {} _ = state.apply_measurement(cirq.measure(q0), measurements, np.random.RandomState()) - assert measurements == {'0': [1]} + assert measurements == {'0': (1,)} def test_reset(): diff --git a/cirq-core/cirq/sim/operation_target.py b/cirq-core/cirq/sim/operation_target.py index e3d95ff79e9..7ff5d6cac17 100644 --- a/cirq-core/cirq/sim/operation_target.py +++ b/cirq-core/cirq/sim/operation_target.py @@ -78,9 +78,9 @@ def qubits(self) -> Tuple['cirq.Qid', ...]: """Gets the qubit order maintained by this target.""" @property - @abc.abstractmethod def log_of_measurement_results(self) -> Dict[str, Any]: """Gets the log of measurement results.""" + return {k: list(v) for k, v in self.classical_data.measurements.items()} @property @abc.abstractmethod diff --git a/cirq-core/cirq/sim/simulator_base_test.py b/cirq-core/cirq/sim/simulator_base_test.py index 757ee8b694d..34a876a15d8 100644 --- a/cirq-core/cirq/sim/simulator_base_test.py +++ b/cirq-core/cirq/sim/simulator_base_test.py @@ -137,12 +137,11 @@ def __init__(self, noise=None, split_untangled_states=True): split_untangled_states=split_untangled_states, ) - def _create_partial_act_on_args( + def _create_partial_act_on_args_ex( self, initial_state: Any, qubits: Sequence['cirq.Qid'], - logs: Dict[str, Any], - classical_data, + classical_data: cirq.ClassicalData, ) -> CountingActOnArgs: return SplittableCountingActOnArgs(qubits=qubits, state=initial_state, classical_data=classical_data) diff --git a/cirq-core/cirq/value/classical_data.py b/cirq-core/cirq/value/classical_data.py index 6b1a991274f..e9c580fbb19 100644 --- a/cirq-core/cirq/value/classical_data.py +++ b/cirq-core/cirq/value/classical_data.py @@ -12,15 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import abc -import dataclasses -from typing import Dict, Mapping, Sequence, Tuple, TYPE_CHECKING, FrozenSet, List +from typing import Dict, Mapping, Sequence, Tuple, TYPE_CHECKING -import sympy - -from cirq._compat import proper_repr -from cirq.protocols import json_serialization, measurement_key_protocol as mkp -from cirq.value import digits, measurement_key +from cirq.value import digits if TYPE_CHECKING: import cirq @@ -34,6 +28,7 @@ def __init__( measurements: Dict['cirq.MeasurementKey', Tuple[int, ...]] = None, measured_qubits: Dict['cirq.MeasurementKey', Tuple['cirq.Qid', ...]] = None, ): + # TODO: Uncomment this after log_of_measurement_results is deprecated and removed # if (measurements is None) != (measured_qubits is None): # raise ValueError( # 'measurements and measured_qubits must both either be provided or left default.' @@ -61,12 +56,22 @@ def measured_qubits(self) -> Mapping['cirq.MeasurementKey', Tuple['cirq.Qid', .. def record_measurement( self, key: 'cirq.MeasurementKey', measurement: Sequence[int], qubits: Sequence['cirq.Qid'] ): - if len(measurement) != len(qubits): + if len(measurement) != len(qubits) and len(measurement) != 1: + # the latter condition is allowed for keyed channel measurements raise ValueError(f'{len(measurement)} measurements but {len(qubits)} qubits.') if key in self._measurements: raise ValueError(f"Measurement already logged to key {key!r}") self._measurements[key] = tuple(measurement) self._measured_qubits[key] = tuple(qubits) + def get_int(self, key: 'cirq.MeasurementKey') -> int: + measurement = self._measurements[key] + # keyed channels + if len(measurement) == 1: + return measurement[0] + return digits.big_endian_digits_to_int( + measurement, base=[q.dimension for q in self._measured_qubits[key]] + ) + def copy(self): return ClassicalData(self._measurements.copy(), self._measured_qubits.copy()) diff --git a/cirq-core/cirq/value/condition.py b/cirq-core/cirq/value/condition.py index cb12a25deb1..2da2bd29cf0 100644 --- a/cirq-core/cirq/value/condition.py +++ b/cirq-core/cirq/value/condition.py @@ -107,7 +107,7 @@ def resolve( ) -> bool: if self.key not in classical_data.keys(): raise ValueError(f'Measurement key {self.key} missing when testing classical control') - return any(classical_data.measurements[self.key]) + return classical_data.get_int(self.key) != 0 def _json_dict_(self): return json_serialization.dataclass_json_dict(self) @@ -156,14 +156,7 @@ def resolve( if missing: raise ValueError(f'Measurement keys {missing} missing when testing classical control') - def value(k): - return ( - digits.big_endian_digits_to_int( - classical_data.measurements[k], base=[q.dimension for q in classical_data.measured_qubits[k]] - ) - ) - - replacements = {str(k): value(k) for k in self.keys} + replacements = {str(k): classical_data.get_int(k) for k in self.keys} return bool(self.expr.subs(replacements)) def _json_dict_(self): diff --git a/cirq-core/cirq/value/condition_test.py b/cirq-core/cirq/value/condition_test.py index fd80033a29a..ba9128826bd 100644 --- a/cirq-core/cirq/value/condition_test.py +++ b/cirq-core/cirq/value/condition_test.py @@ -42,22 +42,25 @@ def test_key_condition_repr(): def test_key_condition_resolve(): - assert init_key_condition.resolve({'0:a': [1]}) - assert init_key_condition.resolve({'0:a': [2]}) - assert init_key_condition.resolve({'0:a': [0, 1]}) - assert init_key_condition.resolve({'0:a': [1, 0]}) - assert not init_key_condition.resolve({'0:a': [0]}) - assert not init_key_condition.resolve({'0:a': [0, 0]}) - assert not init_key_condition.resolve({'0:a': []}) - assert not init_key_condition.resolve({'0:a': [0], 'b': [1]}) + def resolve(measurements): + classical_data = cirq.ClassicalData(measurements, {k: tuple(cirq.LineQubit(i) for i in v) for k, v in measurements.items()}) + return init_key_condition.resolve(classical_data) + assert resolve({'0:a': [1]}) + assert resolve({'0:a': [2]}) + assert resolve({'0:a': [0, 1]}) + assert resolve({'0:a': [1, 0]}) + assert not resolve({'0:a': [0]}) + assert not resolve({'0:a': [0, 0]}) + assert not resolve({'0:a': []}) + assert not resolve({'0:a': [0], 'b': [1]}) with pytest.raises( ValueError, match='Measurement key 0:a missing when testing classical control' ): - _ = init_key_condition.resolve({}) + _ = resolve({}) with pytest.raises( ValueError, match='Measurement key 0:a missing when testing classical control' ): - _ = init_key_condition.resolve({'0:b': [1]}) + _ = resolve({'0:b': [1]}) def test_key_condition_qasm(): @@ -80,24 +83,27 @@ def test_sympy_condition_repr(): def test_sympy_condition_resolve(): - assert init_sympy_condition.resolve({'0:a': [1]}) - assert init_sympy_condition.resolve({'0:a': [2]}) - assert init_sympy_condition.resolve({'0:a': [0, 1]}) - assert init_sympy_condition.resolve({'0:a': [1, 0]}) - assert not init_sympy_condition.resolve({'0:a': [0]}) - assert not init_sympy_condition.resolve({'0:a': [0, 0]}) - assert not init_sympy_condition.resolve({'0:a': []}) - assert not init_sympy_condition.resolve({'0:a': [0], 'b': [1]}) + def resolve(measurements): + classical_data = cirq.ClassicalData(measurements, {k: tuple(cirq.LineQubit(i) for i in v) for k, v in measurements.items()}) + return init_sympy_condition.resolve(classical_data) + assert resolve({'0:a': [1]}) + assert resolve({'0:a': [2]}) + assert resolve({'0:a': [0, 1]}) + assert resolve({'0:a': [1, 0]}) + assert not resolve({'0:a': [0]}) + assert not resolve({'0:a': [0, 0]}) + assert not resolve({'0:a': []}) + assert not resolve({'0:a': [0], 'b': [1]}) with pytest.raises( ValueError, match=re.escape("Measurement keys ['0:a'] missing when testing classical control"), ): - _ = init_sympy_condition.resolve({}) + _ = resolve({}) with pytest.raises( ValueError, match=re.escape("Measurement keys ['0:a'] missing when testing classical control"), ): - _ = init_sympy_condition.resolve({'0:b': [1]}) + _ = resolve({'0:b': [1]}) def test_sympy_condition_qasm(): From f0b0016396e066a6330b9481964785c65408d7b3 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Thu, 23 Dec 2021 23:40:30 -0800 Subject: [PATCH 43/89] fix ActOnArgsContainer.copy --- cirq-core/cirq/sim/act_on_args_container.py | 6 ++---- cirq-core/cirq/sim/simulator_base.py | 2 +- cirq-core/cirq/value/measurement_key.py | 5 +++++ 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/cirq-core/cirq/sim/act_on_args_container.py b/cirq-core/cirq/sim/act_on_args_container.py index 404564afc39..92c3e3cdc6f 100644 --- a/cirq-core/cirq/sim/act_on_args_container.py +++ b/cirq-core/cirq/sim/act_on_args_container.py @@ -135,14 +135,12 @@ def _act_on_fallback_( return True def copy(self) -> 'cirq.ActOnArgsContainer[TActOnArgs]': - logs = self.log_of_measurement_results.copy() - classical_data = self._classical_data.copy() + classical_data = self.classical_data.copy() copies = {a: a.copy() for a in set(self.args.values())} for copy in copies.values(): - copy._log_of_measurement_results = logs copy._classical_data = classical_data args = {q: copies[a] for q, a in self.args.items()} - return ActOnArgsContainer(args, self.qubits, self.split_untangled_states, logs) + return ActOnArgsContainer(args, self.qubits, self.split_untangled_states, classical_data=classical_data) @property def qubits(self) -> Tuple['cirq.Qid', ...]: diff --git a/cirq-core/cirq/sim/simulator_base.py b/cirq-core/cirq/sim/simulator_base.py index 7fa61fa9302..90fb6020afd 100644 --- a/cirq-core/cirq/sim/simulator_base.py +++ b/cirq-core/cirq/sim/simulator_base.py @@ -162,7 +162,7 @@ def _create_partial_act_on_args_ex( return self._create_partial_act_on_args( initial_state, qubits, - classical_data.measurements() # type: ignore + classical_data.measurements # type: ignore ) @abc.abstractmethod diff --git a/cirq-core/cirq/value/measurement_key.py b/cirq-core/cirq/value/measurement_key.py index ee4c12bb051..47c6712a2f0 100644 --- a/cirq-core/cirq/value/measurement_key.py +++ b/cirq-core/cirq/value/measurement_key.py @@ -77,6 +77,11 @@ def __hash__(self): object.__setattr__(self, '_hash', hash(str(self))) return self._hash + def __lt__(self, other): + if isinstance(other, MeasurementKey): + return str(self) < str(other) + return NotImplemented + def _json_dict_(self): return { 'name': self.name, From 334e985a63b808716067680b7008b80246c257a7 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Thu, 23 Dec 2021 23:41:28 -0800 Subject: [PATCH 44/89] format --- cirq-core/cirq/sim/act_on_args.py | 4 +++- cirq-core/cirq/sim/act_on_args_container.py | 4 +++- cirq-core/cirq/sim/simulator_base.py | 4 +--- cirq-core/cirq/sim/simulator_base_test.py | 4 +++- cirq-core/cirq/value/condition_test.py | 10 ++++++++-- 5 files changed, 18 insertions(+), 8 deletions(-) diff --git a/cirq-core/cirq/sim/act_on_args.py b/cirq-core/cirq/sim/act_on_args.py index 4a1e798d23a..59eca943da8 100644 --- a/cirq-core/cirq/sim/act_on_args.py +++ b/cirq-core/cirq/sim/act_on_args.py @@ -102,7 +102,9 @@ def measure(self, qubits: Sequence['cirq.Qid'], key: str, invert_mask: Sequence[ return bits = self._perform_measurement(qubits) corrected = [bit ^ (bit < 2 and mask) for bit, mask in zip(bits, invert_mask)] - self._classical_data.record_measurement(value.MeasurementKey.parse_serialized(key), corrected, qubits) + self._classical_data.record_measurement( + value.MeasurementKey.parse_serialized(key), corrected, qubits + ) def get_axes(self, qubits: Sequence['cirq.Qid']) -> List[int]: return [self.qubit_map[q] for q in qubits] diff --git a/cirq-core/cirq/sim/act_on_args_container.py b/cirq-core/cirq/sim/act_on_args_container.py index 92c3e3cdc6f..e170287dae3 100644 --- a/cirq-core/cirq/sim/act_on_args_container.py +++ b/cirq-core/cirq/sim/act_on_args_container.py @@ -140,7 +140,9 @@ def copy(self) -> 'cirq.ActOnArgsContainer[TActOnArgs]': for copy in copies.values(): copy._classical_data = classical_data args = {q: copies[a] for q, a in self.args.items()} - return ActOnArgsContainer(args, self.qubits, self.split_untangled_states, classical_data=classical_data) + return ActOnArgsContainer( + args, self.qubits, self.split_untangled_states, classical_data=classical_data + ) @property def qubits(self) -> Tuple['cirq.Qid', ...]: diff --git a/cirq-core/cirq/sim/simulator_base.py b/cirq-core/cirq/sim/simulator_base.py index 90fb6020afd..a5be75c3013 100644 --- a/cirq-core/cirq/sim/simulator_base.py +++ b/cirq-core/cirq/sim/simulator_base.py @@ -160,9 +160,7 @@ def _create_partial_act_on_args_ex( # Child classes should override this behavior. We call the old one here by default for # backwards compatibility, until deprecation cycle is complete. return self._create_partial_act_on_args( - initial_state, - qubits, - classical_data.measurements # type: ignore + initial_state, qubits, classical_data.measurements # type: ignore ) @abc.abstractmethod diff --git a/cirq-core/cirq/sim/simulator_base_test.py b/cirq-core/cirq/sim/simulator_base_test.py index 34a876a15d8..8e613c31e7a 100644 --- a/cirq-core/cirq/sim/simulator_base_test.py +++ b/cirq-core/cirq/sim/simulator_base_test.py @@ -143,7 +143,9 @@ def _create_partial_act_on_args_ex( qubits: Sequence['cirq.Qid'], classical_data: cirq.ClassicalData, ) -> CountingActOnArgs: - return SplittableCountingActOnArgs(qubits=qubits, state=initial_state, classical_data=classical_data) + return SplittableCountingActOnArgs( + qubits=qubits, state=initial_state, classical_data=classical_data + ) q0, q1 = cirq.LineQubit.range(2) diff --git a/cirq-core/cirq/value/condition_test.py b/cirq-core/cirq/value/condition_test.py index ba9128826bd..625cd3fd765 100644 --- a/cirq-core/cirq/value/condition_test.py +++ b/cirq-core/cirq/value/condition_test.py @@ -43,8 +43,11 @@ def test_key_condition_repr(): def test_key_condition_resolve(): def resolve(measurements): - classical_data = cirq.ClassicalData(measurements, {k: tuple(cirq.LineQubit(i) for i in v) for k, v in measurements.items()}) + classical_data = cirq.ClassicalData( + measurements, {k: tuple(cirq.LineQubit(i) for i in v) for k, v in measurements.items()} + ) return init_key_condition.resolve(classical_data) + assert resolve({'0:a': [1]}) assert resolve({'0:a': [2]}) assert resolve({'0:a': [0, 1]}) @@ -84,8 +87,11 @@ def test_sympy_condition_repr(): def test_sympy_condition_resolve(): def resolve(measurements): - classical_data = cirq.ClassicalData(measurements, {k: tuple(cirq.LineQubit(i) for i in v) for k, v in measurements.items()}) + classical_data = cirq.ClassicalData( + measurements, {k: tuple(cirq.LineQubit(i) for i in v) for k, v in measurements.items()} + ) return init_sympy_condition.resolve(classical_data) + assert resolve({'0:a': [1]}) assert resolve({'0:a': [2]}) assert resolve({'0:a': [0, 1]}) From 048da053e4cc47b5214f0e706d577e17fd402c58 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Fri, 24 Dec 2021 00:25:33 -0800 Subject: [PATCH 45/89] lint, mypy --- cirq-core/cirq/sim/act_on_args.py | 4 +++- cirq-core/cirq/sim/act_on_args_container.py | 1 - .../cirq/sim/clifford/stabilizer_state_ch_form_test.py | 2 +- cirq-core/cirq/sim/operation_target.py | 3 +-- cirq-core/cirq/sim/simulator_base.py | 1 + cirq-core/cirq/value/classical_data.py | 8 ++++---- cirq-core/cirq/value/condition.py | 4 ++-- 7 files changed, 12 insertions(+), 11 deletions(-) diff --git a/cirq-core/cirq/sim/act_on_args.py b/cirq-core/cirq/sim/act_on_args.py index 59eca943da8..af9680dd8c4 100644 --- a/cirq-core/cirq/sim/act_on_args.py +++ b/cirq-core/cirq/sim/act_on_args.py @@ -73,7 +73,9 @@ def __init__( qubits = () self._set_qubits(qubits) self.prng = prng - self._classical_data = classical_data or value.ClassicalData(log_of_measurement_results) + # pylint: disable=line-too-long + self._classical_data = classical_data or value.ClassicalData(log_of_measurement_results) # type: ignore + # pylint: enable=line-too-long self._ignore_measurement_results = ignore_measurement_results def _set_qubits(self, qubits: Sequence['cirq.Qid']): diff --git a/cirq-core/cirq/sim/act_on_args_container.py b/cirq-core/cirq/sim/act_on_args_container.py index e170287dae3..3d2749923a7 100644 --- a/cirq-core/cirq/sim/act_on_args_container.py +++ b/cirq-core/cirq/sim/act_on_args_container.py @@ -19,7 +19,6 @@ Generic, Iterator, List, - Mapping, Optional, Sequence, Tuple, diff --git a/cirq-core/cirq/sim/clifford/stabilizer_state_ch_form_test.py b/cirq-core/cirq/sim/clifford/stabilizer_state_ch_form_test.py index b984b760056..2b3acb09723 100644 --- a/cirq-core/cirq/sim/clifford/stabilizer_state_ch_form_test.py +++ b/cirq-core/cirq/sim/clifford/stabilizer_state_ch_form_test.py @@ -73,5 +73,5 @@ def test_run(): log_of_measurement_results=measurements, ) cirq.act_on(op, args) - assert measurements['1'] == [1] + assert measurements['1'] == (1,) assert measurements['0'] != measurements['2'] diff --git a/cirq-core/cirq/sim/operation_target.py b/cirq-core/cirq/sim/operation_target.py index 7ff5d6cac17..31dda85fb74 100644 --- a/cirq-core/cirq/sim/operation_target.py +++ b/cirq-core/cirq/sim/operation_target.py @@ -19,7 +19,6 @@ Generic, Iterator, List, - Mapping, Optional, Sequence, Tuple, @@ -80,7 +79,7 @@ def qubits(self) -> Tuple['cirq.Qid', ...]: @property def log_of_measurement_results(self) -> Dict[str, Any]: """Gets the log of measurement results.""" - return {k: list(v) for k, v in self.classical_data.measurements.items()} + return {str(k): list(v) for k, v in self.classical_data.measurements.items()} @property @abc.abstractmethod diff --git a/cirq-core/cirq/sim/simulator_base.py b/cirq-core/cirq/sim/simulator_base.py index a5be75c3013..eb76b701b35 100644 --- a/cirq-core/cirq/sim/simulator_base.py +++ b/cirq-core/cirq/sim/simulator_base.py @@ -159,6 +159,7 @@ def _create_partial_act_on_args_ex( """ # Child classes should override this behavior. We call the old one here by default for # backwards compatibility, until deprecation cycle is complete. + # coverage: ignore return self._create_partial_act_on_args( initial_state, qubits, classical_data.measurements # type: ignore ) diff --git a/cirq-core/cirq/value/classical_data.py b/cirq-core/cirq/value/classical_data.py index e9c580fbb19..9d96ec9cfc2 100644 --- a/cirq-core/cirq/value/classical_data.py +++ b/cirq-core/cirq/value/classical_data.py @@ -34,13 +34,13 @@ def __init__( # 'measurements and measured_qubits must both either be provided or left default.' # ) if measurements is None: - measurements: Dict['cirq.MeasurementKey', Tuple[int, ...]] = {} + measurements = {} if measured_qubits is None: - measured_qubits: Dict['cirq.MeasurementKey', Tuple['cirq.Qid', ...]] = {} + measured_qubits = {} # if set(measurements.keys()) != set(measured_qubits.keys()): # raise ValueError('measurements and measured_qubits must contain same keys.') - self._measurements = measurements - self._measured_qubits = measured_qubits + self._measurements: Dict['cirq.MeasurementKey', Tuple[int, ...]] = measurements + self._measured_qubits: Dict['cirq.MeasurementKey', Tuple['cirq.Qid', ...]] = measured_qubits def keys(self) -> Tuple['cirq.MeasurementKey', ...]: return tuple(self._measurements.keys()) diff --git a/cirq-core/cirq/value/condition.py b/cirq-core/cirq/value/condition.py index 2da2bd29cf0..dadac63a882 100644 --- a/cirq-core/cirq/value/condition.py +++ b/cirq-core/cirq/value/condition.py @@ -14,13 +14,13 @@ import abc import dataclasses -from typing import Dict, Mapping, Sequence, Tuple, TYPE_CHECKING, FrozenSet +from typing import Dict, Tuple, TYPE_CHECKING, FrozenSet import sympy from cirq._compat import proper_repr from cirq.protocols import json_serialization, measurement_key_protocol as mkp -from cirq.value import digits, measurement_key +from cirq.value import measurement_key if TYPE_CHECKING: import cirq From 64f22e6487f01fc62e44712938129c7fa17bf0d3 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Fri, 24 Dec 2021 01:17:03 -0800 Subject: [PATCH 46/89] json --- cirq-core/cirq/json_resolver_cache.py | 1 + .../json_test_data/ClassicalData.json | 18 +++++++++++++++ .../json_test_data/ClassicalData.repr | 1 + cirq-core/cirq/value/classical_data.py | 23 +++++++++++++++++-- 4 files changed, 41 insertions(+), 2 deletions(-) create mode 100644 cirq-core/cirq/protocols/json_test_data/ClassicalData.json create mode 100644 cirq-core/cirq/protocols/json_test_data/ClassicalData.repr diff --git a/cirq-core/cirq/json_resolver_cache.py b/cirq-core/cirq/json_resolver_cache.py index 6a017320dad..88ded1b7e48 100644 --- a/cirq-core/cirq/json_resolver_cache.py +++ b/cirq-core/cirq/json_resolver_cache.py @@ -64,6 +64,7 @@ def _parallel_gate_op(gate, qubits): 'Circuit': cirq.Circuit, 'CircuitOperation': cirq.CircuitOperation, 'ClassicallyControlledOperation': cirq.ClassicallyControlledOperation, + 'ClassicalData': cirq.ClassicalData, 'CliffordState': cirq.CliffordState, 'CliffordTableau': cirq.CliffordTableau, 'CNotPowGate': cirq.CNotPowGate, diff --git a/cirq-core/cirq/protocols/json_test_data/ClassicalData.json b/cirq-core/cirq/protocols/json_test_data/ClassicalData.json new file mode 100644 index 00000000000..94c2af6e1dc --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/ClassicalData.json @@ -0,0 +1,18 @@ +{ + "cirq_type": "ClassicalData", + "measurements": { + "m": [0, 1] + }, + "measured_qubits": { + "m": [ + { + "cirq_type": "LineQubit", + "x": 0 + }, + { + "cirq_type": "LineQubit", + "x": 1 + } + ] + } +} \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/ClassicalData.repr b/cirq-core/cirq/protocols/json_test_data/ClassicalData.repr new file mode 100644 index 00000000000..2e4cadece4f --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/ClassicalData.repr @@ -0,0 +1 @@ +cirq.ClassicalData(measurements={'m': [0, 1]}, measured_qubits={'m': [cirq.LineQubit(0), cirq.LineQubit(1)]}) \ No newline at end of file diff --git a/cirq-core/cirq/value/classical_data.py b/cirq-core/cirq/value/classical_data.py index 9d96ec9cfc2..8b9e2cc25f5 100644 --- a/cirq-core/cirq/value/classical_data.py +++ b/cirq-core/cirq/value/classical_data.py @@ -12,17 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +import dataclasses from typing import Dict, Mapping, Sequence, Tuple, TYPE_CHECKING +from cirq.protocols import json_serialization from cirq.value import digits if TYPE_CHECKING: import cirq +@dataclasses.dataclass class ClassicalData: """Classical data representing measurements and metadata.""" + _measurements: Dict['cirq.MeasurementKey', Tuple[int, ...]] + _measured_qubits: Dict['cirq.MeasurementKey', Tuple['cirq.Qid', ...]] + def __init__( self, measurements: Dict['cirq.MeasurementKey', Tuple[int, ...]] = None, @@ -39,8 +45,8 @@ def __init__( measured_qubits = {} # if set(measurements.keys()) != set(measured_qubits.keys()): # raise ValueError('measurements and measured_qubits must contain same keys.') - self._measurements: Dict['cirq.MeasurementKey', Tuple[int, ...]] = measurements - self._measured_qubits: Dict['cirq.MeasurementKey', Tuple['cirq.Qid', ...]] = measured_qubits + self._measurements = measurements + self._measured_qubits = measured_qubits def keys(self) -> Tuple['cirq.MeasurementKey', ...]: return tuple(self._measurements.keys()) @@ -75,3 +81,16 @@ def get_int(self, key: 'cirq.MeasurementKey') -> int: def copy(self): return ClassicalData(self._measurements.copy(), self._measured_qubits.copy()) + + def _json_dict_(self): + return json_serialization.obj_to_dict_helper(self, ['measurements', 'measured_qubits']) + + @classmethod + def _from_json_dict_(cls, measurements, measured_qubits, **kwargs): + return cls(measurements=measurements, measured_qubits=measured_qubits) + + def __repr__(self): + return ( + f'cirq.ClassicalData(measurements={self.measurements!r},' + f' measured_qubits={self.measured_qubits!r})' + ) From 38d23e70eac50ab169039051779ef195c1cf69d5 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Fri, 24 Dec 2021 01:20:19 -0800 Subject: [PATCH 47/89] lint --- cirq-core/cirq/sim/act_on_args_container.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cirq-core/cirq/sim/act_on_args_container.py b/cirq-core/cirq/sim/act_on_args_container.py index 3d2749923a7..e150d463e9e 100644 --- a/cirq-core/cirq/sim/act_on_args_container.py +++ b/cirq-core/cirq/sim/act_on_args_container.py @@ -70,7 +70,9 @@ def __init__( self.args = args self._qubits = tuple(qubits) self.split_untangled_states = split_untangled_states + # pylint: disable=line-too-long self._classical_data = classical_data or value.ClassicalData(log_of_measurement_results) # type: ignore + # pylint: enable=line-too-long def create_merged_state(self) -> TActOnArgs: if not self.split_untangled_states: From 87ecd1c7291be8001da6ea4cf634572f948f3c98 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Fri, 24 Dec 2021 01:40:34 -0800 Subject: [PATCH 48/89] mkey compare --- cirq-core/cirq/value/measurement_key.py | 4 +++- cirq-core/cirq/value/measurement_key_test.py | 8 ++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/cirq-core/cirq/value/measurement_key.py b/cirq-core/cirq/value/measurement_key.py index 47c6712a2f0..9e98cf68c2d 100644 --- a/cirq-core/cirq/value/measurement_key.py +++ b/cirq-core/cirq/value/measurement_key.py @@ -79,7 +79,9 @@ def __hash__(self): def __lt__(self, other): if isinstance(other, MeasurementKey): - return str(self) < str(other) + if self.path != other.path: + return self.path < other.path + return self.name < other.name return NotImplemented def _json_dict_(self): diff --git a/cirq-core/cirq/value/measurement_key_test.py b/cirq-core/cirq/value/measurement_key_test.py index c7f01de7d9a..58454b2b74c 100644 --- a/cirq-core/cirq/value/measurement_key_test.py +++ b/cirq-core/cirq/value/measurement_key_test.py @@ -98,3 +98,11 @@ def test_with_measurement_key_mapping(): mkey3 = cirq.with_measurement_key_mapping(mkey3, {'new_key': 'newer_key'}) assert mkey3.name == 'newer_key' assert mkey3.path == ('a',) + + +def test_compare(): + assert cirq.MeasurementKey('a') < cirq.MeasurementKey('b') + assert cirq.MeasurementKey(path=(), name='b') < cirq.MeasurementKey(path=('0',), name='a') + assert cirq.MeasurementKey(path=('0',), name='n') < cirq.MeasurementKey(path=('1',), name='a') + with pytest.raises(TypeError): + _ = cirq.MeasurementKey('a') < 'b' From 831e76ae745fbfca94cf564560a0dbeb7cae0ae7 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Fri, 24 Dec 2021 02:06:42 -0800 Subject: [PATCH 49/89] test class --- cirq-core/cirq/value/classical_data_test.py | 89 +++++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 cirq-core/cirq/value/classical_data_test.py diff --git a/cirq-core/cirq/value/classical_data_test.py b/cirq-core/cirq/value/classical_data_test.py new file mode 100644 index 00000000000..3729c35aee4 --- /dev/null +++ b/cirq-core/cirq/value/classical_data_test.py @@ -0,0 +1,89 @@ +# Copyright 2021 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import cirq + +mkey_m = cirq.MeasurementKey('m') +two_qubits = tuple(cirq.LineQubit.range(2)) + + +def test_init_empty(): + cd = cirq.ClassicalData() + assert cd.measurements is not None + assert not cd.measurements + assert cd.keys() is not None + assert not cd.keys() + assert cd.measured_qubits is not None + assert not cd.measured_qubits + + +def test_init_properties(): + cd = cirq.ClassicalData({mkey_m: (0, 1)}, {mkey_m: two_qubits}) + assert cd.measurements == {mkey_m: (0, 1)} + assert cd.keys() == (mkey_m,) + assert cd.measured_qubits == {mkey_m: two_qubits} + + +def test_record_measurement(): + cd = cirq.ClassicalData() + cd.record_measurement(mkey_m, (0, 1), two_qubits) + assert cd.measurements == {mkey_m: (0, 1)} + assert cd.keys() == (mkey_m,) + assert cd.measured_qubits == {mkey_m: two_qubits} + + +def test_record_measurement_errors(): + cd = cirq.ClassicalData() + with pytest.raises(ValueError, match='3 measurements but 2 qubits'): + cd.record_measurement(mkey_m, (0, 1, 2), two_qubits) + cd.record_measurement(mkey_m, (0, 1), two_qubits) + with pytest.raises(ValueError, match='Measurement already logged to key'): + cd.record_measurement(mkey_m, (0, 1), two_qubits) + + +def test_get_int(): + cd = cirq.ClassicalData() + cd.record_measurement(mkey_m, (0, 1), two_qubits) + assert cd.get_int(mkey_m) == 1 + cd = cirq.ClassicalData() + cd.record_measurement(mkey_m, (1, 1), two_qubits) + assert cd.get_int(mkey_m) == 3 + cd = cirq.ClassicalData() + cd.record_measurement(mkey_m, (8,), two_qubits) + assert cd.get_int(mkey_m) == 8 + cd = cirq.ClassicalData() + cd.record_measurement(mkey_m, (1, 1), (cirq.LineQid.range(2, dimension=3))) + assert cd.get_int(mkey_m) == 4 + + +def test_copy(): + cd = cirq.ClassicalData({mkey_m: (0, 1)}, {mkey_m: two_qubits}) + cd1 = cd.copy() + assert cd1 is not cd + assert cd1 == cd + assert cd1.measurements is not cd.measurements + assert cd1.measurements == cd.measurements + assert cd1.measured_qubits is not cd.measured_qubits + assert cd1.measured_qubits == cd.measured_qubits + + +def test_repr(): + cd = cirq.ClassicalData({mkey_m: (0, 1)}, {mkey_m: two_qubits}) + assert repr(cd) == ( + "cirq.ClassicalData(" + "measurements={cirq.MeasurementKey(name='m'): (0, 1)}, " + "measured_qubits={cirq.MeasurementKey(name='m'): (cirq.LineQubit(0), cirq.LineQubit(1))})" + ) From e93ffd19584d499634ea0ba5bd4a315647cbf752 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Fri, 24 Dec 2021 08:35:56 -0800 Subject: [PATCH 50/89] docstrings, create independent function for measuring channels --- .../classically_controlled_operation_test.py | 12 ++-- .../cirq/sim/act_on_state_vector_args.py | 4 +- cirq-core/cirq/value/classical_data.py | 55 ++++++++++++++++++- cirq-core/cirq/value/classical_data_test.py | 22 +++++++- 4 files changed, 80 insertions(+), 13 deletions(-) diff --git a/cirq-core/cirq/ops/classically_controlled_operation_test.py b/cirq-core/cirq/ops/classically_controlled_operation_test.py index 1307c1150b2..3089d94fcf4 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation_test.py +++ b/cirq-core/cirq/ops/classically_controlled_operation_test.py @@ -706,7 +706,7 @@ def test_sympy(): def test_sympy_qudits(): q0 = cirq.LineQid(0, 3) - q1 = cirq.LineQid(1, 3) + q1 = cirq.LineQid(1, 5) q_result = cirq.LineQubit(2) class PlusGate(cirq.Gate): @@ -724,18 +724,18 @@ def _unitary_(self): u[:inc] = np.eye(self.dimension)[-inc:] return u - for i in range(9): - digits = cirq.big_endian_int_to_digits(i, digit_count=2, base=3) + for i in range(15): + digits = cirq.big_endian_int_to_digits(i, digit_count=2, base=(3, 5)) circuit = cirq.Circuit( PlusGate(3, digits[0]).on(q0), - PlusGate(3, digits[1]).on(q1), + PlusGate(5, digits[1]).on(q1), cirq.measure(q0, q1, key='m'), - cirq.X(q_result).with_classical_controls(sympy_parser.parse_expr('m > 4')), + cirq.X(q_result).with_classical_controls(sympy_parser.parse_expr('m % 4 <= 1')), cirq.measure(q_result, key='m_result'), ) result = cirq.Simulator().run(circuit) - assert result.measurements['m_result'][0][0] == (i > 4) + assert result.measurements['m_result'][0][0] == (i % 4 <= 1) def test_sympy_path_prefix(): diff --git a/cirq-core/cirq/sim/act_on_state_vector_args.py b/cirq-core/cirq/sim/act_on_state_vector_args.py index 7b6b471cad7..9fbca455006 100644 --- a/cirq-core/cirq/sim/act_on_state_vector_args.py +++ b/cirq-core/cirq/sim/act_on_state_vector_args.py @@ -278,7 +278,7 @@ def _strat_act_on_state_vector_from_mixture( args.swap_target_tensor_for(args.available_buffer) if protocols.is_measurement(action): key = protocols.measurement_key_name(action) - args.classical_data.record_measurement(key, [index], qubits) + args.classical_data.record_channel_measurement(key, index, qubits) return True @@ -327,5 +327,5 @@ def prepare_into_buffer(k: int): args.swap_target_tensor_for(args.available_buffer) if protocols.is_measurement(action): key = protocols.measurement_key_name(action) - args.classical_data.record_measurement(key, [index], qubits) + args.classical_data.record_channel_measurement(key, index, qubits) return True diff --git a/cirq-core/cirq/value/classical_data.py b/cirq-core/cirq/value/classical_data.py index 8b9e2cc25f5..e6c82d73cfa 100644 --- a/cirq-core/cirq/value/classical_data.py +++ b/cirq-core/cirq/value/classical_data.py @@ -34,6 +34,12 @@ def __init__( measurements: Dict['cirq.MeasurementKey', Tuple[int, ...]] = None, measured_qubits: Dict['cirq.MeasurementKey', Tuple['cirq.Qid', ...]] = None, ): + """Initializes a `ClassicalData` object. + + Args: + measurements: The measurements to seed with, if any. + measured_qubits: The qubits corresponding to the measurements. + """ # TODO: Uncomment this after log_of_measurement_results is deprecated and removed # if (measurements is None) != (measured_qubits is None): # raise ValueError( @@ -43,43 +49,86 @@ def __init__( measurements = {} if measured_qubits is None: measured_qubits = {} + # TODO: Uncomment this after log_of_measurement_results is deprecated and removed # if set(measurements.keys()) != set(measured_qubits.keys()): # raise ValueError('measurements and measured_qubits must contain same keys.') self._measurements = measurements self._measured_qubits = measured_qubits def keys(self) -> Tuple['cirq.MeasurementKey', ...]: + """Gets the measurement keys in the order they were stored.""" return tuple(self._measurements.keys()) @property def measurements(self) -> Mapping['cirq.MeasurementKey', Tuple[int, ...]]: + """Gets the a mapping from measurement key to measurement.""" return self._measurements @property def measured_qubits(self) -> Mapping['cirq.MeasurementKey', Tuple['cirq.Qid', ...]]: + """Gets the a mapping from measurement key to the qubits measured.""" return self._measured_qubits def record_measurement( self, key: 'cirq.MeasurementKey', measurement: Sequence[int], qubits: Sequence['cirq.Qid'] ): - if len(measurement) != len(qubits) and len(measurement) != 1: - # the latter condition is allowed for keyed channel measurements + """Records a measurement. + + Args: + key: The measurement key to hold the measurement. + measurement: The measurement result. + qubits: The qubits that were measured. + + Raises: + ValueError: If the measurement shape does not match the qubits + measured, or if the measurement key was already used. + """ + if len(measurement) != len(qubits): raise ValueError(f'{len(measurement)} measurements but {len(qubits)} qubits.') if key in self._measurements: raise ValueError(f"Measurement already logged to key {key!r}") self._measurements[key] = tuple(measurement) self._measured_qubits[key] = tuple(qubits) + def record_channel_measurement( + self, key: 'cirq.MeasurementKey', measurement: int, qubits: Sequence['cirq.Qid'] + ): + """Records a channel measurement. + + Args: + key: The measurement key to hold the measurement. + measurement: The measurement result. + qubits: The qubits that were measured. + + Raises: + ValueError: If the measurement key was already used. + """ + if key in self._measurements: + raise ValueError(f"Measurement already logged to key {key!r}") + self._measurements[key] = (measurement,) + self._measured_qubits[key] = tuple(qubits) + def get_int(self, key: 'cirq.MeasurementKey') -> int: + """Gets the integer corresponding to the measurement. + + Args: + key: The measurement key. + + Raises: + ValueError: If the key has not been used. + """ + if key not in self._measurements: + raise KeyError(f'The measurement key {key} is not in {self._measurements}') measurement = self._measurements[key] - # keyed channels if len(measurement) == 1: + # Needed to support keyed channels return measurement[0] return digits.big_endian_digits_to_int( measurement, base=[q.dimension for q in self._measured_qubits[key]] ) def copy(self): + """Creates a copy of the object.""" return ClassicalData(self._measurements.copy(), self._measured_qubits.copy()) def _json_dict_(self): diff --git a/cirq-core/cirq/value/classical_data_test.py b/cirq-core/cirq/value/classical_data_test.py index 3729c35aee4..af9be881474 100644 --- a/cirq-core/cirq/value/classical_data_test.py +++ b/cirq-core/cirq/value/classical_data_test.py @@ -54,6 +54,21 @@ def test_record_measurement_errors(): cd.record_measurement(mkey_m, (0, 1), two_qubits) +def test_record_channel_measurement(): + cd = cirq.ClassicalData() + cd.record_channel_measurement(mkey_m, 1, two_qubits) + assert cd.measurements == {mkey_m: (1,)} + assert cd.keys() == (mkey_m,) + assert cd.measured_qubits == {mkey_m: two_qubits} + + +def test_record_channel_measurement_errors(): + cd = cirq.ClassicalData() + cd.record_channel_measurement(mkey_m, 1, two_qubits) + with pytest.raises(ValueError, match='Measurement already logged to key'): + cd.record_channel_measurement(mkey_m, 1, two_qubits) + + def test_get_int(): cd = cirq.ClassicalData() cd.record_measurement(mkey_m, (0, 1), two_qubits) @@ -62,11 +77,14 @@ def test_get_int(): cd.record_measurement(mkey_m, (1, 1), two_qubits) assert cd.get_int(mkey_m) == 3 cd = cirq.ClassicalData() - cd.record_measurement(mkey_m, (8,), two_qubits) - assert cd.get_int(mkey_m) == 8 + cd.record_channel_measurement(mkey_m, 1, two_qubits) + assert cd.get_int(mkey_m) == 1 cd = cirq.ClassicalData() cd.record_measurement(mkey_m, (1, 1), (cirq.LineQid.range(2, dimension=3))) assert cd.get_int(mkey_m) == 4 + cd = cirq.ClassicalData() + with pytest.raises(KeyError, match='The measurement key m is not in {}'): + cd.get_int(mkey_m) def test_copy(): From 721382ad70ba8156e7b724a609d19342f35188d0 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Fri, 24 Dec 2021 08:48:11 -0800 Subject: [PATCH 51/89] KeyError --- cirq-core/cirq/value/classical_data.py | 11 ++++++----- cirq-core/cirq/value/classical_data_test.py | 4 ++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/cirq-core/cirq/value/classical_data.py b/cirq-core/cirq/value/classical_data.py index e6c82d73cfa..adec4e09ff9 100644 --- a/cirq-core/cirq/value/classical_data.py +++ b/cirq-core/cirq/value/classical_data.py @@ -81,12 +81,13 @@ def record_measurement( Raises: ValueError: If the measurement shape does not match the qubits - measured, or if the measurement key was already used. + measured. + KeyError: If the measurement key was already used. """ if len(measurement) != len(qubits): raise ValueError(f'{len(measurement)} measurements but {len(qubits)} qubits.') if key in self._measurements: - raise ValueError(f"Measurement already logged to key {key!r}") + raise KeyError(f"Measurement already logged to key {key}") self._measurements[key] = tuple(measurement) self._measured_qubits[key] = tuple(qubits) @@ -101,10 +102,10 @@ def record_channel_measurement( qubits: The qubits that were measured. Raises: - ValueError: If the measurement key was already used. + KeyError: If the measurement key was already used. """ if key in self._measurements: - raise ValueError(f"Measurement already logged to key {key!r}") + raise KeyError(f"Measurement already logged to key {key}") self._measurements[key] = (measurement,) self._measured_qubits[key] = tuple(qubits) @@ -115,7 +116,7 @@ def get_int(self, key: 'cirq.MeasurementKey') -> int: key: The measurement key. Raises: - ValueError: If the key has not been used. + KeyError: If the key has not been used. """ if key not in self._measurements: raise KeyError(f'The measurement key {key} is not in {self._measurements}') diff --git a/cirq-core/cirq/value/classical_data_test.py b/cirq-core/cirq/value/classical_data_test.py index af9be881474..8d310ee998a 100644 --- a/cirq-core/cirq/value/classical_data_test.py +++ b/cirq-core/cirq/value/classical_data_test.py @@ -50,7 +50,7 @@ def test_record_measurement_errors(): with pytest.raises(ValueError, match='3 measurements but 2 qubits'): cd.record_measurement(mkey_m, (0, 1, 2), two_qubits) cd.record_measurement(mkey_m, (0, 1), two_qubits) - with pytest.raises(ValueError, match='Measurement already logged to key'): + with pytest.raises(KeyError, match='Measurement already logged to key m'): cd.record_measurement(mkey_m, (0, 1), two_qubits) @@ -65,7 +65,7 @@ def test_record_channel_measurement(): def test_record_channel_measurement_errors(): cd = cirq.ClassicalData() cd.record_channel_measurement(mkey_m, 1, two_qubits) - with pytest.raises(ValueError, match='Measurement already logged to key'): + with pytest.raises(KeyError, match='Measurement already logged to key m'): cd.record_channel_measurement(mkey_m, 1, two_qubits) From c22ba2a4f7c56074222d5ad305886901973936f6 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Fri, 24 Dec 2021 09:06:46 -0800 Subject: [PATCH 52/89] revert to ValueError --- cirq-core/cirq/value/classical_data.py | 9 ++++----- cirq-core/cirq/value/classical_data_test.py | 4 ++-- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/cirq-core/cirq/value/classical_data.py b/cirq-core/cirq/value/classical_data.py index adec4e09ff9..d4e01c81f98 100644 --- a/cirq-core/cirq/value/classical_data.py +++ b/cirq-core/cirq/value/classical_data.py @@ -81,13 +81,12 @@ def record_measurement( Raises: ValueError: If the measurement shape does not match the qubits - measured. - KeyError: If the measurement key was already used. + measured or if the measurement key was already used. """ if len(measurement) != len(qubits): raise ValueError(f'{len(measurement)} measurements but {len(qubits)} qubits.') if key in self._measurements: - raise KeyError(f"Measurement already logged to key {key}") + raise ValueError(f"Measurement already logged to key {key}") self._measurements[key] = tuple(measurement) self._measured_qubits[key] = tuple(qubits) @@ -102,10 +101,10 @@ def record_channel_measurement( qubits: The qubits that were measured. Raises: - KeyError: If the measurement key was already used. + ValueError: If the measurement key was already used. """ if key in self._measurements: - raise KeyError(f"Measurement already logged to key {key}") + raise ValueError(f"Measurement already logged to key {key}") self._measurements[key] = (measurement,) self._measured_qubits[key] = tuple(qubits) diff --git a/cirq-core/cirq/value/classical_data_test.py b/cirq-core/cirq/value/classical_data_test.py index 8d310ee998a..7daeff9e7fa 100644 --- a/cirq-core/cirq/value/classical_data_test.py +++ b/cirq-core/cirq/value/classical_data_test.py @@ -50,7 +50,7 @@ def test_record_measurement_errors(): with pytest.raises(ValueError, match='3 measurements but 2 qubits'): cd.record_measurement(mkey_m, (0, 1, 2), two_qubits) cd.record_measurement(mkey_m, (0, 1), two_qubits) - with pytest.raises(KeyError, match='Measurement already logged to key m'): + with pytest.raises(ValueError, match='Measurement already logged to key m'): cd.record_measurement(mkey_m, (0, 1), two_qubits) @@ -65,7 +65,7 @@ def test_record_channel_measurement(): def test_record_channel_measurement_errors(): cd = cirq.ClassicalData() cd.record_channel_measurement(mkey_m, 1, two_qubits) - with pytest.raises(KeyError, match='Measurement already logged to key m'): + with pytest.raises(ValueError, match='Measurement already logged to key m'): cd.record_channel_measurement(mkey_m, 1, two_qubits) From 9940014cbfc5da062e29927359d4c5b6b77949b0 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Sat, 25 Dec 2021 21:46:16 -0800 Subject: [PATCH 53/89] Base class --- cirq-core/cirq/__init__.py | 3 + cirq-core/cirq/contrib/quimb/mps_simulator.py | 4 +- cirq-core/cirq/json_resolver_cache.py | 1 + .../json_test_data/ClassicalData.json | 7 + .../json_test_data/ClassicalData.repr | 2 +- .../json_test_data/MeasurementType.json | 1 + .../json_test_data/MeasurementType.repr | 1 + cirq-core/cirq/sim/act_on_args.py | 6 +- cirq-core/cirq/sim/act_on_args_container.py | 6 +- .../cirq/sim/act_on_density_matrix_args.py | 2 +- .../cirq/sim/act_on_state_vector_args.py | 6 +- .../clifford/act_on_clifford_tableau_args.py | 2 +- .../act_on_stabilizer_ch_form_args.py | 2 +- .../cirq/sim/clifford/clifford_simulator.py | 2 +- .../cirq/sim/density_matrix_simulator.py | 2 +- cirq-core/cirq/sim/operation_target.py | 4 +- cirq-core/cirq/sim/simulator_base.py | 2 +- cirq-core/cirq/sim/simulator_base_test.py | 4 +- cirq-core/cirq/sim/sparse_simulator.py | 2 +- cirq-core/cirq/value/__init__.py | 3 + cirq-core/cirq/value/classical_data.py | 225 ++++++++++++++---- cirq-core/cirq/value/classical_data_test.py | 59 +++-- cirq-core/cirq/value/condition.py | 6 +- cirq-core/cirq/value/condition_test.py | 10 +- 24 files changed, 255 insertions(+), 107 deletions(-) create mode 100644 cirq-core/cirq/protocols/json_test_data/MeasurementType.json create mode 100644 cirq-core/cirq/protocols/json_test_data/MeasurementType.repr diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index 33440246bd1..4341fa32cc7 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -484,6 +484,8 @@ chosen_angle_to_canonical_half_turns, chosen_angle_to_half_turns, ClassicalData, + ClassicalDataBase, + ClassicalDataReader, Condition, Duration, DURATION_LIKE, @@ -492,6 +494,7 @@ LinearDict, MEASUREMENT_KEY_SEPARATOR, MeasurementKey, + MeasurementType, PeriodicValue, RANDOM_STATE_OR_SEED_LIKE, state_vector_to_probabilities, diff --git a/cirq-core/cirq/contrib/quimb/mps_simulator.py b/cirq-core/cirq/contrib/quimb/mps_simulator.py index c9a07293df4..74f8c082d33 100644 --- a/cirq-core/cirq/contrib/quimb/mps_simulator.py +++ b/cirq-core/cirq/contrib/quimb/mps_simulator.py @@ -91,7 +91,7 @@ def _create_partial_act_on_args_ex( self, initial_state: Union[int, 'MPSState'], qubits: Sequence['cirq.Qid'], - classical_data: 'cirq.ClassicalData', + classical_data: 'cirq.ClassicalDataBase', ) -> 'MPSState': """Creates MPSState args for simulating the Circuit. @@ -230,7 +230,7 @@ def __init__( grouping: Optional[Dict['cirq.Qid', int]] = None, initial_state: int = 0, log_of_measurement_results: Dict[str, Any] = None, - classical_data: 'cirq.ClassicalData' = None, + classical_data: 'cirq.ClassicalDataBase' = None, ): """Creates and MPSState diff --git a/cirq-core/cirq/json_resolver_cache.py b/cirq-core/cirq/json_resolver_cache.py index 88ded1b7e48..71c5ced23b4 100644 --- a/cirq-core/cirq/json_resolver_cache.py +++ b/cirq-core/cirq/json_resolver_cache.py @@ -105,6 +105,7 @@ def _parallel_gate_op(gate, qubits): 'MixedUnitaryChannel': cirq.MixedUnitaryChannel, 'MeasurementKey': cirq.MeasurementKey, 'MeasurementGate': cirq.MeasurementGate, + 'MeasurementType': cirq.MeasurementType, '_MeasurementSpec': cirq.work._MeasurementSpec, 'Moment': cirq.Moment, 'MutableDensePauliString': cirq.MutableDensePauliString, diff --git a/cirq-core/cirq/protocols/json_test_data/ClassicalData.json b/cirq-core/cirq/protocols/json_test_data/ClassicalData.json index 94c2af6e1dc..1b33c860407 100644 --- a/cirq-core/cirq/protocols/json_test_data/ClassicalData.json +++ b/cirq-core/cirq/protocols/json_test_data/ClassicalData.json @@ -14,5 +14,12 @@ "x": 1 } ] + }, + "channel_measurements": { + "c": 3 + }, + "measurement_types": { + "c": 2, + "m": 1 } } \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/ClassicalData.repr b/cirq-core/cirq/protocols/json_test_data/ClassicalData.repr index 2e4cadece4f..8e0a28007cb 100644 --- a/cirq-core/cirq/protocols/json_test_data/ClassicalData.repr +++ b/cirq-core/cirq/protocols/json_test_data/ClassicalData.repr @@ -1 +1 @@ -cirq.ClassicalData(measurements={'m': [0, 1]}, measured_qubits={'m': [cirq.LineQubit(0), cirq.LineQubit(1)]}) \ No newline at end of file +cirq.ClassicalData(_measurements={'m': [0, 1]}, _measured_qubits={'m': [cirq.LineQubit(0), cirq.LineQubit(1)]}, _channel_measurements={'c': 3}, _measurement_types={'m': cirq.MeasurementType.MEASUREMENT, 'c': cirq.MeasurementType.CHANNEL}) \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/MeasurementType.json b/cirq-core/cirq/protocols/json_test_data/MeasurementType.json new file mode 100644 index 00000000000..fd8ef095787 --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/MeasurementType.json @@ -0,0 +1 @@ +[1, 2] \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/MeasurementType.repr b/cirq-core/cirq/protocols/json_test_data/MeasurementType.repr new file mode 100644 index 00000000000..edeebfddc51 --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/MeasurementType.repr @@ -0,0 +1 @@ +[cirq.MeasurementType.MEASUREMENT, cirq.MeasurementType.CHANNEL] \ No newline at end of file diff --git a/cirq-core/cirq/sim/act_on_args.py b/cirq-core/cirq/sim/act_on_args.py index af9680dd8c4..b33ea763e14 100644 --- a/cirq-core/cirq/sim/act_on_args.py +++ b/cirq-core/cirq/sim/act_on_args.py @@ -47,7 +47,7 @@ def __init__( prng: np.random.RandomState = None, qubits: Sequence['cirq.Qid'] = None, log_of_measurement_results: Dict[str, List[int]] = None, - classical_data: 'cirq.ClassicalData' = None, + classical_data: 'cirq.ClassicalDataBase' = None, ignore_measurement_results: bool = False, ): """Inits ActOnArgs. @@ -74,7 +74,7 @@ def __init__( self._set_qubits(qubits) self.prng = prng # pylint: disable=line-too-long - self._classical_data = classical_data or value.ClassicalData(log_of_measurement_results) # type: ignore + self._classical_data = classical_data or value.ClassicalData(_measurements=log_of_measurement_results) # type: ignore # pylint: enable=line-too-long self._ignore_measurement_results = ignore_measurement_results @@ -197,7 +197,7 @@ def _on_transpose_to_qubit_order(self: TSelf, qubits: Sequence['cirq.Qid'], targ functionality, if supported.""" @property - def classical_data(self) -> 'cirq.ClassicalData': + def classical_data(self) -> 'cirq.ClassicalDataReader': return self._classical_data @property diff --git a/cirq-core/cirq/sim/act_on_args_container.py b/cirq-core/cirq/sim/act_on_args_container.py index e150d463e9e..8b2d9e08264 100644 --- a/cirq-core/cirq/sim/act_on_args_container.py +++ b/cirq-core/cirq/sim/act_on_args_container.py @@ -51,7 +51,7 @@ def __init__( qubits: Sequence['cirq.Qid'], split_untangled_states: bool, log_of_measurement_results: Dict[str, Any] = None, - classical_data: 'cirq.ClassicalData' = None, + classical_data: 'cirq.ClassicalDataBase' = None, ): """Initializes the class. @@ -71,7 +71,7 @@ def __init__( self._qubits = tuple(qubits) self.split_untangled_states = split_untangled_states # pylint: disable=line-too-long - self._classical_data = classical_data or value.ClassicalData(log_of_measurement_results) # type: ignore + self._classical_data = classical_data or value.ClassicalData(_measurements=log_of_measurement_results) # type: ignore # pylint: enable=line-too-long def create_merged_state(self) -> TActOnArgs: @@ -150,7 +150,7 @@ def qubits(self) -> Tuple['cirq.Qid', ...]: return self._qubits @property - def classical_data(self) -> 'cirq.ClassicalData': + def classical_data(self) -> 'cirq.ClassicalDataReader': return self._classical_data def sample( diff --git a/cirq-core/cirq/sim/act_on_density_matrix_args.py b/cirq-core/cirq/sim/act_on_density_matrix_args.py index 3a0f214cd4c..eb2c517cb32 100644 --- a/cirq-core/cirq/sim/act_on_density_matrix_args.py +++ b/cirq-core/cirq/sim/act_on_density_matrix_args.py @@ -41,7 +41,7 @@ def __init__( prng: np.random.RandomState = None, log_of_measurement_results: Dict[str, Any] = None, qubits: Sequence['cirq.Qid'] = None, - classical_data: 'cirq.ClassicalData' = None, + classical_data: 'cirq.ClassicalDataBase' = None, ignore_measurement_results: bool = False, ): """Inits ActOnDensityMatrixArgs. diff --git a/cirq-core/cirq/sim/act_on_state_vector_args.py b/cirq-core/cirq/sim/act_on_state_vector_args.py index 9fbca455006..849d37d4c98 100644 --- a/cirq-core/cirq/sim/act_on_state_vector_args.py +++ b/cirq-core/cirq/sim/act_on_state_vector_args.py @@ -44,7 +44,7 @@ def __init__( prng: np.random.RandomState = None, log_of_measurement_results: Dict[str, Any] = None, qubits: Sequence['cirq.Qid'] = None, - classical_data: 'cirq.ClassicalData' = None, + classical_data: 'cirq.ClassicalDataBase' = None, ): """Inits ActOnStateVectorArgs. @@ -278,7 +278,7 @@ def _strat_act_on_state_vector_from_mixture( args.swap_target_tensor_for(args.available_buffer) if protocols.is_measurement(action): key = protocols.measurement_key_name(action) - args.classical_data.record_channel_measurement(key, index, qubits) + args._classical_data.record_channel_measurement(key, index) return True @@ -327,5 +327,5 @@ def prepare_into_buffer(k: int): args.swap_target_tensor_for(args.available_buffer) if protocols.is_measurement(action): key = protocols.measurement_key_name(action) - args.classical_data.record_channel_measurement(key, index, qubits) + args._classical_data.record_channel_measurement(key, index) return True diff --git a/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py b/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py index 50ada171dea..cfdee35b2a2 100644 --- a/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py +++ b/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py @@ -42,7 +42,7 @@ def __init__( prng: np.random.RandomState, log_of_measurement_results: Dict[str, Any], qubits: Sequence['cirq.Qid'] = None, - classical_data: 'cirq.ClassicalData' = None, + classical_data: 'cirq.ClassicalDataBase' = None, ): """Inits ActOnCliffordTableauArgs. diff --git a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py index d07c5c1b353..7000a0a2d97 100644 --- a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py +++ b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py @@ -41,7 +41,7 @@ def __init__( prng: np.random.RandomState, log_of_measurement_results: Dict[str, Any] = None, qubits: Sequence['cirq.Qid'] = None, - classical_data: 'cirq.ClassicalData' = None, + classical_data: 'cirq.ClassicalDataBase' = None, ): """Initializes with the given state and the axes for the operation. Args: diff --git a/cirq-core/cirq/sim/clifford/clifford_simulator.py b/cirq-core/cirq/sim/clifford/clifford_simulator.py index 308a8a081c5..917e3b738c1 100644 --- a/cirq-core/cirq/sim/clifford/clifford_simulator.py +++ b/cirq-core/cirq/sim/clifford/clifford_simulator.py @@ -68,7 +68,7 @@ def _create_partial_act_on_args_ex( self, initial_state: Union[int, 'cirq.ActOnStabilizerCHFormArgs'], qubits: Sequence['cirq.Qid'], - classical_data: 'cirq.ClassicalData' = None, + classical_data: 'cirq.ClassicalDataBase' = None, ) -> 'cirq.ActOnStabilizerCHFormArgs': """Creates the ActOnStabilizerChFormArgs for a circuit. diff --git a/cirq-core/cirq/sim/density_matrix_simulator.py b/cirq-core/cirq/sim/density_matrix_simulator.py index e7e315b353a..58ca2a6cf39 100644 --- a/cirq-core/cirq/sim/density_matrix_simulator.py +++ b/cirq-core/cirq/sim/density_matrix_simulator.py @@ -176,7 +176,7 @@ def _create_partial_act_on_args_ex( self, initial_state: Union[np.ndarray, 'cirq.STATE_VECTOR_LIKE', 'cirq.ActOnDensityMatrixArgs'], qubits: Sequence['cirq.Qid'], - classical_data: 'cirq.ClassicalData' = None, + classical_data: 'cirq.ClassicalDataBase' = None, ) -> 'cirq.ActOnDensityMatrixArgs': """Creates the ActOnDensityMatrixArgs for a circuit. diff --git a/cirq-core/cirq/sim/operation_target.py b/cirq-core/cirq/sim/operation_target.py index 31dda85fb74..5ae7dad8d74 100644 --- a/cirq-core/cirq/sim/operation_target.py +++ b/cirq-core/cirq/sim/operation_target.py @@ -79,11 +79,11 @@ def qubits(self) -> Tuple['cirq.Qid', ...]: @property def log_of_measurement_results(self) -> Dict[str, Any]: """Gets the log of measurement results.""" - return {str(k): list(v) for k, v in self.classical_data.measurements.items()} + return {str(k): list(self.classical_data.get_digits(k)) for k in self.classical_data.keys()} @property @abc.abstractmethod - def classical_data(self) -> 'cirq.ClassicalData': + def classical_data(self) -> 'cirq.ClassicalDataReader': """The shared classical data container for this simulation..""" @abc.abstractmethod diff --git a/cirq-core/cirq/sim/simulator_base.py b/cirq-core/cirq/sim/simulator_base.py index eb76b701b35..6ceba551e98 100644 --- a/cirq-core/cirq/sim/simulator_base.py +++ b/cirq-core/cirq/sim/simulator_base.py @@ -143,7 +143,7 @@ def _create_partial_act_on_args_ex( self, initial_state: Any, qubits: Sequence['cirq.Qid'], - classical_data: 'cirq.ClassicalData', + classical_data: 'cirq.ClassicalDataBase', ) -> TActOnArgs: """Creates an instance of the TActOnArgs class for the simulator. diff --git a/cirq-core/cirq/sim/simulator_base_test.py b/cirq-core/cirq/sim/simulator_base_test.py index 8e613c31e7a..6e6df9991f2 100644 --- a/cirq-core/cirq/sim/simulator_base_test.py +++ b/cirq-core/cirq/sim/simulator_base_test.py @@ -111,7 +111,7 @@ def _create_partial_act_on_args_ex( self, initial_state: Any, qubits: Sequence['cirq.Qid'], - classical_data: cirq.ClassicalData, + classical_data: cirq.ClassicalDataBase, ) -> CountingActOnArgs: return CountingActOnArgs(qubits=qubits, state=initial_state, classical_data=classical_data) @@ -141,7 +141,7 @@ def _create_partial_act_on_args_ex( self, initial_state: Any, qubits: Sequence['cirq.Qid'], - classical_data: cirq.ClassicalData, + classical_data: cirq.ClassicalDataBase, ) -> CountingActOnArgs: return SplittableCountingActOnArgs( qubits=qubits, state=initial_state, classical_data=classical_data diff --git a/cirq-core/cirq/sim/sparse_simulator.py b/cirq-core/cirq/sim/sparse_simulator.py index ed084bbf3d4..0e95b3dd247 100644 --- a/cirq-core/cirq/sim/sparse_simulator.py +++ b/cirq-core/cirq/sim/sparse_simulator.py @@ -175,7 +175,7 @@ def _create_partial_act_on_args_ex( self, initial_state: Union['cirq.STATE_VECTOR_LIKE', 'cirq.ActOnStateVectorArgs'], qubits: Sequence['cirq.Qid'], - classical_data: 'cirq.ClassicalData', + classical_data: 'cirq.ClassicalDataBase', ): """Creates the ActOnStateVectorArgs for a circuit. diff --git a/cirq-core/cirq/value/__init__.py b/cirq-core/cirq/value/__init__.py index 168dc935b0b..498096e7cde 100644 --- a/cirq-core/cirq/value/__init__.py +++ b/cirq-core/cirq/value/__init__.py @@ -27,6 +27,9 @@ from cirq.value.classical_data import ( ClassicalData, + ClassicalDataBase, + ClassicalDataReader, + MeasurementType, ) from cirq.value.condition import ( diff --git a/cirq-core/cirq/value/classical_data.py b/cirq-core/cirq/value/classical_data.py index d4e01c81f98..162977c2860 100644 --- a/cirq-core/cirq/value/classical_data.py +++ b/cirq-core/cirq/value/classical_data.py @@ -12,63 +12,154 @@ # See the License for the specific language governing permissions and # limitations under the License. -import dataclasses -from typing import Dict, Mapping, Sequence, Tuple, TYPE_CHECKING - +import abc +import enum +from typing import Dict, Mapping, Sequence, Tuple, TYPE_CHECKING, TypeVar from cirq.protocols import json_serialization -from cirq.value import digits +from cirq.value import digits, value_equality_attr if TYPE_CHECKING: import cirq -@dataclasses.dataclass -class ClassicalData: +class MeasurementType(enum.IntEnum): + MEASUREMENT = 1 + CHANNEL = 2 + + def __repr__(self): + return f'cirq.{str(self)}' + + +TSelf = TypeVar('TSelf', bound='ClassicalDataReader') + + +class ClassicalDataReader(abc.ABC): + @abc.abstractmethod + def keys(self) -> Tuple['cirq.MeasurementKey', ...]: + """Gets the measurement keys in the order they were stored.""" + + @abc.abstractmethod + def get_int(self, key: 'cirq.MeasurementKey') -> int: + """Gets the integer corresponding to the measurement. + + Args: + key: The measurement key. + + Raises: + KeyError: If the key has not been used. + """ + + @abc.abstractmethod + def get_digits(self, key: 'cirq.MeasurementKey') -> Tuple[int, ...]: + """Gets the digits of the measurement. + + Args: + key: The measurement key. + + Raises: + KeyError: If the key has not been used. + """ + + @abc.abstractmethod + def copy(self: TSelf) -> TSelf: + """Creates a copy of the object.""" + + +class ClassicalDataBase(ClassicalDataReader, abc.ABC): + @abc.abstractmethod + def record_measurement( + self, key: 'cirq.MeasurementKey', measurement: Sequence[int], qubits: Sequence['cirq.Qid'] + ): + """Records a measurement. + + Args: + key: The measurement key to hold the measurement. + measurement: The measurement result. + qubits: The qubits that were measured. + + Raises: + ValueError: If the measurement shape does not match the qubits + measured or if the measurement key was already used. + """ + + @abc.abstractmethod + def record_channel_measurement(self, key: 'cirq.MeasurementKey', measurement: int): + """Records a channel measurement. + + Args: + key: The measurement key to hold the measurement. + measurement: The measurement result. + + Raises: + ValueError: If the measurement key was already used. + """ + + +@value_equality_attr.value_equality(unhashable=True) +class ClassicalData(ClassicalDataBase): """Classical data representing measurements and metadata.""" _measurements: Dict['cirq.MeasurementKey', Tuple[int, ...]] _measured_qubits: Dict['cirq.MeasurementKey', Tuple['cirq.Qid', ...]] + _channel_measurements: Dict['cirq.MeasurementKey', int] + _measurement_types: Dict['cirq.MeasurementKey', 'cirq.MeasurementType'] def __init__( self, - measurements: Dict['cirq.MeasurementKey', Tuple[int, ...]] = None, - measured_qubits: Dict['cirq.MeasurementKey', Tuple['cirq.Qid', ...]] = None, + *, + _measurements: Dict['cirq.MeasurementKey', Tuple[int, ...]] = None, + _measured_qubits: Dict['cirq.MeasurementKey', Tuple['cirq.Qid', ...]] = None, + _channel_measurements: Dict['cirq.MeasurementKey', int] = None, + _measurement_types: Dict['cirq.MeasurementKey', 'cirq.MeasurementType'] = None, ): - """Initializes a `ClassicalData` object. - - Args: - measurements: The measurements to seed with, if any. - measured_qubits: The qubits corresponding to the measurements. - """ - # TODO: Uncomment this after log_of_measurement_results is deprecated and removed - # if (measurements is None) != (measured_qubits is None): - # raise ValueError( - # 'measurements and measured_qubits must both either be provided or left default.' - # ) - if measurements is None: - measurements = {} - if measured_qubits is None: - measured_qubits = {} - # TODO: Uncomment this after log_of_measurement_results is deprecated and removed - # if set(measurements.keys()) != set(measured_qubits.keys()): - # raise ValueError('measurements and measured_qubits must contain same keys.') - self._measurements = measurements - self._measured_qubits = measured_qubits + """Initializes a `ClassicalData` object.""" + if _measurements is not None: + if _measurement_types is None: + _measurement_types = { + k: MeasurementType.MEASUREMENT for k, v in _measurements.items() + } + if _channel_measurements is not None: + if _measurement_types is None: + _measurement_types = { + k: MeasurementType.CHANNEL for k, v in _channel_measurements.items() + } + if _measurements is None: + _measurements = {} + if _measured_qubits is None: + _measured_qubits = {} + if _channel_measurements is None: + _channel_measurements = {} + if _measurement_types is None: + _measurement_types = {} + self._measurements = _measurements + self._measured_qubits = _measured_qubits + self._channel_measurements = _channel_measurements + self._measurement_types = _measurement_types def keys(self) -> Tuple['cirq.MeasurementKey', ...]: """Gets the measurement keys in the order they were stored.""" - return tuple(self._measurements.keys()) + return tuple(self._measurement_types.keys()) @property def measurements(self) -> Mapping['cirq.MeasurementKey', Tuple[int, ...]]: """Gets the a mapping from measurement key to measurement.""" return self._measurements + @property + def channel_measurements(self) -> Mapping['cirq.MeasurementKey', int]: + """Gets the a mapping from measurement key to channel measurement.""" + return self._channel_measurements + @property def measured_qubits(self) -> Mapping['cirq.MeasurementKey', Tuple['cirq.Qid', ...]]: """Gets the a mapping from measurement key to the qubits measured.""" return self._measured_qubits + @property + def measurement_types(self) -> Mapping['cirq.MeasurementKey', 'cirq.MeasurementType']: + """Gets the a mapping from measurement key to the qubits measured.""" + return self._measurement_types + def record_measurement( self, key: 'cirq.MeasurementKey', measurement: Sequence[int], qubits: Sequence['cirq.Qid'] ): @@ -87,26 +178,39 @@ def record_measurement( raise ValueError(f'{len(measurement)} measurements but {len(qubits)} qubits.') if key in self._measurements: raise ValueError(f"Measurement already logged to key {key}") + self._measurement_types[key] = MeasurementType.MEASUREMENT self._measurements[key] = tuple(measurement) self._measured_qubits[key] = tuple(qubits) - def record_channel_measurement( - self, key: 'cirq.MeasurementKey', measurement: int, qubits: Sequence['cirq.Qid'] - ): + def record_channel_measurement(self, key: 'cirq.MeasurementKey', measurement: int): """Records a channel measurement. Args: key: The measurement key to hold the measurement. measurement: The measurement result. - qubits: The qubits that were measured. Raises: ValueError: If the measurement key was already used. """ - if key in self._measurements: + if key in self._measurement_types: raise ValueError(f"Measurement already logged to key {key}") - self._measurements[key] = (measurement,) - self._measured_qubits[key] = tuple(qubits) + self._measurement_types[key] = MeasurementType.CHANNEL + self._channel_measurements[key] = measurement + + def get_digits(self, key: 'cirq.MeasurementKey') -> Tuple[int, ...]: + """Gets the digits of the measurement. + + Args: + key: The measurement key. + + Raises: + KeyError: If the key has not been used. + """ + return ( + self._measurements[key] + if self._measurement_types[key] == MeasurementType.MEASUREMENT + else (self._channel_measurements[key],) + ) def get_int(self, key: 'cirq.MeasurementKey') -> int: """Gets the integer corresponding to the measurement. @@ -117,29 +221,54 @@ def get_int(self, key: 'cirq.MeasurementKey') -> int: Raises: KeyError: If the key has not been used. """ - if key not in self._measurements: + if key not in self._measurement_types: raise KeyError(f'The measurement key {key} is not in {self._measurements}') - measurement = self._measurements[key] - if len(measurement) == 1: - # Needed to support keyed channels - return measurement[0] + measurement_type = self._measurement_types[key] + if measurement_type == MeasurementType.CHANNEL: + return self._channel_measurements[key] + if key not in self._measured_qubits: + return digits.big_endian_bits_to_int(self._measurements[key]) return digits.big_endian_digits_to_int( - measurement, base=[q.dimension for q in self._measured_qubits[key]] + self._measurements[key], base=[q.dimension for q in self._measured_qubits[key]] ) def copy(self): """Creates a copy of the object.""" - return ClassicalData(self._measurements.copy(), self._measured_qubits.copy()) + return ClassicalData( + _measurements=self._measurements.copy(), + _measured_qubits=self._measured_qubits.copy(), + _channel_measurements=self._channel_measurements.copy(), + _measurement_types=self._measurement_types.copy(), + ) def _json_dict_(self): - return json_serialization.obj_to_dict_helper(self, ['measurements', 'measured_qubits']) + return json_serialization.obj_to_dict_helper( + self, ['measurements', 'measured_qubits', 'channel_measurements', 'measurement_types'] + ) @classmethod - def _from_json_dict_(cls, measurements, measured_qubits, **kwargs): - return cls(measurements=measurements, measured_qubits=measured_qubits) + def _from_json_dict_( + cls, measurements, measured_qubits, channel_measurements, measurement_types, **kwargs + ): + return cls( + _measurements=measurements, + _measured_qubits=measured_qubits, + _channel_measurements=channel_measurements, + _measurement_types=measurement_types, + ) def __repr__(self): return ( - f'cirq.ClassicalData(measurements={self.measurements!r},' - f' measured_qubits={self.measured_qubits!r})' + f'cirq.ClassicalData(_measurements={self.measurements!r},' + f' _measured_qubits={self.measured_qubits!r},' + f' _channel_measurements={self.channel_measurements!r},' + f' _measurement_types={self.measurement_types!r})' + ) + + def _value_equality_values_(self): + return ( + self._measurements, + self._channel_measurements, + self._measurement_types, + self._measured_qubits, ) diff --git a/cirq-core/cirq/value/classical_data_test.py b/cirq-core/cirq/value/classical_data_test.py index 7daeff9e7fa..f58334e81eb 100644 --- a/cirq-core/cirq/value/classical_data_test.py +++ b/cirq-core/cirq/value/classical_data_test.py @@ -17,24 +17,15 @@ import cirq mkey_m = cirq.MeasurementKey('m') +mkey_c = cirq.MeasurementKey('c') two_qubits = tuple(cirq.LineQubit.range(2)) -def test_init_empty(): +def test_init(): cd = cirq.ClassicalData() - assert cd.measurements is not None - assert not cd.measurements - assert cd.keys() is not None - assert not cd.keys() - assert cd.measured_qubits is not None - assert not cd.measured_qubits - - -def test_init_properties(): - cd = cirq.ClassicalData({mkey_m: (0, 1)}, {mkey_m: two_qubits}) - assert cd.measurements == {mkey_m: (0, 1)} - assert cd.keys() == (mkey_m,) - assert cd.measured_qubits == {mkey_m: two_qubits} + assert cd.measurements == {} + assert cd.keys() == () + assert cd.measured_qubits == {} def test_record_measurement(): @@ -56,17 +47,16 @@ def test_record_measurement_errors(): def test_record_channel_measurement(): cd = cirq.ClassicalData() - cd.record_channel_measurement(mkey_m, 1, two_qubits) - assert cd.measurements == {mkey_m: (1,)} + cd.record_channel_measurement(mkey_m, 1) + assert cd.channel_measurements == {mkey_m: 1} assert cd.keys() == (mkey_m,) - assert cd.measured_qubits == {mkey_m: two_qubits} def test_record_channel_measurement_errors(): cd = cirq.ClassicalData() - cd.record_channel_measurement(mkey_m, 1, two_qubits) + cd.record_channel_measurement(mkey_m, 1) with pytest.raises(ValueError, match='Measurement already logged to key m'): - cd.record_channel_measurement(mkey_m, 1, two_qubits) + cd.record_channel_measurement(mkey_m, 1) def test_get_int(): @@ -77,7 +67,7 @@ def test_get_int(): cd.record_measurement(mkey_m, (1, 1), two_qubits) assert cd.get_int(mkey_m) == 3 cd = cirq.ClassicalData() - cd.record_channel_measurement(mkey_m, 1, two_qubits) + cd.record_channel_measurement(mkey_m, 1) assert cd.get_int(mkey_m) == 1 cd = cirq.ClassicalData() cd.record_measurement(mkey_m, (1, 1), (cirq.LineQid.range(2, dimension=3))) @@ -88,7 +78,15 @@ def test_get_int(): def test_copy(): - cd = cirq.ClassicalData({mkey_m: (0, 1)}, {mkey_m: two_qubits}) + cd = cirq.ClassicalData( + _measurements={mkey_m: (0, 1)}, + _measured_qubits={mkey_m: two_qubits}, + _channel_measurements={mkey_c: 3}, + _measurement_types={ + mkey_m: cirq.MeasurementType.MEASUREMENT, + mkey_c: cirq.MeasurementType.CHANNEL, + }, + ) cd1 = cd.copy() assert cd1 is not cd assert cd1 == cd @@ -96,12 +94,21 @@ def test_copy(): assert cd1.measurements == cd.measurements assert cd1.measured_qubits is not cd.measured_qubits assert cd1.measured_qubits == cd.measured_qubits + assert cd1.channel_measurements is not cd.channel_measurements + assert cd1.channel_measurements == cd.channel_measurements + assert cd1.measurement_types is not cd.measurement_types + assert cd1.measurement_types == cd.measurement_types def test_repr(): - cd = cirq.ClassicalData({mkey_m: (0, 1)}, {mkey_m: two_qubits}) - assert repr(cd) == ( - "cirq.ClassicalData(" - "measurements={cirq.MeasurementKey(name='m'): (0, 1)}, " - "measured_qubits={cirq.MeasurementKey(name='m'): (cirq.LineQubit(0), cirq.LineQubit(1))})" + cd = cirq.ClassicalData( + _measurements={mkey_m: (0, 1)}, + _measured_qubits={mkey_m: two_qubits}, + _channel_measurements={mkey_c: 3}, + _measurement_types={ + mkey_m: cirq.MeasurementType.MEASUREMENT, + mkey_c: cirq.MeasurementType.CHANNEL, + }, ) + cirq.testing.assert_equivalent_repr(cd) + diff --git a/cirq-core/cirq/value/condition.py b/cirq-core/cirq/value/condition.py index dadac63a882..5722ce93358 100644 --- a/cirq-core/cirq/value/condition.py +++ b/cirq-core/cirq/value/condition.py @@ -41,7 +41,7 @@ def replace_key(self, current: 'cirq.MeasurementKey', replacement: 'cirq.Measure @abc.abstractmethod def resolve( self, - classical_data: 'cirq.ClassicalData', + classical_data: 'cirq.ClassicalDataReader', ) -> bool: """Resolves the condition based on the measurements.""" @@ -103,7 +103,7 @@ def __repr__(self): def resolve( self, - classical_data: 'cirq.ClassicalData', + classical_data: 'cirq.ClassicalDataReader', ) -> bool: if self.key not in classical_data.keys(): raise ValueError(f'Measurement key {self.key} missing when testing classical control') @@ -150,7 +150,7 @@ def __repr__(self): def resolve( self, - classical_data: 'cirq.ClassicalData', + classical_data: 'cirq.ClassicalDataReader', ) -> bool: missing = [str(k) for k in self.keys if k not in classical_data.keys()] if missing: diff --git a/cirq-core/cirq/value/condition_test.py b/cirq-core/cirq/value/condition_test.py index 625cd3fd765..65669232bd9 100644 --- a/cirq-core/cirq/value/condition_test.py +++ b/cirq-core/cirq/value/condition_test.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import dataclasses import re import pytest @@ -43,9 +43,7 @@ def test_key_condition_repr(): def test_key_condition_resolve(): def resolve(measurements): - classical_data = cirq.ClassicalData( - measurements, {k: tuple(cirq.LineQubit(i) for i in v) for k, v in measurements.items()} - ) + classical_data = cirq.ClassicalData(_measurements=measurements) return init_key_condition.resolve(classical_data) assert resolve({'0:a': [1]}) @@ -87,9 +85,7 @@ def test_sympy_condition_repr(): def test_sympy_condition_resolve(): def resolve(measurements): - classical_data = cirq.ClassicalData( - measurements, {k: tuple(cirq.LineQubit(i) for i in v) for k, v in measurements.items()} - ) + classical_data = cirq.ClassicalData(_measurements=measurements) return init_sympy_condition.resolve(classical_data) assert resolve({'0:a': [1]}) From f55965fbcfe1e2ad7e2b9148ae7de484250bc851 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Sat, 25 Dec 2021 22:04:36 -0800 Subject: [PATCH 54/89] json --- .../json_test_data/ClassicalData.json | 71 ++++++++++++++----- .../json_test_data/ClassicalData.repr | 2 +- cirq-core/cirq/value/classical_data.py | 17 +++-- cirq-core/cirq/value/classical_data_test.py | 1 - 4 files changed, 64 insertions(+), 27 deletions(-) diff --git a/cirq-core/cirq/protocols/json_test_data/ClassicalData.json b/cirq-core/cirq/protocols/json_test_data/ClassicalData.json index 1b33c860407..d90424e4206 100644 --- a/cirq-core/cirq/protocols/json_test_data/ClassicalData.json +++ b/cirq-core/cirq/protocols/json_test_data/ClassicalData.json @@ -1,25 +1,60 @@ { "cirq_type": "ClassicalData", - "measurements": { - "m": [0, 1] - }, - "measured_qubits": { - "m": [ + "measurements": [ + [ { - "cirq_type": "LineQubit", - "x": 0 + "cirq_type": "MeasurementKey", + "name": "m", + "path": [] }, + [0, 1] + ] + ], + "measured_qubits": [ + [ + { + "cirq_type": "MeasurementKey", + "name": "m", + "path": [] + }, + [ + { + "cirq_type": "LineQubit", + "x": 0 + }, + { + "cirq_type": "LineQubit", + "x": 1 + } + ] + ] + ], + "channel_measurements": [ + [ + { + "cirq_type": "MeasurementKey", + "name": "c", + "path": [] + }, + 3 + ] + ], + "measurement_types": [ + [ { - "cirq_type": "LineQubit", - "x": 1 - } + "cirq_type": "MeasurementKey", + "name": "m", + "path": [] + }, + 1 + ], + [ + { + "cirq_type": "MeasurementKey", + "name": "c", + "path": [] + }, + 2 ] - }, - "channel_measurements": { - "c": 3 - }, - "measurement_types": { - "c": 2, - "m": 1 - } + ] } \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/ClassicalData.repr b/cirq-core/cirq/protocols/json_test_data/ClassicalData.repr index 8e0a28007cb..f2ae4642f5a 100644 --- a/cirq-core/cirq/protocols/json_test_data/ClassicalData.repr +++ b/cirq-core/cirq/protocols/json_test_data/ClassicalData.repr @@ -1 +1 @@ -cirq.ClassicalData(_measurements={'m': [0, 1]}, _measured_qubits={'m': [cirq.LineQubit(0), cirq.LineQubit(1)]}, _channel_measurements={'c': 3}, _measurement_types={'m': cirq.MeasurementType.MEASUREMENT, 'c': cirq.MeasurementType.CHANNEL}) \ No newline at end of file +cirq.ClassicalData(_measurements={cirq.MeasurementKey('m'): [0, 1]}, _measured_qubits={cirq.MeasurementKey('m'): [cirq.LineQubit(0), cirq.LineQubit(1)]}, _channel_measurements={cirq.MeasurementKey('c'): 3}, _measurement_types={cirq.MeasurementKey('m'): cirq.MeasurementType.MEASUREMENT, cirq.MeasurementKey('c'): cirq.MeasurementType.CHANNEL}) \ No newline at end of file diff --git a/cirq-core/cirq/value/classical_data.py b/cirq-core/cirq/value/classical_data.py index 162977c2860..95f1d332027 100644 --- a/cirq-core/cirq/value/classical_data.py +++ b/cirq-core/cirq/value/classical_data.py @@ -242,19 +242,22 @@ def copy(self): ) def _json_dict_(self): - return json_serialization.obj_to_dict_helper( - self, ['measurements', 'measured_qubits', 'channel_measurements', 'measurement_types'] - ) + return { + 'measurements': list(self.measurements.items()), + 'measured_qubits': list(self.measured_qubits.items()), + 'channel_measurements': list(self.channel_measurements.items()), + 'measurement_types': list(self.measurement_types.items()), + } @classmethod def _from_json_dict_( cls, measurements, measured_qubits, channel_measurements, measurement_types, **kwargs ): return cls( - _measurements=measurements, - _measured_qubits=measured_qubits, - _channel_measurements=channel_measurements, - _measurement_types=measurement_types, + _measurements=dict(measurements), + _measured_qubits=dict(measured_qubits), + _channel_measurements=dict(channel_measurements), + _measurement_types=dict(measurement_types), ) def __repr__(self): diff --git a/cirq-core/cirq/value/classical_data_test.py b/cirq-core/cirq/value/classical_data_test.py index f58334e81eb..7644bef6858 100644 --- a/cirq-core/cirq/value/classical_data_test.py +++ b/cirq-core/cirq/value/classical_data_test.py @@ -111,4 +111,3 @@ def test_repr(): }, ) cirq.testing.assert_equivalent_repr(cd) - From f84565a4fe9b2de494e316a3da3e4800e322aef5 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Sat, 25 Dec 2021 22:18:21 -0800 Subject: [PATCH 55/89] mypy --- cirq-core/cirq/sim/act_on_args_container.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cirq-core/cirq/sim/act_on_args_container.py b/cirq-core/cirq/sim/act_on_args_container.py index 8b2d9e08264..27396322c85 100644 --- a/cirq-core/cirq/sim/act_on_args_container.py +++ b/cirq-core/cirq/sim/act_on_args_container.py @@ -136,7 +136,7 @@ def _act_on_fallback_( return True def copy(self) -> 'cirq.ActOnArgsContainer[TActOnArgs]': - classical_data = self.classical_data.copy() + classical_data = self._classical_data.copy() copies = {a: a.copy() for a in set(self.args.values())} for copy in copies.values(): copy._classical_data = classical_data From a8848fb0f5c73fe4a45eaab0f0a98d14e962bec1 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Sat, 25 Dec 2021 23:11:03 -0800 Subject: [PATCH 56/89] mypy --- cirq-core/cirq/value/classical_data.py | 23 +++++++++++---------- cirq-core/cirq/value/classical_data_test.py | 15 ++++++++++++++ cirq-core/cirq/value/condition_test.py | 2 +- 3 files changed, 28 insertions(+), 12 deletions(-) diff --git a/cirq-core/cirq/value/classical_data.py b/cirq-core/cirq/value/classical_data.py index 95f1d332027..41f1548cdf0 100644 --- a/cirq-core/cirq/value/classical_data.py +++ b/cirq-core/cirq/value/classical_data.py @@ -15,7 +15,7 @@ import abc import enum from typing import Dict, Mapping, Sequence, Tuple, TYPE_CHECKING, TypeVar -from cirq.protocols import json_serialization + from cirq.value import digits, value_equality_attr if TYPE_CHECKING: @@ -113,24 +113,25 @@ def __init__( _measurement_types: Dict['cirq.MeasurementKey', 'cirq.MeasurementType'] = None, ): """Initializes a `ClassicalData` object.""" + _measurement_types_was_none = _measurement_types is None + if _measurement_types is None: + _measurement_types = {} if _measurements is not None: - if _measurement_types is None: - _measurement_types = { - k: MeasurementType.MEASUREMENT for k, v in _measurements.items() - } + if _measurement_types_was_none: + _measurement_types.update( + {k: MeasurementType.MEASUREMENT for k, v in _measurements.items()} + ) if _channel_measurements is not None: - if _measurement_types is None: - _measurement_types = { - k: MeasurementType.CHANNEL for k, v in _channel_measurements.items() - } + if _measurement_types_was_none: + _measurement_types.update( + {k: MeasurementType.CHANNEL for k, v in _channel_measurements.items()} + ) if _measurements is None: _measurements = {} if _measured_qubits is None: _measured_qubits = {} if _channel_measurements is None: _channel_measurements = {} - if _measurement_types is None: - _measurement_types = {} self._measurements = _measurements self._measured_qubits = _measured_qubits self._channel_measurements = _channel_measurements diff --git a/cirq-core/cirq/value/classical_data_test.py b/cirq-core/cirq/value/classical_data_test.py index 7644bef6858..a47fc0c0e56 100644 --- a/cirq-core/cirq/value/classical_data_test.py +++ b/cirq-core/cirq/value/classical_data_test.py @@ -26,6 +26,21 @@ def test_init(): assert cd.measurements == {} assert cd.keys() == () assert cd.measured_qubits == {} + assert cd.channel_measurements == {} + assert cd.measurement_types == {} + cd = cirq.ClassicalData( + _measurements={mkey_m: (0, 1)}, + _measured_qubits={mkey_m: two_qubits}, + _channel_measurements={mkey_c: 3}, + ) + assert cd.measurements == {mkey_m: (0, 1)} + assert cd.keys() == (mkey_m, mkey_c) + assert cd.measured_qubits == {mkey_m: two_qubits} + assert cd.channel_measurements == {mkey_c: 3} + assert cd.measurement_types == { + mkey_m: cirq.MeasurementType.MEASUREMENT, + mkey_c: cirq.MeasurementType.CHANNEL, + } def test_record_measurement(): diff --git a/cirq-core/cirq/value/condition_test.py b/cirq-core/cirq/value/condition_test.py index 65669232bd9..e5534734156 100644 --- a/cirq-core/cirq/value/condition_test.py +++ b/cirq-core/cirq/value/condition_test.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import dataclasses + import re import pytest From 26ab128a5e346a5b47dfe67fafacdf3ec3d96e5d Mon Sep 17 00:00:00 2001 From: daxfohl Date: Sat, 25 Dec 2021 23:43:58 -0800 Subject: [PATCH 57/89] rename --- cirq-core/cirq/__init__.py | 6 ++--- cirq-core/cirq/contrib/quimb/mps_simulator.py | 4 +-- cirq-core/cirq/json_resolver_cache.py | 2 +- cirq-core/cirq/sim/act_on_args.py | 6 ++--- cirq-core/cirq/sim/act_on_args_container.py | 6 ++--- .../cirq/sim/act_on_density_matrix_args.py | 2 +- .../cirq/sim/act_on_state_vector_args.py | 2 +- .../clifford/act_on_clifford_tableau_args.py | 2 +- .../act_on_stabilizer_ch_form_args.py | 2 +- .../cirq/sim/clifford/clifford_simulator.py | 2 +- .../cirq/sim/density_matrix_simulator.py | 2 +- cirq-core/cirq/sim/operation_target.py | 2 +- cirq-core/cirq/sim/simulator_base.py | 4 +-- cirq-core/cirq/sim/simulator_base_test.py | 4 +-- cirq-core/cirq/sim/sparse_simulator.py | 2 +- cirq-core/cirq/value/__init__.py | 6 ++--- cirq-core/cirq/value/classical_data.py | 14 +++++----- cirq-core/cirq/value/classical_data_test.py | 26 +++++++++---------- cirq-core/cirq/value/condition.py | 6 ++--- cirq-core/cirq/value/condition_test.py | 4 +-- 20 files changed, 52 insertions(+), 52 deletions(-) diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index 4341fa32cc7..431dc29d127 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -483,9 +483,9 @@ canonicalize_half_turns, chosen_angle_to_canonical_half_turns, chosen_angle_to_half_turns, - ClassicalData, - ClassicalDataBase, - ClassicalDataReader, + ClassicalDataDictionaryStore, + ClassicalDataStore, + ClassicalDataStoreReader, Condition, Duration, DURATION_LIKE, diff --git a/cirq-core/cirq/contrib/quimb/mps_simulator.py b/cirq-core/cirq/contrib/quimb/mps_simulator.py index 74f8c082d33..3d72d1a0709 100644 --- a/cirq-core/cirq/contrib/quimb/mps_simulator.py +++ b/cirq-core/cirq/contrib/quimb/mps_simulator.py @@ -91,7 +91,7 @@ def _create_partial_act_on_args_ex( self, initial_state: Union[int, 'MPSState'], qubits: Sequence['cirq.Qid'], - classical_data: 'cirq.ClassicalDataBase', + classical_data: 'cirq.ClassicalDataStore', ) -> 'MPSState': """Creates MPSState args for simulating the Circuit. @@ -230,7 +230,7 @@ def __init__( grouping: Optional[Dict['cirq.Qid', int]] = None, initial_state: int = 0, log_of_measurement_results: Dict[str, Any] = None, - classical_data: 'cirq.ClassicalDataBase' = None, + classical_data: 'cirq.ClassicalDataStore' = None, ): """Creates and MPSState diff --git a/cirq-core/cirq/json_resolver_cache.py b/cirq-core/cirq/json_resolver_cache.py index 71c5ced23b4..afccb6f39f1 100644 --- a/cirq-core/cirq/json_resolver_cache.py +++ b/cirq-core/cirq/json_resolver_cache.py @@ -64,7 +64,7 @@ def _parallel_gate_op(gate, qubits): 'Circuit': cirq.Circuit, 'CircuitOperation': cirq.CircuitOperation, 'ClassicallyControlledOperation': cirq.ClassicallyControlledOperation, - 'ClassicalData': cirq.ClassicalData, + 'ClassicalDataDictionaryStore': cirq.ClassicalDataDictionaryStore, 'CliffordState': cirq.CliffordState, 'CliffordTableau': cirq.CliffordTableau, 'CNotPowGate': cirq.CNotPowGate, diff --git a/cirq-core/cirq/sim/act_on_args.py b/cirq-core/cirq/sim/act_on_args.py index b33ea763e14..ffd8ee18f9b 100644 --- a/cirq-core/cirq/sim/act_on_args.py +++ b/cirq-core/cirq/sim/act_on_args.py @@ -47,7 +47,7 @@ def __init__( prng: np.random.RandomState = None, qubits: Sequence['cirq.Qid'] = None, log_of_measurement_results: Dict[str, List[int]] = None, - classical_data: 'cirq.ClassicalDataBase' = None, + classical_data: 'cirq.ClassicalDataStore' = None, ignore_measurement_results: bool = False, ): """Inits ActOnArgs. @@ -74,7 +74,7 @@ def __init__( self._set_qubits(qubits) self.prng = prng # pylint: disable=line-too-long - self._classical_data = classical_data or value.ClassicalData(_measurements=log_of_measurement_results) # type: ignore + self._classical_data = classical_data or value.ClassicalDataDictionaryStore(_measurements=log_of_measurement_results) # type: ignore # pylint: enable=line-too-long self._ignore_measurement_results = ignore_measurement_results @@ -197,7 +197,7 @@ def _on_transpose_to_qubit_order(self: TSelf, qubits: Sequence['cirq.Qid'], targ functionality, if supported.""" @property - def classical_data(self) -> 'cirq.ClassicalDataReader': + def classical_data(self) -> 'cirq.ClassicalDataStoreReader': return self._classical_data @property diff --git a/cirq-core/cirq/sim/act_on_args_container.py b/cirq-core/cirq/sim/act_on_args_container.py index 27396322c85..f80d96e5a7b 100644 --- a/cirq-core/cirq/sim/act_on_args_container.py +++ b/cirq-core/cirq/sim/act_on_args_container.py @@ -51,7 +51,7 @@ def __init__( qubits: Sequence['cirq.Qid'], split_untangled_states: bool, log_of_measurement_results: Dict[str, Any] = None, - classical_data: 'cirq.ClassicalDataBase' = None, + classical_data: 'cirq.ClassicalDataStore' = None, ): """Initializes the class. @@ -71,7 +71,7 @@ def __init__( self._qubits = tuple(qubits) self.split_untangled_states = split_untangled_states # pylint: disable=line-too-long - self._classical_data = classical_data or value.ClassicalData(_measurements=log_of_measurement_results) # type: ignore + self._classical_data = classical_data or value.ClassicalDataDictionaryStore(_measurements=log_of_measurement_results) # type: ignore # pylint: enable=line-too-long def create_merged_state(self) -> TActOnArgs: @@ -150,7 +150,7 @@ def qubits(self) -> Tuple['cirq.Qid', ...]: return self._qubits @property - def classical_data(self) -> 'cirq.ClassicalDataReader': + def classical_data(self) -> 'cirq.ClassicalDataStoreReader': return self._classical_data def sample( diff --git a/cirq-core/cirq/sim/act_on_density_matrix_args.py b/cirq-core/cirq/sim/act_on_density_matrix_args.py index eb2c517cb32..b6f109eda29 100644 --- a/cirq-core/cirq/sim/act_on_density_matrix_args.py +++ b/cirq-core/cirq/sim/act_on_density_matrix_args.py @@ -41,7 +41,7 @@ def __init__( prng: np.random.RandomState = None, log_of_measurement_results: Dict[str, Any] = None, qubits: Sequence['cirq.Qid'] = None, - classical_data: 'cirq.ClassicalDataBase' = None, + classical_data: 'cirq.ClassicalDataStore' = None, ignore_measurement_results: bool = False, ): """Inits ActOnDensityMatrixArgs. diff --git a/cirq-core/cirq/sim/act_on_state_vector_args.py b/cirq-core/cirq/sim/act_on_state_vector_args.py index 849d37d4c98..6bb04492191 100644 --- a/cirq-core/cirq/sim/act_on_state_vector_args.py +++ b/cirq-core/cirq/sim/act_on_state_vector_args.py @@ -44,7 +44,7 @@ def __init__( prng: np.random.RandomState = None, log_of_measurement_results: Dict[str, Any] = None, qubits: Sequence['cirq.Qid'] = None, - classical_data: 'cirq.ClassicalDataBase' = None, + classical_data: 'cirq.ClassicalDataStore' = None, ): """Inits ActOnStateVectorArgs. diff --git a/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py b/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py index cfdee35b2a2..e94ed5bd894 100644 --- a/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py +++ b/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py @@ -42,7 +42,7 @@ def __init__( prng: np.random.RandomState, log_of_measurement_results: Dict[str, Any], qubits: Sequence['cirq.Qid'] = None, - classical_data: 'cirq.ClassicalDataBase' = None, + classical_data: 'cirq.ClassicalDataStore' = None, ): """Inits ActOnCliffordTableauArgs. diff --git a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py index 7000a0a2d97..d5d7f7d7bc9 100644 --- a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py +++ b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py @@ -41,7 +41,7 @@ def __init__( prng: np.random.RandomState, log_of_measurement_results: Dict[str, Any] = None, qubits: Sequence['cirq.Qid'] = None, - classical_data: 'cirq.ClassicalDataBase' = None, + classical_data: 'cirq.ClassicalDataStore' = None, ): """Initializes with the given state and the axes for the operation. Args: diff --git a/cirq-core/cirq/sim/clifford/clifford_simulator.py b/cirq-core/cirq/sim/clifford/clifford_simulator.py index 917e3b738c1..2351e423b4e 100644 --- a/cirq-core/cirq/sim/clifford/clifford_simulator.py +++ b/cirq-core/cirq/sim/clifford/clifford_simulator.py @@ -68,7 +68,7 @@ def _create_partial_act_on_args_ex( self, initial_state: Union[int, 'cirq.ActOnStabilizerCHFormArgs'], qubits: Sequence['cirq.Qid'], - classical_data: 'cirq.ClassicalDataBase' = None, + classical_data: 'cirq.ClassicalDataStore' = None, ) -> 'cirq.ActOnStabilizerCHFormArgs': """Creates the ActOnStabilizerChFormArgs for a circuit. diff --git a/cirq-core/cirq/sim/density_matrix_simulator.py b/cirq-core/cirq/sim/density_matrix_simulator.py index 58ca2a6cf39..54b3770d0b4 100644 --- a/cirq-core/cirq/sim/density_matrix_simulator.py +++ b/cirq-core/cirq/sim/density_matrix_simulator.py @@ -176,7 +176,7 @@ def _create_partial_act_on_args_ex( self, initial_state: Union[np.ndarray, 'cirq.STATE_VECTOR_LIKE', 'cirq.ActOnDensityMatrixArgs'], qubits: Sequence['cirq.Qid'], - classical_data: 'cirq.ClassicalDataBase' = None, + classical_data: 'cirq.ClassicalDataStore' = None, ) -> 'cirq.ActOnDensityMatrixArgs': """Creates the ActOnDensityMatrixArgs for a circuit. diff --git a/cirq-core/cirq/sim/operation_target.py b/cirq-core/cirq/sim/operation_target.py index 5ae7dad8d74..09fd84b49f4 100644 --- a/cirq-core/cirq/sim/operation_target.py +++ b/cirq-core/cirq/sim/operation_target.py @@ -83,7 +83,7 @@ def log_of_measurement_results(self) -> Dict[str, Any]: @property @abc.abstractmethod - def classical_data(self) -> 'cirq.ClassicalDataReader': + def classical_data(self) -> 'cirq.ClassicalDataStoreReader': """The shared classical data container for this simulation..""" @abc.abstractmethod diff --git a/cirq-core/cirq/sim/simulator_base.py b/cirq-core/cirq/sim/simulator_base.py index 6ceba551e98..5ceaa81d141 100644 --- a/cirq-core/cirq/sim/simulator_base.py +++ b/cirq-core/cirq/sim/simulator_base.py @@ -143,7 +143,7 @@ def _create_partial_act_on_args_ex( self, initial_state: Any, qubits: Sequence['cirq.Qid'], - classical_data: 'cirq.ClassicalDataBase', + classical_data: 'cirq.ClassicalDataStore', ) -> TActOnArgs: """Creates an instance of the TActOnArgs class for the simulator. @@ -360,7 +360,7 @@ def _create_act_on_args( if isinstance(initial_state, OperationTarget): return initial_state - classical_data = value.ClassicalData() + classical_data = value.ClassicalDataDictionaryStore() if self._split_untangled_states: args_map: Dict[Optional['cirq.Qid'], TActOnArgs] = {} if isinstance(initial_state, int): diff --git a/cirq-core/cirq/sim/simulator_base_test.py b/cirq-core/cirq/sim/simulator_base_test.py index 6e6df9991f2..e47bf3e13dc 100644 --- a/cirq-core/cirq/sim/simulator_base_test.py +++ b/cirq-core/cirq/sim/simulator_base_test.py @@ -111,7 +111,7 @@ def _create_partial_act_on_args_ex( self, initial_state: Any, qubits: Sequence['cirq.Qid'], - classical_data: cirq.ClassicalDataBase, + classical_data: cirq.ClassicalDataStore, ) -> CountingActOnArgs: return CountingActOnArgs(qubits=qubits, state=initial_state, classical_data=classical_data) @@ -141,7 +141,7 @@ def _create_partial_act_on_args_ex( self, initial_state: Any, qubits: Sequence['cirq.Qid'], - classical_data: cirq.ClassicalDataBase, + classical_data: cirq.ClassicalDataStore, ) -> CountingActOnArgs: return SplittableCountingActOnArgs( qubits=qubits, state=initial_state, classical_data=classical_data diff --git a/cirq-core/cirq/sim/sparse_simulator.py b/cirq-core/cirq/sim/sparse_simulator.py index 0e95b3dd247..47233e6446e 100644 --- a/cirq-core/cirq/sim/sparse_simulator.py +++ b/cirq-core/cirq/sim/sparse_simulator.py @@ -175,7 +175,7 @@ def _create_partial_act_on_args_ex( self, initial_state: Union['cirq.STATE_VECTOR_LIKE', 'cirq.ActOnStateVectorArgs'], qubits: Sequence['cirq.Qid'], - classical_data: 'cirq.ClassicalDataBase', + classical_data: 'cirq.ClassicalDataStore', ): """Creates the ActOnStateVectorArgs for a circuit. diff --git a/cirq-core/cirq/value/__init__.py b/cirq-core/cirq/value/__init__.py index 498096e7cde..4fbd4294c6a 100644 --- a/cirq-core/cirq/value/__init__.py +++ b/cirq-core/cirq/value/__init__.py @@ -26,9 +26,9 @@ ) from cirq.value.classical_data import ( - ClassicalData, - ClassicalDataBase, - ClassicalDataReader, + ClassicalDataDictionaryStore, + ClassicalDataStore, + ClassicalDataStoreReader, MeasurementType, ) diff --git a/cirq-core/cirq/value/classical_data.py b/cirq-core/cirq/value/classical_data.py index 41f1548cdf0..69e550070c3 100644 --- a/cirq-core/cirq/value/classical_data.py +++ b/cirq-core/cirq/value/classical_data.py @@ -30,10 +30,10 @@ def __repr__(self): return f'cirq.{str(self)}' -TSelf = TypeVar('TSelf', bound='ClassicalDataReader') +TSelf = TypeVar('TSelf', bound='ClassicalDataStoreReader') -class ClassicalDataReader(abc.ABC): +class ClassicalDataStoreReader(abc.ABC): @abc.abstractmethod def keys(self) -> Tuple['cirq.MeasurementKey', ...]: """Gets the measurement keys in the order they were stored.""" @@ -65,7 +65,7 @@ def copy(self: TSelf) -> TSelf: """Creates a copy of the object.""" -class ClassicalDataBase(ClassicalDataReader, abc.ABC): +class ClassicalDataStore(ClassicalDataStoreReader, abc.ABC): @abc.abstractmethod def record_measurement( self, key: 'cirq.MeasurementKey', measurement: Sequence[int], qubits: Sequence['cirq.Qid'] @@ -96,7 +96,7 @@ def record_channel_measurement(self, key: 'cirq.MeasurementKey', measurement: in @value_equality_attr.value_equality(unhashable=True) -class ClassicalData(ClassicalDataBase): +class ClassicalDataDictionaryStore(ClassicalDataStore): """Classical data representing measurements and metadata.""" _measurements: Dict['cirq.MeasurementKey', Tuple[int, ...]] @@ -112,7 +112,7 @@ def __init__( _channel_measurements: Dict['cirq.MeasurementKey', int] = None, _measurement_types: Dict['cirq.MeasurementKey', 'cirq.MeasurementType'] = None, ): - """Initializes a `ClassicalData` object.""" + """Initializes a `ClassicalDataDictionaryStore` object.""" _measurement_types_was_none = _measurement_types is None if _measurement_types is None: _measurement_types = {} @@ -235,7 +235,7 @@ def get_int(self, key: 'cirq.MeasurementKey') -> int: def copy(self): """Creates a copy of the object.""" - return ClassicalData( + return ClassicalDataDictionaryStore( _measurements=self._measurements.copy(), _measured_qubits=self._measured_qubits.copy(), _channel_measurements=self._channel_measurements.copy(), @@ -263,7 +263,7 @@ def _from_json_dict_( def __repr__(self): return ( - f'cirq.ClassicalData(_measurements={self.measurements!r},' + f'cirq.ClassicalDataDictionaryStore(_measurements={self.measurements!r},' f' _measured_qubits={self.measured_qubits!r},' f' _channel_measurements={self.channel_measurements!r},' f' _measurement_types={self.measurement_types!r})' diff --git a/cirq-core/cirq/value/classical_data_test.py b/cirq-core/cirq/value/classical_data_test.py index a47fc0c0e56..0fe9fb6f8ca 100644 --- a/cirq-core/cirq/value/classical_data_test.py +++ b/cirq-core/cirq/value/classical_data_test.py @@ -22,13 +22,13 @@ def test_init(): - cd = cirq.ClassicalData() + cd = cirq.ClassicalDataDictionaryStore() assert cd.measurements == {} assert cd.keys() == () assert cd.measured_qubits == {} assert cd.channel_measurements == {} assert cd.measurement_types == {} - cd = cirq.ClassicalData( + cd = cirq.ClassicalDataDictionaryStore( _measurements={mkey_m: (0, 1)}, _measured_qubits={mkey_m: two_qubits}, _channel_measurements={mkey_c: 3}, @@ -44,7 +44,7 @@ def test_init(): def test_record_measurement(): - cd = cirq.ClassicalData() + cd = cirq.ClassicalDataDictionaryStore() cd.record_measurement(mkey_m, (0, 1), two_qubits) assert cd.measurements == {mkey_m: (0, 1)} assert cd.keys() == (mkey_m,) @@ -52,7 +52,7 @@ def test_record_measurement(): def test_record_measurement_errors(): - cd = cirq.ClassicalData() + cd = cirq.ClassicalDataDictionaryStore() with pytest.raises(ValueError, match='3 measurements but 2 qubits'): cd.record_measurement(mkey_m, (0, 1, 2), two_qubits) cd.record_measurement(mkey_m, (0, 1), two_qubits) @@ -61,39 +61,39 @@ def test_record_measurement_errors(): def test_record_channel_measurement(): - cd = cirq.ClassicalData() + cd = cirq.ClassicalDataDictionaryStore() cd.record_channel_measurement(mkey_m, 1) assert cd.channel_measurements == {mkey_m: 1} assert cd.keys() == (mkey_m,) def test_record_channel_measurement_errors(): - cd = cirq.ClassicalData() + cd = cirq.ClassicalDataDictionaryStore() cd.record_channel_measurement(mkey_m, 1) with pytest.raises(ValueError, match='Measurement already logged to key m'): cd.record_channel_measurement(mkey_m, 1) def test_get_int(): - cd = cirq.ClassicalData() + cd = cirq.ClassicalDataDictionaryStore() cd.record_measurement(mkey_m, (0, 1), two_qubits) assert cd.get_int(mkey_m) == 1 - cd = cirq.ClassicalData() + cd = cirq.ClassicalDataDictionaryStore() cd.record_measurement(mkey_m, (1, 1), two_qubits) assert cd.get_int(mkey_m) == 3 - cd = cirq.ClassicalData() + cd = cirq.ClassicalDataDictionaryStore() cd.record_channel_measurement(mkey_m, 1) assert cd.get_int(mkey_m) == 1 - cd = cirq.ClassicalData() + cd = cirq.ClassicalDataDictionaryStore() cd.record_measurement(mkey_m, (1, 1), (cirq.LineQid.range(2, dimension=3))) assert cd.get_int(mkey_m) == 4 - cd = cirq.ClassicalData() + cd = cirq.ClassicalDataDictionaryStore() with pytest.raises(KeyError, match='The measurement key m is not in {}'): cd.get_int(mkey_m) def test_copy(): - cd = cirq.ClassicalData( + cd = cirq.ClassicalDataDictionaryStore( _measurements={mkey_m: (0, 1)}, _measured_qubits={mkey_m: two_qubits}, _channel_measurements={mkey_c: 3}, @@ -116,7 +116,7 @@ def test_copy(): def test_repr(): - cd = cirq.ClassicalData( + cd = cirq.ClassicalDataDictionaryStore( _measurements={mkey_m: (0, 1)}, _measured_qubits={mkey_m: two_qubits}, _channel_measurements={mkey_c: 3}, diff --git a/cirq-core/cirq/value/condition.py b/cirq-core/cirq/value/condition.py index 5722ce93358..7c594eb2d95 100644 --- a/cirq-core/cirq/value/condition.py +++ b/cirq-core/cirq/value/condition.py @@ -41,7 +41,7 @@ def replace_key(self, current: 'cirq.MeasurementKey', replacement: 'cirq.Measure @abc.abstractmethod def resolve( self, - classical_data: 'cirq.ClassicalDataReader', + classical_data: 'cirq.ClassicalDataStoreReader', ) -> bool: """Resolves the condition based on the measurements.""" @@ -103,7 +103,7 @@ def __repr__(self): def resolve( self, - classical_data: 'cirq.ClassicalDataReader', + classical_data: 'cirq.ClassicalDataStoreReader', ) -> bool: if self.key not in classical_data.keys(): raise ValueError(f'Measurement key {self.key} missing when testing classical control') @@ -150,7 +150,7 @@ def __repr__(self): def resolve( self, - classical_data: 'cirq.ClassicalDataReader', + classical_data: 'cirq.ClassicalDataStoreReader', ) -> bool: missing = [str(k) for k in self.keys if k not in classical_data.keys()] if missing: diff --git a/cirq-core/cirq/value/condition_test.py b/cirq-core/cirq/value/condition_test.py index e5534734156..e92029b1bfb 100644 --- a/cirq-core/cirq/value/condition_test.py +++ b/cirq-core/cirq/value/condition_test.py @@ -43,7 +43,7 @@ def test_key_condition_repr(): def test_key_condition_resolve(): def resolve(measurements): - classical_data = cirq.ClassicalData(_measurements=measurements) + classical_data = cirq.ClassicalDataDictionaryStore(_measurements=measurements) return init_key_condition.resolve(classical_data) assert resolve({'0:a': [1]}) @@ -85,7 +85,7 @@ def test_sympy_condition_repr(): def test_sympy_condition_resolve(): def resolve(measurements): - classical_data = cirq.ClassicalData(_measurements=measurements) + classical_data = cirq.ClassicalDataDictionaryStore(_measurements=measurements) return init_sympy_condition.resolve(classical_data) assert resolve({'0:a': [1]}) From 540489281e81951a3a235a948deb4491d853839c Mon Sep 17 00:00:00 2001 From: daxfohl Date: Sun, 26 Dec 2021 00:07:27 -0800 Subject: [PATCH 58/89] rename --- cirq-core/cirq/protocols/json_test_data/ClassicalData.repr | 1 - .../{ClassicalData.json => ClassicalDataDictionaryStore.json} | 2 +- .../protocols/json_test_data/ClassicalDataDictionaryStore.repr | 1 + 3 files changed, 2 insertions(+), 2 deletions(-) delete mode 100644 cirq-core/cirq/protocols/json_test_data/ClassicalData.repr rename cirq-core/cirq/protocols/json_test_data/{ClassicalData.json => ClassicalDataDictionaryStore.json} (94%) create mode 100644 cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.repr diff --git a/cirq-core/cirq/protocols/json_test_data/ClassicalData.repr b/cirq-core/cirq/protocols/json_test_data/ClassicalData.repr deleted file mode 100644 index f2ae4642f5a..00000000000 --- a/cirq-core/cirq/protocols/json_test_data/ClassicalData.repr +++ /dev/null @@ -1 +0,0 @@ -cirq.ClassicalData(_measurements={cirq.MeasurementKey('m'): [0, 1]}, _measured_qubits={cirq.MeasurementKey('m'): [cirq.LineQubit(0), cirq.LineQubit(1)]}, _channel_measurements={cirq.MeasurementKey('c'): 3}, _measurement_types={cirq.MeasurementKey('m'): cirq.MeasurementType.MEASUREMENT, cirq.MeasurementKey('c'): cirq.MeasurementType.CHANNEL}) \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/ClassicalData.json b/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.json similarity index 94% rename from cirq-core/cirq/protocols/json_test_data/ClassicalData.json rename to cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.json index d90424e4206..d5c51d5839c 100644 --- a/cirq-core/cirq/protocols/json_test_data/ClassicalData.json +++ b/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.json @@ -1,5 +1,5 @@ { - "cirq_type": "ClassicalData", + "cirq_type": "ClassicalDataDictionaryStore", "measurements": [ [ { diff --git a/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.repr b/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.repr new file mode 100644 index 00000000000..c19b8190bfb --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.repr @@ -0,0 +1 @@ +cirq.ClassicalDataDictionaryStore(_measurements={cirq.MeasurementKey('m'): [0, 1]}, _measured_qubits={cirq.MeasurementKey('m'): [cirq.LineQubit(0), cirq.LineQubit(1)]}, _channel_measurements={cirq.MeasurementKey('c'): 3}, _measurement_types={cirq.MeasurementKey('m'): cirq.MeasurementType.MEASUREMENT, cirq.MeasurementKey('c'): cirq.MeasurementType.CHANNEL}) \ No newline at end of file From 70b4b5e63eb14dfd4f18a0ba0299ec022d28b32a Mon Sep 17 00:00:00 2001 From: daxfohl Date: Sun, 26 Dec 2021 09:49:44 -0800 Subject: [PATCH 59/89] lint --- cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py | 2 +- cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py b/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py index e94ed5bd894..dffc0d0d6e3 100644 --- a/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py +++ b/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py @@ -14,7 +14,7 @@ """A protocol for implementing high performance clifford tableau evolutions for Clifford Simulator.""" -from typing import Any, Dict, TYPE_CHECKING, List, Sequence, Tuple, Union +from typing import Any, Dict, List, Sequence, TYPE_CHECKING, Union import numpy as np diff --git a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py index d5d7f7d7bc9..2906b2a4baf 100644 --- a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py +++ b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, TYPE_CHECKING, List, Sequence, Tuple, Union +from typing import Any, Dict, List, Sequence, TYPE_CHECKING, Union import numpy as np From b014135d5ff78c7883c54b8eca8a26bc44a82874 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Wed, 29 Dec 2021 00:15:36 -0800 Subject: [PATCH 60/89] docstrings, simplify some logic --- cirq-core/cirq/value/classical_data.py | 83 ++++++++------------------ 1 file changed, 25 insertions(+), 58 deletions(-) diff --git a/cirq-core/cirq/value/classical_data.py b/cirq-core/cirq/value/classical_data.py index 69e550070c3..51e89c81fed 100644 --- a/cirq-core/cirq/value/classical_data.py +++ b/cirq-core/cirq/value/classical_data.py @@ -42,6 +42,12 @@ def keys(self) -> Tuple['cirq.MeasurementKey', ...]: def get_int(self, key: 'cirq.MeasurementKey') -> int: """Gets the integer corresponding to the measurement. + The integer is determined by summing the qubit-dimensional basis value + of each measured qubit. For example, if the measurement of qubits + [q0, q1] produces [0, 1], then the corresponding integer is 2, the big- + endian equivalent. If they are qutrits and the measurement is [1, 2], + then the integer is 2 * 3 + 1 = 7. + Args: key: The measurement key. @@ -51,7 +57,10 @@ def get_int(self, key: 'cirq.MeasurementKey') -> int: @abc.abstractmethod def get_digits(self, key: 'cirq.MeasurementKey') -> Tuple[int, ...]: - """Gets the digits of the measurement. + """Gets the values of the qubits that were measured into this key. + + For example, if the measurement of qubits [q0, q1] produces [0, 1], + this function will return (0, 1). Args: key: The measurement key. @@ -99,11 +108,6 @@ def record_channel_measurement(self, key: 'cirq.MeasurementKey', measurement: in class ClassicalDataDictionaryStore(ClassicalDataStore): """Classical data representing measurements and metadata.""" - _measurements: Dict['cirq.MeasurementKey', Tuple[int, ...]] - _measured_qubits: Dict['cirq.MeasurementKey', Tuple['cirq.Qid', ...]] - _channel_measurements: Dict['cirq.MeasurementKey', int] - _measurement_types: Dict['cirq.MeasurementKey', 'cirq.MeasurementType'] - def __init__( self, *, @@ -113,16 +117,13 @@ def __init__( _measurement_types: Dict['cirq.MeasurementKey', 'cirq.MeasurementType'] = None, ): """Initializes a `ClassicalDataDictionaryStore` object.""" - _measurement_types_was_none = _measurement_types is None - if _measurement_types is None: + if not _measurement_types: _measurement_types = {} - if _measurements is not None: - if _measurement_types_was_none: + if _measurements: _measurement_types.update( {k: MeasurementType.MEASUREMENT for k, v in _measurements.items()} ) - if _channel_measurements is not None: - if _measurement_types_was_none: + if _channel_measurements: _measurement_types.update( {k: MeasurementType.CHANNEL for k, v in _channel_measurements.items()} ) @@ -132,14 +133,14 @@ def __init__( _measured_qubits = {} if _channel_measurements is None: _channel_measurements = {} - self._measurements = _measurements - self._measured_qubits = _measured_qubits - self._channel_measurements = _channel_measurements - self._measurement_types = _measurement_types - - def keys(self) -> Tuple['cirq.MeasurementKey', ...]: - """Gets the measurement keys in the order they were stored.""" - return tuple(self._measurement_types.keys()) + self._measurements: Dict['cirq.MeasurementKey', Tuple[int, ...]] = _measurements + self._measured_qubits: Dict[ + 'cirq.MeasurementKey', Tuple['cirq.Qid', ...] + ] = _measured_qubits + self._channel_measurements: Dict['cirq.MeasurementKey', int] = _channel_measurements + self._measurement_types: Dict[ + 'cirq.MeasurementKey', 'cirq.MeasurementType' + ] = _measurement_types @property def measurements(self) -> Mapping['cirq.MeasurementKey', Tuple[int, ...]]: @@ -158,23 +159,15 @@ def measured_qubits(self) -> Mapping['cirq.MeasurementKey', Tuple['cirq.Qid', .. @property def measurement_types(self) -> Mapping['cirq.MeasurementKey', 'cirq.MeasurementType']: - """Gets the a mapping from measurement key to the qubits measured.""" + """Gets the a mapping from measurement key to the measurement type.""" return self._measurement_types + def keys(self) -> Tuple['cirq.MeasurementKey', ...]: + return tuple(self._measurement_types.keys()) + def record_measurement( self, key: 'cirq.MeasurementKey', measurement: Sequence[int], qubits: Sequence['cirq.Qid'] ): - """Records a measurement. - - Args: - key: The measurement key to hold the measurement. - measurement: The measurement result. - qubits: The qubits that were measured. - - Raises: - ValueError: If the measurement shape does not match the qubits - measured or if the measurement key was already used. - """ if len(measurement) != len(qubits): raise ValueError(f'{len(measurement)} measurements but {len(qubits)} qubits.') if key in self._measurements: @@ -184,29 +177,12 @@ def record_measurement( self._measured_qubits[key] = tuple(qubits) def record_channel_measurement(self, key: 'cirq.MeasurementKey', measurement: int): - """Records a channel measurement. - - Args: - key: The measurement key to hold the measurement. - measurement: The measurement result. - - Raises: - ValueError: If the measurement key was already used. - """ if key in self._measurement_types: raise ValueError(f"Measurement already logged to key {key}") self._measurement_types[key] = MeasurementType.CHANNEL self._channel_measurements[key] = measurement def get_digits(self, key: 'cirq.MeasurementKey') -> Tuple[int, ...]: - """Gets the digits of the measurement. - - Args: - key: The measurement key. - - Raises: - KeyError: If the key has not been used. - """ return ( self._measurements[key] if self._measurement_types[key] == MeasurementType.MEASUREMENT @@ -214,14 +190,6 @@ def get_digits(self, key: 'cirq.MeasurementKey') -> Tuple[int, ...]: ) def get_int(self, key: 'cirq.MeasurementKey') -> int: - """Gets the integer corresponding to the measurement. - - Args: - key: The measurement key. - - Raises: - KeyError: If the key has not been used. - """ if key not in self._measurement_types: raise KeyError(f'The measurement key {key} is not in {self._measurements}') measurement_type = self._measurement_types[key] @@ -234,7 +202,6 @@ def get_int(self, key: 'cirq.MeasurementKey') -> int: ) def copy(self): - """Creates a copy of the object.""" return ClassicalDataDictionaryStore( _measurements=self._measurements.copy(), _measured_qubits=self._measured_qubits.copy(), From 15fdf993018bec15380dece386da6ed6602f4b4e Mon Sep 17 00:00:00 2001 From: daxfohl Date: Thu, 30 Dec 2021 11:31:05 -0800 Subject: [PATCH 61/89] Deprecate _create_partial_act_on_args --- cirq-core/cirq/sim/simulator_base.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/cirq-core/cirq/sim/simulator_base.py b/cirq-core/cirq/sim/simulator_base.py index 5ceaa81d141..73a6cb5a818 100644 --- a/cirq-core/cirq/sim/simulator_base.py +++ b/cirq-core/cirq/sim/simulator_base.py @@ -34,6 +34,7 @@ import numpy as np from cirq import ops, protocols, study, value, devices +from cirq._compat import deprecated from cirq.sim import ActOnArgsContainer from cirq.sim.operation_target import OperationTarget from cirq.sim.simulator import ( @@ -119,6 +120,10 @@ def __init__( self._ignore_measurement_results = ignore_measurement_results self._split_untangled_states = split_untangled_states + @deprecated( + deadline="v0.15", + fix="Override _create_partial_act_on_args_ex instead", + ) def _create_partial_act_on_args( self, initial_state: Any, @@ -158,7 +163,8 @@ def _create_partial_act_on_args_ex( simulation. """ # Child classes should override this behavior. We call the old one here by default for - # backwards compatibility, until deprecation cycle is complete. + # backwards compatibility, until deprecation cycle is complete. This method should be + # marked abstract once the deprecation is finished. # coverage: ignore return self._create_partial_act_on_args( initial_state, qubits, classical_data.measurements # type: ignore From 07cb0d5583d799f85915593cc892de870fe2e0f3 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Fri, 21 Jan 2022 09:57:04 -0800 Subject: [PATCH 62/89] lint --- cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py | 2 +- cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py b/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py index 5f576342c3b..ea00294dbd7 100644 --- a/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py +++ b/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py @@ -14,7 +14,7 @@ """A protocol for implementing high performance clifford tableau evolutions for Clifford Simulator.""" -from typing import Any, Dict, List, Sequence, TYPE_CHECKING, Union +from typing import Dict, List, Optional, Sequence, TYPE_CHECKING import numpy as np diff --git a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py index 119d9462851..70f7e576801 100644 --- a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py +++ b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Sequence, TYPE_CHECKING, Union +from typing import Dict, List, Optional, Sequence, TYPE_CHECKING import numpy as np From 4dc6c26b20a796a5ab7785dcadd492291cfb58a6 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Fri, 21 Jan 2022 10:00:03 -0800 Subject: [PATCH 63/89] test --- cirq-core/cirq/sim/simulator_base_test.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/cirq-core/cirq/sim/simulator_base_test.py b/cirq-core/cirq/sim/simulator_base_test.py index 3a20148ad2e..11b97d02225 100644 --- a/cirq-core/cirq/sim/simulator_base_test.py +++ b/cirq-core/cirq/sim/simulator_base_test.py @@ -264,13 +264,15 @@ class MockCountingSimulator( MockCountingActOnArgs, ] ): - def _create_partial_act_on_args( + def _create_partial_act_on_args_ex( self, initial_state: Any, qubits: Sequence['cirq.Qid'], - logs: Dict[str, Any], + classical_data: cirq.ClassicalDataStore, ) -> MockCountingActOnArgs: - return MockCountingActOnArgs(qubits=qubits, state=initial_state, logs=logs) + return MockCountingActOnArgs( + qubits=qubits, state=initial_state, classical_data=classical_data + ) def _create_simulator_trial_result( self, From 437a1532b4153d7ccfedfa5e2c1f351db7eb3146 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Fri, 21 Jan 2022 10:01:54 -0800 Subject: [PATCH 64/89] lint --- cirq-core/cirq/contrib/quimb/mps_simulator.py | 2 +- cirq-core/cirq/sim/act_on_args_container.py | 5 ++--- cirq-core/cirq/sim/clifford/clifford_simulator.py | 2 +- cirq-core/cirq/sim/sparse_simulator.py | 2 -- 4 files changed, 4 insertions(+), 7 deletions(-) diff --git a/cirq-core/cirq/contrib/quimb/mps_simulator.py b/cirq-core/cirq/contrib/quimb/mps_simulator.py index 7b0d53b01dd..4b0e3e6eb32 100644 --- a/cirq-core/cirq/contrib/quimb/mps_simulator.py +++ b/cirq-core/cirq/contrib/quimb/mps_simulator.py @@ -19,7 +19,7 @@ import dataclasses import math -from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, TYPE_CHECKING, Union +from typing import Any, Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Union import numpy as np import quimb.tensor as qtn diff --git a/cirq-core/cirq/sim/act_on_args_container.py b/cirq-core/cirq/sim/act_on_args_container.py index e212cf469ec..481e2a19908 100644 --- a/cirq-core/cirq/sim/act_on_args_container.py +++ b/cirq-core/cirq/sim/act_on_args_container.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import abc import inspect +import warnings +from collections import abc from typing import ( - Any, Dict, Generic, Iterator, @@ -26,7 +26,6 @@ TYPE_CHECKING, Union, ) -import warnings import numpy as np diff --git a/cirq-core/cirq/sim/clifford/clifford_simulator.py b/cirq-core/cirq/sim/clifford/clifford_simulator.py index 2351e423b4e..a3acb447329 100644 --- a/cirq-core/cirq/sim/clifford/clifford_simulator.py +++ b/cirq-core/cirq/sim/clifford/clifford_simulator.py @@ -29,7 +29,7 @@ to state vector amplitudes. """ -from typing import Any, Dict, List, Sequence, Tuple, Union +from typing import Any, Dict, List, Sequence, Union import numpy as np diff --git a/cirq-core/cirq/sim/sparse_simulator.py b/cirq-core/cirq/sim/sparse_simulator.py index 47233e6446e..9a67c5aaa2c 100644 --- a/cirq-core/cirq/sim/sparse_simulator.py +++ b/cirq-core/cirq/sim/sparse_simulator.py @@ -16,14 +16,12 @@ from typing import ( Any, - Dict, Iterator, List, Type, TYPE_CHECKING, Union, Sequence, - Tuple, Optional, ) From f190ffadbe145775202336465a313f6192bafb52 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Fri, 21 Jan 2022 10:38:00 -0800 Subject: [PATCH 65/89] nits --- cirq-core/cirq/sim/clifford/clifford_simulator.py | 2 +- cirq-core/cirq/sim/density_matrix_simulator.py | 2 +- cirq-core/cirq/value/classical_data.py | 2 +- cirq-core/cirq/value/classical_data_test.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cirq-core/cirq/sim/clifford/clifford_simulator.py b/cirq-core/cirq/sim/clifford/clifford_simulator.py index a3acb447329..00d1ee8f109 100644 --- a/cirq-core/cirq/sim/clifford/clifford_simulator.py +++ b/cirq-core/cirq/sim/clifford/clifford_simulator.py @@ -68,7 +68,7 @@ def _create_partial_act_on_args_ex( self, initial_state: Union[int, 'cirq.ActOnStabilizerCHFormArgs'], qubits: Sequence['cirq.Qid'], - classical_data: 'cirq.ClassicalDataStore' = None, + classical_data: 'cirq.ClassicalDataStore', ) -> 'cirq.ActOnStabilizerCHFormArgs': """Creates the ActOnStabilizerChFormArgs for a circuit. diff --git a/cirq-core/cirq/sim/density_matrix_simulator.py b/cirq-core/cirq/sim/density_matrix_simulator.py index 54b3770d0b4..d4fbdf24ef4 100644 --- a/cirq-core/cirq/sim/density_matrix_simulator.py +++ b/cirq-core/cirq/sim/density_matrix_simulator.py @@ -176,7 +176,7 @@ def _create_partial_act_on_args_ex( self, initial_state: Union[np.ndarray, 'cirq.STATE_VECTOR_LIKE', 'cirq.ActOnDensityMatrixArgs'], qubits: Sequence['cirq.Qid'], - classical_data: 'cirq.ClassicalDataStore' = None, + classical_data: 'cirq.ClassicalDataStore', ) -> 'cirq.ActOnDensityMatrixArgs': """Creates the ActOnDensityMatrixArgs for a circuit. diff --git a/cirq-core/cirq/value/classical_data.py b/cirq-core/cirq/value/classical_data.py index 51e89c81fed..1b222f1467d 100644 --- a/cirq-core/cirq/value/classical_data.py +++ b/cirq-core/cirq/value/classical_data.py @@ -1,4 +1,4 @@ -# Copyright 2021 The Cirq Developers +# Copyright 2022 The Cirq Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/cirq-core/cirq/value/classical_data_test.py b/cirq-core/cirq/value/classical_data_test.py index 0fe9fb6f8ca..ab607ed7a61 100644 --- a/cirq-core/cirq/value/classical_data_test.py +++ b/cirq-core/cirq/value/classical_data_test.py @@ -1,4 +1,4 @@ -# Copyright 2021 The Cirq Developers +# Copyright 2022 The Cirq Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 30f7e460142644a736ef70cd308dee64b0cc9e17 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Tue, 25 Jan 2022 18:35:17 -0800 Subject: [PATCH 66/89] Code review cleanup --- cirq-core/cirq/contrib/quimb/mps_simulator.py | 2 +- cirq-core/cirq/sim/act_on_args.py | 9 +++-- cirq-core/cirq/sim/act_on_args_container.py | 9 +++-- .../act_on_stabilizer_ch_form_args.py | 8 ++-- .../cirq/sim/clifford/clifford_simulator.py | 6 ++- .../sim/clifford/clifford_simulator_test.py | 2 +- .../clifford/stabilizer_state_ch_form_test.py | 7 ++-- .../cirq/sim/density_matrix_simulator.py | 2 +- cirq-core/cirq/sim/simulator_base.py | 40 +++---------------- cirq-core/cirq/sim/simulator_base_test.py | 6 +-- cirq-core/cirq/sim/sparse_simulator.py | 2 +- cirq-core/cirq/value/classical_data.py | 4 +- .../calibration/engine_simulator.py | 9 +++++ 13 files changed, 48 insertions(+), 58 deletions(-) diff --git a/cirq-core/cirq/contrib/quimb/mps_simulator.py b/cirq-core/cirq/contrib/quimb/mps_simulator.py index 4b0e3e6eb32..a18e0983fa1 100644 --- a/cirq-core/cirq/contrib/quimb/mps_simulator.py +++ b/cirq-core/cirq/contrib/quimb/mps_simulator.py @@ -87,7 +87,7 @@ def __init__( seed=seed, ) - def _create_partial_act_on_args_ex( + def _create_partial_act_on_args( self, initial_state: Union[int, 'MPSState'], qubits: Sequence['cirq.Qid'], diff --git a/cirq-core/cirq/sim/act_on_args.py b/cirq-core/cirq/sim/act_on_args.py index a937a11daba..bf8db76e965 100644 --- a/cirq-core/cirq/sim/act_on_args.py +++ b/cirq-core/cirq/sim/act_on_args.py @@ -75,9 +75,12 @@ def __init__( qubits = () self._set_qubits(qubits) self.prng = prng - # pylint: disable=line-too-long - self._classical_data = classical_data or value.ClassicalDataDictionaryStore(_measurements=log_of_measurement_results) # type: ignore - # pylint: enable=line-too-long + self._classical_data = classical_data or value.ClassicalDataDictionaryStore( + _measurements={ + value.MeasurementKey.parse_serialized(k): tuple(v) + for k, v in (log_of_measurement_results or {}).items() + } + ) self._ignore_measurement_results = ignore_measurement_results def _set_qubits(self, qubits: Sequence['cirq.Qid']): diff --git a/cirq-core/cirq/sim/act_on_args_container.py b/cirq-core/cirq/sim/act_on_args_container.py index 481e2a19908..aed03ee6fc2 100644 --- a/cirq-core/cirq/sim/act_on_args_container.py +++ b/cirq-core/cirq/sim/act_on_args_container.py @@ -71,9 +71,12 @@ def __init__( self.args = args self._qubits = tuple(qubits) self.split_untangled_states = split_untangled_states - # pylint: disable=line-too-long - self._classical_data = classical_data or value.ClassicalDataDictionaryStore(_measurements=log_of_measurement_results) # type: ignore - # pylint: enable=line-too-long + self._classical_data = classical_data or value.ClassicalDataDictionaryStore( + _measurements={ + value.MeasurementKey.parse_serialized(k): tuple(v) + for k, v in (log_of_measurement_results or {}).items() + } + ) def create_merged_state(self) -> TActOnArgs: if not self.split_untangled_states: diff --git a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py index 70f7e576801..65583f78b52 100644 --- a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py +++ b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py @@ -67,14 +67,16 @@ def sample( repetitions: int = 1, seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, ) -> np.ndarray: - measurements: Dict[str, List[np.ndarray]] = {} + measurements = value.ClassicalDataDictionaryStore() prng = value.parse_random_state(seed) for i in range(repetitions): op = ops.measure(*qubits, key=str(i)) state = self.state.copy() - ch_form_args = ActOnStabilizerCHFormArgs(state, prng, measurements, self.qubits) + ch_form_args = ActOnStabilizerCHFormArgs( + state, prng, qubits=self.qubits, classical_data=measurements + ) protocols.act_on(op, ch_form_args) - return np.array(list(measurements.values()), dtype=bool) + return np.array(list(measurements.measurements.values()), dtype=bool) def _x(self, g: common_gates.XPowGate, axis: int): exponent = g.exponent diff --git a/cirq-core/cirq/sim/clifford/clifford_simulator.py b/cirq-core/cirq/sim/clifford/clifford_simulator.py index 00d1ee8f109..d8e67ba0d48 100644 --- a/cirq-core/cirq/sim/clifford/clifford_simulator.py +++ b/cirq-core/cirq/sim/clifford/clifford_simulator.py @@ -64,7 +64,7 @@ def is_supported_operation(op: 'cirq.Operation') -> bool: # TODO: support more general Pauli measurements return protocols.has_stabilizer_effect(op) - def _create_partial_act_on_args_ex( + def _create_partial_act_on_args( self, initial_state: Union[int, 'cirq.ActOnStabilizerCHFormArgs'], qubits: Sequence['cirq.Qid'], @@ -277,7 +277,9 @@ def apply_measurement( else: state = self.copy() + classical_data = value.ClassicalDataDictionaryStore() ch_form_args = clifford.ActOnStabilizerCHFormArgs( - state.ch_form, prng, measurements, self.qubit_map.keys() + state.ch_form, prng, qubits=self.qubit_map.keys(), classical_data=classical_data ) act_on(op, ch_form_args) + measurements.update({str(k): list(v) for k, v in classical_data.measurements.items()}) diff --git a/cirq-core/cirq/sim/clifford/clifford_simulator_test.py b/cirq-core/cirq/sim/clifford/clifford_simulator_test.py index cdd8c5fcdff..59156a3b88f 100644 --- a/cirq-core/cirq/sim/clifford/clifford_simulator_test.py +++ b/cirq-core/cirq/sim/clifford/clifford_simulator_test.py @@ -546,7 +546,7 @@ def test_valid_apply_measurement(): state = cirq.CliffordState(qubit_map={q0: 0}, initial_state=1) measurements = {} _ = state.apply_measurement(cirq.measure(q0), measurements, np.random.RandomState()) - assert measurements == {'0': (1,)} + assert measurements == {'0': [1]} def test_reset(): diff --git a/cirq-core/cirq/sim/clifford/stabilizer_state_ch_form_test.py b/cirq-core/cirq/sim/clifford/stabilizer_state_ch_form_test.py index 2b3acb09723..8607b6491ff 100644 --- a/cirq-core/cirq/sim/clifford/stabilizer_state_ch_form_test.py +++ b/cirq-core/cirq/sim/clifford/stabilizer_state_ch_form_test.py @@ -64,14 +64,15 @@ def test_run(): ) for _ in range(10): state = cirq.StabilizerStateChForm(num_qubits=3) - measurements = {} + classical_data = cirq.ClassicalDataDictionaryStore() for op in circuit.all_operations(): args = cirq.ActOnStabilizerCHFormArgs( state, qubits=list(circuit.all_qubits()), prng=np.random.RandomState(), - log_of_measurement_results=measurements, + classical_data=classical_data, ) cirq.act_on(op, args) - assert measurements['1'] == (1,) + measurements = {str(k): list(v) for k, v in classical_data.measurements.items()} + assert measurements['1'] == [1] assert measurements['0'] != measurements['2'] diff --git a/cirq-core/cirq/sim/density_matrix_simulator.py b/cirq-core/cirq/sim/density_matrix_simulator.py index d4fbdf24ef4..6f5866f5b07 100644 --- a/cirq-core/cirq/sim/density_matrix_simulator.py +++ b/cirq-core/cirq/sim/density_matrix_simulator.py @@ -172,7 +172,7 @@ def __init__( if dtype not in {np.complex64, np.complex128}: raise ValueError(f'dtype must be complex64 or complex128, was {dtype}') - def _create_partial_act_on_args_ex( + def _create_partial_act_on_args( self, initial_state: Union[np.ndarray, 'cirq.STATE_VECTOR_LIKE', 'cirq.ActOnDensityMatrixArgs'], qubits: Sequence['cirq.Qid'], diff --git a/cirq-core/cirq/sim/simulator_base.py b/cirq-core/cirq/sim/simulator_base.py index a37e864c491..1a32f468894 100644 --- a/cirq-core/cirq/sim/simulator_base.py +++ b/cirq-core/cirq/sim/simulator_base.py @@ -122,31 +122,8 @@ def __init__( self._ignore_measurement_results = ignore_measurement_results self._split_untangled_states = split_untangled_states - @deprecated( - deadline="v0.15", - fix="Override _create_partial_act_on_args_ex instead", - ) + @abc.abstractmethod def _create_partial_act_on_args( - self, - initial_state: Any, - qubits: Sequence['cirq.Qid'], - logs: Dict[str, Any], - ) -> TActOnArgs: - """Creates an instance of the TActOnArgs class for the simulator. - - It represents the supplied qubits initialized to the provided state. - - Args: - initial_state: The initial state to represent. An integer state is - understood to be a pure state. Other state representations are - simulator-dependent. - qubits: The sequence of qubits to represent. - logs: The structure to hold measurement logs. A single instance - should be shared among all ActOnArgs within the simulation. - """ - raise NotImplementedError() - - def _create_partial_act_on_args_ex( self, initial_state: Any, qubits: Sequence['cirq.Qid'], @@ -164,13 +141,6 @@ def _create_partial_act_on_args_ex( classical_data: The shared classical data container for this simulation. """ - # Child classes should override this behavior. We call the old one here by default for - # backwards compatibility, until deprecation cycle is complete. This method should be - # marked abstract once the deprecation is finished. - # coverage: ignore - return self._create_partial_act_on_args( - initial_state, qubits, classical_data.measurements # type: ignore - ) @abc.abstractmethod def _create_step_result( @@ -388,26 +358,26 @@ def _create_act_on_args( args_map: Dict[Optional['cirq.Qid'], TActOnArgs] = {} if isinstance(initial_state, int): for q in reversed(qubits): - args_map[q] = self._create_partial_act_on_args_ex( + args_map[q] = self._create_partial_act_on_args( initial_state=initial_state % q.dimension, qubits=[q], classical_data=classical_data, ) initial_state = int(initial_state / q.dimension) else: - args = self._create_partial_act_on_args_ex( + args = self._create_partial_act_on_args( initial_state=initial_state, qubits=qubits, classical_data=classical_data, ) for q in qubits: args_map[q] = args - args_map[None] = self._create_partial_act_on_args_ex(0, (), classical_data) + args_map[None] = self._create_partial_act_on_args(0, (), classical_data) return ActOnArgsContainer( args_map, qubits, self._split_untangled_states, classical_data=classical_data ) else: - return self._create_partial_act_on_args_ex( + return self._create_partial_act_on_args( initial_state=initial_state, qubits=qubits, classical_data=classical_data, diff --git a/cirq-core/cirq/sim/simulator_base_test.py b/cirq-core/cirq/sim/simulator_base_test.py index 11b97d02225..93260f74320 100644 --- a/cirq-core/cirq/sim/simulator_base_test.py +++ b/cirq-core/cirq/sim/simulator_base_test.py @@ -107,7 +107,7 @@ def __init__(self, noise=None, split_untangled_states=False): split_untangled_states=split_untangled_states, ) - def _create_partial_act_on_args_ex( + def _create_partial_act_on_args( self, initial_state: Any, qubits: Sequence['cirq.Qid'], @@ -137,7 +137,7 @@ def __init__(self, noise=None, split_untangled_states=True): split_untangled_states=split_untangled_states, ) - def _create_partial_act_on_args_ex( + def _create_partial_act_on_args( self, initial_state: Any, qubits: Sequence['cirq.Qid'], @@ -264,7 +264,7 @@ class MockCountingSimulator( MockCountingActOnArgs, ] ): - def _create_partial_act_on_args_ex( + def _create_partial_act_on_args( self, initial_state: Any, qubits: Sequence['cirq.Qid'], diff --git a/cirq-core/cirq/sim/sparse_simulator.py b/cirq-core/cirq/sim/sparse_simulator.py index 9a67c5aaa2c..54f518b4b9e 100644 --- a/cirq-core/cirq/sim/sparse_simulator.py +++ b/cirq-core/cirq/sim/sparse_simulator.py @@ -169,7 +169,7 @@ def __init__( split_untangled_states=split_untangled_states, ) - def _create_partial_act_on_args_ex( + def _create_partial_act_on_args( self, initial_state: Union['cirq.STATE_VECTOR_LIKE', 'cirq.ActOnStateVectorArgs'], qubits: Sequence['cirq.Qid'], diff --git a/cirq-core/cirq/value/classical_data.py b/cirq-core/cirq/value/classical_data.py index 1b222f1467d..696f252f09f 100644 --- a/cirq-core/cirq/value/classical_data.py +++ b/cirq-core/cirq/value/classical_data.py @@ -44,8 +44,8 @@ def get_int(self, key: 'cirq.MeasurementKey') -> int: The integer is determined by summing the qubit-dimensional basis value of each measured qubit. For example, if the measurement of qubits - [q0, q1] produces [0, 1], then the corresponding integer is 2, the big- - endian equivalent. If they are qutrits and the measurement is [1, 2], + [q1, q0] produces [1, 0], then the corresponding integer is 2, the big- + endian equivalent. If they are qutrits and the measurement is [2, 1], then the integer is 2 * 3 + 1 = 7. Args: diff --git a/cirq-google/cirq_google/calibration/engine_simulator.py b/cirq-google/cirq_google/calibration/engine_simulator.py index db4eeb5c619..29bd90f0913 100644 --- a/cirq-google/cirq_google/calibration/engine_simulator.py +++ b/cirq-google/cirq_google/calibration/engine_simulator.py @@ -470,6 +470,15 @@ def simulate( converted = _convert_to_circuit_with_drift(self, program) return self._simulator.simulate(converted, param_resolver, qubit_order, initial_state) + def _create_partial_act_on_args( + self, + initial_state: Union[int, cirq.ActOnStateVectorArgs], + qubits: Sequence[cirq.Qid], + classical_data: cirq.ClassicalDataStore, + ) -> cirq.ActOnStateVectorArgs: + # Needs an implementation since it's abstract but will never actually be called. + raise NotImplementedError() + def _create_step_result( self, sim_state: cirq.OperationTarget, From 102303abecf846a7d500dad3f96c8fb35e7dd53f Mon Sep 17 00:00:00 2001 From: daxfohl Date: Tue, 25 Jan 2022 18:59:00 -0800 Subject: [PATCH 67/89] More comparisons in measurement_key --- cirq-core/cirq/value/measurement_key.py | 3 +++ cirq-core/cirq/value/measurement_key_test.py | 11 +++++++++++ 2 files changed, 14 insertions(+) diff --git a/cirq-core/cirq/value/measurement_key.py b/cirq-core/cirq/value/measurement_key.py index 9e98cf68c2d..e53eac47fde 100644 --- a/cirq-core/cirq/value/measurement_key.py +++ b/cirq-core/cirq/value/measurement_key.py @@ -84,6 +84,9 @@ def __lt__(self, other): return self.name < other.name return NotImplemented + def __le__(self, other): + return self == other or self < other + def _json_dict_(self): return { 'name': self.name, diff --git a/cirq-core/cirq/value/measurement_key_test.py b/cirq-core/cirq/value/measurement_key_test.py index 58454b2b74c..e04a8be9c62 100644 --- a/cirq-core/cirq/value/measurement_key_test.py +++ b/cirq-core/cirq/value/measurement_key_test.py @@ -102,7 +102,18 @@ def test_with_measurement_key_mapping(): def test_compare(): assert cirq.MeasurementKey('a') < cirq.MeasurementKey('b') + assert cirq.MeasurementKey('a') <= cirq.MeasurementKey('b') + assert cirq.MeasurementKey('a') <= cirq.MeasurementKey('a') + assert cirq.MeasurementKey('b') > cirq.MeasurementKey('a') + assert cirq.MeasurementKey('b') >= cirq.MeasurementKey('a') + assert cirq.MeasurementKey('a') >= cirq.MeasurementKey('a') + assert not cirq.MeasurementKey('a') > cirq.MeasurementKey('b') + assert not cirq.MeasurementKey('a') >= cirq.MeasurementKey('b') + assert not cirq.MeasurementKey('b') < cirq.MeasurementKey('a') + assert not cirq.MeasurementKey('b') <= cirq.MeasurementKey('a') assert cirq.MeasurementKey(path=(), name='b') < cirq.MeasurementKey(path=('0',), name='a') assert cirq.MeasurementKey(path=('0',), name='n') < cirq.MeasurementKey(path=('1',), name='a') with pytest.raises(TypeError): _ = cirq.MeasurementKey('a') < 'b' + with pytest.raises(TypeError): + _ = cirq.MeasurementKey('a') <= 'b' From be0e5ff4aed23e97e0c8bb09875ce788dab20ce2 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Tue, 25 Jan 2022 19:00:03 -0800 Subject: [PATCH 68/89] lint --- cirq-core/cirq/sim/simulator_base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cirq-core/cirq/sim/simulator_base.py b/cirq-core/cirq/sim/simulator_base.py index 1a32f468894..610e369611c 100644 --- a/cirq-core/cirq/sim/simulator_base.py +++ b/cirq-core/cirq/sim/simulator_base.py @@ -17,6 +17,7 @@ import abc import collections import inspect +import warnings from typing import ( Any, Dict, @@ -31,12 +32,10 @@ Optional, TypeVar, ) -import warnings import numpy as np from cirq import ops, protocols, study, value, devices -from cirq._compat import deprecated from cirq.sim import ActOnArgsContainer from cirq.sim.operation_target import OperationTarget from cirq.sim.simulator import ( From ba7a1f5ca6bd57ace1afbbbd6ba51d95d103dcbe Mon Sep 17 00:00:00 2001 From: daxfohl Date: Tue, 25 Jan 2022 19:03:38 -0800 Subject: [PATCH 69/89] Additional tests --- cirq-core/cirq/value/classical_data.py | 2 +- cirq-core/cirq/value/classical_data_test.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/cirq-core/cirq/value/classical_data.py b/cirq-core/cirq/value/classical_data.py index 696f252f09f..5596be02efd 100644 --- a/cirq-core/cirq/value/classical_data.py +++ b/cirq-core/cirq/value/classical_data.py @@ -170,7 +170,7 @@ def record_measurement( ): if len(measurement) != len(qubits): raise ValueError(f'{len(measurement)} measurements but {len(qubits)} qubits.') - if key in self._measurements: + if key in self._measurement_types: raise ValueError(f"Measurement already logged to key {key}") self._measurement_types[key] = MeasurementType.MEASUREMENT self._measurements[key] = tuple(measurement) diff --git a/cirq-core/cirq/value/classical_data_test.py b/cirq-core/cirq/value/classical_data_test.py index ab607ed7a61..00cfe475d0e 100644 --- a/cirq-core/cirq/value/classical_data_test.py +++ b/cirq-core/cirq/value/classical_data_test.py @@ -72,6 +72,14 @@ def test_record_channel_measurement_errors(): cd.record_channel_measurement(mkey_m, 1) with pytest.raises(ValueError, match='Measurement already logged to key m'): cd.record_channel_measurement(mkey_m, 1) + with pytest.raises(ValueError, match='Measurement already logged to key m'): + cd.record_measurement(mkey_m, (0, 1), two_qubits) + cd = cirq.ClassicalDataDictionaryStore() + cd.record_measurement(mkey_m, (0, 1), two_qubits) + with pytest.raises(ValueError, match='Measurement already logged to key m'): + cd.record_channel_measurement(mkey_m, 1) + with pytest.raises(ValueError, match='Measurement already logged to key m'): + cd.record_measurement(mkey_m, (0, 1), two_qubits) def test_get_int(): From 24abf5216394dd204ad2a6720b2584025d1b3935 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Tue, 25 Jan 2022 23:14:57 -0800 Subject: [PATCH 70/89] Add extra dimension to classical_data.py --- cirq-core/cirq/ops/measurement_gate_test.py | 9 -- .../ClassicalDataDictionaryStore.json | 22 +++-- .../ClassicalDataDictionaryStore.repr | 2 +- cirq-core/cirq/sim/act_on_args.py | 2 +- cirq-core/cirq/sim/act_on_args_container.py | 2 +- cirq-core/cirq/value/classical_data.py | 96 +++++++++++-------- cirq-core/cirq/value/classical_data_test.py | 33 +++---- cirq-core/cirq/value/condition_test.py | 36 +++---- 8 files changed, 108 insertions(+), 94 deletions(-) diff --git a/cirq-core/cirq/ops/measurement_gate_test.py b/cirq-core/cirq/ops/measurement_gate_test.py index eaa071be9c9..7b1522e31ca 100644 --- a/cirq-core/cirq/ops/measurement_gate_test.py +++ b/cirq-core/cirq/ops/measurement_gate_test.py @@ -322,9 +322,6 @@ def test_act_on_state_vector(): cirq.act_on(m, args) assert args.log_of_measurement_results == {'out': [0, 1]} - with pytest.raises(ValueError, match="already logged to key"): - cirq.act_on(m, args) - def test_act_on_clifford_tableau(): a, b = [cirq.LineQubit(3), cirq.LineQubit(1)] @@ -360,9 +357,6 @@ def test_act_on_clifford_tableau(): cirq.act_on(m, args) assert args.log_of_measurement_results == {'out': [0, 1]} - with pytest.raises(ValueError, match="already logged to key"): - cirq.act_on(m, args) - def test_act_on_stabilizer_ch_form(): a, b = [cirq.LineQubit(3), cirq.LineQubit(1)] @@ -398,9 +392,6 @@ def test_act_on_stabilizer_ch_form(): cirq.act_on(m, args) assert args.log_of_measurement_results == {'out': [0, 1]} - with pytest.raises(ValueError, match="already logged to key"): - cirq.act_on(m, args) - def test_act_on_qutrit(): a, b = [cirq.LineQid(3, dimension=3), cirq.LineQid(1, dimension=3)] diff --git a/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.json b/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.json index d5c51d5839c..ae5bdcb687e 100644 --- a/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.json +++ b/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.json @@ -7,7 +7,7 @@ "name": "m", "path": [] }, - [0, 1] + [[0, 1]] ] ], "measured_qubits": [ @@ -18,14 +18,16 @@ "path": [] }, [ - { - "cirq_type": "LineQubit", - "x": 0 - }, - { - "cirq_type": "LineQubit", - "x": 1 - } + [ + { + "cirq_type": "LineQubit", + "x": 0 + }, + { + "cirq_type": "LineQubit", + "x": 1 + } + ] ] ] ], @@ -36,7 +38,7 @@ "name": "c", "path": [] }, - 3 + [3] ] ], "measurement_types": [ diff --git a/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.repr b/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.repr index c19b8190bfb..18351d55637 100644 --- a/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.repr +++ b/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.repr @@ -1 +1 @@ -cirq.ClassicalDataDictionaryStore(_measurements={cirq.MeasurementKey('m'): [0, 1]}, _measured_qubits={cirq.MeasurementKey('m'): [cirq.LineQubit(0), cirq.LineQubit(1)]}, _channel_measurements={cirq.MeasurementKey('c'): 3}, _measurement_types={cirq.MeasurementKey('m'): cirq.MeasurementType.MEASUREMENT, cirq.MeasurementKey('c'): cirq.MeasurementType.CHANNEL}) \ No newline at end of file +cirq.ClassicalDataDictionaryStore(_measurements={cirq.MeasurementKey('m'): [[0, 1]]}, _measured_qubits={cirq.MeasurementKey('m'): [[cirq.LineQubit(0), cirq.LineQubit(1)]]}, _channel_measurements={cirq.MeasurementKey('c'): [3]}, _measurement_types={cirq.MeasurementKey('m'): cirq.MeasurementType.MEASUREMENT, cirq.MeasurementKey('c'): cirq.MeasurementType.CHANNEL}) \ No newline at end of file diff --git a/cirq-core/cirq/sim/act_on_args.py b/cirq-core/cirq/sim/act_on_args.py index bf8db76e965..c30d46157d2 100644 --- a/cirq-core/cirq/sim/act_on_args.py +++ b/cirq-core/cirq/sim/act_on_args.py @@ -77,7 +77,7 @@ def __init__( self.prng = prng self._classical_data = classical_data or value.ClassicalDataDictionaryStore( _measurements={ - value.MeasurementKey.parse_serialized(k): tuple(v) + value.MeasurementKey.parse_serialized(k): [tuple(v)] for k, v in (log_of_measurement_results or {}).items() } ) diff --git a/cirq-core/cirq/sim/act_on_args_container.py b/cirq-core/cirq/sim/act_on_args_container.py index aed03ee6fc2..e1b9ce3b038 100644 --- a/cirq-core/cirq/sim/act_on_args_container.py +++ b/cirq-core/cirq/sim/act_on_args_container.py @@ -73,7 +73,7 @@ def __init__( self.split_untangled_states = split_untangled_states self._classical_data = classical_data or value.ClassicalDataDictionaryStore( _measurements={ - value.MeasurementKey.parse_serialized(k): tuple(v) + value.MeasurementKey.parse_serialized(k): [tuple(v)] for k, v in (log_of_measurement_results or {}).items() } ) diff --git a/cirq-core/cirq/value/classical_data.py b/cirq-core/cirq/value/classical_data.py index 5596be02efd..11f28fc96bb 100644 --- a/cirq-core/cirq/value/classical_data.py +++ b/cirq-core/cirq/value/classical_data.py @@ -14,7 +14,7 @@ import abc import enum -from typing import Dict, Mapping, Sequence, Tuple, TYPE_CHECKING, TypeVar +from typing import Dict, Mapping, Sequence, Tuple, TYPE_CHECKING, TypeVar, List from cirq.value import digits, value_equality_attr @@ -39,7 +39,7 @@ def keys(self) -> Tuple['cirq.MeasurementKey', ...]: """Gets the measurement keys in the order they were stored.""" @abc.abstractmethod - def get_int(self, key: 'cirq.MeasurementKey') -> int: + def get_int(self, key: 'cirq.MeasurementKey', index=-1) -> int: """Gets the integer corresponding to the measurement. The integer is determined by summing the qubit-dimensional basis value @@ -50,13 +50,18 @@ def get_int(self, key: 'cirq.MeasurementKey') -> int: Args: key: The measurement key. + index: If multiple measurements have the same key, the index + argument can be used to specify which measurement to retrieve. + Here `0` refers to the first measurement, and `-1` refers to + the most recent. Raises: - KeyError: If the key has not been used. + KeyError: If the key has not been used or if the index is out of + bounds. """ @abc.abstractmethod - def get_digits(self, key: 'cirq.MeasurementKey') -> Tuple[int, ...]: + def get_digits(self, key: 'cirq.MeasurementKey', index=-1) -> Tuple[int, ...]: """Gets the values of the qubits that were measured into this key. For example, if the measurement of qubits [q0, q1] produces [0, 1], @@ -64,9 +69,14 @@ def get_digits(self, key: 'cirq.MeasurementKey') -> Tuple[int, ...]: Args: key: The measurement key. + index: If multiple measurements have the same key, the index + argument can be used to specify which measurement to retrieve. + Here `0` refers to the first measurement, and `-1` refers to + the most recent. Raises: - KeyError: If the key has not been used. + KeyError: If the key has not been used or if the index is out of + bounds. """ @abc.abstractmethod @@ -111,9 +121,9 @@ class ClassicalDataDictionaryStore(ClassicalDataStore): def __init__( self, *, - _measurements: Dict['cirq.MeasurementKey', Tuple[int, ...]] = None, - _measured_qubits: Dict['cirq.MeasurementKey', Tuple['cirq.Qid', ...]] = None, - _channel_measurements: Dict['cirq.MeasurementKey', int] = None, + _measurements: Dict['cirq.MeasurementKey', List[Tuple[int, ...]]] = None, + _measured_qubits: Dict['cirq.MeasurementKey', List[Tuple['cirq.Qid', ...]]] = None, + _channel_measurements: Dict['cirq.MeasurementKey', List[int]] = None, _measurement_types: Dict['cirq.MeasurementKey', 'cirq.MeasurementType'] = None, ): """Initializes a `ClassicalDataDictionaryStore` object.""" @@ -133,11 +143,11 @@ def __init__( _measured_qubits = {} if _channel_measurements is None: _channel_measurements = {} - self._measurements: Dict['cirq.MeasurementKey', Tuple[int, ...]] = _measurements + self._measurements: Dict['cirq.MeasurementKey', List[Tuple[int, ...]]] = _measurements self._measured_qubits: Dict[ - 'cirq.MeasurementKey', Tuple['cirq.Qid', ...] + 'cirq.MeasurementKey', List[Tuple['cirq.Qid', ...]] ] = _measured_qubits - self._channel_measurements: Dict['cirq.MeasurementKey', int] = _channel_measurements + self._channel_measurements: Dict['cirq.MeasurementKey', List[int]] = _channel_measurements self._measurement_types: Dict[ 'cirq.MeasurementKey', 'cirq.MeasurementType' ] = _measurement_types @@ -145,17 +155,17 @@ def __init__( @property def measurements(self) -> Mapping['cirq.MeasurementKey', Tuple[int, ...]]: """Gets the a mapping from measurement key to measurement.""" - return self._measurements + return {k: v[-1] for k, v in self._measurements.items()} @property def channel_measurements(self) -> Mapping['cirq.MeasurementKey', int]: """Gets the a mapping from measurement key to channel measurement.""" - return self._channel_measurements + return {k: v[-1] for k, v in self._channel_measurements.items()} @property def measured_qubits(self) -> Mapping['cirq.MeasurementKey', Tuple['cirq.Qid', ...]]: """Gets the a mapping from measurement key to the qubits measured.""" - return self._measured_qubits + return {k: v[-1] for k, v in self._measured_qubits.items()} @property def measurement_types(self) -> Mapping['cirq.MeasurementKey', 'cirq.MeasurementType']: @@ -170,35 +180,45 @@ def record_measurement( ): if len(measurement) != len(qubits): raise ValueError(f'{len(measurement)} measurements but {len(qubits)} qubits.') - if key in self._measurement_types: - raise ValueError(f"Measurement already logged to key {key}") - self._measurement_types[key] = MeasurementType.MEASUREMENT - self._measurements[key] = tuple(measurement) - self._measured_qubits[key] = tuple(qubits) + if key not in self._measurement_types: + self._measurement_types[key] = MeasurementType.MEASUREMENT + self._measurements[key] = [] + self._measured_qubits[key] = [] + if self._measurement_types[key] != MeasurementType.MEASUREMENT: + raise ValueError(f"Channel Measurement already logged to key {key}") + measured_qubits = self._measured_qubits[key] + if measured_qubits: + if [q.dimension for q in qubits] != [q.dimension for q in measured_qubits[-1]]: + raise ValueError(f'Measurements of keys must all be same shape.') + measured_qubits.append(tuple(qubits)) + self._measurements[key].append(tuple(measurement)) def record_channel_measurement(self, key: 'cirq.MeasurementKey', measurement: int): - if key in self._measurement_types: + if key not in self._measurement_types: + self._measurement_types[key] = MeasurementType.CHANNEL + self._channel_measurements[key] = [] + if self._measurement_types[key] != MeasurementType.CHANNEL: raise ValueError(f"Measurement already logged to key {key}") - self._measurement_types[key] = MeasurementType.CHANNEL - self._channel_measurements[key] = measurement + self._channel_measurements[key].append(measurement) - def get_digits(self, key: 'cirq.MeasurementKey') -> Tuple[int, ...]: + def get_digits(self, key: 'cirq.MeasurementKey', index=-1) -> Tuple[int, ...]: return ( - self._measurements[key] + self._measurements[key][index] if self._measurement_types[key] == MeasurementType.MEASUREMENT - else (self._channel_measurements[key],) + else (self._channel_measurements[key][index],) ) - def get_int(self, key: 'cirq.MeasurementKey') -> int: + def get_int(self, key: 'cirq.MeasurementKey', index=-1) -> int: if key not in self._measurement_types: - raise KeyError(f'The measurement key {key} is not in {self._measurements}') + raise KeyError(f'The measurement key {key} is not in {self._measurement_types}') measurement_type = self._measurement_types[key] if measurement_type == MeasurementType.CHANNEL: - return self._channel_measurements[key] + return self._channel_measurements[key][index] if key not in self._measured_qubits: - return digits.big_endian_bits_to_int(self._measurements[key]) + return digits.big_endian_bits_to_int(self._measurements[key][index]) return digits.big_endian_digits_to_int( - self._measurements[key], base=[q.dimension for q in self._measured_qubits[key]] + self._measurements[key][index], + base=[q.dimension for q in self._measured_qubits[key][index]], ) def copy(self): @@ -211,10 +231,10 @@ def copy(self): def _json_dict_(self): return { - 'measurements': list(self.measurements.items()), - 'measured_qubits': list(self.measured_qubits.items()), - 'channel_measurements': list(self.channel_measurements.items()), - 'measurement_types': list(self.measurement_types.items()), + 'measurements': list(self._measurements.items()), + 'measured_qubits': list(self._measured_qubits.items()), + 'channel_measurements': list(self._channel_measurements.items()), + 'measurement_types': list(self._measurement_types.items()), } @classmethod @@ -230,10 +250,10 @@ def _from_json_dict_( def __repr__(self): return ( - f'cirq.ClassicalDataDictionaryStore(_measurements={self.measurements!r},' - f' _measured_qubits={self.measured_qubits!r},' - f' _channel_measurements={self.channel_measurements!r},' - f' _measurement_types={self.measurement_types!r})' + f'cirq.ClassicalDataDictionaryStore(_measurements={self._measurements!r},' + f' _measured_qubits={self._measured_qubits!r},' + f' _channel_measurements={self._channel_measurements!r},' + f' _measurement_types={self._measurement_types!r})' ) def _value_equality_values_(self): diff --git a/cirq-core/cirq/value/classical_data_test.py b/cirq-core/cirq/value/classical_data_test.py index 00cfe475d0e..208565e44c6 100644 --- a/cirq-core/cirq/value/classical_data_test.py +++ b/cirq-core/cirq/value/classical_data_test.py @@ -29,9 +29,9 @@ def test_init(): assert cd.channel_measurements == {} assert cd.measurement_types == {} cd = cirq.ClassicalDataDictionaryStore( - _measurements={mkey_m: (0, 1)}, - _measured_qubits={mkey_m: two_qubits}, - _channel_measurements={mkey_c: 3}, + _measurements={mkey_m: [(0, 1)]}, + _measured_qubits={mkey_m: [two_qubits]}, + _channel_measurements={mkey_c: [3]}, ) assert cd.measurements == {mkey_m: (0, 1)} assert cd.keys() == (mkey_m, mkey_c) @@ -56,8 +56,11 @@ def test_record_measurement_errors(): with pytest.raises(ValueError, match='3 measurements but 2 qubits'): cd.record_measurement(mkey_m, (0, 1, 2), two_qubits) cd.record_measurement(mkey_m, (0, 1), two_qubits) - with pytest.raises(ValueError, match='Measurement already logged to key m'): - cd.record_measurement(mkey_m, (0, 1), two_qubits) + cd.record_measurement(mkey_m, (1, 0), two_qubits) + with pytest.raises(ValueError, match='Measurements of keys must all be same shape'): + cd.record_measurement(mkey_m, (1, 0, 4), tuple(cirq.LineQubit.range(3))) + with pytest.raises(ValueError, match='Measurements of keys must all be same shape'): + cd.record_measurement(mkey_m, (1, 0), tuple(cirq.LineQid.range(2, dimension=3))) def test_record_channel_measurement(): @@ -70,16 +73,14 @@ def test_record_channel_measurement(): def test_record_channel_measurement_errors(): cd = cirq.ClassicalDataDictionaryStore() cd.record_channel_measurement(mkey_m, 1) - with pytest.raises(ValueError, match='Measurement already logged to key m'): - cd.record_channel_measurement(mkey_m, 1) - with pytest.raises(ValueError, match='Measurement already logged to key m'): + cd.record_channel_measurement(mkey_m, 1) + with pytest.raises(ValueError, match='Channel Measurement already logged to key m'): cd.record_measurement(mkey_m, (0, 1), two_qubits) cd = cirq.ClassicalDataDictionaryStore() cd.record_measurement(mkey_m, (0, 1), two_qubits) + cd.record_measurement(mkey_m, (0, 1), two_qubits) with pytest.raises(ValueError, match='Measurement already logged to key m'): cd.record_channel_measurement(mkey_m, 1) - with pytest.raises(ValueError, match='Measurement already logged to key m'): - cd.record_measurement(mkey_m, (0, 1), two_qubits) def test_get_int(): @@ -102,9 +103,9 @@ def test_get_int(): def test_copy(): cd = cirq.ClassicalDataDictionaryStore( - _measurements={mkey_m: (0, 1)}, - _measured_qubits={mkey_m: two_qubits}, - _channel_measurements={mkey_c: 3}, + _measurements={mkey_m: [(0, 1)]}, + _measured_qubits={mkey_m: [two_qubits]}, + _channel_measurements={mkey_c: [3]}, _measurement_types={ mkey_m: cirq.MeasurementType.MEASUREMENT, mkey_c: cirq.MeasurementType.CHANNEL, @@ -125,9 +126,9 @@ def test_copy(): def test_repr(): cd = cirq.ClassicalDataDictionaryStore( - _measurements={mkey_m: (0, 1)}, - _measured_qubits={mkey_m: two_qubits}, - _channel_measurements={mkey_c: 3}, + _measurements={mkey_m: [(0, 1)]}, + _measured_qubits={mkey_m: [two_qubits]}, + _channel_measurements={mkey_c: [3]}, _measurement_types={ mkey_m: cirq.MeasurementType.MEASUREMENT, mkey_c: cirq.MeasurementType.CHANNEL, diff --git a/cirq-core/cirq/value/condition_test.py b/cirq-core/cirq/value/condition_test.py index e92029b1bfb..c4c9f4573f2 100644 --- a/cirq-core/cirq/value/condition_test.py +++ b/cirq-core/cirq/value/condition_test.py @@ -46,14 +46,14 @@ def resolve(measurements): classical_data = cirq.ClassicalDataDictionaryStore(_measurements=measurements) return init_key_condition.resolve(classical_data) - assert resolve({'0:a': [1]}) - assert resolve({'0:a': [2]}) - assert resolve({'0:a': [0, 1]}) - assert resolve({'0:a': [1, 0]}) - assert not resolve({'0:a': [0]}) - assert not resolve({'0:a': [0, 0]}) - assert not resolve({'0:a': []}) - assert not resolve({'0:a': [0], 'b': [1]}) + assert resolve({'0:a': [[1]]}) + assert resolve({'0:a': [[2]]}) + assert resolve({'0:a': [[0, 1]]}) + assert resolve({'0:a': [[1, 0]]}) + assert not resolve({'0:a': [[0]]}) + assert not resolve({'0:a': [[0, 0]]}) + assert not resolve({'0:a': [[]]}) + assert not resolve({'0:a': [[0]], 'b': [[1]]}) with pytest.raises( ValueError, match='Measurement key 0:a missing when testing classical control' ): @@ -61,7 +61,7 @@ def resolve(measurements): with pytest.raises( ValueError, match='Measurement key 0:a missing when testing classical control' ): - _ = resolve({'0:b': [1]}) + _ = resolve({'0:b': [[1]]}) def test_key_condition_qasm(): @@ -88,14 +88,14 @@ def resolve(measurements): classical_data = cirq.ClassicalDataDictionaryStore(_measurements=measurements) return init_sympy_condition.resolve(classical_data) - assert resolve({'0:a': [1]}) - assert resolve({'0:a': [2]}) - assert resolve({'0:a': [0, 1]}) - assert resolve({'0:a': [1, 0]}) - assert not resolve({'0:a': [0]}) - assert not resolve({'0:a': [0, 0]}) - assert not resolve({'0:a': []}) - assert not resolve({'0:a': [0], 'b': [1]}) + assert resolve({'0:a': [[1]]}) + assert resolve({'0:a': [[2]]}) + assert resolve({'0:a': [[0, 1]]}) + assert resolve({'0:a': [[1, 0]]}) + assert not resolve({'0:a': [[0]]}) + assert not resolve({'0:a': [[0, 0]]}) + assert not resolve({'0:a': [[]]}) + assert not resolve({'0:a': [[0]], 'b': [[1]]}) with pytest.raises( ValueError, match=re.escape("Measurement keys ['0:a'] missing when testing classical control"), @@ -105,7 +105,7 @@ def resolve(measurements): ValueError, match=re.escape("Measurement keys ['0:a'] missing when testing classical control"), ): - _ = resolve({'0:b': [1]}) + _ = resolve({'0:b': [[1]]}) def test_sympy_condition_qasm(): From c04f998bb005c95261e046412b68adf1a53b28aa Mon Sep 17 00:00:00 2001 From: daxfohl Date: Wed, 26 Jan 2022 10:00:38 -0800 Subject: [PATCH 71/89] Revert ClassicalData to returning whole dict --- .../ClassicalDataDictionaryStore.json | 22 +++++++++---------- .../act_on_stabilizer_ch_form_args.py | 2 +- .../cirq/sim/clifford/clifford_simulator.py | 2 +- .../clifford/stabilizer_state_ch_form_test.py | 2 +- cirq-core/cirq/value/classical_data.py | 12 +++++----- cirq-core/cirq/value/classical_data_test.py | 12 +++++----- 6 files changed, 25 insertions(+), 27 deletions(-) diff --git a/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.json b/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.json index ae5bdcb687e..54ef613ba29 100644 --- a/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.json +++ b/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.json @@ -17,18 +17,16 @@ "name": "m", "path": [] }, - [ - [ - { - "cirq_type": "LineQubit", - "x": 0 - }, - { - "cirq_type": "LineQubit", - "x": 1 - } - ] - ] + [[ + { + "cirq_type": "LineQubit", + "x": 0 + }, + { + "cirq_type": "LineQubit", + "x": 1 + } + ]] ] ], "channel_measurements": [ diff --git a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py index 65583f78b52..af39d9aaa7f 100644 --- a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py +++ b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py @@ -76,7 +76,7 @@ def sample( state, prng, qubits=self.qubits, classical_data=measurements ) protocols.act_on(op, ch_form_args) - return np.array(list(measurements.measurements.values()), dtype=bool) + return np.array([v[-1] for v in measurements.measurements.values()], dtype=bool) def _x(self, g: common_gates.XPowGate, axis: int): exponent = g.exponent diff --git a/cirq-core/cirq/sim/clifford/clifford_simulator.py b/cirq-core/cirq/sim/clifford/clifford_simulator.py index d8e67ba0d48..945c74be0ae 100644 --- a/cirq-core/cirq/sim/clifford/clifford_simulator.py +++ b/cirq-core/cirq/sim/clifford/clifford_simulator.py @@ -282,4 +282,4 @@ def apply_measurement( state.ch_form, prng, qubits=self.qubit_map.keys(), classical_data=classical_data ) act_on(op, ch_form_args) - measurements.update({str(k): list(v) for k, v in classical_data.measurements.items()}) + measurements.update({str(k): list(v[-1]) for k, v in classical_data.measurements.items()}) diff --git a/cirq-core/cirq/sim/clifford/stabilizer_state_ch_form_test.py b/cirq-core/cirq/sim/clifford/stabilizer_state_ch_form_test.py index 8607b6491ff..0d3e0cb7cf1 100644 --- a/cirq-core/cirq/sim/clifford/stabilizer_state_ch_form_test.py +++ b/cirq-core/cirq/sim/clifford/stabilizer_state_ch_form_test.py @@ -73,6 +73,6 @@ def test_run(): classical_data=classical_data, ) cirq.act_on(op, args) - measurements = {str(k): list(v) for k, v in classical_data.measurements.items()} + measurements = {str(k): list(v[-1]) for k, v in classical_data.measurements.items()} assert measurements['1'] == [1] assert measurements['0'] != measurements['2'] diff --git a/cirq-core/cirq/value/classical_data.py b/cirq-core/cirq/value/classical_data.py index 11f28fc96bb..2a9e230c2c3 100644 --- a/cirq-core/cirq/value/classical_data.py +++ b/cirq-core/cirq/value/classical_data.py @@ -153,19 +153,19 @@ def __init__( ] = _measurement_types @property - def measurements(self) -> Mapping['cirq.MeasurementKey', Tuple[int, ...]]: + def measurements(self) -> Mapping['cirq.MeasurementKey', List[Tuple[int, ...]]]: """Gets the a mapping from measurement key to measurement.""" - return {k: v[-1] for k, v in self._measurements.items()} + return self._measurements @property - def channel_measurements(self) -> Mapping['cirq.MeasurementKey', int]: + def channel_measurements(self) -> Mapping['cirq.MeasurementKey', List[int]]: """Gets the a mapping from measurement key to channel measurement.""" - return {k: v[-1] for k, v in self._channel_measurements.items()} + return self._channel_measurements @property - def measured_qubits(self) -> Mapping['cirq.MeasurementKey', Tuple['cirq.Qid', ...]]: + def measured_qubits(self) -> Mapping['cirq.MeasurementKey', List[Tuple['cirq.Qid', ...]]]: """Gets the a mapping from measurement key to the qubits measured.""" - return {k: v[-1] for k, v in self._measured_qubits.items()} + return self._measured_qubits @property def measurement_types(self) -> Mapping['cirq.MeasurementKey', 'cirq.MeasurementType']: diff --git a/cirq-core/cirq/value/classical_data_test.py b/cirq-core/cirq/value/classical_data_test.py index 208565e44c6..e86b60e2b99 100644 --- a/cirq-core/cirq/value/classical_data_test.py +++ b/cirq-core/cirq/value/classical_data_test.py @@ -33,10 +33,10 @@ def test_init(): _measured_qubits={mkey_m: [two_qubits]}, _channel_measurements={mkey_c: [3]}, ) - assert cd.measurements == {mkey_m: (0, 1)} + assert cd.measurements == {mkey_m: [(0, 1)]} assert cd.keys() == (mkey_m, mkey_c) - assert cd.measured_qubits == {mkey_m: two_qubits} - assert cd.channel_measurements == {mkey_c: 3} + assert cd.measured_qubits == {mkey_m: [two_qubits]} + assert cd.channel_measurements == {mkey_c: [3]} assert cd.measurement_types == { mkey_m: cirq.MeasurementType.MEASUREMENT, mkey_c: cirq.MeasurementType.CHANNEL, @@ -46,9 +46,9 @@ def test_init(): def test_record_measurement(): cd = cirq.ClassicalDataDictionaryStore() cd.record_measurement(mkey_m, (0, 1), two_qubits) - assert cd.measurements == {mkey_m: (0, 1)} + assert cd.measurements == {mkey_m: [(0, 1)]} assert cd.keys() == (mkey_m,) - assert cd.measured_qubits == {mkey_m: two_qubits} + assert cd.measured_qubits == {mkey_m: [two_qubits]} def test_record_measurement_errors(): @@ -66,7 +66,7 @@ def test_record_measurement_errors(): def test_record_channel_measurement(): cd = cirq.ClassicalDataDictionaryStore() cd.record_channel_measurement(mkey_m, 1) - assert cd.channel_measurements == {mkey_m: 1} + assert cd.channel_measurements == {mkey_m: [1]} assert cd.keys() == (mkey_m,) From 96e8eab4c158342746343baa42ed760f911ad149 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Wed, 26 Jan 2022 10:37:09 -0800 Subject: [PATCH 72/89] Allow repeated measurements --- .../classically_controlled_operation_test.py | 37 +++++++++++++++++++ cirq-core/cirq/sim/simulator.py | 2 - cirq-core/cirq/sim/simulator_test.py | 15 -------- cirq-core/cirq/value/condition.py | 14 +++++-- 4 files changed, 47 insertions(+), 21 deletions(-) diff --git a/cirq-core/cirq/ops/classically_controlled_operation_test.py b/cirq-core/cirq/ops/classically_controlled_operation_test.py index 3089d94fcf4..4660dd616ae 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation_test.py +++ b/cirq-core/cirq/ops/classically_controlled_operation_test.py @@ -259,6 +259,43 @@ def test_key_set(sim): assert result.measurements['b'] == 1 +@pytest.mark.parametrize('sim', ALL_SIMULATORS) +def test_repeated_measurement_unset(sim): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.measure(q0, key='a'), + cirq.X(q0), + cirq.measure(q0, key='a'), + cirq.X(q1).with_classical_controls(cirq.KeyCondition(cirq.MeasurementKey('a'), index=-2)), + cirq.measure(q1, key='b'), + cirq.X(q1).with_classical_controls(cirq.KeyCondition(cirq.MeasurementKey('a'), index=-1)), + cirq.measure(q1, key='c'), + ) + result = sim.run(circuit) + assert result.measurements['a'] == 1 + assert result.measurements['b'] == 0 + assert result.measurements['c'] == 1 + + +@pytest.mark.parametrize('sim', ALL_SIMULATORS) +def test_repeated_measurement_set(sim): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.X(q0), + cirq.measure(q0, key='a'), + cirq.X(q0), + cirq.measure(q0, key='a'), + cirq.X(q1).with_classical_controls(cirq.KeyCondition(cirq.MeasurementKey('a'), index=-2)), + cirq.measure(q1, key='b'), + cirq.X(q1).with_classical_controls(cirq.KeyCondition(cirq.MeasurementKey('a'), index=-1)), + cirq.measure(q1, key='c'), + ) + result = sim.run(circuit) + assert result.measurements['a'] == 0 + assert result.measurements['b'] == 1 + assert result.measurements['c'] == 1 + + @pytest.mark.parametrize('sim', ALL_SIMULATORS) def test_subcircuit_key_unset(sim): q0, q1 = cirq.LineQubit.range(2) diff --git a/cirq-core/cirq/sim/simulator.py b/cirq-core/cirq/sim/simulator.py index cb9b87785df..cb080aa93ad 100644 --- a/cirq-core/cirq/sim/simulator.py +++ b/cirq-core/cirq/sim/simulator.py @@ -102,8 +102,6 @@ def run_sweep_iter( if not program.has_measurements(): raise ValueError("Circuit has no measurements to sample.") - _verify_unique_measurement_keys(program) - for param_resolver in study.to_resolvers(params): measurements = {} if repetitions == 0: diff --git a/cirq-core/cirq/sim/simulator_test.py b/cirq-core/cirq/sim/simulator_test.py index 35589b3fb17..0ac3a864858 100644 --- a/cirq-core/cirq/sim/simulator_test.py +++ b/cirq-core/cirq/sim/simulator_test.py @@ -389,21 +389,6 @@ def test_simulation_trial_result_qubit_map(): assert result.qubit_map == {q[0]: 0, q[1]: 1} -def test_verify_unique_measurement_keys(): - q = cirq.LineQubit.range(2) - circuit = cirq.Circuit() - circuit.append( - [ - cirq.measure(q[0], key='a'), - cirq.measure(q[1], key='a'), - cirq.measure(q[0], key='b'), - cirq.measure(q[1], key='b'), - ] - ) - with pytest.raises(ValueError, match='Measurement key a,b repeated'): - _ = cirq.sample(circuit) - - def test_simulate_with_invert_mask(): class PlusGate(cirq.Gate): """A qudit gate that increments a qudit state mod its dimension.""" diff --git a/cirq-core/cirq/value/condition.py b/cirq-core/cirq/value/condition.py index 7c594eb2d95..717426b4baa 100644 --- a/cirq-core/cirq/value/condition.py +++ b/cirq-core/cirq/value/condition.py @@ -87,6 +87,7 @@ class KeyCondition(Condition): """ key: 'cirq.MeasurementKey' + index: int = -1 @property def keys(self): @@ -96,9 +97,11 @@ def replace_key(self, current: 'cirq.MeasurementKey', replacement: 'cirq.Measure return KeyCondition(replacement) if self.key == current else self def __str__(self): - return str(self.key) + return str(self.key) if self.index == -1 else f'{self.key}[{self.index}]' def __repr__(self): + if self.index != -1: + return f'cirq.KeyCondition({self.key!r}, {self.index})' return f'cirq.KeyCondition({self.key!r})' def resolve( @@ -107,17 +110,20 @@ def resolve( ) -> bool: if self.key not in classical_data.keys(): raise ValueError(f'Measurement key {self.key} missing when testing classical control') - return classical_data.get_int(self.key) != 0 + return classical_data.get_int(self.key, self.index) != 0 def _json_dict_(self): - return json_serialization.dataclass_json_dict(self) + fields = ['key'] if self.index == -1 else ['key', 'index'] + return json_serialization.obj_to_dict_helper(self, fields) @classmethod def _from_json_dict_(cls, key, **kwargs): - return cls(key=key) + return cls(key=key, index=kwargs.get('index', -1)) @property def qasm(self): + if self.index != -1: + raise NotImplementedError('Only most recent measurement at key can be used for QASM.') return f'm_{self.key}!=0' From 45df9ffe70d76378540ad7e9d2c00489b48eea00 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Fri, 28 Jan 2022 14:03:59 -0800 Subject: [PATCH 73/89] update commutes --- cirq-core/cirq/circuits/circuit.py | 6 +++--- cirq-core/cirq/ops/raw_types.py | 12 +++++++----- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/cirq-core/cirq/circuits/circuit.py b/cirq-core/cirq/circuits/circuit.py index d8ca53cf7f7..3356f369d76 100644 --- a/cirq-core/cirq/circuits/circuit.py +++ b/cirq-core/cirq/circuits/circuit.py @@ -1957,11 +1957,11 @@ def _prev_moment_available(self, op: 'cirq.Operation', end_moment_index: int) -> while k > 0: k -= 1 moment = self._moments[k] - # This should also validate that measurement keys are disjoint once we allow repeated - # measurements. Search for same message in raw_types.py. + moment_measurement_keys = protocols.measurement_key_objs(moment) if ( moment.operates_on(op_qubits) - or not op_control_keys.isdisjoint(protocols.measurement_key_objs(moment)) + or not op_measurement_keys.isdisjoint(moment_measurement_keys) + or not op_control_keys.isdisjoint(moment_measurement_keys) or not protocols.control_keys(moment).isdisjoint(op_measurement_keys) ): return last_available diff --git a/cirq-core/cirq/ops/raw_types.py b/cirq-core/cirq/ops/raw_types.py index e5600fec6c6..ad2411a4b1e 100644 --- a/cirq-core/cirq/ops/raw_types.py +++ b/cirq-core/cirq/ops/raw_types.py @@ -572,11 +572,13 @@ def _commutes_( if not isinstance(other, Operation): return NotImplemented - # This should also validate that measurement keys are disjoint once we allow repeated - # measurements. Search for same message in circuit.py. - if not protocols.control_keys(self).isdisjoint( - protocols.measurement_key_objs(other) - ) or not protocols.control_keys(other).isdisjoint(protocols.measurement_key_objs(self)): + self_keys = protocols.measurement_key_objs(self) + other_keys = protocols.measurement_key_objs(other) + if ( + not self_keys.isdisjoint(other_keys) + or not protocols.control_keys(self).isdisjoint(other_keys) + or not protocols.control_keys(other).isdisjoint(self_keys) + ): return False if hasattr(other, 'qubits') and set(self.qubits).isdisjoint(other.qubits): From 52614261092bc3b43910071032078fe07cff08c1 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Mon, 7 Feb 2022 17:40:03 -0800 Subject: [PATCH 74/89] Fix merge --- .../ClassicalDataDictionaryStore.json | 16 ---------------- .../ClassicalDataDictionaryStore.repr | 6 +----- cirq-core/cirq/sim/act_on_args.py | 2 -- cirq-core/cirq/sim/act_on_density_matrix_args.py | 2 -- .../cirq/sim/clifford/act_on_stabilizer_args.py | 4 ++-- .../cirq/sim/clifford/clifford_simulator.py | 2 +- 6 files changed, 4 insertions(+), 28 deletions(-) diff --git a/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.json b/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.json index c0f5a219d12..54ef613ba29 100644 --- a/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.json +++ b/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.json @@ -7,11 +7,7 @@ "name": "m", "path": [] }, -<<<<<<< HEAD [[0, 1]] -======= - [0, 1] ->>>>>>> master ] ], "measured_qubits": [ @@ -21,11 +17,7 @@ "name": "m", "path": [] }, -<<<<<<< HEAD [[ -======= - [ ->>>>>>> master { "cirq_type": "LineQubit", "x": 0 @@ -34,11 +26,7 @@ "cirq_type": "LineQubit", "x": 1 } -<<<<<<< HEAD ]] -======= - ] ->>>>>>> master ] ], "channel_measurements": [ @@ -48,11 +36,7 @@ "name": "c", "path": [] }, -<<<<<<< HEAD [3] -======= - 3 ->>>>>>> master ] ], "measurement_types": [ diff --git a/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.repr b/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.repr index bcc3385a989..18351d55637 100644 --- a/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.repr +++ b/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.repr @@ -1,5 +1 @@ -<<<<<<< HEAD -cirq.ClassicalDataDictionaryStore(_measurements={cirq.MeasurementKey('m'): [[0, 1]]}, _measured_qubits={cirq.MeasurementKey('m'): [[cirq.LineQubit(0), cirq.LineQubit(1)]]}, _channel_measurements={cirq.MeasurementKey('c'): [3]}, _measurement_types={cirq.MeasurementKey('m'): cirq.MeasurementType.MEASUREMENT, cirq.MeasurementKey('c'): cirq.MeasurementType.CHANNEL}) -======= -cirq.ClassicalDataDictionaryStore(_measurements={cirq.MeasurementKey('m'): [0, 1]}, _measured_qubits={cirq.MeasurementKey('m'): [cirq.LineQubit(0), cirq.LineQubit(1)]}, _channel_measurements={cirq.MeasurementKey('c'): 3}, _measurement_types={cirq.MeasurementKey('m'): cirq.MeasurementType.MEASUREMENT, cirq.MeasurementKey('c'): cirq.MeasurementType.CHANNEL}) ->>>>>>> master +cirq.ClassicalDataDictionaryStore(_measurements={cirq.MeasurementKey('m'): [[0, 1]]}, _measured_qubits={cirq.MeasurementKey('m'): [[cirq.LineQubit(0), cirq.LineQubit(1)]]}, _channel_measurements={cirq.MeasurementKey('c'): [3]}, _measurement_types={cirq.MeasurementKey('m'): cirq.MeasurementType.MEASUREMENT, cirq.MeasurementKey('c'): cirq.MeasurementType.CHANNEL}) \ No newline at end of file diff --git a/cirq-core/cirq/sim/act_on_args.py b/cirq-core/cirq/sim/act_on_args.py index c09040c207e..66616472580 100644 --- a/cirq-core/cirq/sim/act_on_args.py +++ b/cirq-core/cirq/sim/act_on_args.py @@ -62,8 +62,6 @@ def __init__( ordering of the computational basis states. log_of_measurement_results: A mutable object that measurements are being recorded into. - classical_data: The shared classical data container for this - simulation. ignore_measurement_results: If True, then the simulation will treat measurement as dephasing instead of collapsing process, and not log the result. This is only applicable to diff --git a/cirq-core/cirq/sim/act_on_density_matrix_args.py b/cirq-core/cirq/sim/act_on_density_matrix_args.py index 3b4d25adda6..35ed79de68d 100644 --- a/cirq-core/cirq/sim/act_on_density_matrix_args.py +++ b/cirq-core/cirq/sim/act_on_density_matrix_args.py @@ -70,8 +70,6 @@ def __init__( effects. log_of_measurement_results: A mutable object that measurements are being recorded into. - classical_data: The shared classical data container for this - simulation. ignore_measurement_results: If True, then the simulation will treat measurement as dephasing instead of collapsing process. This is only applicable to simulators that can diff --git a/cirq-core/cirq/sim/clifford/act_on_stabilizer_args.py b/cirq-core/cirq/sim/clifford/act_on_stabilizer_args.py index d9051f21147..5cc45ef6297 100644 --- a/cirq-core/cirq/sim/clifford/act_on_stabilizer_args.py +++ b/cirq-core/cirq/sim/clifford/act_on_stabilizer_args.py @@ -13,7 +13,7 @@ # limitations under the License. import abc -from typing import Any, Dict, Generic, Optional, Sequence, TYPE_CHECKING, TypeVar, Union +from typing import Any, Dict, Generic, List, Optional, Sequence, TYPE_CHECKING, TypeVar, Union import numpy as np @@ -38,7 +38,7 @@ def __init__( self, state: TStabilizerState, prng: Optional[np.random.RandomState] = None, - log_of_measurement_results: Optional[Dict[str, Any]] = None, + log_of_measurement_results: Optional[Dict[str, List[int]]] = None, qubits: Optional[Sequence['cirq.Qid']] = None, classical_data: Optional['cirq.ClassicalDataStore'] = None, ): diff --git a/cirq-core/cirq/sim/clifford/clifford_simulator.py b/cirq-core/cirq/sim/clifford/clifford_simulator.py index c5c98cf77a4..04f6a0da2c4 100644 --- a/cirq-core/cirq/sim/clifford/clifford_simulator.py +++ b/cirq-core/cirq/sim/clifford/clifford_simulator.py @@ -99,8 +99,8 @@ def _create_partial_act_on_args( return clifford.ActOnStabilizerCHFormArgs( prng=self._prng, - qubits=qubits, classical_data=classical_data, + qubits=qubits, initial_state=initial_state, ) From d475e69795308c763906e42aabe23d390d8c09f5 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Mon, 7 Feb 2022 17:42:15 -0800 Subject: [PATCH 75/89] Remove rebinding conflict check in measurement_key.py --- cirq-core/cirq/value/measurement_key.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/cirq-core/cirq/value/measurement_key.py b/cirq-core/cirq/value/measurement_key.py index e53eac47fde..e888325db81 100644 --- a/cirq-core/cirq/value/measurement_key.py +++ b/cirq-core/cirq/value/measurement_key.py @@ -130,10 +130,7 @@ def _with_rescoped_keys_( path: Tuple[str, ...], bindable_keys: FrozenSet['MeasurementKey'], ): - new_key = self.replace(path=path + self.path) - if new_key in bindable_keys: - raise ValueError(f'Conflicting measurement keys found: {new_key}') - return new_key + return self.replace(path=path + self.path) def _with_measurement_key_mapping_(self, key_map: Dict[str, str]): if self.name not in key_map: From 0df173f8082bcafcbc06e08524db18437843e0ba Mon Sep 17 00:00:00 2001 From: daxfohl Date: Mon, 7 Feb 2022 17:59:01 -0800 Subject: [PATCH 76/89] Remove outdated tests --- .../cirq/circuits/circuit_operation_test.py | 20 ------------------- 1 file changed, 20 deletions(-) diff --git a/cirq-core/cirq/circuits/circuit_operation_test.py b/cirq-core/cirq/circuits/circuit_operation_test.py index e4b8c294e90..f917ffc89cf 100644 --- a/cirq-core/cirq/circuits/circuit_operation_test.py +++ b/cirq-core/cirq/circuits/circuit_operation_test.py @@ -834,24 +834,4 @@ def test_mapped_circuit_keeps_keys_under_parent_path(): assert cirq.measurement_key_names(op2.mapped_circuit()) == {'X:A', 'X:B', 'X:C', 'X:D'} -def test_keys_conflict_no_repetitions(): - q = cirq.LineQubit(0) - op1 = cirq.CircuitOperation( - cirq.FrozenCircuit( - cirq.measure(q, key='A'), - ) - ) - op2 = cirq.CircuitOperation(cirq.FrozenCircuit(op1, op1)) - with pytest.raises(ValueError, match='Conflicting measurement keys found: A'): - _ = op2.mapped_circuit(deep=True) - - -def test_keys_conflict_locally(): - q = cirq.LineQubit(0) - op1 = cirq.measure(q, key='A') - op2 = cirq.CircuitOperation(cirq.FrozenCircuit(op1, op1)) - with pytest.raises(ValueError, match='Conflicting measurement keys found: A'): - _ = op2.mapped_circuit() - - # TODO: Operation has a "gate" property. What is this for a CircuitOperation? From 0b1d45f52f79680974c4e018182b859a16ad9846 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Mon, 7 Feb 2022 18:02:39 -0800 Subject: [PATCH 77/89] coverage --- cirq-core/cirq/value/condition_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cirq-core/cirq/value/condition_test.py b/cirq-core/cirq/value/condition_test.py index c4c9f4573f2..88684ce56db 100644 --- a/cirq-core/cirq/value/condition_test.py +++ b/cirq-core/cirq/value/condition_test.py @@ -35,10 +35,12 @@ def test_key_condition_with_keys(): def test_key_condition_str(): assert str(init_key_condition) == '0:a' + assert str(cirq.KeyCondition(key_a, index=-2)) == '0:a[-2]' def test_key_condition_repr(): cirq.testing.assert_equivalent_repr(init_key_condition) + cirq.testing.assert_equivalent_repr(cirq.KeyCondition(key_a, index=-2)) def test_key_condition_resolve(): From e24c04899c29bee9aa10435d6c236a1b619ca250 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Fri, 11 Feb 2022 09:44:30 -0800 Subject: [PATCH 78/89] Revert serialization changes, improve error messages --- .../ClassicallyControlledOperation.json | 6 +++-- .../json_test_data/KeyCondition.json | 3 ++- cirq-core/cirq/value/classical_data.py | 22 ++++++++++--------- cirq-core/cirq/value/classical_data_test.py | 9 ++++++-- cirq-core/cirq/value/condition.py | 5 ++--- 5 files changed, 27 insertions(+), 18 deletions(-) diff --git a/cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.json b/cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.json index 8fbae9b27c7..485e0bfa04f 100644 --- a/cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.json +++ b/cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.json @@ -7,7 +7,8 @@ "cirq_type": "MeasurementKey", "name": "a", "path": [] - } + }, + "index": -1 }, { "cirq_type": "KeyCondition", @@ -15,7 +16,8 @@ "cirq_type": "MeasurementKey", "name": "b", "path": [] - } + }, + "index": -1 } ], "sub_operation": { diff --git a/cirq-core/cirq/protocols/json_test_data/KeyCondition.json b/cirq-core/cirq/protocols/json_test_data/KeyCondition.json index f5b81ba63dc..4478ba5c6c1 100644 --- a/cirq-core/cirq/protocols/json_test_data/KeyCondition.json +++ b/cirq-core/cirq/protocols/json_test_data/KeyCondition.json @@ -4,5 +4,6 @@ "cirq_type": "MeasurementKey", "name": "a", "path": [] - } + }, + "index": -1 } \ No newline at end of file diff --git a/cirq-core/cirq/value/classical_data.py b/cirq-core/cirq/value/classical_data.py index 24751135126..05ab46d2025 100644 --- a/cirq-core/cirq/value/classical_data.py +++ b/cirq-core/cirq/value/classical_data.py @@ -188,8 +188,10 @@ def record_measurement( raise ValueError(f"Channel Measurement already logged to key {key}") measured_qubits = self._measured_qubits[key] if measured_qubits: - if [q.dimension for q in qubits] != [q.dimension for q in measured_qubits[-1]]: - raise ValueError(f'Measurements of keys must all be same shape.') + shape = tuple(q.dimension for q in qubits) + key_shape = tuple(q.dimension for q in measured_qubits[-1]) + if shape != key_shape: + raise ValueError(f'Measurement shape {shape} does not match {key_shape} in {key}.') measured_qubits.append(tuple(qubits)) self._measurements[key].append(tuple(measurement)) @@ -231,10 +233,10 @@ def copy(self): def _json_dict_(self): return { - 'measurements': list(self._measurements.items()), - 'measured_qubits': list(self._measured_qubits.items()), - 'channel_measurements': list(self._channel_measurements.items()), - 'measurement_types': list(self._measurement_types.items()), + 'measurements': list(self.measurements.items()), + 'measured_qubits': list(self.measured_qubits.items()), + 'channel_measurements': list(self.channel_measurements.items()), + 'measurement_types': list(self.measurement_types.items()), } @classmethod @@ -250,10 +252,10 @@ def _from_json_dict_( def __repr__(self): return ( - f'cirq.ClassicalDataDictionaryStore(_measurements={self._measurements!r},' - f' _measured_qubits={self._measured_qubits!r},' - f' _channel_measurements={self._channel_measurements!r},' - f' _measurement_types={self._measurement_types!r})' + f'cirq.ClassicalDataDictionaryStore(_measurements={self.measurements!r},' + f' _measured_qubits={self.measured_qubits!r},' + f' _channel_measurements={self.channel_measurements!r},' + f' _measurement_types={self.measurement_types!r})' ) def _value_equality_values_(self): diff --git a/cirq-core/cirq/value/classical_data_test.py b/cirq-core/cirq/value/classical_data_test.py index e86b60e2b99..ce4d3b661b9 100644 --- a/cirq-core/cirq/value/classical_data_test.py +++ b/cirq-core/cirq/value/classical_data_test.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import re import pytest @@ -57,9 +58,13 @@ def test_record_measurement_errors(): cd.record_measurement(mkey_m, (0, 1, 2), two_qubits) cd.record_measurement(mkey_m, (0, 1), two_qubits) cd.record_measurement(mkey_m, (1, 0), two_qubits) - with pytest.raises(ValueError, match='Measurements of keys must all be same shape'): + with pytest.raises( + ValueError, match=re.escape('Measurement shape (2, 2, 2) does not match (2, 2) in m') + ): cd.record_measurement(mkey_m, (1, 0, 4), tuple(cirq.LineQubit.range(3))) - with pytest.raises(ValueError, match='Measurements of keys must all be same shape'): + with pytest.raises( + ValueError, match=re.escape('Measurement shape (3, 3) does not match (2, 2) in m') + ): cd.record_measurement(mkey_m, (1, 0), tuple(cirq.LineQid.range(2, dimension=3))) diff --git a/cirq-core/cirq/value/condition.py b/cirq-core/cirq/value/condition.py index 717426b4baa..081c2d04e71 100644 --- a/cirq-core/cirq/value/condition.py +++ b/cirq-core/cirq/value/condition.py @@ -113,12 +113,11 @@ def resolve( return classical_data.get_int(self.key, self.index) != 0 def _json_dict_(self): - fields = ['key'] if self.index == -1 else ['key', 'index'] - return json_serialization.obj_to_dict_helper(self, fields) + return json_serialization.dataclass_json_dict(self) @classmethod def _from_json_dict_(cls, key, **kwargs): - return cls(key=key, index=kwargs.get('index', -1)) + return cls(key=key) @property def qasm(self): From 3b022dcfff129c89ebff8f1bdafa123b5d08c2f8 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Fri, 11 Feb 2022 12:26:50 -0800 Subject: [PATCH 79/89] Re-add removed tests, with new expectations --- .../cirq/circuits/circuit_operation_test.py | 24 +++++++++++++++++++ cirq-core/cirq/ops/measurement_gate_test.py | 19 +++++++++++++++ cirq-core/cirq/sim/simulator_test.py | 15 ++++++++++++ 3 files changed, 58 insertions(+) diff --git a/cirq-core/cirq/circuits/circuit_operation_test.py b/cirq-core/cirq/circuits/circuit_operation_test.py index f917ffc89cf..7fbf782ef9b 100644 --- a/cirq-core/cirq/circuits/circuit_operation_test.py +++ b/cirq-core/cirq/circuits/circuit_operation_test.py @@ -834,4 +834,28 @@ def test_mapped_circuit_keeps_keys_under_parent_path(): assert cirq.measurement_key_names(op2.mapped_circuit()) == {'X:A', 'X:B', 'X:C', 'X:D'} +def test_mapped_circuit_allows_repeated_keys(): + q = cirq.LineQubit(0) + op1 = cirq.CircuitOperation( + cirq.FrozenCircuit( + cirq.measure(q, key='A'), + ) + ) + op2 = cirq.CircuitOperation(cirq.FrozenCircuit(op1, op1)) + circuit = op2.mapped_circuit(deep=True) + cirq.testing.assert_has_diagram( + circuit, + "0: ───M('A')───M('A')───", + use_unicode_characters=True, + ) + op1 = cirq.measure(q, key='A') + op2 = cirq.CircuitOperation(cirq.FrozenCircuit(op1, op1)) + circuit = op2.mapped_circuit() + cirq.testing.assert_has_diagram( + circuit, + "0: ───M('A')───M('A')───", + use_unicode_characters=True, + ) + + # TODO: Operation has a "gate" property. What is this for a CircuitOperation? diff --git a/cirq-core/cirq/ops/measurement_gate_test.py b/cirq-core/cirq/ops/measurement_gate_test.py index f5f3f16a9f6..cc1a0934b96 100644 --- a/cirq-core/cirq/ops/measurement_gate_test.py +++ b/cirq-core/cirq/ops/measurement_gate_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import cast import numpy as np import pytest @@ -323,7 +324,13 @@ def test_act_on_state_vector(): dtype=np.complex64, ) cirq.act_on(m, args) + datastore = cast(cirq.ClassicalDataDictionaryStore, args.classical_data) + out = cirq.MeasurementKey('out') assert args.log_of_measurement_results == {'out': [0, 1]} + assert datastore.measurements[out] == [(0, 1)] + cirq.act_on(m, args) + assert args.log_of_measurement_results == {'out': [0, 1]} + assert datastore.measurements[out] == [(0, 1), (0, 1)] def test_act_on_clifford_tableau(): @@ -358,7 +365,13 @@ def test_act_on_clifford_tableau(): log_of_measurement_results={}, ) cirq.act_on(m, args) + datastore = cast(cirq.ClassicalDataDictionaryStore, args.classical_data) + out = cirq.MeasurementKey('out') + assert args.log_of_measurement_results == {'out': [0, 1]} + assert datastore.measurements[out] == [(0, 1)] + cirq.act_on(m, args) assert args.log_of_measurement_results == {'out': [0, 1]} + assert datastore.measurements[out] == [(0, 1), (0, 1)] def test_act_on_stabilizer_ch_form(): @@ -393,7 +406,13 @@ def test_act_on_stabilizer_ch_form(): initial_state=10, ) cirq.act_on(m, args) + datastore = cast(cirq.ClassicalDataDictionaryStore, args.classical_data) + out = cirq.MeasurementKey('out') + assert args.log_of_measurement_results == {'out': [0, 1]} + assert datastore.measurements[out] == [(0, 1)] + cirq.act_on(m, args) assert args.log_of_measurement_results == {'out': [0, 1]} + assert datastore.measurements[out] == [(0, 1), (0, 1)] def test_act_on_qutrit(): diff --git a/cirq-core/cirq/sim/simulator_test.py b/cirq-core/cirq/sim/simulator_test.py index 1d604c412cb..a6c0829096d 100644 --- a/cirq-core/cirq/sim/simulator_test.py +++ b/cirq-core/cirq/sim/simulator_test.py @@ -389,6 +389,21 @@ def test_simulation_trial_result_qubit_map(): assert result.qubit_map == {q[0]: 0, q[1]: 1} +def test_verify_unique_measurement_keys(): + q = cirq.LineQubit.range(2) + circuit = cirq.Circuit() + circuit.append( + [ + cirq.measure(q[0], key='a'), + cirq.measure(q[1], key='a'), + cirq.measure(q[0], key='b'), + cirq.measure(q[1], key='b'), + ] + ) + with pytest.raises(ValueError, match='Duplicate MeasurementGate with key a'): + _ = cirq.sample(circuit) + + def test_simulate_with_invert_mask(): class PlusGate(cirq.Gate): """A qudit gate that increments a qudit state mod its dimension.""" From 8158ba0ca0d50f255aa36c75a757b2c7fc62f41c Mon Sep 17 00:00:00 2001 From: daxfohl Date: Fri, 11 Feb 2022 12:31:45 -0800 Subject: [PATCH 80/89] Remove unused _verify_unique_measurement_keys --- cirq-core/cirq/sim/simulator.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/cirq-core/cirq/sim/simulator.py b/cirq-core/cirq/sim/simulator.py index 1cb2a77279f..c9c5bd7e88f 100644 --- a/cirq-core/cirq/sim/simulator.py +++ b/cirq-core/cirq/sim/simulator.py @@ -984,18 +984,6 @@ def _qubit_map_to_shape(qubit_map: Dict['cirq.Qid', int]) -> Tuple[int, ...]: return tuple(qid_shape) -def _verify_unique_measurement_keys(circuit: 'cirq.AbstractCircuit'): - result = collections.Counter( - key - for op in ops.flatten_op_tree(iter(circuit)) - for key in protocols.measurement_key_names(op) - ) - if result: - duplicates = [k for k, v in result.most_common() if v > 1] - if duplicates: - raise ValueError(f"Measurement key {','.join(duplicates)} repeated") - - def check_all_resolved(circuit): """Raises if the circuit contains unresolved symbols.""" if protocols.is_parameterized(circuit): From a728dc3cf0fc6aed58d26fb17b04d5209f57483f Mon Sep 17 00:00:00 2001 From: daxfohl Date: Fri, 11 Feb 2022 12:39:30 -0800 Subject: [PATCH 81/89] Improve sampler key error messages --- cirq-core/cirq/sim/simulator.py | 14 +++++++++----- cirq-core/cirq/sim/simulator_test.py | 5 +++-- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/cirq-core/cirq/sim/simulator.py b/cirq-core/cirq/sim/simulator.py index c9c5bd7e88f..eeaedcaa170 100644 --- a/cirq-core/cirq/sim/simulator.py +++ b/cirq-core/cirq/sim/simulator.py @@ -836,15 +836,19 @@ def sample_measurement_ops( """ # Sanity checks. - seen_measurement_keys: Set[str] = set() for op in measurement_ops: gate = op.gate if not isinstance(gate, ops.MeasurementGate): raise ValueError(f'{op.gate} was not a MeasurementGate') - key = protocols.measurement_key_name(gate) - if key in seen_measurement_keys: - raise ValueError(f'Duplicate MeasurementGate with key {key}') - seen_measurement_keys.add(key) + result = collections.Counter( + key + for op in measurement_ops + for key in protocols.measurement_key_names(op) + ) + if result: + duplicates = [k for k, v in result.most_common() if v > 1] + if duplicates: + raise ValueError(f"Measurement key {','.join(duplicates)} repeated") # Find measured qubits, ensuring a consistent ordering. measured_qubits = [] diff --git a/cirq-core/cirq/sim/simulator_test.py b/cirq-core/cirq/sim/simulator_test.py index a6c0829096d..108b66d00df 100644 --- a/cirq-core/cirq/sim/simulator_test.py +++ b/cirq-core/cirq/sim/simulator_test.py @@ -13,6 +13,7 @@ # limitations under the License. """Tests for simulator.py""" import abc +import re from typing import Generic, Dict, Any, List, Sequence, Union from unittest import mock @@ -238,7 +239,7 @@ def test_step_sample_measurement_ops_not_measurement(): def test_step_sample_measurement_ops_repeated_qubit(): q0, q1, q2 = cirq.LineQubit.range(3) step_result = FakeStepResult([q0]) - with pytest.raises(ValueError, match='MeasurementGate'): + with pytest.raises(ValueError, match='Measurement key 0 repeated'): step_result.sample_measurement_ops( [cirq.measure(q0), cirq.measure(q1, q2), cirq.measure(q0)] ) @@ -400,7 +401,7 @@ def test_verify_unique_measurement_keys(): cirq.measure(q[1], key='b'), ] ) - with pytest.raises(ValueError, match='Duplicate MeasurementGate with key a'): + with pytest.raises(ValueError, match=re.escape('Measurement key a,b repeated')): _ = cirq.sample(circuit) From 269fcf616a29439ab03d0824ad57397a9cb2c672 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Fri, 11 Feb 2022 12:41:22 -0800 Subject: [PATCH 82/89] format --- cirq-core/cirq/sim/simulator.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/cirq-core/cirq/sim/simulator.py b/cirq-core/cirq/sim/simulator.py index eeaedcaa170..f7171cbd16d 100644 --- a/cirq-core/cirq/sim/simulator.py +++ b/cirq-core/cirq/sim/simulator.py @@ -841,9 +841,7 @@ def sample_measurement_ops( if not isinstance(gate, ops.MeasurementGate): raise ValueError(f'{op.gate} was not a MeasurementGate') result = collections.Counter( - key - for op in measurement_ops - for key in protocols.measurement_key_names(op) + key for op in measurement_ops for key in protocols.measurement_key_names(op) ) if result: duplicates = [k for k, v in result.most_common() if v > 1] From 9df9f15043658b53e3743e939405e51660afcafd Mon Sep 17 00:00:00 2001 From: daxfohl Date: Fri, 11 Feb 2022 12:43:15 -0800 Subject: [PATCH 83/89] Remove unnecessary re.escape --- cirq-core/cirq/sim/simulator_test.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cirq-core/cirq/sim/simulator_test.py b/cirq-core/cirq/sim/simulator_test.py index 108b66d00df..bda9b956292 100644 --- a/cirq-core/cirq/sim/simulator_test.py +++ b/cirq-core/cirq/sim/simulator_test.py @@ -13,7 +13,6 @@ # limitations under the License. """Tests for simulator.py""" import abc -import re from typing import Generic, Dict, Any, List, Sequence, Union from unittest import mock @@ -401,7 +400,7 @@ def test_verify_unique_measurement_keys(): cirq.measure(q[1], key='b'), ] ) - with pytest.raises(ValueError, match=re.escape('Measurement key a,b repeated')): + with pytest.raises(ValueError, match='Measurement key a,b repeated'): _ = cirq.sample(circuit) From 69036c9866cacaf35527b521a538c54af5b77e90 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Mon, 14 Feb 2022 13:46:47 -0800 Subject: [PATCH 84/89] Change "measurements" to "records" --- cirq-core/cirq/ops/measurement_gate_test.py | 12 +-- .../ClassicalDataDictionaryStore.json | 4 +- .../ClassicalDataDictionaryStore.repr | 2 +- cirq-core/cirq/sim/act_on_args.py | 2 +- cirq-core/cirq/sim/act_on_args_container.py | 2 +- .../act_on_stabilizer_ch_form_args.py | 2 +- .../cirq/sim/clifford/clifford_simulator.py | 2 +- .../clifford/stabilizer_state_ch_form_test.py | 2 +- cirq-core/cirq/value/classical_data.py | 76 +++++++++---------- cirq-core/cirq/value/classical_data_test.py | 32 ++++---- cirq-core/cirq/value/condition_test.py | 8 +- 11 files changed, 72 insertions(+), 72 deletions(-) diff --git a/cirq-core/cirq/ops/measurement_gate_test.py b/cirq-core/cirq/ops/measurement_gate_test.py index cc1a0934b96..2272b9a58b0 100644 --- a/cirq-core/cirq/ops/measurement_gate_test.py +++ b/cirq-core/cirq/ops/measurement_gate_test.py @@ -327,10 +327,10 @@ def test_act_on_state_vector(): datastore = cast(cirq.ClassicalDataDictionaryStore, args.classical_data) out = cirq.MeasurementKey('out') assert args.log_of_measurement_results == {'out': [0, 1]} - assert datastore.measurements[out] == [(0, 1)] + assert datastore.records[out] == [(0, 1)] cirq.act_on(m, args) assert args.log_of_measurement_results == {'out': [0, 1]} - assert datastore.measurements[out] == [(0, 1), (0, 1)] + assert datastore.records[out] == [(0, 1), (0, 1)] def test_act_on_clifford_tableau(): @@ -368,10 +368,10 @@ def test_act_on_clifford_tableau(): datastore = cast(cirq.ClassicalDataDictionaryStore, args.classical_data) out = cirq.MeasurementKey('out') assert args.log_of_measurement_results == {'out': [0, 1]} - assert datastore.measurements[out] == [(0, 1)] + assert datastore.records[out] == [(0, 1)] cirq.act_on(m, args) assert args.log_of_measurement_results == {'out': [0, 1]} - assert datastore.measurements[out] == [(0, 1), (0, 1)] + assert datastore.records[out] == [(0, 1), (0, 1)] def test_act_on_stabilizer_ch_form(): @@ -409,10 +409,10 @@ def test_act_on_stabilizer_ch_form(): datastore = cast(cirq.ClassicalDataDictionaryStore, args.classical_data) out = cirq.MeasurementKey('out') assert args.log_of_measurement_results == {'out': [0, 1]} - assert datastore.measurements[out] == [(0, 1)] + assert datastore.records[out] == [(0, 1)] cirq.act_on(m, args) assert args.log_of_measurement_results == {'out': [0, 1]} - assert datastore.measurements[out] == [(0, 1), (0, 1)] + assert datastore.records[out] == [(0, 1), (0, 1)] def test_act_on_qutrit(): diff --git a/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.json b/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.json index 54ef613ba29..8d03c5d2059 100644 --- a/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.json +++ b/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.json @@ -1,6 +1,6 @@ { "cirq_type": "ClassicalDataDictionaryStore", - "measurements": [ + "records": [ [ { "cirq_type": "MeasurementKey", @@ -29,7 +29,7 @@ ]] ] ], - "channel_measurements": [ + "channel_records": [ [ { "cirq_type": "MeasurementKey", diff --git a/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.repr b/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.repr index 18351d55637..156350e2cc0 100644 --- a/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.repr +++ b/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.repr @@ -1 +1 @@ -cirq.ClassicalDataDictionaryStore(_measurements={cirq.MeasurementKey('m'): [[0, 1]]}, _measured_qubits={cirq.MeasurementKey('m'): [[cirq.LineQubit(0), cirq.LineQubit(1)]]}, _channel_measurements={cirq.MeasurementKey('c'): [3]}, _measurement_types={cirq.MeasurementKey('m'): cirq.MeasurementType.MEASUREMENT, cirq.MeasurementKey('c'): cirq.MeasurementType.CHANNEL}) \ No newline at end of file +cirq.ClassicalDataDictionaryStore(_records={cirq.MeasurementKey('m'): [[0, 1]]}, _measured_qubits={cirq.MeasurementKey('m'): [[cirq.LineQubit(0), cirq.LineQubit(1)]]}, _channel_records={cirq.MeasurementKey('c'): [3]}, _measurement_types={cirq.MeasurementKey('m'): cirq.MeasurementType.MEASUREMENT, cirq.MeasurementKey('c'): cirq.MeasurementType.CHANNEL}) \ No newline at end of file diff --git a/cirq-core/cirq/sim/act_on_args.py b/cirq-core/cirq/sim/act_on_args.py index 66616472580..bb656f8a3c7 100644 --- a/cirq-core/cirq/sim/act_on_args.py +++ b/cirq-core/cirq/sim/act_on_args.py @@ -76,7 +76,7 @@ def __init__( self._set_qubits(qubits) self.prng = prng self._classical_data = classical_data or value.ClassicalDataDictionaryStore( - _measurements={ + _records={ value.MeasurementKey.parse_serialized(k): [tuple(v)] for k, v in (log_of_measurement_results or {}).items() } diff --git a/cirq-core/cirq/sim/act_on_args_container.py b/cirq-core/cirq/sim/act_on_args_container.py index 7eed1c8c3a8..f450ad9bbec 100644 --- a/cirq-core/cirq/sim/act_on_args_container.py +++ b/cirq-core/cirq/sim/act_on_args_container.py @@ -72,7 +72,7 @@ def __init__( self._qubits = tuple(qubits) self.split_untangled_states = split_untangled_states self._classical_data = classical_data or value.ClassicalDataDictionaryStore( - _measurements={ + _records={ value.MeasurementKey.parse_serialized(k): [tuple(v)] for k, v in (log_of_measurement_results or {}).items() } diff --git a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py index f75aac19a4e..c3c02f4fcc5 100644 --- a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py +++ b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py @@ -119,4 +119,4 @@ def sample( initial_state=state, ) protocols.act_on(op, ch_form_args) - return np.array([v[-1] for v in measurements.measurements.values()], dtype=bool) + return np.array([v[-1] for v in measurements.records.values()], dtype=bool) diff --git a/cirq-core/cirq/sim/clifford/clifford_simulator.py b/cirq-core/cirq/sim/clifford/clifford_simulator.py index 04f6a0da2c4..3f19c253d6f 100644 --- a/cirq-core/cirq/sim/clifford/clifford_simulator.py +++ b/cirq-core/cirq/sim/clifford/clifford_simulator.py @@ -293,4 +293,4 @@ def apply_measurement( initial_state=state.ch_form, ) act_on(op, ch_form_args) - measurements.update({str(k): list(v[-1]) for k, v in classical_data.measurements.items()}) + measurements.update({str(k): list(v[-1]) for k, v in classical_data.records.items()}) diff --git a/cirq-core/cirq/sim/clifford/stabilizer_state_ch_form_test.py b/cirq-core/cirq/sim/clifford/stabilizer_state_ch_form_test.py index f27eab2ae77..30ef642a745 100644 --- a/cirq-core/cirq/sim/clifford/stabilizer_state_ch_form_test.py +++ b/cirq-core/cirq/sim/clifford/stabilizer_state_ch_form_test.py @@ -73,6 +73,6 @@ def test_run(): initial_state=state, ) cirq.act_on(op, args) - measurements = {str(k): list(v[-1]) for k, v in classical_data.measurements.items()} + measurements = {str(k): list(v[-1]) for k, v in classical_data.records.items()} assert measurements['1'] == [1] assert measurements['0'] != measurements['2'] diff --git a/cirq-core/cirq/value/classical_data.py b/cirq-core/cirq/value/classical_data.py index 05ab46d2025..5806f5f69f9 100644 --- a/cirq-core/cirq/value/classical_data.py +++ b/cirq-core/cirq/value/classical_data.py @@ -121,46 +121,46 @@ class ClassicalDataDictionaryStore(ClassicalDataStore): def __init__( self, *, - _measurements: Dict['cirq.MeasurementKey', List[Tuple[int, ...]]] = None, + _records: Dict['cirq.MeasurementKey', List[Tuple[int, ...]]] = None, _measured_qubits: Dict['cirq.MeasurementKey', List[Tuple['cirq.Qid', ...]]] = None, - _channel_measurements: Dict['cirq.MeasurementKey', List[int]] = None, + _channel_records: Dict['cirq.MeasurementKey', List[int]] = None, _measurement_types: Dict['cirq.MeasurementKey', 'cirq.MeasurementType'] = None, ): """Initializes a `ClassicalDataDictionaryStore` object.""" if not _measurement_types: _measurement_types = {} - if _measurements: + if _records: _measurement_types.update( - {k: MeasurementType.MEASUREMENT for k, v in _measurements.items()} + {k: MeasurementType.MEASUREMENT for k, v in _records.items()} ) - if _channel_measurements: + if _channel_records: _measurement_types.update( - {k: MeasurementType.CHANNEL for k, v in _channel_measurements.items()} + {k: MeasurementType.CHANNEL for k, v in _channel_records.items()} ) - if _measurements is None: - _measurements = {} + if _records is None: + _records = {} if _measured_qubits is None: _measured_qubits = {} - if _channel_measurements is None: - _channel_measurements = {} - self._measurements: Dict['cirq.MeasurementKey', List[Tuple[int, ...]]] = _measurements + if _channel_records is None: + _channel_records = {} + self._records: Dict['cirq.MeasurementKey', List[Tuple[int, ...]]] = _records self._measured_qubits: Dict[ 'cirq.MeasurementKey', List[Tuple['cirq.Qid', ...]] ] = _measured_qubits - self._channel_measurements: Dict['cirq.MeasurementKey', List[int]] = _channel_measurements + self._channel_records: Dict['cirq.MeasurementKey', List[int]] = _channel_records self._measurement_types: Dict[ 'cirq.MeasurementKey', 'cirq.MeasurementType' ] = _measurement_types @property - def measurements(self) -> Mapping['cirq.MeasurementKey', List[Tuple[int, ...]]]: - """Gets the a mapping from measurement key to measurement.""" - return self._measurements + def records(self) -> Mapping['cirq.MeasurementKey', List[Tuple[int, ...]]]: + """Gets the a mapping from measurement key to measurement records.""" + return self._records @property - def channel_measurements(self) -> Mapping['cirq.MeasurementKey', List[int]]: - """Gets the a mapping from measurement key to channel measurement.""" - return self._channel_measurements + def channel_records(self) -> Mapping['cirq.MeasurementKey', List[int]]: + """Gets the a mapping from measurement key to channel measurement records.""" + return self._channel_records @property def measured_qubits(self) -> Mapping['cirq.MeasurementKey', List[Tuple['cirq.Qid', ...]]]: @@ -182,7 +182,7 @@ def record_measurement( raise ValueError(f'{len(measurement)} measurements but {len(qubits)} qubits.') if key not in self._measurement_types: self._measurement_types[key] = MeasurementType.MEASUREMENT - self._measurements[key] = [] + self._records[key] = [] self._measured_qubits[key] = [] if self._measurement_types[key] != MeasurementType.MEASUREMENT: raise ValueError(f"Channel Measurement already logged to key {key}") @@ -193,21 +193,21 @@ def record_measurement( if shape != key_shape: raise ValueError(f'Measurement shape {shape} does not match {key_shape} in {key}.') measured_qubits.append(tuple(qubits)) - self._measurements[key].append(tuple(measurement)) + self._records[key].append(tuple(measurement)) def record_channel_measurement(self, key: 'cirq.MeasurementKey', measurement: int): if key not in self._measurement_types: self._measurement_types[key] = MeasurementType.CHANNEL - self._channel_measurements[key] = [] + self._channel_records[key] = [] if self._measurement_types[key] != MeasurementType.CHANNEL: raise ValueError(f"Measurement already logged to key {key}") - self._channel_measurements[key].append(measurement) + self._channel_records[key].append(measurement) def get_digits(self, key: 'cirq.MeasurementKey', index=-1) -> Tuple[int, ...]: return ( - self._measurements[key][index] + self._records[key][index] if self._measurement_types[key] == MeasurementType.MEASUREMENT - else (self._channel_measurements[key][index],) + else (self._channel_records[key][index],) ) def get_int(self, key: 'cirq.MeasurementKey', index=-1) -> int: @@ -215,53 +215,53 @@ def get_int(self, key: 'cirq.MeasurementKey', index=-1) -> int: raise KeyError(f'The measurement key {key} is not in {self._measurement_types}') measurement_type = self._measurement_types[key] if measurement_type == MeasurementType.CHANNEL: - return self._channel_measurements[key][index] + return self._channel_records[key][index] if key not in self._measured_qubits: - return digits.big_endian_bits_to_int(self._measurements[key][index]) + return digits.big_endian_bits_to_int(self._records[key][index]) return digits.big_endian_digits_to_int( - self._measurements[key][index], + self._records[key][index], base=[q.dimension for q in self._measured_qubits[key][index]], ) def copy(self): return ClassicalDataDictionaryStore( - _measurements=self._measurements.copy(), + _records=self._records.copy(), _measured_qubits=self._measured_qubits.copy(), - _channel_measurements=self._channel_measurements.copy(), + _channel_records=self._channel_records.copy(), _measurement_types=self._measurement_types.copy(), ) def _json_dict_(self): return { - 'measurements': list(self.measurements.items()), + 'records': list(self.records.items()), 'measured_qubits': list(self.measured_qubits.items()), - 'channel_measurements': list(self.channel_measurements.items()), + 'channel_records': list(self.channel_records.items()), 'measurement_types': list(self.measurement_types.items()), } @classmethod def _from_json_dict_( - cls, measurements, measured_qubits, channel_measurements, measurement_types, **kwargs + cls, records, measured_qubits, channel_records, measurement_types, **kwargs ): return cls( - _measurements=dict(measurements), + _records=dict(records), _measured_qubits=dict(measured_qubits), - _channel_measurements=dict(channel_measurements), + _channel_records=dict(channel_records), _measurement_types=dict(measurement_types), ) def __repr__(self): return ( - f'cirq.ClassicalDataDictionaryStore(_measurements={self.measurements!r},' + f'cirq.ClassicalDataDictionaryStore(_records={self.records!r},' f' _measured_qubits={self.measured_qubits!r},' - f' _channel_measurements={self.channel_measurements!r},' + f' _channel_records={self.channel_records!r},' f' _measurement_types={self.measurement_types!r})' ) def _value_equality_values_(self): return ( - self._measurements, - self._channel_measurements, + self._records, + self._channel_records, self._measurement_types, self._measured_qubits, ) diff --git a/cirq-core/cirq/value/classical_data_test.py b/cirq-core/cirq/value/classical_data_test.py index ce4d3b661b9..9b4fa156bec 100644 --- a/cirq-core/cirq/value/classical_data_test.py +++ b/cirq-core/cirq/value/classical_data_test.py @@ -24,20 +24,20 @@ def test_init(): cd = cirq.ClassicalDataDictionaryStore() - assert cd.measurements == {} + assert cd.records == {} assert cd.keys() == () assert cd.measured_qubits == {} - assert cd.channel_measurements == {} + assert cd.channel_records == {} assert cd.measurement_types == {} cd = cirq.ClassicalDataDictionaryStore( - _measurements={mkey_m: [(0, 1)]}, + _records={mkey_m: [(0, 1)]}, _measured_qubits={mkey_m: [two_qubits]}, - _channel_measurements={mkey_c: [3]}, + _channel_records={mkey_c: [3]}, ) - assert cd.measurements == {mkey_m: [(0, 1)]} + assert cd.records == {mkey_m: [(0, 1)]} assert cd.keys() == (mkey_m, mkey_c) assert cd.measured_qubits == {mkey_m: [two_qubits]} - assert cd.channel_measurements == {mkey_c: [3]} + assert cd.channel_records == {mkey_c: [3]} assert cd.measurement_types == { mkey_m: cirq.MeasurementType.MEASUREMENT, mkey_c: cirq.MeasurementType.CHANNEL, @@ -47,7 +47,7 @@ def test_init(): def test_record_measurement(): cd = cirq.ClassicalDataDictionaryStore() cd.record_measurement(mkey_m, (0, 1), two_qubits) - assert cd.measurements == {mkey_m: [(0, 1)]} + assert cd.records == {mkey_m: [(0, 1)]} assert cd.keys() == (mkey_m,) assert cd.measured_qubits == {mkey_m: [two_qubits]} @@ -71,7 +71,7 @@ def test_record_measurement_errors(): def test_record_channel_measurement(): cd = cirq.ClassicalDataDictionaryStore() cd.record_channel_measurement(mkey_m, 1) - assert cd.channel_measurements == {mkey_m: [1]} + assert cd.channel_records == {mkey_m: [1]} assert cd.keys() == (mkey_m,) @@ -108,9 +108,9 @@ def test_get_int(): def test_copy(): cd = cirq.ClassicalDataDictionaryStore( - _measurements={mkey_m: [(0, 1)]}, + _records={mkey_m: [(0, 1)]}, _measured_qubits={mkey_m: [two_qubits]}, - _channel_measurements={mkey_c: [3]}, + _channel_records={mkey_c: [3]}, _measurement_types={ mkey_m: cirq.MeasurementType.MEASUREMENT, mkey_c: cirq.MeasurementType.CHANNEL, @@ -119,21 +119,21 @@ def test_copy(): cd1 = cd.copy() assert cd1 is not cd assert cd1 == cd - assert cd1.measurements is not cd.measurements - assert cd1.measurements == cd.measurements + assert cd1.records is not cd.records + assert cd1.records == cd.records assert cd1.measured_qubits is not cd.measured_qubits assert cd1.measured_qubits == cd.measured_qubits - assert cd1.channel_measurements is not cd.channel_measurements - assert cd1.channel_measurements == cd.channel_measurements + assert cd1.channel_records is not cd.channel_records + assert cd1.channel_records == cd.channel_records assert cd1.measurement_types is not cd.measurement_types assert cd1.measurement_types == cd.measurement_types def test_repr(): cd = cirq.ClassicalDataDictionaryStore( - _measurements={mkey_m: [(0, 1)]}, + _records={mkey_m: [(0, 1)]}, _measured_qubits={mkey_m: [two_qubits]}, - _channel_measurements={mkey_c: [3]}, + _channel_records={mkey_c: [3]}, _measurement_types={ mkey_m: cirq.MeasurementType.MEASUREMENT, mkey_c: cirq.MeasurementType.CHANNEL, diff --git a/cirq-core/cirq/value/condition_test.py b/cirq-core/cirq/value/condition_test.py index 88684ce56db..29148853994 100644 --- a/cirq-core/cirq/value/condition_test.py +++ b/cirq-core/cirq/value/condition_test.py @@ -44,8 +44,8 @@ def test_key_condition_repr(): def test_key_condition_resolve(): - def resolve(measurements): - classical_data = cirq.ClassicalDataDictionaryStore(_measurements=measurements) + def resolve(records): + classical_data = cirq.ClassicalDataDictionaryStore(_records=records) return init_key_condition.resolve(classical_data) assert resolve({'0:a': [[1]]}) @@ -86,8 +86,8 @@ def test_sympy_condition_repr(): def test_sympy_condition_resolve(): - def resolve(measurements): - classical_data = cirq.ClassicalDataDictionaryStore(_measurements=measurements) + def resolve(records): + classical_data = cirq.ClassicalDataDictionaryStore(_records=records) return init_sympy_condition.resolve(classical_data) assert resolve({'0:a': [[1]]}) From f3fd9fe908e64097f71f56c866202c7672ffefc9 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Mon, 14 Feb 2022 14:04:34 -0800 Subject: [PATCH 85/89] Simplify ch-form sampling --- .../clifford/act_on_stabilizer_ch_form_args.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py index c3c02f4fcc5..44689e4cc48 100644 --- a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py +++ b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py @@ -16,7 +16,7 @@ import numpy as np -from cirq import _compat, value, ops, protocols +from cirq import _compat, value from cirq.sim.clifford import stabilizer_state_ch_form from cirq.sim.clifford.act_on_stabilizer_args import ActOnStabilizerArgs @@ -107,16 +107,10 @@ def sample( repetitions: int = 1, seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, ) -> np.ndarray: - measurements = value.ClassicalDataDictionaryStore() prng = value.parse_random_state(seed) - for i in range(repetitions): - op = ops.measure(*qubits, key=str(i)) + axes = self.get_axes(qubits) + measurements = [] + for _ in range(repetitions): state = self.state.copy() - ch_form_args = ActOnStabilizerCHFormArgs( - classical_data=measurements, - prng=prng, - qubits=self.qubits, - initial_state=state, - ) - protocols.act_on(op, ch_form_args) - return np.array([v[-1] for v in measurements.records.values()], dtype=bool) + measurements.append([state._measure(i, prng) for i in axes]) + return np.array(measurements, dtype=bool) From e70010889ff6d139eec6e5302bba927924ebfc50 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Mon, 14 Feb 2022 18:18:45 -0800 Subject: [PATCH 86/89] Push 3d samples into result --- .../classically_controlled_operation_test.py | 14 ++++---- cirq-core/cirq/sim/operation_target.py | 3 +- cirq-core/cirq/sim/simulator.py | 23 +++++++------ cirq-core/cirq/sim/simulator_base.py | 34 ++++++++++++------- cirq-core/cirq/sim/simulator_test.py | 14 ++++---- cirq-core/cirq/value/classical_data.py | 10 ++++++ 6 files changed, 59 insertions(+), 39 deletions(-) diff --git a/cirq-core/cirq/ops/classically_controlled_operation_test.py b/cirq-core/cirq/ops/classically_controlled_operation_test.py index c7e9608f293..0daf9f327e7 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation_test.py +++ b/cirq-core/cirq/ops/classically_controlled_operation_test.py @@ -272,9 +272,10 @@ def test_repeated_measurement_unset(sim): cirq.measure(q1, key='c'), ) result = sim.run(circuit) - assert result.measurements['a'] == 1 - assert result.measurements['b'] == 0 - assert result.measurements['c'] == 1 + assert result.records['a'][0][0][0] == 0 + assert result.records['a'][0][1][0] == 1 + assert result.records['b'][0][0][0] == 0 + assert result.records['c'][0][0][0] == 1 @pytest.mark.parametrize('sim', ALL_SIMULATORS) @@ -291,9 +292,10 @@ def test_repeated_measurement_set(sim): cirq.measure(q1, key='c'), ) result = sim.run(circuit) - assert result.measurements['a'] == 0 - assert result.measurements['b'] == 1 - assert result.measurements['c'] == 1 + assert result.records['a'][0][0][0] == 1 + assert result.records['a'][0][1][0] == 0 + assert result.records['b'][0][0][0] == 1 + assert result.records['c'][0][0][0] == 1 @pytest.mark.parametrize('sim', ALL_SIMULATORS) diff --git a/cirq-core/cirq/sim/operation_target.py b/cirq-core/cirq/sim/operation_target.py index e54916303cd..c0381f7e9f9 100644 --- a/cirq-core/cirq/sim/operation_target.py +++ b/cirq-core/cirq/sim/operation_target.py @@ -14,7 +14,6 @@ """An interface for quantum states as targets for operations.""" import abc from typing import ( - Any, Dict, Generic, Iterator, @@ -86,7 +85,7 @@ def qubits(self) -> Tuple['cirq.Qid', ...]: """Gets the qubit order maintained by this target.""" @property - def log_of_measurement_results(self) -> Dict[str, Any]: + def log_of_measurement_results(self) -> Dict[str, List[int]]: """Gets the log of measurement results.""" return {str(k): list(self.classical_data.get_digits(k)) for k in self.classical_data.keys()} diff --git a/cirq-core/cirq/sim/simulator.py b/cirq-core/cirq/sim/simulator.py index f7171cbd16d..1fafab53b8b 100644 --- a/cirq-core/cirq/sim/simulator.py +++ b/cirq-core/cirq/sim/simulator.py @@ -37,7 +37,6 @@ Generic, Iterator, List, - Optional, Sequence, Set, Tuple, @@ -103,15 +102,15 @@ def run_sweep_iter( raise ValueError("Circuit has no measurements to sample.") for param_resolver in study.to_resolvers(params): - measurements = {} + records = {} if repetitions == 0: for _, op, _ in program.findall_operations_with_gate_type(ops.MeasurementGate): - measurements[protocols.measurement_key_name(op)] = np.empty([0, 1]) + records[protocols.measurement_key_name(op)] = np.empty([0, 1, 1]) else: - measurements = self._run( + records = self._run( circuit=program, param_resolver=param_resolver, repetitions=repetitions ) - yield study.ResultDict(params=param_resolver, measurements=measurements) + yield study.ResultDict(params=param_resolver, records=records) @abc.abstractmethod def _run( @@ -130,10 +129,11 @@ def _run( Returns: A dictionary from measurement gate key to measurement - results. Measurement results are stored in a 2-dimensional - numpy array, the first dimension corresponding to the repetition - and the second to the actual boolean measurement results (ordered - by the qubits being measured.) + results. Measurement results are stored in a 3-dimensional + numpy array, the first dimension corresponding to the repetition. + the second to the instance of that key in the circuit, and the + third to the actual boolean measurement results (ordered by the + qubits being measured.) """ raise NotImplementedError() @@ -761,8 +761,9 @@ class StepResult(Generic[TSimulatorState], metaclass=abc.ABCMeta): results, ordered by the qubits that the measurement operates on. """ - def __init__(self, measurements: Optional[Dict[str, List[int]]] = None) -> None: - self.measurements = measurements or collections.defaultdict(list) + def __init__(self, sim_state: 'cirq.OperationTarget') -> None: + self.measurements = sim_state.log_of_measurement_results + self._classical_data = sim_state.classical_data @abc.abstractmethod def _simulator_state(self) -> TSimulatorState: diff --git a/cirq-core/cirq/sim/simulator_base.py b/cirq-core/cirq/sim/simulator_base.py index 610e369611c..5a474c20104 100644 --- a/cirq-core/cirq/sim/simulator_base.py +++ b/cirq-core/cirq/sim/simulator_base.py @@ -20,17 +20,18 @@ import warnings from typing import ( Any, + cast, Dict, Iterator, + Generic, List, + Mapping, + Optional, + Sequence, Tuple, - TYPE_CHECKING, - cast, - Generic, Type, - Sequence, - Optional, TypeVar, + TYPE_CHECKING, ) import numpy as np @@ -264,9 +265,12 @@ def _run( pass assert step_result is not None measurement_ops = [cast(ops.GateOperation, op) for op in general_ops] - return step_result.sample_measurement_ops(measurement_ops, repetitions, seed=self._prng) + samples = step_result.sample_measurement_ops( + measurement_ops, repetitions, seed=self._prng + ) + return {k: np.array([[x] for x in v], dtype=np.uint8) for k, v in samples.items()} - measurements: Dict[str, List[np.ndarray]] = {} + records: Dict['cirq.MeasurementKey', List[np.ndarray]] = {} for i in range(repetitions): if 'deep_copy_buffers' in inspect.signature(act_on_args.copy).parameters: all_step_results = self._core_iterator( @@ -289,11 +293,15 @@ def _run( ) for step_result in all_step_results: pass - for k, v in step_result.measurements.items(): - if k not in measurements: - measurements[k] = [] - measurements[k].append(np.array(v, dtype=np.uint8)) - return {k: np.array(v) for k, v in measurements.items()} + for k, r in step_result._classical_data.records.items(): + if k not in records: + records[k] = [] + records[k].append(r) + for k, cr in step_result._classical_data.channel_records.items(): + if k not in records: + records[k] = [] + records[k].append([cr]) + return {str(k): np.array(v, dtype=np.uint8) for k, v in records.items()} def simulate_sweep_iter( self, @@ -397,7 +405,7 @@ def __init__( """ self._sim_state = sim_state self._merged_sim_state_cache: Optional[TActOnArgs] = None - super().__init__(sim_state.log_of_measurement_results) + super().__init__(sim_state) qubits = sim_state.qubits self._qubits = qubits self._qubit_mapping = {q: i for i, q in enumerate(qubits)} diff --git a/cirq-core/cirq/sim/simulator_test.py b/cirq-core/cirq/sim/simulator_test.py index bda9b956292..13cf73b577f 100644 --- a/cirq-core/cirq/sim/simulator_test.py +++ b/cirq-core/cirq/sim/simulator_test.py @@ -65,13 +65,13 @@ def _create_simulator_trial_result( @mock.patch.multiple(cirq.SimulatesSamples, __abstractmethods__=set(), _run=mock.Mock()) def test_run_simulator_run(): simulator = cirq.SimulatesSamples() - expected_measurements = {'a': np.array([[1]])} + expected_measurements = {'a': np.array([[[1]]])} simulator._run.return_value = expected_measurements circuit = mock.Mock(cirq.Circuit) circuit.__iter__ = mock.Mock(return_value=iter([])) param_resolver = mock.Mock(cirq.ParamResolver) param_resolver.param_dict = {} - expected_result = cirq.ResultDict(measurements=expected_measurements, params=param_resolver) + expected_result = cirq.ResultDict(records=expected_measurements, params=param_resolver) assert expected_result == simulator.run( program=circuit, repetitions=10, param_resolver=param_resolver ) @@ -83,7 +83,7 @@ def test_run_simulator_run(): @mock.patch.multiple(cirq.SimulatesSamples, __abstractmethods__=set(), _run=mock.Mock()) def test_run_simulator_sweeps(): simulator = cirq.SimulatesSamples() - expected_measurements = {'a': np.array([[1]])} + expected_measurements = {'a': np.array([[[1]]])} simulator._run.return_value = expected_measurements circuit = mock.Mock(cirq.Circuit) circuit.__iter__ = mock.Mock(return_value=iter([])) @@ -91,8 +91,8 @@ def test_run_simulator_sweeps(): for resolver in param_resolvers: resolver.param_dict = {} expected_results = [ - cirq.ResultDict(measurements=expected_measurements, params=param_resolvers[0]), - cirq.ResultDict(measurements=expected_measurements, params=param_resolvers[1]), + cirq.ResultDict(records=expected_measurements, params=param_resolvers[0]), + cirq.ResultDict(records=expected_measurements, params=param_resolvers[1]), ] assert expected_results == simulator.run_sweep( program=circuit, repetitions=10, params=param_resolvers @@ -368,7 +368,7 @@ def text(self, to_print): @duet.sync async def test_async_sample(): - m = {'mock': np.array([[0], [1]])} + m = {'mock': np.array([[[0]], [[1]]])} class MockSimulator(cirq.SimulatesSamples): def _run(self, circuit, param_resolver, repetitions): @@ -377,7 +377,7 @@ def _run(self, circuit, param_resolver, repetitions): q = cirq.LineQubit(0) f = MockSimulator().run_async(cirq.Circuit(cirq.measure(q)), repetitions=10) result = await f - np.testing.assert_equal(result.measurements, m) + np.testing.assert_equal(result.records, m) def test_simulation_trial_result_qubit_map(): diff --git a/cirq-core/cirq/value/classical_data.py b/cirq-core/cirq/value/classical_data.py index 5806f5f69f9..7aa8dedabf0 100644 --- a/cirq-core/cirq/value/classical_data.py +++ b/cirq-core/cirq/value/classical_data.py @@ -38,6 +38,16 @@ class ClassicalDataStoreReader(abc.ABC): def keys(self) -> Tuple['cirq.MeasurementKey', ...]: """Gets the measurement keys in the order they were stored.""" + @property + @abc.abstractmethod + def records(self) -> Mapping['cirq.MeasurementKey', List[Tuple[int, ...]]]: + """Gets the a mapping from measurement key to measurement records.""" + + @property + @abc.abstractmethod + def channel_records(self) -> Mapping['cirq.MeasurementKey', List[int]]: + """Gets the a mapping from measurement key to channel measurement records.""" + @abc.abstractmethod def get_int(self, key: 'cirq.MeasurementKey', index=-1) -> int: """Gets the integer corresponding to the measurement. From 3b18f764bedccbeb1db9c2ef68b22a108c7db16a Mon Sep 17 00:00:00 2001 From: daxfohl Date: Mon, 14 Feb 2022 19:31:19 -0800 Subject: [PATCH 87/89] Fix sampling --- cirq-core/cirq/sim/simulator.py | 10 +++++++--- cirq-core/cirq/sim/simulator_base.py | 5 ++--- cirq-core/cirq/sim/simulator_test.py | 9 ++++++--- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/cirq-core/cirq/sim/simulator.py b/cirq-core/cirq/sim/simulator.py index 1fafab53b8b..4c5fd93a6b6 100644 --- a/cirq-core/cirq/sim/simulator.py +++ b/cirq-core/cirq/sim/simulator.py @@ -37,6 +37,7 @@ Generic, Iterator, List, + Mapping, Sequence, Set, Tuple, @@ -807,6 +808,8 @@ def sample_measurement_ops( measurement_ops: List['cirq.GateOperation'], repetitions: int = 1, seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, + *, + _allow_repeated=False, ) -> Dict[str, np.ndarray]: """Samples from the system at this point in the computation. @@ -844,7 +847,7 @@ def sample_measurement_ops( result = collections.Counter( key for op in measurement_ops for key in protocols.measurement_key_names(op) ) - if result: + if result and not _allow_repeated: duplicates = [k for k, v in result.most_common() if v > 1] if duplicates: raise ValueError(f"Measurement key {','.join(duplicates)} repeated") @@ -873,8 +876,9 @@ def sample_measurement_ops( if inv_mask[i]: out[:, i] ^= out[:, i] < 2 results[gate.key] = out - - return results + if not _allow_repeated: + return results + return {k: [[x] * result[k] for x in v] for k, v in results.items()} @value.value_equality(unhashable=True) diff --git a/cirq-core/cirq/sim/simulator_base.py b/cirq-core/cirq/sim/simulator_base.py index 5a474c20104..a343cc25b97 100644 --- a/cirq-core/cirq/sim/simulator_base.py +++ b/cirq-core/cirq/sim/simulator_base.py @@ -25,7 +25,6 @@ Iterator, Generic, List, - Mapping, Optional, Sequence, Tuple, @@ -266,9 +265,9 @@ def _run( assert step_result is not None measurement_ops = [cast(ops.GateOperation, op) for op in general_ops] samples = step_result.sample_measurement_ops( - measurement_ops, repetitions, seed=self._prng + measurement_ops, repetitions, seed=self._prng, _allow_repeated=True ) - return {k: np.array([[x] for x in v], dtype=np.uint8) for k, v in samples.items()} + return {k: np.array(v, dtype=np.uint8) for k, v in samples.items()} records: Dict['cirq.MeasurementKey', List[np.ndarray]] = {} for i in range(repetitions): diff --git a/cirq-core/cirq/sim/simulator_test.py b/cirq-core/cirq/sim/simulator_test.py index 13cf73b577f..276dc1e55e7 100644 --- a/cirq-core/cirq/sim/simulator_test.py +++ b/cirq-core/cirq/sim/simulator_test.py @@ -389,7 +389,7 @@ def test_simulation_trial_result_qubit_map(): assert result.qubit_map == {q[0]: 0, q[1]: 1} -def test_verify_unique_measurement_keys(): +def test_sample_repeated_measurement_keys(): q = cirq.LineQubit.range(2) circuit = cirq.Circuit() circuit.append( @@ -400,8 +400,11 @@ def test_verify_unique_measurement_keys(): cirq.measure(q[1], key='b'), ] ) - with pytest.raises(ValueError, match='Measurement key a,b repeated'): - _ = cirq.sample(circuit) + result = cirq.sample(circuit) + assert len(result.records['a']) == 1 + assert len(result.records['b']) == 1 + assert len(result.records['a'][0]) == 2 + assert len(result.records['b'][0]) == 2 def test_simulate_with_invert_mask(): From ff7bc6720b26d3c79525f2a33704fb644f7a158c Mon Sep 17 00:00:00 2001 From: daxfohl Date: Mon, 14 Feb 2022 19:38:31 -0800 Subject: [PATCH 88/89] numpy --- cirq-core/cirq/sim/simulator.py | 5 +++-- cirq-core/cirq/sim/simulator_base.py | 3 +-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cirq-core/cirq/sim/simulator.py b/cirq-core/cirq/sim/simulator.py index 4c5fd93a6b6..e0e3b32b41f 100644 --- a/cirq-core/cirq/sim/simulator.py +++ b/cirq-core/cirq/sim/simulator.py @@ -37,7 +37,6 @@ Generic, Iterator, List, - Mapping, Sequence, Set, Tuple, @@ -826,6 +825,8 @@ def sample_measurement_ops( `MeasurementGate` instances to be sampled form. repetitions: The number of samples to take. seed: A seed for the pseudorandom number generator. + _allow_repeated: If True, adds extra dimension to the result, + corresponding to the number of times a key is repeated. Returns: A dictionary from measurement gate key to measurement results. Measurement results are stored in a 2-dimensional @@ -878,7 +879,7 @@ def sample_measurement_ops( results[gate.key] = out if not _allow_repeated: return results - return {k: [[x] * result[k] for x in v] for k, v in results.items()} + return {k: np.array([[x] * result[k] for x in v]) for k, v in results.items()} @value.value_equality(unhashable=True) diff --git a/cirq-core/cirq/sim/simulator_base.py b/cirq-core/cirq/sim/simulator_base.py index a343cc25b97..e9ec879cc7d 100644 --- a/cirq-core/cirq/sim/simulator_base.py +++ b/cirq-core/cirq/sim/simulator_base.py @@ -264,10 +264,9 @@ def _run( pass assert step_result is not None measurement_ops = [cast(ops.GateOperation, op) for op in general_ops] - samples = step_result.sample_measurement_ops( + return step_result.sample_measurement_ops( measurement_ops, repetitions, seed=self._prng, _allow_repeated=True ) - return {k: np.array(v, dtype=np.uint8) for k, v in samples.items()} records: Dict['cirq.MeasurementKey', List[np.ndarray]] = {} for i in range(repetitions): From 4dcc2ea8acd38deb48559e5d5bd38fbe6207a875 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Mon, 14 Feb 2022 19:57:08 -0800 Subject: [PATCH 89/89] Fix repetitions --- cirq-core/cirq/sim/simulator.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/cirq-core/cirq/sim/simulator.py b/cirq-core/cirq/sim/simulator.py index e0e3b32b41f..62c30e6044b 100644 --- a/cirq-core/cirq/sim/simulator.py +++ b/cirq-core/cirq/sim/simulator.py @@ -866,20 +866,28 @@ def sample_measurement_ops( indexed_sample = self.sample(measured_qubits, repetitions, seed=seed) # Extract results for each measurement. - results: Dict[str, np.ndarray] = {} + results: Dict[str, Any] = {} qubits_to_index = {q: i for i, q in enumerate(measured_qubits)} for op in measurement_ops: gate = cast(ops.MeasurementGate, op.gate) + key = gate.key out = np.zeros(shape=(repetitions, len(op.qubits)), dtype=np.int8) inv_mask = gate.full_invert_mask() for i, q in enumerate(op.qubits): out[:, i] = indexed_sample[:, qubits_to_index[q]] if inv_mask[i]: out[:, i] ^= out[:, i] < 2 - results[gate.key] = out - if not _allow_repeated: - return results - return {k: np.array([[x] * result[k] for x in v]) for k, v in results.items()} + if _allow_repeated: + if key not in results: + results[key] = [] + results[key].append(out) + else: + results[gate.key] = out + return ( + results + if not _allow_repeated + else {k: np.array(v).swapaxes(0, 1) for k, v in results.items()} + ) @value.value_equality(unhashable=True)