Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core][distributed] improve shared memory broadcast #5754

Merged
merged 5 commits into from
Jun 22, 2024
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 26 additions & 8 deletions vllm/distributed/device_communicators/shm_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,26 @@ def __init__(self,
| written_flag | reader0_flag | reader1_flag | ... | readerN_flag |
+--------------+--------------+--------------+-----+--------------+

The state of metadata is as follows:

(case 1) 0???...???: the block is not written yet, cannot read, can write
(case 2) 1000...000: the block is just written, can read, cannot write
(case 3) 1???...???: the block is written and read by some readers, can read if not read, cannot write
(case 4) 1111...111: the block is written and read by all readers, cannot read, can write

State transition for readers:

When a reader finds a block that it can read (case 2 or 3), it can yield the block for caller to read.
Only after the caller finishes reading the block, the reader can mark the block as read.
Readers only mark the block as read (from 0 to 1), the writer marks the block as ready to read (from 1 to 0).

State transition for writer:

When the writer writes to a block (case 1 or 4), it first resets the written flag to 0, converting either case
to case 1. Then it can yield the block for caller to write. After the caller finishes writing the block, the writer
can reset the reader flags to 0, and mark the block as written (from 0 to 1).
NOTE: the order is important here, first reset the reader flags (so that we are still in case 1), then mark the block as written. The state transition is atomic. If we do it in the reverse order, it will go through case 3 and then back to case 2, and readers might read the intermediate case 3, which is not correct.

During creation, `name` is None and the buffer is created. We can pass the
created object to other processes by pickling it. The other processes will
get the name of the shared memory and open it, so that they can access the
Expand Down Expand Up @@ -81,10 +101,6 @@ def __init__(self,
lambda *args, **kwargs: None):
self.shared_memory = shared_memory.SharedMemory(name=name)
assert self.shared_memory.size == self.total_bytes_of_buffer
with memoryview(self.shared_memory.buf[self.metadata_offset:]
) as metadata_buffer:
tensor = torch.frombuffer(metadata_buffer, dtype=torch.uint8)
assert torch.all(tensor == 0)

def __reduce__(self):
return (
Expand Down Expand Up @@ -163,11 +179,15 @@ def acquire_write(self):
yield buf

# caller has written to the buffer
# mark the block as written
metadata_buffer[0] = 1
# NOTE: order is important here
# first set the read flags to 0
# then set the written flag to 1
# otherwise, the readers may think they already read the block
for i in range(1, self.buffer.n_reader + 1):
# set read flag to 0, meaning it is not read yet
metadata_buffer[i] = 0
# mark the block as written
metadata_buffer[0] = 1
break

@contextmanager
Expand Down Expand Up @@ -248,12 +268,10 @@ def create_from_process_group(pg: ProcessGroup,
if group_rank == writer_rank:
buffer = ShmRingBuffer(n_reader, max_chunk_bytes, max_chunks)
dist.broadcast_object_list([buffer], src=global_ranks[writer_rank])
dist.barrier(pg)
return ShmRingBufferIO(buffer, -1)
else:
recv = [None]
dist.broadcast_object_list(recv, src=global_ranks[writer_rank])
dist.barrier(pg)
buffer = recv[0] # type: ignore
rest_ranks = [r for r in ranks_inside_group if r != writer_rank]
return ShmRingBufferIO(buffer, rest_ranks.index(group_rank))
Loading