Skip to content

Commit

Permalink
[FSDPv2] Enable auto-wrapping (#6499)
Browse files Browse the repository at this point in the history
Summary:
This pull request enables auto-wrapping for FSDPv2 following the original XlaFullyShardedDataParallel design and implementation.

Test Plan:
python test/spmd/test_fsdp_v2.py -v -k test_fsdp_v2_auto_wrap_
  • Loading branch information
alanwaketan authored and bhavya01 committed Apr 22, 2024
1 parent 3ae25e2 commit 7128027
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 5 deletions.
36 changes: 36 additions & 0 deletions test/spmd/test_fsdp_v2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import functools
import unittest
import os
import sys
Expand All @@ -8,6 +9,7 @@
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as xs
from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy

import test_xla_sharding_base
from torch_xla.experimental.spmd_fully_sharded_data_parallel import SpmdFullyShardedDataParallel as FSDPv2
Expand Down Expand Up @@ -78,6 +80,40 @@ def test_fsdp_v2_output_correctness(self):
output = model(x)
self.assertTrue(torch.allclose(output_expected.cpu(), output.cpu()))

def test_fsdp_v2_auto_wrap_basic(self):
model = self.SimpleLinear().to(xm.xla_device())
mesh = self._get_mesh((self.n_devices, 1), None, ('fsdp', 'tensor'))
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={torch.nn.Linear},
)
model = FSDPv2(model, mesh, auto_wrap_policy=auto_wrap_policy)

self.assertTrue(isinstance(model.fc1, FSDPv2))
self.assertTrue(isinstance(model.fc2, FSDPv2))

def test_fsdp_v2_auto_wrap_callable(self):
model = self.SimpleLinear().to(xm.xla_device())
mesh = self._get_mesh((self.n_devices, 1), None, ('fsdp', 'tensor'))
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={torch.nn.Linear},
)

def auto_wrapper_callable(m, *args, **kwargs):
# Does nothing.
return m

model = FSDPv2(
model,
mesh,
auto_wrap_policy=auto_wrap_policy,
auto_wrapper_callable=auto_wrapper_callable)

# Since the callable is doing nothing, the children should not be wrapped.
self.assertFalse(isinstance(model.fc1, FSDPv2))
self.assertFalse(isinstance(model.fc2, FSDPv2))


if __name__ == '__main__':
test = unittest.main()
Expand Down
59 changes: 54 additions & 5 deletions torch_xla/experimental/spmd_fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import (Any, Callable, Optional, Union)
from typing import (Any, Callable, Dict, Optional, Union)
import warnings

import torch
Expand All @@ -9,6 +9,7 @@

import torch_xla
import torch_xla.distributed.spmd as spmd
from torch_xla.distributed.fsdp.wrap import recursive_wrap


def _prepare_spmd_partition_spec(param):
Expand Down Expand Up @@ -39,10 +40,14 @@ class SpmdFullyShardedDataParallel(nn.Module):
If the output is a tuple, only the first tensor will be sharded.
"""

def __init__(self,
module: nn.Module,
mesh: spmd.Mesh,
shard_output: Optional[Callable] = None):
def __init__(
self,
module: nn.Module,
mesh: spmd.Mesh,
shard_output: Optional[Callable] = None,
auto_wrap_policy: Optional[Callable] = None,
auto_wrapper_callable: Optional[Callable] = None,
):
if isinstance(module, SpmdFullyShardedDataParallel):
raise RuntimeError(
"Cannot wrap a module that is already wrapped with FSDP. For nested FSDP, "
Expand All @@ -65,6 +70,24 @@ def __init__(self,

super().__init__()

wrapper_cls = auto_wrapper_callable or SpmdFullyShardedDataParallel
if auto_wrap_policy is not None:
auto_wrap_kwargs = {
"module": module,
"auto_wrap_policy": auto_wrap_policy,
"wrapper_cls": wrapper_cls,
"ignored_modules": [],
"ignored_params": [],
"only_wrap_children": True, # avoid double wrapping the root
}
fsdp_kwargs = dict(
mesh=mesh,
shard_output=shard_output,
# `auto_wrap_policy` doesn't need to be specified in auto-wrapping
# `auto_wrapper_callable`` doesn't need to be specified in auto-wrapping
)
self._auto_wrap(auto_wrap_kwargs, fsdp_kwargs)

self._orig_module = module
self._mesh = mesh

Expand Down Expand Up @@ -127,3 +150,29 @@ def __getattr__(self, name: str) -> Union[torch.Tensor, nn.Module]:
def __getitem__(self, key: int) -> nn.Module:
"""Forward indexing calls in case the module is a nn.Sequential."""
return self.module.__getitem__(key)

def _auto_wrap(
self,
auto_wrap_kwargs: Dict[str, Any],
fsdp_kwargs: Dict[str, Any],
) -> None:
"""
Recursively auto wraps the root module given by the key "module" in
``auto_wrap_kwargs`` with the arguments in ``auto_wrap_kwargs`` and
``fsdp_kwargs``.
Precondition: ``auto_wrap_policy`` contains the arguments expected by
``_recursive_wrap()``, where ``auto_wrap_policy`` is not ``None``.
``fsdp_kwargs`` contains all FSDP arguments except ``module``.
"""
auto_wrap_policy = auto_wrap_kwargs["auto_wrap_policy"]
root_module = auto_wrap_kwargs["module"]
assert auto_wrap_policy is not None
# For auto wrapping, submodules should not already be wrapped with FSDP
# since double wrapping is not supported
for module_name, module in root_module.named_modules():
if isinstance(module, SpmdFullyShardedDataParallel):
raise ValueError(
f"Expected {module_name} to NOT be SpmdFullyShardedDataParallel "
"if using an `auto_wrap_policy`")

recursive_wrap(**auto_wrap_kwargs, **fsdp_kwargs)

0 comments on commit 7128027

Please sign in to comment.