Skip to content

Commit

Permalink
Fix grad conv im2col (#539)
Browse files Browse the repository at this point in the history
* Fix grad conv im2col

* Also fix depthwise

* Enable prevously broken tests

* Revert "Enable prevously broken tests"

This reverts commit d648fdd.

* Add explicit im2col test

* Fix and test third case

* More tests now pass
  • Loading branch information
wsmoses authored Sep 27, 2023
1 parent 37d9a02 commit 8da76bd
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 13 deletions.
14 changes: 10 additions & 4 deletions src/impl/conv_im2col.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ function ∇conv_data_im2col!(
col_ptr = pointer(col_slice)
gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr)
end
col2im!(view(dx, :, :, :, :, batch_idx), col_slice, cdims)
col2im!(view(dx, :, :, :, :, batch_idx), col_slice, cdims, beta)
end
end
end
Expand Down Expand Up @@ -276,7 +276,7 @@ end


"""
col2im!(x, col, cdims)
col2im!(x, col, cdims, beta=0)
Does the inverse of `im2col!()`, converting `col` back into a 3d image, used for backward
passes, transposed convolutions, etc...
Expand All @@ -287,7 +287,7 @@ desperate enough yet.
"""
col2im!

function col2im!(x::AbstractArray{T,4}, col::AbstractArray{T,2}, cdims::ConvDims) where T
function col2im!(x::AbstractArray{T,4}, col::AbstractArray{T,2}, cdims::ConvDims, beta::T=T(0)) where T
if spatial_dims(cdims) != 3
throw(DimensionMismatch("col2im!() only accepts 3d convoluitional inputs"))
end
Expand All @@ -303,7 +303,13 @@ function col2im!(x::AbstractArray{T,4}, col::AbstractArray{T,2}, cdims::ConvDims

# TODO: Rewrite this method so we don't have this fill!() at the beginning!
# Calculate each output pixel once rather than accumulating into it?
fill!(x, T(0))
if beta == T(0)
fill!(x, T(0))
elseif beta == T(1)
# nothing
else
x .*= beta
end

# Reshape col for easy access.
col_reshaped = reshape(col, (
Expand Down
2 changes: 1 addition & 1 deletion src/impl/depthwiseconv_im2col.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ function ∇depthwiseconv_data_im2col!(
gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr)
end
end
col2im!(view(dx, :, :, :, :, batch_idx), col_slice, cdims)
col2im!(view(dx, :, :, :, :, batch_idx), col_slice, cdims, beta)
end
end
end
Expand Down
25 changes: 17 additions & 8 deletions test/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,16 @@ ddims(x) = dropdims(x, dims=(ndims(x)-1, ndims(x)))
end
end

# Test im2col

for beta in (-2.0, -1.0, 0.0, 0.5, 1.0, 2.0)
cache_dx, cache_dy, cache_w = ([0.17;;; 0.19;;; 0.23], [0.11;;; 0.13;;; 0.15], [1.0;;;])
dx_old = copy(cache_dx)
cdims = DenseConvDims(cache_dx, cache_w)
NNlib.∇conv_data_im2col!(cache_dx, cache_dy, cache_w, cdims; alpha=1.0, beta)
@test isapprox(cache_dx, dx_old * beta + cache_dy, rtol = 1.0e-7)
end

# Test all in-place implementations/interfaces
for (∇conv_filter!, ∇conv_data!) in (
(NNlib.∇conv_filter!, NNlib.∇conv_data!),
Expand All @@ -407,47 +417,46 @@ ddims(x) = dropdims(x, dims=(ndims(x)-1, ndims(x)))
)
#α, β = 2*rand(rng) - 1, 2*rand(rng) - 1
α, β = 2e0, -1e0
flag = ∇conv_data! in (NNlib.∇conv_data!, NNlib.∇conv_data_im2col!)

@testset "$(∇conv_filter!)/$(∇conv_data!)" begin
# First, your basic convolution with no parameters
cdims = DenseConvDims(x, w)
dy = NNlib.conv(x, w, cdims)
@test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw + β*w, rtol = 1.0e-7)
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx + β*x, rtol = 1.0e-7) broken=flag
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx + β*x, rtol = 1.0e-7)

# Next, test convolution on views and alternate datatypes:
@test isapprox(ddims(∇conv_filter!(copy(w), x, view(dy, repeat([:], ndims(dy))...), cdims; alpha=α, beta=β)), α*dw + β*w, rtol = 1.0e-7)
@test isapprox(ddims(∇conv_data!(copy(x), view(dy, repeat([:], ndims(dy))...), w, cdims; alpha=α, beta=β)), α*dx + β*x, rtol = 1.0e-7) broken=flag
@test isapprox(ddims(∇conv_data!(copy(x), view(dy, repeat([:], ndims(dy))...), w, cdims; alpha=α, beta=β)), α*dx + β*x, rtol = 1.0e-7)

@test isapprox(ddims(∇conv_filter!(Float32.(copy(w)), Float32.(x), Float32.(dy), cdims; alpha=Float32(α), beta=Float32(β))), α*dw + β*w, rtol = 1.0e-7)
@test isapprox(ddims(∇conv_data!(Float32.(copy(x)), Float32.(dy), Float32.(w), cdims; alpha=Float32(α), beta=Float32(β))), α*dx + β*x, rtol = 1.0e-7) broken=flag
@test isapprox(ddims(∇conv_data!(Float32.(copy(x)), Float32.(dy), Float32.(w), cdims; alpha=Float32(α), beta=Float32(β))), α*dx + β*x, rtol = 1.0e-7)

# Next, introduce stride:
cdims = DenseConvDims(x, w; stride=2)
dy = NNlib.conv(x, w, cdims)
flag_ = ∇conv_filter! == NNlib.∇conv_filter_direct! && rank in (1,3)
@test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_stride + β*w, rtol = 1.0e-7) broken=flag_
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_stride + β*x, rtol = 1.0e-7) broken=flag
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_stride + β*x, rtol = 1.0e-7)

# Next, introduce dilation:
cdims = DenseConvDims(x, w; dilation=2)
dy = NNlib.conv(x, w, cdims)
flag_ = ∇conv_data! == NNlib.∇conv_data_direct! && rank == 3
@test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_dil + β*w, rtol = 1.0e-7)
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_dil + β*x, rtol = 1.0e-7) broken=flag || flag_
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_dil + β*x, rtol = 1.0e-7) broken=flag_

# Next, introduce padding:
cdims = DenseConvDims(x, w; padding=1)
dy = NNlib.conv(x, w, cdims)
@test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_pad + β*w, rtol = 1.0e-7)
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_pad + β*x, rtol = 1.0e-7) broken=flag
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_pad + β*x, rtol = 1.0e-7)

# Next, test crosscor/conv with a flipped kernel
cdims = DenseConvDims(x, w; flipkernel=true)
dy = NNlib.conv(x, w, cdims)
@test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_flip + β*w, rtol = 1.0e-7)
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_flip + β*x, rtol = 1.0e-7) broken=flag
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_flip + β*x, rtol = 1.0e-7)
end
end
end
Expand Down

0 comments on commit 8da76bd

Please sign in to comment.