Skip to content

Commit

Permalink
Merge pull request #31 from markusschmitt/povm_dev
Browse files Browse the repository at this point in the history
Povm dev
  • Loading branch information
laurinbrunner authored Jul 18, 2022
2 parents 91cae31 + a7b8c1b commit d9af4c3
Show file tree
Hide file tree
Showing 5 changed files with 233 additions and 46 deletions.
5 changes: 3 additions & 2 deletions documentation/source/operator.rst
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ resulting in

:math:`\dot{P}^\textbf{a} = \mathrm{tr}(\mathcal{L}[M^\textbf{c}]M^\textbf{a}) T^{-1\textbf{cb}} P^\textbf{b} = \mathcal{L}^\textbf{ab}P^\textbf{b}`.

Typically :math:`\mathcal{L}[\rho]` consists of 2-body (1-body) operators which translate to (real) :math:`16\times 16` (:math:`4\times 4`) matrices in the POVM-picture. Importantly, only objects of this size need to be stored.
Typically :math:`\mathcal{L}[\rho]` consists of 2-body (1-body) operators which translate to (real) :math:`16\times 16` (:math:`4\times 4`) matrices in the POVM-picture. Importantly, only objects of this size need to be stored, but :math:`n-body operators for :math:`n>2` are also supported.
Frequently encountered unitary and dissipative operators are pre-defined and can be constructed as explained below.

Assembling operators
Expand All @@ -165,7 +165,7 @@ Using the POVM-operator class, the expression for the operator that corresponds
Lindbladian.add({"name": "decaydown", "strength": gamma, "sites": (l,)})

Adding terms to the Lindbladian is done using dictionaries, which have three entries: The name of the operator to be added, its prefactor and the site-ids that are involved.
Valid names that are recognized are the unitary 2-body (1-body) operators [``"XX"``, ``"YY"``, ``"ZZ"``] ([``"X"``, ``"Y"``, ``"Z"``]) corresponding to couplings and external magnetic fields aswell as the single-particle dissipation terms [``"dephasing"``, ``"decaydown"``, ``"decayup"``] corresponding to the dissipation operators [:math:`\sigma^z`, :math:`\sigma^-`, :math:`\sigma^+`].
Valid names that are recognized by default are the unitary 2-body (1-body) operators [``"XX"``, ``"YY"``, ``"ZZ"``] ([``"X"``, ``"Y"``, ``"Z"``]) corresponding to couplings and external magnetic fields aswell as the single-particle dissipation terms [``"dephasing"``, ``"decaydown"``, ``"decayup"``] corresponding to the dissipation operators [:math:`\sigma^z`, :math:`\sigma^-`, :math:`\sigma^+`].

Detailed documentation
----------------------
Expand All @@ -183,3 +183,4 @@ Detailed documentation
.. autofunction:: jVMC.operator.get_observables
.. autofunction:: jVMC.operator.get_1_particle_distributions
.. autofunction:: jVMC.operator.get_paulis
.. autofunction:: jVMC.operator.matrix_to_povm
2 changes: 1 addition & 1 deletion examples/ex5_dissipative_Lindblad.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def norm_fun(v, df=lambda x: x):
tdvpEquation = jVMC.util.tdvp.TDVP(sampler, rhsPrefactor=-1.,
svdTol=1e-6, diagonalShift=0, makeReal='real', crossValidation=False)

stepper = jVMC.util.stepper.AdaptiveHeun(timeStep=1e-3, tol=1e-3) # ODE integrator
stepper = jVMC.util.stepper.AdaptiveHeun(timeStep=1e-3, tol=1e-4) # ODE integrator

res = {"X": [], "Y": [], "Z": [], "X_corr_L1": [],
"Y_corr_L1": [], "Z_corr_L1": []}
Expand Down
100 changes: 58 additions & 42 deletions jVMC/operator/povm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from jVMC.operator import Operator

import functools
import itertools

opDtype = global_defs.tReal

Expand Down Expand Up @@ -351,64 +352,79 @@ def add(self, opDescr):
Args:
* ``opDescr``: Operator dictionary to be added to the operator.
"""
id = len(self.ops) + 1
opDescr["id"] = id
self.ops.append(opDescr)
self.compiled = False

def _get_s_primes(self, s, *args, stateCouplings, matEls, siteCouplings):
def apply_on_singleSiteCoupling(s, stateCouplings, matEls, siteCoupling):
def _get_s_primes(self, s, *args):
def apply_on_singleSiteCoupling(s, stateCouplings, matEls, siteCoupling, max_int_size):
stateIndices = tuple(s[siteCoupling])
OffdConfig = jnp.vstack([s] * 16)
OffdConfig = OffdConfig.at[:, siteCoupling].set(stateCouplings[stateIndices].reshape(-1, 2))
return OffdConfig.reshape((4, 4, -1)), matEls[stateIndices]

sample_shape = s.shape
OffdConfigs, matEls = jax.vmap(apply_on_singleSiteCoupling, in_axes=(None, None, 0, 0))(s.reshape(-1), stateCouplings, matEls, siteCouplings)
OffdConfig = jnp.vstack([s] * 4**max_int_size)
OffdConfig = OffdConfig.at[:, siteCoupling].set(stateCouplings[stateIndices].reshape(4**max_int_size, max_int_size))
return OffdConfig.reshape((4,) * max_int_size + (-1,)), matEls[stateIndices]

OffdConfigs, matEls = jax.vmap(apply_on_singleSiteCoupling, in_axes=(None, None, 0, 0, None))(s.reshape(-1),
self.stateCouplings,
self.matEls,
self.siteCouplings,
self.max_int_size)
return OffdConfigs.reshape((-1,) + s.shape), matEls.reshape(-1)

def compile(self):
"""Compiles an operator mapping function from the previously added dictionaries.
"""
self.siteCouplings = []
self.matEls = []
self.idxbase = jnp.array([[[i, j] for j in range(4)] for i in range(4)])
self.stateCouplings = jnp.array([[self.idxbase for j in range(4)] for i in range(4)])

# find the highest index of the (many) local Hilbert spaces
# Get maximum interaction size (max_int_size)
self.max_int_size = max([len(op["sites"]) for op in self.ops])

self.idxbase = jnp.array(list(itertools.product([0, 1, 2, 3],
repeat=self.max_int_size))).reshape((4,)*self.max_int_size+(self.max_int_size,))
self.stateCouplings = jnp.tile(self.idxbase, (4,)*self.max_int_size + (1,)*self.max_int_size + (1,))

# Find the highest index of the (many) local Hilbert spaces
self.max_site = max([max(op["sites"]) for op in self.ops])

# loop over all local Hilbert spaces
for idx in range(self.max_site + 1):
# sort interactions that involve the current local Hilbert space in the 0-th index into one- and two-body interactions
ops_oneBody = [op for op in self.ops if op["sites"][0] == idx and len(op["sites"]) == 1]
ops_twoBody = [op for op in self.ops if op["sites"][0] == idx and len(op["sites"]) == 2]

# find all the local Hilbert spaces that are coupled to the one we currently selected ("idx")
neighbour_indices = set([op["sites"][1] for op in ops_twoBody])
if len(ops_twoBody) == 0 and len(ops_oneBody) > 0:
# the current local Hilbert space is not listed in the 0-th index of any of the two body interactions, but 1-body interactions still need to respected
# create artificial coupling to a second spin with a unity operator acting on spin #2
neighbour_op_comp = jnp.zeros((16, 16), dtype=opDtype)
for op_oneBody in ops_oneBody:
neighbour_op_comp += op_oneBody["strength"] * jnp.kron(self.povm.operators[op_oneBody["name"]], jnp.eye(4, dtype=opDtype))

self.matEls.append(neighbour_op_comp.reshape((4,) * 4))
self.siteCouplings.append([idx, (idx + 1) % (self.max_site + 1)])
else:
for neighbour_idx in neighbour_indices:
# get all the operators that include the neighbour index
neighbour_ops = [op for op in ops_twoBody if op["sites"][1] == neighbour_idx]

# obtain 16x16 matrices for these operators and add the single particle operators with weight 1 / # of neighbours
neighbour_op_comp = jnp.zeros((16, 16), dtype=opDtype)
for neighbour_op in neighbour_ops:
neighbour_op_comp += neighbour_op["strength"] * self.povm.operators[neighbour_op["name"]]
for op_oneBody in ops_oneBody:
neighbour_op_comp += op_oneBody["strength"] * jnp.kron(self.povm.operators[op_oneBody["name"]], jnp.eye(4, dtype=opDtype)) / len(neighbour_indices)

# for the computed operator (16x16) obtain a representation in which a matrix element and the connected indices are stored
self.matEls.append(neighbour_op_comp.reshape((4,) * 4))
self.siteCouplings.append([idx, neighbour_idx])
# Sort interactions that involve the current local Hilbert space in the 0-th index according to number of
# sites
ops_ordered = [[op for op in self.ops if op["sites"][0] == idx and len(op["sites"]) == (n + 1)]
for n in range(self.max_int_size)]

for n in range(self.max_int_size - 1, -1, -1):
all_indices = set(op["sites"] for op in ops_ordered[n])
for i, indices in enumerate(all_indices):
# Search for operators acting on the same indices, also operators acting on fewer sites are
# accounted for here
ops_same_indices = [[op for op in ops_ordered[_n] if op["sites"] == indices[:_n+1]]
for _n in range(n+1)]
used_op_ids = [[op["id"] for op in ops_same_indices[_n]] for _n in range(n+1)]

# Add contribution of all operators in ops_same_indices, if necessary multiply unity interaction
# to additional sites to make them `max_int_size`-body interactions
neighbour_op_comp = jnp.zeros((4**self.max_int_size, 4**self.max_int_size), dtype=opDtype)
for j, ops in enumerate(ops_same_indices):
for op in ops:
op_matrix = self.povm.operators[op["name"]]
for _ in range(self.max_int_size - j - 1):
op_matrix = jnp.kron(op_matrix, jnp.eye(4, dtype=opDtype))
neighbour_op_comp += op["strength"] * op_matrix

# Avoid counting operators multiple times
ops_ordered = [[op for op in ops_ordered[k] if op["id"] not in used_op_ids[k]]
for k in range(self.max_int_size)]

while len(indices) < self.max_int_size:
empty_idx = (indices[-1] + 1) % (self.max_site + 1)
while empty_idx in indices:
empty_idx = (empty_idx + 1) % (self.max_site + 1)
indices += (empty_idx,)
self.matEls.append(neighbour_op_comp.reshape((4,) * 2 * self.max_int_size))
self.siteCouplings.append(indices)

self.siteCouplings = jnp.array(self.siteCouplings)
self.matEls = jnp.array(self.matEls)
return functools.partial(self._get_s_primes, stateCouplings=self.stateCouplings, matEls=self.matEls, siteCouplings=self.siteCouplings)
return self._get_s_primes
2 changes: 1 addition & 1 deletion jVMC/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
"""Current jVMC version at head on Github."""
__version__ = "1.1.0"
__version__ = "1.1.1"
170 changes: 170 additions & 0 deletions tests/povm_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,52 @@

import jVMC
import jVMC.operator as op
import jax
import jax.numpy as jnp


class TestPOVM(unittest.TestCase):
def prepare_net(self, L, dt, hiddenSize=1, depth=1, cell="RNN"):
def copy_dict(a):
b = {}
for key, value in a.items():
if type(value) == type(a):
b[key] = copy_dict(value)
else:
b[key] = value
return b

sample_shape = (L,)
self.psi = jVMC.util.util.init_net({"gradient_batch_size": 5000, "net1":
{"type": "RNN",
"translation": True,
"parameters": {"inputDim": 4,
"realValuedOutput": True,
"realValuedParams": True,
"logProbFactor": 1, "hiddenSize": hiddenSize, "L": L, "depth": depth, "cell": cell}}},
sample_shape, 1234)

system_data = {"dim": "1D", "L": L}
self.povm = op.POVM(system_data)

prob_dist = jVMC.operator.povm.get_1_particle_distributions("y_up", self.povm)
prob_dist /= prob_dist[0]
biases = jnp.log(prob_dist[1:])
params = copy_dict(self.psi._param_unflatten(self.psi.get_parameters()))

params["outputDense"]["bias"] = biases
params["outputDense"]["kernel"] = 1e-15 * params["outputDense"]["kernel"]
params = jnp.concatenate([p.ravel()
for p in jax.tree_util.tree_flatten(params)[0]])
self.psi.set_parameters(params)

self.sampler = jVMC.sampler.ExactSampler(self.psi, (L,), lDim=4, logProbFactor=1)

self.tdvpEquation = jVMC.util.tdvp.TDVP(self.sampler, rhsPrefactor=-1.,
svdTol=1e-6, diagonalShift=0, makeReal='real', crossValidation=False)

self.stepper = jVMC.util.stepper.Euler(timeStep=dt) # ODE integrator

def test_matrix_to_povm(self):
unity = jnp.eye(2)
zero_matrix = jnp.zeros((2, 2))
Expand Down Expand Up @@ -53,6 +95,134 @@ def test_adding_operator(self):
self.assertRaises(ValueError, povm.add_dissipator, "unity", op.matrix_to_povm(unity, povm.M,
povm.T_inv, mode='dissipative'))

def test_time_evolution_one_site(self):
# This tests the time evolution of a sample system and compares it with the analytical solution

L = 3
Tmax = 2
dt = 1E-3

self.prepare_net(L, dt, hiddenSize=1, depth=1)

Lindbladian = op.POVMOperator(self.povm)
for l in range(L):
Lindbladian.add({"name": "X", "strength": 3.0, "sites": (l,)})
Lindbladian.add({"name": "dephasing", "strength": 1.0, "sites": (l,)})


res = {"X": [], "Y": [], "Z": []}

times = jnp.linspace(0, Tmax, int(Tmax / dt))
for i in range(int(Tmax / dt)):
result = jVMC.operator.povm.measure_povm(Lindbladian.povm, self.sampler)
for dim in ["X", "Y", "Z"]:
res[dim].append(result[dim]["mean"])

dp, _ = self.stepper.step(0, self.tdvpEquation, self.psi.get_parameters(), hamiltonian=Lindbladian,
psi=self.psi)
self.psi.set_parameters(dp)

# Analytical solution
w = jnp.sqrt(35)
Sx_avg = jnp.zeros_like(times)
Sy_avg = (w*jnp.cos(w*times)-jnp.sin(w*times))/w*jnp.exp(-times)
Sz_avg = 6/w*jnp.sin(w*times)*jnp.exp(-times)

self.assertTrue(jnp.allclose(Sx_avg, jnp.asarray(res["X"]), atol=1e-2))
self.assertTrue(jnp.allclose(Sy_avg, jnp.asarray(res["Y"]), atol=1e-2))
self.assertTrue(jnp.allclose(Sz_avg, jnp.asarray(res["Z"]), atol=1e-2))

def test_time_evolution_two_site(self):
# This tests the time evolution of a sample system and compares it with the analytical solution

L = 3
Tmax = 2
dt = 1E-3

self.prepare_net(L, dt, hiddenSize=3, depth=1)

sx = op.get_paulis()[0]
XX_ = jnp.kron(sx, sx)
M_2_body = jnp.array(
[[jnp.kron(self.povm.M[i], self.povm.M[j]) for j in range(4)] for i in range(4)]).reshape(16, 4, 4)
T_inv_2_body = jnp.kron(self.povm.T_inv, self.povm.T_inv)

self.povm.add_dissipator("XX_", op.matrix_to_povm(XX_, M_2_body, T_inv_2_body, mode="dissipative"))

Lindbladian = op.POVMOperator(self.povm)
Lindbladian.add({"name": "XX_", "strength": 1.0, "sites": (0, 1)})
Lindbladian.add({"name": "Z", "strength": 3.0, "sites": (0,)})
Lindbladian.add({"name": "Z", "strength": 3.0, "sites": (1,)})
Lindbladian.add({"name": "Z", "strength": 3.0, "sites": (2,)})

res = {"X": [], "Y": [], "Z": []}

times = jnp.linspace(0, Tmax, int(Tmax / dt))
for i in range(int(Tmax / dt)):
result = jVMC.operator.povm.measure_povm(Lindbladian.povm, self.sampler)
for dim in ["X", "Y", "Z"]:
res[dim].append(result[dim]["mean"])

dp, _ = self.stepper.step(0, self.tdvpEquation, self.psi.get_parameters(), hamiltonian=Lindbladian,
psi=self.psi)
self.psi.set_parameters(dp)

# Analytical solution
w = jnp.sqrt(35)
Sx_avg = -jnp.sin(6*times)/3 - 4/w*jnp.sin(w*times)*jnp.exp(-times)
Sy_avg = jnp.cos(6*times)/3 + (2/3*jnp.cos(w*times) - 2/3/w*jnp.sin(w*times))*jnp.exp(-times)
Sz_avg = jnp.zeros_like(times)

self.assertTrue(jnp.allclose(Sx_avg, jnp.asarray(res["X"]), atol=1e-2))
self.assertTrue(jnp.allclose(Sy_avg, jnp.asarray(res["Y"]), atol=1e-2))
self.assertTrue(jnp.allclose(Sz_avg, jnp.asarray(res["Z"]), atol=1e-2))

def test_time_evolution_three_site(self):
# This tests the time evolution of a sample system and compares it with the analytical solution

L = 3
Tmax = 2
dt = 1E-3

self.prepare_net(L, dt, hiddenSize=3, depth=1)

sx = op.get_paulis()[0]
XXX = jnp.kron(jnp.kron(sx, sx), sx)
M_3_body = jnp.array(
[[[jnp.kron(jnp.kron(self.povm.M[i], self.povm.M[j]), self.povm.M[k]) for j in range(4)] for i in range(4)]
for k in range(4)]).reshape(64, 8, 8)
T_inv_3_body = jnp.kron(jnp.kron(self.povm.T_inv, self.povm.T_inv), self.povm.T_inv)

self.povm.add_dissipator("XXX", op.matrix_to_povm(XXX, M_3_body, T_inv_3_body, mode="dissipative"))

Lindbladian = op.POVMOperator(self.povm)
Lindbladian.add({"name": "XXX", "strength": 1.0, "sites": (0, 1, 2)})
Lindbladian.add({"name": "Z", "strength": 3.0, "sites": (0,)})
Lindbladian.add({"name": "Z", "strength": 3.0, "sites": (1,)})
Lindbladian.add({"name": "Z", "strength": 3.0, "sites": (2,)})

res = {"X": [], "Y": [], "Z": []}

times = jnp.linspace(0, Tmax, int(Tmax / dt))
for i in range(int(Tmax / dt)):
result = jVMC.operator.povm.measure_povm(Lindbladian.povm, self.sampler)
for dim in ["X", "Y", "Z"]:
res[dim].append(result[dim]["mean"])

dp, _ = self.stepper.step(0, self.tdvpEquation, self.psi.get_parameters(), hamiltonian=Lindbladian,
psi=self.psi)
self.psi.set_parameters(dp)

# Analytical solution
w = jnp.sqrt(35)
Sx_avg = -6*jnp.sin(w*times)*jnp.exp(-times)/w
Sy_avg = jnp.cos(w*times)*jnp.exp(-times) - jnp.sin(w*times)*jnp.exp(-times)/w
Sz_avg = jnp.zeros_like(times)

self.assertTrue(jnp.allclose(Sx_avg, jnp.asarray(res["X"]), atol=1e-2))
self.assertTrue(jnp.allclose(Sy_avg, jnp.asarray(res["Y"]), atol=1e-2))
self.assertTrue(jnp.allclose(Sz_avg, jnp.asarray(res["Z"]), atol=1e-2))


if __name__ == '__main__':
unittest.main()

0 comments on commit d9af4c3

Please sign in to comment.