-
-
Notifications
You must be signed in to change notification settings - Fork 611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Have destructure return only trainable params #1742
Conversation
This seems quite a large amount of changes for most of Flux assumptions compared to a more specific fix needed for #1733 |
8fa7352
to
c8a1d1e
Compare
Will take a more detailed pass later today. One early suggestion: doesn't it make more sense to move |
c8a1d1e
to
a649870
Compare
Moved Unfortunately, I didn't manage to use in @ChrisRackauckas is there any specific package whose test I can run to see if the changes to |
The DiffEqFlux and NeuralPDE tests should be the only two using this. |
Though we're in chaos because Zygote updates seem to have broken a lot, again. DiffEqFlux tests fail because Zygote returns zero for the gradients on FFJORD (SciML/DiffEqFlux.jl#635), and NeuralPDE fails because of a tuple multiplication in an update of ChainRules (SciML/NeuralPDE.jl#412). 😱 😱 😱 😱 😱 (Please help me) 😱 😱 😱 😱 😱 |
I would go for a more specific fix for the sciml failures. We have seen the |
This comment is not very clear, but maybe you are confusing things. The RefValue problem related to #1727 (comment) has already been fixed in Functors and Zygote. |
Both DiffEqFlux and NeuralPDE's tests error for me on Flux#master and on this branch as well. |
0400686
to
4dc70b5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Slightly random comments, mostly on things not new to this PR:
src/functor.jl
Outdated
""" | ||
destructure(m) | ||
Flatten a model's parameters into a single weight vector. | ||
julia> m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This wants to be a jldoctest, and should include something like BatchNorm, to illustrate which parameter count becomes the length of θ.
It should also say that x isa AbstractArray{<:Number}
and unique objectid(x)
are the criteria for inclusion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A related question is: Is this the right test? Should these be independent?
julia> x = [1,2,3];
julia> objectid(x), objectid(x'), objectid(transpose(x))
(0x3aed6805416fa931, 0xab1cb79f2fe03e53, 0x731d5dafc3a51f2b)
You could for instance keep calling parent
until it stops, and that the parameter. Or maybe such types should be unwrapped by Functors?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As you know the situation with wrapped shared parameters is pretty complex, I'm not going to address those issues here
function destructure(m) | ||
xs = Zygote.Buffer([]) | ||
collect_params!(xs, m) | ||
return vcat(vec.(copy(xs))...), p -> _restructure(m, p) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does this copy? (And splat?)
And, how easy would it be to avoid Buffer somehow, to make this ready for not using Zygote?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With
function destructure(m)
xs = AbstractArray[]
collect_params!(xs, m)
return vcat(vec.(xs)...), p -> _restructure(m, p)
end
Flux's tests still pass. I still have to test the interaction with DiffEqFlux, NeuralPDE, let's see
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DiffEqFlux tests pass with both Buffer()
and AbstractArray[]
.
@ChrisRackauckas you see any particular reason to keep Buffer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see a reason to use Buffer here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess that's for some potential higher order AD issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or just first-order AD. AFAICT Flux's current test suite never tests the gradient of destructure
(only restructuring) 🙈...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🙈 is the Flux motto, really.
# plus 2 non-trainable, 10 parameters, summarysize 836 bytes. | ||
``` | ||
|
||
Only numerical arrays are collected by `destructe`. Moreover, if the same array is nested multiple times in the same model (e.g. shared by some layers) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only numerical arrays are collected by `destructe`. Moreover, if the same array is nested multiple times in the same model (e.g. shared by some layers) | |
Only numerical arrays are collected by `destructure`. Moreover, if the same array is nested multiple times in the same model (e.g. shared by some layers), |
@adjoint function _restructure(m, xs) | ||
m̄, numel = _restructure(m, xs), length(xs) | ||
function _restructure_pullback(dm) | ||
xs′ = destructure(dm)[1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This gradient can easily be wrong, it looks for duplicates in the gradient which can come from e.g. adding two parameters x + y
. It is completely unaware of duplicates in the original.
Demonstration:
#1826 (comment)
So at very least, we must (1) disable the removal of duplicates from destructure used here, and (2) throw an error if you try to use this adjoint when the original model had any gradients.
Or, failing that, we should remove it from v0.13 until someone can write a version which actually works.
@mcabbott should I close this? Maybe you have some better solution at this point |
I have written many things, e.g. FluxML/Functors.jl#31, but they aren't going to be ready tomorrow. Maybe v0.13 should change the documented behaviour to trainable, without aspiring to fix all the bugs? How difficult would it be to at least throw an error if there are repeated paramters? Edit -- maybe closer now, FluxML/Optimisers.jl#54 ? |
This is now the last issue on the v0.13 milestone. I think what we should do is simply delete |
Agreed, we should try that and run the downstream tests. If everything passes, then there is no reason not to. |
agreed. I'll leave this PR open as a reminder for the milestone, but I think it is better to start clean in a new PR and cherry-pick from here some of the tests |
Have destructure return only trainable params
+
functorize RefValue (this part should go into Functors.jl)
Fix #1733, fix #1727, fix #1732
Adding also tests from #1614