Skip to content

Commit

Permalink
improve tools (langchain-ai#6062)
Browse files Browse the repository at this point in the history
  • Loading branch information
hwchase17 authored and Undertone0809 committed Jun 19, 2023
1 parent 883a891 commit 7f04c96
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 14 deletions.
51 changes: 39 additions & 12 deletions langchain/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,18 +68,14 @@ def _create_subset_model(
name: str, model: BaseModel, field_names: list
) -> Type[BaseModel]:
"""Create a pydantic model with only a subset of model's fields."""
fields = {
field_name: (
model.__fields__[field_name].type_,
model.__fields__[field_name].default,
)
for field_name in field_names
if field_name in model.__fields__
}
fields = {}
for field_name in field_names:
field = model.__fields__[field_name]
fields[field_name] = (field.type_, field.field_info)
return create_model(name, **fields) # type: ignore


def get_filtered_args(
def _get_filtered_args(
inferred_model: Type[BaseModel],
func: Callable,
) -> dict:
Expand All @@ -100,15 +96,22 @@ def create_schema_from_function(
model_name: str,
func: Callable,
) -> Type[BaseModel]:
"""Create a pydantic schema from a function's signature."""
"""Create a pydantic schema from a function's signature.
Args:
model_name: Name to assign to the generated pydandic schema
func: Function to generate the schema from
Returns:
A pydantic model with the same arguments as the function
"""
# https://docs.pydantic.dev/latest/usage/validation_decorator/
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore
inferred_model = validated.model # type: ignore
if "run_manager" in inferred_model.__fields__:
del inferred_model.__fields__["run_manager"]
# Pydantic adds placeholder virtual fields we need to strip
filtered_args = get_filtered_args(inferred_model, func)
valid_properties = _get_filtered_args(inferred_model, func)
return _create_subset_model(
f"{model_name}Schema", inferred_model, list(filtered_args)
f"{model_name}Schema", inferred_model, list(valid_properties)
)


Expand Down Expand Up @@ -534,6 +537,30 @@ def from_function(
infer_schema: bool = True,
**kwargs: Any,
) -> StructuredTool:
"""Create tool from a given function.
A classmethod that helps to create a tool from a function.
Args:
func: The function from which to create a tool
name: The name of the tool. Defaults to the function name
description: The description of the tool. Defaults to the function docstring
return_direct: Whether to return the result directly or as a callback
args_schema: The schema of the tool's input arguments
infer_schema: Whether to infer the schema from the function's signature
**kwargs: Additional arguments to pass to the tool
Returns:
The tool
Examples:
... code-block:: python
def add(a: int, b: int) -> int:
\"\"\"Add two numbers\"\"\"
return a + b
tool = StructuredTool.from_function(add)
tool.run(1, 2) # 3
"""
name = name or func.__name__
description = description or func.__doc__
assert (
Expand Down
38 changes: 36 additions & 2 deletions tests/unit_tests/tools/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,39 @@ def test_tool_lambda_args_schema() -> None:
assert tool.args == expected_args


def test_structured_tool_from_function_docstring() -> None:
"""Test that structured tools can be created from functions."""

def foo(bar: int, baz: str) -> str:
"""Docstring
Args:
bar: int
baz: str
"""
raise NotImplementedError()

structured_tool = StructuredTool.from_function(foo)
assert structured_tool.name == "foo"
assert structured_tool.args == {
"bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "string"},
}

assert structured_tool.args_schema.schema() == {
"properties": {
"bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "string"},
},
"title": "fooSchemaSchema",
"type": "object",
"required": ["bar", "baz"],
}

prefix = "foo(bar: int, baz: str) -> str - "
assert foo.__doc__ is not None
assert structured_tool.description == prefix + foo.__doc__.strip()


def test_structured_tool_lambda_multi_args_schema() -> None:
"""Test args schema inference when the tool argument is a lambda function."""
tool = StructuredTool.from_function(
Expand Down Expand Up @@ -577,12 +610,13 @@ def foo(bar: int, baz: str) -> str:
}

assert structured_tool.args_schema.schema() == {
"title": "fooSchemaSchema",
"type": "object",
"properties": {
"bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "string"},
},
"title": "fooSchemaSchema",
"type": "object",
"required": ["bar", "baz"],
}

prefix = "foo(bar: int, baz: str) -> str - "
Expand Down

0 comments on commit 7f04c96

Please sign in to comment.