-
-
Notifications
You must be signed in to change notification settings - Fork 5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Performance]: empirical measurement of object serialization for input/output of worker #6241
Comments
okay, from vllm.sampling_params import SamplingParams
import pickle
print(len(pickle.dumps(SamplingParams()))) # 611 even if it does not contain any information, it will use |
Btw, both of these are done inside Anyscale before, and last time I benchmarked (this Jan), this could have the nearly same result as nccl broadcast based solution. |
glad to know that. then i think we should work for this, to replace nccl broadcast. |
This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you! |
This issue has been automatically closed due to inactivity. Please feel free to reopen if you feel it is still relevant. Thank you! |
Proposal to improve performance
currently,
LLMEngine
(driver) lives in the same process as tensor parallel rank 0 process, which caused a lot trouble for us, e.g. we cannot easily create two instances of vLLM with different GPUs. Spec decode hacks this a lot.basically, the function we care about is
LLMEngine.step
, and the core line of code is:when we use tensor parallel of size N,
this line will:
execute_model_req
into tensors, broadcast tensors to the rest N - 1 workersif we want to separate the tp rank 0 process and the engine process, such as #6032 ,there will be two serialization:
execute_model_req
will be serialized and sent to tp processes, even with advanced techniques, we can send once, and all processes can receive it, we still need to serialize it.output
will live in the tp rank 0 process at first, and then passed to the engine processTherefore, we need to measure how large are these objects, what is the cost of serializing them.
Here is a simple script:
And we can use the branch https://github.com/youkaichao/vllm/tree/measure_serialization to measure the serialization overhead (remember to
pip install levenshtein
):as we can see, the actual message we pass every step (difference, or distance between consecutive messages) is actually quite small, in several dozens of bytes. however, the serialized data are 10x~100x larger. Why?
for the output, this is because we have a very bad serialization format:
pickle
fordataclasses
. it stores field names, and class names, again and again.for the input, besides the above limitation (e.g. serialization of
SamplingParams
is terribly long, but not informative), we have another limitation: it sends the prompt again and again.What's next?
msgpack
does not work here, it cannot serializeExecuteModelRequest
)What if
ExecuteModelRequest
orSamplerOutput
contains GPU data?GPU data is expensive to move across processes. it should be used as least as possible. in most cases, we should leave GPU data in the worker. ideally, the engine will not own any GPU data.
cc @WoosukKwon @zhuohan123 @simon-mo @comaniac @cadedaniel @stephanie-wang @ruisearch42
Report of performance regression
No response
Misc discussion on performance
No response
Your current environment (if you think it is necessary)
The text was updated successfully, but these errors were encountered: