From c28a22e0c9a811dbdba6e617987ab5780660460c Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 22 Apr 2021 16:51:58 +0100 Subject: [PATCH] Catch URLError (#237) * Catch URLError * Updates * Update CHANGELOG.md * Update CHANGELOG.md * Fix error --- CHANGELOG.md | 2 ++ flash/vision/backbones.py | 36 ++++++++++++++++++++++++++------- flash/vision/detection/model.py | 4 ++-- tests/vision/test_backbones.py | 14 ++++++++++++- 4 files changed, 46 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 627e907654..f116987086 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed classification softmax ([#169](https://github.com/PyTorchLightning/lightning-flash/pull/169)) +- Fixed a bug where loading from a local checkpoint that had `pretrained=True` without an internet connection would sometimes raise an error ([#237](https://github.com/PyTorchLightning/lightning-flash/pull/237)) + ### Removed diff --git a/flash/vision/backbones.py b/flash/vision/backbones.py index c1e6a144a9..d7a8fb9906 100644 --- a/flash/vision/backbones.py +++ b/flash/vision/backbones.py @@ -11,14 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import functools import os +import urllib.error import warnings from functools import partial from typing import Tuple from pytorch_lightning import LightningModule -from pytorch_lightning.utilities import _BOLTS_AVAILABLE +from pytorch_lightning.utilities import _BOLTS_AVAILABLE, rank_zero_warn from torch import nn as nn from torchvision.models.detection.backbone_utils import resnet_fpn_backbone @@ -51,6 +52,24 @@ OBJ_DETECTION_BACKBONES = FlashRegistry("backbones") +def catch_url_error(fn): + + @functools.wraps(fn) + def wrapper(pretrained=False, **kwargs): + try: + return fn(pretrained=pretrained, **kwargs) + except urllib.error.URLError: + result = fn(pretrained=False, **kwargs) + rank_zero_warn( + "Failed to download pretrained weights for the selected backbone. The backbone has been created with" + " `pretrained=False` instead. If you are loading from a local checkpoint, this warning can be safely" + " ignored.", UserWarning + ) + return result + + return wrapper + + @IMAGE_CLASSIFIER_BACKBONES(name="simclr-imagenet", namespace="vision", package="bolts") def load_simclr_imagenet(path_or_url: str = f"{ROOT_S3_BUCKET}/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt", **_): simclr: LightningModule = SimCLR.load_from_checkpoint(path_or_url, strict=False) @@ -83,7 +102,7 @@ def _fn_mobilenet_vgg(model_name: str, pretrained: bool = True) -> Tuple[nn.Modu _type = "mobilenet" if model_name in MOBILENET_MODELS else "vgg" IMAGE_CLASSIFIER_BACKBONES( - fn=partial(_fn_mobilenet_vgg, model_name), + fn=catch_url_error(partial(_fn_mobilenet_vgg, model_name)), name=model_name, namespace="vision", package="torchvision", @@ -99,7 +118,7 @@ def _fn_resnet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int return backbone, num_features IMAGE_CLASSIFIER_BACKBONES( - fn=partial(_fn_resnet, model_name), + fn=catch_url_error(partial(_fn_resnet, model_name)), name=model_name, namespace="vision", package="torchvision", @@ -118,7 +137,10 @@ def _fn_resnet_fpn( return backbone, 256 OBJ_DETECTION_BACKBONES( - fn=partial(_fn_resnet_fpn, model_name), name=model_name, package="torchvision", type="resnet-fpn" + fn=catch_url_error(partial(_fn_resnet_fpn, model_name)), + name=model_name, + package="torchvision", + type="resnet-fpn" ) for model_name in DENSENET_MODELS: @@ -130,7 +152,7 @@ def _fn_densenet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, i return backbone, num_features IMAGE_CLASSIFIER_BACKBONES( - fn=partial(_fn_densenet, model_name), + fn=catch_url_error(partial(_fn_densenet, model_name)), name=model_name, namespace="vision", package="torchvision", @@ -156,5 +178,5 @@ def _fn_timm( return backbone, num_features IMAGE_CLASSIFIER_BACKBONES( - fn=partial(_fn_timm, model_name), name=model_name, namespace="vision", package="timm" + fn=catch_url_error(partial(_fn_timm, model_name)), name=model_name, namespace="vision", package="timm" ) diff --git a/flash/vision/detection/model.py b/flash/vision/detection/model.py index b36d726d62..a7eed0e105 100644 --- a/flash/vision/detection/model.py +++ b/flash/vision/detection/model.py @@ -137,8 +137,8 @@ def get_model( ) else: backbone_model, num_features = ObjectDetector.backbones.get(backbone)( - pretrained_backbone, - trainable_backbone_layers, + pretrained=pretrained_backbone, + trainable_layers=trainable_backbone_layers, **kwargs, ) backbone_model.out_channels = num_features diff --git a/tests/vision/test_backbones.py b/tests/vision/test_backbones.py index e2dd5882e9..046f1a5f1c 100644 --- a/tests/vision/test_backbones.py +++ b/tests/vision/test_backbones.py @@ -1,8 +1,10 @@ +import urllib.error + import pytest from pytorch_lightning.utilities import _BOLTS_AVAILABLE, _TORCHVISION_AVAILABLE from flash.utils.imports import _TIMM_AVAILABLE -from flash.vision.backbones import IMAGE_CLASSIFIER_BACKBONES +from flash.vision.backbones import catch_url_error, IMAGE_CLASSIFIER_BACKBONES @pytest.mark.parametrize(["backbone", "expected_num_features"], [ @@ -17,3 +19,13 @@ def test_image_classifier_backbones_registry(backbone, expected_num_features): backbone_model, num_features = backbone_fn(pretrained=False) assert backbone_model assert num_features == expected_num_features + + +def test_pretrained_backbones_catch_url_error(): + + def raise_error_if_pretrained(pretrained=False): + if pretrained: + raise urllib.error.URLError('Test error') + + with pytest.warns(UserWarning, match="Failed to download pretrained weights"): + catch_url_error(raise_error_if_pretrained)(pretrained=True)