Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make PauliMeasurementGate respect sign of the pauli observable. #4836

Merged
merged 2 commits into from
Jan 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 38 additions & 15 deletions cirq-core/cirq/ops/pauli_measurement_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, FrozenSet, Iterable, Tuple, Sequence, TYPE_CHECKING, Union
from typing import Any, Dict, FrozenSet, Iterable, Tuple, Sequence, TYPE_CHECKING, Union, cast

from cirq import protocols, value
from cirq.ops import (
raw_types,
measurement_gate,
op_tree,
dense_pauli_string,
dense_pauli_string as dps,
pauli_gates,
pauli_string_phasor,
)
Expand All @@ -38,25 +38,36 @@ class PauliMeasurementGate(raw_types.Gate):

def __init__(
self,
observable: Iterable['cirq.Pauli'],
observable: Union['cirq.BaseDensePauliString', Iterable['cirq.Pauli']],
key: Union[str, 'cirq.MeasurementKey'] = '',
) -> None:
"""Inits PauliMeasurementGate.

Args:
observable: Pauli observable to measure. Any `Iterable[cirq.Pauli]`
is a valid Pauli observable, including `cirq.DensePauliString`
instances, which do not contain any identity gates.
is a valid Pauli observable (with a +1 coefficient by default).
If you wish to measure pauli observables with coefficient -1,
then pass a `cirq.DensePauliString` as observable.
key: The string key of the measurement.

Raises:
ValueError: If the observable is empty.
"""
if not observable:
raise ValueError(f'Pauli observable {observable} is empty.')
if not all(isinstance(p, pauli_gates.Pauli) for p in observable):
if not all(
isinstance(p, pauli_gates.Pauli) for p in cast(Iterable['cirq.Gate'], observable)
):
raise ValueError(f'Pauli observable {observable} must be Iterable[`cirq.Pauli`].')
self._observable = tuple(observable)
coefficient = (
observable.coefficient if isinstance(observable, dps.BaseDensePauliString) else 1
)
if coefficient not in [+1, -1]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: can just do [1, -1]

raise ValueError(
f'`cirq.DensePauliString` observable {observable} must have coefficient +1/-1.'
)

self._observable = dps.DensePauliString(observable, coefficient=coefficient)
self.key = key # type: ignore

@property
Expand Down Expand Up @@ -94,9 +105,15 @@ def _with_rescoped_keys_(
def _with_measurement_key_mapping_(self, key_map: Dict[str, str]) -> 'PauliMeasurementGate':
return self.with_key(protocols.with_measurement_key_mapping(self.mkey, key_map))

def with_observable(self, observable: Iterable['cirq.Pauli']) -> 'PauliMeasurementGate':
def with_observable(
self, observable: Union['cirq.BaseDensePauliString', Iterable['cirq.Pauli']]
) -> 'PauliMeasurementGate':
"""Creates a pauli measurement gate with the new observable and same key."""
if tuple(observable) == self._observable:
if (
observable
if isinstance(observable, dps.BaseDensePauliString)
else dps.DensePauliString(observable)
) == self._observable:
return self
return PauliMeasurementGate(observable, key=self.key)

Expand All @@ -111,24 +128,30 @@ def _measurement_key_obj_(self) -> 'cirq.MeasurementKey':

def observable(self) -> 'cirq.DensePauliString':
"""Pauli observable which should be measured by the gate."""
return dense_pauli_string.DensePauliString(self._observable)
return self._observable

def _decompose_(
self, qubits: Tuple['cirq.Qid', ...]
) -> 'protocols.decompose_protocol.DecomposeResult':
any_qubit = qubits[0]
to_z_ops = op_tree.freeze_op_tree(self.observable().on(*qubits).to_z_basis_ops())
to_z_ops = op_tree.freeze_op_tree(self._observable.on(*qubits).to_z_basis_ops())
xor_decomp = tuple(pauli_string_phasor.xor_nonlocal_decompose(qubits, any_qubit))
yield to_z_ops
yield xor_decomp
yield measurement_gate.MeasurementGate(1, self.mkey).on(any_qubit)
yield measurement_gate.MeasurementGate(
1, self.mkey, invert_mask=(self._observable.coefficient != 1,)
).on(any_qubit)
yield protocols.inverse(xor_decomp)
yield protocols.inverse(to_z_ops)

def _circuit_diagram_info_(
self, args: 'cirq.CircuitDiagramInfoArgs'
) -> 'cirq.CircuitDiagramInfo':
symbols = [f'M({g})' for g in self._observable]
coefficient = '' if self._observable.coefficient == 1 else '-'
symbols = [
f'M({"" if i else coefficient}{self._observable[i]})'
for i in range(len(self._observable))
]

# Mention the measurement key.
label_map = args.label_map or {}
Expand All @@ -141,14 +164,14 @@ def _circuit_diagram_info_(
return protocols.CircuitDiagramInfo(tuple(symbols))

def _op_repr_(self, qubits: Sequence['cirq.Qid']) -> str:
args = [repr(self.observable().on(*qubits))]
args = [repr(self._observable.on(*qubits))]
if self.key != _default_measurement_key(qubits):
args.append(f'key={self.mkey!r}')
arg_list = ', '.join(args)
return f'cirq.measure_single_paulistring({arg_list})'

def __repr__(self) -> str:
return f'cirq.PauliMeasurementGate(' f'{self._observable!r}, ' f'{self.mkey!r})'
return f'cirq.PauliMeasurementGate({self._observable!r}, {self.mkey!r})'

def _value_equality_values_(self) -> Any:
return self.key, self._observable
Expand Down
22 changes: 21 additions & 1 deletion cirq-core/cirq/ops/pauli_measurement_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_init(observable, key):
assert g.num_qubits() == len(observable)
assert g.key == 'a'
assert g.mkey == cirq.MeasurementKey('a')
assert g._observable == tuple(observable)
assert g._observable == cirq.DensePauliString(observable)
assert cirq.qid_shape(g) == (2,) * len(observable)


Expand Down Expand Up @@ -162,6 +162,9 @@ def test_bad_observable_raises():
with pytest.raises(ValueError, match=r'Pauli observable .* must be Iterable\[`cirq.Pauli`\]'):
_ = cirq.PauliMeasurementGate(cirq.DensePauliString('XYZI'))

with pytest.raises(ValueError, match=r'must have coefficient \+1/-1.'):
_ = cirq.PauliMeasurementGate(cirq.DensePauliString('XYZ', coefficient=1j))


def test_with_observable():
o1 = [cirq.Z, cirq.Y, cirq.X]
Expand All @@ -170,3 +173,20 @@ def test_with_observable():
g2 = cirq.PauliMeasurementGate(o2, key='a')
assert g1.with_observable(o2) == g2
assert g1.with_observable(o1) is g1


@pytest.mark.parametrize(
'rot, obs, out',
[
(cirq.I, cirq.DensePauliString("Z", coefficient=+1), 0),
(cirq.I, cirq.DensePauliString("Z", coefficient=-1), 1),
(cirq.Y ** 0.5, cirq.DensePauliString("X", coefficient=+1), 0),
(cirq.Y ** 0.5, cirq.DensePauliString("X", coefficient=-1), 1),
(cirq.X ** -0.5, cirq.DensePauliString("Y", coefficient=+1), 0),
(cirq.X ** -0.5, cirq.DensePauliString("Y", coefficient=-1), 1),
],
)
def test_pauli_measurement_gate_samples(rot, obs, out):
q = cirq.NamedQubit("q")
c = cirq.Circuit(rot(q), cirq.PauliMeasurementGate(obs, key='out').on(q))
assert cirq.Simulator().sample(c)['out'][0] == out
89 changes: 50 additions & 39 deletions cirq-core/cirq/protocols/json_test_data/PauliMeasurementGate.json
Original file line number Diff line number Diff line change
@@ -1,42 +1,53 @@
[{
"cirq_type": "PauliMeasurementGate",
"observable": [
{
"cirq_type": "_PauliX",
"exponent": 1.0,
"global_shift": 0.0
[
{
"cirq_type": "PauliMeasurementGate",
"observable": {
"cirq_type": "DensePauliString",
"pauli_mask": [
1,
2,
3
],
"coefficient": {
"cirq_type": "complex",
"real": 1.0,
"imag": 0.0
}
},
{
"cirq_type": "_PauliY",
"exponent": 1.0,
"global_shift": 0.0
"key": "key"
},
{
"cirq_type": "PauliMeasurementGate",
"observable": {
"cirq_type": "DensePauliString",
"pauli_mask": [
1,
2,
3
],
"coefficient": {
"cirq_type": "complex",
"real": 1.0,
"imag": 0.0
}
},
{
"cirq_type": "_PauliZ",
"exponent": 1.0,
"global_shift": 0.0
}
],
"key": "key"
},
{
"cirq_type": "PauliMeasurementGate",
"observable": [
{
"cirq_type": "_PauliX",
"exponent": 1.0,
"global_shift": 0.0
"key": "p:q:key"
},
{
"cirq_type": "PauliMeasurementGate",
"observable": {
"cirq_type": "DensePauliString",
"pauli_mask": [
1,
2,
3
],
"coefficient": {
"cirq_type": "complex",
"real": -1.0,
"imag": 0.0
}
},
{
"cirq_type": "_PauliY",
"exponent": 1.0,
"global_shift": 0.0
},
{
"cirq_type": "_PauliZ",
"exponent": 1.0,
"global_shift": 0.0
}
],
"key": "p:q:key"
}]
"key": "key"
}
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
[{
"cirq_type": "PauliMeasurementGate",
"observable": [
{
"cirq_type": "_PauliX",
"exponent": 1.0,
"global_shift": 0.0
},
{
"cirq_type": "_PauliY",
"exponent": 1.0,
"global_shift": 0.0
},
{
"cirq_type": "_PauliZ",
"exponent": 1.0,
"global_shift": 0.0
}
],
"key": "key"
},
{
"cirq_type": "PauliMeasurementGate",
"observable": [
{
"cirq_type": "_PauliX",
"exponent": 1.0,
"global_shift": 0.0
},
{
"cirq_type": "_PauliY",
"exponent": 1.0,
"global_shift": 0.0
},
{
"cirq_type": "_PauliZ",
"exponent": 1.0,
"global_shift": 0.0
}
],
"key": "p:q:key"
}]
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[
cirq.PauliMeasurementGate((cirq.X, cirq.Y, cirq.Z), cirq.MeasurementKey(name='key')),
cirq.PauliMeasurementGate((cirq.X, cirq.Y, cirq.Z), cirq.MeasurementKey(path=('p', 'q'), name='key')),
cirq.PauliMeasurementGate(cirq.DensePauliString("XYZ"), cirq.MeasurementKey(path=('p', 'q'), name='key')),
cirq.PauliMeasurementGate(cirq.DensePauliString("XYZ", coefficient=-1), cirq.MeasurementKey(name='key')),
]

Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[
cirq.PauliMeasurementGate((cirq.X, cirq.Y, cirq.Z), cirq.MeasurementKey(name='key')),
cirq.PauliMeasurementGate((cirq.X, cirq.Y, cirq.Z), cirq.MeasurementKey(path=('p', 'q'), name='key')),
]