From 1a567ca87d856def11ca01ecb618e7717b5b39a3 Mon Sep 17 00:00:00 2001 From: Ma Zerun Date: Fri, 17 Dec 2021 15:55:02 +0800 Subject: [PATCH] [Enhance] Enhance feature extraction function. (#593) * Fix MobileNet V3 configs * Refactor to support more powerful feature extraction. * Add unit tests * Fix unit test * Imporve according to comments * Update checkpoints path * Fix unit tests * Add docstring of `simple_test` * Add docstring of `extract_feat` * Update model zoo --- .../models/mobilenet_v3_large_imagenet.py | 2 + .../models/mobilenet_v3_small_imagenet.py | 2 + configs/deit/README.md | 8 +- configs/deit/metafile.yml | 8 +- .../mobilenet-v3-small_8xb32_in1k.py | 2 +- docs/en/model_zoo.md | 22 +- mmcls/models/classifiers/base.py | 10 +- mmcls/models/classifiers/image.py | 87 +++++++- mmcls/models/heads/cls_head.py | 49 ++++- mmcls/models/heads/conformer_head.py | 53 +++-- mmcls/models/heads/deit_head.py | 66 ++++-- mmcls/models/heads/linear_head.py | 43 +++- mmcls/models/heads/multi_label_head.py | 47 ++++- mmcls/models/heads/multi_label_linear_head.py | 46 ++++- mmcls/models/heads/stacked_head.py | 64 ++++-- mmcls/models/heads/vision_transformer_head.py | 56 +++++- tests/test_models/test_classifiers.py | 100 ++++++++- tests/test_models/test_heads.py | 189 ++++++++++++++---- 18 files changed, 707 insertions(+), 147 deletions(-) diff --git a/configs/_base_/models/mobilenet_v3_large_imagenet.py b/configs/_base_/models/mobilenet_v3_large_imagenet.py index b6fdafab6e8..5318f50feeb 100644 --- a/configs/_base_/models/mobilenet_v3_large_imagenet.py +++ b/configs/_base_/models/mobilenet_v3_large_imagenet.py @@ -11,4 +11,6 @@ dropout_rate=0.2, act_cfg=dict(type='HSwish'), loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + init_cfg=dict( + type='Normal', layer='Linear', mean=0., std=0.01, bias=0.), topk=(1, 5))) diff --git a/configs/_base_/models/mobilenet_v3_small_imagenet.py b/configs/_base_/models/mobilenet_v3_small_imagenet.py index 5b8af1f9acc..af6cc1b8d9d 100644 --- a/configs/_base_/models/mobilenet_v3_small_imagenet.py +++ b/configs/_base_/models/mobilenet_v3_small_imagenet.py @@ -11,4 +11,6 @@ dropout_rate=0.2, act_cfg=dict(type='HSwish'), loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + init_cfg=dict( + type='Normal', layer='Linear', mean=0., std=0.01, bias=0.), topk=(1, 5))) diff --git a/configs/deit/README.md b/configs/deit/README.md index ba496b5734f..9418e34a1a6 100644 --- a/configs/deit/README.md +++ b/configs/deit/README.md @@ -34,11 +34,11 @@ The pre-trained models are converted from the [official repo](https://github.com | Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download | |:---------------------:|:---------:|:--------:|:---------:|:---------:|:------:|:--------:| | DeiT-tiny\* | 5.72 | 1.08 | 72.13 | 91.13 | [config](configs/deit/deit-tiny_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_3rdparty_pt-4xb256_in1k_20211124-e930093b.pth) | -| DeiT-tiny distilled\* | 5.72 | 1.08 | 74.51 | 91.90 | [config](configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211124-e71bdd9a.pth) | +| DeiT-tiny distilled\* | 5.72 | 1.08 | 74.51 | 91.90 | [config](configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211216-c429839a.pth) | | DeiT-small\* | 22.05 | 4.24 | 79.83 | 94.95 | [config](configs/deit/deit-small_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small_3rdparty_pt-4xb256_in1k_20211124-ffe94edd.pth) | -| DeiT-small distilled\* | 22.05 | 4.24 | 81.17 | 95.40 | [config](configs/deit/deit-small-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211124-15e341b0.pth) | +| DeiT-small distilled\* | 22.05 | 4.24 | 81.17 | 95.40 | [config](configs/deit/deit-small-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211216-4de1d725.pth) | | DeiT-base\* | 86.57 | 16.86 | 81.79 | 95.59 | [config](configs/deit/deit-base_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_pt-16xb64_in1k_20211124-6f40c188.pth) | -| DeiT-base distilled\* | 86.57 | 16.86 | 83.33 | 96.49 | [config](configs/deit/deit-base-distilled_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211124-766d123d.pth) | +| DeiT-base distilled\* | 86.57 | 16.86 | 83.33 | 96.49 | [config](configs/deit/deit-base-distilled_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211216-42891296.pth) | *Models with \* are converted from other repos.* @@ -51,7 +51,7 @@ The fine-tuned models are converted from the [official repo](https://github.com/ | Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download | |:---------------------:|:---------:|:--------:|:---------:|:---------:|:------:|:--------:| | DeiT-base 384px\* | 86.86 | 49.37 | 83.04 | 96.31 | [config](configs/deit/deit-base_ft-16xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_ft-16xb32_in1k-384px_20211124-822d02f2.pth) | -| DeiT-base distilled 384px\* | 86.86 | 49.37 | 85.55 | 97.35 | [config](configs/deit/deit-base-distilled_ft-16xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_ft-16xb32_in1k-384px_20211124-91e88933.pth) | +| DeiT-base distilled 384px\* | 86.86 | 49.37 | 85.55 | 97.35 | [config](configs/deit/deit-base-distilled_ft-16xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_ft-16xb32_in1k-384px_20211216-e48d6000.pth) | *Models with \* are converted from other repos.* diff --git a/configs/deit/metafile.yml b/configs/deit/metafile.yml index 9f475eaa067..7d1980224ba 100644 --- a/configs/deit/metafile.yml +++ b/configs/deit/metafile.yml @@ -40,7 +40,7 @@ Models: Top 1 Accuracy: 74.51 Top 5 Accuracy: 91.90 Task: Image Classification - Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211124-e71bdd9a.pth + Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211216-c429839a.pth Converted From: Weights: https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L108 @@ -72,7 +72,7 @@ Models: Top 1 Accuracy: 81.17 Top 5 Accuracy: 95.40 Task: Image Classification - Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211124-15e341b0.pth + Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211216-4de1d725.pth Converted From: Weights: https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L123 @@ -104,7 +104,7 @@ Models: Top 1 Accuracy: 83.33 Top 5 Accuracy: 96.49 Task: Image Classification - Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211124-766d123d.pth + Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211216-42891296.pth Converted From: Weights: https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L138 @@ -136,7 +136,7 @@ Models: Top 1 Accuracy: 85.55 Top 5 Accuracy: 97.35 Task: Image Classification - Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_ft-16xb32_in1k-384px_20211124-91e88933.pth + Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_ft-16xb32_in1k-384px_20211216-e48d6000.pth Converted From: Weights: https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L168 diff --git a/configs/mobilenet_v3/mobilenet-v3-small_8xb32_in1k.py b/configs/mobilenet_v3/mobilenet-v3-small_8xb32_in1k.py index 74a5a6ab06a..2612166fd2b 100644 --- a/configs/mobilenet_v3/mobilenet-v3-small_8xb32_in1k.py +++ b/configs/mobilenet_v3/mobilenet-v3-small_8xb32_in1k.py @@ -17,7 +17,7 @@ # - modify: RandomErasing use RE-M instead of RE-0 _base_ = [ - '../_base_/models/mobilenet-v3-small_8xb32_in1k.py', + '../_base_/models/mobilenet_v3_small_imagenet.py', '../_base_/datasets/imagenet_bs32_pil_resize.py', '../_base_/default_runtime.py' ] diff --git a/docs/en/model_zoo.md b/docs/en/model_zoo.md index 918f5e4b0f4..782ddacb824 100644 --- a/docs/en/model_zoo.md +++ b/docs/en/model_zoo.md @@ -63,16 +63,18 @@ The ResNet family models below are trained by standard data augmentations, i.e., | T2T-ViT_t-24\* | 64.00 | 12.69 | 82.55 | 96.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-24_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_3rdparty_8xb64_in1k_20210928-fe95a61b.pth) | [log]()| | Mixer-B/16\* | 59.88 | 12.61 | 76.68 | 92.25 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mlp_mixer/mlp-mixer-base-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-base-p16_3rdparty_64xb64_in1k_20211124-1377e3e0.pth) | [log]()| | Mixer-L/16\* | 208.2 | 44.57 | 72.34 | 88.02 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mlp_mixer/mlp-mixer-large-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-large-p16_3rdparty_64xb64_in1k_20211124-5a2519d2.pth) | [log]()| -| DeiT-tiny\* | 5.72 | 1.08 | 72.13 | 91.13 | [config](configs/deit/deit-tiny_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_3rdparty_pt-4xb256_in1k_20211124-e930093b.pth) | [log]()| -| DeiT-tiny distilled\* | 5.72 | 1.08 | 74.51 | 91.90 | [config](configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211124-e71bdd9a.pth) | [log]()| -| DeiT-small\* | 22.05 | 4.24 | 79.83 | 94.95 | [config](configs/deit/deit-small_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small_3rdparty_pt-4xb256_in1k_20211124-ffe94edd.pth) | [log]()| -| DeiT-small distilled\* | 22.05 | 4.24 | 81.17 | 95.40 | [config](configs/deit/deit-small-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211124-15e341b0.pth) | [log]()| -| DeiT-base\* | 86.57 | 16.86 | 81.79 | 95.59 | [config](configs/deit/deit-base_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_pt-16xb64_in1k_20211124-6f40c188.pth) | [log]()| -| DeiT-base distilled\* | 86.57 | 16.86 | 83.33 | 96.49 | [config](configs/deit/deit-base-distilled_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211124-766d123d.pth) | [log]()| -| Conformer-tiny-p16\* | 23.52 | 4.90 | 81.31 | 95.60 | [config](configs/conformer/conformer-tiny-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-tiny-p16_3rdparty_8xb128_in1k_20211206-f6860372.pth) | [log]()| -| Conformer-small-p32 | 38.85 | 7.09 | 81.96 | 96.02 | [config](configs/conformer/conformer-small-p32_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p32_8xb128_in1k_20211206-947a0816.pth) | [log]()| -| Conformer-small-p16\* | 37.67 | 10.31 | 83.32 | 96.46 | [config](configs/conformer/conformer-small-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p16_3rdparty_8xb128_in1k_20211206-3065dcf5.pth) | [log]()| -| Conformer-base-p16\* | 83.29 | 22.89 | 83.82 | 96.59 | [config](configs/conformer/conformer-base-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-base-p16_3rdparty_8xb128_in1k_20211206-bfdf8637.pth) | [log]()| +| DeiT-tiny\* | 5.72 | 1.08 | 72.13 | 91.13 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-tiny_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_3rdparty_pt-4xb256_in1k_20211124-e930093b.pth) | [log]()| +| DeiT-tiny distilled\* | 5.72 | 1.08 | 74.51 | 91.90 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211216-c429839a.pth) | [log]()| +| DeiT-small\* | 22.05 | 4.24 | 79.83 | 94.95 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-small_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small_3rdparty_pt-4xb256_in1k_20211124-ffe94edd.pth) | [log]()| +| DeiT-small distilled\* | 22.05 | 4.24 | 81.17 | 95.40 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-small-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211216-4de1d725.pth) | [log]()| +| DeiT-base\* | 86.57 | 16.86 | 81.79 | 95.59 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-base_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_pt-16xb64_in1k_20211124-6f40c188.pth) | [log]()| +| DeiT-base distilled\* | 86.57 | 16.86 | 83.33 | 96.49 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-base-distilled_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211216-42891296.pth) | [log]()| +| DeiT-base 384px\* | 86.86 | 49.37 | 83.04 | 96.31 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-base_ft-16xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_ft-16xb32_in1k-384px_20211124-822d02f2.pth) | [log]()| +| DeiT-base distilled 384px\* | 86.86 | 49.37 | 85.55 | 97.35 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-base-distilled_ft-16xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_ft-16xb32_in1k-384px_20211216-e48d6000.pth) | [log]()| +| Conformer-tiny-p16\* | 23.52 | 4.90 | 81.31 | 95.60 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/conformer/conformer-tiny-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-tiny-p16_3rdparty_8xb128_in1k_20211206-f6860372.pth) | [log]()| +| Conformer-small-p32 | 38.85 | 7.09 | 81.96 | 96.02 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/conformer/conformer-small-p32_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p32_8xb128_in1k_20211206-947a0816.pth) | [log]()| +| Conformer-small-p16\* | 37.67 | 10.31 | 83.32 | 96.46 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/conformer/conformer-small-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p16_3rdparty_8xb128_in1k_20211206-3065dcf5.pth) | [log]()| +| Conformer-base-p16\* | 83.29 | 22.89 | 83.82 | 96.59 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/conformer/conformer-base-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-base-p16_3rdparty_8xb128_in1k_20211206-bfdf8637.pth) | [log]()| Models with * are converted from other repos, others are trained by ourselves. diff --git a/mmcls/models/classifiers/base.py b/mmcls/models/classifiers/base.py index 02391c71252..a0f6b0656cc 100644 --- a/mmcls/models/classifiers/base.py +++ b/mmcls/models/classifiers/base.py @@ -2,6 +2,7 @@ import warnings from abc import ABCMeta, abstractmethod from collections import OrderedDict +from typing import Sequence import mmcv import torch @@ -35,13 +36,14 @@ def with_head(self): return hasattr(self, 'head') and self.head is not None @abstractmethod - def extract_feat(self, imgs): + def extract_feat(self, imgs, stage=None): pass - def extract_feats(self, imgs): - assert isinstance(imgs, list) + def extract_feats(self, imgs, stage=None): + assert isinstance(imgs, Sequence) + kwargs = {} if stage is None else {'stage': stage} for img in imgs: - yield self.extract_feat(img) + yield self.extract_feat(img, **kwargs) @abstractmethod def forward_train(self, imgs, **kwargs): diff --git a/mmcls/models/classifiers/image.py b/mmcls/models/classifiers/image.py index 5c2f5cefa47..cd453b1917a 100644 --- a/mmcls/models/classifiers/image.py +++ b/mmcls/models/classifiers/image.py @@ -3,6 +3,7 @@ import warnings from ..builder import CLASSIFIERS, build_backbone, build_head, build_neck +from ..heads import MultiLabelClsHead from ..utils.augment import Augments from .base import BaseClassifier @@ -70,8 +71,74 @@ def __init__(self, cfg['prob'] = cutmix_prob self.augments = Augments(cfg) - def extract_feat(self, img): - """Directly extract features from the backbone + neck.""" + def extract_feat(self, img, stage='neck'): + """Directly extract features from the specified stage. + + Args: + img (Tensor): The input images. The shape of it should be + ``(num_samples, num_channels, *img_shape)``. + stage (str): Which stage to output the feature. Choose from + "backbone", "neck" and "pre_logits". Defaults to "neck". + + Returns: + tuple | Tensor: The output of specified stage. + The output depends on detailed implementation. In general, the + output of backbone and neck is a tuple and the output of + pre_logits is a tensor. + + Examples: + 1. Backbone output + + >>> import torch + >>> from mmcv import Config + >>> from mmcls.models import build_classifier + >>> + >>> cfg = Config.fromfile('configs/resnet/resnet18_8xb32_in1k.py').model + >>> cfg.backbone.out_indices = (0, 1, 2, 3) # Output multi-scale feature maps + >>> model = build_classifier(cfg) + >>> outs = model.extract_feat(torch.rand(1, 3, 224, 224), stage='backbone') + >>> for out in outs: + ... print(out.shape) + torch.Size([1, 64, 56, 56]) + torch.Size([1, 128, 28, 28]) + torch.Size([1, 256, 14, 14]) + torch.Size([1, 512, 7, 7]) + + 2. Neck output + + >>> import torch + >>> from mmcv import Config + >>> from mmcls.models import build_classifier + >>> + >>> cfg = Config.fromfile('configs/resnet/resnet18_8xb32_in1k.py').model + >>> cfg.backbone.out_indices = (0, 1, 2, 3) # Output multi-scale feature maps + >>> model = build_classifier(cfg) + >>> + >>> outs = model.extract_feat(torch.rand(1, 3, 224, 224), stage='neck') + >>> for out in outs: + ... print(out.shape) + torch.Size([1, 64]) + torch.Size([1, 128]) + torch.Size([1, 256]) + torch.Size([1, 512]) + + 3. Pre-logits output (without the final linear classifier head) + + >>> import torch + >>> from mmcv import Config + >>> from mmcls.models import build_classifier + >>> + >>> cfg = Config.fromfile('configs/vision_transformer/vit-base-p16_pt-64xb64_in1k-224.py').model + >>> model = build_classifier(cfg) + >>> + >>> out = model.extract_feat(torch.rand(1, 3, 224, 224), stage='pre_logits') + >>> print(out.shape) # The hidden dims in head is 3072 + torch.Size([1, 3072]) + """ # noqa: E501 + assert stage in ['backbone', 'neck', 'pre_logits'], \ + (f'Invalid output stage "{stage}", please choose from "backbone", ' + '"neck" and "pre_logits"') + x = self.backbone(img) if self.return_tuple: if not isinstance(x, tuple): @@ -83,8 +150,16 @@ def extract_feat(self, img): else: if isinstance(x, tuple): x = x[-1] + if stage == 'backbone': + return x + if self.with_neck: x = self.neck(x) + if stage == 'neck': + return x + + if self.with_head and hasattr(self.head, 'pre_logits'): + x = self.head.pre_logits(x) return x def forward_train(self, img, gt_label, **kwargs): @@ -122,12 +197,16 @@ def forward_train(self, img, gt_label, **kwargs): return losses - def simple_test(self, img, img_metas): + def simple_test(self, img, img_metas=None, **kwargs): """Test without augmentation.""" x = self.extract_feat(img) try: - res = self.head.simple_test(x) + if isinstance(self.head, MultiLabelClsHead): + assert 'softmax' not in kwargs, ( + 'Please use `sigmoid` instead of `softmax` ' + 'in multi-label tasks.') + res = self.head.simple_test(x, **kwargs) except TypeError as e: if 'not tuple' in str(e) and self.return_tuple: return TypeError( diff --git a/mmcls/models/heads/cls_head.py b/mmcls/models/heads/cls_head.py index 87305134fb8..2e430c5a5f1 100644 --- a/mmcls/models/heads/cls_head.py +++ b/mmcls/models/heads/cls_head.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +import warnings + import torch import torch.nn.functional as F @@ -62,14 +64,49 @@ def forward_train(self, cls_score, gt_label, **kwargs): losses = self.loss(cls_score, gt_label, **kwargs) return losses - def simple_test(self, cls_score): - """Test without augmentation.""" + def pre_logits(self, x): + if isinstance(x, tuple): + x = x[-1] + + warnings.warn( + 'The input of ClsHead should be already logits. ' + 'Please modify the backbone if you want to get pre-logits feature.' + ) + return x + + def simple_test(self, cls_score, softmax=True, post_process=True): + """Inference without augmentation. + + Args: + cls_score (tuple[Tensor]): The input classification score logits. + Multi-stage inputs are acceptable but only the last stage will + be used to classify. The shape of every item should be + ``(num_samples, num_classes)``. + softmax (bool): Whether to softmax the classification score. + post_process (bool): Whether to do post processing the + inference results. It will convert the output to a list. + + Returns: + Tensor | list: The inference results. + + - If no post processing, the output is a tensor with shape + ``(num_samples, num_classes)``. + - If post processing, the output is a multi-dimentional list of + float and the dimensions are ``(num_samples, num_classes)``. + """ if isinstance(cls_score, tuple): cls_score = cls_score[-1] - if isinstance(cls_score, list): - cls_score = sum(cls_score) / float(len(cls_score)) - pred = F.softmax(cls_score, dim=1) if cls_score is not None else None - return self.post_process(pred) + + if softmax: + pred = ( + F.softmax(cls_score, dim=1) if cls_score is not None else None) + else: + pred = cls_score + + if post_process: + return self.post_process(pred) + else: + return pred def post_process(self, pred): on_trace = is_tracing() diff --git a/mmcls/models/heads/conformer_head.py b/mmcls/models/heads/conformer_head.py index c913b657157..c6557962ae3 100644 --- a/mmcls/models/heads/conformer_head.py +++ b/mmcls/models/heads/conformer_head.py @@ -16,7 +16,7 @@ class ConformerHead(ClsHead): category. in_channels (int): Number of channels in the input feature map. init_cfg (dict | optional): The extra init config of layers. - Defaults to use dict(type='Normal', layer='Linear', std=0.01). + Defaults to use ``dict(type='Normal', layer='Linear', std=0.01)``. """ def __init__( @@ -55,25 +55,54 @@ def init_weights(self): else: self.apply(self._init_weights) - def simple_test(self, x): - """Test without augmentation.""" + def pre_logits(self, x): if isinstance(x, tuple): x = x[-1] - assert isinstance(x, - list) # There are two outputs in the Conformer model + return x + + def simple_test(self, x, softmax=True, post_process=True): + """Inference without augmentation. + + Args: + x (tuple[tuple[tensor, tensor]]): The input features. + Multi-stage inputs are acceptable but only the last stage will + be used to classify. Every item should be a tuple which + includes convluation features and transformer features. The + shape of them should be ``(num_samples, in_channels[0])`` and + ``(num_samples, in_channels[1])``. + softmax (bool): Whether to softmax the classification score. + post_process (bool): Whether to do post processing the + inference results. It will convert the output to a list. + + Returns: + Tensor | list: The inference results. + + - If no post processing, the output is a tensor with shape + ``(num_samples, num_classes)``. + - If post processing, the output is a multi-dimentional list of + float and the dimensions are ``(num_samples, num_classes)``. + """ + x = self.pre_logits(x) + # There are two outputs in the Conformer model + assert len(x) == 2 conv_cls_score = self.conv_cls_head(x[0]) tran_cls_score = self.trans_cls_head(x[1]) - cls_score = conv_cls_score + tran_cls_score - - pred = F.softmax(cls_score, dim=1) if cls_score is not None else None - - return self.post_process(pred) + if softmax: + cls_score = conv_cls_score + tran_cls_score + pred = ( + F.softmax(cls_score, dim=1) if cls_score is not None else None) + if post_process: + pred = self.post_process(pred) + else: + pred = [conv_cls_score, tran_cls_score] + if post_process: + pred = list(map(self.post_process, pred)) + return pred def forward_train(self, x, gt_label): - if isinstance(x, tuple): - x = x[-1] + x = self.pre_logits(x) assert isinstance(x, list) and len(x) == 2, \ 'There should be two outputs in the Conformer model' diff --git a/mmcls/models/heads/deit_head.py b/mmcls/models/heads/deit_head.py index 4a79e455587..5aaf3babc42 100644 --- a/mmcls/models/heads/deit_head.py +++ b/mmcls/models/heads/deit_head.py @@ -12,25 +12,67 @@ class DeiTClsHead(VisionTransformerClsHead): def __init__(self, *args, **kwargs): super(DeiTClsHead, self).__init__(*args, **kwargs) - self.head_dist = nn.Linear(self.in_channels, self.num_classes) + if self.hidden_dim is None: + head_dist = nn.Linear(self.in_channels, self.num_classes) + else: + head_dist = nn.Linear(self.hidden_dim, self.num_classes) + self.layers.add_module('head_dist', head_dist) - def simple_test(self, x): - """Test without augmentation.""" - x = x[-1] - assert isinstance(x, list) and len(x) == 3 + def pre_logits(self, x): + if isinstance(x, tuple): + x = x[-1] _, cls_token, dist_token = x - cls_score = (self.layers(cls_token) + self.head_dist(dist_token)) / 2 - pred = F.softmax(cls_score, dim=1) if cls_score is not None else None - return self.post_process(pred) + if self.hidden_dim is None: + return cls_token, dist_token + else: + cls_token = self.layers.act(self.layers.pre_logits(cls_token)) + dist_token = self.layers.act(self.layers.pre_logits(dist_token)) + return cls_token, dist_token + + def simple_test(self, x, softmax=True, post_process=True): + """Inference without augmentation. + + Args: + x (tuple[tuple[tensor, tensor, tensor]]): The input features. + Multi-stage inputs are acceptable but only the last stage will + be used to classify. Every item should be a tuple which + includes patch token, cls token and dist token. The cls token + and dist token will be used to classify and the shape of them + should be ``(num_samples, in_channels)``. + softmax (bool): Whether to softmax the classification score. + post_process (bool): Whether to do post processing the + inference results. It will convert the output to a list. + + Returns: + Tensor | list: The inference results. + + - If no post processing, the output is a tensor with shape + ``(num_samples, num_classes)``. + - If post processing, the output is a multi-dimentional list of + float and the dimensions are ``(num_samples, num_classes)``. + """ + cls_token, dist_token = self.pre_logits(x) + cls_score = (self.layers.head(cls_token) + + self.layers.head_dist(dist_token)) / 2 + + if softmax: + pred = F.softmax( + cls_score, dim=1) if cls_score is not None else None + else: + pred = cls_score + + if post_process: + return self.post_process(pred) + else: + return pred def forward_train(self, x, gt_label): logger = get_root_logger() logger.warning("MMClassification doesn't support to train the " 'distilled version DeiT.') - x = x[-1] - assert isinstance(x, list) and len(x) == 3 - _, cls_token, dist_token = x - cls_score = (self.layers(cls_token) + self.head_dist(dist_token)) / 2 + cls_token, dist_token = self.pre_logits(x) + cls_score = (self.layers.head(cls_token) + + self.layers.head_dist(dist_token)) / 2 losses = self.loss(cls_score, gt_label) return losses diff --git a/mmcls/models/heads/linear_head.py b/mmcls/models/heads/linear_head.py index d9c96add50c..113b41b685f 100644 --- a/mmcls/models/heads/linear_head.py +++ b/mmcls/models/heads/linear_head.py @@ -35,20 +35,47 @@ def __init__(self, self.fc = nn.Linear(self.in_channels, self.num_classes) - def simple_test(self, x): - """Test without augmentation.""" + def pre_logits(self, x): if isinstance(x, tuple): x = x[-1] + return x + + def simple_test(self, x, softmax=True, post_process=True): + """Inference without augmentation. + + Args: + x (tuple[Tensor]): The input features. + Multi-stage inputs are acceptable but only the last stage will + be used to classify. The shape of every item should be + ``(num_samples, in_channels)``. + softmax (bool): Whether to softmax the classification score. + post_process (bool): Whether to do post processing the + inference results. It will convert the output to a list. + + Returns: + Tensor | list: The inference results. + + - If no post processing, the output is a tensor with shape + ``(num_samples, num_classes)``. + - If post processing, the output is a multi-dimentional list of + float and the dimensions are ``(num_samples, num_classes)``. + """ + x = self.pre_logits(x) cls_score = self.fc(x) - if isinstance(cls_score, list): - cls_score = sum(cls_score) / float(len(cls_score)) - pred = F.softmax(cls_score, dim=1) if cls_score is not None else None - return self.post_process(pred) + if softmax: + pred = ( + F.softmax(cls_score, dim=1) if cls_score is not None else None) + else: + pred = cls_score + + if post_process: + return self.post_process(pred) + else: + return pred def forward_train(self, x, gt_label, **kwargs): - if isinstance(x, tuple): - x = x[-1] + x = self.pre_logits(x) cls_score = self.fc(x) losses = self.loss(cls_score, gt_label, **kwargs) return losses diff --git a/mmcls/models/heads/multi_label_head.py b/mmcls/models/heads/multi_label_head.py index 5f59ed96d13..e11a7733192 100644 --- a/mmcls/models/heads/multi_label_head.py +++ b/mmcls/models/heads/multi_label_head.py @@ -1,6 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -import torch.nn.functional as F from ..builder import HEADS, build_loss from ..utils import is_tracing @@ -47,14 +46,50 @@ def forward_train(self, cls_score, gt_label, **kwargs): losses = self.loss(cls_score, gt_label, **kwargs) return losses - def simple_test(self, x): + def pre_logits(self, x): if isinstance(x, tuple): x = x[-1] - if isinstance(x, list): - x = sum(x) / float(len(x)) - pred = F.sigmoid(x) if x is not None else None - return self.post_process(pred) + from mmcls.utils import get_root_logger + logger = get_root_logger() + logger.warning( + 'The input of MultiLabelClsHead should be already logits. ' + 'Please modify the backbone if you want to get pre-logits feature.' + ) + return x + + def simple_test(self, x, sigmoid=True, post_process=True): + """Inference without augmentation. + + Args: + cls_score (tuple[Tensor]): The input classification score logits. + Multi-stage inputs are acceptable but only the last stage will + be used to classify. The shape of every item should be + ``(num_samples, num_classes)``. + sigmoid (bool): Whether to sigmoid the classification score. + post_process (bool): Whether to do post processing the + inference results. It will convert the output to a list. + + Returns: + Tensor | list: The inference results. + + - If no post processing, the output is a tensor with shape + ``(num_samples, num_classes)``. + - If post processing, the output is a multi-dimentional list of + float and the dimensions are ``(num_samples, num_classes)``. + """ + if isinstance(x, tuple): + x = x[-1] + + if sigmoid: + pred = torch.sigmoid(x) if x is not None else None + else: + pred = x + + if post_process: + return self.post_process(pred) + else: + return pred def post_process(self, pred): on_trace = is_tracing() diff --git a/mmcls/models/heads/multi_label_linear_head.py b/mmcls/models/heads/multi_label_linear_head.py index d3534a76899..0e9d0684a1b 100644 --- a/mmcls/models/heads/multi_label_linear_head.py +++ b/mmcls/models/heads/multi_label_linear_head.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +import torch import torch.nn as nn -import torch.nn.functional as F from ..builder import HEADS from .multi_label_head import MultiLabelClsHead @@ -39,21 +39,47 @@ def __init__(self, self.fc = nn.Linear(self.in_channels, self.num_classes) - def forward_train(self, x, gt_label, **kwargs): + def pre_logits(self, x): if isinstance(x, tuple): x = x[-1] + return x + + def forward_train(self, x, gt_label, **kwargs): + x = self.pre_logits(x) gt_label = gt_label.type_as(x) cls_score = self.fc(x) losses = self.loss(cls_score, gt_label, **kwargs) return losses - def simple_test(self, x): - """Test without augmentation.""" - if isinstance(x, tuple): - x = x[-1] + def simple_test(self, x, sigmoid=True, post_process=True): + """Inference without augmentation. + + Args: + x (tuple[Tensor]): The input features. + Multi-stage inputs are acceptable but only the last stage will + be used to classify. The shape of every item should be + ``(num_samples, in_channels)``. + sigmoid (bool): Whether to sigmoid the classification score. + post_process (bool): Whether to do post processing the + inference results. It will convert the output to a list. + + Returns: + Tensor | list: The inference results. + + - If no post processing, the output is a tensor with shape + ``(num_samples, num_classes)``. + - If post processing, the output is a multi-dimentional list of + float and the dimensions are ``(num_samples, num_classes)``. + """ + x = self.pre_logits(x) cls_score = self.fc(x) - if isinstance(cls_score, list): - cls_score = sum(cls_score) / float(len(cls_score)) - pred = F.sigmoid(cls_score) if cls_score is not None else None - return self.post_process(pred) + if sigmoid: + pred = torch.sigmoid(cls_score) if cls_score is not None else None + else: + pred = cls_score + + if post_process: + return self.post_process(pred) + else: + return pred diff --git a/mmcls/models/heads/stacked_head.py b/mmcls/models/heads/stacked_head.py index 760274d1c28..bbb0dc24ccb 100644 --- a/mmcls/models/heads/stacked_head.py +++ b/mmcls/models/heads/stacked_head.py @@ -49,8 +49,7 @@ class StackedLinearClsHead(ClsHead): """Classifier head with several hidden fc layer and a output fc layer. Args: - num_classes (int): Number of categories excluding the background - category. + num_classes (int): Number of categories. in_channels (int): Number of channels in the input feature map. mid_channels (Sequence): Number of channels in the hidden fc layers. dropout_rate (float): Dropout rate after each hidden fc layer, @@ -89,9 +88,7 @@ def __init__(self, self._init_layers() def _init_layers(self): - self.layers = ModuleList( - init_cfg=dict( - type='Normal', layer='Linear', mean=0., std=0.01, bias=0.)) + self.layers = ModuleList() in_channels = self.in_channels for hidden_channels in self.mid_channels: self.layers.append( @@ -114,24 +111,53 @@ def _init_layers(self): def init_weights(self): self.layers.init_weights() - def simple_test(self, x): - """Test without augmentation.""" + def pre_logits(self, x): if isinstance(x, tuple): x = x[-1] - cls_score = x - for layer in self.layers: - cls_score = layer(cls_score) - if isinstance(cls_score, list): - cls_score = sum(cls_score) / float(len(cls_score)) - pred = F.softmax(cls_score, dim=1) if cls_score is not None else None + for layer in self.layers[:-1]: + x = layer(x) + return x - return self.post_process(pred) + @property + def fc(self): + return self.layers[-1] + + def simple_test(self, x, softmax=True, post_process=True): + """Inference without augmentation. + + Args: + x (tuple[Tensor]): The input features. + Multi-stage inputs are acceptable but only the last stage will + be used to classify. The shape of every item should be + ``(num_samples, in_channels)``. + softmax (bool): Whether to softmax the classification score. + post_process (bool): Whether to do post processing the + inference results. It will convert the output to a list. + + Returns: + Tensor | list: The inference results. + + - If no post processing, the output is a tensor with shape + ``(num_samples, num_classes)``. + - If post processing, the output is a multi-dimentional list of + float and the dimensions are ``(num_samples, num_classes)``. + """ + x = self.pre_logits(x) + cls_score = self.fc(x) + + if softmax: + pred = ( + F.softmax(cls_score, dim=1) if cls_score is not None else None) + else: + pred = cls_score + + if post_process: + return self.post_process(pred) + else: + return pred def forward_train(self, x, gt_label, **kwargs): - if isinstance(x, tuple): - x = x[-1] - cls_score = x - for layer in self.layers: - cls_score = layer(cls_score) + x = self.pre_logits(x) + cls_score = self.fc(x) losses = self.loss(cls_score, gt_label, **kwargs) return losses diff --git a/mmcls/models/heads/vision_transformer_head.py b/mmcls/models/heads/vision_transformer_head.py index 564818ee0ec..a12aa7a0e0e 100644 --- a/mmcls/models/heads/vision_transformer_head.py +++ b/mmcls/models/heads/vision_transformer_head.py @@ -68,20 +68,54 @@ def init_weights(self): std=math.sqrt(1 / self.layers.pre_logits.in_features)) nn.init.zeros_(self.layers.pre_logits.bias) - def simple_test(self, x): - """Test without augmentation.""" - x = x[-1] + def pre_logits(self, x): + if isinstance(x, tuple): + x = x[-1] _, cls_token = x - cls_score = self.layers(cls_token) - if isinstance(cls_score, list): - cls_score = sum(cls_score) / float(len(cls_score)) - pred = F.softmax(cls_score, dim=1) if cls_score is not None else None + if self.hidden_dim is None: + return cls_token + else: + x = self.layers.pre_logits(cls_token) + return self.layers.act(x) + + def simple_test(self, x, softmax=True, post_process=True): + """Inference without augmentation. + + Args: + x (tuple[tuple[tensor, tensor]]): The input features. + Multi-stage inputs are acceptable but only the last stage will + be used to classify. Every item should be a tuple which + includes patch token and cls token. The cls token will be used + to classify and the shape of it should be + ``(num_samples, in_channels)``. + softmax (bool): Whether to softmax the classification score. + post_process (bool): Whether to do post processing the + inference results. It will convert the output to a list. + + Returns: + Tensor | list: The inference results. - return self.post_process(pred) + - If no post processing, the output is a tensor with shape + ``(num_samples, num_classes)``. + - If post processing, the output is a multi-dimentional list of + float and the dimensions are ``(num_samples, num_classes)``. + """ + x = self.pre_logits(x) + cls_score = self.layers.head(x) + + if softmax: + pred = ( + F.softmax(cls_score, dim=1) if cls_score is not None else None) + else: + pred = cls_score + + if post_process: + return self.post_process(pred) + else: + return pred def forward_train(self, x, gt_label, **kwargs): - x = x[-1] - _, cls_token = x - cls_score = self.layers(cls_token) + x = self.pre_logits(x) + cls_score = self.layers.head(x) losses = self.loss(cls_score, gt_label, **kwargs) return losses diff --git a/tests/test_models/test_classifiers.py b/tests/test_models/test_classifiers.py index 41cffbb6264..2157aad9a52 100644 --- a/tests/test_models/test_classifiers.py +++ b/tests/test_models/test_classifiers.py @@ -73,6 +73,19 @@ def test_image_classifier(): pred = model(single_img, return_loss=False, img_metas=None) assert isinstance(pred, list) and len(pred) == 1 + pred = model.simple_test(imgs, softmax=False) + assert isinstance(pred, list) and len(pred) == 16 + assert len(pred[0] == 10) + + pred = model.simple_test(imgs, softmax=False, post_process=False) + assert isinstance(pred, torch.Tensor) + assert pred.shape == (16, 10) + + soft_pred = model.simple_test(imgs, softmax=True, post_process=False) + assert isinstance(soft_pred, torch.Tensor) + assert soft_pred.shape == (16, 10) + torch.testing.assert_allclose(soft_pred, torch.softmax(pred, dim=1)) + # test pretrained # TODO remove deprecated pretrained with pytest.warns(UserWarning): @@ -83,7 +96,7 @@ def test_image_classifier(): type='Pretrained', checkpoint='checkpoint') # test show_result - img = np.random.random_integers(0, 255, (224, 224, 3)).astype(np.uint8) + img = np.random.randint(0, 256, (224, 224, 3)).astype(np.uint8) result = dict(pred_class='cat', pred_label=0, pred_score=0.9) with tempfile.TemporaryDirectory() as tmpdir: @@ -304,3 +317,88 @@ def forward(self, x): with pytest.warns(DeprecationWarning): model.extract_feat(imgs) + + +def test_classifier_extract_feat(): + model_cfg = ConfigDict( + type='ImageClassifier', + backbone=dict( + type='ResNet', + depth=18, + num_stages=4, + out_indices=(0, 1, 2, 3), + style='pytorch'), + neck=dict(type='GlobalAveragePooling'), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=512, + loss=dict(type='CrossEntropyLoss'), + topk=(1, 5), + )) + + model = CLASSIFIERS.build(model_cfg) + + # test backbone output + outs = model.extract_feat(torch.rand(1, 3, 224, 224), stage='backbone') + assert outs[0].shape == (1, 64, 56, 56) + assert outs[1].shape == (1, 128, 28, 28) + assert outs[2].shape == (1, 256, 14, 14) + assert outs[3].shape == (1, 512, 7, 7) + + # test neck output + outs = model.extract_feat(torch.rand(1, 3, 224, 224), stage='neck') + assert outs[0].shape == (1, 64) + assert outs[1].shape == (1, 128) + assert outs[2].shape == (1, 256) + assert outs[3].shape == (1, 512) + + # test pre_logits output + out = model.extract_feat(torch.rand(1, 3, 224, 224), stage='pre_logits') + assert out.shape == (1, 512) + + # test transformer style feature extraction + model_cfg = dict( + type='ImageClassifier', + backbone=dict( + type='VisionTransformer', arch='b', out_indices=[-3, -2, -1]), + neck=None, + head=dict( + type='VisionTransformerClsHead', + num_classes=1000, + in_channels=768, + hidden_dim=1024, + loss=dict(type='CrossEntropyLoss'), + )) + model = CLASSIFIERS.build(model_cfg) + + # test backbone output + outs = model.extract_feat(torch.rand(1, 3, 224, 224), stage='backbone') + for out in outs: + patch_token, cls_token = out + assert patch_token.shape == (1, 768, 14, 14) + assert cls_token.shape == (1, 768) + + # test neck output (the same with backbone) + outs = model.extract_feat(torch.rand(1, 3, 224, 224), stage='neck') + for out in outs: + patch_token, cls_token = out + assert patch_token.shape == (1, 768, 14, 14) + assert cls_token.shape == (1, 768) + + # test pre_logits output + out = model.extract_feat(torch.rand(1, 3, 224, 224), stage='pre_logits') + assert out.shape == (1, 1024) + + # test extract_feats + multi_imgs = [torch.rand(1, 3, 224, 224) for _ in range(3)] + outs = model.extract_feats(multi_imgs) + for outs_per_img in outs: + for out in outs_per_img: + patch_token, cls_token = out + assert patch_token.shape == (1, 768, 14, 14) + assert cls_token.shape == (1, 768) + + outs = model.extract_feats(multi_imgs, stage='pre_logits') + for out_per_img in outs: + assert out_per_img.shape == (1, 1024) diff --git a/tests/test_models/test_heads.py b/tests/test_models/test_heads.py index c56d509ea06..392afe74cf5 100644 --- a/tests/test_models/test_heads.py +++ b/tests/test_models/test_heads.py @@ -4,36 +4,53 @@ import pytest import torch -from mmcls.models.heads import (ClsHead, DeiTClsHead, LinearClsHead, - MultiLabelClsHead, MultiLabelLinearClsHead, - StackedLinearClsHead, VisionTransformerClsHead) +from mmcls.models.heads import (ClsHead, ConformerHead, DeiTClsHead, + LinearClsHead, MultiLabelClsHead, + MultiLabelLinearClsHead, StackedLinearClsHead, + VisionTransformerClsHead) -@pytest.mark.parametrize('feat', [torch.rand(4, 3), (torch.rand(4, 3), )]) +@pytest.mark.parametrize('feat', [torch.rand(4, 10), (torch.rand(4, 10), )]) def test_cls_head(feat): + fake_gt_label = torch.randint(0, 10, (4, )) - # test ClsHead with cal_acc=False - head = ClsHead() - fake_gt_label = torch.randint(0, 2, (4, )) - + # test forward_train with cal_acc=True + head = ClsHead(cal_acc=True) losses = head.forward_train(feat, fake_gt_label) assert losses['loss'].item() > 0 + assert 'accuracy' in losses - # test ClsHead with cal_acc=True - head = ClsHead(cal_acc=True) - feat = torch.rand(4, 3) - fake_gt_label = torch.randint(0, 2, (4, )) - + # test forward_train with cal_acc=False + head = ClsHead() losses = head.forward_train(feat, fake_gt_label) assert losses['loss'].item() > 0 - # test ClsHead with weight + # test forward_train with weight weight = torch.tensor([0.5, 0.5, 0.5, 0.5]) - losses_ = head.forward_train(feat, fake_gt_label) losses = head.forward_train(feat, fake_gt_label, weight=weight) assert losses['loss'].item() == losses_['loss'].item() * 0.5 + # test simple_test with post_process + pred = head.simple_test(feat) + assert isinstance(pred, list) and len(pred) == 4 + with patch('torch.onnx.is_in_onnx_export', return_value=True): + pred = head.simple_test(feat) + assert pred.shape == (4, 10) + + # test simple_test without post_process + pred = head.simple_test(feat, post_process=False) + assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10) + logits = head.simple_test(feat, softmax=False, post_process=False) + torch.testing.assert_allclose(pred, torch.softmax(logits, dim=1)) + + # test pre_logits + features = head.pre_logits(feat) + if isinstance(feat, tuple): + torch.testing.assert_allclose(features, feat[0]) + else: + torch.testing.assert_allclose(features, feat) + @pytest.mark.parametrize('feat', [torch.rand(4, 3), (torch.rand(4, 3), )]) def test_linear_head(feat): @@ -50,35 +67,85 @@ def test_linear_head(feat): head.init_weights() assert abs(head.fc.weight).sum() > 0 - # test simple_test - head = LinearClsHead(10, 3) + # test simple_test with post_process pred = head.simple_test(feat) assert isinstance(pred, list) and len(pred) == 4 - with patch('torch.onnx.is_in_onnx_export', return_value=True): - head = LinearClsHead(10, 3) pred = head.simple_test(feat) assert pred.shape == (4, 10) + # test simple_test without post_process + pred = head.simple_test(feat, post_process=False) + assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10) + logits = head.simple_test(feat, softmax=False, post_process=False) + torch.testing.assert_allclose(pred, torch.softmax(logits, dim=1)) -@pytest.mark.parametrize('feat', [torch.rand(4, 3), (torch.rand(4, 3), )]) + # test pre_logits + features = head.pre_logits(feat) + if isinstance(feat, tuple): + torch.testing.assert_allclose(features, feat[0]) + else: + torch.testing.assert_allclose(features, feat) + + +@pytest.mark.parametrize('feat', [torch.rand(4, 10), (torch.rand(4, 10), )]) def test_multilabel_head(feat): head = MultiLabelClsHead() - fake_gt_label = torch.randint(0, 2, (4, 3)) + fake_gt_label = torch.randint(0, 2, (4, 10)) losses = head.forward_train(feat, fake_gt_label) assert losses['loss'].item() > 0 + # test simple_test with post_process + pred = head.simple_test(feat) + assert isinstance(pred, list) and len(pred) == 4 + with patch('torch.onnx.is_in_onnx_export', return_value=True): + pred = head.simple_test(feat) + assert pred.shape == (4, 10) + + # test simple_test without post_process + pred = head.simple_test(feat, post_process=False) + assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10) + logits = head.simple_test(feat, sigmoid=False, post_process=False) + torch.testing.assert_allclose(pred, torch.sigmoid(logits)) + + # test pre_logits + features = head.pre_logits(feat) + if isinstance(feat, tuple): + torch.testing.assert_allclose(features, feat[0]) + else: + torch.testing.assert_allclose(features, feat) + @pytest.mark.parametrize('feat', [torch.rand(4, 5), (torch.rand(4, 5), )]) def test_multilabel_linear_head(feat): - head = MultiLabelLinearClsHead(3, 5) - fake_gt_label = torch.randint(0, 2, (4, 3)) + head = MultiLabelLinearClsHead(10, 5) + fake_gt_label = torch.randint(0, 2, (4, 10)) head.init_weights() losses = head.forward_train(feat, fake_gt_label) assert losses['loss'].item() > 0 + # test simple_test with post_process + pred = head.simple_test(feat) + assert isinstance(pred, list) and len(pred) == 4 + with patch('torch.onnx.is_in_onnx_export', return_value=True): + pred = head.simple_test(feat) + assert pred.shape == (4, 10) + + # test simple_test without post_process + pred = head.simple_test(feat, post_process=False) + assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10) + logits = head.simple_test(feat, sigmoid=False, post_process=False) + torch.testing.assert_allclose(pred, torch.sigmoid(logits)) + + # test pre_logits + features = head.pre_logits(feat) + if isinstance(feat, tuple): + torch.testing.assert_allclose(features, feat[0]) + else: + torch.testing.assert_allclose(features, feat) + @pytest.mark.parametrize('feat', [torch.rand(4, 5), (torch.rand(4, 5), )]) def test_stacked_linear_cls_head(feat): @@ -93,20 +160,28 @@ def test_stacked_linear_cls_head(feat): # test forward with default setting head = StackedLinearClsHead( - num_classes=3, in_channels=5, mid_channels=[10]) + num_classes=10, in_channels=5, mid_channels=[20]) head.init_weights() losses = head.forward_train(feat, fake_gt_label) assert losses['loss'].item() > 0 - # test simple test + # test simple_test with post_process pred = head.simple_test(feat) - assert len(pred) == 4 - - # test simple test in tracing + assert isinstance(pred, list) and len(pred) == 4 with patch('torch.onnx.is_in_onnx_export', return_value=True): pred = head.simple_test(feat) - assert pred.shape == torch.Size((4, 3)) + assert pred.shape == (4, 10) + + # test simple_test without post_process + pred = head.simple_test(feat, post_process=False) + assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10) + logits = head.simple_test(feat, softmax=False, post_process=False) + torch.testing.assert_allclose(pred, torch.softmax(logits, dim=1)) + + # test pre_logits + features = head.pre_logits(feat) + assert features.shape == (4, 20) # test forward with full function head = StackedLinearClsHead( @@ -144,21 +219,56 @@ def test_vit_head(): head.init_weights() assert abs(head.layers.pre_logits.weight).sum() > 0 - # test simple_test head = VisionTransformerClsHead(10, 100, hidden_dim=20) + # test simple_test with post_process pred = head.simple_test(fake_features) assert isinstance(pred, list) and len(pred) == 4 - with patch('torch.onnx.is_in_onnx_export', return_value=True): - head = VisionTransformerClsHead(10, 100, hidden_dim=20) pred = head.simple_test(fake_features) assert pred.shape == (4, 10) + # test simple_test without post_process + pred = head.simple_test(fake_features, post_process=False) + assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10) + logits = head.simple_test(fake_features, softmax=False, post_process=False) + torch.testing.assert_allclose(pred, torch.softmax(logits, dim=1)) + + # test pre_logits + features = head.pre_logits(fake_features) + assert features.shape == (4, 20) + # test assertion with pytest.raises(ValueError): VisionTransformerClsHead(-1, 100) +def test_conformer_head(): + fake_features = ([torch.rand(4, 64), torch.rand(4, 96)], ) + fake_gt_label = torch.randint(0, 10, (4, )) + + # test conformer head forward + head = ConformerHead(num_classes=10, in_channels=[64, 96]) + losses = head.forward_train(fake_features, fake_gt_label) + assert losses['loss'].item() > 0 + + # test simple_test with post_process + pred = head.simple_test(fake_features) + assert isinstance(pred, list) and len(pred) == 4 + with patch('torch.onnx.is_in_onnx_export', return_value=True): + pred = head.simple_test(fake_features) + assert pred.shape == (4, 10) + + # test simple_test without post_process + pred = head.simple_test(fake_features, post_process=False) + assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10) + logits = head.simple_test(fake_features, softmax=False, post_process=False) + torch.testing.assert_allclose(pred, torch.softmax(sum(logits), dim=1)) + + # test pre_logits + features = head.pre_logits(fake_features) + assert features is fake_features[0] + + def test_deit_head(): fake_features = ([ torch.rand(4, 7, 7, 16), @@ -185,16 +295,25 @@ def test_deit_head(): head.init_weights() assert abs(head.layers.pre_logits.weight).sum() > 0 - # test simple_test head = DeiTClsHead(10, 100, hidden_dim=20) + # test simple_test with post_process pred = head.simple_test(fake_features) assert isinstance(pred, list) and len(pred) == 4 - with patch('torch.onnx.is_in_onnx_export', return_value=True): - head = DeiTClsHead(10, 100, hidden_dim=20) pred = head.simple_test(fake_features) assert pred.shape == (4, 10) + # test simple_test without post_process + pred = head.simple_test(fake_features, post_process=False) + assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10) + logits = head.simple_test(fake_features, softmax=False, post_process=False) + torch.testing.assert_allclose(pred, torch.softmax(logits, dim=1)) + + # test pre_logits + cls_token, dist_token = head.pre_logits(fake_features) + assert cls_token.shape == (4, 20) + assert dist_token.shape == (4, 20) + # test assertion with pytest.raises(ValueError): DeiTClsHead(-1, 100)