Skip to content

Commit

Permalink
Feat: add a method to easily access the state "snapshots" (#100)
Browse files Browse the repository at this point in the history
* add a method to easily access the state "snapshots" after sampling

* add unittests

* add release note

* Apply suggestions from code review

* Update povm_toolbox/quantum_info/product_frame.py

* update release note

---------

Co-authored-by: Max Rossmannek <oss@zurich.ibm.com>
  • Loading branch information
timmintam and mrossinek authored Sep 12, 2024
1 parent 24507da commit 4308b4b
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 3 deletions.
23 changes: 22 additions & 1 deletion povm_toolbox/post_processor/povm_post_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
from typing import Any, cast

import numpy as np
from qiskit.quantum_info import SparsePauliOp
from qiskit.quantum_info import Operator, SparsePauliOp

from povm_toolbox.quantum_info import ProductDual
from povm_toolbox.quantum_info.base import BaseDual, BasePOVM
from povm_toolbox.sampler import POVMPubResult

Expand Down Expand Up @@ -220,3 +221,23 @@ def _single_exp_value_and_std(
std = float("NaN")

return exp_val, std

def get_state_snapshot(self, outcome: tuple[int, ...]) -> dict[tuple[int, ...], Operator]:
"""Return the snapshot of the state associated with ``outcome``.
Args:
outcome: the label specifying the snapshot. The outcome is a tuple of integers (one
index per local frame).
Returns:
The snapshot associated with ``outcome``. The snapshot is a product operator, which is
returned as a dictionary mapping the subsystems of the Hilbert space (e.g. qubits) to
the corresponding local operators forming the product operator.
Raises:
NotImplementedError: if the dual frame associated with the post-processor is not
product.
"""
if isinstance(self.dual, ProductDual):
return self.dual.get_operator(outcome)
raise NotImplementedError("This method is only implemented for `ProductDual` objects.")
17 changes: 17 additions & 0 deletions povm_toolbox/quantum_info/product_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,23 @@ def __len__(self) -> int:
"""Return the number of outcomes of the product frame."""
return self.num_operators

def get_operator(self, frame_op_idx: tuple[int, ...]) -> dict[tuple[int, ...], Operator]:
"""Return a product frame operator in a product form.
Args:
frame_op_idx: the label specifying the frame operator to get. The frame operator is
labeled by a tuple of integers (one index per local frame).
Returns:
The product frame operator specified by ``frame_op_idx``. The operator is returned in a
product form. More specifically, is it a dictionary mapping the subsystems to the
corresponding local frame operators forming the product frame operator.
"""
product_operator = {}
for local_idx, (subsystem, povm) in zip(frame_op_idx, self._frames.items()):
product_operator[subsystem] = povm.operators[local_idx]
return product_operator

def _trace_of_prod(self, operator: SparsePauliOp, frame_op_idx: tuple[int, ...]) -> float:
"""Return the trace of the product of a Hermitian operator with a specific frame operator.
Expand Down
10 changes: 10 additions & 0 deletions releasenotes/notes/add-get-snapshot-method-ee4f85a7b06cb373.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
---
features:
- |
A new method, :meth:`.POVMPostProcessor.get_state_snapshot`, is implemented to easily access the
classical "snapshots" of the state after taking some measurements. A classical "snapshot" of the
state is the dual frame operator associated with the corresponding outcome. The new method
returns the snapshot (typically as a product of local operators) associated with the queried
outcome (typically a tuple of integers).
In parallel, the :meth:`.ProductFrame.get_operator` method has been added for more general
access to frame operators in a product form.
24 changes: 22 additions & 2 deletions test/post_processor/test_post_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
import numpy as np
from povm_toolbox.library import ClassicalShadows, LocallyBiasedClassicalShadows
from povm_toolbox.post_processor import POVMPostProcessor
from povm_toolbox.quantum_info.product_dual import ProductDual
from povm_toolbox.quantum_info import MultiQubitDual, ProductDual
from povm_toolbox.sampler import POVMSampler
from qiskit import QuantumCircuit
from qiskit.circuit import Parameter
from qiskit.primitives import StatevectorSampler as Sampler
from qiskit.quantum_info import SparsePauliOp
from qiskit.quantum_info import Operator, SparsePauliOp


class TestPostProcessor(TestCase):
Expand Down Expand Up @@ -187,3 +187,23 @@ def test_catch_zero_division(self):

self.assertAlmostEqual(exp_val, 0.0)
self.assertTrue(np.isnan(std))

def test_get_state_snapshot(self):
"""Test that the ``get_state_snapshot`` method works correctly."""
post_processor = POVMPostProcessor(self.pub_result)

with self.subTest("Test method works correctly"):
outcome = self.pub_result.get_samples()[0]
# check outcome first
self.assertEqual(outcome, (5, 1))
expected_snapshot = {
(0,): Operator([[0.5, 1.5j], [-1.5j, 0.5]]),
(1,): Operator([[-1, 0.0], [0, 2.0]]),
}
snapshot = post_processor.get_state_snapshot(outcome)
# check snapshot
self.assertDictEqual(snapshot, expected_snapshot)

with self.subTest("Test raises errors") and self.assertRaises(NotImplementedError):
post_processor._dual = MultiQubitDual([Operator(np.eye(4))])
_ = post_processor.get_state_snapshot(outcome)
21 changes: 21 additions & 0 deletions test/quantum_info/test_product_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,24 @@ def test_analysis(self):
check = np.ones(len(product_paulis)) * 2**num_qubit
self.assertTrue(np.allclose(decomposition_weights_n_qubit, check))
self.assertTrue(np.allclose(decomposition_weights_product, check))

def test_get_operator(self):
"""Test that the ``get_operator`` method works correctly."""
frame_0 = MultiQubitFrame([Operator.from_label(label) for label in ["I", "X", "Y", "Z"]])
frame_1 = MultiQubitFrame([Operator.from_label(label) for label in ["0", "1"]])
frame_product = ProductFrame.from_list(frames=[frame_0, frame_1])

with self.subTest("Test method works correctly"):
frame_op_idx = (0, 1)
expected_snapshot = {(0,): Operator.from_label("I"), (1,): Operator.from_label("1")}
snapshot = frame_product.get_operator(frame_op_idx)
self.assertDictEqual(snapshot, expected_snapshot)

frame_op_idx = (2, 0)
expected_snapshot = {(0,): Operator.from_label("Y"), (1,): Operator.from_label("0")}
snapshot = frame_product.get_operator(frame_op_idx)
self.assertDictEqual(snapshot, expected_snapshot)

with self.subTest("invalid frame_op_idx") and self.assertRaises(IndexError):
frame_op_idx = (10, 20)
frame_product.get_operator(frame_op_idx)

0 comments on commit 4308b4b

Please sign in to comment.