Skip to content

Commit

Permalink
Passing Functions as Tools (ollama#321)
Browse files Browse the repository at this point in the history
* Functions can now be passed as tools
  • Loading branch information
ParthSareen authored and pressdarling committed Dec 1, 2024
1 parent 2105f4e commit c4ae94c
Show file tree
Hide file tree
Showing 6 changed files with 545 additions and 39 deletions.
62 changes: 57 additions & 5 deletions ollama/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from typing import (
Any,
Callable,
Literal,
Mapping,
Optional,
Expand All @@ -22,6 +23,9 @@

import sys


from ollama._utils import convert_function_to_tool

if sys.version_info < (3, 9):
from typing import Iterator, AsyncIterator
else:
Expand Down Expand Up @@ -284,7 +288,7 @@ def chat(
model: str = '',
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
*,
tools: Optional[Sequence[Union[Mapping[str, Any], Tool]]] = None,
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
stream: bool = False,
format: Optional[Literal['', 'json']] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
Expand All @@ -293,6 +297,30 @@ def chat(
"""
Create a chat response using the requested model.
Args:
tools:
A JSON schema as a dict, an Ollama Tool or a Python Function.
Python functions need to follow Google style docstrings to be converted to an Ollama Tool.
For more information, see: https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings
stream: Whether to stream the response.
format: The format of the response.
Example:
def add_two_numbers(a: int, b: int) -> int:
'''
Add two numbers together.
Args:
a: First number to add
b: Second number to add
Returns:
int: The sum of a and b
'''
return a + b
client.chat(model='llama3.1:8b', tools=[add_two_numbers], messages=[...])
Raises `RequestError` if a model is not provided.
Raises `ResponseError` if the request could not be fulfilled.
Expand Down Expand Up @@ -750,7 +778,7 @@ async def chat(
model: str = '',
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
*,
tools: Optional[Sequence[Union[Mapping[str, Any], Tool]]] = None,
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
stream: Literal[True] = True,
format: Optional[Literal['', 'json']] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
Expand All @@ -771,6 +799,30 @@ async def chat(
"""
Create a chat response using the requested model.
Args:
tools:
A JSON schema as a dict, an Ollama Tool or a Python Function.
Python functions need to follow Google style docstrings to be converted to an Ollama Tool.
For more information, see: https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings
stream: Whether to stream the response.
format: The format of the response.
Example:
def add_two_numbers(a: int, b: int) -> int:
'''
Add two numbers together.
Args:
a: First number to add
b: Second number to add
Returns:
int: The sum of a and b
'''
return a + b
await client.chat(model='llama3.1:8b', tools=[add_two_numbers], messages=[...])
Raises `RequestError` if a model is not provided.
Raises `ResponseError` if the request could not be fulfilled.
Expand Down Expand Up @@ -1075,9 +1127,9 @@ def _copy_messages(messages: Optional[Sequence[Union[Mapping[str, Any], Message]
)


def _copy_tools(tools: Optional[Sequence[Union[Mapping[str, Any], Tool]]]) -> Iterator[Tool]:
for tool in tools or []:
yield Tool.model_validate(tool)
def _copy_tools(tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None) -> Iterator[Tool]:
for unprocessed_tool in tools or []:
yield convert_function_to_tool(unprocessed_tool) if callable(unprocessed_tool) else Tool.model_validate(unprocessed_tool)


def _as_path(s: Optional[Union[str, PathLike]]) -> Union[Path, None]:
Expand Down
64 changes: 37 additions & 27 deletions ollama/_types.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,18 @@
import json
from base64 import b64encode
from base64 import b64decode, b64encode
from pathlib import Path
from datetime import datetime
from typing import (
Any,
Literal,
Mapping,
Optional,
Sequence,
Union,
)
from typing_extensions import Annotated
from typing import Any, Mapping, Optional, Union, Sequence

from typing_extensions import Annotated, Literal

from pydantic import (
BaseModel,
ByteSize,
ConfigDict,
Field,
FilePath,
Base64Str,
model_serializer,
)
from pydantic.json_schema import JsonSchemaValue


class SubscriptableBaseModel(BaseModel):
Expand Down Expand Up @@ -95,16 +87,26 @@ class BaseGenerateRequest(BaseStreamableRequest):


class Image(BaseModel):
value: Union[FilePath, Base64Str, bytes]
value: Union[str, bytes, Path]

# This overloads the `model_dump` method and returns values depending on the type of the `value` field
@model_serializer
def serialize_model(self):
if isinstance(self.value, Path):
return b64encode(self.value.read_bytes()).decode()
elif isinstance(self.value, bytes):
return b64encode(self.value).decode()
return self.value
if isinstance(self.value, (Path, bytes)):
return b64encode(self.value.read_bytes() if isinstance(self.value, Path) else self.value).decode()

if isinstance(self.value, str):
if Path(self.value).exists():
return b64encode(Path(self.value).read_bytes()).decode()

if self.value.split('.')[-1] in ('png', 'jpg', 'jpeg', 'webp'):
raise ValueError(f'File {self.value} does not exist')

try:
# Try to decode to check if it's already base64
b64decode(self.value)
return self.value
except Exception:
raise ValueError('Invalid image data, expected base64 string or path to image file') from Exception


class GenerateRequest(BaseGenerateRequest):
Expand Down Expand Up @@ -222,20 +224,27 @@ class Function(SubscriptableBaseModel):


class Tool(SubscriptableBaseModel):
type: Literal['function'] = 'function'
type: Optional[Literal['function']] = 'function'

class Function(SubscriptableBaseModel):
name: str
description: str
name: Optional[str] = None
description: Optional[str] = None

class Parameters(SubscriptableBaseModel):
type: str
type: Optional[Literal['object']] = 'object'
required: Optional[Sequence[str]] = None
properties: Optional[JsonSchemaValue] = None

parameters: Parameters
class Property(SubscriptableBaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)

type: Optional[str] = None
description: Optional[str] = None

properties: Optional[Mapping[str, Property]] = None

function: Function
parameters: Optional[Parameters] = None

function: Optional[Function] = None


class ChatRequest(BaseGenerateRequest):
Expand Down Expand Up @@ -335,6 +344,7 @@ class ModelDetails(SubscriptableBaseModel):

class ListResponse(SubscriptableBaseModel):
class Model(SubscriptableBaseModel):
model: Optional[str] = None
modified_at: Optional[datetime] = None
digest: Optional[str] = None
size: Optional[ByteSize] = None
Expand Down
87 changes: 87 additions & 0 deletions ollama/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from __future__ import annotations
from collections import defaultdict
import inspect
from typing import Callable, Union
import re

import pydantic
from ollama._types import Tool


def _parse_docstring(doc_string: Union[str, None]) -> dict[str, str]:
parsed_docstring = defaultdict(str)
if not doc_string:
return parsed_docstring

key = hash(doc_string)
for line in doc_string.splitlines():
lowered_line = line.lower().strip()
if lowered_line.startswith('args:'):
key = 'args'
elif lowered_line.startswith('returns:') or lowered_line.startswith('yields:') or lowered_line.startswith('raises:'):
key = '_'

else:
# maybe change to a list and join later
parsed_docstring[key] += f'{line.strip()}\n'

last_key = None
for line in parsed_docstring['args'].splitlines():
line = line.strip()
if ':' in line:
# Split the line on either:
# 1. A parenthetical expression like (integer) - captured in group 1
# 2. A colon :
# Followed by optional whitespace. Only split on first occurrence.
parts = re.split(r'(?:\(([^)]*)\)|:)\s*', line, maxsplit=1)

arg_name = parts[0].strip()
last_key = arg_name

# Get the description - will be in parts[1] if parenthetical or parts[-1] if after colon
arg_description = parts[-1].strip()
if len(parts) > 2 and parts[1]: # Has parenthetical content
arg_description = parts[-1].split(':', 1)[-1].strip()

parsed_docstring[last_key] = arg_description

elif last_key and line:
parsed_docstring[last_key] += ' ' + line

return parsed_docstring


def convert_function_to_tool(func: Callable) -> Tool:
doc_string_hash = hash(inspect.getdoc(func))
parsed_docstring = _parse_docstring(inspect.getdoc(func))
schema = type(
func.__name__,
(pydantic.BaseModel,),
{
'__annotations__': {k: v.annotation if v.annotation != inspect._empty else str for k, v in inspect.signature(func).parameters.items()},
'__signature__': inspect.signature(func),
'__doc__': parsed_docstring[doc_string_hash],
},
).model_json_schema()

for k, v in schema.get('properties', {}).items():
# If type is missing, the default is string
types = {t.get('type', 'string') for t in v.get('anyOf')} if 'anyOf' in v else {v.get('type', 'string')}
if 'null' in types:
schema['required'].remove(k)
types.discard('null')

schema['properties'][k] = {
'description': parsed_docstring[k],
'type': ', '.join(types),
}

tool = Tool(
function=Tool.Function(
name=func.__name__,
description=schema.get('description', ''),
parameters=Tool.Function.Parameters(**schema),
)
)

return Tool.model_validate(tool)
56 changes: 55 additions & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import os
import io
import json
from pydantic import ValidationError
import pytest
import tempfile
from pathlib import Path
from pytest_httpserver import HTTPServer, URIPattern
from werkzeug.wrappers import Request, Response
from PIL import Image

from ollama._client import Client, AsyncClient
from ollama._client import Client, AsyncClient, _copy_tools


class PrefixPattern(URIPattern):
Expand Down Expand Up @@ -982,3 +983,56 @@ def test_headers():
)
assert client._client.headers['x-custom'] == 'value'
assert client._client.headers['content-type'] == 'application/json'


def test_copy_tools():
def func1(x: int) -> str:
"""Simple function 1.
Args:
x (integer): A number
"""
pass

def func2(y: str) -> int:
"""Simple function 2.
Args:
y (string): A string
"""
pass

# Test with list of functions
tools = list(_copy_tools([func1, func2]))
assert len(tools) == 2
assert tools[0].function.name == 'func1'
assert tools[1].function.name == 'func2'

# Test with empty input
assert list(_copy_tools()) == []
assert list(_copy_tools(None)) == []
assert list(_copy_tools([])) == []

# Test with mix of functions and tool dicts
tool_dict = {
'type': 'function',
'function': {
'name': 'test',
'description': 'Test function',
'parameters': {
'type': 'object',
'properties': {'x': {'type': 'string', 'description': 'A string'}},
'required': ['x'],
},
},
}

tools = list(_copy_tools([func1, tool_dict]))
assert len(tools) == 2
assert tools[0].function.name == 'func1'
assert tools[1].function.name == 'test'


def test_tool_validation():
# Raises ValidationError when used as it is a generator
with pytest.raises(ValidationError):
invalid_tool = {'type': 'invalid_type', 'function': {'name': 'test'}}
list(_copy_tools([invalid_tool]))
Loading

0 comments on commit c4ae94c

Please sign in to comment.