Skip to content

Commit

Permalink
[torch.compile] rework compile control with piecewise cudagraph (vllm…
Browse files Browse the repository at this point in the history
…-project#9715)

Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: Loc Huynh <jc1da.3011@gmail.com>
  • Loading branch information
youkaichao authored and JC1DA committed Nov 11, 2024
1 parent 2337338 commit 6453eb9
Show file tree
Hide file tree
Showing 17 changed files with 983 additions and 106 deletions.
3 changes: 3 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,9 @@ steps:
- tests/compile
commands:
- pytest -v -s compile/test_basic_correctness.py
# these tests need to be separated, cannot combine
- pytest -v -s compile/piecewise/test_simple.py
- pytest -v -s compile/piecewise/test_toy_llama.py

- label: "PyTorch Fullgraph Test" # 18min
source_file_dependencies:
Expand Down
Empty file.
4 changes: 4 additions & 0 deletions tests/compile/piecewise/piecewise_compilation_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"use_cudagraph": true,
"non_cudagraph_ops": ["silly.attention"]
}
96 changes: 96 additions & 0 deletions tests/compile/piecewise/test_simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""
Test the piecewise compilation with a simple model so that we
can exactly calculate the expected output and side effects.
"""
import os

import torch
from torch import nn

from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.levels import CompilationLevel

os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)

global_counter = 0


@torch.library.custom_op("silly::attention", mutates_args=["out"])
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
global global_counter
global_counter += 1
print(f"{global_counter=}")
out.copy_(q)
out[0] += 1


@silly_attention.register_fake
def _(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
return


@support_torch_compile
class SillyModel(nn.Module):

def __init__(self) -> None:
super().__init__()

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Overall effect:
x += 1
x[0] += 2
global_counter += 2
"""
x = x + 1
x = x + 2
out = torch.empty_like(x)
torch.ops.silly.attention(x, x, x, out)
x = out
x = x - 2
x = x - 1
out = torch.empty_like(x)
torch.ops.silly.attention(x, x, x, out)
x = out
x = x + 1
return x


def test_simple_piecewise_compile():

model = SillyModel()

directory = os.path.dirname(__file__)
config = os.path.join(directory, "piecewise_compilation_config.json")
os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config

input_buffer = torch.randn(100).cuda()

with compilation_counter.expect(
num_graphs_seen=1, # one graph for the model
num_piecewise_graphs_seen=5, # 2 * num_layers + 1
num_piecewise_capturable_graphs_seen=3, # 1 + num_layers
num_inductor_compilations=3, # num_piecewise_capturable_graphs_seen
num_cudagraph_caputured=
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):

with set_compile_context([1, 2]):
model(input_buffer)

model(input_buffer[:2])
model(input_buffer[:1])

input_buffer[:2].zero_()
global global_counter
global_counter = 0
output = model(input_buffer[:2])
assert global_counter == 2
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))

# clean up to avoid side effects for other tests
del os.environ["VLLM_TORCH_COMPILE_CONFIG"]
Loading

0 comments on commit 6453eb9

Please sign in to comment.