Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changing tensor creation #220

Merged
merged 8 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions inc/common/pjrt_implementation/buffer_instance.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,14 @@ class DeviceInstance;

class BufferInstance {
public:
BufferInstance(DeviceInstance &device, tt::runtime::Tensor tensor,
std::vector<std::uint32_t> shape,
std::vector<std::uint32_t> stride);
BufferInstance(DeviceInstance &device, tt::runtime::Tensor &tensor,
const std::vector<std::uint32_t> &shape,
const std::vector<std::uint32_t> &stride);

BufferInstance(DeviceInstance &device, tt::runtime::Tensor &tensor,
const std::vector<std::uint32_t> &shape,
const std::vector<std::uint32_t> &stride,
std::shared_ptr<void> host_buffer_ptr);
BufferInstance(DeviceInstance &device);
~BufferInstance();
operator PJRT_Buffer *() { return reinterpret_cast<PJRT_Buffer *>(this); }
Expand All @@ -44,7 +49,7 @@ class BufferInstance {
// the hook to get an unsafe pointer (avoids a copy).
return false;
}
tt::runtime::Tensor tensor() { return tensor_.value(); }
const tt::runtime::Tensor &getTensor() const { return tensor_; }

PJRT_Error *GetMemoryLayout(PJRT_Buffer_GetMemoryLayout_Args *args);
// Gets the required host size in bytes to copy to host.
Expand Down Expand Up @@ -74,7 +79,7 @@ class BufferInstance {
// API elements that must have the same lifetime as BufferInstance.
std::vector<int64_t> dims_;
std::vector<std::uint32_t> stride_;
std::optional<tt::runtime::Tensor> tensor_;
tt::runtime::Tensor tensor_;

std::vector<int64_t> minor_to_major_;
std::vector<int64_t> tile_dims_;
Expand All @@ -85,6 +90,13 @@ class BufferInstance {

// OnReady event - currently not used.
EventInstance *on_ready_event_;

// Pointer to the host memory used to create this buffer.
// If buffer is created
// on device, the value of this pointer is nullptr. It is necessary to keep
// track of this memory since the runtime will not clean it, and we need to
// pass the shared pointer to the runtime.
std::shared_ptr<void> host_buffer_ptr_;
};

} // namespace tt::pjrt
Expand Down
14 changes: 14 additions & 0 deletions inc/common/pjrt_implementation/device_instance.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

#include "xla/pjrt/c/pjrt_c_api.h"

#include "tt/runtime/runtime.h"

#include "common/pjrt_implementation/device_description.h"
#include "common/pjrt_implementation/event_instance.h"
#include "common/status.h"
Expand Down Expand Up @@ -67,6 +69,18 @@ class DeviceInstance {
private:
tt_pjrt_status OpenDevice();

static size_t getTensorSize(const std::vector<std::uint32_t> &shape,
size_t element_size);

// Create a buffer instance from a host data pointer, by copying it into
// another memory. This is necessary as we have no ownership of the passed
// pointer, and it might happen that the pointer is deallocated before the
// buffer is used. See issue #248 for more details.
std::unique_ptr<BufferInstance>
MakeDeviceBuffer(const void *data_ptr, std::vector<std::uint32_t> &shape,
std::vector<std::uint32_t> &strides, size_t element_size,
tt::target::DataType element_type);

ClientInstance &client_;
uint64_t last_transfer_timepoint_ = 0;
DeviceDescription description_;
Expand Down
28 changes: 19 additions & 9 deletions src/common/pjrt_implementation/buffer_instance.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,18 @@ int BufferInstance::id_counter_ = 0;
BufferInstance::~BufferInstance() = default;

BufferInstance::BufferInstance(DeviceInstance &device,
tt::runtime::Tensor tensor,
std::vector<std::uint32_t> shape,
std::vector<std::uint32_t> stride)
: device_(device) {
tt::runtime::Tensor &tensor,
const std::vector<std::uint32_t> &shape,
const std::vector<std::uint32_t> &stride)
: BufferInstance(device, tensor, shape, stride, nullptr) {}

BufferInstance::BufferInstance(DeviceInstance &device,
tt::runtime::Tensor &tensor,
const std::vector<std::uint32_t> &shape,
const std::vector<std::uint32_t> &stride,
std::shared_ptr<void> host_buffer_ptr)
: device_(device), tensor_(tensor), host_buffer_ptr_(host_buffer_ptr) {
DLOG_F(LOG_DEBUG, "BufferInstance::BufferInstance");
tensor_ = tensor;
dims_.resize(shape.size());
for (int i = 0; i < shape.size(); i++) {
dims_[i] = shape[i];
Expand Down Expand Up @@ -132,8 +138,12 @@ void BufferInstance::BindApi(PJRT_Api *api) {
+[](PJRT_Buffer_ReadyEvent_Args *args) -> PJRT_Error * {
DLOG_F(LOG_DEBUG, "BufferInstance::PJRT_Buffer_ReadyEvent");
BufferInstance *buffer = BufferInstance::Unwrap(args->buffer);
buffer->on_ready_event_ = new EventInstance();
args->event = *buffer->on_ready_event_;
std::unique_ptr<EventInstance> onReadyEvent =
std::make_unique<EventInstance>();
buffer->on_ready_event_ = onReadyEvent.get();
// Releasing the ownership to the PJRT API caller since the caller is
// responsible for calling PJRT_Event_Destroy on event.
args->event = *onReadyEvent.release();
return nullptr;
};
// TODO: Rework the API to be Aliases(b1, b2) to let the plugin explicitly
Expand Down Expand Up @@ -208,7 +218,7 @@ tt_pjrt_status BufferInstance::CopyToHost(void *dst, size_t dst_size,
};

DLOG_F(INFO, "Copy to host id: %d", unique_id());
tt::runtime::memcpy(dst, tensor());
tt::runtime::memcpy(dst, getTensor());

EventInstance *copy_done_event = new EventInstance();
copy_done_event->OnReady(copy_done_callback, nullptr);
Expand All @@ -219,7 +229,7 @@ tt_pjrt_status BufferInstance::CopyToHost(void *dst, size_t dst_size,

PJRT_Buffer_Type BufferInstance::getRuntimeType() {
DLOG_F(LOG_DEBUG, "BufferInstance::element_type");
tt::target::DataType Type = tt::runtime::getTensorDataType(tensor());
tt::target::DataType Type = tt::runtime::getTensorDataType(getTensor());
return tt::pjrt::utils::convertElementTypeToBufferType(Type);
}

Expand Down
38 changes: 32 additions & 6 deletions src/common/pjrt_implementation/device_instance.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// https://llvm.org/LICENSE.txt

#include <numeric>

#include "common/pjrt_implementation/device_instance.h"

#include "common/pjrt_implementation/buffer_instance.h"
Expand Down Expand Up @@ -80,17 +82,41 @@ tt_pjrt_status DeviceInstance::HostBufferToDevice(
shape.push_back(dims[i]);
strides.push_back(byte_strides[i] / element_size);
}
std::shared_ptr<void> data_ptr(const_cast<void *>(data), [](void *) {});
tt::runtime::Tensor tensor = tt::runtime::createTensor(
data_ptr, shape, strides, element_size, element_type);
BufferInstance *buffer_instance =
new BufferInstance(*this, tensor, shape, strides);
std::unique_ptr<BufferInstance> buffer_instance =
MakeDeviceBuffer(data, shape, strides, element_size, element_type);
DLOG_F(INFO, "Buffer created with id: %d", buffer_instance->unique_id());
buffer_instance->setType(type);
*out_buffer = buffer_instance;
*out_buffer = buffer_instance.release();
EventInstance *event_instance = new EventInstance();
*out_done_with_host_buffer_event = event_instance;
return tt_pjrt_status::kSuccess;
}

size_t DeviceInstance::getTensorSize(const std::vector<std::uint32_t> &shape,
size_t element_size) {
std::uint32_t elementsCount = std::accumulate(
shape.begin(), shape.end(), 1, std::multiplies<std::uint32_t>());

return static_cast<size_t>(elementsCount) * element_size;
}

std::unique_ptr<BufferInstance> DeviceInstance::MakeDeviceBuffer(
const void *data, std::vector<std::uint32_t> &shape,
std::vector<std::uint32_t> &strides, size_t element_size,
tt::target::DataType element_type) {
size_t tensor_size = getTensorSize(shape, element_size);

std::shared_ptr<void> new_memory(new std::byte[tensor_size], [](void *ptr) {
delete[] static_cast<std::byte *>(ptr);
});

std::memcpy(new_memory.get(), data, tensor_size);

tt::runtime::Tensor device_tensor = tt::runtime::createTensor(
new_memory, shape, strides, element_size, element_type);

return std::make_unique<BufferInstance>(*this, device_tensor, shape, strides,
new_memory);
}

} // namespace tt::pjrt
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ LoadedExecutableInstance::Execute(PJRT_LoadedExecutable_Execute_Args *args) {
for (size_t i = 0; i < args->num_args; ++i) {
BufferInstance *buffer =
BufferInstance::Unwrap(args->argument_lists[dev_index][i]);
rt_inputs.emplace_back(buffer->tensor());
rt_inputs.emplace_back(buffer->getTensor());
int64_t buffer_device_id =
buffer->device().device_description()->getDeviceId();
device_ids.insert(chip_ids[buffer_device_id]);
Expand Down
12 changes: 12 additions & 0 deletions tests/jax/ops/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,18 @@ def conditionally_skip(from_dtype: DTypeLike, to_dtype: DTypeLike):
Extracted here in order not to pollute the test function.
"""
# ---------- Atol comparison failed ----------
# When no conversion is required, a no-op MLIR graph is created.
# However, due to input tensor ownership issues, the output tensor
# returned by the MLIR runtime will reference the same data as the input.
# If the input tensor is deallocated, the output tensor will lose access
# to valid data and may contain garbage.
# See issue #248 for more details.
if from_dtype == to_dtype or (from_dtype == jnp.uint32 and to_dtype == jnp.uint64):
pytest.xfail(
runtime_fail(
"Atol comparison failed. Calculated: atol=65535.0. Required: atol=0.16."
)
)

if from_dtype == jnp.uint32 and to_dtype in [jnp.uint16, jnp.int16]:
pytest.xfail(
Expand Down
Loading