Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC]: Improve guided decoding (logit_processor) APIs and performance. #5423

Open
rkooo567 opened this issue Jun 11, 2024 · 15 comments · May be fixed by #6273
Open

[RFC]: Improve guided decoding (logit_processor) APIs and performance. #5423

rkooo567 opened this issue Jun 11, 2024 · 15 comments · May be fixed by #6273

Comments

@rkooo567
Copy link
Collaborator

rkooo567 commented Jun 11, 2024

Motivation.

Currently, guided decoding & logit processor API is incomplete has has several issues. The RFC is intended to bring up problems and solutions. Some of issues may have been already addressed and there are PRs out already.

There are 3 major issues.

  • It is not supported from SamplingParamters
  • It is not possible to support batch/async logit processing.
  • Upon failures, engine will die.

Proposed Change.

API

guided decoding parameters are not supported with SamplingParams. It is addressed from #4130

Performance

Currently, logit processors APIs are applied row by row blocking (

logits_row = logits_processor(prompt_tokens_ids,
). Instead, we can use parallel processing (e.g., ray or thread pool) to improve the logit processing performance. We are using this mechanism internally at Anyscale. We'd like to support this feature in OSS, and would like to improve logit processor API to support 1. async. 2. batching.

This requires logit processor to be

class LogitPostProcessor:
   def initialize(self, logit_processor_config: LogitProcessorConfig):
       """Initialize the post processor. Post processor may have states
           such as thread pool or Ray actors. It should be initialized
           here.
       """
       ...

   def prepare(
           self,
           seq_gruop_metadata_list: List[SequenceGroupMetadata]):
       """Asynchronously prepare logit masks."""
       ...

   def apply(self, logits: torch.Tensor) -> torch.Tensor:
       """Apply the prepared masks to a given logits."""
       ...

# For each model, we will have

def compute_logits(...):
    ....

def prepare_logits(seq_group_metadata_list):
    ....

prepare and apply assume 1:1 calls. E.g., once prepare is called, apply has to be called before another prepare is called. I think it is the safe assumption. Alternatively, we can make prepare return a class, but that will make interface surface larger, so I don't prefer that solution (but I am open to hear feedback!)

This is the example usage of the API

        # each model will have prepare_logits API
        self.model.prepare_logits(seq_group_metadata_list)
        hidden_states = model_executable(
            input_ids=input_tokens,
            positions=input_positions,
            kv_caches=kv_caches,
            attn_metadata=attn_metadata,
            **multi_modal_kwargs,
        )
        # Compute the logits. logit processors are applied here.
        logits = self.model.compute_logits(hidden_states, sampling_metadata)

We are also considering to upstream Ray based batch processing implementation with lmformatenforcer.

Failure Handling

When using a stateful logit processor, it is possible requests are failed. For example, if we use Ray, Ray actors can die. Or there could be user's schema issue that cannot be caught ahead of time.

When it happens, we should fail the seq_group immediately. We will introduce a new status "FINISHED_INTERNAL_ERROR = enum.auto()" to

class SequenceStatus(enum.Enum):
. If any logit processor is failed, we will mark the relevant seq_group as failed, and the request will be aborted.

Feedback Period.

No response

CC List.

cc @simon-mo @Yard1

Any Other Things.

No response

@rkooo567 rkooo567 added the RFC label Jun 11, 2024
@simon-mo
Copy link
Collaborator

cc @njhill @br3no @mmoskal

@br3no
Copy link
Contributor

br3no commented Jun 11, 2024

I have a few questions:

It is not supported from SamplingParamters

Can you elaborate on why you think placing the guided decoding parameters in the SamplingParams is a good idea? As I commented in #4130, I think they conceptually overlap with the logits processors implementing the guided decoding, which are already in the SamplingParams.

This requires logit processor to be

  • stateful (to use a tool like Ray or thread pool).
    ...

Do you maybe mean stateless? If not, what do you mean exactly?

Regarding the topic of statefulness: we probably don't want to limit ourselves to stateless logits processors. If we manage to make the API so that it is easy to implement stateful logits processors, we would already make things much better. E.g. I think that a very good thing to address would be to add infrastructure for pooling stateful objects and making it easy to define that one such object should not be shared across sequences and requests, or at least should be reset before being used.

Could you also please elaborate on the new LogitsPostProcessor API you propose? Is this the API to be implemented by logits processors? Or is this an API to be implemented by the models?

Are there maybe some type annotations missing for the return values of e.g. prepare? If this method does not return anything, this means the LogitsPostProcessor is stateful, right? Shouldn't we aim for a stateless design here, to make parallelization easier?

I might have misunderstood the proposal though. So, I'd be really happy if you could elaborate on it.

All in all, I would be very interested in improvements in this area, so I'm glad you're working on it!

@rkooo567
Copy link
Collaborator Author

rkooo567 commented Jun 11, 2024

Can you elaborate on why you think placing the guided decoding parameters in the SamplingParams is a good idea? As I commented in #4130, I think they conceptually overlap with the logits processors implementing the guided decoding, which are already in the SamplingParams.

It's like moving the functionality to the core API. Right now, it is implemented like an add-on (only working with OpenAI server), and it doesn't work with tools like https://github.com/anyscale/ray-llm (because we directly use the core API). It requires code that breaks the abstraction barrier (i.e., creating logit processor), and given the guided decoding is a core function, I feel like having the API in SamplingParams make sense.

Do you maybe mean stateless? If not, what do you mean exactly?

To improve time to prepare masks for json mode, we want to use parallel processing tools such as threadpool or ray. It requires the logit processor to be "stateful" because we don't want to recreate actors or threadpools everytime logit processos is requested (it should be created in __init__).

E.g. I think that a very good thing to address would be to add infrastructure for pooling stateful objects and making it easy to define that one such object should not be shared across sequences and requests, or at least should be reset before being used.

+1. I think it'd be an implementation of part 2.

Could you also please elaborate on the new LogitsPostProcessor API you propose? Is this the API to be implemented by logits processors? Or is this an API to be implemented by the models?

It will replace _apply_logit_processor private API inside logit_processor.py. Right now, we apply logit mask row by row. We instead 1. find the relevant logit processor created. 2. logit_processor.prepare(seq_group_metadata_list) -> logit_processor.apply(logits).

Are there maybe some type annotations missing for the return values of e.g. prepare? If this method does not return anything, this means the LogitsPostProcessor is stateful, right? Shouldn't we aim for a stateless design here, to make parallelization easier?

You are right the prep and apply is stateful. We can make it this way as well.

        masks = self.model.prepare_logits(seq_group_metadata_list)
        hidden_states = model_executable(
            input_ids=input_tokens,
            positions=input_positions,
            kv_caches=kv_caches,
            attn_metadata=attn_metadata,
            **multi_modal_kwargs,
        )
        # Compute the logits. logit processors are applied here.
        logits = self.model.compute_logits(hidden_states, sampling_metadata, masks)

But I found it easier to just make it fully stateful.

Hope this clarifies the proposal a little bit!

@simon-mo
Copy link
Collaborator

We should make this work with the following RFCs

@NadavShmayo #4769
@mmoskal #4775
@mmoskal #2888
@maxdebayser @njhill #5329
@lynkz-matt-psaltis #5006

@rkooo567
Copy link
Collaborator Author

rkooo567 commented Jun 12, 2024

My initial thoughts;

@mmoskal
Copy link
Contributor

mmoskal commented Jun 13, 2024

Some ideas:

  • maybe initialize() can be async? the reason is that we don't start scheduling sequences, where the processor is still initializing (in case it takes a few seconds)
  • add some sort of free() API so resources can be freed

With an additional post-sampling callback, this would subsume my SequenceController #4775 :

    def sampled(self, seq: 'Sequence', token_id: int,
                logprobs: Dict[int, 'Logprob']) -> Tuple[int, List[int], bool]:
        """
        Informs the controller a given token has been sampled.
        Returns the number of tokens to backtrack, the tokens to append,
        and whether to stop.
        """
        if token_id == seq.eos_token_id:
            return 0, [], True
        return 0, [token_id], False

@rkooo567
Copy link
Collaborator Author

With an additional post-sampling callback, this would subsume my SequenceController #4775 :

I see. I found that API is limited for our particular use case because as you know it is applied after sampling is done (whereas we want to apply logit processor on final logits). It's great if we can subsume it.

add some sort of free() API so resources can be freed

I am open to it, but right now there's no specific use cases.

maybe initialize() can be async? the reason is that we don't start scheduling sequences, where the processor is still initializing (in case it takes a few seconds)

How is this guaranteed now?

@br3no
Copy link
Contributor

br3no commented Jun 14, 2024

@rkooo567 thanks, let me see if I understand it:

The idea is that the logits processors will be asked to prepare their masks asynchronously and in the meantime the model is going to be run. Once both are ready, the logits are computed by having the model call apply.

This means that the whole process needs to guarantee that there is one logits processor instance per request per sequence. Correct?

The implementation will need to be very careful to avoid contention issues.


Regarding the combination of this with the other PRs: I'm still struggling a bit to understand what general design we need. Let me explain:

The logits processors are now applied in the models; so the general signature of the operation is

compute_logits(hidden_states: Tensor, ...) -> Tensor

We want to support ff-tokens or backtracking (e.g. #4775). These things happen a few layers above the model and don't fit this API above.

So we're talking about different things in different abstraction layers at the same time.

Am I the only one? Is the design clear to you folks? If so, I would appreciate it a lot if someone could describe where which types of object would play which role.

@mmoskal
Copy link
Contributor

mmoskal commented Jun 14, 2024

@br3no One thing that took me a while to see is that there is only one LogitPostProcessor per LLMEngine - it handles logits for all sequences in the current batch.

There was some discussion of allowing a list of those, but IMHO it's easy to write a LogitPostProcessor that bundles an arbitrary number of ``LogitPostProcessor`s so I think there's no need to have a list of post processors in vLLM.

I'm the one asking for ff_tokens and backtracking, I think @rkooo567 is not doing this now.

@njhill
Copy link
Member

njhill commented Jun 15, 2024

@rkooo567 @simon-mo @mmoskal some additional thoughts after we talked offline yesterday:

It's a concern that the current support is kind of broken, it doesn't work for input batches or beam search due to the stateful/concurrency thing. So I wonder if we could prioritize some simpler immediate fixes for that along with the egregious performance overhead with json mode due to having to construct a new CFGuide instance every time. i.e. before the more significant rework to introduce batched application and the prepare step... WDYT?

A couple of other thoughts about the proposed interface:

  • Why would we need an initialize method, couldn't a regular constructor be used for this?
  • I'm not sure that it's a good idea to expose List[SequenceGroupMetadata] in this API ... I had assumed SequenceGroupMetadata is an internal datastructure that we want the freedom to change without breaking 3rd party LogitsProcessor impls. Probably should have some simpler dataclass or abstract class designed specifically for the API.

@br3no
Copy link
Contributor

br3no commented Jun 15, 2024

@mmoskal thanks for your answer! I also would like to support ff-tokens since I think this would contribute to alleviate the performance issues.

@njhill I’m not familiar with lm-format-enforcer, but for the Outlines processors now only the CFG one is problematic. The others are now stateless. Should we concentrate on a “fix” for the output_format: json issue? This would involve an object pool for the CFGGuide for that particular use case. Or am I missing other aspects here?

@rkooo567
Copy link
Collaborator Author

There was some discussion of allowing a list of those, but IMHO it's easy to write a LogitPostProcessor that bundles an arbitrary number of ``LogitPostProcessor`s so I think there's no need to have a list of post processors in vLLM.

I also agree with it. I have impression the current interface is a little over-designed with some vague implementation in mind. For ff-tokens and backtracking, I would like to see the implementation otherwise it is very difficult to design the interface (that's why we punted). I think the interface I propose here is not going to bother us getting there (logit processor API also feels like it is not very stable API yet, so we have time to iterate).

It's a concern that the current support is kind of broken, it doesn't work for input batches or beam search due to the stateful/concurrency thing. So I wonder if we could prioritize some simpler immediate fixes for that along with the egregious performance overhead with json mode due to having to construct a new CFGuide instance every time. i.e. before the more significant rework to introduce batched application and the prepare step... WDYT?

Does it mean supporting stateful logit processor first (meaning merging the open PR)? I am okay with this.

Why would we need an initialize method, couldn't a regular constructor be used for this?

I think regular constructor could work. The main reason was we need to pass the decode config to the logit processor, and since it is inside the model, the required change was big. I think constructor makes more sense actually.

I'm not sure that it's a good idea to expose List[SequenceGroupMetadata] in this API ... I had assumed SequenceGroupMetadata is an internal datastructure that we want the freedom to change without breaking 3rd party LogitsProcessor impls. Probably should have some simpler dataclass or abstract class designed specifically for the API.

Yeah it is a good point. for our internal impl, we just need seq_data, seq_ids, request_id, and sampling params.

@mmoskal mmoskal linked a pull request Jul 9, 2024 that will close this issue
@mmoskal
Copy link
Contributor

mmoskal commented Jul 9, 2024

I did a first pass on this in #6273

Copy link

This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you!

@github-actions github-actions bot added the stale label Oct 26, 2024
@aarnphm
Copy link
Contributor

aarnphm commented Nov 5, 2024

Hi all, for those who are following this thread, I started benchmarking current performance for guided decoding in vLLM here #10046

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants