-
Notifications
You must be signed in to change notification settings - Fork 486
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
33 changed files
with
667 additions
and
101 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import os | ||
import sys | ||
|
||
import torch | ||
from torch import nn | ||
import torch.optim as optim | ||
from torch.distributed._tensor import (DeviceMesh, Shard, distribute_tensor, | ||
distribute_module) | ||
import torch_xla | ||
import torch_xla.debug.metrics as met | ||
import torch_xla.runtime as xr | ||
import torch_xla.core.xla_model as xm | ||
from torch_xla.distributed.spmd import auto_policy | ||
|
||
import unittest | ||
|
||
import test_xla_sharding_base | ||
|
||
|
||
# This integration test passes when run independently. | ||
class DTensorIntegrationTest2(test_xla_sharding_base.XlaShardingTest): | ||
|
||
@classmethod | ||
def setUpClass(cls): | ||
super().setUpClass() | ||
|
||
@unittest.skipUnless(xr.device_type() in ["TPU", "CPU"], | ||
"Auto-sharding currently supports TPU device.") | ||
def test_xla_distribute_module_auto(self): | ||
device_count = xr.global_runtime_device_count() | ||
device_mesh = DeviceMesh("xla", list(range(device_count))) | ||
|
||
# Use torch_xla.distributed.spmd.auto_policy to enable auto-sharding; | ||
# Currently, model should be loaded to xla device via distribute_module. | ||
model = self.SimpleLinear() | ||
sharded_model = distribute_module(model, device_mesh, auto_policy) | ||
sharded_model.train() | ||
self.assertTrue(torch_xla._XLAC._xla_get_auto_sharding()) | ||
|
||
optimizer = optim.SGD(sharded_model.parameters(), lr=0.1) | ||
data = torch.randn(128, 128).to(xm.xla_device()) | ||
target = torch.zeros(128).to(xm.xla_device()) | ||
loss_fn = nn.CrossEntropyLoss() | ||
for _ in range(5): | ||
optimizer.zero_grad() | ||
output = sharded_model(data) | ||
loss = loss_fn(output, target) | ||
loss.backward() | ||
optimizer.step() | ||
xm.mark_step() | ||
# Should compile with auto-sharding, we expect up to 3 times | ||
cnt = met.counter_value("CompileWithAutoSharding") | ||
self.assertTrue((cnt is not None) and (cnt <= 3)) | ||
|
||
|
||
if __name__ == '__main__': | ||
test = unittest.main() | ||
sys.exit(0 if test.result.wasSuccessful() else 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import copy | ||
|
||
import unittest | ||
from unittest.mock import patch | ||
import math | ||
import numpy as np | ||
import os | ||
import sys | ||
|
||
import torch | ||
from torch import nn | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
import torch_xla | ||
import torch_xla.debug.metrics as met | ||
import torch_xla.runtime as xr | ||
import torch_xla.core.xla_model as xm | ||
import torch_xla.debug.metrics as met | ||
import torch_xla.distributed.spmd as xs | ||
from torch_xla.distributed.spmd import XLAShardedTensor | ||
import test_xla_sharding_base | ||
|
||
import torch_xla.core.xla_env_vars as xenv | ||
import torch_xla.utils.utils as xu | ||
from torch_xla._internal import tpu | ||
|
||
|
||
class XlaAutoShardingTest(test_xla_sharding_base.XlaShardingTest): | ||
|
||
@classmethod | ||
def setUpClass(cls): | ||
xr.use_spmd(auto=True) | ||
super().setUpClass() | ||
|
||
@unittest.skipUnless(xr.device_type() in ["TPU", "CPU"], | ||
"Auto-sharding currently supports TPU & CPU backends.") | ||
def test_matmul(self): | ||
met.clear_counters() | ||
t1 = torch.ones(64, 128) | ||
t2 = torch.ones(128, 256) | ||
t3 = (t1 @ t2).sum() | ||
|
||
xt1 = t1.to(xm.xla_device()) | ||
xt2 = t2.to(xm.xla_device()) | ||
xt3 = (xt1 @ xt2).sum() | ||
xm.mark_step() | ||
self.assertEqual(met.counter_value("CompileWithAutoSharding"), 1) | ||
self.assertTrue(torch.allclose(t3, xt3.cpu())) | ||
|
||
@unittest.skipUnless(xr.device_type() in ["TPU", "CPU"], | ||
"Auto-sharding currently supports TPU & CPU backends.") | ||
def test_simple_linear_training(self): | ||
met.clear_counters() | ||
|
||
model = self.SimpleLinear().to(xm.xla_device()) | ||
model.train() | ||
optimizer = optim.SGD(model.parameters(), lr=0.1) | ||
data = torch.randn(128, 128).to(xm.xla_device()) | ||
target = torch.zeros(128).to(xm.xla_device()) | ||
loss_fn = nn.CrossEntropyLoss() | ||
for i in range(5): | ||
optimizer.zero_grad() | ||
output = model(data) | ||
loss = loss_fn(output, target) | ||
loss.backward() | ||
optimizer.step() | ||
xm.mark_step() | ||
cnt = met.counter_value("CompileWithAutoSharding") | ||
self.assertTrue((cnt is not None) and (cnt <= 3)) | ||
|
||
|
||
if __name__ == '__main__': | ||
test = unittest.main() | ||
sys.exit(0 if test.result.wasSuccessful() else 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.