forked from vadimkantorov/metriclearningbench
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathresnet18.py
26 lines (22 loc) · 887 Bytes
/
resnet18.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from collections import OrderedDict
import torch.nn as nn
import torchvision
class resnet18(nn.Sequential):
output_size = 512
input_side = 224
rescale = 1
rgb_mean = [0.485, 0.456, 0.406]
rgb_std = [0.229, 0.224, 0.225]
def __init__(self, dilation = False):
super(resnet18, self).__init__()
pretrained = torchvision.models.resnet18(pretrained = True)
for module in filter(lambda m: type(m) == nn.BatchNorm2d, pretrained.modules()):
module.eval()
module.train = lambda _: None
if dilation:
pretrained.layer4[0].conv1.dilation = (2, 2)
pretrained.layer4[0].conv1.padding = (2, 2)
pretrained.layer4[0].conv1.stride = (1, 1)
pretrained.layer4[0].downsample[0].stride = (1, 1)
for module_name in ['conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool']:
self.add_module(module_name, getattr(pretrained, module_name))