From f7bf91410824e9090677943d51c46eed290bf736 Mon Sep 17 00:00:00 2001 From: Anthony Sottile Date: Sat, 1 Jun 2019 13:50:15 -0700 Subject: [PATCH] Fix `pytest.mark.parametrize` when the argvalue is an iterator --- changelog/5354.bugfix.rst | 1 + src/_pytest/mark/structures.py | 10 +++++++--- testing/test_mark.py | 22 ++++++++++++++++++++++ 3 files changed, 30 insertions(+), 3 deletions(-) create mode 100644 changelog/5354.bugfix.rst diff --git a/changelog/5354.bugfix.rst b/changelog/5354.bugfix.rst new file mode 100644 index 00000000000..812ea8364aa --- /dev/null +++ b/changelog/5354.bugfix.rst @@ -0,0 +1 @@ +Fix ``pytest.mark.parametrize`` when the argvalues is an iterator. diff --git a/src/_pytest/mark/structures.py b/src/_pytest/mark/structures.py index 561ccc3f45d..9602e8acfa3 100644 --- a/src/_pytest/mark/structures.py +++ b/src/_pytest/mark/structures.py @@ -113,14 +113,18 @@ def _parse_parametrize_args(argnames, argvalues, **_): force_tuple = len(argnames) == 1 else: force_tuple = False - parameters = [ + return argnames, force_tuple + + @staticmethod + def _parse_parametrize_parameters(argvalues, force_tuple): + return [ ParameterSet.extract_from(x, force_tuple=force_tuple) for x in argvalues ] - return argnames, parameters @classmethod def _for_parametrize(cls, argnames, argvalues, func, config, function_definition): - argnames, parameters = cls._parse_parametrize_args(argnames, argvalues) + argnames, force_tuple = cls._parse_parametrize_args(argnames, argvalues) + parameters = cls._parse_parametrize_parameters(argvalues, force_tuple) del argvalues if parameters: diff --git a/testing/test_mark.py b/testing/test_mark.py index 5bd97d547e8..dd9d352309d 100644 --- a/testing/test_mark.py +++ b/testing/test_mark.py @@ -413,6 +413,28 @@ def test_func(a, b): assert result.ret == 0 +def test_parametrize_iterator(testdir): + """parametrize should work with generators (#5354).""" + py_file = testdir.makepyfile( + """\ + import pytest + + def gen(): + yield 1 + yield 2 + yield 3 + + @pytest.mark.parametrize('a', gen()) + def test(a): + assert a >= 1 + """ + ) + result = testdir.runpytest(py_file) + assert result.ret == 0 + # should not skip any tests + result.stdout.fnmatch_lines(["*3 passed*"]) + + class TestFunctional(object): def test_merging_markers_deep(self, testdir): # issue 199 - propagate markers into nested classes