forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[torch.compile] rework compile control with piecewise cudagraph (vllm…
…-project#9715) Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Loc Huynh <jc1da.3011@gmail.com>
- Loading branch information
1 parent
2337338
commit 6453eb9
Showing
17 changed files
with
983 additions
and
106 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
{ | ||
"use_cudagraph": true, | ||
"non_cudagraph_ops": ["silly.attention"] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
""" | ||
Test the piecewise compilation with a simple model so that we | ||
can exactly calculate the expected output and side effects. | ||
""" | ||
import os | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from vllm.compilation.compile_context import set_compile_context | ||
from vllm.compilation.counter import compilation_counter | ||
from vllm.compilation.decorators import support_torch_compile | ||
from vllm.compilation.levels import CompilationLevel | ||
|
||
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE) | ||
|
||
global_counter = 0 | ||
|
||
|
||
@torch.library.custom_op("silly::attention", mutates_args=["out"]) | ||
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, | ||
out: torch.Tensor) -> None: | ||
global global_counter | ||
global_counter += 1 | ||
print(f"{global_counter=}") | ||
out.copy_(q) | ||
out[0] += 1 | ||
|
||
|
||
@silly_attention.register_fake | ||
def _(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, | ||
out: torch.Tensor) -> None: | ||
return | ||
|
||
|
||
@support_torch_compile | ||
class SillyModel(nn.Module): | ||
|
||
def __init__(self) -> None: | ||
super().__init__() | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Overall effect: | ||
x += 1 | ||
x[0] += 2 | ||
global_counter += 2 | ||
""" | ||
x = x + 1 | ||
x = x + 2 | ||
out = torch.empty_like(x) | ||
torch.ops.silly.attention(x, x, x, out) | ||
x = out | ||
x = x - 2 | ||
x = x - 1 | ||
out = torch.empty_like(x) | ||
torch.ops.silly.attention(x, x, x, out) | ||
x = out | ||
x = x + 1 | ||
return x | ||
|
||
|
||
def test_simple_piecewise_compile(): | ||
|
||
model = SillyModel() | ||
|
||
directory = os.path.dirname(__file__) | ||
config = os.path.join(directory, "piecewise_compilation_config.json") | ||
os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config | ||
|
||
input_buffer = torch.randn(100).cuda() | ||
|
||
with compilation_counter.expect( | ||
num_graphs_seen=1, # one graph for the model | ||
num_piecewise_graphs_seen=5, # 2 * num_layers + 1 | ||
num_piecewise_capturable_graphs_seen=3, # 1 + num_layers | ||
num_inductor_compilations=3, # num_piecewise_capturable_graphs_seen | ||
num_cudagraph_caputured= | ||
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen | ||
): | ||
|
||
with set_compile_context([1, 2]): | ||
model(input_buffer) | ||
|
||
model(input_buffer[:2]) | ||
model(input_buffer[:1]) | ||
|
||
input_buffer[:2].zero_() | ||
global global_counter | ||
global_counter = 0 | ||
output = model(input_buffer[:2]) | ||
assert global_counter == 2 | ||
assert torch.allclose(output.cpu(), torch.tensor([3., 1.])) | ||
|
||
# clean up to avoid side effects for other tests | ||
del os.environ["VLLM_TORCH_COMPILE_CONFIG"] |
Oops, something went wrong.