diff --git a/NEWS.md b/NEWS.md index 8ac937cee0..f01f12598c 100644 --- a/NEWS.md +++ b/NEWS.md @@ -17,6 +17,7 @@ * [Data.Iris](https://github.com/FluxML/Flux.jl/pull/652) makes Fisher's Iris dataset available with `Iris.labels` and `Iris.features`. * New [InstanceNorm](https://github.com/FluxML/Flux.jl/pull/634), as popularized by [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022). * New [GroupNorm](https://github.com/FluxML/Flux.jl/pull/696), as described in [Group Normalization](https://arxiv.org/abs/1803.08494). +* New [CrossCor](https://github.com/FluxML/Flux.jl/pull/762). AD Changes: diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index ec45c31e7b..3acb910ddb 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -17,6 +17,7 @@ MaxPool MeanPool DepthwiseConv ConvTranspose +CrossCor ``` ## Recurrent Layers diff --git a/src/Flux.jl b/src/Flux.jl index eccdd6a7e5..a041a69a8d 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -6,7 +6,7 @@ using Base: tail using MacroTools, Juno, Requires, Reexport, Statistics, Random using MacroTools: @forward -export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool, +export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool, DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm, params, mapleaves, cpu, gpu, f32, f64 diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 3739fd1c34..ff547b410b 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -198,6 +198,76 @@ end (a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = a(T.(x)) +""" + CrossCor(size, in=>out) + CrossCor(size, in=>out, relu) + +Standard cross convolutional layer. `size` should be a tuple like `(2, 2)`. +`in` and `out` specify the number of input and output channels respectively. + +Example: Applying CrossCor layer to a 1-channel input using a 2x2 window size, + giving us a 16-channel output. Output is activated with ReLU. + + size = (2,2) + in = 1 + out = 16 + CrossCor((2, 2), 1=>16, relu) + +Data should be stored in WHCN order (width, height, # channels, # batches). +In other words, a 100×100 RGB image would be a `100×100×3×1` array, +and a batch of 50 would be a `100×100×3×50` array. + +Takes the keyword arguments `pad`, `stride` and `dilation`. +""" +struct CrossCor{N,M,F,A,V} + σ::F + weight::A + bias::V + stride::NTuple{N,Int} + pad::NTuple{M,Int} + dilation::NTuple{N,Int} +end + +function CrossCor(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity; + stride = 1, pad = 0, dilation = 1) where {T,N} + stride = expand(Val(N-2), stride) + pad = expand(Val(2*(N-2)), pad) + dilation = expand(Val(N-2), dilation) + return CrossCor(σ, w, b, stride, pad, dilation) +end + +CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; + init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N = + CrossCor(param(init(k..., ch...)), param(zeros(ch[2])), σ, + stride = stride, pad = pad, dilation = dilation) + +@treelike CrossCor + +function crosscor(x, w, ddims::DenseConvDims) + ddims = DenseConvDims(ddims, F=true) + return conv(x, w, ddims) +end + +function (c::CrossCor)(x::AbstractArray) + # TODO: breaks gpu broadcast :( + # ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1))) + σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) + cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation) + σ.(crosscor(x, c.weight, cdims) .+ b) +end + +function Base.show(io::IO, l::CrossCor) + print(io, "CrossCor(", size(l.weight)[1:ndims(l.weight)-2]) + print(io, ", ", size(l.weight, ndims(l.weight)-1), "=>", size(l.weight, ndims(l.weight))) + l.σ == identity || print(io, ", ", l.σ) + print(io, ")") +end + +(a::CrossCor{<:Any,<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = + invoke(a, Tuple{AbstractArray}, x) + +(a::CrossCor{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = + a(T.(x)) """ MaxPool(k) diff --git a/test/cuda/cuda.jl b/test/cuda/cuda.jl index 86e7f2f3b8..96d04c284b 100644 --- a/test/cuda/cuda.jl +++ b/test/cuda/cuda.jl @@ -36,6 +36,10 @@ c = gpu(Conv((2,2),3=>4)) l = c(gpu(rand(10,10,3,2))) Flux.back!(sum(l)) +c = gpu(CrossCor((2,2),3=>4)) +l = c(gpu(rand(10,10,3,2))) +Flux.back!(sum(l)) + end @testset "onecold gpu" begin diff --git a/test/layers/conv.jl b/test/layers/conv.jl index 5e12e42668..5b2e2392ca 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -56,6 +56,27 @@ end @test size(x_hat) == size(x) end +@testset "CrossCor" begin + x = rand(Float32, 28, 28, 1, 1) + w = rand(2,2,1,1) + y = CrossCor(w, [0.0]) + + @test sum(w .* x[1:2, 1:2, :, :]) == y(x)[1, 1, 1, 1] + + r = zeros(Float32, 28, 28, 1, 5) + m = Chain( + CrossCor((2, 2), 1=>16, relu), + MaxPool((2,2)), + CrossCor((2, 2), 16=>8, relu), + MaxPool((2,2)), + x -> reshape(x, :, size(x, 4)), + Dense(288, 10), softmax) + + @test size(m(r)) == (10, 5) + @test y(x) != Conv(w, [0.0])(x) + @test CrossCor(w[end:-1:1, end:-1:1, :, :], [0.0])(x) == Conv(w, [0.0])(x) +end + @testset "Conv with non quadratic window #700" begin data = zeros(Float32, 7,7,1,1) data[4,4,1,1] = 1 @@ -81,3 +102,4 @@ end true end end +