Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC]: Hidden states processor #12249

Open
1 task done
DarkLight1337 opened this issue Jan 21, 2025 · 11 comments
Open
1 task done

[RFC]: Hidden states processor #12249

DarkLight1337 opened this issue Jan 21, 2025 · 11 comments
Labels

Comments

@DarkLight1337
Copy link
Member

DarkLight1337 commented Jan 21, 2025

Motivation.

Since #10674, vLLM uses Pooler to extract hidden states from the model and convert them to embeddings, class probabilities, and so on. However, this is still not user-friendly enough:

Proposed Change.

Similar to LogitsProcessor (#1469), we can pass a custom HiddenStatesProcessor in SamplingParams and PoolingParams to postprocess the hidden states and return them in the output. This provides maximum flexibility and enables the same model to be used for different downstream tasks.

# Note that we can use a different processor each time we call `llm.generate`
outputs = llm.generate(..., sampling_params=SamplingParams(hidden_states_processor=...))
custom_outputs = outputs.hidden_states_processor_outputs

The interface of HiddenStatesProcessor is similar to VllmModelForTextGeneration.compute_logits and VllmModelForPooling.pooler:

H = TypeVar("H", default=torch.Tensor)
R = TypeVar("R", default=torch.Tensor)

class HiddenStatesProcessor(Protocol[H, R]):
    def __call__(self, model: VllmModel[H], hidden_states: H) -> R:
        ...

The default poolers for each downstream task will be implemented as built-in HiddenStatesProcessor classes.

  • IdentityHiddenStatesProcessor: Returns hidden states directly (mainly for reward models)
  • NormalizeHiddenStatesProcessor: Applies normalization to hidden states (mainly for prompt embedding)
  • SoftmaxHiddenStatesProcessor: Applies softmax to hidden states (mainly for classification)
  • StepHiddenStatesProcessor: Applies step processing to hidden states (mainly for PRM models)

The existing pooling APIs (LLM.encode, LLM.embed, LLM.score, etc.) will be updated to use these HiddenStatesProcessors automatically.

To get logits from the hidden states, we can have a new hidden states processor that references the LM head of the model:

class LogitsFromHiddenStatesProcessor(HiddenStatesProcessor):
    def __init__(self, lm_head_name: str = "lm_head") -> None:
        self.lm_head_name = lm_head_name

    def __call__(self, model: VllmModel, hidden_states: torch.Tensor) -> torch.Tensor:
        lm_head = getattr(model, self.lm_head_name)
        assert isinstance(lm_head, VocabParallelEmbedding)

        logits = lm_head.linear_method.apply(lm_head, hidden_states)
        return logits

With this design, we can also generate multi-modal outputs:

class ImageFromHiddenStatesProcessor(HiddenStatesProcessor[torch.Tensor, list[Image]]):
    def __init__(self, decoder_name: str = "image_decoder") -> None:
        self.decoder_name = decoder_name
        self._to_pil_image = torchvision.transforms.v2.ToPILImage()

    # Suppose hidden_states is the output of the model's encoder without calling the vision decoder
    def __call__(self, model: VllmModel, hidden_states: torch.Tensor) -> list[Image]:
        image_decoder = getattr(model, self.decoder_name)
        images = image_decoder(hidden_states)  # Shape: [N, C, H, W]
        return [self._to_pil_image(image) for image in images.cpu()]

(Note: This is just one potential approach to generate multi-modal outputs in vLLM. Other methods are still up for discussion.)

Some issues to be addressed:

  • How to handle TP/PP properly? Should the processor be aware of this?
  • The hidden states processor is not known at startup time, so it is excluded from model profiling. This may lead to OOM issues especially if the hidden states processor calls a significant portion of the model.

Feedback Period.

Around 2 weeks? See when I have time to work on this...

CC List.

@simon-mo @youkaichao @robertgshaw2-redhat @ywang96 @Isotr0py @maxdebayser @flaviabeo @HwwwwwwwH

Any Other Things.

Since the regular model runner can also return hidden states, we should consider merging the functionality of PoolingModelRunner with the regular ModelRunner in V1 (#8779) to simplify our codebase. I think the only difference is that PoolingModelRunner uses dummy KV caches?

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@youkaichao
Copy link
Member

we can pass a custom HiddenStatesProcessor in SamplingParams

can we pass it as an argument to LLM? I don't think people use different processor for each request. It should be the same across one inference instance.

@DarkLight1337
Copy link
Member Author

DarkLight1337 commented Jan 21, 2025

can we pass it as an argument to LLM? I don't think people use different processor for each request. It should be the same across one inference instance.

Some users want to use the same LLM engine for online generation and embedding, but not necessarily both in the same request (see #11905). It would be a waste of resources to run both in that case.

@youkaichao
Copy link
Member

I don't think we need to support that. One engine should do one task, otherwise the code would be super-complicated, and would be difficult to optimize.

@youkaichao
Copy link
Member

complicated sampling parameter is a major factor why vllm became slower previously. We should be very careful about runtime cost that happens per-request. While generally I'm fine with adding engine-level features that do not affect per-request performance.

@DarkLight1337
Copy link
Member Author

DarkLight1337 commented Jan 21, 2025

I understand your concern. Keep in mind though that hidden states processors are optional and should not affect performance except when they are being used. Unless I'm mistaken, our Poolers aren't being optimized either, so the performance should be similar as now.

@DarkLight1337
Copy link
Member Author

DarkLight1337 commented Jan 21, 2025

We can pass a dictionary of "allowed" hidden states processors to LLM at startup time so vLLM can profile and optimize them. Then at inference time, we only let the user select from these processors. Would that alleviate your concerns?

@youkaichao
Copy link
Member

Keep in mind though that hidden states processors are optional and should not affect performance except when they are being used.

adding sampling parameters will in general slow down the inference, because sometimes we need to pass the object across process.

I think hidden states processors should only be instance-level, users can only specify one processor for one instance.

@youkaichao
Copy link
Member

We can pass a dictionary of "allowed" hidden states processors to LLM at startup time so vLLM can profile and optimize them. Then at inference time, we only let the user select from these processors.

that would be even worse, since you need to validate the processor during runtime, per-request.

@DarkLight1337
Copy link
Member Author

We can pass a dictionary of "allowed" hidden states processors to LLM at startup time so vLLM can profile and optimize them. Then at inference time, we only let the user select from these processors.

that would be even worse, since you need to validate the processor during runtime, per-request.

I mean that the user only passes the (string) key of the processor in the initial dictionary without sending the actual object.

@comaniac
Copy link
Collaborator

I'm also in favor of setting it at the startup time so that we could better profile and avoid OOM. In general allowing one endpoint to support various pooling mechanism seems not that common. We could mark it as a limitation for now, and think about improvements in the future if there are high demands.

@youkaichao
Copy link
Member

OOM

out-of-memory error is also a valid concern. if the engine can serve multiple types of requests, then the memory profiling stage would be super complicated, and I doubt if it is even possible. under high load, lots of edge case can occur.

there's also scheduling challenges, the memory cost of generation and embedding can require different memory, and one single token budget would not be enough to quantify and bound the memory usage.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants