-
-
Notifications
You must be signed in to change notification settings - Fork 611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Type instability in Recur for 3 dimensional arrays #1947
Comments
What do you mean by a "view of slices of the array along the third dimension"? |
I was about to suggest going back to |
Hmmm. Good catch. Eachslice is effectively doing those views with an efficient rrule (as far as I understand it). It is a shame eachslice is not type stable. I could see two work arounds. 1.) Do a type check after reduce and force typing on the return. 1 is a fast hot fix, but 2 is likely what would be best in the long term. Unsure how feasible 2 is though and I'm swamped at work, so I've been spotty on my flux progress and won't be able to help too much. Given the problem in eachslice from the issue in julia you linked back to, we could implement a specialized eachslice that always crawls the last dimension, which we should be able to infer from the type signature. I think this would be type-stable. |
I found the rrule defined here in Zygote.jl, however there is another definition in ChainRules.jl which takes advantage of Val to optimize it at compile time (and save a lot of allocations). |
I wonder if deleting the Zygote adjoint would be enough, have you tried that? |
Currently, I'm using the ChainRules one, so we could safely delete the definition in Zygote. |
Ideally we wouldn't need #1948/ |
After deleting the rrules in Zygote, I tested the following code: using Zygote
x = rand(Float32, 1, 1, 1, 10);
f(x) = eachslice(x; dims=4);
y, back = Zygote.pullback(f, x);
@code_warntype back(y) Result: MethodInstance for (::Zygote.var"#56#57"{typeof(∂(f))})(::Vector{SubArray{Float32, 3, Array{Float32, 4}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Int64}, true}})
from (::Zygote.var"#56#57")(Δ) in Zygote at ~\.julia\packages\Zygote\xEPQb\src\compiler\interface.jl:41
Arguments
#self#::Zygote.var"#56#57"{typeof(∂(f))}
Δ::Vector{SubArray{Float32, 3, Array{Float32, 4}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Int64}, true}}
Body::Tuple{Array{Float32, 4}}
1 ─ %1 = Core.getfield(#self#, :back)::typeof(∂(f))
│ %2 = (%1)(Δ)::Tuple{Nothing, Array{Float32, 4}}
│ %3 = Zygote.tailmemaybe(%2)::Tuple{Array{Float32, 4}}
└── return %3 So the pullback of |
Ok, it was worth a try! Let's continue this discussion to #1948. |
While working with LSTM I noticed that Recur is not type stable for 3 dimensional arrays.
If we execute:
Result:
The type instability is caused by a call to
eachslice
in:Flux.jl/src/layers/recurrent.jl
Line 89 in f038cff
This is related to JuliaLang/julia#39639. A quick fix would be to replace
eachslice
with a view of slices of the array along the third dimension.The text was updated successfully, but these errors were encountered: