Skip to content

Commit

Permalink
[LLM INFER] Optimize fuse some kernels in postprocess (#9201)
Browse files Browse the repository at this point in the history
* optimize fuse some kernels

* optimize fuse some kernels

* fix top_p reject

* fix

* ci

* fix review

* fix
  • Loading branch information
gzy19990617 authored Nov 6, 2024
1 parent 19a2e1f commit 0977858
Show file tree
Hide file tree
Showing 10 changed files with 837 additions and 58 deletions.
46 changes: 19 additions & 27 deletions csrc/gpu/get_padding_offset_v2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,14 @@
// limitations under the License.

#include "paddle/extension.h"
#include "helper.h"

__global__ void RemovePaddingV2(int64_t *output_data,
const int64_t *input_data,
const int *seq_lens,
const int *cum_offsets,
const int sequence_length) {
const int bi = blockIdx.x;
const int tid = threadIdx.x;

for (int i = tid; i < seq_lens[bi]; i += blockDim.x) {
const int tgt_seq_id = bi * sequence_length - cum_offsets[bi] + i;
const int src_seq_id = bi * sequence_length + i;
output_data[tgt_seq_id] = input_data[src_seq_id];
}
}

__global__ void GetPaddingOffsetKernelV2(int *padding_offset,
__global__ void GetPaddingOffsetV2Kernel(int *padding_offset,
int *cum_offsets_out,
int *cu_seqlens_q,
int *cu_seqlens_k,
int64_t *output_data,
const int64_t *input_data,
const int *cum_offsets,
const int *seq_lens,
const int max_seq_len) {
Expand All @@ -42,8 +30,15 @@ __global__ void GetPaddingOffsetKernelV2(int *padding_offset,
int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1];
for (int i = ti; i < seq_lens[bi]; i += blockDim.x) {
padding_offset[bi * max_seq_len - cum_offset + i] = cum_offset;
const int tgt_seq_id = bi * max_seq_len - cum_offset + i;
const int src_seq_id = bi * max_seq_len + i;
output_data[tgt_seq_id] = input_data[src_seq_id];
}
if (ti == 0) {
if (bi == 0) {
cu_seqlens_q[0] = 0;
cu_seqlens_k[0] = 0;
}
cum_offsets_out[bi] = cum_offset;
int cum_seq_len = (bi + 1) * max_seq_len - cum_offsets[bi];
cu_seqlens_q[bi + 1] = cum_seq_len;
Expand All @@ -64,24 +59,21 @@ std::vector<paddle::Tensor> GetPaddingOffsetV2(const paddle::Tensor& input_ids,
auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false);

const int token_num_data = cpu_token_num.data<int64_t>()[0];
auto x_remove_padding = paddle::full({token_num_data}, 0, paddle::DataType::INT64, input_ids.place());
auto padding_offset = paddle::full({token_num_data}, 0, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_q = paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_k = paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
int blockSize = min((token_num_data + 32 - 1) / 32 * 32, 128);
GetPaddingOffsetKernelV2<<<bsz, 128, 0, cu_stream>>>(

auto x_remove_padding = GetEmptyTensor({token_num_data}, paddle::DataType::INT64, input_ids.place());
auto padding_offset = GetEmptyTensor({token_num_data}, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_q = GetEmptyTensor({bsz + 1}, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_k = GetEmptyTensor({bsz + 1}, paddle::DataType::INT32, input_ids.place());

GetPaddingOffsetV2Kernel<<<bsz, 128, 0, cu_stream>>>(
padding_offset.data<int>(),
cum_offsets_out.data<int>(),
cu_seqlens_q.data<int>(),
cu_seqlens_k.data<int>(),
cum_offsets.data<int>(),
seq_len.data<int>(),
seq_length);
RemovePaddingV2<<<bsz, blockSize, 0, cu_stream>>>(
x_remove_padding.data<int64_t>(),
input_ids.data<int64_t>(),
cum_offsets.data<int>(),
seq_len.data<int>(),
cum_offsets_out.data<int>(),
seq_length);
return {x_remove_padding, cum_offsets_out, padding_offset, cu_seqlens_q, cu_seqlens_k}; // , enc_token_num, dec_token_num};
}
Expand Down
258 changes: 258 additions & 0 deletions csrc/gpu/set_preids_token_penalty_multi_scores.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "helper.h"

template<typename T>
__global__ void set_preids_token_penalty_multi_scores_kernel(const bool *stop_flags,
int64_t *pre_ids,
const int64_t *input_ids,
const int *seq_lens_encoder,
const int *seq_lens_decoder,
const int64_t *step_idx,
const T *penalty_scores,
const T *frequency_score,
const T *presence_score,
const float *temperatures,
const int64_t *cur_len,
const int64_t *min_len,
const int64_t *eos_token_id,
const int64_t *bad_words_list,
int *repeat_times,
T *logits,
const int64_t bs,
const int64_t length,
const int64_t end_length,
const int64_t length_id,
const int64_t bad_words_length,
const int64_t length_input_ids) {
int bi = blockIdx.x;
T *logits_now = logits + bi * length;
int tid = threadIdx.x;

if (tid < bs && !stop_flags[tid]) {
int64_t *pre_ids_now = pre_ids + tid * length;
const int64_t *input_ids_now = input_ids + tid * length_input_ids;
const int seq_len_dec = seq_lens_decoder[tid];
const int seq_len_enc = seq_lens_encoder[tid];
if (seq_len_dec == 0 && seq_len_enc == 0) return; // stoped

const int step_idx_now = step_idx[bi];
if (tid == 0 && step_idx_now >= 0) {
if (seq_len_enc > 0) { // encoder, get last token accord to seq_lens_encoder
pre_ids_now[step_idx_now] = input_ids_now[seq_len_enc - 1];
} else { // decoedr, get first token
pre_ids_now[step_idx_now] = input_ids_now[0];
}
}
}
__syncthreads();
// min_length process
if (bi < bs) {
if (cur_len[bi] < min_len[bi]) {
if (tid < end_length) {
logits_now[eos_token_id[tid]] = -1e10;
}
}
}
// update repeat_times
int *repeat_times_now = repeat_times + bi * length;
const int64_t *pre_ids_now = pre_ids + bi * length_id;
for (int i = tid; i < length_id; i += blockDim.x) {
int64_t id = pre_ids_now[i];
if (id < 0) break;
atomicAdd(&repeat_times_now[id], 1);
}
__syncthreads();
// penalty_scores process
float alpha = static_cast<float>(penalty_scores[bi]);
float beta = static_cast<float>(frequency_score[bi]);
float gamma = static_cast<float>(presence_score[bi]);
for (int i = tid; i < length; i += blockDim.x) {
int times = repeat_times_now[i];
float logit_now = static_cast<float>(logits_now[i]);
if (times != 0) {
logit_now = logit_now < 0 ? logit_now * alpha : logit_now / alpha;
logit_now = logit_now - times * beta - gamma;
}
logits_now[i] = static_cast<T>(logit_now / temperatures[bi]);
}
__syncthreads();
// bad_words process
for (int i = tid; i < bad_words_length; i += blockDim.x) {
const int64_t bad_words_token_id = bad_words_list[i];
if (bad_words_token_id >= length || bad_words_token_id < 0) continue;
logits_now[bad_words_token_id] = -1e10;
}
}

template <paddle::DataType D>
void set_preids_token_penalty_multi_scores(const paddle::Tensor& pre_ids,
const paddle::Tensor& input_ids,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& stop_flags,
const paddle::Tensor& logits,
const paddle::Tensor& penalty_scores,
const paddle::Tensor& frequency_score,
const paddle::Tensor& presence_score,
const paddle::Tensor& temperatures,
const paddle::Tensor& bad_tokens,
const paddle::Tensor& cur_len,
const paddle::Tensor& min_len,
const paddle::Tensor& eos_token_id) {

typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
auto cu_stream = logits.stream();
std::vector<int64_t> shape = logits.shape();
auto repeat_times = paddle::full(shape, 0, paddle::DataType::INT32, pre_ids.place());
int64_t bs = shape[0];
int64_t length = shape[1];
int64_t length_id = pre_ids.shape()[1];
int64_t length_bad_words = bad_tokens.shape()[0];
int64_t length_input_ids = input_ids.shape()[1];

int64_t end_length = eos_token_id.shape()[0];

set_preids_token_penalty_multi_scores_kernel<DataType_><<<bs, 1024, 0, cu_stream>>>(
stop_flags.data<bool>(),
const_cast<int64_t*>(pre_ids.data<int64_t>()),
input_ids.data<int64_t>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
step_idx.data<int64_t>(),
reinterpret_cast<DataType_*>(const_cast<data_t*>(penalty_scores.data<data_t>())),
reinterpret_cast<DataType_*>(const_cast<data_t*>(frequency_score.data<data_t>())),
reinterpret_cast<DataType_*>(const_cast<data_t*>(presence_score.data<data_t>())),
temperatures.data<float>(),
cur_len.data<int64_t>(),
min_len.data<int64_t>(),
eos_token_id.data<int64_t>(),
bad_tokens.data<int64_t>(),
repeat_times.data<int>(),
reinterpret_cast<DataType_*>(const_cast<data_t*>(logits.data<data_t>())),
bs,
length,
end_length,
length_id,
length_bad_words,
length_input_ids
);
}

void SetPreidsTokenPenaltyMultiScores(const paddle::Tensor& pre_ids,
const paddle::Tensor& input_ids,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& stop_flags,
const paddle::Tensor& logits,
const paddle::Tensor& penalty_scores,
const paddle::Tensor& frequency_scores,
const paddle::Tensor& presence_scores,
const paddle::Tensor& temperatures,
const paddle::Tensor& bad_tokens,
const paddle::Tensor& cur_len,
const paddle::Tensor& min_len,
const paddle::Tensor& eos_token_id) {

switch (logits.type()) {
case paddle::DataType::BFLOAT16: {
return set_preids_token_penalty_multi_scores<paddle::DataType::BFLOAT16>(
pre_ids,
input_ids,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
stop_flags,
logits,
penalty_scores,
frequency_scores,
presence_scores,
temperatures,
bad_tokens,
cur_len,
min_len,
eos_token_id
);
}
case paddle::DataType::FLOAT16: {
return set_preids_token_penalty_multi_scores<paddle::DataType::FLOAT16>(
pre_ids,
input_ids,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
stop_flags,
logits,
penalty_scores,
frequency_scores,
presence_scores,
temperatures,
bad_tokens,
cur_len,
min_len,
eos_token_id
);
}
case paddle::DataType::FLOAT32: {
return set_preids_token_penalty_multi_scores<paddle::DataType::FLOAT32>(
pre_ids,
input_ids,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
stop_flags,
logits,
penalty_scores,
frequency_scores,
presence_scores,
temperatures,
bad_tokens,
cur_len,
min_len,
eos_token_id
);
}
default: {
PD_THROW(
"NOT supported data type. "
"Only float16, bfloat16 and float32 are supported. ");
break;
}
}
}

PD_BUILD_OP(set_preids_token_penalty_multi_scores)
.Inputs({"pre_ids",
"input_ids",
"seq_lens_encoder",
"seq_lens_decoder",
"step_idx",
"stop_flags",
"logits",
"penalty_scores",
"frequency_scores",
"presence_scores",
"temperatures",
"bad_tokens",
"cur_len",
"min_len",
"eos_token_id"})
.Outputs({"logits_out", "pre_ids_out"})
.SetInplaceMap({{"logits", "logits_out"}, {"pre_ids", "pre_ids_out"}})
.SetKernelFn(PD_KERNEL(SetPreidsTokenPenaltyMultiScores));
Loading

0 comments on commit 0977858

Please sign in to comment.