From c40634567fcbf22ecc92d7c5e6e5fa6651327b1c Mon Sep 17 00:00:00 2001 From: PythicCoder Date: Tue, 8 Jun 2021 14:45:51 +0300 Subject: [PATCH 1/2] Added **kwargs Fixed keywords for timm for vision transformer models --- flash/image/backbones.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flash/image/backbones.py b/flash/image/backbones.py index 47b2252f94..4c40d45b36 100644 --- a/flash/image/backbones.py +++ b/flash/image/backbones.py @@ -161,10 +161,10 @@ def _fn_timm( model_name: str, pretrained: bool = True, num_classes: int = 0, - global_pool: str = '', + **kwargs, ) -> Tuple[nn.Module, int]: backbone = timm.create_model( - model_name, pretrained=pretrained, num_classes=num_classes, global_pool=global_pool + model_name, pretrained=pretrained, num_classes=num_classes, **kwargs ) num_features = backbone.num_features return backbone, num_features From 5b6aeb043d2c8e97775db98ef12c0e039090346b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 8 Jun 2021 11:46:28 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- flash/image/backbones.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/flash/image/backbones.py b/flash/image/backbones.py index 4c40d45b36..ec7cacf8f4 100644 --- a/flash/image/backbones.py +++ b/flash/image/backbones.py @@ -163,9 +163,7 @@ def _fn_timm( num_classes: int = 0, **kwargs, ) -> Tuple[nn.Module, int]: - backbone = timm.create_model( - model_name, pretrained=pretrained, num_classes=num_classes, **kwargs - ) + backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=num_classes, **kwargs) num_features = backbone.num_features return backbone, num_features