Skip to content

Commit

Permalink
[core][model] yet another cpu offload implementation (#6496)
Browse files Browse the repository at this point in the history
Co-authored-by: Michael Goin <michael@neuralmagic.com>
  • Loading branch information
youkaichao and mgoin authored Jul 18, 2024
1 parent 18fecc3 commit 1c27d25
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 4 deletions.
1 change: 1 addition & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ steps:
# install tensorizer for tensorize_vllm_model.py
- pip install awscli tensorizer
- python3 offline_inference.py
- python3 cpu_offload.py
- python3 offline_inference_with_prefix.py
- python3 llm_engine_example.py
- python3 llava_example.py
Expand Down
22 changes: 22 additions & 0 deletions examples/cpu_offload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from vllm import LLM, SamplingParams

# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Create an LLM.
llm = LLM(model="meta-llama/Llama-2-13b-chat-hf", cpu_offload_gb=10)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
2 changes: 2 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,7 @@ def __init__(
num_gpu_blocks_override: Optional[int] = None,
sliding_window: Optional[int] = None,
enable_prefix_caching: bool = False,
cpu_offload_gb: float = 0,
) -> None:
self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization
Expand All @@ -441,6 +442,7 @@ def __init__(
self.cache_dtype = cache_dtype
self.sliding_window = sliding_window
self.enable_prefix_caching = enable_prefix_caching
self.cpu_offload_gb = cpu_offload_gb
self._verify_args()
self._verify_cache_dtype()
self._verify_prefix_caching()
Expand Down
24 changes: 23 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class EngineArgs:
disable_sliding_window: bool = False
use_v2_block_manager: bool = False
swap_space: int = 4 # GiB
cpu_offload_gb: int = 0 # GiB
gpu_memory_utilization: float = 0.90
max_num_batched_tokens: Optional[int] = None
max_num_seqs: int = 256
Expand Down Expand Up @@ -303,6 +304,20 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
type=int,
default=EngineArgs.swap_space,
help='CPU swap space size (GiB) per GPU.')
parser.add_argument(
'--cpu-offload-gb',
type=float,
default=0,
help='The space in GiB to offload to CPU, per GPU. '
'Default is 0, which means no offloading. Intuitively, '
'this argument can be seen as a virtual way to increase '
'the GPU memory size. For example, if you have one 24 GB '
'GPU and set this to 10, virtually you can think of it as '
'a 34 GB GPU. Then you can load a 13B model with BF16 weight,'
'which requires at least 26GB GPU memory. Note that this '
'requires fast CPU-GPU interconnect, as part of the model is'
'loaded from CPU memory to GPU memory on the fly in each '
'model forward pass.')
parser.add_argument(
'--gpu-memory-utilization',
type=float,
Expand Down Expand Up @@ -633,6 +648,11 @@ def create_engine_config(self, ) -> EngineConfig:
raise ValueError(
"BitsAndBytes load format and QLoRA adapter only support "
f"'bitsandbytes' quantization, but got {self.quantization}")

assert self.cpu_offload_gb >= 0, (
"CPU offload space must be non-negative"
f", but got {self.cpu_offload_gb}")

multimodal_config = MultiModalConfig()

device_config = DeviceConfig(device=self.device)
Expand Down Expand Up @@ -666,7 +686,9 @@ def create_engine_config(self, ) -> EngineConfig:
cache_dtype=self.kv_cache_dtype,
num_gpu_blocks_override=self.num_gpu_blocks_override,
sliding_window=model_config.get_sliding_window(),
enable_prefix_caching=self.enable_prefix_caching)
enable_prefix_caching=self.enable_prefix_caching,
cpu_offload_gb=self.cpu_offload_gb,
)
parallel_config = ParallelConfig(
pipeline_parallel_size=self.pipeline_parallel_size,
tensor_parallel_size=self.tensor_parallel_size,
Expand Down
6 changes: 6 additions & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ class LLM:
when their `best_of` sampling parameters are larger than 1. If all
requests will have `best_of=1`, you can safely set this to 0.
Otherwise, too small values may cause out-of-memory (OOM) errors.
cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
the model weights. This virtually increases the GPU memory space
you can use to hold the model weights, at the cost of CPU-GPU data
transfer for every forward pass.
enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
Expand Down Expand Up @@ -114,6 +118,7 @@ def __init__(
seed: int = 0,
gpu_memory_utilization: float = 0.9,
swap_space: int = 4,
cpu_offload_gb: float = 0,
enforce_eager: bool = False,
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: int = 8192,
Expand Down Expand Up @@ -141,6 +146,7 @@ def __init__(
seed=seed,
gpu_memory_utilization=gpu_memory_utilization,
swap_space=swap_space,
cpu_offload_gb=cpu_offload_gb,
enforce_eager=enforce_eager,
max_context_len_to_capture=max_context_len_to_capture,
max_seq_len_to_capture=max_seq_len_to_capture,
Expand Down
73 changes: 70 additions & 3 deletions vllm/model_executor/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Callable, Dict, List, Tuple

import torch
from torch.func import functional_call

from vllm.multimodal import BatchedTensors
from vllm.utils import is_pin_memory_available


def merge_vision_embeddings(input_ids: torch.Tensor,
Expand Down Expand Up @@ -52,6 +54,70 @@ def __init__(self, *args, **kwargs):
super().__init__()


_CPU_OFFLOAD_BYTES = 0
_CPU_OFFLOAD_MAX_BYTES = 0


def set_cpu_offload_max_bytes(max_bytes: int) -> None:
global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
_CPU_OFFLOAD_BYTES = 0
_CPU_OFFLOAD_MAX_BYTES = max_bytes


def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
device = next(module.parameters()).device

if device == torch.device("cpu"):
return module

global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
return module

pin_memory = is_pin_memory_available()

# offload parameters to CPU
# use pin_memory if possible, which helps cudagraph capture speed
for p in module.parameters():
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
# we use per-parameter offloading
# one module might have some parameters offloaded and some not
break

# `torch.empty_like` does not support `pin_memory` argument
cpu_data = torch.empty(size=p.data.size(),
dtype=p.data.dtype,
layout=p.data.layout,
device='cpu',
pin_memory=pin_memory)
cpu_data.copy_(p.data)
p.data = cpu_data
_CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()

state_dict: Dict[str, torch.Tensor] = module.state_dict()

original_forward = module.forward

def forward(*args, **kwargs):
module.forward = original_forward
device_state = {
# here we blindly call `to(device)`
# if the parameter is already on the device, it will be a no-op
k: v.to(device, non_blocking=True)
for k, v in state_dict.items()
}
output = functional_call(module,
device_state,
args=args,
kwargs=kwargs)
module.forward = forward
return output

module.forward = forward

return module


def make_layers(
num_hidden_layers: int, layer_fn: Callable[[], torch.nn.Module]
) -> Tuple[int, int, torch.nn.ModuleList]:
Expand All @@ -64,9 +130,10 @@ def make_layers(
get_pp_group().rank_in_group,
get_pp_group().world_size)
modules = torch.nn.ModuleList(
[PPMissingLayer() for _ in range(start_layer)] +
[layer_fn() for _ in range(start_layer, end_layer)] +
[PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
[PPMissingLayer() for _ in range(start_layer)] + [
maybe_offload_to_cpu(layer_fn())
for _ in range(start_layer, end_layer)
] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
return start_layer, end_layer, modules


Expand Down
4 changes: 4 additions & 0 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models.interfaces import (supports_lora,
supports_vision)
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
MultiModalInputs)
from vllm.prompt_adapter.layers import PromptAdapterMapping
Expand Down Expand Up @@ -544,6 +545,9 @@ def __init__(
self.flashinfer_prefill_workspace_buffer = None
self.flashinfer_prefill_wrapper = None

set_cpu_offload_max_bytes(
int(self.cache_config.cpu_offload_gb * 1024**3))

def load_model(self) -> None:
with CudaMemoryProfiler() as m:
self.model = get_model(model_config=self.model_config,
Expand Down

0 comments on commit 1c27d25

Please sign in to comment.