Skip to content

Commit

Permalink
Merge #762
Browse files Browse the repository at this point in the history
762: CrossCor layer r=avik-pal a=ayush-1506

Same as #423 (which could be edited since I lost access to that github account).

Co-authored-by: ayush-1506 <ayush.shridhar1506@gmail.com>
  • Loading branch information
bors[bot] and ayush-1506 committed May 14, 2019
2 parents 308b199 + 98a027a commit 8ee6af1
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 1 deletion.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
* [Data.Iris](https://github.com/FluxML/Flux.jl/pull/652) makes Fisher's Iris dataset available with `Iris.labels` and `Iris.features`.
* New [InstanceNorm](https://github.com/FluxML/Flux.jl/pull/634), as popularized by [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022).
* New [GroupNorm](https://github.com/FluxML/Flux.jl/pull/696), as described in [Group Normalization](https://arxiv.org/abs/1803.08494).
* New [CrossCor](https://github.com/FluxML/Flux.jl/pull/762).

AD Changes:

Expand Down
1 change: 1 addition & 0 deletions docs/src/models/layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ MaxPool
MeanPool
DepthwiseConv
ConvTranspose
CrossCor
```

## Recurrent Layers
Expand Down
2 changes: 1 addition & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using Base: tail
using MacroTools, Juno, Requires, Reexport, Statistics, Random
using MacroTools: @forward

export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool,
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
params, mapleaves, cpu, gpu, f32, f64

Expand Down
70 changes: 70 additions & 0 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,76 @@ end

(a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))
"""
CrossCor(size, in=>out)
CrossCor(size, in=>out, relu)
Standard cross convolutional layer. `size` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively.
Example: Applying CrossCor layer to a 1-channel input using a 2x2 window size,
giving us a 16-channel output. Output is activated with ReLU.
size = (2,2)
in = 1
out = 16
CrossCor((2, 2), 1=>16, relu)
Data should be stored in WHCN order (width, height, # channels, # batches).
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad`, `stride` and `dilation`.
"""
struct CrossCor{N,M,F,A,V}
σ::F
weight::A
bias::V
stride::NTuple{N,Int}
pad::NTuple{M,Int}
dilation::NTuple{N,Int}
end

function CrossCor(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N}
stride = expand(Val(N-2), stride)
pad = expand(Val(2*(N-2)), pad)
dilation = expand(Val(N-2), dilation)
return CrossCor(σ, w, b, stride, pad, dilation)
end

CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
CrossCor(param(init(k..., ch...)), param(zeros(ch[2])), σ,
stride = stride, pad = pad, dilation = dilation)

@treelike CrossCor

function crosscor(x, w, ddims::DenseConvDims)
ddims = DenseConvDims(ddims, F=true)
return conv(x, w, ddims)
end

function (c::CrossCor)(x::AbstractArray)
# TODO: breaks gpu broadcast :(
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
σ.(crosscor(x, c.weight, cdims) .+ b)
end

function Base.show(io::IO, l::CrossCor)
print(io, "CrossCor(", size(l.weight)[1:ndims(l.weight)-2])
print(io, ", ", size(l.weight, ndims(l.weight)-1), "=>", size(l.weight, ndims(l.weight)))
l.σ == identity || print(io, ", ", l.σ)
print(io, ")")
end

(a::CrossCor{<:Any,<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
invoke(a, Tuple{AbstractArray}, x)

(a::CrossCor{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))

"""
MaxPool(k)
Expand Down
4 changes: 4 additions & 0 deletions test/cuda/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ c = gpu(Conv((2,2),3=>4))
l = c(gpu(rand(10,10,3,2)))
Flux.back!(sum(l))

c = gpu(CrossCor((2,2),3=>4))
l = c(gpu(rand(10,10,3,2)))
Flux.back!(sum(l))

end

@testset "onecold gpu" begin
Expand Down
22 changes: 22 additions & 0 deletions test/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,27 @@ end
@test size(x_hat) == size(x)
end

@testset "CrossCor" begin
x = rand(Float32, 28, 28, 1, 1)
w = rand(2,2,1,1)
y = CrossCor(w, [0.0])

@test sum(w .* x[1:2, 1:2, :, :]) == y(x)[1, 1, 1, 1]

r = zeros(Float32, 28, 28, 1, 5)
m = Chain(
CrossCor((2, 2), 1=>16, relu),
MaxPool((2,2)),
CrossCor((2, 2), 16=>8, relu),
MaxPool((2,2)),
x -> reshape(x, :, size(x, 4)),
Dense(288, 10), softmax)

@test size(m(r)) == (10, 5)
@test y(x) != Conv(w, [0.0])(x)
@test CrossCor(w[end:-1:1, end:-1:1, :, :], [0.0])(x) == Conv(w, [0.0])(x)
end

@testset "Conv with non quadratic window #700" begin
data = zeros(Float32, 7,7,1,1)
data[4,4,1,1] = 1
Expand All @@ -81,3 +102,4 @@ end
true
end
end

0 comments on commit 8ee6af1

Please sign in to comment.