Skip to content

Commit

Permalink
fix: crash when stopping WebTunnelAgent under load; bump version
Browse files Browse the repository at this point in the history
  • Loading branch information
obiltschnig committed Nov 16, 2024
1 parent 4328f15 commit 0260e6e
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 105 deletions.
3 changes: 2 additions & 1 deletion WebTunnel/WebTunnelAgent/src/WebTunnelAgent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ class WebTunnelAgent: public Poco::Util::ServerApplication
std::string password = config().getString("webtunnel.password"s, ""s);
if (!username.empty())
{
logger().debug("Authenticating as %s."s, username);
Poco::Net::HTTPBasicCredentials creds(username, password);
creds.authenticate(request);
}
Expand Down Expand Up @@ -977,7 +978,7 @@ class WebTunnelAgent: public Poco::Util::ServerApplication

const std::string WebTunnelAgent::SEC_WEBSOCKET_PROTOCOL("Sec-WebSocket-Protocol");
const std::string WebTunnelAgent::WEBTUNNEL_PROTOCOL("com.appinf.webtunnel.server/1.0");
const std::string WebTunnelAgent::WEBTUNNEL_AGENT("WebTunnelAgent/1.17.2");
const std::string WebTunnelAgent::WEBTUNNEL_AGENT("WebTunnelAgent/2.0.0");


POCO_SERVER_MAIN(WebTunnelAgent)
91 changes: 91 additions & 0 deletions WebTunnel/include/Poco/WebTunnel/SocketDispatcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,31 @@ namespace Poco {
namespace WebTunnel {


class AutoSetEvent
{
public:
AutoSetEvent(Poco::Event& event):
_event(event)
{
}

~AutoSetEvent()
{
try
{
_event.set();
}
catch (...)
{
poco_unexpected();
}
}

private:
Poco::Event& _event;
};


class WebTunnel_API SocketDispatcher: public Poco::Runnable
/// SocketDispatcher implements a multi-threaded variant of the
/// Reactor pattern, optimized for forwarding data from one
Expand Down Expand Up @@ -102,6 +127,72 @@ class WebTunnel_API SocketDispatcher: public Poco::Runnable
/// Shuts down the sending direction of the socket, but only after
/// all pending sends has been sent.

class WebTunnel_API TaskNotification: public Poco::Notification
{
public:
using Ptr = Poco::AutoPtr<TaskNotification>;

enum
{
TASK_WAIT_TIMEOUT = 30000
};

TaskNotification(SocketDispatcher& dispatcher):
_dispatcher(dispatcher)
{
}

~TaskNotification() = default;

void wait()
{
_done.wait(TASK_WAIT_TIMEOUT);
}

virtual void execute() = 0;

protected:
SocketDispatcher& _dispatcher;
Poco::Event _done;
};

template <class Fn>
class FunctorTaskNotification: public TaskNotification
{
public:
using Ptr = Poco::AutoPtr<FunctorTaskNotification>;

FunctorTaskNotification(SocketDispatcher& dispatcher, Fn&& fn):
TaskNotification(dispatcher),
_fn(std::move(fn))
{
}

void execute()
{
AutoSetEvent ase(_done);

_fn(_dispatcher);
}

private:
Fn _fn;
};

template <class Fn>
void queueTask(Fn&& fn)
/// Enqueues a task for execution in the dispatcher thread.
/// The task is given as a lambda expression or functor.
{
typename FunctorTaskNotification<Fn>::Ptr pTask = new FunctorTaskNotification<Fn>(*this, std::move(fn));
_queue.enqueueNotification(pTask);
_pollSet.wakeUp();
if (!inDispatcherThread())
{
pTask->wait();
}
}

protected:
struct PendingSend
{
Expand Down
62 changes: 30 additions & 32 deletions WebTunnel/src/RemotePortForwarder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,12 @@ RemotePortForwarder::~RemotePortForwarder()

void RemotePortForwarder::stop()
{
if (_pWebSocket)
{
closeWebSocket(RPF_CLOSE_GRACEFUL, true);
}
_dispatcher.queueTask(
[pSelf=this](SocketDispatcher& dispatcher)
{
pSelf->closeWebSocket(RPF_CLOSE_GRACEFUL, true);
}
);
}


Expand Down Expand Up @@ -530,45 +532,41 @@ void RemotePortForwarder::sendResponse(Poco::UInt16 channel, Poco::UInt8 opcode,

void RemotePortForwarder::closeWebSocket(CloseReason reason, bool active)
{
{
if (!_pWebSocket || !_pWebSocket->impl()->initialized()) return;
if (_webSocketFlags & CF_CLOSED_LOCAL) return;

if (_webSocketFlags & CF_CLOSED_LOCAL) return;

if (_logger.debug())
{
_logger.debug("Closing WebSocket, reason: %d, active: %b"s, static_cast<int>(reason), active);
}
try
if (_logger.debug())
{
_logger.debug("Closing WebSocket, reason: %d, active: %b"s, static_cast<int>(reason), active);
}
try
{
if (reason == RPF_CLOSE_GRACEFUL)
{
if (reason == RPF_CLOSE_GRACEFUL)
try
{
try
{
if (active)
{
char buffer[2];
Poco::MemoryOutputStream ostr(buffer, sizeof(buffer));
Poco::BinaryWriter writer(ostr, Poco::BinaryWriter::NETWORK_BYTE_ORDER);
writer << static_cast<Poco::UInt16>(Poco::Net::WebSocket::WS_NORMAL_CLOSE);
_dispatcher.sendBytes(*_pWebSocket, buffer, sizeof(buffer), Poco::Net::WebSocket::FRAME_FLAG_FIN | Poco::Net::WebSocket::FRAME_OP_CLOSE);
}
}
catch (Poco::Exception&)
if (active)
{
char buffer[2];
Poco::MemoryOutputStream ostr(buffer, sizeof(buffer));
Poco::BinaryWriter writer(ostr, Poco::BinaryWriter::NETWORK_BYTE_ORDER);
writer << static_cast<Poco::UInt16>(Poco::Net::WebSocket::WS_NORMAL_CLOSE);
_dispatcher.sendBytes(*_pWebSocket, buffer, sizeof(buffer), Poco::Net::WebSocket::FRAME_FLAG_FIN | Poco::Net::WebSocket::FRAME_OP_CLOSE);
}
}
for (ChannelMap::iterator it = _channelMap.begin(); it != _channelMap.end(); ++it)
catch (Poco::Exception&)
{
_dispatcher.removeSocket(it->second.socket);
}
_channelMap.clear();
_dispatcher.shutdownSend(*_pWebSocket);
}
catch (Poco::Exception& exc)
for (ChannelMap::iterator it = _channelMap.begin(); it != _channelMap.end(); ++it)
{
_logger.log(exc);
_dispatcher.removeSocket(it->second.socket);
}
_channelMap.clear();
_dispatcher.shutdownSend(*_pWebSocket);
}
catch (Poco::Exception& exc)
{
_logger.log(exc);
}

_dispatcher.updateSocket(*_pWebSocket, Poco::Net::PollSet::POLL_READ, _closeTimeout);
Expand Down
86 changes: 14 additions & 72 deletions WebTunnel/src/SocketDispatcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,67 +26,10 @@ namespace Poco {
namespace WebTunnel {


class AutoSetEvent
class AddSocketNotification: public SocketDispatcher::TaskNotification
{
public:
AutoSetEvent(Poco::Event& event):
_event(event)
{
}

~AutoSetEvent()
{
try
{
_event.set();
}
catch (...)
{
poco_unexpected();
}
}

private:
Poco::Event& _event;
};


class TaskNotification: public Poco::Notification
{
public:
typedef Poco::AutoPtr<TaskNotification> Ptr;

enum
{
TASK_WAIT_TIMEOUT = 30000
};

TaskNotification(SocketDispatcher& dispatcher):
_dispatcher(dispatcher)
{
}

~TaskNotification()
{
}

void wait()
{
_done.wait(TASK_WAIT_TIMEOUT);
}

virtual void execute() = 0;

protected:
SocketDispatcher& _dispatcher;
Poco::Event _done;
};


class AddSocketNotification: public TaskNotification
{
public:
typedef Poco::AutoPtr<AddSocketNotification> Ptr;
using Ptr = Poco::AutoPtr<AddSocketNotification>;

AddSocketNotification(SocketDispatcher& dispatcher, const Poco::Net::StreamSocket& socket, const SocketDispatcher::SocketHandler::Ptr& pHandler, int mode, Poco::Timespan timeout):
TaskNotification(dispatcher),
Expand All @@ -112,10 +55,10 @@ class AddSocketNotification: public TaskNotification
};


class UpdateSocketNotification: public TaskNotification
class UpdateSocketNotification: public SocketDispatcher::TaskNotification
{
public:
typedef Poco::AutoPtr<UpdateSocketNotification> Ptr;
using Ptr = Poco::AutoPtr<UpdateSocketNotification>;

UpdateSocketNotification(SocketDispatcher& dispatcher, const Poco::Net::StreamSocket& socket, int mode, Poco::Timespan timeout):
TaskNotification(dispatcher),
Expand All @@ -139,10 +82,10 @@ class UpdateSocketNotification: public TaskNotification
};


class RemoveSocketNotification: public TaskNotification
class RemoveSocketNotification: public SocketDispatcher::TaskNotification
{
public:
typedef Poco::AutoPtr<RemoveSocketNotification> Ptr;
using Ptr = Poco::AutoPtr<RemoveSocketNotification>;

RemoveSocketNotification(SocketDispatcher& dispatcher, const Poco::Net::StreamSocket& socket):
TaskNotification(dispatcher),
Expand All @@ -162,10 +105,10 @@ class RemoveSocketNotification: public TaskNotification
};


class CloseSocketNotification: public TaskNotification
class CloseSocketNotification: public SocketDispatcher::TaskNotification
{
public:
typedef Poco::AutoPtr<CloseSocketNotification> Ptr;
using Ptr = Poco::AutoPtr<CloseSocketNotification>;

CloseSocketNotification(SocketDispatcher& dispatcher, const Poco::Net::StreamSocket& socket):
TaskNotification(dispatcher),
Expand All @@ -185,10 +128,10 @@ class CloseSocketNotification: public TaskNotification
};


class ResetNotification: public TaskNotification
class ResetNotification: public SocketDispatcher::TaskNotification
{
public:
typedef Poco::AutoPtr<ResetNotification> Ptr;
using Ptr = Poco::AutoPtr<ResetNotification>;

ResetNotification(SocketDispatcher& dispatcher):
TaskNotification(dispatcher)
Expand All @@ -204,10 +147,10 @@ class ResetNotification: public TaskNotification
};


class SendBytesNotification: public TaskNotification
class SendBytesNotification: public SocketDispatcher::TaskNotification
{
public:
typedef Poco::AutoPtr<SendBytesNotification> Ptr;
using Ptr = Poco::AutoPtr<SendBytesNotification>;

SendBytesNotification(SocketDispatcher& dispatcher, const Poco::Net::StreamSocket& socket, const void* pBuffer, std::size_t length, int options):
TaskNotification(dispatcher),
Expand All @@ -231,10 +174,10 @@ class SendBytesNotification: public TaskNotification
};


class ShutdownSendNotification: public TaskNotification
class ShutdownSendNotification: public SocketDispatcher::TaskNotification
{
public:
typedef Poco::AutoPtr<ShutdownSendNotification> Ptr;
using Ptr = Poco::AutoPtr<ShutdownSendNotification>;

ShutdownSendNotification(SocketDispatcher& dispatcher, const Poco::Net::StreamSocket& socket):
TaskNotification(dispatcher),
Expand Down Expand Up @@ -303,7 +246,6 @@ void SocketDispatcher::reset()

void SocketDispatcher::addSocket(const Poco::Net::StreamSocket& socket, SocketHandler::Ptr pHandler, int mode, Poco::Timespan timeout)
{
poco_assert (!socket.getBlocking());
AddSocketNotification::Ptr pNf = new AddSocketNotification(*this, socket, pHandler, mode, timeout);
_queue.enqueueNotification(pNf);
_pollSet.wakeUp();
Expand Down

0 comments on commit 0260e6e

Please sign in to comment.