Skip to content

Commit

Permalink
Fix complex RNN
Browse files Browse the repository at this point in the history
  • Loading branch information
tszoldra committed Nov 27, 2023
1 parent 62abb98 commit 48bdf94
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 8 deletions.
4 changes: 2 additions & 2 deletions jVMC/nets/rnn1d_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def setup(self):
self.initFunction = jax.nn.initializers.variance_scaling(scale=self.initScale, mode="fan_avg", distribution="uniform")
else:
self.dtype = global_defs.tCpx
self.initFunction = partial(jVMC.nets.initializers.cplx_variance_scaling, scale=self.initScale)
self.initFunction = jVMC.nets.initializers.cplx_variance_scaling

if isinstance(self.cell, str):
self.zero_carry = jnp.zeros((self.depth, 1, self.hiddenSize), dtype=self.dtype)
Expand All @@ -129,7 +129,7 @@ def setup(self):
self.cells = self.cell[0]
self.zero_carry = self.cell[1]

self.rnnCell = RNNCellStack(self.cells, actFun=self.actFun)
self.rnnCell = RNNCellStack(self.cells, actFun=self.actFun, dtype=self.dtype)
init_args = init_fn_args(dtype=self.dtype, bias_init=jax.nn.initializers.zeros, kernel_init=self.initFunction)
self.outputDense = nn.Dense(features=(self.inputDim-1) * (2 - self.realValuedOutput),
use_bias=True, **init_args)
Expand Down
2 changes: 1 addition & 1 deletion jVMC/nets/rnn2d_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def setup(self):
self.initFunction = jax.nn.initializers.variance_scaling(scale=self.initScale, mode="fan_avg", distribution="uniform")
else:
self.dtype = global_defs.tCpx
self.initFunction = partial(jVMC.nets.initializers.cplx_variance_scaling, scale=self.initScale)
self.initFunction = jVMC.nets.initializers.cplx_variance_scaling

if isinstance(self.cell, str):
if self.cell in ["LSTM", "GRU"] and not self.realValuedParams:
Expand Down
35 changes: 30 additions & 5 deletions tests/nets_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unittest

import sys
sys.path.append(sys.path[0] + "../..")
import jVMC
import jVMC.nets as nets

Expand Down Expand Up @@ -68,20 +69,44 @@ def test_sym_net(self):

def test_sym_net_generative(self):
L=5
rbm = nets.RNN1DGeneral(L=5)
rnn = nets.RNN1DGeneral(L=5)
orbit = jVMC.util.symmetries.get_orbit_1D(L, "translation")
rbm_sym = nets.SymNet(net=rbm, orbit=orbit)
params = rbm_sym.init(random.PRNGKey(0), jnp.zeros((5,), dtype=np.int32))
rnn_sym = nets.SymNet(net=rnn, orbit=orbit)
params = rnn_sym.init(random.PRNGKey(0), jnp.zeros((5,), dtype=np.int32))

S0 = jnp.pad(jnp.array([1, 0, 1, 1, 0]), (0, 4), 'wrap')
S = jnp.array(
[S0[i:i + 5]for i in range(5)]
)
psiS = jax.vmap(lambda s: rbm_sym.apply(params, s))(S)
psiS = jax.vmap(lambda s: rnn_sym.apply(params, s))(S)
psiS = psiS - psiS[0]

self.assertTrue(jnp.max(jnp.abs(psiS)) < 1e-12)


class TestCpxNet(unittest.TestCase):

def test_cpx_rnn_1d(self):
rnn = nets.RNN1DGeneral(L=5, realValuedParams=False)
params = rnn.init(random.PRNGKey(0), jnp.zeros((5,), dtype=np.int32))

S0 = jnp.array([1, 0, 1, 1, 0])
psiS0 = rnn.apply(params, S0)
self.assertTrue(jnp.max(jnp.abs(psiS0 - (-1.7393452561818394+0.025880153799492975j))) < 1e-12)

def test_cpx_rnn_2d(self):
rnn = nets.RNN2DGeneral(L=4, realValuedParams=False)
params = rnn.init(random.PRNGKey(0), jnp.zeros((4, 4), dtype=np.int32))

S0 = jnp.array(
[[1, 0, 1, 1],
[0, 1, 1, 1],
[0, 0, 1, 0],
[1, 0, 0, 1]]
)
psiS0 = rnn.apply(params, S0)
self.assertTrue(jnp.max(jnp.abs(psiS0 - (-5.549380111605981-0.0316078980423882j))) < 1e-12)


if __name__ == "__main__":
unittest.main()

0 comments on commit 48bdf94

Please sign in to comment.