Skip to content

Commit

Permalink
Fix chat template for Yuan when answer is contained within question
Browse files Browse the repository at this point in the history
  • Loading branch information
lvdongyi committed Nov 23, 2024
1 parent 1ba3bef commit 9be04bf
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 1 deletion.
82 changes: 81 additions & 1 deletion paddlenlp/transformers/yuan/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
"""Tokenization class for Yuan2.0 model"""

import os
import re
from shutil import copyfile
from typing import List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

import sentencepiece as spm

Expand Down Expand Up @@ -200,3 +201,82 @@ def create_token_type_ids_from_sequences(
if token_ids_1 is None:
return len(token_ids_0 + eos) * [0]
return len(token_ids_0 + eos + token_ids_1 + eos) * [0]

def _encode_chat_inputs(
self,
conversations: List[Tuple[str, str]],
context_data: Dict[str, Any] = {},
system: str = None,
add_generation_prompt=True,
):
result = {}

# Some template do not support system msg, so we need to check it first.
if system:
try:
self.chat_template.render(messages={"role": "system", "content": system})
except Exception as e:
raise ValueError("System is not supported in this tokenizer.", e)

Check warning on line 219 in paddlenlp/transformers/yuan/tokenizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/yuan/tokenizer.py#L216-L219

Added lines #L216 - L219 were not covered by tests

# convert list msg to role dict msg
conversation_dict = []
origin_msg = []
for round in conversations:
round_role = [
{"role": "user", "content": round[0]},
{"role": "assistant", "content": round[1]},
]
origin_msg.extend(round_role)
conversation_dict.append(round_role)
ans = []

# get answer in single round, then compile the chat entirely and split by single round ans
# attention: answer should include end token!
for conv in conversation_dict:
roundi = [system] + conv if system else conv
roundi_str = self.chat_template.render(
messages=roundi, add_generation_prompt=False, **self.special_tokens_map
)
roundi_no_ans = [system] + [conv[0]] if system else [conv[0]]
roundi_no_ans_str = self.chat_template.render(
messages=roundi_no_ans, add_generation_prompt=add_generation_prompt, **self.special_tokens_map
)

ans_roundi = roundi_str[len(roundi_no_ans_str) - len("<sep>") + len("<n>") : -len("<sep>")]
ans.append(ans_roundi)
for idx, _ in enumerate(ans):
ans[idx] += "<n>" if idx != len(ans) - 1 else "<sep>"

non_learnable_parts = self._extract_non_learnable_parts(origin_msg, ans)
assert len(non_learnable_parts) == len(ans)

conversation_ids = []
for i in range(len(non_learnable_parts)):
conversation_ids.append(
self.batch_encode(
[non_learnable_parts[i], ans[i]],
add_special_tokens=False,
padding=False,
)["input_ids"]
)

result["conversations"] = conversation_ids
return result

def _extract_non_learnable_parts(self, origin_msg: List[Dict[str, str]], split_s: List[str]):
"""Split the entire chat by specified words. Extract the non-learnable parts."""
# distingish and replace the special words in original string to an uncompiled form: Like | -> \|
split_s_with_front_token = split_s.copy()
for idx, _ in enumerate(split_s):
split_s_with_front_token[idx] = "<n>" + split_s_with_front_token[idx]
regex_pattern = "|".join(map(re.escape, split_s_with_front_token))
# splited by replaced specified words
non_learnable_parts = re.split(
r"(?:%s)" % regex_pattern,
self.chat_template.render(messages=origin_msg, add_generation_prompt=False, **self.special_tokens_map),
)
if non_learnable_parts[-1] == "":
non_learnable_parts.pop()
for idx, _ in enumerate(non_learnable_parts):
non_learnable_parts[idx] = non_learnable_parts[idx] + "<n>"
return non_learnable_parts
13 changes: 13 additions & 0 deletions tests/transformers/yuan/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
47 changes: 47 additions & 0 deletions tests/transformers/yuan/test_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

from paddlenlp.transformers import YuanTokenizer


class YuanTokenizationTest(unittest.TestCase):
def test_extract_non_learnable_parts(self):
models_with_templates = [
"IEITYuan/Yuan2-2B",
]
dummy_conversastions = [
["Q.", "A."],
["Q.A.", "A."],
["Q?", "A!"],
]
decode_outputs = [
["Q.<n>", "A.<n>"],
["Q.A.<n>", "A.<n>"],
["Q?<n>", " A!<sep>"], # notify there is an extra space
]
context_data = {}
context_data["is_training"] = True
for model_id in models_with_templates:
tokenizer = YuanTokenizer.from_pretrained(model_id)
if tokenizer.chat_template is None:
continue
conversation_result: list[tuple[list[int], list[int]]] = tokenizer.encode_chat_inputs(
dummy_conversastions,
context_data=context_data,
)
for idx, round in enumerate(conversation_result["conversations"]):
self.assertEquals(tokenizer.decode(round[0]), decode_outputs[idx][0])
self.assertEquals(tokenizer.decode(round[1]), decode_outputs[idx][1])

0 comments on commit 9be04bf

Please sign in to comment.