diff --git a/mypy/semanal.py b/mypy/semanal.py index eceb96ca63ee..51310e4f3e4d 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -1915,7 +1915,7 @@ def get_all_bases_tvars( except TypeTranslationError: # This error will be caught later. continue - base_tvars = base.accept(TypeVarLikeQuery(self.lookup_qualified, self.tvar_scope)) + base_tvars = base.accept(TypeVarLikeQuery(self, self.tvar_scope)) tvars.extend(base_tvars) return remove_dups(tvars) @@ -1933,7 +1933,7 @@ def get_and_bind_all_tvars(self, type_exprs: list[Expression]) -> list[TypeVarLi except TypeTranslationError: # This error will be caught later. continue - base_tvars = base.accept(TypeVarLikeQuery(self.lookup_qualified, self.tvar_scope)) + base_tvars = base.accept(TypeVarLikeQuery(self, self.tvar_scope)) tvars.extend(base_tvars) tvars = remove_dups(tvars) # Variables are defined in order of textual appearance. tvar_defs = [] @@ -3294,7 +3294,7 @@ def analyze_alias( ) return None, [], set(), [] - found_type_vars = typ.accept(TypeVarLikeQuery(self.lookup_qualified, self.tvar_scope)) + found_type_vars = typ.accept(TypeVarLikeQuery(self, self.tvar_scope)) tvar_defs: list[TypeVarLikeType] = [] namespace = self.qualified_name(name) with self.tvar_scope_frame(self.tvar_scope.class_frame(namespace)): diff --git a/mypy/type_visitor.py b/mypy/type_visitor.py index 823e74e7e283..c5324357117b 100644 --- a/mypy/type_visitor.py +++ b/mypy/type_visitor.py @@ -315,7 +315,7 @@ class TypeQuery(SyntheticTypeVisitor[T]): # TODO: check that we don't have existing violations of this rule. """ - def __init__(self, strategy: Callable[[Iterable[T]], T]) -> None: + def __init__(self, strategy: Callable[[list[T]], T]) -> None: self.strategy = strategy # Keep track of the type aliases already visited. This is needed to avoid # infinite recursion on types like A = Union[int, List[A]]. diff --git a/mypy/typeanal.py b/mypy/typeanal.py index 28f293613d50..0755b21854de 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -4,7 +4,6 @@ import itertools from contextlib import contextmanager -from itertools import chain from typing import Callable, Iterable, Iterator, List, Sequence, Tuple, TypeVar from typing_extensions import Final, Protocol @@ -203,8 +202,6 @@ def __init__( allow_type_any: bool = False, ) -> None: self.api = api - self.lookup_qualified = api.lookup_qualified - self.lookup_fqn_func = api.lookup_fully_qualified self.fail_func = api.fail self.note_func = api.note self.tvar_scope = tvar_scope @@ -244,6 +241,14 @@ def __init__( # Allow variables typed as Type[Any] and type (useful for base classes). self.allow_type_any = allow_type_any + def lookup_qualified( + self, name: str, ctx: Context, suppress_errors: bool = False + ) -> SymbolTableNode | None: + return self.api.lookup_qualified(name, ctx, suppress_errors) + + def lookup_fully_qualified(self, name: str) -> SymbolTableNode: + return self.api.lookup_fully_qualified(name) + def visit_unbound_type(self, t: UnboundType, defining_literal: bool = False) -> Type: typ = self.visit_unbound_type_nonoptional(t, defining_literal) if t.optional: @@ -1408,14 +1413,17 @@ def tvar_scope_frame(self) -> Iterator[None]: yield self.tvar_scope = old_scope + def find_type_var_likes(self, t: Type, include_callables: bool = True) -> TypeVarLikeList: + return t.accept( + TypeVarLikeQuery(self.api, self.tvar_scope, include_callables=include_callables) + ) + def infer_type_variables(self, type: CallableType) -> list[tuple[str, TypeVarLikeExpr]]: """Return list of unique type variables referred to in a callable.""" names: list[str] = [] tvars: list[TypeVarLikeExpr] = [] for arg in type.arg_types: - for name, tvar_expr in arg.accept( - TypeVarLikeQuery(self.lookup_qualified, self.tvar_scope) - ): + for name, tvar_expr in self.find_type_var_likes(arg): if name not in names: names.append(name) tvars.append(tvar_expr) @@ -1423,12 +1431,13 @@ def infer_type_variables(self, type: CallableType) -> list[tuple[str, TypeVarLik # look inside Callable types. Type variables only appearing in # functions in the return type belong to those functions, not the # function we're currently analyzing. - for name, tvar_expr in type.ret_type.accept( - TypeVarLikeQuery(self.lookup_qualified, self.tvar_scope, include_callables=False) - ): + for name, tvar_expr in self.find_type_var_likes(type.ret_type, include_callables=False): if name not in names: names.append(name) tvars.append(tvar_expr) + + if not names: + return [] # Fast path return list(zip(names, tvars)) def bind_function_type_variables( @@ -1546,7 +1555,7 @@ def named_type( line: int = -1, column: int = -1, ) -> Instance: - node = self.lookup_fqn_func(fully_qualified_name) + node = self.lookup_fully_qualified(fully_qualified_name) assert isinstance(node.node, TypeInfo) any_type = AnyType(TypeOfAny.special_form) if args is not None: @@ -1785,7 +1794,9 @@ def set_any_tvars( return TypeAliasType(node, [any_type] * len(node.alias_tvars), newline, newcolumn) -def remove_dups(tvars: Iterable[T]) -> list[T]: +def remove_dups(tvars: list[T]) -> list[T]: + if len(tvars) <= 1: + return tvars # Get unique elements in order of appearance all_tvars: set[T] = set() new_tvars: list[T] = [] @@ -1796,8 +1807,13 @@ def remove_dups(tvars: Iterable[T]) -> list[T]: return new_tvars -def flatten_tvars(ll: Iterable[list[T]]) -> list[T]: - return remove_dups(chain.from_iterable(ll)) +def flatten_tvars(lists: list[list[T]]) -> list[T]: + result: list[T] = [] + for lst in lists: + for item in lst: + if item not in result: + result.append(item) + return result class TypeVarLikeQuery(TypeQuery[TypeVarLikeList]): @@ -1805,17 +1821,15 @@ class TypeVarLikeQuery(TypeQuery[TypeVarLikeList]): def __init__( self, - lookup: Callable[[str, Context], SymbolTableNode | None], + api: SemanticAnalyzerCoreInterface, scope: TypeVarLikeScope, *, include_callables: bool = True, - include_bound_tvars: bool = False, ) -> None: - self.include_callables = include_callables - self.lookup = lookup - self.scope = scope - self.include_bound_tvars = include_bound_tvars super().__init__(flatten_tvars) + self.api = api + self.scope = scope + self.include_callables = include_callables # Only include type variables in type aliases args. This would be anyway # that case if we expand (as target variables would be overridden with args) # and it may cause infinite recursion on invalid (diverging) recursive aliases. @@ -1833,16 +1847,16 @@ def visit_unbound_type(self, t: UnboundType) -> TypeVarLikeList: if name.endswith("args"): if name.endswith(".args") or name.endswith(".kwargs"): base = ".".join(name.split(".")[:-1]) - n = self.lookup(base, t) + n = self.api.lookup_qualified(base, t) if n is not None and isinstance(n.node, ParamSpecExpr): node = n name = base if node is None: - node = self.lookup(name, t) + node = self.api.lookup_qualified(name, t) if ( node and isinstance(node.node, TypeVarLikeExpr) - and (self.include_bound_tvars or self.scope.get_binding(node) is None) + and self.scope.get_binding(node) is None ): assert isinstance(node.node, TypeVarLikeExpr) return [(name, node.node)]