Skip to content

Commit

Permalink
add exec_type_checking
Browse files Browse the repository at this point in the history
  • Loading branch information
zhPavel committed Aug 28, 2024
1 parent 570d839 commit b2a57f8
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 4 deletions.
1 change: 1 addition & 0 deletions src/adaptix/_internal/type_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@
make_norm_type,
normalize_type,
)
from .type_evaler import exec_type_checking
51 changes: 51 additions & 0 deletions src/adaptix/_internal/type_tools/type_evaler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import ast
import inspect
from collections.abc import Callable, Sequence
from types import ModuleType


def make_fragments_collector(*, typing_modules: Sequence[str]) -> Callable[[ast.Module], list[ast.stmt]]:
def check_condition(expr: ast.expr) -> bool:
# searches for `TYPE_CHECKING`
if isinstance(expr, ast.Name) and isinstance(expr.ctx, ast.Load):
return True

# searches for `typing.TYPE_CHECKING`
if ( # noqa: SIM103
isinstance(expr, ast.Attribute)
and expr.attr == "TYPE_CHECKING"
and isinstance(expr.ctx, ast.Load)
and isinstance(expr.value, ast.Name)
and expr.value.id in typing_modules
and isinstance(expr.value.ctx, ast.Load)
):
return True
return False

def collect_type_checking_only_fragments(module: ast.Module) -> list[ast.stmt]:
fragments = []
for stmt in module.body:
if isinstance(stmt, ast.If) and not stmt.orelse and check_condition(stmt.test):
fragments.extend(stmt.body)

return fragments

return collect_type_checking_only_fragments


default_collector = make_fragments_collector(typing_modules=["typing"])


def exec_type_checking(
module: ModuleType,
*,
collector: Callable[[ast.Module], list[ast.stmt]] = default_collector,
) -> None:
source = inspect.getsource(module)
fragments = collector(ast.parse(source))
code = compile(ast.Module(fragments, type_ignores=[]), f"<exec_type_checking of {module}>", "exec")
namespace = module.__dict__.copy()
exec(code, namespace) # noqa: S102
for k, v in namespace.items():
if not hasattr(module, k):
setattr(module, k, v)
5 changes: 5 additions & 0 deletions src/adaptix/type_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from adaptix._internal.type_tools import exec_type_checking

__all__ = (
"exec_type_checking",
)
26 changes: 22 additions & 4 deletions tests/tests_helpers/tests_helpers/misc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import dataclasses
import importlib.util
import inspect
import re
import runpy
Expand Down Expand Up @@ -189,6 +190,15 @@ def load_namespace(
return SimpleNamespace(**ns_dict)


@contextmanager
def temp_module(module: ModuleType):
sys.modules[module.__name__] = module
try:
yield
finally:
sys.modules.pop(module.__name__, None)


@contextmanager
def load_namespace_keeping_module(
file_name: str,
Expand All @@ -202,11 +212,19 @@ def load_namespace_keeping_module(
module = ModuleType(run_name)
for attr, value in ns.__dict__.items():
setattr(module, attr, value)
sys.modules[run_name] = module
try:

with temp_module(module):
yield ns
finally:
sys.modules.pop(run_name, None)


def import_local_module(file_path: Path, name: Optional[str] = None) -> ModuleType:
if name is None:
name = file_path.stem

spec = importlib.util.spec_from_file_location(name, file_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module


def with_notes(exc: E, *notes: Union[str, list[str]]) -> E:
Expand Down
16 changes: 16 additions & 0 deletions tests/unit/type_tools/data_type_checking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import typing
from collections.abc import Sequence
from typing import TYPE_CHECKING

if TYPE_CHECKING:
IntSeq = Sequence[int]


if typing.TYPE_CHECKING:
StrSeq = Sequence[str]


class Foo:
a: bool
b: "IntSeq"
c: "StrSeq"
18 changes: 18 additions & 0 deletions tests/unit/type_tools/test_type_evaler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from collections.abc import Sequence
from pathlib import Path

from tests_helpers.misc import import_local_module, temp_module

from adaptix._internal.type_tools import get_all_type_hints
from adaptix.type_tools import exec_type_checking


def test_exec_type_checking():
module = import_local_module(Path(__file__).with_name("data_type_checking.py"))
with temp_module(module):
exec_type_checking(module)
assert get_all_type_hints(module.Foo) == {
"a": bool,
"b": Sequence[int],
"c": Sequence[str],
}

0 comments on commit b2a57f8

Please sign in to comment.