Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] Enhance feature extraction function. #593

Merged
merged 12 commits into from
Dec 17, 2021
2 changes: 2 additions & 0 deletions configs/_base_/models/mobilenet_v3_large_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,6 @@
dropout_rate=0.2,
act_cfg=dict(type='HSwish'),
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
init_cfg=dict(
type='Normal', layer='Linear', mean=0., std=0.01, bias=0.),
topk=(1, 5)))
2 changes: 2 additions & 0 deletions configs/_base_/models/mobilenet_v3_small_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,6 @@
dropout_rate=0.2,
act_cfg=dict(type='HSwish'),
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
init_cfg=dict(
type='Normal', layer='Linear', mean=0., std=0.01, bias=0.),
topk=(1, 5)))
8 changes: 4 additions & 4 deletions configs/deit/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ The pre-trained models are converted from the [official repo](https://github.com
| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|:---------------------:|:---------:|:--------:|:---------:|:---------:|:------:|:--------:|
| DeiT-tiny\* | 5.72 | 1.08 | 72.13 | 91.13 | [config](configs/deit/deit-tiny_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_3rdparty_pt-4xb256_in1k_20211124-e930093b.pth) |
| DeiT-tiny distilled\* | 5.72 | 1.08 | 74.51 | 91.90 | [config](configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211124-e71bdd9a.pth) |
| DeiT-tiny distilled\* | 5.72 | 1.08 | 74.51 | 91.90 | [config](configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211216-c429839a.pth) |
| DeiT-small\* | 22.05 | 4.24 | 79.83 | 94.95 | [config](configs/deit/deit-small_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small_3rdparty_pt-4xb256_in1k_20211124-ffe94edd.pth) |
| DeiT-small distilled\* | 22.05 | 4.24 | 81.17 | 95.40 | [config](configs/deit/deit-small-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211124-15e341b0.pth) |
| DeiT-small distilled\* | 22.05 | 4.24 | 81.17 | 95.40 | [config](configs/deit/deit-small-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211216-4de1d725.pth) |
| DeiT-base\* | 86.57 | 16.86 | 81.79 | 95.59 | [config](configs/deit/deit-base_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_pt-16xb64_in1k_20211124-6f40c188.pth) |
| DeiT-base distilled\* | 86.57 | 16.86 | 83.33 | 96.49 | [config](configs/deit/deit-base-distilled_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211124-766d123d.pth) |
| DeiT-base distilled\* | 86.57 | 16.86 | 83.33 | 96.49 | [config](configs/deit/deit-base-distilled_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211216-42891296.pth) |

*Models with \* are converted from other repos.*

Expand All @@ -51,7 +51,7 @@ The fine-tuned models are converted from the [official repo](https://github.com/
| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|:---------------------:|:---------:|:--------:|:---------:|:---------:|:------:|:--------:|
| DeiT-base 384px\* | 86.86 | 49.37 | 83.04 | 96.31 | [config](configs/deit/deit-base_ft-16xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_ft-16xb32_in1k-384px_20211124-822d02f2.pth) |
| DeiT-base distilled 384px\* | 86.86 | 49.37 | 85.55 | 97.35 | [config](configs/deit/deit-base-distilled_ft-16xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_ft-16xb32_in1k-384px_20211124-91e88933.pth) |
| DeiT-base distilled 384px\* | 86.86 | 49.37 | 85.55 | 97.35 | [config](configs/deit/deit-base-distilled_ft-16xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_ft-16xb32_in1k-384px_20211216-e48d6000.pth) |

*Models with \* are converted from other repos.*

Expand Down
8 changes: 4 additions & 4 deletions configs/deit/metafile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ Models:
Top 1 Accuracy: 74.51
Top 5 Accuracy: 91.90
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211124-e71bdd9a.pth
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211216-c429839a.pth
Converted From:
Weights: https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L108
Expand Down Expand Up @@ -72,7 +72,7 @@ Models:
Top 1 Accuracy: 81.17
Top 5 Accuracy: 95.40
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211124-15e341b0.pth
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211216-4de1d725.pth
Converted From:
Weights: https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L123
Expand Down Expand Up @@ -104,7 +104,7 @@ Models:
Top 1 Accuracy: 83.33
Top 5 Accuracy: 96.49
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211124-766d123d.pth
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211216-42891296.pth
Converted From:
Weights: https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L138
Expand Down Expand Up @@ -136,7 +136,7 @@ Models:
Top 1 Accuracy: 85.55
Top 5 Accuracy: 97.35
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_ft-16xb32_in1k-384px_20211124-91e88933.pth
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_ft-16xb32_in1k-384px_20211216-e48d6000.pth
Converted From:
Weights: https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L168
Expand Down
2 changes: 1 addition & 1 deletion configs/mobilenet_v3/mobilenet-v3-small_8xb32_in1k.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# - modify: RandomErasing use RE-M instead of RE-0

_base_ = [
'../_base_/models/mobilenet-v3-small_8xb32_in1k.py',
'../_base_/models/mobilenet_v3_small_imagenet.py',
'../_base_/datasets/imagenet_bs32_pil_resize.py',
'../_base_/default_runtime.py'
]
Expand Down
22 changes: 12 additions & 10 deletions docs/en/model_zoo.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,18 @@ The ResNet family models below are trained by standard data augmentations, i.e.,
| T2T-ViT_t-24\* | 64.00 | 12.69 | 82.55 | 96.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-24_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_3rdparty_8xb64_in1k_20210928-fe95a61b.pth) | [log]()|
| Mixer-B/16\* | 59.88 | 12.61 | 76.68 | 92.25 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mlp_mixer/mlp-mixer-base-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-base-p16_3rdparty_64xb64_in1k_20211124-1377e3e0.pth) | [log]()|
| Mixer-L/16\* | 208.2 | 44.57 | 72.34 | 88.02 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mlp_mixer/mlp-mixer-large-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-large-p16_3rdparty_64xb64_in1k_20211124-5a2519d2.pth) | [log]()|
| DeiT-tiny\* | 5.72 | 1.08 | 72.13 | 91.13 | [config](configs/deit/deit-tiny_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_3rdparty_pt-4xb256_in1k_20211124-e930093b.pth) | [log]()|
| DeiT-tiny distilled\* | 5.72 | 1.08 | 74.51 | 91.90 | [config](configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211124-e71bdd9a.pth) | [log]()|
| DeiT-small\* | 22.05 | 4.24 | 79.83 | 94.95 | [config](configs/deit/deit-small_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small_3rdparty_pt-4xb256_in1k_20211124-ffe94edd.pth) | [log]()|
| DeiT-small distilled\* | 22.05 | 4.24 | 81.17 | 95.40 | [config](configs/deit/deit-small-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211124-15e341b0.pth) | [log]()|
| DeiT-base\* | 86.57 | 16.86 | 81.79 | 95.59 | [config](configs/deit/deit-base_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_pt-16xb64_in1k_20211124-6f40c188.pth) | [log]()|
| DeiT-base distilled\* | 86.57 | 16.86 | 83.33 | 96.49 | [config](configs/deit/deit-base-distilled_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211124-766d123d.pth) | [log]()|
| Conformer-tiny-p16\* | 23.52 | 4.90 | 81.31 | 95.60 | [config](configs/conformer/conformer-tiny-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-tiny-p16_3rdparty_8xb128_in1k_20211206-f6860372.pth) | [log]()|
| Conformer-small-p32 | 38.85 | 7.09 | 81.96 | 96.02 | [config](configs/conformer/conformer-small-p32_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p32_8xb128_in1k_20211206-947a0816.pth) | [log]()|
| Conformer-small-p16\* | 37.67 | 10.31 | 83.32 | 96.46 | [config](configs/conformer/conformer-small-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p16_3rdparty_8xb128_in1k_20211206-3065dcf5.pth) | [log]()|
| Conformer-base-p16\* | 83.29 | 22.89 | 83.82 | 96.59 | [config](configs/conformer/conformer-base-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-base-p16_3rdparty_8xb128_in1k_20211206-bfdf8637.pth) | [log]()|
| DeiT-tiny\* | 5.72 | 1.08 | 72.13 | 91.13 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-tiny_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_3rdparty_pt-4xb256_in1k_20211124-e930093b.pth) | [log]()|
| DeiT-tiny distilled\* | 5.72 | 1.08 | 74.51 | 91.90 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211216-c429839a.pth) | [log]()|
| DeiT-small\* | 22.05 | 4.24 | 79.83 | 94.95 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-small_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small_3rdparty_pt-4xb256_in1k_20211124-ffe94edd.pth) | [log]()|
| DeiT-small distilled\* | 22.05 | 4.24 | 81.17 | 95.40 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-small-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211216-4de1d725.pth) | [log]()|
| DeiT-base\* | 86.57 | 16.86 | 81.79 | 95.59 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-base_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_pt-16xb64_in1k_20211124-6f40c188.pth) | [log]()|
| DeiT-base distilled\* | 86.57 | 16.86 | 83.33 | 96.49 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-base-distilled_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211216-42891296.pth) | [log]()|
| DeiT-base 384px\* | 86.86 | 49.37 | 83.04 | 96.31 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-base_ft-16xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_ft-16xb32_in1k-384px_20211124-822d02f2.pth) | [log]()|
| DeiT-base distilled 384px\* | 86.86 | 49.37 | 85.55 | 97.35 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-base-distilled_ft-16xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_ft-16xb32_in1k-384px_20211216-e48d6000.pth) | [log]()|
| Conformer-tiny-p16\* | 23.52 | 4.90 | 81.31 | 95.60 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/conformer/conformer-tiny-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-tiny-p16_3rdparty_8xb128_in1k_20211206-f6860372.pth) | [log]()|
| Conformer-small-p32 | 38.85 | 7.09 | 81.96 | 96.02 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/conformer/conformer-small-p32_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p32_8xb128_in1k_20211206-947a0816.pth) | [log]()|
| Conformer-small-p16\* | 37.67 | 10.31 | 83.32 | 96.46 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/conformer/conformer-small-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p16_3rdparty_8xb128_in1k_20211206-3065dcf5.pth) | [log]()|
| Conformer-base-p16\* | 83.29 | 22.89 | 83.82 | 96.59 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/conformer/conformer-base-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-base-p16_3rdparty_8xb128_in1k_20211206-bfdf8637.pth) | [log]()|

Models with * are converted from other repos, others are trained by ourselves.

Expand Down
10 changes: 6 additions & 4 deletions mmcls/models/classifiers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import warnings
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from typing import Sequence

import mmcv
import torch
Expand Down Expand Up @@ -35,13 +36,14 @@ def with_head(self):
return hasattr(self, 'head') and self.head is not None

@abstractmethod
def extract_feat(self, imgs):
def extract_feat(self, imgs, stage=None):
pass

def extract_feats(self, imgs):
assert isinstance(imgs, list)
def extract_feats(self, imgs, stage=None):
mzr1996 marked this conversation as resolved.
Show resolved Hide resolved
assert isinstance(imgs, Sequence)
kwargs = {} if stage is None else {'stage': stage}
for img in imgs:
yield self.extract_feat(img)
yield self.extract_feat(img, **kwargs)

@abstractmethod
def forward_train(self, imgs, **kwargs):
Expand Down
87 changes: 83 additions & 4 deletions mmcls/models/classifiers/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import warnings

from ..builder import CLASSIFIERS, build_backbone, build_head, build_neck
from ..heads import MultiLabelClsHead
from ..utils.augment import Augments
from .base import BaseClassifier

Expand Down Expand Up @@ -70,8 +71,74 @@ def __init__(self,
cfg['prob'] = cutmix_prob
self.augments = Augments(cfg)

def extract_feat(self, img):
"""Directly extract features from the backbone + neck."""
def extract_feat(self, img, stage='neck'):
"""Directly extract features from the specified stage.

Args:
img (Tensor): The input images. The shape of it should be
``(num_samples, num_channels, *img_shape)``.
stage (str): Which stage to output the feature. Choose from
"backbone", "neck" and "pre_logits". Defaults to "neck".

Returns:
tuple | Tensor: The output of specified stage.
The output depends on detailed implementation. In general, the
output of backbone and neck is a tuple and the output of
pre_logits is a tensor.

Examples:
1. Backbone output

>>> import torch
>>> from mmcv import Config
>>> from mmcls.models import build_classifier
>>>
>>> cfg = Config.fromfile('configs/resnet/resnet18_8xb32_in1k.py').model
>>> cfg.backbone.out_indices = (0, 1, 2, 3) # Output multi-scale feature maps
>>> model = build_classifier(cfg)
>>> outs = model.extract_feat(torch.rand(1, 3, 224, 224), stage='backbone')
>>> for out in outs:
... print(out.shape)
torch.Size([1, 64, 56, 56])
torch.Size([1, 128, 28, 28])
torch.Size([1, 256, 14, 14])
torch.Size([1, 512, 7, 7])

2. Neck output

>>> import torch
>>> from mmcv import Config
>>> from mmcls.models import build_classifier
>>>
>>> cfg = Config.fromfile('configs/resnet/resnet18_8xb32_in1k.py').model
>>> cfg.backbone.out_indices = (0, 1, 2, 3) # Output multi-scale feature maps
>>> model = build_classifier(cfg)
>>>
>>> outs = model.extract_feat(torch.rand(1, 3, 224, 224), stage='neck')
>>> for out in outs:
... print(out.shape)
torch.Size([1, 64])
torch.Size([1, 128])
torch.Size([1, 256])
torch.Size([1, 512])

3. Pre-logits output (without the final linear classifier head)

>>> import torch
>>> from mmcv import Config
>>> from mmcls.models import build_classifier
>>>
>>> cfg = Config.fromfile('configs/vision_transformer/vit-base-p16_pt-64xb64_in1k-224.py').model
>>> model = build_classifier(cfg)
>>>
>>> out = model.extract_feat(torch.rand(1, 3, 224, 224), stage='pre_logits')
>>> print(out.shape) # The hidden dims in head is 3072
torch.Size([1, 3072])
""" # noqa: E501
assert stage in ['backbone', 'neck', 'pre_logits'], \
(f'Invalid output stage "{stage}", please choose from "backbone", '
'"neck" and "pre_logits"')

x = self.backbone(img)
if self.return_tuple:
if not isinstance(x, tuple):
Expand All @@ -83,8 +150,16 @@ def extract_feat(self, img):
else:
if isinstance(x, tuple):
x = x[-1]
if stage == 'backbone':
return x

if self.with_neck:
x = self.neck(x)
if stage == 'neck':
return x

if self.with_head and hasattr(self.head, 'pre_logits'):
x = self.head.pre_logits(x)
return x

def forward_train(self, img, gt_label, **kwargs):
Expand Down Expand Up @@ -122,12 +197,16 @@ def forward_train(self, img, gt_label, **kwargs):

return losses

def simple_test(self, img, img_metas):
def simple_test(self, img, img_metas=None, **kwargs):
"""Test without augmentation."""
x = self.extract_feat(img)

try:
res = self.head.simple_test(x)
if isinstance(self.head, MultiLabelClsHead):
assert 'softmax' not in kwargs, (
'Please use `sigmoid` instead of `softmax` '
'in multi-label tasks.')
res = self.head.simple_test(x, **kwargs)
except TypeError as e:
if 'not tuple' in str(e) and self.return_tuple:
return TypeError(
Expand Down
Loading