Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve performance of quantizelinear for int4 #1706

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from

Conversation

dhernandez0
Copy link
Contributor

@dhernandez0 dhernandez0 commented Dec 17, 2024

In this PR we improve the performance of quantizelinear for int4, these are the changes:

  • Get the right element type for the getVectorDim() call, when there are input fusions the type can change and we should get the original type to decide how many k/d to copy per thread.
  • Move genericops before lds barrier if possible
  • Pack scale and bias together in the same tensor (quantizelinear)

There's a PR in migraphx to also change the layout of the scale+bias tensor: ROCm/AMDMIGraphX#3718

This is the migraphx program of the layout change (int32 packing scale and bias together):

p = migraphx.program()
m = p.get_main_module()
x_1 = m.add_parameter("x1", migraphx.shape(type="half_type", lens=[384, 32, 32, 2]))
x_2 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="uint8_type", lens=[12288, 2048]), 2))
p_x4 = m.add_parameter("x4", migraphx.shape(type="half_type", lens=[1, 1, 4096]))
x_1_transposed = m.add_instruction(migraphx.op("transpose", permutation=[0,2,1,3]), [x_1])
x_1_new = m.add_instruction(migraphx.op("reshape", dims=[12288,32,2]), [x_1_transposed])
x_4 = m.add_instruction(migraphx.op("unpack_int4", axis=1), [x_2]) # migraphx.shape(type="uint8_type", lens=[12288, 4096])
x_6 = m.add_instruction(migraphx.op("unsqueeze", axes=[2]), [x_1_new]) # migraphx.shape(type="half_type", lens=[12288, 32, 1, 1, 1, 1, 1])
x_7 = m.add_instruction(migraphx.op("multibroadcast", out_lens=[12288,32,128,2]), [x_6]) # migraphx.shape(type="half_type", lens=[12288, 32, 1, 1, 1, 1, 128], strides=[32, 1, 1, 1, 1, 1, 0])
x_8 = m.add_instruction(migraphx.op("reshape", dims=[12288,4096,2]), [x_7]) # migraphx.shape(type="half_type", lens=[12288, 4096])
scale = m.add_instruction(migraphx.op("slice", axes=[2], starts=[0], ends=[1]), [x_8])
bias = m.add_instruction(migraphx.op("slice", axes=[2], starts=[1], ends=[2]), [x_8])
scale_squeeze = m.add_instruction(migraphx.op("squeeze", axes=2), [scale])
bias_squeeze = m.add_instruction(migraphx.op("squeeze", axes=2), [bias])
x_12 = m.add_instruction(migraphx.op("dequantizelinear"), [x_4, scale_squeeze, bias_squeeze]) # migraphx.shape(type="half_type", lens=[12288, 4096])
x_13 = m.add_instruction(migraphx.op("unsqueeze", axes=[0]), [x_12]) # migraphx.shape(type="half_type", lens=[1, 12288, 4096])
x_14 = m.add_instruction(migraphx.op("transpose", permutation=[0,2,1]), [x_13]) # migraphx.shape(type="half_type", lens=[1, 4096, 12288], strides=[50331648, 1, 4096])
x_15 = m.add_instruction(migraphx.op("dot"), [p_x4, x_14]) # migraphx.shape(type="half_type", lens=[1, 1, 12288])
m.add_return([x_15])
develop branch this PR (+new layout) speed up
gfx1101 0.1248ms 0.0897ms 1.39x
gfx1150 (strix) 0.5302ms 0.4071ms 1.30x

@pfultz2 pointed out we can use slice operations instead of changing quantizelinear to use one param. This simplifies this PR a lot.

TODO:

Copy link

codecov bot commented Dec 17, 2024

Codecov Report

Attention: Patch coverage is 23.80952% with 48 lines in your changes missing coverage. Please review.

Project coverage is 78.45%. Comparing base (a4e8230) to head (e663907).
Report is 17 commits behind head on develop.

Files with missing lines Patch % Lines
...lir/lib/Dialect/Rock/utility/transformMapUtils.cpp 6.12% 45 Missing and 1 partial ⚠️
...ialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp 83.33% 0 Missing and 2 partials ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #1706      +/-   ##
===========================================
- Coverage    78.88%   78.45%   -0.44%     
===========================================
  Files          100      100              
  Lines        28346    28405      +59     
  Branches      4130     4146      +16     
===========================================
- Hits         22361    22285      -76     
- Misses        4368     4458      +90     
- Partials      1617     1662      +45     
Flag Coverage Δ
mfma 78.45% <23.80%> (-0.43%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@@ -2549,7 +2549,16 @@ struct GridwiseGemmAccelRewritePattern

// Obtain data types of inputs.
auto elementTypeA = op.getA().getType().getElementType();
auto maybeElementTypeALoad = getGemmInputElementType(op.getA());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to have a unit-test that uses the getGemmInputElementType to test out that logic along with gridwisegemmtoblockwise.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, do you know if where unit tests are written? I've only done lit/mlir tests (not sure how they are called).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant lit test. You can also write e2e test that verifies results against host runner.

@dhernandez0 dhernandez0 force-pushed the 1665-quantizelinear-slower-than-f16-on-mi300x branch from 1748a96 to f87b60a Compare December 18, 2024 15:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants