From 25b93cc4c0427fddd9878f5ca217bc805a8ea365 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Wed, 31 Jul 2024 16:42:38 -0700 Subject: [PATCH] core[patch]: stringify tool non-content blocks (#24626) Slightly breaking bugfix. Shouldn't cause too many issues since no models would be able to handle non-content block ToolMessage.content anyways. --- libs/core/langchain_core/tools.py | 29 ++++++++++----- libs/core/tests/unit_tests/test_tools.py | 46 ++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 8 deletions(-) diff --git a/libs/core/langchain_core/tools.py b/libs/core/langchain_core/tools.py index 2bcf88404fad0..4cf12f050e255 100644 --- a/libs/core/langchain_core/tools.py +++ b/libs/core/langchain_core/tools.py @@ -1485,14 +1485,7 @@ def _format_output( content: Any, artifact: Any, tool_call_id: Optional[str], name: str, status: str ) -> Union[ToolMessage, Any]: if tool_call_id: - # NOTE: This will fail to stringify lists which aren't actually content blocks - # but whose first element happens to be a string or dict. Tools should avoid - # returning such contents. - if not isinstance(content, str) and not ( - isinstance(content, list) - and content - and isinstance(content[0], (str, dict)) - ): + if not _is_message_content_type(content): content = _stringify(content) return ToolMessage( content, @@ -1505,6 +1498,26 @@ def _format_output( return content +def _is_message_content_type(obj: Any) -> bool: + """Check for OpenAI or Anthropic format tool message content.""" + if isinstance(obj, str): + return True + elif isinstance(obj, list) and all(_is_message_content_block(e) for e in obj): + return True + else: + return False + + +def _is_message_content_block(obj: Any) -> bool: + """Check for OpenAI or Anthropic format tool message content blocks.""" + if isinstance(obj, str): + return True + elif isinstance(obj, dict): + return obj.get("type", None) in ("text", "image_url", "image", "json") + else: + return False + + def _stringify(content: Any) -> str: try: return json.dumps(content) diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 1ba9e5961ab30..f2a040fc916df 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -32,6 +32,8 @@ StructuredTool, Tool, ToolException, + _is_message_content_block, + _is_message_content_type, tool, ) from langchain_core.utils.function_calling import convert_to_openai_function @@ -1623,3 +1625,47 @@ def foo(a: int, b: str) -> str: "title": pydantic_model.__name__, "type": "object", } + + +valid_tool_result_blocks = [ + "foo", + {"type": "text", "text": "foo"}, + {"type": "text", "blah": "foo"}, # note, only 'type' key is currently checked + {"type": "image_url", "image_url": {}}, # openai format + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": "123", + }, + }, # anthropic format + {"type": "json", "json": {}}, # bedrock format +] +invalid_tool_result_blocks = [ + {"text": "foo"}, # missing type + {"results": "foo"}, # not content blocks +] + + +@pytest.mark.parametrize( + ("obj", "expected"), + [ + *([[block, True] for block in valid_tool_result_blocks]), + *([[block, False] for block in invalid_tool_result_blocks]), + ], +) +def test__is_message_content_block(obj: Any, expected: bool) -> None: + assert _is_message_content_block(obj) is expected + + +@pytest.mark.parametrize( + ("obj", "expected"), + [ + ["foo", True], + [valid_tool_result_blocks, True], + [invalid_tool_result_blocks, False], + ], +) +def test__is_message_content_type(obj: Any, expected: bool) -> None: + assert _is_message_content_type(obj) is expected