diff --git a/test/spmd/test_fsdp_v2.py b/test/spmd/test_fsdp_v2.py index 0057efcf10f..e8f275a4b99 100644 --- a/test/spmd/test_fsdp_v2.py +++ b/test/spmd/test_fsdp_v2.py @@ -1,4 +1,5 @@ import copy +import functools import unittest import os import sys @@ -8,6 +9,7 @@ import torch_xla.runtime as xr import torch_xla.core.xla_model as xm import torch_xla.distributed.spmd as xs +from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy import test_xla_sharding_base from torch_xla.experimental.spmd_fully_sharded_data_parallel import SpmdFullyShardedDataParallel as FSDPv2 @@ -78,6 +80,40 @@ def test_fsdp_v2_output_correctness(self): output = model(x) self.assertTrue(torch.allclose(output_expected.cpu(), output.cpu())) + def test_fsdp_v2_auto_wrap_basic(self): + model = self.SimpleLinear().to(xm.xla_device()) + mesh = self._get_mesh((self.n_devices, 1), None, ('fsdp', 'tensor')) + auto_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls={torch.nn.Linear}, + ) + model = FSDPv2(model, mesh, auto_wrap_policy=auto_wrap_policy) + + self.assertTrue(isinstance(model.fc1, FSDPv2)) + self.assertTrue(isinstance(model.fc2, FSDPv2)) + + def test_fsdp_v2_auto_wrap_callable(self): + model = self.SimpleLinear().to(xm.xla_device()) + mesh = self._get_mesh((self.n_devices, 1), None, ('fsdp', 'tensor')) + auto_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls={torch.nn.Linear}, + ) + + def auto_wrapper_callable(m, *args, **kwargs): + # Does nothing. + return m + + model = FSDPv2( + model, + mesh, + auto_wrap_policy=auto_wrap_policy, + auto_wrapper_callable=auto_wrapper_callable) + + # Since the callable is doing nothing, the children should not be wrapped. + self.assertFalse(isinstance(model.fc1, FSDPv2)) + self.assertFalse(isinstance(model.fc2, FSDPv2)) + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/experimental/spmd_fully_sharded_data_parallel.py b/torch_xla/experimental/spmd_fully_sharded_data_parallel.py index 67cc58a43a0..3eb4e3fdc4e 100644 --- a/torch_xla/experimental/spmd_fully_sharded_data_parallel.py +++ b/torch_xla/experimental/spmd_fully_sharded_data_parallel.py @@ -1,4 +1,4 @@ -from typing import (Any, Callable, Optional, Union) +from typing import (Any, Callable, Dict, Optional, Union) import warnings import torch @@ -9,6 +9,7 @@ import torch_xla import torch_xla.distributed.spmd as spmd +from torch_xla.distributed.fsdp.wrap import recursive_wrap def _prepare_spmd_partition_spec(param): @@ -39,10 +40,14 @@ class SpmdFullyShardedDataParallel(nn.Module): If the output is a tuple, only the first tensor will be sharded. """ - def __init__(self, - module: nn.Module, - mesh: spmd.Mesh, - shard_output: Optional[Callable] = None): + def __init__( + self, + module: nn.Module, + mesh: spmd.Mesh, + shard_output: Optional[Callable] = None, + auto_wrap_policy: Optional[Callable] = None, + auto_wrapper_callable: Optional[Callable] = None, + ): if isinstance(module, SpmdFullyShardedDataParallel): raise RuntimeError( "Cannot wrap a module that is already wrapped with FSDP. For nested FSDP, " @@ -65,6 +70,24 @@ def __init__(self, super().__init__() + wrapper_cls = auto_wrapper_callable or SpmdFullyShardedDataParallel + if auto_wrap_policy is not None: + auto_wrap_kwargs = { + "module": module, + "auto_wrap_policy": auto_wrap_policy, + "wrapper_cls": wrapper_cls, + "ignored_modules": [], + "ignored_params": [], + "only_wrap_children": True, # avoid double wrapping the root + } + fsdp_kwargs = dict( + mesh=mesh, + shard_output=shard_output, + # `auto_wrap_policy` doesn't need to be specified in auto-wrapping + # `auto_wrapper_callable`` doesn't need to be specified in auto-wrapping + ) + self._auto_wrap(auto_wrap_kwargs, fsdp_kwargs) + self._orig_module = module self._mesh = mesh @@ -127,3 +150,29 @@ def __getattr__(self, name: str) -> Union[torch.Tensor, nn.Module]: def __getitem__(self, key: int) -> nn.Module: """Forward indexing calls in case the module is a nn.Sequential.""" return self.module.__getitem__(key) + + def _auto_wrap( + self, + auto_wrap_kwargs: Dict[str, Any], + fsdp_kwargs: Dict[str, Any], + ) -> None: + """ + Recursively auto wraps the root module given by the key "module" in + ``auto_wrap_kwargs`` with the arguments in ``auto_wrap_kwargs`` and + ``fsdp_kwargs``. + Precondition: ``auto_wrap_policy`` contains the arguments expected by + ``_recursive_wrap()``, where ``auto_wrap_policy`` is not ``None``. + ``fsdp_kwargs`` contains all FSDP arguments except ``module``. + """ + auto_wrap_policy = auto_wrap_kwargs["auto_wrap_policy"] + root_module = auto_wrap_kwargs["module"] + assert auto_wrap_policy is not None + # For auto wrapping, submodules should not already be wrapped with FSDP + # since double wrapping is not supported + for module_name, module in root_module.named_modules(): + if isinstance(module, SpmdFullyShardedDataParallel): + raise ValueError( + f"Expected {module_name} to NOT be SpmdFullyShardedDataParallel " + "if using an `auto_wrap_policy`") + + recursive_wrap(**auto_wrap_kwargs, **fsdp_kwargs)