Skip to content

Commit

Permalink
fix(tgi): bogus max_new_tokens with static batching
Browse files Browse the repository at this point in the history
  • Loading branch information
dacorvo committed May 23, 2024
1 parent ac55c7e commit 98422bc
Showing 1 changed file with 4 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def assign(self, request: Request, generation_config: GenerationConfig):
self.seed = request.parameters.seed
# TODO: watermark
self._generation_config.max_new_tokens = request.stopping_parameters.max_new_tokens
self._max_new_tokens = self._generation_config.max_new_tokens
# TODO: stop_sequences, ignore_eos_token

def reset(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor, selector: TokenSelector):
Expand Down Expand Up @@ -197,8 +198,9 @@ def pause(self, reset_on_pause: bool):
if reset_on_pause:
# Drop the last token as it will be added back when resuming the slot
self._generated_tokens -= 1
# Subtract the number of cached tokens from the maximum number of tokens
self._generation_config.max_new_tokens -= self._generated_tokens
# Since generated tokens are now part of the prefill, we need to reevaluate
# max_new_tokens for the next generation
self._generation_config.max_new_tokens = self._max_new_tokens - self._generated_tokens
self._state = Slot.State.PAUSE

def resume(self):
Expand Down

0 comments on commit 98422bc

Please sign in to comment.