diff --git a/tt_metal/host_api.hpp b/tt_metal/host_api.hpp index cdbaa865b56..7b728c3d1c9 100644 --- a/tt_metal/host_api.hpp +++ b/tt_metal/host_api.hpp @@ -259,6 +259,17 @@ uint32_t CreateSemaphore( */ std::shared_ptr CreateBuffer(const InterleavedBufferConfig &config); +/** +* Creates a pre-allocated interleaved DRAM or L1 buffer on device +* +* Return value: std::shared_ptr +* +* | Argument | Description | Type | Valid Range | Required | +* |-----------------|---------------------------------------- |--------------------------|-------------|----------| +* | config | config for buffer | InterleavedBufferConfig | | Yes | +*/ +std::shared_ptr CreateBuffer(const InterleavedBufferConfig &config, DeviceAddr address); + /** * Allocates a sharded DRAM or L1 buffer on device * @@ -270,6 +281,17 @@ std::shared_ptr CreateBuffer(const InterleavedBufferConfig &config); */ std::shared_ptr CreateBuffer(const ShardedBufferConfig &config); +/** +* Creates a pre-allocated sharded DRAM or L1 buffer on device +* +* Return value: std::shared_ptr +* +* | Argument | Description | Type | Valid Range | Required | +* |-----------------|---------------------------------------- |--------------------------|-------------|----------| +* | config | config for buffer | ShardedBufferConfig | | Yes | +*/ +std::shared_ptr CreateBuffer(const ShardedBufferConfig &config, DeviceAddr address); + /** * Deallocates buffer from device by marking its memory as free. * diff --git a/tt_metal/impl/buffers/buffer.cpp b/tt_metal/impl/buffers/buffer.cpp index bdd6b6a7137..d32869d383a 100644 --- a/tt_metal/impl/buffers/buffer.cpp +++ b/tt_metal/impl/buffers/buffer.cpp @@ -203,7 +203,9 @@ Buffer::Buffer( const BufferType buffer_type, const TensorMemoryLayout buffer_layout, const std::optional& shard_parameters, - const std::optional bottom_up) : + const std::optional bottom_up, + const bool owns_data, + Private) : device_(device), size_(size), page_size_(page_size), @@ -211,6 +213,7 @@ Buffer::Buffer( buffer_layout_(buffer_layout), shard_parameters_(shard_parameters), bottom_up_(bottom_up.value_or(this->is_dram())), + owns_data_(owns_data), buffer_page_mapping_(nullptr) { TT_FATAL(this->device_ != nullptr && this->device_->allocator_ != nullptr, "Device and allocator need to not be null."); @@ -227,7 +230,8 @@ std::shared_ptr Buffer::create( const TensorMemoryLayout buffer_layout, const std::optional& shard_parameters, const std::optional bottom_up) { - auto* bufferPtr = new Buffer(device, size, page_size, buffer_type, buffer_layout, shard_parameters, bottom_up); + auto* bufferPtr = new Buffer(device, size, page_size, buffer_type, buffer_layout, shard_parameters, bottom_up, true /* owns data */, Private()); + // Using a custom deleter to properly clean up the owned datas auto buffer = std::shared_ptr(bufferPtr, deleter); buffer->weak_self = buffer; @@ -240,7 +244,7 @@ std::shared_ptr Buffer::create( buffer->address_ = detail::AllocateBuffer(buffer.get()); std::unique_lock lock(buffer->allocation_mutex_); - buffer->allocation_status_.store(AllocationStatus::ALLOCATED, std::memory_order::relaxed); + buffer->allocation_status_.store(AllocationStatus::ALLOCATED, std::memory_order::release); lock.unlock(); buffer->allocation_cv_.notify_all(); }); @@ -248,8 +252,30 @@ std::shared_ptr Buffer::create( return buffer; } +std::shared_ptr Buffer::create( + Device *device, + DeviceAddr address, + DeviceAddr size, + DeviceAddr page_size, + const BufferType buffer_type, + const TensorMemoryLayout buffer_layout, + const std::optional& shard_parameters, + const std::optional bottom_up) { + // Not using a custom deleter, because it doesn't own any data to cleanup + auto buffer = std::make_shared(device, size, page_size, buffer_type, buffer_layout, shard_parameters, bottom_up, false /* owns data */, Private()); + buffer->weak_self = buffer; + + buffer->address_ = address; + buffer->allocation_status_.store(AllocationStatus::ALLOCATED, std::memory_order::relaxed); + + return buffer; +} + void Buffer::deallocate() { deallocation_requested_.store(true, std::memory_order::relaxed); + if (!owns_data_) { + return; + } device_->push_work([self = weak_self.lock()] { self->deallocate_impl(); }); @@ -289,7 +315,7 @@ bool Buffer::is_allocated() const { } uint32_t Buffer::address() const { - if (device_->can_use_passthrough_scheduling()) { + if (allocation_status_.load(std::memory_order::acquire) != AllocationStatus::ALLOCATION_REQUESTED) { return address_; } diff --git a/tt_metal/impl/buffers/buffer.hpp b/tt_metal/impl/buffers/buffer.hpp index 2c51edf576b..b12f9967c8c 100644 --- a/tt_metal/impl/buffers/buffer.hpp +++ b/tt_metal/impl/buffers/buffer.hpp @@ -144,6 +144,8 @@ struct BufferPageMapping { inline namespace v0 { class Buffer final { + struct Private {}; + public: static std::shared_ptr create( Device *device, @@ -153,6 +155,15 @@ class Buffer final { TensorMemoryLayout buffer_layout = TensorMemoryLayout::INTERLEAVED, const std::optional& shard_parameter = std::nullopt, std::optional bottom_up = std::nullopt); + static std::shared_ptr create( + Device *device, + DeviceAddr address, + DeviceAddr size, + DeviceAddr page_size, + BufferType buffer_type, + TensorMemoryLayout buffer_layout = TensorMemoryLayout::INTERLEAVED, + const std::optional& shard_parameter = std::nullopt, + std::optional bottom_up = std::nullopt); Buffer(const Buffer &other) = delete; Buffer &operator=(const Buffer &other) = delete; @@ -210,7 +221,7 @@ class Buffer final { const std::shared_ptr& get_buffer_page_mapping(); - private: + Buffer( Device *device, DeviceAddr size, @@ -218,8 +229,11 @@ class Buffer final { BufferType buffer_type, TensorMemoryLayout buffer_layout, const std::optional& shard_parameter, - std::optional bottom_up); + std::optional bottom_up, + bool owns_data, + Private); + private: enum class AllocationStatus : uint8_t { ALLOCATION_REQUESTED, ALLOCATED, @@ -239,6 +253,7 @@ class Buffer final { const BufferType buffer_type_; const TensorMemoryLayout buffer_layout_; const bool bottom_up_; + const bool owns_data_; std::atomic allocation_status_ = AllocationStatus::ALLOCATION_REQUESTED; DeviceAddr address_ = 0; diff --git a/tt_metal/tt_metal.cpp b/tt_metal/tt_metal.cpp index d905e5627bf..7f3133c7550 100644 --- a/tt_metal/tt_metal.cpp +++ b/tt_metal/tt_metal.cpp @@ -1090,6 +1090,11 @@ std::shared_ptr CreateBuffer(const InterleavedBufferConfig &config) { config.device, config.size, config.page_size, config.buffer_type, config.buffer_layout, std::nullopt, std::nullopt); } +std::shared_ptr CreateBuffer(const InterleavedBufferConfig &config, DeviceAddr address) { + return Buffer::create( + config.device, address, config.size, config.page_size, config.buffer_type, config.buffer_layout, std::nullopt, std::nullopt); +} + std::shared_ptr CreateBuffer(const ShardedBufferConfig &config) { return Buffer::create( config.device, @@ -1101,6 +1106,18 @@ std::shared_ptr CreateBuffer(const ShardedBufferConfig &config) { std::nullopt); } +std::shared_ptr CreateBuffer(const ShardedBufferConfig &config, DeviceAddr address) { + return Buffer::create( + config.device, + address, + config.size, + config.page_size, + config.buffer_type, + config.buffer_layout, + config.shard_parameters, + std::nullopt); +} + void DeallocateBuffer(Buffer &buffer) { buffer.deallocate(); } void AssignGlobalBufferToProgram(