Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug] Pre-norm wrapper only normalizing the first input #233

Merged
merged 2 commits into from
Mar 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
- Expose bias flag for feedforwards, same default as Timm [#220]
- Update eps value for layernormm, same default as torch [#221]
- PreNorm bugfix, only one input was normalized [#233]

## [0.0.9] - 2022-02-09
### Added
Expand Down
27 changes: 27 additions & 0 deletions tests/test_residual.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


import torch

from xformers.components import PreNorm


class Passthrough(torch.nn.Module):
def __init__(self) -> None:
super().__init__()

def forward(self, *args):
return args


def test_pre_norm():
# Check that passing the same tensor a bunch of times skips the extra normalizations
x = torch.rand((3, 3))

wrap = PreNorm(d_model=3, sublayer=Passthrough(), use_triton=False)
outputs = wrap(inputs=[x, x, x])

assert id(outputs[0]) == id(outputs[1])
6 changes: 5 additions & 1 deletion xformers/components/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
from .in_proj_container import InProjContainer, InProjParams # noqa
from .multi_head_dispatch import MultiHeadDispatch # noqa
from .multi_head_dispatch import MultiHeadDispatchConfig
from .residual import LayerNormStyle, PostNorm, PreNorm, Residual # noqa
from .residual import LayerNormStyle # noqa; noqa
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: extra noqa typo?

from .residual import PostNorm # noqa
from .residual import PreNorm # noqa
from .residual import RequiresWrappedInputs # noqa
from .residual import Residual # noqa

# automatically import any Python files in the directory
import_all_modules(str(Path(__file__).parent), "xformers.components")
Expand Down
69 changes: 46 additions & 23 deletions xformers/components/residual.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,6 @@
from xformers.triton.layer_norm import FusedLayerNorm


def _to_tensor_list(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was supposed to be a helper, but in the end it was masking bugs (in that layer(x, y, z) could have a different behaviour depending on the residual wraps). I think that it's better to force inputs in a single fashion

inputs: Union[torch.Tensor, List[torch.Tensor]]
) -> List[torch.Tensor]:
if not isinstance(inputs, list):
inputs = [inputs]
return inputs


class LayerNormStyle(str, Enum):
"""Support different layer norm styles.
See "On Layer Normalization in the Transformer Architecture",
Expand All @@ -34,21 +26,37 @@ class LayerNormStyle(str, Enum):
Post = "post"


class RequiresWrappedInputs:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

classes which derive from this class only accept a single input list (makes it impossible to subtly footgun)

"""Used to mark, through inheritance,
the fact that this class will require inputs to be passed as a single list"""

pass


# CREDITS: the following is inspired by FastAI's Transformer implementation
class Residual(nn.Module):
"""Object-oriented handling of the residual path"""
class Residual(nn.Module, RequiresWrappedInputs):
"""
Object-oriented handling of the residual path

.. Note: the wrapped layers must accept all the inputs as a single list
"""

def __init__(self, layer: nn.Module):
super().__init__()
self.layer = layer

def forward(self, inputs: Union[torch.Tensor, List[torch.Tensor]], *args, **kwargs):
inputs = _to_tensor_list(inputs)
# PreNorm and PostNorm require all the tensors to be passed as a list
self.wrap_inputs = isinstance(layer, PreNorm) or isinstance(layer, PostNorm)

return inputs[0] + self.layer(*inputs, *args, **kwargs)
def forward(self, inputs: List[torch.Tensor], **kwargs):
if self.wrap_inputs:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the trick here is that these residual/norm wrapper can wrap themselves at times. When they wrap an external layer, then the inputs are unrolled, when the sublayer is another wrap then we maintain inputs=List[Tensor] to prevent bugs like this one

return inputs[0] + self.layer(inputs=inputs, **kwargs)

else:
return inputs[0] + self.layer(*inputs, **kwargs)


class PreNorm(nn.Module):
class PreNorm(nn.Module, RequiresWrappedInputs):
"""Adds LayerNorm before computing attention

..Note: If a list of inputs is passed, all of them get normalized"""
Expand All @@ -61,15 +69,28 @@ def __init__(self, d_model: int, sublayer: nn.Module, use_triton: bool = True):
self.norm = nn.LayerNorm(d_model)

self.sublayer = sublayer
self.wrap_inputs = isinstance(sublayer, RequiresWrappedInputs)

def forward(self, inputs: Union[torch.Tensor, List[torch.Tensor]], *args, **kwargs):
inputs = _to_tensor_list(inputs)
def forward(self, inputs: List[torch.Tensor], **kwargs):
assert len(inputs) > 0

x_norm = [self.norm(x_) for x_ in inputs]
return self.sublayer(*x_norm, *args, **kwargs)
# Perf improvement: if the inputs are all the same, only norm once
ids = [id(x) for x in inputs]
if ids.count(ids[0]) == len(ids):
# The same tensor is passed multiple times
x_norm = self.norm(inputs[0])
inputs_normed = [x_norm for _ in inputs]
else:
# The inputs differ, norm them all
inputs_normed = [self.norm(x_) for x_ in inputs]

if self.wrap_inputs:
return self.sublayer(inputs=inputs_normed, **kwargs)
else:
return self.sublayer(*inputs_normed, **kwargs)

class PostNorm(nn.Module):

class PostNorm(nn.Module, RequiresWrappedInputs):
"""Adds LayerNorm after computing attention"""

def __init__(self, d_model: int, sublayer: nn.Module, use_triton: bool = True):
Expand All @@ -80,9 +101,11 @@ def __init__(self, d_model: int, sublayer: nn.Module, use_triton: bool = True):
self.norm = nn.LayerNorm(d_model)

self.sublayer = sublayer
self.wrap_inputs = isinstance(sublayer, RequiresWrappedInputs)

def forward(self, inputs: Union[torch.Tensor, List[torch.Tensor]], *args, **kwargs):
inputs = _to_tensor_list(inputs)

x = self.sublayer(*inputs, *args, **kwargs)
def forward(self, inputs: List[torch.Tensor], **kwargs):
if self.wrap_inputs:
x = self.sublayer(inputs=inputs, **kwargs)
else:
x = self.sublayer(*inputs, **kwargs)
return self.norm(x)
14 changes: 12 additions & 2 deletions xformers/components/reversible.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from torch.autograd.function import Function
from torch.utils.checkpoint import get_device_states, set_device_states

from xformers.components import RequiresWrappedInputs

# CREDITS: Code adapted from
# https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reversible.py
# https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py,
Expand All @@ -26,6 +28,7 @@ def __init__(self, net: nn.Module):
self.cuda_in_fwd: bool = False
self.gpu_devices: List[int] = []
self.gpu_states: List[torch.Tensor] = []
self.wrap_inputs = isinstance(net, RequiresWrappedInputs)

def record_rng(self, *args):
self.cpu_state = torch.get_rng_state()
Expand All @@ -38,7 +41,10 @@ def forward(self, *args, record_rng: bool = False, set_rng: bool = False, **kwar
self.record_rng(*args)

if not set_rng:
return self.net(*args, **kwargs)
if self.wrap_inputs:
return self.net(inputs=args, **kwargs)
else:
return self.net(*args, **kwargs)

rng_devices: List[int] = []
if self.cuda_in_fwd:
Expand All @@ -48,7 +54,11 @@ def forward(self, *args, record_rng: bool = False, set_rng: bool = False, **kwar
torch.set_rng_state(self.cpu_state)
if self.cuda_in_fwd:
set_device_states(self.gpu_devices, self.gpu_states)
return self.net(*args, **kwargs)

if self.wrap_inputs:
return self.net(inputs=args, **kwargs)
else:
return self.net(*args, **kwargs)


class ReversibleBlock(nn.Module):
Expand Down
12 changes: 7 additions & 5 deletions xformers/factory/block_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,8 @@ def forward(
q, k, v = x, x, x

# Pre/Post norms and residual paths are already handled
x = self.wrap_att(q, k, v, att_mask=att_mask)
x = self.wrap_ff(x)
x = self.wrap_att(inputs=[q, k, v], att_mask=att_mask)
Copy link
Contributor Author

@blefaudeux blefaudeux Mar 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all the wraps require a single input list + kwargs, which I believe is more future proof (this normalizing bug cannot happen, or at least not as easily)

x = self.wrap_ff(inputs=[x])

return x

Expand Down Expand Up @@ -397,8 +397,10 @@ def forward(
else:
target_q, target_k, target_v = target, target, target

x = self.wrap_att([target_q, target_k, target_v], att_mask=decoder_att_mask)
x = self.wrap_cross([x, memory, memory], att_mask=encoder_att_mask)
x = self.wrap_ff(x)
x = self.wrap_att(
inputs=[target_q, target_k, target_v], att_mask=decoder_att_mask
)
x = self.wrap_cross(inputs=[x, memory, memory], att_mask=encoder_att_mask)
x = self.wrap_ff(inputs=[x])

return x