diff --git a/docs/typing.rst b/docs/typing.rst index 0bee90e0a..5077ed0e6 100644 --- a/docs/typing.rst +++ b/docs/typing.rst @@ -634,3 +634,7 @@ you may use the special ``\from`` escape code to import them: \from typing import Optional as _Opt, Literal def lookup(array: Array[T], index: Literal[0] = 0) -> _Opt[T]: \doc + +You may also add free-form text the beginning or the end of the generated stub. +To do so, add an entry that matches on ``module_name.__prefix__`` or +``module_name.__suffix__``. diff --git a/src/stubgen.py b/src/stubgen.py index d6156a1de..b27ba6f98 100755 --- a/src/stubgen.py +++ b/src/stubgen.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +# pyright: strict + """ stubgen.py: nanobind stub generation tool @@ -136,6 +138,11 @@ class NbType(Protocol): @dataclass class ReplacePattern: + """ + A compiled query (regular expression) and replacement pattern. Patterns can + be loaded using the ``load_pattern_file()`` function dfined below + """ + # A replacement patterns as produced by ``load_pattern_file()`` below query: Pattern[str] lines: List[str] @@ -614,10 +621,23 @@ def process_general(m: Match[str]) -> str: return s - def apply_pattern(self, value: object, pattern: ReplacePattern, match: Match[str]) -> None: + def apply_pattern(self, query: str, value: object) -> bool: """ - Called when ``value`` matched an entry of a pattern file + Check if ``value`` matches an entry of a pattern file. Applies the + pattern and returns ``True`` in that case, otherwise returns ``False``. """ + + match: Optional[Match[str]] = None + pattern: Optional[ReplacePattern] = None + + for pattern in self.patterns: + match = pattern.query.search(query) + if match: + break + + if not match or not pattern: + return False + for line in pattern.lines: ls = line.strip() if ls == "\\doc": @@ -663,6 +683,9 @@ def apply_pattern(self, value: object, pattern: ReplacePattern, match: Match[str line = line.replace(f"\\{k}", v) self.write_ln(line) + # Success, pattern was applied + return True + def put(self, value: object, name: Optional[str] = None, parent: Optional[object] = None) -> None: old_prefix = self.prefix @@ -675,13 +698,8 @@ def put(self, value: object, name: Optional[str] = None, parent: Optional[object self.prefix = self.prefix + (("." + name) if name else "") # Check if an entry in a provided pattern file matches - if self.prefix: - for pattern in self.patterns: - match = pattern.query.search(self.prefix) - if match: - # If so, don't recurse - self.apply_pattern(value, pattern, match) - return + if self.apply_pattern(self.prefix, value): + return # Exclude various standard elements found in modules, classes, etc. if name in SKIP_LIST: @@ -713,8 +731,11 @@ def put(self, value: object, name: Optional[str] = None, parent: Optional[object # Do not recurse into submodules, but include a directive to import them self.import_object(value.__name__, name=None, as_name=name) return - for name, child in getmembers(value): - self.put(child, name=name, parent=value) + else: + self.apply_pattern(self.prefix + ".__prefix__", None) + for name, child in getmembers(value): + self.put(child, name=name, parent=value) + self.apply_pattern(self.prefix + ".__suffix__", None) elif self.is_function(tp): value = cast(NbFunction, value) self.put_function(value, name, parent) @@ -996,7 +1017,10 @@ def get(self) -> str: if s: s += "\n" s += self.put_abstract_enum_class() + + # Append the main generated stub s += self.output + return s.rstrip() + "\n" def put_abstract_enum_class(self) -> str: @@ -1143,6 +1167,11 @@ def parse_options(args: List[str]) -> argparse.Namespace: def load_pattern_file(fname: str) -> List[ReplacePattern]: + """ + Load a pattern file from disk and return a list of pattern instances that + includes precompiled versions of all of the contained regular expressions. + """ + with open(fname, "r") as f: f_lines = f.readlines() @@ -1150,7 +1179,7 @@ def load_pattern_file(fname: str) -> List[ReplacePattern]: def add_pattern(query: str, lines: List[str]): # Exactly 1 empty line at the end - while lines and lines[-1].isspace(): + while lines and (lines[-1].isspace() or len(lines[-1]) == 0): lines.pop() lines.append("") diff --git a/tests/pattern_file.nb b/tests/pattern_file.nb index 7ddf3e63f..8b9939f58 100644 --- a/tests/pattern_file.nb +++ b/tests/pattern_file.nb @@ -11,3 +11,9 @@ tweak_me: # Apply a pattern to multiple places __(lt|gt)__: def __\1__(self, arg: int, /) -> bool: ... + +test_typing_ext.__prefix__: + # a prefix + +test_typing_ext.__suffix__: + # a suffix diff --git a/tests/test_typing_ext.pyi.ref b/tests/test_typing_ext.pyi.ref index a0a64ff2f..0f03e31de 100644 --- a/tests/test_typing_ext.pyi.ref +++ b/tests/test_typing_ext.pyi.ref @@ -3,6 +3,8 @@ from .submodule import F as F, f as f2 from collections.abc import Iterable from typing import Self, Optional, TypeAlias, TypeVar, Generic +# a prefix + @my_decorator class CustomSignature(Iterable[int]): @my_decorator @@ -52,3 +54,5 @@ def tweak_me(arg: int): prior docstring remains preserved """ + +# a suffix