Skip to content

Commit

Permalink
stubtest: Improve heuristics for determining whether global-namespace…
Browse files Browse the repository at this point in the history
… names are imported (#14270)

Stubtest currently has both false-positives and false-negatives when it
comes to verifying constants in the global namespace of a module.

This PR fixes the false positive by using `inspect.getsourcelines()` to
dynamically retrieve the module source code. It then uses `symtable` to
analyse that source code to gather a list of names which are known to be
imported.

The PR fixes the false negative by only using the `__module__` heuristic
on objects which are callable. The vast majority of callable objects
will be types or functions. For these objects, the `__module__`
attribute will give a good indication of whether the object originates
from another module or not; for other objects, it's less useful.
  • Loading branch information
AlexWaygood authored Dec 22, 2022
1 parent 2514610 commit 31b0413
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 6 deletions.
50 changes: 44 additions & 6 deletions mypy/stubtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import os
import pkgutil
import re
import symtable
import sys
import traceback
import types
Expand Down Expand Up @@ -283,6 +284,36 @@ def _verify_exported_names(
)


def _get_imported_symbol_names(runtime: types.ModuleType) -> frozenset[str] | None:
"""Retrieve the names in the global namespace which are known to be imported.
1). Use inspect to retrieve the source code of the module
2). Use symtable to parse the source and retrieve names that are known to be imported
from other modules.
If either of the above steps fails, return `None`.
Note that if a set of names is returned,
it won't include names imported via `from foo import *` imports.
"""
try:
source = inspect.getsource(runtime)
except (OSError, TypeError, SyntaxError):
return None

if not source.strip():
# The source code for the module was an empty file,
# no point in parsing it with symtable
return frozenset()

try:
module_symtable = symtable.symtable(source, runtime.__name__, "exec")
except SyntaxError:
return None

return frozenset(sym.get_name() for sym in module_symtable.get_symbols() if sym.is_imported())


@verify.register(nodes.MypyFile)
def verify_mypyfile(
stub: nodes.MypyFile, runtime: MaybeMissing[types.ModuleType], object_path: list[str]
Expand Down Expand Up @@ -312,15 +343,22 @@ def verify_mypyfile(
if not o.module_hidden and (not is_probably_private(m) or hasattr(runtime, m))
}

imported_symbols = _get_imported_symbol_names(runtime)

def _belongs_to_runtime(r: types.ModuleType, attr: str) -> bool:
obj = getattr(r, attr)
try:
obj_mod = getattr(obj, "__module__", None)
except Exception:
if isinstance(obj, types.ModuleType):
return False
if obj_mod is not None:
return bool(obj_mod == r.__name__)
return not isinstance(obj, types.ModuleType)
if callable(obj):
try:
obj_mod = getattr(obj, "__module__", None)
except Exception:
return False
if obj_mod is not None:
return bool(obj_mod == r.__name__)
if imported_symbols is not None:
return attr not in imported_symbols
return True

runtime_public_contents = (
runtime_all_as_set
Expand Down
3 changes: 3 additions & 0 deletions mypy/test/teststubtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,6 +1082,9 @@ def test_missing_no_runtime_all(self) -> Iterator[Case]:
yield Case(stub="", runtime="import sys", error=None)
yield Case(stub="", runtime="def g(): ...", error="g")
yield Case(stub="", runtime="CONSTANT = 0", error="CONSTANT")
yield Case(stub="", runtime="import re; constant = re.compile('foo')", error="constant")
yield Case(stub="", runtime="from json.scanner import NUMBER_RE", error=None)
yield Case(stub="", runtime="from string import ascii_letters", error=None)

@collect_cases
def test_non_public_1(self) -> Iterator[Case]:
Expand Down

0 comments on commit 31b0413

Please sign in to comment.