Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CrossCor layer #762

Merged
merged 10 commits into from
May 14, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
* [Data.Iris](https://github.com/FluxML/Flux.jl/pull/652) makes Fisher's Iris dataset available with `Iris.labels` and `Iris.features`.
* New [InstanceNorm](https://github.com/FluxML/Flux.jl/pull/634), as popularized by [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022).
* New [GroupNorm](https://github.com/FluxML/Flux.jl/pull/696), as described in [Group Normalization](https://arxiv.org/abs/1803.08494).
* New [CrossCor](https://github.com/FluxML/Flux.jl/pull/762).

AD Changes:

Expand Down
1 change: 1 addition & 0 deletions docs/src/models/layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ MaxPool
MeanPool
DepthwiseConv
ConvTranspose
CrossCor
```

## Recurrent Layers
Expand Down
2 changes: 1 addition & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using Base: tail
using MacroTools, Juno, Requires, Reexport, Statistics, Random
using MacroTools: @forward

export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool,
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
params, mapleaves, cpu, gpu, f32, f64

Expand Down
70 changes: 70 additions & 0 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,76 @@ end

(a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))
"""
CrossCor(size, in=>out)
CrossCor(size, in=>out, relu)

Standard cross convolutional layer. `size` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively.

Example: Applying CrossCor layer to a 1-channel input using a 2x2 window size,
giving us a 16-channel output. Output is activated with ReLU.

size = (2,2)
in = 1
out = 16
CrossCor((2, 2), 1=>16, relu)

Data should be stored in WHCN order (width, height, # channels, # batches).
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
and a batch of 50 would be a `100×100×3×50` array.

Takes the keyword arguments `pad`, `stride` and `dilation`.
"""
struct CrossCor{N,M,F,A,V}
σ::F
weight::A
bias::V
stride::NTuple{N,Int}
pad::NTuple{M,Int}
dilation::NTuple{N,Int}
end

function CrossCor(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N}
stride = expand(Val(N-2), stride)
pad = expand(Val(2*(N-2)), pad)
dilation = expand(Val(N-2), dilation)
return CrossCor(σ, w, b, stride, pad, dilation)
end

CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
CrossCor(param(init(k..., ch...)), param(zeros(ch[2])), σ,
stride = stride, pad = pad, dilation = dilation)

@treelike CrossCor

function crosscor(x, w, ddims::DenseConvDims)
ddims = DenseConvDims(ddims, F=true)
return conv(x, w, ddims)
end

function (c::CrossCor)(x::AbstractArray)
# TODO: breaks gpu broadcast :(
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
σ.(crosscor(x, c.weight, cdims) .+ b)
end

function Base.show(io::IO, l::CrossCor)
print(io, "CrossCor(", size(l.weight)[1:ndims(l.weight)-2])
print(io, ", ", size(l.weight, ndims(l.weight)-1), "=>", size(l.weight, ndims(l.weight)))
l.σ == identity || print(io, ", ", l.σ)
print(io, ")")
end

(a::CrossCor{<:Any,<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
invoke(a, Tuple{AbstractArray}, x)

(a::CrossCor{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))

"""
MaxPool(k)
Expand Down
4 changes: 4 additions & 0 deletions test/cuda/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ c = gpu(Conv((2,2),3=>4))
l = c(gpu(rand(10,10,3,2)))
Flux.back!(sum(l))

c = gpu(CrossCor((2,2),3=>4))
l = c(gpu(rand(10,10,3,2)))
Flux.back!(sum(l))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to have some numerical tests here, e.g. comparing to a regular convolution. Also, let's have a news item for this.


end

@testset "onecold gpu" begin
Expand Down
22 changes: 22 additions & 0 deletions test/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,27 @@ end
@test size(x_hat) == size(x)
end

@testset "CrossCor" begin
x = rand(Float32, 28, 28, 1, 1)
w = rand(2,2,1,1)
y = CrossCor(w, [0.0])

@test sum(w .* x[1:2, 1:2, :, :]) == y(x)[1, 1, 1, 1]

r = zeros(Float32, 28, 28, 1, 5)
m = Chain(
CrossCor((2, 2), 1=>16, relu),
MaxPool((2,2)),
CrossCor((2, 2), 16=>8, relu),
MaxPool((2,2)),
x -> reshape(x, :, size(x, 4)),
Dense(288, 10), softmax)

@test size(m(r)) == (10, 5)
@test y(x) != Conv(w, [0.0])(x)
@test CrossCor(w[end:-1:1, end:-1:1, :, :], [0.0])(x) == Conv(w, [0.0])(x)
end

@testset "Conv with non quadratic window #700" begin
data = zeros(Float32, 7,7,1,1)
data[4,4,1,1] = 1
Expand All @@ -81,3 +102,4 @@ end
true
end
end