-
Notifications
You must be signed in to change notification settings - Fork 28
/
model.py
29 lines (24 loc) · 940 Bytes
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch.nn as nn
from torchvision import models
def get_head(out_size, cfg):
""" creates projection head g() from config """
x = []
in_size = out_size
for _ in range(cfg.head_layers - 1):
x.append(nn.Linear(in_size, cfg.head_size))
if cfg.add_bn:
x.append(nn.BatchNorm1d(cfg.head_size))
x.append(nn.ReLU())
in_size = cfg.head_size
x.append(nn.Linear(in_size, cfg.emb))
return nn.Sequential(*x)
def get_model(arch, dataset):
""" creates encoder E() by name and modifies it for dataset """
model = getattr(models, arch)(pretrained=False)
if dataset != "imagenet":
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
if dataset == "cifar10" or dataset == "cifar100":
model.maxpool = nn.Identity()
out_size = model.fc.in_features
model.fc = nn.Identity()
return nn.DataParallel(model), out_size