-
-
Notifications
You must be signed in to change notification settings - Fork 612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use Optimisers.jl #1481
base: master
Are you sure you want to change the base?
Use Optimisers.jl #1481
Conversation
@ModelZookeeper commands |
1 similar comment
@ModelZookeeper commands |
src/optimise/train.jl
Outdated
function train!(m, loss, data, opt; cb = (x...) -> () | ||
prehook = (x...) -> (), | ||
posthook = (x...) -> ()) | ||
st = [Optimisers.init(opt, p) for p in Flux.params(m)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this can be reduced to st = Optimisers.state(opt, m)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be defined as a function in the optimiser.jl
file. One for state(opt, p::Params)
and another for update(opt, x::Params, dx, state)
.
I would prefer we remove the hooks from |
Done a bunch of cleanup, including removing the train loop changes, but I feel like we would circle back to it soon |
src/optimise/train.jl
Outdated
@@ -97,12 +102,13 @@ Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays. | |||
function train!(loss, ps, data, opt; cb = () -> ()) | |||
ps = Params(ps) | |||
cb = runall(cb) | |||
st = [Optimisers.init(opt, p) for p in ps] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we define init
for Params
like we do update
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm thinking if this should be wholesale replaced with Optimisers.state
instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm only worried about corner cases that this might hit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think state
makes sense. For every p
in ps
, we should be able to independently optimize each one by calling the user facing state
and update
. I think it's safe to do that here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, so that works in cases where we have arrays as params, but not arbitrary structs, which is what we want.
The user facing API is something like (pseudocode; although it might work in this MWE)
m = Dense(3,3)
opt = ADAM()
st = Optimisers.state(opt, m) # `m` could contain arbitrary structs which shouldn't be functor'd
loss(m, x, y) = Flux.mse(m(x), y)
for i = 1:1000
gs, = gradient(m) do m
@show loss(m, w, w′)
end
m, st = opt(m, gs, st)
end
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's exactly what I'm suggesting. To be more specific, change the code to:
st = [Optimisers.init(opt, p) for p in ps] | |
st = state(opt, ps) |
And elsewhere (still in Flux.jl), define:
Optimisers.init(o, ps::Params) = [init(o, p) for p in ps]
Should allow for the future case where m
is not a Params
but also the current case where we need to support Params
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking pretty good. Two ML calls ago, we talked about having apply!
also defined in Optimisers.jl as a temporary measure to avoid potential performance regressions. Is that still the plan?
x .-= apply!(opt, x, x̄) | ||
function update!(opt, x, x̄, st) | ||
x̄, st = apply(opt, x, x̄, st) | ||
x .-= x̄ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we use Optimisers.patch
here for consistency?
@@ -97,12 +104,13 @@ Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays. | |||
function train!(loss, ps, data, opt; cb = () -> ()) | |||
ps = Params(ps) | |||
cb = runall(cb) | |||
st = Optimisers.init(opt, ps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
st = Optimisers.init(opt, ps) | |
st = Optimisers.state(opt, ps) |
Cool, looks good. Should we be calling |
Ref #1613 (comment) In my mind, this PR is basically good to go. What I would suggest is establishing a |
That branch exists already - it's this pr! I think we're kind of tied to merge the zeros first, then this and boom. |
Why does this depend on If this is the branch, then I guess let's rebase it? |
We can rebase. Without the zeros, state initialisation and optimisation doesn't work. In order to define methods in init, one would have to do it manually. |
I'm not sure what you mean. Why is this not enough Optimisers.init(o, ps::Params) = [Optimisers.init(o, p) for p in ps] |
We also have a route that does not require the use of |
You mean where we allow passing in |
Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com>
Fixes #637, fixes #823.