Skip to content

Commit

Permalink
Add simple CPU offloading support. (#2081)
Browse files Browse the repository at this point in the history
  • Loading branch information
janimo authored Nov 23, 2024
1 parent 865233e commit d98fa1e
Show file tree
Hide file tree
Showing 9 changed files with 174 additions and 29 deletions.
5 changes: 4 additions & 1 deletion python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
is_hip,
monkey_patch_vllm_model_config,
monkey_patch_vllm_p2p_access_check,
set_cpu_offload_max_bytes,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -145,7 +146,9 @@ def __init__(
}
)

# Init componnets
set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))

# Init components
min_per_gpu_memory = self.init_torch_distributed()
self.sampler = Sampler()
self.load_model()
Expand Down
15 changes: 10 additions & 5 deletions python/sglang/srt/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import make_layers


# Aligned with HF's implementation, using sliding window inclusive with the last token
Expand Down Expand Up @@ -267,11 +268,15 @@ def __init__(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList(
[
Gemma2DecoderLayer(layer_id, config, cache_config, quant_config)
for layer_id in range(config.num_hidden_layers)
]
self.layers = make_layers(
config.num_hidden_layers,
lambda idx, prefix: Gemma2DecoderLayer(
layer_id=idx,
config=config,
cache_config=cache_config,
quant_config=quant_config,
),
prefix="",
)
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

Expand Down
15 changes: 8 additions & 7 deletions python/sglang/srt/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import make_layers


class LlamaMLP(nn.Module):
Expand Down Expand Up @@ -255,14 +256,14 @@ def __init__(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList(
[
LlamaDecoderLayer(
config, i, quant_config=quant_config, prefix=f"model.layers.{i}"
)
for i in range(config.num_hidden_layers)
]
self.layers = make_layers(
config.num_hidden_layers,
lambda idx, prefix: LlamaDecoderLayer(
config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
),
prefix="model.layers",
)

self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def forward(
Expand Down
13 changes: 8 additions & 5 deletions python/sglang/srt/models/olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import make_layers


class OlmoAttention(nn.Module):
Expand Down Expand Up @@ -220,11 +221,13 @@ def __init__(
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.hidden_size
)
self.layers = nn.ModuleList(
[
OlmoDecoderLayer(config, layer_id, quant_config)
for layer_id in range(config.num_hidden_layers)
]
self.layers = make_layers(
config.num_hidden_layers,
lambda idx, prefix: OlmoDecoderLayer(
layer_id=idx,
config=config,
quant_config=quant_config,
),
)
self.norm = nn.LayerNorm(
config.hidden_size, elementwise_affine=False, bias=False
Expand Down
13 changes: 8 additions & 5 deletions python/sglang/srt/models/olmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import make_layers


class OlmoeMoE(nn.Module):
Expand Down Expand Up @@ -261,11 +262,13 @@ def __init__(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList(
[
OlmoeDecoderLayer(config, layer_id, quant_config=quant_config)
for layer_id in range(config.num_hidden_layers)
]
self.layers = make_layers(
config.num_hidden_layers,
lambda idx, prefix: OlmoeDecoderLayer(
config=config,
quant_config=quant_config,
layer_id=idx,
),
)
self.norm = RMSNorm(config.hidden_size, eps=1e-5)

Expand Down
13 changes: 8 additions & 5 deletions python/sglang/srt/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import make_layers

Qwen2Config = None

Expand Down Expand Up @@ -230,11 +231,13 @@ def __init__(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList(
[
Qwen2DecoderLayer(config, i, quant_config=quant_config)
for i in range(config.num_hidden_layers)
]
self.layers = make_layers(
config.num_hidden_layers,
lambda idx, prefix: Qwen2DecoderLayer(
layer_id=idx,
config=config,
quant_config=quant_config,
),
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

Expand Down
8 changes: 8 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class ServerArgs:
max_prefill_tokens: int = 16384
schedule_policy: str = "lpm"
schedule_conservativeness: float = 1.0
cpu_offload_gb: int = 0

# Other runtime options
tp_size: int = 1
Expand Down Expand Up @@ -373,6 +374,13 @@ def add_cli_args(parser: argparse.ArgumentParser):
help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
)

parser.add_argument(
"--cpu-offload-gb",
type=int,
default=ServerArgs.cpu_offload_gb,
help="How many GBs of RAM to reserve for CPU offloading",
)

# Other runtime options
parser.add_argument(
"--tensor-parallel-size",
Expand Down
91 changes: 90 additions & 1 deletion python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import warnings
from importlib.metadata import PackageNotFoundError, version
from io import BytesIO
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Protocol, Tuple, Union

import numpy as np
import psutil
Expand All @@ -44,6 +44,7 @@
from packaging import version as pkg_version
from starlette.routing import Mount
from torch import nn
from torch.func import functional_call
from torch.profiler import ProfilerActivity, profile, record_function
from triton.runtime.cache import (
FileCacheManager,
Expand Down Expand Up @@ -190,6 +191,94 @@ def get_available_gpu_memory(device, gpu_id, distributed=False):
return free_gpu_memory / (1 << 30)


def is_pin_memory_available() -> bool:
return torch.cuda.is_available()


_CPU_OFFLOAD_BYTES = 0
_CPU_OFFLOAD_MAX_BYTES = 0


def set_cpu_offload_max_bytes(max_bytes: int) -> None:
global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
_CPU_OFFLOAD_BYTES = 0
_CPU_OFFLOAD_MAX_BYTES = max_bytes


def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
device = next(module.parameters()).device

if device == torch.device("cpu"):
return module

global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
return module

pin_memory = is_pin_memory_available()
# offload parameters to CPU
# use pin_memory if possible, which helps cudagraph capture speed
offloaded_parameters = False
for p in module.parameters():
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
# we use per-parameter offloading
# one module might have some parameters offloaded and some not
break

# `torch.empty_like` does not support `pin_memory` argument
cpu_data = torch.empty_strided(
size=p.data.size(),
stride=p.data.stride(),
dtype=p.data.dtype,
layout=p.data.layout,
device="cpu",
pin_memory=pin_memory,
)
cpu_data.copy_(p.data)
p.data = cpu_data
_CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
offloaded_parameters = True

if offloaded_parameters:
original_forward = module.forward

def forward(*args, **kwargs):
module.forward = original_forward
device_state = {
# here we blindly call `to(device)`
# if the parameter is already on the device, it will be a no-op
k: v.to(device, non_blocking=True)
for k, v in module.state_dict().items()
}
output = functional_call(module, device_state, args=args, kwargs=kwargs)
module.forward = forward
return output

module.forward = forward

return module


class LayerFn(Protocol):

def __call__(self, layer_id: int, prefix: str) -> torch.nn.Module: ...


def make_layers(
num_hidden_layers: int,
layer_fn: LayerFn,
prefix: str = "",
) -> Tuple[int, int, torch.nn.ModuleList]:
"""Make a list of layers with the given layer function"""
modules = torch.nn.ModuleList(
[
maybe_offload_to_cpu(layer_fn(idx=idx, prefix=f"{prefix}.{idx}"))
for idx in range(num_hidden_layers)
]
)
return modules


def set_random_seed(seed: int) -> None:
"""Set the random seed for all libraries."""
random.seed(seed)
Expand Down
30 changes: 30 additions & 0 deletions test/srt/test_srt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,36 @@ def test_7_engine_offline_throughput(self):
result = throughput_test(server_args=server_args, bench_args=bench_args)
self.assertGreater(result["total_throughput"], 3500)

def test_8_engine_cpu_offload(self):
prompt = "Today is a sunny day and I like"
model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST

sampling_params = {"temperature": 0, "max_new_tokens": 8}

engine = sgl.Engine(
model_path=model_path,
random_seed=42,
max_total_tokens=128,
)
out1 = engine.generate(prompt, sampling_params)["text"]
engine.shutdown()

engine = sgl.Engine(
model_path=model_path,
random_seed=42,
max_total_tokens=128,
cpu_offload_gb=3,
)
out2 = engine.generate(prompt, sampling_params)["text"]
engine.shutdown()

print("==== Answer 1 ====")
print(out1)

print("==== Answer 2 ====")
print(out2)
self.assertEqual(out1, out2)


if __name__ == "__main__":
unittest.main()

0 comments on commit d98fa1e

Please sign in to comment.