From 7ccbf8f8d18e4159d974d676f45989b246d8d555 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Thu, 8 Feb 2024 00:43:47 +0000 Subject: [PATCH 1/3] [FSDPv2] Enable auto-wrapping Summary: This pull request enables auto-wrapping for FSDPv2 following the original XlaFullyShardedDataParallel design and implementation. Test Plan: python test/spmd/test_fsdp_v2.py -v -k test_fsdp_v2_auto_wrap_ --- test/spmd/test_fsdp_v2.py | 60 ++++++++++++++++++- .../spmd_fully_sharded_data_parallel.py | 52 +++++++++++++++- 2 files changed, 108 insertions(+), 4 deletions(-) diff --git a/test/spmd/test_fsdp_v2.py b/test/spmd/test_fsdp_v2.py index 0057efcf10f..6a0ce21b98a 100644 --- a/test/spmd/test_fsdp_v2.py +++ b/test/spmd/test_fsdp_v2.py @@ -1,4 +1,5 @@ import copy +import functools import unittest import os import sys @@ -8,11 +9,11 @@ import torch_xla.runtime as xr import torch_xla.core.xla_model as xm import torch_xla.distributed.spmd as xs +from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy import test_xla_sharding_base from torch_xla.experimental.spmd_fully_sharded_data_parallel import SpmdFullyShardedDataParallel as FSDPv2 - # TODO(alanwaketan): Add more tests for FSDPv2. class FSDPv2Test(test_xla_sharding_base.XlaShardingTest): @@ -78,6 +79,63 @@ def test_fsdp_v2_output_correctness(self): output = model(x) self.assertTrue(torch.allclose(output_expected.cpu(), output.cpu())) + def test_fsdp_v2_auto_wrap_basic(self): + model = self.SimpleLinear().to(xm.xla_device()) + mesh = self._get_mesh((self.n_devices, 1), None, ('fsdp', 'tensor')) + auto_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls={torch.nn.Linear}, + ) + model = FSDPv2(model, mesh, auto_wrap_policy=auto_wrap_policy) + + # Make sure all weights are sharded. + if self.n_devices > 1: + annotation = '{devices=[%d,1]%s}' % (self.n_devices, ','.join( + [str(i) for i in range(self.n_devices)])) + self.assertEqual(annotation, + torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight)) + self.assertEqual(annotation, + torch_xla._XLAC._get_xla_sharding_spec(model.fc2.weight)) + + x = torch.randn(16, 128).to(xm.xla_device()) + xs.mark_sharding(x, mesh, ('fsdp', None)) + output = model(x) + # Make sure output are sharded. + if self.n_devices > 1: + annotation = '{devices=[%d,1]%s}' % (self.n_devices, ','.join( + [str(i) for i in range(self.n_devices)])) + self.assertEqual(annotation, + torch_xla._XLAC._get_xla_sharding_spec(output)) + + loss = output.sum() + loss.backward() + + # Make sure optimization barrier is applied. + hlo = torch_xla._XLAC._get_xla_tensors_hlo([model.fc2.weight.grad]) + self.assertIn( + 'opt-barrier.38 = (f32[1,64]{0,1}, f32[1]{0}, f32[16,64]{1,0}) opt-barrier((f32[1,64]{0,1}, f32[1]{0}, f32[16,64]{1,0}) %tuple.37', + hlo) + + # Make sure the model can execute without error. + xm.mark_step() + xm.wait_device_ops() + + def test_fsdp_v2_auto_wrap_callable(self): + model = self.SimpleLinear().to(xm.xla_device()) + mesh = self._get_mesh((self.n_devices, 1), None, ('fsdp', 'tensor')) + auto_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls={torch.nn.Linear}, + ) + def auto_wrapper_callable(m, *args, **kwargs): + # Does nothing. + return m + model = FSDPv2(model, mesh, auto_wrap_policy=auto_wrap_policy, auto_wrapper_callable=auto_wrapper_callable) + + # Since the callable is doing nothing, the weights should not be sharded. + self.assertEqual("{replicated}", torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight)) + self.assertEqual("{replicated}", torch_xla._XLAC._get_xla_sharding_spec(model.fc2.weight)) + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/experimental/spmd_fully_sharded_data_parallel.py b/torch_xla/experimental/spmd_fully_sharded_data_parallel.py index 67cc58a43a0..6d5e1a0d91a 100644 --- a/torch_xla/experimental/spmd_fully_sharded_data_parallel.py +++ b/torch_xla/experimental/spmd_fully_sharded_data_parallel.py @@ -1,4 +1,4 @@ -from typing import (Any, Callable, Optional, Union) +from typing import (Any, Callable, Dict, Optional, Union) import warnings import torch @@ -9,7 +9,7 @@ import torch_xla import torch_xla.distributed.spmd as spmd - +from torch_xla.distributed.fsdp.wrap import recursive_wrap def _prepare_spmd_partition_spec(param): partition_spec = [None] * len(param.shape) @@ -42,7 +42,9 @@ class SpmdFullyShardedDataParallel(nn.Module): def __init__(self, module: nn.Module, mesh: spmd.Mesh, - shard_output: Optional[Callable] = None): + shard_output: Optional[Callable] = None, + auto_wrap_policy: Optional[Callable] = None, + auto_wrapper_callable: Optional[Callable] = None,): if isinstance(module, SpmdFullyShardedDataParallel): raise RuntimeError( "Cannot wrap a module that is already wrapped with FSDP. For nested FSDP, " @@ -65,6 +67,24 @@ def __init__(self, super().__init__() + wrapper_cls = auto_wrapper_callable or SpmdFullyShardedDataParallel + if auto_wrap_policy is not None: + auto_wrap_kwargs = { + "module": module, + "auto_wrap_policy": auto_wrap_policy, + "wrapper_cls": wrapper_cls, + "ignored_modules": [], + "ignored_params": [], + "only_wrap_children": True, # avoid double wrapping the root + } + fsdp_kwargs = dict( + mesh=mesh, + shard_output=shard_output, + # `auto_wrap_policy` doesn't need to be specified in auto-wrapping + # `auto_wrapper_callable`` doesn't need to be specified in auto-wrapping + ) + self._auto_wrap(auto_wrap_kwargs, fsdp_kwargs) + self._orig_module = module self._mesh = mesh @@ -127,3 +147,29 @@ def __getattr__(self, name: str) -> Union[torch.Tensor, nn.Module]: def __getitem__(self, key: int) -> nn.Module: """Forward indexing calls in case the module is a nn.Sequential.""" return self.module.__getitem__(key) + + def _auto_wrap( + self, + auto_wrap_kwargs: Dict[str, Any], + fsdp_kwargs: Dict[str, Any], + ) -> None: + """ + Recursively auto wraps the root module given by the key "module" in + ``auto_wrap_kwargs`` with the arguments in ``auto_wrap_kwargs`` and + ``fsdp_kwargs``. + Precondition: ``auto_wrap_policy`` contains the arguments expected by + ``_recursive_wrap()``, where ``auto_wrap_policy`` is not ``None``. + ``fsdp_kwargs`` contains all FSDP arguments except ``module``. + """ + auto_wrap_policy = auto_wrap_kwargs["auto_wrap_policy"] + root_module = auto_wrap_kwargs["module"] + assert auto_wrap_policy is not None + # For auto wrapping, submodules should not already be wrapped with FSDP + # since double wrapping is not supported + for module_name, module in root_module.named_modules(): + if isinstance(module, SpmdFullyShardedDataParallel): + raise ValueError( + f"Expected {module_name} to NOT be SpmdFullyShardedDataParallel " + "if using an `auto_wrap_policy`") + + recursive_wrap(**auto_wrap_kwargs, **fsdp_kwargs) From b4eb6c1ab05a203d61a3671b08d05586914385ae Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Thu, 8 Feb 2024 01:07:41 +0000 Subject: [PATCH 2/3] Fix the tests --- test/spmd/test_fsdp_v2.py | 39 +++++---------------------------------- 1 file changed, 5 insertions(+), 34 deletions(-) diff --git a/test/spmd/test_fsdp_v2.py b/test/spmd/test_fsdp_v2.py index 6a0ce21b98a..d4488fa8b5a 100644 --- a/test/spmd/test_fsdp_v2.py +++ b/test/spmd/test_fsdp_v2.py @@ -88,37 +88,8 @@ def test_fsdp_v2_auto_wrap_basic(self): ) model = FSDPv2(model, mesh, auto_wrap_policy=auto_wrap_policy) - # Make sure all weights are sharded. - if self.n_devices > 1: - annotation = '{devices=[%d,1]%s}' % (self.n_devices, ','.join( - [str(i) for i in range(self.n_devices)])) - self.assertEqual(annotation, - torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight)) - self.assertEqual(annotation, - torch_xla._XLAC._get_xla_sharding_spec(model.fc2.weight)) - - x = torch.randn(16, 128).to(xm.xla_device()) - xs.mark_sharding(x, mesh, ('fsdp', None)) - output = model(x) - # Make sure output are sharded. - if self.n_devices > 1: - annotation = '{devices=[%d,1]%s}' % (self.n_devices, ','.join( - [str(i) for i in range(self.n_devices)])) - self.assertEqual(annotation, - torch_xla._XLAC._get_xla_sharding_spec(output)) - - loss = output.sum() - loss.backward() - - # Make sure optimization barrier is applied. - hlo = torch_xla._XLAC._get_xla_tensors_hlo([model.fc2.weight.grad]) - self.assertIn( - 'opt-barrier.38 = (f32[1,64]{0,1}, f32[1]{0}, f32[16,64]{1,0}) opt-barrier((f32[1,64]{0,1}, f32[1]{0}, f32[16,64]{1,0}) %tuple.37', - hlo) - - # Make sure the model can execute without error. - xm.mark_step() - xm.wait_device_ops() + self.assertTrue(isinstance(model.fc1, FSDPv2)) + self.assertTrue(isinstance(model.fc2, FSDPv2)) def test_fsdp_v2_auto_wrap_callable(self): model = self.SimpleLinear().to(xm.xla_device()) @@ -132,9 +103,9 @@ def auto_wrapper_callable(m, *args, **kwargs): return m model = FSDPv2(model, mesh, auto_wrap_policy=auto_wrap_policy, auto_wrapper_callable=auto_wrapper_callable) - # Since the callable is doing nothing, the weights should not be sharded. - self.assertEqual("{replicated}", torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight)) - self.assertEqual("{replicated}", torch_xla._XLAC._get_xla_sharding_spec(model.fc2.weight)) + # Since the callable is doing nothing, the children should not be wrapped. + self.assertFalse(isinstance(model.fc1, FSDPv2)) + self.assertFalse(isinstance(model.fc2, FSDPv2)) if __name__ == '__main__': From 114d17efa9e61c652a5a96db8b30d95a97919005 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Thu, 8 Feb 2024 01:10:54 +0000 Subject: [PATCH 3/3] Fix the linters --- test/spmd/test_fsdp_v2.py | 9 ++++++++- .../spmd_fully_sharded_data_parallel.py | 15 +++++++++------ 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/test/spmd/test_fsdp_v2.py b/test/spmd/test_fsdp_v2.py index d4488fa8b5a..e8f275a4b99 100644 --- a/test/spmd/test_fsdp_v2.py +++ b/test/spmd/test_fsdp_v2.py @@ -14,6 +14,7 @@ import test_xla_sharding_base from torch_xla.experimental.spmd_fully_sharded_data_parallel import SpmdFullyShardedDataParallel as FSDPv2 + # TODO(alanwaketan): Add more tests for FSDPv2. class FSDPv2Test(test_xla_sharding_base.XlaShardingTest): @@ -98,10 +99,16 @@ def test_fsdp_v2_auto_wrap_callable(self): transformer_auto_wrap_policy, transformer_layer_cls={torch.nn.Linear}, ) + def auto_wrapper_callable(m, *args, **kwargs): # Does nothing. return m - model = FSDPv2(model, mesh, auto_wrap_policy=auto_wrap_policy, auto_wrapper_callable=auto_wrapper_callable) + + model = FSDPv2( + model, + mesh, + auto_wrap_policy=auto_wrap_policy, + auto_wrapper_callable=auto_wrapper_callable) # Since the callable is doing nothing, the children should not be wrapped. self.assertFalse(isinstance(model.fc1, FSDPv2)) diff --git a/torch_xla/experimental/spmd_fully_sharded_data_parallel.py b/torch_xla/experimental/spmd_fully_sharded_data_parallel.py index 6d5e1a0d91a..3eb4e3fdc4e 100644 --- a/torch_xla/experimental/spmd_fully_sharded_data_parallel.py +++ b/torch_xla/experimental/spmd_fully_sharded_data_parallel.py @@ -11,6 +11,7 @@ import torch_xla.distributed.spmd as spmd from torch_xla.distributed.fsdp.wrap import recursive_wrap + def _prepare_spmd_partition_spec(param): partition_spec = [None] * len(param.shape) # Skip scalar tensors and it replicated. @@ -39,12 +40,14 @@ class SpmdFullyShardedDataParallel(nn.Module): If the output is a tuple, only the first tensor will be sharded. """ - def __init__(self, - module: nn.Module, - mesh: spmd.Mesh, - shard_output: Optional[Callable] = None, - auto_wrap_policy: Optional[Callable] = None, - auto_wrapper_callable: Optional[Callable] = None,): + def __init__( + self, + module: nn.Module, + mesh: spmd.Mesh, + shard_output: Optional[Callable] = None, + auto_wrap_policy: Optional[Callable] = None, + auto_wrapper_callable: Optional[Callable] = None, + ): if isinstance(module, SpmdFullyShardedDataParallel): raise RuntimeError( "Cannot wrap a module that is already wrapped with FSDP. For nested FSDP, "