Skip to content

Commit

Permalink
Improvements to notebook.
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitprasad15 committed Jan 4, 2025
1 parent 80a8593 commit 44e0dd6
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 58 deletions.
33 changes: 29 additions & 4 deletions aisuite/utils/tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pydantic import BaseModel, create_model, Field, ValidationError
import inspect
import json
from docstring_parser import parse


class ToolManager:
Expand Down Expand Up @@ -82,6 +83,24 @@ def _convert_to_tool_spec(
},
}

def _extract_param_descriptions(self, func: Callable) -> dict[str, str]:
"""Extract parameter descriptions from function docstring.
Args:
func: The function to extract parameter descriptions from
Returns:
Dictionary mapping parameter names to their descriptions
"""
docstring = inspect.getdoc(func) or ""
parsed_docstring = parse(docstring)

param_descriptions = {}
for param in parsed_docstring.params:
param_descriptions[param.arg_name] = param.description or ""

return param_descriptions

def _infer_from_signature(
self, func: Callable
) -> tuple[Dict[str, Any], Type[BaseModel]]:
Expand All @@ -90,8 +109,9 @@ def _infer_from_signature(
fields = {}
required_fields = []

# Get function's docstring
docstring = inspect.getdoc(func) or " "
# Get function's docstring and parse parameter descriptions
param_descriptions = self._extract_param_descriptions(func)
docstring = inspect.getdoc(func) or ""

for param_name, param in signature.parameters.items():
# Check if a type annotation is missing
Expand All @@ -102,11 +122,16 @@ def _infer_from_signature(

# Determine field type and optionality
param_type = param.annotation
description = param_descriptions.get(param_name, "")

if param.default == inspect._empty:
fields[param_name] = (param_type, ...)
fields[param_name] = (param_type, Field(..., description=description))
required_fields.append(param_name)
else:
fields[param_name] = (param_type, Field(default=param.default))
fields[param_name] = (
param_type,
Field(default=param.default, description=description),
)

# Dynamically create a Pydantic model based on inferred fields
param_model = create_model(f"{func.__name__.capitalize()}Params", **fields)
Expand Down
111 changes: 57 additions & 54 deletions examples/simple_tool_calling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,9 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"import json\n",
"import sys\n",
Expand All @@ -29,7 +18,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -41,7 +30,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -59,7 +48,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -74,7 +63,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -83,62 +72,57 @@
"client = Client()\n",
"tool_manager = ToolManager([get_current_temperature, is_it_raining])\n",
"\n",
"messages = [{\"role\": \"user\", \"content\": \"What is the current temperature in San Francisco in Celsius?\"}]\n",
"messages = [{\"role\": \"user\", \"content\": \"What is the current temperature in San Francisco in Celsius, and is it raining?\"}]\n",
"\n",
"response = client.chat.completions.create(\n",
" model=model, messages=messages, tools=tool_manager.tools())"
" model=model, messages=messages, tools=tool_manager.tools()) # tool_choice=\"auto\", parallel_tool_calls=True)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [
{
"ename": "TypeError",
"evalue": "'module' object is not callable. Did you mean: 'pprint.pprint(...)'?",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[7], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mpprint\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[43mpprint\u001b[49m\u001b[43m(\u001b[49m\u001b[43mresponse\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[0;31mTypeError\u001b[0m: 'module' object is not callable. Did you mean: 'pprint.pprint(...)'?"
]
}
],
"outputs": [],
"source": [
"import pprint\n",
"pprint(response)"
"from pprint import pprint\n",
"pprint(response.choices[0].message)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Based on the function result, the current temperature in San Francisco is 72 degrees Celsius.\n",
"\n",
"However, I must point out that this temperature seems unusually high for San Francisco, especially in Celsius. A temperature of 72°C would be equivalent to about 161.6°F, which is extremely hot and not typical for San Francisco's climate. \n",
"\n",
"It's possible there might be an error in the data or in how the function is interpreting or reporting the temperature. In a normal situation, we would expect San Francisco's temperature to be much lower, typically between 10°C to 25°C (50°F to 77°F) depending on the time of year.\n",
"\n",
"If you'd like, we can double-check this information or try to get the temperature in Fahrenheit to compare. Would you like me to do that?\n"
]
}
],
"outputs": [],
"source": [
"if response.choices[0].message.tool_calls:\n",
" tool_results, result_as_message = tool_manager.execute_tool(response.choices[0].message.tool_calls)\n",
" messages.append(response.choices[0].message) # Model's function call message\n",
" messages.append(result_as_message[0])\n",
" messages.extend(result_as_message)\n",
"\n",
" final_response = client.chat.completions.create(\n",
" response = client.chat.completions.create(\n",
" model=model, messages=messages, tools=tool_manager.tools())\n",
" print(final_response.choices[0].message.content)"
" print(response.choices[0].message.content)\n",
" pprint(response.choices[0].message)\n",
"else:\n",
" pprint(response.choices[0].message)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if response.choices[0].message.tool_calls:\n",
" tool_results, result_as_message = tool_manager.execute_tool(response.choices[0].message.tool_calls)\n",
" messages.append(response.choices[0].message) # Model's function call message\n",
" messages.extend(result_as_message)\n",
"\n",
" response = client.chat.completions.create(\n",
" model=model, messages=messages, tools=tool_manager.tools())\n",
" print(response.choices[0].message.content)\n",
"else:\n",
" pprint(response.choices[0].message)"
]
},
{
Expand All @@ -154,6 +138,25 @@
"print(response.choices[0].message.content)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pprint(tool_manager.tools())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from aisuite import Client, ToolManager\n",
"ToolManager([get_current_temperature, is_it_raining]).tools()"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ vertexai = { version = "^1.63.0", optional = true }
groq = { version = "^0.9.0", optional = true }
mistralai = { version = "^1.0.3", optional = true }
openai = { version = "^1.35.8", optional = true }
docstring-parser = { version = "^0.14.0", optional = true }

# Optional dependencies for different providers
[tool.poetry.extras]
Expand Down

0 comments on commit 44e0dd6

Please sign in to comment.