diff --git a/.coveragerc b/.coveragerc index 847e7525..f20b132b 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,2 +1,4 @@ [run] -omit=hyper/httplib_compat.py +omit = + hyper/compat.py + hyper/httplib_compat.py diff --git a/.travis.yml b/.travis.yml index a82905d9..7779057b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,5 +1,6 @@ language: python python: + - "2.7" - "3.3" install: diff --git a/CONTRIBUTORS.rst b/CONTRIBUTORS.rst index 69e271e6..29cd4797 100644 --- a/CONTRIBUTORS.rst +++ b/CONTRIBUTORS.rst @@ -13,3 +13,7 @@ In chronological order: - Sriram Ganesan (@elricL) - Implemented the Huffman encoding/decoding logic. + +- Alek Storm (@alekstorm) + + - Implemented Python 2.7 support. diff --git a/conftest.py b/conftest.py index 95a52255..8546d6e2 100644 --- a/conftest.py +++ b/conftest.py @@ -2,6 +2,10 @@ import pytest import os import json +import sys + +if sys.version_info[0] == 2: + from codecs import open # This pair of generator expressions are pretty lame, but building lists is a # bad idea as I plan to have a substantial number of tests here. diff --git a/hyper/__init__.py b/hyper/__init__.py index 7827d9d4..a726be7b 100644 --- a/hyper/__init__.py +++ b/hyper/__init__.py @@ -11,10 +11,10 @@ from .http20.connection import HTTP20Connection from .http20.response import HTTP20Response -# Throw import errors on Python 2. +# Throw import errors on Python <2.7 and 3.0-3.2. import sys as _sys -if _sys.version_info[0] < 3 or _sys.version_info[1] < 3: - raise ImportError("hyper only supports Python 3.3 or higher.") +if _sys.version_info < (2,7) or (3,0) <= _sys.version_info < (3,3): + raise ImportError("hyper only supports Python 2.7 and Python 3.3 or higher.") __all__ = [HTTP20Response, HTTP20Connection] diff --git a/hyper/compat.py b/hyper/compat.py new file mode 100644 index 00000000..ed48a643 --- /dev/null +++ b/hyper/compat.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +""" +hyper/compat +~~~~~~~~~ + +Normalizes the Python 2/3 API for internal use. +""" +import sys +import zlib + +# Syntax sugar. +_ver = sys.version_info + +#: Python 2.x? +is_py2 = (_ver[0] == 2) + +#: Python 3.x? +is_py3 = (_ver[0] == 3) + +if is_py2: + from urlparse import urlparse + + def to_byte(char): + return ord(char) + + def decode_hex(b): + return b.decode('hex') + + # The standard zlib.compressobj() accepts only positional arguments. + def zlib_compressobj(level=6, method=zlib.DEFLATED, wbits=15, memlevel=8, + strategy=zlib.Z_DEFAULT_STRATEGY): + return zlib.compressobj(level, method, wbits, memlevel, strategy) + +elif is_py3: + from urllib.parse import urlparse + + def to_byte(char): + return char + + def decode_hex(b): + return bytes.fromhex(b) + + zlib_compressobj = zlib.compressobj diff --git a/hyper/contrib.py b/hyper/contrib.py index 0633f5f6..a684eac2 100644 --- a/hyper/contrib.py +++ b/hyper/contrib.py @@ -15,7 +15,7 @@ HTTPAdapter = object from hyper import HTTP20Connection -from urllib.parse import urlparse +from hyper.compat import urlparse class HTTP20Adapter(HTTPAdapter): """ diff --git a/hyper/http20/connection.py b/hyper/http20/connection.py index 86e9fdd4..adbb57b4 100644 --- a/hyper/http20/connection.py +++ b/hyper/http20/connection.py @@ -43,7 +43,7 @@ class HTTP20Connection(object): :class:`FlowControlManager ` will be used. """ - def __init__(self, host, port=None, *, window_manager=None, **kwargs): + def __init__(self, host, port=None, window_manager=None, **kwargs): """ Creates an HTTP/2.0 connection to a specific server. """ @@ -160,7 +160,6 @@ def connect(self): if self._sock is None: sock = socket.create_connection((self.host, self.port), 5) sock = wrap_socket(sock, self.host) - assert sock.selected_npn_protocol() == 'HTTP-draft-09/2.0' self._sock = sock # We need to send the connection header immediately on this diff --git a/hyper/http20/hpack.py b/hyper/http20/hpack.py index 3aefbcf5..d7efffa3 100644 --- a/hyper/http20/hpack.py +++ b/hyper/http20/hpack.py @@ -10,6 +10,7 @@ import collections import logging +from ..compat import to_byte from .huffman import HuffmanDecoder, HuffmanEncoder from hyper.http20.huffman_constants import ( REQUEST_CODES, REQUEST_CODES_LENGTH, RESPONSE_CODES, RESPONSE_CODES_LENGTH @@ -55,13 +56,13 @@ def decode_integer(data, prefix_bits): mask = 0xFF >> (8 - prefix_bits) index = 0 - number = data[index] & mask + number = to_byte(data[index]) & mask if (number == max_number): while True: index += 1 - next_byte = data[index] + next_byte = to_byte(data[index]) if next_byte >= 128: number += (next_byte - 128) * multiple(index) @@ -407,7 +408,7 @@ def _encode_indexed(self, index): """ field = encode_integer(index, 7) field[0] = field[0] | 0x80 # we set the top bit - return field + return bytes(field) def _encode_literal(self, name, value, indexing, huffman=False): """ @@ -415,7 +416,7 @@ def _encode_literal(self, name, value, indexing, huffman=False): is True, the header will be added to the header table: otherwise it will not. """ - prefix = bytes([0x00 if indexing else 0x40]) + prefix = b'\x00' if indexing else b'\x40' if huffman: name = self.huffman_coder.encode(name) @@ -428,7 +429,7 @@ def _encode_literal(self, name, value, indexing, huffman=False): name_len[0] |= 0x80 value_len[0] |= 0x80 - return b''.join([prefix, name_len, name, value_len, value]) + return b''.join([prefix, bytes(name_len), name, bytes(value_len), value]) def _encode_indexed_literal(self, index, value, indexing, huffman=False): """ @@ -449,7 +450,7 @@ def _encode_indexed_literal(self, index, value, indexing, huffman=False): if huffman: value_len[0] |= 0x80 - return b''.join([name, value_len, value]) + return b''.join([bytes(name), bytes(value_len), value]) class Decoder(object): @@ -572,11 +573,12 @@ def decode(self, data): while current_index < data_len: # Work out what kind of header we're decoding. # If the high bit is 1, it's an indexed field. - indexed = bool(data[current_index] & 0x80) + current = to_byte(data[current_index]) + indexed = bool(current & 0x80) # Otherwise, if the second-highest bit is 1 it's a field that # doesn't alter the header table. - literal_no_index = bool(data[current_index] & 0x40) + literal_no_index = bool(current & 0x40) if indexed: header, consumed = self._decode_indexed(data[current_index:]) @@ -678,7 +680,7 @@ def _decode_literal(self, data, should_index): # If the low six bits of the first byte are nonzero, the header # name is indexed. - first_byte = data[0] + first_byte = to_byte(data[0]) if first_byte & 0x3F: # Indexed header name. @@ -701,7 +703,7 @@ def _decode_literal(self, data, should_index): length, consumed = decode_integer(data, 7) name = data[consumed:consumed + length] - if data[0] & 0x80: + if to_byte(data[0]) & 0x80: name = self.huffman_coder.decode(name) total_consumed = consumed + length + 1 # Since we moved forward 1. @@ -711,7 +713,7 @@ def _decode_literal(self, data, should_index): length, consumed = decode_integer(data, 7) value = data[consumed:consumed + length] - if data[0] & 0x80: + if to_byte(data[0]) & 0x80: value = self.huffman_coder.decode(value) # Updated the total consumed length. diff --git a/hyper/http20/huffman.py b/hyper/http20/huffman.py index a911f08f..1909e900 100644 --- a/hyper/http20/huffman.py +++ b/hyper/http20/huffman.py @@ -6,9 +6,9 @@ An implementation of a bitwise prefix tree specially built for decoding Huffman-coded content where we already know the Huffman table. """ +from ..compat import to_byte, decode_hex from .exceptions import HPACKDecodingError - def _pad_binary(bin_str, req_len=8): """ Given a binary string (returned by bin()), pad it to a full byte length. @@ -21,7 +21,7 @@ def _hex_to_bin_str(hex_string): Given a Python bytestring, returns a string representing those bytes in unicode form. """ - unpadded_bin_string_list = map(bin, hex_string) + unpadded_bin_string_list = (bin(to_byte(c)) for c in hex_string) padded_bin_string_list = map(_pad_binary, unpadded_bin_string_list) bitwise_message = "".join(padded_bin_string_list) return bitwise_message @@ -60,7 +60,7 @@ def decode(self, encoded_string): """ number = _hex_to_bin_str(encoded_string) cur_node = self.root - decoded_message = [] + decoded_message = bytearray() try: for digit in number: @@ -103,9 +103,10 @@ def encode(self, bytes_to_encode): # Turn each byte into its huffman code. These codes aren't necessarily # octet aligned, so keep track of how far through an octet we are. To # handle this cleanly, just use a single giant integer. - for letter in bytes_to_encode: - bin_int_len = self.huffman_code_list_lengths[letter] - bin_int = self.huffman_code_list[letter] & (2 ** (bin_int_len + 1) - 1) + for char in bytes_to_encode: + byte = to_byte(char) + bin_int_len = self.huffman_code_list_lengths[byte] + bin_int = self.huffman_code_list[byte] & (2 ** (bin_int_len + 1) - 1) final_num <<= bin_int_len final_num |= bin_int final_int_len += bin_int_len @@ -115,10 +116,11 @@ def encode(self, bytes_to_encode): final_num <<= bits_to_be_padded final_num |= (1 << (bits_to_be_padded)) - 1 - # Convert the number to hex and strip off the leading '0x' - final_num = hex(final_num)[2:] + # Convert the number to hex and strip off the leading '0x' and the + # trailing 'L', if present. + final_num = hex(final_num)[2:].rstrip('L') # If this is odd, prepend a zero. final_num = '0' + final_num if len(final_num) % 2 != 0 else final_num - return bytes.fromhex(final_num) + return decode_hex(final_num) diff --git a/hyper/http20/tls.py b/hyper/http20/tls.py index fa031a14..c84465a2 100644 --- a/hyper/http20/tls.py +++ b/hyper/http20/tls.py @@ -8,6 +8,8 @@ import ssl import os.path as path +from ..compat import is_py3 + # Right now we support draft 9. SUPPORTED_PROTOCOLS = ['http/1.1', 'HTTP-draft-09/2.0'] @@ -18,35 +20,45 @@ # to. _context = None +# Exposed here so it can be monkey-patched in integration tests. +_verify_mode = ssl.CERT_REQUIRED + # Work out where our certificates are. cert_loc = path.join(path.dirname(__file__), '..', 'certs.pem') -def wrap_socket(socket, server_hostname): - """ - A vastly simplified SSL wrapping function. We'll probably extend this to - do more things later. - """ - global _context +if is_py3: # pragma: no cover + def wrap_socket(socket, server_hostname): + """ + A vastly simplified SSL wrapping function. We'll probably extend this to + do more things later. + """ + global _context - if _context is None: # pragma: no cover - _context = _init_context() + if _context is None: # pragma: no cover + _context = _init_context() - if ssl.HAS_SNI: - return _context.wrap_socket(socket, server_hostname=server_hostname) + if ssl.HAS_SNI: + return _context.wrap_socket(socket, server_hostname=server_hostname) - return _context.wrap_socket(socket) # pragma: no cover + wrapped = _context.wrap_socket(socket) # pragma: no cover + assert wrapped.selected_npn_protocol() == 'HTTP-draft-09/2.0' + return wrapped +else: # pragma: no cover + def wrap_socket(socket, server_hostname): + return ssl.wrap_socket(socket, ssl_version=ssl.PROTOCOL_SSLv23, + ca_certs=cert_loc, cert_reqs=_verify_mode) -def _init_context(): +def _init_context(): # pragma: no cover """ Creates the singleton SSLContext we use. """ context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) context.set_default_verify_paths() context.load_verify_locations(cafile=cert_loc) - context.verify_mode = ssl.CERT_REQUIRED + context.verify_mode = _verify_mode try: context.set_npn_protocols(SUPPORTED_PROTOCOLS) diff --git a/test/server.py b/test/server.py index 61a813be..0267f73d 100644 --- a/test/server.py +++ b/test/server.py @@ -17,13 +17,14 @@ import ssl import sys +from hyper.compat import is_py3 from hyper.http20.hpack import Encoder from hyper.http20.huffman import HuffmanEncoder from hyper.http20.huffman_constants import ( RESPONSE_CODES, RESPONSE_CODES_LENGTH ) -class SocketServerThread(threading.Thread): +class _SocketServerThreadBase(threading.Thread): """ This method stolen wholesale from shazow/urllib3. @@ -32,22 +33,18 @@ class SocketServerThread(threading.Thread): :param ready_event: Event which gets set when the socket handler is ready to receive requests. """ - def __init__(self, socket_handler, host='localhost', port=8081, - ready_event=None): + def __init__(self, socket_handler, host='localhost', ready_event=None): threading.Thread.__init__(self) self.socket_handler = socket_handler self.host = host self.ready_event = ready_event - self.cxt = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - self.cxt.set_npn_protocols(['HTTP-draft-09/2.0']) - self.cxt.load_cert_chain(certfile='test/certs/server.crt', keyfile='test/certs/server.key') def _start_server(self): sock = socket.socket(socket.AF_INET6) if sys.platform != 'win32': sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock = self.cxt.wrap_socket(sock, server_side=True) + sock = self._wrap_socket(sock) sock.bind((self.host, 0)) self.port = sock.getsockname()[1] @@ -60,10 +57,38 @@ def _start_server(self): self.socket_handler(sock) sock.close() + def _wrap_socket(self, sock): + raise NotImplementedError() + def run(self): self.server = self._start_server() +class _SocketServerThreadPy2(_SocketServerThreadBase): + def _wrap_socket(self, sock): + return ssl.wrap_socket(sock, server_side=True, + certfile='test/certs/server.crt', + keyfile='test/certs/server.key') + + +class _SocketServerThreadPy3(_SocketServerThreadBase): + def __init__(self, socket_handler, host='localhost', ready_event=None): + super().__init__(socket_handler, host, ready_event) + self.cxt = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + self.cxt.set_npn_protocols(['HTTP-draft-09/2.0']) + self.cxt.load_cert_chain(certfile='test/certs/server.crt', + keyfile='test/certs/server.key') + + def _wrap_socket(self, sock): + return self.cxt.wrap_socket(sock, server_side=True) + + +if is_py3: + SocketServerThread = _SocketServerThreadPy3 +else: + SocketServerThread = _SocketServerThreadPy2 + + class SocketLevelTest(object): """ A test-class that defines a few helper methods for running socket-level @@ -101,4 +126,3 @@ def tear_down(self): Tears down the testing thread. """ self.server_thread.join(0.1) - diff --git a/test/test_hyper.py b/test/test_hyper.py index 7be6432a..7f3f3f39 100644 --- a/test/test_hyper.py +++ b/test/test_hyper.py @@ -14,6 +14,7 @@ from hyper.http20.response import HTTP20Response from hyper.http20.exceptions import HPACKDecodingError, HPACKEncodingError from hyper.http20.window import FlowControlManager +from hyper.compat import zlib_compressobj from hyper.contrib import HTTP20Adapter import pytest import zlib @@ -1251,7 +1252,7 @@ def test_status_is_stripped_from_headers(self): def test_response_transparently_decrypts_gzip(self): headers = {':status': '200', 'content-encoding': 'gzip'} - c = zlib.compressobj(wbits=24) + c = zlib_compressobj(wbits=24) body = c.compress(b'this is test data') body += c.flush() resp = HTTP20Response(headers, DummyStream(body)) @@ -1260,7 +1261,7 @@ def test_response_transparently_decrypts_gzip(self): def test_response_transparently_decrypts_real_deflate(self): headers = {':status': '200', 'content-encoding': 'deflate'} - c = zlib.compressobj(wbits=zlib.MAX_WBITS) + c = zlib_compressobj(wbits=zlib.MAX_WBITS) body = c.compress(b'this is test data') body += c.flush() resp = HTTP20Response(headers, DummyStream(body)) @@ -1269,7 +1270,7 @@ def test_response_transparently_decrypts_real_deflate(self): def test_response_transparently_decrypts_wrong_deflate(self): headers = {':status': '200', 'content-encoding': 'deflate'} - c = zlib.compressobj(wbits=-zlib.MAX_WBITS) + c = zlib_compressobj(wbits=-zlib.MAX_WBITS) body = c.compress(b'this is test data') body += c.flush() resp = HTTP20Response(headers, DummyStream(body)) @@ -1346,6 +1347,7 @@ def test_adapter_reuses_connections(self): # Some utility classes for the tests. class NullEncoder(object): + @staticmethod def encode(headers): return '\n'.join("%s%s" % (name, val) for name, val in headers) diff --git a/test/test_import.py b/test/test_import.py index 7cf6c03d..b8e3a1f2 100644 --- a/test/test_import.py +++ b/test/test_import.py @@ -6,7 +6,7 @@ class TestImportPython2(object): def test_cannot_import_python_2(self, monkeypatch): - monkeypatch.setattr(sys, 'version_info', (2, 7, 7, 'final', 0)) + monkeypatch.setattr(sys, 'version_info', (2, 6, 5, 'final', 0)) with pytest.raises(ImportError): imp.reload(hyper) diff --git a/test/test_integration.py b/test/test_integration.py index 685486f8..7b77da42 100644 --- a/test/test_integration.py +++ b/test/test_integration.py @@ -12,6 +12,7 @@ import hyper import pytest from hyper import HTTP20Connection +from hyper.compat import is_py3 from hyper.contrib import HTTP20Adapter from hyper.http20.frame import ( Frame, SettingsFrame, WindowUpdateFrame, DataFrame, HeadersFrame, @@ -26,8 +27,9 @@ from server import SocketLevelTest # Turn off certificate verification for the tests. -hyper.http20.tls._context = hyper.http20.tls._init_context() -hyper.http20.tls._context.verify_mode = ssl.CERT_NONE +hyper.http20.tls._verify_mode = ssl.CERT_NONE +if is_py3: + hyper.http20.tls._context = hyper.http20.tls._init_context() def decode_frame(frame_data): f, length = Frame.parse_frame_header(frame_data[:8])