Skip to content

Commit

Permalink
Allow querying the allocator for the buffer size (#1404)
Browse files Browse the repository at this point in the history
  • Loading branch information
angeloskath authored Sep 12, 2024
1 parent 8b30acd commit 881f09b
Show file tree
Hide file tree
Showing 10 changed files with 42 additions and 13 deletions.
15 changes: 13 additions & 2 deletions mlx/allocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,22 @@ void free(Buffer buffer) {
}

Buffer CommonAllocator::malloc(size_t size, bool) {
return Buffer{std::malloc(size)};
void* ptr = std::malloc(size + sizeof(size_t));
if (ptr != nullptr) {
*static_cast<size_t*>(ptr) = size;
}
return Buffer{ptr};
}

void CommonAllocator::free(Buffer buffer) {
std::free(buffer.raw_ptr());
std::free(buffer.ptr());
}

size_t CommonAllocator::size(Buffer buffer) const {
if (buffer.ptr() == nullptr) {
return 0;
}
return *static_cast<size_t*>(buffer.ptr());
}

Buffer malloc_or_wait(size_t size) {
Expand Down
2 changes: 2 additions & 0 deletions mlx/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class Allocator {
public:
virtual Buffer malloc(size_t size, bool allow_swap = false) = 0;
virtual void free(Buffer buffer) = 0;
virtual size_t size(Buffer buffer) const = 0;

Allocator() = default;
Allocator(const Allocator& other) = delete;
Expand All @@ -57,6 +58,7 @@ class CommonAllocator : public Allocator {
public:
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override;

private:
CommonAllocator() = default;
Expand Down
4 changes: 4 additions & 0 deletions mlx/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,10 @@ class array {
return array_desc_->data->buffer;
}

size_t buffer_size() const {
return allocator::allocator().size(buffer());
}

// Return a copy of the shared pointer
// to the array::Data struct
std::shared_ptr<Data> data_shared_ptr() const {
Expand Down
16 changes: 8 additions & 8 deletions mlx/backend/common/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,15 @@ void set_binary_op_output_data(
array& out,
BinaryOpType bopt,
bool donate_with_move = false) {
bool b_donatable = is_donatable(b, out);
bool a_donatable = is_donatable(a, out);
switch (bopt) {
case BinaryOpType::ScalarScalar:
out.set_data(
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
break;
case BinaryOpType::ScalarVector:
if (b.is_donatable() && b.itemsize() == out.itemsize()) {
if (b_donatable) {
if (donate_with_move) {
out.move_shared_buffer(b);
} else {
Expand All @@ -64,7 +66,7 @@ void set_binary_op_output_data(
}
break;
case BinaryOpType::VectorScalar:
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
if (a_donatable) {
if (donate_with_move) {
out.move_shared_buffer(a);
} else {
Expand All @@ -79,13 +81,13 @@ void set_binary_op_output_data(
}
break;
case BinaryOpType::VectorVector:
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
if (a_donatable) {
if (donate_with_move) {
out.move_shared_buffer(a);
} else {
out.copy_shared_buffer(a);
}
} else if (b.is_donatable() && b.itemsize() == out.itemsize()) {
} else if (b_donatable) {
if (donate_with_move) {
out.move_shared_buffer(b);
} else {
Expand All @@ -100,16 +102,14 @@ void set_binary_op_output_data(
}
break;
case BinaryOpType::General:
if (a.is_donatable() && a.flags().row_contiguous &&
a.itemsize() == out.itemsize() && a.size() == out.size()) {
if (a_donatable && a.flags().row_contiguous && a.size() == out.size()) {
if (donate_with_move) {
out.move_shared_buffer(a);
} else {
out.copy_shared_buffer(a);
}
} else if (
b.is_donatable() && b.flags().row_contiguous &&
b.itemsize() == out.itemsize() && b.size() == out.size()) {
b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
if (donate_with_move) {
out.move_shared_buffer(b);
} else {
Expand Down
2 changes: 1 addition & 1 deletion mlx/backend/common/ternary.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ void set_ternary_op_output_data(
TernaryOpType topt,
bool donate_with_move = false) {
auto maybe_donate = [&out, donate_with_move](const array& x) {
if (x.is_donatable() && x.itemsize() == out.itemsize()) {
if (is_donatable(x, out)) {
if (donate_with_move) {
out.move_shared_buffer(x);
} else {
Expand Down
2 changes: 1 addition & 1 deletion mlx/backend/common/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace mlx::core {
namespace {

void set_unary_output_data(const array& in, array& out) {
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
if (is_donatable(in, out)) {
out.copy_shared_buffer(in);
} else {
auto size = in.data_size();
Expand Down
7 changes: 7 additions & 0 deletions mlx/backend/common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,4 +155,11 @@ inline auto check_contiguity(
no_broadcast_data_size, is_row_contiguous, is_col_contiguous);
}

inline bool is_donatable(const array& in, const array& out) {
constexpr size_t donation_extra = 16384;

return in.is_donatable() && in.itemsize() == out.itemsize() &&
in.buffer_size() <= out.nbytes() + donation_extra;
}

} // namespace mlx::core
4 changes: 4 additions & 0 deletions mlx/backend/metal/allocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,10 @@ void MetalAllocator::free(Buffer buffer) {
}
}

size_t MetalAllocator::size(Buffer buffer) const {
return static_cast<MTL::Buffer*>(buffer.ptr())->length();
}

MetalAllocator& allocator() {
// By creating the |allocator_| on heap, the destructor of MetalAllocator will
// not be called on exit and all the buffers will be leaked. This is necessary
Expand Down
1 change: 1 addition & 0 deletions mlx/backend/metal/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class MetalAllocator : public allocator::Allocator {
public:
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override;
size_t get_active_memory() {
return active_memory_;
};
Expand Down
2 changes: 1 addition & 1 deletion mlx/backend/no_metal/allocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Allocator& allocator() {
}

void* Buffer::raw_ptr() {
return ptr_;
return static_cast<size_t*>(ptr_) + 1;
}

} // namespace mlx::core::allocator

0 comments on commit 881f09b

Please sign in to comment.