From 20e6fb0a2929234bfd124dfb6315a23725ab5d83 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 10 Oct 2024 15:28:05 +0100 Subject: [PATCH] Rename arg to "save_raw_chat_template" across all classes --- src/transformers/processing_utils.py | 22 ++++++++++----------- src/transformers/tokenization_utils_base.py | 4 ++-- tests/test_processing_common.py | 2 +- tests/test_tokenization_common.py | 8 ++++---- 4 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index 1a46dfa4e563bc..ffcd9545c62042 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -526,17 +526,17 @@ def save_pretrained(self, save_directory, push_to_hub: bool = False, **kwargs): # If we save using the predefined names, we can load using `from_pretrained` # plus we save chat_template in its own file output_processor_file = os.path.join(save_directory, PROCESSOR_NAME) - output_naked_chat_template_file = os.path.join(save_directory, "processor_chat_template.jinja") + output_raw_chat_template_file = os.path.join(save_directory, "processor_chat_template.jinja") output_chat_template_file = os.path.join(save_directory, "chat_template.json") processor_dict = self.to_dict() # Save `chat_template` in its own file. We can't get it from `processor_dict` as we popped it in `to_dict` # to avoid serializing chat template in json config file. So let's get it from `self` directly if self.chat_template is not None: - if kwargs.get("save_naked_chat_template", False): - with open(output_naked_chat_template_file, "w", encoding="utf-8") as writer: + if kwargs.get("save_raw_chat_template", False): + with open(output_raw_chat_template_file, "w", encoding="utf-8") as writer: writer.write(self.chat_template) - logger.info(f"chat template saved in {output_naked_chat_template_file}") + logger.info(f"chat template saved in {output_raw_chat_template_file}") else: chat_template_json_string = ( json.dumps({"chat_template": self.chat_template}, indent=2, sort_keys=True) + "\n" @@ -611,18 +611,18 @@ def get_processor_dict( resolved_processor_file = pretrained_model_name_or_path # cant't load chat-template when given a file as pretrained_model_name_or_path resolved_chat_template_file = None - resolved_naked_chat_template_file = None + resolved_raw_chat_template_file = None is_local = True elif is_remote_url(pretrained_model_name_or_path): processor_file = pretrained_model_name_or_path resolved_processor_file = download_url(pretrained_model_name_or_path) # can't load chat-template when given a file url as pretrained_model_name_or_path resolved_chat_template_file = None - resolved_naked_chat_template_file = None + resolved_raw_chat_template_file = None else: processor_file = PROCESSOR_NAME chat_template_file = "chat_template.json" - naked_chat_template_file = "processor_chat_template.jinja" + raw_chat_template_file = "processor_chat_template.jinja" try: # Load from local folder or from cache or download from model Hub and cache resolved_processor_file = cached_file( @@ -658,9 +658,9 @@ def get_processor_dict( _raise_exceptions_for_missing_entries=False, ) - resolved_naked_chat_template_file = cached_file( + resolved_raw_chat_template_file = cached_file( pretrained_model_name_or_path, - naked_chat_template_file, + raw_chat_template_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies, @@ -686,8 +686,8 @@ def get_processor_dict( ) # Add chat template as kwarg before returning because most models don't have processor config - if resolved_naked_chat_template_file is not None: - with open(resolved_naked_chat_template_file, "r", encoding="utf-8") as reader: + if resolved_raw_chat_template_file is not None: + with open(resolved_raw_chat_template_file, "r", encoding="utf-8") as reader: chat_template = reader.read() kwargs["chat_template"] = chat_template elif resolved_chat_template_file is not None: diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index bac6d7bff28d63..c990ca7713318b 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -2602,14 +2602,14 @@ def save_pretrained( # Let's make sure we properly save the special tokens. tokenizer_config.update(self.special_tokens_map) - if self.chat_template is not None and not kwargs.get("skip_chat_template_save", False): + if self.chat_template is not None: if isinstance(self.chat_template, dict): # Chat template dicts are saved to the config as lists of dicts with fixed key names. # They will be reconstructed as a single dict during loading. # We're trying to discourage chat template dicts, and they are always # saved in the config, never as single files. tokenizer_config["chat_template"] = [{"name": k, "template": v} for k, v in self.chat_template.items()] - elif kwargs.get("save_chat_template_file", False): + elif kwargs.get("save_raw_chat_template", False): with open(chat_template_file, "w", encoding="utf-8") as f: f.write(self.chat_template) logger.info(f"chat template saved in {chat_template_file}") diff --git a/tests/test_processing_common.py b/tests/test_processing_common.py index 66bc4add7e4e53..8d8fc52a08a6ad 100644 --- a/tests/test_processing_common.py +++ b/tests/test_processing_common.py @@ -532,7 +532,7 @@ def test_chat_template_save_loading(self): self.assertEqual(processor.chat_template, reloaded_processor.chat_template) with tempfile.TemporaryDirectory() as tmpdirname: - processor.save_pretrained(tmpdirname, save_naked_chat_template=True) + processor.save_pretrained(tmpdirname, save_raw_chat_template=True) self.assertTrue(Path(tmpdirname, "processor_chat_template.jinja").is_file()) self.assertFalse(Path(tmpdirname, "chat_template.json").is_file()) reloaded_processor = self.processor_class.from_pretrained(tmpdirname) diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index dadab49139cde8..d4833a83236b9f 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -1117,7 +1117,7 @@ def test_chat_template(self): new_tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False) with tempfile.TemporaryDirectory() as tmp_dir_name: - tokenizer.save_pretrained(tmp_dir_name, save_chat_template_file=True) + tokenizer.save_pretrained(tmp_dir_name, save_raw_chat_template=True) chat_template_file = Path(tmp_dir_name) / "chat_template.jinja" self.assertTrue(chat_template_file.is_file()) self.assertEqual(chat_template_file.read_text(), dummy_template) @@ -1407,11 +1407,11 @@ def test_chat_template_dict_saving(self): tokenizers = self.get_tokenizers() for tokenizer in tokenizers: with self.subTest(f"{tokenizer.__class__.__name__}"): - for save_chat_template_file in (True, False): + for save_raw_chat_template in (True, False): tokenizer.chat_template = {"template1": dummy_template_1, "template2": dummy_template_2} with tempfile.TemporaryDirectory() as tmp_dir_name: - # Test that save_chat_template_file is ignored when there's a dict of multiple templates - tokenizer.save_pretrained(tmp_dir_name, save_chat_template_file=save_chat_template_file) + # Test that save_raw_chat_template is ignored when there's a dict of multiple templates + tokenizer.save_pretrained(tmp_dir_name, save_raw_chat_template=save_raw_chat_template) config_dict = json.load(open(os.path.join(tmp_dir_name, "tokenizer_config.json"))) # Assert that chat templates are correctly serialized as lists of dictionaries self.assertEqual(