Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ALPN support. #120

Merged
merged 20 commits into from
Apr 14, 2015
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 155 additions & 1 deletion OpenSSL/SSL.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from errno import errorcode

from six import text_type as _text_type
from six import binary_type as _binary_type
from six import integer_types as integer_types
from six import int2byte, indexbytes

Expand Down Expand Up @@ -318,6 +319,56 @@ def wrapper(ssl, out, outlen, in_, inlen, arg):
)


class _ALPNSelectHelper(_CallbackExceptionHelper):
"""
Wrap a callback such that it can be used as an ALPN selection callback.
"""
def __init__(self, callback):
_CallbackExceptionHelper.__init__(self)

@wraps(callback)
def wrapper(ssl, out, outlen, in_, inlen, arg):
try:
conn = Connection._reverse_mapping[ssl]

# The string passed to us is made up of multiple
# length-prefixed bytestrings. We need to split that into a
# list.
instr = _ffi.buffer(in_, inlen)[:]
protolist = []
while instr:
encoded_len = indexbytes(instr, 0)
proto = instr[1:encoded_len + 1]
protolist.append(proto)
instr = instr[encoded_len + 1:]

# Call the callback
outstr = callback(conn, protolist)

if not isinstance(outstr, _binary_type):
raise TypeError("ALPN callback must return a bytestring.")

# Save our callback arguments on the connection object to make
# sure that they don't get freed before OpenSSL can use them.
# Then, return them in the appropriate output parameters.
conn._alpn_select_callback_args = [
_ffi.new("unsigned char *", len(outstr)),
_ffi.new("unsigned char[]", outstr),
]
outlen[0] = conn._alpn_select_callback_args[0][0]
out[0] = conn._alpn_select_callback_args[1]
return 0
except Exception as e:
self._problems.append(e)
return 2 # SSL_TLSEXT_ERR_ALERT_FATAL

self.callback = _ffi.callback(
"int (*)(SSL *, unsigned char **, unsigned char *, "
"const unsigned char *, unsigned int, void *)",
wrapper
)


def _asFileDescriptor(obj):
fd = None
if not isinstance(obj, integer_types):
Expand Down Expand Up @@ -348,6 +399,22 @@ def SSLeay_version(type):



def _requires_alpn(func):
"""
Wraps any function that requires ALPN support in OpenSSL, ensuring that
NotImplementedError is raised if ALPN support is not present.
"""
@wraps(func)
def wrapper(*args, **kwargs):
if not _lib.Cryptography_HAS_ALPN:
raise NotImplementedError("ALPN not available.")

return func(*args, **kwargs)

return wrapper



class Session(object):
pass

Expand Down Expand Up @@ -409,6 +476,8 @@ def __init__(self, method):
self._npn_advertise_callback = None
self._npn_select_helper = None
self._npn_select_callback = None
self._alpn_select_helper = None
self._alpn_select_callback = None

# SSL_CTX_set_app_data(self->ctx, self);
# SSL_CTX_set_mode(self->ctx, SSL_MODE_ENABLE_PARTIAL_WRITE |
Expand Down Expand Up @@ -923,7 +992,6 @@ def wrapper(ssl, alert, arg):
_lib.SSL_CTX_set_tlsext_servername_callback(
self._context, self._tlsext_servername_callback)


def set_npn_advertise_callback(self, callback):
"""
Specify a callback function that will be called when offering `Next
Expand Down Expand Up @@ -956,6 +1024,44 @@ def set_npn_select_callback(self, callback):
_lib.SSL_CTX_set_next_proto_select_cb(
self._context, self._npn_select_callback, _ffi.NULL)

@_requires_alpn
def set_alpn_protos(self, protos):
"""
Specify the clients ALPN protocol list.

These protocols are offered to the server during protocol negotiation.

:param protos: A list of the protocols to be offered to the server.
This list should be a Python list of bytestrings representing the
protocols to offer, e.g. ``[b'http/1.1', b'spdy/2']``.
"""
# Take the list of protocols and join them together, prefixing them
# with their lengths.
protostr = b''.join(
chain.from_iterable((int2byte(len(p)), p) for p in protos)
)

# Build a C string from the list. We don't need to save this off
# because OpenSSL immediately copies the data out.
input_str = _ffi.new("unsigned char[]", protostr)
input_str_len = _ffi.cast("unsigned", len(protostr))
_lib.SSL_CTX_set_alpn_protos(self._context, input_str, input_str_len)

@_requires_alpn
def set_alpn_select_callback(self, callback):
"""

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First line is too long here - its a nit but perhaps

"""Set the callback to handle ALPN protocol choice.

:param ....

Set the callback to handle ALPN protocol choice.

:param callback: The callback function. It will be invoked with two
arguments: the Connection, and a list of offered protocols as
bytestrings, e.g ``[b'http/1.1', b'spdy/2']``. It should return
one of those bytestrings, the chosen protocol.
"""
self._alpn_select_helper = _ALPNSelectHelper(callback)
self._alpn_select_callback = self._alpn_select_helper.callback
_lib.SSL_CTX_set_alpn_select_cb(
self._context, self._alpn_select_callback, _ffi.NULL)

ContextType = Context


Expand Down Expand Up @@ -987,6 +1093,12 @@ def __init__(self, context, socket=None):
self._npn_advertise_callback_args = None
self._npn_select_callback_args = None

# References to strings used for Application Layer Protocol
# Negotiation. These strings get copied at some point but it's well
# after the callback returns, so we have to hang them somewhere to
# avoid them getting freed.
self._alpn_select_callback_args = None

self._reverse_mapping[self._ssl] = self

if socket is None:
Expand Down Expand Up @@ -1025,6 +1137,8 @@ def _raise_ssl_error(self, ssl, result):
self._context._npn_advertise_helper.raise_if_problem()
if self._context._npn_select_helper is not None:
self._context._npn_select_helper.raise_if_problem()
if self._context._alpn_select_helper is not None:
self._context._alpn_select_helper.raise_if_problem()

error = _lib.SSL_get_error(ssl, result)
if error == _lib.SSL_ERROR_WANT_READ:
Expand Down Expand Up @@ -1757,6 +1871,46 @@ def get_next_proto_negotiated(self):

return _ffi.buffer(data[0], data_len[0])[:]

@_requires_alpn
def set_alpn_protos(self, protos):
"""

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another minor nit here. Perhaps:

"""Specify the clients ALPN protocol list.

These protocols are offered to the server during protocol negotiation.

:param...

Specify the client's ALPN protocol list.

These protocols are offered to the server during protocol negotiation.

:param protos: A list of the protocols to be offered to the server.
This list should be a Python list of bytestrings representing the
protocols to offer, e.g. ``[b'http/1.1', b'spdy/2']``.
"""
# Take the list of protocols and join them together, prefixing them
# with their lengths.
protostr = b''.join(
chain.from_iterable((int2byte(len(p)), p) for p in protos)
)

# Build a C string from the list. We don't need to save this off
# because OpenSSL immediately copies the data out.
input_str = _ffi.new("unsigned char[]", protostr)
input_str_len = _ffi.cast("unsigned", len(protostr))
_lib.SSL_set_alpn_protos(self._ssl, input_str, input_str_len)


def get_alpn_proto_negotiated(self):
"""Get the protocol that was negotiated by ALPN."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know PEP8 says to use this single-line style of docstring. Can you change this to put the triple quotes on their own lines, anyway? (I recognize what a jerk-like comment this is). (This message approved by Hynek, future king of pyOpenSSL)

if not _lib.Cryptography_HAS_ALPN:
raise NotImplementedError("ALPN not available")

data = _ffi.new("unsigned char **")
data_len = _ffi.new("unsigned int *")

_lib.SSL_get0_alpn_selected(self._ssl, data, data_len)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if none was selected? Does it return NULL, or an empty string? Ah, typical openssl. It returns len 0 and data is unspecified - so we can likely segfault with the current code.

I suggest a
if not data_len:
return b''
check.

if not data_len:
return b''

return _ffi.buffer(data[0], data_len[0])[:]



ConnectionType = Connection

Expand Down
Loading