Skip to content
This repository has been archived by the owner on Feb 26, 2023. It is now read-only.

Initial implementation of GRU layers #48

Merged
merged 12 commits into from
Jan 14, 2022
Merged

Conversation

ptigwe
Copy link
Contributor

@ptigwe ptigwe commented Nov 27, 2021

Currently this shows a working implementation of GRU layer which follows quite closely the Keras API, using the flax.linen.GRUCell as its backbone.

Tackling the implementation of GRU (#40)

Following the implementation of BatchNorm, use the underlying Flax
implementation for the logic behind the forward pass. This also adds the
use of `jax.lax.scan` for performing the sequential calls on the hidden
state, making this the implementation of the GRU layer as opposed to a
reimplementation of the `GRUCell`.
- `return_state`: whether the final state should be returned
- `return_sequences`: whether all the intermediate states should be
  returned
- `go_backwards`: whether the input should be run in reverse order
Swaps the order of the state and input variables as well as the return
order of the final state and sequence of states to be more inline with
that of Keras API.
This allows for optional passing an initial state and the specifying of
a function which can intialize the initial state.
Adds the `stateful` flag which allows for the last state of the GRU to
be used as the start state for the next batch.
@cgarciae
Copy link
Owner

cgarciae commented Dec 1, 2021

Thanks @ptigwe for this! Sorry it took so long to respond, I'd seen it but hadn't had the time to review it. Overall it looks very good, I'll just leave a couple of comments.

treex/nn/recurrent.py Outdated Show resolved Hide resolved
treex/nn/recurrent.py Outdated Show resolved Hide resolved
treex/nn/recurrent.py Outdated Show resolved Hide resolved
@cgarciae
Copy link
Owner

cgarciae commented Dec 2, 2021

As discussed offline, we feel here we can lift Keras restrictions regarding the shape of the input by removing the time_major argument in favor of something like a time_axis which lets you select the dimension you are going to scan over. This changes requires a couple of changes to the logic including figuring out the proper batch_dims to initialize_carry.

In preparation for the updates to `flax.jax_utils.scan_in_dim`, this
changes `time_major` to `time_axis`. Currently allowing for only the
specification of a single time dimension via `type hinting` although
this is currently not enforced in runtime. But underneath, it stores
this as a tuple which would allow for its use in `scan_in_dim`.
@cgarciae
Copy link
Owner

Hey @ptigwe!
There are a few minor comments / changes left if you'd like to finish them else I'd be very happy to continue with the PR.

@ptigwe
Copy link
Contributor Author

ptigwe commented Jan 11, 2022

@cgarciae, I believe I have fixed most of the comments which you mentioned before, including the time_axis as discussed. The only thing I didn't update was the switch from jax.lax.scan to jax_utils.scan_in_dim, as I was waiting for the change in flax to get merged and the upstream version updated.

Please let me know if there is anything else which I missed and I would make the changes ASAP.

@cgarciae
Copy link
Owner

Thanks @ptigwe ! There is a small comment about changing the default time_axis from 0 to -2.

@ptigwe
Copy link
Contributor Author

ptigwe commented Jan 12, 2022

@cgarciae, I guess I must have missed that one. I've added it now to the PR. Quick clarification as I'm fixing the errors, by having time_axis = -2 this means the default expected input shape would be [..., T, C] as opposed to the current default of 0 / -3 with [..., T, B, C].

Changes the default `time_axis` to be -2, i.e. by default the expected
shape of the input should be of the form [..., time, :, :].
Swaps `jax.lax.scan` for `flax.jax_utils.scan_in_dim` which allows for
one to have multiple time dimensions by specifying `time_axis` to be a
tuple instead of a single int value.
@cgarciae
Copy link
Owner

Keras uses [B, T, C], we are going for [..., B, T, C] which is even better. I think this is nicer than [...,T, B, C] because 1D convolutions and transformers also use [..., B, T, C] so the op would chain better.

@ptigwe
Copy link
Contributor Author

ptigwe commented Jan 12, 2022

OK cool. It has already been set in the previous commit. Also decided not to include the change to jax_utils.scan_in_dim as the current version being pointed to 0.3.6 does not have the unrolled added to it yet.

@cgarciae
Copy link
Owner

If you want to update flax to the latest version to add this feature now you can run:

poetry add flax@latest

@ptigwe
Copy link
Contributor Author

ptigwe commented Jan 14, 2022

Seems the latest is indeed 0.3.6 which indeed does not have the updated scan_in_dims. Whenever it gets updated we can always update the code to make use of it. I've also included some comments on things that might need changing in that case.

@cgarciae
Copy link
Owner

@ptigwe Sounds good. I'll merge this for now, we can create a new PR with scan_in_dims latter.

Thanks a lot for pushing this through!

@cgarciae cgarciae merged commit fce1175 into cgarciae:master Jan 14, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants