diff --git a/src/transformers/models/code_llama/tokenization_code_llama_fast.py b/src/transformers/models/code_llama/tokenization_code_llama_fast.py index 66a312eb3dfa8f..7d1e237022377e 100644 --- a/src/transformers/models/code_llama/tokenization_code_llama_fast.py +++ b/src/transformers/models/code_llama/tokenization_code_llama_fast.py @@ -278,7 +278,7 @@ def set_infilling_processor(self, reset, suffix_first=False, add_special_tokens= special_tokens = [(self.bos_token, self.bos_token_id)] if self.add_bos_token and add_special_tokens else [] if suffix_first: # format as "
 {suf}  {pre}"
-            pair += [self.prefix_token, self.suffix_token, "$A", self.middle_token, "$B"]
+            pair += [self.prefix_token, self.suffix_token, "$B", self.middle_token, "$A"]
             special_tokens += [
                 (self.prefix_token, self.prefix_id),
                 (self.suffix_token, self.suffix_id),
diff --git a/tests/models/code_llama/test_tokenization_code_llama.py b/tests/models/code_llama/test_tokenization_code_llama.py
index 3df0c552c0daa4..2673981527048d 100644
--- a/tests/models/code_llama/test_tokenization_code_llama.py
+++ b/tests/models/code_llama/test_tokenization_code_llama.py
@@ -643,3 +643,15 @@ def main():
         input_ids = tokenizer.encode(PROMPTS[0])
         self.assertEqual(input_ids, tokenizer.encode(prefix, suffix=suffix))
         self.assertEqual(tokenizer.encode(prefix, suffix=suffix), tokenizer_fast.encode(prefix, suffix=suffix))
+
+        # Adding suffix_first check for infilling tasks
+        suffix_first_formatted_prompt = tokenizer.tokenize(PROMPTS[0], suffix_first=True)
+        self.assertEqual(suffix_first_formatted_prompt, tokenizer_fast.tokenize(PROMPTS[0], suffix_first=True))
+        prefix, suffix = PROMPTS[0].split("")
+        self.assertEqual(suffix_first_formatted_prompt, tokenizer.tokenize(prefix, suffix, suffix_first=True))
+        self.assertEqual(suffix_first_formatted_prompt, tokenizer_fast.tokenize(prefix, suffix, suffix_first=True))
+
+        prefix, suffix = PROMPTS[0].split("")
+        suffix_first_input_ids = tokenizer.encode(PROMPTS[0], suffix_first=True)
+        self.assertEqual(suffix_first_input_ids, tokenizer.encode(prefix, suffix=suffix, suffix_first=True))
+        self.assertEqual(suffix_first_input_ids, tokenizer_fast.encode(prefix, suffix=suffix, suffix_first=True))