Skip to content

Commit

Permalink
use Self for return type annotation (#379)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidism authored Jun 2, 2023
2 parents 5efbdb4 + 1e49fd5 commit cc134bb
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 22 deletions.
1 change: 1 addition & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Unreleased
- Implement ``format_map``, ``casefold``, ``removeprefix``, and ``removesuffix``
methods. :issue:`370`
- Fix static typing for basic ``str`` methods on ``Markup``. :issue:`358`
- Use ``Self`` for annotating return types. :pr:`379`


Version 2.1.2
Expand Down
48 changes: 26 additions & 22 deletions src/markupsafe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class Markup(str):

def __new__(
cls, base: t.Any = "", encoding: t.Optional[str] = None, errors: str = "strict"
) -> "Markup":
) -> "te.Self":
if hasattr(base, "__html__"):
base = base.__html__()

Expand All @@ -79,30 +79,30 @@ def __new__(

return super().__new__(cls, base, encoding, errors)

def __html__(self) -> "Markup":
def __html__(self) -> "te.Self":
return self

def __add__(self, other: t.Union[str, "HasHTML"]) -> "Markup":
def __add__(self, other: t.Union[str, "HasHTML"]) -> "te.Self":
if isinstance(other, str) or hasattr(other, "__html__"):
return self.__class__(super().__add__(self.escape(other)))

return NotImplemented

def __radd__(self, other: t.Union[str, "HasHTML"]) -> "Markup":
def __radd__(self, other: t.Union[str, "HasHTML"]) -> "te.Self":
if isinstance(other, str) or hasattr(other, "__html__"):
return self.escape(other).__add__(self)

return NotImplemented

def __mul__(self, num: "te.SupportsIndex") -> "Markup":
def __mul__(self, num: "te.SupportsIndex") -> "te.Self":
if isinstance(num, int):
return self.__class__(super().__mul__(num))

return NotImplemented

__rmul__ = __mul__

def __mod__(self, arg: t.Any) -> "Markup":
def __mod__(self, arg: t.Any) -> "te.Self":
if isinstance(arg, tuple):
# a tuple of arguments, each wrapped
arg = tuple(_MarkupEscapeHelper(x, self.escape) for x in arg)
Expand All @@ -118,26 +118,28 @@ def __mod__(self, arg: t.Any) -> "Markup":
def __repr__(self) -> str:
return f"{self.__class__.__name__}({super().__repr__()})"

def join(self, seq: t.Iterable[t.Union[str, "HasHTML"]]) -> "Markup":
def join(self, seq: t.Iterable[t.Union[str, "HasHTML"]]) -> "te.Self":
return self.__class__(super().join(map(self.escape, seq)))

join.__doc__ = str.join.__doc__

def split( # type: ignore
def split( # type: ignore[override]
self, sep: t.Optional[str] = None, maxsplit: int = -1
) -> t.List["Markup"]:
) -> t.List["te.Self"]:
return [self.__class__(v) for v in super().split(sep, maxsplit)]

split.__doc__ = str.split.__doc__

def rsplit( # type: ignore
def rsplit( # type: ignore[override]
self, sep: t.Optional[str] = None, maxsplit: int = -1
) -> t.List["Markup"]:
) -> t.List["te.Self"]:
return [self.__class__(v) for v in super().rsplit(sep, maxsplit)]

rsplit.__doc__ = str.rsplit.__doc__

def splitlines(self, keepends: bool = False) -> t.List["Markup"]: # type: ignore
def splitlines( # type: ignore[override]
self, keepends: bool = False
) -> t.List["te.Self"]:
return [self.__class__(v) for v in super().splitlines(keepends)]

splitlines.__doc__ = str.splitlines.__doc__
Expand All @@ -164,10 +166,10 @@ def striptags(self) -> str:
value = _strip_comments_re.sub("", self)
value = _strip_tags_re.sub("", value)
value = " ".join(value.split())
return Markup(value).unescape()
return self.__class__(value).unescape()

@classmethod
def escape(cls, s: t.Any) -> "Markup":
def escape(cls, s: t.Any) -> "te.Self":
"""Escape a string. Calls :func:`escape` and ensures that for
subclasses the correct type is returned.
"""
Expand All @@ -176,7 +178,7 @@ def escape(cls, s: t.Any) -> "Markup":
if rv.__class__ is not cls:
return cls(rv)

return rv
return rv # type: ignore[return-value]

__getitem__ = _simple_escaping_wrapper(str.__getitem__)
capitalize = _simple_escaping_wrapper(str.capitalize)
Expand All @@ -200,25 +202,27 @@ def escape(cls, s: t.Any) -> "Markup":
removeprefix = _simple_escaping_wrapper(str.removeprefix)
removesuffix = _simple_escaping_wrapper(str.removesuffix)

def partition(self, sep: str) -> t.Tuple["Markup", "Markup", "Markup"]:
def partition(self, sep: str) -> t.Tuple["te.Self", "te.Self", "te.Self"]:
l, s, r = super().partition(self.escape(sep))
cls = self.__class__
return cls(l), cls(s), cls(r)

def rpartition(self, sep: str) -> t.Tuple["Markup", "Markup", "Markup"]:
def rpartition(self, sep: str) -> t.Tuple["te.Self", "te.Self", "te.Self"]:
l, s, r = super().rpartition(self.escape(sep))
cls = self.__class__
return cls(l), cls(s), cls(r)

def format(self, *args: t.Any, **kwargs: t.Any) -> "Markup":
def format(self, *args: t.Any, **kwargs: t.Any) -> "te.Self":
formatter = EscapeFormatter(self.escape)
return self.__class__(formatter.vformat(self, args, kwargs))

def format_map(self, map: t.Mapping[str, t.Any]) -> str: # type: ignore[override]
def format_map( # type: ignore[override]
self, map: t.Mapping[str, t.Any]
) -> "te.Self":
formatter = EscapeFormatter(self.escape)
return self.__class__(formatter.vformat(self, (), map))

def __html_format__(self, format_spec: str) -> "Markup":
def __html_format__(self, format_spec: str) -> "te.Self":
if format_spec:
raise ValueError("Unsupported format specification for Markup.")

Expand Down Expand Up @@ -273,8 +277,8 @@ def __init__(self, obj: t.Any, escape: t.Callable[[t.Any], Markup]) -> None:
self.obj = obj
self.escape = escape

def __getitem__(self, item: t.Any) -> "_MarkupEscapeHelper":
return _MarkupEscapeHelper(self.obj[item], self.escape)
def __getitem__(self, item: t.Any) -> "te.Self":
return self.__class__(self.obj[item], self.escape)

def __str__(self) -> str:
return str(self.escape(self.obj))
Expand Down

0 comments on commit cc134bb

Please sign in to comment.