Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add schmidt_decomposition function to quantum_info #10104

Merged
merged 25 commits into from
Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
121de7f
add schmidt_decomposition function to states/utils.py
diemilio May 11, 2023
aa44954
doc: add docstrings to schmidt_decomposition function
diemilio May 12, 2023
4c63ba2
tox -eblack
diemilio May 12, 2023
0bcbe8f
lint changes
diemilio May 12, 2023
f139fdf
fix qubit ordering definitions in docstring
diemilio May 13, 2023
897935d
add abstol condition to save singular values
diemilio May 16, 2023
c633288
add schmidt_decomposition to __init__.py files in .quantum_info and .…
diemilio May 16, 2023
017f506
Merge branch 'main' into add-schmidt-decomp
diemilio May 16, 2023
c1102c9
add tests for schmidt_decomposition function
diemilio May 16, 2023
9ba232d
Merge remote-tracking branch 'origin/add-schmidt-decomp' into add-sch…
diemilio May 16, 2023
b92f0d2
tox lint test
diemilio May 16, 2023
e4c35f8
correct test ordering
diemilio May 16, 2023
e3a8fc2
(docs) add schmidt_decomposition to __init__.py docstring
diemilio May 20, 2023
14bcee6
relabel subsystems in schmidt_decomposition for little-endian consist…
diemilio May 22, 2023
54b6e6c
(test) add test for individual elements of Schmidt decomposition
diemilio May 23, 2023
a6084c6
tox and lint
diemilio May 23, 2023
29224cc
Merge branch 'main' into add-schmidt-decomp
diemilio May 23, 2023
c613771
(test) add schmidt component check for 3-level system
diemilio May 23, 2023
000ca7f
(docs) add reno
diemilio May 23, 2023
ccc364c
(test) change assertions for schmidt_decomposition tests to AlmostEqual
diemilio May 23, 2023
7973c99
(docs) add note to docstring to clarify system partition to perform t…
diemilio May 29, 2023
0e20673
Merge branch 'main' into add-schmidt-decomp
diemilio May 29, 2023
a66ab76
(chore) fix typos, lint changes
diemilio May 29, 2023
448802a
separate test functions, check for state with diffprob amps
diemilio Jun 19, 2023
c2d355d
eblack lint changes
diemilio Jun 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions qiskit/quantum_info/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@
mutual_information,
partial_trace,
purity,
schmidt_decomposition,
diemilio marked this conversation as resolved.
Show resolved Hide resolved
shannon_entropy,
state_fidelity,
)
Expand Down
2 changes: 1 addition & 1 deletion qiskit/quantum_info/states/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .statevector import Statevector
from .stabilizerstate import StabilizerState
from .densitymatrix import DensityMatrix
from .utils import partial_trace, shannon_entropy
from .utils import partial_trace, schmidt_decomposition, shannon_entropy
from .measures import (
state_fidelity,
purity,
Expand Down
78 changes: 78 additions & 0 deletions qiskit/quantum_info/states/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from qiskit.quantum_info.states.statevector import Statevector
from qiskit.quantum_info.states.densitymatrix import DensityMatrix
from qiskit.quantum_info.operators.channel import SuperOp
from qiskit.quantum_info.operators.predicates import ATOL_DEFAULT


def partial_trace(state, qargs):
Expand Down Expand Up @@ -118,6 +119,83 @@ def logfn(x):
return h_val


def schmidt_decomposition(state, qargs):
r"""Return the Schmidt Decomposition of a pure quantum state.

For an arbitrary bipartite state::

.. math::
|\psi\rangle_{AB} = \sum_{i,j} c_{ij}
|x_i\rangle_A \otimes |y_j\rangle_B,

its Schmidt Decomposition is given by the single-index sum over k:

.. math::
|\psi\rangle_{AB} = \sum_{k} \lambda_{k}
|u_k\rangle_A \otimes |v_k\rangle_B

where :math:`|u_k\rangle_A` and :math:`|v_k\rangle_B` are an
orthonormal set of vectors in their respective spaces :math:`A` and :math:`B`,
and the Schmidt coefficients :math:`\lambda_k` are positive real values.

Args:
state (Statevector or DensityMatrix): the input state.
qargs (list): the list of Input state positions corresponding to subsystem :math:`B`.

Returns:
list: list of tuples ``(s, u, v)``, where ``s`` (float) are the
Schmidt coefficients :math:`\lambda_k`, and ``u`` (Statevector),
``v`` (Statevector) are the Schmidt vectors
:math:`|u_k\rangle_A`, :math:`|u_k\rangle_B`, respectively.

Raises:
QiskitError: if Input qargs is not a list of positions of the Input state.
QiskitError: if Input qargs is not a proper subset of Input state.
"""
state = _format_state(state, validate=False)

# convert to statevector if state is density matrix. Errors if state is mixed.
if isinstance(state, DensityMatrix):
state = state.to_statevector()

# reshape statevector into state tensor
dims = state.dims()
state_tens = state._data.reshape(dims[::-1])
ndim = state_tens.ndim

# check if qargs are valid
if not isinstance(qargs, (list, np.ndarray)):
raise QiskitError("Input qargs is not a list of positions of the Input state")
qudits = list(range(ndim))
qargs = set(qargs)
if qargs == set(qudits) or not qargs.issubset(qudits):
raise QiskitError("Input qargs is not a proper subset of Input state")

# define subsystem A and B qargs and dims
qargs_a = list(qargs)
qargs_b = [i for i in qudits if i not in qargs_a]
dims_a = state.dims(qargs_a)
dims_b = state.dims(qargs_b)
ndim_a = np.prod(dims_a)
ndim_b = np.prod(dims_b)

# permute state for desired qargs order
qargs_axes = [list(qudits)[::-1].index(i) for i in qargs_a + qargs_b][::-1]
state_tens = state_tens.transpose(qargs_axes)

# convert state tensor to matrix of prob amplitudes and perform svd.
state_mat = state_tens.reshape([ndim_a, ndim_b])
u_mat, s_arr, vh_mat = np.linalg.svd(state_mat, full_matrices=False)

schmidt_components = [
(s, Statevector(u, dims=dims_a), Statevector(v, dims=dims_b))
for s, u, v in zip(s_arr, u_mat.T, vh_mat)
if s > ATOL_DEFAULT
]

return schmidt_components


def _format_state(state, validate=True):
"""Format input state into class object"""
if isinstance(state, list):
Expand Down
31 changes: 30 additions & 1 deletion test/python/quantum_info/states/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from qiskit.test import QiskitTestCase
from qiskit.quantum_info.states import Statevector, DensityMatrix
from qiskit.quantum_info.states import partial_trace, shannon_entropy
from qiskit.quantum_info.states import partial_trace, shannon_entropy, schmidt_decomposition


class TestStateUtils(QiskitTestCase):
Expand Down Expand Up @@ -55,6 +55,35 @@ def test_shannon_entropy(self):
# Base 10
self.assertAlmostEqual(0.533908120973504, shannon_entropy(input_pvec, 10))

def test_schmidt_decomposition(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you test to see if you are getting the correct decomposition, not just the return to the original.

Copy link
Contributor Author

@diemilio diemilio May 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ikkoham. I had already tested that the correct decomposition was being returned, but working on the implementation of these tests made me realize something. As you probably know, the SVD is a unique decomposition but only up to the preservation of the sign parity of the singular vectors. So for example, the state |++⟩ has these two decompositions which are equally valid:
λ|u⟩|v⟩ = 1.0 (1.0 |+⟩) (1.0 |+⟩)
λ|u⟩|v⟩ = 1.0 (-1.0 |+⟩) (-1.0 |+⟩)

The global phase of the reconstructed total state remains the same, but the individual "global" phases of each singular vector can be different as long as they preserve the parity.

The sign selection for the singular vectors depends on the underlying algorithm used for the SVD which, in this case is numpy.linalg.svd. This really isn't a big problem, except that for simple cases of separable states (like the ones you can construct from the Statevector.from_label method), I am always getting singular vectors with the negative sign, so for the self.assertEqual to work, I am going to have to premultiply some terms by -1 in the test function. Again, there is nothing incorrect about this as long as it is done keeping in mind that the parity must be preserved, my worry is that if someone is going thru the code later on, it might be confusing.

What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went ahead and implemented the test as described above, but let me know if you would like to see it implemented in a different way. Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. AFAIK, there are no rules or conventions on that point. I'm fine with this implementation. (I'd like to hear other's opinion.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ikkoham. I found this very helpful report on the sign ambiguity of the SVD:
https://www.osti.gov/servlets/purl/920802

What they suggest to deal with the sign ambiguity, is:
"...in order to identify the sign of a singular vector, it is suggested that it be similar to the sign of the majority of vectors it is representing."

They propose a function to do this, but the issue is that they only consider real-valued matrices. I am not familiar with how SVD algorithms work, so I don't know if in the case of complex-valued data, the SVD could also return singular vectors with arbitrary phases as long long as the overall sign of the sum term is preserved. If this is the case, the implementation of this function can get complicated.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of checking for equality between individual statevectors perhaps it would be better to just verify that the result does indeed recompose and satisfy all properties as described in your documentation of a Schmidt decomposition.

The question of numerical stability of the decomposition due to the sign issue could be treated in a separate test since it addresses a somewhat distinct issue. I'm not yet familiar with what the best convention might be at the moment so I'd be ok with any convention, which might include unhandled, and clearly documenting the choice.

Also in your current test function it would be better to split it up into several tests along the lines of each commented block you already have.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ewinston. Thanks for the feedback.

  1. Currently there are tests for both: 1) check for equality between individual singular vectors and singular values, and 2) test to verify that the result recomposes. Is your suggestion to get rid of 1)?
  2. Regarding sign ambiguity testing, do you mean we should just keep the test we have right now but move it under a different test function? Or are you referring to having a different set of tests altogether?
  3. Regarding sign convention selection, I am not quite sure what algorithm is used in the numpy implementation of the svd, but in the paper I shared it is mentioned that if it is a Lanczos-based method, there is a random component to it, so the selected sign might not be guaranteed. The best solution would be to implement the sign correction process suggested in that paper, but again, here we're just trying to correct for some "global" phase factor which is less critical in quantum computing compared to other applications where the sign might play an important role.
  4. I will split the test functions just as you suggested. I will probably wait to hear back from you regarding items 1 and 2 before doing so, so I can make the final changes all together.

"""Test schmidt_decomposition function"""

# separable 2-level system without subsystem permutation
target = Statevector.from_label("10l")
schmidt_comps = schmidt_decomposition(Statevector.from_label("10l"), [0])
state = Statevector(sum(suv[0] * np.kron(suv[1], suv[2]) for suv in schmidt_comps))
self.assertEqual(state, target)

# separable 2-level system with subsystem permutation
target = Statevector.from_label("0l1")
schmidt_comps = schmidt_decomposition(Statevector.from_label("l10"), [2, 1])
state = Statevector(sum(suv[0] * np.kron(suv[1], suv[2]) for suv in schmidt_comps))
self.assertEqual(state, target)

# entangled 2-level system
target = 1 / np.sqrt(2) * (Statevector.from_label("00") + Statevector.from_label("11"))
schmidt_comps = schmidt_decomposition(target, [0])
state = Statevector(sum(suv[0] * np.kron(suv[1], suv[2]) for suv in schmidt_comps))
self.assertEqual(state, target)

# entangled 3-level system
target = Statevector(np.array([1, 0, 0, 0, 1, 0, 0, 0, 1]) * 1 / np.sqrt(3), dims=(3, 3))
schmidt_comps = schmidt_decomposition(target, [0])
state = Statevector(
sum(suv[0] * np.kron(suv[1], suv[2]) for suv in schmidt_comps), dims=(3, 3)
)
self.assertEqual(state, target)


if __name__ == "__main__":
unittest.main()