diff --git a/stdlib/pickle.pyi b/stdlib/pickle.pyi index 19564f31178e..da6c5d11794a 100644 --- a/stdlib/pickle.pyi +++ b/stdlib/pickle.pyi @@ -1,11 +1,18 @@ import sys -from typing import IO, Any, Callable, ClassVar, Iterable, Iterator, Mapping, Optional, Tuple, Type, Union +from typing import Any, Callable, ClassVar, Iterable, Iterator, Mapping, Optional, Protocol, Tuple, Type, Union HIGHEST_PROTOCOL: int DEFAULT_PROTOCOL: int bytes_types: Tuple[Type[Any], ...] # undocumented +class _ReadableFileobj(Protocol): + def read(self, __n: int) -> bytes: ... + def readline(self) -> bytes: ... + +class _WritableFileobj(Protocol): + def write(self, __b: bytes) -> Any: ... + if sys.version_info >= (3, 8): # TODO: holistic design for buffer interface (typing.Buffer?) class PickleBuffer: @@ -15,22 +22,32 @@ if sys.version_info >= (3, 8): def release(self) -> None: ... _BufferCallback = Optional[Callable[[PickleBuffer], Any]] def dump( - obj: Any, file: IO[bytes], protocol: int | None = ..., *, fix_imports: bool = ..., buffer_callback: _BufferCallback = ... + obj: Any, + file: _WritableFileobj, + protocol: int | None = ..., + *, + fix_imports: bool = ..., + buffer_callback: _BufferCallback = ..., ) -> None: ... def dumps( obj: Any, protocol: int | None = ..., *, fix_imports: bool = ..., buffer_callback: _BufferCallback = ... ) -> bytes: ... def load( - file: IO[bytes], *, fix_imports: bool = ..., encoding: str = ..., errors: str = ..., buffers: Iterable[Any] | None = ... + file: _ReadableFileobj, + *, + fix_imports: bool = ..., + encoding: str = ..., + errors: str = ..., + buffers: Iterable[Any] | None = ..., ) -> Any: ... def loads( __data: bytes, *, fix_imports: bool = ..., encoding: str = ..., errors: str = ..., buffers: Iterable[Any] | None = ... ) -> Any: ... else: - def dump(obj: Any, file: IO[bytes], protocol: int | None = ..., *, fix_imports: bool = ...) -> None: ... + def dump(obj: Any, file: _WritableFileobj, protocol: int | None = ..., *, fix_imports: bool = ...) -> None: ... def dumps(obj: Any, protocol: int | None = ..., *, fix_imports: bool = ...) -> bytes: ... - def load(file: IO[bytes], *, fix_imports: bool = ..., encoding: str = ..., errors: str = ...) -> Any: ... + def load(file: _ReadableFileobj, *, fix_imports: bool = ..., encoding: str = ..., errors: str = ...) -> Any: ... def loads(data: bytes, *, fix_imports: bool = ..., encoding: str = ..., errors: str = ...) -> Any: ... class PickleError(Exception): ... @@ -53,11 +70,16 @@ class Pickler: if sys.version_info >= (3, 8): def __init__( - self, file: IO[bytes], protocol: int | None = ..., *, fix_imports: bool = ..., buffer_callback: _BufferCallback = ... + self, + file: _WritableFileobj, + protocol: int | None = ..., + *, + fix_imports: bool = ..., + buffer_callback: _BufferCallback = ..., ) -> None: ... def reducer_override(self, obj: Any) -> Any: ... else: - def __init__(self, file: IO[bytes], protocol: int | None = ..., *, fix_imports: bool = ...) -> None: ... + def __init__(self, file: _WritableFileobj, protocol: int | None = ..., *, fix_imports: bool = ...) -> None: ... def dump(self, __obj: Any) -> None: ... def clear_memo(self) -> None: ... def persistent_id(self, obj: Any) -> Any: ... @@ -68,7 +90,7 @@ class Unpickler: if sys.version_info >= (3, 8): def __init__( self, - file: IO[bytes], + file: _ReadableFileobj, *, fix_imports: bool = ..., encoding: str = ..., @@ -76,7 +98,9 @@ class Unpickler: buffers: Iterable[Any] | None = ..., ) -> None: ... else: - def __init__(self, file: IO[bytes], *, fix_imports: bool = ..., encoding: str = ..., errors: str = ...) -> None: ... + def __init__( + self, file: _ReadableFileobj, *, fix_imports: bool = ..., encoding: str = ..., errors: str = ... + ) -> None: ... def load(self) -> Any: ... def find_class(self, __module_name: str, __global_name: str) -> Any: ... def persistent_load(self, pid: Any) -> Any: ...