Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix complex RNN #66

Merged
merged 2 commits into from
Dec 6, 2023
Merged

Fix complex RNN #66

merged 2 commits into from
Dec 6, 2023

Conversation

tszoldra
Copy link
Contributor

RNN with complex parameters could not be created. To reproduce:

>>> import jVMC
>>> import jax
>>> net = jVMC.nets.RNN1DGeneral(L=4, realValuedParams=False)
>>> params = net.init(jax.random.PRNGKey(0), jax.numpy.zeros(4, dtype=jax.numpy.int32))

gives the output

File site-packages/jVMC/nets/rnn1d_general.py", line 146, in __call__
    _, probs = self.rnn_cell((self.zero_carry, jnp.zeros(self.inputDim)), jax.nn.one_hot(x, self.inputDim))
  File "site-packages/flax/core/axes_scan.py", line 139, in scan_fn
    _, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
  File "site-packages/flax/core/axes_scan.py", line 115, in body_fn
    broadcast_out, c, ys = fn(broadcast_in, c, *xs)
  File "site-packages/jVMC/nets/rnn1d_general.py", line 154, in rnn_cell
    newCarry, out = self.rnnCell(carry[0], carry[1])
  File "site-packages/jVMC/nets/rnn1d_general.py", line 52, in __call__
    current_carry, newR = cell(c, newR)
  File "site-packages/jVMC/nets/rnn1d_general.py", line 224, in __call__
    newCarry = (self.actFun(cellCarry(carry[0])) + state)[None, :]
  File "site-packages/flax/linen/linear.py", line 196, in __call__
    kernel = self.param('kernel',
TypeError: cplx_variance_scaling() got an unexpected keyword argument 'scale'

I fixed the wrong arguments passed to cplx_variance_scaling(). Also, in the case of RNN1DGeneral(), the self.dtype argument was not passed to RNNCellStack() which caused mixing of datatypes (complex and float). I also added tests for RNN with complex parameters, but I am sure it is not a good practice to compare the output of the RNN with the fixed value. But the test allows one to see if the RNN1DGeneral and RNN2DGeneral at least compile.

@markusschmitt
Copy link
Owner

Thanks for this fix!

There seems to be a problem with the tests, though, when using jax/jaxlib 0.4.11. Can you change the dependencies in setup.py (in the top directory) to jax>=0.4.1,<=0.4.20 and jaxlib>=0.4.1,<=0.4.20? Then the CI testing will install the latest version and the tests will hopefully pass.

@markusschmitt markusschmitt merged commit fc7bc11 into markusschmitt:master Dec 6, 2023
1 check passed
@markusschmitt markusschmitt mentioned this pull request Dec 6, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants