Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Fix TnT compatibility and verbose warning. #436

Merged
merged 2 commits into from
Sep 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mmcls/models/backbones/tnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,4 +364,4 @@ def forward(self, x):
pixel_embed, patch_embed = layer(pixel_embed, patch_embed)

patch_embed = self.norm(patch_embed)
return patch_embed[:, 0]
return (patch_embed[:, 0], )
3 changes: 2 additions & 1 deletion mmcls/models/classifiers/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from ..utils.augment import Augments
from .base import BaseClassifier

warnings.simplefilter('once')


@CLASSIFIERS.register_module()
class ImageClassifier(BaseClassifier):
Expand Down Expand Up @@ -74,7 +76,6 @@ def extract_feat(self, img):
if self.return_tuple:
if not isinstance(x, tuple):
x = (x, )
warnings.simplefilter('once')
warnings.warn(
'We will force all backbones to return a tuple in the '
'future. Please check your backbone and wrap the output '
Expand Down
6 changes: 4 additions & 2 deletions tests/test_models/test_backbones/test_tnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def test_tnt_backbone():

imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert feat.shape == torch.Size((1, 640))
assert len(feat) == 1
assert feat[0].shape == torch.Size((1, 640))

# Test tnt with embed_dims=768
arch = {
Expand All @@ -45,4 +46,5 @@ def test_tnt_backbone():

imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert feat.shape == torch.Size((1, 768))
assert len(feat) == 1
assert feat[0].shape == torch.Size((1, 768))