From 68ee8b3f15b744339c156bfeb583a414061ab22d Mon Sep 17 00:00:00 2001 From: Cheryl Sabella Date: Sat, 20 May 2023 12:07:40 -0400 Subject: [PATCH] gh-56276: Add tests to test_compare (#3199) Co-authored-by: Terry Jan Reedy Co-authored-by: Oleg Iarygin --- Lib/test/test_compare.py | 426 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 409 insertions(+), 17 deletions(-) diff --git a/Lib/test/test_compare.py b/Lib/test/test_compare.py index 2b3faed7965bc1..8166b0eea306e3 100644 --- a/Lib/test/test_compare.py +++ b/Lib/test/test_compare.py @@ -1,21 +1,27 @@ +"""Test equality and order comparisons.""" import unittest from test.support import ALWAYS_EQ +from fractions import Fraction +from decimal import Decimal -class Empty: - def __repr__(self): - return '' -class Cmp: - def __init__(self,arg): - self.arg = arg +class ComparisonSimpleTest(unittest.TestCase): + """Test equality and order comparisons for some simple cases.""" - def __repr__(self): - return '' % self.arg + class Empty: + def __repr__(self): + return '' - def __eq__(self, other): - return self.arg == other + class Cmp: + def __init__(self, arg): + self.arg = arg + + def __repr__(self): + return '' % self.arg + + def __eq__(self, other): + return self.arg == other -class ComparisonTest(unittest.TestCase): set1 = [2, 2.0, 2, 2+0j, Cmp(2.0)] set2 = [[1], (3,), None, Empty()] candidates = set1 + set2 @@ -32,16 +38,15 @@ def test_id_comparisons(self): # Ensure default comparison compares id() of args L = [] for i in range(10): - L.insert(len(L)//2, Empty()) + L.insert(len(L)//2, self.Empty()) for a in L: for b in L: - self.assertEqual(a == b, id(a) == id(b), - 'a=%r, b=%r' % (a, b)) + self.assertEqual(a == b, a is b, 'a=%r, b=%r' % (a, b)) def test_ne_defaults_to_not_eq(self): - a = Cmp(1) - b = Cmp(1) - c = Cmp(2) + a = self.Cmp(1) + b = self.Cmp(1) + c = self.Cmp(2) self.assertIs(a == b, True) self.assertIs(a != b, False) self.assertIs(a != c, True) @@ -114,5 +119,392 @@ def test_issue_1393(self): self.assertEqual(ALWAYS_EQ, y) +class ComparisonFullTest(unittest.TestCase): + """Test equality and ordering comparisons for built-in types and + user-defined classes that implement relevant combinations of rich + comparison methods. + """ + + class CompBase: + """Base class for classes with rich comparison methods. + + The "x" attribute should be set to an underlying value to compare. + + Derived classes have a "meth" tuple attribute listing names of + comparison methods implemented. See assert_total_order(). + """ + + # Class without any rich comparison methods. + class CompNone(CompBase): + meth = () + + # Classes with all combinations of value-based equality comparison methods. + class CompEq(CompBase): + meth = ("eq",) + def __eq__(self, other): + return self.x == other.x + + class CompNe(CompBase): + meth = ("ne",) + def __ne__(self, other): + return self.x != other.x + + class CompEqNe(CompBase): + meth = ("eq", "ne") + def __eq__(self, other): + return self.x == other.x + def __ne__(self, other): + return self.x != other.x + + # Classes with all combinations of value-based less/greater-than order + # comparison methods. + class CompLt(CompBase): + meth = ("lt",) + def __lt__(self, other): + return self.x < other.x + + class CompGt(CompBase): + meth = ("gt",) + def __gt__(self, other): + return self.x > other.x + + class CompLtGt(CompBase): + meth = ("lt", "gt") + def __lt__(self, other): + return self.x < other.x + def __gt__(self, other): + return self.x > other.x + + # Classes with all combinations of value-based less/greater-or-equal-than + # order comparison methods + class CompLe(CompBase): + meth = ("le",) + def __le__(self, other): + return self.x <= other.x + + class CompGe(CompBase): + meth = ("ge",) + def __ge__(self, other): + return self.x >= other.x + + class CompLeGe(CompBase): + meth = ("le", "ge") + def __le__(self, other): + return self.x <= other.x + def __ge__(self, other): + return self.x >= other.x + + # It should be sufficient to combine the comparison methods only within + # each group. + all_comp_classes = ( + CompNone, + CompEq, CompNe, CompEqNe, # equal group + CompLt, CompGt, CompLtGt, # less/greater-than group + CompLe, CompGe, CompLeGe) # less/greater-or-equal group + + def create_sorted_instances(self, class_, values): + """Create objects of type `class_` and return them in a list. + + `values` is a list of values that determines the value of data + attribute `x` of each object. + + Objects in the returned list are sorted by their identity. They + assigned values in `values` list order. By assign decreasing + values to objects with increasing identities, testcases can assert + that order comparison is performed by value and not by identity. + """ + + instances = [class_() for __ in range(len(values))] + instances.sort(key=id) + # Assign the provided values to the instances. + for inst, value in zip(instances, values): + inst.x = value + return instances + + def assert_equality_only(self, a, b, equal): + """Assert equality result and that ordering is not implemented. + + a, b: Instances to be tested (of same or different type). + equal: Boolean indicating the expected equality comparison results. + """ + self.assertEqual(a == b, equal) + self.assertEqual(b == a, equal) + self.assertEqual(a != b, not equal) + self.assertEqual(b != a, not equal) + with self.assertRaisesRegex(TypeError, "not supported"): + a < b + with self.assertRaisesRegex(TypeError, "not supported"): + a <= b + with self.assertRaisesRegex(TypeError, "not supported"): + a > b + with self.assertRaisesRegex(TypeError, "not supported"): + a >= b + with self.assertRaisesRegex(TypeError, "not supported"): + b < a + with self.assertRaisesRegex(TypeError, "not supported"): + b <= a + with self.assertRaisesRegex(TypeError, "not supported"): + b > a + with self.assertRaisesRegex(TypeError, "not supported"): + b >= a + + def assert_total_order(self, a, b, comp, a_meth=None, b_meth=None): + """Test total ordering comparison of two instances. + + a, b: Instances to be tested (of same or different type). + + comp: -1, 0, or 1 indicates that the expected order comparison + result for operations that are supported by the classes is + a <, ==, or > b. + + a_meth, b_meth: Either None, indicating that all rich comparison + methods are available, aa for builtins, or the tuple (subset) + of "eq", "ne", "lt", "le", "gt", and "ge" that are available + for the corresponding instance (of a user-defined class). + """ + self.assert_eq_subtest(a, b, comp, a_meth, b_meth) + self.assert_ne_subtest(a, b, comp, a_meth, b_meth) + self.assert_lt_subtest(a, b, comp, a_meth, b_meth) + self.assert_le_subtest(a, b, comp, a_meth, b_meth) + self.assert_gt_subtest(a, b, comp, a_meth, b_meth) + self.assert_ge_subtest(a, b, comp, a_meth, b_meth) + + # The body of each subtest has form: + # + # if value-based comparison methods: + # expect what the testcase defined for a op b and b rop a; + # else: no value-based comparison + # expect default behavior of object for a op b and b rop a. + + def assert_eq_subtest(self, a, b, comp, a_meth, b_meth): + if a_meth is None or "eq" in a_meth or "eq" in b_meth: + self.assertEqual(a == b, comp == 0) + self.assertEqual(b == a, comp == 0) + else: + self.assertEqual(a == b, a is b) + self.assertEqual(b == a, a is b) + + def assert_ne_subtest(self, a, b, comp, a_meth, b_meth): + if a_meth is None or not {"ne", "eq"}.isdisjoint(a_meth + b_meth): + self.assertEqual(a != b, comp != 0) + self.assertEqual(b != a, comp != 0) + else: + self.assertEqual(a != b, a is not b) + self.assertEqual(b != a, a is not b) + + def assert_lt_subtest(self, a, b, comp, a_meth, b_meth): + if a_meth is None or "lt" in a_meth or "gt" in b_meth: + self.assertEqual(a < b, comp < 0) + self.assertEqual(b > a, comp < 0) + else: + with self.assertRaisesRegex(TypeError, "not supported"): + a < b + with self.assertRaisesRegex(TypeError, "not supported"): + b > a + + def assert_le_subtest(self, a, b, comp, a_meth, b_meth): + if a_meth is None or "le" in a_meth or "ge" in b_meth: + self.assertEqual(a <= b, comp <= 0) + self.assertEqual(b >= a, comp <= 0) + else: + with self.assertRaisesRegex(TypeError, "not supported"): + a <= b + with self.assertRaisesRegex(TypeError, "not supported"): + b >= a + + def assert_gt_subtest(self, a, b, comp, a_meth, b_meth): + if a_meth is None or "gt" in a_meth or "lt" in b_meth: + self.assertEqual(a > b, comp > 0) + self.assertEqual(b < a, comp > 0) + else: + with self.assertRaisesRegex(TypeError, "not supported"): + a > b + with self.assertRaisesRegex(TypeError, "not supported"): + b < a + + def assert_ge_subtest(self, a, b, comp, a_meth, b_meth): + if a_meth is None or "ge" in a_meth or "le" in b_meth: + self.assertEqual(a >= b, comp >= 0) + self.assertEqual(b <= a, comp >= 0) + else: + with self.assertRaisesRegex(TypeError, "not supported"): + a >= b + with self.assertRaisesRegex(TypeError, "not supported"): + b <= a + + def test_objects(self): + """Compare instances of type 'object'.""" + a = object() + b = object() + self.assert_equality_only(a, a, True) + self.assert_equality_only(a, b, False) + + def test_comp_classes_same(self): + """Compare same-class instances with comparison methods.""" + + for cls in self.all_comp_classes: + with self.subTest(cls): + instances = self.create_sorted_instances(cls, (1, 2, 1)) + + # Same object. + self.assert_total_order(instances[0], instances[0], 0, + cls.meth, cls.meth) + + # Different objects, same value. + self.assert_total_order(instances[0], instances[2], 0, + cls.meth, cls.meth) + + # Different objects, value ascending for ascending identities. + self.assert_total_order(instances[0], instances[1], -1, + cls.meth, cls.meth) + + # different objects, value descending for ascending identities. + # This is the interesting case to assert that order comparison + # is performed based on the value and not based on the identity. + self.assert_total_order(instances[1], instances[2], +1, + cls.meth, cls.meth) + + def test_comp_classes_different(self): + """Compare different-class instances with comparison methods.""" + + for cls_a in self.all_comp_classes: + for cls_b in self.all_comp_classes: + with self.subTest(a=cls_a, b=cls_b): + a1 = cls_a() + a1.x = 1 + b1 = cls_b() + b1.x = 1 + b2 = cls_b() + b2.x = 2 + + self.assert_total_order( + a1, b1, 0, cls_a.meth, cls_b.meth) + self.assert_total_order( + a1, b2, -1, cls_a.meth, cls_b.meth) + + def test_str_subclass(self): + """Compare instances of str and a subclass.""" + class StrSubclass(str): + pass + + s1 = str("a") + s2 = str("b") + c1 = StrSubclass("a") + c2 = StrSubclass("b") + c3 = StrSubclass("b") + + self.assert_total_order(s1, s1, 0) + self.assert_total_order(s1, s2, -1) + self.assert_total_order(c1, c1, 0) + self.assert_total_order(c1, c2, -1) + self.assert_total_order(c2, c3, 0) + + self.assert_total_order(s1, c2, -1) + self.assert_total_order(s2, c3, 0) + self.assert_total_order(c1, s2, -1) + self.assert_total_order(c2, s2, 0) + + def test_numbers(self): + """Compare number types.""" + + # Same types. + i1 = 1001 + i2 = 1002 + self.assert_total_order(i1, i1, 0) + self.assert_total_order(i1, i2, -1) + + f1 = 1001.0 + f2 = 1001.1 + self.assert_total_order(f1, f1, 0) + self.assert_total_order(f1, f2, -1) + + q1 = Fraction(2002, 2) + q2 = Fraction(2003, 2) + self.assert_total_order(q1, q1, 0) + self.assert_total_order(q1, q2, -1) + + d1 = Decimal('1001.0') + d2 = Decimal('1001.1') + self.assert_total_order(d1, d1, 0) + self.assert_total_order(d1, d2, -1) + + c1 = 1001+0j + c2 = 1001+1j + self.assert_equality_only(c1, c1, True) + self.assert_equality_only(c1, c2, False) + + + # Mixing types. + for n1, n2 in ((i1,f1), (i1,q1), (i1,d1), (f1,q1), (f1,d1), (q1,d1)): + self.assert_total_order(n1, n2, 0) + for n1 in (i1, f1, q1, d1): + self.assert_equality_only(n1, c1, True) + + def test_sequences(self): + """Compare list, tuple, and range.""" + l1 = [1, 2] + l2 = [2, 3] + self.assert_total_order(l1, l1, 0) + self.assert_total_order(l1, l2, -1) + + t1 = (1, 2) + t2 = (2, 3) + self.assert_total_order(t1, t1, 0) + self.assert_total_order(t1, t2, -1) + + r1 = range(1, 2) + r2 = range(2, 2) + self.assert_equality_only(r1, r1, True) + self.assert_equality_only(r1, r2, False) + + self.assert_equality_only(t1, l1, False) + self.assert_equality_only(l1, r1, False) + self.assert_equality_only(r1, t1, False) + + def test_bytes(self): + """Compare bytes and bytearray.""" + bs1 = b'a1' + bs2 = b'b2' + self.assert_total_order(bs1, bs1, 0) + self.assert_total_order(bs1, bs2, -1) + + ba1 = bytearray(b'a1') + ba2 = bytearray(b'b2') + self.assert_total_order(ba1, ba1, 0) + self.assert_total_order(ba1, ba2, -1) + + self.assert_total_order(bs1, ba1, 0) + self.assert_total_order(bs1, ba2, -1) + self.assert_total_order(ba1, bs1, 0) + self.assert_total_order(ba1, bs2, -1) + + def test_sets(self): + """Compare set and frozenset.""" + s1 = {1, 2} + s2 = {1, 2, 3} + self.assert_total_order(s1, s1, 0) + self.assert_total_order(s1, s2, -1) + + f1 = frozenset(s1) + f2 = frozenset(s2) + self.assert_total_order(f1, f1, 0) + self.assert_total_order(f1, f2, -1) + + self.assert_total_order(s1, f1, 0) + self.assert_total_order(s1, f2, -1) + self.assert_total_order(f1, s1, 0) + self.assert_total_order(f1, s2, -1) + + def test_mappings(self): + """ Compare dict. + """ + d1 = {1: "a", 2: "b"} + d2 = {2: "b", 3: "c"} + d3 = {3: "c", 2: "b"} + self.assert_equality_only(d1, d1, True) + self.assert_equality_only(d1, d2, False) + self.assert_equality_only(d2, d3, True) + + if __name__ == '__main__': unittest.main()