Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement mark_sharding as a custom op to support dynamo spmd activation sharding #5712

Merged
merged 14 commits into from
Nov 11, 2023
Merged
41 changes: 40 additions & 1 deletion test/spmd/test_dynamo_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,22 @@

class SimpleLinear(nn.Module):

def __init__(self):
def __init__(self, mesh=None):
super(SimpleLinear, self).__init__()
self.fc1 = nn.Linear(128, 128)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(128, 1)
# Add an additional 1x1 layer at the end to ensure the final layer
# is not sharded.
self.fc3 = nn.Linear(1, 1)
# If mesh is not none, we'll do a mark sharding inside the forward function
# to ensure dynamo can recognize and trace it in a torch compile.
self.mesh = mesh

def forward(self, x):
if self.mesh and 'xla' in str(self.fc2.weight.device):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yeounoh, I updated the unit test for doing mark_sharding inside torch compile to be part of the existing SimpleLinear. Here, I just do a mark_sharding call inside the forward function. Please let me know if you think this will suffice.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SG

xs.mark_sharding(
self.fc2.weight, self.mesh, (1, 0), dynamo_custom_op=True)
y = self.relu(self.fc1(x))
z = self.fc2(y)
return self.fc3(z)
Expand Down Expand Up @@ -53,6 +59,22 @@ def test_dynamo_spmd_basic(self):
# TODO(JackCaoG): add counter checks after ExecuteReplicated also creates
# a ExecuteMetric.

def test_dynamo_spmd_basic_with_custom_mark_sharding_op(self):
wonjoolee95 marked this conversation as resolved.
Show resolved Hide resolved
device = xm.xla_device()
linear = SimpleLinear().to(device)
linear.eval()
xla_x = torch.randn(1, 128, device=device)
xs.mark_sharding(
linear.fc2.weight,
self._get_mesh((1, self.n_devices)), (1, 0),
dynamo_custom_op=True)
xla_res = linear(xla_x)
xm.mark_step()

dynamo_linear = torch.compile(linear, backend="openxla")
dynamo_res = dynamo_linear(xla_x)
torch.allclose(xla_res.cpu(), dynamo_res.cpu())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a counter check? We want to make sure we are not recompiling across differerent runs. You can either add it here or add a separate test.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated


def test_dynamo_spmd_output_sharding_spec(self):
device = xm.xla_device()
linear = SimpleLinear().to(device)
Expand Down Expand Up @@ -171,6 +193,23 @@ def test_dynamo_input_sharding_threashold(self):
else:
del os.environ['XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD']

def test_mark_sharding_inside_compile(self):
device = xm.xla_device()
mesh = self._get_mesh((1, self.n_devices))

# Passing this `mesh` as a parameter to `SimpleLinear` will call the dynamo custom op
# variant of mark_sharding inside the forward function.
linear = SimpleLinear(mesh=mesh).to(device)
linear.eval()

xla_x = torch.randn(1, 128, device=device)
xla_res = linear(xla_x)
xm.mark_step()

dynamo_linear = torch.compile(linear, backend="openxla")
dynamo_res = dynamo_linear(xla_x)
torch.allclose(xla_res.cpu(), dynamo_res.cpu())


if __name__ == '__main__':
test = unittest.main()
Expand Down
12 changes: 0 additions & 12 deletions torch_xla/csrc/aten_autograd_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,17 +253,5 @@ torch::Tensor max_pool2d_backward(torch::Tensor grad_output, torch::Tensor self,
return grad;
}

TORCH_LIBRARY(xla, m) {
m.def(
"max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], "
"int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor",
torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_forward)));

m.def(
"max_pool2d_backward(Tensor grad_output, Tensor self, int[2] "
"kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) "
"-> Tensor",
torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_backward)));
}
} // namespace aten_autograd_ops
} // namespace torch_xla
11 changes: 11 additions & 0 deletions torch_xla/csrc/aten_autograd_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ struct MaxPool3dAutogradFunction
torch::autograd::variable_list grad_output);
};

torch::Tensor max_pool2d_forward(torch::Tensor self,
torch::IntArrayRef kernel_size,
torch::IntArrayRef stride,
torch::IntArrayRef padding,
torch::IntArrayRef dilation, bool ceil_mode);

torch::Tensor max_pool2d_backward(torch::Tensor grad_output, torch::Tensor self,
torch::IntArrayRef kernel_size,
torch::IntArrayRef stride,
torch::IntArrayRef padding, bool ceil_mode);

} // namespace aten_autograd_ops
} // namespace torch_xla

Expand Down
224 changes: 161 additions & 63 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "pybind11/pytypes.h"
#include "pybind11/stl_bind.h"
#include "torch_xla/csrc/XLANativeFunctions.h"
#include "torch_xla/csrc/aten_autograd_ops.h"
#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/device.h"
#include "torch_xla/csrc/helpers.h"
Expand Down Expand Up @@ -651,6 +652,136 @@ std::string GetPyTypeString(py::handle obj) {
return type;
}

void xla_mark_sharding(const at::Tensor& input, xla::OpSharding sharding) {
TORCH_LAZY_COUNTER("XlaMarkSharding", 1);
XLA_CHECK(UseVirtualDevice())
<< "Please enable SPMD via `torch_xla.runtime.use_spmd()`";
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
auto new_sharding_spec = std::make_shared<XLATensor::ShardingSpec>(
sharding, MakeShapeWithDeviceLayout(
xtensor->shape(),
static_cast<XlaDeviceType>(xtensor->GetDevice().type())));

// For Non DeviceData IR values, we directly attach the sharding spec
// to the xtensor.
const DeviceData* device_data_node = nullptr;
if (xtensor->CurrentIrValue()) {
device_data_node = DeviceData::Cast(xtensor->CurrentIrValue().node.get());
if (!device_data_node) {
tensor_methods::custom_sharding_(xtensor, new_sharding_spec);
return;
}
}

// For data, we need to deal with the data transfers between
// host and device.
at::Tensor cpu_tensor;
if (xtensor->CurrentTensorData().has_value()) {
TORCH_LAZY_COUNTER("VirtualDeviceUsage", 1);
// When virtual device is enabled for SPMD, we defer the initial
// data transfer to the device and retain the original data on the
// host, until the sharded data transfer.
cpu_tensor = xtensor->CurrentTensorData().value();
} else {
// A new input tensor is not expected to be sharded. But sometimes,
// the same input is called for sharding annotation over multiple steps,
// in which case we can skip if it's the same sharding; however, if it's
// the same input with a different sharding then we block & ask the user
// to clear the existing sharding first.
auto current_sharding_spec = xtensor->sharding_spec();
if (current_sharding_spec && (current_sharding_spec->sharding.type() !=
xla::OpSharding::REPLICATED)) {
XLA_CHECK(ShardingUtil::EqualShardingSpecs(*new_sharding_spec,
*current_sharding_spec))
<< "Existing annotation must be cleared first.";
return;
}

// If the at::Tensor data is not present, we need to re-download the
// tensor from the physical device to CPU. In that case, the value
// must be present on the backend device.
XLA_CHECK((xtensor->CurrentDataHandle() &&
xtensor->CurrentDataHandle()->HasValue()) ||
device_data_node != nullptr)
<< "Cannot shard tensor. Data does not present on any device.";
std::vector<XLATensorPtr> xla_tensors{xtensor};
cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0];
}
auto xla_data = CreateTensorsData(
std::vector<at::Tensor>{cpu_tensor},
std::vector<XLATensor::ShardingSpecPtr>{new_sharding_spec},
std::vector<std::string>{GetVirtualDevice().toString()})[0];
xtensor->SetXlaData(xla_data);
xtensor->SetShardingSpec(*new_sharding_spec);

// Register sharded tensor data.
XLAGraphExecutor::Get()->RegisterTensor(xtensor->data());
}

void xla_mark_sharding_dynamo_custom_op(
const at::Tensor& input, c10::List<at::IntArrayRef> tile_assignment,
c10::List<at::IntArrayRef> group_assignment,
c10::List<at::IntArrayRef> replication_groups, int64_t sharding_type) {
py::list tile_assignment_py = py::list();
for (int i = 0; i < tile_assignment.size(); i++) {
py::list pylist = py::list();
for (int64_t t : tile_assignment[i].get().toIntList()) {
pylist.append(t);
}
tile_assignment_py.append(pylist);
}

py::list group_assignment_py = py::list();
for (int i = 0; i < group_assignment.size(); i++) {
py::list pylist = py::list();
for (int64_t t : group_assignment[i].get().toIntList()) {
pylist.append(t);
}
group_assignment_py.append(pylist);
}

py::list replication_groups_py = py::list();
for (int i = 0; i < replication_groups.size(); i++) {
py::list pylist = py::list();
for (int64_t t : replication_groups[i].get().toIntList()) {
pylist.append(t);
}
replication_groups_py.append(pylist);
}

xla::OpSharding op_sharding = ShardingUtil::CreateOpSharding(
tile_assignment_py, group_assignment_py, replication_groups_py,
ShardingUtil::ShardingType(sharding_type));

xla_mark_sharding(input, op_sharding);
}

// Macro for defining a function that will be run at static initialization time
// to define a library of operators in the namespace. Used to define a new set
// of custom operators that do not already exist in PyTorch.
TORCH_LIBRARY(xla, m) {
m.def(
"max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], "
"int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor",
torch::dispatch(
c10::DispatchKey::XLA,
TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_forward)));

m.def(
"max_pool2d_backward(Tensor grad_output, Tensor self, int[2] "
"kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) "
"-> Tensor",
torch::dispatch(
c10::DispatchKey::XLA,
TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_backward)));
m.def(
"xla_mark_sharding_dynamo_custom_op(Tensor input, int[][] "
"tile_assignment, int[][] group_assignment, int[][] replication_groups, "
"int sharding_type) -> ()",
torch::dispatch(c10::DispatchKey::XLA,
TORCH_FN(xla_mark_sharding_dynamo_custom_op)));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a random quesiton, will pytorch try to run backward on this custom op?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to https://pytorch.org/docs/stable/notes/extending.html, it looks like we need to implement and override a specific backward function. So in this case, PyTorch should not try to run backward on this custom op.

}

std::vector<bool> check_materialization_helper(
const std::vector<XLATensorPtr>& xtensors) {
std::vector<bool> need_materialization;
Expand Down Expand Up @@ -1559,72 +1690,39 @@ void InitXlaModuleBindings(py::module m) {
tile_assignment, group_assignment, replication_groups,
ShardingUtil::ShardingType(sharding_type));
}));
m.def("_xla_mark_sharding", [](const at::Tensor& input,
xla::OpSharding sharding) {
TORCH_LAZY_COUNTER("XlaMarkSharding", 1);
XLA_CHECK(UseVirtualDevice())
<< "Please enable SPMD via `torch_xla.runtime.use_spmd()`";
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
auto new_sharding_spec = std::make_shared<XLATensor::ShardingSpec>(
sharding, MakeShapeWithDeviceLayout(
xtensor->shape(),
static_cast<XlaDeviceType>(xtensor->GetDevice().type())));

// For Non DeviceData IR values, we directly attach the sharding spec
// to the xtensor.
const DeviceData* device_data_node = nullptr;
if (xtensor->CurrentIrValue()) {
device_data_node = DeviceData::Cast(xtensor->CurrentIrValue().node.get());
if (!device_data_node) {
tensor_methods::custom_sharding_(xtensor, new_sharding_spec);
return;
}
}
m.def("_xla_mark_sharding",
[](const at::Tensor& input, xla::OpSharding sharding) {
xla_mark_sharding(input, sharding);
});
m.def("_xla_mark_sharding_dynamo_custom_op",
[](const at::Tensor& input, const py::list& tile_assignment,
const py::list& group_assignment, const py::list& replication_groups,
int sharding_type) {
c10::List<at::IntArrayRef> time_assignment_list =
c10::List<at::IntArrayRef>();
for (auto t : tile_assignment) {
time_assignment_list.push_back(
wonjoolee95 marked this conversation as resolved.
Show resolved Hide resolved
at::IntArrayRef(t.cast<std::vector<int64_t>>()));
}

// For data, we need to deal with the data transfers between
// host and device.
at::Tensor cpu_tensor;
if (xtensor->CurrentTensorData().has_value()) {
TORCH_LAZY_COUNTER("VirtualDeviceUsage", 1);
// When virtual device is enabled for SPMD, we defer the initial
// data transfer to the device and retain the original data on the
// host, until the sharded data transfer.
cpu_tensor = xtensor->CurrentTensorData().value();
} else {
// A new input tensor is not expected to be sharded. But sometimes,
// the same input is called for sharding annotation over multiple steps,
// in which case we can skip if it's the same sharding; however, if it's
// the same input with a different sharding then we block & ask the user
// to clear the existing sharding first.
auto current_sharding_spec = xtensor->sharding_spec();
if (current_sharding_spec && (current_sharding_spec->sharding.type() !=
xla::OpSharding::REPLICATED)) {
XLA_CHECK(ShardingUtil::EqualShardingSpecs(*new_sharding_spec,
*current_sharding_spec))
<< "Existing annotation must be cleared first.";
return;
}
c10::List<at::IntArrayRef> group_assignment_list =
c10::List<at::IntArrayRef>();
for (auto t : group_assignment) {
group_assignment_list.push_back(
at::IntArrayRef(t.cast<std::vector<int64_t>>()));
}

// If the at::Tensor data is not present, we need to re-download the
// tensor from the physical device to CPU. In that case, the value
// must be present on the backend device.
XLA_CHECK((xtensor->CurrentDataHandle() &&
xtensor->CurrentDataHandle()->HasValue()) ||
device_data_node != nullptr)
<< "Cannot shard tensor. Data does not present on any device.";
std::vector<XLATensorPtr> xla_tensors{xtensor};
cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0];
}
auto xla_data = CreateTensorsData(
std::vector<at::Tensor>{cpu_tensor},
std::vector<XLATensor::ShardingSpecPtr>{new_sharding_spec},
std::vector<std::string>{GetVirtualDevice().toString()})[0];
xtensor->SetXlaData(xla_data);
xtensor->SetShardingSpec(*new_sharding_spec);
c10::List<at::IntArrayRef> replication_groups_list =
c10::List<at::IntArrayRef>();
for (auto t : replication_groups) {
replication_groups_list.push_back(
at::IntArrayRef(t.cast<std::vector<int64_t>>()));
}

// Register sharded tensor data.
XLAGraphExecutor::Get()->RegisterTensor(xtensor->data());
});
xla_mark_sharding_dynamo_custom_op(
input, time_assignment_list, group_assignment_list,
replication_groups_list, sharding_type);
});
m.def("_xla_clear_sharding", [](const at::Tensor& input) {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
xtensor->ClearShardingSpec();
Expand Down
Loading
Loading