From 3a39a8da66f8ac36cc9b00c719a5aa433a261d1d Mon Sep 17 00:00:00 2001 From: Jani Monoses Date: Mon, 18 Nov 2024 16:51:05 +0200 Subject: [PATCH 1/6] Add simple CPU offloading support. --- .../sglang/srt/model_executor/model_runner.py | 5 +- python/sglang/srt/models/gemma2.py | 15 ++- python/sglang/srt/models/llama.py | 15 +-- python/sglang/srt/server_args.py | 8 ++ python/sglang/srt/utils.py | 91 ++++++++++++++++++- 5 files changed, 120 insertions(+), 14 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 3144efe84b..1717b62432 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -63,6 +63,7 @@ is_hip, monkey_patch_vllm_model_config, monkey_patch_vllm_p2p_access_check, + set_cpu_offload_max_bytes, ) logger = logging.getLogger(__name__) @@ -147,7 +148,9 @@ def __init__( } ) - # Init componnets + set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3)) + + # Init components min_per_gpu_memory = self.init_torch_distributed() self.sampler = Sampler() self.load_model() diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index 36749f22f4..f8bb6d8ae7 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -38,6 +38,7 @@ from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.utils import make_layers # Aligned with HF's implementation, using sliding window inclusive with the last token @@ -267,11 +268,15 @@ def __init__( config.vocab_size, config.hidden_size, ) - self.layers = nn.ModuleList( - [ - Gemma2DecoderLayer(layer_id, config, cache_config, quant_config) - for layer_id in range(config.num_hidden_layers) - ] + self.layers = make_layers( + config.num_hidden_layers, + lambda idx, prefix: Gemma2DecoderLayer( + layer_id=idx, + config=config, + cache_config=cache_config, + quant_config=quant_config, + ), + prefix="", ) self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 284334396c..0bf65cbc56 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -44,6 +44,7 @@ ) from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.utils import make_layers class LlamaMLP(nn.Module): @@ -256,14 +257,14 @@ def __init__( config.vocab_size, config.hidden_size, ) - self.layers = nn.ModuleList( - [ - LlamaDecoderLayer( - config, i, quant_config=quant_config, prefix=f"model.layers.{i}" - ) - for i in range(config.num_hidden_layers) - ] + self.layers = make_layers( + config.num_hidden_layers, + lambda idx, prefix: LlamaDecoderLayer( + config=config, quant_config=quant_config, layer_id=idx, prefix=prefix + ), + prefix="model.layers", ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 5487f772f4..6fb62c5cd5 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -64,6 +64,7 @@ class ServerArgs: max_prefill_tokens: int = 16384 schedule_policy: str = "lpm" schedule_conservativeness: float = 1.0 + cpu_offload_gb: int = 0 # Other runtime options tp_size: int = 1 @@ -374,6 +375,13 @@ def add_cli_args(parser: argparse.ArgumentParser): help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.", ) + parser.add_argument( + "--cpu-offload-gb", + type=int, + default=ServerArgs.cpu_offload_gb, + help="How many GBs of RAM to reserve for CPU offloading", + ) + # Other runtime options parser.add_argument( "--tensor-parallel-size", diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index d177a0bf82..a4aac75991 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -33,7 +33,7 @@ import warnings from importlib.metadata import PackageNotFoundError, version from io import BytesIO -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Protocol, Tuple, Union import numpy as np import psutil @@ -46,6 +46,7 @@ from packaging import version as pkg_version from starlette.routing import Mount from torch import nn +from torch.func import functional_call from torch.profiler import ProfilerActivity, profile, record_function from triton.runtime.cache import ( FileCacheManager, @@ -192,6 +193,94 @@ def get_available_gpu_memory(device, gpu_id, distributed=False): return free_gpu_memory / (1 << 30) +def is_pin_memory_available() -> bool: + return torch.cuda.is_available() + + +_CPU_OFFLOAD_BYTES = 0 +_CPU_OFFLOAD_MAX_BYTES = 0 + + +def set_cpu_offload_max_bytes(max_bytes: int) -> None: + global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES + _CPU_OFFLOAD_BYTES = 0 + _CPU_OFFLOAD_MAX_BYTES = max_bytes + + +def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module: + device = next(module.parameters()).device + + if device == torch.device("cpu"): + return module + + global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES + if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES: + return module + + pin_memory = is_pin_memory_available() + # offload parameters to CPU + # use pin_memory if possible, which helps cudagraph capture speed + offloaded_parameters = False + for p in module.parameters(): + if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES: + # we use per-parameter offloading + # one module might have some parameters offloaded and some not + break + + # `torch.empty_like` does not support `pin_memory` argument + cpu_data = torch.empty_strided( + size=p.data.size(), + stride=p.data.stride(), + dtype=p.data.dtype, + layout=p.data.layout, + device="cpu", + pin_memory=pin_memory, + ) + cpu_data.copy_(p.data) + p.data = cpu_data + _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size() + offloaded_parameters = True + + if offloaded_parameters: + original_forward = module.forward + + def forward(*args, **kwargs): + module.forward = original_forward + device_state = { + # here we blindly call `to(device)` + # if the parameter is already on the device, it will be a no-op + k: v.to(device, non_blocking=True) + for k, v in module.state_dict().items() + } + output = functional_call(module, device_state, args=args, kwargs=kwargs) + module.forward = forward + return output + + module.forward = forward + + return module + + +class LayerFn(Protocol): + + def __call__(self, layer_id: int, prefix: str) -> torch.nn.Module: ... + + +def make_layers( + num_hidden_layers: int, + layer_fn: LayerFn, + prefix: str, +) -> Tuple[int, int, torch.nn.ModuleList]: + """Make a list of layers with the given layer function""" + modules = torch.nn.ModuleList( + [ + maybe_offload_to_cpu(layer_fn(idx=idx, prefix=f"{prefix}.{idx}")) + for idx in range(num_hidden_layers) + ] + ) + return modules + + def set_random_seed(seed: int) -> None: """Set the random seed for all libraries.""" random.seed(seed) From f6211c3f4b33025564bd23d75f7ba6e0a9d4c139 Mon Sep 17 00:00:00 2001 From: Jani Monoses Date: Mon, 18 Nov 2024 19:53:34 +0200 Subject: [PATCH 2/6] Qwen2 CPU offload support --- python/sglang/srt/models/qwen2.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 796e34a4a4..1c18c78090 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -40,6 +40,7 @@ VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.utils import make_layers Qwen2Config = None @@ -230,11 +231,14 @@ def __init__( config.vocab_size, config.hidden_size, ) - self.layers = nn.ModuleList( - [ - Qwen2DecoderLayer(config, i, quant_config=quant_config) - for i in range(config.num_hidden_layers) - ] + self.layers = make_layers( + config.num_hidden_layers, + lambda idx, prefix: Qwen2DecoderLayer( + layer_id=idx, + config=config, + quant_config=quant_config, + ), + prefix="", ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) From 901ad3d5a8eda8178e123c65f9802988572a80a6 Mon Sep 17 00:00:00 2001 From: Jani Monoses Date: Mon, 18 Nov 2024 22:15:31 +0200 Subject: [PATCH 3/6] Add CPU offload test case --- test/srt/test_srt_engine.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/test/srt/test_srt_engine.py b/test/srt/test_srt_engine.py index 5d7f954408..a17fdea8bf 100644 --- a/test/srt/test_srt_engine.py +++ b/test/srt/test_srt_engine.py @@ -160,6 +160,38 @@ def test_7_engine_offline_throughput(self): result = throughput_test(server_args=server_args, bench_args=bench_args) self.assertGreater(result["total_throughput"], 3500) + def test_8_engine_cpu_offload(self): + prompt = "Today is a sunny day and I like" + model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + + sampling_params = {"temperature": 0, "max_new_tokens": 8} + + engine = sgl.Engine( + model_path=model_path, + random_seed=42, + max_total_tokens=128, + log_level="error", + ) + out1 = engine.generate(prompt, sampling_params)["text"] + engine.shutdown() + + engine = sgl.Engine( + model_path=model_path, + random_seed=42, + max_total_tokens=128, + log_level="error", + cpu_offload_gb=3, + ) + out2 = engine.generate(prompt, sampling_params)["text"] + engine.shutdown() + + print("==== Answer 1 ====") + print(out1) + + print("==== Answer 2 ====") + print(out2) + self.assertEqual(out1, out2) + if __name__ == "__main__": unittest.main() From 969e672ad333e57f5586723944b5babf017a6981 Mon Sep 17 00:00:00 2001 From: Jani Monoses Date: Thu, 21 Nov 2024 23:01:20 +0200 Subject: [PATCH 4/6] Support OLMo and OLMoE --- python/sglang/srt/models/olmo.py | 13 ++++++++----- python/sglang/srt/models/olmoe.py | 13 ++++++++----- python/sglang/srt/models/qwen2.py | 1 - python/sglang/srt/utils.py | 2 +- 4 files changed, 17 insertions(+), 12 deletions(-) diff --git a/python/sglang/srt/models/olmo.py b/python/sglang/srt/models/olmo.py index d073ca3b62..4aa998e481 100644 --- a/python/sglang/srt/models/olmo.py +++ b/python/sglang/srt/models/olmo.py @@ -39,6 +39,7 @@ VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.utils import make_layers class OlmoAttention(nn.Module): @@ -221,11 +222,13 @@ def __init__( self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size ) - self.layers = nn.ModuleList( - [ - OlmoDecoderLayer(config, layer_id, quant_config) - for layer_id in range(config.num_hidden_layers) - ] + self.layers = make_layers( + config.num_hidden_layers, + lambda idx, prefix: OlmoDecoderLayer( + layer_id=idx, + config=config, + quant_config=quant_config, + ), ) self.norm = nn.LayerNorm( config.hidden_size, elementwise_affine=False, bias=False diff --git a/python/sglang/srt/models/olmoe.py b/python/sglang/srt/models/olmoe.py index a33523847d..3593082722 100644 --- a/python/sglang/srt/models/olmoe.py +++ b/python/sglang/srt/models/olmoe.py @@ -48,6 +48,7 @@ VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.utils import make_layers class OlmoeMoE(nn.Module): @@ -261,11 +262,13 @@ def __init__( config.vocab_size, config.hidden_size, ) - self.layers = nn.ModuleList( - [ - OlmoeDecoderLayer(config, layer_id, quant_config=quant_config) - for layer_id in range(config.num_hidden_layers) - ] + self.layers = make_layers( + config.num_hidden_layers, + lambda idx, prefix: OlmoeDecoderLayer( + config=config, + quant_config=quant_config, + layer_id=idx, + ), ) self.norm = RMSNorm(config.hidden_size, eps=1e-5) diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 1c18c78090..2579e91c8f 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -238,7 +238,6 @@ def __init__( config=config, quant_config=quant_config, ), - prefix="", ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index a4aac75991..d4a9d9c2e8 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -269,7 +269,7 @@ def __call__(self, layer_id: int, prefix: str) -> torch.nn.Module: ... def make_layers( num_hidden_layers: int, layer_fn: LayerFn, - prefix: str, + prefix: str = "", ) -> Tuple[int, int, torch.nn.ModuleList]: """Make a list of layers with the given layer function""" modules = torch.nn.ModuleList( From d5c883915f247cdd1bb867dce1911decac805241 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Fri, 22 Nov 2024 21:59:29 -0800 Subject: [PATCH 5/6] Update test/srt/test_srt_engine.py --- test/srt/test_srt_engine.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/srt/test_srt_engine.py b/test/srt/test_srt_engine.py index a17fdea8bf..4be14b4eb9 100644 --- a/test/srt/test_srt_engine.py +++ b/test/srt/test_srt_engine.py @@ -179,7 +179,6 @@ def test_8_engine_cpu_offload(self): model_path=model_path, random_seed=42, max_total_tokens=128, - log_level="error", cpu_offload_gb=3, ) out2 = engine.generate(prompt, sampling_params)["text"] From d57b974793ff52547a54558d87dfdb9e6b06234a Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Fri, 22 Nov 2024 21:59:34 -0800 Subject: [PATCH 6/6] Update test/srt/test_srt_engine.py --- test/srt/test_srt_engine.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/srt/test_srt_engine.py b/test/srt/test_srt_engine.py index 4be14b4eb9..f0dfa8f85a 100644 --- a/test/srt/test_srt_engine.py +++ b/test/srt/test_srt_engine.py @@ -170,7 +170,6 @@ def test_8_engine_cpu_offload(self): model_path=model_path, random_seed=42, max_total_tokens=128, - log_level="error", ) out1 = engine.generate(prompt, sampling_params)["text"] engine.shutdown()