From dd580f21ffb6e4d23a899c7e0ac6d2bc502f3f1a Mon Sep 17 00:00:00 2001 From: guo ran <360112263@qq.com> Date: Wed, 13 Jul 2022 17:21:19 +0800 Subject: [PATCH] OneEmbedding add tmp_buffer allocator (#8588) * fix embedding manager * format * refine embedding_manager tmp_buffer allocator * fix * format * refine * refine * auto format by CI Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> Co-authored-by: oneflow-ci-bot --- oneflow/core/embedding/embedding_manager.cpp | 155 ++++++------------ oneflow/core/embedding/embedding_manager.h | 24 ++- oneflow/user/kernels/data_shuffle_kernel.cu | 122 ++++++-------- oneflow/user/kernels/one_embedding_kernels.cu | 37 ++--- 4 files changed, 132 insertions(+), 206 deletions(-) diff --git a/oneflow/core/embedding/embedding_manager.cpp b/oneflow/core/embedding/embedding_manager.cpp index d6843991377..52cc123bf22 100644 --- a/oneflow/core/embedding/embedding_manager.cpp +++ b/oneflow/core/embedding/embedding_manager.cpp @@ -37,6 +37,23 @@ struct IdStatistics { #if CUDA_VERSION >= 11020 +class DynamicTmpBufferAllocator final : public TmpBufferAllocator { + public: + OF_DISALLOW_COPY_AND_MOVE(DynamicTmpBufferAllocator); + DynamicTmpBufferAllocator(cudaStream_t stream, cudaMemPool_t pool) + : stream_(stream), mem_pool_(pool) {} + ~DynamicTmpBufferAllocator() override = default; + + void Allocate(void** ptr, size_t size) override { + OF_CUDA_CHECK(cudaMallocFromPoolAsync(ptr, GetCudaAlignedSize(size), mem_pool_, stream_)); + } + void Free(void* ptr) override { OF_CUDA_CHECK(cudaFreeAsync(ptr, stream_)); } + + private: + cudaStream_t stream_{}; + cudaMemPool_t mem_pool_{}; +}; + class DynamicAllocationEmbeddingState final : public EmbeddingState { public: OF_DISALLOW_COPY_AND_MOVE(DynamicAllocationEmbeddingState); @@ -67,12 +84,10 @@ class DynamicAllocationEmbeddingState final : public EmbeddingState { OF_CUDA_CHECK(cudaMemPoolDestroy(mem_pool_)); } - void OnEmbeddingPrefetchStart(user_op::KernelComputeContext* ctx, int64_t iter) override { - // do nothing - } - - void OnEmbeddingPrefetchEnd(user_op::KernelComputeContext* ctx, int64_t iter) override { - // do nothing + std::unique_ptr NewTmpBufferAllocator( + user_op::KernelComputeContext* ctx) override { + return std::make_unique( + ctx->stream()->As()->cuda_stream(), mem_pool_); } void OnEmbeddingLookupStart(user_op::KernelComputeContext* ctx, int64_t iter) override { @@ -142,14 +157,6 @@ class DynamicAllocationEmbeddingState final : public EmbeddingState { // do nothing } - void OnEmbeddingGradientShuffleStart(user_op::KernelComputeContext* ctx, int64_t iter) override { - // do nothing - } - - void OnEmbeddingGradientShuffleEnd(user_op::KernelComputeContext* ctx, int64_t iter) override { - // do nothing - } - void OnEmbeddingUpdateStart(user_op::KernelComputeContext* ctx, int64_t iter) override { const user_op::Tensor* updated_unique_embeddings = ctx->Tensor4ArgNameAndIndex("updated_unique_embeddings", 0); @@ -204,24 +211,6 @@ class DynamicAllocationEmbeddingState final : public EmbeddingState { // do nothing } - void AllocPrefetchTmpBuffer(user_op::KernelComputeContext* ctx, void** ptr, - size_t size) override { - this->AllocTmpBuffer(ctx, ptr, size); - } - - void FreePrefetchTmpBuffer(user_op::KernelComputeContext* ctx, void* ptr) override { - this->FreeTmpBuffer(ctx, ptr); - } - - void AllocTmpBuffer(user_op::KernelComputeContext* ctx, void** ptr, size_t size) override { - OF_CUDA_CHECK(cudaMallocFromPoolAsync(ptr, size, mem_pool_, - ctx->stream()->As()->cuda_stream())); - } - - void FreeTmpBuffer(user_op::KernelComputeContext* ctx, void* ptr) override { - OF_CUDA_CHECK(cudaFreeAsync(ptr, ctx->stream()->As()->cuda_stream())); - } - void SetIdFinalNumUnique(uint32_t final_num_unique, int64_t iter) override { std::unique_lock lock(mutex_); int64_t index = iter % kRingBufferSize; @@ -271,6 +260,31 @@ class DynamicAllocationEmbeddingState final : public EmbeddingState { #endif +class StaticTmpBufferAllocator final : public TmpBufferAllocator { + public: + OF_DISALLOW_COPY_AND_MOVE(StaticTmpBufferAllocator); + StaticTmpBufferAllocator(void* ptr, size_t size) : ptr_(ptr), offset_(0), size_(size) {} + ~StaticTmpBufferAllocator() override = default; + + void Allocate(void** ptr, size_t size) override { + CHECK(ptr_ != nullptr); + CHECK_GE(offset_, 0); + size_t aligned_size = GetCudaAlignedSize(size); + CHECK_LE(offset_ + aligned_size, size_); + *ptr = reinterpret_cast(ptr_) + offset_; + offset_ += aligned_size; + } + + void Free(void* ptr) override { + // do nothing + } + + private: + void* ptr_; + int64_t offset_; + size_t size_; +}; + class StaticAllocationEmbeddingState final : public EmbeddingState { public: OF_DISALLOW_COPY_AND_MOVE(StaticAllocationEmbeddingState); @@ -282,40 +296,16 @@ class StaticAllocationEmbeddingState final : public EmbeddingState { embeding_update_unique_embeddings_(nullptr), embeding_update_updated_unique_embeddings_(nullptr), embedding_put_unique_embeddings_(nullptr), - tmp_buffer_ptr_(nullptr), - tmp_buffer_offset_(0), - tmp_buffer_size_(0), - prefetch_tmp_buffer_ptr_(nullptr), - prefetch_tmp_buffer_offset_(0), - prefetch_tmp_buffer_size_(0) { + embedding_fused_update_put_unique_embeddings_(nullptr) { id_statistics_vec_.resize(kRingBufferSize); } ~StaticAllocationEmbeddingState() override = default; - void InitTmpBufferPtr(user_op::KernelComputeContext* ctx) { + std::unique_ptr NewTmpBufferAllocator( + user_op::KernelComputeContext* ctx) override { user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); - tmp_buffer_ptr_ = tmp_buffer->mut_dptr(); - tmp_buffer_offset_ = 0; - tmp_buffer_size_ = tmp_buffer->shape_view().elem_cnt(); - } - - void ResetTmpBufferPtr() { - tmp_buffer_ptr_ = nullptr; - tmp_buffer_offset_ = 0; - tmp_buffer_size_ = 0; - } - - void OnEmbeddingPrefetchStart(user_op::KernelComputeContext* ctx, int64_t iter) override { - user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); - prefetch_tmp_buffer_ptr_ = tmp_buffer->mut_dptr(); - prefetch_tmp_buffer_offset_ = 0; - prefetch_tmp_buffer_size_ = tmp_buffer->shape_view().elem_cnt(); - } - - void OnEmbeddingPrefetchEnd(user_op::KernelComputeContext* ctx, int64_t iter) override { - prefetch_tmp_buffer_ptr_ = nullptr; - prefetch_tmp_buffer_offset_ = 0; - prefetch_tmp_buffer_size_ = 0; + return std::make_unique(tmp_buffer->mut_dptr(), + tmp_buffer->shape_view().elem_cnt()); } void OnEmbeddingLookupStart(user_op::KernelComputeContext* ctx, int64_t iter) override { @@ -326,7 +316,6 @@ class StaticAllocationEmbeddingState final : public EmbeddingState { has_lookup_embeddings_ = true; lookup_embeddings_ = embeddings->mut_dptr(); } - this->InitTmpBufferPtr(ctx); } void* LookupUniqueValues(int64_t iter) override { return lookup_unique_values_; } @@ -340,14 +329,12 @@ class StaticAllocationEmbeddingState final : public EmbeddingState { lookup_unique_values_ = nullptr; lookup_embeddings_ = nullptr; has_lookup_embeddings_ = false; - this->ResetTmpBufferPtr(); } void OnEmbeddingShuffleStart(user_op::KernelComputeContext* ctx, int64_t iter) override { const user_op::Tensor* cur_rank_embeddings = ctx->Tensor4ArgNameAndIndex("cur_rank_embeddings", 0); embedding_shuffle_cur_rank_embeddings_ = cur_rank_embeddings->dptr(); - this->InitTmpBufferPtr(ctx); } const void* EmbeddingShuffleCurRankEmbeddings(int64_t iter) override { @@ -356,15 +343,6 @@ class StaticAllocationEmbeddingState final : public EmbeddingState { void OnEmbeddingShuffleEnd(user_op::KernelComputeContext* ctx, int64_t iter) override { embedding_shuffle_cur_rank_embeddings_ = nullptr; - this->ResetTmpBufferPtr(); - } - - void OnEmbeddingGradientShuffleStart(user_op::KernelComputeContext* ctx, int64_t iter) override { - this->InitTmpBufferPtr(ctx); - } - - void OnEmbeddingGradientShuffleEnd(user_op::KernelComputeContext* ctx, int64_t iter) override { - this->ResetTmpBufferPtr(); } void OnEmbeddingUpdateStart(user_op::KernelComputeContext* ctx, int64_t iter) override { @@ -414,31 +392,6 @@ class StaticAllocationEmbeddingState final : public EmbeddingState { embedding_fused_update_put_unique_embeddings_ = nullptr; } - void AllocPrefetchTmpBuffer(user_op::KernelComputeContext* ctx, void** ptr, - size_t size) override { - CHECK(prefetch_tmp_buffer_ptr_ != nullptr); - CHECK_GE(prefetch_tmp_buffer_offset_, 0); - CHECK_LE(prefetch_tmp_buffer_offset_ + size, prefetch_tmp_buffer_size_); - *ptr = reinterpret_cast(prefetch_tmp_buffer_ptr_) + prefetch_tmp_buffer_offset_; - prefetch_tmp_buffer_offset_ += size; - } - - void FreePrefetchTmpBuffer(user_op::KernelComputeContext* ctx, void* ptr) override { - // do nothing - } - - void AllocTmpBuffer(user_op::KernelComputeContext* ctx, void** ptr, size_t size) override { - CHECK(tmp_buffer_ptr_ != nullptr); - CHECK_GE(tmp_buffer_offset_, 0); - CHECK_LE(tmp_buffer_offset_ + size, tmp_buffer_size_); - *ptr = reinterpret_cast(tmp_buffer_ptr_) + tmp_buffer_offset_; - tmp_buffer_offset_ += size; - } - - void FreeTmpBuffer(user_op::KernelComputeContext* ctx, void* ptr) override { - // do nothing - } - void SetIdFinalNumUnique(uint32_t final_num_unique, int64_t iter) override { std::unique_lock lock(mutex_); int64_t index = iter % kRingBufferSize; @@ -480,12 +433,6 @@ class StaticAllocationEmbeddingState final : public EmbeddingState { const void* embedding_put_unique_embeddings_; const void* embedding_fused_update_put_unique_embeddings_; std::vector id_statistics_vec_; - void* tmp_buffer_ptr_; - int64_t tmp_buffer_offset_; - size_t tmp_buffer_size_; - void* prefetch_tmp_buffer_ptr_; - int64_t prefetch_tmp_buffer_offset_; - size_t prefetch_tmp_buffer_size_; std::mutex mutex_; }; diff --git a/oneflow/core/embedding/embedding_manager.h b/oneflow/core/embedding/embedding_manager.h index b3ea9d7cfbd..44fcd4e73cf 100644 --- a/oneflow/core/embedding/embedding_manager.h +++ b/oneflow/core/embedding/embedding_manager.h @@ -42,13 +42,22 @@ inline bool UseDynamicMemoryAllocation() { #ifdef WITH_CUDA +class TmpBufferAllocator { + public: + TmpBufferAllocator() = default; + virtual ~TmpBufferAllocator() = default; + + virtual void Allocate(void** ptr, size_t size) = 0; + virtual void Free(void* ptr) = 0; +}; + class EmbeddingState { public: EmbeddingState() = default; virtual ~EmbeddingState() = default; - virtual void OnEmbeddingPrefetchStart(user_op::KernelComputeContext* ctx, int64_t iter) = 0; - virtual void OnEmbeddingPrefetchEnd(user_op::KernelComputeContext* ctx, int64_t iter) = 0; + virtual std::unique_ptr NewTmpBufferAllocator( + user_op::KernelComputeContext* ctx) = 0; virtual void OnEmbeddingLookupStart(user_op::KernelComputeContext* ctx, int64_t iter) = 0; virtual void* LookupUniqueValues(int64_t iter) = 0; @@ -59,10 +68,6 @@ class EmbeddingState { virtual const void* EmbeddingShuffleCurRankEmbeddings(int64_t iter) = 0; virtual void OnEmbeddingShuffleEnd(user_op::KernelComputeContext* ctx, int64_t iter) = 0; - virtual void OnEmbeddingGradientShuffleStart(user_op::KernelComputeContext* ctx, - int64_t iter) = 0; - virtual void OnEmbeddingGradientShuffleEnd(user_op::KernelComputeContext* ctx, int64_t iter) = 0; - virtual void OnEmbeddingUpdateStart(user_op::KernelComputeContext* ctx, int64_t iter) = 0; virtual const void* EmbeddingUpdateUniqueEmbeddings(int64_t iter) = 0; virtual void* EmbeddingUpdateUpdatedUniqueEmbeddings(int64_t iter) = 0; @@ -76,13 +81,6 @@ class EmbeddingState { virtual const void* EmbeddingFusedUpdatePutUniqueEmbeddings(int64_t iter) = 0; virtual void OnEmbeddingFusedUpdatePutEnd(user_op::KernelComputeContext* ctx, int64_t iter) = 0; - virtual void AllocPrefetchTmpBuffer(user_op::KernelComputeContext* ctx, void** ptr, - size_t size) = 0; - virtual void FreePrefetchTmpBuffer(user_op::KernelComputeContext* ctx, void* ptr) = 0; - - virtual void AllocTmpBuffer(user_op::KernelComputeContext* ctx, void** ptr, size_t size) = 0; - virtual void FreeTmpBuffer(user_op::KernelComputeContext* ctx, void* ptr) = 0; - virtual void SetIdFinalNumUnique(uint32_t final_num_unique, int64_t iter) = 0; virtual void SetIdNumUniqueMatrix(const std::vector& num_unique_matrix, int64_t iter) = 0; diff --git a/oneflow/user/kernels/data_shuffle_kernel.cu b/oneflow/user/kernels/data_shuffle_kernel.cu index 3e41a2fcb0b..6c30edabf09 100644 --- a/oneflow/user/kernels/data_shuffle_kernel.cu +++ b/oneflow/user/kernels/data_shuffle_kernel.cu @@ -939,6 +939,8 @@ class EmbeddingShuffleKernel final : public user_op::OpKernel { auto* kernel_state = dynamic_cast*>(state); CHECK(kernel_state != nullptr); embedding::EmbeddingState* embedding_state = kernel_state->EmbeddingState(); + std::unique_ptr allocator = + embedding_state->NewTmpBufferAllocator(ctx); embedding_state->OnEmbeddingShuffleStart(ctx, current_iter_); const user_op::Tensor* num_unique_matrix = ctx->Tensor4ArgNameAndIndex("num_unique_matrix", 0); const user_op::Tensor* cur_rank_inverse_indices = @@ -986,9 +988,8 @@ class EmbeddingShuffleKernel final : public user_op::OpKernel { // 1. reverse cur_rank unique, from (num_unique, embedding_size) to (cur_rank_num_ids, // embedding_size) void* reverse_unique_cur_rank_embeddings; - embedding_state->AllocTmpBuffer( - ctx, &reverse_unique_cur_rank_embeddings, - GetCudaAlignedSize(cur_rank_num_ids * embedding_size * sizeof(T))); + allocator->Allocate(&reverse_unique_cur_rank_embeddings, + cur_rank_num_ids * embedding_size * sizeof(T)); GatherKernelUtilImpl::Forward( ctx->stream(), reinterpret_cast(cur_rank_inverse_indices->dptr()), cur_rank_num_ids, cur_rank_embeddings_ptr, Shape({1, num_unique, embedding_size}), @@ -1001,18 +1002,17 @@ class EmbeddingShuffleKernel final : public user_op::OpKernel { data_type, host_num_unique_matrix, reinterpret_cast(reverse_unique_cur_rank_embeddings), embeddings->mut_dptr()); - embedding_state->FreeTmpBuffer(ctx, reverse_unique_cur_rank_embeddings); + allocator->Free(reverse_unique_cur_rank_embeddings); } else { void* received_embeddings; // T - embedding_state->AllocTmpBuffer( - ctx, &received_embeddings, - GetCudaAlignedSize(unique_partitioned_num_ids * embedding_size * sizeof(T))); + allocator->Allocate(&received_embeddings, GetCudaAlignedSize(unique_partitioned_num_ids + * embedding_size * sizeof(T))); ShuffleEmbeddings(cuda_stream, comm, parallel_id, parallel_num, num_ids, embedding_size, data_type, host_num_unique_matrix, reinterpret_cast(reverse_unique_cur_rank_embeddings), reinterpret_cast(received_embeddings)); - embedding_state->FreeTmpBuffer(ctx, reverse_unique_cur_rank_embeddings); + allocator->Free(reverse_unique_cur_rank_embeddings); // 3. reverse unique_partition, from (unique_partitioned_num_ids, embedding_size) to // (num_ids, embedding_size) @@ -1020,19 +1020,17 @@ class EmbeddingShuffleKernel final : public user_op::OpKernel { ctx->stream(), reinterpret_cast(inverse_unique_partition_indices->dptr()), num_ids, reinterpret_cast(received_embeddings), Shape({1, unique_partitioned_num_ids, embedding_size}), embeddings->mut_dptr(), 0); - embedding_state->FreeTmpBuffer(ctx, received_embeddings); + allocator->Free(received_embeddings); } } else { CHECK(!skip_last_gather) << "when enable_quantized_comm, should not use fuse kernel."; // 1. quantize cur_rank_embeddings, from (num_unique, embedding_size) T to (num_unique, // embedding_size) int8_t, and get (num_unique,) T factor void* quantize_cur_rank_embeddings; // int8_t - embedding_state->AllocTmpBuffer( - ctx, &quantize_cur_rank_embeddings, - GetCudaAlignedSize(num_unique * embedding_size * sizeof(int8_t))); + allocator->Allocate(&quantize_cur_rank_embeddings, + num_unique * embedding_size * sizeof(int8_t)); void* cur_rank_quantize_factor; // T - embedding_state->AllocTmpBuffer(ctx, &cur_rank_quantize_factor, - GetCudaAlignedSize(num_unique * sizeof(T))); + allocator->Allocate(&cur_rank_quantize_factor, num_unique * sizeof(T)); DispatchQuantizeWarpImplPackSize()( cuda_stream, cur_rank_embeddings_ptr, reinterpret_cast(quantize_cur_rank_embeddings), @@ -1041,36 +1039,32 @@ class EmbeddingShuffleKernel final : public user_op::OpKernel { // embedding_size) void* reverse_unique_cur_rank_embeddings; // int8_t - embedding_state->AllocTmpBuffer( - ctx, &reverse_unique_cur_rank_embeddings, - GetCudaAlignedSize(cur_rank_num_ids * embedding_size * sizeof(int8_t))); + allocator->Allocate(&reverse_unique_cur_rank_embeddings, + cur_rank_num_ids * embedding_size * sizeof(int8_t)); GatherKernelUtilImpl::Forward( ctx->stream(), reinterpret_cast(cur_rank_inverse_indices->dptr()), cur_rank_num_ids, reinterpret_cast(quantize_cur_rank_embeddings), Shape({1, num_unique, embedding_size}), reinterpret_cast(reverse_unique_cur_rank_embeddings), 0); - embedding_state->FreeTmpBuffer(ctx, quantize_cur_rank_embeddings); + allocator->Free(quantize_cur_rank_embeddings); // 3. reverse cur_rank quantize factor unique, from (num_unique) to (cur_rank_num_ids) void* reverse_cur_rank_quantize_factor; // T - embedding_state->AllocTmpBuffer(ctx, &reverse_cur_rank_quantize_factor, - GetCudaAlignedSize(cur_rank_num_ids * sizeof(T))); + allocator->Allocate(&reverse_cur_rank_quantize_factor, cur_rank_num_ids * sizeof(T)); GatherKernelUtilImpl::Forward( ctx->stream(), reinterpret_cast(cur_rank_inverse_indices->dptr()), cur_rank_num_ids, reinterpret_cast(cur_rank_quantize_factor), Shape({1, num_unique, 1}), reinterpret_cast(reverse_cur_rank_quantize_factor), 0); - embedding_state->FreeTmpBuffer(ctx, cur_rank_quantize_factor); + allocator->Free(cur_rank_quantize_factor); // 4. send recv embedding and factor, from (cur_rank_num_ids, embedding_size) to // (unique_partitioned_num_ids, embedding_size) void* received_embeddings; // int8_t void* recv_quantize_factor; // T - embedding_state->AllocTmpBuffer( - ctx, &received_embeddings, - GetCudaAlignedSize(unique_partitioned_num_ids * embedding_size * sizeof(int8_t))); - embedding_state->AllocTmpBuffer(ctx, &recv_quantize_factor, - GetCudaAlignedSize(unique_partitioned_num_ids * sizeof(T))); + allocator->Allocate(&received_embeddings, + unique_partitioned_num_ids * embedding_size * sizeof(int8_t)); + allocator->Allocate(&recv_quantize_factor, unique_partitioned_num_ids * sizeof(T)); ShuffleEmbeddings(cuda_stream, comm, parallel_id, parallel_num, num_ids, embedding_size, data_type, host_num_unique_matrix, @@ -1078,33 +1072,31 @@ class EmbeddingShuffleKernel final : public user_op::OpKernel { reinterpret_cast(received_embeddings), reinterpret_cast(reverse_cur_rank_quantize_factor), reinterpret_cast(recv_quantize_factor)); - embedding_state->FreeTmpBuffer(ctx, reverse_unique_cur_rank_embeddings); - embedding_state->FreeTmpBuffer(ctx, reverse_cur_rank_quantize_factor); + allocator->Free(reverse_unique_cur_rank_embeddings); + allocator->Free(reverse_cur_rank_quantize_factor); // 5. reverse unique_partition, from (unique_partitioned_num_ids, embedding_size) to (num_ids, // embedding_size) void* reverse_recv_quantize_cur_rank_embeddings; // int8_t - embedding_state->AllocTmpBuffer( - ctx, &reverse_recv_quantize_cur_rank_embeddings, - GetCudaAlignedSize(num_ids * embedding_size * sizeof(int8_t))); + allocator->Allocate(&reverse_recv_quantize_cur_rank_embeddings, + num_ids * embedding_size * sizeof(int8_t)); GatherKernelUtilImpl::Forward( ctx->stream(), reinterpret_cast(inverse_unique_partition_indices->dptr()), num_ids, reinterpret_cast(received_embeddings), Shape({1, unique_partitioned_num_ids, embedding_size}), reinterpret_cast(reverse_recv_quantize_cur_rank_embeddings), 0); - embedding_state->FreeTmpBuffer(ctx, received_embeddings); + allocator->Free(received_embeddings); // 6. reverse unique_partition_factor, from (unique_partitioned_num_ids) to (num_ids) void* reverse_recv_quantize_factor; // T - embedding_state->AllocTmpBuffer(ctx, &reverse_recv_quantize_factor, - GetCudaAlignedSize(num_ids * sizeof(T))); + allocator->Allocate(&reverse_recv_quantize_factor, num_ids * sizeof(T)); GatherKernelUtilImpl::Forward( ctx->stream(), reinterpret_cast(inverse_unique_partition_indices->dptr()), num_ids, reinterpret_cast(recv_quantize_factor), Shape({1, unique_partitioned_num_ids, 1}), reinterpret_cast(reverse_recv_quantize_factor), 0); - embedding_state->FreeTmpBuffer(ctx, recv_quantize_factor); + allocator->Free(recv_quantize_factor); // 7. dequantize embeddings, from (num_ids, embedding_size) int8_t to (num_ids, // embedding_size) T @@ -1114,8 +1106,8 @@ class EmbeddingShuffleKernel final : public user_op::OpKernel { cuda_stream, reinterpret_cast(reverse_recv_quantize_cur_rank_embeddings), reinterpret_cast(reverse_recv_quantize_factor), embeddings->mut_dptr(), embedding_size, dequantize_elem_cnt))); - embedding_state->FreeTmpBuffer(ctx, reverse_recv_quantize_cur_rank_embeddings); - embedding_state->FreeTmpBuffer(ctx, reverse_recv_quantize_factor); + allocator->Free(reverse_recv_quantize_cur_rank_embeddings); + allocator->Free(reverse_recv_quantize_factor); } embedding_state->OnEmbeddingShuffleEnd(ctx, current_iter_); current_iter_++; @@ -1370,7 +1362,8 @@ class EmbeddingGradientShuffleKernel final : public user_op::OpKernel { auto* kernel_state = dynamic_cast*>(state); CHECK(kernel_state != nullptr); embedding::EmbeddingState* embedding_state = kernel_state->EmbeddingState(); - embedding_state->OnEmbeddingGradientShuffleStart(ctx, current_iter_); + std::unique_ptr allocator = + embedding_state->NewTmpBufferAllocator(ctx); const user_op::Tensor* embedding_grad = ctx->Tensor4ArgNameAndIndex("embedding_grad", 0); const user_op::Tensor* num_unique_matrix = ctx->Tensor4ArgNameAndIndex("num_unique_matrix", 0); @@ -1420,9 +1413,8 @@ class EmbeddingGradientShuffleKernel final : public user_op::OpKernel { // 1. sum to unique grad, from (num_ids, embedding_size) to (unique_partitioned_num_ids, // padded_embedding_size) void* unique_partition_embedding_grad; // T - embedding_state->AllocTmpBuffer( - ctx, &unique_partition_embedding_grad, - GetCudaAlignedSize(unique_partitioned_num_ids * padded_embedding_size * sizeof(T))); + allocator->Allocate(&unique_partition_embedding_grad, + unique_partitioned_num_ids * padded_embedding_size * sizeof(T)); const T* unique_embedding_grad_ptr; if (skip_first_scatter) { @@ -1438,9 +1430,8 @@ class EmbeddingGradientShuffleKernel final : public user_op::OpKernel { // 2. send recv grad, from (unique_partitioned_num_ids, padded_embedding_size) to // (cur_rank_num_ids, padded_embedding_size) void* received_embedding_grad; // T - embedding_state->AllocTmpBuffer( - ctx, &received_embedding_grad, - GetCudaAlignedSize(cur_rank_num_ids * padded_embedding_size * sizeof(T))); + allocator->Allocate(&received_embedding_grad, + cur_rank_num_ids * padded_embedding_size * sizeof(T)); ShuffleEmbeddingsGrad(cuda_stream, comm, parallel_id, parallel_num, num_ids, padded_embedding_size, data_type, host_num_unique_matrix, @@ -1460,16 +1451,15 @@ class EmbeddingGradientShuffleKernel final : public user_op::OpKernel { reinterpret_cast(received_embedding_grad), reinterpret_cast(cur_rank_inverse_indices->dptr()), cur_rank_unique_embedding_grad->mut_dptr(), buffer_ptr); - embedding_state->FreeTmpBuffer(ctx, unique_partition_embedding_grad); - embedding_state->FreeTmpBuffer(ctx, received_embedding_grad); + allocator->Free(unique_partition_embedding_grad); + allocator->Free(received_embedding_grad); } else { CHECK(!skip_first_scatter) << "when enable_quantized_comm, should not use fuse kernel."; // 1. sum to unique grad, from (num_ids, embedding_size) to (unique_partitioned_num_ids, // padded_embedding_size) void* unique_partition_embedding_grad; // T - embedding_state->AllocTmpBuffer( - ctx, &unique_partition_embedding_grad, - GetCudaAlignedSize(unique_partitioned_num_ids * padded_embedding_size * sizeof(T))); + allocator->Allocate(&unique_partition_embedding_grad, + unique_partitioned_num_ids * padded_embedding_size * sizeof(T)); UniquePartitionEmbeddingGrad( ctx->stream(), unique_partitioned_num_ids, num_ids, embedding_size, padded_embedding_size, @@ -1481,12 +1471,10 @@ class EmbeddingGradientShuffleKernel final : public user_op::OpKernel { // quantize_cur_rank_embedding_grad(unique_partitioned_num_ids, padded_embedding_size) int8_t // and cur_rank_quantize_factor(unique_partitioned_num_ids) T void* quantize_cur_rank_embedding_grad; // int8_t - embedding_state->AllocTmpBuffer( - ctx, &quantize_cur_rank_embedding_grad, - GetCudaAlignedSize(unique_partitioned_num_ids * padded_embedding_size * sizeof(int8_t))); + allocator->Allocate(&quantize_cur_rank_embedding_grad, + unique_partitioned_num_ids * padded_embedding_size * sizeof(int8_t)); void* cur_rank_quantize_factor; // T - embedding_state->AllocTmpBuffer(ctx, &cur_rank_quantize_factor, - GetCudaAlignedSize(unique_partitioned_num_ids * sizeof(T))); + allocator->Allocate(&cur_rank_quantize_factor, unique_partitioned_num_ids * sizeof(T)); DispatchQuantizeWarpImplPackSize()( cuda_stream, reinterpret_cast(unique_partition_embedding_grad), @@ -1498,12 +1486,10 @@ class EmbeddingGradientShuffleKernel final : public user_op::OpKernel { // (cur_rank_num_ids, padded_embedding_size) int8_t send recv quantize_factor, from // (unique_partitioned_num_ids) T to (cur_rank_num_ids) T void* received_embedding_grad; // int8_t - embedding_state->AllocTmpBuffer( - ctx, &received_embedding_grad, - GetCudaAlignedSize(cur_rank_num_ids * padded_embedding_size * sizeof(int8_t))); + allocator->Allocate(&received_embedding_grad, + cur_rank_num_ids * padded_embedding_size * sizeof(int8_t)); void* received_cur_rank_quantize_factor; // T - embedding_state->AllocTmpBuffer(ctx, &received_cur_rank_quantize_factor, - GetCudaAlignedSize(cur_rank_num_ids * sizeof(T))); + allocator->Allocate(&received_cur_rank_quantize_factor, cur_rank_num_ids * sizeof(T)); ShuffleEmbeddingsGrad(cuda_stream, comm, parallel_id, parallel_num, num_ids, padded_embedding_size, data_type, host_num_unique_matrix, @@ -1511,8 +1497,8 @@ class EmbeddingGradientShuffleKernel final : public user_op::OpKernel { reinterpret_cast(received_embedding_grad), reinterpret_cast(cur_rank_quantize_factor), reinterpret_cast(received_cur_rank_quantize_factor)); - embedding_state->FreeTmpBuffer(ctx, quantize_cur_rank_embedding_grad); - embedding_state->FreeTmpBuffer(ctx, cur_rank_quantize_factor); + allocator->Free(quantize_cur_rank_embedding_grad); + allocator->Free(cur_rank_quantize_factor); /* Host num unique matrix: @@ -1527,17 +1513,16 @@ class EmbeddingGradientShuffleKernel final : public user_op::OpKernel { // 4. dequantize grad, from (cur_rank_num_ids, padded_embedding_size) int8_t to // (cur_rank_num_ids, padded_embedding_size) T void* dequantize_cur_rank_embedding_grad; // T - embedding_state->AllocTmpBuffer( - ctx, &dequantize_cur_rank_embedding_grad, - GetCudaAlignedSize(cur_rank_num_ids * padded_embedding_size * sizeof(T))); + allocator->Allocate(&dequantize_cur_rank_embedding_grad, + cur_rank_num_ids * padded_embedding_size * sizeof(T)); OF_CUDA_CHECK((LaunchDequantizeKernel( cuda_stream, reinterpret_cast(received_embedding_grad), reinterpret_cast(received_cur_rank_quantize_factor), reinterpret_cast(dequantize_cur_rank_embedding_grad), padded_embedding_size, cur_rank_num_ids * padded_embedding_size))); - embedding_state->FreeTmpBuffer(ctx, received_embedding_grad); - embedding_state->FreeTmpBuffer(ctx, received_cur_rank_quantize_factor); + allocator->Free(received_embedding_grad); + allocator->Free(received_cur_rank_quantize_factor); // use unique_partition_embedding_grad as UniqueCurRankEmbeddingGrad buffer. T* buffer_ptr = reinterpret_cast(unique_partition_embedding_grad); @@ -1552,10 +1537,9 @@ class EmbeddingGradientShuffleKernel final : public user_op::OpKernel { reinterpret_cast(dequantize_cur_rank_embedding_grad), reinterpret_cast(cur_rank_inverse_indices->dptr()), cur_rank_unique_embedding_grad->mut_dptr(), buffer_ptr); - embedding_state->FreeTmpBuffer(ctx, unique_partition_embedding_grad); - embedding_state->FreeTmpBuffer(ctx, dequantize_cur_rank_embedding_grad); + allocator->Free(unique_partition_embedding_grad); + allocator->Free(dequantize_cur_rank_embedding_grad); } - embedding_state->OnEmbeddingGradientShuffleEnd(ctx, current_iter_); current_iter_++; } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } diff --git a/oneflow/user/kernels/one_embedding_kernels.cu b/oneflow/user/kernels/one_embedding_kernels.cu index 1bbe0de06e0..f217d0d8339 100644 --- a/oneflow/user/kernels/one_embedding_kernels.cu +++ b/oneflow/user/kernels/one_embedding_kernels.cu @@ -561,7 +561,7 @@ user_op::InferTmpSizeFn GenEmbeddingInferTmpSizeFn() { size_t value_buffer_size; if (is_prefetch) { size_t value_byte_size = ctx->Attr("line_size") * sizeof(T); - value_buffer_size = num_ids * value_byte_size; + value_buffer_size = GetCudaAlignedSize(num_ids * value_byte_size); } else { value_buffer_size = 0; } @@ -590,7 +590,8 @@ class EmbeddingPrefetchKernel final : public user_op::OpKernel { auto* kernel_state = dynamic_cast*>(state); CHECK(kernel_state != nullptr); embedding::EmbeddingState* embedding_state = kernel_state->EmbeddingState(); - embedding_state->OnEmbeddingPrefetchStart(ctx, current_iter_); + std::unique_ptr allocator = + embedding_state->NewTmpBufferAllocator(ctx); uint32_t num_unique = embedding_state->GetIdNumUnique(current_iter_); const user_op::Tensor* num_unique_ids = ctx->Tensor4ArgNameAndIndex("num_unique_ids", 0); const user_op::Tensor* unique_ids = ctx->Tensor4ArgNameAndIndex("unique_ids", 0); @@ -599,21 +600,17 @@ class EmbeddingPrefetchKernel final : public user_op::OpKernel { const int64_t line_size = ctx->Attr("line_size"); void* num_missing_ptr; - embedding_state->AllocPrefetchTmpBuffer(ctx, &num_missing_ptr, - GetCudaAlignedSize(sizeof(uint32_t))); + allocator->Allocate(&num_missing_ptr, sizeof(uint32_t)); void* missing_indices_ptr; - embedding_state->AllocPrefetchTmpBuffer(ctx, &missing_indices_ptr, - GetCudaAlignedSize(num_unique * sizeof(uint32_t))); + allocator->Allocate(&missing_indices_ptr, num_unique * sizeof(uint32_t)); void* values_ptr; - embedding_state->AllocPrefetchTmpBuffer(ctx, &values_ptr, - GetCudaAlignedSize(num_unique * line_size * sizeof(T))); + allocator->Allocate(&values_ptr, num_unique * line_size * sizeof(T)); LookupAndInitMissing(ctx->stream(), kernel_state, num_unique, embedding_size, line_size, true, unique_ids->dptr(), table_ids->dptr(), num_missing_ptr, missing_indices_ptr, values_ptr); - embedding_state->FreePrefetchTmpBuffer(ctx, num_missing_ptr); - embedding_state->FreePrefetchTmpBuffer(ctx, missing_indices_ptr); - embedding_state->FreePrefetchTmpBuffer(ctx, values_ptr); - embedding_state->OnEmbeddingPrefetchEnd(ctx, current_iter_); + allocator->Free(num_missing_ptr); + allocator->Free(missing_indices_ptr); + allocator->Free(values_ptr); current_iter_++; } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } @@ -666,6 +663,8 @@ class EmbeddingLookupKernel final : public user_op::OpKernel { auto* kernel_state = dynamic_cast*>(state); CHECK(kernel_state != nullptr); embedding::EmbeddingState* embedding_state = kernel_state->EmbeddingState(); + std::unique_ptr allocator = + embedding_state->NewTmpBufferAllocator(ctx); embedding_state->OnEmbeddingLookupStart(ctx, current_iter_); const user_op::Tensor* num_unique_ids = ctx->Tensor4ArgNameAndIndex("num_unique_ids", 0); const user_op::Tensor* unique_ids = ctx->Tensor4ArgNameAndIndex("unique_ids", 0); @@ -680,25 +679,23 @@ class EmbeddingLookupKernel final : public user_op::OpKernel { void* embeddings_ptr = embedding_state->LookupEmbeddings(current_iter_); user_op::Tensor* embeddings = ctx->Tensor4ArgNameAndIndex("embeddings", 0); void* lookup_mask_ptr; - embedding_state->AllocTmpBuffer(ctx, &lookup_mask_ptr, - GetCudaAlignedSize(num_unique * sizeof(uint8_t))); + allocator->Allocate(&lookup_mask_ptr, num_unique * sizeof(uint8_t)); LookupAndFusedInitMissingSliceCast( ctx->stream(), kernel_state, num_unique, embedding_size, line_size, unique_values->data_type(), embeddings->data_type(), unique_ids->dptr(), table_ids->dptr(), reinterpret_cast(lookup_mask_ptr), values_ptr, embeddings_ptr); - embedding_state->FreeTmpBuffer(ctx, lookup_mask_ptr); + allocator->Free(lookup_mask_ptr); } else { void* num_missing_ptr; - embedding_state->AllocTmpBuffer(ctx, &num_missing_ptr, GetCudaAlignedSize(sizeof(uint32_t))); + allocator->Allocate(&num_missing_ptr, sizeof(uint32_t)); void* missing_indices_ptr; - embedding_state->AllocTmpBuffer(ctx, &missing_indices_ptr, - GetCudaAlignedSize(num_unique * sizeof(uint32_t))); + allocator->Allocate(&missing_indices_ptr, num_unique * sizeof(uint32_t)); LookupAndInitMissing(ctx->stream(), kernel_state, num_unique, embedding_size, line_size, false, unique_ids->dptr(), table_ids->dptr(), num_missing_ptr, missing_indices_ptr, values_ptr); - embedding_state->FreeTmpBuffer(ctx, num_missing_ptr); - embedding_state->FreeTmpBuffer(ctx, missing_indices_ptr); + allocator->Free(num_missing_ptr); + allocator->Free(missing_indices_ptr); if (has_output_embeddings) { void* embeddings_ptr = embedding_state->LookupEmbeddings(current_iter_); user_op::Tensor* embeddings = ctx->Tensor4ArgNameAndIndex("embeddings", 0);