Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support multi cards in ascend graph mode #2755

Merged
merged 3 commits into from
Nov 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 12 additions & 16 deletions lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def __init__(self, model: torch.nn.Module, model_config: ModelConfig,
super().__init__(model, model_config, cache_config, backend_config,
device)

self.supported_model = ['Llama3-8B', 'Llama2-7B', 'Qwen2-7B']
self.enable_graph = self.check_enable_graph()
if self.enable_graph:
import dlinfer.graph
Expand All @@ -39,26 +38,23 @@ def check_enable_graph(self):
# eager_mode
if self.backend_config.eager_mode:
return False
# tp
if torch.distributed.is_initialized():
warnings.warn(
"Graph mode of device_type 'ascend' only supports tp=1 "
'for now, fallback to eager mode', RuntimeWarning)
return False

warnings.warn(
'\n\n'
'**********************************************************\n'
' The following models were tested in graph mode of\n'
" device_type 'ascend' when tp=1:\n"
f" {', '.join(self.supported_model)}\n"
' Other LLaMa-like models may work in graph mode, please\n'
' check the result yourself!\n'
' If graph mode does not work correctly with your model,\n'
' please use eager mode instead.\n'
'**********************************************************\n\n',
'************************************************************\n'
' Graph mode is an experimental feature. We currently\n'
' support both dense and Mixture of Experts (MoE) models\n'
' with bf16 and fp16 data types.\n'
' If graph mode does not function correctly with your model,\n'
' please consider using eager mode as an alternative.\n'
'************************************************************\n\n',
RuntimeWarning)

# tp
if torch.distributed.is_initialized():
torch._inductor.config.compile_threads = 1
return True

return True

def patch_kernels_custom_op(self):
Expand Down
Loading