-
Notifications
You must be signed in to change notification settings - Fork 3
/
powersgd_grad.py
312 lines (261 loc) · 11.5 KB
/
powersgd_grad.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
# powersgd code copied here to get the norm and gradients
# without doing distributed
import datetime
import os
import time
from contextlib import contextmanager
from typing import List
import numpy as np
import torch
try:
import bit2byte
except ImportError:
pass
class Reducer:
def __init__(self, random_seed, device, timer):
self.rng = np.random.RandomState(random_seed)
M = 1024 * 1024
# self.precalc_numbers = (
# torch.from_numpy(self.rng.randn(128 * M)).to(device).type(torch.float32)
# )
if torch.distributed.is_available():
self.n_workers = torch.distributed.get_world_size()
self.rank = torch.distributed.get_rank()
else:
self.n_workers = 1
self.rank = 0
self.device = device
self.timer = timer
def reduce(self, grad_in, grad_out, memory_out):
"""Return communicated bits"""
raise NotImplementedError()
class RankKReducer(Reducer):
def __init__(self, random_seed, device, timer, n_power_iterations=0,
reuse_query=True, rank=1):
super().__init__(random_seed, device, timer)
assert n_power_iterations == 0
self.rank = rank
self.p_memory = None
self.q_memory = None
self.reuse_query = reuse_query
# self.memory_update = None
def set_random(self, vector):
#TODO: Verify what this is doing
# this brings some non-determinism
torch.manual_seed(self.rng.randint(1_000_000_000))
vector.data[:] = torch.randn(*vector.shape, device=self.device)
# orthogonalize(vector)
def reduce(self, grad_in, grad_out, memory_out):
"""
Reduce gradients between the workers in place
:param grad_in: dictionary
:param grad_out: dictionary
:param memory_out: dictionary
"""
floats_communicated = 0
# import ipdb; ipdb.set_trace()
# if use_memory and self.memory_update==None:
# # need to intialize the vector
# self.memory_update = [torch.zeros_like(gg) for gg in grad_in]
# if use_memory:
# # add the memory term to the gradient before using
# for idx, mem_term in enumerate(self.memory_update):
# grad_in[idx] = grad_in[idx] + mem_term
# Split the tensors into rank1-ones that will be reduced un-compressed
# and rank > 1 tensors that are compressed
# No need for rank 1 tensors
# rank1_tensors = [
# (tensor, out, mem)
# for tensor, out, mem in zip(grad_in, grad_out, memory_out)
# if tensor.ndimension() <= 1
# ]
high_rank_tensors = [
(tensor, out, mem)
for tensor, out, mem in zip(grad_in, grad_out, memory_out)
if tensor.ndimension() > 1
]
# We are building a rank-1 approximation of every tensor
# that can be interpreted as a matrix. Let the approximation be
# M = p q^T
# We are allocating consequtive memory for the p's and q's
memory_is_uninitialized = self.p_memory is None
with self.timer("reduce.allocate_memory", verbosity=2):
p_total_size = 0
q_total_size = 0
for tensor, _, _ in high_rank_tensors:
matrix = tensor.view(tensor.shape[0], -1)
n, m = matrix.shape
rank = min(n, m, self.rank)
p_total_size += n * rank
q_total_size += m * rank
if self.p_memory is None:
self.p_memory = torch.empty(p_total_size, device=self.device)
self.q_memory = torch.empty(q_total_size, device=self.device)
# Find them again and make lists of pointers
ps = []
qs = []
p_idx = 0
q_idx = 0
for tensor, _, _ in high_rank_tensors:
matrix = tensor.view(tensor.shape[0], -1)
n, m = matrix.shape
rank = min(n, m, self.rank)
ps.append(self.p_memory[p_idx : p_idx + n * rank].view(n, rank))
qs.append(self.q_memory[q_idx : q_idx + m * rank].view(m, rank))
p_idx += n * rank
q_idx += m * rank
with self.timer("reduce.prepare.q", verbosity=2):
for (tensor, _, _), q, p in zip(high_rank_tensors, qs, ps):
matrix = tensor.view(tensor.shape[0], -1)
n, m = matrix.shape
if self.reuse_query and not memory_is_uninitialized:
# orthogonalize(q)
pass
else:
# Sample a query vector q
self.set_random(q)
with self.timer("reduce.compute.p", verbosity=2):
for (tensor, _, _), q, p in zip(high_rank_tensors, qs, ps):
matrix = tensor.view(tensor.shape[0], -1)
torch.matmul(matrix, q, out=p)
with self.timer("reduce.p", verbosity=2):
# Don't need all reduce
# all_reduce(self.p_memory)
torch.distributed.all_reduce(self.p_memory, async_op=False)
# bits_communicated += n_bits(self.p_memory)
floats_communicated += torch.numel(self.p_memory)
# Start communicating rank 1 tensors
# no need for rank1 tensors
# with self.timer("reduce.rank1.pack", verbosity=2):
# rank1_tensor_list = TensorBuffer([tensor for (tensor, _, _) in rank1_tensors])
# Don't need all reduce
# Don't need all reduce even in case of distribbuted case
# because rank 1 tensor ill not be even part of powersgd
#TODO: Verify above hypothesis
# with self.timer("reduce.rank1.all_reduce", verbosity=2):
# rank1_handle = rank1_tensor_list.all_reduce(async_op=True)
# bits_communicated += rank1_tensor_list.bits()
with self.timer("reduce.normalize.p", verbosity=2):
for p in ps:
orthogonalize(p)
with self.timer("reduce.compute.q", verbosity=2):
for p, q, (tensor, _, _) in zip(ps, qs, high_rank_tensors):
matrix = tensor.view(tensor.shape[0], -1)
torch.matmul(matrix.t(), p, out=q)
with self.timer("reduce.q", verbosity=2):
# all_reduce(self.q_memory)
torch.distributed.all_reduce(self.q_memory, async_op=False)
# bits_communicated += n_bits(self.q_memory)
floats_communicated += torch.numel(self.q_memory)
self.q_memory.data[:] /= self.n_workers
with self.timer("reduce.outerprod", verbosity=2):
for p, q, (tensor, out, mem) in zip(ps, qs, high_rank_tensors):
# Set the output gradient
torch.matmul(p, q.t(), out=out.data[:])
mem.data[:] = tensor - out
# no need for rank1 tensors
# with self.timer("reduce.rank1.unpack", verbosity=2):
# # rank1_handle.wait()
# rank1_tensor_list.buffer /= self.n_workers
# rank1_tensor_list.unpack([out for (_, out, _) in rank1_tensors])
# if use_memory:
# # very dirty hack, the previous iteration
# # was adding the memory term, that updates the things in place
# # and effect the whole gradient in subsequent methods
# # this is a quick fix where i subtract the same term again
# for idx, mem_val in enumerate(self.memory_update):
# grad_in[idx] = grad_in[idx] - mem_val
# if use_memory:
# for idx, mem_update in enumerate(memory_out):
# self.memory_update[idx] = mem_update
return floats_communicated
class TopKReducer(Reducer):
def __init__(self, random_seed, device,timer, rank=1):
super().__init__(random_seed, device, timer)
self.k = rank #Tells the percentage of k we want
def reduce(self, grad_in, grad_out, memory_out):
grad_single_tensor = list_to_tensor(grad_in)
num_sample = int(self.k * len(grad_single_tensor))
indexes = torch.argsort(torch.abs(grad_single_tensor),
descending=True)[:num_sample]
values = grad_single_tensor[indexes]
index_list = [torch.empty_like(indexes) for k in range(self.n_workers)]
ind_wait = torch.distributed.all_gather(index_list, indexes, async_op=True)
value_list = [torch.empty_like(values) for k in range(self.n_workers)]
val_wait = torch.distributed.all_gather(value_list, values,
async_op=True)
ind_wait.wait()
val_wait.wait()
# grads synced
grad_accum = torch.zeros_like(grad_single_tensor)
for idx, vals in zip(index_list, value_list):
grad_accum[idx] += vals
# got the sparsified gradient
start_index = 0
for idx, ts in enumerate(grad_out):
num_elements_ts = ts.numel()
ts = grad_accum[start_index:start_index+num_elements_ts].reshape(ts.shape)
grad_out[idx] = ts
start_index += num_elements_ts
# import ipdb; ipdb.set_trace()
return (2*indexes.numel()*self.n_workers)
def list_to_tensor(input_list):
temp_list = [t.reshape(-1) for t in input_list]
return (torch.cat(temp_list))
class TensorBuffer():
"""
Packs multiple tensors into one flat buffer for efficient
intra-worker communication.
"""
def __init__(self, tensors):
indices = [0]
for tensor in tensors:
new_end = indices[-1] + tensor.nelement()
indices.append(new_end)
self._start_idx = indices[:-1]
self._end_idx = indices[1:]
self._tensors = tensors
self.buffer = torch.cat([t.view(-1) for t in tensors]) # copies
def __getitem__(self, index):
return self.buffer[self._start_idx[index] : self._end_idx[index]].view(*self._tensors[index].shape)
def __len__(self):
return len(self._tensors)
def pack(self, tensors=None):
# Optional. init already does this.
if tensors is None:
tensors = self._tensors
for tensor, entry in zip(tensors, self):
entry[:] = tensor
def unpack(self, tensors):
for tensor, entry in zip(tensors, self):
tensor[:] = entry
def nelement(self):
return self.buffer.nelement()
def element_size(self):
return self.buffer.element_size()
def bits(self):
return 8 * self.nelement() * self.element_size()
def all_reduce(self, async_op=False):
return torch.distributed.all_reduce(self.buffer, async_op=async_op)
def all_gather(self, async_op=False):
n_workers = torch.distributed.get_world_size() if torch.distributed.is_available() else 1
buffers = [torch.empty_like(self.buffer) for i in range(n_workers)]
handle = all_gather(buffers, self.buffer, async_op=async_op)
if async_op:
return buffers, handle
else:
return buffers
def n_bits(tensor):
return 8 * tensor.nelement() * tensor.element_size()
@torch.jit.script
def orthogonalize(matrix):
n, m = matrix.shape
for i in range(m):
# Normalize the i'th column
col = matrix[:, i : i + 1]
col /= torch.sqrt(torch.sum(col ** 2))
# Project it on the rest and remove it
if i + 1 < m:
rest = matrix[:, i + 1 :]
# rest -= torch.matmul(col.t(), rest) * col
rest -= torch.sum(col * rest, dim=0) * col