diff --git a/src/Metalhead.jl b/src/Metalhead.jl index aab12aa11..c97bc0053 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -11,7 +11,7 @@ using MLUtils import Functors include("utilities.jl") -include("layers.jl") +include("layers.jl") # CNN models include("convnets/alexnet.jl") @@ -23,6 +23,7 @@ include("convnets/resnext.jl") include("convnets/densenet.jl") include("convnets/squeezenet.jl") include("convnets/mobilenet.jl") +include("convnets/efficientnet/efficientnet.jl") # Other models include("other/mlpmixer.jl") @@ -37,12 +38,13 @@ export AlexNet, DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201, ResNeXt, MobileNetv2, MobileNetv3, + EfficientNet, MLPMixer, ViT # use Flux._big_show to pretty print large models -for T in (:AlexNet, :VGG, :ResNet, :GoogLeNet, :Inception3, :SqueezeNet, :DenseNet, :ResNeXt, - :MobileNetv2, :MobileNetv3, :MLPMixer, :ViT) +for T in (:AlexNet, :VGG, :ResNet, :GoogLeNet, :Inception3, :SqueezeNet, :DenseNet, :ResNeXt, + :MobileNetv2, :MobileNetv3, :EfficientNet, :MLPMixer) @eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model) end diff --git a/src/convnets/efficientnet/efficientnet.jl b/src/convnets/efficientnet/efficientnet.jl new file mode 100644 index 000000000..d2fa84a5a --- /dev/null +++ b/src/convnets/efficientnet/efficientnet.jl @@ -0,0 +1,72 @@ +include("params.jl") +include("mb.jl") + +struct EfficientNet{S, B, H, P, F} + stem::S + blocks::B + + head::H + pooling::P + top::F +end +Flux.@functor EfficientNet + +function EfficientNet( + model_name, block_params, global_params; in_channels, n_classes, pretrain, +) + pad, bias = SamePad(), false + out_channels = round_filter(32, global_params) + stem = Chain( + Conv((3, 3), in_channels=>out_channels; bias, stride=2, pad), + BatchNorm(out_channels, swish)) + + blocks = MBConv[] + for bp in block_params + in_channels = round_filter(bp.in_channels, global_params) + out_channels = round_filter(bp.out_channels, global_params) + repeat = global_params.depth_coef ≈ 1 ? + bp.repeat : ceil(Int64, global_params.depth_coef * bp.repeat) + + push!(blocks, MBConv( + in_channels, out_channels, bp.kernel, bp.stride; + expansion_ratio=bp.expansion_ratio)) + for _ in 1:(repeat - 1) + push!(blocks, MBConv( + out_channels, out_channels, bp.kernel, 1; + expansion_ratio=bp.expansion_ratio)) + end + end + blocks = Chain(blocks...) + + head_out_channels = round_filter(1280, global_params) + head = Chain( + Conv((1, 1), out_channels=>head_out_channels; bias, pad), + BatchNorm(head_out_channels, swish)) + + top = Dense(head_out_channels, n_classes) + model = EfficientNet(stem, blocks, head, AdaptiveMeanPool((1, 1)), top) + pretrain && loadpretrain!(model, "EfficientNet" * model_name) + model +end + +""" + EfficientNet(block_params, global_params; in_channels = 3) + +Construct an EfficientNet model +([reference](https://arxiv.org/abs/1905.11946)). + +# Arguments +- `model_name::String`: Name of the model. Accepts `b0`-`b8` names. +- `in_channels::Int`: Number of input channels. Default is `3`. +- `n_classes::Int`: Number of output classes. Default is `1000`. +- `pretrain::Bool`: Whether to load ImageNet pretrained weights. + Default is `false`. +""" +EfficientNet( + model_name::String; in_channels::Int = 3, + n_classes::Int = 1000, pretrain::Bool = false, +) = EfficientNet( + model_name, get_efficientnet_params(model_name)...; + in_channels, n_classes, pretrain) + +(m::EfficientNet)(x) = m.top(Flux.flatten(m.pooling(m.head(m.blocks(m.stem(x)))))) diff --git a/src/convnets/efficientnet/mb.jl b/src/convnets/efficientnet/mb.jl new file mode 100644 index 000000000..466daf2d9 --- /dev/null +++ b/src/convnets/efficientnet/mb.jl @@ -0,0 +1,82 @@ +struct MBConv{E, D, X, P} + expansion::E + depthwise::D + excitation::X + projection::P + + do_expansion::Bool + do_excitation::Bool + do_skip::Bool +end +Flux.@functor MBConv + +""" + MBConv( + in_channels, out_channels, kernel, stride; + expansion_ratio, se_ratio) + +Mobile Inverted Residual Bottleneck Block +([reference](https://arxiv.org/abs/1801.04381)). + +# Arguments +- `in_channels`: Number of input channels. +- `out_channels`: Number of output channels. +- `expansion_ratio`: + Expansion ratio defines the number of output channels. + Set to `1` to disable expansion phase. + `out_channels = input_channels * expansion_ratio`. +- `kernel`: Size of the kernel for the depthwise conv phase. +- `stride`: Size of the stride for the depthwise conv phase. +- `se_ratio`: + Squeeze-Excitation ratio. Should be in `(0, 1]` range. + Set to `-1` to disable. +""" +function MBConv( + in_channels, out_channels, kernel, stride; + expansion_ratio, se_ratio = 0.25, +) + do_skip = stride == 1 && in_channels == out_channels + do_expansion, do_excitation = expansion_ratio != 1, 0 < se_ratio ≤ 1 + pad, bias = SamePad(), false + + mid_channels = ceil(Int, in_channels * expansion_ratio) + expansion = do_expansion ? + Chain( + Conv((1, 1), in_channels=>mid_channels; bias, pad), + BatchNorm(mid_channels, swish)) : + identity + + depthwise = Chain( + Conv(kernel, mid_channels=>mid_channels; bias, stride, pad, groups=mid_channels), + BatchNorm(mid_channels, swish)) + + if do_excitation + n_squeezed_channels = max(1, ceil(Int, in_channels * se_ratio)) + excitation = Chain( + AdaptiveMeanPool((1, 1)), + Conv((1, 1), mid_channels=>n_squeezed_channels, swish; pad), + Conv((1, 1), n_squeezed_channels=>mid_channels; pad)) + else + excitation = identity + end + + projection = Chain( + Conv((1, 1), mid_channels=>out_channels; pad, bias), + BatchNorm(out_channels)) + MBConv( + expansion, depthwise, excitation, projection, do_expansion, + do_excitation, do_skip) +end + +function (m::MBConv)(x) + o = m.depthwise(m.expansion(x)) + + if m.do_excitation + o = σ.(m.excitation(o)) .* o + end + o = m.projection(o) + if m.do_skip + o = o + x + end + o +end diff --git a/src/convnets/efficientnet/params.jl b/src/convnets/efficientnet/params.jl new file mode 100644 index 000000000..45e0db198 --- /dev/null +++ b/src/convnets/efficientnet/params.jl @@ -0,0 +1,59 @@ +struct BlockParams + repeat::Int + kernel::Tuple{Int, Int} + stride::Int + expansion_ratio::Int + in_channels::Int + out_channels::Int +end + +struct GlobalParams + width_coef::Real + depth_coef::Real + image_size::Tuple{Int, Int} + + depth_divisor::Int + min_depth::Union{Nothing, Int} +end + +# (width_coefficient, depth_coefficient, resolution) +get_efficientnet_coefficients(model_name::String) = + Dict( + "b0" => (1.0, 1.0, 224), + "b1" => (1.0, 1.1, 240), + "b2" => (1.1, 1.2, 260), + "b3" => (1.2, 1.4, 300), + "b4" => (1.4, 1.8, 380), + "b5" => (1.6, 2.2, 456), + "b6" => (1.8, 2.6, 528), + "b7" => (2.0, 3.1, 600), + "b8" => (2.2, 3.6, 672))[model_name] + +function get_efficientnet_params(model_name) + block_params = [ + BlockParams(1, (3, 3), 1, 1, 32, 16), + BlockParams(2, (3, 3), 2, 6, 16, 24), + BlockParams(2, (5, 5), 2, 6, 24, 40), + BlockParams(3, (3, 3), 2, 6, 40, 80), + BlockParams(3, (5, 5), 1, 6, 80, 112), + BlockParams(4, (5, 5), 2, 6, 112, 192), + BlockParams(1, (3, 3), 1, 6, 192, 320)] + + width_coef, depth_coef, resolution = get_efficientnet_coefficients(model_name) + global_params = GlobalParams( + width_coef, depth_coef, (resolution, resolution), 8, nothing) + block_params, global_params +end + +function round_filter(filters, global_params::GlobalParams) + global_params.width_coef ≈ 1 && return filters + + depth_divisor = global_params.depth_divisor + filters *= global_params.width_coef + min_depth = global_params.min_depth + min_depth = min_depth ≡ nothing ? depth_divisor : min_depth + + new_filters = max(min_depth, (floor(Int, filters + depth_divisor / 2) ÷ depth_divisor) * depth_divisor) + new_filters < 0.9 * filters && (new_filters += global_params.depth_divisor) + new_filters +end diff --git a/test/convnets.jl b/test/convnets.jl index e9a99748d..f7f558712 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -125,3 +125,15 @@ end end end end + +@testset "EfficientNet" begin + m = EfficientNet("b4") + x = rand(Float32, 224, 224, 3, 2) + @test size(m(x)) == (1000, 2) + @test_throws ArgumentError (EfficientNet("b0"; pretrain = true); true) + @test_skip gradtest(m, rand(Float32, 224, 224, 3, 2)) + + # TODO: add test for inferrability once the new version of Flux is released with + # https://github.com/FluxML/Flux.jl/pull/1856 + # @inferred m(x) +end