Skip to content

Commit

Permalink
Addressing comments + cleanup + optional tool
Browse files Browse the repository at this point in the history
  • Loading branch information
ParthSareen committed Nov 20, 2024
1 parent a4ec34a commit 6d9c156
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 132 deletions.
17 changes: 6 additions & 11 deletions ollama/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def chat(
Example:
def add_two_numbers(a: int, b: int) -> int:
\"""
'''
Add two numbers together.
Args:
Expand All @@ -316,7 +316,7 @@ def add_two_numbers(a: int, b: int) -> int:
Returns:
int: The sum of a and b
\"""
'''
return a + b
client.chat(model='llama3.1:8b', tools=[add_two_numbers], messages=[...])
Expand Down Expand Up @@ -809,7 +809,7 @@ async def chat(
Example:
def add_two_numbers(a: int, b: int) -> int:
\"""
'''
Add two numbers together.
Args:
Expand All @@ -818,10 +818,10 @@ def add_two_numbers(a: int, b: int) -> int:
Returns:
int: The sum of a and b
\"""
'''
return a + b
client.chat(model='llama3.1:8b', tools=[add_two_numbers], messages=[...])
await client.chat(model='llama3.1:8b', tools=[add_two_numbers], messages=[...])
Raises `RequestError` if a model is not provided.
Expand Down Expand Up @@ -1128,10 +1128,7 @@ def _copy_messages(messages: Optional[Sequence[Union[Mapping[str, Any], Message]


def _copy_tools(tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None) -> Iterator[Tool]:
if not tools:
return []

for unprocessed_tool in tools:
for unprocessed_tool in tools or []:
yield convert_function_to_tool(unprocessed_tool) if callable(unprocessed_tool) else Tool.model_validate(unprocessed_tool)


Expand Down Expand Up @@ -1207,8 +1204,6 @@ def _parse_host(host: Optional[str]) -> str:
'https://[0001:002:003:0004::1]:56789/path'
>>> _parse_host('[0001:002:003:0004::1]:56789/path/')
'http://[0001:002:003:0004::1]:56789/path'
>>> _parse_host('http://host.docker.internal:11434/path')
'http://host.docker.internal:11434/path'
"""

host, port = host or '', 11434
Expand Down
16 changes: 8 additions & 8 deletions ollama/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,27 +216,27 @@ class Function(SubscriptableBaseModel):


class Tool(SubscriptableBaseModel):
type: Literal['function'] = 'function'
type: Optional[Literal['function']] = 'function'

class Function(SubscriptableBaseModel):
name: str
description: str
name: Optional[str] = None
description: Optional[str] = None

class Parameters(SubscriptableBaseModel):
type: Literal['object'] = 'object'
type: Optional[Literal['object']] = 'object'
required: Optional[Sequence[str]] = None

class Property(SubscriptableBaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)

type: str
description: str
type: Optional[str] = None
description: Optional[str] = None

properties: Optional[Mapping[str, Property]] = None

parameters: Parameters
parameters: Optional[Parameters] = None

function: Function
function: Optional[Function] = None


class ChatRequest(BaseGenerateRequest):
Expand Down
12 changes: 5 additions & 7 deletions ollama/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,12 @@ def _parse_docstring(doc_string: Union[str, None]) -> dict[str, str]:
if not doc_string:
return parsed_docstring

lowered_doc_string = doc_string.lower()

key = hash(doc_string)
parsed_docstring[key] = ''
for line in lowered_doc_string.splitlines():
if line.startswith('args:'):
for line in doc_string.splitlines():
lowered_line = line.lower()
if lowered_line.startswith('args:'):
key = 'args'
elif line.startswith('returns:') or line.startswith('yields:') or line.startswith('raises:'):
elif lowered_line.startswith('returns:') or lowered_line.startswith('yields:') or lowered_line.startswith('raises:'):
key = '_'

else:
Expand All @@ -29,7 +27,7 @@ def _parse_docstring(doc_string: Union[str, None]) -> dict[str, str]:
last_key = None
for line in parsed_docstring['args'].splitlines():
line = line.strip()
if ':' in line and not line.startswith('args'):
if ':' in line and not line.lower().startswith('args:'):
# Split on first occurrence of '(' or ':' to separate arg name from description
split_char = '(' if '(' in line else ':'
arg_name, rest = line.split(split_char, 1)
Expand Down
15 changes: 2 additions & 13 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,25 +1025,14 @@ def func2(y: str) -> int:
},
}

tool_json = json.loads(json.dumps(tool_dict))
tools = list(_copy_tools([func1, tool_dict, tool_json]))
assert len(tools) == 3
tools = list(_copy_tools([func1, tool_dict]))
assert len(tools) == 2
assert tools[0].function.name == 'func1'
assert tools[1].function.name == 'test'
assert tools[2].function.name == 'test'


def test_tool_validation():
# Test that malformed tool dictionaries are rejected
# Raises ValidationError when used as it is a generator
with pytest.raises(ValidationError):
invalid_tool = {'type': 'invalid_type', 'function': {'name': 'test'}}
list(_copy_tools([invalid_tool]))

# Test missing required fields
incomplete_tool = {
'type': 'function',
'function': {'name': 'test'}, # missing description and parameters
}
with pytest.raises(ValidationError):
list(_copy_tools([incomplete_tool]))
82 changes: 1 addition & 81 deletions tests/test_type_serialization.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from base64 import b64decode, b64encode

import pytest


from ollama._types import Image, Tool
from ollama._types import Image


def test_image_serialization():
Expand All @@ -16,81 +14,3 @@ def test_image_serialization():
b64_str = 'dGVzdCBiYXNlNjQgc3RyaW5n'
img = Image(value=b64_str)
assert img.model_dump() == b64decode(b64_str).decode()


def test_tool_serialization():
# Test valid tool serialization
tool = Tool(
function=Tool.Function(
name='add_two_numbers',
description='Add two numbers together.',
parameters=Tool.Function.Parameters(
type='object',
properties={
'a': Tool.Function.Parameters.Property(
type='integer',
description='The first number',
),
'b': Tool.Function.Parameters.Property(
type='integer',
description='The second number',
),
},
required=['a', 'b'],
),
)
)
assert tool.model_dump() == {
'type': 'function',
'function': {
'name': 'add_two_numbers',
'description': 'Add two numbers together.',
'parameters': {
'type': 'object',
'properties': {
'a': {
'type': 'integer',
'description': 'The first number',
},
'b': {
'type': 'integer',
'description': 'The second number',
},
},
'required': ['a', 'b'],
},
},
}

# Test invalid type
with pytest.raises(ValueError):
property = Tool.Function.Parameters.Property(
type=lambda x: x, # Invalid type
description='Invalid type',
)
Tool.model_validate(
Tool(
function=Tool.Function(
parameters=Tool.Function.Parameters(
properties={
'x': property,
}
)
)
)
)

# Test invalid parameters type
with pytest.raises(ValueError):
Tool.model_validate(
Tool(
function=Tool.Function(
name='test',
description='Test',
parameters=Tool.Function.Parameters(
type='invalid_type', # Must be 'object'
properties={},
),
)
)
)
24 changes: 12 additions & 12 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
def test_function_to_tool_conversion():
def add_numbers(x: int, y: Union[int, None] = None) -> int:
"""Add two numbers together.
Args:
args:
x (integer): The first number
y (integer, optional): The second number
Expand All @@ -22,10 +22,10 @@ def add_numbers(x: int, y: Union[int, None] = None) -> int:

assert tool['type'] == 'function'
assert tool['function']['name'] == 'add_numbers'
assert tool['function']['description'] == 'add two numbers together.'
assert tool['function']['description'] == 'Add two numbers together.'
assert tool['function']['parameters']['type'] == 'object'
assert tool['function']['parameters']['properties']['x']['type'] == 'integer'
assert tool['function']['parameters']['properties']['x']['description'] == 'the first number'
assert tool['function']['parameters']['properties']['x']['description'] == 'The first number'
assert tool['function']['parameters']['required'] == ['x']


Expand All @@ -42,7 +42,7 @@ def simple_func():

tool = convert_function_to_tool(simple_func).model_dump()
assert tool['function']['name'] == 'simple_func'
assert tool['function']['description'] == 'a simple function with no arguments.'
assert tool['function']['description'] == 'A simple function with no arguments.'
assert tool['function']['parameters']['properties'] == {}


Expand Down Expand Up @@ -137,9 +137,9 @@ def func_with_complex_docs(x: int, y: List[str]) -> Dict[str, Any]:
pass

tool = convert_function_to_tool(func_with_complex_docs).model_dump()
assert tool['function']['description'] == 'test function with complex docstring.'
assert tool['function']['parameters']['properties']['x']['description'] == 'a number with multiple lines'
assert tool['function']['parameters']['properties']['y']['description'] == 'a list with multiple lines'
assert tool['function']['description'] == 'Test function with complex docstring.'
assert tool['function']['parameters']['properties']['x']['description'] == 'A number with multiple lines'
assert tool['function']['parameters']['properties']['y']['description'] == 'A list with multiple lines'


def test_skewed_docstring_parsing():
Expand All @@ -159,8 +159,8 @@ def add_two_numbers(x: int, y: int) -> int:
pass

tool = convert_function_to_tool(add_two_numbers).model_dump()
assert tool['function']['parameters']['properties']['x']['description'] == ': the first number'
assert tool['function']['parameters']['properties']['y']['description'] == 'the second number'
assert tool['function']['parameters']['properties']['x']['description'] == ': The first number'
assert tool['function']['parameters']['properties']['y']['description'] == 'The second number'


def test_function_with_no_docstring():
Expand All @@ -187,7 +187,7 @@ def only_description():
pass

tool = convert_function_to_tool(only_description).model_dump()
assert tool['function']['description'] == 'a function with only a description.'
assert tool['function']['description'] == 'A function with only a description.'
assert tool['function']['parameters'] == {'type': 'object', 'properties': {}, 'required': None}

def only_description_with_args(x: int, y: int):
Expand All @@ -197,7 +197,7 @@ def only_description_with_args(x: int, y: int):
pass

tool = convert_function_to_tool(only_description_with_args).model_dump()
assert tool['function']['description'] == 'a function with only a description.'
assert tool['function']['description'] == 'A function with only a description.'
assert tool['function']['parameters'] == {
'type': 'object',
'properties': {
Expand All @@ -223,7 +223,7 @@ def function_with_yields(x: int, y: int):
pass

tool = convert_function_to_tool(function_with_yields).model_dump()
assert tool['function']['description'] == 'a function with yields section.'
assert tool['function']['description'] == 'A function with yields section.'
assert tool['function']['parameters']['properties']['x']['description'] == 'the first number'
assert tool['function']['parameters']['properties']['y']['description'] == 'the second number'

Expand Down

0 comments on commit 6d9c156

Please sign in to comment.