Skip to content

Commit

Permalink
[FSDPv2] Support MultiSlice (pytorch#7044)
Browse files Browse the repository at this point in the history
Summary:
This pull request adds the multi-slice support for FSDPv2. Basically, the default setup is to use the dcn axis as the data axis, and it means we only do data parallel over multi-slices. In the future, we could also support FSDP over mutli-slices.

Test Plan:
PJRT_DEVICE=TPU python test/spmd/test_fsdp_v2.py
  • Loading branch information
alanwaketan authored May 11, 2024
1 parent 40f7e1f commit 6f0b61e
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 11 deletions.
70 changes: 62 additions & 8 deletions test/spmd/test_fsdp_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ def setUpClass(cls):
def test_fsdp_v2_basic(self):
model = self.SimpleLinear().to(xm.xla_device())
mesh = self._get_mesh((self.n_devices, 1), None, ('fsdp', 'tensor'))
model.fc1 = FSDPv2(model.fc1, mesh)
model.fc2 = FSDPv2(model.fc2, mesh)
model = FSDPv2(model, mesh)
model.fc1 = FSDPv2(model.fc1, mesh=mesh)
model.fc2 = FSDPv2(model.fc2, mesh=mesh)
model = FSDPv2(model, mesh=mesh)

# Make sure all weights are sharded.
if self.n_devices > 1:
Expand Down Expand Up @@ -67,9 +67,9 @@ def test_fsdp_v2_output_correctness(self):

model = copy.deepcopy(model_expected)
mesh = self._get_mesh((self.n_devices, 1), None, ('fsdp', 'tensor'))
model.fc1 = FSDPv2(model.fc1, mesh)
model.fc2 = FSDPv2(model.fc2, mesh)
model = FSDPv2(model, mesh)
model.fc1 = FSDPv2(model.fc1, mesh=mesh)
model.fc2 = FSDPv2(model.fc2, mesh=mesh)
model = FSDPv2(model, mesh=mesh)

x_expected = torch.randn(16, 128).to(xm.xla_device())

Expand All @@ -87,7 +87,7 @@ def test_fsdp_v2_auto_wrap_basic(self):
transformer_auto_wrap_policy,
transformer_layer_cls={torch.nn.Linear},
)
model = FSDPv2(model, mesh, auto_wrap_policy=auto_wrap_policy)
model = FSDPv2(model, mesh=mesh, auto_wrap_policy=auto_wrap_policy)

self.assertTrue(isinstance(model.fc1, FSDPv2))
self.assertTrue(isinstance(model.fc2, FSDPv2))
Expand All @@ -106,7 +106,7 @@ def auto_wrapper_callable(m, *args, **kwargs):

model = FSDPv2(
model,
mesh,
mesh=mesh,
auto_wrap_policy=auto_wrap_policy,
auto_wrapper_callable=auto_wrapper_callable)

Expand Down Expand Up @@ -139,6 +139,60 @@ def test_fsdp_v2_cpu_model(self):
self.assertEqual(
str(list(model._orig_module.parameters())[0].device), "xla:0")

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_fsdp_v2_multi_slice(self):
model = self.SimpleLinear().to(xm.xla_device())
mesh = self._get_mesh((2, self.n_devices // 2, 1), None,
('data', 'fsdp', 'tensor'))
model = FSDPv2(model, mesh=mesh, extra_data_axis="data")

# Make sure all weights are sharded.
annotation = '{devices=[2,1,2]0,2,1,3 last_tile_dim_replicate}'
self.assertEqual(annotation,
torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight))
self.assertEqual(annotation,
torch_xla._XLAC._get_xla_sharding_spec(model.fc2.weight))

x = torch.randn(16, 128).to(xm.xla_device())
xs.mark_sharding(x, mesh, (('data', 'fsdp'), None))
output = model(x)
# Make sure output are sharded.
annotation = '{devices=[4,1]0,1,2,3}'
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(x))
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(output))

# Make sure the model can execute without error.
xm.mark_step()
xm.wait_device_ops()

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_fsdp_v2_multi_slice_output_correctness(self):
model_expected = self.SimpleLinear().to(xm.xla_device())

model = copy.deepcopy(model_expected)
mesh = self._get_mesh((2, self.n_devices // 2, 1), None,
('data', 'fsdp', 'tensor'))
model = FSDPv2(model, mesh=mesh, extra_data_axis="data")

x_expected = torch.randn(16, 128).to(xm.xla_device())

x = copy.deepcopy(x_expected)
xs.mark_sharding(x, mesh, (('data', 'fsdp'), None))

output_expected = model_expected(x_expected)
output = model(x)
self.assertTrue(torch.allclose(output_expected.cpu(), output.cpu()))

def test_fsdp_v2_multi_slice_error(self):
model = self.SimpleLinear().to(xm.xla_device())
xs.set_global_mesh(
self._get_mesh((2, self.n_devices // 2, 1), None,
('data', 'fsdp', 'tensor')))

with self.assertRaisesRegex(ValueError,
"The provided ddp axis is not in the mesh."):
model = FSDPv2(model, extra_data_axis='ddp')


if __name__ == '__main__':
test = unittest.main()
Expand Down
1 change: 1 addition & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ python3 test/spmd/test_xla_distributed_checkpoint.py
python3 test/spmd/test_train_spmd_linear_model.py
python3 test/spmd/test_xla_spmd_python_api_interaction.py
python3 test/spmd/test_xla_auto_sharding.py
python3 test/spmd/test_fsdp_v2.py
XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 test/ds/test_dynamic_shape_models.py -v
XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 test/ds/test_dynamic_shapes.py -v
python3 test/test_autocast.py
Expand Down
14 changes: 11 additions & 3 deletions torch_xla/experimental/spmd_fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch_xla.distributed.fsdp.wrap import recursive_wrap


def _prepare_spmd_partition_spec(param):
def _prepare_spmd_partition_spec(param, extra_data_axis=None):
partition_spec = [None] * len(param.shape)
# Skip scalar tensors and it replicated.
if len(partition_spec) == 0:
Expand All @@ -24,6 +24,8 @@ def _prepare_spmd_partition_spec(param):
# TODO: should we shard on the maximal dim for param? Then we need
# another helper for the output.
partition_spec[0] = "fsdp"
if extra_data_axis:
partition_spec[0] = (extra_data_axis, "fsdp")
return tuple(partition_spec)


Expand All @@ -44,10 +46,12 @@ class SpmdFullyShardedDataParallel(nn.Module):
def __init__(
self,
module: nn.Module,
*,
mesh: Optional[spmd.Mesh] = None,
shard_output: Optional[Callable] = None,
auto_wrap_policy: Optional[Callable] = None,
auto_wrapper_callable: Optional[Callable] = None,
extra_data_axis: Optional[str] = None,
):
if isinstance(module, SpmdFullyShardedDataParallel):
raise RuntimeError(
Expand All @@ -74,6 +78,9 @@ def __init__(
)
if "fsdp" not in mesh.axis_names:
raise ValueError("The mesh must have an axis named 'fsdp'.")
if extra_data_axis and extra_data_axis not in mesh.axis_names:
raise ValueError(
f"The provided {extra_data_axis} axis is not in the mesh.")

super().__init__()

Expand Down Expand Up @@ -130,8 +137,9 @@ def shard_output_impl(output, mesh):
f"The output type is not supported: {type(output)}. Please provide your own shard_output callable."
)

spmd.mark_sharding(real_output, mesh,
_prepare_spmd_partition_spec(real_output))
spmd.mark_sharding(
real_output, mesh,
_prepare_spmd_partition_spec(real_output, extra_data_axis))

shard_output = shard_output_impl

Expand Down

0 comments on commit 6f0b61e

Please sign in to comment.