Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ssl: enable client-side ssl session cache #634

Merged
merged 2 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions debian/changelog
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
yass (1.5.14-1) UNRELEASED; urgency=medium

* ssl: enable client-side ssl session cache.
* ssl: deduplicate all ceritificates.

-- Chilledheart <keeyou-cn@outlook.com> Mon, 8 Jan 2024 20:41:54 +0800
yass (1.5.13-1) UNRELEASED; urgency=medium

* gtk: add server sni support.
Expand Down
3 changes: 2 additions & 1 deletion src/cli/cli_connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1472,7 +1472,8 @@ void CliConnection::OnConnect() {
<< " connect " << remote_domain();
// create lazy
if (enable_upstream_tls_) {
channel_ = ssl_stream::create(*io_context_,
channel_ = ssl_stream::create(ssl_socket_data_index(),
*io_context_,
remote_host_ips_,
remote_host_sni_,
remote_port_,
Expand Down
11 changes: 10 additions & 1 deletion src/connection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,16 +233,19 @@ class Connection {
/// \param peer_endpoint the peer endpoint
/// \param the number of connection id
/// \param the pointer of tlsext ctx
/// \param the ssl client data index
void on_accept(asio::ip::tcp::socket&& socket,
const asio::ip::tcp::endpoint& endpoint,
const asio::ip::tcp::endpoint& peer_endpoint,
int connection_id,
tlsext_ctx_t *tlsext_ctx) {
tlsext_ctx_t *tlsext_ctx,
int ssl_socket_data_index) {
downlink_->on_accept(std::move(socket));
endpoint_ = endpoint;
peer_endpoint_ = peer_endpoint;
connection_id_ = connection_id;
tlsext_ctx_.reset(tlsext_ctx);
ssl_socket_data_index_ = ssl_socket_data_index;
}

/// Enter the start phase, begin to read requests
Expand Down Expand Up @@ -284,6 +287,10 @@ class Connection {
return *tlsext_ctx_;
}

int ssl_socket_data_index() const {
return ssl_socket_data_index_;
}

protected:
/// the peek current io
bool DoPeek() {
Expand All @@ -308,6 +315,8 @@ class Connection {
int connection_id_ = -1;
/// the tlsext ctx
std::unique_ptr<tlsext_ctx_t> tlsext_ctx_;
/// the ssl client data index
int ssl_socket_data_index_ = -1;

/// if https fallback
bool upstream_https_fallback_;
Expand Down
30 changes: 22 additions & 8 deletions src/content_server.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "crypto/crypter_export.hpp"
#include "network.hpp"
#include "net/x509_util.hpp"
#include "net/ssl_socket.hpp"

#define MAX_LISTEN_ADDRESSES 30

Expand Down Expand Up @@ -73,6 +74,7 @@ class ContentServer {
}

~ContentServer() {
client_instance_ = nullptr;
work_guard_.reset();
}

Expand Down Expand Up @@ -259,7 +261,7 @@ class ContentServer {
SetTCPKeepAlive(socket.native_handle(), ec);
SetSocketTcpNoDelay(&socket, ec);
conn->on_accept(std::move(socket), ctx.endpoint, ctx.peer_endpoint,
connection_id, tlsext_ctx);
connection_id, tlsext_ctx, ssl_socket_data_index_);
conn->set_disconnect_cb(
[this, conn]() mutable { on_disconnect(conn); });
connection_map_.insert(std::make_pair(connection_id, conn));
Expand Down Expand Up @@ -400,7 +402,7 @@ class ContentServer {
// SSL_CTX_set1_ech_keys

// Deduplicate all certificates minted from the SSL_CTX in memory.
SSL_CTX_set0_buffer_pool(ssl_ctx_.native_handle(), x509_util::GetBufferPool());
SSL_CTX_set0_buffer_pool(ssl_ctx_.native_handle(), net::x509_util::GetBufferPool());
}

void setup_ssl_ctx_alpn_cb(tlsext_ctx_t *tlsext_ctx) {
Expand Down Expand Up @@ -554,29 +556,38 @@ class ContentServer {
}
VLOG(1) << "Alpn support (client) enabled";

#if 0
client_instance_ = this;
ssl_socket_data_index_ = SSL_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);

// Disable the internal session cache. Session caching is handled
// externally (i.e. by SSLClientSessionCache).
SSL_CTX_set_session_cache_mode(upstream_ssl_ctx_.native_handle(),
SSL_SESS_CACHE_CLIENT | SSL_SESS_CACHE_NO_INTERNAL);
SSL_CTX_sess_set_new_cb(upstream_ssl_ctx_.native_handle(), NewSessionCallback);
#endif

SSL_CTX_set_timeout(upstream_ssl_ctx_.native_handle(), 1 * 60 * 60 /* one hour */);

SSL_CTX_set_grease_enabled(upstream_ssl_ctx_.native_handle(), 1);

// Deduplicate all certificates minted from the SSL_CTX in memory.
SSL_CTX_set0_buffer_pool(upstream_ssl_ctx_.native_handle(), x509_util::GetBufferPool());
SSL_CTX_set0_buffer_pool(upstream_ssl_ctx_.native_handle(), net::x509_util::GetBufferPool());
}

#if 0
private:
int ssl_socket_data_index_ = -1;
static ContentServer<T> *client_instance_;
static ContentServer *GetInstance() { return client_instance_; }
net::SSLSocket* GetClientSocketFromSSL(const SSL* ssl) {
DCHECK(ssl);
net::SSLSocket* socket = static_cast<net::SSLSocket*>(SSL_get_ex_data(ssl, ssl_socket_data_index_));
DCHECK(socket);
return socket;
}

static int NewSessionCallback(SSL* ssl, SSL_SESSION* session) {
SSLClientSocketImpl* socket = GetInstance()->GetClientSocketFromSSL(ssl);
net::SSLSocket* socket = GetInstance()->GetClientSocketFromSSL(ssl);
return socket->NewSessionCallback(session);
}
#endif

private:
asio::io_context &io_context_;
Expand Down Expand Up @@ -619,4 +630,7 @@ class ContentServer {
T factory_;
};

template<typename T>
ContentServer<T> *ContentServer<T>::client_instance_ = nullptr;

#endif // H_CONTENT_SERVER
38 changes: 35 additions & 3 deletions src/net/ssl_socket.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
// SPDX-License-Identifier: GPL-2.0
/* Copyright (c) 2023 Chilledheart */
/* Copyright (c) 2023-2024 Chilledheart */

#include "net/ssl_socket.hpp"
#include "network.hpp"

#include <absl/container/flat_hash_map.h>

namespace net {

namespace {
Expand All @@ -17,16 +19,22 @@ const int kCertVerifyPending = 1;
const int kDefaultOpenSSLBufferSize = 17 * 1024;
} // namespace

SSLSocket::SSLSocket(asio::io_context *io_context,
static constexpr int kMaximumSSLCache = 1024;
static absl::flat_hash_map<asio::ip::address, bssl::UniquePtr<SSL_SESSION>> g_ssl_lru_cache;

SSLSocket::SSLSocket(int ssl_socket_data_index,
asio::io_context *io_context,
asio::ip::tcp::socket* socket,
SSL_CTX* ssl_ctx,
bool https_fallback,
const std::string& host_name)
: io_context_(io_context), stream_socket_(socket),
: ssl_socket_data_index_(ssl_socket_data_index),
io_context_(io_context), stream_socket_(socket),
early_data_enabled_(absl::GetFlag(FLAGS_tls13_early_data)),
pending_read_error_(kSSLClientSocketNoPendingResult) {
DCHECK(!ssl_);
ssl_.reset(SSL_new(ssl_ctx));
CHECK_NE(0, SSL_set_ex_data(ssl_.get(), ssl_socket_data_index_, this));

// TODO: reuse SSL session

Expand Down Expand Up @@ -151,6 +159,7 @@ int SSLSocket::Connect(CompletionOnceCallback callback) {
}

SSLSocket::~SSLSocket() {
CHECK_NE(0, SSL_set_ex_data(ssl_.get(), ssl_socket_data_index_, nullptr));
VLOG(1) << "SSLSocket " << this << " freed memory";
}

Expand Down Expand Up @@ -361,6 +370,29 @@ void SSLSocket::WaitWrite(WaitCallback &&cb) {
});
}

int SSLSocket::NewSessionCallback(SSL_SESSION* session) {
asio::ip::address ip_addr;
if (SSL_CIPHER_get_kx_nid(SSL_SESSION_get0_cipher(session)) == NID_kx_rsa) {
// If RSA key exchange was used, additionally key the cache with the
// destination IP address. Of course, if a proxy is being used, the
// semantics of this are a little complex, but we're doing our best. See
// https://crbug.com/969684
asio::error_code ec;
auto ip_endpoint = stream_socket_->remote_endpoint(ec);
if (ec) {
return 0;
}
ip_addr = ip_endpoint.address();
}

// OpenSSL optionally passes ownership of |session|. Returning one signals
// that this function has claimed it.
g_ssl_lru_cache[ip_addr] = bssl::UniquePtr<SSL_SESSION>(session);
if (g_ssl_lru_cache.size() >= kMaximumSSLCache)
g_ssl_lru_cache.clear();
return 1;
}

void SSLSocket::OnWaitRead(asio::error_code ec) {
if (disconnected_)
return;
Expand Down
9 changes: 7 additions & 2 deletions src/net/ssl_socket.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: GPL-2.0
/* Copyright (c) 2023 Chilledheart */
/* Copyright (c) 2023-2024 Chilledheart */

#ifndef H_NET_SSL_SOCKET
#define H_NET_SSL_SOCKET
Expand Down Expand Up @@ -46,7 +46,8 @@ using WaitCallback = absl::AnyInvocable<void(asio::error_code ec)>;

class SSLSocket : public RefCountedThreadSafe<SSLSocket> {
public:
SSLSocket(asio::io_context *io_context,
SSLSocket(int ssl_socket_data_index,
asio::io_context *io_context,
asio::ip::tcp::socket* socket,
SSL_CTX* ssl_ctx,
bool https_fallback,
Expand Down Expand Up @@ -78,6 +79,9 @@ class SSLSocket : public RefCountedThreadSafe<SSLSocket> {
const std::string& negotiated_protocol() const {
return negotiated_protocol_;
}

int NewSessionCallback(SSL_SESSION* session);

protected:
void OnWaitRead(asio::error_code ec);
void OnWaitWrite(asio::error_code ec);
Expand All @@ -101,6 +105,7 @@ class SSLSocket : public RefCountedThreadSafe<SSLSocket> {
int MapLastOpenSSLError(int ssl_error);

private:
int ssl_socket_data_index_;
asio::io_context* io_context_;
asio::ip::tcp::socket* stream_socket_;

Expand Down
4 changes: 2 additions & 2 deletions src/net/x509_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

#include "net/x509_util.hpp"

namespace x509_util {
namespace net::x509_util {

namespace {

Expand All @@ -24,4 +24,4 @@ CRYPTO_BUFFER_POOL* GetBufferPool() {
return g_buffer_pool_singleton.pool();
}

} // namespace x509_util
} // namespace net::x509_util
4 changes: 2 additions & 2 deletions src/net/x509_util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
#include <openssl/base.h>
#include <openssl/pool.h>

namespace x509_util {
namespace net::x509_util {

// Returns a CRYPTO_BUFFER_POOL for deduplicating certificates.
CRYPTO_BUFFER_POOL* GetBufferPool();

} // namespace x509_util
} // namespace net::x509_util

#endif // H_NET_X509_UTIL
3 changes: 2 additions & 1 deletion src/server/server_connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1156,7 +1156,8 @@ void ServerConnection::OnConnect() {
host_name = request_.endpoint().address().to_string();
}
if (enable_upstream_tls_) {
channel_ = ssl_stream::create(*io_context_,
channel_ = ssl_stream::create(ssl_socket_data_index(),
*io_context_,
std::string(), host_name, port,
this, upstream_https_fallback_,
upstream_ssl_ctx_);
Expand Down
8 changes: 5 additions & 3 deletions src/ssl_stream.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: GPL-2.0
/* Copyright (c) 2023 Chilledheart */
/* Copyright (c) 2023-2024 Chilledheart */

#ifndef H_SSL_STREAM
#define H_SSL_STREAM
Expand All @@ -17,14 +17,16 @@ class ssl_stream : public stream {

/// construct a ssl stream object with ss protocol
///
/// \param ssl_socket_data_index the ssl client data index
/// \param io_context the io context associated with the service
/// \param host_ips the ip addresses used with endpoint
/// \param host_sni the sni name used with endpoint
/// \param port the sni port used with endpoint
/// \param channel the underlying data channel used in stream
/// \param https_fallback the data channel falls back to https (alpn)
/// \param ssl_ctx the ssl context object for tls data transfer
ssl_stream(asio::io_context& io_context,
ssl_stream(int ssl_socket_data_index,
asio::io_context& io_context,
const std::string& host_ips,
const std::string& host_sni,
uint16_t port,
Expand All @@ -34,7 +36,7 @@ class ssl_stream : public stream {
: stream(io_context, host_ips, host_sni, port, channel),
https_fallback_(https_fallback),
enable_tls_(true),
ssl_socket_(net::SSLSocket::Create(&io_context, &socket_, ssl_ctx->native_handle(), https_fallback, host_sni)) {
ssl_socket_(net::SSLSocket::Create(ssl_socket_data_index, &io_context, &socket_, ssl_ctx->native_handle(), https_fallback, host_sni)) {
}

~ssl_stream() override {}
Expand Down
3 changes: 3 additions & 0 deletions yass.spec.in
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,9 @@ for embedded devices and low end boxes.
%{_mandir}/man1/yass_cli.1*

%changelog
* Mon Jan 8 2024 Chilledheart <keeyou-cn@outlook.com> - 1.5.14-1
- ssl: enable client-side ssl session cache.
- ssl: deduplicate all ceritificates.
* Sun Jan 7 2024 Chilledheart <keeyou-cn@outlook.com> - 1.5.13-1
- gtk: add server sni support.
* Sun Jan 7 2024 Chilledheart <keeyou-cn@outlook.com> - 1.5.12-1
Expand Down
Loading