-
Notifications
You must be signed in to change notification settings - Fork 6
/
meta_template.py
95 lines (81 loc) · 3.58 KB
/
meta_template.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# This code is modified from https://github.com/wyharveychen/CloserLookFewShot
import backbone
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
import utils
from abc import abstractmethod
import cv2
class MetaTemplate(nn.Module):
def __init__(self, model_func, n_way, n_support, change_way = False):
super(MetaTemplate, self).__init__()
self.n_way = n_way
self.n_support = n_support
self.n_query = -1 #(change depends on input)
self.feature = model_func()
self.feat_dim = self.feature.final_feat_dim
'''
@abstractmethod:Abstract a base class and specify which methods to use, but only abstract methods, do not implement
functions, the class can only be inherited, not instantiated, but subclasses must implement the methods.
'''
@abstractmethod
def set_forward(self,x,is_feature):
pass
@abstractmethod
def set_forward_loss(self, x):
pass
def forward(self,x):
out = self.feature.forward(x)
return out
def parse_feature(self,x,is_feature):
x = Variable(x.cuda())
if is_feature:
z_all = x
else:
x = x.contiguous().view( self.n_way * (self.n_support + self.n_query), *x.size()[2:]) #images:(5*21)*3*224*224
z_all = self.feature.forward(x) #features:(5*21)*512*1*1
c = z_all.shape[1]
h = z_all.shape[-2]
w = z_all.shape[-1]
z_all = z_all.view( self.n_way, self.n_support + self.n_query, c) #5*21*512*1*1
z_support = z_all[:, :self.n_support] #5*5*512*1*1
z_query = z_all[:, self.n_support:] #5*16*512*1*1
return z_support, z_query
def correct(self, x):
scores1, scores2, scores3 = self.set_forward(x)
y_query = np.repeat(range( self.n_way ), self.n_query )
topk_scores, topk_labels = scores1.data.topk(1, 1, True, True)
topk_ind = topk_labels.cpu().numpy()
top1_correct = np.sum(topk_ind[:,0] == y_query)
return float(top1_correct), len(y_query)
def train_loop(self, epoch, train_loader, optimizer, scheduler, const1, const2):
print_freq = 10
avg_loss=0
for i, (x,_ ) in enumerate(train_loader):
self.n_query = x.size(1) - self.n_support
optimizer.zero_grad()
###### add calibration loss ########################################################################
loss1, loss2, loss3 = self.set_forward_loss( x )
loss = loss1 + const1 * loss2 + const2 * loss3
loss.backward()
optimizer.step()
avg_loss = avg_loss+loss.item()
if i % print_freq==0:
print('Epoch [%d], Batch [%d/%d], Loss: %.6f, lr: %f' % (epoch, i, len(train_loader), avg_loss/float(i+1), scheduler.get_lr()[0]))
scheduler.step(epoch)
def test_loop(self, test_loader, record = None):
correct =0
count = 0
acc_all = []
iter_num = len(test_loader)
for i, (x,_) in enumerate(test_loader):
self.n_query = x.size(1) - self.n_support
correct_this, count_this = self.correct(x)
acc_all.append(correct_this/ count_this*100 )
acc_all = np.asarray(acc_all)
acc_mean = np.mean(acc_all)
acc_std = np.std(acc_all)
print('%d Test Acc = %4.2f%% +- %4.2f%%' %(iter_num, acc_mean, 1.96* acc_std/np.sqrt(iter_num)))
return acc_mean