Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUFFT improvements #1313

Merged
merged 2 commits into from
Jan 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 56 additions & 66 deletions lib/cufft/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ Base.:(*)(p::ScaledPlan, x::DenseCuArray) = rmul!(p.p * x, p.scale)

## plan structure

# K is a flag for forward/backward
# K is an integer flag for forward/backward
# also used as an alias for r2c/c2r

# inplace is a boolean flag

abstract type CuFFTPlan{T<:cufftNumber, K, inplace} <: Plan{T} end

# for some reason, cufftHandle is an integer and not a pointer...
Expand Down Expand Up @@ -382,143 +384,131 @@ end

## plan execution

function assert_applicable(p::CuFFTPlan{T,K}, X::DenseCuArray{T}) where {T,K}
# NOTE: "in-place complex-to-real FFTs may overwrite arbitrary imaginary input point values
# [...]. Out-of-place complex-to-real FFT will always overwrite input buffer."
# see # JuliaGPU/CuArrays.jl#345, NVIDIA/cuFFT#2714055.

function assert_applicable(p::CuFFTPlan{T}, X::DenseCuArray{T}) where {T}
(size(X) == p.sz) ||
throw(ArgumentError("CuFFT plan applied to wrong-size input"))
end

function assert_applicable(p::CuFFTPlan{T,K}, X::DenseCuArray{T}, Y::DenseCuArray{Ty}) where {T,K,Ty}
function assert_applicable(p::CuFFTPlan{T,K,inplace}, X::DenseCuArray{T}, Y::DenseCuArray{T}) where {T,K,inplace}
assert_applicable(p, X)
(size(Y) == p.osz) ||
if size(Y) != p.osz
throw(ArgumentError("CuFFT plan applied to wrong-size output"))
# type errors should be impossible by dispatch, but just in case:
if p.xtype ∈ [CUFFT_C2R, CUFFT_Z2D]
(Ty == real(T)) ||
throw(ArgumentError("Type mismatch for argument Y"))
elseif p.xtype ∈ [CUFFT_R2C, CUFFT_D2Z]
(Ty == complex(T)) ||
throw(ArgumentError("Type mismatch for argument Y"))
else
(Ty == T) ||
throw(ArgumentError("Type mismatch for argument Y"))
elseif inplace != (pointer(X) == pointer(Y))
throw(ArgumentError(string("CuFFT ",
inplace ? "in-place" : "out-of-place",
" plan applied to ",
inplace ? "out-of-place" : "in-place",
" data")))
end
end

function unsafe_execute!(plan::cCuFFTPlan{cufftComplex,K,true,N},
x::DenseCuArray{cufftComplex,N}) where {K,N}
function unsafe_execute!(plan::cCuFFTPlan{cufftComplex,K,<:Any,N},
x::DenseCuArray{cufftComplex,N},
y::DenseCuArray{cufftComplex,N}) where {K,N}
@assert plan.xtype == CUFFT_C2C
update_stream(plan)
cufftExecC2C(plan, x, x, K)
cufftExecC2C(plan, x, y, K)
end

function unsafe_execute!(plan::rCuFFTPlan{cufftComplex,K,true,N},
x::DenseCuArray{cufftComplex,N}) where {K,N}
x::DenseCuArray{cufftComplex,N},
y::DenseCuArray{cufftReal,N}) where {K,N}
@assert plan.xtype == CUFFT_C2R
update_stream(plan)
cufftExecC2R(plan, x, x)
end

function unsafe_execute!(plan::cCuFFTPlan{cufftComplex,K,false,N},
x::DenseCuArray{cufftComplex,N}, y::DenseCuArray{cufftComplex}
) where {K,N}
@assert plan.xtype == CUFFT_C2C
x = copy(x) # JuliaGPU/CuArrays.jl#345, NVIDIA/cuFFT#2714055
update_stream(plan)
cufftExecC2C(plan, x, y, K)
unsafe_free!(x)
cufftExecC2R(plan, x, y)
end
function unsafe_execute!(plan::rCuFFTPlan{cufftComplex,K,false,N},
x::DenseCuArray{cufftComplex,N}, y::DenseCuArray{cufftReal}
) where {K,N}
x::DenseCuArray{cufftComplex,N},
y::DenseCuArray{cufftReal}) where {K,N}
@assert plan.xtype == CUFFT_C2R
x = copy(x) # JuliaGPU/CuArrays.jl#345, NVIDIA/cuFFT#2714055
x = copy(x)
update_stream(plan)
cufftExecC2R(plan, x, y)
unsafe_free!(x)
end

function unsafe_execute!(plan::rCuFFTPlan{cufftReal,K,false,N},
x::DenseCuArray{cufftReal,N}, y::DenseCuArray{cufftComplex,N}
) where {K,N}
function unsafe_execute!(plan::rCuFFTPlan{cufftReal,K,<:Any,N},
maleadt marked this conversation as resolved.
Show resolved Hide resolved
x::DenseCuArray{cufftReal,N},
y::DenseCuArray{cufftComplex,N}) where {K,N}
@assert plan.xtype == CUFFT_R2C
x = copy(x) # JuliaGPU/CuArrays.jl#345, NVIDIA/cuFFT#2714055
update_stream(plan)
cufftExecR2C(plan, x, y)
unsafe_free!(x)
end

function unsafe_execute!(plan::cCuFFTPlan{cufftDoubleComplex,K,true,N},
x::DenseCuArray{cufftDoubleComplex,N}) where {K,N}
function unsafe_execute!(plan::cCuFFTPlan{cufftDoubleComplex,K,<:Any,N},
x::DenseCuArray{cufftDoubleComplex,N},
y::DenseCuArray{cufftDoubleComplex}) where {K,N}
@assert plan.xtype == CUFFT_Z2Z
update_stream(plan)
cufftExecZ2Z(plan, x, x, K)
cufftExecZ2Z(plan, x, y, K)
end

function unsafe_execute!(plan::rCuFFTPlan{cufftDoubleComplex,K,true,N},
x::DenseCuArray{cufftDoubleComplex,N}) where {K,N}
x::DenseCuArray{cufftDoubleComplex,N},
y::DenseCuArray{cufftDoubleReal}) where {K,N}
update_stream(plan)
@assert plan.xtype == CUFFT_Z2D
cufftExecZ2D(plan, x, x)
end

function unsafe_execute!(plan::cCuFFTPlan{cufftDoubleComplex,K,false,N},
x::DenseCuArray{cufftDoubleComplex,N}, y::DenseCuArray{cufftDoubleComplex}
) where {K,N}
@assert plan.xtype == CUFFT_Z2Z
x = copy(x) # JuliaGPU/CuArrays.jl#345, NVIDIA/cuFFT#2714055
update_stream(plan)
cufftExecZ2Z(plan, x, y, K)
unsafe_free!(x)
cufftExecZ2D(plan, x, y)
end
function unsafe_execute!(plan::rCuFFTPlan{cufftDoubleComplex,K,false,N},
x::DenseCuArray{cufftDoubleComplex,N}, y::DenseCuArray{cufftDoubleReal}
) where {K,N}
x::DenseCuArray{cufftDoubleComplex,N},
y::DenseCuArray{cufftDoubleReal}) where {K,N}
@assert plan.xtype == CUFFT_Z2D
x = copy(x) # JuliaGPU/CuArrays.jl#345, NVIDIA/cuFFT#2714055
x = copy(x)
update_stream(plan)
cufftExecZ2D(plan, x, y)
unsafe_free!(x)
end

function unsafe_execute!(plan::rCuFFTPlan{cufftDoubleReal,K,false,N},
x::DenseCuArray{cufftDoubleReal,N}, y::DenseCuArray{cufftDoubleComplex,N}
) where {K,N}
function unsafe_execute!(plan::rCuFFTPlan{cufftDoubleReal,K,<:Any,N},
x::DenseCuArray{cufftDoubleReal,N},
y::DenseCuArray{cufftDoubleComplex,N}) where {K,N}
@assert plan.xtype == CUFFT_D2Z
x = copy(x) # JuliaGPU/CuArrays.jl#345, NVIDIA/cuFFT#2714055
update_stream(plan)
cufftExecD2Z(plan, x, y)
unsafe_free!(x)
end

function LinearAlgebra.mul!(y::DenseCuArray{Ty}, p::CuFFTPlan{T,K,false}, x::DenseCuArray{T}
) where {Ty,T,K}

## high-level integrations

function LinearAlgebra.mul!(y::DenseCuArray{T}, p::CuFFTPlan{T}, x::DenseCuArray{T}
) where {T}
assert_applicable(p,x,y)
unsafe_execute!(p,x,y)
return y
end

function Base.:(*)(p::cCuFFTPlan{T,K,true,N}, x::DenseCuArray{T,N}) where {T,K,N}
assert_applicable(p,x)
unsafe_execute!(p,x)
unsafe_execute!(p,x,x)
x
end

function Base.:(*)(p::rCuFFTPlan{T,CUFFT_FORWARD,false,N}, x::DenseCuArray{T,N}
) where {T<:cufftReals,N}
assert_applicable(p,x)
@assert p.xtype ∈ [CUFFT_R2C,CUFFT_D2Z]
y = CuArray{complex(T),N}(undef, p.osz)
mul!(y,p,x)
unsafe_execute!(p,x,y)
y
end

function Base.:(*)(p::rCuFFTPlan{T,CUFFT_INVERSE,false,N}, x::DenseCuArray{T,N}
) where {T<:cufftComplexes,N}
assert_applicable(p,x)
@assert p.xtype ∈ [CUFFT_C2R,CUFFT_Z2D]
y = CuArray{real(T),N}(undef, p.osz)
mul!(y,p,x)
unsafe_execute!(p,x,y)
y
end

function Base.:(*)(p::cCuFFTPlan{T,K,false,N}, x::DenseCuArray{T,N}) where {T,K,N}
assert_applicable(p,x)
y = CuArray{T,N}(undef, p.osz)
mul!(y,p,x)
unsafe_execute!(p,x,y)
y
end
8 changes: 3 additions & 5 deletions test/cufft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ function batched(X::AbstractArray{T,N},region) where {T <: Complex,N}
p = plan_fft(d_X,region)
d_Y = p * d_X
Y = collect(d_Y)
@test_maybe_broken isapprox(Y, fftw_X, rtol = MYRTOL, atol = MYATOL)
@test isapprox(Y, fftw_X, rtol = MYRTOL, atol = MYATOL)

pinv = plan_ifft(d_Y,region)
d_Z = pinv * d_Y
Expand Down Expand Up @@ -139,7 +139,7 @@ end
@test_throws ArgumentError batched(X,(3,1))
end

@testset "Batch 2D (in 4D)" begin
CUFFT.version() >= v"10.2" && @testset "Batch 2D (in 4D)" begin
dims = (N1,N2,N3,N4)
for region in [(1,2),(1,4),(3,4)]
X = rand(T, dims)
Expand Down Expand Up @@ -173,14 +173,12 @@ function out_of_place(X::AbstractArray{T,N}) where {T <: Real,N}
pinv2 = inv(p)
d_Z = pinv2 * d_Y
Z = collect(d_Z)
@test_maybe_broken isapprox(Z, X, rtol = MYRTOL, atol = MYATOL)
# JuliaGPU/CUDA.jl#345, NVIDIA/cuFFT#2714102
@test isapprox(Z, X, rtol = MYRTOL, atol = MYATOL)

pinv3 = inv(pinv)
d_W = pinv3 * d_X
W = collect(d_W)
@test isapprox(W, Y, rtol = MYRTOL, atol = MYATOL)
# JuliaGPU/CUDA.jl#345, NVIDIA/cuFFT#2714102
end

function batched(X::AbstractArray{T,N},region) where {T <: Real,N}
Expand Down