-
-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for Optimisers.jl #114
Conversation
src/training.jl
Outdated
@@ -49,18 +49,29 @@ function step! end | |||
function step!(learner, phase::TrainingPhase, batch) | |||
xs, ys = batch | |||
runstep(learner, phase, (; xs=xs, ys=ys)) do handle, state | |||
state.grads = gradient(learner.params) do | |||
state.ŷs = learner.model(state.xs) | |||
state.grads, _, _= gradient(learner.model, state.xs, state.ys) do model, xs, ys |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if passing in all three is necessary here? Since we want the gradients of the model with respect to xs and ys, but I wonder if this calculates some unneeded gradients the other way around as well?
src/training.jl
Outdated
end | ||
end | ||
|
||
# Handle both old Flux.jl and new Optimisers.jl optimisers | ||
function _update!(optimizer::Flux.Optimise.AbstractOptimiser, params, model, grads) | ||
update!(optimizer, model, grads) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This currently throws an error. For context params isa Params
and grads
is no longer a Grads
. Is a Params
even needed anymore?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a comment above about this.
|
||
|
||
@testset "Optimisers.jl compatibility" begin | ||
learner = testlearner(coeff = 3, opt=Optimisers.Descent(0.001)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This already passes 👍 but all the old optim tests are broken for the time being
src/training.jl
Outdated
state.grads = gradient(learner.params) do | ||
state.ŷs = learner.model(state.xs) | ||
|
||
state.grads, = gradient(learner.model) do model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think here you want to take the gradient w.r.t. learner.params
when the optimizer is a Flux.Optimise.AbstractOptimiser
. Conversely, if it is not, you take the gradient w.r.t. learner.model
like you are now.
This is why update!
below is erroring cause you need to Grads
object for the old optimizers. And you can only get that with implicit params.
I think some dispatch for the gradient would be easiest. Another option is to have a utility that takes the model, the gradient w.r.t. it, and Params
, then it produces a Grads
to match.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I figured. Was hoping there may be a way to have the same Zygote.gradient
call work but I guess not. I'll add a dispatch on the optimiser there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well what you're asking for might come eventually in a later version of Flux as part of the AD-agnostic push. So, the code might eventually get simpler.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's good to know. Definitely let me know then, so I can clean this up again.
src/training.jl
Outdated
end | ||
end | ||
|
||
# Handle both old Flux.jl and new Optimisers.jl optimisers | ||
function _update!(optimizer::Flux.Optimise.AbstractOptimiser, params, model, grads) | ||
update!(optimizer, model, grads) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a comment above about this.
@darsnack I added the dispatch for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks right to me!
Closes #112 (once done). @ToucheSir @darsnack
So this is a first draft for adding Optimisers.jl support (new optims) while keeping compatibility with optimisers in
Flux.Optimise
(old optims).Passing in new optims already works, but I've broken support for old optims. Before, FluxTraining.jl was using implicit parameters with
Params
andGrads
objects. I'm not sure how to use the old optims with explicit parameters togradient
.I'll leave some more questions next to the code changes, some feedback from you two would be much appreciated!