Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a parametrized benchmark for swiglu #1105

Merged
merged 3 commits into from
Sep 5, 2024
Merged

Add a parametrized benchmark for swiglu #1105

merged 3 commits into from
Sep 5, 2024

Conversation

IvanYashchuk
Copy link
Collaborator

@IvanYashchuk IvanYashchuk commented Sep 5, 2024

Adds a new benchmark case for the activation function of Llama MLP torch.nn.functional.silu(x_fc_1) * x_fc_2: https://github.com/Lightning-AI/litgpt/blob/fdf6a120056d1363287285599eb84907f6c589b9/litgpt/model.py#L372

Here are the results for Llama 3 8B configuration with batch sizes 1 and 2 on H100:

------------------------------------------------------------------------------- benchmark 'config=Llama-3-8B bs=1 compute_type=ComputeType.INFERENCE': 5 tests -------------------------------------------------------------------------------
Name (time in us)                                                         Min                 Max                Mean             StdDev              Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_litgpt_swiglu[Llama-3-8B-inference-bs1-torch.compile]           257.9034 (1.0)      373.7697 (1.0)      263.7706 (1.0)      17.6783 (1.0)      259.6658 (1.0)      1.2085 (1.07)       60;140        3.7912 (1.0)        1294           3
test_litgpt_swiglu[Llama-3-8B-inference-bs1-torch.compile+liger]     260.1664 (1.01)     421.2319 (1.13)     266.7489 (1.01)     19.2305 (1.09)     262.1046 (1.01)     1.3307 (1.18)       61;159        3.7488 (0.99)       1282           3
test_litgpt_swiglu[Llama-3-8B-inference-bs1-torch+liger]             273.8149 (1.06)     468.4066 (1.25)     284.0076 (1.08)     30.0264 (1.70)     276.4002 (1.06)     1.6572 (1.47)      105;195        3.5210 (0.93)       1828           2
test_litgpt_swiglu[Llama-3-8B-inference-bs1-thunder]                 291.0770 (1.13)     486.3574 (1.30)     302.6241 (1.15)     30.6936 (1.74)     294.6382 (1.13)     2.3446 (2.08)       98;185        3.3044 (0.87)       1718           2
test_litgpt_swiglu[Llama-3-8B-inference-bs1-torch]                   436.3761 (1.69)     570.9534 (1.53)     443.0198 (1.68)     21.7283 (1.23)     438.1957 (1.69)     1.1252 (1.0)         51;94        2.2572 (0.60)       1147           2
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

---------------------------------------------------------------------------- benchmark 'config=Llama-3-8B bs=1 compute_type=ComputeType.TRAINING_BACKWARD': 5 tests ---------------------------------------------------------------------------
Name (time in us)                                                        Min                   Max                Mean             StdDev              Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_litgpt_swiglu[Llama-3-8B-backward-bs1-torch.compile+liger]     438.9444 (1.0)        910.5259 (1.52)     452.9431 (1.00)     38.0853 (1.53)     444.4392 (1.0)      4.7262 (2.47)        65;84        2.2078 (1.00)       1130           2
test_litgpt_swiglu[Llama-3-8B-backward-bs1-torch+liger]             442.6504 (1.01)       618.9051 (1.04)     454.4416 (1.00)     31.9635 (1.28)     446.1873 (1.00)     2.2256 (1.16)       65;127        2.2005 (1.00)       1130           2
test_litgpt_swiglu[Llama-3-8B-backward-bs1-torch.compile]           443.4905 (1.01)       597.2490 (1.0)      452.4754 (1.0)      24.9601 (1.0)      446.7046 (1.01)     1.9167 (1.0)        54;100        2.2101 (1.0)        1132           2
test_litgpt_swiglu[Llama-3-8B-backward-bs1-thunder]                 450.0614 (1.03)       644.9015 (1.08)     461.6722 (1.02)     30.8360 (1.24)     454.0675 (1.02)     2.3335 (1.22)       53;127        2.1660 (0.98)       1128           2
test_litgpt_swiglu[Llama-3-8B-backward-bs1-torch]                   777.6967 (1.77)     1,070.5460 (1.79)     795.0041 (1.76)     48.0013 (1.92)     783.6400 (1.76)     5.2054 (2.72)        61;70        1.2579 (0.57)       1284           1
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

-------------------------------------------------------------------------- benchmark 'config=Llama-3-8B bs=1 compute_type=ComputeType.TRAINING_FORWARD': 5 tests ---------------------------------------------------------------------------
Name (time in us)                                                       Min                 Max                Mean             StdDev              Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_litgpt_swiglu[Llama-3-8B-forward-bs1-torch.compile]           261.1143 (1.0)      371.6873 (1.0)      268.0700 (1.0)      20.2952 (1.0)      262.8528 (1.0)      1.2310 (1.0)        73;130        3.7304 (1.0)        1280           3
test_litgpt_swiglu[Llama-3-8B-forward-bs1-torch+liger]             273.9176 (1.05)     452.1019 (1.22)     285.0940 (1.06)     31.9310 (1.57)     276.9795 (1.05)     1.7657 (1.43)      104;187        3.5076 (0.94)       1825           2
test_litgpt_swiglu[Llama-3-8B-forward-bs1-torch.compile+liger]     280.1844 (1.07)     479.2414 (1.29)     290.8800 (1.09)     30.3655 (1.50)     283.4606 (1.08)     1.9370 (1.57)      102;198        3.4378 (0.92)       1785           2
test_litgpt_swiglu[Llama-3-8B-forward-bs1-thunder]                 293.8700 (1.13)     523.5821 (1.41)     307.9894 (1.15)     36.2610 (1.79)     298.2542 (1.13)     3.2156 (2.61)       99;205        3.2469 (0.87)       1697           2
test_litgpt_swiglu[Llama-3-8B-forward-bs1-torch]                   437.6438 (1.68)     585.5404 (1.58)     444.9212 (1.66)     23.9094 (1.18)     439.5393 (1.67)     1.3874 (1.13)        52;75        2.2476 (0.60)       1143           2
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

-------------------------------------------------------------------------------- benchmark 'config=Llama-3-8B bs=2 compute_type=ComputeType.INFERENCE': 5 tests --------------------------------------------------------------------------------
Name (time in us)                                                         Min                   Max                Mean             StdDev              Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_litgpt_swiglu[Llama-3-8B-inference-bs2-torch.compile]           498.0909 (1.0)        672.4799 (1.02)     506.9069 (1.0)      26.7068 (1.0)      500.6442 (1.0)      1.7794 (1.00)       48;105        1.9727 (1.0)        1004           2
test_litgpt_swiglu[Llama-3-8B-inference-bs2-torch+liger]             499.2611 (1.00)       660.4067 (1.0)      509.4721 (1.01)     29.0899 (1.09)     501.9680 (1.00)     1.8635 (1.05)       56;128        1.9628 (0.99)       1003           2
test_litgpt_swiglu[Llama-3-8B-inference-bs2-torch.compile+liger]     500.1344 (1.00)       667.2249 (1.01)     509.2604 (1.00)     27.4019 (1.03)     502.6642 (1.00)     1.7755 (1.0)        48;116        1.9636 (1.00)       1000           2
test_litgpt_swiglu[Llama-3-8B-inference-bs2-thunder]                 580.0794 (1.16)       967.0402 (1.46)     603.1700 (1.19)     61.4176 (2.30)     586.9209 (1.17)     6.0075 (3.38)       93;186        1.6579 (0.84)       1725           1
test_litgpt_swiglu[Llama-3-8B-inference-bs2-torch]                   867.0869 (1.74)     1,156.1923 (1.75)     882.4813 (1.74)     48.8080 (1.83)     871.2760 (1.74)     2.6614 (1.50)        56;79        1.1332 (0.57)       1153           1
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

------------------------------------------------------------------------------ benchmark 'config=Llama-3-8B bs=2 compute_type=ComputeType.TRAINING_BACKWARD': 5 tests -----------------------------------------------------------------------------
Name (time in us)                                                          Min                   Max                  Mean             StdDev                Median               IQR            Outliers         OPS            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_litgpt_swiglu[Llama-3-8B-backward-bs2-torch.compile+liger]       869.8208 (1.0)      1,598.7931 (1.28)       898.9833 (1.00)     83.4485 (1.50)       878.5098 (1.00)     8.1114 (2.24)       68;108  1,112.3677 (1.00)       1146           1
test_litgpt_swiglu[Llama-3-8B-backward-bs2-torch+liger]               871.0301 (1.00)     1,574.0991 (1.26)       896.1385 (1.0)      76.8793 (1.38)       877.8151 (1.0)      7.7938 (2.15)        63;96  1,115.8990 (1.0)        1138           1
test_litgpt_swiglu[Llama-3-8B-backward-bs2-torch.compile]             890.4757 (1.02)     1,251.0130 (1.0)        915.5389 (1.02)     67.2028 (1.21)       897.2869 (1.02)     4.9510 (1.37)       70;123  1,092.2529 (0.98)       1123           1
test_litgpt_swiglu[Llama-3-8B-backward-bs2-thunder]                   895.2850 (1.03)     1,276.3930 (1.02)       923.7420 (1.03)     71.2424 (1.28)       904.2779 (1.03)     5.3896 (1.49)       70;115  1,082.5534 (0.97)       1117           1
test_litgpt_swiglu[Llama-3-8B-backward-bs2-torch]                   1,476.1370 (1.70)     1,820.4669 (1.46)     1,496.7633 (1.67)     55.6942 (1.0)      1,483.8863 (1.69)     3.6168 (1.0)         33;67    668.1083 (0.60)        678           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

--------------------------------------------------------------------------- benchmark 'config=Llama-3-8B bs=2 compute_type=ComputeType.TRAINING_FORWARD': 5 tests ----------------------------------------------------------------------------
Name (time in us)                                                       Min                   Max                Mean             StdDev              Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_litgpt_swiglu[Llama-3-8B-forward-bs2-torch+liger]             499.5775 (1.0)        655.7805 (1.0)      509.2559 (1.0)      28.2605 (1.0)      502.1906 (1.0)      1.5695 (1.0)        56;106        1.9636 (1.0)        1001           2
test_litgpt_swiglu[Llama-3-8B-forward-bs2-torch.compile]           504.6555 (1.01)       722.7624 (1.10)     516.1296 (1.01)     33.7983 (1.20)     507.3906 (1.01)     1.9771 (1.26)       58;105        1.9375 (0.99)        992           2
test_litgpt_swiglu[Llama-3-8B-forward-bs2-torch.compile+liger]     505.2150 (1.01)       734.2570 (1.12)     516.9047 (1.02)     28.8159 (1.02)     509.2416 (1.01)     8.1468 (5.19)        39;44        1.9346 (0.99)        991           2
test_litgpt_swiglu[Llama-3-8B-forward-bs2-thunder]                 583.4647 (1.17)       987.3961 (1.51)     607.6304 (1.19)     64.1179 (2.27)     590.7598 (1.18)     6.5868 (4.20)       90;194        1.6457 (0.84)       1714           1
test_litgpt_swiglu[Llama-3-8B-forward-bs2-torch]                   869.0651 (1.74)     1,191.2668 (1.82)     885.4322 (1.74)     50.9958 (1.80)     874.0772 (1.74)     3.1209 (1.99)        52;65        1.1294 (0.58)       1150           1
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Legend:
  Outliers: 1 Standard Deviation from Mean; 1.5 IQR (InterQuartile Range) from 1st Quartile and 3rd Quartile.
  OPS: Operations Per Second, computed as 1 / Mean

Relevant references:

cc @crcrpar

Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @IvanYashchuk

@t-vi t-vi merged commit e64d347 into main Sep 5, 2024
37 checks passed
@t-vi t-vi deleted the benchmark-swiglu branch September 5, 2024 17:09
@IvanYashchuk
Copy link
Collaborator Author

IvanYashchuk commented Sep 6, 2024

Liger kernel performance can be improved by replacing zeros_like with empty_like (here, PR linkedin/Liger-Kernel#217). For example here's the impact for test_litgpt_swiglu[Llama-3-8B-inference-bs1-torch+liger] (1.5x speedup):

---------------------------------------------------------------------------------------------------------------- benchmark: 2 tests ---------------------------------------------------------------------------------------------------------------
Name (time in us)                                                              Min                 Max                Mean             StdDev              Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_litgpt_swiglu[Llama-3-8B-forward-bs1-torch+liger] (0002_e64d347)     272.5641 (1.0)      433.6799 (1.0)      282.0790 (1.0)      28.3710 (1.11)     274.9823 (1.0)      1.6231 (1.0)       103;188        3.5451 (1.0)        1838           2
test_litgpt_swiglu[Llama-3-8B-forward-bs1-torch+liger] (0001_e64d347)     415.7315 (1.53)     575.9231 (1.33)     424.2210 (1.50)     25.6607 (1.0)      417.8598 (1.52)     1.6439 (1.01)       67;115        2.3573 (0.66)       1204           2
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

@ByronHsu
Copy link
Contributor

ByronHsu commented Sep 6, 2024

awesome

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants