Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

outputs includes eos token #2538

Closed
wuchaooooo opened this issue Jan 22, 2024 · 1 comment
Closed

outputs includes eos token #2538

wuchaooooo opened this issue Jan 22, 2024 · 1 comment
Labels

Comments

@wuchaooooo
Copy link

wuchaooooo commented Jan 22, 2024

code

gen_kwargs = {"top_p": top_p, "temperature": temperature, "max_tokens": max_length, "include_stop_str_in_output": False}
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"), tokenizer.get_command("<|observation|>")]
sampling_params = SamplingParams(stop_token_ids=eos_token_id, **gen_kwargs)
outputs = self.model.generate(sampling_params=sampling_params, prompt_token_ids=inputs["input_ids"].tolist())

outputs string

I need to use the insauto_quote_tool tool to get the user's car insurance quote.<|assistant|> insauto_quote_tool
 ```python
tool_call(type='object', properties={'quote_biz_id': '20220906000831000002005700226000'})
```<|observation|>

question

lwhy is '<|observation|>' still at the end?

environment

Model:ChatGLM3-6b

@simon-mo
Copy link
Collaborator

I think by default vLLM removes the special tokens as specified by the tokenizer.

def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None:
"""Decodes the new token for a sequence."""
(new_tokens, new_output_text, prefix_offset,
read_offset) = detokenize_incrementally(
self.tokenizer,
all_input_ids=seq.get_token_ids(),
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
read_offset=seq.read_offset,
skip_special_tokens=prms.skip_special_tokens,
spaces_between_special_tokens=prms.spaces_between_special_tokens,
)
if seq.tokens is None:
seq.tokens = new_tokens
else:
seq.tokens.extend(new_tokens)
seq.prefix_offset = prefix_offset
seq.read_offset = read_offset
seq.output_text += new_output_text

So while stop_token_ids is the stopping criteria. It is not the criteria for skipping during decode.

There's another flag include_stop_str_in_output, but currently it only works for SamplingParams.stop

def _check_stop(self, seq: Sequence,
sampling_params: SamplingParams) -> None:
"""Stop the finished sequences."""
for stop_str in sampling_params.stop:
if seq.output_text.endswith(stop_str):
if not sampling_params.include_stop_str_in_output:
# Truncate the output text so that the stop string is
# not included in the output.
seq.output_text = seq.output_text[:-len(stop_str)]
seq.status = SequenceStatus.FINISHED_STOPPED
return
if seq.get_last_token_id() in sampling_params.stop_token_ids:
seq.status = SequenceStatus.FINISHED_STOPPED
return

I would recommend extend include_stop_str_in_output flag to support stop_token_ids for your use case.
Contribution welcomed!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants