Skip to content

Commit

Permalink
Add ALPN support.
Browse files Browse the repository at this point in the history
  • Loading branch information
Lukasa committed Jun 7, 2014
1 parent 06ddbf3 commit 5763ae4
Show file tree
Hide file tree
Showing 3 changed files with 266 additions and 1 deletion.
112 changes: 111 additions & 1 deletion OpenSSL/SSL.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from sys import platform
from functools import wraps, partial
from itertools import count
from itertools import count, chain
from weakref import WeakValueDictionary
from errno import errorcode

from six import text_type as _text_type
from six import integer_types as integer_types
from six import int2byte, indexbytes

from OpenSSL._util import (
ffi as _ffi,
Expand Down Expand Up @@ -293,6 +294,7 @@ def __init__(self, method):
self._info_callback = None
self._tlsext_servername_callback = None
self._app_data = None
self._alpn_select_callback = None

# SSL_CTX_set_app_data(self->ctx, self);
# SSL_CTX_set_mode(self->ctx, SSL_MODE_ENABLE_PARTIAL_WRITE |
Expand Down Expand Up @@ -809,6 +811,73 @@ def wrapper(ssl, alert, arg):
_lib.SSL_CTX_set_tlsext_servername_callback(
self._context, self._tlsext_servername_callback)

def set_alpn_protos(self, protos):
"""
Specify the list of protocols that will get offered to the server for
ALPN negotiation.
:param protos: A list of the protocols to be offered to the server.
This list should be a Python list of bytestrings representing the
protocols to offer, e.g. ``[b'http/1.1', b'spdy/2']``.
"""
# Take the list of protocols and join them together, prefixing them
# with their lengths.
protostr = b''.join(
chain.from_iterable((int2byte(len(p)), p) for p in protos)
)

# Build a C string from the list. We don't need to save this off
# because OpenSSL immediately copies the data out.
input_str = _ffi.new("unsigned char[]", protostr)
input_str_len = _ffi.new("unsigned", len(protostr))
_lib.SSL_CTX_set_alpn_protos(self._context, input_str)
return

def set_alpn_select_callback(self, callback):
"""
Specify a callback that will be called when the client offers ALPN
protocols.
:param callback: The callback function. It will be invoked with two
arguments: the Connection, and a list of offered protocols as
bytestrings, e.g ``[b'http/1.1', b'spdy/2']``. It should return
one of those bytestrings, then chosen protocol.
"""
@wraps(callback)
def wrapper(ssl, out, outlen, in_, inlen, arg):
conn = Connection._reverse_mapping[ssl]

# The string passed to us is made up of multiple length-prefixed
# bytestrings. We need to split that into a list.
instr = _ffi.buffer(in_, inlen)[:]
protolist = []
while instr:
l = indexbytes(instr, 0)
proto = instr[1:l+1]
protolist.append(proto)
instr = instr[l+1:]

# Call the callback
outstr = callback(conn, protolist)

# Save our callback arguments on the connection object to make sure
# that they don't get freed before OpenSSL can use them. Then,
# return them in the appropriate output parameters.
conn._alpn_select_callback_args = [
_ffi.new("unsigned char *", len(outstr)),
_ffi.new("unsigned char[]", outstr),
]
outlen[0] = conn._alpn_select_callback_args[0][0]
out[0] = conn._alpn_select_callback_args[1]
return 0

self._alpn_select_callback = _ffi.callback(
"int (*)(SSL *, unsigned char **, unsigned char *, "
"const unsigned char *, unsigned int, void *",
wrapper)
_lib.SSL_CTX_set_alpn_select_cb(
self._context, self._alpn_select_callback, _ffi.NULL)

ContextType = Context


Expand All @@ -833,6 +902,12 @@ def __init__(self, context, socket=None):
self._ssl = _ffi.gc(ssl, _lib.SSL_free)
self._context = context

# References to strings used for Application Layer Protocol
# Negotiation. These strings get copied at some point but it's well
# after the callback returns, so we have to hang them somewhere to
# avoid them getting freed.
self._alpn_select_callback_args = None

self._reverse_mapping[self._ssl] = self

if socket is None:
Expand Down Expand Up @@ -1551,6 +1626,41 @@ def get_cipher_version(self):
return version.decode("utf-8")


def set_alpn_protos(self, protos):
"""
Specify the list of protocols that will get offered to the server for
ALPN negotiation.
:param protos: A list of the protocols to be offered to the server.
This list should be a Python list of bytestrings representing the
protocols to offer, e.g. ``[b'http/1.1', b'spdy/2']``.
"""
# Take the list of protocols and join them together, prefixing them
# with their lengths.
protostr = b''.join(
chain.from_iterable((int2byte(len(p)), p) for p in protos)
)

# Build a C string from the list. We don't need to save this off
# because OpenSSL immediately copies the data out.
input_str = _ffi.new("unsigned char[]", protostr)
input_str_len = _ffi.new("unsigned", len(protostr))
_lib.SSL_set_alpn_protos(self._ssl, input_str)
return


def get_alpn_proto_negotiated(self):
"""
Get the protocol that was negotiated by ALPN.
"""
data = _ffi.new("unsigned char **")
data_len = _ffi.new("unsigned int *")

_lib.SSL_get0_alpn_selected(self._ssl, data, data_len)

return _ffi.buffer(data[0], data_len[0])[:]



ConnectionType = Connection

Expand Down
117 changes: 117 additions & 0 deletions OpenSSL/test/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1435,6 +1435,123 @@ def servername(conn):



class ApplicationLayerProtoNegotiationTests(TestCase, _LoopbackMixin):
"""
Tests for ALPN in PyOpenSSL.
"""
def test_alpn_success(self):
"""
Tests that clients and servers that agree on the negotiated ALPN
protocol can correct establish a connection, and that the agreed
protocol is reported by the connections.
"""
select_args = []
def select(conn, options):
select_args.append((conn, options))
return b'spdy/2'

client_context = Context(TLSv1_METHOD)
client_context.set_alpn_protos([b'http/1.1', b'spdy/2'])

server_context = Context(TLSv1_METHOD)
server_context.set_alpn_select_callback(select)

# Necessary to actually accept the connection
server_context.use_privatekey(
load_privatekey(FILETYPE_PEM, server_key_pem))
server_context.use_certificate(
load_certificate(FILETYPE_PEM, server_cert_pem))

# Do a little connection to trigger the logic
server = Connection(server_context, None)
server.set_accept_state()

client = Connection(client_context, None)
client.set_connect_state()

self._interactInMemory(server, client)

self.assertEqual([(client, [b'http/1.1', b'spdy/2'])], select_args)

self.assertEqual(server.get_alpn_proto_negotiated(), b'spdy/2')
self.assertEqual(client.get_alpn_proto_negotiated(), b'spdy/2')


def test_alpn_set_on_connection(self):
"""
The same as test_alpn_success, but setting the ALPN protocols on the
connection rather than the context.
"""
select_args = []
def select(conn, options):
select_args.append((conn, options))
return b'spdy/2'

# Setup the client context but don't set any ALPN protocols.
client_context = Context(TLSv1_METHOD)

server_context = Context(TLSv1_METHOD)
server_context.set_alpn_select_callback(select)

# Necessary to actually accept the connection
server_context.use_privatekey(
load_privatekey(FILETYPE_PEM, server_key_pem))
server_context.use_certificate(
load_certificate(FILETYPE_PEM, server_cert_pem))

# Do a little connection to trigger the logic
server = Connection(server_context, None)
server.set_accept_state()

# Set the ALPN protocols on the client connection.
client = Connection(client_context, None)
client.set_alpn_protos([b'http/1.1', b'spdy/2'])
client.set_connect_state()

self._interactInMemory(server, client)

self.assertEqual([(client, [b'http/1.1', b'spdy/2'])], select_args)

self.assertEqual(server.get_alpn_proto_negotiated(), b'spdy/2')
self.assertEqual(client.get_alpn_proto_negotiated(), b'spdy/2')


def test_alpn_server_fail(self):
"""
Tests that when clients and servers cannot agree on what protocol to
use next that the TLS connection does not get established.
"""
select_args = []
def select(conn, options):
select_args.append((conn, options))
return b''

client_context = Context(TLSv1_METHOD)
client_context.set_alpn_protos([b'http/1.1', b'spdy/2'])

server_context = Context(TLSv1_METHOD)
server_context.set_alpn_select_callback(select)

# Necessary to actually accept the connection
server_context.use_privatekey(
load_privatekey(FILETYPE_PEM, server_key_pem))
server_context.use_certificate(
load_certificate(FILETYPE_PEM, server_cert_pem))

# Do a little connection to trigger the logic
server = Connection(server_context, None)
server.set_accept_state()

client = Connection(client_context, None)
client.set_connect_state()

# If the client doesn't return anything, the connection will fail.
self.assertRaises(Error, self._interactInMemory, server, client)

self.assertEqual([(client, [b'http/1.1', b'spdy/2'])], select_args)



class SessionTests(TestCase):
"""
Unit tests for :py:obj:`OpenSSL.SSL.Session`.
Expand Down
38 changes: 38 additions & 0 deletions doc/api/ssl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,27 @@ Context objects have the following methods:
.. versionadded:: 0.13


.. py:method:: Context.set_alpn_protos(protos)
Specify the protocols that the client is prepared to speak after the TLS
connection has been negotiated, using Application Layer Protocol
Negotiation.

*protos* should be a list of protocols that the client is offering, each
as a bytestring. For example, ``[b'http/1.1', b'spdy/2']``.


.. py:method:: Context.set_alpn_select_callback(callback)
Specify a callback function that will be called on the server when a client
offers protocols using Application Layer Protocol Negotiation.

*callback* should be the callback function. It will be invoked with two
arguments: the :py:class:`Connection`, and a list of offered protocols as
bytestrings, e.g. ``[b'http/1.1', b'spdy/2']``. It should return one of
these bytestrings, the chosen protocol.


.. _openssl-session:

Session objects
Expand Down Expand Up @@ -806,6 +827,23 @@ Connection objects have the following methods:
.. versionadded:: 0.15


.. py:method:: Connection.set_alpn_protos(protos)
Specify the protocols that the client is prepared to speak after the TLS
connection has been negotiated, using Application Layer Protocol
Negotiation.

*protos* should be a list of protocols that the client is offering, each
as a bytestring. For example, ``[b'http/1.1', b'spdy/2']``.


.. py:method:: Connection.get_alpn_proto_negotiated()
Get the protocol that was negotiated by Application Layer Protocol
Negotiation. Returns a bytestring of the protocol name. If no protocol has
been negotiated yet, returns an empty string.


.. Rubric:: Footnotes

.. [#connection-context-socket] Actually, all that is required is an object that
Expand Down

0 comments on commit 5763ae4

Please sign in to comment.