Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Truncate end of text token in HuggingFaceClient #2643

Merged
merged 1 commit into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/helm/clients/auto_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from retrying import Attempt, RetryError

from helm.benchmark.model_deployment_registry import ModelDeployment, get_model_deployment
from helm.benchmark.tokenizer_config_registry import get_tokenizer_config
from helm.common.file_caches.file_cache import FileCache
from helm.common.file_caches.local_file_cache import LocalFileCache
from helm.common.credentials_utils import provide_api_key
Expand Down Expand Up @@ -89,6 +90,9 @@ def _get_client(self, model_deployment_name: str) -> Client:
"hf_auth_token": lambda: self.credentials.get("huggingfaceAuthToken", None), # HuggingFace
"file_cache": lambda: self._get_file_cache(host_organization), # Text-to-image models
"endpoint": lambda: self.credentials.get(host_organization + "Endpoint", None), # Palmyra
"end_of_text_token": lambda: self._get_end_of_text_token(
tokenizer_name=model_deployment.tokenizer_name or model_deployment.name
),
},
)
client = create_object(client_spec)
Expand Down Expand Up @@ -214,3 +218,9 @@ def _get_file_cache(self, host_organization: str) -> FileCache:
# Initialize `FileCache` for text-to-image model APIs
local_file_cache_path: str = os.path.join(self.file_storage_path, "output", host_organization)
return LocalFileCache(local_file_cache_path, file_extension="png")

def _get_end_of_text_token(self, tokenizer_name: str) -> Optional[str]:
tokenizer_config = get_tokenizer_config(tokenizer_name)
if tokenizer_config is None:
raise ValueError(f"Could not find tokenizer_config for tokenizer {tokenizer_name}")
return tokenizer_config.end_of_text_token
26 changes: 22 additions & 4 deletions src/helm/clients/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,12 @@ def make_cache_key(raw_request: Mapping, request: Request) -> Mapping:
return {**raw_request}


def truncate_sequence(sequence: GeneratedOutput, request: Request, print_warning: bool = True) -> GeneratedOutput:
def truncate_sequence(
sequence: GeneratedOutput,
request: Request,
end_of_text_token: Optional[str] = None,
print_warning: bool = True,
) -> GeneratedOutput:
"""
Certain providers have bugs where they aren't respecting max_tokens,
stop_sequences and the end of text token, so as a hack, we have to manually
Expand All @@ -63,7 +68,11 @@ def truncate_sequence(sequence: GeneratedOutput, request: Request, print_warning
hlog("WARNING: don't know how to handle echo_prompt and max_tokens > 0, not truncating")
return sequence

for stop in request.stop_sequences:
if end_of_text_token:
stop_sequences = request.stop_sequences + [end_of_text_token]
else:
stop_sequences = request.stop_sequences
for stop in stop_sequences:
# Find `stop` in the text
try:
new_text = sequence.text[: sequence.text.index(stop)]
Expand Down Expand Up @@ -115,7 +124,12 @@ def truncate_sequence(sequence: GeneratedOutput, request: Request, print_warning


def truncate_and_tokenize_response_text(
text: str, request: Request, tokenizer: Tokenizer, tokenizer_name: str, original_finish_reason: str = "endoftext"
text: str,
request: Request,
tokenizer: Tokenizer,
tokenizer_name: str,
end_of_text_token: Optional[str] = None,
original_finish_reason: str = "endoftext",
) -> GeneratedOutput:
"""Truncate a string-only response to respect stop_sequences and max_tokens.

Expand All @@ -138,7 +152,11 @@ def truncate_and_tokenize_response_text(
if request.echo_prompt:
raise Exception("truncate_and_tokenize_response_text() does not support requests with echo_prompt = True")

for stop_sequence in request.stop_sequences:
if end_of_text_token:
stop_sequences = request.stop_sequences + [end_of_text_token]
else:
stop_sequences = request.stop_sequences
for stop_sequence in stop_sequences:
try:
text = text[: text.index(stop_sequence)]
finish_reason = "stop"
Expand Down
4 changes: 3 additions & 1 deletion src/helm/clients/huggingface_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def __init__(
cache_config: CacheConfig,
tokenizer: Tokenizer,
pretrained_model_name_or_path: Optional[str] = None,
end_of_text_token: Optional[str] = None,
**kwargs,
):
super().__init__(cache_config=cache_config)
Expand All @@ -281,6 +282,7 @@ def __init__(
self._wrapped_tokenizer: WrappedPreTrainedTokenizer = tokenizer.get_wrapped_tokenizer()
self._tokenizer = tokenizer
self._kwargs = _process_huggingface_client_kwargs(kwargs)
self._end_of_text_token = end_of_text_token

def make_request(self, request: Request) -> RequestResult:
# Embedding not supported for this model
Expand Down Expand Up @@ -348,7 +350,7 @@ def do_it() -> Dict[str, Any]:
sequence_logprob += logprob

completion = GeneratedOutput(text=raw_completion["text"], logprob=sequence_logprob, tokens=tokens)
completion = truncate_sequence(completion, request)
completion = truncate_sequence(completion, request, end_of_text_token=self._end_of_text_token)
completions.append(completion)

return RequestResult(
Expand Down