From 2cdc69da03b46be554d6afb83dbe4ef230eac91a Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Wed, 22 Jun 2022 14:26:03 +0200 Subject: [PATCH] initial commit (#17818) --- src/transformers/modeling_utils.py | 107 +---------------------------- 1 file changed, 1 insertion(+), 106 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index de9273c710bfd8..3ede5b14ad46b3 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -32,6 +32,7 @@ from torch.nn import CrossEntropyLoss from requests import HTTPError +from transformers.utils.hub import convert_file_size_to_int, get_checkpoint_shard_files from .activations import get_activation from .configuration_utils import PretrainedConfig @@ -205,40 +206,6 @@ def get_state_dict_dtype(state_dict): return next(state_dict.values()).dtype -def convert_file_size_to_int(size: Union[int, str]): - """ - Converts a size expressed as a string with digits an unit (like `"5MB"`) to an integer (in bytes). - - Args: - size (`int` or `str`): The size to convert. Will be directly returned if an `int`. - - Example: - - ```py - >>> convert_file_size_to_int("1MiB") - 1048576 - ``` - """ - if isinstance(size, int): - return size - if size.upper().endswith("GIB"): - return int(size[:-3]) * (2**30) - if size.upper().endswith("MIB"): - return int(size[:-3]) * (2**20) - if size.upper().endswith("KIB"): - return int(size[:-3]) * (2**10) - if size.upper().endswith("GB"): - int_size = int(size[:-2]) * (10**9) - return int_size // 8 if size.endswith("b") else int_size - if size.upper().endswith("MB"): - int_size = int(size[:-2]) * (10**6) - return int_size // 8 if size.endswith("b") else int_size - if size.upper().endswith("KB"): - int_size = int(size[:-2]) * (10**3) - return int_size // 8 if size.endswith("b") else int_size - raise ValueError("`size` is not in a valid format. Use an integer followed by the unit, e.g., '5GB'.") - - def dtype_byte_size(dtype): """ Returns the size (in bytes) occupied by one parameter of type `dtype`. @@ -324,78 +291,6 @@ def shard_checkpoint(state_dict: Dict[str, torch.Tensor], max_shard_size: Union[ return shards, index -def get_checkpoint_shard_files( - pretrained_model_name_or_path, - index_filename, - cache_dir=None, - force_download=False, - proxies=None, - resume_download=False, - local_files_only=False, - use_auth_token=None, - user_agent=None, - revision=None, - mirror=None, -): - """ - For a given model: - - - download and cache all the shards of a sharded checkpoint if `pretrained_model_name_or_path` is a model ID on the - Hub - - returns the list of paths to all the shards, as well as some metadata. - - For the description of each arg, see [`PreTrainedModel.from_pretrained`]. `index_filename` is the full path to the - index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub). - """ - with open(index_filename, "r") as f: - index = json.loads(f.read()) - - shard_filenames = sorted(list(set(index["weight_map"].values()))) - sharded_metadata = index["metadata"] - sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys()) - - # First, let's deal with local folder. - if os.path.isdir(pretrained_model_name_or_path): - shard_filenames = [os.path.join(pretrained_model_name_or_path, f) for f in shard_filenames] - return shard_filenames, sharded_metadata - - # At this stage pretrained_model_name_or_path is a model identifier on the Hub - cached_filenames = [] - for shard_filename in shard_filenames: - shard_url = hf_bucket_url( - pretrained_model_name_or_path, filename=shard_filename, revision=revision, mirror=mirror - ) - - try: - # Load from URL - cached_filename = cached_path( - shard_url, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - resume_download=resume_download, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - user_agent=user_agent, - ) - # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so - # we don't have to catch them here. - except EntryNotFoundError: - raise EnvironmentError( - f"{pretrained_model_name_or_path} does not appear to have a file named {shard_filename} which is " - "required according to the checkpoint index." - ) - except HTTPError: - raise EnvironmentError( - f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {shard_filename}. You should try" - " again after checking your internet connection." - ) - - cached_filenames.append(cached_filename) - - return cached_filenames, sharded_metadata - - def load_sharded_checkpoint(model, folder, strict=True): """ This is the same as