-
-
Notifications
You must be signed in to change notification settings - Fork 612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flux.destructure
's restructure fails in the gradient if loss does not use all parameters
#1826
Comments
They should be 0, Zygote's parameter tracking isn't that fine-grained. Usually you'd expect a mismatch error like this when an entire array parameter isn't used at all, but AFAICT that is not the case here. To troubleshoot, you'll want to look at what is being passed as the gradient value into |
@ToucheSir Thanks for the pointer. From what I can tell, the In the slightly simpler MWE below, when (layers = ((paths = ((weight = Float32[-11.297353 -10.366326 -10.771574 -10.596351 -9.816019; 11.591414 13.127091 9.820183 12.75714 8.937439; 7.2377276 10.328969 6.7674723 8.917322 7.2669835], bias = Float32[-17.877731, 19.011127, 13.564459], σ = nothing), nothing),),),) so it only seems to be grabbing the I don't have a strong grasp on Zygote's fundamentals beyond the basics, so it's not super obvious to me what might be driving this issue. using Flux
struct Split{T} # taken from: https://fluxml.ai/Flux.jl/stable/models/advanced/#Multiple-outputs:-a-custom-Split-layer
paths::T
end
Split(paths...) = Split(paths)
@functor Split
(m::Split)(x::AbstractArray) = map(f -> f(x), m.paths)
n_input, n_batch, n_shared = 5, 13, 11
n_outputs = [3, 7]
data = rand(Float32, n_input, n_batch)
model = Chain(
Split(Dense(n_input, n_outputs[1]), Dense(n_input, n_outputs[2]))
)
ps, re = Flux.destructure(model);
loss(x, idx, ps) = sum(abs2, re(ps)(x)[idx])
gradient(params -> loss(data, 1, params), ps) |
Ah sorry, I missed that the tuple of arrays from |
FWIW, it does seem viable to calculate gradients with the implicit parameters style and then (un)flatten them post-hoc (although I believe the solution here requires some modification to account for the fact that some gradients might be |
Yes, using implicit parameters avoids this problem altogether. I don't believe it even needs any additional modification either, just a check to make sure each param is in the |
We're working through our requirements at the moment, but being able to use only flattened parameters would make some of our designs a bit easier, hence the motivation for raising the issue. |
Flux.destructure
doesn't work with Split
Flux.destructure
's restructure fails in the gradient if loss does not use all parameters
I think the MWE is something like this: julia> v, re = Flux.destructure((x=[1,2], y=[3,4,5]))
([1, 2, 3, 4, 5], Flux.var"#128#129"...
julia> gradient(zero(v)) do w
m = re(w)
5 * sum(m.x) + 7 * sum(m[2]) # uses both x and y
end
([5.0, 5.0, 7.0, 7.0, 7.0],)
julia> gradient(w -> sum(re(w).x), zero(v)) # uses only x
┌ Warning: Expected 5 params, got 2
└ @ Flux ~/.julia/packages/Flux/HNHmp/src/functor.jl:61
ERROR: DimensionMismatch("variable with size(x) == (5,) cannot have a gradient with size(dx) == (2,)") And the problem is that Line 649 in ea26f45
blindly calls destructure again on dm , the "structural" gradient it gets, which in this case looks like this:
julia> dm = gradient(m -> sum(m.x), re(zero(v)))
((x = [1.0, 1.0], y = nothing),)
julia> dm = gradient(m -> sum(m.x), re(zero(v)))
((x = [1.0, 1.0], y = nothing),)
julia> Flux.destructure(dm)
([1.0, 1.0], ...) which gives a vector the wrong length, since the structure is different. And maybe Functors isn't equipped to handle that, you'd need to write a different recursion? |
Note that this also interacts badly with shared parameters: instead of an error, it can silently give you the wrong answer. julia> sh = [7,7];
julia> v, re = Flux.destructure((x=sh, y=[3,4], z=sh)) # shared array in the model
([7, 7, 3, 4], Base.Fix1(Flux._restructure, ...))
julia> re([1,10,100,1000])
(x = [1, 10], y = [100, 1000], z = [1, 10])
julia> gradient(zero(v)) do w
m = re(w)
3 * sum(m.x) + 13 * sum(m.z) # no dependence on y, but two distinct gradient arrays
end
([3.0, 3.0, 13.0, 13.0],) # wrong answer, should be [16, 16, 0, 0]?
julia> gradient(zero(v)) do w
m = re(w)
4(sum(m.x) + sum(m.z)) # now two gradients are ===, so it eliminates one
end
┌ Warning: Expected 4 params, got 2
└ @ Flux ~/.julia/packages/Flux/HNHmp/src/functor.jl:61
ERROR: DimensionMismatch("variable with size(x) == (4,) cannot have a gradient with size(dx) == (2,)")
julia> gradient(zero(v)) do w
m = re(w)
4(sum(m.x) + sum(m.y)) + 13*sum(m.z) # again two gradients are ===, so it eliminates one
end
([4.0, 4.0, 13.0, 13.0],) # wrong answer, certainly gradient with y should be 4 Edit: #1767 is the issue about sharing. I see that julia> fmap(println, (x=[1,2], y=[3,4], z=[5,6]), (x=[0,2], y=nothing, z=[0,3])) # visits nothing
[1, 2][0, 2]
[3, 4]nothing
[5, 6][0, 3]
(x = nothing, y = nothing, z = nothing)
julia> fmap(println, (x=[1,2], y=(a=[3,4], b=[5,6])), (x=[0,2], y=nothing)) # collapsed branch is an error
[1, 2][0, 2]
ERROR: MethodError: no method matching length(::Nothing)
julia> fmap(println, (x=sh, y=[3,4], z=sh)) # default fmap omits shared array
[7, 7]
[3, 4]
(x = nothing, y = nothing, z = nothing)
julia> sh2 = [0,7]; # shared "gradient" array correctly ignored, but
julia> fmap(println, (x=sh, y=[3,4], z=sh), (x=sh2, y=sh2, z=[0,3])) # never sees [0,3] so can't accumulate
[7, 7][0, 7]
[3, 4][0, 7]
(x = nothing, y = nothing, z = nothing) |
It doesn't seem like you can use the explicit-parameters
destructure
-based formulation of Flux in combination with NN architectures that use a single input to make multiple outputs. Is this by design? Ideally, the entries in the gradient corresponding to weights that don't get used in the loss could be 0 or missing, but maybe that introduces implementation challenges.The full error message:
The text was updated successfully, but these errors were encountered: