Skip to content

Commit

Permalink
groq: add back streaming tool calls (#26391)
Browse files Browse the repository at this point in the history
  • Loading branch information
efriis committed Sep 12, 2024
1 parent 396c0ae commit 54c8508
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 112 deletions.
81 changes: 5 additions & 76 deletions libs/partners/groq/langchain_groq/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
ToolMessage,
ToolMessageChunk,
)
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
from langchain_core.output_parsers import (
JsonOutputParser,
PydanticOutputParser,
Expand Down Expand Up @@ -385,9 +384,11 @@ def validate_environment(cls, values: Dict) -> Dict:
values["temperature"] = 1e-8

client_params = {
"api_key": values["groq_api_key"].get_secret_value()
if values["groq_api_key"]
else None,
"api_key": (
values["groq_api_key"].get_secret_value()
if values["groq_api_key"]
else None
),
"base_url": values["groq_api_base"],
"timeout": values["request_timeout"],
"max_retries": values["max_retries"],
Expand Down Expand Up @@ -502,42 +503,6 @@ def _stream(
) -> Iterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)

# groq api does not support streaming with tools yet
if "tools" in kwargs:
response = self.client.create(
messages=message_dicts, **{**params, **kwargs}
)
chat_result = self._create_chat_result(response)
generation = chat_result.generations[0]
message = cast(AIMessage, generation.message)
tool_call_chunks = [
create_tool_call_chunk(
name=rtc["function"].get("name"),
args=rtc["function"].get("arguments"),
id=rtc.get("id"),
index=rtc.get("index"),
)
for rtc in message.additional_kwargs.get("tool_calls", [])
]
chunk_ = ChatGenerationChunk(
message=AIMessageChunk(
content=message.content,
additional_kwargs=message.additional_kwargs,
tool_call_chunks=tool_call_chunks,
usage_metadata=message.usage_metadata,
),
generation_info=generation.generation_info,
)
if run_manager:
geninfo = chunk_.generation_info or {}
run_manager.on_llm_new_token(
chunk_.text,
chunk=chunk_,
logprobs=geninfo.get("logprobs"),
)
yield chunk_
return

params = {**params, **kwargs, "stream": True}

default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
Expand Down Expand Up @@ -574,42 +539,6 @@ async def _astream(
) -> AsyncIterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)

# groq api does not support streaming with tools yet
if "tools" in kwargs:
response = await self.async_client.create(
messages=message_dicts, **{**params, **kwargs}
)
chat_result = self._create_chat_result(response)
generation = chat_result.generations[0]
message = cast(AIMessage, generation.message)
tool_call_chunks = [
{
"name": rtc["function"].get("name"),
"args": rtc["function"].get("arguments"),
"id": rtc.get("id"),
"index": rtc.get("index"),
}
for rtc in message.additional_kwargs.get("tool_calls", [])
]
chunk_ = ChatGenerationChunk(
message=AIMessageChunk(
content=message.content,
additional_kwargs=message.additional_kwargs,
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
usage_metadata=message.usage_metadata,
),
generation_info=generation.generation_info,
)
if run_manager:
geninfo = chunk_.generation_info or {}
await run_manager.on_llm_new_token(
chunk_.text,
chunk=chunk_,
logprobs=geninfo.get("logprobs"),
)
yield chunk_
return

params = {**params, **kwargs, "stream": True}

default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
Expand Down
Original file line number Diff line number Diff line change
@@ -1,42 +1,6 @@
# serializer version: 1
# name: TestGroqStandard.test_serdes[serialized]
dict({
'graph': dict({
'edges': list([
dict({
'source': 0,
'target': 1,
}),
dict({
'source': 1,
'target': 2,
}),
]),
'nodes': list([
dict({
'data': 'ChatGroqInput',
'id': 0,
'type': 'schema',
}),
dict({
'data': dict({
'id': list([
'langchain_groq',
'chat_models',
'ChatGroq',
]),
'name': 'ChatGroq',
}),
'id': 1,
'type': 'runnable',
}),
dict({
'data': 'ChatGroqOutput',
'id': 2,
'type': 'schema',
}),
]),
}),
'id': list([
'langchain_groq',
'chat_models',
Expand Down

0 comments on commit 54c8508

Please sign in to comment.