diff --git a/src/Flux.jl b/src/Flux.jl index efbba31e56..5ec715dae2 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -70,11 +70,11 @@ include("loading.jl") include("outputsize.jl") export @autosize +include("deprecations.jl") + include("losses/Losses.jl") using .Losses -include("deprecations.jl") - include("cuda/cuda.jl") end # module diff --git a/src/deprecations.jl b/src/deprecations.jl index 627e5bd5b8..b796c498d7 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -203,6 +203,17 @@ function trainmode!(m, active::Bool) testmode!(m, !active) end +# Greek-letter keywords deprecated in Flux 0.13 +# Arguments (old => new, :function, "β" => "beta") +function _greek_ascii_depwarn(βbeta::Pair, func = :loss, names = "" => "") + Base.depwarn("""function $func no longer accepts greek-letter keyword $(names.first) + please use ascii $(names.second) instead""", func) + βbeta.first +end +_greek_ascii_depwarn(βbeta::Pair{Nothing}, _...) = βbeta.second + +ChainRulesCore.@non_differentiable _greek_ascii_depwarn(::Any...) + # v0.14 deprecations diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 3794576e76..a37f1bd863 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -152,11 +152,12 @@ testmode!(m::AlphaDropout, mode=true) = (m.active = isnothing(_tidy_active(mode)) ? nothing : !mode; m) """ - LayerNorm(size..., λ=identity; affine=true, ϵ=1fe-5) + LayerNorm(size..., λ=identity; affine=true, eps=1f-5) A [normalisation layer](https://arxiv.org/abs/1607.06450) designed to be used with recurrent hidden states. The argument `size` should be an integer or a tuple of integers. + In the forward pass, the layer normalises the mean and standard deviation of the input, then applies the elementwise activation `λ`. The input is normalised along the first `length(size)` dimensions @@ -190,9 +191,10 @@ struct LayerNorm{F,D,T,N} affine::Bool end -function LayerNorm(size::Tuple{Vararg{Int}}, λ=identity; affine::Bool=true, ϵ::Real=1f-5) +function LayerNorm(size::Tuple{Vararg{Int}}, λ=identity; affine::Bool=true, eps::Real=1f-5, ϵ=nothing) + ε = _greek_ascii_depwarn(ϵ => eps, :LayerNorm, "ϵ" => "eps") diag = affine ? Scale(size..., λ) : λ!=identity ? Base.Fix1(broadcast, λ) : identity - return LayerNorm(λ, diag, ϵ, size, affine) + return LayerNorm(λ, diag, ε, size, affine) end LayerNorm(size::Integer...; kw...) = LayerNorm(Int.(size); kw...) LayerNorm(size_act...; kw...) = LayerNorm(Int.(size_act[1:end-1]), size_act[end]; kw...) @@ -269,7 +271,7 @@ ChainRulesCore.@non_differentiable _track_stats!(::Any...) BatchNorm(channels::Integer, λ=identity; initβ=zeros32, initγ=ones32, affine=true, track_stats=true, active=nothing, - ϵ=1f-5, momentum= 0.1f0) + eps=1f-5, momentum= 0.1f0) [Batch Normalization](https://arxiv.org/abs/1502.03167) layer. `channels` should be the size of the channel dimension in your data (see below). @@ -321,8 +323,10 @@ end function BatchNorm(chs::Int, λ=identity; initβ=zeros32, initγ=ones32, - affine=true, track_stats=true, active::Union{Bool,Nothing}=nothing, - ϵ=1f-5, momentum=0.1f0) + affine::Bool=true, track_stats::Bool=true, active::Union{Bool,Nothing}=nothing, + eps::Real=1f-5, momentum::Real=0.1f0, ϵ=nothing) + + ε = _greek_ascii_depwarn(ϵ => eps, :BatchNorm, "ϵ" => "eps") β = affine ? initβ(chs) : nothing γ = affine ? initγ(chs) : nothing @@ -330,7 +334,7 @@ function BatchNorm(chs::Int, λ=identity; σ² = track_stats ? ones32(chs) : nothing return BatchNorm(λ, β, γ, - μ, σ², ϵ, momentum, + μ, σ², ε, momentum, affine, track_stats, active, chs) end @@ -361,7 +365,7 @@ end InstanceNorm(channels::Integer, λ=identity; initβ=zeros32, initγ=ones32, affine=false, track_stats=false, - ϵ=1f-5, momentum=0.1f0) + eps=1f-5, momentum=0.1f0) [Instance Normalization](https://arxiv.org/abs/1607.08022) layer. `channels` should be the size of the channel dimension in your data (see below). @@ -411,8 +415,10 @@ end function InstanceNorm(chs::Int, λ=identity; initβ=zeros32, initγ=ones32, - affine=false, track_stats=false, active::Union{Bool,Nothing}=nothing, - ϵ=1f-5, momentum=0.1f0) + affine::Bool=false, track_stats::Bool=false, active::Union{Bool,Nothing}=nothing, + eps::Real=1f-5, momentum::Real=0.1f0, ϵ=nothing) + + ε = _greek_ascii_depwarn(ϵ => eps, :InstanceNorm, "ϵ" => "eps") β = affine ? initβ(chs) : nothing γ = affine ? initγ(chs) : nothing @@ -420,7 +426,7 @@ function InstanceNorm(chs::Int, λ=identity; σ² = track_stats ? ones32(chs) : nothing return InstanceNorm(λ, β, γ, - μ, σ², ϵ, momentum, + μ, σ², ε, momentum, affine, track_stats, active, chs) end @@ -450,7 +456,7 @@ end GroupNorm(channels::Integer, G::Integer, λ=identity; initβ=zeros32, initγ=ones32, affine=true, track_stats=false, - ϵ=1f-5, momentum=0.1f0) + eps=1f-5, momentum=0.1f0) [Group Normalization](https://arxiv.org/abs/1803.08494) layer. @@ -508,12 +514,13 @@ trainable(gn::GroupNorm) = hasaffine(gn) ? (β = gn.β, γ = gn.γ) : (;) function GroupNorm(chs::Int, G::Int, λ=identity; initβ=zeros32, initγ=ones32, - affine=true, track_stats=false, active::Union{Bool,Nothing}=nothing, - ϵ=1f-5, momentum=0.1f0) + affine::Bool=true, track_stats::Bool=false, active::Union{Bool,Nothing}=nothing, + eps::Real=1f-5, momentum::Real=0.1f0, ϵ=nothing) -if track_stats + if track_stats Base.depwarn("`track_stats=true` will be removed from GroupNorm in Flux 0.14. The default value is `track_stats=false`, which will work as before.", :GroupNorm) -end + end + ε = _greek_ascii_depwarn(ϵ => eps, :GroupNorm, "ϵ" => "eps") chs % G == 0 || error("The number of groups ($(G)) must divide the number of channels ($chs)") @@ -525,7 +532,7 @@ end return GroupNorm(G, λ, β, γ, μ, σ², - ϵ, momentum, + ε, momentum, affine, track_stats, active, chs) end diff --git a/src/losses/Losses.jl b/src/losses/Losses.jl index 3d8f6f8149..7f9fcbe429 100644 --- a/src/losses/Losses.jl +++ b/src/losses/Losses.jl @@ -4,7 +4,7 @@ using Statistics using Zygote using Zygote: @adjoint using ChainRulesCore -using ..Flux: ofeltype, epseltype +using ..Flux: ofeltype, epseltype, _greek_ascii_depwarn using CUDA using NNlib: logsoftmax, logσ, ctc_loss, ctc_alpha, ∇ctc_loss import Base.Broadcast: broadcasted diff --git a/src/losses/functions.jl b/src/losses/functions.jl index c40d4dcd76..45a2804db0 100644 --- a/src/losses/functions.jl +++ b/src/losses/functions.jl @@ -48,13 +48,13 @@ function mse(ŷ, y; agg = mean) end """ - msle(ŷ, y; agg = mean, ϵ = eps(ŷ)) + msle(ŷ, y; agg = mean, eps = eps(eltype(ŷ))) The loss corresponding to mean squared logarithmic errors, calculated as agg((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)) .^ 2) -The `ϵ` term provides numerical stability. +The `ϵ == eps` term provides numerical stability. Penalizes an under-estimation more than an over-estimatation. # Example @@ -66,13 +66,14 @@ julia> Flux.msle(Float32[0.9, 1.8, 2.7], 1:3) 0.011100831f0 ``` """ -function msle(ŷ, y; agg = mean, ϵ = epseltype(ŷ)) +function msle(ŷ, y; agg = mean, eps::Real = epseltype(ŷ), ϵ = nothing) + ϵ = _greek_ascii_depwarn(ϵ => eps, :msle, "ϵ" => "eps") _check_sizes(ŷ, y) agg((log.((ŷ .+ ϵ) ./ (y .+ ϵ))) .^2 ) end """ - huber_loss(ŷ, y; δ = 1, agg = mean) + huber_loss(ŷ, y; delta = 1, agg = mean) Return the mean of the [Huber loss](https://en.wikipedia.org/wiki/Huber_loss) given the prediction `ŷ` and true values `y`. @@ -82,17 +83,20 @@ given the prediction `ŷ` and true values `y`. | δ * (|ŷ - y| - 0.5 * δ), otherwise # Example + ```jldoctest julia> ŷ = [1.1, 2.1, 3.1]; julia> Flux.huber_loss(ŷ, 1:3) # default δ = 1 > |ŷ - y| 0.005000000000000009 -julia> Flux.huber_loss(ŷ, 1:3, δ=0.05) # changes behaviour as |ŷ - y| > δ +julia> Flux.huber_loss(ŷ, 1:3, delta=0.05) # changes behaviour as |ŷ - y| > δ 0.003750000000000005 ``` """ -function huber_loss(ŷ, y; agg = mean, δ = ofeltype(ŷ, 1)) +function huber_loss(ŷ, y; agg = mean, delta::Real = 1, δ = nothing) + delta_tmp = _greek_ascii_depwarn(δ => delta, :huber_loss, "δ" => "delta") + δ = ofeltype(ŷ, delta_tmp) _check_sizes(ŷ, y) abs_error = abs.(ŷ .- y) #TODO: remove ignore_derivatives when Zygote can handle this function with CuArrays @@ -167,7 +171,7 @@ function label_smoothing(y::Union{AbstractArray,Number}, α::Number; dims::Int = end """ - crossentropy(ŷ, y; dims = 1, ϵ = eps(ŷ), agg = mean) + crossentropy(ŷ, y; dims = 1, eps = eps(eltype(ŷ)), agg = mean) Return the cross entropy between the given probability distributions; calculated as @@ -222,7 +226,8 @@ julia> Flux.crossentropy(y_model, y_smooth) 1.5776052f0 ``` """ -function crossentropy(ŷ, y; dims = 1, agg = mean, ϵ = epseltype(ŷ)) +function crossentropy(ŷ, y; dims = 1, agg = mean, eps::Real = epseltype(ŷ), ϵ = nothing) + ϵ = _greek_ascii_depwarn(ϵ => eps, :crossentropy, "ϵ" => "eps") _check_sizes(ŷ, y) agg(.-sum(xlogy.(y, ŷ .+ ϵ); dims = dims)) end @@ -267,14 +272,14 @@ function logitcrossentropy(ŷ, y; dims = 1, agg = mean) end """ - binarycrossentropy(ŷ, y; agg = mean, ϵ = eps(ŷ)) + binarycrossentropy(ŷ, y; agg = mean, eps = eps(eltype(ŷ))) Return the binary cross-entropy loss, computed as agg(@.(-y * log(ŷ + ϵ) - (1 - y) * log(1 - ŷ + ϵ))) Where typically, the prediction `ŷ` is given by the output of a [sigmoid](@ref man-activations) activation. -The `ϵ` term is included to avoid infinity. Using [`logitbinarycrossentropy`](@ref) is recomended +The `ϵ == eps` term is included to avoid infinity. Using [`logitbinarycrossentropy`](@ref) is recomended over `binarycrossentropy` for numerical stability. Use [`label_smoothing`](@ref) to smooth the `y` value as preprocessing before @@ -310,7 +315,8 @@ julia> Flux.crossentropy(y_prob, y_hot) 0.43989f0 ``` """ -function binarycrossentropy(ŷ, y; agg = mean, ϵ = epseltype(ŷ)) +function binarycrossentropy(ŷ, y; agg = mean, eps::Real = epseltype(ŷ), ϵ = nothing) + ϵ = _greek_ascii_depwarn(ϵ => eps, :binarycrossentropy, "ϵ" => "eps") _check_sizes(ŷ, y) agg(@.(-xlogy(y, ŷ + ϵ) - xlogy(1 - y, 1 - ŷ + ϵ))) end @@ -346,7 +352,7 @@ function logitbinarycrossentropy(ŷ, y; agg = mean) end """ - kldivergence(ŷ, y; agg = mean, ϵ = eps(ŷ)) + kldivergence(ŷ, y; agg = mean, eps = eps(eltype(ŷ))) Return the [Kullback-Leibler divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence) @@ -373,17 +379,18 @@ true julia> Flux.kldivergence(p2, p1; agg = sum) ≈ 2log(2) true -julia> Flux.kldivergence(p2, p2; ϵ = 0) # about -2e-16 with the regulator +julia> Flux.kldivergence(p2, p2; eps = 0) # about -2e-16 with the regulator 0.0 -julia> Flux.kldivergence(p1, p2; ϵ = 0) # about 17.3 with the regulator +julia> Flux.kldivergence(p1, p2; eps = 0) # about 17.3 with the regulator Inf ``` """ -function kldivergence(ŷ, y; dims = 1, agg = mean, ϵ = epseltype(ŷ)) +function kldivergence(ŷ, y; dims = 1, agg = mean, eps::Real = epseltype(ŷ), ϵ = nothing) + ϵ = _greek_ascii_depwarn(ϵ => eps, :kldivergence, "ϵ" => "eps") _check_sizes(ŷ, y) - entropy = agg(sum(xlogx.(y), dims = dims)) - cross_entropy = crossentropy(ŷ, y; dims = dims, agg = agg, ϵ = ϵ) + entropy = agg(sum(xlogx.(y); dims = dims)) + cross_entropy = crossentropy(ŷ, y; dims, agg, eps=ϵ) return entropy + cross_entropy end @@ -501,23 +508,27 @@ julia> 1 - Flux.dice_coeff_loss(y_pred, 1:3) # ~ F1 score for image segmentatio 0.99900760833609 ``` """ -function dice_coeff_loss(ŷ, y; smooth = ofeltype(ŷ, 1.0)) +function dice_coeff_loss(ŷ, y; smooth = 1) + s = ofeltype(ŷ, smooth) _check_sizes(ŷ, y) - 1 - (2 * sum(y .* ŷ) + smooth) / (sum(y .^ 2) + sum(ŷ .^ 2) + smooth) #TODO agg + # TODO add agg + 1 - (2 * sum(y .* ŷ) + s) / (sum(y .^ 2) + sum(ŷ .^ 2) + s) end """ - tversky_loss(ŷ, y; β = 0.7) + tversky_loss(ŷ, y; beta = 0.7) Return the [Tversky loss](https://arxiv.org/abs/1706.05721). Used with imbalanced data to give more weight to false negatives. -Larger β weigh recall more than precision (by placing more emphasis on false negatives). +Larger `β == beta` weigh recall more than precision (by placing more emphasis on false negatives). Calculated as: 1 - sum(|y .* ŷ| + 1) / (sum(y .* ŷ + (1 - β)*(1 .- y) .* ŷ + β*y .* (1 .- ŷ)) + 1) """ -function tversky_loss(ŷ, y; β = ofeltype(ŷ, 0.7)) +function tversky_loss(ŷ, y; beta::Real = 0.7, β = nothing) + beta_temp = _greek_ascii_depwarn(β => beta, :tversky_loss, "β" => "beta") + β = ofeltype(ŷ, beta_temp) _check_sizes(ŷ, y) #TODO add agg num = sum(y .* ŷ) + 1 @@ -526,12 +537,12 @@ function tversky_loss(ŷ, y; β = ofeltype(ŷ, 0.7)) end """ - binary_focal_loss(ŷ, y; agg=mean, γ=2, ϵ=eps(ŷ)) + binary_focal_loss(ŷ, y; agg=mean, gamma=2, eps=eps(eltype(ŷ))) Return the [binary_focal_loss](https://arxiv.org/pdf/1708.02002.pdf) The input, 'ŷ', is expected to be normalized (i.e. [softmax](@ref Softmax) output). -For `γ == 0`, the loss is mathematically equivalent to [`Losses.binarycrossentropy`](@ref). +For `gamma = 0`, the loss is mathematically equivalent to [`Losses.binarycrossentropy`](@ref). See also: [`Losses.focal_loss`](@ref) for multi-class setting @@ -553,25 +564,28 @@ julia> Flux.binary_focal_loss(ŷ, y) ≈ 0.0728675615927385 true ``` """ -function binary_focal_loss(ŷ, y; agg=mean, γ=2, ϵ=epseltype(ŷ)) +function binary_focal_loss(ŷ, y; agg=mean, gamma=2, eps::Real=epseltype(ŷ), ϵ = nothing, γ = nothing) + ϵ = _greek_ascii_depwarn(ϵ => eps, :binary_focal_loss, "ϵ" => "eps") + gamma_temp = _greek_ascii_depwarn(γ => gamma, :binary_focal_loss, "γ" => "gamma") + γ = gamma_temp isa Integer ? gamma_temp : ofeltype(ŷ, gamma_temp) _check_sizes(ŷ, y) - ŷ = ŷ .+ ϵ - p_t = y .* ŷ + (1 .- y) .* (1 .- ŷ) - ce = -log.(p_t) + ŷϵ = ŷ .+ ϵ + p_t = y .* ŷϵ + (1 .- y) .* (1 .- ŷϵ) + ce = .-log.(p_t) weight = (1 .- p_t) .^ γ loss = weight .* ce agg(loss) end """ - focal_loss(ŷ, y; dims=1, agg=mean, γ=2, ϵ=eps(ŷ)) + focal_loss(ŷ, y; dims=1, agg=mean, gamma=2, eps=eps(eltype(ŷ))) Return the [focal_loss](https://arxiv.org/pdf/1708.02002.pdf) which can be used in classification tasks with highly imbalanced classes. It down-weights well-classified examples and focuses on hard examples. The input, 'ŷ', is expected to be normalized (i.e. [softmax](@ref Softmax) output). -The modulating factor, `γ`, controls the down-weighting strength. +The modulating factor, `γ == gamma`, controls the down-weighting strength. For `γ == 0`, the loss is mathematically equivalent to [`Losses.crossentropy`](@ref). # Example @@ -597,10 +611,13 @@ true See also: [`Losses.binary_focal_loss`](@ref) for binary (not one-hot) labels """ -function focal_loss(ŷ, y; dims=1, agg=mean, γ=2, ϵ=epseltype(ŷ)) - _check_sizes(ŷ, y) - ŷ = ŷ .+ ϵ - agg(sum(@. -y * (1 - ŷ)^γ * log(ŷ); dims=dims)) +function focal_loss(ŷ, y; dims=1, agg=mean, gamma=2, eps::Real=epseltype(ŷ), ϵ=nothing, γ=nothing) + ϵ = _greek_ascii_depwarn(ϵ => eps, :focal_loss, "ϵ" => "eps") + gamma_temp = _greek_ascii_depwarn(γ => gamma, :focal_loss, "γ" => "gamma") + γ = gamma_temp isa Integer ? gamma_temp : ofeltype(ŷ, gamma_temp) + _check_sizes(ŷ, y) + ŷϵ = ŷ .+ ϵ + agg(sum(@. -y * (1 - ŷϵ)^γ * log(ŷϵ); dims)) end """ diff --git a/test/losses.jl b/test/losses.jl index 2ca697a657..a7f23a06c4 100644 --- a/test/losses.jl +++ b/test/losses.jl @@ -1,5 +1,6 @@ using Test using Flux: onehotbatch, σ +using Statistics: mean using Flux.Losses: mse, label_smoothing, crossentropy, logitcrossentropy, binarycrossentropy, logitbinarycrossentropy using Flux.Losses: xlogx, xlogy @@ -88,8 +89,8 @@ y_dis[1,:], y_dis[2,:] = y_dis[2,:], y_dis[1,:] @test crossentropy(ŷ, y_smoothed) ≈ lossvalue_smoothed @test crossentropy(ylp, label_smoothing(yl, 2sf)) ≈ -sum(yls.*log.(ylp)) @test crossentropy(ylp, yl) ≈ -sum(yl.*log.(ylp)) - @test iszero(crossentropy(y_same, ya, ϵ=0)) - @test iszero(crossentropy(ya, ya, ϵ=0)) + @test iszero(crossentropy(y_same, ya, ϵ=0)) # ε is deprecated + @test iszero(crossentropy(ya, ya, eps=0)) @test crossentropy(y_sim, ya) < crossentropy(y_sim, ya_smoothed) @test crossentropy(y_dis, ya) > crossentropy(y_dis, ya_smoothed) end @@ -105,7 +106,7 @@ yls = y.*(1-2sf).+sf @testset "binarycrossentropy" begin @test binarycrossentropy.(σ.(logŷ), label_smoothing(y, 2sf; dims=0); ϵ=0) ≈ -yls.*log.(σ.(logŷ)) - (1 .- yls).*log.(1 .- σ.(logŷ)) - @test binarycrossentropy(σ.(logŷ), y; ϵ=0) ≈ mean(-y.*log.(σ.(logŷ)) - (1 .- y).*log.(1 .- σ.(logŷ))) + @test binarycrossentropy(σ.(logŷ), y; eps=0) ≈ mean(-y.*log.(σ.(logŷ)) - (1 .- y).*log.(1 .- σ.(logŷ))) @test binarycrossentropy(σ.(logŷ), y) ≈ mean(-y.*log.(σ.(logŷ) .+ eps.(σ.(logŷ))) - (1 .- y).*log.(1 .- σ.(logŷ) .+ eps.(σ.(logŷ)))) @test binarycrossentropy([0.1,0.2,0.9], 1) ≈ -mean(log, [0.1,0.2,0.9]) # constant label end @@ -208,7 +209,7 @@ end 0.1 0.3] @test Flux.focal_loss(ŷ, y) ≈ 1.1277571935622628 @test Flux.focal_loss(ŷ1, y1) ≈ 0.45990566879720157 - @test Flux.focal_loss(ŷ, y; γ=0.0) ≈ Flux.crossentropy(ŷ, y) + @test Flux.focal_loss(ŷ, y; gamma=0) ≈ Flux.crossentropy(ŷ, y) end @testset "siamese_contrastive_loss" begin