-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Custom trivial rrule causes NNlib CPU backend to be used for CUDA Flux.Conv #140
Comments
It will take me some time to get an environment with a GPU to try it, but I have a couple of quick guesses:
import Yota: V
# this is what happens internally when you call rrule_via_ad()
tape = Yota.gradtape(model.l, x; seed=:auto, ctx=Yota.GradCtx())
# check if any operation produces a non-GPU array
for i in length(tape)
op = tape[V(i)]
if isa(op.val, AbstractArray) && !isa(op.val, CuArray)
println("Operaion $(op) produces a non-GPU array")
elseif isa(op.val, Tuple)
for res in op.val
if isa(res, AbstractArray) && !isa(res, CuArray)
println("Operaion $(op) produces a non-GPU array")
end
end
end
end |
Thanks! I tried 1 and it gave the same error: let
model = Wrapper(gpu(Conv((1,1), 3=>1)))
x = gpu(randn(Float32, 32,32,3,1))
Yota.grad((m, xx) -> sum(model(xx)), model, x)
end;
┌ Warning: Performing scalar indexing on task Task (runnable) @0x000002428231c7d0. For 2 I can't really wrap my head around how to do it without calling back to AD for the wrapped function. Here is one example from the wild where the only purpose of the rrule is to snoop on gradients. How can it be written without calling back to AD? I'm not sure if the last part was some way to try to see where it goes wrong or if it was just fyi, but I tried running it but I get the same error from import Yota: V
let
model = Wrapper(gpu(Conv((1,1), 3=>1)))
x = gpu(randn(Float32, 32,32,3,1))
tape = Yota.gradtape(model.l, x; seed=:auto, ctx=Yota.GradCtx())
end;
┌ Warning: Performing scalar indexing on task Task (runnable) @0x0000024281dbb650. Let me know if there is something else I can do to help troubleshoot from my side. |
I meant using a fake |
TLDR: Run It turns out, neither CUDA, nor Flux install cuDNN by default. I don't quite understand how Zygote works around this issue, but it may indirectly install cuDNN or rewrite calls to |
I had both CUDA and cuDNN installed in the MWE. :( The example without the |
So you mean with cuDNN installed the Yota + ChainRules example still doesn't work? Can you post your |
Ok, I can reproduce it now. Investigating. |
Fixed in version 0.8.5, please re-open if you still experience the issue. |
Sorry for high level MWE, I could not come up with a way to further break it down:
However, after adding a trivial custom rrule for
Wrapper
it seems like the CPU backend is used:Seems like Zygote can handle it:
Full Stack Trace
The text was updated successfully, but these errors were encountered: