Skip to content

Commit

Permalink
[FSDPv2] Use the global mesh API (pytorch#6500)
Browse files Browse the repository at this point in the history
Summary:
Make FSDPv2 to use the global mesh API such that users can just set the global mesh.

Test Plan:
python test/spmd/test_fsdp_v2.py -v -k test_fsdp_v2_global_mesh
  • Loading branch information
alanwaketan authored and amithrm committed Mar 1, 2024
1 parent e3a3060 commit b41a812
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
15 changes: 15 additions & 0 deletions test/spmd/test_fsdp_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,21 @@ def auto_wrapper_callable(m, *args, **kwargs):
self.assertFalse(isinstance(model.fc1, FSDPv2))
self.assertFalse(isinstance(model.fc2, FSDPv2))

def test_fsdp_v2_global_mesh(self):
model = self.SimpleLinear().to(xm.xla_device())
mesh = self._get_mesh((self.n_devices, 1), None, ('fsdp', 'tensor'))
xs.set_global_mesh(mesh)

model = FSDPv2(model)
self.assertEqual(id(model._mesh), id(mesh))

def test_fsdp_v2_global_mesh_error(self):
model = self.SimpleLinear().to(xm.xla_device())
xs.set_global_mesh(None)

with self.assertRaises(ValueError):
model = FSDPv2(model)


if __name__ == '__main__':
test = unittest.main()
Expand Down
8 changes: 7 additions & 1 deletion torch_xla/experimental/spmd_fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class SpmdFullyShardedDataParallel(nn.Module):
def __init__(
self,
module: nn.Module,
mesh: spmd.Mesh,
mesh: Optional[spmd.Mesh] = None,
shard_output: Optional[Callable] = None,
auto_wrap_policy: Optional[Callable] = None,
auto_wrapper_callable: Optional[Callable] = None,
Expand All @@ -65,6 +65,12 @@ def __init__(
"and do not perform the forward pass in other ways apart from the `forward` method. "
"(i.e. you should directly call the FSDP-wrapped module itself in your code, "
"instead of using any of its submodules or its weights).")
if mesh is None:
mesh = spmd.get_global_mesh()
if mesh is None:
raise ValueError(
"No mesh is provided and no global mesh is set. Please provide a mesh."
)
if "fsdp" not in mesh.axis_names:
raise ValueError("The mesh must have an axis named 'fsdp'.")

Expand Down

0 comments on commit b41a812

Please sign in to comment.