-
Notifications
You must be signed in to change notification settings - Fork 3
/
pcgrad.py
199 lines (163 loc) · 6.04 KB
/
pcgrad.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pdb
import numpy as np
import copy
import random
class PCGrad():
def __init__(self, optimizer, reduction='mean'):
self._optim, self._reduction = optimizer, reduction
return
@property
def optimizer(self):
return self._optim
@property
def param_groups(self):
return self._optim.param_groups
def state_dict(self):
return self._optim.state_dict()
def zero_grad(self):
'''
clear the gradient of the parameters
'''
return self._optim.zero_grad(set_to_none=True)
def step(self):
'''
update the parameters with the gradient
'''
return self._optim.step()
def pc_backward(self, objectives):
'''
calculate the gradient of the parameters
input:
- objectives: a list of objectives
'''
grads, shapes, has_grads = self._pack_grad(objectives)
pc_grad = self._project_conflicting(grads, has_grads)
pc_grad = self._unflatten_grad(pc_grad, shapes[0])
self._set_grad(pc_grad)
return
def _project_conflicting(self, grads, has_grads, shapes=None):
shared = torch.stack(has_grads).prod(0).bool()
pc_grad, num_task = copy.deepcopy(grads), len(grads)
for g_i in pc_grad:
random.shuffle(grads)
for g_j in grads:
g_i_g_j = torch.dot(g_i, g_j)
if g_i_g_j < 0:
g_i -= (g_i_g_j) * g_j / (g_j.norm()**2)
merged_grad = torch.zeros_like(grads[0]).to(grads[0].device)
if self._reduction:
merged_grad[shared] = torch.stack([g[shared]
for g in pc_grad]).mean(dim=0)
elif self._reduction == 'sum':
merged_grad[shared] = torch.stack([g[shared]
for g in pc_grad]).sum(dim=0)
else: exit('invalid reduction method')
merged_grad[~shared] = torch.stack([g[~shared]
for g in pc_grad]).sum(dim=0)
return merged_grad
def _set_grad(self, grads):
'''
set the modified gradients to the network
'''
idx = 0
for group in self._optim.param_groups:
for p in group['params']:
# if p.grad is None: continue
p.grad = grads[idx]
idx += 1
return
def _pack_grad(self, objectives):
'''
pack the gradient of the parameters of the network for each objective
output:
- grad: a list of the gradient of the parameters
- shape: a list of the shape of the parameters
- has_grad: a list of mask represent whether the parameter has gradient
'''
grads, shapes, has_grads = [], [], []
for obj in objectives:
self._optim.zero_grad(set_to_none=True)
obj.backward(retain_graph=True)
grad, shape, has_grad = self._retrieve_grad()
grads.append(self._flatten_grad(grad, shape))
has_grads.append(self._flatten_grad(has_grad, shape))
shapes.append(shape)
return grads, shapes, has_grads
def _unflatten_grad(self, grads, shapes):
unflatten_grad, idx = [], 0
for shape in shapes:
length = np.prod(shape)
unflatten_grad.append(grads[idx:idx + length].view(shape).clone())
idx += length
return unflatten_grad
def _flatten_grad(self, grads, shapes):
flatten_grad = torch.cat([g.flatten() for g in grads])
return flatten_grad
def _retrieve_grad(self):
'''
get the gradient of the parameters of the network with specific
objective
output:
- grad: a list of the gradient of the parameters
- shape: a list of the shape of the parameters
- has_grad: a list of mask represent whether the parameter has gradient
'''
grad, shape, has_grad = [], [], []
for group in self._optim.param_groups:
for p in group['params']:
# if p.grad is None: continue
# tackle the multi-head scenario
if p.grad is None:
shape.append(p.shape)
grad.append(torch.zeros_like(p).to(p.device))
has_grad.append(torch.zeros_like(p).to(p.device))
continue
shape.append(p.grad.shape)
grad.append(p.grad.clone())
has_grad.append(torch.ones_like(p).to(p.device))
return grad, shape, has_grad
class TestNet(nn.Module):
def __init__(self):
super().__init__()
self._linear = nn.Linear(3, 4)
def forward(self, x):
return self._linear(x)
class MultiHeadTestNet(nn.Module):
def __init__(self):
super().__init__()
self._linear = nn.Linear(3, 2)
self._head1 = nn.Linear(2, 4)
self._head2 = nn.Linear(2, 4)
def forward(self, x):
feat = self._linear(x)
return self._head1(feat), self._head2(feat)
if __name__ == '__main__':
# fully shared network test
torch.manual_seed(4)
x, y = torch.randn(2, 3), torch.randn(2, 4)
net = TestNet()
y_pred = net(x)
pc_adam = PCGrad(optim.Adam(net.parameters()))
pc_adam.zero_grad()
loss1_fn, loss2_fn = nn.L1Loss(), nn.MSELoss()
loss1, loss2 = loss1_fn(y_pred, y), loss2_fn(y_pred, y)
pc_adam.pc_backward([loss1, loss2])
for p in net.parameters():
print(p.grad)
print('-' * 80)
# seperated shared network test
torch.manual_seed(4)
x, y = torch.randn(2, 3), torch.randn(2, 4)
net = MultiHeadTestNet()
y_pred_1, y_pred_2 = net(x)
pc_adam = PCGrad(optim.Adam(net.parameters()))
pc_adam.zero_grad()
loss1_fn, loss2_fn = nn.MSELoss(), nn.MSELoss()
loss1, loss2 = loss1_fn(y_pred_1, y), loss2_fn(y_pred_2, y)
pc_adam.pc_backward([loss1, loss2])
for p in net.parameters():
print(p.grad)