From 9567cf6d84012f1fcfe95c55f5233d32ad7150f9 Mon Sep 17 00:00:00 2001 From: Fernando Cossio <39391180+fcossio@users.noreply.github.com> Date: Fri, 14 Jun 2024 15:24:54 +0200 Subject: [PATCH 1/4] Feature: add option global_pool='max' to VisionTransformer Most of the CNNs have a max global pooling option. I would like to extend ViT to have this option. --- timm/models/vision_transformer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index a3ca0990d8..5beb77a224 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -400,7 +400,7 @@ def __init__( patch_size: Union[int, Tuple[int, int]] = 16, in_chans: int = 3, num_classes: int = 1000, - global_pool: Literal['', 'avg', 'token', 'map'] = 'token', + global_pool: Literal['', 'avg', 'max', 'token', 'map'] = 'token', embed_dim: int = 768, depth: int = 12, num_heads: int = 12, @@ -459,10 +459,10 @@ def __init__( block_fn: Transformer block layer. """ super().__init__() - assert global_pool in ('', 'avg', 'token', 'map') + assert global_pool in ('', 'avg', 'max', 'token', 'map') assert class_token or global_pool != 'token' assert pos_embed in ('', 'none', 'learn') - use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm + use_fc_norm = global_pool in ['avg', 'max'] if fc_norm is None else fc_norm norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6) act_layer = get_act_layer(act_layer) or nn.GELU @@ -761,6 +761,8 @@ def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tenso x = self.attn_pool(x) elif self.global_pool == 'avg': x = x[:, self.num_prefix_tokens:].mean(dim=1) + elif self.global_pool == 'max': + x, _ = torch.max(x[:, self.num_prefix_tokens:], dim=1) elif self.global_pool: x = x[:, 0] # class token x = self.fc_norm(x) From 71101ebba002c72fb9203a66b2b6682b67a22002 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 14 Jun 2024 23:16:33 -0700 Subject: [PATCH 2/4] Refactor vit pooling to add more reduction options, separately callable --- timm/models/vision_transformer.py | 50 +++++++++++++++++++++++-------- 1 file changed, 38 insertions(+), 12 deletions(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 5beb77a224..441ac0c55d 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -386,6 +386,31 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self._forward(x) +def global_pool_nlc( + x: torch.Tensor, + pool_type: str = 'token', + num_prefix_tokens: int = 1, + reduce_include_prefix: bool = False, +): + if not pool_type: + return x + + if pool_type == 'token': + x = x[:, 0] # class token + else: + x = x if reduce_include_prefix else x[:, num_prefix_tokens:] + if pool_type == 'avg': + x = x.mean(dim=1) + elif pool_type == 'avgmax': + x = 0.5 * (x.amax(dim=1) + x.mean(dim=1)) + elif pool_type == 'max': + x = x.amax(dim=1) + else: + assert not pool_type, f'Unknown pool type {pool_type}' + + return x + + class VisionTransformer(nn.Module): """ Vision Transformer @@ -400,7 +425,7 @@ def __init__( patch_size: Union[int, Tuple[int, int]] = 16, in_chans: int = 3, num_classes: int = 1000, - global_pool: Literal['', 'avg', 'max', 'token', 'map'] = 'token', + global_pool: Literal['', 'avg', 'avgmax', 'max', 'token', 'map'] = 'token', embed_dim: int = 768, depth: int = 12, num_heads: int = 12, @@ -459,10 +484,10 @@ def __init__( block_fn: Transformer block layer. """ super().__init__() - assert global_pool in ('', 'avg', 'max', 'token', 'map') + assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map') assert class_token or global_pool != 'token' assert pos_embed in ('', 'none', 'learn') - use_fc_norm = global_pool in ['avg', 'max'] if fc_norm is None else fc_norm + use_fc_norm = global_pool in ('avg', 'avgmax', 'max') if fc_norm is None else fc_norm norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6) act_layer = get_act_layer(act_layer) or nn.GELU @@ -596,10 +621,10 @@ def set_grad_checkpointing(self, enable: bool = True) -> None: def get_classifier(self) -> nn.Module: return self.head - def reset_classifier(self, num_classes: int, global_pool = None) -> None: + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes if global_pool is not None: - assert global_pool in ('', 'avg', 'token', 'map') + assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map') if global_pool == 'map' and self.attn_pool is None: assert False, "Cannot currently add attention pooling in reset_classifier()." elif global_pool != 'map ' and self.attn_pool is not None: @@ -756,15 +781,16 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = self.norm(x) return x - def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + def pool(self, x: torch.Tensor, pool_type: Optional[str] = None) -> torch.Tensor: if self.attn_pool is not None: x = self.attn_pool(x) - elif self.global_pool == 'avg': - x = x[:, self.num_prefix_tokens:].mean(dim=1) - elif self.global_pool == 'max': - x, _ = torch.max(x[:, self.num_prefix_tokens:], dim=1) - elif self.global_pool: - x = x[:, 0] # class token + return x + pool_type = self.global_pool if pool_type is None else pool_type + x = global_pool_nlc(x, pool_type=pool_type, num_prefix_tokens=self.num_prefix_tokens) + return x + + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + x = self.pool(x) x = self.fc_norm(x) x = self.head_drop(x) return x if pre_logits else self.head(x) From b1a6f4a9461f9347520f7e3e1add13d536276917 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 16 Jun 2024 10:39:27 -0700 Subject: [PATCH 3/4] Some missed reset_classifier() type annotations --- timm/models/efficientnet.py | 2 +- timm/models/ghostnet.py | 2 +- timm/models/hrnet.py | 2 +- timm/models/inception_v4.py | 2 +- timm/models/metaformer.py | 6 +++--- timm/models/nasnet.py | 2 +- timm/models/pnasnet.py | 2 +- timm/models/regnet.py | 2 +- timm/models/rexnet.py | 3 ++- timm/models/selecsls.py | 2 +- timm/models/senet.py | 2 +- timm/models/vision_transformer_relpos.py | 2 +- timm/models/vision_transformer_sam.py | 2 +- timm/models/vovnet.py | 21 +++++++++++++++------ timm/models/xception.py | 2 +- timm/models/xception_aligned.py | 2 +- 16 files changed, 33 insertions(+), 23 deletions(-) diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 36059577e0..658b74d8aa 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -156,7 +156,7 @@ def set_grad_checkpointing(self, enable=True): def get_classifier(self) -> nn.Module: return self.classifier - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): self.num_classes = num_classes self.global_pool, self.classifier = create_classifier( self.num_features, self.num_classes, pool_type=global_pool) diff --git a/timm/models/ghostnet.py b/timm/models/ghostnet.py index 07a17dabdf..d73276d4e4 100644 --- a/timm/models/ghostnet.py +++ b/timm/models/ghostnet.py @@ -273,7 +273,7 @@ def set_grad_checkpointing(self, enable=True): def get_classifier(self) -> nn.Module: return self.classifier - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): self.num_classes = num_classes # cannot meaningfully change pooling of efficient head after creation self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) diff --git a/timm/models/hrnet.py b/timm/models/hrnet.py index 82d887a92a..aee2f99fb7 100644 --- a/timm/models/hrnet.py +++ b/timm/models/hrnet.py @@ -739,7 +739,7 @@ def set_grad_checkpointing(self, enable=True): def get_classifier(self) -> nn.Module: return self.classifier - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): self.num_classes = num_classes self.global_pool, self.classifier = create_classifier( self.num_features, self.num_classes, pool_type=global_pool) diff --git a/timm/models/inception_v4.py b/timm/models/inception_v4.py index 6f75817844..a435533fd4 100644 --- a/timm/models/inception_v4.py +++ b/timm/models/inception_v4.py @@ -280,7 +280,7 @@ def set_grad_checkpointing(self, enable=True): def get_classifier(self) -> nn.Module: return self.last_linear - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): self.num_classes = num_classes self.global_pool, self.last_linear = create_classifier( self.num_features, self.num_classes, pool_type=global_pool) diff --git a/timm/models/metaformer.py b/timm/models/metaformer.py index 7e3e758770..9a568ff075 100644 --- a/timm/models/metaformer.py +++ b/timm/models/metaformer.py @@ -26,9 +26,9 @@ # See the License for the specific language governing permissions and # limitations under the License. - from collections import OrderedDict from functools import partial +from typing import Optional import torch import torch.nn as nn @@ -548,7 +548,7 @@ def __init__( # if using MlpHead, dropout is handled by MlpHead if num_classes > 0: if self.use_mlp_head: - # FIXME hidden size + # FIXME not actually returning mlp hidden state right now as pre-logits. final = MlpHead(self.num_features, num_classes, drop_rate=self.drop_rate) self.head_hidden_size = self.num_features else: @@ -583,7 +583,7 @@ def set_grad_checkpointing(self, enable=True): def get_classifier(self) -> nn.Module: return self.head.fc - def reset_classifier(self, num_classes=0, global_pool=None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): if global_pool is not None: self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity() diff --git a/timm/models/nasnet.py b/timm/models/nasnet.py index af072aa942..0bcc048568 100644 --- a/timm/models/nasnet.py +++ b/timm/models/nasnet.py @@ -518,7 +518,7 @@ def set_grad_checkpointing(self, enable=True): def get_classifier(self) -> nn.Module: return self.last_linear - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): self.num_classes = num_classes self.global_pool, self.last_linear = create_classifier( self.num_features, self.num_classes, pool_type=global_pool) diff --git a/timm/models/pnasnet.py b/timm/models/pnasnet.py index 6d6d9dbd9b..20d17945b5 100644 --- a/timm/models/pnasnet.py +++ b/timm/models/pnasnet.py @@ -307,7 +307,7 @@ def set_grad_checkpointing(self, enable=True): def get_classifier(self) -> nn.Module: return self.last_linear - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): self.num_classes = num_classes self.global_pool, self.last_linear = create_classifier( self.num_features, self.num_classes, pool_type=global_pool) diff --git a/timm/models/regnet.py b/timm/models/regnet.py index 4edb257d49..374ecaa05a 100644 --- a/timm/models/regnet.py +++ b/timm/models/regnet.py @@ -514,7 +514,7 @@ def set_grad_checkpointing(self, enable=True): def get_classifier(self) -> nn.Module: return self.head.fc - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.head.reset(num_classes, pool_type=global_pool) def forward_intermediates( diff --git a/timm/models/rexnet.py b/timm/models/rexnet.py index eeadeb337b..9971728c24 100644 --- a/timm/models/rexnet.py +++ b/timm/models/rexnet.py @@ -12,6 +12,7 @@ from functools import partial from math import ceil +from typing import Optional import torch import torch.nn as nn @@ -229,7 +230,7 @@ def set_grad_checkpointing(self, enable=True): def get_classifier(self) -> nn.Module: return self.head.fc - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes self.head.reset(num_classes, global_pool) diff --git a/timm/models/selecsls.py b/timm/models/selecsls.py index 7fa6c3e4aa..fdfa16c318 100644 --- a/timm/models/selecsls.py +++ b/timm/models/selecsls.py @@ -161,7 +161,7 @@ def set_grad_checkpointing(self, enable=True): def get_classifier(self) -> nn.Module: return self.fc - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): self.num_classes = num_classes self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) diff --git a/timm/models/senet.py b/timm/models/senet.py index c04250fd60..dd9b149b3d 100644 --- a/timm/models/senet.py +++ b/timm/models/senet.py @@ -337,7 +337,7 @@ def set_grad_checkpointing(self, enable=True): def get_classifier(self) -> nn.Module: return self.last_linear - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): self.num_classes = num_classes self.global_pool, self.last_linear = create_classifier( self.num_features, self.num_classes, pool_type=global_pool) diff --git a/timm/models/vision_transformer_relpos.py b/timm/models/vision_transformer_relpos.py index 61003014c3..ed66068eb2 100644 --- a/timm/models/vision_transformer_relpos.py +++ b/timm/models/vision_transformer_relpos.py @@ -381,7 +381,7 @@ def set_grad_checkpointing(self, enable=True): def get_classifier(self) -> nn.Module: return self.head - def reset_classifier(self, num_classes: int, global_pool=None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes if global_pool is not None: assert global_pool in ('', 'avg', 'token') diff --git a/timm/models/vision_transformer_sam.py b/timm/models/vision_transformer_sam.py index 2fd5209c85..aeabc77097 100644 --- a/timm/models/vision_transformer_sam.py +++ b/timm/models/vision_transformer_sam.py @@ -536,7 +536,7 @@ def set_grad_checkpointing(self, enable=True): def get_classifier(self) -> nn.Module: return self.head - def reset_classifier(self, num_classes=0, global_pool=None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.head.reset(num_classes, global_pool) def forward_intermediates( diff --git a/timm/models/vovnet.py b/timm/models/vovnet.py index 5ca409d383..86851666a2 100644 --- a/timm/models/vovnet.py +++ b/timm/models/vovnet.py @@ -11,7 +11,7 @@ Hacked together by / Copyright 2020 Ross Wightman """ -from typing import List +from typing import List, Optional import torch import torch.nn as nn @@ -134,9 +134,17 @@ def __init__( else: drop_path = None blocks += [OsaBlock( - in_chs, mid_chs, out_chs, layer_per_block, residual=residual and i > 0, depthwise=depthwise, - attn=attn if last_block else '', norm_layer=norm_layer, act_layer=act_layer, drop_path=drop_path) - ] + in_chs, + mid_chs, + out_chs, + layer_per_block, + residual=residual and i > 0, + depthwise=depthwise, + attn=attn if last_block else '', + norm_layer=norm_layer, + act_layer=act_layer, + drop_path=drop_path + )] in_chs = out_chs self.blocks = nn.Sequential(*blocks) @@ -252,8 +260,9 @@ def set_grad_checkpointing(self, enable=True): def get_classifier(self) -> nn.Module: return self.head.fc - def reset_classifier(self, num_classes, global_pool='avg'): - self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + def reset_classifier(self, num_classes, global_pool: Optional[str] = None): + self.num_classes = num_classes + self.head.reset(num_classes, global_pool) def forward_features(self, x): x = self.stem(x) diff --git a/timm/models/xception.py b/timm/models/xception.py index c023705194..e1f92abfa0 100644 --- a/timm/models/xception.py +++ b/timm/models/xception.py @@ -174,7 +174,7 @@ def set_grad_checkpointing(self, enable=True): def get_classifier(self) -> nn.Module: return self.fc - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): self.num_classes = num_classes self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) diff --git a/timm/models/xception_aligned.py b/timm/models/xception_aligned.py index 0eabdca213..f9071ed3f3 100644 --- a/timm/models/xception_aligned.py +++ b/timm/models/xception_aligned.py @@ -274,7 +274,7 @@ def set_grad_checkpointing(self, enable=True): def get_classifier(self) -> nn.Module: return self.head.fc - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.head.reset(num_classes, pool_type=global_pool) def forward_features(self, x): From 6254dfaece48589d3865d2bfda7e37b80bbc3b68 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 16 Jun 2024 11:24:45 -0700 Subject: [PATCH 4/4] Add numpy<2.0 to requirements until tests are sorted out for pytorch 2.3 vs older --- requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index b750d03d5e..918fda0ec4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,5 @@ torch>=1.7 torchvision pyyaml huggingface_hub -safetensors>=0.2 \ No newline at end of file +safetensors>=0.2 +numpy<2.0