From 04d7d4dae151b7ad44e7ae99a3c2aca79bab6abc Mon Sep 17 00:00:00 2001 From: David Salvisberg Date: Wed, 12 Apr 2023 12:07:46 +0200 Subject: [PATCH] improve typing for `_simple_escaping_wrapper` --- CHANGES.rst | 2 ++ src/markupsafe/__init__.py | 53 +++++++++++++++++--------------------- 2 files changed, 26 insertions(+), 29 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index e25c1563..90476f76 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -3,6 +3,8 @@ Version 2.1.3 Unreleased +- Fix static typing for basic ``str`` methods on ``Markup``. :issue:`358` + Version 2.1.2 ------------- diff --git a/src/markupsafe/__init__.py b/src/markupsafe/__init__.py index b2faa466..c593a133 100644 --- a/src/markupsafe/__init__.py +++ b/src/markupsafe/__init__.py @@ -10,6 +10,8 @@ class HasHTML(te.Protocol): def __html__(self) -> str: pass + _P = te.ParamSpec("_P") + __version__ = "2.1.3.dev" @@ -17,16 +19,14 @@ def __html__(self) -> str: _strip_tags_re = re.compile(r"<.*?>", re.DOTALL) -def _simple_escaping_wrapper(name: str) -> t.Callable[..., "Markup"]: - orig = getattr(str, name) - - @functools.wraps(orig) - def wrapped(self: "Markup", *args: t.Any, **kwargs: t.Any) -> "Markup": - args = _escape_argspec(list(args), enumerate(args), self.escape) # type: ignore +def _simple_escaping_wrapper(func: "t.Callable[_P, str]") -> "t.Callable[_P, Markup]": + @functools.wraps(func) + def wrapped(self: "Markup", *args: "_P.args", **kwargs: "_P.kwargs") -> "Markup": + arg_list = _escape_argspec(list(args), enumerate(args), self.escape) _escape_argspec(kwargs, kwargs.items(), self.escape) - return self.__class__(orig(self, *args, **kwargs)) + return self.__class__(func(self, *arg_list, **kwargs)) # type: ignore[arg-type] - return wrapped + return wrapped # type: ignore[return-value] class Markup(str): @@ -177,27 +177,22 @@ def escape(cls, s: t.Any) -> "Markup": return rv - for method in ( - "__getitem__", - "capitalize", - "title", - "lower", - "upper", - "replace", - "ljust", - "rjust", - "lstrip", - "rstrip", - "center", - "strip", - "translate", - "expandtabs", - "swapcase", - "zfill", - ): - locals()[method] = _simple_escaping_wrapper(method) - - del method + __getitem__ = _simple_escaping_wrapper(str.__getitem__) + capitalize = _simple_escaping_wrapper(str.capitalize) + title = _simple_escaping_wrapper(str.title) + lower = _simple_escaping_wrapper(str.lower) + upper = _simple_escaping_wrapper(str.upper) + replace = _simple_escaping_wrapper(str.replace) + ljust = _simple_escaping_wrapper(str.ljust) + rjust = _simple_escaping_wrapper(str.rjust) + lstrip = _simple_escaping_wrapper(str.lstrip) + rstrip = _simple_escaping_wrapper(str.rstrip) + center = _simple_escaping_wrapper(str.center) + strip = _simple_escaping_wrapper(str.strip) + translate = _simple_escaping_wrapper(str.translate) + expandtabs = _simple_escaping_wrapper(str.expandtabs) + swapcase = _simple_escaping_wrapper(str.swapcase) + zfill = _simple_escaping_wrapper(str.zfill) def partition(self, sep: str) -> t.Tuple["Markup", "Markup", "Markup"]: l, s, r = super().partition(self.escape(sep))