Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug] Pre-norm wrapper only normalizing the first input #233

Merged
merged 2 commits into from
Mar 11, 2022

Conversation

blefaudeux
Copy link
Contributor

@blefaudeux blefaudeux commented Mar 11, 2022

What does this PR do?

Could well explain #219 (cc @jramapuram), I stumbled upon this bug by luck while rewriting some of the input projection part and being more strict with self-attention (in that case the tensors need to be the exact same objects, with an incoming PR, makes it easier to catch a bug if the intent is self attention).

This showed that the pre-norm wrapper was (a) only normalizing the first input (b) creating new objects because it would normalize the tensors one by one, even if it was the same tensor to begin with. This is a bug which already existed in the past and was fixed, not sure how it came back (long lived branch + botched merge maybe), hence the unit test and the change of interface to make sure that this does not happen anymore. Consequence was both correctness + speed (after the pre-norm, the tensors were not the same anymore, so the self attention speed up was off)

(a) I changed the interface to make it compulsory to pass the inputs as a list for all these wrappers, so that there's no more confusion
(b) is unit tested in this PR + the incoming PR will add another test

Before submitting

  • Did you have fun?
    • Make sure you had fun coding 🙃
  • Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?
  • Did you update the changelog? (if needed)

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 11, 2022
@blefaudeux blefaudeux marked this pull request as draft March 11, 2022 06:54
@blefaudeux
Copy link
Contributor Author

As usual, reversible is the pain point.. fixing that

@blefaudeux blefaudeux marked this pull request as ready for review March 11, 2022 07:06
@@ -34,21 +26,37 @@ class LayerNormStyle(str, Enum):
Post = "post"


class RequiresWrappedInputs:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

classes which derive from this class only accept a single input list (makes it impossible to subtly footgun)

@@ -16,14 +16,6 @@
from xformers.triton.layer_norm import FusedLayerNorm


def _to_tensor_list(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was supposed to be a helper, but in the end it was masking bugs (in that layer(x, y, z) could have a different behaviour depending on the residual wraps). I think that it's better to force inputs in a single fashion


return inputs[0] + self.layer(*inputs, *args, **kwargs)
def forward(self, inputs: List[torch.Tensor], **kwargs):
if self.wrap_inputs:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the trick here is that these residual/norm wrapper can wrap themselves at times. When they wrap an external layer, then the inputs are unrolled, when the sublayer is another wrap then we maintain inputs=List[Tensor] to prevent bugs like this one

@@ -335,8 +335,8 @@ def forward(
q, k, v = x, x, x

# Pre/Post norms and residual paths are already handled
x = self.wrap_att(q, k, v, att_mask=att_mask)
x = self.wrap_ff(x)
x = self.wrap_att(inputs=[q, k, v], att_mask=att_mask)
Copy link
Contributor Author

@blefaudeux blefaudeux Mar 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all the wraps require a single input list + kwargs, which I believe is more future proof (this normalizing bug cannot happen, or at least not as easily)

@codecov-commenter
Copy link

Codecov Report

Merging #233 (fb0b479) into main (0abfb78) will decrease coverage by 0.04%.
The diff coverage is 88.63%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #233      +/-   ##
==========================================
- Coverage   92.22%   92.17%   -0.05%     
==========================================
  Files          60       60              
  Lines        3228     3247      +19     
==========================================
+ Hits         2977     2993      +16     
- Misses        251      254       +3     
Flag Coverage Δ
Python 92.17% <88.63%> (-0.05%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
xformers/components/reversible.py 83.82% <50.00%> (-3.28%) ⬇️
xformers/components/residual.py 96.15% <96.15%> (+0.80%) ⬆️
xformers/components/__init__.py 100.00% <100.00%> (ø)
xformers/factory/block_factory.py 93.93% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 0abfb78...fb0b479. Read the comment docs.

Copy link
Contributor

@dianaml0 dianaml0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch and nice fix!

@@ -15,7 +15,11 @@
from .in_proj_container import InProjContainer, InProjParams # noqa
from .multi_head_dispatch import MultiHeadDispatch # noqa
from .multi_head_dispatch import MultiHeadDispatchConfig
from .residual import LayerNormStyle, PostNorm, PreNorm, Residual # noqa
from .residual import LayerNormStyle # noqa; noqa
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: extra noqa typo?

@blefaudeux blefaudeux merged commit 78bc58c into main Mar 11, 2022
@blefaudeux blefaudeux deleted the pre_post_norm_fix branch March 11, 2022 21:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants