Skip to content

Commit

Permalink
Merge pull request #1 from python/master
Browse files Browse the repository at this point in the history
Fix daemon crash on malformed NamedTuple (python#14119)
  • Loading branch information
ChristianWitzler authored Nov 21, 2022
2 parents 15c37df + c660354 commit 1cd9216
Show file tree
Hide file tree
Showing 20 changed files with 213 additions and 40 deletions.
4 changes: 2 additions & 2 deletions docs/source/generics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -724,7 +724,7 @@ achieved by combining with :py:func:`@overload <typing.overload>`:

.. code-block:: python
from typing import Any, Callable, TypeVar, overload
from typing import Any, Callable, Optional, TypeVar, overload
F = TypeVar('F', bound=Callable[..., Any])
Expand All @@ -736,7 +736,7 @@ achieved by combining with :py:func:`@overload <typing.overload>`:
def atomic(*, savepoint: bool = True) -> Callable[[F], F]: ...
# Implementation
def atomic(__func: Callable[..., Any] = None, *, savepoint: bool = True):
def atomic(__func: Optional[Callable[..., Any]] = None, *, savepoint: bool = True):
def decorator(func: Callable[..., Any]):
... # Code goes here
if __func is not None:
Expand Down
1 change: 1 addition & 0 deletions misc/sync-typeshed.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def main() -> None:
commits_to_cherry_pick = [
"780534b13722b7b0422178c049a1cbbf4ea4255b", # LiteralString reverts
"5319fa34a8004c1568bb6f032a07b8b14cc95bed", # sum reverts
"0062994228fb62975c6cef4d2c80d00c7aa1c545", # ctypes reverts
]
for commit in commits_to_cherry_pick:
subprocess.run(["git", "cherry-pick", commit], check=True)
Expand Down
1 change: 1 addition & 0 deletions mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,7 @@ class FreezeTypeVarsVisitor(TypeTraverserVisitor):
def visit_callable_type(self, t: CallableType) -> None:
for v in t.variables:
v.id.meta_level = 0
super().visit_callable_type(t)


def lookup_member_var_or_accessor(info: TypeInfo, name: str, is_lvalue: bool) -> SymbolNode | None:
Expand Down
4 changes: 4 additions & 0 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,6 +1068,7 @@ class ClassDef(Statement):
"analyzed",
"has_incompatible_baseclass",
"deco_line",
"removed_statements",
)

__match_args__ = ("name", "defs")
Expand All @@ -1086,6 +1087,8 @@ class ClassDef(Statement):
keywords: dict[str, Expression]
analyzed: Expression | None
has_incompatible_baseclass: bool
# Used by special forms like NamedTuple and TypedDict to store invalid statements
removed_statements: list[Statement]

def __init__(
self,
Expand All @@ -1111,6 +1114,7 @@ def __init__(
self.has_incompatible_baseclass = False
# Used for error reporting (to keep backwad compatibility with pre-3.8)
self.deco_line: int | None = None
self.removed_statements = []

def accept(self, visitor: StatementVisitor[T]) -> T:
return visitor.visit_class_def(self)
Expand Down
8 changes: 7 additions & 1 deletion mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1448,7 +1448,13 @@ def visit_decorator(self, dec: Decorator) -> None:
dec.var.is_classmethod = True
self.check_decorated_function_is_method("classmethod", dec)
elif refers_to_fullname(
d, ("builtins.property", "abc.abstractproperty", "functools.cached_property")
d,
(
"builtins.property",
"abc.abstractproperty",
"functools.cached_property",
"enum.property",
),
):
removed.append(i)
dec.func.is_property = True
Expand Down
19 changes: 14 additions & 5 deletions mypy/semanal_namedtuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
NameExpr,
PassStmt,
RefExpr,
Statement,
StrExpr,
SymbolTable,
SymbolTableNode,
Expand Down Expand Up @@ -111,7 +112,7 @@ def analyze_namedtuple_classdef(
if result is None:
# This is a valid named tuple, but some types are incomplete.
return True, None
items, types, default_items = result
items, types, default_items, statements = result
if is_func_scope and "@" not in defn.name:
defn.name += "@" + str(defn.line)
existing_info = None
Expand All @@ -123,31 +124,35 @@ def analyze_namedtuple_classdef(
defn.analyzed = NamedTupleExpr(info, is_typed=True)
defn.analyzed.line = defn.line
defn.analyzed.column = defn.column
defn.defs.body = statements
# All done: this is a valid named tuple with all types known.
return True, info
# This can't be a valid named tuple.
return False, None

def check_namedtuple_classdef(
self, defn: ClassDef, is_stub_file: bool
) -> tuple[list[str], list[Type], dict[str, Expression]] | None:
) -> tuple[list[str], list[Type], dict[str, Expression], list[Statement]] | None:
"""Parse and validate fields in named tuple class definition.
Return a three tuple:
Return a four tuple:
* field names
* field types
* field default values
* valid statements
or None, if any of the types are not ready.
"""
if self.options.python_version < (3, 6) and not is_stub_file:
self.fail("NamedTuple class syntax is only supported in Python 3.6", defn)
return [], [], {}
return [], [], {}, []
if len(defn.base_type_exprs) > 1:
self.fail("NamedTuple should be a single base", defn)
items: list[str] = []
types: list[Type] = []
default_items: dict[str, Expression] = {}
statements: list[Statement] = []
for stmt in defn.defs.body:
statements.append(stmt)
if not isinstance(stmt, AssignmentStmt):
# Still allow pass or ... (for empty namedtuples).
if isinstance(stmt, PassStmt) or (
Expand All @@ -160,9 +165,13 @@ def check_namedtuple_classdef(
# And docstrings.
if isinstance(stmt, ExpressionStmt) and isinstance(stmt.expr, StrExpr):
continue
statements.pop()
defn.removed_statements.append(stmt)
self.fail(NAMEDTUP_CLASS_ERROR, stmt)
elif len(stmt.lvalues) > 1 or not isinstance(stmt.lvalues[0], NameExpr):
# An assignment, but an invalid one.
statements.pop()
defn.removed_statements.append(stmt)
self.fail(NAMEDTUP_CLASS_ERROR, stmt)
else:
# Append name and type in this case...
Expand Down Expand Up @@ -199,7 +208,7 @@ def check_namedtuple_classdef(
)
else:
default_items[name] = stmt.rvalue
return items, types, default_items
return items, types, default_items, statements

def check_namedtuple(
self, node: Expression, var_name: str | None, is_func_scope: bool
Expand Down
2 changes: 2 additions & 0 deletions mypy/semanal_typeddict.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,11 @@ def analyze_typeddict_classdef_fields(
):
statements.append(stmt)
else:
defn.removed_statements.append(stmt)
self.fail(TPDICT_CLASS_ERROR, stmt)
elif len(stmt.lvalues) > 1 or not isinstance(stmt.lvalues[0], NameExpr):
# An assignment, but an invalid one.
defn.removed_statements.append(stmt)
self.fail(TPDICT_CLASS_ERROR, stmt)
else:
name = stmt.lvalues[0].name
Expand Down
2 changes: 2 additions & 0 deletions mypy/server/aststrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ def visit_class_def(self, node: ClassDef) -> None:
]
with self.enter_class(node.info):
super().visit_class_def(node)
node.defs.body.extend(node.removed_statements)
node.removed_statements = []
TypeState.reset_subtype_caches_for(node.info)
# Kill the TypeInfo, since there is none before semantic analysis.
node.info = CLASSDEF_NO_INFO
Expand Down
7 changes: 7 additions & 0 deletions mypy/test/testtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
UninhabitedType,
UnionType,
get_proper_type,
has_recursive_types,
)


Expand Down Expand Up @@ -157,6 +158,12 @@ def test_type_alias_expand_all(self) -> None:
[self.fx.a, self.fx.a], Instance(self.fx.std_tuplei, [self.fx.a])
)

def test_recursive_nested_in_non_recursive(self) -> None:
A, _ = self.fx.def_alias_1(self.fx.a)
NA = self.fx.non_rec_alias(Instance(self.fx.gi, [UnboundType("T")]), ["T"], [A])
assert not NA.is_recursive
assert has_recursive_types(NA)

def test_indirection_no_infinite_recursion(self) -> None:
A, _ = self.fx.def_alias_1(self.fx.a)
visitor = TypeIndirectionVisitor()
Expand Down
10 changes: 7 additions & 3 deletions mypy/test/typefixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,9 +339,13 @@ def def_alias_2(self, base: Instance) -> tuple[TypeAliasType, Type]:
A.alias = AN
return A, target

def non_rec_alias(self, target: Type) -> TypeAliasType:
AN = TypeAlias(target, "__main__.A", -1, -1)
return TypeAliasType(AN, [])
def non_rec_alias(
self, target: Type, alias_tvars: list[str] | None = None, args: list[Type] | None = None
) -> TypeAliasType:
AN = TypeAlias(target, "__main__.A", -1, -1, alias_tvars=alias_tvars)
if args is None:
args = []
return TypeAliasType(AN, args)


class InterfaceTypeFixture(TypeFixture):
Expand Down
24 changes: 8 additions & 16 deletions mypy/type_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,24 +404,16 @@ def visit_placeholder_type(self, t: PlaceholderType) -> T:
return self.query_types(t.args)

def visit_type_alias_type(self, t: TypeAliasType) -> T:
# Skip type aliases already visited types to avoid infinite recursion.
# TODO: Ideally we should fire subvisitors here (or use caching) if we care
# about duplicates.
if t in self.seen_aliases:
return self.strategy([])
self.seen_aliases.add(t)
if self.skip_alias_target:
return self.query_types(t.args)
return get_proper_type(t).accept(self)

def query_types(self, types: Iterable[Type]) -> T:
"""Perform a query for a list of types.
Use the strategy to combine the results.
Skip type aliases already visited types to avoid infinite recursion.
"""
res: list[T] = []
for t in types:
if isinstance(t, TypeAliasType):
# Avoid infinite recursion for recursive type aliases.
# TODO: Ideally we should fire subvisitors here (or use caching) if we care
# about duplicates.
if t in self.seen_aliases:
continue
self.seen_aliases.add(t)
res.append(t.accept(self))
return self.strategy(res)
"""Perform a query for a list of types using the strategy to combine the results."""
return self.strategy([t.accept(self) for t in types])
2 changes: 1 addition & 1 deletion mypy/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ def try_analyze_special_unbound_type(self, t: UnboundType, fullname: str) -> Typ
if fullname == "builtins.None":
return NoneType()
elif fullname == "typing.Any" or fullname == "builtins.Any":
return AnyType(TypeOfAny.explicit)
return AnyType(TypeOfAny.explicit, line=t.line, column=t.column)
elif fullname in FINAL_TYPE_NAMES:
self.fail(
"Final can be only used as an outermost qualifier in a variable annotation",
Expand Down
26 changes: 19 additions & 7 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,30 +278,42 @@ def _expand_once(self) -> Type:
self.alias.target, self.alias.alias_tvars, self.args, self.line, self.column
)

def _partial_expansion(self) -> tuple[ProperType, bool]:
def _partial_expansion(self, nothing_args: bool = False) -> tuple[ProperType, bool]:
# Private method mostly for debugging and testing.
unroller = UnrollAliasVisitor(set())
unrolled = self.accept(unroller)
if nothing_args:
alias = self.copy_modified(args=[UninhabitedType()] * len(self.args))
else:
alias = self
unrolled = alias.accept(unroller)
assert isinstance(unrolled, ProperType)
return unrolled, unroller.recursed

def expand_all_if_possible(self) -> ProperType | None:
def expand_all_if_possible(self, nothing_args: bool = False) -> ProperType | None:
"""Attempt a full expansion of the type alias (including nested aliases).
If the expansion is not possible, i.e. the alias is (mutually-)recursive,
return None.
return None. If nothing_args is True, replace all type arguments with an
UninhabitedType() (used to detect recursively defined aliases).
"""
unrolled, recursed = self._partial_expansion()
unrolled, recursed = self._partial_expansion(nothing_args=nothing_args)
if recursed:
return None
return unrolled

@property
def is_recursive(self) -> bool:
"""Whether this type alias is recursive.
Note this doesn't check generic alias arguments, but only if this alias
*definition* is recursive. The property value thus can be cached on the
underlying TypeAlias node. If you want to include all nested types, use
has_recursive_types() function.
"""
assert self.alias is not None, "Unfixed type alias"
is_recursive = self.alias._is_recursive
if is_recursive is None:
is_recursive = self.expand_all_if_possible() is None
is_recursive = self.expand_all_if_possible(nothing_args=True) is None
# We cache the value on the underlying TypeAlias node as an optimization,
# since the value is the same for all instances of the same alias.
self.alias._is_recursive = is_recursive
Expand Down Expand Up @@ -3259,7 +3271,7 @@ def __init__(self) -> None:
super().__init__(any)

def visit_type_alias_type(self, t: TypeAliasType) -> bool:
return t.is_recursive
return t.is_recursive or self.query_types(t.args)


def has_recursive_types(typ: Type) -> bool:
Expand Down
6 changes: 5 additions & 1 deletion mypy/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,11 @@ def parse_gray_color(cup: bytes) -> str:


def should_force_color() -> bool:
return bool(int(os.getenv("MYPY_FORCE_COLOR", os.getenv("FORCE_COLOR", "0"))))
env_var = os.getenv("MYPY_FORCE_COLOR", os.getenv("FORCE_COLOR", "0"))
try:
return bool(int(env_var))
except ValueError:
return bool(env_var)


class FancyFormatter:
Expand Down
2 changes: 0 additions & 2 deletions test-data/unit/check-class-namedtuple.test
Original file line number Diff line number Diff line change
Expand Up @@ -393,8 +393,6 @@ class X(typing.NamedTuple):
[out]
main:6: error: Invalid statement in NamedTuple definition; expected "field_name: field_type [= default]"
main:7: error: Invalid statement in NamedTuple definition; expected "field_name: field_type [= default]"
main:7: error: Type cannot be declared in assignment to non-self attribute
main:7: error: "int" has no attribute "x"
main:9: error: Non-default NamedTuple fields cannot follow default fields

[builtins fixtures/list.pyi]
Expand Down
14 changes: 14 additions & 0 deletions test-data/unit/check-incremental.test
Original file line number Diff line number Diff line change
Expand Up @@ -6334,3 +6334,17 @@ reveal_type(D().meth)
[out2]
tmp/m.py:4: note: Revealed type is "def [Self <: lib.C] (self: Self`0, other: Self`0) -> Self`0"
tmp/m.py:5: note: Revealed type is "def (other: m.D) -> m.D"

[case testIncrementalNestedGenericCallableCrash]
from typing import TypeVar, Callable

T = TypeVar("T")

class B:
def foo(self) -> Callable[[T], T]: ...

class C(B):
def __init__(self) -> None:
self.x = self.foo()
[out]
[out2]
Loading

0 comments on commit 1cd9216

Please sign in to comment.