diff --git a/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh index d264f18156438..f0ee54357af74 100644 --- a/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh @@ -53,7 +53,7 @@ benchmark() { model="meta-llama/Meta-Llama-3.1-8B-Instruct" dataset_name="sonnet" dataset_path="../sonnet_4x.txt" - num_prompts=20 + num_prompts=10 qps=$1 prefix_len=50 input_len=2048 diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py index de2e7cf5d5d57..6172bf092fb03 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py @@ -2,7 +2,7 @@ from vllm.distributed.kv_transfer.kv_lookup_buffer.base import \ KVLookupBufferBase from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase -from typing import Dict, Tuple, List, Optional +from typing import Dict, Tuple, List, Optional, Union import threading import torch from collections import deque @@ -14,10 +14,19 @@ class SimpleKVLookupBuffer(KVLookupBufferBase): - def __init__(self, signal_pipe, data_pipe, buffer_size_thresh): + def __init__(self, + signal_pipe: KVPipeBase, + data_pipe: KVPipeBase, + buffer_size_thresh: int): """ - signal_pipe: on CPU -- avoid recv() stops the python intepreter - data_pipe: on GPU + signal_pipe: on CPU + + NOTE: on-device recv will block all threads in the process, making the + KV cache producer unable to listen to new request while transmitting + KV cache. Luckily CPU recv only blocks the current thread so we use + CPU recv to listen to new request. + + data_pipe: on device (e.g. GPU) """ self.buffer = deque() @@ -33,7 +42,9 @@ def __init__(self, signal_pipe, data_pipe, buffer_size_thresh): self.end_signal = None - def _matches(self, tokens_roi_sender, tokens_roi_recver): + def _matches(self, + tokens_roi_sender: List[torch.Tensor], + tokens_roi_recver: List[torch.Tensor]): # tokens_roi_sender: tokens and roi of the producer (in the buffer) # tokens_roi_recver: tokens and roi of the consumer (query) @@ -69,7 +80,7 @@ def _send_tensor_and_dec_size(self, tensor: Optional[torch.Tensor]) -> None: self.buffer_size -= tensor.element_size() * tensor.numel() self.data_pipe.send_tensor(tensor) - def _get_element_size(self, data): + def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]): if data == [] or data is None: return 0 @@ -78,7 +89,12 @@ def _get_element_size(self, data): assert False, "Unknown data type %s" % type(data) - def _add_to_buffer(self, input_tokens, roi, key, value, hidden): + def _add_to_buffer(self, + input_tokens: torch.Tensor, + roi: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + hidden: torch.Tensor): if isinstance(input_tokens, torch.Tensor): input_tokens = input_tokens.clone() @@ -150,7 +166,9 @@ def drop_select_handler(self): logger.debug("Closing drop_select_handler") - def drop_select(self, input_tokens, roi): + def drop_select(self, + input_tokens: torch.Tensor, + roi: torch.Tensor): assert self.request_handling_thread is None, \ "drop_select should be called by the receiver" @@ -183,6 +201,7 @@ def insert(self, input_tokens, roi, key, value, hidden) -> None: while self.buffer_size > self.buffer_size_threshold: # logger.debug("KV transfer buffer is full. Handling...") self.full_handler() + self._add_to_buffer(input_tokens, roi, key, value, hidden)