Skip to content

Commit

Permalink
Add option to force unions (and options) to be rendered with bars (#418)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
hoodmane and pre-commit-ci[bot] authored Feb 8, 2024
1 parent 2156242 commit 601583f
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 8 deletions.
10 changes: 9 additions & 1 deletion src/sphinx_autodoc_typehints/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,13 @@ def format_annotation(annotation: Any, config: Config) -> str: # noqa: C901, PL
args_format = "\\[{}]"
formatted_args: str | None = ""

always_use_bars_union: bool = getattr(config, "always_use_bars_union", True)
is_bars_union = full_name == "types.UnionType" or (
always_use_bars_union and type(annotation).__qualname__ == "_UnionGenericAlias"
)
if is_bars_union:
full_name = ""

# Some types require special handling
if full_name == "typing.NewType":
args_format = f"\\(``{annotation.__name__}``, {{}})"
Expand Down Expand Up @@ -248,7 +255,7 @@ def format_annotation(annotation: Any, config: Config) -> str: # noqa: C901, PL
formatted_args = f"\\[\\[{', '.join(fmt[:-1])}], {fmt[-1]}]"
elif full_name == "typing.Literal":
formatted_args = f"\\[{', '.join(f'``{arg!r}``' for arg in args)}]"
elif full_name == "types.UnionType":
elif is_bars_union:
return " | ".join([format_annotation(arg, config) for arg in args])

if args and not formatted_args:
Expand Down Expand Up @@ -929,6 +936,7 @@ def setup(app: Sphinx) -> dict[str, bool]:
app.add_config_value("typehints_use_rtype", True, "env") # noqa: FBT003
app.add_config_value("typehints_defaults", None, "env")
app.add_config_value("simplify_optional_unions", True, "env") # noqa: FBT003
app.add_config_value("always_use_bars_union", False, "env") # noqa: FBT003
app.add_config_value("typehints_formatter", None, "env")
app.add_config_value("typehints_use_signature", False, "env") # noqa: FBT003
app.add_config_value("typehints_use_signature_return", False, "env") # noqa: FBT003
Expand Down
46 changes: 39 additions & 7 deletions tests/test_sphinx_autodoc_typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def test_parse_annotation(annotation: Any, module: str, class_name: str, args: t

@pytest.mark.parametrize(("annotation", "expected_result"), _CASES)
def test_format_annotation(inv: Inventory, annotation: Any, expected_result: str) -> None:
conf = create_autospec(Config, _annotation_globals=globals())
conf = create_autospec(Config, _annotation_globals=globals(), always_use_bars_union=False)
result = format_annotation(annotation, conf)
assert result == expected_result

Expand All @@ -377,7 +377,12 @@ def test_format_annotation(inv: Inventory, annotation: Any, expected_result: str
# encapsulate Union in typing.Optional
expected_result_not_simplified = ":py:data:`~typing.Optional`\\ \\[" + expected_result_not_simplified
expected_result_not_simplified += "]"
conf = create_autospec(Config, simplify_optional_unions=False, _annotation_globals=globals())
conf = create_autospec(
Config,
simplify_optional_unions=False,
_annotation_globals=globals(),
always_use_bars_union=False,
)
assert format_annotation(annotation, conf) == expected_result_not_simplified

# Test with the "fully_qualified" flag turned on
Expand All @@ -397,7 +402,12 @@ def test_format_annotation(inv: Inventory, annotation: Any, expected_result: str
expected_result = expected_result.replace("~nptyping", "nptyping")
expected_result = expected_result.replace("~numpy", "numpy")
expected_result = expected_result.replace("~" + __name__, __name__)
conf = create_autospec(Config, typehints_fully_qualified=True, _annotation_globals=globals())
conf = create_autospec(
Config,
typehints_fully_qualified=True,
_annotation_globals=globals(),
always_use_bars_union=False,
)
assert format_annotation(annotation, conf) == expected_result

# Test for the correct role (class vs data) using the official Sphinx inventory
Expand All @@ -413,6 +423,26 @@ def test_format_annotation(inv: Inventory, annotation: Any, expected_result: str
assert m.group("role") == expected_role


@pytest.mark.parametrize(
("annotation", "expected_result"),
[
("int | float", ":py:class:`int` | :py:class:`float`"),
("int | float | None", ":py:class:`int` | :py:class:`float` | :py:obj:`None`"),
("Union[int, float]", ":py:class:`int` | :py:class:`float`"),
("Union[int, float, None]", ":py:class:`int` | :py:class:`float` | :py:obj:`None`"),
("Optional[int | float]", ":py:class:`int` | :py:class:`float` | :py:obj:`None`"),
("Optional[Union[int, float]]", ":py:class:`int` | :py:class:`float` | :py:obj:`None`"),
("Union[int | float, str]", ":py:class:`int` | :py:class:`float` | :py:class:`str`"),
("Union[int, float] | str", ":py:class:`int` | :py:class:`float` | :py:class:`str`"),
],
)
@pytest.mark.skipif(not PY310_PLUS, reason="| union doesn't work before py310")
def test_always_use_bars_union(annotation: str, expected_result: str) -> None:
conf = create_autospec(Config, always_use_bars_union=True)
result = format_annotation(eval(annotation), conf) # noqa: S307
assert result == expected_result


@pytest.mark.parametrize("library", [typing, typing_extensions], ids=["typing", "typing_extensions"])
@pytest.mark.parametrize(
("annotation", "params", "expected_result"),
Expand Down Expand Up @@ -519,12 +549,13 @@ class dummy_module.DataClass(x)


def maybe_fix_py310(expected_contents: str) -> str:
if sys.version_info >= (3, 11):
return expected_contents
if not PY310_PLUS:
return expected_contents.replace('"', "")

for old, new in [
("bool | None", '"Optional"["bool"]'),
("str | None", '"Optional"["str"]'),
('"str" | "None"', '"Optional"["str"]'),
]:
expected_contents = expected_contents.replace(old, new)
return expected_contents
Expand All @@ -550,15 +581,16 @@ def test_sphinx_output_future_annotations(app: SphinxTestApp, status: StringIO)
Method docstring.
Parameters:
* **x** (bool | None) -- foo
* **x** ("bool" | "None") -- foo
* **y** ("int" | "str" | "float") -- bar
* **z** (str | None) -- baz
* **z** ("str" | "None") -- baz
Return type:
"str"
"""
expected_contents = dedent(expected_contents)
expected_contents = maybe_fix_py310(dedent(expected_contents))
assert contents == expected_contents

Expand Down

0 comments on commit 601583f

Please sign in to comment.