Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add int4b_t/uint4b_t support for mixed dtypes GEMM #1190

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

alexsamardzic
Copy link
Contributor

@alexsamardzic alexsamardzic commented Nov 15, 2023

@alexsamardzic
Copy link
Contributor Author

This PR is intended to extend existing support for S8/U8 dtypes for mixed GEMM with the same kind of support for S4/U4.

PR is opened for an initial review, there is number of things to add:

  1. Because of how existing tile iterators work, there will be twice more values for S4/U4 operand loaded by iterator than needed. This is at the moment handled in the test/unit/gemm/warp/testbed.h, for the purpose of testing changes; apparently, it will have to be handled some other way, any suggestions on how best to do it would be appreciated.
  2. The code at the moment implements only S4 support, for the case of B operand being of this dtype; the PR will eventually include both S4 and U4 support, for both A or B operands, but these should be straightforward to add once we're content with this initial case.
  3. PR will also include existing mixed dtypes support in cutlass_library with support for 4-bit dtypes. This should be also simple to add once the rest is ready.

@@ -2690,6 +2691,77 @@ struct FastNumericArrayConverter<cutlass::half_t, uint8_t, 4, Round> {
}
};

/// Partial specialization for Array<cutlass::half_t, 8> <= Array<int4b_t, 8>
template <FloatRoundStyle Round>
struct FastNumericArrayConverter<cutlass::half_t, cutlass::int4b_t, 8, Round> {
Copy link

@rawnhenry rawnhenry Nov 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you temporarily remove this converter, please? I have one that is slightly faster and I am hoping to push out next week along with some other changes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem. But this PR needs work anyway, and it certainly won't be landed before these changes of yours, so I'm keeping it for now just so that the rest of it works; as soon as your changes merged, I'll rebase and remove it.

Copy link

@rawnhenry rawnhenry Nov 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh ok, it sounds good. :) How much work do you think is needed? Is it ready for full review?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The list of what remains to be done is in the first comment above. I'd appreciate a feedback of what is already there, and eventually some suggestions about item 1 from this list. Items 2 and 3 should be quick to complete afterwards.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. I am not sure of how to handle it apart from a fresh threadblock level mma where we skip loading some of the int4 data from smem

@manishucsd
Copy link
Contributor

@alexsamardzic , thanks for the PR. I am going to try 1. can you give write privilege to your branch?

@rawnhenry
Copy link

@manishucsd is the plan to add a fresh threadblock level mma to handle the item 1?

// eventually for narrower ElementA.
using LoadInstructionShapeA =
GemmShape<LoadInstructionM, InstructionShape::kN, InstructionShape::kK>;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are only supporting TN layouts with this. right? I only see adjustment needed for the K-dim (which is going to be contiguous dimension) and not for M-dim or N-dim. The K-dim of the iterator which is on 4bits will need to adjust the load shapeK?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to eventually support all layout combinations that are supported for int8, so - yes, I'll have to refine this part.

@manishucsd
Copy link
Contributor

manishucsd commented Nov 16, 2023

@manishucsd is the plan to add a fresh threadblock level mma to handle the item 1?

I don't have the details figured out, yet. It is in my head for now, we can do it in (a.) the threadblock level mma / mainloop (as you mentioned) or (b.) within the iterator (efficiently skipping iterator++ using a predicate which toggles). What do you recommend? Have you handled something similar in FasterTransformer version?

@rawnhenry
Copy link

I handled it in FT by copying the threadblock level mma. If it is possible, handling at the iterator level seems a lot cleaner, in my opinion.

The reason I copied the threadblock mma in FT is because I already needed a separate mainloop to handle the application of scales.

@alexsamardzic
Copy link
Contributor Author

@alexsamardzic , thanks for the PR. I am going to try 1. can you give write privilege to your branch?

Done. (FWIW, my time zone is GMT+1 - sorry for the delay.)

Comment on lines 628 to 636
if constexpr (elem_B_4bit) {
if (!even_flag_) {
ptr_B += TransformedFragmentB::kElements / 2 / MmaOperandB::kElements;
}
even_flag_ = !even_flag_;
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alexsamardzic can you explain what are you trying to achieve here? and in general with a mutable even_flag_ in this file.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. I think I got it. Ideally, this should also be in iterator. You are loading double the amount of data for int4 operand.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, just because how iterators used are written currently, twice the amount of int4 data needed gets loaded on each load() call. At the moment, as mentioned under 1. in my initial comment, there is a part of handling this in the testbed too, that apparently need to be moved elsewhere - so, open for suggestions for a better approach.

@manishucsd
Copy link
Contributor

manishucsd commented Nov 28, 2023

Please see the two approaches below. I am leaning towards second approach.

// F16 * S4
// ThreadblockShape  : (128x128x64)
// WarpShape         : (64x64x64)
// InstructionShape  : (16x8x16)


* Approach 1 *
/* Consider F16 * S4 mixed-input warp-level matrix multiply */
// Pros: Cleaner "looking"  warp-level and threadblock-level mma
//       Similar to the rest of the warp-level and threadblock-level mma operation
// Cons: Requires additional internal state and fragment iterator for the narrower input fragment, 
//       Non-intiutive warp-level and threadblock-level mma operations

for (int k = 0; k < ThreadblockShape::kK; k += InstructionShape::K) {

  smem_iter_A.load(loaded_frag_A); // loads 64x16 per warp
  ++smem_iter_A;  // advance iterators to next 64x16 block

  // Skips every other call to load call using predication.
  // Needs additional internal state within the iterator and predication
  smem_iter_B.load(loaded_frag_B); // loads 32x64 per warp 
  ++smem_iter_B;  // advance iterators to next 32x64 block (skips every other operator++() call  

  // transform the fragments A (64x16) and B (32x64) into the appropriate
  mma.transform(transformed_frag_A, transformed_frag_B, loaded_frag_A, loaded_frag_B);

  // MmaMixedInputTensorOp maintains a internal state and fragement iterator for the narrower input fragment
  mma(accum, transformed_frag_A, transformed_frag_B, accum); 
}


* Approach 2 *

/* Consider F16 * S4 mixed-input warp-level matrix multiply */
// Pros: Intuitive warp-level and threadblock-level mma
//       Cleaner warp-level and threadblock-level mma operations
//       No additional internal state iterator changes and handling predicated loads.
// Cons: Different threadblock-level mainloop for mixed-input with S4 datatype. 

static const int kFragmentInt4TileK = 32; // Instruction::kK * 2
for (int k = 0; k < ThreadblockShape::kK; k += kFragmentInt4TileK) {

  smem_iter_A.load(loaded_frag_A); // loads 64x16 per warp
  ++smem_iter_A;  // advance iterators to next 64x16 block

  smem_iter_B.load(loaded_frag_B); // loads 32x64 per warp
  ++smem_iter_B;  // advance iterators to next 32x64 block
  

  Array<int4_t, FragmentB::kElements / 2>* partition_loaded_frag_B = 
    reinterpret_cast<Array<int4_t, FragmentB::kElements / 2>*>(&loaded_frag_B);

  // transform the [fragments A (64x16)]=64x16xF16 and [half of B (32x64)]=16x64xS4 
  mma.transform(transformed_frag_A, transformed_frag_B, loaded_frag_A, partition_loaded_frag_B[0]);

  // MmaMixedInputTensorOp 
  mma(accum, transformed_frag_A, transformed_frag_B, accum);

  smem_iter_A.load(loaded_frag_A); // loads 64x16 per warp
  ++smem_iter_A;  // advance iterators to next 64x16 block

  // transform the [fragments A (64x16)]=64x16xF16 and [half of B (32x64)]=16x64xS4 
  mma.transform(transformed_frag_A, transformed_frag_B, loaded_frag_A, partition_loaded_frag_B[1]);

  // MmaMixedInputTensorOp 
  mma(accum, transformed_frag_A, transformed_frag_B, accum);

}

@alexsamardzic
Copy link
Contributor Author

Thanks, let me try to make an update along the line of approach 2.

@alexsamardzic
Copy link
Contributor Author

alexsamardzic commented Dec 3, 2023

A branch with an implementation according to the approach 2 above is here. It has some quirks too:

  1. The code at the threadblock level, calling mma.transform() and mma() that I also put for now in testbed.h, is relatively clean, but still (as mentioned in my first comment above) changes alike to this code are to be put somewhere in cutlass/gemm/threadblock - I guess we'd have now to change threadblock loops like this in several classes there?
  2. The MmaMixedInputTensorOp::transform() method would have to know now is it about to transform "lower" or "upper" half of the loaded fragment. But the interface of this method should not be changed, so I've put a quick fix for now, that is to have a flag that this method flips each time when it's called - apparently, this depends on the user of this method to call it in proper sequence. For the same reason, the FragmentShuffler now has an additional bool template argument. This is all rather ugly.

Overall, my initial proposal (as in this PR) is also along the line of approach 2, but I'd say is less intrusive. Its only downside is that it uses twice more registers to keep transformed_frag_B in the threadblock loop, but I don't think that makes much of the difference (considering the number of registers needed for F16/S8 case).

@alexsamardzic
Copy link
Contributor Author

Any feedback on my last comment above?

@manishucsd
Copy link
Contributor

HI @alexsamardzic, Thank you for being persistent on this. Apologies, I have been away for a few days.

I don't get the need for bool in the template and the if/else. Also, it is going to add a branch into the mainloop? We don't want any branches in the multistage mainloop for performance. The responsibility of FragmentShuffler followed by Converter is to take the entire the fragment, be it for one mma.sync or two mma.sync, shuffle to fix the thread-value ownership and call Converter. For the B operand in F16 * S4 case, FragmentB will hold data worth for two mma.sync. I believe all of this can be handled at compile-time reading the datatypes. Let us shoot for single FragmentShuffler implementation for S8 and S4, if not possible with single implementation we can have specialized implementation of FragmentShuffler for S4 based on the datatype and no need for bool?

I have drawn a figure to help us visualize the different components of the computation (F16 * S4). Please take a look and see if it helps:
Screenshot 2023-12-27 at 8 55 40 AM

@rawnhenry and @hwu36 are working on hopper version of F16 * S4 that does not need shuffle, tagging them to see if they have suggestions based on the figures and discussion for the Ampere version.

@alexsamardzic
Copy link
Contributor Author

Hi @manishucsd, thanks for the additional clarifications - I think by far I understand well the various components of changes needed for S4, but it was still helpful to make sure I'm on the same page.

To recap, I provided so far two various attempts for supporting F16/S4 mixed data-types GEMM: here is the first one, and here is the new one. In the first one, transforming (which means: reshuffle+upcasting) S4 fragment is performed at once, while in the second one I tried to follow what you suggested as "approach 2" in your next-to-last comment above, so half of S4 fragment is transformed on each MmaMixedInputTensorOp::transform() call. You last comment calls again for transforming whole S4 fragment at once, but that's not the problem - I have both implementations ready.

As already discussed: The main challenge with extending mixed data-types GEMM with S4 is that twice more data than is loaded per corresponding load() call than actually needed for mma() call. For this reason, some changes (I mean here on the code like the code you provided in your next-to-last comment, and the code that is on the bottom-right of the picture attached with your last comment) will have to be made at the thread-block level, more precisely somewhere under include/cutlass/gemm/threadblock. So far, in both of my above mentioned branches, I've put these changes in test/unit/gemm/warp/testbed.h in order to be able to test the rest of my changes, but apparently this is not the right place; the main input that I need at this stage is where exactly to make these changes (e.g.: is it just in include/cutlass/gemm/threadblock/mma_multistage.h or somewhere else too, etc.)?

As far as the remaining details concerned, I think all of them are minor, and we could discuss it along the way:

  • @rawnhenry mentioned above he has more efficient converter, I guess implemented for this Hopper version, and that's fine, I can use this one instead of mine.
  • I don't think it's possible to use the same shuffler for S4 and S8, but we could consider that too.
  • I'll try of course to avoid run-time branching as much as possible, and further optimize the code; but it'll be more effective for me to make changes like this once I'm sure this loop consisting of load()-transform()-mma() calls is put in the right place, and is looking the way you guys are happy about.

@alexsamardzic alexsamardzic force-pushed the 4bit-support-mixed-dtypes-gemm branch 3 times, most recently from 774fa01 to 3bcd5b0 Compare January 12, 2024 18:36
@alexsamardzic
Copy link
Contributor Author

Re-based on the latest main and made several updates, most important of which is that cutlass::gemm::threadblock::MmaMultistage class is changed to support 4-bit mixed data-types GEMM. A new test gemm_universal_f16t_s4n_f16t_mixed_input_tensor_op_f16_sm80.cu is added to unit/gemm/device, to validate these changes. So the PR contains now all the changes needed, at least for f16t/s4n case - would it be possible to have it reviewed again?

@alexsamardzic
Copy link
Contributor Author

Hi guys, would it be possible for you to provide any feedback on this PR at this stage? This functionality is really needed for PyTorch, and I recently updated the PR so that in my opinion it should be close to its final form, at least for f16t/s4n case - I plan to add support for other combinations (as well as for cutlass_library stuff), but before that it would be good to know would this be acceptable.

@hwu36
Copy link
Collaborator

hwu36 commented Feb 8, 2024

i will circle back to this one after i finish what is in my hand. this is a big and important change, so i need to spend quite a bit of time to work on this.

Copy link

github-actions bot commented Mar 9, 2024

This PR has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this PR if it is no longer required. Otherwise, please respond with a comment indicating any updates. This PR will be labeled inactive-90d if there is no activity in the next 60 days.

@cpuhrsch
Copy link

@hwu36 - Do you have cycles coming up to review this? This is very useful for quantization.

Copy link

This PR has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this PR if it is no longer required. Otherwise, please respond with a comment indicating any updates. This PR will be labeled inactive-90d if there is no activity in the next 60 days.

@cpuhrsch
Copy link

cpuhrsch commented Jun 5, 2024

@manishucsd - do you have some time coming up to review this PR?

pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2],
pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]);
pipe_state.warp_loaded_frag_B_[PipeState::is_mixed_and_B_4bit ? (warp_mma_k / 2 + 1) % 2 : (warp_mma_k + 1) % 2]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alexsamardzic, For these mainloop changes, can we run full device level tests to see that nothing is broken on SM80. I would prefer to not touch this file and create a separate version of this file for just F16 x S4 cases.

@hwu36 , What are your thoughts on this? Also, is it possible to handle this outside of mainloop. For e.g. in a specialization of shared memory iterator. We have probably discussed it before, but worth re-visiting that thought. Happy to schedule something between the three of us to brainstorm this PR further.

Copy link

This PR has been labeled inactive-90d due to no recent activity in the past 90 days. Please close this PR if it is no longer required. Otherwise, please respond with a comment indicating any updates.

@alexsamardzic
Copy link
Contributor Author

Rebased on latest main.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants