Skip to content

Commit

Permalink
Trying stuff out
Browse files Browse the repository at this point in the history
  • Loading branch information
ParthSareen committed Nov 11, 2024
1 parent 97aa167 commit 4e8bd4b
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 14 deletions.
30 changes: 20 additions & 10 deletions ollama/_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,22 @@
from __future__ import annotations

from types import UnionType
from typing import Any, Callable, List, Mapping, Optional, Union, get_args, get_origin
from ollama._types import Tool
from collections.abc import Sequence, Set
from typing import Dict, Set as TypeSet
import sys

# Type compatibility layer
if sys.version_info >= (3, 10):
from types import UnionType

def is_union(tp: Any) -> bool:
return get_origin(tp) in (Union, UnionType)
else:

def is_union(tp: Any) -> bool:
return get_origin(tp) is Union


# Map both the type and the type reference to the same JSON type
TYPE_MAP = {
Expand Down Expand Up @@ -47,8 +59,7 @@

def _get_json_type(python_type: Any) -> str | List[str]:
# Handle Optional types (Union[type, None] and type | None)
origin = get_origin(python_type)
if origin is UnionType or origin is Union:
if is_union(python_type):
args = get_args(python_type)
# Filter out None/NoneType from union args
non_none_args = [arg for arg in args if arg not in (None, type(None))]
Expand All @@ -60,16 +71,16 @@ def _get_json_type(python_type: Any) -> str | List[str]:
return 'null'

# Handle generic types (List[int], Dict[str, int], etc.)
if origin is not None:
if get_origin(python_type) is not None:
# Get the base type (List, Dict, etc.)
base_type = TYPE_MAP.get(origin, None)
base_type = TYPE_MAP.get(get_origin(python_type), None)
if base_type:
return base_type
# If it's a subclass of known abstract base classes, map to appropriate type
if isinstance(origin, type):
if issubclass(origin, (list, Sequence, tuple, set, Set)):
if isinstance(get_origin(python_type), type):
if issubclass(get_origin(python_type), (list, Sequence, tuple, set, Set)):
return 'array'
if issubclass(origin, (dict, Mapping)):
if issubclass(get_origin(python_type), (dict, Mapping)):
return 'object'

# Handle both type objects and type references
Expand All @@ -90,8 +101,7 @@ def _get_json_type(python_type: Any) -> str | List[str]:


def _is_optional_type(python_type: Any) -> bool:
origin = get_origin(python_type)
if origin is UnionType or origin is Union:
if is_union(python_type):
args = get_args(python_type)
return any(arg in (None, type(None)) for arg in args)
return False
Expand Down
8 changes: 4 additions & 4 deletions tests/test_utils_legacy.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from __future__ import annotations

import sys
# import sys
from typing import Dict, List, Mapping, Optional, Sequence, Set, Tuple, Union

import pytest
# import pytest

if sys.version_info >= (3, 10):
pytest.skip('Python 3.9 or lower is required', allow_module_level=True)
# if sys.version_info >= (3, 10):
# pytest.skip('Python 3.9 or lower is required', allow_module_level=True)

from ollama._utils import _get_json_type, convert_function_to_tool

Expand Down

0 comments on commit 4e8bd4b

Please sign in to comment.