Skip to content

Commit

Permalink
pass optimizations flags only in Lazy mode
Browse files Browse the repository at this point in the history
  • Loading branch information
nirda7 committed Jun 27, 2024
1 parent 90c2527 commit f7c2157
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,6 @@ def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, dtype):
def forward(self, *args, **kwargs):
kwargs = kwargs.copy()
selected_token_indices = kwargs.pop('selected_token_indices')
if 'bypass_hpu_graphs' in kwargs:
kwargs.pop('bypass_hpu_graphs') # required for PT eager
input_ids = kwargs['input_ids']
kwargs['attn_metadata'] = self._set_attn_bias(kwargs['attn_metadata'], input_ids.size(0), input_ids.size(1), input_ids.device, torch.bfloat16)
hidden_states = self.model(*args, **kwargs)
Expand Down Expand Up @@ -866,14 +864,15 @@ def execute_model(
}
if self.vision_language_config:
execute_model_kwargs.update({"image_input": multi_modal_input})

if htorch.utils.internal.is_lazy():
execute_model_kwargs.update({"bypass_hpu_graphs":not use_graphs, "warmup_mode":warmup_mode})
htorch.core.mark_step()
if self.is_driver_worker:
model_event_name = f"model_{'prompt' if is_prompt else 'decode'}_bs{batch_size}_seq{seq_len}_graphs{'T' if use_graphs else 'F'}"
else:
model_event_name = 'model_executable'
with self.profiler.record_event('internal', model_event_name):
hidden_states = self.model.forward(**execute_model_kwargs, selected_token_indices=sampling_metadata.selected_token_indices, bypass_hpu_graphs=not use_graphs, warmup_mode=warmup_mode)
hidden_states = self.model.forward(**execute_model_kwargs, selected_token_indices=sampling_metadata.selected_token_indices)

# Compute the logits.
with self.profiler.record_event('internal', f'compute_logits_{"prompt" if is_prompt else "decode"}_bs{batch_size}_seq{seq_len}'):
Expand Down

0 comments on commit f7c2157

Please sign in to comment.