Skip to content

Commit

Permalink
[websockets] Handle clean disconnection when disconnect request occur…
Browse files Browse the repository at this point in the history
…s before the first connect attempt
  • Loading branch information
c-jimenez committed Jul 4, 2024
1 parent bdb36b4 commit eabee12
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 115 deletions.
2 changes: 1 addition & 1 deletion makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ BIN_DIR:=$(ROOT_DIR)/bin
PARALLEL_BUILD?=-j 4

# Build type can be either Debug or Release
BUILD_TYPE?=Release
BUILD_TYPE?=Debug

# Logger configuration
EXTERNAL_LOGGER?=OFF
Expand Down
250 changes: 136 additions & 114 deletions src/websockets/libwebsockets/LibWebsocketClientPool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,17 @@ int LibWebsocketClientPool::eventCallback(struct lws* wsi, enum lws_callback_rea
{
lws_set_timeout(waiting_client->m_wsi, static_cast<pending_timeout>(1), LWS_TO_KILL_SYNC);
}
lws_vhost_destroy(waiting_client->m_vhost);
waiting_client->m_vhost = nullptr;
waiting_client->m_connected = false;
waiting_client->m_disconnect_process_done = true;
if (waiting_client->m_vhost)
{
lws_vhost_destroy(waiting_client->m_vhost);
}

std::lock_guard<std::mutex> lock(waiting_client->m_disconnect_mutex);
waiting_client->m_protocol.clear();
waiting_client->m_vhost = nullptr;
waiting_client->m_connected = false;
waiting_client->m_disconnect_process_in_progress = false;
waiting_client->m_disconnect_process_done = true;
waiting_client->m_disconnect_cond_var.notify_all();
}
}
Expand Down Expand Up @@ -242,6 +249,7 @@ LibWebsocketClientPool::Client::Client(LibWebsocketClientPool& pool)
m_connected(false),
m_disconnect_cond_var(),
m_disconnect_mutex(),
m_disconnect_process_in_progress(false),
m_disconnect_process_done(false),
m_context(m_pool.m_context),
m_vhost(nullptr),
Expand Down Expand Up @@ -282,6 +290,8 @@ bool LibWebsocketClientPool::Client::connect(const std::string& url,
{
bool ret = false;

std::lock_guard<std::mutex> lock(m_disconnect_mutex);

// Check if thread is alive and if a listener has been registered
if (!m_vhost && m_listener)
{
Expand Down Expand Up @@ -334,19 +344,21 @@ bool LibWebsocketClientPool::Client::disconnect()
{
bool ret = false;

std::unique_lock<std::mutex> lock(m_disconnect_mutex);

// Check if connected
if (m_vhost)
if (!m_disconnect_process_in_progress && !m_protocol.empty())
{
// Schedule disconnection
m_retry_interval = 0;
m_disconnect_process_done = false;
m_retry_interval = 0;
m_disconnect_process_in_progress = true;
m_disconnect_process_done = false;
m_pool.m_waiting_disconnect_queue.push(this);
lws_cancel_service(m_context);

// Wait actual disconnection
if (std::this_thread::get_id() != m_pool.m_thread->get_id())
{
std::unique_lock<std::mutex> lock(m_disconnect_mutex);
m_disconnect_cond_var.wait(lock, [&] { return m_disconnect_process_done; });
}
}
Expand Down Expand Up @@ -374,6 +386,8 @@ bool LibWebsocketClientPool::Client::send(const void* data, size_t size)
{
bool ret = false;

std::lock_guard<std::mutex> lock(m_disconnect_mutex);

// Check if connected
if (m_connected)
{
Expand Down Expand Up @@ -434,136 +448,144 @@ void LibWebsocketClientPool::Client::connectCallback(struct lws_sorted_usec_list
ScheduleData* schedule_data = lws_container_of(sul, ScheduleData, sched_list);
if (schedule_data)
{
// Check if vhost has been created
Client* client = schedule_data->client;
if (!client->m_vhost)
Client* client = schedule_data->client;
std::lock_guard<std::mutex> lock(client->m_disconnect_mutex);

// Check if a disconnect process is in progress
if (!client->m_disconnect_process_in_progress)
{
// Define callback
struct lws_protocols protocols[] = {
{"LibWebsocketClientPoolClient", &LibWebsocketClientPool::Client::eventCallback, 0, 0, 0, client, 0},
LWS_PROTOCOL_LIST_TERM};

// Fill vhost information
struct lws_context_creation_info vhost_info;
memset(&vhost_info, 0, sizeof(vhost_info));
vhost_info.options = LWS_SERVER_OPTION_DO_SSL_GLOBAL_INIT;
vhost_info.port = CONTEXT_PORT_NO_LISTEN;
vhost_info.timeout_secs = client->m_connect_timeout;
vhost_info.connect_timeout_secs = client->m_connect_timeout;
vhost_info.protocols = protocols;
vhost_info.log_cx = &pool->m_logs_context;
if (client->m_url.protocol() == "wss")
// Check if vhost has been created
if (!client->m_vhost)
{
if (!client->m_credentials.tls12_cipher_list.empty())
// Define callback
struct lws_protocols protocols[] = {
{"LibWebsocketClientPoolClient", &LibWebsocketClientPool::Client::eventCallback, 0, 0, 0, client, 0},
LWS_PROTOCOL_LIST_TERM};

// Fill vhost information
struct lws_context_creation_info vhost_info;
memset(&vhost_info, 0, sizeof(vhost_info));
vhost_info.options = LWS_SERVER_OPTION_DO_SSL_GLOBAL_INIT;
vhost_info.port = CONTEXT_PORT_NO_LISTEN;
vhost_info.timeout_secs = client->m_connect_timeout;
vhost_info.connect_timeout_secs = client->m_connect_timeout;
vhost_info.protocols = protocols;
vhost_info.log_cx = &pool->m_logs_context;
if (client->m_url.protocol() == "wss")
{
vhost_info.client_ssl_cipher_list = client->m_credentials.tls12_cipher_list.c_str();
}
if (!client->m_credentials.tls13_cipher_list.empty())
{
vhost_info.client_tls_1_3_plus_cipher_list = client->m_credentials.tls13_cipher_list.c_str();
}
if (client->m_credentials.encoded_pem_certificates)
{
// Use PEM encoded data
if (!client->m_credentials.server_certificate_ca.empty())
if (!client->m_credentials.tls12_cipher_list.empty())
{
vhost_info.client_ssl_ca_mem = client->m_credentials.server_certificate_ca.c_str();
vhost_info.client_ssl_ca_mem_len = static_cast<unsigned int>(client->m_credentials.server_certificate_ca.size());
vhost_info.client_ssl_cipher_list = client->m_credentials.tls12_cipher_list.c_str();
}
if (!client->m_credentials.client_certificate.empty())
if (!client->m_credentials.tls13_cipher_list.empty())
{
vhost_info.client_ssl_cert_mem = client->m_credentials.client_certificate.c_str();
vhost_info.client_ssl_cert_mem_len = static_cast<unsigned int>(client->m_credentials.client_certificate.size());
vhost_info.client_tls_1_3_plus_cipher_list = client->m_credentials.tls13_cipher_list.c_str();
}
if (!client->m_credentials.client_certificate_private_key.empty())
if (client->m_credentials.encoded_pem_certificates)
{
vhost_info.client_ssl_key_mem = client->m_credentials.client_certificate_private_key.c_str();
vhost_info.client_ssl_key_mem_len =
static_cast<unsigned int>(client->m_credentials.client_certificate_private_key.size());
// Use PEM encoded data
if (!client->m_credentials.server_certificate_ca.empty())
{
vhost_info.client_ssl_ca_mem = client->m_credentials.server_certificate_ca.c_str();
vhost_info.client_ssl_ca_mem_len =
static_cast<unsigned int>(client->m_credentials.server_certificate_ca.size());
}
if (!client->m_credentials.client_certificate.empty())
{
vhost_info.client_ssl_cert_mem = client->m_credentials.client_certificate.c_str();
vhost_info.client_ssl_cert_mem_len = static_cast<unsigned int>(client->m_credentials.client_certificate.size());
}
if (!client->m_credentials.client_certificate_private_key.empty())
{
vhost_info.client_ssl_key_mem = client->m_credentials.client_certificate_private_key.c_str();
vhost_info.client_ssl_key_mem_len =
static_cast<unsigned int>(client->m_credentials.client_certificate_private_key.size());
}
}
}
else
{
// Load PEM files from filesystem
if (!client->m_credentials.server_certificate_ca.empty())
else
{
vhost_info.client_ssl_ca_filepath = client->m_credentials.server_certificate_ca.c_str();
// Load PEM files from filesystem
if (!client->m_credentials.server_certificate_ca.empty())
{
vhost_info.client_ssl_ca_filepath = client->m_credentials.server_certificate_ca.c_str();
}
if (!client->m_credentials.client_certificate.empty())
{
vhost_info.client_ssl_cert_filepath = client->m_credentials.client_certificate.c_str();
}
if (!client->m_credentials.client_certificate_private_key.empty())
{
vhost_info.client_ssl_private_key_filepath = client->m_credentials.client_certificate_private_key.c_str();
}
}
if (!client->m_credentials.client_certificate.empty())
if (!client->m_credentials.client_certificate_private_key_passphrase.empty())
{
vhost_info.client_ssl_cert_filepath = client->m_credentials.client_certificate.c_str();
vhost_info.client_ssl_private_key_password =
client->m_credentials.client_certificate_private_key_passphrase.c_str();
}
if (!client->m_credentials.client_certificate_private_key.empty())
{
vhost_info.client_ssl_private_key_filepath = client->m_credentials.client_certificate_private_key.c_str();
}
}
if (!client->m_credentials.client_certificate_private_key_passphrase.empty())
{
vhost_info.client_ssl_private_key_password = client->m_credentials.client_certificate_private_key_passphrase.c_str();
}
}

// Create vhost
client->m_vhost = lws_create_vhost(client->m_context, &vhost_info);
}
if (client->m_vhost)
{
// Connexion parameters
struct lws_client_connect_info connect_info;
memset(&connect_info, 0, sizeof(connect_info));
connect_info.context = client->m_context;
connect_info.vhost = client->m_vhost;
connect_info.address = client->m_url.address().c_str();
connect_info.path = client->m_url.path().c_str();
connect_info.host = connect_info.address;
connect_info.origin = connect_info.address;
if (client->m_url.protocol() == "wss")
// Create vhost
client->m_vhost = lws_create_vhost(client->m_context, &vhost_info);
}
if (client->m_vhost)
{
connect_info.ssl_connection = LCCSCF_USE_SSL;
if (client->m_credentials.allow_selfsigned_certificates)
// Connexion parameters
struct lws_client_connect_info connect_info;
memset(&connect_info, 0, sizeof(connect_info));
connect_info.context = client->m_context;
connect_info.vhost = client->m_vhost;
connect_info.address = client->m_url.address().c_str();
connect_info.path = client->m_url.path().c_str();
connect_info.host = connect_info.address;
connect_info.origin = connect_info.address;
if (client->m_url.protocol() == "wss")
{
connect_info.ssl_connection |= LCCSCF_ALLOW_SELFSIGNED;
connect_info.ssl_connection = LCCSCF_USE_SSL;
if (client->m_credentials.allow_selfsigned_certificates)
{
connect_info.ssl_connection |= LCCSCF_ALLOW_SELFSIGNED;
}
if (client->m_credentials.allow_expired_certificates)
{
connect_info.ssl_connection |= LCCSCF_ALLOW_EXPIRED;
}
if (client->m_credentials.accept_untrusted_certificates)
{
connect_info.ssl_connection |= LCCSCF_ALLOW_INSECURE;
}
if (client->m_credentials.skip_server_name_check)
{
connect_info.ssl_connection |= LCCSCF_SKIP_SERVER_CERT_HOSTNAME_CHECK;
}
connect_info.port = 443;
}
if (client->m_credentials.allow_expired_certificates)
else
{
connect_info.ssl_connection |= LCCSCF_ALLOW_EXPIRED;
connect_info.port = 80;
}
if (client->m_credentials.accept_untrusted_certificates)
if (client->m_url.port())
{
connect_info.ssl_connection |= LCCSCF_ALLOW_INSECURE;
connect_info.port = static_cast<int>(client->m_url.port());
}
if (client->m_credentials.skip_server_name_check)
connect_info.protocol = client->m_protocol.c_str();
connect_info.local_protocol_name = "LibWebsocketClientPoolClient";
connect_info.pwsi = &client->m_wsi;
connect_info.retry_and_idle_policy = &client->m_retry_policy;
connect_info.userdata = client;

// Start connection
if (!lws_client_connect_via_info(&connect_info))
{
connect_info.ssl_connection |= LCCSCF_SKIP_SERVER_CERT_HOSTNAME_CHECK;
// Schedule a retry
client->m_retry_count = 0;
lws_retry_sul_schedule(pool->m_context,
0,
sul,
&client->m_retry_policy,
&LibWebsocketClientPool::Client::connectCallback,
&client->m_retry_count);
}
connect_info.port = 443;
}
else
{
connect_info.port = 80;
}
if (client->m_url.port())
{
connect_info.port = static_cast<int>(client->m_url.port());
}
connect_info.protocol = client->m_protocol.c_str();
connect_info.local_protocol_name = "LibWebsocketClientPoolClient";
connect_info.pwsi = &client->m_wsi;
connect_info.retry_and_idle_policy = &client->m_retry_policy;
connect_info.userdata = client;

// Start connection
if (!lws_client_connect_via_info(&connect_info))
{
// Schedule a retry
client->m_retry_count = 0;
lws_retry_sul_schedule(pool->m_context,
0,
sul,
&client->m_retry_policy,
&LibWebsocketClientPool::Client::connectCallback,
&client->m_retry_count);
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/websockets/libwebsockets/LibWebsocketClientPool.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ class LibWebsocketClientPool
std::condition_variable m_disconnect_cond_var;
/** @brief Disconnect mutex */
std::mutex m_disconnect_mutex;
/** @brief Indicate that the disconnect process is in progress */
bool m_disconnect_process_in_progress;
/** @brief Indicate that the disconnect process is done */
bool m_disconnect_process_done;

Expand Down

0 comments on commit eabee12

Please sign in to comment.