Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] Sampling controller interface #6273

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

mmoskal
Copy link
Contributor

@mmoskal mmoskal commented Jul 9, 2024

This patch adds SamplingController object on LLMEngine. It subsumes LogitsProcessor functions in SamplingParams as suggested in #5423. Instead of calling the class LogitPostProcessor I used SamplingController since it's a bit broader than just dealing with logits (in particular it allows for fast forward tokens, see below). I'm happy to rename if needed.

The basic idea is that the engine holds an instance of SamplingController and calls methods on it to influence the sampling process. In every step the following methods are called:

  • prepare(sampling_metadata: SamplingMetadata): this is meant to start computation of logit biases for sequences in the batch; the controller will likely need to store mapping from sequences to indices in logit tensor
  • model forward pass is started
  • transform_logits(logits: torch.Tensor) -> torch.Tensor is called on the entire logit tensor (for all sequences in the batch)
  • sampling is performed as usual
  • transform_sampler_output(output: SamplerOutput) -> SamplerOutput is called on the output of the sampler

In case of an empty step (where no sequences are scheduled to run), the empty_step() method is called, instead of the three methods mentioned above. This is to allow the controller to perform cleanup.

To be clear, the only way to use this right now, is to derive from SamplingController and use vllm as a library.

The transformation of sampler output is primary useful in conjunction with the newly added SequenceOutput.fast_forward_tokens field. If this is set, these tokens are to be added to the sequence instead of the sampled token.

An example, where fast forward tokens are useful is generating data adhering to a certain JSON schema. The controller first forces {"name":" to be generated, then the model generates John", the controller forces ,\n"age":, model generates 42, and so on. Another example is chain-of-thought reasoning, where after the model generated a sentence, the controller forces more instructions for the model, the model generates more text, and so on. If used, these greatly speed up generation process.

The SamplingController is passed from LLMEngine to worker/executor via ExecuteModelRequest and ModelRunnerInputBase. Only driver worker receives the controller (as it's the only one that does sampling).

I also added request_id to SequenceGroupToSample (which is referenced from SamplingMetadata) to allow the controller to identify the sequences in the batch.

There are two places where I had to make changes to allow empty fast forward tokens:

  • SequenceData.update_num_computed_tokens
  • BlockTable.append_token_ids
    The main reason to support empty tokens is for the cases when the computation of the next token mask did not complete in time. Instead of waiting for it to finish, or aborting the sequence, the controller can instead do an "empty pass" on that particular sequence and try it again in the next step, without holding up the rest of the batch. Of course, if this happens several times, the controller is free to terminate the sequence.

Status: I have this working with AICI. I think this is good to go, though it may need some tests.

CC @rkooo567 @simon-mo @GindaChen @cadedaniel @njhill

FIX #5423

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

@mmoskal
Copy link
Contributor Author

mmoskal commented Jul 11, 2024

@rkooo567 @simon-mo any comments?

@Peng-YM
Copy link

Peng-YM commented Jul 15, 2024

Looks great, thanks for your contribution!

@DarkLight1337 DarkLight1337 requested review from rkooo567 and simon-mo and removed request for rkooo567 July 16, 2024 07:46
@rkooo567
Copy link
Collaborator

I will take a look at it today!

@mmoskal
Copy link
Contributor Author

mmoskal commented Jul 18, 2024

@rkooo567 any updates?

Copy link
Collaborator

@rkooo567 rkooo567 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we will need review from some of more stakeholders. Maybe @njhill and @simon-mo But the part I need looks pretty good to me. I am going to try prototyping our internal implementation by early next week and get back to you asap.

I have 2 questions.

  • For SamplingController, currently, there's no way to pass this. What's the API you are thinking to set this? (is it supposed to be set in a fork level?)
  • For fast forward token, are you planning to upstream any implementation?


for i, token_block in enumerate(token_blocks):
self._blocks.append_token_ids(first_block_idx + i, token_block)
# don't bother appending anything, if no new token_ids were generated
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: is this change related to this PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this happens when the "fast forward" is 0-tokens long - see comment in PR description about BlockTable.append_token_ids

@@ -912,6 +941,26 @@ def prune(self,
self.seq_ids = seq_ids


class SamplingController:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make it an abstract class?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also in the docstring, we should probably mention this class is a singleton and stateful. And prepare & transform logits need to be called 1 tot 1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added docstrings.

It seems python doesn't have abstract classes, only abstract methods. I can imagine use cases where you only override some of these methods, and none of them you have to override, so I don't think it makes sense to make any of them abstract.

One thing we could do is to make the engine.sampling_controller field non-optional and use this base class as a no-op implementation. Not sure if that would be cleaner though?

vllm/sequence.py Show resolved Hide resolved
@@ -468,8 +474,8 @@ def lora_int_id(self) -> int:

def get_last_latency(self, now: float) -> Optional[float]:
"""Sets the last token time for Request level timings."""
# If still in prefill phase, raise Error.
if self.is_prefill():
# If still in initial prefill phase, raise Error.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the case it is not "initial"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This happens for fast-forward tokens - if there is more than one, we switch to prefill mode.


if (ctrl := model_input.sampling_controller) is not None:
assert model_input.sampling_metadata is not None
ctrl.prepare(model_input.sampling_metadata)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been changed from RFC (we pass seq group metadata, here, we are using sampling params). Is there any way to hook this with seq group metadata?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are SamplingParams in sampling_metadata (specifically in SequenceGroupToSample); there is also request_id and sequence_id - I think this is the stuff you said you needed. The main reason to use sampling metadata is that it also has all the logit positions already computed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. yeah I guess what we have in sampling metdata is probably sufficient. let me get back to you soon about this.

@rkooo567 rkooo567 requested a review from njhill July 19, 2024 17:50
@mmoskal
Copy link
Contributor Author

mmoskal commented Jul 19, 2024

Regarding usage pattern - it currently requires you to construct LLMEngine yourself. I plan to do a PR on the openai api_server.py to make it a bit more modular, so you can have your own api_server that extends the "openai" one. See here for an usage example: https://github.com/microsoft/aici/blob/split_toktrie/py/pyaici/vllm_server.py

Regarding fast-forward tokens - they work with the current patch, provided you enable chunked prefill. AICI supports them and so does the Low-level Guidance parser I've been working on (right now https://github.com/mmoskal/llguidance but it will be moving gh orgs shortly)

I pushed the docstring changes, and now it seems some tests are failing, not sure what that is about...

@rkooo567
Copy link
Collaborator

okay, let me prototype our internal changes with this API and I will get back to you asap. In the meantime let's wait for additional community reviews. other than that, it looks okay to me (I don't understand the fast forward token part that well, so I did loose review on it btw)

@cadedaniel
Copy link
Collaborator

cadedaniel commented Jul 19, 2024

Can I review the fast forward part? It intersects with spec decode, and there's a few ways to implement it.

I will review this weekend.

@cadedaniel
Copy link
Collaborator

cadedaniel commented Jul 22, 2024

On the topic of fast-forward tokens:

I agree we should support this optimization in vLLM, it's part of how SGLang's json mode has such good performance. However I feel we should decouple the implementation/design from this sampling controller pull request. There are design tradeoffs in how we implement fast-forward tokens, and we should consider the different ways before going with one.

Concretely, there's a different design that I believe is better for vLLM: use a specialized Worker which wraps the normal Worker, which contains the specialized logic which requires fast-forward tokens. For example, JSON mode could be implemented at as a JSONModeWorker, which wraps a normal worker within it.

The design benefit here is that from the wrapped Worker's perspective, jump tokens are specified using the same API as chunked prefill, speculative decoding scoring, and any other algorithms which require multiple query tokens per sequence (e.g. classifier free guidance, or speculative edits (fireworks blog post calls it "fast apply", vllm issue)). Even if we just consider speculative decoding, we'd love to be able to combine speculative decoding with jump-tokens in JSON mode for extra speedup.

Two concrete downsides of the current design in this PR:

  • It conflates scheduling status of a sequence with multi-query decode. We shouldn't need to make a sequence "PREFILL" to get multi-query decode, because that currently instructs the backend to use a different path for that kernel (e.g. fire off different attention kernels). For performance we should be able to have the backend fire off a single attention kernel which computes both multi-query and single-query attention scores. See this design doc by @LiuXiaoxuanPKU for more information.
  • It makes on-device code live outside of the Worker paradigm; e.g. a sampling controller that exists at the Engine level. To fully support different hardware backends and distributed configurations, we want to delegate the responsibility of all device code to below the Executor and Worker level. This can be trivially fixable in this PR by moving it to the Worker, but then we have the challenges of 1) the API only supports nvidia/amd backends, and 2) the interaction with Scheduler no longer works because the PR assumes worker 0 and engine share the same process. this is changing in SPMD refactor for other performance reasons.

Edit: one last use case: I believe tree scoring in speculative decoding could use the same API as multi-path jump decoding.

@Harsha-Nori
Copy link

Hey all, I'm the lead maintainer for the Guidance project (https://github.com/guidance-ai/guidance), and we've had many requests for guidance support on vLLM, which this PR would enable :). Guidance leverages "jump"/fast forward tokens in the majority of programs users write -- including many use cases beyond JSON mode -- so it'd be great for us to have infrastructure to leverage here in vLLM. Happy to help in any way I can to get this implemented!

@cadedaniel
Copy link
Collaborator

@mmoskal @Harsha-Nori I recommend following up with @zhuohan123 for best approach on jump tokens. the proposed approach has downsides for vLLM performance and maintanence.

Copy link

This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you!

Copy link

mergify bot commented Nov 26, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @mmoskal.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 26, 2024
@github-actions github-actions bot added unstale and removed stale labels Nov 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[RFC]: Improve guided decoding (logit_processor) APIs and performance.
5 participants