From 1605eb7e58321544d2c2dea030b2c78779a918da Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Mon, 19 Sep 2022 10:42:40 -0500 Subject: [PATCH] Work around mypy limitations (#191) Currently mypy doesn't interpret things other than classes as being compatible with `Type[T]` (they literally have to be instances of `type`). This makes it hard to properly type the `decode` methods or `Decoder` constructors, as passing type-like-things (e.g. `Union[int, str]`) will error under mypy. `pyright` doesn't have this limitation. For now we add a fallback `overload` to these methods that infers the decoder type as `Decoder[Any]`. This means that `mypy` will no longer error, but will require an explicit annotation to infer the type of the output of `decode`. If/once the `TypeForm` PEP lands, this hack can be removed. Until then `pyright` support for these annotations will be much better than `mypy` support. --- msgspec/json.pyi | 15 +++++++++++++++ msgspec/msgpack.pyi | 16 ++++++++++++++++ tests/basic_typing_examples.py | 28 +++++++++++++++++++++++++++- 3 files changed, 58 insertions(+), 1 deletion(-) diff --git a/msgspec/json.pyi b/msgspec/json.pyi index 48b49e70..2aadb4f7 100644 --- a/msgspec/json.pyi +++ b/msgspec/json.pyi @@ -33,6 +33,7 @@ class Encoder: class Decoder(Generic[T]): type: Type[T] dec_hook: dec_hook_sig + @overload def __init__( self: Decoder[Any], @@ -46,6 +47,13 @@ class Decoder(Generic[T]): *, dec_hook: dec_hook_sig = None, ) -> None: ... + @overload + def __init__( + self: Decoder[Any], + type: Any = ..., + *, + dec_hook: dec_hook_sig = None, + ) -> None: ... def decode(self, data: bytes) -> T: ... @overload @@ -61,6 +69,13 @@ def decode( type: Type[T] = ..., dec_hook: dec_hook_sig = None, ) -> T: ... +@overload +def decode( + buf: bytes, + *, + type: Any = ..., + dec_hook: dec_hook_sig = None, +) -> Any: ... def encode(obj: Any, *, enc_hook: enc_hook_sig = None) -> bytes: ... def schema(type: Any) -> Dict[str, Any]: ... def schema_components( diff --git a/msgspec/msgpack.pyi b/msgspec/msgpack.pyi index b751618c..53c63e0e 100644 --- a/msgspec/msgpack.pyi +++ b/msgspec/msgpack.pyi @@ -32,6 +32,14 @@ class Decoder(Generic[T]): dec_hook: dec_hook_sig = None, ext_hook: ext_hook_sig = None, ) -> None: ... + @overload + def __init__( + self: Decoder[Any], + type: Any = ..., + *, + dec_hook: dec_hook_sig = None, + ext_hook: ext_hook_sig = None, + ) -> None: ... def decode(self, data: bytes) -> T: ... class Encoder: @@ -63,4 +71,12 @@ def decode( dec_hook: dec_hook_sig = None, ext_hook: ext_hook_sig = None, ) -> T: ... +@overload +def decode( + buf: bytes, + *, + type: Any = ..., + dec_hook: dec_hook_sig = None, + ext_hook: ext_hook_sig = None, +) -> Any: ... def encode(obj: Any, *, enc_hook: enc_hook_sig = None) -> bytes: ... diff --git a/tests/basic_typing_examples.py b/tests/basic_typing_examples.py index 0be5584b..e9b9e6c4 100644 --- a/tests/basic_typing_examples.py +++ b/tests/basic_typing_examples.py @@ -1,7 +1,9 @@ # fmt: off +from __future__ import annotations + import datetime import pickle -from typing import List, Any, Type +from typing import List, Any, Type, Union import msgspec def check___version__() -> None: @@ -391,6 +393,14 @@ def check_msgpack_Decoder_decode_typed() -> None: reveal_type(o) # assert ("List" in typ or "list" in typ) and "int" in typ +def check_msgpack_Decoder_decode_union() -> None: + # Pyright doesn't require the annotation, but mypy does until TypeForm + # is supported. This is mostly checking that no error happens here. + dec: msgspec.msgpack.Decoder[Union[int, str]] = msgspec.msgpack.Decoder(Union[int, str]) + o = dec.decode(b'') + reveal_type(o) # assert ("int" in typ and "str" in typ) + + def check_msgpack_Decoder_decode_type_comment() -> None: dec = msgspec.msgpack.Decoder() # type: msgspec.msgpack.Decoder[List[int]] b = msgspec.msgpack.encode([1, 2, 3]) @@ -414,6 +424,11 @@ def check_msgpack_decode_typed() -> None: reveal_type(o) # assert ("List" in typ or "list" in typ) and "int" in typ +def check_msgpack_decode_typed_union() -> None: + o: Union[int, str] = msgspec.msgpack.decode(b"", type=Union[int, str]) + reveal_type(o) # assert "int" in typ and "str" in typ + + def check_msgpack_encode_enc_hook() -> None: msgspec.msgpack.encode(object(), enc_hook=lambda x: None) @@ -495,6 +510,12 @@ def check_json_Decoder_decode_type_comment() -> None: reveal_type(o) # assert ("List" in typ or "list" in typ) and "int" in typ +def check_json_Decoder_decode_union() -> None: + dec: msgspec.json.Decoder[Union[int, str]] = msgspec.json.Decoder(Union[int, str]) + o = dec.decode(b'') + reveal_type(o) # assert ("int" in typ and "str" in typ) + + def check_json_decode_any() -> None: b = msgspec.json.encode([1, 2, 3]) o = msgspec.json.decode(b) @@ -509,6 +530,11 @@ def check_json_decode_typed() -> None: reveal_type(o) # assert ("List" in typ or "list" in typ) and "int" in typ +def check_json_decode_typed_union() -> None: + o: Union[int, str] = msgspec.json.decode(b"", type=Union[int, str]) + reveal_type(o) # assert "int" in typ and "str" in typ + + def check_json_encode_enc_hook() -> None: msgspec.json.encode(object(), enc_hook=lambda x: None)