Skip to content

Commit

Permalink
[feat] Adding a conv MLP, following VAN (#321)
Browse files Browse the repository at this point in the history
* Adding a conv MLP, following VAN

* Renaming to Conv2DFeedforward, more specific I believe

* Catch FF requiring squared context length

* Adding a reference in the README

* removing dead code
  • Loading branch information
blefaudeux authored Jun 6, 2022
1 parent bcb7075 commit b41a3f3
Show file tree
Hide file tree
Showing 11 changed files with 119 additions and 9 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Four blocksparsity layouts from DeepSpeed [#320]
- Support several initialization options [#312]
- Conv2DFeedforward feedforward part [#321]


## [0.0.11] - 2022-05-30
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ Patrick et al., 2021](https://arxiv.org/abs/2106.05392)*
- [MLP](xformers/components/feedforward/mlp.py)
- [Fused](xformers/components/feedforward/fused_mlp.py)
- [Mixture of Experts](xformers/components/feedforward/mixture_of_experts.py)
- [Conv2DFeedforward](xformers/components/feedforward/conv_mlp.py)
</p></details>
Expand Down
1 change: 1 addition & 0 deletions examples/cifarMetaformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(
use_rotary_embeddings=use_rotary_embeddings,
mlp_multiplier=4,
dim_head=32,
feedforward="Conv2DFeedforward",
)

# Now instantiate the metaformer trunk
Expand Down
5 changes: 4 additions & 1 deletion tests/test_block_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,10 @@ def test_xformer_decoder_block(
)

# Test different sequence lengths when encoding and decoding
if not decoder_block.requires_same_k_q_dimensions:
if (
not decoder_block.requires_same_k_q_dimensions
and not decoder_block.requires_squared_context_length
):
if not causal or not decoder_block.causal_attention:
_ = decoder_block(inputs[:, :-16], encoded)
else:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_feedforward.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from xformers.helpers.test_utils import init_torch_distributed_local

BATCH = 4
SEQ = 512
SEQ = 256
EMBD = 16
LATENT = 128
DROPOUT = 0.5
Expand Down
5 changes: 4 additions & 1 deletion tests/test_model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,16 +199,19 @@ def check_against_default(p):


@pytest.mark.parametrize("weight_init", [w.value for w in xFormerWeightInit])
@pytest.mark.parametrize("feedforward", ["MLP", "Conv2DFeedforward"])
@pytest.mark.parametrize("deepnorm", [False, True])
@pytest.mark.parametrize("device", DEVICES)
def test_weight_init(weight_init, deepnorm, device):
def test_weight_init(weight_init, feedforward, deepnorm, device):
torch.cuda.manual_seed(42)
torch.manual_seed(42)

config = test_configs_dict

if deepnorm:
config["encoder"]["layer_norm_style"] = "deepnorm"
config["encoder"]["feedforward_config"]["name"] = feedforward

config["decoder"]["layer_norm_style"] = "deepnorm"

# Make sure that all the init methods catch all the weights
Expand Down
5 changes: 5 additions & 0 deletions xformers/components/feedforward/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,13 @@ def __init__(
**kwargs,
):
super().__init__()

# This feedforward requires a CUDA accelerator
self.requires_cuda = False

# This feedforward requires a context length which is squared, often due to 2D pooling
self.requires_squared_context = False

@classmethod
def from_config(cls: Type[Self], config: FeedforwardConfig) -> Self:
# Generate the class inputs from the config
Expand Down
97 changes: 97 additions & 0 deletions xformers/components/feedforward/conv_mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


# CREDITS: Largely reusing the code from the reference VAN implementation
# see https://github.com/Visual-Attention-Network

import math
from dataclasses import dataclass
from typing import Optional

import torch.nn as nn

from xformers.components import Activation, build_activation
from xformers.components.feedforward import Feedforward, FeedforwardConfig

from . import register_feedforward


@dataclass
class ConvMlpConfig(FeedforwardConfig):
hidden_layer_multiplier: int
dim_model: int
dim_model_out: Optional[int]
act_layer: Activation
dropout: float


@register_feedforward("Conv2DFeedforward", ConvMlpConfig)
class Conv2DFeedforward(Feedforward):
"""
A Convolutional feed-forward network, as proposed in VAN_ (Vision Attention Network, Guo et al.)
.. _VAN: https://arxiv.org/pdf/2202.09741.pdf
"""

def __init__(
self,
dim_model: int,
hidden_layer_multiplier: int = 1,
dim_model_out: Optional[int] = None,
activation: Activation = Activation.GeLU,
dropout=0.0,
*args,
**kwargs,
):
super().__init__()
out_features = dim_model_out or dim_model
hidden_features = hidden_layer_multiplier * dim_model

self.conv_mlp = nn.Sequential(
nn.Conv2d(dim_model, hidden_features, 1),
nn.Conv2d(
hidden_features,
hidden_features,
3,
1,
1,
bias=True,
groups=hidden_features,
),
build_activation(activation),
nn.Conv2d(hidden_features, out_features, 1),
nn.Dropout(dropout),
)

# This feedforward requires a context length which is squared, often due to 2D pooling
self.requires_squared_context = True

def init_weights(self, **kwargs):
# Follow the original init, but also make it possible to initialize from the outside
def init_module(m: nn.Module):
if isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()

self.apply(init_module)

def forward(self, x):
# The conv layers expect NCHW, we have NLC by default
B, L, C = x.shape
HW = int(math.sqrt(x.shape[-2]))
assert HW**2 == L, "Conv2DFeedforward requires squared context lengths"

x = x.reshape((B, HW, HW, C)).swapdims(1, -1)

# The actual FW, including the 2d convolutions
x = self.conv_mlp(x)

# back to NLC
x = x.transpose(1, -1)
return x.flatten(1, 2)
4 changes: 3 additions & 1 deletion xformers/factory/block_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,11 @@ def __init__(self, config: xFormerDecoderConfig, **kwargs):
cross_mha = build_multi_head_attention(config.multi_head_config_cross)
feedforward = build_feedforward(config.feedforward_config)

# Expose attention specific capabilities
# Expose attention or feedforward specific capabilities
self.supports_attention_mask = mha.attention.supports_attention_mask
self.requires_same_k_q_dimensions = mha.attention.requires_same_k_q_dimensions
self.requires_squared_context_length = feedforward.requires_squared_context

self.causal_attention = (
mha.attention.causal if hasattr(mha.attention, "causal") else False
)
Expand Down
4 changes: 0 additions & 4 deletions xformers/factory/weight_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@ class xFormerWeightInit(str, Enum):
Small = "small"


# TODO: Check with a bunch of quick trainings whether all the inits are in the green
# TODO: Check test coverage


def get_weight_init_fn(init_choice: xFormerWeightInit):
"""
Provide the xFormers factory with weight init routines.
Expand Down
3 changes: 2 additions & 1 deletion xformers/helpers/hierarchical_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def get_hierarchical_configuration(
use_rotary_embeddings: bool = True,
mlp_multiplier: int = 4,
dim_head=32,
feedforward="MLP",
):
"""
A small helper to generate hierarchical xformers configurations,
Expand All @@ -49,7 +50,7 @@ def get_hierarchical_configuration(
},
},
"feedforward_config": {
"name": "MLP",
"name": feedforward,
"activation": "gelu",
"hidden_layer_multiplier": mlp_multiplier,
"dropout": 0.0,
Expand Down

0 comments on commit b41a3f3

Please sign in to comment.