-
Notifications
You must be signed in to change notification settings - Fork 2
/
architectures_unstructured.py
49 lines (45 loc) · 2.9 KB
/
architectures_unstructured.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
# from torchvision.models.resnet import resnet50
import torch.backends.cudnn as cudnn
from archs_unstructured.cifar_resnet import resnet34_in, resnet50_in, resnet18_in, resnet9_in
from archs_unstructured.cifar_resnet import wide_resnet_22_8, wide_resnet_28_10_drop02
# from archs.cifar_resnet import wide_resnet_22_8, wide_resnet_22_8_drop02, wide_resnet_28_10_drop02, wide_resnet_28_12_drop02, wide_resnet_16_8_drop02
from torch.nn.functional import interpolate
ARCHITECTURES = ["resnet50", "lenet300", "lenet5", "vgg19", "resnet32", "resnet50", "resnet34_in", "resnet50_in", "vgg16", "lenet_5_caffe",
"resnet18_in", "wide_resnet_16_8_drop02", "resnet9_in",
"wide_resnet_22_8_drop02", "wide_resnet_22_8", "wide_resnet_28_10", "wide_resnet_28_12", "wide_resnet_28_10_drop02", "wide_resnet_28_12_drop02"]
def get_architecture(arch: str, dataset: str, device, args) -> torch.nn.Module:
""" Return a neural network (with random weights)
:param arch: the architecture - should be in the ARCHITECTURES list above
:param dataset: the dataset - should be in the datasets.DATASETS list
:return: a Pytorch module
"""
if arch == "resnet9_in" and dataset == "cifar100":
model = resnet9_in(num_classes=100, args=args).to(device)
elif arch == "resnet18_in" and dataset == "cifar100":
model = resnet18_in(num_classes=100, args=args).to(device)
elif arch == "resnet18_in" and dataset == "cifar10":
model = resnet18_in(num_classes=10, args=args).to(device)
elif arch == "resnet18_in" and dataset == "tiny_imagenet":
model = resnet18_in(num_classes=200, args=args).to(device)
elif arch == "resnet34_in" and dataset == "cifar100":
model = resnet34_in(num_classes=100, args=args).to(device)
elif arch == "resnet34_in" and dataset == "cifar10":
model = resnet34_in(num_classes=10, args=args).to(device)
elif arch == "resnet34_in" and dataset == "tiny_imagenet":
model = resnet34_in(num_classes=200, args=args).to(device)
elif arch == "resnet50_in" and dataset == "cifar100":
model = resnet50_in(num_classes=100, args=args).to(device)
elif arch == "wide_resnet_22_8" and dataset == "cifar100":
model = wide_resnet_22_8(num_classes=100, args=args).to(device)
elif arch == "wide_resnet_22_8" and dataset == "cifar10":
model = wide_resnet_22_8(num_classes=10, args=args).to(device)
elif arch == "wide_resnet_22_8" and dataset == "tiny_imagenet":
model = wide_resnet_22_8(num_classes=200, args=args).to(device)
elif arch == "wide_resnet_28_10" and dataset == "cifar100":
model = wide_resnet_28_10_drop02(num_classes=100, args=args).to(device)
elif arch == "wide_resnet_28_10" and dataset == "cifar10":
model = wide_resnet_28_10_drop02(num_classes=10, args=args).to(device)
else:
raise AssertionError("Your architecture is not in the list.")
return model