Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core][distributed] add zmq fallback for broadcasting large objects #6183

Merged
merged 11 commits into from
Jul 10, 2024

Conversation

youkaichao
Copy link
Member

The input to vision language model contains images, which has variable length and can be quite large.

While the shared memory broadcast introduced in #5399 works fine for LLMs, later we find we often need to adjust the buffer size for vision language models.

Estimating the size upper bound can be difficult. To solve the problem, this PR adds a fallback option using zeromq.

  • When the object is small, we will use shared memory broadcast, which is fast and efficient.
  • If the object is too large, we leave an overflow message in the shared memory, and send the data via zeromq, which can handle arbitrary sized data.

In addition, shared memory broadcast is limited to single node, while zeromq (socket-based) is not. Therefore, we can extend the broadcast to also work for cross-node settings. This PR extends the functionality.

cc @DarkLight1337 @ywang96 for vision language model related.

Comment on lines +415 to +425
if self.n_local_reader > 0:
if len(serialized_obj) >= self.buffer.max_chunk_bytes:
with self.acquire_write() as buf:
buf[0] = 1 # overflow
self.local_socket.send(serialized_obj)
else:
with self.acquire_write() as buf:
buf[0] = 0 # not overflow
buf[1:len(serialized_obj) + 1] = serialized_obj
if self.n_remote_reader > 0:
self.remote_socket.send(serialized_obj)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the most critical part, previously, when object size is too large (e.g. large image data), vLLM will error. Now, we will fall back to zqm.

@DarkLight1337
Copy link
Member

DarkLight1337 commented Jul 8, 2024

This idea sounds good, but I don't have much experience with cross-device communication, so I'll leave the review to someone who is more qualified.

Copy link
Member

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took a very brief look and the code in general LGTM. I didn't check the internal logic of the message queue-based broadcast, but the interface and code change looks good. Please let me know whether I need to look into any part more carefully.

vllm/distributed/parallel_state.py Outdated Show resolved Hide resolved
@WoosukKwon
Copy link
Collaborator

@youkaichao pytest tests/distributed/test_shm_broadcast.py worked on the AMD MI210 machine. Do you want other kinds of tests as well?

@youkaichao
Copy link
Member Author

@WoosukKwon thanks for testing! I think it works because I added the dependency in requirements-common.txt , so all platforms can use it.

@youkaichao youkaichao merged commit da78cae into vllm-project:main Jul 10, 2024
70 checks passed
@youkaichao youkaichao deleted the add_zmq branch July 10, 2024 01:49
adityagoel14 pushed a commit to adityagoel14/vllm-torchrun-test that referenced this pull request Jul 10, 2024
…-project#6183)

[core][distributed] add zmq fallback for broadcasting large objects (vllm-project#6183)

(cherry picked from commit da78cae)
dtrifiro pushed a commit to opendatahub-io/vllm that referenced this pull request Jul 17, 2024
…-project#6183)

[core][distributed] add zmq fallback for broadcasting large objects (vllm-project#6183)
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jul 24, 2024
…-project#6183)

[core][distributed] add zmq fallback for broadcasting large objects (vllm-project#6183)
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
…-project#6183)

[core][distributed] add zmq fallback for broadcasting large objects (vllm-project#6183)

Signed-off-by: Alvant <alvasian@yandex.ru>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants