Skip to content

Commit

Permalink
stubgen: don't include aliases to functions/classes of different pack…
Browse files Browse the repository at this point in the history
…ages
  • Loading branch information
wjakob committed Mar 3, 2024
1 parent 2341d14 commit e1cb670
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 11 deletions.
64 changes: 53 additions & 11 deletions src/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,17 @@ class and repeatedly call ``.put()`` to register modules or contents within the
# Type of an entry of the ``__nb_signature__`` tuple of nanobind getters and setters.
NbGetterSetterSignature = Tuple[str, str]

class NamedObject(Protocol):
"""
Typing protocol representing a an object with __name__ and __module__ members
"""
__module__: str
__name__: str

class NbFunction(Protocol):
"""
Typing protocol representing a nanobind function with its __nb_signature__ property
"""
__module__: Literal["nanobind"]
__name__: Literal["nb_func", "nb_method"]
__nb_signature__: Tuple[NbFunctionSignature, ...]
Expand All @@ -123,13 +133,15 @@ class NbGetterSetter(Protocol):


class NbStaticProperty(Protocol):
"""Typing protocol representing a nanobind static property"""
__module__: Literal["nanobind"]
__name__: Literal["nb_static_property"]
fget: NbGetterSetter
fset: NbGetterSetter


class NbType(Protocol):
"""typing protocol representing a nanobind type object"""
__module__: Literal["nanobind"]
__name__: Literal["nb_type"]
__nb_signature__: str
Expand All @@ -155,6 +167,8 @@ def __init__(
module: types.ModuleType,
include_docstrings: bool = True,
include_private: bool = False,
include_internal_imports: bool = True,
include_external_imports: bool = False,
max_expr_length: int = 50,
patterns: List[ReplacePattern] = [],
) -> None:
Expand All @@ -167,6 +181,12 @@ def __init__(
# Include private members that start or end with a single underscore?
self.include_private = include_private

# Include types and functions imported from the same package (but a different module)
self.include_internal_imports = include_internal_imports

# Include types and functions imported from external packages?
self.include_external_imports = include_external_imports

# Maximal length (in characters) before an expression gets abbreviated as '...'
self.max_expr_length = max_expr_length

Expand Down Expand Up @@ -435,12 +455,18 @@ def put_nb_static_property(self, name: Optional[str], prop: NbStaticProperty):

def put_type(self, tp: NbType, name: Optional[str]):
"""Append a 'nb_type' type object"""
if name and (name != tp.__name__ or self.module.__name__ != tp.__module__):
if self.module.__name__ == tp.__module__:
# This is an alias of a type in the same module
tp_name, tp_mod_name = tp.__name__, tp.__module__
mod_name = self.module.__name__

if name and (name != tp_name or mod_name != tp_mod_name):
same_module = tp_mod_name == mod_name
same_toplevel_module = tp_mod_name.split(".")[0] == mod_name.split(".")[0]

if same_module:
# This is an alias of a type in the same module or same top-level module
alias_tp = self.import_object("typing", "TypeAlias")
self.write_ln(f"{name}: {alias_tp} = {tp.__name__}\n")
else:
self.write_ln(f"{name}: {alias_tp} = {tp_name}\n")
elif self.include_external_imports or (same_toplevel_module and self.include_internal_imports):
# Import from a different module
self.put_value(tp, name)
else:
Expand Down Expand Up @@ -475,7 +501,7 @@ def put_type(self, tp: NbType, name: Optional[str]):
self.write_ln(self.simplify_types(s))
self.output = self.output[:-1] + ":\n"
else:
self.write_ln(f"class {tp.__name__}:")
self.write_ln(f"class {tp_name}:")
if tp_bases is None:
tp_bases = getattr(tp, "__orig_bases__", None)
if tp_bases is None:
Expand Down Expand Up @@ -531,16 +557,27 @@ def put_value(self, value: object, name: str, parent: Optional[object] = None, a
"""
tp = type(value)

# Ignore module imports of non-type values like 'from typing import Optional'
if (
not self.include_external_imports
and tp.__module__ == "typing"
and str(value) == f"typing.{name}"
):
return

if isinstance(parent, type) and issubclass(tp, parent) and self.is_enum(parent):
# This is an entry of an enumeration
self.write_ln(f"{name}: {self.type_str(tp)}")
if value.__doc__ and self.include_docstrings:
self.put_docstr(value.__doc__)
self.write("\n")
elif self.is_function(tp) or isinstance(value, type):
# This is a function or a type, import it from its actual source
value = cast(type, value)
self.import_object(value.__module__, value.__name__, name)
named_value = cast(NamedObject, value)
same_toplevel_module = named_value.__module__.split(".")[0] == self.module.__name__.split(".")[0]

if self.include_external_imports or (same_toplevel_module and self.include_internal_imports):
# This is a function or a type, import it from its actual source
self.import_object(named_value.__module__, named_value.__name__, name)
else:
value_str = self.expr_str(value, abbrev)

Expand Down Expand Up @@ -728,6 +765,10 @@ def put(self, value: object, name: Optional[str] = None, parent: Optional[object

if ismodule(value):
if len(self.stack) != 1:
is_external = value.__name__.split(".")[0] != self.module.__name__.split(".")[0]
if not self.include_external_imports and is_external:
return

# Do not recurse into submodules, but include a directive to import them
self.import_object(value.__name__, name=None, as_name=name)
return
Expand Down Expand Up @@ -831,8 +872,9 @@ def import_object(
def expr_str(self, e: Any, abbrev: bool = True) -> Optional[str]:
"""
Attempt to convert a value into valid Python syntax that regenerates
that value. When ``abbrev`` is True, give up and replace with '...' if
the expression is too complicated to be included in the stubs
that value. When ``abbrev`` is True, the implementation gives up and
returns ``None`` when the expression is considered to be too
complicated.
"""
tp = type(e)
for t in [bool, int, type(None), type(builtins.Ellipsis)]:
Expand Down
4 changes: 4 additions & 0 deletions tests/py_stub_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
else:
import typing

# Ignore a type and a function from elsewhere. These shouldn't be included in
# the stub by default
from os import PathLike, getcwd

del sys

C = 123
Expand Down

0 comments on commit e1cb670

Please sign in to comment.