-
Notifications
You must be signed in to change notification settings - Fork 480
/
xla_model.py
1569 lines (1296 loc) · 57.5 KB
/
xla_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import contextlib
import io
import itertools
import logging
import sys
import re
import threading
import time
import warnings
from typing import Any, Callable, Dict, List, Optional, Set, TextIO, Tuple, TypedDict, Union
import torch
import torch.distributed._functional_collectives
from torch.library import Library
import torch.nn.functional as F
import torch.optim as optim
import torch_xla
from torch_xla import runtime
import torch_xla.core.xla_env_vars as xenv
import torch_xla.debug.metrics_saver as ms
import torch_xla.utils.utils as xu
import torch_xla.utils.closures as xc
from torch_xla.distributed.spmd.xla_sharding import ShardingSpec
import os
from torch_xla.experimental.deprecation import deprecated
import torch_xla._internal.utils as _utils
_DEVICES = xu.LazyProperty(lambda: torch_xla._XLAC._xla_get_devices())
REDUCE_SUM = 'sum'
REDUCE_MUL = 'mul'
REDUCE_AND = 'and'
REDUCE_OR = 'or'
REDUCE_MIN = 'min'
REDUCE_MAX = 'max'
_DEVICE_CONTEXTS = dict()
_DEVICE_CONTEXTS_LOCK = threading.Lock()
XLA_LIB = Library("xla", "DEF")
from . import xla_model as this_module
xrt_world_size = deprecated(this_module, torch_xla.runtime.world_size,
'xrt_world_size() will be removed in release 2.7.')
get_ordinal = deprecated(
this_module, torch_xla.runtime.global_ordinal,
'xla_model.get_ordinal() will be removed in release 2.7.')
parse_xla_device = deprecated(
this_module, _utils.parse_xla_device,
'xla_model.parse_xla_device() will be removed in release 2.7.')
class DeviceContext(object):
def __init__(self, device: Union[str, torch.device]):
self.device = device
def _get_device_context(
device: Optional[Union[str, torch.device]] = None) -> DeviceContext:
if device is None:
device = torch_xla._XLAC._xla_get_default_device()
else:
device = str(device)
with _DEVICE_CONTEXTS_LOCK:
devctx = _DEVICE_CONTEXTS.get(device, None)
if devctx is None:
devctx = DeviceContext(device)
_DEVICE_CONTEXTS[device] = devctx
return devctx
def is_xla_tensor(tensor: torch.Tensor) -> bool:
return tensor.device.type == 'xla'
def get_xla_supported_devices(devkind: Optional[str] = None,
max_devices: Optional[int] = None) -> List[str]:
"""Returns a list of supported devices of a given kind.
Args:
devkind (string..., optional): If specified, a device type such as `TPU`,
`CUDA`, `CPU`, or name of custom PJRT device.
max_devices (int, optional): The maximum number of devices to be returned of
that kind.
Returns:
The list of device strings such as ['xla:0', 'xla:1', ...]
"""
# TODO(wcromar): Remove `devkind` after 2.3 release cut. We no longer support
# multiple device types.
if not devkind:
devices = torch_xla._XLAC._xla_get_devices()
return [
f'xla:{i}'
for i, _ in enumerate(devices[:max_devices] if max_devices else devices)
]
else:
warnings.warn("`devkind` argument is deprecated and will be removed in a "
"future release.")
xla_devices = _DEVICES.value
kind_devices = []
for i, device in enumerate(xla_devices):
if re.match(devkind + r':\d+$', device):
kind_devices.append('xla:{}'.format(i))
if kind_devices:
return kind_devices[:max_devices] if max_devices else kind_devices
def get_local_ordinal() -> int:
"""Retrieves the replication local ordinal of the current thread.
The local ordinals range from 0 to the number of local devices minus 1.
Returns:
The replication local ordinal of the current thread.
"""
return runtime.local_ordinal()
def is_master_ordinal(local: bool = True) -> bool:
"""Checks whether the current process is the master ordinal (0).
Args:
local (bool): Whether the local or global master ordinal should be checked.
In case of multi-host replication, there is only one global master ordinal
(host 0, device 0), while there are NUM_HOSTS local master ordinals.
Default: True
Returns:
A boolean indicating whether the current process is the master ordinal.
"""
ordinal = get_local_ordinal() if local else runtime.global_ordinal()
return ordinal == 0
def master_print(*args: Any,
fd: TextIO = sys.stdout,
local: bool = False,
flush: bool = False):
if is_master_ordinal(local=local):
print(*args, file=fd, flush=flush)
def xla_device(n: Optional[int] = None,
devkind: Optional[str] = None) -> torch.device:
"""Returns a given instance of an XLA device.
Args:
n (int, optional): The specific instance (ordinal) to be returned. If
specified, the specific XLA device instance will be returned. Otherwise
the first device of `devkind` will be returned.
devkind (string..., optional): If specified, device type such as `TPU`,
`CUDA`, `CPU`, or custom PJRT device. Deprecated.
Returns:
A `torch.device` with the requested instance.
"""
# When SPMD is enabled, we always return `xla:0` to the user, and
# under the hood we use virtual device logic for every xla tensor
if xu.check_env_flag('XLA_USE_SPMD'):
device = 'xla:0'
torch_xla._XLAC._xla_set_default_device(device)
return torch.device(device)
return runtime.xla_device(n, devkind)
def _xla_real_device(device: torch.device) -> Any:
device_str = str(device)
m = re.match(r'xla:(\d+)$', device_str)
if not m:
raise RuntimeError('Invalid device format: {}'.format(device_str))
return _DEVICES.value[int(m.group(1))]
def xla_real_devices(devices: Optional[List[torch.device]] = None) -> List[str]:
"""Returns the real devices' name.
Args:
devices: The list of torch devices such as ['xla:0', 'xla:1'].
Returns:
A list of real devices' name such as ['CUDA:0', 'CUDA:1'].
"""
if not devices:
devices = get_xla_supported_devices()
return [_xla_real_device(device) for device in devices]
def xla_device_hw(device: Union[str, torch.device]) -> str:
"""Returns the hardware type of the given device.
Args:
device (string or torch.device): The xla device that will be mapped to the
real device.
Returns:
A string representation of the hardware type of the given device.
"""
real_device = _xla_real_device(device)
return real_device.split(':')[0]
def xla_replication_devices(
local_devices: Optional[List[torch.device]] = None) -> List[str]:
real_devices = xla_real_devices(local_devices)
device_types = set()
for device in real_devices:
xdev = _utils.parse_xla_device(device)
device_types.add(xdev[0])
if len(device_types) != 1:
# No replication if the device set spawns multiple device types.
raise RuntimeError(
'Cannot replicate across different device types: devices={}/{}'.format(
local_devices, real_devices))
device_type = device_types.pop()
kind_devices = get_xla_supported_devices()
if len(kind_devices) != len(local_devices):
# Replication can only happen among all devices of one kind.
raise RuntimeError(
'Cannot replicate if number of devices ({}) is different from {}'.
format(len(local_devices), len(kind_devices)))
replication_devices = []
for device in torch_xla._XLAC._xla_get_all_devices():
# device is like 'CUDA:0'
xdev = _utils.parse_xla_device(device)
if not xdev:
raise RuntimeError('Invalid device format: {}'.format(device))
if xdev[0] == device_type:
replication_devices.append(device)
sorted_by_ordinal = sorted(
replication_devices,
key=lambda device: _utils.parse_xla_device(device)[1])
return sorted_by_ordinal
def unlazy(tensors: List[torch.Tensor]):
"""Blocks the program until `tensors` are materialized.
This API is for benchmarking, don't use it in real models.
Args:
tensors: List of `torch.Tensor`s to materialize. For each
Tensor `t` in the list, `t.device` must be an `xla` device.
"""
torch_xla._XLAC._xla_sync_multi(tensors, devices=[], wait=True)
def set_replication(device: torch.device,
devices: Optional[List[torch.device]]):
device = str(device)
devctx = _get_device_context(device=device)
devices = [str(x) for x in devices]
if devices:
# sample replication_devices: ['CUDA:0', 'CUDA:1', 'CUDA:2', 'CUDA:3']
replication_devices = xla_replication_devices(devices)
torch_xla._XLAC._xla_set_replication_devices(replication_devices)
devctx.device_index = devices.index(device)
else:
torch_xla._XLAC._xla_set_replication_devices([])
devctx.device_index = 0
torch_xla._XLAC._set_all_reduce_token(devctx.device, None)
torch_xla._XLAC._xla_set_default_device(device)
class RateTracker(object):
def __init__(self, smooth_factor: Optional[float] = None):
self._smooth_factor = xu.getenv_as(
'RATE_TRACKER_SMOOTHING', float,
0.4) if smooth_factor is None else smooth_factor
self._start_time = time.time()
self._partial_time = self._start_time
self._partial_count = 0.0
self._partial_rate = None
self._count = 0.0
def _update(self, now: float, rate: float):
self._partial_count += self._count
self._count = 0.0
self._partial_time = now
self._partial_rate = rate
def add(self, count: float):
self._count += count
def _smooth(self, current_rate: float) -> float:
if self._partial_rate is None:
smoothed_rate = current_rate
else:
smoothed_rate = ((1 - self._smooth_factor) * current_rate +
self._smooth_factor * self._partial_rate)
return smoothed_rate
def rate(self):
now = time.time()
delta = now - self._partial_time
report_rate = 0.0
if delta > 0:
report_rate = self._smooth(self._count / delta)
self._update(now, report_rate)
return report_rate
def global_rate(self):
delta = time.time() - self._start_time
count = self._partial_count + self._count
return count / delta if delta > 0 else 0.0
class ToXlaTensorArena(object):
def __init__(self, convert_fn: Callable[[List[torch.Tensor]],
List[torch.Tensor]],
select_fn: Callable[[torch.Tensor], bool]):
self._convert_fn = convert_fn
self._select_fn = select_fn
self._tensors = []
def _add(self, tensor: torch.Tensor):
self._tensors.append(tensor)
def _convert(self):
self._index = 0
if self._tensors:
self._converted_tensors = self._convert_fn(self._tensors)
else:
self._converted_tensors = []
def _get_converted_tensor(self) -> torch.Tensor:
assert self._index < len(self._converted_tensors)
new_tensor = self._converted_tensors[self._index]
self._index += 1
return new_tensor
def _collect_tensors(self, inputs: Any):
def collect_fn(value: Any):
self._add(value)
xu.for_each_instance(inputs, lambda x: self._select_fn(x), collect_fn)
def _replace_tensors(self, inputs: Any):
def convert_fn(value: Any):
return self._get_converted_tensor()
return xu.for_each_instance_rewrite(inputs, lambda x: self._select_fn(x),
convert_fn)
def transform(self, inputs: Any):
self._tensors = []
self._collect_tensors(inputs)
self._convert()
return self._replace_tensors(inputs)
def check_view_sharing(obj):
tensors = set()
aliases = dict()
def tensor_info(t: torch.Tensor) -> str:
return '{}{}'.format(t.dtype, list(t.size()))
def tensor_id(t: torch.Tensor) -> Tuple[int, str]:
if is_xla_tensor(t):
return torch_xla._XLAC._xla_get_tensor_id(t), 'xla'
return id(t), 'torch'
def alias_id(t: torch.Tensor) -> Tuple[int, str]:
if is_xla_tensor(t):
aid = torch_xla._XLAC._xla_get_tensor_view_alias_id(t)
return None if aid == 0 else aid, 'xla'
return t.storage().data_ptr(), 'torch'
def check_object(obj):
tid = tensor_id(obj)
if tid not in tensors:
tensors.add(tid)
aid = alias_id(obj)
if aid[0] is not None:
if aid in aliases:
oobj = aliases[aid]
raise RuntimeError(
'Tensor ID {} ({}) is sharing a view with tensor ID {} ({})'.
format(tid, tensor_info(obj), tensor_id(oobj), tensor_info(oobj)))
aliases[aid] = obj
xu.for_each_instance(obj, lambda x: type(x) == torch.Tensor, check_object)
def _fetch_gradients(optimizer: optim.Optimizer) -> List[torch.Tensor]:
gradients = []
for param_group in optimizer.__getstate__()['param_groups']:
for group, params in param_group.items():
if group == 'params':
for p in params:
if isinstance(p, torch.Tensor) and p.grad is not None:
gradients.append(p.grad.data)
return gradients
def _get_all_reduce_token() -> Tuple[Any, DeviceContext]:
devctx = _get_device_context()
token = torch_xla._XLAC._get_all_reduce_token(devctx.device)
return token, devctx
def all_reduce(
reduce_type: str,
inputs: Union[torch.Tensor, List[torch.Tensor]],
scale: float = 1.0,
groups: Optional[List[List[int]]] = None,
pin_layout: bool = True) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Performs an inplace reduce operation on the input tensor(s).
Args:
reduce_type (string): One of ``xm.REDUCE_SUM``, ``xm.REDUCE_MUL``,
``xm.REDUCE_AND``, ``xm.REDUCE_OR``, ``xm.REDUCE_MIN`` and
``xm.REDUCE_MAX``.
inputs: Either a single `torch.Tensor` or a list of `torch.Tensor` to
perform the all reduce op to.
scale (float): A default scaling value to be applied after the reduce.
Default: 1.0
groups (list, optional): A list of list, representing the replica groups for
the `all_reduce()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]`
defines two groups, one with the `[0, 1, 2, 3]` replicas and one with
the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with
all the replicas in it.
pin_layout (bool, optional): whether to pin the layout for this communication op.
Layout pining can prevent potential data corruption when each process that
participate in the communication has slightly different program, but it might
cause some xla compilation to fail. Unpin the layout when you see error message
like "HloModule has a mix of layout constrained".
Returns:
If a single `torch.Tensor` is passed, the return value is a `torch.Tensor`
holding the reduced value (across the replicas). If a list/tuple is passed,
this function performs an inplace all-reduce op on the input tensors, and
returns the list/tuple itself.
"""
groups = groups or []
# No-op if there is only one device
if runtime.world_size() == 1 and not xu.getenv_as('XLA_ALWAYS_ALLREDUCE',
bool, False):
if isinstance(inputs, torch.Tensor):
return inputs.clone()
else:
return inputs
if isinstance(inputs, torch.Tensor):
result = None
if scale == 1.0 and groups == [] and pin_layout:
# TODO(alanwaketan): Support groups.
# Only c10d_functional version cc ops are traceable by Dynamo.
result = torch.ops._c10d_functional.all_reduce(inputs, reduce_type, "")
else:
result = torch_xla._XLAC._xla_all_reduce(reduce_type, inputs, scale,
groups, pin_layout)
results = [result]
else:
torch_xla._XLAC._xla_all_reduce_inplace(reduce_type, inputs, scale, groups,
pin_layout)
results = inputs
return results[0] if isinstance(inputs, torch.Tensor) else results
def _all_gather_using_all_reduce(
value: torch.Tensor,
dim: int = 0,
groups: Optional[List[List[int]]] = None,
pin_layout: bool = True) -> Optional[torch.Tensor]:
"""Performs an all-gather operation using all-reduce along a given dimension.
Args:
value (torch.Tensor): The input tensor.
dim (int): The gather dimension.
Default: 0
groups (list, optional): A list of list, representing the replica groups for
the `all_gather()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]`
defines two groups, one with the `[0, 1, 2, 3]` replicas and one with
the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with
all the replicas in it.
pin_layout (bool, optional): whether to pin the layout for this communication op.
Layout pining can prevent potential data corruption when each process that
participate in the communication has slightly different program, but it might
cause some xla compilation to fail. Unpin the layout when you see error message
like "HloModule has a mix of layout constrained".
Returns:
A tensor which has, in the ``dim`` dimension, all the values from the
participating replicas.
"""
if dim < 0:
dim = value.dim() + dim
size = value.size(dim)
padding = [0] * (2 * value.dim())
ordinal = runtime.global_ordinal()
if groups is None:
left, right = ordinal, runtime.world_size() - 1 - ordinal
else:
ordinals = dict()
for g in groups:
for i, x in enumerate(g):
ordinals[x] = (i, len(g) - 1 - i)
left, right = ordinals[ordinal]
idx = value.dim() - 1 - dim
padding[2 * idx] = left * size
padding[2 * idx + 1] = right * size
return all_reduce(REDUCE_SUM, F.pad(value, padding), groups=groups)
def all_gather(value: torch.Tensor,
dim: int = 0,
groups: Optional[List[List[int]]] = None,
output: Optional[torch.Tensor] = None,
pin_layout: bool = True) -> torch.Tensor:
"""Performs an all-gather operation along a given dimension.
Args:
value (torch.Tensor): The input tensor.
dim (int): The gather dimension.
Default: 0
groups (list, optional): A list of list, representing the replica groups for
the `all_gather()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]`
defines two groups, one with the `[0, 1, 2, 3]` replicas and one with
the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with
all the replicas in it.
output (torch.Tensor): Optional output tensor.
pin_layout (bool, optional): whether to pin the layout for this communication op.
Layout pining can prevent potential data corruption when each process that
participate in the communication has slightly different program, but it might
cause some xla compilation to fail. Unpin the layout when you see error message
like "HloModule has a mix of layout constrained".
Returns:
A tensor which has, in the ``dim`` dimension, all the values from the
participating replicas.
"""
# _all_gather_using_all_reduce does not support list of tensors as input
if pin_layout and output == None and isinstance(value, torch.Tensor):
# There is not an easy way to pin the all_gather layout, so use all_reduce
# based all_gather for this purpose.
return _all_gather_using_all_reduce(
value, dim=dim, groups=groups, pin_layout=True)
if dim < 0:
dim = value.dim() + dim
if groups:
shard_count = len(groups[0])
assert all(len(group) == shard_count for group in groups), \
"Replica groups must have the same number of replicas/shards."
else:
# All replicas belong to a single group
shard_count = runtime.world_size()
token, devctx = _get_all_reduce_token()
if isinstance(value, torch.Tensor):
if output != None:
# Call the out of place version of the all_gather
new_token = torch_xla._XLAC._xla_all_gather_out(output, value, token, dim,
shard_count, groups or [],
pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
return output
result = torch_xla._XLAC._xla_all_gather(value, dim, shard_count, groups or
[], pin_layout)
return result
# Now the input should be a list of Tensors.
elif isinstance(value, list) and all(
isinstance(v, torch.Tensor) for v in value):
if pin_layout:
raise RuntimeError(
"For xm.all_gather with list of tensors input, pin_layout=True is not yet supported."
)
if output != None:
if not isinstance(output, list) or any(
not isinstance(v, torch.Tensor) for v in output):
raise TypeError(
f"`output` needs to be a list of Tensors, but given {type(output)}."
)
if len(output) != len(value):
raise ValueError("`output` length doesn't match `input` length: "
f"{len(output)} vs {len(input)}.")
# Call the out of place version of the reduce_scatter
new_token = torch_xla._XLAC._xla_all_gather_coalesced_out(
output, value, token, dim, shard_count, groups or [], pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
return output
result = torch_xla._XLAC._xla_all_gather_coalesced(value, token, dim,
shard_count, groups or
[], pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1])
return result[:-1]
else:
raise TypeError("`value` needs to be a Tensor or a list of Tensors, but "
f"given {type(value)}.")
class CoalescingBuckets(object):
def __init__(
self,
func: Callable[[
Union[torch.Tensor,
List[torch.Tensor]], Optional[Union[torch.Tensor,
List[torch.Tensor]]]
], Union[torch.Tensor, List[torch.Tensor]]],
input_list: Any,
output_list: Optional[Any] = None,
bucket_cap_mb: int = 160):
if not isinstance(input_list, list) or any(
not isinstance(v, torch.Tensor) for v in input_list):
raise TypeError(
f"`input_list` needs to be a list of Tensors, but given {type(input_list)}."
)
if output_list != None:
if not isinstance(output_list, list) or any(
not isinstance(v, torch.Tensor) for v in output_list):
raise TypeError(
f"`output_list` needs to be a list of Tensors, but given {type(output_list)}."
)
if len(output_list) != len(input_list):
raise ValueError(
"`output_list` length doesn't match `input_list` length: "
f"{len(output_list)} vs {len(input_list)}.")
self._func = func
self._input_list = input_list
self._output_list = output_list
self._total = 0
self._tensor_bucket = []
self._output_bucket = [] if output_list else None
self._bucket_cap = bucket_cap_mb * 1024 * 1024
self._out_tensors = []
def flush(self):
if len(self._tensor_bucket) == 1:
# Use non-coalesced CCOp if its just one tensor
output = self._output_bucket[0] if self._output_bucket else None
self._out_tensors.append(self._func(self._tensor_bucket[0], output))
elif len(self._tensor_bucket):
self._out_tensors.extend(
self._func(self._tensor_bucket, self._output_bucket))
self._total = 0
self._tensor_bucket = []
self._output_bucket = [] if self._output_list else None
def add(self, tensor: torch.Tensor, idx: int):
self._total += tensor.numel() * tensor.element_size()
self._tensor_bucket.append(tensor)
if self._output_list != None:
self._output_bucket.append(self._output_list[idx])
def __call__(self) -> Union[torch.Tensor, List[torch.Tensor]]:
for idx, tensor in enumerate(self._input_list):
tensor_bytes = tensor.numel() * tensor.element_size()
# Aim for target bucket_cap_mb: flush new tensor with bucket if bucket content
# is small (1/2 cap) but don't combine if combined total is over 2x cap
total_new = self._total + tensor_bytes
if tensor_bytes > self._bucket_cap and self._total < 0.5 * self._bucket_cap and total_new <= 2 * self._bucket_cap:
self.add(tensor, idx)
self.flush()
else:
# Bucketize till the total spills over
if total_new > self._bucket_cap:
self.flush()
self.add(tensor, idx)
# Flush the last remaining bucket
self.flush()
assert len(self._out_tensors) == len(self._input_list)
return self._out_tensors
def all_gather_bucketized(
input_list: List[torch.Tensor],
dim: int = 0,
groups: Optional[List[List[int]]] = None,
output: Optional[torch.Tensor] = None,
pin_layout: bool = False,
bucket_cap_mb=160) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Performs an all-gather operation along a given dimension, with bucketization.
Args:
See all_gather for the args: dim, groups, output, pin_layout
input_list: List of input tensors
bucket_cap_mb: Number of MegaBytes of the tensor bucket to fill before doing all-gather.
Returns:
A list of tensors each of which has, in the ``dim`` dimension, all the values from the
participating replicas.
"""
# sanity checks
if pin_layout:
raise RuntimeError(
"For xm.all_gather_bucketized, pin_layout=True is not yet supported.")
def _all_gather_coalesced(_input_list, _output_list=None):
return all_gather(
value=_input_list,
dim=dim,
groups=groups,
output=_output_list,
pin_layout=pin_layout)
buckets = CoalescingBuckets(
_all_gather_coalesced, input_list, output, bucket_cap_mb=bucket_cap_mb)
return buckets()
def all_to_all(value: torch.Tensor,
split_dimension: int,
concat_dimension: int,
split_count: int,
groups: Optional[List[List[int]]] = None,
pin_layout: bool = True) -> torch.Tensor:
"""Performs an XLA `AllToAll()` operation on the input tensor.
See: https://www.tensorflow.org/xla/operation_semantics#alltoall
Args:
value (torch.Tensor): The input tensor.
split_dimension (int): The dimension upon which the split should happen.
concat_dimension (int): The dimension upon which the concat should happen.
split_count (int): The split count.
groups (list, optional): A list of list, representing the replica groups for
the `all_reduce()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]`
defines two groups, one with the `[0, 1, 2, 3]` replicas and one with
the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with
all the replicas in it.
pin_layout (bool, optional): whether to pin the layout for this communication op.
Layout pining can prevent potential data corruption when each process that
participate in the communication has slightly different program, but it might
cause some xla compilation to fail. Unpin the layout when you see error message
like "HloModule has a mix of layout constrained".
Returns:
The result `torch.Tensor` of the `all_to_all()` operation.
"""
token, devctx = _get_all_reduce_token()
result = torch_xla._XLAC._xla_all_to_all(value, token, split_dimension,
concat_dimension, split_count,
groups or [], pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1])
return result[0]
def collective_permute(value: torch.Tensor,
pairs: List[List[int]]) -> torch.Tensor:
"""Performs a XLA `CollectivePermute()` operation on the input tensor.
WARNING: This function is not very reliable, may produce wrong results under
certain inputs. Use it at your own risk.
See: https://www.tensorflow.org/xla/operation_semantics#collectivepermute
Args:
value (torch.Tensor): The input tensor.
pairs (list): A list of (source_replica_id, target_replica_id) pairs,
representing the sender and receiver for the `collective_permute()`
operation. Example: `[[0, 1], [1, 2], [2, 0]]` defines three pairs. The
tensor will be sent from replica 0 to replica 1, replica 1 to replica 2,
and replica 2 to replica 0.
Returns:
The result `torch.Tensor` of the `collective_permute()` operation.
"""
token, devctx = _get_all_reduce_token()
result = torch_xla._XLAC._xla_collective_permute(value, token, pairs)
torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1])
return result[0]
def collective_broadcast(tensors: List[torch.Tensor],
root_ordinal: int = 0,
groups: Optional[List[int]] = None,
pin_layout: bool = True) -> None:
"""Broadcast values of `tensors` from root replica to other replicas in-place.
Args:
tensors (list): List of `torch.Tensor`s to broadcast.
root_ordinal (int): Ordinal of replica with values to broadcast.
groups (list, optional): A list of list, representing the replica groups for
the `all_reduce()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]`
defines two groups, one with the `[0, 1, 2, 3]` replicas and one with
the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with
all the replicas in it.
pin_layout (bool, optional): whether to pin the layout for this communication op.
Layout pining can prevent potential data corruption when each process that
participate in the communication has slightly different program, but it might
cause some xla compilation to fail. Unpin the layout when you see error message
like "HloModule has a mix of layout constrained".
"""
with torch.no_grad():
# We must produce the exact same graph in each replica to prevent hanging,
# so each replica must have the same multiply op with the same parameters.
for tensor in tensors:
scale = torch.tensor(
1 if runtime.global_ordinal() == root_ordinal else 0,
dtype=tensor.dtype)
# Transfer scale tensor as device data instead of constant 1 or 0.
xscale = send_cpu_data_to_device(scale, tensor.device)
tensor.mul_(xscale[0])
all_reduce(REDUCE_SUM, tensors, groups=groups, pin_layout=pin_layout)
def send(value: torch.Tensor, channel_id: int) -> torch.Tensor:
"""Performs a XLA `Send()` operation on the input tensor.
See: https://www.tensorflow.org/xla/operation_semantics#send
Args:
value (torch.Tensor): The input tensor.
channel_id (int64): opaque id identifying the destination of the send op.
"""
token, devctx = _get_all_reduce_token()
# The input will be returned as result.
input_as_result, new_token = torch_xla._XLAC._xla_send(
value, token, channel_id)
torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
return input_as_result
def recv(output: torch.Tensor, channel_id: int) -> torch.Tensor:
"""Performs a XLA `Recv()` operation on the input tensor.
See: https://www.tensorflow.org/xla/operation_semantics#recv
Args:
output (torch.Tensor): The output tensor.
channel_id (int64): opaque id identifying the source of the recv op.
"""
token, devctx = _get_all_reduce_token()
result, new_token = torch_xla._XLAC._xla_recv(output, token, channel_id)
torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
return result
def reduce_scatter(reduce_type: str,
input: Union[torch.Tensor, List[torch.Tensor]],
scale: float,
scatter_dim: int,
shard_count: int,
groups: Optional[List[List[int]]] = None,
output: Optional[Union[torch.Tensor,
List[torch.Tensor]]] = None,
pin_layout: bool = True) -> torch.Tensor:
"""Performs a XLA `ReduceScatter()` operation on the input tensor.
See: https://www.tensorflow.org/xla/operation_semantics#reducescatter
Args:
reduce_type (string): One of ``xm.REDUCE_SUM``, ``xm.REDUCE_MUL``,
``xm.REDUCE_AND``, ``xm.REDUCE_OR``, ``xm.REDUCE_MIN`` and
``xm.REDUCE_MAX``.
input: (torch.Tensor or a list of torch.Tensor): The input. If it's a list, then
it will also be the output.
scale (float): A default scaling value to be applied after the reduce.
scatter_dim (int): Dimension number to which apply scatter operation.
shard_count (int): The number of ways to split up the scatter_dim in.
groups (list): A list of list, representing the replica groups for
the `reduce_scatter()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]`
defines two groups, one with the `[0, 1, 2, 3]` replicas and one with
the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with
all the replicas in it.
output: Optional output tensor if `input` is a torch.Tensor, or a list of
torch.Tensor if `input` is a list of torch.Tensor.
pin_layout (bool, optional): whether to pin the layout for this communication op.
Layout pining can prevent potential data corruption when each process that
participate in the communication has slightly different program, but it might
cause some xla compilation to fail. Unpin the layout when you see error message
like "HloModule has a mix of layout constrained".
Returns:
A `torch.Tensor` with all the values reduced across replicas. Each process
gets a shard split along the `scatter_dim`. All other dimensions are
the same as the input.
"""
token, devctx = _get_all_reduce_token()
if isinstance(input, torch.Tensor):
if output != None:
# Call the out of place version of the reduce_scatter
new_token = torch_xla._XLAC._xla_reduce_scatter_out(
reduce_type, output, input, token, scale, scatter_dim, shard_count,
groups or [], pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
return output
result = torch_xla._XLAC._xla_reduce_scatter(reduce_type, input, token,
scale, scatter_dim,
shard_count, groups or [],
pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1])
return result[0]
# Now the input should be a list of Tensors.
elif isinstance(input, list) and all(
isinstance(v, torch.Tensor) for v in input):
if output != None:
if not isinstance(output, list) or any(
not isinstance(v, torch.Tensor) for v in output):
raise TypeError(
f"`output` needs to be a list of Tensors, but given {type(output)}."
)
if len(output) != len(input):
raise ValueError("`output` length doesn't match `input` length: "
f"{len(output)} vs {len(input)}.")
# Call the out of place version of the reduce_scatter
new_token = torch_xla._XLAC._xla_reduce_scatter_coalesced_out(
reduce_type, output, input, token, scale, scatter_dim, shard_count,
groups or [], pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
return output
result = torch_xla._XLAC._xla_reduce_scatter_coalesced(
reduce_type, input, token, scale, scatter_dim, shard_count, groups or
[], pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1])
return result[:-1]
else:
raise TypeError("`input` needs to be a Tensor or a list of Tensors, but "
f"given {type(input)}.")
def reduce_scatter_bucketized(reduce_type: str,
input_list: Union[torch.Tensor,
List[torch.Tensor]],
scale: float,
scatter_dim: int,
shard_count: int,
groups: Optional[List[List[int]]] = None,
output: Optional[Union[
torch.Tensor, List[torch.Tensor]]] = None,
pin_layout: bool = False,
bucket_cap_mb: int = 160) -> CoalescingBuckets:
"""Performs a XLA `ReduceScatter()` operation on a list of tensors (bucketized).
See: https://www.tensorflow.org/xla/operation_semantics#reducescatter
Args:
see reduce_scatter for reduce_type, scale, scatter_dim, shard_count, groups, pin_layout
input_list: List of input tensors
output: Optional list of output torch.Tensor
bucket_cap_mb: Number of MegaBytes of the tensor bucket to fill before doing reduce-scatter.
Returns:
A list of `torch.Tensors` with all the values reduced across replicas. Each process
gets a shard split along the `scatter_dim`. All other dimensions are
the same as the input.
"""
def _reduce_scatter_coalesced(
_input_list: Union[torch.Tensor, List[torch.Tensor]],
_output_list: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None
) -> Union[torch.Tensor, List[torch.Tensor]]:
return reduce_scatter(
reduce_type=reduce_type,
input=_input_list,
scale=scale,
scatter_dim=scatter_dim,
shard_count=shard_count,
groups=groups,
output=_output_list,
pin_layout=pin_layout)
buckets = CoalescingBuckets(
_reduce_scatter_coalesced,
input_list,
output,
bucket_cap_mb=bucket_cap_mb)
return buckets()
def add_step_closure(closure: Callable[..., Any],
args: Tuple[Any, ...] = (),
run_async: bool = False):
"""Adds a closure to the list of the ones to be run at the end of the step.
Many times during model training there is the need to print/report (print to
console, post to tensorboard, etc...) information which require the content of
intermediary tensors to be inspected.
Inspecting different tensors content in different points of the model code
requires many executions and typically causes performance issues.
Adding a step closure will ensure that it will be run after the barrier, when
all the live tensors will be already materialized to device data.
Live tensors which will include the ones captured by the closure arguments.
So using `add_step_closure()` will ensure a single execution will be
performed, even when multiple closures are queued, requiring multiple tensors