-
Notifications
You must be signed in to change notification settings - Fork 8
/
evaluate.py
119 lines (102 loc) · 4.37 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# Base code is from https://github.com/cs230-stanford/cs230-code-examples
import logging
import copy
import torch
import numpy as np
from collections import OrderedDict
from src.data_loader import fetch_dataloaders
def evaluate(model, loss_fn, meta_classes, task_lr, task_type, metrics, params,
split):
"""
Evaluate the model on `num_steps` batches.
Args:
model: (MetaLearner) a meta-learner that is trained on MAML
loss_fn: a loss function
meta_classes: (list) a list of classes to be evaluated in meta-training or meta-testing
task_lr: (float) a task-specific learning rate
task_type: (subclass of FewShotTask) a type for generating tasks
metrics: (dict) a dictionary of functions that compute a metric using
the output and labels of each batch
params: (Params) hyperparameters
split: (string) 'train' if evaluate on 'meta-training' and
'test' if evaluate on 'meta-testing' TODO 'meta-validating'
"""
# params information
SEED = params.SEED
num_classes = params.num_classes
num_samples = params.num_samples
num_query = params.num_query
num_steps = params.num_steps
num_eval_updates = params.num_eval_updates
# set model to evaluation mode
# NOTE eval() is not needed since everytime task is varying and batchnorm
# should compute statistics within the task.
# model.eval()
# summary for current eval loop
summ = []
# compute metrics over the dataset
for episode in range(num_steps):
# Make a single task
# Make dataloaders to load support set and query set
task = task_type(meta_classes, num_classes, num_samples, num_query)
dataloaders = fetch_dataloaders(['train', 'test'], task)
dl_sup = dataloaders['train']
dl_que = dataloaders['test']
X_sup, Y_sup = dl_sup.__iter__().next()
X_que, Y_que = dl_que.__iter__().next()
# move to GPU if available
if params.cuda:
X_sup, Y_sup = X_sup.cuda(async=True), Y_sup.cuda(async=True)
X_que, Y_que = X_que.cuda(async=True), Y_que.cuda(async=True)
# Direct optimization
net_clone = copy.deepcopy(model)
optim = torch.optim.SGD(net_clone.parameters(), lr=task_lr)
for _ in range(num_eval_updates):
Y_sup_hat = net_clone(X_sup)
loss = loss_fn(Y_sup_hat, Y_sup)
optim.zero_grad()
loss.backward()
optim.step()
Y_que_hat = net_clone(X_que)
loss = loss_fn(Y_que_hat, Y_que)
# # clear previous gradients, compute gradients of all variables wrt loss
# def zero_grad(params):
# for p in params:
# if p.grad is not None:
# p.grad.zero_()
# # NOTE In Meta-SGD paper, num_eval_updates=1 is enough
# for _ in range(num_eval_updates):
# Y_sup_hat = model(X_sup)
# loss = loss_fn(Y_sup_hat, Y_sup)
# zero_grad(model.parameters())
# grads = torch.autograd.grad(loss, model.parameters())
# # step() manually
# adapted_state_dict = model.cloned_state_dict()
# adapted_params = OrderedDict()
# for (key, val), grad in zip(model.named_parameters(), grads):
# adapted_params[key] = val - task_lr * grad
# adapted_state_dict[key] = adapted_params[key]
# Y_que_hat = model(X_que, adapted_state_dict)
# loss = loss_fn(Y_que_hat, Y_que) # NOTE !!!!!!!!
# extract data from torch Variable, move to cpu, convert to numpy arrays
Y_que_hat = Y_que_hat.data.cpu().numpy()
Y_que = Y_que.data.cpu().numpy()
# compute all metrics on this batch
summary_batch = {
metric: metrics[metric](Y_que_hat, Y_que)
for metric in metrics
}
summary_batch['loss'] = loss.item()
summ.append(summary_batch)
# compute mean of all metrics in summary
metrics_mean = {
metric: np.mean([x[metric] for x in summ])
for metric in summ[0]
}
metrics_string = " ; ".join(
"{}: {:05.3f}".format(k, v) for k, v in metrics_mean.items())
logging.info("- [" + split.upper() + "] Eval metrics : " + metrics_string)
return metrics_mean
if __name__ == '__main__':
# TODO Evaluate trained model.
pass