Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle gradients and object trees together #24

Open
wants to merge 6 commits into
base: master
Choose a base branch
from

Conversation

DhairyaLGandhi
Copy link
Member

@DhairyaLGandhi DhairyaLGandhi commented Aug 12, 2021

Flux has model structs, and Zygote would return NamedTuple gradients for them. With FluxML/Functors.jl#1 we add the ability to handle gradients in Functors.jl - in other words do a "zipped" tree walk on two equivalent functors and apply functions to them.

This extends the functionality to Optimisers.jl to allow for similar walks and apply apply (no pun intended).

Further, it moves towards #16 (comment) for a nicer API

@DhairyaLGandhi
Copy link
Member Author

This is relevant for the DDP work too. Its one of the fixes that needed to go in to make it tick. NamedTuples are more reliably serialised compared to Grads which cause all sorts of issues with reference types.

CI will currently fail since we need to update the API in a lot of other places, including the rules themselves.

@darsnack
Copy link
Member

darsnack commented Aug 12, 2021

I think the old interface matches FluxML/Functors.jl#1 better, but I'm fine putting everything into a tuple. I will still suggest that the interface should be:

update(o, st, x, (dx1, dx2, ...))

src/interface.jl Outdated Show resolved Hide resolved
@ToucheSir
Copy link
Member

AFAICT the only missing part is functionality to convert a tree of gradients from Zygote into a tree of tuples. What are your thoughts about implementing that?

@DhairyaLGandhi
Copy link
Member Author

DhairyaLGandhi commented Dec 7, 2021

Which part are you referring to exactly? We directly use the NamedTuple to parse the gradients. Are you referring to partials and higher order terms here? For those, I think we can wrap dx in a tuple, but those would need to be handled in apply right? The rules for all optimisers may not be generic to receiving tuples for gradients.

@ToucheSir
Copy link
Member

ToucheSir commented Dec 8, 2021

Yes, higher-order terms in particular. The current interface on master puts the state first for update(o, state, x::T, x̄s...) and apply(o, state, x, dxs...). Those were changed to update(o, x::T, x̄, state) and apply(o, x, dx, state) in this PR. My understanding was that with this approach, and dx must be tuples for higher-order optimizers.

However, Zygote does not return gradients with each parameter pre-wrapped in a tuple. That leaves us with 2 choices using this state-last interface:

  1. fmap-wrap the gradients before passing them to update/apply. This would require apply methods to use dx[1], dx[2] etc. instead of varargs.
  2. Only (un)wrap first-order gradients. This would allow most optimizer apply methods to stay as-is (i.e. treat dx as an parameter instead of a 1-tuple), but anything second order or higher would be stuck with the aforementioned dx[...] pattern. It would also add another layer to the process: gradient(...) (unwrapped) -> wrap -> update(...) -> unwrap -> apply(...).

Alternatively, this PR could revert to a state-first interface. That would sidestep the need for wrapping and make full use of FluxML/Functors.jl#1's support for parallel object trees (Tuple{FunctorType{A}, FunctorType{B}} instead of FunctorType{Tuple{A,B}}}.

@DhairyaLGandhi
Copy link
Member Author

2 is what we are building towards bit with only wrapping in case higher order terms are present. Apart from that i think we are good here

@darsnack
Copy link
Member

What do we lose by sticking with the interface already on master? Why are we wanting to add additional steps to the design?

@mcabbott
Copy link
Member

I have not followed Optimisers.jl closely, but in trying to solve FluxML/Flux.jl#1826 re-invented fmap(f, xs...)... and then started to wonder if it has the right behaviour. See examples here:

FluxML/Flux.jl#1826 (comment)

In particular, it won't accumulate gradients for parameters which are repeated, and it won't handle branches of the tree with no gradient. I don't know whether it's behaviour on such things is what's intended, or a bug. There is about 1 test, which does not include such cases.

@DhairyaLGandhi
Copy link
Member Author

Happy to add tests, but that comment is seemingly wanting something different entirely?

@mcabbott
Copy link
Member

No. The most relevant line is:

fmap(println, (x=[1,2], y=(a=[3,4], b=[5,6])), (x=[0,2], y=nothing))

This will happen all the time with gradients, since Zygote collapses nothings.

The second most relevant line is:

sh = [7,7]
sh2 = [0,7]
fmap(println, (x=sh, y=[3,4], z=sh), (x=sh2, y=sh2, z=[0,3]))

Since fmap explicitly has special handling for repeated leaves, a===b, it also needs to accumulate their gradients. Which can independently be === each other without consequence.

@darsnack
Copy link
Member

That's because of the cache in fmap, right? fmap isn't even used here, so I don't think it is affected.

But that speaks to a wider problem about lack of consistency. Applying the optimizers should correspond to just fmaping the rule across the structure and gradients. Going for consistency from the outset means we will have less whack-a-mole with behavioral bugs.

@mcabbott
Copy link
Member

Yes, maybe fmap is off-topic. I have not tried hard to digest what this PR does.

But I don't see any attempt in this package to test tricky cases, and fmap is illustrative of whether they are being considered at all.

Can we use the move here to upgrade such things? Any open issue on Flux should be translated into a test here, of whatever tricky case is exposed, and a few more inspired by it. They can be marked broken for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants