From 128226dbe8b2639339ca9f79f0bd77c6b578855e Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Wed, 2 Feb 2022 14:42:38 +0200 Subject: [PATCH 1/7] Fix type-stability for normalization layers --- Project.toml | 4 +- src/layers/conv.jl | 14 ++-- src/layers/normalise.jl | 126 ++++++++++++++++++++++------------- test/layers/normalisation.jl | 35 ++++++++-- test/runtests.jl | 90 ++++++++++++------------- 5 files changed, 166 insertions(+), 103 deletions(-) diff --git a/Project.toml b/Project.toml index d06b1bb5e3..77da52a829 100644 --- a/Project.toml +++ b/Project.toml @@ -38,8 +38,8 @@ Colors = "0.12" Functors = "0.2.1" Juno = "0.8" MacroTools = "0.5" -NNlib = "0.7.24" -NNlibCUDA = "0.1.7" +NNlib = "0.8.0" +NNlibCUDA = "0.2.0" Reexport = "0.2, 1.0" StatsBase = "0.33" ZipFile = "0.9" diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 0b6372364d..be0be3e34c 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -162,7 +162,9 @@ end function (c::Conv)(x::AbstractArray) σ, b = c.σ, reshape(c.bias, ntuple(_ -> 1, length(c.stride))..., :, 1) - cdims = DenseConvDims(x, c.weight; stride = c.stride, padding = c.pad, dilation = c.dilation, groups = c.groups) + cdims = DenseConvDims( + x, c.weight; stride=c.stride, padding=c.pad, + dilation=c.dilation, groups=c.groups) σ.(conv(x, c.weight, cdims) .+ b) end @@ -656,19 +658,23 @@ julia> lay(rand(Float32, 100, 7, 50)) |> size (34, 7, 50) ``` """ -struct MaxPool{N,M} +struct MaxPool{N, M} k::NTuple{N,Int} pad::NTuple{M,Int} stride::NTuple{N,Int} end -function MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N +function MaxPool(k::NTuple{N, Integer}; pad = 0, stride = k) where N stride = expand(Val(N), stride) - pad = calc_padding(MaxPool ,pad, k, 1, stride) + pad = calc_padding(MaxPool, pad, k, 1, stride) return MaxPool(k, pad, stride) end function (m::MaxPool)(x) + # size_x = size(x) + # kernel, stride, padding, dilation = NNlib.prepare_pooldims( + # Val(N), size_x, m.k; padding=m.pad, stride=m.stride) + # pdims = PoolDims{kernel, stride, padding, dilation}(size_x) pdims = PoolDims(x, m.k; padding=m.pad, stride=m.stride) return maxpool(x, pdims) end diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 146b7dba56..0d62969020 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -122,13 +122,13 @@ testmode!(m::AlphaDropout, mode=true) = LayerNorm(sz, λ=identity; affine=true, ϵ=1fe-5) A [normalisation layer](https://arxiv.org/abs/1607.06450) designed to be -used with recurrent hidden states. -The argument `sz` should be an integer or a tuple of integers. -In the forward pass, the layer normalises the mean and standard +used with recurrent hidden states. +The argument `sz` should be an integer or a tuple of integers. +In the forward pass, the layer normalises the mean and standard deviation of the input, the applied the elementwise activation `λ`. The input is normalised along the first `length(sz)` dimensions for tuple `sz`, along the first dimension for integer `sz`. -The input is expected to have first dimensions' size equal to `sz`. +The input is expected to have first dimensions' size equal to `sz`. If `affine=true` also applies a learnable shift and rescaling as in the [`Diagonal`](@ref) layer. @@ -164,38 +164,70 @@ function Base.show(io::IO, l::LayerNorm) print(io, ")") end +_maybe_promote_type(::Type{T1}, ::Type{T2}) where {T1, T2} = promote_type(T1, T2) +_maybe_promote_type(::Type{Nothing}, ::Type{T2}) where T2 = T2 +_maybe_promote_type(::Type{T1}, ::Type{Nothing}) where T1 = T1 + +_maybe_eltype(::Type{T}) where T <: AbstractArray = eltype(T) +_maybe_eltype(::Type{Nothing}) = Nothing + +abstract type Normalization{F, V, N, W} end + +function _promote_to_output( + ::Normalization{F, V, N, W}, x::AbstractArray{T}, +) where {F, V, N, W, T} + Vel = _maybe_eltype(V) + Wel = _maybe_eltype(W) + _maybe_promote_type(_maybe_promote_type( + _maybe_promote_type(T, Vel), N), Wel) +end + # For InstanceNorm, GroupNorm, and BatchNorm. # Compute the statistics on the slices specified by reduce_dims. # reduce_dims=[1,...,N-2,N] for BatchNorm # reduce_dims=[1,...,N-2] for InstanceNorm and GroupNorm -function _norm_layer_forward(l, x::AbstractArray{T,N}; reduce_dims, affine_shape) where {T, N} +_norm_layer_forward(l, x; reduce_dims, affine_shape) = + _norm_layer_forward(l, x, _promote_to_output(l, x); reduce_dims, affine_shape) + +function _norm_layer_forward( + l, x::Array{T, N}, ::Type{O}; reduce_dims, affine_shape, +) where {T, N, O} if !_isactive(l) && l.track_stats # testmode with tracked stats stats_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N) μ = reshape(l.μ, stats_shape) σ² = reshape(l.σ², stats_shape) - else # trainmode or testmode without tracked stats + else # trainmode or testmode without tracked stats μ = mean(x; dims=reduce_dims) σ² = mean((x .- μ).^2; dims=reduce_dims) if l.track_stats - ## update moving mean/std - Zygote.ignore() do - mtm = l.momentum - m = prod(size(x, i) for i in reduce_dims) # needed for computing corrected var - μnew = vec(N ∈ reduce_dims ? μ : mean(μ, dims=N)) - σ²new = vec(N ∈ reduce_dims ? σ² : mean(σ², dims=N)) - l.μ = (1-mtm) .* l.μ .+ mtm .* μnew - l.σ² = (1-mtm) .* l.σ² .+ mtm .* (m / (m - one(eltype(l.σ²)))) .* σ²new - end + _track_stats!(l, x, μ, σ², reduce_dims) # update moving mean/std end end - if hasaffine(l) - γ = reshape(l.γ, affine_shape) - β = reshape(l.β, affine_shape) - return l.λ.(γ .* (x .- μ) ./ sqrt.(σ² .+ l.ϵ) .+ β) - else - return l.λ.((x .- μ) ./ sqrt.(σ² .+ l.ϵ)) - end + + o::Array{O, N} = ((x .- μ) ./ sqrt.(σ² .+ l.ϵ)) + hasaffine(l) || return l.λ.(o) + + γ = reshape(l.γ, affine_shape) + β = reshape(l.β, affine_shape) + return l.λ.(γ .* o .+ β) +end + +function _track_stats!( + bn, x::AbstractArray{T, N}, μ, σ², reduce_dims, +) where {T, N} + V = eltype(bn.σ²) + mtm = bn.momentum + res_mtm = one(V) - mtm + m = prod(size(x, i) for i in reduce_dims) + + μnew = vec(N ∈ reduce_dims ? μ : mean(μ, dims=N)) + σ²new = vec(N ∈ reduce_dims ? σ² : mean(σ², dims=N)) + + bn.μ = res_mtm .* bn.μ .+ mtm .* μnew + bn.σ² = res_mtm .* bn.σ² .+ mtm .* (m / (m - one(V))) .* σ²new + nothing end +Zygote.@nograd _track_stats! """ BatchNorm(channels::Integer, λ=identity; @@ -210,15 +242,15 @@ Given an array with `N` dimensions, call the `N-1`th the channel dimension. For a batch of feature vectors this is just the data dimension, for `WHCN` images it's the usual channel dimension. -`BatchNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×D_N` +`BatchNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×D_N` input slice and normalises the input accordingly. -If `affine=true`, it also applies a shift and a rescale to the input +If `affine=true`, it also applies a shift and a rescale to the input through to learnable per-channel bias β and scale γ parameters. -After normalisation, elementwise activation `λ` is applied. +After normalisation, elementwise activation `λ` is applied. -If `track_stats=true`, accumulates mean and var statistics in training phase +If `track_stats=true`, accumulates mean and var statistics in training phase that will be used to renormalize the input in test phase. Use [`testmode!`](@ref) during inference. @@ -233,7 +265,7 @@ m = Chain( softmax) ``` """ -mutable struct BatchNorm{F,V,N,W} +mutable struct BatchNorm{F,V,N,W} <: Normalization{F, V, N, W} λ::F # activation function β::V # bias γ::V # scale @@ -248,7 +280,7 @@ mutable struct BatchNorm{F,V,N,W} end function BatchNorm(chs::Int, λ=identity; - initβ=zeros32, initγ=ones32, + initβ=zeros32, initγ=ones32, affine=true, track_stats=true, ϵ=1f-5, momentum=0.1f0) @@ -258,8 +290,8 @@ function BatchNorm(chs::Int, λ=identity; σ² = track_stats ? ones32(chs) : nothing return BatchNorm(λ, β, γ, - μ, σ², ϵ, momentum, - affine, track_stats, + μ, σ², ϵ, momentum, + affine, track_stats, nothing, chs) end @@ -294,22 +326,22 @@ end [Instance Normalization](https://arxiv.org/abs/1607.08022) layer. `channels` should be the size of the channel dimension in your data (see below). -Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension. +Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension. For `WHCN` images it's the usual channel dimension. -`InstanceNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×1` +`InstanceNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×1` input slice and normalises the input accordingly. -If `affine=true`, it also applies a shift and a rescale to the input +If `affine=true`, it also applies a shift and a rescale to the input through to learnable per-channel bias `β` and scale `γ` parameters. -If `track_stats=true`, accumulates mean and var statistics in training phase +If `track_stats=true`, accumulates mean and var statistics in training phase that will be used to renormalize the input in test phase. -**Warning**: the defaults for `affine` and `track_stats` used to be `true` +**Warning**: the defaults for `affine` and `track_stats` used to be `true` in previous Flux versions (< v0.12). """ -mutable struct InstanceNorm{F,V,N,W} +mutable struct InstanceNorm{F,V,N,W} <: Normalization{F, V, N, W} λ::F # activation function β::V # bias γ::V # scale @@ -334,7 +366,7 @@ function InstanceNorm(chs::Int, λ=identity; σ² = track_stats ? ones32(chs) : nothing return InstanceNorm(λ, β, γ, - μ, σ², ϵ, momentum, + μ, σ², ϵ, momentum, affine, track_stats, nothing, chs) end @@ -377,16 +409,16 @@ The number of channels must be an integer multiple of the number of groups. `channels` should be the size of the channel dimension in your data (see below). -Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension. +Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension. For `WHCN` images it's the usual channel dimension. -If `affine=true`, it also applies a shift and a rescale to the input +If `affine=true`, it also applies a shift and a rescale to the input through to learnable per-channel bias `β` and scale `γ` parameters. -If `track_stats=true`, accumulates mean and var statistics in training phase +If `track_stats=true`, accumulates mean and var statistics in training phase that will be used to renormalize the input in test phase. """ -mutable struct GroupNorm{F,V,N,W} +mutable struct GroupNorm{F,V,N,W} <: Normalization{F, V, N, W} G::Int # number of groups λ::F # activation function β::V # bias @@ -405,7 +437,7 @@ end trainable(gn::GroupNorm) = hasaffine(gn) ? (gn.β, gn.γ) : () function GroupNorm(chs::Int, G::Int, λ=identity; - initβ=zeros32, initγ=ones32, + initβ=zeros32, initγ=ones32, affine=true, track_stats=false, ϵ=1f-5, momentum=0.1f0) @@ -416,11 +448,11 @@ function GroupNorm(chs::Int, G::Int, λ=identity; μ = track_stats ? zeros32(G) : nothing σ² = track_stats ? ones32(G) : nothing - return GroupNorm(G, λ, + return GroupNorm(G, λ, β, γ, - μ, σ², - ϵ, momentum, - affine, track_stats, + μ, σ², + ϵ, momentum, + affine, track_stats, nothing, chs) end @@ -451,7 +483,7 @@ end """ hasaffine(l) -Return `true` if a normalisation layer has trainable shift and +Return `true` if a normalisation layer has trainable shift and scale parameters, `false` otherwise. See [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`GroupNorm`](@ref), and [`LayerNorm`](@ref). diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index 4548decd91..5f11d178ba 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -78,7 +78,7 @@ end y = evalwgrad(m, x) @test mean(y) ≈ 0 atol=0.1 @test var(y) ≈ 1 atol=0.1 - + # Known good value ranges # Values taken from https://github.com/pytorch/pytorch/blob/v1.10.0/test/cpp/api/modules.cpp#L1337-L1338 x = ones(100) @@ -118,9 +118,15 @@ end # 1.3 # 1.3 @test m.σ² ≈ .1 .* var(x, dims=2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.] - + x′ = m(x) @test isapprox(x′[1], (1 .- 0.3) / sqrt(1.3), atol = 1.0e-5) + + @inferred m(x) + end + + let m = BatchNorm(2; track_stats=false), x = [1.0 3.0 5.0; 2.0 4.0 6.0] + @inferred m(x) end # with activation function @@ -128,29 +134,34 @@ end 2.0 4.0 6.0] y = m(x) @test isapprox(y, sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ)), atol = 1.0e-7) + @inferred m(x) end let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:6), 3, 2, 1) y = reshape(permutedims(x, [2, 1, 3]), 2, :) y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3]) @test m(x) == y + @inferred m(x) end let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:12), 2, 3, 2, 1) y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :) y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4]) @test m(x) == y + @inferred m(x) end let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:24), 2, 2, 3, 2, 1) y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :) y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5]) @test m(x) == y + @inferred m(x) end let m = BatchNorm(32), x = randn(Float32, 416, 416, 32, 1); m(x) @test (@allocated m(x)) < 100_000_000 + @inferred m(x) end @test length(Flux.params(BatchNorm(10))) == 2 @@ -201,6 +212,8 @@ end @test length(m.μ) == 2 @test length(m.σ²) == 2 @test y ≈ (x .- reshape(m.μ, 1,2,1)) ./ sqrt.(reshape(m.σ², 1,2,1) .+ 1f-5) atol=1.0e-5 + + @inferred m(x) end # with activation function @@ -211,10 +224,12 @@ end affine_shape[[1,3]] .= 1 y = evalwgrad(m, x) - y = m(x) # inference time after a training step + y = m(x) # inference time after a training step μ = reshape(m.μ, affine_shape...) σ² = reshape(m.σ², affine_shape...) @test y ≈ sigmoid.((x .- μ) ./ sqrt.(σ² .+ m.ϵ)) atol=1.0e-7 + + @inferred m(x) end # with activation function @@ -222,24 +237,28 @@ end x = reshape(collect(1:prod(sizes)), sizes) @test Flux.hasaffine(m) == true - @test length(params(m)) == 2 + @test length(params(m)) == 2 x = Float64.(x) y = m(x) μ = mean(x, dims=1) σ² = var(x, dims=1, corrected=false) @test y ≈ sigmoid.((x .- μ) ./ sqrt.(σ² .+ m.ϵ)) atol=1.0e-7 + + @inferred m(x) end let m = InstanceNorm(2, sigmoid), sizes = (3, 2, 2), x = reshape(collect(1:prod(sizes)), sizes) @test Flux.hasaffine(m) == false @test length(params(m)) == 0 - + x = Float64.(x) y = m(x) μ = mean(x, dims=1) σ² = var(x, dims=1, corrected=false) @test y ≈ sigmoid.((x .- μ) ./ sqrt.(σ² .+ m.ϵ)) atol=1.0e-7 + + @inferred m(x) end @@ -248,6 +267,8 @@ end y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3) y = reshape(m(y), sizes...) @test m(x) == y + + @inferred m(x) end # check that μ, σ², and the output are the correct size for higher rank tensors @@ -257,6 +278,8 @@ end @test size(m.μ) == (sizes[end - 1], ) @test size(m.σ²) == (sizes[end - 1], ) @test size(y) == sizes + + @inferred m(x) end # show that instance norm is equal to batch norm when channel and batch dims are squashed @@ -268,6 +291,8 @@ end let m = InstanceNorm(32), x = randn(Float32, 416, 416, 32, 1); m(x) @test (@allocated m(x)) < 100_000_000 + + @inferred m(x) end @test length(Flux.params(InstanceNorm(10))) == 0 diff --git a/test/runtests.jl b/test/runtests.jl index 781edb549d..1a5a12aaa0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,55 +8,55 @@ using CUDA Random.seed!(0) -@testset "Utils" begin - include("utils.jl") -end +# @testset "Utils" begin +# include("utils.jl") +# end -@testset "Onehot" begin - include("onehot.jl") -end +# @testset "Onehot" begin +# include("onehot.jl") +# end -@testset "Optimise" begin - include("optimise.jl") -end +# @testset "Optimise" begin +# include("optimise.jl") +# end -@testset "Data" begin - include("data.jl") -end +# @testset "Data" begin +# include("data.jl") +# end -@testset "Losses" begin - include("losses.jl") - include("ctc.jl") - CUDA.functional() && include("ctc-gpu.jl") -end +# @testset "Losses" begin +# include("losses.jl") +# include("ctc.jl") +# CUDA.functional() && include("ctc-gpu.jl") +# end @testset "Layers" begin - include("layers/basic.jl") + # include("layers/basic.jl") include("layers/normalisation.jl") - include("layers/stateless.jl") - include("layers/recurrent.jl") - include("layers/conv.jl") - include("layers/upsample.jl") - include("layers/show.jl") -end - -@testset "outputsize" begin - using Flux: outputsize - include("outputsize.jl") -end - -@testset "CUDA" begin - if CUDA.functional() - include("cuda/runtests.jl") - else - @warn "CUDA unavailable, not testing GPU support" - end -end - -@static if VERSION == v"1.6" - using Documenter - @testset "Docs" begin - DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true) - doctest(Flux) - end -end + # include("layers/stateless.jl") + # include("layers/recurrent.jl") + # include("layers/conv.jl") + # include("layers/upsample.jl") + # include("layers/show.jl") +end + +# @testset "outputsize" begin +# using Flux: outputsize +# include("outputsize.jl") +# end + +# @testset "CUDA" begin +# if CUDA.functional() +# include("cuda/runtests.jl") +# else +# @warn "CUDA unavailable, not testing GPU support" +# end +# end + +# @static if VERSION == v"1.6" +# using Documenter +# @testset "Docs" begin +# DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true) +# doctest(Flux) +# end +# end From f0f330e684a299611fc85f57dbf4c3d18035b55b Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Wed, 2 Feb 2022 14:44:53 +0200 Subject: [PATCH 2/7] Add btime --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index 77da52a829..59c9139a1c 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.12.8" AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193" Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" From 76e9a3e24dd369f479d5b3d7d10c7fa74a7a0390 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Wed, 2 Feb 2022 14:45:53 +0200 Subject: [PATCH 3/7] cleanup --- Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/Project.toml b/Project.toml index 59c9139a1c..77da52a829 100644 --- a/Project.toml +++ b/Project.toml @@ -6,7 +6,6 @@ version = "0.12.8" AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193" Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" From 4a3c80b0c4c579b018d26ca80e1baeadf058d28b Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Wed, 2 Feb 2022 23:57:17 +0200 Subject: [PATCH 4/7] Refactor --- src/layers/normalise.jl | 19 ++++++--- test/runtests.jl | 90 ++++++++++++++++++++--------------------- 2 files changed, 58 insertions(+), 51 deletions(-) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 1231ddaa87..931ed07b9c 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -206,16 +206,22 @@ function _promote_to_output( _maybe_promote_type(T, Vel), N), Wel) end +function _basetype(::Type{T}) where T + if T <: Array + return Array + elseif T <: CuArray + return CuArray + end + throw("Unsupported type $T") +end + # For InstanceNorm, GroupNorm, and BatchNorm. # Compute the statistics on the slices specified by reduce_dims. # reduce_dims=[1,...,N-2,N] for BatchNorm # reduce_dims=[1,...,N-2] for InstanceNorm and GroupNorm -_norm_layer_forward(l, x; reduce_dims, affine_shape) = - _norm_layer_forward(l, x, _promote_to_output(l, x); reduce_dims, affine_shape) - function _norm_layer_forward( - l, x::Array{T, N}, ::Type{O}; reduce_dims, affine_shape, -) where {T, N, O} + l, x::AbstractArray{T, N}; reduce_dims, affine_shape, +) where {T, N} if !_isactive(l) && l.track_stats # testmode with tracked stats stats_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N) μ = reshape(l.μ, stats_shape) @@ -228,7 +234,8 @@ function _norm_layer_forward( end end - o::Array{O, N} = ((x .- μ) ./ sqrt.(σ² .+ l.ϵ)) + O = _promote_to_output(l, x) + o::_basetype(typeof(x)){O, N} = ((x .- μ) ./ sqrt.(σ² .+ l.ϵ)) hasaffine(l) || return l.λ.(o) γ = reshape(l.γ, affine_shape) diff --git a/test/runtests.jl b/test/runtests.jl index 1a5a12aaa0..781edb549d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,55 +8,55 @@ using CUDA Random.seed!(0) -# @testset "Utils" begin -# include("utils.jl") -# end +@testset "Utils" begin + include("utils.jl") +end -# @testset "Onehot" begin -# include("onehot.jl") -# end +@testset "Onehot" begin + include("onehot.jl") +end -# @testset "Optimise" begin -# include("optimise.jl") -# end +@testset "Optimise" begin + include("optimise.jl") +end -# @testset "Data" begin -# include("data.jl") -# end +@testset "Data" begin + include("data.jl") +end -# @testset "Losses" begin -# include("losses.jl") -# include("ctc.jl") -# CUDA.functional() && include("ctc-gpu.jl") -# end +@testset "Losses" begin + include("losses.jl") + include("ctc.jl") + CUDA.functional() && include("ctc-gpu.jl") +end @testset "Layers" begin - # include("layers/basic.jl") + include("layers/basic.jl") include("layers/normalisation.jl") - # include("layers/stateless.jl") - # include("layers/recurrent.jl") - # include("layers/conv.jl") - # include("layers/upsample.jl") - # include("layers/show.jl") -end - -# @testset "outputsize" begin -# using Flux: outputsize -# include("outputsize.jl") -# end - -# @testset "CUDA" begin -# if CUDA.functional() -# include("cuda/runtests.jl") -# else -# @warn "CUDA unavailable, not testing GPU support" -# end -# end - -# @static if VERSION == v"1.6" -# using Documenter -# @testset "Docs" begin -# DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true) -# doctest(Flux) -# end -# end + include("layers/stateless.jl") + include("layers/recurrent.jl") + include("layers/conv.jl") + include("layers/upsample.jl") + include("layers/show.jl") +end + +@testset "outputsize" begin + using Flux: outputsize + include("outputsize.jl") +end + +@testset "CUDA" begin + if CUDA.functional() + include("cuda/runtests.jl") + else + @warn "CUDA unavailable, not testing GPU support" + end +end + +@static if VERSION == v"1.6" + using Documenter + @testset "Docs" begin + DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true) + doctest(Flux) + end +end From c85ec1e4fda241adfbacdf0197396ce2aa63c868 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Thu, 3 Feb 2022 01:11:42 +0200 Subject: [PATCH 5/7] Reduce number of changes --- src/layers/conv.jl | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 316571e71a..0942a1db3c 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -162,9 +162,7 @@ end function (c::Conv)(x::AbstractArray) σ, b = c.σ, reshape(c.bias, ntuple(_ -> 1, length(c.stride))..., :, 1) - cdims = DenseConvDims( - x, c.weight; stride=c.stride, padding=c.pad, - dilation=c.dilation, groups=c.groups) + cdims = DenseConvDims(x, c.weight; stride = c.stride, padding = c.pad, dilation = c.dilation, groups = c.groups) σ.(conv(x, c.weight, cdims) .+ b) end @@ -658,23 +656,19 @@ julia> lay(rand(Float32, 100, 7, 50)) |> size (34, 7, 50) ``` """ -struct MaxPool{N, M} +struct MaxPool{N,M} k::NTuple{N,Int} pad::NTuple{M,Int} stride::NTuple{N,Int} end -function MaxPool(k::NTuple{N, Integer}; pad = 0, stride = k) where N +function MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N stride = expand(Val(N), stride) pad = calc_padding(MaxPool, pad, k, 1, stride) return MaxPool(k, pad, stride) end function (m::MaxPool)(x) - # size_x = size(x) - # kernel, stride, padding, dilation = NNlib.prepare_pooldims( - # Val(N), size_x, m.k; padding=m.pad, stride=m.stride) - # pdims = PoolDims{kernel, stride, padding, dilation}(size_x) pdims = PoolDims(x, m.k; padding=m.pad, stride=m.stride) return maxpool(x, pdims) end From ea507878b11b1b84e8a4bda0334b51f409cf8781 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Thu, 3 Feb 2022 16:29:31 +0200 Subject: [PATCH 6/7] Use inference barrier --- src/layers/normalise.jl | 33 +++++++-------------------------- 1 file changed, 7 insertions(+), 26 deletions(-) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 931ed07b9c..266909037e 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -195,26 +195,6 @@ _maybe_promote_type(::Type{T1}, ::Type{Nothing}) where T1 = T1 _maybe_eltype(::Type{T}) where T <: AbstractArray = eltype(T) _maybe_eltype(::Type{Nothing}) = Nothing -abstract type Normalization{F, V, N, W} end - -function _promote_to_output( - ::Normalization{F, V, N, W}, x::AbstractArray{T}, -) where {F, V, N, W, T} - Vel = _maybe_eltype(V) - Wel = _maybe_eltype(W) - _maybe_promote_type(_maybe_promote_type( - _maybe_promote_type(T, Vel), N), Wel) -end - -function _basetype(::Type{T}) where T - if T <: Array - return Array - elseif T <: CuArray - return CuArray - end - throw("Unsupported type $T") -end - # For InstanceNorm, GroupNorm, and BatchNorm. # Compute the statistics on the slices specified by reduce_dims. # reduce_dims=[1,...,N-2,N] for BatchNorm @@ -234,8 +214,7 @@ function _norm_layer_forward( end end - O = _promote_to_output(l, x) - o::_basetype(typeof(x)){O, N} = ((x .- μ) ./ sqrt.(σ² .+ l.ϵ)) + o = _norm_layer_forward(x, μ, σ², l.ϵ) hasaffine(l) || return l.λ.(o) γ = reshape(l.γ, affine_shape) @@ -243,6 +222,8 @@ function _norm_layer_forward( return l.λ.(γ .* o .+ β) end +@inline _norm_layer_forward(x, μ, σ², ϵ) = (x .- μ) ./ sqrt.(σ² .+ ϵ) + function _track_stats!( bn, x::AbstractArray{T, N}, μ, σ², reduce_dims, ) where {T, N} @@ -256,7 +237,7 @@ function _track_stats!( bn.μ = res_mtm .* bn.μ .+ mtm .* μnew bn.σ² = res_mtm .* bn.σ² .+ mtm .* (m / (m - one(V))) .* σ²new - nothing + return nothing end Zygote.@nograd _track_stats! @@ -296,7 +277,7 @@ m = Chain( softmax) ``` """ -mutable struct BatchNorm{F,V,N,W} <: Normalization{F, V, N, W} +mutable struct BatchNorm{F,V,N,W} λ::F # activation function β::V # bias γ::V # scale @@ -372,7 +353,7 @@ that will be used to renormalize the input in test phase. **Warning**: the defaults for `affine` and `track_stats` used to be `true` in previous Flux versions (< v0.12). """ -mutable struct InstanceNorm{F,V,N,W} <: Normalization{F, V, N, W} +mutable struct InstanceNorm{F,V,N,W} λ::F # activation function β::V # bias γ::V # scale @@ -449,7 +430,7 @@ through to learnable per-channel bias `β` and scale `γ` parameters. If `track_stats=true`, accumulates mean and var statistics in training phase that will be used to renormalize the input in test phase. """ -mutable struct GroupNorm{F,V,N,W} <: Normalization{F, V, N, W} +mutable struct GroupNorm{F,V,N,W} G::Int # number of groups λ::F # activation function β::V # bias From d1510801499a6c2b881d45fcff4c85251665744a Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Thu, 3 Feb 2022 16:51:26 +0200 Subject: [PATCH 7/7] Remove more stuff --- src/layers/normalise.jl | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 266909037e..d7fabf7dc8 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -188,13 +188,6 @@ function Base.show(io::IO, l::LayerNorm) print(io, ")") end -_maybe_promote_type(::Type{T1}, ::Type{T2}) where {T1, T2} = promote_type(T1, T2) -_maybe_promote_type(::Type{Nothing}, ::Type{T2}) where T2 = T2 -_maybe_promote_type(::Type{T1}, ::Type{Nothing}) where T1 = T1 - -_maybe_eltype(::Type{T}) where T <: AbstractArray = eltype(T) -_maybe_eltype(::Type{Nothing}) = Nothing - # For InstanceNorm, GroupNorm, and BatchNorm. # Compute the statistics on the slices specified by reduce_dims. # reduce_dims=[1,...,N-2,N] for BatchNorm