From 0ddc3016d3e468201595d4ececec87c8baaca530 Mon Sep 17 00:00:00 2001 From: Piotrek Rybiec Date: Sun, 28 Apr 2024 12:39:04 +0200 Subject: [PATCH] resnet done --- algo/machine_learning/models/resnet.py | 31 ++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/algo/machine_learning/models/resnet.py b/algo/machine_learning/models/resnet.py index 129395f..591650b 100644 --- a/algo/machine_learning/models/resnet.py +++ b/algo/machine_learning/models/resnet.py @@ -5,7 +5,7 @@ import mlx.core as mx import mlx.nn as nn -from algo.helpers import sequential +from algo.helpers import sequential, torch_load, fetch, get_child class BasicBlock(nn.Module): expansion = 1 @@ -116,4 +116,31 @@ def __call__(self, x): return out return features - + def load_pretrained(self): + urls = { + (18, 1 ,64): 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + (34, 1 ,64): 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + (50, 1 ,64): 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + (50, 32, 4): 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', + (101, 1 ,64): 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + (152, 1 ,64): 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', + } + + self.url = urls[(self.num, self.groups, self.base_width)] + for k, v in torch_load(fetch(self.url)).items(): + obj = get_child(self, k) + dat = mx.array(v) + + if 'fc.' in k and obj.shape != dat.shape: + print('skipping fully connected layer') + continue + + assert obj.shape == dat.shape or (obj.shape == (1,) and dat.shape == ()), (k, obj.shape, dat.shape) + +_num_classes = 1000 +ResNet18 = ResNet(18, num_classes=_num_classes) +ResNet34 = ResNet(34, num_classes=_num_classes) +ResNet50 = ResNet(50, num_classes=_num_classes) +ResNet101 = ResNet(101, num_classes=_num_classes) +ResNet152 = ResNet(152, num_classes=_num_classes) +ResNetXt50_32x4D = ResNet(50, num_classes=_num_classes, groups=32, width_per_group=4)