diff --git a/src/marshmallow/fields.py b/src/marshmallow/fields.py index 3e9a0f3bd..4480fb35e 100644 --- a/src/marshmallow/fields.py +++ b/src/marshmallow/fields.py @@ -30,6 +30,10 @@ from marshmallow.validate import And, Length from marshmallow.warnings import RemovedInMarshmallow4Warning +if typing.TYPE_CHECKING: + from marshmallow.schema import SchemaMeta + + __all__ = [ "Field", "Raw", @@ -535,10 +539,10 @@ class ParentSchema(Schema): def __init__( self, nested: SchemaABC - | type + | SchemaMeta | str - | dict[str, Field | type] - | typing.Callable[[], SchemaABC | type | dict[str, Field | type]], + | dict[str, Field | type[Field]] + | typing.Callable[[], SchemaABC | SchemaMeta | dict[str, Field | type[Field]]], *, dump_default: typing.Any = missing_, default: typing.Any = missing_, @@ -698,7 +702,7 @@ class AlbumSchema(Schema): def __init__( self, - nested: SchemaABC | type | str | typing.Callable[[], SchemaABC], + nested: SchemaABC | SchemaMeta | str | typing.Callable[[], SchemaABC], field_name: str, **kwargs, ): @@ -749,7 +753,7 @@ class List(Field): #: Default error messages. default_error_messages = {"invalid": "Not a valid list."} - def __init__(self, cls_or_instance: Field | type, **kwargs): + def __init__(self, cls_or_instance: Field | type[Field], **kwargs): super().__init__(**kwargs) try: self.inner = resolve_field_instance(cls_or_instance) @@ -1553,8 +1557,8 @@ class Mapping(Field): def __init__( self, - keys: Field | type | None = None, - values: Field | type | None = None, + keys: Field | type[Field] | None = None, + values: Field | type[Field] | None = None, **kwargs, ): super().__init__(**kwargs) @@ -1876,7 +1880,7 @@ def __init__( self, enum: type[EnumType], *, - by_value: bool | Field | type = False, + by_value: bool | Field | type[Field] = False, **kwargs, ): super().__init__(**kwargs) diff --git a/src/marshmallow/schema.py b/src/marshmallow/schema.py index 6f15a0f3c..0439c0ee0 100644 --- a/src/marshmallow/schema.py +++ b/src/marshmallow/schema.py @@ -56,7 +56,7 @@ def _get_fields(attrs): # This function allows Schemas to inherit from non-Schema classes and ensures # inheritance according to the MRO -def _get_fields_by_mro(klass): +def _get_fields_by_mro(klass: SchemaMeta): """Collect fields from a class, following its method resolution order. The class itself is excluded from the search; only its parents are checked. Get fields from ``_declared_fields`` if available, else use ``__dict__``. @@ -124,7 +124,7 @@ def __new__(mcs, name, bases, attrs): @classmethod def get_declared_fields( mcs, - klass: type, + klass: SchemaMeta, cls_fields: list, inherited_fields: list, dict_cls: type = dict, @@ -417,7 +417,7 @@ def dict_class(self) -> type: @classmethod def from_dict( cls, - fields: dict[str, ma_fields.Field | type], + fields: dict[str, ma_fields.Field | type[ma_fields.Field]], *, name: str = "GeneratedSchema", ) -> type: