Skip to content

Commit

Permalink
Add xavier_uniform init of MNVC hybrid attention modules. Small impro…
Browse files Browse the repository at this point in the history
…vement in training stability.
  • Loading branch information
rwightman committed Jul 27, 2024
1 parent 9558a7f commit ab8cb07
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
10 changes: 10 additions & 0 deletions timm/layers/attention2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,16 @@ def __init__(

self.einsum = False

def init_weights(self):
# using xavier appeared to improve stability for mobilenetv4 hybrid w/ this layer
nn.init.xavier_uniform_(self.query.proj.weight)
nn.init.xavier_uniform_(self.key.proj.weight)
nn.init.xavier_uniform_(self.value.proj.weight)
if self.kv_stride > 1:
nn.init.xavier_uniform_(self.key.down_conv.weight)
nn.init.xavier_uniform_(self.value.down_conv.weight)
nn.init.xavier_uniform_(self.output.proj.weight)

def _reshape_input(self, t: torch.Tensor):
"""Reshapes a tensor to three dimensions, keeping the batch and channels."""
s = t.shape
Expand Down
7 changes: 6 additions & 1 deletion timm/models/_efficientnet_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@

import torch.nn as nn

from ._efficientnet_blocks import *
from timm.layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible, LayerType
from ._efficientnet_blocks import *
from ._manipulate import named_modules

__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights",
'resolve_bn_args', 'resolve_act_layer', 'round_channels', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT']
Expand Down Expand Up @@ -569,3 +570,7 @@ def efficientnet_init_weights(model: nn.Module, init_fn=None):
for n, m in model.named_modules():
init_fn(m, n)

# iterate and call any module.init_weights() fn, children first
for n, m in named_modules(model):
if hasattr(m, 'init_weights'):
m.init_weights()

0 comments on commit ab8cb07

Please sign in to comment.