Skip to content

Commit

Permalink
Fix an if-splitting bug.
Browse files Browse the repository at this point in the history
In our implementation of COMPARE_OP, we were looping through all possible
bindings of the left and right side, then calling the comparison magic method
with all the left bindings that could not be matched definitely, and the
*entire right variable*. Just as for the left side, we should be filtering out
the right bindings that were already matched.

PiperOrigin-RevId: 343433846
  • Loading branch information
rchen152 committed Nov 23, 2020
1 parent 7471549 commit 74033c2
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
9 changes: 9 additions & 0 deletions pytype/tests/py3/test_splits.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,15 @@ class MyIterable(Iterable[T]): ...
def f(x: MyIterable[int]) -> Union[int, str]: ...
""")

def test_str_none_eq(self):
self.Check("""
from typing import Optional
def f(x: str, y: Optional[str]) -> str:
if x == y:
return y
return x
""")


class SplitTestPy3(test_base.TargetPython3FeatureTest):
"""Tests for if-splitting in Python 3."""
Expand Down
11 changes: 7 additions & 4 deletions pytype/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2056,17 +2056,20 @@ def _cmp_rel(self, state, op_name, x, y):
# A variable of the values without a special cmp_rel implementation. Needed
# because overloaded __eq__ implementations do not necessarily return a
# bool; see, e.g., test_overloaded in test_cmp.
leftover = self.program.NewVariable()
leftover_x = self.program.NewVariable()
leftover_y = self.program.NewVariable()
for b1 in x.bindings:
for b2 in y.bindings:
val = compare.cmp_rel(self, getattr(slots, op_name), b1.data, b2.data)
if val is None:
leftover.AddBinding(b1.data, {b1}, state.node)
leftover_x.AddBinding(b1.data, {b1}, state.node)
leftover_y.AddBinding(b2.data, {b2}, state.node)
else:
ret.AddBinding(self.convert.bool_values[val], {b1, b2}, state.node)
if leftover.bindings:
if leftover_x.bindings:
op = "__%s__" % op_name.lower()
state, leftover_ret = self.call_binary_operator(state, op, leftover, y)
state, leftover_ret = self.call_binary_operator(
state, op, leftover_x, leftover_y)
ret.PasteVariable(leftover_ret, state.node)
return state, ret

Expand Down

0 comments on commit 74033c2

Please sign in to comment.