Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zygote, a FillArray of structs and broadcasting don't work together #632

Open
mohamed82008 opened this issue May 1, 2020 · 2 comments
Open
Labels
needs adjoint missing rule

Comments

@mohamed82008
Copy link
Contributor

Hi!

Here is a MWE that gives an error when trying to differentiate a function that broadcasts a FillArray of structs.

using FillArrays, Zygote

struct T
    a
end
f(t::T, x) = t.a + x
Zygote.gradient(rand(2)) do x
    ts = Fill(T(1), 2)
    sum(f.(ts, x))
end

The error is:

ERROR: Need an adjoint for constructor Fill{T,1,Tuple{Base.OneTo{Int64}}}. Gradient is of type Array{NamedTuple{(:a,),Tuple{Float64}},1}
Stacktrace:
 [1] error(::String) at .\error.jl:33
 [2] (::Zygote.Jnew{Fill{T,1,Tuple{Base.OneTo{Int64}}},Nothing,false})(::Array{NamedTuple{(:a,),Tuple{Float64}},1}) at C:\Users\user\.julia\packages\Zygote\jLxtV\src\lib\lib.jl:306
 [3] (::Zygote.var"#380#back#193"{Zygote.Jnew{Fill{T,1,Tuple{Base.OneTo{Int64}}},Nothing,false}})(::Array{NamedTuple{(:a,),Tuple{Float64}},1}) at C:\Users\user\.julia\dev\ZygoteRules\src\adjoint.jl:49
 [4] Fill at C:\Users\user\.julia\packages\FillArrays\OhEYG\src\FillArrays.jl:57 [inlined]
 [5] (::typeof((Fill{T,1,Tuple{Base.OneTo{Int64}}})))(::Array{NamedTuple{(:a,),Tuple{Float64}},1}) at C:\Users\user\.julia\packages\Zygote\jLxtV\src\compiler\interface2.jl:0
 [6] Fill at C:\Users\user\.julia\packages\FillArrays\OhEYG\src\FillArrays.jl:64 [inlined]
 [7] Fill at C:\Users\user\.julia\packages\FillArrays\OhEYG\src\FillArrays.jl:69 [inlined]
 [8] Fill at C:\Users\user\.julia\packages\FillArrays\OhEYG\src\FillArrays.jl:76 [inlined]
 [9] (::typeof((Fill)))(::Array{NamedTuple{(:a,),Tuple{Float64}},1}) at C:\Users\user\.julia\packages\Zygote\jLxtV\src\compiler\interface2.jl:0
 [10] #5 at .\REPL[8]:2 [inlined]
 [11] (::typeof((#5)))(::Float64) at C:\Users\user\.julia\packages\Zygote\jLxtV\src\compiler\interface2.jl:0
 [12] (::Zygote.var"#36#37"{typeof((#5))})(::Float64) at C:\Users\user\.julia\packages\Zygote\jLxtV\src\compiler\interface.jl:36
 [13] gradient(::Function, ::Array{Float64,1}) at C:\Users\user\.julia\packages\Zygote\jLxtV\src\compiler\interface.jl:45
 [14] top-level scope at REPL[8]:1

This is on Zygote 0.4.19.

@cossio
Copy link
Contributor

cossio commented Jun 7, 2020

Did you try defining an adjoint for the Fill constructor?

@mcabbott
Copy link
Member

Not fixed by JuliaArrays/FillArrays.jl#153, FWIW.

@mcabbott mcabbott added the needs adjoint missing rule label Jul 22, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs adjoint missing rule
Projects
None yet
Development

No branches or pull requests

3 participants