Skip to content

Commit

Permalink
Update unit test and run linter
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjoolee95 committed Nov 7, 2023
1 parent 9aaa533 commit a98bfb2
Show file tree
Hide file tree
Showing 9 changed files with 153 additions and 170 deletions.
40 changes: 20 additions & 20 deletions test/spmd/test_dynamo_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,22 @@

class SimpleLinear(nn.Module):

def __init__(self, mark_sharding_inside = False, op_sharding = None):
def __init__(self, mesh=None):
super(SimpleLinear, self).__init__()
self.fc1 = nn.Linear(128, 128)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(128, 1)
# Add an additional 1x1 layer at the end to ensure the final layer
# is not sharded.
self.fc3 = nn.Linear(1, 1)
# If mesh is not none, we'll do a mark sharding inside the forward function
# to ensure dynamo can recognize and trace it in a torch compile.
self.mesh = mesh

def forward(self, x):
print(f'self.fc2.weight.device={self.fc2.weight.device}')
if self.mark_sharding_inside and self.op_sharding and 'xla' in self.fc2.weight.device:
xs.mark_sharding(self.fc2.weight, self.op_sharding)
if self.mesh and 'xla' in str(self.fc2.weight.device):
xs.mark_sharding(
self.fc2.weight, self.mesh, (1, 0), dynamo_custom_op=True)
y = self.relu(self.fc1(x))
z = self.fc2(y)
return self.fc3(z)
Expand Down Expand Up @@ -61,9 +64,10 @@ def test_dynamo_spmd_basic_with_custom_mark_sharding_op(self):
linear = SimpleLinear().to(device)
linear.eval()
xla_x = torch.randn(1, 128, device=device)
xs.mark_sharding_dynamo_custom_op(linear.fc2.weight,
self._get_mesh((1, self.n_devices)),
(1, 0))
xs.mark_sharding(
linear.fc2.weight,
self._get_mesh((1, self.n_devices)), (1, 0),
dynamo_custom_op=True)
xla_res = linear(xla_x)
xm.mark_step()

Expand Down Expand Up @@ -191,23 +195,19 @@ def test_dynamo_input_sharding_threashold(self):

def test_mark_sharding_inside_compile(self):
device = xm.xla_device()
mesh = self._get_mesh((1, self.n_devices))

def fn_simple(t):
xs.mark_sharding_dynamo_custom_op(
t, self._get_mesh((1, self.n_devices)), (0, 1))

x = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]],
dtype=torch.float,
device=device)

return t + x
# Passing this `mesh` as a parameter to `SimpleLinear` will call the dynamo custom op
# variant of mark_sharding inside the forward function.
linear = SimpleLinear(mesh=mesh).to(device)
linear.eval()

x_xla = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]]).to(device)
xla_res = fn_simple(x_xla)
xla_x = torch.randn(1, 128, device=device)
xla_res = linear(xla_x)
xm.mark_step()

dynamo_fn_simple = torch.compile(fn_simple, backend="openxla")
dynamo_res = dynamo_fn_simple(x_xla)
dynamo_linear = torch.compile(linear, backend="openxla")
dynamo_res = dynamo_linear(xla_x)
torch.allclose(xla_res.cpu(), dynamo_res.cpu())


Expand Down
12 changes: 0 additions & 12 deletions torch_xla/csrc/aten_autograd_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,17 +253,5 @@ torch::Tensor max_pool2d_backward(torch::Tensor grad_output, torch::Tensor self,
return grad;
}

// TORCH_LIBRARY(xla, m) {
// m.def(
// "max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], "
// "int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor",
// torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_forward)));

// m.def(
// "max_pool2d_backward(Tensor grad_output, Tensor self, int[2] "
// "kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) "
// "-> Tensor",
// torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_backward)));
// }
} // namespace aten_autograd_ops
} // namespace torch_xla
181 changes: 100 additions & 81 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -653,72 +653,75 @@ std::string GetPyTypeString(py::handle obj) {
}

void xla_mark_sharding(const at::Tensor& input, xla::OpSharding sharding) {
TORCH_LAZY_COUNTER("XlaMarkSharding", 1);
XLA_CHECK(UseVirtualDevice())
<< "Please enable SPMD via `torch_xla.runtime.use_spmd()`";
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
auto new_sharding_spec = std::make_shared<XLATensor::ShardingSpec>(
sharding, MakeShapeWithDeviceLayout(
xtensor->shape(),
static_cast<XlaDeviceType>(xtensor->GetDevice().type())));

// For Non DeviceData IR values, we directly attach the sharding spec
// to the xtensor.
const DeviceData* device_data_node = nullptr;
if (xtensor->CurrentIrValue()) {
device_data_node = DeviceData::Cast(xtensor->CurrentIrValue().node.get());
if (!device_data_node) {
tensor_methods::custom_sharding_(xtensor, new_sharding_spec);
return;
}
TORCH_LAZY_COUNTER("XlaMarkSharding", 1);
XLA_CHECK(UseVirtualDevice())
<< "Please enable SPMD via `torch_xla.runtime.use_spmd()`";
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
auto new_sharding_spec = std::make_shared<XLATensor::ShardingSpec>(
sharding, MakeShapeWithDeviceLayout(
xtensor->shape(),
static_cast<XlaDeviceType>(xtensor->GetDevice().type())));

// For Non DeviceData IR values, we directly attach the sharding spec
// to the xtensor.
const DeviceData* device_data_node = nullptr;
if (xtensor->CurrentIrValue()) {
device_data_node = DeviceData::Cast(xtensor->CurrentIrValue().node.get());
if (!device_data_node) {
tensor_methods::custom_sharding_(xtensor, new_sharding_spec);
return;
}
}

// For data, we need to deal with the data transfers between
// host and device.
at::Tensor cpu_tensor;
if (xtensor->CurrentTensorData().has_value()) {
TORCH_LAZY_COUNTER("VirtualDeviceUsage", 1);
// When virtual device is enabled for SPMD, we defer the initial
// data transfer to the device and retain the original data on the
// host, until the sharded data transfer.
cpu_tensor = xtensor->CurrentTensorData().value();
} else {
// A new input tensor is not expected to be sharded. But sometimes,
// the same input is called for sharding annotation over multiple steps,
// in which case we can skip if it's the same sharding; however, if it's
// the same input with a different sharding then we block & ask the user
// to clear the existing sharding first.
auto current_sharding_spec = xtensor->sharding_spec();
if (current_sharding_spec && (current_sharding_spec->sharding.type() !=
xla::OpSharding::REPLICATED)) {
XLA_CHECK(ShardingUtil::EqualShardingSpecs(*new_sharding_spec,
*current_sharding_spec))
<< "Existing annotation must be cleared first.";
return;
}

// If the at::Tensor data is not present, we need to re-download the
// tensor from the physical device to CPU. In that case, the value
// must be present on the backend device.
XLA_CHECK((xtensor->CurrentDataHandle() &&
xtensor->CurrentDataHandle()->HasValue()) ||
device_data_node != nullptr)
<< "Cannot shard tensor. Data does not present on any device.";
std::vector<XLATensorPtr> xla_tensors{xtensor};
cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0];
// For data, we need to deal with the data transfers between
// host and device.
at::Tensor cpu_tensor;
if (xtensor->CurrentTensorData().has_value()) {
TORCH_LAZY_COUNTER("VirtualDeviceUsage", 1);
// When virtual device is enabled for SPMD, we defer the initial
// data transfer to the device and retain the original data on the
// host, until the sharded data transfer.
cpu_tensor = xtensor->CurrentTensorData().value();
} else {
// A new input tensor is not expected to be sharded. But sometimes,
// the same input is called for sharding annotation over multiple steps,
// in which case we can skip if it's the same sharding; however, if it's
// the same input with a different sharding then we block & ask the user
// to clear the existing sharding first.
auto current_sharding_spec = xtensor->sharding_spec();
if (current_sharding_spec && (current_sharding_spec->sharding.type() !=
xla::OpSharding::REPLICATED)) {
XLA_CHECK(ShardingUtil::EqualShardingSpecs(*new_sharding_spec,
*current_sharding_spec))
<< "Existing annotation must be cleared first.";
return;
}
auto xla_data = CreateTensorsData(
std::vector<at::Tensor>{cpu_tensor},
std::vector<XLATensor::ShardingSpecPtr>{new_sharding_spec},
std::vector<std::string>{GetVirtualDevice().toString()})[0];
xtensor->SetXlaData(xla_data);
xtensor->SetShardingSpec(*new_sharding_spec);

// Register sharded tensor data.
XLAGraphExecutor::Get()->RegisterTensor(xtensor->data());
// If the at::Tensor data is not present, we need to re-download the
// tensor from the physical device to CPU. In that case, the value
// must be present on the backend device.
XLA_CHECK((xtensor->CurrentDataHandle() &&
xtensor->CurrentDataHandle()->HasValue()) ||
device_data_node != nullptr)
<< "Cannot shard tensor. Data does not present on any device.";
std::vector<XLATensorPtr> xla_tensors{xtensor};
cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0];
}
auto xla_data = CreateTensorsData(
std::vector<at::Tensor>{cpu_tensor},
std::vector<XLATensor::ShardingSpecPtr>{new_sharding_spec},
std::vector<std::string>{GetVirtualDevice().toString()})[0];
xtensor->SetXlaData(xla_data);
xtensor->SetShardingSpec(*new_sharding_spec);

// Register sharded tensor data.
XLAGraphExecutor::Get()->RegisterTensor(xtensor->data());
}

void xla_mark_sharding_dynamo_custom_op(const at::Tensor& input, c10::List<at::IntArrayRef> tile_assignment, c10::List<at::IntArrayRef> group_assignment, c10::List<at::IntArrayRef> replication_groups, int64_t sharding_type) {
void xla_mark_sharding_dynamo_custom_op(
const at::Tensor& input, c10::List<at::IntArrayRef> tile_assignment,
c10::List<at::IntArrayRef> group_assignment,
c10::List<at::IntArrayRef> replication_groups, int64_t sharding_type) {
py::list tile_assignment_py = py::list();
for (int i = 0; i < tile_assignment.size(); i++) {
py::list pylist = py::list();
Expand Down Expand Up @@ -747,28 +750,36 @@ void xla_mark_sharding_dynamo_custom_op(const at::Tensor& input, c10::List<at::I
}

xla::OpSharding op_sharding = ShardingUtil::CreateOpSharding(
tile_assignment_py, group_assignment_py, replication_groups_py,
ShardingUtil::ShardingType(sharding_type));
tile_assignment_py, group_assignment_py, replication_groups_py,
ShardingUtil::ShardingType(sharding_type));

xla_mark_sharding(input, op_sharding);
}

// Macro for defining a function that will be run at static initialization time to define a library of operators in the namespace.
// Used to define a new set of custom operators that do not already exist in PyTorch.
// Macro for defining a function that will be run at static initialization time
// to define a library of operators in the namespace. Used to define a new set
// of custom operators that do not already exist in PyTorch.
TORCH_LIBRARY(xla, m) {
m.def(
"max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], "
"int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor",
torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_forward)));
torch::dispatch(
c10::DispatchKey::XLA,
TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_forward)));

m.def(
"max_pool2d_backward(Tensor grad_output, Tensor self, int[2] "
"kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) "
"-> Tensor",
torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_backward)));
torch::dispatch(
c10::DispatchKey::XLA,
TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_backward)));
m.def(
"xla_mark_sharding_dynamo_custom_op(Tensor input, int[][] tile_assignment, int[][] group_assignment, int[][] replication_groups, int sharding_type) -> ()",
torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(xla_mark_sharding_dynamo_custom_op)));
"xla_mark_sharding_dynamo_custom_op(Tensor input, int[][] "
"tile_assignment, int[][] group_assignment, int[][] replication_groups, "
"int sharding_type) -> ()",
torch::dispatch(c10::DispatchKey::XLA,
TORCH_FN(xla_mark_sharding_dynamo_custom_op)));
}

std::vector<bool> check_materialization_helper(
Expand Down Expand Up @@ -1679,30 +1690,38 @@ void InitXlaModuleBindings(py::module m) {
tile_assignment, group_assignment, replication_groups,
ShardingUtil::ShardingType(sharding_type));
}));
m.def("_xla_mark_sharding", [](const at::Tensor& input,
xla::OpSharding sharding) {
xla_mark_sharding(input, sharding);
});
m.def("_xla_mark_sharding",
[](const at::Tensor& input, xla::OpSharding sharding) {
xla_mark_sharding(input, sharding);
});
m.def("_xla_mark_sharding_dynamo_custom_op",
[](const at::Tensor& input, const py::list& tile_assignment,
const py::list& group_assignment,
const py::list& replication_groups, int sharding_type) {
c10::List<at::IntArrayRef> time_assignment_list = c10::List<at::IntArrayRef>();
const py::list& group_assignment, const py::list& replication_groups,
int sharding_type) {
c10::List<at::IntArrayRef> time_assignment_list =
c10::List<at::IntArrayRef>();
for (auto t : tile_assignment) {
time_assignment_list.push_back(at::IntArrayRef(t.cast<std::vector<int64_t>>()));
time_assignment_list.push_back(
at::IntArrayRef(t.cast<std::vector<int64_t>>()));
}

c10::List<at::IntArrayRef> group_assignment_list = c10::List<at::IntArrayRef>();
c10::List<at::IntArrayRef> group_assignment_list =
c10::List<at::IntArrayRef>();
for (auto t : group_assignment) {
group_assignment_list.push_back(at::IntArrayRef(t.cast<std::vector<int64_t>>()));
group_assignment_list.push_back(
at::IntArrayRef(t.cast<std::vector<int64_t>>()));
}

c10::List<at::IntArrayRef> replication_groups_list = c10::List<at::IntArrayRef>();
c10::List<at::IntArrayRef> replication_groups_list =
c10::List<at::IntArrayRef>();
for (auto t : replication_groups) {
replication_groups_list.push_back(at::IntArrayRef(t.cast<std::vector<int64_t>>()));
replication_groups_list.push_back(
at::IntArrayRef(t.cast<std::vector<int64_t>>()));
}

xla_mark_sharding_dynamo_custom_op(input, time_assignment_list, group_assignment_list, replication_groups_list, sharding_type);
xla_mark_sharding_dynamo_custom_op(
input, time_assignment_list, group_assignment_list,
replication_groups_list, sharding_type);
});
m.def("_xla_clear_sharding", [](const at::Tensor& input) {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
Expand Down
1 change: 0 additions & 1 deletion torch_xla/csrc/ops/xla_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,5 @@ const OpKindWrapper xla_tensor_data("xla::tensor_data");
const OpKindWrapper xla_unselect("xla::unselect");
const OpKindWrapper xla_update_slice("xla::update_slice");
const OpKindWrapper xla_custom_sharding("xla::custom_sharding");
const OpKindWrapper xla_custom_mark_sharding("xla::custom_mark_sharding");

} // namespace torch_xla
1 change: 0 additions & 1 deletion torch_xla/csrc/ops/xla_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ extern const OpKindWrapper xla_tensor_data;
extern const OpKindWrapper xla_unselect;
extern const OpKindWrapper xla_update_slice;
extern const OpKindWrapper xla_custom_sharding;
extern const OpKindWrapper xla_custom_mark_sharding;

} // namespace torch_xla

Expand Down
1 change: 0 additions & 1 deletion torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@
#include "torch_xla/csrc/tensor_ops.h"
#include "torch_xla/csrc/tensor_util.h"
#include "torch_xla/csrc/xla_graph_executor.h"
#include "torch_xla/csrc/xla_sharding_util.h"
#include "xla/literal_util.h"

namespace torch_xla {
Expand Down
7 changes: 0 additions & 7 deletions torch_xla/csrc/xla_lower_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1228,11 +1228,4 @@ xla::XlaOp BuildCustomSharding(const xla::XlaOp& input) {
{input}, ShapeHelper::ShapeOfXlaOp(input));
}

xla::XlaOp BuildCustomMarkSharding(const torch::lazy::BackendDevice& device,
const xla::XlaOp& input,
const xla::XlaOp& sharding) {
return xla::CustomCall(input.builder(), /*call_target_name=*/"MarkSharding",
{input}, ShapeHelper::ShapeOfXlaOp(input));
}

} // namespace torch_xla
4 changes: 0 additions & 4 deletions torch_xla/csrc/xla_lower_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,6 @@ xla::XlaOp BuildCdistForward(xla::XlaOp x1, xla::XlaOp x2, xla::XlaOp p,

xla::XlaOp BuildCustomSharding(const xla::XlaOp& input);

xla::XlaOp BuildCustomMarkSharding(const torch::lazy::BackendDevice& device,
const xla::XlaOp& input,
const xla::XlaOp& sharding);

} // namespace torch_xla

#endif // XLA_TORCH_XLA_CSRC_XLA_LOWER_UTIL_H_
Loading

0 comments on commit a98bfb2

Please sign in to comment.