Skip to content

Commit

Permalink
OneEmbedding add tmp_buffer allocator (#8588)
Browse files Browse the repository at this point in the history
* fix embedding manager

* format

* refine embedding_manager tmp_buffer allocator

* fix

* format

* refine

* refine

* auto format by CI

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
  • Loading branch information
3 people authored Jul 13, 2022
1 parent 09601e1 commit dd580f2
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 206 deletions.
155 changes: 51 additions & 104 deletions oneflow/core/embedding/embedding_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,23 @@ struct IdStatistics {

#if CUDA_VERSION >= 11020

class DynamicTmpBufferAllocator final : public TmpBufferAllocator {
public:
OF_DISALLOW_COPY_AND_MOVE(DynamicTmpBufferAllocator);
DynamicTmpBufferAllocator(cudaStream_t stream, cudaMemPool_t pool)
: stream_(stream), mem_pool_(pool) {}
~DynamicTmpBufferAllocator() override = default;

void Allocate(void** ptr, size_t size) override {
OF_CUDA_CHECK(cudaMallocFromPoolAsync(ptr, GetCudaAlignedSize(size), mem_pool_, stream_));
}
void Free(void* ptr) override { OF_CUDA_CHECK(cudaFreeAsync(ptr, stream_)); }

private:
cudaStream_t stream_{};
cudaMemPool_t mem_pool_{};
};

class DynamicAllocationEmbeddingState final : public EmbeddingState {
public:
OF_DISALLOW_COPY_AND_MOVE(DynamicAllocationEmbeddingState);
Expand Down Expand Up @@ -67,12 +84,10 @@ class DynamicAllocationEmbeddingState final : public EmbeddingState {
OF_CUDA_CHECK(cudaMemPoolDestroy(mem_pool_));
}

void OnEmbeddingPrefetchStart(user_op::KernelComputeContext* ctx, int64_t iter) override {
// do nothing
}

void OnEmbeddingPrefetchEnd(user_op::KernelComputeContext* ctx, int64_t iter) override {
// do nothing
std::unique_ptr<TmpBufferAllocator> NewTmpBufferAllocator(
user_op::KernelComputeContext* ctx) override {
return std::make_unique<DynamicTmpBufferAllocator>(
ctx->stream()->As<ep::CudaStream>()->cuda_stream(), mem_pool_);
}

void OnEmbeddingLookupStart(user_op::KernelComputeContext* ctx, int64_t iter) override {
Expand Down Expand Up @@ -142,14 +157,6 @@ class DynamicAllocationEmbeddingState final : public EmbeddingState {
// do nothing
}

void OnEmbeddingGradientShuffleStart(user_op::KernelComputeContext* ctx, int64_t iter) override {
// do nothing
}

void OnEmbeddingGradientShuffleEnd(user_op::KernelComputeContext* ctx, int64_t iter) override {
// do nothing
}

void OnEmbeddingUpdateStart(user_op::KernelComputeContext* ctx, int64_t iter) override {
const user_op::Tensor* updated_unique_embeddings =
ctx->Tensor4ArgNameAndIndex("updated_unique_embeddings", 0);
Expand Down Expand Up @@ -204,24 +211,6 @@ class DynamicAllocationEmbeddingState final : public EmbeddingState {
// do nothing
}

void AllocPrefetchTmpBuffer(user_op::KernelComputeContext* ctx, void** ptr,
size_t size) override {
this->AllocTmpBuffer(ctx, ptr, size);
}

void FreePrefetchTmpBuffer(user_op::KernelComputeContext* ctx, void* ptr) override {
this->FreeTmpBuffer(ctx, ptr);
}

void AllocTmpBuffer(user_op::KernelComputeContext* ctx, void** ptr, size_t size) override {
OF_CUDA_CHECK(cudaMallocFromPoolAsync(ptr, size, mem_pool_,
ctx->stream()->As<ep::CudaStream>()->cuda_stream()));
}

void FreeTmpBuffer(user_op::KernelComputeContext* ctx, void* ptr) override {
OF_CUDA_CHECK(cudaFreeAsync(ptr, ctx->stream()->As<ep::CudaStream>()->cuda_stream()));
}

void SetIdFinalNumUnique(uint32_t final_num_unique, int64_t iter) override {
std::unique_lock<std::mutex> lock(mutex_);
int64_t index = iter % kRingBufferSize;
Expand Down Expand Up @@ -271,6 +260,31 @@ class DynamicAllocationEmbeddingState final : public EmbeddingState {

#endif

class StaticTmpBufferAllocator final : public TmpBufferAllocator {
public:
OF_DISALLOW_COPY_AND_MOVE(StaticTmpBufferAllocator);
StaticTmpBufferAllocator(void* ptr, size_t size) : ptr_(ptr), offset_(0), size_(size) {}
~StaticTmpBufferAllocator() override = default;

void Allocate(void** ptr, size_t size) override {
CHECK(ptr_ != nullptr);
CHECK_GE(offset_, 0);
size_t aligned_size = GetCudaAlignedSize(size);
CHECK_LE(offset_ + aligned_size, size_);
*ptr = reinterpret_cast<char*>(ptr_) + offset_;
offset_ += aligned_size;
}

void Free(void* ptr) override {
// do nothing
}

private:
void* ptr_;
int64_t offset_;
size_t size_;
};

class StaticAllocationEmbeddingState final : public EmbeddingState {
public:
OF_DISALLOW_COPY_AND_MOVE(StaticAllocationEmbeddingState);
Expand All @@ -282,40 +296,16 @@ class StaticAllocationEmbeddingState final : public EmbeddingState {
embeding_update_unique_embeddings_(nullptr),
embeding_update_updated_unique_embeddings_(nullptr),
embedding_put_unique_embeddings_(nullptr),
tmp_buffer_ptr_(nullptr),
tmp_buffer_offset_(0),
tmp_buffer_size_(0),
prefetch_tmp_buffer_ptr_(nullptr),
prefetch_tmp_buffer_offset_(0),
prefetch_tmp_buffer_size_(0) {
embedding_fused_update_put_unique_embeddings_(nullptr) {
id_statistics_vec_.resize(kRingBufferSize);
}
~StaticAllocationEmbeddingState() override = default;

void InitTmpBufferPtr(user_op::KernelComputeContext* ctx) {
std::unique_ptr<TmpBufferAllocator> NewTmpBufferAllocator(
user_op::KernelComputeContext* ctx) override {
user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
tmp_buffer_ptr_ = tmp_buffer->mut_dptr();
tmp_buffer_offset_ = 0;
tmp_buffer_size_ = tmp_buffer->shape_view().elem_cnt();
}

void ResetTmpBufferPtr() {
tmp_buffer_ptr_ = nullptr;
tmp_buffer_offset_ = 0;
tmp_buffer_size_ = 0;
}

void OnEmbeddingPrefetchStart(user_op::KernelComputeContext* ctx, int64_t iter) override {
user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
prefetch_tmp_buffer_ptr_ = tmp_buffer->mut_dptr();
prefetch_tmp_buffer_offset_ = 0;
prefetch_tmp_buffer_size_ = tmp_buffer->shape_view().elem_cnt();
}

void OnEmbeddingPrefetchEnd(user_op::KernelComputeContext* ctx, int64_t iter) override {
prefetch_tmp_buffer_ptr_ = nullptr;
prefetch_tmp_buffer_offset_ = 0;
prefetch_tmp_buffer_size_ = 0;
return std::make_unique<StaticTmpBufferAllocator>(tmp_buffer->mut_dptr(),
tmp_buffer->shape_view().elem_cnt());
}

void OnEmbeddingLookupStart(user_op::KernelComputeContext* ctx, int64_t iter) override {
Expand All @@ -326,7 +316,6 @@ class StaticAllocationEmbeddingState final : public EmbeddingState {
has_lookup_embeddings_ = true;
lookup_embeddings_ = embeddings->mut_dptr();
}
this->InitTmpBufferPtr(ctx);
}

void* LookupUniqueValues(int64_t iter) override { return lookup_unique_values_; }
Expand All @@ -340,14 +329,12 @@ class StaticAllocationEmbeddingState final : public EmbeddingState {
lookup_unique_values_ = nullptr;
lookup_embeddings_ = nullptr;
has_lookup_embeddings_ = false;
this->ResetTmpBufferPtr();
}

void OnEmbeddingShuffleStart(user_op::KernelComputeContext* ctx, int64_t iter) override {
const user_op::Tensor* cur_rank_embeddings =
ctx->Tensor4ArgNameAndIndex("cur_rank_embeddings", 0);
embedding_shuffle_cur_rank_embeddings_ = cur_rank_embeddings->dptr();
this->InitTmpBufferPtr(ctx);
}

const void* EmbeddingShuffleCurRankEmbeddings(int64_t iter) override {
Expand All @@ -356,15 +343,6 @@ class StaticAllocationEmbeddingState final : public EmbeddingState {

void OnEmbeddingShuffleEnd(user_op::KernelComputeContext* ctx, int64_t iter) override {
embedding_shuffle_cur_rank_embeddings_ = nullptr;
this->ResetTmpBufferPtr();
}

void OnEmbeddingGradientShuffleStart(user_op::KernelComputeContext* ctx, int64_t iter) override {
this->InitTmpBufferPtr(ctx);
}

void OnEmbeddingGradientShuffleEnd(user_op::KernelComputeContext* ctx, int64_t iter) override {
this->ResetTmpBufferPtr();
}

void OnEmbeddingUpdateStart(user_op::KernelComputeContext* ctx, int64_t iter) override {
Expand Down Expand Up @@ -414,31 +392,6 @@ class StaticAllocationEmbeddingState final : public EmbeddingState {
embedding_fused_update_put_unique_embeddings_ = nullptr;
}

void AllocPrefetchTmpBuffer(user_op::KernelComputeContext* ctx, void** ptr,
size_t size) override {
CHECK(prefetch_tmp_buffer_ptr_ != nullptr);
CHECK_GE(prefetch_tmp_buffer_offset_, 0);
CHECK_LE(prefetch_tmp_buffer_offset_ + size, prefetch_tmp_buffer_size_);
*ptr = reinterpret_cast<char*>(prefetch_tmp_buffer_ptr_) + prefetch_tmp_buffer_offset_;
prefetch_tmp_buffer_offset_ += size;
}

void FreePrefetchTmpBuffer(user_op::KernelComputeContext* ctx, void* ptr) override {
// do nothing
}

void AllocTmpBuffer(user_op::KernelComputeContext* ctx, void** ptr, size_t size) override {
CHECK(tmp_buffer_ptr_ != nullptr);
CHECK_GE(tmp_buffer_offset_, 0);
CHECK_LE(tmp_buffer_offset_ + size, tmp_buffer_size_);
*ptr = reinterpret_cast<char*>(tmp_buffer_ptr_) + tmp_buffer_offset_;
tmp_buffer_offset_ += size;
}

void FreeTmpBuffer(user_op::KernelComputeContext* ctx, void* ptr) override {
// do nothing
}

void SetIdFinalNumUnique(uint32_t final_num_unique, int64_t iter) override {
std::unique_lock<std::mutex> lock(mutex_);
int64_t index = iter % kRingBufferSize;
Expand Down Expand Up @@ -480,12 +433,6 @@ class StaticAllocationEmbeddingState final : public EmbeddingState {
const void* embedding_put_unique_embeddings_;
const void* embedding_fused_update_put_unique_embeddings_;
std::vector<IdStatistics> id_statistics_vec_;
void* tmp_buffer_ptr_;
int64_t tmp_buffer_offset_;
size_t tmp_buffer_size_;
void* prefetch_tmp_buffer_ptr_;
int64_t prefetch_tmp_buffer_offset_;
size_t prefetch_tmp_buffer_size_;
std::mutex mutex_;
};

Expand Down
24 changes: 11 additions & 13 deletions oneflow/core/embedding/embedding_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,22 @@ inline bool UseDynamicMemoryAllocation() {

#ifdef WITH_CUDA

class TmpBufferAllocator {
public:
TmpBufferAllocator() = default;
virtual ~TmpBufferAllocator() = default;

virtual void Allocate(void** ptr, size_t size) = 0;
virtual void Free(void* ptr) = 0;
};

class EmbeddingState {
public:
EmbeddingState() = default;
virtual ~EmbeddingState() = default;

virtual void OnEmbeddingPrefetchStart(user_op::KernelComputeContext* ctx, int64_t iter) = 0;
virtual void OnEmbeddingPrefetchEnd(user_op::KernelComputeContext* ctx, int64_t iter) = 0;
virtual std::unique_ptr<TmpBufferAllocator> NewTmpBufferAllocator(
user_op::KernelComputeContext* ctx) = 0;

virtual void OnEmbeddingLookupStart(user_op::KernelComputeContext* ctx, int64_t iter) = 0;
virtual void* LookupUniqueValues(int64_t iter) = 0;
Expand All @@ -59,10 +68,6 @@ class EmbeddingState {
virtual const void* EmbeddingShuffleCurRankEmbeddings(int64_t iter) = 0;
virtual void OnEmbeddingShuffleEnd(user_op::KernelComputeContext* ctx, int64_t iter) = 0;

virtual void OnEmbeddingGradientShuffleStart(user_op::KernelComputeContext* ctx,
int64_t iter) = 0;
virtual void OnEmbeddingGradientShuffleEnd(user_op::KernelComputeContext* ctx, int64_t iter) = 0;

virtual void OnEmbeddingUpdateStart(user_op::KernelComputeContext* ctx, int64_t iter) = 0;
virtual const void* EmbeddingUpdateUniqueEmbeddings(int64_t iter) = 0;
virtual void* EmbeddingUpdateUpdatedUniqueEmbeddings(int64_t iter) = 0;
Expand All @@ -76,13 +81,6 @@ class EmbeddingState {
virtual const void* EmbeddingFusedUpdatePutUniqueEmbeddings(int64_t iter) = 0;
virtual void OnEmbeddingFusedUpdatePutEnd(user_op::KernelComputeContext* ctx, int64_t iter) = 0;

virtual void AllocPrefetchTmpBuffer(user_op::KernelComputeContext* ctx, void** ptr,
size_t size) = 0;
virtual void FreePrefetchTmpBuffer(user_op::KernelComputeContext* ctx, void* ptr) = 0;

virtual void AllocTmpBuffer(user_op::KernelComputeContext* ctx, void** ptr, size_t size) = 0;
virtual void FreeTmpBuffer(user_op::KernelComputeContext* ctx, void* ptr) = 0;

virtual void SetIdFinalNumUnique(uint32_t final_num_unique, int64_t iter) = 0;
virtual void SetIdNumUniqueMatrix(const std::vector<uint32_t>& num_unique_matrix,
int64_t iter) = 0;
Expand Down
Loading

0 comments on commit dd580f2

Please sign in to comment.