-
-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
functor RefValue #26
functor RefValue #26
Conversation
We plan on using Ref to force leaf nodes for erstwhile functored structs. |
Doesn't seem like something that we plan, it seems like something you said once and it's not needed at all since we have already options for defining leaves. This is fixing some real issues instead. |
No, I don't think so. Our options for defining leaves are either them being arrays or them not having been |
If the argument is that you can make something a leaf by wrapping it in Also, if a certain type that is not an array should always be considered a leaf, then the author of that type can always extend |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems fine to me, would want to hear @ToucheSir's thoughts on this.
@@ -0,0 +1 @@ | |||
@functor Base.RefValue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we're starting a new file, https://github.com/FluxML/Functors.jl/blob/master/src/functor.jl#L10-L12 should be in here as well
Although I can't think of any deleterious effects of functoring After looking into things more, |
Seems like having something closer to FluxML/Flux.jl#1727 (comment) would be the correct implementation. We can discuss the semantic function Ref plays for us, but it is not only mutables that can produce a Ref within Zygote. |
There's still the problem of distinguishing a julia> struct HasRef{T}
inner::RefValue{T}
end
julia> gradient(hr -> hr.inner[] * 5, HasRef(Ref(1)))
(nothing,)
julia> gradient(x -> x[] * 5, Ref(1))
ERROR: MethodError: no method matching (::ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}})(::NamedTuple{(:x,), Tuple{Float64}})
Closest candidates are:
(::ChainRulesCore.ProjectTo{T, D} where D<:NamedTuple)(::ChainRulesCore.AbstractZero) where T at /home/brianc/.julia/packages/ChainRulesCore/8vlYQ/src/projection.jl:120
(::ChainRulesCore.ProjectTo{var"#s15", D} where {var"#s15"<:Real, D<:NamedTuple})(::Complex) at /home/brianc/.julia/packages/ChainRulesCore/8vlYQ/src/projection.jl:179
(::ChainRulesCore.ProjectTo{var"#s15", D} where {var"#s15"<:Number, D<:NamedTuple})(::ChainRulesCore.Tangent{var"#s14", T} where {var"#s14"<:Number, T}) at /home/brianc/.julia/packages/ChainRulesCore/8vlYQ/src/projection.jl:185
...
Stacktrace:
[1] (::ChainRulesCore.ProjectTo{Ref, NamedTuple{(:type, :x), Tuple{DataType, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}})(dx::Base.RefValue{Any})
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/8vlYQ/src/projection.jl:275
[2] _project
@ ~/.julia/packages/Zygote/EPhp6/src/compiler/chainrules.jl:140 [inlined]
[3] map(f::typeof(Zygote._project), t::Tuple{Base.RefValue{Int64}}, s::Tuple{Base.RefValue{Any}})
@ Base ./tuple.jl:232
[4] gradient(f::Function, args::Base.RefValue{Int64})
@ Zygote ~/.julia/packages/Zygote/EPhp6/src/compiler/interface.jl:77
[5] top-level scope
@ REPL[8]:1 |
@ToucheSir this PR does (partially) solve FluxML/Flux.jl#1727 julia> model = BatchNorm(2)
BatchNorm(2) # 4 parameters, plus 4 non-trainable
julia> p, re = Flux.destructure(model)
(Float32[0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0], Flux.var"#60#62"{BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}(BatchNorm(2)))
julia> x = rand(Float32, 2, 1)
2×1 Matrix{Float32}:
0.84261227
0.6932374
julia> a, back = Flux.pullback(x, p) do _x, _p
vec(re(_p)(_x))
end
(Float32[0.0, 0.0], Zygote.var"#50#51"{typeof(∂(#5))}(∂(#5)))
julia> back(a)
┌ Warning: Expected 8 params, got 0
└ @ Flux ~/.julia/packages/Flux/ZnXxS/src/utils.jl:623
(Float32[0.0; 0.0], Any[])
julia> Flux.@functor Base.RefValue
julia> back(a)
┌ Warning: Expected 8 params, got 4
└ @ Flux ~/.julia/packages/Flux/ZnXxS/src/utils.jl:623
(Float32[0.0; 0.0], Float32[0.0, 0.0, 0.0, 0.0]) The other part is addressed by having |
You're right, I'd forgotten that the whole broader conversation about Refs happened later in that thread. Still, I think the best compromise would be to only functor the Ref types we expect to receive from Zygote (side note: a list of these would be very much appreciated) and no more. Forwarding the functor call like I described in my last comment would also allow us to remove some similar special-cased code in Optimisers.jl. |
Zygote returns |
I don't necessarily disagree, but doing this just made work for Optimisers.jl harder. It seems like mutable struct Refs may have been addressed by FluxML/Zygote.jl#1102, however, so we'll have to give it all a test. |
related to FluxML/Flux.jl#1727