Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom trivial rrule causes NNlib CPU backend to be used for CUDA Flux.Conv #140

Closed
DrChainsaw opened this issue Jul 30, 2023 · 8 comments
Closed

Comments

@DrChainsaw
Copy link

Sorry for high level MWE, I could not come up with a way to further break it down:

using Flux, Yota

struct Wrapper{L}
  l::L
end

(w::Wrapper)(x) = w.l(x)

let
  model = Wrapper(gpu(Conv((1,1), 3=>1)))
  x = gpu(randn(Float32, 32,32,3,1))
  Yota.grad(m -> sum(m(x)), model)
end; # No problem!

However, after adding a trivial custom rrule for Wrapper it seems like the CPU backend is used:

function ChainRulesCore.rrule(config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode}, w::Wrapper, args...)
    res, back = ChainRulesCore.rrule_via_ad(config, w.l, args...)
    function Wrapper_back(Δ)
      δs = back(Δ)
      ChainRulesCore.Tangent{Wrapper}(l=δs[1]), δs[2:end]...
    end
    return res, Wrapper_back
end

julia> let
       model = Wrapper(gpu(Conv((1,1), 3=>1)))
       x = gpu(randn(Float32, 32,32,3,1))
       Yota.grad(m -> sum(m(x)), model)
       end;
┌ Warning: Performing scalar indexing on task Task (runnable) @0x0000022b2f8789c0.
│ Invocation of getindex resulted in scalar indexing of a GPU array.
│ This is typically caused by calling an iterating implementation of a method.
│ Such implementations *do not* execute on the GPU, but very slowly on the CPU,
│ and therefore are only permitted from the REPL for prototyping purposes.
│ If you did intend to index this array, annotate the caller with @allowscalar.
└ @ GPUArraysCore E:\Programs\julia\.julia\packages\GPUArraysCore\uOYfN\src\GPUArraysCore.jl:106
ERROR: TaskFailedException
### Full stacktrace in details in the end ###

Seems like Zygote can handle it:

let
    model = Wrapper(gpu(Conv((1,1), 3=>1)))
    x = gpu(randn(Float32, 32,32,3,1))
    gradient(m -> sum(m(x)), model)
end;
Full Stack Trace
julia> let
     model = Wrapper(gpu(Conv((1,1), 3=>1)))
     x = gpu(randn(Float32, 32,32,3,1))
     Yota.grad(m -> sum(m(x)), model)
     end;
┌ Warning: Performing scalar indexing on task Task (runnable) @0x0000022b2f8789c0.
│ Invocation of getindex resulted in scalar indexing of a GPU array.
│ This is typically caused by calling an iterating implementation of a method.
│ Such implementations *do not* execute on the GPU, but very slowly on the CPU,
│ and therefore are only permitted from the REPL for prototyping purposes.
│ If you did intend to index this array, annotate the caller with @allowscalar.
└ @ GPUArraysCore E:\Programs\julia\.julia\packages\GPUArraysCore\uOYfN\src\GPUArraysCore.jl:106
ERROR: TaskFailedException

  nested task error: TaskFailedException

      nested task error: MethodError: no method matching gemm!(::Val{false}, ::Val{true}, ::Int64, ::Int64, ::Int64, ::Float32, ::Ptr{Float32}, ::CUDA.CuPtr{Float32}, ::Float32, ::Ptr{Float32}) 

      Closest candidates are:
        gemm!(::Val, ::Val, ::Int64, ::Int64, ::Int64, ::Float32, ::Ptr{Float32}, ::Ptr{Float32}, ::Float32, ::Ptr{Float32})
         @ NNlib E:\Programs\julia\.julia\packages\NNlib\aaK3U\src\gemm.jl:29
        gemm!(::Val, ::Val, ::Int64, ::Int64, ::Int64, ::Float64, ::Ptr{Float64}, ::Ptr{Float64}, ::Float64, ::Ptr{Float64})
         @ NNlib E:\Programs\julia\.julia\packages\NNlib\aaK3U\src\gemm.jl:29
        gemm!(::Val, ::Val, ::Int64, ::Int64, ::Int64, ::ComplexF64, ::Ptr{ComplexF64}, ::Ptr{ComplexF64}, ::ComplexF64, ::Ptr{ComplexF64})
         @ NNlib E:\Programs\julia\.julia\packages\NNlib\aaK3U\src\gemm.jl:29
        ...

      Stacktrace:
       [1] macro expansion
         @ E:\Programs\julia\.julia\packages\NNlib\aaK3U\src\impl\conv_im2col.jl:163 [inlined]
       [2] (::NNlib.var"#640#641"{Float32, Array{Float32, 3}, Float32, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, CUDA.CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, DenseConvDims{3, 3, 3, 6, 3}, Int64, Int64, Int64, UnitRange{Int64}, Int64})()
         @ NNlib .\threadingconstructs.jl:404
  Stacktrace:
   [1] sync_end(c::Channel{Any})
     @ Base .\task.jl:445
   [2] macro expansion
     @ .\task.jl:477 [inlined]
   [3] ∇conv_data_im2col!(dx::SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, dy::SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, w::CUDA.CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, cdims::DenseConvDims{3, 3, 3, 6, 3}; col::Array{Float32, 3}, alpha::Float32, beta::Float32, ntasks::Int64)       @ NNlib E:\Programs\julia\.julia\packages\NNlib\aaK3U\src\impl\conv_im2col.jl:155
   [4] ∇conv_data_im2col!(dx::SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, dy::SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, w::CUDA.CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, cdims::DenseConvDims{3, 3, 3, 6, 3})
     @ NNlib E:\Programs\julia\.julia\packages\NNlib\aaK3U\src\impl\conv_im2col.jl:126
   [5] (::NNlib.var"#323#327"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, DenseConvDims{3, 3, 3, 6, 3}, CUDA.CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, SubArray{Float32, 5, 
Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}})()
     @ NNlib .\threadingconstructs.jl:404
Stacktrace:
[1] sync_end(c::Channel{Any})
  @ Base .\task.jl:445
[2] macro expansion
  @ .\task.jl:477 [inlined]
[3] ∇conv_data!(out::Array{Float32, 5}, in1::Array{Float32, 5}, in2::CUDA.CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, cdims::DenseConvDims{3, 3, 3, 6, 3}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
  @ NNlib E:\Programs\julia\.julia\packages\NNlib\aaK3U\src\conv.jl:249
[4] ∇conv_data!
  @ E:\Programs\julia\.julia\packages\NNlib\aaK3U\src\conv.jl:226 [inlined]
[5] ∇conv_data!(y::Array{Float32, 4}, x::Array{Float32, 4}, w::CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, cdims::DenseConvDims{2, 2, 2, 4, 2}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, 
NamedTuple{(), Tuple{}}})
  @ NNlib E:\Programs\julia\.julia\packages\NNlib\aaK3U\src\conv.jl:145
[6] ∇conv_data!
  @ E:\Programs\julia\.julia\packages\NNlib\aaK3U\src\conv.jl:140 [inlined]
[7] #∇conv_data#241
  @ E:\Programs\julia\.julia\packages\NNlib\aaK3U\src\conv.jl:99 [inlined]
[8] ∇conv_data
  @ E:\Programs\julia\.julia\packages\NNlib\aaK3U\src\conv.jl:95 [inlined]
[9] #380
  @ E:\Programs\julia\.julia\packages\NNlib\aaK3U\src\conv.jl:350 [inlined]
[10] unthunk
  @ E:\Programs\julia\.julia\packages\ChainRulesCore\0t04l\src\tangent_types\thunks.jl:204 [inlined]
[11] map(f::typeof(ChainRulesCore.unthunk), t::Tuple{ChainRulesCore.Tangent{Conv{2, 4, typeof(identity), CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, NamedTuple{(:weight, :bias), Tuple{Array{Float32, 4}, Vector{Float32}}}}, ChainRulesCore.Thunk{NNlib.var"#380#383"{Array{Float32, 4}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, DenseConvDims{2, 2, 2, 4, 2}}}})
  @ Base .\tuple.jl:274
[12] mkcall(::Any, ::Any, ::Vararg{Any}; val::Any, line::Any, kwargs::Any, free_kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
  @ Umlaut E:\Programs\julia\.julia\packages\Umlaut\XPASX\src\tape.jl:207
[13] mkcall(::Any, ::Any, ::Vararg{Any})
  @ Umlaut E:\Programs\julia\.julia\packages\Umlaut\XPASX\src\tape.jl:188
[14] finalize_grad!(tape::Umlaut.Tape{Yota.GradCtx})
  @ Yota E:\Programs\julia\.julia\packages\Yota\sbYPp\src\grad.jl:238
[15] #gradtape!#77
  @ E:\Programs\julia\.julia\packages\Yota\sbYPp\src\grad.jl:251 [inlined]
[16] gradtape(f::Conv{2, 4, typeof(identity), CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, args::CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}; ctx::Yota.GradCtx, seed::Symbol)
  @ Yota E:\Programs\julia\.julia\packages\Yota\sbYPp\src\grad.jl:268
[17] gradtape
  @ E:\Programs\julia\.julia\packages\Yota\sbYPp\src\grad.jl:263 [inlined]
[18] make_rrule!(f::Conv{2, 4, typeof(identity), CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, args::CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer})
  @ Yota E:\Programs\julia\.julia\packages\Yota\sbYPp\src\chainrules.jl:91
[19] rrule_via_ad(cfg::Yota.YotaRuleConfig, f::Conv{2, 4, typeof(identity), CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, args::CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer})
  @ Yota E:\Programs\julia\.julia\packages\Yota\sbYPp\src\chainrules.jl:119
[20] rrule(config::Yota.YotaRuleConfig, w::Wrapper{Conv{2, 4, typeof(identity), CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, args::CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer})
  @ Main .\REPL[38]:2
[21] mkcall(::Any, ::Any, ::Vararg{Any}; val::Any, line::Any, kwargs::Any, free_kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
  @ Umlaut E:\Programs\julia\.julia\packages\Umlaut\XPASX\src\tape.jl:207
[22] chainrules_transform!(tape::Umlaut.Tape{Yota.GradCtx})
  @ Yota E:\Programs\julia\.julia\packages\Yota\sbYPp\src\grad.jl:149
[23] #gradtape!#77
  @ E:\Programs\julia\.julia\packages\Yota\sbYPp\src\grad.jl:247 [inlined]
[24] gradtape(::Function, ::Wrapper{Conv{2, 4, typeof(identity), CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, ::Vararg{Any}; ctx::Yota.GradCtx, seed::Int64)
  @ Yota E:\Programs\julia\.julia\packages\Yota\sbYPp\src\grad.jl:268
[25] grad(::Function, ::Wrapper{Conv{2, 4, typeof(identity), CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, ::Vararg{Any}; seed::Int64)       
  @ Yota E:\Programs\julia\.julia\packages\Yota\sbYPp\src\grad.jl:360
[26] grad(::Function, ::Wrapper{Conv{2, 4, typeof(identity), CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, ::Vararg{Any})
  @ Yota E:\Programs\julia\.julia\packages\Yota\sbYPp\src\grad.jl:352
[27] top-level scope
  @ REPL[46]:4
[28] top-level scope
  @ E:\Programs\julia\.julia\packages\CUDA\tVtYo\src\initialization.jl:185
@dfdx
Copy link
Owner

dfdx commented Jul 30, 2023

It will take me some time to get an environment with a GPU to try it, but I have a couple of quick guesses:

  1. x is passed implicitly in your code. Yota treats all non-argument values as constants, so x is recorded to the underlying tape as is and never changes. Not sure it causes the issue, but I'd certainly use Yota.grad((m, x) -> sum(m, x), model, x) instead.
  2. rrule_via_ad() has always been quite fragile, so I'd try an rrule() without it first. If it works, we can further check what's going on in rrule_via_ad() as follows:
import Yota: V

# this is what happens internally when you call rrule_via_ad()
tape = Yota.gradtape(model.l, x; seed=:auto, ctx=Yota.GradCtx())

# check if any operation produces a non-GPU array
for i in length(tape)
    op = tape[V(i)]
    if isa(op.val, AbstractArray) && !isa(op.val, CuArray)
        println("Operaion $(op) produces a non-GPU array")
    elseif isa(op.val, Tuple)
        for res in op.val
             if isa(res, AbstractArray) && !isa(res, CuArray)
                println("Operaion $(op) produces a non-GPU array")
            end
        end
    end
end

@DrChainsaw
Copy link
Author

Thanks!

I tried 1 and it gave the same error:

let
       model = Wrapper(gpu(Conv((1,1), 3=>1)))
       x = gpu(randn(Float32, 32,32,3,1))
       Yota.grad((m, xx) -> sum(model(xx)), model, x)
       end;
┌ Warning: Performing scalar indexing on task Task (runnable) @0x000002428231c7d0.

For 2 I can't really wrap my head around how to do it without calling back to AD for the wrapped function.

Here is one example from the wild where the only purpose of the rrule is to snoop on gradients. How can it be written without calling back to AD?

I'm not sure if the last part was some way to try to see where it goes wrong or if it was just fyi, but I tried running it but I get the same error from gradtape:

import Yota: V

let
       model = Wrapper(gpu(Conv((1,1), 3=>1)))
       x = gpu(randn(Float32, 32,32,3,1))
       tape = Yota.gradtape(model.l, x; seed=:auto, ctx=Yota.GradCtx())
end;
┌ Warning: Performing scalar indexing on task Task (runnable) @0x0000024281dbb650.

Let me know if there is something else I can do to help troubleshoot from my side.

@dfdx
Copy link
Owner

dfdx commented Jul 31, 2023

For 2 I can't really wrap my head around how to do it without calling back to AD for the wrapped function.

I meant using a fake rrule just to understand where the error comes from. But from the last experiment I'm now pretty much sure the problem comes from rrule_via_ad(), which narrows down the search. I will try to get a Julia + GPU setup and debug the issue this week. Thanks for discovering it!

@dfdx
Copy link
Owner

dfdx commented Aug 4, 2023

TLDR: Run ] add cuDNN .

It turns out, neither CUDA, nor Flux install cuDNN by default. import Flux reports it as a possible issue, but doesn't prevent you from running the code on GPU. As a result, even forward pass model(x) leads to the error you posted, without Yota or ChainRules involved. Installing cuDNN solves the issue.

I don't quite understand how Zygote works around this issue, but it may indirectly install cuDNN or rewrite calls to conv to some alternative implementation than model(x) invokes.

@DrChainsaw
Copy link
Author

I had both CUDA and cuDNN installed in the MWE. :(

The example without the rrule and the Zygote example wouldn't have worked without it. Sorry for not making that clear.

@dfdx
Copy link
Owner

dfdx commented Aug 5, 2023

So you mean with cuDNN installed the Yota + ChainRules example still doesn't work? Can you post your ] st then?

@dfdx
Copy link
Owner

dfdx commented Aug 5, 2023

Ok, I can reproduce it now. Investigating.

@dfdx
Copy link
Owner

dfdx commented Aug 5, 2023

Fixed in version 0.8.5, please re-open if you still experience the issue.

@dfdx dfdx closed this as completed Aug 5, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants