-
-
Notifications
You must be signed in to change notification settings - Fork 66
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement EfficientNet #113
Conversation
As for the pretrained weights, we can use utilities to load PyTorch's weights to the model, save as |
Wanted to add pretrained weights, but looks like Metalhead's utilities for loading weights only include trainable parameters: Line 23 in 4955fde
But it is improtant to load moving mean and variance for the BatchNorm. |
Additionally, maybe we should start adding inferrability tests for the models? (in light of Taking TTFX seriously) |
Based on FluxML/MetalheadWeights#2, that seems to be the way to go. |
Thank you for this PR! I have filed FluxML/Flux.jl#1875 to address the issue about loading state. Along with FluxML/MetalheadWeights#2, we should be able to handle models saved with The implementation in this PR is pretty much good to go. It mostly needs to be refactored to match the model building style of this package, and it needs to reuse existing functions. Based on this, I suggest the following (this is not a complete rewrite—the internal code is the same—it was just too annoying to input as a GH review comment): function efficientnet(imsize, scalings, block_config;
inchannels = 3, nclasses = 1000, max_width = 1280)
wscale, dscale = scalings
out_channels = _round_channels(32, 8)
stem = Chain(Conv((3, 3), inchannels => out_channels; bias = false, stride = 2, SamePad()),
BatchNorm(out_channels, swish))
blocks = []
for (n, k, s, e, i, o) in block_config
in_channels = round_filter(i, 8)
out_channels = round_filter(o, 8)
repeat = dscale ≈ 1 ? n : ceil(Int64, dscale * n)
push!(blocks, invertedresidual(k, in_channels, in_channels * e, out_channels, swish;
stride = s, reduction = 4))
for _ in 1:(repeat - 1)
push!(blocks, invertedresidual(k, out_channels, out_channels * e, out_channels, swish;
stride = 1, reduction = 4))
end
end
blocks = Chain(blocks...)
head_out_channels = _round_channels(max_width, 8)
head = Chain(Conv((1, 1), out_channels => head_out_channels; bias = false, pad = SamePad()),
BatchNorm(head_out_channels, swish))
top = Dense(head_out_channels, nclasses)
return Chain(Chain(stem, blocks, head),
Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, top))
end
# n: # of block repetitions
# k: kernel size k x k
# s: stride
# e: expantion ratio
# i: block input channels
# o: block output channels
const efficinet_block_configs = [
# (n, k, s, e, i, o)
(1, 3, 1, 1, 32, 16),
(2, 3, 2, 6, 16, 24),
(2, 5, 2, 6, 24, 40),
(3, 3, 2, 6, 40, 80),
(3, 5, 1, 6, 80, 112),
(4, 5, 2, 6, 112, 192),
(1, 3, 1, 6, 192, 320)
]
# w: width scaling
# d: depth scaling
# r: image resolution
const efficient_global_configs = Dict(
# ( r, ( w, d))
:b0 => (224, (1.0, 1.0)),
:b1 => (240, (1.0, 1.1)),
:b2 => (260, (1.1, 1.2)),
:b3 => (300, (1.2, 1.4)),
:b4 => (380, (1.4, 1.8)),
:b5 => (456, (1.6, 2.2)),
:b6 => (528, (1.8, 2.6)),
:b7 => (600, (2.0, 3.1)),
:b8 => (672, (2.2, 3.6))
)
struct EfficientNet
layers
end
function EfficientNet(imsize, scalings, block_config;
inchannels = 3, nclasses = 1000, max_width = 1280)
layers = efficientnet(imsize, scalings, block_config;
inchannels = inchannels, nclasses = nclasses, max_width = max_width)
EfficientNet(layers)
end
@functor EfficientNet
(m::EfficientNet)(x) = m.layers(x)
backbone(m::EfficientNet) = m.layers[1]
classifier(m::EfficientNet) = m.layers[2]
function EfficientNet(name::Symbol; pretrain = false)
@assert name in keys(efficient_global_configs)
"`name` must be one of $(sort(collect(keys(efficient_global_configs))))"
model = EfficientNet(efficient_global_configs[name]..., efficinet_block_configs)
pretrain && loadpretrain!(model, string("EfficientNet", name))
return model
end Note that this requires a rebase to pass since it depends on #120. The above code can be under |
@pxl-th any interest in reviving this PR with the feedback above? If not, is it okay if I continue off this PR to complete it? |
Sorry, I don't have a lot of free time currently. It is totally ok if you are interested in completing this PR :) |
Superseded by #171. |
This is an implementation of EfficientNet model in a similar way to PyTorch EfficientNet.
TODO