Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Povm dev #31

Merged
merged 14 commits into from
Jul 18, 2022
Merged
5 changes: 3 additions & 2 deletions documentation/source/operator.rst
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ resulting in

:math:`\dot{P}^\textbf{a} = \mathrm{tr}(\mathcal{L}[M^\textbf{c}]M^\textbf{a}) T^{-1\textbf{cb}} P^\textbf{b} = \mathcal{L}^\textbf{ab}P^\textbf{b}`.

Typically :math:`\mathcal{L}[\rho]` consists of 2-body (1-body) operators which translate to (real) :math:`16\times 16` (:math:`4\times 4`) matrices in the POVM-picture. Importantly, only objects of this size need to be stored.
Typically :math:`\mathcal{L}[\rho]` consists of 2-body (1-body) operators which translate to (real) :math:`16\times 16` (:math:`4\times 4`) matrices in the POVM-picture. Importantly, only objects of this size need to be stored, but :math:`n-body operators for :math:`n>2` are also supported.
Frequently encountered unitary and dissipative operators are pre-defined and can be constructed as explained below.

Assembling operators
Expand All @@ -165,7 +165,7 @@ Using the POVM-operator class, the expression for the operator that corresponds
Lindbladian.add({"name": "decaydown", "strength": gamma, "sites": (l,)})

Adding terms to the Lindbladian is done using dictionaries, which have three entries: The name of the operator to be added, its prefactor and the site-ids that are involved.
Valid names that are recognized are the unitary 2-body (1-body) operators [``"XX"``, ``"YY"``, ``"ZZ"``] ([``"X"``, ``"Y"``, ``"Z"``]) corresponding to couplings and external magnetic fields aswell as the single-particle dissipation terms [``"dephasing"``, ``"decaydown"``, ``"decayup"``] corresponding to the dissipation operators [:math:`\sigma^z`, :math:`\sigma^-`, :math:`\sigma^+`].
Valid names that are recognized by default are the unitary 2-body (1-body) operators [``"XX"``, ``"YY"``, ``"ZZ"``] ([``"X"``, ``"Y"``, ``"Z"``]) corresponding to couplings and external magnetic fields aswell as the single-particle dissipation terms [``"dephasing"``, ``"decaydown"``, ``"decayup"``] corresponding to the dissipation operators [:math:`\sigma^z`, :math:`\sigma^-`, :math:`\sigma^+`].

Detailed documentation
----------------------
Expand All @@ -183,3 +183,4 @@ Detailed documentation
.. autofunction:: jVMC.operator.get_observables
.. autofunction:: jVMC.operator.get_1_particle_distributions
.. autofunction:: jVMC.operator.get_paulis
.. autofunction:: jVMC.operator.matrix_to_povm
2 changes: 1 addition & 1 deletion examples/ex5_dissipative_Lindblad.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def norm_fun(v, df=lambda x: x):
tdvpEquation = jVMC.util.tdvp.TDVP(sampler, rhsPrefactor=-1.,
svdTol=1e-6, diagonalShift=0, makeReal='real', crossValidation=False)

stepper = jVMC.util.stepper.AdaptiveHeun(timeStep=1e-3, tol=1e-3) # ODE integrator
stepper = jVMC.util.stepper.AdaptiveHeun(timeStep=1e-3, tol=1e-4) # ODE integrator

res = {"X": [], "Y": [], "Z": [], "X_corr_L1": [],
"Y_corr_L1": [], "Z_corr_L1": []}
Expand Down
100 changes: 58 additions & 42 deletions jVMC/operator/povm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from jVMC.operator import Operator

import functools
import itertools

opDtype = global_defs.tReal

Expand Down Expand Up @@ -351,64 +352,79 @@ def add(self, opDescr):
Args:
* ``opDescr``: Operator dictionary to be added to the operator.
"""
id = len(self.ops) + 1
opDescr["id"] = id
self.ops.append(opDescr)
self.compiled = False

def _get_s_primes(self, s, *args, stateCouplings, matEls, siteCouplings):
def apply_on_singleSiteCoupling(s, stateCouplings, matEls, siteCoupling):
def _get_s_primes(self, s, *args):
def apply_on_singleSiteCoupling(s, stateCouplings, matEls, siteCoupling, max_int_size):
stateIndices = tuple(s[siteCoupling])
OffdConfig = jnp.vstack([s] * 16)
OffdConfig = OffdConfig.at[:, siteCoupling].set(stateCouplings[stateIndices].reshape(-1, 2))
return OffdConfig.reshape((4, 4, -1)), matEls[stateIndices]

sample_shape = s.shape
OffdConfigs, matEls = jax.vmap(apply_on_singleSiteCoupling, in_axes=(None, None, 0, 0))(s.reshape(-1), stateCouplings, matEls, siteCouplings)
OffdConfig = jnp.vstack([s] * 4**max_int_size)
OffdConfig = OffdConfig.at[:, siteCoupling].set(stateCouplings[stateIndices].reshape(4**max_int_size, max_int_size))
return OffdConfig.reshape((4,) * max_int_size + (-1,)), matEls[stateIndices]

OffdConfigs, matEls = jax.vmap(apply_on_singleSiteCoupling, in_axes=(None, None, 0, 0, None))(s.reshape(-1),
self.stateCouplings,
self.matEls,
self.siteCouplings,
self.max_int_size)
return OffdConfigs.reshape((-1,) + s.shape), matEls.reshape(-1)

def compile(self):
"""Compiles an operator mapping function from the previously added dictionaries.
"""
self.siteCouplings = []
self.matEls = []
self.idxbase = jnp.array([[[i, j] for j in range(4)] for i in range(4)])
self.stateCouplings = jnp.array([[self.idxbase for j in range(4)] for i in range(4)])

# find the highest index of the (many) local Hilbert spaces
# Get maximum interaction size (max_int_size)
self.max_int_size = max([len(op["sites"]) for op in self.ops])

self.idxbase = jnp.array(list(itertools.product([0, 1, 2, 3],
repeat=self.max_int_size))).reshape((4,)*self.max_int_size+(self.max_int_size,))
self.stateCouplings = jnp.tile(self.idxbase, (4,)*self.max_int_size + (1,)*self.max_int_size + (1,))

# Find the highest index of the (many) local Hilbert spaces
self.max_site = max([max(op["sites"]) for op in self.ops])

# loop over all local Hilbert spaces
for idx in range(self.max_site + 1):
# sort interactions that involve the current local Hilbert space in the 0-th index into one- and two-body interactions
ops_oneBody = [op for op in self.ops if op["sites"][0] == idx and len(op["sites"]) == 1]
ops_twoBody = [op for op in self.ops if op["sites"][0] == idx and len(op["sites"]) == 2]

# find all the local Hilbert spaces that are coupled to the one we currently selected ("idx")
neighbour_indices = set([op["sites"][1] for op in ops_twoBody])
if len(ops_twoBody) == 0 and len(ops_oneBody) > 0:
# the current local Hilbert space is not listed in the 0-th index of any of the two body interactions, but 1-body interactions still need to respected
# create artificial coupling to a second spin with a unity operator acting on spin #2
neighbour_op_comp = jnp.zeros((16, 16), dtype=opDtype)
for op_oneBody in ops_oneBody:
neighbour_op_comp += op_oneBody["strength"] * jnp.kron(self.povm.operators[op_oneBody["name"]], jnp.eye(4, dtype=opDtype))

self.matEls.append(neighbour_op_comp.reshape((4,) * 4))
self.siteCouplings.append([idx, (idx + 1) % (self.max_site + 1)])
else:
for neighbour_idx in neighbour_indices:
# get all the operators that include the neighbour index
neighbour_ops = [op for op in ops_twoBody if op["sites"][1] == neighbour_idx]

# obtain 16x16 matrices for these operators and add the single particle operators with weight 1 / # of neighbours
neighbour_op_comp = jnp.zeros((16, 16), dtype=opDtype)
for neighbour_op in neighbour_ops:
neighbour_op_comp += neighbour_op["strength"] * self.povm.operators[neighbour_op["name"]]
for op_oneBody in ops_oneBody:
neighbour_op_comp += op_oneBody["strength"] * jnp.kron(self.povm.operators[op_oneBody["name"]], jnp.eye(4, dtype=opDtype)) / len(neighbour_indices)

# for the computed operator (16x16) obtain a representation in which a matrix element and the connected indices are stored
self.matEls.append(neighbour_op_comp.reshape((4,) * 4))
self.siteCouplings.append([idx, neighbour_idx])
# Sort interactions that involve the current local Hilbert space in the 0-th index according to number of
# sites
ops_ordered = [[op for op in self.ops if op["sites"][0] == idx and len(op["sites"]) == (n + 1)]
for n in range(self.max_int_size)]

for n in range(self.max_int_size - 1, -1, -1):
all_indices = set(op["sites"] for op in ops_ordered[n])
for i, indices in enumerate(all_indices):
# Search for operators acting on the same indices, also operators acting on fewer sites are
# accounted for here
ops_same_indices = [[op for op in ops_ordered[_n] if op["sites"] == indices[:_n+1]]
for _n in range(n+1)]
used_op_ids = [[op["id"] for op in ops_same_indices[_n]] for _n in range(n+1)]

# Add contribution of all operators in ops_same_indices, if necessary multiply unity interaction
# to additional sites to make them `max_int_size`-body interactions
neighbour_op_comp = jnp.zeros((4**self.max_int_size, 4**self.max_int_size), dtype=opDtype)
for j, ops in enumerate(ops_same_indices):
for op in ops:
op_matrix = self.povm.operators[op["name"]]
for _ in range(self.max_int_size - j - 1):
op_matrix = jnp.kron(op_matrix, jnp.eye(4, dtype=opDtype))
neighbour_op_comp += op["strength"] * op_matrix

# Avoid counting operators multiple times
ops_ordered = [[op for op in ops_ordered[k] if op["id"] not in used_op_ids[k]]
for k in range(self.max_int_size)]

while len(indices) < self.max_int_size:
empty_idx = (indices[-1] + 1) % (self.max_site + 1)
while empty_idx in indices:
empty_idx = (empty_idx + 1) % (self.max_site + 1)
indices += (empty_idx,)
self.matEls.append(neighbour_op_comp.reshape((4,) * 2 * self.max_int_size))
self.siteCouplings.append(indices)

self.siteCouplings = jnp.array(self.siteCouplings)
self.matEls = jnp.array(self.matEls)
return functools.partial(self._get_s_primes, stateCouplings=self.stateCouplings, matEls=self.matEls, siteCouplings=self.siteCouplings)
return self._get_s_primes
2 changes: 1 addition & 1 deletion jVMC/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
"""Current jVMC version at head on Github."""
__version__ = "1.1.0"
__version__ = "1.1.1"
170 changes: 170 additions & 0 deletions tests/povm_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,52 @@

import jVMC
import jVMC.operator as op
import jax
import jax.numpy as jnp


class TestPOVM(unittest.TestCase):
def prepare_net(self, L, dt, hiddenSize=1, depth=1, cell="RNN"):
def copy_dict(a):
b = {}
for key, value in a.items():
if type(value) == type(a):
b[key] = copy_dict(value)
else:
b[key] = value
return b

sample_shape = (L,)
self.psi = jVMC.util.util.init_net({"gradient_batch_size": 5000, "net1":
{"type": "RNN",
"translation": True,
"parameters": {"inputDim": 4,
"realValuedOutput": True,
"realValuedParams": True,
"logProbFactor": 1, "hiddenSize": hiddenSize, "L": L, "depth": depth, "cell": cell}}},
sample_shape, 1234)

system_data = {"dim": "1D", "L": L}
self.povm = op.POVM(system_data)

prob_dist = jVMC.operator.povm.get_1_particle_distributions("y_up", self.povm)
prob_dist /= prob_dist[0]
biases = jnp.log(prob_dist[1:])
params = copy_dict(self.psi._param_unflatten(self.psi.get_parameters()))

params["outputDense"]["bias"] = biases
params["outputDense"]["kernel"] = 1e-15 * params["outputDense"]["kernel"]
params = jnp.concatenate([p.ravel()
for p in jax.tree_util.tree_flatten(params)[0]])
self.psi.set_parameters(params)

self.sampler = jVMC.sampler.ExactSampler(self.psi, (L,), lDim=4, logProbFactor=1)

self.tdvpEquation = jVMC.util.tdvp.TDVP(self.sampler, rhsPrefactor=-1.,
svdTol=1e-6, diagonalShift=0, makeReal='real', crossValidation=False)

self.stepper = jVMC.util.stepper.Euler(timeStep=dt) # ODE integrator

def test_matrix_to_povm(self):
unity = jnp.eye(2)
zero_matrix = jnp.zeros((2, 2))
Expand Down Expand Up @@ -53,6 +95,134 @@ def test_adding_operator(self):
self.assertRaises(ValueError, povm.add_dissipator, "unity", op.matrix_to_povm(unity, povm.M,
povm.T_inv, mode='dissipative'))

def test_time_evolution_one_site(self):
# This tests the time evolution of a sample system and compares it with the analytical solution

L = 3
Tmax = 2
dt = 1E-3

self.prepare_net(L, dt, hiddenSize=1, depth=1)

Lindbladian = op.POVMOperator(self.povm)
for l in range(L):
Lindbladian.add({"name": "X", "strength": 3.0, "sites": (l,)})
Lindbladian.add({"name": "dephasing", "strength": 1.0, "sites": (l,)})


res = {"X": [], "Y": [], "Z": []}

times = jnp.linspace(0, Tmax, int(Tmax / dt))
for i in range(int(Tmax / dt)):
result = jVMC.operator.povm.measure_povm(Lindbladian.povm, self.sampler)
for dim in ["X", "Y", "Z"]:
res[dim].append(result[dim]["mean"])

dp, _ = self.stepper.step(0, self.tdvpEquation, self.psi.get_parameters(), hamiltonian=Lindbladian,
psi=self.psi)
self.psi.set_parameters(dp)

# Analytical solution
w = jnp.sqrt(35)
Sx_avg = jnp.zeros_like(times)
Sy_avg = (w*jnp.cos(w*times)-jnp.sin(w*times))/w*jnp.exp(-times)
Sz_avg = 6/w*jnp.sin(w*times)*jnp.exp(-times)

self.assertTrue(jnp.allclose(Sx_avg, jnp.asarray(res["X"]), atol=1e-2))
self.assertTrue(jnp.allclose(Sy_avg, jnp.asarray(res["Y"]), atol=1e-2))
self.assertTrue(jnp.allclose(Sz_avg, jnp.asarray(res["Z"]), atol=1e-2))

def test_time_evolution_two_site(self):
# This tests the time evolution of a sample system and compares it with the analytical solution

L = 3
Tmax = 2
dt = 1E-3

self.prepare_net(L, dt, hiddenSize=3, depth=1)

sx = op.get_paulis()[0]
XX_ = jnp.kron(sx, sx)
M_2_body = jnp.array(
[[jnp.kron(self.povm.M[i], self.povm.M[j]) for j in range(4)] for i in range(4)]).reshape(16, 4, 4)
T_inv_2_body = jnp.kron(self.povm.T_inv, self.povm.T_inv)

self.povm.add_dissipator("XX_", op.matrix_to_povm(XX_, M_2_body, T_inv_2_body, mode="dissipative"))

Lindbladian = op.POVMOperator(self.povm)
Lindbladian.add({"name": "XX_", "strength": 1.0, "sites": (0, 1)})
Lindbladian.add({"name": "Z", "strength": 3.0, "sites": (0,)})
Lindbladian.add({"name": "Z", "strength": 3.0, "sites": (1,)})
Lindbladian.add({"name": "Z", "strength": 3.0, "sites": (2,)})

res = {"X": [], "Y": [], "Z": []}

times = jnp.linspace(0, Tmax, int(Tmax / dt))
for i in range(int(Tmax / dt)):
result = jVMC.operator.povm.measure_povm(Lindbladian.povm, self.sampler)
for dim in ["X", "Y", "Z"]:
res[dim].append(result[dim]["mean"])

dp, _ = self.stepper.step(0, self.tdvpEquation, self.psi.get_parameters(), hamiltonian=Lindbladian,
psi=self.psi)
self.psi.set_parameters(dp)

# Analytical solution
w = jnp.sqrt(35)
Sx_avg = -jnp.sin(6*times)/3 - 4/w*jnp.sin(w*times)*jnp.exp(-times)
Sy_avg = jnp.cos(6*times)/3 + (2/3*jnp.cos(w*times) - 2/3/w*jnp.sin(w*times))*jnp.exp(-times)
Sz_avg = jnp.zeros_like(times)

self.assertTrue(jnp.allclose(Sx_avg, jnp.asarray(res["X"]), atol=1e-2))
self.assertTrue(jnp.allclose(Sy_avg, jnp.asarray(res["Y"]), atol=1e-2))
self.assertTrue(jnp.allclose(Sz_avg, jnp.asarray(res["Z"]), atol=1e-2))

def test_time_evolution_three_site(self):
# This tests the time evolution of a sample system and compares it with the analytical solution

L = 3
Tmax = 2
dt = 1E-3

self.prepare_net(L, dt, hiddenSize=3, depth=1)

sx = op.get_paulis()[0]
XXX = jnp.kron(jnp.kron(sx, sx), sx)
M_3_body = jnp.array(
[[[jnp.kron(jnp.kron(self.povm.M[i], self.povm.M[j]), self.povm.M[k]) for j in range(4)] for i in range(4)]
for k in range(4)]).reshape(64, 8, 8)
T_inv_3_body = jnp.kron(jnp.kron(self.povm.T_inv, self.povm.T_inv), self.povm.T_inv)

self.povm.add_dissipator("XXX", op.matrix_to_povm(XXX, M_3_body, T_inv_3_body, mode="dissipative"))

Lindbladian = op.POVMOperator(self.povm)
Lindbladian.add({"name": "XXX", "strength": 1.0, "sites": (0, 1, 2)})
Lindbladian.add({"name": "Z", "strength": 3.0, "sites": (0,)})
Lindbladian.add({"name": "Z", "strength": 3.0, "sites": (1,)})
Lindbladian.add({"name": "Z", "strength": 3.0, "sites": (2,)})

res = {"X": [], "Y": [], "Z": []}

times = jnp.linspace(0, Tmax, int(Tmax / dt))
for i in range(int(Tmax / dt)):
result = jVMC.operator.povm.measure_povm(Lindbladian.povm, self.sampler)
for dim in ["X", "Y", "Z"]:
res[dim].append(result[dim]["mean"])

dp, _ = self.stepper.step(0, self.tdvpEquation, self.psi.get_parameters(), hamiltonian=Lindbladian,
psi=self.psi)
self.psi.set_parameters(dp)

# Analytical solution
w = jnp.sqrt(35)
Sx_avg = -6*jnp.sin(w*times)*jnp.exp(-times)/w
Sy_avg = jnp.cos(w*times)*jnp.exp(-times) - jnp.sin(w*times)*jnp.exp(-times)/w
Sz_avg = jnp.zeros_like(times)

self.assertTrue(jnp.allclose(Sx_avg, jnp.asarray(res["X"]), atol=1e-2))
self.assertTrue(jnp.allclose(Sy_avg, jnp.asarray(res["Y"]), atol=1e-2))
self.assertTrue(jnp.allclose(Sz_avg, jnp.asarray(res["Z"]), atol=1e-2))


if __name__ == '__main__':
unittest.main()