Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT Adding experimental feature : Triton mm int8xint2 #195

Merged
merged 18 commits into from
Oct 2, 2024

Conversation

MekkCyber
Copy link
Contributor

@MekkCyber MekkCyber commented Sep 3, 2024

Summary

Introducing matrix multiplication int8xint2 in Triton as an experimental feature. This approach involves performing matmul with on-the-fly unpacking, utilizing cached tiling techniques. Currently, it leverages tl.dot with int8 values, which is the most optimized method available at this time. However, with future hardware advancements, this could become significantly more efficient, particularly when using ternary weights, potentially eliminating the need for multiplication altogether.

Copy link
Collaborator

@ByronHsu ByronHsu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome work!

key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(
Copy link
Collaborator

@ByronHsu ByronHsu Sep 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a paper reference and more comments. it is a bit hard to understand the current code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I will add some comments to explain the process when i have some time

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

claude / chatgpt can be helpful

return c


def test_kernel(size=2048) :
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this to tests/

return packed


def get_cuda_autotune_config():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should not be cuda?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I didn't understand the comment

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i mean triton and cuda are different things. maybe replace cuda with triton

@ByronHsu
Copy link
Collaborator

ByronHsu commented Sep 4, 2024

good to merge if we can add more comments / ref to explain

triton_output = matmul(ht.view(B * M, N), u.T.contiguous()).view(B, M, -1)

# Validate packing and unpacking of weights
assert (pack_weights(unpack_weights(u.T), 2) == u.T).all(), "Packed weights do not match original weights."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imo we can separate the correctness of pack + unpack to another testing func

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay i will

BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
# Only triggered when TRITON_DEBUG is set to 1 => ex : TRITON_DEBUG=1 python scritp.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a device_assert, it only works when running with TRITON_DEBUG not set to 0, and it ensure that the K is a multiple of BLOCK_SIZE * 4, which is the case of weight matrices, for alignment purposes. In the future we can find a way to make it more general

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does that mean it doesn't error out even if the alignment is incorrect if TRITON_DEBUG not enabled? Wondering if we can use https://triton-lang.org/main/python-api/generated/triton.language.static_assert.html#triton.language.static_assert

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or can we just uplift the assertion before the kernel launch (line 158)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now it's working with static_assert, I just had to specify that K is tl.constexpr

@ByronHsu
Copy link
Collaborator

ByronHsu commented Sep 4, 2024

Copy link
Collaborator

@qingquansong qingquansong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my own learning purpose, any specific reason we choose bit2 here? I think 1-bit has some cuda support bit counting or matching operations so could be faster if implemented in the context of CUDA, 4bits is safer when doing weight quantization, and is bit 2 a trade-off of accuracy and speed compared bot 4bits and 1bits?

@MekkCyber
Copy link
Contributor Author

@qingquansong Yeah, according to the 1.58 LLM paper, using only -1 and 1 can actually deteriorate performance. Adding a 0 element to select important features or not seems like a much better approach. Plus, using -1, 1, and 0 helps in the context of matmul-free LLMs, where, with the right hardware support, it can significantly boost inference speed and reduce energy consumption, because add operations consume significantly less energy and time than mul operations

ByronHsu
ByronHsu previously approved these changes Sep 7, 2024
Copy link
Collaborator

@ByronHsu ByronHsu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ByronHsu ByronHsu enabled auto-merge (squash) September 7, 2024 20:06
@ByronHsu
Copy link
Collaborator

ByronHsu commented Sep 8, 2024

cc @MekkCyber please fix the style & test and we are good to merge

auto-merge was automatically disabled September 9, 2024 09:53

Head branch was pushed to by a user without write access

@ByronHsu
Copy link
Collaborator

ByronHsu commented Sep 9, 2024

@MekkCyber ci still failing

@MekkCyber
Copy link
Contributor Author

weird, all tests pass locally, will look into that

@ByronHsu ByronHsu mentioned this pull request Sep 30, 2024
@ByronHsu
Copy link
Collaborator

@MekkCyber can you follow up on this? we can merge it for the next release

@lancerts lancerts enabled auto-merge (squash) October 2, 2024 22:04
@lancerts lancerts merged commit 60640e1 into linkedin:main Oct 2, 2024
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants