diff --git a/src/helm/proxy/clients/huggingface_client.py b/src/helm/proxy/clients/huggingface_client.py index e9f93aa9f9..9bb6f24f57 100644 --- a/src/helm/proxy/clients/huggingface_client.py +++ b/src/helm/proxy/clients/huggingface_client.py @@ -2,6 +2,12 @@ import torch from dataclasses import asdict from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.generation.stopping_criteria import ( + StoppingCriteria, + StoppingCriteriaList, + STOPPING_CRITERIA_INPUTS_DOCSTRING, + add_start_docstrings, +) from typing import Any, Dict, List, Optional from helm.common.cache import Cache, CacheConfig @@ -35,6 +41,21 @@ def resolve_alias(model_name: str) -> str: return _MODEL_NAME_ALIASES.get(model_name, model_name) +class StopAtSpecificTokenCriteria(StoppingCriteria): + def __init__(self, stop_sequence: List[int] = None): + super().__init__() + self.stop_sequence = stop_sequence + + # @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + # Create a tensor from the stop_sequence + stop_sequence_tensor = torch.tensor(self.stop_sequence, device=input_ids.device, dtype=input_ids.dtype) + + # Check if the current sequence ends with the stop_sequence + current_sequence = input_ids[:, -len(self.stop_sequence) :] + return torch.all(current_sequence == stop_sequence_tensor).item() + + class HuggingFaceServer: """A thin wrapper around a Hugging Face AutoModelForCausalLM for HuggingFaceClient to call.""" @@ -72,9 +93,10 @@ def serve_request(self, raw_request: Dict[str, Any]): raw_request["stop_sequences"], return_token_type_ids=False, add_special_tokens=False ) assert len(stop_sequence_ids.input_ids) == 1, "Total number of stop words should be 1." - assert len(stop_sequence_ids.input_ids[0]) == 1, "Total number of tokens in each stop word should be 1." + # assert len(stop_sequence_ids.input_ids[0]) == 1, "Total number of tokens in each stop word should be 1." + if len(stop_sequence_ids.input_ids[0]) == 1: + raw_request["eos_token_id"] = stop_sequence_ids.input_ids[0][0] del raw_request["stop_sequences"] - raw_request["eos_token_id"] = stop_sequence_ids.input_ids[0][0] # Strip out irrelevant parameters relevant_raw_request = { @@ -83,8 +105,16 @@ def serve_request(self, raw_request: Dict[str, Any]): if key not in ["engine", "prompt", "echo_prompt", "stop_sequences"] } + stopping_criteria = StoppingCriteriaList() + if stop_sequence_ids != None: + stopping_criteria.append(StopAtSpecificTokenCriteria(stop_sequence=stop_sequence_ids.input_ids[0])) + # Use HuggingFace's `generate` method. - output = self.model.generate(**encoded_input, **relevant_raw_request) + output = self.model.generate( + **encoded_input, + **relevant_raw_request, + stopping_criteria=stopping_criteria if len(stop_sequence_ids.input_ids[0]) > 1 else None, + ) sequences = output.sequences scores = output.scores