diff --git a/flash/image/backbones.py b/flash/image/backbones.py index 47b2252f94..ec7cacf8f4 100644 --- a/flash/image/backbones.py +++ b/flash/image/backbones.py @@ -161,11 +161,9 @@ def _fn_timm( model_name: str, pretrained: bool = True, num_classes: int = 0, - global_pool: str = '', + **kwargs, ) -> Tuple[nn.Module, int]: - backbone = timm.create_model( - model_name, pretrained=pretrained, num_classes=num_classes, global_pool=global_pool - ) + backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=num_classes, **kwargs) num_features = backbone.num_features return backbone, num_features