Skip to content

Commit

Permalink
Merge pull request #75 from markusschmitt/jax_flax_version_compatibility
Browse files Browse the repository at this point in the history
Compatibility with newest Jax and Flax versions
  • Loading branch information
markusschmitt authored Sep 11, 2024
2 parents 22d7d59 + 806640d commit 1a4c10e
Show file tree
Hide file tree
Showing 17 changed files with 45 additions and 47 deletions.
4 changes: 2 additions & 2 deletions jVMC/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@
from .version import __version__
from .global_defs import set_pmap_devices

from jax.config import config
config.update("jax_enable_x64", True)
import jax
jax.config.update("jax_enable_x64", True)
3 changes: 1 addition & 2 deletions jVMC/nets/cnn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import jax
from jax.config import config
config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)
import flax
import flax.linen as nn
import jax.numpy as jnp
Expand Down
3 changes: 1 addition & 2 deletions jVMC/nets/initializers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import jax
from jax.config import config
config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)
import flax
import flax.linen as nn

Expand Down
3 changes: 1 addition & 2 deletions jVMC/nets/rbm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import jax
from jax.config import config
config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)
import flax
#from flax import nn
import flax.linen as nn
Expand Down
15 changes: 9 additions & 6 deletions jVMC/nets/rnn1d_general.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import jax
from jax.config import config
config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)
import flax
import flax.linen as nn
import numpy as np
Expand Down Expand Up @@ -119,10 +118,10 @@ def setup(self):
if self.cell == "RNN":
self.cells = [RNNCell(actFun=self.actFun, initFun=self.initFunction, dtype=self.dtype) for _ in range(self.depth)]
elif self.cell == "LSTM":
self.cells = [LSTMCell() for _ in range(self.depth)]
self.cells = [LSTMCell(features=self.hiddenSize) for _ in range(self.depth)]
self.zero_carry = jnp.zeros((self.depth, 2, self.hiddenSize), dtype=self.dtype)
elif self.cell == "GRU":
self.cells = [GRUCell() for _ in range(self.depth)]
self.cells = [GRUCell(features=self.hiddenSize) for _ in range(self.depth)]
else:
ValueError("Cell name not recognized.")
else:
Expand Down Expand Up @@ -179,16 +178,20 @@ def rnn_cell_sample(self, carry, x):


class GRUCell(nn.Module):
features: int

@nn.compact
def __call__(self, carry, state):
current_carry, newR = nn.GRUCell(**init_fn_args(recurrent_kernel_init=jax.nn.initializers.orthogonal(dtype=global_defs.tReal)))(carry, state)
current_carry, newR = nn.GRUCell(features=self.features, **init_fn_args(recurrent_kernel_init=jax.nn.initializers.orthogonal(dtype=global_defs.tReal)))(carry, state)
return current_carry, newR[0]


class LSTMCell(nn.Module):
features: int

@nn.compact
def __call__(self, carry, state):
current_carry, newR = nn.OptimizedLSTMCell(**init_fn_args(recurrent_kernel_init=jax.nn.initializers.orthogonal(dtype=global_defs.tReal)))(carry, state)
current_carry, newR = nn.OptimizedLSTMCell(features=self.features, **init_fn_args(recurrent_kernel_init=jax.nn.initializers.orthogonal(dtype=global_defs.tReal)))(carry, state)
return jnp.asarray(current_carry), newR


Expand Down
13 changes: 8 additions & 5 deletions jVMC/nets/rnn2d_general.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import jax
from jax.config import config
config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)
# config.update('jax_disable_jit', True)
import flax
import flax.linen as nn
Expand Down Expand Up @@ -127,7 +126,7 @@ def setup(self):
if self.cell == "RNN":
self.cells = [RNNCell(actFun=self.actFun, initFun=self.initFunction, dtype=self.dtype) for _ in range(self.depth)]
elif self.cell == "LSTM":
self.cells = [LSTMCell() for _ in range(self.depth)]
self.cells = [LSTMCell(features=self.hiddenSize) for _ in range(self.depth)]
self.zero_carry = jnp.zeros((self.L, self.depth, 2, self.hiddenSize), dtype=self.dtype)
elif self.cell == "GRU":
self.cells = [GRUCell() for _ in range(self.depth)]
Expand Down Expand Up @@ -241,6 +240,8 @@ def rnn_cell_H_sample(self, carry, x):


class GRUCell(nn.Module):
features: int

@nn.compact
def __call__(self, carryH, carryV, state):
cellCarryH = nn.Dense(features=carryH.shape[-1],
Expand All @@ -249,12 +250,14 @@ def __call__(self, carryH, carryV, state):
cellCarryV = nn.Dense(features=carryV.shape[-1],
use_bias=False,
dtype=global_defs.tReal)
current_carry, newR = nn.GRUCell(param_dtype=global_defs.tReal)(cellCarryH(carryH) + cellCarryV(carryV), state)
current_carry, newR = nn.GRUCell(features=self.features, param_dtype=global_defs.tReal)(cellCarryH(carryH) + cellCarryV(carryV), state)

return current_carry, newR[0]


class LSTMCell(nn.Module):
features: int

@nn.compact
def __call__(self, carryH, carryV, state):
cellCarryH = nn.Dense(features=carryH.shape[-1],
Expand All @@ -263,7 +266,7 @@ def __call__(self, carryH, carryV, state):
cellCarryV = nn.Dense(features=carryV.shape[-1],
use_bias=False,
param_dtype=global_defs.tReal)
current_carry, newR = nn.OptimizedLSTMCell(param_dtype=global_defs.tReal)(cellCarryH(carryH) + cellCarryV(carryV), state)
current_carry, newR = nn.OptimizedLSTMCell(features=self.features, param_dtype=global_defs.tReal)(cellCarryH(carryH) + cellCarryV(carryV), state)

return jnp.asarray(current_carry), newR

Expand Down
3 changes: 1 addition & 2 deletions jVMC/util/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import jax
from jax.config import config
config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

import jax.numpy as jnp
import jax.random as random
Expand Down
2 changes: 1 addition & 1 deletion jVMC/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
"""Current jVMC version at head on Github."""
__version__ = "1.3.1"
__version__ = "1.4.0"
16 changes: 10 additions & 6 deletions jVMC/vqs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import jax
from jax.config import config
config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)
from jax import jit, grad, vmap
from jax import numpy as jnp
from jax import random
Expand Down Expand Up @@ -481,8 +480,13 @@ def params(self):
@params.setter
def params(self, val):
# Replace 'params' in parameters by `val`
self.parameters = freeze({
**unfreeze(self.parameters.pop("params")[0]),
"params": unfreeze(val)
})
ps = unfreeze(self.parameters)
ps["params"] = unfreeze(val)
if isinstance(self.parameters, flax.core.frozen_dict.FrozenDict):
ps = freeze(ps)
self.parameters = ps
# self.parameters = freeze({
# **unfreeze(self.parameters.pop("params")[0]),
# "params": unfreeze(val)
# })

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
with open("README.md", "r") as fh:
long_description = fh.read()

DEFAULT_DEPENDENCIES = ["setuptools", "wheel", "numpy", "openfermion", "jax>=0.4.1,<=0.4.20", "jaxlib>=0.4.1,<=0.4.20", "flax>=0.6.4,<=0.6.11", "mpi4py", "h5py", "PyYAML", "matplotlib", "scipy<1.13"] # Scipy version restricted, because jax is currently incompatible with new function namespace scipy.sparse.tril
DEFAULT_DEPENDENCIES = ["setuptools", "wheel", "numpy", "openfermion", "jax>=0.4.12,<=0.4.31", "jaxlib>=0.4.12,<=0.4.31", "flax>=0.7.0", "mpi4py", "h5py", "PyYAML", "matplotlib", "scipy<1.13"] # Scipy version restricted, because jax is currently incompatible with new function namespace scipy.sparse.tril
#CUDA_DEPENDENCIES = ["setuptools", "wheel", "numpy", "jax[cuda]>=0.2.11,<=0.2.25", "flax>=0.3.6,<=0.3.6", "mpi4py", "h5py"]
DEV_DEPENDENCIES = DEFAULT_DEPENDENCIES + ["sphinx", "mock", "sphinx_rtd_theme", "pytest", "pytest-mpi"]

Expand Down
3 changes: 1 addition & 2 deletions tests/mpi_wrapper_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import unittest

import jax
from jax.config import config
config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

import jax.numpy as jnp

Expand Down
5 changes: 1 addition & 4 deletions tests/nets_test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import unittest
import sys
sys.path.append(sys.path[0] + "../..")
import jVMC
import jVMC.nets as nets

import jax
from jax.config import config
config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)
import jax.random as random
import jax.numpy as jnp
import numpy as np
Expand Down
3 changes: 1 addition & 2 deletions tests/operator_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import unittest

import jax
from jax.config import config
config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

import jax.random as random
import jax.numpy as jnp
Expand Down
8 changes: 4 additions & 4 deletions tests/povm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def copy_dict(a):
self.stepper = jVMC.util.stepper.Heun(timeStep=dt) # ODE integrator

def test_matrix_to_povm(self):
unity = jnp.eye(2)
zero_matrix = jnp.zeros((2, 2))
unity = jnp.eye(2, dtype=jVMC.global_defs.tCpx)
zero_matrix = jnp.zeros((2, 2), dtype=jVMC.global_defs.tCpx)

system_data = {"dim": "1D", "L": 2}
povm = jVMC.operator.POVM(system_data)
Expand All @@ -78,8 +78,8 @@ def test_matrix_to_povm(self):
self.assertRaises(ValueError, op.matrix_to_povm, zero_matrix, povm.M, povm.T_inv, mode='wrong_mode')

def test_adding_operator(self):
unity = jnp.eye(2)
zero_matrix = jnp.zeros((2, 2))
unity = jnp.eye(2, dtype=jVMC.global_defs.tCpx)
zero_matrix = jnp.zeros((2, 2), dtype=jVMC.global_defs.tCpx)

system_data = {"dim": "1D", "L": 2}
povm = jVMC.operator.POVM(system_data)
Expand Down
3 changes: 1 addition & 2 deletions tests/sampler_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import unittest

import jax
from jax.config import config
config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)
import jax.random as random
import jax.numpy as jnp

Expand Down
3 changes: 1 addition & 2 deletions tests/symmetries_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import unittest

import jax
from jax.config import config
config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

import numpy as np
Expand Down
3 changes: 1 addition & 2 deletions tests/tdvp_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import unittest

import jax
from jax.config import config
config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

import jax.random as random
import jax.numpy as jnp
Expand Down

0 comments on commit 1a4c10e

Please sign in to comment.