Skip to content

Commit

Permalink
Merge pull request #2225 from huggingface/small_things
Browse files Browse the repository at this point in the history
Small things
  • Loading branch information
rwightman committed Jul 26, 2024
2 parents 8b14fc7 + 61df3fd commit cec70b6
Show file tree
Hide file tree
Showing 22 changed files with 1,213 additions and 422 deletions.
2 changes: 1 addition & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*',
'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*',
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*', 'vitamin*'
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*', 'vitamin*', 'test_vit*',
]
NUM_NON_STD = len(NON_STD_FILTERS)

Expand Down
51 changes: 50 additions & 1 deletion timm/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import List, Sequence, Tuple, Union

import torch
import torchvision.transforms as transforms
import torchvision.transforms.functional as F
try:
from torchvision.transforms.functional import InterpolationMode
Expand All @@ -17,7 +18,7 @@
__all__ = [
"ToNumpy", "ToTensor", "str_to_interp_mode", "str_to_pil_interp", "interp_mode_to_str",
"RandomResizedCropAndInterpolation", "CenterCropOrPad", "center_crop_or_pad", "crop_or_pad",
"RandomCropOrPad", "RandomPad", "ResizeKeepRatio", "TrimBorder"
"RandomCropOrPad", "RandomPad", "ResizeKeepRatio", "TrimBorder", "MaybeToTensor", "MaybePILToTensor"
]


Expand All @@ -40,6 +41,54 @@ def __call__(self, pil_img):
return F.pil_to_tensor(pil_img).to(dtype=self.dtype)


class MaybeToTensor(transforms.ToTensor):
"""Convert a PIL Image or ndarray to tensor if it's not already one.
"""

def __init__(self) -> None:
super().__init__()

def __call__(self, pic) -> torch.Tensor:
"""
Args:
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
if isinstance(pic, torch.Tensor):
return pic
return F.to_tensor(pic)

def __repr__(self) -> str:
return f"{self.__class__.__name__}()"


class MaybePILToTensor:
"""Convert a PIL Image to a tensor of the same type - this does not scale values.
"""

def __init__(self) -> None:
super().__init__()

def __call__(self, pic):
"""
Note: A deep copy of the underlying array is performed.
Args:
pic (PIL Image): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
if isinstance(pic, torch.Tensor):
return pic
return F.pil_to_tensor(pic)

def __repr__(self) -> str:
return f"{self.__class__.__name__}()"


# Pillow is deprecating the top-level resampling attributes (e.g., Image.BILINEAR) in
# favor of the Image.Resampling enum. The top-level resampling attributes will be
# removed in Pillow 10.
Expand Down
16 changes: 8 additions & 8 deletions timm/data/transforms_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform
from timm.data.transforms import str_to_interp_mode, str_to_pil_interp, RandomResizedCropAndInterpolation,\
ResizeKeepRatio, CenterCropOrPad, RandomCropOrPad, TrimBorder, ToNumpy
from timm.data.transforms import str_to_interp_mode, str_to_pil_interp, RandomResizedCropAndInterpolation, \
ResizeKeepRatio, CenterCropOrPad, RandomCropOrPad, TrimBorder, ToNumpy, MaybeToTensor, MaybePILToTensor
from timm.data.random_erasing import RandomErasing


Expand Down Expand Up @@ -49,10 +49,10 @@ def transforms_noaug_train(
tfl += [ToNumpy()]
elif not normalize:
# when normalize disabled, converted to tensor without scaling, keep original dtype
tfl += [transforms.PILToTensor()]
tfl += [MaybePILToTensor()]
else:
tfl += [
transforms.ToTensor(),
MaybeToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std)
Expand Down Expand Up @@ -218,10 +218,10 @@ def transforms_imagenet_train(
final_tfl += [ToNumpy()]
elif not normalize:
# when normalize disable, converted to tensor without scaling, keeps original dtype
final_tfl += [transforms.PILToTensor()]
final_tfl += [MaybePILToTensor()]
else:
final_tfl += [
transforms.ToTensor(),
MaybeToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std),
Expand Down Expand Up @@ -318,10 +318,10 @@ def transforms_imagenet_eval(
tfl += [ToNumpy()]
elif not normalize:
# when normalize disabled, converted to tensor without scaling, keeps original dtype
tfl += [transforms.PILToTensor()]
tfl += [MaybePILToTensor()]
else:
tfl += [
transforms.ToTensor(),
MaybeToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std),
Expand Down
1 change: 1 addition & 0 deletions timm/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .global_context import GlobalContext
from .grid import ndgrid, meshgrid
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple
from .hybrid_embed import HybridEmbed, HybridEmbedWithSize
from .inplace_abn import InplaceAbn
from .linear import Linear
from .mixed_conv2d import MixedConv2d
Expand Down
253 changes: 253 additions & 0 deletions timm/layers/hybrid_embed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
""" Image to Patch Hybird Embedding Layer
Hacked together by / Copyright 2020 Ross Wightman
"""
import logging
import math
from typing import List, Optional, Tuple, Union

import torch
from torch import nn as nn
import torch.nn.functional as F

from .format import Format, nchw_to
from .helpers import to_2tuple
from .patch_embed import resample_patch_embed


_logger = logging.getLogger(__name__)


class HybridEmbed(nn.Module):
""" CNN Feature Map Embedding
Extract feature map from CNN, flatten, project to embedding dim.
"""
output_fmt: Format
dynamic_img_pad: torch.jit.Final[bool]

def __init__(
self,
backbone: nn.Module,
img_size: Union[int, Tuple[int, int]] = 224,
patch_size: Union[int, Tuple[int, int]] = 1,
feature_size: Optional[Union[int, Tuple[int, int]]] = None,
feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
in_chans: int = 3,
embed_dim: int = 768,
bias: bool = True,
proj: bool = True,
flatten: bool = True,
output_fmt: Optional[str] = None,
strict_img_size: bool = True,
dynamic_img_pad: bool = False,
):
super().__init__()
assert isinstance(backbone, nn.Module)
self.backbone = backbone
self.in_chans = in_chans
(
self.img_size,
self.patch_size,
self.feature_size,
self.feature_ratio,
self.feature_dim,
self.grid_size,
self.num_patches,
) = self._init_backbone(
img_size=img_size,
patch_size=patch_size,
feature_size=feature_size,
feature_ratio=feature_ratio,
)

if output_fmt is not None:
self.flatten = False
self.output_fmt = Format(output_fmt)
else:
# flatten spatial dim and transpose to channels last, kept for bwd compat
self.flatten = flatten
self.output_fmt = Format.NCHW
self.strict_img_size = strict_img_size
self.dynamic_img_pad = dynamic_img_pad
if not dynamic_img_pad:
assert self.feature_size[0] % self.patch_size[0] == 0 and self.feature_size[1] % self.patch_size[1] == 0

if proj:
self.proj = nn.Conv2d(
self.feature_dim,
embed_dim,
kernel_size=patch_size,
stride=patch_size,
bias=bias,
)
else:
assert self.feature_dim == embed_dim, \
f'The feature dim ({self.feature_dim} must match embed dim ({embed_dim}) when projection disabled.'
self.proj = nn.Identity()

def _init_backbone(
self,
img_size: Union[int, Tuple[int, int]] = 224,
patch_size: Union[int, Tuple[int, int]] = 1,
feature_size: Optional[Union[int, Tuple[int, int]]] = None,
feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
feature_dim: Optional[int] = None,
):
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
if feature_size is None:
with torch.no_grad():
# NOTE Most reliable way of determining output dims is to run forward pass
training = self.backbone.training
if training:
self.backbone.eval()
o = self.backbone(torch.zeros(1, self.in_chans, img_size[0], img_size[1]))
if isinstance(o, (list, tuple)):
o = o[-1] # last feature if backbone outputs list/tuple of features
feature_size = o.shape[-2:]
feature_dim = o.shape[1]
self.backbone.train(training)
feature_ratio = tuple([s // f for s, f in zip(img_size, feature_size)])
else:
feature_size = to_2tuple(feature_size)
feature_ratio = to_2tuple(feature_ratio or 16)
if feature_dim is None:
if hasattr(self.backbone, 'feature_info'):
feature_dim = self.backbone.feature_info.channels()[-1]
else:
feature_dim = self.backbone.num_features
grid_size = tuple([f // p for f, p in zip(feature_size, patch_size)])
num_patches = grid_size[0] * grid_size[1]
return img_size, patch_size, feature_size, feature_ratio, feature_dim, grid_size, num_patches

def set_input_size(
self,
img_size: Optional[Union[int, Tuple[int, int]]] = None,
patch_size: Optional[Union[int, Tuple[int, int]]] = None,
feature_size: Optional[Union[int, Tuple[int, int]]] = None,
feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
feature_dim: Optional[int] = None,
):
assert img_size is not None or patch_size is not None
img_size = img_size or self.img_size
new_patch_size = None
if patch_size is not None:
new_patch_size = to_2tuple(patch_size)
if new_patch_size is not None and new_patch_size != self.patch_size:
assert isinstance(self.proj, nn.Conv2d), 'HybridEmbed must have a projection layer to change patch size.'
with torch.no_grad():
new_proj = nn.Conv2d(
self.proj.in_channels,
self.proj.out_channels,
kernel_size=new_patch_size,
stride=new_patch_size,
bias=self.proj.bias is not None,
)
new_proj.weight.copy_(resample_patch_embed(self.proj.weight, new_patch_size, verbose=True))
if self.proj.bias is not None:
new_proj.bias.copy_(self.proj.bias)
self.proj = new_proj
patch_size = new_patch_size
patch_size = patch_size or self.patch_size

if img_size != self.img_size or patch_size != self.patch_size:
(
self.img_size,
self.patch_size,
self.feature_size,
self.feature_ratio,
self.feature_dim,
self.grid_size,
self.num_patches,
) = self._init_backbone(
img_size=img_size,
patch_size=patch_size,
feature_size=feature_size,
feature_ratio=feature_ratio,
feature_dim=feature_dim,
)

def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]:
total_reduction = (
self.feature_ratio[0] * self.patch_size[0],
self.feature_ratio[1] * self.patch_size[1]
)
if as_scalar:
return max(total_reduction)
else:
return total_reduction

def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
""" Get feature grid size taking account dynamic padding and backbone network feat reduction
"""
feat_size = (img_size[0] // self.feature_ratio[0], img_size[1] // self.feature_ratio[1])
if self.dynamic_img_pad:
return math.ceil(feat_size[0] / self.patch_size[0]), math.ceil(feat_size[1] / self.patch_size[1])
else:
return feat_size[0] // self.patch_size[0], feat_size[1] // self.patch_size[1]

@torch.jit.ignore
def set_grad_checkpointing(self, enable: bool = True):
if hasattr(self.backbone, 'set_grad_checkpointing'):
self.backbone.set_grad_checkpointing(enable=enable)
elif hasattr(self.backbone, 'grad_checkpointing'):
self.backbone.grad_checkpointing = enable

def forward(self, x):
x = self.backbone(x)
if isinstance(x, (list, tuple)):
x = x[-1] # last feature if backbone outputs list/tuple of features
_, _, H, W = x.shape
if self.dynamic_img_pad:
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
x = F.pad(x, (0, pad_w, 0, pad_h))
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
elif self.output_fmt != Format.NCHW:
x = nchw_to(x, self.output_fmt)
return x


class HybridEmbedWithSize(HybridEmbed):
""" CNN Feature Map Embedding
Extract feature map from CNN, flatten, project to embedding dim.
"""
def __init__(
self,
backbone: nn.Module,
img_size: Union[int, Tuple[int, int]] = 224,
patch_size: Union[int, Tuple[int, int]] = 1,
feature_size: Optional[Union[int, Tuple[int, int]]] = None,
feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
in_chans: int = 3,
embed_dim: int = 768,
bias=True,
proj=True,
):
super().__init__(
backbone=backbone,
img_size=img_size,
patch_size=patch_size,
feature_size=feature_size,
feature_ratio=feature_ratio,
in_chans=in_chans,
embed_dim=embed_dim,
bias=bias,
proj=proj,
)

@torch.jit.ignore
def set_grad_checkpointing(self, enable: bool = True):
if hasattr(self.backbone, 'set_grad_checkpointing'):
self.backbone.set_grad_checkpointing(enable=enable)
elif hasattr(self.backbone, 'grad_checkpointing'):
self.backbone.grad_checkpointing = enable

def forward(self, x) -> Tuple[torch.Tensor, List[int]]:
x = self.backbone(x)
if isinstance(x, (list, tuple)):
x = x[-1] # last feature if backbone outputs list/tuple of features
x = self.proj(x)
return x.flatten(2).transpose(1, 2), x.shape[-2:]
Loading

0 comments on commit cec70b6

Please sign in to comment.