diff --git a/WebTunnel/include/Poco/WebTunnel/SocketDispatcher.h b/WebTunnel/include/Poco/WebTunnel/SocketDispatcher.h index cbfcacd8..3c717d0d 100644 --- a/WebTunnel/include/Poco/WebTunnel/SocketDispatcher.h +++ b/WebTunnel/include/Poco/WebTunnel/SocketDispatcher.h @@ -108,6 +108,9 @@ class WebTunnel_API SocketDispatcher: public Poco::Runnable void closeSocket(const Poco::Net::StreamSocket& socket); /// Closes and removes a socket and its associated handler from the SocketDispatcher. + bool hasSocket(const Poco::Net::StreamSocket& socket); + /// Returns true if the socket is active in the SocketDispatcher. + void stop(); /// Stops the SocketDispatcher and removes all sockets. @@ -256,6 +259,7 @@ class WebTunnel_API SocketDispatcher: public Poco::Runnable void updateSocketImpl(const Poco::Net::StreamSocket& socket, int mode, Poco::Timespan timeout); void removeSocketImpl(const Poco::Net::StreamSocket& socket); void closeSocketImpl(Poco::Net::StreamSocket& socket); + bool hasSocketImpl(const Poco::Net::StreamSocket& socket) const; void resetImpl(); void sendBytesImpl(Poco::Net::StreamSocket& socket, Poco::Buffer&& buffer, int flags); void shutdownSendImpl(Poco::Net::StreamSocket& socket); @@ -278,6 +282,7 @@ class WebTunnel_API SocketDispatcher: public Poco::Runnable friend class UpdateSocketNotification; friend class RemoveSocketNotification; friend class CloseSocketNotification; + friend class HasSocketNotification; friend class ResetNotification; friend class SendBytesNotification; friend class ShutdownSendNotification; diff --git a/WebTunnel/src/RemotePortForwarder.cpp b/WebTunnel/src/RemotePortForwarder.cpp index f31fadbd..8d5be68f 100644 --- a/WebTunnel/src/RemotePortForwarder.cpp +++ b/WebTunnel/src/RemotePortForwarder.cpp @@ -89,12 +89,22 @@ RemotePortForwarder::~RemotePortForwarder() void RemotePortForwarder::stop() { + const Poco::Timestamp::TimeDiff STOP_TIMEOUT = 500000; + _dispatcher.queueTask( [pSelf=this](SocketDispatcher& dispatcher) { pSelf->closeWebSocket(RPF_CLOSE_GRACEFUL, true); } ); + + Poco::Timestamp closeTime; + while (_dispatcher.hasSocket(*_pWebSocket) && !closeTime.isElapsed(STOP_TIMEOUT)) + { + _logger.debug("Waiting for WebSocket closing handshake to complete..."s); + Poco::Thread::sleep(20); + } + _dispatcher.removeSocket(*_pWebSocket); } diff --git a/WebTunnel/src/SocketDispatcher.cpp b/WebTunnel/src/SocketDispatcher.cpp index 3d921df4..b01b57d2 100644 --- a/WebTunnel/src/SocketDispatcher.cpp +++ b/WebTunnel/src/SocketDispatcher.cpp @@ -128,6 +128,35 @@ class CloseSocketNotification: public SocketDispatcher::TaskNotification }; +class HasSocketNotification: public SocketDispatcher::TaskNotification +{ +public: + using Ptr = Poco::AutoPtr; + + HasSocketNotification(SocketDispatcher& dispatcher, const Poco::Net::StreamSocket& socket): + TaskNotification(dispatcher), + _socket(socket) + { + } + + void execute() + { + AutoSetEvent ase(_done); + + _result = _dispatcher.hasSocketImpl(_socket); + } + + bool result() const + { + return _result; + } + +private: + Poco::Net::StreamSocket _socket; + bool _result = false; +}; + + class ResetNotification: public SocketDispatcher::TaskNotification { public: @@ -296,6 +325,23 @@ void SocketDispatcher::closeSocket(const Poco::Net::StreamSocket& socket) } +bool SocketDispatcher::hasSocket(const Poco::Net::StreamSocket& socket) +{ + if (inDispatcherThread()) + { + return hasSocketImpl(socket); + } + else + { + HasSocketNotification::Ptr pNf = new HasSocketNotification(*this, socket); + _queue.enqueueNotification(pNf); + _pollSet.wakeUp(); + pNf->wait(); + return pNf->result(); + } +} + + void SocketDispatcher::sendBytes(Poco::Net::StreamSocket& socket, const void* buffer, std::size_t length, int options) { if (inDispatcherThread()) @@ -553,14 +599,18 @@ void SocketDispatcher::updateSocketImpl(const Poco::Net::StreamSocket& socket, i void SocketDispatcher::removeSocketImpl(const Poco::Net::StreamSocket& socket) { - _logger.trace("Removing socket %?d..."s, socket.impl()->sockfd()); - _socketMap.erase(socket); - try - { - _pollSet.remove(socket); - } - catch (Poco::IOException&) + auto it = _socketMap.find(socket); + if (it != _socketMap.end()) { + _logger.trace("Removing socket %?d..."s, socket.impl()->sockfd()); + _socketMap.erase(it); + try + { + _pollSet.remove(socket); + } + catch (Poco::IOException&) + { + } } } @@ -580,6 +630,12 @@ void SocketDispatcher::closeSocketImpl(Poco::Net::StreamSocket& socket) } +bool SocketDispatcher::hasSocketImpl(const Poco::Net::StreamSocket& socket) const +{ + return _socketMap.find(socket) != _socketMap.end(); +} + + void SocketDispatcher::resetImpl() { _socketMap.clear();