diff --git a/test/test_operations.py b/test/test_operations.py index 4db6be38cea..75d8e6f5f7d 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -13,6 +13,7 @@ parser.add_argument('--verbosity', type=int, default=0) FLAGS, leftovers = parser.parse_known_args() sys.argv = [sys.argv[0]] + leftovers +from absl.testing import absltest, parameterized # Normal imports section starts here. import collections @@ -28,6 +29,11 @@ import torch.nn as nn import torch.nn.functional as F import torch.optim as optim +from torch.testing._internal.common_device_type import dtypes +from torch.testing._internal.common_dtype import ( + all_types_and_complex_and, + all_types_and, +) import torch_xla import torch_xla.core.xla_builder as xb import torch_xla.core.xla_op_registry as xor @@ -40,6 +46,7 @@ import torch_xla.distributed.spmd as xs from torch_xla import runtime as xr import torch_xla.test.test_utils as xtu +import torch_xla.utils.dlpack as xdlpack import torch_xla.utils.utils as xu import torch_xla.utils.serialization as xser import torch_xla.core.xla_model as xm @@ -2464,6 +2471,139 @@ def test_unsafe_buffer_pointer(self): self.assertGreaterEqual(buf_ptr_3, 0) +class TestDLPack(parameterized.TestCase): + + def _test_dlpack_capsule_conversion_helper(self, xla_tensor): + dlpt = xdlpack.to_dlpack(xla_tensor) # dlpt1 has type PyCapsule + xla_tensor2 = xdlpack.from_dlpack(dlpt) + + self.assertEqual(xla_tensor.device, xla_tensor2.device) + self.assertTrue(torch.allclose(xla_tensor.cpu(), xla_tensor2.cpu())) + self.assertRaisesRegex(RuntimeError, + "DLTensor capsule can be consumed only once", + lambda: xdlpack.from_dlpack(dlpt)) + + self.assertEqual( + torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor), + torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor2)) + + @onlyIfTorchSupportsCUDA + @onlyIfPJRTDeviceIsCUDA + @parameterized.parameters(*all_types_and(torch.half, torch.bfloat16)) + def test_dlpack_roundtrip_tensor(self, dtype): + xla_device = xm.xla_device() + # xtensor->CurrentDataHandle() == nullptr but xtensor->CurrentIrValue().node != nullptr and device_data != nullptr + # xla_tensor_2 uses XLANativeFunctions::_to_copy + xla_tensor_2 = torch.arange(5, dtype=dtype).to(xla_device) + self._test_dlpack_capsule_conversion_helper(xla_tensor_2) + + # xla_tensor_3 uses arange_out IR node. + xla_tensor_3 = torch.arange(5, dtype=dtype, device=xm.xla_device()) + xm.mark_step() + self._test_dlpack_capsule_conversion_helper(xla_tensor_3) + + @onlyIfTorchSupportsCUDA + @onlyIfPJRTDeviceIsCUDA + @parameterized.parameters(*all_types_and_complex_and(torch.half, + torch.bfloat16, + torch.bool, torch.uint16, + torch.uint32, + torch.uint64)) + def test_dlpack_roundtrip_scalar(self, dtype): + xla_device = xm.xla_device() + xla_tensor_0 = torch.tensor(42, dtype=dtype).to(xla_device) + # `mark_step` ensures xtensor->CurrentDataHandle() != nullptr + xm.mark_step() + self._test_dlpack_capsule_conversion_helper(xla_tensor_0) + + xla_tensor_1 = torch.tensor(42, dtype=dtype).to(xla_device) + # xtensor->CurrentDataHandle() == nullptr but xtensor->CurrentIrValue().node != nullptr and device_data != nullptr + self._test_dlpack_capsule_conversion_helper(xla_tensor_1) + + @onlyIfTorchSupportsCUDA + @onlyIfPJRTDeviceIsCUDA + def test_dlpack_roundtrip_bool(self): + xla_tensor = torch.ones(1, dtype=torch.bool).to(xm.xla_device()) + self._test_dlpack_capsule_conversion_helper(xla_tensor) + + @onlyIfTorchSupportsCUDA + @onlyIfPJRTDeviceIsCUDA + def test_dlpack_pytorch_cuda_to_xla(self): + t1_cuda = torch.arange(5).cuda() + dlt1 = torch.utils.dlpack.to_dlpack(t1_cuda) + xla_t1 = xdlpack.from_dlpack(dlt1) + self.assertEqual(xla_t1.device.type, 'xla') + self.assertEqual(xla_t1.device.index, t1_cuda.device.index) + t1_cuda[0] = t1_cuda[0] + 20 + self.assertTrue(torch.allclose(xla_t1.cpu(), t1_cuda.cpu())) + + t2_cuda = torch.tensor(5).cuda() + dlt2 = torch.utils.dlpack.to_dlpack(t2_cuda) + xla_t2 = xdlpack.from_dlpack(dlt2) + self.assertEqual(xla_t2.device.type, 'xla') + self.assertEqual(xla_t2.device.index, t2_cuda.device.index) + t2_cuda.fill_(6) + self.assertTrue(torch.allclose(xla_t2.cpu(), t2_cuda.cpu())) + + cuda1 = torch.device('cuda:1') + t3_cuda = torch.tensor(5, device=cuda1) + dlt3 = torch.utils.dlpack.to_dlpack(t3_cuda) + xla_t3 = xdlpack.from_dlpack(dlt3) + self.assertEqual(xla_t3.device.type, 'xla') + self.assertEqual( + xla_t3.device.index, + t3_cuda.device.index, + msg='both value should 1. xla_t3.device should be xla:1.') + t3_cuda.fill_(6) + self.assertTrue(torch.allclose(xla_t3.cpu(), t3_cuda.cpu())) + + @onlyIfTorchSupportsCUDA + @onlyIfPJRTDeviceIsCUDA + def test_dlpack_xla_to_pytorch_cuda(self): + xla_t1 = torch.arange(5).to(xm.xla_device()) + dlt1 = xdlpack.to_dlpack(xla_t1) + cuda_t1 = torch.utils.dlpack.from_dlpack(dlt1) + self.assertEqual(cuda_t1.device.type, 'cuda') + self.assertEqual(cuda_t1.device.index, xla_t1.device.index) + cuda_t1[0] = cuda_t1[0] + 20 + self.assertTrue(torch.allclose(xla_t1.cpu(), cuda_t1.cpu())) + + @onlyIfTorchSupportsCUDA + @onlyIfPJRTDeviceIsCUDA + def test_dlpack_non_default_layout(self): + cuda_t = torch.arange(25, device=torch.device('cuda')).reshape(5, 5) + + t1 = cuda_t.t() + xla_t1 = xdlpack.from_dlpack(t1.__dlpack__()) + self.assertEqual(xla_t1.device.type, 'xla') + self.assertEqual(xla_t1.device.index, 0) + self.assertTrue(torch.allclose(t1.cpu(), xla_t1.cpu())) + + t2 = cuda_t[0] + xla_t2 = xdlpack.from_dlpack(t2.__dlpack__()) + self.assertEqual(xla_t2.device.type, 'xla') + self.assertEqual(xla_t2.device.index, 0) + self.assertTrue(torch.allclose(t2.cpu(), xla_t2.cpu())) + + t3 = cuda_t[:, 0] + self.assertRaisesRegex( + RuntimeError, + r"Only DLPack tensors with trivial \(compact\) striding are supported", + lambda: xdlpack.from_dlpack(t3.__dlpack__())) + + t4 = cuda_t[1, :] + xla_t4 = xdlpack.from_dlpack(t4.__dlpack__()) + self.assertEqual(xla_t4.device.type, 'xla') + self.assertEqual(xla_t4.device.index, 0) + self.assertTrue(torch.allclose(t4.cpu(), xla_t4.cpu())) + + t5 = cuda_t[1] + xla_t5 = xdlpack.from_dlpack(t5.__dlpack__()) + self.assertEqual(xla_t5.device.type, 'xla') + self.assertEqual(xla_t5.device.index, 0) + self.assertTrue(torch.allclose(t5.cpu(), xla_t5.cpu())) + + class SimpleModelWithDropout(torch.nn.Module): def __init__(self): diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index 2faf483f067..a2aadc0c633 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -42,6 +42,7 @@ ptxla_cc_library( "cross_replica_reduces.cpp", "data_ops.cpp", "debug_util.cpp", + "dl_convertor.cpp", "elementwise.cpp", "helpers.cpp", "ir_dump_util.cpp", @@ -81,6 +82,7 @@ ptxla_cc_library( "cross_replica_reduces.h", "data_ops.h", "debug_util.h", + "dl_convertor.h", "elementwise.h", "generated_file_include.h", "helpers.h", diff --git a/torch_xla/csrc/dl_convertor.cpp b/torch_xla/csrc/dl_convertor.cpp new file mode 100644 index 00000000000..d29401be8fe --- /dev/null +++ b/torch_xla/csrc/dl_convertor.cpp @@ -0,0 +1,345 @@ +#include "torch_xla/csrc/dl_convertor.h" + +#include + +#include "absl/types/span.h" +#include "torch_xla/csrc/aten_xla_bridge.h" +#include "torch_xla/csrc/ops/device_data.h" +#include "torch_xla/csrc/runtime/computation_client.h" +#include "torch_xla/csrc/runtime/debug_macros.h" +#include "torch_xla/csrc/runtime/pjrt_computation_client.h" +#include "torch_xla/csrc/runtime/runtime.h" +#include "torch_xla/csrc/runtime/tf_logging.h" +#include "torch_xla/csrc/tensor.h" +#include "torch_xla/csrc/tensor_util.h" +#include "torch_xla/csrc/unwrap_data.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_future.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/status.h" + +namespace torch_xla { + +struct DLPackTensor { + ~DLPackTensor(); + std::unique_ptr external_reference; + std::shared_ptr buffer_reference; + + std::vector shape; + std::vector strides; + DLManagedTensor tensor; +}; + +DLPackTensor::~DLPackTensor() { + if (external_reference) { + external_reference.reset(nullptr); + } +} + +void DLPackTensorDeleter(DLManagedTensor* t) { + if (t) { + delete static_cast(t->manager_ctx); + } +} + +DLDeviceType DLDeviceTypeForDevice(const xla::PjRtDevice& device) { + if (device.client()->platform_id() == xla::CpuId()) { + return DLDeviceType::kDLCPU; + } else if (device.client()->platform_id() == xla::CudaId()) { + return DLDeviceType::kDLCUDA; + } + XLA_ERROR() << "Device " << device.DebugString() + << " cannot be used as a DLPack device."; +} + +// Reference: https://github.com/openxla/xla/blob/main/xla/python/dlpack.cc +DLDevice DLDeviceForDevice(const xla::PjRtDevice& device) { + DLDevice dlDevice; + dlDevice.device_type = DLDeviceTypeForDevice(device); + dlDevice.device_id = device.local_hardware_id(); + return dlDevice; +} + +// Reference: https://github.com/openxla/xla/blob/main/xla/python/dlpack.cc +DLDataType PrimitiveTypeToDLDataType(xla::PrimitiveType type) { + switch (type) { + case xla::PrimitiveType::S8: + return DLDataType{kDLInt, 8, 1}; + case xla::PrimitiveType::S16: + return DLDataType{kDLInt, 16, 1}; + case xla::PrimitiveType::S32: + return DLDataType{kDLInt, 32, 1}; + case xla::PrimitiveType::S64: + return DLDataType{kDLInt, 64, 1}; + case xla::PrimitiveType::U8: + return DLDataType{kDLUInt, 8, 1}; + case xla::PrimitiveType::U16: + return DLDataType{kDLUInt, 16, 1}; + case xla::PrimitiveType::U32: + return DLDataType{kDLUInt, 32, 1}; + case xla::PrimitiveType::U64: + return DLDataType{kDLUInt, 64, 1}; + case xla::PrimitiveType::F16: + return DLDataType{kDLFloat, 16, 1}; + case xla::PrimitiveType::F32: + return DLDataType{kDLFloat, 32, 1}; + case xla::PrimitiveType::F64: + return DLDataType{kDLFloat, 64, 1}; + case xla::PrimitiveType::BF16: + return DLDataType{kDLBfloat, 16, 1}; + case xla::PrimitiveType::PRED: + return DLDataType{kDLBool, 8, 1}; + case xla::PrimitiveType::C64: + return DLDataType{kDLComplex, 64, 1}; + case xla::PrimitiveType::C128: + return DLDataType{kDLComplex, 128, 1}; + default: + XLA_ERROR() << "XLA type " << xla::PrimitiveType_Name(type) + << " has no DLPack equivalent"; + } +} + +std::vector StridesForShape(xla::PrimitiveType element_type, + absl::Span dimensions, + const xla::Layout& layout) { + XLA_CHECK_EQ(dimensions.size(), layout.minor_to_major().size()); + std::vector strides; + strides.resize(dimensions.size()); + int64_t stride = 1; + for (int i : layout.minor_to_major()) { + strides[i] = stride; + stride *= dimensions[i]; + } + return strides; +} + +// Convert an XLA tensor to a dlPack tensor. +DLManagedTensor* toDLPack(const at::Tensor& input) { + std::shared_ptr handle = + get_data_handle(input); + XLA_CHECK(handle != nullptr) + << "Could not extract a valid data handle from the input tensor"; + + std::shared_ptr pjrt_buffer = + runtime::GetComputationClient()->GetPjRtBuffer(handle); + XLA_CHECK(pjrt_buffer != nullptr) << "Could not get a valid pjrt_buffer"; + + XLA_CHECK(!pjrt_buffer->IsTuple()) + << "Unimplemented. BufferToDLPackManagedTensor is not " + "implemented for tuple buffers."; + XLA_CHECK(!pjrt_buffer->has_dynamic_dimensions()) + << "Unimplemented. DynamicShape is not implemented in DLPack."; + + auto pack = std::make_unique(); + DLTensor& dt = pack->tensor.dl_tensor; + { + // AcquireExternalReference may block + auto external_ref = pjrt_buffer->AcquireExternalReference(); + XLA_CHECK_OK(external_ref.status()); + pack->external_reference = std::move(external_ref.value()); + xla::PjRtFuture<> future = pjrt_buffer->GetReadyFuture(); + absl::Status status = future.Await(); + XLA_CHECK_OK(status); + } + pack->buffer_reference = pjrt_buffer; + + dt.data = pack->external_reference->OpaqueDeviceMemoryDataPointer(); + pack->tensor.manager_ctx = pack.get(); + pack->tensor.deleter = DLPackTensorDeleter; + dt.device = DLDeviceForDevice(*pjrt_buffer->device()); + dt.device.device_id = pjrt_buffer->device()->local_hardware_id(); + dt.ndim = pjrt_buffer->dimensions().size(); + dt.dtype = PrimitiveTypeToDLDataType(pjrt_buffer->element_type()); + + pack->shape = std::vector(pjrt_buffer->dimensions().begin(), + pjrt_buffer->dimensions().end()); + xla::Layout xla_layout = xla::GetXlaLayoutUnsafe(pjrt_buffer->layout()); + pack->strides = StridesForShape(pjrt_buffer->element_type(), + pjrt_buffer->dimensions(), xla_layout); + dt.shape = reinterpret_cast(pack->shape.data()); + dt.strides = reinterpret_cast(pack->strides.data()); + dt.byte_offset = 0; + + return &(pack.release()->tensor); +} + +// Reference: https://github.com/openxla/xla/blob/main/xla/python/dlpack.cc +absl::StatusOr DeviceForDLDevice(const DLDevice& context) { + switch (context.device_type) { + case DLDeviceType::kDLCPU: + XLA_CHECK_EQ(runtime::GetComputationClient()->GetPlatformID(), + xla::CpuId()); + return runtime::GetComputationClient()->LookupAddressableDevice( + context.device_id); + case DLDeviceType::kDLCUDA: + XLA_CHECK_EQ(runtime::GetComputationClient()->GetPlatformID(), + xla::CudaId()); + return runtime::GetComputationClient()->LookupAddressableDevice( + context.device_id); + default: + return tsl::errors::InvalidArgument( + "Unknown/unsupported DLPack device type %d", context.device_type); + } +} + +// Reference: https://github.com/openxla/xla/blob/main/xla/python/dlpack.cc +absl::StatusOr DLDataTypeToPrimitiveType(DLDataType type) { + if (type.lanes != 1) { + return tsl::errors::Unimplemented( + "DLPack types with lanes != 1 not implemented, got %d", type.lanes); + } + switch (type.code) { + case kDLBool: + switch (type.bits) { + case 8: + return xla::PrimitiveType::PRED; + default: + return tsl::errors::Unimplemented( + "Only 8-bit DLPack booleans are supported, got %d bits", + type.bits); + } + case kDLInt: + switch (type.bits) { + case 8: + return xla::PrimitiveType::S8; + case 16: + return xla::PrimitiveType::S16; + case 32: + return xla::PrimitiveType::S32; + case 64: + return xla::PrimitiveType::S64; + default: + return tsl::errors::Unimplemented( + "Invalid or unsupported DLPack integer width: %d bits", + type.bits); + } + case kDLUInt: + switch (type.bits) { + case 8: + return xla::PrimitiveType::U8; + case 16: + return xla::PrimitiveType::U16; + case 32: + return xla::PrimitiveType::U32; + case 64: + return xla::PrimitiveType::U64; + default: + return tsl::errors::Unimplemented( + "Invalid or unsupported DLPack unsigned integer width: %d bits", + type.bits); + } + case kDLFloat: + switch (type.bits) { + case 16: + return xla::PrimitiveType::F16; + case 32: + return xla::PrimitiveType::F32; + case 64: + return xla::PrimitiveType::F64; + default: + return tsl::errors::Unimplemented( + "Invalid or unsupported DLPack float width: %d bits", type.bits); + } + case kDLBfloat: + switch (type.bits) { + case 16: + return xla::PrimitiveType::BF16; + default: + return tsl::errors::Unimplemented( + "Invalid or unsupported DLPack Bfloat width: %d bits", type.bits); + } + case kDLComplex: + switch (type.bits) { + case 64: + return xla::PrimitiveType::C64; + case 128: + return xla::PrimitiveType::C128; + default: + return tsl::errors::Unimplemented( + "Invalid or unsupported DLPack complex width: %d bits", + type.bits); + } + default: + return tsl::errors::Unimplemented( + "Unknown or invalid DLPack type code %d", type.code); + } +} + +// Reference: https://github.com/openxla/xla/blob/main/xla/python/dlpack.cc +absl::StatusOr> StridesToLayout( + absl::Span dims, absl::Span strides) { + XLA_CHECK_EQ(dims.size(), strides.size()); + std::vector minor_to_major(dims.size()); + std::iota(minor_to_major.begin(), minor_to_major.end(), 0); + absl::c_sort(minor_to_major, [&](int a, int b) { + if (strides[a] < strides[b]) { + return true; + } + if (strides[a] > strides[b]) { + return false; + } + // If two dimensions have the same stride, prefer the major-to-minor + // interpretation of the ordering, since that's what JAX wants. + return b < a; + }); + int64_t stride = 1; + for (int64_t d : minor_to_major) { + if (dims[d] > 1 && strides[d] != stride) { + return tsl::errors::Unimplemented( + "Only DLPack tensors with trivial (compact) striding are supported; " + "i.e., tensors whose striding represents a transposition of the " + "underlying buffer but not broadcasting. Dimensions were: [%s], " + "strides were [%s].", + absl::StrJoin(dims, ","), absl::StrJoin(strides, ",")); + } + stride *= dims[d]; + } + return minor_to_major; +} + +at::Tensor fromDLPack(DLManagedTensor* dlmt) { + XLA_CHECK(dlmt->dl_tensor.ndim >= 0) + << "Number of dimensions in DLManagedTensor must be nonnegative, got " + << dlmt->dl_tensor.ndim; + xla::PjRtDevice* device = DeviceForDLDevice(dlmt->dl_tensor.device).value(); + absl::Span dimensions( + const_cast(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim); + xla::PrimitiveType element_type = + DLDataTypeToPrimitiveType(dlmt->dl_tensor.dtype).value(); + + std::vector minor_to_major; + if (dlmt->dl_tensor.strides && + absl::c_find(dimensions, 0) == dimensions.end()) { + absl::Span strides( + const_cast(dlmt->dl_tensor.strides), dlmt->dl_tensor.ndim); + minor_to_major = StridesToLayout(dimensions, strides).value(); + } else { + minor_to_major.resize(dlmt->dl_tensor.ndim); + std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0); + } + xla::Shape shape = xla::ShapeUtil::MakeShapeWithDenseLayout( + element_type, dimensions, minor_to_major); + + std::function on_delete_callback; + if (dlmt->deleter) { + on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); }; + } + xla::StatusOr> pjrt_buffer = + device->client()->CreateViewOfDeviceBuffer( + static_cast(dlmt->dl_tensor.data) + + dlmt->dl_tensor.byte_offset, + shape, device, on_delete_callback); + XLA_CHECK_OK(pjrt_buffer.status()) << "Failed to create a pjrt buffer."; + XLA_CHECK(pjrt_buffer.value() != nullptr) << "pjrt buffer is null."; + + runtime::ComputationClient::DataPtr data = + runtime::PjRtComputationClient::CreateData( + runtime::GetComputationClient()->PjRtDeviceToString(device), shape, + std::move(pjrt_buffer.value())); + + at::ScalarType tensor_type = at::toScalarType(dlmt->dl_tensor.dtype); + XLATensorPtr xla_tensor = XLATensor::Create(data, tensor_type); + return bridge::AtenFromXlaTensor(xla_tensor); +} + +} // namespace torch_xla diff --git a/torch_xla/csrc/dl_convertor.h b/torch_xla/csrc/dl_convertor.h new file mode 100644 index 00000000000..f5a54823e2e --- /dev/null +++ b/torch_xla/csrc/dl_convertor.h @@ -0,0 +1,14 @@ +#ifndef XLA_TORCH_XLA_CSRC_DL_CONVERTOR_H_ +#define XLA_TORCH_XLA_CSRC_DL_CONVERTOR_H_ + +#include +#include + +namespace torch_xla { + +DLManagedTensor* toDLPack(const at::Tensor& src); +at::Tensor fromDLPack(DLManagedTensor* src); + +} // namespace torch_xla + +#endif // XLA_TORCH_XLA_CSRC_DL_CONVERTOR_H_ diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 503a3382b74..a9240db692a 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -34,6 +35,7 @@ #include "torch_xla/csrc/aten_autograd_ops.h" #include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/device.h" +#include "torch_xla/csrc/dl_convertor.h" #include "torch_xla/csrc/dtype.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/ir.h" @@ -1098,6 +1100,36 @@ void BuildLoweringContextSubmodule(py::module* m) { .def("get_name_string", &PyLoweringContext::GetNameString); } +// Used in the to_dlpack. +void dlPack_Capsule_Destructor(PyObject* data) { + if (!PyCapsule_IsValid(data, "dltensor")) { + return; + } + DLManagedTensor* dlMTensor = + static_cast(PyCapsule_GetPointer(data, "dltensor")); + if (dlMTensor) { + dlMTensor->deleter(dlMTensor); + } else { + // The tensor has been deleted. Clear any error from + // PyCapsule_GetPointer. + PyErr_Clear(); + } +} + +at::Tensor tensor_fromDLPack(PyObject* data) { + DLManagedTensor* dlMTensor = + (DLManagedTensor*)PyCapsule_GetPointer(data, "dltensor"); + XLA_CHECK(dlMTensor != nullptr) + << "from_dlpack received an invalid capsule. Note that a DLTensor " + "capsule can be consumed only once. You may have already constructed " + "a tensor from it once."; + + at::Tensor tensor = torch_xla::fromDLPack(dlMTensor); + PyCapsule_SetName(data, "used_dltensor"); + PyCapsule_SetDestructor(data, nullptr); + return tensor; +} + void InitXlaModuleBindings(py::module m) { m.def("_prepare_to_exit", []() { PrepareToExit(); }); m.def("_xla_runtime_is_initialized", []() { @@ -2507,6 +2539,29 @@ void InitXlaModuleBindings(py::module m) { "without a data handle or an IR."; }); + // from an XLA tensor to a dlpack tensor. + // If ext_data is the result of an CUDA computation, we should synchronize + // (waits for all kernels in all streams on a CUDA device to complete) if the + // current stream is different from the ext_data's stream. Otherwise, we may + // risk of getting incorrect results. + m.def("_to_dlpack", [](const at::Tensor& input) -> py::handle { + DLManagedTensor* dlMTensor; + { + NoGilSection nogil; + dlMTensor = torch_xla::toDLPack(input); + } + return PyCapsule_New(dlMTensor, "dltensor", dlPack_Capsule_Destructor); + }); + + // from a dlpack tensor to an XLA tensor + // If ext_data is the result of an CUDA computation, we should synchronize + // (waits for all kernels in all streams on a CUDA device to complete) if the + // current stream is different from the ext_data's stream. Otherwise, we may + // risk of getting incorrect results. + m.def("_from_dlpack", [](py::handle ext_data) -> at::Tensor { + return tensor_fromDLPack(ext_data.ptr()); + }); + // -------------Dynamo Integration API Start------------------------- /* * Return tensor ids and at::tensors for all DeviceData nodes that is needed diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index 5a0c42482e4..93664900ebd 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -27,6 +27,8 @@ #include "xla/client/xla_computation.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/literal_util.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_common.h" #include "xla/types.h" namespace torch_xla { @@ -279,6 +281,9 @@ class ComputationClient { // structure will be empty if there is no sharding, like with PjRtData. virtual std::optional GetDataSharding(DataPtr handle) = 0; + virtual std::string PjRtDeviceToString( + xla::PjRtDevice* const device) const = 0; + // Transfers local tensor values to the TPU devices and fetches the handles. virtual std::vector TransferToDevice( absl::Span> tensors) = 0; @@ -308,6 +313,9 @@ class ComputationClient { virtual std::uintptr_t UnsafeBufferPointer(const DataPtr handle) = 0; + virtual std::shared_ptr GetPjRtBuffer( + const DataPtr handle) = 0; + // Compiles a set of computations. virtual std::vector Compile( std::vector instances) = 0; @@ -348,6 +356,11 @@ class ComputationClient { virtual torch_xla::DeviceType GetDeviceType() const = 0; + virtual xla::PjRtPlatformId GetPlatformID() const = 0; + + virtual absl::StatusOr LookupAddressableDevice( + int local_device_id) const = 0; + virtual size_t GetNumDevices() const = 0; virtual std::vector GetLocalDevices() const = 0; diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index 842398126d0..e2a72992d6f 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -402,6 +402,11 @@ std::uintptr_t IfrtComputationClient::UnsafeBufferPointer( XLA_ERROR() << __FUNCTION__ << " not implemented"; } +std::shared_ptr IfrtComputationClient::GetPjRtBuffer( + const DataPtr handle) { + XLA_ERROR() << __FUNCTION__ << " not implemented"; +} + std::vector IfrtComputationClient::TransferFromDevice( absl::Span handles) { metrics::TimedSection timed(TransferFromDeviceMetric()); diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index 4c10be9d1ca..ca40c8fb02c 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -56,6 +56,8 @@ class IfrtComputationClient : public ComputationClient { std::uintptr_t UnsafeBufferPointer(const DataPtr handle) override; + std::shared_ptr GetPjRtBuffer(const DataPtr handle) override; + DataPtr TransferShardsToDevice( absl::Span> tensor_shards, std::string device, xla::Shape shape, xla::OpSharding sharding) override; @@ -84,6 +86,15 @@ class IfrtComputationClient : public ComputationClient { absl::AsciiStrToUpper(client_->platform_name())); }; + xla::PjRtPlatformId GetPlatformID() const override { + return client_->platform_id(); + } + + absl::StatusOr LookupAddressableDevice( + int local_device_id) const override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + std::vector GetLocalDevices() const override; std::vector GetAllDevices() const override; @@ -121,6 +132,10 @@ class IfrtComputationClient : public ComputationClient { XLA_ERROR() << __FUNCTION__ << " not implemented"; }; + std::string PjRtDeviceToString(xla::PjRtDevice* const device) const override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + std::string SerializeComputation(const ComputationPtr computation) override { XLA_ERROR() << __FUNCTION__ << " not implemented"; } diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index df49f725b7b..55089014152 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -184,6 +184,13 @@ ComputationClient::DataPtr PjRtComputationClient::CreateDataPlaceholder( return std::make_shared(std::move(device), std::move(shape)); } +ComputationClient::DataPtr PjRtComputationClient::CreateData( + std::string device, xla::Shape shape, + std::shared_ptr pjrt_buffer) { + return std::make_shared(std::move(device), std::move(shape), + pjrt_buffer); +} + std::vector PjRtComputationClient::GetDataShards( ComputationClient::DataPtr data) { tsl::profiler::TraceMe activity("PjRtComputationClient::GetDataShards", @@ -462,12 +469,31 @@ std::uintptr_t PjRtComputationClient::UnsafeBufferPointer( std::shared_ptr pjrt_data = std::dynamic_pointer_cast(handle); XLA_CHECK(pjrt_data) << "handle must be PjRtData, got " << handle->ToString(); + XLA_CHECK(pjrt_data->buffer != nullptr) + << "PjRt buffer is null in " << __FUNCTION__; xla::StatusOr ptr = client_->UnsafeBufferPointer(pjrt_data->buffer.get()); XLA_CHECK(ptr.ok()); return ptr.value(); } +std::shared_ptr PjRtComputationClient::GetPjRtBuffer( + const DataPtr handle) { + std::shared_ptr pjrt_data = + std::dynamic_pointer_cast(handle); + + XLA_CHECK(pjrt_data) << "handle must be PjRtData, got " << handle->ToString(); + std::shared_ptr pjrt_buffer = pjrt_data->buffer; + if (pjrt_buffer != nullptr) { + return pjrt_buffer; + } else { + TF_VLOG(3) << "The pjrt buffer is null so we need to wait for device ops " + "to finish."; + WaitDeviceOps({}); + return std::dynamic_pointer_cast(handle)->buffer; + } +} + std::vector PjRtComputationClient::TransferFromDevice( absl::Span handles) { metrics::TimedSection timed(TransferFromDeviceMetric()); @@ -482,7 +508,9 @@ std::vector PjRtComputationClient::TransferFromDevice( // Use XLA replication to reassemble the sharded data. If input handle // is not sharded, then it is a no-op. std::shared_ptr pjrt_data = ReplicateShardedData(handle); - XLA_CHECK(pjrt_data); + XLA_CHECK(pjrt_data) << "PjRt_data is null in " << __FUNCTION__; + XLA_CHECK(pjrt_data->buffer != nullptr) + << "PjRt buffer is null in " << __FUNCTION__; xla::Literal& literal = literals.emplace_back(host_output_shape(pjrt_data->buffer.get())); @@ -492,7 +520,8 @@ std::vector PjRtComputationClient::TransferFromDevice( } for (auto& future : futures) { absl::Status status = future.Await(); - XLA_CHECK_OK(status); + XLA_CHECK_OK(status) << "Failed to await future from buffer to literal in" + << __FUNCTION__; } InboundDataMetric()->AddSample(total_size); diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index e3359e80b26..5ed4326d283 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -32,6 +32,9 @@ class PjRtComputationClient : public ComputationClient { std::string device, xla::Shape shape, std::optional sharding = std::nullopt) override; + static DataPtr CreateData(std::string device, xla::Shape shape, + std::shared_ptr pjrt_buffer); + std::vector GetDataShards(DataPtr data) override; DataPtr GetDataShard(DataPtr data, size_t index) override; @@ -57,6 +60,8 @@ class PjRtComputationClient : public ComputationClient { std::uintptr_t UnsafeBufferPointer(const DataPtr handle) override; + std::shared_ptr GetPjRtBuffer(const DataPtr handle) override; + DataPtr TransferShardsToDevice( absl::Span> tensor_shards, std::string device, xla::Shape shape, xla::OpSharding sharding) override; @@ -89,6 +94,16 @@ class PjRtComputationClient : public ComputationClient { absl::AsciiStrToUpper(client_->platform_name())); }; + xla::PjRtPlatformId GetPlatformID() const override { + return client_->platform_id(); + } + + absl::StatusOr LookupAddressableDevice( + int local_device_id) const override { + return client_->LookupAddressableDevice( + xla::PjRtLocalDeviceId(local_device_id)); + } + std::vector GetLocalDevices() const override; std::vector GetAllDevices() const override; @@ -122,6 +137,10 @@ class PjRtComputationClient : public ComputationClient { MemoryInfo GetMemoryInfo(const std::string& device) override; + std::string PjRtDeviceToString(xla::PjRtDevice* const device) const override; + std::vector PjRtDevicesToString( + absl::Span devices) const; + private: std::unique_ptr client_; std::unique_ptr coordinator_; @@ -137,10 +156,6 @@ class PjRtComputationClient : public ComputationClient { xla::PjRtDevice* StringToPjRtDevice(const std::string& device); - std::string PjRtDeviceToString(xla::PjRtDevice* const device) const; - std::vector PjRtDevicesToString( - absl::Span devices) const; - struct PjRtData : public Data { PjRtData(std::string device, xla::Shape device_shape) : Data(std::move(device), std::move(device_shape)) {} diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index dd13bd63d1b..8822b6de7c4 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -17,6 +17,7 @@ #include "torch_xla/csrc/dtype.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/layout_manager.h" +#include "torch_xla/csrc/ops/device_data.h" #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/runtime.h" @@ -931,4 +932,24 @@ xla::PrimitiveType GetShapeDimensionType( return xla::PrimitiveType::S32; } +std::shared_ptr get_data_handle( + const at::Tensor& input) { + XLATensorPtr xtensor = bridge::GetXlaTensor(input); + if (xtensor->CurrentDataHandle() != nullptr) { + TF_VLOG(4) << "The xla tensor has a current data handle."; + return std::dynamic_pointer_cast( + xtensor->CurrentDataHandle()); + } else if (xtensor->CurrentIrValue().node != nullptr) { + DeviceData* device_data = + DeviceData::Cast(xtensor->CurrentIrValue().node.get()); + if (device_data != nullptr) { + return UnwrapXlaData(device_data->data()); + } + TF_VLOG(4) << "The xla tensor has IR value but does not have device data."; + } + TF_VLOG(4) + << "The xla tensor either has no current data handle or has no IR value."; + return nullptr; +} + } // namespace torch_xla diff --git a/torch_xla/csrc/tensor_util.h b/torch_xla/csrc/tensor_util.h index 7d726c00b50..0804d3e9f78 100644 --- a/torch_xla/csrc/tensor_util.h +++ b/torch_xla/csrc/tensor_util.h @@ -212,6 +212,9 @@ inline std::vector xla_expand_outplace(at::TensorList to_expand) { } } +std::shared_ptr get_data_handle( + const at::Tensor& input); + } // namespace torch_xla #endif // XLA_TORCH_XLA_CSRC_TENSOR_UTIL_H_ diff --git a/torch_xla/utils/dlpack.py b/torch_xla/utils/dlpack.py new file mode 100644 index 00000000000..9f93d532b27 --- /dev/null +++ b/torch_xla/utils/dlpack.py @@ -0,0 +1,10 @@ +from typing import Any +import torch_xla + + +def to_dlpack(xla_tensor: Any): + return torch_xla._XLAC._to_dlpack(xla_tensor) + + +def from_dlpack(ext_tensor: Any): + return torch_xla._XLAC._from_dlpack(ext_tensor)