From b41a8121c0ac38e44deacbea7295336b5e2c1350 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Thu, 8 Feb 2024 01:55:36 -0800 Subject: [PATCH] [FSDPv2] Use the global mesh API (#6500) Summary: Make FSDPv2 to use the global mesh API such that users can just set the global mesh. Test Plan: python test/spmd/test_fsdp_v2.py -v -k test_fsdp_v2_global_mesh --- test/spmd/test_fsdp_v2.py | 15 +++++++++++++++ .../spmd_fully_sharded_data_parallel.py | 8 +++++++- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/test/spmd/test_fsdp_v2.py b/test/spmd/test_fsdp_v2.py index e8f275a4b99e..776e154b35c1 100644 --- a/test/spmd/test_fsdp_v2.py +++ b/test/spmd/test_fsdp_v2.py @@ -114,6 +114,21 @@ def auto_wrapper_callable(m, *args, **kwargs): self.assertFalse(isinstance(model.fc1, FSDPv2)) self.assertFalse(isinstance(model.fc2, FSDPv2)) + def test_fsdp_v2_global_mesh(self): + model = self.SimpleLinear().to(xm.xla_device()) + mesh = self._get_mesh((self.n_devices, 1), None, ('fsdp', 'tensor')) + xs.set_global_mesh(mesh) + + model = FSDPv2(model) + self.assertEqual(id(model._mesh), id(mesh)) + + def test_fsdp_v2_global_mesh_error(self): + model = self.SimpleLinear().to(xm.xla_device()) + xs.set_global_mesh(None) + + with self.assertRaises(ValueError): + model = FSDPv2(model) + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/experimental/spmd_fully_sharded_data_parallel.py b/torch_xla/experimental/spmd_fully_sharded_data_parallel.py index 3eb4e3fdc4eb..783c8f567190 100644 --- a/torch_xla/experimental/spmd_fully_sharded_data_parallel.py +++ b/torch_xla/experimental/spmd_fully_sharded_data_parallel.py @@ -43,7 +43,7 @@ class SpmdFullyShardedDataParallel(nn.Module): def __init__( self, module: nn.Module, - mesh: spmd.Mesh, + mesh: Optional[spmd.Mesh] = None, shard_output: Optional[Callable] = None, auto_wrap_policy: Optional[Callable] = None, auto_wrapper_callable: Optional[Callable] = None, @@ -65,6 +65,12 @@ def __init__( "and do not perform the forward pass in other ways apart from the `forward` method. " "(i.e. you should directly call the FSDP-wrapped module itself in your code, " "instead of using any of its submodules or its weights).") + if mesh is None: + mesh = spmd.get_global_mesh() + if mesh is None: + raise ValueError( + "No mesh is provided and no global mesh is set. Please provide a mesh." + ) if "fsdp" not in mesh.axis_names: raise ValueError("The mesh must have an axis named 'fsdp'.")