Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Have destructure return only trainable params #1742

Closed
wants to merge 5 commits into from

Conversation

CarloLucibello
Copy link
Member

@CarloLucibello CarloLucibello commented Oct 11, 2021

Have destructure return only trainable params
+
functorize RefValue (this part should go into Functors.jl)

Fix #1733, fix #1727, fix #1732

Adding also tests from #1614

@CarloLucibello CarloLucibello marked this pull request as draft October 11, 2021 22:05
@CarloLucibello CarloLucibello changed the title destructure returns only trainable params Have destructure return only trainable params Oct 11, 2021
@DhairyaLGandhi
Copy link
Member

This seems quite a large amount of changes for most of Flux assumptions compared to a more specific fix needed for #1733

test/utils.jl Outdated Show resolved Hide resolved
@CarloLucibello CarloLucibello marked this pull request as ready for review October 12, 2021 08:14
@darsnack
Copy link
Member

Will take a more detailed pass later today. One early suggestion: doesn't it make more sense to move destructure and friends to functor.jl instead of the other way round? Also, let's just go ahead and move the RefValue fix to Functors.jl now.

@CarloLucibello
Copy link
Member Author

CarloLucibello commented Oct 16, 2021

Moved destructure to functors.jl, and created test/functors.jl where I moved some of the tests and added new ones (I left some comments above to mark them).

Unfortunately, I didn't manage to use in params the same infrastructure of destructure without causing regressions in some cases that are caught by the newly added tests. I'll leave that for future work.

@ChrisRackauckas is there any specific package whose test I can run to see if the changes to destructure are breaking something downstream?

@ChrisRackauckas
Copy link
Member

The DiffEqFlux and NeuralPDE tests should be the only two using this.

@ChrisRackauckas
Copy link
Member

Though we're in chaos because Zygote updates seem to have broken a lot, again. DiffEqFlux tests fail because Zygote returns zero for the gradients on FFJORD (SciML/DiffEqFlux.jl#635), and NeuralPDE fails because of a tuple multiplication in an update of ChainRules (SciML/NeuralPDE.jl#412).

😱 😱 😱 😱 😱 (Please help me) 😱 😱 😱 😱 😱

@DhairyaLGandhi
Copy link
Member

I would go for a more specific fix for the sciml failures. We have seen the *(::Tuple) failures in a number of places. @ChrisRackauckas #1727 (comment) could you try with this definition of destructure. The failures related to CR would either need to be fixed there or we use a different adjoint.

@CarloLucibello
Copy link
Member Author

I would go for a more specific fix for the sciml failures. We have seen the *(::Tuple) failures in a number of places. @ChrisRackauckas #1727 (comment) could you try with this definition of destructure. The failures related to CR would either need to be fixed there or we use a different adjoint.

This comment is not very clear, but maybe you are confusing things. The RefValue problem related to #1727 (comment) has already been fixed in Functors and Zygote.
Here we fix the residual problem of having destructure return only trainable params (as the title suggests).
The CR/Zygote changes impacting NeurlPDE and DiffEqFlux are not related to Flux or this PR, and undoubtedly will be fixed somewhere else.

@CarloLucibello
Copy link
Member Author

Both DiffEqFlux and NeuralPDE's tests error for me on Flux#master and on this branch as well.
Let's wait for things in those 2 repos to be fixed before testing again and merge this.

@CarloLucibello CarloLucibello added this to the v0.13 milestone Dec 14, 2021
@mcabbott mcabbott mentioned this pull request Jan 12, 2022
test/functor.jl Outdated Show resolved Hide resolved
Copy link
Member

@mcabbott mcabbott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slightly random comments, mostly on things not new to this PR:

src/functor.jl Outdated
"""
destructure(m)
Flatten a model's parameters into a single weight vector.
julia> m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This wants to be a jldoctest, and should include something like BatchNorm, to illustrate which parameter count becomes the length of θ.

It should also say that x isa AbstractArray{<:Number} and unique objectid(x) are the criteria for inclusion.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A related question is: Is this the right test? Should these be independent?

julia> x = [1,2,3];

julia> objectid(x), objectid(x'), objectid(transpose(x))
(0x3aed6805416fa931, 0xab1cb79f2fe03e53, 0x731d5dafc3a51f2b)

You could for instance keep calling parent until it stops, and that the parameter. Or maybe such types should be unwrapped by Functors?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you know the situation with wrapped shared parameters is pretty complex, I'm not going to address those issues here

function destructure(m)
xs = Zygote.Buffer([])
collect_params!(xs, m)
return vcat(vec.(copy(xs))...), p -> _restructure(m, p)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this copy? (And splat?)

And, how easy would it be to avoid Buffer somehow, to make this ready for not using Zygote?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With

function destructure(m)
  xs = AbstractArray[]
  collect_params!(xs, m)
  return vcat(vec.(xs)...), p -> _restructure(m, p)
end

Flux's tests still pass. I still have to test the interaction with DiffEqFlux, NeuralPDE, let's see

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DiffEqFlux tests pass with both Buffer() and AbstractArray[].

@ChrisRackauckas you see any particular reason to keep Buffer?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see a reason to use Buffer here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess that's for some potential higher order AD issue?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or just first-order AD. AFAICT Flux's current test suite never tests the gradient of destructure (only restructuring) 🙈...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙈 is the Flux motto, really.

src/functor.jl Outdated Show resolved Hide resolved
src/functor.jl Outdated Show resolved Hide resolved
# plus 2 non-trainable, 10 parameters, summarysize 836 bytes.
```

Only numerical arrays are collected by `destructe`. Moreover, if the same array is nested multiple times in the same model (e.g. shared by some layers)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Only numerical arrays are collected by `destructe`. Moreover, if the same array is nested multiple times in the same model (e.g. shared by some layers)
Only numerical arrays are collected by `destructure`. Moreover, if the same array is nested multiple times in the same model (e.g. shared by some layers),

@adjoint function _restructure(m, xs)
m̄, numel = _restructure(m, xs), length(xs)
function _restructure_pullback(dm)
xs′ = destructure(dm)[1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This gradient can easily be wrong, it looks for duplicates in the gradient which can come from e.g. adding two parameters x + y. It is completely unaware of duplicates in the original.

Demonstration:
#1826 (comment)

So at very least, we must (1) disable the removal of duplicates from destructure used here, and (2) throw an error if you try to use this adjoint when the original model had any gradients.

Or, failing that, we should remove it from v0.13 until someone can write a version which actually works.

@CarloLucibello
Copy link
Member Author

@mcabbott should I close this? Maybe you have some better solution at this point

@mcabbott
Copy link
Member

mcabbott commented Feb 5, 2022

I have written many things, e.g. FluxML/Functors.jl#31, but they aren't going to be ready tomorrow. Maybe v0.13 should change the documented behaviour to trainable, without aspiring to fix all the bugs?

How difficult would it be to at least throw an error if there are repeated paramters?

Edit -- maybe closer now, FluxML/Optimisers.jl#54 ?

@mcabbott
Copy link
Member

mcabbott commented Mar 5, 2022

This is now the last issue on the v0.13 milestone.

I think what we should do is simply delete destructure completely, and call the one from Optimisers. Apart from fixing bugs it ought to be a drop-in replacement for the one here -- that is, it also only keeps trainable parameters.

@darsnack
Copy link
Member

darsnack commented Mar 5, 2022

Agreed, we should try that and run the downstream tests. If everything passes, then there is no reason not to.

@CarloLucibello
Copy link
Member Author

agreed. I'll leave this PR open as a reminder for the milestone, but I think it is better to start clean in a new PR and cherry-pick from here some of the tests

@CarloLucibello CarloLucibello deleted the cl/params2 branch April 7, 2022 07:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
6 participants