diff --git a/test/spmd/test_fsdp_v2.py b/test/spmd/test_fsdp_v2.py index 776e154b35c..ae997892547 100644 --- a/test/spmd/test_fsdp_v2.py +++ b/test/spmd/test_fsdp_v2.py @@ -129,6 +129,16 @@ def test_fsdp_v2_global_mesh_error(self): with self.assertRaises(ValueError): model = FSDPv2(model) + def test_fsdp_v2_cpu_model(self): + cpu_model = self.SimpleLinear() + + mesh = self._get_mesh((self.n_devices, 1), None, ('fsdp', 'tensor')) + xs.set_global_mesh(mesh) + + model = FSDPv2(cpu_model) + self.assertEqual( + str(list(model._orig_module.parameters())[0].device), "xla:0") + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/experimental/spmd_fully_sharded_data_parallel.py b/torch_xla/experimental/spmd_fully_sharded_data_parallel.py index 783c8f56719..461d66b8565 100644 --- a/torch_xla/experimental/spmd_fully_sharded_data_parallel.py +++ b/torch_xla/experimental/spmd_fully_sharded_data_parallel.py @@ -8,6 +8,7 @@ import numpy as np import torch_xla +import torch_xla.core.xla_model as xm import torch_xla.distributed.spmd as spmd from torch_xla.distributed.fsdp.wrap import recursive_wrap @@ -94,7 +95,9 @@ def __init__( ) self._auto_wrap(auto_wrap_kwargs, fsdp_kwargs) - self._orig_module = module + # Let's move the module to xla device in case it's not moved + # by the caller already. + self._orig_module = module.to(xm.xla_device()) self._mesh = mesh # Only handle params which are not already sharded. This enables