Skip to content
This repository has been archived by the owner on Feb 26, 2023. It is now read-only.

Commit

Permalink
Adds GRU layer (#48)
Browse files Browse the repository at this point in the history
* Initial Implementation of GRU layer

Following the implementation of BatchNorm, use the underlying Flax
implementation for the logic behind the forward pass. This also adds the
use of `jax.lax.scan` for performing the sequential calls on the hidden
state, making this the implementation of the GRU layer as opposed to a
reimplementation of the `GRUCell`.

* Add key functionality to GRU implementation

- `return_state`: whether the final state should be returned
- `return_sequences`: whether all the intermediate states should be
  returned
- `go_backwards`: whether the input should be run in reverse order

* Swap order of state and input variables

Swaps the order of the state and input variables as well as the return
order of the final state and sequence of states to be more inline with
that of Keras API.

* Allow for optional passing of initial state

This allows for optional passing an initial state and the specifying of
a function which can intialize the initial state.

* Add `stateful` flag

Adds the `stateful` flag which allows for the last state of the GRU to
be used as the start state for the next batch.

* Adds documentation to GRU

* Change `time_major` to `time_axis`

In preparation for the updates to `flax.jax_utils.scan_in_dim`, this
changes `time_major` to `time_axis`. Currently allowing for only the
specification of a single time dimension via `type hinting` although
this is currently not enforced in runtime. But underneath, it stores
this as a tuple which would allow for its use in `scan_in_dim`.

* Sets default `time_axis=-2`

Changes the default `time_axis` to be -2, i.e. by default the expected
shape of the input should be of the form [..., time, :, :].

* Swap `jax.lax.scan` for `jax_utils.scan_in_dim`

Swaps `jax.lax.scan` for `flax.jax_utils.scan_in_dim` which allows for
one to have multiple time dimensions by specifying `time_axis` to be a
tuple instead of a single int value.

* Revert "Swap `jax.lax.scan` for `jax_utils.scan_in_dim`"

This reverts commit 4a8f15a.

* Fix testcases due to different default timeaxis

* Fix black formatting issues
  • Loading branch information
ptigwe authored Jan 14, 2022
1 parent 1e24884 commit fce1175
Show file tree
Hide file tree
Showing 2 changed files with 356 additions and 0 deletions.
170 changes: 170 additions & 0 deletions tests/nn/test_recurrent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import hypothesis as hp
import hypothesis.strategies as st
import jax
import jax.numpy as jnp
import numpy as np
import pytest

import treex as tx
from treex.nn import recurrent


class TestGRU:
@hp.given(
batch_size=st.integers(min_value=1, max_value=32),
hidden_dim=st.integers(min_value=1, max_value=32),
)
@hp.settings(deadline=None, max_examples=20)
def test_init_carry(self, batch_size, hidden_dim):
next_key = tx.KeySeq().init(42)
carry = recurrent.GRU(hidden_dim).module.initialize_carry(
next_key, (batch_size,), hidden_dim
)
assert carry.shape == (batch_size, hidden_dim)

@hp.given(
batch_size=st.integers(min_value=1, max_value=32),
hidden_dim=st.integers(min_value=1, max_value=32),
features=st.integers(min_value=1, max_value=32),
timesteps=st.integers(min_value=1, max_value=32),
time_axis=st.integers(min_value=0, max_value=1),
)
@hp.settings(deadline=None, max_examples=20)
def test_forward(self, batch_size, hidden_dim, features, timesteps, time_axis):
key = tx.Key(8)

gru = recurrent.GRU(
hidden_dim, return_state=True, return_sequences=True, time_axis=time_axis
)
gru = gru.init(key, (jnp.ones((1, 1, features)), jnp.ones((1, hidden_dim))))

carry = gru.module.initialize_carry(key, (batch_size,), hidden_dim)

dims = (batch_size, timesteps, features)
if time_axis == 0:
dims = (timesteps, batch_size, features)

sequences, final_state = gru(jnp.ones(dims), carry)

assert final_state.shape == (batch_size, hidden_dim)

if time_axis == 0:
assert sequences.shape == (timesteps, batch_size, hidden_dim)
else:
assert sequences.shape == (batch_size, timesteps, hidden_dim)

def test_jit(self):
x = np.random.uniform(size=(20, 10, 2))
module = recurrent.GRU(3, time_axis=0).init(42, (x, jnp.zeros((10, 3))))

@jax.jit
def f(module, x):
return module, module(x)

module2, y = f(module, x)
assert y.shape == (10, 3)
print(jax.tree_leaves(module))
print(jax.tree_leaves(module2))
assert all(
np.allclose(a, b)
for a, b in zip(
jax.tree_leaves(module.parameters()),
jax.tree_leaves(module2.parameters()),
)
)

@hp.given(return_state=st.booleans(), return_sequences=st.booleans())
@hp.settings(deadline=None, max_examples=20)
def test_return_state_and_sequences(self, return_state, return_sequences):
key = tx.Key(8)
hidden_dim = 5
features = 10
batch_size = 32
time = 10

gru = recurrent.GRU(
hidden_dim,
time_axis=0,
return_state=return_state,
return_sequences=return_sequences,
)
gru = gru.init(key, (jnp.ones((1, 1, features)), jnp.zeros((1, hidden_dim))))

output = gru(
jnp.ones((time, batch_size, features)), jnp.zeros((batch_size, hidden_dim))
)

sequence_shape = (time, batch_size, hidden_dim)
state_shape = (batch_size, hidden_dim)
if return_sequences and not return_state:
assert output.shape == sequence_shape
elif return_state and return_sequences:
assert output[0].shape == sequence_shape and output[1].shape == state_shape
else:
assert output.shape == state_shape

def test_backward_mode(self):
key = tx.Key(8)
hidden_dim = 5
features = 10
batch_size = 32
time = 10

gru_fwd = recurrent.GRU(hidden_dim, time_axis=0, go_backwards=False)
gru_fwd = gru_fwd.init(
key, (jnp.ones((1, 1, features)), jnp.zeros((1, hidden_dim)))
)

gru_bwd = recurrent.GRU(hidden_dim, time_axis=0, go_backwards=True)
gru_bwd.params = gru_fwd.params
inputs, init_carry = (
jnp.ones((time, batch_size, features)),
jnp.zeros((batch_size, hidden_dim)),
)

assert np.allclose(
gru_fwd(inputs[:, ::-1, :], init_carry), gru_fwd(inputs, init_carry)
)

def test_optional_initial_state(self):
key = tx.Key(8)
hidden_dim = 5
features = 10
batch_size = 32
time = 10

gru = recurrent.GRU(hidden_dim, time_axis=0, go_backwards=False)
gru = gru.init(key, (jnp.ones((1, 1, features)), jnp.zeros((1, hidden_dim))))

inputs = np.random.rand(time, batch_size, features)
assert np.allclose(gru(inputs), gru(inputs, np.zeros((batch_size, hidden_dim))))
assert np.allclose(gru(inputs), gru(inputs, gru.initialize_state(batch_size)))

def test_stateful(self):
key = tx.Key(8)
hidden_dim = 5
features = 10
batch_size = 32
time = 10

gru = recurrent.GRU(hidden_dim, time_axis=0, stateful=True)
gru = gru.init(key, (jnp.ones((1, 1, features)), jnp.zeros((1, hidden_dim))))
base = recurrent.GRU(hidden_dim, time_axis=0, stateful=False)
base = base.init(key, (jnp.ones((1, 1, features)), jnp.zeros((1, hidden_dim))))
base.params = gru.params

inputs = np.random.rand(time, batch_size, features)

# Initial state with zeros
last_state = gru(inputs)
assert np.allclose(last_state, base(inputs, np.zeros((batch_size, hidden_dim))))

# Subsequent calls starting from `last_state`
state = last_state
last_state = gru(inputs)
assert np.allclose(last_state, base(inputs, state))

# Subsequent calls starting from `last_state`
state = last_state
last_state = gru(inputs)
assert np.allclose(last_state, base(inputs, state))
186 changes: 186 additions & 0 deletions treex/nn/recurrent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import typing as tp
from functools import partial

import jax
import jax.numpy as jnp
import numpy as np
import treeo as to
from flax.linen import recurrent as flax_module

from treex import types
from treex.key_seq import KeySeq
from treex.module import Module, next_key
from treex.nn.linear import Linear

CallableModule = tp.Callable[..., jnp.ndarray]


class GRU(Module):
"""Gated Recurrent Unit - Cho et al. 2014
`GRU` is implemented as a wrapper on top of `flax.linen.GRUCell`, providing higher level
functionality and features similar to what can be found in that of `tf.keras.layers.GRU`.
"""

gate_fn: CallableModule = flax_module.sigmoid
activation_fn: CallableModule = flax_module.tanh
kernel_init: tp.Callable[
[flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype], flax_module.Array
]
recurrent_kernel_init: tp.Callable[
[flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype], flax_module.Array
]
bias_init: tp.Callable[
[flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype], flax_module.Array
]
initial_state_init: tp.Callable[
[flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype], flax_module.Array
]
params: tp.Dict[str, tp.Dict[str, flax_module.Array]] = types.Parameter.node()
last_state: flax_module.Array = types.Cache.node()

# static
hidden_units: int
return_state: bool
return_sequences: bool
go_backwards: bool
stateful: bool
time_axis: tp.Tuple[int]
unroll: int

def __init__(
self,
units: int,
*,
gate_fn: CallableModule = flax_module.sigmoid,
activation_fn: CallableModule = flax_module.tanh,
kernel_init: CallableModule = flax_module.default_kernel_init,
recurrent_kernel_init: CallableModule = flax_module.orthogonal(),
bias_init: CallableModule = flax_module.zeros,
initial_state_init: CallableModule = flax_module.zeros,
return_sequences: bool = False,
return_state: bool = False,
go_backwards: bool = False,
stateful: bool = False,
time_axis: int = -2,
unroll: int = 1
):
"""
Arguments:
units: dimensionality of the state space
gate_fn: activation function used for gates. (default: `sigmoid`)
kernel_init: initializer function for the kernels that transform the input
(default: `lecun_normal`)
recurrent_kernel_init: initializer function for the kernels that transform
the hidden state (default: `orthogonal`)
bias_init: initializer function for the bias parameters (default: `zeros`)
initial_state_init: initializer function for the hidden state (default: `zeros`)
return_sequences: whether to return the last state or the sequences (default: `False`)
return_state: whether to return the last state in addition to the sequences
(default: `False`)
go_backwards: whether to process the input sequence backwards and return the
reversed sequence (default: `False`)
stateful: whether to use the last state of the current batch as the start_state
of the next batch (default: `False`)
time_axis: specifies which axis of the input corresponds to the timesteps. By default,
`time_axis = -2` which corresponds to the input being of shape `[..., timesteps, :]`
unroll: number of iterations to be unrolled into a single XLA iteration using
`jax.lax.scan` (default: `1`)
"""
self.hidden_units = units
self.gate_fn = gate_fn
self.activation_fn = activation_fn
self.kernel_init = kernel_init
self.recurrent_kernel_init = recurrent_kernel_init
self.bias_init = bias_init
self.initial_state_init = initial_state_init
self.params = {}
self.return_sequences = return_sequences
self.return_state = return_state
self.go_backwards = go_backwards
self.stateful = stateful
self.time_axis = (time_axis,)
self.unroll = unroll

self.next_key = KeySeq()
self.last_state = None

@property
def module(self) -> flax_module.GRUCell:
return flax_module.GRUCell(
gate_fn=self.gate_fn,
activation_fn=self.activation_fn,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
)

def initialize_state(self, batch_dim: tp.Union[tp.Tuple[int], int]) -> jnp.ndarray:
"""Initializes the hidden state of the GRU
Arguments:
batch_size: Number of elements in a batch
Returns:
The initial hidden state as specified by `initial_state_init`
"""
if not isinstance(batch_dim, tp.Iterable):
batch_dim = (batch_dim,)
return self.module.initialize_carry(
self.next_key(), batch_dim, self.hidden_units, self.initial_state_init
)

def __call__(
self, x: jnp.ndarray, initial_state: tp.Optional[jnp.ndarray] = None
) -> tp.Union[jnp.ndarray, tp.Tuple[jnp.ndarray, jnp.ndarray]]:
"""Applies the GRU to the sequence of inputs `x` starting from the `initial_state`.
Arguments:
`x`: sequence of inputs to the GRU
`initial_state`: optional initial hidden state. If nothing is specified,
either the `last_state` (i.e. the output from the previous batch is used
if `stateful == True`) or the `initial_state` is gotten from the `initial_state_init`
function
Returns:
- The final state of the GRU by default
- The full sequence of states (if `return_sequences == True`)
- A tuple of both the sequence of states and final state (if both
`return_state` and `return_sequences` are `True`)
"""
# Move time axis to be the first so it can be looped over
# Note: not needed with jax_utils.scan_in_dim
x = jnp.swapaxes(x, self.time_axis, 0)

if self.go_backwards:
x = x[::-1]

if initial_state is None:
initial_state = self.initialize_state(x.shape[len(self.time_axis) : -1])
if self.stateful and self.last_state is not None:
initial_state = self.last_state

if self.initializing():
_variables = self.module.init(
next_key(), initial_state, x[(0,) * len(self.time_axis)]
)
self.params = _variables["params"]

variables = dict(params=self.params)

def iter_fn(state, x):
return self.module.apply(variables, state, x)

final_state, sequences = jax.lax.scan(
iter_fn, initial_state, x, unroll=self.unroll
)
if self.stateful and not self.initializing():
self.last_state = final_state

# Note: Not needed with jax_utils.scan_in_dim
sequences = jnp.swapaxes(sequences, self.time_axis, 0)

if self.return_sequences and not self.return_state:
return sequences
if self.return_sequences and self.return_state:
return sequences, final_state
return final_state

0 comments on commit fce1175

Please sign in to comment.