Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Improve type hints for str vs bytes
Browse files Browse the repository at this point in the history
aronbierbaum committed Sep 8, 2022

Verified

This commit was signed with the committer’s verified signature.
aronbierbaum Aron Bierbaum
1 parent 0a9b609 commit 80e9ad4
Showing 3 changed files with 125 additions and 104 deletions.
2 changes: 0 additions & 2 deletions lxml-stubs/cssselect.pyi
Original file line number Diff line number Diff line change
@@ -5,8 +5,6 @@ from lxml import etree
# dummy for missing stubs
def __getattr__(name) -> Any: ...

_DictAnyStr = Union[Dict[str, str], Dict[bytes, bytes]]

class CSSSelector(etree.XPath):
def __init__(
self,
219 changes: 121 additions & 98 deletions lxml-stubs/etree.pyi
Original file line number Diff line number Diff line change
@@ -3,9 +3,11 @@
# Any use of `Any` below means I couldn't figure out the type.

from os import PathLike
import sys
from typing import (
IO,
Any,
AnyStr,
Callable,
Dict,
Iterable,
@@ -32,7 +34,11 @@ def __getattr__(name: str) -> Any: ...
# unnecessary constraint. It seems reasonable to constrain each
# List/Dict argument to use one type consistently, though, and it is
# necessary in order to keep these brief.
_AnyStr = Union[str, bytes]
if sys.version_info[0] >= 3:
_StrResult = str
else:
_StrResult = Union[str, bytes]

_AnySmartStr = Union[
"_ElementUnicodeResult", "_PyElementUnicodeResult", "_ElementStringResult"
]
@@ -44,21 +50,27 @@ _XPathObject = Union[
bool,
float,
_AnySmartStr,
_AnyStr,
_StrResult,
List[
Union[
"_Element",
_AnySmartStr,
_AnyStr,
Tuple[Optional[_AnyStr], Optional[_AnyStr]],
_StrResult,
Tuple[Optional[_StrResult], Optional[_StrResult]],
]
],
]
_AnyParser = Union["XMLParser", "HTMLParser"]
_ListAnyStr = Union[List[str], List[bytes]]
_DictAnyStr = Union[Dict[str, str], Dict[bytes, bytes]]
_Dict_Tuple2AnyStr_Any = Union[Dict[Tuple[str, str], Any], Tuple[bytes, bytes], Any]
_xpath = Union["XPath", _AnyStr]
if sys.version_info[0] >= 3:
_ListAnyStr = List[str]
_DictAnyStr = Dict[str, str]
else:
_ListAnyStr = Union[List[str], List[bytes]]
_DictAnyStr = Union[Dict[str, str], Dict[bytes, bytes]]
_StrOrBytes = Union[str, bytes]
_ValueType = Union[str, bytes, QName]
_InputDictAnyStr = [Dict[str, str], Dict[bytes, bytes]]
_ExtensionsDict = Dict[Tuple[_StrOrBytes, _StrOrBytes], Any]

# See https://github.com/python/typing/pull/273
# Due to Mapping having invariant key types, Mapping[Union[A, B], ...]
@@ -81,7 +93,7 @@ _KnownEncodings = Literal[
"us-ascii",
]
_ElementOrTree = Union[_Element, _ElementTree]
_FileSource = Union[_AnyStr, IO[Any], PathLike[Any]]
_FileSource = Union[_StrOrBytes, IO[AnyStr], PathLike[str]]

class ElementChildIterator(Iterator["_Element"]):
def __iter__(self) -> "ElementChildIterator": ...
@@ -91,21 +103,21 @@ class _ElementUnicodeResult(str):
is_attribute: bool
is_tail: bool
is_text: bool
attrname: Optional[_AnyStr]
attrname: Optional[_StrResult]
def getparent(self) -> Optional["_Element"]: ...

class _PyElementUnicodeResult(str):
is_attribute: bool
is_tail: bool
is_text: bool
attrname: Optional[_AnyStr]
attrname: Optional[_StrResult]
def getparent(self) -> Optional["_Element"]: ...

class _ElementStringResult(bytes):
is_attribute: bool
is_tail: bool
is_text: bool
attrname: Optional[_AnyStr]
attrname: Optional[_StrResult]
def getparent(self) -> Optional["_Element"]: ...

class DocInfo:
@@ -152,9 +164,9 @@ class _Element(Iterable["_Element"], Sized):
) -> List["_Element"]: ...
def clear(self) -> None: ...
@overload
def get(self, key: _TagName) -> Optional[str]: ...
def get(self, key: _TagName) -> Optional[_StrResult]: ...
@overload
def get(self, key: _TagName, default: _T) -> Union[str, _T]: ...
def get(self, key: _TagName, default: _T) -> Union[_StrResult, _T]: ...
def getnext(self) -> Optional[_Element]: ...
def getparent(self) -> Optional[_Element]: ...
def getprevious(self) -> Optional[_Element]: ...
@@ -163,7 +175,7 @@ class _Element(Iterable["_Element"], Sized):
self, child: _Element, start: Optional[int] = ..., stop: Optional[int] = ...
) -> int: ...
def insert(self, index: int, element: _Element) -> None: ...
def items(self) -> Sequence[Tuple[_AnyStr, _AnyStr]]: ...
def items(self) -> Sequence[Tuple[_StrResult, _StrResult]]: ...
def iter(
self, tag: Optional[_TagSelector] = ..., *tags: _TagSelector
) -> Iterator[_Element]: ...
@@ -189,35 +201,35 @@ class _Element(Iterable["_Element"], Sized):
tag: Optional[_TagSelector] = ...,
with_tail: bool = False,
*tags: _TagSelector,
) -> Iterator[_AnyStr]: ...
def keys(self) -> Sequence[_AnyStr]: ...
) -> Iterator[_StrResult]: ...
def keys(self) -> Sequence[_StrResult]: ...
def makeelement(
self,
_tag: _TagName,
attrib: Optional[_DictAnyStr] = ...,
attrib: Optional[_InputDictAnyStr] = ...,
nsmap: Optional[_NSMapArg] = ...,
**_extra: Any,
) -> _Element: ...
def remove(self, element: _Element) -> None: ...
def replace(self, old_element: _Element, new_element: _Element) -> None: ...
def set(self, key: _TagName, value: _AnyStr) -> None: ...
def values(self) -> Sequence[_AnyStr]: ...
def set(self, key: _TagName, value: _ValueType) -> None: ...
def values(self) -> Sequence[_StrResult]: ...
def xpath(
self,
_path: _AnyStr,
_path: _StrOrBytes,
namespaces: Optional[_NonDefaultNSMapArg] = ...,
extensions: Any = ...,
smart_strings: bool = ...,
**_variables: _XPathObject,
) -> _XPathObject: ...
tag = ... # type: str
tag = ... # type: _StrResult
attrib = ... # type: _Attrib
text = ... # type: Optional[str]
tail = ... # type: Optional[str]
prefix = ... # type: str
text = ... # type: Optional[_StrResult]
tail = ... # type: Optional[_StrResult]
prefix = ... # type: Optional[_StrResult]
sourceline = ... # Optional[int]
@property
def nsmap(self) -> Dict[Optional[str], str]: ...
def nsmap(self) -> Dict[Optional[_StrResult], Optional[_StrResult]]: ...
base = ... # type: Optional[str]

class ElementBase(_Element): ...
@@ -250,34 +262,36 @@ class _ElementTree:
self,
source: _FileSource,
parser: Optional[_AnyParser] = ...,
base_url: Optional[_AnyStr] = ...,
base_url: Optional[_StrOrBytes] = ...,
) -> _Element: ...
def write(
self,
file: _FileSource,
encoding: _AnyStr = ...,
method: _AnyStr = ...,
encoding: Optional[_StrOrBytes] = ...,
method: _StrOrBytes = ...,
pretty_print: bool = ...,
xml_declaration: Any = ...,
with_tail: Any = ...,
standalone: bool = ...,
doctype: _StrOrBytes = ...,
compression: int = ...,
exclusive: bool = ...,
inclusive_ns_prefixes: Iterable[_StrOrBytes] = ...,
with_comments: bool = ...,
inclusive_ns_prefixes: _ListAnyStr = ...,
strip_text: bool = ...,
) -> None: ...
def write_c14n(
self,
file: _FileSource,
with_comments: bool = ...,
compression: int = ...,
inclusive_ns_prefixes: Iterable[_AnyStr] = ...,
inclusive_ns_prefixes: Iterable[_StrOrBytes] = ...,
) -> None: ...
def _setroot(self, root: _Element) -> None: ...
def xinclude(self) -> None: ...
def xpath(
self,
_path: _AnyStr,
_path: _StrOrBytes,
namespaces: Optional[_NonDefaultNSMapArg] = ...,
extensions: Any = ...,
smart_strings: bool = ...,
@@ -286,7 +300,7 @@ class _ElementTree:
def xslt(
self,
_xslt: XSLT,
extensions: Optional[_Dict_Tuple2AnyStr_Any] = ...,
extensions: Optional[_ExtensionsDict] = ...,
access_control: Optional[XSLTAccessControl] = ...,
**_variables: Any,
) -> _ElementTree: ...
@@ -295,35 +309,41 @@ class __ContentOnlyEleement(_Element): ...
class _Comment(__ContentOnlyEleement): ...

class _ProcessingInstruction(__ContentOnlyEleement):
target: _AnyStr
target: _StrResult

class _Attrib:
def __setitem__(self, key: _AnyStr, value: _AnyStr) -> None: ...
def __delitem__(self, key: _AnyStr) -> None: ...
def __setitem__(self, key: _TagName, value: _ValueType) -> None: ...
def __delitem__(self, key: _TagName) -> None: ...
def update(
self,
sequence_or_dict: Union[
_Attrib, Mapping[_AnyStr, _AnyStr], Sequence[Tuple[_AnyStr, _AnyStr]]
_Attrib, Mapping[_TagName, _ValueType], Sequence[Tuple[_TagName, _ValueType]]
],
) -> None: ...
def pop(self, key: _AnyStr, default: _AnyStr) -> _AnyStr: ...
@overload
def pop(self, key: _TagName) -> _StrResult: ...
@overload
def pop(self, key: _TagName, default: _T) -> Union[_StrResult, _T]: ...
def clear(self) -> None: ...
def __repr__(self) -> str: ...
def __copy__(self) -> _DictAnyStr: ...
def __deepcopy__(self, memo: Dict[Any, Any]) -> _DictAnyStr: ...
def __getitem__(self, key: _AnyStr) -> _AnyStr: ...
def __getitem__(self, key: _TagName) -> _StrResult: ...
def __bool__(self) -> bool: ...
def __len__(self) -> int: ...
def get(self, key: _AnyStr, default: _AnyStr = ...) -> Optional[_AnyStr]: ...
@overload
def get(self, key: _TagName) -> Optional[_StrResult]: ...
@overload
def get(self, key: _TagName, default: _T) -> Union[_StrResult, _T]: ...
def keys(self) -> _ListAnyStr: ...
def __iter__(self) -> Iterator[_AnyStr]: ... # actually _AttribIterator
def iterkeys(self) -> Iterator[_AnyStr]: ...
def __iter__(self) -> Iterator[_StrResult]: ... # actually _AttribIterator
def iterkeys(self) -> Iterator[_StrResult]: ...
def values(self) -> _ListAnyStr: ...
def itervalues(self) -> Iterator[_AnyStr]: ...
def items(self) -> List[Tuple[_AnyStr, _AnyStr]]: ...
def iteritems(self) -> Iterator[Tuple[_AnyStr, _AnyStr]]: ...
def has_key(self, key: _AnyStr) -> bool: ...
def __contains__(self, key: _AnyStr) -> bool: ...
def itervalues(self) -> Iterator[_StrResult]: ...
def items(self) -> List[Tuple[_StrResult, _StrResult]]: ...
def iteritems(self) -> Iterator[Tuple[_StrResult, _StrResult]]: ...
def has_key(self, key: _TagName) -> bool: ...
def __contains__(self, key: _TagName) -> bool: ...
def __richcmp__(self, other: _Attrib, op: int) -> bool: ...

class QName:
@@ -332,8 +352,8 @@ class QName:
text = ... # type: str
def __init__(
self,
text_or_uri_element: Union[None, _AnyStr, _Element],
tag: Optional[_AnyStr] = ...,
text_or_uri_element: Union[None, _TagName, _Element],
tag: Optional[_TagName] = ...,
) -> None: ...

class _XSLTResultTree(_ElementTree, SupportsBytes):
@@ -343,11 +363,11 @@ class _XSLTQuotedStringParam: ...

# https://lxml.de/parsing.html#the-target-parser-interface
class ParserTarget(Protocol):
def comment(self, text: _AnyStr) -> None: ...
def comment(self, text: _StrResult) -> None: ...
def close(self) -> Any: ...
def data(self, data: _AnyStr) -> None: ...
def end(self, tag: _AnyStr) -> None: ...
def start(self, tag: _AnyStr, attrib: Dict[_AnyStr, _AnyStr]) -> None: ...
def data(self, data: _StrResult) -> None: ...
def end(self, tag: _StrResult) -> None: ...
def start(self, tag: _StrResult, attrib: Dict[_StrResult, _StrResult]) -> None: ...

class ElementClassLookup: ...

@@ -367,7 +387,7 @@ class _BaseParser:
def makeelement(
self,
_tag: _TagName,
attrib: Optional[Union[_DictAnyStr, _Attrib]] = ...,
attrib: Optional[Union[_InputDictAnyStr, _Attrib]] = ...,
nsmap: Optional[_NSMapArg] = ...,
**_extra: Any,
) -> _Element: ...
@@ -381,12 +401,12 @@ class _BaseParser:
class _FeedParser(_BaseParser):
def __getattr__(self, name: str) -> Any: ... # Incomplete
def close(self) -> _Element: ...
def feed(self, data: _AnyStr) -> None: ...
def feed(self, data: _StrOrBytes) -> None: ...

class XMLParser(_FeedParser):
def __init__(
self,
encoding: Optional[_AnyStr] = ...,
encoding: Optional[_StrOrBytes] = ...,
attribute_defaults: bool = ...,
dtd_validation: bool = ...,
load_dtd: bool = ...,
@@ -409,7 +429,7 @@ class XMLParser(_FeedParser):
class HTMLParser(_FeedParser):
def __init__(
self,
encoding: Optional[_AnyStr] = ...,
encoding: Optional[_StrOrBytes] = ...,
collect_ids: bool = ...,
compact: bool = ...,
huge_tree: bool = ...,
@@ -430,10 +450,10 @@ class _ResolverRegistry:
class Resolver:
def resolve(self, system_url: str, public_id: str): ...
def resolve_file(
self, f: IO[Any], context: Any, *, base_url: Optional[_AnyStr], close: bool
self, f: IO[Any], context: Any, *, base_url: Optional[_StrOrBytes], close: bool
): ...
def resolve_string(
self, string: _AnyStr, context: Any, *, base_url: Optional[_AnyStr]
self, string: _StrOrBytes, context: Any, *, base_url: Optional[_StrOrBytes]
): ...

class XMLSchema(_Validator):
@@ -450,121 +470,124 @@ class XSLT:
def __init__(
self,
xslt_input: _ElementOrTree,
extensions: _Dict_Tuple2AnyStr_Any = ...,
extensions: Optional[_ExtensionsDict] = ...,
regexp: bool = ...,
access_control: XSLTAccessControl = ...,
access_control: Optional[XSLTAccessControl] = ...,
) -> None: ...
def __call__(
self,
_input: _ElementOrTree,
profile_run: bool = ...,
**kwargs: Union[_AnyStr, _XSLTQuotedStringParam],
**kwargs: Union[_StrOrBytes, XPath, _XSLTQuotedStringParam],
) -> _XSLTResultTree: ...
@staticmethod
def strparam(s: _AnyStr) -> _XSLTQuotedStringParam: ...
def strparam(s: _StrOrBytes) -> _XSLTQuotedStringParam: ...

def Comment(text: Optional[_AnyStr] = ...) -> _Comment: ...
def Comment(text: Optional[_StrOrBytes] = ...) -> _Comment: ...
def Element(
_tag: _TagName,
attrib: Optional[_DictAnyStr] = ...,
attrib: Optional[_InputDictAnyStr] = ...,
nsmap: Optional[_NSMapArg] = ...,
**extra: _AnyStr,
**extra: _StrOrBytes,
) -> _Element: ...
def SubElement(
_parent: _Element,
_tag: _TagName,
attrib: Optional[_DictAnyStr] = ...,
attrib: Optional[_InputDictAnyStr] = ...,
nsmap: Optional[_NSMapArg] = ...,
**extra: _AnyStr,
**extra: _StrOrBytes,
) -> _Element: ...
def ElementTree(
element: _Element = ...,
file: _FileSource = ...,
parser: _AnyParser = ...,
) -> _ElementTree: ...
def ProcessingInstruction(
target: _AnyStr, text: _AnyStr = ...
target: _StrOrBytes, text: Optional[_StrOrBytes] = ...
) -> _ProcessingInstruction: ...

PI = ProcessingInstruction

def HTML(
text: _AnyStr,
text: _StrOrBytes,
parser: Optional[HTMLParser] = ...,
base_url: Optional[_AnyStr] = ...,
base_url: Optional[_StrOrBytes] = ...,
) -> _Element: ...
def XML(
text: _AnyStr,
text: _StrOrBytes,
parser: Optional[XMLParser] = ...,
base_url: Optional[_AnyStr] = ...,
base_url: Optional[_StrOrBytes] = ...,
) -> _Element: ...
def cleanup_namespaces(
tree_or_element: _ElementOrTree,
top_nsmap: Optional[_NSMapArg] = ...,
keep_ns_prefixes: Optional[Iterable[_AnyStr]] = ...,
keep_ns_prefixes: Optional[Iterable[_StrOrBytes]] = ...,
) -> None: ...
def parse(
source: _FileSource,
parser: _AnyParser = ...,
base_url: _AnyStr = ...,
base_url: _StrOrBytes = ...,
) -> Union[_ElementTree, Any]: ...
@overload
def fromstring(
text: _AnyStr,
text: _StrOrBytes,
parser: None = ...,
*,
base_url: _AnyStr = ...,
base_url: _StrOrBytes = ...,
) -> _Element: ...
@overload
def fromstring(
text: _AnyStr,
text: _StrOrBytes,
parser: _AnyParser = ...,
*,
base_url: _AnyStr = ...,
base_url: _StrOrBytes = ...,
) -> Union[_Element, Any]: ...
@overload
def tostring(
element_or_tree: _ElementOrTree,
encoding: Union[Type[str], Literal["unicode"]],
method: str = ...,
method: _StrOrBytes = ...,
xml_declaration: bool = ...,
pretty_print: bool = ...,
with_tail: bool = ...,
standalone: bool = ...,
doctype: str = ...,
exclusive: bool = ...,
with_comments: bool = ...,
inclusive_ns_prefixes: Any = ...,
with_comments: bool = ...,
skip_text: bool = ...,
) -> str: ...
@overload
def tostring(
element_or_tree: _ElementOrTree,
# Should be anything but "unicode", cannot be typed
encoding: Optional[_KnownEncodings] = None,
method: str = ...,
method: _StrOrBytes = ...,
xml_declaration: bool = ...,
pretty_print: bool = ...,
with_tail: bool = ...,
standalone: bool = ...,
doctype: str = ...,
exclusive: bool = ...,
with_comments: bool = ...,
inclusive_ns_prefixes: Any = ...,
with_comments: bool = ...,
skip_text: bool = ...,
) -> bytes: ...
@overload
def tostring(
element_or_tree: _ElementOrTree,
encoding: Union[str, type] = ...,
method: str = ...,
method: _StrOrBytes = ...,
xml_declaration: bool = ...,
pretty_print: bool = ...,
with_tail: bool = ...,
standalone: bool = ...,
doctype: str = ...,
exclusive: bool = ...,
with_comments: bool = ...,
inclusive_ns_prefixes: Any = ...,
) -> _AnyStr: ...
with_comments: bool = ...,
skip_text: bool = ...,
) -> _StrOrBytes: ...

class _ErrorLog: ...
class Error(Exception): ...
@@ -596,7 +619,7 @@ class _XPathEvaluatorBase: ...
class XPath(_XPathEvaluatorBase):
def __init__(
self,
path: _AnyStr,
path: _StrOrBytes,
*,
namespaces: Optional[_NonDefaultNSMapArg] = ...,
extensions: Any = ...,
@@ -611,7 +634,7 @@ class XPath(_XPathEvaluatorBase):
class ETXPath(XPath):
def __init__(
self,
path: _AnyStr,
path: _StrOrBytes,
*,
extensions: Any = ...,
regexp: bool = ...,
@@ -628,8 +651,8 @@ class XPathElementEvaluator(_XPathEvaluatorBase):
regexp: bool = ...,
smart_strings: bool = ...,
) -> None: ...
def __call__(self, _path: _AnyStr, **_variables: _XPathObject) -> _XPathObject: ...
def register_namespace(self, prefix: _AnyStr, uri: _AnyStr) -> None: ...
def __call__(self, _path: _StrOrBytes, **_variables: _XPathObject) -> _XPathObject: ...
def register_namespace(self, prefix: _StrOrBytes, uri: _StrOrBytes) -> None: ...
def register_namespaces(
self, namespaces: Optional[_NonDefaultNSMapArg]
) -> None: ...
@@ -670,9 +693,9 @@ def XPathEvaluator(
smart_strings: bool = ...,
) -> Union[XPathElementEvaluator, XPathDocumentEvaluator]: ...

_ElementFactory = Callable[[Any, Dict[_AnyStr, _AnyStr]], _Element]
_CommentFactory = Callable[[_AnyStr], _Comment]
_ProcessingInstructionFactory = Callable[[_AnyStr, _AnyStr], _ProcessingInstruction]
_ElementFactory = Callable[[_StrOrBytes, Dict[_StrOrBytes, _StrOrBytes]], _Element]
_CommentFactory = Callable[[_StrOrBytes], _Comment]
_ProcessingInstructionFactory = Callable[[_StrOrBytes, _StrOrBytes], _ProcessingInstruction]

class TreeBuilder:
def __init__(
@@ -685,10 +708,10 @@ class TreeBuilder:
insert_pis: bool = ...,
) -> None: ...
def close(self) -> _Element: ...
def comment(self, text: _AnyStr) -> None: ...
def data(self, data: _AnyStr) -> None: ...
def end(self, tag: _AnyStr) -> None: ...
def pi(self, target: _AnyStr, data: Optional[_AnyStr] = ...) -> Any: ...
def start(self, tag: _AnyStr, attrib: Dict[_AnyStr, _AnyStr]) -> None: ...
def comment(self, text: _StrOrBytes) -> None: ...
def data(self, data: _StrOrBytes) -> None: ...
def end(self, tag: _TagName) -> None: ...
def pi(self, target: _StrOrBytes, data: Optional[_StrOrBytes] = ...) -> Any: ...
def start(self, tag: _TagName, attrib: Dict[_StrOrBytes, _StrOrBytes]) -> None: ...

def iselement(element: Any) -> TypeGuard[_Element]: ...
8 changes: 4 additions & 4 deletions lxml-stubs/html/__init__.pyi
Original file line number Diff line number Diff line change
@@ -16,7 +16,7 @@ from typing_extensions import Literal
if TYPE_CHECKING:
from ..etree import HTMLParser as _HTMLParser
from ..etree import XMLParser as _XMLParser
from ..etree import _AnySmartStr, _AnyStr, _BaseParser, _Element
from ..etree import _AnySmartStr, _BaseParser, _Element, _StrOrBytes

_HANDLE_FALURES = Literal["ignore", "discard", None]

@@ -70,16 +70,16 @@ class XHTMLParser(_XMLParser):
pass

def document_fromstring(
html: "_AnyStr", parser: "_BaseParser" = ..., ensure_head_body: bool = ..., **kw
html: _StrOrBytes, parser: "_BaseParser" = ..., ensure_head_body: bool = ..., **kw
) -> "_Element": ...
def fragments_fromstring(
html: "_AnyStr",
html: _StrOrBytes,
no_leading_text: bool = ...,
base_url: str = ...,
parser: "_BaseParser" = ...,
**kw
) -> "_Element": ...
def fromstring(
html: "_AnyStr", base_url: str = ..., parser: "_BaseParser" = ..., **kw
html: _StrOrBytes, base_url: Optional[str] = ..., parser: "_BaseParser" = ..., **kw
) -> "_Element": ...
def __getattr__(name: str) -> Any: ... # incomplete

0 comments on commit 80e9ad4

Please sign in to comment.