Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add async support for HuggingFaceTextGenInference #6507

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 83 additions & 2 deletions langchain/llms/huggingface_text_gen_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@

from pydantic import Extra, Field, root_validator

from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM


Expand All @@ -26,10 +29,13 @@ class HuggingFaceTextGenInference(LLM):
- seed: The seed to use when generating text.
- inference_server_url: The URL of the inference server to use.
- timeout: The timeout value in seconds to use while connecting to inference server.
- server_kwargs: The keyword arguments to pass to the inference server.
- client: The client object used to communicate with the inference server.
- async_client: The async client object used to communicate with the server.

Methods:
- _call: Generates text based on a given prompt and stop sequences.
- _acall: Async generates text based on a given prompt and stop sequences.
- _llm_type: Returns the type of LLM.
"""

Expand Down Expand Up @@ -78,8 +84,10 @@ class HuggingFaceTextGenInference(LLM):
seed: Optional[int] = None
inference_server_url: str = ""
timeout: int = 120
server_kwargs: Dict[str, Any] = Field(default_factory=dict)
stream: bool = False
client: Any
async_client: Any

class Config:
"""Configuration for this pydantic object."""
Expand All @@ -94,7 +102,14 @@ def validate_environment(cls, values: Dict) -> Dict:
import text_generation

values["client"] = text_generation.Client(
values["inference_server_url"], timeout=values["timeout"]
values["inference_server_url"],
timeout=values["timeout"],
**values["server_kwargs"],
)
values["async_client"] = text_generation.AsyncClient(
values["inference_server_url"],
timeout=values["timeout"],
**values["server_kwargs"],
)
except ImportError:
raise ImportError(
Expand Down Expand Up @@ -171,3 +186,69 @@ def _call(
text_callback(token.text)
text += token.text
return text

async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
if stop is None:
stop = self.stop_sequences
else:
stop += self.stop_sequences

if not self.stream:
res = await self.async_client.generate(
prompt,
stop_sequences=stop,
max_new_tokens=self.max_new_tokens,
top_k=self.top_k,
top_p=self.top_p,
typical_p=self.typical_p,
temperature=self.temperature,
repetition_penalty=self.repetition_penalty,
seed=self.seed,
**kwargs,
)
# remove stop sequences from the end of the generated text
for stop_seq in stop:
if stop_seq in res.generated_text:
res.generated_text = res.generated_text[
: res.generated_text.index(stop_seq)
]
text: str = res.generated_text
else:
text_callback = None
if run_manager:
text_callback = partial(
run_manager.on_llm_new_token, verbose=self.verbose
)
params = {
**{
"stop_sequences": stop,
"max_new_tokens": self.max_new_tokens,
"top_k": self.top_k,
"top_p": self.top_p,
"typical_p": self.typical_p,
"temperature": self.temperature,
"repetition_penalty": self.repetition_penalty,
"seed": self.seed,
},
**kwargs,
}
text = ""
async for res in self.async_client.generate_stream(prompt, **params):
token = res.token
is_stop = False
for stop_seq in stop:
if stop_seq in token.text:
is_stop = True
break
if is_stop:
break
if not token.special:
if text_callback:
await text_callback(token.text)
return text