Skip to content

Commit

Permalink
stubgen: added the ability to include a custom prefix/suffix in gener…
Browse files Browse the repository at this point in the history
…ated stubs
  • Loading branch information
wjakob committed Mar 3, 2024
1 parent 0d191b4 commit 4240a97
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 12 deletions.
4 changes: 4 additions & 0 deletions docs/typing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -634,3 +634,7 @@ you may use the special ``\from`` escape code to import them:
\from typing import Optional as _Opt, Literal
def lookup(array: Array[T], index: Literal[0] = 0) -> _Opt[T]:
\doc
You may also add free-form text the beginning or the end of the generated stub.
To do so, add an entry that matches on ``module_name.__prefix__`` or
``module_name.__suffix__``.
53 changes: 41 additions & 12 deletions src/stubgen.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#!/usr/bin/env python3
# pyright: strict

"""
stubgen.py: nanobind stub generation tool
Expand Down Expand Up @@ -136,6 +138,11 @@ class NbType(Protocol):

@dataclass
class ReplacePattern:
"""
A compiled query (regular expression) and replacement pattern. Patterns can
be loaded using the ``load_pattern_file()`` function dfined below
"""

# A replacement patterns as produced by ``load_pattern_file()`` below
query: Pattern[str]
lines: List[str]
Expand Down Expand Up @@ -614,10 +621,23 @@ def process_general(m: Match[str]) -> str:

return s

def apply_pattern(self, value: object, pattern: ReplacePattern, match: Match[str]) -> None:
def apply_pattern(self, query: str, value: object) -> bool:
"""
Called when ``value`` matched an entry of a pattern file
Check if ``value`` matches an entry of a pattern file. Applies the
pattern and returns ``True`` in that case, otherwise returns ``False``.
"""

match: Optional[Match[str]] = None
pattern: Optional[ReplacePattern] = None

for pattern in self.patterns:
match = pattern.query.search(query)
if match:
break

if not match or not pattern:
return False

for line in pattern.lines:
ls = line.strip()
if ls == "\\doc":
Expand Down Expand Up @@ -663,6 +683,9 @@ def apply_pattern(self, value: object, pattern: ReplacePattern, match: Match[str
line = line.replace(f"\\{k}", v)
self.write_ln(line)

# Success, pattern was applied
return True

def put(self, value: object, name: Optional[str] = None, parent: Optional[object] = None) -> None:
old_prefix = self.prefix

Expand All @@ -675,13 +698,8 @@ def put(self, value: object, name: Optional[str] = None, parent: Optional[object
self.prefix = self.prefix + (("." + name) if name else "")

# Check if an entry in a provided pattern file matches
if self.prefix:
for pattern in self.patterns:
match = pattern.query.search(self.prefix)
if match:
# If so, don't recurse
self.apply_pattern(value, pattern, match)
return
if self.apply_pattern(self.prefix, value):
return

# Exclude various standard elements found in modules, classes, etc.
if name in SKIP_LIST:
Expand Down Expand Up @@ -713,8 +731,11 @@ def put(self, value: object, name: Optional[str] = None, parent: Optional[object
# Do not recurse into submodules, but include a directive to import them
self.import_object(value.__name__, name=None, as_name=name)
return
for name, child in getmembers(value):
self.put(child, name=name, parent=value)
else:
self.apply_pattern(self.prefix + ".__prefix__", None)
for name, child in getmembers(value):
self.put(child, name=name, parent=value)
self.apply_pattern(self.prefix + ".__suffix__", None)
elif self.is_function(tp):
value = cast(NbFunction, value)
self.put_function(value, name, parent)
Expand Down Expand Up @@ -996,7 +1017,10 @@ def get(self) -> str:
if s:
s += "\n"
s += self.put_abstract_enum_class()

# Append the main generated stub
s += self.output

return s.rstrip() + "\n"

def put_abstract_enum_class(self) -> str:
Expand Down Expand Up @@ -1143,14 +1167,19 @@ def parse_options(args: List[str]) -> argparse.Namespace:


def load_pattern_file(fname: str) -> List[ReplacePattern]:
"""
Load a pattern file from disk and return a list of pattern instances that
includes precompiled versions of all of the contained regular expressions.
"""

with open(fname, "r") as f:
f_lines = f.readlines()

patterns: List[ReplacePattern] = []

def add_pattern(query: str, lines: List[str]):
# Exactly 1 empty line at the end
while lines and lines[-1].isspace():
while lines and (lines[-1].isspace() or len(lines[-1]) == 0):
lines.pop()
lines.append("")

Expand Down
6 changes: 6 additions & 0 deletions tests/pattern_file.nb
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,9 @@ tweak_me:
# Apply a pattern to multiple places
__(lt|gt)__:
def __\1__(self, arg: int, /) -> bool: ...

test_typing_ext.__prefix__:
# a prefix

test_typing_ext.__suffix__:
# a suffix
4 changes: 4 additions & 0 deletions tests/test_typing_ext.pyi.ref
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ from .submodule import F as F, f as f2
from collections.abc import Iterable
from typing import Self, Optional, TypeAlias, TypeVar, Generic

# a prefix

@my_decorator
class CustomSignature(Iterable[int]):
@my_decorator
Expand Down Expand Up @@ -52,3 +54,5 @@ def tweak_me(arg: int):
prior docstring
remains preserved
"""

# a suffix

0 comments on commit 4240a97

Please sign in to comment.