Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove greek-letter keyword arguments #2139

Merged
merged 5 commits into from
Apr 29, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 24 additions & 17 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,12 @@ testmode!(m::AlphaDropout, mode=true) =
(m.active = isnothing(_tidy_active(mode)) ? nothing : !mode; m)

"""
LayerNorm(size..., λ=identity; affine=true, ϵ=1fe-5)
LayerNorm(size..., λ=identity; affine=true, eps=1f-5)

A [normalisation layer](https://arxiv.org/abs/1607.06450) designed to be
used with recurrent hidden states.
The argument `size` should be an integer or a tuple of integers.

In the forward pass, the layer normalises the mean and standard
deviation of the input, then applies the elementwise activation `λ`.
The input is normalised along the first `length(size)` dimensions
Expand Down Expand Up @@ -190,9 +191,10 @@ struct LayerNorm{F,D,T,N}
affine::Bool
end

function LayerNorm(size::Tuple{Vararg{Int}}, λ=identity; affine::Bool=true, ϵ::Real=1f-5)
function LayerNorm(size::Tuple{Vararg{Int}}, λ=identity; affine::Bool=true, eps::Real=1f-5, ϵ=nothing)
ε = Losses._greek_ascii_depwarn(ϵ => eps, :LayerNorm, "ϵ" => "eps")
diag = affine ? Scale(size..., λ) : λ!=identity ? Base.Fix1(broadcast, λ) : identity
return LayerNorm(λ, diag, ϵ, size, affine)
return LayerNorm(λ, diag, ε, size, affine)
end
LayerNorm(size::Integer...; kw...) = LayerNorm(Int.(size); kw...)
LayerNorm(size_act...; kw...) = LayerNorm(Int.(size_act[1:end-1]), size_act[end]; kw...)
Expand Down Expand Up @@ -269,7 +271,7 @@ ChainRulesCore.@non_differentiable _track_stats!(::Any...)
BatchNorm(channels::Integer, λ=identity;
initβ=zeros32, initγ=ones32,
affine=true, track_stats=true, active=nothing,
ϵ=1f-5, momentum= 0.1f0)
eps=1f-5, momentum= 0.1f0)

[Batch Normalization](https://arxiv.org/abs/1502.03167) layer.
`channels` should be the size of the channel dimension in your data (see below).
Expand Down Expand Up @@ -321,16 +323,18 @@ end

function BatchNorm(chs::Int, λ=identity;
initβ=zeros32, initγ=ones32,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR does not remove initβ, initγ. IMO they should be replaced with a method where you pass in a Scale layer. But perhaps better part of an overhaul of norm layers.

Getting rid of greek-letter field names may also be a good idea. The norm layers are the worst offenders.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can tackle this as part of a bigger rework of norm layer internals to save on some churn.

affine=true, track_stats=true, active::Union{Bool,Nothing}=nothing,
ϵ=1f-5, momentum=0.1f0)
affine::Bool=true, track_stats::Bool=true, active::Union{Bool,Nothing}=nothing,
eps::Real=1f-5, momentum::Real=0.1f0, ϵ=nothing)

ε = Losses._greek_ascii_depwarn(ϵ => eps, :BatchNorm, "ϵ" => "eps")

β = affine ? initβ(chs) : nothing
γ = affine ? initγ(chs) : nothing
μ = track_stats ? zeros32(chs) : nothing
σ² = track_stats ? ones32(chs) : nothing

return BatchNorm(λ, β, γ,
μ, σ², ϵ, momentum,
μ, σ², ε, momentum,
affine, track_stats,
active, chs)
end
Expand Down Expand Up @@ -361,7 +365,7 @@ end
InstanceNorm(channels::Integer, λ=identity;
initβ=zeros32, initγ=ones32,
affine=false, track_stats=false,
ϵ=1f-5, momentum=0.1f0)
eps=1f-5, momentum=0.1f0)

[Instance Normalization](https://arxiv.org/abs/1607.08022) layer.
`channels` should be the size of the channel dimension in your data (see below).
Expand Down Expand Up @@ -411,16 +415,18 @@ end

function InstanceNorm(chs::Int, λ=identity;
initβ=zeros32, initγ=ones32,
affine=false, track_stats=false, active::Union{Bool,Nothing}=nothing,
ϵ=1f-5, momentum=0.1f0)
affine::Bool=false, track_stats::Bool=false, active::Union{Bool,Nothing}=nothing,
eps::Real=1f-5, momentum::Real=0.1f0, ϵ=nothing)

ε = Losses._greek_ascii_depwarn(ϵ => eps, :InstanceNorm, "ϵ" => "eps")

β = affine ? initβ(chs) : nothing
γ = affine ? initγ(chs) : nothing
μ = track_stats ? zeros32(chs) : nothing
σ² = track_stats ? ones32(chs) : nothing

return InstanceNorm(λ, β, γ,
μ, σ², ϵ, momentum,
μ, σ², ε, momentum,
affine, track_stats,
active, chs)
end
Expand Down Expand Up @@ -450,7 +456,7 @@ end
GroupNorm(channels::Integer, G::Integer, λ=identity;
initβ=zeros32, initγ=ones32,
affine=true, track_stats=false,
ϵ=1f-5, momentum=0.1f0)
eps=1f-5, momentum=0.1f0)

[Group Normalization](https://arxiv.org/abs/1803.08494) layer.

Expand Down Expand Up @@ -508,12 +514,13 @@ trainable(gn::GroupNorm) = hasaffine(gn) ? (β = gn.β, γ = gn.γ) : (;)

function GroupNorm(chs::Int, G::Int, λ=identity;
initβ=zeros32, initγ=ones32,
affine=true, track_stats=false, active::Union{Bool,Nothing}=nothing,
ϵ=1f-5, momentum=0.1f0)
affine::Bool=true, track_stats::Bool=false, active::Union{Bool,Nothing}=nothing,
eps::Real=1f-5, momentum::Real=0.1f0, ϵ=nothing)

if track_stats
if track_stats
Base.depwarn("`track_stats=true` will be removed from GroupNorm in Flux 0.14. The default value is `track_stats=false`, which will work as before.", :GroupNorm)
end
end
ε = Losses._greek_ascii_depwarn(ϵ => eps, :GroupNorm, "ϵ" => "eps")

chs % G == 0 || error("The number of groups ($(G)) must divide the number of channels ($chs)")

Expand All @@ -525,7 +532,7 @@ end
return GroupNorm(G, λ,
β, γ,
μ, σ²,
ϵ, momentum,
ε, momentum,
affine, track_stats,
active, chs)
end
Expand Down
85 changes: 51 additions & 34 deletions src/losses/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ function mse(ŷ, y; agg = mean)
end

"""
msle(ŷ, y; agg = mean, ϵ = eps())
msle(ŷ, y; agg = mean, eps = eps(eltype(ŷ)))

The loss corresponding to mean squared logarithmic errors, calculated as

agg((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)) .^ 2)

The `ϵ` term provides numerical stability.
The `ϵ == eps` term provides numerical stability.
Penalizes an under-estimation more than an over-estimatation.

# Example
Expand All @@ -66,13 +66,14 @@ julia> Flux.msle(Float32[0.9, 1.8, 2.7], 1:3)
0.011100831f0
```
"""
function msle(ŷ, y; agg = mean, ϵ = epseltype(ŷ))
function msle(ŷ, y; agg = mean, eps::Real = epseltype(ŷ), ϵ = nothing)
ϵ = _greek_ascii_depwarn(ϵ => eps, :msle, "ϵ" => "eps")
_check_sizes(ŷ, y)
agg((log.((ŷ .+ ϵ) ./ (y .+ ϵ))) .^2 )
end

"""
huber_loss(ŷ, y; δ = 1, agg = mean)
huber_loss(ŷ, y; delta = 1, agg = mean)

Return the mean of the [Huber loss](https://en.wikipedia.org/wiki/Huber_loss)
given the prediction `ŷ` and true values `y`.
Expand All @@ -82,17 +83,20 @@ given the prediction `ŷ` and true values `y`.
| δ * (|ŷ - y| - 0.5 * δ), otherwise

# Example

```jldoctest
julia> ŷ = [1.1, 2.1, 3.1];

julia> Flux.huber_loss(ŷ, 1:3) # default δ = 1 > |ŷ - y|
0.005000000000000009

julia> Flux.huber_loss(ŷ, 1:3, δ=0.05) # changes behaviour as |ŷ - y| > δ
julia> Flux.huber_loss(ŷ, 1:3, delta=0.05) # changes behaviour as |ŷ - y| > δ
0.003750000000000005
```
"""
function huber_loss(ŷ, y; agg = mean, δ = ofeltype(ŷ, 1))
function huber_loss(ŷ, y; agg = mean, delta::Real = 1, δ = nothing)
delta_tmp = _greek_ascii_depwarn(δ => delta, :huber_loss, "δ" => "delta")
δ = ofeltype(ŷ, delta_tmp)
_check_sizes(ŷ, y)
abs_error = abs.(ŷ .- y)
#TODO: remove ignore_derivatives when Zygote can handle this function with CuArrays
Expand Down Expand Up @@ -167,7 +171,7 @@ function label_smoothing(y::Union{AbstractArray,Number}, α::Number; dims::Int =
end

"""
crossentropy(ŷ, y; dims = 1, ϵ = eps(), agg = mean)
crossentropy(ŷ, y; dims = 1, eps = eps(eltype(ŷ)), agg = mean)

Return the cross entropy between the given probability distributions;
calculated as
Expand Down Expand Up @@ -222,7 +226,8 @@ julia> Flux.crossentropy(y_model, y_smooth)
1.5776052f0
```
"""
function crossentropy(ŷ, y; dims = 1, agg = mean, ϵ = epseltype(ŷ))
function crossentropy(ŷ, y; dims = 1, agg = mean, eps::Real = epseltype(ŷ), ϵ = nothing)
ϵ = _greek_ascii_depwarn(ϵ => eps, :crossentropy, "ϵ" => "eps")
_check_sizes(ŷ, y)
agg(.-sum(xlogy.(y, ŷ .+ ϵ); dims = dims))
end
Expand Down Expand Up @@ -267,14 +272,14 @@ function logitcrossentropy(ŷ, y; dims = 1, agg = mean)
end

"""
binarycrossentropy(ŷ, y; agg = mean, ϵ = eps())
binarycrossentropy(ŷ, y; agg = mean, eps = eps(eltype(ŷ)))

Return the binary cross-entropy loss, computed as

agg(@.(-y * log(ŷ + ϵ) - (1 - y) * log(1 - ŷ + ϵ)))

Where typically, the prediction `ŷ` is given by the output of a [sigmoid](@ref man-activations) activation.
The `ϵ` term is included to avoid infinity. Using [`logitbinarycrossentropy`](@ref) is recomended
The `ϵ == eps` term is included to avoid infinity. Using [`logitbinarycrossentropy`](@ref) is recomended
over `binarycrossentropy` for numerical stability.

Use [`label_smoothing`](@ref) to smooth the `y` value as preprocessing before
Expand Down Expand Up @@ -310,7 +315,8 @@ julia> Flux.crossentropy(y_prob, y_hot)
0.43989f0
```
"""
function binarycrossentropy(ŷ, y; agg = mean, ϵ = epseltype(ŷ))
function binarycrossentropy(ŷ, y; agg = mean, eps::Real = epseltype(ŷ), ϵ = nothing)
ϵ = _greek_ascii_depwarn(ϵ => eps, :binarycrossentropy, "ϵ" => "eps")
_check_sizes(ŷ, y)
agg(@.(-xlogy(y, ŷ + ϵ) - xlogy(1 - y, 1 - ŷ + ϵ)))
end
Expand Down Expand Up @@ -346,7 +352,7 @@ function logitbinarycrossentropy(ŷ, y; agg = mean)
end

"""
kldivergence(ŷ, y; agg = mean, ϵ = eps())
kldivergence(ŷ, y; agg = mean, eps = eps(eltype(ŷ)))

Return the
[Kullback-Leibler divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence)
Expand All @@ -373,17 +379,18 @@ true
julia> Flux.kldivergence(p2, p1; agg = sum) ≈ 2log(2)
true

julia> Flux.kldivergence(p2, p2; ϵ = 0) # about -2e-16 with the regulator
julia> Flux.kldivergence(p2, p2; eps = 0) # about -2e-16 with the regulator
0.0

julia> Flux.kldivergence(p1, p2; ϵ = 0) # about 17.3 with the regulator
julia> Flux.kldivergence(p1, p2; eps = 0) # about 17.3 with the regulator
Inf
```
"""
function kldivergence(ŷ, y; dims = 1, agg = mean, ϵ = epseltype(ŷ))
function kldivergence(ŷ, y; dims = 1, agg = mean, eps::Real = epseltype(ŷ), ϵ = nothing)
ϵ = _greek_ascii_depwarn(ϵ => eps, :kldivergence, "ϵ" => "eps")
_check_sizes(ŷ, y)
entropy = agg(sum(xlogx.(y), dims = dims))
cross_entropy = crossentropy(ŷ, y; dims = dims, agg = agg, ϵ = ϵ)
entropy = agg(sum(xlogx.(y); dims = dims))
cross_entropy = crossentropy(ŷ, y; dims, agg, eps=ϵ)
return entropy + cross_entropy
end

Expand Down Expand Up @@ -501,23 +508,27 @@ julia> 1 - Flux.dice_coeff_loss(y_pred, 1:3) # ~ F1 score for image segmentatio
0.99900760833609
```
"""
function dice_coeff_loss(ŷ, y; smooth = ofeltype(ŷ, 1.0))
function dice_coeff_loss(ŷ, y; smooth = 1)
s = ofeltype(ŷ, smooth)
_check_sizes(ŷ, y)
1 - (2 * sum(y .* ŷ) + smooth) / (sum(y .^ 2) + sum(ŷ .^ 2) + smooth) #TODO agg
# TODO add agg
1 - (2 * sum(y .* ŷ) + s) / (sum(y .^ 2) + sum(ŷ .^ 2) + s)
end

"""
tversky_loss(ŷ, y; β = 0.7)
tversky_loss(ŷ, y; beta = 0.7)

Return the [Tversky loss](https://arxiv.org/abs/1706.05721).
Used with imbalanced data to give more weight to false negatives.
Larger β weigh recall more than precision (by placing more emphasis on false negatives).
Larger `β == beta` weigh recall more than precision (by placing more emphasis on false negatives).
Calculated as:

1 - sum(|y .* ŷ| + 1) / (sum(y .* ŷ + (1 - β)*(1 .- y) .* ŷ + β*y .* (1 .- ŷ)) + 1)

"""
function tversky_loss(ŷ, y; β = ofeltype(ŷ, 0.7))
function tversky_loss(ŷ, y; beta::Real = 0.7, β = nothing)
beta_temp = _greek_ascii_depwarn(β => beta, :tversky_loss, "β" => "beta")
β = ofeltype(ŷ, beta_temp)
_check_sizes(ŷ, y)
#TODO add agg
num = sum(y .* ŷ) + 1
Expand All @@ -526,12 +537,12 @@ function tversky_loss(ŷ, y; β = ofeltype(ŷ, 0.7))
end

"""
binary_focal_loss(ŷ, y; agg=mean, γ=2, ϵ=eps())
binary_focal_loss(ŷ, y; agg=mean, gamma=2, eps=eps(eltype(ŷ)))

Return the [binary_focal_loss](https://arxiv.org/pdf/1708.02002.pdf)
The input, 'ŷ', is expected to be normalized (i.e. [softmax](@ref Softmax) output).

For `γ == 0`, the loss is mathematically equivalent to [`Losses.binarycrossentropy`](@ref).
For `gamma = 0`, the loss is mathematically equivalent to [`Losses.binarycrossentropy`](@ref).

See also: [`Losses.focal_loss`](@ref) for multi-class setting

Expand All @@ -553,25 +564,28 @@ julia> Flux.binary_focal_loss(ŷ, y) ≈ 0.0728675615927385
true
```
"""
function binary_focal_loss(ŷ, y; agg=mean, γ=2, ϵ=epseltype(ŷ))
function binary_focal_loss(ŷ, y; agg=mean, gamma=2, eps::Real=epseltype(ŷ), ϵ = nothing, γ = nothing)
ϵ = _greek_ascii_depwarn(ϵ => eps, :binary_focal_loss, "ϵ" => "eps")
gamma_temp = _greek_ascii_depwarn(γ => gamma, :binary_focal_loss, "γ" => "gamma")
γ = gamma_temp isa Integer ? gamma_temp : ofeltype(ŷ, gamma_temp)
_check_sizes(ŷ, y)
= ŷ .+ ϵ
p_t = y .* + (1 .- y) .* (1 .- )
ce = -log.(p_t)
ŷϵ = ŷ .+ ϵ
p_t = y .* ŷϵ + (1 .- y) .* (1 .- ŷϵ)
ce = .-log.(p_t)
weight = (1 .- p_t) .^ γ
loss = weight .* ce
agg(loss)
end

"""
focal_loss(ŷ, y; dims=1, agg=mean, γ=2, ϵ=eps())
focal_loss(ŷ, y; dims=1, agg=mean, gamma=2, eps=eps(eltype(ŷ)))

Return the [focal_loss](https://arxiv.org/pdf/1708.02002.pdf)
which can be used in classification tasks with highly imbalanced classes.
It down-weights well-classified examples and focuses on hard examples.
The input, 'ŷ', is expected to be normalized (i.e. [softmax](@ref Softmax) output).

The modulating factor, `γ`, controls the down-weighting strength.
The modulating factor, `γ == gamma`, controls the down-weighting strength.
For `γ == 0`, the loss is mathematically equivalent to [`Losses.crossentropy`](@ref).

# Example
Expand All @@ -597,10 +611,13 @@ true
See also: [`Losses.binary_focal_loss`](@ref) for binary (not one-hot) labels

"""
function focal_loss(ŷ, y; dims=1, agg=mean, γ=2, ϵ=epseltype(ŷ))
_check_sizes(ŷ, y)
ŷ = ŷ .+ ϵ
agg(sum(@. -y * (1 - ŷ)^γ * log(ŷ); dims=dims))
function focal_loss(ŷ, y; dims=1, agg=mean, gamma=2, eps::Real=epseltype(ŷ), ϵ=nothing, γ=nothing)
ϵ = _greek_ascii_depwarn(ϵ => eps, :focal_loss, "ϵ" => "eps")
gamma_temp = _greek_ascii_depwarn(γ => gamma, :focal_loss, "γ" => "gamma")
γ = gamma_temp isa Integer ? gamma_temp : ofeltype(ŷ, gamma_temp)
_check_sizes(ŷ, y)
ŷϵ = ŷ .+ ϵ
agg(sum(@. -y * (1 - ŷϵ)^γ * log(ŷϵ); dims))
end

"""
Expand Down
9 changes: 9 additions & 0 deletions src/losses/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,12 @@ end
_check_sizes(ŷ, y) = nothing # pass-through, for constant label e.g. y = 1

ChainRulesCore.@non_differentiable _check_sizes(ŷ::Any, y::Any)

# Greek-letter keywords deprecated in Flux 0.13
# Arguments (old => new, :function, "β" => "beta")
function _greek_ascii_depwarn(βbeta::Pair, func = :loss, names = "" => "")
Base.depwarn("""loss function $func no longer accepts greek-letter keyword $(names.first)
please use ascii $(names.second) instead""", func)
βbeta.first
end
_greek_ascii_depwarn(βbeta::Pair{Nothing}, _...) = βbeta.second
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this live outside of the Losses module?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may want to make it non-differentiable as well.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could live somewhere else, but where? And is it worth it?

(Most uses are in Losses, I spotted the normalisation ones afterwards, just 3.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

utils.jl? That's included before any layers or loss functions are defined.

Loading