Skip to content

Commit

Permalink
[7/N] torch.compile, reduce compilation time (vllm-project#10460)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <youkaichao@gmail.com>
  • Loading branch information
youkaichao authored Nov 20, 2024
1 parent 5f1d6af commit 0cd3d97
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 16 deletions.
2 changes: 1 addition & 1 deletion tests/compile/piecewise/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_simple_piecewise_compile():
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
non_cudagraph_ops=["silly.attention"],
splitting_ops=["silly.attention"],
cudagraph_copy_inputs=True,
))
with set_current_vllm_config(vllm_config):
Expand Down
4 changes: 2 additions & 2 deletions tests/compile/piecewise/test_toy_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def run_model(llama_config,
use_cudagraph=True,
)
if split_attn:
compilation_config.non_cudagraph_ops = ["silly.attention"]
compilation_config.splitting_ops = ["silly.attention"]
else:
compilation_config = CompilationConfig(
level=CompilationLevel.NO_COMPILATION, )
Expand Down Expand Up @@ -378,7 +378,7 @@ def benchmark():
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
non_cudagraph_ops=["silly.attention"],
splitting_ops=["silly.attention"],
)
else:
compilation_config = CompilationConfig(
Expand Down
2 changes: 1 addition & 1 deletion vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
self.add_passes_to_config()

self.split_gm, self.piecewise_graphs = split_graph(
graph, self.compilation_configs.non_cudagraph_ops)
graph, self.compilation_configs.splitting_ops)

from torch._dynamo.utils import lazy_format_graph_code
logger.debug("%s", lazy_format_graph_code("before split", self.graph))
Expand Down
17 changes: 10 additions & 7 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2089,13 +2089,15 @@ class CompilationConfig(BaseModel):
- 'none,+op1,+op2' to enable only op1 and op2
By default, all custom ops are enabled when running without Inductor
and disabled when running with Inductor (compile_level >= Inductor).
- splitting_ops: a list of ops to split the full graph into subgraphs, used in piecewise compilation.
- CudaGraph capture:
- use_cudagraph: whether to use cudagraph inside compilation.
- False: cudagraph inside compilation is not used.
- True: cudagraph inside compilation is used. It requires
that all input buffers have fixed addresses.
Note that this is orthogonal to the cudagraph capture out
side of compilation.
that all input buffers have fixed addresses, and all
splitting ops write their outputs to input buffers.
Note that this is orthogonal to the cudagraph capture logic
outside of compilation.
TODO: move outside cudagraph logic into compilation.
torch.compile will handle cudagraph capture logic in the future.
- cudagraph_capture_sizes: sizes to capture cudagraph.
Expand Down Expand Up @@ -2149,6 +2151,11 @@ class CompilationConfig(BaseModel):
level: int = 0
backend: str = ""
custom_ops: List[str] = Field(default_factory=list)
splitting_ops: List[str] = Field(default_factory=lambda: [
"vllm.unified_flash_attention",
"vllm.unified_flash_infer",
"vllm.unified_v1_flash_attention",
])

use_inductor: bool = True
inductor_specialize_for_cudagraph_no_more_than: Optional[int] = None
Expand All @@ -2157,7 +2164,6 @@ class CompilationConfig(BaseModel):
inductor_passes: Dict[str, str] = Field(default_factory=dict)

use_cudagraph: bool = False
non_cudagraph_ops: List[str] = Field(default_factory=list)
cudagraph_num_of_warmups: int = 0
cudagraph_capture_sizes: Optional[List[int]] = None
cudagraph_copy_inputs: bool = False
Expand Down Expand Up @@ -2348,9 +2354,6 @@ def __post_init__(self):
# and avoid any potential issues with the inductor.
self.compilation_config.custom_ops = ["none"]
self.compilation_config.use_cudagraph = True
self.compilation_config.non_cudagraph_ops = [
"vllm.unified_v1_flash_attention"
]
self.compilation_config.use_inductor = True
self.compilation_config.enable_fusion = False

Expand Down
18 changes: 13 additions & 5 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""A GPU worker class."""
import gc
import os
import time
from typing import Dict, List, Optional, Set, Tuple, Type, Union

import torch
Expand Down Expand Up @@ -189,6 +190,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
torch.cuda.reset_peak_memory_stats()

free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info()
start_time = time.time()

# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
Expand Down Expand Up @@ -229,12 +231,18 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0)

end_time = time.time()
logger.info(
"Memory profiling results: total_gpu_memory=%.2fGiB"
" initial_memory_usage=%.2fGiB peak_torch_memory=%.2fGiB"
" memory_usage_post_profile=%.2fGiB"
" non_torch_memory=%.2fGiB kv_cache_size=%.2fGiB"
" gpu_memory_utilization=%.2f", total_gpu_memory / (1024**3),
"Memory profiling results: "
"duration=%.2f seconds, "
"total_gpu_memory=%.2fGiB, "
"initial_memory_usage=%.2fGiB, "
"peak_torch_memory=%.2fGiB, "
"memory_usage_post_profile=%.2fGiB, "
"non_torch_memory=%.2fGiB, "
"kv_cache_size=%.2fGiB, "
"gpu_memory_utilization=%.2f.", end_time - start_time,
total_gpu_memory / (1024**3),
(total_gpu_memory - free_memory_pre_profile) / (1024**3),
(peak_memory - non_torch_allocations) / (1024**3),
total_allocated_bytes / (1024**3),
Expand Down

0 comments on commit 0cd3d97

Please sign in to comment.