Skip to content

Commit

Permalink
Fix stats.py for MPI
Browse files Browse the repository at this point in the history
  • Loading branch information
tszoldra committed Dec 1, 2023
1 parent 62abb98 commit bebec73
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
6 changes: 3 additions & 3 deletions jVMC/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __init__(self, observations=None, weights=None):
observations = observations[...,None]

self._weights = weights
self._mean = mpi.global_sum( _mean_helper(observations,self._weights)[None,...] )
self._mean = mpi.global_sum( _mean_helper(observations,self._weights)[:, None,...] )
self._data = _data_prep(observations, self._weights, self._mean)
else:
self._weights = weights
Expand All @@ -132,7 +132,7 @@ def covar(self, other=None):
if other is None:
other = self

return mpi.global_sum( _covar_helper(self._data, other._data)[None,...] )
return mpi.global_sum( _covar_helper(self._data, other._data)[:, None,...] )


def var(self):
Expand Down Expand Up @@ -165,7 +165,7 @@ def covar_var(self, other=None):
if other is None:
other = self

return mpi.global_sum( _covar_var_helper(self._data, other._data, self._weights)[None,...] ) \
return mpi.global_sum( _covar_var_helper(self._data, other._data, self._weights)[:, None,...] ) \
- jnp.abs(self.covar(other))**2


Expand Down
7 changes: 4 additions & 3 deletions tests/stats_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@

from jVMC.stats import SampledObs
import jVMC.mpi_wrapper as mpi
from jVMC.global_defs import device_count


class TestStats(unittest.TestCase):

def test_sampled_obs(self):

Obs1Loc = jnp.array([[1,2,3]])
Obs2Loc = jnp.array([[[1,4],[2,5],[3,7]]])
p = (1./3) * jnp.ones(3)[None,...]
Obs1Loc = jnp.array([[1, 2, 3]] * device_count())
Obs2Loc = jnp.array([[[1, 4], [2, 5], [3, 7]]] * device_count())
p = (1. / (3 * device_count())) * jnp.ones((device_count(), 3))

obs1 = SampledObs(Obs1Loc, p)
obs2 = SampledObs(Obs2Loc, p)
Expand Down

0 comments on commit bebec73

Please sign in to comment.