Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Catch URLError #237

Merged
merged 7 commits into from
Apr 22, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed classification softmax ([#169](https://github.com/PyTorchLightning/lightning-flash/pull/169))

- Fixed a bug where loading from a local checkpoint without an internet connection would sometimes raise an error ([#237](https://github.com/PyTorchLightning/lightning-flash/pull/237))
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved

### Removed


Expand Down
36 changes: 29 additions & 7 deletions flash/vision/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import os
import urllib.error
import warnings
from functools import partial
from typing import Tuple

from pytorch_lightning import LightningModule
from pytorch_lightning.utilities import _BOLTS_AVAILABLE
from pytorch_lightning.utilities import _BOLTS_AVAILABLE, rank_zero_warn
from torch import nn as nn
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone

Expand Down Expand Up @@ -51,6 +52,24 @@
OBJ_DETECTION_BACKBONES = FlashRegistry("backbones")


def catch_url_error(fn):

@functools.wraps(fn)
def wrapper(*args, pretrained=False, **kwargs):
try:
return fn(*args, pretrained=pretrained, **kwargs)
except urllib.error.URLError:
result = fn(*args, pretrained=False, **kwargs)
rank_zero_warn(
"Failed to download pretrained weights for the selected backbone. The backbone has been created with"
" `pretrained=False` instead. If you are loading from a local checkpoint, this warning can be safely"
" ignored.", UserWarning
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed the order so if the fn call fails too, the warning is still shown

Suggested change
result = fn(*args, pretrained=False, **kwargs)
rank_zero_warn(
"Failed to download pretrained weights for the selected backbone. The backbone has been created with"
" `pretrained=False` instead. If you are loading from a local checkpoint, this warning can be safely"
" ignored.", UserWarning
)
rank_zero_warn(
"Failed to download pretrained weights for the selected backbone. The backbone has been created with"
" `pretrained=False` instead. If you are loading from a local checkpoint, this warning can be safely"
" ignored.", UserWarning
)
result = fn(*args, pretrained=False, **kwargs)

Copy link
Collaborator Author

@ethanwharris ethanwharris Apr 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure we want this. If it fails with pretrained=False then the warning is no longer true when it says "the backbone has been created"?

return result

return wrapper


@IMAGE_CLASSIFIER_BACKBONES(name="simclr-imagenet", namespace="vision", package="bolts")
def load_simclr_imagenet(path_or_url: str = f"{ROOT_S3_BUCKET}/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt", **_):
simclr: LightningModule = SimCLR.load_from_checkpoint(path_or_url, strict=False)
Expand Down Expand Up @@ -83,7 +102,7 @@ def _fn_mobilenet_vgg(model_name: str, pretrained: bool = True) -> Tuple[nn.Modu
_type = "mobilenet" if model_name in MOBILENET_MODELS else "vgg"

IMAGE_CLASSIFIER_BACKBONES(
fn=partial(_fn_mobilenet_vgg, model_name),
fn=catch_url_error(partial(_fn_mobilenet_vgg, model_name)),
name=model_name,
namespace="vision",
package="torchvision",
Expand All @@ -99,7 +118,7 @@ def _fn_resnet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int
return backbone, num_features

IMAGE_CLASSIFIER_BACKBONES(
fn=partial(_fn_resnet, model_name),
fn=catch_url_error(partial(_fn_resnet, model_name)),
name=model_name,
namespace="vision",
package="torchvision",
Expand All @@ -118,7 +137,10 @@ def _fn_resnet_fpn(
return backbone, 256

OBJ_DETECTION_BACKBONES(
fn=partial(_fn_resnet_fpn, model_name), name=model_name, package="torchvision", type="resnet-fpn"
fn=catch_url_error(partial(_fn_resnet_fpn, model_name)),
name=model_name,
package="torchvision",
type="resnet-fpn"
)

for model_name in DENSENET_MODELS:
Expand All @@ -130,7 +152,7 @@ def _fn_densenet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, i
return backbone, num_features

IMAGE_CLASSIFIER_BACKBONES(
fn=partial(_fn_densenet, model_name),
fn=catch_url_error(partial(_fn_densenet, model_name)),
name=model_name,
namespace="vision",
package="torchvision",
Expand All @@ -156,5 +178,5 @@ def _fn_timm(
return backbone, num_features

IMAGE_CLASSIFIER_BACKBONES(
fn=partial(_fn_timm, model_name), name=model_name, namespace="vision", package="timm"
fn=catch_url_error(partial(_fn_timm, model_name)), name=model_name, namespace="vision", package="timm"
)
14 changes: 13 additions & 1 deletion tests/vision/test_backbones.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import urllib.error

import pytest
from pytorch_lightning.utilities import _BOLTS_AVAILABLE, _TORCHVISION_AVAILABLE

from flash.utils.imports import _TIMM_AVAILABLE
from flash.vision.backbones import IMAGE_CLASSIFIER_BACKBONES
from flash.vision.backbones import catch_url_error, IMAGE_CLASSIFIER_BACKBONES


@pytest.mark.parametrize(["backbone", "expected_num_features"], [
Expand All @@ -17,3 +19,13 @@ def test_image_classifier_backbones_registry(backbone, expected_num_features):
backbone_model, num_features = backbone_fn(pretrained=False)
assert backbone_model
assert num_features == expected_num_features


def test_pretrained_backbones_catch_url_error():

def raise_error_if_pretrained(pretrained=False):
if pretrained:
raise urllib.error.URLError('Test error')

with pytest.warns(UserWarning, match="Failed to download pretrained weights"):
catch_url_error(raise_error_if_pretrained)(pretrained=True)