From e0b124a77c847f431775b0fe68fab8dc207503cc Mon Sep 17 00:00:00 2001 From: Flex Wang Date: Wed, 4 Oct 2023 21:30:07 -0700 Subject: [PATCH] Add ability for force bos id for mbart (#22) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Merge with main (#1) * Update beam_search_topk_kernels.cu fix: fix bug of beam search * fix: change int of some kernels to int64_t to prevent overflow * fix: gpt tensor shapes inconsistency (#505) Signed-off-by: AkiyamaYummy <842720660@qq.com> * Update gpt_guide.md (#529) * fix: fix bug of gpt buffer and gpt gemm overflow * Update T5DecodingWeight.cc fix: fix loading bug of t5 * [Enhancement]add pytorch backend support for gptneox (#550) * add pytorch backend support for gptneox Signed-off-by: AkiyamaYummy <842720660@qq.com> * fix early stopping invalid * 1) Some unused parameters and logic have been removed. 2) Revisions that would affect pipeline parallelism have been reverted. 3) The code has been made capable of direct validation on TabbyML/NeoX-1.3B. Signed-off-by: AkiyamaYummy <842720660@qq.com> * Change the names of classes, removing 'parallel' from their names Signed-off-by: AkiyamaYummy <842720660@qq.com> * Format the code. Signed-off-by: AkiyamaYummy <842720660@qq.com> * Only print results when rank is 0. Signed-off-by: AkiyamaYummy <842720660@qq.com> * Add dist.init_process_group(). Signed-off-by: AkiyamaYummy <842720660@qq.com> * update docs Signed-off-by: AkiyamaYummy <842720660@qq.com> --------- Signed-off-by: AkiyamaYummy <842720660@qq.com> * Update cublasMMWrapper.cc Fix the CUBLAS_VERSION checking of cublasMMWrapper * Update cublasMMWrapper.cc * fix overflow in softmax_kernel when process long seqlen and big batch_size (#524) * Update unfused_attention_kernels.cu fix bug of softmax kernel * [Enhancement]create huggingface_gptneox_convert.py (#569) * create huggingface_gptneox_convert.py Signed-off-by: AkiyamaYummy <842720660@qq.com> * adjust HF's multi bin files Signed-off-by: AkiyamaYummy <842720660@qq.com> * update gptneox_guide.md Signed-off-by: AkiyamaYummy <842720660@qq.com> --------- Signed-off-by: AkiyamaYummy <842720660@qq.com> * perf(bloom): improve performance of huggingface_bloom_convert.py, decrease the time cost and the mem using (#568) Co-authored-by: r.yang * Fix/gpt early stop (#584) * fix: fix bug of early stopping of gpt * [bugfix] Fix 2-shot All Reduce correctness issue (indexing bug). (#672) FasterTransformer 2-shot all reduce is implemented as a reduce-scatter + all-gather. There is an indexing bug in the all-gather step. Prior to this change, 2-shot all reduce was only producing correct results on device 0. Now, all devices have the correct results. * fix: swap tensor bug (#683) * Support size_per_head=112 (#660) * fix multi-gpu build * add support for size_per_head=112 for gpt decoder * remove mpi_cxx from multi-gpu build for now (#705) --------- Signed-off-by: AkiyamaYummy <842720660@qq.com> Co-authored-by: byshiue Co-authored-by: _yummy_ <842720660@qq.com> Co-authored-by: Ying Sheng Co-authored-by: zhangxin81 <115389973+zhangxin81@users.noreply.github.com> Co-authored-by: 杨睿 <595403043@qq.com> Co-authored-by: r.yang Co-authored-by: Rahul Kindi Co-authored-by: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com> Co-authored-by: Daya Khudia <37562707+dskhudia@users.noreply.github.com> Co-authored-by: Dean Wyatte <2512762+dwyatte@users.noreply.github.com> * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit --------- Signed-off-by: AkiyamaYummy <842720660@qq.com> Co-authored-by: Asim Shankar Co-authored-by: byshiue Co-authored-by: _yummy_ <842720660@qq.com> Co-authored-by: Ying Sheng Co-authored-by: zhangxin81 <115389973+zhangxin81@users.noreply.github.com> Co-authored-by: 杨睿 <595403043@qq.com> Co-authored-by: r.yang Co-authored-by: Rahul Kindi Co-authored-by: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com> Co-authored-by: Daya Khudia <37562707+dskhudia@users.noreply.github.com> Co-authored-by: Dean Wyatte <2512762+dwyatte@users.noreply.github.com> --- .../kernels/decoding_kernels.cu | 28 +++++++++++++++++ .../kernels/decoding_kernels.h | 7 +++++ .../models/bart/BartDecoding.cc | 30 +++++++++++++++++++ .../models/bart/BartDecoding.h | 5 ++-- 4 files changed, 68 insertions(+), 2 deletions(-) diff --git a/src/fastertransformer/kernels/decoding_kernels.cu b/src/fastertransformer/kernels/decoding_kernels.cu index 89f0d5011..040c1bcff 100644 --- a/src/fastertransformer/kernels/decoding_kernels.cu +++ b/src/fastertransformer/kernels/decoding_kernels.cu @@ -64,6 +64,34 @@ void invokeDecodingInitialize(bool* finished, finished, sequence_length, word_ids, cum_log_probs, sentence_ids, batch_size, beam_width, max_input_length); } +__global__ void forceId(int* word_ids, + const int* force_bos_ids, + const int batch_size, + const int beam_width, + const int step) +{ + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size * beam_width; + index += blockDim.x * gridDim.x) { + if (word_ids != nullptr) { + word_ids[index+step*batch_size*beam_width] = force_bos_ids[index / beam_width]; + } + } +} + +void invokeForceId(int* word_ids, + const int* force_bos_ids, + const int batch_size, + const int beam_width, + const int step, + cudaStream_t stream) +{ + dim3 grid((int)ceil(batch_size * beam_width * 1.0 / 256)); + dim3 block(256); + + forceId<<>>( + word_ids, force_bos_ids, batch_size, beam_width, step); +} + template void invokeDecodingInitialize(bool* finished, int* sequence_length, int* word_ids, diff --git a/src/fastertransformer/kernels/decoding_kernels.h b/src/fastertransformer/kernels/decoding_kernels.h index 7527d8fc4..6307c5b93 100644 --- a/src/fastertransformer/kernels/decoding_kernels.h +++ b/src/fastertransformer/kernels/decoding_kernels.h @@ -33,6 +33,13 @@ void invokeDecodingInitialize(bool* finished, const int max_input_length, cudaStream_t stream); +void invokeForceId(int* word_ids, + const int* force_bos_ids, + const int batch_size, + const int beam_width, + const int step, + cudaStream_t stream); + // get token from all_ids at step, then lookup from the embedding table // by the token template diff --git a/src/fastertransformer/models/bart/BartDecoding.cc b/src/fastertransformer/models/bart/BartDecoding.cc index 8b91796f5..fef6508e9 100644 --- a/src/fastertransformer/models/bart/BartDecoding.cc +++ b/src/fastertransformer/models/bart/BartDecoding.cc @@ -116,6 +116,7 @@ void BartDecoding::allocateBuffer( start_ids_buf_ = (int*)(allocator_->reMalloc(start_ids_buf_, sizeof(int) * batch_size, false)); end_ids_buf_ = (int*)(allocator_->reMalloc(end_ids_buf_, sizeof(int) * batch_size, false)); + forced_bos_ids_buf_ = (int*)(allocator_->reMalloc(forced_bos_ids_buf_, sizeof(int) * batch_size, false)); output_ids_buf_ = (int*)(allocator_->reMalloc(output_ids_buf_, sizeof(int) * batchxbeam * (max_seq_len + 1), false)); @@ -182,6 +183,7 @@ void BartDecoding::freeBuffer() allocator_->free((void**)(&tiled_encoder_sequence_length_)); allocator_->free((void**)(&start_ids_buf_)); + allocator_->free((void**)(&forced_bos_ids_buf_)); allocator_->free((void**)(&end_ids_buf_)); allocator_->free((void**)(&output_ids_buf_)); @@ -343,6 +345,7 @@ void BartDecoding::forward(TensorMap* output_tensors, // stop_words_list [batch_size, 2, stop_words_length], optional // bad_words_list [batch_size, 2, stop_words_length], optional // start_id [batch_size] on cpu, optional + // forced_bos_id [batch_size] on cpu, optional // end_id [batch_size] on cpu, optional // runtime_top_k [1] or [batch_size] on cpu, optional, uint. // runtime_top_p [1] or [batch_size] on cpu, optional, float. @@ -382,6 +385,7 @@ void BartDecoding::forward(TensorMap* output_tensors, dynamic_decode_layer_->setup(batch_size, beam_width, &input_map); handleOptArg(&input_map, "start_id", start_ids_buf_, start_id_, batch_size); handleOptArg(&input_map, "end_id", end_ids_buf_, end_id_, batch_size); + handleOptArg(&input_map, "forced_bos_id", forced_bos_ids_buf_, -1, batch_size); } FT_CHECK_WITH_INFO(input_tensors->at("encoder_output").shape[2] == d_model_, @@ -792,6 +796,32 @@ void BartDecoding::forward(TensorMap* output_tensors, dynamic_decode_output_tensors.insert(*t); } dynamic_decode_layer_->forward(&dynamic_decode_output_tensors, &dynamic_decode_input_tensors); + if (step == 1 && input_tensors->isExist("forced_bos_id")) { + invokeForceId(output_ids_buf_, + forced_bos_ids_buf_, + batch_size, + beam_width, + step, + stream_); + sync_check_cuda_error(); + } + // { + // for (auto t = dynamic_decode_output_tensors.begin(); t != dynamic_decode_output_tensors.end(); ++t) { + // printf("step: %d, t->first: %s\n", step, t->first.c_str()); + // // printf("%s\n", t->second.toString().c_str()); + // { + // int* buf; + // int st = t->second.size(); + // buf = new int[st]; + // cudaMemcpy(buf, t->second.data, sizeof(int) * t->second.size(), cudaMemcpyDeviceToHost); + // for (int i=0; i