Skip to content

Commit

Permalink
add load_url to handle imcompatibility of PyTorch versions
Browse files Browse the repository at this point in the history
  • Loading branch information
zhouzaida committed Oct 10, 2021
1 parent af8f18c commit 01fe0ed
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 19 deletions.
23 changes: 5 additions & 18 deletions mmcv/runner/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@
import torch
import torchvision
from torch.optim import Optimizer
from torch.utils import model_zoo

import mmcv
from mmcv.utils import TORCH_VERSION, digit_version
from mmcv.utils import load_url
from ..fileio import FileClient
from ..fileio import load as load_file
from ..parallel import is_module_wrapper
Expand Down Expand Up @@ -281,12 +280,12 @@ def load_from_http(filename, map_location=None, model_dir=None):
rank, world_size = get_dist_info()
rank = int(os.environ.get('LOCAL_RANK', rank))
if rank == 0:
checkpoint = model_zoo.load_url(
checkpoint = load_url(
filename, model_dir=model_dir, map_location=map_location)
if world_size > 1:
torch.distributed.barrier()
if rank > 0:
checkpoint = model_zoo.load_url(
checkpoint = load_url(
filename, model_dir=model_dir, map_location=map_location)
return checkpoint

Expand Down Expand Up @@ -629,18 +628,6 @@ def save_checkpoint(model, filename, optimizer=None, meta=None):
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
meta (dict, optional): Metadata to be saved in checkpoint.
"""

def _save_ckpt(checkpoint, file):
# The 1.6 release of PyTorch switched torch.save to use a new
# zipfile-based file format. It will cause RuntimeError when a
# checkpoint was saved in high version (PyTorch version>=1.6.0) but
# loaded in low version (PyTorch version<1.6.0). More details at
# https://github.com/open-mmlab/mmpose/issues/904
if digit_version(TORCH_VERSION) >= digit_version('1.6.0'):
torch.save(checkpoint, file, _use_new_zipfile_serialization=False)
else:
torch.save(checkpoint, file)

if meta is None:
meta = {}
elif not isinstance(meta, dict):
Expand Down Expand Up @@ -683,12 +670,12 @@ def _save_ckpt(checkpoint, file):
with TemporaryDirectory() as tmp_dir:
checkpoint_file = osp.join(tmp_dir, model_name)
with open(checkpoint_file, 'wb') as f:
_save_ckpt(checkpoint, f)
torch.save(checkpoint, f)
f.flush()
model.create_file(checkpoint_file, name=model_name)
else:
mmcv.mkdir_or_exist(osp.dirname(filename))
# immediately flush buffer
with open(filename, 'wb') as f:
_save_ckpt(checkpoint, f)
torch.save(checkpoint, f)
f.flush()
3 changes: 2 additions & 1 deletion mmcv/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
_MaxPoolNd, get_build_config, is_rocm_pytorch, _get_cuda_home)
from .registry import Registry, build_from_cfg
from .trace import is_jit_tracing
from .hub import load_url
__all__ = [
'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger',
'print_log', 'is_str', 'iter_cast', 'list_cast', 'tuple_cast',
Expand All @@ -65,5 +66,5 @@
'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer',
'assert_params_all_zeros', 'check_python_script',
'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch',
'_get_cuda_home'
'_get_cuda_home', 'load_url'
]
114 changes: 114 additions & 0 deletions mmcv/utils/hub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# The 1.6 release of PyTorch switched torch.save to use a new zipfile-based
# file format. It will cause RuntimeError when a checkpoint was saved in high
# version (PyTorch version>=1.6.0) but loaded in low version torch < 1.6.0.
# More details at# https://github.com/open-mmlab/mmpose/issues/904
from .parrots_wrapper import TORCH_VERSION, digit_version
from .path import mkdir_or_exist

if TORCH_VERSION != 'parrots' and digit_version(TORCH_VERSION) < digit_version(
'1.6.0'):
# Modified from https://github.com/pytorch/pytorch/blob/master/torch/hub.py
import os
import torch
import warnings
from urllib.parse import urlparse
import sys
import zipfile
from torch.hub import download_url_to_file, _get_torch_home, HASH_REGEX

# Hub used to support automatically extracts from zipfile manually
# compressed by users. The legacy zip format expects only one file from
# torch.save() < 1.6 in the zip. We should remove this support since
# zipfile is now default zipfile format for torch.save().
def _is_legacy_zip_format(filename):
if zipfile.is_zipfile(filename):
infolist = zipfile.ZipFile(filename).infolist()
return len(infolist) == 1 and not infolist[0].is_dir()
return False

def _legacy_zip_load(filename, model_dir, map_location):
warnings.warn('Falling back to the old format < 1.6. This support will'
' be deprecated in favor of default zipfile format '
'introduced in 1.6. Please redo torch.save() to save it '
'in the new zipfile format.')
# Note: extractall() defaults to overwrite file if exists. No need to
# clean up beforehand. We deliberately don't handle tarfile here
# since our legacy serialization format was in tar.
# E.g. resnet18-5c106cde.pth which is widely used.
with zipfile.ZipFile(filename) as f:
members = f.infolist()
if len(members) != 1:
raise RuntimeError(
'Only one file(not dir) is allowed in the zipfile')
f.extractall(model_dir)
extraced_name = members[0].filename
extracted_file = os.path.join(model_dir, extraced_name)
return torch.load(extracted_file, map_location=map_location)

def load_url(url,
model_dir=None,
map_location=None,
progress=True,
check_hash=False,
file_name=None):
r"""Loads the Torch serialized object at the given URL.
If downloaded file is a zip file, it will be automatically decompressed
If the object is already present in `model_dir`, it's deserialized and
returned.
The default value of ``model_dir`` is ``<hub_dir>/checkpoints`` where
``hub_dir`` is the directory returned by :func:`~torch.hub.get_dir`.
Args:
url (str): URL of the object to download
model_dir (str, optional): directory in which to save the object
map_location (optional): a function or a dict specifying how to
remap storage locations (see torch.load)
progress (bool, optional): whether or not to display a progress bar
to stderr. Default: True
check_hash(bool, optional): If True, the filename part of the URL
should follow the naming convention ``filename-<sha256>.ext``
where ``<sha256>`` is the first eight or more digits of the
SHA256 hash of the contents of the file. The hash is used to
ensure unique names and to verify the contents of the file.
Default: False
file_name (str, optional): name for the downloaded file. Filename
from ``url`` will be used if not set. Default: None.
Example:
>>> url = ('https://s3.amazonaws.com/pytorch/models/resnet18-5c106'
... 'cde.pth')
>>> state_dict = torch.hub.load_state_dict_from_url(url)
"""
# Issue warning to move data if old env is set
if os.getenv('TORCH_MODEL_ZOO'):
warnings.warn('TORCH_MODEL_ZOO is deprecated, please use env '
'TORCH_HOME instead')

if model_dir is None:
torch_home = _get_torch_home()
model_dir = os.path.join(torch_home, 'checkpoints')

mkdir_or_exist(model_dir)

parts = urlparse(url)
filename = os.path.basename(parts.path)
if file_name is not None:
filename = file_name
cached_file = os.path.join(model_dir, filename)
if not os.path.exists(cached_file):
sys.stderr.write('Downloading: "{}" to {}\n'.format(
url, cached_file))
hash_prefix = None
if check_hash:
r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
hash_prefix = r.group(1) if r else None
download_url_to_file(
url, cached_file, hash_prefix, progress=progress)

if _is_legacy_zip_format(cached_file):
return _legacy_zip_load(cached_file, model_dir, map_location)
return torch.load(cached_file, map_location=map_location)
else:
from torch.utils.model_zoo import load_url # noqa: F401

0 comments on commit 01fe0ed

Please sign in to comment.