From 990d8b6a831acdfa07689b52d4f76b68feb68c0b Mon Sep 17 00:00:00 2001 From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Date: Fri, 19 Nov 2021 17:34:18 +0800 Subject: [PATCH] [Fix] Add load_url to handle incompatibility of PyTorch versions (#1377) * [Fix] Fix torch.load error * [Fix] Fix torch.load error * rename _save to _save_ckpt * add load_url to handle imcompatibility of PyTorch versions * add unittest for load_url * fix typo * print a friendly information when error occurred --- .github/workflows/build.yml | 2 +- mmcv/runner/checkpoint.py | 7 +- mmcv/utils/__init__.py | 3 +- mmcv/utils/hub.py | 127 +++++++++++++++++++++++++++++++++++ tests/test_load_model_zoo.py | 6 +- tests/test_utils/test_hub.py | 32 +++++++++ 6 files changed, 168 insertions(+), 9 deletions(-) create mode 100644 mmcv/utils/hub.py create mode 100644 tests/test_utils/test_hub.py diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 9c52827ed2..91f1cd54c1 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -47,7 +47,7 @@ jobs: - name: Run unittests and generate coverage report run: | pip install -r requirements/test.txt - pytest tests/ --ignore=tests/test_runner --ignore=tests/test_optimizer.py --ignore=tests/test_cnn --ignore=tests/test_parallel.py --ignore=tests/test_ops --ignore=tests/test_load_model_zoo.py --ignore=tests/test_utils/test_logging.py --ignore=tests/test_image/test_io.py --ignore=tests/test_utils/test_registry.py --ignore=tests/test_utils/test_parrots_jit.py --ignore=tests/test_utils/test_trace.py + pytest tests/ --ignore=tests/test_runner --ignore=tests/test_optimizer.py --ignore=tests/test_cnn --ignore=tests/test_parallel.py --ignore=tests/test_ops --ignore=tests/test_load_model_zoo.py --ignore=tests/test_utils/test_logging.py --ignore=tests/test_image/test_io.py --ignore=tests/test_utils/test_registry.py --ignore=tests/test_utils/test_parrots_jit.py --ignore=tests/test_utils/test_trace.py --ignore=tests/test_utils/test_hub.py build_without_ops: runs-on: ubuntu-18.04 diff --git a/mmcv/runner/checkpoint.py b/mmcv/runner/checkpoint.py index 6ad605b854..678e385cc9 100644 --- a/mmcv/runner/checkpoint.py +++ b/mmcv/runner/checkpoint.py @@ -13,13 +13,12 @@ import torch import torchvision from torch.optim import Optimizer -from torch.utils import model_zoo import mmcv from ..fileio import FileClient from ..fileio import load as load_file from ..parallel import is_module_wrapper -from ..utils import mkdir_or_exist +from ..utils import load_url, mkdir_or_exist from .dist_utils import get_dist_info ENV_MMCV_HOME = 'MMCV_HOME' @@ -281,12 +280,12 @@ def load_from_http(filename, map_location=None, model_dir=None): rank, world_size = get_dist_info() rank = int(os.environ.get('LOCAL_RANK', rank)) if rank == 0: - checkpoint = model_zoo.load_url( + checkpoint = load_url( filename, model_dir=model_dir, map_location=map_location) if world_size > 1: torch.distributed.barrier() if rank > 0: - checkpoint = model_zoo.load_url( + checkpoint = load_url( filename, model_dir=model_dir, map_location=map_location) return checkpoint diff --git a/mmcv/utils/__init__.py b/mmcv/utils/__init__.py index 378a006843..478f015111 100644 --- a/mmcv/utils/__init__.py +++ b/mmcv/utils/__init__.py @@ -46,6 +46,7 @@ _MaxPoolNd, get_build_config, is_rocm_pytorch, _get_cuda_home) from .registry import Registry, build_from_cfg from .trace import is_jit_tracing + from .hub import load_url __all__ = [ 'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger', 'print_log', 'is_str', 'iter_cast', 'list_cast', 'tuple_cast', @@ -65,5 +66,5 @@ 'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer', 'assert_params_all_zeros', 'check_python_script', 'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch', - '_get_cuda_home', 'has_method' + '_get_cuda_home', 'load_url', 'has_method' ] diff --git a/mmcv/utils/hub.py b/mmcv/utils/hub.py new file mode 100644 index 0000000000..4e11796af0 --- /dev/null +++ b/mmcv/utils/hub.py @@ -0,0 +1,127 @@ +# The 1.6 release of PyTorch switched torch.save to use a new zipfile-based +# file format. It will cause RuntimeError when a checkpoint was saved in +# torch >= 1.6.0 but loaded in torch < 1.7.0. +# More details at https://github.com/open-mmlab/mmpose/issues/904 +from .parrots_wrapper import TORCH_VERSION +from .path import mkdir_or_exist +from .version_utils import digit_version + +if TORCH_VERSION != 'parrots' and digit_version(TORCH_VERSION) < digit_version( + '1.7.0'): + # Modified from https://github.com/pytorch/pytorch/blob/master/torch/hub.py + import os + import torch + import warnings + from urllib.parse import urlparse + import sys + import zipfile + from torch.hub import download_url_to_file, _get_torch_home, HASH_REGEX + + # Hub used to support automatically extracts from zipfile manually + # compressed by users. The legacy zip format expects only one file from + # torch.save() < 1.6 in the zip. We should remove this support since + # zipfile is now default zipfile format for torch.save(). + def _is_legacy_zip_format(filename): + if zipfile.is_zipfile(filename): + infolist = zipfile.ZipFile(filename).infolist() + return len(infolist) == 1 and not infolist[0].is_dir() + return False + + def _legacy_zip_load(filename, model_dir, map_location): + warnings.warn('Falling back to the old format < 1.6. This support will' + ' be deprecated in favor of default zipfile format ' + 'introduced in 1.6. Please redo torch.save() to save it ' + 'in the new zipfile format.') + # Note: extractall() defaults to overwrite file if exists. No need to + # clean up beforehand. We deliberately don't handle tarfile here + # since our legacy serialization format was in tar. + # E.g. resnet18-5c106cde.pth which is widely used. + with zipfile.ZipFile(filename) as f: + members = f.infolist() + if len(members) != 1: + raise RuntimeError( + 'Only one file(not dir) is allowed in the zipfile') + f.extractall(model_dir) + extraced_name = members[0].filename + extracted_file = os.path.join(model_dir, extraced_name) + return torch.load(extracted_file, map_location=map_location) + + def load_url(url, + model_dir=None, + map_location=None, + progress=True, + check_hash=False, + file_name=None): + r"""Loads the Torch serialized object at the given URL. + + If downloaded file is a zip file, it will be automatically decompressed + + If the object is already present in `model_dir`, it's deserialized and + returned. + The default value of ``model_dir`` is ``/checkpoints`` where + ``hub_dir`` is the directory returned by :func:`~torch.hub.get_dir`. + + Args: + url (str): URL of the object to download + model_dir (str, optional): directory in which to save the object + map_location (optional): a function or a dict specifying how to + remap storage locations (see torch.load) + progress (bool, optional): whether or not to display a progress bar + to stderr. Default: True + check_hash(bool, optional): If True, the filename part of the URL + should follow the naming convention ``filename-.ext`` + where ```` is the first eight or more digits of the + SHA256 hash of the contents of the file. The hash is used to + ensure unique names and to verify the contents of the file. + Default: False + file_name (str, optional): name for the downloaded file. Filename + from ``url`` will be used if not set. Default: None. + + Example: + >>> url = ('https://s3.amazonaws.com/pytorch/models/resnet18-5c106' + ... 'cde.pth') + >>> state_dict = torch.hub.load_state_dict_from_url(url) + """ + # Issue warning to move data if old env is set + if os.getenv('TORCH_MODEL_ZOO'): + warnings.warn('TORCH_MODEL_ZOO is deprecated, please use env ' + 'TORCH_HOME instead') + + if model_dir is None: + torch_home = _get_torch_home() + model_dir = os.path.join(torch_home, 'checkpoints') + + mkdir_or_exist(model_dir) + + parts = urlparse(url) + filename = os.path.basename(parts.path) + if file_name is not None: + filename = file_name + cached_file = os.path.join(model_dir, filename) + if not os.path.exists(cached_file): + sys.stderr.write('Downloading: "{}" to {}\n'.format( + url, cached_file)) + hash_prefix = None + if check_hash: + r = HASH_REGEX.search(filename) # r is Optional[Match[str]] + hash_prefix = r.group(1) if r else None + download_url_to_file( + url, cached_file, hash_prefix, progress=progress) + + if _is_legacy_zip_format(cached_file): + return _legacy_zip_load(cached_file, model_dir, map_location) + + try: + return torch.load(cached_file, map_location=map_location) + except RuntimeError as error: + if digit_version(TORCH_VERSION) < digit_version('1.5.0'): + warnings.warn( + f'If the error is the same as "{cached_file} is a zip ' + 'archive (did you mean to use torch.jit.load()?)", you can' + ' upgrade your torch to 1.5.0 or higher (current torch ' + f'version is {TORCH_VERSION}). The error was raised ' + ' because the checkpoint was saved in torch>=1.6.0 but ' + 'loaded in torch<1.5.') + raise error +else: + from torch.utils.model_zoo import load_url # noqa: F401 diff --git a/tests/test_load_model_zoo.py b/tests/test_load_model_zoo.py index ada0bc78f6..c230bfa62e 100644 --- a/tests/test_load_model_zoo.py +++ b/tests/test_load_model_zoo.py @@ -73,8 +73,8 @@ def load(filepath, map_location=None): @patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')]) @patch('mmcv.runner.checkpoint.load_from_http', load_from_http) +@patch('mmcv.runner.checkpoint.load_url', load_url) @patch('torch.load', load) -@patch('torch.utils.model_zoo.load_url', load_url) def test_load_external_url(): # test modelzoo:// url = _load_checkpoint('modelzoo://resnet50') @@ -128,7 +128,7 @@ def test_load_external_url(): os.environ[ENV_MMCV_HOME] = mmcv_home url = _load_checkpoint('open-mmlab://train') assert url == 'url:https://localhost/train.pth' - with pytest.raises(IOError, match='train.pth is not a checkpoint ' 'file'): + with pytest.raises(IOError, match='train.pth is not a checkpoint file'): _load_checkpoint('open-mmlab://train_empty') url = _load_checkpoint('open-mmlab://test') assert url == f'local:{osp.join(_get_mmcv_home(), "test.pth")}' @@ -140,7 +140,7 @@ def test_load_external_url(): assert url == 'url:http://localhost/train.pth' # test local file - with pytest.raises(IOError, match='train.pth is not a checkpoint ' 'file'): + with pytest.raises(IOError, match='train.pth is not a checkpoint file'): _load_checkpoint('train.pth') url = _load_checkpoint(osp.join(_get_mmcv_home(), 'test.pth')) assert url == f'local:{osp.join(_get_mmcv_home(), "test.pth")}' diff --git a/tests/test_utils/test_hub.py b/tests/test_utils/test_hub.py new file mode 100644 index 0000000000..046415152f --- /dev/null +++ b/tests/test_utils/test_hub.py @@ -0,0 +1,32 @@ +import pytest +from torch.utils import model_zoo + +from mmcv.utils import TORCH_VERSION, digit_version, load_url + + +def test_load_url(): + url1 = 'https://download.openmmlab.com/mmcv/test_data/saved_in_pt1.5.pth' + url2 = 'https://download.openmmlab.com/mmcv/test_data/saved_in_pt1.6.pth' + + # The 1.6 release of PyTorch switched torch.save to use a new zipfile-based + # file format. It will cause RuntimeError when a checkpoint was saved in + # torch >= 1.6.0 but loaded in torch < 1.7.0. + # More details at https://github.com/open-mmlab/mmpose/issues/904 + if digit_version(TORCH_VERSION) < digit_version('1.7.0'): + model_zoo.load_url(url1) + with pytest.raises(RuntimeError): + model_zoo.load_url(url2) + else: + # high version of PyTorch can load checkpoints from url, regardless + # of which version they were saved in + model_zoo.load_url(url1) + model_zoo.load_url(url2) + + load_url(url1) + # if a checkpoint was saved in torch >= 1.6.0 but loaded in torch < 1.5.0, + # it will raise a RuntimeError + if digit_version(TORCH_VERSION) < digit_version('1.5.0'): + with pytest.raises(RuntimeError): + load_url(url2) + else: + load_url(url2)