-
Notifications
You must be signed in to change notification settings - Fork 978
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add int4b_t/uint4b_t support for mixed dtypes GEMM #1190
base: main
Are you sure you want to change the base?
Add int4b_t/uint4b_t support for mixed dtypes GEMM #1190
Conversation
This PR is intended to extend existing support for PR is opened for an initial review, there is number of things to add:
|
include/cutlass/numeric_conversion.h
Outdated
@@ -2690,6 +2691,77 @@ struct FastNumericArrayConverter<cutlass::half_t, uint8_t, 4, Round> { | |||
} | |||
}; | |||
|
|||
/// Partial specialization for Array<cutlass::half_t, 8> <= Array<int4b_t, 8> | |||
template <FloatRoundStyle Round> | |||
struct FastNumericArrayConverter<cutlass::half_t, cutlass::int4b_t, 8, Round> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you temporarily remove this converter, please? I have one that is slightly faster and I am hoping to push out next week along with some other changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No problem. But this PR needs work anyway, and it certainly won't be landed before these changes of yours, so I'm keeping it for now just so that the rest of it works; as soon as your changes merged, I'll rebase and remove it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahh ok, it sounds good. :) How much work do you think is needed? Is it ready for full review?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The list of what remains to be done is in the first comment above. I'd appreciate a feedback of what is already there, and eventually some suggestions about item 1 from this list. Items 2 and 3 should be quick to complete afterwards.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. I am not sure of how to handle it apart from a fresh threadblock level mma where we skip loading some of the int4 data from smem
@alexsamardzic , thanks for the PR. I am going to try 1. can you give write privilege to your branch? |
@manishucsd is the plan to add a fresh threadblock level mma to handle the item 1? |
// eventually for narrower ElementA. | ||
using LoadInstructionShapeA = | ||
GemmShape<LoadInstructionM, InstructionShape::kN, InstructionShape::kK>; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are only supporting TN layouts with this. right? I only see adjustment needed for the K-dim (which is going to be contiguous dimension) and not for M-dim or N-dim. The K-dim of the iterator which is on 4bits will need to adjust the load shapeK?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want to eventually support all layout combinations that are supported for int8, so - yes, I'll have to refine this part.
I don't have the details figured out, yet. It is in my head for now, we can do it in (a.) the threadblock level mma / mainloop (as you mentioned) or (b.) within the iterator (efficiently skipping iterator++ using a predicate which toggles). What do you recommend? Have you handled something similar in FasterTransformer version? |
I handled it in FT by copying the threadblock level mma. If it is possible, handling at the iterator level seems a lot cleaner, in my opinion. The reason I copied the threadblock mma in FT is because I already needed a separate mainloop to handle the application of scales. |
Done. (FWIW, my time zone is GMT+1 - sorry for the delay.) |
if constexpr (elem_B_4bit) { | ||
if (!even_flag_) { | ||
ptr_B += TransformedFragmentB::kElements / 2 / MmaOperandB::kElements; | ||
} | ||
even_flag_ = !even_flag_; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@alexsamardzic can you explain what are you trying to achieve here? and in general with a mutable even_flag_
in this file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok. I think I got it. Ideally, this should also be in iterator. You are loading double the amount of data for int4 operand.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, just because how iterators used are written currently, twice the amount of int4 data needed gets loaded on each load()
call. At the moment, as mentioned under 1. in my initial comment, there is a part of handling this in the testbed too, that apparently need to be moved elsewhere - so, open for suggestions for a better approach.
Please see the two approaches below. I am leaning towards second approach. // F16 * S4
// ThreadblockShape : (128x128x64)
// WarpShape : (64x64x64)
// InstructionShape : (16x8x16)
* Approach 1 *
/* Consider F16 * S4 mixed-input warp-level matrix multiply */
// Pros: Cleaner "looking" warp-level and threadblock-level mma
// Similar to the rest of the warp-level and threadblock-level mma operation
// Cons: Requires additional internal state and fragment iterator for the narrower input fragment,
// Non-intiutive warp-level and threadblock-level mma operations
for (int k = 0; k < ThreadblockShape::kK; k += InstructionShape::K) {
smem_iter_A.load(loaded_frag_A); // loads 64x16 per warp
++smem_iter_A; // advance iterators to next 64x16 block
// Skips every other call to load call using predication.
// Needs additional internal state within the iterator and predication
smem_iter_B.load(loaded_frag_B); // loads 32x64 per warp
++smem_iter_B; // advance iterators to next 32x64 block (skips every other operator++() call
// transform the fragments A (64x16) and B (32x64) into the appropriate
mma.transform(transformed_frag_A, transformed_frag_B, loaded_frag_A, loaded_frag_B);
// MmaMixedInputTensorOp maintains a internal state and fragement iterator for the narrower input fragment
mma(accum, transformed_frag_A, transformed_frag_B, accum);
}
* Approach 2 *
/* Consider F16 * S4 mixed-input warp-level matrix multiply */
// Pros: Intuitive warp-level and threadblock-level mma
// Cleaner warp-level and threadblock-level mma operations
// No additional internal state iterator changes and handling predicated loads.
// Cons: Different threadblock-level mainloop for mixed-input with S4 datatype.
static const int kFragmentInt4TileK = 32; // Instruction::kK * 2
for (int k = 0; k < ThreadblockShape::kK; k += kFragmentInt4TileK) {
smem_iter_A.load(loaded_frag_A); // loads 64x16 per warp
++smem_iter_A; // advance iterators to next 64x16 block
smem_iter_B.load(loaded_frag_B); // loads 32x64 per warp
++smem_iter_B; // advance iterators to next 32x64 block
Array<int4_t, FragmentB::kElements / 2>* partition_loaded_frag_B =
reinterpret_cast<Array<int4_t, FragmentB::kElements / 2>*>(&loaded_frag_B);
// transform the [fragments A (64x16)]=64x16xF16 and [half of B (32x64)]=16x64xS4
mma.transform(transformed_frag_A, transformed_frag_B, loaded_frag_A, partition_loaded_frag_B[0]);
// MmaMixedInputTensorOp
mma(accum, transformed_frag_A, transformed_frag_B, accum);
smem_iter_A.load(loaded_frag_A); // loads 64x16 per warp
++smem_iter_A; // advance iterators to next 64x16 block
// transform the [fragments A (64x16)]=64x16xF16 and [half of B (32x64)]=16x64xS4
mma.transform(transformed_frag_A, transformed_frag_B, loaded_frag_A, partition_loaded_frag_B[1]);
// MmaMixedInputTensorOp
mma(accum, transformed_frag_A, transformed_frag_B, accum);
} |
Thanks, let me try to make an update along the line of approach 2. |
bf3d571
to
9d5bd54
Compare
A branch with an implementation according to the approach 2 above is here. It has some quirks too:
Overall, my initial proposal (as in this PR) is also along the line of approach 2, but I'd say is less intrusive. Its only downside is that it uses twice more registers to keep |
Any feedback on my last comment above? |
HI @alexsamardzic, Thank you for being persistent on this. Apologies, I have been away for a few days. I don't get the need for bool in the template and the if/else. Also, it is going to add a branch into the mainloop? We don't want any branches in the multistage mainloop for performance. The responsibility of I have drawn a figure to help us visualize the different components of the computation ( @rawnhenry and @hwu36 are working on hopper version of |
Hi @manishucsd, thanks for the additional clarifications - I think by far I understand well the various components of changes needed for To recap, I provided so far two various attempts for supporting As already discussed: The main challenge with extending mixed data-types GEMM with As far as the remaining details concerned, I think all of them are minor, and we could discuss it along the way:
|
774fa01
to
3bcd5b0
Compare
3bcd5b0
to
bed979f
Compare
Re-based on the latest main and made several updates, most important of which is that |
Hi guys, would it be possible for you to provide any feedback on this PR at this stage? This functionality is really needed for PyTorch, and I recently updated the PR so that in my opinion it should be close to its final form, at least for |
i will circle back to this one after i finish what is in my hand. this is a big and important change, so i need to spend quite a bit of time to work on this. |
This PR has been labeled |
bed979f
to
cc92d34
Compare
f5667c3
to
855b057
Compare
@hwu36 - Do you have cycles coming up to review this? This is very useful for quantization. |
This PR has been labeled |
@manishucsd - do you have some time coming up to review this PR? |
855b057
to
b96bd61
Compare
pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2], | ||
pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]); | ||
pipe_state.warp_loaded_frag_B_[PipeState::is_mixed_and_B_4bit ? (warp_mma_k / 2 + 1) % 2 : (warp_mma_k + 1) % 2]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@alexsamardzic, For these mainloop changes, can we run full device level tests to see that nothing is broken on SM80. I would prefer to not touch this file and create a separate version of this file for just F16 x S4
cases.
@hwu36 , What are your thoughts on this? Also, is it possible to handle this outside of mainloop. For e.g. in a specialization of shared memory iterator. We have probably discussed it before, but worth re-visiting that thought. Happy to schedule something between the three of us to brainstorm this PR further.
This PR has been labeled |
b96bd61
to
ca299c2
Compare
Rebased on latest main. |
@manishucsd @rhenry-nv