-
Notifications
You must be signed in to change notification settings - Fork 3
/
validation.py
242 lines (195 loc) · 9.85 KB
/
validation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
# -*- coding: utf-8 -*-
# --------------------------------------------------
#
# validation.py
#
# Validation phase. Model tuning of attribute label embedding (ALE) method on APY dataset
# 15 seen classes : bird - cat - mug - bus - diningtable - bottle - car - boat
# dog - zebra - monkey - centaur - chair - bicycle - building
# 5 unseen classes : aeroplane - wolf - carriage - sofa - bag
#
# Written by cetinsamet -*- cetin.samet@metu.edu.tr
# April, 2019
# --------------------------------------------------
import random
random.seed(123)
import numpy as np
np.random.seed(123)
import torch
torch.manual_seed(123)
from torch.utils.data import TensorDataset, DataLoader
from easydict import EasyDict as edict
from tools import load_data, map_labels
from model import Network, evaluate
from config import OBJPATH
import pickle
def main():
#print('##### VALIDATION PHASE #####')
# read data
__C = edict()
with open(OBJPATH, 'rb') as infile:
__C = pickle.load(infile)
# --------------------------------------------------------------------------------------------------------------- #
#load data
allClassVectors = load_data(__C.ALL_CLASS_VEC, 'all_class_vec')
trainFeatures = load_data(__C.TRAIN_FEATURES, 'train_features')
trainLabels = load_data(__C.TRAIN_LABELS, 'train_labels')
seenFeatures = load_data(__C.VAL_SEEN_FEATURES, 'val_seen_features')
seenLabels = load_data(__C.VAL_SEEN_LABELS, 'val_seen_labels')
unseenFeatures = load_data(__C.VAL_UNSEEN_FEATURES, 'val_unseen_features')
unseenLabels = load_data(__C.VAL_UNSEEN_LABELS, 'val_unseen_labels')
'''
print("##" * 25)
print("All Class Vectors : ", allClassVectors.shape)
print("Train Features : ", trainFeatures.shape)
print("Train Labels : ", trainLabels.shape)
print("Seen Features : ", seenFeatures.shape)
print("Seen Labels : ", seenLabels.shape)
print("Unseen Features : ", unseenFeatures.shape)
print("Unseen Labels : ", unseenLabels.shape)
print("##" * 25)
'''
# --------------------------------------------------------------------------------------------------------------- #
# get data information
n_class, attr_dim = allClassVectors.shape
n_train, feat_dim = trainFeatures.shape
n_seen, _ = seenFeatures.shape
n_unseen, _ = unseenFeatures.shape
'''
print("##" * 25)
print("Number of Train samples : ", n_train)
print("Number of Seen samples : ", n_seen)
print("Number of Unseen samples : ", n_unseen)
print("Number of Classes : ", n_class)
print("Vector Dim : ", attr_dim)
print("Feature Dim : ", feat_dim)
print("##" * 25)
'''
# --------------------------------------------------------------------------------------------------------------- #
seenClassIndices = np.unique(trainLabels)
unseenClassIndices = np.unique(unseenLabels)
m_trainLabels = map_labels(trainLabels, n_class, seenClassIndices)
m_seenLabels = map_labels(seenLabels, n_class, seenClassIndices)
m_genSeenLabels = seenLabels.flatten()
m_unseenLabels = map_labels(unseenLabels, n_class, unseenClassIndices)
m_genUnseenLabels = unseenLabels.flatten()
# --------------------------------------------------------------------------------------------------------------- #
# set network hyper-parameters
n_epoch = __C.N_EPOCH
batch_size = __C.BATCH_SIZE
lr = __C.LR
# set network architecture, optimizer and loss function
model = Network(feature_dim=feat_dim, vector_dim=attr_dim)
optimizer = torch.optim.Adam(model.parameters(), lr=lr) # <-- Optimizer
criterion = torch.nn.CrossEntropyLoss(reduction='sum') # <-- Loss Function
# --------------------------------------------------------------------------------------------------------------- #
# convert data from numpy arrays to pytorch tensors
x_trainFeatures = torch.from_numpy(trainFeatures).float()
y_trainLabels = torch.from_numpy(m_trainLabels).long()
x_seenFeatures = torch.from_numpy(seenFeatures).float()
y_seenLabels = torch.from_numpy(m_seenLabels).long()
y_genSeenLabels = torch.from_numpy(m_genSeenLabels).long()
x_unseenFeatures = torch.from_numpy(unseenFeatures).float()
y_unseenLabels = torch.from_numpy(m_unseenLabels).long()
y_genUnseenLabels = torch.from_numpy(m_genUnseenLabels).long()
seenVectors = torch.from_numpy(allClassVectors[seenClassIndices, :]).float()
unseenVectors = torch.from_numpy(allClassVectors[unseenClassIndices, :]).float()
allVectors = torch.from_numpy(allClassVectors).float()
'''
print("##" * 25)
print("Seen Vector shape : ", tuple(seenVectors.size()))
print("Unseen Vector shape : ", tuple(unseenVectors.size()))
print("All Vector shape : ", tuple(allVectors.size()))
print("##" * 25)
'''
# initialize data loader
trainData = TensorDataset(x_trainFeatures, y_trainLabels)
trainLoader = DataLoader(trainData, batch_size=batch_size, shuffle=True)
# **************************************************************************************************************** #
# ATTRIBUTE LABEL EMBEDDING #
# **************************************************************************************************************** #
max_zslAcc = float('-inf')
max_gSeenAcc = float('-inf')
max_gUnseenAcc = float('-inf')
max_hScore = float('-inf')
# --------------------- #
# TRAINING #
# --------------------- #
for epochID in range(n_epoch):
model.train() # <-- Train Mode On
running_train_loss = 0.
for x, y in trainLoader:
y_out = model(x, seenVectors)
train_loss = criterion(y_out, y)
optimizer.zero_grad() # <-- set gradients to zero
train_loss.backward() # <-- calculate gradients
optimizer.step() # <-- update weights
running_train_loss += train_loss.item()
# ---------------------- #
# PRINT LOSS #
# ---------------------- #
#print("%s\tTrain Loss: %s" % (str(epochID + 1), str(running_train_loss / n_train)))
if (epochID + 1) % __C.INFO_EPOCH == 0:
# ---------------------- #
# EVALUATION #
# ---------------------- #
model.eval() # <-- Evaluation Mode On
#print("##" * 25)
# ------------------------------------------------------- #
# TRAIN ACCURACY
y_out = model(x_trainFeatures, seenVectors)
y_out = torch.argmax(y_out, dim=1)
trainScore = torch.sum(y_out == y_trainLabels).item()
trainAcc = trainScore / n_train
#print("Train acc : %s" % str(trainAcc))
# ------------------------------------------------------- #
# * ----- * ----- * ----- * ----- * ----- * ----- * ----- *
# ------------------------------------------------------- #
# ZERO-SHOT ACCURACY
zslAcc = evaluate( model = model,
x = x_unseenFeatures,
y = y_unseenLabels,
vec = unseenVectors)
#print("Zero-Shot acc : %s" % str(zslAcc))
# ------------------------------------------------------- #
# * ----- * ----- * ----- * ----- * ----- * ----- * ----- *
# ------------------------------------------------------- #
# GENERALIZED SEEN ACCURACY
gSeenAcc = evaluate( model = model,
x = x_seenFeatures,
y = y_genSeenLabels,
vec = allVectors)
#print("Generalized Seen acc : %s" % str(gSeenAcc))
# ------------------------------------------------------- #
# * ----- * ----- * ----- * ----- * ----- * ----- * ----- *
# ------------------------------------------------------- #
# GENERALIZED UNSEEN ACCURACY
gUnseenAcc = evaluate( model = model,
x = x_unseenFeatures,
y = y_genUnseenLabels,
vec = allVectors)
#print("Generalized Unseen acc : %s" % str(gUnseenAcc))
# ------------------------------------------------------- #
# * ----- * ----- * ----- * ----- * ----- * ----- * ----- *
# ------------------------------------------------------- #
# GENERALIZED ZERO-SHOT ACCURACY
if gSeenAcc + gUnseenAcc == 0.:
hScore = 0.
else:
hScore = (2 * gSeenAcc * gUnseenAcc) / (gSeenAcc + gUnseenAcc)
#print("H-Score : %s" % str(hScore))
# ------------------------------------------------------- #
#print("##" * 25)
if hScore > max_hScore:
max_zslAcc = zslAcc
max_gSeenAcc = gSeenAcc
max_gUnseenAcc = gUnseenAcc
max_hScore = hScore
print("Zsl Acc: %.5s\tGen Seen Acc: %.5s\tGen Unseen Acc: %.5s\t\033[1mH-Score: %.5s\033[0m" \
% (str(max_zslAcc), \
str(max_gSeenAcc), \
str(max_gUnseenAcc), \
str(max_hScore)))
return
if __name__ == '__main__':
main()