From 428a2d78c9205f52a00a0867e8674fe5247b7963 Mon Sep 17 00:00:00 2001 From: Shikhar Goswami Date: Sat, 30 Jan 2021 14:30:16 +0530 Subject: [PATCH 01/14] Implementation of Focal Loss --- src/losses/Losses.jl | 3 ++- src/losses/functions.jl | 17 +++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/losses/Losses.jl b/src/losses/Losses.jl index e781affaf6..f87f487dfc 100644 --- a/src/losses/Losses.jl +++ b/src/losses/Losses.jl @@ -18,7 +18,8 @@ export mse, mae, msle, dice_coeff_loss, poisson_loss, hinge_loss, squared_hinge_loss, - ctc_loss + ctc_loss, + focal_loss include("utils.jl") include("functions.jl") diff --git a/src/losses/functions.jl b/src/losses/functions.jl index d7fd666b63..2a70fb149c 100644 --- a/src/losses/functions.jl +++ b/src/losses/functions.jl @@ -429,6 +429,23 @@ function tversky_loss(ŷ, y; β = ofeltype(ŷ, 0.7)) 1 - num / den end +""" + focal_loss(yhat, y; dims=1, agg=mean, gamma=2.0, eps = eps(eltype(yhat)) + + 𝛄: modulating factor. + +Return the [focal_loss](https://arxiv.org/pdf/1708.02002.pdf) +Extremely useful for classification when you have highly imbalanced classes. It down-weights +well-classified examples and focuses on hard examples. Loss is much high for misclassified points as compared to well-classified points. Used in single-shot detectors. +""" +function focal_loss(ŷ, y; dims=1, agg=mean, 𝛄=2.0, eps = eps(eltype(ŷ))) + ŷ = ŷ .+ eps + p_t = [y==1 ? ŷ : 1-ŷ for (ŷ, y) in zip(ŷ, y)] + ce = -log.(p_t) + weight = (1 .- p_t) .^ 𝛄 + loss = weight .* ce + agg(sum(loss, dims=dims)) +end ```@meta DocTestFilters = nothing From 3c05c80ca4d6a4413cf1b5700fb754dac9d4cb35 Mon Sep 17 00:00:00 2001 From: Shikhar Goswami Date: Sat, 30 Jan 2021 15:11:17 +0530 Subject: [PATCH 02/14] Added tests for focal_loss --- test/losses.jl | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/test/losses.jl b/test/losses.jl index 6f7a5c8407..8a3157d7a3 100644 --- a/test/losses.jl +++ b/test/losses.jl @@ -174,3 +174,16 @@ end end end end + +y = [0 1 1] +ŷ = [0.1 0.7 0.9] +y1 = [1 0 + 0 0 + 0 1] +ŷ1 = [0.6 0.3 + 0.3 0.1 + 0.1 0.6] +@testset "focal_loss" begin + @test Flux.focal_loss(ŷ, y) ≈ 0.011402651755880793 + @test Flux.focal_loss(ŷ1, y1) ≈ 0.11488644991362265 +end From 00a785d3ab163fb5922ea4f4ae92b8cf4cb4d171 Mon Sep 17 00:00:00 2001 From: Shikhar Goswami Date: Sun, 31 Jan 2021 00:18:37 +0530 Subject: [PATCH 03/14] Changes done --- src/losses/functions.jl | 27 ++++++++++++++++++-------- test/losses.jl | 42 ++++++++++++++++++++++++++++++----------- 2 files changed, 50 insertions(+), 19 deletions(-) diff --git a/src/losses/functions.jl b/src/losses/functions.jl index 2a70fb149c..2d0a33d54b 100644 --- a/src/losses/functions.jl +++ b/src/losses/functions.jl @@ -430,23 +430,34 @@ function tversky_loss(ŷ, y; β = ofeltype(ŷ, 0.7)) end """ - focal_loss(yhat, y; dims=1, agg=mean, gamma=2.0, eps = eps(eltype(yhat)) + binary_focal_loss(ŷ, y; dims=1, agg=mean, γ=2.0, eps = epseltype(ŷ)) - 𝛄: modulating factor. + γ: modulating factor. -Return the [focal_loss](https://arxiv.org/pdf/1708.02002.pdf) -Extremely useful for classification when you have highly imbalanced classes. It down-weights -well-classified examples and focuses on hard examples. Loss is much high for misclassified points as compared to well-classified points. Used in single-shot detectors. """ -function focal_loss(ŷ, y; dims=1, agg=mean, 𝛄=2.0, eps = eps(eltype(ŷ))) +function binary_focal_loss(ŷ, y; dims=1, agg=mean, γ=2.0, eps = epseltype(ŷ)) ŷ = ŷ .+ eps - p_t = [y==1 ? ŷ : 1-ŷ for (ŷ, y) in zip(ŷ, y)] + p_t = y .*ŷ + (1 .- y) .* (1 .- ŷ) ce = -log.(p_t) - weight = (1 .- p_t) .^ 𝛄 + weight = (1 .- p_t) .^ γ loss = weight .* ce agg(sum(loss, dims=dims)) end +""" + categorical_focal_loss(ŷ, y; dims=1, agg=mean, γ=2.0, eps = epseltype(ŷ)) + Softmax version of Focal Loss + γ: modulating factor. + +""" +function categorical_focal_loss(ŷ, y; dims=1, agg=mean, γ=2.0, eps = epseltype(ŷ)) + ŷ = softmax(ŷ; dims=dims) + ŷ = ŷ .+ eps + ce = -y .* log.(ŷ) + weight = (1 .- ŷ) .^ γ + loss = weight .* ce + agg(sum(loss, dims=dims)) +end ```@meta DocTestFilters = nothing ``` diff --git a/test/losses.jl b/test/losses.jl index 8a3157d7a3..2723c073b0 100644 --- a/test/losses.jl +++ b/test/losses.jl @@ -175,15 +175,35 @@ end end end -y = [0 1 1] -ŷ = [0.1 0.7 0.9] -y1 = [1 0 - 0 0 - 0 1] -ŷ1 = [0.6 0.3 - 0.3 0.1 - 0.1 0.6] -@testset "focal_loss" begin - @test Flux.focal_loss(ŷ, y) ≈ 0.011402651755880793 - @test Flux.focal_loss(ŷ1, y1) ≈ 0.11488644991362265 +@testset "binary_focal_loss" begin + y = [0 1 1] + ŷ = [0.1 0.7 0.9] + y1 = [1 0 + 0 0 + 0 1] + ŷ1 = [0.6 0.3 + 0.3 0.1 + 0.1 0.6] + @test Flux.binary_focal_loss(ŷ, y) ≈ 0.011402651755880793 + @test Flux.binary_focal_loss(ŷ1, y1) ≈ 0.11488644991362265 +end + +@testset "categorical_focal_loss" begin + y = [1 0 0 1 0 1 1 + 0 1 0 0 1 0 0 + 0 0 1 0 0 0 0] + + ŷ = [-9.0 -6.0 -3.0 0.0 2.0 5.0 8.0 + -8.0 -5.0 -2.0 0.0 3.0 6.0 9.0 + -7.0 -4.0 -1.0 1.0 4.0 7.0 7.5] + y1 = [1 0 + 0 0 + 0 1] + ŷ1 = [5.0 1.0 + 2.3 2.0 + 3.8 9.0] + @test Flux.categorical_focal_loss(ŷ, y) ≈ 1.0668209889165343 + @test Flux.categorical_focal_loss(ŷ1, y1) ≈ 0.011366240888043638 end + + From b87a6d282848a3a67662c2e828c89b3a9f9cbc29 Mon Sep 17 00:00:00 2001 From: Shikhar Goswami Date: Sun, 31 Jan 2021 00:48:23 +0530 Subject: [PATCH 04/14] oops! forgot toadd these --- src/losses/Losses.jl | 3 ++- test/losses.jl | 2 -- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/losses/Losses.jl b/src/losses/Losses.jl index f87f487dfc..39a09e84c9 100644 --- a/src/losses/Losses.jl +++ b/src/losses/Losses.jl @@ -19,7 +19,8 @@ export mse, mae, msle, poisson_loss, hinge_loss, squared_hinge_loss, ctc_loss, - focal_loss + binary_focal_loss, + categorical_focal_loss include("utils.jl") include("functions.jl") diff --git a/test/losses.jl b/test/losses.jl index 2723c073b0..5466a669bf 100644 --- a/test/losses.jl +++ b/test/losses.jl @@ -205,5 +205,3 @@ end @test Flux.categorical_focal_loss(ŷ, y) ≈ 1.0668209889165343 @test Flux.categorical_focal_loss(ŷ1, y1) ≈ 0.011366240888043638 end - - From 36c2f2c91798a212aae505f51dc25fc8244bc6c2 Mon Sep 17 00:00:00 2001 From: Shikhar Goswami Date: Sun, 31 Jan 2021 23:21:30 +0530 Subject: [PATCH 05/14] Revised focal_loss --- src/losses/Losses.jl | 3 +-- src/losses/functions.jl | 38 ++++++++++++++++++++++---------------- test/losses.jl | 35 +++++++++++++++++++---------------- 3 files changed, 42 insertions(+), 34 deletions(-) diff --git a/src/losses/Losses.jl b/src/losses/Losses.jl index 39a09e84c9..bf944f9231 100644 --- a/src/losses/Losses.jl +++ b/src/losses/Losses.jl @@ -19,8 +19,7 @@ export mse, mae, msle, poisson_loss, hinge_loss, squared_hinge_loss, ctc_loss, - binary_focal_loss, - categorical_focal_loss + binary_focal_loss, focal_loss include("utils.jl") include("functions.jl") diff --git a/src/losses/functions.jl b/src/losses/functions.jl index 2d0a33d54b..3ff2701470 100644 --- a/src/losses/functions.jl +++ b/src/losses/functions.jl @@ -430,33 +430,39 @@ function tversky_loss(ŷ, y; β = ofeltype(ŷ, 0.7)) end """ - binary_focal_loss(ŷ, y; dims=1, agg=mean, γ=2.0, eps = epseltype(ŷ)) + binary_focal_loss(ŷ, y; dims=1, agg=mean, γ=2.0, ϵ=epseltype(ŷ)) - γ: modulating factor. +Return the [binary_focal_loss](https://arxiv.org/pdf/1708.02002.pdf) +Can be used in classification tasks in the presence of highly imbalanced classes. +It down-weights well-classified examples and focuses on hard examples. + +γ(default=2.0) is a number called modulating factor. +For γ=0, the loss is mathematically equivalent to binarycrossentropy(ŷ, y) + +See also: ['focal_loss'](@ref) """ -function binary_focal_loss(ŷ, y; dims=1, agg=mean, γ=2.0, eps = epseltype(ŷ)) - ŷ = ŷ .+ eps - p_t = y .*ŷ + (1 .- y) .* (1 .- ŷ) +function binary_focal_loss(ŷ, y; dims=1, agg=mean, γ=ofeltype(ŷ, 2.0), ϵ=epseltype(ŷ)) + ŷ = ŷ .+ ϵ + p_t = y .* ŷ + (1 .- y) .* (1 .- ŷ) ce = -log.(p_t) weight = (1 .- p_t) .^ γ loss = weight .* ce - agg(sum(loss, dims=dims)) + agg(loss) end """ - categorical_focal_loss(ŷ, y; dims=1, agg=mean, γ=2.0, eps = epseltype(ŷ)) - Softmax version of Focal Loss - γ: modulating factor. + focal_loss(ŷ, y; dims=1, agg=mean, γ=2.0, ϵ=epseltype(ŷ)) + +Return the [focal_loss](https://arxiv.org/pdf/1708.02002.pdf) +For γ=0, the loss is mathematically equivalent to crossentropy(ŷ, y) + +See also: [`binary_focal_loss`](@ref) """ -function categorical_focal_loss(ŷ, y; dims=1, agg=mean, γ=2.0, eps = epseltype(ŷ)) - ŷ = softmax(ŷ; dims=dims) - ŷ = ŷ .+ eps - ce = -y .* log.(ŷ) - weight = (1 .- ŷ) .^ γ - loss = weight .* ce - agg(sum(loss, dims=dims)) +function focal_loss(ŷ, y; dims=1, agg=mean, γ=ofeltype(ŷ, 2.0), ϵ=epseltype(ŷ)) + ŷ = ŷ .+ ϵ + agg(sum(@. -y * (1 - ŷ)^γ * log(ŷ); dims=1)) end ```@meta DocTestFilters = nothing diff --git a/test/losses.jl b/test/losses.jl index 5466a669bf..8ae4510c02 100644 --- a/test/losses.jl +++ b/test/losses.jl @@ -176,32 +176,35 @@ end end @testset "binary_focal_loss" begin - y = [0 1 1] - ŷ = [0.1 0.7 0.9] + y = [0 1 0 + 1 0 1] + yhat = [0.268941 0.5 0.268941 + 0.731059 0.5 0.731059] + y1 = [1 0 - 0 0 0 1] ŷ1 = [0.6 0.3 - 0.3 0.1 - 0.1 0.6] - @test Flux.binary_focal_loss(ŷ, y) ≈ 0.011402651755880793 - @test Flux.binary_focal_loss(ŷ1, y1) ≈ 0.11488644991362265 + 0.4 0.7] + @test Flux.binary_focal_loss(ŷ, y) ≈ 0.0728675615927385 + @test Flux.binary_focal_loss(ŷ1, y1) ≈ 0.05691642237852222 + @test Flux.binary_focal_loss(ŷ, y; γ=0.0) == Flux.binarycrossentropy(ŷ, y) end -@testset "categorical_focal_loss" begin - y = [1 0 0 1 0 1 1 - 0 1 0 0 1 0 0 - 0 0 1 0 0 0 0] +@testset "focal_loss" begin + y = [1 0 0 0 1 + 0 1 0 1 0 + 0 0 1 0 0] - ŷ = [-9.0 -6.0 -3.0 0.0 2.0 5.0 8.0 - -8.0 -5.0 -2.0 0.0 3.0 6.0 9.0 - -7.0 -4.0 -1.0 1.0 4.0 7.0 7.5] + ŷ = [0.0900306 0.0900306 0.0900306 0.0900306 0.0900306 + 0.244728 0.244728 0.244728 0.244728 0.244728 + 0.665241 0.665241 0.665241 0.665241 0.665241] y1 = [1 0 0 0 0 1] ŷ1 = [5.0 1.0 2.3 2.0 3.8 9.0] - @test Flux.categorical_focal_loss(ŷ, y) ≈ 1.0668209889165343 - @test Flux.categorical_focal_loss(ŷ1, y1) ≈ 0.011366240888043638 + @test Flux.focal_loss(ŷ, y) ≈ 1.1277571935622628 + @test Flux.focal_loss(ŷ1, y1) ≈ 0.011366240888043638 + @test Flux.focal_loss(ŷ, y; γ=0.0) == Flux.crossentropy(ŷ, y) end From feeb6d7c725eb0f120a43630430faab19efd6f5a Mon Sep 17 00:00:00 2001 From: Shikhar Goswami Date: Mon, 1 Feb 2021 13:56:10 +0530 Subject: [PATCH 06/14] Refactored code and docstring --- src/losses/functions.jl | 22 ++++++++++++++-------- test/losses.jl | 10 +++++----- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/src/losses/functions.jl b/src/losses/functions.jl index 3ff2701470..f9b6f012c9 100644 --- a/src/losses/functions.jl +++ b/src/losses/functions.jl @@ -430,16 +430,17 @@ function tversky_loss(ŷ, y; β = ofeltype(ŷ, 0.7)) end """ - binary_focal_loss(ŷ, y; dims=1, agg=mean, γ=2.0, ϵ=epseltype(ŷ)) + binary_focal_loss(ŷ, y; dims=1, agg=mean, γ=ofeltype(ŷ, 2.0), ϵ=epseltype(ŷ)) Return the [binary_focal_loss](https://arxiv.org/pdf/1708.02002.pdf) -Can be used in classification tasks in the presence of highly imbalanced classes. +which can be used in classification tasks with highly imbalanced classes. It down-weights well-classified examples and focuses on hard examples. +The input, 'ŷ', is expected to be unnormalized. -γ(default=2.0) is a number called modulating factor. -For γ=0, the loss is mathematically equivalent to binarycrossentropy(ŷ, y) +The modulating factor, `γ`, controls the down-weighting strength. +For `γ == 0`, the loss is mathematically equivalent to [`Losses.binarycrossentropy`](@ref). -See also: ['focal_loss'](@ref) +See also: [`Losses.focal_loss`](@ref) for multi-class setting """ function binary_focal_loss(ŷ, y; dims=1, agg=mean, γ=ofeltype(ŷ, 2.0), ϵ=epseltype(ŷ)) @@ -452,12 +453,17 @@ function binary_focal_loss(ŷ, y; dims=1, agg=mean, γ=ofeltype(ŷ, 2.0), ϵ=e end """ - focal_loss(ŷ, y; dims=1, agg=mean, γ=2.0, ϵ=epseltype(ŷ)) + focal_loss(ŷ, y; dims=1, agg=mean, γ=ofeltype(ŷ, 2.0), ϵ=epseltype(ŷ)) Return the [focal_loss](https://arxiv.org/pdf/1708.02002.pdf) -For γ=0, the loss is mathematically equivalent to crossentropy(ŷ, y) +which can be used in classification tasks with highly imbalanced classes. +It down-weights well-classified examples and focuses on hard examples. +The input, `ŷ`, is expected to be unnormalized. + +The modulating factor, `γ`, controls the down-weighting strength. +For `γ == 0`, the loss is mathematically equivalent to [`Losses.crossentropy`](@ref). -See also: [`binary_focal_loss`](@ref) +See also: [`Losses.binary_focal_loss`](@ref) for binary (not one-hot) labels """ function focal_loss(ŷ, y; dims=1, agg=mean, γ=ofeltype(ŷ, 2.0), ϵ=epseltype(ŷ)) diff --git a/test/losses.jl b/test/losses.jl index 8ae4510c02..4983d63679 100644 --- a/test/losses.jl +++ b/test/losses.jl @@ -178,7 +178,7 @@ end @testset "binary_focal_loss" begin y = [0 1 0 1 0 1] - yhat = [0.268941 0.5 0.268941 + ŷ = [0.268941 0.5 0.268941 0.731059 0.5 0.731059] y1 = [1 0 @@ -201,10 +201,10 @@ end y1 = [1 0 0 0 0 1] - ŷ1 = [5.0 1.0 - 2.3 2.0 - 3.8 9.0] + ŷ1 = [0.4 0.2 + 0.5 0.5 + 0.1 0.3] @test Flux.focal_loss(ŷ, y) ≈ 1.1277571935622628 - @test Flux.focal_loss(ŷ1, y1) ≈ 0.011366240888043638 + @test Flux.focal_loss(ŷ1, y1) ≈ 0.45990566879720157 @test Flux.focal_loss(ŷ, y; γ=0.0) == Flux.crossentropy(ŷ, y) end From 07c33eb73d7527fddfc94138208c6ec59f2b0f64 Mon Sep 17 00:00:00 2001 From: Shikhar Goswami Date: Tue, 2 Feb 2021 20:56:03 +0530 Subject: [PATCH 07/14] Changes made --- src/losses/functions.jl | 8 ++++---- test/losses.jl | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/losses/functions.jl b/src/losses/functions.jl index f9b6f012c9..460a256e7a 100644 --- a/src/losses/functions.jl +++ b/src/losses/functions.jl @@ -435,7 +435,7 @@ end Return the [binary_focal_loss](https://arxiv.org/pdf/1708.02002.pdf) which can be used in classification tasks with highly imbalanced classes. It down-weights well-classified examples and focuses on hard examples. -The input, 'ŷ', is expected to be unnormalized. +The input, 'ŷ', is expected to be normalized (i.e. [`softmax`](@ref) output). The modulating factor, `γ`, controls the down-weighting strength. For `γ == 0`, the loss is mathematically equivalent to [`Losses.binarycrossentropy`](@ref). @@ -443,7 +443,7 @@ For `γ == 0`, the loss is mathematically equivalent to [`Losses.binarycrossentr See also: [`Losses.focal_loss`](@ref) for multi-class setting """ -function binary_focal_loss(ŷ, y; dims=1, agg=mean, γ=ofeltype(ŷ, 2.0), ϵ=epseltype(ŷ)) +function binary_focal_loss(ŷ, y; agg=mean, γ=ofeltype(ŷ, 2.0), ϵ=epseltype(ŷ)) ŷ = ŷ .+ ϵ p_t = y .* ŷ + (1 .- y) .* (1 .- ŷ) ce = -log.(p_t) @@ -458,7 +458,7 @@ end Return the [focal_loss](https://arxiv.org/pdf/1708.02002.pdf) which can be used in classification tasks with highly imbalanced classes. It down-weights well-classified examples and focuses on hard examples. -The input, `ŷ`, is expected to be unnormalized. +The input, 'ŷ', is expected to be normalized (i.e. [`softmax`](@ref) output). The modulating factor, `γ`, controls the down-weighting strength. For `γ == 0`, the loss is mathematically equivalent to [`Losses.crossentropy`](@ref). @@ -468,7 +468,7 @@ See also: [`Losses.binary_focal_loss`](@ref) for binary (not one-hot) labels """ function focal_loss(ŷ, y; dims=1, agg=mean, γ=ofeltype(ŷ, 2.0), ϵ=epseltype(ŷ)) ŷ = ŷ .+ ϵ - agg(sum(@. -y * (1 - ŷ)^γ * log(ŷ); dims=1)) + agg(sum(@. -y * (1 - ŷ)^γ * log(ŷ); dims=dims)) end ```@meta DocTestFilters = nothing diff --git a/test/losses.jl b/test/losses.jl index 4983d63679..7b44ee1586 100644 --- a/test/losses.jl +++ b/test/losses.jl @@ -187,7 +187,7 @@ end 0.4 0.7] @test Flux.binary_focal_loss(ŷ, y) ≈ 0.0728675615927385 @test Flux.binary_focal_loss(ŷ1, y1) ≈ 0.05691642237852222 - @test Flux.binary_focal_loss(ŷ, y; γ=0.0) == Flux.binarycrossentropy(ŷ, y) + @test Flux.binary_focal_loss(ŷ, y; γ=0.0) ≈ Flux.binarycrossentropy(ŷ, y) end @testset "focal_loss" begin @@ -206,5 +206,5 @@ end 0.1 0.3] @test Flux.focal_loss(ŷ, y) ≈ 1.1277571935622628 @test Flux.focal_loss(ŷ1, y1) ≈ 0.45990566879720157 - @test Flux.focal_loss(ŷ, y; γ=0.0) == Flux.crossentropy(ŷ, y) + @test Flux.focal_loss(ŷ, y; γ=0.0) ≈ Flux.crossentropy(ŷ, y) end From 8ccff262d49595e4bee281e591f911a217c26e84 Mon Sep 17 00:00:00 2001 From: Shikhar Goswami Date: Tue, 2 Feb 2021 21:04:13 +0530 Subject: [PATCH 08/14] docstring mistake solved! --- src/losses/functions.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/losses/functions.jl b/src/losses/functions.jl index 460a256e7a..5f023459f3 100644 --- a/src/losses/functions.jl +++ b/src/losses/functions.jl @@ -430,7 +430,7 @@ function tversky_loss(ŷ, y; β = ofeltype(ŷ, 0.7)) end """ - binary_focal_loss(ŷ, y; dims=1, agg=mean, γ=ofeltype(ŷ, 2.0), ϵ=epseltype(ŷ)) + binary_focal_loss(ŷ, y; agg=mean, γ=ofeltype(ŷ, 2.0), ϵ=epseltype(ŷ)) Return the [binary_focal_loss](https://arxiv.org/pdf/1708.02002.pdf) which can be used in classification tasks with highly imbalanced classes. From e6342a46a997518adcf9d4689935abfb83f7aad4 Mon Sep 17 00:00:00 2001 From: Shikhar Goswami Date: Fri, 5 Feb 2021 15:19:26 +0530 Subject: [PATCH 09/14] Added GPU tests and requested changes made --- src/losses/functions.jl | 11 ++++------- test/cuda/losses.jl | 13 ++++++++++++- test/losses.jl | 10 ++++------ 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/src/losses/functions.jl b/src/losses/functions.jl index 5f023459f3..210e40c9c1 100644 --- a/src/losses/functions.jl +++ b/src/losses/functions.jl @@ -430,20 +430,17 @@ function tversky_loss(ŷ, y; β = ofeltype(ŷ, 0.7)) end """ - binary_focal_loss(ŷ, y; agg=mean, γ=ofeltype(ŷ, 2.0), ϵ=epseltype(ŷ)) + binary_focal_loss(ŷ, y; agg=mean, γ=2, ϵ=epseltype(ŷ)) Return the [binary_focal_loss](https://arxiv.org/pdf/1708.02002.pdf) -which can be used in classification tasks with highly imbalanced classes. -It down-weights well-classified examples and focuses on hard examples. The input, 'ŷ', is expected to be normalized (i.e. [`softmax`](@ref) output). -The modulating factor, `γ`, controls the down-weighting strength. For `γ == 0`, the loss is mathematically equivalent to [`Losses.binarycrossentropy`](@ref). See also: [`Losses.focal_loss`](@ref) for multi-class setting """ -function binary_focal_loss(ŷ, y; agg=mean, γ=ofeltype(ŷ, 2.0), ϵ=epseltype(ŷ)) +function binary_focal_loss(ŷ, y; agg=mean, γ=2, ϵ=epseltype(ŷ)) ŷ = ŷ .+ ϵ p_t = y .* ŷ + (1 .- y) .* (1 .- ŷ) ce = -log.(p_t) @@ -453,7 +450,7 @@ function binary_focal_loss(ŷ, y; agg=mean, γ=ofeltype(ŷ, 2.0), ϵ=epseltype end """ - focal_loss(ŷ, y; dims=1, agg=mean, γ=ofeltype(ŷ, 2.0), ϵ=epseltype(ŷ)) + focal_loss(ŷ, y; dims=1, agg=mean, γ=2, ϵ=epseltype(ŷ)) Return the [focal_loss](https://arxiv.org/pdf/1708.02002.pdf) which can be used in classification tasks with highly imbalanced classes. @@ -466,7 +463,7 @@ For `γ == 0`, the loss is mathematically equivalent to [`Losses.crossentropy`]( See also: [`Losses.binary_focal_loss`](@ref) for binary (not one-hot) labels """ -function focal_loss(ŷ, y; dims=1, agg=mean, γ=ofeltype(ŷ, 2.0), ϵ=epseltype(ŷ)) +function focal_loss(ŷ, y; dims=1, agg=mean, γ=2, ϵ=epseltype(ŷ)) ŷ = ŷ .+ ϵ agg(sum(@. -y * (1 - ŷ)^γ * log(ŷ); dims=dims)) end diff --git a/test/cuda/losses.jl b/test/cuda/losses.jl index 0913b0eb6a..a0f7f47d80 100644 --- a/test/cuda/losses.jl +++ b/test/cuda/losses.jl @@ -1,4 +1,4 @@ -using Flux.Losses: crossentropy, binarycrossentropy, logitbinarycrossentropy +using Flux.Losses: crossentropy, binarycrossentropy, logitbinarycrossentropy, binary_focal_loss, focal_loss @testset "Losses" begin @@ -14,6 +14,17 @@ y = [1, 1, 0.] @test binarycrossentropy(σ.(x), y) ≈ binarycrossentropy(gpu(σ.(x)), gpu(y)) @test logitbinarycrossentropy(x, y) ≈ logitbinarycrossentropy(gpu(x), gpu(y)) +x = [0.268941 0.5 0.268941 + 0.731059 0.5 0.731059] +y = [0 1 0 + 1 0 1] +@test binary_focal_loss(x, y) ≈ binary_focal_loss(gpu(x), gpu(y)) + +x = softmax(reshape(-7:7, 3, 5) .* 1f0) +y = [1 0 0 0 1 + 0 1 0 1 0 + 0 0 1 0 0] +@test focal_loss(x, y) ≈ focal_loss(gpu(x), gpu(y)) @testset "GPU grad tests" begin x = rand(Float32, 3,3) diff --git a/test/losses.jl b/test/losses.jl index 7b44ee1586..9abc03abb8 100644 --- a/test/losses.jl +++ b/test/losses.jl @@ -13,7 +13,8 @@ const ALL_LOSSES = [Flux.Losses.mse, Flux.Losses.mae, Flux.Losses.msle, Flux.Losses.tversky_loss, Flux.Losses.dice_coeff_loss, Flux.Losses.poisson_loss, - Flux.Losses.hinge_loss, Flux.Losses.squared_hinge_loss] + Flux.Losses.hinge_loss, Flux.Losses.squared_hinge_loss, + Flux.Losses.binary_focal_loss, Flux.Losses.focal_loss] @testset "xlogx & xlogy" begin @@ -179,7 +180,7 @@ end y = [0 1 0 1 0 1] ŷ = [0.268941 0.5 0.268941 - 0.731059 0.5 0.731059] + 0.731059 0.5 0.731059] y1 = [1 0 0 1] @@ -194,10 +195,7 @@ end y = [1 0 0 0 1 0 1 0 1 0 0 0 1 0 0] - - ŷ = [0.0900306 0.0900306 0.0900306 0.0900306 0.0900306 - 0.244728 0.244728 0.244728 0.244728 0.244728 - 0.665241 0.665241 0.665241 0.665241 0.665241] + ŷ = softmax(reshape(-7:7, 3, 5) .* 1f0) y1 = [1 0 0 0 0 1] From 63747a6f97daf0791cad9c02fcce819923ac7fe3 Mon Sep 17 00:00:00 2001 From: Shikhar Goswami Date: Fri, 5 Feb 2021 19:13:42 +0530 Subject: [PATCH 10/14] Added Doctest and fixed docstring mistake --- src/losses/functions.jl | 43 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/src/losses/functions.jl b/src/losses/functions.jl index 210e40c9c1..8687147a2f 100644 --- a/src/losses/functions.jl +++ b/src/losses/functions.jl @@ -430,13 +430,31 @@ function tversky_loss(ŷ, y; β = ofeltype(ŷ, 0.7)) end """ - binary_focal_loss(ŷ, y; agg=mean, γ=2, ϵ=epseltype(ŷ)) + binary_focal_loss(ŷ, y; agg=mean, γ=2, ϵ=eps(ŷ)) Return the [binary_focal_loss](https://arxiv.org/pdf/1708.02002.pdf) The input, 'ŷ', is expected to be normalized (i.e. [`softmax`](@ref) output). For `γ == 0`, the loss is mathematically equivalent to [`Losses.binarycrossentropy`](@ref). +# Example +```jldoctest +julia> y = [0 1 0 + 1 0 1] +2×3 Array{Int64,2}: + 0 1 0 + 1 0 1 + +julia> ŷ = [0.268941 0.5 0.268941 + 0.731059 0.5 0.731059] +2×3 Array{Float64,2}: + 0.268941 0.5 0.268941 + 0.731059 0.5 0.731059 + +julia> Flux.binary_focal_loss(ŷ, y) ≈ 0.0728675615927385 +true +``` + See also: [`Losses.focal_loss`](@ref) for multi-class setting """ @@ -450,7 +468,7 @@ function binary_focal_loss(ŷ, y; agg=mean, γ=2, ϵ=epseltype(ŷ)) end """ - focal_loss(ŷ, y; dims=1, agg=mean, γ=2, ϵ=epseltype(ŷ)) + focal_loss(ŷ, y; dims=1, agg=mean, γ=2, ϵ=eps(ŷ)) Return the [focal_loss](https://arxiv.org/pdf/1708.02002.pdf) which can be used in classification tasks with highly imbalanced classes. @@ -460,6 +478,26 @@ The input, 'ŷ', is expected to be normalized (i.e. [`softmax`](@ref) output). The modulating factor, `γ`, controls the down-weighting strength. For `γ == 0`, the loss is mathematically equivalent to [`Losses.crossentropy`](@ref). +# Example +```jldoctest +julia> y = [1 0 0 0 1 + 0 1 0 1 0 + 0 0 1 0 0] +3×5 Array{Int64,2}: + 1 0 0 0 1 + 0 1 0 1 0 + 0 0 1 0 0 + +julia> ŷ = softmax(reshape(-7:7, 3, 5) .* 1f0) +3×5 Array{Float32,2}: + 0.0900306 0.0900306 0.0900306 0.0900306 0.0900306 + 0.244728 0.244728 0.244728 0.244728 0.244728 + 0.665241 0.665241 0.665241 0.665241 0.665241 + +julia> Flux.focal_loss(ŷ, y) ≈ 1.1277571935622628 +true +``` + See also: [`Losses.binary_focal_loss`](@ref) for binary (not one-hot) labels """ @@ -470,3 +508,4 @@ end ```@meta DocTestFilters = nothing ``` + From 5ce5481fcf40d91e6954f3f14db32d28bd94494e Mon Sep 17 00:00:00 2001 From: Shikhar Goswami Date: Fri, 5 Feb 2021 19:16:32 +0530 Subject: [PATCH 11/14] Done! --- src/losses/functions.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/losses/functions.jl b/src/losses/functions.jl index 8687147a2f..154291b7be 100644 --- a/src/losses/functions.jl +++ b/src/losses/functions.jl @@ -458,7 +458,7 @@ true See also: [`Losses.focal_loss`](@ref) for multi-class setting """ -function binary_focal_loss(ŷ, y; agg=mean, γ=2, ϵ=epseltype(ŷ)) +function binary_focal_loss(ŷ, y; agg=mean, γ=ofeltype(ŷ, 2.0), ϵ=epseltype(ŷ)) ŷ = ŷ .+ ϵ p_t = y .* ŷ + (1 .- y) .* (1 .- ŷ) ce = -log.(p_t) @@ -501,7 +501,7 @@ true See also: [`Losses.binary_focal_loss`](@ref) for binary (not one-hot) labels """ -function focal_loss(ŷ, y; dims=1, agg=mean, γ=2, ϵ=epseltype(ŷ)) +function focal_loss(ŷ, y; dims=1, agg=mean, γ=ofeltype(ŷ, 2.0), ϵ=epseltype(ŷ)) ŷ = ŷ .+ ϵ agg(sum(@. -y * (1 - ŷ)^γ * log(ŷ); dims=dims)) end From 19876931299e99ed98df3839a7da19bd1c46c995 Mon Sep 17 00:00:00 2001 From: Shikhar Goswami Date: Fri, 5 Feb 2021 19:41:12 +0530 Subject: [PATCH 12/14] Added entry in NEWS.md --- NEWS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/NEWS.md b/NEWS.md index db7853995e..8001bf7c6e 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,6 +2,7 @@ ## v0.12.0 +* Added [Focal Loss function](https://github.com/FluxML/Flux.jl/pull/1489) to Losses module * The Dense layer now supports inputs with [multiple batch dimensions](https://github.com/FluxML/Flux.jl/pull/1405). * Dense and Conv layers no longer perform [implicit type conversion](https://github.com/FluxML/Flux.jl/pull/1394). * Excise datasets in favour of other providers in the julia ecosystem. From 63e4d98c95770b3b0051dd7b791f4f74a5984365 Mon Sep 17 00:00:00 2001 From: Shikhar Goswami <44720861+shikhargoswami@users.noreply.github.com> Date: Fri, 5 Feb 2021 22:20:34 +0530 Subject: [PATCH 13/14] Applied the suggestions Co-authored-by: Carlo Lucibello --- src/losses/functions.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/losses/functions.jl b/src/losses/functions.jl index 154291b7be..20d45b7a80 100644 --- a/src/losses/functions.jl +++ b/src/losses/functions.jl @@ -458,7 +458,7 @@ true See also: [`Losses.focal_loss`](@ref) for multi-class setting """ -function binary_focal_loss(ŷ, y; agg=mean, γ=ofeltype(ŷ, 2.0), ϵ=epseltype(ŷ)) +function binary_focal_loss(ŷ, y; agg=mean, γ=2, ϵ=epseltype(ŷ)) ŷ = ŷ .+ ϵ p_t = y .* ŷ + (1 .- y) .* (1 .- ŷ) ce = -log.(p_t) @@ -501,11 +501,10 @@ true See also: [`Losses.binary_focal_loss`](@ref) for binary (not one-hot) labels """ -function focal_loss(ŷ, y; dims=1, agg=mean, γ=ofeltype(ŷ, 2.0), ϵ=epseltype(ŷ)) +function focal_loss(ŷ, y; dims=1, agg=mean, γ=2, ϵ=epseltype(ŷ)) ŷ = ŷ .+ ϵ agg(sum(@. -y * (1 - ŷ)^γ * log(ŷ); dims=dims)) end ```@meta DocTestFilters = nothing ``` - From 284425b4afdb940fd7614256a57994eb8eb8f875 Mon Sep 17 00:00:00 2001 From: Shikhar Goswami <44720861+shikhargoswami@users.noreply.github.com> Date: Fri, 5 Feb 2021 22:57:49 +0530 Subject: [PATCH 14/14] Update losses.md --- docs/src/models/losses.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/src/models/losses.md b/docs/src/models/losses.md index 5d8eb48700..d35ebd3673 100644 --- a/docs/src/models/losses.md +++ b/docs/src/models/losses.md @@ -39,4 +39,6 @@ Flux.Losses.hinge_loss Flux.Losses.squared_hinge_loss Flux.Losses.dice_coeff_loss Flux.Losses.tversky_loss +Flux.Losses.binary_focal_loss +Flux.Losses.focal_loss ```