Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OneEmbedding add tmp_buffer allocator #8588

Merged
merged 19 commits into from
Jul 13, 2022
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 51 additions & 104 deletions oneflow/core/embedding/embedding_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,23 @@ struct IdStatistics {

#if CUDA_VERSION >= 11020

class DynamicTmpBufferAllocator final : public TmpBufferAllocator {
public:
OF_DISALLOW_COPY_AND_MOVE(DynamicTmpBufferAllocator);
DynamicTmpBufferAllocator(cudaStream_t stream, cudaMemPool_t pool)
: stream_(stream), mem_pool_(pool) {}
~DynamicTmpBufferAllocator() override = default;

void Allocate(void** ptr, size_t size) override {
OF_CUDA_CHECK(cudaMallocFromPoolAsync(ptr, GetCudaAlignedSize(size), mem_pool_, stream_));
}
void Free(void* ptr) override { OF_CUDA_CHECK(cudaFreeAsync(ptr, stream_)); }

private:
cudaStream_t stream_{};
cudaMemPool_t mem_pool_{};
};

class DynamicAllocationEmbeddingState final : public EmbeddingState {
public:
OF_DISALLOW_COPY_AND_MOVE(DynamicAllocationEmbeddingState);
Expand Down Expand Up @@ -67,12 +84,10 @@ class DynamicAllocationEmbeddingState final : public EmbeddingState {
OF_CUDA_CHECK(cudaMemPoolDestroy(mem_pool_));
}

void OnEmbeddingPrefetchStart(user_op::KernelComputeContext* ctx, int64_t iter) override {
// do nothing
}

void OnEmbeddingPrefetchEnd(user_op::KernelComputeContext* ctx, int64_t iter) override {
// do nothing
std::unique_ptr<TmpBufferAllocator> NewTmpBufferAllocator(
user_op::KernelComputeContext* ctx) override {
return std::make_unique<DynamicTmpBufferAllocator>(
ctx->stream()->As<ep::CudaStream>()->cuda_stream(), mem_pool_);
}

void OnEmbeddingLookupStart(user_op::KernelComputeContext* ctx, int64_t iter) override {
Expand Down Expand Up @@ -142,14 +157,6 @@ class DynamicAllocationEmbeddingState final : public EmbeddingState {
// do nothing
}

void OnEmbeddingGradientShuffleStart(user_op::KernelComputeContext* ctx, int64_t iter) override {
// do nothing
}

void OnEmbeddingGradientShuffleEnd(user_op::KernelComputeContext* ctx, int64_t iter) override {
// do nothing
}

void OnEmbeddingUpdateStart(user_op::KernelComputeContext* ctx, int64_t iter) override {
const user_op::Tensor* updated_unique_embeddings =
ctx->Tensor4ArgNameAndIndex("updated_unique_embeddings", 0);
Expand Down Expand Up @@ -204,24 +211,6 @@ class DynamicAllocationEmbeddingState final : public EmbeddingState {
// do nothing
}

void AllocPrefetchTmpBuffer(user_op::KernelComputeContext* ctx, void** ptr,
size_t size) override {
this->AllocTmpBuffer(ctx, ptr, size);
}

void FreePrefetchTmpBuffer(user_op::KernelComputeContext* ctx, void* ptr) override {
this->FreeTmpBuffer(ctx, ptr);
}

void AllocTmpBuffer(user_op::KernelComputeContext* ctx, void** ptr, size_t size) override {
OF_CUDA_CHECK(cudaMallocFromPoolAsync(ptr, size, mem_pool_,
ctx->stream()->As<ep::CudaStream>()->cuda_stream()));
}

void FreeTmpBuffer(user_op::KernelComputeContext* ctx, void* ptr) override {
OF_CUDA_CHECK(cudaFreeAsync(ptr, ctx->stream()->As<ep::CudaStream>()->cuda_stream()));
}

void SetIdFinalNumUnique(uint32_t final_num_unique, int64_t iter) override {
std::unique_lock<std::mutex> lock(mutex_);
int64_t index = iter % kRingBufferSize;
Expand Down Expand Up @@ -271,6 +260,31 @@ class DynamicAllocationEmbeddingState final : public EmbeddingState {

#endif

class StaticTmpBufferAllocator final : public TmpBufferAllocator {
public:
OF_DISALLOW_COPY_AND_MOVE(StaticTmpBufferAllocator);
StaticTmpBufferAllocator(void* ptr, size_t size) : ptr_(ptr), offset_(0), size_(size) {}
~StaticTmpBufferAllocator() override = default;

void Allocate(void** ptr, size_t size) override {
CHECK(ptr_ != nullptr);
CHECK_GE(offset_, 0);
size_t aligned_size = GetCudaAlignedSize(size);
CHECK_LE(offset_ + aligned_size, size_);
*ptr = reinterpret_cast<char*>(ptr_) + offset_;
offset_ += aligned_size;
}

void Free(void* ptr) override {
// do nothing
}

private:
void* ptr_;
int64_t offset_;
size_t size_;
};

class StaticAllocationEmbeddingState final : public EmbeddingState {
public:
OF_DISALLOW_COPY_AND_MOVE(StaticAllocationEmbeddingState);
Expand All @@ -282,40 +296,16 @@ class StaticAllocationEmbeddingState final : public EmbeddingState {
embeding_update_unique_embeddings_(nullptr),
embeding_update_updated_unique_embeddings_(nullptr),
embedding_put_unique_embeddings_(nullptr),
tmp_buffer_ptr_(nullptr),
tmp_buffer_offset_(0),
tmp_buffer_size_(0),
prefetch_tmp_buffer_ptr_(nullptr),
prefetch_tmp_buffer_offset_(0),
prefetch_tmp_buffer_size_(0) {
embedding_fused_update_put_unique_embeddings_(nullptr) {
id_statistics_vec_.resize(kRingBufferSize);
}
~StaticAllocationEmbeddingState() override = default;

void InitTmpBufferPtr(user_op::KernelComputeContext* ctx) {
std::unique_ptr<TmpBufferAllocator> NewTmpBufferAllocator(
user_op::KernelComputeContext* ctx) override {
user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
tmp_buffer_ptr_ = tmp_buffer->mut_dptr();
tmp_buffer_offset_ = 0;
tmp_buffer_size_ = tmp_buffer->shape_view().elem_cnt();
}

void ResetTmpBufferPtr() {
tmp_buffer_ptr_ = nullptr;
tmp_buffer_offset_ = 0;
tmp_buffer_size_ = 0;
}

void OnEmbeddingPrefetchStart(user_op::KernelComputeContext* ctx, int64_t iter) override {
user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
prefetch_tmp_buffer_ptr_ = tmp_buffer->mut_dptr();
prefetch_tmp_buffer_offset_ = 0;
prefetch_tmp_buffer_size_ = tmp_buffer->shape_view().elem_cnt();
}

void OnEmbeddingPrefetchEnd(user_op::KernelComputeContext* ctx, int64_t iter) override {
prefetch_tmp_buffer_ptr_ = nullptr;
prefetch_tmp_buffer_offset_ = 0;
prefetch_tmp_buffer_size_ = 0;
return std::make_unique<StaticTmpBufferAllocator>(tmp_buffer->mut_dptr(),
tmp_buffer->shape_view().elem_cnt());
}

void OnEmbeddingLookupStart(user_op::KernelComputeContext* ctx, int64_t iter) override {
Expand All @@ -326,7 +316,6 @@ class StaticAllocationEmbeddingState final : public EmbeddingState {
has_lookup_embeddings_ = true;
lookup_embeddings_ = embeddings->mut_dptr();
}
this->InitTmpBufferPtr(ctx);
}

void* LookupUniqueValues(int64_t iter) override { return lookup_unique_values_; }
Expand All @@ -340,14 +329,12 @@ class StaticAllocationEmbeddingState final : public EmbeddingState {
lookup_unique_values_ = nullptr;
lookup_embeddings_ = nullptr;
has_lookup_embeddings_ = false;
this->ResetTmpBufferPtr();
}

void OnEmbeddingShuffleStart(user_op::KernelComputeContext* ctx, int64_t iter) override {
const user_op::Tensor* cur_rank_embeddings =
ctx->Tensor4ArgNameAndIndex("cur_rank_embeddings", 0);
embedding_shuffle_cur_rank_embeddings_ = cur_rank_embeddings->dptr();
this->InitTmpBufferPtr(ctx);
}

const void* EmbeddingShuffleCurRankEmbeddings(int64_t iter) override {
Expand All @@ -356,15 +343,6 @@ class StaticAllocationEmbeddingState final : public EmbeddingState {

void OnEmbeddingShuffleEnd(user_op::KernelComputeContext* ctx, int64_t iter) override {
embedding_shuffle_cur_rank_embeddings_ = nullptr;
this->ResetTmpBufferPtr();
}

void OnEmbeddingGradientShuffleStart(user_op::KernelComputeContext* ctx, int64_t iter) override {
this->InitTmpBufferPtr(ctx);
}

void OnEmbeddingGradientShuffleEnd(user_op::KernelComputeContext* ctx, int64_t iter) override {
this->ResetTmpBufferPtr();
}

void OnEmbeddingUpdateStart(user_op::KernelComputeContext* ctx, int64_t iter) override {
Expand Down Expand Up @@ -414,31 +392,6 @@ class StaticAllocationEmbeddingState final : public EmbeddingState {
embedding_fused_update_put_unique_embeddings_ = nullptr;
}

void AllocPrefetchTmpBuffer(user_op::KernelComputeContext* ctx, void** ptr,
size_t size) override {
CHECK(prefetch_tmp_buffer_ptr_ != nullptr);
CHECK_GE(prefetch_tmp_buffer_offset_, 0);
CHECK_LE(prefetch_tmp_buffer_offset_ + size, prefetch_tmp_buffer_size_);
*ptr = reinterpret_cast<char*>(prefetch_tmp_buffer_ptr_) + prefetch_tmp_buffer_offset_;
prefetch_tmp_buffer_offset_ += size;
}

void FreePrefetchTmpBuffer(user_op::KernelComputeContext* ctx, void* ptr) override {
// do nothing
}

void AllocTmpBuffer(user_op::KernelComputeContext* ctx, void** ptr, size_t size) override {
CHECK(tmp_buffer_ptr_ != nullptr);
CHECK_GE(tmp_buffer_offset_, 0);
CHECK_LE(tmp_buffer_offset_ + size, tmp_buffer_size_);
*ptr = reinterpret_cast<char*>(tmp_buffer_ptr_) + tmp_buffer_offset_;
tmp_buffer_offset_ += size;
}

void FreeTmpBuffer(user_op::KernelComputeContext* ctx, void* ptr) override {
// do nothing
}

void SetIdFinalNumUnique(uint32_t final_num_unique, int64_t iter) override {
std::unique_lock<std::mutex> lock(mutex_);
int64_t index = iter % kRingBufferSize;
Expand Down Expand Up @@ -480,12 +433,6 @@ class StaticAllocationEmbeddingState final : public EmbeddingState {
const void* embedding_put_unique_embeddings_;
const void* embedding_fused_update_put_unique_embeddings_;
std::vector<IdStatistics> id_statistics_vec_;
void* tmp_buffer_ptr_;
int64_t tmp_buffer_offset_;
size_t tmp_buffer_size_;
void* prefetch_tmp_buffer_ptr_;
int64_t prefetch_tmp_buffer_offset_;
size_t prefetch_tmp_buffer_size_;
std::mutex mutex_;
};

Expand Down
24 changes: 11 additions & 13 deletions oneflow/core/embedding/embedding_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,22 @@ inline bool UseDynamicMemoryAllocation() {

#ifdef WITH_CUDA

class TmpBufferAllocator {
public:
TmpBufferAllocator() = default;
virtual ~TmpBufferAllocator() = default;

virtual void Allocate(void** ptr, size_t size) = 0;
virtual void Free(void* ptr) = 0;
};

class EmbeddingState {
public:
EmbeddingState() = default;
virtual ~EmbeddingState() = default;

virtual void OnEmbeddingPrefetchStart(user_op::KernelComputeContext* ctx, int64_t iter) = 0;
virtual void OnEmbeddingPrefetchEnd(user_op::KernelComputeContext* ctx, int64_t iter) = 0;
virtual std::unique_ptr<TmpBufferAllocator> NewTmpBufferAllocator(
user_op::KernelComputeContext* ctx) = 0;

virtual void OnEmbeddingLookupStart(user_op::KernelComputeContext* ctx, int64_t iter) = 0;
virtual void* LookupUniqueValues(int64_t iter) = 0;
Expand All @@ -59,10 +68,6 @@ class EmbeddingState {
virtual const void* EmbeddingShuffleCurRankEmbeddings(int64_t iter) = 0;
virtual void OnEmbeddingShuffleEnd(user_op::KernelComputeContext* ctx, int64_t iter) = 0;

virtual void OnEmbeddingGradientShuffleStart(user_op::KernelComputeContext* ctx,
int64_t iter) = 0;
virtual void OnEmbeddingGradientShuffleEnd(user_op::KernelComputeContext* ctx, int64_t iter) = 0;

virtual void OnEmbeddingUpdateStart(user_op::KernelComputeContext* ctx, int64_t iter) = 0;
virtual const void* EmbeddingUpdateUniqueEmbeddings(int64_t iter) = 0;
virtual void* EmbeddingUpdateUpdatedUniqueEmbeddings(int64_t iter) = 0;
Expand All @@ -76,13 +81,6 @@ class EmbeddingState {
virtual const void* EmbeddingFusedUpdatePutUniqueEmbeddings(int64_t iter) = 0;
virtual void OnEmbeddingFusedUpdatePutEnd(user_op::KernelComputeContext* ctx, int64_t iter) = 0;

virtual void AllocPrefetchTmpBuffer(user_op::KernelComputeContext* ctx, void** ptr,
size_t size) = 0;
virtual void FreePrefetchTmpBuffer(user_op::KernelComputeContext* ctx, void* ptr) = 0;

virtual void AllocTmpBuffer(user_op::KernelComputeContext* ctx, void** ptr, size_t size) = 0;
virtual void FreeTmpBuffer(user_op::KernelComputeContext* ctx, void* ptr) = 0;

virtual void SetIdFinalNumUnique(uint32_t final_num_unique, int64_t iter) = 0;
virtual void SetIdNumUniqueMatrix(const std::vector<uint32_t>& num_unique_matrix,
int64_t iter) = 0;
Expand Down
Loading