Skip to content

Commit

Permalink
[SPMD] auto-sharding PoC (#6719)
Browse files Browse the repository at this point in the history
  • Loading branch information
yeounoh authored Mar 14, 2024
1 parent fe3f23c commit 370089a
Show file tree
Hide file tree
Showing 33 changed files with 667 additions and 101 deletions.
29 changes: 29 additions & 0 deletions docs/spmd.md
Original file line number Diff line number Diff line change
Expand Up @@ -492,4 +492,33 @@ generated_table = visualize_sharding(sharding, use_color=False)

You could use these examples on TPU/GPU/CPU single-host and modify it to run on multi-host. And you could modify it to sharding-style `tiled`, `partial_replication` and `replicated`.

### Auto-Sharding
We are introducing a new PyTorch/XLA SPMD feature, called ``auto-sharding``, [RFC](https://github.com/pytorch/xla/issues/6322). This is an experimental feature in `r2.3` and `nightly`, that supports `XLA:TPU` and a single TPUVM host.

PyTorch/XLA auto-sharding can be enabled by one of the following:
- Setting envvar `XLA_SPMD_AUTO=1`
- Calling the SPMD API in the beginning of your code:
```python
import torch_xla.runtime as xr
xr.use_spmd(auto=True)
```
- Calling `pytorch.distributed._tensor.distribute_module` with `auto-policy` and `xla`:
```python
import torch_xla.runtime as xr
from torch.distributed._tensor import DeviceMesh, distribute_module
from torch_xla.distributed.spmd import auto_policy

device_count = xr.global_runtime_device_count()
device_mesh = DeviceMesh("xla", list(range(device_count)))

# Currently, model should be loaded to xla device via distribute_module.
model = MyModule() # nn.module
sharded_model = distribute_module(model, device_mesh, auto_policy)
```

Optionally, one can set the following options/env-vars to control the behvaior of
the XLA-based auto-sharding pass:
- `XLA_AUTO_USE_GROUP_SHARDING`: group resharding of the parameters. Set by default.
- `XLA_AUTO_SPMD_MESH`: logical mesh shape to be used for auto-sharding. For example,
`XLA_AUTO_SPMD_MESH=2,2` corresponds to a 2-by-2 mesh with 4 global devices. If unset,
a default device mesh shape of `num_devices,1` will be used.
2 changes: 2 additions & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ function run_xla_op_tests3 {
run_test "$CDIR/spmd/test_xla_distributed_checkpoint.py"
run_test "$CDIR/spmd/test_xla_spmd_python_api_interaction.py"
run_test "$CDIR/spmd/test_dtensor_integration.py"
run_test "$CDIR/spmd/test_dtensor_integration2.py"
run_test "$CDIR/spmd/test_xla_auto_sharding.py"
run_test "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY
run_test "$CDIR/test_input_output_aliases.py"
run_test "$CDIR/test_torch_distributed_xla_backend.py"
Expand Down
1 change: 1 addition & 0 deletions test/spmd/args_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def parse_common_options(datadir=None,
parser.add_argument('--async_closures', action='store_true')
parser.add_argument('--debug', action='store_true')
parser.add_argument('--profile', action='store_true')
parser.add_argument('--auto_spmd', action='store_true')
if opts:
for name, aopts in opts:
parser.add_argument(name, **aopts)
Expand Down
73 changes: 38 additions & 35 deletions test/spmd/test_dtensor_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
import torch
from torch import nn
import torch.optim as optim
from torch.distributed._tensor import DeviceMesh, Shard
from torch.distributed._tensor import (DeviceMesh, Shard, distribute_tensor,
distribute_module)
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
from torch_xla.distributed.spmd import xla_distribute_tensor, xla_distribute_module
from torch_xla.distributed.spmd import auto_policy

import unittest

Expand All @@ -19,7 +21,6 @@ class DTensorIntegrationTest(test_xla_sharding_base.XlaShardingTest):

@classmethod
def setUpClass(cls):
xr.use_spmd()
super().setUpClass()

def test_xla_distribute_tensor(self):
Expand All @@ -33,8 +34,7 @@ def test_xla_distribute_tensor(self):
3,
requires_grad=requires_grad,
device=xm.xla_device())
dist_tensor = xla_distribute_tensor(tensor_to_shard, device_mesh,
shard_spec)
dist_tensor = distribute_tensor(tensor_to_shard, device_mesh, shard_spec)
# TODO(yeounoh) switch to DTensor API when XLAShardedTensor inherits DTensor
assert type(dist_tensor).__name__ == "XLAShardedTensor"
assert len(dist_tensor.sharding_spec) > 0
Expand All @@ -47,65 +47,68 @@ def test_xla_distribute_tensor(self):
self.assertTrue(dist_tensor.global_tensor.requires_grad)
self.assertTrue(dist_tensor.is_leaf)

def test_xla_distribute_module(self):
def test_optimizer_step_with_sharding(self):
# Use simple linear model to test model parameter sharding
model = self.SimpleLinear().to(xm.xla_device())

# Running the same mark_sharding test with xla_distribute_tensor instead
device_count = xr.global_runtime_device_count()
device_mesh = DeviceMesh("xla", list(range(device_count)))
shard_spec = [Shard(0)]
distribute_tensor(model.fc1.weight, device_mesh, shard_spec)
sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight)

def shard_params(mod_name, mod, mesh):
shard_spec = [Shard(0)]
# annoate fc1 and fc2
if isinstance(mod, nn.Linear):
for name, param in mod.named_parameters():
dist_param = xla_distribute_tensor(param, mesh, shard_spec)

sharded_model = xla_distribute_module(model, device_mesh, shard_params)
self.assertTrue(
torch_xla._XLAC._get_xla_sharding_spec(sharded_model.fc1.weight) != "")
self.assertTrue(
torch_xla._XLAC._get_xla_sharding_spec(sharded_model.fc2.weight) != "")

sharded_model.train()
model.train()
optimizer = optim.SGD(model.parameters(), lr=0.1)
data = torch.randn(128, 128).to(xm.xla_device())
target = torch.zeros(128).to(xm.xla_device())
loss_fn = nn.CrossEntropyLoss()
for i in range(3):
for _ in range(3):
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
xm.mark_step()
# Sharding is persisted across mark_step calls, and test if the sharded computation
# can repeat more than once without crashing.
self.assertEqual(sharding_spec,
torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight))

def test_optimizer_step_with_sharding(self):
# Use simple linear model to test model parameter sharding
def test_xla_distribute_module(self):
model = self.SimpleLinear().to(xm.xla_device())

# Running the same mark_sharding test with xla_distribute_tensor instead
device_count = xr.global_runtime_device_count()
device_mesh = DeviceMesh("xla", list(range(device_count)))
shard_spec = [Shard(0)]
xla_distribute_tensor(model.fc1.weight, device_mesh, shard_spec)
sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight)

model.train()
optimizer = optim.SGD(model.parameters(), lr=0.1)
def shard_params(mod_name, mod, mesh):
shard_spec = [Shard(0)]
# annoate fc1 and fc2
if isinstance(mod, nn.Linear):
for name, param in mod.named_parameters():
dist_param = distribute_tensor(param, mesh, shard_spec)

sharded_model = distribute_module(model, device_mesh, shard_params)
self.assertTrue(
torch_xla._XLAC._get_xla_sharding_spec(sharded_model.fc1.weight) != "")
self.assertTrue(
torch_xla._XLAC._get_xla_sharding_spec(sharded_model.fc2.weight) != "")

sharded_model.train()
optimizer = optim.SGD(sharded_model.parameters(), lr=0.1)
data = torch.randn(128, 128).to(xm.xla_device())
target = torch.zeros(128).to(xm.xla_device())
loss_fn = nn.CrossEntropyLoss()
for i in range(3):
for _ in range(3):
optimizer.zero_grad()
output = model(data)
output = sharded_model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
xm.mark_step()
# Sharding is persisted across mark_step calls, and test if the sharded computation
# can repeat more than once without crashing.
self.assertEqual(sharding_spec,
torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight))
# Should run with SPMD mode, ExecuteReplicated.
self.assertTrue(met.counter_value("ExecuteReplicated") > 0)
self.assertTrue(met.counter_value("ExecuteComputation") is None)


if __name__ == '__main__':
Expand Down
58 changes: 58 additions & 0 deletions test/spmd/test_dtensor_integration2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import os
import sys

import torch
from torch import nn
import torch.optim as optim
from torch.distributed._tensor import (DeviceMesh, Shard, distribute_tensor,
distribute_module)
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
from torch_xla.distributed.spmd import auto_policy

import unittest

import test_xla_sharding_base


# This integration test passes when run independently.
class DTensorIntegrationTest2(test_xla_sharding_base.XlaShardingTest):

@classmethod
def setUpClass(cls):
super().setUpClass()

@unittest.skipUnless(xr.device_type() in ["TPU", "CPU"],
"Auto-sharding currently supports TPU device.")
def test_xla_distribute_module_auto(self):
device_count = xr.global_runtime_device_count()
device_mesh = DeviceMesh("xla", list(range(device_count)))

# Use torch_xla.distributed.spmd.auto_policy to enable auto-sharding;
# Currently, model should be loaded to xla device via distribute_module.
model = self.SimpleLinear()
sharded_model = distribute_module(model, device_mesh, auto_policy)
sharded_model.train()
self.assertTrue(torch_xla._XLAC._xla_get_auto_sharding())

optimizer = optim.SGD(sharded_model.parameters(), lr=0.1)
data = torch.randn(128, 128).to(xm.xla_device())
target = torch.zeros(128).to(xm.xla_device())
loss_fn = nn.CrossEntropyLoss()
for _ in range(5):
optimizer.zero_grad()
output = sharded_model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
xm.mark_step()
# Should compile with auto-sharding, we expect up to 3 times
cnt = met.counter_value("CompileWithAutoSharding")
self.assertTrue((cnt is not None) and (cnt <= 3))


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
1 change: 0 additions & 1 deletion test/spmd/test_dynamo_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ class DynamoSpmdInferenceTest(test_xla_sharding_base.XlaShardingTest):

@classmethod
def setUpClass(cls):
xr.use_spmd()
super().setUpClass()

def test_dynamo_spmd_basic(self):
Expand Down
1 change: 0 additions & 1 deletion test/spmd/test_spmd_graph_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ class SpmdGraphDumpTest(test_xla_sharding_base.XlaShardingTest):

@classmethod
def setUpClass(cls):
xr.use_spmd()
super().setUpClass()

def test_dump_with_output_sharding(self):
Expand Down
2 changes: 2 additions & 0 deletions test/spmd/test_train_spmd_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@
import torch_xla.test.test_utils as test_utils
import torch_xla.distributed.spmd as xs

xr.use_spmd(auto=FLAGS.auto_spmd)

DEFAULT_KWARGS = dict(
batch_size=128,
test_set_batch_size=64,
Expand Down
2 changes: 2 additions & 0 deletions test/spmd/test_train_spmd_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
FLAGS = args_parse.parse_common_options(
batch_size=128, num_epochs=1, opts=MODEL_OPTS.items())

xr.use_spmd(auto=FLAGS.auto_spmd)


class SimpleLinear(nn.Module):

Expand Down
74 changes: 74 additions & 0 deletions test/spmd/test_xla_auto_sharding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import copy

import unittest
from unittest.mock import patch
import math
import numpy as np
import os
import sys

import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.spmd as xs
from torch_xla.distributed.spmd import XLAShardedTensor
import test_xla_sharding_base

import torch_xla.core.xla_env_vars as xenv
import torch_xla.utils.utils as xu
from torch_xla._internal import tpu


class XlaAutoShardingTest(test_xla_sharding_base.XlaShardingTest):

@classmethod
def setUpClass(cls):
xr.use_spmd(auto=True)
super().setUpClass()

@unittest.skipUnless(xr.device_type() in ["TPU", "CPU"],
"Auto-sharding currently supports TPU & CPU backends.")
def test_matmul(self):
met.clear_counters()
t1 = torch.ones(64, 128)
t2 = torch.ones(128, 256)
t3 = (t1 @ t2).sum()

xt1 = t1.to(xm.xla_device())
xt2 = t2.to(xm.xla_device())
xt3 = (xt1 @ xt2).sum()
xm.mark_step()
self.assertEqual(met.counter_value("CompileWithAutoSharding"), 1)
self.assertTrue(torch.allclose(t3, xt3.cpu()))

@unittest.skipUnless(xr.device_type() in ["TPU", "CPU"],
"Auto-sharding currently supports TPU & CPU backends.")
def test_simple_linear_training(self):
met.clear_counters()

model = self.SimpleLinear().to(xm.xla_device())
model.train()
optimizer = optim.SGD(model.parameters(), lr=0.1)
data = torch.randn(128, 128).to(xm.xla_device())
target = torch.zeros(128).to(xm.xla_device())
loss_fn = nn.CrossEntropyLoss()
for i in range(5):
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
xm.mark_step()
cnt = met.counter_value("CompileWithAutoSharding")
self.assertTrue((cnt is not None) and (cnt <= 3))


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
1 change: 0 additions & 1 deletion test/spmd/test_xla_distributed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ class DistributedCheckpointTestBase(test_xla_sharding_base.XlaShardingTest):

@classmethod
def setUpClass(cls):
xr.use_spmd()
super().setUpClass()

def _get_sharded_model(self, mesh_shape=None):
Expand Down
3 changes: 0 additions & 3 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ class BasicXlaShardingTest(test_xla_sharding_base.XlaShardingTest):

@classmethod
def setUpClass(cls):
xr.use_spmd()
super().setUpClass()

def test_xla_sharded_tensor(self):
Expand All @@ -38,8 +37,6 @@ def test_xla_sharded_tensor(self):
device=xm.xla_device())
xst1 = xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)),
partition_spec)

# TODO(244003536) add more tests for XLAShardedTensror.
self.assertTrue(isinstance(xst1, XLAShardedTensor))

def test_xla_sharded_tensor_repr(self):
Expand Down
8 changes: 8 additions & 0 deletions test/spmd/test_xla_sharding_base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import unittest
import numpy as np

Expand Down Expand Up @@ -31,6 +32,13 @@ def forward(self, x):
def setUpClass(cls):
cls.n_devices = xr.global_runtime_device_count()
cls.device_ids = np.array(range(cls.n_devices))
xr.use_spmd()

@classmethod
def tearDownClass(cls):
del os.environ['XLA_USE_SPMD']
if 'XLA_AUTO_SPMD' in os.environ:
del os.environ['XLA_AUTO_SPMD']

def _get_mesh(self, mesh_shape, device_ids=None, axis_names=None):
assert type(mesh_shape) is tuple, 'mesh_shape must be Tuple[int]'
Expand Down
Loading

0 comments on commit 370089a

Please sign in to comment.