diff --git a/cirq-core/cirq/circuits/circuit.py b/cirq-core/cirq/circuits/circuit.py index 3415337a49f..7a7e85498eb 100644 --- a/cirq-core/cirq/circuits/circuit.py +++ b/cirq-core/cirq/circuits/circuit.py @@ -1980,11 +1980,11 @@ def earliest_available_moment( while k > 0: k -= 1 moment = self._moments[k] - # This should also validate that measurement keys are disjoint once we allow repeated - # measurements. Search for same message in raw_types.py. + moment_measurement_keys = protocols.measurement_key_objs(moment) if ( moment.operates_on(op_qubits) - or not op_control_keys.isdisjoint(protocols.measurement_key_objs(moment)) + or not op_measurement_keys.isdisjoint(moment_measurement_keys) + or not op_control_keys.isdisjoint(moment_measurement_keys) or not protocols.control_keys(moment).isdisjoint(op_measurement_keys) ): return last_available diff --git a/cirq-core/cirq/circuits/circuit_operation_test.py b/cirq-core/cirq/circuits/circuit_operation_test.py index e4b8c294e90..7fbf782ef9b 100644 --- a/cirq-core/cirq/circuits/circuit_operation_test.py +++ b/cirq-core/cirq/circuits/circuit_operation_test.py @@ -834,7 +834,7 @@ def test_mapped_circuit_keeps_keys_under_parent_path(): assert cirq.measurement_key_names(op2.mapped_circuit()) == {'X:A', 'X:B', 'X:C', 'X:D'} -def test_keys_conflict_no_repetitions(): +def test_mapped_circuit_allows_repeated_keys(): q = cirq.LineQubit(0) op1 = cirq.CircuitOperation( cirq.FrozenCircuit( @@ -842,16 +842,20 @@ def test_keys_conflict_no_repetitions(): ) ) op2 = cirq.CircuitOperation(cirq.FrozenCircuit(op1, op1)) - with pytest.raises(ValueError, match='Conflicting measurement keys found: A'): - _ = op2.mapped_circuit(deep=True) - - -def test_keys_conflict_locally(): - q = cirq.LineQubit(0) + circuit = op2.mapped_circuit(deep=True) + cirq.testing.assert_has_diagram( + circuit, + "0: ───M('A')───M('A')───", + use_unicode_characters=True, + ) op1 = cirq.measure(q, key='A') op2 = cirq.CircuitOperation(cirq.FrozenCircuit(op1, op1)) - with pytest.raises(ValueError, match='Conflicting measurement keys found: A'): - _ = op2.mapped_circuit() + circuit = op2.mapped_circuit() + cirq.testing.assert_has_diagram( + circuit, + "0: ───M('A')───M('A')───", + use_unicode_characters=True, + ) # TODO: Operation has a "gate" property. What is this for a CircuitOperation? diff --git a/cirq-core/cirq/ops/classically_controlled_operation_test.py b/cirq-core/cirq/ops/classically_controlled_operation_test.py index e3cf0f7b1c0..0daf9f327e7 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation_test.py +++ b/cirq-core/cirq/ops/classically_controlled_operation_test.py @@ -259,6 +259,45 @@ def test_key_set(sim): assert result.measurements['b'] == 1 +@pytest.mark.parametrize('sim', ALL_SIMULATORS) +def test_repeated_measurement_unset(sim): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.measure(q0, key='a'), + cirq.X(q0), + cirq.measure(q0, key='a'), + cirq.X(q1).with_classical_controls(cirq.KeyCondition(cirq.MeasurementKey('a'), index=-2)), + cirq.measure(q1, key='b'), + cirq.X(q1).with_classical_controls(cirq.KeyCondition(cirq.MeasurementKey('a'), index=-1)), + cirq.measure(q1, key='c'), + ) + result = sim.run(circuit) + assert result.records['a'][0][0][0] == 0 + assert result.records['a'][0][1][0] == 1 + assert result.records['b'][0][0][0] == 0 + assert result.records['c'][0][0][0] == 1 + + +@pytest.mark.parametrize('sim', ALL_SIMULATORS) +def test_repeated_measurement_set(sim): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.X(q0), + cirq.measure(q0, key='a'), + cirq.X(q0), + cirq.measure(q0, key='a'), + cirq.X(q1).with_classical_controls(cirq.KeyCondition(cirq.MeasurementKey('a'), index=-2)), + cirq.measure(q1, key='b'), + cirq.X(q1).with_classical_controls(cirq.KeyCondition(cirq.MeasurementKey('a'), index=-1)), + cirq.measure(q1, key='c'), + ) + result = sim.run(circuit) + assert result.records['a'][0][0][0] == 1 + assert result.records['a'][0][1][0] == 0 + assert result.records['b'][0][0][0] == 1 + assert result.records['c'][0][0][0] == 1 + + @pytest.mark.parametrize('sim', ALL_SIMULATORS) def test_subcircuit_key_unset(sim): q0, q1 = cirq.LineQubit.range(2) diff --git a/cirq-core/cirq/ops/measurement_gate_test.py b/cirq-core/cirq/ops/measurement_gate_test.py index 2be81a296bb..2272b9a58b0 100644 --- a/cirq-core/cirq/ops/measurement_gate_test.py +++ b/cirq-core/cirq/ops/measurement_gate_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import cast import numpy as np import pytest @@ -323,10 +324,13 @@ def test_act_on_state_vector(): dtype=np.complex64, ) cirq.act_on(m, args) + datastore = cast(cirq.ClassicalDataDictionaryStore, args.classical_data) + out = cirq.MeasurementKey('out') assert args.log_of_measurement_results == {'out': [0, 1]} - - with pytest.raises(ValueError, match="already logged to key"): - cirq.act_on(m, args) + assert datastore.records[out] == [(0, 1)] + cirq.act_on(m, args) + assert args.log_of_measurement_results == {'out': [0, 1]} + assert datastore.records[out] == [(0, 1), (0, 1)] def test_act_on_clifford_tableau(): @@ -361,10 +365,13 @@ def test_act_on_clifford_tableau(): log_of_measurement_results={}, ) cirq.act_on(m, args) + datastore = cast(cirq.ClassicalDataDictionaryStore, args.classical_data) + out = cirq.MeasurementKey('out') assert args.log_of_measurement_results == {'out': [0, 1]} - - with pytest.raises(ValueError, match="already logged to key"): - cirq.act_on(m, args) + assert datastore.records[out] == [(0, 1)] + cirq.act_on(m, args) + assert args.log_of_measurement_results == {'out': [0, 1]} + assert datastore.records[out] == [(0, 1), (0, 1)] def test_act_on_stabilizer_ch_form(): @@ -399,10 +406,13 @@ def test_act_on_stabilizer_ch_form(): initial_state=10, ) cirq.act_on(m, args) + datastore = cast(cirq.ClassicalDataDictionaryStore, args.classical_data) + out = cirq.MeasurementKey('out') assert args.log_of_measurement_results == {'out': [0, 1]} - - with pytest.raises(ValueError, match="already logged to key"): - cirq.act_on(m, args) + assert datastore.records[out] == [(0, 1)] + cirq.act_on(m, args) + assert args.log_of_measurement_results == {'out': [0, 1]} + assert datastore.records[out] == [(0, 1), (0, 1)] def test_act_on_qutrit(): diff --git a/cirq-core/cirq/ops/raw_types.py b/cirq-core/cirq/ops/raw_types.py index e5600fec6c6..ad2411a4b1e 100644 --- a/cirq-core/cirq/ops/raw_types.py +++ b/cirq-core/cirq/ops/raw_types.py @@ -572,11 +572,13 @@ def _commutes_( if not isinstance(other, Operation): return NotImplemented - # This should also validate that measurement keys are disjoint once we allow repeated - # measurements. Search for same message in circuit.py. - if not protocols.control_keys(self).isdisjoint( - protocols.measurement_key_objs(other) - ) or not protocols.control_keys(other).isdisjoint(protocols.measurement_key_objs(self)): + self_keys = protocols.measurement_key_objs(self) + other_keys = protocols.measurement_key_objs(other) + if ( + not self_keys.isdisjoint(other_keys) + or not protocols.control_keys(self).isdisjoint(other_keys) + or not protocols.control_keys(other).isdisjoint(self_keys) + ): return False if hasattr(other, 'qubits') and set(self.qubits).isdisjoint(other.qubits): diff --git a/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.json b/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.json index d5c51d5839c..8d03c5d2059 100644 --- a/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.json +++ b/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.json @@ -1,13 +1,13 @@ { "cirq_type": "ClassicalDataDictionaryStore", - "measurements": [ + "records": [ [ { "cirq_type": "MeasurementKey", "name": "m", "path": [] }, - [0, 1] + [[0, 1]] ] ], "measured_qubits": [ @@ -17,7 +17,7 @@ "name": "m", "path": [] }, - [ + [[ { "cirq_type": "LineQubit", "x": 0 @@ -26,17 +26,17 @@ "cirq_type": "LineQubit", "x": 1 } - ] + ]] ] ], - "channel_measurements": [ + "channel_records": [ [ { "cirq_type": "MeasurementKey", "name": "c", "path": [] }, - 3 + [3] ] ], "measurement_types": [ diff --git a/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.repr b/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.repr index c19b8190bfb..156350e2cc0 100644 --- a/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.repr +++ b/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.repr @@ -1 +1 @@ -cirq.ClassicalDataDictionaryStore(_measurements={cirq.MeasurementKey('m'): [0, 1]}, _measured_qubits={cirq.MeasurementKey('m'): [cirq.LineQubit(0), cirq.LineQubit(1)]}, _channel_measurements={cirq.MeasurementKey('c'): 3}, _measurement_types={cirq.MeasurementKey('m'): cirq.MeasurementType.MEASUREMENT, cirq.MeasurementKey('c'): cirq.MeasurementType.CHANNEL}) \ No newline at end of file +cirq.ClassicalDataDictionaryStore(_records={cirq.MeasurementKey('m'): [[0, 1]]}, _measured_qubits={cirq.MeasurementKey('m'): [[cirq.LineQubit(0), cirq.LineQubit(1)]]}, _channel_records={cirq.MeasurementKey('c'): [3]}, _measurement_types={cirq.MeasurementKey('m'): cirq.MeasurementType.MEASUREMENT, cirq.MeasurementKey('c'): cirq.MeasurementType.CHANNEL}) \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.json b/cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.json index 8fbae9b27c7..485e0bfa04f 100644 --- a/cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.json +++ b/cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.json @@ -7,7 +7,8 @@ "cirq_type": "MeasurementKey", "name": "a", "path": [] - } + }, + "index": -1 }, { "cirq_type": "KeyCondition", @@ -15,7 +16,8 @@ "cirq_type": "MeasurementKey", "name": "b", "path": [] - } + }, + "index": -1 } ], "sub_operation": { diff --git a/cirq-core/cirq/protocols/json_test_data/KeyCondition.json b/cirq-core/cirq/protocols/json_test_data/KeyCondition.json index f5b81ba63dc..4478ba5c6c1 100644 --- a/cirq-core/cirq/protocols/json_test_data/KeyCondition.json +++ b/cirq-core/cirq/protocols/json_test_data/KeyCondition.json @@ -4,5 +4,6 @@ "cirq_type": "MeasurementKey", "name": "a", "path": [] - } + }, + "index": -1 } \ No newline at end of file diff --git a/cirq-core/cirq/sim/act_on_args.py b/cirq-core/cirq/sim/act_on_args.py index 8694067131c..bb656f8a3c7 100644 --- a/cirq-core/cirq/sim/act_on_args.py +++ b/cirq-core/cirq/sim/act_on_args.py @@ -76,8 +76,8 @@ def __init__( self._set_qubits(qubits) self.prng = prng self._classical_data = classical_data or value.ClassicalDataDictionaryStore( - _measurements={ - value.MeasurementKey.parse_serialized(k): tuple(v) + _records={ + value.MeasurementKey.parse_serialized(k): [tuple(v)] for k, v in (log_of_measurement_results or {}).items() } ) diff --git a/cirq-core/cirq/sim/act_on_args_container.py b/cirq-core/cirq/sim/act_on_args_container.py index b00960b2b49..f450ad9bbec 100644 --- a/cirq-core/cirq/sim/act_on_args_container.py +++ b/cirq-core/cirq/sim/act_on_args_container.py @@ -72,8 +72,8 @@ def __init__( self._qubits = tuple(qubits) self.split_untangled_states = split_untangled_states self._classical_data = classical_data or value.ClassicalDataDictionaryStore( - _measurements={ - value.MeasurementKey.parse_serialized(k): tuple(v) + _records={ + value.MeasurementKey.parse_serialized(k): [tuple(v)] for k, v in (log_of_measurement_results or {}).items() } ) diff --git a/cirq-core/cirq/sim/clifford/act_on_stabilizer_args.py b/cirq-core/cirq/sim/clifford/act_on_stabilizer_args.py index d9051f21147..5cc45ef6297 100644 --- a/cirq-core/cirq/sim/clifford/act_on_stabilizer_args.py +++ b/cirq-core/cirq/sim/clifford/act_on_stabilizer_args.py @@ -13,7 +13,7 @@ # limitations under the License. import abc -from typing import Any, Dict, Generic, Optional, Sequence, TYPE_CHECKING, TypeVar, Union +from typing import Any, Dict, Generic, List, Optional, Sequence, TYPE_CHECKING, TypeVar, Union import numpy as np @@ -38,7 +38,7 @@ def __init__( self, state: TStabilizerState, prng: Optional[np.random.RandomState] = None, - log_of_measurement_results: Optional[Dict[str, Any]] = None, + log_of_measurement_results: Optional[Dict[str, List[int]]] = None, qubits: Optional[Sequence['cirq.Qid']] = None, classical_data: Optional['cirq.ClassicalDataStore'] = None, ): diff --git a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py index 6775fafa2ce..44689e4cc48 100644 --- a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py +++ b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Sequence, TYPE_CHECKING, Union +from typing import Dict, List, Optional, Sequence, TYPE_CHECKING, Union import numpy as np -from cirq import _compat, value, ops, protocols +from cirq import _compat, value from cirq.sim.clifford import stabilizer_state_ch_form from cirq.sim.clifford.act_on_stabilizer_args import ActOnStabilizerArgs @@ -39,7 +39,7 @@ def __init__( self, state: Optional['cirq.StabilizerStateChForm'] = None, prng: Optional[np.random.RandomState] = None, - log_of_measurement_results: Optional[Dict[str, Any]] = None, + log_of_measurement_results: Optional[Dict[str, List[int]]] = None, qubits: Optional[Sequence['cirq.Qid']] = None, initial_state: Union[int, 'cirq.StabilizerStateChForm'] = 0, classical_data: Optional['cirq.ClassicalDataStore'] = None, @@ -107,16 +107,10 @@ def sample( repetitions: int = 1, seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, ) -> np.ndarray: - measurements = value.ClassicalDataDictionaryStore() prng = value.parse_random_state(seed) - for i in range(repetitions): - op = ops.measure(*qubits, key=str(i)) + axes = self.get_axes(qubits) + measurements = [] + for _ in range(repetitions): state = self.state.copy() - ch_form_args = ActOnStabilizerCHFormArgs( - classical_data=measurements, - prng=prng, - qubits=self.qubits, - initial_state=state, - ) - protocols.act_on(op, ch_form_args) - return np.array(list(measurements.measurements.values()), dtype=bool) + measurements.append([state._measure(i, prng) for i in axes]) + return np.array(measurements, dtype=bool) diff --git a/cirq-core/cirq/sim/clifford/clifford_simulator.py b/cirq-core/cirq/sim/clifford/clifford_simulator.py index e04559272d7..3f19c253d6f 100644 --- a/cirq-core/cirq/sim/clifford/clifford_simulator.py +++ b/cirq-core/cirq/sim/clifford/clifford_simulator.py @@ -293,4 +293,4 @@ def apply_measurement( initial_state=state.ch_form, ) act_on(op, ch_form_args) - measurements.update({str(k): list(v) for k, v in classical_data.measurements.items()}) + measurements.update({str(k): list(v[-1]) for k, v in classical_data.records.items()}) diff --git a/cirq-core/cirq/sim/clifford/stabilizer_state_ch_form_test.py b/cirq-core/cirq/sim/clifford/stabilizer_state_ch_form_test.py index 6f1dc38b4bd..30ef642a745 100644 --- a/cirq-core/cirq/sim/clifford/stabilizer_state_ch_form_test.py +++ b/cirq-core/cirq/sim/clifford/stabilizer_state_ch_form_test.py @@ -73,6 +73,6 @@ def test_run(): initial_state=state, ) cirq.act_on(op, args) - measurements = {str(k): list(v) for k, v in classical_data.measurements.items()} + measurements = {str(k): list(v[-1]) for k, v in classical_data.records.items()} assert measurements['1'] == [1] assert measurements['0'] != measurements['2'] diff --git a/cirq-core/cirq/sim/operation_target.py b/cirq-core/cirq/sim/operation_target.py index e54916303cd..c0381f7e9f9 100644 --- a/cirq-core/cirq/sim/operation_target.py +++ b/cirq-core/cirq/sim/operation_target.py @@ -14,7 +14,6 @@ """An interface for quantum states as targets for operations.""" import abc from typing import ( - Any, Dict, Generic, Iterator, @@ -86,7 +85,7 @@ def qubits(self) -> Tuple['cirq.Qid', ...]: """Gets the qubit order maintained by this target.""" @property - def log_of_measurement_results(self) -> Dict[str, Any]: + def log_of_measurement_results(self) -> Dict[str, List[int]]: """Gets the log of measurement results.""" return {str(k): list(self.classical_data.get_digits(k)) for k in self.classical_data.keys()} diff --git a/cirq-core/cirq/sim/simulator.py b/cirq-core/cirq/sim/simulator.py index 4afc35d7c89..62c30e6044b 100644 --- a/cirq-core/cirq/sim/simulator.py +++ b/cirq-core/cirq/sim/simulator.py @@ -37,7 +37,6 @@ Generic, Iterator, List, - Optional, Sequence, Set, Tuple, @@ -102,18 +101,16 @@ def run_sweep_iter( if not program.has_measurements(): raise ValueError("Circuit has no measurements to sample.") - _verify_unique_measurement_keys(program) - for param_resolver in study.to_resolvers(params): - measurements = {} + records = {} if repetitions == 0: for _, op, _ in program.findall_operations_with_gate_type(ops.MeasurementGate): - measurements[protocols.measurement_key_name(op)] = np.empty([0, 1]) + records[protocols.measurement_key_name(op)] = np.empty([0, 1, 1]) else: - measurements = self._run( + records = self._run( circuit=program, param_resolver=param_resolver, repetitions=repetitions ) - yield study.ResultDict(params=param_resolver, measurements=measurements) + yield study.ResultDict(params=param_resolver, records=records) @abc.abstractmethod def _run( @@ -132,10 +129,11 @@ def _run( Returns: A dictionary from measurement gate key to measurement - results. Measurement results are stored in a 2-dimensional - numpy array, the first dimension corresponding to the repetition - and the second to the actual boolean measurement results (ordered - by the qubits being measured.) + results. Measurement results are stored in a 3-dimensional + numpy array, the first dimension corresponding to the repetition. + the second to the instance of that key in the circuit, and the + third to the actual boolean measurement results (ordered by the + qubits being measured.) """ raise NotImplementedError() @@ -763,8 +761,9 @@ class StepResult(Generic[TSimulatorState], metaclass=abc.ABCMeta): results, ordered by the qubits that the measurement operates on. """ - def __init__(self, measurements: Optional[Dict[str, List[int]]] = None) -> None: - self.measurements = measurements or collections.defaultdict(list) + def __init__(self, sim_state: 'cirq.OperationTarget') -> None: + self.measurements = sim_state.log_of_measurement_results + self._classical_data = sim_state.classical_data @abc.abstractmethod def _simulator_state(self) -> TSimulatorState: @@ -808,6 +807,8 @@ def sample_measurement_ops( measurement_ops: List['cirq.GateOperation'], repetitions: int = 1, seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, + *, + _allow_repeated=False, ) -> Dict[str, np.ndarray]: """Samples from the system at this point in the computation. @@ -824,6 +825,8 @@ def sample_measurement_ops( `MeasurementGate` instances to be sampled form. repetitions: The number of samples to take. seed: A seed for the pseudorandom number generator. + _allow_repeated: If True, adds extra dimension to the result, + corresponding to the number of times a key is repeated. Returns: A dictionary from measurement gate key to measurement results. Measurement results are stored in a 2-dimensional @@ -838,15 +841,17 @@ def sample_measurement_ops( """ # Sanity checks. - seen_measurement_keys: Set[str] = set() for op in measurement_ops: gate = op.gate if not isinstance(gate, ops.MeasurementGate): raise ValueError(f'{op.gate} was not a MeasurementGate') - key = protocols.measurement_key_name(gate) - if key in seen_measurement_keys: - raise ValueError(f'Duplicate MeasurementGate with key {key}') - seen_measurement_keys.add(key) + result = collections.Counter( + key for op in measurement_ops for key in protocols.measurement_key_names(op) + ) + if result and not _allow_repeated: + duplicates = [k for k, v in result.most_common() if v > 1] + if duplicates: + raise ValueError(f"Measurement key {','.join(duplicates)} repeated") # Find measured qubits, ensuring a consistent ordering. measured_qubits = [] @@ -861,19 +866,28 @@ def sample_measurement_ops( indexed_sample = self.sample(measured_qubits, repetitions, seed=seed) # Extract results for each measurement. - results: Dict[str, np.ndarray] = {} + results: Dict[str, Any] = {} qubits_to_index = {q: i for i, q in enumerate(measured_qubits)} for op in measurement_ops: gate = cast(ops.MeasurementGate, op.gate) + key = gate.key out = np.zeros(shape=(repetitions, len(op.qubits)), dtype=np.int8) inv_mask = gate.full_invert_mask() for i, q in enumerate(op.qubits): out[:, i] = indexed_sample[:, qubits_to_index[q]] if inv_mask[i]: out[:, i] ^= out[:, i] < 2 - results[gate.key] = out - - return results + if _allow_repeated: + if key not in results: + results[key] = [] + results[key].append(out) + else: + results[gate.key] = out + return ( + results + if not _allow_repeated + else {k: np.array(v).swapaxes(0, 1) for k, v in results.items()} + ) @value.value_equality(unhashable=True) @@ -986,18 +1000,6 @@ def _qubit_map_to_shape(qubit_map: Dict['cirq.Qid', int]) -> Tuple[int, ...]: return tuple(qid_shape) -def _verify_unique_measurement_keys(circuit: 'cirq.AbstractCircuit'): - result = collections.Counter( - key - for op in ops.flatten_op_tree(iter(circuit)) - for key in protocols.measurement_key_names(op) - ) - if result: - duplicates = [k for k, v in result.most_common() if v > 1] - if duplicates: - raise ValueError(f"Measurement key {','.join(duplicates)} repeated") - - def check_all_resolved(circuit): """Raises if the circuit contains unresolved symbols.""" if protocols.is_parameterized(circuit): diff --git a/cirq-core/cirq/sim/simulator_base.py b/cirq-core/cirq/sim/simulator_base.py index 610e369611c..e9ec879cc7d 100644 --- a/cirq-core/cirq/sim/simulator_base.py +++ b/cirq-core/cirq/sim/simulator_base.py @@ -20,17 +20,17 @@ import warnings from typing import ( Any, + cast, Dict, Iterator, + Generic, List, + Optional, + Sequence, Tuple, - TYPE_CHECKING, - cast, - Generic, Type, - Sequence, - Optional, TypeVar, + TYPE_CHECKING, ) import numpy as np @@ -264,9 +264,11 @@ def _run( pass assert step_result is not None measurement_ops = [cast(ops.GateOperation, op) for op in general_ops] - return step_result.sample_measurement_ops(measurement_ops, repetitions, seed=self._prng) + return step_result.sample_measurement_ops( + measurement_ops, repetitions, seed=self._prng, _allow_repeated=True + ) - measurements: Dict[str, List[np.ndarray]] = {} + records: Dict['cirq.MeasurementKey', List[np.ndarray]] = {} for i in range(repetitions): if 'deep_copy_buffers' in inspect.signature(act_on_args.copy).parameters: all_step_results = self._core_iterator( @@ -289,11 +291,15 @@ def _run( ) for step_result in all_step_results: pass - for k, v in step_result.measurements.items(): - if k not in measurements: - measurements[k] = [] - measurements[k].append(np.array(v, dtype=np.uint8)) - return {k: np.array(v) for k, v in measurements.items()} + for k, r in step_result._classical_data.records.items(): + if k not in records: + records[k] = [] + records[k].append(r) + for k, cr in step_result._classical_data.channel_records.items(): + if k not in records: + records[k] = [] + records[k].append([cr]) + return {str(k): np.array(v, dtype=np.uint8) for k, v in records.items()} def simulate_sweep_iter( self, @@ -397,7 +403,7 @@ def __init__( """ self._sim_state = sim_state self._merged_sim_state_cache: Optional[TActOnArgs] = None - super().__init__(sim_state.log_of_measurement_results) + super().__init__(sim_state) qubits = sim_state.qubits self._qubits = qubits self._qubit_mapping = {q: i for i, q in enumerate(qubits)} diff --git a/cirq-core/cirq/sim/simulator_test.py b/cirq-core/cirq/sim/simulator_test.py index eeedb83f2b2..276dc1e55e7 100644 --- a/cirq-core/cirq/sim/simulator_test.py +++ b/cirq-core/cirq/sim/simulator_test.py @@ -65,13 +65,13 @@ def _create_simulator_trial_result( @mock.patch.multiple(cirq.SimulatesSamples, __abstractmethods__=set(), _run=mock.Mock()) def test_run_simulator_run(): simulator = cirq.SimulatesSamples() - expected_measurements = {'a': np.array([[1]])} + expected_measurements = {'a': np.array([[[1]]])} simulator._run.return_value = expected_measurements circuit = mock.Mock(cirq.Circuit) circuit.__iter__ = mock.Mock(return_value=iter([])) param_resolver = mock.Mock(cirq.ParamResolver) param_resolver.param_dict = {} - expected_result = cirq.ResultDict(measurements=expected_measurements, params=param_resolver) + expected_result = cirq.ResultDict(records=expected_measurements, params=param_resolver) assert expected_result == simulator.run( program=circuit, repetitions=10, param_resolver=param_resolver ) @@ -83,7 +83,7 @@ def test_run_simulator_run(): @mock.patch.multiple(cirq.SimulatesSamples, __abstractmethods__=set(), _run=mock.Mock()) def test_run_simulator_sweeps(): simulator = cirq.SimulatesSamples() - expected_measurements = {'a': np.array([[1]])} + expected_measurements = {'a': np.array([[[1]]])} simulator._run.return_value = expected_measurements circuit = mock.Mock(cirq.Circuit) circuit.__iter__ = mock.Mock(return_value=iter([])) @@ -91,8 +91,8 @@ def test_run_simulator_sweeps(): for resolver in param_resolvers: resolver.param_dict = {} expected_results = [ - cirq.ResultDict(measurements=expected_measurements, params=param_resolvers[0]), - cirq.ResultDict(measurements=expected_measurements, params=param_resolvers[1]), + cirq.ResultDict(records=expected_measurements, params=param_resolvers[0]), + cirq.ResultDict(records=expected_measurements, params=param_resolvers[1]), ] assert expected_results == simulator.run_sweep( program=circuit, repetitions=10, params=param_resolvers @@ -238,7 +238,7 @@ def test_step_sample_measurement_ops_not_measurement(): def test_step_sample_measurement_ops_repeated_qubit(): q0, q1, q2 = cirq.LineQubit.range(3) step_result = FakeStepResult([q0]) - with pytest.raises(ValueError, match='MeasurementGate'): + with pytest.raises(ValueError, match='Measurement key 0 repeated'): step_result.sample_measurement_ops( [cirq.measure(q0), cirq.measure(q1, q2), cirq.measure(q0)] ) @@ -368,7 +368,7 @@ def text(self, to_print): @duet.sync async def test_async_sample(): - m = {'mock': np.array([[0], [1]])} + m = {'mock': np.array([[[0]], [[1]]])} class MockSimulator(cirq.SimulatesSamples): def _run(self, circuit, param_resolver, repetitions): @@ -377,7 +377,7 @@ def _run(self, circuit, param_resolver, repetitions): q = cirq.LineQubit(0) f = MockSimulator().run_async(cirq.Circuit(cirq.measure(q)), repetitions=10) result = await f - np.testing.assert_equal(result.measurements, m) + np.testing.assert_equal(result.records, m) def test_simulation_trial_result_qubit_map(): @@ -389,7 +389,7 @@ def test_simulation_trial_result_qubit_map(): assert result.qubit_map == {q[0]: 0, q[1]: 1} -def test_verify_unique_measurement_keys(): +def test_sample_repeated_measurement_keys(): q = cirq.LineQubit.range(2) circuit = cirq.Circuit() circuit.append( @@ -400,8 +400,11 @@ def test_verify_unique_measurement_keys(): cirq.measure(q[1], key='b'), ] ) - with pytest.raises(ValueError, match='Measurement key a,b repeated'): - _ = cirq.sample(circuit) + result = cirq.sample(circuit) + assert len(result.records['a']) == 1 + assert len(result.records['b']) == 1 + assert len(result.records['a'][0]) == 2 + assert len(result.records['b'][0]) == 2 def test_simulate_with_invert_mask(): diff --git a/cirq-core/cirq/value/classical_data.py b/cirq-core/cirq/value/classical_data.py index 5596be02efd..7aa8dedabf0 100644 --- a/cirq-core/cirq/value/classical_data.py +++ b/cirq-core/cirq/value/classical_data.py @@ -14,7 +14,7 @@ import abc import enum -from typing import Dict, Mapping, Sequence, Tuple, TYPE_CHECKING, TypeVar +from typing import Dict, List, Mapping, Sequence, Tuple, TYPE_CHECKING, TypeVar from cirq.value import digits, value_equality_attr @@ -38,8 +38,18 @@ class ClassicalDataStoreReader(abc.ABC): def keys(self) -> Tuple['cirq.MeasurementKey', ...]: """Gets the measurement keys in the order they were stored.""" + @property @abc.abstractmethod - def get_int(self, key: 'cirq.MeasurementKey') -> int: + def records(self) -> Mapping['cirq.MeasurementKey', List[Tuple[int, ...]]]: + """Gets the a mapping from measurement key to measurement records.""" + + @property + @abc.abstractmethod + def channel_records(self) -> Mapping['cirq.MeasurementKey', List[int]]: + """Gets the a mapping from measurement key to channel measurement records.""" + + @abc.abstractmethod + def get_int(self, key: 'cirq.MeasurementKey', index=-1) -> int: """Gets the integer corresponding to the measurement. The integer is determined by summing the qubit-dimensional basis value @@ -50,13 +60,18 @@ def get_int(self, key: 'cirq.MeasurementKey') -> int: Args: key: The measurement key. + index: If multiple measurements have the same key, the index + argument can be used to specify which measurement to retrieve. + Here `0` refers to the first measurement, and `-1` refers to + the most recent. Raises: - KeyError: If the key has not been used. + KeyError: If the key has not been used or if the index is out of + bounds. """ @abc.abstractmethod - def get_digits(self, key: 'cirq.MeasurementKey') -> Tuple[int, ...]: + def get_digits(self, key: 'cirq.MeasurementKey', index=-1) -> Tuple[int, ...]: """Gets the values of the qubits that were measured into this key. For example, if the measurement of qubits [q0, q1] produces [0, 1], @@ -64,9 +79,14 @@ def get_digits(self, key: 'cirq.MeasurementKey') -> Tuple[int, ...]: Args: key: The measurement key. + index: If multiple measurements have the same key, the index + argument can be used to specify which measurement to retrieve. + Here `0` refers to the first measurement, and `-1` refers to + the most recent. Raises: - KeyError: If the key has not been used. + KeyError: If the key has not been used or if the index is out of + bounds. """ @abc.abstractmethod @@ -111,49 +131,49 @@ class ClassicalDataDictionaryStore(ClassicalDataStore): def __init__( self, *, - _measurements: Dict['cirq.MeasurementKey', Tuple[int, ...]] = None, - _measured_qubits: Dict['cirq.MeasurementKey', Tuple['cirq.Qid', ...]] = None, - _channel_measurements: Dict['cirq.MeasurementKey', int] = None, + _records: Dict['cirq.MeasurementKey', List[Tuple[int, ...]]] = None, + _measured_qubits: Dict['cirq.MeasurementKey', List[Tuple['cirq.Qid', ...]]] = None, + _channel_records: Dict['cirq.MeasurementKey', List[int]] = None, _measurement_types: Dict['cirq.MeasurementKey', 'cirq.MeasurementType'] = None, ): """Initializes a `ClassicalDataDictionaryStore` object.""" if not _measurement_types: _measurement_types = {} - if _measurements: + if _records: _measurement_types.update( - {k: MeasurementType.MEASUREMENT for k, v in _measurements.items()} + {k: MeasurementType.MEASUREMENT for k, v in _records.items()} ) - if _channel_measurements: + if _channel_records: _measurement_types.update( - {k: MeasurementType.CHANNEL for k, v in _channel_measurements.items()} + {k: MeasurementType.CHANNEL for k, v in _channel_records.items()} ) - if _measurements is None: - _measurements = {} + if _records is None: + _records = {} if _measured_qubits is None: _measured_qubits = {} - if _channel_measurements is None: - _channel_measurements = {} - self._measurements: Dict['cirq.MeasurementKey', Tuple[int, ...]] = _measurements + if _channel_records is None: + _channel_records = {} + self._records: Dict['cirq.MeasurementKey', List[Tuple[int, ...]]] = _records self._measured_qubits: Dict[ - 'cirq.MeasurementKey', Tuple['cirq.Qid', ...] + 'cirq.MeasurementKey', List[Tuple['cirq.Qid', ...]] ] = _measured_qubits - self._channel_measurements: Dict['cirq.MeasurementKey', int] = _channel_measurements + self._channel_records: Dict['cirq.MeasurementKey', List[int]] = _channel_records self._measurement_types: Dict[ 'cirq.MeasurementKey', 'cirq.MeasurementType' ] = _measurement_types @property - def measurements(self) -> Mapping['cirq.MeasurementKey', Tuple[int, ...]]: - """Gets the a mapping from measurement key to measurement.""" - return self._measurements + def records(self) -> Mapping['cirq.MeasurementKey', List[Tuple[int, ...]]]: + """Gets the a mapping from measurement key to measurement records.""" + return self._records @property - def channel_measurements(self) -> Mapping['cirq.MeasurementKey', int]: - """Gets the a mapping from measurement key to channel measurement.""" - return self._channel_measurements + def channel_records(self) -> Mapping['cirq.MeasurementKey', List[int]]: + """Gets the a mapping from measurement key to channel measurement records.""" + return self._channel_records @property - def measured_qubits(self) -> Mapping['cirq.MeasurementKey', Tuple['cirq.Qid', ...]]: + def measured_qubits(self) -> Mapping['cirq.MeasurementKey', List[Tuple['cirq.Qid', ...]]]: """Gets the a mapping from measurement key to the qubits measured.""" return self._measured_qubits @@ -170,76 +190,88 @@ def record_measurement( ): if len(measurement) != len(qubits): raise ValueError(f'{len(measurement)} measurements but {len(qubits)} qubits.') - if key in self._measurement_types: - raise ValueError(f"Measurement already logged to key {key}") - self._measurement_types[key] = MeasurementType.MEASUREMENT - self._measurements[key] = tuple(measurement) - self._measured_qubits[key] = tuple(qubits) + if key not in self._measurement_types: + self._measurement_types[key] = MeasurementType.MEASUREMENT + self._records[key] = [] + self._measured_qubits[key] = [] + if self._measurement_types[key] != MeasurementType.MEASUREMENT: + raise ValueError(f"Channel Measurement already logged to key {key}") + measured_qubits = self._measured_qubits[key] + if measured_qubits: + shape = tuple(q.dimension for q in qubits) + key_shape = tuple(q.dimension for q in measured_qubits[-1]) + if shape != key_shape: + raise ValueError(f'Measurement shape {shape} does not match {key_shape} in {key}.') + measured_qubits.append(tuple(qubits)) + self._records[key].append(tuple(measurement)) def record_channel_measurement(self, key: 'cirq.MeasurementKey', measurement: int): - if key in self._measurement_types: + if key not in self._measurement_types: + self._measurement_types[key] = MeasurementType.CHANNEL + self._channel_records[key] = [] + if self._measurement_types[key] != MeasurementType.CHANNEL: raise ValueError(f"Measurement already logged to key {key}") - self._measurement_types[key] = MeasurementType.CHANNEL - self._channel_measurements[key] = measurement + self._channel_records[key].append(measurement) - def get_digits(self, key: 'cirq.MeasurementKey') -> Tuple[int, ...]: + def get_digits(self, key: 'cirq.MeasurementKey', index=-1) -> Tuple[int, ...]: return ( - self._measurements[key] + self._records[key][index] if self._measurement_types[key] == MeasurementType.MEASUREMENT - else (self._channel_measurements[key],) + else (self._channel_records[key][index],) ) - def get_int(self, key: 'cirq.MeasurementKey') -> int: + def get_int(self, key: 'cirq.MeasurementKey', index=-1) -> int: if key not in self._measurement_types: - raise KeyError(f'The measurement key {key} is not in {self._measurements}') + raise KeyError(f'The measurement key {key} is not in {self._measurement_types}') measurement_type = self._measurement_types[key] if measurement_type == MeasurementType.CHANNEL: - return self._channel_measurements[key] + return self._channel_records[key][index] if key not in self._measured_qubits: - return digits.big_endian_bits_to_int(self._measurements[key]) + return digits.big_endian_bits_to_int(self._records[key][index]) return digits.big_endian_digits_to_int( - self._measurements[key], base=[q.dimension for q in self._measured_qubits[key]] + self._records[key][index], + base=[q.dimension for q in self._measured_qubits[key][index]], ) def copy(self): return ClassicalDataDictionaryStore( - _measurements=self._measurements.copy(), + _records=self._records.copy(), _measured_qubits=self._measured_qubits.copy(), - _channel_measurements=self._channel_measurements.copy(), + _channel_records=self._channel_records.copy(), _measurement_types=self._measurement_types.copy(), ) def _json_dict_(self): return { - 'measurements': list(self.measurements.items()), + 'records': list(self.records.items()), 'measured_qubits': list(self.measured_qubits.items()), - 'channel_measurements': list(self.channel_measurements.items()), + 'channel_records': list(self.channel_records.items()), 'measurement_types': list(self.measurement_types.items()), } @classmethod def _from_json_dict_( - cls, measurements, measured_qubits, channel_measurements, measurement_types, **kwargs + cls, records, measured_qubits, channel_records, measurement_types, **kwargs ): return cls( - _measurements=dict(measurements), + _records=dict(records), _measured_qubits=dict(measured_qubits), - _channel_measurements=dict(channel_measurements), + _channel_records=dict(channel_records), _measurement_types=dict(measurement_types), ) def __repr__(self): return ( - f'cirq.ClassicalDataDictionaryStore(_measurements={self.measurements!r},' + f'cirq.ClassicalDataDictionaryStore(_records={self.records!r},' f' _measured_qubits={self.measured_qubits!r},' - f' _channel_measurements={self.channel_measurements!r},' + f' _channel_records={self.channel_records!r},' f' _measurement_types={self.measurement_types!r})' ) def _value_equality_values_(self): return ( - self._measurements, - self._channel_measurements, + self._records, + self._channel_records, self._measurement_types, self._measured_qubits, ) diff --git a/cirq-core/cirq/value/classical_data_test.py b/cirq-core/cirq/value/classical_data_test.py index 00cfe475d0e..9b4fa156bec 100644 --- a/cirq-core/cirq/value/classical_data_test.py +++ b/cirq-core/cirq/value/classical_data_test.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import re import pytest @@ -23,20 +24,20 @@ def test_init(): cd = cirq.ClassicalDataDictionaryStore() - assert cd.measurements == {} + assert cd.records == {} assert cd.keys() == () assert cd.measured_qubits == {} - assert cd.channel_measurements == {} + assert cd.channel_records == {} assert cd.measurement_types == {} cd = cirq.ClassicalDataDictionaryStore( - _measurements={mkey_m: (0, 1)}, - _measured_qubits={mkey_m: two_qubits}, - _channel_measurements={mkey_c: 3}, + _records={mkey_m: [(0, 1)]}, + _measured_qubits={mkey_m: [two_qubits]}, + _channel_records={mkey_c: [3]}, ) - assert cd.measurements == {mkey_m: (0, 1)} + assert cd.records == {mkey_m: [(0, 1)]} assert cd.keys() == (mkey_m, mkey_c) - assert cd.measured_qubits == {mkey_m: two_qubits} - assert cd.channel_measurements == {mkey_c: 3} + assert cd.measured_qubits == {mkey_m: [two_qubits]} + assert cd.channel_records == {mkey_c: [3]} assert cd.measurement_types == { mkey_m: cirq.MeasurementType.MEASUREMENT, mkey_c: cirq.MeasurementType.CHANNEL, @@ -46,9 +47,9 @@ def test_init(): def test_record_measurement(): cd = cirq.ClassicalDataDictionaryStore() cd.record_measurement(mkey_m, (0, 1), two_qubits) - assert cd.measurements == {mkey_m: (0, 1)} + assert cd.records == {mkey_m: [(0, 1)]} assert cd.keys() == (mkey_m,) - assert cd.measured_qubits == {mkey_m: two_qubits} + assert cd.measured_qubits == {mkey_m: [two_qubits]} def test_record_measurement_errors(): @@ -56,30 +57,35 @@ def test_record_measurement_errors(): with pytest.raises(ValueError, match='3 measurements but 2 qubits'): cd.record_measurement(mkey_m, (0, 1, 2), two_qubits) cd.record_measurement(mkey_m, (0, 1), two_qubits) - with pytest.raises(ValueError, match='Measurement already logged to key m'): - cd.record_measurement(mkey_m, (0, 1), two_qubits) + cd.record_measurement(mkey_m, (1, 0), two_qubits) + with pytest.raises( + ValueError, match=re.escape('Measurement shape (2, 2, 2) does not match (2, 2) in m') + ): + cd.record_measurement(mkey_m, (1, 0, 4), tuple(cirq.LineQubit.range(3))) + with pytest.raises( + ValueError, match=re.escape('Measurement shape (3, 3) does not match (2, 2) in m') + ): + cd.record_measurement(mkey_m, (1, 0), tuple(cirq.LineQid.range(2, dimension=3))) def test_record_channel_measurement(): cd = cirq.ClassicalDataDictionaryStore() cd.record_channel_measurement(mkey_m, 1) - assert cd.channel_measurements == {mkey_m: 1} + assert cd.channel_records == {mkey_m: [1]} assert cd.keys() == (mkey_m,) def test_record_channel_measurement_errors(): cd = cirq.ClassicalDataDictionaryStore() cd.record_channel_measurement(mkey_m, 1) - with pytest.raises(ValueError, match='Measurement already logged to key m'): - cd.record_channel_measurement(mkey_m, 1) - with pytest.raises(ValueError, match='Measurement already logged to key m'): + cd.record_channel_measurement(mkey_m, 1) + with pytest.raises(ValueError, match='Channel Measurement already logged to key m'): cd.record_measurement(mkey_m, (0, 1), two_qubits) cd = cirq.ClassicalDataDictionaryStore() cd.record_measurement(mkey_m, (0, 1), two_qubits) + cd.record_measurement(mkey_m, (0, 1), two_qubits) with pytest.raises(ValueError, match='Measurement already logged to key m'): cd.record_channel_measurement(mkey_m, 1) - with pytest.raises(ValueError, match='Measurement already logged to key m'): - cd.record_measurement(mkey_m, (0, 1), two_qubits) def test_get_int(): @@ -102,9 +108,9 @@ def test_get_int(): def test_copy(): cd = cirq.ClassicalDataDictionaryStore( - _measurements={mkey_m: (0, 1)}, - _measured_qubits={mkey_m: two_qubits}, - _channel_measurements={mkey_c: 3}, + _records={mkey_m: [(0, 1)]}, + _measured_qubits={mkey_m: [two_qubits]}, + _channel_records={mkey_c: [3]}, _measurement_types={ mkey_m: cirq.MeasurementType.MEASUREMENT, mkey_c: cirq.MeasurementType.CHANNEL, @@ -113,21 +119,21 @@ def test_copy(): cd1 = cd.copy() assert cd1 is not cd assert cd1 == cd - assert cd1.measurements is not cd.measurements - assert cd1.measurements == cd.measurements + assert cd1.records is not cd.records + assert cd1.records == cd.records assert cd1.measured_qubits is not cd.measured_qubits assert cd1.measured_qubits == cd.measured_qubits - assert cd1.channel_measurements is not cd.channel_measurements - assert cd1.channel_measurements == cd.channel_measurements + assert cd1.channel_records is not cd.channel_records + assert cd1.channel_records == cd.channel_records assert cd1.measurement_types is not cd.measurement_types assert cd1.measurement_types == cd.measurement_types def test_repr(): cd = cirq.ClassicalDataDictionaryStore( - _measurements={mkey_m: (0, 1)}, - _measured_qubits={mkey_m: two_qubits}, - _channel_measurements={mkey_c: 3}, + _records={mkey_m: [(0, 1)]}, + _measured_qubits={mkey_m: [two_qubits]}, + _channel_records={mkey_c: [3]}, _measurement_types={ mkey_m: cirq.MeasurementType.MEASUREMENT, mkey_c: cirq.MeasurementType.CHANNEL, diff --git a/cirq-core/cirq/value/condition.py b/cirq-core/cirq/value/condition.py index 7c594eb2d95..081c2d04e71 100644 --- a/cirq-core/cirq/value/condition.py +++ b/cirq-core/cirq/value/condition.py @@ -87,6 +87,7 @@ class KeyCondition(Condition): """ key: 'cirq.MeasurementKey' + index: int = -1 @property def keys(self): @@ -96,9 +97,11 @@ def replace_key(self, current: 'cirq.MeasurementKey', replacement: 'cirq.Measure return KeyCondition(replacement) if self.key == current else self def __str__(self): - return str(self.key) + return str(self.key) if self.index == -1 else f'{self.key}[{self.index}]' def __repr__(self): + if self.index != -1: + return f'cirq.KeyCondition({self.key!r}, {self.index})' return f'cirq.KeyCondition({self.key!r})' def resolve( @@ -107,7 +110,7 @@ def resolve( ) -> bool: if self.key not in classical_data.keys(): raise ValueError(f'Measurement key {self.key} missing when testing classical control') - return classical_data.get_int(self.key) != 0 + return classical_data.get_int(self.key, self.index) != 0 def _json_dict_(self): return json_serialization.dataclass_json_dict(self) @@ -118,6 +121,8 @@ def _from_json_dict_(cls, key, **kwargs): @property def qasm(self): + if self.index != -1: + raise NotImplementedError('Only most recent measurement at key can be used for QASM.') return f'm_{self.key}!=0' diff --git a/cirq-core/cirq/value/condition_test.py b/cirq-core/cirq/value/condition_test.py index e92029b1bfb..29148853994 100644 --- a/cirq-core/cirq/value/condition_test.py +++ b/cirq-core/cirq/value/condition_test.py @@ -35,25 +35,27 @@ def test_key_condition_with_keys(): def test_key_condition_str(): assert str(init_key_condition) == '0:a' + assert str(cirq.KeyCondition(key_a, index=-2)) == '0:a[-2]' def test_key_condition_repr(): cirq.testing.assert_equivalent_repr(init_key_condition) + cirq.testing.assert_equivalent_repr(cirq.KeyCondition(key_a, index=-2)) def test_key_condition_resolve(): - def resolve(measurements): - classical_data = cirq.ClassicalDataDictionaryStore(_measurements=measurements) + def resolve(records): + classical_data = cirq.ClassicalDataDictionaryStore(_records=records) return init_key_condition.resolve(classical_data) - assert resolve({'0:a': [1]}) - assert resolve({'0:a': [2]}) - assert resolve({'0:a': [0, 1]}) - assert resolve({'0:a': [1, 0]}) - assert not resolve({'0:a': [0]}) - assert not resolve({'0:a': [0, 0]}) - assert not resolve({'0:a': []}) - assert not resolve({'0:a': [0], 'b': [1]}) + assert resolve({'0:a': [[1]]}) + assert resolve({'0:a': [[2]]}) + assert resolve({'0:a': [[0, 1]]}) + assert resolve({'0:a': [[1, 0]]}) + assert not resolve({'0:a': [[0]]}) + assert not resolve({'0:a': [[0, 0]]}) + assert not resolve({'0:a': [[]]}) + assert not resolve({'0:a': [[0]], 'b': [[1]]}) with pytest.raises( ValueError, match='Measurement key 0:a missing when testing classical control' ): @@ -61,7 +63,7 @@ def resolve(measurements): with pytest.raises( ValueError, match='Measurement key 0:a missing when testing classical control' ): - _ = resolve({'0:b': [1]}) + _ = resolve({'0:b': [[1]]}) def test_key_condition_qasm(): @@ -84,18 +86,18 @@ def test_sympy_condition_repr(): def test_sympy_condition_resolve(): - def resolve(measurements): - classical_data = cirq.ClassicalDataDictionaryStore(_measurements=measurements) + def resolve(records): + classical_data = cirq.ClassicalDataDictionaryStore(_records=records) return init_sympy_condition.resolve(classical_data) - assert resolve({'0:a': [1]}) - assert resolve({'0:a': [2]}) - assert resolve({'0:a': [0, 1]}) - assert resolve({'0:a': [1, 0]}) - assert not resolve({'0:a': [0]}) - assert not resolve({'0:a': [0, 0]}) - assert not resolve({'0:a': []}) - assert not resolve({'0:a': [0], 'b': [1]}) + assert resolve({'0:a': [[1]]}) + assert resolve({'0:a': [[2]]}) + assert resolve({'0:a': [[0, 1]]}) + assert resolve({'0:a': [[1, 0]]}) + assert not resolve({'0:a': [[0]]}) + assert not resolve({'0:a': [[0, 0]]}) + assert not resolve({'0:a': [[]]}) + assert not resolve({'0:a': [[0]], 'b': [[1]]}) with pytest.raises( ValueError, match=re.escape("Measurement keys ['0:a'] missing when testing classical control"), @@ -105,7 +107,7 @@ def resolve(measurements): ValueError, match=re.escape("Measurement keys ['0:a'] missing when testing classical control"), ): - _ = resolve({'0:b': [1]}) + _ = resolve({'0:b': [[1]]}) def test_sympy_condition_qasm(): diff --git a/cirq-core/cirq/value/measurement_key.py b/cirq-core/cirq/value/measurement_key.py index e53eac47fde..e888325db81 100644 --- a/cirq-core/cirq/value/measurement_key.py +++ b/cirq-core/cirq/value/measurement_key.py @@ -130,10 +130,7 @@ def _with_rescoped_keys_( path: Tuple[str, ...], bindable_keys: FrozenSet['MeasurementKey'], ): - new_key = self.replace(path=path + self.path) - if new_key in bindable_keys: - raise ValueError(f'Conflicting measurement keys found: {new_key}') - return new_key + return self.replace(path=path + self.path) def _with_measurement_key_mapping_(self, key_map: Dict[str, str]): if self.name not in key_map: