diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index f42289f8d26..04e079089f2 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -1089,6 +1089,13 @@ def test_mark_shard_scalar(self): with self.assertRaises(AttributeError): xt.mesh_shape + def test_global_mesh(self): + expected_mesh = self._get_mesh((1, self.n_devices)) + xs.set_global_mesh(expected_mesh) + mesh = xs.get_global_mesh() + + self.assertEqual(id(mesh), id(expected_mesh)) + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/distributed/spmd/__init__.py b/torch_xla/distributed/spmd/__init__.py index 7f494b74c9d..53376979d3c 100644 --- a/torch_xla/distributed/spmd/__init__.py +++ b/torch_xla/distributed/spmd/__init__.py @@ -1,7 +1,8 @@ from .xla_sharded_tensor import XLAShard, XLAShardedTensor from .xla_sharding import (Mesh, HybridMesh, ShardingType, ShardingSpec, XLAPatchedLinear, mark_sharding, clear_sharding, - wrap_if_sharded, xla_patched_nn_linear_forward) + wrap_if_sharded, xla_patched_nn_linear_forward, + set_global_mesh, get_global_mesh) from .api import xla_distribute_tensor, xla_distribute_module __all__ = [ @@ -18,4 +19,6 @@ "xla_distribute_tensor", "xla_distribute_module", "xla_patched_nn_linear_forward", + "set_global_mesh", + "get_global_mesh", ] diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index 6c7061aeb87..4c5d1cde0c8 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -123,6 +123,19 @@ def get_op_sharding(self, replication_groups, sharding_type) +_GLOBAL_MESH: Mesh = None + + +def set_global_mesh(mesh: Mesh): + global _GLOBAL_MESH + _GLOBAL_MESH = mesh + + +def get_global_mesh(): + global _GLOBAL_MESH + return _GLOBAL_MESH + + # HybridDevice class has been inspired from jax's mesh_utils: https://github.com/google/jax/blob/fc5960f2b8b7a0ef74dbae4e27c5c08ff1564cff/jax/experimental/mesh_utils.py#L4ƒ class HybridMesh(Mesh): """Creates a hybrid device mesh of devices connected with ICI and DCN networks.