Skip to content

Commit

Permalink
efficientnet constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
a1eaiactaest committed Apr 17, 2024
1 parent bf5dc39 commit 649e8d0
Showing 1 changed file with 79 additions and 1 deletion.
80 changes: 79 additions & 1 deletion ads/machine_learning/efficientnet.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#!/usr/bin/env python3

import math

import mlx
import mlx.core as mx
import mlx.nn as nn
Expand All @@ -10,7 +12,8 @@
class MBConvBlock:
def __init__(
self,
kernel_sizestrides,
kernel_size,
strides,
expand_ratio,
input_filters,
output_filters,
Expand Down Expand Up @@ -63,5 +66,80 @@ def __call__(self, inputs):
x = x.add(inputs)
return x

class EfficientNet(nn.Module):
def __init__(self, number=0, classes=1000, has_se=True, track_running_stats=True, input_channels=3, has_fc_output=True):
self.number = number
global_params = [
(1.0, 1.0),
(1.0, 1.1),
(1.1, 1.2),
(1.2, 1.4),
(1.4, 1.8),
(1.6, 2.2),
(1.8, 2.6),
(2.0, 3.1),
(2.2, 3.6),
(4.3, 5.3),
][max(number, 0)]

def round_filters(filters):
multiplier = global_params[0]
divisor = 8
filters *= multiplier
new_filters = max(divisor, int(filters + divisor / 2) // divisor**2)
if new_filters < .9 * filters:
new_filters += divisor
return int(new_filters)

def round_repeats(repeats):
return int(math.ceil(global_params[1] * repeats))

out_channels = round_filters(32)
self._conv_stem = nn.init.glorot_uniform()(mx.zeros((out_channels, input_channels, 3, 3)))
self._bn0 = nn.BatchNorm(out_channels)
blocks_args = [
# num_reapeats, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio
[1, 3, (1, 1), 1, 32, 16, 0.25],
[2, 3, (2, 2), 6, 16, 24, 0.25],
[2, 5, (2, 2), 6, 24, 40, 0.25],
[3, 5, (2, 2), 6, 40, 80, 0.25],
[3, 5, (1, 1), 6, 80, 112, 0.25],
[4, 5, (2, 2), 6, 112, 192, 0.25],
[1, 3, (1, 1), 6, 192, 320, 0.25],
]

if self.number == -1:
blocks_args = [
[1, 3, (2, 2), 1, 32, 40, 0.25],
[1, 3, (2, 2), 1, 40, 80, 0.25],
[1, 3, (2, 2), 1, 80, 192, 0.25],
[1, 3, (2, 2), 1, 192, 320, 0.25],
]
elif self.numer == -2:
blocks_args = [
[1, 9, (8, 8), 1, 32, 320, 0.25],
]

self.blocks = []

for num_repeats, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio in blocks_args:
input_filters = round_filters(input_filters)
output_filters = round_filters(output_filters)
for _ in range(round_repeats(num_repeats)):
self.blocks.append(MBConvBlock(kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio, has_se=has_se, track_running_stats=track_running_stats))
input_filters = output_filters
strides = (1, 1)

in_channels = round_filters(320)
out_channels = round_filters(1280)
self._conv_head = nn.init.glorot_uniform()(mx.zeros((out_channels, in_channels, 1, 1)))
self._bn1 = nn.BatchNorm(out_channels)
if has_fc_output:
self._fc = nn.init.glorot_uniform()(mx.zeros((classes, out_channels)))
self._fc_bias = mx.zeros(classes)
else: self._fc = None





0 comments on commit 649e8d0

Please sign in to comment.