Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
ParthSareen committed Nov 19, 2024
1 parent 0d9eec0 commit 65db34a
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 63 deletions.
8 changes: 8 additions & 0 deletions ollama/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,14 @@ def chat(
"""
Create a chat response using the requested model.
Args:
tools (Sequence[Union[Mapping[str, Any], Tool, Callable]]):
A JSON schema as a dict, an Ollama Tool or a Python Function.
Python functions need to follow Google style docstrings to be converted to an Ollama Tool.
For more information, see: https://google.github.io/styleguide/pyguide.html#38-Docstrings
stream (bool): Whether to stream the response.
format (Optional[Literal['', 'json']]): The format of the response.
Raises `RequestError` if a model is not provided.
Raises `ResponseError` if the request could not be fulfilled.
Expand Down
2 changes: 1 addition & 1 deletion ollama/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ class ModelDetails(SubscriptableBaseModel):

class ListResponse(SubscriptableBaseModel):
class Model(SubscriptableBaseModel):
name: Optional[str] = None
model: Optional[str] = None
modified_at: Optional[datetime] = None
digest: Optional[str] = None
size: Optional[ByteSize] = None
Expand Down
100 changes: 39 additions & 61 deletions ollama/_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
from collections import defaultdict
import inspect
from typing import Callable, Union

Expand All @@ -7,96 +8,73 @@


def _parse_docstring(doc_string: Union[str, None]) -> dict[str, str]:
parsed_docstring = {'description': ''}
parsed_docstring = defaultdict(str)
if not doc_string:
return parsed_docstring

lowered_doc_string = doc_string.lower()

if 'args:' not in lowered_doc_string:
parsed_docstring['description'] = lowered_doc_string.strip()
return parsed_docstring

else:
parsed_docstring['description'] = lowered_doc_string.split('args:')[0].strip()
args_section = lowered_doc_string.split('args:')[1]

if 'returns:' in lowered_doc_string:
# Return section can be captured and used
args_section = args_section.split('returns:')[0]
# change name
key = 'func_description'
parsed_docstring[key] = ''
for line in lowered_doc_string.splitlines():
if line.startswith('args:'):
key = 'args'
elif line.startswith('returns:') or line.startswith('yields:') or line.startswith('raises:'):
key = '_'

if 'yields:' in lowered_doc_string:
args_section = args_section.split('yields:')[0]
else:
# maybe change to a list and join later
parsed_docstring[key] += f'{line.strip()}\n'

cur_var = None
for line in args_section.split('\n'):
last_key = None
for line in parsed_docstring['args'].splitlines():
line = line.strip()
if not line:
continue
if ':' not in line:
# Continuation of the previous parameter's description
if cur_var:
parsed_docstring[cur_var] += f' {line}'
continue

# For the case with: `param_name (type)`: ...
if '(' in line:
param_name = line.split('(')[0]
param_desc = line.split('):')[1]

# For the case with: `param_name: ...`
else:
param_name, param_desc = line.split(':', 1)
if ':' in line and not line.startswith('args'):
# Split on first occurrence of '(' or ':' to separate arg name from description
split_char = '(' if '(' in line else ':'
arg_name, rest = line.split(split_char, 1)

last_key = arg_name.strip()
# Get description after the colon
arg_description = rest.split(':', 1)[1].strip() if split_char == '(' else rest.strip()
parsed_docstring[last_key] = arg_description

parsed_docstring[param_name.strip()] = param_desc.strip()
cur_var = param_name.strip()
elif last_key and line:
parsed_docstring[last_key] += ' ' + line

return parsed_docstring


def convert_function_to_tool(func: Callable) -> Tool:
parsed_docstring = _parse_docstring(inspect.getdoc(func))
schema = type(
func.__name__,
(pydantic.BaseModel,),
{
'__annotations__': {k: v.annotation for k, v in inspect.signature(func).parameters.items()},
'__annotations__': {k: v.annotation if v.annotation != inspect._empty else str for k, v in inspect.signature(func).parameters.items()},
'__signature__': inspect.signature(func),
'__doc__': inspect.getdoc(func),
'__doc__': parsed_docstring.get('func_description', ''),
},
).model_json_schema()

properties = {}
required = []
parsed_docstring = _parse_docstring(schema.get('description'))
for k, v in schema.get('properties', {}).items():
prop = {
# think about how no type is handled
types = {t.get('type', 'string') for t in v.get('anyOf', [])} if 'anyOf' in v else {v.get('type', 'string')}
if 'null' in types:
schema['required'].remove(k)
types.discard('null')

schema['properties'][k] = {
'description': parsed_docstring.get(k, ''),
'type': v.get('type'),
'type': ', '.join(types),
}

if 'anyOf' in v:
is_optional = any(t.get('type') == 'null' for t in v['anyOf'])
types = [t.get('type', 'string') for t in v['anyOf'] if t.get('type') != 'null']
prop['type'] = types[0] if len(types) == 1 else str(types)
if not is_optional:
required.append(k)
else:
if prop['type'] != 'null':
required.append(k)

properties[k] = prop

schema['properties'] = properties

tool = Tool(
function=Tool.Function(
name=func.__name__,
description=parsed_docstring.get('description'),
parameters=Tool.Function.Parameters(
type='object',
properties=schema.get('properties', {}),
required=required,
),
description=schema.get('description', ''),
parameters=Tool.Function.Parameters(**schema),
)
)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def only_description():

tool = convert_function_to_tool(only_description).model_dump()
assert tool['function']['description'] == 'a function with only a description.'
assert tool['function']['parameters'] == {'type': 'object', 'properties': {}, 'required': []}
assert tool['function']['parameters'] == {'type': 'object', 'properties': {}, 'required': None}

def only_description_with_args(x: int, y: int):
"""
Expand Down

0 comments on commit 65db34a

Please sign in to comment.