From 9d32d8b5dd5000e2f8b23cc598f0de4822c9fda4 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 22 Apr 2021 12:22:51 +0100 Subject: [PATCH 1/5] Catch URLError --- flash/vision/backbones.py | 38 +++++++++++++++++++++++++++------- tests/vision/test_backbones.py | 14 ++++++++++++- 2 files changed, 44 insertions(+), 8 deletions(-) diff --git a/flash/vision/backbones.py b/flash/vision/backbones.py index c1e6a144a9..0c4974a8b8 100644 --- a/flash/vision/backbones.py +++ b/flash/vision/backbones.py @@ -11,14 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import functools import os +import urllib.error import warnings from functools import partial from typing import Tuple from pytorch_lightning import LightningModule -from pytorch_lightning.utilities import _BOLTS_AVAILABLE +from pytorch_lightning.utilities import _BOLTS_AVAILABLE, rank_zero_warn from torch import nn as nn from torchvision.models.detection.backbone_utils import resnet_fpn_backbone @@ -51,7 +52,26 @@ OBJ_DETECTION_BACKBONES = FlashRegistry("backbones") +def catch_url_error(fn): + + @functools.wraps(fn) + def wrapper(*args, pretrained=False, **kwargs): + try: + return fn(*args, pretrained=pretrained, **kwargs) + except urllib.error.URLError: + result = fn(*args, pretrained=False, **kwargs) + rank_zero_warn( + "Failed to download pretrained weights for the selected backbone. The backbone has been created with" + " `pretrained=False` instead. If you are loading from a local checkpoint, this warning can be safely" + " ignored.", UserWarning + ) + return result + + return wrapper + + @IMAGE_CLASSIFIER_BACKBONES(name="simclr-imagenet", namespace="vision", package="bolts") +@catch_url_error def load_simclr_imagenet(path_or_url: str = f"{ROOT_S3_BUCKET}/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt", **_): simclr: LightningModule = SimCLR.load_from_checkpoint(path_or_url, strict=False) # remove the last two layers & turn it into a Sequential model @@ -60,6 +80,7 @@ def load_simclr_imagenet(path_or_url: str = f"{ROOT_S3_BUCKET}/simclr/bolts_simc @IMAGE_CLASSIFIER_BACKBONES(name="swav-imagenet", namespace="vision", package="bolts") +@catch_url_error def load_swav_imagenet( path_or_url: str = f"{ROOT_S3_BUCKET}/swav/swav_imagenet/swav_imagenet.pth.tar", **_, @@ -83,7 +104,7 @@ def _fn_mobilenet_vgg(model_name: str, pretrained: bool = True) -> Tuple[nn.Modu _type = "mobilenet" if model_name in MOBILENET_MODELS else "vgg" IMAGE_CLASSIFIER_BACKBONES( - fn=partial(_fn_mobilenet_vgg, model_name), + fn=catch_url_error(partial(_fn_mobilenet_vgg, model_name)), name=model_name, namespace="vision", package="torchvision", @@ -99,7 +120,7 @@ def _fn_resnet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int return backbone, num_features IMAGE_CLASSIFIER_BACKBONES( - fn=partial(_fn_resnet, model_name), + fn=catch_url_error(partial(_fn_resnet, model_name)), name=model_name, namespace="vision", package="torchvision", @@ -118,7 +139,10 @@ def _fn_resnet_fpn( return backbone, 256 OBJ_DETECTION_BACKBONES( - fn=partial(_fn_resnet_fpn, model_name), name=model_name, package="torchvision", type="resnet-fpn" + fn=catch_url_error(partial(_fn_resnet_fpn, model_name)), + name=model_name, + package="torchvision", + type="resnet-fpn" ) for model_name in DENSENET_MODELS: @@ -130,7 +154,7 @@ def _fn_densenet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, i return backbone, num_features IMAGE_CLASSIFIER_BACKBONES( - fn=partial(_fn_densenet, model_name), + fn=catch_url_error(partial(_fn_densenet, model_name)), name=model_name, namespace="vision", package="torchvision", @@ -156,5 +180,5 @@ def _fn_timm( return backbone, num_features IMAGE_CLASSIFIER_BACKBONES( - fn=partial(_fn_timm, model_name), name=model_name, namespace="vision", package="timm" + fn=catch_url_error(partial(_fn_timm, model_name)), name=model_name, namespace="vision", package="timm" ) diff --git a/tests/vision/test_backbones.py b/tests/vision/test_backbones.py index e2dd5882e9..046f1a5f1c 100644 --- a/tests/vision/test_backbones.py +++ b/tests/vision/test_backbones.py @@ -1,8 +1,10 @@ +import urllib.error + import pytest from pytorch_lightning.utilities import _BOLTS_AVAILABLE, _TORCHVISION_AVAILABLE from flash.utils.imports import _TIMM_AVAILABLE -from flash.vision.backbones import IMAGE_CLASSIFIER_BACKBONES +from flash.vision.backbones import catch_url_error, IMAGE_CLASSIFIER_BACKBONES @pytest.mark.parametrize(["backbone", "expected_num_features"], [ @@ -17,3 +19,13 @@ def test_image_classifier_backbones_registry(backbone, expected_num_features): backbone_model, num_features = backbone_fn(pretrained=False) assert backbone_model assert num_features == expected_num_features + + +def test_pretrained_backbones_catch_url_error(): + + def raise_error_if_pretrained(pretrained=False): + if pretrained: + raise urllib.error.URLError('Test error') + + with pytest.warns(UserWarning, match="Failed to download pretrained weights"): + catch_url_error(raise_error_if_pretrained)(pretrained=True) From 110195d20cf4d1570941e18471ac7bf54de98913 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 22 Apr 2021 12:25:22 +0100 Subject: [PATCH 2/5] Updates --- flash/vision/backbones.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/flash/vision/backbones.py b/flash/vision/backbones.py index 0c4974a8b8..577398e9d6 100644 --- a/flash/vision/backbones.py +++ b/flash/vision/backbones.py @@ -71,7 +71,6 @@ def wrapper(*args, pretrained=False, **kwargs): @IMAGE_CLASSIFIER_BACKBONES(name="simclr-imagenet", namespace="vision", package="bolts") -@catch_url_error def load_simclr_imagenet(path_or_url: str = f"{ROOT_S3_BUCKET}/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt", **_): simclr: LightningModule = SimCLR.load_from_checkpoint(path_or_url, strict=False) # remove the last two layers & turn it into a Sequential model @@ -80,7 +79,6 @@ def load_simclr_imagenet(path_or_url: str = f"{ROOT_S3_BUCKET}/simclr/bolts_simc @IMAGE_CLASSIFIER_BACKBONES(name="swav-imagenet", namespace="vision", package="bolts") -@catch_url_error def load_swav_imagenet( path_or_url: str = f"{ROOT_S3_BUCKET}/swav/swav_imagenet/swav_imagenet.pth.tar", **_, From 08e789ca9f47508d8c28829f8e637ae593345b81 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 22 Apr 2021 12:27:29 +0100 Subject: [PATCH 3/5] Update CHANGELOG.md --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 627e907654..78e1c71d18 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed classification softmax ([#169](https://github.com/PyTorchLightning/lightning-flash/pull/169)) +- Fixed a bug where loading from a local checkpoint without an internet connection would sometimes raise an error ([#237](https://github.com/PyTorchLightning/lightning-flash/pull/237)) + ### Removed From a429a16902c3a2f6d7b128d17fa48e3e23889031 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 22 Apr 2021 13:13:40 +0100 Subject: [PATCH 4/5] Update CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 78e1c71d18..f116987086 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,7 +23,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed classification softmax ([#169](https://github.com/PyTorchLightning/lightning-flash/pull/169)) -- Fixed a bug where loading from a local checkpoint without an internet connection would sometimes raise an error ([#237](https://github.com/PyTorchLightning/lightning-flash/pull/237)) +- Fixed a bug where loading from a local checkpoint that had `pretrained=True` without an internet connection would sometimes raise an error ([#237](https://github.com/PyTorchLightning/lightning-flash/pull/237)) ### Removed From 0115ed6551c6dbf51682982b5616ba0aded09f25 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 22 Apr 2021 13:41:52 +0100 Subject: [PATCH 5/5] Fix error --- flash/vision/backbones.py | 6 +++--- flash/vision/detection/model.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/flash/vision/backbones.py b/flash/vision/backbones.py index 577398e9d6..d7a8fb9906 100644 --- a/flash/vision/backbones.py +++ b/flash/vision/backbones.py @@ -55,11 +55,11 @@ def catch_url_error(fn): @functools.wraps(fn) - def wrapper(*args, pretrained=False, **kwargs): + def wrapper(pretrained=False, **kwargs): try: - return fn(*args, pretrained=pretrained, **kwargs) + return fn(pretrained=pretrained, **kwargs) except urllib.error.URLError: - result = fn(*args, pretrained=False, **kwargs) + result = fn(pretrained=False, **kwargs) rank_zero_warn( "Failed to download pretrained weights for the selected backbone. The backbone has been created with" " `pretrained=False` instead. If you are loading from a local checkpoint, this warning can be safely" diff --git a/flash/vision/detection/model.py b/flash/vision/detection/model.py index b36d726d62..a7eed0e105 100644 --- a/flash/vision/detection/model.py +++ b/flash/vision/detection/model.py @@ -137,8 +137,8 @@ def get_model( ) else: backbone_model, num_features = ObjectDetector.backbones.get(backbone)( - pretrained_backbone, - trainable_backbone_layers, + pretrained=pretrained_backbone, + trainable_layers=trainable_backbone_layers, **kwargs, ) backbone_model.out_channels = num_features