Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add inference of Compare nodes #979

Merged
merged 3 commits into from
Sep 14, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions astroid/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@
"""this module contains a set of functions to handle inference on astroid trees
"""

import ast
import functools
import itertools
import operator
from typing import Any, Iterable

import wrapt

Expand Down Expand Up @@ -798,6 +800,98 @@ def infer_binop(self, context=None):
nodes.BinOp._infer_binop = _infer_binop
nodes.BinOp._infer = infer_binop

COMPARE_OPS = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

COMPARE_OPS: Dict[str, Callable[[Any, Any], bool]] = { ... }

That would help inferring op_func and expr correctly.

"==": operator.eq,
"!=": operator.ne,
"<": operator.lt,
"<=": operator.le,
">": operator.gt,
">=": operator.ge,
"in": lambda a, b: a in b,
"not in": lambda a, b: a not in b,
}
UNINFERABLE_OPS = {
"is",
"is not",
}
Comment on lines +805 to +808
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be a frozenset



def _to_literal(node: nodes.NodeNG) -> Any:
# Can raise SyntaxError or ValueError from ast.literal_eval
# Is this the stupidest idea or the simplest idea?
return ast.literal_eval(node.as_string())
Comment on lines +813 to +814
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like that! Something like this here would also work, but isn't as clean

    if isinstance(node, nodes.Const):
        return node.value
    if isinstance(node, nodes.Tuple):
        return tuple(_to_literal(val) for val in node.elts)
    if isinstance(node, nodes.List):
        return [_to_literal(val) for val in node.elts]
    if isinstance(node, nodes.Set):
        return set(_to_literal(val) for val in node.elts)
    if isinstance(node, nodes.Dict):
        return {_to_literal(k): _to_literal(v) for k, v in node.items}



def _do_compare(
left_iter: Iterable[nodes.NodeNG], op: str, right_iter: Iterable[nodes.NodeNG]
) -> "bool | type[util.Uninferable]":
"""
If all possible combinations are either True or False, return that:
>>> _do_compare([1, 2], '<=', [3, 4])
True
>>> _do_compare([1, 2], '==', [3, 4])
False

If any item is uninferable, or if some combinations are True and some
are False, return Uninferable:
>>> _do_compare([1, 3], '<=', [2, 4])
util.Uninferable
"""
retval = None
if op in UNINFERABLE_OPS:
return util.Uninferable
op_func = COMPARE_OPS[op]

for left, right in itertools.product(left_iter, right_iter):
if left is util.Uninferable or right is util.Uninferable:
return util.Uninferable

try:
left, right = _to_literal(left), _to_literal(right)
except (SyntaxError, ValueError):
return util.Uninferable

try:
expr = op_func(left, right)
except TypeError as exc:
raise AstroidTypeError from exc

if retval is None:
retval = expr
elif retval != expr:
return util.Uninferable
# (or both, but "True | False" is basically the same)

return retval # it was all the same value


def _infer_compare(self: nodes.Compare, context: contextmod.InferenceContext) -> Any:
"""Chained comparison inference logic."""
retval = True

ops = self.ops
left_node = self.left
lhs = list(left_node.infer(context=context))
# should we break early if first element is uninferable?
for op, right_node in ops:
# eagerly evaluate rhs so that values can be re-used as lhs
rhs = list(right_node.infer(context=context))
try:
retval = _do_compare(lhs, op, rhs)
except AstroidTypeError:
retval = util.Uninferable
break
if retval is not True:
break # short-circuit
lhs = rhs # continue
if retval is util.Uninferable:
yield retval
else:
yield nodes.Const(retval)


nodes.Compare._infer = _infer_compare
Pierre-Sassoulas marked this conversation as resolved.
Show resolved Hide resolved


def _infer_augassign(self, context=None):
"""Inference logic for augmented binary operations."""
Expand Down
255 changes: 255 additions & 0 deletions tests/unittest_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5216,6 +5216,261 @@ def f(**kwargs):
assert next(extract_node(code).infer()).as_string() == "{'f': 1}"


@pytest.mark.parametrize(
"op,result",
[
("<", False),
("<=", True),
("==", True),
(">=", True),
(">", False),
("!=", False),
],
)
def test_compare(op, result):
code = """
123 {} 123
""".format(
op
)
node = extract_node(code)
inferred = next(node.infer())
assert inferred.value == result


@pytest.mark.xfail(reason="uninferable")
@pytest.mark.parametrize(
"op,result",
[
("is", True),
("is not", False),
],
)
def test_compare_identity(op, result):
code = """
obj = object()
obj {} obj
""".format(
op
)
node = extract_node(code)
inferred = next(node.infer())
assert inferred.value == result


@pytest.mark.parametrize(
"op,result",
[
("in", True),
("not in", False),
],
)
def test_compare_membership(op, result):
code = """
1 {} [1, 2, 3]
""".format(
op
)
node = extract_node(code)
inferred = next(node.infer())
assert inferred.value == result


@pytest.mark.parametrize(
"lhs,rhs,result",
[
(1, 1, True),
(1, 1.1, True),
(1.1, 1, False),
(1.0, 1.0, True),
("abc", "def", True),
("abc", "", False),
([], [1], True),
((1, 2), (2, 3), True),
((1, 0), (1,), False),
(True, True, True),
(True, False, False),
(False, 1, True),
(1 + 0j, 2 + 0j, util.Uninferable),
(+0.0, -0.0, True),
(0, "1", util.Uninferable),
(b"\x00", b"\x01", True),
],
)
def test_compare_lesseq_types(lhs, rhs, result):
code = """
{lhs!r} <= {rhs!r}
""".format(
lhs=lhs, rhs=rhs
)
node = extract_node(code)
inferred = next(node.infer())
assert inferred.value == result


def test_compare_chained():
code = """
3 < 5 > 3
"""
node = extract_node(code)
inferred = next(node.infer())
assert inferred.value is True


def test_compare_inferred_members():
code = """
a = 11
b = 13
a < b
"""
node = extract_node(code)
inferred = next(node.infer())
assert inferred.value is True


def test_compare_instance_members():
code = """
class A:
value = 123
class B:
@property
def value(self):
return 456
A().value < B().value
"""
node = extract_node(code)
inferred = next(node.infer())
assert inferred.value is True


@pytest.mark.xfail(reason="unimplemented")
def test_compare_dynamic():
code = """
class A:
def __le__(self, other):
return True
A() <= None
"""
node = extract_node(code)
inferred = next(node.infer())
assert inferred.value is True


def test_compare_uninferable_member():
code = """
from unknown import UNKNOWN
0 <= UNKNOWN
"""
node = extract_node(code)
inferred = next(node.infer())
assert inferred is util.Uninferable


def test_compare_chained_comparisons_shortcircuit_on_false():
code = """
from unknown import UNKNOWN
2 < 1 < UNKNOWN
"""
node = extract_node(code)
inferred = next(node.infer())
assert inferred.value is False


def test_compare_chained_comparisons_continue_on_true():
code = """
from unknown import UNKNOWN
1 < 2 < UNKNOWN
"""
node = extract_node(code)
inferred = next(node.infer())
assert inferred is util.Uninferable


@pytest.mark.xfail(reason="unimplemented")
def test_compare_known_false_branch():
code = """
a = 'hello'
if 1 < 2:
a = 'goodbye'
a
"""
node = extract_node(code)
inferred = list(node.infer())
assert len(inferred) == 1
assert isinstance(inferred[0], nodes.Const)
assert inferred[0].value == "hello"


def test_compare_ifexp_constant():
code = """
a = 'hello' if 1 < 2 else 'goodbye'
a
"""
node = extract_node(code)
inferred = list(node.infer())
assert len(inferred) == 1
assert isinstance(inferred[0], nodes.Const)
assert inferred[0].value == "hello"


def test_compare_typeerror():
code = """
123 <= "abc"
"""
node = extract_node(code)
inferred = list(node.infer())
assert len(inferred) == 1
assert inferred[0] is util.Uninferable


def test_compare_multiple_possibilites():
code = """
from unknown import UNKNOWN
a = 1
if UNKNOWN:
a = 2
b = 3
if UNKNOWN:
b = 4
a < b
"""
node = extract_node(code)
inferred = list(node.infer())
assert len(inferred) == 1
# All possible combinations are true: (1 < 3), (1 < 4), (2 < 3), (2 < 4)
assert inferred[0].value is True


def test_compare_ambiguous_multiple_possibilites():
code = """
from unknown import UNKNOWN
a = 1
if UNKNOWN:
a = 3
b = 2
if UNKNOWN:
b = 4
a < b
"""
node = extract_node(code)
inferred = list(node.infer())
assert len(inferred) == 1
# Not all possible combinations are true: (1 < 2), (1 < 4), (3 !< 2), (3 < 4)
assert inferred[0] is util.Uninferable


def test_compare_nonliteral():
code = """
def func(a, b):
return (a, b) <= (1, 2) #@
"""
return_node = extract_node(code)
node = return_node.value
inferred = list(node.infer()) # should not raise ValueError
assert len(inferred) == 1
assert inferred[0] is util.Uninferable


def test_limit_inference_result_amount():
"""Test setting limit inference result amount"""
code = """
Expand Down