Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add on_tool_start/end callbacks #1879

Merged
merged 1 commit into from
Dec 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dspy/predict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
from .multi_chain_comparison import MultiChainComparison
from .predict import Predict
from .program_of_thought import ProgramOfThought
from .react import ReAct
from .react import ReAct, Tool
from .retry import Retry
from .parallel import Parallel
2 changes: 2 additions & 0 deletions dspy/predict/react.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dspy.primitives.program import Module
from dspy.signatures.signature import ensure_signature
from dspy.adapters.json_adapter import get_annotation_name
from dspy.utils.callback import with_callbacks
from typing import Callable, Any, get_type_hints, get_origin, Literal

class Tool:
Expand All @@ -19,6 +20,7 @@ def __init__(self, func: Callable, name: str = None, desc: str = None, args: dic
for k, v in (args or get_type_hints(annotations_func)).items() if k != 'return'
}

@with_callbacks
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)

Expand Down
39 changes: 39 additions & 0 deletions dspy/utils/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,38 @@ def on_adapter_parse_end(
"""
pass

def on_tool_start(
self,
call_id: str,
instance: Any,
inputs: Dict[str, Any],
):
"""A handler triggered when a tool is called.

Args:
call_id: A unique identifier for the call. Can be used to connect start/end handlers.
instance: The Tool instance.
inputs: The inputs to the Tool's __call__ method. Each arguments is stored as
a key-value pair in a dictionary.
"""
pass

def on_tool_end(
self,
call_id: str,
outputs: Optional[Dict[str, Any]],
exception: Optional[Exception] = None,
):
"""A handler triggered after a tool is executed.

Args:
call_id: A unique identifier for the call. Can be used to connect start/end handlers.
outputs: The outputs of the Tool's __call__ method. If the method is interrupted by
an exception, this will be None.
exception: If an exception is raised during the execution, it will be stored here.
"""
pass


def with_callbacks(fn):
@functools.wraps(fn)
Expand Down Expand Up @@ -256,6 +288,9 @@ def _get_on_start_handler(callback: BaseCallback, instance: Any, fn: Callable) -
else:
raise ValueError(f"Unsupported adapter method for using callback: {fn.__name__}.")

if isinstance(instance, dspy.Tool):
return callback.on_tool_start

# We treat everything else as a module.
return callback.on_module_start

Expand All @@ -272,5 +307,9 @@ def _get_on_end_handler(callback: BaseCallback, instance: Any, fn: Callable) ->
return callback.on_adapter_parse_end
else:
raise ValueError(f"Unsupported adapter method for using callback: {fn.__name__}.")

if isinstance(instance, dspy.Tool):
return callback.on_tool_end

# We treat everything else as a module.
return callback.on_module_end
41 changes: 41 additions & 0 deletions tests/callback/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ def on_adapter_parse_start(self, call_id, instance, inputs):
def on_adapter_parse_end(self, call_id, outputs, exception):
self.calls.append({"handler": "on_adapter_parse_end", "outputs": outputs, "exception": exception})

def on_tool_start(self, call_id, instance, inputs):
self.calls.append({"handler": "on_tool_start", "instance": instance, "inputs": inputs})

def on_tool_end(self, call_id, outputs, exception):
self.calls.append({"handler": "on_tool_end", "outputs": outputs, "exception": exception})


@pytest.mark.parametrize(
("args", "kwargs"),
Expand Down Expand Up @@ -181,6 +187,41 @@ def test_callback_complex_module():
]


def test_tool_calls():
callback = MyCallback()
dspy.settings.configure(callbacks=[callback])

def tool_1(query: str) -> str:
"""A dummy tool function."""
return "result 1"

def tool_2(query: str) -> str:
"""Another dummy tool function."""
return "result 2"

class MyModule(dspy.Module):
def __init__(self):
self.tools = [dspy.Tool(tool_1), dspy.Tool(tool_2)]

def forward(self, query: str) -> str:
query = self.tools[0](query)
return self.tools[1](query)

module = MyModule()
result = module("query")

assert result == "result 2"
assert len(callback.calls) == 6
assert [call["handler"] for call in callback.calls] == [
"on_module_start",
"on_tool_start",
"on_tool_end",
"on_tool_start",
"on_tool_end",
"on_module_end",
]


def test_active_id():
# Test the call ID is generated and handled properly
class CustomCallback(BaseCallback):
Expand Down
Loading