Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flux.destructure's restructure fails in the gradient if loss does not use all parameters #1826

Closed
newalexander opened this issue Jan 10, 2022 · 8 comments · Fixed by #1901
Closed

Comments

@newalexander
Copy link

It doesn't seem like you can use the explicit-parameters destructure-based formulation of Flux in combination with NN architectures that use a single input to make multiple outputs. Is this by design? Ideally, the entries in the gradient corresponding to weights that don't get used in the loss could be 0 or missing, but maybe that introduces implementation challenges.

using Flux

struct Split{T}  # taken from: https://fluxml.ai/Flux.jl/stable/models/advanced/#Multiple-outputs:-a-custom-Split-layer
    paths::T
end
Split(paths...) = Split(paths)
Flux.@functor Split
(m::Split)(x::AbstractArray) = map(f -> f(x), m.paths)

n_input, n_batch, n_shared = 5, 13, 11
n_outputs = [3, 7]

data = rand(Float32, n_input, n_batch)
model = Chain(
    Dense(n_input, n_shared),
    Split(Dense(n_shared, n_outputs[1]), Dense(n_shared, n_outputs[2]))
)

ps, re = Flux.destructure(model)
loss(x, idx, ps) = sum(abs2, re(ps)(x)[idx])  # loss wrt `idx`th output term
gradient(params -> loss(data, 1, params), ps)  # fails with ERROR: LoadError: DimensionMismatch("variable with size(x) == (210,) cannot have a gradient with size(dx) == (102,)")

The full error message:

┌ Warning: Expected 186 params, got 102
└ @ Flux ~/.julia/packages/Flux/BPPNj/src/utils.jl:647
ERROR: LoadError: DimensionMismatch("variable with size(x) == (186,) cannot have a gradient with size(dx) == (102,)")
Stacktrace:
 [1] (::ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}})(dx::Vector{Float32})
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/IFusD/src/projection.jl:226
 [2] _project
   @ ~/.julia/packages/Zygote/umM0L/src/compiler/chainrules.jl:182 [inlined]
 [3] map(f::typeof(Zygote._project), t::Tuple{Vector{Float32}}, s::Tuple{Vector{Float32}})
   @ Base ./tuple.jl:246
 [4] gradient(f::Function, args::Vector{Float32})
   @ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface.jl:77
 [5] top-level scope
   @ [script location]scripts/mwe_1.jl:25
in expression starting at [script location]scripts/mwe_1.jl:25
The terminal process "julia '--color=yes', '--project=~/.julia/environments/v1.7', '[script location]scripts/mwe_1.jl'" terminated with exit code: 1.
@ToucheSir
Copy link
Member

Ideally, the entries in the gradient corresponding to weights that don't get used in the loss could be 0 or missing, but maybe that introduces implementation challenges.

They should be 0, Zygote's parameter tracking isn't that fine-grained. Usually you'd expect a mismatch error like this when an entire array parameter isn't used at all, but AFAICT that is not the case here. To troubleshoot, you'll want to look at what is being passed as the gradient value into _restructure's adjoint and check if anything is short/missing.

@newalexander
Copy link
Author

newalexander commented Jan 10, 2022

@ToucheSir Thanks for the pointer. From what I can tell, the restructure_pullback closure in the definition of the adjoint to _restructure here is not receiving the full model.

In the slightly simpler MWE below, when _restructure_pullback gets called, dm is

(layers = ((paths = ((weight = Float32[-11.297353 -10.366326 -10.771574 -10.596351 -9.816019; 11.591414 13.127091 9.820183 12.75714 8.937439; 7.2377276 10.328969 6.7674723 8.917322 7.2669835], bias = Float32[-17.877731, 19.011127, 13.564459], σ = nothing), nothing),),),)

so it only seems to be grabbing the Dense(n_input, n_outputs[1]) path, rather than the full model.

I don't have a strong grasp on Zygote's fundamentals beyond the basics, so it's not super obvious to me what might be driving this issue.

using Flux

struct Split{T}  # taken from: https://fluxml.ai/Flux.jl/stable/models/advanced/#Multiple-outputs:-a-custom-Split-layer
    paths::T
end
Split(paths...) = Split(paths)
@functor Split
(m::Split)(x::AbstractArray) = map(f -> f(x), m.paths)

n_input, n_batch, n_shared = 5, 13, 11
n_outputs = [3, 7]

data = rand(Float32, n_input, n_batch)
model = Chain(
	Split(Dense(n_input, n_outputs[1]), Dense(n_input, n_outputs[2]))
)

ps, re = Flux.destructure(model);

loss(x, idx, ps) = sum(abs2, re(ps)(x)[idx])
gradient(params -> loss(data, 1, params), ps) 

@ToucheSir
Copy link
Member

Ah sorry, I missed that the tuple of arrays from Split was being indexed directly. As you saw, Zygote substitutes nothing for the rest of the params. This is probably possible with the current destructure interface in Flux, but would require decent scale rewrite of the implementation. I wish I had a better answer than "this needs someone to do the (tricky) legwork"...

@newalexander
Copy link
Author

FWIW, it does seem viable to calculate gradients with the implicit parameters style and then (un)flatten them post-hoc (although I believe the solution here requires some modification to account for the fact that some gradients might be nothing).
https://discourse.julialang.org/t/manually-updating-the-parameters-of-a-neural-network-in-flux/25327/3

@ToucheSir
Copy link
Member

Yes, using implicit parameters avoids this problem altogether. I don't believe it even needs any additional modification either, just a check to make sure each param is in the Grads dict before calling update!. So if that works for you, great! I assumed flattened parameters was a hard requirement given the issue title.

@newalexander
Copy link
Author

We're working through our requirements at the moment, but being able to use only flattened parameters would make some of our designs a bit easier, hence the motivation for raising the issue.

@mcabbott mcabbott changed the title Flux.destructure doesn't work with Split Flux.destructure's restructure fails in the gradient if loss does not use all parameters Jan 16, 2022
@mcabbott
Copy link
Member

mcabbott commented Jan 16, 2022

I think the MWE is something like this:

julia> v, re = Flux.destructure((x=[1,2], y=[3,4,5]))
([1, 2, 3, 4, 5], Flux.var"#128#129"...

julia> gradient(zero(v)) do w
         m = re(w)
         5 * sum(m.x) + 7 * sum(m[2])  # uses both x and y
       end
([5.0, 5.0, 7.0, 7.0, 7.0],)

julia> gradient(w -> sum(re(w).x), zero(v))  # uses only x
┌ Warning: Expected 5 params, got 2
└ @ Flux ~/.julia/packages/Flux/HNHmp/src/functor.jl:61
ERROR: DimensionMismatch("variable with size(x) == (5,) cannot have a gradient with size(dx) == (2,)")

And the problem is that _restructure_pullback here:

xs′ = destructure(dm)[1]

blindly calls destructure again on dm, the "structural" gradient it gets, which in this case looks like this:

julia> dm = gradient(m -> sum(m.x), re(zero(v)))
((x = [1.0, 1.0], y = nothing),)

julia> dm = gradient(m -> sum(m.x), re(zero(v)))
((x = [1.0, 1.0], y = nothing),)

julia> Flux.destructure(dm)
([1.0, 1.0], ...)

which gives a vector the wrong length, since the structure is different. And maybe Functors isn't equipped to handle that, you'd need to write a different recursion?

@mcabbott
Copy link
Member

mcabbott commented Jan 16, 2022

Note that this also interacts badly with shared parameters: instead of an error, it can silently give you the wrong answer.

julia> sh = [7,7];

julia> v, re = Flux.destructure((x=sh, y=[3,4], z=sh))  # shared array in the model
([7, 7, 3, 4], Base.Fix1(Flux._restructure, ...))

julia> re([1,10,100,1000])
(x = [1, 10], y = [100, 1000], z = [1, 10])

julia> gradient(zero(v)) do w
         m = re(w)
         3 * sum(m.x) + 13 * sum(m.z)  # no dependence on y, but two distinct gradient arrays
       end
([3.0, 3.0, 13.0, 13.0],)  # wrong answer, should be [16, 16, 0, 0]?

julia> gradient(zero(v)) do w
         m = re(w)
         4(sum(m.x) + sum(m.z))  # now two gradients are ===, so it eliminates one
       end
┌ Warning: Expected 4 params, got 2
└ @ Flux ~/.julia/packages/Flux/HNHmp/src/functor.jl:61
ERROR: DimensionMismatch("variable with size(x) == (4,) cannot have a gradient with size(dx) == (2,)")

julia> gradient(zero(v)) do w
         m = re(w)
         4(sum(m.x) + sum(m.y)) + 13*sum(m.z)  # again two gradients are ===, so it eliminates one
       end
([4.0, 4.0, 13.0, 13.0],)  # wrong answer, certainly gradient with y should be 4

Edit: #1767 is the issue about sharing.

I see that Functors#master has a 3-arg fmap which is the sort of thing you might want for this. But it seems to have the wrong behaviour... and of course no adversarial tests...

julia> fmap(println, (x=[1,2], y=[3,4], z=[5,6]), (x=[0,2], y=nothing, z=[0,3]))  # visits nothing
[1, 2][0, 2]
[3, 4]nothing
[5, 6][0, 3]
(x = nothing, y = nothing, z = nothing)

julia> fmap(println, (x=[1,2], y=(a=[3,4], b=[5,6])), (x=[0,2], y=nothing))  # collapsed branch is an error
[1, 2][0, 2]
ERROR: MethodError: no method matching length(::Nothing)

julia> fmap(println, (x=sh, y=[3,4], z=sh))  # default fmap omits shared array
[7, 7]
[3, 4]
(x = nothing, y = nothing, z = nothing)

julia> sh2 = [0,7];  # shared "gradient" array correctly ignored, but 

julia> fmap(println, (x=sh, y=[3,4], z=sh), (x=sh2, y=sh2, z=[0,3]))  # never sees [0,3] so can't accumulate
[7, 7][0, 7]
[3, 4][0, 7]
(x = nothing, y = nothing, z = nothing)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants