Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
apply fix from vllm-project#6214
Browse files Browse the repository at this point in the history
  • Loading branch information
tlrmchlsmth committed Jul 16, 2024
1 parent b733a84 commit fb846ce
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions vllm/model_executor/models/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,20 +538,20 @@ def _prepare_current_run_mamba_cache(

def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
"""
Copy the relevant Mamba cache into the CUDA graph input buffer
that was provided during the capture runs
(MambaForCausalLM.mamba_gc_cache_buffer).
Copy the relevant Mamba cache into the CUDA graph input buffer
that was provided during the capture runs
(MambaForCausalLM.mamba_gc_cache_buffer).
"""
assert all(
key in kwargs
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
batch_size = len(request_ids_to_seq_ids)
cg_batch_size = input_buffers['input_ids'].shape[0]
(
current_mamba_cache,
indices,
) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
batch_size)
cg_batch_size)
self.current_indices = indices
finished_requests_ids = kwargs["finished_requests_ids"]
self._release_mamba_cache(finished_requests_ids)
Expand Down

0 comments on commit fb846ce

Please sign in to comment.