Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sampling with an Exponent different from two. #41

Merged
merged 7 commits into from
Jun 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions documentation/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,18 @@
parallelism

.. toctree::
:hidden:
:glob:
:maxdepth: 2
:caption: API documentation

vqs
operator
sampler
nets
mpi
util
:hidden:
:glob:
:maxdepth: 2
:caption: API documentation

vqs
operator
sampler
nets
mpi
stats
util

.. toctree::
:hidden:
Expand Down
21 changes: 21 additions & 0 deletions documentation/source/stats.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
.. _stats:

Sample statistics module
========================

The ``SampledObs`` class provides funcitonality to conveniently compute sample statistics.

**Example:**

Assuming that ``sampler`` is an instance of a sampler class and ``psi`` is a variational quantum state,
the quantum geometric tensor
:math:`S_{k,k'}=\langle(\partial_{\theta_k}\log\psi_\theta)^*\partial_{\theta_{k'}}\log\psi_\theta\rangle_c`
can be computed as

>>> s, logPsi, p = sampler.sample()
>>> grads = SampledObs( psi.gradients(s), p)
>>> S = grads.covar()

.. automodule:: jVMC.stats
:members:
:special-members: __call__
1 change: 1 addition & 0 deletions jVMC/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from . import mpi_wrapper
from . import vqs
from . import sampler
from . import stats

from .version import __version__
from .global_defs import set_pmap_devices
Expand Down
7 changes: 6 additions & 1 deletion jVMC/global_defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import jax

from functools import partial
import collections

try:
myDevice = jax.devices()[MPI.COMM_WORLD.Get_rank() % len(jax.devices())]
Expand All @@ -21,7 +22,11 @@
myDeviceCount = len(myPmapDevices)
pmap_for_my_devices = partial(jax.pmap, devices=myPmapDevices)

import collections
pmapDevices = None
def pmap_devices_updated():
if collections.Counter(pmapDevices) == collections.Counter(myPmapDevices):
return False
return True


def get_iterable(x):
Expand Down
94 changes: 25 additions & 69 deletions jVMC/mpi_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,38 +17,28 @@
communicationTime = 0.


def _cov_helper_with_p(data, p):
def _cov_helper(data, p):
return jnp.expand_dims(
jnp.matmul(jnp.conj(jnp.transpose(data)), jnp.multiply(p[:, None], data)),
axis=0
)


def _cov_helper_without_p(data):
return jnp.expand_dims(
jnp.matmul(jnp.conj(jnp.transpose(data)), data),
axis=0
)


_sum_up_pmapd = None
_sum_sq_pmapd = None
_sum_sq_withp_pmapd = None
mean_helper = None
cov_helper_with_p = None
cov_helper_without_p = None

pmapDevices = None
cov_helper = None

import collections
#pmapDevices = None


def pmap_devices_updated():
import collections

if collections.Counter(pmapDevices) == collections.Counter(global_defs.myPmapDevices):
return False

return True
# def pmap_devices_updated():
# if collections.Counter(pmapDevices) == collections.Counter(global_defs.myPmapDevices):
# return False
# return True


def jit_my_stuff():
Expand All @@ -57,19 +47,16 @@ def jit_my_stuff():

global _sum_up_pmapd
global _sum_sq_pmapd
global _sum_sq_withp_pmapd
global mean_helper
global cov_helper_with_p
global cov_helper_without_p
global cov_helper
global pmapDevices

if pmap_devices_updated():
if global_defs.pmap_devices_updated():
_sum_up_pmapd = global_defs.pmap_for_my_devices(lambda x: jax.lax.psum(jnp.sum(x, axis=0), 'i'), axis_name='i')
_sum_sq_pmapd = global_defs.pmap_for_my_devices(lambda data, mean: jax.lax.psum(jnp.sum(jnp.conj(data - mean) * (data - mean), axis=0), 'i'), axis_name='i', in_axes=(0, None))
_sum_sq_withp_pmapd = global_defs.pmap_for_my_devices(lambda data, mean, p: jax.lax.psum(jnp.conj(data - mean).dot(p * (data - mean)), 'i'), axis_name='i', in_axes=(0, None, 0))
# _sum_sq_pmapd = global_defs.pmap_for_my_devices(lambda data, mean, p: jax.lax.psum(jnp.conj(data - mean).dot(p[..., None] * (data - mean)), 'i'), axis_name='i', in_axes=(0, None, 0))
_sum_sq_pmapd = global_defs.pmap_for_my_devices(lambda data, mean, p: jnp.einsum('ij, ij, i -> j', jnp.conj(data - mean[None, ...]), (data - mean[None, ...]), p), in_axes=(0, None, 0))
mean_helper = global_defs.pmap_for_my_devices(lambda data, p: jnp.expand_dims(jnp.dot(p, data), axis=0), in_axes=(0, 0))
cov_helper_with_p = global_defs.pmap_for_my_devices(_cov_helper_with_p, in_axes=(0, 0))
cov_helper_without_p = global_defs.pmap_for_my_devices(_cov_helper_without_p)
cov_helper = global_defs.pmap_for_my_devices(_cov_helper, in_axes=(0, 0))

pmapDevices = global_defs.myPmapDevices

Expand Down Expand Up @@ -111,7 +98,7 @@ def distribute_sampling(numSamples, localDevices=None, numChainsPerDevice=1) ->
numChainsPerProcess = localDevices * numChainsPerDevice

def spc(spp):
return int( (spp + numChainsPerProcess - 1) // numChainsPerProcess )
return int((spp + numChainsPerProcess - 1) // numChainsPerProcess)

a = numSamples % commSize
globNumSamples = (a * spc(1 + numSamples // commSize) + (commSize - a) * spc(numSamples // commSize)) * numChainsPerProcess
Expand Down Expand Up @@ -171,19 +158,15 @@ def global_sum(data):
return jax.device_put(res, global_defs.myDevice)


def global_mean(data, p=None):
def global_mean(data, p):
""" Computes the mean of input data across MPI processes and device/batch dimensions.

On each MPI process the input data is assumed to be a ``jax.numpy.array`` with a leading
device dimension followed by a batch dimension. The data is reduced by computing the mean
along device and batch dimensions as well as accross MPI processes. Hence, the result is
an array of shape ``data.shape[2:]``.

If no probabilities ``p`` are given, the empirical mean is computed, i.e.,

:math:`\\langle X\\rangle=\\frac{1}{N_S}\sum_{j=1}^{N_S} X_j`

Otherwise, the mean is computed using the given probabilities, i.e.,
The mean is computed using the given probabilities, i.e.,

:math:`\\langle X\\rangle=\sum_{j=1}^{N_S} p_jX_j`

Expand All @@ -195,29 +178,22 @@ def global_mean(data, p=None):
Mean of data across MPI processes and device/batch dimensions.
"""

jit_my_stuff()

if p is not None:
return global_sum(mean_helper(data, p))

global globNumSamples
jit_my_stuff()

return global_sum(data) / globNumSamples
return global_sum(mean_helper(data, p))


def global_variance(data, p=None):
def global_variance(data, p):
""" Computes the variance of input data across MPI processes and device/batch dimensions.

On each MPI process the input data is assumed to be a ``jax.numpy.array`` with a leading
device dimension followed by a batch dimension. The data is reduced by computing the variance
along device and batch dimensions as well as accross MPI processes. Hence, the result is
an array of shape ``data.shape[2:]``.

If no probabilities ``p`` are given, the empirical element-wise variance is computed, i.e.,

:math:`\\text{Var}(X)=\\frac{1}{N_S}\sum_{j=1}^{N_S} |X_j-\\langle X\\rangle|^2`

Otherwise, the mean is computed using the given probabilities, i.e.,
The mean is computed using the given probabilities, i.e.,

:math:`\\text{Var}(X)=\sum_{j=1}^{N_S} p_j |X_j-\\langle X\\rangle|^2`

Expand All @@ -234,15 +210,8 @@ def global_variance(data, p=None):
data.block_until_ready()

mean = global_mean(data, p)

# Compute sum locally
localSum = None
if p is not None:
localSum = np.array(_sum_sq_withp_pmapd(data, mean, p)[0])
else:
res = _sum_sq_pmapd(data, mean)[0]
res.block_until_ready()
localSum = np.array(res)
localSum = np.array(_sum_sq_pmapd(data, mean, p)[0])

# Allocate memory for result
res = np.empty_like(localSum, dtype=localSum.dtype)
Expand All @@ -256,15 +225,10 @@ def global_variance(data, p=None):
global communicationTime
communicationTime += time.perf_counter() - t0

if p is not None:
# return jnp.array(res)
return jax.device_put(res, global_defs.myDevice)
else:
# return jnp.array(res) / globNumSamples
return jax.device_put(res / globNumSamples, global_defs.myDevice)
return jax.device_put(res, global_defs.myDevice)


def global_covariance(data, p=None):
def global_covariance(data, p):
""" Computes the covariance matrix of input data across MPI processes and device/batch dimensions.

On each MPI process the input data is assumed to be a ``jax.numpy.array`` with a leading
Expand All @@ -273,11 +237,7 @@ def global_covariance(data, p=None):
matrix along device and batch dimensions as well as accross MPI processes. Hence, the result is
an array of shape ``data.shape[2]`` :math:`\\times` ``data.shape[2]``.

If no probabilities ``p`` are given, the empirical covariance is computed, i.e.,

:math:`\\text{Cov}(X)=\\frac{1}{N_S}\sum_{j=1}^{N_S} X_j\\cdot X_j^\\dagger - \\bigg(\\frac{1}{N_S}\sum_{j=1}^{N_S} X_j\\bigg)\\cdot\\bigg(\\frac{1}{N_S}\sum_{j=1}^{N_S}X_j^\\dagger\\bigg)`

Otherwise, the mean is computed using the given probabilities, i.e.,
The mean is computed using the given probabilities, i.e.,

:math:`\\text{Cov}(X)=\sum_{j=1}^{N_S} p_jX_j\\cdot X_j^\\dagger - \\bigg(\sum_{j=1}^{N_S} p_jX_j\\bigg)\\cdot\\bigg(\sum_{j=1}^{N_S}p_jX_j^\\dagger\\bigg)`

Expand All @@ -291,11 +251,7 @@ def global_covariance(data, p=None):

jit_my_stuff()

if p is not None:

return global_sum(cov_helper_with_p(data, p))

return global_mean(cov_helper_without_p(data))
return global_sum(cov_helper(data, p))


def bcast_unknown_size(data, root=0):
Expand Down
23 changes: 6 additions & 17 deletions jVMC/operator/povm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,27 +35,16 @@ def measure_povm(povm, sampler, sampleConfigs=None, probs=None, observables=None
for name, ops in observables.items():
results = povm.evaluate_observable(ops, sampleConfigs)
result[name] = {}
if probs is not None:
result[name]["mean"] = jnp.array(mpi.global_mean(results[0], probs))
result[name]["variance"] = jnp.array(mpi.global_variance(results[0], probs))
result[name]["MC_error"] = jnp.array(0)

else:
result[name]["mean"] = jnp.array(mpi.global_mean(results[0]))
result[name]["variance"] = jnp.array(mpi.global_variance(results[0]))
result[name]["MC_error"] = jnp.array(result[name]["variance"] / jnp.sqrt(sampler.get_last_number_of_samples()))
result[name]["mean"] = jnp.array(mpi.global_mean(results[0][..., None], probs)[0])
result[name]["variance"] = jnp.array(mpi.global_variance(results[0][..., None], probs)[0])
result[name]["MC_error"] = jnp.array(result[name]["variance"] / jnp.sqrt(sampler.get_last_number_of_samples()))

for key, value in results[1].items():
result_name = name + "_corr_L" + str(key)
result[result_name] = {}
if probs is not None:
result[result_name]["mean"] = jnp.array(mpi.global_mean(value, probs) - result[name]["mean"]**2)
result[result_name]["variance"] = jnp.array(mpi.global_variance(value, probs))
result[result_name]["MC_error"] = jnp.array(0.)
else:
result[result_name]["mean"] = jnp.array(mpi.global_mean(value) - result[name]["mean"]**2)
result[result_name]["variance"] = jnp.array(mpi.global_variance(value))
result[result_name]["MC_error"] = jnp.array(result[result_name]["variance"] / jnp.sqrt(sampler.get_last_number_of_samples()))
result[result_name]["mean"] = jnp.array(mpi.global_mean(value[..., None], probs)[0] - result[name]["mean"]**2)
result[result_name]["variance"] = jnp.array(mpi.global_variance(value[..., None], probs)[0])
result[result_name]["MC_error"] = jnp.array(result[name]["variance"] / jnp.sqrt(sampler.get_last_number_of_samples()))

return result

Expand Down
Loading