Skip to content

Commit

Permalink
conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
ayush1999 committed Mar 26, 2019
1 parent 38b307b commit 161fa86
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 524 deletions.
69 changes: 12 additions & 57 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
<<<<<<< HEAD
using NNlib: conv, ∇conv_data, depthwiseconv, crossconv
=======
using NNlib: conv, depthwiseconv, crosscor
>>>>>>> some final changes
using NNlib: conv, ∇conv_data, depthwiseconv, crosscor

@generated sub2(::Val{N}) where N = :(Val($(N-2)))

Expand Down Expand Up @@ -73,8 +69,6 @@ end
"""
ConvTranspose(size, in=>out)
ConvTranspose(size, in=>out, relu)
CrossCor(size, in=>out)
CrossCor(size, in=>out, relu)
Standard convolutional transpose layer. `size` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively.
Expand All @@ -83,7 +77,6 @@ be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad`, `stride` and `dilation`.
"""
struct ConvTranspose{N,F,A,V}
struct CrossCor{N,F,A,V}
σ::F
weight::A
bias::V
Expand Down Expand Up @@ -173,8 +166,8 @@ function Base.show(io::IO, l::DepthwiseConv)
end

"""
CrossConv(size, in=>out)
CrossConv(size, in=>out, relu)
CrossCor(size, in=>out)
CrossCor(size, in=>out, relu)
Standard cross convolutional layer. `size` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively.
Expand All @@ -197,8 +190,8 @@ CrossCor(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N} =
CrossCor(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...)

CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
stride = 1, pad = 0, dilation = 1) where N =
CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
CrossCor(param(init(k..., ch...)), param(zeros(ch[2])), σ,
stride = stride, pad = pad, dilation = dilation)

Expand All @@ -218,6 +211,12 @@ function Base.show(io::IO, l::CrossCor)
print(io, ")")
end

(a::CrossCor{<:Any,<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
invoke(a, Tuple{AbstractArray}, x)

(a::CrossCor{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))

"""
MaxPool(k)
Expand Down Expand Up @@ -260,48 +259,4 @@ MeanPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N =

function Base.show(io::IO, m::MeanPool)
print(io, "MeanPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
end
<<<<<<< HEAD
=======

"""
CrossCor(size, in=>out)
CrossCor(size, in=>out, relu)
Standard cross convolutional layer. `size` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively.
Data should be stored in WHCN order. In other words, a 100×100 RGB image would
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad`, `stride` and `dilation`.
"""
struct CrossCor{N,F,A,V}
σ::F
weight::A
bias::V
stride::NTuple{N,Int}
pad::NTuple{N,Int}
dilation::NTuple{N,Int}
end
CrossCor(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N} =
CrossCor(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...)

CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
CrossCor(param(init(k..., ch...)), param(zeros(ch[2])), σ,
stride = stride, pad = pad, dilation = dilation)

@treelike CrossCor

function (c::CrossCor)(x)
# TODO: breaks gpu broadcast :(
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
σ.(crosscor(x, c.weight, stride = c.stride, pad = c.pad, dilation = c.dilation) .+ b)
end
function Base.show(io::IO, l::CrossCor)
print(io, "CrossCor(", size(l.weight)[1:ndims(l.weight)-2])
print(io, ", ", size(l.weight, ndims(l.weight)-1), "=>", size(l.weight, ndims(l.weight)))
l.σ == identity || print(io, ", ", l.σ)
print(io, ")")
end
>>>>>>> some final changes
end
Loading

0 comments on commit 161fa86

Please sign in to comment.