-
Notifications
You must be signed in to change notification settings - Fork 633
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat] Rewrite the input projection + add several init options #312
Conversation
c477645
to
32a3ac6
Compare
It turns out that using a single projection for QKV in the self-attention case is actually slower, due to the posterior .contiguous() call I think. cc @dianaml0, I got this from your benchmark_multihead script. Removing this as part of this PR, this simplifies the code a lot in the initialization part as a positive side effect. Having a custom self attention projection which results in contiguous tensors from the get go would fix that, and ideally we could use tensor views to keep the initialization code the same (so that the changes from this PR would not need to be reversed. Separate projections:
With a single QKV projection but a .contiguous() call afterwards:
basically for the bigger (not that big) sizes we become bandwidth bound anyway, and the single projection followed by tensor reordering loses, makes sense. |
162746f
to
e97c288
Compare
@@ -12,7 +12,7 @@ | |||
from torch.nn.init import constant_ | |||
|
|||
from xformers.components.attention import Attention | |||
from xformers.components.in_proj_container import InProjContainer, InProjParams |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
renaming this, I thought it was a little easier to understand what it was doing from the outside
bf76de6
to
bfd27b7
Compare
@@ -59,10 +59,10 @@ def __init__( | |||
attention: Attention, | |||
bias: bool = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would it be worth it to expose one bias option per input ? it could be a Union[bool, Tuple[bool, bool, bool]]
instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps would be worth the flexibility?
@@ -205,16 +202,16 @@ def check(t, name): | |||
_split_heads if self.attention.requires_head_dimension else _fold_heads | |||
) | |||
|
|||
q = reshape_fn(q, B, S_Q, self.num_heads, self.dim_k) | |||
k = reshape_fn(k, B, S_K, self.num_heads, self.dim_k) | |||
v = reshape_fn(v, B, S_K, self.num_heads, self.dim_k) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all this was a little buggy if the Q/K/V dimensions were not the same, we were not properly testing for this. The PR make it clear that V can be different
|
||
@classmethod | ||
def from_config(cls, config: xFormerConfig): | ||
return cls(config.stack_configs, config.tie_embedding_weights) | ||
|
||
def _deepnorm_weight_init(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of a ad-hoc init here, only for deepnorm, create a dedicated weight_init part (covering a couple of popular options) and move everything there
) | ||
|
||
|
||
class xFormerWeightInit(str, Enum): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this file should handle all supported init options, easy to refer to and easy to compare different methods. Always possible for users to do this on their own from the outside, of course
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great to have this!
elif hasattr(module, "init_weights"): | ||
module.init_weights() # type: ignore | ||
else: | ||
_maybe_report_no_init(module, name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
catch cases which look really suspect, normally all the weights should be covered/touched above
38e2ccf
to
562a984
Compare
5baebeb
to
f3a7857
Compare
_ = decoder_block( | ||
inputs, encoded, encoder_att_mask=att_mask, input_mask=input_mask | ||
) | ||
|
||
# Test different sequence lengths when encoding and decoding | ||
if not decoder_block.mha.attention.requires_same_k_q_dimensions: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was getting a little obscure, so the encoder and decoder blocks now expose directly these attention flags, I figured it was easier to read and made sense
config = test_configs_dict | ||
|
||
# Make sure that all the init methods catch all the weights | ||
xformers_weight_init._assert_if_not_initialized = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a little stringent but probably a good thing: any tensor not handled by the initialization schemes will assert out, to catch a possible mistake / typo
@@ -172,7 +169,6 @@ def test_pytorch_tranformer_parity(device=torch.device("cuda")): | |||
dim_feedforward=4 * EMB, | |||
dropout=DROP, | |||
activation=ACTIVATION, | |||
layer_norm_eps=1e-06, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the default is 1e-5 both for xformers/triton and pytorch actually
@@ -35,7 +35,6 @@ | |||
"attention": { | |||
"name": "scaled_dot_product", | |||
"dropout": DROP, | |||
"causal": False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is the default, and less is more
dim_key: Optional[int] = None, | ||
dim_value: Optional[int] = None, | ||
in_proj_container: Optional[InProjContainer] = None, | ||
use_separate_proj_weight: Optional[bool] = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was a nasty default I believe, by default we would use the same weights for q/k/v projections, which for the self attention is probably not a good idea
self.dim_model = dim_model | ||
self.attention = attention | ||
|
||
# key, query, value projections for all heads | ||
# critical options are | ||
# - are we sharing weights ? | ||
# - are we adding biases, and if yes are they shared ? | ||
# - are we adding biases ? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure of the amount of nuances to support here, for instance one bias setting per input ? could be a little overkill, but easy to add if need be
|
||
# Self-attend | ||
y = self.attention(q, k, v, **kw_mask_args) | ||
|
||
# Re-assemble all head outputs side by side | ||
y = ( | ||
y.view(B, self.num_heads, S_Q, self.dim_k) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the embedding size is linked to the value here, it only worked because we did not test this case
@@ -130,12 +130,19 @@ def __init__(self, config: xFormerEncoderConfig, **kwargs): | |||
residual_scale=residual_scale, | |||
) | |||
|
|||
self.mha = build_multi_head_attention(config.multi_head_config) | |||
self.feedforward = build_feedforward(asdict(config.feedforward_config)) | |||
mha = build_multi_head_attention(config.multi_head_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't have to keep mha and feedforward as direct members, they were showing up in the model graph and when initializing, making it a little hard to see what was actually in the xformer model.
I really like this minGPT usecase by the way, I get to read a new dialog everytime. This is just after one epoch so it does not make a ton of sense, but there's some structure to it already Example
|
Codecov Report
@@ Coverage Diff @@
## main #312 +/- ##
==========================================
- Coverage 93.04% 92.51% -0.53%
==========================================
Files 66 67 +1
Lines 3564 3634 +70
==========================================
+ Hits 3316 3362 +46
- Misses 248 272 +24
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really nice changes!
This is completely optional, and will only occur when generating full models through xFormers, not when picking parts individually. | ||
There are basically two initialization mechanisms exposed, but the user is free to initialize weights as he/she sees fit after the fact. | ||
- Parts can expose a `init_weights()` method, which define sane defaults | ||
- xFormers supports [specific init schemes](xformers/factory/weight_init.py) which *can take precedence* over the init_weights() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure why but the markdown formatting isn't working when I view it, just for this paragraph
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've checked on https://github.com/facebookresearch/xformers/tree/xformers_init and it's properly rendered, maybe a PR thing..
if not decoder_block.mha.attention.requires_same_k_q_dimensions: | ||
if not causal or not hasattr(decoder_block.mha.attention, "causal"): | ||
if not decoder_block.requires_same_k_q_dimensions: | ||
if not causal or not decoder_block.causal_attention: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Helps readability
@@ -12,7 +12,7 @@ | |||
|
|||
from .activations import Activation, build_activation # noqa | |||
from .attention import Attention, build_attention # noqa | |||
from .in_proj_container import InProjContainer, InProjParams # noqa | |||
from .input_projection import InputProjection, InputProjectionConfig # noqa |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really like this change too
@@ -59,10 +59,10 @@ def __init__( | |||
attention: Attention, | |||
bias: bool = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps would be worth the flexibility?
) | ||
|
||
|
||
class xFormerWeightInit(str, Enum): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great to have this!
07d335d
to
249e113
Compare
What does this PR do?
Loosely linked to #219 and #264
rewrite the input projection to make it a little cleaner and easier to follow
add a dedicated weight_init file (purely optional, opt in basis)
Many TODOS left, comments welcome
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.