Skip to content

Commit

Permalink
Releasing the GIL before calling ucp_worker_progress() (#802)
Browse files Browse the repository at this point in the history
* adding `with gil` to all Cython callback function

* Releasing the GIL before calling ucp_worker_progress()
  • Loading branch information
madsbk authored Oct 27, 2021
1 parent 121ad2b commit e28d770
Show file tree
Hide file tree
Showing 10 changed files with 18 additions and 15 deletions.
4 changes: 3 additions & 1 deletion ucp/_libs/transfer_am.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ logger = logging.getLogger("ucx")


IF CY_UCP_AM_SUPPORTED:
cdef void _send_nbx_callback(void *request, ucs_status_t status, void *user_data):
cdef void _send_nbx_callback(
void *request, ucs_status_t status, void *user_data
) with gil:
cdef UCXRequest req
cdef dict req_info
cdef str name, ucx_status_msg, msg
Expand Down
2 changes: 1 addition & 1 deletion ucp/_libs/transfer_common.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ from ..exceptions import UCXCanceled, UCXError, log_errors
# This callback function is currently needed by stream_send_nb and
# tag_send_nb transfer functions, as well as UCXEndpoint and UCXWorker
# flush methods.
cdef void _send_callback(void *request, ucs_status_t status):
cdef void _send_callback(void *request, ucs_status_t status) with gil:
cdef UCXRequest req
cdef dict req_info
cdef str name, ucx_status_msg, msg
Expand Down
2 changes: 1 addition & 1 deletion ucp/_libs/transfer_stream.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def stream_send_nb(

cdef void _stream_recv_callback(
void *request, ucs_status_t status, size_t length
):
) with gil:
cdef UCXRequest req
cdef dict req_info
cdef str name, ucx_status_msg, msg
Expand Down
2 changes: 1 addition & 1 deletion ucp/_libs/transfer_tag.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def tag_send_nb(

cdef void _tag_recv_callback(
void *request, ucs_status_t status, ucp_tag_recv_info_t *info
):
) with gil:
cdef UCXRequest req
cdef dict req_info
cdef str name, ucx_status_msg, msg
Expand Down
2 changes: 1 addition & 1 deletion ucp/_libs/ucx_api_dep.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ cdef extern from "ucp/api/ucp.h":

ucp_datatype_t ucp_dt_make_contig(size_t elem_size)

unsigned ucp_worker_progress(ucp_worker_h worker)
unsigned ucp_worker_progress(ucp_worker_h worker) nogil

ctypedef struct ucp_tag_recv_info_t:
ucp_tag_t sender_tag
Expand Down
4 changes: 2 additions & 2 deletions ucp/_libs/ucx_endpoint.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ from ..exceptions import UCXCanceled, UCXConnectionReset, UCXError
logger = logging.getLogger("ucx")


cdef void _err_cb(void *arg, ucp_ep_h ep, ucs_status_t status):
cdef void _err_cb(void *arg, ucp_ep_h ep, ucs_status_t status) with gil:
cdef UCXEndpoint ucx_ep = <UCXEndpoint> arg
assert ucx_ep.worker.initialized

Expand All @@ -35,7 +35,7 @@ cdef void _err_cb(void *arg, ucp_ep_h ep, ucs_status_t status):

cdef (ucp_err_handler_cb_t, uintptr_t) _get_error_callback(
str tls, bint endpoint_error_handling
) except *:
) except * with gil:
cdef ucp_err_handler_cb_t err_cb = <ucp_err_handler_cb_t>NULL
cdef ucs_status_t *cb_status = <ucs_status_t *>NULL

Expand Down
2 changes: 1 addition & 1 deletion ucp/_libs/ucx_listener.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ from .ucx_api_dep cimport *
from ..exceptions import log_errors


cdef void _listener_callback(ucp_conn_request_h conn_request, void *args):
cdef void _listener_callback(ucp_conn_request_h conn_request, void *args) with gil:
"""Callback function used by UCXListener"""
cdef dict cb_data = <dict> args

Expand Down
5 changes: 3 additions & 2 deletions ucp/_libs/ucx_worker.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,9 @@ cdef class UCXWorker(UCXObject):
the call-back function given to UCXListener, tag_send_nb, and tag_recv_nb.
"""
assert self.initialized
while ucp_worker_progress(self._handle) != 0:
pass
with nogil:
while ucp_worker_progress(self._handle) != 0:
pass

@property
def handle(self):
Expand Down
4 changes: 2 additions & 2 deletions ucp/_libs/ucx_worker_cb.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ IF CY_UCP_AM_SUPPORTED:
ucs_status_t status,
size_t length,
void *user_data
):
) with gil:
cdef bytearray buf
cdef UCXRequest req
cdef dict req_info
Expand Down Expand Up @@ -88,7 +88,7 @@ IF CY_UCP_AM_SUPPORTED:
void *data,
size_t length,
const ucp_am_recv_param_t *param
):
) with gil:
cdef UCXWorker worker = <UCXWorker>arg
cdef dict am_recv_pool = worker._am_recv_pool
cdef dict am_recv_wait = worker._am_recv_wait
Expand Down
6 changes: 3 additions & 3 deletions ucp/continuous_ucx_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def __init__(self, worker, event_loop, epoll_fd):
super().__init__(worker, event_loop)

# Creating a job that is ready straightaway but with low priority.
# Calling `await event_loop.sock_recv(rsock, 1)` will return when
# all non-IO tasks are finished.
# Calling `await self.event_loop.sock_recv(self.rsock, 1)` will
# return when all non-IO tasks are finished.
# See <https://stackoverflow.com/a/48491563>.
self.rsock, wsock = socket.socketpair()
self.rsock.setblocking(0)
Expand Down Expand Up @@ -99,7 +99,7 @@ async def _arm_worker(self):
del worker

# This IO task returns when all non-IO tasks are finished.
# Notice, we do NOT hold a reference to `ctx` while waiting.
# Notice, we do NOT hold a reference to `worker` while waiting.
await self.event_loop.sock_recv(self.rsock, 1)

worker = self.weakref_worker()
Expand Down

0 comments on commit e28d770

Please sign in to comment.