-
-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Handle gradients and object trees together #24
Conversation
This is relevant for the DDP work too. Its one of the fixes that needed to go in to make it tick. NamedTuples are more reliably serialised compared to CI will currently fail since we need to update the API in a lot of other places, including the rules themselves. |
I think the old interface matches FluxML/Functors.jl#1 better, but I'm fine putting everything into a tuple. I will still suggest that the interface should be: update(o, st, x, (dx1, dx2, ...)) |
AFAICT the only missing part is functionality to convert a tree of gradients from Zygote into a tree of tuples. What are your thoughts about implementing that? |
Which part are you referring to exactly? We directly use the NamedTuple to parse the gradients. Are you referring to partials and higher order terms here? For those, I think we can wrap |
Yes, higher-order terms in particular. The current interface on master puts the However, Zygote does not return gradients with each parameter pre-wrapped in a tuple. That leaves us with 2 choices using this state-last interface:
Alternatively, this PR could revert to a state-first interface. That would sidestep the need for wrapping and make full use of FluxML/Functors.jl#1's support for parallel object trees ( |
2 is what we are building towards bit with only wrapping in case higher order terms are present. Apart from that i think we are good here |
What do we lose by sticking with the interface already on master? Why are we wanting to add additional steps to the design? |
I have not followed Optimisers.jl closely, but in trying to solve FluxML/Flux.jl#1826 re-invented In particular, it won't accumulate gradients for parameters which are repeated, and it won't handle branches of the tree with no gradient. I don't know whether it's behaviour on such things is what's intended, or a bug. There is about 1 test, which does not include such cases. |
Happy to add tests, but that comment is seemingly wanting something different entirely? |
No. The most relevant line is: fmap(println, (x=[1,2], y=(a=[3,4], b=[5,6])), (x=[0,2], y=nothing)) This will happen all the time with gradients, since Zygote collapses nothings. The second most relevant line is: sh = [7,7]
sh2 = [0,7]
fmap(println, (x=sh, y=[3,4], z=sh), (x=sh2, y=sh2, z=[0,3])) Since |
That's because of the But that speaks to a wider problem about lack of consistency. Applying the optimizers should correspond to just |
Yes, maybe But I don't see any attempt in this package to test tricky cases, and Can we use the move here to upgrade such things? Any open issue on Flux should be translated into a test here, of whatever tricky case is exposed, and a few more inspired by it. They can be marked broken for now. |
closing as outdated |
Flux has model structs, and Zygote would return NamedTuple gradients for them. With FluxML/Functors.jl#1 we add the ability to handle gradients in Functors.jl - in other words do a "zipped" tree walk on two equivalent functors and apply functions to them.
This extends the functionality to Optimisers.jl to allow for similar walks and apply
apply
(no pun intended).Further, it moves towards #16 (comment) for a nicer API