Skip to content

Commit

Permalink
Enable vineyard llm kv cache in vLLM
Browse files Browse the repository at this point in the history
Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
(cherry picked from commit 1545f6bf7edcd667e305d3fbcadd913066f04747)

# Conflicts:
#	vllm/attention/backends/flash_attn.py
#	vllm/worker/model_runner.py
  • Loading branch information
sighingnow committed Jun 18, 2024
1 parent 26e1188 commit d347dab
Show file tree
Hide file tree
Showing 4 changed files with 436 additions and 0 deletions.
8 changes: 8 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,14 @@ def __init__(
)
# TODO(woosuk): Print more configs in debug mode.

try:
import better_exceptions
except ImportError:
better_exceptions = None

if better_exceptions is not None:
better_exceptions.hook()

self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
Expand Down
5 changes: 5 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
VLLM_WORKER_MULTIPROC_METHOD: str = "spawn"
VLLM_IMAGE_FETCH_TIMEOUT: int = 5
VLLM_TARGET_DEVICE: str = "cuda"
VLLM_USE_VINEYARD_CACHE: Optional[str] = None
MAX_JOBS: Optional[str] = None
NVCC_THREADS: Optional[str] = None
VLLM_USE_PRECOMPILED: bool = False
Expand All @@ -52,6 +53,10 @@
"VLLM_TARGET_DEVICE":
lambda: os.getenv("VLLM_TARGET_DEVICE", "cuda"),

# Enable vineyard kv cache for vLLM.
"VLLM_USE_VINEYARD_CACHE":
lambda: os.getenv("VLLM_USE_VINEYARD_CACHE", None),

# Maximum number of compilation jobs to run in parallel.
# By default this is the number of CPUs
"MAX_JOBS":
Expand Down
41 changes: 41 additions & 0 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict
from vllm.distributed.parallel_state import graph_capture

import vllm.envs as envs
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
Expand Down Expand Up @@ -143,6 +145,32 @@ def __init__(
# Set after load_model.
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None

# Delay the initialization of vineyard cache after model loading
# to ensure the tensor model parallel group is initialized.
self.vineyard_llm_cache = None

def _init_vineyard_cache(self):
if envs.VLLM_USE_VINEYARD_CACHE:
if not self.scheduler_config.chunked_prefill_enabled:
logger.warn("Vineyard LLM cache is not enabled, requires chunked prefill")
elif not envs.VLLM_USE_FLASH_ATTN_DECODING:
logger.warn("Vineyard LLM cache is not enabled, requires flash attention decoding")
else:
from vllm.worker.vineyard_llm_cache import VineyardLLMCache
self.vineyard_llm_cache: VineyardLLMCache = VineyardLLMCache.from_envs(
model_config=self.model_config,
parallel_config=self.parallel_config,
kv_cache_dtype=self.kv_cache_dtype,
torch_dtype=get_kv_cache_torch_dtype(self.kv_cache_dtype,
self.model_config.dtype),
)
if self.vineyard_llm_cache:
logger.info("Using Vineyard LLM cache")
else:
logger.warn("Vineyard LLM cache is failed to be initialized")
else:
logger.info("Vineyard LLM cache is not enabled")

def load_model(self) -> None:
with CudaMemoryProfiler() as m:
self.model = get_model(
Expand Down Expand Up @@ -209,6 +237,8 @@ def load_model(self) -> None:
"provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!")

self._init_vineyard_cache()

def save_sharded_state(
self,
path: str,
Expand All @@ -223,6 +253,9 @@ def save_sharded_state(
max_size=max_size,
)

def set_block_size(self, block_size: int) -> None:
self.block_size = block_size

def save_tensorized_model(
self,
tensorizer_config: TensorizerConfig,
Expand Down Expand Up @@ -730,6 +763,10 @@ def execute_model(
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
kv_caches: List[torch.Tensor],
) -> Optional[SamplerOutput]:
if self.vineyard_llm_cache and kv_caches[0] is not None:
cache_hints = self.vineyard_llm_cache.prefetch_kv_caches(
seq_group_metadata_list, kv_caches, getattr(self, 'block_size', None))

(input_tokens, input_positions, attn_metadata, sampling_metadata,
lora_requests, lora_mapping, multi_modal_kwargs
) = self.prepare_input_tensors(seq_group_metadata_list)
Expand Down Expand Up @@ -757,6 +794,10 @@ def execute_model(
# Compute the logits.
logits = self.model.compute_logits(hidden_states, sampling_metadata)

if self.vineyard_llm_cache and kv_caches[0] is not None:
self.vineyard_llm_cache.update_kv_caches(
cache_hints, seq_group_metadata_list, kv_caches, getattr(self, 'block_size', None))

# Only perform sampling in the driver worker.
if not self.is_driver_worker:
return None
Expand Down
Loading

0 comments on commit d347dab

Please sign in to comment.