Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[r2.3] backport: auto-sharding PoC (#6719) #6755

Merged
merged 2 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions docs/spmd.md
Original file line number Diff line number Diff line change
Expand Up @@ -492,4 +492,33 @@ generated_table = visualize_sharding(sharding, use_color=False)

You could use these examples on TPU/GPU/CPU single-host and modify it to run on multi-host. And you could modify it to sharding-style `tiled`, `partial_replication` and `replicated`.

### Auto-Sharding
We are introducing a new PyTorch/XLA SPMD feature, called ``auto-sharding``, [RFC](https://github.com/pytorch/xla/issues/6322). This is an experimental feature in `r2.3` and `nightly`, that supports `XLA:TPU` and a single TPUVM host.

PyTorch/XLA auto-sharding can be enabled by one of the following:
- Setting envvar `XLA_SPMD_AUTO=1`
- Calling the SPMD API in the beginning of your code:
```python
import torch_xla.runtime as xr
xr.use_spmd(auto=True)
```
- Calling `pytorch.distributed._tensor.distribute_module` with `auto-policy` and `xla`:
```python
import torch_xla.runtime as xr
from torch.distributed._tensor import DeviceMesh, distribute_module
from torch_xla.distributed.spmd import auto_policy

device_count = xr.global_runtime_device_count()
device_mesh = DeviceMesh("xla", list(range(device_count)))

# Currently, model should be loaded to xla device via distribute_module.
model = MyModule() # nn.module
sharded_model = distribute_module(model, device_mesh, auto_policy)
```

Optionally, one can set the following options/env-vars to control the behvaior of
the XLA-based auto-sharding pass:
- `XLA_AUTO_USE_GROUP_SHARDING`: group resharding of the parameters. Set by default.
- `XLA_AUTO_SPMD_MESH`: logical mesh shape to be used for auto-sharding. For example,
`XLA_AUTO_SPMD_MESH=2,2` corresponds to a 2-by-2 mesh with 4 global devices. If unset,
a default device mesh shape of `num_devices,1` will be used.
2 changes: 2 additions & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ function run_xla_op_tests3 {
run_test "$CDIR/spmd/test_xla_distributed_checkpoint.py"
run_test "$CDIR/spmd/test_xla_spmd_python_api_interaction.py"
run_test "$CDIR/spmd/test_dtensor_integration.py"
run_test "$CDIR/spmd/test_dtensor_integration2.py"
run_test "$CDIR/spmd/test_xla_auto_sharding.py"
run_test "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY
run_test "$CDIR/test_input_output_aliases.py"
run_test "$CDIR/test_torch_distributed_xla_backend.py"
Expand Down
1 change: 1 addition & 0 deletions test/spmd/args_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def parse_common_options(datadir=None,
parser.add_argument('--async_closures', action='store_true')
parser.add_argument('--debug', action='store_true')
parser.add_argument('--profile', action='store_true')
parser.add_argument('--auto_spmd', action='store_true')
if opts:
for name, aopts in opts:
parser.add_argument(name, **aopts)
Expand Down
73 changes: 38 additions & 35 deletions test/spmd/test_dtensor_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
import torch
from torch import nn
import torch.optim as optim
from torch.distributed._tensor import DeviceMesh, Shard
from torch.distributed._tensor import (DeviceMesh, Shard, distribute_tensor,
distribute_module)
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
from torch_xla.distributed.spmd import xla_distribute_tensor, xla_distribute_module
from torch_xla.distributed.spmd import auto_policy

import unittest

Expand All @@ -19,7 +21,6 @@ class DTensorIntegrationTest(test_xla_sharding_base.XlaShardingTest):

@classmethod
def setUpClass(cls):
xr.use_spmd()
super().setUpClass()

def test_xla_distribute_tensor(self):
Expand All @@ -33,8 +34,7 @@ def test_xla_distribute_tensor(self):
3,
requires_grad=requires_grad,
device=xm.xla_device())
dist_tensor = xla_distribute_tensor(tensor_to_shard, device_mesh,
shard_spec)
dist_tensor = distribute_tensor(tensor_to_shard, device_mesh, shard_spec)
# TODO(yeounoh) switch to DTensor API when XLAShardedTensor inherits DTensor
assert type(dist_tensor).__name__ == "XLAShardedTensor"
assert len(dist_tensor.sharding_spec) > 0
Expand All @@ -47,65 +47,68 @@ def test_xla_distribute_tensor(self):
self.assertTrue(dist_tensor.global_tensor.requires_grad)
self.assertTrue(dist_tensor.is_leaf)

def test_xla_distribute_module(self):
def test_optimizer_step_with_sharding(self):
# Use simple linear model to test model parameter sharding
model = self.SimpleLinear().to(xm.xla_device())

# Running the same mark_sharding test with xla_distribute_tensor instead
device_count = xr.global_runtime_device_count()
device_mesh = DeviceMesh("xla", list(range(device_count)))
shard_spec = [Shard(0)]
distribute_tensor(model.fc1.weight, device_mesh, shard_spec)
sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight)

def shard_params(mod_name, mod, mesh):
shard_spec = [Shard(0)]
# annoate fc1 and fc2
if isinstance(mod, nn.Linear):
for name, param in mod.named_parameters():
dist_param = xla_distribute_tensor(param, mesh, shard_spec)

sharded_model = xla_distribute_module(model, device_mesh, shard_params)
self.assertTrue(
torch_xla._XLAC._get_xla_sharding_spec(sharded_model.fc1.weight) != "")
self.assertTrue(
torch_xla._XLAC._get_xla_sharding_spec(sharded_model.fc2.weight) != "")

sharded_model.train()
model.train()
optimizer = optim.SGD(model.parameters(), lr=0.1)
data = torch.randn(128, 128).to(xm.xla_device())
target = torch.zeros(128).to(xm.xla_device())
loss_fn = nn.CrossEntropyLoss()
for i in range(3):
for _ in range(3):
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
xm.mark_step()
# Sharding is persisted across mark_step calls, and test if the sharded computation
# can repeat more than once without crashing.
self.assertEqual(sharding_spec,
torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight))

def test_optimizer_step_with_sharding(self):
# Use simple linear model to test model parameter sharding
def test_xla_distribute_module(self):
model = self.SimpleLinear().to(xm.xla_device())

# Running the same mark_sharding test with xla_distribute_tensor instead
device_count = xr.global_runtime_device_count()
device_mesh = DeviceMesh("xla", list(range(device_count)))
shard_spec = [Shard(0)]
xla_distribute_tensor(model.fc1.weight, device_mesh, shard_spec)
sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight)

model.train()
optimizer = optim.SGD(model.parameters(), lr=0.1)
def shard_params(mod_name, mod, mesh):
shard_spec = [Shard(0)]
# annoate fc1 and fc2
if isinstance(mod, nn.Linear):
for name, param in mod.named_parameters():
dist_param = distribute_tensor(param, mesh, shard_spec)

sharded_model = distribute_module(model, device_mesh, shard_params)
self.assertTrue(
torch_xla._XLAC._get_xla_sharding_spec(sharded_model.fc1.weight) != "")
self.assertTrue(
torch_xla._XLAC._get_xla_sharding_spec(sharded_model.fc2.weight) != "")

sharded_model.train()
optimizer = optim.SGD(sharded_model.parameters(), lr=0.1)
data = torch.randn(128, 128).to(xm.xla_device())
target = torch.zeros(128).to(xm.xla_device())
loss_fn = nn.CrossEntropyLoss()
for i in range(3):
for _ in range(3):
optimizer.zero_grad()
output = model(data)
output = sharded_model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
xm.mark_step()
# Sharding is persisted across mark_step calls, and test if the sharded computation
# can repeat more than once without crashing.
self.assertEqual(sharding_spec,
torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight))
# Should run with SPMD mode, ExecuteReplicated.
self.assertTrue(met.counter_value("ExecuteReplicated") > 0)
self.assertTrue(met.counter_value("ExecuteComputation") is None)


if __name__ == '__main__':
Expand Down
58 changes: 58 additions & 0 deletions test/spmd/test_dtensor_integration2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import os
import sys

import torch
from torch import nn
import torch.optim as optim
from torch.distributed._tensor import (DeviceMesh, Shard, distribute_tensor,
distribute_module)
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
from torch_xla.distributed.spmd import auto_policy

import unittest

import test_xla_sharding_base


# This integration test passes when run independently.
class DTensorIntegrationTest2(test_xla_sharding_base.XlaShardingTest):

@classmethod
def setUpClass(cls):
super().setUpClass()

@unittest.skipUnless(xr.device_type() in ["TPU", "CPU"],
"Auto-sharding currently supports TPU device.")
def test_xla_distribute_module_auto(self):
device_count = xr.global_runtime_device_count()
device_mesh = DeviceMesh("xla", list(range(device_count)))

# Use torch_xla.distributed.spmd.auto_policy to enable auto-sharding;
# Currently, model should be loaded to xla device via distribute_module.
model = self.SimpleLinear()
sharded_model = distribute_module(model, device_mesh, auto_policy)
sharded_model.train()
self.assertTrue(torch_xla._XLAC._xla_get_auto_sharding())

optimizer = optim.SGD(sharded_model.parameters(), lr=0.1)
data = torch.randn(128, 128).to(xm.xla_device())
target = torch.zeros(128).to(xm.xla_device())
loss_fn = nn.CrossEntropyLoss()
for _ in range(5):
optimizer.zero_grad()
output = sharded_model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
xm.mark_step()
# Should compile with auto-sharding, we expect up to 3 times
cnt = met.counter_value("CompileWithAutoSharding")
self.assertTrue((cnt is not None) and (cnt <= 3))


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
1 change: 0 additions & 1 deletion test/spmd/test_dynamo_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ class DynamoSpmdInferenceTest(test_xla_sharding_base.XlaShardingTest):

@classmethod
def setUpClass(cls):
xr.use_spmd()
super().setUpClass()

def test_dynamo_spmd_basic(self):
Expand Down
1 change: 0 additions & 1 deletion test/spmd/test_spmd_graph_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ class SpmdGraphDumpTest(test_xla_sharding_base.XlaShardingTest):

@classmethod
def setUpClass(cls):
xr.use_spmd()
super().setUpClass()

def test_dump_with_output_sharding(self):
Expand Down
2 changes: 2 additions & 0 deletions test/spmd/test_train_spmd_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@
import torch_xla.test.test_utils as test_utils
import torch_xla.distributed.spmd as xs

xr.use_spmd(auto=FLAGS.auto_spmd)

DEFAULT_KWARGS = dict(
batch_size=128,
test_set_batch_size=64,
Expand Down
2 changes: 2 additions & 0 deletions test/spmd/test_train_spmd_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
FLAGS = args_parse.parse_common_options(
batch_size=128, num_epochs=1, opts=MODEL_OPTS.items())

xr.use_spmd(auto=FLAGS.auto_spmd)


class SimpleLinear(nn.Module):

Expand Down
74 changes: 74 additions & 0 deletions test/spmd/test_xla_auto_sharding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import copy

import unittest
from unittest.mock import patch
import math
import numpy as np
import os
import sys

import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.spmd as xs
from torch_xla.distributed.spmd import XLAShardedTensor
import test_xla_sharding_base

import torch_xla.core.xla_env_vars as xenv
import torch_xla.utils.utils as xu
from torch_xla._internal import tpu


class XlaAutoShardingTest(test_xla_sharding_base.XlaShardingTest):

@classmethod
def setUpClass(cls):
xr.use_spmd(auto=True)
super().setUpClass()

@unittest.skipUnless(xr.device_type() in ["TPU", "CPU"],
"Auto-sharding currently supports TPU & CPU backends.")
def test_matmul(self):
met.clear_counters()
t1 = torch.ones(64, 128)
t2 = torch.ones(128, 256)
t3 = (t1 @ t2).sum()

xt1 = t1.to(xm.xla_device())
xt2 = t2.to(xm.xla_device())
xt3 = (xt1 @ xt2).sum()
xm.mark_step()
self.assertEqual(met.counter_value("CompileWithAutoSharding"), 1)
self.assertTrue(torch.allclose(t3, xt3.cpu()))

@unittest.skipUnless(xr.device_type() in ["TPU", "CPU"],
"Auto-sharding currently supports TPU & CPU backends.")
def test_simple_linear_training(self):
met.clear_counters()

model = self.SimpleLinear().to(xm.xla_device())
model.train()
optimizer = optim.SGD(model.parameters(), lr=0.1)
data = torch.randn(128, 128).to(xm.xla_device())
target = torch.zeros(128).to(xm.xla_device())
loss_fn = nn.CrossEntropyLoss()
for i in range(5):
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
xm.mark_step()
cnt = met.counter_value("CompileWithAutoSharding")
self.assertTrue((cnt is not None) and (cnt <= 3))


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
1 change: 0 additions & 1 deletion test/spmd/test_xla_distributed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ class DistributedCheckpointTestBase(test_xla_sharding_base.XlaShardingTest):

@classmethod
def setUpClass(cls):
xr.use_spmd()
super().setUpClass()

def _get_sharded_model(self, mesh_shape=None):
Expand Down
3 changes: 0 additions & 3 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ class BasicXlaShardingTest(test_xla_sharding_base.XlaShardingTest):

@classmethod
def setUpClass(cls):
xr.use_spmd()
super().setUpClass()

def test_xla_sharded_tensor(self):
Expand All @@ -38,8 +37,6 @@ def test_xla_sharded_tensor(self):
device=xm.xla_device())
xst1 = xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)),
partition_spec)

# TODO(244003536) add more tests for XLAShardedTensror.
self.assertTrue(isinstance(xst1, XLAShardedTensor))

def test_xla_sharded_tensor_repr(self):
Expand Down
8 changes: 8 additions & 0 deletions test/spmd/test_xla_sharding_base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import unittest
import numpy as np

Expand Down Expand Up @@ -31,6 +32,13 @@ def forward(self, x):
def setUpClass(cls):
cls.n_devices = xr.global_runtime_device_count()
cls.device_ids = np.array(range(cls.n_devices))
xr.use_spmd()

@classmethod
def tearDownClass(cls):
del os.environ['XLA_USE_SPMD']
if 'XLA_AUTO_SPMD' in os.environ:
del os.environ['XLA_AUTO_SPMD']

def _get_mesh(self, mesh_shape, device_ids=None, axis_names=None):
assert type(mesh_shape) is tuple, 'mesh_shape must be Tuple[int]'
Expand Down
Loading
Loading