Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of Focal loss #1489

Merged
merged 14 commits into from
Feb 5, 2021
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## v0.12.0

* Added [Focal Loss function](https://github.com/FluxML/Flux.jl/pull/1489) to Losses module
* The Dense layer now supports inputs with [multiple batch dimensions](https://github.com/FluxML/Flux.jl/pull/1405).
* Dense and Conv layers no longer perform [implicit type conversion](https://github.com/FluxML/Flux.jl/pull/1394).
* Excise datasets in favour of other providers in the julia ecosystem.
Expand Down
2 changes: 2 additions & 0 deletions docs/src/models/losses.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,6 @@ Flux.Losses.hinge_loss
Flux.Losses.squared_hinge_loss
Flux.Losses.dice_coeff_loss
Flux.Losses.tversky_loss
Flux.Losses.binary_focal_loss
Flux.Losses.focal_loss
```
3 changes: 2 additions & 1 deletion src/losses/Losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ export mse, mae, msle,
dice_coeff_loss,
poisson_loss,
hinge_loss, squared_hinge_loss,
ctc_loss
ctc_loss,
binary_focal_loss, focal_loss

include("utils.jl")
include("functions.jl")
Expand Down
75 changes: 75 additions & 0 deletions src/losses/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,82 @@ function tversky_loss(ŷ, y; β = ofeltype(ŷ, 0.7))
1 - num / den
end

"""
binary_focal_loss(ŷ, y; agg=mean, γ=2, ϵ=eps(ŷ))

Return the [binary_focal_loss](https://arxiv.org/pdf/1708.02002.pdf)
The input, 'ŷ', is expected to be normalized (i.e. [`softmax`](@ref) output).

For `γ == 0`, the loss is mathematically equivalent to [`Losses.binarycrossentropy`](@ref).

# Example
```jldoctest
julia> y = [0 1 0
1 0 1]
2×3 Array{Int64,2}:
0 1 0
1 0 1

julia> ŷ = [0.268941 0.5 0.268941
0.731059 0.5 0.731059]
2×3 Array{Float64,2}:
0.268941 0.5 0.268941
0.731059 0.5 0.731059

julia> Flux.binary_focal_loss(ŷ, y) ≈ 0.0728675615927385
true
```

See also: [`Losses.focal_loss`](@ref) for multi-class setting

"""
function binary_focal_loss(ŷ, y; agg=mean, γ=2, ϵ=epseltype(ŷ))
ŷ = ŷ .+ ϵ
p_t = y .* ŷ + (1 .- y) .* (1 .- ŷ)
ce = -log.(p_t)
weight = (1 .- p_t) .^ γ
loss = weight .* ce
agg(loss)
end

"""
focal_loss(ŷ, y; dims=1, agg=mean, γ=2, ϵ=eps(ŷ))

Return the [focal_loss](https://arxiv.org/pdf/1708.02002.pdf)
which can be used in classification tasks with highly imbalanced classes.
It down-weights well-classified examples and focuses on hard examples.
The input, 'ŷ', is expected to be normalized (i.e. [`softmax`](@ref) output).

The modulating factor, `γ`, controls the down-weighting strength.
For `γ == 0`, the loss is mathematically equivalent to [`Losses.crossentropy`](@ref).

# Example
```jldoctest
julia> y = [1 0 0 0 1
0 1 0 1 0
0 0 1 0 0]
3×5 Array{Int64,2}:
1 0 0 0 1
0 1 0 1 0
0 0 1 0 0

julia> ŷ = softmax(reshape(-7:7, 3, 5) .* 1f0)
3×5 Array{Float32,2}:
0.0900306 0.0900306 0.0900306 0.0900306 0.0900306
0.244728 0.244728 0.244728 0.244728 0.244728
0.665241 0.665241 0.665241 0.665241 0.665241

julia> Flux.focal_loss(ŷ, y) ≈ 1.1277571935622628
true
```

See also: [`Losses.binary_focal_loss`](@ref) for binary (not one-hot) labels

"""
function focal_loss(ŷ, y; dims=1, agg=mean, γ=2, ϵ=epseltype(ŷ))
ŷ = ŷ .+ ϵ
agg(sum(@. -y * (1 - ŷ)^γ * log(ŷ); dims=dims))
end
```@meta
DocTestFilters = nothing
```
13 changes: 12 additions & 1 deletion test/cuda/losses.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Flux.Losses: crossentropy, binarycrossentropy, logitbinarycrossentropy
using Flux.Losses: crossentropy, binarycrossentropy, logitbinarycrossentropy, binary_focal_loss, focal_loss


@testset "Losses" begin
Expand All @@ -14,6 +14,17 @@ y = [1, 1, 0.]
@test binarycrossentropy(σ.(x), y) ≈ binarycrossentropy(gpu(σ.(x)), gpu(y))
@test logitbinarycrossentropy(x, y) ≈ logitbinarycrossentropy(gpu(x), gpu(y))

x = [0.268941 0.5 0.268941
0.731059 0.5 0.731059]
y = [0 1 0
1 0 1]
@test binary_focal_loss(x, y) ≈ binary_focal_loss(gpu(x), gpu(y))

x = softmax(reshape(-7:7, 3, 5) .* 1f0)
y = [1 0 0 0 1
0 1 0 1 0
0 0 1 0 0]
@test focal_loss(x, y) ≈ focal_loss(gpu(x), gpu(y))

@testset "GPU grad tests" begin
x = rand(Float32, 3,3)
Expand Down
34 changes: 33 additions & 1 deletion test/losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ const ALL_LOSSES = [Flux.Losses.mse, Flux.Losses.mae, Flux.Losses.msle,
Flux.Losses.tversky_loss,
Flux.Losses.dice_coeff_loss,
Flux.Losses.poisson_loss,
Flux.Losses.hinge_loss, Flux.Losses.squared_hinge_loss]
Flux.Losses.hinge_loss, Flux.Losses.squared_hinge_loss,
Flux.Losses.binary_focal_loss, Flux.Losses.focal_loss]


@testset "xlogx & xlogy" begin
Expand Down Expand Up @@ -174,3 +175,34 @@ end
end
end
end

@testset "binary_focal_loss" begin
y = [0 1 0
1 0 1]
ŷ = [0.268941 0.5 0.268941
0.731059 0.5 0.731059]

y1 = [1 0
0 1]
ŷ1 = [0.6 0.3
0.4 0.7]
@test Flux.binary_focal_loss(ŷ, y) ≈ 0.0728675615927385
@test Flux.binary_focal_loss(ŷ1, y1) ≈ 0.05691642237852222
@test Flux.binary_focal_loss(ŷ, y; γ=0.0) ≈ Flux.binarycrossentropy(ŷ, y)
end

@testset "focal_loss" begin
y = [1 0 0 0 1
0 1 0 1 0
0 0 1 0 0]
ŷ = softmax(reshape(-7:7, 3, 5) .* 1f0)
y1 = [1 0
0 0
0 1]
ŷ1 = [0.4 0.2
0.5 0.5
0.1 0.3]
@test Flux.focal_loss(ŷ, y) ≈ 1.1277571935622628
@test Flux.focal_loss(ŷ1, y1) ≈ 0.45990566879720157
@test Flux.focal_loss(ŷ, y; γ=0.0) ≈ Flux.crossentropy(ŷ, y)
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
end