diff --git a/cirq-core/cirq/ops/pauli_measurement_gate.py b/cirq-core/cirq/ops/pauli_measurement_gate.py index 8f9a7358f75..2b2667a00a3 100644 --- a/cirq-core/cirq/ops/pauli_measurement_gate.py +++ b/cirq-core/cirq/ops/pauli_measurement_gate.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, FrozenSet, Iterable, Tuple, Sequence, TYPE_CHECKING, Union +from typing import Any, Dict, FrozenSet, Iterable, Tuple, Sequence, TYPE_CHECKING, Union, cast from cirq import protocols, value from cirq.ops import ( raw_types, measurement_gate, op_tree, - dense_pauli_string, + dense_pauli_string as dps, pauli_gates, pauli_string_phasor, ) @@ -38,15 +38,16 @@ class PauliMeasurementGate(raw_types.Gate): def __init__( self, - observable: Iterable['cirq.Pauli'], + observable: Union['cirq.BaseDensePauliString', Iterable['cirq.Pauli']], key: Union[str, 'cirq.MeasurementKey'] = '', ) -> None: """Inits PauliMeasurementGate. Args: observable: Pauli observable to measure. Any `Iterable[cirq.Pauli]` - is a valid Pauli observable, including `cirq.DensePauliString` - instances, which do not contain any identity gates. + is a valid Pauli observable (with a +1 coefficient by default). + If you wish to measure pauli observables with coefficient -1, + then pass a `cirq.DensePauliString` as observable. key: The string key of the measurement. Raises: @@ -54,9 +55,19 @@ def __init__( """ if not observable: raise ValueError(f'Pauli observable {observable} is empty.') - if not all(isinstance(p, pauli_gates.Pauli) for p in observable): + if not all( + isinstance(p, pauli_gates.Pauli) for p in cast(Iterable['cirq.Gate'], observable) + ): raise ValueError(f'Pauli observable {observable} must be Iterable[`cirq.Pauli`].') - self._observable = tuple(observable) + coefficient = ( + observable.coefficient if isinstance(observable, dps.BaseDensePauliString) else 1 + ) + if coefficient not in [+1, -1]: + raise ValueError( + f'`cirq.DensePauliString` observable {observable} must have coefficient +1/-1.' + ) + + self._observable = dps.DensePauliString(observable, coefficient=coefficient) self.key = key # type: ignore @property @@ -94,9 +105,15 @@ def _with_rescoped_keys_( def _with_measurement_key_mapping_(self, key_map: Dict[str, str]) -> 'PauliMeasurementGate': return self.with_key(protocols.with_measurement_key_mapping(self.mkey, key_map)) - def with_observable(self, observable: Iterable['cirq.Pauli']) -> 'PauliMeasurementGate': + def with_observable( + self, observable: Union['cirq.BaseDensePauliString', Iterable['cirq.Pauli']] + ) -> 'PauliMeasurementGate': """Creates a pauli measurement gate with the new observable and same key.""" - if tuple(observable) == self._observable: + if ( + observable + if isinstance(observable, dps.BaseDensePauliString) + else dps.DensePauliString(observable) + ) == self._observable: return self return PauliMeasurementGate(observable, key=self.key) @@ -111,24 +128,30 @@ def _measurement_key_obj_(self) -> 'cirq.MeasurementKey': def observable(self) -> 'cirq.DensePauliString': """Pauli observable which should be measured by the gate.""" - return dense_pauli_string.DensePauliString(self._observable) + return self._observable def _decompose_( self, qubits: Tuple['cirq.Qid', ...] ) -> 'protocols.decompose_protocol.DecomposeResult': any_qubit = qubits[0] - to_z_ops = op_tree.freeze_op_tree(self.observable().on(*qubits).to_z_basis_ops()) + to_z_ops = op_tree.freeze_op_tree(self._observable.on(*qubits).to_z_basis_ops()) xor_decomp = tuple(pauli_string_phasor.xor_nonlocal_decompose(qubits, any_qubit)) yield to_z_ops yield xor_decomp - yield measurement_gate.MeasurementGate(1, self.mkey).on(any_qubit) + yield measurement_gate.MeasurementGate( + 1, self.mkey, invert_mask=(self._observable.coefficient != 1,) + ).on(any_qubit) yield protocols.inverse(xor_decomp) yield protocols.inverse(to_z_ops) def _circuit_diagram_info_( self, args: 'cirq.CircuitDiagramInfoArgs' ) -> 'cirq.CircuitDiagramInfo': - symbols = [f'M({g})' for g in self._observable] + coefficient = '' if self._observable.coefficient == 1 else '-' + symbols = [ + f'M({"" if i else coefficient}{self._observable[i]})' + for i in range(len(self._observable)) + ] # Mention the measurement key. label_map = args.label_map or {} @@ -141,14 +164,14 @@ def _circuit_diagram_info_( return protocols.CircuitDiagramInfo(tuple(symbols)) def _op_repr_(self, qubits: Sequence['cirq.Qid']) -> str: - args = [repr(self.observable().on(*qubits))] + args = [repr(self._observable.on(*qubits))] if self.key != _default_measurement_key(qubits): args.append(f'key={self.mkey!r}') arg_list = ', '.join(args) return f'cirq.measure_single_paulistring({arg_list})' def __repr__(self) -> str: - return f'cirq.PauliMeasurementGate(' f'{self._observable!r}, ' f'{self.mkey!r})' + return f'cirq.PauliMeasurementGate({self._observable!r}, {self.mkey!r})' def _value_equality_values_(self) -> Any: return self.key, self._observable diff --git a/cirq-core/cirq/ops/pauli_measurement_gate_test.py b/cirq-core/cirq/ops/pauli_measurement_gate_test.py index e888f63d066..030192bc00c 100644 --- a/cirq-core/cirq/ops/pauli_measurement_gate_test.py +++ b/cirq-core/cirq/ops/pauli_measurement_gate_test.py @@ -43,7 +43,7 @@ def test_init(observable, key): assert g.num_qubits() == len(observable) assert g.key == 'a' assert g.mkey == cirq.MeasurementKey('a') - assert g._observable == tuple(observable) + assert g._observable == cirq.DensePauliString(observable) assert cirq.qid_shape(g) == (2,) * len(observable) @@ -162,6 +162,9 @@ def test_bad_observable_raises(): with pytest.raises(ValueError, match=r'Pauli observable .* must be Iterable\[`cirq.Pauli`\]'): _ = cirq.PauliMeasurementGate(cirq.DensePauliString('XYZI')) + with pytest.raises(ValueError, match=r'must have coefficient \+1/-1.'): + _ = cirq.PauliMeasurementGate(cirq.DensePauliString('XYZ', coefficient=1j)) + def test_with_observable(): o1 = [cirq.Z, cirq.Y, cirq.X] @@ -170,3 +173,20 @@ def test_with_observable(): g2 = cirq.PauliMeasurementGate(o2, key='a') assert g1.with_observable(o2) == g2 assert g1.with_observable(o1) is g1 + + +@pytest.mark.parametrize( + 'rot, obs, out', + [ + (cirq.I, cirq.DensePauliString("Z", coefficient=+1), 0), + (cirq.I, cirq.DensePauliString("Z", coefficient=-1), 1), + (cirq.Y ** 0.5, cirq.DensePauliString("X", coefficient=+1), 0), + (cirq.Y ** 0.5, cirq.DensePauliString("X", coefficient=-1), 1), + (cirq.X ** -0.5, cirq.DensePauliString("Y", coefficient=+1), 0), + (cirq.X ** -0.5, cirq.DensePauliString("Y", coefficient=-1), 1), + ], +) +def test_pauli_measurement_gate_samples(rot, obs, out): + q = cirq.NamedQubit("q") + c = cirq.Circuit(rot(q), cirq.PauliMeasurementGate(obs, key='out').on(q)) + assert cirq.Simulator().sample(c)['out'][0] == out diff --git a/cirq-core/cirq/protocols/json_test_data/PauliMeasurementGate.json b/cirq-core/cirq/protocols/json_test_data/PauliMeasurementGate.json index 448008f99f8..8154106d323 100644 --- a/cirq-core/cirq/protocols/json_test_data/PauliMeasurementGate.json +++ b/cirq-core/cirq/protocols/json_test_data/PauliMeasurementGate.json @@ -1,42 +1,53 @@ -[{ - "cirq_type": "PauliMeasurementGate", - "observable": [ - { - "cirq_type": "_PauliX", - "exponent": 1.0, - "global_shift": 0.0 +[ + { + "cirq_type": "PauliMeasurementGate", + "observable": { + "cirq_type": "DensePauliString", + "pauli_mask": [ + 1, + 2, + 3 + ], + "coefficient": { + "cirq_type": "complex", + "real": 1.0, + "imag": 0.0 + } }, - { - "cirq_type": "_PauliY", - "exponent": 1.0, - "global_shift": 0.0 + "key": "key" + }, + { + "cirq_type": "PauliMeasurementGate", + "observable": { + "cirq_type": "DensePauliString", + "pauli_mask": [ + 1, + 2, + 3 + ], + "coefficient": { + "cirq_type": "complex", + "real": 1.0, + "imag": 0.0 + } }, - { - "cirq_type": "_PauliZ", - "exponent": 1.0, - "global_shift": 0.0 - } - ], - "key": "key" -}, -{ - "cirq_type": "PauliMeasurementGate", - "observable": [ - { - "cirq_type": "_PauliX", - "exponent": 1.0, - "global_shift": 0.0 + "key": "p:q:key" + }, + { + "cirq_type": "PauliMeasurementGate", + "observable": { + "cirq_type": "DensePauliString", + "pauli_mask": [ + 1, + 2, + 3 + ], + "coefficient": { + "cirq_type": "complex", + "real": -1.0, + "imag": 0.0 + } }, - { - "cirq_type": "_PauliY", - "exponent": 1.0, - "global_shift": 0.0 - }, - { - "cirq_type": "_PauliZ", - "exponent": 1.0, - "global_shift": 0.0 - } - ], - "key": "p:q:key" -}] \ No newline at end of file + "key": "key" + } +] \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/PauliMeasurementGate.json_inward b/cirq-core/cirq/protocols/json_test_data/PauliMeasurementGate.json_inward new file mode 100644 index 00000000000..448008f99f8 --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/PauliMeasurementGate.json_inward @@ -0,0 +1,42 @@ +[{ + "cirq_type": "PauliMeasurementGate", + "observable": [ + { + "cirq_type": "_PauliX", + "exponent": 1.0, + "global_shift": 0.0 + }, + { + "cirq_type": "_PauliY", + "exponent": 1.0, + "global_shift": 0.0 + }, + { + "cirq_type": "_PauliZ", + "exponent": 1.0, + "global_shift": 0.0 + } + ], + "key": "key" +}, +{ + "cirq_type": "PauliMeasurementGate", + "observable": [ + { + "cirq_type": "_PauliX", + "exponent": 1.0, + "global_shift": 0.0 + }, + { + "cirq_type": "_PauliY", + "exponent": 1.0, + "global_shift": 0.0 + }, + { + "cirq_type": "_PauliZ", + "exponent": 1.0, + "global_shift": 0.0 + } + ], + "key": "p:q:key" +}] \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/PauliMeasurementGate.repr b/cirq-core/cirq/protocols/json_test_data/PauliMeasurementGate.repr index 7ea72c781f7..02adaa9842e 100644 --- a/cirq-core/cirq/protocols/json_test_data/PauliMeasurementGate.repr +++ b/cirq-core/cirq/protocols/json_test_data/PauliMeasurementGate.repr @@ -1,5 +1,6 @@ [ cirq.PauliMeasurementGate((cirq.X, cirq.Y, cirq.Z), cirq.MeasurementKey(name='key')), - cirq.PauliMeasurementGate((cirq.X, cirq.Y, cirq.Z), cirq.MeasurementKey(path=('p', 'q'), name='key')), + cirq.PauliMeasurementGate(cirq.DensePauliString("XYZ"), cirq.MeasurementKey(path=('p', 'q'), name='key')), + cirq.PauliMeasurementGate(cirq.DensePauliString("XYZ", coefficient=-1), cirq.MeasurementKey(name='key')), ] diff --git a/cirq-core/cirq/protocols/json_test_data/PauliMeasurementGate.repr_inward b/cirq-core/cirq/protocols/json_test_data/PauliMeasurementGate.repr_inward new file mode 100644 index 00000000000..7ea72c781f7 --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/PauliMeasurementGate.repr_inward @@ -0,0 +1,5 @@ +[ + cirq.PauliMeasurementGate((cirq.X, cirq.Y, cirq.Z), cirq.MeasurementKey(name='key')), + cirq.PauliMeasurementGate((cirq.X, cirq.Y, cirq.Z), cirq.MeasurementKey(path=('p', 'q'), name='key')), +] +