Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Added **kwargs (#377)
Browse files Browse the repository at this point in the history
* Added **kwargs

Fixed keywords for timm for vision transformer models

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
aribornstein and pre-commit-ci[bot] committed Jun 8, 2021
1 parent 4b6d2ec commit 849dd81
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions flash/image/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,9 @@ def _fn_timm(
model_name: str,
pretrained: bool = True,
num_classes: int = 0,
global_pool: str = '',
**kwargs,
) -> Tuple[nn.Module, int]:
backbone = timm.create_model(
model_name, pretrained=pretrained, num_classes=num_classes, global_pool=global_pool
)
backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=num_classes, **kwargs)
num_features = backbone.num_features
return backbone, num_features

Expand Down

0 comments on commit 849dd81

Please sign in to comment.