-
Notifications
You must be signed in to change notification settings - Fork 180
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support fused silu mul #427
Conversation
Using cutlass would be great if they already incorporate half2 operations. |
make sense |
IMO I'm okay with introducing these new operators as a workaround solution, and it's preferrable to use existing building blocks to minimize the maintainance overhead. Regarding this operator, can we try using triton directly? I think triton should already supported opitimizations such as half2. |
Ok. I’ll take a look. Thanks! |
import torch
from torch.utils.benchmark import Timer
from itertools import product
from vllm import _custom_ops as ops
from flashinfer.activation import silu_and_mul as flashinfer_silu_and_mul
from flag_gems import silu_and_mul as flag_gems_silu_and_mul
def forward_vllm(x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=torch.float16, device=x.device)
ops.silu_and_mul(out, x)
return out
def forward_flashinfer(x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
out = torch.empty((*x.shape[:-1], d), dtype=torch.float16, device=x.device)
flashinfer_silu_and_mul(x, out)
return out
def forward_flag_gems(x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
return flag_gems_silu_and_mul(x[..., :d], x[..., d:])
def test_consistency():
x = torch.randn(2, 4, 2*d, dtype=torch.float16, device=device)
out_vllm = forward_vllm(x)
out_flashinfer = forward_flashinfer(x)
out_flag_gems = forward_flag_gems(x)
assert torch.allclose(out_vllm, out_flashinfer, atol=1e-3, rtol=1e-3)
assert torch.allclose(out_vllm, out_flag_gems, atol=1e-3, rtol=1e-3)
assert torch.allclose(out_flashinfer, out_flag_gems, atol=1e-3, rtol=1e-3)
print("Consistency test passed!")
device = torch.device("cuda")
d = 4096
test_consistency()
results = []
sizes = [2, 8, 32, 128, 512]
for batch_size, seq_length in product(sizes, sizes):
label = "SiLU and Mul"
sub_label = f"[{batch_size}, {seq_length}]"
input_tensor = torch.randn(batch_size, seq_length, 2*d, dtype=torch.float16, device=device)
min_run_time = max(0.1, min(1, batch_size * seq_length / 1e6))
for num_threads in [1, 4, 16, 32]:
results.append(
Timer(
stmt="forward_vllm(input_tensor)",
setup="from __main__ import forward_vllm",
globals={"input_tensor": input_tensor},
num_threads=num_threads,
label=label,
sub_label=sub_label,
description="vLLM",
).blocked_autorange(min_run_time=min_run_time)
)
results.append(
Timer(
stmt="forward_flashinfer(input_tensor)",
setup="from __main__ import forward_flashinfer",
globals={"input_tensor": input_tensor},
num_threads=num_threads,
label=label,
sub_label=sub_label,
description="FlashInfer",
).blocked_autorange(min_run_time=min_run_time)
)
results.append(
Timer(
stmt="forward_flag_gems(input_tensor)",
setup="from __main__ import forward_flag_gems",
globals={"input_tensor": input_tensor},
num_threads=num_threads,
label=label,
sub_label=sub_label,
description="Flag_gems",
).blocked_autorange(min_run_time=min_run_time)
)
compare = torch.utils.benchmark.Compare(results)
compare.print()
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zhyncs I did some simple change to the code (use vectorized read/write), and here is the results I got (by using triton's do_bench function) on H100:
Consistency test passed!
batch_size: 2 seq_length: 2 vllm_time: 0.007171261124312878
batch_size: 2 seq_length: 2 flashinfer_time: 0.005875087808817625
batch_size: 2 seq_length: 2 flaggems_time: 0.02994345873594284
batch_size: 2 seq_length: 8 vllm_time: 0.007260866463184357
batch_size: 2 seq_length: 8 flashinfer_time: 0.005772186443209648
batch_size: 2 seq_length: 8 flaggems_time: 0.0059105088002979755
batch_size: 2 seq_length: 32 vllm_time: 0.0077180881053209305
batch_size: 2 seq_length: 32 flashinfer_time: 0.006187621038407087
batch_size: 2 seq_length: 32 flaggems_time: 0.006364865694195032
batch_size: 2 seq_length: 128 vllm_time: 0.009424506686627865
batch_size: 2 seq_length: 128 flashinfer_time: 0.00816467683762312
batch_size: 2 seq_length: 128 flaggems_time: 0.008360029198229313
batch_size: 2 seq_length: 512 vllm_time: 0.02061079442501068
batch_size: 2 seq_length: 512 flashinfer_time: 0.014950418844819069
batch_size: 2 seq_length: 512 flaggems_time: 0.014861035160720348
batch_size: 8 seq_length: 2 vllm_time: 0.007269856985658407
batch_size: 8 seq_length: 2 flashinfer_time: 0.005773282144218683
batch_size: 8 seq_length: 2 flaggems_time: 0.005844910629093647
batch_size: 8 seq_length: 8 vllm_time: 0.00772811146453023
batch_size: 8 seq_length: 8 flashinfer_time: 0.006187872029840946
batch_size: 8 seq_length: 8 flaggems_time: 0.006329760421067476
batch_size: 8 seq_length: 32 vllm_time: 0.009468046016991138
batch_size: 8 seq_length: 32 flashinfer_time: 0.00817921757698059
batch_size: 8 seq_length: 32 flaggems_time: 0.008257889188826084
batch_size: 8 seq_length: 128 vllm_time: 0.020637067034840584
batch_size: 8 seq_length: 128 flashinfer_time: 0.015106520615518093
batch_size: 8 seq_length: 128 flaggems_time: 0.015257231891155243
batch_size: 8 seq_length: 512 vllm_time: 0.06076494976878166
batch_size: 8 seq_length: 512 flashinfer_time: 0.04020121321082115
batch_size: 8 seq_length: 512 flaggems_time: 0.04041324928402901
batch_size: 32 seq_length: 2 vllm_time: 0.007802661973983049
batch_size: 32 seq_length: 2 flashinfer_time: 0.006300441455096006
batch_size: 32 seq_length: 2 flaggems_time: 0.00637076934799552
batch_size: 32 seq_length: 8 vllm_time: 0.009482021443545818
batch_size: 32 seq_length: 8 flashinfer_time: 0.008183696307241917
batch_size: 32 seq_length: 8 flaggems_time: 0.008226810954511166
batch_size: 32 seq_length: 32 vllm_time: 0.020641470327973366
batch_size: 32 seq_length: 32 flashinfer_time: 0.015115585178136826
batch_size: 32 seq_length: 32 flaggems_time: 0.015271436423063278
batch_size: 32 seq_length: 128 vllm_time: 0.0607980377972126
batch_size: 32 seq_length: 128 flashinfer_time: 0.040251944214105606
batch_size: 32 seq_length: 128 flaggems_time: 0.04044438898563385
batch_size: 32 seq_length: 512 vllm_time: 0.21253922581672668
batch_size: 32 seq_length: 512 flashinfer_time: 0.1371561884880066
batch_size: 32 seq_length: 512 flaggems_time: 0.153084397315979
batch_size: 128 seq_length: 2 vllm_time: 0.00945486780256033
batch_size: 128 seq_length: 2 flashinfer_time: 0.008165393956005573
batch_size: 128 seq_length: 2 flaggems_time: 0.008223879151046276
batch_size: 128 seq_length: 8 vllm_time: 0.020657455548644066
batch_size: 128 seq_length: 8 flashinfer_time: 0.015147659927606583
batch_size: 128 seq_length: 8 flaggems_time: 0.015288702212274075
batch_size: 128 seq_length: 32 vllm_time: 0.06075974926352501
batch_size: 128 seq_length: 32 flashinfer_time: 0.04024820774793625
batch_size: 128 seq_length: 32 flaggems_time: 0.04044437035918236
batch_size: 128 seq_length: 128 vllm_time: 0.2123134285211563
batch_size: 128 seq_length: 128 flashinfer_time: 0.13708913326263428
batch_size: 128 seq_length: 128 flaggems_time: 0.15339134633541107
batch_size: 128 seq_length: 512 vllm_time: 0.8181041479110718
batch_size: 128 seq_length: 512 flashinfer_time: 0.5250738263130188
batch_size: 128 seq_length: 512 flaggems_time: 0.5300045013427734
batch_size: 512 seq_length: 2 vllm_time: 0.020511353388428688
batch_size: 512 seq_length: 2 flashinfer_time: 0.01491069421172142
batch_size: 512 seq_length: 2 flaggems_time: 0.015027211979031563
batch_size: 512 seq_length: 8 vllm_time: 0.060630060732364655
batch_size: 512 seq_length: 8 flashinfer_time: 0.040194932371377945
batch_size: 512 seq_length: 8 flaggems_time: 0.04028919339179993
batch_size: 512 seq_length: 32 vllm_time: 0.2125125527381897
batch_size: 512 seq_length: 32 flashinfer_time: 0.13712455332279205
batch_size: 512 seq_length: 32 flaggems_time: 0.15308579802513123
batch_size: 512 seq_length: 128 vllm_time: 0.818162202835083
batch_size: 512 seq_length: 128 flashinfer_time: 0.5249825119972229
batch_size: 512 seq_length: 128 flaggems_time: 0.529996395111084
batch_size: 512 seq_length: 512 vllm_time: 3.2437238693237305
batch_size: 512 seq_length: 512 flashinfer_time: 2.0770304203033447
batch_size: 512 seq_length: 512 flaggems_time: 2.1354780197143555
I think we achieve the best performance among the three in most cases. Let's merge this first and I don't want to spend too much time on optimizing elementwise kernels :)
🤖 I have created a release *beep* *boop* --- ## [0.1.4](v0.1.3...v0.1.4) (2024-08-09) ### Features * append attention kernels for fp8 kv-cache ([#420](#420)) ([906c2f5](906c2f5)) * support min_p sampling ([#422](#422)) ([d52f2da](d52f2da)) * deterministic sampling ([#417](#417)) ([0dd801d](0dd801d)) * more sampling operator options ([#431](#431)) ([68df9c4](68df9c4)) * support fused add rmsnorm ([#419](#419)) ([b781513](b781513)) * support fused silu mul ([#427](#427)) ([ea0ba9a](ea0ba9a)) ### Bug Fixes * fix dispatch fp16 type when enable fp8 ([#430](#430)) ([daa5566](daa5566)) * improve numerical stability of sampling kernels ([#429](#429)) ([898d8ea](898d8ea)) ### Other improvements * break up `_kernels` into multiple modules ([#428](#428)) ([8e482d9](8e482d9)) ### Acknowledgement We thank contributions and feedbacks from the community: [@comaniac](https://github.com/comaniac), [@esmeetu](https://github.com/esmeetu), [@LiuXiaoxuanPKU](https://github.com/LiuXiaoxuanPKU), [@peng1999](https://github.com/peng1999), [@xslingcn](https://github.com/xslingcn), [@Yard1](https://github.com/Yard1), [@zhyncs](https://github.com/zhyncs). --- This PR was generated with [Release Please](https://github.com/googleapis/release-please). See [documentation](https://github.com/googleapis/release-please#release-please). --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Zihao Ye <expye@outlook.com>
Motivation
as titled
I implemented a simplified version based on FasterTransformers, and I am considering whether to use optimizations like half2, and whether to consider using CUTLASS's LeftSiLUAndMul. Do you have any suggestions? Thanks. @yzh119
Modification