Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

functor RefValue #26

Merged
merged 3 commits into from
Oct 16, 2021
Merged

functor RefValue #26

merged 3 commits into from
Oct 16, 2021

Conversation

CarloLucibello
Copy link
Member

related to FluxML/Flux.jl#1727

@DhairyaLGandhi
Copy link
Member

We plan on using Ref to force leaf nodes for erstwhile functored structs.

@CarloLucibello
Copy link
Member Author

Doesn't seem like something that we plan, it seems like something you said once and it's not needed at all since we have already options for defining leaves. This is fixing some real issues instead.

@DhairyaLGandhi
Copy link
Member

No, I don't think so. Our options for defining leaves are either them being arrays or them not having been functored. Writing filtered walks are not always easy to implement anyway. So having a convenient way which already works seems sensible. We have a fair few options to handle FluxML/Flux.jl#1727 as mentioned in the issue comments.

@darsnack
Copy link
Member

If the argument is that you can make something a leaf by wrapping it in Ref, then I think that re-using Ref is API abuse. Ref can show up in arbitrary Julia structures and implicitly assuming it has some fixed meaning is not flexible design. If we want a wrapper type that behaves like Ref but is understood as leaf by Functors.jl, then create a Leaf type. It will have one unambiguous meaning.

Also, if a certain type that is not an array should always be considered a leaf, then the author of that type can always extend Functors.functor to make it leaf anywhere it is used. Walks are not meant to be a mechanism to declaring leaves. It is mechanism for dynamically choosing how to walk over all leaves (i.e. making it possible for external APIs like trainable to compose well with fmap).

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems fine to me, would want to hear @ToucheSir's thoughts on this.

@@ -0,0 +1 @@
@functor Base.RefValue
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're starting a new file, https://github.com/FluxML/Functors.jl/blob/master/src/functor.jl#L10-L12 should be in here as well

@ToucheSir
Copy link
Member

ToucheSir commented Oct 12, 2021

Although I can't think of any deleterious effects of functoring RefValue and @darsnack pretty much hit the nail on the head wrt overloading the semantic meaning of a type whose presence we can't control, I'm afraid it won't contribute much to FluxML/Flux.jl#1727 either.

After looking into things more, RefValue would need a custom functor method that forwards to the inner value to address that issue. As you can imagine, that falls into the same trap of overloading/pseudo-piracy. Since we know Zygote only inserts these refs around NamedTuples to stand-in for mutable types, I think the least intrusive (and thus least likely to break) intervention would be some kind of custom walk that can pattern match on and unwrap RefValue{Any}{<:NamedTuple}. It won't be perfect, but it'd have a much smaller blast radius than trying to handle Refs more generally.

@DhairyaLGandhi
Copy link
Member

Seems like having something closer to FluxML/Flux.jl#1727 (comment) would be the correct implementation. We can discuss the semantic function Ref plays for us, but it is not only mutables that can produce a Ref within Zygote.

@ToucheSir
Copy link
Member

There's still the problem of distinguishing a Ref that was already in the model from one created by Zygote. Unfortunately, it's hard to say what the correct semantics should be since there doesn't seem to be AD support ATM:

julia> struct HasRef{T}
         inner::RefValue{T}
       end

julia> gradient(hr -> hr.inner[] * 5, HasRef(Ref(1)))
(nothing,)

julia> gradient(x -> x[] * 5, Ref(1))
ERROR: MethodError: no method matching (::ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}})(::NamedTuple{(:x,), Tuple{Float64}})
Closest candidates are:
  (::ChainRulesCore.ProjectTo{T, D} where D<:NamedTuple)(::ChainRulesCore.AbstractZero) where T at /home/brianc/.julia/packages/ChainRulesCore/8vlYQ/src/projection.jl:120
  (::ChainRulesCore.ProjectTo{var"#s15", D} where {var"#s15"<:Real, D<:NamedTuple})(::Complex) at /home/brianc/.julia/packages/ChainRulesCore/8vlYQ/src/projection.jl:179
  (::ChainRulesCore.ProjectTo{var"#s15", D} where {var"#s15"<:Number, D<:NamedTuple})(::ChainRulesCore.Tangent{var"#s14", T} where {var"#s14"<:Number, T}) at /home/brianc/.julia/packages/ChainRulesCore/8vlYQ/src/projection.jl:185
  ...
Stacktrace:
 [1] (::ChainRulesCore.ProjectTo{Ref, NamedTuple{(:type, :x), Tuple{DataType, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}})(dx::Base.RefValue{Any})
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/8vlYQ/src/projection.jl:275
 [2] _project
   @ ~/.julia/packages/Zygote/EPhp6/src/compiler/chainrules.jl:140 [inlined]
 [3] map(f::typeof(Zygote._project), t::Tuple{Base.RefValue{Int64}}, s::Tuple{Base.RefValue{Any}})
   @ Base ./tuple.jl:232
 [4] gradient(f::Function, args::Base.RefValue{Int64})
   @ Zygote ~/.julia/packages/Zygote/EPhp6/src/compiler/interface.jl:77
 [5] top-level scope
   @ REPL[8]:1

@CarloLucibello
Copy link
Member Author

@ToucheSir this PR does (partially) solve FluxML/Flux.jl#1727

julia> model = BatchNorm(2)
BatchNorm(2)        # 4 parameters, plus 4 non-trainable

julia> p, re = Flux.destructure(model)
(Float32[0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0], Flux.var"#60#62"{BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}(BatchNorm(2)))

julia> x = rand(Float32, 2, 1)
2×1 Matrix{Float32}:
 0.84261227
 0.6932374

julia> a, back = Flux.pullback(x, p) do _x, _p
           vec(re(_p)(_x))
       end
(Float32[0.0, 0.0], Zygote.var"#50#51"{typeof((#5))}(∂(#5)))

julia> back(a)
┌ Warning: Expected 8 params, got 0
└ @ Flux ~/.julia/packages/Flux/ZnXxS/src/utils.jl:623
(Float32[0.0; 0.0], Any[])

julia> Flux.@functor Base.RefValue

julia> back(a)
┌ Warning: Expected 8 params, got 4
└ @ Flux ~/.julia/packages/Flux/ZnXxS/src/utils.jl:623
(Float32[0.0; 0.0], Float32[0.0, 0.0, 0.0, 0.0])

The other part is addressed by having destructure only return trainable params (FluxML/Flux.jl#1742)

@ToucheSir
Copy link
Member

You're right, I'd forgotten that the whole broader conversation about Refs happened later in that thread. Still, I think the best compromise would be to only functor the Ref types we expect to receive from Zygote (side note: a list of these would be very much appreciated) and no more. Forwarding the functor call like I described in my last comment would also allow us to remove some similar special-cased code in Optimisers.jl.

@CarloLucibello
Copy link
Member Author

Zygote returns RefValue{Any} on the cases I tested. It seems a bit odd to functorize only RefValue{Any}, and I'm not sure it covers all Zygote can return. Also, as expressed in #49, I think we should start being more bold with what we functorize, and this is a very tiny step in that direction.

@CarloLucibello CarloLucibello merged commit 82766af into master Oct 16, 2021
@ToucheSir
Copy link
Member

I don't necessarily disagree, but doing this just made work for Optimisers.jl harder. It seems like mutable struct Refs may have been addressed by FluxML/Zygote.jl#1102, however, so we'll have to give it all a test.

@CarloLucibello CarloLucibello deleted the cl/refvalue branch April 2, 2024 12:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants