Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

convnext zepto, rmsnorm experiments #2281

Merged
merged 4 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion timm/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from .mixed_conv2d import MixedConv2d
from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, SwiGLUPacked, ConvMlp, GlobalResponseNormMlp
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d
from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\
SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d
from .padding import get_padding, get_same_padding, pad_same
Expand Down
3 changes: 2 additions & 1 deletion timm/layers/create_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import torch.nn as nn

from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d
from torchvision.ops.misc import FrozenBatchNorm2d

_NORM_MAP = dict(
Expand All @@ -22,6 +22,7 @@
layernorm=LayerNorm,
layernorm2d=LayerNorm2d,
rmsnorm=RmsNorm,
rmsnorm2d=RmsNorm2d,
frozenbatchnorm2d=FrozenBatchNorm2d,
)
_NORM_TYPES = {m for n, m in _NORM_MAP.items()}
Expand Down
38 changes: 38 additions & 0 deletions timm/layers/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,41 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed.
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
return x


class RmsNorm2d(nn.Module):
""" RmsNorm w/ fast (apex) norm if available
"""
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
normalized_shape: Tuple[int, ...]
eps: float
elementwise_affine: bool

def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
normalized_shape = channels
if isinstance(normalized_shape, numbers.Integral):
# mypy error: incompatible types in assignment
normalized_shape = (normalized_shape,) # type: ignore[assignment]
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = affine
if self.elementwise_affine:
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
else:
self.register_parameter('weight', None)

self.reset_parameters()

def reset_parameters(self) -> None:
if self.elementwise_affine:
nn.init.ones_(self.weight)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.permute(0, 2, 3, 1)
# NOTE fast norm fallback needs our rms norm impl, so both paths through here.
# Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed.
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
x = x.permute(0, 3, 1, 2)
return x
53 changes: 46 additions & 7 deletions timm/models/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, Mlp, GlobalResponseNormMlp, \
LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple
LayerNorm2d, LayerNorm, RmsNorm2d, RmsNorm, create_conv2d, get_act_layer, get_norm_layer, make_divisible, to_ntuple
from timm.layers import NormMlpClassifierHead, ClassifierHead
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
Expand Down Expand Up @@ -289,24 +289,27 @@ def __init__(
super().__init__()
assert output_stride in (8, 16, 32)
kernel_sizes = to_ntuple(4)(kernel_sizes)
if norm_layer is None:
norm_layer = LayerNorm2d
norm_layer_cl = norm_layer if conv_mlp else LayerNorm
use_rms = isinstance(norm_layer, str) and norm_layer.startswith('rmsnorm')
if norm_layer is None or use_rms:
norm_layer = RmsNorm2d if use_rms else LayerNorm2d
norm_layer_cl = norm_layer if conv_mlp else (RmsNorm if use_rms else LayerNorm)
if norm_eps is not None:
norm_layer = partial(norm_layer, eps=norm_eps)
norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
else:
assert conv_mlp,\
'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
norm_layer = get_norm_layer(norm_layer)
norm_layer_cl = norm_layer
if norm_eps is not None:
norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
act_layer = get_act_layer(act_layer)

self.num_classes = num_classes
self.drop_rate = drop_rate
self.feature_info = []

assert stem_type in ('patch', 'overlap', 'overlap_tiered')
assert stem_type in ('patch', 'overlap', 'overlap_tiered', 'overlap_act')
if stem_type == 'patch':
# NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
self.stem = nn.Sequential(
Expand All @@ -316,11 +319,12 @@ def __init__(
stem_stride = patch_size
else:
mid_chs = make_divisible(dims[0] // 2) if 'tiered' in stem_type else dims[0]
self.stem = nn.Sequential(
self.stem = nn.Sequential(*filter(None, [
nn.Conv2d(in_chans, mid_chs, kernel_size=3, stride=2, padding=1, bias=conv_bias),
act_layer() if 'act' in stem_type else None,
nn.Conv2d(mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias),
norm_layer(dims[0]),
)
]))
stem_stride = 4

self.stages = nn.Sequential()
Expand Down Expand Up @@ -592,6 +596,13 @@ def _cfgv2(url='', **kwargs):
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),

'convnext_zepto_rms.ra4_e3600_r224_in1k': _cfg(
hf_hub_id='timm/',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
'convnext_zepto_rms_ols.untrained': _cfg(
# hf_hub_id='timm/',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
test_input_size=(3, 256, 256), test_crop_pct=0.95),
'convnext_atto.d2_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth',
hf_hub_id='timm/',
Expand All @@ -600,6 +611,9 @@ def _cfgv2(url='', **kwargs):
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_ols_a2-78d1c8f3.pth',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=0.95),
'convnext_atto_rms.untrained': _cfg(
#hf_hub_id='timm/',
test_input_size=(3, 256, 256), test_crop_pct=0.95),
'convnext_femto.d1_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth',
hf_hub_id='timm/',
Expand Down Expand Up @@ -968,6 +982,23 @@ def _cfgv2(url='', **kwargs):
})


@register_model
def convnext_zepto_rms(pretrained=False, **kwargs) -> ConvNeXt:
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
model_args = dict(depths=(2, 2, 4, 2), dims=(32, 64, 128, 256), conv_mlp=True, norm_layer='rmsnorm2d')
model = _create_convnext('convnext_zepto_rms', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def convnext_zepto_rms_ols(pretrained=False, **kwargs) -> ConvNeXt:
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
model_args = dict(
depths=(2, 2, 4, 2), dims=(32, 64, 128, 256), conv_mlp=True, norm_layer='rmsnorm2d', stem_type='overlap_act')
model = _create_convnext('convnext_zepto_rms_ols', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def convnext_atto(pretrained=False, **kwargs) -> ConvNeXt:
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
Expand All @@ -984,6 +1015,14 @@ def convnext_atto_ols(pretrained=False, **kwargs) -> ConvNeXt:
return model


@register_model
def convnext_atto_rms(pretrained=False, **kwargs) -> ConvNeXt:
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
model_args = dict(depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True, norm_layer='rmsnorm2d')
model = _create_convnext('convnext_atto_rms', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def convnext_femto(pretrained=False, **kwargs) -> ConvNeXt:
# timm femto variant
Expand Down
4 changes: 2 additions & 2 deletions timm/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2019,7 +2019,7 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
hf_hub_id='timm/',
input_size=(3, 160, 160), crop_pct=0.95),
'test_vit3.r160_in1k': _cfg(
#hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 160, 160), crop_pct=0.95),
}

Expand Down Expand Up @@ -3238,7 +3238,7 @@ def test_vit3(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT Test
"""
model_args = dict(
patch_size=16, embed_dim=96, depth=10, num_heads=3, mlp_ratio=2,
patch_size=16, embed_dim=96, depth=9, num_heads=3, mlp_ratio=2,
class_token=False, reg_tokens=1, global_pool='map', init_values=1e-5)
model = _create_vision_transformer('test_vit3', pretrained=pretrained, **dict(model_args, **kwargs))
return model
Expand Down
Loading