From 2eca8c1daf5878e572b052ea322d1596438b08d8 Mon Sep 17 00:00:00 2001 From: "Wang, Chang" Date: Mon, 5 Feb 2024 17:31:22 +0800 Subject: [PATCH] Set trainer.save_model state_dict format to safetensors (#1227) --- .../transformers/modeling/model.py | 128 +++++++++--------- .../transformers/trainer.py | 6 +- 2 files changed, 67 insertions(+), 67 deletions(-) diff --git a/intel_extension_for_transformers/transformers/modeling/model.py b/intel_extension_for_transformers/transformers/modeling/model.py index 597988dd3dc..197f67a9c1a 100644 --- a/intel_extension_for_transformers/transformers/modeling/model.py +++ b/intel_extension_for_transformers/transformers/modeling/model.py @@ -124,7 +124,39 @@ def from_pretrained(cls, model_name_or_path: str, **kwargs): model_class._keys_to_ignore_on_load_missing = missing_keys_to_ignore_on_load else: # pragma: no cover model_class._keys_to_ignore_on_load_missing.extend(missing_keys_to_ignore_on_load) - + if not os.path.isdir(model_name_or_path) and not os.path.isfile(model_name_or_path): # pragma: no cover + from transformers.utils import cached_file + try: + # Load from URL or cache if already cached + resolved_weights_file = cached_file( + model_name_or_path, + filename=WEIGHTS_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + use_auth_token=use_auth_token, + ) + except EnvironmentError as err: # pragma: no cover + logger.error(err) + msg = ( + f"Can't load weights for '{model_name_or_path}'. Make sure that:\n\n" + f"- '{model_name_or_path}' is a correct model identifier " + f"listed on 'https://huggingface.co/models'\n (make sure " + f"'{model_name_or_path}' is not a path to a local directory with " + f"something else, in that case)\n\n- or '{model_name_or_path}' is " + f"the correct path to a directory containing a file " + f"named one of {WEIGHTS_NAME}\n\n" + ) + if revision is not None: + msg += (f"- or '{revision}' is a valid git identifier " + f"(branch name, a tag name, or a commit id) that " + f"exists for this model name as listed on its model " + f"page on 'https://huggingface.co/models'\n\n" + ) + raise EnvironmentError(msg) + else: + resolved_weights_file = os.path.join(model_name_or_path, WEIGHTS_NAME) + state_dict = torch.load(resolved_weights_file, {}) model = model_class.from_pretrained( model_name_or_path, cache_dir=cache_dir, @@ -132,77 +164,43 @@ def from_pretrained(cls, model_name_or_path: str, **kwargs): resume_download=resume_download, use_auth_token=use_auth_token, revision=revision, + state_dict=state_dict, **kwargs, ) - model_class._keys_to_ignore_on_load_unexpected = keys_to_ignore_on_load_unexpected model_class._keys_to_ignore_on_load_missing = keys_to_ignore_on_load_missing dataloader = kwargs.get("dataloader", None) if not os.path.isdir(model_name_or_path) and not os.path.isfile(model_name_or_path): # pragma: no cover - # pylint: disable=E0611 - if Version(transformers.__version__) < Version('4.22.0'): - from transformers.file_utils import cached_path, hf_bucket_url - weights_file = hf_bucket_url(model_name_or_path, - filename=WEIGHTS_NAME, - revision=revision) - try: - # Load from URL or cache if already cached - resolved_weights_file = cached_path( - weights_file, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - use_auth_token=use_auth_token, - ) - except EnvironmentError as err: # pragma: no cover - logger.error(err) - msg = ( - f"Can't load weights for '{model_name_or_path}'. Make sure that:\n\n" - f"- '{model_name_or_path}' is a correct model identifier " - f"listed on 'https://huggingface.co/models'\n (make sure " - f"'{model_name_or_path}' is not a path to a local directory with " - f"something else, in that case)\n\n- or '{model_name_or_path}' is " - f"the correct path to a directory containing a file " - f"named one of {WEIGHTS_NAME}\n\n" - ) - if revision is not None: - msg += (f"- or '{revision}' is a valid git identifier " - f"(branch name, a tag name, or a commit id) that " - f"exists for this model name as listed on its model " - f"page on 'https://huggingface.co/models'\n\n" - ) - raise EnvironmentError(msg) - else: - from transformers.utils import cached_file - try: - # Load from URL or cache if already cached - resolved_weights_file = cached_file( - model_name_or_path, - filename=WEIGHTS_NAME, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - use_auth_token=use_auth_token, - ) - except EnvironmentError as err: # pragma: no cover - logger.error(err) - msg = ( - f"Can't load weights for '{model_name_or_path}'. Make sure that:\n\n" - f"- '{model_name_or_path}' is a correct model identifier " - f"listed on 'https://huggingface.co/models'\n (make sure " - f"'{model_name_or_path}' is not a path to a local directory with " - f"something else, in that case)\n\n- or '{model_name_or_path}' is " - f"the correct path to a directory containing a file " - f"named one of {WEIGHTS_NAME}\n\n" - ) - if revision is not None: - msg += (f"- or '{revision}' is a valid git identifier " - f"(branch name, a tag name, or a commit id) that " - f"exists for this model name as listed on its model " - f"page on 'https://huggingface.co/models'\n\n" - ) - raise EnvironmentError(msg) + from transformers.utils import cached_file + try: + # Load from URL or cache if already cached + resolved_weights_file = cached_file( + model_name_or_path, + filename=WEIGHTS_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + use_auth_token=use_auth_token, + ) + except EnvironmentError as err: # pragma: no cover + logger.error(err) + msg = ( + f"Can't load weights for '{model_name_or_path}'. Make sure that:\n\n" + f"- '{model_name_or_path}' is a correct model identifier " + f"listed on 'https://huggingface.co/models'\n (make sure " + f"'{model_name_or_path}' is not a path to a local directory with " + f"something else, in that case)\n\n- or '{model_name_or_path}' is " + f"the correct path to a directory containing a file " + f"named one of {WEIGHTS_NAME}\n\n" + ) + if revision is not None: + msg += (f"- or '{revision}' is a valid git identifier " + f"(branch name, a tag name, or a commit id) that " + f"exists for this model name as listed on its model " + f"page on 'https://huggingface.co/models'\n\n" + ) + raise EnvironmentError(msg) q_model = load( resolved_weights_file, model, diff --git a/intel_extension_for_transformers/transformers/trainer.py b/intel_extension_for_transformers/transformers/trainer.py index f5f52fb2056..81c7ee36a7e 100644 --- a/intel_extension_for_transformers/transformers/trainer.py +++ b/intel_extension_for_transformers/transformers/trainer.py @@ -1981,7 +1981,8 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): if is_pretrained: if state_dict is None: state_dict = unwrapped_model.state_dict() - unwrapped_model.save_pretrained(output_dir, state_dict=state_dict) + unwrapped_model.save_pretrained(output_dir, state_dict=state_dict, + safe_serialization=self.args.save_safetensors) else: logger.info( "Trainer.model is not a `PreTrainedModel`, only saving its state dict.") @@ -1993,7 +1994,8 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): if self.enable_inc_quant and self.opt_model: self._save_inc_int8(self.opt_model, output_dir) else: - self.model.save_pretrained(output_dir, state_dict=state_dict) + self.model.save_pretrained(output_dir, state_dict=state_dict, + safe_serialization=self.args.save_safetensors) if self.tokenizer is not None: # pragma: no cover self.tokenizer.save_pretrained(output_dir)