From 649e8d0a4f8b8b2ce0014519f73376850acdc3c8 Mon Sep 17 00:00:00 2001 From: Piotrek Rybiec Date: Wed, 17 Apr 2024 19:15:16 +0200 Subject: [PATCH] efficientnet constructor --- ads/machine_learning/efficientnet.py | 80 +++++++++++++++++++++++++++- 1 file changed, 79 insertions(+), 1 deletion(-) diff --git a/ads/machine_learning/efficientnet.py b/ads/machine_learning/efficientnet.py index f122664..1040b73 100644 --- a/ads/machine_learning/efficientnet.py +++ b/ads/machine_learning/efficientnet.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 +import math + import mlx import mlx.core as mx import mlx.nn as nn @@ -10,7 +12,8 @@ class MBConvBlock: def __init__( self, - kernel_sizestrides, + kernel_size, + strides, expand_ratio, input_filters, output_filters, @@ -63,5 +66,80 @@ def __call__(self, inputs): x = x.add(inputs) return x +class EfficientNet(nn.Module): + def __init__(self, number=0, classes=1000, has_se=True, track_running_stats=True, input_channels=3, has_fc_output=True): + self.number = number + global_params = [ + (1.0, 1.0), + (1.0, 1.1), + (1.1, 1.2), + (1.2, 1.4), + (1.4, 1.8), + (1.6, 2.2), + (1.8, 2.6), + (2.0, 3.1), + (2.2, 3.6), + (4.3, 5.3), + ][max(number, 0)] + + def round_filters(filters): + multiplier = global_params[0] + divisor = 8 + filters *= multiplier + new_filters = max(divisor, int(filters + divisor / 2) // divisor**2) + if new_filters < .9 * filters: + new_filters += divisor + return int(new_filters) + + def round_repeats(repeats): + return int(math.ceil(global_params[1] * repeats)) + + out_channels = round_filters(32) + self._conv_stem = nn.init.glorot_uniform()(mx.zeros((out_channels, input_channels, 3, 3))) + self._bn0 = nn.BatchNorm(out_channels) + blocks_args = [ + # num_reapeats, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio + [1, 3, (1, 1), 1, 32, 16, 0.25], + [2, 3, (2, 2), 6, 16, 24, 0.25], + [2, 5, (2, 2), 6, 24, 40, 0.25], + [3, 5, (2, 2), 6, 40, 80, 0.25], + [3, 5, (1, 1), 6, 80, 112, 0.25], + [4, 5, (2, 2), 6, 112, 192, 0.25], + [1, 3, (1, 1), 6, 192, 320, 0.25], + ] + + if self.number == -1: + blocks_args = [ + [1, 3, (2, 2), 1, 32, 40, 0.25], + [1, 3, (2, 2), 1, 40, 80, 0.25], + [1, 3, (2, 2), 1, 80, 192, 0.25], + [1, 3, (2, 2), 1, 192, 320, 0.25], + ] + elif self.numer == -2: + blocks_args = [ + [1, 9, (8, 8), 1, 32, 320, 0.25], + ] + + self.blocks = [] + + for num_repeats, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio in blocks_args: + input_filters = round_filters(input_filters) + output_filters = round_filters(output_filters) + for _ in range(round_repeats(num_repeats)): + self.blocks.append(MBConvBlock(kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio, has_se=has_se, track_running_stats=track_running_stats)) + input_filters = output_filters + strides = (1, 1) + + in_channels = round_filters(320) + out_channels = round_filters(1280) + self._conv_head = nn.init.glorot_uniform()(mx.zeros((out_channels, in_channels, 1, 1))) + self._bn1 = nn.BatchNorm(out_channels) + if has_fc_output: + self._fc = nn.init.glorot_uniform()(mx.zeros((classes, out_channels))) + self._fc_bias = mx.zeros(classes) + else: self._fc = None + + +