-
Notifications
You must be signed in to change notification settings - Fork 37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
causal_mask of the decoder #16
Comments
Hi @burcehan! Yes, this is an issue. To be honest, I don't know how I would get across this. There has been a discussion about this in Issue #11, and what was decided there was that we would include both types of masking:
def gen_causal_mask(input_size, dim_k, full_attention=False):
"""
Generates a causal mask of size (input_size, dim_k) for linformer
Else, it generates (input_size, input_size) for full attention
"""
if full_attention:
return (torch.triu(torch.ones(input_size, input_size))==1).transpose(0,1)
return (torch.triu(torch.ones(dim_k, input_size))==1).transpose(0,1)
x = torch.randint(1,10000,(1,512))
y = torch.randint(1,10000,(1,512))
x_mask = torch.ones_like(x).bool()
y_mask = torch.ones_like(y).bool()
enc_output = encoder(x, input_mask=x_mask)
print(enc_output.shape) # (1, 512, 128)
dec_output = decoder(y, embeddings=enc_output, input_mask=y_mask, embeddings_mask=x_mask)
print(dec_output.shape) # (1, 512, 10000) Testing can be seen in #13. Unfortunately, I am quite busy lately, so I haven't had the chance to actually finish testing (as in, make sure this performs on par with other transformers). I'll get back to it, hopefully sometime in October. Also, just a side note on the Linformer in general: I think that all interpretability of tokens stops being relevant once a downsampling technique is used. As stated in issue #15, the tokens can become dispersed throughout the whole head if a linear layer is used. With convolution, at least you get some sense of locality and grouping, but still, this is an open research problem, as to how mask the Linformer, because of these tokens getting lost when downsampling. (I personally think technique 2 is more interpretable, as that masks it right away, and there is no possibility of information going forward.) I hope this answered your question! |
A quick note that even if you do an "n by k" mask on the low-rank attention, future information will still bleed into the past and is therefore not "causal". |
Yes, I know, unfortunately, the Linformer was never meant to be used for causal decoding, but rather for BERT-style attention. I heard that the authors may release another version of the Linformer that deals with causal masking, but I have my own ideas too about how this mask should be applied. As to implementation... I may have time one of these weekends to try something out and publish it |
That is, I have my own idea about how to make this causal, but unfortunately not the computing resources to test it out and actually make a paper. |
Hi ,
You've done a great job and thanks for the sharing.
I don't understand the causal_mask of the decoder,the shape of attention matrix is (n, k) , only the (k,k) part is masked, Does it work? Is there any test results in language model?
Thanks for your time!
The text was updated successfully, but these errors were encountered: