diff --git a/environment-dev.yml b/environment-dev.yml index 131815a..9906a96 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -19,4 +19,4 @@ dependencies: # Test dependencies - doctest >= 2.4.6 - pytest - - jupyter_kernel_test>=0.5,<0.6 + - jupyter_kernel_test>=0.6,<0.7 diff --git a/include/xeus-zmq/xserver_zmq.hpp b/include/xeus-zmq/xserver_zmq.hpp index e8238ec..220b290 100644 --- a/include/xeus-zmq/xserver_zmq.hpp +++ b/include/xeus-zmq/xserver_zmq.hpp @@ -11,17 +11,18 @@ #define XEUS_SERVER_IMPL_HPP #include "zmq.hpp" +#include "zmq_addon.hpp" #include "xeus/xeus_context.hpp" #include "xeus/xkernel_configuration.hpp" #include "xeus/xserver.hpp" #include "xeus-zmq.hpp" -#include "xauthentication.hpp" #include "xthread.hpp" namespace xeus { + class xauthentication; class xpublisher; class xheartbeat; class xtrivial_messenger; @@ -39,6 +40,8 @@ namespace xeus ~xserver_zmq() override; + zmq::multipart_t serialize_iopub(xpub_message&& msg); + using xserver::notify_internal_listener; protected: @@ -67,6 +70,9 @@ namespace xeus zmq::socket_t m_publisher_controller; zmq::socket_t m_heartbeat_controller; + using authentication_ptr = std::unique_ptr; + authentication_ptr p_auth; + publisher_ptr p_publisher; heartbeat_ptr p_heartbeat; @@ -76,8 +82,6 @@ namespace xeus using trivial_messenger_ptr = std::unique_ptr; trivial_messenger_ptr p_messenger; - using authentication_ptr = std::unique_ptr; - authentication_ptr p_auth; nl::json::error_handler_t m_error_handler; bool m_request_stop; diff --git a/include/xeus-zmq/xserver_zmq_split.hpp b/include/xeus-zmq/xserver_zmq_split.hpp index 1e53713..c851184 100644 --- a/include/xeus-zmq/xserver_zmq_split.hpp +++ b/include/xeus-zmq/xserver_zmq_split.hpp @@ -18,11 +18,11 @@ #include "xeus/xkernel_configuration.hpp" #include "xeus-zmq.hpp" -#include "xauthentication.hpp" #include "xthread.hpp" namespace xeus { + class xauthentication; class xcontrol; class xheartbeat; class xpublisher; @@ -54,6 +54,8 @@ namespace xeus xmessage deserialize(zmq::multipart_t& wire_msg) const; + zmq::multipart_t serialize_iopub(xpub_message&& msg); + protected: xcontrol_messenger& get_control_messenger_impl() override; @@ -82,6 +84,9 @@ namespace xeus virtual void start_server(zmq::multipart_t& wire_msg) = 0; + using authentication_ptr = std::unique_ptr; + authentication_ptr p_auth; + controller_ptr p_controller; heartbeat_ptr p_heartbeat; publisher_ptr p_publisher; @@ -92,8 +97,6 @@ namespace xeus xthread m_iopub_thread; xthread m_shell_thread; - using authentication_ptr = std::unique_ptr; - authentication_ptr p_auth; nl::json::error_handler_t m_error_handler; std::atomic m_control_stopped; diff --git a/src/xpublisher.cpp b/src/xpublisher.cpp index 87d10be..7593e38 100644 --- a/src/xpublisher.cpp +++ b/src/xpublisher.cpp @@ -7,24 +7,27 @@ * The full license is in the file LICENSE, distributed with this software. * ****************************************************************************/ -#include #include +#include -#include "zmq_addon.hpp" #include "xeus-zmq/xmiddleware.hpp" #include "xpublisher.hpp" namespace xeus { xpublisher::xpublisher(zmq::context_t& context, + std::function serialize_iopub_msg_cb, const std::string& transport, const std::string& ip, const std::string& port) - : m_publisher(context, zmq::socket_type::pub) + : m_publisher(context, zmq::socket_type::xpub) , m_listener(context, zmq::socket_type::sub) , m_controller(context, zmq::socket_type::rep) + , m_serialize_iopub_msg_cb(std::move(serialize_iopub_msg_cb)) { init_socket(m_publisher, transport, ip, port); + // Set xpub_verbose option to 1 to pass all subscription messages (not only unique ones). + m_publisher.set(zmq::sockopt::xpub_verbose, 1); m_listener.set(zmq::sockopt::subscribe, ""); m_listener.bind(get_publisher_end_point()); m_controller.set(zmq::sockopt::linger, get_socket_linger()); @@ -35,6 +38,16 @@ namespace xeus { } + xpub_message xpublisher::create_xpub_message(const std::string& topic) + { + xmessage_base_data data; + data.m_header = xeus::make_header("iopub_welcome", "", ""); + data.m_content["subscription"] = topic; + xpub_message p_msg("", std::move(data)); + + return p_msg; + } + std::string xpublisher::get_port() const { return get_socket_port(m_publisher); @@ -44,12 +57,13 @@ namespace xeus { zmq::pollitem_t items[] = { { m_listener, 0, ZMQ_POLLIN, 0 }, - { m_controller, 0, ZMQ_POLLIN, 0 } + { m_controller, 0, ZMQ_POLLIN, 0 }, + { m_publisher, 0, ZMQ_POLLIN, 0 } }; while (true) { - zmq::poll(&items[0], 2, std::chrono::milliseconds(-1)); + zmq::poll(&items[0], 3, std::chrono::milliseconds(-1)); if (items[0].revents & ZMQ_POLLIN) { @@ -66,6 +80,43 @@ namespace xeus wire_msg.send(m_controller); break; } + + if (items[2].revents & ZMQ_POLLIN) + { + // Received event: Single frame + // Either `1{subscription-topic}` for subscription + // or `0{subscription-topic}` for unsubscription + zmq::multipart_t wire_msg; + wire_msg.recv(m_publisher); + + // Received event should be a single frame + if (wire_msg.size() != 1) + { + throw std::runtime_error("ERROR: Received message on XPUB is not a single frame"); + } + + zmq::message_t frame = wire_msg.pop(); + + // Event is one byte 0 = unsub or 1 = sub, followed by topic + uint8_t *event = (uint8_t *)frame.data(); + // If subscription (unsubscription is ignored) + if (event[0] == 1) + { + std::string topic((char *)(event + 1), frame.size() - 1); + if (m_serialize_iopub_msg_cb) + { + // Construct the `iopub_welcome` message + xpub_message p_msg = create_xpub_message(topic); + zmq::multipart_t iopub_welcome_wire_msg = m_serialize_iopub_msg_cb(std::move(p_msg)); + // Send the `iopub_welcome` message + iopub_welcome_wire_msg.send(m_publisher); + } + else + { + throw std::runtime_error("ERROR: IOPUB serialization callback not set"); + } + } + } } } } diff --git a/src/xpublisher.hpp b/src/xpublisher.hpp index ba59d30..7a54203 100644 --- a/src/xpublisher.hpp +++ b/src/xpublisher.hpp @@ -10,9 +10,13 @@ #ifndef XEUS_PUBLISHER_HPP #define XEUS_PUBLISHER_HPP +#include #include #include "zmq.hpp" +#include "zmq_addon.hpp" + +#include "xeus/xmessage.hpp" namespace xeus { @@ -21,6 +25,7 @@ namespace xeus public: xpublisher(zmq::context_t& context, + std::function serialize_iopub_msg_cb, const std::string& transport, const std::string& ip, const std::string& port); @@ -33,9 +38,13 @@ namespace xeus private: + xpub_message create_xpub_message(const std::string& topic); + zmq::socket_t m_publisher; zmq::socket_t m_listener; zmq::socket_t m_controller; + + std::function m_serialize_iopub_msg_cb; }; } diff --git a/src/xserver_zmq.cpp b/src/xserver_zmq.cpp index 9c4e8b6..0521ba2 100644 --- a/src/xserver_zmq.cpp +++ b/src/xserver_zmq.cpp @@ -10,8 +10,8 @@ #include #include -#include "zmq_addon.hpp" #include "xeus/xguid.hpp" +#include "xeus-zmq/xauthentication.hpp" #include "xeus-zmq/xserver_zmq.hpp" #include "xeus-zmq/xmiddleware.hpp" #include "xeus-zmq/xzmq_serializer.hpp" @@ -31,12 +31,14 @@ namespace xeus , m_publisher_pub(context, zmq::socket_type::pub) , m_publisher_controller(context, zmq::socket_type::req) , m_heartbeat_controller(context, zmq::socket_type::req) - , p_publisher(new xpublisher(context, config.m_transport, config.m_ip, config.m_iopub_port)) + , p_auth(make_xauthentication(config.m_signature_scheme, config.m_key)) + , p_publisher(new xpublisher(context, + std::bind(&xserver_zmq::serialize_iopub, this, std::placeholders::_1), + config.m_transport, config.m_ip, config.m_iopub_port)) , p_heartbeat(new xheartbeat(context, config.m_transport, config.m_ip, config.m_hb_port)) , m_iopub_thread() , m_hb_thread() , p_messenger(new xtrivial_messenger(this)) - , p_auth(make_xauthentication(config.m_signature_scheme, config.m_key)) , m_error_handler(eh) , m_request_stop(false) { @@ -207,6 +209,11 @@ namespace xeus (void)m_heartbeat_controller.recv(response); } + zmq::multipart_t xserver_zmq::serialize_iopub(xpub_message&& msg) + { + return xzmq_serializer::serialize_iopub(std::move(msg), *p_auth, m_error_handler); + } + std::unique_ptr make_xserver_zmq(xcontext& context, const xconfiguration& config, nl::json::error_handler_t eh) diff --git a/src/xserver_zmq_split.cpp b/src/xserver_zmq_split.cpp index 824a6d2..4267da4 100644 --- a/src/xserver_zmq_split.cpp +++ b/src/xserver_zmq_split.cpp @@ -12,6 +12,7 @@ #include "zmq_addon.hpp" #include "xeus/xguid.hpp" +#include "xeus-zmq/xauthentication.hpp" #include "xeus-zmq/xserver_zmq_split.hpp" #include "xeus-zmq/xmiddleware.hpp" #include "xeus-zmq/xzmq_serializer.hpp" @@ -26,15 +27,17 @@ namespace xeus xserver_zmq_split::xserver_zmq_split(zmq::context_t& context, const xconfiguration& config, nl::json::error_handler_t eh) - : p_controller(new xcontrol(context, config.m_transport, config.m_ip ,config.m_control_port, this)) + : p_auth(make_xauthentication(config.m_signature_scheme, config.m_key)) + , p_controller(new xcontrol(context, config.m_transport, config.m_ip ,config.m_control_port, this)) , p_heartbeat(new xheartbeat(context, config.m_transport, config.m_ip, config.m_hb_port)) - , p_publisher(new xpublisher(context, config.m_transport, config.m_ip, config.m_iopub_port)) + , p_publisher(new xpublisher(context, + std::bind(&xserver_zmq_split::serialize_iopub, this, std::placeholders::_1), + config.m_transport, config.m_ip, config.m_iopub_port)) , p_shell(new xshell(context, config.m_transport, config.m_ip ,config.m_shell_port, config.m_stdin_port, this)) , m_control_thread() , m_hb_thread() , m_iopub_thread() , m_shell_thread() - , p_auth(make_xauthentication(config.m_signature_scheme, config.m_key)) , m_error_handler(eh) , m_control_stopped(false) { @@ -157,5 +160,10 @@ namespace xeus { return m_control_stopped; } + + zmq::multipart_t xserver_zmq_split::serialize_iopub(xpub_message&& msg) + { + return xzmq_serializer::serialize_iopub(std::move(msg), *p_auth, m_error_handler); + } } diff --git a/test/test_kernel.py b/test/test_kernel.py index f681a0e..d025c76 100644 --- a/test/test_kernel.py +++ b/test/test_kernel.py @@ -31,6 +31,10 @@ def test_xeus_stderr(self): self.assertEqual(output_msgs[0]['content']['name'], 'stderr') self.assertEqual(output_msgs[0]['content']['text'], 'error') +class XeusIopubWelcomeTests(jupyter_kernel_test.IopubWelcomeTests): + + kernel_name = "test_kernel" + support_iopub_welcome = True if __name__ == '__main__': unittest.main() diff --git a/test/test_kernel_control.py b/test/test_kernel_control.py index c1607c3..74294e1 100644 --- a/test/test_kernel_control.py +++ b/test/test_kernel_control.py @@ -31,6 +31,10 @@ def test_xeus_stderr(self): self.assertEqual(output_msgs[0]['content']['name'], 'stderr') self.assertEqual(output_msgs[0]['content']['text'], 'error') +class XeusIopubWelcomeTests(jupyter_kernel_test.IopubWelcomeTests): + + kernel_name = "test_kernel_control" + support_iopub_welcome = True if __name__ == '__main__': unittest.main() diff --git a/test/test_kernel_shell.py b/test/test_kernel_shell.py index 1b64afc..5b9aa25 100644 --- a/test/test_kernel_shell.py +++ b/test/test_kernel_shell.py @@ -31,6 +31,10 @@ def test_xeus_stderr(self): self.assertEqual(output_msgs[0]['content']['name'], 'stderr') self.assertEqual(output_msgs[0]['content']['text'], 'error') +class XeusIopubWelcomeTests(jupyter_kernel_test.IopubWelcomeTests): + + kernel_name = "test_kernel_shell" + support_iopub_welcome = True if __name__ == '__main__': unittest.main()