Skip to content

Commit

Permalink
fix review
Browse files Browse the repository at this point in the history
  • Loading branch information
gzy19990617 committed Oct 27, 2024
1 parent 6b3f627 commit 55aacac
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 10 deletions.
11 changes: 7 additions & 4 deletions csrc/gpu/get_padding_offset_v2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include "paddle/extension.h"
#include "helper.h"

__global__ void GetPaddingOffsetV2Kernel(int *padding_offset,
int *cum_offsets_out,
Expand Down Expand Up @@ -54,10 +55,12 @@ std::vector<paddle::Tensor> GetPaddingOffsetV2(const paddle::Tensor& input_ids,
auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false);

const int token_num_data = cpu_token_num.data<int64_t>()[0];
auto x_remove_padding = paddle::full({token_num_data}, 0, paddle::DataType::INT64, input_ids.place());
auto padding_offset = paddle::full({token_num_data}, 0, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_q = paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_k = paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());

auto x_remove_padding = GetEmptyTensor({token_num_data}, paddle::DataType::INT64, input_ids.place());
auto padding_offset = GetEmptyTensor({token_num_data}, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_q = GetEmptyTensor({bsz + 1}, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_k = GetEmptyTensor({bsz + 1}, paddle::DataType::INT32, input_ids.place());

GetPaddingOffsetV2Kernel<<<bsz, 128, 0, cu_stream>>>(
padding_offset.data<int>(),
cum_offsets_out.data<int>(),
Expand Down
12 changes: 6 additions & 6 deletions csrc/gpu/set_preids_token_penalty_multi_scores.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include "helper.h"

template<typename T>
__global__ void update_value_all(const bool *stop_flags,
__global__ void set_preids_token_penalty_multi_scores_kernel(const bool *stop_flags,
int64_t *pre_ids,
const int64_t *input_ids,
const int *seq_lens_encoder,
Expand Down Expand Up @@ -98,7 +98,7 @@ __global__ void update_value_all(const bool *stop_flags,
}

template <paddle::DataType D>
void set_preids_token_penalty_multi_scores_kernel(const paddle::Tensor& pre_ids,
void set_preids_token_penalty_multi_scores(const paddle::Tensor& pre_ids,
const paddle::Tensor& input_ids,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
Expand Down Expand Up @@ -128,7 +128,7 @@ void set_preids_token_penalty_multi_scores_kernel(const paddle::Tensor& pre_ids,

int64_t end_length = eos_token_id.shape()[0];

update_value_all<DataType_><<<bs, 1024, 0, cu_stream>>>(
set_preids_token_penalty_multi_scores_kernel<DataType_><<<bs, 1024, 0, cu_stream>>>(
stop_flags.data<bool>(),
const_cast<int64_t*>(pre_ids.data<int64_t>()),
input_ids.data<int64_t>(),
Expand Down Expand Up @@ -172,7 +172,7 @@ void SetPreidsTokenPenaltyMultiScores(const paddle::Tensor& pre_ids,

switch (logits.type()) {
case paddle::DataType::BFLOAT16: {
return set_preids_token_penalty_multi_scores_kernel<paddle::DataType::BFLOAT16>(
return set_preids_token_penalty_multi_scores<paddle::DataType::BFLOAT16>(
pre_ids,
input_ids,
seq_lens_encoder,
Expand All @@ -191,7 +191,7 @@ void SetPreidsTokenPenaltyMultiScores(const paddle::Tensor& pre_ids,
);
}
case paddle::DataType::FLOAT16: {
return set_preids_token_penalty_multi_scores_kernel<paddle::DataType::FLOAT16>(
return set_preids_token_penalty_multi_scores<paddle::DataType::FLOAT16>(
pre_ids,
input_ids,
seq_lens_encoder,
Expand All @@ -210,7 +210,7 @@ void SetPreidsTokenPenaltyMultiScores(const paddle::Tensor& pre_ids,
);
}
case paddle::DataType::FLOAT32: {
return set_preids_token_penalty_multi_scores_kernel<paddle::DataType::FLOAT32>(
return set_preids_token_penalty_multi_scores<paddle::DataType::FLOAT32>(
pre_ids,
input_ids,
seq_lens_encoder,
Expand Down

0 comments on commit 55aacac

Please sign in to comment.