Skip to content

Commit

Permalink
Fix huggingface logprob (stanford-crfm#1964)
Browse files Browse the repository at this point in the history
  • Loading branch information
ruixin31 committed Dec 5, 2023
1 parent 771a084 commit 24de541
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 31 deletions.
104 changes: 73 additions & 31 deletions src/helm/proxy/clients/huggingface_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,45 +82,75 @@ def serve_request(self, raw_request: Dict[str, Any]):
stopping_criteria.append(StopAtSpecificTokenCriteria(stop_sequence=stop_sequence_input_ids))
del raw_request["stop_sequences"]

# Strip out irrelevant parameters
relevant_raw_request = {
key: raw_request[key]
for key in raw_request
if key not in ["engine", "prompt", "echo_prompt", "stop_sequences"]
}
# Check if we need to compute the perplexity of the prompt (#1497)
compute_logprobs_only = (
raw_request["max_new_tokens"] == 0
and raw_request["num_return_sequences"] == 1
and raw_request["echo_prompt"]
)

# Use HuggingFace's `generate` method.
output = self.model.generate(
**encoded_input,
**relevant_raw_request,
stopping_criteria=stopping_criteria,
)
sequences = output.sequences
scores = output.scores
if compute_logprobs_only:
with torch.no_grad():
output = self.model(encoded_input["input_ids"])
sequences = encoded_input["input_ids"]
scores = output.logits
else:
# Strip out irrelevant parameters
relevant_raw_request = {
key: raw_request[key]
for key in raw_request
if key not in ["engine", "prompt", "echo_prompt", "stop_sequences"]
}

output = self.model.generate(
**encoded_input,
**relevant_raw_request,
stopping_criteria=stopping_criteria,
)
sequences = output.sequences
scores = output.scores

prompt_tokens_logprobs = []
prompt_tokens_top_logprobs_dicts: List[Dict] = []
if compute_logprobs_only:
# Append the logprob of the first token of the prompt.
prompt_tokens_logprobs.append(0.0)
prompt_tokens_top_logprobs_dicts.append({})

# Compute logprobs of prompt tokens.
for completion_id in range(raw_request["num_return_sequences"]):
for i in range(len(sequences[completion_id]) - 1):
logprobs = torch.nn.functional.log_softmax(scores[completion_id][i], dim=0)
topk_logprobs = torch.topk(logprobs, k=top_k_per_token)
prompt_tokens_top_logprobs_dicts.append(
{
self.tokenizer.convert_ids_to_tokens(k.item()): v.item()
for (k, v) in zip(topk_logprobs.indices, topk_logprobs.values)
}
)
prompt_tokens_logprobs.append(logprobs[sequences[completion_id][i + 1]].item())

# Compute logprobs for each completed sequence.
all_logprobs_of_chosen_tokens = []
all_top_logprobs_dicts = []
# Compute logprobs of generated tokens for each completed sequence.
all_generated_tokens_logprobs = []
all_generated_tokens_top_logprobs_dicts = []
for completion_id in range(raw_request["num_return_sequences"]):
logprobs_of_chosen_tokens = []
top_logprobs_dicts = []
generated_tokens_logprobs = []
generated_tokens_top_logprobs_dicts = []
for i in range(len(sequences[completion_id]) - len(encoded_input.input_ids[0])):
logprobs = torch.nn.functional.log_softmax(scores[i][completion_id], dim=0)

# Get top tokens in terms of log probability.
topk_logprobs = torch.topk(logprobs, k=top_k_per_token)
top_logprobs_dicts.append(
generated_tokens_top_logprobs_dicts.append(
{
self.tokenizer.convert_ids_to_tokens(k.item()): v.item()
for (k, v) in zip(topk_logprobs.indices, topk_logprobs.values)
}
)

# Get log probability of chosen token.
j = i + len(encoded_input.input_ids[0])
logprobs_of_chosen_tokens.append(logprobs[sequences[completion_id][j]].item())
all_logprobs_of_chosen_tokens.append(logprobs_of_chosen_tokens)
all_top_logprobs_dicts.append(top_logprobs_dicts)
generated_tokens_logprobs.append(logprobs[sequences[completion_id][j]].item())
all_generated_tokens_logprobs.append(generated_tokens_logprobs)
all_generated_tokens_top_logprobs_dicts.append(generated_tokens_top_logprobs_dicts)

# Remove prompt from the start of each sequence if echo_prompt is False.
if not raw_request["echo_prompt"]:
Expand All @@ -130,15 +160,17 @@ def serve_request(self, raw_request: Dict[str, Any]):
all_decoded_text = self.tokenizer.batch_decode(sequences)

completions = []
for decoded_text, tokens, logprobs_of_chosen_tokens, top_logprobs_dicts in zip(
all_decoded_text, all_tokens, all_logprobs_of_chosen_tokens, all_top_logprobs_dicts
for decoded_text, tokens, generated_tokens_logprobs, generated_tokens_top_logprobs_dicts in zip(
all_decoded_text, all_tokens, all_generated_tokens_logprobs, all_generated_tokens_top_logprobs_dicts
):
completions.append(
{
"text": decoded_text,
"tokens": tokens,
"logprobs": logprobs_of_chosen_tokens,
"top_logprobs_dicts": top_logprobs_dicts,
"logprobs": generated_tokens_logprobs,
"top_logprobs_dicts": generated_tokens_top_logprobs_dicts,
"prompt_logprobs": prompt_tokens_logprobs,
"prompt_top_logprobs_dicts": prompt_tokens_top_logprobs_dicts,
}
)

Expand Down Expand Up @@ -229,8 +261,18 @@ def do_it():
if request.echo_prompt:
# Add prompt to list of generated tokens.
generated_tokens = raw_completion["tokens"][response["input_length"] :]
for token_text in raw_completion["tokens"][: response["input_length"]]:
tokens.append(Token(text=token_text, logprob=0.0, top_logprobs={}))
if raw_completion.get("prompt_logprobs") and raw_completion.get("prompt_top_logprobs_dicts"):
for token_text, logprob, top_logprobs_dict in zip(
raw_completion["tokens"][: response["input_length"]],
raw_completion["prompt_logprobs"][: response["input_length"]],
raw_completion["prompt_top_logprobs_dicts"][: response["input_length"]],
):
tokens.append(Token(text=token_text, logprob=logprob, top_logprobs=top_logprobs_dict))
sequence_logprob += logprob
else:
for token_text in raw_completion["tokens"][: response["input_length"]]:
tokens.append(Token(text=token_text, logprob=0.0, top_logprobs={}))

else:
generated_tokens = raw_completion["tokens"]

Expand Down
22 changes: 22 additions & 0 deletions src/helm/proxy/clients/test_huggingface_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,25 @@ def test_gptj_6b(self):
)
)
assert len(result.completions) == 3

def test_logprob(self):
prompt: str = "I am a computer scientist."
result: RequestResult = self.client.make_request(
Request(
model="openai/gpt2",
model_deployment="huggingface/gpt2",
prompt=prompt,
num_completions=1,
max_tokens=0,
echo_prompt=True,
)
)
assert result.completions[0].text.startswith(
prompt
), "echo_prompt was set to true. Expected the prompt at the beginning of each completion"
total_logprob: float = 0
assert len(result.completions[0].tokens) == 6, "Expected 6 tokens in the completion"
for token in result.completions[0].tokens[1:]:
assert token.logprob != 0
total_logprob += token.logprob
assert result.completions[0].logprob == pytest.approx(total_logprob)

0 comments on commit 24de541

Please sign in to comment.