Skip to content

Commit

Permalink
Support specifying a prefix of assistant response (#2172)
Browse files Browse the repository at this point in the history
* Support prefix of assistant response

* fix UT
  • Loading branch information
AllentDan authored Jul 31, 2024
1 parent fed65b1 commit 02dece1
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
8 changes: 7 additions & 1 deletion lmdeploy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ def messages2prompt(self, messages, sequence_start=True, **kwargs):
role = message['role']
content = message['content']
ret += f'{box_map[role]}{content}{eox_map[role]}'
if len(messages) and messages[-1]['role'] == 'assistant':
return ret[:-len(eox_map['assistant'])] # prefix of response
ret += f'{self.assistant}'
return ret

Expand Down Expand Up @@ -510,6 +512,8 @@ def messages2prompt(self,
) + f" name={name_map[message['name']]}\n" if 'name' in message else box_map[
role]
ret += f'{begin}{content}{eox_map[role]}'
if len(messages) and messages[-1]['role'] == 'assistant':
return ret[:-len(eox_map['assistant'])] # prefix of response
ret += f'{self.assistant}'
return ret

Expand Down Expand Up @@ -844,9 +848,11 @@ def messages2prompt(self,
ret += f'{box_map[role]}{self.tools}{tool_prompt}{self.eotools}{content}{eox_map[role]}'
else:
ret += f'{box_map[role]}{content}{eox_map[role]}'
ret += f'{self.assistant}'
if sequence_start and not isinstance(messages, str):
ret = '<|begin_of_text|>' + ret
if len(messages) and messages[-1]['role'] == 'assistant':
return ret[:-len(eox_map['assistant'])] # prefix of response
ret += f'{self.assistant}'
return ret

@classmethod
Expand Down
15 changes: 14 additions & 1 deletion tests/test_lmdeploy/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ def test_vicuna():
assert _prompt is None


def test_prefix_response():
model = MODELS.get('internlm2')()
messages = [dict(role='assistant', content='prefix test')]
prompt = model.messages2prompt(messages)
assert prompt[-len('prefix test'):] == 'prefix test'


def test_internlm_chat():
prompt = 'hello, can u introduce yourself'
model = MODELS.get('internlm')(capability='completion')
Expand Down Expand Up @@ -311,6 +318,9 @@ def test_deepseek_coder():
}, {
'role': 'assistant',
'content': 'I am an AI'
}, {
'role': 'user',
'content': 'hi'
}]
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
Expand Down Expand Up @@ -392,6 +402,9 @@ def test_internvl_phi3():
}, {
'role': 'assistant',
'content': 'I am an AI'
}, {
'role': 'user',
'content': 'hi'
}]
res = model.messages2prompt(messages)
from huggingface_hub import hf_hub_download
Expand Down Expand Up @@ -426,7 +439,7 @@ def test_internvl2():
expected = '<|im_start|>system\n你是由上海人工智能实验室联合商汤科技开发的'\
'书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。'\
'<|im_end|>\n<|im_start|>user\nwho are you<|im_end|>\n<|im_start|>'\
'assistant\nI am an AI<|im_end|>\n<|im_start|>assistant\n'
'assistant\nI am an AI'
res = model.messages2prompt(messages)
assert res == expected

Expand Down

0 comments on commit 02dece1

Please sign in to comment.