Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About W2A16 weight only matmul #10

Open
goddice opened this issue Sep 26, 2024 · 2 comments
Open

About W2A16 weight only matmul #10

goddice opened this issue Sep 26, 2024 · 2 comments

Comments

@goddice
Copy link

goddice commented Sep 26, 2024

Hi, if I have a linear layer the weight only has the value of {0, 1, -1}. Is it possible to utilize your kernel for weight compression and inference speed-up? My current weight is in bfloat16 format.

For example, if I have this code:

input = torch.randn(64, 1024, dtype=torch.bfloat16).cuda()
weights = torch.randint(-1, 2, (1024, 1024), dtype=torch.int8)
weights_bf16 = weights.bfloat16().cuda()
output = torch.nn.functional.linear(input, weights_bf16, None)

How to use ABQ's kernel to optimize the computation? Thanks!

@lswzjuer
Copy link
Collaborator

Thanks for your attention to our work. Matrix multiplication of int and float is not supported, but based on experience in model optimization, the effect of int16 and float16 will be basically aligned (sd or llm).

So I suggest you try W2Aint16. In this case, you can directly use our operator for acceleration. Our operator is suitable for W2 scenarios.

@goddice
Copy link
Author

goddice commented Sep 27, 2024

Thanks for your attention to our work. Matrix multiplication of int and float is not supported, but based on experience in model optimization, the effect of int16 and float16 will be basically aligned (sd or llm).

So I suggest you try W2Aint16. In this case, you can directly use our operator for acceleration. Our operator is suitable for W2 scenarios.

Thanks for the reply.
So if all of the linear layer in my model has the bfloat16 weight with values {0,1,-1}, and the input and output are blfoat16. What steps should I do to use your W2Aint16 operator? Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants