Skip to content

Commit

Permalink
[OneEmbedding] optimize unsorted_segment_sum when col is odd (#8204)
Browse files Browse the repository at this point in the history
* optimize unsorted_segment_sum when col is odd

* address review

Co-authored-by: ZZK <42901638+MARD1NO@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored May 15, 2022
1 parent 4b3633d commit 4787170
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 59 deletions.
220 changes: 168 additions & 52 deletions oneflow/user/kernels/data_shuffle_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ limitations under the License.
#include "oneflow/core/cuda/atomic.cuh"
#include "oneflow/core/embedding/hash_functions.cuh"
#include "oneflow/core/cuda/elementwise.cuh"
#include "oneflow/core/ep/include/primitive/copy_nd.h"
#include "oneflow/core/cuda/atomic.cuh"

namespace oneflow {

Expand Down Expand Up @@ -1099,6 +1101,131 @@ void ShuffleEmbeddingsGrad(cudaStream_t cuda_stream, ncclComm_t comm, int64_t pa
recv_offsets, recv_elem_cnt, received_cur_rank_quantize_factor);
}

template<typename K, typename IDX>
__global__ void UnsortedSegmentHalfGpu(const IDX in_h2_elem_cnt, const IDX h2_inner_dim_size,
const IDX inner_dim_size, const half* data,
const K* segment_ids, const IDX num_segments,
half2* out_h2) {
CUDA_1D_KERNEL_LOOP_T(IDX, i, in_h2_elem_cnt) {
const IDX segment_id_idx = i / h2_inner_dim_size;
const IDX h2_inner_idx = i - segment_id_idx * h2_inner_dim_size;
const IDX inner_idx_0 = 2 * h2_inner_idx;
const IDX inner_idx_1 = inner_idx_0 + 1;
const half* data_row = data + segment_id_idx * inner_dim_size;
half2 val;
val.x = data_row[inner_idx_0];
val.y = (inner_idx_1 >= inner_dim_size) ? static_cast<half>(0) : data_row[inner_idx_1];
const IDX idx = segment_ids[segment_id_idx];
const IDX out_h2_offset = idx * h2_inner_dim_size + h2_inner_idx;
cuda::atomic::Add(out_h2 + out_h2_offset, val);
}
}

template<typename T, typename K>
struct UnsortedSegmentSumPad {
void operator()(ep::Stream* stream, const K* segment_ids, const T* data, int64_t num_segment_ids,
int64_t num_segments, int64_t inner_dim_size, int64_t padded_inner_dim_size,
T* out) const {
UNIMPLEMENTED();
}
};

template<typename K>
struct UnsortedSegmentSumPad<half, K> {
void operator()(ep::Stream* stream, const K* segment_ids, const half* data,
int64_t num_segment_ids, int64_t num_segments, int64_t inner_dim_size,
int64_t padded_inner_dim_size, half* out) const {
const int64_t data_elem_cnt = num_segment_ids * inner_dim_size;
const int64_t out_elem_cnt = num_segments * padded_inner_dim_size;
CHECK_EQ(padded_inner_dim_size % 2, 0);
CHECK_EQ(inner_dim_size + 1, padded_inner_dim_size);
const int64_t h2_inner_dim_size = padded_inner_dim_size / 2;
const int64_t in_h2_elem_cnt = num_segment_ids * h2_inner_dim_size;
if (std::max(data_elem_cnt, out_elem_cnt) < GetMaxVal<int32_t>() / 2) {
UnsortedSegmentHalfGpu<K, int32_t>
<<<BlocksNum4ThreadsNum(in_h2_elem_cnt), kCudaThreadsNumPerBlock, 0,
stream->As<ep::CudaStream>()->cuda_stream()>>>(
in_h2_elem_cnt, h2_inner_dim_size, inner_dim_size, data, segment_ids, num_segments,
reinterpret_cast<half2*>(out));
} else {
UnsortedSegmentHalfGpu<K, int64_t>
<<<BlocksNum4ThreadsNum(in_h2_elem_cnt), kCudaThreadsNumPerBlock, 0,
stream->As<ep::CudaStream>()->cuda_stream()>>>(
in_h2_elem_cnt, h2_inner_dim_size, inner_dim_size, data, segment_ids, num_segments,
reinterpret_cast<half2*>(out));
}
}
};
template<typename T, typename K>
void UnsortedSegmentSum(ep::Stream* stream, const K* segment_ids, const T* data,
int64_t num_segment_ids, int64_t num_segments, int64_t inner_dim_size,
int64_t padded_inner_dim_size, T* out) {
if (inner_dim_size == padded_inner_dim_size) {
UnsortedSegmentSumKernelUtil<DeviceType::kCUDA, T, K, T>::UnsortedSegmentSum(
stream, segment_ids, data, num_segment_ids, num_segments, 1, inner_dim_size, 0, out);
} else {
CHECK_EQ(inner_dim_size + 1, padded_inner_dim_size);
UnsortedSegmentSumPad<T, K>()(stream, segment_ids, data, num_segment_ids, num_segments,
inner_dim_size, padded_inner_dim_size, out);
}
}
template<typename T, typename IDX>
void UniquePartitionEmbeddingGrad(ep::Stream* stream, int64_t parallel_id, int64_t parallel_num,
int64_t num_ids, int64_t embedding_size,
int64_t padded_embedding_size, const IDX* host_num_unique_matrix,
const T* embedding_grad,
const IDX* inverse_unique_partition_indices,
T* unique_partition_embedding_grad) {
for (int64_t i = 0; i < parallel_num; ++i) {
const int64_t offset = i * num_ids * padded_embedding_size;
const int64_t valid_value_size =
host_num_unique_matrix[parallel_id * parallel_num + i] * padded_embedding_size * sizeof(T);
OF_CUDA_CHECK(cudaMemsetAsync(unique_partition_embedding_grad + offset, 0, valid_value_size,
stream->As<ep::CudaStream>()->cuda_stream()));
}
UnsortedSegmentSum<T, IDX>(stream, inverse_unique_partition_indices, embedding_grad, num_ids,
parallel_num * num_ids, embedding_size, padded_embedding_size,
unique_partition_embedding_grad);
}
template<typename T, typename IDX>
void UniqueCurRankEmbeddingGrad(ep::Stream* stream, DataType data_type, int64_t cur_rank_num_ids,
int64_t embedding_size, int64_t padded_embedding_size,
const T* cur_rank_embedding_grad,
const IDX* cur_rank_inverse_indices,
T* cur_rank_unique_embedding_grad, T* tmp_buffer) {
T* unsorted_segment_sum_out =
(embedding_size == padded_embedding_size) ? cur_rank_unique_embedding_grad : tmp_buffer;
OF_CUDA_CHECK(cudaMemsetAsync(unsorted_segment_sum_out, 0,
cur_rank_num_ids * padded_embedding_size * sizeof(T),
stream->As<ep::CudaStream>()->cuda_stream()));
UnsortedSegmentSum<T, IDX>(stream, cur_rank_inverse_indices, cur_rank_embedding_grad,
cur_rank_num_ids, cur_rank_num_ids, padded_embedding_size,
padded_embedding_size, unsorted_segment_sum_out);
if (embedding_size != padded_embedding_size) {
std::unique_ptr<ep::primitive::CopyNd> primitive =
ep::primitive::NewPrimitive<ep::primitive::CopyNdFactory>(DeviceType::kCUDA, 2);
DimVector dst_shape = {cur_rank_num_ids, embedding_size};
DimVector dst_pos_vec = {0, 0};
DimVector src_shape = {cur_rank_num_ids, padded_embedding_size};
DimVector src_pos_vec = {0, 0};
DimVector extent_vec = {cur_rank_num_ids, embedding_size};
primitive->Launch(stream, data_type, 2, cur_rank_unique_embedding_grad, dst_shape.data(),
dst_pos_vec.data(), unsorted_segment_sum_out, src_shape.data(),
src_pos_vec.data(), extent_vec.data());
}
}
int64_t GetPaddedEmbeddingSize(DataType data_type, int64_t embedding_size) {
if (data_type == DataType::kFloat16 && embedding_size % 2 != 0) {
return embedding_size + 1;
} else {
return embedding_size;
}
}
template<typename T, typename IDX>
class EmbeddingGradientShuffleKernel final : public user_op::OpKernel {
public:
Expand Down Expand Up @@ -1131,12 +1258,14 @@ class EmbeddingGradientShuffleKernel final : public user_op::OpKernel {
const int64_t num_ids = inverse_unique_partition_indices->shape().elem_cnt();
const int64_t parallel_num = ctx->parallel_ctx().parallel_num();
const int64_t parallel_id = ctx->parallel_ctx().parallel_id();
const int64_t padded_embedding_size = GetPaddedEmbeddingSize(data_type, embedding_size);
user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
ncclComm_t comm = kernel_state->comm();
using ComputeType = typename DefaultComputeType<T>::type;
bool enable_quantized_comm_env_var =
ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM", false);
bool enable_quantized_comm = enable_quantized_comm_env_var && (embedding_size < kMaxColSize);
bool enable_quantized_comm =
enable_quantized_comm_env_var && (padded_embedding_size < kMaxColSize);
if (enable_quantized_comm_env_var && !enable_quantized_comm) {
LOG(WARNING) << "Only envrionment variable ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM=1 and "
"embedding_size less equal than 1024 can use quantized communication. ";
Expand All @@ -1152,7 +1281,7 @@ class EmbeddingGradientShuffleKernel final : public user_op::OpKernel {
cur_rank_num_ids += host_num_unique_matrix[i * parallel_num + parallel_id];
}
size_t full_num_ids = parallel_num * num_ids;
size_t full_elem_cnt = full_num_ids * embedding_size;
size_t full_elem_cnt = full_num_ids * padded_embedding_size;
size_t unique_partition_embedding_grad_size = GetCudaAlignedSize(full_elem_cnt * sizeof(T));
if (!enable_quantized_comm) {
Expand All @@ -1163,30 +1292,22 @@ class EmbeddingGradientShuffleKernel final : public user_op::OpKernel {
CHECK_GE(tmp_buffer->shape().elem_cnt(),
unique_partition_embedding_grad_size + received_embedding_grad_size);
// unique and partition embedding grad
for (int64_t i = 0; i < parallel_num; ++i) {
const int64_t offset = i * num_ids * embedding_size;
const int64_t valid_value_size =
host_num_unique_matrix[parallel_id * parallel_num + i] * embedding_size * sizeof(T);
OF_CUDA_CHECK(cudaMemsetAsync(unique_partition_embedding_grad + offset, 0, valid_value_size,
cuda_stream));
}
UnsortedSegmentSumKernelUtil<DeviceType::kCUDA, T, IDX, T>::UnsortedSegmentSum(
ctx->stream(), reinterpret_cast<const IDX*>(inverse_unique_partition_indices->dptr()),
embedding_grad->dptr<T>(), num_ids, parallel_num * num_ids, 1, embedding_size, 0,
UniquePartitionEmbeddingGrad(
ctx->stream(), parallel_id, parallel_num, num_ids, embedding_size, padded_embedding_size,
host_num_unique_matrix, embedding_grad->dptr<T>(),
reinterpret_cast<const IDX*>(inverse_unique_partition_indices->dptr()),
unique_partition_embedding_grad);
ShuffleEmbeddingsGrad(cuda_stream, comm, parallel_id, parallel_num, num_ids, embedding_size,
data_type, host_num_unique_matrix, unique_partition_embedding_grad,
received_embedding_grad);
ShuffleEmbeddingsGrad(cuda_stream, comm, parallel_id, parallel_num, num_ids,
padded_embedding_size, data_type, host_num_unique_matrix,
unique_partition_embedding_grad, received_embedding_grad);
// unique cur_rank embedding grad
OF_CUDA_CHECK(cudaMemsetAsync(cur_rank_unique_embedding_grad->mut_dptr(), 0,
cur_rank_num_ids * embedding_size * sizeof(T), cuda_stream));
UnsortedSegmentSumKernelUtil<DeviceType::kCUDA, T, IDX, T>::UnsortedSegmentSum(
ctx->stream(), reinterpret_cast<const IDX*>(cur_rank_inverse_indices->dptr()),
received_embedding_grad, cur_rank_num_ids, cur_rank_num_ids, 1, embedding_size, 0,
cur_rank_unique_embedding_grad->mut_dptr<T>());
// use unique_partition_embedding_grad as UniqueCurRankEmbeddingGrad buffer.
T* buffer_ptr = unique_partition_embedding_grad;
UniqueCurRankEmbeddingGrad(ctx->stream(), data_type, cur_rank_num_ids, embedding_size,
padded_embedding_size, received_embedding_grad,
reinterpret_cast<const IDX*>(cur_rank_inverse_indices->dptr()),
cur_rank_unique_embedding_grad->mut_dptr<T>(), buffer_ptr);
} else {
size_t received_embedding_grad_size = GetCudaAlignedSize(full_elem_cnt * sizeof(int8_t));
size_t quantize_cur_rank_embedding_grad_size = received_embedding_grad_size;
Expand Down Expand Up @@ -1218,35 +1339,28 @@ class EmbeddingGradientShuffleKernel final : public user_op::OpKernel {
+ received_embedding_grad_size + quantize_cur_rank_embedding_grad_size
+ cur_rank_quantize_factor_size + received_cur_rank_quantize_factor_size);
// unique and partition embedding grad
for (int64_t i = 0; i < parallel_num; ++i) {
const int64_t offset = i * num_ids * embedding_size;
const int64_t valid_value_size =
host_num_unique_matrix[parallel_id * parallel_num + i] * embedding_size * sizeof(T);
OF_CUDA_CHECK(cudaMemsetAsync(unique_partition_embedding_grad + offset, 0, valid_value_size,
cuda_stream));
}

UnsortedSegmentSumKernelUtil<DeviceType::kCUDA, T, IDX, T>::UnsortedSegmentSum(
ctx->stream(), reinterpret_cast<const IDX*>(inverse_unique_partition_indices->dptr()),
embedding_grad->dptr<T>(), num_ids, parallel_num * num_ids, 1, embedding_size, 0,
UniquePartitionEmbeddingGrad(
ctx->stream(), parallel_id, parallel_num, num_ids, embedding_size, padded_embedding_size,
host_num_unique_matrix, embedding_grad->dptr<T>(),
reinterpret_cast<const IDX*>(inverse_unique_partition_indices->dptr()),
unique_partition_embedding_grad);
// Quantize.
for (int64_t i = 0; i < parallel_num; ++i) {
const int64_t embedding_grad_offset = i * num_ids * embedding_size;
const int64_t embedding_grad_offset = i * num_ids * padded_embedding_size;
const int64_t quantize_factor_offset = i * num_ids;
const int64_t valid_row_size = host_num_unique_matrix[parallel_id * parallel_num + i];
DispatchQuantizeWarpImplPackSize<T, ComputeType>()(
cuda_stream, unique_partition_embedding_grad + embedding_grad_offset,
quantize_cur_rank_embedding_grad + embedding_grad_offset,
cur_rank_quantize_factor + quantize_factor_offset, valid_row_size, embedding_size);
cur_rank_quantize_factor + quantize_factor_offset, valid_row_size,
padded_embedding_size);
}
ShuffleEmbeddingsGrad(cuda_stream, comm, parallel_id, parallel_num, num_ids, embedding_size,
data_type, host_num_unique_matrix, quantize_cur_rank_embedding_grad,
received_embedding_grad, cur_rank_quantize_factor,
received_cur_rank_quantize_factor);
ShuffleEmbeddingsGrad(cuda_stream, comm, parallel_id, parallel_num, num_ids,
padded_embedding_size, data_type, host_num_unique_matrix,
quantize_cur_rank_embedding_grad, received_embedding_grad,
cur_rank_quantize_factor, received_cur_rank_quantize_factor);
int64_t dequantize_cur_rank_num = 0;
for (int64_t i = 0; i < parallel_num; ++i) {
Expand All @@ -1262,17 +1376,16 @@ class EmbeddingGradientShuffleKernel final : public user_op::OpKernel {
*/
dequantize_cur_rank_num += host_num_unique_matrix[i * parallel_num + parallel_id];
}
IDX dequantize_elem_cnt = dequantize_cur_rank_num * embedding_size;
IDX dequantize_elem_cnt = dequantize_cur_rank_num * padded_embedding_size;
OF_CUDA_CHECK((LaunchDequantizeKernel<T, ComputeType, IDX>(
cuda_stream, received_embedding_grad, received_cur_rank_quantize_factor,
dequantize_cur_rank_embedding_grad, embedding_size, dequantize_elem_cnt)));
// unique cur_rank embedding grad
OF_CUDA_CHECK(cudaMemsetAsync(cur_rank_unique_embedding_grad->mut_dptr(), 0,
cur_rank_num_ids * embedding_size * sizeof(T), cuda_stream));
UnsortedSegmentSumKernelUtil<DeviceType::kCUDA, T, IDX, T>::UnsortedSegmentSum(
ctx->stream(), reinterpret_cast<const IDX*>(cur_rank_inverse_indices->dptr()),
dequantize_cur_rank_embedding_grad, cur_rank_num_ids, cur_rank_num_ids, 1, embedding_size,
0, cur_rank_unique_embedding_grad->mut_dptr<T>());
dequantize_cur_rank_embedding_grad, padded_embedding_size, dequantize_elem_cnt)));
// use unique_partition_embedding_grad as UniqueCurRankEmbeddingGrad buffer.
T* buffer_ptr = unique_partition_embedding_grad;
UniqueCurRankEmbeddingGrad(ctx->stream(), data_type, cur_rank_num_ids, embedding_size,
padded_embedding_size, dequantize_cur_rank_embedding_grad,
reinterpret_cast<const IDX*>(cur_rank_inverse_indices->dptr()),
cur_rank_unique_embedding_grad->mut_dptr<T>(), buffer_ptr);
}
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
Expand All @@ -1291,10 +1404,13 @@ class EmbeddingGradientShuffleKernel final : public user_op::OpKernel {
ctx->InputTensorDesc("cur_rank_unique_embedding_grad", 0); \
size_t cur_rank_embedding_grad_num = cur_rank_unique_embedding_grad.shape().At(0); \
size_t embedding_size = cur_rank_unique_embedding_grad.shape().At(1); \
size_t cur_rank_embedding_grad_elem_cnt = cur_rank_embedding_grad_num * embedding_size; \
size_t padded_embedding_size = \
GetPaddedEmbeddingSize(cur_rank_unique_embedding_grad.data_type(), embedding_size); \
size_t cur_rank_embedding_grad_elem_cnt = \
cur_rank_embedding_grad_num * padded_embedding_size; \
bool enable_quantized_comm = \
ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM", false) \
&& (embedding_size < kMaxColSize); \
&& (padded_embedding_size < kMaxColSize); \
size_t tmp_size = 0; \
if (!enable_quantized_comm) { \
size_t cur_rank_embedding_grad_size = GetCudaAlignedSize( \
Expand Down
Loading

0 comments on commit 4787170

Please sign in to comment.