-
Notifications
You must be signed in to change notification settings - Fork 480
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement mark_sharding as a custom op to support dynamo spmd activation sharding #5712
Changes from 11 commits
53e7ec0
f0e8a94
7891b42
6aeeecf
dc19b9b
a20d710
bd169c2
ae05c9a
08d6296
9aaa533
a98bfb2
a4318a6
e35ca64
ae721ae
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,16 +15,22 @@ | |
|
||
class SimpleLinear(nn.Module): | ||
|
||
def __init__(self): | ||
def __init__(self, mesh=None): | ||
super(SimpleLinear, self).__init__() | ||
self.fc1 = nn.Linear(128, 128) | ||
self.relu = nn.ReLU() | ||
self.fc2 = nn.Linear(128, 1) | ||
# Add an additional 1x1 layer at the end to ensure the final layer | ||
# is not sharded. | ||
self.fc3 = nn.Linear(1, 1) | ||
# If mesh is not none, we'll do a mark sharding inside the forward function | ||
# to ensure dynamo can recognize and trace it in a torch compile. | ||
self.mesh = mesh | ||
|
||
def forward(self, x): | ||
if self.mesh and 'xla' in str(self.fc2.weight.device): | ||
xs.mark_sharding( | ||
self.fc2.weight, self.mesh, (1, 0), dynamo_custom_op=True) | ||
y = self.relu(self.fc1(x)) | ||
z = self.fc2(y) | ||
return self.fc3(z) | ||
|
@@ -53,6 +59,22 @@ def test_dynamo_spmd_basic(self): | |
# TODO(JackCaoG): add counter checks after ExecuteReplicated also creates | ||
# a ExecuteMetric. | ||
|
||
def test_dynamo_spmd_basic_with_custom_mark_sharding_op(self): | ||
wonjoolee95 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
device = xm.xla_device() | ||
linear = SimpleLinear().to(device) | ||
linear.eval() | ||
xla_x = torch.randn(1, 128, device=device) | ||
xs.mark_sharding( | ||
linear.fc2.weight, | ||
self._get_mesh((1, self.n_devices)), (1, 0), | ||
dynamo_custom_op=True) | ||
xla_res = linear(xla_x) | ||
xm.mark_step() | ||
|
||
dynamo_linear = torch.compile(linear, backend="openxla") | ||
dynamo_res = dynamo_linear(xla_x) | ||
torch.allclose(xla_res.cpu(), dynamo_res.cpu()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we add a counter check? We want to make sure we are not recompiling across differerent runs. You can either add it here or add a separate test. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated |
||
|
||
def test_dynamo_spmd_output_sharding_spec(self): | ||
device = xm.xla_device() | ||
linear = SimpleLinear().to(device) | ||
|
@@ -171,6 +193,23 @@ def test_dynamo_input_sharding_threashold(self): | |
else: | ||
del os.environ['XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD'] | ||
|
||
def test_mark_sharding_inside_compile(self): | ||
device = xm.xla_device() | ||
mesh = self._get_mesh((1, self.n_devices)) | ||
|
||
# Passing this `mesh` as a parameter to `SimpleLinear` will call the dynamo custom op | ||
# variant of mark_sharding inside the forward function. | ||
linear = SimpleLinear(mesh=mesh).to(device) | ||
linear.eval() | ||
|
||
xla_x = torch.randn(1, 128, device=device) | ||
xla_res = linear(xla_x) | ||
xm.mark_step() | ||
|
||
dynamo_linear = torch.compile(linear, backend="openxla") | ||
dynamo_res = dynamo_linear(xla_x) | ||
torch.allclose(xla_res.cpu(), dynamo_res.cpu()) | ||
|
||
|
||
if __name__ == '__main__': | ||
test = unittest.main() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,6 +29,7 @@ | |
#include "pybind11/pytypes.h" | ||
#include "pybind11/stl_bind.h" | ||
#include "torch_xla/csrc/XLANativeFunctions.h" | ||
#include "torch_xla/csrc/aten_autograd_ops.h" | ||
#include "torch_xla/csrc/aten_xla_bridge.h" | ||
#include "torch_xla/csrc/device.h" | ||
#include "torch_xla/csrc/helpers.h" | ||
|
@@ -651,6 +652,136 @@ std::string GetPyTypeString(py::handle obj) { | |
return type; | ||
} | ||
|
||
void xla_mark_sharding(const at::Tensor& input, xla::OpSharding sharding) { | ||
TORCH_LAZY_COUNTER("XlaMarkSharding", 1); | ||
XLA_CHECK(UseVirtualDevice()) | ||
<< "Please enable SPMD via `torch_xla.runtime.use_spmd()`"; | ||
XLATensorPtr xtensor = bridge::GetXlaTensor(input); | ||
auto new_sharding_spec = std::make_shared<XLATensor::ShardingSpec>( | ||
sharding, MakeShapeWithDeviceLayout( | ||
xtensor->shape(), | ||
static_cast<XlaDeviceType>(xtensor->GetDevice().type()))); | ||
|
||
// For Non DeviceData IR values, we directly attach the sharding spec | ||
// to the xtensor. | ||
const DeviceData* device_data_node = nullptr; | ||
if (xtensor->CurrentIrValue()) { | ||
device_data_node = DeviceData::Cast(xtensor->CurrentIrValue().node.get()); | ||
if (!device_data_node) { | ||
tensor_methods::custom_sharding_(xtensor, new_sharding_spec); | ||
return; | ||
} | ||
} | ||
|
||
// For data, we need to deal with the data transfers between | ||
// host and device. | ||
at::Tensor cpu_tensor; | ||
if (xtensor->CurrentTensorData().has_value()) { | ||
TORCH_LAZY_COUNTER("VirtualDeviceUsage", 1); | ||
// When virtual device is enabled for SPMD, we defer the initial | ||
// data transfer to the device and retain the original data on the | ||
// host, until the sharded data transfer. | ||
cpu_tensor = xtensor->CurrentTensorData().value(); | ||
} else { | ||
// A new input tensor is not expected to be sharded. But sometimes, | ||
// the same input is called for sharding annotation over multiple steps, | ||
// in which case we can skip if it's the same sharding; however, if it's | ||
// the same input with a different sharding then we block & ask the user | ||
// to clear the existing sharding first. | ||
auto current_sharding_spec = xtensor->sharding_spec(); | ||
if (current_sharding_spec && (current_sharding_spec->sharding.type() != | ||
xla::OpSharding::REPLICATED)) { | ||
XLA_CHECK(ShardingUtil::EqualShardingSpecs(*new_sharding_spec, | ||
*current_sharding_spec)) | ||
<< "Existing annotation must be cleared first."; | ||
return; | ||
} | ||
|
||
// If the at::Tensor data is not present, we need to re-download the | ||
// tensor from the physical device to CPU. In that case, the value | ||
// must be present on the backend device. | ||
XLA_CHECK((xtensor->CurrentDataHandle() && | ||
xtensor->CurrentDataHandle()->HasValue()) || | ||
device_data_node != nullptr) | ||
<< "Cannot shard tensor. Data does not present on any device."; | ||
std::vector<XLATensorPtr> xla_tensors{xtensor}; | ||
cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0]; | ||
} | ||
auto xla_data = CreateTensorsData( | ||
std::vector<at::Tensor>{cpu_tensor}, | ||
std::vector<XLATensor::ShardingSpecPtr>{new_sharding_spec}, | ||
std::vector<std::string>{GetVirtualDevice().toString()})[0]; | ||
xtensor->SetXlaData(xla_data); | ||
xtensor->SetShardingSpec(*new_sharding_spec); | ||
|
||
// Register sharded tensor data. | ||
XLAGraphExecutor::Get()->RegisterTensor(xtensor->data()); | ||
} | ||
|
||
void xla_mark_sharding_dynamo_custom_op( | ||
const at::Tensor& input, c10::List<at::IntArrayRef> tile_assignment, | ||
c10::List<at::IntArrayRef> group_assignment, | ||
c10::List<at::IntArrayRef> replication_groups, int64_t sharding_type) { | ||
py::list tile_assignment_py = py::list(); | ||
for (int i = 0; i < tile_assignment.size(); i++) { | ||
py::list pylist = py::list(); | ||
for (int64_t t : tile_assignment[i].get().toIntList()) { | ||
pylist.append(t); | ||
} | ||
tile_assignment_py.append(pylist); | ||
} | ||
|
||
py::list group_assignment_py = py::list(); | ||
for (int i = 0; i < group_assignment.size(); i++) { | ||
py::list pylist = py::list(); | ||
for (int64_t t : group_assignment[i].get().toIntList()) { | ||
pylist.append(t); | ||
} | ||
group_assignment_py.append(pylist); | ||
} | ||
|
||
py::list replication_groups_py = py::list(); | ||
for (int i = 0; i < replication_groups.size(); i++) { | ||
py::list pylist = py::list(); | ||
for (int64_t t : replication_groups[i].get().toIntList()) { | ||
pylist.append(t); | ||
} | ||
replication_groups_py.append(pylist); | ||
} | ||
|
||
xla::OpSharding op_sharding = ShardingUtil::CreateOpSharding( | ||
tile_assignment_py, group_assignment_py, replication_groups_py, | ||
ShardingUtil::ShardingType(sharding_type)); | ||
|
||
xla_mark_sharding(input, op_sharding); | ||
} | ||
|
||
// Macro for defining a function that will be run at static initialization time | ||
// to define a library of operators in the namespace. Used to define a new set | ||
// of custom operators that do not already exist in PyTorch. | ||
TORCH_LIBRARY(xla, m) { | ||
m.def( | ||
"max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], " | ||
"int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", | ||
torch::dispatch( | ||
c10::DispatchKey::XLA, | ||
TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_forward))); | ||
|
||
m.def( | ||
"max_pool2d_backward(Tensor grad_output, Tensor self, int[2] " | ||
"kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) " | ||
"-> Tensor", | ||
torch::dispatch( | ||
c10::DispatchKey::XLA, | ||
TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_backward))); | ||
m.def( | ||
"xla_mark_sharding_dynamo_custom_op(Tensor input, int[][] " | ||
"tile_assignment, int[][] group_assignment, int[][] replication_groups, " | ||
"int sharding_type) -> ()", | ||
torch::dispatch(c10::DispatchKey::XLA, | ||
TORCH_FN(xla_mark_sharding_dynamo_custom_op))); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just a random quesiton, will pytorch try to run backward on this custom op? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. According to https://pytorch.org/docs/stable/notes/extending.html, it looks like we need to implement and override a specific |
||
} | ||
|
||
std::vector<bool> check_materialization_helper( | ||
const std::vector<XLATensorPtr>& xtensors) { | ||
std::vector<bool> need_materialization; | ||
|
@@ -1559,72 +1690,39 @@ void InitXlaModuleBindings(py::module m) { | |
tile_assignment, group_assignment, replication_groups, | ||
ShardingUtil::ShardingType(sharding_type)); | ||
})); | ||
m.def("_xla_mark_sharding", [](const at::Tensor& input, | ||
xla::OpSharding sharding) { | ||
TORCH_LAZY_COUNTER("XlaMarkSharding", 1); | ||
XLA_CHECK(UseVirtualDevice()) | ||
<< "Please enable SPMD via `torch_xla.runtime.use_spmd()`"; | ||
XLATensorPtr xtensor = bridge::GetXlaTensor(input); | ||
auto new_sharding_spec = std::make_shared<XLATensor::ShardingSpec>( | ||
sharding, MakeShapeWithDeviceLayout( | ||
xtensor->shape(), | ||
static_cast<XlaDeviceType>(xtensor->GetDevice().type()))); | ||
|
||
// For Non DeviceData IR values, we directly attach the sharding spec | ||
// to the xtensor. | ||
const DeviceData* device_data_node = nullptr; | ||
if (xtensor->CurrentIrValue()) { | ||
device_data_node = DeviceData::Cast(xtensor->CurrentIrValue().node.get()); | ||
if (!device_data_node) { | ||
tensor_methods::custom_sharding_(xtensor, new_sharding_spec); | ||
return; | ||
} | ||
} | ||
m.def("_xla_mark_sharding", | ||
[](const at::Tensor& input, xla::OpSharding sharding) { | ||
xla_mark_sharding(input, sharding); | ||
}); | ||
m.def("_xla_mark_sharding_dynamo_custom_op", | ||
[](const at::Tensor& input, const py::list& tile_assignment, | ||
const py::list& group_assignment, const py::list& replication_groups, | ||
int sharding_type) { | ||
c10::List<at::IntArrayRef> time_assignment_list = | ||
c10::List<at::IntArrayRef>(); | ||
for (auto t : tile_assignment) { | ||
time_assignment_list.push_back( | ||
wonjoolee95 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
at::IntArrayRef(t.cast<std::vector<int64_t>>())); | ||
} | ||
|
||
// For data, we need to deal with the data transfers between | ||
// host and device. | ||
at::Tensor cpu_tensor; | ||
if (xtensor->CurrentTensorData().has_value()) { | ||
TORCH_LAZY_COUNTER("VirtualDeviceUsage", 1); | ||
// When virtual device is enabled for SPMD, we defer the initial | ||
// data transfer to the device and retain the original data on the | ||
// host, until the sharded data transfer. | ||
cpu_tensor = xtensor->CurrentTensorData().value(); | ||
} else { | ||
// A new input tensor is not expected to be sharded. But sometimes, | ||
// the same input is called for sharding annotation over multiple steps, | ||
// in which case we can skip if it's the same sharding; however, if it's | ||
// the same input with a different sharding then we block & ask the user | ||
// to clear the existing sharding first. | ||
auto current_sharding_spec = xtensor->sharding_spec(); | ||
if (current_sharding_spec && (current_sharding_spec->sharding.type() != | ||
xla::OpSharding::REPLICATED)) { | ||
XLA_CHECK(ShardingUtil::EqualShardingSpecs(*new_sharding_spec, | ||
*current_sharding_spec)) | ||
<< "Existing annotation must be cleared first."; | ||
return; | ||
} | ||
c10::List<at::IntArrayRef> group_assignment_list = | ||
c10::List<at::IntArrayRef>(); | ||
for (auto t : group_assignment) { | ||
group_assignment_list.push_back( | ||
at::IntArrayRef(t.cast<std::vector<int64_t>>())); | ||
} | ||
|
||
// If the at::Tensor data is not present, we need to re-download the | ||
// tensor from the physical device to CPU. In that case, the value | ||
// must be present on the backend device. | ||
XLA_CHECK((xtensor->CurrentDataHandle() && | ||
xtensor->CurrentDataHandle()->HasValue()) || | ||
device_data_node != nullptr) | ||
<< "Cannot shard tensor. Data does not present on any device."; | ||
std::vector<XLATensorPtr> xla_tensors{xtensor}; | ||
cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0]; | ||
} | ||
auto xla_data = CreateTensorsData( | ||
std::vector<at::Tensor>{cpu_tensor}, | ||
std::vector<XLATensor::ShardingSpecPtr>{new_sharding_spec}, | ||
std::vector<std::string>{GetVirtualDevice().toString()})[0]; | ||
xtensor->SetXlaData(xla_data); | ||
xtensor->SetShardingSpec(*new_sharding_spec); | ||
c10::List<at::IntArrayRef> replication_groups_list = | ||
c10::List<at::IntArrayRef>(); | ||
for (auto t : replication_groups) { | ||
replication_groups_list.push_back( | ||
at::IntArrayRef(t.cast<std::vector<int64_t>>())); | ||
} | ||
|
||
// Register sharded tensor data. | ||
XLAGraphExecutor::Get()->RegisterTensor(xtensor->data()); | ||
}); | ||
xla_mark_sharding_dynamo_custom_op( | ||
input, time_assignment_list, group_assignment_list, | ||
replication_groups_list, sharding_type); | ||
}); | ||
m.def("_xla_clear_sharding", [](const at::Tensor& input) { | ||
XLATensorPtr xtensor = bridge::GetXlaTensor(input); | ||
xtensor->ClearShardingSpec(); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yeounoh, I updated the unit test for doing
mark_sharding
inside torch compile to be part of the existingSimpleLinear
. Here, I just do amark_sharding
call inside theforward
function. Please let me know if you think this will suffice.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SG