forked from markdtw/meta-learning-lstm-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmetalearner.py
108 lines (86 loc) · 4.29 KB
/
metalearner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from __future__ import division, print_function, absolute_import
import pdb
import math
import torch
import torch.nn as nn
class MetaLSTMCell(nn.Module):
"""C_t = f_t * C_{t-1} + i_t * \tilde{C_t}"""
def __init__(self, input_size, hidden_size, n_learner_params):
super(MetaLSTMCell, self).__init__()
"""Args:
input_size (int): cell input size, default = 20
hidden_size (int): should be 1
n_learner_params (int): number of learner's parameters
"""
self.input_size = input_size
self.hidden_size = hidden_size
self.n_learner_params = n_learner_params
self.WF = nn.Parameter(torch.Tensor(input_size + 2, hidden_size))
self.WI = nn.Parameter(torch.Tensor(input_size + 2, hidden_size))
self.cI = nn.Parameter(torch.Tensor(n_learner_params, 1))
self.bI = nn.Parameter(torch.Tensor(1, hidden_size))
self.bF = nn.Parameter(torch.Tensor(1, hidden_size))
self.reset_parameters()
def reset_parameters(self):
for weight in self.parameters():
nn.init.uniform_(weight, -0.01, 0.01)
# want initial forget value to be high and input value to be low so that
# model starts with gradient descent
nn.init.uniform_(self.bF, 4, 6)
nn.init.uniform_(self.bI, -5, -4)
def init_cI(self, flat_params):
self.cI.data.copy_(flat_params.unsqueeze(1))
def forward(self, inputs, hx=None):
"""Args:
inputs = [x_all, grad]:
x_all (torch.Tensor of size [n_learner_params, input_size]): outputs from previous LSTM
grad (torch.Tensor of size [n_learner_params]): gradients from learner
hx = [f_prev, i_prev, c_prev]:
f (torch.Tensor of size [n_learner_params, 1]): forget gate
i (torch.Tensor of size [n_learner_params, 1]): input gate
c (torch.Tensor of size [n_learner_params, 1]): flattened learner parameters
"""
x_all, grad = inputs
batch, _ = x_all.size()
if hx is None:
f_prev = torch.zeros((batch, self.hidden_size)).to(self.WF.device)
i_prev = torch.zeros((batch, self.hidden_size)).to(self.WI.device)
c_prev = self.cI
hx = [f_prev, i_prev, c_prev]
f_prev, i_prev, c_prev = hx
# f_t = sigmoid(W_f * [grad_t, loss_t, theta_{t-1}, f_{t-1}] + b_f)
f_next = torch.mm(torch.cat((x_all, c_prev, f_prev), 1), self.WF) + self.bF.expand_as(f_prev)
# i_t = sigmoid(W_i * [grad_t, loss_t, theta_{t-1}, i_{t-1}] + b_i)
i_next = torch.mm(torch.cat((x_all, c_prev, i_prev), 1), self.WI) + self.bI.expand_as(i_prev)
# next cell/params
c_next = torch.sigmoid(f_next).mul(c_prev) - torch.sigmoid(i_next).mul(grad)
return c_next, [f_next, i_next, c_next]
def extra_repr(self):
s = '{input_size}, {hidden_size}, {n_learner_params}'
return s.format(**self.__dict__)
class MetaLearner(nn.Module):
def __init__(self, input_size, hidden_size, n_learner_params):
super(MetaLearner, self).__init__()
"""Args:
input_size (int): for the first LSTM layer, default = 4
hidden_size (int): for the first LSTM layer, default = 20
n_learner_params (int): number of learner's parameters
"""
self.lstm = nn.LSTMCell(input_size=input_size, hidden_size=hidden_size)
self.metalstm = MetaLSTMCell(input_size=hidden_size, hidden_size=1, n_learner_params=n_learner_params)
def forward(self, inputs, hs=None):
"""Args:
inputs = [loss, grad_prep, grad]
loss (torch.Tensor of size [1, 2])
grad_prep (torch.Tensor of size [n_learner_params, 2])
grad (torch.Tensor of size [n_learner_params])
hs = [(lstm_hn, lstm_cn), [metalstm_fn, metalstm_in, metalstm_cn]]
"""
loss, grad_prep, grad = inputs
loss = loss.expand_as(grad_prep)
inputs = torch.cat((loss, grad_prep), 1) # [n_learner_params, 4]
if hs is None:
hs = [None, None]
lstmhx, lstmcx = self.lstm(inputs, hs[0])
flat_learner_unsqzd, metalstm_hs = self.metalstm([lstmhx, grad], hs[1])
return flat_learner_unsqzd.squeeze(), [(lstmhx, lstmcx), metalstm_hs]