Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ViT pooling refactor #2209

Merged
merged 4 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ torch>=1.7
torchvision
pyyaml
huggingface_hub
safetensors>=0.2
safetensors>=0.2
numpy<2.0
2 changes: 1 addition & 1 deletion timm/models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.classifier

def reset_classifier(self, num_classes, global_pool='avg'):
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
self.num_classes = num_classes
self.global_pool, self.classifier = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
Expand Down
2 changes: 1 addition & 1 deletion timm/models/ghostnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.classifier

def reset_classifier(self, num_classes, global_pool='avg'):
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
self.num_classes = num_classes
# cannot meaningfully change pooling of efficient head after creation
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
Expand Down
2 changes: 1 addition & 1 deletion timm/models/hrnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.classifier

def reset_classifier(self, num_classes, global_pool='avg'):
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
self.num_classes = num_classes
self.global_pool, self.classifier = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
Expand Down
2 changes: 1 addition & 1 deletion timm/models/inception_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.last_linear

def reset_classifier(self, num_classes, global_pool='avg'):
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
self.num_classes = num_classes
self.global_pool, self.last_linear = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
Expand Down
6 changes: 3 additions & 3 deletions timm/models/metaformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.


from collections import OrderedDict
from functools import partial
from typing import Optional

import torch
import torch.nn as nn
Expand Down Expand Up @@ -548,7 +548,7 @@ def __init__(
# if using MlpHead, dropout is handled by MlpHead
if num_classes > 0:
if self.use_mlp_head:
# FIXME hidden size
# FIXME not actually returning mlp hidden state right now as pre-logits.
final = MlpHead(self.num_features, num_classes, drop_rate=self.drop_rate)
self.head_hidden_size = self.num_features
else:
Expand Down Expand Up @@ -583,7 +583,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.head.fc

def reset_classifier(self, num_classes=0, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
if global_pool is not None:
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
Expand Down
2 changes: 1 addition & 1 deletion timm/models/nasnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.last_linear

def reset_classifier(self, num_classes, global_pool='avg'):
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
self.num_classes = num_classes
self.global_pool, self.last_linear = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
Expand Down
2 changes: 1 addition & 1 deletion timm/models/pnasnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.last_linear

def reset_classifier(self, num_classes, global_pool='avg'):
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
self.num_classes = num_classes
self.global_pool, self.last_linear = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
Expand Down
2 changes: 1 addition & 1 deletion timm/models/regnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.head.fc

def reset_classifier(self, num_classes, global_pool='avg'):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.head.reset(num_classes, pool_type=global_pool)

def forward_intermediates(
Expand Down
3 changes: 2 additions & 1 deletion timm/models/rexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from functools import partial
from math import ceil
from typing import Optional

import torch
import torch.nn as nn
Expand Down Expand Up @@ -229,7 +230,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.head.fc

def reset_classifier(self, num_classes, global_pool='avg'):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)

Expand Down
2 changes: 1 addition & 1 deletion timm/models/selecsls.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.fc

def reset_classifier(self, num_classes, global_pool='avg'):
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
self.num_classes = num_classes
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)

Expand Down
2 changes: 1 addition & 1 deletion timm/models/senet.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.last_linear

def reset_classifier(self, num_classes, global_pool='avg'):
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
self.num_classes = num_classes
self.global_pool, self.last_linear = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
Expand Down
48 changes: 38 additions & 10 deletions timm/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,31 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return self._forward(x)


def global_pool_nlc(
x: torch.Tensor,
pool_type: str = 'token',
num_prefix_tokens: int = 1,
reduce_include_prefix: bool = False,
):
if not pool_type:
return x

if pool_type == 'token':
x = x[:, 0] # class token
else:
x = x if reduce_include_prefix else x[:, num_prefix_tokens:]
if pool_type == 'avg':
x = x.mean(dim=1)
elif pool_type == 'avgmax':
x = 0.5 * (x.amax(dim=1) + x.mean(dim=1))
elif pool_type == 'max':
x = x.amax(dim=1)
else:
assert not pool_type, f'Unknown pool type {pool_type}'

return x


class VisionTransformer(nn.Module):
""" Vision Transformer

Expand All @@ -400,7 +425,7 @@ def __init__(
patch_size: Union[int, Tuple[int, int]] = 16,
in_chans: int = 3,
num_classes: int = 1000,
global_pool: Literal['', 'avg', 'token', 'map'] = 'token',
global_pool: Literal['', 'avg', 'avgmax', 'max', 'token', 'map'] = 'token',
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
Expand Down Expand Up @@ -459,10 +484,10 @@ def __init__(
block_fn: Transformer block layer.
"""
super().__init__()
assert global_pool in ('', 'avg', 'token', 'map')
assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
assert class_token or global_pool != 'token'
assert pos_embed in ('', 'none', 'learn')
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
use_fc_norm = global_pool in ('avg', 'avgmax', 'max') if fc_norm is None else fc_norm
norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
act_layer = get_act_layer(act_layer) or nn.GELU

Expand Down Expand Up @@ -596,10 +621,10 @@ def set_grad_checkpointing(self, enable: bool = True) -> None:
def get_classifier(self) -> nn.Module:
return self.head

def reset_classifier(self, num_classes: int, global_pool = None) -> None:
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'avg', 'token', 'map')
assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
if global_pool == 'map' and self.attn_pool is None:
assert False, "Cannot currently add attention pooling in reset_classifier()."
elif global_pool != 'map ' and self.attn_pool is not None:
Expand Down Expand Up @@ -756,13 +781,16 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
x = self.norm(x)
return x

def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
def pool(self, x: torch.Tensor, pool_type: Optional[str] = None) -> torch.Tensor:
if self.attn_pool is not None:
x = self.attn_pool(x)
elif self.global_pool == 'avg':
x = x[:, self.num_prefix_tokens:].mean(dim=1)
elif self.global_pool:
x = x[:, 0] # class token
return x
pool_type = self.global_pool if pool_type is None else pool_type
x = global_pool_nlc(x, pool_type=pool_type, num_prefix_tokens=self.num_prefix_tokens)
return x

def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
x = self.pool(x)
x = self.fc_norm(x)
x = self.head_drop(x)
return x if pre_logits else self.head(x)
Expand Down
2 changes: 1 addition & 1 deletion timm/models/vision_transformer_relpos.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.head

def reset_classifier(self, num_classes: int, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'avg', 'token')
Expand Down
2 changes: 1 addition & 1 deletion timm/models/vision_transformer_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.head

def reset_classifier(self, num_classes=0, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.head.reset(num_classes, global_pool)

def forward_intermediates(
Expand Down
21 changes: 15 additions & 6 deletions timm/models/vovnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
Hacked together by / Copyright 2020 Ross Wightman
"""

from typing import List
from typing import List, Optional

import torch
import torch.nn as nn
Expand Down Expand Up @@ -134,9 +134,17 @@ def __init__(
else:
drop_path = None
blocks += [OsaBlock(
in_chs, mid_chs, out_chs, layer_per_block, residual=residual and i > 0, depthwise=depthwise,
attn=attn if last_block else '', norm_layer=norm_layer, act_layer=act_layer, drop_path=drop_path)
]
in_chs,
mid_chs,
out_chs,
layer_per_block,
residual=residual and i > 0,
depthwise=depthwise,
attn=attn if last_block else '',
norm_layer=norm_layer,
act_layer=act_layer,
drop_path=drop_path
)]
in_chs = out_chs
self.blocks = nn.Sequential(*blocks)

Expand Down Expand Up @@ -252,8 +260,9 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.head.fc

def reset_classifier(self, num_classes, global_pool='avg'):
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
def reset_classifier(self, num_classes, global_pool: Optional[str] = None):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)

def forward_features(self, x):
x = self.stem(x)
Expand Down
2 changes: 1 addition & 1 deletion timm/models/xception.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.fc

def reset_classifier(self, num_classes, global_pool='avg'):
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
self.num_classes = num_classes
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)

Expand Down
2 changes: 1 addition & 1 deletion timm/models/xception_aligned.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.head.fc

def reset_classifier(self, num_classes, global_pool='avg'):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.head.reset(num_classes, pool_type=global_pool)

def forward_features(self, x):
Expand Down
Loading