Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core[patch]: stringify tool non-content blocks #24626

Merged
merged 5 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions libs/core/langchain_core/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1485,14 +1485,7 @@ def _format_output(
content: Any, artifact: Any, tool_call_id: Optional[str], name: str, status: str
) -> Union[ToolMessage, Any]:
if tool_call_id:
# NOTE: This will fail to stringify lists which aren't actually content blocks
# but whose first element happens to be a string or dict. Tools should avoid
# returning such contents.
if not isinstance(content, str) and not (
isinstance(content, list)
and content
and isinstance(content[0], (str, dict))
):
if not _is_message_content_type(content):
content = _stringify(content)
return ToolMessage(
content,
Expand All @@ -1505,6 +1498,26 @@ def _format_output(
return content


def _is_message_content_type(obj: Any) -> bool:
"""Check for OpenAI or Anthropic format tool message content."""
if isinstance(obj, str):
return True
elif isinstance(obj, list) and all(_is_message_content_block(e) for e in obj):
return True
else:
return False


def _is_message_content_block(obj: Any) -> bool:
"""Check for OpenAI or Anthropic format tool message content blocks."""
if isinstance(obj, str):
return True
elif isinstance(obj, dict):
return obj.get("type", None) in ("text", "image_url", "image", "json")
else:
return False


def _stringify(content: Any) -> str:
try:
return json.dumps(content)
Expand Down
46 changes: 46 additions & 0 deletions libs/core/tests/unit_tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
StructuredTool,
Tool,
ToolException,
_is_message_content_block,
_is_message_content_type,
tool,
)
from langchain_core.utils.function_calling import convert_to_openai_function
Expand Down Expand Up @@ -1623,3 +1625,47 @@ def foo(a: int, b: str) -> str:
"title": pydantic_model.__name__,
"type": "object",
}


valid_tool_result_blocks = [
"foo",
{"type": "text", "text": "foo"},
{"type": "text", "blah": "foo"}, # note, only 'type' key is currently checked
{"type": "image_url", "image_url": {}}, # openai format
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/jpeg",
"data": "123",
},
}, # anthropic format
{"type": "json", "json": {}}, # bedrock format
]
invalid_tool_result_blocks = [
{"text": "foo"}, # missing type
{"results": "foo"}, # not content blocks
]


@pytest.mark.parametrize(
("obj", "expected"),
[
*([[block, True] for block in valid_tool_result_blocks]),
*([[block, False] for block in invalid_tool_result_blocks]),
],
)
def test__is_message_content_block(obj: Any, expected: bool) -> None:
assert _is_message_content_block(obj) is expected


@pytest.mark.parametrize(
("obj", "expected"),
[
["foo", True],
[valid_tool_result_blocks, True],
[invalid_tool_result_blocks, False],
],
)
def test__is_message_content_type(obj: Any, expected: bool) -> None:
assert _is_message_content_type(obj) is expected
Loading