diff --git a/messaging/__init__.py b/messaging/__init__.py index c6a40f7199f5f5..66e69e330e85b1 100644 --- a/messaging/__init__.py +++ b/messaging/__init__.py @@ -25,14 +25,12 @@ def pub_sock(endpoint): def sub_sock(endpoint, poller=None, addr="127.0.0.1", conflate=False, timeout=None): sock = SubSocket() - sock.connect(context, endpoint, conflate) + addr = addr.encode('utf8') + sock.connect(context, endpoint, addr, conflate) if timeout is not None: sock.setTimeout(timeout) - if addr != "127.0.0.1": - raise NotImplementedError("Only localhost supported") - if poller is not None: poller.registerSocket(sock) return sock @@ -94,14 +92,15 @@ def recv_sock(sock, wait=False): return dat def recv_one(sock): - return log.Event.from_bytes(sock.receive()) + dat = sock.receive() + if dat is not None: + dat = log.Event.from_bytes(dat) + return dat def recv_one_or_none(sock): dat = sock.receive(non_blocking=True) - if dat is not None: - log.Event.from_bytes(dat) - + dat = log.Event.from_bytes(dat) return dat def recv_one_retry(sock): @@ -167,6 +166,9 @@ def update_msgs(self, cur_time, msgs): self.frame += 1 self.updated = dict.fromkeys(self.updated, False) for msg in msgs: + if msg is None: + continue + s = msg.which() self.updated[s] = True self.rcv_time[s] = cur_time diff --git a/messaging/impl_zmq.cc b/messaging/impl_zmq.cc index 05696dcb2aabde..77c35a6567597e 100644 --- a/messaging/impl_zmq.cc +++ b/messaging/impl_zmq.cc @@ -73,7 +73,7 @@ ZMQMessage::~ZMQMessage() { } -void ZMQSubSocket::connect(Context *context, std::string endpoint, bool conflate){ +void ZMQSubSocket::connect(Context *context, std::string endpoint, std::string address, bool conflate){ sock = zmq_socket(context->getRawContext(), ZMQ_SUB); assert(sock); @@ -87,7 +87,7 @@ void ZMQSubSocket::connect(Context *context, std::string endpoint, bool conflate int reconnect_ivl = 500; zmq_setsockopt(sock, ZMQ_RECONNECT_IVL_MAX, &reconnect_ivl, sizeof(reconnect_ivl)); - full_endpoint = "tcp://127.0.0.1:"; + full_endpoint = "tcp://" + address + ":"; full_endpoint += std::to_string(get_port(endpoint)); std::cout << "ZMQ SUB: " << full_endpoint << std::endl; @@ -109,15 +109,6 @@ Message * ZMQSubSocket::receive(bool non_blocking){ r = new ZMQMessage; r->init((char*)zmq_msg_data(&msg), zmq_msg_size(&msg)); } - // else { - // std::cout << "endpoint: " << full_endpoint << std::endl; - // std::cout << "Receive error: " << zmq_strerror(errno) << std::endl; - // std::cout << "non_blocking: " << non_blocking << std::endl; - // int timeout = 123; - // size_t sz = sizeof(int); - // zmq_getsockopt(sock, ZMQ_RCVTIMEO, &timeout, &sz); - // std::cout << "timeout: " << timeout << std::endl; - // } zmq_msg_close(&msg); return r; diff --git a/messaging/impl_zmq.hpp b/messaging/impl_zmq.hpp index cfe105f4422ea3..a40d623e3af849 100644 --- a/messaging/impl_zmq.hpp +++ b/messaging/impl_zmq.hpp @@ -32,7 +32,7 @@ class ZMQSubSocket : public SubSocket { void * sock; std::string full_endpoint; public: - void connect(Context *context, std::string endpoint, bool conflate=false); + void connect(Context *context, std::string endpoint, std::string address, bool conflate=false); void setTimeout(int timeout); void * getRawSocket() {return sock;} Message *receive(bool non_blocking=false); diff --git a/messaging/messaging.cc b/messaging/messaging.cc index 352b04a6d969d9..a1178a510872a6 100644 --- a/messaging/messaging.cc +++ b/messaging/messaging.cc @@ -13,7 +13,14 @@ SubSocket * SubSocket::create(){ SubSocket * SubSocket::create(Context * context, std::string endpoint){ SubSocket *s = SubSocket::create(); - s->connect(context, endpoint); + s->connect(context, endpoint, "127.0.0.1"); + + return s; +} + +SubSocket * SubSocket::create(Context * context, std::string endpoint, std::string address){ + SubSocket *s = SubSocket::create(); + s->connect(context, endpoint, address); return s; } diff --git a/messaging/messaging.hpp b/messaging/messaging.hpp index f379dc14343e62..95e34c0d67f4dc 100644 --- a/messaging/messaging.hpp +++ b/messaging/messaging.hpp @@ -23,12 +23,13 @@ class Message { class SubSocket { public: - virtual void connect(Context *context, std::string endpoint, bool conflate=false) = 0; + virtual void connect(Context *context, std::string endpoint, std::string address, bool conflate=false) = 0; virtual void setTimeout(int timeout) = 0; virtual Message *receive(bool non_blocking=false) = 0; virtual void * getRawSocket() = 0; static SubSocket * create(); static SubSocket * create(Context * context, std::string endpoint); + static SubSocket * create(Context * context, std::string endpoint, std::string address); virtual ~SubSocket(){}; }; diff --git a/messaging/messaging.pxd b/messaging/messaging.pxd index cb8ca07eab0f04..5e3da7241544cf 100644 --- a/messaging/messaging.pxd +++ b/messaging/messaging.pxd @@ -23,7 +23,7 @@ cdef extern from "messaging.hpp": cdef cppclass SubSocket: @staticmethod SubSocket * create() - void connect(Context *, string, bool) + void connect(Context *, string, string, bool) Message * receive(bool) void setTimeout(int) diff --git a/messaging/messaging_pyx.pyx b/messaging/messaging_pyx.pyx index e2ac79cb2b43f7..a4ac5172b04ae7 100644 --- a/messaging/messaging_pyx.pyx +++ b/messaging/messaging_pyx.pyx @@ -66,8 +66,8 @@ cdef class SubSocket: self.is_owner = False self.socket = ptr - def connect(self, Context context, string endpoint, bool conflate=False): - self.socket.connect(context.context, endpoint, conflate) + def connect(self, Context context, string endpoint, string address=b"127.0.0.1", bool conflate=False): + self.socket.connect(context.context, endpoint, address, conflate) def setTimeout(self, int timeout): self.socket.setTimeout(timeout)