Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LightLLM benchmark #670

Closed
zhyncs opened this issue Aug 3, 2023 · 8 comments
Closed

LightLLM benchmark #670

zhyncs opened this issue Aug 3, 2023 · 8 comments
Labels
enhancement New feature or request P0

Comments

@zhyncs
Copy link
Contributor

zhyncs commented Aug 3, 2023

Hi vLLM genius @zhuohan123 @WoosukKwon

I find a new project https://github.com/ModelTC/lightllm

After reading their blog, the performance advantage on the 7b model is not very obvious, but the gap is larger on the 65b. We will also do some verification and comparison later. The reason for bringing up this issue is to hope that we may see what the LightLLM does well, so that we can refer to and port similar optimizations to vLLM. Cheers.

@zhyncs
Copy link
Contributor Author

zhyncs commented Aug 3, 2023

As mentioned in the blog, one point particularly caught my attention, and I feel that this feature may be compatible with the current design of vLLM.

三进程架构,主要用于异步化处理 tokenize 和 detokenize操作, 可以避免这些耗时的cpu处理阻碍模型推理时gpu的调度执行,降低gpu的使用率,进而降低了整体服务性能。

English version

The three-process architecture is mainly used for asynchronous processing of tokenize and detokenize operations, which can avoid these time-consuming cpu processing from hindering the scheduling execution of gpu during model inference, reduce the utilization rate of gpu, and then reduce the overall service performance.

Also cc @naed90 @wejoncy

@zhyncs
Copy link
Contributor Author

zhyncs commented Aug 3, 2023

After briefly reviewing LightLLM, I found that its kernel is written in OpenAI Triton rather than CUDA. This makes it easier for those who want to participate in optimization in the future to get started, and it also has good performance. For multi GPUs inference, it uses RPyC instead of Ray. Both of vLLM and LightLLM use Tensor Parallelism over multiple GPUs for faster inference. Currently, it's hard to say whether the technology choice can bring such a big performance difference. I think the performance improvement comes more from TokenAttention and Efficient Router, as well as the a-synchronization of tokenize and de tokenize. I've counted the lines of code in LightLLM, and only look at Llama2, there are only over 2000 lines overall when removing Bloom and LLaMA, which is quite unbelievable.

@zhuohan123
Copy link
Member

zhuohan123 commented Aug 3, 2023

Hi! Thanks for bringing this up. We are excited to see new open-source efforts based on vLLM projects. I'm reproducing the results from LightLLM right now, but from the first glance:

  • TokenAttention is the special case of PagedAttention when block size equals to 1, which we have tested before and find it under-utilizes GPU compute compared to larger block size. Unless LightLLM's Triton kernel implementation is surprisingly fast, this should not bring speedup.
  • The memory saving brought by TokenAttention should also not be significant. In our current benchmarks, the memory waste is already less than 4%. From 4% -> 0% should not bring that much speed gain. In addition, running 7B LLaMA on 80GB A100 should have abundant memory, memory waste should further not be the speed bottleneck.
  • GPU-based efficient router is a cool design. However, in vLLM, the page table only has ~10k blocks, which is very small compared to the neural network computation and should not be a speed bottleneck as well.
  • Therefore, we believe the speedup should mainly come from the design of tokenizers and samplers , which is a known speed bottleneck in vLLM.

However, these are all my initial guesses. We will perform a more thorough benchmark and update the results here.

@zhuohan123
Copy link
Member

I tried to rerun the LLaMA 7B benchmark on an 80GB A100 on GCP. With the latest vLLM main branch:

(vllm) zhuohan@zhuohan-1:~/vllm/vllm/benchmarks$ python benchmark_serving.py --backend vllm --tokenizer huggyllama/llama-7b --dataset ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 2000 --request-rate 200 --host 127.0.0.1 --port 9009
Namespace(backend='vllm', host='127.0.0.1', port=9009, dataset='ShareGPT_V3_unfiltered_cleaned_split.json', tokenizer='huggyllama/llama-7b', best_of=1, use_beam_search=False, num_prompts=2000, request_rate=200.0, seed=0, trust_remote_code=False)
INFO 08-03 21:09:35 tokenizer.py:29] For some LLaMA-based models, initializing the fast tokenizer may take a long time. To eliminate the initialization time, consider using 'hf-internal-testing/llama-tokenizer' instead of the original tokenizer.
Total time: 433.25 s
Throughput: 4.62 requests/s
Average latency: 200.60 s
Average latency per token: 0.69 s
Average latency per output token: 4.12 s

When commenting out all tokenization and using batched argmax sampling (branch):

(vllm) zhuohan@zhuohan-1:~/vllm/vllm/benchmarks$ python benchmark_serving.py --backend vllm --tokenizer huggyllama/llama-7b --dataset ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 2000 --request-rate 200 --host 127.0.0.1 --port 9009
Namespace(backend='vllm', host='127.0.0.1', port=9009, dataset='ShareGPT_V3_unfiltered_cleaned_split.json', tokenizer='huggyllama/llama-7b', best_of=1, use_beam_search=False, num_prompts=2000, request_rate=200.0, seed=0, trust_remote_code=False)
INFO 08-04 00:29:29 tokenizer.py:29] For some LLaMA-based models, initializing the fast tokenizer may take a long time. To eliminate the initialization time, consider using 'hf-internal-testing/llama-tokenizer' instead of the original tokenizer.
Total time: 245.59 s
Throughput: 8.14 requests/s
Average latency: 113.03 s
Average latency per token: 0.39 s
Average latency per output token: 2.33 s

The improvement in throughput (1.76x) is very close to the improvement reported in LightLLM (1.92x).

Will update the results after reproducing LightLLM's results.

@zhyncs
Copy link
Contributor Author

zhyncs commented Aug 4, 2023

Hi @zhuohan123

Thanks for your detailed reply. As mentioned above, the advantage of 7b throughout is not obvious. Perhaps, in addition to reproducing the 7b results, we may also investigate the 65b results, where the throughput difference becomes even larger when the model size increases. Thanks.

@irasin
Copy link
Contributor

irasin commented Aug 4, 2023

I found that a lot of coroutines are used in lightllm router, maybe these also bring some throughput improvements?

@OmarSayedMostafa
Copy link

I tried to rerun the LLaMA 7B benchmark on an 80GB A100 on GCP. With the latest vLLM main branch:

(vllm) zhuohan@zhuohan-1:~/vllm/vllm/benchmarks$ python benchmark_serving.py --backend vllm --tokenizer huggyllama/llama-7b --dataset ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 2000 --request-rate 200 --host 127.0.0.1 --port 9009
Namespace(backend='vllm', host='127.0.0.1', port=9009, dataset='ShareGPT_V3_unfiltered_cleaned_split.json', tokenizer='huggyllama/llama-7b', best_of=1, use_beam_search=False, num_prompts=2000, request_rate=200.0, seed=0, trust_remote_code=False)
INFO 08-03 21:09:35 tokenizer.py:29] For some LLaMA-based models, initializing the fast tokenizer may take a long time. To eliminate the initialization time, consider using 'hf-internal-testing/llama-tokenizer' instead of the original tokenizer.
Total time: 433.25 s
Throughput: 4.62 requests/s
Average latency: 200.60 s
Average latency per token: 0.69 s
Average latency per output token: 4.12 s

When commenting out all tokenization and using batched argmax sampling (branch):

(vllm) zhuohan@zhuohan-1:~/vllm/vllm/benchmarks$ python benchmark_serving.py --backend vllm --tokenizer huggyllama/llama-7b --dataset ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 2000 --request-rate 200 --host 127.0.0.1 --port 9009
Namespace(backend='vllm', host='127.0.0.1', port=9009, dataset='ShareGPT_V3_unfiltered_cleaned_split.json', tokenizer='huggyllama/llama-7b', best_of=1, use_beam_search=False, num_prompts=2000, request_rate=200.0, seed=0, trust_remote_code=False)
INFO 08-04 00:29:29 tokenizer.py:29] For some LLaMA-based models, initializing the fast tokenizer may take a long time. To eliminate the initialization time, consider using 'hf-internal-testing/llama-tokenizer' instead of the original tokenizer.
Total time: 245.59 s
Throughput: 8.14 requests/s
Average latency: 113.03 s
Average latency per token: 0.39 s
Average latency per output token: 2.33 s

The improvement in throughput (1.76x) is very close to the improvement reported in LightLLM (1.92x).

Will update the results after reproducing LightLLM's results.

@zhuohan123 can you update the link to the branch where only batched argmax sampling is used , since its not working anymore, thanks in advance.

@zhuohan123
Copy link
Member

.

This optimization has been integrated to the latest main branch of vLLM. Please try it out!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request P0
Projects
None yet
Development

No branches or pull requests

4 participants