Skip to content

Commit

Permalink
remove Flux.flatten in favor of MLUtils.flatten (#2188)
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello authored Feb 14, 2023
1 parent 6278916 commit 8a45b88
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 28 deletions.
2 changes: 1 addition & 1 deletion perf/vgg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ function vgg16()
Conv((3, 3), 512 => 512, relu, pad=(1, 1), stride=(1, 1)),
BatchNorm(512),
MaxPool((2,2)),
flatten,
Flux.flatten,
Dense(512, 4096, relu),
Dropout(0.5),
Dense(4096, 4096, relu),
Expand Down
27 changes: 0 additions & 27 deletions src/layers/stateless.jl
Original file line number Diff line number Diff line change
@@ -1,30 +1,3 @@
"""
flatten(x::AbstractArray)
Reshape arbitrarly-shaped input into a matrix-shaped output,
preserving the size of the last dimension.
See also [`unsqueeze`](@ref).
# Examples
```jldoctest
julia> rand(3,4,5) |> Flux.flatten |> size
(12, 5)
julia> xs = rand(Float32, 10,10,3,7);
julia> m = Chain(Conv((3,3), 3 => 4, pad=1), Flux.flatten, Dense(400 => 33));
julia> xs |> m[1] |> size
(10, 10, 4, 7)
julia> xs |> m |> size
(33, 7)
```
"""
function flatten(x::AbstractArray)
return reshape(x, :, size(x)[end])
end

"""
normalise(x; dims=ndims(x), ϵ=1e-5)
Expand Down

0 comments on commit 8a45b88

Please sign in to comment.