diff --git a/lualib-src/lualib-tls.c b/lualib-src/lualib-tls.c index 983766c..513cafc 100644 --- a/lualib-src/lualib-tls.c +++ b/lualib-src/lualib-tls.c @@ -9,16 +9,23 @@ #include #include #include +#include #include "silly.h" #define ssl_malloc silly_malloc #define ssl_free silly_free +struct ctx_entry { + SSL_CTX *ptr; + X509 *cert; +}; + struct ctx { int mode; int alpn; - SSL_CTX *ptr; + int entry_count; + struct ctx_entry entries[1]; }; struct tls { @@ -28,14 +35,27 @@ struct tls { BIO *out_bio; }; +static void +ctx_destroy(struct ctx *ctx) +{ + int i; + for (i = 0; i < ctx->entry_count; i++) { + if (ctx->entries[i].ptr != NULL) { + SSL_CTX_free(ctx->entries[i].ptr); + } + if (ctx->entries[i].cert != NULL) { + X509_free(ctx->entries[i].cert); + } + } + ctx->entry_count = 0; +} + static int lctx_free(lua_State *L) { struct ctx *ctx; ctx = (struct ctx *)luaL_checkudata(L, 1, "TLS_CTX"); - if (ctx->ptr!= NULL) - SSL_CTX_free(ctx->ptr); - ctx->ptr= NULL; + ctx_destroy(ctx); return 0; } @@ -52,17 +72,20 @@ ltls_free(lua_State *L) static struct ctx * -new_tls_ctx(lua_State *L, SSL_CTX *ptr, int mode) +new_tls_ctx(lua_State *L, int mode, int ctx_count) { + int size; struct ctx *ctx; - ctx = (struct ctx*)lua_newuserdatauv(L, sizeof(*ctx), 0); + size = offsetof(struct ctx, entries) + ctx_count * sizeof(struct ctx_entry); + ctx = (struct ctx*)lua_newuserdatauv(L, size, 0); if (luaL_newmetatable(L, "TLS_CTX")) { lua_pushcfunction(L, lctx_free); lua_setfield(L, -2, "__gc"); } - ctx->ptr = ptr; + memset(ctx, 0, size); ctx->mode = mode; ctx->alpn = 0; + ctx->entry_count = ctx_count; lua_setmetatable(L, -2); return ctx; } @@ -90,13 +113,15 @@ static int lctx_client(lua_State *L) { SSL_CTX *ptr; + struct ctx *ctx; ptr = SSL_CTX_new(TLS_method()); if (ptr == NULL) { lua_pushnil(L); lua_pushstring(L, "SSL_CTX_new fail"); return 2; } - new_tls_ctx(L, ptr, 'C'); + ctx = new_tls_ctx(L, 'C', 1); + ctx->entries[0].ptr = ptr; return 1; } @@ -120,7 +145,7 @@ int alpn_cb(SSL *ssl, const unsigned char **out, unsigned char *outlen, unsigned char *outx; (void)ssl; struct ctx *ctx = (struct ctx *)arg; - if (ctx->ptr == NULL) + if (ctx->entry_count == 0) return SSL_TLSEXT_ERR_NOACK; if (ctx->alpn == 1) { alpn = alpn_h2; @@ -138,61 +163,145 @@ int alpn_cb(SSL *ssl, const unsigned char **out, unsigned char *outlen, } } -static int -lctx_server(lua_State *L) +static const char * +fill_entry(lua_State *L, struct ctx_entry *entry, int stk) { - int r, alpn; - SSL_CTX *ptr; - struct ctx *ctx; + int ret; + FILE *fp = NULL; + X509 *cert = NULL; + SSL_CTX *ptr = NULL; + const char *err = NULL; const char *certpath, *keypath; + int top = lua_gettop(L); ptr = SSL_CTX_new(TLS_method()); if (ptr == NULL) { - lua_pushnil(L); - lua_pushstring(L, "SSL_CTX_new"); - return 2; + err = "SSL_CTX_new fail"; + goto fail; } - certpath = luaL_checkstring(L, 1); - keypath = luaL_checkstring(L, 2); - //create tls_ctx first so that - //even lctx_server fail - //gc also will free the ptr - ctx = new_tls_ctx(L, ptr, 'S'); - r = SSL_CTX_use_certificate_file(ptr, certpath, SSL_FILETYPE_PEM); - if (r != 1) { - lua_pop(L, 1); - lua_pushnil(L); - lua_pushstring(L, "SSL_CTX_use_certificate_file"); - return 2; + SSL_CTX_set_min_proto_version(ptr, TLS1_1_VERSION); + lua_getfield(L, stk, "cert"); + certpath = luaL_checkstring(L, -1); + lua_getfield(L, stk, "cert_key"); + keypath = luaL_checkstring(L, -1); + fp = fopen(certpath, "r"); + if (fp == NULL) { + err = "open certificate file fail"; + goto fail; } - r = SSL_CTX_use_PrivateKey_file(ptr, keypath, SSL_FILETYPE_PEM); - if (r != 1) { - lua_pop(L, 1); - lua_pushnil(L); - lua_pushstring(L, "SSL_CTX_use_PrivateKey_file"); - return 2; + cert = PEM_read_X509(fp, NULL, NULL, NULL); + fclose(fp); + fp = NULL; + if (cert == NULL) { + err = "read certificate file fail"; + goto fail; + } + ret = SSL_CTX_use_certificate_chain_file(ptr, certpath); + if (ret != 1) { + err = "SSL_CTX_use_certificate_file"; + goto fail; + } + ret = SSL_CTX_use_PrivateKey_file(ptr, keypath, SSL_FILETYPE_PEM); + if (ret != 1) { + printf("SSL_CTX_use_PrivateKey_file fail:%s\n", ERR_error_string(ERR_get_error(), NULL)); + err = "SSL_CTX_use_PrivateKey_file"; + goto fail; + } + ret = SSL_CTX_check_private_key(ptr); + if (ret != 1) { + err = "SSL_CTX_check_private_key"; + goto fail; + } + lua_settop(L, top); + entry->ptr = ptr; + entry->cert = cert; + return NULL; +fail: + lua_settop(L, top); + if (fp != NULL) + fclose(fp); + if (ptr != NULL) + SSL_CTX_free(ptr); + if (cert != NULL) + X509_free(cert); + return err; +} + +static int ssl_servername_cb(SSL *s, int *ad, void *arg) +{ + int i; + SSL_CTX *ptr = NULL; + const char *servername; + struct ctx *ctx = (struct ctx *) arg; + (void)ad; + servername = SSL_get_servername(s, TLSEXT_NAMETYPE_host_name); + if (servername != NULL) { + for (i = 0; i < ctx->entry_count; i++) { + X509 *cert = ctx->entries[i].cert; + if (cert == NULL) + continue; + if (X509_check_host(cert, servername, 0, 0, NULL) == 1) { + ptr = ctx->entries[i].ptr; + break; + } + } } - r = SSL_CTX_check_private_key(ptr); - if (r != 1) { + if (ptr == NULL) { + ptr = ctx->entries[0].ptr; + } + SSL_set_SSL_CTX(s, ptr); + return SSL_TLSEXT_ERR_OK; +} + +static int +lctx_server(lua_State *L) +{ + SSL_CTX *ptr; + struct ctx *ctx; + const char *err = NULL; + int i, ncert, r, alpn; + ncert = luaL_len(L, 1); + ctx = new_tls_ctx(L, 'S', ncert); + for (i = 0; i < ctx->entry_count; i++) { + int absidx; + struct ctx_entry *entry; + lua_rawgeti(L, 1, i + 1); + absidx = lua_absindex(L, -1); + entry = &ctx->entries[i]; + err = fill_entry(L, entry, absidx); lua_pop(L, 1); + if (err != NULL) + break; + ptr = entry->ptr; + SSL_CTX_set_tlsext_servername_callback(ptr, ssl_servername_cb); + SSL_CTX_set_tlsext_servername_arg(ptr, ctx); + } + if (err != NULL) { + ctx_destroy(ctx); lua_pushnil(L); - lua_pushstring(L, "SSL_CTX_check_private_key"); + lua_pushstring(L, err); return 2; } - if (lua_type(L, 3) != LUA_TNIL) { + if (lua_type(L, 2) != LUA_TNIL) { const char *cipher; - cipher = luaL_checkstring(L, 3); - r = SSL_CTX_set_cipher_list(ptr, cipher); - if (r != 1) { - lua_pop(L, 1); - lua_pushnil(L); - lua_pushstring(L, "SSL_CTX_set_cipher_list"); - return 2; + cipher = luaL_checkstring(L, 2); + for (i = 0; i < ctx->entry_count; i++) { + SSL_CTX *ptr = ctx->entries[i].ptr; + r = SSL_CTX_set_cipher_list(ptr, cipher); + if (r != 1) { + ctx_destroy(ctx); + lua_pushnil(L); + lua_pushstring(L, "SSL_CTX_set_cipher_list"); + return 2; + } } } - alpn = luaL_optinteger(L, 4, 0); + alpn = luaL_optinteger(L, 3, 0); if (alpn == 1) { ctx->alpn = alpn; - SSL_CTX_set_alpn_select_cb(ptr, alpn_cb, ctx); + for (i = 0; i < ctx->entry_count; i++) { + SSL_CTX *ptr = ctx->entries[i].ptr; + SSL_CTX_set_alpn_select_cb(ptr, alpn_cb, ctx); + } } return 1; } @@ -209,7 +318,7 @@ ltls_open(lua_State *L) hostname = lua_tostring(L, 3); alpn = luaL_optinteger(L, 4, 0); tls = new_tls(L, fd); - tls->ssl = SSL_new(ctx->ptr); + tls->ssl = SSL_new(ctx->entries[0].ptr); if (tls->ssl == NULL) luaL_error(L, "SSL_new fail"); if (alpn == 1) diff --git a/lualib/http/server.lua b/lualib/http/server.lua index c065c51..1846244 100644 --- a/lualib/http/server.lua +++ b/lualib/http/server.lua @@ -131,7 +131,7 @@ local function httpd(scheme, handler) end --request line local method, uri, ver = - first:match("(%w+)%s+(.-)%s+HTTP/([%d|.]+)\r\n") + first:match("(%w+)%s+(.-)%s+HTTP/([%d|.]+)\r\n") assert(method and uri and ver) res.method = method res.version = ver @@ -173,8 +173,7 @@ local server = { fd2 = tls.listen { disp = httpd("https", handler), port = conf.tls_port, - key = conf.tls_key, - cert = conf.tls_cert, + certs = conf.tls_certs, } end return fd1, fd2 diff --git a/lualib/http2/stream.lua b/lualib/http2/stream.lua index 44272bc..d77ce30 100644 --- a/lualib/http2/stream.lua +++ b/lualib/http2/stream.lua @@ -303,23 +303,23 @@ local function frame_winupdate(ch, id, flag, dat) end local frame_client = { -[FRAME_HEADERS] = frame_header_client, -[FRAME_DATA] = frame_data, -[FRAME_RST] = frame_rst, -[FRAME_SETTINGS] = frame_settings, -[FRAME_PING] = frame_ping, -[FRAME_GOAWAY] = frame_goaway, -[FRAME_WINUPDATE] = frame_winupdate, + [FRAME_HEADERS] = frame_header_client, + [FRAME_DATA] = frame_data, + [FRAME_RST] = frame_rst, + [FRAME_SETTINGS] = frame_settings, + [FRAME_PING] = frame_ping, + [FRAME_GOAWAY] = frame_goaway, + [FRAME_WINUPDATE] = frame_winupdate, } local frame_server = { -[FRAME_HEADERS] = frame_header_server, -[FRAME_DATA] = frame_data, -[FRAME_RST] = frame_rst, -[FRAME_SETTINGS] = frame_settings, -[FRAME_PING] = frame_ping, -[FRAME_GOAWAY] = frame_goaway, -[FRAME_WINUPDATE] = frame_winupdate, + [FRAME_HEADERS] = frame_header_server, + [FRAME_DATA] = frame_data, + [FRAME_RST] = frame_rst, + [FRAME_SETTINGS] = frame_settings, + [FRAME_PING] = frame_ping, + [FRAME_GOAWAY] = frame_goaway, + [FRAME_WINUPDATE] = frame_winupdate, } local function common_dispatch(ch, frame_process) @@ -516,9 +516,8 @@ end function M.listen(conf) return tls_listen { disp = httpd(conf.handler), + certs = conf.tls_certs, port = conf.tls_port, - key = conf.tls_key, - cert = conf.tls_cert, alpn = "h2" } end @@ -559,7 +558,6 @@ function M.ack(s, status, header, endstream) return tls_write(ch.fd, dat) end - function M.write(s, dat, continue) local ch = s.channel if not ch then diff --git a/lualib/sys/tls.lua b/lualib/sys/tls.lua index 4dce65a..a1ca8a1 100644 --- a/lualib/sys/tls.lua +++ b/lualib/sys/tls.lua @@ -88,7 +88,7 @@ function EVENT.data(fd, message) return end local delim = s.delim - tls.message(s.ssl, message) + tls.message(s.ssl, message) if not delim then --non suspend read return end @@ -150,6 +150,7 @@ end function M.listen(conf) assert(conf.port) assert(conf.disp) + assert(#conf.certs > 0) local portid = core.tcp_listen(conf.port, socket_dispatch, conf.backlog) if not portid then return nil @@ -157,7 +158,7 @@ function M.listen(conf) tls = require "sys.tls.tls" ctx = ctx or require "sys.tls.ctx" local mode = alpn_mode[conf.alpn] - local c = ctx.server(conf.cert, conf.key, conf.ciphers, mode) + local c = ctx.server(conf.certs, conf.ciphers, mode) local s = new_socket(portid, c, nil, nil) s.ctx = c s.disp = conf.disp diff --git a/test/testhttp2.lua b/test/testhttp2.lua index 13b468a..c2fd065 100644 --- a/test/testhttp2.lua +++ b/test/testhttp2.lua @@ -14,8 +14,12 @@ return function() end http2.listen { tls_port = ":8081", - tls_cert= "test/cert.pem", - tls_key = "test/key.pem", + tls_certs = { + { + cert = "test/cert.pem", + cert_key = "test/key.pem", + } + }, handler = function(stream) core.sleep(math.random(1, 300)) local header = stream:read() diff --git a/test/testtcp.lua b/test/testtcp.lua index 9b58833..6c411fa 100644 --- a/test/testtcp.lua +++ b/test/testtcp.lua @@ -18,8 +18,12 @@ end) local tlsfd = tls.listen { port = ":10002", - cert = "test/cert.pem", - key = "test/key.pem", + certs = { + { + cert= "test/cert.pem", + cert_key = "test/key.pem", + }, + }, disp = function(fd, addr) if listen_cb then listen_cb(fd) diff --git a/test/testwebsocket.lua b/test/testwebsocket.lua index 4a0375e..c83b761 100644 --- a/test/testwebsocket.lua +++ b/test/testwebsocket.lua @@ -5,8 +5,12 @@ local listen_cb websocket.listen { port = ":10003", tls_port = ":10004", - tls_cert= "test/cert.pem", - tls_key = "test/key.pem", + tls_certs = { + { + cert = "test/cert.pem", + cert_key = "test/key.pem", + } + }, handler = function(sock) local dat, typ = sock:read() testaux.asserteq(typ, "ping", "server read type `ping`")