Skip to content

Commit

Permalink
Enable eager spmd (#7341)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG committed Jul 12, 2024
1 parent b34abad commit c00762d
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 1 deletion.
49 changes: 49 additions & 0 deletions examples/eager/train_decoder_only_eager_spmd_data_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import sys
import os
example_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0])))
sys.path.append(example_folder)
from train_decoder_only_base import TrainDecoderOnlyBase

import numpy as np

import torch
import torch_xla
import torch_xla.distributed.spmd as xs
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.utils as xu
from torch_xla import runtime as xr

# Enable the SPMD
xr.use_spmd()


# More detailed examaple can be found in https://github.com/pytorch/xla/blob/master/test/spmd/test_train_spmd_imagenet.py
# Check out our user guide in https://github.com/pytorch/xla/blob/master/docs/spmd.md
class TrainDecoderSpmdDDP(TrainDecoderOnlyBase):

def __init__(self):
super().__init__()
# Shard along batch dimension only
num_devices = xr.global_runtime_device_count()
device_ids = np.arange(num_devices)
mesh_shape = (num_devices,)
mesh = xs.Mesh(device_ids, mesh_shape, ('data',))
# scale the batch size with num_devices since there will be only one
# process that handles all runtime devices.
self.batch_size *= num_devices

train_loader = xu.SampleGenerator(
data=(torch.zeros(self.batch_size, self.seq_len, dtype=torch.int64),
torch.zeros(self.batch_size, self.seq_len, dtype=torch.int64)),
sample_count=self.train_dataset_len // self.batch_size)
self.train_device_loader = pl.MpDeviceLoader(
train_loader,
self.device,
# Shard the input's batch dimension along the `data` axis, no sharding along other dimensions
input_sharding=xs.ShardingSpec(mesh, ('data', None)))


if __name__ == '__main__':
torch_xla.experimental.eager_mode(True)
spmd_ddp = TrainDecoderSpmdDDP()
spmd_ddp.start_training()
44 changes: 44 additions & 0 deletions test/eager/test_eager_spmd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import unittest
import sys

import torch
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.distributed.spmd as xs
import numpy as np


class Eager(unittest.TestCase):

@classmethod
def setUpClass(cls):
torch_xla.experimental.eager_mode(True)
xr.use_spmd()
cls.n_devices = xr.global_runtime_device_count()
cls.device_ids = np.array(range(cls.n_devices))

def _get_mesh(self, mesh_shape, device_ids=None, axis_names=None):
assert type(mesh_shape) is tuple, 'mesh_shape must be Tuple[int]'
if device_ids is None:
device_ids = self.device_ids
assert len(device_ids) == self.n_devices
return xs.Mesh(device_ids, mesh_shape, axis_names)

def test_eager_spmd_basic(self):
device = torch_xla.device()
mesh = self._get_mesh((self.n_devices,), axis_names=('data',))
torch.manual_seed(100)
linear = torch.nn.Linear(10, 20)
input = torch.randn(8, 10)
input_xla = input.to(device)
res = linear(input)
linear.to(device)
res_xla = linear(input_xla)
self.assertTrue(torch.allclose(res, res_xla.cpu(), rtol=1e-3))


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ function run_xla_op_tests2 {
run_test "$CDIR/eager/test_eager_with_xla_compile.py"
run_test "$CDIR/eager/test_eager_with_torch_compile.py"
run_test "$CDIR/eager/test_eager_all_reduce_in_place.py"
run_test "$CDIR/eager/test_eager_spmd.py"
}

# All the new xla op tests should go to run_xla_op_tests3
Expand Down
2 changes: 2 additions & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@ python3 examples/train_resnet_amp.py

# HACK: don't confuse local `torch_xla` folder with installed package
# Python 3.11 has the permanent fix: https://stackoverflow.com/a/73636559
# Egaer tests will take more HBM, only run them on TPU v4 CI
TPU_VERSION=$(python -c "import sys; sys.path.remove(''); import torch_xla; print(torch_xla._internal.tpu.version())")
if [[ -n "$TPU_VERSION" && "$TPU_VERSION" == "4" ]]; then
python3 examples/eager/train_decoder_only_eager.py
python3 examples/eager/train_decoder_only_eager_spmd_data_parallel.py
python3 examples/eager/train_decoder_only_eager_with_compile.py
python3 examples/eager/train_decoder_only_eager_multi_process.py
fi
4 changes: 3 additions & 1 deletion torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2742,7 +2742,9 @@ void XLANativeFunctions::_propagate_xla_data(const at::Tensor& input,

// 2) Aid SPMD.
XLATensor::ShardingSpecPtr sharding = input_tensor->sharding_spec();
if (sharding && sharding->sharding.type() != xla::OpSharding::UNKNOWN) {
// don't propagate sharding in eager mode.
if (!XLAGraphExecutor::Get()->UseEagerMode() && sharding &&
sharding->sharding.type() != xla::OpSharding::UNKNOWN) {
tensor_methods::custom_sharding_(output_tensor,
input_tensor->sharding_spec());
}
Expand Down

0 comments on commit c00762d

Please sign in to comment.