Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flux's gradient differentiatingrfft leads to non-bit error #1835

Closed
ziyiyin97 opened this issue Mar 26, 2023 · 1 comment
Closed

Flux's gradient differentiatingrfft leads to non-bit error #1835

ziyiyin97 opened this issue Mar 26, 2023 · 1 comment

Comments

@ziyiyin97
Copy link

ziyiyin97 commented Mar 26, 2023

Describe the bug

Flux's gradient on rfft (in FFTW.jl, doing Fast Fourier Transform for real-valued entities) leads to error. On the other hand, fft runs fine.

To reproduce

The Minimal Working Example (MWE) for this bug:

(@v1.8) pkg> activate --temp
  Activating new project at `/tmp/jl_yR7NT1`

(jl_yR7NT1) pkg> add CUDA, Flux, FFTW
    Updating registry at `~/.julia/registries/General.toml`
   Resolving package versions...
   Installed Optimisers ─ v0.2.17
   Installed CUDA ─────── v4.1.1
    Updating `/tmp/jl_yR7NT1/Project.toml`
  [052768ef] + CUDA v4.1.1
  [7a1cc6ca] + FFTW v1.6.0
  [587475ba] + Flux v0.13.14
    Updating `/tmp/jl_yR7NT1/Manifest.toml`
  [621f4979] + AbstractFFTs v1.3.1
  [7d9f7c33] + Accessors v0.1.28
  [79e6a3ab] + Adapt v3.6.1
  [dce04be8] + ArgCheck v2.3.0
  [a9b6321e] + Atomix v0.1.0
  [ab4f0b2a] + BFloat16s v0.4.2
  [198e06fe] + BangBang v0.3.37
  [9718e550] + Baselet v0.1.1
  [fa961155] + CEnum v0.4.2
  [052768ef] + CUDA v4.1.1
  [1af6417a] + CUDA_Runtime_Discovery v0.1.1
  [082447d4] + ChainRules v1.48.0
  [d360d2e6] + ChainRulesCore v1.15.7
  [9e997f8a] + ChangesOfVariables v0.1.6
  [bbf7d656] + CommonSubexpressions v0.3.0
  [34da2185] + Compat v4.6.1
  [a33af91c] + CompositionsBase v0.1.1
  [187b0558] + ConstructionBase v1.5.1
  [6add18c4] + ContextVariablesX v0.1.3
  [9a962f9c] + DataAPI v1.14.0
  [864edb3b] + DataStructures v0.18.13
  [e2d170a0] + DataValueInterfaces v1.0.0
  [244e2a9f] + DefineSingletons v0.1.2
  [163ba53b] + DiffResults v1.1.0
  [b552c78f] + DiffRules v1.13.0
  [ffbed154] + DocStringExtensions v0.9.3
  [e2ba6199] + ExprTools v0.1.9
  [7a1cc6ca] + FFTW v1.6.0
  [cc61a311] + FLoops v0.2.1
  [b9860ae5] + FLoopsBase v0.1.1
  [1a297f60] + FillArrays v0.13.10
  [587475ba] + Flux v0.13.14
  [9c68100b] + FoldsThreads v0.1.1
  [f6369f11] + ForwardDiff v0.10.35
  [069b7b12] + FunctionWrappers v1.1.3
  [d9f16b24] + Functors v0.4.3
  [0c68f7d7] + GPUArrays v8.6.5
  [46192b85] + GPUArraysCore v0.1.4
  [61eb1bfa] + GPUCompiler v0.18.0
  [7869d1d1] + IRTools v0.4.9
  [22cec73e] + InitialValues v0.3.1
  [3587e190] + InverseFunctions v0.1.8
  [92d709cd] + IrrationalConstants v0.2.2
  [82899510] + IteratorInterfaceExtensions v1.0.0
  [692b3bcd] + JLLWrappers v1.4.1
  [b14d175d] + JuliaVariables v0.2.4
  [63c18a36] + KernelAbstractions v0.9.1
  [929cbde3] + LLVM v4.17.1
  [2ab3a3ac] + LogExpFunctions v0.3.23
  [d8e11817] + MLStyle v0.4.17
  [f1d291b0] + MLUtils v0.4.1
  [1914dd2f] + MacroTools v0.5.10
  [128add7d] + MicroCollections v0.1.4
  [e1d29d7a] + Missings v1.1.0
  [872c559c] + NNlib v0.8.19
  [a00861dc] + NNlibCUDA v0.2.7
  [77ba4419] + NaNMath v1.0.2
  [71a1bf82] + NameResolution v0.1.5
  [0b1bfda6] + OneHotArrays v0.2.3
  [3bd65402] + Optimisers v0.2.17
  [bac558e1] + OrderedCollections v1.4.1
  [21216c6a] + Preferences v1.3.0
  [8162dcfd] + PrettyPrint v0.2.0
  [33c8b6b6] + ProgressLogging v0.1.4
  [74087812] + Random123 v1.6.0
  [e6cf234a] + RandomNumbers v1.5.3
  [c1ae055f] + RealDot v0.1.0
  [189a3867] + Reexport v1.2.2
  [ae029012] + Requires v1.3.0
  [efcf1570] + Setfield v1.1.1
  [605ecd9f] + ShowCases v0.1.0
  [699a6c99] + SimpleTraits v0.9.4
  [66db9d55] + SnoopPrecompile v1.0.3
  [a2af1166] + SortingAlgorithms v1.1.0
  [276daf66] + SpecialFunctions v2.2.0
  [171d559e] + SplittablesBase v0.1.15
  [90137ffa] + StaticArrays v1.5.19
  [1e83bf80] + StaticArraysCore v1.4.0
  [82ae8749] + StatsAPI v1.5.0
  [2913bbd2] + StatsBase v0.33.21
  [09ab397b] + StructArrays v0.6.15
  [3783bdb8] + TableTraits v1.0.1
  [bd369af6] + Tables v1.10.1
  [a759f4b9] + TimerOutputs v0.5.22
  [28d57a85] + Transducers v0.4.75
  [013be700] + UnsafeAtomics v0.2.1
  [d80eeb9a] + UnsafeAtomicsLLVM v0.1.0
  [e88e6eb3] + Zygote v0.6.59
  [700de1a5] + ZygoteRules v0.2.3
  [02a925ec] + cuDNN v1.0.2
⌅ [4ee394cb] + CUDA_Driver_jll v0.4.0+2
  [76a88914] + CUDA_Runtime_jll v0.4.0+2
  [62b44479] + CUDNN_jll v8.8.1+0
  [f5851436] + FFTW_jll v3.3.10+0
  [1d5cc7b8] + IntelOpenMP_jll v2018.0.3+2
⌅ [dad2f222] + LLVMExtra_jll v0.0.18+0
  [856f044c] + MKL_jll v2022.2.0+0
  [efe28fd5] + OpenSpecFun_jll v0.5.5+0
  [0dad84c5] + ArgTools v1.1.1
  [56f22d72] + Artifacts
  [2a0f44e3] + Base64
  [ade2ca70] + Dates
  [8bb1440f] + DelimitedFiles
  [8ba89e20] + Distributed
  [f43a241f] + Downloads v1.6.0
  [7b1f6079] + FileWatching
  [9fa8497b] + Future
  [b77e0a4c] + InteractiveUtils
  [4af54fe1] + LazyArtifacts
  [b27032c2] + LibCURL v0.6.3
  [76f85450] + LibGit2
  [8f399da3] + Libdl
  [37e2e46d] + LinearAlgebra
  [56ddb016] + Logging
  [d6f4376e] + Markdown
  [a63ad114] + Mmap
  [ca575930] + NetworkOptions v1.2.0
  [44cfe95a] + Pkg v1.8.0
  [de0858da] + Printf
  [3fa0cd96] + REPL
  [9a3f8284] + Random
  [ea8e919c] + SHA v0.7.0
  [9e88b42a] + Serialization
  [6462fe0b] + Sockets
  [2f01184e] + SparseArrays
  [10745b16] + Statistics
  [fa267f1f] + TOML v1.0.0
  [a4e569a6] + Tar v1.10.1
  [8dfed614] + Test
  [cf7118a7] + UUIDs
  [4ec0a83e] + Unicode
  [e66e0078] + CompilerSupportLibraries_jll v1.0.1+0
  [deac9b47] + LibCURL_jll v7.84.0+0
  [29816b5a] + LibSSH2_jll v1.10.2+0
  [c8ffd9c3] + MbedTLS_jll v2.28.0+0
  [14a3606d] + MozillaCACerts_jll v2022.2.1
  [4536629a] + OpenBLAS_jll v0.3.20+0
  [05823500] + OpenLibm_jll v0.8.1+0
  [83775a58] + Zlib_jll v1.2.12+3
  [8e850b90] + libblastrampoline_jll v5.1.1+0
  [8e850ede] + nghttp2_jll v1.48.0+0
  [3f19e933] + p7zip_jll v17.4.0+0
        Info Packages marked with ⌅ have new versions available but compatibility constraints restrict them from upgrading. To see why use `status --outdated -m`
Precompiling project...
  5 dependencies successfully precompiled in 54 seconds. 100 already precompiled.

julia> using CUDA, FFTW, Flux

julia> x = CUDA.randn(3)
3-element CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}:
 0.19325934
 0.55793864
 0.08928435

julia> gradient(()->sum(abs.(fft(x))), Flux.params(x)) # this works
Grads(...)

julia> gradient(()->sum(abs.(rfft(x))), Flux.params(x))
ERROR: GPU compilation of broadcast_kernel(CUDA.CuKernelContext, CuDeviceVector{ComplexF32, 1}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1}, Tuple{Base.OneTo{Int64}}, typeof(/), Tuple{Base.Broadcast.Extruded{CuDeviceVector{ComplexF32, 1}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Extruded{Vector{Int64}, Tuple{Bool}, Tuple{Int64}}}}, Int64) in world 32592 failed
KernelError: passing and using non-bitstype argument

Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1}, Tuple{Base.OneTo{Int64}}, typeof(/), Tuple{Base.Broadcast.Extruded{CuDeviceVector{ComplexF32, 1}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Extruded{Vector{Int64}, Tuple{Bool}, Tuple{Int64}}}}, which is not isbits:
  .args is of type Tuple{Base.Broadcast.Extruded{CuDeviceVector{ComplexF32, 1}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Extruded{Vector{Int64}, Tuple{Bool}, Tuple{Int64}}} which is not isbits.
    .2 is of type Base.Broadcast.Extruded{Vector{Int64}, Tuple{Bool}, Tuple{Int64}} which is not isbits.
      .x is of type Vector{Int64} which is not isbits.


Stacktrace:
  [1] check_invocation(job::GPUCompiler.CompilerJob)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/anMCs/src/validation.jl:101
  [2] macro expansion
    @ ~/.julia/packages/GPUCompiler/anMCs/src/driver.jl:154 [inlined]
  [3] macro expansion
    @ ~/.julia/packages/TimerOutputs/LHjFw/src/TimerOutput.jl:253 [inlined]
  [4] macro expansion
    @ ~/.julia/packages/GPUCompiler/anMCs/src/driver.jl:152 [inlined]
  [5] emit_julia(job::GPUCompiler.CompilerJob; validate::Bool)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/anMCs/src/utils.jl:83
  [6] emit_julia
    @ ~/.julia/packages/GPUCompiler/anMCs/src/utils.jl:77 [inlined]
  [7] compile(job::GPUCompiler.CompilerJob, ctx::LLVM.Context)
    @ CUDA ~/.julia/packages/CUDA/N71Iw/src/compiler/compilation.jl:105
  [8] #203
    @ ~/.julia/packages/CUDA/N71Iw/src/compiler/compilation.jl:100 [inlined]
  [9] JuliaContext(f::CUDA.var"#203#204"{GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams}})
    @ GPUCompiler ~/.julia/packages/GPUCompiler/anMCs/src/driver.jl:76
 [10] compile
    @ ~/.julia/packages/CUDA/N71Iw/src/compiler/compilation.jl:99 [inlined]
 [11] actual_compilation(cache::Dict{UInt64, Any}, key::UInt64, cfg::GPUCompiler.CompilerConfig{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams}, ft::Type, tt::Type, world::UInt64, compiler::typeof(CUDA.compile), linker::typeof(CUDA.link))
    @ GPUCompiler ~/.julia/packages/GPUCompiler/anMCs/src/cache.jl:184
 [12] cached_compilation(cache::Dict{UInt64, Any}, cfg::GPUCompiler.CompilerConfig{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams}, ft::Type, tt::Type, compiler::Function, linker::Function)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/anMCs/src/cache.jl:163
 [13] macro expansion
    @ ~/.julia/packages/CUDA/N71Iw/src/compiler/execution.jl:310 [inlined]
 [14] macro expansion
    @ ./lock.jl:223 [inlined]
 [15] cufunction(f::GPUArrays.var"#broadcast_kernel#28", tt::Type{Tuple{CUDA.CuKernelContext, CuDeviceVector{ComplexF32, 1}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1}, Tuple{Base.OneTo{Int64}}, typeof(/), Tuple{Base.Broadcast.Extruded{CuDeviceVector{ComplexF32, 1}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Extruded{Vector{Int64}, Tuple{Bool}, Tuple{Int64}}}}, Int64}}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ CUDA ~/.julia/packages/CUDA/N71Iw/src/compiler/execution.jl:306
 [16] cufunction
    @ ~/.julia/packages/CUDA/N71Iw/src/compiler/execution.jl:303 [inlined]
 [17] macro expansion
    @ ~/.julia/packages/CUDA/N71Iw/src/compiler/execution.jl:104 [inlined]
 [18] #launch_heuristic#244
    @ ~/.julia/packages/CUDA/N71Iw/src/gpuarrays.jl:17 [inlined]
 [19] _copyto!
    @ ~/.julia/packages/GPUArrays/XR4WO/src/host/broadcast.jl:65 [inlined]
 [20] copyto!
    @ ~/.julia/packages/GPUArrays/XR4WO/src/host/broadcast.jl:46 [inlined]
 [21] copy
    @ ~/.julia/packages/GPUArrays/XR4WO/src/host/broadcast.jl:37 [inlined]
 [22] materialize
    @ ./broadcast.jl:860 [inlined]
 [23] (::AbstractFFTs.AbstractFFTsChainRulesCoreExt.var"#rfft_pullback#6"{UnitRange{Int64}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, Vector{Int64}, Int64})(ȳ::CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer})
    @ AbstractFFTs.AbstractFFTsChainRulesCoreExt ~/.julia/packages/AbstractFFTs/0uOAT/ext/AbstractFFTsChainRulesCoreExt.jl:40
 [24] ZBack
    @ ~/.julia/packages/Zygote/TSj5C/src/compiler/chainrules.jl:211 [inlined]
 [25] Pullback
    @ ~/.julia/packages/AbstractFFTs/0uOAT/src/definitions.jl:62 [inlined]
 [26] Pullback
    @ ./REPL[8]:1 [inlined]
 [27] (::Zygote.Pullback{Tuple{var"#5#6"}, Tuple{Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(rfft), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.Pullback{Tuple{typeof(ndims), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.ZBack{AbstractFFTs.AbstractFFTsChainRulesCoreExt.var"#rfft_pullback#6"{UnitRange{Int64}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, Vector{Int64}, Int64}}, Zygote.ZBack{ChainRules.var"#:_pullback#275"{Tuple{Int64, Int64}}}}}, Zygote.var"#4160#back#1438"{Zygote.var"#1434#1437"{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcasted), typeof(abs), CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.var"#1982#back#200"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing}, Tuple{}}, Zygote.var"#combine_styles_pullback#1169"{Tuple{Nothing, Nothing}}}}, Zygote.var"#2841#back#683"{Zygote.var"#map_back#677"{typeof(Base.Broadcast.broadcastable), 1, Tuple{Tuple{}}, Tuple{Val{0}}, Tuple{}}}, Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#4128#back#1421"{Zygote.var"#bc_fwd_back#1409"{1, CuArray{ForwardDiff.Dual{Nothing, Float32, 2}, 1, CUDA.Mem.DeviceBuffer}, Tuple{CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer}}, Val{1}}}}}, Zygote.var"#1982#back#200"{typeof(identity)}}}, Zygote.var"#1955#back#190"{Zygote.var"#186#189"{Zygote.Context{true}, GlobalRef, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/TSj5C/src/compiler/interface2.jl:0
 [28] (::Zygote.var"#118#119"{Zygote.Params{Zygote.Buffer{Any, Vector{Any}}}, Zygote.Pullback{Tuple{var"#5#6"}, Tuple{Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(rfft), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.Pullback{Tuple{typeof(ndims), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.ZBack{AbstractFFTs.AbstractFFTsChainRulesCoreExt.var"#rfft_pullback#6"{UnitRange{Int64}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, Vector{Int64}, Int64}}, Zygote.ZBack{ChainRules.var"#:_pullback#275"{Tuple{Int64, Int64}}}}}, Zygote.var"#4160#back#1438"{Zygote.var"#1434#1437"{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcasted), typeof(abs), CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.var"#1982#back#200"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing}, Tuple{}}, Zygote.var"#combine_styles_pullback#1169"{Tuple{Nothing, Nothing}}}}, Zygote.var"#2841#back#683"{Zygote.var"#map_back#677"{typeof(Base.Broadcast.broadcastable), 1, Tuple{Tuple{}}, Tuple{Val{0}}, Tuple{}}}, Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#4128#back#1421"{Zygote.var"#bc_fwd_back#1409"{1, CuArray{ForwardDiff.Dual{Nothing, Float32, 2}, 1, CUDA.Mem.DeviceBuffer}, Tuple{CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer}}, Val{1}}}}}, Zygote.var"#1982#back#200"{typeof(identity)}}}, Zygote.var"#1955#back#190"{Zygote.var"#186#189"{Zygote.Context{true}, GlobalRef, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, Zygote.Context{true}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/TSj5C/src/compiler/interface.jl:389
 [29] gradient(f::Function, args::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
    @ Zygote ~/.julia/packages/Zygote/TSj5C/src/compiler/interface.jl:97
 [30] top-level scope
    @ REPL[8]:1
 [31] top-level scope
    @ ~/.julia/packages/CUDA/N71Iw/src/initialization.jl:163

Expected behavior

rfft should not lead to error.

Version info

Details on Julia:

julia> versioninfo()
Julia Version 1.8.5
Commit 17cfb8e65ea (2023-01-08 06:45 UTC)
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 16 × Intel(R) Xeon(R) W-2245 CPU @ 3.90GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-13.0.1 (ORCJIT, cascadelake)
  Threads: 16 on 16 virtual cores
Environment:
  JULIA_NUM_THREADS = 16

Details on CUDA:

julia> CUDA.versioninfo()
CUDA runtime 12.1, artifact installation
CUDA driver 12.1
NVIDIA driver 510.108.3, originally for CUDA 11.6

Libraries: 
- CUBLAS: 12.1.0
- CURAND: 10.3.2
- CUFFT: 11.0.2
- CUSOLVER: 11.4.4
- CUSPARSE: 12.0.2
- CUPTI: 18.0.0
- NVML: 11.0.0+510.108.3

Toolchain:
- Julia: 1.8.5
- LLVM: 13.0.1
- PTX ISA support: 3.2, 4.0, 4.1, 4.2, 4.3, 5.0, 6.0, 6.1, 6.3, 6.4, 6.5, 7.0, 7.1, 7.2
- Device capability support: sm_37, sm_50, sm_52, sm_53, sm_60, sm_61, sm_62, sm_70, sm_72, sm_75, sm_80, sm_86

1 device:
  0: NVIDIA T1000 8GB (sm_75, 7.000 GiB / 8.000 GiB available)

Additional context

I hope this is the correct repository to raise this issue in but please let me know if I should redirect it to other package (Flux or FFTW)?

@ziyiyin97 ziyiyin97 added the bug Something isn't working label Mar 26, 2023
@ziyiyin97 ziyiyin97 changed the title rfft leads to non-bit error Flux's gradient differentiatingrfft leads to non-bit error Mar 26, 2023
@maleadt
Copy link
Member

maleadt commented Mar 27, 2023

As CUDA.jl does not define any adjoints, this would be better filed on the Zygote.jl repository. Or maybe start with a Discourse post to see if you aren't doing anything wrong.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants