Skip to content

Commit

Permalink
improve typing for _simple_escaping_wrapper (#376)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidism authored Jun 2, 2023
2 parents 910ffff + 04d7d4d commit 60e7eb9
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 29 deletions.
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ Version 2.1.3

Unreleased

- Fix static typing for basic ``str`` methods on ``Markup``. :issue:`358`


Version 2.1.2
-------------
Expand Down
53 changes: 24 additions & 29 deletions src/markupsafe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,23 @@ class HasHTML(te.Protocol):
def __html__(self) -> str:
pass

_P = te.ParamSpec("_P")


__version__ = "2.1.3.dev"

_strip_comments_re = re.compile(r"<!--.*?-->", re.DOTALL)
_strip_tags_re = re.compile(r"<.*?>", re.DOTALL)


def _simple_escaping_wrapper(name: str) -> t.Callable[..., "Markup"]:
orig = getattr(str, name)

@functools.wraps(orig)
def wrapped(self: "Markup", *args: t.Any, **kwargs: t.Any) -> "Markup":
args = _escape_argspec(list(args), enumerate(args), self.escape) # type: ignore
def _simple_escaping_wrapper(func: "t.Callable[_P, str]") -> "t.Callable[_P, Markup]":
@functools.wraps(func)
def wrapped(self: "Markup", *args: "_P.args", **kwargs: "_P.kwargs") -> "Markup":
arg_list = _escape_argspec(list(args), enumerate(args), self.escape)
_escape_argspec(kwargs, kwargs.items(), self.escape)
return self.__class__(orig(self, *args, **kwargs))
return self.__class__(func(self, *arg_list, **kwargs)) # type: ignore[arg-type]

return wrapped
return wrapped # type: ignore[return-value]


class Markup(str):
Expand Down Expand Up @@ -177,27 +177,22 @@ def escape(cls, s: t.Any) -> "Markup":

return rv

for method in (
"__getitem__",
"capitalize",
"title",
"lower",
"upper",
"replace",
"ljust",
"rjust",
"lstrip",
"rstrip",
"center",
"strip",
"translate",
"expandtabs",
"swapcase",
"zfill",
):
locals()[method] = _simple_escaping_wrapper(method)

del method
__getitem__ = _simple_escaping_wrapper(str.__getitem__)
capitalize = _simple_escaping_wrapper(str.capitalize)
title = _simple_escaping_wrapper(str.title)
lower = _simple_escaping_wrapper(str.lower)
upper = _simple_escaping_wrapper(str.upper)
replace = _simple_escaping_wrapper(str.replace)
ljust = _simple_escaping_wrapper(str.ljust)
rjust = _simple_escaping_wrapper(str.rjust)
lstrip = _simple_escaping_wrapper(str.lstrip)
rstrip = _simple_escaping_wrapper(str.rstrip)
center = _simple_escaping_wrapper(str.center)
strip = _simple_escaping_wrapper(str.strip)
translate = _simple_escaping_wrapper(str.translate)
expandtabs = _simple_escaping_wrapper(str.expandtabs)
swapcase = _simple_escaping_wrapper(str.swapcase)
zfill = _simple_escaping_wrapper(str.zfill)

def partition(self, sep: str) -> t.Tuple["Markup", "Markup", "Markup"]:
l, s, r = super().partition(self.escape(sep))
Expand Down

0 comments on commit 60e7eb9

Please sign in to comment.