Skip to content

Commit

Permalink
Address comments -- fix typos and variable names
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjoolee95 committed Nov 8, 2023
1 parent a98bfb2 commit f468ed6
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 145 deletions.
138 changes: 4 additions & 134 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -652,136 +652,6 @@ std::string GetPyTypeString(py::handle obj) {
return type;
}

void xla_mark_sharding(const at::Tensor& input, xla::OpSharding sharding) {
TORCH_LAZY_COUNTER("XlaMarkSharding", 1);
XLA_CHECK(UseVirtualDevice())
<< "Please enable SPMD via `torch_xla.runtime.use_spmd()`";
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
auto new_sharding_spec = std::make_shared<XLATensor::ShardingSpec>(
sharding, MakeShapeWithDeviceLayout(
xtensor->shape(),
static_cast<XlaDeviceType>(xtensor->GetDevice().type())));

// For Non DeviceData IR values, we directly attach the sharding spec
// to the xtensor.
const DeviceData* device_data_node = nullptr;
if (xtensor->CurrentIrValue()) {
device_data_node = DeviceData::Cast(xtensor->CurrentIrValue().node.get());
if (!device_data_node) {
tensor_methods::custom_sharding_(xtensor, new_sharding_spec);
return;
}
}

// For data, we need to deal with the data transfers between
// host and device.
at::Tensor cpu_tensor;
if (xtensor->CurrentTensorData().has_value()) {
TORCH_LAZY_COUNTER("VirtualDeviceUsage", 1);
// When virtual device is enabled for SPMD, we defer the initial
// data transfer to the device and retain the original data on the
// host, until the sharded data transfer.
cpu_tensor = xtensor->CurrentTensorData().value();
} else {
// A new input tensor is not expected to be sharded. But sometimes,
// the same input is called for sharding annotation over multiple steps,
// in which case we can skip if it's the same sharding; however, if it's
// the same input with a different sharding then we block & ask the user
// to clear the existing sharding first.
auto current_sharding_spec = xtensor->sharding_spec();
if (current_sharding_spec && (current_sharding_spec->sharding.type() !=
xla::OpSharding::REPLICATED)) {
XLA_CHECK(ShardingUtil::EqualShardingSpecs(*new_sharding_spec,
*current_sharding_spec))
<< "Existing annotation must be cleared first.";
return;
}

// If the at::Tensor data is not present, we need to re-download the
// tensor from the physical device to CPU. In that case, the value
// must be present on the backend device.
XLA_CHECK((xtensor->CurrentDataHandle() &&
xtensor->CurrentDataHandle()->HasValue()) ||
device_data_node != nullptr)
<< "Cannot shard tensor. Data does not present on any device.";
std::vector<XLATensorPtr> xla_tensors{xtensor};
cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0];
}
auto xla_data = CreateTensorsData(
std::vector<at::Tensor>{cpu_tensor},
std::vector<XLATensor::ShardingSpecPtr>{new_sharding_spec},
std::vector<std::string>{GetVirtualDevice().toString()})[0];
xtensor->SetXlaData(xla_data);
xtensor->SetShardingSpec(*new_sharding_spec);

// Register sharded tensor data.
XLAGraphExecutor::Get()->RegisterTensor(xtensor->data());
}

void xla_mark_sharding_dynamo_custom_op(
const at::Tensor& input, c10::List<at::IntArrayRef> tile_assignment,
c10::List<at::IntArrayRef> group_assignment,
c10::List<at::IntArrayRef> replication_groups, int64_t sharding_type) {
py::list tile_assignment_py = py::list();
for (int i = 0; i < tile_assignment.size(); i++) {
py::list pylist = py::list();
for (int64_t t : tile_assignment[i].get().toIntList()) {
pylist.append(t);
}
tile_assignment_py.append(pylist);
}

py::list group_assignment_py = py::list();
for (int i = 0; i < group_assignment.size(); i++) {
py::list pylist = py::list();
for (int64_t t : group_assignment[i].get().toIntList()) {
pylist.append(t);
}
group_assignment_py.append(pylist);
}

py::list replication_groups_py = py::list();
for (int i = 0; i < replication_groups.size(); i++) {
py::list pylist = py::list();
for (int64_t t : replication_groups[i].get().toIntList()) {
pylist.append(t);
}
replication_groups_py.append(pylist);
}

xla::OpSharding op_sharding = ShardingUtil::CreateOpSharding(
tile_assignment_py, group_assignment_py, replication_groups_py,
ShardingUtil::ShardingType(sharding_type));

xla_mark_sharding(input, op_sharding);
}

// Macro for defining a function that will be run at static initialization time
// to define a library of operators in the namespace. Used to define a new set
// of custom operators that do not already exist in PyTorch.
TORCH_LIBRARY(xla, m) {
m.def(
"max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], "
"int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor",
torch::dispatch(
c10::DispatchKey::XLA,
TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_forward)));

m.def(
"max_pool2d_backward(Tensor grad_output, Tensor self, int[2] "
"kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) "
"-> Tensor",
torch::dispatch(
c10::DispatchKey::XLA,
TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_backward)));
m.def(
"xla_mark_sharding_dynamo_custom_op(Tensor input, int[][] "
"tile_assignment, int[][] group_assignment, int[][] replication_groups, "
"int sharding_type) -> ()",
torch::dispatch(c10::DispatchKey::XLA,
TORCH_FN(xla_mark_sharding_dynamo_custom_op)));
}

std::vector<bool> check_materialization_helper(
const std::vector<XLATensorPtr>& xtensors) {
std::vector<bool> need_materialization;
Expand Down Expand Up @@ -1692,16 +1562,16 @@ void InitXlaModuleBindings(py::module m) {
}));
m.def("_xla_mark_sharding",
[](const at::Tensor& input, xla::OpSharding sharding) {
xla_mark_sharding(input, sharding);
ShardingUtil::xla_mark_sharding(input, sharding);
});
m.def("_xla_mark_sharding_dynamo_custom_op",
[](const at::Tensor& input, const py::list& tile_assignment,
const py::list& group_assignment, const py::list& replication_groups,
int sharding_type) {
c10::List<at::IntArrayRef> time_assignment_list =
c10::List<at::IntArrayRef> tile_assignment_list =
c10::List<at::IntArrayRef>();
for (auto t : tile_assignment) {
time_assignment_list.push_back(
tile_assignment_list.push_back(
at::IntArrayRef(t.cast<std::vector<int64_t>>()));
}

Expand All @@ -1720,7 +1590,7 @@ void InitXlaModuleBindings(py::module m) {
}

xla_mark_sharding_dynamo_custom_op(
input, time_assignment_list, group_assignment_list,
input, tile_assignment_list, group_assignment_list,
replication_groups_list, sharding_type);
});
m.def("_xla_clear_sharding", [](const at::Tensor& input) {
Expand Down
135 changes: 135 additions & 0 deletions torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#include <unordered_map>

#include "torch/csrc/lazy/core/ir_util.h"
#include "torch_xla/csrc/aten_autograd_ops.h"
#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/device.h"
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/ops/device_data.h"
Expand All @@ -14,7 +16,9 @@
#include "torch_xla/csrc/runtime/runtime.h"
#include "torch_xla/csrc/runtime/thread_pool.h"
#include "torch_xla/csrc/tensor.h"
#include "torch_xla/csrc/tensor_methods.h"
#include "torch_xla/csrc/tensor_util.h"
#include "torch_xla/csrc/xla_graph_executor.h"
#include "tsl/profiler/lib/traceme.h"
#include "xla/execution_options_util.h"
#include "xla/hlo/ir/hlo_module.h"
Expand Down Expand Up @@ -742,4 +746,135 @@ runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData(
source_tensors, GetVirtualDevice().toString(), global_shape, sharding);
}

void ShardingUtil::xla_mark_sharding(const at::Tensor& input,
xla::OpSharding sharding) {
TORCH_LAZY_COUNTER("XlaMarkSharding", 1);
XLA_CHECK(UseVirtualDevice())
<< "Please enable SPMD via `torch_xla.runtime.use_spmd()`";
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
auto new_sharding_spec = std::make_shared<XLATensor::ShardingSpec>(
sharding, MakeShapeWithDeviceLayout(
xtensor->shape(),
static_cast<XlaDeviceType>(xtensor->GetDevice().type())));

// For Non DeviceData IR values, we directly attach the sharding spec
// to the xtensor.
const DeviceData* device_data_node = nullptr;
if (xtensor->CurrentIrValue()) {
device_data_node = DeviceData::Cast(xtensor->CurrentIrValue().node.get());
if (!device_data_node) {
tensor_methods::custom_sharding_(xtensor, new_sharding_spec);
return;
}
}

// For data, we need to deal with the data transfers between
// host and device.
at::Tensor cpu_tensor;
if (xtensor->CurrentTensorData().has_value()) {
TORCH_LAZY_COUNTER("VirtualDeviceUsage", 1);
// When virtual device is enabled for SPMD, we defer the initial
// data transfer to the device and retain the original data on the
// host, until the sharded data transfer.
cpu_tensor = xtensor->CurrentTensorData().value();
} else {
// A new input tensor is not expected to be sharded. But sometimes,
// the same input is called for sharding annotation over multiple steps,
// in which case we can skip if it's the same sharding; however, if it's
// the same input with a different sharding then we block & ask the user
// to clear the existing sharding first.
auto current_sharding_spec = xtensor->sharding_spec();
if (current_sharding_spec && (current_sharding_spec->sharding.type() !=
xla::OpSharding::REPLICATED)) {
XLA_CHECK(ShardingUtil::EqualShardingSpecs(*new_sharding_spec,
*current_sharding_spec))
<< "Existing annotation must be cleared first.";
return;
}

// If the at::Tensor data is not present, we need to re-download the
// tensor from the physical device to CPU. In that case, the value
// must be present on the backend device.
XLA_CHECK((xtensor->CurrentDataHandle() &&
xtensor->CurrentDataHandle()->HasValue()) ||
device_data_node != nullptr)
<< "Cannot shard tensor. Data does not present on any device.";
std::vector<XLATensorPtr> xla_tensors{xtensor};
cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0];
}
auto xla_data = CreateTensorsData(
std::vector<at::Tensor>{cpu_tensor},
std::vector<XLATensor::ShardingSpecPtr>{new_sharding_spec},
std::vector<std::string>{GetVirtualDevice().toString()})[0];
xtensor->SetXlaData(xla_data);
xtensor->SetShardingSpec(*new_sharding_spec);

// Register sharded tensor data.
XLAGraphExecutor::Get()->RegisterTensor(xtensor->data());
}

void xla_mark_sharding_dynamo_custom_op(
const at::Tensor& input, c10::List<at::IntArrayRef> tile_assignment,
c10::List<at::IntArrayRef> group_assignment,
c10::List<at::IntArrayRef> replication_groups, int64_t sharding_type) {
py::list tile_assignment_py = py::list();
for (int i = 0; i < tile_assignment.size(); i++) {
py::list pylist = py::list();
for (int64_t t : tile_assignment[i].get().toIntList()) {
pylist.append(t);
}
tile_assignment_py.append(pylist);
}

py::list group_assignment_py = py::list();
for (int i = 0; i < group_assignment.size(); i++) {
py::list pylist = py::list();
for (int64_t t : group_assignment[i].get().toIntList()) {
pylist.append(t);
}
group_assignment_py.append(pylist);
}

py::list replication_groups_py = py::list();
for (int i = 0; i < replication_groups.size(); i++) {
py::list pylist = py::list();
for (int64_t t : replication_groups[i].get().toIntList()) {
pylist.append(t);
}
replication_groups_py.append(pylist);
}

xla::OpSharding op_sharding = ShardingUtil::CreateOpSharding(
tile_assignment_py, group_assignment_py, replication_groups_py,
ShardingUtil::ShardingType(sharding_type));

ShardingUtil::xla_mark_sharding(input, op_sharding);
}

// Macro for defining a function that will be run at static initialization time
// to define a library of operators in the namespace. Used to define a new set
// of custom operators that do not already exist in PyTorch.
TORCH_LIBRARY(xla, m) {
m.def(
"max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], "
"int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor",
torch::dispatch(
c10::DispatchKey::XLA,
TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_forward)));

m.def(
"max_pool2d_backward(Tensor grad_output, Tensor self, int[2] "
"kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) "
"-> Tensor",
torch::dispatch(
c10::DispatchKey::XLA,
TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_backward)));
m.def(
"xla_mark_sharding_dynamo_custom_op(Tensor input, int[][] "
"tile_assignment, int[][] group_assignment, int[][] replication_groups, "
"int sharding_type) -> ()",
torch::dispatch(c10::DispatchKey::XLA,
TORCH_FN(torch_xla::xla_mark_sharding_dynamo_custom_op)));
}

} // namespace torch_xla
8 changes: 8 additions & 0 deletions torch_xla/csrc/xla_sharding_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,16 @@ class ShardingUtil {
const std::vector<at::Tensor>& shards,
const std::vector<std::string>& devices,
const XLATensor::ShardingSpecPtr& sharding_spec);

static void xla_mark_sharding(const at::Tensor& input,
xla::OpSharding sharding);
};

void xla_mark_sharding_dynamo_custom_op(
const at::Tensor& input, c10::List<at::IntArrayRef> tile_assignment,
c10::List<at::IntArrayRef> group_assignment,
c10::List<at::IntArrayRef> replication_groups, int64_t sharding_type);

} // namespace torch_xla

#endif // XLA_TORCH_XLA_CSRC_XLA_SHARDING_UTIL_H_
Loading

0 comments on commit f468ed6

Please sign in to comment.