Skip to content

Commit

Permalink
Clean up some code
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjoolee95 committed Nov 2, 2023
1 parent ae05c9a commit 08d6296
Show file tree
Hide file tree
Showing 10 changed files with 188 additions and 218 deletions.
1 change: 0 additions & 1 deletion .torch_pin

This file was deleted.

32 changes: 28 additions & 4 deletions test/spmd/test_dynamo_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,35 @@ def fn_simple(x):
device = xm.xla_device()
x_xla = torch.zeros((1, 8)).to(device)
xla_res = fn_simple(x_xla)
xm.mark_step()
print(xla_res)
# xm.mark_step()

dynamo_linear = torch.compile(fn_simple, backend="openxla")
dynamo_res = dynamo_linear(x_xla)
torch.allclose(xla_res.cpu(), dynamo_res.cpu())
# dynamo_linear = torch.compile(fn_simple, backend="openxla")
# dynamo_res = dynamo_linear(x_xla)
# torch.allclose(xla_res.cpu(), dynamo_res.cpu())

# TODO (@wonjoo) Remove this test, this is just for debugging
def test_wonjoo(self):

def fn_simple(x):
print(f'x inside fn_simple before: {x}')
torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op(x_xla, [0], [0], [0], 0)
print(f'x inside fn_simple after: {x}')
return x

device = xm.xla_device()

x_xla = torch.zeros((1, 8)).to(device)

# print(torch.ops.xla.add)
print(torch.ops.xla.max_pool2d_forward)
print(torch.ops.xla.xla_mark_sharding_dynamo_custom_op)
print(dir(torch.ops.xla.xla_mark_sharding_dynamo_custom_op))
# print(f'x_xla before: {x_xla}')

# dynamo_fn = torch.compile(fn_simple, backend="openxla")
# dynamo_res = dynamo_fn(x_xla)
# print(f'dynamo_res: {dynamo_res}')


if __name__ == '__main__':
Expand Down
22 changes: 11 additions & 11 deletions torch_xla/csrc/aten_autograd_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,17 +253,17 @@ torch::Tensor max_pool2d_backward(torch::Tensor grad_output, torch::Tensor self,
return grad;
}

TORCH_LIBRARY(xla, m) {
m.def(
"max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], "
"int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor",
torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_forward)));
// TORCH_LIBRARY(xla, m) {
// m.def(
// "max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], "
// "int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor",
// torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_forward)));

m.def(
"max_pool2d_backward(Tensor grad_output, Tensor self, int[2] "
"kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) "
"-> Tensor",
torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_backward)));
}
// m.def(
// "max_pool2d_backward(Tensor grad_output, Tensor self, int[2] "
// "kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) "
// "-> Tensor",
// torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_backward)));
// }
} // namespace aten_autograd_ops
} // namespace torch_xla
11 changes: 11 additions & 0 deletions torch_xla/csrc/aten_autograd_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ struct MaxPool3dAutogradFunction
torch::autograd::variable_list grad_output);
};

torch::Tensor max_pool2d_forward(torch::Tensor self,
torch::IntArrayRef kernel_size,
torch::IntArrayRef stride,
torch::IntArrayRef padding,
torch::IntArrayRef dilation, bool ceil_mode);

torch::Tensor max_pool2d_backward(torch::Tensor grad_output, torch::Tensor self,
torch::IntArrayRef kernel_size,
torch::IntArrayRef stride,
torch::IntArrayRef padding, bool ceil_mode);

} // namespace aten_autograd_ops
} // namespace torch_xla

Expand Down
191 changes: 123 additions & 68 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "pybind11/pytypes.h"
#include "pybind11/stl_bind.h"
#include "torch_xla/csrc/XLANativeFunctions.h"
#include "torch_xla/csrc/aten_autograd_ops.h"
#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/device.h"
#include "torch_xla/csrc/helpers.h"
Expand Down Expand Up @@ -651,6 +652,121 @@ std::string GetPyTypeString(py::handle obj) {
return type;
}

void xla_mark_sharding(const at::Tensor& input, xla::OpSharding sharding) {
TORCH_LAZY_COUNTER("XlaMarkSharding", 1);
XLA_CHECK(UseVirtualDevice())
<< "Please enable SPMD via `torch_xla.runtime.use_spmd()`";
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
auto new_sharding_spec = std::make_shared<XLATensor::ShardingSpec>(
sharding, MakeShapeWithDeviceLayout(
xtensor->shape(),
static_cast<XlaDeviceType>(xtensor->GetDevice().type())));

// For Non DeviceData IR values, we directly attach the sharding spec
// to the xtensor.
const DeviceData* device_data_node = nullptr;
if (xtensor->CurrentIrValue()) {
device_data_node = DeviceData::Cast(xtensor->CurrentIrValue().node.get());
if (!device_data_node) {
tensor_methods::custom_sharding_(xtensor, new_sharding_spec);
return;
}
}

// For data, we need to deal with the data transfers between
// host and device.
at::Tensor cpu_tensor;
if (xtensor->CurrentTensorData().has_value()) {
TORCH_LAZY_COUNTER("VirtualDeviceUsage", 1);
// When virtual device is enabled for SPMD, we defer the initial
// data transfer to the device and retain the original data on the
// host, until the sharded data transfer.
cpu_tensor = xtensor->CurrentTensorData().value();
} else {
// A new input tensor is not expected to be sharded. But sometimes,
// the same input is called for sharding annotation over multiple steps,
// in which case we can skip if it's the same sharding; however, if it's
// the same input with a different sharding then we block & ask the user
// to clear the existing sharding first.
auto current_sharding_spec = xtensor->sharding_spec();
if (current_sharding_spec && (current_sharding_spec->sharding.type() !=
xla::OpSharding::REPLICATED)) {
XLA_CHECK(ShardingUtil::EqualShardingSpecs(*new_sharding_spec,
*current_sharding_spec))
<< "Existing annotation must be cleared first.";
return;
}

// If the at::Tensor data is not present, we need to re-download the
// tensor from the physical device to CPU. In that case, the value
// must be present on the backend device.
XLA_CHECK((xtensor->CurrentDataHandle() &&
xtensor->CurrentDataHandle()->HasValue()) ||
device_data_node != nullptr)
<< "Cannot shard tensor. Data does not present on any device.";
std::vector<XLATensorPtr> xla_tensors{xtensor};
cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0];
}
auto xla_data = CreateTensorsData(
std::vector<at::Tensor>{cpu_tensor},
std::vector<XLATensor::ShardingSpecPtr>{new_sharding_spec},
std::vector<std::string>{GetVirtualDevice().toString()})[0];
xtensor->SetXlaData(xla_data);
xtensor->SetShardingSpec(*new_sharding_spec);

// Register sharded tensor data.
XLAGraphExecutor::Get()->RegisterTensor(xtensor->data());
}

void xla_mark_sharding_dynamo_custom_op(const at::Tensor& input, c10::List<at::IntArrayRef> tile_assignment, c10::List<at::IntArrayRef> group_assignment, c10::List<at::IntArrayRef> replication_groups, int64_t sharding_type) {
std::cout << "at xla_mark_sharding_dynamo_custom_op0" << std::endl;

std::cout << "input: " << input << std::endl;
// std::cout << "tile_assignment: " << tile_assignment << std::endl;
std::cout << "converting tile_assignment_py" << std::endl;
// const py::list& tile_assignment_py = py::cast(tile_assignment[0]);
const py::list& tile_assignment_py = py::cast(torch::lazy::ToVector<int64_t>(tile_assignment[0]));

// std::cout << "group_assignment: " << group_assignment << std::endl;
std::cout << "converting group_assignment_py" << std::endl;
const py::list& group_assignment_py = py::cast(group_assignment);

// std::cout << "replication_groups: " << replication_groups << std::endl;
std::cout << "converting replication_groups_py" << std::endl;
const py::list& replication_groups_py = py::cast(replication_groups);

std::cout << "at xla_mark_sharding_dynamo_custom_op1" << std::endl;

const xla::OpSharding op_sharding = ShardingUtil::CreateOpSharding(
tile_assignment_py, group_assignment_py, replication_groups_py,
ShardingUtil::ShardingType(sharding_type));


std::cout << "at xla_mark_sharding_dynamo_custom_op2" << std::endl;

xla_mark_sharding(input, op_sharding);

std::cout << "at xla_mark_sharding_dynamo_custom_op3" << std::endl;
}

// Macro for defining a function that will be run at static initialization time to define a library of operators in the namespace.
// Used to define a new set of custom operators that do not already exist in PyTorch.
TORCH_LIBRARY(xla, m) {
m.def(
"max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], "
"int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor",
torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_forward)));

m.def(
"max_pool2d_backward(Tensor grad_output, Tensor self, int[2] "
"kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) "
"-> Tensor",
torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_backward)));
m.def(
"xla_mark_sharding_dynamo_custom_op(Tensor input, int[][] tile_assignment, int[][] group_assignment, int[][] replication_groups, int sharding_type) -> ()",
torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(xla_mark_sharding_dynamo_custom_op)));
}

std::vector<bool> check_materialization_helper(
const std::vector<XLATensorPtr>& xtensors) {
std::vector<bool> need_materialization;
Expand Down Expand Up @@ -1561,75 +1677,14 @@ void InitXlaModuleBindings(py::module m) {
}));
m.def("_xla_mark_sharding", [](const at::Tensor& input,
xla::OpSharding sharding) {
TORCH_LAZY_COUNTER("XlaMarkSharding", 1);
XLA_CHECK(UseVirtualDevice())
<< "Please enable SPMD via `torch_xla.runtime.use_spmd()`";
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
auto new_sharding_spec = std::make_shared<XLATensor::ShardingSpec>(
sharding, MakeShapeWithDeviceLayout(
xtensor->shape(),
static_cast<XlaDeviceType>(xtensor->GetDevice().type())));

// For Non DeviceData IR values, we directly attach the sharding spec
// to the xtensor.
const DeviceData* device_data_node = nullptr;
if (xtensor->CurrentIrValue()) {
device_data_node = DeviceData::Cast(xtensor->CurrentIrValue().node.get());
if (!device_data_node) {
tensor_methods::custom_sharding_(xtensor, new_sharding_spec);
return;
}
}

// For data, we need to deal with the data transfers between
// host and device.
at::Tensor cpu_tensor;
if (xtensor->CurrentTensorData().has_value()) {
TORCH_LAZY_COUNTER("VirtualDeviceUsage", 1);
// When virtual device is enabled for SPMD, we defer the initial
// data transfer to the device and retain the original data on the
// host, until the sharded data transfer.
cpu_tensor = xtensor->CurrentTensorData().value();
} else {
// A new input tensor is not expected to be sharded. But sometimes,
// the same input is called for sharding annotation over multiple steps,
// in which case we can skip if it's the same sharding; however, if it's
// the same input with a different sharding then we block & ask the user
// to clear the existing sharding first.
auto current_sharding_spec = xtensor->sharding_spec();
if (current_sharding_spec && (current_sharding_spec->sharding.type() !=
xla::OpSharding::REPLICATED)) {
XLA_CHECK(ShardingUtil::EqualShardingSpecs(*new_sharding_spec,
*current_sharding_spec))
<< "Existing annotation must be cleared first.";
return;
}

// If the at::Tensor data is not present, we need to re-download the
// tensor from the physical device to CPU. In that case, the value
// must be present on the backend device.
XLA_CHECK((xtensor->CurrentDataHandle() &&
xtensor->CurrentDataHandle()->HasValue()) ||
device_data_node != nullptr)
<< "Cannot shard tensor. Data does not present on any device.";
std::vector<XLATensorPtr> xla_tensors{xtensor};
cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0];
}
auto xla_data = CreateTensorsData(
std::vector<at::Tensor>{cpu_tensor},
std::vector<XLATensor::ShardingSpecPtr>{new_sharding_spec},
std::vector<std::string>{GetVirtualDevice().toString()})[0];
xtensor->SetXlaData(xla_data);
xtensor->SetShardingSpec(*new_sharding_spec);

// Register sharded tensor data.
XLAGraphExecutor::Get()->RegisterTensor(xtensor->data());
xla_mark_sharding(input, sharding);
});
m.def("_xla_mark_sharding_dynamo_custom_op",
[](const at::Tensor& input, xla::OpSharding sharding) {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
tensor_methods::custom_mark_sharding(xtensor, sharding);
});
// m.def("_xla_mark_sharding_dynamo_custom_op",
// [](const at::Tensor& input, xla::OpSharding sharding) {
// // xla_mark_sharding_dynamo_custom_op(input, tile_assignment, group_assignment, replication_groups, sharding_type);
// // at::IntArrayRef tile_assignment, at::IntArrayRef group_assignment, c10::List<at::IntArrayRef> replication_groups, int64_t sharding_type
// at::IntArrayRef tile_assignment =
// });
m.def("_xla_clear_sharding", [](const at::Tensor& input) {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
xtensor->ClearShardingSpec();
Expand Down
34 changes: 0 additions & 34 deletions torch_xla/csrc/ops/custom_mark_sharding.cpp

This file was deleted.

23 changes: 0 additions & 23 deletions torch_xla/csrc/ops/custom_mark_sharding.h

This file was deleted.

Loading

0 comments on commit 08d6296

Please sign in to comment.