Skip to content

Commit

Permalink
Merge pull request vllm-project#8 from KuntaiDu/jiayi-dev-v2
Browse files Browse the repository at this point in the history
update overhead benchmark
  • Loading branch information
KuntaiDu authored Sep 15, 2024
2 parents 1f47731 + caaaeb8 commit 0dd3571
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 9 deletions.
2 changes: 1 addition & 1 deletion benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ benchmark() {
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
dataset_name="sonnet"
dataset_path="../sonnet_4x.txt"
num_prompts=20
num_prompts=10
qps=$1
prefix_len=50
input_len=2048
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from vllm.distributed.kv_transfer.kv_lookup_buffer.base import \
KVLookupBufferBase
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
from typing import Dict, Tuple, List, Optional
from typing import Dict, Tuple, List, Optional, Union
import threading
import torch
from collections import deque
Expand All @@ -14,10 +14,19 @@

class SimpleKVLookupBuffer(KVLookupBufferBase):

def __init__(self, signal_pipe, data_pipe, buffer_size_thresh):
def __init__(self,
signal_pipe: KVPipeBase,
data_pipe: KVPipeBase,
buffer_size_thresh: int):
"""
signal_pipe: on CPU -- avoid recv() stops the python intepreter
data_pipe: on GPU
signal_pipe: on CPU
NOTE: on-device recv will block all threads in the process, making the
KV cache producer unable to listen to new request while transmitting
KV cache. Luckily CPU recv only blocks the current thread so we use
CPU recv to listen to new request.
data_pipe: on device (e.g. GPU)
"""

self.buffer = deque()
Expand All @@ -33,7 +42,9 @@ def __init__(self, signal_pipe, data_pipe, buffer_size_thresh):
self.end_signal = None


def _matches(self, tokens_roi_sender, tokens_roi_recver):
def _matches(self,
tokens_roi_sender: List[torch.Tensor],
tokens_roi_recver: List[torch.Tensor]):

# tokens_roi_sender: tokens and roi of the producer (in the buffer)
# tokens_roi_recver: tokens and roi of the consumer (query)
Expand Down Expand Up @@ -69,7 +80,7 @@ def _send_tensor_and_dec_size(self, tensor: Optional[torch.Tensor]) -> None:
self.buffer_size -= tensor.element_size() * tensor.numel()
self.data_pipe.send_tensor(tensor)

def _get_element_size(self, data):
def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]):

if data == [] or data is None:
return 0
Expand All @@ -78,7 +89,12 @@ def _get_element_size(self, data):

assert False, "Unknown data type %s" % type(data)

def _add_to_buffer(self, input_tokens, roi, key, value, hidden):
def _add_to_buffer(self,
input_tokens: torch.Tensor,
roi: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
hidden: torch.Tensor):

if isinstance(input_tokens, torch.Tensor):
input_tokens = input_tokens.clone()
Expand Down Expand Up @@ -150,7 +166,9 @@ def drop_select_handler(self):
logger.debug("Closing drop_select_handler")


def drop_select(self, input_tokens, roi):
def drop_select(self,
input_tokens: torch.Tensor,
roi: torch.Tensor):

assert self.request_handling_thread is None, \
"drop_select should be called by the receiver"
Expand Down Expand Up @@ -183,6 +201,7 @@ def insert(self, input_tokens, roi, key, value, hidden) -> None:
while self.buffer_size > self.buffer_size_threshold:
# logger.debug("KV transfer buffer is full. Handling...")
self.full_handler()


self._add_to_buffer(input_tokens, roi, key, value, hidden)

Expand Down

0 comments on commit 0dd3571

Please sign in to comment.