Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assertion statements in attention implementation #264

Open
dnnspark opened this issue Apr 8, 2022 · 9 comments
Open

Assertion statements in attention implementation #264

dnnspark opened this issue Apr 8, 2022 · 9 comments
Assignees

Comments

@dnnspark
Copy link

dnnspark commented Apr 8, 2022

❓ Questions and Help

I'm trying to implement Perceiver using xformers, and stumbled upon two assertion statements.

  1. The first one is this one: Doesn't this have to be t.shape[2] % self.dim_head == 0, to be consistent with the error message one line below?

  2. The second one is this one: why does the query projection have to preserve the dimension? I'm trying to implement a cross-attention scenario that query and key comes from different sources (So they are of differeint dimsensions) and the linear projections make sure they are of same dimensions (i.e. N x D_{query} -> N x d for query and M x D_{key} -> M x d). However, the assertion above enforce the query projection preserve the dimension. What's the point of this assertion (btw, this assertion does not exist in the original torchtext implementation)? Or, what's the right way to implement this idea?

@dnnspark dnnspark changed the title Assertion on linear projection of queries Assertion statements in attention implementation Apr 11, 2022
@blefaudeux
Copy link
Contributor

hi @dnnspark, thanks for your message ! Replying to 1. first, can you walk me through the problem ? I may be a little tired but I don't see the issue right now ? dim_k is defined as dim_key // num_heads (ok, the choice of letters is probably not great), so it looks like we're talking about the same thing.
Looking into 2. :)

@blefaudeux
Copy link
Contributor

ok, 2. now:

  • it's not a part of the repo that I personally like a lot (I wrote a big part of it, I'm probably allowed to say that), it's too complicated and confusing for what it does. The gist of it was to enable a self-attention small optim (project with one buffer, you save a lot of reads, that's here, and handle different init options which are not always the same in NLP or ViT for instance
  • it looks like this assertion was just there to remove one variable, it's not fundamentally required indeed as far as I can see. I can propose a PR to remove that, actually I had a draft PR to rewrite this part and it could be part of it
  • if you want absolute freedom (probably the best if you have an exotic projection scheme), you can swap this block by another one which is just right for you, it's used as is if part of the config

let me know if this helps, I can definitely follow up on this assert

@blefaudeux blefaudeux self-assigned this Apr 13, 2022
@dnnspark
Copy link
Author

dnnspark commented Apr 13, 2022

Hi @blefaudeux, thanks for checking!

For 1, becasue the input of the forward() are query, key, value with no constraints on their shape (e.g. assertion or docstring), I was assuming it works in general cross attention scenarios. Let's say the shape of query, key, value are (4, 24, 300), (4, 36, 300), (4, 36, 200) respectively (batch_size=4) -- they are projected from source and target data outside of this function -- , and the number of heads is 20. It's a valid input because all channels (i.e. 300, 300, 200) are divisible by the number of heads: each chunk of the query and key, which are of shape of (4, 24, 15) and (4, 36, 15), are used to compute an attention matrix of shape (4, 24, 36) and multiplied to the corresponding chuck of value input of shape=(4, 36, 10) to fill a corresponding part of output of shape=(4, 24, 10). There are 20 of these output (num_heads=20), and they are concatenated to make the final output (4, 24, 200), same shape of query input (and it's added residually to it outside of this function).

In this case, dim_head=20 and dim_k=15 (300 / 20). But dim_value (200) is not divisible by dim_k.

For 2, I see; that makes a lot of sense for the self-attention. And agreed, it may need some adjustments for being used for more general cross attention use case.

Thanks!

@blefaudeux
Copy link
Contributor

Hi @blefaudeux, thanks for checking!

For 1, becasue the input of the forward() are query, key, value with no constraints on their shape (e.g. assertion or docstring), I was assuming it works in general cross attention scenarios. Let's say the shape of query, key, value are (4, 24, 300), (4, 36, 300), (4, 36, 200) respectively (batch_size=4) -- they are projected from source and target data outside of this function -- , and the number of heads is 20. It's a valid input because all channels (i.e. 300, 300, 200) are divisible by the number of heads: each chunk of the query and key, which are of shape of (4, 24, 15) and (4, 36, 15), are used to compute an attention matrix of shape (4, 24, 36) and multiplied to the corresponding chuck of value input of shape=(4, 36, 10) to fill a corresponding part of output of shape=(4, 24, 10). There are 20 of these output (num_heads=20), and they are concatenated to make the final output (4, 24, 200), same shape of query input (and it's added residually to it outside of this function).

In this case, dim_head=20 and dim_k=15 (300 / 20). But dim_value (200) is not divisible by dim_k.

aah yes I see your point now, yes it implicitly assumes the same dimension everywhere, that's bad. Can be fixed, I'm trying to get out of a CI quagmire and will submit a PR, or feel free to do that if you fancy it

@blefaudeux
Copy link
Contributor

oh, let me just volley a PR right now for 1. and this will be fixed. one sec

@dnnspark
Copy link
Author

dnnspark commented Apr 13, 2022

Hey @blefaudeux, with a second thought, I think you're right about this dimension issue. In my example above, there's a flaw:

There are 20 of these output (num_heads=20), and they are concatenated to make the final output (4, 24, 200), same shape of query input (and it's added residually to it outside of this function).

It's actually not the same shape of query input: (4, 24, 300). So I think the dimensions of all inputs (query, key, value) has to be always same. In that case, the first assertion is actually correct, even though the name is a bit confusing (which makes your PR still legit).

@blefaudeux
Copy link
Contributor

Hey @blefaudeux, with a second thought, I think you're right about this dimension issue. In my example above, there's a flaw:

There are 20 of these output (num_heads=20), and they are concatenated to make the final output (4, 24, 200), same shape of query input (and it's added residually to it outside of this function).

It's actually not the same shape of query input: (4, 24, 300). So I think the dimensions of all inputs (query, key, value) has to be always same. In that case, the first assertion is actually correct, even though the name is a bit confusing (which makes your PR still legit).

an afterthought on my side is that this assert is not at the right place anyway, unless the projection conserves dimensions (I thought that was partly your point in your explanation actually). We check the dimensions pre-projection, then project, then head split (which is where the dimension misfit would be visible), but one could imagine an initial misfit which is "fixed" by differentiated projections (not saying that this would be a good thing to do, but it would work I believe). I'll try to fix that in the PR

blefaudeux added a commit that referenced this issue Apr 14, 2022
* Fixing #264, thanks @dnnspark
* changelog addendum
* moving the dimension check to post projection
@blefaudeux
Copy link
Contributor

it turns out that some of the checks were not correct, undue constraints, fixed with the attached PR

@blefaudeux
Copy link
Contributor

@dnnspark I think that this is fixed with the PR which landed yesterday ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants