-
-
Notifications
You must be signed in to change notification settings - Fork 604
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conv
is not working for Complex when using CUDA
#1655
Comments
Seems like the weights are real still. What happens if we convert those to complex, CUDA should be able to work with that. |
Hi @DhairyaLGandhi , I have tried c_glorot_uniform(dims...) = Flux.glorot_uniform(dims...) + Flux.glorot_uniform(dims...) * im
m = Chain(
Conv((3, ), 1=>2, pad=1, init=c_glorot_uniform),
) |> gpu And I got the same error
|
https://github.com/FluxML/NNlib.jl/blob/v0.7.33/src/impl/conv_im2col.jl#L230 is the culprit, so unless we get a CUDA-compatible (conv_)im2col in NNlib this will not work. |
Seems like if we can use a sufficiently general rule it would. Does |
conv_direct is even worse because it makes pervasive use of scalar indexing. |
The code mentioned above will work if change the
T
foFloat32
. If run on CPU, bothComplexF32
andFloat32
work.And the model will also work if allow scalar.
The error message when
T = ComplexF32
using CUDA:The text was updated successfully, but these errors were encountered: