-
-
Notifications
You must be signed in to change notification settings - Fork 611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Upgrade train!
to work with explicit parameters
#2029
Conversation
src/train/Train.jl
Outdated
for opt in [ | ||
:Descent, :Adam, :Momentum, :Nesterov, :RMSProp, | ||
:AdaGrad, :AdaMax, :AdaDelta, :AMSGrad, :NAdam, :AdamW, :RAdam, :OAdam, :AdaBelief, | ||
# :InvDecay, :ExpDecay, :WeightDecay, :stop, :skip, :Optimiser, | ||
# :ClipValue, :ClipNorm, | ||
# TODO check that parameters line up nicely old-vs-new, and include the remaining rules | ||
] | ||
@eval $opt(parameters...; kw...) = FluxState(Optimisers.$opt(parameters...; kw...), missing) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Punning rule constructors names like this feels like a recipe for confusion. Could we compromise a bit on brevity and, say, define a shorthand alias for the FluxState
constructor instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be too cute, but to try to defend this: I'm not sure it's a pun, it's more like different packages using the same word for slightly different implementations of the same concept. Like Zygote.gradient
and ForwardDiff.gradient
, not an accident they share a symbol, but never directly interchangeable. Sometimes one will use the other, internally.
There might be a much nicer name for FluxState
, too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, punning is probably not the right term. The gradient
example is a good one though, because unlike Zygote and ForwardDiff, Optimisers is a dependency of Flux. It would be equivalent to Flux defining its own gradient
function.
Stepping back a bit, is there a way to keep train!(loss::Function, pars::Params, opt::Flux.AbstractOptimiser)
without using FluxState
? I think that would allow us to be a bit more aggressive with the new train!
methods.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I chose the example because ForwardDiff is also a dependency of Zygote, and will sometimes be used in the evaluation of the gradient, internally.
This PR deletes the notion of Flux.AbstractOptimiser
, on the grounds that maintaining two different implementations of Adam (etc) seems like a bug magnet. And because it only takes a few lines to replace that with a Params-to-Optimisers.jl bridge. But it means the optimiser is opt::FluxState
now, instead of opt::Flux.AbstractOptimiser
. Because this is mutable, a lot of code which worked on 0.13 will still work with this PR. Maybe I'm not clear on what you're suggesting here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The big question is whether the additional internal complexity and possible naming confusion are worth it if this is already a breaking change. Now, there is the question of what is a breaking change (e.g. does preventing deserialization of saved models count?), but assuming consensus on it being one I don't think the tradeoff is worth it.
One alternative would be to make a version of train!
that works with Optimisers.jl types sans wrapper and ask users to qualify their imports until 0.14. This migration will require more user effort, but should also be free of any surprising errors that come from a false sense of security (this code runs, so that code should too, right?).
I think the main factors for deciding between both approaches are when 0.14 is coming and how long it will last. The longer the former, the more leeway we have to ask users to change their code. The longer the latter, the less jarring the transition from Flux.Adam() isa FluxState
to Flux.Adam() isa AbstractRule
will be.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure I follow what alternatives you are comparing.
One alternative would be to make a version of train! that works with Optimisers.jl types sans wrapper
I don't think this can work. train!
needs to put the updated optimiser state somewhere. Even if it's given the Optimisers.jl state tree, not just the rule, it can't mutate that. The exception is rules like Descent()
with no state, which this method accepts (and checks & warns you if necessary):
https://github.com/FluxML/Flux.jl/pull/2029/files#diff-c835714f94af5b03e96dd7e45827c090cac82c1c168f535ab0d81280de54eb69R112-R120
At present, Flux.Adam()
has a mutable IdDict in which the state is stored. After this PR, Flux.Adam()
is different struct, storing things in a slightly different IdDict. But for implicit parameters, you use it the same way. No code changes except for Adam not being exported.
At first I was going to suggest just the version of train!
for explicit parameters, leave the old one alone. But this seems more confusing, as that already needs some mutable container like FluxState for its own use, which is then separate both from the AbstractOptimiser container used for implicit train!
, and from what Optimisers.jl does. 3 different flavours. Replacing the implicit train!
gets it down to 2 flavours: Flux.jl's, and Optimisers.jl's.
when 0.14 is coming and how long it will last
Note that what's envisaged here is that train!
with explicit parameters, and FluxState and Flux.Adam()
, will last beyond 0.14. Perhaps in 0.x they will call Diffractor.gradient
or Yota.grad
instead. And perhaps in 0.y the implicit train!
, and Zygote dep, can be dropped entirely.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
new_opt_state = train!(loss, model, data, Optimisers.setup(opt))
works and also keeps things to 2 flavours.
Personally, I'd prefer not to see Flux.Adam() != Optimisers.Adam()
last beyond 0.14. Perhaps Flux could absorb parts of https://github.com/FluxML/FluxTraining.jl/blob/master/src/training.jl in the future after its own functionality has been further diffused into lower-level libraries, but train!
has proven to be awkward since its inception. It's illustrative to see how much "cleaner" FluxML/FluxTraining.jl#114 was than all previous attempts to modernize train!
. I think that hints at the latter function being a fundamentally poor abstraction.
what's the status now? |
Nothing has moved really. But after reading the docs a bit, I think some variant of A simpler take than this PR's present state might be to remove implicit |
Closing in favour of #2082 Adding more code to deal with implicit parameters in new ways doesn't seem great. |
This PR proposes to move away from implicit parameters not by simply deleting
train!
, but instead by re-writing it to use explicit mode. This means that implicittrain!
has an easy upgrade path, and the new explicittrain!
can later be changed to use something other than Zygote.The option to use Optimisers.jl directly remains. But the style is quite different, and looking after the state yourself requires a certain amount of boilerplate. According to this PR, Flux should continue to offer a tidier version, which exploits mutation to update models & state objects.
The mutable state involves a new optimiser wrapper type, which is used for both explicit and implicit mode. Both modes use Optimisers.jl internally, so all the rule definitions in
Flux.Optimise
can be deleted. While many uses of the oldtrain!
will continue to work without modification, I think this is likely to be sufficiently breaking that it can only be in v0.14.Example
A simple example that runs both modes, and works if you overload
explicit_withgradient
to use Diffractor instead of Zygote in that mode:Checklist