Skip to content

Commit

Permalink
core[patch]: support final AIMessage responses in `tool_example_to_me…
Browse files Browse the repository at this point in the history
…ssages` (#28267)

We have a test
[test_structured_few_shot_examples](https://github.com/langchain-ai/langchain/blob/ad4333ca032033097c663dfe818c5c892c368bd6/libs/standard-tests/langchain_tests/integration_tests/chat_models.py#L546)
in standard integration tests that implements a version of tool-calling
few shot examples that works with ~all tested providers. The formulation
supported by ~all providers is: `human message, tool call, tool message,
AI reponse`.

Here we update
`langchain_core.utils.function_calling.tool_example_to_messages` to
support this formulation.

The `tool_example_to_messages` util is undocumented outside of our API
reference. IMO, if we are testing that this function works across all
providers, it can be helpful to feature it in our guides. The structured
few-shot examples we document at the moment require users to implement
this function and can be simplified.
  • Loading branch information
ccurme authored Nov 22, 2024
1 parent a5fcbe6 commit a433039
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 28 deletions.
17 changes: 14 additions & 3 deletions libs/core/langchain_core/utils/function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pydantic import BaseModel
from typing_extensions import TypedDict, get_args, get_origin, is_typeddict

from langchain_core._api import deprecated
from langchain_core._api import beta, deprecated
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
from langchain_core.utils.json_schema import dereference_refs
from langchain_core.utils.pydantic import is_basemodel_subclass
Expand Down Expand Up @@ -494,21 +494,28 @@ def convert_to_openai_tool(
return {"type": "function", "function": oai_function}


@beta()
def tool_example_to_messages(
input: str, tool_calls: list[BaseModel], tool_outputs: Optional[list[str]] = None
input: str,
tool_calls: list[BaseModel],
tool_outputs: Optional[list[str]] = None,
*,
ai_response: Optional[str] = None,
) -> list[BaseMessage]:
"""Convert an example into a list of messages that can be fed into an LLM.
This code is an adapter that converts a single example to a list of messages
that can be fed into a chat model.
The list of messages per example corresponds to:
The list of messages per example by default corresponds to:
1) HumanMessage: contains the content from which content should be extracted.
2) AIMessage: contains the extracted information from the model
3) ToolMessage: contains confirmation to the model that the model requested a tool
correctly.
If `ai_response` is specified, there will be a final AIMessage with that response.
The ToolMessage is required because some chat models are hyper-optimized for agents
rather than for an extraction use case.
Expand All @@ -519,6 +526,7 @@ def tool_example_to_messages(
tool_outputs: Optional[List[str]], a list of tool call outputs.
Does not need to be provided. If not provided, a placeholder value
will be inserted. Defaults to None.
ai_response: Optional[str], if provided, content for a final AIMessage.
Returns:
A list of messages
Expand Down Expand Up @@ -584,6 +592,9 @@ class Person(BaseModel):
)
for output, tool_call_dict in zip(tool_outputs, openai_tool_calls):
messages.append(ToolMessage(content=output, tool_call_id=tool_call_dict["id"])) # type: ignore

if ai_response:
messages.append(AIMessage(content=ai_response))
return messages


Expand Down
18 changes: 18 additions & 0 deletions libs/core/tests/unit_tests/utils/test_function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,24 @@ def test_tool_outputs() -> None:
]
assert messages[2].content == "Output1"

# Test final AI response
messages = tool_example_to_messages(
input="This is an example",
tool_calls=[
FakeCall(data="ToolCall1"),
],
tool_outputs=["Output1"],
ai_response="The output is Output1",
)
assert len(messages) == 4
assert isinstance(messages[0], HumanMessage)
assert isinstance(messages[1], AIMessage)
assert isinstance(messages[2], ToolMessage)
assert isinstance(messages[3], AIMessage)
response = messages[3]
assert response.content == "The output is Output1"
assert not response.tool_calls


@pytest.mark.parametrize("use_extension_typed_dict", [True, False])
@pytest.mark.parametrize("use_extension_annotated", [True, False])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.tools import tool
from langchain_core.utils.function_calling import tool_example_to_messages
from pydantic import BaseModel, Field
from pydantic.v1 import BaseModel as BaseModelV1
from pydantic.v1 import Field as FieldV1
Expand Down Expand Up @@ -857,33 +858,20 @@ def test_structured_few_shot_examples(self, model: BaseChatModel) -> None:
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")
model_with_tools = model.bind_tools([my_adder_tool], tool_choice="any")
function_name = "my_adder_tool"
function_args = {"a": 1, "b": 2}
function_result = json.dumps({"result": 3})

messages_string_content = [
HumanMessage("What is 1 + 2"),
AIMessage(
"",
tool_calls=[
{
"name": function_name,
"args": function_args,
"id": "abc123",
"type": "tool_call",
},
],
),
ToolMessage(
function_result,
name=function_name,
tool_call_id="abc123",
),
AIMessage(function_result),
HumanMessage("What is 3 + 4"),
]
result_string_content = model_with_tools.invoke(messages_string_content)
assert isinstance(result_string_content, AIMessage)
tool_schema = my_adder_tool.args_schema
assert tool_schema is not None
few_shot_messages = tool_example_to_messages(
"What is 1 + 2",
[tool_schema(a=1, b=2)],
tool_outputs=[function_result],
ai_response=function_result,
)

messages = few_shot_messages + [HumanMessage("What is 3 + 4")]
result = model_with_tools.invoke(messages)
assert isinstance(result, AIMessage)

def test_image_inputs(self, model: BaseChatModel) -> None:
if not self.supports_image_inputs:
Expand Down

0 comments on commit a433039

Please sign in to comment.