Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature]: obtain logits #11397

Open
1 task done
zhc7 opened this issue Dec 21, 2024 · 20 comments
Open
1 task done

[Feature]: obtain logits #11397

zhc7 opened this issue Dec 21, 2024 · 20 comments

Comments

@zhc7
Copy link

zhc7 commented Dec 21, 2024

🚀 The feature, motivation and pitch

same as issue #185 , which is not solved but closed.

Alternatives

No response

Additional context

No response

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@DarkLight1337
Copy link
Member

DarkLight1337 commented Dec 21, 2024

A bit of a hack, but right now you can initialize

llm = LLM(..., task="embed", override_pooler_config=PoolerConfig(pooling_type="ALL"))

and call llm.encode(...) to get the hidden states (logits) directly.

Note that this is not intended usage of the embed task, which is supposed to return the hidden state of a single embedding token. Another task will be added in a future PR to explicitly retrieve all of the output hidden states.

@zhc7
Copy link
Author

zhc7 commented Dec 21, 2024

Thank you for your response! but I am wondering if there can be a way that can both generate and return logits. Since you already know all the logits during the generating process, obtaining them from another instance seems unefficient. Maybe there can be a field like log probs that return top k logit values.

but still, thanks for the temporary bypass method.

@zhc7
Copy link
Author

zhc7 commented Dec 23, 2024

for those who have also come across this problem, I found the key logic is at:

sampling_tensors.repetition_penalties)
# Use float32 to apply temperature scaling.
# Use in-place division to avoid creating a new tensor.
logits = logits.to(torch.float)
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))
if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None:
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
sampling_tensors.top_ks)

you may do anything you want to logits here, and if you want to obtain it, you can modify:

prompt_logprobs, sample_logprobs = get_logprobs(
logprobs, sampling_metadata, maybe_deferred_sample_results)
return _build_sampler_output(
maybe_deferred_sample_results,
sampling_metadata,
prompt_logprobs,
sample_logprobs,
on_device_tensors=on_device_tensors,
skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output)

e.g. replace logprobs with logits. the type of logprobs is List[SampleLogprobs]] and SampleLogprobs = List[Dict[int, Logprob]], so you can do:

        for val, lst in zip(something, sample_logprobs):
            for d in lst:
                for k in d.keys():
                    d[k].logprob = anything_you_want_to_obtain

this is a bit destructive but can work more efficiently than running another model.

@Dineshkumar-Anandan-ZS0367

@DarkLight1337

Kindly help me on this!

This the outcome for logprobs=5 from the qwen2-vl-7b model

[{785: Logprob(logprob=0.0, rank=1, decoded_token='The'), 2: Logprob(logprob=-inf, rank=2, decoded_token='#'), 0: Logprob(logprob=-inf, rank=3, decoded_token='!'), 3: Logprob(logprob=-inf, rank=4, decoded_token='$'), 1: Logprob(logprob=-inf, rank=5, decoded_token='"')}, {4462: Logprob(logprob=0.0, rank=1, decoded_token=' member'), 2: Logprob(logprob=-inf, rank=2, decoded_token='#'), 0: Logprob(logprob=-inf, rank=3, decoded_token='!'), 3: Logprob(logprob=-inf, rank=4, decoded_token='$'), 1: Logprob(logprob=-inf, rank=5, decoded_token='"')}, {594: Logprob(logprob=0.0, rank=1, decoded_token="'s"), 2: Logprob(logprob=-inf, rank=2, decoded_token='#'), 0: Logprob(logprob=-inf, rank=3, decoded_token='!'), 3: Logprob(logprob=-inf, rank=4, decoded_token='$'), 1: Logprob(logprob=-inf, rank=5, decoded_token='"')}, {2400: Logprob(logprob=0.0, rank=1, decoded_token=' date'), 2: Logprob(logprob=-inf, rank=2, decoded_token='#'), 0: Logprob(logprob=-inf, rank=3, decoded_token='!'), 3: Logprob(logprob=-inf, rank=4, decoded_token='$'), 1: Logprob(logprob=-inf, rank=5, decoded_token='"')}, {315: Logprob(logprob=0.0, rank=1, decoded_token=' of'), 2: Logprob(logprob=-inf, rank=2, decoded_token='#'), 0: Logprob(logprob=-inf, rank=3, decoded_token='!'), 3: Logprob(logprob=-inf, rank=4, decoded_token='$'), 1: Logprob(logprob=-inf, rank=5, decoded_token='"')}, {7194: Logprob(logprob=0.0, rank=1, decoded_token=' birth'), 2: Logprob(logprob=-inf, rank=2, decoded_token='#'), 0: Logprob(logprob=-inf, rank=3, decoded_token='!'), 3: Logprob(logprob=-inf, rank=4, decoded_token='$'), 1: Logprob(logprob=-inf, rank=5, decoded_token='"')}, {320: Logprob(logprob=0.0, rank=1, decoded_token=' ('), 2: Logprob(logprob=-inf, rank=2, decoded_token='#'), 0: Logprob(logprob=-inf, rank=3, decoded_token='!'), 3: Logprob(logprob=-inf, rank=4, decoded_token='$'), 1: Logprob(logprob=-inf, rank=5, decoded_token='"')}, {96576: Logprob(logprob=0.0, rank=1, decoded_token='DOB'), 2: Logprob(logprob=-inf, rank=2, decoded_token='#'), 0: Logprob(logprob=-inf, rank=3, decoded_token='!'), 3: Logprob(logprob=-inf, rank=4, decoded_token='$'), 1: Logprob(logprob=-inf, rank=5, decoded_token='"')}, {8: Logprob(logprob=0.0, rank=1, decoded_token=')'), 2: Logprob(logprob=-inf, rank=2, decoded_token='#'), 0: Logprob(logprob=-inf, rank=3, decoded_token='!'), 3: Logprob(logprob=-inf, rank=4, decoded_token='$'), 1: Logprob(logprob=-inf, rank=5, decoded_token='"')}, {374: Logprob(logprob=0.0, rank=1, decoded_token=' is'), 2: Logprob(logprob=-inf, rank=2, decoded_token='#'), 0: Logprob(logprob=-inf, rank=3, decoded_token='!'), 3: Logprob(logprob=-inf, rank=4, decoded_token='$'), 1: Logprob(logprob=-inf, rank=5, decoded_token='"')}, {220: Logprob(logprob=0.0, rank=1, decoded_token=' '), 2: Logprob(logprob=-inf, rank=2, decoded_token='#'), 0: Logprob(logprob=-inf, rank=3, decoded_token='!'), 3: Logprob(logprob=-inf, rank=4, decoded_token='$'), 1: Logprob(logprob=-inf, rank=5, decoded_token='"')}, {17: Logprob(logprob=0.0, rank=1, decoded_token='2'), 2: Logprob(logprob=-inf, rank=2, decoded_token='#'), 0: Logprob(logprob=-inf, rank=3, decoded_token='!'), 3: Logprob(logprob=-inf, rank=4, decoded_token='$'), 1: Logprob(logprob=-inf, rank=5, decoded_token='"')}, {22: Logprob(logprob=0.0, rank=1, decoded_token='7'), 2: Logprob(logprob=-inf, rank=2, decoded_token='#'), 0: Logprob(logprob=-inf, rank=3, decoded_token='!'), 3: Logprob(logprob=-inf, rank=4, decoded_token='$'), 1: Logprob(logprob=-inf, rank=5, decoded_token='"')}, {14: Logprob(logprob=0.0, rank=1, decoded_token='/'), 2: Logprob(logprob=-inf, rank=2, decoded_token='#'), 0: Logprob(logprob=-inf, rank=3, decoded_token='!'), 3: Logprob(logprob=-inf, rank=4, decoded_token='$'), 1: Logprob(logprob=-inf, rank=5, decoded_token='"')}, {15: Logprob(logprob=0.0, rank=1, decoded_token='0'), 2: Logprob(logprob=-inf, rank=2, decoded_token='#'), 0: Logprob(logprob=-inf, rank=3, decoded_token='!'), 3: Logprob(logprob=-inf, rank=4, decoded_token='$'), 1: Logprob(logprob=-inf, rank=5, decoded_token='"')}, {23: Logprob(logprob=0.0, rank=1, decoded_token='8'), 2: Logprob(logprob=-inf, rank=2, decoded_token='#'), 0: Logprob(logprob=-inf, rank=3, decoded_token='!'), 3: Logprob(logprob=-inf, rank=4, decoded_token='$'), 1: Logprob(logprob=-inf, rank=5, decoded_token='"')}, {14: Logprob(logprob=0.0, rank=1, decoded_token='/'), 2: Logprob(logprob=-inf, rank=2, decoded_token='#'), 0: Logprob(logprob=-inf, rank=3, decoded_token='!'), 3: Logprob(logprob=-inf, rank=4, decoded_token='$'), 1: Logprob(logprob=-inf, rank=5, decoded_token='"')}, {16: Logprob(logprob=0.0, rank=1, decoded_token='1'), 2: Logprob(logprob=-inf, rank=2, decoded_token='#'), 0: Logprob(logprob=-inf, rank=3, decoded_token='!'), 3: Logprob(logprob=-inf, rank=4, decoded_token='$'), 1: Logprob(logprob=-inf, rank=5, decoded_token='"')}, {24: Logprob(logprob=0.0, rank=1, decoded_token='9'), 2: Logprob(logprob=-inf, rank=2, decoded_token='#'), 0: Logprob(logprob=-inf, rank=3, decoded_token='!'), 3: Logprob(logprob=-inf, rank=4, decoded_token='$'), 1: Logprob(logprob=-inf, rank=5, decoded_token='"')}, {21: Logprob(logprob=0.0, rank=1, decoded_token='6'), 2: Logprob(logprob=-inf, rank=2, decoded_token='#'), 0: Logprob(logprob=-inf, rank=3, decoded_token='!'), 3: Logprob(logprob=-inf, rank=4, decoded_token='$'), 1: Logprob(logprob=-inf, rank=5, decoded_token='"')}, {19: Logprob(logprob=0.0, rank=1, decoded_token='4'), 2: Logprob(logprob=-inf, rank=2, decoded_token='#'), 0: Logprob(logprob=-inf, rank=3, decoded_token='!'), 3: Logprob(logprob=-inf, rank=4, decoded_token='$'), 1: Logprob(logprob=-inf, rank=5, decoded_token='"')}, {13: Logprob(logprob=0.0, rank=1, decoded_token='.'), 2: Logprob(logprob=-inf, rank=2, decoded_token='#'), 0: Logprob(logprob=-inf, rank=3, decoded_token='!'), 3: Logprob(logprob=-inf, rank=4, decoded_token='$'), 1: Logprob(logprob=-inf, rank=5, decoded_token='"')}, {151645: Logprob(logprob=0.0, rank=1, decoded_token=''), 2: Logprob(logprob=-inf, rank=2, decoded_token='#'), 0: Logprob(logprob=-inf, rank=3, decoded_token='!'), 3: Logprob(logprob=-inf, rank=4, decoded_token='$'), 1: Logprob(logprob=-inf, rank=5, decoded_token='"')}]

@DarkLight1337
Copy link
Member

@DarkLight1337

Kindly help me on this!

This the outcome for logprobs=5 from the qwen2-vl-7b model

Please follow the above comment to obtain the logits instead of the logprobs.

@Dineshkumar-Anandan-ZS0367

@DarkLight1337, Is there any mistakes

This is the code right sir,

    logits = logits.to(torch.float)
    logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))

    print("before logits: ", logits)

    if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None:
        logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
                                    sampling_tensors.top_ks)

    if do_min_p:
        logits = _apply_min_p(logits, sampling_tensors.min_ps)


    print("after top p&k: ", logits)

    # We use float32 for probabilities and log probabilities.
    # Compute the probabilities.
    probs = torch.softmax(logits, dim=-1, dtype=torch.float)
    # Compute the log probabilities.
    logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)

    print("probs: ", probs)

I am getting,

before logits: tensor([[ 50.8878, 80.5463, 81.7951, ..., -18.3415, -18.3415, -18.3415]],
device='cuda:0')

after top p&k: tensor([[-inf, -inf, -inf, ..., -inf, -inf, -inf]], device='cuda:0')

probs: tensor([[0., 0., 0., ..., 0., 0., 0.]], device='cuda:0')

@DarkLight1337
Copy link
Member

I think you just need to insert this

        for val, lst in zip(something, sample_logprobs):
            for d in lst:
                for k in d.keys():
                    d[k].logprob = anything_you_want_to_obtain

before _build_sampler_output is called.

@Dineshkumar-Anandan-ZS0367

@DarkLight1337

I just commented this piece of code.
if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None:
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
sampling_tensors.top_ks)

After that, This is the outcome
CompletionOutput(index=0, text="The member's date of birth (DOB) is 27/08/1964.", token_ids=(785, 4462, 594, 2400, 315, 7194, 320, 96576, 8, 374, 220, 17, 22, 14, 15, 23, 14, 16, 24, 21, 19, 13,
151645), cumulative_logprob=-0.006753989486014689, logprobs=[{785: Logprob(logprob=-0.006748148240149021, rank=1, decoded_token='The'), 58: Logprob(logprob=-5.001865386962891, rank=2, decoded_token='['), 15505:
Logprob(logprob=-31.226261138916016, rank=3, decoded_token='[['), 73594: Logprob(logprob=-33.723819732666016, rank=4, decoded_token='```'), 8420: Logprob(logprob=-40.59212112426758, rank=5, decoded_token='Here')
, 1249: Logprob(logprob=-42.46529006958008, rank=6, decoded_token='To'), 63: Logprob(logprob=-46.83601760864258, rank=7, decoded_token='`'), 28715: Logprob(logprob=-46.83601760864258, rank=8, decoded_token='Base
d'), 37909: Logprob(logprob=-49.33357620239258, rank=9, decoded_token='Bounding')

But these value are not proper format and also unwanted tokens are also present in this outcome

@DarkLight1337
Copy link
Member

The values in logprobs are actually logits now, so you should already have the necessary information. What format exactly do you need?

@Dineshkumar-Anandan-ZS0367
Copy link

Dineshkumar-Anandan-ZS0367 commented Jan 6, 2025

Apologise Cyrus @DarkLight1337 sir.

But i don't saw the logits for this token value - 27/08/1964

How to get this.

This was sampling params,
sampling_params = SamplingParams(
temperature=0.1,
top_p=0.001,
repetition_penalty=1.05,
max_tokens=2048,
stop_token_ids=[],
logprobs=20
)

@DarkLight1337
Copy link
Member

DarkLight1337 commented Jan 6, 2025

I think you can adjust _build_sampler_output so that the logits of the input tokens are also included (I think they are stored in prompt_logprobs), not just the output tokens. Honestly I'm not familiar with this part of the code either...

Actually, maybe you can set prompt_logprobs inside SamplingParams, which should return the "logprobs" (actually logits) of the prompt tokens.

@ControllableGeneration
Copy link

ControllableGeneration commented Jan 14, 2025

A bit of a hack, but right now you can initialize

llm = LLM(..., task="embed", override_pooler_config=PoolerConfig(pooling_type="ALL"))

and call llm.encode(...) to get the hidden states (logits) directly.

Note that this is not intended usage of the embed task, which is supposed to return the hidden state of a single embedding token. Another task will be added in a future PR to explicitly retrieve all of the output hidden states.

This is more what I want! I think this method can save computation if one only wants to get logits of prompt,

However, there are two problems of which I wonder if @DarkLight1337 have any solutions:

  1. a model which supports the "generate" runner but initialized fo rthe 'pooling" runner disables LLM.generate().
  2. llm.encode() gives hidden states rather than logits (logits require one linear mapping from hidden states. Such processing has been implemented in model.compute_logits method. But it requires an entrypoint)

@DarkLight1337
Copy link
Member

a model which supports the "generate" runner but initialized for the 'pooling" runner disables LLM.generate().

This is by design since we don't currently support running both generate and pooling model workers at the same time.

llm.encode() gives hidden states rather than logits (logits require one linear mapping from hidden states. Such processing has been implemented in model.compute_logits method. But it requires an entrypoint

For the TP=1 case, you can access the model instance directly via

llm = LLM(...)
model = llm.llm_engine.model_executor.driver_worker.model_runner.model

You can compute the logits by calling model.compute_logits manually.

@ControllableGeneration
Copy link

Thank you a lot @DarkLight1337! But the workflow is still not right. In order to compute_logits, I need hidden_states, which I can get from llm.encode(), and sampling_metadata. However, sampling_metadata is not an attribute of vllm.worker.pooling_model_runner.PoolingModelRunner. This means that if I work with task="embed", I cannot get sampling_metadata which is required by model.compute_logits.

Can you help me out with this?

And with the first question

a model which supports the "generate" runner but initialized for the 'pooling" runner disables LLM.generate().

This is by design since we don't currently support running both generate and pooling model workers at the same time.

What I need is to start one model by llm = LLM(...) so that I only have one GPU location occupied, and then be able to both "generate" and "encode". I know maybe I should modify sampling_params.py but I am not sure. Can you give me a hint on how to save GPU so that I can generate whereas get logits of prompts at the same time?

@DarkLight1337
Copy link
Member

In order to compute_logits, I need hidden_states, which I can get from llm.encode(), and sampling_metadata. However, sampling_metadata is not an attribute of vllm.worker.pooling_model_runner.PoolingModelRunner. This means that if I work with task="embed", I cannot get sampling_metadata which is required by model.compute_logits.

I see. In that case, you can just call model.lm_head directly on the hidden states.

@DarkLight1337
Copy link
Member

Can you give me a hint on how to save GPU so that I can generate whereas get logits of prompts at the same time?

You can apply the change as suggested in a previous comment: #11397 (comment)

@ControllableGeneration
Copy link

ControllableGeneration commented Jan 15, 2025

Can you give me a hint on how to save GPU so that I can generate whereas get logits of prompts at the same time?

You can apply the change as suggested in a previous comment: #11397 (comment)

Thank you a lot for staying with me here! @DarkLight1337

I have tested in the code and realized that promp_logprobs are lists of None's and sample_logprobs, as suggested in #11397 (comment), are lists of lists of Logprob dicts, so sample_logprobs indeed have values. But sample_logprobs are not what I want, because I want to obtain logits of prompts and sample_logprobs have variable lengths compared to prompt token sequences.

I do notice that with "sampling_params = SamplingParams(temperature=0., prompt_logprobs=1)", llm.generate() generates "prompt_logprobs=[None, PromptLogprobs, PromptLogprobs, ...]". Why does it start with a "None", as it is not the case for logprobs when I set "sampling_params = SamplingParams(temperature=0., logprobs=1)"?

At the same time, the prompt_logprobs in sampler.py are still lists of None's, instead of lists of PromptLogprobs. This is also weird to me, which makes me wonder if the computation of promp_logprobs is done somewhere else. This can also be validated that the logits in sampler.py are of shapes (1, vocab_size), and I notice that the logits, which have only a shape of 1 token, have been allocated to sample_logprobs instead of prompt_logprobs.

@DarkLight1337
Copy link
Member

At this point, your guess is as good as mine (I'm not familiar with this part of the code). It is possible that the first element is None because no logprobs are generated for the first token (the logprobs outputted by the model from the first token are for the second token). I suggest looking into the code further to better understand the details.

@ControllableGeneration
Copy link

Alrighty. Let me dig deeper and hopefully I can update to community some neat solutions

@ControllableGeneration
Copy link

Overall, this is quite helpful as a starting point if you only want to get prompt logits. You just need to modify the ModelRunner class a little bit

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants