Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor get_O_loc #36

Merged
merged 4 commits into from
Jan 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions jVMC/operator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,13 @@ class Operator(metaclass=abc.ABCMeta):

"""

def __init__(self):
def __init__(self, ElocBatchSize=-1):
"""Initialize ``Operator``.
"""

self.compiled = False
self.compiled_argnum = -1
self.ElocBatchSize = ElocBatchSize

# pmap'd member functions
self._get_s_primes_pmapd = None
Expand Down Expand Up @@ -146,13 +147,37 @@ def id_fun(*args):

return self._flatten_pmapd(self.sp), self.matEl


def _get_O_loc(self, matEl, logPsiS, logPsiSP):

return jax.vmap(lambda x, y, z: jnp.sum(x * jnp.exp(z - y)), in_axes=(0, 0, 0))(matEl, logPsiS, logPsiSP.reshape(matEl.shape))

def get_O_loc(self, samples, psi, logPsiS=None, *args):
"""Compute :math:`O_{loc}(s)`.

If the instance parameter ElocBatchSize is larger than 0 :math:`O_{loc}(s)` is computed in a batch-wise manner
to avoid out-of-memory issues.

Arguments:
* ``samples``: Sample of computational basis configurations :math:`s`.
* ``psi``: Neural quantum state.
* ``logPsiS``: Logarithmic amplitudes :math:`\\ln(\psi(s))`
* ``*args``: Further positional arguments for the operator.

Returns:
:math:`O_{loc}(s)` for each configuration :math:`s`.
"""

def get_O_loc(self, logPsiS, logPsiSP):
if logPsiS is None:
logPsiS = psi(samples)

if self.ElocBatchSize > 0:
return self.get_O_loc_batched(samples, psi, logPsiS, self.ElocBatchSize, *args)
else:
sampleOffdConfigs, _ = self.get_s_primes(samples, *args)
logPsiSP = psi(sampleOffdConfigs)
return self.get_O_loc_unbatched(logPsiS, logPsiSP)

def get_O_loc_unbatched(self, logPsiS, logPsiSP):
"""Compute :math:`O_{loc}(s)`.

This member function assumes that ``get_s_primes(s)`` has been called before, as \
Expand Down Expand Up @@ -196,7 +221,7 @@ def get_O_loc_batched(self, samples, psi, logPsiS, batchSize, *args):
remainder = numSamples % batchSize

# Minimize mismatch
if remainder>0:
if remainder > 0:
batchSize = numSamples // (numBatches+1)
numBatches = numSamples // batchSize
remainder = numSamples % batchSize
Expand All @@ -208,7 +233,7 @@ def get_O_loc_batched(self, samples, psi, logPsiS, batchSize, *args):

sp, _ = self.get_s_primes(batch, *args)

OlocBatch = self.get_O_loc(logPsiSbatch, psi(sp))
OlocBatch = self.get_O_loc_unbatched(logPsiSbatch, psi(sp))

Oloc = self._insert_Oloc_batch_pmapd(Oloc, OlocBatch, b * batchSize)

Expand All @@ -226,7 +251,7 @@ def expand_batch(batch, batchSize):

sp, _ = self.get_s_primes(batch, *args)

OlocBatch = self.get_O_loc(logPsiSbatch, psi(sp))
OlocBatch = self.get_O_loc_unbatched(logPsiSbatch, psi(sp))

OlocBatch = global_defs.pmap_for_my_devices(
lambda d, startIdx, sliceSize: jax.lax.dynamic_slice_in_dim(d, startIdx, sliceSize),
Expand Down
4 changes: 2 additions & 2 deletions jVMC/operator/branch_free.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class BranchFreeOperator(Operator):
* ``lDim``: Dimension of local Hilbert space.
"""

def __init__(self, lDim=2):
def __init__(self, lDim=2, **kwargs):
"""Initialize ``Operator``.

Arguments:
Expand All @@ -156,7 +156,7 @@ def __init__(self, lDim=2):
self.ops = []
self.lDim = lDim

super().__init__()
super().__init__(**kwargs)

def add(self, opDescr):
"""Add another operator to the operator
Expand Down
4 changes: 2 additions & 2 deletions jVMC/operator/povm.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,13 +338,13 @@ class POVMOperator(Operator):
* ``lDim``: Dimension of local Hilbert space.
"""

def __init__(self, povm, ldim=4):
def __init__(self, povm, ldim=4, **kwargs):
"""Initialize ``Operator``.
"""
self.povm = povm
self.ldim = ldim
self.ops = []
super().__init__()
super().__init__(**kwargs)

def add(self, opDescr):
"""Add another operator to the operator.
Expand Down
6 changes: 1 addition & 5 deletions jVMC/util/tdvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,11 +266,7 @@ def stop_timing(outp, name, waitFor=None):

# Evaluate local energy
start_timing(outp, "compute Eloc")
sampleOffdConfigs, matEls = hamiltonian.get_s_primes(sampleConfigs, t)
start_timing(outp, "evaluate off-diagonal")
sampleLogPsiOffd = psi(sampleOffdConfigs)
stop_timing(outp, "evaluate off-diagonal", waitFor=sampleLogPsiOffd)
Eloc = hamiltonian.get_O_loc(sampleLogPsi, sampleLogPsiOffd)
Eloc = hamiltonian.get_O_loc(sampleConfigs, psi, sampleLogPsi, t)
stop_timing(outp, "compute Eloc", waitFor=Eloc)

# Evaluate gradients
Expand Down
6 changes: 2 additions & 4 deletions jVMC/util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,8 @@ def measure(observables, psi, sampler, numSamples=None):
if isinstance(op, collections.abc.Iterable):
args = tuple(op[1:])
op = op[0]

sampleOffdConfigs, matEls = op.get_s_primes(sampleConfigs, *args)
sampleLogPsiOffd = psi(sampleOffdConfigs)
Oloc = op.get_O_loc(sampleLogPsi, sampleLogPsiOffd)

Oloc = op.get_O_loc(sampleConfigs, psi, sampleLogPsi, *args)

if p is not None:
tmpMeans.append(mpi.global_mean(Oloc, p))
Expand Down
32 changes: 29 additions & 3 deletions tests/operator_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_nonzeros(self):
logPsi=jnp.ones(s.shape[:-1])
logPsiSP=jnp.ones(sp.shape[:-1])

tmp = h.get_O_loc(logPsi,logPsiSP)
tmp = h.get_O_loc_unbatched(logPsi, logPsiSP)

self.assertTrue( jnp.sum(jnp.abs( tmp - 2. * jnp.sum(-(s[...,:3]-1), axis=-1) )) < 1e-7 )

Expand All @@ -64,7 +64,7 @@ def f(t):
logPsi=jnp.ones(s.shape[:-1])
logPsiSP=jnp.ones(sp.shape[:-1])

tmp = h.get_O_loc(logPsi,logPsiSP)
tmp = h.get_O_loc_unbatched(logPsi, logPsiSP)

self.assertTrue( jnp.sum(jnp.abs( tmp - f(t) * jnp.sum(-(s[...,:3]-1), axis=-1) )) < 1e-7 )

Expand All @@ -87,13 +87,39 @@ def test_batched_Oloc(self):

sp, matEl = h.get_s_primes(s)
logPsiSp = psi(sp)
Oloc1 = h.get_O_loc(logPsi, logPsiSp)
Oloc1 = h.get_O_loc_unbatched(logPsi, logPsiSp)

batchSize = 13
Oloc2 = h.get_O_loc_batched(s, psi, logPsi, batchSize)

self.assertTrue(jnp.abs(jnp.sum(Oloc1) - jnp.sum(Oloc2)) < 1e-5)

def test_batched_Oloc2(self):
L = 4

hamilton_unbatched = op.BranchFreeOperator()
hamilton_batched = op.BranchFreeOperator(ElocBatchSize=13)
for i in range(L):
hamilton_unbatched.add(op.scal_opstr(2., (op.Sx(i),)))
hamilton_unbatched.add(op.scal_opstr(2., (op.Sy(i), op.Sz((i + 1) % L))))
hamilton_batched.add(op.scal_opstr(2., (op.Sx(i),)))
hamilton_batched.add(op.scal_opstr(2., (op.Sy(i), op.Sz((i + 1) % L))))

rbm = nets.CpxRBM(numHidden=2, bias=False)
psi = NQS(rbm)

mcSampler = jVMC.sampler.MCSampler(psi, (L,), random.PRNGKey(0), updateProposer=jVMC.sampler.propose_spin_flip,
numChains=1)

numSamples = 100
s, logPsi, _ = mcSampler.sample(numSamples=numSamples)

Oloc1 = hamilton_unbatched.get_O_loc(s, psi, logPsi)

Oloc2 = hamilton_batched.get_O_loc(s, psi, logPsi)

self.assertTrue(jnp.abs(jnp.sum(Oloc1) - jnp.sum(Oloc2)) < 1e-5)

def test_td_prefactor(self):

hamiltonian = op.BranchFreeOperator()
Expand Down