diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index ca8d58fd3d6aa..a76c3757be739 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -538,20 +538,20 @@ def _prepare_current_run_mamba_cache( def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): """ - Copy the relevant Mamba cache into the CUDA graph input buffer - that was provided during the capture runs - (MambaForCausalLM.mamba_gc_cache_buffer). + Copy the relevant Mamba cache into the CUDA graph input buffer + that was provided during the capture runs + (MambaForCausalLM.mamba_gc_cache_buffer). """ assert all( key in kwargs for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] - batch_size = len(request_ids_to_seq_ids) + cg_batch_size = input_buffers['input_ids'].shape[0] ( current_mamba_cache, indices, ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, - batch_size) + cg_batch_size) self.current_indices = indices finished_requests_ids = kwargs["finished_requests_ids"] self._release_mamba_cache(finished_requests_ids)