diff --git a/changelog/5062.feature.rst b/changelog/5062.feature.rst new file mode 100644 index 00000000000..e311d161d52 --- /dev/null +++ b/changelog/5062.feature.rst @@ -0,0 +1 @@ +Unroll calls to ``all`` to full for-loops for better failure messages, especially when using Generator Expressions. diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 18506d2e187..e3870d43509 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -964,6 +964,8 @@ def visit_Call_35(self, call): """ visit `ast.Call` nodes on Python3.5 and after """ + if isinstance(call.func, ast.Name) and call.func.id == "all": + return self._visit_all(call) new_func, func_expl = self.visit(call.func) arg_expls = [] new_args = [] @@ -987,6 +989,27 @@ def visit_Call_35(self, call): outer_expl = "%s\n{%s = %s\n}" % (res_expl, res_expl, expl) return res, outer_expl + def _visit_all(self, call): + """Special rewrite for the builtin all function, see #5062""" + if not isinstance(call.args[0], (ast.GeneratorExp, ast.ListComp)): + return + gen_exp = call.args[0] + assertion_module = ast.Module( + body=[ast.Assert(test=gen_exp.elt, lineno=1, msg="", col_offset=1)] + ) + AssertionRewriter(module_path=None, config=None).run(assertion_module) + for_loop = ast.For( + iter=gen_exp.generators[0].iter, + target=gen_exp.generators[0].target, + body=assertion_module.body, + orelse=[], + ) + self.statements.append(for_loop) + return ( + ast.Num(n=1), + "", + ) # Return an empty expression, all the asserts are in the for_loop + def visit_Starred(self, starred): # From Python 3.5, a Starred node can appear in a function call res, expl = self.visit(starred.value) @@ -997,6 +1020,8 @@ def visit_Call_legacy(self, call): """ visit `ast.Call nodes on 3.4 and below` """ + if isinstance(call.func, ast.Name) and call.func.id == "all": + return self._visit_all(call) new_func, func_expl = self.visit(call.func) arg_expls = [] new_args = [] diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 72bfbcc55af..e819075b588 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -656,6 +656,12 @@ def __repr__(self): else: assert lines == ["assert 0 == 1\n + where 1 = \\n{ \\n~ \\n}.a"] + def test_unroll_expression(self): + def f(): + assert all(x == 1 for x in range(10)) + + assert "0 == 1" in getmsg(f) + def test_custom_repr_non_ascii(self): def f(): class A(object): @@ -671,6 +677,53 @@ def __repr__(self): assert "UnicodeDecodeError" not in msg assert "UnicodeEncodeError" not in msg + def test_unroll_generator(self, testdir): + testdir.makepyfile( + """ + def check_even(num): + if num % 2 == 0: + return True + return False + + def test_generator(): + odd_list = list(range(1,9,2)) + assert all(check_even(num) for num in odd_list)""" + ) + result = testdir.runpytest() + result.stdout.fnmatch_lines(["*assert False*", "*where False = check_even(1)*"]) + + def test_unroll_list_comprehension(self, testdir): + testdir.makepyfile( + """ + def check_even(num): + if num % 2 == 0: + return True + return False + + def test_list_comprehension(): + odd_list = list(range(1,9,2)) + assert all([check_even(num) for num in odd_list])""" + ) + result = testdir.runpytest() + result.stdout.fnmatch_lines(["*assert False*", "*where False = check_even(1)*"]) + + def test_for_loop(self, testdir): + testdir.makepyfile( + """ + def check_even(num): + if num % 2 == 0: + return True + return False + + def test_for_loop(): + odd_list = list(range(1,9,2)) + for num in odd_list: + assert check_even(num) + """ + ) + result = testdir.runpytest() + result.stdout.fnmatch_lines(["*assert False*", "*where False = check_even(1)*"]) + class TestRewriteOnImport(object): def test_pycache_is_a_file(self, testdir):