-
Notifications
You must be signed in to change notification settings - Fork 4
/
model.py
127 lines (74 loc) · 3.82 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch
from torch import nn
from modules import *
from utils import get_glimpse
import math
class GDRAM(nn.Module):
def __init__(self, device=None, dataset=None, Fast=False):
super(GDRAM, self).__init__()
self.glimpse_size = 12
self.num_scales = 4
self.img_size = 128
self.class_num = 10
if dataset == 'cifar100':
self.class_num = 100
self.normalized_glimpse_size = self.glimpse_size/(self.img_size/2)
self.glimpse_net = GlimpseNetwork(3*self.num_scales,self.glimpse_size,2,128,128)
self.rnn1 = GlimpseLSTMCoreNetwork(128,128)
self.rnn2 = LocationLSTMCoreNetwork(128,128,self.glimpse_size)
self.class_net = ActionNetwork(128, self.class_num)
self.emission_net = EmissionNetwork(128)
self.baseline_net = BaselineNetwork(128*2,1)
self.num_glimpses = 8
self.location_size = 2
self.device = device
self.fast = Fast
def forward(self, x):
batch_size = x.size(0)
hidden1, cell_state1 = self.rnn1.init_hidden(batch_size)
hidden1 = hidden1.to(self.device)
cell_state1 = cell_state1.to(self.device)
hidden2, cell_state2 = self.rnn2.init_hidden(x, batch_size)
hidden2 = hidden2.to(self.device)
cell_state2 = cell_state2.to(self.device)
#location = torch.zeros(batch_size,2).to(self.device)
std = (torch.ones(batch_size,2)*(math.exp(-1/2))).to(self.device)
location, std, log_prob = self.emission_net(hidden2)
location = torch.clamp(location, min=-1 + self.normalized_glimpse_size / 2,
max=1 - self.normalized_glimpse_size / 2)
location_log_probs = torch.empty(batch_size, self.num_glimpses).to(self.device)
locations = torch.empty(batch_size, self.num_glimpses, self.location_size).to(self.device)
baselines = torch.empty(batch_size, self.num_glimpses).to(self.device)
weights = torch.empty(batch_size, self.num_glimpses).to(self.device)
weight = torch.ones(batch_size).to(self.device)
action_logits = 0
weight_sum = 0
for i in range(self.num_glimpses):
locations[:, i] = location
location_log_probs[:, i] = log_prob
glimpse = get_glimpse(x, location.detach(), self.glimpse_size, self.num_scales, device=self.device).to(self.device)
glimpse_feature = self.glimpse_net(glimpse, location)
hidden1, cell_state1 = self.rnn1(glimpse_feature, (hidden1, cell_state1))
hidden2, cell_state2 = self.rnn2(hidden1, (hidden2, cell_state2))
loc_diff, std, log_prob = self.emission_net(hidden2)
loc_diff *= (self.normalized_glimpse_size/2 * 2**(self.num_scales - 1))
new_location = location.detach() + loc_diff
new_location = torch.clamp(new_location, min = -1 + self.normalized_glimpse_size/2 , max= 1 - self.normalized_glimpse_size/2)
location = new_location
hidden = torch.cat((hidden1, hidden2), dim=1)
baseline = self.baseline_net(hidden)
#location_log_probs[:, i] = log_prob
baselines[:, i] = baseline.squeeze()
weight = weight.unsqueeze(1)
action_logit = self.class_net(hidden1)
action_logits += weight*action_logit
weights[:,i] = weight.squeeze()
weight_sum += weight
if (not self.training and i>1) and self.fast:
if weights[0,-1]<0.5 and weights[0,-2]<0.5:
break
std = torch.mean(std, dim=1)
normalized_std = (std-math.exp(-1/2))/(math.exp(1/2)-math.exp(-1/2))
weight = 1 - normalized_std
action_logits /= weight_sum
return action_logits, locations, location_log_probs, baselines, weights