diff --git a/perf/vgg.jl b/perf/vgg.jl index 708c152c90..33b7bfd61d 100644 --- a/perf/vgg.jl +++ b/perf/vgg.jl @@ -38,7 +38,7 @@ function vgg16() Conv((3, 3), 512 => 512, relu, pad=(1, 1), stride=(1, 1)), BatchNorm(512), MaxPool((2,2)), - flatten, + Flux.flatten, Dense(512, 4096, relu), Dropout(0.5), Dense(4096, 4096, relu), diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index 2674a6a445..635fa9c414 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -1,30 +1,3 @@ -""" - flatten(x::AbstractArray) - -Reshape arbitrarly-shaped input into a matrix-shaped output, -preserving the size of the last dimension. - -See also [`unsqueeze`](@ref). - -# Examples -```jldoctest -julia> rand(3,4,5) |> Flux.flatten |> size -(12, 5) - -julia> xs = rand(Float32, 10,10,3,7); - -julia> m = Chain(Conv((3,3), 3 => 4, pad=1), Flux.flatten, Dense(400 => 33)); - -julia> xs |> m[1] |> size -(10, 10, 4, 7) - -julia> xs |> m |> size -(33, 7) -``` -""" -function flatten(x::AbstractArray) - return reshape(x, :, size(x)[end]) -end """ normalise(x; dims=ndims(x), ϵ=1e-5)