Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

manual gradient checks for RNN - implicit and explicit gradients #2215

Merged
merged 3 commits into from
Mar 21, 2023

Conversation

jeremiedb
Copy link
Contributor

This adds gradient tests for RNN in relation to #2185

Note that for implicit gradient mode, gradients successfully pass all tests on all Julia versions.
Implicit mode gradients only fail on Julia >= 1.7 when in REPL (and gradient isnt' call from within a function).

For explicit mode gradient (new Optimisers.jl), all gradients fail on Julia >= 1.7.
On Julia v1.6, all gradients other than state0 are correct. The correct state0 gradient is actually assigned to Recur's state rather than to the cell's state0.

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable

Copy link
Member

@ToucheSir ToucheSir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you discover anything in your test creation which might give us a lead on where the incorrect grads are coming from?

test/layers/recurrent.jl Show resolved Hide resolved
@jeremiedb
Copy link
Contributor Author

Did you discover anything in your test creation which might give us a lead on where the incorrect grads are coming from?

Sorry no new insights at the moment. For explicit mode state0 issue, I suspect some "alias" from

reset!(m::Recur) = (m.state = m.cell.state0)
.

Copy link
Member

@ToucheSir ToucheSir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@ToucheSir ToucheSir merged commit 0038a60 into FluxML:master Mar 21, 2023
@mcabbott
Copy link
Member

Would it be easy to add explicit errors to the wrong ones? E.g. by overloading Zygote.pullback(::Context{true}, ...) where the struct has a flag to indicate implicit mode.

@ToucheSir
Copy link
Member

Unfortunately it's the explicit mode path which is the broken one.

@mcabbott
Copy link
Member

Sure, but we can dispatch on that too? Haven't tried & not sure whether there's a point at which this could be attached.

@ToucheSir
Copy link
Member

We can, but we have to find that point first. And if we do, it's likely we'll be able fo fix the bug then without having to put up a "this is broken" sign.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants