Skip to content

Commit

Permalink
Fix types in tool tests (#2285)
Browse files Browse the repository at this point in the history
* fixed types related to function calling

* polishing

* fixed types in tests
  • Loading branch information
davorrunje authored Apr 5, 2024
1 parent 0e0895f commit 0c0f953
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 21 deletions.
4 changes: 2 additions & 2 deletions autogen/_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pydantic._internal._typing_extra import eval_type_lenient as evaluate_forwardref
from pydantic.json_schema import JsonSchemaValue

def type2schema(t: Optional[Type[Any]]) -> JsonSchemaValue:
def type2schema(t: Any) -> JsonSchemaValue:
"""Convert a type to a JSON schema
Args:
Expand Down Expand Up @@ -55,7 +55,7 @@ def model_dump_json(model: BaseModel) -> str:

JsonSchemaValue = Dict[str, Any] # type: ignore[misc]

def type2schema(t: Optional[Type[Any]]) -> JsonSchemaValue:
def type2schema(t: Any) -> JsonSchemaValue:
"""Convert a type to a JSON schema
Args:
Expand Down
4 changes: 1 addition & 3 deletions autogen/function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,7 @@ class ToolFunction(BaseModel):
function: Annotated[Function, Field(description="Function under tool")]


def get_parameter_json_schema(
k: str, v: Union[Annotated[Type[Any], str], Type[Any]], default_values: Dict[str, Any]
) -> JsonSchemaValue:
def get_parameter_json_schema(k: str, v: Any, default_values: Dict[str, Any]) -> JsonSchemaValue:
"""Get a JSON schema for a parameter as defined by the OpenAI API
Args:
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ files = [
"autogen/_pydantic.py",
"autogen/function_utils.py",
"autogen/io",
"test/test_pydantic.py",
"test/test_function_utils.py",
"test/io",
]

Expand Down
32 changes: 16 additions & 16 deletions test/test_function_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import inspect
import unittest.mock
from typing import Dict, List, Literal, Optional, Tuple
from typing import Any, Dict, List, Literal, Optional, Tuple

import pytest
from pydantic import BaseModel, Field
Expand All @@ -25,11 +25,11 @@
)


def f(a: Annotated[str, "Parameter a"], b: int = 2, c: Annotated[float, "Parameter c"] = 0.1, *, d):
def f(a: Annotated[str, "Parameter a"], b: int = 2, c: Annotated[float, "Parameter c"] = 0.1, *, d): # type: ignore[no-untyped-def]
pass


def g(
def g( # type: ignore[empty-body]
a: Annotated[str, "Parameter a"],
b: int = 2,
c: Annotated[float, "Parameter c"] = 0.1,
Expand All @@ -39,7 +39,7 @@ def g(
pass


async def a_g(
async def a_g( # type: ignore[empty-body]
a: Annotated[str, "Parameter a"],
b: int = 2,
c: Annotated[float, "Parameter c"] = 0.1,
Expand Down Expand Up @@ -83,7 +83,7 @@ class B(BaseModel):
b: float
c: str

expected = {
expected: Dict[str, Any] = {
"description": "b",
"properties": {"b": {"title": "B", "type": "number"}, "c": {"title": "C", "type": "string"}},
"required": ["b", "c"],
Expand All @@ -107,7 +107,7 @@ def test_get_default_values() -> None:


def test_get_param_annotations() -> None:
def f(a: Annotated[str, "Parameter a"], b=1, c: Annotated[float, "Parameter c"] = 1.0):
def f(a: Annotated[str, "Parameter a"], b=1, c: Annotated[float, "Parameter c"] = 1.0): # type: ignore[no-untyped-def]
pass

expected = {"a": Annotated[str, "Parameter a"], "c": Annotated[float, "Parameter c"]}
Expand All @@ -119,14 +119,14 @@ def f(a: Annotated[str, "Parameter a"], b=1, c: Annotated[float, "Parameter c"]


def test_get_missing_annotations() -> None:
def _f1(a: str, b=2):
def _f1(a: str, b=2): # type: ignore[no-untyped-def]
pass

missing, unannotated_with_default = get_missing_annotations(get_typed_signature(_f1), ["a"])
assert missing == set()
assert unannotated_with_default == {"b"}

def _f2(a: str, b) -> str:
def _f2(a: str, b) -> str: # type: ignore[empty-body,no-untyped-def]
"ok"

missing, unannotated_with_default = get_missing_annotations(get_typed_signature(_f2), ["a", "b"])
Expand All @@ -142,7 +142,7 @@ def _f3() -> None:


def test_get_parameters() -> None:
def f(a: Annotated[str, "Parameter a"], b=1, c: Annotated[float, "Parameter c"] = 1.0):
def f(a: Annotated[str, "Parameter a"], b=1, c: Annotated[float, "Parameter c"] = 1.0): # type: ignore[no-untyped-def]
pass

typed_signature = get_typed_signature(f)
Expand All @@ -165,7 +165,7 @@ def f(a: Annotated[str, "Parameter a"], b=1, c: Annotated[float, "Parameter c"]


def test_get_function_schema_no_return_type() -> None:
def f(a: Annotated[str, "Parameter a"], b: int, c: float = 0.1):
def f(a: Annotated[str, "Parameter a"], b: int, c: float = 0.1): # type: ignore[no-untyped-def]
pass

expected = (
Expand All @@ -182,7 +182,7 @@ def f(a: Annotated[str, "Parameter a"], b: int, c: float = 0.1):
def test_get_function_schema_unannotated_with_default() -> None:
with unittest.mock.patch("autogen.function_utils.logger.warning") as mock_logger_warning:

def f(
def f( # type: ignore[no-untyped-def]
a: Annotated[str, "Parameter a"], b=2, c: Annotated[float, "Parameter c"] = 0.1, d="whatever", e=None
) -> str:
return "ok"
Expand All @@ -195,7 +195,7 @@ def f(


def test_get_function_schema_missing() -> None:
def f(a: Annotated[str, "Parameter a"], b, c: Annotated[float, "Parameter c"] = 0.1) -> float:
def f(a: Annotated[str, "Parameter a"], b, c: Annotated[float, "Parameter c"] = 0.1) -> float: # type: ignore[no-untyped-def, empty-body]
pass

expected = (
Expand Down Expand Up @@ -291,7 +291,7 @@ class Currency(BaseModel):


def test_get_function_schema_pydantic() -> None:
def currency_calculator(
def currency_calculator( # type: ignore[empty-body]
base: Annotated[Currency, "Base currency: amount and currency symbol"],
quote_currency: Annotated[CurrencySymbol, "Quote currency symbol (default: 'EUR')"] = "EUR",
) -> Currency:
Expand Down Expand Up @@ -346,12 +346,12 @@ def currency_calculator(

def test_get_load_param_if_needed_function() -> None:
assert get_load_param_if_needed_function(CurrencySymbol) is None
assert get_load_param_if_needed_function(Currency)({"currency": "USD", "amount": 123.45}, Currency) == Currency(
assert get_load_param_if_needed_function(Currency)({"currency": "USD", "amount": 123.45}, Currency) == Currency( # type: ignore[misc]
currency="USD", amount=123.45
)

f = get_load_param_if_needed_function(Annotated[Currency, "amount and a symbol of a currency"])
actual = f({"currency": "USD", "amount": 123.45}, Currency)
actual = f({"currency": "USD", "amount": 123.45}, Currency) # type: ignore[misc]
expected = Currency(currency="USD", amount=123.45)
assert actual == expected, actual

Expand Down Expand Up @@ -391,7 +391,7 @@ async def f(
assert actual[1] == "EUR"


def test_serialize_to_json():
def test_serialize_to_json() -> None:
assert serialize_to_str("abc") == "abc"
assert serialize_to_str(123) == "123"
assert serialize_to_str([123, 456]) == "[123, 456]"
Expand Down

0 comments on commit 0c0f953

Please sign in to comment.