diff --git a/docs/spmd.md b/docs/spmd.md index 89d206e933f..c139c613d8d 100644 --- a/docs/spmd.md +++ b/docs/spmd.md @@ -492,4 +492,33 @@ generated_table = visualize_sharding(sharding, use_color=False) You could use these examples on TPU/GPU/CPU single-host and modify it to run on multi-host. And you could modify it to sharding-style `tiled`, `partial_replication` and `replicated`. +### Auto-Sharding +We are introducing a new PyTorch/XLA SPMD feature, called ``auto-sharding``, [RFC](https://github.com/pytorch/xla/issues/6322). This is an experimental feature in `r2.3` and `nightly`, that supports `XLA:TPU` and a single TPUVM host. +PyTorch/XLA auto-sharding can be enabled by one of the following: +- Setting envvar `XLA_SPMD_AUTO=1` +- Calling the SPMD API in the beginning of your code: +```python +import torch_xla.runtime as xr +xr.use_spmd(auto=True) +``` +- Calling `pytorch.distributed._tensor.distribute_module` with `auto-policy` and `xla`: +```python +import torch_xla.runtime as xr +from torch.distributed._tensor import DeviceMesh, distribute_module +from torch_xla.distributed.spmd import auto_policy + +device_count = xr.global_runtime_device_count() +device_mesh = DeviceMesh("xla", list(range(device_count))) + +# Currently, model should be loaded to xla device via distribute_module. +model = MyModule() # nn.module +sharded_model = distribute_module(model, device_mesh, auto_policy) +``` + +Optionally, one can set the following options/env-vars to control the behvaior of +the XLA-based auto-sharding pass: +- `XLA_AUTO_USE_GROUP_SHARDING`: group resharding of the parameters. Set by default. +- `XLA_AUTO_SPMD_MESH`: logical mesh shape to be used for auto-sharding. For example, +`XLA_AUTO_SPMD_MESH=2,2` corresponds to a 2-by-2 mesh with 4 global devices. If unset, +a default device mesh shape of `num_devices,1` will be used. diff --git a/test/run_tests.sh b/test/run_tests.sh index 206539632d6..ebd6e109db2 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -226,6 +226,8 @@ function run_xla_op_tests3 { run_test "$CDIR/spmd/test_xla_distributed_checkpoint.py" run_test "$CDIR/spmd/test_xla_spmd_python_api_interaction.py" run_test "$CDIR/spmd/test_dtensor_integration.py" + run_test "$CDIR/spmd/test_dtensor_integration2.py" + run_test "$CDIR/spmd/test_xla_auto_sharding.py" run_test "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY run_test "$CDIR/test_input_output_aliases.py" run_test "$CDIR/test_torch_distributed_xla_backend.py" diff --git a/test/spmd/args_parse.py b/test/spmd/args_parse.py index a1540047545..858f84265a6 100644 --- a/test/spmd/args_parse.py +++ b/test/spmd/args_parse.py @@ -39,6 +39,7 @@ def parse_common_options(datadir=None, parser.add_argument('--async_closures', action='store_true') parser.add_argument('--debug', action='store_true') parser.add_argument('--profile', action='store_true') + parser.add_argument('--auto_spmd', action='store_true') if opts: for name, aopts in opts: parser.add_argument(name, **aopts) diff --git a/test/spmd/test_dtensor_integration.py b/test/spmd/test_dtensor_integration.py index 7e67aaeae59..272b8c7d7c4 100644 --- a/test/spmd/test_dtensor_integration.py +++ b/test/spmd/test_dtensor_integration.py @@ -4,11 +4,13 @@ import torch from torch import nn import torch.optim as optim -from torch.distributed._tensor import DeviceMesh, Shard +from torch.distributed._tensor import (DeviceMesh, Shard, distribute_tensor, + distribute_module) import torch_xla +import torch_xla.debug.metrics as met import torch_xla.runtime as xr import torch_xla.core.xla_model as xm -from torch_xla.distributed.spmd import xla_distribute_tensor, xla_distribute_module +from torch_xla.distributed.spmd import auto_policy import unittest @@ -19,7 +21,6 @@ class DTensorIntegrationTest(test_xla_sharding_base.XlaShardingTest): @classmethod def setUpClass(cls): - xr.use_spmd() super().setUpClass() def test_xla_distribute_tensor(self): @@ -33,8 +34,7 @@ def test_xla_distribute_tensor(self): 3, requires_grad=requires_grad, device=xm.xla_device()) - dist_tensor = xla_distribute_tensor(tensor_to_shard, device_mesh, - shard_spec) + dist_tensor = distribute_tensor(tensor_to_shard, device_mesh, shard_spec) # TODO(yeounoh) switch to DTensor API when XLAShardedTensor inherits DTensor assert type(dist_tensor).__name__ == "XLAShardedTensor" assert len(dist_tensor.sharding_spec) > 0 @@ -47,65 +47,68 @@ def test_xla_distribute_tensor(self): self.assertTrue(dist_tensor.global_tensor.requires_grad) self.assertTrue(dist_tensor.is_leaf) - def test_xla_distribute_module(self): + def test_optimizer_step_with_sharding(self): + # Use simple linear model to test model parameter sharding model = self.SimpleLinear().to(xm.xla_device()) + # Running the same mark_sharding test with xla_distribute_tensor instead device_count = xr.global_runtime_device_count() device_mesh = DeviceMesh("xla", list(range(device_count))) + shard_spec = [Shard(0)] + distribute_tensor(model.fc1.weight, device_mesh, shard_spec) + sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight) - def shard_params(mod_name, mod, mesh): - shard_spec = [Shard(0)] - # annoate fc1 and fc2 - if isinstance(mod, nn.Linear): - for name, param in mod.named_parameters(): - dist_param = xla_distribute_tensor(param, mesh, shard_spec) - - sharded_model = xla_distribute_module(model, device_mesh, shard_params) - self.assertTrue( - torch_xla._XLAC._get_xla_sharding_spec(sharded_model.fc1.weight) != "") - self.assertTrue( - torch_xla._XLAC._get_xla_sharding_spec(sharded_model.fc2.weight) != "") - - sharded_model.train() + model.train() optimizer = optim.SGD(model.parameters(), lr=0.1) data = torch.randn(128, 128).to(xm.xla_device()) target = torch.zeros(128).to(xm.xla_device()) loss_fn = nn.CrossEntropyLoss() - for i in range(3): + for _ in range(3): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() optimizer.step() xm.mark_step() + # Sharding is persisted across mark_step calls, and test if the sharded computation + # can repeat more than once without crashing. + self.assertEqual(sharding_spec, + torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight)) - def test_optimizer_step_with_sharding(self): - # Use simple linear model to test model parameter sharding + def test_xla_distribute_module(self): model = self.SimpleLinear().to(xm.xla_device()) - # Running the same mark_sharding test with xla_distribute_tensor instead device_count = xr.global_runtime_device_count() device_mesh = DeviceMesh("xla", list(range(device_count))) - shard_spec = [Shard(0)] - xla_distribute_tensor(model.fc1.weight, device_mesh, shard_spec) - sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight) - model.train() - optimizer = optim.SGD(model.parameters(), lr=0.1) + def shard_params(mod_name, mod, mesh): + shard_spec = [Shard(0)] + # annoate fc1 and fc2 + if isinstance(mod, nn.Linear): + for name, param in mod.named_parameters(): + dist_param = distribute_tensor(param, mesh, shard_spec) + + sharded_model = distribute_module(model, device_mesh, shard_params) + self.assertTrue( + torch_xla._XLAC._get_xla_sharding_spec(sharded_model.fc1.weight) != "") + self.assertTrue( + torch_xla._XLAC._get_xla_sharding_spec(sharded_model.fc2.weight) != "") + + sharded_model.train() + optimizer = optim.SGD(sharded_model.parameters(), lr=0.1) data = torch.randn(128, 128).to(xm.xla_device()) target = torch.zeros(128).to(xm.xla_device()) loss_fn = nn.CrossEntropyLoss() - for i in range(3): + for _ in range(3): optimizer.zero_grad() - output = model(data) + output = sharded_model(data) loss = loss_fn(output, target) loss.backward() optimizer.step() xm.mark_step() - # Sharding is persisted across mark_step calls, and test if the sharded computation - # can repeat more than once without crashing. - self.assertEqual(sharding_spec, - torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight)) + # Should run with SPMD mode, ExecuteReplicated. + self.assertTrue(met.counter_value("ExecuteReplicated") > 0) + self.assertTrue(met.counter_value("ExecuteComputation") is None) if __name__ == '__main__': diff --git a/test/spmd/test_dtensor_integration2.py b/test/spmd/test_dtensor_integration2.py new file mode 100644 index 00000000000..3955d02b552 --- /dev/null +++ b/test/spmd/test_dtensor_integration2.py @@ -0,0 +1,58 @@ +import os +import sys + +import torch +from torch import nn +import torch.optim as optim +from torch.distributed._tensor import (DeviceMesh, Shard, distribute_tensor, + distribute_module) +import torch_xla +import torch_xla.debug.metrics as met +import torch_xla.runtime as xr +import torch_xla.core.xla_model as xm +from torch_xla.distributed.spmd import auto_policy + +import unittest + +import test_xla_sharding_base + + +# This integration test passes when run independently. +class DTensorIntegrationTest2(test_xla_sharding_base.XlaShardingTest): + + @classmethod + def setUpClass(cls): + super().setUpClass() + + @unittest.skipUnless(xr.device_type() in ["TPU", "CPU"], + "Auto-sharding currently supports TPU device.") + def test_xla_distribute_module_auto(self): + device_count = xr.global_runtime_device_count() + device_mesh = DeviceMesh("xla", list(range(device_count))) + + # Use torch_xla.distributed.spmd.auto_policy to enable auto-sharding; + # Currently, model should be loaded to xla device via distribute_module. + model = self.SimpleLinear() + sharded_model = distribute_module(model, device_mesh, auto_policy) + sharded_model.train() + self.assertTrue(torch_xla._XLAC._xla_get_auto_sharding()) + + optimizer = optim.SGD(sharded_model.parameters(), lr=0.1) + data = torch.randn(128, 128).to(xm.xla_device()) + target = torch.zeros(128).to(xm.xla_device()) + loss_fn = nn.CrossEntropyLoss() + for _ in range(5): + optimizer.zero_grad() + output = sharded_model(data) + loss = loss_fn(output, target) + loss.backward() + optimizer.step() + xm.mark_step() + # Should compile with auto-sharding, we expect up to 3 times + cnt = met.counter_value("CompileWithAutoSharding") + self.assertTrue((cnt is not None) and (cnt <= 3)) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/spmd/test_dynamo_spmd.py b/test/spmd/test_dynamo_spmd.py index cc29e98265f..0595f502da0 100644 --- a/test/spmd/test_dynamo_spmd.py +++ b/test/spmd/test_dynamo_spmd.py @@ -41,7 +41,6 @@ class DynamoSpmdInferenceTest(test_xla_sharding_base.XlaShardingTest): @classmethod def setUpClass(cls): - xr.use_spmd() super().setUpClass() def test_dynamo_spmd_basic(self): diff --git a/test/spmd/test_spmd_graph_dump.py b/test/spmd/test_spmd_graph_dump.py index 9d6f697025d..e3cadd7b3ce 100644 --- a/test/spmd/test_spmd_graph_dump.py +++ b/test/spmd/test_spmd_graph_dump.py @@ -18,7 +18,6 @@ class SpmdGraphDumpTest(test_xla_sharding_base.XlaShardingTest): @classmethod def setUpClass(cls): - xr.use_spmd() super().setUpClass() def test_dump_with_output_sharding(self): diff --git a/test/spmd/test_train_spmd_imagenet.py b/test/spmd/test_train_spmd_imagenet.py index c7802c64a31..8757a45cc1b 100644 --- a/test/spmd/test_train_spmd_imagenet.py +++ b/test/spmd/test_train_spmd_imagenet.py @@ -86,6 +86,8 @@ import torch_xla.test.test_utils as test_utils import torch_xla.distributed.spmd as xs +xr.use_spmd(auto=FLAGS.auto_spmd) + DEFAULT_KWARGS = dict( batch_size=128, test_set_batch_size=64, diff --git a/test/spmd/test_train_spmd_linear_model.py b/test/spmd/test_train_spmd_linear_model.py index e08f361c42a..686178292ea 100644 --- a/test/spmd/test_train_spmd_linear_model.py +++ b/test/spmd/test_train_spmd_linear_model.py @@ -36,6 +36,8 @@ FLAGS = args_parse.parse_common_options( batch_size=128, num_epochs=1, opts=MODEL_OPTS.items()) +xr.use_spmd(auto=FLAGS.auto_spmd) + class SimpleLinear(nn.Module): diff --git a/test/spmd/test_xla_auto_sharding.py b/test/spmd/test_xla_auto_sharding.py new file mode 100644 index 00000000000..2647f3c7bd1 --- /dev/null +++ b/test/spmd/test_xla_auto_sharding.py @@ -0,0 +1,74 @@ +import copy + +import unittest +from unittest.mock import patch +import math +import numpy as np +import os +import sys + +import torch +from torch import nn +import torch.nn.functional as F +import torch.optim as optim +import torch_xla +import torch_xla.debug.metrics as met +import torch_xla.runtime as xr +import torch_xla.core.xla_model as xm +import torch_xla.debug.metrics as met +import torch_xla.distributed.spmd as xs +from torch_xla.distributed.spmd import XLAShardedTensor +import test_xla_sharding_base + +import torch_xla.core.xla_env_vars as xenv +import torch_xla.utils.utils as xu +from torch_xla._internal import tpu + + +class XlaAutoShardingTest(test_xla_sharding_base.XlaShardingTest): + + @classmethod + def setUpClass(cls): + xr.use_spmd(auto=True) + super().setUpClass() + + @unittest.skipUnless(xr.device_type() in ["TPU", "CPU"], + "Auto-sharding currently supports TPU & CPU backends.") + def test_matmul(self): + met.clear_counters() + t1 = torch.ones(64, 128) + t2 = torch.ones(128, 256) + t3 = (t1 @ t2).sum() + + xt1 = t1.to(xm.xla_device()) + xt2 = t2.to(xm.xla_device()) + xt3 = (xt1 @ xt2).sum() + xm.mark_step() + self.assertEqual(met.counter_value("CompileWithAutoSharding"), 1) + self.assertTrue(torch.allclose(t3, xt3.cpu())) + + @unittest.skipUnless(xr.device_type() in ["TPU", "CPU"], + "Auto-sharding currently supports TPU & CPU backends.") + def test_simple_linear_training(self): + met.clear_counters() + + model = self.SimpleLinear().to(xm.xla_device()) + model.train() + optimizer = optim.SGD(model.parameters(), lr=0.1) + data = torch.randn(128, 128).to(xm.xla_device()) + target = torch.zeros(128).to(xm.xla_device()) + loss_fn = nn.CrossEntropyLoss() + for i in range(5): + optimizer.zero_grad() + output = model(data) + loss = loss_fn(output, target) + loss.backward() + optimizer.step() + xm.mark_step() + cnt = met.counter_value("CompileWithAutoSharding") + self.assertTrue((cnt is not None) and (cnt <= 3)) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py index d8d19ad7c9e..98c465e0718 100644 --- a/test/spmd/test_xla_distributed_checkpoint.py +++ b/test/spmd/test_xla_distributed_checkpoint.py @@ -41,7 +41,6 @@ class DistributedCheckpointTestBase(test_xla_sharding_base.XlaShardingTest): @classmethod def setUpClass(cls): - xr.use_spmd() super().setUpClass() def _get_sharded_model(self, mesh_shape=None): diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 6fccfe32800..2dabc383f69 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -28,7 +28,6 @@ class BasicXlaShardingTest(test_xla_sharding_base.XlaShardingTest): @classmethod def setUpClass(cls): - xr.use_spmd() super().setUpClass() def test_xla_sharded_tensor(self): @@ -38,8 +37,6 @@ def test_xla_sharded_tensor(self): device=xm.xla_device()) xst1 = xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)), partition_spec) - - # TODO(244003536) add more tests for XLAShardedTensror. self.assertTrue(isinstance(xst1, XLAShardedTensor)) def test_xla_sharded_tensor_repr(self): diff --git a/test/spmd/test_xla_sharding_base.py b/test/spmd/test_xla_sharding_base.py index 91525b6ccc7..abed0d00904 100644 --- a/test/spmd/test_xla_sharding_base.py +++ b/test/spmd/test_xla_sharding_base.py @@ -1,3 +1,4 @@ +import os import unittest import numpy as np @@ -31,6 +32,13 @@ def forward(self, x): def setUpClass(cls): cls.n_devices = xr.global_runtime_device_count() cls.device_ids = np.array(range(cls.n_devices)) + xr.use_spmd() + + @classmethod + def tearDownClass(cls): + del os.environ['XLA_USE_SPMD'] + if 'XLA_AUTO_SPMD' in os.environ: + del os.environ['XLA_AUTO_SPMD'] def _get_mesh(self, mesh_shape, device_ids=None, axis_names=None): assert type(mesh_shape) is tuple, 'mesh_shape must be Tuple[int]' diff --git a/test/spmd/test_xla_sharding_hlo.py b/test/spmd/test_xla_sharding_hlo.py index fd0540941c8..a5a1159aa9e 100644 --- a/test/spmd/test_xla_sharding_hlo.py +++ b/test/spmd/test_xla_sharding_hlo.py @@ -18,7 +18,6 @@ class XlaShardingHloTest(test_xla_sharding_base.XlaShardingTest): @classmethod def setUpClass(cls): - xr.use_spmd() super().setUpClass() @patch.dict(os.environ, {"XLA_DUMP_POST_OPTIMIZATIONS": "1"}) diff --git a/test/spmd/test_xla_spmd_python_api_interaction.py b/test/spmd/test_xla_spmd_python_api_interaction.py index 9c857de84e8..9415769d91c 100644 --- a/test/spmd/test_xla_spmd_python_api_interaction.py +++ b/test/spmd/test_xla_spmd_python_api_interaction.py @@ -19,7 +19,6 @@ class BasicXMAPITest(test_xla_sharding_base.XlaShardingTest): @classmethod def setUpClass(cls): - xr.use_spmd() super().setUpClass() def test_get_xla_supported_devices(self): @@ -65,7 +64,6 @@ class BasicRuntimeAPITest(test_xla_sharding_base.XlaShardingTest): @classmethod def setUpClass(cls): - xr.use_spmd() super().setUpClass() def test_local_process_count(self): @@ -146,7 +144,6 @@ class BasicAutocastAPITest(test_xla_sharding_base.XlaShardingTest): @classmethod def setUpClass(cls): - xr.use_spmd() super().setUpClass() @unittest.skipIf(xr.device_type() not in ['TPU', 'CUDA'], @@ -165,7 +162,6 @@ class BasicDistributedTest(test_xla_sharding_base.XlaShardingTest): @classmethod def setUpClass(cls): - xr.use_spmd() return super().setUpClass() def test_xla_backend(self): diff --git a/test/spmd/test_xla_virtual_device.py b/test/spmd/test_xla_virtual_device.py index b953c780a61..739680de405 100644 --- a/test/spmd/test_xla_virtual_device.py +++ b/test/spmd/test_xla_virtual_device.py @@ -17,7 +17,6 @@ class VirtualDeviceTest(test_xla_sharding_base.XlaShardingTest): @classmethod def setUpClass(cls): - xr.use_spmd() super().setUpClass() def test_mark_sharding(self): diff --git a/test/test_train_mp_imagenet.py b/test/test_train_mp_imagenet.py index 7fce41f8ab2..3342d3901d6 100644 --- a/test/test_train_mp_imagenet.py +++ b/test/test_train_mp_imagenet.py @@ -69,6 +69,7 @@ ) import os +import sys import schedulers import numpy as np import torch diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index 4b3de9267c9..77676b60abe 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -10,6 +10,7 @@ python3 test/spmd/test_xla_virtual_device.py python3 test/spmd/test_xla_distributed_checkpoint.py python3 test/spmd/test_train_spmd_linear_model.py python3 test/spmd/test_xla_spmd_python_api_interaction.py +python3 test/spmd/test_xla_auto_sharding.py XLA_EXPERIMENTAL=nonzero:masked_select python3 test/ds/test_dynamic_shape_models.py -v XLA_EXPERIMENTAL=nonzero:masked_select python3 test/ds/test_dynamic_shapes.py -v python3 test/test_autocast.py diff --git a/torch_xla/csrc/aten_xla_bridge.cpp b/torch_xla/csrc/aten_xla_bridge.cpp index e30e83ee3d1..bd9b46b0a04 100644 --- a/torch_xla/csrc/aten_xla_bridge.cpp +++ b/torch_xla/csrc/aten_xla_bridge.cpp @@ -39,6 +39,16 @@ class AtenXlaDeviceMapper { return devices_; } + void SetVirtualDevice() { + for (auto& device : GetAllDevices()) { + if (static_cast(device.type()) == XlaDeviceType::SPMD) { + return; + } + } + devices_.emplace_back(ParseDeviceString("SPMD:0")); + devices_ordinals_[devices_.back()] = 0; + } + private: AtenXlaDeviceMapper() { if (UseVirtualDevice()) { diff --git a/torch_xla/csrc/helpers.cpp b/torch_xla/csrc/helpers.cpp index adf4961534c..6e3d235f166 100644 --- a/torch_xla/csrc/helpers.cpp +++ b/torch_xla/csrc/helpers.cpp @@ -1002,8 +1002,7 @@ xla::StatusOr XlaHelpers::WrapXlaComputation( if (input_output_alias_pair.size() > 0) { for (const auto& [input_index, output_index] : input_output_alias_pair) { // Both input and output will be a tuple so parameter_number will always - // be - // 0 + // be 0 builder.SetUpAlias(/*output_index=*/xla::ShapeIndex({output_index}), /*param_number=*/0, /*param_index=*/xla::ShapeIndex({input_index})); diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index c66bfe915b0..66090744e8f 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1170,6 +1170,12 @@ void InitXlaModuleBindings(py::module m) { [](const at::Tensor& tensor) { return GetTensorViewAliasId(tensor); }); m.def("_xla_get_tensor_id", [](const at::Tensor& tensor) { return GetTensorId(tensor); }); + m.def("_xla_set_auto_sharding", []() { + ShardingUtil::SetAutoSharding(); + XLA_CHECK(ShardingUtil::GetAutoSharding()); + }); + m.def("_xla_get_auto_sharding", + []() { return ShardingUtil::GetAutoSharding(); }); m.def("_xla_get_spmd_config_is_locked", []() { return GetLockSpmdConfig(); }); m.def("_xla_force_spmd_device", []() { // It is actually more easier to force SPMD mode than blocking if there is @@ -1179,9 +1185,10 @@ void InitXlaModuleBindings(py::module m) { // allows the users to call `xr.use_spmd()` more freely, given that the // earlier they call, the smaller the one-time overhead of replicating // non-SPMD backed tensors. - torch::lazy::BackendDevice backend_device = bridge::GetCurrentDevice(); + torch::lazy::BackendDevice current_device = bridge::GetCurrentDevice(); std::vector xtensors = - XLAGraphExecutor::Get()->GetLiveTensors(&backend_device); + XLAGraphExecutor::Get()->GetLiveTensors(¤t_device); + torch::lazy::BackendDevice spmd_device = ParseDeviceString("SPMD:0"); for (auto xtensor : xtensors) { XlaDeviceType xla_device_type = static_cast(xtensor->GetDevice().type()); @@ -1189,16 +1196,14 @@ void InitXlaModuleBindings(py::module m) { // Internally this moves the device data to the host and then copy // to the SPMD virtual device. The original data should be destroyed // in the transition, after creating a detached host-side copy. - // TODO(yeounoh) this can be further optimized via CopyToDevice. + // TODO(yeounoh) Consider CopyToDevice, and make data's device mutable. at::Tensor tensor = xtensor->ToTensor(false); - torch::lazy::BackendDevice device = ParseDeviceString("SPMD:0"); - xtensor->SetXlaData(TensorToXlaData(tensor, device)); - // TODO(yeounoh) allow tensor data's device to be mutable. + xtensor->SetXlaData(TensorToXlaData(tensor, spmd_device)); } } - if (!UseVirtualDevice()) { - XLA_CHECK(UseVirtualDevice(/*force_spmd=*/true)); - } + + // Ensure that virtual device is registered. + XLA_CHECK(UseVirtualDevice(/*force_spmd=*/true)); }); m.def("_init_computation_client", []() { runtime::GetComputationClient(); }); m.def("_xla_get_device_hw_type", [](const at::Tensor& tensor) { diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index f1c2bb42467..33b48255baf 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -213,7 +213,10 @@ class ComputationClient { const xla::Shape* output_shape, bool parameter_is_tupled_arguments = false, bool is_sharded = false, - bool allow_spmd_sharding_propagation_to_output = true) + bool allow_spmd_sharding_propagation_to_output = true, + bool use_auto_spmd_partitioning = false, + std::vector auto_spmd_mesh_shape = {}, + std::vector auto_spmd_mesh_ids = {}) : computation(std::move(computation)), compilation_device(std::move(compilation_device)), devices(std::move(devices)), @@ -221,7 +224,10 @@ class ComputationClient { parameter_is_tupled_arguments(parameter_is_tupled_arguments), is_sharded(is_sharded), allow_spmd_sharding_propagation_to_output( - allow_spmd_sharding_propagation_to_output) {} + allow_spmd_sharding_propagation_to_output), + use_auto_spmd_partitioning(use_auto_spmd_partitioning), + auto_spmd_mesh_shape(auto_spmd_mesh_shape), + auto_spmd_mesh_ids(auto_spmd_mesh_ids) {} xla::XlaComputation computation; std::string compilation_device; @@ -230,6 +236,9 @@ class ComputationClient { bool parameter_is_tupled_arguments; bool is_sharded; bool allow_spmd_sharding_propagation_to_output; + bool use_auto_spmd_partitioning; + std::vector auto_spmd_mesh_shape; + std::vector auto_spmd_mesh_ids; }; struct ExecuteComputationOptions : public ClientExecuteOptions {}; @@ -270,6 +279,12 @@ class ComputationClient { virtual std::vector TransferToDevice( absl::Span> tensors) = 0; + // Reshard and return data sharded by `sharding` spec. This is a no-op if the + // input sharding spec is identical to the target `sharding` sharding spec. + virtual std::vector ReshardData( + absl::Span handles, + absl::Span shardings) = 0; + // Transfers local sharded tensor values to the TPU devices and returns a // `PjRtShardedData`. virtual DataPtr TransferShardsToDevice( diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index 7d7736a872e..d6d914ad8da 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -45,6 +45,12 @@ class IfrtComputationClient : public ComputationClient { std::vector TransferToDevice( absl::Span> tensors) override; + std::vector ReshardData( + absl::Span handles, + absl::Span shardings) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + std::vector TransferFromDevice( absl::Span handles) override; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index ae3f3c47b7e..28de28c7b8e 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -28,6 +28,7 @@ #include "xla/literal.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_executable.h" +#include "xla/protobuf_util.h" #include "xla/shape.h" using xla::internal::XlaBuilderFriend; @@ -223,7 +224,7 @@ ComputationClient::DataPtr PjRtComputationClient::GetDataShard( ComputationClient::DataPtr PjRtComputationClient::WrapDataShards( absl::Span shards, std::string device, xla::Shape shape, xla::OpSharding sharding) { - XLA_CHECK_EQ(shards.size(), client_->addressable_device_count()); + XLA_CHECK_EQ(shards.size(), client_->addressable_devices().size()); std::vector> pjrt_data_shards; pjrt_data_shards.reserve(shards.size()); for (auto& shard : shards) { @@ -312,9 +313,9 @@ ComputationClient::DataPtr PjRtComputationClient::CopyToDevice( // Returns error if the buffer is already on `dst_device`. xla::StatusOr> status_or = pjrt_data->buffer->CopyToDevice(dst_device); - XLA_CHECK(status_or.ok()) - << pjrt_data->device() << " buffer already exists on " << dst; - + if (!status_or.ok()) { + return data; + } return std::make_shared(dst, pjrt_data->shape(), std::move(status_or.value())); } @@ -382,6 +383,83 @@ PjRtComputationClient::ReplicateShardedData( << handle->ToString(); } +std::vector PjRtComputationClient::ReshardData( + absl::Span handles, + absl::Span shardings) { + tsl::profiler::TraceMe activity("ReshardData", + tsl::profiler::TraceMeLevel::kInfo); + XLA_COUNTER("ReshardData", 1); + XLA_CHECK_EQ(handles.size(), shardings.size()) + << "input handles and shardings must have the same length."; + XLA_CHECK(UseVirtualDevice()) << "We only supports SPMD mode resharding."; + + // Perform a simple identity calculation to reshard. + xla::XlaBuilder builder("ReshardData"); + + std::vector shapes; + shapes.reserve(handles.size()); + std::vector hlo_shardings; + hlo_shardings.reserve(handles.size()); + std::vector param_ops; + param_ops.reserve(handles.size()); + for (int i = 0; i < handles.size(); ++i) { + PjRtShardedData* sharded_data = + dynamic_cast(handles[i].get()); + XLA_CHECK_NE(sharded_data, nullptr) + << "Resharding requires PjRtShardedData on SPMD virtual device, " + << "current device: " << handles[i]->device(); + shapes.push_back(sharded_data->shape()); + + const xla::OpSharding& sharding = shardings[i]; + XLA_CHECK_NE(sharding.type(), xla::OpSharding::UNKNOWN) + << "Resharding by UNKNOWN sharding type is not allowed."; + + hlo_shardings.push_back( + ConsumeValue(xla::HloSharding::FromProto(sharding))); + + xla::OpSharding fallback_sharding; + fallback_sharding.set_type(xla::OpSharding::REPLICATED); + xla::XlaScopedShardingAssignment assign( + &builder, sharded_data->GetSharding().type() == xla::OpSharding::UNKNOWN + ? fallback_sharding + : sharded_data->GetSharding()); + param_ops.push_back( + xla::Parameter(&builder, i, shapes[i], absl::StrCat("p.", i))); + } + + xla::XlaOp root; + { + xla::Shape shapes_tuple = xla::ShapeUtil::MakeTupleShape(shapes); + XLA_CHECK_EQ(shapes_tuple.tuple_shapes_size(), hlo_shardings.size()); + xla::HloSharding new_shardings_tuple = + xla::HloSharding::Tuple(shapes_tuple, hlo_shardings); + xla::XlaScopedShardingAssignment assign(&builder, + new_shardings_tuple.ToProto()); + root = xla::Tuple(&builder, param_ops); + } + + xla::XlaComputation xla_computation = ConsumeValue(builder.Build(root)); + xla::ProgramShape program_shape = + ConsumeValue(xla_computation.GetProgramShape()); + + std::string device = GetDefaultDevice(); + std::vector instances; + instances.push_back({std::move(xla_computation), device, + GetCompilationDevices(device, {}), + &program_shape.result(), + /*should_wrap_parameter=*/false, + /*is_sharded=*/true, + /*allow_spmd_sharding_propagation_to_output=*/false}); + std::shared_ptr + computation = Compile(std::move(instances)).front(); + + torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions + execute_options; + auto resharded_results = ExecuteReplicated( + *computation, handles, GetLocalDevices(), execute_options); + return resharded_results; +} + std::vector PjRtComputationClient::TransferFromDevice( absl::Span handles) { metrics::TimedSection timed(TransferFromDeviceMetric()); @@ -425,17 +503,42 @@ std::vector PjRtComputationClient::Compile( if (instance.is_sharded) { // TODO(yeounoh) multi-host, multi-slice configurations compile_options.executable_build_options.set_use_spmd_partitioning(true); + // We can override the compiler's default behavior to replicate the // outputs. Setting this to true would wrapping the sharded outputs in // PjRtShardedData. compile_options.executable_build_options .set_allow_spmd_sharding_propagation_to_output( {instance.allow_spmd_sharding_propagation_to_output}); + + int num_partitions = client_->device_count(); compile_options.executable_build_options.set_num_partitions( - client_->device_count()); + num_partitions); compile_options.executable_build_options.set_num_replicas(1); compile_options.parameter_is_tupled_arguments = instance.parameter_is_tupled_arguments; + compile_options.executable_build_options.set_use_auto_spmd_partitioning( + instance.use_auto_spmd_partitioning); + TF_VLOG(3) << "Auto SPMD partitioning " + << (instance.use_auto_spmd_partitioning ? "enabled!" + : "disabled."); + if (!instance.auto_spmd_mesh_shape.empty()) { + compile_options.executable_build_options + .set_auto_spmd_partitioning_mesh_shape( + instance.auto_spmd_mesh_shape); + TF_VLOG(3) << "auto_spmd_partitioning_mesh_shape=" + << absl::StrJoin(compile_options.executable_build_options + .auto_spmd_partitioning_mesh_shape(), + ","); + } + if (!instance.auto_spmd_mesh_ids.empty()) { + compile_options.executable_build_options + .set_auto_spmd_partitioning_mesh_ids(instance.auto_spmd_mesh_ids); + TF_VLOG(3) << "auto_spmd_partitioning_mesh_ids=" + << absl::StrJoin(compile_options.executable_build_options + .auto_spmd_partitioning_mesh_ids(), + ","); + } // TODO(244391366) verify this is correct for the collectives ops xla::DeviceAssignment device_assignment(1, client_->device_count()); @@ -481,7 +584,6 @@ std::vector PjRtComputationClient::Compile( const auto& hlo_modules = ConsumeValue(executable->GetHloModules()); xla::HloComputation* hlo_computation = hlo_modules[0]->entry_computation(); - std::shared_ptr pjrt_computation = std::make_shared( std::move(xla::XlaComputation(hlo_modules[0]->ToProto())), diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index a1f73ef3562..9a911c0139b 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -44,6 +44,14 @@ class PjRtComputationClient : public ComputationClient { std::vector TransferToDevice( absl::Span> tensors) override; + // Reshard and return data sharded by `sharding` spec. This is a no-op if + // the input sharding spec is identical to the target `sharding` sharding + // spec. + // TODO(yeounoh) replace ReplicateShardedData with this. + std::vector ReshardData( + absl::Span handles, + absl::Span shardings) override; + std::vector TransferFromDevice( absl::Span handles) override; diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 5ef977cc415..33f24b2b9a0 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -50,6 +50,14 @@ namespace torch_xla { +namespace { +bool CanApplySharding(const XLATensor::ShardingSpecPtr sharding) { + return !sharding || + sharding->sharding.type() == xla::OpSharding::REPLICATED || + sharding->sharding.type() == xla::OpSharding::UNKNOWN; +} +} // namespace + XLATensor::Data::~Data() { XLAGraphExecutor::Get()->UnregisterTensor(this); } XLATensorPtr XLATensor::Create(const at::Tensor& tensor, @@ -247,12 +255,15 @@ void XLATensor::SetShardingSpec(const ShardingSpec& sharding, bool overwrite) { TORCH_LAZY_COUNTER("SetShardingSpec", 1); data()->sharding = std::make_shared(sharding); } else { + // Tensor is already sharding annotated, check if it is UNKNOWN or + // the same sharding type. XLA_CHECK(ShardingUtil::EqualShardingSpecs(sharding, *sharding_spec())) << "Existing sharding annotation, " << sharding_spec()->sharding.DebugString() << ", must be cleared before applying a new one, " << sharding.sharding.DebugString(); } + // Sync to the node. dynamic_cast(GetIrValue().node.get()) ->SetSharding(sharding_spec()->sharding, GetIrValue().index); } @@ -268,18 +279,31 @@ void XLATensor::ClearShardingSpec() { XLATensor::ShardingSpecPtr XLATensor::sharding_spec() const { ShardingSpecPtr sharding = data()->sharding; torch::lazy::Value ir_value = CurrentIrValue(); - if (sharding && ir_value) { - // The copy of sharding annotation on the IR node should be the same. + if (ir_value) { auto* xla_node = dynamic_cast(ir_value.node.get()); - if (xla_node->GetSharding(ir_value.index)) { - XLA_CHECK(ShardingUtil::EqualShardingSpecs( - *sharding, ShardingSpec{*xla_node->GetSharding(ir_value.index), - xla_node->xla_shape()})) - << "Sharding on tensor: " - << xla::HloSharding::FromProto(sharding->sharding)->ToString() - << ", sharding on IR: " - << xla::HloSharding::FromProto(*xla_node->GetSharding(ir_value.index)) - ->ToString(); + const auto* new_op_sharding = xla_node->GetSharding(ir_value.index).get(); + if (new_op_sharding && + (new_op_sharding->type() != xla::OpSharding::UNKNOWN)) { + // Re-sync the sharding annotation from the node to the tensor if there is + // one attached to the node. A new sharding annotation is attached + // directly to the node, and gets synced to the tensor after this. + // If sharding is attached via SetShardingSpec, then it flows from the + // tensor to the node. If sharding is attached by the compiler pass, then + // it first gets attached to the graph node, and then synced to the tensor + // here. + if (!sharding || + (sharding && !ShardingUtil::EqualOpShardings(*new_op_sharding, + sharding->sharding))) { + TF_VLOG(5) << "Syncing node sharding (type=" << new_op_sharding->type() + << ") to tensor (shape=" << xla_node->xla_shape().ToString() + << ")."; + data()->sharding = std::make_shared( + *new_op_sharding, xla_node->xla_shape()); + } + } else if (sharding) { + // There is a case where the sharding spec on the tensor is not + // propagated down to the node after a reset. + xla_node->SetSharding(sharding->sharding, ir_value.index); } } return sharding; @@ -335,7 +359,7 @@ void XLATensor::SetInPlaceIrValue(torch::lazy::Value ir_value) { } void XLATensor::AssignIrValue(torch::lazy::Value ir_value) const { - TF_VLOG(5) << "Assign IR value: " + TF_VLOG(6) << "Assign IR value: " << (ir_value ? ir_value->ToString() : "empty node"); data()->ir_value = std::move(ir_value); data()->generation += 1; diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 74ba457aed7..cd66bdb2db9 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -614,6 +614,12 @@ XLAGraphExecutor::SyncTensorCollection XLAGraphExecutor::CollectSyncTensors( coll.device = *unique_device; coll.indices.reserve(tensors.size()); for (size_t i = 0; i < tensors.size(); ++i) { + // Sync sharding annotations between the tensor and its node, if exists. + // This either push down the sharding on the tensor to the IR before node + // hash computation if the node has no annotation, or it actually syncs the + // sharding attached to the node to the tensor, since sharding propagation & + // resharding should attach the latest to the node. + tensors[i]->sharding_spec(); if (tensor_ids.insert(tensors[i]->GetUniqueId()).second && // A tensor's xla_data might not be up to date if there is a view // associated with it. Make sure to sync those tensors here too. @@ -623,14 +629,6 @@ XLAGraphExecutor::SyncTensorCollection XLAGraphExecutor::CollectSyncTensors( torch::lazy::Value ir_value = tensors[i]->CurrentIrValue(); if (ir_value) { if (ShouldSyncIrValue(ir_value)) { - // `sharding_spec()` checks sharding equality. If IR node has no - // sharding, then sync XLATensor sharding to the IR node. XLATensor's - // sharding takes the precedence as the source of the truth. - XLATensor::ShardingSpecPtr sharding = tensors[i]->sharding_spec(); - if (sharding) { - dynamic_cast(ir_value.node.get()) - ->SetSharding(sharding->sharding, ir_value.index); - } auto device_data = torch_xla::DeviceData::Cast(ir_value.node.get()); // If current tensor is cloned from another tensor, we want to assign // a new XlaData to it after current execution. Cloned tensor might @@ -1204,6 +1202,19 @@ XLAGraphExecutor::TryRunCachedSync( TORCH_LAZY_VALUE_METRIC("TensorsGraphSize", po_data->post_order.size()); TF_VLOG(5) << "TensorsGraphSize=" << po_data->post_order.size(); + if (ShardingUtil::GetAutoSharding()) { + // TODO(yeounoh) we may be able to update the cache to avoid this. + // The current issue is that we are not properly updating the original + // tensors to track the new sharded data after resharding. + const xla::HloModuleProto& computation_proto = + cached_computation->computation->computation().proto(); + ShardingUtil::ReshardParameters(computation_proto, tensors, + &po_data->parameters_data, + &po_data->post_order); + TF_VLOG(5) << "Parameter sequence hash after resharding: " + << torch::lazy::Hash(po_data->parameter_sequence); + } + // don't schedule the execution if the purpose of this SyncTensor is just to // warm up the cache. return std::pair>( @@ -1313,6 +1324,7 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile( runtime::sys_util::GetEnvInt("XLA_PARAMETER_WRAPPING_THREADSHOLD", 3200); static const bool using_pjrt = runtime::sys_util::GetEnvString("PJRT_DEVICE", "").size() > 0; + static const bool use_autosharding = ShardingUtil::GetAutoSharding(); LoweringContext lowering_ctx("SyncTensorsGraph", coll.device, po_data->post_order, std::move(po_data->emission_map)); @@ -1328,10 +1340,10 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile( std::vector> input_output_alias_pair; std::vector buffer_donor_indices; - // TODO(yeounoh) aliasing is disabled for partitioned computation, + // TODO(yeounoh) enable aliasing is disabled for partitioned computation, // since the current aliasing compares the unpartitioned input and output // shapes which can lead to an incorrect aliasing pairs if sharded. - if (enable_aliasing) { + if (enable_aliasing && !use_autosharding) { if (coll.config.sync_ltc_data && coll.config.force_ltc_data) { // We can only alias at the step barrier, when force_ltc_data is true. // Consider the case: @@ -1369,9 +1381,10 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile( xla::XlaComputation computation = ConsumeValue(lowering_ctx.BuildXla()); xla::ProgramShape program_shape = ConsumeValue(computation.GetProgramShape()); + // TODO(yeounoh) enable wrapping with auto-sharding. bool should_wrap_parameter = (program_shape.parameters_size() >= parameter_wrapping_threadshold) && - using_pjrt; + using_pjrt && !use_autosharding; if (should_wrap_parameter) { TF_VLOG(3) << "Wrapping graph with " << program_shape.parameters_size() << " parameters. Threadshold = " @@ -1390,6 +1403,25 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile( coll.device.toString(), devices), &shape, should_wrap_parameter, is_sharded}); + if (use_autosharding) { + TF_VLOG(5) << "use_auto_spmd_partitioning is set."; + TF_CHECK(is_sharded) << "Auto-sharding pass requires SPMD mode."; + instances.front().use_auto_spmd_partitioning = use_autosharding; + TORCH_LAZY_COUNTER("CompileWithAutoSharding", 1); + + // Apply XLA_AUTO_SPMD_MESH if it is set. + // TODO(yeounoh) allow multi mesh exploration. + auto mesh_shape_ids = ShardingUtil::GetAutoShardingMesh(); + std::vector auto_spmd_mesh_shape = std::get<0>(mesh_shape_ids); + std::vector auto_spmd_mesh_ids = std::get<1>(mesh_shape_ids); + instances.front().auto_spmd_mesh_shape = auto_spmd_mesh_shape; + instances.front().auto_spmd_mesh_ids = auto_spmd_mesh_ids; + TF_VLOG(5) << "auto_spmd_mesh_shape={" + << absl::StrJoin(auto_spmd_mesh_shape, ",") << "}\n" + << "auto_spmd_mesh_ids={" + << absl::StrJoin(auto_spmd_mesh_ids, ",") << "}"; + } + DebugUtil::analyze_graph_execution_python_frame( DebugUtil::GraphAnalysisSource::Compilation, /*graph_hash=*/coll.hash, /*program_shape=*/&program_shape); @@ -1410,6 +1442,17 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile( << " is computation hash " << torch::lazy::HashToString(torch::lazy::Hash( computations.front()->computation().proto().SerializeAsString())); + + if (use_autosharding) { + const xla::HloModuleProto& computation_proto = + computations.front()->computation().proto(); + ShardingUtil::ReshardParameters(computation_proto, &tensors, + &po_data->parameters_data, + &po_data->post_order); + TF_VLOG(5) << "Parameter sequence hash after resharding: " + << torch::lazy::Hash(po_data->parameter_sequence); + } + if (should_wrap_parameter) { XLA_CHECK_EQ(program_shape.parameters_size(), 1); XLA_CHECK_EQ(program_shape.parameters()[0].tuple_shapes_size(), @@ -1421,8 +1464,7 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile( return {/*device=*/coll.device, /*emitted_nodes=*/lowering_ctx.GetEmittedNodeCount(), - /*computation=*/ - computations.front(), + /*computation=*/computations.front(), /*parameters_data=*/std::move(po_data->parameters_data), /*is_sharded=*/is_sharded}; } @@ -1482,9 +1524,9 @@ XLAGraphExecutor::SyncTensorsGraphInternal( } CompilationResult compile_result = Compile(*tensors, devices, coll, &po_data, ir_values); + TORCH_LAZY_VALUE_METRIC("TensorsGraphSize", compile_result.emitted_nodes); TF_VLOG(5) << "TensorsGraphSize=" << compile_result.emitted_nodes; - auto cached_computation = std::make_shared( std::move(compile_result.computation), compile_result.is_sharded); GetComputationCache()->Add(coll.hash, cached_computation); diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index d3b5e78877d..7f8b96254ff 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -30,6 +30,7 @@ #include "xla/service/sharding_propagation.h" #include "xla/service/spmd/spmd_partitioner.h" #include "xla/xla.pb.h" +#include "xla_sharding_util.h" namespace torch_xla { @@ -52,6 +53,8 @@ using tsl::ERROR; using tsl::INFO; using xla::internal::XlaBuilderFriend; +static bool use_auto_sharding = false; + // Return py::obj type as string. std::string GetPyType(const py::object& elem) { return elem.attr("__class__").attr("__name__").cast(); @@ -163,6 +166,21 @@ std::vector> ExtractGroupMembers( return groups; } +std::vector ParseStringToIntVector(const std::string& str) { + std::istringstream ss; + ss.str(str); + std::vector result; + for (std::string s; std::getline(ss, s, ',');) { + try { + result.push_back(std::stoi(s)); + } catch (std::invalid_argument const& e) { + TF_LOG(ERROR) << "Error parsing string: " << str + << " with an exception: " << e.what(); + } + } + return result; +} + } // namespace bool ShardingUtil::SetHloSharding(LoweringContext* lowering_ctx) { @@ -606,6 +624,121 @@ runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData( source_tensors, GetVirtualDevice().toString(), global_shape, sharding); } +std::tuple, std::vector> +ShardingUtil::GetAutoShardingMesh() { + // Auto-sharding uses mesh_shape = {n_devices, 1} if XLA_AUTO_SPMD_MESH + // is not set. XLA_AUTO_SPMD_MESH takes a form of string, "2,2" which + // corresponds to a 2-by-2 mesh. + std::vector mesh_shape = ParseStringToIntVector( + runtime::sys_util::GetEnvString("XLA_AUTO_SPMD_MESH", "")); + std::vector device_mesh_ids; + if (!mesh_shape.empty()) { + int64_t total_devices = 1; + for (auto i : mesh_shape) { + total_devices *= i; + } + XLA_CHECK_EQ(total_devices, + runtime::GetComputationClient()->GetAllDevices().size()) + << "Invalid auto-sharding mesh_shape: " + << absl::StrJoin(mesh_shape, ","); + device_mesh_ids = std::vector(total_devices); + std::iota(device_mesh_ids.begin(), device_mesh_ids.end(), 0); + } + return std::make_tuple(mesh_shape, device_mesh_ids); +} + +void ShardingUtil::ReshardParameters( + const xla::HloModuleProto& module, std::vector* tensors, + std::vector* parameters, + std::vector* nodes) { + // Extract input shardings generated from auto-sharding pass. + std::vector input_shardings; + if (module.spmd_parameters_shardings().size() == 1 && + module.spmd_parameters_shardings()[0].type() == xla::OpSharding::TUPLE) { + auto tuple_shardings = + module.spmd_parameters_shardings()[0].tuple_shardings(); + input_shardings = std::vector(tuple_shardings.begin(), + tuple_shardings.end()); + } else { + for (auto sharding : module.spmd_parameters_shardings()) { + input_shardings.push_back(sharding); + } + } + if (input_shardings.size() == 0) { + TF_VLOG(3) << "ReshardParamters... skip with empty input_shardings."; + return; + } + XLA_CHECK_EQ(input_shardings.size(), parameters->size()); + + // Reshard parameters as needed, as with a new sharding spec. + std::vector data = + UnwrapXlaData(*parameters); + + std::vector reshard_indices; + std::vector data_to_reshard; + std::vector shardings_to_reshard; + for (int i = 0; i < input_shardings.size(); ++i) { + XLA_CHECK(input_shardings[i].type() != xla::OpSharding::UNKNOWN) + << "Resharding by UNKNOWN sharding type is not allowed."; + // Skip re-sharding if not necessary. + if (!xla::protobuf_util::ProtobufEquals(data[i]->GetSharding(), + input_shardings[i])) { + reshard_indices.push_back(i); + data_to_reshard.push_back(data[i]); + shardings_to_reshard.push_back(input_shardings[i]); + } + } + if (reshard_indices.size() == 0) { + TF_VLOG(3) << "ReshardParamters... skip with no new shardings."; + return; + } + TF_VLOG(3) << "ReshardParamters... resharding " << reshard_indices.size() + << " parameters."; + + TORCH_LAZY_COUNTER("ReshardParameters", 1); + + // Construct parameter handle to XlaNode mappping for faster look-up. + std::unordered_map + xla_node_map; + for (const torch::lazy::Node* node : *nodes) { + const auto backend_data = + torch::lazy::getBackend()->GetComputationDataFromNode(node); + if (backend_data) { + torch::lazy::BackendData::Handle handle = backend_data->GetHandle(); + xla_node_map[handle] = node; + } + } + + std::vector outputs; + outputs.reserve(reshard_indices.size()); + // Groupping is computationally more efficient but increases memory + // consumption. It is groupped by default, but can be overriden for + // more-granular control over the peak memory consumption. + bool group_sharding = + runtime::sys_util::GetEnvBool("XLA_AUTO_USE_GROUP_SHARDING", true); + if (group_sharding) { + outputs = WrapXlaData(runtime::GetComputationClient()->ReshardData( + data_to_reshard, shardings_to_reshard)); + } else { + for (int i = 0; i < data_to_reshard.size(); ++i) { + auto output = WrapXlaData(runtime::GetComputationClient()->ReshardData( + {data_to_reshard[i]}, {shardings_to_reshard[i]})); + outputs.insert(outputs.end(), output.begin(), output.end()); + } + } + XLA_CHECK_EQ(outputs.size(), reshard_indices.size()); + + for (int i = 0; i < outputs.size(); ++i) { + (*parameters)[reshard_indices[i]] = outputs[i]; + auto it_node = xla_node_map.find(data_to_reshard[i]->GetHandle()); + XLA_CHECK(it_node != xla_node_map.end()) + << "xla_node_map does not contain " << data_to_reshard[i]->ToString() + << ", target sharding: " << shardings_to_reshard[i].DebugString(); + auto device_data_node = DeviceData::Cast(it_node->second); + device_data_node->SetSharding(shardings_to_reshard[i], 0); + } +} + void ShardingUtil::XlaMarkSharding(const at::Tensor& input, xla::OpSharding sharding) { TORCH_LAZY_COUNTER("XlaMarkSharding", 1); @@ -719,4 +852,14 @@ void ShardingUtil::XlaMarkShardingDynamoCustomOp( ShardingUtil::XlaMarkSharding(input, op_sharding); } +void ShardingUtil::SetAutoSharding() { + // This stays on throughout the program. + use_auto_sharding = true; +} +bool ShardingUtil::GetAutoSharding() { + if (runtime::sys_util::GetEnvBool("XLA_AUTO_SPMD", false)) { + use_auto_sharding = true; + } + return use_auto_sharding; +} } // namespace torch_xla diff --git a/torch_xla/csrc/xla_sharding_util.h b/torch_xla/csrc/xla_sharding_util.h index 601c93945a2..d25aee9e4a2 100644 --- a/torch_xla/csrc/xla_sharding_util.h +++ b/torch_xla/csrc/xla_sharding_util.h @@ -122,6 +122,27 @@ class ShardingUtil { static void XlaMarkSharding(const at::Tensor& input, xla::OpSharding sharding); + //////////////////////////// Auto-Sharding //////////////////////////// + + // Construct a device mesh for auto-sharding pass. Returns a tuple of mesh + // shape and device ids vectors. + static std::tuple, std::vector> + GetAutoShardingMesh(); + + // Reshard the parameters if the expected shardings mismatch. Resharding is + // expensive especially for those already sharded. The cost can easily be + // armotized over multiple steps, though, since the input sharding is + // propagated to the output for the subsequent runs. Sharded data transfer + // during resharding should be asynchronous. It is recommended to keep the + // input sharding on the input data as-is. + static void ReshardParameters( + const xla::HloModuleProto& module, std::vector* tensors, + std::vector* parameters, + std::vector* nodes); + + static void SetAutoSharding(); + static bool GetAutoSharding(); + //////////////////////////// Dynamo Integration //////////////////////////// static void XlaMarkShardingDynamoCustomOp( diff --git a/torch_xla/distributed/spmd/__init__.py b/torch_xla/distributed/spmd/__init__.py index 53376979d3c..87ac6f8e965 100644 --- a/torch_xla/distributed/spmd/__init__.py +++ b/torch_xla/distributed/spmd/__init__.py @@ -3,9 +3,10 @@ XLAPatchedLinear, mark_sharding, clear_sharding, wrap_if_sharded, xla_patched_nn_linear_forward, set_global_mesh, get_global_mesh) -from .api import xla_distribute_tensor, xla_distribute_module +from .api import xla_distribute_tensor, xla_distribute_module, auto_policy __all__ = [ + "auto_policy", "XLAShard", "XLAShardedTensor", "Mesh", diff --git a/torch_xla/distributed/spmd/api.py b/torch_xla/distributed/spmd/api.py index dae46e60d55..a521804ae2d 100644 --- a/torch_xla/distributed/spmd/api.py +++ b/torch_xla/distributed/spmd/api.py @@ -30,12 +30,20 @@ def wrapper( *args: Tuple[object], **kwargs: Dict[str, Any] # type: ignore[misc] ) -> None: - os.environ["XLA_USE_SPMD"] = "1" + xr.use_spmd() return func(self, *args, **kwargs) # type: ignore[misc] return wrapper +# Passing `partition_fn=auto_policy` to xla_distribute_module will +# enable auto-sharding pass. +def auto_policy(mod_name, mod, mesh): + # no-op, xla_distribute_module will check if this is + # `auto_policy` and enable auto-sharding. + return + + @with_xla def convert_to_xla_mesh(dt_mesh: DeviceMesh) -> "Mesh": """ @@ -205,9 +213,16 @@ def xla_distribute_module( """ if partition_fn: - # apply partition_fun to submodules - for name, submod in module.named_modules(): - partition_fn(name, submod, device_mesh) + if getattr(partition_fn, '__name__', 'unknown') == "auto_policy": + # TODO(yeounoh) allow pre-loading to xla device in the future. + assert next(module.parameters()).device != xm.xla_device(), \ + f"Currently requires module to be on cpu, before xla_distribute_module." + xr.use_spmd(auto=True) + module = module.to(xm.xla_device()) + else: + # apply partition_fun to submodules + for name, submod in module.named_modules(): + partition_fn(name, submod, device_mesh) # non-partitioned (annotated) submodules and parameters are implicilty replicated # register input_fn as module forward pre hook diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index 4c5d1cde0c8..1cdc66a20c2 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -371,6 +371,7 @@ class ShardingType(IntEnum): TILED = 3 MANUAL = 4 PARTIAL = 5 + UNKNOWN = 6 # implicit replication. TODO(yeounoh) wait for auto-sharding support def _get_sharding_type(partition_spec: Tuple[Union[int, None]], diff --git a/torch_xla/runtime.py b/torch_xla/runtime.py index 4082e02d2a8..917de07ce29 100644 --- a/torch_xla/runtime.py +++ b/torch_xla/runtime.py @@ -233,10 +233,11 @@ def addressable_runtime_device_count() -> int: # API to enable SPMD mode. This is a recommended way to enable SPMD. -# TODO(yeounoh) this does not block users from using XLA_USE_SPMD flag, yet. -# we will enforce `use_spmd()` once the flag is fully deprecated. +# This forces SPMD mode if some tensors are already initialized on non-SPMD +# devices. This means that those tensors would be replicated across the devices. +# TODO(yeounoh) introduce SPMD configuration. @requires_pjrt -def use_spmd(): +def use_spmd(auto: Optional[bool] = False): if os.environ.get("XLA_USE_SPMD") is not None: warnings.warn("XLA_USE_SPMD is being deprecated. " "Use torch_xla.runtime.use_spmd() " @@ -250,9 +251,13 @@ def use_spmd(): "please set SPMD mode before initializting tensors " "(i.e., call use_spmd() in the beginning of the program).") torch_xla._XLAC._xla_force_spmd_device() + xm.wait_device_ops() - # TODO(yeounoh) replace this when we fully deprecate the flag. + # TODO(yeounoh) we can drop envvar in the future os.environ["XLA_USE_SPMD"] = "1" + if auto: + torch_xla._XLAC._xla_set_auto_sharding() + os.environ["XLA_AUTO_SPMD"] = "1" @requires_pjrt