Feature request: Add include_init_carry to RNN #3323
Replies: 7 comments
-
Hey @carlosgmartin, |
Beta Was this translation helpful? Give feedback.
-
@cgarciae Thanks for getting back to me. Here's an example of what I mean: import jax
import flax
module = flax.linen.RNN(flax.linen.GRUCell(), 4)
key = jax.random.PRNGKey(0)
x = jax.random.uniform(key, (32, 20, 3))
w = module.init(key, x)
y = module.apply(w, x)
print(y.shape) # (32, 20, 4) A recurrent cell like GRU maps an input and state to a new state (the output). The output of the RNN is the sequence of new states. I was wondering if it would be possible to include the initial state in this sequence. The result would be a sequence of the same length as the input, plus 1. This is the state of the RNN at timestep zero, before it has seen any of the input sequence, which can be used to make predictions before any of the input has been seen. |
Beta Was this translation helpful? Give feedback.
-
You can get the initial carry by using the |
Beta Was this translation helpful? Give feedback.
-
@cgarciae Like this? Would you suggest any changes vis-a-vis conforming to the flax API/style? import jax
from jax import random, numpy as jnp
from flax import linen as nn
def prepend(item, seq, axis):
return jax.tree_map(
lambda item, seq: jnp.concatenate([jnp.expand_dims(item, axis), seq], axis),
item,
seq,
)
class FullRNN(nn.RNN):
def __call__(self, *args, **kwargs):
y = super().__call__(*args, **kwargs)
init_key = kwargs.get("init_key", None)
if init_key is None:
init_key = random.PRNGKey(0)
y0 = self.cell.initialize_carry(init_key, y.shape[:-2], y.shape[-1])
return prepend(y0, y, -2)
def main():
module = FullRNN(nn.GRUCell(), 4)
key = random.PRNGKey(0)
key, subkey = random.split(key)
x = random.uniform(subkey, (32, 20, 3))
key, subkey = random.split(key)
w = module.init(subkey, x)
y = module.apply(w, x)
print(y.shape) # (32, 21, 4)
if __name__ == "__main__":
main() I'm using this in an autoregressive recurrent generative model where
The initial (i.e., zeroth) token to be generated has no preceding tokens as input. |
Beta Was this translation helpful? Give feedback.
-
Not sure about using |
Beta Was this translation helpful? Give feedback.
-
Thanks for the feedback. I was considering cases like LSTMCell or ConvLSTMCell where it's not just an array. |
Beta Was this translation helpful? Give feedback.
-
Converting this to a discussion since it seems the feature request is possible in the current API. |
Beta Was this translation helpful? Give feedback.
-
RNN returns as the output sequence the sequence of new carries at each timestep, but excludes the initial carry. Inside
RNN.__call__
, the initialcarry
gets overwritten by the linecarry, outputs = scan_output
. Feature request: Add a flaginclude_initial_carry: bool
that includes the initial carry in the output sequence.Beta Was this translation helpful? Give feedback.
All reactions