Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Optimize block_manager_v2 vs block_manager_v1 (to make V2 default) #5602

Merged
merged 48 commits into from
Jul 2, 2024

Conversation

alexm-neuralmagic
Copy link
Collaborator

@alexm-neuralmagic alexm-neuralmagic commented Jun 17, 2024

This PR optimizes block_manager_v2 python logic to make it comparable to block_manager_v1. The goal is to enable block_manager_v2 by default as part of the spec decode project.

The issues optimized are:

  1. Python Block object allocations/deallocations are expensive on the hot-path of iterative batching, so a block pool is used to cache block objects.
  2. Any string/list duplication should be avoided, especially for token id lists
  3. Modified Prefix Caching Block/Allocator to avoid any full traversals of block_ids by using dynamic/incremental style computations
  4. Redid the way access all blocks updates timestamps by deferring the actual updates to free(..) of sequences

Here is initial performance comparison for both standard and prefix-cache enabled runs:

image

@alexm-neuralmagic alexm-neuralmagic marked this pull request as draft June 17, 2024 15:02
vllm/sequence.py Show resolved Hide resolved
block_ids = self.block_tables[seq.seq_id].physical_block_ids
assert all(b is not None for b in block_ids)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we keep these in for correctness? can have a flag strict_mode which checks these only in testing / not in production

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added "assert block_id is not None" checks into BlockList so the invariant of "assert all(b is not None for b in block_ids)" is always kept.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome

block_size: int,
block_id: Optional[int] = None):
# Please keep sync with the __init__()
# (Calling __init__() directly raises linter errors)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we ignore the linter error instead of duplicating code ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This actually works! Thanks for the suggestion


if block_token_ids:
blocks.extend(
self._allocator.allocate_immutable_group(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can name it allocate_immutable_blocks to reduce new concepts. can also rename the bs=1 path to be allocate_immutable_block so contrast is clear.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, renamed the functions as you proposed. In addition renamed allocate_mutable => allocate_mutable_block

Comment on lines 143 to 196
blocks = self._blocks[self._num_full_slots // self._block_size:]
blocks = self.blocks[self._num_full_slots // self._block_size:]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this working?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this invokes the property blocks(..) and it returns self._blocks.list()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh gotcha

token_ids=token_ids,
block_size=block_size,
block_id=physical_block_id)
block.block_pool_id = block_pool_id
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we avoid extending the block API for this optimization? we can keep a mapping of object address to block pool id in this class

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, just replaced with simple class member

vllm/core/block/naive_block.py Outdated Show resolved Hide resolved
vllm/core/block/naive_block.py Show resolved Hide resolved
assert block.block_id is not None
self._free_block_id(block.block_id)
block.block_id = None

def free(self, block: Block) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: for readability, have this invoke free_block_id instead of _free_block_id

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, modified to invoke free_block_id directly

@@ -149,6 +169,17 @@ def allocate_immutable(self, prev_block: Optional[Block],
return self._allocators[device].allocate_immutable(
prev_block, token_ids)

def free_block_id(self, block: Block) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran out of time to review today. Can you help me understand why we need a new API for this // if there's no way to combine free_block and free_block_id? ideally we have one way of freeing

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is that inside cow_block_if_not_appendable(..) (in common.py) we decrement ref count for the block_id for this block, and then in the caller, we reuse the same block object while assigning to its block_id the newly allocated block id (self._block_id = (self._allocator.cow_block_if_not_appendable(..)).

Same happens in prefix caching inside _free_block_id_for_block(..) when we promote a naive block to the immutable (prefix block) => we call return self._hashless_allocator.free_block_id(block), and at the caller reuse the same block object.

Without the block pool a free() was simply setting block.block_id = None, but with block pool, free(..) is actually releasing the block itself, so the second free_block_id() is behaving more similar to block.block_id = None

Copy link
Collaborator Author

@alexm-neuralmagic alexm-neuralmagic Jun 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will try to restructure the code a bit, so that we don't have the free_block_id. Will keep you posted about this issue.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. It sounds like my original design should have had more thought on the distinction between Python block objects and block ids themselves. It's OK if we have some suboptimality given that, but also hope you're able to find a simple solution :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was able to refactor the code so that only free() is used at all places. I think it is a good change since it forces an explicit free/alloc calls for block objects, and this avoids potential memory leaks (due to previous separation between the block_id and block - currently they are "more fused").

The main things I needed to change is CoW and promote_to_immutable (in prefix-caching). The change moves these two functions to the allocator level (outside of the block itself), since these functions free-and-reallocate a new block, which needs to be updated in the associated lists in block_table.py. To make this cleaner, I added a function in block_table.py that is called "append_token_ids_and_update_allocator".

In addition, I redid the free() procedure of prefix-caching since it was a bit complicated, by separating the two main cases there: (1) immutable/promoted block and (2) mutable/hashless block.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have verified performance it is even a little better now.

Copy link
Collaborator Author

@alexm-neuralmagic alexm-neuralmagic Jun 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I have squashed the relevant commits to "refactor code so that only free() is used" so it will be easier to see the changes I did only for this change.

@cadedaniel
Copy link
Collaborator

Great work btw! thanks!

@alexm-neuralmagic
Copy link
Collaborator Author

Updated the PR with performance fixes for prefix-caching block_manager_v2. The table above is updated with new numbers for both standard run and prefix-cache enabled run.

@alexm-neuralmagic
Copy link
Collaborator Author

Will start addressing review comments and cleaning up the PR

@hibukipanim
Copy link

As the PR touches prefix caching and preparing v2-block-manager to be default, I was curious to see if the PR might resolve this correctness issue: #5543 (comment).
and you might be interested to know that when running with this branch (commit c1f650fa7f162eb48763d8eeb70081986379f7e1) with --enable-prefix-caching --use-v2-block-manager, the snippet in the linked issue crashes the server with:

ERROR 06-19 07:45:58 async_llm_engine.py:45] Engine background task failed
ERROR 06-19 07:45:58 async_llm_engine.py:45] Traceback (most recent call last):
ERROR 06-19 07:45:58 async_llm_engine.py:45]   File "/workspace/nm-vllm/vllm/engine/async_llm_engine.py", line 40, in _raise_exception_on_finish
ERROR 06-19 07:45:58 async_llm_engine.py:45]     task.result()
ERROR 06-19 07:45:58 async_llm_engine.py:45]   File "/workspace/nm-vllm/vllm/engine/async_llm_engine.py", line 521, in run_engine_loop
ERROR 06-19 07:45:58 async_llm_engine.py:45]     has_requests_in_progress = await asyncio.wait_for(
ERROR 06-19 07:45:58 async_llm_engine.py:45]   File "/usr/lib/python3.10/asyncio/tasks.py", line 445, in wait_for
ERROR 06-19 07:45:58 async_llm_engine.py:45]     return fut.result()
ERROR 06-19 07:45:58 async_llm_engine.py:45]   File "/workspace/nm-vllm/vllm/engine/async_llm_engine.py", line 495, in engine_step
ERROR 06-19 07:45:58 async_llm_engine.py:45]     request_outputs = await self.engine.step_async()
ERROR 06-19 07:45:58 async_llm_engine.py:45]   File "/workspace/nm-vllm/vllm/engine/async_llm_engine.py", line 226, in step_async
ERROR 06-19 07:45:58 async_llm_engine.py:45]     output = await self.model_executor.execute_model_async(
ERROR 06-19 07:45:58 async_llm_engine.py:45]   File "/workspace/nm-vllm/vllm/executor/gpu_executor.py", line 117, in execute_model_async
ERROR 06-19 07:45:58 async_llm_engine.py:45]     output = await make_async(self.driver_worker.execute_model
ERROR 06-19 07:45:58 async_llm_engine.py:45]   File "/usr/lib/python3.10/concurrent/futures/thread.py", line 58, in run
ERROR 06-19 07:45:58 async_llm_engine.py:45]     result = self.fn(*self.args, **self.kwargs)
ERROR 06-19 07:45:58 async_llm_engine.py:45]   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
ERROR 06-19 07:45:58 async_llm_engine.py:45]     return func(*args, **kwargs)
ERROR 06-19 07:45:58 async_llm_engine.py:45]   File "/workspace/nm-vllm/vllm/worker/worker.py", line 272, in execute_model
ERROR 06-19 07:45:58 async_llm_engine.py:45]     output = self.model_runner.execute_model(seq_group_metadata_list,
ERROR 06-19 07:45:58 async_llm_engine.py:45]   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
ERROR 06-19 07:45:58 async_llm_engine.py:45]     return func(*args, **kwargs)
ERROR 06-19 07:45:58 async_llm_engine.py:45]   File "/workspace/nm-vllm/vllm/worker/model_runner.py", line 736, in execute_model
ERROR 06-19 07:45:58 async_llm_engine.py:45]     hidden_states = model_executable(
ERROR 06-19 07:45:58 async_llm_engine.py:45]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
ERROR 06-19 07:45:58 async_llm_engine.py:45]     return self._call_impl(*args, **kwargs)
ERROR 06-19 07:45:58 async_llm_engine.py:45]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
ERROR 06-19 07:45:58 async_llm_engine.py:45]     return forward_call(*args, **kwargs)
ERROR 06-19 07:45:58 async_llm_engine.py:45]   File "/workspace/nm-vllm/vllm/model_executor/models/llama.py", line 371, in forward
ERROR 06-19 07:45:58 async_llm_engine.py:45]     hidden_states = self.model(input_ids, positions, kv_caches,
ERROR 06-19 07:45:58 async_llm_engine.py:45]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
ERROR 06-19 07:45:58 async_llm_engine.py:45]     return self._call_impl(*args, **kwargs)
ERROR 06-19 07:45:58 async_llm_engine.py:45]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
ERROR 06-19 07:45:58 async_llm_engine.py:45]     return forward_call(*args, **kwargs)
ERROR 06-19 07:45:58 async_llm_engine.py:45]   File "/workspace/nm-vllm/vllm/model_executor/models/llama.py", line 288, in forward
ERROR 06-19 07:45:58 async_llm_engine.py:45]     hidden_states, residual = layer(
ERROR 06-19 07:45:58 async_llm_engine.py:45]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
ERROR 06-19 07:45:58 async_llm_engine.py:45]     return self._call_impl(*args, **kwargs)
ERROR 06-19 07:45:58 async_llm_engine.py:45]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
ERROR 06-19 07:45:58 async_llm_engine.py:45]     return forward_call(*args, **kwargs)
ERROR 06-19 07:45:58 async_llm_engine.py:45]   File "/workspace/nm-vllm/vllm/model_executor/models/llama.py", line 227, in forward
ERROR 06-19 07:45:58 async_llm_engine.py:45]     hidden_states = self.self_attn(
ERROR 06-19 07:45:58 async_llm_engine.py:45]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
ERROR 06-19 07:45:58 async_llm_engine.py:45]     return self._call_impl(*args, **kwargs)
ERROR 06-19 07:45:58 async_llm_engine.py:45]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
ERROR 06-19 07:45:58 async_llm_engine.py:45]     return forward_call(*args, **kwargs)
ERROR 06-19 07:45:58 async_llm_engine.py:45]   File "/workspace/nm-vllm/vllm/model_executor/models/llama.py", line 161, in forward
ERROR 06-19 07:45:58 async_llm_engine.py:45]     attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
ERROR 06-19 07:45:58 async_llm_engine.py:45]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
ERROR 06-19 07:45:58 async_llm_engine.py:45]     return self._call_impl(*args, **kwargs)
ERROR 06-19 07:45:58 async_llm_engine.py:45]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
ERROR 06-19 07:45:58 async_llm_engine.py:45]     return forward_call(*args, **kwargs)
ERROR 06-19 07:45:58 async_llm_engine.py:45]   File "/workspace/nm-vllm/vllm/attention/layer.py", line 89, in forward
ERROR 06-19 07:45:58 async_llm_engine.py:45]     return self.impl.forward(query, key, value, kv_cache, attn_metadata,
ERROR 06-19 07:45:58 async_llm_engine.py:45]   File "/workspace/nm-vllm/vllm/attention/backends/flash_attn.py", line 338, in forward
ERROR 06-19 07:45:58 async_llm_engine.py:45]     flash_attn_varlen_func(
ERROR 06-19 07:45:58 async_llm_engine.py:45]   File "/usr/local/lib/python3.10/dist-packages/vllm_flash_attn/flash_attn_interface.py", line 1099, in flash_attn_varlen_func
ERROR 06-19 07:45:58 async_llm_engine.py:45]     return FlashAttnVarlenFunc.apply(
ERROR 06-19 07:45:58 async_llm_engine.py:45]   File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 598, in apply
ERROR 06-19 07:45:58 async_llm_engine.py:45]     return super().apply(*args, **kwargs)  # type: ignore[misc]
ERROR 06-19 07:45:58 async_llm_engine.py:45]   File "/usr/local/lib/python3.10/dist-packages/vllm_flash_attn/flash_attn_interface.py", line 596, in forward
ERROR 06-19 07:45:58 async_llm_engine.py:45]     out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
ERROR 06-19 07:45:58 async_llm_engine.py:45]   File "/usr/local/lib/python3.10/dist-packages/vllm_flash_attn/flash_attn_interface.py", line 88, in _flash_attn_varlen_forward
ERROR 06-19 07:45:58 async_llm_engine.py:45]     out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
ERROR 06-19 07:45:58 async_llm_engine.py:45] RuntimeError: out must have shape (total_q, num_heads, head_size_og)
Exception in callback functools.partial(<function _raise_exception_on_finish at 0x7f2bc22f4160>, error_callback=<bound method AsyncLLMEngine._error_callback of <vllm.engine.async_llm_engine.AsyncLLMEngine object at 0x7f2bb73e0910>>)
handle: <Handle functools.partial(<function _raise_exception_on_finish at 0x7f2bc22f4160>, error_callback=<bound method AsyncLLMEngine._error_callback of <vllm.engine.async_llm_engine.AsyncLLMEngine object at 0x7f2bb73e0910>>)>
Traceback (most recent call last):
  File "/workspace/nm-vllm/vllm/engine/async_llm_engine.py", line 40, in _raise_exception_on_finish
    task.result()
  File "/workspace/nm-vllm/vllm/engine/async_llm_engine.py", line 521, in run_engine_loop
    has_requests_in_progress = await asyncio.wait_for(
  File "/usr/lib/python3.10/asyncio/tasks.py", line 445, in wait_for
    return fut.result()
  File "/workspace/nm-vllm/vllm/engine/async_llm_engine.py", line 495, in engine_step
    request_outputs = await self.engine.step_async()
  File "/workspace/nm-vllm/vllm/engine/async_llm_engine.py", line 226, in step_async
    output = await self.model_executor.execute_model_async(
  File "/workspace/nm-vllm/vllm/executor/gpu_executor.py", line 117, in execute_model_async
    output = await make_async(self.driver_worker.execute_model
  File "/usr/lib/python3.10/concurrent/futures/thread.py", line 58, in run
    result = self.fn(*self.args, **self.kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/nm-vllm/vllm/worker/worker.py", line 272, in execute_model
    output = self.model_runner.execute_model(seq_group_metadata_list,
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/nm-vllm/vllm/worker/model_runner.py", line 736, in execute_model
    hidden_states = model_executable(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/nm-vllm/vllm/model_executor/models/llama.py", line 371, in forward
    hidden_states = self.model(input_ids, positions, kv_caches,
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/nm-vllm/vllm/model_executor/models/llama.py", line 288, in forward
    hidden_states, residual = layer(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/nm-vllm/vllm/model_executor/models/llama.py", line 227, in forward
    hidden_states = self.self_attn(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/nm-vllm/vllm/model_executor/models/llama.py", line 161, in forward
    attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/nm-vllm/vllm/attention/layer.py", line 89, in forward
    return self.impl.forward(query, key, value, kv_cache, attn_metadata,
  File "/workspace/nm-vllm/vllm/attention/backends/flash_attn.py", line 338, in forward
    flash_attn_varlen_func(
  File "/usr/local/lib/python3.10/dist-packages/vllm_flash_attn/flash_attn_interface.py", line 1099, in flash_attn_varlen_func
    return FlashAttnVarlenFunc.apply(
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 598, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/usr/local/lib/python3.10/dist-packages/vllm_flash_attn/flash_attn_interface.py", line 596, in forward
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
  File "/usr/local/lib/python3.10/dist-packages/vllm_flash_attn/flash_attn_interface.py", line 88, in _flash_attn_varlen_forward
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
RuntimeError: out must have shape (total_q, num_heads, head_size_og)

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "uvloop/cbhandles.pyx", line 63, in uvloop.loop.Handle._run
  File "/workspace/nm-vllm/vllm/engine/async_llm_engine.py", line 47, in _raise_exception_on_finish
    raise AsyncEngineDeadError(
vllm.engine.async_llm_engine.AsyncEngineDeadError: Task finished unexpectedly. This should never happen! Please open an issue on Github. See stack trace above for the actual cause.
INFO 06-19 07:45:58 async_llm_engine.py:158] Aborted request cmpl-4ce91102896f49d598ec6313f9629a10-0.
INFO:     172.17.0.1:47640 - "POST /v1/completions HTTP/1.1" 500 Internal Server Error
ERROR:    Exception in ASGI application

@alexm-neuralmagic
Copy link
Collaborator Author

@hibukipanim thanks for pointing this issue, I will check

@alexm-neuralmagic alexm-neuralmagic marked this pull request as ready for review June 19, 2024 19:03
@alexm-neuralmagic alexm-neuralmagic force-pushed the block_manager_v2_perf branch 2 times, most recently from 0148b6e to e08d643 Compare June 20, 2024 21:34
vllm/sequence.py Outdated

@property
def prompt_token_ids(self) -> List[int]:
return self._prompt_token_ids
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should return a tuple/shallow copy so that this and also output_token_ids doesn't get modified by mistake (and thus bypass _update_cached_all_tokens)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, what happens if someone modifies the prompt token ids / output token ids list?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, changed the return types to be tuples.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have changed the approach here to protect accesses to prompt_token_ids and output_token_ids. Now, it uses a class MonitoredList that records a timestamp of the last update, and based on that, the cached all tokens is updated. I did in this way to avoid changing all usages of the prompt/output token ids due to tuple change and also it avoids unnecessary copies of list => tuples which are also expensive.

Copy link
Collaborator Author

@alexm-neuralmagic alexm-neuralmagic Jun 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Yard1 found out that there is actually an issue with the deserialization with ray, so I have removed this and made the prompt/output token_ids accessors return tuples. It introduces a conversion for the output_token_ids to tuple but it seems not to be bad and the performance is still good. To make it work, I have propagated the tuple type upward in the vllm software stack, since we don't expect seq_data users to use these accessors to change data (but only via the append_token() function)

@cadedaniel
Copy link
Collaborator

ok looking

Copy link
Collaborator

@cadedaniel cadedaniel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

most comments are nits. big question is the design change around CoW/promotion (I think it's actually a bad design change). let's schedule some time to go over this sync as I think it will be faster than back and forth.

Comment on lines 14 to 16
llm = LLM(model="facebook/opt-125m")
llm = LLM(model="facebook/opt-125m",
use_v2_block_manager=True,
enable_prefix_caching=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's leave this out for now

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, removed


first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=token_ids,
allocator=allocator,
)

# mark all blocks in first chain as computed
allocator.mark_blocks_as_computed(blocks)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO(cade) see why this api is no longer required

from vllm.utils import Device, cdiv, chunk_list


# This class is an optimization to allow fast-access to physical block ids
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's write this as a docstring

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggest writing also how it achieves the optimization (can write docstrings for individual functions but it's more tedious)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

from vllm.utils import Device, cdiv, chunk_list


# This class is an optimization to allow fast-access to physical block ids
class BlockList:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: would be great to have basic unit tests for this helper

from vllm.utils import Device, cdiv, chunk_list


# This class is an optimization to allow fast-access to physical block ids
class BlockList:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I have preference for putting helper methods/functions below the main class of the file, so the file can be read top-down

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved to block/common.py

Comment on lines 103 to 104
self._cached_computed_seq_blocks: Dict[SeqId, List[int]] = {}
self._seq_last_access: Dict[SeqId, float] = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the motivation for raising these to BlockManger level? we should keep things simple at this layer unless there's good reason not to

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was a significant overhead in these function calls, since they traversed the full block lists.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we modify the API such that it allows caching the result // we don't have to traverse the full block lists?

Two downsides:

  • we expose more complexity in this layer than is necessary (this is tech debt we can live with, if it's too hard)
  • we make it harder for other block managers to use prefix caching (we may have a block manager which specializes for another type, e.g. the newer models which use sliding window + normal attention).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good idea. I have refactored this logic out to two classes: ComputedBlocksTracker and LastAccessBlocksTracker so it will be easier to port the logic to other places.

Comment on lines 239 to 240
# TODO: Ask Cade how it may be possible to have
# allocated block id inside the evictor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's go over this

block_ids = self.block_tables[seq.seq_id].physical_block_ids
assert all(b is not None for b in block_ids)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome

@@ -274,6 +285,43 @@ def mark_blocks_as_computed(self, seq_group: SequenceGroup):
# So this function is useless for block_v2.
pass

def get_and_update_computed_block_ids(self, seqs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring / typing

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

vllm/sequence.py Outdated

@property
def prompt_token_ids(self) -> List[int]:
return self._prompt_token_ids
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, what happens if someone modifies the prompt token ids / output token ids list?

Copy link
Collaborator Author

@alexm-neuralmagic alexm-neuralmagic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the PR with addressed review comments from Cade and Yard1. I have moved the CoW and Promo functionality back to the block and ensured that there is no new _free_block_id() interface to minimize interface changes.

Also, I had moved the code a bit inside the prefix-caching allocator to make it more readable and easier to maintain.

Verified that performance is still good, for both standard and prefix-cached runs.

TODO:
Fixing tests now

Comment on lines 12 to 19
# Used to pre-allocate block objects, in order to avoid excessive python
# object allocations/deallocations.
# The pool starts from "pool_size" objects and will increase to more objects
# if necessary
#
# Note that multiple block objects may point to the same physical block id,
# which is why this pool is needed, so that it will be easier to support
# prefix caching and more complicated sharing of physical blocks.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added docstring and moved BlockPool class to block/common.py

@@ -19,6 +19,28 @@
_DEFAULT_LAST_ACCESSED_TIME = -1


class BlockTracker:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

from vllm.utils import Device, cdiv, chunk_list


# This class is an optimization to allow fast-access to physical block ids
class BlockList:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved to block/common.py

return self._block_ids


def append_token_ids_and_update_allocator(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed this function in favor of moving this logic back into block class

block: Block, token_ids: List[int],
allocator: DeviceAwareBlockAllocator) -> Block:
new_block = allocator.cow_block_if_not_appendable(block)
if new_block:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

vllm/sequence.py Outdated

@property
def prompt_token_ids(self) -> List[int]:
return self._prompt_token_ids
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, changed the return types to be tuples.

from vllm.utils import Device, cdiv, chunk_list


# This class is an optimization to allow fast-access to physical block ids
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

elif block_id in self.evictor:
self.evictor.update(block_id, now)
else:
raise ValueError(
"Mark block as accessed which is not belonged to GPU")

def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
"""Mark blocks as computed, used in prefix caching."""
raise NotImplementedError("Marking as computed is incremental")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For prefix caching, a block is "computed" when it is full, so it is possible to use the block.content_hash as the indicator for computed or not computed without the need from the scheduler to explicitly state it. Which is why the original implementation was not doing anything for that case, and this function was never called. I simply replaced the code with an error exception just to make sure it is indeed not used.


self._update_num_token_ids()

def _update_num_token_ids(self):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

Comment on lines 103 to 104
self._cached_computed_seq_blocks: Dict[SeqId, List[int]] = {}
self._seq_last_access: Dict[SeqId, float] = {}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was a significant overhead in these function calls, since they traversed the full block lists.

@DarkLight1337
Copy link
Member

To speed up the CI queue for #5905, I've cancelled the distributed tests for the latest CI run in this PR since they won't pass anyway until #5905 has been merged. Please merge main into your branch after that happens so that the CI can pass once again.

self._num_full_slots = len(token_ids)

def update(self, blocks):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: typing

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

elif block_id in self.evictor:
self.evictor.update(block_id, now)
else:
raise ValueError(
"Mark block as accessed which is not belonged to GPU")

def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
"""Mark blocks as computed, used in prefix caching."""
raise NotImplementedError("Marking as computed is incremental")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. let's delete the API?

allocator=self._allocator,
block_id=None))

def increase_pool(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: docstrings on public methods

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mark_blocks_as_computed still used in block_manager_v1
added docstring

vllm/core/block/block_table.py Show resolved Hide resolved
Comment on lines 298 to 328
raise NotImplementedError
device = Device.GPU
return self._allocators[device].promote_to_immutable_block(block)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this implementation and cow_block_if_not_appendable? technically, vLLM does not support modification of block content for CPU-based allocators

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this method is only invoked when appending tokens

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add some comment when it's used? (I think they should be removed but seems I miss a case)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You actually right, this is cpu-gpu allocator, so it is not doing the actual CoW or promo, since it is done only by the specific Naive or Prefix allocators, and they have these functions define via the base class BlockAllocator. Good catch!

Comment on lines 376 to 379
if self._proxy.token_ids:
return len(self._proxy.token_ids)
else:
return 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you see my comment about token_ids being optional? It adds more complexity to the API, and leaks abstraction details here and other places that need to check if it's None before deciding behavior.

If we want a no-op token id List for the undefined blocks, we can have a class which implements List and always returns 0 for len / raises NotImplemented for anything that writes. that way we don't have Optional / no branches checking for it everywhere

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was able to remove the Optional from token_ids. Now it is the same as before.

Comment on lines 103 to 104
self._cached_computed_seq_blocks: Dict[SeqId, List[int]] = {}
self._seq_last_access: Dict[SeqId, float] = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we modify the API such that it allows caching the result // we don't have to traverse the full block lists?

Two downsides:

  • we expose more complexity in this layer than is necessary (this is tech debt we can live with, if it's too hard)
  • we make it harder for other block managers to use prefix caching (we may have a block manager which specializes for another type, e.g. the newer models which use sliding window + normal attention).

Comment on lines 315 to 326
self._cached_computed_seq_blocks[seq_id] = computed_block_ids
else:
computed_block_ids = self._cached_computed_seq_blocks[seq_id]
if len(computed_block_ids) < len(block_ids):
# Incremental init for seq_id => Look only at the new blocks
computed_block_ids = self.block_allocator.get_computed_block_ids( # noqa: E501
computed_block_ids, block_ids)
self._cached_computed_seq_blocks[
seq_id] = computed_block_ids
else:
# Cache HIT
assert len(computed_block_ids) == len(block_ids)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will still result in constant recomputation in the worst case. I think we can do the following:

  1. After the first run, if len(computed_block_ids) != len(block_ids), we know that we will never add any extra blocks to computed_block_ids (since we'd have a gap otherwise). Therefore, we should save that as a boolean in the cache alongside the computed block ids
  2. In the subsequent runs, if the seq_id is present in cache, but the boolean is False, we just return the cached computed block ids without calling get_computed_block_ids. Otherwise, if the boolean is true, we call get_computed_block_ids for the new blocks and save in cache, with the len(computed_block_ids) == len(block_ids) boolean.

let me know if this makes sense? I may be missing something here.

Copy link
Collaborator

@Yard1 Yard1 Jun 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's the suggested change:

    def _get_and_update_computed_block_ids(self, seqs):
        """Handles caching of per-sequence computed block ids. 
        When a sequence appears for the first time, it traverses all of the 
        blocks and detects the prefix of blocks that is computed. On the
        subsequent times, it only traverses the new blocks that were added 
        and updates the already recorded prefix of blocks with the newly 
        computed blocks.
        """
        ret = []
        for seq in seqs:
            seq_id = seq.seq_id

            # Get block ids of this sequence, while not considering the
            # last block
            block_ids = self.block_tables[seq_id].physical_block_ids[:-1]

            # Here we cache the detection of computed_block_ids for seq_id.
            # Since computed_block_ids form a prefix of block_ids,
            # the first time we see seq_id, we detect computed_block_ids
            # fully and store them in the cache. In the next times we see
            # seq_id, we detect computed_block_ids incrementally, by looking
            # only at the new blocks that come after the cached
            # computed_block_ids
            if seq_id not in self._cached_computed_seq_blocks:
                # First time init for seq_id => Detect fully
                computed_block_ids = self.block_allocator.get_computed_block_ids(  # noqa: E501
                    [], block_ids)
                self._cached_computed_seq_blocks[seq_id] = (computed_block_ids, len(computed_block_ids)>=len(block_ids)-1)
            else:
                computed_block_ids, should_continue_adding = self._cached_computed_seq_blocks[seq_id]
                if should_continue_adding:
                    if len(computed_block_ids) < len(block_ids):
                        # Incremental init for seq_id => Look only at the new blocks
                        computed_block_ids = self.block_allocator.get_computed_block_ids(  # noqa: E501
                            computed_block_ids, block_ids)
                        self._cached_computed_seq_blocks[
                            seq_id] = (computed_block_ids, len(computed_block_ids)>=len(block_ids)-1)
                    else:
                        # Cache HIT
                        assert len(computed_block_ids) == len(block_ids)

            ret.append(computed_block_ids)

        return ret

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Yard1 and I discussed this in more detail and this is a really good suggestion that should help with performance. Will add this to the algorithm.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Yard1 Added your idea inside. All works.

@alexm-neuralmagic
Copy link
Collaborator Author

@cadedaniel @Yard1 I have addressed the review comments, the PR is ready for a pass.

Copy link
Collaborator

@cadedaniel cadedaniel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small comments only, let's go!

block.append_token_ids(token_block)
self._blocks[idx] = block # Refresh the cached block_id
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this still necessary?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I redid the code so it is hidden inside the BlockList (by adding append_token_ids(block_idx, tokens) api func)

Comment on lines 301 to 303
cur_token_ids = block.token_ids
if cur_token_ids is not None:
token_ids.extend(cur_token_ids)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove check now that it can't be None?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

Comment on lines 308 to 309
if not self._is_allocated:
return 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think we don't need this branch anymore. if it's not allocated, self.blocks will be empty

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, removed

Comment on lines +129 to +130
assert src_block_id is not None
assert trg_block_id is not None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: a little weird that we check a non-Optional is not None. but my guess it's due to python typing weakness...

can ignore

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the type to Optional[BlockId], I think it makes more sense

Comment on lines 298 to 328
raise NotImplementedError
device = Device.GPU
return self._allocators[device].promote_to_immutable_block(block)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add some comment when it's used? (I think they should be removed but seems I miss a case)

pass

@abstractmethod
def promote_to_immutable_block(self, block: Block) -> BlockId:
"""NOTE: This should not be used besides Block"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggest keeping the NOTE in

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

Comment on lines +315 to +321
"""Decrements the refcount of the block. The block may be in two
possible states: (1) immutable/cached or (2) mutable/hashless.
In the first case, the refcount is decremented directly and the block
may be possibly added to the evictor. In other case, hashless
allocator free(..) with keep_block_object=True is called to only free
the block id (since the block object may be reused by the caller)
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

love this :)

@@ -658,6 +801,7 @@ def content_hash(self) -> Optional[int]:
if prev_block_hash is None and not is_first_block:
return None

assert len(self.token_ids) > 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: do we need this assert given if not self.is_full?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You right, removed

Comment on lines +850 to +851
Note that currently, for a given sequence, we also skip the last
block id for caching purposes, to avoid caching of a full sequence
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this work with lookahead scheduling (where potenially >1 block is modified in single step)? don't have to fix now but in the future we want speculative decoding x prefix caching to work

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should work since the blocks that are used for appending or speculative tokens won't be marked as computed, so they won't go into the common cache prefix.

Comment on lines +918 to +921
class LastAccessBlocksTracker:
"""Manages the last access time of the tracked sequences, in order to allow
an efficient update of allocator's block last access times
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

@alexm-neuralmagic
Copy link
Collaborator Author

@cadedaniel fixed the nits, thanks for catching these issues!

Copy link
Collaborator

@cadedaniel cadedaniel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the excellent contribution!

@cadedaniel cadedaniel merged commit 3476ed0 into vllm-project:main Jul 2, 2024
69 checks passed
kzawora-intel added a commit to HabanaAI/vllm-fork that referenced this pull request Jul 2, 2024
* [Hardware][Intel] Optimize CPU backend and add more performance tips (vllm-project#4971)

Co-authored-by: Jianan Gu <jianan.gu@intel.com>

* [Docs] Add 4th meetup slides (vllm-project#5509)

* [Misc] Add vLLM version getter to utils (vllm-project#5098)

* [CI/Build] Simplify OpenAI server setup in tests (vllm-project#5100)

* [Doc] Update LLaVA docs (vllm-project#5437)

Co-authored-by: Roger Wang <ywang@roblox.com>

* [Kernel] Factor out epilogues from cutlass kernels (vllm-project#5391)

Co-authored-by: Michael Goin <michael@neuralmagic.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: zifeitong <zifei.tong@parasail.io>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>

* [MISC] Remove FP8 warning (vllm-project#5472)

Co-authored-by: Philipp Moritz <pcmoritz@gmail.com>

* Seperate dev requirements into lint and test (vllm-project#5474)

* Revert "[Core] Remove unnecessary copies in flash attn backend" (vllm-project#5478)

* [misc] fix format.sh (vllm-project#5511)

* [CI/Build] Disable test_fp8.py (vllm-project#5508)

* [Kernel] Disable CUTLASS kernels for fp8 (vllm-project#5505)

* Add `cuda_device_count_stateless` (vllm-project#5473)

* [Hardware][Intel] Support CPU inference with AVX2 ISA (vllm-project#5452)

* [Misc] Fix arg names in quantizer script (vllm-project#5507)

* bump version to v0.5.0.post1 (vllm-project#5522)

* [CI/Build][Misc] Add CI that benchmarks vllm performance on those PRs with `perf-benchmarks` label (vllm-project#5073)

Co-authored-by: simon-mo <simon.mo@hey.com>

* [CI/Build] Disable LLaVA-NeXT CPU test (vllm-project#5529)

* [Kernel] Fix CUTLASS 3.x custom broadcast load epilogue (vllm-project#5516)

* [Misc] Fix arg names (vllm-project#5524)

* [ Misc ] Rs/compressed tensors cleanup (vllm-project#5432)

Co-authored-by: mgoin <michael@neuralmagic.com>
Co-authored-by: Dipika Sikka <dipikasikka1@gmail.com>

* [Kernel] Suppress mma.sp warning on CUDA 12.5 and later (vllm-project#5401)

* [mis] fix flaky test of test_cuda_device_count_stateless (vllm-project#5546)

* [Core] Remove duplicate processing in async engine (vllm-project#5525)

* [misc][distributed] fix benign error in `is_in_the_same_node` (vllm-project#5512)

* [Docs] Add ZhenFund as a Sponsor (vllm-project#5548)

* [Doc] Update documentation on Tensorizer (vllm-project#5471)

* [Bugfix] Enable loading FP8 checkpoints for gpt_bigcode models  (vllm-project#5460)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>

* [Bugfix] Fix typo in Pallas backend (vllm-project#5558)

* [Core][Distributed] improve p2p cache generation (vllm-project#5528)

* Add ccache to amd (vllm-project#5555)

* [Core][Bugfix]: fix prefix caching for blockv2 (vllm-project#5364)

Signed-off-by: Lei Wen <wenlei03@qiyi.com>
Co-authored-by: Lei Wen <wenlei03@qiyi.com>

* [mypy] Enable type checking for test directory (vllm-project#5017)

* [CI/Build] Test both text and token IDs in batched OpenAI Completions API (vllm-project#5568)

* [misc] Do not allow to use lora with chunked prefill. (vllm-project#5538)

Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>

* add gptq_marlin test for bug report vllm-project#5088 (vllm-project#5145)

* [BugFix] Don't start a Ray cluster when not using Ray (vllm-project#5570)

* [Fix] Correct OpenAI batch response format (vllm-project#5554)

* Add basic correctness 2 GPU tests to 4 GPU pipeline (vllm-project#5518)

* [CI][BugFix] Flip is_quant_method_supported condition (vllm-project#5577)

* [build][misc] limit numpy version (vllm-project#5582)

* [Doc] add debugging tips for crash and multi-node debugging (vllm-project#5581)

* Fix w8a8 benchmark and add Llama-3-8B (vllm-project#5562)

* [Model] Rename Phi3 rope scaling type (vllm-project#5595)

* Correct alignment in the seq_len diagram. (vllm-project#5592)

Co-authored-by: Liqian Chen <liqian.chen@deeplang.ai>

* [Kernel] `compressed-tensors` marlin 24 support (vllm-project#5435)

* [Misc] use AutoTokenizer for benchmark serving when vLLM not installed (vllm-project#5588)

* [Hardware][Intel GPU] Add Intel GPU(XPU) inference backend (vllm-project#3814)

Co-authored-by: Jiang Li <jiang1.li@intel.com>
Co-authored-by: Abhilash Majumder <abhilash.majumder@intel.com>
Co-authored-by: Abhilash Majumder <30946547+abhilash1910@users.noreply.github.com>

* [CI/BUILD] Support non-AVX512 vLLM building and testing (vllm-project#5574)

* [CI] the readability of benchmarking and prepare for dashboard (vllm-project#5571)

[CI] Improve the readability of performance benchmarking results and prepare for upcoming performance dashboard (vllm-project#5571)

* [bugfix][distributed] fix 16 gpus local rank arrangement (vllm-project#5604)

* [Optimization] use a pool to reuse LogicalTokenBlock.token_ids (vllm-project#5584)

* [Bugfix] Fix KV head calculation for MPT models when using GQA (vllm-project#5142)

* [Fix] Use utf-8 encoding in entrypoints/openai/run_batch.py (vllm-project#5606)

* [Speculative Decoding 1/2 ] Add typical acceptance sampling as one of the sampling techniques in the verifier (vllm-project#5131)

* [Model] Initialize Phi-3-vision support (vllm-project#4986)

* [Kernel] Add punica dimensions for Granite 13b (vllm-project#5559)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>

* [misc][typo] fix typo (vllm-project#5620)

* [Misc] Fix typo (vllm-project#5618)

* [CI] Avoid naming different metrics with the same name in performance benchmark (vllm-project#5615)

* [bugfix][distributed] improve p2p capability test (vllm-project#5612)

[bugfix][distributed] do not error if two processes do not agree on p2p capability (vllm-project#5612)

* [Misc] Remove import from transformers logging (vllm-project#5625)

* [CI/Build][Misc] Update Pytest Marker for VLMs (vllm-project#5623)

* [ci] Deprecate original CI template (vllm-project#5624)

Signed-off-by: kevin <kevin@anyscale.com>

* [Misc] Add OpenTelemetry support (vllm-project#4687)

This PR adds basic support for OpenTelemetry distributed tracing.
It includes changes to enable tracing functionality and improve monitoring capabilities.

I've also added a markdown with print-screens to guide users how to use this feature. You can find it here

* [Misc] Add channel-wise quantization support for w8a8 dynamic per token activation quantization (vllm-project#5542)

* [ci] Setup Release pipeline and build release wheels with cache (vllm-project#5610)

Signed-off-by: kevin <kevin@anyscale.com>

* [Model] LoRA support added for command-r (vllm-project#5178)

* [Bugfix] Fix for inconsistent behaviour related to sampling and repetition penalties  (vllm-project#5639)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>

* [Doc] Added cerebrium as Integration option (vllm-project#5553)

* [Bugfix] Fix CUDA version check for mma warning suppression (vllm-project#5642)

* [Bugfix] Fix w8a8 benchmarks for int8 case (vllm-project#5643)

* [Bugfix] Fix Phi-3 Long RoPE scaling implementation (vllm-project#5628)

* [Bugfix] Added test for sampling repetition penalty bug. (vllm-project#5659)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>

* [Bugfix][CI/Build][AMD][ROCm]Fixed the cmake build bug which generate garbage on certain devices (vllm-project#5641)

* [misc][distributed] use 127.0.0.1 for single-node (vllm-project#5619)

* [Model] Add FP8 kv cache for Qwen2 (vllm-project#5656)

* [Bugfix] Fix sampling_params passed incorrectly in Phi3v example (vllm-project#5684)

* [Misc]Add param max-model-len in benchmark_latency.py (vllm-project#5629)

* [CI/Build] Add tqdm to dependencies (vllm-project#5680)

* [ci] Add A100 queue into AWS CI template (vllm-project#5648)

Signed-off-by: kevin <kevin@anyscale.com>

* [Frontend][Bugfix] Fix preemption_mode -> preemption-mode for CLI arg in arg_utils.py (vllm-project#5688)

* [ci][distributed] add tests for custom allreduce (vllm-project#5689)

* [Bugfix] AsyncLLMEngine hangs with asyncio.run (vllm-project#5654)

* [Doc] Update docker references (vllm-project#5614)

Signed-off-by: Rafael Vasquez <rafvasq21@gmail.com>

* [Misc] Add per channel support for static activation quantization; update w8a8 schemes to share base classes (vllm-project#5650)

* [ci] Limit num gpus if specified for A100 (vllm-project#5694)

Signed-off-by: kevin <kevin@anyscale.com>

* [Misc] Improve conftest (vllm-project#5681)

* [Bugfix][Doc] FIx Duplicate Explicit Target Name Errors (vllm-project#5703)

* [Kernel] Update Cutlass int8 kernel configs for SM90 (vllm-project#5514)

Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>

* [Model] Port over CLIPVisionModel for VLMs (vllm-project#5591)

* [Kernel] Update Cutlass int8 kernel configs for SM80 (vllm-project#5275)

Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>

* [Bugfix] Fix the CUDA version check for FP8 support in the CUTLASS kernels (vllm-project#5715)

* [Frontend] Add FlexibleArgumentParser to support both underscore and dash in names (vllm-project#5718)

* [distributed][misc] use fork by default for mp (vllm-project#5669)

* [Model] MLPSpeculator speculative decoding support (vllm-project#4947)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>

Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
Co-authored-by: Davis Wertheimer <Davis.Wertheimer@ibm.com>

* [Kernel] Add punica dimension for Qwen2 LoRA (vllm-project#5441)

* [BugFix] Fix test_phi3v.py (vllm-project#5725)

* [Bugfix] Add  fully sharded layer for QKVParallelLinearWithLora (vllm-project#5665)

Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>

* [Core][Distributed] add shm broadcast (vllm-project#5399)

Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>

* [Kernel][CPU] Add Quick `gelu` to CPU (vllm-project#5717)

* [Doc] Documentation on supported hardware for quantization methods (vllm-project#5745)

* [BugFix] exclude version 1.15.0 for modelscope (vllm-project#5668)

* [ci][test] fix ca test in main (vllm-project#5746)

* [LoRA] Add support for pinning lora adapters in the LRU cache (vllm-project#5603)

* [CI][Hardware][Intel GPU] add Intel GPU(XPU) ci pipeline (vllm-project#5616)

* [Model] Support Qwen-VL and Qwen-VL-Chat models with text-only inputs (vllm-project#5710)

Co-authored-by: Roger Wang <ywang@roblox.com>

* [Misc] Remove vllm-project#4789 workaround left in vllm/entrypoints/openai/run_batch.py (vllm-project#5756)

* [Bugfix] Fix pin_lora error in TPU executor (vllm-project#5760)

* [Docs][TPU] Add installation tip for TPU (vllm-project#5761)

* [core][distributed] improve shared memory broadcast (vllm-project#5754)

* [BugFix] [Kernel] Add Cutlass2x fallback kernels (vllm-project#5744)

Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>

* [Distributed] Add send and recv helpers (vllm-project#5719)

* [Bugfix] Add phi3v resize for dynamic shape and fix torchvision requirement (vllm-project#5772)

* [doc][faq] add warning to download models for every nodes (vllm-project#5783)

* post-rebase api adjustments

* [Doc] Add "Suggest edit" button to doc pages (vllm-project#5789)

* [Doc] Add Phi-3-medium to list of supported models (vllm-project#5788)

* [Bugfix] Fix FlexibleArgumentParser replaces _ with - for actual args (vllm-project#5795)

* [ci] Remove aws template (vllm-project#5757)

Signed-off-by: kevin <kevin@anyscale.com>

* [Doc] Add notice about breaking changes to VLMs (vllm-project#5818)

* [Speculative Decoding] Support draft model on different tensor-parallel size than target model (vllm-project#5414)

* add pin_lora to habana components

* add WA for model loader

* fix api mismatches with ray

* tensor parallel fixes

* workers cpu alignment fix

* [Misc] Remove useless code in cpu_worker (vllm-project#5824)

* prefill/decode metadata fixes

* [Core] Add fault tolerance for `RayTokenizerGroupPool` (vllm-project#5748)

* re-enable attn metadata trimming

* worker_use_ray fix

* [doc][distributed] add both gloo and nccl tests (vllm-project#5834)

* [CI/Build] Add unit testing for FlexibleArgumentParser (vllm-project#5798)

* [Misc] Update `w4a16` `compressed-tensors` support to include `w8a16` (vllm-project#5794)

* [Hardware][TPU] Refactor TPU backend (vllm-project#5831)

* [Hardware][AMD][CI/Build][Doc] Upgrade to ROCm 6.1, Dockerfile improvements, test fixes (vllm-project#5422)

* [Hardware][TPU] Raise errors for unsupported sampling params (vllm-project#5850)

* [CI/Build] Add E2E tests for MLPSpeculator (vllm-project#5791)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>

* [Bugfix] Fix assertion in NeuronExecutor (vllm-project#5841)

* [Core] Refactor Worker and ModelRunner to consolidate control plane communication (vllm-project#5408)

Signed-off-by: Stephanie Wang <swang@cs.berkeley.edu>
Signed-off-by: Stephanie <swang@anyscale.com>
Co-authored-by: Stephanie <swang@anyscale.com>

* [Misc][Doc] Add Example of using OpenAI Server with VLM (vllm-project#5832)

* [bugfix][distributed] fix shm broadcast when the queue size is full (vllm-project#5801)

* [Bugfix] Fix embedding to support 2D inputs (vllm-project#5829)

* [Bugfix][TPU] Fix KV cache size calculation (vllm-project#5860)

* [CI/Build] Refactor image test assets (vllm-project#5821)

* [Kernel] Adding bias epilogue support for `cutlass_scaled_mm` (vllm-project#5560)

Co-authored-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>

* [Frontend] Add tokenize/detokenize endpoints (vllm-project#5054)

* [Hardware][TPU] Support parallel sampling & Swapping (vllm-project#5855)

* [Bugfix][TPU] Fix CPU cache allocation (vllm-project#5869)

* Support CPU inference with VSX PowerPC ISA (vllm-project#5652)

* [doc] update usage of env var to avoid conflict (vllm-project#5873)

* [Misc] Add example for LLaVA-NeXT (vllm-project#5879)

* [BugFix] Fix cuda graph for MLPSpeculator (vllm-project#5875)

Co-authored-by: Abhinav Goyal <abhinav.goyal@flipkart.com>

* [Doc] Add note about context length in Phi-3-Vision example (vllm-project#5887)

* [VLM][Bugfix] Make sure that `multi_modal_kwargs` is broadcasted properly (vllm-project#5880)

Signed-off-by: Xiaowei Jiang <xwjiang2010@gmail.com>

* [Model] Add base class for LoRA-supported models (vllm-project#5018)

* [Bugfix] Fix img_sizes Parsing in Phi3-Vision (vllm-project#5888)

* [CI/Build] [1/3] Reorganize entrypoints tests (vllm-project#5526)

* add collective crash WA

* add comment to the weird mark_step

* [Model][Bugfix] Implicit model flags and reenable Phi-3-Vision (vllm-project#5896)

* [doc][misc] add note for Kubernetes users (vllm-project#5916)

* [BugFix] Fix `MLPSpeculator` handling of `num_speculative_tokens` (vllm-project#5876)

* [BugFix] Fix `min_tokens` behaviour for multiple eos tokens (vllm-project#5849)

* [CI/Build] Fix Args for `_get_logits_warper` in Sampler Test (vllm-project#5922)

* [Model] Add Gemma 2 (vllm-project#5908)

* [core][misc] remove logical block (vllm-project#5882)

* [Kernel][ROCm][AMD] fused_moe Triton configs v2 for mi300X (vllm-project#5932)

* [Hardware][TPU] Optimize KV cache swapping (vllm-project#5878)

* [VLM][BugFix] Make sure that `multi_modal_kwargs` can broadcast properly with ring buffer. (vllm-project#5905)

Signed-off-by: Xiaowei Jiang <xwjiang2010@gmail.com>
Co-authored-by: Roger Wang <ywang@roblox.com>

* [Bugfix][Hardware][Intel CPU] Fix unpassed multi_modal_kwargs for CPU runner (vllm-project#5956)

* [Core] Registry for processing model inputs (vllm-project#5214)

Co-authored-by: ywang96 <ywang@roblox.com>

* Unmark fused_moe config json file as executable (vllm-project#5960)

* [Hardware][Intel] OpenVINO vLLM backend (vllm-project#5379)

* [Bugfix] Better error message for MLPSpeculator when `num_speculative_tokens` is set too high (vllm-project#5894)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>

* [CI/Build] [2/3] Reorganize entrypoints tests (vllm-project#5904)

* [Distributed] Make it clear that % should not be in tensor dict keys. (vllm-project#5927)

Signed-off-by: Xiaowei Jiang <xwjiang2010@gmail.com>

* [Spec Decode] Introduce DraftModelRunner (vllm-project#5799)

* [Bugfix] Fix compute datatype for cutlass 3.x epilogues (vllm-project#5931)

* [ Misc ] Remove `fp8_shard_indexer` from Col/Row Parallel Linear (Simplify Weight Loading) (vllm-project#5928)

Co-authored-by: Robert Shaw <rshaw@neuralmagic>

* [ Bugfix ] Enabling Loading Models With Fused QKV/MLP on Disk with FP8 (vllm-project#5921)

Co-authored-by: Robert Shaw <rshaw@neuralmagic>

* Support Deepseek-V2 (vllm-project#4650)

Co-authored-by: Philipp Moritz <pcmoritz@gmail.com>

* [Bugfix] Only add `Attention.kv_scale` if kv cache quantization is enabled (vllm-project#5936)

* Unmark more files as executable (vllm-project#5962)

* [Bugfix] Fix Engine Failing After Invalid Request - AsyncEngineDeadError (vllm-project#5963)

Co-authored-by: Robert Shaw <rshaw@neuralmagic>

* [Kernel] Flashinfer for prefill & decode, with Cudagraph support for decode (vllm-project#4628)

Co-authored-by: LiuXiaoxuanPKU <llilyliupku@gmail.com>, bong-furiosa <bongwon.jang@furiosa.ai>

* [Bugfix][TPU] Fix TPU sampler output (vllm-project#5978)

* [Bugfix][TPU] Fix pad slot id (vllm-project#5977)

* [Bugfix] fix missing last itl in openai completions benchmark (vllm-project#5926)

* [Misc] Extend vLLM Metrics logging API (vllm-project#5925)

Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>

* [Kernel] Add punica dimensions for Granite 3b and 8b (vllm-project#5930)

Signed-off-by: Joe Runde <joe@joerun.de>

* [Bugfix] Fix precisions in Gemma 1 (vllm-project#5913)

* [Misc] Update Phi-3-Vision Example (vllm-project#5981)

Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>

* [Bugfix] Support `eos_token_id` from `config.json` (vllm-project#5954)

* [Core] Optimize `SequenceStatus.is_finished` by switching to IntEnum (vllm-project#5974)

* [Kernel] Raise an exception in MoE kernel if the batch size is larger then 65k (vllm-project#5939)

* [ CI/Build ] Added E2E Test For Compressed Tensors (vllm-project#5839)

Co-authored-by: Michael Goin <michael@neuralmagic.com>
Co-authored-by: Robert Shaw <rshaw@neuralmagic>

* [CI/Build] Add TP test for vision models (vllm-project#5892)

* [ CI/Build ] LM Eval Harness Based CI Testing (vllm-project#5838)

Co-authored-by: Robert Shaw <rshaw@neuralmagic>

* [Bugfix][CI/Build][Hardware][AMD] Install matching torchvision to fix AMD tests (vllm-project#5949)

* [CI/Build] Temporarily Remove Phi3-Vision from TP Test (vllm-project#5989)

* [CI/Build] Reuse code for checking output consistency (vllm-project#5988)

* [CI/Build] [3/3] Reorganize entrypoints tests (vllm-project#5966)

* [ci][distributed] fix device count call

[ci][distributed] fix some cuda init that makes it necessary to use spawn (vllm-project#5991)

* [Frontend]: Support base64 embedding (vllm-project#5935)

Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>

* [Lora] Use safetensor keys instead of adapter_config.json to find unexpected modules.  (vllm-project#5909)

Co-authored-by: sang <sangcho@anyscale.com>

* [ CI ] Temporarily Disable Large LM-Eval Tests (vllm-project#6005)

Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic>

* [Misc] Fix `get_min_capability` (vllm-project#5971)

* [ Misc ] Refactor w8a8 to use `process_weights_after_load` (Simplify Weight Loading) (vllm-project#5940)

Co-authored-by: Robert Shaw <rshaw@neuralmagic>

* [misc][cuda] use nvml to avoid accidentally cuda initialization (vllm-project#6007)

* [Speculative Decoding 2/2 ] Integrate typical acceptance sampler into Spec Decode Worker (vllm-project#5348)

* Revert test changes

* cleanup

* llm engine cleanup

* utils.py cleanup

* custom ops refactor

* move xops to ops

* remove vllm/hpu/attn_bias.py

* whitespace fix

* revert accidental changes in rmsnorm

* Fix hpugraph hashing

* add trim_attn_metadata comment

* fix prompt bucketing:

* [ CI ] Re-enable Large Model LM Eval (vllm-project#6031)

* [doc][misc] remove deprecated api server in doc (vllm-project#6037)

* [Misc] update benchmark backend for scalellm (vllm-project#6018)

* [doc][misc] further lower visibility of simple api server (vllm-project#6041)

Co-authored-by: Simon Mo <simon.mo@hey.com>

* [Bugfix] Use RayActorError for older versions of Ray in  RayTokenizerGroupPool (vllm-project#6039)

* [Bugfix] adding chunking mechanism to fused_moe to handle large inputs (vllm-project#6029)

* add FAQ doc under 'serving' (vllm-project#5946)

* [Bugfix][Doc] Fix Doc Formatting (vllm-project#6048)

* [Bugfix] Add explicit `end_forward` calls to flashinfer (vllm-project#6044)

* [BugFix] Ensure worker model loop is always stopped at the right time (vllm-project#5987)

* [Frontend] Relax api url assertion for openai benchmarking (vllm-project#6046)

* [Model] Changes to MLPSpeculator to support tie_weights and input_scale (vllm-project#5965)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Joshua Rosenkranz <jmrosenk@us.ibm.com>

* [Core] Optimize block_manager_v2 vs block_manager_v1 (to make V2 default)  (vllm-project#5602)

* [Frontend] Add template related params to request (vllm-project#5709)

* [VLM] Remove `image_input_type` from VLM config (vllm-project#5852)

Signed-off-by: Xiaowei Jiang <xwjiang2010@gmail.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: Roger Wang <ywang@roblox.com>

* [Doc] Reinstate doc dependencies (vllm-project#6061)

* guard model loader wa for hpu

---------

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Lei Wen <wenlei03@qiyi.com>
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Signed-off-by: kevin <kevin@anyscale.com>
Signed-off-by: Rafael Vasquez <rafvasq21@gmail.com>
Signed-off-by: Stephanie Wang <swang@cs.berkeley.edu>
Signed-off-by: Stephanie <swang@anyscale.com>
Signed-off-by: Xiaowei Jiang <xwjiang2010@gmail.com>
Signed-off-by: Joe Runde <joe@joerun.de>
Co-authored-by: Li, Jiang <jiang1.li@intel.com>
Co-authored-by: Jianan Gu <jianan.gu@intel.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
Co-authored-by: Roger Wang <ywang@roblox.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: zifeitong <zifei.tong@parasail.io>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
Co-authored-by: Philipp Moritz <pcmoritz@gmail.com>
Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
Co-authored-by: Jie Fu (傅杰) <jiefu@tencent.com>
Co-authored-by: Allen.Dou <allen.dou@hotmail.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
Co-authored-by: Kuntai Du <kuntai@uchicago.edu>
Co-authored-by: Dipika Sikka <dipikasikka1@gmail.com>
Co-authored-by: Sanger Steel <sangersteel@gmail.com>
Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: leiwen83 <leiwen83@users.noreply.github.com>
Co-authored-by: Lei Wen <wenlei03@qiyi.com>
Co-authored-by: SangBin Cho <rkooo567@gmail.com>
Co-authored-by: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
Co-authored-by: Amit Garg <gargamit@microsoft.com>
Co-authored-by: Charles Riggins <liqianchen123@foxmail.com>
Co-authored-by: Liqian Chen <liqian.chen@deeplang.ai>
Co-authored-by: zhyncs <me@zhyncs.com>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
Co-authored-by: Abhilash Majumder <abhilash.majumder@intel.com>
Co-authored-by: Abhilash Majumder <30946547+abhilash1910@users.noreply.github.com>
Co-authored-by: Bruce Fontaine <bruce@2.7182.net>
Co-authored-by: zifeitong <zifeitong@gmail.com>
Co-authored-by: sroy745 <142070531+sroy745@users.noreply.github.com>
Co-authored-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Joe Runde <joe@joerun.de>
Co-authored-by: Chang Su <chang.s.su@oracle.com>
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
Co-authored-by: Kevin H. Luu <kevin@anyscale.com>
Co-authored-by: Ronen Schaffer <ronen.schaffer@ibm.com>
Co-authored-by: sergey-tinkoff <167607910+sergey-tinkoff@users.noreply.github.com>
Co-authored-by: milo157 <43028253+milo157@users.noreply.github.com>
Co-authored-by: Shukant Pal <SukantK2002@outlook.com>
Co-authored-by: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com>
Co-authored-by: DearPlanet <junsong.zhang2021.work@outlook.com>
Co-authored-by: Rafael Vasquez <rafvasq21@gmail.com>
Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Joshua Rosenkranz <joshua.rosenkranz@gmail.com>
Co-authored-by: Davis Wertheimer <Davis.Wertheimer@ibm.com>
Co-authored-by: Jinzhen Lin <linjinzhen@hotmail.com>
Co-authored-by: Jee Li <pandaleefree@163.com>
Co-authored-by: rohithkrn <rohith.nallamaddi@gmail.com>
Co-authored-by: Murali Andoorveedu <37849411+andoorve@users.noreply.github.com>
Co-authored-by: Woo-Yeon Lee <wooyeonlee0@gmail.com>
Co-authored-by: Matt Wong <156021403+mawong-amd@users.noreply.github.com>
Co-authored-by: aws-patlange <90803007+aws-patlange@users.noreply.github.com>
Co-authored-by: Stephanie Wang <swang@cs.berkeley.edu>
Co-authored-by: Stephanie <swang@anyscale.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Co-authored-by: sasha0552 <admin@sasha0552.org>
Co-authored-by: Chip Kerchner <49959681+ChipKerchner@users.noreply.github.com>
Co-authored-by: Abhinav Goyal <abhinav.goyal@flipkart.com>
Co-authored-by: xwjiang2010 <87673679+xwjiang2010@users.noreply.github.com>
Co-authored-by: Divakar Verma <137818590+divakar-amd@users.noreply.github.com>
Co-authored-by: Ilya Lavrenov <ilya.lavrenov@intel.com>
Co-authored-by: Robert Shaw <rshaw@neuralmagic>
Co-authored-by: wangding zeng <155410488+zwd003@users.noreply.github.com>
Co-authored-by: Lily Liu <lilyliupku@gmail.com>
Co-authored-by: LiuXiaoxuanPKU <llilyliupku@gmail.com>, bong-furiosa <bongwon.jang@furiosa.ai>
Co-authored-by: mcalman <68564154+mcalman@users.noreply.github.com>
Co-authored-by: William Lin <SolitaryThinker@users.noreply.github.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: llmpros <10524065+llmpros@users.noreply.github.com>
Co-authored-by: sang <sangcho@anyscale.com>
Co-authored-by: Avshalom Manevich <12231371+avshalomman@users.noreply.github.com>
Co-authored-by: James Whedbee <jamesw@telnyx.com>
Co-authored-by: Joshua Rosenkranz <jmrosenk@us.ibm.com>
Co-authored-by: danieljannai21 <100521221+danieljannai21@users.noreply.github.com>
from os.path import commonprefix
from typing import Dict, FrozenSet, Iterable, List, Optional, Tuple

from vllm.core.block.common import (CopyOnWriteTracker,
get_all_blocks_recursively)
from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator
from vllm.core.block.naive_block import (BlockPool, NaiveBlock,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: Why import BlockPool from vllm.core.block.naive_block instead of vllm.core.block.common?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forgot to change it. It was originally in naive_block.

prashantgupta24 pushed a commit to opendatahub-io/vllm that referenced this pull request Jul 3, 2024
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jul 8, 2024
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jul 24, 2024
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants