Skip to content

Commit

Permalink
resnet done
Browse files Browse the repository at this point in the history
  • Loading branch information
a1eaiactaest committed Apr 28, 2024
1 parent a247e73 commit 0ddc301
Showing 1 changed file with 29 additions and 2 deletions.
31 changes: 29 additions & 2 deletions algo/machine_learning/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import mlx.core as mx
import mlx.nn as nn

from algo.helpers import sequential
from algo.helpers import sequential, torch_load, fetch, get_child

class BasicBlock(nn.Module):
expansion = 1
Expand Down Expand Up @@ -116,4 +116,31 @@ def __call__(self, x):
return out
return features


def load_pretrained(self):
urls = {
(18, 1 ,64): 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
(34, 1 ,64): 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
(50, 1 ,64): 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
(50, 32, 4): 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
(101, 1 ,64): 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
(152, 1 ,64): 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}

self.url = urls[(self.num, self.groups, self.base_width)]
for k, v in torch_load(fetch(self.url)).items():
obj = get_child(self, k)
dat = mx.array(v)

if 'fc.' in k and obj.shape != dat.shape:
print('skipping fully connected layer')
continue

assert obj.shape == dat.shape or (obj.shape == (1,) and dat.shape == ()), (k, obj.shape, dat.shape)

_num_classes = 1000
ResNet18 = ResNet(18, num_classes=_num_classes)
ResNet34 = ResNet(34, num_classes=_num_classes)
ResNet50 = ResNet(50, num_classes=_num_classes)
ResNet101 = ResNet(101, num_classes=_num_classes)
ResNet152 = ResNet(152, num_classes=_num_classes)
ResNetXt50_32x4D = ResNet(50, num_classes=_num_classes, groups=32, width_per_group=4)

0 comments on commit 0ddc301

Please sign in to comment.