From 64026af910114d68d94b851fb88e1574422848ad Mon Sep 17 00:00:00 2001 From: Max Murin Date: Tue, 3 Jan 2023 10:56:09 -0800 Subject: [PATCH 1/3] fix bug, add test --- mypy/checker.py | 22 ++++- mypy/checkexpr.py | 134 ++++++++++++++++++------------- mypy/nodes.py | 5 +- mypy/treetransform.py | 1 - test-data/unit/check-unions.test | 17 ++++ 5 files changed, 117 insertions(+), 62 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 1c8956ae6722..9af959757ea5 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -4455,7 +4455,7 @@ def visit_for_stmt(self, s: ForStmt) -> None: if s.is_async: iterator_type, item_type = self.analyze_async_iterable_item_type(s.expr) else: - iterator_type, item_type = self.analyze_iterable_item_type(s.expr) + iterator_type, item_type = self.analyze_iterable_item_expression(s.expr) s.inferred_item_type = item_type s.inferred_iterator_type = iterator_type self.analyze_index_variables(s.index, item_type, s.index_type is None, s) @@ -4472,7 +4472,7 @@ def analyze_async_iterable_item_type(self, expr: Expression) -> tuple[Type, Type ) return iterator, item_type - def analyze_iterable_item_type(self, expr: Expression) -> tuple[Type, Type]: + def analyze_iterable_item_expression(self, expr: Expression) -> tuple[Type, Type]: """Analyse iterable expression and return iterator and iterator item types.""" echk = self.expr_checker iterable = get_proper_type(echk.accept(expr)) @@ -4491,6 +4491,24 @@ def analyze_iterable_item_type(self, expr: Expression) -> tuple[Type, Type]: # Non-tuple iterable. return iterator, echk.check_method_call_by_name("__next__", iterator, [], [], expr)[0] + def analyze_iterable_item_type(self, type: Type, context: Context) -> tuple[Type, Type]: + """Analyse iterable expression and return iterator and iterator item types.""" + echk = self.expr_checker + iterable = get_proper_type(type) + iterator = echk.check_method_call_by_name("__iter__", iterable, [], [], context)[0] + + if isinstance(iterable, TupleType): + joined: Type = UninhabitedType() + for item in iterable.items: + joined = join_types(joined, item) + return iterator, joined + else: + # Non-tuple iterable. + return ( + iterator, + echk.check_method_call_by_name("__next__", iterator, [], [], context)[0], + ) + def analyze_range_native_int_type(self, expr: Expression) -> Type | None: """Try to infer native int item type from arguments to range(...). diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index b97c78cba2fc..142fb36d880a 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2887,75 +2887,102 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: That is, 'a < b > c == d' is check as 'a < b and b > c and c == d' """ result: Type | None = None - sub_result: Type | None = None + sub_result: Type # Check each consecutive operand pair and their operator for left, right, operator in zip(e.operands, e.operands[1:], e.operators): left_type = self.accept(left) - method_type: mypy.types.Type | None = None - if operator == "in" or operator == "not in": + """ + This case covers both iterables and containers, which have different meanings. + For a container, the in operator calls the __contains__ method. + For an iterable, the in operator iterates over the iterable, and compares each item one-by-one. + We allow `in` for a union of containers and iterables as long as at least one of them matches the + type of the left operand, as the operation will simply return False if the union's container/iterator + type doesn't match the left operand. + """ + # If the right operand has partial type, look it up without triggering # a "Need type annotation ..." message, as it would be noise. right_type = self.find_partial_type_ref_fast_path(right) if right_type is None: right_type = self.accept(right) # Validate the right operand - # Keep track of whether we get type check errors (these won't be reported, they - # are just to verify whether something is valid typing wise). - with self.msg.filter_errors(save_filtered_errors=True) as local_errors: - _, method_type = self.check_method_call_by_name( - method="__contains__", - base_type=right_type, - args=[left], - arg_kinds=[ARG_POS], - context=e, - ) + right_type = get_proper_type(right_type) + item_types: Sequence[Type] = [right_type] + if isinstance(right_type, UnionType): + item_types = list(right_type.items) sub_result = self.bool_type() - # Container item type for strict type overlap checks. Note: we need to only - # check for nominal type, because a usual "Unsupported operands for in" - # will be reported for types incompatible with __contains__(). - # See testCustomContainsCheckStrictEquality for an example. - cont_type = self.chk.analyze_container_item_type(right_type) - if isinstance(right_type, PartialType): - # We don't really know if this is an error or not, so just shut up. - pass - elif ( - local_errors.has_new_errors() - and - # is_valid_var_arg is True for any Iterable - self.is_valid_var_arg(right_type) - ): - _, itertype = self.chk.analyze_iterable_item_type(right) - method_type = CallableType( - [left_type], - [nodes.ARG_POS], - [None], - self.bool_type(), - self.named_type("builtins.function"), - ) - if not is_subtype(left_type, itertype): - self.msg.unsupported_operand_types("in", left_type, right_type, e) - # Only show dangerous overlap if there are no other errors. - elif ( - not local_errors.has_new_errors() - and cont_type - and self.dangerous_comparison( - left_type, cont_type, original_container=right_type - ) - ): - self.msg.dangerous_comparison(left_type, cont_type, "container", e) - else: - self.msg.add_errors(local_errors.filtered_errors()) + + container_types: list[Type] = [] + iterable_types: list[Type] = [] + failed_out = False + encountered_partial_type = False + + for item_type in item_types: + # Keep track of whether we get type check errors (these won't be reported, they + # are just to verify whether something is valid typing wise). + with self.msg.filter_errors(save_filtered_errors=True) as container_errors: + self.check_method_call_by_name( + method="__contains__", + base_type=item_type, + args=[left], + arg_kinds=[ARG_POS], + context=e, + original_type=right_type, + ) + + # Container item type for strict type overlap checks. Note: we need to only + # check for nominal type, because a usual "Unsupported operands for in" + # will be reported for types incompatible with __contains__(). + # See testCustomContainsCheckStrictEquality for an example. + cont_type = self.chk.analyze_container_item_type(item_type) + + if isinstance(item_type, PartialType): + # We don't really know if this is an error or not, so just shut up. + encountered_partial_type = True + pass + elif ( + container_errors.has_new_errors() + and + # is_valid_var_arg is True for any Iterable + self.is_valid_var_arg(item_type) + ): + # it's not a container, but it is an iterable + with self.msg.filter_errors(save_filtered_errors=True) as iterable_errors: + _, itertype = self.chk.analyze_iterable_item_type(item_type, e) + if iterable_errors.has_new_errors(): + self.msg.add_errors(iterable_errors.filtered_errors()) + failed_out = True + else: + iterable_types.append(itertype) + elif not container_errors.has_new_errors() and cont_type: + container_types.append(cont_type) + else: + self.msg.add_errors(container_errors.filtered_errors()) + failed_out = True + + if not encountered_partial_type and not failed_out: + iterable_type = UnionType.make_union(iterable_types) + if not is_subtype(left_type, iterable_type): + if len(container_types) == 0: + self.msg.unsupported_operand_types("in", left_type, right_type, e) + else: + container_type = UnionType.make_union(container_types) + if self.dangerous_comparison( + left_type, container_type, original_container=right_type + ): + self.msg.dangerous_comparison( + left_type, container_type, "container", e + ) + elif operator in operators.op_methods: method = operators.op_methods[operator] with ErrorWatcher(self.msg.errors) as w: - sub_result, method_type = self.check_op( - method, left_type, right, e, allow_reverse=True - ) + sub_result, _ = self.check_op(method, left_type, right, e, allow_reverse=True) # Only show dangerous overlap if there are no other errors. See # testCustomEqCheckStrictEquality for an example. @@ -2983,12 +3010,9 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: right_type = try_getting_literal(right_type) if self.dangerous_comparison(left_type, right_type): self.msg.dangerous_comparison(left_type, right_type, "identity", e) - method_type = None else: raise RuntimeError(f"Unknown comparison operator {operator}") - e.method_types.append(method_type) - # Determine type of boolean-and of result and sub_result if result is None: result = sub_result @@ -4618,7 +4642,7 @@ def check_for_comp(self, e: GeneratorExpr | DictionaryComprehension) -> None: if is_async: _, sequence_type = self.chk.analyze_async_iterable_item_type(sequence) else: - _, sequence_type = self.chk.analyze_iterable_item_type(sequence) + _, sequence_type = self.chk.analyze_iterable_item_expression(sequence) self.chk.analyze_index_variables(index, sequence_type, True, e) for condition in conditions: self.accept(condition) diff --git a/mypy/nodes.py b/mypy/nodes.py index 80ab787f4a9c..ee5600382d67 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -2017,20 +2017,17 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: class ComparisonExpr(Expression): """Comparison expression (e.g. a < b > c < d).""" - __slots__ = ("operators", "operands", "method_types") + __slots__ = ("operators", "operands") __match_args__ = ("operands", "operators") operators: list[str] operands: list[Expression] - # Inferred type for the operator methods (when relevant; None for 'is'). - method_types: list[mypy.types.Type | None] def __init__(self, operators: list[str], operands: list[Expression]) -> None: super().__init__() self.operators = operators self.operands = operands - self.method_types = [] def pairwise(self) -> Iterator[tuple[str, Expression, Expression]]: """If this comparison expr is "a < b is c == d", yields the sequence diff --git a/mypy/treetransform.py b/mypy/treetransform.py index 535f50d5cf5e..887980c10a6c 100644 --- a/mypy/treetransform.py +++ b/mypy/treetransform.py @@ -530,7 +530,6 @@ def visit_op_expr(self, node: OpExpr) -> OpExpr: def visit_comparison_expr(self, node: ComparisonExpr) -> ComparisonExpr: new = ComparisonExpr(node.operators, self.expressions(node.operands)) - new.method_types = [self.optional_type(t) for t in node.method_types] return new def visit_cast_expr(self, node: CastExpr) -> CastExpr: diff --git a/test-data/unit/check-unions.test b/test-data/unit/check-unions.test index a561c29e54f7..435c1e077231 100644 --- a/test-data/unit/check-unions.test +++ b/test-data/unit/check-unions.test @@ -1183,3 +1183,20 @@ def foo( yield i foo([1]) [builtins fixtures/list.pyi] + +[case testUnionIterableContainer] +from typing import Iterable, Container, Union + +i: Iterable[str] +c: Container[str] +u: Union[Iterable[str], Container[str]] +ni: Union[Iterable[str], int] +nc: Union[Container[str], int] + +'x' in i +'x' in c +'x' in u +'x' in ni # E: Unsupported right operand type for in ("Union[Iterable[str], int]") +'x' in nc # E: Unsupported right operand type for in ("Union[Container[str], int]") +[builtins fixtures/tuple.pyi] +[typing fixtures/typing-full.pyi] From 81829d252ac06aa1337510f83e67e5c9270fbbda Mon Sep 17 00:00:00 2001 From: Max Murin Date: Tue, 3 Jan 2023 12:16:45 -0800 Subject: [PATCH 2/3] change method name to better describe its function, don't remove method types --- mypy/checker.py | 10 ++++++---- mypy/checkexpr.py | 24 +++++++++++++++++++----- mypy/nodes.py | 5 ++++- mypy/treetransform.py | 1 + 4 files changed, 30 insertions(+), 10 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 9af959757ea5..4ab22ca316d1 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -4455,7 +4455,7 @@ def visit_for_stmt(self, s: ForStmt) -> None: if s.is_async: iterator_type, item_type = self.analyze_async_iterable_item_type(s.expr) else: - iterator_type, item_type = self.analyze_iterable_item_expression(s.expr) + iterator_type, item_type = self.analyze_iterable_item_type(s.expr) s.inferred_item_type = item_type s.inferred_iterator_type = iterator_type self.analyze_index_variables(s.index, item_type, s.index_type is None, s) @@ -4472,7 +4472,7 @@ def analyze_async_iterable_item_type(self, expr: Expression) -> tuple[Type, Type ) return iterator, item_type - def analyze_iterable_item_expression(self, expr: Expression) -> tuple[Type, Type]: + def analyze_iterable_item_type(self, expr: Expression) -> tuple[Type, Type]: """Analyse iterable expression and return iterator and iterator item types.""" echk = self.expr_checker iterable = get_proper_type(echk.accept(expr)) @@ -4491,8 +4491,10 @@ def analyze_iterable_item_expression(self, expr: Expression) -> tuple[Type, Type # Non-tuple iterable. return iterator, echk.check_method_call_by_name("__next__", iterator, [], [], expr)[0] - def analyze_iterable_item_type(self, type: Type, context: Context) -> tuple[Type, Type]: - """Analyse iterable expression and return iterator and iterator item types.""" + def analyze_iterable_item_type_without_expression( + self, type: Type, context: Context + ) -> tuple[Type, Type]: + """Analyse iterable type and return iterator and iterator item types.""" echk = self.expr_checker iterable = get_proper_type(type) iterator = echk.check_method_call_by_name("__iter__", iterable, [], [], context)[0] diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 142fb36d880a..3f8887917c37 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2925,7 +2925,7 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: # Keep track of whether we get type check errors (these won't be reported, they # are just to verify whether something is valid typing wise). with self.msg.filter_errors(save_filtered_errors=True) as container_errors: - self.check_method_call_by_name( + _, method_type = self.check_method_call_by_name( method="__contains__", base_type=item_type, args=[left], @@ -2933,7 +2933,6 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: context=e, original_type=right_type, ) - # Container item type for strict type overlap checks. Note: we need to only # check for nominal type, because a usual "Unsupported operands for in" # will be reported for types incompatible with __contains__(). @@ -2952,14 +2951,25 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: ): # it's not a container, but it is an iterable with self.msg.filter_errors(save_filtered_errors=True) as iterable_errors: - _, itertype = self.chk.analyze_iterable_item_type(item_type, e) + _, itertype = self.chk.analyze_iterable_item_type_without_expression( + item_type, e + ) if iterable_errors.has_new_errors(): self.msg.add_errors(iterable_errors.filtered_errors()) failed_out = True else: + method_type = CallableType( + [left_type], + [nodes.ARG_POS], + [None], + self.bool_type(), + self.named_type("builtins.function"), + ) + e.method_types.append(method_type) iterable_types.append(itertype) elif not container_errors.has_new_errors() and cont_type: container_types.append(cont_type) + e.method_types.append(method_type) else: self.msg.add_errors(container_errors.filtered_errors()) failed_out = True @@ -2982,7 +2992,10 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: method = operators.op_methods[operator] with ErrorWatcher(self.msg.errors) as w: - sub_result, _ = self.check_op(method, left_type, right, e, allow_reverse=True) + sub_result, method_type = self.check_op( + method, left_type, right, e, allow_reverse=True + ) + e.method_types.append(method_type) # Only show dangerous overlap if there are no other errors. See # testCustomEqCheckStrictEquality for an example. @@ -3010,6 +3023,7 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: right_type = try_getting_literal(right_type) if self.dangerous_comparison(left_type, right_type): self.msg.dangerous_comparison(left_type, right_type, "identity", e) + e.method_types.append(None) else: raise RuntimeError(f"Unknown comparison operator {operator}") @@ -4642,7 +4656,7 @@ def check_for_comp(self, e: GeneratorExpr | DictionaryComprehension) -> None: if is_async: _, sequence_type = self.chk.analyze_async_iterable_item_type(sequence) else: - _, sequence_type = self.chk.analyze_iterable_item_expression(sequence) + _, sequence_type = self.chk.analyze_iterable_item_type(sequence) self.chk.analyze_index_variables(index, sequence_type, True, e) for condition in conditions: self.accept(condition) diff --git a/mypy/nodes.py b/mypy/nodes.py index ee5600382d67..80ab787f4a9c 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -2017,17 +2017,20 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: class ComparisonExpr(Expression): """Comparison expression (e.g. a < b > c < d).""" - __slots__ = ("operators", "operands") + __slots__ = ("operators", "operands", "method_types") __match_args__ = ("operands", "operators") operators: list[str] operands: list[Expression] + # Inferred type for the operator methods (when relevant; None for 'is'). + method_types: list[mypy.types.Type | None] def __init__(self, operators: list[str], operands: list[Expression]) -> None: super().__init__() self.operators = operators self.operands = operands + self.method_types = [] def pairwise(self) -> Iterator[tuple[str, Expression, Expression]]: """If this comparison expr is "a < b is c == d", yields the sequence diff --git a/mypy/treetransform.py b/mypy/treetransform.py index 887980c10a6c..535f50d5cf5e 100644 --- a/mypy/treetransform.py +++ b/mypy/treetransform.py @@ -530,6 +530,7 @@ def visit_op_expr(self, node: OpExpr) -> OpExpr: def visit_comparison_expr(self, node: ComparisonExpr) -> ComparisonExpr: new = ComparisonExpr(node.operators, self.expressions(node.operands)) + new.method_types = [self.optional_type(t) for t in node.method_types] return new def visit_cast_expr(self, node: CastExpr) -> CastExpr: From 05abbb56a30dcf2508833253c08958f7bdaab252 Mon Sep 17 00:00:00 2001 From: Max Murin Date: Thu, 26 Jan 2023 13:56:43 -0600 Subject: [PATCH 3/3] change block comment format --- mypy/checkexpr.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 3f8887917c37..5b593c3fb0bd 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2894,14 +2894,12 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: left_type = self.accept(left) if operator == "in" or operator == "not in": - """ - This case covers both iterables and containers, which have different meanings. - For a container, the in operator calls the __contains__ method. - For an iterable, the in operator iterates over the iterable, and compares each item one-by-one. - We allow `in` for a union of containers and iterables as long as at least one of them matches the - type of the left operand, as the operation will simply return False if the union's container/iterator - type doesn't match the left operand. - """ + # This case covers both iterables and containers, which have different meanings. + # For a container, the in operator calls the __contains__ method. + # For an iterable, the in operator iterates over the iterable, and compares each item one-by-one. + # We allow `in` for a union of containers and iterables as long as at least one of them matches the + # type of the left operand, as the operation will simply return False if the union's container/iterator + # type doesn't match the left operand. # If the right operand has partial type, look it up without triggering # a "Need type annotation ..." message, as it would be noise.