-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathGetModel.py
45 lines (33 loc) · 1.74 KB
/
GetModel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from models import phc_models, real_models
def GetModel(str_model, n, num_classes=1, weights=None, shared=False, patch_weights=True, visualize=False):
"""
Get model from str_model.
Parameters:
- str_model can be: resnet18, phcresnet18, resnet50, phcresnet50, sbonet, physbonet, senet, physenet.
- weights: path tho weights. Needed for physbonet and physenet.
- shared: parameter of physbonet.
- patch_weghts: parameter of physenet.
"""
print('Model:', str_model)
print()
## Two-view models ##
if str_model == 'resnet18':
return real_models.ResNet18(num_classes, channels=n, visualize=visualize)
elif str_model == 'phcresnet18':
return phc_models.PHCResNet18(channels=2, n=n, num_classes=num_classes, visualize=visualize)
if str_model == 'resnet50':
return real_models.ResNet50(num_classes, channels=n)
elif str_model == 'phcresnet50':
return phc_models.PHCResNet50(channels=2, n=n, num_classes=num_classes)
## Four-view models ##
if str_model == 'sbonet':
return real_models.SEnet(shared=shared, num_classes=num_classes, weights=weights)
elif str_model == 'physbonet':
return phc_models.PHYSBOnet(n=n, shared=shared, num_classes=num_classes, weights=weights)
if str_model == 'senet':
return real_models.SEnet(num_classes=num_classes, weights=weights, patch_weights=patch_weights, visualize=visualize)
elif str_model == 'physenet':
return phc_models.PHYSEnet(n=n, num_classes=num_classes, weights=weights, patch_weights=patch_weights, visualize=visualize)
else:
raise ValueError ('Model not implemented, check allowed models (-help) \n \
Check the model you typed.')