Skip to content

Commit

Permalink
tls support sni cert
Browse files Browse the repository at this point in the history
  • Loading branch information
findstr authored and zhoupy committed Jan 24, 2024
1 parent 229f505 commit 9fb99d7
Show file tree
Hide file tree
Showing 7 changed files with 196 additions and 77 deletions.
207 changes: 158 additions & 49 deletions lualib-src/lualib-tls.c
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,23 @@
#include <openssl/bio.h>
#include <openssl/ssl.h>
#include <openssl/err.h>
#include <openssl/x509v3.h>

#include "silly.h"

#define ssl_malloc silly_malloc
#define ssl_free silly_free

struct ctx_entry {
SSL_CTX *ptr;
X509 *cert;
};

struct ctx {
int mode;
int alpn;
SSL_CTX *ptr;
int entry_count;
struct ctx_entry entries[1];
};

struct tls {
Expand All @@ -28,14 +35,27 @@ struct tls {
BIO *out_bio;
};

static void
ctx_destroy(struct ctx *ctx)
{
int i;
for (i = 0; i < ctx->entry_count; i++) {
if (ctx->entries[i].ptr != NULL) {
SSL_CTX_free(ctx->entries[i].ptr);
}
if (ctx->entries[i].cert != NULL) {
X509_free(ctx->entries[i].cert);
}
}
ctx->entry_count = 0;
}

static int
lctx_free(lua_State *L)
{
struct ctx *ctx;
ctx = (struct ctx *)luaL_checkudata(L, 1, "TLS_CTX");
if (ctx->ptr!= NULL)
SSL_CTX_free(ctx->ptr);
ctx->ptr= NULL;
ctx_destroy(ctx);
return 0;
}

Expand All @@ -52,17 +72,20 @@ ltls_free(lua_State *L)


static struct ctx *
new_tls_ctx(lua_State *L, SSL_CTX *ptr, int mode)
new_tls_ctx(lua_State *L, int mode, int ctx_count)
{
int size;
struct ctx *ctx;
ctx = (struct ctx*)lua_newuserdatauv(L, sizeof(*ctx), 0);
size = offsetof(struct ctx, entries) + ctx_count * sizeof(struct ctx_entry);
ctx = (struct ctx*)lua_newuserdatauv(L, size, 0);
if (luaL_newmetatable(L, "TLS_CTX")) {
lua_pushcfunction(L, lctx_free);
lua_setfield(L, -2, "__gc");
}
ctx->ptr = ptr;
memset(ctx, 0, size);
ctx->mode = mode;
ctx->alpn = 0;
ctx->entry_count = ctx_count;
lua_setmetatable(L, -2);
return ctx;
}
Expand Down Expand Up @@ -90,13 +113,15 @@ static int
lctx_client(lua_State *L)
{
SSL_CTX *ptr;
struct ctx *ctx;
ptr = SSL_CTX_new(TLS_method());
if (ptr == NULL) {
lua_pushnil(L);
lua_pushstring(L, "SSL_CTX_new fail");
return 2;
}
new_tls_ctx(L, ptr, 'C');
ctx = new_tls_ctx(L, 'C', 1);
ctx->entries[0].ptr = ptr;
return 1;
}

Expand All @@ -120,7 +145,7 @@ int alpn_cb(SSL *ssl, const unsigned char **out, unsigned char *outlen,
unsigned char *outx;
(void)ssl;
struct ctx *ctx = (struct ctx *)arg;
if (ctx->ptr == NULL)
if (ctx->entry_count == 0)
return SSL_TLSEXT_ERR_NOACK;
if (ctx->alpn == 1) {
alpn = alpn_h2;
Expand All @@ -138,61 +163,145 @@ int alpn_cb(SSL *ssl, const unsigned char **out, unsigned char *outlen,
}
}

static int
lctx_server(lua_State *L)
static const char *
fill_entry(lua_State *L, struct ctx_entry *entry, int stk)
{
int r, alpn;
SSL_CTX *ptr;
struct ctx *ctx;
int ret;
FILE *fp = NULL;
X509 *cert = NULL;
SSL_CTX *ptr = NULL;
const char *err = NULL;
const char *certpath, *keypath;
int top = lua_gettop(L);
ptr = SSL_CTX_new(TLS_method());
if (ptr == NULL) {
lua_pushnil(L);
lua_pushstring(L, "SSL_CTX_new");
return 2;
err = "SSL_CTX_new fail";
goto fail;
}
certpath = luaL_checkstring(L, 1);
keypath = luaL_checkstring(L, 2);
//create tls_ctx first so that
//even lctx_server fail
//gc also will free the ptr
ctx = new_tls_ctx(L, ptr, 'S');
r = SSL_CTX_use_certificate_file(ptr, certpath, SSL_FILETYPE_PEM);
if (r != 1) {
lua_pop(L, 1);
lua_pushnil(L);
lua_pushstring(L, "SSL_CTX_use_certificate_file");
return 2;
SSL_CTX_set_min_proto_version(ptr, TLS1_1_VERSION);
lua_getfield(L, stk, "cert");
certpath = luaL_checkstring(L, -1);
lua_getfield(L, stk, "cert_key");
keypath = luaL_checkstring(L, -1);
fp = fopen(certpath, "r");
if (fp == NULL) {
err = "open certificate file fail";
goto fail;
}
r = SSL_CTX_use_PrivateKey_file(ptr, keypath, SSL_FILETYPE_PEM);
if (r != 1) {
lua_pop(L, 1);
lua_pushnil(L);
lua_pushstring(L, "SSL_CTX_use_PrivateKey_file");
return 2;
cert = PEM_read_X509(fp, NULL, NULL, NULL);
fclose(fp);
fp = NULL;
if (cert == NULL) {
err = "read certificate file fail";
goto fail;
}
ret = SSL_CTX_use_certificate_chain_file(ptr, certpath);
if (ret != 1) {
err = "SSL_CTX_use_certificate_file";
goto fail;
}
ret = SSL_CTX_use_PrivateKey_file(ptr, keypath, SSL_FILETYPE_PEM);
if (ret != 1) {
printf("SSL_CTX_use_PrivateKey_file fail:%s\n", ERR_error_string(ERR_get_error(), NULL));
err = "SSL_CTX_use_PrivateKey_file";
goto fail;
}
ret = SSL_CTX_check_private_key(ptr);
if (ret != 1) {
err = "SSL_CTX_check_private_key";
goto fail;
}
lua_settop(L, top);
entry->ptr = ptr;
entry->cert = cert;
return NULL;
fail:
lua_settop(L, top);
if (fp != NULL)
fclose(fp);
if (ptr != NULL)
SSL_CTX_free(ptr);
if (cert != NULL)
X509_free(cert);
return err;
}

static int ssl_servername_cb(SSL *s, int *ad, void *arg)
{
int i;
SSL_CTX *ptr = NULL;
const char *servername;
struct ctx *ctx = (struct ctx *) arg;
(void)ad;
servername = SSL_get_servername(s, TLSEXT_NAMETYPE_host_name);
if (servername != NULL) {
for (i = 0; i < ctx->entry_count; i++) {
X509 *cert = ctx->entries[i].cert;
if (cert == NULL)
continue;
if (X509_check_host(cert, servername, 0, 0, NULL) == 1) {
ptr = ctx->entries[i].ptr;
break;
}
}
}
r = SSL_CTX_check_private_key(ptr);
if (r != 1) {
if (ptr == NULL) {
ptr = ctx->entries[0].ptr;
}
SSL_set_SSL_CTX(s, ptr);
return SSL_TLSEXT_ERR_OK;
}

static int
lctx_server(lua_State *L)
{
SSL_CTX *ptr;
struct ctx *ctx;
const char *err = NULL;
int i, ncert, r, alpn;
ncert = luaL_len(L, 1);
ctx = new_tls_ctx(L, 'S', ncert);
for (i = 0; i < ctx->entry_count; i++) {
int absidx;
struct ctx_entry *entry;
lua_rawgeti(L, 1, i + 1);
absidx = lua_absindex(L, -1);
entry = &ctx->entries[i];
err = fill_entry(L, entry, absidx);
lua_pop(L, 1);
if (err != NULL)
break;
ptr = entry->ptr;
SSL_CTX_set_tlsext_servername_callback(ptr, ssl_servername_cb);
SSL_CTX_set_tlsext_servername_arg(ptr, ctx);
}
if (err != NULL) {
ctx_destroy(ctx);
lua_pushnil(L);
lua_pushstring(L, "SSL_CTX_check_private_key");
lua_pushstring(L, err);
return 2;
}
if (lua_type(L, 3) != LUA_TNIL) {
if (lua_type(L, 2) != LUA_TNIL) {
const char *cipher;
cipher = luaL_checkstring(L, 3);
r = SSL_CTX_set_cipher_list(ptr, cipher);
if (r != 1) {
lua_pop(L, 1);
lua_pushnil(L);
lua_pushstring(L, "SSL_CTX_set_cipher_list");
return 2;
cipher = luaL_checkstring(L, 2);
for (i = 0; i < ctx->entry_count; i++) {
SSL_CTX *ptr = ctx->entries[i].ptr;
r = SSL_CTX_set_cipher_list(ptr, cipher);
if (r != 1) {
ctx_destroy(ctx);
lua_pushnil(L);
lua_pushstring(L, "SSL_CTX_set_cipher_list");
return 2;
}
}
}
alpn = luaL_optinteger(L, 4, 0);
alpn = luaL_optinteger(L, 3, 0);
if (alpn == 1) {
ctx->alpn = alpn;
SSL_CTX_set_alpn_select_cb(ptr, alpn_cb, ctx);
for (i = 0; i < ctx->entry_count; i++) {
SSL_CTX *ptr = ctx->entries[i].ptr;
SSL_CTX_set_alpn_select_cb(ptr, alpn_cb, ctx);
}
}
return 1;
}
Expand All @@ -209,7 +318,7 @@ ltls_open(lua_State *L)
hostname = lua_tostring(L, 3);
alpn = luaL_optinteger(L, 4, 0);
tls = new_tls(L, fd);
tls->ssl = SSL_new(ctx->ptr);
tls->ssl = SSL_new(ctx->entries[0].ptr);
if (tls->ssl == NULL)
luaL_error(L, "SSL_new fail");
if (alpn == 1)
Expand Down
5 changes: 2 additions & 3 deletions lualib/http/server.lua
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ local function httpd(scheme, handler)
end
--request line
local method, uri, ver =
first:match("(%w+)%s+(.-)%s+HTTP/([%d|.]+)\r\n")
first:match("(%w+)%s+(.-)%s+HTTP/([%d|.]+)\r\n")
assert(method and uri and ver)
res.method = method
res.version = ver
Expand Down Expand Up @@ -173,8 +173,7 @@ local server = {
fd2 = tls.listen {
disp = httpd("https", handler),
port = conf.tls_port,
key = conf.tls_key,
cert = conf.tls_cert,
certs = conf.tls_certs,
}
end
return fd1, fd2
Expand Down
32 changes: 15 additions & 17 deletions lualib/http2/stream.lua
Original file line number Diff line number Diff line change
Expand Up @@ -303,23 +303,23 @@ local function frame_winupdate(ch, id, flag, dat)
end

local frame_client = {
[FRAME_HEADERS] = frame_header_client,
[FRAME_DATA] = frame_data,
[FRAME_RST] = frame_rst,
[FRAME_SETTINGS] = frame_settings,
[FRAME_PING] = frame_ping,
[FRAME_GOAWAY] = frame_goaway,
[FRAME_WINUPDATE] = frame_winupdate,
[FRAME_HEADERS] = frame_header_client,
[FRAME_DATA] = frame_data,
[FRAME_RST] = frame_rst,
[FRAME_SETTINGS] = frame_settings,
[FRAME_PING] = frame_ping,
[FRAME_GOAWAY] = frame_goaway,
[FRAME_WINUPDATE] = frame_winupdate,
}

local frame_server = {
[FRAME_HEADERS] = frame_header_server,
[FRAME_DATA] = frame_data,
[FRAME_RST] = frame_rst,
[FRAME_SETTINGS] = frame_settings,
[FRAME_PING] = frame_ping,
[FRAME_GOAWAY] = frame_goaway,
[FRAME_WINUPDATE] = frame_winupdate,
[FRAME_HEADERS] = frame_header_server,
[FRAME_DATA] = frame_data,
[FRAME_RST] = frame_rst,
[FRAME_SETTINGS] = frame_settings,
[FRAME_PING] = frame_ping,
[FRAME_GOAWAY] = frame_goaway,
[FRAME_WINUPDATE] = frame_winupdate,
}

local function common_dispatch(ch, frame_process)
Expand Down Expand Up @@ -516,9 +516,8 @@ end
function M.listen(conf)
return tls_listen {
disp = httpd(conf.handler),
certs = conf.tls_certs,
port = conf.tls_port,
key = conf.tls_key,
cert = conf.tls_cert,
alpn = "h2"
}
end
Expand Down Expand Up @@ -559,7 +558,6 @@ function M.ack(s, status, header, endstream)
return tls_write(ch.fd, dat)
end


function M.write(s, dat, continue)
local ch = s.channel
if not ch then
Expand Down
Loading

0 comments on commit 9fb99d7

Please sign in to comment.