diff --git a/docs/type_param_demo.py b/docs/type_param_demo.py index 177cc6cc..ba7a2387 100644 --- a/docs/type_param_demo.py +++ b/docs/type_param_demo.py @@ -70,7 +70,11 @@ def get(self, key: K, default: V) -> V: ... def get(self, key: K, default: T) -> Union[V, T]: ... def get(self, key: K, default=None): - """Return the mapped value, or the specified default.""" + """Return the mapped value, or the specified default. + + :param key: Key to retrieve. + :param default: Default value to return if key is not present. + """ ... def __len__(self) -> int: diff --git a/sphinx_immaterial/apidoc/python/parameter_objects.py b/sphinx_immaterial/apidoc/python/parameter_objects.py index d02592c1..56969bc1 100644 --- a/sphinx_immaterial/apidoc/python/parameter_objects.py +++ b/sphinx_immaterial/apidoc/python/parameter_objects.py @@ -253,16 +253,29 @@ def get_objects( PythonDomain.get_objects = get_objects # type: ignore[assignment] +def _fix_pending_xrefs_to_type_params( + type_param_symbols: dict[str, str], parent: docutils.nodes.Node +) -> None: + for xref in parent.findall(condition=sphinx.addnodes.pending_xref): + if xref["refdomain"] == "py" and xref["reftype"] in ("class", "param"): + p_symbol = type_param_symbols.get(xref["reftarget"]) + if p_symbol is not None: + xref["reftarget"] = p_symbol + xref["refspecific"] = False + + def _add_parameter_links_to_signature( env: sphinx.environment.BuildEnvironment, signode: sphinx.addnodes.desc_signature, type_param_symbol_prefix: str, function_param_symbol_prefix: str, -) -> Dict[str, docutils.nodes.Element]: +) -> tuple[dict[str, docutils.nodes.Element], dict[str, str]]: """Cross-links parameter names in signature to parameter objects. Returns: - Map of parameter name to original (not linked) parameter node. + Tuple of: + - Map of parameter name to original (not linked) parameter node. + - Map of type parameter name to parameter object symbol. """ sig_param_nodes: Dict[str, docutils.nodes.Element] = {} @@ -336,15 +349,13 @@ def _collect_parameters( refnode["implicit_sig_param_ref"] = True name_node.replace_self(refnode) - # Also cross-link references to type parameters in annotations. - for xref in signode.findall(condition=sphinx.addnodes.pending_xref): - if xref["refdomain"] == "py" and xref["reftype"] in ("class", "param"): - p_symbol = type_param_symbols.get(xref["reftarget"]) - if p_symbol is not None: - xref["reftarget"] = p_symbol - xref["refspecific"] = False + if type_param_symbols: + # Also cross-link references to type parameters in annotations. + _fix_pending_xrefs_to_type_params(type_param_symbols, signode) + for parent in sig_param_nodes.values(): + _fix_pending_xrefs_to_type_params(type_param_symbols, parent) - return sig_param_nodes + return sig_param_nodes, type_param_symbols def _collate_parameter_symbols( @@ -550,6 +561,8 @@ def _cross_link_parameters( env = app.env assert isinstance(env, sphinx.environment.BuildEnvironment) + type_param_symbols: dict[str, str] = {} + # Collect the docutils nodes corresponding to the declarations of the # parameters in each signature, and turn the parameter names into # cross-links to the parameter description. @@ -559,9 +572,11 @@ def _cross_link_parameters( # e.g. `x : int = 10` rather than just `x`. sig_param_nodes_for_signature = [] for signode, symbol, function_symbol in zip(signodes, symbols, function_symbols): - sig_param_nodes_for_signature.append( - _add_parameter_links_to_signature(env, signode, symbol, function_symbol) + sig_param_nodes, sig_type_param_symbols = _add_parameter_links_to_signature( + env, signode, symbol, function_symbol ) + sig_param_nodes_for_signature.append(sig_param_nodes) + type_param_symbols.update(sig_type_param_symbols) # Find all parameter descriptions in the object description body, and mark # them as the target for cross links to that parameter. Also substitute in @@ -576,6 +591,10 @@ def _cross_link_parameters( noindex=noindex, ) + # Fix any remaining references to type parameters. + if type_param_symbols: + _fix_pending_xrefs_to_type_params(type_param_symbols, content) + if not noindex: py = cast(sphinx.domains.python.PythonDomain, env.get_domain("py")) diff --git a/tests/python_apigen_test.py b/tests/python_apigen_test.py index 41bd0c94..a5c8eb63 100644 --- a/tests/python_apigen_test.py +++ b/tests/python_apigen_test.py @@ -25,7 +25,15 @@ def apigen_make_app(tmp_path: pathlib.Path, make_app): def make(extra_conf: str = "", **kwargs): (tmp_path / "conf.py").write_text(conf + extra_conf, encoding="utf-8") - (tmp_path / "index.rst").write_text("", encoding="utf-8") + (tmp_path / "index.rst").write_text( + """ +.. python-apigen-group:: Public Members + +.. python-apigen-group:: Classes + +""", + encoding="utf-8", + ) return make_app(srcdir=SphinxPath(str(tmp_path)), **kwargs) yield make @@ -157,3 +165,24 @@ def test_pure_python_property(apigen_make_app): assert member.name == "baz" assert len(member.siblings) == 1 assert member.siblings[0].name == "bar" + + +@pytest.mark.skipif( + sphinx.version_info < (7, 1), + reason=f"Type parameters are not supported by Sphinx {sphinx.version_info}", +) +def test_type_params(apigen_make_app): + """Tests that references to type parameters are all resolved.""" + testmod = "python_apigen_test_modules.type_params" + app = apigen_make_app( + confoverrides=dict( + python_apigen_modules={ + testmod: "api/", + }, + nitpicky=True, + ), + ) + app.build() + print(app._status.getvalue()) + print(app._warning.getvalue()) + assert not app._warning.getvalue() diff --git a/tests/python_apigen_test_modules/type_params.py b/tests/python_apigen_test_modules/type_params.py new file mode 100644 index 00000000..6a0f895f --- /dev/null +++ b/tests/python_apigen_test_modules/type_params.py @@ -0,0 +1,29 @@ +from typing import TypeVar + +T = TypeVar("T") + + +def foo(x: T) -> T: + """Foo function. + + :param x: Something or other. + """ + return x + + +def bar(x: T) -> T: + return x + + +class C: + def get(self, x: T, y: T) -> T: + """Get function. + + :param x: Something or other. + :param y: Another param. + :type y: T + """ + return x + + +__all__ = ["foo", "bar", "C"]