Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make RNNs blocked (and maybe fixing gradients along the way) #2258

Open
mkschleg opened this issue May 19, 2023 · 6 comments · May be fixed by #2500
Open

Make RNNs blocked (and maybe fixing gradients along the way) #2258

mkschleg opened this issue May 19, 2023 · 6 comments · May be fixed by #2500
Labels
Milestone

Comments

@mkschleg
Copy link
Contributor

Motivation and description

Given #2185 and other issues caused by the current mutability of the recur interface, we should move to a more standard blocked (i.e. 3D for simple RNN) interface. This has the benefits of:

  1. cleaning the recurrent interface so it is more easily used by people coming from other packages,
  2. more easily enable workflows using convRNNs, and
  3. potentially enable some optimizations we can handle on the Flux side (see Lux's Recurrence return_sequence=true vs false)

I have not tested how we might fix the gradients by moving to this restricted interface. But if we decide to remove the statefulness (see below) we can fix gradients as seen in FluxML/Fluxperimental.jl#7.

Possible Implementation

I see two ways we can do this change, one which is a wider change of the Flux chain interface and another which tries to only fix Recur. In either case, the implementation would assume the final dimension of your multi-dimensional array is the time index. For a simple RNN it would assume the dimensions of the incoming array as: Features x Batch x Time. It will produce an error if a 2d array or 1d array is passed to recur, to avoid ambiguities.

One possible implementation is to go ahead and do the full change over to removing state from the network generally. See FluxML/Fluxperimental.jl#7. This would overhaul large parts of the interface into chain, and could be targeted at 0.14. See the implementation done in the above PR and FluxML/Fluxperimental.jl#5 for details.

The second possible approach is to just first remove the loop over timesteps interface and replace with the 3d interface. This initial change restricts the interface to be 3d, but I haven't tested how we could fix gradients while maintaining mutability and statefulness in Recur. The interface/impl would likely look much like:

function (m::Recur)(x::AbstractArray{T, 3}) where T
h = [m(x_t) for x_t in eachlastdim(x)]
sze = size(h[1])
reshape(reduce(hcat, h), sze[1], sze[2], length(h))
end

@ToucheSir
Copy link
Member

ToucheSir commented May 26, 2023

On your second approach, how about emulating what PyTorch does with immutable struct wrappers over the RNN cell types? Say const LSTM = NewRecur{LSTMCell}. This API would only accept 3D+ sequences and return the hidden state. We could make passing the hidden state optional with a signature like (::NewRecur)(x, h = <default value>). Much like PyTorch and TensorFlow, we could add a type parameters to toggle whether the RNN is bidirectional and returns a sequence/timestep respectively.

Integration with Chain in this approach would be a bit more work since preceding and following layers would need to be sequence-aware, but in practice I haven't seen many actually taking advantage of being able to create combinations like Chain(Dense(), RNN(), LayerNorm()) to apply non-recurrent layers per-timestep. Just extracting the new hidden state or output from the RNN return value would be a simple matter of adding first or last after it in the chain.

This approach avoids having to deal with state by making the user save it and carry it over themselves. It's not as ergonomic since they'd have to thread said states through larger models, but that's more than doable with layers like Parallel. Going stateless also makes AD easier! The big remaining AD-related issue I see is that differentiating through loops is still slow, but we can address that easier with a 3D sequence-based interface by defining rules for NewRecur. Maybe this rule could be attached to a JAX-like scan function so that people can use it for their own recurrent layers.

@mkschleg
Copy link
Contributor Author

Good idea. This also seems similar to Haiku's approach as well afaict. I think this could also give us an opportunity to provide an interface for a static unroll vs a dynamic unroll. I think at first we should just do a loop, but there might an opportunity to use a generated function to replace the loop with the unrolled version. But that might be more problematic than its worth depending on how the scan function turns out.

@mkschleg
Copy link
Contributor Author

I added a first pass of this functionality to Fluxperimental.jl. There are some lingering issues that need to be resolved, but it gives us an idea of what we need to support. I think we should have a conversation on how to solve the issue of returning the carry back to the user. @ToucheSir mentioned using Parallel streams to solve the issue, but I think that would be a pretty horrendous interface to use.

@mkschleg
Copy link
Contributor Author

mkschleg commented Aug 9, 2023

Ok. We have merged a potential interface (NewRecur) into Fluxperimental. I think we can make a push to get this into v0.15. I think this is necessary if we want to fully remove the old gradient interface. Thoughts? I can work on a PR (I've finally gotten through my PhD defense, so time is less of an issue).

@ToucheSir
Copy link
Member

I like that idea. We should probably figure out how to make the rrule type stable as part of that :P

@mkschleg
Copy link
Contributor Author

Working PR in #2316.

@ToucheSir ToucheSir mentioned this issue Feb 12, 2024
14 tasks
@CarloLucibello CarloLucibello linked a pull request Oct 14, 2024 that will close this issue
13 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants