Skip to content

Commit

Permalink
Merge pull request #389 from zsoerenm/complex-convolution-fix
Browse files Browse the repository at this point in the history
Fix gradient of convolution for complex values
  • Loading branch information
ToucheSir authored Mar 2, 2022
2 parents 51595b7 + 9b6d233 commit aa86827
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/gemm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ for (gemm, elt) in gemm_datatype_mappings
beta::$(elt), C::Ptr{$elt})
# Convert our compile-time transpose marker to a char for BLAS
convtrans(V::Val{false}) = 'N'
convtrans(V::Val{true}) = 'T'
convtrans(V::Val{true}) = 'C'

if transA == Val(false)
lda = M
Expand Down
4 changes: 2 additions & 2 deletions src/impl/conv_direct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ Calculate the gradient imposed upon `x` in the convolution `y = x * w`.
function ∇conv_data_direct!(dx::AbstractArray{xT,5}, dy::AbstractArray{yT,5},
w::AbstractArray{wT,5}, cdims::DenseConvDims;
alpha::xT=xT(1), beta=false) where {xT, yT, wT}
w = transpose_swapbatch(w[end:-1:1, end:-1:1, end:-1:1, :, :])
w = conj(transpose_swapbatch(w[end:-1:1, end:-1:1, end:-1:1, :, :]))
dy = predilate(dy, stride(cdims))
ctdims = DenseConvDims(dy, w; padding=transpose_pad(cdims),
dilation=dilation(cdims),
Expand All @@ -188,7 +188,7 @@ Calculate the gradient imposed upon `w` in the convolution `y = x * w`.
function ∇conv_filter_direct!(dw::AbstractArray{wT,5}, x::AbstractArray{xT,5},
dy::AbstractArray{yT,5}, cdims::DenseConvDims;
alpha::wT=wT(1), beta=false) where {xT, yT, wT}
x = transpose_swapbatch(x[end:-1:1, end:-1:1, end:-1:1, :, :])
x = conj(transpose_swapbatch(x[end:-1:1, end:-1:1, end:-1:1, :, :]))
dy = transpose_swapbatch(predilate(dy, stride(cdims)))
ctdims = DenseConvDims(dy, x; padding=transpose_pad(cdims),
stride=dilation(cdims))
Expand Down
30 changes: 30 additions & 0 deletions test/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,36 @@ ddims(x) = dropdims(x, dims=(ndims(x)-1, ndims(x)))
end
end

@testset "Complex Dense Convolution" begin
# For now only 1 dimensional 1x1 convolution
x = reshape(complex.(Float64[1:4;], Float64[1:4;] .+ 1), 1, 4, 1)
w = reshape(complex.(Float64[1:4;] .+ 2, Float64[1:4;] .+ 3), 1, 4, 1)
cdims = DenseConvDims(x, w)
convs = [NNlib.conv, NNlib.conv_im2col, NNlib.conv_direct,]
NNlib.is_nnpack_available() && push!(convs, NNlib.conv_nnpack)
for conv in convs
if NNlib.is_nnpack_available()
if conv == NNlib.conv_nnpack && !NNlib.nnpack_supported_operation(cdims)
continue
end
end
@testset "$(conv)" begin
@test isapprox(ddims(conv(x, w, cdims)), [transpose(vec(w)) * vec(x)], rtol = 1.0e-7)
end
end
dy = NNlib.conv(x, w, cdims)
for (∇conv_filter, ∇conv_data) in (
(NNlib.∇conv_filter, NNlib.∇conv_data),
(NNlib.∇conv_filter_im2col, NNlib.∇conv_data_im2col),
(NNlib.∇conv_filter_direct, NNlib.∇conv_data_direct),
)
@testset "$(∇conv_filter)/$(∇conv_data)" begin
@test isapprox(∇conv_filter(x, dy, cdims), conj(x) .* dy, rtol = 1.0e-7)
@test isapprox(∇conv_data(dy, w, cdims), dy .* conj(w), rtol = 1.0e-7)
end
end
end

if get(ENV, "NNLIB_TEST_FUZZING", "false") == "true"
# @info("Skipping Convolutional fuzzing tests, set NNLIB_TEST_FUZZING=true to run them")
@testset "fuzzing" begin
Expand Down

0 comments on commit aa86827

Please sign in to comment.