Skip to content

Commit

Permalink
feat: TRTLLM API handle tokenizers without pad_id (e.g., tiktoken) (N…
Browse files Browse the repository at this point in the history
…VIDIA#399)

Signed-off-by: Terry Kong <terryk@nvidia.com>
Signed-off-by: NeMo-Aligner CI <nemo-aligner-ci@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: abukharin <abukharin@nvidia.com>
  • Loading branch information
2 people authored and abukharin committed Nov 22, 2024
1 parent 96f5455 commit f5b93c9
Showing 1 changed file with 4 additions and 9 deletions.
13 changes: 4 additions & 9 deletions nemo_aligner/utils/trt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ def append_and_repad_list(list_of_items, item_to_append, pad_id):


class GPTGenerateTRTLLM:
# If a tokenizer does not have a pad_id, we use a large negative number and replace
# with self.eos_id after generation.
# Use a reserved negative number since there is variation between tokenizers if
# they (1) have a pad_id (2) don't have a pad_id or (3) have None as the pad_id.
# This pad_id is replaced with eos_id after generation.
DEFAULT_PAD_ID = -42

def __init__(
Expand All @@ -72,12 +73,6 @@ def __init__(
"You are trying to use NeMo-Aligner's TensorRT-LLM acceleration for LLM generation. Please build the dockerfile to enable this feature: https://github.com/NVIDIA/NeMo-Aligner/blob/main/Dockerfile"
)

# If this assert turns out to be a blocker with some tokenizers, potential workarounds could be to:
# - add a config option to allow specifying which token we pass as `end_id` to TRT-LLM (should
# be a token that the model is guaranteed to never generate)
assert (
tokenizer.pad_id != tokenizer.eos_id
), f"We require tokenizers to have a different {tokenizer.pad_id=} than {tokenizer.eos_id=} when using TRT-LLM. This is to make sure all code goes into the same path and include the eos_id when the response lengths are computed"
assert max_input_len > 0
assert max_generation_length > 0
assert (
Expand All @@ -104,7 +99,7 @@ def __init__(
rng_generator.manual_seed(seed)
self.rng_generator = rng_generator

self.pad_id = tokenizer.pad_id if tokenizer.pad_id is not None else GPTGenerateTRTLLM.DEFAULT_PAD_ID
self.pad_id = GPTGenerateTRTLLM.DEFAULT_PAD_ID
self.eos_id = tokenizer.eos_id
end_strings = list(end_strings)

Expand Down

0 comments on commit f5b93c9

Please sign in to comment.