This repository has been archived by the owner on Feb 26, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Initial Implementation of GRU layer Following the implementation of BatchNorm, use the underlying Flax implementation for the logic behind the forward pass. This also adds the use of `jax.lax.scan` for performing the sequential calls on the hidden state, making this the implementation of the GRU layer as opposed to a reimplementation of the `GRUCell`. * Add key functionality to GRU implementation - `return_state`: whether the final state should be returned - `return_sequences`: whether all the intermediate states should be returned - `go_backwards`: whether the input should be run in reverse order * Swap order of state and input variables Swaps the order of the state and input variables as well as the return order of the final state and sequence of states to be more inline with that of Keras API. * Allow for optional passing of initial state This allows for optional passing an initial state and the specifying of a function which can intialize the initial state. * Add `stateful` flag Adds the `stateful` flag which allows for the last state of the GRU to be used as the start state for the next batch. * Adds documentation to GRU * Change `time_major` to `time_axis` In preparation for the updates to `flax.jax_utils.scan_in_dim`, this changes `time_major` to `time_axis`. Currently allowing for only the specification of a single time dimension via `type hinting` although this is currently not enforced in runtime. But underneath, it stores this as a tuple which would allow for its use in `scan_in_dim`. * Sets default `time_axis=-2` Changes the default `time_axis` to be -2, i.e. by default the expected shape of the input should be of the form [..., time, :, :]. * Swap `jax.lax.scan` for `jax_utils.scan_in_dim` Swaps `jax.lax.scan` for `flax.jax_utils.scan_in_dim` which allows for one to have multiple time dimensions by specifying `time_axis` to be a tuple instead of a single int value. * Revert "Swap `jax.lax.scan` for `jax_utils.scan_in_dim`" This reverts commit 4a8f15a. * Fix testcases due to different default timeaxis * Fix black formatting issues
- Loading branch information
Showing
2 changed files
with
356 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
import hypothesis as hp | ||
import hypothesis.strategies as st | ||
import jax | ||
import jax.numpy as jnp | ||
import numpy as np | ||
import pytest | ||
|
||
import treex as tx | ||
from treex.nn import recurrent | ||
|
||
|
||
class TestGRU: | ||
@hp.given( | ||
batch_size=st.integers(min_value=1, max_value=32), | ||
hidden_dim=st.integers(min_value=1, max_value=32), | ||
) | ||
@hp.settings(deadline=None, max_examples=20) | ||
def test_init_carry(self, batch_size, hidden_dim): | ||
next_key = tx.KeySeq().init(42) | ||
carry = recurrent.GRU(hidden_dim).module.initialize_carry( | ||
next_key, (batch_size,), hidden_dim | ||
) | ||
assert carry.shape == (batch_size, hidden_dim) | ||
|
||
@hp.given( | ||
batch_size=st.integers(min_value=1, max_value=32), | ||
hidden_dim=st.integers(min_value=1, max_value=32), | ||
features=st.integers(min_value=1, max_value=32), | ||
timesteps=st.integers(min_value=1, max_value=32), | ||
time_axis=st.integers(min_value=0, max_value=1), | ||
) | ||
@hp.settings(deadline=None, max_examples=20) | ||
def test_forward(self, batch_size, hidden_dim, features, timesteps, time_axis): | ||
key = tx.Key(8) | ||
|
||
gru = recurrent.GRU( | ||
hidden_dim, return_state=True, return_sequences=True, time_axis=time_axis | ||
) | ||
gru = gru.init(key, (jnp.ones((1, 1, features)), jnp.ones((1, hidden_dim)))) | ||
|
||
carry = gru.module.initialize_carry(key, (batch_size,), hidden_dim) | ||
|
||
dims = (batch_size, timesteps, features) | ||
if time_axis == 0: | ||
dims = (timesteps, batch_size, features) | ||
|
||
sequences, final_state = gru(jnp.ones(dims), carry) | ||
|
||
assert final_state.shape == (batch_size, hidden_dim) | ||
|
||
if time_axis == 0: | ||
assert sequences.shape == (timesteps, batch_size, hidden_dim) | ||
else: | ||
assert sequences.shape == (batch_size, timesteps, hidden_dim) | ||
|
||
def test_jit(self): | ||
x = np.random.uniform(size=(20, 10, 2)) | ||
module = recurrent.GRU(3, time_axis=0).init(42, (x, jnp.zeros((10, 3)))) | ||
|
||
@jax.jit | ||
def f(module, x): | ||
return module, module(x) | ||
|
||
module2, y = f(module, x) | ||
assert y.shape == (10, 3) | ||
print(jax.tree_leaves(module)) | ||
print(jax.tree_leaves(module2)) | ||
assert all( | ||
np.allclose(a, b) | ||
for a, b in zip( | ||
jax.tree_leaves(module.parameters()), | ||
jax.tree_leaves(module2.parameters()), | ||
) | ||
) | ||
|
||
@hp.given(return_state=st.booleans(), return_sequences=st.booleans()) | ||
@hp.settings(deadline=None, max_examples=20) | ||
def test_return_state_and_sequences(self, return_state, return_sequences): | ||
key = tx.Key(8) | ||
hidden_dim = 5 | ||
features = 10 | ||
batch_size = 32 | ||
time = 10 | ||
|
||
gru = recurrent.GRU( | ||
hidden_dim, | ||
time_axis=0, | ||
return_state=return_state, | ||
return_sequences=return_sequences, | ||
) | ||
gru = gru.init(key, (jnp.ones((1, 1, features)), jnp.zeros((1, hidden_dim)))) | ||
|
||
output = gru( | ||
jnp.ones((time, batch_size, features)), jnp.zeros((batch_size, hidden_dim)) | ||
) | ||
|
||
sequence_shape = (time, batch_size, hidden_dim) | ||
state_shape = (batch_size, hidden_dim) | ||
if return_sequences and not return_state: | ||
assert output.shape == sequence_shape | ||
elif return_state and return_sequences: | ||
assert output[0].shape == sequence_shape and output[1].shape == state_shape | ||
else: | ||
assert output.shape == state_shape | ||
|
||
def test_backward_mode(self): | ||
key = tx.Key(8) | ||
hidden_dim = 5 | ||
features = 10 | ||
batch_size = 32 | ||
time = 10 | ||
|
||
gru_fwd = recurrent.GRU(hidden_dim, time_axis=0, go_backwards=False) | ||
gru_fwd = gru_fwd.init( | ||
key, (jnp.ones((1, 1, features)), jnp.zeros((1, hidden_dim))) | ||
) | ||
|
||
gru_bwd = recurrent.GRU(hidden_dim, time_axis=0, go_backwards=True) | ||
gru_bwd.params = gru_fwd.params | ||
inputs, init_carry = ( | ||
jnp.ones((time, batch_size, features)), | ||
jnp.zeros((batch_size, hidden_dim)), | ||
) | ||
|
||
assert np.allclose( | ||
gru_fwd(inputs[:, ::-1, :], init_carry), gru_fwd(inputs, init_carry) | ||
) | ||
|
||
def test_optional_initial_state(self): | ||
key = tx.Key(8) | ||
hidden_dim = 5 | ||
features = 10 | ||
batch_size = 32 | ||
time = 10 | ||
|
||
gru = recurrent.GRU(hidden_dim, time_axis=0, go_backwards=False) | ||
gru = gru.init(key, (jnp.ones((1, 1, features)), jnp.zeros((1, hidden_dim)))) | ||
|
||
inputs = np.random.rand(time, batch_size, features) | ||
assert np.allclose(gru(inputs), gru(inputs, np.zeros((batch_size, hidden_dim)))) | ||
assert np.allclose(gru(inputs), gru(inputs, gru.initialize_state(batch_size))) | ||
|
||
def test_stateful(self): | ||
key = tx.Key(8) | ||
hidden_dim = 5 | ||
features = 10 | ||
batch_size = 32 | ||
time = 10 | ||
|
||
gru = recurrent.GRU(hidden_dim, time_axis=0, stateful=True) | ||
gru = gru.init(key, (jnp.ones((1, 1, features)), jnp.zeros((1, hidden_dim)))) | ||
base = recurrent.GRU(hidden_dim, time_axis=0, stateful=False) | ||
base = base.init(key, (jnp.ones((1, 1, features)), jnp.zeros((1, hidden_dim)))) | ||
base.params = gru.params | ||
|
||
inputs = np.random.rand(time, batch_size, features) | ||
|
||
# Initial state with zeros | ||
last_state = gru(inputs) | ||
assert np.allclose(last_state, base(inputs, np.zeros((batch_size, hidden_dim)))) | ||
|
||
# Subsequent calls starting from `last_state` | ||
state = last_state | ||
last_state = gru(inputs) | ||
assert np.allclose(last_state, base(inputs, state)) | ||
|
||
# Subsequent calls starting from `last_state` | ||
state = last_state | ||
last_state = gru(inputs) | ||
assert np.allclose(last_state, base(inputs, state)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
import typing as tp | ||
from functools import partial | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import numpy as np | ||
import treeo as to | ||
from flax.linen import recurrent as flax_module | ||
|
||
from treex import types | ||
from treex.key_seq import KeySeq | ||
from treex.module import Module, next_key | ||
from treex.nn.linear import Linear | ||
|
||
CallableModule = tp.Callable[..., jnp.ndarray] | ||
|
||
|
||
class GRU(Module): | ||
"""Gated Recurrent Unit - Cho et al. 2014 | ||
`GRU` is implemented as a wrapper on top of `flax.linen.GRUCell`, providing higher level | ||
functionality and features similar to what can be found in that of `tf.keras.layers.GRU`. | ||
""" | ||
|
||
gate_fn: CallableModule = flax_module.sigmoid | ||
activation_fn: CallableModule = flax_module.tanh | ||
kernel_init: tp.Callable[ | ||
[flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype], flax_module.Array | ||
] | ||
recurrent_kernel_init: tp.Callable[ | ||
[flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype], flax_module.Array | ||
] | ||
bias_init: tp.Callable[ | ||
[flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype], flax_module.Array | ||
] | ||
initial_state_init: tp.Callable[ | ||
[flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype], flax_module.Array | ||
] | ||
params: tp.Dict[str, tp.Dict[str, flax_module.Array]] = types.Parameter.node() | ||
last_state: flax_module.Array = types.Cache.node() | ||
|
||
# static | ||
hidden_units: int | ||
return_state: bool | ||
return_sequences: bool | ||
go_backwards: bool | ||
stateful: bool | ||
time_axis: tp.Tuple[int] | ||
unroll: int | ||
|
||
def __init__( | ||
self, | ||
units: int, | ||
*, | ||
gate_fn: CallableModule = flax_module.sigmoid, | ||
activation_fn: CallableModule = flax_module.tanh, | ||
kernel_init: CallableModule = flax_module.default_kernel_init, | ||
recurrent_kernel_init: CallableModule = flax_module.orthogonal(), | ||
bias_init: CallableModule = flax_module.zeros, | ||
initial_state_init: CallableModule = flax_module.zeros, | ||
return_sequences: bool = False, | ||
return_state: bool = False, | ||
go_backwards: bool = False, | ||
stateful: bool = False, | ||
time_axis: int = -2, | ||
unroll: int = 1 | ||
): | ||
""" | ||
Arguments: | ||
units: dimensionality of the state space | ||
gate_fn: activation function used for gates. (default: `sigmoid`) | ||
kernel_init: initializer function for the kernels that transform the input | ||
(default: `lecun_normal`) | ||
recurrent_kernel_init: initializer function for the kernels that transform | ||
the hidden state (default: `orthogonal`) | ||
bias_init: initializer function for the bias parameters (default: `zeros`) | ||
initial_state_init: initializer function for the hidden state (default: `zeros`) | ||
return_sequences: whether to return the last state or the sequences (default: `False`) | ||
return_state: whether to return the last state in addition to the sequences | ||
(default: `False`) | ||
go_backwards: whether to process the input sequence backwards and return the | ||
reversed sequence (default: `False`) | ||
stateful: whether to use the last state of the current batch as the start_state | ||
of the next batch (default: `False`) | ||
time_axis: specifies which axis of the input corresponds to the timesteps. By default, | ||
`time_axis = -2` which corresponds to the input being of shape `[..., timesteps, :]` | ||
unroll: number of iterations to be unrolled into a single XLA iteration using | ||
`jax.lax.scan` (default: `1`) | ||
""" | ||
self.hidden_units = units | ||
self.gate_fn = gate_fn | ||
self.activation_fn = activation_fn | ||
self.kernel_init = kernel_init | ||
self.recurrent_kernel_init = recurrent_kernel_init | ||
self.bias_init = bias_init | ||
self.initial_state_init = initial_state_init | ||
self.params = {} | ||
self.return_sequences = return_sequences | ||
self.return_state = return_state | ||
self.go_backwards = go_backwards | ||
self.stateful = stateful | ||
self.time_axis = (time_axis,) | ||
self.unroll = unroll | ||
|
||
self.next_key = KeySeq() | ||
self.last_state = None | ||
|
||
@property | ||
def module(self) -> flax_module.GRUCell: | ||
return flax_module.GRUCell( | ||
gate_fn=self.gate_fn, | ||
activation_fn=self.activation_fn, | ||
kernel_init=self.kernel_init, | ||
bias_init=self.bias_init, | ||
) | ||
|
||
def initialize_state(self, batch_dim: tp.Union[tp.Tuple[int], int]) -> jnp.ndarray: | ||
"""Initializes the hidden state of the GRU | ||
Arguments: | ||
batch_size: Number of elements in a batch | ||
Returns: | ||
The initial hidden state as specified by `initial_state_init` | ||
""" | ||
if not isinstance(batch_dim, tp.Iterable): | ||
batch_dim = (batch_dim,) | ||
return self.module.initialize_carry( | ||
self.next_key(), batch_dim, self.hidden_units, self.initial_state_init | ||
) | ||
|
||
def __call__( | ||
self, x: jnp.ndarray, initial_state: tp.Optional[jnp.ndarray] = None | ||
) -> tp.Union[jnp.ndarray, tp.Tuple[jnp.ndarray, jnp.ndarray]]: | ||
"""Applies the GRU to the sequence of inputs `x` starting from the `initial_state`. | ||
Arguments: | ||
`x`: sequence of inputs to the GRU | ||
`initial_state`: optional initial hidden state. If nothing is specified, | ||
either the `last_state` (i.e. the output from the previous batch is used | ||
if `stateful == True`) or the `initial_state` is gotten from the `initial_state_init` | ||
function | ||
Returns: | ||
- The final state of the GRU by default | ||
- The full sequence of states (if `return_sequences == True`) | ||
- A tuple of both the sequence of states and final state (if both | ||
`return_state` and `return_sequences` are `True`) | ||
""" | ||
# Move time axis to be the first so it can be looped over | ||
# Note: not needed with jax_utils.scan_in_dim | ||
x = jnp.swapaxes(x, self.time_axis, 0) | ||
|
||
if self.go_backwards: | ||
x = x[::-1] | ||
|
||
if initial_state is None: | ||
initial_state = self.initialize_state(x.shape[len(self.time_axis) : -1]) | ||
if self.stateful and self.last_state is not None: | ||
initial_state = self.last_state | ||
|
||
if self.initializing(): | ||
_variables = self.module.init( | ||
next_key(), initial_state, x[(0,) * len(self.time_axis)] | ||
) | ||
self.params = _variables["params"] | ||
|
||
variables = dict(params=self.params) | ||
|
||
def iter_fn(state, x): | ||
return self.module.apply(variables, state, x) | ||
|
||
final_state, sequences = jax.lax.scan( | ||
iter_fn, initial_state, x, unroll=self.unroll | ||
) | ||
if self.stateful and not self.initializing(): | ||
self.last_state = final_state | ||
|
||
# Note: Not needed with jax_utils.scan_in_dim | ||
sequences = jnp.swapaxes(sequences, self.time_axis, 0) | ||
|
||
if self.return_sequences and not self.return_state: | ||
return sequences | ||
if self.return_sequences and self.return_state: | ||
return sequences, final_state | ||
return final_state |