Skip to content

Commit

Permalink
Merge pull request #74 from markusschmitt/fix_parallel_sampling
Browse files Browse the repository at this point in the history
Fix parallel sampling
  • Loading branch information
markusschmitt authored Aug 8, 2024
2 parents e6accab + 21c3e4b commit 22d7d59
Show file tree
Hide file tree
Showing 10 changed files with 91 additions and 58 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/automatic_testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ jobs:
pip install -e .[dev]
- name: Run tests
run: pytest tests/
run: mpirun -n 2 python -m pytest --with-mpi tests/ #pytest tests/ #
2 changes: 1 addition & 1 deletion examples/ex2_unitary_time_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
sampler = jVMC.sampler.ExactSampler(psi, L)

# Set up TDVP
tdvpEquation = jVMC.util.tdvp.TDVP(sampler, svdTol=1e-8,
tdvpEquation = jVMC.util.tdvp.TDVP(sampler, pinvTol=1e-8,
rhsPrefactor=1.j,
makeReal='imag')

Expand Down
4 changes: 2 additions & 2 deletions jVMC/mpi_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def distribute_sampling(numSamples, localDevices=None, numChainsPerDevice=1) ->

globNumSamples = numSamples

return samplesPerProcess
return samplesPerProcess, globNumSamples

numChainsPerProcess = localDevices * numChainsPerDevice

Expand All @@ -92,7 +92,7 @@ def spc(spp):
a = numSamples % commSize
globNumSamples = (a * spc(1 + numSamples // commSize) + (commSize - a) * spc(numSamples // commSize)) * numChainsPerProcess

return spc(samplesPerProcess)
return spc(samplesPerProcess), globNumSamples


def first_sample_id():
Expand Down
7 changes: 6 additions & 1 deletion jVMC/operator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ def get_O_loc(self, samples, psi, logPsiS=None, *args):
else:
sampleOffdConfigs, _ = self.get_s_primes(samples, *args)
logPsiSP = psi(sampleOffdConfigs)
if not psi.logarithmic:
logPsiSP = jnp.log(logPsiSP)
return self.get_O_loc_unbatched(logPsiS, logPsiSP)

def get_O_loc_unbatched(self, logPsiS, logPsiSP):
Expand Down Expand Up @@ -244,8 +246,11 @@ def get_O_loc_batched(self, samples, psi, logPsiS, batchSize, *args):
logPsiSbatch = self._get_logPsi_batch_pmapd(logPsiS, b * batchSize, batchSize)

sp, _ = self.get_s_primes(batch, *args)
logPsiSP = psi(sp)
if not psi.logarithmic:
logPsiSP = jnp.log(logPsiSP)

OlocBatch = self.get_O_loc_unbatched(logPsiSbatch, psi(sp))
OlocBatch = self.get_O_loc_unbatched(logPsiSbatch, logPsiSP)

if Oloc is None:
if OlocBatch.dtype == global_defs.tCpx:
Expand Down
18 changes: 9 additions & 9 deletions jVMC/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def get_last_number_of_samples(self):
Returns:
Number of samples generated by last call to ``sample()`` member function.
"""
return mpi.globNumSamples
return self.globNumSamples

def sample(self, parameters=None, numSamples=None, multipleOf=1):
"""Generate random samples from wave function.
Expand Down Expand Up @@ -216,11 +216,10 @@ def sample(self, parameters=None, numSamples=None, multipleOf=1):
if parameters is not None:
tmpP = self.net.params
self.net.set_parameters(parameters)
configs = self._get_samples_gen(self.net.parameters, numSamples, multipleOf)
coeffs = self.net(configs)
configs, coeffs, ps = self._get_samples_gen(self.net.parameters, numSamples, multipleOf)
if parameters is not None:
self.net.params = tmpP
return configs, coeffs, jnp.ones(configs.shape[:2]) / jnp.prod(jnp.asarray(configs.shape[:2]))
return configs, coeffs, ps

configs, logPsi = self._get_samples_mcmc(parameters, numSamples, multipleOf)
p = jnp.exp((1.0 / self.logProbFactor - self.mu) * jnp.real(logPsi))
Expand All @@ -235,7 +234,7 @@ def _randomize_samples(self, samples, key, orbit):

def _get_samples_gen(self, params, numSamples, multipleOf=1):

numSamples = mpi.distribute_sampling(numSamples, localDevices=global_defs.device_count(), numChainsPerDevice=multipleOf)
numSamples, self.globNumSamples = mpi.distribute_sampling(numSamples, localDevices=global_defs.device_count(), numChainsPerDevice=multipleOf)

tmpKeys = random.split(self.key[0], 3 * global_defs.device_count())
self.key = tmpKeys[:global_defs.device_count()]
Expand All @@ -248,9 +247,10 @@ def _get_samples_gen(self, params, numSamples, multipleOf=1):
self._randomize_samples_jitd[str(numSamples)] = global_defs.pmap_for_my_devices(self._randomize_samples, static_broadcasted_argnums=(), in_axes=(0, 0, None))

if not self.orbit is None:
return self._randomize_samples_jitd[str(numSamples)](samples, tmpKey2, self.orbit)
samples = self._randomize_samples_jitd[str(numSamples)](samples, tmpKey2, self.orbit)
# return self._randomize_samples_jitd[str(numSamples)](samples, tmpKey2, self.orbit)

return samples
return samples, self.net(samples), jnp.ones(samples.shape[:2]) / self.globNumSamples

def _get_samples_mcmc(self, params, numSamples, multipleOf=1):

Expand All @@ -267,7 +267,7 @@ def _get_samples_mcmc(self, params, numSamples, multipleOf=1):
# Initialize sampling stuff
self._mc_init(params)

numSamples = mpi.distribute_sampling(numSamples, localDevices=global_defs.device_count(), numChainsPerDevice=np.lcm(self.numChains, multipleOf))
numSamples, self.globNumSamples = mpi.distribute_sampling(numSamples, localDevices=global_defs.device_count(), numChainsPerDevice=np.lcm(self.numChains, multipleOf))
numSamplesStr = str(numSamples)

# check whether _get_samples is already compiled for given number of samples
Expand Down Expand Up @@ -415,7 +415,7 @@ def __init__(self, net, sampleShape, lDim=2, logProbFactor=0.5):

def get_basis(self):

myNumStates = mpi.distribute_sampling(self.lDim**self.N)
myNumStates, _ = mpi.distribute_sampling(self.lDim**self.N)
myFirstState = mpi.first_sample_id()

deviceCount = global_defs.device_count()
Expand Down
2 changes: 1 addition & 1 deletion jVMC/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
"""Current jVMC version at head on Github."""
__version__ = "1.3.0"
__version__ = "1.3.1"
4 changes: 4 additions & 0 deletions jVMC/vqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class NQS:
"""

def __init__(self, net,
logarithmic=True,
batchSize=1000,
seed=1234,
orbit=None,
Expand All @@ -138,6 +139,8 @@ def __init__(self, net,
a ``__call__`` function for evaluation. \
If a tuple of two networks is given, the first is used for the logarithmic \
amplitude and the second for the phase of the wave function coefficient.
* ``logarithmic``: Boolean variable indicating, whether the ANN returns logarithmic \
(:math:`\log\psi_\theta(s)`) or plain (:math:`\psi_\theta(s)`) wave function coefficients.
* ``batchSize``: Batch size for batched network evaluation. Choice \
of this parameter impacts performance: with too small values performance \
is limited by memory access overheads, too large values can lead \
Expand All @@ -153,6 +156,7 @@ def __init__(self, net,
self.holomorphic = False
self.flat_gradient_function = flat_gradient_real
self.dict_gradient_function = dict_gradient_real
self.logarithmic = logarithmic

self.initialized = False
self.seed = seed
Expand Down
17 changes: 9 additions & 8 deletions tests/mpi_wrapper_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,25 @@ class TestMPI(unittest.TestCase):
def test_mean(self):

data = jnp.array(np.arange(720 * 4 * global_defs.device_count()).reshape((global_defs.device_count() * 720, 4)))
myNumSamples = mpi.distribute_sampling(global_defs.device_count() * 720)
myNumSamples, globNumSamples = mpi.distribute_sampling(global_defs.device_count() * 720)

myData = data[mpi.rank * myNumSamples:(mpi.rank + 1) * myNumSamples].reshape(get_shape((-1, 4)))

self.assertTrue(jnp.sum(mpi.global_mean(myData, jnp.ones(myData.shape[:2]) / jnp.prod(jnp.asarray(myData.shape[:2]))) - jnp.mean(data, axis=0)) < 1e-10)
self.assertTrue(jnp.sum(mpi.global_mean(myData, jnp.ones(myData.shape[:2]) / globNumSamples) - jnp.mean(data, axis=0)) < 1e-10)

def test_var(self):

data = jnp.array(np.arange(720 * 4 * global_defs.device_count()).reshape((global_defs.device_count() * 720, 4)))
myNumSamples = mpi.distribute_sampling(global_defs.device_count() * 720)
myNumSamples, globNumSamples = mpi.distribute_sampling(global_defs.device_count() * 720)

myData = data[mpi.rank * myNumSamples:(mpi.rank + 1) * myNumSamples].reshape(get_shape((-1, 4)))
self.assertTrue(jnp.sum(mpi.global_variance(myData, jnp.ones(myData.shape[:2]) / jnp.prod(jnp.asarray(myData.shape[:2]))) - jnp.var(data, axis=0)) < 1e-09)

self.assertTrue(jnp.sum(mpi.global_variance(myData, jnp.ones(myData.shape[:2]) / globNumSamples) - jnp.var(data, axis=0)) < 1e-09)

def test_bcast(self):
data = np.zeros(10, dtype=np.int32)
with self.assertRaises(TypeError) as context:
mpi.bcast_unknown_size(data)
# def test_bcast(self):
# data = np.zeros(10, dtype=np.int32)
# with self.assertRaises(TypeError) as context:
# mpi.bcast_unknown_size(data)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 22d7d59

Please sign in to comment.