Skip to content

Commit

Permalink
Clean up some code
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjoolee95 committed Oct 31, 2023
1 parent ae05c9a commit 3c51d5e
Show file tree
Hide file tree
Showing 10 changed files with 69 additions and 147 deletions.
1 change: 0 additions & 1 deletion .torch_pin

This file was deleted.

18 changes: 18 additions & 0 deletions test/spmd/test_dynamo_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,24 @@ def fn_simple(x):
dynamo_res = dynamo_linear(x_xla)
torch.allclose(xla_res.cpu(), dynamo_res.cpu())

# TODO (@wonjoo) Remove this test, this is just for debugging
def test_wonjoo(self):

def fn_simple(x):
print(f'x inside fn_simple before: {x}')
torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op(x_xla, [0], [0], [0], 0)
print(f'x inside fn_simple after: {x}')
return x

device = xm.xla_device()

x_xla = torch.zeros((1, 8)).to(device)
print(f'x_xla before: {x_xla}')

dynamo_fn = torch.compile(fn_simple, backend="openxla")
dynamo_res = dynamo_fn(x_xla)
print(f'dynamo_res: {dynamo_res}')


if __name__ == '__main__':
test = unittest.main()
Expand Down
22 changes: 11 additions & 11 deletions torch_xla/csrc/aten_autograd_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,17 +253,17 @@ torch::Tensor max_pool2d_backward(torch::Tensor grad_output, torch::Tensor self,
return grad;
}

TORCH_LIBRARY(xla, m) {
m.def(
"max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], "
"int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor",
torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_forward)));
// TORCH_LIBRARY(xla, m) {
// m.def(
// "max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], "
// "int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor",
// torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_forward)));

m.def(
"max_pool2d_backward(Tensor grad_output, Tensor self, int[2] "
"kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) "
"-> Tensor",
torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_backward)));
}
// m.def(
// "max_pool2d_backward(Tensor grad_output, Tensor self, int[2] "
// "kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) "
// "-> Tensor",
// torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_backward)));
// }
} // namespace aten_autograd_ops
} // namespace torch_xla
11 changes: 11 additions & 0 deletions torch_xla/csrc/aten_autograd_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ struct MaxPool3dAutogradFunction
torch::autograd::variable_list grad_output);
};

torch::Tensor max_pool2d_forward(torch::Tensor self,
torch::IntArrayRef kernel_size,
torch::IntArrayRef stride,
torch::IntArrayRef padding,
torch::IntArrayRef dilation, bool ceil_mode);

torch::Tensor max_pool2d_backward(torch::Tensor grad_output, torch::Tensor self,
torch::IntArrayRef kernel_size,
torch::IntArrayRef stride,
torch::IntArrayRef padding, bool ceil_mode);

} // namespace aten_autograd_ops
} // namespace torch_xla

Expand Down
30 changes: 27 additions & 3 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "pybind11/pytypes.h"
#include "pybind11/stl_bind.h"
#include "torch_xla/csrc/XLANativeFunctions.h"
#include "torch_xla/csrc/aten_autograd_ops.h"
#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/device.h"
#include "torch_xla/csrc/helpers.h"
Expand Down Expand Up @@ -651,6 +652,30 @@ std::string GetPyTypeString(py::handle obj) {
return type;
}

void xla_mark_sharding_dynamo_custom_op(const at::Tensor& input, at::IntArrayRef tile_assignment, at::IntArrayRef group_assignment, at::IntArrayRef replication_groups, int64_t sharding_type) {
// TODO create the OpSharding and manually call the pybind
XLATensorPtr self_tensor = bridge::GetXlaTensor(input);
tensor_methods::fill_(self_tensor, 1);
}

// Macro for defining a function that will be run at static initialization time to define a library of operators in the namespace.
// Used to define a new set of custom operators that do not already exist in PyTorch.
TORCH_LIBRARY(xla, m) {
m.def(
"max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], "
"int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor",
torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_forward)));

m.def(
"max_pool2d_backward(Tensor grad_output, Tensor self, int[2] "
"kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) "
"-> Tensor",
torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_backward)));
m.def(
"xla_mark_sharding_dynamo_custom_op(Tensor input, int[] tile_assignment, int[] group_assignment, int[] replication_groups, int sharding_type) -> ()",
torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(xla_mark_sharding_dynamo_custom_op)));
}

std::vector<bool> check_materialization_helper(
const std::vector<XLATensorPtr>& xtensors) {
std::vector<bool> need_materialization;
Expand Down Expand Up @@ -1626,9 +1651,8 @@ void InitXlaModuleBindings(py::module m) {
XLAGraphExecutor::Get()->RegisterTensor(xtensor->data());
});
m.def("_xla_mark_sharding_dynamo_custom_op",
[](const at::Tensor& input, xla::OpSharding sharding) {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
tensor_methods::custom_mark_sharding(xtensor, sharding);
[](const at::Tensor& input, at::IntArrayRef tile_assignment, at::IntArrayRef group_assignment, at::IntArrayRef replication_groups, int64_t sharding_type) {
xla_mark_sharding_dynamo_custom_op(input, tile_assignment, group_assignment, replication_groups, sharding_type);
});
m.def("_xla_clear_sharding", [](const at::Tensor& input) {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
Expand Down
34 changes: 0 additions & 34 deletions torch_xla/csrc/ops/custom_mark_sharding.cpp

This file was deleted.

23 changes: 0 additions & 23 deletions torch_xla/csrc/ops/custom_mark_sharding.h

This file was deleted.

71 changes: 0 additions & 71 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
#include "torch_xla/csrc/ops/count_nonzero.h"
#include "torch_xla/csrc/ops/cumprod.h"
#include "torch_xla/csrc/ops/cumsum.h"
#include "torch_xla/csrc/ops/custom_mark_sharding.h"
#include "torch_xla/csrc/ops/custom_sharding.h"
#include "torch_xla/csrc/ops/device_data.h"
#include "torch_xla/csrc/ops/diagonal.h"
Expand Down Expand Up @@ -444,76 +443,6 @@ void custom_sharding_(
input->SetShardingSpec(*sharding_spec);
}

void custom_mark_sharding(const XLATensorPtr& input, xla::OpSharding sharding) {
// TODO (@wonjoo) Do we need this `sharding` here?
input->SetInPlaceIrValue(torch::lazy::MakeNode<CustomMarkSharding>(
input->GetIrValue(), input->GetIrValue()));

TORCH_LAZY_COUNTER("XlaMarkSharding", 1);
XLA_CHECK(UseVirtualDevice())
<< "Please enable SPMD via `torch_xla.runtime.use_spmd()`";
// XLATensorPtr xtensor = bridge::GetXlaTensor(input);
auto new_sharding_spec = std::make_shared<XLATensor::ShardingSpec>(
sharding, MakeShapeWithDeviceLayout(
input->shape(),
static_cast<XlaDeviceType>(input->GetDevice().type())));

// For Non DeviceData IR values, we directly attach the sharding spec
// to the xtensor.
const DeviceData* device_data_node = nullptr;
if (input->CurrentIrValue()) {
device_data_node = DeviceData::Cast(input->CurrentIrValue().node.get());
if (!device_data_node) {
tensor_methods::custom_sharding_(input, new_sharding_spec);
return;
}
}

// For data, we need to deal with the data transfers between
// host and device.
at::Tensor cpu_tensor;
if (input->CurrentTensorData().has_value()) {
TORCH_LAZY_COUNTER("VirtualDeviceUsage", 1);
// When virtual device is enabled for SPMD, we defer the initial
// data transfer to the device and retain the original data on the
// host, until the sharded data transfer.
cpu_tensor = input->CurrentTensorData().value();
} else {
// A new input tensor is not expected to be sharded. But sometimes,
// the same input is called for sharding annotation over multiple steps,
// in which case we can skip if it's the same sharding; however, if it's
// the same input with a different sharding then we block & ask the user
// to clear the existing sharding first.
auto current_sharding_spec = input->sharding_spec();
if (current_sharding_spec && (current_sharding_spec->sharding.type() !=
xla::OpSharding::REPLICATED)) {
XLA_CHECK(ShardingUtil::EqualShardingSpecs(*new_sharding_spec,
*current_sharding_spec))
<< "Existing annotation must be cleared first.";
return;
}

// If the at::Tensor data is not present, we need to re-download the
// tensor from the physical device to CPU. In that case, the value
// must be present on the backend device.
XLA_CHECK((input->CurrentDataHandle() &&
input->CurrentDataHandle()->HasValue()) ||
device_data_node != nullptr)
<< "Cannot shard tensor. Data does not present on any device.";
std::vector<XLATensorPtr> xla_tensors{input};
cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0];
}
auto xla_data = CreateTensorsData(
std::vector<at::Tensor>{cpu_tensor},
std::vector<XLATensor::ShardingSpecPtr>{new_sharding_spec},
std::vector<std::string>{GetVirtualDevice().toString()})[0];
input->SetXlaData(xla_data);
input->SetShardingSpec(*new_sharding_spec);

// Register sharded tensor data.
XLAGraphExecutor::Get()->RegisterTensor(input->data());
}

XLATensorPtr get_dimensions_size(const XLATensorPtr& input,
std::vector<int64_t> dimensions) {
return input->CreateFrom(torch::lazy::MakeNode<GetDimensionsSize>(
Expand Down
2 changes: 0 additions & 2 deletions torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ std::pair<XLATensorPtr, torch::lazy::Value> collective_permute(
void custom_sharding_(const XLATensorPtr& input,
const std::shared_ptr<XLATensor::ShardingSpec>& spec);

void custom_mark_sharding(const XLATensorPtr& input, xla::OpSharding sharding);

XLATensorPtr get_dimensions_size(const XLATensorPtr& input,
std::vector<int64_t> dimensions);

Expand Down
4 changes: 2 additions & 2 deletions torch_xla/experimental/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,9 +500,9 @@ def mark_sharding(
op_sharding = mesh.get_op_sharding(partition_spec)

if isinstance(t, XLAShardedTensor):
torch_xla._XLAC._xla_mark_sharding(t.global_tensor, op_sharding)
torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op(t.global_tensor, op_sharding)
return t
torch_xla._XLAC._xla_mark_sharding(t, op_sharding)
torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op(t, op_sharding)
return XLAShardedTensor(t)


Expand Down

0 comments on commit 3c51d5e

Please sign in to comment.