Skip to content

Commit

Permalink
Revert "remove flatten"
Browse files Browse the repository at this point in the history
This reverts commit d91c203.
  • Loading branch information
CarloLucibello committed Feb 13, 2023
1 parent d91c203 commit 6278916
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
2 changes: 1 addition & 1 deletion perf/vgg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ function vgg16()
Conv((3, 3), 512 => 512, relu, pad=(1, 1), stride=(1, 1)),
BatchNorm(512),
MaxPool((2,2)),
Flux.flatten,
flatten,
Dense(512, 4096, relu),
Dropout(0.5),
Dense(4096, 4096, relu),
Expand Down
28 changes: 28 additions & 0 deletions src/layers/stateless.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,31 @@
"""
flatten(x::AbstractArray)
Reshape arbitrarly-shaped input into a matrix-shaped output,
preserving the size of the last dimension.
See also [`unsqueeze`](@ref).
# Examples
```jldoctest
julia> rand(3,4,5) |> Flux.flatten |> size
(12, 5)
julia> xs = rand(Float32, 10,10,3,7);
julia> m = Chain(Conv((3,3), 3 => 4, pad=1), Flux.flatten, Dense(400 => 33));
julia> xs |> m[1] |> size
(10, 10, 4, 7)
julia> xs |> m |> size
(33, 7)
```
"""
function flatten(x::AbstractArray)
return reshape(x, :, size(x)[end])
end

"""
normalise(x; dims=ndims(x), ϵ=1e-5)
Expand Down

0 comments on commit 6278916

Please sign in to comment.