From 0fe12a33541bf3a915310d3528bdfcb5c16888b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Duque=20Mesa?= <675763+sduquemesa@users.noreply.github.com> Date: Thu, 28 Apr 2022 17:51:10 -0500 Subject: [PATCH] Concatenate any state (#130) Co-authored-by: ziofil Co-authored-by: Theodor --- .github/CHANGELOG.md | 20 +++++++++ mrmustard/lab/abstract/state.py | 48 +++++++++++++++++---- mrmustard/physics/fock.py | 43 +++++++++++++++++-- tests/test_lab/test_states.py | 52 +++++++++++++++++++++-- tests/test_physics/test_fock/test_fock.py | 34 +++++++++++++++ 5 files changed, 181 insertions(+), 16 deletions(-) diff --git a/.github/CHANGELOG.md b/.github/CHANGELOG.md index b7dd2832f..fc087b400 100644 --- a/.github/CHANGELOG.md +++ b/.github/CHANGELOG.md @@ -4,6 +4,26 @@ * Can switch progress bar on and off (default is on) from the settings via `settings.PROGRESSBAR = True/False`. [(#128)](https://github.com/XanaduAI/MrMustard/issues/128) +* States in Gaussian and Fock representation now can be concatenated. + [(#130)](https://github.com/XanaduAI/MrMustard/pull/130) + + ```python + from mrmustard.lab.states import Gaussian, Fock' + from mrmustard.lab.gates import Attenuator + + # concatenate pure states + fock_state = Fock(4) + gaussian_state = Gaussian(1) + pure_state = fock_state & gaussian_state + + # also can concatenate mixed states + mixed1 = fock_state >> Attenuator(0.8) + mixed2 = gaussian_state >> Attenuator(0.5) + mixed_state = mixed1 & mixed2 + + mixed_state.dm() + ``` + ### Breaking changes ### Improvements diff --git a/mrmustard/lab/abstract/state.py b/mrmustard/lab/abstract/state.py index 4598bb0da..6a2c781f2 100644 --- a/mrmustard/lab/abstract/state.py +++ b/mrmustard/lab/abstract/state.py @@ -242,17 +242,23 @@ def ket(self, cutoffs: List[int] = None) -> Optional[Tensor]: """ if self.is_mixed: return None + cutoffs = ( self.cutoffs if cutoffs is None else [c if c is not None else self.cutoffs[i] for i, c in enumerate(cutoffs)] ) + if self.is_gaussian: self._ket = fock.fock_representation( self.cov, self.means, shape=cutoffs, return_dm=False ) else: # only fock representation is available if self._ket is None: + # if state is pure and has a density matrix, calculate the ket + if self.is_pure: + self._ket = fock.dm_to_ket(self.dm) + return self._ket return None current_cutoffs = list(self._ket.shape[: self.num_modes]) if cutoffs != current_cutoffs: @@ -390,11 +396,34 @@ def primal(self, other: Union[State, Transformation]) -> State: def __and__(self, other: State) -> State: r"""Concatenates two states.""" - if not (self.is_gaussian and other.is_gaussian): - raise NotImplementedError( - "Concatenation of non-gaussian states is not implemented yet." + if not self.is_gaussian or not other.is_gaussian: # convert all to fock now + # TODO: would be more efficient if we could keep pure states as kets + if self.is_mixed or other.is_mixed: + self_fock = self.dm() + other_fock = other.dm() + dm = fock.math.tensordot(self_fock, other_fock, [[], []]) + # e.g. self has shape [1,3,1,3] and other has shape [2,2] + # we want self & other to have shape [1,3,2,1,3,2] + # before transposing shape is [1,3,1,3]+[2,2] + self_idx = list(range(len(self_fock.shape))) + other_idx = list(range(len(self_idx), len(self_idx) + len(other_fock.shape))) + return State( + dm=math.transpose( + dm, + self_idx[: len(self_idx) // 2] + + other_idx[: len(other_idx) // 2] + + self_idx[len(self_idx) // 2 :] + + other_idx[len(other_idx) // 2 :], + ), + modes=self.modes + [m + max(self.modes) + 1 for m in other.modes], + ) + # else, all states are pure + self_fock = self.ket() + other_fock = other.ket() + return State( + ket=fock.math.tensordot(self_fock, other_fock, [[], []]), + modes=self.modes + [m + max(self.modes) + 1 for m in other.modes], ) - cov = gaussian.join_covs([self.cov, other.cov]) means = gaussian.join_means([self.means, other.means]) return State( @@ -432,7 +461,7 @@ def get_modes(self, item): # if not gaussian fock_partitioned = fock.trace( - self.dm(self.cutoffs), [m for m in range(self.num_modes) if m not in item] + self.dm(self.cutoffs), keep=[m for m in range(self.num_modes) if m in item] ) return State(dm=fock_partitioned, modes=item) @@ -449,13 +478,14 @@ def __eq__(self, other): if not np.allclose(self.cov, other.cov, atol=1e-6): return False return True - if self.is_pure and other.is_pure: + try: return np.allclose( self.ket(cutoffs=other.cutoffs), other.ket(cutoffs=other.cutoffs), atol=1e-6 ) - return np.allclose( - self.dm(cutoffs=other.cutoffs), other.dm(cutoffs=other.cutoffs), atol=1e-6 - ) + except TypeError: + return np.allclose( + self.dm(cutoffs=other.cutoffs), other.dm(cutoffs=other.cutoffs), atol=1e-6 + ) def __rshift__(self, other): r"""Applies other (a Transformation) to self (a State), e.g., ``Coherent(x=0.1) >> Sgate(r=0.1)``.""" diff --git a/mrmustard/physics/fock.py b/mrmustard/physics/fock.py index cbd01420c..7c543a8d3 100644 --- a/mrmustard/physics/fock.py +++ b/mrmustard/physics/fock.py @@ -123,6 +123,40 @@ def ket_to_dm(ket: Tensor) -> Tensor: return math.outer(ket, math.conj(ket)) +def dm_to_ket(dm: Tensor) -> Tensor: + r"""Maps a density matrix to a ket if the state is pure. + + If the state is pure :math:`\hat \rho= |\psi\rangle\langle \psi|` then the + ket is the eigenvector of :math:`\rho` corresponding to the eigenvalue 1. + + Args: + dm (Tensor): the density matrix + + Returns: + Tensor: the ket + + Raises: + ValueError: if ket for mixed states cannot be calculated + """ + + is_pure_dm = np.isclose(purity(dm), 1.0, atol=1e-6) + if not is_pure_dm: + raise ValueError("Cannot calculate ket for mixed states.") + + cutoffs = dm.shape[: len(dm.shape) // 2] + d = int(np.prod(cutoffs)) + dm = math.reshape(dm, (d, d)) + dm = normalize(dm, is_dm=True) + + _, eigvecs = math.eigh(dm) + # eigenvalues and related eigenvectors are sorted in non-decreasing order, + # meaning the associated eigvec to eigval 1 is stored last. + ket = eigvecs[:, -1] + ket = math.reshape(ket, cutoffs) + + return ket + + def ket_to_probs(ket: Tensor) -> Tensor: r"""Maps a ket to probabilities. @@ -423,9 +457,12 @@ def trace(dm, keep: List[int]): N = len(dm.shape) // 2 trace = [m for m in range(N) if m not in keep] # put at the end all of the indices to trace over - dm = math.transpose( - dm, [i for pair in [(k, k + N) for k in keep] + [(t, t + N) for t in trace] for i in pair] - ) + keep_idx = [i for pair in [(k, k + N) for k in keep] for i in pair] + keep_idx = keep_idx[::2] + keep_idx[1::2] + trace_idx = [i for pair in [(t, t + N) for t in trace] for i in pair] + trace_idx = trace_idx[::2] + trace_idx[1::2] # stagger the indices + dm = math.transpose(dm, keep_idx + trace_idx) + d = int(np.prod(dm.shape[-len(trace) :])) # make it square on those indices dm = math.reshape(dm, dm.shape[: 2 * len(keep)] + (d, d)) diff --git a/tests/test_lab/test_states.py b/tests/test_lab/test_states.py index 5afc9ccb1..3a709d7b9 100644 --- a/tests/test_lab/test_states.py +++ b/tests/test_lab/test_states.py @@ -17,11 +17,22 @@ from hypothesis import given, strategies as st, assume from hypothesis.extra.numpy import arrays from mrmustard.physics import gaussian as gp -from mrmustard.lab.states import * -from mrmustard.lab.gates import * +from mrmustard.lab.states import ( + Fock, + Coherent, + Vacuum, + Gaussian, + SqueezedVacuum, + DisplacedSqueezed, + Thermal, +) +from mrmustard.lab.gates import Attenuator, Sgate, Dgate, Ggate from mrmustard.lab.abstract import State from mrmustard import settings -from tests import random + +from mrmustard.math import Math + +math = Math() @st.composite @@ -98,6 +109,7 @@ def test_the_purity_of_a_mixed_state(nbar): phi2=st.floats(0.0, 2 * np.pi), ) def test_join_two_states(r1, phi1, r2, phi2): + """Test Sgate acts the same in parallel or individually for two states.""" S1 = Vacuum(1) >> Sgate(r=r1, phi=phi1) S2 = Vacuum(1) >> Sgate(r=r2, phi=phi2) S12 = Vacuum(2) >> Sgate(r=[r1, r2], phi=[phi1, phi2]) @@ -113,6 +125,7 @@ def test_join_two_states(r1, phi1, r2, phi2): phi3=st.floats(0.0, 2 * np.pi), ) def test_join_three_states(r1, phi1, r2, phi2, r3, phi3): + """Test Sgate acts the same in parallel or individually for three states.""" S1 = Vacuum(1) >> Sgate(r=r1, phi=phi1) S2 = Vacuum(1) >> Sgate(r=r2, phi=phi2) S3 = Vacuum(1) >> Sgate(r=r3, phi=phi3) @@ -122,12 +135,14 @@ def test_join_three_states(r1, phi1, r2, phi2, r3, phi3): @given(xy=xy_arrays()) def test_coh_state(xy): + """Test coherent state preparation.""" x, y = xy assert Vacuum(len(x)) >> Dgate(x, y) == Coherent(x, y) @given(r=st.floats(0.0, 1.0), phi=st.floats(0.0, 2 * np.pi)) def test_sq_state(r, phi): + """Test squeezed vacuum preparation.""" assert Vacuum(1) >> Sgate(r, phi) == SqueezedVacuum(r, phi) @@ -138,10 +153,12 @@ def test_sq_state(r, phi): phi=st.floats(0.0, 2 * np.pi), ) def test_dispsq_state(x, y, r, phi): + """Test displaced squeezed state.""" assert Vacuum(1) >> Sgate(r, phi) >> Dgate(x, y) == DisplacedSqueezed(r, phi, x, y) def test_get_modes(): + """Test get_modes returns the states as expected.""" a = Gaussian(2) b = Gaussian(2) assert a == (a & b).get_modes([0, 1]) @@ -150,6 +167,7 @@ def test_get_modes(): @given(m=st.integers(0, 3)) def test_modes_after_projection(m): + """Test number of modes is correct after single projection.""" a = Gaussian(4) << Fock(1)[m] assert np.allclose(a.modes, [k for k in range(4) if k != m]) assert len(a.modes) == 3 @@ -157,6 +175,7 @@ def test_modes_after_projection(m): @given(n=st.integers(0, 3), m=st.integers(0, 3)) def test_modes_after_double_projection(n, m): + """Test number of modes is correct after double projection.""" assume(n != m) a = Gaussian(4) << Fock([1, 2])[n, m] assert np.allclose(a.modes, [k for k in range(4) if k != m and k != n]) @@ -164,7 +183,7 @@ def test_modes_after_double_projection(n, m): def test_random_state_is_entangled(): - """Tests that a Gaussian state generated at random is entangled""" + """Tests that a Gaussian state generated at random is entangled.""" state = Vacuum(2) >> Ggate(num_modes=2) mat = state.cov assert np.allclose(gp.log_negativity(mat, 2), 0.0) @@ -191,3 +210,28 @@ def test_getitem_set_modes(modes): state2 = State(ket=ket, modes=modes) assert state1.modes == state2.modes + + +@pytest.mark.parametrize("pure", [True, False]) +def test_concat_pure_states(pure): + """Test that fock states concatenate correctly and are separable""" + state1 = Fock(1, cutoffs=[15]) + state2 = Fock(4, cutoffs=[15]) + + if not pure: + state1 >>= Attenuator(transmissivity=0.95) + state2 >>= Attenuator(transmissivity=0.9) + + psi = state1 & state2 + + # test concatenated state + psi_dm = math.transpose(math.tensordot(state1.dm(), state2.dm(), [[], []]), [0, 2, 1, 3]) + assert np.allclose(psi.dm(), psi_dm) + + # trace state2 and check resulting dm corresponds to state 1 + dm1 = math.trace(math.transpose(psi.dm(), [0, 2, 1, 3])) + assert np.allclose(state1.dm(), dm1) + + # trace state1 and check resulting dm corresponds to state 2 + dm2 = math.trace(math.transpose(psi.dm(), [1, 3, 0, 2])) + assert np.allclose(state2.dm(), dm2) diff --git a/tests/test_physics/test_fock/test_fock.py b/tests/test_physics/test_fock/test_fock.py index b59bd60ec..c188c5290 100644 --- a/tests/test_physics/test_fock/test_fock.py +++ b/tests/test_physics/test_fock/test_fock.py @@ -13,11 +13,13 @@ # limitations under the License. from hypothesis import settings, given, strategies as st +import pytest import numpy as np from scipy.special import factorial from thewalrus.quantum import total_photon_number_distribution from mrmustard.lab import * +from mrmustard.physics.fock import dm_to_ket, ket_to_dm # helper strategies @@ -139,3 +141,35 @@ def test_density_matrix(num_modes): # rho_legit = L[modes](G(Vacuum(num_modes))).dm(cutoffs=cutoffs) # rho_built = G(Vacuum(num_modes=num_modes)).dm(cutoffs=cutoffs) assert np.allclose(rho_legit, rho_made) + + +@pytest.mark.parametrize( + "state", + [ + Vacuum(num_modes=2), + Fock(4), + Coherent(x=0.1, y=-0.4, cutoffs=[15]), + Gaussian(num_modes=2, cutoffs=[15]), + ], +) +def test_dm_to_ket(state): + """Tests pure state density matrix conversion to ket""" + dm = state.dm() + + ket = dm_to_ket(dm) + # check if ket is normalized + assert np.allclose(np.linalg.norm(ket), 1) + # check kets are equivalent + assert np.allclose(ket, state.ket()) + + dm_reconstructed = ket_to_dm(ket) + # check ket leads to same dm + assert np.allclose(dm, dm_reconstructed) + + +def test_dm_to_ket_error(): + """Test dm_to_ket raises an error when state is mixed""" + state = Coherent(x=0.1, y=-0.4, cutoffs=[15]) >> Attenuator(0.5) + + with pytest.raises(ValueError): + dm_to_ket(state)