From 849dd81fddf337cd358b5b752efad2b7631acdf6 Mon Sep 17 00:00:00 2001 From: PythicCoder Date: Tue, 8 Jun 2021 15:19:18 +0300 Subject: [PATCH] Added **kwargs (#377) * Added **kwargs Fixed keywords for timm for vision transformer models * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- flash/image/backbones.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/flash/image/backbones.py b/flash/image/backbones.py index 47b2252f94..ec7cacf8f4 100644 --- a/flash/image/backbones.py +++ b/flash/image/backbones.py @@ -161,11 +161,9 @@ def _fn_timm( model_name: str, pretrained: bool = True, num_classes: int = 0, - global_pool: str = '', + **kwargs, ) -> Tuple[nn.Module, int]: - backbone = timm.create_model( - model_name, pretrained=pretrained, num_classes=num_classes, global_pool=global_pool - ) + backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=num_classes, **kwargs) num_features = backbone.num_features return backbone, num_features