Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Measurement confusion maps #5480

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion cirq-core/cirq/experiments/readout_confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def __init__(
the corresponding confusion matrix.
repetitions: The number of repetitions that were used to estimate the confusion
matrices.
timestamp: The time the data was taken, in seconds since the epoch.
timestamp: The time the data was taken, in seconds since the epoch. This will be
zero for fake data (i.e. data not generated from an experiment).

Raises:
ValueError: If length of `confusion_matrices` and `measure_qubits` is different or if
Expand Down Expand Up @@ -113,6 +114,34 @@ def __init__(
if sum(len(q) for q in self._measure_qubits) != len(self._qubits):
raise ValueError(f"Repeated qubits not allowed in measure_qubits: {measure_qubits}.")

@classmethod
def from_measurement(
cls, gate: ops.MeasurementGate, qubits: Sequence['cirq.Qid']
) -> 'TensoredConfusionMatrices':
"""Generates TCM for the confusion map in a MeasurementGate.

This ignores any invert_mask defined for the gate - it only replicates the confusion map.

Args:
gate: the MeasurementGate to match.
qubits: qubits the gate is applied to.

Returns:
TensoredConfusionMatrices matching the confusion map of the given gate.

Raises:
ValueError: if the gate has no confusion map.
"""
if not gate.confusion_map:
raise ValueError(f"Measurement has no confusion matrices: {gate}")
confusion_matrices = []
ordered_qubits = []
for indices, cm in gate.confusion_map.items():
confusion_matrices.append(cm)
ordered_qubits.append(tuple(qubits[idx] for idx in indices))
# Use zero for reps/timestamp to mark fake data.
95-martin-orion marked this conversation as resolved.
Show resolved Hide resolved
return cls(confusion_matrices, ordered_qubits, repetitions=0, timestamp=0)

@property
def repetitions(self) -> int:
"""The number of repetitions that were used to estimate the confusion matrices."""
Expand Down
21 changes: 21 additions & 0 deletions cirq-core/cirq/experiments/readout_confusion_matrix_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,27 @@ def l2norm(result: np.ndarray):
assert l2norm(corrected_result) <= l2norm(sampled_result)


def test_from_measurement():
qubits = cirq.LineQubit.range(3)
confuse_02 = np.array([[0, 1, 0, 0], [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0]])
confuse_1 = np.array([[0, 1], [1, 0]])
op = cirq.measure(
*qubits,
key='a',
invert_mask=(True, False),
confusion_map={(0, 2): confuse_02, (1,): confuse_1},
)
tcm = cirq.TensoredConfusionMatrices.from_measurement(op.gate, op.qubits)
expected_tcm = cirq.TensoredConfusionMatrices(
[confuse_02, confuse_1], ((qubits[0], qubits[2]), (qubits[1],)), repetitions=0, timestamp=0
)
assert tcm == expected_tcm

no_cm_op = cirq.measure(*qubits, key='a')
with pytest.raises(ValueError, match="Measurement has no confusion matrices"):
_ = cirq.TensoredConfusionMatrices.from_measurement(no_cm_op.gate, no_cm_op.qubits)


def test_readout_confusion_matrix_raises():
num_qubits = 2
confusion_matrix = get_expected_cm(num_qubits, 0.1, 0.2)
Expand Down
9 changes: 7 additions & 2 deletions cirq-core/cirq/ops/measure_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Iterable, List, overload, Optional, Tuple, TYPE_CHECKING, Union
from typing import Callable, Dict, Iterable, List, overload, Optional, Tuple, TYPE_CHECKING, Union

import numpy as np

Expand Down Expand Up @@ -107,6 +107,7 @@ def measure(
*target,
key: Optional[Union[str, 'cirq.MeasurementKey']] = None,
invert_mask: Tuple[bool, ...] = (),
confusion_map: Optional[Dict[Tuple[int, ...], np.ndarray]] = None,
) -> raw_types.Operation:
"""Returns a single MeasurementGate applied to all the given qubits.

Expand All @@ -121,6 +122,10 @@ def measure(
invert_mask: A list of Truthy or Falsey values indicating whether
the corresponding qubits should be flipped. None indicates no
inverting should be done.
confusion_map: A map of qubit index sets (using indices in
`target`) to the 2D confusion matrix for those qubits. Indices
not included use the identity. Applied before invert_mask if both
are provided.

Returns:
An operation targeting the given qubits with a measurement.
Expand All @@ -146,7 +151,7 @@ def measure(
if key is None:
key = _default_measurement_key(targets)
qid_shape = protocols.qid_shape(targets)
return MeasurementGate(len(targets), key, invert_mask, qid_shape).on(*targets)
return MeasurementGate(len(targets), key, invert_mask, qid_shape, confusion_map).on(*targets)


@overload
Expand Down
4 changes: 4 additions & 0 deletions cirq-core/cirq/ops/measure_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ def test_measure_qubits():
assert cirq.measure(cirq.LineQid.for_qid_shape((1, 2, 3)), key='a') == cirq.MeasurementGate(
num_qubits=3, key='a', qid_shape=(1, 2, 3)
).on(*cirq.LineQid.for_qid_shape((1, 2, 3)))
cmap = {(0,): np.array([[0, 1], [1, 0]])}
assert cirq.measure(a, confusion_map=cmap) == cirq.MeasurementGate(
num_qubits=1, key='a', confusion_map=cmap
).on(a)

with pytest.raises(ValueError, match='ndarray'):
_ = cirq.measure(np.array([1, 0]))
Expand Down
98 changes: 71 additions & 27 deletions cirq-core/cirq/ops/measurement_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import numpy as np

from cirq import protocols, value
from cirq import _compat, protocols, value
from cirq.ops import raw_types

if TYPE_CHECKING:
Expand All @@ -40,6 +40,7 @@ def __init__(
key: Union[str, 'cirq.MeasurementKey'] = '',
invert_mask: Tuple[bool, ...] = (),
qid_shape: Tuple[int, ...] = None,
confusion_map: Optional[Dict[Tuple[int, ...], np.ndarray]] = None,
) -> None:
"""Inits MeasurementGate.

Expand All @@ -52,10 +53,15 @@ def __init__(
Qubits with indices past the end of the mask are not flipped.
qid_shape: Specifies the dimension of each qid the measurement
applies to. The default is 2 for every qubit.
confusion_map: A map of qubit index sets (using indices in the
operation generated from this gate) to the 2D confusion matrix
for those qubits. Indices not included use the identity.
Applied before invert_mask if both are provided.

Raises:
ValueError: If the length of invert_mask is greater than num_qubits.
or if the length of qid_shape doesn't equal num_qubits.
ValueError: If invert_mask or confusion_map have indices
greater than the available qubit indices, or if the length of
qid_shape doesn't equal num_qubits.
"""
if qid_shape is None:
if num_qubits is None:
Expand All @@ -74,6 +80,9 @@ def __init__(
self._invert_mask = invert_mask or ()
if self.invert_mask is not None and len(self.invert_mask) > self.num_qubits():
raise ValueError('len(invert_mask) > num_qubits')
self._confusion_map = confusion_map or {}
if any(x >= self.num_qubits() for idx in self._confusion_map for x in idx):
raise ValueError('Confusion matrices have index out of bounds.')

@property
def key(self) -> str:
Expand All @@ -87,6 +96,10 @@ def mkey(self) -> 'cirq.MeasurementKey':
def invert_mask(self) -> Tuple[bool, ...]:
return self._invert_mask

@property
def confusion_map(self) -> Dict[Tuple[int, ...], np.ndarray]:
return self._confusion_map

def _qid_shape_(self) -> Tuple[int, ...]:
return self._qid_shape

Expand All @@ -98,7 +111,11 @@ def with_key(self, key: Union[str, 'cirq.MeasurementKey']) -> 'MeasurementGate':
if key == self.key:
return self
return MeasurementGate(
self.num_qubits(), key=key, invert_mask=self.invert_mask, qid_shape=self._qid_shape
self.num_qubits(),
key=key,
invert_mask=self.invert_mask,
qid_shape=self._qid_shape,
confusion_map=self.confusion_map,
)

def _with_key_path_(self, path: Tuple[str, ...]):
Expand All @@ -116,14 +133,22 @@ def _with_measurement_key_mapping_(self, key_map: Dict[str, str]):
return self.with_key(protocols.with_measurement_key_mapping(self.mkey, key_map))

def with_bits_flipped(self, *bit_positions: int) -> 'MeasurementGate':
"""Toggles whether or not the measurement inverts various outputs."""
"""Toggles whether or not the measurement inverts various outputs.

This only affects the invert_mask, which is applied after confusion
matrices if any are defined.
"""
old_mask = self.invert_mask or ()
n = max(len(old_mask) - 1, *bit_positions) + 1
new_mask = [k < len(old_mask) and old_mask[k] for k in range(n)]
for b in bit_positions:
new_mask[b] = not new_mask[b]
return MeasurementGate(
self.num_qubits(), key=self.key, invert_mask=tuple(new_mask), qid_shape=self._qid_shape
self.num_qubits(),
key=self.key,
invert_mask=tuple(new_mask),
qid_shape=self._qid_shape,
confusion_map=self.confusion_map,
)

def full_invert_mask(self) -> Tuple[bool, ...]:
Expand Down Expand Up @@ -166,12 +191,17 @@ def _circuit_diagram_info_(
self, args: 'cirq.CircuitDiagramInfoArgs'
) -> 'cirq.CircuitDiagramInfo':
symbols = ['M'] * self.num_qubits()

# Show which output bits are negated.
if self.invert_mask:
for i, b in enumerate(self.invert_mask):
if b:
symbols[i] = '!M'
flipped_indices = {i for i, x in enumerate(self.full_invert_mask()) if x}
confused_indices = {x for idxs in self.confusion_map for x in idxs}

# Show which output bits are negated and/or confused.
for i in range(self.num_qubits()):
prefix = ''
if i in flipped_indices:
prefix += '!'
if i in confused_indices:
prefix += '?'
symbols[i] = prefix + symbols[i]

# Mention the measurement key.
label_map = args.label_map or {}
Expand All @@ -184,7 +214,7 @@ def _circuit_diagram_info_(
return protocols.CircuitDiagramInfo(symbols)

def _qasm_(self, args: 'cirq.QasmArgs', qubits: Tuple['cirq.Qid', ...]) -> Optional[str]:
if not all(d == 2 for d in self._qid_shape):
if self.confusion_map or not all(d == 2 for d in self._qid_shape):
return NotImplemented
args.validate_version('2.0')
invert_mask = self.invert_mask
Expand All @@ -202,7 +232,7 @@ def _qasm_(self, args: 'cirq.QasmArgs', qubits: Tuple['cirq.Qid', ...]) -> Optio
def _quil_(
self, qubits: Tuple['cirq.Qid', ...], formatter: 'cirq.QuilFormatter'
) -> Optional[str]:
if not all(d == 2 for d in self._qid_shape):
if self.confusion_map or not all(d == 2 for d in self._qid_shape):
return NotImplemented
invert_mask = self.invert_mask
if len(invert_mask) < len(qubits):
Expand All @@ -222,28 +252,39 @@ def _op_repr_(self, qubits: Sequence['cirq.Qid']) -> str:
args.append(f'key={self.mkey!r}')
if self.invert_mask:
args.append(f'invert_mask={self.invert_mask!r}')
if self.confusion_map:
proper_map_str = ', '.join(
f"{k!r}: {_compat.proper_repr(v)}" for k, v in self.confusion_map.items()
)
args.append(f'confusion_map={{{proper_map_str}}}')
arg_list = ', '.join(args)
return f'cirq.measure({arg_list})'

def __repr__(self):
qid_shape_arg = ''
args = [f'{self.num_qubits()!r}', f'{self.mkey!r}', f'{self.invert_mask}']
if any(d != 2 for d in self._qid_shape):
qid_shape_arg = f', {self._qid_shape!r}'
return (
f'cirq.MeasurementGate('
f'{self.num_qubits()!r}, '
f'{self.mkey!r}, '
f'{self.invert_mask}'
f'{qid_shape_arg})'
)
args.append(f'qid_shape={self._qid_shape!r}')
if self.confusion_map:
proper_map_str = ', '.join(
f"{k!r}: {_compat.proper_repr(v)}" for k, v in self.confusion_map.items()
)
args.append(f'confusion_map={{{proper_map_str}}}')
return f'cirq.MeasurementGate({", ".join(args)})'

def _value_equality_values_(self) -> Any:
return self.key, self.invert_mask, self._qid_shape
hashable_cmap = frozenset(
(idxs, tuple(v for _, v in np.ndenumerate(cmap)))
for idxs, cmap in self._confusion_map.items()
)
return self.key, self.invert_mask, self._qid_shape, hashable_cmap

def _json_dict_(self) -> Dict[str, Any]:
other = {}
other: Dict[str, Any] = {}
if not all(d == 2 for d in self._qid_shape):
other['qid_shape'] = self._qid_shape
if self.confusion_map:
json_cmap = [(k, v.tolist()) for k, v in self.confusion_map.items()]
other['confusion_map'] = json_cmap
return {
'num_qubits': len(self._qid_shape),
'key': self.key,
Expand All @@ -252,12 +293,15 @@ def _json_dict_(self) -> Dict[str, Any]:
}

@classmethod
def _from_json_dict_(cls, num_qubits, key, invert_mask, qid_shape=None, **kwargs):
def _from_json_dict_(
cls, num_qubits, key, invert_mask, qid_shape=None, confusion_map=None, **kwargs
):
return cls(
num_qubits=num_qubits,
key=value.MeasurementKey.parse_serialized(key),
invert_mask=tuple(invert_mask),
qid_shape=None if qid_shape is None else tuple(qid_shape),
confusion_map={tuple(k): np.array(v) for k, v in confusion_map or []},
)

def _has_stabilizer_effect_(self) -> Optional[bool]:
Expand All @@ -268,7 +312,7 @@ def _act_on_(self, sim_state: 'cirq.SimulationStateBase', qubits: Sequence['cirq

if not isinstance(sim_state, SimulationState):
return NotImplemented
sim_state.measure(qubits, self.key, self.full_invert_mask())
sim_state.measure(qubits, self.key, self.full_invert_mask(), self.confusion_map)
return True


Expand Down
Loading