Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc]: hidden states using vllm #3594

Closed
ra-MANUJ-an opened this issue Mar 24, 2024 · 2 comments
Closed

[Misc]: hidden states using vllm #3594

ra-MANUJ-an opened this issue Mar 24, 2024 · 2 comments

Comments

@ra-MANUJ-an
Copy link

Anything you want to discuss about vllm.

Following is a little piece of code to extract embeddings from a certain layer of LLM:

def process_row(prompt: str, model, tokenizer, layers_to_use: list, remove_period: bool):
    """
    Processes a row of data and returns the embeddings.
    """
    if remove_period:
        prompt = prompt.rstrip(". ")
    inputs = tokenizer(prompt, return_tensors="pt")
    with torch.no_grad():
        outputs = model.generate(inputs.input_ids, output_hidden_states=True, return_dict_in_generate=True, max_new_tokens=1, min_new_tokens=1)
    embeddings = {}
    for layer in layers_to_use:
        last_hidden_state = outputs.hidden_states[0][layer][0][-1]
        embeddings[layer] = [last_hidden_state.numpy().tolist()]
    return embeddings

It's pretty standard way, but it's pretty slow. Is there any way to use vllm to make it faster without needing to call generate function everytime? I've tried batching, but it's slow too. Any help is appreciated!

One way to get last hidden state values using vllm is as follows:

from vllm import LLM, SamplingParams
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup, SequenceData, 
                           SequenceGroupMetadata, SequenceStatus)
from transformers import LlamaModel, LlamaTokenizer
from vllm import EngineArgs, LLMEngine, SamplingParams, RequestOutput
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata


llm = LLM(model=path_to_llama2)


# Enable top-k sampling to reflect the accurate memory usage.
vocab_size = llm.llm_engine.workers[0].model.config.vocab_size
sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1)
max_num_batched_tokens = llm.llm_engine.workers[0].scheduler_config.max_num_batched_tokens
max_num_seqs = llm.llm_engine.workers[0].scheduler_config.max_num_seqs
prompt = train[0]
prompt_token_ids = llm.llm_engine.tokenizer.encode(prompt) #[2, 100, 524, 10]
seqs = []
    
group_id = 1
seq_data = SequenceData(prompt_token_ids)
seq = SequenceGroupMetadata(
    request_id=str(group_id),
    is_prompt=True,
    seq_data={group_id: seq_data},
    sampling_params=sampling_params,
    block_tables=None,
)
seqs.append(seq)
input_tokens, input_positions, input_metadata = llm.llm_engine.workers[0]._prepare_inputs(
    seqs)
prompt_len = len(seq_data.prompt_token_ids)
input_tokens = input_tokens[:prompt_len]
input_positions = input_positions[:prompt_len]
# Execute the model.
num_layers = llm.llm_engine.workers[0].model_config.get_num_layers(llm.llm_engine.workers[0].parallel_config)
tempOut = llm.llm_engine.workers[0].model.model(
    input_ids=input_tokens,
    positions=input_positions,
    kv_caches=[(None, None)] * num_layers,
    input_metadata=input_metadata,
    cache_events=None,
)
print(tempOut.size())

but this doesn't get me with all the hidden state embeddings (of all layers). Is there any other way to get such values in a faster manner?

Copy link

This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you!

@github-actions github-actions bot added the stale label Oct 29, 2024
Copy link

This issue has been automatically closed due to inactivity. Please feel free to reopen if you feel it is still relevant. Thank you!

@github-actions github-actions bot closed this as not planned Won't fix, can't repro, duplicate, stale Nov 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant