Skip to content

Commit

Permalink
[FSDPv2] Move the module to xla device (#6525)
Browse files Browse the repository at this point in the history
Summary:
This change allows moving module to the xla device during wrapping such that the caller doesn't need to move the module to the xla device.

Test Plan:
python test/spmd/test_fsdp_v2.py -v -k test_fsdp_v2_cpu_model
  • Loading branch information
alanwaketan committed Feb 13, 2024
1 parent 959e478 commit afc1f0e
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
10 changes: 10 additions & 0 deletions test/spmd/test_fsdp_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,16 @@ def test_fsdp_v2_global_mesh_error(self):
with self.assertRaises(ValueError):
model = FSDPv2(model)

def test_fsdp_v2_cpu_model(self):
cpu_model = self.SimpleLinear()

mesh = self._get_mesh((self.n_devices, 1), None, ('fsdp', 'tensor'))
xs.set_global_mesh(mesh)

model = FSDPv2(cpu_model)
self.assertEqual(
str(list(model._orig_module.parameters())[0].device), "xla:0")


if __name__ == '__main__':
test = unittest.main()
Expand Down
5 changes: 4 additions & 1 deletion torch_xla/experimental/spmd_fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np

import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as spmd
from torch_xla.distributed.fsdp.wrap import recursive_wrap

Expand Down Expand Up @@ -94,7 +95,9 @@ def __init__(
)
self._auto_wrap(auto_wrap_kwargs, fsdp_kwargs)

self._orig_module = module
# Let's move the module to xla device in case it's not moved
# by the caller already.
self._orig_module = module.to(xm.xla_device())
self._mesh = mesh

# Only handle params which are not already sharded. This enables
Expand Down

0 comments on commit afc1f0e

Please sign in to comment.