From 5a8f9085b9bb2283b1dd116987ce6d9564636e7a Mon Sep 17 00:00:00 2001 From: Sylvain MARIE Date: Mon, 29 Jun 2020 17:40:37 +0200 Subject: [PATCH] Now a case can be parametrized using `@parametrize` : no need to use `@cases_generator` anymore. Fixes #106 (second half of it) --- pytest_cases/case_parametrizer_new.py | 26 ++++--- pytest_cases/common_pytest.py | 101 +++++++++++++++++++++++++- 2 files changed, 115 insertions(+), 12 deletions(-) diff --git a/pytest_cases/case_parametrizer_new.py b/pytest_cases/case_parametrizer_new.py index a76c28e0..c19944dc 100644 --- a/pytest_cases/case_parametrizer_new.py +++ b/pytest_cases/case_parametrizer_new.py @@ -13,7 +13,7 @@ from .common_mini_six import string_types from .common_others import get_code_first_line, AUTO, AUTO2 -from .common_pytest import safe_isclass, copy_pytest_marks +from .common_pytest import safe_isclass, copy_pytest_marks, get_callspecs from .case_funcs_new import matches_tag_query, is_case_function, is_case_class, CaseInfo from .fixture_parametrize_plus import parametrize_plus, lazy_value @@ -146,7 +146,7 @@ def get_pytest_parametrize_args(cases_funs # type: List[Callable] return [c for _f in cases_funs for c in case_to_argvalues(_f)] -def case_to_argvalues(f # type: Callable +def case_to_argvalues(case_fun # type: Callable ): # type: (...) -> List[lazy_value] """Transform a single case into one or several `lazy_value` to be used in `@parametrize` @@ -164,21 +164,29 @@ def case_to_argvalues(f # type: Callable id = None marks = () - case_info = CaseInfo.get_from(f) + case_info = CaseInfo.get_from(case_fun) if case_info is not None: id = case_info.id marks = case_info.marks if id is None: # default test id from function name - if f.__name__.startswith('case_'): - id = f.__name__[5:] - elif f.__name__.startswith('cases_'): - id = f.__name__[6:] + if case_fun.__name__.startswith('case_'): + id = case_fun.__name__[5:] + elif case_fun.__name__.startswith('cases_'): + id = case_fun.__name__[6:] else: - id = f.__name__ + id = case_fun.__name__ - return lazy_value(f, id=id, marks=marks) + # get the list of all calls that pytest *would* have made for such a (possibly parametrized) function + calls = get_callspecs(case_fun) + + if len(calls) == 0: + # single unparametrized case function + return (lazy_value(case_fun, id=id, marks=marks),) + else: + # parametrized. create one version of the callable for each parametrized call + return tuple(lazy_value(partial(case_fun, **c.funcargs), id="%s-%s" % (id, c.id), marks=c.marks) for c in calls) def import_default_cases_module(f, alt_name=False): diff --git a/pytest_cases/common_pytest.py b/pytest_cases/common_pytest.py index 14b7ccee..cdf16317 100644 --- a/pytest_cases/common_pytest.py +++ b/pytest_cases/common_pytest.py @@ -3,9 +3,9 @@ import warnings try: # python 3.3+ - from inspect import signature + from inspect import signature, Parameter except ImportError: - from funcsigs import signature # noqa + from funcsigs import signature, Parameter # noqa from distutils.version import LooseVersion from inspect import isgeneratorfunction, isclass @@ -604,7 +604,7 @@ def get_pytest_nodeid(metafunc): from _pytest.fixtures import scopes as pt_scopes except ImportError: # pytest 2 - from _pytest.python import scopes as pt_scopes # noqa + from _pytest.python import scopes as pt_scopes, Metafunc # noqa def get_pytest_scopenum(scope_str): @@ -653,3 +653,98 @@ def mini_idvalset(argnames, argvalues, idx): for val, argname in zip(argvalues, argnames) ] return "-".join(this_id) + + +from _pytest.python import Metafunc + +try: + from _pytest.compat import getfuncargnames +except ImportError: + import sys + + def num_mock_patch_args(function): + """ return number of arguments used up by mock arguments (if any) """ + patchings = getattr(function, "patchings", None) + if not patchings: + return 0 + + mock_sentinel = getattr(sys.modules.get("mock"), "DEFAULT", object()) + ut_mock_sentinel = getattr(sys.modules.get("unittest.mock"), "DEFAULT", object()) + + return len( + [ + p + for p in patchings + if not p.attribute_name + and (p.new is mock_sentinel or p.new is ut_mock_sentinel) + ] + ) + + def getfuncargnames(function, cls=None): + """Returns the names of a function's mandatory arguments.""" + parameters = signature(function).parameters + + arg_names = tuple( + p.name + for p in parameters.values() + if ( + p.kind is Parameter.POSITIONAL_OR_KEYWORD + or p.kind is Parameter.KEYWORD_ONLY + ) + and p.default is Parameter.empty + ) + + # If this function should be treated as a bound method even though + # it's passed as an unbound method or function, remove the first + # parameter name. + if cls and not isinstance(cls.__dict__.get(function.__name__, None), staticmethod): + arg_names = arg_names[1:] + # Remove any names that will be replaced with mocks. + if hasattr(function, "__wrapped__"): + arg_names = arg_names[num_mock_patch_args(function):] + return arg_names + + +class MiniFuncDef(object): + __slots__ = ('nodeid',) + + def __init__(self, nodeid): + self.nodeid = nodeid + + +class MiniMetafunc(Metafunc): + def __init__(self, func): + self.config = None + self.function = func + self.definition = MiniFuncDef(func.__name__) + self._calls = [] + # non-default parameters + self.fixturenames = getfuncargnames(func) + + +def get_callspecs(func): + """ + Returns a list of pytest CallSpec objects corresponding to calls that should be made for this parametrized function. + This mini-helper assumes no complex things (scope='function', indirect=False, no fixtures, no custom configuration) + + :param func: + :return: + """ + meta = MiniMetafunc(func) + + pmarks = get_pytest_parametrize_marks(func) + for pmark in pmarks: + if len(pmark.param_names) == 1: + argvals = tuple(v if is_marked_parameter_value(v) else (v,) for v in pmark.param_values) + else: + argvals = pmark.param_values + meta.parametrize(argnames=pmark.param_names, argvalues=argvals, ids=pmark.param_ids, + # use indirect = False and scope = 'function' to avoid having to implement complex patches + indirect=False, scope='function') + + if not has_pytest_param: + # fix the CallSpec2 instances so that the marks appear + for c in meta._calls: + c.marks = list(c.keywords.values()) + + return meta._calls