-
Notifications
You must be signed in to change notification settings - Fork 17
Conversation
Following the implementation of BatchNorm, use the underlying Flax implementation for the logic behind the forward pass. This also adds the use of `jax.lax.scan` for performing the sequential calls on the hidden state, making this the implementation of the GRU layer as opposed to a reimplementation of the `GRUCell`.
- `return_state`: whether the final state should be returned - `return_sequences`: whether all the intermediate states should be returned - `go_backwards`: whether the input should be run in reverse order
Swaps the order of the state and input variables as well as the return order of the final state and sequence of states to be more inline with that of Keras API.
This allows for optional passing an initial state and the specifying of a function which can intialize the initial state.
Adds the `stateful` flag which allows for the last state of the GRU to be used as the start state for the next batch.
Thanks @ptigwe for this! Sorry it took so long to respond, I'd seen it but hadn't had the time to review it. Overall it looks very good, I'll just leave a couple of comments. |
As discussed offline, we feel here we can lift Keras restrictions regarding the shape of the input by removing the |
In preparation for the updates to `flax.jax_utils.scan_in_dim`, this changes `time_major` to `time_axis`. Currently allowing for only the specification of a single time dimension via `type hinting` although this is currently not enforced in runtime. But underneath, it stores this as a tuple which would allow for its use in `scan_in_dim`.
Hey @ptigwe! |
@cgarciae, I believe I have fixed most of the comments which you mentioned before, including the Please let me know if there is anything else which I missed and I would make the changes ASAP. |
Thanks @ptigwe ! There is a small comment about changing the default |
@cgarciae, I guess I must have missed that one. I've added it now to the PR. Quick clarification as I'm fixing the errors, by having |
Changes the default `time_axis` to be -2, i.e. by default the expected shape of the input should be of the form [..., time, :, :].
Swaps `jax.lax.scan` for `flax.jax_utils.scan_in_dim` which allows for one to have multiple time dimensions by specifying `time_axis` to be a tuple instead of a single int value.
This reverts commit 4a8f15a.
Keras uses |
OK cool. It has already been set in the previous commit. Also decided not to include the change to |
If you want to update poetry add flax@latest |
Seems the |
@ptigwe Sounds good. I'll merge this for now, we can create a new PR with Thanks a lot for pushing this through! |
Currently this shows a working implementation of GRU layer which follows quite closely the Keras API, using the
flax.linen.GRUCell
as its backbone.Tackling the implementation of GRU (#40)