diff --git a/pytest_cases/tests/advanced/test_memoize.py b/pytest_cases/tests/advanced/test_memoize.py index 6ddac0ba..28267941 100644 --- a/pytest_cases/tests/advanced/test_memoize.py +++ b/pytest_cases/tests/advanced/test_memoize.py @@ -1,4 +1,5 @@ from pytest_cases import cases_data, CaseDataGetter, THIS_MODULE, case_tags +from pytest_cases.tests.utils import nb_pytest_parameters, get_pytest_param try: # python 3.2+ from functools import lru_cache @@ -75,17 +76,20 @@ def test_c(case_data # type: CaseDataGetter def test_assert_parametrized(): """Asserts that all tests are parametrized with the correct number of cases""" - assert len(test_a.pytestmark) == 1 - assert len(test_a.pytestmark[0].args) == 2 - assert test_a.pytestmark[0].args[0] == 'case_data' - assert len(test_a.pytestmark[0].args[1]) == 2 - - assert len(test_b.pytestmark) == 1 - assert len(test_b.pytestmark[0].args) == 2 - assert test_b.pytestmark[0].args[0] == 'case_data' - assert len(test_b.pytestmark[0].args[1]) == 2 - - assert len(test_c.pytestmark) == 1 - assert len(test_c.pytestmark[0].args) == 2 - assert test_c.pytestmark[0].args[0] == 'case_data' - assert len(test_c.pytestmark[0].args[1]) == 3 + assert nb_pytest_parameters(test_a) == 1 + param_args = get_pytest_param(test_a, 0) + assert len(param_args) == 2 + assert param_args[0] == 'case_data' + assert len(param_args[1]) == 2 + + assert nb_pytest_parameters(test_b) == 1 + param_args = get_pytest_param(test_b, 0) + assert len(param_args) == 2 + assert param_args[0] == 'case_data' + assert len(param_args[1]) == 2 + + assert nb_pytest_parameters(test_c) == 1 + param_args = get_pytest_param(test_c, 0) + assert len(param_args) == 2 + assert param_args[0] == 'case_data' + assert len(param_args[1]) == 3 diff --git a/pytest_cases/tests/advanced/test_memoize_generators.py b/pytest_cases/tests/advanced/test_memoize_generators.py index 434bdaef..c846e9e2 100644 --- a/pytest_cases/tests/advanced/test_memoize_generators.py +++ b/pytest_cases/tests/advanced/test_memoize_generators.py @@ -1,4 +1,5 @@ from pytest_cases import cases_data, THIS_MODULE, cases_generator, CaseDataGetter, extract_cases_from_module +from pytest_cases.tests.utils import nb_pytest_parameters, get_pytest_param try: # python 3+: type hints from pytest_cases import CaseData @@ -46,12 +47,14 @@ def test_assert_cases_are_here(): def test_assert_parametrized(): """Asserts that test_b is parametrized with the correct number of cases""" - assert len(test_a.pytestmark) == 1 - assert len(test_a.pytestmark[0].args) == 2 - assert test_a.pytestmark[0].args[0] == 'case_data' - assert len(test_a.pytestmark[0].args[1]) == 3 - - assert len(test_b.pytestmark) == 1 - assert len(test_b.pytestmark[0].args) == 2 - assert test_b.pytestmark[0].args[0] == 'case_data' - assert len(test_b.pytestmark[0].args[1]) == 3 + assert nb_pytest_parameters(test_a) == 1 + param_args = get_pytest_param(test_a, 0) + assert len(param_args) == 2 + assert param_args[0] == 'case_data' + assert len(param_args[1]) == 3 + + assert nb_pytest_parameters(test_b) == 1 + param_args = get_pytest_param(test_b, 0) + assert len(param_args) == 2 + assert param_args[0] == 'case_data' + assert len(param_args[1]) == 3 diff --git a/pytest_cases/tests/advanced/test_parameters.py b/pytest_cases/tests/advanced/test_parameters.py index db765af2..16483ef0 100644 --- a/pytest_cases/tests/advanced/test_parameters.py +++ b/pytest_cases/tests/advanced/test_parameters.py @@ -1,6 +1,7 @@ import pytest from pytest_cases.tests.example_code import super_function_i_want_to_test +from pytest_cases.tests.utils import nb_pytest_parameters, get_pytest_param from pytest_cases import cases_data, CaseDataGetter, THIS_MODULE, cases_generator try: @@ -57,12 +58,14 @@ def test_with_parameters(case_data, # type: CaseDataGetter def test_assert_parametrized(): """Asserts that all tests are parametrized with the correct number of cases""" - assert len(test_with_parameters.pytestmark) == 2 + assert nb_pytest_parameters(test_with_parameters) == 2 - assert len(test_with_parameters.pytestmark[0].args) == 2 - assert test_with_parameters.pytestmark[0].args[0] == 'version' - assert len(test_with_parameters.pytestmark[0].args[1]) == 2 + param_args = get_pytest_param(test_with_parameters, 0) + assert len(param_args) == 2 + assert param_args[0] == 'version' + assert len(param_args[1]) == 2 - assert len(test_with_parameters.pytestmark[1].args) == 2 - assert test_with_parameters.pytestmark[1].args[0] == 'case_data' - assert len(test_with_parameters.pytestmark[1].args[1]) == 1 + 1 + 2 * 2 + param_args = get_pytest_param(test_with_parameters, 1) + assert len(param_args) == 2 + assert param_args[0] == 'case_data' + assert len(param_args[1]) == 1 + 1 + 2 * 2 diff --git a/pytest_cases/tests/utils.py b/pytest_cases/tests/utils.py new file mode 100644 index 00000000..524e4ed0 --- /dev/null +++ b/pytest_cases/tests/utils.py @@ -0,0 +1,16 @@ +def nb_pytest_parameters(f): + try: + # new pytest + return len(f.pytestmark) + except AttributeError: + # old pytest + return len(f.parametrize.args) / 2 + + +def get_pytest_param(f, i): + try: + # new pytest + return f.pytestmark[i].args + except AttributeError: + # old pytest + return f.parametrize.args[2*i:2*(i+1)]