diff --git a/Manifest.toml b/Manifest.toml index 46e56ad3ab..84df71b24b 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -53,9 +53,9 @@ version = "0.2.0" [[Compat]] deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] -git-tree-sha1 = "195a3ffcb8b0762684b6821de18f83a16455c6ea" +git-tree-sha1 = "84aa74986c5b9b898b0d1acaf3258741ee64754f" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "2.0.0" +version = "2.1.0" [[DataStructures]] deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"] @@ -84,7 +84,7 @@ uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" version = "0.0.10" [[Distributed]] -deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"] +deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[FixedPointNumbers]] @@ -100,7 +100,7 @@ uuid = "f6369f11-7733-5829-9624-2563aa707210" version = "0.10.3" [[InteractiveUtils]] -deps = ["LinearAlgebra", "Markdown"] +deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" [[Juno]] @@ -149,7 +149,7 @@ uuid = "a63ad114-7e13-5084-954f-fe012c677804" [[NNlib]] deps = ["Libdl", "LinearAlgebra", "MacroTools", "Requires", "Test"] -git-tree-sha1 = "d07ac0bfd3c71c3a29bc9c22becbba19227bbeb5" +git-tree-sha1 = "9ac5cd21484189339b27840818c4882d1b6df7fd" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" version = "0.5.0" @@ -265,7 +265,7 @@ uuid = "30578b45-9adc-5946-b283-645ec420af67" version = "0.4.0" [[UUIDs]] -deps = ["Random"] +deps = ["Random", "SHA"] uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [[Unicode]] diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 216957a70e..30cd4f7c5a 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -50,7 +50,8 @@ function (c::Conv)(x::AbstractArray) # TODO: breaks gpu broadcast :( # ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1))) σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) - σ.(conv(x, c.weight, stride = c.stride, pad = c.pad, dilation = c.dilation) .+ b) + cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation) + σ.(conv(x, c.weight, cdims) .+ b) end function Base.show(io::IO, l::Conv) @@ -99,7 +100,17 @@ ConvTranspose(param(init(k..., reverse(ch)...)), param(zeros(ch[2])), σ, function (c::ConvTranspose)(x::AbstractArray) # ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1))) σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) - σ.(∇conv_data(x, c.weight, stride = c.stride, pad = c.pad, dilation = c.dilation) .+ b) + # Calculate size of "input", from ∇conv_data()'s perspective... + I = (size(x)[1:end-2] .- 1).*c.stride .+ 1 .+ (size(c.weight)[1:end-2] .- 1).*c.dilation .- 2 .* c.pad + C_in = size(c.weight)[end-1] + batch_size = size(x)[end] + # Create DenseConvDims() that looks like the corresponding conv() + cdims = DenseConvDims((I..., C_in, batch_size), size(c.weight); + stride=c.stride, + padding=c.pad, + dilation=c.dilation, + ) + return σ.(∇conv_data(x, c.weight, cdims) .+ b) end function Base.show(io::IO, l::ConvTranspose) @@ -134,20 +145,22 @@ struct DepthwiseConv{N,F,A,V} bias::V stride::NTuple{N,Int} pad::NTuple{N,Int} + dilation::NTuple{N,Int} end DepthwiseConv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity; - stride = 1, pad = 0) where {T,N} = - DepthwiseConv(σ, w, b, expand.(sub2(Val(N)), (stride, pad))...) + stride = 1, pad = 0, dilation = 1) where {T,N} = + DepthwiseConv(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...) DepthwiseConv(k::NTuple{N,Integer}, ch::Integer, σ = identity; init = glorot_uniform, - stride = 1, pad = 0) where N = + stride = 1, pad = 0, dilation = 1) where N = DepthwiseConv(param(init(k..., 1, ch)), param(zeros(ch)), σ, - stride = stride, pad = pad) + stride = stride, pad = pad, dilation=dilation) DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = glorot_uniform, stride::NTuple{N,Integer} = map(_->1,k), - pad::NTuple{N,Integer} = map(_->0,k)) where N = + pad::NTuple{N,Integer} = map(_->0,k), + dilation::NTuple{N,Integer} = map(_->1,k)) where N = DepthwiseConv(param(init(k..., ch[2], ch[1])), param(zeros(ch[2]*ch[1])), σ, stride = stride, pad = pad) @@ -155,7 +168,8 @@ DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity function (c::DepthwiseConv)(x) σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) - σ.(depthwiseconv(x, c.weight, stride = c.stride, pad = c.pad) .+ b) + cdims = DepthwiseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation) + σ.(depthwiseconv(x, c.weight, cdims) .+ b) end function Base.show(io::IO, l::DepthwiseConv) @@ -181,7 +195,10 @@ end MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N = MaxPool(k, expand(Val(N), pad), expand(Val(N), stride)) -(m::MaxPool)(x) = maxpool(x, m.k; pad = m.pad, stride = m.stride) +function (m::MaxPool)(x) + pdims = PoolDims(x, m.k; padding=m.pad, stride=m.stride) + return maxpool(x, pdims) +end function Base.show(io::IO, m::MaxPool) print(io, "MaxPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")") @@ -203,7 +220,10 @@ end MeanPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N = MeanPool(k, expand(Val(N), pad), expand(Val(N), stride)) -(m::MeanPool)(x) = meanpool(x, m.k; pad = m.pad, stride = m.stride) +function (m::MeanPool)(x) + pdims = PoolDims(x, m.k; padding=m.pad, stride=m.stride) + return meanpool(x, pdims) +end function Base.show(io::IO, m::MeanPool) print(io, "MeanPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")") diff --git a/test/layers/conv.jl b/test/layers/conv.jl index 0bec44c15e..a6b35ebb94 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -4,9 +4,9 @@ using Flux: maxpool, meanpool @testset "Pooling" begin x = randn(Float32, 10, 10, 3, 2) mp = MaxPool((2, 2)) - @test mp(x) == maxpool(x, (2,2)) + @test mp(x) == maxpool(x, PoolDims(x, 2)) mp = MeanPool((2, 2)) - @test mp(x) == meanpool(x, (2,2)) + @test mp(x) == meanpool(x, PoolDims(x, 2)) end @testset "CNN" begin