From 71979112cb9f52a9237884bf4d3b60d714acb6dc Mon Sep 17 00:00:00 2001 From: Cory Benfield Date: Sun, 12 Apr 2015 09:11:49 -0400 Subject: [PATCH] Handle exceptions in ALPN select callback. --- OpenSSL/SSL.py | 90 ++++++++++++++++++++++++---------------- OpenSSL/test/test_ssl.py | 33 +++++++++++++++ 2 files changed, 88 insertions(+), 35 deletions(-) diff --git a/OpenSSL/SSL.py b/OpenSSL/SSL.py index 74b5c0eb3..da9dbd6d6 100644 --- a/OpenSSL/SSL.py +++ b/OpenSSL/SSL.py @@ -317,6 +317,56 @@ def wrapper(ssl, out, outlen, in_, inlen, arg): ) +class _AlpnSelectHelper(_CallbackExceptionHelper): + """ + Wrap a callback such that it can be used as an ALPN selection callback. + """ + def __init__(self, callback): + _CallbackExceptionHelper.__init__(self) + + @wraps(callback) + def wrapper(ssl, out, outlen, in_, inlen, arg): + try: + conn = Connection._reverse_mapping[ssl] + + # The string passed to us is made up of multiple + # length-prefixed bytestrings. We need to split that into a + # list. + instr = _ffi.buffer(in_, inlen)[:] + protolist = [] + while instr: + l = indexbytes(instr, 0) + proto = instr[1:l+1] + protolist.append(proto) + instr = instr[l+1:] + + # Call the callback + outstr = callback(conn, protolist) + + if not isinstance(outstr, _binary_type): + raise TypeError("ALPN callback must return a bytestring.") + + # Save our callback arguments on the connection object to make + # sure that they don't get freed before OpenSSL can use them. + # Then, return them in the appropriate output parameters. + conn._alpn_select_callback_args = [ + _ffi.new("unsigned char *", len(outstr)), + _ffi.new("unsigned char[]", outstr), + ] + outlen[0] = conn._alpn_select_callback_args[0][0] + out[0] = conn._alpn_select_callback_args[1] + return 0 + except Exception as e: + self._problems.append(e) + return 2 # SSL_TLSEXT_ERR_ALERT_FATAL + + self.callback = _ffi.callback( + "int (*)(SSL *, unsigned char **, unsigned char *, " + "const unsigned char *, unsigned int, void *)", + wrapper + ) + + def _asFileDescriptor(obj): fd = None if not isinstance(obj, integer_types): @@ -408,6 +458,7 @@ def __init__(self, method): self._npn_advertise_callback = None self._npn_select_helper = None self._npn_select_callback = None + self._alpn_select_helper = None self._alpn_select_callback = None # SSL_CTX_set_app_data(self->ctx, self); @@ -991,41 +1042,8 @@ def set_alpn_select_callback(self, callback): bytestrings, e.g ``[b'http/1.1', b'spdy/2']``. It should return one of those bytestrings, the chosen protocol. """ - @wraps(callback) - def wrapper(ssl, out, outlen, in_, inlen, arg): - conn = Connection._reverse_mapping[ssl] - - # The string passed to us is made up of multiple length-prefixed - # bytestrings. We need to split that into a list. - instr = _ffi.buffer(in_, inlen)[:] - protolist = [] - while instr: - l = indexbytes(instr, 0) - proto = instr[1:l+1] - protolist.append(proto) - instr = instr[l+1:] - - # Call the callback - outstr = callback(conn, protolist) - - if not isinstance(outstr, _binary_type): - raise TypeError("ALPN callback must return a bytestring.") - - # Save our callback arguments on the connection object to make sure - # that they don't get freed before OpenSSL can use them. Then, - # return them in the appropriate output parameters. - conn._alpn_select_callback_args = [ - _ffi.new("unsigned char *", len(outstr)), - _ffi.new("unsigned char[]", outstr), - ] - outlen[0] = conn._alpn_select_callback_args[0][0] - out[0] = conn._alpn_select_callback_args[1] - return 0 - - self._alpn_select_callback = _ffi.callback( - "int (*)(SSL *, unsigned char **, unsigned char *, " - "const unsigned char *, unsigned int, void *)", - wrapper) + self._alpn_select_helper = _AlpnSelectHelper(callback) + self._alpn_select_callback = self._alpn_select_helper.callback _lib.SSL_CTX_set_alpn_select_cb( self._context, self._alpn_select_callback, _ffi.NULL) @@ -1104,6 +1122,8 @@ def _raise_ssl_error(self, ssl, result): self._context._npn_advertise_helper.raise_if_problem() if self._context._npn_select_helper is not None: self._context._npn_select_helper.raise_if_problem() + if self._context._alpn_select_helper is not None: + self._context._alpn_select_helper.raise_if_problem() error = _lib.SSL_get_error(ssl, result) if error == _lib.SSL_ERROR_WANT_READ: diff --git a/OpenSSL/test/test_ssl.py b/OpenSSL/test/test_ssl.py index 68bb67987..ad2cd6d09 100644 --- a/OpenSSL/test/test_ssl.py +++ b/OpenSSL/test/test_ssl.py @@ -1776,6 +1776,39 @@ def test_alpn_no_server(self): self.assertEqual(client.get_alpn_proto_negotiated(), b'') + def test_alpn_callback_exception(self): + """ + Test that we can handle exceptions in the ALPN select callback. + """ + select_args = [] + def select(conn, options): + select_args.append((conn, options)) + raise TypeError + + client_context = Context(TLSv1_METHOD) + client_context.set_alpn_protos([b'http/1.1', b'spdy/2']) + + server_context = Context(TLSv1_METHOD) + server_context.set_alpn_select_callback(select) + + # Necessary to actually accept the connection + server_context.use_privatekey( + load_privatekey(FILETYPE_PEM, server_key_pem)) + server_context.use_certificate( + load_certificate(FILETYPE_PEM, server_cert_pem)) + + # Do a little connection to trigger the logic + server = Connection(server_context, None) + server.set_accept_state() + + client = Connection(client_context, None) + client.set_connect_state() + + # If the client doesn't return anything, the connection will fail. + self.assertRaises(TypeError, self._interactInMemory, server, client) + self.assertEqual([(server, [b'http/1.1', b'spdy/2'])], select_args) + + class SessionTests(TestCase): """