Skip to content

Commit

Permalink
ENH Add a css wrapper to generated types (#397)
Browse files Browse the repository at this point in the history
  • Loading branch information
hoodmane authored Nov 9, 2023
1 parent 8d9d554 commit 9735dc6
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 65 deletions.
54 changes: 48 additions & 6 deletions src/sphinx_autodoc_typehints/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, AnyStr, Callable, ForwardRef, NewType, TypeVar, get_type_hints

from docutils import nodes
from docutils.frontend import OptionParser
from docutils.parsers.rst import Parser as RstParser
from docutils.parsers.rst import states
from docutils.utils import new_document
from sphinx.ext.autodoc.mock import mock
from sphinx.util import logging
from sphinx.util import logging, rst
from sphinx.util.inspect import signature as sphinx_signature
from sphinx.util.inspect import stringify_signature

Expand Down Expand Up @@ -209,7 +211,7 @@ def format_annotation(annotation: Any, config: Config) -> str: # noqa: C901, PL
fully_qualified: bool = getattr(config, "typehints_fully_qualified", False)
prefix = "" if fully_qualified or full_name == class_name else "~"
role = "data" if module == "typing" and class_name in _PYDATA_ANNOTATIONS else "class"
args_format = "\\[{}]"
args_format = "\\ \\[{}]"
formatted_args: str | None = ""

# Some types require special handling
Expand Down Expand Up @@ -242,9 +244,9 @@ def format_annotation(annotation: Any, config: Config) -> str: # noqa: C901, PL
args = tuple(x for x in args if x is not type(None))
elif full_name in ("typing.Callable", "collections.abc.Callable") and args and args[0] is not ...:
fmt = [format_annotation(arg, config) for arg in args]
formatted_args = f"\\[\\[{', '.join(fmt[:-1])}], {fmt[-1]}]"
formatted_args = f"\\ \\[\\[{', '.join(fmt[:-1])}], {fmt[-1]}]"
elif full_name == "typing.Literal":
formatted_args = f"\\[{', '.join(f'``{arg!r}``' for arg in args)}]"
formatted_args = f"\\ \\[{', '.join(f'``{arg!r}``' for arg in args)}]"
elif full_name == "types.UnionType":
return " | ".join([format_annotation(arg, config) for arg in args])

Expand Down Expand Up @@ -724,7 +726,7 @@ def _inject_signature( # noqa: C901
if annotation is None:
type_annotation = f":type {arg_name}: "
else:
formatted_annotation = format_annotation(annotation, app.config)
formatted_annotation = add_type_css_class(format_annotation(annotation, app.config))
type_annotation = f":type {arg_name}: {formatted_annotation}"

if app.config.typehints_defaults:
Expand Down Expand Up @@ -843,7 +845,7 @@ def _inject_rtype( # noqa: PLR0913
if not app.config.typehints_use_rtype and r.found_return and " -- " in lines[insert_index]:
return

formatted_annotation = format_annotation(type_hints["return"], app.config)
formatted_annotation = add_type_css_class(format_annotation(type_hints["return"], app.config))

if r.found_param and insert_index < len(lines) and lines[insert_index].strip():
insert_index -= 1
Expand Down Expand Up @@ -874,6 +876,45 @@ def validate_config(app: Sphinx, env: BuildEnvironment, docnames: list[str]) ->
raise ValueError(msg)


def unescape(escaped: str) -> str:
# For some reason the string we get has a bunch of null bytes in it??
# Remove them...
escaped = escaped.replace("\x00", "")
# For some reason the extra slash before spaces gets lost between the .rst
# source and when this directive is called. So don't replace "\<space>" =>
# "<space>"
return re.sub(r"\\([^ ])", r"\1", escaped)


def add_type_css_class(type_rst: str) -> str:
return f":sphinx_autodoc_typehints_type:`{rst.escape(type_rst)}`"


def sphinx_autodoc_typehints_type_role(
_role: str,
_rawtext: str,
text: str,
_lineno: int,
inliner: states.Inliner,
_options: dict[str, Any] | None = None,
_content: list[str] | None = None,
) -> tuple[list[Node], list[Node]]:
"""
Add css tag around rendered type.
The body should be escaped rst. This renders its body as rst and wraps the
result in <span class="sphinx_autodoc_typehints-type"> </span>
"""
unescaped = unescape(text)
# the typestubs for docutils don't have any info about Inliner
doc = new_document("", inliner.document.settings) # type: ignore[attr-defined]
RstParser().parse(unescaped, doc)
n = nodes.inline(text)
n["classes"].append("sphinx_autodoc_typehints-type")
n += doc.children[0].children
return [n], []


def setup(app: Sphinx) -> dict[str, bool]:
app.add_config_value("always_document_param_types", False, "html") # noqa: FBT003
app.add_config_value("typehints_fully_qualified", False, "env") # noqa: FBT003
Expand All @@ -884,6 +925,7 @@ def setup(app: Sphinx) -> dict[str, bool]:
app.add_config_value("typehints_formatter", None, "env")
app.add_config_value("typehints_use_signature", False, "env") # noqa: FBT003
app.add_config_value("typehints_use_signature_return", False, "env") # noqa: FBT003
app.add_role("sphinx_autodoc_typehints_type", sphinx_autodoc_typehints_type_role)
app.connect("env-before-read-docs", validate_config) # config may be changed after “config-inited” event
app.connect("autodoc-process-signature", process_signature)
app.connect("autodoc-process-docstring", process_docstring)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def function_with_escaped_default(x: str = "\b"): # noqa: ANN201, ARG001
Function docstring.
Parameters:
**x** (*a.b.c*) -- foo
**x** (a.b.c) -- foo
""",
)
def function_with_unresolvable_annotation(x: a.b.c): # noqa: ANN201, ARG001, F821
Expand Down
111 changes: 53 additions & 58 deletions tests/test_sphinx_autodoc_typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,87 +201,87 @@ def test_parse_annotation(annotation: Any, module: str, class_name: str, args: t
(type, ":py:class:`type`"),
(collections.abc.Callable, ":py:class:`~collections.abc.Callable`"),
(Type, ":py:class:`~typing.Type`"),
(Type[A], ":py:class:`~typing.Type`\\[:py:class:`~%s.A`]" % __name__),
(Type[A], ":py:class:`~typing.Type`\\ \\[:py:class:`~%s.A`]" % __name__),
(Any, ":py:data:`~typing.Any`"),
(AnyStr, ":py:data:`~typing.AnyStr`"),
(Generic[T], ":py:class:`~typing.Generic`\\[:py:class:`~typing.TypeVar`\\(``T``)]"),
(Generic[T], ":py:class:`~typing.Generic`\\ \\[:py:class:`~typing.TypeVar`\\(``T``)]"),
(Mapping, ":py:class:`~typing.Mapping`"),
(
Mapping[T, int], # type: ignore[valid-type]
":py:class:`~typing.Mapping`\\[:py:class:`~typing.TypeVar`\\(``T``), :py:class:`int`]",
":py:class:`~typing.Mapping`\\ \\[:py:class:`~typing.TypeVar`\\(``T``), :py:class:`int`]",
),
(
Mapping[str, V_contra], # type: ignore[valid-type]
":py:class:`~typing.Mapping`\\[:py:class:`str`, :py:class:`~typing.TypeVar`\\("
":py:class:`~typing.Mapping`\\ \\[:py:class:`str`, :py:class:`~typing.TypeVar`\\("
"``V_contra``, contravariant=True)]",
),
(
Mapping[T, U_co], # type: ignore[valid-type]
":py:class:`~typing.Mapping`\\[:py:class:`~typing.TypeVar`\\(``T``), "
":py:class:`~typing.Mapping`\\ \\[:py:class:`~typing.TypeVar`\\(``T``), "
":py:class:`~typing.TypeVar`\\(``U_co``, covariant=True)]",
),
(Mapping[str, bool], ":py:class:`~typing.Mapping`\\[:py:class:`str`, :py:class:`bool`]"),
(Mapping[str, bool], ":py:class:`~typing.Mapping`\\ \\[:py:class:`str`, :py:class:`bool`]"),
(Dict, ":py:class:`~typing.Dict`"),
(
Dict[T, int], # type: ignore[valid-type]
":py:class:`~typing.Dict`\\[:py:class:`~typing.TypeVar`\\(``T``), :py:class:`int`]",
":py:class:`~typing.Dict`\\ \\[:py:class:`~typing.TypeVar`\\(``T``), :py:class:`int`]",
),
(
Dict[str, V_contra], # type: ignore[valid-type]
":py:class:`~typing.Dict`\\[:py:class:`str`, :py:class:`~typing.TypeVar`\\(``V_contra``, contravariant=True)]",
":py:class:`~typing.Dict`\\ \\[:py:class:`str`, :py:class:`~typing.TypeVar`\\(``V_contra``, contravariant=True)]", # noqa: E501
),
(
Dict[T, U_co], # type: ignore[valid-type]
":py:class:`~typing.Dict`\\[:py:class:`~typing.TypeVar`\\(``T``),"
":py:class:`~typing.Dict`\\ \\[:py:class:`~typing.TypeVar`\\(``T``),"
" :py:class:`~typing.TypeVar`\\(``U_co``, covariant=True)]",
),
(Dict[str, bool], ":py:class:`~typing.Dict`\\[:py:class:`str`, :py:class:`bool`]"),
(Dict[str, bool], ":py:class:`~typing.Dict`\\ \\[:py:class:`str`, :py:class:`bool`]"),
(Tuple, ":py:data:`~typing.Tuple`"),
(Tuple[str, bool], ":py:data:`~typing.Tuple`\\[:py:class:`str`, :py:class:`bool`]"),
(Tuple[int, int, int], ":py:data:`~typing.Tuple`\\[:py:class:`int`, :py:class:`int`, :py:class:`int`]"),
(Tuple[str, ...], ":py:data:`~typing.Tuple`\\[:py:class:`str`, :py:data:`...<Ellipsis>`]"),
(Tuple[str, bool], ":py:data:`~typing.Tuple`\\ \\[:py:class:`str`, :py:class:`bool`]"),
(Tuple[int, int, int], ":py:data:`~typing.Tuple`\\ \\[:py:class:`int`, :py:class:`int`, :py:class:`int`]"),
(Tuple[str, ...], ":py:data:`~typing.Tuple`\\ \\[:py:class:`str`, :py:data:`...<Ellipsis>`]"),
(Union, ":py:data:`~typing.Union`"),
(Union[str, bool], ":py:data:`~typing.Union`\\[:py:class:`str`, :py:class:`bool`]"),
(Union[str, bool, None], ":py:data:`~typing.Union`\\[:py:class:`str`, :py:class:`bool`, :py:obj:`None`]"),
pytest.param(Union[str, Any], ":py:data:`~typing.Union`\\[:py:class:`str`, :py:data:`~typing.Any`]"),
(Optional[str], ":py:data:`~typing.Optional`\\[:py:class:`str`]"),
(Union[str, None], ":py:data:`~typing.Optional`\\[:py:class:`str`]"),
(Union[str, bool], ":py:data:`~typing.Union`\\ \\[:py:class:`str`, :py:class:`bool`]"),
(Union[str, bool, None], ":py:data:`~typing.Union`\\ \\[:py:class:`str`, :py:class:`bool`, :py:obj:`None`]"),
pytest.param(Union[str, Any], ":py:data:`~typing.Union`\\ \\[:py:class:`str`, :py:data:`~typing.Any`]"),
(Optional[str], ":py:data:`~typing.Optional`\\ \\[:py:class:`str`]"),
(Union[str, None], ":py:data:`~typing.Optional`\\ \\[:py:class:`str`]"),
(
Optional[Union[str, bool]],
":py:data:`~typing.Union`\\[:py:class:`str`, :py:class:`bool`, :py:obj:`None`]",
":py:data:`~typing.Union`\\ \\[:py:class:`str`, :py:class:`bool`, :py:obj:`None`]",
),
(Callable, ":py:data:`~typing.Callable`"),
(Callable[..., int], ":py:data:`~typing.Callable`\\[:py:data:`...<Ellipsis>`, :py:class:`int`]"),
(Callable[[int], int], ":py:data:`~typing.Callable`\\[\\[:py:class:`int`], :py:class:`int`]"),
(Callable[..., int], ":py:data:`~typing.Callable`\\ \\[:py:data:`...<Ellipsis>`, :py:class:`int`]"),
(Callable[[int], int], ":py:data:`~typing.Callable`\\ \\[\\[:py:class:`int`], :py:class:`int`]"),
(
Callable[[int, str], bool],
":py:data:`~typing.Callable`\\[\\[:py:class:`int`, :py:class:`str`], :py:class:`bool`]",
":py:data:`~typing.Callable`\\ \\[\\[:py:class:`int`, :py:class:`str`], :py:class:`bool`]",
),
(
Callable[[int, str], None],
":py:data:`~typing.Callable`\\[\\[:py:class:`int`, :py:class:`str`], :py:obj:`None`]",
":py:data:`~typing.Callable`\\ \\[\\[:py:class:`int`, :py:class:`str`], :py:obj:`None`]",
),
(
Callable[[T], T],
":py:data:`~typing.Callable`\\[\\[:py:class:`~typing.TypeVar`\\(``T``)],"
":py:data:`~typing.Callable`\\ \\[\\[:py:class:`~typing.TypeVar`\\(``T``)],"
" :py:class:`~typing.TypeVar`\\(``T``)]",
),
(
AbcCallable[[int, str], bool], # type: ignore[valid-type,misc,type-arg]
":py:class:`~collections.abc.Callable`\\[\\[:py:class:`int`, :py:class:`str`], :py:class:`bool`]",
":py:class:`~collections.abc.Callable`\\ \\[\\[:py:class:`int`, :py:class:`str`], :py:class:`bool`]",
),
(Pattern, ":py:class:`~typing.Pattern`"),
(Pattern[str], ":py:class:`~typing.Pattern`\\[:py:class:`str`]"),
(Pattern[str], ":py:class:`~typing.Pattern`\\ \\[:py:class:`str`]"),
(IO, ":py:class:`~typing.IO`"),
(IO[str], ":py:class:`~typing.IO`\\[:py:class:`str`]"),
(IO[str], ":py:class:`~typing.IO`\\ \\[:py:class:`str`]"),
(Metaclass, ":py:class:`~%s.Metaclass`" % __name__),
(A, ":py:class:`~%s.A`" % __name__),
(B, ":py:class:`~%s.B`" % __name__),
(B[int], ":py:class:`~%s.B`\\[:py:class:`int`]" % __name__),
(B[int], ":py:class:`~%s.B`\\ \\[:py:class:`int`]" % __name__),
(C, ":py:class:`~%s.C`" % __name__),
(D, ":py:class:`~%s.D`" % __name__),
(E, ":py:class:`~%s.E`" % __name__),
(E[int], ":py:class:`~%s.E`\\[:py:class:`int`]" % __name__),
(E[int], ":py:class:`~%s.E`\\ \\[:py:class:`int`]" % __name__),
(W, f":py:{'class' if PY310_PLUS else 'func'}:`~typing.NewType`\\(``W``, :py:class:`str`)"),
(T, ":py:class:`~typing.TypeVar`\\(``T``)"),
(U_co, ":py:class:`~typing.TypeVar`\\(``U_co``, covariant=True)"),
Expand All @@ -306,17 +306,17 @@ def test_parse_annotation(annotation: Any, module: str, class_name: str, args: t
# Zero-length tuple remains
(Tuple[()], ":py:data:`~typing.Tuple`"),
# Internal single tuple with simple types is flattened in the output
(Tuple[(int,)], ":py:data:`~typing.Tuple`\\[:py:class:`int`]"),
(Tuple[(int, int)], ":py:data:`~typing.Tuple`\\[:py:class:`int`, :py:class:`int`]"),
(Tuple[(int,)], ":py:data:`~typing.Tuple`\\ \\[:py:class:`int`]"),
(Tuple[(int, int)], ":py:data:`~typing.Tuple`\\ \\[:py:class:`int`, :py:class:`int`]"),
# Ellipsis in single tuple also gets flattened
(Tuple[(int, ...)], ":py:data:`~typing.Tuple`\\[:py:class:`int`, :py:data:`...<Ellipsis>`]"),
(Tuple[(int, ...)], ":py:data:`~typing.Tuple`\\ \\[:py:class:`int`, :py:data:`...<Ellipsis>`]"),
(
RecList,
":py:data:`~typing.Union`\\[:py:class:`int`, :py:class:`~typing.List`\\[RecList]]",
":py:data:`~typing.Union`\\ \\[:py:class:`int`, :py:class:`~typing.List`\\ \\[RecList]]",
),
(
MutualRecA,
":py:data:`~typing.Union`\\[:py:class:`bool`, :py:class:`~typing.List`\\[MutualRecB]]",
":py:data:`~typing.Union`\\ \\[:py:class:`bool`, :py:class:`~typing.List`\\ \\[MutualRecB]]",
),
]

Expand All @@ -327,39 +327,39 @@ def test_parse_annotation(annotation: Any, module: str, class_name: str, args: t
(
nptyping.NDArray[nptyping.Shape["*"], nptyping.Float],
(
":py:class:`~nptyping.ndarray.NDArray`\\[:py:class:`~nptyping.base_meta_classes.Shape`\\[*], "
":py:class:`~nptyping.ndarray.NDArray`\\ \\[:py:class:`~nptyping.base_meta_classes.Shape`\\ \\[*], "
":py:class:`~numpy.float64`]"
),
),
(
nptyping.NDArray[nptyping.Shape["64"], nptyping.Float],
(
":py:class:`~nptyping.ndarray.NDArray`\\[:py:class:`~nptyping.base_meta_classes.Shape`\\[64],"
":py:class:`~nptyping.ndarray.NDArray`\\ \\[:py:class:`~nptyping.base_meta_classes.Shape`\\ \\[64],"
" :py:class:`~numpy.float64`]"
),
),
(
nptyping.NDArray[nptyping.Shape["*, *"], nptyping.Float],
(
":py:class:`~nptyping.ndarray.NDArray`\\[:py:class:`~nptyping.base_meta_classes.Shape`\\[*, "
":py:class:`~nptyping.ndarray.NDArray`\\ \\[:py:class:`~nptyping.base_meta_classes.Shape`\\ \\[*, "
"*], :py:class:`~numpy.float64`]"
),
),
(
nptyping.NDArray[nptyping.Shape["*, ..."], nptyping.Float],
":py:class:`~nptyping.ndarray.NDArray`\\[:py:data:`~typing.Any`, :py:class:`~numpy.float64`]",
":py:class:`~nptyping.ndarray.NDArray`\\ \\[:py:data:`~typing.Any`, :py:class:`~numpy.float64`]",
),
(
nptyping.NDArray[nptyping.Shape["*, 3"], nptyping.Float],
(
":py:class:`~nptyping.ndarray.NDArray`\\[:py:class:`~nptyping.base_meta_classes.Shape`\\[*, 3"
":py:class:`~nptyping.ndarray.NDArray`\\ \\[:py:class:`~nptyping.base_meta_classes.Shape`\\ \\[*, 3"
"], :py:class:`~numpy.float64`]"
),
),
(
nptyping.NDArray[nptyping.Shape["3, ..."], nptyping.Float],
(
":py:class:`~nptyping.ndarray.NDArray`\\[:py:class:`~nptyping.base_meta_classes.Shape`\\[3, "
":py:class:`~nptyping.ndarray.NDArray`\\ \\[:py:class:`~nptyping.base_meta_classes.Shape`\\ \\[3, "
"...], :py:class:`~numpy.float64`]"
),
),
Expand All @@ -379,7 +379,7 @@ def test_format_annotation(inv: Inventory, annotation: Any, expected_result: str
# subsequent tests
expected_result_not_simplified = expected_result.replace(", ``None``", "")
# encapsulate Union in typing.Optional
expected_result_not_simplified = ":py:data:`~typing.Optional`\\[" + expected_result_not_simplified
expected_result_not_simplified = ":py:data:`~typing.Optional`\\ \\[" + expected_result_not_simplified
expected_result_not_simplified += "]"
conf = create_autospec(Config, simplify_optional_unions=False, _annotation_globals=globals())
assert format_annotation(annotation, conf) == expected_result_not_simplified
Expand Down Expand Up @@ -421,11 +421,11 @@ def test_format_annotation(inv: Inventory, annotation: Any, expected_result: str
@pytest.mark.parametrize(
("annotation", "params", "expected_result"),
[
("ClassVar", int, ":py:data:`~typing.ClassVar`\\[:py:class:`int`]"),
("ClassVar", int, ":py:data:`~typing.ClassVar`\\ \\[:py:class:`int`]"),
("NoReturn", None, ":py:data:`~typing.NoReturn`"),
("Literal", ("a", 1), ":py:data:`~typing.Literal`\\[``'a'``, ``1``]"),
("Literal", ("a", 1), ":py:data:`~typing.Literal`\\ \\[``'a'``, ``1``]"),
("Type", None, ":py:class:`~typing.Type`"),
("Type", (A,), f":py:class:`~typing.Type`\\[:py:class:`~{__name__}.A`]"),
("Type", (A,), f":py:class:`~typing.Type`\\ \\[:py:class:`~{__name__}.A`]"),
],
)
def test_format_annotation_both_libs(library: ModuleType, annotation: str, params: Any, expected_result: str) -> None:
Expand Down Expand Up @@ -524,16 +524,11 @@ class dummy_module.DataClass(x)

def maybe_fix_py310(expected_contents: str) -> str:
if not PY310_PLUS:
return expected_contents
return expected_contents.replace('"', "")

for old, new in [
("*bool** | **None*", '"Optional"["bool"]'),
("*int** | **str** | **float*", '"int" | "str" | "float"'),
("*str** | **None*", '"Optional"["str"]'),
("(*bool*)", '("bool")'),
("(*int*", '("int"'),
(" str", ' "str"'),
('"Optional"["str"]', '"Optional"["str"]'),
('"Optional"["Callable"[["int", "bytes"], "int"]]', '"Optional"["Callable"[["int", "bytes"], "int"]]'),
("bool | None", '"Optional"["bool"]'),
("str | None", '"Optional"["str"]'),
]:
expected_contents = expected_contents.replace(old, new)
return expected_contents
Expand All @@ -559,14 +554,14 @@ def test_sphinx_output_future_annotations(app: SphinxTestApp, status: StringIO)
Method docstring.
Parameters:
* **x** (*bool** | **None*) -- foo
* **x** (bool | None) -- foo
* **y** (*int** | **str** | **float*) -- bar
* **y** ("int" | "str" | "float") -- bar
* **z** (*str** | **None*) -- baz
* **z** (str | None) -- baz
Return type:
str
"str"
"""
expected_contents = maybe_fix_py310(dedent(expected_contents))
assert contents == expected_contents
Expand Down Expand Up @@ -625,7 +620,7 @@ def test_sphinx_output_defaults(
("formatter_config_val", "expected"),
[
(None, ['("bool") -- foo', '("int") -- bar', '"str"']),
(lambda ann, conf: "Test", ["(*Test*) -- foo", "(*Test*) -- bar", "Test"]), # noqa: ARG005
(lambda ann, conf: "Test", ["(Test) -- foo", "(Test) -- bar", "Test"]), # noqa: ARG005
("some string", Exception("needs to be callable or `None`")),
],
)
Expand Down

0 comments on commit 9735dc6

Please sign in to comment.