Skip to content

Commit

Permalink
validate_call type params fix (#9760)
Browse files Browse the repository at this point in the history
  • Loading branch information
sydney-runkle authored Jun 26, 2024
1 parent 04f3a46 commit fcd2010
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 5 deletions.
3 changes: 2 additions & 1 deletion pydantic/_internal/_typing_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ def get_function_type_hints(

globalns = add_module_globals(function)
type_hints = {}
type_params: tuple[Any] = getattr(function, '__type_params__', ()) # type: ignore
for name, value in annotations.items():
if include_keys is not None and name not in include_keys:
continue
Expand All @@ -315,7 +316,7 @@ def get_function_type_hints(
elif isinstance(value, str):
value = _make_forward_ref(value)

type_hints[name] = eval_type_backport(value, globalns, types_namespace)
type_hints[name] = eval_type_backport(value, globalns, types_namespace, type_params)

return type_hints

Expand Down
19 changes: 17 additions & 2 deletions pydantic/_internal/_validate_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@ class ValidateCallWrapper:
'__dict__', # required for __module__
)

def __init__(self, function: Callable[..., Any], config: ConfigDict | None, validate_return: bool):
def __init__(
self,
function: Callable[..., Any],
config: ConfigDict | None,
validate_return: bool,
namespace: dict[str, Any] | None,
):
if isinstance(function, partial):
func = function.func
schema_type = func
Expand All @@ -36,7 +42,16 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali
self.__qualname__ = function.__qualname__
self.__module__ = function.__module__

namespace = _typing_extra.add_module_globals(function, None)
global_ns = _typing_extra.add_module_globals(function, None)
# TODO: this is a bit of a hack, we should probably have a better way to handle this
# specifically, we shouldn't be pumping the namespace full of type_params
# when we take namespace and type_params arguments in eval_type_backport
type_params = getattr(schema_type, '__type_params__', ())
namespace = {
**{param.__name__: param for param in type_params},
**(global_ns or {}),
**(namespace or {}),
}
config_wrapper = ConfigWrapper(config)
gen_schema = _generate_schema.GenerateSchema(config_wrapper, namespace)
schema = gen_schema.clean_schema(gen_schema.generate_schema(function))
Expand Down
6 changes: 4 additions & 2 deletions pydantic/validate_call_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import functools
from typing import TYPE_CHECKING, Any, Callable, TypeVar, overload

from ._internal import _validate_call
from ._internal import _typing_extra, _validate_call

__all__ = ('validate_call',)

Expand Down Expand Up @@ -46,12 +46,14 @@ def validate_call(
Returns:
The decorated function.
"""
local_ns = _typing_extra.parent_frame_namespace()

def validate(function: AnyCallableT) -> AnyCallableT:
if isinstance(function, (classmethod, staticmethod)):
name = type(function).__name__
raise TypeError(f'The `@{name}` decorator should be applied after `@validate_call` (put `@{name}` on top)')
validate_call_wrapper = _validate_call.ValidateCallWrapper(function, config, validate_return)

validate_call_wrapper = _validate_call.ValidateCallWrapper(function, config, validate_return, local_ns)

@functools.wraps(function)
def wrapper_function(*args, **kwargs):
Expand Down
57 changes: 57 additions & 0 deletions tests/test_validate_call.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import inspect
import re
import sys
from datetime import datetime, timezone
from functools import partial
from typing import Any, List, Tuple
Expand Down Expand Up @@ -803,3 +804,59 @@ def foo(bar: 'list[int | str]') -> 'list[int | str]':
'input': {'not a str or int'},
},
]


@pytest.mark.skipif(sys.version_info < (3, 12), reason='requires Python 3.12+ for PEP 695 syntax with generics')
def test_validate_call_with_pep_695_syntax() -> None:
"""Note: validate_call still doesn't work properly with generics, see https://github.com/pydantic/pydantic/issues/7796.
This test is just to ensure that the syntax is accepted and doesn't raise a NameError."""
globs = {}
exec(
"""
from typing import Iterable
from pydantic import validate_call
@validate_call
def find_max_no_validate_return[T](args: Iterable[T]) -> T:
return sorted(args, reverse=True)[0]
@validate_call(validate_return=True)
def find_max_validate_return[T](args: Iterable[T]) -> T:
return sorted(args, reverse=True)[0]
""",
globs,
)
functions = [globs['find_max_no_validate_return'], globs['find_max_validate_return']]
for find_max in functions:
assert len(find_max.__type_params__) == 1
assert find_max([1, 2, 10, 5]) == 10

with pytest.raises(ValidationError):
find_max(1)


class M0(BaseModel):
z: int


M = M0


def test_uses_local_ns():
class M1(BaseModel):
y: int

M = M1 # noqa: F841

def foo():
class M2(BaseModel):
z: int

M = M2

@validate_call
def bar(m: M) -> M:
return m

assert bar({'z': 1}) == M2(z=1)

0 comments on commit fcd2010

Please sign in to comment.