Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some semantic analyzer micro-optimizations #14367

Merged
merged 1 commit into from
Dec 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1915,7 +1915,7 @@ def get_all_bases_tvars(
except TypeTranslationError:
# This error will be caught later.
continue
base_tvars = base.accept(TypeVarLikeQuery(self.lookup_qualified, self.tvar_scope))
base_tvars = base.accept(TypeVarLikeQuery(self, self.tvar_scope))
tvars.extend(base_tvars)
return remove_dups(tvars)

Expand All @@ -1933,7 +1933,7 @@ def get_and_bind_all_tvars(self, type_exprs: list[Expression]) -> list[TypeVarLi
except TypeTranslationError:
# This error will be caught later.
continue
base_tvars = base.accept(TypeVarLikeQuery(self.lookup_qualified, self.tvar_scope))
base_tvars = base.accept(TypeVarLikeQuery(self, self.tvar_scope))
tvars.extend(base_tvars)
tvars = remove_dups(tvars) # Variables are defined in order of textual appearance.
tvar_defs = []
Expand Down Expand Up @@ -3294,7 +3294,7 @@ def analyze_alias(
)
return None, [], set(), []

found_type_vars = typ.accept(TypeVarLikeQuery(self.lookup_qualified, self.tvar_scope))
found_type_vars = typ.accept(TypeVarLikeQuery(self, self.tvar_scope))
tvar_defs: list[TypeVarLikeType] = []
namespace = self.qualified_name(name)
with self.tvar_scope_frame(self.tvar_scope.class_frame(namespace)):
Expand Down
2 changes: 1 addition & 1 deletion mypy/type_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ class TypeQuery(SyntheticTypeVisitor[T]):
# TODO: check that we don't have existing violations of this rule.
"""

def __init__(self, strategy: Callable[[Iterable[T]], T]) -> None:
def __init__(self, strategy: Callable[[list[T]], T]) -> None:
self.strategy = strategy
# Keep track of the type aliases already visited. This is needed to avoid
# infinite recursion on types like A = Union[int, List[A]].
Expand Down
58 changes: 36 additions & 22 deletions mypy/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import itertools
from contextlib import contextmanager
from itertools import chain
from typing import Callable, Iterable, Iterator, List, Sequence, Tuple, TypeVar
from typing_extensions import Final, Protocol

Expand Down Expand Up @@ -203,8 +202,6 @@ def __init__(
allow_type_any: bool = False,
) -> None:
self.api = api
self.lookup_qualified = api.lookup_qualified
self.lookup_fqn_func = api.lookup_fully_qualified
self.fail_func = api.fail
self.note_func = api.note
self.tvar_scope = tvar_scope
Expand Down Expand Up @@ -244,6 +241,14 @@ def __init__(
# Allow variables typed as Type[Any] and type (useful for base classes).
self.allow_type_any = allow_type_any

def lookup_qualified(
self, name: str, ctx: Context, suppress_errors: bool = False
) -> SymbolTableNode | None:
return self.api.lookup_qualified(name, ctx, suppress_errors)

def lookup_fully_qualified(self, name: str) -> SymbolTableNode:
return self.api.lookup_fully_qualified(name)

def visit_unbound_type(self, t: UnboundType, defining_literal: bool = False) -> Type:
typ = self.visit_unbound_type_nonoptional(t, defining_literal)
if t.optional:
Expand Down Expand Up @@ -1408,27 +1413,31 @@ def tvar_scope_frame(self) -> Iterator[None]:
yield
self.tvar_scope = old_scope

def find_type_var_likes(self, t: Type, include_callables: bool = True) -> TypeVarLikeList:
return t.accept(
TypeVarLikeQuery(self.api, self.tvar_scope, include_callables=include_callables)
)

def infer_type_variables(self, type: CallableType) -> list[tuple[str, TypeVarLikeExpr]]:
"""Return list of unique type variables referred to in a callable."""
names: list[str] = []
tvars: list[TypeVarLikeExpr] = []
for arg in type.arg_types:
for name, tvar_expr in arg.accept(
TypeVarLikeQuery(self.lookup_qualified, self.tvar_scope)
):
for name, tvar_expr in self.find_type_var_likes(arg):
if name not in names:
names.append(name)
tvars.append(tvar_expr)
# When finding type variables in the return type of a function, don't
# look inside Callable types. Type variables only appearing in
# functions in the return type belong to those functions, not the
# function we're currently analyzing.
for name, tvar_expr in type.ret_type.accept(
TypeVarLikeQuery(self.lookup_qualified, self.tvar_scope, include_callables=False)
):
for name, tvar_expr in self.find_type_var_likes(type.ret_type, include_callables=False):
if name not in names:
names.append(name)
tvars.append(tvar_expr)

if not names:
return [] # Fast path
return list(zip(names, tvars))

def bind_function_type_variables(
Expand Down Expand Up @@ -1546,7 +1555,7 @@ def named_type(
line: int = -1,
column: int = -1,
) -> Instance:
node = self.lookup_fqn_func(fully_qualified_name)
node = self.lookup_fully_qualified(fully_qualified_name)
assert isinstance(node.node, TypeInfo)
any_type = AnyType(TypeOfAny.special_form)
if args is not None:
Expand Down Expand Up @@ -1785,7 +1794,9 @@ def set_any_tvars(
return TypeAliasType(node, [any_type] * len(node.alias_tvars), newline, newcolumn)


def remove_dups(tvars: Iterable[T]) -> list[T]:
def remove_dups(tvars: list[T]) -> list[T]:
if len(tvars) <= 1:
return tvars
# Get unique elements in order of appearance
all_tvars: set[T] = set()
new_tvars: list[T] = []
Expand All @@ -1796,26 +1807,29 @@ def remove_dups(tvars: Iterable[T]) -> list[T]:
return new_tvars


def flatten_tvars(ll: Iterable[list[T]]) -> list[T]:
return remove_dups(chain.from_iterable(ll))
def flatten_tvars(lists: list[list[T]]) -> list[T]:
result: list[T] = []
for lst in lists:
for item in lst:
if item not in result:
result.append(item)
return result


class TypeVarLikeQuery(TypeQuery[TypeVarLikeList]):
"""Find TypeVar and ParamSpec references in an unbound type."""

def __init__(
self,
lookup: Callable[[str, Context], SymbolTableNode | None],
api: SemanticAnalyzerCoreInterface,
scope: TypeVarLikeScope,
*,
include_callables: bool = True,
include_bound_tvars: bool = False,
) -> None:
self.include_callables = include_callables
self.lookup = lookup
self.scope = scope
self.include_bound_tvars = include_bound_tvars
super().__init__(flatten_tvars)
self.api = api
self.scope = scope
self.include_callables = include_callables
# Only include type variables in type aliases args. This would be anyway
# that case if we expand (as target variables would be overridden with args)
# and it may cause infinite recursion on invalid (diverging) recursive aliases.
Expand All @@ -1833,16 +1847,16 @@ def visit_unbound_type(self, t: UnboundType) -> TypeVarLikeList:
if name.endswith("args"):
if name.endswith(".args") or name.endswith(".kwargs"):
base = ".".join(name.split(".")[:-1])
n = self.lookup(base, t)
n = self.api.lookup_qualified(base, t)
if n is not None and isinstance(n.node, ParamSpecExpr):
node = n
name = base
if node is None:
node = self.lookup(name, t)
node = self.api.lookup_qualified(name, t)
if (
node
and isinstance(node.node, TypeVarLikeExpr)
and (self.include_bound_tvars or self.scope.get_binding(node) is None)
and self.scope.get_binding(node) is None
):
assert isinstance(node.node, TypeVarLikeExpr)
return [(name, node.node)]
Expand Down