Skip to content

Commit

Permalink
Merge branch 'fix_operator_prefactors'
Browse files Browse the repository at this point in the history
  • Loading branch information
Markus committed Mar 21, 2022
2 parents 25ebc79 + a0b8e77 commit 453449d
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 17 deletions.
39 changes: 36 additions & 3 deletions jVMC/operator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@ class Operator(metaclass=abc.ABCMeta):
A tuple ``sp, matEls``, where ``sp`` is the list of connected basis configurations \
(as ``jax.numpy.array``) and ``matEls`` the corresponding matrix elements.
Important: Any child class inheriting from ``Operator`` has to call ``super().__init__()`` in \
Alternatively, ``compile()`` can return a tuple of two functions, the first as described above and
and the second a preprocessor for the additional positional arguments ``*args``. Assuming that ``compile()``
returns the tuple ``(f1, f2)``, ``f1`` will be called as ``f1(s, f2(*args))`` .
*Important:* Any child class inheriting from ``Operator`` has to call ``super().__init__()`` in \
its constructor.
**Example:**
Expand Down Expand Up @@ -112,13 +116,27 @@ def get_s_primes(self, s, *args):
"""

def id_fun(*args):
return args

if (not self.compiled) or self.compiled_argnum!=len(args):
_get_s_primes = jax.vmap(self.compile(), in_axes=(0,)+(None,)*len(args))
self._get_s_primes_pmapd = global_defs.pmap_for_my_devices(_get_s_primes, in_axes=(0,)+(None,)*len(args))
fun = self.compile()
self.compiled = True
self.compiled_argnum = len(args)
if type(fun) is tuple:
self.arg_fun = fun[1]
args = self.arg_fun(*args)
fun = fun[0]
else:
self.arg_fun = id_fun
_get_s_primes = jax.vmap(fun, in_axes=(0,)+(None,)*len(args))
#_get_s_primes = jax.vmap(self.compile(), in_axes=(0,)+(None,)*len(args))
self._get_s_primes_pmapd = global_defs.pmap_for_my_devices(_get_s_primes, in_axes=(0,)+(None,)*len(args))
else:
args = self.arg_fun(*args)

# Compute matrix elements
#self.sp, self.matEl = self._get_s_primes_pmapd(s, *args)
self.sp, self.matEl = self._get_s_primes_pmapd(s, *args)

# Get only non-zero contributions
Expand Down Expand Up @@ -155,6 +173,21 @@ def get_O_loc(self, logPsiS, logPsiSP):
return self._get_O_loc_pmapd(self.matEl, logPsiS, logPsiSP)

def get_O_loc_batched(self, samples, psi, logPsiS, batchSize, *args):
"""Compute :math:`O_{loc}(s)` in batches.
Computes :math:`O_{loc}(s)=\sum_{s'} O_{s,s'}\\frac{\psi(s')}{\psi(s)}` in a batch-wise manner
to avoid out-of-memory issues.
Arguments:
* ``samples``: Sample of computational basis configurations :math:`s`.
* ``psi``: Neural quantum state.
* ``logPsiS``: Logarithmic amplitudes :math:`\\ln(\psi(s))`
* ``batchSize``: Batch size.
* ``*args``: Further positional arguments for the operator.
Returns:
:math:`O_{loc}(s)` for each configuration :math:`s`.
"""

Oloc = self._alloc_Oloc_pmapd(samples)

Expand Down
56 changes: 44 additions & 12 deletions jVMC/operator/branch_free.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
import numpy as np

import jVMC.global_defs as global_defs
from mpi4py import MPI
from . import Operator

import functools
import sys

opDtype = global_defs.tCpx

Expand Down Expand Up @@ -107,6 +109,7 @@ def Sm(idx):

import copy

@jax.jit
def _id_prefactor(*args, val=1.0, **kwargs):
return val

Expand Down Expand Up @@ -192,10 +195,10 @@ def compile(self):
# check whether string contains prefactor
k0=0
if callable(op[0]):
self.prefactor.append(op[0])
self.prefactor.append((o, jax.jit(op[0])))
k0=1
else:
self.prefactor.append(_id_prefactor)
#else:
# self.prefactor.append(_id_prefactor)
isDiagonal = True
for k in range(k0, k0+self.maxOpStrLength):
if k < len(op):
Expand All @@ -218,32 +221,61 @@ def compile(self):
self.matElsC = jnp.array(self.matEls, dtype=opDtype)
self.diag = jnp.array(self.diag, dtype=np.int32)

return functools.partial(self._get_s_primes, idxC=self.idxC, mapC=self.mapC, matElsC=self.matElsC, diag=self.diag, prefactor=self.prefactor)
def arg_fun(*args, prefactor, init):
N = len(prefactor)
if N<50:
res = init
for i,f in prefactor:
res[i] = f(*args)
else:
# parallelize this, because jit compilation for each element can be slow
comm = MPI.COMM_WORLD
commSize = comm.Get_size()
rank = comm.Get_rank()
nEls = (N + commSize - 1) // commSize
myStart = nEls * rank
myEnd = min(myStart+nEls, N)
res = init[myStart:myEnd]
for i,f in prefactor[myStart:myEnd]:
res[i-myStart] = f(*args)

res = np.concatenate(comm.allgather(res), axis=0)

return (jnp.array(res), )

return functools.partial(self._get_s_primes, idxC=self.idxC, mapC=self.mapC, matElsC=self.matElsC, diag=self.diag, prefactor=self.prefactor),\
functools.partial(arg_fun, prefactor=self.prefactor, init=np.ones(self.idxC.shape[0], dtype=self.matElsC.dtype))

def _get_s_primes(self, s, *args, idxC, mapC, matElsC, diag, prefactor):

numOps = idxC.shape[0]
matEl = jnp.ones(numOps, dtype=matElsC.dtype)
#matEl = jnp.ones(numOps, dtype=matElsC.dtype)
matEl = args[0]

sp = jnp.array([s] * numOps)

def apply_fun(config, configMatEl, idx, sMap, matEls):
def apply_fun(c, x):
config, configMatEl = c
idx, sMap, matEls = x

configShape = config.shape
config = config.ravel()
configMatEl = configMatEl * matEls[config[idx]]
config = config.at[idx].set(sMap[config[idx]])

return config.reshape(configShape), configMatEl
return (config.reshape(configShape), configMatEl), None

def apply_multi(config, configMatEl, opIdx, opMap, opMatEls, prefactor):
#def apply_multi(config, configMatEl, opIdx, opMap, opMatEls, prefactor):
def apply_multi(config, configMatEl, opIdx, opMap, opMatEls):

for idx, mp, me in zip(opIdx, opMap, opMatEls):
config, configMatEl = apply_fun(config, configMatEl, idx, mp, me)
(config, configMatEl), _ = jax.lax.scan(apply_fun, (config, configMatEl), (opIdx, opMap, opMatEls))

return config, prefactor*configMatEl
#return config, prefactor*configMatEl
return config, configMatEl

# vmap over operators
sp, matEl = vmap(apply_multi, in_axes=(0, 0, 0, 0, 0, 0))(sp, matEl, idxC, mapC, matElsC, jnp.array([f(*args) for f in prefactor]))
#sp, matEl = vmap(apply_multi, in_axes=(0, 0, 0, 0, 0, 0))(sp, matEl, idxC, mapC, matElsC, jnp.array([f(*args) for f in prefactor]))
sp, matEl = vmap(apply_multi, in_axes=(0, 0, 0, 0, 0))(sp, matEl, idxC, mapC, matElsC)
if len(diag) > 1:
matEl = matEl.at[diag[0]].set(jnp.sum(matEl[diag], axis=0))
matEl = matEl.at[diag[1:]].set(jnp.zeros((diag.shape[0] - 1,), dtype=matElsC.dtype))
Expand Down
4 changes: 2 additions & 2 deletions jVMC/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
"""Current Flax version at head on Github."""
__version__ = "0.1.6"
"""Current jVMC version at head on Github."""
__version__ = "0.1.7"

0 comments on commit 453449d

Please sign in to comment.