diff --git a/Modules/_asynciomodule.c b/Modules/_asynciomodule.c index 6fe4ca4694..6c1a7b50c2 100644 --- a/Modules/_asynciomodule.c +++ b/Modules/_asynciomodule.c @@ -1969,15 +1969,13 @@ unregister_task(asyncio_state *state, PyObject *task) static int enter_task(asyncio_state *state, PyObject *loop, PyObject *task) { - PyObject *item; - Py_hash_t hash; - hash = PyObject_Hash(loop); - if (hash == -1) { + int is_insert; + PyObject *item = _PyDict_SetDefault(state->current_tasks, loop, task, 1, + &is_insert); + if (item == NULL) { return -1; } - item = _PyDict_GetItem_KnownHash(state->current_tasks, loop, hash); - if (item != NULL) { - Py_INCREF(item); + if (!is_insert) { PyErr_Format( PyExc_RuntimeError, "Cannot enter into task %R while another " \ @@ -1986,36 +1984,51 @@ enter_task(asyncio_state *state, PyObject *loop, PyObject *task) Py_DECREF(item); return -1; } - if (PyErr_Occurred()) { - return -1; - } - return _PyDict_SetItem_KnownHash(state->current_tasks, loop, task, hash); + assert(item == task); + Py_DECREF(item); + return 0; } +struct task_matches_arg { + PyObject *task; + PyObject *item; +}; + +static int +task_matches(PyObject *item, void *data) +{ + struct task_matches_arg *arg = (struct task_matches_arg *)data; + arg->item = Py_NewRef(item); + return arg->task == item; +} static int leave_task(asyncio_state *state, PyObject *loop, PyObject *task) /*[clinic end generated code: output=0ebf6db4b858fb41 input=51296a46313d1ad8]*/ { - PyObject *item; - Py_hash_t hash; - hash = PyObject_Hash(loop); - if (hash == -1) { - return -1; - } - item = _PyDict_GetItem_KnownHash(state->current_tasks, loop, hash); - if (item != task) { - if (item == NULL) { - /* Not entered, replace with None */ - item = Py_None; + struct task_matches_arg arg; + arg.task = task; + arg.item = Py_None; + int res = _PyDict_DelItemIf(state->current_tasks, loop, &task_matches, &arg); + if (res != 0) { + if (PyErr_ExceptionMatches(PyExc_KeyError)) { + // use custom error message if loop is not present + goto fail; } - PyErr_Format( - PyExc_RuntimeError, - "Leaving task %R does not match the current task %R.", - task, item, NULL); return -1; } - return _PyDict_DelItem_KnownHash(state->current_tasks, loop, hash); + if (arg.item == task) { + Py_XDECREF(arg.item); + return 0; + } + +fail: + PyErr_Format( + PyExc_RuntimeError, + "Leaving task %R does not match the current task %R.", + task, arg.item, NULL); + Py_XDECREF(arg.item); + return -1; } /* ----- Task */