Skip to content

Commit

Permalink
Merge pull request #55 from markusschmitt/fix_mean_var_outp
Browse files Browse the repository at this point in the history
Output mean and var as scalar in TDVP and MinSR classes.
  • Loading branch information
markusschmitt authored Jun 11, 2023
2 parents 1056c12 + 97128bc commit 4b297af
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 7 deletions.
4 changes: 2 additions & 2 deletions jVMC/util/minsr.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ def stop_timing(outp, name, waitFor=None):
if "intStep" in rhsArgs:
if rhsArgs["intStep"] == 0:

self.ElocMean0 = Eloc.mean()
self.ElocVar0 = Eloc.var()
self.ElocMean0 = Eloc.mean()[0]
self.ElocVar0 = Eloc.var()[0]

self.metaData = {}

Expand Down
6 changes: 3 additions & 3 deletions jVMC/util/tdvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,10 @@ def get_S(self):

def get_tdvp_equation(self, Eloc, gradients):

self.ElocMean = Eloc.mean()
self.ElocVar = Eloc.var()
self.ElocMean = Eloc.mean()[0]
self.ElocVar = Eloc.var()[0]

self.F0 = (-self.rhsPrefactor) * gradients.covar(Eloc).ravel() #* EOdata.mean()
self.F0 = (-self.rhsPrefactor) * gradients.covar(Eloc).ravel()
F = self.makeReal(self.F0)

self.S0 = gradients.covar()
Expand Down
2 changes: 1 addition & 1 deletion jVMC/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
"""Current jVMC version at head on Github."""
__version__ = "1.2.0"
__version__ = "1.2.1"
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
long_description = fh.read()


DEFAULT_DEPENDENCIES = ["setuptools", "wheel", "numpy", "jax>=0.4.1,<=0.4.11", "jaxlib>=0.4.1,<=0.4.11", "flax>=0.6.4", "mpi4py", "h5py", "PyYAML"]
DEFAULT_DEPENDENCIES = ["setuptools", "wheel", "numpy", "jax>=0.4.1,<=0.4.11", "jaxlib>=0.4.1,<=0.4.11", "flax>=0.6.4", "mpi4py", "h5py", "PyYAML", "matplotlib"]
#CUDA_DEPENDENCIES = ["setuptools", "wheel", "numpy", "jax[cuda]>=0.2.11,<=0.2.25", "flax>=0.3.6,<=0.3.6", "mpi4py", "h5py"]
DEV_DEPENDENCIES = DEFAULT_DEPENDENCIES + ["sphinx", "mock", "sphinx_rtd_theme", "pytest", "pytest-mpi"]

Expand Down

0 comments on commit 4b297af

Please sign in to comment.