Skip to content

Commit

Permalink
add stop support in hf backends (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
bluecoconut committed May 12, 2023
1 parent 96c6d5a commit d40f681
Showing 1 changed file with 29 additions and 3 deletions.
32 changes: 29 additions & 3 deletions lambdaprompt/backends.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import aiohttp
from pydantic import BaseModel, Extra
from typing import Optional
from typing import Optional, Union, List
import tenacity


Expand Down Expand Up @@ -88,7 +88,7 @@ class Parameters(RequestBackend.Parameters):
model: str = 'text-davinci-003'
presence_penalty: float = 0.2
frequency_penalty: float = 0.2
stop: Optional[str]
stop: Optional[Union[str, List[str]]]

def __init__(self, openai_api_key=None, **param_override):
self.openai_api_key = openai_api_key or os.environ.get("OPENAI_API_KEY")
Expand Down Expand Up @@ -151,6 +151,7 @@ class Parameters(Backend.Parameters):
top_p: float = 0.92
top_k: int = 0
repetition_penalty: float = 1.1
stop: Optional[Union[str, List[str]]]

def __init__(self, model_name, torch_dtype=None, trust_remote_code=True, use_auth_token=None, **param_override):
import torch
Expand Down Expand Up @@ -181,12 +182,37 @@ def preprocess(self, prompt):

async def __call__(self, prompt, **kwargs):
import torch
from transformers import StoppingCriteriaList
genkwargs = self.parse_param(**kwargs)

def get_stopping_for_ends(end_ids):
# This assumes that stop is a nice "end_id" token
# we're not decoding and checking the text, staying in id_land, so could cause some weirdness
if len(end_ids) == 0:
return StoppingCriteriaList([lambda *args, **kwargs: False])
max_stop_length = max(x.shape[0] for x in end_ids)
def stop_on_any(input_ids: torch.LongTensor, score: torch.FloatTensor, **kwargs) -> bool:
last_tokens = input_ids[0, -max_stop_length:]
for end_id in end_ids:
if torch.equal(last_tokens[-end_id.shape[0]:], end_id):
return True
return False

return StoppingCriteriaList([stop_on_any])

stop = genkwargs.pop("stop", None) or []
if isinstance(stop, str):
stop = [stop]
end_ids = [self.tokenizer(x, return_tensors="pt").input_ids[0].to(self.model.device) for x in stop]
s = self.preprocess(prompt)
input_ids = self.tokenizer(s, return_tensors="pt").input_ids
input_ids = input_ids.to(self.model.device)
with torch.no_grad():
output_ids = self.model.generate(input_ids, **self.parse_param(**kwargs))
output_ids = self.model.generate(input_ids, stopping_criteria=get_stopping_for_ends(end_ids), **genkwargs)
new_tokens = output_ids[0, len(input_ids[0]) :]
for end_id in end_ids:
if torch.equal(new_tokens[-end_id.shape[0]:], end_id):
new_tokens = new_tokens[:-end_id.shape[0]]
output_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
return output_text

Expand Down

0 comments on commit d40f681

Please sign in to comment.