Skip to content

Commit

Permalink
Handle exceptions in ALPN select callback.
Browse files Browse the repository at this point in the history
  • Loading branch information
Lukasa committed Apr 12, 2015
1 parent 3e9f152 commit 7197911
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 35 deletions.
90 changes: 55 additions & 35 deletions OpenSSL/SSL.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,56 @@ def wrapper(ssl, out, outlen, in_, inlen, arg):
)


class _AlpnSelectHelper(_CallbackExceptionHelper):
"""
Wrap a callback such that it can be used as an ALPN selection callback.
"""
def __init__(self, callback):
_CallbackExceptionHelper.__init__(self)

@wraps(callback)
def wrapper(ssl, out, outlen, in_, inlen, arg):
try:
conn = Connection._reverse_mapping[ssl]

# The string passed to us is made up of multiple
# length-prefixed bytestrings. We need to split that into a
# list.
instr = _ffi.buffer(in_, inlen)[:]
protolist = []
while instr:
l = indexbytes(instr, 0)
proto = instr[1:l+1]
protolist.append(proto)
instr = instr[l+1:]

# Call the callback
outstr = callback(conn, protolist)

if not isinstance(outstr, _binary_type):
raise TypeError("ALPN callback must return a bytestring.")

# Save our callback arguments on the connection object to make
# sure that they don't get freed before OpenSSL can use them.
# Then, return them in the appropriate output parameters.
conn._alpn_select_callback_args = [
_ffi.new("unsigned char *", len(outstr)),
_ffi.new("unsigned char[]", outstr),
]
outlen[0] = conn._alpn_select_callback_args[0][0]
out[0] = conn._alpn_select_callback_args[1]
return 0
except Exception as e:
self._problems.append(e)
return 2 # SSL_TLSEXT_ERR_ALERT_FATAL

self.callback = _ffi.callback(
"int (*)(SSL *, unsigned char **, unsigned char *, "
"const unsigned char *, unsigned int, void *)",
wrapper
)


def _asFileDescriptor(obj):
fd = None
if not isinstance(obj, integer_types):
Expand Down Expand Up @@ -408,6 +458,7 @@ def __init__(self, method):
self._npn_advertise_callback = None
self._npn_select_helper = None
self._npn_select_callback = None
self._alpn_select_helper = None
self._alpn_select_callback = None

# SSL_CTX_set_app_data(self->ctx, self);
Expand Down Expand Up @@ -991,41 +1042,8 @@ def set_alpn_select_callback(self, callback):
bytestrings, e.g ``[b'http/1.1', b'spdy/2']``. It should return
one of those bytestrings, the chosen protocol.
"""
@wraps(callback)
def wrapper(ssl, out, outlen, in_, inlen, arg):
conn = Connection._reverse_mapping[ssl]

# The string passed to us is made up of multiple length-prefixed
# bytestrings. We need to split that into a list.
instr = _ffi.buffer(in_, inlen)[:]
protolist = []
while instr:
l = indexbytes(instr, 0)
proto = instr[1:l+1]
protolist.append(proto)
instr = instr[l+1:]

# Call the callback
outstr = callback(conn, protolist)

if not isinstance(outstr, _binary_type):
raise TypeError("ALPN callback must return a bytestring.")

# Save our callback arguments on the connection object to make sure
# that they don't get freed before OpenSSL can use them. Then,
# return them in the appropriate output parameters.
conn._alpn_select_callback_args = [
_ffi.new("unsigned char *", len(outstr)),
_ffi.new("unsigned char[]", outstr),
]
outlen[0] = conn._alpn_select_callback_args[0][0]
out[0] = conn._alpn_select_callback_args[1]
return 0

self._alpn_select_callback = _ffi.callback(
"int (*)(SSL *, unsigned char **, unsigned char *, "
"const unsigned char *, unsigned int, void *)",
wrapper)
self._alpn_select_helper = _AlpnSelectHelper(callback)
self._alpn_select_callback = self._alpn_select_helper.callback
_lib.SSL_CTX_set_alpn_select_cb(
self._context, self._alpn_select_callback, _ffi.NULL)

Expand Down Expand Up @@ -1104,6 +1122,8 @@ def _raise_ssl_error(self, ssl, result):
self._context._npn_advertise_helper.raise_if_problem()
if self._context._npn_select_helper is not None:
self._context._npn_select_helper.raise_if_problem()
if self._context._alpn_select_helper is not None:
self._context._alpn_select_helper.raise_if_problem()

error = _lib.SSL_get_error(ssl, result)
if error == _lib.SSL_ERROR_WANT_READ:
Expand Down
33 changes: 33 additions & 0 deletions OpenSSL/test/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1776,6 +1776,39 @@ def test_alpn_no_server(self):
self.assertEqual(client.get_alpn_proto_negotiated(), b'')


def test_alpn_callback_exception(self):
"""
Test that we can handle exceptions in the ALPN select callback.
"""
select_args = []
def select(conn, options):
select_args.append((conn, options))
raise TypeError

client_context = Context(TLSv1_METHOD)
client_context.set_alpn_protos([b'http/1.1', b'spdy/2'])

server_context = Context(TLSv1_METHOD)
server_context.set_alpn_select_callback(select)

# Necessary to actually accept the connection
server_context.use_privatekey(
load_privatekey(FILETYPE_PEM, server_key_pem))
server_context.use_certificate(
load_certificate(FILETYPE_PEM, server_cert_pem))

# Do a little connection to trigger the logic
server = Connection(server_context, None)
server.set_accept_state()

client = Connection(client_context, None)
client.set_connect_state()

# If the client doesn't return anything, the connection will fail.
self.assertRaises(TypeError, self._interactInMemory, server, client)
self.assertEqual([(server, [b'http/1.1', b'spdy/2'])], select_args)



class SessionTests(TestCase):
"""
Expand Down

0 comments on commit 7197911

Please sign in to comment.