From 2f499bd279133c5ee5b5623106cfb6824efec798 Mon Sep 17 00:00:00 2001 From: Andrej Jakovljevic Date: Fri, 31 Jan 2025 15:35:09 +0000 Subject: [PATCH 1/8] Working set --- .../pjrt_implementation/device_instance.h | 9 +++++++ .../pjrt_implementation/device_instance.cc | 24 ++++++++++++++++--- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/inc/common/pjrt_implementation/device_instance.h b/inc/common/pjrt_implementation/device_instance.h index 7365811a..200b6915 100644 --- a/inc/common/pjrt_implementation/device_instance.h +++ b/inc/common/pjrt_implementation/device_instance.h @@ -10,6 +10,8 @@ #include "xla/pjrt/c/pjrt_c_api.h" +#include "tt/runtime/runtime.h" + #include "common/pjrt_implementation/device_description.h" #include "common/pjrt_implementation/event_instance.h" #include "common/status.h" @@ -67,6 +69,13 @@ class DeviceInstance { private: tt_pjrt_status OpenDevice(); + size_t getSize(const std::vector &shape, size_t element_size); + + tt::runtime::Tensor + MakeDeviceTensor(const void *data_ptr, std::vector &shape, + std::vector &strides, size_t element_size, + tt::target::DataType element_type); + ClientInstance &client_; uint64_t last_transfer_timepoint_ = 0; DeviceDescription description_; diff --git a/src/common/pjrt_implementation/device_instance.cc b/src/common/pjrt_implementation/device_instance.cc index 8b7b7593..4d322224 100644 --- a/src/common/pjrt_implementation/device_instance.cc +++ b/src/common/pjrt_implementation/device_instance.cc @@ -80,9 +80,8 @@ tt_pjrt_status DeviceInstance::HostBufferToDevice( shape.push_back(dims[i]); strides.push_back(byte_strides[i] / element_size); } - std::shared_ptr data_ptr(const_cast(data), [](void *) {}); - tt::runtime::Tensor tensor = tt::runtime::createTensor( - data_ptr, shape, strides, element_size, element_type); + tt::runtime::Tensor tensor = + MakeDeviceTensor(data, shape, strides, element_size, element_type); BufferInstance *buffer_instance = new BufferInstance(*this, tensor, shape, strides); DLOG_F(INFO, "Buffer created with id: %d", buffer_instance->unique_id()); @@ -93,4 +92,23 @@ tt_pjrt_status DeviceInstance::HostBufferToDevice( return tt_pjrt_status::kSuccess; } +size_t DeviceInstance::getSize(const std::vector &shape, size_t element_size) { + size_t size = 1; + for (auto dim : shape) { + size *= dim; + } + return size*element_size; +} + +tt::runtime::Tensor DeviceInstance::MakeDeviceTensor( + const void *data, std::vector &shape, + std::vector &strides, size_t element_size, + tt::target::DataType element_type) { + size_t tensor_size = getSize(shape, element_size); + std::shared_ptr new_memory(new char[tensor_size], [](void *) {}); + std::memcpy(new_memory.get(), data, tensor_size); + return tt::runtime::createTensor( + new_memory, shape, strides, element_size, element_type); +} + } // namespace tt::pjrt From 5d10c9100dae1c43a293a8e03e5a0b77c57cff8e Mon Sep 17 00:00:00 2001 From: Andrej Jakovljevic Date: Fri, 31 Jan 2025 15:44:14 +0000 Subject: [PATCH 2/8] Added new memcpy --- inc/common/pjrt_implementation/buffer_instance.h | 6 +++--- inc/common/pjrt_implementation/device_instance.h | 2 +- src/common/pjrt_implementation/buffer_instance.cc | 4 ++-- src/common/pjrt_implementation/device_instance.cc | 8 ++++---- .../pjrt_implementation/loaded_executable_instance.cc | 3 ++- 5 files changed, 12 insertions(+), 11 deletions(-) diff --git a/inc/common/pjrt_implementation/buffer_instance.h b/inc/common/pjrt_implementation/buffer_instance.h index b9d552c6..53f2823f 100644 --- a/inc/common/pjrt_implementation/buffer_instance.h +++ b/inc/common/pjrt_implementation/buffer_instance.h @@ -23,7 +23,7 @@ class DeviceInstance; class BufferInstance { public: - BufferInstance(DeviceInstance &device, tt::runtime::Tensor tensor, + BufferInstance(DeviceInstance &device, std::unique_ptr& tensor, std::vector shape, std::vector stride); BufferInstance(DeviceInstance &device); @@ -44,7 +44,7 @@ class BufferInstance { // the hook to get an unsafe pointer (avoids a copy). return false; } - tt::runtime::Tensor tensor() { return tensor_.value(); } + tt::runtime::Tensor tensor() { return *tensor_; } PJRT_Error *GetMemoryLayout(PJRT_Buffer_GetMemoryLayout_Args *args); // Gets the required host size in bytes to copy to host. @@ -74,7 +74,7 @@ class BufferInstance { // API elements that must have the same lifetime as BufferInstance. std::vector dims_; std::vector stride_; - std::optional tensor_; + std::unique_ptr tensor_; std::vector minor_to_major_; std::vector tile_dims_; diff --git a/inc/common/pjrt_implementation/device_instance.h b/inc/common/pjrt_implementation/device_instance.h index 200b6915..1c8c3d5f 100644 --- a/inc/common/pjrt_implementation/device_instance.h +++ b/inc/common/pjrt_implementation/device_instance.h @@ -71,7 +71,7 @@ class DeviceInstance { size_t getSize(const std::vector &shape, size_t element_size); - tt::runtime::Tensor + std::unique_ptr MakeDeviceTensor(const void *data_ptr, std::vector &shape, std::vector &strides, size_t element_size, tt::target::DataType element_type); diff --git a/src/common/pjrt_implementation/buffer_instance.cc b/src/common/pjrt_implementation/buffer_instance.cc index dfdc2cae..2ab3d31d 100644 --- a/src/common/pjrt_implementation/buffer_instance.cc +++ b/src/common/pjrt_implementation/buffer_instance.cc @@ -19,12 +19,12 @@ int BufferInstance::id_counter_ = 0; BufferInstance::~BufferInstance() = default; BufferInstance::BufferInstance(DeviceInstance &device, - tt::runtime::Tensor tensor, + std::unique_ptr& tensor, std::vector shape, std::vector stride) : device_(device) { DLOG_F(LOG_DEBUG, "BufferInstance::BufferInstance"); - tensor_ = tensor; + tensor_ = std::move(tensor); dims_.resize(shape.size()); for (int i = 0; i < shape.size(); i++) { dims_[i] = shape[i]; diff --git a/src/common/pjrt_implementation/device_instance.cc b/src/common/pjrt_implementation/device_instance.cc index 4d322224..9582b58b 100644 --- a/src/common/pjrt_implementation/device_instance.cc +++ b/src/common/pjrt_implementation/device_instance.cc @@ -80,7 +80,7 @@ tt_pjrt_status DeviceInstance::HostBufferToDevice( shape.push_back(dims[i]); strides.push_back(byte_strides[i] / element_size); } - tt::runtime::Tensor tensor = + std::unique_ptr tensor = MakeDeviceTensor(data, shape, strides, element_size, element_type); BufferInstance *buffer_instance = new BufferInstance(*this, tensor, shape, strides); @@ -100,15 +100,15 @@ size_t DeviceInstance::getSize(const std::vector &shape, size_t e return size*element_size; } -tt::runtime::Tensor DeviceInstance::MakeDeviceTensor( +std::unique_ptr DeviceInstance::MakeDeviceTensor( const void *data, std::vector &shape, std::vector &strides, size_t element_size, tt::target::DataType element_type) { size_t tensor_size = getSize(shape, element_size); std::shared_ptr new_memory(new char[tensor_size], [](void *) {}); std::memcpy(new_memory.get(), data, tensor_size); - return tt::runtime::createTensor( - new_memory, shape, strides, element_size, element_type); + return std::make_unique(tt::runtime::createTensor( + new_memory, shape, strides, element_size, element_type)); } } // namespace tt::pjrt diff --git a/src/common/pjrt_implementation/loaded_executable_instance.cc b/src/common/pjrt_implementation/loaded_executable_instance.cc index 9b025172..5dad1f6b 100644 --- a/src/common/pjrt_implementation/loaded_executable_instance.cc +++ b/src/common/pjrt_implementation/loaded_executable_instance.cc @@ -129,8 +129,9 @@ LoadedExecutableInstance::Execute(PJRT_LoadedExecutable_Execute_Args *args) { // PJRT expects an empty shape for scalars. std::vector output_shape = is_scalar ? std::vector() : output_specs[i].shape; + std::unique_ptr tensor_ptr = std::make_unique(rt_outputs[i]); auto result_buffer = std::make_unique( - *this->addressable_devices_[dev_index], rt_outputs[i], output_shape, + *this->addressable_devices_[dev_index], tensor_ptr, output_shape, output_specs[i].stride); result_buffer->setType(tt::pjrt::utils::convertElementTypeToBufferType( output_specs[i].dataType)); From bd4013d096430b62b9896093cab320772d4b71b3 Mon Sep 17 00:00:00 2001 From: Andrej Jakovljevic Date: Fri, 31 Jan 2025 15:58:27 +0000 Subject: [PATCH 3/8] Formatting --- inc/common/pjrt_implementation/buffer_instance.h | 3 ++- src/common/pjrt_implementation/buffer_instance.cc | 2 +- src/common/pjrt_implementation/device_instance.cc | 5 +++-- src/common/pjrt_implementation/loaded_executable_instance.cc | 3 ++- 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/inc/common/pjrt_implementation/buffer_instance.h b/inc/common/pjrt_implementation/buffer_instance.h index 53f2823f..2bb300d9 100644 --- a/inc/common/pjrt_implementation/buffer_instance.h +++ b/inc/common/pjrt_implementation/buffer_instance.h @@ -23,7 +23,8 @@ class DeviceInstance; class BufferInstance { public: - BufferInstance(DeviceInstance &device, std::unique_ptr& tensor, + BufferInstance(DeviceInstance &device, + std::unique_ptr &tensor, std::vector shape, std::vector stride); BufferInstance(DeviceInstance &device); diff --git a/src/common/pjrt_implementation/buffer_instance.cc b/src/common/pjrt_implementation/buffer_instance.cc index 2ab3d31d..a2218694 100644 --- a/src/common/pjrt_implementation/buffer_instance.cc +++ b/src/common/pjrt_implementation/buffer_instance.cc @@ -19,7 +19,7 @@ int BufferInstance::id_counter_ = 0; BufferInstance::~BufferInstance() = default; BufferInstance::BufferInstance(DeviceInstance &device, - std::unique_ptr& tensor, + std::unique_ptr &tensor, std::vector shape, std::vector stride) : device_(device) { diff --git a/src/common/pjrt_implementation/device_instance.cc b/src/common/pjrt_implementation/device_instance.cc index 9582b58b..45bb7edf 100644 --- a/src/common/pjrt_implementation/device_instance.cc +++ b/src/common/pjrt_implementation/device_instance.cc @@ -92,12 +92,13 @@ tt_pjrt_status DeviceInstance::HostBufferToDevice( return tt_pjrt_status::kSuccess; } -size_t DeviceInstance::getSize(const std::vector &shape, size_t element_size) { +size_t DeviceInstance::getSize(const std::vector &shape, + size_t element_size) { size_t size = 1; for (auto dim : shape) { size *= dim; } - return size*element_size; + return size * element_size; } std::unique_ptr DeviceInstance::MakeDeviceTensor( diff --git a/src/common/pjrt_implementation/loaded_executable_instance.cc b/src/common/pjrt_implementation/loaded_executable_instance.cc index 5dad1f6b..4fd71f48 100644 --- a/src/common/pjrt_implementation/loaded_executable_instance.cc +++ b/src/common/pjrt_implementation/loaded_executable_instance.cc @@ -129,7 +129,8 @@ LoadedExecutableInstance::Execute(PJRT_LoadedExecutable_Execute_Args *args) { // PJRT expects an empty shape for scalars. std::vector output_shape = is_scalar ? std::vector() : output_specs[i].shape; - std::unique_ptr tensor_ptr = std::make_unique(rt_outputs[i]); + std::unique_ptr tensor_ptr = + std::make_unique(rt_outputs[i]); auto result_buffer = std::make_unique( *this->addressable_devices_[dev_index], tensor_ptr, output_shape, output_specs[i].stride); From 34045a25a844f750751c0108bdbe7d700717a63a Mon Sep 17 00:00:00 2001 From: Andrej Jakovljevic Date: Tue, 4 Feb 2025 10:32:37 +0000 Subject: [PATCH 4/8] Lifetime fix --- inc/common/pjrt_implementation/buffer_instance.h | 9 +++++++-- inc/common/pjrt_implementation/device_instance.h | 9 +++++---- .../pjrt_implementation/buffer_instance.cc | 7 ++++--- .../pjrt_implementation/device_instance.cc | 16 +++++++++------- .../loaded_executable_instance.cc | 2 +- 5 files changed, 26 insertions(+), 17 deletions(-) diff --git a/inc/common/pjrt_implementation/buffer_instance.h b/inc/common/pjrt_implementation/buffer_instance.h index 2bb300d9..da9ebcdb 100644 --- a/inc/common/pjrt_implementation/buffer_instance.h +++ b/inc/common/pjrt_implementation/buffer_instance.h @@ -26,7 +26,8 @@ class BufferInstance { BufferInstance(DeviceInstance &device, std::unique_ptr &tensor, std::vector shape, - std::vector stride); + std::vector stride, + std::shared_ptr host_buffer_ptr); BufferInstance(DeviceInstance &device); ~BufferInstance(); operator PJRT_Buffer *() { return reinterpret_cast(this); } @@ -85,7 +86,11 @@ class BufferInstance { std::optional DataType; // OnReady event - currently not used. - EventInstance *on_ready_event_; + std::shared_ptr on_ready_event_; + + // Pointer to the host memory used to create this buffer, if buffer is created + // on device, the value of this pointer is nullptr. + std::shared_ptr host_buffer_ptr_; }; } // namespace tt::pjrt diff --git a/inc/common/pjrt_implementation/device_instance.h b/inc/common/pjrt_implementation/device_instance.h index 1c8c3d5f..9275e361 100644 --- a/inc/common/pjrt_implementation/device_instance.h +++ b/inc/common/pjrt_implementation/device_instance.h @@ -71,10 +71,11 @@ class DeviceInstance { size_t getSize(const std::vector &shape, size_t element_size); - std::unique_ptr - MakeDeviceTensor(const void *data_ptr, std::vector &shape, - std::vector &strides, size_t element_size, - tt::target::DataType element_type); + BufferInstance *MakeDeviceBuffer(const void *data_ptr, + std::vector &shape, + std::vector &strides, + size_t element_size, + tt::target::DataType element_type); ClientInstance &client_; uint64_t last_transfer_timepoint_ = 0; diff --git a/src/common/pjrt_implementation/buffer_instance.cc b/src/common/pjrt_implementation/buffer_instance.cc index a2218694..02d77655 100644 --- a/src/common/pjrt_implementation/buffer_instance.cc +++ b/src/common/pjrt_implementation/buffer_instance.cc @@ -21,8 +21,9 @@ BufferInstance::~BufferInstance() = default; BufferInstance::BufferInstance(DeviceInstance &device, std::unique_ptr &tensor, std::vector shape, - std::vector stride) - : device_(device) { + std::vector stride, + std::shared_ptr host_buffer_ptr) + : device_(device), host_buffer_ptr_(host_buffer_ptr) { DLOG_F(LOG_DEBUG, "BufferInstance::BufferInstance"); tensor_ = std::move(tensor); dims_.resize(shape.size()); @@ -132,7 +133,7 @@ void BufferInstance::BindApi(PJRT_Api *api) { +[](PJRT_Buffer_ReadyEvent_Args *args) -> PJRT_Error * { DLOG_F(LOG_DEBUG, "BufferInstance::PJRT_Buffer_ReadyEvent"); BufferInstance *buffer = BufferInstance::Unwrap(args->buffer); - buffer->on_ready_event_ = new EventInstance(); + buffer->on_ready_event_ = std::make_shared(); args->event = *buffer->on_ready_event_; return nullptr; }; diff --git a/src/common/pjrt_implementation/device_instance.cc b/src/common/pjrt_implementation/device_instance.cc index 45bb7edf..4c016fd7 100644 --- a/src/common/pjrt_implementation/device_instance.cc +++ b/src/common/pjrt_implementation/device_instance.cc @@ -80,10 +80,8 @@ tt_pjrt_status DeviceInstance::HostBufferToDevice( shape.push_back(dims[i]); strides.push_back(byte_strides[i] / element_size); } - std::unique_ptr tensor = - MakeDeviceTensor(data, shape, strides, element_size, element_type); BufferInstance *buffer_instance = - new BufferInstance(*this, tensor, shape, strides); + MakeDeviceBuffer(data, shape, strides, element_size, element_type); DLOG_F(INFO, "Buffer created with id: %d", buffer_instance->unique_id()); buffer_instance->setType(type); *out_buffer = buffer_instance; @@ -101,15 +99,19 @@ size_t DeviceInstance::getSize(const std::vector &shape, return size * element_size; } -std::unique_ptr DeviceInstance::MakeDeviceTensor( +BufferInstance *DeviceInstance::MakeDeviceBuffer( const void *data, std::vector &shape, std::vector &strides, size_t element_size, tt::target::DataType element_type) { size_t tensor_size = getSize(shape, element_size); - std::shared_ptr new_memory(new char[tensor_size], [](void *) {}); + std::shared_ptr new_memory(new char[tensor_size], [](void *ptr) { + delete[] static_cast(ptr); + }); std::memcpy(new_memory.get(), data, tensor_size); - return std::make_unique(tt::runtime::createTensor( - new_memory, shape, strides, element_size, element_type)); + std::unique_ptr device_tensor = + std::make_unique(tt::runtime::createTensor( + new_memory, shape, strides, element_size, element_type)); + return new BufferInstance(*this, device_tensor, shape, strides, new_memory); } } // namespace tt::pjrt diff --git a/src/common/pjrt_implementation/loaded_executable_instance.cc b/src/common/pjrt_implementation/loaded_executable_instance.cc index 4fd71f48..76cdb99a 100644 --- a/src/common/pjrt_implementation/loaded_executable_instance.cc +++ b/src/common/pjrt_implementation/loaded_executable_instance.cc @@ -133,7 +133,7 @@ LoadedExecutableInstance::Execute(PJRT_LoadedExecutable_Execute_Args *args) { std::make_unique(rt_outputs[i]); auto result_buffer = std::make_unique( *this->addressable_devices_[dev_index], tensor_ptr, output_shape, - output_specs[i].stride); + output_specs[i].stride, nullptr); result_buffer->setType(tt::pjrt::utils::convertElementTypeToBufferType( output_specs[i].dataType)); DLOG_F(INFO, "Runtime output id: %d", result_buffer->unique_id()); From acd80067ee65d395024b8fafa7ec1dfc937c5e32 Mon Sep 17 00:00:00 2001 From: Andrej Jakovljevic Date: Fri, 7 Feb 2025 09:19:39 +0000 Subject: [PATCH 5/8] Some changes --- inc/common/pjrt_implementation/device_instance.h | 2 +- src/common/pjrt_implementation/device_instance.cc | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/inc/common/pjrt_implementation/device_instance.h b/inc/common/pjrt_implementation/device_instance.h index 9275e361..5dc5c4e8 100644 --- a/inc/common/pjrt_implementation/device_instance.h +++ b/inc/common/pjrt_implementation/device_instance.h @@ -69,7 +69,7 @@ class DeviceInstance { private: tt_pjrt_status OpenDevice(); - size_t getSize(const std::vector &shape, size_t element_size); + static size_t getTensorSize(const std::vector &shape, size_t element_size); BufferInstance *MakeDeviceBuffer(const void *data_ptr, std::vector &shape, diff --git a/src/common/pjrt_implementation/device_instance.cc b/src/common/pjrt_implementation/device_instance.cc index 4c016fd7..7a4f40ab 100644 --- a/src/common/pjrt_implementation/device_instance.cc +++ b/src/common/pjrt_implementation/device_instance.cc @@ -90,7 +90,7 @@ tt_pjrt_status DeviceInstance::HostBufferToDevice( return tt_pjrt_status::kSuccess; } -size_t DeviceInstance::getSize(const std::vector &shape, +size_t DeviceInstance::getTensorSize(const std::vector &shape, size_t element_size) { size_t size = 1; for (auto dim : shape) { @@ -103,7 +103,7 @@ BufferInstance *DeviceInstance::MakeDeviceBuffer( const void *data, std::vector &shape, std::vector &strides, size_t element_size, tt::target::DataType element_type) { - size_t tensor_size = getSize(shape, element_size); + size_t tensor_size = getTensorSize(shape, element_size); std::shared_ptr new_memory(new char[tensor_size], [](void *ptr) { delete[] static_cast(ptr); }); From ed353380ff53037cb184535b6a218bf3ed1d6ed7 Mon Sep 17 00:00:00 2001 From: Andrej Jakovljevic Date: Fri, 7 Feb 2025 10:16:37 +0000 Subject: [PATCH 6/8] Addressed comments --- .../pjrt_implementation/buffer_instance.h | 23 +++++++----- .../pjrt_implementation/device_instance.h | 18 ++++++---- .../pjrt_implementation/buffer_instance.cc | 31 ++++++++++------ .../pjrt_implementation/device_instance.cc | 35 +++++++++++-------- .../loaded_executable_instance.cc | 8 ++--- tests/jax/ops/test_convert.py | 12 +++++++ 6 files changed, 81 insertions(+), 46 deletions(-) diff --git a/inc/common/pjrt_implementation/buffer_instance.h b/inc/common/pjrt_implementation/buffer_instance.h index da9ebcdb..dd0764bf 100644 --- a/inc/common/pjrt_implementation/buffer_instance.h +++ b/inc/common/pjrt_implementation/buffer_instance.h @@ -23,10 +23,13 @@ class DeviceInstance; class BufferInstance { public: - BufferInstance(DeviceInstance &device, - std::unique_ptr &tensor, - std::vector shape, - std::vector stride, + BufferInstance(DeviceInstance &device, tt::runtime::Tensor &tensor, + const std::vector &shape, + const std::vector &stride); + + BufferInstance(DeviceInstance &device, tt::runtime::Tensor &tensor, + const std::vector &shape, + const std::vector &stride, std::shared_ptr host_buffer_ptr); BufferInstance(DeviceInstance &device); ~BufferInstance(); @@ -46,7 +49,7 @@ class BufferInstance { // the hook to get an unsafe pointer (avoids a copy). return false; } - tt::runtime::Tensor tensor() { return *tensor_; } + const tt::runtime::Tensor &getTensor() const { return tensor_; } PJRT_Error *GetMemoryLayout(PJRT_Buffer_GetMemoryLayout_Args *args); // Gets the required host size in bytes to copy to host. @@ -76,7 +79,7 @@ class BufferInstance { // API elements that must have the same lifetime as BufferInstance. std::vector dims_; std::vector stride_; - std::unique_ptr tensor_; + tt::runtime::Tensor tensor_; std::vector minor_to_major_; std::vector tile_dims_; @@ -86,11 +89,13 @@ class BufferInstance { std::optional DataType; // OnReady event - currently not used. - std::shared_ptr on_ready_event_; + EventInstance *on_ready_event_; // Pointer to the host memory used to create this buffer, if buffer is created - // on device, the value of this pointer is nullptr. - std::shared_ptr host_buffer_ptr_; + // on device, the value of this pointer is nullptr. It is necessary to keep + // track of this memory since the runtime will not clean it, and we need to + // pass the shared pointer to the runtime. + std::shared_ptr host_buffer_ptr_ = nullptr; }; } // namespace tt::pjrt diff --git a/inc/common/pjrt_implementation/device_instance.h b/inc/common/pjrt_implementation/device_instance.h index 5dc5c4e8..5124ddbb 100644 --- a/inc/common/pjrt_implementation/device_instance.h +++ b/inc/common/pjrt_implementation/device_instance.h @@ -69,13 +69,17 @@ class DeviceInstance { private: tt_pjrt_status OpenDevice(); - static size_t getTensorSize(const std::vector &shape, size_t element_size); - - BufferInstance *MakeDeviceBuffer(const void *data_ptr, - std::vector &shape, - std::vector &strides, - size_t element_size, - tt::target::DataType element_type); + static size_t getTensorSize(const std::vector &shape, + size_t element_size); + + // Create a buffer instance from a host data pointer, by copying it into + // another memory. This is necessary as we have no ownership of the passed + // pointer, and it might happen that the pointer is deallocated before the + // buffer is used. + std::unique_ptr + MakeDeviceBuffer(const void *data_ptr, std::vector &shape, + std::vector &strides, size_t element_size, + tt::target::DataType element_type); ClientInstance &client_; uint64_t last_transfer_timepoint_ = 0; diff --git a/src/common/pjrt_implementation/buffer_instance.cc b/src/common/pjrt_implementation/buffer_instance.cc index 02d77655..04fcf746 100644 --- a/src/common/pjrt_implementation/buffer_instance.cc +++ b/src/common/pjrt_implementation/buffer_instance.cc @@ -19,13 +19,11 @@ int BufferInstance::id_counter_ = 0; BufferInstance::~BufferInstance() = default; BufferInstance::BufferInstance(DeviceInstance &device, - std::unique_ptr &tensor, - std::vector shape, - std::vector stride, - std::shared_ptr host_buffer_ptr) - : device_(device), host_buffer_ptr_(host_buffer_ptr) { + tt::runtime::Tensor &tensor, + const std::vector &shape, + const std::vector &stride) + : device_(device), tensor_(tensor) { DLOG_F(LOG_DEBUG, "BufferInstance::BufferInstance"); - tensor_ = std::move(tensor); dims_.resize(shape.size()); for (int i = 0; i < shape.size(); i++) { dims_[i] = shape[i]; @@ -34,6 +32,15 @@ BufferInstance::BufferInstance(DeviceInstance &device, unique_id_ = id_counter_++; } +BufferInstance::BufferInstance(DeviceInstance &device, + tt::runtime::Tensor &tensor, + const std::vector &shape, + const std::vector &stride, + std::shared_ptr host_buffer_ptr) + : BufferInstance(device, tensor, shape, stride) { + host_buffer_ptr_ = host_buffer_ptr; +} + void BufferInstance::ComputeLayout() { DLOG_F(LOG_DEBUG, "BufferInstance::ComputeLayout"); } @@ -133,8 +140,12 @@ void BufferInstance::BindApi(PJRT_Api *api) { +[](PJRT_Buffer_ReadyEvent_Args *args) -> PJRT_Error * { DLOG_F(LOG_DEBUG, "BufferInstance::PJRT_Buffer_ReadyEvent"); BufferInstance *buffer = BufferInstance::Unwrap(args->buffer); - buffer->on_ready_event_ = std::make_shared(); - args->event = *buffer->on_ready_event_; + std::unique_ptr onReadyEvent = + std::make_unique(); + buffer->on_ready_event_ = onReadyEvent.get(); + // Releasing the ownership to the PJRT API caller since the caller is + // responsible for calling PJRT_Event_Destroy on event. + args->event = *onReadyEvent.release(); return nullptr; }; // TODO: Rework the API to be Aliases(b1, b2) to let the plugin explicitly @@ -209,7 +220,7 @@ tt_pjrt_status BufferInstance::CopyToHost(void *dst, size_t dst_size, }; DLOG_F(INFO, "Copy to host id: %d", unique_id()); - tt::runtime::memcpy(dst, tensor()); + tt::runtime::memcpy(dst, getTensor()); EventInstance *copy_done_event = new EventInstance(); copy_done_event->OnReady(copy_done_callback, nullptr); @@ -220,7 +231,7 @@ tt_pjrt_status BufferInstance::CopyToHost(void *dst, size_t dst_size, PJRT_Buffer_Type BufferInstance::getRuntimeType() { DLOG_F(LOG_DEBUG, "BufferInstance::element_type"); - tt::target::DataType Type = tt::runtime::getTensorDataType(tensor()); + tt::target::DataType Type = tt::runtime::getTensorDataType(getTensor()); return tt::pjrt::utils::convertElementTypeToBufferType(Type); } diff --git a/src/common/pjrt_implementation/device_instance.cc b/src/common/pjrt_implementation/device_instance.cc index 7a4f40ab..adffb983 100644 --- a/src/common/pjrt_implementation/device_instance.cc +++ b/src/common/pjrt_implementation/device_instance.cc @@ -8,6 +8,8 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // https://llvm.org/LICENSE.txt +#include + #include "common/pjrt_implementation/device_instance.h" #include "common/pjrt_implementation/buffer_instance.h" @@ -80,38 +82,41 @@ tt_pjrt_status DeviceInstance::HostBufferToDevice( shape.push_back(dims[i]); strides.push_back(byte_strides[i] / element_size); } - BufferInstance *buffer_instance = + std::unique_ptr buffer_instance = MakeDeviceBuffer(data, shape, strides, element_size, element_type); DLOG_F(INFO, "Buffer created with id: %d", buffer_instance->unique_id()); buffer_instance->setType(type); - *out_buffer = buffer_instance; + *out_buffer = buffer_instance.release(); EventInstance *event_instance = new EventInstance(); *out_done_with_host_buffer_event = event_instance; return tt_pjrt_status::kSuccess; } size_t DeviceInstance::getTensorSize(const std::vector &shape, - size_t element_size) { - size_t size = 1; - for (auto dim : shape) { - size *= dim; - } - return size * element_size; + size_t element_size) { + std::uint32_t elementsCount = std::accumulate( + shape.begin(), shape.end(), 1, std::multiplies()); + + return static_cast(elementsCount) * element_size; } -BufferInstance *DeviceInstance::MakeDeviceBuffer( +std::unique_ptr DeviceInstance::MakeDeviceBuffer( const void *data, std::vector &shape, std::vector &strides, size_t element_size, tt::target::DataType element_type) { size_t tensor_size = getTensorSize(shape, element_size); - std::shared_ptr new_memory(new char[tensor_size], [](void *ptr) { - delete[] static_cast(ptr); + + std::shared_ptr new_memory(new std::byte[tensor_size], [](void *ptr) { + delete[] static_cast(ptr); }); + std::memcpy(new_memory.get(), data, tensor_size); - std::unique_ptr device_tensor = - std::make_unique(tt::runtime::createTensor( - new_memory, shape, strides, element_size, element_type)); - return new BufferInstance(*this, device_tensor, shape, strides, new_memory); + + tt::runtime::Tensor device_tensor = tt::runtime::createTensor( + new_memory, shape, strides, element_size, element_type); + + return std::make_unique(*this, device_tensor, shape, strides, + new_memory); } } // namespace tt::pjrt diff --git a/src/common/pjrt_implementation/loaded_executable_instance.cc b/src/common/pjrt_implementation/loaded_executable_instance.cc index 76cdb99a..f8ff1417 100644 --- a/src/common/pjrt_implementation/loaded_executable_instance.cc +++ b/src/common/pjrt_implementation/loaded_executable_instance.cc @@ -97,7 +97,7 @@ LoadedExecutableInstance::Execute(PJRT_LoadedExecutable_Execute_Args *args) { for (size_t i = 0; i < args->num_args; ++i) { BufferInstance *buffer = BufferInstance::Unwrap(args->argument_lists[dev_index][i]); - rt_inputs.emplace_back(buffer->tensor()); + rt_inputs.emplace_back(buffer->getTensor()); int64_t buffer_device_id = buffer->device().device_description()->getDeviceId(); device_ids.insert(chip_ids[buffer_device_id]); @@ -129,11 +129,9 @@ LoadedExecutableInstance::Execute(PJRT_LoadedExecutable_Execute_Args *args) { // PJRT expects an empty shape for scalars. std::vector output_shape = is_scalar ? std::vector() : output_specs[i].shape; - std::unique_ptr tensor_ptr = - std::make_unique(rt_outputs[i]); auto result_buffer = std::make_unique( - *this->addressable_devices_[dev_index], tensor_ptr, output_shape, - output_specs[i].stride, nullptr); + *this->addressable_devices_[dev_index], rt_outputs[i], output_shape, + output_specs[i].stride); result_buffer->setType(tt::pjrt::utils::convertElementTypeToBufferType( output_specs[i].dataType)); DLOG_F(INFO, "Runtime output id: %d", result_buffer->unique_id()); diff --git a/tests/jax/ops/test_convert.py b/tests/jax/ops/test_convert.py index 57d47b9c..ad092671 100644 --- a/tests/jax/ops/test_convert.py +++ b/tests/jax/ops/test_convert.py @@ -25,6 +25,18 @@ def conditionally_skip(from_dtype: DTypeLike, to_dtype: DTypeLike): Extracted here in order not to pollute the test function. """ # ---------- Atol comparison failed ---------- + # When no conversion is required, a no-op MLIR graph is created. + # However, due to input tensor ownership issues, the output tensor + # returned by the MLIR runtime will reference the same data as the input. + # If the input tensor is deallocated, the output tensor will lose access + # to valid data and may contain garbage. + # See issue #244 for more details. + if from_dtype == to_dtype or (from_dtype == jnp.uint32 and to_dtype == jnp.uint64): + pytest.xfail( + runtime_fail( + "Atol comparison failed. Calculated: atol=65535.0. Required: atol=0.16." + ) + ) if from_dtype == jnp.uint32 and to_dtype in [jnp.uint16, jnp.int16]: pytest.xfail( From f3bf9644e1fd714c2ab9ff66920ffc6b4fca633a Mon Sep 17 00:00:00 2001 From: Andrej Jakovljevic Date: Fri, 7 Feb 2025 13:18:01 +0000 Subject: [PATCH 7/8] Addressed more comments --- .../pjrt_implementation/buffer_instance.h | 5 +++-- .../pjrt_implementation/buffer_instance.cc | 20 +++++++++---------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/inc/common/pjrt_implementation/buffer_instance.h b/inc/common/pjrt_implementation/buffer_instance.h index dd0764bf..87961b6b 100644 --- a/inc/common/pjrt_implementation/buffer_instance.h +++ b/inc/common/pjrt_implementation/buffer_instance.h @@ -91,11 +91,12 @@ class BufferInstance { // OnReady event - currently not used. EventInstance *on_ready_event_; - // Pointer to the host memory used to create this buffer, if buffer is created + // Pointer to the host memory used to create this buffer. + // If buffer is created // on device, the value of this pointer is nullptr. It is necessary to keep // track of this memory since the runtime will not clean it, and we need to // pass the shared pointer to the runtime. - std::shared_ptr host_buffer_ptr_ = nullptr; + std::shared_ptr host_buffer_ptr_; }; } // namespace tt::pjrt diff --git a/src/common/pjrt_implementation/buffer_instance.cc b/src/common/pjrt_implementation/buffer_instance.cc index 04fcf746..73a1c397 100644 --- a/src/common/pjrt_implementation/buffer_instance.cc +++ b/src/common/pjrt_implementation/buffer_instance.cc @@ -22,14 +22,8 @@ BufferInstance::BufferInstance(DeviceInstance &device, tt::runtime::Tensor &tensor, const std::vector &shape, const std::vector &stride) - : device_(device), tensor_(tensor) { - DLOG_F(LOG_DEBUG, "BufferInstance::BufferInstance"); - dims_.resize(shape.size()); - for (int i = 0; i < shape.size(); i++) { - dims_[i] = shape[i]; - } - stride_ = stride; - unique_id_ = id_counter_++; + : BufferInstance(device, tensor, shape, stride, nullptr) { + } BufferInstance::BufferInstance(DeviceInstance &device, @@ -37,8 +31,14 @@ BufferInstance::BufferInstance(DeviceInstance &device, const std::vector &shape, const std::vector &stride, std::shared_ptr host_buffer_ptr) - : BufferInstance(device, tensor, shape, stride) { - host_buffer_ptr_ = host_buffer_ptr; + : device_(device), tensor_(tensor), host_buffer_ptr_(host_buffer_ptr) { + DLOG_F(LOG_DEBUG, "BufferInstance::BufferInstance"); + dims_.resize(shape.size()); + for (int i = 0; i < shape.size(); i++) { + dims_[i] = shape[i]; + } + stride_ = stride; + unique_id_ = id_counter_++; } void BufferInstance::ComputeLayout() { From 5fcaf3a85595e60d0c4820fc7edf3968626a30ac Mon Sep 17 00:00:00 2001 From: Andrej Jakovljevic Date: Fri, 7 Feb 2025 13:48:10 +0000 Subject: [PATCH 8/8] Addressing comments --- inc/common/pjrt_implementation/buffer_instance.h | 2 +- inc/common/pjrt_implementation/device_instance.h | 2 +- src/common/pjrt_implementation/buffer_instance.cc | 4 +--- tests/jax/ops/test_convert.py | 2 +- 4 files changed, 4 insertions(+), 6 deletions(-) diff --git a/inc/common/pjrt_implementation/buffer_instance.h b/inc/common/pjrt_implementation/buffer_instance.h index 87961b6b..5befa815 100644 --- a/inc/common/pjrt_implementation/buffer_instance.h +++ b/inc/common/pjrt_implementation/buffer_instance.h @@ -91,7 +91,7 @@ class BufferInstance { // OnReady event - currently not used. EventInstance *on_ready_event_; - // Pointer to the host memory used to create this buffer. + // Pointer to the host memory used to create this buffer. // If buffer is created // on device, the value of this pointer is nullptr. It is necessary to keep // track of this memory since the runtime will not clean it, and we need to diff --git a/inc/common/pjrt_implementation/device_instance.h b/inc/common/pjrt_implementation/device_instance.h index 5124ddbb..e692b6eb 100644 --- a/inc/common/pjrt_implementation/device_instance.h +++ b/inc/common/pjrt_implementation/device_instance.h @@ -75,7 +75,7 @@ class DeviceInstance { // Create a buffer instance from a host data pointer, by copying it into // another memory. This is necessary as we have no ownership of the passed // pointer, and it might happen that the pointer is deallocated before the - // buffer is used. + // buffer is used. See issue #248 for more details. std::unique_ptr MakeDeviceBuffer(const void *data_ptr, std::vector &shape, std::vector &strides, size_t element_size, diff --git a/src/common/pjrt_implementation/buffer_instance.cc b/src/common/pjrt_implementation/buffer_instance.cc index 73a1c397..cbb5286b 100644 --- a/src/common/pjrt_implementation/buffer_instance.cc +++ b/src/common/pjrt_implementation/buffer_instance.cc @@ -22,9 +22,7 @@ BufferInstance::BufferInstance(DeviceInstance &device, tt::runtime::Tensor &tensor, const std::vector &shape, const std::vector &stride) - : BufferInstance(device, tensor, shape, stride, nullptr) { - -} + : BufferInstance(device, tensor, shape, stride, nullptr) {} BufferInstance::BufferInstance(DeviceInstance &device, tt::runtime::Tensor &tensor, diff --git a/tests/jax/ops/test_convert.py b/tests/jax/ops/test_convert.py index ad092671..0521cc99 100644 --- a/tests/jax/ops/test_convert.py +++ b/tests/jax/ops/test_convert.py @@ -30,7 +30,7 @@ def conditionally_skip(from_dtype: DTypeLike, to_dtype: DTypeLike): # returned by the MLIR runtime will reference the same data as the input. # If the input tensor is deallocated, the output tensor will lose access # to valid data and may contain garbage. - # See issue #244 for more details. + # See issue #248 for more details. if from_dtype == to_dtype or (from_dtype == jnp.uint32 and to_dtype == jnp.uint64): pytest.xfail( runtime_fail(