Skip to content

Commit

Permalink
update product frames
Browse files Browse the repository at this point in the history
  • Loading branch information
timmintam committed Aug 29, 2024
1 parent 7634120 commit 6734c61
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 19 deletions.
50 changes: 34 additions & 16 deletions povm_toolbox/quantum_info/product_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def __init__(self, frames: dict[tuple[int, ...], T]) -> None:
subsystem_indices = set()
self._dimension = 1
self._num_operators = 1
shape: list[int] = []
shape: tuple[int, ...] = tuple()
subshape_ndims = []
for idx, frame in frames.items():
idx_set = set(idx)
if len(idx) != len(idx_set):
Expand All @@ -94,14 +95,16 @@ def __init__(self, frames: dict[tuple[int, ...], T]) -> None:
subsystem_indices.update(idx_set)
self._dimension *= frame.dimension
self._num_operators *= frame.num_operators
shape.append(frame.num_operators)
shape += frame.shape
subshape_ndims.append(len(frame.shape))

self._informationally_complete: bool = all(
[frame.informationally_complete for frame in frames.values()]
)

self._frames = frames
self._shape: tuple[int, ...] = tuple(shape)
self._subshape_ndims = tuple(subshape_ndims)
self._shape = shape

self._check_validity()

Expand Down Expand Up @@ -201,6 +204,31 @@ def _check_validity(self) -> None:
for povm in self._frames.values():
povm._check_validity()

def _ravel_index(self, index: tuple[int, ...]) -> tuple[int, ...]:
"""TODO.
Args:
index: TODO.
Returns:
TODO.
Raises:
ValueError: TODO.
"""
if len(index) != len(self.shape):
raise ValueError("TODO.")

index_processed = []
start = 0
for ndim in self._subshape_ndims:
local_flat_index = np.ravel_multi_index(
index[start : start + ndim], self.shape[start : start + ndim]
)
index_processed.append(int(local_flat_index))
start += ndim
return tuple(index_processed)

def __getitem__(self, sub_system: tuple[int, ...]) -> T:
r"""Return the :class:`.MultiQubitFrame` acting on the specified sub-system.
Expand Down Expand Up @@ -236,6 +264,8 @@ def _trace_of_prod(self, operator: SparsePauliOp, frame_op_idx: tuple[int, ...])
"""
p_idx = 0.0 + 0.0j

index_processed = self._ravel_index(frame_op_idx)

# Second, we iterate over our input operator, ``operator``.
for label, op_coeff in operator.label_iter():
summand = op_coeff
Expand All @@ -249,7 +279,7 @@ def _trace_of_prod(self, operator: SparsePauliOp, frame_op_idx: tuple[int, ...])
sublabel = "".join(label[-(i + 1)] for i in idx)
# Try to obtain the coefficient of the local POVM for this local Pauli term.
try:
local_idx = frame_op_idx[j]
local_idx = index_processed[j]
coeff = povm.pauli_operators[local_idx][sublabel]
except KeyError:
# If it does not exist, the current summand becomes 0 because it would be
Expand All @@ -258,18 +288,6 @@ def _trace_of_prod(self, operator: SparsePauliOp, frame_op_idx: tuple[int, ...])
# In this case we can break the iteration over the remaining local POVMs.
break
except IndexError as exc:
if len(frame_op_idx) <= j:
raise IndexError(
f"The outcome label {frame_op_idx} does not match the expected shape. "
f"It is supposed to contain {len(self._frames)} integers, but has "
f"{len(frame_op_idx)}."
) from exc
if povm.num_operators <= frame_op_idx[j]:
raise IndexError(
f"Outcome index '{frame_op_idx[j]}' is out of range for the local POVM"
f" acting on subsystems {idx}. This POVM has {povm.num_operators}"
" outcomes."
) from exc
raise exc
else:
# If the label does exist, we multiply the coefficient into our summand.
Expand Down
6 changes: 3 additions & 3 deletions test/quantum_info/test_product_povm.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,17 +322,17 @@ def test_errors_analysis(self):
observable = 1.0j * Operator.from_label("XXX")
_ = prod_povm.analysis(observable)
with self.subTest("Test invalid shape for ``frame_op_idx``.") and self.assertRaises(
IndexError
ValueError
):
observable = Operator.from_label("ZZZ")
_ = prod_povm.analysis(observable, frame_op_idx=(0, 0))
with self.subTest(
"Test invalid ``frame_op_idx`` argument (out of range)."
) and self.assertRaises(IndexError):
) and self.assertRaises(ValueError):
observable = Operator.from_label("ZZZ")
_ = prod_povm.analysis(observable, frame_op_idx=(0, 0, 6))
with self.subTest(
"Test invalid ``frame_op_idx`` argument (negative out of range)."
) and self.assertRaises(IndexError):
) and self.assertRaises(ValueError):
observable = Operator.from_label("ZZZ")
_ = prod_povm.analysis(observable, frame_op_idx=(0, 0, -10))

0 comments on commit 6734c61

Please sign in to comment.