Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

new cross convolution layer #423

Closed
wants to merge 11 commits into from
Closed
2 changes: 1 addition & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using Base: tail
using MacroTools, Juno, Requires, Reexport, Statistics, Random
using MacroTools: @forward

export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool,
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
params, mapleaves, cpu, gpu, f32, f64

Expand Down
62 changes: 60 additions & 2 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using NNlib: conv, ∇conv_data, depthwiseconv
using NNlib: conv, ∇conv_data, depthwiseconv, DenseConvDims

@generated sub2(::Val{N}) where N = :(Val($(N-2)))

Expand Down Expand Up @@ -171,6 +171,64 @@ end
(a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))

"""
CrossCor(size, in=>out)
CrossCor(size, in=>out, relu)

Standard cross convolutional layer. `size` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively.

Data should be stored in WHCN order. In other words, a 100×100 RGB image would
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.

Takes the keyword arguments `pad`, `stride` and `dilation`.
"""
struct CrossCor{N,F,A,V}
σ::F
weight::A
bias::V
stride::NTuple{N,Int}
pad::NTuple{N,Int}
dilation::NTuple{N,Int}
end

CrossCor(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N} =
CrossCor(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...)

CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
CrossCor(param(init(k..., ch...)), param(zeros(ch[2])), σ,
stride = stride, pad = pad, dilation = dilation)

@treelike CrossCor

function crosscor(x, w, ddims::DenseConvDims)
ddims = DenseConvDims(ddims, F=true)
return conv(x, w, ddims)
end

function (c::CrossCor)(x)
# TODO: breaks gpu broadcast :(
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
ddims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
σ.(crosscor(x, c.weight, ddims) .+ b)
end

function Base.show(io::IO, l::CrossCor)
print(io, "CrossCor(", size(l.weight)[1:ndims(l.weight)-2])
print(io, ", ", size(l.weight, ndims(l.weight)-1), "=>", size(l.weight, ndims(l.weight)))
l.σ == identity || print(io, ", ", l.σ)
print(io, ")")
end

(a::CrossCor{<:Any,<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
invoke(a, Tuple{AbstractArray}, x)

(a::CrossCor{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))

"""
MaxPool(k)

Expand Down Expand Up @@ -213,4 +271,4 @@ MeanPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N =

function Base.show(io::IO, m::MeanPool)
print(io, "MeanPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
end
end
4 changes: 4 additions & 0 deletions test/cuda/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ c = gpu(Conv((2,2),3=>4))
l = c(gpu(rand(10,10,3,2)))
Flux.back!(sum(l))

c = gpu(CrossCor((2,2),3=>4))
l = c(gpu(rand(10,10,3,2)))
Flux.back!(sum(l))

end

if CuArrays.libcudnn != nothing
Expand Down
13 changes: 13 additions & 0 deletions test/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,16 @@ end

@test size(m4(r), 3) == 3
end

@testset "CrossCor" begin
r = rand(28, 28, 1, 1)
m = Chain(CrossCor((2, 2), 1=>16, relu),
MaxPool((2,2)),
CrossCor((2, 2), 16=>8, relu),
MaxPool((2,2)),
x -> reshape(x, :, size(x, 4)),
Dense(288, 10), softmax)

@test size(m(r)) == (10, 5)

end