-
Notifications
You must be signed in to change notification settings - Fork 486
/
test_operations.py
3238 lines (2596 loc) · 103 KB
/
test_operations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Parse local options first, and rewrite the sys.argv[].
# We need to do that before import "common", as otherwise we get an error for
# unrecognized arguments.
import argparse
import os
import sys
import tempfile
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('--replicated', action='store_true')
parser.add_argument('--long_test', action='store_true')
parser.add_argument('--max_diff_count', type=int, default=25)
parser.add_argument('--verbosity', type=int, default=0)
FLAGS, leftovers = parser.parse_known_args()
sys.argv = [sys.argv[0]] + leftovers
from absl.testing import absltest, parameterized
# Normal imports section starts here.
import collections
import copy
import itertools
import math
from numbers import Number
from functools import reduce
import numpy
import random
import re
import torch
import torch.autograd as ad
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.testing._internal.common_device_type import dtypes
from torch.testing._internal.common_dtype import (
all_types_and_complex_and,
all_types_and,
)
import torch_xla
import torch_xla.core.xla_builder as xb
import torch_xla.core.xla_op_registry as xor
import torch_xla.distributed.data_parallel as dp
from torch_xla.distributed.fsdp import checkpoint_module
from torch_xla.distributed.fsdp.utils import apply_xla_patch_to_nn_linear
import torch_xla.debug.metrics as met
import torch_xla.debug.model_comparator as mc
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.spmd as xs
from torch_xla import runtime as xr
import torch_xla.test.test_utils as xtu
import torch_xla.utils.dlpack as xdlpack
import torch_xla.utils.utils as xu
import torch_xla.utils.serialization as xser
import torch_xla.core.xla_model as xm
import torch_xla.core.functions as xf
import torch_xla.debug.profiler as xp
import unittest
import test_utils
DeviceSupport = collections.namedtuple('DeviceSupport', ['num_devices'])
XLA_DISABLE_FUNCTIONALIZATION = bool(
os.environ.get('XLA_DISABLE_FUNCTIONALIZATION', False))
def _is_on_tpu():
return 'XRT_TPU_CONFIG' in os.environ or xr.device_type() == 'TPU'
skipOnTpu = unittest.skipIf(_is_on_tpu(), 'Not supported on TPU')
skipOnEagerDebug = unittest.skipIf(torch_xla.experimental.is_eager_mode(),
'skip on eager debug mode')
def _skipIfFunctionalization(value=True, reason=""):
verb = "is" if value else "is not"
reason = f" Reason: {reason}" if reason else ""
return unittest.skipIf(
XLA_DISABLE_FUNCTIONALIZATION is value,
f'Works only when functionalization {verb} disabled.{reason}.')
def skipIfFunctionalizationEnabled(reason):
return _skipIfFunctionalization(value=False, reason=reason)
def skipIfFunctionalizationDisabled(reason):
return _skipIfFunctionalization(value=True, reason=reason)
def onlyOnCUDA(fn):
accelerator = os.environ.get("PJRT_DEVICE").lower()
return unittest.skipIf(accelerator != "cuda", "PJRT_DEVICE=CUDA required")(fn)
def onlyIfXLAExperimentalContains(feat):
experimental = os.environ.get("XLA_EXPERIMENTAL", "").split(":")
return unittest.skipIf(feat not in experimental,
f"XLA_EXPERIMENTAL={feat} required")
def _gen_tensor(*args, **kwargs):
return torch.randn(*args, **kwargs)
def _gen_int_tensor(*args, **kwargs):
return torch.randint(*args, **kwargs)
def _gen_mask(size):
return torch.randint(0, 2, size, dtype=torch.bool)
def _get_device_support(devname):
devices = torch_xla._XLAC._xla_get_devices()
num_devices = 0
for device in devices:
if re.match(devname + r':\d+$', device):
num_devices += 1
return DeviceSupport(num_devices=num_devices) if num_devices > 0 else None
def _support_replicated(devname, num_devices):
devsup = _get_device_support(devname)
if not devsup:
return False
return devsup.num_devices >= num_devices
def _random_inputs(shapes, num_replicas=1):
random_tensors = []
for _ in range(0, num_replicas):
replica_inputs = []
for shape in shapes:
replica_inputs.append(_gen_tensor(*shape))
random_tensors.append(tuple(replica_inputs))
return tuple(random_tensors)
def _random_like(tensor_list):
random_tensors = []
for o in tensor_list:
if o.dtype == torch.float32 or o.dtype == torch.float64:
random_tensors += [_gen_tensor(*o.shape, dtype=o.dtype)]
elif o.dtype == torch.int64:
# TODO remove this, we shouldn't be needing to pass random_tensor for long types
random_tensors += [torch.empty_like(o)]
else:
raise RuntimeError('Unsupported type: ', o.dtype)
return random_tensors
def _zeros_like(tensor_list):
zeros_tensors = []
for o in tensor_list:
if o.dtype == torch.float32 or o.dtype == torch.float64:
zeros_tensors += [torch.zeros(*o.shape, dtype=o.dtype)]
elif o.dtype == torch.int64:
# TODO remove this, we shouldn't be needing to pass zeros_tensor for long types
zeros_tensors += [torch.zeros_like(o)]
else:
raise RuntimeError('Unsupported type: ', o.dtype)
return zeros_tensors
def onlyIfTorchSupportsCUDA(fn):
return unittest.skipIf(
not torch.cuda.is_available(), reason="requires PyTorch CUDA support")(
fn)
def onlyIfPJRTDeviceIsCUDA(fn):
return unittest.skipIf(
os.environ.get("PJRT_DEVICE") not in ("GPU", "CUDA"),
reason="requires CUDA as PJRT_DEVICE")(
fn)
class TestToXlaTensorArena(test_utils.XlaTestCase):
def test(self):
xla_device = xm.xla_device()
kdata = [_gen_tensor(2, 3), _gen_tensor(3, 4)]
kdata.append([_gen_tensor(2, 5), _gen_tensor(3, 6)])
data = dict()
data[_gen_tensor(2, 2)] = tuple(kdata)
data[_gen_tensor(2, 4)] = set([12.0, _gen_tensor(3, 7)])
data['ABC'] = _gen_tensor(4, 3)
def select_fn(v):
return type(v) == torch.Tensor
def convert_fn(tensors):
devices = [str(xla_device)] * len(tensors)
return torch_xla._XLAC._xla_tensors_from_aten(tensors, devices)
def check_fn(v):
if select_fn(v):
return xm.is_xla_tensor(v)
elif isinstance(v, (list, tuple, set)):
for x in v:
if not check_fn(x):
return False
elif isinstance(v, dict):
for k, x in v.items():
if not check_fn(k) or not check_fn(x):
return False
return True
xla_data = xm.ToXlaTensorArena(convert_fn, select_fn).transform(data)
self.assertTrue(check_fn(xla_data))
class TestParallelLoader(test_utils.XlaTestCase):
def test(self):
devices = [torch.device(x) for x in xm.get_xla_supported_devices()]
A = 3.11
B = 4.09
batch_size = 128 * len(devices)
gen = xu.FnDataGenerator(
lambda x: x * A + B, batch_size, _gen_tensor, dims=[8], count=10)
para_loader = pl.ParallelLoader(gen, devices)
for device in devices:
loader = para_loader.per_device_loader(device)
for data, target in loader:
self.assertEqual(data.device, device)
self.assertEqual(target.device, device)
class TestAtenTensorTo(test_utils.XlaTestCase):
def test(self):
devices = xm.get_xla_supported_devices()
for device in reversed(devices):
t = _gen_tensor(8, 12)
tto = t.to(device=torch.device(device))
self.assertEqual(tto.device, torch.device(device))
t = _gen_tensor(8, 12).to(device=torch.device(devices[0]))
for device in devices[1:]:
tto = t.to(device=torch.device(device))
self.assertEqual(tto.device, torch.device(device))
for i in range(0, len(devices) - 1):
dev0 = devices[i]
dev1 = devices[i + 1]
t0 = torch.zeros(4, 4, device=torch.device(dev0))
t1 = t0.to(device=torch.device(dev1))
t0 = t0 + torch.ones_like(t0, device=torch.device(dev0))
t1 = t1 + torch.ones_like(t1, device=torch.device(dev1))
self.assertEqual(t0.cpu(), t1.cpu())
class XlaMNIST(nn.Module):
def __init__(self):
super(XlaMNIST, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
@unittest.skipIf(
xr.device_type() == 'CUDA',
'Parallelism for DataParallel uses multi-threads. But cuda assumes one GPU device per process instead of relying on threads.'
)
class TestParallelTensorMNIST(test_utils.XlaTestCase):
def test(self):
# devices=['xla:0', 'xla:1', 'xla:2', 'xla:3'] for example.
devices = xm.get_xla_supported_devices()
batch_size = xu.getenv_as('BATCH_SIZE', int, defval=8)
sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10)
train_loader = xu.SampleGenerator(
data=(torch.zeros(batch_size, 1, 28,
28), torch.zeros(batch_size, dtype=torch.int64)),
sample_count=sample_count * len(devices))
def loop_fn(model, loader, device, context):
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
for data, target in loader:
with xu.TimedScope(msg='Training loop: ', printfn=None):
optimizer.zero_grad()
output = xu.timed(lambda: model(data), msg='Model: ', printfn=None)
loss = xu.timed(
lambda: loss_fn(output, target), msg='Loss: ', printfn=None)
xu.timed(loss.backward, msg='LossBkw: ', printfn=None)
xu.timed(
lambda: xm.optimizer_step(optimizer), msg='Step: ', printfn=None)
self.assertLess(loss.cpu().item(), 3.0)
model_parallel = dp.DataParallel(XlaMNIST, device_ids=devices)
model_parallel(loop_fn, train_loader)
class TestLongGraphChain(test_utils.XlaTestCase):
def test(self):
device = xm.xla_device()
orig_x = torch.Tensor([[1, 2], [3, 4]])
orig_y = torch.Tensor([[0.1, 0.2], [0.3, 0.4]])
x = orig_x
y = orig_y
xla_x = orig_x.to(device)
xla_y = orig_y.to(device)
for i in range(0, 2000):
x = x + 2 * y
xla_x = xla_x + 2 * xla_y
self.assertEqualRel(
x,
xla_x.cpu(),
rel_err=1e-3,
abs_err=5,
max_diff_count=FLAGS.max_diff_count)
class TestSelect(test_utils.XlaTestCase):
def test_get_xla_tensor(self):
x = _gen_tensor(14, 24, 8, device=xm.xla_device())
t = x.data.cpu()
sx = x.select(1, 12)
tx = t.select(1, 12)
self.assertEqual(tx, sx.data.cpu())
def test_masked_fill_scalar(self):
def fn(tensor):
# Build a mask from the first line of tensor.
# Also, make it have the same rank as the original tensor.
mask = tensor[0].ge(0.5).unsqueeze(dim=0)
# Call masked_fill.
return tensor.masked_fill(mask, 10)
x = _gen_tensor(2, 2, device=xm.xla_device())
x_cpu = x.cpu()
self.assertEqual(fn(x_cpu), fn(x))
class TestRandom(test_utils.XlaTestCase):
def test_random_from_to_bool(self):
for from_val, to_val in [[0, 1], [0, 2], [1, 2]]:
x = _gen_tensor(10, device=xm.xla_device())
x.random_(from_val, to_val)
delta = 1
self.assertTrue(from_val <= x.to(torch.int).min() < (from_val + delta))
self.assertTrue((to_val - delta) <= x.to(torch.int).max() < to_val)
class TestBinaryCrossEntropyLimitValue(test_utils.XlaTestCase):
def test_cross_entropy_loss(self):
def test_fn(pred, target):
lossfn = nn.BCELoss()
return lossfn(pred, target)
pred = torch.tensor(1.0)
target = torch.tensor(1.0)
for offset in [1, 0, 1e-8, 1e-7]:
self.runAtenTest([pred - offset, target], test_fn)
class TestNllLossLimitValue(test_utils.XlaTestCase):
def test_nll_loss_inf(self):
def test_fn(logits, target):
return nn.functional.nll_loss(logits, target)
inf = float('inf')
logits = torch.tensor([[1., inf], [-inf, 2.]])
for target in [[0, 0], [0, 1], [1, 0], [1, 1]]:
target_tensor = torch.tensor(target)
self.runAtenTest([logits, target_tensor], test_fn)
def test_nll_loss_nan(self):
def test_fn(logits, target):
return nn.functional.nll_loss(logits, target)
nan = float('nan')
logits = torch.tensor([[1, nan], [nan, 2]])
# Need to include both nan being labeled and nan not being labeled case
for target in [[0, 0], [0, 1], [1, 0], [1, 1]]:
target_tensor = torch.tensor(target)
self.runAtenTest([logits, target_tensor], test_fn)
class TestInterOpSyncTensors(test_utils.XlaTestCase):
def test_inter_op_sync(self):
def test_fn(x):
# logaddexp can be replaced with any op that does not have
# xla lowering.
y = torch.logaddexp(x, x)
return torch.masked_select(y, y.eq(0))
x = torch.tensor([1., 2., 3.])
self.runAtenTest([x], test_fn)
class TestDynamicShape(test_utils.XlaTestCase):
def test_nonzero_shape(self):
x = torch.tensor((0, 1, 2, 0, 3, 4), device=xm.xla_device())
x_dim0_shape = torch_xla._XLAC._get_xla_tensor_dimension_size(
torch.nonzero(x, as_tuple=False), 0)
self.assertEqual(x_dim0_shape.item(), 4)
def test_masked_select_shape(self):
x = torch.tensor((0, 1, 2, 0, 3, 4), device=xm.xla_device())
mask = x.ge(2)
x_dim0_shape = torch_xla._XLAC._get_xla_tensor_dimension_size(
torch.masked_select(x, mask), 0)
self.assertEqual(x_dim0_shape.item(), 3)
def test_nonzero_cast(self):
t1 = torch.ones(5, 2, device=xm.xla_device())
# Result of the nonzero should be the index type. Currently
# index type is s64 on cpu and gpu, but s32 on TPU. We should be
# able to cast it to any other type without error.
t2 = torch.nonzero(t1.int()).float()
xm.mark_step()
class TestOptimizationBarrier(test_utils.XlaTestCase):
def test_optimization_barrier_correctness(self):
device = xm.xla_device()
# only test optimization_barrier on TPU
if xm.xla_device_hw(device) != 'TPU':
return
x = torch.randn(5, 5, device=device)
y = torch.randn(5, 5, device=device)
z = x + y
xm.optimization_barrier_([x, y])
self.assertEqual(z, x + y)
class TestDataType(test_utils.XlaTestCase):
def test_mixed_dtype_tuple(self):
def op_fn(a):
return xb.Op.tuple((a, a.cast(xb.Type.BF16)))
op = xor.register('test_mixed_dtype_tuple', op_fn)
xla_device = xm.xla_device()
a_tensor = torch.randn([2, 3]).to(xla_device)
a_result, a_cast = op(a_tensor)
self.assertEqual(a_result.dtype, torch.float)
self.assertEqual(a_cast.dtype, torch.bfloat16)
class TestAtenXlaTensor(test_utils.XlaTestCase):
def test_get_real_xla_devices(self):
devices = xm.get_xla_supported_devices()
xla_devices = torch_xla._XLAC._xla_real_devices(devices)
for device, xdevice in zip(devices, xla_devices):
self.assertIsNotNone(re.fullmatch(r'[A-Z]+:\d+$', xdevice))
def test_negative_slice(self):
t = _gen_tensor(32, 24, 32)
x = t.to(xm.xla_device())
t_slice = t[:, :, -1]
x_slice = x[:, :, -1]
self.assertEqual(t_slice.data, x_slice.data.cpu())
def test_negative_cat(self):
t = _gen_tensor(2, 5, 3)
x = t.to(xm.xla_device())
t_cat = torch.cat([t, t], -1)
x_cat = torch.cat([x, x], -1)
self.assertEqual(t_cat.data, x_cat.data.cpu())
def test_cat_empty_tensor(self):
t = _gen_tensor(2, 5, 3)
empty_tensor = torch.Tensor()
x = t.to(xm.xla_device())
empty_tensor_xla = empty_tensor.to(xm.xla_device())
t_cat = torch.cat([t, empty_tensor], 0)
x_cat = torch.cat([x, empty_tensor_xla], 0)
self.assertEqual(t_cat.data, x_cat.data.cpu())
def test_nan_to_num_in_place(self):
t = torch.tensor([float('nan'), float('nan'), -float('nan'), 3.14])
def fn(x):
x.nan_to_num_(1.0, 2.0, 3.0)
return x
self.runAtenTest(t, fn)
@skipOnTpu
def test_nan_to_num_in_place_with_inf(self):
# Since TPU converts double to float (unlike CPU), the Inf entries are
# expected to be different. Skipping tests for Inf entries.
t = torch.tensor([float('nan'), float('inf'), -float('inf'), 3.14])
def fn(x):
x.nan_to_num_(1.0, 2.0, 3.0)
return x
self.runAtenTest(t, fn)
@skipOnTpu
def test_amp_foreach_non_finite_check_and_unscale_(self):
# Since TPU converts double to float (unlike CPU), the Inf entries are
# expected to be different. Skipping tests for Inf entries.
grads0 = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
grads1 = torch.tensor([1.0, 2.0, float('nan'), 4.0], dtype=torch.float32)
inv_scale = torch.tensor(0.2, dtype=torch.float32)
found_inf = torch.tensor(0, dtype=torch.float32)
grads_output0 = grads0 * inv_scale
found_inf_output0 = torch.tensor(0, dtype=torch.float32)
found_inf_output1 = torch.tensor(1, dtype=torch.float32)
xla_device = xm.xla_device()
xla_grads0 = grads0.to(xla_device)
xla_inv_scale = inv_scale.to(xla_device)
xla_found_inf = found_inf.to(xla_device)
torch._amp_foreach_non_finite_check_and_unscale_([xla_grads0],
xla_found_inf,
xla_inv_scale)
self.assertEqual(grads_output0, xla_grads0, prec=1e-4)
self.assertEqual(found_inf_output0, xla_found_inf)
xla_grads1 = grads1.to(xla_device)
torch._amp_foreach_non_finite_check_and_unscale_([xla_grads1],
xla_found_inf,
xla_inv_scale)
self.assertEqual(found_inf_output1, xla_found_inf)
def test_masked_fill_with_tensor(self):
input = _gen_tensor(2, 5, 4, 3)
mask = _gen_mask(input.size())
value = torch.tensor(42)
xla_input = input.to(xm.xla_device())
xla_mask = mask.to(xm.xla_device())
xla_value = value.to(xm.xla_device())
result = torch.masked_fill(input, mask, value)
xla_result = torch.masked_fill(xla_input, xla_mask, xla_value)
self.assertEqual(input.data, xla_input.data.cpu())
self.assertEqual(result.data, xla_result.data.cpu())
def test_masked_fill_in_out_place(self):
def test_fn(a, b, m):
ar = torch.masked_fill(a, m, 17.0)
bi = b.masked_fill_(m, 21.0)
return ar, bi
t1 = _gen_tensor(5, 3)
t2 = _gen_tensor(*t1.size())
self.runAtenTest([t1, t2, _gen_mask(t1.size())], test_fn)
def test_add_mixed_device(self):
input = _gen_tensor(3, 800, 1066)
xla_input = input.to(xm.xla_device())
output = input + 2
xla_output = xla_input + 2
self.assertEqual(output.data, xla_output.data.cpu())
def test_mul_mixed_device(self):
input = _gen_tensor(3, 800, 1066)
xla_input = input.to(xm.xla_device())
output = input * 2
xla_output = xla_input * 2
self.assertEqual(output.data, xla_output.data.cpu())
def test_sub_mixed_device(self):
input = _gen_tensor(3, 800, 1066)
xla_input = input.to(xm.xla_device())
output = input - 2
xla_output = xla_input - 2
self.assertEqual(output.data, xla_output.data.cpu())
def test_div_mixed_device(self):
input = _gen_tensor(3, 800, 1066)
xla_input = input.to(xm.xla_device())
output = input / 2
xla_output = xla_input / 2
self.assertEqual(output.data, xla_output.data.cpu())
def test_rand(self):
x = torch.rand(3, 5, device=xm.xla_device())
self.assertEqual(x.device.type, 'xla')
def test_randperm(self):
x = torch.randperm(3, device=xm.xla_device(), dtype=torch.int32)
self.assertEqual(x.device.type, 'xla')
def test_randn_like(self):
shape = (5, 1, 1)
x = torch.randn_like(torch.zeros(shape, device=xm.xla_device()))
self.assertEqual(x.device.type, 'xla')
def test_rand_like(self):
shape = (5, 1, 1)
x = torch.rand_like(torch.zeros(shape, device=xm.xla_device()))
self.assertEqual(x.device.type, 'xla')
def test_randint_like(self):
shape = (5, 1, 1)
x = torch.randint_like(
torch.zeros(shape, device=xm.xla_device(), dtype=torch.uint8), 6, 10)
self.assertEqual(x.device.type, 'xla')
def test_no_storage(self):
x = torch.randn(5, device=xm.xla_device())
self.assertRaises(Exception, x.device)
def test_slice_copy(self):
a = torch.rand(3, 3, 3)
xla_device = xm.xla_device()
xla_a = a.to(xla_device)
shape = (4, 4, 4)
b = a.new(*shape).zero_()
xla_b = xla_a.new(*shape).zero_()
b[:a.shape[0], :a.shape[1], :a.shape[2]].copy_(a)
xla_b[:a.shape[0], :a.shape[1], :a.shape[2]].copy_(xla_a)
self.assertEqual(b.data, xla_b.data.cpu())
def test_slice_assign(self):
a = torch.rand(3, 3, 3)
xla_device = xm.xla_device()
xla_a = a.to(xla_device)
shape = (4, 4, 4)
b = a.new(*shape).zero_()
xla_b = xla_a.new(*shape).zero_()
b[0, :, :] = 1
xla_b[0, :, :] = 1
self.assertEqual(b.data, xla_b.data.cpu())
def test_slice_stepped_assign(self):
a = torch.ones((10, 4))
xla_device = xm.xla_device()
xla_a = a.to(xla_device)
a[:, 0::2] = 2
xla_a[:, 0::2] = 2
self.assertEqual(a.data, xla_a.data.cpu())
def test_slice_stepped_other_assign(self):
a = torch.ones((10, 4))
xla_device = xm.xla_device()
xla_a = a.to(xla_device)
a[:, 1::4] = 2
xla_a[:, 1::4] = 2
self.assertEqual(a.data, xla_a.data.cpu())
def test_ailing_slice(self):
xla_device = xm.xla_device()
a = torch.ones((1000, 324)).to(xla_device)
xla_a = a.to(xla_device)
w = a[:, 2::4]
xla_w = a[:, 2::4]
dw = torch.clamp(w, max=3.1)
xla_dw = torch.clamp(xla_w, max=3.1)
self.assertEqual(w.data, xla_w.data.cpu())
def test_slice_rnd_stepped_assign(self):
xla_device = xm.xla_device()
size = 10
for s in range(0, size - 1):
for e in range(1, size - s):
a = torch.ones((3, size))
xla_a = a.to(xla_device)
a[:, s::e] = 2
xla_a[:, s::e] = 2
self.assertEqual(a.data, xla_a.data.cpu())
def test_arange_nan(self):
with self.assertRaisesRegex(RuntimeError, r'unsupported range'):
a = torch.arange(-5, float('nan'), device=xm.xla_device())
with self.assertRaisesRegex(RuntimeError, r'unsupported range'):
a = torch.arange(float('nan'), 5, device=xm.xla_device())
def test_empty_advanced_indexing(self):
xla_device = xm.xla_device()
base = torch.randn(2, 3, 4, 5)
xla_base = base.to(device=xla_device)
result = base[:, torch.empty(0, 6, dtype=torch.int64)]
xla_result = xla_base[:, torch.empty(0, 6, dtype=torch.int64)]
self.assertEqual(result, xla_result)
@unittest.skip(
"grad_input produces wrong results after functionalization. pytorch/pytorch#91199"
)
def test_empty_strided(self):
xla_device = xm.xla_device()
m = nn.Conv1d(4, 6, kernel_size=3, groups=2)
a = torch.rand(2, 4, 6, requires_grad=True)
xla_m = copy.deepcopy(m).to(xla_device)
xla_a = a.clone().to(xla_device).detach()
xla_a.requires_grad = True
output = m(a)
grad_input = torch.autograd.grad(
output, (a,) + tuple(m.parameters()), output, create_graph=True)
grad_grad_input = torch.autograd.grad(
output.sum() + sum(map(lambda x: x.sum(), grad_input)),
(a, output) + tuple(m.parameters()),
retain_graph=True)
xla_output = xla_m(xla_a)
xla_grad_input = torch.autograd.grad(
xla_output, (xla_a,) + tuple(xla_m.parameters()),
xla_output,
create_graph=True)
xla_grad_grad_input = torch.autograd.grad(
xla_output.sum() + sum(map(lambda x: x.sum(), xla_grad_input)),
(xla_a, xla_output) + tuple(xla_m.parameters()),
retain_graph=True)
self.assertEqual(output, xla_output, prec=1e-4)
self.assertEqual(grad_input, xla_grad_input, prec=1e-4)
self.assertEqual(grad_grad_input, xla_grad_grad_input, prec=1e-4)
def test_clamp(self):
a = torch.randn(3, 3)
xla_a = a.to(xm.xla_device())
b = torch.clamp(a, max=3.4)
xla_b = torch.clamp(xla_a, max=3.4)
self.assertEqual(b.data, xla_b.data.cpu())
def test_rrelu_module(self):
xla_device = xm.xla_device()
a = torch.rand(1, 2, 2, requires_grad=True)
xla_a = a.to(xla_device).detach()
xla_a.requires_grad = True
m = nn.RReLU()
xla_m = m.to(xla_device)
output = m(a)
xla_output = xla_m(xla_a)
self.assertEqual(output, xla_output.cpu())
output.sum().backward()
xla_output.sum().backward()
self.assertEqual(a.grad, xla_a.grad.cpu())
def test_max_broadcast(self):
xla_device = xm.xla_device()
a = torch.rand(3, 1, 2)
b = torch.rand(4, 2)
c = torch.max(a, b)
xla_a = a.to(xla_device)
xla_b = b.to(xla_device)
xla_c = torch.max(xla_a, xla_b)
self.assertEqual(c.data, xla_c.data.cpu())
def test_sgn(self):
xla_device = xm.xla_device()
t = torch.randn(2, 3, dtype=torch.cfloat)
# Generate inf+infj
t[0][0].real.div_(0)
t[0][0].imag.div_(0)
# Generate nan+nanj
t[0][1] = 0
t[0][1].real.div_(0)
t[0][1].imag.div_(0)
# Generate 0+0j
t[1][0] = 0
# Generate inf+0j
t[1][1].real.div_(0)
t[1][1] = t[1][1].real.abs()
# Generate -inf+0j
t[1][2].real.div_(0)
t[1][2] = t[1][1].real.abs() * -1
a = t.sgn()
xla_a = t.to(xla_device).sgn()
self.assertEqual(a.data, xla_a.data.cpu())
t = torch.randn(2, 3, dtype=torch.float32)
t[0][0].div_(0)
t[0][1] = 0
t[0][1].div_(0)
t[1][0] = 0
t[1][2].div_(0)
t[1][2] = t[1][1].abs() * -1
a = t.sgn()
xla_a = t.to(xla_device).sgn()
self.assertEqual(a.data, xla_a.data.cpu())
@skipIfFunctionalizationDisabled("view_as_real unsupported")
def test_view_as_real_c64(self):
xla_device = torch_xla.device()
x = torch.randn(4, dtype=torch.cfloat, device=xla_device)
real = torch.view_as_real(x)
self.assertEqual(real.dtype, torch.float32)
# XLA type of the real needs to be f32 as well
self.assertIn("f32[4,2]", torch_xla._XLAC._get_xla_tensor_debug_info(real))
# HLO generated needs to have type f32 as well
self.assertIn("f32[4,2]",
torch_xla._XLAC._get_xla_tensors_text([real]).split('\n')[-3])
@skipIfFunctionalizationDisabled("view_as_real unsupported")
def test_view_as_real_c128(self):
xla_device = torch_xla.device()
x = torch.randn(4, dtype=torch.cdouble, device=xla_device)
real = torch.view_as_real(x)
self.assertEqual(real.dtype, torch.float64)
# XLA type of the real needs to be f32 as well
self.assertIn("f64[4,2]", torch_xla._XLAC._get_xla_tensor_debug_info(real))
# HLO generated needs to have type f32 as well
self.assertIn("f64[4,2]",
torch_xla._XLAC._get_xla_tensors_text([real]).split('\n')[-3])
@skipIfFunctionalizationDisabled("view_as_real unsupported")
def test_view_as_complex_f32(self):
xla_device = torch_xla.device()
x = torch.randn(4, 2, device=xla_device)
complex = torch.view_as_complex(x)
self.assertEqual(complex.dtype, torch.complex64)
# XLA type of the real needs to be f32 as well
self.assertIn("c64[4]", torch_xla._XLAC._get_xla_tensor_debug_info(complex))
# HLO generated needs to have type f32 as well
self.assertIn(
"c64[4]",
torch_xla._XLAC._get_xla_tensors_text([complex]).split('\n')[-3])
@skipIfFunctionalizationDisabled("view_as_real unsupported")
def test_view_as_complex_f64(self):
xla_device = torch_xla.device()
x = torch.randn(4, 2, dtype=torch.float64, device=xla_device)
complex = torch.view_as_complex(x)
self.assertEqual(complex.dtype, torch.complex128)
# XLA type of the real needs to be f32 as well
self.assertIn("c128[4]",
torch_xla._XLAC._get_xla_tensor_debug_info(complex))
# HLO generated needs to have type f32 as well
self.assertIn(
"c128[4]",
torch_xla._XLAC._get_xla_tensors_text([complex]).split('\n')[-3])
def test_index_put(self):
xla_device = xm.xla_device()
a = torch.tensor([1, 1, 1, 1]).to(xla_device).to(dtype=torch.float32)
b = torch.rand(4) > 0.1
a[b] = 10
vset = b.sum().item()
self.assertEqual(a.sum().item(), 10.0 * vset + (4.0 - vset))
@skipOnTpu
def test_pow_integer_types(self):
self.runAtenTest(torch.randint(10, (2, 2)), lambda x: torch.pow(x, 2))
self.runAtenTest(torch.randint(10, (2, 2)), lambda x: torch.pow(2, x))
self.runAtenTest(torch.randint(10, (2, 2)), lambda x: torch.pow(x, x))
self.runAtenTest(torch.randint(10, (2, 2)), lambda x: x.pow_(2))
self.runAtenTest(torch.randint(10, (2, 2)), lambda x: x.pow_(x))
@skipOnTpu
def test_matmul_integer_types(self):
# all variance of matmul: dot/mv/mm/bmm
self.runAtenTest((torch.randint(10, (2,)), torch.randint(10, (2,))),
lambda x, y: torch.matmul(x, y))
self.runAtenTest((torch.randint(10, (3, 4)), torch.randint(10, (4,))),
lambda x, y: torch.matmul(x, y))
self.runAtenTest((torch.randint(10, (10, 3, 4)), torch.randint(10, (4,))),
lambda x, y: torch.matmul(x, y))
self.runAtenTest((torch.randint(10,
(10, 3, 4)), torch.randint(10, (10, 4, 5))),
lambda x, y: torch.matmul(x, y))
self.runAtenTest((torch.randint(10, (10, 3, 4)), torch.randint(10, (4, 5))),
lambda x, y: torch.matmul(x, y))
@skipOnTpu
def test_addmm_integer_types(self):
self.runAtenTest((torch.randint(10, (2, 3)), torch.randint(
10, (2, 3)), torch.randint(10, (3, 3))),
lambda x, y, z: torch.addmm(x, y, z))
@skipOnTpu
def test_baddmm_integer_types(self):
self.runAtenTest(
(torch.randint(10, (10, 3, 5)), torch.randint(10, (10, 3, 4)),
torch.randint(10, (10, 4, 5))), lambda x, y, z: torch.baddbmm(x, y, z))
def test_view_empty(self):
# These used to throw floating point exception.
empty = torch.empty(0, device=xm.xla_device())
with self.assertRaisesRegex(
RuntimeError, r'unspecified dimension size -1 can be any value'):
empty.view(-1, 0)
with self.assertRaisesRegex(
RuntimeError, r'unspecified dimension size -1 can be any value'):
empty.view(3, 0, -1, 0)
def test_view_1718(self):
def test_fn(device):
torch.manual_seed(0)
linear = nn.Linear(8, 16).to(device=device)
batch = torch.rand(4, 8).to(device=device)
x = linear(batch)
x[:, :4] = 0
loss = x.sum()
loss.backward()
return loss, linear.weight.grad
cpu_loss, cpu_weight_grad = test_fn('cpu')
xla_loss, xla_weight_grad = test_fn(xm.xla_device())
self.assertEqual(cpu_loss, xla_loss)
self.assertEqual(cpu_weight_grad, xla_weight_grad)
def test_inplace_view_backprop_base(self):
root = torch.randn(2, 2, device=xm.xla_device(), requires_grad=True)
x = root.clone()
v1 = x.narrow(0, 0, 1)
v1.mul_(2)
x.sum().backward()
self.assertEqual(root.grad.tolist(), [[2, 2], [1, 1]])
def test_inplace_view_backprop_view_of_view(self):
root = torch.randn(2, 2, device=xm.xla_device(), requires_grad=True)
x = root.clone()
v1 = x.narrow(0, 0, 1)
v2 = x.narrow(0, 0, 1)
v1.mul_(2)
v2.sum().backward()
self.assertEqual(root.grad.tolist(), [[2, 2], [0, 0]])
def test_inplace_view_of_view(self):
# modify view-of-view and backprop through base
root = torch.randn(2, 2, device=xm.xla_device(), requires_grad=True)
x = root.clone()
v1 = x.narrow(0, 0, 1)
v2 = v1.narrow(1, 1, 1)
v2.mul_(2)
x.sum().backward()
self.assertEqual(root.grad.tolist(), [[1, 2], [1, 1]])
def test_inplace_view_multiple_outputs(self):
root = torch.arange(
9., device=xm.xla_device()).reshape(3, 3).requires_grad_()
x = root.clone()
v1 = x.unbind()
with self.assertRaises(RuntimeError):
v1[0].mul_(2)
v2 = v1[0].narrow(0, 0, 2)
with self.assertRaises(RuntimeError):
v2.mul_(2)
def test_inplace_view_gradcheck(self):
# gradcheck modifications to views
a = torch.randn(4, 4, requires_grad=True)
b = torch.randn(2, 2, requires_grad=True)
def test_fn(root, b):
x = root.clone()
x.narrow(1, 2, 2).narrow(0, 1, 2).mul_(b)
x.narrow(1, 0, 2).narrow(0, 1, 2).mul_(b)
x.sum().backward()
return x
self.runAtenTest((a, b), test_fn)
def test_inplace_view_makes_base_require_grad(self):
# in-place modification to view makes base require grad
a = torch.randn(4, 4, requires_grad=False)
b = torch.randn(4, 2, requires_grad=True)
def func(root, b):
x = root.clone()
self.assertFalse(x.requires_grad)
x.narrow(1, 2, 2).mul_(b)
self.assertTrue(x.requires_grad)
x.sum().backward()
self.assertTrue(root.grad is None)
return b.grad
self.runAtenTest((a, b), func)
def test_inplace_view_backprop_view(self):
# modify view and backprop through view
xla_device = xm.xla_device()
a = torch.tensor([2., 5.], device=xla_device, requires_grad=False)
b = torch.tensor([3.], device=xla_device, requires_grad=True)
res = a.narrow(0, 1, 1).mul_(b)
res.sum().backward()
self.assertEqual(b.grad.tolist(), [5])
self.assertIsNone(a.grad)
def test_inplace_view_modify_base(self):
# Test that an in-place operation on a base that forced it to require
# grad also forces any previous views to require grad and backprop
# correctly
r = torch.ones(1, requires_grad=True)