From 56b13313075fea0b01c3bb51ed6988aecf835f17 Mon Sep 17 00:00:00 2001 From: phinney <20641413+fzyxh@users.noreply.github.com> Date: Wed, 11 Oct 2023 13:42:21 +0800 Subject: [PATCH 1/2] Fix: stop word should be 1 --- src/helm/proxy/clients/huggingface_client.py | 29 ++++++++++++++++++-- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/src/helm/proxy/clients/huggingface_client.py b/src/helm/proxy/clients/huggingface_client.py index e9f93aa9f9..08e6f3dde1 100644 --- a/src/helm/proxy/clients/huggingface_client.py +++ b/src/helm/proxy/clients/huggingface_client.py @@ -2,6 +2,8 @@ import torch from dataclasses import asdict from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList, \ + STOPPING_CRITERIA_INPUTS_DOCSTRING, add_start_docstrings from typing import Any, Dict, List, Optional from helm.common.cache import Cache, CacheConfig @@ -35,6 +37,21 @@ def resolve_alias(model_name: str) -> str: return _MODEL_NAME_ALIASES.get(model_name, model_name) +class StopAtSpecificTokenCriteria(StoppingCriteria): + def __init__(self, stop_sequence: List[int]= None): + super().__init__() + self.stop_sequence = stop_sequence + + # @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + # Create a tensor from the stop_sequence + stop_sequence_tensor = torch.tensor(self.stop_sequence, device=input_ids.device, dtype=input_ids.dtype) + + # Check if the current sequence ends with the stop_sequence + current_sequence = input_ids[:, -len(self.stop_sequence):] + return torch.all(current_sequence == stop_sequence_tensor).item() + + class HuggingFaceServer: """A thin wrapper around a Hugging Face AutoModelForCausalLM for HuggingFaceClient to call.""" @@ -72,9 +89,10 @@ def serve_request(self, raw_request: Dict[str, Any]): raw_request["stop_sequences"], return_token_type_ids=False, add_special_tokens=False ) assert len(stop_sequence_ids.input_ids) == 1, "Total number of stop words should be 1." - assert len(stop_sequence_ids.input_ids[0]) == 1, "Total number of tokens in each stop word should be 1." + # assert len(stop_sequence_ids.input_ids[0]) == 1, "Total number of tokens in each stop word should be 1." + if len(stop_sequence_ids.input_ids[0]) == 1: + raw_request["eos_token_id"] = stop_sequence_ids.input_ids[0][0] del raw_request["stop_sequences"] - raw_request["eos_token_id"] = stop_sequence_ids.input_ids[0][0] # Strip out irrelevant parameters relevant_raw_request = { @@ -83,8 +101,13 @@ def serve_request(self, raw_request: Dict[str, Any]): if key not in ["engine", "prompt", "echo_prompt", "stop_sequences"] } + stopping_criteria = StoppingCriteriaList() + if stop_sequence_ids != None: + stopping_criteria.append(StopAtSpecificTokenCriteria(stop_sequence=stop_sequence_ids.input_ids[0])) + # Use HuggingFace's `generate` method. - output = self.model.generate(**encoded_input, **relevant_raw_request) + output = self.model.generate(**encoded_input, **relevant_raw_request, + stopping_criteria=stopping_criteria if len(stop_sequence_ids.input_ids[0])>1 else None) sequences = output.sequences scores = output.scores From c0d1464d3a9581037c32619403b5e706f36fa456 Mon Sep 17 00:00:00 2001 From: phinney <20641413+fzyxh@users.noreply.github.com> Date: Thu, 12 Oct 2023 10:14:09 +0800 Subject: [PATCH 2/2] Update: reformat huggingface_client --- src/helm/proxy/clients/huggingface_client.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/helm/proxy/clients/huggingface_client.py b/src/helm/proxy/clients/huggingface_client.py index 08e6f3dde1..9bb6f24f57 100644 --- a/src/helm/proxy/clients/huggingface_client.py +++ b/src/helm/proxy/clients/huggingface_client.py @@ -2,8 +2,12 @@ import torch from dataclasses import asdict from transformers import AutoModelForCausalLM, AutoTokenizer -from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList, \ - STOPPING_CRITERIA_INPUTS_DOCSTRING, add_start_docstrings +from transformers.generation.stopping_criteria import ( + StoppingCriteria, + StoppingCriteriaList, + STOPPING_CRITERIA_INPUTS_DOCSTRING, + add_start_docstrings, +) from typing import Any, Dict, List, Optional from helm.common.cache import Cache, CacheConfig @@ -38,7 +42,7 @@ def resolve_alias(model_name: str) -> str: class StopAtSpecificTokenCriteria(StoppingCriteria): - def __init__(self, stop_sequence: List[int]= None): + def __init__(self, stop_sequence: List[int] = None): super().__init__() self.stop_sequence = stop_sequence @@ -48,7 +52,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa stop_sequence_tensor = torch.tensor(self.stop_sequence, device=input_ids.device, dtype=input_ids.dtype) # Check if the current sequence ends with the stop_sequence - current_sequence = input_ids[:, -len(self.stop_sequence):] + current_sequence = input_ids[:, -len(self.stop_sequence) :] return torch.all(current_sequence == stop_sequence_tensor).item() @@ -106,8 +110,11 @@ def serve_request(self, raw_request: Dict[str, Any]): stopping_criteria.append(StopAtSpecificTokenCriteria(stop_sequence=stop_sequence_ids.input_ids[0])) # Use HuggingFace's `generate` method. - output = self.model.generate(**encoded_input, **relevant_raw_request, - stopping_criteria=stopping_criteria if len(stop_sequence_ids.input_ids[0])>1 else None) + output = self.model.generate( + **encoded_input, + **relevant_raw_request, + stopping_criteria=stopping_criteria if len(stop_sequence_ids.input_ids[0]) > 1 else None, + ) sequences = output.sequences scores = output.scores