From f7c21579aaf483d8c438dc4540ad35a27992e671 Mon Sep 17 00:00:00 2001 From: Nir David Date: Thu, 27 Jun 2024 14:02:32 +0300 Subject: [PATCH] pass optimizations flags only in Lazy mode --- vllm/worker/habana_model_runner.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index c928552256856..1bff70e49c3af 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -145,8 +145,6 @@ def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, dtype): def forward(self, *args, **kwargs): kwargs = kwargs.copy() selected_token_indices = kwargs.pop('selected_token_indices') - if 'bypass_hpu_graphs' in kwargs: - kwargs.pop('bypass_hpu_graphs') # required for PT eager input_ids = kwargs['input_ids'] kwargs['attn_metadata'] = self._set_attn_bias(kwargs['attn_metadata'], input_ids.size(0), input_ids.size(1), input_ids.device, torch.bfloat16) hidden_states = self.model(*args, **kwargs) @@ -866,14 +864,15 @@ def execute_model( } if self.vision_language_config: execute_model_kwargs.update({"image_input": multi_modal_input}) - + if htorch.utils.internal.is_lazy(): + execute_model_kwargs.update({"bypass_hpu_graphs":not use_graphs, "warmup_mode":warmup_mode}) htorch.core.mark_step() if self.is_driver_worker: model_event_name = f"model_{'prompt' if is_prompt else 'decode'}_bs{batch_size}_seq{seq_len}_graphs{'T' if use_graphs else 'F'}" else: model_event_name = 'model_executable' with self.profiler.record_event('internal', model_event_name): - hidden_states = self.model.forward(**execute_model_kwargs, selected_token_indices=sampling_metadata.selected_token_indices, bypass_hpu_graphs=not use_graphs, warmup_mode=warmup_mode) + hidden_states = self.model.forward(**execute_model_kwargs, selected_token_indices=sampling_metadata.selected_token_indices) # Compute the logits. with self.profiler.record_event('internal', f'compute_logits_{"prompt" if is_prompt else "decode"}_bs{batch_size}_seq{seq_len}'):