-
-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for Optimisers.jl #114
Changes from 3 commits
4c155ec
4974657
a22dff4
903c243
63128f3
4b1c052
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -49,18 +49,30 @@ function step! end | |
function step!(learner, phase::TrainingPhase, batch) | ||
xs, ys = batch | ||
runstep(learner, phase, (; xs=xs, ys=ys)) do handle, state | ||
state.grads = gradient(learner.params) do | ||
state.ŷs = learner.model(state.xs) | ||
|
||
state.grads, = gradient(learner.model) do model | ||
state.ŷs = model(state.xs) | ||
handle(LossBegin()) | ||
state.loss = learner.lossfn(state.ŷs, state.ys) | ||
handle(BackwardBegin()) | ||
return state.loss | ||
end | ||
handle(BackwardEnd()) | ||
update!(learner.optimizer, learner.params, state.grads) | ||
learner.params, learner.model = _update!( | ||
learner.optimizer, learner.params, learner.model, state.grads) | ||
end | ||
end | ||
|
||
# Handle both old Flux.jl and new Optimisers.jl optimisers | ||
function _update!(optimizer::Flux.Optimise.AbstractOptimiser, params, model, grads) | ||
update!(optimizer, model, grads) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This currently throws an error. For context There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Left a comment above about this. |
||
return params, model | ||
end | ||
function _update!(_, st, model, grads) | ||
st, model = Optimisers.update!(st, model, grads) | ||
return st, model | ||
end | ||
|
||
|
||
function step!(learner, phase::ValidationPhase, batch) | ||
xs, ys = batch | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,3 +47,10 @@ end | |
fit!(learner, 5) | ||
@test learner.model.coeff[1] ≈ 3 atol = 0.1 | ||
end | ||
|
||
|
||
@testset "Optimisers.jl compatibility" begin | ||
learner = testlearner(coeff = 3, opt=Optimisers.Descent(0.001)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This already passes 👍 but all the old optim tests are broken for the time being |
||
fit!(learner, 5) | ||
@test learner.model.coeff[1] ≈ 3 atol = 0.1 | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think here you want to take the gradient w.r.t.
learner.params
when the optimizer is aFlux.Optimise.AbstractOptimiser
. Conversely, if it is not, you take the gradient w.r.t.learner.model
like you are now.This is why
update!
below is erroring cause you need toGrads
object for the old optimizers. And you can only get that with implicit params.I think some dispatch for the gradient would be easiest. Another option is to have a utility that takes the model, the gradient w.r.t. it, and
Params
, then it produces aGrads
to match.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I figured. Was hoping there may be a way to have the same
Zygote.gradient
call work but I guess not. I'll add a dispatch on the optimiser there.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well what you're asking for might come eventually in a later version of Flux as part of the AD-agnostic push. So, the code might eventually get simpler.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's good to know. Definitely let me know then, so I can clean this up again.