-
Notifications
You must be signed in to change notification settings - Fork 2
/
eval_imagenet.py
142 lines (115 loc) · 4.67 KB
/
eval_imagenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
import argparse
from utils import *
parser = argparse.ArgumentParser()
parser.add_argument('--net','-n', default = 'resnet50', type=str)
parser.add_argument('--gpu', '-g', default = '0', type=str)
parser.add_argument('--save_path', '-s', default='.', type=str)
parser.add_argument('--method' ,'-m', default = 'featurenorm', type=str)
args = parser.parse_args()
def forward_feature_resnet50(model, x):
features = []
x = model.conv1(x)
x = model.bn1(x)
x = model.relu(x)
x = model.maxpool(x)
for i in range(3):
x = model.layer1[i](x)
features.append(x)
for i in range(4):
x = model.layer2[i](x)
features.append(x)
for i in range(6):
x = model.layer3[i](x)
features.append(x)
for i in range(3):
x = model.layer4[i](x)
features.append(x)
return features
def forward_feature_vgg16(model, x):
layers = [64, 'r', 64, 'r', "M", 128, 'r', 128, 'r', "M", 256, 'r', 256, 'r', 256, 'r', "M", 512, 'r', 512, 'r', 512, 'r', "M", 512, 'r', 512, 'r', 512, 'r', "M"]
features = []
for i, layer in enumerate(layers):
x = model.features[i](x)
if layer == 'M':
features.append(x)
return features
def forward_feature_mobilenetv3(model, x):
features = []
for i, layer in enumerate(model.features):
# print(layer, type(layer).__name__)
x = model.features[i](x)
features.append(x)
return features
def calculate_norm(model, loader, device):
#FeatureNorm from the selected block
if type(model).__name__ == 'ResNet':
forward_features = forward_feature_resnet50
elif type(model).__name__ == 'VGG':
forward_features = forward_feature_vgg16
elif type(model).__name__ == 'MobileNetV3':
forward_features = forward_feature_mobilenetv3
model.eval()
predictions = []
with torch.no_grad():
for batch_idx, (inputs, t) in enumerate(loader):
x = inputs.to(device)
# ResNet
features = forward_features(model, x)
features = features[model.sblock]
# Norm calculation
norm = torch.norm(F.relu(features), dim=[2, 3]).mean(1)
predictions.append(norm)
predictions = torch.cat(predictions).to(device)
return predictions
def calculate_msp(model, loader, device):
model.eval()
predictions = []
with torch.no_grad():
for batch_idx, (inputs, t) in enumerate(loader):
x = inputs.to(device)
x = model(x)
x = torch.softmax(x, dim=1).max(dim=1).values
predictions.append(x)
predictions = torch.cat(predictions).to(device)
return predictions
if args.method == 'msp':
calculate_score = calculate_msp
elif args.method == 'featurenorm':
calculate_score = calculate_norm
def OOD_results(preds_id, model, loader, device, method, file):
#image_norm(loader)
preds_ood = calculate_score(model, loader, device).cpu()
print(torch.mean(preds_ood), torch.mean(preds_id))
show_performance(preds_id, preds_ood, method, file=file)
def eval():
device = 'cuda:'+args.gpu
num_classes = 1000
if 'resnet50' == args.net:
model = torchvision.models.resnet50(pretrained=True, num_classes=1000)
model.sblock = 14
if 'vgg16' == args.net:
model = torchvision.models.vgg16(pretrained=True, num_classes=1000)
model.sblock = 4
if 'mobilenetv3' == args.net:
model = torchvision.models.mobilenet_v3_large(pretrained=True, num_classes=1000)
model.sblock = 16
model.to(device)
model.eval()
config = read_conf('conf/imagenet.json')
_, valid_loader = get_imagenet(config['id_dataset'], 32)
f = open('{}/{}_result.txt'.format(args.save_path, args.net), 'w')
valid_accuracy = validation_accuracy(model, valid_loader, device)
print(valid_accuracy)
f.write('Accuracy for ValidationSet: {}\n'.format(str(valid_accuracy)))
preds_in = calculate_score(model, valid_loader, device).cpu()
OOD_results(preds_in, model, get_ood('./OOD_for_ImageNet/iNaturalist', for_imagenet=True), device, args.method+'-SVHN', f) # iNaturalist
OOD_results(preds_in, model, get_ood('./OOD_for_ImageNet/SUN', for_imagenet=True), device, args.method+'-SUN', f) # SUN
OOD_results(preds_in, model, get_ood('./OOD_for_ImageNet/Places', for_imagenet=True), device, args.method+'-PLACES', f) # PLACES
OOD_results(preds_in, model, get_ood('./OOD_for_ImageNet/dtd/images', for_imagenet=True), device, args.method+'-Textures', f) #TExtures
f.close()
if __name__ =='__main__':
eval()