Skip to content

Commit

Permalink
Actually rewrite VGG to use vgg
Browse files Browse the repository at this point in the history
  • Loading branch information
theabhirath committed Jan 4, 2024
1 parent ee2d8bf commit 09f91c0
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
3 changes: 3 additions & 0 deletions src/Metalhead.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ include("vit-based/vit.jl")
# Load pretrained weights
include("pretrain.jl")

# deprecated
include("deprecations.jl")

# export model functions
export AlexNet, VGG, ResNet, WideResNet, ResNeXt, DenseNet,
GoogLeNet, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception,
Expand Down
13 changes: 5 additions & 8 deletions src/convnets/vgg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,10 @@ function vgg(imsize::Dims{2}; config, batchnorm::Bool = false, fcsize::Integer =
return Chain(Chain(conv...), class)
end

const VGG_CONV_CONFIGS = Dict(:A => [(64, 1), (128, 1), (256, 2), (512, 2), (512, 2)],
:B => [(64, 2), (128, 2), (256, 2), (512, 2), (512, 2)],
:D => [(64, 2), (128, 2), (256, 3), (512, 3), (512, 3)],
:E => [(64, 2), (128, 2), (256, 4), (512, 4), (512, 4)])

const VGG_CONFIGS = Dict(11 => :A, 13 => :B, 16 => :D, 19 => :E)
const VGG_CONFIGS = Dict(11 => [(64, 1), (128, 1), (256, 2), (512, 2), (512, 2)],
13 => [(64, 2), (128, 2), (256, 2), (512, 2), (512, 2)],
16 => [(64, 2), (128, 2), (256, 3), (512, 3), (512, 3)],
19 => [(64, 2), (128, 2), (256, 4), (512, 4), (512, 4)])

"""
VGG(depth::Integer; pretrain::Bool = false, batchnorm::Bool = false,
Expand Down Expand Up @@ -132,8 +130,7 @@ end
function VGG(depth::Integer; pretrain::Bool = false, batchnorm::Bool = false,
inchannels::Integer = 3, nclasses::Integer = 1000)
_checkconfig(depth, keys(VGG_CONFIGS))
model = VGG((224, 224); config = VGG_CONV_CONFIGS[VGG_CONFIGS[depth]], batchnorm,
inchannels, nclasses)
model = vgg((224, 224); config = VGG_CONFIGS[depth], batchnorm, inchannels, nclasses)
if pretrain
artifact_name = string("vgg", depth)
if batchnorm
Expand Down
2 changes: 1 addition & 1 deletion src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ function VGG(imsize::Dims{2}; config, batchnorm::Bool = false, dropout_prob = 0.
inchannels, batchnorm, nclasses)` instead for the same functionality.", :VGG)
layers = vgg(imsize; config, inchannels, batchnorm, nclasses, dropout_prob)
return VGG(layers)
end
end

0 comments on commit 09f91c0

Please sign in to comment.