diff --git a/tests/nn/test_recurrent.py b/tests/nn/test_recurrent.py new file mode 100644 index 00000000..deb93ecb --- /dev/null +++ b/tests/nn/test_recurrent.py @@ -0,0 +1,170 @@ +import hypothesis as hp +import hypothesis.strategies as st +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +import treex as tx +from treex.nn import recurrent + + +class TestGRU: + @hp.given( + batch_size=st.integers(min_value=1, max_value=32), + hidden_dim=st.integers(min_value=1, max_value=32), + ) + @hp.settings(deadline=None, max_examples=20) + def test_init_carry(self, batch_size, hidden_dim): + next_key = tx.KeySeq().init(42) + carry = recurrent.GRU(hidden_dim).module.initialize_carry( + next_key, (batch_size,), hidden_dim + ) + assert carry.shape == (batch_size, hidden_dim) + + @hp.given( + batch_size=st.integers(min_value=1, max_value=32), + hidden_dim=st.integers(min_value=1, max_value=32), + features=st.integers(min_value=1, max_value=32), + timesteps=st.integers(min_value=1, max_value=32), + time_axis=st.integers(min_value=0, max_value=1), + ) + @hp.settings(deadline=None, max_examples=20) + def test_forward(self, batch_size, hidden_dim, features, timesteps, time_axis): + key = tx.Key(8) + + gru = recurrent.GRU( + hidden_dim, return_state=True, return_sequences=True, time_axis=time_axis + ) + gru = gru.init(key, (jnp.ones((1, 1, features)), jnp.ones((1, hidden_dim)))) + + carry = gru.module.initialize_carry(key, (batch_size,), hidden_dim) + + dims = (batch_size, timesteps, features) + if time_axis == 0: + dims = (timesteps, batch_size, features) + + sequences, final_state = gru(jnp.ones(dims), carry) + + assert final_state.shape == (batch_size, hidden_dim) + + if time_axis == 0: + assert sequences.shape == (timesteps, batch_size, hidden_dim) + else: + assert sequences.shape == (batch_size, timesteps, hidden_dim) + + def test_jit(self): + x = np.random.uniform(size=(20, 10, 2)) + module = recurrent.GRU(3, time_axis=0).init(42, (x, jnp.zeros((10, 3)))) + + @jax.jit + def f(module, x): + return module, module(x) + + module2, y = f(module, x) + assert y.shape == (10, 3) + print(jax.tree_leaves(module)) + print(jax.tree_leaves(module2)) + assert all( + np.allclose(a, b) + for a, b in zip( + jax.tree_leaves(module.parameters()), + jax.tree_leaves(module2.parameters()), + ) + ) + + @hp.given(return_state=st.booleans(), return_sequences=st.booleans()) + @hp.settings(deadline=None, max_examples=20) + def test_return_state_and_sequences(self, return_state, return_sequences): + key = tx.Key(8) + hidden_dim = 5 + features = 10 + batch_size = 32 + time = 10 + + gru = recurrent.GRU( + hidden_dim, + time_axis=0, + return_state=return_state, + return_sequences=return_sequences, + ) + gru = gru.init(key, (jnp.ones((1, 1, features)), jnp.zeros((1, hidden_dim)))) + + output = gru( + jnp.ones((time, batch_size, features)), jnp.zeros((batch_size, hidden_dim)) + ) + + sequence_shape = (time, batch_size, hidden_dim) + state_shape = (batch_size, hidden_dim) + if return_sequences and not return_state: + assert output.shape == sequence_shape + elif return_state and return_sequences: + assert output[0].shape == sequence_shape and output[1].shape == state_shape + else: + assert output.shape == state_shape + + def test_backward_mode(self): + key = tx.Key(8) + hidden_dim = 5 + features = 10 + batch_size = 32 + time = 10 + + gru_fwd = recurrent.GRU(hidden_dim, time_axis=0, go_backwards=False) + gru_fwd = gru_fwd.init( + key, (jnp.ones((1, 1, features)), jnp.zeros((1, hidden_dim))) + ) + + gru_bwd = recurrent.GRU(hidden_dim, time_axis=0, go_backwards=True) + gru_bwd.params = gru_fwd.params + inputs, init_carry = ( + jnp.ones((time, batch_size, features)), + jnp.zeros((batch_size, hidden_dim)), + ) + + assert np.allclose( + gru_fwd(inputs[:, ::-1, :], init_carry), gru_fwd(inputs, init_carry) + ) + + def test_optional_initial_state(self): + key = tx.Key(8) + hidden_dim = 5 + features = 10 + batch_size = 32 + time = 10 + + gru = recurrent.GRU(hidden_dim, time_axis=0, go_backwards=False) + gru = gru.init(key, (jnp.ones((1, 1, features)), jnp.zeros((1, hidden_dim)))) + + inputs = np.random.rand(time, batch_size, features) + assert np.allclose(gru(inputs), gru(inputs, np.zeros((batch_size, hidden_dim)))) + assert np.allclose(gru(inputs), gru(inputs, gru.initialize_state(batch_size))) + + def test_stateful(self): + key = tx.Key(8) + hidden_dim = 5 + features = 10 + batch_size = 32 + time = 10 + + gru = recurrent.GRU(hidden_dim, time_axis=0, stateful=True) + gru = gru.init(key, (jnp.ones((1, 1, features)), jnp.zeros((1, hidden_dim)))) + base = recurrent.GRU(hidden_dim, time_axis=0, stateful=False) + base = base.init(key, (jnp.ones((1, 1, features)), jnp.zeros((1, hidden_dim)))) + base.params = gru.params + + inputs = np.random.rand(time, batch_size, features) + + # Initial state with zeros + last_state = gru(inputs) + assert np.allclose(last_state, base(inputs, np.zeros((batch_size, hidden_dim)))) + + # Subsequent calls starting from `last_state` + state = last_state + last_state = gru(inputs) + assert np.allclose(last_state, base(inputs, state)) + + # Subsequent calls starting from `last_state` + state = last_state + last_state = gru(inputs) + assert np.allclose(last_state, base(inputs, state)) diff --git a/treex/nn/recurrent.py b/treex/nn/recurrent.py new file mode 100644 index 00000000..5e991e00 --- /dev/null +++ b/treex/nn/recurrent.py @@ -0,0 +1,186 @@ +import typing as tp +from functools import partial + +import jax +import jax.numpy as jnp +import numpy as np +import treeo as to +from flax.linen import recurrent as flax_module + +from treex import types +from treex.key_seq import KeySeq +from treex.module import Module, next_key +from treex.nn.linear import Linear + +CallableModule = tp.Callable[..., jnp.ndarray] + + +class GRU(Module): + """Gated Recurrent Unit - Cho et al. 2014 + + `GRU` is implemented as a wrapper on top of `flax.linen.GRUCell`, providing higher level + functionality and features similar to what can be found in that of `tf.keras.layers.GRU`. + """ + + gate_fn: CallableModule = flax_module.sigmoid + activation_fn: CallableModule = flax_module.tanh + kernel_init: tp.Callable[ + [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype], flax_module.Array + ] + recurrent_kernel_init: tp.Callable[ + [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype], flax_module.Array + ] + bias_init: tp.Callable[ + [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype], flax_module.Array + ] + initial_state_init: tp.Callable[ + [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype], flax_module.Array + ] + params: tp.Dict[str, tp.Dict[str, flax_module.Array]] = types.Parameter.node() + last_state: flax_module.Array = types.Cache.node() + + # static + hidden_units: int + return_state: bool + return_sequences: bool + go_backwards: bool + stateful: bool + time_axis: tp.Tuple[int] + unroll: int + + def __init__( + self, + units: int, + *, + gate_fn: CallableModule = flax_module.sigmoid, + activation_fn: CallableModule = flax_module.tanh, + kernel_init: CallableModule = flax_module.default_kernel_init, + recurrent_kernel_init: CallableModule = flax_module.orthogonal(), + bias_init: CallableModule = flax_module.zeros, + initial_state_init: CallableModule = flax_module.zeros, + return_sequences: bool = False, + return_state: bool = False, + go_backwards: bool = False, + stateful: bool = False, + time_axis: int = -2, + unroll: int = 1 + ): + """ + Arguments: + units: dimensionality of the state space + gate_fn: activation function used for gates. (default: `sigmoid`) + kernel_init: initializer function for the kernels that transform the input + (default: `lecun_normal`) + recurrent_kernel_init: initializer function for the kernels that transform + the hidden state (default: `orthogonal`) + bias_init: initializer function for the bias parameters (default: `zeros`) + initial_state_init: initializer function for the hidden state (default: `zeros`) + return_sequences: whether to return the last state or the sequences (default: `False`) + return_state: whether to return the last state in addition to the sequences + (default: `False`) + go_backwards: whether to process the input sequence backwards and return the + reversed sequence (default: `False`) + stateful: whether to use the last state of the current batch as the start_state + of the next batch (default: `False`) + time_axis: specifies which axis of the input corresponds to the timesteps. By default, + `time_axis = -2` which corresponds to the input being of shape `[..., timesteps, :]` + unroll: number of iterations to be unrolled into a single XLA iteration using + `jax.lax.scan` (default: `1`) + """ + self.hidden_units = units + self.gate_fn = gate_fn + self.activation_fn = activation_fn + self.kernel_init = kernel_init + self.recurrent_kernel_init = recurrent_kernel_init + self.bias_init = bias_init + self.initial_state_init = initial_state_init + self.params = {} + self.return_sequences = return_sequences + self.return_state = return_state + self.go_backwards = go_backwards + self.stateful = stateful + self.time_axis = (time_axis,) + self.unroll = unroll + + self.next_key = KeySeq() + self.last_state = None + + @property + def module(self) -> flax_module.GRUCell: + return flax_module.GRUCell( + gate_fn=self.gate_fn, + activation_fn=self.activation_fn, + kernel_init=self.kernel_init, + bias_init=self.bias_init, + ) + + def initialize_state(self, batch_dim: tp.Union[tp.Tuple[int], int]) -> jnp.ndarray: + """Initializes the hidden state of the GRU + + Arguments: + batch_size: Number of elements in a batch + + Returns: + The initial hidden state as specified by `initial_state_init` + """ + if not isinstance(batch_dim, tp.Iterable): + batch_dim = (batch_dim,) + return self.module.initialize_carry( + self.next_key(), batch_dim, self.hidden_units, self.initial_state_init + ) + + def __call__( + self, x: jnp.ndarray, initial_state: tp.Optional[jnp.ndarray] = None + ) -> tp.Union[jnp.ndarray, tp.Tuple[jnp.ndarray, jnp.ndarray]]: + """Applies the GRU to the sequence of inputs `x` starting from the `initial_state`. + + Arguments: + `x`: sequence of inputs to the GRU + `initial_state`: optional initial hidden state. If nothing is specified, + either the `last_state` (i.e. the output from the previous batch is used + if `stateful == True`) or the `initial_state` is gotten from the `initial_state_init` + function + + Returns: + - The final state of the GRU by default + - The full sequence of states (if `return_sequences == True`) + - A tuple of both the sequence of states and final state (if both + `return_state` and `return_sequences` are `True`) + """ + # Move time axis to be the first so it can be looped over + # Note: not needed with jax_utils.scan_in_dim + x = jnp.swapaxes(x, self.time_axis, 0) + + if self.go_backwards: + x = x[::-1] + + if initial_state is None: + initial_state = self.initialize_state(x.shape[len(self.time_axis) : -1]) + if self.stateful and self.last_state is not None: + initial_state = self.last_state + + if self.initializing(): + _variables = self.module.init( + next_key(), initial_state, x[(0,) * len(self.time_axis)] + ) + self.params = _variables["params"] + + variables = dict(params=self.params) + + def iter_fn(state, x): + return self.module.apply(variables, state, x) + + final_state, sequences = jax.lax.scan( + iter_fn, initial_state, x, unroll=self.unroll + ) + if self.stateful and not self.initializing(): + self.last_state = final_state + + # Note: Not needed with jax_utils.scan_in_dim + sequences = jnp.swapaxes(sequences, self.time_axis, 0) + + if self.return_sequences and not self.return_state: + return sequences + if self.return_sequences and self.return_state: + return sequences, final_state + return final_state