diff --git a/cht-petals/models.py b/cht-petals/models.py index 80241c6..b899107 100644 --- a/cht-petals/models.py +++ b/cht-petals/models.py @@ -1,6 +1,6 @@ import json from abc import ABC, abstractmethod -from typing import Dict, List, Union +from typing import Dict, List, Optional, Union import torch from petals import AutoDistributedModelForCausalLM @@ -42,12 +42,12 @@ class PetalsBasedModel(ChatModel): def generate( cls, messages: list, + stop: Optional[Union[str, List[str]]] = None, temperature: float = 0.9, top_p: float = 0.9, n: int = 1, stream: bool = False, max_tokens: int = 128, - stop: str = "/s>", **kwargs, ) -> List: prompt = cls.stitch_prompt(messages, cls.PROMPT_TEMPLATE) @@ -61,19 +61,19 @@ def generate( top_p=top_p, max_new_tokens=max_tokens, ) - outputs = cls.safe_decode(cls.tokenizer, outputs[0, n_input_tokens:], streaming=stream, stop=stop) + outputs = cls.safe_decode(cls.tokenizer, outputs[0, n_input_tokens:], stop_tokens=stop) return [outputs] @classmethod def generate_streaming( cls, messages: list, + stop: Optional[Union[str, List[str]]] = None, temperature: float = 0.9, top_p: float = 0.9, n: int = 1, stream: bool = False, max_tokens: int = 128, - stop: str = "/s>", session=None, inputs=None, **kwargs, @@ -95,7 +95,7 @@ def generate_streaming( ) delta = outputs[0, n_input_tokens:].tolist() token_count = len(delta) # noqa - outputs = cls.safe_decode(cls.tokenizer, delta, streaming=stream, stop=stop) + outputs = cls.safe_decode(cls.tokenizer, delta, stop_tokens=stop) if not outputs: return None # end outputs = outputs.lstrip() if inputs is not None else outputs @@ -153,13 +153,13 @@ def stitch_prompt(messages: list, prompt_template: Dict[str, str]) -> str: def safe_decode( tokenizer: PreTrainedTokenizer, outputs: Union[torch.Tensor, List[int]], - streaming: bool = False, - stop: str = "/s>", + stop_tokens: Optional[Union[str, List[str]]] = None, ) -> str: # Workaround to make SentencePiece .decode() keep leading spaces in a token fake_token = tokenizer("^")["input_ids"][0] outputs = outputs.tolist() if isinstance(outputs, torch.Tensor) else outputs - result = tokenizer.decode([fake_token] + outputs) - if streaming: - return result.lstrip("").lstrip(stop) - return result.lstrip("").rsplit("", 1)[0].rsplit(stop, 1)[0].strip() + result = tokenizer.decode([fake_token] + outputs).replace("", "") + + for stop_token in stop_tokens: + result = result.split(stop_token)[0] + return result diff --git a/cht-petals/routes.py b/cht-petals/routes.py index 3f46925..f129723 100644 --- a/cht-petals/routes.py +++ b/cht-petals/routes.py @@ -13,7 +13,7 @@ class ChatCompletionInput(BaseModel): model: str messages: List[dict] - stop: Optional[Union[str, List[str]]] = "/s>" + stop: Optional[Union[str, List[str]]] = ["", "/s>"] temperature: float = 1.0 top_p: float = 1.0 n: int = 1