From 3f6f60734df9cf7858cd6819d86802c407d30e3a Mon Sep 17 00:00:00 2001 From: Adrian Hill Date: Mon, 15 Jan 2024 18:43:13 +0100 Subject: [PATCH 1/4] Fix VGG Dropout probability --- src/convnets/vgg.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/convnets/vgg.jl b/src/convnets/vgg.jl index ea9f42fe..864285b2 100644 --- a/src/convnets/vgg.jl +++ b/src/convnets/vgg.jl @@ -103,13 +103,13 @@ const VGG_CONFIGS = Dict(11 => [(64, 1), (128, 1), (256, 2), (512, 2), (512, 2)] """ VGG(depth::Integer; pretrain::Bool = false, batchnorm::Bool = false, - inchannels::Integer = 3, nclasses::Integer = 1000) + inchannels::Integer = 3, nclasses::Integer = 1000, dropout_prob = 0.5) Create a VGG style model with specified `depth`. ([reference](https://arxiv.org/abs/1409.1556v6)). !!! warning - + `VGG` does not currently support pretrained weights for the `batchnorm = true` option. # Arguments @@ -119,6 +119,7 @@ Create a VGG style model with specified `depth`. - `batchnorm`: set to `true` to use batch normalization after each convolution - `inchannels`: number of input channels - `nclasses`: number of output classes +- `dropout_prob`: probability of `Dropout` layers setting inputs to zero See also [`vgg`](@ref). """ @@ -128,9 +129,9 @@ end @functor VGG function VGG(depth::Integer; pretrain::Bool = false, batchnorm::Bool = false, - inchannels::Integer = 3, nclasses::Integer = 1000) + inchannels::Integer = 3, nclasses::Integer = 1000, dropout_prob = 0.5) _checkconfig(depth, keys(VGG_CONFIGS)) - layers = vgg((224, 224); config = VGG_CONFIGS[depth], batchnorm, inchannels, nclasses) + layers = vgg((224, 224); config = VGG_CONFIGS[depth], batchnorm, inchannels, nclasses, dropout_prob) model = VGG(layers) if pretrain artifact_name = string("vgg", depth) From 69b34c3dc386ae63b3e4efa1f791518e2bfe8618 Mon Sep 17 00:00:00 2001 From: Adrian Hill Date: Tue, 16 Jan 2024 17:17:29 +0100 Subject: [PATCH 2/4] Remove kwarg `dropout_prob` --- src/convnets/vgg.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/convnets/vgg.jl b/src/convnets/vgg.jl index 864285b2..6a6e04a8 100644 --- a/src/convnets/vgg.jl +++ b/src/convnets/vgg.jl @@ -119,7 +119,6 @@ Create a VGG style model with specified `depth`. - `batchnorm`: set to `true` to use batch normalization after each convolution - `inchannels`: number of input channels - `nclasses`: number of output classes -- `dropout_prob`: probability of `Dropout` layers setting inputs to zero See also [`vgg`](@ref). """ @@ -129,9 +128,10 @@ end @functor VGG function VGG(depth::Integer; pretrain::Bool = false, batchnorm::Bool = false, - inchannels::Integer = 3, nclasses::Integer = 1000, dropout_prob = 0.5) + inchannels::Integer = 3, nclasses::Integer = 1000) _checkconfig(depth, keys(VGG_CONFIGS)) - layers = vgg((224, 224); config = VGG_CONFIGS[depth], batchnorm, inchannels, nclasses, dropout_prob) + layers = vgg((224, 224); config = VGG_CONFIGS[depth], batchnorm, inchannels, nclasses, + dropout_prob = 0.5) model = VGG(layers) if pretrain artifact_name = string("vgg", depth) From be19ebf6054c9f73ff74e7a30a7f4b66d28e2e86 Mon Sep 17 00:00:00 2001 From: Adrian Hill Date: Tue, 16 Jan 2024 17:18:18 +0100 Subject: [PATCH 3/4] Remove kwarg `dropout_prob` from docstring --- src/convnets/vgg.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/convnets/vgg.jl b/src/convnets/vgg.jl index 6a6e04a8..46876c6f 100644 --- a/src/convnets/vgg.jl +++ b/src/convnets/vgg.jl @@ -103,7 +103,7 @@ const VGG_CONFIGS = Dict(11 => [(64, 1), (128, 1), (256, 2), (512, 2), (512, 2)] """ VGG(depth::Integer; pretrain::Bool = false, batchnorm::Bool = false, - inchannels::Integer = 3, nclasses::Integer = 1000, dropout_prob = 0.5) + inchannels::Integer = 3, nclasses::Integer = 1000) Create a VGG style model with specified `depth`. ([reference](https://arxiv.org/abs/1409.1556v6)). From 01bb4adccc8f56f34fa2de4ae29a923bd3c2ec71 Mon Sep 17 00:00:00 2001 From: Adrian Hill Date: Wed, 17 Jan 2024 19:44:28 +0100 Subject: [PATCH 4/4] Increment patch version number to 0.9.3 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 3e7d07f4..d865a5d0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Metalhead" uuid = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" -version = "0.9.2" +version = "0.9.3" [deps] Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"