diff --git a/libs/partners/groq/langchain_groq/chat_models.py b/libs/partners/groq/langchain_groq/chat_models.py index 23aeae9f11bed..9be07c594ad64 100644 --- a/libs/partners/groq/langchain_groq/chat_models.py +++ b/libs/partners/groq/langchain_groq/chat_models.py @@ -52,7 +52,6 @@ ToolMessage, ToolMessageChunk, ) -from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk from langchain_core.output_parsers import ( JsonOutputParser, PydanticOutputParser, @@ -385,9 +384,11 @@ def validate_environment(cls, values: Dict) -> Dict: values["temperature"] = 1e-8 client_params = { - "api_key": values["groq_api_key"].get_secret_value() - if values["groq_api_key"] - else None, + "api_key": ( + values["groq_api_key"].get_secret_value() + if values["groq_api_key"] + else None + ), "base_url": values["groq_api_base"], "timeout": values["request_timeout"], "max_retries": values["max_retries"], @@ -502,42 +503,6 @@ def _stream( ) -> Iterator[ChatGenerationChunk]: message_dicts, params = self._create_message_dicts(messages, stop) - # groq api does not support streaming with tools yet - if "tools" in kwargs: - response = self.client.create( - messages=message_dicts, **{**params, **kwargs} - ) - chat_result = self._create_chat_result(response) - generation = chat_result.generations[0] - message = cast(AIMessage, generation.message) - tool_call_chunks = [ - create_tool_call_chunk( - name=rtc["function"].get("name"), - args=rtc["function"].get("arguments"), - id=rtc.get("id"), - index=rtc.get("index"), - ) - for rtc in message.additional_kwargs.get("tool_calls", []) - ] - chunk_ = ChatGenerationChunk( - message=AIMessageChunk( - content=message.content, - additional_kwargs=message.additional_kwargs, - tool_call_chunks=tool_call_chunks, - usage_metadata=message.usage_metadata, - ), - generation_info=generation.generation_info, - ) - if run_manager: - geninfo = chunk_.generation_info or {} - run_manager.on_llm_new_token( - chunk_.text, - chunk=chunk_, - logprobs=geninfo.get("logprobs"), - ) - yield chunk_ - return - params = {**params, **kwargs, "stream": True} default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk @@ -574,42 +539,6 @@ async def _astream( ) -> AsyncIterator[ChatGenerationChunk]: message_dicts, params = self._create_message_dicts(messages, stop) - # groq api does not support streaming with tools yet - if "tools" in kwargs: - response = await self.async_client.create( - messages=message_dicts, **{**params, **kwargs} - ) - chat_result = self._create_chat_result(response) - generation = chat_result.generations[0] - message = cast(AIMessage, generation.message) - tool_call_chunks = [ - { - "name": rtc["function"].get("name"), - "args": rtc["function"].get("arguments"), - "id": rtc.get("id"), - "index": rtc.get("index"), - } - for rtc in message.additional_kwargs.get("tool_calls", []) - ] - chunk_ = ChatGenerationChunk( - message=AIMessageChunk( - content=message.content, - additional_kwargs=message.additional_kwargs, - tool_call_chunks=tool_call_chunks, # type: ignore[arg-type] - usage_metadata=message.usage_metadata, - ), - generation_info=generation.generation_info, - ) - if run_manager: - geninfo = chunk_.generation_info or {} - await run_manager.on_llm_new_token( - chunk_.text, - chunk=chunk_, - logprobs=geninfo.get("logprobs"), - ) - yield chunk_ - return - params = {**params, **kwargs, "stream": True} default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk diff --git a/libs/partners/groq/tests/unit_tests/__snapshots__/test_standard.ambr b/libs/partners/groq/tests/unit_tests/__snapshots__/test_standard.ambr index caf848b56544e..741d2c847455d 100644 --- a/libs/partners/groq/tests/unit_tests/__snapshots__/test_standard.ambr +++ b/libs/partners/groq/tests/unit_tests/__snapshots__/test_standard.ambr @@ -1,42 +1,6 @@ # serializer version: 1 # name: TestGroqStandard.test_serdes[serialized] dict({ - 'graph': dict({ - 'edges': list([ - dict({ - 'source': 0, - 'target': 1, - }), - dict({ - 'source': 1, - 'target': 2, - }), - ]), - 'nodes': list([ - dict({ - 'data': 'ChatGroqInput', - 'id': 0, - 'type': 'schema', - }), - dict({ - 'data': dict({ - 'id': list([ - 'langchain_groq', - 'chat_models', - 'ChatGroq', - ]), - 'name': 'ChatGroq', - }), - 'id': 1, - 'type': 'runnable', - }), - dict({ - 'data': 'ChatGroqOutput', - 'id': 2, - 'type': 'schema', - }), - ]), - }), 'id': list([ 'langchain_groq', 'chat_models',