-
Notifications
You must be signed in to change notification settings - Fork 486
/
test_mp_reduce_scatter.py
180 lines (147 loc) · 5.57 KB
/
test_mp_reduce_scatter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import sys
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
def _mp_fn(index):
device = xm.xla_device()
world_size = xr.world_size()
scale = 1 / world_size
scatter_dim = 1
shard_size = 2
input_list_size = 5
if xm.xla_device_hw(device) in ['TPU', 'CUDA']:
rand = torch.rand((32, shard_size * world_size, 32))
xrand = rand.to(device)
res = xm.reduce_scatter(xm.REDUCE_SUM, xrand, scale, scatter_dim,
world_size)
expected_world = xm.all_reduce(xm.REDUCE_SUM, xrand, scale)
xm.mark_step()
slice_idx = torch.tensor(
list(range(index * shard_size, (index + 1) * shard_size)))
expected = expected_world.cpu().index_select(scatter_dim, slice_idx)
assert res.cpu().allclose(expected)
xm.rendezvous('test_reduce_scatter')
# Testing reduce-scatter with list input
rand_list = [
torch.rand((32, shard_size * world_size, 32))
for _ in range(input_list_size)
]
xrand_list = [rand.to(device) for rand in rand_list]
# TODO: fix the broken case with pin_layout=True
res_list = xm.reduce_scatter(
xm.REDUCE_SUM,
xrand_list,
scale,
scatter_dim,
world_size,
pin_layout=False)
for i, res in enumerate(res_list):
expected_world = xm.all_reduce(xm.REDUCE_SUM, xrand_list[i], scale)
xm.mark_step()
slice_idx = torch.tensor(
list(range(index * shard_size, (index + 1) * shard_size)))
expected = expected_world.cpu().index_select(scatter_dim, slice_idx)
assert res.cpu().allclose(expected)
xm.rendezvous('test_reduce_scatter_list_input')
# Testing reduce-scatter with list input bucketized
rand_list = [
torch.rand((32, shard_size * world_size, 32))
for _ in range(input_list_size)
]
xrand_list = [rand.to(device) for rand in rand_list]
# TODO: fix the broken case with pin_layout=True
res_list = xm.reduce_scatter_bucketized(
xm.REDUCE_SUM,
xrand_list,
scale,
scatter_dim,
world_size,
pin_layout=False)
for i, res in enumerate(res_list):
expected_world = xm.all_reduce(xm.REDUCE_SUM, xrand_list[i], scale)
xm.mark_step()
slice_idx = torch.tensor(
list(range(index * shard_size, (index + 1) * shard_size)))
expected = expected_world.cpu().index_select(scatter_dim, slice_idx)
assert res.cpu().allclose(expected)
xm.rendezvous('test_reduce_scatter_list_input_bucketized')
# Testing reduce-scatter with list input and output
output_list = [
torch.rand((32, shard_size * world_size, 32))
for _ in range(input_list_size)
]
xoutput_list = [output.to(device) for output in output_list]
# TODO: fix the broken case with pin_layout=True
res_list = xm.reduce_scatter(
xm.REDUCE_SUM,
xrand_list,
scale,
scatter_dim,
world_size,
output=xoutput_list,
pin_layout=False)
assert (xoutput_list == res_list)
for i, res in enumerate(xoutput_list):
expected_world = xm.all_reduce(xm.REDUCE_SUM, xrand_list[i], scale)
xm.mark_step()
slice_idx = torch.tensor(
list(range(index * shard_size, (index + 1) * shard_size)))
expected = expected_world.cpu().index_select(scatter_dim, slice_idx)
assert res.cpu().allclose(expected)
xm.rendezvous('test_reduce_scatter_list_input_output')
# Testing reduce-scatter with list input and output (buckettized)
output_list = [
torch.rand((32, shard_size * world_size, 32))
for _ in range(input_list_size)
]
xoutput_list = [output.to(device) for output in output_list]
# TODO: fix the broken case with pin_layout=True
res_list = xm.reduce_scatter_bucketized(
xm.REDUCE_SUM,
xrand_list,
scale,
scatter_dim,
world_size,
output=xoutput_list,
pin_layout=False)
assert (xoutput_list == res_list)
for i, res in enumerate(xoutput_list):
expected_world = xm.all_reduce(xm.REDUCE_SUM, xrand_list[i], scale)
xm.mark_step()
slice_idx = torch.tensor(
list(range(index * shard_size, (index + 1) * shard_size)))
expected = expected_world.cpu().index_select(scatter_dim, slice_idx)
assert res.cpu().allclose(expected)
xm.rendezvous('test_reduce_scatter_list_input_output_bucketized')
# Testing reduce-scatter with list input and output (buckettized, but zero bucket size)
output_list = [
torch.rand((32, shard_size * world_size, 32))
for _ in range(input_list_size)
]
xoutput_list = [output.to(device) for output in output_list]
# TODO: fix the broken case with pin_layout=True
res_list = xm.reduce_scatter_bucketized(
xm.REDUCE_SUM,
xrand_list,
scale,
scatter_dim,
world_size,
output=xoutput_list,
bucket_cap_mb=0,
pin_layout=False)
assert (xoutput_list == res_list)
for i, res in enumerate(xoutput_list):
expected_world = xm.all_reduce(xm.REDUCE_SUM, xrand_list[i], scale)
xm.mark_step()
slice_idx = torch.tensor(
list(range(index * shard_size, (index + 1) * shard_size)))
expected = expected_world.cpu().index_select(scatter_dim, slice_idx)
assert res.cpu().allclose(expected)
xm.rendezvous(
'test_reduce_scatter_list_input_output_bucketized, zero bucket size')
else:
print(
'Default device {} is not a TPU device'.format(device), file=sys.stderr)
if __name__ == '__main__':
torch_xla.launch(_mp_fn, args=())