diff --git a/src/transformers/models/code_llama/tokenization_code_llama_fast.py b/src/transformers/models/code_llama/tokenization_code_llama_fast.py index 66a312eb3dfa8f..7d1e237022377e 100644 --- a/src/transformers/models/code_llama/tokenization_code_llama_fast.py +++ b/src/transformers/models/code_llama/tokenization_code_llama_fast.py @@ -278,7 +278,7 @@ def set_infilling_processor(self, reset, suffix_first=False, add_special_tokens= special_tokens = [(self.bos_token, self.bos_token_id)] if self.add_bos_token and add_special_tokens else [] if suffix_first: # format as "
{suf} {pre}" - pair += [self.prefix_token, self.suffix_token, "$A", self.middle_token, "$B"] + pair += [self.prefix_token, self.suffix_token, "$B", self.middle_token, "$A"] special_tokens += [ (self.prefix_token, self.prefix_id), (self.suffix_token, self.suffix_id), diff --git a/tests/models/code_llama/test_tokenization_code_llama.py b/tests/models/code_llama/test_tokenization_code_llama.py index 3df0c552c0daa4..2673981527048d 100644 --- a/tests/models/code_llama/test_tokenization_code_llama.py +++ b/tests/models/code_llama/test_tokenization_code_llama.py @@ -643,3 +643,15 @@ def main(): input_ids = tokenizer.encode(PROMPTS[0]) self.assertEqual(input_ids, tokenizer.encode(prefix, suffix=suffix)) self.assertEqual(tokenizer.encode(prefix, suffix=suffix), tokenizer_fast.encode(prefix, suffix=suffix)) + + # Adding suffix_first check for infilling tasks + suffix_first_formatted_prompt = tokenizer.tokenize(PROMPTS[0], suffix_first=True) + self.assertEqual(suffix_first_formatted_prompt, tokenizer_fast.tokenize(PROMPTS[0], suffix_first=True)) + prefix, suffix = PROMPTS[0].split(" ") + self.assertEqual(suffix_first_formatted_prompt, tokenizer.tokenize(prefix, suffix, suffix_first=True)) + self.assertEqual(suffix_first_formatted_prompt, tokenizer_fast.tokenize(prefix, suffix, suffix_first=True)) + + prefix, suffix = PROMPTS[0].split(" ") + suffix_first_input_ids = tokenizer.encode(PROMPTS[0], suffix_first=True) + self.assertEqual(suffix_first_input_ids, tokenizer.encode(prefix, suffix=suffix, suffix_first=True)) + self.assertEqual(suffix_first_input_ids, tokenizer_fast.encode(prefix, suffix=suffix, suffix_first=True))