diff --git a/.coveragerc b/.coveragerc index f20b132b..f5fa49a8 100644 --- a/.coveragerc +++ b/.coveragerc @@ -2,3 +2,4 @@ omit = hyper/compat.py hyper/httplib_compat.py + hyper/ssl_compat.py diff --git a/.gitignore b/.gitignore index e113636e..fe5fcae5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +build/ env/ dist/ *.egg-info/ diff --git a/hyper/compat.py b/hyper/compat.py index ed48a643..e89bf524 100644 --- a/hyper/compat.py +++ b/hyper/compat.py @@ -5,19 +5,30 @@ Normalizes the Python 2/3 API for internal use. """ +from contextlib import contextmanager import sys import zlib -# Syntax sugar. -_ver = sys.version_info +try: + from . import ssl_compat +except ImportError: + # TODO log? + ssl_compat = None -#: Python 2.x? -is_py2 = (_ver[0] == 2) +_ver = sys.version_info +is_py2 = _ver[0] == 2 +is_py3 = _ver[0] == 3 +is_py3_3 = is_py3 and _ver[1] == 3 -#: Python 3.x? -is_py3 = (_ver[0] == 3) +@contextmanager +def ignore_missing(): + try: + yield + except (AttributeError, NotImplementedError): # pragma: no cover + pass if is_py2: + ssl = ssl_compat from urlparse import urlparse def to_byte(char): @@ -41,3 +52,8 @@ def decode_hex(b): return bytes.fromhex(b) zlib_compressobj = zlib.compressobj + + if is_py3_3: + ssl = ssl_compat + else: + import ssl diff --git a/hyper/http20/frame.py b/hyper/http20/frame.py index afe41d12..5b1723d9 100644 --- a/hyper/http20/frame.py +++ b/hyper/http20/frame.py @@ -12,7 +12,6 @@ # The maximum length of a frame. Some frames have shorter maximum lengths. FRAME_MAX_LEN = (2 ** 14) - 1 - class Frame(object): """ The base class for all HTTP/2.0 frames. @@ -82,7 +81,8 @@ class DataFrame(Frame): associated with a stream. One or more DATA frames are used, for instance, to carry HTTP request or response payloads. """ - defined_flags = [('END_STREAM', 0x01)] + defined_flags = [('END_STREAM', 0x01), ('END_SEGMENT', 0x02), + ('PAD_LOW', 0x10), ('PAD_HIGH', 0x20)] type = 0 @@ -90,18 +90,40 @@ def __init__(self, stream_id): super(DataFrame, self).__init__(stream_id) self.data = b'' + self.low_padding = 0 + self.high_padding = 0 # Data frames may not be stream 0. if not self.stream_id: raise ValueError() def serialize(self): - data = self.build_frame_header(len(self.data)) - data += self.data - return data + padding_data = b'' + if 'PAD_LOW' in self.flags: + if 'PAD_HIGH' in self.flags: + padding_data = struct.pack('!BB', self.high_padding, self.low_padding) + else: + padding_data = struct.pack('!B', self.low_padding) + padding = b'\0' * self.total_padding + body = b''.join([padding_data, self.data, padding]) + header = self.build_frame_header(len(body)) + return header + body def parse_body(self, data): - self.data = data + padding_data_length = 0 + if 'PAD_LOW' in self.flags: + if 'PAD_HIGH' in self.flags: + self.high_padding, self.low_padding = struct.unpack('!BB', data[:2]) + padding_data_length = 2 + else: + self.low_padding = struct.unpack('!B', data[:1])[0] + padding_data_length = 1 + self.data = data[padding_data_length:len(data)-self.total_padding] + + @property + def total_padding(self): + """Return the total length of the padding, if any.""" + return (self.high_padding << 8) + self.low_padding class PriorityFrame(Frame): @@ -187,9 +209,8 @@ class SettingsFrame(Frame): # attributes. HEADER_TABLE_SIZE = 0x01 ENABLE_PUSH = 0x02 - MAX_CONCURRENT_STREAMS = 0x04 - INITIAL_WINDOW_SIZE = 0x07 - FLOW_CONTROL_OPTIONS = 0x0A + MAX_CONCURRENT_STREAMS = 0x03 + INITIAL_WINDOW_SIZE = 0x04 def __init__(self, stream_id): super(SettingsFrame, self).__init__(stream_id) @@ -201,19 +222,19 @@ def __init__(self, stream_id): raise ValueError() def serialize(self): - # Each setting consumes 8 bytes. - length = len(self.settings) * 8 + # Each setting consumes 5 bytes. + length = len(self.settings) * 5 data = self.build_frame_header(length) for setting, value in self.settings.items(): - data += struct.pack("!LL", setting & 0x00FFFFFF, value) + data += struct.pack("!BL", setting & 0xFF, value) return data def parse_body(self, data): - for i in range(0, len(data), 8): - name, value = struct.unpack("!LL", data[i:i+8]) + for i in range(0, len(data), 5): + name, value = struct.unpack("!BL", data[i:i+5]) self.settings[name] = value @@ -315,7 +336,7 @@ class WindowUpdateFrame(Frame): can indirectly cause the propagation of flow control information toward the original sender. """ - type = 0x09 + type = 0x08 def __init__(self, stream_id): super(WindowUpdateFrame, self).__init__(stream_id) @@ -348,8 +369,11 @@ class HeadersFrame(DataFrame): defined_flags = [ ('END_STREAM', 0x01), + ('END_SEGMENT', 0x02), ('END_HEADERS', 0x04), - ('PRIORITY', 0x08) + ('PRIORITY', 0x08), + ('PAD_LOW', 0x10), + ('PAD_HIGH', 0x20), ] def __init__(self, stream_id): @@ -386,21 +410,22 @@ class ContinuationFrame(DataFrame): Much like the HEADERS frame, hyper treats this as an opaque data frame with different flags and a different type. """ - type = 0x0A + type = 0x09 - defined_flags = [('END_HEADERS', 0x04)] + defined_flags = [('END_HEADERS', 0x04), ('PAD_LOW', 0x10), ('PAD_HIGH', 0x20)] # A map of type byte to frame class. -FRAMES = { - 0x00: DataFrame, - 0x01: HeadersFrame, - 0x02: PriorityFrame, - 0x03: RstStreamFrame, - 0x04: SettingsFrame, - 0x05: PushPromiseFrame, - 0x06: PingFrame, - 0x07: GoAwayFrame, - 0x09: WindowUpdateFrame, - 0x0A: ContinuationFrame -} +_FRAME_CLASSES = [ + DataFrame, + HeadersFrame, + PriorityFrame, + RstStreamFrame, + SettingsFrame, + PushPromiseFrame, + PingFrame, + GoAwayFrame, + WindowUpdateFrame, + ContinuationFrame, +] +FRAMES = {cls.type: cls for cls in _FRAME_CLASSES} diff --git a/hyper/http20/hpack.py b/hyper/http20/hpack.py index 02903979..b6ab823e 100644 --- a/hyper/http20/hpack.py +++ b/hyper/http20/hpack.py @@ -637,7 +637,7 @@ def _decode_indexed(self, data): # set. Otherwise, decode it as an integer with a 7-bit prefix: that's # our new header table max size. if not index: - next_byte = data[consumed] + next_byte = to_byte(data[consumed]) if next_byte & 0x80: self.reference_set = set() diff --git a/hyper/http20/stream.py b/hyper/http20/stream.py index b0705c98..cddb9a00 100644 --- a/hyper/http20/stream.py +++ b/hyper/http20/stream.py @@ -140,6 +140,11 @@ def listlen(list): # Append the data to the buffer. data.append(frame.data) + # Increase the window size. Only do this if the data frame contains + # actual data. + size = len(frame.data) + frame.total_padding + increment = self._in_window_manager._handle_frame(size) + # If that was the last frame, we're done here. if 'END_STREAM' in frame.flags: self.state = ( @@ -148,9 +153,6 @@ def listlen(list): ) break - # Increase the window size. Only do this if the data frame contains - # actual data. - increment = self._in_window_manager._handle_frame(len(frame.data)) if increment: w = WindowUpdateFrame(self.stream_id) w.window_increment = increment diff --git a/hyper/http20/tls.py b/hyper/http20/tls.py index c84465a2..1019d461 100644 --- a/hyper/http20/tls.py +++ b/hyper/http20/tls.py @@ -5,75 +5,57 @@ Contains the TLS/SSL logic for use in hyper. """ -import ssl import os.path as path -from ..compat import is_py3 +from ..compat import ignore_missing, ssl -# Right now we support draft 9. -SUPPORTED_PROTOCOLS = ['http/1.1', 'HTTP-draft-09/2.0'] +NPN_PROTOCOL = 'h2-10' +SUPPORTED_NPN_PROTOCOLS = ['http/1.1', NPN_PROTOCOL] # We have a singleton SSLContext object. There's no reason to be creating one -# per connection. We're using v23 right now until someone gives me a reason not -# to. +# per connection. _context = None -# Exposed here so it can be monkey-patched in integration tests. -_verify_mode = ssl.CERT_REQUIRED - - # Work out where our certificates are. cert_loc = path.join(path.dirname(__file__), '..', 'certs.pem') +def wrap_socket(sock, server_hostname): + """ + A vastly simplified SSL wrapping function. We'll probably extend this to + do more things later. + """ + global _context -if is_py3: # pragma: no cover - def wrap_socket(socket, server_hostname): - """ - A vastly simplified SSL wrapping function. We'll probably extend this to - do more things later. - """ - global _context - - if _context is None: # pragma: no cover - _context = _init_context() - - if ssl.HAS_SNI: - return _context.wrap_socket(socket, server_hostname=server_hostname) + if _context is None: # pragma: no cover + _context = _init_context() - wrapped = _context.wrap_socket(socket) # pragma: no cover - assert wrapped.selected_npn_protocol() == 'HTTP-draft-09/2.0' - return wrapped -else: # pragma: no cover - def wrap_socket(socket, server_hostname): - return ssl.wrap_socket(socket, ssl_version=ssl.PROTOCOL_SSLv23, - ca_certs=cert_loc, cert_reqs=_verify_mode) + # the spec requires SNI support + ssl_sock = _context.wrap_socket(sock, server_hostname=server_hostname) + # Setting SSLContext.check_hostname to True only verifies that the + # post-handshake servername matches that of the certificate. We also need to + # check that it matches the requested one. + ssl.match_hostname(ssl_sock.getpeercert(), server_hostname) + with ignore_missing(): + assert ssl_sock.selected_npn_protocol() == NPN_PROTOCOL + return ssl_sock -def _init_context(): # pragma: no cover +def _init_context(): """ Creates the singleton SSLContext we use. """ - context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) context.set_default_verify_paths() context.load_verify_locations(cafile=cert_loc) - context.verify_mode = _verify_mode - - try: - context.set_npn_protocols(SUPPORTED_PROTOCOLS) - except (AttributeError, NotImplementedError): # pragma: no cover - pass + context.verify_mode = ssl.CERT_REQUIRED + context.check_hostname = True - # We do our best to do better security - try: - context.options |= ssl.OP_NO_SSLv2 - except AttributeError: # pragma: no cover - pass + with ignore_missing(): + context.set_npn_protocols(SUPPORTED_NPN_PROTOCOLS) - try: - context.options |= ssl.OP_NO_COMPRESSION - except AttributeError: # pragma: no cover - pass + # required by the spec + context.options |= ssl.OP_NO_COMPRESSION return context diff --git a/hyper/httplib_compat.py b/hyper/httplib_compat.py index 6b551a76..315f24bb 100644 --- a/hyper/httplib_compat.py +++ b/hyper/httplib_compat.py @@ -15,7 +15,7 @@ except ImportError: import httplib -import ssl +from .compat import ssl from .http20.tls import wrap_socket # If there's no NPN support, we're going to drop all support for HTTP/2.0. diff --git a/hyper/ssl_compat.py b/hyper/ssl_compat.py new file mode 100644 index 00000000..e12bf1a5 --- /dev/null +++ b/hyper/ssl_compat.py @@ -0,0 +1,314 @@ +# -*- coding: utf-8 -*- +""" +hyper/ssl_compat +~~~~~~~~~ + +Shoves pyOpenSSL into an API that looks like the standard Python 3.x ssl module. + +Currently exposes exactly those attributes, classes, and methods that we +actually use in hyper (all method signatures are complete, however). May be +expanded to something more general-purpose in the future. +""" +try: + import StringIO as BytesIO +except ImportError: + from io import BytesIO +import errno +import socket +import time + +from OpenSSL import SSL as ossl + +CERT_NONE = ossl.VERIFY_NONE +CERT_REQUIRED = ossl.VERIFY_PEER | ossl.VERIFY_FAIL_IF_NO_PEER_CERT + +_OPENSSL_ATTRS = dict( + OP_NO_COMPRESSION='OP_NO_COMPRESSION', + PROTOCOL_TLSv1_2='TLSv1_2_METHOD', +) + +for external, internal in _OPENSSL_ATTRS.items(): + value = getattr(ossl, internal, None) + if value: + locals()[external] = value + +OP_ALL = 0 +for bit in [31] + list(range(10)): # TODO figure out the names of these other flags + OP_ALL |= 1 << bit + +HAS_NPN = False # TODO + +def _proxy(method): + return lambda self, *args, **kwargs: getattr(self._conn, method)(*args, **kwargs) + +# TODO missing some attributes +class SSLError(OSError): + pass + +class CertificateError(SSLError): + pass + + +# lifted from the Python 3.4 stdlib +def _dnsname_match(dn, hostname, max_wildcards=1): + """ + Matching according to RFC 6125, section 6.4.3. + + See http://tools.ietf.org/html/rfc6125#section-6.4.3 + """ + pats = [] + if not dn: + return False + + parts = dn.split(r'.') + leftmost = parts[0] + remainder = parts[1:] + + wildcards = leftmost.count('*') + if wildcards > max_wildcards: + # Issue #17980: avoid denials of service by refusing more + # than one wildcard per fragment. A survery of established + # policy among SSL implementations showed it to be a + # reasonable choice. + raise CertificateError( + "too many wildcards in certificate DNS name: " + repr(dn)) + + # speed up common case w/o wildcards + if not wildcards: + return dn.lower() == hostname.lower() + + # RFC 6125, section 6.4.3, subitem 1. + # The client SHOULD NOT attempt to match a presented identifier in which + # the wildcard character comprises a label other than the left-most label. + if leftmost == '*': + # When '*' is a fragment by itself, it matches a non-empty dotless + # fragment. + pats.append('[^.]+') + elif leftmost.startswith('xn--') or hostname.startswith('xn--'): + # RFC 6125, section 6.4.3, subitem 3. + # The client SHOULD NOT attempt to match a presented identifier + # where the wildcard character is embedded within an A-label or + # U-label of an internationalized domain name. + pats.append(re.escape(leftmost)) + else: + # Otherwise, '*' matches any dotless string, e.g. www* + pats.append(re.escape(leftmost).replace(r'\*', '[^.]*')) + + # add the remaining fragments, ignore any wildcards + for frag in remainder: + pats.append(re.escape(frag)) + + pat = re.compile(r'\A' + r'\.'.join(pats) + r'\Z', re.IGNORECASE) + return pat.match(hostname) + + +# lifted from the Python 3.4 stdlib +def match_hostname(cert, hostname): + """ + Verify that ``cert`` (in decoded format as returned by + ``SSLSocket.getpeercert())`` matches the ``hostname``. RFC 2818 and RFC + 6125 rules are followed, but IP addresses are not accepted for ``hostname``. + + ``CertificateError`` is raised on failure. On success, the function returns + nothing. + """ + if not cert: + raise ValueError("empty or no certificate, match_hostname needs a " + "SSL socket or SSL context with either " + "CERT_OPTIONAL or CERT_REQUIRED") + dnsnames = [] + san = cert.get('subjectAltName', ()) + for key, value in san: + if key == 'DNS': + if _dnsname_match(value, hostname): + return + dnsnames.append(value) + if not dnsnames: + # The subject is only checked when there is no dNSName entry + # in subjectAltName + for sub in cert.get('subject', ()): + for key, value in sub: + # XXX according to RFC 2818, the most specific Common Name + # must be used. + if key == 'commonName': + if _dnsname_match(value, hostname): + return + dnsnames.append(value) + if len(dnsnames) > 1: + raise CertificateError("hostname %r " + "doesn't match either of %s" + % (hostname, ', '.join(map(repr, dnsnames)))) + elif len(dnsnames) == 1: + raise CertificateError("hostname %r " + "doesn't match %r" + % (hostname, dnsnames[0])) + else: + raise CertificateError("no appropriate commonName or " + "subjectAltName fields were found") + + +class SSLSocket(object): + SSL_TIMEOUT = 3 + SSL_RETRY = .01 + + def __init__(self, conn, server_side, do_handshake_on_connect, + suppress_ragged_eofs, server_hostname, check_hostname): + self._conn = conn + self._do_handshake_on_connect = do_handshake_on_connect + self._suppress_ragged_eofs = suppress_ragged_eofs + self._check_hostname = check_hostname + + if server_side: + self._conn.set_accept_state() + else: + if server_hostname: + self._conn.set_tlsext_host_name(server_hostname.encode('utf-8')) + self._conn.set_connect_state() # FIXME does this override do_handshake_on_connect=False? + + if self.connected and self._do_handshake_on_connect: + self.do_handshake() + + @property + def connected(self): + try: + self._conn.getpeername() + except socket.error as e: + if e.errno != errno.ENOTCONN: + # It's an exception other than the one we expected if we're not + # connected. + raise + return False + return True + + # Lovingly stolen from CherryPy (http://svn.cherrypy.org/tags/cherrypy-3.2.1/cherrypy/wsgiserver/ssl_pyopenssl.py). + def _safe_ssl_call(self, suppress_ragged_eofs, call, *args, **kwargs): + """Wrap the given call with SSL error-trapping.""" + start = time.time() + while True: + try: + return call(*args, **kwargs) + except (ossl.WantReadError, ossl.WantWriteError): + # Sleep and try again. This is dangerous, because it means + # the rest of the stack has no way of differentiating + # between a "new handshake" error and "client dropped". + # Note this isn't an endless loop: there's a timeout below. + time.sleep(self.SSL_RETRY) + except ossl.Error as e: + if suppress_ragged_eofs and e.args == (-1, 'Unexpected EOF'): + return b'' + raise socket.error(e.args[0]) + + if time.time() - start > self.SSL_TIMEOUT: + raise socket.timeout('timed out') + + def connect(self, address): + self._conn.connect(address) + if self._do_handshake_on_connect: + self.do_handshake() + + def do_handshake(self): + self._safe_ssl_call(False, self._conn.do_handshake) + if self._check_hostname: + match_hostname(self.getpeercert(), self._conn.get_servername().decode('utf-8')) + + def recv(self, bufsize, flags=None): + return self._safe_ssl_call(self._suppress_ragged_eofs, self._conn.recv, + bufsize, flags) + + def send(self, data, flags=None): + return self._safe_ssl_call(False, self._conn.send, data, flags) + + def selected_npn_protocol(self): + raise NotImplementedError() + + def getpeercert(self): + def resolve_alias(alias): + return dict( + C='countryName', + ST='stateOrProvinceName', + L='localityName', + O='organizationName', + OU='organizationalUnitName', + CN='commonName', + ).get(alias, alias) + + def to_components(name): + # TODO Verify that these are actually *supposed* to all be single-element + # tuples, and that's not just a quirk of the examples I've seen. + return tuple([((resolve_alias(name.decode('utf-8')), value.decode('utf-8')),) for name, value in name.get_components()]) + + # The standard getpeercert() takes the nice X509 object tree returned + # by OpenSSL and turns it into a dict according to some format it seems + # to have made up on the spot. Here, we do our best to emulate that. + cert = self._conn.get_peer_certificate() + result = dict( + issuer=to_components(cert.get_issuer()), + subject=to_components(cert.get_subject()), + version=cert.get_subject(), + serialNumber=cert.get_serial_number(), + notBefore=cert.get_notBefore(), + notAfter=cert.get_notAfter(), + ) + # TODO extensions, including subjectAltName (see _decode_certificate in _ssl.c) + return result + + # a dash of magic to reduce boilerplate + for method in ['accept', 'bind', 'close', 'getsockname', 'listen']: + locals()[method] = _proxy(method) + + +class SSLContext(object): + def __init__(self, protocol): + self.protocol = protocol + self._ctx = ossl.Context(protocol) + self.options = OP_ALL + self.check_hostname = False + + @property + def options(self): + return self._options + + @options.setter + def options(self, value): + self._options = value + self._ctx.set_options(value) + + @property + def verify_mode(self): + return self._ctx.get_verify_mode() + + @verify_mode.setter + def verify_mode(self, value): + # TODO verify exception is raised on failure + self._ctx.set_verify(value, lambda conn, cert, errnum, errdepth, ok: ok) + + def set_default_verify_paths(self): + self._ctx.set_default_verify_paths() + + def load_verify_locations(self, cafile=None, capath=None, cadata=None): + # TODO factor out common code + if cafile is not None: + cafile = cafile.encode('utf-8') + if capath is not None: + capath = capath.encode('utf-8') + self._ctx.load_verify_locations(cafile, capath) + if cadata is not None: + self._ctx.load_verify_locations(BytesIO(cadata)) + + def load_cert_chain(self, certfile, keyfile=None, password=None): + self._ctx.use_certificate_file(certfile) + if password is not None: + self._ctx.set_password_cb(lambda max_length, prompt_twice, userdata: password) + self._ctx.use_privatekey_file(keyfile or certfile) + + def set_npn_protocols(self, protocols): + # TODO + raise NotImplementedError() + + def wrap_socket(self, sock, server_side=False, do_handshake_on_connect=True, + suppress_ragged_eofs=True, server_hostname=None): + conn = ossl.Connection(self._ctx, sock) + return SSLSocket(conn, server_side, do_handshake_on_connect, + suppress_ragged_eofs, server_hostname, + # TODO what if this is changed after the fact? + self.check_hostname) diff --git a/setup.py b/setup.py index b0cb07aa..8025b4bc 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,8 @@ #!/usr/bin/env python - +# -*- coding: utf-8 -*- +import itertools import os +import re import sys try: @@ -9,7 +11,6 @@ from distutils.core import setup # Get the version -import re version_regex = r'__version__ = ["\']([^"\']*)["\']' with open('hyper/__init__.py', 'r') as f: text = f.read() @@ -25,6 +26,13 @@ os.system('python setup.py sdist upload') sys.exit() +py_version = sys.version_info[:2] + +def resolve_install_requires(): + if py_version in [(2,7), (3,3)]: + return ['pyOpenSSL>=0.14'] + return [] + packages = ['hyper', 'hyper.http20'] setup( @@ -47,5 +55,6 @@ 'Programming Language :: Python', 'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3.3', - ] + ], + install_requires=resolve_install_requires(), ) diff --git a/test/server.py b/test/server.py index 7ed67463..13739743 100644 --- a/test/server.py +++ b/test/server.py @@ -14,17 +14,18 @@ import threading import socket -import ssl import sys -from hyper.compat import is_py3 +from hyper import HTTP20Connection +from hyper.compat import ssl from hyper.http20.hpack import Encoder from hyper.http20.huffman import HuffmanEncoder from hyper.http20.huffman_constants import ( REQUEST_CODES, REQUEST_CODES_LENGTH ) +from hyper.http20.tls import NPN_PROTOCOL -class _SocketServerThreadBase(threading.Thread): +class SocketServerThread(threading.Thread): """ This method stolen wholesale from shazow/urllib3. @@ -40,11 +41,17 @@ def __init__(self, socket_handler, host='localhost', ready_event=None): self.host = host self.ready_event = ready_event + self.cxt = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) + if ssl.HAS_NPN: + self.cxt.set_npn_protocols([NPN_PROTOCOL]) + self.cxt.load_cert_chain(certfile='test/certs/server.crt', + keyfile='test/certs/server.key') + def _start_server(self): sock = socket.socket(socket.AF_INET6) if sys.platform != 'win32': sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock = self._wrap_socket(sock) + sock = self.cxt.wrap_socket(sock, server_side=True) sock.bind((self.host, 0)) self.port = sock.getsockname()[1] @@ -64,31 +71,6 @@ def run(self): self.server = self._start_server() -class _SocketServerThreadPy2(_SocketServerThreadBase): - def _wrap_socket(self, sock): - return ssl.wrap_socket(sock, server_side=True, - certfile='test/certs/server.crt', - keyfile='test/certs/server.key') - - -class _SocketServerThreadPy3(_SocketServerThreadBase): - def __init__(self, socket_handler, host='localhost', ready_event=None): - super().__init__(socket_handler, host, ready_event) - self.cxt = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - self.cxt.set_npn_protocols(['HTTP-draft-09/2.0']) - self.cxt.load_cert_chain(certfile='test/certs/server.crt', - keyfile='test/certs/server.key') - - def _wrap_socket(self, sock): - return self.cxt.wrap_socket(sock, server_side=True) - - -if is_py3: - SocketServerThread = _SocketServerThreadPy3 -else: - SocketServerThread = _SocketServerThreadPy2 - - class SocketLevelTest(object): """ A test-class that defines a few helper methods for running socket-level @@ -106,13 +88,16 @@ def _start_server(self, socket_handler): ready_event = threading.Event() self.server_thread = SocketServerThread( socket_handler=socket_handler, - ready_event=ready_event + ready_event=ready_event, ) self.server_thread.start() ready_event.wait() self.host = self.server_thread.host self.port = self.server_thread.port + def get_connection(self): + return HTTP20Connection(self.host, self.port) + def get_encoder(self): """ Returns a HPACK encoder set up for responses. diff --git a/test/test_hyper.py b/test/test_hyper.py index 7fb23df7..2e866b97 100644 --- a/test/test_hyper.py +++ b/test/test_hyper.py @@ -48,10 +48,14 @@ def test_base_frame_cant_parse_body(self): class TestDataFrame(object): - def test_data_frame_has_only_one_flag(self): + payload = b'\x00\x08\x00\x01\x00\x00\x00\x01testdata' + payload_with_low_padding = b'\x00\x13\x00\x11\x00\x00\x00\x01\x0Atestdata' + b'\0' * 10 + payload_with_high_and_low_padding = b'\x06\x14\x00\x31\x00\x00\x00\x01\x06\x0Atestdata' + b'\0' * (6 * 256 + 10) + + def test_data_frame_has_correct_flags(self): f = DataFrame(1) flags = f.parse_flags(0xFF) - assert flags == set(['END_STREAM']) + assert flags == set(['END_STREAM', 'END_SEGMENT', 'PAD_LOW', 'PAD_HIGH']) def test_data_frame_serializes_properly(self): f = DataFrame(1) @@ -59,15 +63,55 @@ def test_data_frame_serializes_properly(self): f.data = b'testdata' s = f.serialize() - assert s == b'\x00\x08\x00\x01\x00\x00\x00\x01testdata' + assert s == self.payload + + def test_data_frame_with_low_padding_serializes_properly(self): + f = DataFrame(1) + f.flags = set(['END_STREAM', 'PAD_LOW']) + f.data = b'testdata' + f.low_padding = 10 + + s = f.serialize() + assert s == self.payload_with_low_padding + + def test_data_frame_with_high_and_low_padding_serializes_properly(self): + f = DataFrame(1) + f.flags = set(['END_STREAM', 'PAD_LOW', 'PAD_HIGH']) + f.data = b'testdata' + f.high_padding = 6 + f.low_padding = 10 + + s = f.serialize() + assert s == self.payload_with_high_and_low_padding def test_data_frame_parses_properly(self): - s = b'\x00\x08\x00\x01\x00\x00\x00\x01testdata' - f, length = Frame.parse_frame_header(s[:8]) - f.parse_body(s[8:8 + length]) + f, length = Frame.parse_frame_header(self.payload[:8]) + f.parse_body(self.payload[8:8 + length]) assert isinstance(f, DataFrame) assert f.flags == set(['END_STREAM']) + assert f.low_padding == 0 + assert f.high_padding == 0 + assert f.data == b'testdata' + + def test_data_frame_with_low_padding_parses_properly(self): + f, length = Frame.parse_frame_header(self.payload_with_low_padding[:8]) + f.parse_body(self.payload_with_low_padding[8:8 + length]) + + assert isinstance(f, DataFrame) + assert f.flags == set(['END_STREAM', 'PAD_LOW']) + assert f.low_padding == 10 + assert f.high_padding == 0 + assert f.data == b'testdata' + + def test_data_frame_with_high_and_low_padding_parses_properly(self): + f, length = Frame.parse_frame_header(self.payload_with_high_and_low_padding[:8]) + f.parse_body(self.payload_with_high_and_low_padding[8:8 + length]) + + assert isinstance(f, DataFrame) + assert f.flags == set(['END_STREAM', 'PAD_LOW', 'PAD_HIGH']) + assert f.low_padding == 10 + assert f.high_padding == 6 assert f.data == b'testdata' def test_data_frame_comes_on_a_stream(self): @@ -141,6 +185,21 @@ def test_rst_stream_frame_must_have_body_length_four(self): class TestSettingsFrame(object): + serialized = ( + b'\x00\x14\x04\x01\x00\x00\x00\x00' + # Frame header + b'\x01\x00\x00\x10\x00' + # HEADER_TABLE_SIZE + b'\x02\x00\x00\x00\x00' + # ENABLE_PUSH + b'\x03\x00\x00\x00\x64' + # MAX_CONCURRENT_STREAMS + b'\x04\x00\x00\xFF\xFF' # INITIAL_WINDOW_SIZE + ) + + settings = { + SettingsFrame.HEADER_TABLE_SIZE: 4096, + SettingsFrame.ENABLE_PUSH: 0, + SettingsFrame.MAX_CONCURRENT_STREAMS: 100, + SettingsFrame.INITIAL_WINDOW_SIZE: 65535, + } + def test_settings_frame_has_only_one_flag(self): f = SettingsFrame(0) flags = f.parse_flags(0xFF) @@ -149,45 +208,18 @@ def test_settings_frame_has_only_one_flag(self): def test_settings_frame_serializes_properly(self): f = SettingsFrame(0) f.parse_flags(0xFF) - f.settings = { - SettingsFrame.HEADER_TABLE_SIZE: 4096, - SettingsFrame.ENABLE_PUSH: 0, - SettingsFrame.MAX_CONCURRENT_STREAMS: 100, - SettingsFrame.INITIAL_WINDOW_SIZE: 65535, - SettingsFrame.FLOW_CONTROL_OPTIONS: 1, - } + f.settings = self.settings s = f.serialize() - assert s == ( - b'\x00\x28\x04\x01\x00\x00\x00\x00' + # Frame header - b'\x00\x00\x00\x01\x00\x00\x10\x00' + # HEADER_TABLE_SIZE - b'\x00\x00\x00\x02\x00\x00\x00\x00' + # ENABLE_PUSH - b'\x00\x00\x00\x04\x00\x00\x00\x64' + # MAX_CONCURRENT_STREAMS - b'\x00\x00\x00\x0A\x00\x00\x00\x01' + # FLOW_CONTROL_OPTIONS - b'\x00\x00\x00\x07\x00\x00\xFF\xFF' # INITIAL_WINDOW_SIZE - ) + assert s == self.serialized def test_settings_frame_parses_properly(self): - s = ( - b'\x00\x28\x04\x01\x00\x00\x00\x00' + # Frame header - b'\x00\x00\x00\x01\x00\x00\x10\x00' + # HEADER_TABLE_SIZE - b'\x00\x00\x00\x02\x00\x00\x00\x00' + # ENABLE_PUSH - b'\x00\x00\x00\x04\x00\x00\x00\x64' + # MAX_CONCURRENT_STREAMS - b'\x00\x00\x00\x0A\x00\x00\x00\x01' + # FLOW_CONTROL_OPTIONS - b'\x00\x00\x00\x07\x00\x00\xFF\xFF' # INITIAL_WINDOW_SIZE - ) - f, length = Frame.parse_frame_header(s[:8]) - f.parse_body(s[8:8 + length]) + f, length = Frame.parse_frame_header(self.serialized[:8]) + f.parse_body(self.serialized[8:8 + length]) assert isinstance(f, SettingsFrame) assert f.flags == set(['ACK']) - assert f.settings == { - SettingsFrame.HEADER_TABLE_SIZE: 4096, - SettingsFrame.ENABLE_PUSH: 0, - SettingsFrame.MAX_CONCURRENT_STREAMS: 100, - SettingsFrame.INITIAL_WINDOW_SIZE: 65535, - SettingsFrame.FLOW_CONTROL_OPTIONS: 1, - } + assert f.settings == self.settings def test_settings_frames_never_have_streams(self): with pytest.raises(ValueError): @@ -297,10 +329,10 @@ def test_window_update_serializes_properly(self): f.window_increment = 512 s = f.serialize() - assert s == b'\x00\x04\x09\x00\x00\x00\x00\x00\x00\x00\x02\x00' + assert s == b'\x00\x04\x08\x00\x00\x00\x00\x00\x00\x00\x02\x00' def test_windowupdate_frame_parses_properly(self): - s = b'\x00\x04\x09\x00\x00\x00\x00\x00\x00\x00\x02\x00' + s = b'\x00\x04\x08\x00\x00\x00\x00\x00\x00\x00\x02\x00' f, length = Frame.parse_frame_header(s[:8]) f.parse_body(s[8:8 + length]) @@ -314,11 +346,12 @@ def test_headers_frame_flags(self): f = HeadersFrame(1) flags = f.parse_flags(0xFF) - assert flags == set(['END_STREAM', 'END_HEADERS', 'PRIORITY']) + assert flags == set(['END_STREAM', 'END_SEGMENT', 'END_HEADERS', + 'PRIORITY', 'PAD_LOW', 'PAD_HIGH']) def test_headers_frame_serialize_with_priority_properly(self): f = HeadersFrame(1) - f.parse_flags(0xFF) + f.parse_flags(0x0D) f.priority = (2 ** 30) + 1 f.data = b'hello world' @@ -331,7 +364,7 @@ def test_headers_frame_serialize_with_priority_properly(self): def test_headers_frame_serialize_without_priority_properly(self): f = HeadersFrame(1) - f.parse_flags(0xFF) + f.parse_flags(0x0D) f.data = b'hello world' s = f.serialize() @@ -360,21 +393,21 @@ def test_continuation_frame_flags(self): f = ContinuationFrame(1) flags = f.parse_flags(0xFF) - assert flags == set(['END_HEADERS']) + assert flags == set(['END_HEADERS', 'PAD_LOW', 'PAD_HIGH']) def test_continuation_frame_serializes(self): f = ContinuationFrame(1) - f.parse_flags(0xFF) + f.parse_flags(0x04) f.data = b'hello world' s = f.serialize() assert s == ( - b'\x00\x0B\x0A\x04\x00\x00\x00\x01' + + b'\x00\x0B\x09\x04\x00\x00\x00\x01' + b'hello world' ) def test_continuation_frame_parses_properly(self): - s = b'\x00\x0B\x0A\x04\x00\x00\x00\x01hello world' + s = b'\x00\x0B\x09\x04\x00\x00\x00\x01hello world' f, length = Frame.parse_frame_header(s[:8]) f.parse_body(s[8:8 + length]) @@ -1146,6 +1179,34 @@ def test_windowupdate_frames_update_windows(self): assert s._out_flow_control_window == 65535 + 1000 + def test_flow_control_manager_update_includes_padding(self): + out_frames = [] + in_frames = [] + + def send_cb(frame): + out_frames.append(frame) + + def recv_cb(s): + def inner(): + s.receive_frame(in_frames.pop(0)) + return inner + + start_window = 65535 + s = Stream(1, send_cb, None, None, None, None, FlowControlManager(start_window)) + s._recv_cb = recv_cb(s) + s.state = STATE_HALF_CLOSED_LOCAL + + # Provide two data frames to read. + f = DataFrame(1) + f.data = b'hi there!' + f.low_padding = 10 + f.flags.add('END_STREAM') + in_frames.append(f) + + data = s._read() + assert data == b'hi there!' + assert s._in_window_manager.window_size == start_window - f.low_padding - len(data) + def test_stream_reading_works(self): out_frames = [] in_frames = [] @@ -1158,7 +1219,7 @@ def inner(): s.receive_frame(in_frames.pop(0)) return inner - s = Stream(1, send_cb, None, None, None, None, None) + s = Stream(1, send_cb, None, None, None, None, FlowControlManager(65535)) s._recv_cb = recv_cb(s) s.state = STATE_HALF_CLOSED_LOCAL diff --git a/test/test_integration.py b/test/test_integration.py index 2dc0b1ae..45fd249b 100644 --- a/test/test_integration.py +++ b/test/test_integration.py @@ -7,12 +7,10 @@ hitting the network, so that's alright. """ import requests -import ssl import threading import hyper import pytest -from hyper import HTTP20Connection -from hyper.compat import is_py3 +from hyper.compat import ssl from hyper.contrib import HTTP20Adapter from hyper.http20.frame import ( Frame, SettingsFrame, WindowUpdateFrame, DataFrame, HeadersFrame, @@ -27,9 +25,9 @@ from server import SocketLevelTest # Turn off certificate verification for the tests. -hyper.http20.tls._verify_mode = ssl.CERT_NONE -if is_py3: +if ssl is not None: hyper.http20.tls._context = hyper.http20.tls._init_context() + hyper.http20.tls._context.verify_mode = ssl.CERT_NONE def decode_frame(frame_data): f, length = Frame.parse_frame_header(frame_data[:8]) @@ -83,7 +81,7 @@ def socket_handler(listener): sock.close() self._start_server(socket_handler) - conn = HTTP20Connection(self.host, self.port) + conn = self.get_connection() conn.connect() send_event.wait() @@ -117,7 +115,7 @@ def socket_handler(listener): sock.close() self._start_server(socket_handler) - conn = HTTP20Connection(self.host, self.port) + conn = self.get_connection() conn.connect() send_event.wait() @@ -172,7 +170,7 @@ def socket_handler(listener): sock.close() self._start_server(socket_handler) - conn = HTTP20Connection(self.host, self.port) + conn = self.get_connection() conn.putrequest('GET', '/') conn.endheaders() @@ -226,7 +224,7 @@ def socket_handler(listener): sock.close() self._start_server(socket_handler) - with HTTP20Connection(self.host, self.port) as conn: + with self.get_connection() as conn: conn.connect() send_event.wait() @@ -257,7 +255,7 @@ def socket_handler(listener): sock.close() self._start_server(socket_handler) - conn = HTTP20Connection(self.host, self.port) + conn = self.get_connection() conn.request('GET', '/') resp = conn.getresponse() @@ -293,7 +291,7 @@ def socket_handler(listener): sock.close() self._start_server(socket_handler) - conn = HTTP20Connection(self.host, self.port) + conn = self.get_connection() conn.request('GET', '/') resp = conn.getresponse() @@ -333,7 +331,7 @@ def socket_handler(listener): sock.close() self._start_server(socket_handler) - conn = HTTP20Connection(self.host, self.port) + conn = self.get_connection() conn.connect() # Confirm the connection is closed. @@ -368,7 +366,7 @@ def socket_handler(listener): recv_event.wait() self._start_server(socket_handler) - conn = HTTP20Connection(self.host, self.port) + conn = self.get_connection() with pytest.raises(ConnectionError): conn.connect() @@ -456,9 +454,9 @@ def socket_handler(listener): self._start_server(socket_handler) s = requests.Session() - s.mount('https://%s' % self.host, HTTP20Adapter()) + s.mount('http://%s' % self.host, HTTP20Adapter()) r = s.post( - 'https://%s:%s/some/path' % (self.host, self.port), + 'http://%s:%s/some/path' % (self.host, self.port), data='hi there', ) diff --git a/test_requirements.txt b/test_requirements.txt index df88a33a..9549cc44 100644 --- a/test_requirements.txt +++ b/test_requirements.txt @@ -1,5 +1,5 @@ py==1.4.19 pytest==2.5.1 -requests pytest-xdist pytest-cov +requests