diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 36059577e0..658b74d8aa 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -156,7 +156,7 @@ def set_grad_checkpointing(self, enable=True): def get_classifier(self) -> nn.Module: return self.classifier - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): self.num_classes = num_classes self.global_pool, self.classifier = create_classifier( self.num_features, self.num_classes, pool_type=global_pool) diff --git a/timm/models/ghostnet.py b/timm/models/ghostnet.py index 07a17dabdf..d73276d4e4 100644 --- a/timm/models/ghostnet.py +++ b/timm/models/ghostnet.py @@ -273,7 +273,7 @@ def set_grad_checkpointing(self, enable=True): def get_classifier(self) -> nn.Module: return self.classifier - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): self.num_classes = num_classes # cannot meaningfully change pooling of efficient head after creation self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) diff --git a/timm/models/hrnet.py b/timm/models/hrnet.py index 82d887a92a..aee2f99fb7 100644 --- a/timm/models/hrnet.py +++ b/timm/models/hrnet.py @@ -739,7 +739,7 @@ def set_grad_checkpointing(self, enable=True): def get_classifier(self) -> nn.Module: return self.classifier - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): self.num_classes = num_classes self.global_pool, self.classifier = create_classifier( self.num_features, self.num_classes, pool_type=global_pool) diff --git a/timm/models/inception_v4.py b/timm/models/inception_v4.py index 6f75817844..a435533fd4 100644 --- a/timm/models/inception_v4.py +++ b/timm/models/inception_v4.py @@ -280,7 +280,7 @@ def set_grad_checkpointing(self, enable=True): def get_classifier(self) -> nn.Module: return self.last_linear - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): self.num_classes = num_classes self.global_pool, self.last_linear = create_classifier( self.num_features, self.num_classes, pool_type=global_pool) diff --git a/timm/models/metaformer.py b/timm/models/metaformer.py index 7e3e758770..9a568ff075 100644 --- a/timm/models/metaformer.py +++ b/timm/models/metaformer.py @@ -26,9 +26,9 @@ # See the License for the specific language governing permissions and # limitations under the License. - from collections import OrderedDict from functools import partial +from typing import Optional import torch import torch.nn as nn @@ -548,7 +548,7 @@ def __init__( # if using MlpHead, dropout is handled by MlpHead if num_classes > 0: if self.use_mlp_head: - # FIXME hidden size + # FIXME not actually returning mlp hidden state right now as pre-logits. final = MlpHead(self.num_features, num_classes, drop_rate=self.drop_rate) self.head_hidden_size = self.num_features else: @@ -583,7 +583,7 @@ def set_grad_checkpointing(self, enable=True): def get_classifier(self) -> nn.Module: return self.head.fc - def reset_classifier(self, num_classes=0, global_pool=None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): if global_pool is not None: self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity() diff --git a/timm/models/nasnet.py b/timm/models/nasnet.py index af072aa942..0bcc048568 100644 --- a/timm/models/nasnet.py +++ b/timm/models/nasnet.py @@ -518,7 +518,7 @@ def set_grad_checkpointing(self, enable=True): def get_classifier(self) -> nn.Module: return self.last_linear - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): self.num_classes = num_classes self.global_pool, self.last_linear = create_classifier( self.num_features, self.num_classes, pool_type=global_pool) diff --git a/timm/models/pnasnet.py b/timm/models/pnasnet.py index 6d6d9dbd9b..20d17945b5 100644 --- a/timm/models/pnasnet.py +++ b/timm/models/pnasnet.py @@ -307,7 +307,7 @@ def set_grad_checkpointing(self, enable=True): def get_classifier(self) -> nn.Module: return self.last_linear - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): self.num_classes = num_classes self.global_pool, self.last_linear = create_classifier( self.num_features, self.num_classes, pool_type=global_pool) diff --git a/timm/models/regnet.py b/timm/models/regnet.py index 4edb257d49..374ecaa05a 100644 --- a/timm/models/regnet.py +++ b/timm/models/regnet.py @@ -514,7 +514,7 @@ def set_grad_checkpointing(self, enable=True): def get_classifier(self) -> nn.Module: return self.head.fc - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.head.reset(num_classes, pool_type=global_pool) def forward_intermediates( diff --git a/timm/models/rexnet.py b/timm/models/rexnet.py index eeadeb337b..9971728c24 100644 --- a/timm/models/rexnet.py +++ b/timm/models/rexnet.py @@ -12,6 +12,7 @@ from functools import partial from math import ceil +from typing import Optional import torch import torch.nn as nn @@ -229,7 +230,7 @@ def set_grad_checkpointing(self, enable=True): def get_classifier(self) -> nn.Module: return self.head.fc - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes self.head.reset(num_classes, global_pool) diff --git a/timm/models/selecsls.py b/timm/models/selecsls.py index 7fa6c3e4aa..fdfa16c318 100644 --- a/timm/models/selecsls.py +++ b/timm/models/selecsls.py @@ -161,7 +161,7 @@ def set_grad_checkpointing(self, enable=True): def get_classifier(self) -> nn.Module: return self.fc - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): self.num_classes = num_classes self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) diff --git a/timm/models/senet.py b/timm/models/senet.py index c04250fd60..dd9b149b3d 100644 --- a/timm/models/senet.py +++ b/timm/models/senet.py @@ -337,7 +337,7 @@ def set_grad_checkpointing(self, enable=True): def get_classifier(self) -> nn.Module: return self.last_linear - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): self.num_classes = num_classes self.global_pool, self.last_linear = create_classifier( self.num_features, self.num_classes, pool_type=global_pool) diff --git a/timm/models/vision_transformer_relpos.py b/timm/models/vision_transformer_relpos.py index 61003014c3..ed66068eb2 100644 --- a/timm/models/vision_transformer_relpos.py +++ b/timm/models/vision_transformer_relpos.py @@ -381,7 +381,7 @@ def set_grad_checkpointing(self, enable=True): def get_classifier(self) -> nn.Module: return self.head - def reset_classifier(self, num_classes: int, global_pool=None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes if global_pool is not None: assert global_pool in ('', 'avg', 'token') diff --git a/timm/models/vision_transformer_sam.py b/timm/models/vision_transformer_sam.py index 2fd5209c85..aeabc77097 100644 --- a/timm/models/vision_transformer_sam.py +++ b/timm/models/vision_transformer_sam.py @@ -536,7 +536,7 @@ def set_grad_checkpointing(self, enable=True): def get_classifier(self) -> nn.Module: return self.head - def reset_classifier(self, num_classes=0, global_pool=None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.head.reset(num_classes, global_pool) def forward_intermediates( diff --git a/timm/models/vovnet.py b/timm/models/vovnet.py index 5ca409d383..86851666a2 100644 --- a/timm/models/vovnet.py +++ b/timm/models/vovnet.py @@ -11,7 +11,7 @@ Hacked together by / Copyright 2020 Ross Wightman """ -from typing import List +from typing import List, Optional import torch import torch.nn as nn @@ -134,9 +134,17 @@ def __init__( else: drop_path = None blocks += [OsaBlock( - in_chs, mid_chs, out_chs, layer_per_block, residual=residual and i > 0, depthwise=depthwise, - attn=attn if last_block else '', norm_layer=norm_layer, act_layer=act_layer, drop_path=drop_path) - ] + in_chs, + mid_chs, + out_chs, + layer_per_block, + residual=residual and i > 0, + depthwise=depthwise, + attn=attn if last_block else '', + norm_layer=norm_layer, + act_layer=act_layer, + drop_path=drop_path + )] in_chs = out_chs self.blocks = nn.Sequential(*blocks) @@ -252,8 +260,9 @@ def set_grad_checkpointing(self, enable=True): def get_classifier(self) -> nn.Module: return self.head.fc - def reset_classifier(self, num_classes, global_pool='avg'): - self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + def reset_classifier(self, num_classes, global_pool: Optional[str] = None): + self.num_classes = num_classes + self.head.reset(num_classes, global_pool) def forward_features(self, x): x = self.stem(x) diff --git a/timm/models/xception.py b/timm/models/xception.py index c023705194..e1f92abfa0 100644 --- a/timm/models/xception.py +++ b/timm/models/xception.py @@ -174,7 +174,7 @@ def set_grad_checkpointing(self, enable=True): def get_classifier(self) -> nn.Module: return self.fc - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): self.num_classes = num_classes self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) diff --git a/timm/models/xception_aligned.py b/timm/models/xception_aligned.py index 0eabdca213..f9071ed3f3 100644 --- a/timm/models/xception_aligned.py +++ b/timm/models/xception_aligned.py @@ -274,7 +274,7 @@ def set_grad_checkpointing(self, enable=True): def get_classifier(self) -> nn.Module: return self.head.fc - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.head.reset(num_classes, pool_type=global_pool) def forward_features(self, x):