Skip to content

Commit

Permalink
[plugin][torch.compile] allow to add custom compile backend (vllm-pro…
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao authored and siddharth9820 committed Sep 30, 2024
1 parent bcf00b4 commit 1368786
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
13 changes: 13 additions & 0 deletions vllm/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from typing import Callable, Optional, Union

import vllm.envs as envs

Expand Down Expand Up @@ -29,3 +30,15 @@ def load_general_plugins():
except Exception:
logger.exception("Failed to load general plugin: %s",
plugin.name)


_torch_compile_backend: Optional[Union[Callable, str]] = None


def set_torch_compile_backend(backend: Union[Callable, str]):
global _torch_compile_backend
_torch_compile_backend = backend


def get_torch_compile_backend() -> Optional[Union[Callable, str]]:
return _torch_compile_backend
4 changes: 3 additions & 1 deletion vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,10 +1064,12 @@ def load_model(self) -> None:
"This may lead to less accurate results!")

if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo():
from vllm.plugins import get_torch_compile_backend
backend = get_torch_compile_backend() or "eager"
self.model = torch.compile(
self.model,
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
backend="eager")
backend=backend)

def save_sharded_state(
self,
Expand Down

0 comments on commit 1368786

Please sign in to comment.