From c4ae94c10b915ccd852b67984a2c2c15924c3306 Mon Sep 17 00:00:00 2001 From: Parth Sareen Date: Wed, 20 Nov 2024 15:49:50 -0800 Subject: [PATCH] Passing Functions as Tools (#321) * Functions can now be passed as tools --- ollama/_client.py | 62 ++++++- ollama/_types.py | 64 ++++---- ollama/_utils.py | 87 ++++++++++ tests/test_client.py | 56 ++++++- tests/test_type_serialization.py | 45 +++++- tests/test_utils.py | 270 +++++++++++++++++++++++++++++++ 6 files changed, 545 insertions(+), 39 deletions(-) create mode 100644 ollama/_utils.py create mode 100644 tests/test_utils.py diff --git a/ollama/_client.py b/ollama/_client.py index 095d901d..a8a19d35 100644 --- a/ollama/_client.py +++ b/ollama/_client.py @@ -10,6 +10,7 @@ from typing import ( Any, + Callable, Literal, Mapping, Optional, @@ -22,6 +23,9 @@ import sys + +from ollama._utils import convert_function_to_tool + if sys.version_info < (3, 9): from typing import Iterator, AsyncIterator else: @@ -284,7 +288,7 @@ def chat( model: str = '', messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None, *, - tools: Optional[Sequence[Union[Mapping[str, Any], Tool]]] = None, + tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None, stream: bool = False, format: Optional[Literal['', 'json']] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, @@ -293,6 +297,30 @@ def chat( """ Create a chat response using the requested model. + Args: + tools: + A JSON schema as a dict, an Ollama Tool or a Python Function. + Python functions need to follow Google style docstrings to be converted to an Ollama Tool. + For more information, see: https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings + stream: Whether to stream the response. + format: The format of the response. + + Example: + def add_two_numbers(a: int, b: int) -> int: + ''' + Add two numbers together. + + Args: + a: First number to add + b: Second number to add + + Returns: + int: The sum of a and b + ''' + return a + b + + client.chat(model='llama3.1:8b', tools=[add_two_numbers], messages=[...]) + Raises `RequestError` if a model is not provided. Raises `ResponseError` if the request could not be fulfilled. @@ -750,7 +778,7 @@ async def chat( model: str = '', messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None, *, - tools: Optional[Sequence[Union[Mapping[str, Any], Tool]]] = None, + tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None, stream: Literal[True] = True, format: Optional[Literal['', 'json']] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, @@ -771,6 +799,30 @@ async def chat( """ Create a chat response using the requested model. + Args: + tools: + A JSON schema as a dict, an Ollama Tool or a Python Function. + Python functions need to follow Google style docstrings to be converted to an Ollama Tool. + For more information, see: https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings + stream: Whether to stream the response. + format: The format of the response. + + Example: + def add_two_numbers(a: int, b: int) -> int: + ''' + Add two numbers together. + + Args: + a: First number to add + b: Second number to add + + Returns: + int: The sum of a and b + ''' + return a + b + + await client.chat(model='llama3.1:8b', tools=[add_two_numbers], messages=[...]) + Raises `RequestError` if a model is not provided. Raises `ResponseError` if the request could not be fulfilled. @@ -1075,9 +1127,9 @@ def _copy_messages(messages: Optional[Sequence[Union[Mapping[str, Any], Message] ) -def _copy_tools(tools: Optional[Sequence[Union[Mapping[str, Any], Tool]]]) -> Iterator[Tool]: - for tool in tools or []: - yield Tool.model_validate(tool) +def _copy_tools(tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None) -> Iterator[Tool]: + for unprocessed_tool in tools or []: + yield convert_function_to_tool(unprocessed_tool) if callable(unprocessed_tool) else Tool.model_validate(unprocessed_tool) def _as_path(s: Optional[Union[str, PathLike]]) -> Union[Path, None]: diff --git a/ollama/_types.py b/ollama/_types.py index 968099dc..bcf88969 100644 --- a/ollama/_types.py +++ b/ollama/_types.py @@ -1,26 +1,18 @@ import json -from base64 import b64encode +from base64 import b64decode, b64encode from pathlib import Path from datetime import datetime -from typing import ( - Any, - Literal, - Mapping, - Optional, - Sequence, - Union, -) -from typing_extensions import Annotated +from typing import Any, Mapping, Optional, Union, Sequence + +from typing_extensions import Annotated, Literal from pydantic import ( BaseModel, ByteSize, + ConfigDict, Field, - FilePath, - Base64Str, model_serializer, ) -from pydantic.json_schema import JsonSchemaValue class SubscriptableBaseModel(BaseModel): @@ -95,16 +87,26 @@ class BaseGenerateRequest(BaseStreamableRequest): class Image(BaseModel): - value: Union[FilePath, Base64Str, bytes] + value: Union[str, bytes, Path] - # This overloads the `model_dump` method and returns values depending on the type of the `value` field @model_serializer def serialize_model(self): - if isinstance(self.value, Path): - return b64encode(self.value.read_bytes()).decode() - elif isinstance(self.value, bytes): - return b64encode(self.value).decode() - return self.value + if isinstance(self.value, (Path, bytes)): + return b64encode(self.value.read_bytes() if isinstance(self.value, Path) else self.value).decode() + + if isinstance(self.value, str): + if Path(self.value).exists(): + return b64encode(Path(self.value).read_bytes()).decode() + + if self.value.split('.')[-1] in ('png', 'jpg', 'jpeg', 'webp'): + raise ValueError(f'File {self.value} does not exist') + + try: + # Try to decode to check if it's already base64 + b64decode(self.value) + return self.value + except Exception: + raise ValueError('Invalid image data, expected base64 string or path to image file') from Exception class GenerateRequest(BaseGenerateRequest): @@ -222,20 +224,27 @@ class Function(SubscriptableBaseModel): class Tool(SubscriptableBaseModel): - type: Literal['function'] = 'function' + type: Optional[Literal['function']] = 'function' class Function(SubscriptableBaseModel): - name: str - description: str + name: Optional[str] = None + description: Optional[str] = None class Parameters(SubscriptableBaseModel): - type: str + type: Optional[Literal['object']] = 'object' required: Optional[Sequence[str]] = None - properties: Optional[JsonSchemaValue] = None - parameters: Parameters + class Property(SubscriptableBaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + type: Optional[str] = None + description: Optional[str] = None + + properties: Optional[Mapping[str, Property]] = None - function: Function + parameters: Optional[Parameters] = None + + function: Optional[Function] = None class ChatRequest(BaseGenerateRequest): @@ -335,6 +344,7 @@ class ModelDetails(SubscriptableBaseModel): class ListResponse(SubscriptableBaseModel): class Model(SubscriptableBaseModel): + model: Optional[str] = None modified_at: Optional[datetime] = None digest: Optional[str] = None size: Optional[ByteSize] = None diff --git a/ollama/_utils.py b/ollama/_utils.py new file mode 100644 index 00000000..c0b67c99 --- /dev/null +++ b/ollama/_utils.py @@ -0,0 +1,87 @@ +from __future__ import annotations +from collections import defaultdict +import inspect +from typing import Callable, Union +import re + +import pydantic +from ollama._types import Tool + + +def _parse_docstring(doc_string: Union[str, None]) -> dict[str, str]: + parsed_docstring = defaultdict(str) + if not doc_string: + return parsed_docstring + + key = hash(doc_string) + for line in doc_string.splitlines(): + lowered_line = line.lower().strip() + if lowered_line.startswith('args:'): + key = 'args' + elif lowered_line.startswith('returns:') or lowered_line.startswith('yields:') or lowered_line.startswith('raises:'): + key = '_' + + else: + # maybe change to a list and join later + parsed_docstring[key] += f'{line.strip()}\n' + + last_key = None + for line in parsed_docstring['args'].splitlines(): + line = line.strip() + if ':' in line: + # Split the line on either: + # 1. A parenthetical expression like (integer) - captured in group 1 + # 2. A colon : + # Followed by optional whitespace. Only split on first occurrence. + parts = re.split(r'(?:\(([^)]*)\)|:)\s*', line, maxsplit=1) + + arg_name = parts[0].strip() + last_key = arg_name + + # Get the description - will be in parts[1] if parenthetical or parts[-1] if after colon + arg_description = parts[-1].strip() + if len(parts) > 2 and parts[1]: # Has parenthetical content + arg_description = parts[-1].split(':', 1)[-1].strip() + + parsed_docstring[last_key] = arg_description + + elif last_key and line: + parsed_docstring[last_key] += ' ' + line + + return parsed_docstring + + +def convert_function_to_tool(func: Callable) -> Tool: + doc_string_hash = hash(inspect.getdoc(func)) + parsed_docstring = _parse_docstring(inspect.getdoc(func)) + schema = type( + func.__name__, + (pydantic.BaseModel,), + { + '__annotations__': {k: v.annotation if v.annotation != inspect._empty else str for k, v in inspect.signature(func).parameters.items()}, + '__signature__': inspect.signature(func), + '__doc__': parsed_docstring[doc_string_hash], + }, + ).model_json_schema() + + for k, v in schema.get('properties', {}).items(): + # If type is missing, the default is string + types = {t.get('type', 'string') for t in v.get('anyOf')} if 'anyOf' in v else {v.get('type', 'string')} + if 'null' in types: + schema['required'].remove(k) + types.discard('null') + + schema['properties'][k] = { + 'description': parsed_docstring[k], + 'type': ', '.join(types), + } + + tool = Tool( + function=Tool.Function( + name=func.__name__, + description=schema.get('description', ''), + parameters=Tool.Function.Parameters(**schema), + ) + ) + + return Tool.model_validate(tool) diff --git a/tests/test_client.py b/tests/test_client.py index 1dd22925..fbd01bda 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,6 +1,7 @@ import os import io import json +from pydantic import ValidationError import pytest import tempfile from pathlib import Path @@ -8,7 +9,7 @@ from werkzeug.wrappers import Request, Response from PIL import Image -from ollama._client import Client, AsyncClient +from ollama._client import Client, AsyncClient, _copy_tools class PrefixPattern(URIPattern): @@ -982,3 +983,56 @@ def test_headers(): ) assert client._client.headers['x-custom'] == 'value' assert client._client.headers['content-type'] == 'application/json' + + +def test_copy_tools(): + def func1(x: int) -> str: + """Simple function 1. + Args: + x (integer): A number + """ + pass + + def func2(y: str) -> int: + """Simple function 2. + Args: + y (string): A string + """ + pass + + # Test with list of functions + tools = list(_copy_tools([func1, func2])) + assert len(tools) == 2 + assert tools[0].function.name == 'func1' + assert tools[1].function.name == 'func2' + + # Test with empty input + assert list(_copy_tools()) == [] + assert list(_copy_tools(None)) == [] + assert list(_copy_tools([])) == [] + + # Test with mix of functions and tool dicts + tool_dict = { + 'type': 'function', + 'function': { + 'name': 'test', + 'description': 'Test function', + 'parameters': { + 'type': 'object', + 'properties': {'x': {'type': 'string', 'description': 'A string'}}, + 'required': ['x'], + }, + }, + } + + tools = list(_copy_tools([func1, tool_dict])) + assert len(tools) == 2 + assert tools[0].function.name == 'func1' + assert tools[1].function.name == 'test' + + +def test_tool_validation(): + # Raises ValidationError when used as it is a generator + with pytest.raises(ValidationError): + invalid_tool = {'type': 'invalid_type', 'function': {'name': 'test'}} + list(_copy_tools([invalid_tool])) diff --git a/tests/test_type_serialization.py b/tests/test_type_serialization.py index f127b03f..e3e8268c 100644 --- a/tests/test_type_serialization.py +++ b/tests/test_type_serialization.py @@ -1,15 +1,48 @@ -from base64 import b64decode, b64encode +from base64 import b64encode +from pathlib import Path +import pytest from ollama._types import Image +import tempfile -def test_image_serialization(): - # Test bytes serialization +def test_image_serialization_bytes(): image_bytes = b'test image bytes' + encoded_string = b64encode(image_bytes).decode() img = Image(value=image_bytes) - assert img.model_dump() == b64encode(image_bytes).decode() + assert img.model_dump() == encoded_string - # Test base64 string serialization + +def test_image_serialization_base64_string(): b64_str = 'dGVzdCBiYXNlNjQgc3RyaW5n' img = Image(value=b64_str) - assert img.model_dump() == b64decode(b64_str).decode() + assert img.model_dump() == b64_str # Should return as-is if valid base64 + + +def test_image_serialization_plain_string(): + img = Image(value='not a path or base64') + assert img.model_dump() == 'not a path or base64' # Should return as-is + + +def test_image_serialization_path(): + with tempfile.NamedTemporaryFile() as temp_file: + temp_file.write(b'test file content') + temp_file.flush() + img = Image(value=Path(temp_file.name)) + assert img.model_dump() == b64encode(b'test file content').decode() + + +def test_image_serialization_string_path(): + with tempfile.NamedTemporaryFile() as temp_file: + temp_file.write(b'test file content') + temp_file.flush() + img = Image(value=temp_file.name) + assert img.model_dump() == b64encode(b'test file content').decode() + + with pytest.raises(ValueError): + img = Image(value='some_path/that/does/not/exist.png') + img.model_dump() + + with pytest.raises(ValueError): + img = Image(value='not an image') + img.model_dump() diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..9fb1e3b2 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,270 @@ +import json +import sys +from typing import Dict, List, Mapping, Sequence, Set, Tuple, Union + + +from ollama._utils import convert_function_to_tool + + +def test_function_to_tool_conversion(): + def add_numbers(x: int, y: Union[int, None] = None) -> int: + """Add two numbers together. + args: + x (integer): The first number + y (integer, optional): The second number + + Returns: + integer: The sum of x and y + """ + return x + y + + tool = convert_function_to_tool(add_numbers).model_dump() + + assert tool['type'] == 'function' + assert tool['function']['name'] == 'add_numbers' + assert tool['function']['description'] == 'Add two numbers together.' + assert tool['function']['parameters']['type'] == 'object' + assert tool['function']['parameters']['properties']['x']['type'] == 'integer' + assert tool['function']['parameters']['properties']['x']['description'] == 'The first number' + assert tool['function']['parameters']['required'] == ['x'] + + +def test_function_with_no_args(): + def simple_func(): + """ + A simple function with no arguments. + Args: + None + Returns: + None + """ + pass + + tool = convert_function_to_tool(simple_func).model_dump() + assert tool['function']['name'] == 'simple_func' + assert tool['function']['description'] == 'A simple function with no arguments.' + assert tool['function']['parameters']['properties'] == {} + + +def test_function_with_all_types(): + if sys.version_info >= (3, 10): + + def all_types( + x: int, + y: str, + z: list[int], + w: dict[str, int], + v: int | str | None, + ) -> int | dict[str, int] | str | list[int] | None: + """ + A function with all types. + Args: + x (integer): The first number + y (string): The second number + z (array): The third number + w (object): The fourth number + v (integer | string | None): The fifth number + """ + pass + else: + + def all_types( + x: int, + y: str, + z: Sequence, + w: Mapping[str, int], + d: Dict[str, int], + s: Set[int], + t: Tuple[int, str], + l: List[int], # noqa: E741 + o: Union[int, None], + ) -> Union[Mapping[str, int], str, None]: + """ + A function with all types. + Args: + x (integer): The first number + y (string): The second number + z (array): The third number + w (object): The fourth number + d (object): The fifth number + s (array): The sixth number + t (array): The seventh number + l (array): The eighth number + o (integer | None): The ninth number + """ + pass + + tool_json = convert_function_to_tool(all_types).model_dump_json() + tool = json.loads(tool_json) + assert tool['function']['parameters']['properties']['x']['type'] == 'integer' + assert tool['function']['parameters']['properties']['y']['type'] == 'string' + + if sys.version_info >= (3, 10): + assert tool['function']['parameters']['properties']['z']['type'] == 'array' + assert tool['function']['parameters']['properties']['w']['type'] == 'object' + assert set(x.strip().strip("'") for x in tool['function']['parameters']['properties']['v']['type'].removeprefix('[').removesuffix(']').split(',')) == {'string', 'integer'} + assert tool['function']['parameters']['properties']['v']['type'] != 'null' + assert tool['function']['parameters']['required'] == ['x', 'y', 'z', 'w'] + else: + assert tool['function']['parameters']['properties']['z']['type'] == 'array' + assert tool['function']['parameters']['properties']['w']['type'] == 'object' + assert tool['function']['parameters']['properties']['d']['type'] == 'object' + assert tool['function']['parameters']['properties']['s']['type'] == 'array' + assert tool['function']['parameters']['properties']['t']['type'] == 'array' + assert tool['function']['parameters']['properties']['l']['type'] == 'array' + assert tool['function']['parameters']['properties']['o']['type'] == 'integer' + assert tool['function']['parameters']['properties']['o']['type'] != 'null' + assert tool['function']['parameters']['required'] == ['x', 'y', 'z', 'w', 'd', 's', 't', 'l'] + + +def test_function_docstring_parsing(): + from typing import List, Dict, Any + + def func_with_complex_docs(x: int, y: List[str]) -> Dict[str, Any]: + """ + Test function with complex docstring. + + Args: + x (integer): A number + with multiple lines + y (array of string): A list + with multiple lines + + Returns: + object: A dictionary + with multiple lines + """ + pass + + tool = convert_function_to_tool(func_with_complex_docs).model_dump() + assert tool['function']['description'] == 'Test function with complex docstring.' + assert tool['function']['parameters']['properties']['x']['description'] == 'A number with multiple lines' + assert tool['function']['parameters']['properties']['y']['description'] == 'A list with multiple lines' + + +def test_skewed_docstring_parsing(): + def add_two_numbers(x: int, y: int) -> int: + """ + Add two numbers together. + Args: + x (integer): : The first number + + + + + y (integer ): The second number + Returns: + integer: The sum of x and y + """ + pass + + tool = convert_function_to_tool(add_two_numbers).model_dump() + assert tool['function']['parameters']['properties']['x']['description'] == ': The first number' + assert tool['function']['parameters']['properties']['y']['description'] == 'The second number' + + +def test_function_with_no_docstring(): + def no_docstring(): + pass + + def no_docstring_with_args(x: int, y: int): + pass + + tool = convert_function_to_tool(no_docstring).model_dump() + assert tool['function']['description'] == '' + + tool = convert_function_to_tool(no_docstring_with_args).model_dump() + assert tool['function']['description'] == '' + assert tool['function']['parameters']['properties']['x']['description'] == '' + assert tool['function']['parameters']['properties']['y']['description'] == '' + + +def test_function_with_only_description(): + def only_description(): + """ + A function with only a description. + """ + pass + + tool = convert_function_to_tool(only_description).model_dump() + assert tool['function']['description'] == 'A function with only a description.' + assert tool['function']['parameters'] == {'type': 'object', 'properties': {}, 'required': None} + + def only_description_with_args(x: int, y: int): + """ + A function with only a description. + """ + pass + + tool = convert_function_to_tool(only_description_with_args).model_dump() + assert tool['function']['description'] == 'A function with only a description.' + assert tool['function']['parameters'] == { + 'type': 'object', + 'properties': { + 'x': {'type': 'integer', 'description': ''}, + 'y': {'type': 'integer', 'description': ''}, + }, + 'required': ['x', 'y'], + } + + +def test_function_with_yields(): + def function_with_yields(x: int, y: int): + """ + A function with yields section. + + Args: + x: the first number + y: the second number + + Yields: + The sum of x and y + """ + pass + + tool = convert_function_to_tool(function_with_yields).model_dump() + assert tool['function']['description'] == 'A function with yields section.' + assert tool['function']['parameters']['properties']['x']['description'] == 'the first number' + assert tool['function']['parameters']['properties']['y']['description'] == 'the second number' + + +def test_function_with_no_types(): + def no_types(a, b): + """ + A function with no types. + """ + pass + + tool = convert_function_to_tool(no_types).model_dump() + assert tool['function']['parameters']['properties']['a']['type'] == 'string' + assert tool['function']['parameters']['properties']['b']['type'] == 'string' + + +def test_function_with_parentheses(): + def func_with_parentheses(a: int, b: int) -> int: + """ + A function with parentheses. + Args: + a: First (:thing) number to add + b: Second number to add + Returns: + int: The sum of a and b + """ + pass + + def func_with_parentheses_and_args(a: int, b: int): + """ + A function with parentheses and args. + Args: + a(integer) : First (:thing) number to add + b(integer) :Second number to add + """ + pass + + tool = convert_function_to_tool(func_with_parentheses).model_dump() + assert tool['function']['parameters']['properties']['a']['description'] == 'First (:thing) number to add' + assert tool['function']['parameters']['properties']['b']['description'] == 'Second number to add' + + tool = convert_function_to_tool(func_with_parentheses_and_args).model_dump() + assert tool['function']['parameters']['properties']['a']['description'] == 'First (:thing) number to add' + assert tool['function']['parameters']['properties']['b']['description'] == 'Second number to add'