Skip to content

Commit

Permalink
Merge #1489
Browse files Browse the repository at this point in the history
1489: Implementation of Focal loss r=darsnack a=shikhargoswami

Focal loss was introduced in the RetinaNet paper (https://arxiv.org/pdf/1708.02002.pdf). 

Focal loss is useful for classification when you we highly imbalanced classes. It down-weights well-classified examples and focuses on hard examples. The loss value is much high for a sample which is misclassified by the classifier as compared to the loss value corresponding to a well-classified example. 

Used in single-shot object detection where the imbalance between the background class and other classes is extremely high.

Here's it's tensorflow implementation (https://github.com/tensorflow/addons/blob/v0.12.0/tensorflow_addons/losses/focal_loss.py#L26-L81)
### PR Checklist

- [x] Tests are added
- [x] Entry in NEWS.md
- [x] Documentation, if applicable
- [ ] Final review from `@dhairyagandhi96` (for API changes).


Co-authored-by: Shikhar Goswami <shikhargoswami2308@gmail.com>
Co-authored-by: Shikhar Goswami <44720861+shikhargoswami@users.noreply.github.com>
  • Loading branch information
3 people authored Feb 5, 2021
2 parents 7e9a180 + 63e4d98 commit 9bee3f3
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 3 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## v0.12.0

* Added [Focal Loss function](https://github.com/FluxML/Flux.jl/pull/1489) to Losses module
* The Dense layer now supports inputs with [multiple batch dimensions](https://github.com/FluxML/Flux.jl/pull/1405).
* Dense and Conv layers no longer perform [implicit type conversion](https://github.com/FluxML/Flux.jl/pull/1394).
* Excise datasets in favour of other providers in the julia ecosystem.
Expand Down
3 changes: 2 additions & 1 deletion src/losses/Losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ export mse, mae, msle,
dice_coeff_loss,
poisson_loss,
hinge_loss, squared_hinge_loss,
ctc_loss
ctc_loss,
binary_focal_loss, focal_loss

include("utils.jl")
include("functions.jl")
Expand Down
75 changes: 75 additions & 0 deletions src/losses/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,82 @@ function tversky_loss(ŷ, y; β = ofeltype(ŷ, 0.7))
1 - num / den
end

"""
binary_focal_loss(ŷ, y; agg=mean, γ=2, ϵ=eps(ŷ))
Return the [binary_focal_loss](https://arxiv.org/pdf/1708.02002.pdf)
The input, 'ŷ', is expected to be normalized (i.e. [`softmax`](@ref) output).
For `γ == 0`, the loss is mathematically equivalent to [`Losses.binarycrossentropy`](@ref).
# Example
```jldoctest
julia> y = [0 1 0
1 0 1]
2×3 Array{Int64,2}:
0 1 0
1 0 1
julia> ŷ = [0.268941 0.5 0.268941
0.731059 0.5 0.731059]
2×3 Array{Float64,2}:
0.268941 0.5 0.268941
0.731059 0.5 0.731059
julia> Flux.binary_focal_loss(ŷ, y) ≈ 0.0728675615927385
true
```
See also: [`Losses.focal_loss`](@ref) for multi-class setting
"""
function binary_focal_loss(ŷ, y; agg=mean, γ=2, ϵ=epseltype(ŷ))
=.+ ϵ
p_t = y .*+ (1 .- y) .* (1 .- ŷ)
ce = -log.(p_t)
weight = (1 .- p_t) .^ γ
loss = weight .* ce
agg(loss)
end

"""
focal_loss(ŷ, y; dims=1, agg=mean, γ=2, ϵ=eps(ŷ))
Return the [focal_loss](https://arxiv.org/pdf/1708.02002.pdf)
which can be used in classification tasks with highly imbalanced classes.
It down-weights well-classified examples and focuses on hard examples.
The input, 'ŷ', is expected to be normalized (i.e. [`softmax`](@ref) output).
The modulating factor, `γ`, controls the down-weighting strength.
For `γ == 0`, the loss is mathematically equivalent to [`Losses.crossentropy`](@ref).
# Example
```jldoctest
julia> y = [1 0 0 0 1
0 1 0 1 0
0 0 1 0 0]
3×5 Array{Int64,2}:
1 0 0 0 1
0 1 0 1 0
0 0 1 0 0
julia> ŷ = softmax(reshape(-7:7, 3, 5) .* 1f0)
3×5 Array{Float32,2}:
0.0900306 0.0900306 0.0900306 0.0900306 0.0900306
0.244728 0.244728 0.244728 0.244728 0.244728
0.665241 0.665241 0.665241 0.665241 0.665241
julia> Flux.focal_loss(ŷ, y) ≈ 1.1277571935622628
true
```
See also: [`Losses.binary_focal_loss`](@ref) for binary (not one-hot) labels
"""
function focal_loss(ŷ, y; dims=1, agg=mean, γ=2, ϵ=epseltype(ŷ))
=.+ ϵ
agg(sum(@. -y * (1 - ŷ)^γ * log(ŷ); dims=dims))
end
```@meta
DocTestFilters = nothing
```
13 changes: 12 additions & 1 deletion test/cuda/losses.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Flux.Losses: crossentropy, binarycrossentropy, logitbinarycrossentropy
using Flux.Losses: crossentropy, binarycrossentropy, logitbinarycrossentropy, binary_focal_loss, focal_loss


@testset "Losses" begin
Expand All @@ -14,6 +14,17 @@ y = [1, 1, 0.]
@test binarycrossentropy(σ.(x), y) binarycrossentropy(gpu(σ.(x)), gpu(y))
@test logitbinarycrossentropy(x, y) logitbinarycrossentropy(gpu(x), gpu(y))

x = [0.268941 0.5 0.268941
0.731059 0.5 0.731059]
y = [0 1 0
1 0 1]
@test binary_focal_loss(x, y) binary_focal_loss(gpu(x), gpu(y))

x = softmax(reshape(-7:7, 3, 5) .* 1f0)
y = [1 0 0 0 1
0 1 0 1 0
0 0 1 0 0]
@test focal_loss(x, y) focal_loss(gpu(x), gpu(y))

@testset "GPU grad tests" begin
x = rand(Float32, 3,3)
Expand Down
34 changes: 33 additions & 1 deletion test/losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ const ALL_LOSSES = [Flux.Losses.mse, Flux.Losses.mae, Flux.Losses.msle,
Flux.Losses.tversky_loss,
Flux.Losses.dice_coeff_loss,
Flux.Losses.poisson_loss,
Flux.Losses.hinge_loss, Flux.Losses.squared_hinge_loss]
Flux.Losses.hinge_loss, Flux.Losses.squared_hinge_loss,
Flux.Losses.binary_focal_loss, Flux.Losses.focal_loss]


@testset "xlogx & xlogy" begin
Expand Down Expand Up @@ -174,3 +175,34 @@ end
end
end
end

@testset "binary_focal_loss" begin
y = [0 1 0
1 0 1]
ŷ = [0.268941 0.5 0.268941
0.731059 0.5 0.731059]

y1 = [1 0
0 1]
ŷ1 = [0.6 0.3
0.4 0.7]
@test Flux.binary_focal_loss(ŷ, y) 0.0728675615927385
@test Flux.binary_focal_loss(ŷ1, y1) 0.05691642237852222
@test Flux.binary_focal_loss(ŷ, y; γ=0.0) Flux.binarycrossentropy(ŷ, y)
end

@testset "focal_loss" begin
y = [1 0 0 0 1
0 1 0 1 0
0 0 1 0 0]
ŷ = softmax(reshape(-7:7, 3, 5) .* 1f0)
y1 = [1 0
0 0
0 1]
ŷ1 = [0.4 0.2
0.5 0.5
0.1 0.3]
@test Flux.focal_loss(ŷ, y) 1.1277571935622628
@test Flux.focal_loss(ŷ1, y1) 0.45990566879720157
@test Flux.focal_loss(ŷ, y; γ=0.0) Flux.crossentropy(ŷ, y)
end

0 comments on commit 9bee3f3

Please sign in to comment.