Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/gpt early stop #584

Merged
merged 3 commits into from
May 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,9 @@ In the experiments of decoding, we updated the following parameters:

### Changelog

May 2023
- Fix bugs of generation early stopping

January 2023
- Support GPT MoE
- Support FP8 for Bert and GPT (**Experimental**)
Expand Down
34 changes: 25 additions & 9 deletions src/fastertransformer/kernels/gpt_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,7 @@ __global__ void generate_dups_indices(int* batch_to_compact,
int* compact_size,
const int* shared_contexts,
const size_t batch_size,
const size_t beam_width,
const size_t input_seq_len)
{
const int padded_batchsize = blockDim.x * ((batch_size + blockDim.x - 1) / blockDim.x);
Expand All @@ -649,20 +650,23 @@ __global__ void generate_dups_indices(int* batch_to_compact,
__shared__ int scan_offset;

int scan = 0;
for (int batch = threadIdx.x; batch < padded_batchsize; batch += blockDim.x) {
bool masked = (batch >= batch_size);
bool first_iter = batch < blockDim.x;
for (int seq_idx = threadIdx.x; seq_idx < padded_batchsize; seq_idx += blockDim.x) {
bool masked = (seq_idx >= batch_size);
bool first_iter = seq_idx < blockDim.x;

int is_first_occur = masked ? 0 : shared_contexts[batch] == batch;
int is_first_occur = masked ? 0 : shared_contexts[seq_idx] == seq_idx;
BlockScan(temp_storage).ExclusiveSum(is_first_occur, scan);

if (!masked && is_first_occur) {
int compact_idx = scan + (first_iter ? 0 : scan_offset);
// Context rep. writes initial index
batch_to_compact[batch] = compact_idx;
compact_to_batch[compact_idx] = batch;
batch_to_compact[seq_idx * beam_width] = compact_idx;
// input ids are tiled in context part
compact_to_batch[compact_idx] = seq_idx * beam_width;
}

__syncthreads();

if (threadIdx.x == blockDim.x - 1) {
scan_offset = scan + is_first_occur + (first_iter ? 0 : scan_offset);
}
Expand All @@ -671,8 +675,15 @@ __global__ void generate_dups_indices(int* batch_to_compact,

if (!masked && !is_first_occur) {
// Fill the rest of batch_to_compact based on what rep. wrote
const int src_idx = batch_to_compact[shared_contexts[batch]];
batch_to_compact[batch] = src_idx;
const int src_idx = batch_to_compact[shared_contexts[seq_idx] * beam_width];
batch_to_compact[seq_idx * beam_width] = src_idx;
}

if (!masked) {
// set same compact idx for beams
for (int beam_id = 1; beam_id < beam_width; ++beam_id) {
batch_to_compact[seq_idx * beam_width + beam_id] = batch_to_compact[seq_idx * beam_width];
}
}
}

Expand All @@ -696,14 +707,17 @@ void invokeFindContextDups(int* shared_contexts,
int* compact_size,
const int* input_ids,
const size_t batch_size,
const size_t beam_width,
const size_t input_seq_len,
cudaStream_t stream)
{
dim3 block{512};
dim3 grid{((int)batch_size + block.x - 1) / block.x};
// set shared_context[i] = i
init_shared_contexts<<<grid, block, 0, stream>>>(shared_contexts, batch_size);

grid = dim3{(unsigned int)(batch_size * (batch_size - 1)) / 2};
// set shared_contexts[i] = j, where j = min{k, such that input_ids[k] == input_ids[i]}
if (input_seq_len <= 128) {
block = 128;
find_context_dups<128><<<grid, block, 0, stream>>>(shared_contexts, input_ids, batch_size, input_seq_len);
Expand All @@ -713,8 +727,10 @@ void invokeFindContextDups(int* shared_contexts,
find_context_dups<256><<<grid, block, 0, stream>>>(shared_contexts, input_ids, batch_size, input_seq_len);
}

// set batch_to_compact[i] = j, where j is the position of input_ids[i] in the compact_batch
// set compact_to_batch[i] = j, where j is such that compact_to_batch[i] = input_ids[j]
generate_dups_indices<<<1, DUPS_INDICES_BLOCK_SIZE, 0, stream>>>(
batch_to_compact, compact_to_batch, compact_size, shared_contexts, batch_size, input_seq_len);
batch_to_compact, compact_to_batch, compact_size, shared_contexts, batch_size, beam_width, input_seq_len);
}

template<typename T>
Expand Down
1 change: 1 addition & 0 deletions src/fastertransformer/kernels/gpt_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ void invokeFindContextDups(int* shared_contexts,
int* compact_size,
const int* input_ids,
const size_t batch_size,
const size_t beam_width,
const size_t input_seq_len,
cudaStream_t stream = 0);

Expand Down
2 changes: 1 addition & 1 deletion src/fastertransformer/kernels/stop_criteria_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ void invokeLengthCriterion(bool* finished,

length_criterion<<<grid, block, 0, stream>>>(
finished, should_stop, h_pinned_finished_sum_, sequence_limit_length, batch_size, beam_width, step);
while (((volatile size_t*)h_pinned_finished_sum_)[0] == -1) {};
while (((volatile int*)h_pinned_finished_sum_)[0] == -1) {};
sync_check_cuda_error();

*should_stop = h_pinned_finished_sum_[0] == batch_size * beam_width;
Expand Down
Loading