Skip to content
This repository has been archived by the owner on Feb 26, 2023. It is now read-only.

Initial implementation of GRU layers #48

Merged
merged 12 commits into from
Jan 14, 2022
167 changes: 167 additions & 0 deletions tests/nn/test_recurrent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import hypothesis as hp
import hypothesis.strategies as st
import jax
import jax.numpy as jnp
import numpy as np
import pytest

import treex as tx
from treex.nn import recurrent


class TestGRU:
@hp.given(
batch_size=st.integers(min_value=1, max_value=32),
hidden_dim=st.integers(min_value=1, max_value=32),
)
@hp.settings(deadline=None, max_examples=20)
def test_init_carry(self, batch_size, hidden_dim):
next_key = tx.KeySeq().init(42)
carry = recurrent.GRU(hidden_dim).module.initialize_carry(
next_key, (batch_size,), hidden_dim
)
assert carry.shape == (batch_size, hidden_dim)

@hp.given(
batch_size=st.integers(min_value=1, max_value=32),
hidden_dim=st.integers(min_value=1, max_value=32),
features=st.integers(min_value=1, max_value=32),
timesteps=st.integers(min_value=1, max_value=32),
time_major=st.booleans(),
)
@hp.settings(deadline=None, max_examples=20)
def test_forward(self, batch_size, hidden_dim, features, timesteps, time_major):
key = tx.Key(8)

gru = recurrent.GRU(
hidden_dim, return_state=True, return_sequences=True, time_major=time_major
)
gru = gru.init(key, (jnp.ones((1, 1, features)), jnp.ones((1, hidden_dim))))

carry = gru.module.initialize_carry(key, (batch_size,), hidden_dim)

dims = (batch_size, timesteps, features)
if time_major:
dims = (timesteps, batch_size, features)

sequences, final_state = gru(jnp.ones(dims), carry)

assert final_state.shape == (batch_size, hidden_dim)

if time_major:
assert sequences.shape == (timesteps, batch_size, hidden_dim)
else:
assert sequences.shape == (batch_size, timesteps, hidden_dim)

def test_jit(self):
x = np.random.uniform(size=(10, 20, 2))
module = recurrent.GRU(3).init(42, (x, jnp.zeros((10, 3))))

@jax.jit
def f(module, x):
return module, module(x)

module2, y = f(module, x)
assert y.shape == (10, 3)
print(jax.tree_leaves(module))
print(jax.tree_leaves(module2))
assert all(
np.allclose(a, b)
for a, b in zip(
jax.tree_leaves(module.parameters()),
jax.tree_leaves(module2.parameters()),
)
)

@hp.given(return_state=st.booleans(), return_sequences=st.booleans())
@hp.settings(deadline=None, max_examples=20)
def test_return_state_and_sequences(self, return_state, return_sequences):
key = tx.Key(8)
hidden_dim = 5
features = 10
batch_size = 32
time = 10

gru = recurrent.GRU(
hidden_dim, return_state=return_state, return_sequences=return_sequences
)
gru = gru.init(key, (jnp.ones((1, 1, features)), jnp.zeros((1, hidden_dim))))

output = gru(
jnp.ones((batch_size, time, features)), jnp.zeros((batch_size, hidden_dim))
)

sequence_shape = (batch_size, time, hidden_dim)
state_shape = (batch_size, hidden_dim)
if return_sequences and not return_state:
assert output.shape == sequence_shape
elif return_state and return_sequences:
assert output[0].shape == sequence_shape and output[1].shape == state_shape
else:
assert output.shape == state_shape

def test_backward_mode(self):
key = tx.Key(8)
hidden_dim = 5
features = 10
batch_size = 32
time = 10

gru_fwd = recurrent.GRU(hidden_dim, go_backwards=False)
gru_fwd = gru_fwd.init(
key, (jnp.ones((1, 1, features)), jnp.zeros((1, hidden_dim)))
)

gru_bwd = recurrent.GRU(hidden_dim, go_backwards=True)
gru_bwd.params = gru_fwd.params
inputs, init_carry = (
jnp.ones((batch_size, time, features)),
jnp.zeros((batch_size, hidden_dim)),
)

assert np.allclose(
gru_fwd(inputs[:, ::-1, :], init_carry), gru_fwd(inputs, init_carry)
)

def test_optional_initial_state(self):
key = tx.Key(8)
hidden_dim = 5
features = 10
batch_size = 32
time = 10

gru = recurrent.GRU(hidden_dim, go_backwards=False)
gru = gru.init(key, (jnp.ones((1, 1, features)), jnp.zeros((1, hidden_dim))))

inputs = np.random.rand(batch_size, time, features)
assert np.allclose(gru(inputs), gru(inputs, np.zeros((batch_size, hidden_dim))))
assert np.allclose(gru(inputs), gru(inputs, gru.initialize_state(batch_size)))

def test_stateful(self):
key = tx.Key(8)
hidden_dim = 5
features = 10
batch_size = 32
time = 10

gru = recurrent.GRU(hidden_dim, stateful=True)
gru = gru.init(key, (jnp.ones((1, 1, features)), jnp.zeros((1, hidden_dim))))
base = recurrent.GRU(hidden_dim, stateful=False)
base = base.init(key, (jnp.ones((1, 1, features)), jnp.zeros((1, hidden_dim))))
base.params = gru.params

inputs = np.random.rand(batch_size, time, features)

# Initial state with zeros
last_state = gru(inputs)
assert np.allclose(last_state, base(inputs, np.zeros((batch_size, hidden_dim))))

# Subsequent calls starting from `last_state`
state = last_state
last_state = gru(inputs)
assert np.allclose(last_state, base(inputs, state))

# Subsequent calls starting from `last_state`
state = last_state
last_state = gru(inputs)
assert np.allclose(last_state, base(inputs, state))
183 changes: 183 additions & 0 deletions treex/nn/recurrent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
import typing as tp
from functools import partial

import jax
import jax.numpy as jnp
import numpy as np
import treeo as to
from flax.linen import recurrent as flax_module

from treex import types
from treex.key_seq import KeySeq
from treex.module import Module, next_key
from treex.nn.linear import Linear

CallableModule = tp.Callable[..., jnp.ndarray]


class GRU(Module):
"""Gated Recurrent Unit - Cho et al. 2014

`GRU` is implemented as a wrapper on top of `flax.linen.GRUCell`, providing higher level
functionality and features similar to what can be found in that of `tf.keras.layers.GRU`.
"""

gate_fn: CallableModule = flax_module.sigmoid
activation_fn: CallableModule = flax_module.tanh
kernel_init: tp.Callable[
[flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype], flax_module.Array
]
recurrent_kernel_init: tp.Callable[
[flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype], flax_module.Array
]
bias_init: tp.Callable[
[flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype], flax_module.Array
]
initial_state_init: tp.Callable[
[flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype], flax_module.Array
]
params: tp.Dict[str, tp.Dict[str, flax_module.Array]] = types.Parameter.node()
last_state: flax_module.Array = types.BatchStat.node()
ptigwe marked this conversation as resolved.
Show resolved Hide resolved

# static
hidden_units: int
return_state: bool
return_sequences: bool
go_backwards: bool
stateful: bool
time_major: bool
unroll: int

def __init__(
self,
units: int,
*,
gate_fn: CallableModule = flax_module.sigmoid,
activation_fn: CallableModule = flax_module.tanh,
kernel_init: CallableModule = flax_module.default_kernel_init,
recurrent_kernel_init: CallableModule = flax_module.orthogonal(),
bias_init: CallableModule = flax_module.zeros,
initial_state_init: CallableModule = flax_module.zeros,
return_sequences: bool = False,
return_state: bool = False,
go_backwards: bool = False,
stateful: bool = False,
time_major: bool = False,
unroll: int = 1
):
"""
Arguments:
units: dimensionality of the state space
gate_fn: activation function used for gates. (default: `sigmoid`)
kernel_init: initializer function for the kernels that transform the input
(default: `lecun_normal`)
recurrent_kernel_init: initializer function for the kernels that transform
the hidden state (default: `orthogonal`)
bias_init: initializer function for the bias parameters (default: `zeros`)
initial_state_init: initializer function for the hidden state (default: `zeros`)
return_sequences: whether to return the last state or the sequences (default: `False`)
return_state: whether to return the last state in addition to the sequences
(default: `False`)
go_backwards: whether to process the input sequence backwards and return the
reversed sequence (default: `False`)
stateful: whether to use the last state of the current batch as the start_state
of the next batch (default: `False`)
time_major: defines the shape of the `input` and `output`. If `True`, the inputs
and outputs will have a shape of `[timesteps, batch, feature]`, otherwise it will
be `[batch, timesteps, feature]`.
unroll: number of iterations to be unrolled into a single XLA iteration using
`jax.lax.scan` (default: `1`)
"""
self.hidden_units = units
self.gate_fn = gate_fn
self.activation_fn = activation_fn
self.kernel_init = kernel_init
self.recurrent_kernel_init = recurrent_kernel_init
self.bias_init = bias_init
self.initial_state_init = initial_state_init
self.params = {}
self.return_sequences = return_sequences
self.return_state = return_state
self.go_backwards = go_backwards
self.stateful = stateful
self.time_major = time_major
self.unroll = unroll

self.next_key = KeySeq()
self.last_state = None

@property
def module(self) -> flax_module.GRUCell:
return flax_module.GRUCell(
gate_fn=self.gate_fn,
activation_fn=self.activation_fn,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
)

def initialize_state(self, batch_size: int) -> jnp.ndarray:
"""Initializes the hidden state of the GRU

Arguments:
batch_size: Number of elements in a batch

Returns:
The initial hidden state as specified by `initial_state_init`
"""
return self.module.initialize_carry(
self.next_key(), (batch_size,), self.hidden_units, self.initial_state_init
)

def __call__(
self, x: jnp.ndarray, initial_state: tp.Optional[jnp.ndarray] = None
) -> tp.Union[jnp.ndarray, tp.Tuple[jnp.ndarray, jnp.ndarray]]:
"""Applies the GRU to the sequence of inputs `x` starting from the `initial_state`.

Arguments:
`x`: sequence of inputs to the GRU
`initial_state`: optional initial hidden state. If nothing is specified,
either the `last_state` (i.e. the output from the previous batch is used
if `stateful == True`) or the `initial_state` is gotten from the `initial_state_init`
function

Returns:
- The final state of the GRU by default
- The full sequence of states (if `return_sequences == True`)
- A tuple of both the sequence of states and final state (if both
`return_state` and `return_sequences` are `True`)
"""
# Move time dimension to be the first so it can be looped over
if not self.time_major:
x = jnp.transpose(x, (1, 0, 2))
ptigwe marked this conversation as resolved.
Show resolved Hide resolved

if self.go_backwards:
x = x[::-1, :, :]
ptigwe marked this conversation as resolved.
Show resolved Hide resolved

if initial_state is None:
initial_state = self.initialize_state(x.shape[1])
if self.stateful and self.last_state is not None:
initial_state = self.last_state

if self.initializing():
_variables = self.module.init(next_key(), initial_state, x[0, ...])
self.params = _variables["params"]

variables = dict(params=self.params)

def iter_fn(state, x):
return self.module.apply(variables, state, x)

final_state, sequences = jax.lax.scan(
iter_fn, initial_state, x, unroll=self.unroll
)
if self.stateful and not self.initializing():
self.last_state = final_state

if not self.time_major:
sequences = jnp.transpose(sequences, (1, 0, 2))

if self.return_sequences and not self.return_state:
return sequences
if self.return_sequences and self.return_state:
return sequences, final_state
return final_state