Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
jonb377 committed Oct 9, 2023
1 parent e38c56c commit 5585a72
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions test/spmd/test_xla_spmd_python_api_interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import sys

import torch
import torch.distributed as dist
import torch_xla
import torch_xla.distributed.xla_backend
import torch_xla.core.xla_model as xm
from torch_xla import runtime as xr
from torch_xla.amp import autocast
Expand Down Expand Up @@ -132,6 +134,18 @@ def test_xla_autocast_api(self):
self.assertTrue(t3.dtype == expected_dtype)


class BasicDistributedTest(test_xla_sharding_base.XlaShardingTest):
@classmethod
def setUpClass(cls):
xr.use_spmd()
return super().setUpClass()

def test_xla_backend(self):
# XLA backend is not supported with SPMD
with self.assertRaises(AssertionError):
dist.init_process_group('xla', init_method='xla://')


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)

0 comments on commit 5585a72

Please sign in to comment.