Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

causal_mask of the decoder #16

Closed
burcehan opened this issue Sep 10, 2020 · 4 comments
Closed

causal_mask of the decoder #16

burcehan opened this issue Sep 10, 2020 · 4 comments

Comments

@burcehan
Copy link

Hi ,
You've done a great job and thanks for the sharing.
I don't understand the causal_mask of the decoder,the shape of attention matrix is (n, k) , only the (k,k) part is masked, Does it work? Is there any test results in language model?
Thanks for your time!

@tatp22
Copy link
Owner

tatp22 commented Sep 10, 2020

Hi @burcehan!

Yes, this is an issue. To be honest, I don't know how I would get across this. There has been a discussion about this in Issue #11, and what was decided there was that we would include both types of masking:

  1. The upper right tri is masked (Here, this is exactly what you are referring to). This is done with the function get_causal_mask:
def gen_causal_mask(input_size, dim_k, full_attention=False):
    """
    Generates a causal mask of size (input_size, dim_k) for linformer
    Else, it generates (input_size, input_size) for full attention
    """
    if full_attention:
        return (torch.triu(torch.ones(input_size, input_size))==1).transpose(0,1)
    return (torch.triu(torch.ones(dim_k, input_size))==1).transpose(0,1)
  1. You can mask the inputs to the decoder yourself by supplying your own mask. This is done by making your own masks:
x = torch.randint(1,10000,(1,512))
y = torch.randint(1,10000,(1,512))

x_mask = torch.ones_like(x).bool()
y_mask = torch.ones_like(y).bool()

enc_output = encoder(x, input_mask=x_mask)
print(enc_output.shape) # (1, 512, 128)
dec_output = decoder(y, embeddings=enc_output, input_mask=y_mask, embeddings_mask=x_mask)
print(dec_output.shape) # (1, 512, 10000)

Testing can be seen in #13. Unfortunately, I am quite busy lately, so I haven't had the chance to actually finish testing (as in, make sure this performs on par with other transformers). I'll get back to it, hopefully sometime in October.

Also, just a side note on the Linformer in general: I think that all interpretability of tokens stops being relevant once a downsampling technique is used. As stated in issue #15, the tokens can become dispersed throughout the whole head if a linear layer is used. With convolution, at least you get some sense of locality and grouping, but still, this is an open research problem, as to how mask the Linformer, because of these tokens getting lost when downsampling. (I personally think technique 2 is more interpretable, as that masks it right away, and there is no possibility of information going forward.)

I hope this answered your question!

@tatp22 tatp22 closed this as completed Sep 30, 2020
@vanzytay
Copy link

A quick note that even if you do an "n by k" mask on the low-rank attention, future information will still bleed into the past and is therefore not "causal".

@tatp22
Copy link
Owner

tatp22 commented Nov 25, 2020

Yes, I know, unfortunately, the Linformer was never meant to be used for causal decoding, but rather for BERT-style attention. I heard that the authors may release another version of the Linformer that deals with causal masking, but I have my own ideas too about how this mask should be applied. As to implementation... I may have time one of these weekends to try something out and publish it

@tatp22
Copy link
Owner

tatp22 commented Nov 25, 2020

That is, I have my own idea about how to make this causal, but unfortunately not the computing resources to test it out and actually make a paper.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants