Skip to content

Commit

Permalink
Merge pull request #57 from markusschmitt/dev_1.2.2
Browse files Browse the repository at this point in the history
Dev 1.2.2
  • Loading branch information
markusschmitt authored Jun 14, 2023
2 parents 4b297af + ec2f6c3 commit 4a4529d
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 140 deletions.
3 changes: 1 addition & 2 deletions jVMC/global_defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
myDeviceCount = len(myPmapDevices)
pmap_for_my_devices = partial(jax.pmap, devices=myPmapDevices)

pmapDevices = None
def pmap_devices_updated():
def pmap_devices_updated(pmapDevices):
if collections.Counter(pmapDevices) == collections.Counter(myPmapDevices):
return False
return True
Expand Down
8 changes: 4 additions & 4 deletions jVMC/mpi_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _cov_helper(data, p):
mean_helper = None
cov_helper = None


mpiPmapDevices = None
def jit_my_stuff():
# This is a helper function to make sure that pmap'd functions work with the actual choice of devices
# at all times.
Expand All @@ -38,16 +38,16 @@ def jit_my_stuff():
global _sum_sq_pmapd
global mean_helper
global cov_helper
global pmapDevices
global mpiPmapDevices

if global_defs.pmap_devices_updated():
if global_defs.pmap_devices_updated(mpiPmapDevices):
_sum_up_pmapd = global_defs.pmap_for_my_devices(lambda x: jax.lax.psum(jnp.sum(x, axis=0), 'i'), axis_name='i')
# _sum_sq_pmapd = global_defs.pmap_for_my_devices(lambda data, mean, p: jax.lax.psum(jnp.conj(data - mean).dot(p[..., None] * (data - mean)), 'i'), axis_name='i', in_axes=(0, None, 0))
_sum_sq_pmapd = global_defs.pmap_for_my_devices(lambda data, mean, p: jnp.einsum('ij, ij, i -> j', jnp.conj(data - mean[None, ...]), (data - mean[None, ...]), p), in_axes=(0, None, 0))
mean_helper = global_defs.pmap_for_my_devices(lambda data, p: jnp.expand_dims(jnp.dot(p, data), axis=0), in_axes=(0, 0))
cov_helper = global_defs.pmap_for_my_devices(_cov_helper, in_axes=(0, 0))

pmapDevices = global_defs.myPmapDevices
mpiPmapDevices = global_defs.myPmapDevices


def distribute_sampling(numSamples, localDevices=None, numChainsPerDevice=1) -> int:
Expand Down
25 changes: 13 additions & 12 deletions jVMC/operator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@

opDtype = global_defs.tCpx

def expand_batch(batch, batchSize):
outShape = list(batch.shape)
outShape[0] = batchSize
outp = jnp.zeros(tuple(outShape), dtype=batch.dtype)
return outp.at[:batch.shape[0]].set(batch)

class Operator(metaclass=abc.ABCMeta):
"""This class defines an interface and provides functionality to compute operator matrix elements
Expand Down Expand Up @@ -75,9 +80,13 @@ def __init__(self, ElocBatchSize=-1):
self._get_config_batch_pmapd = global_defs.pmap_for_my_devices(lambda d, startIdx, sliceSize: jax.lax.dynamic_slice_in_dim(d, startIdx, sliceSize), in_axes=(0, None, None), static_broadcasted_argnums=(2,))
self._get_logPsi_batch_pmapd = global_defs.pmap_for_my_devices(lambda d, startIdx, sliceSize: jax.lax.dynamic_slice_in_dim(d, startIdx, sliceSize), in_axes=(0, None, None), static_broadcasted_argnums=(2,))
self._insert_Oloc_batch_pmapd = global_defs.pmap_for_my_devices(
lambda dst, src, beg: jax.lax.dynamic_update_slice(dst, src, [beg, ]),
in_axes=(0, 0, None)
)
lambda dst, src, beg: jax.lax.dynamic_update_slice(dst, src, [beg, ]),
in_axes=(0, 0, None)
)
self._get_Oloc_slice_pmapd = global_defs.pmap_for_my_devices(
lambda d, startIdx, sliceSize: jax.lax.dynamic_slice_in_dim(d, startIdx, sliceSize),
in_axes=(0, None, None), static_broadcasted_argnums=(2,)
)

def _find_nonzero(self, m):

Expand Down Expand Up @@ -248,11 +257,6 @@ def get_O_loc_batched(self, samples, psi, logPsiS, batchSize, *args):

if remainder > 0:

def expand_batch(batch, batchSize):
outShape = list(batch.shape)
outShape[0] = batchSize
outp = jnp.zeros(tuple(outShape), dtype=batch.dtype)
return outp.at[:batch.shape[0]].set(batch)
batch = self._get_config_batch_pmapd(samples, numBatches * batchSize, remainder)
batch = global_defs.pmap_for_my_devices(expand_batch, static_broadcasted_argnums=(1,))(batch, batchSize)
logPsiSbatch = self._get_logPsi_batch_pmapd(logPsiS, numBatches * batchSize, numSamples % batchSize)
Expand All @@ -262,10 +266,7 @@ def expand_batch(batch, batchSize):

OlocBatch = self.get_O_loc_unbatched(logPsiSbatch, psi(sp))

OlocBatch = global_defs.pmap_for_my_devices(
lambda d, startIdx, sliceSize: jax.lax.dynamic_slice_in_dim(d, startIdx, sliceSize),
in_axes=(0, None, None), static_broadcasted_argnums=(2,)
)(OlocBatch, 0, remainder)
OlocBatch = self._get_Oloc_slice_pmapd(OlocBatch, 0, remainder)

Oloc = self._insert_Oloc_batch_pmapd(Oloc, OlocBatch, numBatches * batchSize)

Expand Down
211 changes: 132 additions & 79 deletions jVMC/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,84 @@

import jVMC
import jVMC.mpi_wrapper as mpi
import jVMC.global_defs as global_defs
from jVMC.global_defs import pmap_for_my_devices

_mean_helper = None
_data_prep = None
_covar_helper = None
_covar_var_helper = None
_covar_data_helper = None
_trafo_helper_1 = None
_trafo_helper_2 = None
_select_helper = None
_get_subset_helper = None
_subset_mean_helper = None
_subset_data_prep = None

statsPmapDevices = None

def jit_my_stuff():
# This is a helper function to make sure that pmap'd functions work with the actual choice of devices
# at all times.

global _mean_helper
global _covar_helper
global _covar_var_helper
global _covar_data_helper
global _trafo_helper_1
global _trafo_helper_2
global _select_helper
global _data_prep
global _get_subset_helper
global _subset_mean_helper
global _subset_data_prep

global statsPmapDevices

if jVMC.global_defs.pmap_devices_updated(statsPmapDevices):

statsPmapDevices = global_defs.myPmapDevices

_mean_helper = pmap_for_my_devices(lambda data, w: jnp.tensordot(w, data, axes=(0,0)), in_axes=(0, 0))
_data_prep = pmap_for_my_devices(lambda data, w, mean: jax.vmap(lambda d, w, m: jnp.sqrt(w) * (d - m), in_axes=(0,0,None))(data, w, mean), in_axes=(0, 0, None))
_covar_helper = pmap_for_my_devices(
lambda data1, data2:
jnp.tensordot(
jnp.conj(data1),
data2, axes=(0,0)),
in_axes=(0, 0)
)
_covar_var_helper = pmap_for_my_devices(
lambda data1, data2, w:
jnp.sum(
jnp.abs(
jax.vmap(lambda a,b: jnp.outer(a,b))(jnp.conj(data1), data2),
)**2 / w[...,None,None],
axis=0),
in_axes=(0, 0, 0)
)
_covar_data_helper = pmap_for_my_devices(lambda data1, data2, w: jax.vmap(lambda a,b,w: jnp.outer(a,b) / w)(jnp.conj(data1), data2, w), in_axes=(0, 0, 0))
_trafo_helper_1 = pmap_for_my_devices(
lambda data, w, mean, f: f(
jax.vmap(lambda x,y: x/jnp.sqrt(y), in_axes=(0,0))(data, w)
+ mean
),
in_axes=(0, 0, None), static_broadcasted_argnums=(3,))
_trafo_helper_2 = pmap_for_my_devices(
lambda data, w, mean, v, f:
jnp.matmul(v,
f(
jax.vmap(lambda x,y: x/jnp.sqrt(y), in_axes=(0,0))(data, w)
+ mean
)
),
in_axes=(0, 0, None, None), static_broadcasted_argnums=(4,))
_select_helper = pmap_for_my_devices( lambda ix,g: jax.vmap(lambda ix,g: g[ix], in_axes=(None, 0))(ix,g), in_axes=(None, 0) )
_get_subset_helper = pmap_for_my_devices(lambda x, ixs: x[slice(*ixs)], in_axes=(0,), static_broadcasted_argnums=(1,))
_subset_mean_helper = pmap_for_my_devices(lambda d, w, m: jnp.tensordot(jnp.sqrt(w), d, axes=(0,0)) + m, in_axes=(0,0,None))
_subset_data_prep = pmap_for_my_devices(jax.vmap(lambda d, w, m1, m2: d+jnp.sqrt(w)*(m1-m2), in_axes=(0,0,None,None)), in_axes=(0,0,None,None))


class SampledObs():
"""This class implements the computation of statistics from Monte Carlo or exact samples.
Expand All @@ -15,7 +91,7 @@ class SampledObs():
* ``weights``: Weights :math:`w_n` associated with observation :math:`O_n`.
"""

def __init__(self, observations, weights):
def __init__(self, observations=None, weights=None):
"""Initializes SampledObs class.
Args:
Expand All @@ -24,15 +100,19 @@ def __init__(self, observations, weights):
* ``weights``: Weights :math:`w_n` associated with observation :math:`O_n`.
"""

self.jit_my_stuff()
jit_my_stuff()

if len(observations.shape) == 2:
observations = observations[...,None]
if (observations is not None) and (weights is not None):
if len(observations.shape) == 2:
observations = observations[...,None]

self._weights = weights
self._mean = mpi.global_sum( self._mean_helper(observations,self._weights)[None,...] )
#self._data = self._data_prep(observations, self._mean)
self._data = self._data_prep(observations, self._weights, self._mean)
self._weights = weights
self._mean = mpi.global_sum( _mean_helper(observations,self._weights)[None,...] )
self._data = _data_prep(observations, self._weights, self._mean)
else:
self._weights = weights
self._data = observations
self._mean = None


def mean(self):
Expand All @@ -49,20 +129,16 @@ def covar(self, other=None):
* ``other`` [optional]: Another instance of `SampledObs`.
"""

self.jit_my_stuff()

if other is None:
other = self

#return mpi.global_sum( self._covar_helper(self._data, other._data, self._weights)[None,...] )
return mpi.global_sum( self._covar_helper(self._data, other._data)[None,...] )
return mpi.global_sum( _covar_helper(self._data, other._data)[None,...] )


def var(self):
"""Returns the variance.
"""

#return mpi.global_sum( self._mean_helper(jnp.abs(self._data)**2, self._weights)[None,...] )
return mpi.global_sum( jnp.abs(self._data)**2 )


Expand All @@ -73,13 +149,10 @@ def covar_data(self, other=None):
* ``other`` [optional]: Another instance of `SampledObs`.
"""

self.jit_my_stuff()

if other is None:
other = self

#return SampledObs( self._covar_data_helper(self._data, other._data), self._weights )
return SampledObs( self._covar_data_helper(self._data, other._data, self._weights), self._weights )
return SampledObs( _covar_data_helper(self._data, other._data, self._weights), self._weights )


def covar_var(self, other=None):
Expand All @@ -89,86 +162,66 @@ def covar_var(self, other=None):
* ``other`` [optional]: Another instance of `SampledObs`.
"""

self.jit_my_stuff()

if other is None:
other = self

return mpi.global_sum( self._covar_var_helper(self._data, other._data, self._weights)[None,...] ) \
return mpi.global_sum( _covar_var_helper(self._data, other._data, self._weights)[None,...] ) \
- jnp.abs(self.covar(other))**2


def transform(self, fun=lambda x: x):
def transform(self, nonLinearFun=lambda x: x, linearFun=None):
"""Returns a `SampledObs` for the transformed data.
Args:
* ``fun``: A function.
"""

#return SampledObs( self._trafo_helper(self._data, self._mean, fun), self._weights )
return SampledObs( self._trafo_helper(self._data, self._weights, self._mean, fun), self._weights )
if linearFun is None:
return SampledObs( _trafo_helper_1(self._data, self._weights, self._mean, nonLinearFun), self._weights )

return SampledObs( _trafo_helper_2(self._data, self._weights, self._mean, linearFun, nonLinearFun), self._weights )


def select(self, ixs):
"""Returns a `SampledObs` for the data selection indicated by the given indices.
Args:
* ``ixs``: Indices of selected data.
"""

newObs = SampledObs()
newObs._data = _select_helper(ixs, self._data)
newObs._mean = self._mean[ixs]
newObs._weights = self._weights

return newObs


def subset(self, start=None, end=None, step=None):
"""Returns a `SampledObs` for a subset of the data.
def tangent_kernel(self):
Args:
* ``start``: Start sample index for subset selection
* ``end``: End sample index for subset selection
* ``step``: Sample index step for subset selection
"""

all_data = mpi.gather(self._data)

return jnp.matmul(all_data, jnp.conj(jnp.transpose(all_data)))
newObs = SampledObs()
newObs._weights = _get_subset_helper(self._weights, (start, end, step))
normalization = mpi.global_sum(newObs._weights)
newObs._data = _get_subset_helper(self._data, (start, end, step))
newObs._weights = newObs._weights / normalization
newObs._data = newObs._data / jnp.sqrt(normalization)
newObs._mean = mpi.global_sum( _subset_mean_helper(newObs._data, newObs._weights, self._mean)[None,...] )
newObs._data = _subset_data_prep(newObs._data, newObs._weights, self._mean, newObs._mean)

return newObs

def jit_my_stuff(self):
# This is a helper function to make sure that pmap'd functions work with the actual choice of devices
# at all times.

if jVMC.global_defs.pmap_devices_updated():
self._mean_helper = pmap_for_my_devices(lambda data, w: jnp.tensordot(w, data, axes=(0,0)), in_axes=(0, 0))
#self._data_prep = pmap_for_my_devices(lambda data, mean: data - mean, in_axes=(0, None))
self._data_prep = pmap_for_my_devices(lambda data, w, mean: jax.vmap(lambda d, w, m: jnp.sqrt(w) * (d - m), in_axes=(0,0,None))(data, w, mean), in_axes=(0, 0, None))
# self._covar_helper = pmap_for_my_devices(
# lambda data1, data2, w:
# jnp.tensordot(
# jnp.conj(
# jax.vmap(lambda a,b: a*b, in_axes=(0,0))(w, data1)
# ),
# data2, axes=(0,0)),
# in_axes=(0, 0, 0)
# )
self._covar_helper = pmap_for_my_devices(
lambda data1, data2:
jnp.tensordot(
jnp.conj(data1),
data2, axes=(0,0)),
in_axes=(0, 0)
)
# self._covar_var_helper = pmap_for_my_devices(
# lambda data1, data2, w:
# jnp.sum(
# w[...,None,None] *
# jnp.abs(
# jax.vmap(lambda a,b: jnp.outer(a,b))(jnp.conj(data1), data2),
# )**2,
# axis=0),
# in_axes=(0, 0, 0)
# )
self._covar_var_helper = pmap_for_my_devices(
lambda data1, data2, w:
jnp.sum(
jnp.abs(
jax.vmap(lambda a,b: jnp.outer(a,b))(jnp.conj(data1), data2),
)**2 / w[...,None,None],
axis=0),
in_axes=(0, 0, 0)
)
self._covar_data_helper = pmap_for_my_devices(lambda data1, data2, w: jax.vmap(lambda a,b,w: jnp.outer(a,b) / w)(jnp.conj(data1), data2, w), in_axes=(0, 0, 0))
#self._trafo_helper = pmap_for_my_devices(lambda data, mean, f: f(data + mean), in_axes=(0, None), static_broadcasted_argnums=(2,))
self._trafo_helper = pmap_for_my_devices(
lambda data, w, mean, f: f(
jax.vmap(lambda x,y: x/jnp.sqrt(y), in_axes=(0,0))(data, w)
+ mean
),
in_axes=(0, 0, None), static_broadcasted_argnums=(3,))

def tangent_kernel(self):

all_data = mpi.gather(self._data)

return jnp.matmul(all_data, jnp.conj(jnp.transpose(all_data)))



Loading

0 comments on commit 4a4529d

Please sign in to comment.