Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Rewrite the input projection + add several init options #312

Merged
merged 6 commits into from
Jun 3, 2022

Conversation

blefaudeux
Copy link
Contributor

@blefaudeux blefaudeux commented May 26, 2022

What does this PR do?

Loosely linked to #219 and #264

  • rewrite the input projection to make it a little cleaner and easier to follow

    • nn.Linear instead of reinventing the linear projection wheel
    • explicit setting for self attention
    • move the weight init settings out of it (this was the only part which exposed something specific)
    • renam in_proj_container into input_projection, it feels a little easier to understand ?
  • add a dedicated weight_init file (purely optional, opt in basis)

    • handle deepnorm from there
    • handle a couple of popular options for weight inits in general

Many TODOS left, comments welcome

  • Rewrite input projection
  • Dedicated place to handle all inits
  • Handle deepnorm QKV init
  • Handle the small init option
  • Typesafe all the inits
  • Add a weight init unit test
  • Add some explanations in the README
  • Check test coverage
  • Check with a bunch of quick trainings whether all the inits are in the green
  • Make it possible in the MHA to select the bias per projection

Before submitting

  • Did you have fun?
    • Make sure you had fun coding 🙃
  • Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
    • N/A
  • Did you make sure to update the docs?
    • N/A
  • Did you write any new necessary tests?
    • N/A
  • Did you update the changelog? (if needed)
    • N/A

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 26, 2022
@blefaudeux blefaudeux marked this pull request as draft May 26, 2022 22:21
@blefaudeux blefaudeux force-pushed the xformers_init branch 4 times, most recently from c477645 to 32a3ac6 Compare May 30, 2022 17:35
@blefaudeux
Copy link
Contributor Author

blefaudeux commented May 30, 2022

It turns out that using a single projection for QKV in the self-attention case is actually slower, due to the posterior .contiguous() call I think. cc @dianaml0, I got this from your benchmark_multihead script. Removing this as part of this PR, this simplifies the code a lot in the initialization part as a positive side effect. Having a custom self attention projection which results in contiguous tensors from the get go would fix that, and ideally we could use tensor views to keep the initialization code the same (so that the changes from this PR would not need to be reversed.

Separate projections:
--- Type: torch.float16 ---

Units: runtime in ms, lower is better B=8, M=384, K=128, N_HEADS=4 B=8, M=784, K=512, N_HEADS=4 B=4, M=1024, K=768, N_HEADS=4 B=4, M=2048, K=1024, N_HEADS=4 B=2, M=2048, K=2048, N_HEADS=4 B=2, M=2048, K=4096, N_HEADS=4 B=2, M=4096, K=4096, N_HEADS=4 B=1, M=2048, K=12288, N_HEADS=4
torch - fw (self_attn) 0.24 1.61 1.86 6.59 7.70 21.76 51.08 73.89
xf - fw (self_attn) 0.19 1.45 1.62 5.71 7.25 21.16 49.29 75.29

With a single QKV projection but a .contiguous() call afterwards:
--- Type: torch.float16 ---

Units: runtime in ms, lower is better B=8, M=384, K=128, N_HEADS=4 B=8, M=784, K=512, N_HEADS=4 B=4, M=1024, K=768, N_HEADS=4 B=4, M=2048, K=1024, N_HEADS=4 B=2, M=2048, K=2048, N_HEADS=4 B=2, M=2048, K=4096, N_HEADS=4 B=2, M=4096, K=4096, N_HEADS=4 B=1, M=2048, K=12288, N_HEADS=4
torch - fw (self_attn) 0.24 1.62 1.86 6.60 7.75 21.57 51.10 76.29
xf - fw (self_attn) 0.20 1.47 1.76 5.82 7.42 21.53 51.18 76.19

basically for the bigger (not that big) sizes we become bandwidth bound anyway, and the single projection followed by tensor reordering loses, makes sense.

@blefaudeux blefaudeux force-pushed the xformers_init branch 5 times, most recently from 162746f to e97c288 Compare May 30, 2022 21:30
@@ -12,7 +12,7 @@
from torch.nn.init import constant_

from xformers.components.attention import Attention
from xformers.components.in_proj_container import InProjContainer, InProjParams
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renaming this, I thought it was a little easier to understand what it was doing from the outside

@blefaudeux blefaudeux force-pushed the xformers_init branch 8 times, most recently from bf76de6 to bfd27b7 Compare May 31, 2022 23:24
@@ -59,10 +59,10 @@ def __init__(
attention: Attention,
bias: bool = True,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be worth it to expose one bias option per input ? it could be a Union[bool, Tuple[bool, bool, bool]] instead

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps would be worth the flexibility?

@@ -205,16 +202,16 @@ def check(t, name):
_split_heads if self.attention.requires_head_dimension else _fold_heads
)

q = reshape_fn(q, B, S_Q, self.num_heads, self.dim_k)
k = reshape_fn(k, B, S_K, self.num_heads, self.dim_k)
v = reshape_fn(v, B, S_K, self.num_heads, self.dim_k)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all this was a little buggy if the Q/K/V dimensions were not the same, we were not properly testing for this. The PR make it clear that V can be different


@classmethod
def from_config(cls, config: xFormerConfig):
return cls(config.stack_configs, config.tie_embedding_weights)

def _deepnorm_weight_init(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of a ad-hoc init here, only for deepnorm, create a dedicated weight_init part (covering a couple of popular options) and move everything there

)


class xFormerWeightInit(str, Enum):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this file should handle all supported init options, easy to refer to and easy to compare different methods. Always possible for users to do this on their own from the outside, of course

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great to have this!

elif hasattr(module, "init_weights"):
module.init_weights() # type: ignore
else:
_maybe_report_no_init(module, name)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

catch cases which look really suspect, normally all the weights should be covered/touched above

@blefaudeux blefaudeux force-pushed the xformers_init branch 4 times, most recently from 5baebeb to f3a7857 Compare June 1, 2022 21:06
@blefaudeux blefaudeux marked this pull request as ready for review June 1, 2022 21:07
@blefaudeux blefaudeux requested a review from danthe3rd June 1, 2022 21:07
_ = decoder_block(
inputs, encoded, encoder_att_mask=att_mask, input_mask=input_mask
)

# Test different sequence lengths when encoding and decoding
if not decoder_block.mha.attention.requires_same_k_q_dimensions:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was getting a little obscure, so the encoder and decoder blocks now expose directly these attention flags, I figured it was easier to read and made sense

config = test_configs_dict

# Make sure that all the init methods catch all the weights
xformers_weight_init._assert_if_not_initialized = True
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a little stringent but probably a good thing: any tensor not handled by the initialization schemes will assert out, to catch a possible mistake / typo

@@ -172,7 +169,6 @@ def test_pytorch_tranformer_parity(device=torch.device("cuda")):
dim_feedforward=4 * EMB,
dropout=DROP,
activation=ACTIVATION,
layer_norm_eps=1e-06,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the default is 1e-5 both for xformers/triton and pytorch actually

@@ -35,7 +35,6 @@
"attention": {
"name": "scaled_dot_product",
"dropout": DROP,
"causal": False,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the default, and less is more

dim_key: Optional[int] = None,
dim_value: Optional[int] = None,
in_proj_container: Optional[InProjContainer] = None,
use_separate_proj_weight: Optional[bool] = False,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was a nasty default I believe, by default we would use the same weights for q/k/v projections, which for the self attention is probably not a good idea

self.dim_model = dim_model
self.attention = attention

# key, query, value projections for all heads
# critical options are
# - are we sharing weights ?
# - are we adding biases, and if yes are they shared ?
# - are we adding biases ?
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure of the amount of nuances to support here, for instance one bias setting per input ? could be a little overkill, but easy to add if need be


# Self-attend
y = self.attention(q, k, v, **kw_mask_args)

# Re-assemble all head outputs side by side
y = (
y.view(B, self.num_heads, S_Q, self.dim_k)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the embedding size is linked to the value here, it only worked because we did not test this case

@@ -130,12 +130,19 @@ def __init__(self, config: xFormerEncoderConfig, **kwargs):
residual_scale=residual_scale,
)

self.mha = build_multi_head_attention(config.multi_head_config)
self.feedforward = build_feedforward(asdict(config.feedforward_config))
mha = build_multi_head_attention(config.multi_head_config)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't have to keep mha and feedforward as direct members, they were showing up in the model graph and when initializing, making it a little hard to see what was actually in the xformer model.

@blefaudeux
Copy link
Contributor Author

MinGPT training with a couple of options. Not earth shattering, everything works, but there are some measurable differences

Screenshot from 2022-06-01 15-10-40

@blefaudeux
Copy link
Contributor Author

blefaudeux commented Jun 1, 2022

I really like this minGPT usecase by the way, I get to read a new dialog everytime. This is just after one epoch so it does not make a ton of sense, but there's some structure to it already

Example
Friends of my soul, lends the tackle's foe.

FRIAR LAURENCE:
I do not say 'tis true; and if I cannot,
Much in the view of heaven and am look'd
For fortune may be approved by the grave;
There was no man spraying for their throne,
Who even for them that we make us all,
And bid my knee proportion'd at my hand.
What further was the banish'd of the world,
The chapel-bone, the sun of Angelo,
Whom out of bitter jealousies with her:
So soon affections and these things, make no trade.

ESCALUS:
Neither.

LUCIO:
In chamber, I say. Farewell:
I think so.

PETRUCHIO:
I do beseech you, sir; for I knew of his pack.

KING RICHARD II:
Why, then I see my duty to him?

KATHARINA:
I like it well: I'll serve him in an eye
Where he shall poison on for a visor
That foul deeds die the entrails of my land,
Is curse this business.

DUKE VINCENTIO:
Take you this duke?

Provost:
None, sir, none.

ANGELO:
The gods bless more.

LUCENTIO:
His mouth in evil. I say his mother
Upon my head; and that's oft my other sense.

@codecov-commenter
Copy link

codecov-commenter commented Jun 1, 2022

Codecov Report

Merging #312 (249e113) into main (9e27c8f) will decrease coverage by 0.52%.
The diff coverage is 87.72%.

@@            Coverage Diff             @@
##             main     #312      +/-   ##
==========================================
- Coverage   93.04%   92.51%   -0.53%     
==========================================
  Files          66       67       +1     
  Lines        3564     3634      +70     
==========================================
+ Hits         3316     3362      +46     
- Misses        248      272      +24     
Flag Coverage Δ
Python 92.51% <87.72%> (-0.53%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
xformers/triton/dropout.py 77.41% <25.00%> (-2.36%) ⬇️
xformers/factory/weight_init.py 80.32% <80.32%> (ø)
xformers/components/__init__.py 100.00% <100.00%> (ø)
xformers/components/attention/compositional.py 100.00% <100.00%> (ø)
xformers/components/input_projection.py 100.00% <100.00%> (ø)
xformers/components/multi_head_dispatch.py 97.95% <100.00%> (+0.08%) ⬆️
xformers/components/positional_embedding/vocab.py 100.00% <100.00%> (ø)
xformers/factory/__init__.py 100.00% <100.00%> (ø)
xformers/factory/block_factory.py 97.01% <100.00%> (+0.13%) ⬆️
xformers/factory/model_factory.py 98.16% <100.00%> (+1.89%) ⬆️
... and 1 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 9e27c8f...249e113. Read the comment docs.

Copy link
Contributor

@dianaml0 dianaml0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice changes!

This is completely optional, and will only occur when generating full models through xFormers, not when picking parts individually.
There are basically two initialization mechanisms exposed, but the user is free to initialize weights as he/she sees fit after the fact.
- Parts can expose a `init_weights()` method, which define sane defaults
- xFormers supports [specific init schemes](xformers/factory/weight_init.py) which *can take precedence* over the init_weights()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why but the markdown formatting isn't working when I view it, just for this paragraph

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've checked on https://github.com/facebookresearch/xformers/tree/xformers_init and it's properly rendered, maybe a PR thing..

if not decoder_block.mha.attention.requires_same_k_q_dimensions:
if not causal or not hasattr(decoder_block.mha.attention, "causal"):
if not decoder_block.requires_same_k_q_dimensions:
if not causal or not decoder_block.causal_attention:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Helps readability

@@ -12,7 +12,7 @@

from .activations import Activation, build_activation # noqa
from .attention import Attention, build_attention # noqa
from .in_proj_container import InProjContainer, InProjParams # noqa
from .input_projection import InputProjection, InputProjectionConfig # noqa
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really like this change too

@@ -59,10 +59,10 @@ def __init__(
attention: Attention,
bias: bool = True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps would be worth the flexibility?

)


class xFormerWeightInit(str, Enum):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great to have this!

@blefaudeux blefaudeux force-pushed the xformers_init branch 2 times, most recently from 07d335d to 249e113 Compare June 3, 2022 00:02
@blefaudeux blefaudeux merged commit 52d1dd0 into main Jun 3, 2022
@blefaudeux blefaudeux deleted the xformers_init branch June 3, 2022 03:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants