Skip to content

Commit

Permalink
fix type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
bhavya01 committed Apr 24, 2024
1 parent 54d98e4 commit f669516
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions jetstream/engine/token_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,12 +217,12 @@ def encode(self, s: str, **kwargs):
true_length: Actual length of the non-padded sequence if padding is used.
"""
is_bos = kwargs.pop('is_bos', True)
prefill_length = kwargs.pop('prefill_length', None)
prefill_lengths = kwargs.pop('prefill_lengths', None)
max_prefill_length = kwargs.pop('max_prefill_length', None)

tokens, true_length = tokenize_and_pad(
s, self.vocab, is_bos=is_bos,
prefill_length=prefill_length,
prefill_lengths=prefill_lengths,
max_prefill_length=max_prefill_length
)
return tokens, true_length
Expand All @@ -234,7 +234,7 @@ def decode(
result_tokens: engine_api.ResultTokens,
complete: np.ndarray,
**kwargs,
) -> Tuple[List[str], np.ndarray]:
) -> Tuple[List[List[int]], np.ndarray]:
"""Processes a result tokens into a list of strings, handling multiple
samples.
Expand Down

0 comments on commit f669516

Please sign in to comment.