Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stubtest: adjust symtable logic #16823

Merged
merged 6 commits into from
Jan 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 30 additions & 29 deletions mypy/stubtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import collections.abc
import copy
import enum
import functools
import importlib
import importlib.machinery
import inspect
Expand Down Expand Up @@ -310,35 +311,23 @@ def _verify_exported_names(
)


def _get_imported_symbol_names(runtime: types.ModuleType) -> frozenset[str] | None:
"""Retrieve the names in the global namespace which are known to be imported.
@functools.lru_cache
def _module_symbol_table(runtime: types.ModuleType) -> symtable.SymbolTable | None:
"""Retrieve the symbol table for the module (or None on failure).

1). Use inspect to retrieve the source code of the module
2). Use symtable to parse the source and retrieve names that are known to be imported
from other modules.

If either of the above steps fails, return `None`.

Note that if a set of names is returned,
it won't include names imported via `from foo import *` imports.
1) Use inspect to retrieve the source code of the module
2) Use symtable to parse the source (and use what symtable knows for its purposes)
"""
try:
source = inspect.getsource(runtime)
except (OSError, TypeError, SyntaxError):
return None

if not source.strip():
# The source code for the module was an empty file,
# no point in parsing it with symtable
return frozenset()

try:
module_symtable = symtable.symtable(source, runtime.__name__, "exec")
return symtable.symtable(source, runtime.__name__, "exec")
except SyntaxError:
return None

return frozenset(sym.get_name() for sym in module_symtable.get_symbols() if sym.is_imported())


@verify.register(nodes.MypyFile)
def verify_mypyfile(
Expand Down Expand Up @@ -369,25 +358,37 @@ def verify_mypyfile(
if not o.module_hidden and (not is_probably_private(m) or hasattr(runtime, m))
}

imported_symbols = _get_imported_symbol_names(runtime)

def _belongs_to_runtime(r: types.ModuleType, attr: str) -> bool:
"""Heuristics to determine whether a name originates from another module."""
obj = getattr(r, attr)
if isinstance(obj, types.ModuleType):
return False
if callable(obj):
# It's highly likely to be a class or a function if it's callable,
# so the __module__ attribute will give a good indication of which module it comes from

symbol_table = _module_symbol_table(r)
if symbol_table is not None:
try:
obj_mod = obj.__module__
except Exception:
symbol = symbol_table.lookup(attr)
except KeyError:
pass
else:
if isinstance(obj_mod, str):
return bool(obj_mod == r.__name__)
if imported_symbols is not None:
return attr not in imported_symbols
if symbol.is_imported():
# symtable says we got this from another module
return False
# But we can't just return True here, because symtable doesn't know about symbols
# that come from `from module import *`
if symbol.is_assigned():
# symtable knows we assigned this symbol in the module
return True

# The __module__ attribute is unreliable for anything except functions and classes,
# but it's our best guess at this point
try:
obj_mod = obj.__module__
except Exception:
pass
else:
if isinstance(obj_mod, str):
return bool(obj_mod == r.__name__)
hauntsaninja marked this conversation as resolved.
Show resolved Hide resolved
return True

runtime_public_contents = (
Expand Down
18 changes: 18 additions & 0 deletions mypy/test/teststubtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1266,6 +1266,24 @@ def test_missing_no_runtime_all(self) -> Iterator[Case]:
yield Case(stub="", runtime="from json.scanner import NUMBER_RE", error=None)
yield Case(stub="", runtime="from string import ascii_letters", error=None)

@collect_cases
def test_missing_no_runtime_all_terrible(self) -> Iterator[Case]:
yield Case(
stub="",
runtime="""
import sys
import types
import __future__
_m = types.SimpleNamespace()
_m.annotations = __future__.annotations
sys.modules["_terrible_stubtest_test_module"] = _m

from _terrible_stubtest_test_module import *
assert annotations
""",
error=None,
)

@collect_cases
def test_non_public_1(self) -> Iterator[Case]:
yield Case(
Expand Down
Loading