forked from ollama/ollama-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Passing Functions as Tools (ollama#321)
* Functions can now be passed as tools
- Loading branch information
1 parent
2105f4e
commit c4ae94c
Showing
6 changed files
with
545 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from __future__ import annotations | ||
from collections import defaultdict | ||
import inspect | ||
from typing import Callable, Union | ||
import re | ||
|
||
import pydantic | ||
from ollama._types import Tool | ||
|
||
|
||
def _parse_docstring(doc_string: Union[str, None]) -> dict[str, str]: | ||
parsed_docstring = defaultdict(str) | ||
if not doc_string: | ||
return parsed_docstring | ||
|
||
key = hash(doc_string) | ||
for line in doc_string.splitlines(): | ||
lowered_line = line.lower().strip() | ||
if lowered_line.startswith('args:'): | ||
key = 'args' | ||
elif lowered_line.startswith('returns:') or lowered_line.startswith('yields:') or lowered_line.startswith('raises:'): | ||
key = '_' | ||
|
||
else: | ||
# maybe change to a list and join later | ||
parsed_docstring[key] += f'{line.strip()}\n' | ||
|
||
last_key = None | ||
for line in parsed_docstring['args'].splitlines(): | ||
line = line.strip() | ||
if ':' in line: | ||
# Split the line on either: | ||
# 1. A parenthetical expression like (integer) - captured in group 1 | ||
# 2. A colon : | ||
# Followed by optional whitespace. Only split on first occurrence. | ||
parts = re.split(r'(?:\(([^)]*)\)|:)\s*', line, maxsplit=1) | ||
|
||
arg_name = parts[0].strip() | ||
last_key = arg_name | ||
|
||
# Get the description - will be in parts[1] if parenthetical or parts[-1] if after colon | ||
arg_description = parts[-1].strip() | ||
if len(parts) > 2 and parts[1]: # Has parenthetical content | ||
arg_description = parts[-1].split(':', 1)[-1].strip() | ||
|
||
parsed_docstring[last_key] = arg_description | ||
|
||
elif last_key and line: | ||
parsed_docstring[last_key] += ' ' + line | ||
|
||
return parsed_docstring | ||
|
||
|
||
def convert_function_to_tool(func: Callable) -> Tool: | ||
doc_string_hash = hash(inspect.getdoc(func)) | ||
parsed_docstring = _parse_docstring(inspect.getdoc(func)) | ||
schema = type( | ||
func.__name__, | ||
(pydantic.BaseModel,), | ||
{ | ||
'__annotations__': {k: v.annotation if v.annotation != inspect._empty else str for k, v in inspect.signature(func).parameters.items()}, | ||
'__signature__': inspect.signature(func), | ||
'__doc__': parsed_docstring[doc_string_hash], | ||
}, | ||
).model_json_schema() | ||
|
||
for k, v in schema.get('properties', {}).items(): | ||
# If type is missing, the default is string | ||
types = {t.get('type', 'string') for t in v.get('anyOf')} if 'anyOf' in v else {v.get('type', 'string')} | ||
if 'null' in types: | ||
schema['required'].remove(k) | ||
types.discard('null') | ||
|
||
schema['properties'][k] = { | ||
'description': parsed_docstring[k], | ||
'type': ', '.join(types), | ||
} | ||
|
||
tool = Tool( | ||
function=Tool.Function( | ||
name=func.__name__, | ||
description=schema.get('description', ''), | ||
parameters=Tool.Function.Parameters(**schema), | ||
) | ||
) | ||
|
||
return Tool.model_validate(tool) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.