Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve params in text #311

Merged
merged 3 commits into from
Oct 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 30 additions & 14 deletions src/doc_builder/autodoc.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,36 +533,52 @@ def resolve_links_in_text(text, package, mapping, page_info):
prefix = f"/docs/{package_name}/{version}/{language}/"

def _resolve_link(search):
object_name, last_char = search.groups()
object_or_param_name, last_char = search.groups()
# Deal with external libs first.
lib_name = object_name.split(".")[0]
lib_name = object_or_param_name.split(".")[0]
if lib_name.startswith("~"):
lib_name = lib_name[1:]
if lib_name in HUGGINFACE_LIBS and lib_name != package_name:
link = get_external_object_link(object_name, page_info)
link = get_external_object_link(object_or_param_name, page_info)
return f"{link}{last_char}"
object_name, param_name = None, None
# If `#` is in the name, assume it's a link to the function/method parameter.
if "#" in object_or_param_name:
object_name_for_param = object_or_param_name.split("#", 1)[0]
# Strip preceding `~` if it's there.
object_name_for_param = (
object_name_for_param[1:] if object_name_for_param.startswith("~") else object_name_for_param
)
obj = find_object_in_package(object_name_for_param, package)
param_name = object_or_param_name.split("#", 1)[-1]
# If the name begins with `~`, we shortcut to the last part.
if object_name.startswith("~"):
obj = find_object_in_package(object_name[1:], package)
object_name = object_name.split(".")[-1]
elif object_or_param_name.startswith("~"):
obj = find_object_in_package(object_or_param_name[1:], package)
object_name = object_or_param_name.split(".")[-1]
else:
obj = find_object_in_package(object_name, package)
obj = find_object_in_package(object_or_param_name, package)
object_name = object_or_param_name
# Object not found, return the original link text.
if obj is None:
return f"`{object_name}`{last_char}"
return f"`{object_or_param_name}`{last_char}"

link_name = object_name if param_name is None else param_name

# If the object is not a class, we add ()
if not isinstance(obj, (type, property)):
object_name = f"{object_name}()"
# If the link points to an object and the object is not a class, we add ()
if param_name is None and not isinstance(obj, (type, property)):
link_name = f"{link_name}()"

# Link to the anchor
anchor = get_shortest_path(obj, package)
if anchor not in mapping:
return f"`{object_name}`{last_char}"
return f"`{link_name}`{last_char}"
page = f"{prefix}{mapping[anchor]}"
if param_name:
anchor = f"{anchor}.{param_name}"
if "#" in page:
return f"[{object_name}]({page}){last_char}"
return f"[{link_name}]({page}){last_char}"
else:
return f"[{object_name}]({page}#{anchor}){last_char}"
return f"[{link_name}]({page}#{anchor}){last_char}"

return re.sub(r"\[`([^`]+)`\]([^\(])", _resolve_link, text)

Expand Down
12 changes: 12 additions & 0 deletions tests/test_autodoc.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,18 @@ def test_resolve_links_in_text(self):
),
)

self.assertEqual(
resolve_links_in_text(
"Link to [`transformers.BertModel.forward#input_ids`], [`~transformers.BertModel.forward#input_ids`].",
transformers,
small_mapping,
page_info,
),
(
"Link to [input_ids](/docs/transformers/main/en/model_doc/bert.html#transformers.BertModel.forward.input_ids), [input_ids](/docs/transformers/main/en/model_doc/bert.html#transformers.BertModel.forward.input_ids)."
),
)

self.assertEqual(
resolve_links_in_text(
"This is a regular [`link`](url)",
Expand Down