From 17aad8dc8747ec87182d5fb5d74607a9ab3d1966 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Sun, 10 Oct 2021 18:06:55 +0800 Subject: [PATCH] add unittest for load_url --- .github/workflows/build.yml | 2 +- mmcv/runner/checkpoint.py | 3 +-- mmcv/utils/hub.py | 11 ++++++----- tests/test_load_model_zoo.py | 6 +++--- tests/test_utils/test_hub.py | 28 ++++++++++++++++++++++++++++ 5 files changed, 39 insertions(+), 11 deletions(-) create mode 100644 tests/test_utils/test_hub.py diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 697c3ba6c5..d65a1671b5 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -50,7 +50,7 @@ jobs: - name: Run unittests and generate coverage report run: | pip install -r requirements/test.txt - pytest tests/ --ignore=tests/test_runner --ignore=tests/test_optimizer.py --ignore=tests/test_cnn --ignore=tests/test_parallel.py --ignore=tests/test_ops --ignore=tests/test_load_model_zoo.py --ignore=tests/test_utils/test_logging.py --ignore=tests/test_image/test_io.py --ignore=tests/test_utils/test_registry.py --ignore=tests/test_utils/test_parrots_jit.py --ignore=tests/test_utils/test_trace.py + pytest tests/ --ignore=tests/test_runner --ignore=tests/test_optimizer.py --ignore=tests/test_cnn --ignore=tests/test_parallel.py --ignore=tests/test_ops --ignore=tests/test_load_model_zoo.py --ignore=tests/test_utils/test_logging.py --ignore=tests/test_image/test_io.py --ignore=tests/test_utils/test_registry.py --ignore=tests/test_utils/test_parrots_jit.py --ignore=tests/test_utils/test_trace.py --ignore=tests/test_utils/test_hub.py build_without_ops: runs-on: ubuntu-18.04 diff --git a/mmcv/runner/checkpoint.py b/mmcv/runner/checkpoint.py index abb1d193dc..b07f0a6755 100644 --- a/mmcv/runner/checkpoint.py +++ b/mmcv/runner/checkpoint.py @@ -15,11 +15,10 @@ from torch.optim import Optimizer import mmcv -from mmcv.utils import load_url from ..fileio import FileClient from ..fileio import load as load_file from ..parallel import is_module_wrapper -from ..utils import mkdir_or_exist +from ..utils import load_url, mkdir_or_exist from .dist_utils import get_dist_info ENV_MMCV_HOME = 'MMCV_HOME' diff --git a/mmcv/utils/hub.py b/mmcv/utils/hub.py index 839d4e6e66..f2ec13e1af 100644 --- a/mmcv/utils/hub.py +++ b/mmcv/utils/hub.py @@ -1,12 +1,13 @@ # The 1.6 release of PyTorch switched torch.save to use a new zipfile-based -# file format. It will cause RuntimeError when a checkpoint was saved in high -# version (PyTorch version>=1.6.0) but loaded in low version torch < 1.6.0. -# More details at# https://github.com/open-mmlab/mmpose/issues/904 -from .parrots_wrapper import TORCH_VERSION, digit_version +# file format. It will cause RuntimeError when a checkpoint was saved in +# torch >= 1.6.0 but loaded in torch < 1.7.0. +# More details at https://github.com/open-mmlab/mmpose/issues/904 +from .parrots_wrapper import TORCH_VERSION from .path import mkdir_or_exist +from .version_utils import digit_version if TORCH_VERSION != 'parrots' and digit_version(TORCH_VERSION) < digit_version( - '1.6.0'): + '1.7.0'): # Modified from https://github.com/pytorch/pytorch/blob/master/torch/hub.py import os import torch diff --git a/tests/test_load_model_zoo.py b/tests/test_load_model_zoo.py index ada0bc78f6..c230bfa62e 100644 --- a/tests/test_load_model_zoo.py +++ b/tests/test_load_model_zoo.py @@ -73,8 +73,8 @@ def load(filepath, map_location=None): @patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')]) @patch('mmcv.runner.checkpoint.load_from_http', load_from_http) +@patch('mmcv.runner.checkpoint.load_url', load_url) @patch('torch.load', load) -@patch('torch.utils.model_zoo.load_url', load_url) def test_load_external_url(): # test modelzoo:// url = _load_checkpoint('modelzoo://resnet50') @@ -128,7 +128,7 @@ def test_load_external_url(): os.environ[ENV_MMCV_HOME] = mmcv_home url = _load_checkpoint('open-mmlab://train') assert url == 'url:https://localhost/train.pth' - with pytest.raises(IOError, match='train.pth is not a checkpoint ' 'file'): + with pytest.raises(IOError, match='train.pth is not a checkpoint file'): _load_checkpoint('open-mmlab://train_empty') url = _load_checkpoint('open-mmlab://test') assert url == f'local:{osp.join(_get_mmcv_home(), "test.pth")}' @@ -140,7 +140,7 @@ def test_load_external_url(): assert url == 'url:http://localhost/train.pth' # test local file - with pytest.raises(IOError, match='train.pth is not a checkpoint ' 'file'): + with pytest.raises(IOError, match='train.pth is not a checkpoint file'): _load_checkpoint('train.pth') url = _load_checkpoint(osp.join(_get_mmcv_home(), 'test.pth')) assert url == f'local:{osp.join(_get_mmcv_home(), "test.pth")}' diff --git a/tests/test_utils/test_hub.py b/tests/test_utils/test_hub.py new file mode 100644 index 0000000000..cc3d700360 --- /dev/null +++ b/tests/test_utils/test_hub.py @@ -0,0 +1,28 @@ +import pytest +from torch.utils import model_zoo + +from mmcv.utils import TORCH_VERSION, digit_version, load_url + + +def test_load_url(): + url1 = 'https://download.openmmlab.com/mmcv/test_data/saved_in_pt1.5.pth' + url2 = 'https://download.openmmlab.com/mmcv/test_data/saved_in_pt1.6.pth' + + # The 1.6 release of PyTorch switched torch.save to use a new zipfile-based + # file format. It will cause RuntimeError when a checkpoint was saved in + # torch >= 1.6.0 but loaded in torch < 1.7.0. + # More details at https://github.com/open-mmlab/mmpose/issues/904 + if digit_version(TORCH_VERSION) < digit_version('1.7.0'): + with pytest.raises(RuntimeError): + model_zoo.load_url(url2) + with pytest.raises(RuntimeError): + model_zoo.load_url(url1) + else: + # high version of PyTorch can load checkpoints from url, regardless + # of which version they were saved in + model_zoo.load_url(url1) + model_zoo.load_url(url2) + + # load_url works well for all versions of PyTorch + load_url(url1) + load_url(url2)