-
Notifications
You must be signed in to change notification settings - Fork 486
/
test_xla_sharding.py
1409 lines (1216 loc) · 57.4 KB
/
test_xla_sharding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import copy
import unittest
from unittest.mock import patch
import math
import numpy as np
import os
import sys
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
import torch_xla
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.spmd as xs
from torch_xla.distributed.spmd import XLAShardedTensor
import torch_xla.distributed.parallel_loader as pl
import test_xla_sharding_base
import torch_xla.core.xla_env_vars as xenv
import torch_xla.utils.utils as xu
from torch_xla._internal import tpu
class BasicXlaShardingTest(test_xla_sharding_base.XlaShardingTest):
@classmethod
def setUpClass(cls):
super().setUpClass()
def test_xla_sharded_tensor(self):
partition_spec = (0, 1)
xt1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]],
dtype=torch.float,
device=xm.xla_device())
xst1 = xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)),
partition_spec)
self.assertTrue(isinstance(xst1, XLAShardedTensor))
def test_xla_sharded_tensor_repr(self):
xt = torch.randn(128, 128).to(xm.xla_device())
model = self.SimpleLinear().to(xm.xla_device())
mesh = self._get_mesh((1, self.n_devices))
partition_spec = (0, 1)
xst = xs.mark_sharding(xt, mesh, partition_spec)
self.assertTrue(isinstance(xst, XLAShardedTensor))
xt_output = model(xt)
self.assertTrue('XLAShardedTensor' not in str(xt_output))
xst_output = model(xst)
self.assertTrue('XLAShardedTensor' in str(xst_output))
def test_sharded_tensor_debug_info(self):
partition_spec = (0, 1)
xt1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]],
dtype=torch.float,
device=xm.xla_device())
xst1 = xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)),
partition_spec)
debug_info = torch_xla._XLAC._get_xla_tensor_debug_info(xst1.global_tensor)
self.assertIn('XLAShardedData', debug_info)
self.assertIn('Data Device: SPMD:0', debug_info)
self.assertIn('OpSharding: {', debug_info)
self.assertIn('NumShards: %s' % (self.n_devices), debug_info)
def test_xla_shards(self):
num_element = self.n_devices
mesh = self._get_mesh((self.n_devices,))
t = torch.arange(num_element, dtype=torch.float32)
xt = xs.mark_sharding(t.to(xm.xla_device()), mesh, (0,))
shards = xt.local_shards
self.assertEqual(len(shards), self.n_devices)
shard_len = math.ceil(num_element / self.n_devices)
for i, shard in enumerate(shards):
self.assertEqual(shard.data.device, torch.device('cpu'))
self.assertEqual(shard.data.shape, (shard_len,))
start, end = i * shard_len, (i + 1) * shard_len
expected = torch.arange(start, end, dtype=torch.float32)
self.assertTrue(torch.allclose(shard.data, expected))
if isinstance(shard.indices, list):
self.assertEqual(len(shard.indices), len(t.shape))
self.assertEqual(shard.indices[0], slice(start, end, 1))
else:
self.assertIsInstance(shard.indices, type(Ellipsis))
self.assertTrue(torch.allclose(shard.data, t[shard.indices]))
# Tiled sharding makes all shards have replica_id 0.
self.assertEqual(shard.replica_id, 0)
def test_padded_xla_shards(self):
num_element = self.n_devices + 1 # Ensure padding with two or more devices
mesh = self._get_mesh((self.n_devices,))
t = torch.arange(num_element, dtype=torch.float32)
xt = xs.mark_sharding(t.to(xm.xla_device()), mesh, (0,))
shards = xt.local_shards
self.assertEqual(len(shards), self.n_devices)
shard_len = math.ceil(num_element / self.n_devices)
for i, shard in enumerate(shards):
self.assertEqual(shard.data.device, torch.device('cpu'))
self.assertEqual(shard.data.shape, (shard_len,))
# Tensor shards will be zero-padded
start, end = min(i * shard_len, t.shape[0]), min((i + 1) * shard_len,
t.shape[0])
if start < num_element:
expected = torch.arange(start, end, dtype=torch.float32)
pad_len = shard_len - expected.shape[0]
expected = F.pad(expected, (0, pad_len), "constant", 0)
else:
expected = torch.zeros(shard.data.shape, dtype=torch.float32)
self.assertTrue(torch.allclose(shard.data, expected))
if isinstance(shard.indices, list):
self.assertEqual(len(shard.indices), len(t.shape))
self.assertEqual(shard.indices[0], slice(start, end, 1))
else:
self.assertIsInstance(shard.indices, type(Ellipsis))
self.assertTrue(torch.allclose(shard.unpadded_data, t[shard.indices]))
# Tiled sharding makes all shards have replica_id 0.
self.assertEqual(shard.replica_id, 0)
def test_replicated_xla_shards(self):
num_element = self.n_devices
mesh = self._get_mesh((self.n_devices,))
t = torch.arange(num_element, dtype=torch.float32)
xt = xs.mark_sharding(t.to(xm.xla_device()), mesh, (None,))
shards = xt.local_shards
self.assertEqual(len(shards), self.n_devices)
for i, shard in enumerate(shards):
self.assertEqual(shard.data.device, torch.device('cpu'))
self.assertEqual(shard.data.shape, (num_element,))
self.assertTrue(torch.allclose(shard.data, t))
self.assertIsInstance(shard.indices, type(Ellipsis))
self.assertTrue(torch.allclose(shard.data, t[shard.indices]))
self.assertTrue(torch.allclose(shard.data, shard.unpadded_data))
# Replicated sharding sets the shard replica_id to the device ordinal
self.assertEqual(shard.replica_id, i)
@unittest.skipUnless(xr.global_runtime_device_count() >= 4,
"Multiple devices required for partial replication")
def test_partially_replicated_xla_shards(self):
num_element = 256
mesh = self._get_mesh((self.n_devices // 2, 2))
t = torch.arange(num_element, dtype=torch.float32).reshape((16, 16))
# Partial replication along the 0th tensor axis, shard 2-way on the 1st
xt = xs.mark_sharding(t.to(xm.xla_device()), mesh, (None, 1))
shard_len = t.shape[1] // 2
shards = xt.local_shards
self.assertEqual(len(shards), self.n_devices)
for i, shard in enumerate(shards):
self.assertEqual(shard.data.device, torch.device('cpu'))
self.assertEqual(shard.data.shape, (t.shape[0], shard_len))
self.assertEqual(len(shard.indices), len(t.shape))
start, end = (i % 2) * shard_len, ((i % 2) + 1) * shard_len
# All shards should contain the full range for dim 0
self.assertEqual(shard.indices[0], slice(0, t.shape[0], 1))
# The index range should be sharded for dim 1
self.assertEqual(shard.indices[1], slice(start, end, 1))
self.assertTrue(torch.allclose(shard.data, t[shard.indices]))
self.assertTrue(torch.allclose(shard.data, shard.unpadded_data))
# The replica_id should be coincide with the replication group for the
# device. Given the mesh shape, the shard replica_id will be the device's
# row in the mesh, which is device_id // 2
self.assertEqual(shard.replica_id, i // 2)
def test_load_local_shards(self):
num_element = self.n_devices
mesh = self._get_mesh((self.n_devices,))
t = torch.arange(num_element, dtype=torch.float32) + 1
xt = xs.mark_sharding(t.to(xm.xla_device()), mesh, (0,))
local_shards = xt.local_shards
self.assertTrue(len(local_shards) == self.n_devices)
# More than one device is required for sharding to not be REPLICATED
if self.n_devices > 1:
for shard in local_shards:
# Update the shard's data on CPU
self.assertEqual(shard.data.device, torch.device('cpu'))
shard.data *= -1
# Loading a complete list of shards should succeed
xt.load_local_shards_(local_shards)
self.assertTrue(torch.allclose(xt.cpu(), -t))
# Loading an incomplete list of shards should fail
with self.assertRaises(RuntimeError):
xt.load_local_shards_(local_shards[:-1])
# Loading incompatible shapes should fail
for local_shard in local_shards:
local_shard.data = torch.randn(*(2 * local_shard.data.shape))
with self.assertRaises(RuntimeError):
xt.load_local_shards_(local_shards)
# Replicated shards should fail
rt = xs.mark_sharding(t.to(xm.xla_device()), mesh, (None,))
local_shards = rt.local_shards
with self.assertRaises(RuntimeError):
rt.load_local_shards_(local_shards)
def test_xla_sharding_type(self):
t = torch.randn(10, 20).to(xm.xla_device())
self.assertEqual(torch_xla._XLAC._get_xla_sharding_type(t), None)
x_dim = 2 if self.n_devices >= 2 else 1
# if self.n_devices==4, mesh=(2,2)
# if self.n_devices==2, mesh=(2,1)
# if self.n_devices==1, mesh=(1,1)
mesh = self._get_mesh((x_dim, self.n_devices // x_dim))
xt = xs.mark_sharding(t, mesh, (0, 1))
if self.n_devices >= 2:
self.assertEqual(xt.sharding_type, xs.ShardingType.TILED)
else:
self.assertEqual(xt.sharding_type, xs.ShardingType.REPLICATED)
xs.clear_sharding(t)
xt = xs.mark_sharding(t, mesh, (None, None))
self.assertEqual(xt.sharding_type, xs.ShardingType.REPLICATED)
xs.clear_sharding(t)
xt = xs.mark_sharding(t, mesh, (None, 1))
if mesh.get_logical_mesh().shape[1] > 1:
self.assertEqual(xt.sharding_type, xs.ShardingType.PARTIAL)
else:
self.assertEqual(xt.sharding_type, xs.ShardingType.REPLICATED)
def test_custom_tile_assignment(self):
xt = torch.randn(10, 20).to(device=xm.xla_device())
mesh_shape = (1, self.n_devices)
device_ids = np.flip(self.device_ids)
mesh = self._get_mesh(mesh_shape, device_ids)
xs.mark_sharding(xt, mesh, (0, 1))
if self.n_devices > 1:
annotation = '{devices=[1,%d]%s}' % (self.n_devices, ','.join(
[str(i) for i in reversed(range(self.n_devices))]))
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt))
def test_mark_sharding_2d(self):
t1 = torch.randn(1, 128, device='cpu')
t2 = torch.randn(1, 128, device='cpu')
expected = t1 + t2
xt1 = t1.to(xm.xla_device())
xt2 = t2.to(xm.xla_device())
xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)), (0, 1))
if self.n_devices > 1:
annotation = '{devices=[1,%d]%s}' % (self.n_devices, ','.join(
[str(i) for i in range(self.n_devices)]))
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt1))
actual = (xt1 + xt2).cpu()
self.assertTrue(torch.allclose(expected, actual))
def test_mark_sharding_4d(self):
t = torch.randn(2, 4, 8, 16, device='cpu')
expected = t + t
xt = t.to(xm.xla_device())
# Shard along two axes if four or more devices are available
z_dim = 2 if self.n_devices >= 4 else 1
xs.mark_sharding(xt, self._get_mesh((1, 1, z_dim, self.n_devices // z_dim)),
(0, 1, 2, 3))
if self.n_devices > 1:
annotation = '{devices=[1,1,%d,%d]%s}' % (
z_dim, self.n_devices // z_dim, ','.join(
[str(i) for i in range(self.n_devices)]))
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt))
actual = (xt + xt).cpu()
self.assertTrue(torch.allclose(expected, actual))
def test_mark_sharding_not_ordered_sharding_spec_2d(self):
device = xm.xla_device()
t1 = torch.randn(8, 16, device='cpu')
expected = t1 + t1
xt1 = t1.to(device)
# Shard along first dimension
xt1 = xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)), (1, 0))
for local_shard in xt1.local_shards:
self.assertEqual(local_shard.data.size()[0], 8 / self.n_devices)
self.assertEqual(local_shard.data.size()[1], 16)
self.assertTrue(torch.allclose(expected, (xt1 + xt1).cpu()))
def test_mark_sharding_not_ordered_sharding_spec_3d(self):
device = xm.xla_device()
t1 = torch.randn(4, 8, 16, device='cpu')
expected = t1 + t1
xt1 = t1.to(device)
z_dim = 2 if self.n_devices >= 4 else 1
# Expect local shard size to be [4, 8 / (self.n_devices / z_dim), 16 / z_dim]
xt1 = xs.mark_sharding(xt1,
self._get_mesh((z_dim, 1, self.n_devices // z_dim)),
(1, 2, 0))
for local_shard in xt1.local_shards:
self.assertEqual(local_shard.data.size()[0], 4)
self.assertEqual(local_shard.data.size()[1], 8 / (self.n_devices / z_dim))
self.assertEqual(local_shard.data.size()[2], 16 / z_dim)
self.assertTrue(torch.allclose(expected, (xt1 + xt1).cpu()))
def test_mark_sharding_not_ordered_sharding_spec_4d(self):
device = xm.xla_device()
t1 = torch.randn(32, 4, 8, 16, device='cpu')
expected = t1 + t1
xt1 = t1.to(device)
z_dim = 2 if self.n_devices >= 4 else 1
# Expect local shard size to be [32 / (self.n_devices / z_dim), 4, 8 , 16 / z_dim]
xt1 = xs.mark_sharding(
xt1, self._get_mesh((z_dim, 1, 1, self.n_devices // z_dim)),
(3, 1, 2, 0))
for local_shard in xt1.local_shards:
self.assertEqual(local_shard.data.size()[0],
32 / (self.n_devices / z_dim))
self.assertEqual(local_shard.data.size()[1], 4)
self.assertEqual(local_shard.data.size()[2], 8)
self.assertEqual(local_shard.data.size()[3], 16 / z_dim)
self.assertTrue(torch.allclose(expected, (xt1 + xt1).cpu()))
def test_mark_sharding_partial(self):
device = xm.xla_device()
t1 = torch.randn(4, 4).to(device)
t2 = torch.randn(4, 4).to(device)
# Somehow the eager cpu result is different from the xla result.
expected = t1 @ t2
# To re-materialize t1 and t2.
xm.mark_step()
xm.wait_device_ops()
expected = expected.cpu()
# Shard along two axes if four or more devices are available
z_dim = 2 if self.n_devices >= 4 else 1
mesh = self._get_mesh((z_dim, self.n_devices // z_dim))
xt1 = xs.mark_sharding(t1, mesh, (0, None))
# partial replication requires >= 4 devices; otherwise, it's replicated.
if self.n_devices >= 4:
# xt1 is sharded `z_dim`-way, replicated `n_devices/z_dim`-way.
self.assertIn('last_tile_dim_replicate',
torch_xla._XLAC._get_xla_sharding_spec(t1))
self.assertIn('[%d,1,%d]' % (z_dim, self.n_devices // z_dim),
torch_xla._XLAC._get_xla_sharding_spec(t1))
# replicated group should share the same data content.
if (self.n_devices // z_dim) > 1:
shards = xt1.local_shards
self.assertTrue(torch.allclose(shards[0].data, shards[1].data))
actual = (xt1 @ t2).cpu()
self.assertTrue(torch.allclose(expected, actual))
def test_propagate_replicated_sharding(self):
device = xm.xla_device()
t1 = torch.randn(4, 4).to(device)
t2 = torch.randn(4, 4).to(device)
t3 = t1 @ t2
# To propagate replicated sharding
xm.mark_step()
xm.wait_device_ops()
self.assertIn("replicated", torch_xla._XLAC._get_xla_sharding_spec(t3))
def test_mark_sharding_partial_unordered(self):
device = xm.xla_device()
t1 = torch.randn(4, 3, 4).to(device)
t2 = torch.randn(4, 3, 4).to(device)
expected = t1 + t2
# To re-materialize t1 and t2.
xm.mark_step()
xm.wait_device_ops()
expected = expected.cpu()
# Shard along two axes if four or more devices are available
z_dim = 2 if self.n_devices >= 4 else 1
mesh = self._get_mesh((z_dim, 1, self.n_devices // z_dim))
xt1 = xs.mark_sharding(t1, mesh, (1, None, 0))
# partial replication requires >= 4 devices; otherwise, it's replicated.
if self.n_devices >= 4:
# xt1 is sharded `z_dim`-way, replicated `n_devices/z_dim`-way.
self.assertIn('last_tile_dim_replicate',
torch_xla._XLAC._get_xla_sharding_spec(t1))
self.assertIn('[1,1,%d,%d]' % (z_dim, self.n_devices // z_dim),
torch_xla._XLAC._get_xla_sharding_spec(t1))
# replicated group should share the same data content.
if (self.n_devices // z_dim) > 1:
shards = xt1.local_shards
self.assertTrue(torch.allclose(shards[0].data, shards[1].data))
self.assertEqual(shards[0].data.shape, (4, 3, 4 // z_dim))
actual = (xt1 + t2).cpu()
self.assertTrue(torch.allclose(expected, actual))
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required for tupled partition spec")
def test_tupled_partition_spec(self):
mesh = self._get_mesh((2, self.n_devices // 2))
t = torch.randn(16).to(xm.xla_device())
xs.mark_sharding(t, mesh, ((0, 1),))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(t), "{devices=[%d]%s}" %
(self.n_devices, ','.join(str(x) for x in range(self.n_devices))))
@unittest.skipUnless(xr.global_runtime_device_count() >= 4,
"Multiple devices required for tupled partition spec")
def test_named_partial_tupled_partition_spec(self):
mesh = xs.Mesh(
range(self.n_devices), (1, 2, self.n_devices // 2), ('r', 'b', 'm'))
# Shard the first dimension on `r` and `b`, replicate the second dimension
t = torch.randn(16, 16).to(xm.xla_device())
xs.mark_sharding(t, mesh, (('r', 'b'), None))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(t),
"{devices=[2,1,%d]%s last_tile_dim_replicate}" %
(self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices))))
# Replicate the first dimension, shard the second on `b` and `m`
u = torch.randn(16, 16).to(xm.xla_device())
xs.mark_sharding(u, mesh, (None, ('b', 'm')))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(u), "{devices=[1,%d]%s}" %
(self.n_devices, ','.join(str(x) for x in range(self.n_devices))))
# Replicate the first dimension, shard the second on `r` and `m`
v = torch.randn(16, 16).to(xm.xla_device())
xs.mark_sharding(v, mesh, (None, ('r', 'm')))
device_order = mesh.get_logical_mesh().transpose((0, 2, 1)).flatten()
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(v),
"{devices=[1,%d,2]%s last_tile_dim_replicate}" %
(self.n_devices // 2, ','.join(str(x) for x in device_order)))
# Replicate the first dimension, shard the second on `m` and `b`
v = torch.randn(16, 16).to(xm.xla_device())
xs.mark_sharding(v, mesh, (None, ('m', 'b')))
device_order = mesh.get_logical_mesh().transpose((2, 1, 0)).flatten()
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(v), "{devices=[1,%d]%s}" %
(self.n_devices, ','.join(str(x) for x in device_order)))
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
'Multiple devices required for tupled partition spec')
def test_multiple_tuples_in_spec(self):
mesh = xs.Mesh(
range(self.n_devices), (1, 2, self.n_devices // 2, 1),
('a', 'b', 'c', 'd'))
t = torch.randn(2, 2).to(xm.xla_device())
xs.mark_sharding(t, mesh, (('a', 'b'), ('c', 'd')))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(t), "{devices=[2,%d]%s}" %
(self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices))))
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
'At least 2 devices needed for 2D mesh')
def test_3d_tensor_2d_mesh(self):
mesh = self._get_mesh((2, self.n_devices // 2))
t = torch.randn(16, 16, 16).to(xm.xla_device())
xs.mark_sharding(t, mesh, (None, 0, 1))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(t), '{devices=[1,2,%d]%s}' %
(self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices))))
def test_partial_replication_addmm(self):
device = xm.xla_device()
z_dim = 2 if self.n_devices >= 4 else 1
mesh = self._get_mesh((z_dim, self.n_devices // z_dim))
xx = torch.randn(16, 128).to(device)
xw = torch.randn(128, 256).to(device)
xb = torch.randn(16, 256).to(device)
# Somehow the eager cpu result is different from the xla result.
expected = xx @ xw + xb
xm.mark_step() # To re-materialize xx, xw, and xb.
xm.wait_device_ops()
expected = expected.cpu()
xs.mark_sharding(xx, mesh, (0, None))
xs.mark_sharding(xw, mesh, (None, 1))
# Check if the partial replication annotations are passed to the compiler.
# Note that partial replication requires >= 4 devices; otherwise, it's replicated.
if self.n_devices >= 4:
self.assertIn('last_tile_dim_replicate',
torch_xla._XLAC._get_xla_sharding_spec(xx))
self.assertIn('last_tile_dim_replicate',
torch_xla._XLAC._get_xla_sharding_spec(xw))
actual = (xx @ xw + xb).cpu()
self.assertTrue(torch.allclose(expected, actual, atol=1e-5))
def test_clear_sharding(self):
xt = torch.randn(2, 4, 8, 16).to(xm.xla_device())
xs.mark_sharding(xt, self._get_mesh((1, 1, 1, self.n_devices)),
(0, 1, 2, 3))
self.assertTrue(torch_xla._XLAC._get_xla_sharding_spec(xt))
xs.clear_sharding(xt)
self.assertFalse(torch_xla._XLAC._get_xla_sharding_spec(xt))
def test_replication_with_no_clear_sharding(self):
xt = torch.randn(2, 4).to(xm.xla_device())
# replication
xs.mark_sharding(xt, self._get_mesh((1, self.n_devices)), (None, None))
# sharding annotation over an existing replication sharding is permitted.
xs.mark_sharding(xt, self._get_mesh((1, self.n_devices)), (0, 1))
if self.n_devices > 1:
self.assertFalse(
"replicated" in torch_xla._XLAC._get_xla_sharding_spec(xt))
def test_deep_copy(self):
xt = torch.randn(2, 4, 8, 16).to(xm.xla_device())
xs.mark_sharding(xt, self._get_mesh((1, 1, 1, self.n_devices)),
(0, 1, 2, 3))
xt2 = copy.deepcopy(xt)
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(xt),
torch_xla._XLAC._get_xla_sharding_spec(xt2))
def test_clone(self):
xt = torch.randn(2, 4, 8, 16).to(xm.xla_device())
xs.mark_sharding(xt, self._get_mesh((1, 1, 1, self.n_devices)),
(0, 1, 2, 3))
sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(xt)
xt2 = xt.clone()
# check the original sharding spec is preserved after clone()
self.assertEqual(sharding_spec, torch_xla._XLAC._get_xla_sharding_spec(xt))
# check the cloned sharding spec is the same
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(xt),
torch_xla._XLAC._get_xla_sharding_spec(xt2))
def test_mark_step_with_sharding(self):
xt = torch.ones(2, 2).to(xm.xla_device())
xs.mark_sharding(xt, self._get_mesh((1, self.n_devices)), (0, 1))
sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(xt)
xm.mark_step() # mark_step should preserve the sharding
self.assertEqual(sharding_spec, torch_xla._XLAC._get_xla_sharding_spec(xt))
def test_execute_replicated_metrics(self):
met.clear_all()
xt = torch.ones(2, 2).to(xm.xla_device())
xs.mark_sharding(xt, self._get_mesh((1, self.n_devices)), (0, 1))
xt += 2
xm.mark_step()
xm.wait_device_ops()
self.assertEqual(met.metric_data('ExecuteReplicatedTime')[0], 1)
def test_optimizer_step_with_sharding(self):
# Use simple linear model to test model parameter sharding
model = self.SimpleLinear().to(xm.xla_device())
xs.mark_sharding(model.fc1.weight, self._get_mesh((1, self.n_devices)),
(0, 1))
sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight)
model.train()
optimizer = optim.SGD(model.parameters(), lr=0.1)
data = torch.randn(128, 128).to(xm.xla_device())
target = torch.zeros(128).to(xm.xla_device())
loss_fn = nn.CrossEntropyLoss()
for i in range(3):
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
xm.mark_step()
# Sharding is persisted across mark_step calls, and test if the sharded computation
# can repeat more than once without crashing.
self.assertEqual(sharding_spec,
torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight))
def test_sharding_propagation(self):
met.clear_all()
self.assertFalse(met.counter_value("ReplicateShardedData"))
# Linear model with two linear layers and only one is annotated.
model = self.SimpleLinear().to(xm.xla_device())
xs.mark_sharding(model.fc1.weight, self._get_mesh((1, self.n_devices)),
(0, 1))
self.assertTrue(torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight))
self.assertFalse(torch_xla._XLAC._get_xla_sharding_spec(model.fc2.weight))
model.train()
optimizer = optim.SGD(model.parameters(), lr=0.1)
data = torch.randn(128, 128).to(xm.xla_device())
target = torch.zeros(128).to(xm.xla_device())
loss_fn = nn.CrossEntropyLoss()
for i in range(3):
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
xm.mark_step()
# Verify that the fc1 & output are sharded and valid
model.fc1.weight.to('cpu')
output.to('cpu')
self.assertEqual(met.counter_value("ReplicateShardedData"), 2)
def test_inplace_add_with_sharding(self):
xt = torch.ones(2, 2).to(xm.xla_device())
xs.mark_sharding(xt, self._get_mesh((1, self.n_devices)), (0, 1))
sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(xt)
xt.add_(1) # inplace update should preserve the sharding
self.assertEqual(sharding_spec, torch_xla._XLAC._get_xla_sharding_spec(xt))
hlo = torch_xla._XLAC._get_xla_tensors_hlo([xt])
self.assertIn(
'%custom-call.7 = f32[2,2]{1,0} custom-call(f32[2,2]{1,0} %add.6), custom_call_target="Sharding", sharding=',
hlo)
# avoid calling xr.addressable_device_count here otherwise it will init the test
# in non-spmd mode.
@unittest.skipIf(xr.device_type() == 'CPU',
"sharding will be the same for both tensors on single device"
)
def test_shard_hashing(self):
xt1 = torch.ones(2, 2).to(xm.xla_device())
xt2 = torch.ones(2, 2).to(xm.xla_device())
# Add sharding to xt1, this should result in the hashes being different for
# xt1 and xt2
xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)), (0, 1))
# Adding 0 to the tensor force graph compilation, which would catch IR hash
# collisions
self.assertTrue(torch.allclose(xt1 + 0, xt2 + 0))
# Check that hashes are different for the sharded and non-sharded tensors
hash1 = torch_xla._XLAC._get_graph_hash([xt1 + 0])
hash2 = torch_xla._XLAC._get_graph_hash([xt2 + 0])
self.assertNotEqual(hash1, hash2)
def test_transfer_sharded_data_to_host(self):
xt1 = torch.ones(16, 16).to(xm.xla_device())
xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)), (0, 1))
t1 = xt1.cpu()
self.assertTrue(torch.allclose(t1, torch.ones(16, 16)))
def test_send_cpu_data_to_device_with_sharding(self):
# Execute pending graph to avoid contaminating metrics
xm.mark_step(wait=True)
met.clear_all()
tensor = torch.arange(16, dtype=torch.float32).reshape(1, 16)
mesh = self._get_mesh((1, self.n_devices))
# Create a ShardingSpec and use it to shard the tensor while sending to
# device
sharding_spec = xs.ShardingSpec(mesh, (0, 1))
self.assertTrue(sharding_spec.can_apply(tensor))
xtensors = xm.send_cpu_data_to_device([tensor],
xm.xla_device(),
input_sharding=sharding_spec)
self.assertEqual(len(xtensors), 1)
outbound = met.metric_data("OutboundData")[1]
self.assertEqual(outbound, tensor.element_size() * tensor.nelement())
# Verify the resulting sharding annotation matches an explicit
# `mark_sharding` call.
xt = xtensors[0]
explicit_xt = tensor.to(xm.xla_device())
xs.mark_sharding(explicit_xt, mesh, (0, 1))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(xt),
torch_xla._XLAC._get_xla_sharding_spec(explicit_xt))
def test_multiple_operations(self):
t1 = torch.randn(2, 2)
t2 = torch.randn(2, 2)
expected_1 = t1 + t2
xt1 = t1.to(xm.xla_device())
xt2 = t2.to(xm.xla_device())
xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)), (0, 1))
xt3 = xt1 + xt2
self.assertTrue(torch.allclose(expected_1, xt3.cpu()))
t4 = torch.randn(2, 2)
t5 = torch.randn(2, 2)
expected_2 = t4 + t5
xt4 = t4.to(xm.xla_device())
xt5 = t5.to(xm.xla_device())
xs.mark_sharding(xt4, self._get_mesh((1, self.n_devices)), (0, 1))
xs.mark_sharding(xt5, self._get_mesh((1, self.n_devices)), (0, 1))
xt6 = xt4 + xt5
self.assertTrue(torch.allclose(expected_2, xt6.cpu()))
def test_no_sharding(self):
partition_spec = (0, 1)
t1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]],
dtype=torch.float,
device=xm.xla_device())
t2 = torch.tensor([[8, 7, 6, 5, 4, 3, 2, 1]],
dtype=torch.float,
device=xm.xla_device())
t3 = t1 + t2
t3_expected = [9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0]
self.assertEqual(t3.tolist()[0], t3_expected)
def test_xla_sharded_hlo_dump(self):
partition_spec = (0, 1)
xt1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]],
dtype=torch.float,
device=xm.xla_device())
xst1 = xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)),
partition_spec)
xst2 = xst1 + 5
hlo = torch_xla._XLAC._get_xla_tensors_hlo([xst2.global_tensor])
self.assertIn('%p1.3 = f32[1,8]{1,0} parameter(1), sharding', hlo)
if torch_xla._XLAC._xla_get_auto_sharding():
# scalar 5 should be implicitly replicated, so the pre-optimization HLO
# shouldn't mark it with sharding.
self.assertNotIn('%p0.2 = f32[] parameter(0), sharding={replicated}', hlo)
def test_2d_tensor_3d_mesh(self):
ct1 = torch.randn(16, 16, device='cpu')
ct2 = torch.randn(16, 16, device='cpu')
expected = ct1 + ct2
t1 = ct1.to(xm.xla_device())
t2 = ct2.to(xm.xla_device())
# Meaningful test for higher-order mesh with extra replication
# requires multiple devices. Otherwise, this should defaults back to
# full replication.
if self.n_devices >= 4:
mesh = self._get_mesh((2, self.n_devices // 2, 1))
xs.mark_sharding(t1, mesh, partition_spec=(2, 1))
sharding_annotation = 'sharding={devices=[1,%d,2]' % (self.n_devices // 2)
elif self.n_devices == 2:
mesh = self._get_mesh((2, 1, 1))
xs.mark_sharding(t1, mesh, partition_spec=(2, 1))
sharding_annotation = "sharding={replicated}"
else:
mesh = self._get_mesh((1, 1, 1))
xs.mark_sharding(t1, mesh, partition_spec=(2, 1))
sharding_annotation = "sharding={replicated}"
self.assertIn(sharding_annotation,
torch_xla._XLAC._get_xla_tensors_hlo([t1]))
actual = (t1 + t2).cpu()
self.assertTrue(torch.allclose(expected, actual))
@unittest.skipIf(xr.device_type() == 'TPU' and tpu.version() < 3,
"Crash on TPU v2")
@unittest.skipUnless(
xu.getenv_as(xenv.PJRT_DEVICE, str) == "TPU",
f"Requires PJRT_DEVICE set to `TPU`.")
def test_hybrid_mesh_shape(self):
mesh = self._get_mesh((1, self.n_devices))
hybrid_mesh = self._get_hybrid_mesh((1, self.n_devices))
# Check if shape of hybrid mesh matches mesh
self.assertEqual(mesh.get_logical_mesh().shape,
hybrid_mesh.get_logical_mesh().shape)
@unittest.skipIf(xr.device_type() == 'TPU' and tpu.version() < 3,
"Crash on TPU v2")
@patch('torch_xla.runtime.global_runtime_device_attributes')
@patch('torch_xla.core.xla_model.xla_device_hw')
def test_hybrid_mesh(self, xla_device_mock, device_attributes_mock):
# mock device attributes for 2 slices of v4-8
num_slices = 2
xla_device_mock.return_value = "TPU"
device_attributes_mock.return_value = [{
'coords': [0, 0, 0],
'core_on_chip': 0,
'slice_index': 0,
'name': 'TPU:2'
}, {
'core_on_chip': 0,
'coords': [1, 0, 0],
'slice_index': 0,
'name': 'TPU:1'
}, {
'slice_index': 0,
'core_on_chip': 0,
'coords': [0, 1, 0],
'name': 'TPU:0'
}, {
'coords': [1, 1, 0],
'core_on_chip': 0,
'slice_index': 0,
'name': 'TPU:3'
}, {
'coords': [0, 0, 0],
'slice_index': 1,
'core_on_chip': 0,
'name': 'TPU:4'
}, {
'coords': [1, 0, 0],
'slice_index': 1,
'core_on_chip': 0,
'name': 'TPU:7'
}, {
'coords': [0, 1, 0],
'slice_index': 1,
'core_on_chip': 0,
'name': 'TPU:6'
}, {
'core_on_chip': 0,
'coords': [1, 1, 0],
'slice_index': 1,
'name': 'TPU:5'
}]
hybrid_mesh = xs.HybridMesh(
ici_mesh_shape=(2, 2), dcn_mesh_shape=(num_slices, 1))
self.assertEqual(hybrid_mesh.get_logical_mesh().tolist(),
[[2, 1], [0, 3], [4, 7], [6, 5]])
def test_mark_sharding_ir(self):
t1 = torch.randn(1, 128, device='cpu')
t2 = torch.randn(1, 128, device='cpu')
expected = t1 + t2
xt1 = t1.to(xm.xla_device())
xt2 = t2.to(xm.xla_device())
actual = xt1 + xt2
actual = xs.mark_sharding(actual, self._get_mesh((1, self.n_devices)),
(0, 1))
hlo = torch_xla._XLAC._get_xla_tensors_hlo([actual.global_tensor])
self.assertIn(
'%custom-call.7 = f32[1,128]{1,0} custom-call(f32[1,128]{1,0} %add.6), custom_call_target="Sharding", sharding=',
hlo)
actual += 0
hlo = torch_xla._XLAC._get_xla_tensors_hlo([actual.global_tensor])
self.assertIn(
'%add.12 = f32[1,128]{1,0} add(f32[1,128]{1,0} %custom-call.9, f32[1,128]{1,0} %broadcast.11)',
hlo)
self.assertTrue(torch.allclose(expected, actual.cpu()))
def test_sharded_tensor_aliasing(self):
met.clear_all()
partition_spec = (0, 1)
xt1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]],
dtype=torch.float,
device=xm.xla_device())
xst1 = xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)),
partition_spec)
xst1 += 1
xm.mark_step()
self.assertEqual(met.metric_data("InputOutputAliasCount")[0], 1)
def test_mark_sharding_ir_with_multiple_output(self):
partition_spec = (0,)
xt1 = torch.randn(8, 8).to(xm.xla_device())
# max return 2 tensors `value` and `indices`. They are the output
# of the same IR Node `MaxInDim`
(xt_val, xt_index) = torch.max(xt1, 1)
xst_val = xs.mark_sharding(xt_val, self._get_mesh((self.n_devices,)),
partition_spec)
# `xst_val`` should have sharding spec now, but `xst_index` should not
self.assertNotEqual(torch_xla._XLAC._get_xla_sharding_spec(xt_val), '')
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(xt_index), '')
# xst_index's HLO should not have any sharding
self.assertNotIn('convert(s32[8]{0} %get-tuple-element.25), sharding',
torch_xla._XLAC._get_xla_tensors_hlo([xt_index]))
def test_sharded_tensor_to_cpu_int_type(self):
partition_spec = (0, 1)
t1 = torch.arange(64).reshape(8, 8)
xt1 = t1.clone().to(xm.xla_device())
xst1 = xs.mark_sharding(xt1, self._get_mesh((self.n_devices, 1)),
partition_spec)
self.assertTrue(torch.allclose(t1, xst1.cpu()))
def test_named_partition_spec(self):
xt1 = torch.arange(64).reshape(8, 8).to(xm.xla_device())
mesh = xs.Mesh(
list(range(self.n_devices)), (1, self.n_devices), ('data', 'model'))
partition_spec = ('model', 'data')
xs.mark_sharding(xt1, mesh, partition_spec)
sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(xt1)
if self.n_devices > 1:
self.assertTrue(f"devices=[{self.n_devices},1]" in sharding_spec)
else:
self.assertTrue("replicated" in sharding_spec)
def test_shard_device_data_ir(self):
device = xm.xla_device()
xla_x = torch.randn(8, 128, device=device)
# xla_x now becomes a device data IR
xla_y = xla_x * 5
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(xla_x), '')
xs.mark_sharding(xla_x, self._get_mesh((1, self.n_devices)), (1, 0))
self.assertNotEqual(torch_xla._XLAC._get_xla_sharding_spec(xla_x), '')
xm.mark_step()
self.assertTrue(torch.allclose(xla_y.cpu(), xla_x.cpu() * 5))
def test_shard_device_data_ir_after_mark_step(self):
device = xm.xla_device()
xla_x = torch.randn(8, 128, device=device)
x = xla_x.cpu()
# xla_x now becomes a device data IR without XLAData
xm.mark_step()
xs.mark_sharding(xla_x, self._get_mesh((1, self.n_devices)), (1, 0))
self.assertNotEqual(torch_xla._XLAC._get_xla_sharding_spec(xla_x), '')
self.assertTrue(torch.allclose(xla_x.cpu(), x))
def test_op_sharding_cache(self):
met.clear_all()
mesh = self._get_mesh((1, self.n_devices))
t = torch.randn(1, self.n_devices).to(xm.xla_device())
xs.mark_sharding(t, mesh, (0, 1))
self.assertIn("CreateOpSharding", met.counter_names())
self.assertEqual(met.counter_value("CreateOpSharding"), 1)
# Sharding with the same partition spec should not result in another call
u = torch.randn(1, self.n_devices).to(xm.xla_device())
xs.mark_sharding(u, mesh, (0, 1))
self.assertEqual(met.counter_value("CreateOpSharding"), 1)
# Changing the partition spec will result in another CreateOpSharding
v = torch.randn(1, self.n_devices).to(xm.xla_device())
xs.mark_sharding(v, mesh, (0, None))
self.assertEqual(met.counter_value("CreateOpSharding"), 2)
def test_from_cpu_shards_replicated(self):
from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards
# Create an OpSharding with all devices on a single axis
mesh = self._get_mesh((self.n_devices,))
partition_spec = (None,)
op_sharding = mesh.get_op_sharding(partition_spec)
shards = [torch.arange(4)] * self.n_devices
# No shape should result in the shape of a single shard.
global_tensor = from_cpu_shards(shards, op_sharding)
self.assertTrue(torch.allclose(global_tensor.cpu(), shards[0]))
# Specify a valid shape for the global tensor
global_tensor = from_cpu_shards(shards, op_sharding, shards[0].shape)
self.assertTrue(torch.allclose(global_tensor.cpu(), shards[0]))
# All invalid shapes should raise
with self.assertRaises(RuntimeError):
from_cpu_shards(shards, op_sharding, torch.Size((5,)))
with self.assertRaises(RuntimeError):
from_cpu_shards(shards, op_sharding, torch.Size((3,)))
with self.assertRaises(RuntimeError):
from_cpu_shards(shards, op_sharding, torch.Size((2, 2)))
def test_from_cpu_shards_tiled(self):
from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards
# Create an OpSharding with all devices on a single axis
mesh = self._get_mesh((self.n_devices,))
partition_spec = (0,)
op_sharding = mesh.get_op_sharding(partition_spec)
shards = [torch.LongTensor([i]) for i in range(self.n_devices)]
global_tensor = from_cpu_shards(shards, op_sharding)
self.assertTrue(
torch.allclose(global_tensor.cpu(), torch.arange(self.n_devices)))
# Test incorrect number of shards
with self.assertRaises(RuntimeError):
from_cpu_shards(shards[:-1], op_sharding)
# Test an invalid global shape - too many values.
with self.assertRaises(RuntimeError):
from_cpu_shards(shards, op_sharding, torch.Size((self.n_devices * 2,)))
# Test an invalid global shape - incorrect rank
with self.assertRaises(RuntimeError):
from_cpu_shards(shards, op_sharding, torch.Size((1, self.n_devices)))
# Test a valid global shape - restrict the number of meaningful values
# to 1, treating the rest as padding.
global_tensor = from_cpu_shards(shards, op_sharding, torch.Size((1,)))
self.assertTrue(torch.allclose(global_tensor.cpu(), torch.arange(1)))
def test_from_cpu_shards_2d(self):
from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards
# Create an appropriate 2D mesh for the number of devices
if self.n_devices >= 4:
mesh_shape = (self.n_devices // 2, 2)
else:
mesh_shape = (1, self.n_devices)
mesh_2d = self._get_mesh(mesh_shape)
# Replicated sharding
shards = [torch.LongTensor([self.n_devices])] * self.n_devices
partition_spec = (None, None)
op_sharding = mesh_2d.get_op_sharding(partition_spec)
global_tensor = from_cpu_shards(shards, op_sharding)
self.assertTrue(torch.allclose(global_tensor.cpu(), shards[0]))
if self.n_devices > 1:
# Tiled sharding