From c911a43405a989774f78085d8e4b0e4ea53c52d3 Mon Sep 17 00:00:00 2001 From: Mircea Pasoi Date: Tue, 20 Jun 2023 16:48:55 -0700 Subject: [PATCH 1/3] Add async support for HuggingFaceTextGenInference --- .../llms/huggingface_text_gen_inference.py | 85 ++++++++++++++++++- 1 file changed, 83 insertions(+), 2 deletions(-) diff --git a/langchain/llms/huggingface_text_gen_inference.py b/langchain/llms/huggingface_text_gen_inference.py index 56b95636dc0f0..b7d3b41bec679 100644 --- a/langchain/llms/huggingface_text_gen_inference.py +++ b/langchain/llms/huggingface_text_gen_inference.py @@ -4,7 +4,10 @@ from pydantic import Extra, Field, root_validator -from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain.llms.base import LLM @@ -26,10 +29,13 @@ class HuggingFaceTextGenInference(LLM): - seed: The seed to use when generating text. - inference_server_url: The URL of the inference server to use. - timeout: The timeout value in seconds to use while connecting to inference server. + - server_kwargs: The keyword arguments to pass to the inference server (e.g. headers) - client: The client object used to communicate with the inference server. + - async_client: The async client object used to communicate with the inference server. Methods: - _call: Generates text based on a given prompt and stop sequences. + - _acall: Async generates text based on a given prompt and stop sequences. - _llm_type: Returns the type of LLM. """ @@ -78,8 +84,10 @@ class HuggingFaceTextGenInference(LLM): seed: Optional[int] = None inference_server_url: str = "" timeout: int = 120 + server_kwargs: dict[str, Any] = Field(default_factory=dict) stream: bool = False client: Any + async_client: Any class Config: """Configuration for this pydantic object.""" @@ -94,7 +102,14 @@ def validate_environment(cls, values: Dict) -> Dict: import text_generation values["client"] = text_generation.Client( - values["inference_server_url"], timeout=values["timeout"] + values["inference_server_url"], + timeout=values["timeout"], + **values["server_kwargs"], + ) + values["async_client"] = text_generation.AsyncClient( + values["inference_server_url"], + timeout=values["timeout"], + **values["server_kwargs"], ) except ImportError: raise ImportError( @@ -171,3 +186,69 @@ def _call( text_callback(token.text) text += token.text return text + + async def _acall( + self, + prompt: str, + stop: Optional[list[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + if stop is None: + stop = self.stop_sequences + else: + stop += self.stop_sequences + + if not self.stream: + res = await self.async_client.generate( + prompt, + stop_sequences=stop, + max_new_tokens=self.max_new_tokens, + top_k=self.top_k, + top_p=self.top_p, + typical_p=self.typical_p, + temperature=self.temperature, + repetition_penalty=self.repetition_penalty, + seed=self.seed, + **kwargs, + ) + # remove stop sequences from the end of the generated text + for stop_seq in stop: + if stop_seq in res.generated_text: + res.generated_text = res.generated_text[ + : res.generated_text.index(stop_seq) + ] + text: str = res.generated_text + else: + text_callback = None + if run_manager: + text_callback = partial( + run_manager.on_llm_new_token, verbose=self.verbose + ) + params = { + **{ + "stop_sequences": stop, + "max_new_tokens": self.max_new_tokens, + "top_k": self.top_k, + "top_p": self.top_p, + "typical_p": self.typical_p, + "temperature": self.temperature, + "repetition_penalty": self.repetition_penalty, + "seed": self.seed, + }, + **kwargs, + } + text = "" + async for res in self.async_client.generate_stream(prompt, **params): + token = res.token + is_stop = False + for stop_seq in stop: + if stop_seq in token.text: + is_stop = True + break + if is_stop: + break + if not token.special: + if text_callback: + await text_callback(token.text) + return text From 6ddc3ff218f4a691a1a9c41b4639d6d605fdc871 Mon Sep 17 00:00:00 2001 From: Mircea Pasoi Date: Tue, 20 Jun 2023 19:43:57 -0700 Subject: [PATCH 2/3] Lint fixes --- langchain/llms/huggingface_text_gen_inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/langchain/llms/huggingface_text_gen_inference.py b/langchain/llms/huggingface_text_gen_inference.py index b7d3b41bec679..080d12837b3b1 100644 --- a/langchain/llms/huggingface_text_gen_inference.py +++ b/langchain/llms/huggingface_text_gen_inference.py @@ -29,9 +29,9 @@ class HuggingFaceTextGenInference(LLM): - seed: The seed to use when generating text. - inference_server_url: The URL of the inference server to use. - timeout: The timeout value in seconds to use while connecting to inference server. - - server_kwargs: The keyword arguments to pass to the inference server (e.g. headers) + - server_kwargs: The keyword arguments to pass to the inference server. - client: The client object used to communicate with the inference server. - - async_client: The async client object used to communicate with the inference server. + - async_client: The async client object used to communicate with the server. Methods: - _call: Generates text based on a given prompt and stop sequences. From eb477f89a98ae9cad88f5fa3ece90614fc508089 Mon Sep 17 00:00:00 2001 From: Dev 2049 Date: Tue, 20 Jun 2023 22:30:42 -0700 Subject: [PATCH 3/3] lint --- langchain/llms/huggingface_text_gen_inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/langchain/llms/huggingface_text_gen_inference.py b/langchain/llms/huggingface_text_gen_inference.py index 080d12837b3b1..10bc73613cb21 100644 --- a/langchain/llms/huggingface_text_gen_inference.py +++ b/langchain/llms/huggingface_text_gen_inference.py @@ -84,7 +84,7 @@ class HuggingFaceTextGenInference(LLM): seed: Optional[int] = None inference_server_url: str = "" timeout: int = 120 - server_kwargs: dict[str, Any] = Field(default_factory=dict) + server_kwargs: Dict[str, Any] = Field(default_factory=dict) stream: bool = False client: Any async_client: Any @@ -190,7 +190,7 @@ def _call( async def _acall( self, prompt: str, - stop: Optional[list[str]] = None, + stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: