Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type instability in Recur for 3 dimensional arrays #1947

Closed
characat0 opened this issue Apr 19, 2022 · 9 comments · Fixed by #1948
Closed

Type instability in Recur for 3 dimensional arrays #1947

characat0 opened this issue Apr 19, 2022 · 9 comments · Fixed by #1948

Comments

@characat0
Copy link
Contributor

While working with LSTM I noticed that Recur is not type stable for 3 dimensional arrays.

If we execute:

using Flux
model = LSTM(1=>1)
x = rand(Float32, 1, 1, 1)
@code_warntype model(x)

Result:

MethodInstance for (::Flux.Recur{Flux.LSTMCell{Matrix{Float32}, Vector{Float32}, Tuple{Matrix{Float32}, Matrix{Float32}}}, Tuple{Matrix{Float32}, Matrix{Float32}}})(::Array{Float32, 3})
  from (m::Flux.Recur)(x::AbstractArray{T, 3}) where T in Flux at ~\.julia\packages\Flux\18YZE\src\layers\recurrent.jl:88
Static Parameters
  T = Float32
Arguments
  m::Flux.Recur{Flux.LSTMCell{Matrix{Float32}, Vector{Float32}, Tuple{Matrix{Float32}, Matrix{Float32}}}, Tuple{Matrix{Float32}, Matrix{Float32}}}
  x::Array{Float32, 3}
Locals
  #269::Flux.var"#269#270"{Flux.Recur{Flux.LSTMCell{Matrix{Float32}, Vector{Float32}, Tuple{Matrix{Float32}, Matrix{Float32}}}, Tuple{Matrix{Float32}, Matrix{Float32}}}}
  sze::Any
  h::Any
Body::Any
1%1  = %new(Flux.var"#269#270"{Flux.Recur{Flux.LSTMCell{Matrix{Float32}, Vector{Float32}, Tuple{Matrix{Float32}, Matrix{Float32}}}, Tuple{Matrix{Float32}, Matrix{Float32}}}}, m)::Flux.var"#269#270"{Flux.Recur{Flux.LSTMCell{Matrix{Float32}, Vector{Float32}, Tuple{Matrix{Float32}, Matrix{Float32}}}, Tuple{Matrix{Float32}, Matrix{Float32}}}}
│   %2  = invoke Base.var"#eachslice##kw"()($(QuoteNode((dims = 3,)))::NamedTuple{(:dims,), Tuple{Int64}}, Flux.eachslice::typeof(eachslice), x::Array{Float32, 3})::Base.Generator{Base.OneTo{Int64}}%3  = Base.Generator(%1, %2)::Base.Generator{_A, Flux.var"#269#270"{Flux.Recur{Flux.LSTMCell{Matrix{Float32}, Vector{Float32}, Tuple{Matrix{Float32}, Matrix{Float32}}}, Tuple{Matrix{Float32}, Matrix{Float32}}}}} where _A
│   %4  = Base.collect(%3)::Any%5  = Base.getindex(%4, 1)::Any%6  = Flux.size(%5)::Any%7  = Flux.reduce(Flux.hcat, %4)::Any%8  = Base.getindex(%6, 1)::Any%9  = Base.getindex(%6, 2)::Any%10 = Flux.length(%4)::Any%11 = Flux.reshape(%7, %8, %9, %10)::Any
└──       return %11

The type instability is caused by a call to eachslice in:

h = [m(x_t) for x_t in eachslice(x, dims=3)]

This is related to JuliaLang/julia#39639. A quick fix would be to replace eachslice with a view of slices of the array along the third dimension.

@ToucheSir
Copy link
Member

What do you mean by a "view of slices of the array along the third dimension"?

@characat0
Copy link
Contributor Author

characat0 commented Apr 19, 2022

What do you mean by a "view of slices of the array along the third dimension"?

I was about to suggest going back to view(x, :, :, i) before I saw #1873, but now I am not sure if there is a way to make this type stable and performant.

@mkschleg
Copy link
Contributor

mkschleg commented Apr 20, 2022

Hmmm. Good catch. Eachslice is effectively doing those views with an efficient rrule (as far as I understand it). It is a shame eachslice is not type stable. I could see two work arounds.

1.) Do a type check after reduce and force typing on the return.
2.) Do a custom eachslice that is type stable and re-implement the rrule for this new eachslice.

1 is a fast hot fix, but 2 is likely what would be best in the long term. Unsure how feasible 2 is though and I'm swamped at work, so I've been spotty on my flux progress and won't be able to help too much.

Given the problem in eachslice from the issue in julia you linked back to, we could implement a specialized eachslice that always crawls the last dimension, which we should be able to infer from the type signature. I think this would be type-stable.

@characat0
Copy link
Contributor Author

I found the rrule defined here in Zygote.jl, however there is another definition in ChainRules.jl which takes advantage of Val to optimize it at compile time (and save a lot of allocations).
I made a custom function called eachlastdim that returns an iterator over the last dimension and implemented the rrule for it, but I'm unsure if I should use the Zygote or ChainRules method since ChainRules is not directly imported in Flux.

@ToucheSir
Copy link
Member

I wonder if deleting the Zygote adjoint would be enough, have you tried that?

@characat0
Copy link
Contributor Author

Currently, I'm using the ChainRules one, so we could safely delete the definition in Zygote.
Let me know what do you think about the addition of eachlastdim and I will gladly make a PR in Zygote for deleting the old adjoints.

@ToucheSir
Copy link
Member

Ideally we wouldn't need #1948/eachlastdim at all if the compiler is smart enough to make sense of the rrule in ChainRules. That's why I mentioned testing after deleting the Zygote adjoint: Zygote will automatically fall back to that rrule and you can see whether type stability is preserved then.

@characat0
Copy link
Contributor Author

After deleting the rrules in Zygote, I tested the following code:

using Zygote
x = rand(Float32, 1, 1, 1, 10);
f(x) = eachslice(x; dims=4);
y, back = Zygote.pullback(f, x);
@code_warntype back(y)

Result:

MethodInstance for (::Zygote.var"#56#57"{typeof((f))})(::Vector{SubArray{Float32, 3, Array{Float32, 4}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Int64}, true}})
  from (::Zygote.var"#56#57")(Δ) in Zygote at ~\.julia\packages\Zygote\xEPQb\src\compiler\interface.jl:41
Arguments
  #self#::Zygote.var"#56#57"{typeof(∂(f))}
  Δ::Vector{SubArray{Float32, 3, Array{Float32, 4}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Int64}, true}}
Body::Tuple{Array{Float32, 4}}
1%1 = Core.getfield(#self#, :back)::typeof(∂(f))%2 = (%1)(Δ)::Tuple{Nothing, Array{Float32, 4}}%3 = Zygote.tailmemaybe(%2)::Tuple{Array{Float32, 4}}
└──      return %3

So the pullback of eachslice is type stable using ChainRules rrule, however, eachslice remains type unstable.

@ToucheSir
Copy link
Member

ToucheSir commented Apr 21, 2022

Ok, it was worth a try! Let's continue this discussion to #1948.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants