From 4fef3ccb7cdd53b553e50a10b7bb7d522ae8a82e Mon Sep 17 00:00:00 2001 From: deanlee Date: Wed, 29 May 2024 20:16:31 +0800 Subject: [PATCH] improve blocking receive --- SConscript | 5 +- msgq/impl_msgq.cc | 104 +++++++++++++++-------------------- msgq/ipc.pxd | 2 +- msgq/ipc_pyx.pyx | 4 +- msgq/tests/test_messaging.py | 25 +++++++++ 5 files changed, 77 insertions(+), 63 deletions(-) diff --git a/SConscript b/SConscript index 147eb3042..d11c17c54 100644 --- a/SConscript +++ b/SConscript @@ -15,7 +15,7 @@ msgq_objects = env.SharedObject([ 'msgq/msgq.cc', ]) msgq = env.Library('msgq', msgq_objects) -msgq_python = envCython.Program('msgq/ipc_pyx.so', 'msgq/ipc_pyx.pyx', LIBS=envCython["LIBS"]+[msgq, "zmq", common]) +msgq_python = envCython.Program('msgq/ipc_pyx.so', 'msgq/ipc_pyx.pyx', LIBS=envCython["LIBS"]+[msgq, "zmq", 'pthread',common]) # Build Vision IPC vipc_files = ['visionipc.cc', 'visionipc_server.cc', 'visionipc_client.cc', 'visionbuf.cc'] @@ -31,7 +31,7 @@ visionipc = env.Library('visionipc', vipc_objects) vipc_frameworks = [] -vipc_libs = envCython["LIBS"] + [visionipc, msgq, common, "zmq"] +vipc_libs = envCython["LIBS"] + [visionipc, msgq, common, "zmq", 'pthread'] if arch == "Darwin": vipc_frameworks.append('OpenCL') else: @@ -45,4 +45,5 @@ if GetOption('extras'): [f'{visionipc_dir.abspath}/test_runner.cc', f'{visionipc_dir.abspath}/visionipc_tests.cc'], LIBS=['pthread'] + vipc_libs, FRAMEWORKS=vipc_frameworks) +msgq = [msgq, 'pthread'] Export('visionipc', 'msgq', 'msgq_python') diff --git a/msgq/impl_msgq.cc b/msgq/impl_msgq.cc index b23991351..88404e74d 100644 --- a/msgq/impl_msgq.cc +++ b/msgq/impl_msgq.cc @@ -1,20 +1,12 @@ #include #include #include -#include +#include #include -#include #include "msgq/impl_msgq.h" - -volatile sig_atomic_t msgq_do_exit = 0; - -void sig_handler(int signal) { - assert(signal == SIGINT || signal == SIGTERM); - msgq_do_exit = 1; -} - +using namespace std::chrono; MSGQContext::MSGQContext() { } @@ -70,61 +62,55 @@ int MSGQSubSocket::connect(Context *context, std::string endpoint, std::string a return 0; } - -Message * MSGQSubSocket::receive(bool non_blocking){ - msgq_do_exit = 0; - - void (*prev_handler_sigint)(int); - void (*prev_handler_sigterm)(int); - if (!non_blocking){ - prev_handler_sigint = std::signal(SIGINT, sig_handler); - prev_handler_sigterm = std::signal(SIGTERM, sig_handler); - } - - msgq_msg_t msg; - - MSGQMessage *r = NULL; - +Message *MSGQSubSocket::receive(bool non_blocking) { + msgq_msg_t msg{}; int rc = msgq_msg_recv(&msg, q); - // Hack to implement blocking read with a poller. Don't use this - while (!non_blocking && rc == 0 && msgq_do_exit == 0){ - msgq_pollitem_t items[1]; - items[0].q = q; - - int t = (timeout != -1) ? timeout : 100; - - int n = msgq_poll(items, 1, t); - rc = msgq_msg_recv(&msg, q); - - // The poll indicated a message was ready, but the receive failed. Try again - if (n == 1 && rc == 0){ - continue; - } - - if (timeout != -1){ - break; + if (rc == 0 && !non_blocking) { + sigset_t mask; + sigset_t old_mask; + sigemptyset(&mask); + sigaddset(&mask, SIGINT); + sigaddset(&mask, SIGTERM); + sigaddset(&mask, SIGUSR2); + + pthread_sigmask(SIG_BLOCK, &mask, &old_mask); + + int64_t timieout_ns = ((timeout != -1) ? timeout : 1000) * 1000000; + auto start = steady_clock::now(); + + // Continue receiving messages until timeout or interruption by SIGINT or SIGTERM + while (rc == 0 && timieout_ns > 0) { + struct timespec ts { + timieout_ns / 1000000000, + timieout_ns % 1000000000, + }; + + int ret = sigtimedwait(&mask, nullptr, &ts); + if (ret == SIGINT || ret == SIGTERM) { + // Ensure signal handling is not missed + raise(ret); + break; + } else if (ret == -1 && errno == EAGAIN && timeout != -1) { + break; // Timed out + } + + rc = msgq_msg_recv(&msg, q); + + if (timeout != -1) { + timieout_ns -= duration_cast(steady_clock::now() - start).count(); + start = steady_clock::now(); // Update start time + } } + pthread_sigmask(SIG_SETMASK, &old_mask, nullptr); } - - if (!non_blocking){ - std::signal(SIGINT, prev_handler_sigint); - std::signal(SIGTERM, prev_handler_sigterm); - } - - errno = msgq_do_exit ? EINTR : 0; - - if (rc > 0){ - if (msgq_do_exit){ - msgq_msg_close(&msg); // Free unused message on exit - } else { - r = new MSGQMessage; - r->takeOwnership(msg.data, msg.size); - } + if (rc > 0) { + MSGQMessage *r = new MSGQMessage; + r->takeOwnership(msg.data, msg.size); + return r; } - - return (Message*)r; + return nullptr; } void MSGQSubSocket::setTimeout(int t){ diff --git a/msgq/ipc.pxd b/msgq/ipc.pxd index 2c7ac963e..ca33ea0f8 100644 --- a/msgq/ipc.pxd +++ b/msgq/ipc.pxd @@ -50,7 +50,7 @@ cdef extern from "msgq/ipc.h": @staticmethod SubSocket * create() int connect(Context *, string, string, bool) - Message * receive(bool) + Message * receive(bool) nogil void setTimeout(int) cdef cppclass PubSocket: diff --git a/msgq/ipc_pyx.pyx b/msgq/ipc_pyx.pyx index d8797f395..10b324bfc 100644 --- a/msgq/ipc_pyx.pyx +++ b/msgq/ipc_pyx.pyx @@ -196,7 +196,9 @@ cdef class SubSocket: self.socket.setTimeout(timeout) def receive(self, bool non_blocking=False): - msg = self.socket.receive(non_blocking) + cdef cppMessage *msg + with nogil: + msg = self.socket.receive(non_blocking) if msg == NULL: # If a blocking read returns no message check errno if SIGINT was caught in the C++ code diff --git a/msgq/tests/test_messaging.py b/msgq/tests/test_messaging.py index 40dfd7f00..01e40838d 100644 --- a/msgq/tests/test_messaging.py +++ b/msgq/tests/test_messaging.py @@ -1,5 +1,8 @@ import os +import pytest import random +import signal +import threading import time import string import msgq @@ -67,3 +70,25 @@ def test_receive_timeout(self): recvd = sub_sock.receive() assert (time.monotonic() - start_time) < 0.2 assert recvd is None + + def test_receive_interrupts_on_sigint(self): + sock = random_sock() + sub_sock = msgq.sub_sock(sock) + + # Send SIGINT after a short delay + pid = os.getpid() + def send_sigint(): + time.sleep(.5) + os.kill(pid, signal.SIGINT) + + # Start a thread to send SIGINT + thread = threading.Thread(target=send_sigint) + thread.start() + + with pytest.raises(KeyboardInterrupt): + start_time = time.monotonic() + recvd = sub_sock.receive() + assert (time.monotonic() - start_time) < 0.5 + assert recvd is None + + thread.join()