Skip to content

Commit

Permalink
[v1] reduce graph capture time for piecewise cudagraph (vllm-project#…
Browse files Browse the repository at this point in the history
…10059)

Signed-off-by: youkaichao <youkaichao@gmail.com>
  • Loading branch information
youkaichao authored Nov 6, 2024
1 parent 0c63c34 commit c4cacba
Showing 1 changed file with 25 additions and 11 deletions.
36 changes: 25 additions & 11 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import copy
import dataclasses
import operator
from contextlib import ExitStack
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from unittest.mock import patch

import torch
import torch.fx as fx
Expand Down Expand Up @@ -503,17 +505,29 @@ def __call__(self, *args) -> Any:
entry.input_addresses = input_addresses
cudagraph = torch.cuda.CUDAGraph()

# mind-exploding: carefully manage the reference and memory.
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
# `output` is managed by pytorch's cudagraph pool
output = entry.runnable(*args)
if self.is_last_graph:
# by converting it to weak ref,
# the original `output` will immediately be released
# to save memory. It is only safe to do this for
# the last graph, because the output of the last graph
# will not be used by any other cuda graph.
output = weak_ref_tensors(output)
with ExitStack() as stack:
if not self.is_first_graph:
# during every model forward, we will capture
# many pieces of cudagraphs (roughly one per layer).
# running gc again and again across layers will
# make the cudagraph capture very slow.
# therefore, we only run gc for the first graph,
# and disable gc for the rest of the graphs.
stack.enter_context(patch("gc.collect", lambda: None))
stack.enter_context(
patch("torch.cuda.empty_cache", lambda: None))

# mind-exploding: carefully manage the reference and memory.
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
# `output` is managed by pytorch's cudagraph pool
output = entry.runnable(*args)
if self.is_last_graph:
# by converting it to weak ref,
# the original `output` will immediately be released
# to save memory. It is only safe to do this for
# the last graph, because the output of the last graph
# will not be used by any other cuda graph.
output = weak_ref_tensors(output)

# here we always use weak ref for the output
# to save memory
Expand Down

0 comments on commit c4cacba

Please sign in to comment.