-
Notifications
You must be signed in to change notification settings - Fork 230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FEAT Adding experimental feature : Triton mm int8xint2 #195
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
awesome work!
key=['M', 'N', 'K'], | ||
) | ||
@triton.jit | ||
def matmul_kernel( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a paper reference and more comments. it is a bit hard to understand the current code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I will add some comments to explain the process when i have some time
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
claude / chatgpt can be helpful
return c | ||
|
||
|
||
def test_kernel(size=2048) : |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move this to tests/
return packed | ||
|
||
|
||
def get_cuda_autotune_config(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should not be cuda?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I didn't understand the comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i mean triton and cuda are different things. maybe replace cuda with triton
good to merge if we can add more comments / ref to explain |
triton_output = matmul(ht.view(B * M, N), u.T.contiguous()).view(B, M, -1) | ||
|
||
# Validate packing and unpacking of weights | ||
assert (pack_weights(unpack_weights(u.T), 2) == u.T).all(), "Packed weights do not match original weights." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
imo we can separate the correctness of pack + unpack to another testing func
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay i will
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, | ||
GROUP_SIZE_M: tl.constexpr, | ||
): | ||
# Only triggered when TRITON_DEBUG is set to 1 => ex : TRITON_DEBUG=1 python scritp.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's a device_assert, it only works when running with TRITON_DEBUG not set to 0, and it ensure that the K is a multiple of BLOCK_SIZE * 4, which is the case of weight matrices, for alignment purposes. In the future we can find a way to make it more general
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does that mean it doesn't error out even if the alignment is incorrect if TRITON_DEBUG not enabled? Wondering if we can use https://triton-lang.org/main/python-api/generated/triton.language.static_assert.html#triton.language.static_assert
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or can we just uplift the assertion before the kernel launch (line 158)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now it's working with static_assert, I just had to specify that K is tl.constexpr
please also update the doc: https://github.com/linkedin/Liger-Kernel?tab=readme-ov-file#experimental-kernels |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For my own learning purpose, any specific reason we choose bit2 here? I think 1-bit has some cuda support bit counting or matching operations so could be faster if implemented in the context of CUDA, 4bits is safer when doing weight quantization, and is bit 2 a trade-off of accuracy and speed compared bot 4bits and 1bits?
@qingquansong Yeah, according to the 1.58 LLM paper, using only -1 and 1 can actually deteriorate performance. Adding a 0 element to select important features or not seems like a much better approach. Plus, using -1, 1, and 0 helps in the context of matmul-free LLMs, where, with the right hardware support, it can significantly boost inference speed and reduce energy consumption, because add operations consume significantly less energy and time than mul operations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
cc @MekkCyber please fix the style & test and we are good to merge |
Head branch was pushed to by a user without write access
@MekkCyber ci still failing |
weird, all tests pass locally, will look into that |
@MekkCyber can you follow up on this? we can merge it for the next release |
Summary
Introducing matrix multiplication int8xint2 in Triton as an experimental feature. This approach involves performing matmul with on-the-fly unpacking, utilizing cached tiling techniques. Currently, it leverages tl.dot with int8 values, which is the most optimized method available at this time. However, with future hardware advancements, this could become significantly more efficient, particularly when using ternary weights, potentially eliminating the need for multiplication altogether.