Skip to content

Commit

Permalink
fix(tgi): copy tokens in each slot
Browse files Browse the repository at this point in the history
  • Loading branch information
dacorvo committed Sep 12, 2023
1 parent 4e16ede commit 5b13e6a
Showing 1 changed file with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def assign(self, request: Request, generation_config: GenerationConfig):
self._generation_config.max_new_tokens = request.stopping_parameters.max_new_tokens
# TODO: stop_sequences, ignore_eos_token

def reset(self, input_ids, selector):
def reset(self, input_ids: torch.LongTensor, selector: TokenSelector):
"""Reset the slot for the next generation.
Args:
Expand All @@ -166,7 +166,7 @@ def reset(self, input_ids, selector):
selector: (`optimum.neuron.generation.TokenSelector`):
An object implementing the updated token selection logic.
"""
self._tokens = input_ids
self._tokens = input_ids.clone()
self._selector = selector

def pause(self):
Expand Down

0 comments on commit 5b13e6a

Please sign in to comment.