diff --git a/jVMC/operator/base.py b/jVMC/operator/base.py index d3f05f0..702ec3b 100644 --- a/jVMC/operator/base.py +++ b/jVMC/operator/base.py @@ -28,7 +28,11 @@ class Operator(metaclass=abc.ABCMeta): A tuple ``sp, matEls``, where ``sp`` is the list of connected basis configurations \ (as ``jax.numpy.array``) and ``matEls`` the corresponding matrix elements. - Important: Any child class inheriting from ``Operator`` has to call ``super().__init__()`` in \ + Alternatively, ``compile()`` can return a tuple of two functions, the first as described above and + and the second a preprocessor for the additional positional arguments ``*args``. Assuming that ``compile()`` + returns the tuple ``(f1, f2)``, ``f1`` will be called as ``f1(s, f2(*args))`` . + + *Important:* Any child class inheriting from ``Operator`` has to call ``super().__init__()`` in \ its constructor. **Example:** @@ -112,13 +116,27 @@ def get_s_primes(self, s, *args): """ + def id_fun(*args): + return args + if (not self.compiled) or self.compiled_argnum!=len(args): - _get_s_primes = jax.vmap(self.compile(), in_axes=(0,)+(None,)*len(args)) - self._get_s_primes_pmapd = global_defs.pmap_for_my_devices(_get_s_primes, in_axes=(0,)+(None,)*len(args)) + fun = self.compile() self.compiled = True self.compiled_argnum = len(args) + if type(fun) is tuple: + self.arg_fun = fun[1] + args = self.arg_fun(*args) + fun = fun[0] + else: + self.arg_fun = id_fun + _get_s_primes = jax.vmap(fun, in_axes=(0,)+(None,)*len(args)) + #_get_s_primes = jax.vmap(self.compile(), in_axes=(0,)+(None,)*len(args)) + self._get_s_primes_pmapd = global_defs.pmap_for_my_devices(_get_s_primes, in_axes=(0,)+(None,)*len(args)) + else: + args = self.arg_fun(*args) # Compute matrix elements + #self.sp, self.matEl = self._get_s_primes_pmapd(s, *args) self.sp, self.matEl = self._get_s_primes_pmapd(s, *args) # Get only non-zero contributions @@ -155,6 +173,21 @@ def get_O_loc(self, logPsiS, logPsiSP): return self._get_O_loc_pmapd(self.matEl, logPsiS, logPsiSP) def get_O_loc_batched(self, samples, psi, logPsiS, batchSize, *args): + """Compute :math:`O_{loc}(s)` in batches. + + Computes :math:`O_{loc}(s)=\sum_{s'} O_{s,s'}\\frac{\psi(s')}{\psi(s)}` in a batch-wise manner + to avoid out-of-memory issues. + + Arguments: + * ``samples``: Sample of computational basis configurations :math:`s`. + * ``psi``: Neural quantum state. + * ``logPsiS``: Logarithmic amplitudes :math:`\\ln(\psi(s))` + * ``batchSize``: Batch size. + * ``*args``: Further positional arguments for the operator. + + Returns: + :math:`O_{loc}(s)` for each configuration :math:`s`. + """ Oloc = self._alloc_Oloc_pmapd(samples) diff --git a/jVMC/operator/branch_free.py b/jVMC/operator/branch_free.py index baa1614..a5a9881 100644 --- a/jVMC/operator/branch_free.py +++ b/jVMC/operator/branch_free.py @@ -4,9 +4,11 @@ import numpy as np import jVMC.global_defs as global_defs +from mpi4py import MPI from . import Operator import functools +import sys opDtype = global_defs.tCpx @@ -107,6 +109,7 @@ def Sm(idx): import copy +@jax.jit def _id_prefactor(*args, val=1.0, **kwargs): return val @@ -192,10 +195,10 @@ def compile(self): # check whether string contains prefactor k0=0 if callable(op[0]): - self.prefactor.append(op[0]) + self.prefactor.append((o, jax.jit(op[0]))) k0=1 - else: - self.prefactor.append(_id_prefactor) + #else: + # self.prefactor.append(_id_prefactor) isDiagonal = True for k in range(k0, k0+self.maxOpStrLength): if k < len(op): @@ -218,32 +221,61 @@ def compile(self): self.matElsC = jnp.array(self.matEls, dtype=opDtype) self.diag = jnp.array(self.diag, dtype=np.int32) - return functools.partial(self._get_s_primes, idxC=self.idxC, mapC=self.mapC, matElsC=self.matElsC, diag=self.diag, prefactor=self.prefactor) + def arg_fun(*args, prefactor, init): + N = len(prefactor) + if N<50: + res = init + for i,f in prefactor: + res[i] = f(*args) + else: + # parallelize this, because jit compilation for each element can be slow + comm = MPI.COMM_WORLD + commSize = comm.Get_size() + rank = comm.Get_rank() + nEls = (N + commSize - 1) // commSize + myStart = nEls * rank + myEnd = min(myStart+nEls, N) + res = init[myStart:myEnd] + for i,f in prefactor[myStart:myEnd]: + res[i-myStart] = f(*args) + + res = np.concatenate(comm.allgather(res), axis=0) + + return (jnp.array(res), ) + + return functools.partial(self._get_s_primes, idxC=self.idxC, mapC=self.mapC, matElsC=self.matElsC, diag=self.diag, prefactor=self.prefactor),\ + functools.partial(arg_fun, prefactor=self.prefactor, init=np.ones(self.idxC.shape[0], dtype=self.matElsC.dtype)) def _get_s_primes(self, s, *args, idxC, mapC, matElsC, diag, prefactor): numOps = idxC.shape[0] - matEl = jnp.ones(numOps, dtype=matElsC.dtype) + #matEl = jnp.ones(numOps, dtype=matElsC.dtype) + matEl = args[0] + sp = jnp.array([s] * numOps) - def apply_fun(config, configMatEl, idx, sMap, matEls): + def apply_fun(c, x): + config, configMatEl = c + idx, sMap, matEls = x configShape = config.shape config = config.ravel() configMatEl = configMatEl * matEls[config[idx]] config = config.at[idx].set(sMap[config[idx]]) - return config.reshape(configShape), configMatEl + return (config.reshape(configShape), configMatEl), None - def apply_multi(config, configMatEl, opIdx, opMap, opMatEls, prefactor): + #def apply_multi(config, configMatEl, opIdx, opMap, opMatEls, prefactor): + def apply_multi(config, configMatEl, opIdx, opMap, opMatEls): - for idx, mp, me in zip(opIdx, opMap, opMatEls): - config, configMatEl = apply_fun(config, configMatEl, idx, mp, me) + (config, configMatEl), _ = jax.lax.scan(apply_fun, (config, configMatEl), (opIdx, opMap, opMatEls)) - return config, prefactor*configMatEl + #return config, prefactor*configMatEl + return config, configMatEl # vmap over operators - sp, matEl = vmap(apply_multi, in_axes=(0, 0, 0, 0, 0, 0))(sp, matEl, idxC, mapC, matElsC, jnp.array([f(*args) for f in prefactor])) + #sp, matEl = vmap(apply_multi, in_axes=(0, 0, 0, 0, 0, 0))(sp, matEl, idxC, mapC, matElsC, jnp.array([f(*args) for f in prefactor])) + sp, matEl = vmap(apply_multi, in_axes=(0, 0, 0, 0, 0))(sp, matEl, idxC, mapC, matElsC) if len(diag) > 1: matEl = matEl.at[diag[0]].set(jnp.sum(matEl[diag], axis=0)) matEl = matEl.at[diag[1:]].set(jnp.zeros((diag.shape[0] - 1,), dtype=matElsC.dtype)) diff --git a/jVMC/version.py b/jVMC/version.py index 6c9bc00..9523f15 100644 --- a/jVMC/version.py +++ b/jVMC/version.py @@ -1,2 +1,2 @@ -"""Current Flax version at head on Github.""" -__version__ = "0.1.6" +"""Current jVMC version at head on Github.""" +__version__ = "0.1.7"