diff --git a/jVMC/operator/base.py b/jVMC/operator/base.py index 702ec3b..db5d1dc 100644 --- a/jVMC/operator/base.py +++ b/jVMC/operator/base.py @@ -53,12 +53,13 @@ class Operator(metaclass=abc.ABCMeta): """ - def __init__(self): + def __init__(self, ElocBatchSize=-1): """Initialize ``Operator``. """ self.compiled = False self.compiled_argnum = -1 + self.ElocBatchSize = ElocBatchSize # pmap'd member functions self._get_s_primes_pmapd = None @@ -146,13 +147,37 @@ def id_fun(*args): return self._flatten_pmapd(self.sp), self.matEl - def _get_O_loc(self, matEl, logPsiS, logPsiSP): return jax.vmap(lambda x, y, z: jnp.sum(x * jnp.exp(z - y)), in_axes=(0, 0, 0))(matEl, logPsiS, logPsiSP.reshape(matEl.shape)) + def get_O_loc(self, samples, psi, logPsiS=None, *args): + """Compute :math:`O_{loc}(s)`. + + If the instance parameter ElocBatchSize is larger than 0 :math:`O_{loc}(s)` is computed in a batch-wise manner + to avoid out-of-memory issues. + + Arguments: + * ``samples``: Sample of computational basis configurations :math:`s`. + * ``psi``: Neural quantum state. + * ``logPsiS``: Logarithmic amplitudes :math:`\\ln(\psi(s))` + * ``*args``: Further positional arguments for the operator. + + Returns: + :math:`O_{loc}(s)` for each configuration :math:`s`. + """ - def get_O_loc(self, logPsiS, logPsiSP): + if logPsiS is None: + logPsiS = psi(samples) + + if self.ElocBatchSize > 0: + return self.get_O_loc_batched(samples, psi, logPsiS, self.ElocBatchSize, *args) + else: + sampleOffdConfigs, _ = self.get_s_primes(samples, *args) + logPsiSP = psi(sampleOffdConfigs) + return self.get_O_loc_unbatched(logPsiS, logPsiSP) + + def get_O_loc_unbatched(self, logPsiS, logPsiSP): """Compute :math:`O_{loc}(s)`. This member function assumes that ``get_s_primes(s)`` has been called before, as \ @@ -196,7 +221,7 @@ def get_O_loc_batched(self, samples, psi, logPsiS, batchSize, *args): remainder = numSamples % batchSize # Minimize mismatch - if remainder>0: + if remainder > 0: batchSize = numSamples // (numBatches+1) numBatches = numSamples // batchSize remainder = numSamples % batchSize @@ -208,7 +233,7 @@ def get_O_loc_batched(self, samples, psi, logPsiS, batchSize, *args): sp, _ = self.get_s_primes(batch, *args) - OlocBatch = self.get_O_loc(logPsiSbatch, psi(sp)) + OlocBatch = self.get_O_loc_unbatched(logPsiSbatch, psi(sp)) Oloc = self._insert_Oloc_batch_pmapd(Oloc, OlocBatch, b * batchSize) @@ -226,7 +251,7 @@ def expand_batch(batch, batchSize): sp, _ = self.get_s_primes(batch, *args) - OlocBatch = self.get_O_loc(logPsiSbatch, psi(sp)) + OlocBatch = self.get_O_loc_unbatched(logPsiSbatch, psi(sp)) OlocBatch = global_defs.pmap_for_my_devices( lambda d, startIdx, sliceSize: jax.lax.dynamic_slice_in_dim(d, startIdx, sliceSize), diff --git a/jVMC/operator/branch_free.py b/jVMC/operator/branch_free.py index a5a9881..860c5ad 100644 --- a/jVMC/operator/branch_free.py +++ b/jVMC/operator/branch_free.py @@ -147,7 +147,7 @@ class BranchFreeOperator(Operator): * ``lDim``: Dimension of local Hilbert space. """ - def __init__(self, lDim=2): + def __init__(self, lDim=2, **kwargs): """Initialize ``Operator``. Arguments: @@ -156,7 +156,7 @@ def __init__(self, lDim=2): self.ops = [] self.lDim = lDim - super().__init__() + super().__init__(**kwargs) def add(self, opDescr): """Add another operator to the operator diff --git a/jVMC/operator/povm.py b/jVMC/operator/povm.py index d1e1f2c..9a1da49 100644 --- a/jVMC/operator/povm.py +++ b/jVMC/operator/povm.py @@ -338,13 +338,13 @@ class POVMOperator(Operator): * ``lDim``: Dimension of local Hilbert space. """ - def __init__(self, povm, ldim=4): + def __init__(self, povm, ldim=4, **kwargs): """Initialize ``Operator``. """ self.povm = povm self.ldim = ldim self.ops = [] - super().__init__() + super().__init__(**kwargs) def add(self, opDescr): """Add another operator to the operator. diff --git a/jVMC/util/tdvp.py b/jVMC/util/tdvp.py index b35dfed..80a5bef 100644 --- a/jVMC/util/tdvp.py +++ b/jVMC/util/tdvp.py @@ -266,11 +266,7 @@ def stop_timing(outp, name, waitFor=None): # Evaluate local energy start_timing(outp, "compute Eloc") - sampleOffdConfigs, matEls = hamiltonian.get_s_primes(sampleConfigs, t) - start_timing(outp, "evaluate off-diagonal") - sampleLogPsiOffd = psi(sampleOffdConfigs) - stop_timing(outp, "evaluate off-diagonal", waitFor=sampleLogPsiOffd) - Eloc = hamiltonian.get_O_loc(sampleLogPsi, sampleLogPsiOffd) + Eloc = hamiltonian.get_O_loc(sampleConfigs, psi, sampleLogPsi, t) stop_timing(outp, "compute Eloc", waitFor=Eloc) # Evaluate gradients diff --git a/jVMC/util/util.py b/jVMC/util/util.py index 7a76a3b..15f2fe6 100644 --- a/jVMC/util/util.py +++ b/jVMC/util/util.py @@ -157,10 +157,8 @@ def measure(observables, psi, sampler, numSamples=None): if isinstance(op, collections.abc.Iterable): args = tuple(op[1:]) op = op[0] - - sampleOffdConfigs, matEls = op.get_s_primes(sampleConfigs, *args) - sampleLogPsiOffd = psi(sampleOffdConfigs) - Oloc = op.get_O_loc(sampleLogPsi, sampleLogPsiOffd) + + Oloc = op.get_O_loc(sampleConfigs, psi, sampleLogPsi, *args) if p is not None: tmpMeans.append(mpi.global_mean(Oloc, p)) diff --git a/tests/operator_t.py b/tests/operator_t.py index 0574231..59b7840 100644 --- a/tests/operator_t.py +++ b/tests/operator_t.py @@ -39,7 +39,7 @@ def test_nonzeros(self): logPsi=jnp.ones(s.shape[:-1]) logPsiSP=jnp.ones(sp.shape[:-1]) - tmp = h.get_O_loc(logPsi,logPsiSP) + tmp = h.get_O_loc_unbatched(logPsi, logPsiSP) self.assertTrue( jnp.sum(jnp.abs( tmp - 2. * jnp.sum(-(s[...,:3]-1), axis=-1) )) < 1e-7 ) @@ -64,7 +64,7 @@ def f(t): logPsi=jnp.ones(s.shape[:-1]) logPsiSP=jnp.ones(sp.shape[:-1]) - tmp = h.get_O_loc(logPsi,logPsiSP) + tmp = h.get_O_loc_unbatched(logPsi, logPsiSP) self.assertTrue( jnp.sum(jnp.abs( tmp - f(t) * jnp.sum(-(s[...,:3]-1), axis=-1) )) < 1e-7 ) @@ -87,13 +87,39 @@ def test_batched_Oloc(self): sp, matEl = h.get_s_primes(s) logPsiSp = psi(sp) - Oloc1 = h.get_O_loc(logPsi, logPsiSp) + Oloc1 = h.get_O_loc_unbatched(logPsi, logPsiSp) batchSize = 13 Oloc2 = h.get_O_loc_batched(s, psi, logPsi, batchSize) self.assertTrue(jnp.abs(jnp.sum(Oloc1) - jnp.sum(Oloc2)) < 1e-5) + def test_batched_Oloc2(self): + L = 4 + + hamilton_unbatched = op.BranchFreeOperator() + hamilton_batched = op.BranchFreeOperator(ElocBatchSize=13) + for i in range(L): + hamilton_unbatched.add(op.scal_opstr(2., (op.Sx(i),))) + hamilton_unbatched.add(op.scal_opstr(2., (op.Sy(i), op.Sz((i + 1) % L)))) + hamilton_batched.add(op.scal_opstr(2., (op.Sx(i),))) + hamilton_batched.add(op.scal_opstr(2., (op.Sy(i), op.Sz((i + 1) % L)))) + + rbm = nets.CpxRBM(numHidden=2, bias=False) + psi = NQS(rbm) + + mcSampler = jVMC.sampler.MCSampler(psi, (L,), random.PRNGKey(0), updateProposer=jVMC.sampler.propose_spin_flip, + numChains=1) + + numSamples = 100 + s, logPsi, _ = mcSampler.sample(numSamples=numSamples) + + Oloc1 = hamilton_unbatched.get_O_loc(s, psi, logPsi) + + Oloc2 = hamilton_batched.get_O_loc(s, psi, logPsi) + + self.assertTrue(jnp.abs(jnp.sum(Oloc1) - jnp.sum(Oloc2)) < 1e-5) + def test_td_prefactor(self): hamiltonian = op.BranchFreeOperator()