From f500f5fe2bac17a219ad8abac3a36753400e5433 Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Sat, 28 Sep 2024 14:41:18 -0700 Subject: [PATCH] Make union support unhashable objects --- Lib/test/test_typing.py | 10 +- Objects/unionobject.c | 475 ++++++++++++++++++++-------------------- 2 files changed, 235 insertions(+), 250 deletions(-) diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py index d61e3ab5974cbe..b99b8badf7b1f7 100644 --- a/Lib/test/test_typing.py +++ b/Lib/test/test_typing.py @@ -2062,16 +2062,10 @@ class B(metaclass=UnhashableMeta): ... self.assertEqual(Union[A, B].__args__, (A, B)) union1 = Union[A, B] - with self.assertRaises(TypeError): - hash(union1) - union2 = Union[int, B] - with self.assertRaises(TypeError): - hash(union2) - union3 = Union[A, int] - with self.assertRaises(TypeError): - hash(union3) + + self.assertEqual(len({union1, union2, union3}), 3) def test_repr(self): u = Union[Employee, int] diff --git a/Objects/unionobject.c b/Objects/unionobject.c index 3cf735f53153b8..d2ed67f987daf1 100644 --- a/Objects/unionobject.c +++ b/Objects/unionobject.c @@ -5,12 +5,11 @@ #include "pycore_unionobject.h" -static PyObject *make_union(PyObject *); - - typedef struct { PyObject_HEAD - PyObject *args; + PyObject *args; // all args (tuple) + PyObject *hashable_args; // frozenset or NULL + PyObject *unhashable_args; // tuple or NULL PyObject *parameters; PyObject *weakreflist; } unionobject; @@ -26,6 +25,8 @@ unionobject_dealloc(PyObject *self) } Py_XDECREF(alias->args); + Py_XDECREF(alias->hashable_args); + Py_XDECREF(alias->unhashable_args); Py_XDECREF(alias->parameters); Py_TYPE(self)->tp_free(self); } @@ -35,6 +36,8 @@ union_traverse(PyObject *self, visitproc visit, void *arg) { unionobject *alias = (unionobject *)self; Py_VISIT(alias->args); + Py_VISIT(alias->hashable_args); + Py_VISIT(alias->unhashable_args); Py_VISIT(alias->parameters); return 0; } @@ -43,15 +46,70 @@ static Py_hash_t union_hash(PyObject *self) { unionobject *alias = (unionobject *)self; - PyObject *args = PyFrozenSet_New(alias->args); - if (args == NULL) { - return (Py_hash_t)-1; + Py_hash_t hash; + if (alias->hashable_args) { + hash = PyObject_Hash(alias->hashable_args); + if (hash == -1) { + return -1; + } + } + else { + hash = 604; + } + // Mix in the ids of all the unhashable args. + if (alias->unhashable_args) { + assert(PyTuple_CheckExact(alias->unhashable_args)); + Py_ssize_t n = PyTuple_GET_SIZE(alias->unhashable_args); + for (Py_ssize_t i = 0; i < n; i++) { + PyObject *arg = PyTuple_GET_ITEM(alias->unhashable_args, i); + hash ^= (Py_hash_t)arg; + } } - Py_hash_t hash = PyObject_Hash(args); - Py_DECREF(args); return hash; } +static int +unions_equal(unionobject *a, unionobject *b) +{ + int result = PyObject_RichCompareBool(a->hashable_args, b->hashable_args, Py_EQ); + if (result == -1) { + return -1; + } + if (result == 0) { + return 0; + } + if (a->unhashable_args && b->unhashable_args) { + Py_ssize_t n = PyTuple_GET_SIZE(a->unhashable_args); + if (n != PyTuple_GET_SIZE(b->unhashable_args)) { + return 0; + } + for (Py_ssize_t i = 0; i < n; i++) { + PyObject *arg_a = PyTuple_GET_ITEM(a->unhashable_args, i); + int result = PySequence_Contains(b->unhashable_args, arg_a); + if (result == -1) { + return -1; + } + if (!result) { + return 0; + } + } + for (Py_ssize_t i = 0; i < n; i++) { + PyObject *arg_b = PyTuple_GET_ITEM(b->unhashable_args, i); + int result = PySequence_Contains(a->unhashable_args, arg_b); + if (result == -1) { + return -1; + } + if (!result) { + return 0; + } + } + } + else if (a->unhashable_args || b->unhashable_args) { + return 0; + } + return 1; +} + static PyObject * union_richcompare(PyObject *a, PyObject *b, int op) { @@ -59,95 +117,130 @@ union_richcompare(PyObject *a, PyObject *b, int op) Py_RETURN_NOTIMPLEMENTED; } - PyObject *a_set = PySet_New(((unionobject*)a)->args); - if (a_set == NULL) { + int equal = unions_equal((unionobject*)a, (unionobject*)b); + if (equal == -1) { return NULL; } - PyObject *b_set = PySet_New(((unionobject*)b)->args); - if (b_set == NULL) { - Py_DECREF(a_set); - return NULL; + if (op == Py_EQ) { + return PyBool_FromLong(equal); + } + else { + return PyBool_FromLong(!equal); } - PyObject *result = PyObject_RichCompare(a_set, b_set, op); - Py_DECREF(b_set); - Py_DECREF(a_set); - return result; } -static int -is_same(PyObject *left, PyObject *right) +typedef struct { + PyObject *args; // list + PyObject *hashable_args; // set + PyObject *unhashable_args; // list or NULL + bool is_checked; // whether to call type_check() +} unionbuilder; + +static bool unionbuilder_add_tuple(unionbuilder *, PyObject *); +static PyObject *make_union(unionbuilder *); +static PyObject *type_check(PyObject *, const char *); + +static bool +unionbuilder_init(unionbuilder *ub, bool is_checked) { - int is_ga = _PyGenericAlias_Check(left) && _PyGenericAlias_Check(right); - return is_ga ? PyObject_RichCompareBool(left, right, Py_EQ) : left == right; + ub->args = PyList_New(0); + if (ub->args == NULL) { + return false; + } + ub->hashable_args = PySet_New(NULL); + if (ub->hashable_args == NULL) { + Py_DECREF(ub->args); + return false; + } + ub->unhashable_args = NULL; + ub->is_checked = is_checked; + return true; } -static int -contains(PyObject **items, Py_ssize_t size, PyObject *obj) +static void +unionbuilder_finalize(unionbuilder *ub) { - for (Py_ssize_t i = 0; i < size; i++) { - int is_duplicate = is_same(items[i], obj); - if (is_duplicate) { // -1 or 1 - return is_duplicate; - } - } - return 0; + Py_DECREF(ub->args); + Py_DECREF(ub->hashable_args); + Py_XDECREF(ub->unhashable_args); } -static PyObject * -merge(PyObject **items1, Py_ssize_t size1, - PyObject **items2, Py_ssize_t size2) +static bool +unionbuilder_add_single_unchecked(unionbuilder *ub, PyObject *arg) { - PyObject *tuple = NULL; - Py_ssize_t pos = 0; - - for (Py_ssize_t i = 0; i < size2; i++) { - PyObject *arg = items2[i]; - int is_duplicate = contains(items1, size1, arg); - if (is_duplicate < 0) { - Py_XDECREF(tuple); - return NULL; - } - if (is_duplicate) { - continue; + Py_hash_t hash = PyObject_Hash(arg); + if (hash == -1) { + PyErr_Clear(); + if (ub->unhashable_args == NULL) { + ub->unhashable_args = PyList_New(0); + if (ub->unhashable_args == NULL) { + return false; + } } - - if (tuple == NULL) { - tuple = PyTuple_New(size1 + size2 - i); - if (tuple == NULL) { - return NULL; + else { + int contains = PySequence_Contains(ub->unhashable_args, arg); + if (contains < 0) { + return false; } - for (; pos < size1; pos++) { - PyObject *a = items1[pos]; - PyTuple_SET_ITEM(tuple, pos, Py_NewRef(a)); + if (contains == 1) { + return true; } } - PyTuple_SET_ITEM(tuple, pos, Py_NewRef(arg)); - pos++; + if (PyList_Append(ub->unhashable_args, arg) < 0) { + return false; + } } - - if (tuple) { - (void) _PyTuple_Resize(&tuple, pos); + else { + int contains = PySet_Contains(ub->hashable_args, arg); + if (contains < 0) { + return false; + } + if (contains == 1) { + return true; + } + if (PySet_Add(ub->hashable_args, arg) < 0) { + return false; + } } - return tuple; + return PyList_Append(ub->args, arg) == 0; } -static PyObject ** -get_types(PyObject **obj, Py_ssize_t *size) +static bool +unionbuilder_add_single(unionbuilder *ub, PyObject *arg) { - if (*obj == Py_None) { - *obj = (PyObject *)&_PyNone_Type; + if (Py_IsNone(arg)) { + arg = (PyObject *)&_PyNone_Type; // immortal, so no refcounting needed + } + else if (_PyUnion_Check(arg)) { + PyObject *args = ((unionobject *)arg)->args; + return unionbuilder_add_tuple(ub, args); } - if (_PyUnion_Check(*obj)) { - PyObject *args = ((unionobject *) *obj)->args; - *size = PyTuple_GET_SIZE(args); - return &PyTuple_GET_ITEM(args, 0); + if (ub->is_checked) { + PyObject *type = type_check(arg, "Union[arg, ...]: each arg must be a type."); + if (type == NULL) { + return false; + } + bool result = unionbuilder_add_single_unchecked(ub, type); + Py_DECREF(type); + return result; } else { - *size = 1; - return obj; + return unionbuilder_add_single_unchecked(ub, arg); } } +static bool +unionbuilder_add_tuple(unionbuilder *ub, PyObject *tuple) +{ + Py_ssize_t n = PyTuple_GET_SIZE(tuple); + for (Py_ssize_t i = 0; i < n; i++) { + if (!unionbuilder_add_single(ub, PyTuple_GET_ITEM(tuple, i))) { + return false; + } + } + return true; +} + static int is_unionable(PyObject *obj) { @@ -168,19 +261,18 @@ _Py_union_type_or(PyObject* self, PyObject* other) Py_RETURN_NOTIMPLEMENTED; } - Py_ssize_t size1, size2; - PyObject **items1 = get_types(&self, &size1); - PyObject **items2 = get_types(&other, &size2); - PyObject *tuple = merge(items1, size1, items2, size2); - if (tuple == NULL) { - if (PyErr_Occurred()) { - return NULL; - } - return Py_NewRef(self); + unionbuilder ub; + // unchecked because we already checked is_unionable() + if (!unionbuilder_init(&ub, false)) { + return NULL; + } + if (!unionbuilder_add_single(&ub, self) || + !unionbuilder_add_single(&ub, other)) { + unionbuilder_finalize(&ub); + return NULL; } - PyObject *new_union = make_union(tuple); - Py_DECREF(tuple); + PyObject *new_union = make_union(&ub); return new_union; } @@ -206,6 +298,18 @@ union_repr(PyObject *self) goto error; } } + +#if 0 + PyUnicodeWriter_WriteUTF8(writer, "|args=", 6); + PyUnicodeWriter_WriteRepr(writer, alias->args); + PyUnicodeWriter_WriteUTF8(writer, "|h=", 3); + PyUnicodeWriter_WriteRepr(writer, alias->hashable_args); + if (alias->unhashable_args) { + PyUnicodeWriter_WriteUTF8(writer, "|u=", 3); + PyUnicodeWriter_WriteRepr(writer, alias->unhashable_args); + } +#endif + return PyUnicodeWriter_Finish(writer); error: @@ -235,21 +339,7 @@ union_getitem(PyObject *self, PyObject *item) return NULL; } - PyObject *res; - Py_ssize_t nargs = PyTuple_GET_SIZE(newargs); - if (nargs == 0) { - res = make_union(newargs); - } - else { - res = Py_NewRef(PyTuple_GET_ITEM(newargs, 0)); - for (Py_ssize_t iarg = 1; iarg < nargs; iarg++) { - PyObject *arg = PyTuple_GET_ITEM(newargs, iarg); - Py_SETREF(res, PyNumber_Or(res, arg)); - if (res == NULL) { - break; - } - } - } + PyObject *res = _Py_union_from_tuple(newargs); Py_DECREF(newargs); return res; } @@ -367,159 +457,24 @@ type_check(PyObject *arg, const char *msg) return result; } -static int -add_object_to_union_args(PyObject *args_list, PyObject *args_set, - PyObject *unhashable_args, PyObject *obj) -{ - if (Py_IS_TYPE(obj, &_PyUnion_Type)) { - PyObject *args = ((unionobject *) obj)->args; - Py_ssize_t size = PyTuple_GET_SIZE(args); - for (Py_ssize_t i = 0; i < size; i++) { - PyObject *arg = PyTuple_GET_ITEM(args, i); - if (add_object_to_union_args(args_list, args_set, unhashable_args, arg) < 0) { - return -1; - } - } - return 0; - } - PyObject *type = type_check(obj, "Union[arg, ...]: each arg must be a type."); - if (type == NULL) { - return -1; - } - int contains = PySet_Contains(args_set, type); - if (contains < 0) { - if (!PyErr_ExceptionMatches(PyExc_TypeError)) { - Py_DECREF(type); - return -1; - } - PyErr_Clear(); - if (PyList_Append(unhashable_args, obj) < 0) { - Py_DECREF(type); - return -1; - } - Py_DECREF(type); - return 0; - } - else if (contains == 1) { - Py_DECREF(type); - return 0; - } - if (PyList_Append(args_list, type) < 0) { - Py_DECREF(type); - return -1; - } - if (PySet_Add(args_set, type) < 0) { - Py_DECREF(type); - return -1; - } - Py_DECREF(type); - return 0; -} - PyObject * _Py_union_from_tuple(PyObject *args) { - PyObject *args_list = PyList_New(0); - if (args_list == NULL) { - return NULL; - } - PyObject *args_set = PySet_New(NULL); - if (args_set == NULL) { - Py_DECREF(args_list); - return NULL; - } - PyObject *unhashable_args = PyList_New(0); - if (unhashable_args == NULL) { - Py_DECREF(args_list); - Py_DECREF(args_set); + unionbuilder ub; + if (!unionbuilder_init(&ub, true)) { return NULL; } - if (!PyTuple_CheckExact(args)) { - if (add_object_to_union_args(args_list, args_set, unhashable_args, args) < 0) { - Py_DECREF(args_list); - Py_DECREF(args_set); - Py_DECREF(unhashable_args); + if (PyTuple_CheckExact(args)) { + if (!unionbuilder_add_tuple(&ub, args)) { return NULL; } } else { - Py_ssize_t size = PyTuple_GET_SIZE(args); - for (Py_ssize_t i = 0; i < size; i++) { - PyObject *arg = PyTuple_GET_ITEM(args, i); - if (add_object_to_union_args(args_list, args_set, unhashable_args, arg) < 0) { - Py_DECREF(args_list); - Py_DECREF(args_set); - Py_DECREF(unhashable_args); - return NULL; - } - } - } - Py_DECREF(args_set); - Py_ssize_t num_unhashable = PyList_Size(unhashable_args); - if (num_unhashable < 0) { - Py_DECREF(args_list); - Py_DECREF(unhashable_args); - return NULL; - } - if (num_unhashable > 0) { - PyObject *new_unhashable = PyList_New(0); - if (new_unhashable == NULL) { - Py_DECREF(args_list); - Py_DECREF(unhashable_args); - Py_DECREF(new_unhashable); + if (!unionbuilder_add_single(&ub, args)) { return NULL; } - for (Py_ssize_t i = 0; i < num_unhashable; i++) { - PyObject *obj = PyList_GetItemRef(unhashable_args, i); - if (obj == NULL) { - Py_DECREF(args_list); - Py_DECREF(unhashable_args); - Py_DECREF(new_unhashable); - return NULL; - } - int contains = PySequence_Contains(new_unhashable, obj); - if (contains < 0) { - Py_DECREF(args_list); - Py_DECREF(unhashable_args); - Py_DECREF(new_unhashable); - Py_DECREF(obj); - return NULL; - } - if (contains == 1) { - Py_DECREF(obj); - continue; - } - if (PyList_Append(args_list, obj) < 0) { - Py_DECREF(args_list); - Py_DECREF(unhashable_args); - Py_DECREF(new_unhashable); - Py_DECREF(obj); - return NULL; - } - Py_DECREF(obj); - } - Py_DECREF(new_unhashable); - } - Py_DECREF(unhashable_args); - if (PyList_GET_SIZE(args_list) == 0) { - Py_DECREF(args_list); - PyErr_SetString(PyExc_TypeError, "Cannot take a Union of no types."); - return NULL; - } - else if (PyList_GET_SIZE(args_list) == 1) { - PyObject *result = PyList_GET_ITEM(args_list, 0); - Py_INCREF(result); - Py_DECREF(args_list); - return result; } - PyObject *args_tuple = PyList_AsTuple(args_list); - Py_DECREF(args_list); - if (args_tuple == NULL) { - return NULL; - } - PyObject *u = make_union(args_tuple); - Py_DECREF(args_tuple); - return u; + return make_union(&ub); } static PyObject * @@ -566,18 +521,54 @@ PyTypeObject _PyUnion_Type = { }; static PyObject * -make_union(PyObject *args) +make_union(unionbuilder *ub) { - assert(PyTuple_CheckExact(args)); + Py_ssize_t n = PyList_GET_SIZE(ub->args); + if (n == 0) { + PyErr_SetString(PyExc_TypeError, "Cannot take a Union of no types."); + unionbuilder_finalize(ub); + return NULL; + } + if (n == 1) { + PyObject *result = PyList_GET_ITEM(ub->args, 0); + Py_INCREF(result); + unionbuilder_finalize(ub); + return result; + } + + PyObject *args = NULL, *hashable_args = NULL, *unhashable_args = NULL; + args = PyList_AsTuple(ub->args); + if (args == NULL) { + goto error; + } + hashable_args = PyFrozenSet_New(ub->hashable_args); + if (hashable_args == NULL) { + goto error; + } + if (ub->unhashable_args != NULL) { + unhashable_args = PyList_AsTuple(ub->unhashable_args); + if (unhashable_args == NULL) { + goto error; + } + } unionobject *result = PyObject_GC_New(unionobject, &_PyUnion_Type); if (result == NULL) { - return NULL; + goto error; } + unionbuilder_finalize(ub); result->parameters = NULL; - result->args = Py_NewRef(args); + result->args = args; + result->hashable_args = hashable_args; + result->unhashable_args = unhashable_args; result->weakreflist = NULL; _PyObject_GC_TRACK(result); return (PyObject*)result; +error: + Py_XDECREF(args); + Py_XDECREF(hashable_args); + Py_XDECREF(unhashable_args); + unionbuilder_finalize(ub); + return NULL; }