diff --git a/examples/ex0_ground_state_search.ipynb b/examples/ex0_ground_state_search.ipynb index c05441d..b173ddb 100644 --- a/examples/ex0_ground_state_search.ipynb +++ b/examples/ex0_ground_state_search.ipynb @@ -14,8 +14,7 @@ "sys.path.append(sys.path[0] + \"/..\")\n", "\n", "import jax\n", - "from jax.config import config\n", - "config.update(\"jax_enable_x64\", True)\n", + "jax.config.update(\"jax_enable_x64\", True)\n", "\n", "import jax.random as random\n", "import jax.numpy as jnp\n", diff --git a/examples/ex0_ground_state_search.py b/examples/ex0_ground_state_search.py index 17c8090..b916e7c 100644 --- a/examples/ex0_ground_state_search.py +++ b/examples/ex0_ground_state_search.py @@ -2,8 +2,7 @@ # coding: utf-8 import jax -from jax.config import config -config.update("jax_enable_x64", True) +jax.config.update("jax_enable_x64", True) import jax.random as random import jax.numpy as jnp diff --git a/examples/ex2_unitary_time_evolution.py b/examples/ex2_unitary_time_evolution.py index 37c7fd3..e1f31aa 100644 --- a/examples/ex2_unitary_time_evolution.py +++ b/examples/ex2_unitary_time_evolution.py @@ -1,12 +1,7 @@ import os import jax -from jax.config import config -config.update("jax_enable_x64", True) - -import jax.random as random -import flax -import jax.numpy as jnp +jax.config.update("jax_enable_x64", True) import numpy as np diff --git a/examples/ex3_custom_net.py b/examples/ex3_custom_net.py index 746437e..d8a93e5 100644 --- a/examples/ex3_custom_net.py +++ b/examples/ex3_custom_net.py @@ -1,7 +1,6 @@ import jax import flax import jVMC -import numpy as np # This class defines the network structure of a complex RBM diff --git a/examples/ex4_benchmarking.py b/examples/ex4_benchmarking.py index 5833274..8971f8e 100644 --- a/examples/ex4_benchmarking.py +++ b/examples/ex4_benchmarking.py @@ -1,8 +1,7 @@ import os import jax -from jax.config import config -config.update("jax_enable_x64", True) +jax.config.update("jax_enable_x64", True) import jax.random as random import jax.numpy as jnp diff --git a/examples/ex5_dissipative_Lindblad.py b/examples/ex5_dissipative_Lindblad.py index fe638e6..571f2f0 100644 --- a/examples/ex5_dissipative_Lindblad.py +++ b/examples/ex5_dissipative_Lindblad.py @@ -1,13 +1,9 @@ import matplotlib.pyplot as plt -import numpy as np -import flax -import jax.random as random -from jax.config import config import jax.numpy as jnp import jax import jVMC from functools import partial -config.update("jax_enable_x64", True) +jax.config.update("jax_enable_x64", True) def copy_dict(a): diff --git a/examples/ex6_dissipative_Lindblad_2D.py b/examples/ex6_dissipative_Lindblad_2D.py index 7ed97fe..3397273 100644 --- a/examples/ex6_dissipative_Lindblad_2D.py +++ b/examples/ex6_dissipative_Lindblad_2D.py @@ -1,12 +1,8 @@ import matplotlib.pyplot as plt -import numpy as np -import flax -import jax.random as random -from jax.config import config import jax.numpy as jnp import jax import jVMC -config.update("jax_enable_x64", True) +jax.config.update("jax_enable_x64", True) from functools import partial diff --git a/examples/ex7_fermions.ipynb b/examples/ex7_fermions.ipynb index ee4100a..362bb0a 100644 --- a/examples/ex7_fermions.ipynb +++ b/examples/ex7_fermions.ipynb @@ -27,8 +27,7 @@ "\n", "# jax\n", "import jax\n", - "from jax.config import config\n", - "config.update(\"jax_enable_x64\", True)\n", + "jax.config.update(\"jax_enable_x64\", True)\n", "import jax.numpy as jnp\n", "import flax.linen as nn\n", "\n", @@ -138,11 +137,11 @@ " if i == flavourL//flavour-1:\n", " continue\n", " # up chain hopping\n", - " hamiltonian.add(op.scal_opstr( t, ( annihilation(site1UP + i) ,creation(site1UP + i + 1) ) ) )\n", - " hamiltonian.add(op.scal_opstr( t, ( annihilation(site1UP + i + 1) ,creation(site1UP + i) ) ) )\n", + " hamiltonian.add(op.scal_opstr( t, ( creation(site1UP + i + 1), annihilation(site1UP + i) , ) ) )\n", + " hamiltonian.add(op.scal_opstr( t, ( creation(site1UP + i), annihilation(site1UP + i + 1) , ) ) )\n", " # down chain hopping\n", - " hamiltonian.add(op.scal_opstr( t, ( annihilation(site1DO - i) ,creation(site1DO - i - 1) ) ) )\n", - " hamiltonian.add(op.scal_opstr( t, ( annihilation(site1DO - i - 1) ,creation(site1DO - i) ) ) )" + " hamiltonian.add(op.scal_opstr( t, ( creation(site1DO - i - 1), annihilation(site1DO - i) , ) ) )\n", + " hamiltonian.add(op.scal_opstr( t, ( creation(site1DO - i), annihilation(site1DO - i - 1) , ) ) )" ] }, { diff --git a/jVMC/operator/branch_free.py b/jVMC/operator/branch_free.py index 555ba6d..6db54f9 100644 --- a/jVMC/operator/branch_free.py +++ b/jVMC/operator/branch_free.py @@ -1,5 +1,5 @@ import jax -from jax import jit, vmap, grad +from jax import vmap import jax.numpy as jnp import numpy as np @@ -8,7 +8,6 @@ from . import Operator import functools -import sys opDtype = global_defs.tCpx @@ -28,8 +27,10 @@ def Id(idx=0, lDim=2): """ - return {'idx': idx, 'map': jnp.array([j for j in range(lDim)], dtype=np.int32), - 'matEls': jnp.array([1. for j in range(lDim)], dtype=opDtype), 'diag': True} + return LocalOp(idx = idx, + map = jnp.array([j for j in range(lDim)], dtype=np.int32), + matEls = jnp.array([1. for j in range(lDim)], dtype=opDtype), + diag = True) def Sx(idx): @@ -44,7 +45,10 @@ def Sx(idx): """ - return {'idx': idx, 'map': jnp.array([1, 0], dtype=np.int32), 'matEls': jnp.array([1.0, 1.0], dtype=opDtype), 'diag': False} + return LocalOp(idx = idx, + map = jnp.array([1, 0], dtype=np.int32), + matEls = jnp.array([1.0, 1.0], dtype=opDtype), + diag = False) def Sy(idx): @@ -59,7 +63,10 @@ def Sy(idx): """ - return {'idx': idx, 'map': jnp.array([1, 0], dtype=np.int32), 'matEls': jnp.array([1.j, -1.j], dtype=opDtype), 'diag': False} + return LocalOp(idx = idx, + map = jnp.array([1, 0], dtype=np.int32), + matEls = jnp.array([1.j, -1.j], dtype=opDtype), + diag = False) def Sz(idx): @@ -74,7 +81,10 @@ def Sz(idx): """ - return {'idx': idx, 'map': jnp.array([0, 1], dtype=np.int32), 'matEls': jnp.array([-1.0, 1.0], dtype=opDtype), 'diag': True} + return LocalOp(idx = idx, + map = jnp.array([0, 1], dtype=np.int32), + matEls = jnp.array([-1.0, 1.0], dtype=opDtype), + diag = True) def Sp(idx): @@ -89,7 +99,10 @@ def Sp(idx): """ - return {'idx': idx, 'map': jnp.array([1, 0], dtype=np.int32), 'matEls': jnp.array([1.0, 0.0], dtype=opDtype), 'diag': False} + return LocalOp(idx = idx, + map = jnp.array([1, 0], dtype=np.int32), + matEls = jnp.array([1.0, 0.0], dtype=opDtype), + diag = False) def Sm(idx): @@ -104,74 +117,77 @@ def Sm(idx): """ - return {'idx': idx, 'map': jnp.array([0, 0], dtype=np.int32), 'matEls': jnp.array([0.0, 1.0], dtype=opDtype), 'diag': False} + return LocalOp(idx = idx, + map = jnp.array([0, 0], dtype=np.int32), + matEls = jnp.array([0.0, 1.0], dtype=opDtype), + diag = False) ###################### # fermionic number operator def number(idx): - """Returns a :math:`c^\dagger c` femrioic number operator + """Returns a :math:`c^\dagger c` fermionic number operator Args: * ``idx``: Index of the local Hilbert space. Returns: - Dictionary defining :math:`c^\dagger c` femrioic number operator + Dictionary defining :math:`c^\dagger c` fermionic number operator """ - return { - 'idx': idx, - 'map': jax.numpy.array([0,1],dtype=np.int32), - 'matEls': jax.numpy.array([0.,1.],dtype=opDtype), - 'diag': True, - 'fermionic': False - } + return LocalOp( + idx = idx, + map = jax.numpy.array([0,1],dtype=np.int32), + matEls = jax.numpy.array([0.,1.],dtype=opDtype), + diag = True, + fermionic = False + ) ###################### # fermionic creation operator def creation(idx): - """Returns a :math:`c^\dagger` femrioic creation operator + """Returns a :math:`c^\dagger` fermionic creation operator Args: * ``idx``: Index of the local Hilbert space. Returns: - Dictionary defining :math:`c^\dagger` femrioic creation operator + Dictionary defining :math:`c^\dagger` fermionic creation operator """ - return { - 'idx': idx, - 'map': jax.numpy.array([1,0],dtype=np.int32), - 'matEls': jax.numpy.array([1.,0.],dtype=opDtype), - 'diag': False, - "fermionic": True - } + return LocalOp( + idx = idx, + map = jax.numpy.array([1,0],dtype=np.int32), + matEls = jax.numpy.array([1.,0.],dtype=opDtype), + diag = False, + fermionic = True + ) ###################### # fermionic annihilation operator def annihilation(idx): - """Returns a :math:`c` femrioic creation operator + """Returns a :math:`c` fermionic creation operator Args: * ``idx``: Index of the local Hilbert space. Returns: - Dictionary defining :math:`c` femrioic creation operator + Dictionary defining :math:`c` fermionic creation operator """ - return { - 'idx': idx, - 'map': jax.numpy.array([1,0],dtype=np.int32), - 'matEls': jax.numpy.array([0.,1.],dtype=opDtype), - 'diag': False, - "fermionic": True - } + return LocalOp( + idx = idx, + map = jax.numpy.array([1,0],dtype=np.int32), + matEls = jax.numpy.array([0.,1.],dtype=opDtype), + diag = False, + fermionic = True + ) import copy @@ -183,6 +199,62 @@ def _id_prefactor(*args, val=1.0, **kwargs): def _prod_fun(f1, f2, *args, **kwargs): return f1(*args) * f2(*args) + +class OpStr(tuple): + """This class provides the interface for operator strings + """ + + def __init__(self, *args): + + super(OpStr, self).__init__() + + + def __new__(cls, *args): + + factors = [] + ops = [] + for o in args: + if isinstance(o, (LocalOp, dict)): + ops.append(o) + else: + if callable(o): + factors.append(o) + else: + factors.append(functools.partial(_id_prefactor, val=o)) + + while len(factors)>1: + factors[0] = functools.partial(_prod_fun, f1=factors[0], f2=factors.pop()) + + return super(OpStr, cls).__new__(cls, tuple(factors + ops)) + + + def __mul__(self, other): + + if not isinstance(other, (tuple, OpStr)): + other = OpStr(other) + + if callable(other[0]): + return OpStr(*(other[0] * self), *(other[1:])) + + return OpStr(*self, *other) + + def __rmul__(self, a): + + if isinstance(a, dict): + return OpStr(LocalOp(**a), *self) + + newOp = [copy.deepcopy(o) for o in self] + if not callable(a): + a = functools.partial(_id_prefactor, val=a) + + if callable(newOp[0]): + newOp[0] = functools.partial(_prod_fun, f1=a, f2=newOp[0]) + else: + newOp = [a] + newOp + + return OpStr(*tuple(newOp)) + + def scal_opstr(a, op): """Add prefactor to operator string @@ -195,16 +267,46 @@ def scal_opstr(a, op): """ - newOp = [copy.deepcopy(o) for o in op] - if not callable(a): - a = functools.partial(_id_prefactor, val=a) + if not isinstance(op, (tuple, OpStr)): + raise RuntimeError("Can add prefactors only to OpStr or tuple objects.") + + if isinstance(op, tuple): + op = OpStr(*op) + + return a * op + + +class LocalOp(dict): + """This class provides the interface for operators acting on a local Hilbert space + + Initializer arguments: + + * "idx": Lattice site index, + * "map": Indices of non-zero matrix elements, + * "matEls": Non-zero matrix elements, + * "diag": Boolean indicating, whether the operator is diagonal, + * "fermionic": Boolean indicating, whether this is a fermionic operator + """ + + def __init__(self, **kwargs): + for k in kwargs.keys(): + self[k] = kwargs[k] - if callable(newOp[0]): - newOp[0] = functools.partial(_prod_fun, f1=a, f2=newOp[0]) - else: - newOp = [a] + newOp - return tuple(newOp) + def __mul__(self, other): + + if isinstance(other, dict): + return OpStr(self, LocalOp(**other)) + + if isinstance(other, OpStr): + return OpStr(self, *other) + + return OpStr(self, other) + + def __rmul__(self, other): + + return other * OpStr(self) + class BranchFreeOperator(Operator): """This class provides functionality to compute operator matrix elements @@ -236,6 +338,10 @@ def add(self, opDescr): self.ops.append(opDescr) self.compiled = False + def __iadd__(self, opDescr): + self.add(opDescr) + return self + def compile(self): """Compiles a operator mapping function from the given operator strings. @@ -268,20 +374,19 @@ def compile(self): if callable(op[0]): self.prefactor.append((o, jax.jit(op[0]))) k0=1 - #else: - # self.prefactor.append(_id_prefactor) isDiagonal = True - for k in range(k0, k0+self.maxOpStrLength): - if k < len(op): - if not op[k]['diag']: + for k in range(self.maxOpStrLength): + kRev = len(op) - k - 1 + if kRev >= k0: + if not op[kRev]['diag']: isDiagonal = False - self.idx[o].append(op[k]['idx']) - self.map[o].append(op[k]['map']) - self.matEls[o].append(op[k]['matEls']) + self.idx[o].append(op[kRev]['idx']) + self.map[o].append(op[kRev]['map']) + self.matEls[o].append(op[kRev]['matEls']) ######## fermions ######## fermi_check = True - if "fermionic" in op[k]: - if op[k]["fermionic"]: + if "fermionic" in op[kRev]: + if op[kRev]["fermionic"]: fermi_check = False self.fermionic[o].append(1.) if fermi_check: @@ -319,9 +424,14 @@ def arg_fun(*args, prefactor, init): nEls = (N + commSize - 1) // commSize myStart = nEls * rank myEnd = min(myStart+nEls, N) - res = init[myStart:myEnd] + + firstIdx = [0] + [prefactor[nEls * r][0]-1 for r in range(1,commSize)] + lastIdx = [prefactor[min(nEls * (r+1), N-1)][0]-1 for r in range(commSize-1)] + [len(init)] + + res = init[firstIdx[rank]:lastIdx[rank]] + for i,f in prefactor[myStart:myEnd]: - res[i-myStart] = f(*args) + res[i-firstIdx[rank]] = f(*args) res = np.concatenate(comm.allgather(res), axis=0) diff --git a/jVMC/sampler.py b/jVMC/sampler.py index fd48bc1..6376bd4 100644 --- a/jVMC/sampler.py +++ b/jVMC/sampler.py @@ -2,16 +2,13 @@ import jax.numpy as jnp import jax.random as random import numpy as np -from jax import vmap, jit -import flax +from jax import vmap import jVMC.mpi_wrapper as mpi from jVMC.nets.sym_wrapper import SymNet from functools import partial -import time - import jVMC.global_defs as global_defs diff --git a/jVMC/util/minsr.py b/jVMC/util/minsr.py index 7d6703a..6a83d08 100644 --- a/jVMC/util/minsr.py +++ b/jVMC/util/minsr.py @@ -64,7 +64,7 @@ def solve(self, eloc, gradients): """ T = gradients.tangent_kernel() - T_inv = jnp.linalg.pinv(T, rcond=self.pinvTol, hermitian=True) + T_inv = jnp.linalg.pinv(T, rtol=self.pinvTol, hermitian=True) eloc_all = mpi.gather(eloc._data).reshape((-1,)) gradients_all = mpi.gather(gradients._data) diff --git a/jVMC/version.py b/jVMC/version.py index fa01ca4..7c1af2b 100644 --- a/jVMC/version.py +++ b/jVMC/version.py @@ -1,2 +1,2 @@ """Current jVMC version at head on Github.""" -__version__ = "1.4.0" +__version__ = "1.5.0" diff --git a/pytest.ini b/pytest.ini index 05a2282..5fa13fd 100644 --- a/pytest.ini +++ b/pytest.ini @@ -3,3 +3,4 @@ filterwarnings = error ignore::UserWarning ignore::DeprecationWarning + ignore::FutureWarning diff --git a/tests/data_ref/fermion_ref.txt b/tests/data_ref/fermion_ref.txt new file mode 100644 index 0000000..4e7c71f --- /dev/null +++ b/tests/data_ref/fermion_ref.txt @@ -0,0 +1,256 @@ + (9.999999999999999988e-15+0.000000000000000000e+00j) + (9.999999999999999988e-15+0.000000000000000000e+00j) + (9.999999999999999988e-15+0.000000000000000000e+00j) + (9.999999999999999988e-15+0.000000000000000000e+00j) + (9.999999999999999988e-15+0.000000000000000000e+00j) + (9.999999999999999988e-15+0.000000000000000000e+00j) + (9.999999999999999988e-15+0.000000000000000000e+00j) + (9.999999999999999988e-15+0.000000000000000000e+00j) + (9.999999999999999988e-15+0.000000000000000000e+00j) + (9.999999999999999988e-15+0.000000000000000000e+00j) + (9.999999999999999988e-15+0.000000000000000000e+00j) + (9.999999999999999988e-15+0.000000000000000000e+00j) + (9.999999999999999988e-15+0.000000000000000000e+00j) + (9.999999999999999988e-15+0.000000000000000000e+00j) + (9.999999999999999988e-15+0.000000000000000000e+00j) + (9.999999999999999988e-15+0.000000000000000000e+00j) + (9.999999999999999988e-15+0.000000000000000000e+00j) + (1.012405569205572800e-14+0.000000000000000000e+00j) + (9.992429911120988466e-15+0.000000000000000000e+00j) + (1.018107827316719497e-14+0.000000000000000000e+00j) + (1.001554814630909135e-14+0.000000000000000000e+00j) + (9.884563579205653332e-15+0.000000000000000000e+00j) + (1.008694658250692440e-14+0.000000000000000000e+00j) + (9.884514953390900512e-15+0.000000000000000000e+00j) + (1.000683538783617658e-14+0.000000000000000000e+00j) + (9.946116298366358861e-15+0.000000000000000000e+00j) + (1.000799583478743204e-14+0.000000000000000000e+00j) + (1.002929008086214697e-14+0.000000000000000000e+00j) + (1.000737640176805697e-14+0.000000000000000000e+00j) + (1.002464086983243822e-14+0.000000000000000000e+00j) + (1.014875162650758221e-14+0.000000000000000000e+00j) + (9.668575924364690636e-15+0.000000000000000000e+00j) + (9.999999999999999988e-15+0.000000000000000000e+00j) + (9.992429911120988466e-15+0.000000000000000000e+00j) + (9.856030343947204453e-15+0.000000000000000000e+00j) + (9.892150263399298250e-15+0.000000000000000000e+00j) + (1.000948787777628991e-14+0.000000000000000000e+00j) + (9.954283356922028726e-15+0.000000000000000000e+00j) + (1.002686385991066692e-14+0.000000000000000000e+00j) + (1.022252311620020033e-14+0.000000000000000000e+00j) + (1.001554814630909135e-14+0.000000000000000000e+00j) + (1.002238895218893926e-14+0.000000000000000000e+00j) + (9.983150525305456827e-15+0.000000000000000000e+00j) + (1.002888069389126038e-14+0.000000000000000000e+00j) + (9.964414878238126149e-15+0.000000000000000000e+00j) + (9.778699805553789976e-15+0.000000000000000000e+00j) + (9.921350660454092063e-15+0.000000000000000000e+00j) + (9.800043828099024970e-15+0.000000000000000000e+00j) + (1.005647457659877795e-14+0.000000000000000000e+00j) + (1.004963176930474261e-14+0.000000000000000000e+00j) + (9.976453401598552481e-15+0.000000000000000000e+00j) + (1.169473103279208420e-01+0.000000000000000000e+00j) + (1.005134811118076924e-14+0.000000000000000000e+00j) + (1.142075452651279827e-01+0.000000000000000000e+00j) + (5.237071983649793649e-02+0.000000000000000000e+00j) + (1.004624547804380435e-14+0.000000000000000000e+00j) + (9.801156937621899771e-15+0.000000000000000000e+00j) + (5.237071983649789486e-02+0.000000000000000000e+00j) + (3.093131902174355713e-02+0.000000000000000000e+00j) + (9.743846605124971913e-15+0.000000000000000000e+00j) + (6.215385802675736995e-03+0.000000000000000000e+00j) + (9.872559084226558761e-15+0.000000000000000000e+00j) + (9.610747684769895433e-15+0.000000000000000000e+00j) + (9.957321596339594247e-15+0.000000000000000000e+00j) + (9.999999999999999988e-15+0.000000000000000000e+00j) + (1.001554814630909135e-14+0.000000000000000000e+00j) + (1.000952631933381154e-14+0.000000000000000000e+00j) + (9.950482633728291774e-15+0.000000000000000000e+00j) + (1.015641917049555495e-14+0.000000000000000000e+00j) + (1.001105185731426122e-14+0.000000000000000000e+00j) + (1.005228932557884958e-14+0.000000000000000000e+00j) + (9.864242167336837070e-15+0.000000000000000000e+00j) + (9.992429911120988466e-15+0.000000000000000000e+00j) + (1.009598174887286302e-14+0.000000000000000000e+00j) + (9.964395953139177211e-15+0.000000000000000000e+00j) + (1.001590440683868420e-14+0.000000000000000000e+00j) + (9.899404601594539389e-15+0.000000000000000000e+00j) + (1.001893910518010840e-14+0.000000000000000000e+00j) + (1.011646600491084927e-14+0.000000000000000000e+00j) + (9.885677836861881077e-15+0.000000000000000000e+00j) + (9.920054770346611154e-15+0.000000000000000000e+00j) + (9.999940393487268171e-15+0.000000000000000000e+00j) + (1.008598962297082857e-14+0.000000000000000000e+00j) + (1.142075452651278578e-01+0.000000000000000000e+00j) + (1.009529221811425285e-14+0.000000000000000000e+00j) + (4.582053623105379891e-01+0.000000000000000000e+00j) + (1.666316408201792276e-01+0.000000000000000000e+00j) + (9.713404183019009794e-15+0.000000000000000000e+00j) + (9.973601688587729818e-15+0.000000000000000000e+00j) + (1.666316408201793386e-01+0.000000000000000000e+00j) + (7.318181125002692911e-02+0.000000000000000000e+00j) + (1.003156940823014553e-14+0.000000000000000000e+00j) + (3.093131902174340794e-02+0.000000000000000000e+00j) + (1.008701223400030831e-14+0.000000000000000000e+00j) + (9.802407387794949514e-15+0.000000000000000000e+00j) + (1.003733269103751425e-14+0.000000000000000000e+00j) + (1.000074240342488738e-14+0.000000000000000000e+00j) + (1.001881603435488266e-14+0.000000000000000000e+00j) + (9.987142702059044029e-15+0.000000000000000000e+00j) + (5.237071983649779078e-02+0.000000000000000000e+00j) + (9.958728624327174480e-15+0.000000000000000000e+00j) + (1.666316408201789501e-01+0.000000000000000000e+00j) + (6.696642544736111646e-02+0.000000000000000000e+00j) + (9.877974903794605311e-15+0.000000000000000000e+00j) + (9.873062225234007634e-15+0.000000000000000000e+00j) + (3.412580519826258207e-01+0.000000000000000000e+00j) + (1.666316408201790888e-01+0.000000000000000000e+00j) + (1.015305103960146056e-14+0.000000000000000000e+00j) + (5.237071983649770751e-02+0.000000000000000000e+00j) + (9.818757372019085450e-15+0.000000000000000000e+00j) + (9.892764745687594218e-15+0.000000000000000000e+00j) + (9.996813570831133770e-15+0.000000000000000000e+00j) + (1.000000000000533900e-14+0.000000000000000000e+00j) + (1.000000000104220831e-14+0.000000000000000000e+00j) + (1.000000000071505980e-14+0.000000000000000000e+00j) + (1.000000000003459785e-14+0.000000000000000000e+00j) + (1.000000000033691775e-14+0.000000000000000000e+00j) + (1.000000000001106928e-14+0.000000000000000000e+00j) + (1.000000000007643588e-14+0.000000000000000000e+00j) + (1.000000000000270420e-14+0.000000000000000000e+00j) + (1.000000000012334155e-14+0.000000000000000000e+00j) + (1.000000000006407285e-14+0.000000000000000000e+00j) + (9.999999999931935492e-15+0.000000000000000000e+00j) + (9.999999999991611241e-15+0.000000000000000000e+00j) + (1.000000000006094107e-14+0.000000000000000000e+00j) + (1.000000000001678537e-14+0.000000000000000000e+00j) + (9.999999999976449335e-15+0.000000000000000000e+00j) + (9.999999999999999988e-15+0.000000000000000000e+00j) + (9.999999999999999988e-15+0.000000000000000000e+00j) + (1.000679694627865496e-14+0.000000000000000000e+00j) + (1.001554814630909135e-14+0.000000000000000000e+00j) + (9.980677438256133008e-15+0.000000000000000000e+00j) + (9.992429911120988466e-15+0.000000000000000000e+00j) + (9.938769036971834898e-15+0.000000000000000000e+00j) + (1.007128603902573344e-14+0.000000000000000000e+00j) + (9.928224767797384550e-15+0.000000000000000000e+00j) + (9.823666865507379077e-15+0.000000000000000000e+00j) + (9.892445471042393772e-15+0.000000000000000000e+00j) + (9.829751334798516039e-15+0.000000000000000000e+00j) + (1.014617686817310711e-14+0.000000000000000000e+00j) + (1.025971279402843961e-14+0.000000000000000000e+00j) + (9.905744382799989120e-15+0.000000000000000000e+00j) + (1.002305939686780660e-14+0.000000000000000000e+00j) + (1.002135951846055144e-14+0.000000000000000000e+00j) + (1.000074792749339864e-14+0.000000000000000000e+00j) + (1.005359678784152954e-14+0.000000000000000000e+00j) + (1.018976680940818047e-14+0.000000000000000000e+00j) + (5.237071983649791568e-02+0.000000000000000000e+00j) + (9.732260597886158843e-15+0.000000000000000000e+00j) + (1.666316408201793942e-01+0.000000000000000000e+00j) + (3.412580519826272640e-01+0.000000000000000000e+00j) + (1.014150231331177771e-14+0.000000000000000000e+00j) + (1.001387406630726283e-14+0.000000000000000000e+00j) + (6.696642544736125524e-02+0.000000000000000000e+00j) + (1.666316408201795607e-01+0.000000000000000000e+00j) + (1.017132658965099075e-14+0.000000000000000000e+00j) + (5.237071983649790874e-02+0.000000000000000000e+00j) + (1.015691993588858495e-14+0.000000000000000000e+00j) + (9.876147300480524217e-15+0.000000000000000000e+00j) + (9.996813570831083283e-15+0.000000000000000000e+00j) + (1.009364927243126021e-14+0.000000000000000000e+00j) + (9.996983102860466859e-15+0.000000000000000000e+00j) + (9.865119003336554879e-15+0.000000000000000000e+00j) + (3.093131902174344958e-02+0.000000000000000000e+00j) + (1.003485643405988742e-14+0.000000000000000000e+00j) + (7.318181125002692911e-02+0.000000000000000000e+00j) + (1.666316408201792554e-01+0.000000000000000000e+00j) + (9.822789425956947954e-15+0.000000000000000000e+00j) + (1.023813826766765529e-14+0.000000000000000000e+00j) + (1.666316408201790888e-01+0.000000000000000000e+00j) + (4.582053623105369344e-01+0.000000000000000000e+00j) + (9.884237794695145747e-15+0.000000000000000000e+00j) + (1.142075452651274553e-01+0.000000000000000000e+00j) + (1.010987642599157761e-14+0.000000000000000000e+00j) + (1.018174716635376494e-14+0.000000000000000000e+00j) + (9.947465467718161384e-15+0.000000000000000000e+00j) + (1.000000000000065632e-14+0.000000000000000000e+00j) + (1.000000000071473952e-14+0.000000000000000000e+00j) + (1.000000000245886050e-14+0.000000000000000000e+00j) + (1.000000000000567663e-14+0.000000000000000000e+00j) + (1.000000000099483090e-14+0.000000000000000000e+00j) + (1.000000000012496187e-14+0.000000000000000000e+00j) + (1.000000000002113988e-14+0.000000000000000000e+00j) + (9.999999999991609664e-15+0.000000000000000000e+00j) + (1.000000000031858935e-14+0.000000000000000000e+00j) + (9.999999999905472364e-15+0.000000000000000000e+00j) + (1.000000000010784990e-14+0.000000000000000000e+00j) + (1.000000000002495639e-14+0.000000000000000000e+00j) + (9.999999999954440115e-15+0.000000000000000000e+00j) + (9.999999999939721549e-15+0.000000000000000000e+00j) + (1.000000000001706463e-14+0.000000000000000000e+00j) + (9.999999999999999988e-15+0.000000000000000000e+00j) + (9.866735545680255152e-15+0.000000000000000000e+00j) + (9.993443954883948516e-15+0.000000000000000000e+00j) + (1.003145126982223010e-14+0.000000000000000000e+00j) + (6.215385802675729189e-03+0.000000000000000000e+00j) + (1.014422707513329393e-14+0.000000000000000000e+00j) + (3.093131902174352937e-02+0.000000000000000000e+00j) + (5.237071983649774221e-02+0.000000000000000000e+00j) + (9.845993553485651949e-15+0.000000000000000000e+00j) + (9.775910074652939773e-15+0.000000000000000000e+00j) + (5.237071983649781853e-02+0.000000000000000000e+00j) + (1.142075452651274137e-01+0.000000000000000000e+00j) + (9.885831689495509686e-15+0.000000000000000000e+00j) + (1.169473103279204396e-01+0.000000000000000000e+00j) + (1.011621870238967856e-14+0.000000000000000000e+00j) + (9.815184548709895538e-15+0.000000000000000000e+00j) + (1.010538552908876145e-14+0.000000000000000000e+00j) + (9.999999999990913888e-15+0.000000000000000000e+00j) + (1.000000000033755830e-14+0.000000000000000000e+00j) + (1.000000000099414459e-14+0.000000000000000000e+00j) + (1.000000000004920755e-14+0.000000000000000000e+00j) + (1.000000000225102720e-14+0.000000000000000000e+00j) + (1.000000000000000788e-14+0.000000000000000000e+00j) + (1.000000000008931166e-14+0.000000000000000000e+00j) + (1.000000000001678379e-14+0.000000000000000000e+00j) + (1.000000000062591694e-14+0.000000000000000000e+00j) + (1.000000000012502182e-14+0.000000000000000000e+00j) + (9.999999999931211317e-15+0.000000000000000000e+00j) + (9.999999999939726282e-15+0.000000000000000000e+00j) + (1.000000000004380543e-14+0.000000000000000000e+00j) + (1.000000000002551491e-14+0.000000000000000000e+00j) + (9.999999999991431381e-15+0.000000000000000000e+00j) + (9.999999999999999988e-15+0.000000000000000000e+00j) + (9.999999999989566514e-15+0.000000000000000000e+00j) + (1.000000000012265208e-14+0.000000000000000000e+00j) + (1.000000000031918573e-14+0.000000000000000000e+00j) + (9.999999999971766656e-15+0.000000000000000000e+00j) + (1.000000000062555249e-14+0.000000000000000000e+00j) + (1.000000000006667451e-14+0.000000000000000000e+00j) + (9.999999999974756439e-15+0.000000000000000000e+00j) + (9.999999999976452490e-15+0.000000000000000000e+00j) + (1.000000000083510865e-14+0.000000000000000000e+00j) + (9.999999999994104042e-15+0.000000000000000000e+00j) + (1.000000000005448503e-14+0.000000000000000000e+00j) + (1.000000000001706305e-14+0.000000000000000000e+00j) + (1.000000000000999012e-14+0.000000000000000000e+00j) + (9.999999999991431381e-15+0.000000000000000000e+00j) + (1.000000000000327061e-14+0.000000000000000000e+00j) + (9.999999999999999988e-15+0.000000000000000000e+00j) + (9.999999999999999988e-15+0.000000000000000000e+00j) + (9.999999999999999988e-15+0.000000000000000000e+00j) + (9.999999999999999988e-15+0.000000000000000000e+00j) + (9.999999999999999988e-15+0.000000000000000000e+00j) + (9.999999999999999988e-15+0.000000000000000000e+00j) + (9.999999999999999988e-15+0.000000000000000000e+00j) + (9.999999999999999988e-15+0.000000000000000000e+00j) + (9.999999999999999988e-15+0.000000000000000000e+00j) + (9.999999999999999988e-15+0.000000000000000000e+00j) + (9.999999999999999988e-15+0.000000000000000000e+00j) + (9.999999999999999988e-15+0.000000000000000000e+00j) + (9.999999999999999988e-15+0.000000000000000000e+00j) + (9.999999999999999988e-15+0.000000000000000000e+00j) + (9.999999999999999988e-15+0.000000000000000000e+00j) + (9.999999999999999988e-15+0.000000000000000000e+00j) + (9.999999999999999988e-15+0.000000000000000000e+00j) diff --git a/tests/operator_test.py b/tests/operator_test.py index 928bf64..a1cf05c 100644 --- a/tests/operator_test.py +++ b/tests/operator_test.py @@ -15,6 +15,29 @@ from jVMC.vqs import NQS import jVMC.global_defs as global_defs +import flax.linen as nn +class Target(nn.Module): + """Target wave function, returns a vector with the same dimension as the Hilbert space + + Initialization arguments: + * ``L``: System size + * ``d``: local Hilbert space dimension + * ``delta``: small number to avoid log(0) + + """ + L: int + d: float = 2.00 + delta: float = 1e-15 + + @nn.compact + def __call__(self, s): + kernel = self.param('kernel', + nn.initializers.constant(1), + (int(self.d**self.L))) + # return amplitude for state s + idx = ((self.d**jnp.arange(self.L)).dot(s[::-1])).astype(int) # NOTE that the state is reversed to account for different bit conventions used in openfermion + return jnp.log(abs(kernel[idx]+self.delta)) + 1.j*jnp.angle(kernel[idx]) + def get_shape(shape): return (global_defs.device_count(),) + shape @@ -31,9 +54,9 @@ def test_nonzeros(self): h = op.BranchFreeOperator() - h.add(op.scal_opstr(2., (op.Sp(0),))) - h.add(op.scal_opstr(2., (op.Sp(1),))) - h.add(op.scal_opstr(2., (op.Sp(2),))) + h += 2. * op.Sp(0) + h += 2. * op.Sp(1) + h += 2. * op.Sp(2) sp, matEl = h.get_s_primes(s) @@ -141,8 +164,10 @@ def test_fermionic_operators(self): def commutator(i,j): Comm = op.BranchFreeOperator() - Comm.add(op.scal_opstr( 1., (op.annihilation(i), op.creation(j), ) ) ) - Comm.add(op.scal_opstr( 1., (op.creation(j), op.annihilation(i), ) ) ) + # Comm.add(op.scal_opstr( 1., (op.annihilation(i), op.creation(j), ) ) ) + # Comm.add(op.scal_opstr( 1., (op.creation(j), op.annihilation(i), ) ) ) + Comm.add(op.scal_opstr( 1., (op.creation(j), op.annihilation(i)) ) ) + Comm.add(op.scal_opstr( 1., (op.annihilation(i), op.creation(j)) ) ) return Comm observalbes_dict = { @@ -168,6 +193,59 @@ def commutator(i,j): jnp.array([0.,0.,0.,0.]), rtol=1e-15) ) + + + t = - 1.0 # hopping + mu = -2.0 # chemical potential + V = 4.0 # interaction + L = 4 # number of sites + flavour = 2 # number of flavours + flavourL = flavour*L # number of spins times sites + + # initalize the Hamitonian + hamiltonian = op.BranchFreeOperator() + # impurity definitions + site1UP = 0 + site1DO = flavourL-1#//flavour + # loop over the 1d lattice + for i in range(0,flavourL//flavour): + # interaction + hamiltonian.add(op.scal_opstr( V, ( op.number(site1UP + i) , op.number(site1DO - i) ) ) ) + # chemical potential + hamiltonian.add(op.scal_opstr(mu , ( op.number(site1UP + i) ,) ) ) + hamiltonian.add(op.scal_opstr(mu , ( op.number(site1DO - i) ,) ) ) + if i == flavourL//flavour-1: + continue + # up chain hopping + hamiltonian.add(op.scal_opstr( t, ( op.creation(site1UP + i + 1), op.annihilation(site1UP + i) , ) ) ) + hamiltonian.add(op.scal_opstr( t, ( op.creation(site1UP + i), op.annihilation(site1UP + i + 1) , ) ) ) + # down chain hopping + hamiltonian.add(op.scal_opstr( t, ( op.creation(site1DO - i - 1), op.annihilation(site1DO - i) , ) ) ) + hamiltonian.add(op.scal_opstr( t, ( op.creation(site1DO - i), op.annihilation(site1DO - i - 1) , ) ) ) + + b = np.loadtxt("tests/data_ref/fermion_ref.txt", dtype=np.complex128) + chi_model = Target(L=flavourL, d=2) + chi = NQS(chi_model) + chi(jnp.array(jnp.ones((1, 1, flavourL)))) + chi.set_parameters(b) + chiSampler = jVMC.sampler.ExactSampler(chi, (flavourL,)) + s, logPsi, p = chiSampler.sample() + sPrime, _ = hamiltonian.get_s_primes(s) + Oloc = hamiltonian.get_O_loc(s, chi, logPsi) + Omean = jVMC.mpi_wrapper.global_mean(Oloc,p) + + self.assertTrue(jnp.allclose(Omean, -9.95314531)) + + def test_opstr(self): + op1 = op.Sz(3) + op2 = op.Sx(5) + + opstr1 = 13. * op1 * op2 + opstr2 = 1.j * opstr1 * op1 + + self.assertTrue(jnp.allclose(opstr2[0](), 13.j)) + for o in opstr2[1:]: + self.assertTrue(isinstance(o, (op.LocalOp, dict))) if __name__ == "__main__": diff --git a/tests/povm_test.py b/tests/povm_test.py index c481bd0..d6f02aa 100644 --- a/tests/povm_test.py +++ b/tests/povm_test.py @@ -143,7 +143,7 @@ def test_time_evolution_one_site(self): Sy_avg = (w * jnp.cos(w * times) - jnp.sin(w * times)) / w * jnp.exp(-times) Sz_avg = 6 / w * jnp.sin(w * times) * jnp.exp(-times) - print(Sz_avg-jnp.asarray(res["Z"])) + # print(Sz_avg-jnp.asarray(res["Z"])) self.assertTrue(jnp.allclose(Sx_avg, jnp.asarray(res["X"]), atol=1e-2)) self.assertTrue(jnp.allclose(Sy_avg, jnp.asarray(res["Y"]), atol=1e-2)) self.assertTrue(jnp.allclose(Sz_avg, jnp.asarray(res["Z"]), atol=1e-2)) diff --git a/tests/sampler_test.py b/tests/sampler_test.py index 8980016..02f4c5c 100644 --- a/tests/sampler_test.py +++ b/tests/sampler_test.py @@ -585,7 +585,6 @@ def test_exact_sampler(self): s, psi_s, pex = exactSampler.sample() import flax - print(isinstance(psi.parameters, flax.core.frozen_dict.FrozenDict)) self.assertTrue(jnp.max((psi(s) - psi_s) / psi_s) < 1e-14) s, psi_s, pex = exactSampler.sample(parameters=p0)