Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Optimisers.jl #1481

Open
wants to merge 25 commits into
base: master
Choose a base branch
from
Open

Use Optimisers.jl #1481

wants to merge 25 commits into from

Conversation

DhairyaLGandhi
Copy link
Member

@DhairyaLGandhi DhairyaLGandhi commented Jan 27, 2021

Fixes #637, fixes #823.

@DhairyaLGandhi
Copy link
Member Author

@ModelZookeeper commands

1 similar comment
@DhairyaLGandhi
Copy link
Member Author

@ModelZookeeper commands

function train!(m, loss, data, opt; cb = (x...) -> ()
prehook = (x...) -> (),
posthook = (x...) -> ())
st = [Optimisers.init(opt, p) for p in Flux.params(m)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be reduced to st = Optimisers.state(opt, m).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be defined as a function in the optimiser.jl file. One for state(opt, p::Params) and another for update(opt, x::Params, dx, state).

@darsnack
Copy link
Member

darsnack commented Feb 2, 2021

I would prefer we remove the hooks from Flux.train! in this PR and save those for a later commit. This PR can simply remove the optimizers in favor of Optimisers.jl.

@DhairyaLGandhi
Copy link
Member Author

Done a bunch of cleanup, including removing the train loop changes, but I feel like we would circle back to it soon

@@ -97,12 +102,13 @@ Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
function train!(loss, ps, data, opt; cb = () -> ())
ps = Params(ps)
cb = runall(cb)
st = [Optimisers.init(opt, p) for p in ps]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we define init for Params like we do update?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking if this should be wholesale replaced with Optimisers.state instead

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm only worried about corner cases that this might hit

Copy link
Member

@darsnack darsnack Feb 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think state makes sense. For every p in ps, we should be able to independently optimize each one by calling the user facing state and update. I think it's safe to do that here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, so that works in cases where we have arrays as params, but not arbitrary structs, which is what we want.

The user facing API is something like (pseudocode; although it might work in this MWE)

m = Dense(3,3)
opt = ADAM()
st = Optimisers.state(opt, m) # `m` could contain arbitrary structs which shouldn't be functor'd
loss(m, x, y) = Flux.mse(m(x), y)
for i = 1:1000
  gs, = gradient(m) do m
    @show loss(m, w, w′)
  end
  m, st = opt(m, gs, st)
end

Copy link
Member

@darsnack darsnack Feb 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's exactly what I'm suggesting. To be more specific, change the code to:

Suggested change
st = [Optimisers.init(opt, p) for p in ps]
st = state(opt, ps)

And elsewhere (still in Flux.jl), define:

Optimisers.init(o, ps::Params) = [init(o, p) for p in ps]

Should allow for the future case where m is not a Params but also the current case where we need to support Params.

@DhairyaLGandhi DhairyaLGandhi marked this pull request as ready for review March 8, 2021 14:00
Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking pretty good. Two ML calls ago, we talked about having apply! also defined in Optimisers.jl as a temporary measure to avoid potential performance regressions. Is that still the plan?

docs/src/training/optimisers.md Outdated Show resolved Hide resolved
x .-= apply!(opt, x, x̄)
function update!(opt, x, x̄, st)
x̄, st = apply(opt, x, x̄, st)
x .-= x̄
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use Optimisers.patch here for consistency?

src/optimise/train.jl Outdated Show resolved Hide resolved
@@ -97,12 +104,13 @@ Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
function train!(loss, ps, data, opt; cb = () -> ())
ps = Params(ps)
cb = runall(cb)
st = Optimisers.init(opt, ps)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
st = Optimisers.init(opt, ps)
st = Optimisers.state(opt, ps)

@DhairyaLGandhi
Copy link
Member Author

DhairyaLGandhi commented Mar 11, 2021

See FluxML/Optimisers.jl#13

@darsnack
Copy link
Member

Cool, looks good. Should we be calling Optimisers.update here then? Calling apply directly bypasses the mutability check.

@darsnack
Copy link
Member

Ref #1613 (comment)

In my mind, this PR is basically good to go. What I would suggest is establishing a optimizers.jl branch off master (similar to how we had a zygote branch for that transition). Then we can safely merge this into the optimisers.jl branch without much concern, and everyone can start the process of benchmarking/validating the transition.

@DhairyaLGandhi
Copy link
Member Author

That branch exists already - it's this pr! I think we're kind of tied to merge the zeros first, then this and boom.

@darsnack
Copy link
Member

Why does this depend on Zeros?

If this is the branch, then I guess let's rebase it?

@DhairyaLGandhi
Copy link
Member Author

We can rebase. Without the zeros, state initialisation and optimisation doesn't work. In order to define methods in init, one would have to do it manually.

@darsnack
Copy link
Member

darsnack commented Jul 11, 2021

I'm not sure what you mean. Why is this not enough

Optimisers.init(o, ps::Params) = [Optimisers.init(o, p) for p in ps]

@DhairyaLGandhi
Copy link
Member Author

We also have a route that does not require the use of Params.

@darsnack
Copy link
Member

You mean where we allow passing in [W, b]? Wouldn't that route be handled by Optimisers.jl already?

Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com>
@mcabbott mcabbott mentioned this pull request Feb 5, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Flux Optimizers should define equality New New Optimisers
3 participants