diff --git a/libs/core/langchain_core/utils/function_calling.py b/libs/core/langchain_core/utils/function_calling.py index 3aff07faecd4c..4779d26244203 100644 --- a/libs/core/langchain_core/utils/function_calling.py +++ b/libs/core/langchain_core/utils/function_calling.py @@ -22,7 +22,7 @@ from pydantic import BaseModel from typing_extensions import TypedDict, get_args, get_origin, is_typeddict -from langchain_core._api import deprecated +from langchain_core._api import beta, deprecated from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage from langchain_core.utils.json_schema import dereference_refs from langchain_core.utils.pydantic import is_basemodel_subclass @@ -494,21 +494,28 @@ def convert_to_openai_tool( return {"type": "function", "function": oai_function} +@beta() def tool_example_to_messages( - input: str, tool_calls: list[BaseModel], tool_outputs: Optional[list[str]] = None + input: str, + tool_calls: list[BaseModel], + tool_outputs: Optional[list[str]] = None, + *, + ai_response: Optional[str] = None, ) -> list[BaseMessage]: """Convert an example into a list of messages that can be fed into an LLM. This code is an adapter that converts a single example to a list of messages that can be fed into a chat model. - The list of messages per example corresponds to: + The list of messages per example by default corresponds to: 1) HumanMessage: contains the content from which content should be extracted. 2) AIMessage: contains the extracted information from the model 3) ToolMessage: contains confirmation to the model that the model requested a tool correctly. + If `ai_response` is specified, there will be a final AIMessage with that response. + The ToolMessage is required because some chat models are hyper-optimized for agents rather than for an extraction use case. @@ -519,6 +526,7 @@ def tool_example_to_messages( tool_outputs: Optional[List[str]], a list of tool call outputs. Does not need to be provided. If not provided, a placeholder value will be inserted. Defaults to None. + ai_response: Optional[str], if provided, content for a final AIMessage. Returns: A list of messages @@ -584,6 +592,9 @@ class Person(BaseModel): ) for output, tool_call_dict in zip(tool_outputs, openai_tool_calls): messages.append(ToolMessage(content=output, tool_call_id=tool_call_dict["id"])) # type: ignore + + if ai_response: + messages.append(AIMessage(content=ai_response)) return messages diff --git a/libs/core/tests/unit_tests/utils/test_function_calling.py b/libs/core/tests/unit_tests/utils/test_function_calling.py index 4eaa3da2b19ab..ba4c50187f139 100644 --- a/libs/core/tests/unit_tests/utils/test_function_calling.py +++ b/libs/core/tests/unit_tests/utils/test_function_calling.py @@ -679,6 +679,24 @@ def test_tool_outputs() -> None: ] assert messages[2].content == "Output1" + # Test final AI response + messages = tool_example_to_messages( + input="This is an example", + tool_calls=[ + FakeCall(data="ToolCall1"), + ], + tool_outputs=["Output1"], + ai_response="The output is Output1", + ) + assert len(messages) == 4 + assert isinstance(messages[0], HumanMessage) + assert isinstance(messages[1], AIMessage) + assert isinstance(messages[2], ToolMessage) + assert isinstance(messages[3], AIMessage) + response = messages[3] + assert response.content == "The output is Output1" + assert not response.tool_calls + @pytest.mark.parametrize("use_extension_typed_dict", [True, False]) @pytest.mark.parametrize("use_extension_annotated", [True, False]) diff --git a/libs/standard-tests/langchain_tests/integration_tests/chat_models.py b/libs/standard-tests/langchain_tests/integration_tests/chat_models.py index 61f45e63b9bee..5ef6f99c82148 100644 --- a/libs/standard-tests/langchain_tests/integration_tests/chat_models.py +++ b/libs/standard-tests/langchain_tests/integration_tests/chat_models.py @@ -17,6 +17,7 @@ from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate from langchain_core.tools import tool +from langchain_core.utils.function_calling import tool_example_to_messages from pydantic import BaseModel, Field from pydantic.v1 import BaseModel as BaseModelV1 from pydantic.v1 import Field as FieldV1 @@ -857,33 +858,20 @@ def test_structured_few_shot_examples(self, model: BaseChatModel) -> None: if not self.has_tool_calling: pytest.skip("Test requires tool calling.") model_with_tools = model.bind_tools([my_adder_tool], tool_choice="any") - function_name = "my_adder_tool" - function_args = {"a": 1, "b": 2} function_result = json.dumps({"result": 3}) - messages_string_content = [ - HumanMessage("What is 1 + 2"), - AIMessage( - "", - tool_calls=[ - { - "name": function_name, - "args": function_args, - "id": "abc123", - "type": "tool_call", - }, - ], - ), - ToolMessage( - function_result, - name=function_name, - tool_call_id="abc123", - ), - AIMessage(function_result), - HumanMessage("What is 3 + 4"), - ] - result_string_content = model_with_tools.invoke(messages_string_content) - assert isinstance(result_string_content, AIMessage) + tool_schema = my_adder_tool.args_schema + assert tool_schema is not None + few_shot_messages = tool_example_to_messages( + "What is 1 + 2", + [tool_schema(a=1, b=2)], + tool_outputs=[function_result], + ai_response=function_result, + ) + + messages = few_shot_messages + [HumanMessage("What is 3 + 4")] + result = model_with_tools.invoke(messages) + assert isinstance(result, AIMessage) def test_image_inputs(self, model: BaseChatModel) -> None: if not self.supports_image_inputs: