Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Passing Functions as Tools #321

Merged
merged 32 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
4383603
WIP tool parsing
ParthSareen Nov 8, 2024
afe7db6
Managing multiple type options
ParthSareen Nov 9, 2024
8fee892
Add tool parsing and processing
ParthSareen Nov 11, 2024
0e5a940
Formatting and todos
ParthSareen Nov 11, 2024
1ef75a7
TODOs
ParthSareen Nov 11, 2024
93c7a63
wip
ParthSareen Nov 11, 2024
e5dc2b8
add annotations import for old tests
ParthSareen Nov 11, 2024
aa20015
Exhaustive type matching
ParthSareen Nov 11, 2024
d79538e
Ruff fix
ParthSareen Nov 11, 2024
97aa167
WIP trying tests out
ParthSareen Nov 11, 2024
8ec5123
Trying stuff out
ParthSareen Nov 11, 2024
efb775b
Multi-line docstrings and exhaustive tests
ParthSareen Nov 12, 2024
2efa54a
Walrus op for cleanup
ParthSareen Nov 12, 2024
1f089f7
Stringify return type arrays to not break server
ParthSareen Nov 13, 2024
fe8d143
WIP
ParthSareen Nov 14, 2024
67321a8
Organization, cleanup, pydantic serialization, update tests
ParthSareen Nov 14, 2024
2cc0b40
Typing fix
ParthSareen Nov 14, 2024
e68700c
Python3.8+ compatibility
ParthSareen Nov 14, 2024
f452fab
Add str -> str valid json mapping and add test
ParthSareen Nov 14, 2024
ca16670
Code cleanup and organization
ParthSareen Nov 14, 2024
7dcb598
Test unhappy parse path
ParthSareen Nov 14, 2024
7c5c294
Code cleanup + organize and add tests for type serialization
ParthSareen Nov 14, 2024
16c868a
Update to have graceful handling and not raise - added tests as well
ParthSareen Nov 15, 2024
718412a
Making good use of pydantic
ParthSareen Nov 18, 2024
e7bb55f
Add yields and test
ParthSareen Nov 18, 2024
7396ab6
Simplified parsing and fixed required - added tests
ParthSareen Nov 18, 2024
0d9eec0
Add tool.model_validate
ParthSareen Nov 18, 2024
ed3ba8a
Code style updates
ParthSareen Nov 19, 2024
a4ec34a
Add better messaging for chat
ParthSareen Nov 19, 2024
6d9c156
Addressing comments + cleanup + optional tool
ParthSareen Nov 19, 2024
c5c61a3
Better docstring parsing and some fixes
ParthSareen Nov 20, 2024
b0e0409
Bugfix/image encoding (#327)
ParthSareen Nov 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 57 additions & 5 deletions ollama/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from typing import (
Any,
Callable,
Literal,
Mapping,
Optional,
Expand All @@ -22,6 +23,9 @@

import sys


from ollama._utils import convert_function_to_tool

if sys.version_info < (3, 9):
from typing import Iterator, AsyncIterator
else:
Expand Down Expand Up @@ -284,7 +288,7 @@ def chat(
model: str = '',
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
*,
tools: Optional[Sequence[Union[Mapping[str, Any], Tool]]] = None,
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
stream: bool = False,
format: Optional[Literal['', 'json']] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
Expand All @@ -293,6 +297,30 @@ def chat(
"""
Create a chat response using the requested model.

Args:
tools:
A JSON schema as a dict, an Ollama Tool or a Python Function.
Python functions need to follow Google style docstrings to be converted to an Ollama Tool.
For more information, see: https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings
stream: Whether to stream the response.
format: The format of the response.

Example:
def add_two_numbers(a: int, b: int) -> int:
'''
Add two numbers together.

Args:
a: First number to add
b: Second number to add

Returns:
int: The sum of a and b
'''
return a + b

client.chat(model='llama3.1:8b', tools=[add_two_numbers], messages=[...])

Raises `RequestError` if a model is not provided.

Raises `ResponseError` if the request could not be fulfilled.
Expand Down Expand Up @@ -750,7 +778,7 @@ async def chat(
model: str = '',
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
*,
tools: Optional[Sequence[Union[Mapping[str, Any], Tool]]] = None,
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
stream: Literal[True] = True,
format: Optional[Literal['', 'json']] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
Expand All @@ -771,6 +799,30 @@ async def chat(
"""
Create a chat response using the requested model.

Args:
tools:
A JSON schema as a dict, an Ollama Tool or a Python Function.
Python functions need to follow Google style docstrings to be converted to an Ollama Tool.
For more information, see: https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings
stream: Whether to stream the response.
format: The format of the response.

Example:
def add_two_numbers(a: int, b: int) -> int:
'''
Add two numbers together.

Args:
a: First number to add
b: Second number to add

Returns:
int: The sum of a and b
'''
return a + b

await client.chat(model='llama3.1:8b', tools=[add_two_numbers], messages=[...])

Raises `RequestError` if a model is not provided.

Raises `ResponseError` if the request could not be fulfilled.
Expand Down Expand Up @@ -1075,9 +1127,9 @@ def _copy_messages(messages: Optional[Sequence[Union[Mapping[str, Any], Message]
)


def _copy_tools(tools: Optional[Sequence[Union[Mapping[str, Any], Tool]]]) -> Iterator[Tool]:
for tool in tools or []:
yield Tool.model_validate(tool)
def _copy_tools(tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None) -> Iterator[Tool]:
for unprocessed_tool in tools or []:
yield convert_function_to_tool(unprocessed_tool) if callable(unprocessed_tool) else Tool.model_validate(unprocessed_tool)


def _as_path(s: Optional[Union[str, PathLike]]) -> Union[Path, None]:
Expand Down
64 changes: 37 additions & 27 deletions ollama/_types.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,18 @@
import json
from base64 import b64encode
from base64 import b64decode, b64encode
from pathlib import Path
from datetime import datetime
from typing import (
Any,
Literal,
Mapping,
Optional,
Sequence,
Union,
)
from typing_extensions import Annotated
from typing import Any, Mapping, Optional, Union, Sequence

from typing_extensions import Annotated, Literal

from pydantic import (
BaseModel,
ByteSize,
ConfigDict,
Field,
FilePath,
Base64Str,
model_serializer,
)
from pydantic.json_schema import JsonSchemaValue


class SubscriptableBaseModel(BaseModel):
Expand Down Expand Up @@ -95,16 +87,26 @@ class BaseGenerateRequest(BaseStreamableRequest):


class Image(BaseModel):
value: Union[FilePath, Base64Str, bytes]
value: Union[str, bytes, Path]

# This overloads the `model_dump` method and returns values depending on the type of the `value` field
@model_serializer
def serialize_model(self):
if isinstance(self.value, Path):
return b64encode(self.value.read_bytes()).decode()
elif isinstance(self.value, bytes):
return b64encode(self.value).decode()
return self.value
if isinstance(self.value, (Path, bytes)):
return b64encode(self.value.read_bytes() if isinstance(self.value, Path) else self.value).decode()

if isinstance(self.value, str):
if Path(self.value).exists():
return b64encode(Path(self.value).read_bytes()).decode()

if self.value.split('.')[-1] in ('png', 'jpg', 'jpeg', 'webp'):
raise ValueError(f'File {self.value} does not exist')

try:
# Try to decode to check if it's already base64
b64decode(self.value)
return self.value
except Exception:
raise ValueError('Invalid image data, expected base64 string or path to image file') from Exception


class GenerateRequest(BaseGenerateRequest):
Expand Down Expand Up @@ -222,20 +224,27 @@ class Function(SubscriptableBaseModel):


class Tool(SubscriptableBaseModel):
type: Literal['function'] = 'function'
type: Optional[Literal['function']] = 'function'

class Function(SubscriptableBaseModel):
name: str
description: str
name: Optional[str] = None
description: Optional[str] = None

class Parameters(SubscriptableBaseModel):
type: str
type: Optional[Literal['object']] = 'object'
required: Optional[Sequence[str]] = None
properties: Optional[JsonSchemaValue] = None

parameters: Parameters
class Property(SubscriptableBaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
ParthSareen marked this conversation as resolved.
Show resolved Hide resolved

type: Optional[str] = None
description: Optional[str] = None

properties: Optional[Mapping[str, Property]] = None

function: Function
parameters: Optional[Parameters] = None

function: Optional[Function] = None


class ChatRequest(BaseGenerateRequest):
Expand Down Expand Up @@ -335,6 +344,7 @@ class ModelDetails(SubscriptableBaseModel):

class ListResponse(SubscriptableBaseModel):
class Model(SubscriptableBaseModel):
model: Optional[str] = None
modified_at: Optional[datetime] = None
digest: Optional[str] = None
size: Optional[ByteSize] = None
Expand Down
87 changes: 87 additions & 0 deletions ollama/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from __future__ import annotations
from collections import defaultdict
import inspect
from typing import Callable, Union
import re

import pydantic
from ollama._types import Tool


def _parse_docstring(doc_string: Union[str, None]) -> dict[str, str]:
parsed_docstring = defaultdict(str)
if not doc_string:
return parsed_docstring

key = hash(doc_string)
for line in doc_string.splitlines():
lowered_line = line.lower().strip()
if lowered_line.startswith('args:'):
key = 'args'
elif lowered_line.startswith('returns:') or lowered_line.startswith('yields:') or lowered_line.startswith('raises:'):
key = '_'

else:
# maybe change to a list and join later
parsed_docstring[key] += f'{line.strip()}\n'

last_key = None
for line in parsed_docstring['args'].splitlines():
line = line.strip()
if ':' in line:
# Split the line on either:
# 1. A parenthetical expression like (integer) - captured in group 1
# 2. A colon :
# Followed by optional whitespace. Only split on first occurrence.
parts = re.split(r'(?:\(([^)]*)\)|:)\s*', line, maxsplit=1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tend to avoid regexp when possible since it's hard to grok. In this scenario, a simpler solution would be to split on the mandatory : than parse the pre-colon and post-colon sections independently. Here's an example that passes your tests

  for line in parsed_docstring['args'].splitlines():
    pre, _, post = line.partition(':')
    if not pre.strip():
      continue
    if not post.strip() and last_key:
      parsed_docstring[last_key] += ' ' + pre
      continue

    arg_name, _, _ = pre.replace('(', ' ').partition(' ')
    last_key = arg_name.strip()

    parsed_docstring[last_key] = post.strip()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO @ParthSareen to spin out issue


arg_name = parts[0].strip()
last_key = arg_name

# Get the description - will be in parts[1] if parenthetical or parts[-1] if after colon
arg_description = parts[-1].strip()
if len(parts) > 2 and parts[1]: # Has parenthetical content
arg_description = parts[-1].split(':', 1)[-1].strip()

parsed_docstring[last_key] = arg_description

elif last_key and line:
parsed_docstring[last_key] += ' ' + line

return parsed_docstring


def convert_function_to_tool(func: Callable) -> Tool:
doc_string_hash = hash(inspect.getdoc(func))
parsed_docstring = _parse_docstring(inspect.getdoc(func))
schema = type(
func.__name__,
(pydantic.BaseModel,),
{
'__annotations__': {k: v.annotation if v.annotation != inspect._empty else str for k, v in inspect.signature(func).parameters.items()},
'__signature__': inspect.signature(func),
'__doc__': parsed_docstring[doc_string_hash],
},
).model_json_schema()

for k, v in schema.get('properties', {}).items():
ParthSareen marked this conversation as resolved.
Show resolved Hide resolved
# If type is missing, the default is string
types = {t.get('type', 'string') for t in v.get('anyOf')} if 'anyOf' in v else {v.get('type', 'string')}
if 'null' in types:
schema['required'].remove(k)
types.discard('null')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If null is the only type (for some reason), this will be an empty string

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is okay, IMO something like:

def (a:None, b:type(None)):
  ...

is extremely unlikely


schema['properties'][k] = {
'description': parsed_docstring[k],
'type': ', '.join(types),
}

tool = Tool(
function=Tool.Function(
name=func.__name__,
description=schema.get('description', ''),
parameters=Tool.Function.Parameters(**schema),
)
)

return Tool.model_validate(tool)
56 changes: 55 additions & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import os
import io
import json
from pydantic import ValidationError
import pytest
import tempfile
from pathlib import Path
from pytest_httpserver import HTTPServer, URIPattern
from werkzeug.wrappers import Request, Response
from PIL import Image

from ollama._client import Client, AsyncClient
from ollama._client import Client, AsyncClient, _copy_tools


class PrefixPattern(URIPattern):
Expand Down Expand Up @@ -982,3 +983,56 @@ def test_headers():
)
assert client._client.headers['x-custom'] == 'value'
assert client._client.headers['content-type'] == 'application/json'


def test_copy_tools():
def func1(x: int) -> str:
"""Simple function 1.
Args:
x (integer): A number
"""
pass

def func2(y: str) -> int:
"""Simple function 2.
Args:
y (string): A string
"""
pass

# Test with list of functions
tools = list(_copy_tools([func1, func2]))
assert len(tools) == 2
assert tools[0].function.name == 'func1'
assert tools[1].function.name == 'func2'

# Test with empty input
assert list(_copy_tools()) == []
assert list(_copy_tools(None)) == []
assert list(_copy_tools([])) == []

# Test with mix of functions and tool dicts
tool_dict = {
'type': 'function',
'function': {
'name': 'test',
'description': 'Test function',
'parameters': {
'type': 'object',
'properties': {'x': {'type': 'string', 'description': 'A string'}},
'required': ['x'],
},
},
}

tools = list(_copy_tools([func1, tool_dict]))
assert len(tools) == 2
assert tools[0].function.name == 'func1'
assert tools[1].function.name == 'test'


def test_tool_validation():
# Raises ValidationError when used as it is a generator
with pytest.raises(ValidationError):
invalid_tool = {'type': 'invalid_type', 'function': {'name': 'test'}}
list(_copy_tools([invalid_tool]))
Loading