From 5211a3dcac765cd67d8f751834cb9cb5fc404907 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 12 Sep 2024 23:00:44 -0700 Subject: [PATCH] add torch compile backend in plugin --- vllm/plugins/__init__.py | 13 +++++++++++++ vllm/worker/model_runner.py | 4 +++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index 765f74fe7356f..7939688ef0da3 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -1,4 +1,5 @@ import logging +from typing import Callable, Optional, Union import vllm.envs as envs @@ -29,3 +30,15 @@ def load_general_plugins(): except Exception: logger.exception("Failed to load general plugin: %s", plugin.name) + + +_torch_compile_backend: Optional[Union[Callable, str]] = None + + +def set_torch_compile_backend(backend: Union[Callable, str]): + global _torch_compile_backend + _torch_compile_backend = backend + + +def get_torch_compile_backend() -> Optional[Union[Callable, str]]: + return _torch_compile_backend diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index acb7bafefc204..bff789c429710 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1064,10 +1064,12 @@ def load_model(self) -> None: "This may lead to less accurate results!") if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo(): + from vllm.plugins import get_torch_compile_backend + backend = get_torch_compile_backend() or "eager" self.model = torch.compile( self.model, fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, - backend="eager") + backend=backend) def save_sharded_state( self,