Skip to content

Commit

Permalink
add unittest for load_url
Browse files Browse the repository at this point in the history
  • Loading branch information
zhouzaida committed Oct 11, 2021
1 parent 01fe0ed commit 17aad8d
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ jobs:
- name: Run unittests and generate coverage report
run: |
pip install -r requirements/test.txt
pytest tests/ --ignore=tests/test_runner --ignore=tests/test_optimizer.py --ignore=tests/test_cnn --ignore=tests/test_parallel.py --ignore=tests/test_ops --ignore=tests/test_load_model_zoo.py --ignore=tests/test_utils/test_logging.py --ignore=tests/test_image/test_io.py --ignore=tests/test_utils/test_registry.py --ignore=tests/test_utils/test_parrots_jit.py --ignore=tests/test_utils/test_trace.py
pytest tests/ --ignore=tests/test_runner --ignore=tests/test_optimizer.py --ignore=tests/test_cnn --ignore=tests/test_parallel.py --ignore=tests/test_ops --ignore=tests/test_load_model_zoo.py --ignore=tests/test_utils/test_logging.py --ignore=tests/test_image/test_io.py --ignore=tests/test_utils/test_registry.py --ignore=tests/test_utils/test_parrots_jit.py --ignore=tests/test_utils/test_trace.py --ignore=tests/test_utils/test_hub.py
build_without_ops:
runs-on: ubuntu-18.04
Expand Down
3 changes: 1 addition & 2 deletions mmcv/runner/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@
from torch.optim import Optimizer

import mmcv
from mmcv.utils import load_url
from ..fileio import FileClient
from ..fileio import load as load_file
from ..parallel import is_module_wrapper
from ..utils import mkdir_or_exist
from ..utils import load_url, mkdir_or_exist
from .dist_utils import get_dist_info

ENV_MMCV_HOME = 'MMCV_HOME'
Expand Down
11 changes: 6 additions & 5 deletions mmcv/utils/hub.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# The 1.6 release of PyTorch switched torch.save to use a new zipfile-based
# file format. It will cause RuntimeError when a checkpoint was saved in high
# version (PyTorch version>=1.6.0) but loaded in low version torch < 1.6.0.
# More details at# https://github.com/open-mmlab/mmpose/issues/904
from .parrots_wrapper import TORCH_VERSION, digit_version
# file format. It will cause RuntimeError when a checkpoint was saved in
# torch >= 1.6.0 but loaded in torch < 1.7.0.
# More details at https://github.com/open-mmlab/mmpose/issues/904
from .parrots_wrapper import TORCH_VERSION
from .path import mkdir_or_exist
from .version_utils import digit_version

if TORCH_VERSION != 'parrots' and digit_version(TORCH_VERSION) < digit_version(
'1.6.0'):
'1.7.0'):
# Modified from https://github.com/pytorch/pytorch/blob/master/torch/hub.py
import os
import torch
Expand Down
6 changes: 3 additions & 3 deletions tests/test_load_model_zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def load(filepath, map_location=None):

@patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')])
@patch('mmcv.runner.checkpoint.load_from_http', load_from_http)
@patch('mmcv.runner.checkpoint.load_url', load_url)
@patch('torch.load', load)
@patch('torch.utils.model_zoo.load_url', load_url)
def test_load_external_url():
# test modelzoo://
url = _load_checkpoint('modelzoo://resnet50')
Expand Down Expand Up @@ -128,7 +128,7 @@ def test_load_external_url():
os.environ[ENV_MMCV_HOME] = mmcv_home
url = _load_checkpoint('open-mmlab://train')
assert url == 'url:https://localhost/train.pth'
with pytest.raises(IOError, match='train.pth is not a checkpoint ' 'file'):
with pytest.raises(IOError, match='train.pth is not a checkpoint file'):
_load_checkpoint('open-mmlab://train_empty')
url = _load_checkpoint('open-mmlab://test')
assert url == f'local:{osp.join(_get_mmcv_home(), "test.pth")}'
Expand All @@ -140,7 +140,7 @@ def test_load_external_url():
assert url == 'url:http://localhost/train.pth'

# test local file
with pytest.raises(IOError, match='train.pth is not a checkpoint ' 'file'):
with pytest.raises(IOError, match='train.pth is not a checkpoint file'):
_load_checkpoint('train.pth')
url = _load_checkpoint(osp.join(_get_mmcv_home(), 'test.pth'))
assert url == f'local:{osp.join(_get_mmcv_home(), "test.pth")}'
28 changes: 28 additions & 0 deletions tests/test_utils/test_hub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pytest
from torch.utils import model_zoo

from mmcv.utils import TORCH_VERSION, digit_version, load_url


def test_load_url():
url1 = 'https://download.openmmlab.com/mmcv/test_data/saved_in_pt1.5.pth'
url2 = 'https://download.openmmlab.com/mmcv/test_data/saved_in_pt1.6.pth'

# The 1.6 release of PyTorch switched torch.save to use a new zipfile-based
# file format. It will cause RuntimeError when a checkpoint was saved in
# torch >= 1.6.0 but loaded in torch < 1.7.0.
# More details at https://github.com/open-mmlab/mmpose/issues/904
if digit_version(TORCH_VERSION) < digit_version('1.7.0'):
with pytest.raises(RuntimeError):
model_zoo.load_url(url2)
with pytest.raises(RuntimeError):
model_zoo.load_url(url1)
else:
# high version of PyTorch can load checkpoints from url, regardless
# of which version they were saved in
model_zoo.load_url(url1)
model_zoo.load_url(url2)

# load_url works well for all versions of PyTorch
load_url(url1)
load_url(url2)

0 comments on commit 17aad8d

Please sign in to comment.