Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a TypeGuard for dataclasses.is_dataclass(); refine asdict(), astuple(), fields(), replace() #9362

Merged
merged 15 commits into from
Jan 28, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions stdlib/dataclasses.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ import sys
import types
from builtins import type as Type # alias to avoid name clashes with fields named "type"
from collections.abc import Callable, Iterable, Mapping
from typing import Any, Generic, Protocol, TypeVar, overload
from typing_extensions import Literal, TypeAlias
from typing import Any, ClassVar, Generic, Protocol, TypeVar, overload
from typing_extensions import Literal, TypeAlias, TypeGuard

if sys.version_info >= (3, 9):
from types import GenericAlias
Expand All @@ -30,6 +30,11 @@ __all__ = [
if sys.version_info >= (3, 10):
__all__ += ["KW_ONLY"]

class _DataclassInstance(Protocol):
__dataclass_fields__: ClassVar[dict[str, Field[Any]]]

_DataclassT = TypeVar("_DataclassT", bound=_DataclassInstance)

# define _MISSING_TYPE as an enum within the type stubs,
# even though that is not really its type at runtime
# this allows us to use Literal[_MISSING_TYPE.MISSING]
Expand All @@ -44,13 +49,13 @@ if sys.version_info >= (3, 10):
class KW_ONLY: ...

@overload
def asdict(obj: Any) -> dict[str, Any]: ...
def asdict(obj: _DataclassInstance) -> dict[str, Any]: ...
@overload
def asdict(obj: Any, *, dict_factory: Callable[[list[tuple[str, Any]]], _T]) -> _T: ...
def asdict(obj: _DataclassInstance, *, dict_factory: Callable[[list[tuple[str, Any]]], _T]) -> _T: ...
@overload
def astuple(obj: Any) -> tuple[Any, ...]: ...
def astuple(obj: _DataclassInstance) -> tuple[Any, ...]: ...
@overload
def astuple(obj: Any, *, tuple_factory: Callable[[list[Any]], _T]) -> _T: ...
def astuple(obj: _DataclassInstance, *, tuple_factory: Callable[[list[Any]], _T]) -> _T: ...

if sys.version_info >= (3, 8):
# cls argument is now positional-only
Expand Down Expand Up @@ -212,8 +217,13 @@ else:
metadata: Mapping[Any, Any] | None = ...,
) -> Any: ...

def fields(class_or_instance: Any) -> tuple[Field[Any], ...]: ...
def is_dataclass(obj: Any) -> bool: ...
def fields(class_or_instance: _DataclassInstance | type[_DataclassInstance]) -> tuple[Field[Any], ...]: ...
@overload
def is_dataclass(obj: _DataclassInstance | type[_DataclassInstance]) -> bool: ...
AlexWaygood marked this conversation as resolved.
Show resolved Hide resolved
@overload
def is_dataclass(obj: type) -> TypeGuard[type[_DataclassInstance]]: ...
@overload
def is_dataclass(obj: object) -> TypeGuard[_DataclassInstance | type[_DataclassInstance]]: ...

class FrozenInstanceError(AttributeError): ...

Expand Down Expand Up @@ -285,4 +295,4 @@ else:
frozen: bool = ...,
) -> type: ...

def replace(__obj: _T, **changes: Any) -> _T: ...
def replace(__obj: _DataclassT, **changes: Any) -> _DataclassT: ...
67 changes: 67 additions & 0 deletions test_cases/stdlib/check_dataclasses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import dataclasses as dc
AlexWaygood marked this conversation as resolved.
Show resolved Hide resolved
from typing import Any, Dict, Tuple, Type
from typing_extensions import assert_type


@dc.dataclass
class Foo:
attr: str


assert_type(dc.fields(Foo), Tuple[dc.Field[Any], ...])
# These should fail due to the fact it's a dataclass class, not an instance
dc.asdict(Foo) # type: ignore
AlexWaygood marked this conversation as resolved.
Show resolved Hide resolved
dc.astuple(Foo) # type: ignore
dc.replace(Foo) # type: ignore

if dc.is_dataclass(Foo):
# The inferred type doesn't change
# if it's already known to be a subtype of type[_DataclassInstance]
assert_type(Foo, Type[Foo])

f = Foo(attr="attr")

assert_type(dc.fields(f), Tuple[dc.Field[Any], ...])
assert_type(dc.asdict(f), Dict[str, Any])
assert_type(dc.astuple(f), Tuple[Any, ...])
assert_type(dc.replace(f, attr="new"), Foo)

if dc.is_dataclass(f):
# The inferred type doesn't change
# if it's already known to be a subtype of _DataclassInstance
assert_type(f, Foo)


def test_other_isdataclass_overloads(x: type, y: object) -> None:
dc.fields(x) # TODO: why does this pass mypy? It should fail, ideally...
AlexWaygood marked this conversation as resolved.
Show resolved Hide resolved
dc.fields(y) # type: ignore

dc.asdict(x) # type: ignore
dc.asdict(y) # type: ignore

dc.astuple(x) # type: ignore
dc.astuple(y) # type: ignore

dc.replace(x) # type: ignore
dc.replace(y) # type: ignore

if dc.is_dataclass(x):
assert_type(dc.fields(x), Tuple[dc.Field[Any], ...])
# These should fail due to the fact it's a dataclass class, not an instance
dc.asdict(x) # type: ignore
dc.astuple(x) # type: ignore
dc.replace(x) # type: ignore

if dc.is_dataclass(y):
assert_type(dc.fields(y), Tuple[dc.Field[Any], ...])
# These should fail due to the fact we don't know
# whether it's a dataclass class or a dataclass instance
dc.asdict(y) # type: ignore
dc.astuple(y) # type: ignore
dc.replace(y) # type: ignore

if dc.is_dataclass(y) and not isinstance(y, type):
assert_type(dc.fields(y), Tuple[dc.Field[Any], ...])
assert_type(dc.asdict(y), Dict[str, Any])
assert_type(dc.astuple(y), Tuple[Any, ...])
dc.replace(y)