diff --git a/docs/changelog.rst b/docs/changelog.rst index fb39eb8..9b6d684 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -2,6 +2,7 @@ Changelog ========= +* :bug:`85` Fixed marbles handling of deprecated assertEquals and similar methods * :release:`0.9.5 <2018-06-24>` * :support:`80` Added support for ``pandas<0.24`` * :bug:`58` Fixed test failure on OSX diff --git a/marbles/core/marbles/core/marbles.py b/marbles/core/marbles/core/marbles.py index 76b6fa8..d7bffd4 100644 --- a/marbles/core/marbles/core/marbles.py +++ b/marbles/core/marbles/core/marbles.py @@ -432,9 +432,19 @@ def _find_msg_argument(signature): The index of the ``msg`` param, the default value for it, and the number of non-``msg`` positional parameters we expect. ''' - names = signature.parameters.keys() + names = list(signature.parameters.keys()) + if len(names) == 2: + param_kinds = [signature.parameters[name].kind for name in names] + if (param_kinds[0] == inspect.Parameter.VAR_POSITIONAL + and param_kinds[1] == inspect.Parameter.VAR_KEYWORD): + # This is likely an assertion wrapper like + # assertEquals(*args, **kwargs), in which case we should + # just forward the arguments along and catch them in the + # wrapped assertion method (see + # https://github.com/twosigma/marbles/issue/85). + return sys.maxsize, None, len(names) try: - msg_idx = list(names).index('msg') + msg_idx = names.index('msg') default_msg = signature.parameters['msg'].default except ValueError: # 'msg' is not in list # It's likely that this is a custom assertion that's just diff --git a/marbles/core/tests/test_marbles.py b/marbles/core/tests/test_marbles.py index 065e625..38c0ed7 100644 --- a/marbles/core/tests/test_marbles.py +++ b/marbles/core/tests/test_marbles.py @@ -82,6 +82,16 @@ def test_success(self): def test_failure(self): self.assertTrue(False, note='some note') + def test_deprecated_assertEquals_success(self): + x = 1 + y = 1 + self.assertEquals(x, y) + + def test_deprecated_assertEquals_failure(self): + x = 1 + y = 2 + self.assertEquals(x, y) + def test_fail_without_msg_without_note(self): self.fail() @@ -345,6 +355,23 @@ def test_annotated_assertion_error_raised(self): with self.assertRaises(ContextualAssertionError): self.case.test_failure() + def test_deprecated_assertEquals_success(self): + '''Does the deprecated assertEquals method still work?''' + if self._use_annotated_test_case: + with self.assertRaises(AnnotationError): + self.case.test_deprecated_assertEquals_success() + else: + self.case.test_deprecated_assertEquals_success() + + def test_deprecated_assertEquals_failure(self): + '''Does the deprecated assertEquals method work on failure?''' + if self._use_annotated_test_case: + with self.assertRaises(AnnotationError): + self.case.test_deprecated_assertEquals_failure() + else: + with self.assertRaises(ContextualAssertionError): + self.case.test_deprecated_assertEquals_failure() + def test_fail_handles_note_properly(self): '''Does TestCase.fail() deal with note the right way?''' if self._use_annotated_test_case: @@ -515,7 +542,7 @@ def test_get_stack(self): self.assertEqual(e.filename, os.path.abspath(__file__)) # This isn't great because I have to change it every time I # add/remove imports but oh well - self.assertEqual(e.linenumber, 211) + self.assertEqual(e.linenumber, 221) def test_assert_stmt_indicates_line(self): '''Does e.assert_stmt indicate the line from the source code?'''