Skip to content

Commit

Permalink
initial commit (huggingface#17818)
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker authored and younesbelkada committed Jun 25, 2022
1 parent 422e8aa commit 2cdc69d
Showing 1 changed file with 1 addition and 106 deletions.
107 changes: 1 addition & 106 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from torch.nn import CrossEntropyLoss

from requests import HTTPError
from transformers.utils.hub import convert_file_size_to_int, get_checkpoint_shard_files

from .activations import get_activation
from .configuration_utils import PretrainedConfig
Expand Down Expand Up @@ -205,40 +206,6 @@ def get_state_dict_dtype(state_dict):
return next(state_dict.values()).dtype


def convert_file_size_to_int(size: Union[int, str]):
"""
Converts a size expressed as a string with digits an unit (like `"5MB"`) to an integer (in bytes).
Args:
size (`int` or `str`): The size to convert. Will be directly returned if an `int`.
Example:
```py
>>> convert_file_size_to_int("1MiB")
1048576
```
"""
if isinstance(size, int):
return size
if size.upper().endswith("GIB"):
return int(size[:-3]) * (2**30)
if size.upper().endswith("MIB"):
return int(size[:-3]) * (2**20)
if size.upper().endswith("KIB"):
return int(size[:-3]) * (2**10)
if size.upper().endswith("GB"):
int_size = int(size[:-2]) * (10**9)
return int_size // 8 if size.endswith("b") else int_size
if size.upper().endswith("MB"):
int_size = int(size[:-2]) * (10**6)
return int_size // 8 if size.endswith("b") else int_size
if size.upper().endswith("KB"):
int_size = int(size[:-2]) * (10**3)
return int_size // 8 if size.endswith("b") else int_size
raise ValueError("`size` is not in a valid format. Use an integer followed by the unit, e.g., '5GB'.")


def dtype_byte_size(dtype):
"""
Returns the size (in bytes) occupied by one parameter of type `dtype`.
Expand Down Expand Up @@ -324,78 +291,6 @@ def shard_checkpoint(state_dict: Dict[str, torch.Tensor], max_shard_size: Union[
return shards, index


def get_checkpoint_shard_files(
pretrained_model_name_or_path,
index_filename,
cache_dir=None,
force_download=False,
proxies=None,
resume_download=False,
local_files_only=False,
use_auth_token=None,
user_agent=None,
revision=None,
mirror=None,
):
"""
For a given model:
- download and cache all the shards of a sharded checkpoint if `pretrained_model_name_or_path` is a model ID on the
Hub
- returns the list of paths to all the shards, as well as some metadata.
For the description of each arg, see [`PreTrainedModel.from_pretrained`]. `index_filename` is the full path to the
index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub).
"""
with open(index_filename, "r") as f:
index = json.loads(f.read())

shard_filenames = sorted(list(set(index["weight_map"].values())))
sharded_metadata = index["metadata"]
sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys())

# First, let's deal with local folder.
if os.path.isdir(pretrained_model_name_or_path):
shard_filenames = [os.path.join(pretrained_model_name_or_path, f) for f in shard_filenames]
return shard_filenames, sharded_metadata

# At this stage pretrained_model_name_or_path is a model identifier on the Hub
cached_filenames = []
for shard_filename in shard_filenames:
shard_url = hf_bucket_url(
pretrained_model_name_or_path, filename=shard_filename, revision=revision, mirror=mirror
)

try:
# Load from URL
cached_filename = cached_path(
shard_url,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
)
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
# we don't have to catch them here.
except EntryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {shard_filename} which is "
"required according to the checkpoint index."
)
except HTTPError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {shard_filename}. You should try"
" again after checking your internet connection."
)

cached_filenames.append(cached_filename)

return cached_filenames, sharded_metadata


def load_sharded_checkpoint(model, folder, strict=True):
"""
This is the same as
Expand Down

0 comments on commit 2cdc69d

Please sign in to comment.