diff --git a/README.md b/README.md index d9cb1b68a..a00e0d631 100644 --- a/README.md +++ b/README.md @@ -213,6 +213,9 @@ In the experiments of decoding, we updated the following parameters: ### Changelog +May 2023 +- Fix bugs of generation early stopping + January 2023 - Support GPT MoE - Support FP8 for Bert and GPT (**Experimental**) diff --git a/src/fastertransformer/kernels/gpt_kernels.cu b/src/fastertransformer/kernels/gpt_kernels.cu index 9402b57fa..7dc9af620 100644 --- a/src/fastertransformer/kernels/gpt_kernels.cu +++ b/src/fastertransformer/kernels/gpt_kernels.cu @@ -640,6 +640,7 @@ __global__ void generate_dups_indices(int* batch_to_compact, int* compact_size, const int* shared_contexts, const size_t batch_size, + const size_t beam_width, const size_t input_seq_len) { const int padded_batchsize = blockDim.x * ((batch_size + blockDim.x - 1) / blockDim.x); @@ -649,20 +650,23 @@ __global__ void generate_dups_indices(int* batch_to_compact, __shared__ int scan_offset; int scan = 0; - for (int batch = threadIdx.x; batch < padded_batchsize; batch += blockDim.x) { - bool masked = (batch >= batch_size); - bool first_iter = batch < blockDim.x; + for (int seq_idx = threadIdx.x; seq_idx < padded_batchsize; seq_idx += blockDim.x) { + bool masked = (seq_idx >= batch_size); + bool first_iter = seq_idx < blockDim.x; - int is_first_occur = masked ? 0 : shared_contexts[batch] == batch; + int is_first_occur = masked ? 0 : shared_contexts[seq_idx] == seq_idx; BlockScan(temp_storage).ExclusiveSum(is_first_occur, scan); if (!masked && is_first_occur) { int compact_idx = scan + (first_iter ? 0 : scan_offset); // Context rep. writes initial index - batch_to_compact[batch] = compact_idx; - compact_to_batch[compact_idx] = batch; + batch_to_compact[seq_idx * beam_width] = compact_idx; + // input ids are tiled in context part + compact_to_batch[compact_idx] = seq_idx * beam_width; } + __syncthreads(); + if (threadIdx.x == blockDim.x - 1) { scan_offset = scan + is_first_occur + (first_iter ? 0 : scan_offset); } @@ -671,8 +675,15 @@ __global__ void generate_dups_indices(int* batch_to_compact, if (!masked && !is_first_occur) { // Fill the rest of batch_to_compact based on what rep. wrote - const int src_idx = batch_to_compact[shared_contexts[batch]]; - batch_to_compact[batch] = src_idx; + const int src_idx = batch_to_compact[shared_contexts[seq_idx] * beam_width]; + batch_to_compact[seq_idx * beam_width] = src_idx; + } + + if (!masked) { + // set same compact idx for beams + for (int beam_id = 1; beam_id < beam_width; ++beam_id) { + batch_to_compact[seq_idx * beam_width + beam_id] = batch_to_compact[seq_idx * beam_width]; + } } } @@ -696,14 +707,17 @@ void invokeFindContextDups(int* shared_contexts, int* compact_size, const int* input_ids, const size_t batch_size, + const size_t beam_width, const size_t input_seq_len, cudaStream_t stream) { dim3 block{512}; dim3 grid{((int)batch_size + block.x - 1) / block.x}; + // set shared_context[i] = i init_shared_contexts<<>>(shared_contexts, batch_size); grid = dim3{(unsigned int)(batch_size * (batch_size - 1)) / 2}; + // set shared_contexts[i] = j, where j = min{k, such that input_ids[k] == input_ids[i]} if (input_seq_len <= 128) { block = 128; find_context_dups<128><<>>(shared_contexts, input_ids, batch_size, input_seq_len); @@ -713,8 +727,10 @@ void invokeFindContextDups(int* shared_contexts, find_context_dups<256><<>>(shared_contexts, input_ids, batch_size, input_seq_len); } + // set batch_to_compact[i] = j, where j is the position of input_ids[i] in the compact_batch + // set compact_to_batch[i] = j, where j is such that compact_to_batch[i] = input_ids[j] generate_dups_indices<<<1, DUPS_INDICES_BLOCK_SIZE, 0, stream>>>( - batch_to_compact, compact_to_batch, compact_size, shared_contexts, batch_size, input_seq_len); + batch_to_compact, compact_to_batch, compact_size, shared_contexts, batch_size, beam_width, input_seq_len); } template diff --git a/src/fastertransformer/kernels/gpt_kernels.h b/src/fastertransformer/kernels/gpt_kernels.h index 617f9bc05..d78224e0a 100644 --- a/src/fastertransformer/kernels/gpt_kernels.h +++ b/src/fastertransformer/kernels/gpt_kernels.h @@ -127,6 +127,7 @@ void invokeFindContextDups(int* shared_contexts, int* compact_size, const int* input_ids, const size_t batch_size, + const size_t beam_width, const size_t input_seq_len, cudaStream_t stream = 0); diff --git a/src/fastertransformer/kernels/stop_criteria_kernels.cu b/src/fastertransformer/kernels/stop_criteria_kernels.cu index 5d6611153..a8d4b98fa 100644 --- a/src/fastertransformer/kernels/stop_criteria_kernels.cu +++ b/src/fastertransformer/kernels/stop_criteria_kernels.cu @@ -150,7 +150,7 @@ void invokeLengthCriterion(bool* finished, length_criterion<<>>( finished, should_stop, h_pinned_finished_sum_, sequence_limit_length, batch_size, beam_width, step); - while (((volatile size_t*)h_pinned_finished_sum_)[0] == -1) {}; + while (((volatile int*)h_pinned_finished_sum_)[0] == -1) {}; sync_check_cuda_error(); *should_stop = h_pinned_finished_sum_[0] == batch_size * beam_width; diff --git a/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc b/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc index ad9c3527b..93b80ae6e 100644 --- a/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc +++ b/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc @@ -101,7 +101,11 @@ void ParallelGpt::allocateBuffer(size_t batch_size, bool is_return_context_cum_log_probs) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); - const size_t batchxbeam = batch_size * beam_width; + const size_t batchxbeam = batch_size * beam_width; + const size_t local_batch_size = getLocalBatchSize(batch_size, 1, pipeline_para_.world_size_); + FT_CHECK(batch_size % local_batch_size == 0); + const size_t num_microbatches = batch_size / local_batch_size; + const size_t self_cache_size = (num_layer_ / pipeline_para_.world_size_) * batchxbeam * memory_len * hidden_units_ / tensor_para_.world_size_; @@ -111,8 +115,8 @@ void ParallelGpt::allocateBuffer(size_t batch_size, padded_embedding_kernel_ptr_ = padded_embedding_kernel_; } - input_attention_mask_ = (T*)(allocator_->reMalloc( - input_attention_mask_, sizeof(T) * batchxbeam * max_input_len * max_input_len, false)); + tiled_input_attention_mask_ = (T*)(allocator_->reMalloc( + tiled_input_attention_mask_, sizeof(T) * batchxbeam * max_input_len * max_input_len, false)); decoder_input_buf_ = (T*)(allocator_->reMalloc(decoder_input_buf_, sizeof(T) * batchxbeam * hidden_units_, false)); decoder_normed_input_buf_ = (T*)(allocator_->reMalloc(decoder_normed_input_buf_, sizeof(T) * batchxbeam * hidden_units_, false)); @@ -125,7 +129,6 @@ void ParallelGpt::allocateBuffer(size_t batch_size, (float*)(allocator_->reMalloc(nccl_logits_buf_, sizeof(float) * batchxbeam * vocab_size_padded_, false)); cum_log_probs_ = (float*)(allocator_->reMalloc(cum_log_probs_, sizeof(float) * batchxbeam, false)); finished_buf_ = (bool*)(allocator_->reMalloc(finished_buf_, sizeof(bool) * batchxbeam, false)); - h_finished_buf_ = new bool[batchxbeam]; sequence_lengths_ = (int*)(allocator_->reMalloc(sequence_lengths_, sizeof(int) * batchxbeam, false)); key_cache_ = (T*)(allocator_->reMalloc(key_cache_, sizeof(T) * self_cache_size * 2, true)); @@ -154,7 +157,8 @@ void ParallelGpt::allocateBuffer(size_t batch_size, output_ids_buf_ = (int*)(allocator_->reMalloc(output_ids_buf_, sizeof(int) * batchxbeam * max_session_len, true)); parent_ids_buf_ = (int*)(allocator_->reMalloc(parent_ids_buf_, sizeof(int) * batchxbeam * max_session_len, true)); seq_limit_len_ = (uint32_t*)(allocator_->reMalloc(seq_limit_len_, sizeof(uint32_t) * batch_size, false)); - masked_tokens_ = (bool*)(allocator_->reMalloc(masked_tokens_, sizeof(bool) * batchxbeam * memory_len, true)); + tiled_masked_tokens_ = + (bool*)(allocator_->reMalloc(tiled_masked_tokens_, sizeof(bool) * batchxbeam * memory_len, true)); context_decoder_input_buf_ = (T*)(allocator_->reMalloc( context_decoder_input_buf_, sizeof(T) * batchxbeam * max_input_len * hidden_units_, false)); @@ -184,12 +188,13 @@ void ParallelGpt::allocateBuffer(size_t batch_size, lp_logprob_buf_ = (float*)allocator_->reMalloc(lp_logprob_buf_, sizeof(float) * batchxbeam * max_input_len); } if (shared_contexts_ratio_ > 0.0f) { - shared_contexts_idx_ = (int*)allocator_->reMalloc(shared_contexts_idx_, 3 * batchxbeam * sizeof(int), false); - batch_to_compact_idx_ = shared_contexts_idx_ + batchxbeam; - compact_idx_ = shared_contexts_idx_ + 2 * batchxbeam; + shared_contexts_idx_ = (int*)allocator_->reMalloc(shared_contexts_idx_, batch_size * sizeof(int), false); + batch_to_compact_idx_ = (int*)allocator_->reMalloc(batch_to_compact_idx_, batchxbeam * sizeof(int), false); + compact_idx_ = (int*)allocator_->reMalloc(compact_idx_, batch_size * sizeof(int), false); compact_size_ = (int*)allocator_->reMalloc(compact_size_, sizeof(int), false); } - generation_should_stop_ = (bool*)allocator_->reMalloc(generation_should_stop_, sizeof(bool), true, true); + microbatch_should_stop_ = + (bool*)allocator_->reMalloc(microbatch_should_stop_, sizeof(bool) * num_microbatches, true, true); tiled_total_padding_count_ = (int*)allocator_->reMalloc(tiled_total_padding_count_, batchxbeam * sizeof(int), false); @@ -205,7 +210,7 @@ void ParallelGpt::freeBuffer() allocator_->free((void**)(&padded_embedding_kernel_)); } - allocator_->free((void**)(&input_attention_mask_)); + allocator_->free((void**)(&tiled_input_attention_mask_)); allocator_->free((void**)(&decoder_input_buf_)); allocator_->free((void**)(&decoder_output_buf_)); allocator_->free((void**)(&normed_decoder_output_buf_)); @@ -213,7 +218,6 @@ void ParallelGpt::freeBuffer() allocator_->free((void**)(&nccl_logits_buf_)); allocator_->free((void**)(&cum_log_probs_)); allocator_->free((void**)(&finished_buf_)); - delete[] h_finished_buf_; allocator_->free((void**)(&sequence_lengths_)); allocator_->free((void**)(&key_cache_)); @@ -230,7 +234,7 @@ void ParallelGpt::freeBuffer() allocator_->free((void**)(&transposed_output_ids_buf_)); allocator_->free((void**)(&output_ids_buf_)); allocator_->free((void**)(&parent_ids_buf_)); - allocator_->free((void**)(&masked_tokens_)); + allocator_->free((void**)(&tiled_masked_tokens_)); allocator_->free((void**)(&seq_limit_len_)); @@ -254,7 +258,7 @@ void ParallelGpt::freeBuffer() allocator_->free((void**)(&lp_nccl_logits_buf_)); allocator_->free((void**)(&lp_logprob_buf_)); - allocator_->free((void**)(&generation_should_stop_), true); + allocator_->free((void**)(µbatch_should_stop_), true); if (shared_contexts_ratio_ > 0.0f) { allocator_->free((void**)(&shared_contexts_idx_)); @@ -416,6 +420,8 @@ void ParallelGpt::computeContextCumLogProbs(float* cum_l const size_t batchxbeam = batch_size * beam_width; const size_t n_hidden_states = batchxbeam * max_input_length; + const cudaDataType_t cublas_type = getCudaDataType(); + if (pipeline_para_.rank_ == pipeline_para_.world_size_ - 1) { // normed decoder output [batch_size * beam_width, max_input_length, hidden_units_] invokeGeneralLayerNorm(lp_normed_decoder_output_buf_, @@ -439,10 +445,10 @@ void ParallelGpt::computeContextCumLogProbs(float* cum_l hidden_units_, // k &alpha, padded_embedding_kernel_ptr_, - sizeof(T) == 2 ? CUDA_R_16F : CUDA_R_32F, + cublas_type, hidden_units_, // k lp_normed_decoder_output_buf_, - sizeof(T) == 2 ? CUDA_R_16F : CUDA_R_32F, + cublas_type, hidden_units_, // k &beta, lp_logits_buf_, @@ -464,10 +470,10 @@ void ParallelGpt::computeContextCumLogProbs(float* cum_l hidden_units_, // k &alpha, padded_embedding_kernel_ptr_ + tensor_para_.rank_ * local_vocab_size * hidden_units_, - sizeof(T) == 2 ? CUDA_R_16F : CUDA_R_32F, + cublas_type, hidden_units_, // k lp_normed_decoder_output_buf_, - sizeof(T) == 2 ? CUDA_R_16F : CUDA_R_32F, + cublas_type, hidden_units_, // k &beta, lp_nccl_logits_buf_ + tensor_para_.rank_ * n_hidden_states * local_vocab_size, @@ -809,8 +815,9 @@ void ParallelGpt::forward(std::unordered_map* outp num_layer_ / pipeline_para_.world_size_, batch_size * beam_width, local_head_num_, memory_len, size_per_head_}; { - PUSH_RANGE("dynamic decode setup"); TensorMap input_map(*input_tensors); + + PUSH_RANGE("dynamic decode setup"); dynamic_decode_layer_->setup(batch_size, beam_width, &input_map); handleOptArg(&input_map, "start_id", start_ids_buf_, start_id_, batch_size); handleOptArg(&input_map, "end_id", end_ids_buf_, end_id_, batch_size); @@ -858,7 +865,7 @@ void ParallelGpt::forward(std::unordered_map* outp PUSH_RANGE("initialize output and parent ids"); cudaMemsetAsync(output_ids_buf_, 0, sizeof(int) * batch_size * beam_width * session_len, stream_); cudaMemsetAsync(parent_ids_buf_, 0, sizeof(int) * batch_size * beam_width * session_len, stream_); - cudaMemsetAsync(masked_tokens_, false, sizeof(bool) * batch_size * beam_width * memory_len, stream_); + cudaMemsetAsync(tiled_masked_tokens_, false, sizeof(bool) * batch_size * beam_width * memory_len, stream_); cudaMemsetAsync(tiled_total_padding_count_, 0, sizeof(int) * batch_size * beam_width, stream_); if (beam_width > 1) { cudaMemsetAsync(cache_indirections_[0], 0, 2 * sizeof(int) * batch_size * beam_width * memory_len, stream_); @@ -879,6 +886,25 @@ void ParallelGpt::forward(std::unordered_map* outp } POP_RANGE; + int compact_size; + bool use_shared_contexts = (shared_contexts_ratio_ > 0.0f) && (max_input_length >= 1) && (batch_size > 1); + PUSH_RANGE("find context dups"); + if (use_shared_contexts) { + invokeFindContextDups(shared_contexts_idx_, + batch_to_compact_idx_, + compact_idx_, + compact_size_, + input_tensors->at("input_ids").getPtr(), + batch_size, + beam_width, + max_input_length, + stream_); + cudaD2Hcpy(&compact_size, compact_size_, 1); + use_shared_contexts = compact_size <= shared_contexts_ratio_ * batch_size; + sync_check_cuda_error(); + } + POP_RANGE; + // NOTE: p/prompt-tuning process here (lookup prompt embedding tables by task name ids) // get p/prompt-tuning weight for each batch --> shape [batch, beam_width] // --> ptrs with shape [prompt_len, hidden_size] @@ -1010,7 +1036,7 @@ void ParallelGpt::forward(std::unordered_map* outp POP_RANGE; } PUSH_RANGE("build decoder attention mask"); - invokeBuildDecoderAttentionMask(input_attention_mask_, + invokeBuildDecoderAttentionMask(tiled_input_attention_mask_, tiled_input_lengths_buf_, nullptr, batch_size * beam_width, @@ -1020,24 +1046,6 @@ void ParallelGpt::forward(std::unordered_map* outp sync_check_cuda_error(); POP_RANGE; - int compact_size; - bool use_shared_contexts = (shared_contexts_ratio_ > 0.0f) && (max_input_length >= 1) && (batch_size > 1); - PUSH_RANGE("find context dups"); - if (use_shared_contexts) { - invokeFindContextDups(shared_contexts_idx_, - batch_to_compact_idx_, - compact_idx_, - compact_size_, - tiled_input_ids_buf_, - batch_size * beam_width, - max_input_length, - stream_); - cudaD2Hcpy(&compact_size, compact_size_, 1); - use_shared_contexts = compact_size <= shared_contexts_ratio_ * batch_size * beam_width; - sync_check_cuda_error(); - } - POP_RANGE; - TensorMap decoder_input_tensors( {{"decoder_input", Tensor(MEMORY_GPU, @@ -1049,15 +1057,16 @@ void ParallelGpt::forward(std::unordered_map* outp Tensor(MEMORY_GPU, data_type, {batch_size * beam_width, 1, (size_t)max_input_length, (size_t)max_input_length}, - input_attention_mask_)}, + tiled_input_attention_mask_)}, {"input_lengths", Tensor(MEMORY_GPU, TYPE_INT32, {batch_size * beam_width}, tiled_input_lengths_buf_)}}); if (use_shared_contexts) { decoder_input_tensors.insert("compact_idx", Tensor(MEMORY_GPU, TYPE_INT32, {(size_t)compact_size}, compact_idx_)); - decoder_input_tensors.insert("batch_to_compact_idx", - Tensor(MEMORY_GPU, TYPE_INT32, {batch_size * beam_width}, batch_to_compact_idx_)); + decoder_input_tensors.insert( + "batch_to_compact_idx", + Tensor(MEMORY_GPU, TYPE_INT32, {batch_size * beam_width}, batch_to_compact_idx_)); } if (gpt_variant_params_.use_attention_linear_bias) { decoder_input_tensors.insert("linear_bias_slopes", @@ -1169,7 +1178,7 @@ void ParallelGpt::forward(std::unordered_map* outp } PUSH_RANGE("mask padding tokens"); - invokeMaskPaddingTokens(masked_tokens_, + invokeMaskPaddingTokens(tiled_masked_tokens_, input_tensors->at("input_lengths").getPtr(), memory_len, max_input_length, @@ -1184,6 +1193,10 @@ void ParallelGpt::forward(std::unordered_map* outp const size_t local_batch_size = getLocalBatchSize(batch_size, 1, pipeline_para_.world_size_); FT_CHECK(batch_size % local_batch_size == 0); + const size_t iteration_num = batch_size / local_batch_size; + for (int microbatch = 0; microbatch < iteration_num; ++microbatch) { + microbatch_should_stop_[microbatch] = false; + } for (step_ = step_start; step_ < (int)gen_len; step_++) { // Loop body produces Nth token by embedding && encoding token (N-1) @@ -1192,11 +1205,14 @@ void ParallelGpt::forward(std::unordered_map* outp const int src_indir_idx = (step_ - step_start) % 2; const int tgt_indir_idx = 1 - src_indir_idx; - const size_t iteration_num = batch_size / local_batch_size; - *generation_should_stop_ = !fill_caches_only; + bool generation_should_stop = !fill_caches_only; PUSH_RANGE(fmtstr("token_%d", step_ - step_start)); for (uint ite = 0; ite < iteration_num; ++ite) { + // skip the finished microbatch in previous steps + if (microbatch_should_stop_[ite]) { + continue; + } const int id_offset = ite * local_batch_size * beam_width; const int hidden_units_offset = id_offset * hidden_units_; const int vocab_size_units_offset = id_offset * vocab_size_padded_; @@ -1214,10 +1230,9 @@ void ParallelGpt::forward(std::unordered_map* outp pipeline_para_, stream_); - // receive updated generation_should_stop_ from last rank - if (ite == 0) { - ftNcclRecv(generation_should_stop_, 1, pipeline_para_.world_size_ - 1, pipeline_para_, stream_); - } + // receive updated microbatch_should_stop_ from last rank + ftNcclRecv(microbatch_should_stop_ + ite, 1, pipeline_para_.world_size_ - 1, pipeline_para_, stream_); + generation_should_stop &= microbatch_should_stop_[ite]; // receive updated cache_indirections from last rank if (beam_width > 1) { @@ -1241,10 +1256,10 @@ void ParallelGpt::forward(std::unordered_map* outp // throw errors when detected ftNcclStreamSynchronize(tensor_para_, pipeline_para_, stream_); sync_check_cuda_error(); - - if (ite == 0 && *generation_should_stop_) { - break; - } + } + // skip the microbatch for last step, which is updated by last rank + if (microbatch_should_stop_[ite]) { + continue; } if ((max_input_length <= 1) || (step_ > step_start) || continue_gen) { @@ -1302,7 +1317,7 @@ void ParallelGpt::forward(std::unordered_map* outp Tensor(MEMORY_GPU, TYPE_BOOL, {local_batch_size * beam_width, memory_len}, - masked_tokens_ + id_offset * memory_len)}}); + tiled_masked_tokens_ + id_offset * memory_len)}}); if (beam_width > 1) { decoder_input_tensors.insert({"cache_indirection", Tensor(MEMORY_GPU, @@ -1403,7 +1418,7 @@ void ParallelGpt::forward(std::unordered_map* outp CUDA_R_32F, cublasGemmAlgo_t(-1)); POP_RANGE; - PUSH_RANGE("logits all reduce sum"); + PUSH_RANGE("logits all gather"); ftNcclAllGather(nccl_logits_buf_ + vocab_size_units_offset, nccl_logits_buf_ + vocab_size_units_offset, local_batch_size * beam_width * local_vocab_size, @@ -1484,9 +1499,14 @@ void ParallelGpt::forward(std::unordered_map* outp PUSH_RANGE("result sampling and stop check"); dynamic_decode_layer_->forward(&dynamic_decode_output_tensors, &dynamic_decode_input_tensors); - *generation_should_stop_ &= subbatch_should_stop; + generation_should_stop &= subbatch_should_stop; + microbatch_should_stop_[ite] = subbatch_should_stop; POP_RANGE; } + else { + // for other ranks, they cannot update generation_should_stop by DynamicDecode, set to false directly; + generation_should_stop &= microbatch_should_stop_[ite]; + } PUSH_RANGE("result communication"); // send results to other rank @@ -1504,10 +1524,8 @@ void ParallelGpt::forward(std::unordered_map* outp ftNcclSend( sequence_lengths_ + id_offset, local_batch_size * beam_width, i, pipeline_para_, stream_); - // send updated generation_should_stop_ - if (ite == 0) { - ftNcclSend(generation_should_stop_, 1, i, pipeline_para_, stream_); - } + // send updated microbatch_should_stop_ + ftNcclSend(microbatch_should_stop_ + ite, 1, i, pipeline_para_, stream_); // send updated cache_indirections if (beam_width > 1) { @@ -1547,13 +1565,20 @@ void ParallelGpt::forward(std::unordered_map* outp if (step_ == initial_step + max_input_length) { /* We have just finished processing input: update the padding count: * total_padding_count += (max_input_length - input_lengths) */ + PUSH_RANGE("Update padding count"); invokeUpdatePaddingCount(tiled_total_padding_count_, input_tensors->at("input_lengths").getPtr(), max_input_length, batch_size, beam_width, stream_); + POP_RANGE; } + + if (generation_should_stop) { + break; + } + POP_RANGE; } PUSH_RANGE("communicate tensors"); @@ -1605,6 +1630,7 @@ void ParallelGpt::setOutputTensors(std::unordered_map* const size_t max_context_len, const size_t max_input_without_prompt_length) { + PUSH_RANGE("Resolve output tensors"); if (pipeline_para_.rank_ != pipeline_para_.world_size_ - 1) { return; } @@ -1706,6 +1732,7 @@ void ParallelGpt::setOutputTensors(std::unordered_map* cudaD2Dcpy( output_tensors->at("is_finished").getPtr(), finished_buf_, output_tensors->at("is_finished").size()); } + POP_RANGE; } template diff --git a/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.h b/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.h index 39b6bab5e..ea24de2d3 100644 --- a/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.h +++ b/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.h @@ -116,7 +116,7 @@ class ParallelGpt: public BaseLayer { T* padded_embedding_kernel_; const T* padded_embedding_kernel_ptr_; - T* input_attention_mask_; + T* tiled_input_attention_mask_; T* decoder_input_buf_; T* decoder_normed_input_buf_ = nullptr; @@ -126,10 +126,9 @@ class ParallelGpt: public BaseLayer { float* nccl_logits_buf_; float* cum_log_probs_; bool* finished_buf_; - bool* h_finished_buf_; int* sequence_lengths_ = nullptr; uint32_t* seq_limit_len_ = nullptr; - bool* generation_should_stop_ = nullptr; + bool* microbatch_should_stop_ = nullptr; int* shared_contexts_idx_ = nullptr; T* compact_decoder_features_ = nullptr; @@ -154,7 +153,7 @@ class ParallelGpt: public BaseLayer { int* transposed_output_ids_buf_; int* output_ids_buf_; int* parent_ids_buf_; - bool* masked_tokens_ = nullptr; + bool* tiled_masked_tokens_ = nullptr; T* context_decoder_input_buf_; T* context_decoder_normed_input_buf_; diff --git a/src/fastertransformer/th_op/common/GptOps.cc b/src/fastertransformer/th_op/common/GptOps.cc index ea3a86887..fbb018085 100644 --- a/src/fastertransformer/th_op/common/GptOps.cc +++ b/src/fastertransformer/th_op/common/GptOps.cc @@ -48,6 +48,7 @@ std::vector find_context_duplications(Tensor input_ids) get_ptr(compact_size_tensor), get_ptr(input_ids), batch_size, + 1, seq_len, stream); diff --git a/tests/unittests/test_gpt_kernels.cu b/tests/unittests/test_gpt_kernels.cu index cef959078..c41308b8c 100644 --- a/tests/unittests/test_gpt_kernels.cu +++ b/tests/unittests/test_gpt_kernels.cu @@ -85,6 +85,7 @@ int test_find_context_dups() d_compact_size, d_input_ids, batch_size, + 1,//beam_width vec_size); int compact_size;