diff --git a/changes/2465-daviskirk.md b/changes/2465-daviskirk.md new file mode 100644 index 0000000000..da317c5820 --- /dev/null +++ b/changes/2465-daviskirk.md @@ -0,0 +1 @@ +Support user defined generic field types in generic models. diff --git a/pydantic/generics.py b/pydantic/generics.py index ad224a477a..d337883a31 100644 --- a/pydantic/generics.py +++ b/pydantic/generics.py @@ -167,7 +167,12 @@ def replace_types(type_: Any, type_map: Mapping[Any, Any]) -> Any: # If all arguments are the same, there is no need to modify the # type or create a new object at all return type_ - if origin_type is not None and isinstance(type_, typing_base) and not isinstance(origin_type, typing_base): + if ( + origin_type is not None + and isinstance(type_, typing_base) + and not isinstance(origin_type, typing_base) + and getattr(type_, '_name', None) is not None + ): # In python < 3.9 generic aliases don't exist so any of these like `list`, # `type` or `collections.abc.Callable` need to be translated. # See: https://www.python.org/dev/peps/pep-0585 diff --git a/tests/test_generics.py b/tests/test_generics.py index d1e42d8666..9824293722 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -1,6 +1,20 @@ import sys from enum import Enum -from typing import Any, Callable, ClassVar, Dict, Generic, List, Optional, Sequence, Tuple, Type, TypeVar, Union +from typing import ( + Any, + Callable, + ClassVar, + Dict, + Generic, + List, + Mapping, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, +) import pytest from typing_extensions import Literal @@ -808,6 +822,30 @@ class Model(GenericModel, Generic[T]): assert replace_types(list[Union[str, list, T]], {T: int}) == list[Union[str, list, int]] +@skip_36 +def test_replace_types_with_user_defined_generic_type_field(): + """Test that using user defined generic types as generic model fields are handled correctly.""" + + T = TypeVar('T') + KT = TypeVar('KT') + VT = TypeVar('VT') + + class GenericMapping(Mapping[KT, VT]): + pass + + class GenericList(List[T]): + pass + + class Model(GenericModel, Generic[T, KT, VT]): + + map_field: GenericMapping[KT, VT] + list_field: GenericList[T] + + assert replace_types(Model, {T: bool, KT: str, VT: int}) == Model[bool, str, int] + assert replace_types(Model[T, KT, VT], {T: bool, KT: str, VT: int}) == Model[bool, str, int] + assert replace_types(Model[T, VT, KT], {T: bool, KT: str, VT: int}) == Model[T, VT, KT][bool, int, str] + + @skip_36 def test_replace_types_identity_on_unchanged(): T = TypeVar('T') @@ -1071,3 +1109,21 @@ class GModel(GenericModel, Generic[FieldType, ValueType]): Fields = Literal['foo', 'bar'] m = GModel[Fields, str](field={'foo': 'x'}) assert m.dict() == {'field': {'foo': 'x'}} + + +@skip_36 +def test_generic_with_user_defined_generic_field(): + T = TypeVar('T') + + class GenericList(List[T]): + pass + + class Model(GenericModel, Generic[T]): + + field: GenericList[T] + + model = Model[int](field=[5]) + assert model.field[0] == 5 + + with pytest.raises(ValidationError): + model = Model[int](field=['a'])