Skip to content

Commit

Permalink
fix(RemotePortForwarder): wait for WebSocket close handshake to compl…
Browse files Browse the repository at this point in the history
…ete when stopping
  • Loading branch information
obiltschnig committed Nov 20, 2024
1 parent 9aebdc4 commit 210a720
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 7 deletions.
5 changes: 5 additions & 0 deletions WebTunnel/include/Poco/WebTunnel/SocketDispatcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ class WebTunnel_API SocketDispatcher: public Poco::Runnable
void closeSocket(const Poco::Net::StreamSocket& socket);
/// Closes and removes a socket and its associated handler from the SocketDispatcher.

bool hasSocket(const Poco::Net::StreamSocket& socket);
/// Returns true if the socket is active in the SocketDispatcher.

void stop();
/// Stops the SocketDispatcher and removes all sockets.

Expand Down Expand Up @@ -256,6 +259,7 @@ class WebTunnel_API SocketDispatcher: public Poco::Runnable
void updateSocketImpl(const Poco::Net::StreamSocket& socket, int mode, Poco::Timespan timeout);
void removeSocketImpl(const Poco::Net::StreamSocket& socket);
void closeSocketImpl(Poco::Net::StreamSocket& socket);
bool hasSocketImpl(const Poco::Net::StreamSocket& socket) const;
void resetImpl();
void sendBytesImpl(Poco::Net::StreamSocket& socket, Poco::Buffer<char>&& buffer, int flags);
void shutdownSendImpl(Poco::Net::StreamSocket& socket);
Expand All @@ -278,6 +282,7 @@ class WebTunnel_API SocketDispatcher: public Poco::Runnable
friend class UpdateSocketNotification;
friend class RemoveSocketNotification;
friend class CloseSocketNotification;
friend class HasSocketNotification;
friend class ResetNotification;
friend class SendBytesNotification;
friend class ShutdownSendNotification;
Expand Down
10 changes: 10 additions & 0 deletions WebTunnel/src/RemotePortForwarder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,22 @@ RemotePortForwarder::~RemotePortForwarder()

void RemotePortForwarder::stop()
{
const Poco::Timestamp::TimeDiff STOP_TIMEOUT = 500000;

_dispatcher.queueTask(
[pSelf=this](SocketDispatcher& dispatcher)
{
pSelf->closeWebSocket(RPF_CLOSE_GRACEFUL, true);
}
);

Poco::Timestamp closeTime;
while (_dispatcher.hasSocket(*_pWebSocket) && !closeTime.isElapsed(STOP_TIMEOUT))
{
_logger.debug("Waiting for WebSocket closing handshake to complete..."s);
Poco::Thread::sleep(20);
}
_dispatcher.removeSocket(*_pWebSocket);
}


Expand Down
70 changes: 63 additions & 7 deletions WebTunnel/src/SocketDispatcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,35 @@ class CloseSocketNotification: public SocketDispatcher::TaskNotification
};


class HasSocketNotification: public SocketDispatcher::TaskNotification
{
public:
using Ptr = Poco::AutoPtr<HasSocketNotification>;

HasSocketNotification(SocketDispatcher& dispatcher, const Poco::Net::StreamSocket& socket):
TaskNotification(dispatcher),
_socket(socket)
{
}

void execute()
{
AutoSetEvent ase(_done);

_result = _dispatcher.hasSocketImpl(_socket);
}

bool result() const
{
return _result;
}

private:
Poco::Net::StreamSocket _socket;
bool _result = false;
};


class ResetNotification: public SocketDispatcher::TaskNotification
{
public:
Expand Down Expand Up @@ -296,6 +325,23 @@ void SocketDispatcher::closeSocket(const Poco::Net::StreamSocket& socket)
}


bool SocketDispatcher::hasSocket(const Poco::Net::StreamSocket& socket)
{
if (inDispatcherThread())
{
return hasSocketImpl(socket);
}
else
{
HasSocketNotification::Ptr pNf = new HasSocketNotification(*this, socket);
_queue.enqueueNotification(pNf);
_pollSet.wakeUp();
pNf->wait();
return pNf->result();
}
}


void SocketDispatcher::sendBytes(Poco::Net::StreamSocket& socket, const void* buffer, std::size_t length, int options)
{
if (inDispatcherThread())
Expand Down Expand Up @@ -553,14 +599,18 @@ void SocketDispatcher::updateSocketImpl(const Poco::Net::StreamSocket& socket, i

void SocketDispatcher::removeSocketImpl(const Poco::Net::StreamSocket& socket)
{
_logger.trace("Removing socket %?d..."s, socket.impl()->sockfd());
_socketMap.erase(socket);
try
{
_pollSet.remove(socket);
}
catch (Poco::IOException&)
auto it = _socketMap.find(socket);
if (it != _socketMap.end())
{
_logger.trace("Removing socket %?d..."s, socket.impl()->sockfd());
_socketMap.erase(it);
try
{
_pollSet.remove(socket);
}
catch (Poco::IOException&)
{
}
}
}

Expand All @@ -580,6 +630,12 @@ void SocketDispatcher::closeSocketImpl(Poco::Net::StreamSocket& socket)
}


bool SocketDispatcher::hasSocketImpl(const Poco::Net::StreamSocket& socket) const
{
return _socketMap.find(socket) != _socketMap.end();
}


void SocketDispatcher::resetImpl()
{
_socketMap.clear();
Expand Down

0 comments on commit 210a720

Please sign in to comment.